1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/ir/importexport/graphdef_export.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/strings/str_cat.h"
22 #include "llvm/ADT/PointerUnion.h"
23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
28 #include "mlir/IR/Location.h" // from @llvm-project
29 #include "mlir/IR/MLIRContext.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "mlir/IR/SymbolTable.h" // from @llvm-project
32 #include "mlir/IR/Threading.h" // from @llvm-project
33 #include "mlir/IR/Value.h" // from @llvm-project
34 #include "mlir/Support/LLVM.h" // from @llvm-project
35 #include "tensorflow/core/framework/attr_value.pb.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/function.pb.h"
38 #include "tensorflow/core/framework/graph.pb.h"
39 #include "tensorflow/core/framework/node_def.pb.h"
40 #include "tensorflow/core/framework/op.h"
41 #include "tensorflow/core/framework/op_def.pb.h"
42 #include "tensorflow/core/framework/op_def_builder.h"
43 #include "tensorflow/core/framework/types.pb.h"
44 #include "tensorflow/core/framework/versions.pb.h"
45 #include "tensorflow/core/ir/dialect.h"
46 #include "tensorflow/core/ir/importexport/convert_attributes.h"
47 #include "tensorflow/core/ir/importexport/convert_types.h"
48 #include "tensorflow/core/ir/importexport/functiondef_export.h"
49 #include "tensorflow/core/ir/ops.h"
50 #include "tensorflow/core/ir/types/dialect.h"
51 #include "tensorflow/core/platform/errors.h"
52 #include "tensorflow/core/platform/status.h"
53 #include "tensorflow/core/platform/statusor.h"
54
55 using tensorflow::AttrValue;
56 using tensorflow::DataType;
57 using tensorflow::FunctionDef;
58 using tensorflow::FunctionLibraryDefinition;
59 using tensorflow::GradientDef;
60 using tensorflow::GraphDef;
61 using tensorflow::NodeDef;
62 using tensorflow::OpDef;
63 using tensorflow::OpRegistrationData;
64 using tensorflow::OpRegistry;
65 using tensorflow::Status;
66 using tensorflow::StatusOr;
67 using tensorflow::VersionDef;
68 using tensorflow::errors::InvalidArgument;
69
70 namespace mlir {
71 namespace tfg {
72 namespace {
73 // This class implements an exporter for TFG directly to GraphDef.
74 class GraphDefExporter {
75 public:
GraphDefExporter(TFGraphDialect * dialect,const OpRegistry & registry,llvm::PointerUnion<SymbolTable *,const FunctionLibraryDefinition * > function_table)76 GraphDefExporter(
77 TFGraphDialect *dialect, const OpRegistry ®istry,
78 llvm::PointerUnion<SymbolTable *, const FunctionLibraryDefinition *>
79 function_table)
80 : ctx_(dialect->getContext()),
81 dialect_(dialect),
82 registry_(registry),
83 function_table_(function_table) {}
84
85 // Export a TFG module to GraphDef. The module may contain at most one GraphOp
86 // and only GraphFuncOp otherwise.
87 Status ExportToGraphDef(ModuleOp module, GraphDef *graph);
88
89 // Export a TFG graph function to a FunctionDef. If the function has a
90 // gradient, add it to the graph afterwards to preserve thread-safety.
91 StatusOr<Optional<GradientDef>> ExportFunction(GraphFuncOp func,
92 FunctionDef *def);
93
94 private:
95 // Export just the input and outputs of a function signature. When
96 // fully-qualifying result names, this must be done before any nodes are
97 // Convert argument attributes to an ArgDef.
98 StatusOr<OpDef::ArgDef> ConvertArgumentAttributes(DictionaryAttr attrs);
99
100 // Convert a TFG op to a node. When converting a function, fully-qualified
101 // result names must be used.
102 Status ConvertOperation(Operation *op, NodeDef *node, bool is_func);
103
104 // Get the name associated with a value.
105 StatusOr<std::string> GetEdgeName(Value value, bool is_func);
106
107 // Get the name and index of an output segment to fully qualify result names.
108 // This requires querying the op registry.
109 StatusOr<std::pair<StringRef, unsigned>> GetOutputSegment(OpResult result);
110
111 // Get the name of a function argument from a function in the symbol table.
112 StatusOr<StringRef> GetFunctionOutputName(unsigned result_idx,
113 const std::string &op_name,
114 SymbolTable &table);
115 // Get the name of a function argument from a function in the library.
116 static StatusOr<StringRef> GetFunctionOutputName(
117 unsigned result_idx, const std::string &op_name,
118 const FunctionLibraryDefinition &library);
119
120 // The current MLIR context.
121 MLIRContext *ctx_;
122 // The TFG dialect instance.
123 TFGraphDialect *dialect_;
124 // The TF op registry to use.
125 const OpRegistry ®istry_;
126 // A lookup table for functions.
127 llvm::PointerUnion<SymbolTable *, const FunctionLibraryDefinition *>
128 function_table_;
129 };
130 } // namespace
131
132 // Returns a validated graph to export. A TFG module is valid for export if it
133 // contains at most one graph operation and any number of graph functions.
134 // Otherwise, returns an error.
ValidateModuleForExport(ModuleOp module)135 static StatusOr<GraphOp> ValidateModuleForExport(ModuleOp module) {
136 GraphOp graph_op;
137 for (Operation &op : *module.getBody()) {
138 if (isa<GraphFuncOp>(op)) continue;
139 if (auto new_graph_op = dyn_cast<GraphOp>(op)) {
140 if (graph_op) {
141 return InvalidArgument(
142 "Can't export module with two different tfg.graph");
143 }
144 graph_op = new_graph_op;
145 continue;
146 }
147 return InvalidArgument(
148 "Can't export module with other ops than tfg.graph or tfg.func, has: ",
149 op.getName().getStringRef().str());
150 }
151 return graph_op;
152 }
153
154 // Converts a version attribute to VersionDef.
ExportVersionAttr(VersionAttr attr,VersionDef * version)155 static void ExportVersionAttr(VersionAttr attr, VersionDef *version) {
156 version->set_producer(attr.getProducer());
157 version->set_min_consumer(attr.getMinConsumer());
158 for (int32_t bad_consumer : attr.getBadConsumers())
159 version->add_bad_consumers(bad_consumer);
160 }
161
ExportToGraphDef(ModuleOp module,GraphDef * graph)162 Status GraphDefExporter::ExportToGraphDef(ModuleOp module, GraphDef *graph) {
163 TF_ASSIGN_OR_RETURN(GraphOp graph_op, ValidateModuleForExport(module));
164 if (graph_op) {
165 ExportVersionAttr(graph_op.version(), graph->mutable_versions());
166 for (Operation &op : *graph_op.getBody()) {
167 TF_RETURN_IF_ERROR(ConvertOperation(&op, graph->mutable_node()->Add(),
168 /*is_func=*/false));
169 }
170 }
171
172 const auto convert_func = [this](GraphFuncOp func, FunctionDef *def,
173 Optional<GradientDef> &gradient) {
174 // Generic functions are not on the hot path and skip the conversion to
175 // Graph so just call the existing exporter.
176 if (func.generic()) {
177 TF_ASSIGN_OR_RETURN(*def, ConvertGenericFunctionToFunctionDef(func));
178 } else {
179 TF_ASSIGN_OR_RETURN(gradient, ExportFunction(func, def));
180 }
181 return ::tensorflow::OkStatus();
182 };
183
184 // TODO(jeffniu): Don't export functions in parallel if there are too few or
185 // they are too small.
186 if (ctx_->isMultithreadingEnabled()) {
187 ctx_->enterMultiThreadedExecution();
188 auto exit =
189 llvm::make_scope_exit([this] { ctx_->exitMultiThreadedExecution(); });
190
191 // Prepare the arguments to parallel for each.
192 struct Argument {
193 GraphFuncOp func;
194 FunctionDef *def;
195 Status status;
196 Optional<GradientDef> gradient;
197 };
198 std::vector<Argument> args;
199 for (auto func : module.getOps<GraphFuncOp>())
200 args.push_back(Argument{func, graph->mutable_library()->add_function()});
201 const auto process_func = [&convert_func](Argument &arg) {
202 arg.status = convert_func(arg.func, arg.def, arg.gradient);
203 return success(arg.status.ok());
204 };
205
206 // Execute the exports in parallel.
207 if (failed(failableParallelForEach(ctx_, args, process_func))) {
208 Status result;
209 for (const Argument &arg : args) {
210 result.Update(arg.status);
211 }
212 return result;
213 }
214 } else {
215 for (auto func : module.getOps<GraphFuncOp>()) {
216 Optional<GradientDef> gradient;
217 TF_RETURN_IF_ERROR(convert_func(
218 func, graph->mutable_library()->add_function(), gradient));
219 if (gradient)
220 *graph->mutable_library()->add_gradient() = std::move(*gradient);
221 }
222 }
223
224 return ::tensorflow::OkStatus();
225 }
226
227 // The only dialect attributes allowed have the "tf." prefix. This is a slightly
228 // faster check that an attribute is a dialect attribute.
IsDialectAttr(const NamedAttribute & attr)229 static bool IsDialectAttr(const NamedAttribute &attr) {
230 return attr.getName().getValue().startswith("tf.");
231 }
232
233 // Export the given attribute list.
ConvertAttributes(tensorflow::protobuf::Map<std::string,AttrValue> * map,ArrayRef<NamedAttribute> attrs)234 static Status ConvertAttributes(
235 tensorflow::protobuf::Map<std::string, AttrValue> *map,
236 ArrayRef<NamedAttribute> attrs) {
237 for (const NamedAttribute &attr : attrs) {
238 if (!IsDialectAttr(attr)) continue;
239 StringRef name = attr.getName().strref().drop_front(/*strlen("tf.")=*/3);
240 TF_ASSIGN_OR_RETURN((*map)[name.str()], ConvertAttribute(attr.getValue()));
241 }
242 return ::tensorflow::OkStatus();
243 }
244
ExportFunction(GraphFuncOp func,FunctionDef * def)245 StatusOr<Optional<GradientDef>> GraphDefExporter::ExportFunction(
246 GraphFuncOp func, FunctionDef *def) {
247 std::string func_name = func.sym_name().str();
248
249 // TODO(jeffniu): Exploit the sorted order of the function attributes.
250
251 // Get a gradient, if there is one.
252 Optional<GradientDef> gradient;
253 if (Optional<StringRef> gradient_name = func.gradient()) {
254 gradient.emplace();
255 gradient->set_gradient_func(gradient_name->str());
256 gradient->set_function_name(func_name);
257 }
258
259 // Convert the first-class attributes.
260 OpDef *signature = def->mutable_signature();
261 signature->set_name(func_name);
262 if (Optional<StringRef> description = func.description())
263 signature->set_description(description->str());
264 signature->set_is_stateful(func.is_stateful());
265
266 if (DenseIntElementsAttr keys = func.resource_arg_unique_ids_keysAttr()) {
267 DenseIntElementsAttr values = func.resource_arg_unique_ids_valuesAttr();
268 if (!values) {
269 return InvalidArgument(
270 "'resource_arg_unique_ids_keys' is present but "
271 "'resource_arg_unique_ids_values' is missing");
272 }
273 if (keys.size() != values.size()) {
274 return InvalidArgument(
275 "'resource_arg_unique_ids_keys' is not the same size as "
276 "'resource_arg_unique_ids_values'");
277 }
278 auto *id_map = def->mutable_resource_arg_unique_id();
279 for (auto kv :
280 llvm::zip(keys.getValues<int32_t>(), values.getValues<int32_t>()))
281 (*id_map)[std::get<0>(kv)] = std::get<1>(kv);
282 }
283
284 // Convert other attributes with the "tf." prefix.
285 TF_RETURN_IF_ERROR(ConvertAttributes(def->mutable_attr(), func->getAttrs()));
286
287 // Convert the arguments.
288 for (int i = 0, e = func.getNumArguments(); i < e; i += 2) {
289 auto attrs = func.arg_attrs().getValue()[i].cast<DictionaryAttr>();
290 TF_ASSIGN_OR_RETURN(OpDef::ArgDef &arg = *signature->add_input_arg(),
291 ConvertArgumentAttributes(attrs));
292 DataType dtype;
293 TF_RETURN_IF_ERROR(ConvertToDataType(
294 func.getArgument(i).getType().cast<TensorType>().getElementType(),
295 &dtype));
296 arg.set_type(dtype);
297 // Convert the attributes.
298 if (llvm::any_of(attrs, [](const NamedAttribute &attr) {
299 return IsDialectAttr(attr);
300 })) {
301 auto *map = (*def->mutable_arg_attr())[i / 2].mutable_attr();
302 TF_RETURN_IF_ERROR(ConvertAttributes(map, attrs.getValue()));
303 }
304 }
305
306 // Convert the results.
307 auto return_op = cast<ReturnOp>(func.getBody()->getTerminator());
308 for (auto it :
309 llvm::zip(func.getResultTypes(),
310 func.getAllResultAttrs().getAsRange<DictionaryAttr>(),
311 TFOp(return_op).getNonControlOperands())) {
312 TF_ASSIGN_OR_RETURN(OpDef::ArgDef &arg = *signature->add_output_arg(),
313 ConvertArgumentAttributes(std::get<1>(it)));
314 DataType dtype;
315 TF_RETURN_IF_ERROR(ConvertToDataType(
316 std::get<0>(it).cast<TensorType>().getElementType(), &dtype));
317 arg.set_type(dtype);
318 // Map the result.
319 TF_ASSIGN_OR_RETURN((*def->mutable_ret())[arg.name()],
320 GetEdgeName(std::get<2>(it), /*is_func=*/true));
321 }
322
323 // Convert the control results.
324 for (auto it :
325 llvm::zip(return_op.control_ret_attrs().getAsRange<DictionaryAttr>(),
326 TFOp(return_op).getControlOperands())) {
327 // The control result attributes contain only the name.
328 DictionaryAttr attrs = std::get<0>(it);
329 if (attrs.empty())
330 return InvalidArgument("Control result is missing 'tfg.name'");
331 assert(attrs.begin()->getName() == dialect_->getTfgNameAttrIdentifier());
332 std::string name = attrs.begin()->getValue().cast<StringAttr>().str();
333 signature->add_control_output(name);
334 // Map the control result.
335 TF_ASSIGN_OR_RETURN(std::string value_name,
336 GetEdgeName(std::get<1>(it), /*is_func=*/true));
337 // Add the control result name without '^'.
338 def->mutable_control_ret()->insert({std::move(name), value_name.substr(1)});
339 }
340
341 // Convert the body.
342 for (Operation &op : func.getBody()->without_terminator())
343 TF_RETURN_IF_ERROR(
344 ConvertOperation(&op, def->add_node_def(), /*is_func=*/true));
345
346 return gradient;
347 }
348
ConvertArgumentAttributes(DictionaryAttr attrs)349 StatusOr<OpDef::ArgDef> GraphDefExporter::ConvertArgumentAttributes(
350 DictionaryAttr attrs) {
351 OpDef::ArgDef arg;
352 auto name = attrs.getAs<StringAttr>(dialect_->getTfgNameAttrIdentifier());
353 if (!name) return InvalidArgument("argument is missing 'tfg.name'");
354 arg.set_name(name.str());
355 if (auto description =
356 attrs.getAs<StringAttr>(dialect_->getTfgDescriptionAttrIdentifier()))
357 arg.set_description(description.str());
358 arg.set_is_ref(!!attrs.get(dialect_->getTfgIsRefAttrIdentifier()));
359 TF_RETURN_IF_ERROR(ConvertHandleData(
360 attrs.getAs<ArrayAttr>(dialect_->getTfgHandleDataAttrIdentifier()),
361 &arg));
362 if (auto full_type = attrs.getAs<tf_type::FullTypeAttr>(
363 dialect_->getTfgFullTypeAttrIdentifier())) {
364 TF_ASSIGN_OR_RETURN(*arg.mutable_experimental_full_type(),
365 ConvertAttribute(full_type));
366 }
367 return arg;
368 }
369
370 // Converts a location to the debug information for the node def, if we find
371 // supported location, that is a top-level NameLoc or any NameLoc nested inside
372 // a FusedLoc. Other kind of location are ignored. If a NameLoc is of the form
373 // "name@func" we parse it and import the two appropriately.
ExtractExperimentalDebugInfoFromLocation(Location inst_loc,NodeDef::ExperimentalDebugInfo * debug_info)374 static void ExtractExperimentalDebugInfoFromLocation(
375 Location inst_loc, NodeDef::ExperimentalDebugInfo *debug_info) {
376 auto add_name_loc = [&](mlir::NameLoc name_loc) {
377 StringRef node, func;
378 std::tie(node, func) = name_loc.getName().strref().split('@');
379 debug_info->add_original_node_names(node.str());
380 if (!func.empty()) debug_info->add_original_func_names(func.str());
381 };
382 if (auto fused = inst_loc.dyn_cast<mlir::FusedLoc>()) {
383 for (Location loc : fused.getLocations())
384 if (auto name_loc = loc.dyn_cast<mlir::NameLoc>()) add_name_loc(name_loc);
385 return;
386 }
387 if (auto name_loc = inst_loc.dyn_cast<mlir::NameLoc>())
388 add_name_loc(name_loc);
389 }
390
ConvertToNodeDef(Operation * op,NodeDef * node,TFGraphDialect * dialect,function_ref<StatusOr<std::string> (Value)> get_value_name)391 Status ConvertToNodeDef(
392 Operation *op, NodeDef *node, TFGraphDialect *dialect,
393 function_ref<StatusOr<std::string>(Value)> get_value_name) {
394 // Convert first-class attributes.
395 if (auto name =
396 op->getAttrOfType<StringAttr>(dialect->getNameAttrIdentifier()))
397 node->set_name(name.str());
398 if (auto device =
399 op->getAttrOfType<StringAttr>(dialect->getDeviceAttrIdentifier()))
400 node->set_device(device.str());
401 if (auto full_type = op->getAttrOfType<tf_type::FullTypeAttr>(
402 dialect->getFullTypeAttrIdentifier())) {
403 TF_ASSIGN_OR_RETURN(*node->mutable_experimental_type(),
404 ConvertAttribute(full_type));
405 }
406 {
407 if (auto assigned_device = op->getAttrOfType<StringAttr>(
408 dialect->getAssignedDeviceAttrIdentifier())) {
409 if (!assigned_device.getValue().empty()) {
410 (*node->mutable_attr())[dialect->getAssignedDeviceAttrKey().str()]
411 .set_s(assigned_device.str());
412 }
413 }
414 }
415 // Convert other attributes.
416 for (const NamedAttribute &attr : op->getAttrs()) {
417 if (attr.getName() == dialect->getAssignedDeviceAttrIdentifier() ||
418 attr.getName() == dialect->getDeviceAttrIdentifier() ||
419 attr.getName() == dialect->getFullTypeAttrIdentifier() ||
420 attr.getName() == dialect->getNameAttrIdentifier())
421 continue;
422 TF_ASSIGN_OR_RETURN((*node->mutable_attr())[attr.getName().str()],
423 ConvertAttribute(attr.getValue()));
424 }
425
426 // Set the op name.
427 node->set_op(op->getName().stripDialect().str());
428
429 // Set the input names.
430 for (Value operand : op->getOperands()) {
431 TF_ASSIGN_OR_RETURN(std::string input_name, get_value_name(operand));
432 node->add_input(std::move(input_name));
433 }
434
435 // Export the location as debug info.
436 if (!op->getLoc().isa<UnknownLoc>()) {
437 ExtractExperimentalDebugInfoFromLocation(
438 op->getLoc(), node->mutable_experimental_debug_info());
439 if (node->experimental_debug_info().original_node_names().empty())
440 node->clear_experimental_debug_info();
441 }
442
443 return ::tensorflow::OkStatus();
444 }
445
ConvertOperation(Operation * op,NodeDef * node,bool is_func)446 Status GraphDefExporter::ConvertOperation(Operation *op, NodeDef *node,
447 bool is_func) {
448 return ConvertToNodeDef(op, node, dialect_, [&](Value value) {
449 return GetEdgeName(value, is_func);
450 });
451 }
452
453 // Get the edge name of a value. If `get_output_segment` is specified, it means
454 // the name should be fully qualified if it is an operation result for exporting
455 // a function.
GetValueName(Value value,TFGraphDialect * dialect,function_ref<StatusOr<std::pair<StringRef,unsigned>> (OpResult)> get_output_segment)456 static StatusOr<std::string> GetValueName(
457 Value value, TFGraphDialect *dialect,
458 function_ref<StatusOr<std::pair<StringRef, unsigned>>(OpResult)>
459 get_output_segment) {
460 std::string name;
461 bool is_control = value.getType() == dialect->getControlType();
462
463 if (auto arg = value.dyn_cast<BlockArgument>()) {
464 auto func = dyn_cast<GraphFuncOp>(arg.getOwner()->getParentOp());
465 if (!func)
466 return InvalidArgument("Expected block argument owner to be tfg.func");
467 // If the block argument is a control token, use the attributes of the
468 // associated data argument (which preceeds it).
469 auto attrs = func.arg_attrs()
470 .getValue()[arg.getArgNumber() - is_control]
471 .cast<DictionaryAttr>();
472 auto name_attr =
473 attrs.getAs<StringAttr>(dialect->getTfgNameAttrIdentifier());
474 if (!name_attr) {
475 return InvalidArgument(
476 "Can't export graph with missing op-name for function parameter #",
477 arg.getArgNumber());
478 }
479 name.reserve(name_attr.size() + 1);
480 if (is_control) name.push_back('^');
481 name.append(name_attr.data(), name_attr.size());
482 return name;
483 }
484
485 auto result = value.cast<OpResult>();
486 auto name_attr = result.getOwner()->getAttrOfType<StringAttr>(
487 dialect->getNameAttrIdentifier());
488 if (!name_attr)
489 return InvalidArgument("Can't export graph with missing op-name");
490
491 if (is_control) {
492 name.reserve(1 + name_attr.size());
493 name.push_back('^');
494 name.append(name_attr.data(), name_attr.size());
495 return name;
496 }
497
498 if (!get_output_segment) {
499 name.reserve(name_attr.size() + 3);
500 name.append(name_attr.data(), name_attr.size());
501 if (result.getResultNumber()) {
502 name.push_back(':');
503 absl::StrAppend(&name, result.getResultNumber());
504 }
505 return name;
506 }
507
508 TF_ASSIGN_OR_RETURN(auto segment, get_output_segment(result));
509 name.reserve(name_attr.size() + segment.first.size() + 4);
510 name.append(name_attr.data(), name_attr.size());
511 name.push_back(':');
512 name.append(segment.first.data(), segment.first.size());
513 name.push_back(':');
514 absl::StrAppend(&name, segment.second);
515 return name;
516 }
517
GetValueName(Value value,TFGraphDialect * dialect)518 StatusOr<std::string> GetValueName(Value value, TFGraphDialect *dialect) {
519 return GetValueName(value, dialect, /*get_output_segment=*/nullptr);
520 }
521
GetEdgeName(Value value,bool is_func)522 StatusOr<std::string> GraphDefExporter::GetEdgeName(Value value, bool is_func) {
523 if (!is_func) return GetValueName(value, dialect_);
524 return GetValueName(value, dialect_, [&](OpResult result) {
525 return GetOutputSegment(result);
526 });
527 }
528
529 // Get the segment size of an op's output.
GetOutputSegmentSize(Operation * op,const OpDef::ArgDef & arg)530 static StatusOr<unsigned> GetOutputSegmentSize(Operation *op,
531 const OpDef::ArgDef &arg) {
532 if (!arg.type_list_attr().empty()) {
533 if (auto v = op->getAttr(arg.type_list_attr()).dyn_cast<ArrayAttr>())
534 return v.size();
535 return InvalidArgument("Type attr not found: ", arg.type_list_attr());
536 }
537 if (arg.number_attr().empty()) return 1;
538 if (auto v = op->getAttr(arg.number_attr()).dyn_cast<IntegerAttr>())
539 return v.getValue().getZExtValue();
540 return InvalidArgument("Type attr not found: ", arg.number_attr());
541 }
542
GetFunctionOutputName(unsigned result_idx,const std::string & op_name,SymbolTable & table)543 StatusOr<StringRef> GraphDefExporter::GetFunctionOutputName(
544 unsigned result_idx, const std::string &op_name, SymbolTable &table) {
545 if (auto func = table.lookup<GraphFuncOp>(op_name)) {
546 if (result_idx >= func.getNumResults()) {
547 return InvalidArgument("Result #", result_idx, " of function '", op_name,
548 "' is out of range");
549 }
550 if (auto name = func.getResultAttrOfType<StringAttr>(
551 result_idx, dialect_->getTfgNameAttrIdentifier())) {
552 return name.getValue();
553 }
554 return InvalidArgument("Function '", op_name, "' result #", result_idx,
555 "' is missing 'tfg.name'");
556 }
557 return InvalidArgument("Op '", op_name,
558 "' is neither registered nor a function");
559 }
560
561 // Get the name of a function argument from a function in the library.
GetFunctionOutputName(unsigned result_idx,const std::string & op_name,const FunctionLibraryDefinition & library)562 StatusOr<StringRef> GraphDefExporter::GetFunctionOutputName(
563 unsigned result_idx, const std::string &op_name,
564 const FunctionLibraryDefinition &library) {
565 if (const FunctionDef *function = library.Find(op_name)) {
566 if (result_idx >= function->signature().output_arg_size()) {
567 return InvalidArgument("Result #", result_idx, " of function '", op_name,
568 "' is out of range");
569 }
570 return {function->signature().output_arg(result_idx).name()};
571 }
572 return InvalidArgument("Op '", op_name,
573 "' is neither registered nor a function");
574 }
575
GetOutputSegment(OpResult result)576 StatusOr<std::pair<StringRef, unsigned>> GraphDefExporter::GetOutputSegment(
577 OpResult result) {
578 // TODO(jeffniu): OpRegistry::LookUp should accept `string_view`.
579 Operation *op = result.getOwner();
580 std::string op_name = op->getName().stripDialect().str();
581 unsigned result_idx = result.getResultNumber();
582 // Only edges in functions need to have fully-qualified names. Get the segment
583 // name using the op definition.
584 if (const OpRegistrationData *op_reg_data = registry_.LookUp(op_name)) {
585 const OpDef &op_def = op_reg_data->op_def;
586
587 for (const OpDef::ArgDef &arg : op_def.output_arg()) {
588 TF_ASSIGN_OR_RETURN(unsigned size, GetOutputSegmentSize(op, arg));
589 if (size > result_idx)
590 return std::pair<StringRef, unsigned>(arg.name(), result_idx);
591 result_idx -= size;
592 }
593 return InvalidArgument("Result #", result_idx, " of op '", op_name,
594 "' is out of range");
595 }
596 // Try to find a function for a legacy call. Function output segments have
597 // exactly one element each.
598 StringRef arg_name;
599 if (auto *table = function_table_.dyn_cast<SymbolTable *>()) {
600 TF_ASSIGN_OR_RETURN(arg_name,
601 GetFunctionOutputName(result_idx, op_name, *table));
602 } else {
603 TF_ASSIGN_OR_RETURN(
604 arg_name,
605 GetFunctionOutputName(
606 result_idx, op_name,
607 *function_table_.get<const FunctionLibraryDefinition *>()));
608 }
609 return std::pair<StringRef, unsigned>(arg_name, 0);
610 }
611
612 // Convert a TFG graph directly to GraphDef.
ConvertToGraphDef(ModuleOp module,tensorflow::GraphDef * graph)613 Status ConvertToGraphDef(ModuleOp module, tensorflow::GraphDef *graph) {
614 SymbolTable table(module);
615 GraphDefExporter exporter(
616 module.getContext()->getOrLoadDialect<TFGraphDialect>(),
617 *OpRegistry::Global(), &table);
618 return exporter.ExportToGraphDef(module, graph);
619 }
620
621 // Convert a single TFG function to a FunctionDef and add it to the function
622 // library. If a function with the same name already exists, replace it.
ConvertToFunctionDef(GraphFuncOp func,FunctionLibraryDefinition & library)623 Status ConvertToFunctionDef(GraphFuncOp func,
624 FunctionLibraryDefinition &library) {
625 GraphDefExporter exporter(func.getDialect(), *OpRegistry::Global(), &library);
626 FunctionDef def;
627 TF_ASSIGN_OR_RETURN(Optional<GradientDef> gradient,
628 exporter.ExportFunction(func, &def));
629 const std::string &name = def.signature().name();
630 if (library.Contains(name)) {
631 TF_RETURN_IF_ERROR(library.ReplaceFunction(name, def));
632 } else {
633 TF_RETURN_IF_ERROR(library.AddFunctionDef(def));
634 }
635 if (gradient) {
636 if (library.FindGradient(name).empty()) {
637 TF_RETURN_IF_ERROR(library.AddGradientDef(*gradient));
638 } else {
639 TF_RETURN_IF_ERROR(library.ReplaceGradient(*gradient));
640 }
641 }
642 return ::tensorflow::OkStatus();
643 }
644
645 } // namespace tfg
646 } // namespace mlir
647