xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
17 
18 #include <ostream>
19 #include <sstream>
20 #include <type_traits>
21 #include <utility>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "absl/strings/str_split.h"
29 #include "llvm/ADT/Optional.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/types.h"
38 
39 namespace tensorflow {
40 
str() const41 std::string GraphImportConfig::str() const {
42   std::ostringstream ss;
43 
44   ss << "graph_func_name: " << graph_func_name;
45   InputArrays inputs;
46   ss << "\ninputs: ";
47   for (auto& it : inputs) {
48     ss << "\n\t" << it.first << " -> "
49        << DataTypeString(it.second.imported_dtype) << " "
50        << it.second.shape.DebugString();
51   }
52   ss << "\noutputs:";
53   for (auto& output : outputs) ss << " " << output;
54   ss << "\ncontrol_outputs:";
55   for (auto& output : control_outputs) ss << " " << output;
56   ss << "\nprune_unused_nodes: " << prune_unused_nodes;
57   ss << "\nconvert_legacy_fed_inputs: " << convert_legacy_fed_inputs;
58   ss << "\ngraph_as_function: " << graph_as_function;
59   ss << "\nupgrade_legacy: " << upgrade_legacy;
60   ss << "\nrestrict_functionalization_to_compiled_nodes: "
61      << restrict_functionalization_to_compiled_nodes;
62   ss << "\nenable_shape_inference: " << enable_shape_inference;
63   ss << "\nunconditionally_use_set_output_shapes: "
64      << unconditionally_use_set_output_shapes;
65 
66   return ss.str();
67 }
68 
ParseOutputArrayInfo(absl::string_view array_names,std::vector<string> * outputs)69 Status ParseOutputArrayInfo(absl::string_view array_names,
70                             std::vector<string>* outputs) {
71   TF_RETURN_IF_ERROR(ParseNodeNames(array_names, *outputs));
72   return OkStatus();
73 }
74 
ParseOutputArrayInfo(const std::vector<string> & output_names,std::vector<string> * outputs)75 Status ParseOutputArrayInfo(const std::vector<string>& output_names,
76                             std::vector<string>* outputs) {
77   for (auto& output_name : output_names) {
78     if (output_name.empty()) continue;
79     outputs->push_back(output_name);
80   }
81   return OkStatus();
82 }
83 
ParseInputArrayInfo(absl::string_view array_names,absl::string_view data_types,absl::string_view shapes,GraphImportConfig::InputArrays * inputs)84 Status ParseInputArrayInfo(absl::string_view array_names,
85                            absl::string_view data_types,
86                            absl::string_view shapes,
87                            GraphImportConfig::InputArrays* inputs) {
88   std::vector<string> node_names;
89   std::vector<string> node_dtypes;
90   std::vector<llvm::Optional<std::vector<int>>> node_shapes;
91   TF_RETURN_IF_ERROR(ParseNodeNames(array_names, node_names));
92   TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types, node_dtypes));
93   TF_RETURN_IF_ERROR(ParseNodeShapes(shapes, node_shapes));
94   return ParseInputArrayInfo(node_names, node_dtypes, node_shapes, inputs);
95 }
96 
ParseShapeStr(absl::string_view node_shapes_str)97 static StatusOr<std::vector<int>> ParseShapeStr(
98     absl::string_view node_shapes_str) {
99   std::vector<int> dims;
100   for (absl::string_view dim_str : absl::StrSplit(node_shapes_str, ',')) {
101     // Treats empty input shape as scalar
102     if (dim_str.empty()) continue;
103     if (dim_str == "?") {
104       dims.push_back(-1);
105       continue;
106     }
107     int size;
108     TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size));
109     dims.push_back(size);
110   }
111   return dims;
112 }
113 
HandleSubtype(absl::string_view subtype,ArrayInfo::SubTypeInfo * result)114 static Status HandleSubtype(absl::string_view subtype,
115                             ArrayInfo::SubTypeInfo* result) {
116   std::vector<std::string> shape_and_type = absl::StrSplit(subtype, ':');
117 
118   std::vector<int> dims;
119   if (shape_and_type.size() > 2) {
120     return errors::FailedPrecondition("Invalid argument: '", subtype,
121                                       "', expected a single shape and type pair"
122                                       " seperated with a ':'");
123   } else if (shape_and_type.size() == 2) {
124     const auto& shape_str = shape_and_type[0];
125     TF_ASSIGN_OR_RETURN(dims, ParseShapeStr(shape_str));
126   }
127 
128   const auto& subtype_str = shape_and_type.back();
129   DataType subtype_dtype;
130   if (!DataType_Parse(subtype_str, &subtype_dtype)) {
131     return errors::FailedPrecondition(
132         absl::StrCat("Invalid type: '", subtype_str, "'"));
133   }
134 
135   TensorShapeProto subtype_tensor_shape;
136   for (auto& dim : dims) {
137     subtype_tensor_shape.add_dim()->set_size(dim);
138   }
139   *result = {subtype_dtype, subtype_tensor_shape};
140   return OkStatus();
141 }
142 
ParseInputArrayInfo(const std::vector<string> & node_names,const std::vector<string> & node_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & node_shapes,GraphImportConfig::InputArrays * inputs)143 Status ParseInputArrayInfo(
144     const std::vector<string>& node_names,
145     const std::vector<string>& node_dtypes,
146     const std::vector<llvm::Optional<std::vector<int>>>& node_shapes,
147     GraphImportConfig::InputArrays* inputs) {
148   std::vector<std::string> used_node_dtypes;
149   if (node_dtypes.empty()) {
150     // Mark all the node dtypes Invalid, so the importer can handle them by
151     // using the type from the graph.
152     used_node_dtypes.resize(node_names.size(), DataType_Name(DT_INVALID));
153   } else if (node_names.size() == node_dtypes.size()) {
154     for (const auto& dtype : node_dtypes) {
155       if (dtype.empty()) {
156         used_node_dtypes.push_back(DataType_Name(DT_INVALID));
157       } else if (dtype != DataType_Name(DT_INVALID)) {
158         used_node_dtypes.push_back(dtype);
159       } else {
160         return errors::FailedPrecondition(
161             "Use '' if want to use the type from graph.");
162       }
163     }
164   } else {
165     return errors::InvalidArgument(absl::StrCat(
166         "Length of input node array and data type doesn't match (#arrays ",
167         node_names.size(), ", #data_types ", node_dtypes.size(), ")"));
168   }
169 
170   if (!node_shapes.empty() && node_names.size() != node_shapes.size()) {
171     return errors::InvalidArgument(absl::StrCat(
172         "Length of input node array and data shape doesn't match (#arrays ",
173         node_names.size(), ", #input_shapes ", node_shapes.size(), ")"));
174   }
175 
176   // StringMap doesn't support reserve else reserve input map size here.
177   for (int i = 0, end = node_names.size(); i < end; i++) {
178     auto& name = node_names[i];
179     const string& type = used_node_dtypes[i];
180     if (name.empty()) continue;
181 
182     auto it_inserted_pair = inputs->insert({name, {}});
183     if (!it_inserted_pair.second)
184       return errors::FailedPrecondition(
185           absl::StrCat("tensor ", name, " is repeated in the arrays flag"));
186 
187     ArrayInfo& info = it_inserted_pair.first->second;
188     // Splitting the type and subtype into parts
189     std::vector<std::string> parts =
190         absl::StrSplit(type, absl::ByAnyChar("()"));
191     // If type has subtypes then parts[0] = type, parts[1] = subtypes,
192     // parts[2] = ""
193     if (parts.size() != 3 && parts.size() != 1) {
194       return errors::InvalidArgument("Invalid type '", type, "'");
195     } else if (parts.size() == 3) {
196       // First part is the type, second is the subtype
197       ArrayInfo::SubTypeInfo subtype;
198       TF_RETURN_IF_ERROR(HandleSubtype(parts[1], &subtype));
199       info.subtypes.push_back(std::move(subtype));
200     }
201     if (!DataType_Parse(parts[0], &info.imported_dtype)) {
202       return errors::FailedPrecondition(
203           absl::StrCat("Invalid node type '", node_dtypes[i], "'"));
204     }
205 
206     if (!node_shapes.empty()) {
207       if (!node_shapes[i].has_value()) {
208         info.shape.set_unknown_rank(true);
209         continue;
210       }
211       for (auto& dim : node_shapes[i].getValue()) {
212         info.shape.add_dim()->set_size(dim);
213       }
214     }
215   }
216   return OkStatus();
217 }
218 
ParseNodeShapes(absl::string_view shapes_str,std::vector<llvm::Optional<std::vector<int>>> & shapes_vector)219 Status ParseNodeShapes(
220     absl::string_view shapes_str,
221     std::vector<llvm::Optional<std::vector<int>>>& shapes_vector) {
222   shapes_vector.clear();
223   if (!shapes_str.empty()) {
224     std::vector<string> node_shapes_str = absl::StrSplit(shapes_str, ':');
225     for (int i = 0; i < node_shapes_str.size(); i++) {
226       if (node_shapes_str[i] == "*") {
227         shapes_vector.push_back(llvm::None);
228         continue;
229       }
230       TF_ASSIGN_OR_RETURN(auto shape, ParseShapeStr(node_shapes_str[i]));
231       shapes_vector.push_back(std::move(shape));
232     }
233   }
234   return OkStatus();
235 }
236 
ParseNodeNames(absl::string_view names_str,std::vector<std::string> & names_vector)237 Status ParseNodeNames(absl::string_view names_str,
238                       std::vector<std::string>& names_vector) {
239   names_vector = absl::StrSplit(names_str, ',', absl::SkipEmpty());
240   return OkStatus();
241 }
242 
ParseDTypesHelper(absl::string_view data_types_str)243 static StatusOr<std::vector<std::string>> ParseDTypesHelper(
244     absl::string_view data_types_str) {
245   bool inside_subtype = false;
246   int cur_pos = 0;
247   std::vector<std::string> dtypes;
248   for (auto& it : llvm::enumerate(data_types_str)) {
249     char c = it.value();
250     int i = it.index();
251     // Skip parsing the subtypes of a type
252     if (c == '(') {
253       if (inside_subtype) {
254         return errors::FailedPrecondition(
255             absl::StrCat("Syntax error: unexpected '(' in input data types: '",
256                          data_types_str, "'"));
257       }
258       inside_subtype = true;
259     } else if (c == ')') {
260       if (!inside_subtype) {
261         return errors::FailedPrecondition(
262             absl::StrCat("Syntax error: unexpected ')' in input data types: '",
263                          data_types_str, "'"));
264       }
265       inside_subtype = false;
266     }
267     if (inside_subtype) continue;
268     if (c == ',') {
269       dtypes.push_back(
270           std::string(data_types_str.substr(cur_pos, i - cur_pos)));
271       cur_pos = i + 1;
272     }
273   }
274   if (inside_subtype) {
275     return errors::FailedPrecondition(
276         absl::StrCat("Syntax error: expected a ')' in input data types '",
277                      data_types_str, "'"));
278   }
279   if (!data_types_str.empty()) {
280     dtypes.push_back(std::string(
281         data_types_str.substr(cur_pos, data_types_str.size() - cur_pos)));
282   }
283   return dtypes;
284 }
285 
ParseNodeDataTypes(absl::string_view data_types_str,std::vector<std::string> & data_type_vector)286 Status ParseNodeDataTypes(absl::string_view data_types_str,
287                           std::vector<std::string>& data_type_vector) {
288   data_type_vector.clear();
289   if (!data_types_str.empty()) {
290     TF_ASSIGN_OR_RETURN(data_type_vector, ParseDTypesHelper(data_types_str));
291   }
292   return OkStatus();
293 }
294 
295 }  // namespace tensorflow
296