xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ir/importexport/graphdef_export.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/ir/importexport/graphdef_export.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/strings/str_cat.h"
22 #include "llvm/ADT/PointerUnion.h"
23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
28 #include "mlir/IR/Location.h"  // from @llvm-project
29 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
32 #include "mlir/IR/Threading.h"  // from @llvm-project
33 #include "mlir/IR/Value.h"  // from @llvm-project
34 #include "mlir/Support/LLVM.h"  // from @llvm-project
35 #include "tensorflow/core/framework/attr_value.pb.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/framework/function.pb.h"
38 #include "tensorflow/core/framework/graph.pb.h"
39 #include "tensorflow/core/framework/node_def.pb.h"
40 #include "tensorflow/core/framework/op.h"
41 #include "tensorflow/core/framework/op_def.pb.h"
42 #include "tensorflow/core/framework/op_def_builder.h"
43 #include "tensorflow/core/framework/types.pb.h"
44 #include "tensorflow/core/framework/versions.pb.h"
45 #include "tensorflow/core/ir/dialect.h"
46 #include "tensorflow/core/ir/importexport/convert_attributes.h"
47 #include "tensorflow/core/ir/importexport/convert_types.h"
48 #include "tensorflow/core/ir/importexport/functiondef_export.h"
49 #include "tensorflow/core/ir/ops.h"
50 #include "tensorflow/core/ir/types/dialect.h"
51 #include "tensorflow/core/platform/errors.h"
52 #include "tensorflow/core/platform/status.h"
53 #include "tensorflow/core/platform/statusor.h"
54 
55 using tensorflow::AttrValue;
56 using tensorflow::DataType;
57 using tensorflow::FunctionDef;
58 using tensorflow::FunctionLibraryDefinition;
59 using tensorflow::GradientDef;
60 using tensorflow::GraphDef;
61 using tensorflow::NodeDef;
62 using tensorflow::OpDef;
63 using tensorflow::OpRegistrationData;
64 using tensorflow::OpRegistry;
65 using tensorflow::Status;
66 using tensorflow::StatusOr;
67 using tensorflow::VersionDef;
68 using tensorflow::errors::InvalidArgument;
69 
70 namespace mlir {
71 namespace tfg {
72 namespace {
73 // This class implements an exporter for TFG directly to GraphDef.
74 class GraphDefExporter {
75  public:
GraphDefExporter(TFGraphDialect * dialect,const OpRegistry & registry,llvm::PointerUnion<SymbolTable *,const FunctionLibraryDefinition * > function_table)76   GraphDefExporter(
77       TFGraphDialect *dialect, const OpRegistry &registry,
78       llvm::PointerUnion<SymbolTable *, const FunctionLibraryDefinition *>
79           function_table)
80       : ctx_(dialect->getContext()),
81         dialect_(dialect),
82         registry_(registry),
83         function_table_(function_table) {}
84 
85   // Export a TFG module to GraphDef. The module may contain at most one GraphOp
86   // and only GraphFuncOp otherwise.
87   Status ExportToGraphDef(ModuleOp module, GraphDef *graph);
88 
89   // Export a TFG graph function to a FunctionDef. If the function has a
90   // gradient, add it to the graph afterwards to preserve thread-safety.
91   StatusOr<Optional<GradientDef>> ExportFunction(GraphFuncOp func,
92                                                  FunctionDef *def);
93 
94  private:
95   // Export just the input and outputs of a function signature. When
96   // fully-qualifying result names, this must be done before any nodes are
97   // Convert argument attributes to an ArgDef.
98   StatusOr<OpDef::ArgDef> ConvertArgumentAttributes(DictionaryAttr attrs);
99 
100   // Convert a TFG op to a node. When converting a function, fully-qualified
101   // result names must be used.
102   Status ConvertOperation(Operation *op, NodeDef *node, bool is_func);
103 
104   // Get the name associated with a value.
105   StatusOr<std::string> GetEdgeName(Value value, bool is_func);
106 
107   // Get the name and index of an output segment to fully qualify result names.
108   // This requires querying the op registry.
109   StatusOr<std::pair<StringRef, unsigned>> GetOutputSegment(OpResult result);
110 
111   // Get the name of a function argument from a function in the symbol table.
112   StatusOr<StringRef> GetFunctionOutputName(unsigned result_idx,
113                                             const std::string &op_name,
114                                             SymbolTable &table);
115   // Get the name of a function argument from a function in the library.
116   static StatusOr<StringRef> GetFunctionOutputName(
117       unsigned result_idx, const std::string &op_name,
118       const FunctionLibraryDefinition &library);
119 
120   // The current MLIR context.
121   MLIRContext *ctx_;
122   // The TFG dialect instance.
123   TFGraphDialect *dialect_;
124   // The TF op registry to use.
125   const OpRegistry &registry_;
126   // A lookup table for functions.
127   llvm::PointerUnion<SymbolTable *, const FunctionLibraryDefinition *>
128       function_table_;
129 };
130 }  // namespace
131 
132 // Returns a validated graph to export. A TFG module is valid for export if it
133 // contains at most one graph operation and any number of graph functions.
134 // Otherwise, returns an error.
ValidateModuleForExport(ModuleOp module)135 static StatusOr<GraphOp> ValidateModuleForExport(ModuleOp module) {
136   GraphOp graph_op;
137   for (Operation &op : *module.getBody()) {
138     if (isa<GraphFuncOp>(op)) continue;
139     if (auto new_graph_op = dyn_cast<GraphOp>(op)) {
140       if (graph_op) {
141         return InvalidArgument(
142             "Can't export module with two different tfg.graph");
143       }
144       graph_op = new_graph_op;
145       continue;
146     }
147     return InvalidArgument(
148         "Can't export module with other ops than tfg.graph or tfg.func, has: ",
149         op.getName().getStringRef().str());
150   }
151   return graph_op;
152 }
153 
154 // Converts a version attribute to VersionDef.
ExportVersionAttr(VersionAttr attr,VersionDef * version)155 static void ExportVersionAttr(VersionAttr attr, VersionDef *version) {
156   version->set_producer(attr.getProducer());
157   version->set_min_consumer(attr.getMinConsumer());
158   for (int32_t bad_consumer : attr.getBadConsumers())
159     version->add_bad_consumers(bad_consumer);
160 }
161 
ExportToGraphDef(ModuleOp module,GraphDef * graph)162 Status GraphDefExporter::ExportToGraphDef(ModuleOp module, GraphDef *graph) {
163   TF_ASSIGN_OR_RETURN(GraphOp graph_op, ValidateModuleForExport(module));
164   if (graph_op) {
165     ExportVersionAttr(graph_op.version(), graph->mutable_versions());
166     for (Operation &op : *graph_op.getBody()) {
167       TF_RETURN_IF_ERROR(ConvertOperation(&op, graph->mutable_node()->Add(),
168                                           /*is_func=*/false));
169     }
170   }
171 
172   const auto convert_func = [this](GraphFuncOp func, FunctionDef *def,
173                                    Optional<GradientDef> &gradient) {
174     // Generic functions are not on the hot path and skip the conversion to
175     // Graph so just call the existing exporter.
176     if (func.generic()) {
177       TF_ASSIGN_OR_RETURN(*def, ConvertGenericFunctionToFunctionDef(func));
178     } else {
179       TF_ASSIGN_OR_RETURN(gradient, ExportFunction(func, def));
180     }
181     return ::tensorflow::OkStatus();
182   };
183 
184   // TODO(jeffniu): Don't export functions in parallel if there are too few or
185   // they are too small.
186   if (ctx_->isMultithreadingEnabled()) {
187     ctx_->enterMultiThreadedExecution();
188     auto exit =
189         llvm::make_scope_exit([this] { ctx_->exitMultiThreadedExecution(); });
190 
191     // Prepare the arguments to parallel for each.
192     struct Argument {
193       GraphFuncOp func;
194       FunctionDef *def;
195       Status status;
196       Optional<GradientDef> gradient;
197     };
198     std::vector<Argument> args;
199     for (auto func : module.getOps<GraphFuncOp>())
200       args.push_back(Argument{func, graph->mutable_library()->add_function()});
201     const auto process_func = [&convert_func](Argument &arg) {
202       arg.status = convert_func(arg.func, arg.def, arg.gradient);
203       return success(arg.status.ok());
204     };
205 
206     // Execute the exports in parallel.
207     if (failed(failableParallelForEach(ctx_, args, process_func))) {
208       Status result;
209       for (const Argument &arg : args) {
210         result.Update(arg.status);
211       }
212       return result;
213     }
214   } else {
215     for (auto func : module.getOps<GraphFuncOp>()) {
216       Optional<GradientDef> gradient;
217       TF_RETURN_IF_ERROR(convert_func(
218           func, graph->mutable_library()->add_function(), gradient));
219       if (gradient)
220         *graph->mutable_library()->add_gradient() = std::move(*gradient);
221     }
222   }
223 
224   return ::tensorflow::OkStatus();
225 }
226 
227 // The only dialect attributes allowed have the "tf." prefix. This is a slightly
228 // faster check that an attribute is a dialect attribute.
IsDialectAttr(const NamedAttribute & attr)229 static bool IsDialectAttr(const NamedAttribute &attr) {
230   return attr.getName().getValue().startswith("tf.");
231 }
232 
233 // Export the given attribute list.
ConvertAttributes(tensorflow::protobuf::Map<std::string,AttrValue> * map,ArrayRef<NamedAttribute> attrs)234 static Status ConvertAttributes(
235     tensorflow::protobuf::Map<std::string, AttrValue> *map,
236     ArrayRef<NamedAttribute> attrs) {
237   for (const NamedAttribute &attr : attrs) {
238     if (!IsDialectAttr(attr)) continue;
239     StringRef name = attr.getName().strref().drop_front(/*strlen("tf.")=*/3);
240     TF_ASSIGN_OR_RETURN((*map)[name.str()], ConvertAttribute(attr.getValue()));
241   }
242   return ::tensorflow::OkStatus();
243 }
244 
ExportFunction(GraphFuncOp func,FunctionDef * def)245 StatusOr<Optional<GradientDef>> GraphDefExporter::ExportFunction(
246     GraphFuncOp func, FunctionDef *def) {
247   std::string func_name = func.sym_name().str();
248 
249   // TODO(jeffniu): Exploit the sorted order of the function attributes.
250 
251   // Get a gradient, if there is one.
252   Optional<GradientDef> gradient;
253   if (Optional<StringRef> gradient_name = func.gradient()) {
254     gradient.emplace();
255     gradient->set_gradient_func(gradient_name->str());
256     gradient->set_function_name(func_name);
257   }
258 
259   // Convert the first-class attributes.
260   OpDef *signature = def->mutable_signature();
261   signature->set_name(func_name);
262   if (Optional<StringRef> description = func.description())
263     signature->set_description(description->str());
264   signature->set_is_stateful(func.is_stateful());
265 
266   if (DenseIntElementsAttr keys = func.resource_arg_unique_ids_keysAttr()) {
267     DenseIntElementsAttr values = func.resource_arg_unique_ids_valuesAttr();
268     if (!values) {
269       return InvalidArgument(
270           "'resource_arg_unique_ids_keys' is present but "
271           "'resource_arg_unique_ids_values' is missing");
272     }
273     if (keys.size() != values.size()) {
274       return InvalidArgument(
275           "'resource_arg_unique_ids_keys' is not the same size as "
276           "'resource_arg_unique_ids_values'");
277     }
278     auto *id_map = def->mutable_resource_arg_unique_id();
279     for (auto kv :
280          llvm::zip(keys.getValues<int32_t>(), values.getValues<int32_t>()))
281       (*id_map)[std::get<0>(kv)] = std::get<1>(kv);
282   }
283 
284   // Convert other attributes with the "tf." prefix.
285   TF_RETURN_IF_ERROR(ConvertAttributes(def->mutable_attr(), func->getAttrs()));
286 
287   // Convert the arguments.
288   for (int i = 0, e = func.getNumArguments(); i < e; i += 2) {
289     auto attrs = func.arg_attrs().getValue()[i].cast<DictionaryAttr>();
290     TF_ASSIGN_OR_RETURN(OpDef::ArgDef &arg = *signature->add_input_arg(),
291                         ConvertArgumentAttributes(attrs));
292     DataType dtype;
293     TF_RETURN_IF_ERROR(ConvertToDataType(
294         func.getArgument(i).getType().cast<TensorType>().getElementType(),
295         &dtype));
296     arg.set_type(dtype);
297     // Convert the attributes.
298     if (llvm::any_of(attrs, [](const NamedAttribute &attr) {
299           return IsDialectAttr(attr);
300         })) {
301       auto *map = (*def->mutable_arg_attr())[i / 2].mutable_attr();
302       TF_RETURN_IF_ERROR(ConvertAttributes(map, attrs.getValue()));
303     }
304   }
305 
306   // Convert the results.
307   auto return_op = cast<ReturnOp>(func.getBody()->getTerminator());
308   for (auto it :
309        llvm::zip(func.getResultTypes(),
310                  func.getAllResultAttrs().getAsRange<DictionaryAttr>(),
311                  TFOp(return_op).getNonControlOperands())) {
312     TF_ASSIGN_OR_RETURN(OpDef::ArgDef &arg = *signature->add_output_arg(),
313                         ConvertArgumentAttributes(std::get<1>(it)));
314     DataType dtype;
315     TF_RETURN_IF_ERROR(ConvertToDataType(
316         std::get<0>(it).cast<TensorType>().getElementType(), &dtype));
317     arg.set_type(dtype);
318     // Map the result.
319     TF_ASSIGN_OR_RETURN((*def->mutable_ret())[arg.name()],
320                         GetEdgeName(std::get<2>(it), /*is_func=*/true));
321   }
322 
323   // Convert the control results.
324   for (auto it :
325        llvm::zip(return_op.control_ret_attrs().getAsRange<DictionaryAttr>(),
326                  TFOp(return_op).getControlOperands())) {
327     // The control result attributes contain only the name.
328     DictionaryAttr attrs = std::get<0>(it);
329     if (attrs.empty())
330       return InvalidArgument("Control result is missing 'tfg.name'");
331     assert(attrs.begin()->getName() == dialect_->getTfgNameAttrIdentifier());
332     std::string name = attrs.begin()->getValue().cast<StringAttr>().str();
333     signature->add_control_output(name);
334     // Map the control result.
335     TF_ASSIGN_OR_RETURN(std::string value_name,
336                         GetEdgeName(std::get<1>(it), /*is_func=*/true));
337     // Add the control result name without '^'.
338     def->mutable_control_ret()->insert({std::move(name), value_name.substr(1)});
339   }
340 
341   // Convert the body.
342   for (Operation &op : func.getBody()->without_terminator())
343     TF_RETURN_IF_ERROR(
344         ConvertOperation(&op, def->add_node_def(), /*is_func=*/true));
345 
346   return gradient;
347 }
348 
ConvertArgumentAttributes(DictionaryAttr attrs)349 StatusOr<OpDef::ArgDef> GraphDefExporter::ConvertArgumentAttributes(
350     DictionaryAttr attrs) {
351   OpDef::ArgDef arg;
352   auto name = attrs.getAs<StringAttr>(dialect_->getTfgNameAttrIdentifier());
353   if (!name) return InvalidArgument("argument is missing 'tfg.name'");
354   arg.set_name(name.str());
355   if (auto description =
356           attrs.getAs<StringAttr>(dialect_->getTfgDescriptionAttrIdentifier()))
357     arg.set_description(description.str());
358   arg.set_is_ref(!!attrs.get(dialect_->getTfgIsRefAttrIdentifier()));
359   TF_RETURN_IF_ERROR(ConvertHandleData(
360       attrs.getAs<ArrayAttr>(dialect_->getTfgHandleDataAttrIdentifier()),
361       &arg));
362   if (auto full_type = attrs.getAs<tf_type::FullTypeAttr>(
363           dialect_->getTfgFullTypeAttrIdentifier())) {
364     TF_ASSIGN_OR_RETURN(*arg.mutable_experimental_full_type(),
365                         ConvertAttribute(full_type));
366   }
367   return arg;
368 }
369 
370 // Converts a location to the debug information for the node def, if we find
371 // supported location, that is a top-level NameLoc or any NameLoc nested inside
372 // a FusedLoc. Other kind of location are ignored. If a NameLoc is of the form
373 // "name@func" we parse it and import the two appropriately.
ExtractExperimentalDebugInfoFromLocation(Location inst_loc,NodeDef::ExperimentalDebugInfo * debug_info)374 static void ExtractExperimentalDebugInfoFromLocation(
375     Location inst_loc, NodeDef::ExperimentalDebugInfo *debug_info) {
376   auto add_name_loc = [&](mlir::NameLoc name_loc) {
377     StringRef node, func;
378     std::tie(node, func) = name_loc.getName().strref().split('@');
379     debug_info->add_original_node_names(node.str());
380     if (!func.empty()) debug_info->add_original_func_names(func.str());
381   };
382   if (auto fused = inst_loc.dyn_cast<mlir::FusedLoc>()) {
383     for (Location loc : fused.getLocations())
384       if (auto name_loc = loc.dyn_cast<mlir::NameLoc>()) add_name_loc(name_loc);
385     return;
386   }
387   if (auto name_loc = inst_loc.dyn_cast<mlir::NameLoc>())
388     add_name_loc(name_loc);
389 }
390 
ConvertToNodeDef(Operation * op,NodeDef * node,TFGraphDialect * dialect,function_ref<StatusOr<std::string> (Value)> get_value_name)391 Status ConvertToNodeDef(
392     Operation *op, NodeDef *node, TFGraphDialect *dialect,
393     function_ref<StatusOr<std::string>(Value)> get_value_name) {
394   // Convert first-class attributes.
395   if (auto name =
396           op->getAttrOfType<StringAttr>(dialect->getNameAttrIdentifier()))
397     node->set_name(name.str());
398   if (auto device =
399           op->getAttrOfType<StringAttr>(dialect->getDeviceAttrIdentifier()))
400     node->set_device(device.str());
401   if (auto full_type = op->getAttrOfType<tf_type::FullTypeAttr>(
402           dialect->getFullTypeAttrIdentifier())) {
403     TF_ASSIGN_OR_RETURN(*node->mutable_experimental_type(),
404                         ConvertAttribute(full_type));
405   }
406   {
407     if (auto assigned_device = op->getAttrOfType<StringAttr>(
408             dialect->getAssignedDeviceAttrIdentifier())) {
409       if (!assigned_device.getValue().empty()) {
410         (*node->mutable_attr())[dialect->getAssignedDeviceAttrKey().str()]
411             .set_s(assigned_device.str());
412       }
413     }
414   }
415   // Convert other attributes.
416   for (const NamedAttribute &attr : op->getAttrs()) {
417     if (attr.getName() == dialect->getAssignedDeviceAttrIdentifier() ||
418         attr.getName() == dialect->getDeviceAttrIdentifier() ||
419         attr.getName() == dialect->getFullTypeAttrIdentifier() ||
420         attr.getName() == dialect->getNameAttrIdentifier())
421       continue;
422     TF_ASSIGN_OR_RETURN((*node->mutable_attr())[attr.getName().str()],
423                         ConvertAttribute(attr.getValue()));
424   }
425 
426   // Set the op name.
427   node->set_op(op->getName().stripDialect().str());
428 
429   // Set the input names.
430   for (Value operand : op->getOperands()) {
431     TF_ASSIGN_OR_RETURN(std::string input_name, get_value_name(operand));
432     node->add_input(std::move(input_name));
433   }
434 
435   // Export the location as debug info.
436   if (!op->getLoc().isa<UnknownLoc>()) {
437     ExtractExperimentalDebugInfoFromLocation(
438         op->getLoc(), node->mutable_experimental_debug_info());
439     if (node->experimental_debug_info().original_node_names().empty())
440       node->clear_experimental_debug_info();
441   }
442 
443   return ::tensorflow::OkStatus();
444 }
445 
ConvertOperation(Operation * op,NodeDef * node,bool is_func)446 Status GraphDefExporter::ConvertOperation(Operation *op, NodeDef *node,
447                                           bool is_func) {
448   return ConvertToNodeDef(op, node, dialect_, [&](Value value) {
449     return GetEdgeName(value, is_func);
450   });
451 }
452 
453 // Get the edge name of a value. If `get_output_segment` is specified, it means
454 // the name should be fully qualified if it is an operation result for exporting
455 // a function.
GetValueName(Value value,TFGraphDialect * dialect,function_ref<StatusOr<std::pair<StringRef,unsigned>> (OpResult)> get_output_segment)456 static StatusOr<std::string> GetValueName(
457     Value value, TFGraphDialect *dialect,
458     function_ref<StatusOr<std::pair<StringRef, unsigned>>(OpResult)>
459         get_output_segment) {
460   std::string name;
461   bool is_control = value.getType() == dialect->getControlType();
462 
463   if (auto arg = value.dyn_cast<BlockArgument>()) {
464     auto func = dyn_cast<GraphFuncOp>(arg.getOwner()->getParentOp());
465     if (!func)
466       return InvalidArgument("Expected block argument owner to be tfg.func");
467     // If the block argument is a control token, use the attributes of the
468     // associated data argument (which preceeds it).
469     auto attrs = func.arg_attrs()
470                      .getValue()[arg.getArgNumber() - is_control]
471                      .cast<DictionaryAttr>();
472     auto name_attr =
473         attrs.getAs<StringAttr>(dialect->getTfgNameAttrIdentifier());
474     if (!name_attr) {
475       return InvalidArgument(
476           "Can't export graph with missing op-name for function parameter #",
477           arg.getArgNumber());
478     }
479     name.reserve(name_attr.size() + 1);
480     if (is_control) name.push_back('^');
481     name.append(name_attr.data(), name_attr.size());
482     return name;
483   }
484 
485   auto result = value.cast<OpResult>();
486   auto name_attr = result.getOwner()->getAttrOfType<StringAttr>(
487       dialect->getNameAttrIdentifier());
488   if (!name_attr)
489     return InvalidArgument("Can't export graph with missing op-name");
490 
491   if (is_control) {
492     name.reserve(1 + name_attr.size());
493     name.push_back('^');
494     name.append(name_attr.data(), name_attr.size());
495     return name;
496   }
497 
498   if (!get_output_segment) {
499     name.reserve(name_attr.size() + 3);
500     name.append(name_attr.data(), name_attr.size());
501     if (result.getResultNumber()) {
502       name.push_back(':');
503       absl::StrAppend(&name, result.getResultNumber());
504     }
505     return name;
506   }
507 
508   TF_ASSIGN_OR_RETURN(auto segment, get_output_segment(result));
509   name.reserve(name_attr.size() + segment.first.size() + 4);
510   name.append(name_attr.data(), name_attr.size());
511   name.push_back(':');
512   name.append(segment.first.data(), segment.first.size());
513   name.push_back(':');
514   absl::StrAppend(&name, segment.second);
515   return name;
516 }
517 
GetValueName(Value value,TFGraphDialect * dialect)518 StatusOr<std::string> GetValueName(Value value, TFGraphDialect *dialect) {
519   return GetValueName(value, dialect, /*get_output_segment=*/nullptr);
520 }
521 
GetEdgeName(Value value,bool is_func)522 StatusOr<std::string> GraphDefExporter::GetEdgeName(Value value, bool is_func) {
523   if (!is_func) return GetValueName(value, dialect_);
524   return GetValueName(value, dialect_, [&](OpResult result) {
525     return GetOutputSegment(result);
526   });
527 }
528 
529 // Get the segment size of an op's output.
GetOutputSegmentSize(Operation * op,const OpDef::ArgDef & arg)530 static StatusOr<unsigned> GetOutputSegmentSize(Operation *op,
531                                                const OpDef::ArgDef &arg) {
532   if (!arg.type_list_attr().empty()) {
533     if (auto v = op->getAttr(arg.type_list_attr()).dyn_cast<ArrayAttr>())
534       return v.size();
535     return InvalidArgument("Type attr not found: ", arg.type_list_attr());
536   }
537   if (arg.number_attr().empty()) return 1;
538   if (auto v = op->getAttr(arg.number_attr()).dyn_cast<IntegerAttr>())
539     return v.getValue().getZExtValue();
540   return InvalidArgument("Type attr not found: ", arg.number_attr());
541 }
542 
GetFunctionOutputName(unsigned result_idx,const std::string & op_name,SymbolTable & table)543 StatusOr<StringRef> GraphDefExporter::GetFunctionOutputName(
544     unsigned result_idx, const std::string &op_name, SymbolTable &table) {
545   if (auto func = table.lookup<GraphFuncOp>(op_name)) {
546     if (result_idx >= func.getNumResults()) {
547       return InvalidArgument("Result #", result_idx, " of function '", op_name,
548                              "' is out of range");
549     }
550     if (auto name = func.getResultAttrOfType<StringAttr>(
551             result_idx, dialect_->getTfgNameAttrIdentifier())) {
552       return name.getValue();
553     }
554     return InvalidArgument("Function '", op_name, "' result #", result_idx,
555                            "' is missing 'tfg.name'");
556   }
557   return InvalidArgument("Op '", op_name,
558                          "' is neither registered nor a function");
559 }
560 
561 // Get the name of a function argument from a function in the library.
GetFunctionOutputName(unsigned result_idx,const std::string & op_name,const FunctionLibraryDefinition & library)562 StatusOr<StringRef> GraphDefExporter::GetFunctionOutputName(
563     unsigned result_idx, const std::string &op_name,
564     const FunctionLibraryDefinition &library) {
565   if (const FunctionDef *function = library.Find(op_name)) {
566     if (result_idx >= function->signature().output_arg_size()) {
567       return InvalidArgument("Result #", result_idx, " of function '", op_name,
568                              "' is out of range");
569     }
570     return {function->signature().output_arg(result_idx).name()};
571   }
572   return InvalidArgument("Op '", op_name,
573                          "' is neither registered nor a function");
574 }
575 
GetOutputSegment(OpResult result)576 StatusOr<std::pair<StringRef, unsigned>> GraphDefExporter::GetOutputSegment(
577     OpResult result) {
578   // TODO(jeffniu): OpRegistry::LookUp should accept `string_view`.
579   Operation *op = result.getOwner();
580   std::string op_name = op->getName().stripDialect().str();
581   unsigned result_idx = result.getResultNumber();
582   // Only edges in functions need to have fully-qualified names. Get the segment
583   // name using the op definition.
584   if (const OpRegistrationData *op_reg_data = registry_.LookUp(op_name)) {
585     const OpDef &op_def = op_reg_data->op_def;
586 
587     for (const OpDef::ArgDef &arg : op_def.output_arg()) {
588       TF_ASSIGN_OR_RETURN(unsigned size, GetOutputSegmentSize(op, arg));
589       if (size > result_idx)
590         return std::pair<StringRef, unsigned>(arg.name(), result_idx);
591       result_idx -= size;
592     }
593     return InvalidArgument("Result #", result_idx, " of op '", op_name,
594                            "' is out of range");
595   }
596   // Try to find a function for a legacy call. Function output segments have
597   // exactly one element each.
598   StringRef arg_name;
599   if (auto *table = function_table_.dyn_cast<SymbolTable *>()) {
600     TF_ASSIGN_OR_RETURN(arg_name,
601                         GetFunctionOutputName(result_idx, op_name, *table));
602   } else {
603     TF_ASSIGN_OR_RETURN(
604         arg_name,
605         GetFunctionOutputName(
606             result_idx, op_name,
607             *function_table_.get<const FunctionLibraryDefinition *>()));
608   }
609   return std::pair<StringRef, unsigned>(arg_name, 0);
610 }
611 
612 // Convert a TFG graph directly to GraphDef.
ConvertToGraphDef(ModuleOp module,tensorflow::GraphDef * graph)613 Status ConvertToGraphDef(ModuleOp module, tensorflow::GraphDef *graph) {
614   SymbolTable table(module);
615   GraphDefExporter exporter(
616       module.getContext()->getOrLoadDialect<TFGraphDialect>(),
617       *OpRegistry::Global(), &table);
618   return exporter.ExportToGraphDef(module, graph);
619 }
620 
621 // Convert a single TFG function to a FunctionDef and add it to the function
622 // library. If a function with the same name already exists, replace it.
ConvertToFunctionDef(GraphFuncOp func,FunctionLibraryDefinition & library)623 Status ConvertToFunctionDef(GraphFuncOp func,
624                             FunctionLibraryDefinition &library) {
625   GraphDefExporter exporter(func.getDialect(), *OpRegistry::Global(), &library);
626   FunctionDef def;
627   TF_ASSIGN_OR_RETURN(Optional<GradientDef> gradient,
628                       exporter.ExportFunction(func, &def));
629   const std::string &name = def.signature().name();
630   if (library.Contains(name)) {
631     TF_RETURN_IF_ERROR(library.ReplaceFunction(name, def));
632   } else {
633     TF_RETURN_IF_ERROR(library.AddFunctionDef(def));
634   }
635   if (gradient) {
636     if (library.FindGradient(name).empty()) {
637       TF_RETURN_IF_ERROR(library.AddGradientDef(*gradient));
638     } else {
639       TF_RETURN_IF_ERROR(library.ReplaceGradient(*gradient));
640     }
641   }
642   return ::tensorflow::OkStatus();
643 }
644 
645 }  // namespace tfg
646 }  // namespace mlir
647