1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
17
18 #include <ostream>
19 #include <sstream>
20 #include <type_traits>
21 #include <utility>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "absl/strings/str_split.h"
29 #include "llvm/ADT/Optional.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/types.h"
38
39 namespace tensorflow {
40
str() const41 std::string GraphImportConfig::str() const {
42 std::ostringstream ss;
43
44 ss << "graph_func_name: " << graph_func_name;
45 InputArrays inputs;
46 ss << "\ninputs: ";
47 for (auto& it : inputs) {
48 ss << "\n\t" << it.first << " -> "
49 << DataTypeString(it.second.imported_dtype) << " "
50 << it.second.shape.DebugString();
51 }
52 ss << "\noutputs:";
53 for (auto& output : outputs) ss << " " << output;
54 ss << "\ncontrol_outputs:";
55 for (auto& output : control_outputs) ss << " " << output;
56 ss << "\nprune_unused_nodes: " << prune_unused_nodes;
57 ss << "\nconvert_legacy_fed_inputs: " << convert_legacy_fed_inputs;
58 ss << "\ngraph_as_function: " << graph_as_function;
59 ss << "\nupgrade_legacy: " << upgrade_legacy;
60 ss << "\nrestrict_functionalization_to_compiled_nodes: "
61 << restrict_functionalization_to_compiled_nodes;
62 ss << "\nenable_shape_inference: " << enable_shape_inference;
63 ss << "\nunconditionally_use_set_output_shapes: "
64 << unconditionally_use_set_output_shapes;
65
66 return ss.str();
67 }
68
ParseOutputArrayInfo(absl::string_view array_names,std::vector<string> * outputs)69 Status ParseOutputArrayInfo(absl::string_view array_names,
70 std::vector<string>* outputs) {
71 TF_RETURN_IF_ERROR(ParseNodeNames(array_names, *outputs));
72 return OkStatus();
73 }
74
ParseOutputArrayInfo(const std::vector<string> & output_names,std::vector<string> * outputs)75 Status ParseOutputArrayInfo(const std::vector<string>& output_names,
76 std::vector<string>* outputs) {
77 for (auto& output_name : output_names) {
78 if (output_name.empty()) continue;
79 outputs->push_back(output_name);
80 }
81 return OkStatus();
82 }
83
ParseInputArrayInfo(absl::string_view array_names,absl::string_view data_types,absl::string_view shapes,GraphImportConfig::InputArrays * inputs)84 Status ParseInputArrayInfo(absl::string_view array_names,
85 absl::string_view data_types,
86 absl::string_view shapes,
87 GraphImportConfig::InputArrays* inputs) {
88 std::vector<string> node_names;
89 std::vector<string> node_dtypes;
90 std::vector<llvm::Optional<std::vector<int>>> node_shapes;
91 TF_RETURN_IF_ERROR(ParseNodeNames(array_names, node_names));
92 TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types, node_dtypes));
93 TF_RETURN_IF_ERROR(ParseNodeShapes(shapes, node_shapes));
94 return ParseInputArrayInfo(node_names, node_dtypes, node_shapes, inputs);
95 }
96
ParseShapeStr(absl::string_view node_shapes_str)97 static StatusOr<std::vector<int>> ParseShapeStr(
98 absl::string_view node_shapes_str) {
99 std::vector<int> dims;
100 for (absl::string_view dim_str : absl::StrSplit(node_shapes_str, ',')) {
101 // Treats empty input shape as scalar
102 if (dim_str.empty()) continue;
103 if (dim_str == "?") {
104 dims.push_back(-1);
105 continue;
106 }
107 int size;
108 TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size));
109 dims.push_back(size);
110 }
111 return dims;
112 }
113
HandleSubtype(absl::string_view subtype,ArrayInfo::SubTypeInfo * result)114 static Status HandleSubtype(absl::string_view subtype,
115 ArrayInfo::SubTypeInfo* result) {
116 std::vector<std::string> shape_and_type = absl::StrSplit(subtype, ':');
117
118 std::vector<int> dims;
119 if (shape_and_type.size() > 2) {
120 return errors::FailedPrecondition("Invalid argument: '", subtype,
121 "', expected a single shape and type pair"
122 " seperated with a ':'");
123 } else if (shape_and_type.size() == 2) {
124 const auto& shape_str = shape_and_type[0];
125 TF_ASSIGN_OR_RETURN(dims, ParseShapeStr(shape_str));
126 }
127
128 const auto& subtype_str = shape_and_type.back();
129 DataType subtype_dtype;
130 if (!DataType_Parse(subtype_str, &subtype_dtype)) {
131 return errors::FailedPrecondition(
132 absl::StrCat("Invalid type: '", subtype_str, "'"));
133 }
134
135 TensorShapeProto subtype_tensor_shape;
136 for (auto& dim : dims) {
137 subtype_tensor_shape.add_dim()->set_size(dim);
138 }
139 *result = {subtype_dtype, subtype_tensor_shape};
140 return OkStatus();
141 }
142
ParseInputArrayInfo(const std::vector<string> & node_names,const std::vector<string> & node_dtypes,const std::vector<llvm::Optional<std::vector<int>>> & node_shapes,GraphImportConfig::InputArrays * inputs)143 Status ParseInputArrayInfo(
144 const std::vector<string>& node_names,
145 const std::vector<string>& node_dtypes,
146 const std::vector<llvm::Optional<std::vector<int>>>& node_shapes,
147 GraphImportConfig::InputArrays* inputs) {
148 std::vector<std::string> used_node_dtypes;
149 if (node_dtypes.empty()) {
150 // Mark all the node dtypes Invalid, so the importer can handle them by
151 // using the type from the graph.
152 used_node_dtypes.resize(node_names.size(), DataType_Name(DT_INVALID));
153 } else if (node_names.size() == node_dtypes.size()) {
154 for (const auto& dtype : node_dtypes) {
155 if (dtype.empty()) {
156 used_node_dtypes.push_back(DataType_Name(DT_INVALID));
157 } else if (dtype != DataType_Name(DT_INVALID)) {
158 used_node_dtypes.push_back(dtype);
159 } else {
160 return errors::FailedPrecondition(
161 "Use '' if want to use the type from graph.");
162 }
163 }
164 } else {
165 return errors::InvalidArgument(absl::StrCat(
166 "Length of input node array and data type doesn't match (#arrays ",
167 node_names.size(), ", #data_types ", node_dtypes.size(), ")"));
168 }
169
170 if (!node_shapes.empty() && node_names.size() != node_shapes.size()) {
171 return errors::InvalidArgument(absl::StrCat(
172 "Length of input node array and data shape doesn't match (#arrays ",
173 node_names.size(), ", #input_shapes ", node_shapes.size(), ")"));
174 }
175
176 // StringMap doesn't support reserve else reserve input map size here.
177 for (int i = 0, end = node_names.size(); i < end; i++) {
178 auto& name = node_names[i];
179 const string& type = used_node_dtypes[i];
180 if (name.empty()) continue;
181
182 auto it_inserted_pair = inputs->insert({name, {}});
183 if (!it_inserted_pair.second)
184 return errors::FailedPrecondition(
185 absl::StrCat("tensor ", name, " is repeated in the arrays flag"));
186
187 ArrayInfo& info = it_inserted_pair.first->second;
188 // Splitting the type and subtype into parts
189 std::vector<std::string> parts =
190 absl::StrSplit(type, absl::ByAnyChar("()"));
191 // If type has subtypes then parts[0] = type, parts[1] = subtypes,
192 // parts[2] = ""
193 if (parts.size() != 3 && parts.size() != 1) {
194 return errors::InvalidArgument("Invalid type '", type, "'");
195 } else if (parts.size() == 3) {
196 // First part is the type, second is the subtype
197 ArrayInfo::SubTypeInfo subtype;
198 TF_RETURN_IF_ERROR(HandleSubtype(parts[1], &subtype));
199 info.subtypes.push_back(std::move(subtype));
200 }
201 if (!DataType_Parse(parts[0], &info.imported_dtype)) {
202 return errors::FailedPrecondition(
203 absl::StrCat("Invalid node type '", node_dtypes[i], "'"));
204 }
205
206 if (!node_shapes.empty()) {
207 if (!node_shapes[i].has_value()) {
208 info.shape.set_unknown_rank(true);
209 continue;
210 }
211 for (auto& dim : node_shapes[i].getValue()) {
212 info.shape.add_dim()->set_size(dim);
213 }
214 }
215 }
216 return OkStatus();
217 }
218
ParseNodeShapes(absl::string_view shapes_str,std::vector<llvm::Optional<std::vector<int>>> & shapes_vector)219 Status ParseNodeShapes(
220 absl::string_view shapes_str,
221 std::vector<llvm::Optional<std::vector<int>>>& shapes_vector) {
222 shapes_vector.clear();
223 if (!shapes_str.empty()) {
224 std::vector<string> node_shapes_str = absl::StrSplit(shapes_str, ':');
225 for (int i = 0; i < node_shapes_str.size(); i++) {
226 if (node_shapes_str[i] == "*") {
227 shapes_vector.push_back(llvm::None);
228 continue;
229 }
230 TF_ASSIGN_OR_RETURN(auto shape, ParseShapeStr(node_shapes_str[i]));
231 shapes_vector.push_back(std::move(shape));
232 }
233 }
234 return OkStatus();
235 }
236
ParseNodeNames(absl::string_view names_str,std::vector<std::string> & names_vector)237 Status ParseNodeNames(absl::string_view names_str,
238 std::vector<std::string>& names_vector) {
239 names_vector = absl::StrSplit(names_str, ',', absl::SkipEmpty());
240 return OkStatus();
241 }
242
ParseDTypesHelper(absl::string_view data_types_str)243 static StatusOr<std::vector<std::string>> ParseDTypesHelper(
244 absl::string_view data_types_str) {
245 bool inside_subtype = false;
246 int cur_pos = 0;
247 std::vector<std::string> dtypes;
248 for (auto& it : llvm::enumerate(data_types_str)) {
249 char c = it.value();
250 int i = it.index();
251 // Skip parsing the subtypes of a type
252 if (c == '(') {
253 if (inside_subtype) {
254 return errors::FailedPrecondition(
255 absl::StrCat("Syntax error: unexpected '(' in input data types: '",
256 data_types_str, "'"));
257 }
258 inside_subtype = true;
259 } else if (c == ')') {
260 if (!inside_subtype) {
261 return errors::FailedPrecondition(
262 absl::StrCat("Syntax error: unexpected ')' in input data types: '",
263 data_types_str, "'"));
264 }
265 inside_subtype = false;
266 }
267 if (inside_subtype) continue;
268 if (c == ',') {
269 dtypes.push_back(
270 std::string(data_types_str.substr(cur_pos, i - cur_pos)));
271 cur_pos = i + 1;
272 }
273 }
274 if (inside_subtype) {
275 return errors::FailedPrecondition(
276 absl::StrCat("Syntax error: expected a ')' in input data types '",
277 data_types_str, "'"));
278 }
279 if (!data_types_str.empty()) {
280 dtypes.push_back(std::string(
281 data_types_str.substr(cur_pos, data_types_str.size() - cur_pos)));
282 }
283 return dtypes;
284 }
285
ParseNodeDataTypes(absl::string_view data_types_str,std::vector<std::string> & data_type_vector)286 Status ParseNodeDataTypes(absl::string_view data_types_str,
287 std::vector<std::string>& data_type_vector) {
288 data_type_vector.clear();
289 if (!data_types_str.empty()) {
290 TF_ASSIGN_OR_RETURN(data_type_vector, ParseDTypesHelper(data_types_str));
291 }
292 return OkStatus();
293 }
294
295 } // namespace tensorflow
296