xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_FLAGS_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_FLAGS_H_
18 
19 #include <string>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "llvm/ADT/MapVector.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/StringMap.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/lib/core/status.h"
29 
30 namespace tensorflow {
31 
32 struct ArrayInfoBase {
33   // The node type when the input node is imported. Typically needs to be
34   // specified when passing arbitrary nodes (some node attributes are removed).
35   DataType imported_dtype;
36 
37   // Node "shape" attribute value.
38   TensorShapeProto shape;
39 };
40 
41 struct ArrayInfo : public ArrayInfoBase {
42   using SubTypeInfo = ArrayInfoBase;
43   // DT_RESOURCE and DT_VARIANT have subtypes
44   std::vector<SubTypeInfo> subtypes;
45 };
46 
47 struct GraphImportConfig {
48   // Returns string representation of config.
49   std::string str() const;
50 
51   using InputArrays =
52       llvm::MapVector<std::string, ArrayInfo, llvm::StringMap<unsigned>>;
53   // The name assigned to the function which is the import result of the given
54   // graph. If empty, a default one will be used.
55   std::string graph_func_name;
56   // Maps input node names to node data types and shapes.
57   InputArrays inputs;
58   // name:index strings for the data outputs.
59   std::vector<string> outputs;
60   // name strings for the control outputs.
61   std::vector<string> control_outputs;
62   // Setting prune_unused_nodes to true, would prune unreachable nodes if
63   // output_arrays is specified.
64   bool prune_unused_nodes = false;
65   // If true, inputs of type LegacyFedInput are replaced with Placeholder ops.
66   // LegacyFedInput ops have two outputs unlike Placeholder which has only one
67   // output, so if both outputs of the LegacyFedInput ops are used then returns
68   // an error.
69   bool convert_legacy_fed_inputs = false;
70   // If true, the main graph will be treated as a function.
71   bool graph_as_function = false;
72   // If true, upgrade legacy features of the graph (for instance, functionalize
73   // control-flow).
74   bool upgrade_legacy = false;
75   // If true, functionalization is restricted to nodes that will be
76   // XLA-compiled. This is only needed if
77   // - `upgrade_legacy` is true
78   // - upgrading legacy features of the graph (which includes functionalization)
79   //   runs before compilation cluster extraction (as for MLIR-based TPU bridge)
80   // - session runtime is used (session runtime has issues with function names
81   //   rewritten by functionalization).
82   // Otherwise, this parameter should be set to false.
83   bool restrict_functionalization_to_compiled_nodes = false;
84   // If true, enables shape inference on input.
85   // TODO(jpienaar): This will be removed shortly.
86   bool enable_shape_inference = true;
87   // _output_shapes is an unregistered attribute which is used during
88   // GraphConstructor::ConvertGraph to override shapes. It is unfortunately
89   // not always set correctly (which is undesirable and should be addressed)
90   // so make it opt-in to consider it unconditionally also when importing the
91   // graph.
92   bool unconditionally_use_set_output_shapes = false;
93 };
94 
95 struct GraphExportConfig {
96   // Whether to export shape attribute for the NodeDefs in the GraphDef.
97   bool export_shapes = true;
98   // Whether to export library field in the GraphDef.
99   bool export_library = true;
100   // Whether to export debug original node name in the GraphDef.
101   bool export_debug_info = true;
102   // Whether to export the entry function to function library instead of the
103   // graph.
104   bool export_entry_func_to_flib = false;
105 };
106 
107 // Parses the command line flag strings to the specification of nodes in
108 // the Graph.
109 Status ParseOutputArrayInfo(absl::string_view array_names,
110                             std::vector<string>* outputs);
111 
112 Status ParseOutputArrayInfo(const std::vector<string>& output_names,
113                             std::vector<string>* outputs);
114 
115 // Parses the command line flag strings to the specification of nodes in
116 // the Graph. `data_types` input string can be empty since the flag is optional.
117 Status ParseInputArrayInfo(absl::string_view array_names,
118                            absl::string_view data_types,
119                            absl::string_view shapes,
120                            GraphImportConfig::InputArrays* inputs);
121 
122 Status ParseInputArrayInfo(
123     const std::vector<string>& node_names,
124     const std::vector<string>& node_dtypes,
125     const std::vector<llvm::Optional<std::vector<int>>>& node_shapes,
126     GraphImportConfig::InputArrays* inputs);
127 
128 // Parses shapes from the given string into shapes_vector which is a structured
129 // format.
130 // NOTE: If shapes_str is empty, shapes_vector will also be empty.
131 Status ParseNodeShapes(
132     absl::string_view shapes_str,
133     std::vector<llvm::Optional<std::vector<int>>>& shapes_vector);
134 
135 // Parses names from the given string into the names_vector.
136 // NOTE: If names_str is empty, names_vector will also be empty.
137 Status ParseNodeNames(absl::string_view names_str,
138                       std::vector<std::string>& names_vector);
139 
140 // Parses data types from the given string into the data_type_vector.
141 // NOTE: If data_types_str is empty, data_type_vector will also be empty.
142 Status ParseNodeDataTypes(absl::string_view data_types_str,
143                           std::vector<std::string>& data_type_vector);
144 
145 }  // namespace tensorflow
146 
147 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_FLAGS_H_
148