1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/ir/importexport/functiondef_import.h"
17
18 #include <string>
19
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringMap.h"
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
25 #include "mlir/IR/MLIRContext.h" // from @llvm-project
26 #include "mlir/IR/OperationSupport.h" // from @llvm-project
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/ir/dialect.h"
31 #include "tensorflow/core/ir/importexport/convert_attributes.h"
32 #include "tensorflow/core/ir/importexport/convert_types.h"
33 #include "tensorflow/core/ir/ops.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/status.h"
36 #include "tensorflow/core/platform/statusor.h"
37
38 using tensorflow::AttrValue;
39 using tensorflow::FunctionDef;
40 using tensorflow::NodeDef;
41 using tensorflow::OpDef;
42 using tensorflow::OpDef_AttrDef;
43 using tensorflow::Status;
44 using tensorflow::StatusOr;
45 using tensorflow::errors::InvalidArgument;
46 using tensorflow::protobuf::RepeatedPtrField;
47
48 #define DEBUG_TYPE "graphdef-to-mlir"
49
50 namespace mlir {
51 namespace tfg {
52 namespace {
53
54 class ValueMapManager {
55 public:
ValueMapManager(llvm::StringMap<llvm::StringMap<SmallVector<Value,1>>> & values_map,OpBuilder & builder,OperationName mlir_placeholder,Type placeholder_ty,Type control_ty,Location loc)56 ValueMapManager(
57 llvm::StringMap<llvm::StringMap<SmallVector<Value, 1>>>& values_map,
58 OpBuilder& builder, OperationName mlir_placeholder, Type placeholder_ty,
59 Type control_ty, Location loc)
60 : values_map_(values_map),
61 builder_(builder),
62 loc_(loc),
63 mlir_placeholder_(mlir_placeholder),
64 placeholder_ty_(placeholder_ty),
65 control_ty_(control_ty) {}
66
DefineOperation(Operation * op,StringRef node_name)67 Status DefineOperation(Operation* op, StringRef node_name) {
68 llvm::StringMap<SmallVector<Value, 1>>& op_info = values_map_[node_name];
69 SmallVector<Value, 1>& base_operation = op_info["^"];
70 // Replace placeholders.
71 if (!base_operation.empty()) {
72 Operation* placeholder = base_operation[0].getDefiningOp();
73 if (!placeholder ||
74 placeholder->getName().getStringRef() != "tfg.__mlir_placeholder")
75 return InvalidArgument(absl::StrCat(
76 "Duplicated node (or function argument) with the same name: `",
77 node_name.str(), "`"));
78
79 op->moveBefore(placeholder);
80 placeholder->replaceAllUsesWith(op);
81 placeholder->erase();
82 base_operation.clear();
83 }
84 base_operation.push_back(op->getResult(1));
85 base_operation.push_back(op->getResult(0));
86 return ::tensorflow::OkStatus();
87 }
88
GetValueOrCreatePlaceholder(StringRef full_name)89 Value GetValueOrCreatePlaceholder(StringRef full_name) {
90 StringRef node_name;
91 StringRef output_name = "";
92 bool is_control_dep = full_name[0] == '^';
93 int output_num = 0;
94 if (is_control_dep) full_name = full_name.drop_front();
95 {
96 size_t colon_sep = full_name.find_first_of(':');
97 if (colon_sep == StringRef::npos) {
98 node_name = full_name;
99 } else {
100 node_name = full_name.take_front(colon_sep);
101 output_name = full_name.drop_front(colon_sep + 1);
102 }
103 colon_sep = output_name.find_last_of(':');
104 if (colon_sep != StringRef::npos) {
105 // NOLINTNEXTLINE: type matching the API taking a reference.
106 unsigned long long value;
107 if (!llvm::getAsUnsignedInteger(output_name.drop_front(colon_sep + 1),
108 10, value))
109 output_num = value;
110 output_name = output_name.take_front(colon_sep);
111 }
112 }
113
114 llvm::StringMap<SmallVector<Value, 1>>& op_info = values_map_[node_name];
115 SmallVector<Value, 1>& base_operation = op_info["^"];
116 if (base_operation.empty()) {
117 OperationState state(loc_, mlir_placeholder_);
118 state.addAttribute(TFGraphDialect::getNameAttrKey(),
119 builder_.getStringAttr(node_name));
120 state.types.push_back(placeholder_ty_);
121 state.types.push_back(control_ty_);
122 Operation* placeholder = builder_.create(state);
123 base_operation.push_back(placeholder->getResult(1));
124 base_operation.push_back(placeholder->getResult(0));
125 }
126 if (is_control_dep) return base_operation[0];
127 SmallVector<Value, 1>& value_info = op_info[output_name];
128 if (value_info.size() <= output_num)
129 value_info.resize(output_num + 1, Value{});
130 if (!value_info[output_num]) {
131 // Create a tfg.get_result for this output.
132 value_info[output_num] = builder_.create<GetResultOp>(
133 loc_, base_operation[1], output_name, output_num);
134 }
135 return value_info[output_num];
136 }
137
138 private:
139 llvm::StringMap<llvm::StringMap<SmallVector<Value, 1>>>& values_map_;
140 OpBuilder& builder_;
141 Location loc_;
142 OperationName mlir_placeholder_;
143 Type placeholder_ty_;
144 Type control_ty_;
145 };
146
147 // Convert the list of `nodes` one by one into MLIR Operations using the
148 // provided OpBuilder.
149 // The provided `nodes_map` will be populated with a mapping from the node name
150 // to the result count and the Operation.
151 // The supplied `args_map` is looked up for Function arguments when an entry
152 // cannot be found in the nodes_map.
ImportNodes(ValueMapManager value_manager,const RepeatedPtrField<NodeDef> & nodes,OpBuilder & builder)153 Status ImportNodes(ValueMapManager value_manager,
154 const RepeatedPtrField<NodeDef>& nodes, OpBuilder& builder) {
155 Location unknown_loc = builder.getUnknownLoc();
156 MLIRContext* context = builder.getContext();
157
158 Type placeholder_ty = OpaqueTensorType::get(context);
159 Type control_ty = ControlType::get(context);
160 TFGraphDialect* tfgDialect =
161 cast<TFGraphDialect>(context->getLoadedDialect("tfg"));
162 StringAttr device_attr = tfgDialect->getDeviceAttrIdentifier();
163 StringAttr name_attr = tfgDialect->getNameAttrIdentifier();
164 StringAttr fulltype_attr = tfgDialect->getFullTypeAttrIdentifier();
165 // Process every node and create a matching MLIR operation
166 for (const NodeDef& node : nodes) {
167 DVLOG(1) << "Processing node " << node.name() << "\n";
168 if (node.op().empty()) return InvalidArgument("empty op type");
169 OperationState state(unknown_loc, absl::StrCat("tfg.", node.op()));
170 // Fetch the inputs, creating placeholder if an input hasn't been visited.
171 for (const std::string& input : node.input()) {
172 if (input.empty())
173 return InvalidArgument("Node '", node.name(), "' has an empty input");
174 state.operands.push_back(
175 value_manager.GetValueOrCreatePlaceholder(input));
176 }
177 // Retrieve the entry in the nodes_map for this node and infer the result
178 // count from what was inferred during the first traversal above.
179 state.types.push_back(placeholder_ty);
180 state.types.push_back(control_ty);
181 // Handle attributes.
182 for (const auto& namedAttr : node.attr()) {
183 const std::string& name = namedAttr.first;
184 const AttrValue& tf_attr = namedAttr.second;
185 TF_ASSIGN_OR_RETURN(Attribute attr,
186 ConvertAttributeValue(tf_attr, builder));
187 state.addAttribute(name, attr);
188 }
189 if (!node.device().empty())
190 state.addAttribute(device_attr, StringAttr::get(context, node.device()));
191 if (!node.name().empty())
192 state.addAttribute(name_attr, StringAttr::get(context, node.name()));
193 if (node.has_experimental_type()) {
194 TF_ASSIGN_OR_RETURN(tf_type::FullTypeAttr type,
195 ConvertAttribute(node.experimental_type(), builder));
196 state.addAttribute(fulltype_attr, type);
197 }
198
199 Operation* op = builder.create(state);
200
201 StringRef node_name = node.name();
202 {
203 size_t colon_sep = node_name.find_first_of(':');
204 if (colon_sep != StringRef::npos)
205 node_name = node_name.take_front(colon_sep);
206 }
207 TF_RETURN_IF_ERROR(value_manager.DefineOperation(op, node_name));
208 }
209 // We don't expect any placeholder left at this point, fail if any.
210 for (Operation& op : *builder.getInsertionBlock()) {
211 if (op.getName().getStringRef() == "tfg.__mlir_placeholder") {
212 return InvalidArgument(absl::StrCat(
213 "Couldn't import graph: placeholder left ",
214 op.getAttrOfType<StringAttr>(name_attr).getValue().str()));
215 }
216 }
217 return ::tensorflow::OkStatus();
218 }
219
ConvertArgDefAttributes(const OpDef::ArgDef & arg,Builder builder)220 tensorflow::StatusOr<NamedAttrList> ConvertArgDefAttributes(
221 const OpDef::ArgDef& arg, Builder builder) {
222 NamedAttrList input_attrs;
223 StringAttr arg_name = builder.getStringAttr(arg.name());
224 input_attrs.set("tfg.name", arg_name);
225 if (!arg.description().empty())
226 input_attrs.append("tfg.description",
227 builder.getStringAttr(arg.description()));
228
229 Type input_type;
230 if (arg.type() != tensorflow::DT_INVALID) {
231 TF_RETURN_IF_ERROR(ConvertDataType(arg.type(), builder, &input_type));
232 input_attrs.append("tfg.type", TypeAttr::get(input_type));
233 }
234 if (!arg.type_attr().empty())
235 input_attrs.append("tfg.type_attr", builder.getStringAttr(arg.type_attr()));
236 if (!arg.number_attr().empty())
237 input_attrs.append("tfg.number_attr",
238 builder.getStringAttr(arg.number_attr()));
239 if (!arg.type_list_attr().empty())
240 input_attrs.append("tfg.type_list_attr",
241 builder.getStringAttr(arg.type_list_attr()));
242 if (arg.handle_data_size()) {
243 TF_ASSIGN_OR_RETURN(Attribute handle_data,
244 ConvertHandleData(builder, arg.handle_data()));
245 input_attrs.append("tfg.handle_data", handle_data);
246 }
247 if (arg.is_ref()) input_attrs.append("tfg.is_ref", builder.getUnitAttr());
248 if (arg.has_experimental_full_type()) {
249 TF_ASSIGN_OR_RETURN(
250 tf_type::FullTypeAttr type,
251 ConvertAttribute(arg.experimental_full_type(), builder));
252 input_attrs.append("tfg.experimental_full_type", type);
253 }
254 return input_attrs;
255 }
256
257 // Import the given `func` and inser the resulting `GraphFunc`
258 // operation using the provided `builder`. The `nodes_map` and `args_map` are
259 // used as scratchpad for the import inside this function. The `gradients` maps
260 // is provided to
ImportGenericFunction(GraphFuncOp func_op,const FunctionDef & func,llvm::StringMap<llvm::StringMap<SmallVector<Value,1>>> & values_map,OpBuilder & builder)261 Status ImportGenericFunction(
262 GraphFuncOp func_op, const FunctionDef& func,
263 llvm::StringMap<llvm::StringMap<SmallVector<Value, 1>>>& values_map,
264 OpBuilder& builder) {
265 const OpDef& signature = func.signature();
266 Location unknown_loc = builder.getUnknownLoc();
267 MLIRContext* context = builder.getContext();
268
269 NamedAttrList attrs;
270 DictionaryAttr func_attrs = builder.getDictionaryAttr({});
271 if (signature.name().empty())
272 return InvalidArgument("generic function without a name");
273 attrs.append("sym_name", builder.getStringAttr(signature.name()));
274 attrs.append("generic", builder.getUnitAttr());
275 if (!signature.description().empty())
276 attrs.append("description", builder.getStringAttr(signature.description()));
277 if (signature.is_stateful())
278 attrs.append("is_stateful", builder.getUnitAttr());
279 if (signature.control_output_size()) {
280 SmallVector<Attribute> control_outputs;
281 for (const std::string& output : signature.control_output())
282 control_outputs.push_back(builder.getStringAttr(output));
283 attrs.append("control_output", builder.getArrayAttr(control_outputs));
284 }
285 {
286 NamedAttrList attr_defs;
287 for (const OpDef_AttrDef& attr : signature.attr()) {
288 NamedAttrList attr_def;
289 if (attr.name().empty())
290 return InvalidArgument("Missing name for function attribute");
291 if (!attr.type().empty())
292 attr_def.append(builder.getNamedAttr(
293 "function_type", builder.getStringAttr(attr.type())));
294 if (attr.has_default_value()) {
295 TF_ASSIGN_OR_RETURN(Attribute attr, ConvertAttributeValue(
296 attr.default_value(), builder));
297 attr_def.append(builder.getNamedAttr("default_value", attr));
298 }
299 if (!attr.description().empty())
300 attr_def.append(builder.getNamedAttr(
301 "description", builder.getStringAttr(attr.description())));
302 if (attr.has_minimum() || attr.minimum())
303 attr_def.append(builder.getNamedAttr(
304 "minimum", builder.getI32IntegerAttr(attr.minimum())));
305 if (attr.has_allowed_values()) {
306 TF_ASSIGN_OR_RETURN(
307 Attribute attr,
308 ConvertAttributeValue(attr.allowed_values(), builder));
309 attr_def.append(builder.getNamedAttr("allowed_values", attr));
310 }
311 attr_defs.append(builder.getNamedAttr(
312 attr.name(), attr_def.getDictionary(builder.getContext())));
313 }
314 if (!attr_defs.empty()) {
315 func_attrs = attr_defs.getDictionary(builder.getContext());
316 attrs.append("tfg.func_attrs", func_attrs);
317 }
318 }
319
320 // The resource_arg_unique_id is a list of `pair<int, int>`, we import it
321 // as two arrays of integer right now.
322 if (func.resource_arg_unique_id_size()) {
323 SmallVector<int32_t> resource_arg_unique_ids_keys;
324 SmallVector<int32_t> resource_arg_unique_ids_values;
325 for (const auto& unique_id : func.resource_arg_unique_id()) {
326 resource_arg_unique_ids_keys.push_back(unique_id.first);
327 resource_arg_unique_ids_values.push_back(unique_id.second);
328 }
329 attrs.append("resource_arg_unique_ids_keys",
330 builder.getI32TensorAttr(resource_arg_unique_ids_keys));
331 attrs.append("resource_arg_unique_ids_values",
332 builder.getI32TensorAttr(resource_arg_unique_ids_values));
333 }
334
335 // Import the function attributes with a `tf.` prefix to match the current
336 // infrastructure expectations.
337 for (const auto& namedAttr : func.attr()) {
338 if (namedAttr.first.empty())
339 return InvalidArgument("Invalid function attribute name");
340 const std::string& name = "tf." + namedAttr.first;
341 const AttrValue& tf_attr = namedAttr.second;
342 TF_ASSIGN_OR_RETURN(Attribute attr,
343 ConvertAttributeValue(tf_attr, builder));
344 attrs.append(name, attr);
345 }
346 SmallString<8> arg_or_res_attr_name;
347 SmallString<8> sub_arg_attr_name;
348 // Iterate of the input in the signature. Each input will correspond to
349 // potentially multiple arguments because of how the OpDef allows repeated
350 // arguments controlled by `number_attr` for example.
351 // We populate the `arg_names` vector with the name of each input at each
352 // position, and `arg_types` with the matching type.
353 int arg_num = 0;
354 SmallVector<StringRef> arg_names;
355 SmallVector<Type> arg_types;
356 SmallVector<Attribute> args_attrs;
357 SmallVector<Attribute> res_attrs;
358 for (const auto& enumerated_input : llvm::enumerate(signature.input_arg())) {
359 const OpDef::ArgDef& input = enumerated_input.value();
360 TF_ASSIGN_OR_RETURN(NamedAttrList input_attrs,
361 ConvertArgDefAttributes(input, builder));
362 auto it = func.arg_attr().find(enumerated_input.index());
363 if (it != func.arg_attr().end()) {
364 NamedAttrList arg_attr;
365 for (const auto& named_attr : it->second.attr()) {
366 TF_ASSIGN_OR_RETURN(Attribute attr,
367 ConvertAttributeValue(named_attr.second, builder));
368 arg_attr.append(named_attr.first, attr);
369 }
370 input_attrs.append("tfg.arg_attrs",
371 arg_attr.getDictionary(builder.getContext()));
372 }
373 arg_names.push_back(builder.getStringAttr(input.name()).getValue());
374 arg_types.push_back(OpaqueTensorType::get(context));
375 args_attrs.push_back(input_attrs.getDictionary(context));
376 args_attrs.push_back(NamedAttrList{}.getDictionary(context));
377 arg_num++;
378 }
379 attrs.push_back(
380 builder.getNamedAttr(function_interface_impl::getArgDictAttrName(),
381 builder.getArrayAttr(args_attrs)));
382
383 // Process the results attributes now.
384 int res_num = 0;
385 for (const OpDef::ArgDef& output : signature.output_arg()) {
386 TF_ASSIGN_OR_RETURN(NamedAttrList output_attrs,
387 ConvertArgDefAttributes(output, builder));
388 res_attrs.push_back(output_attrs.getDictionary(context));
389 ++res_num;
390 }
391 // Process the control output metadata and store them as attributes.
392 for (const std::string& output : signature.control_output()) {
393 NamedAttrList output_attrs;
394 output_attrs.append("tfg.name", builder.getStringAttr(output));
395 res_attrs.push_back(output_attrs.getDictionary(context));
396 ++res_num;
397 }
398 attrs.push_back(
399 builder.getNamedAttr(function_interface_impl::getResultDictAttrName(),
400 builder.getArrayAttr(res_attrs)));
401
402 values_map.clear();
403 Block* body = new Block();
404 func_op.body().push_back(body);
405 Type control_ty = ControlType::get(context);
406 // Create the block arguments and populate the `values_map` with the matching
407 // input names.
408 for (auto type_and_name : llvm::zip(arg_types, arg_names)) {
409 Value arg = body->addArgument(std::get<0>(type_and_name), unknown_loc);
410 llvm::StringMap<SmallVector<Value, 1>>& values =
411 values_map[std::get<1>(type_and_name)];
412 Value ctl = body->addArgument(control_ty, unknown_loc);
413 values[""].push_back(arg);
414 values["^"].push_back(ctl);
415 }
416
417 // Pre-populate the nodes_map with the needed slots for the return.
418 OpBuilder body_builder = OpBuilder::atBlockEnd(body);
419 // We use placeholders during the import to create "fake" operations to break
420 // cycles: we need operands to feed to the users.
421 OperationName mlir_placeholder("tfg.__mlir_placeholder", context);
422 Type placeholder_ty = OpaqueTensorType::get(context);
423 ValueMapManager value_manager(values_map, body_builder, mlir_placeholder,
424 placeholder_ty, control_ty, unknown_loc);
425
426 // Import the function body here, after this we have a function with all
427 // the nodes, and the nodes_map contains the mapping from node_name to actual
428 // MLIR Operations.
429 TF_RETURN_WITH_CONTEXT_IF_ERROR(
430 ImportNodes(value_manager, func.node_def(), body_builder),
431 " when importing function ", func.signature().name());
432
433 // After the body, the final part is to setup the return. It comes in two
434 // parts: the `ret` field from the FunctionDef for the regular output and the
435 // `control_ret` field for the control output.
436 //
437 // Because `ret` and `control_ret` aren't ordered, there is an indirection to
438 // the FunctionDef signature to retrieve the position of each `ret` and
439 // `control_ret` entry by name. We compute this mapping from the name of an
440 // output to the position in the result array first.
441 res_num = 0;
442 llvm::StringMap<int> output_name_to_position;
443 for (const OpDef::ArgDef& output : signature.output_arg()) {
444 if (output_name_to_position.count(output.name()))
445 return InvalidArgument("Duplicated output_arg entry", output.name());
446 output_name_to_position[output.name()] = res_num;
447 ++res_num;
448 }
449 res_num = 0;
450 llvm::StringMap<int> control_output_to_position;
451 for (const std::string& output : signature.control_output()) {
452 if (control_output_to_position.count(output))
453 return InvalidArgument("Duplicated control_output entry", output);
454 control_output_to_position[output] = res_num;
455 ++res_num;
456 }
457
458 // We pre-allocate the array of operands and populate it using the
459 // `output_name_to_position` and `control_output_to_position` populated
460 // previously.
461 SmallVector<Value> ret_vals(func.ret_size() + func.control_ret_size(),
462 Value());
463 for (const auto& ret_val : func.ret()) {
464 auto position = output_name_to_position.find(ret_val.first);
465 if (position == output_name_to_position.end()) {
466 return InvalidArgument(
467 "Can't import function, returned value references unknown output "
468 "argument ",
469 ret_val.first);
470 }
471 if (ret_val.second.empty()) {
472 return InvalidArgument("Function '", func.signature().name(),
473 "' has empty result name");
474 }
475 ret_vals[position->second] =
476 value_manager.GetValueOrCreatePlaceholder(ret_val.second);
477 }
478 for (const auto& ret_val : func.control_ret()) {
479 auto position = control_output_to_position.find(ret_val.first);
480 if (position == control_output_to_position.end()) {
481 return InvalidArgument(
482 "Can't import function, returned value references unknown output "
483 "argument ",
484 ret_val.first);
485 }
486 if (ret_val.second.empty()) {
487 return InvalidArgument("Function '", func.signature().name(),
488 "' has empty control result name");
489 }
490 Value result = value_manager.GetValueOrCreatePlaceholder(
491 (Twine("^") + ret_val.second).str());
492 if (!result.getType().isa<ControlType>())
493 return InvalidArgument("failed to map returned value ", ret_val.second,
494 ", isn't a control output");
495 ret_vals[func.ret_size() + position->second] = result;
496 }
497 // Check that all the of the return operands have been populated.
498 for (auto& indexed_val : llvm::enumerate(ret_vals)) {
499 if (indexed_val.value()) continue;
500 return InvalidArgument(
501 "Failed to import function, missing output for position ",
502 indexed_val.index());
503 }
504 MutableArrayRef<Value> operands = ret_vals;
505 ReturnOp ret_op = body_builder.create<ReturnOp>(
506 unknown_loc, operands.slice(0, func.ret_size()),
507 operands.slice(func.ret_size()));
508
509 // Now that we have all the types, set the function signature as the
510 // "function_type" attribute.
511 {
512 SmallVector<Type> arg_types_with_ctl;
513 for (Type type : arg_types) {
514 arg_types_with_ctl.push_back(type);
515 arg_types_with_ctl.push_back(control_ty);
516 }
517 attrs.append("function_type",
518 TypeAttr::get(builder.getFunctionType(
519 arg_types_with_ctl, ret_op.getOperandTypes())));
520 }
521 func_op->setAttrs(attrs);
522 return ::tensorflow::OkStatus();
523 }
524
525 } // namespace
526
ConvertGenericFunction(GraphFuncOp func_op,const FunctionDef & func,OpBuilder & builder)527 Status ConvertGenericFunction(GraphFuncOp func_op, const FunctionDef& func,
528 OpBuilder& builder) {
529 llvm::StringMap<llvm::StringMap<SmallVector<Value, 1>>> values_map;
530 return ImportGenericFunction(func_op, func, values_map, builder);
531 }
532
533 } // namespace tfg
534 } // namespace mlir
535