1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_MLIR_BRIDGE_ROLLOUT_POLICY_H_
17 #define TENSORFLOW_COMPILER_MLIR_MLIR_BRIDGE_ROLLOUT_POLICY_H_
18
19 #include "mlir/IR/BuiltinOps.h"
20 #include "absl/types/optional.h"
21 #include "tensorflow/core/graph/graph.h"
22 #include "tensorflow/core/protobuf/config.pb.h"
23
24 namespace tensorflow {
25
26 enum class MlirBridgeRolloutPolicy {
27 // The MLIR bridge is explicitly disabled by the user and must not be run.
28 kDisabledByUser = 0,
29 // The MLIR bridge is explicitly enabled by the user and must be run. If the
30 // MLIR bridge errors, the fallback path should NOT be used.
31 kEnabledByUser,
32 // The bridge was not explicitly enabled or disabled by the user. Based on the
33 // features in the model, the MLIR bridge should not be run.
34 kDisabledAfterGraphAnalysis,
35 // The bridge was not explicitly enabled or disabled by the user. Based on the
36 // features in the model, the MLIR bridge should be run. If the MLIR Bridge
37 // errors, the fallback path should be used whenever possible.
38 kEnabledAfterGraphAnalysis,
39 // The bridge was fallback enabled in a safe mode and passed all graph
40 // analysis checks.
41 kEnabledAfterGraphAnalysisSafeModeFallback
42 };
43
44 // Analyzes the user requested policy as well as the contents of the graph and
45 // returns true when the MLIR Bridge should be run.
46 //
47 // If the user explicitly requests the bridge be enabled or disabled, this
48 // function will respect the request. If the user does not explicitly request
49 // enabled or disabled, it will decide whether or not to run the bridge.
50 //
51 // The config_proto param is a required input for all TF1 graphs but it is
52 // redundant for TF2 graphs.
53 // If getting rollout policy involves graph analysis, `record_stats` is used
54 // to decide whether to emit metrics on unsupported features of the graph.
55 MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
56 const tensorflow::Graph& graph,
57 const FunctionLibraryDefinition* function_library,
58 std::optional<tensorflow::ConfigProto> config_proto,
59 bool uses_uninitialized_resource_args, bool is_v1_compat,
60 bool record_stats);
61
GetMlirBridge2ndPhaseRolloutPolicy(mlir::ModuleOp module)62 static inline MlirBridgeRolloutPolicy GetMlirBridge2ndPhaseRolloutPolicy(
63 mlir::ModuleOp module) {
64 return MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis;
65 }
66
67 // Explicit Interface for when we want to log features vs test the validity of
68 // the graph for MLIR bridge processing. Note that right now the logging
69 // which is done in the logic used by GraphHasFeaturesUnsupportedByMlirBridge
70 // has diverged and logs supported features as well. Parameters are the same
71 // as for GetMlirBridgeRolloutPolicy with the exception of
72 // record_stats, which isn't needed because this interface will always record.
73 void LogGraphFeatures(const Graph& graph,
74 const FunctionLibraryDefinition* function_library,
75 std::optional<ConfigProto> config_proto,
76 bool uses_uninitialized_resource_args, bool is_v1_compat);
77
78 } // namespace tensorflow
79
80 #endif // TENSORFLOW_COMPILER_MLIR_MLIR_BRIDGE_ROLLOUT_POLICY_H_
81