xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ir/ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/ir/ops.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Support/SMLoc.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
36 #include "mlir/IR/Dialect.h"  // from @llvm-project
37 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
38 #include "mlir/IR/FunctionImplementation.h"  // from @llvm-project
39 #include "mlir/IR/FunctionInterfaces.h"  // from @llvm-project
40 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
41 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
42 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
43 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
44 #include "mlir/IR/TypeRange.h"  // from @llvm-project
45 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
46 #include "mlir/IR/Value.h"  // from @llvm-project
47 #include "mlir/Interfaces/ControlFlowInterfaces.h"  // from @llvm-project
48 #include "mlir/Support/LLVM.h"  // from @llvm-project
49 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
50 #include "tensorflow/core/ir/dialect.h"
51 #include "tensorflow/core/ir/interfaces.h"
52 #include "tensorflow/core/ir/types/dialect.h"
53 #include "tensorflow/core/ir/utility.h"
54 
55 // Generated definitions.
56 #include "tensorflow/core/ir/dialect.cc.inc"
57 
58 namespace mlir {
59 namespace tfg {
60 
61 //===----------------------------------------------------------------------===//
62 // TFGraph dialect.
63 //===----------------------------------------------------------------------===//
64 
65 // Name operation results with the operation name, except control outputs which
66 // are named "ctl". MLIR will automatically use a numerical suffix to unique.
GenericGetAsmResultNames(Operation * op,OpAsmSetValueNameFn set_name_fn)67 static void GenericGetAsmResultNames(Operation *op,
68                                      OpAsmSetValueNameFn set_name_fn) {
69   // We only name the results when there are results to name, an op like `print`
70   // which does not have results will just use the `ctl` name for the control
71   // output.
72   if (op->getNumResults() > 1 && !op->getResult(0).getType().isa<ControlType>())
73     set_name_fn(op->getResult(0), op->getName().stripDialect());
74   for (Value result : op->getResults()) {
75     if (result.getType().isa<ControlType>()) {
76       set_name_fn(op->getResult(op->getNumResults() - 1), "ctl");
77       break;
78     }
79   }
80 }
81 
82 // TFGraph support for interacting with the AsmPrinter.
83 // Gives prettier names to SSA values.
84 struct TFGraphOpAsmInterface
85     : public OpAsmOpInterface::FallbackModel<TFGraphOpAsmInterface> {
classofmlir::tfg::TFGraphOpAsmInterface86   static bool classof(Operation *op) { return true; }
87 
getAsmResultNamesmlir::tfg::TFGraphOpAsmInterface88   void getAsmResultNames(Operation *op, OpAsmSetValueNameFn set_name_fn) const {
89     GenericGetAsmResultNames(op, set_name_fn);
90   }
getAsmBlockArgumentNamesmlir::tfg::TFGraphOpAsmInterface91   void getAsmBlockArgumentNames(Operation *op, Region &region,
92                                 OpAsmSetValueNameFn setNameFn) const {}
getAsmBlockNamesmlir::tfg::TFGraphOpAsmInterface93   void getAsmBlockNames(Operation *op,
94                         mlir::OpAsmSetBlockNameFn setNameFn) const {}
95 };
96 
97 // Dialect construction: there is one instance per context and it registers its
98 // operations, types, and interfaces here.
initialize()99 void TFGraphDialect::initialize() {
100   getContext()->getOrLoadDialect<tf_type::TFTypeDialect>();
101   addOperations<
102 #define GET_OP_LIST
103 #include "tensorflow/core/ir/ops.cc.inc"
104       >();
105   addAttributes<
106 #define GET_ATTRDEF_LIST
107 #include "tensorflow/core/ir/attributes.cc.inc"
108       >();
109 
110   // Support unknown operations because not all TensorFlow operations are
111   // registered.
112   allowUnknownOperations();
113 
114   // Create the fallback OpAsmOpInterface instance.
115   fallbackOpAsmInterface_ = new TFGraphOpAsmInterface;
116 
117   // Register the memory effects interface adaptor.
118   addInterfaces<StatefulMemoryEffectInterface>();
119 
120   // Initialized the cached operation names.
121 #define GET_OP_NAME_DEFS
122 #include "tensorflow/core/ir/tf_op_names.inc"
123 
124   // Caching some often used context-owned informations for fast-access.
125   name_key_ = StringAttr::get(getContext(), getNameAttrKey());
126   device_key_ = StringAttr::get(getContext(), getDeviceAttrKey());
127   assigned_device_key_ =
128       StringAttr::get(getContext(), getAssignedDeviceAttrKey());
129   fulltype_key_ = StringAttr::get(getContext(), getFullTypeAttrKey());
130   lifted_graph_func_name_ =
131       StringAttr::get(getContext(), getLiftedGraphFuncNameKey());
132   tfg_name_key_ = StringAttr::get(getContext(), getTfgNameAttrKey());
133   tfg_description_key_ =
134       StringAttr::get(getContext(), getTfgDescriptionAttrKey());
135   tfg_is_ref_key_ = StringAttr::get(getContext(), getTfgIsRefAttrKey());
136   tfg_handle_data_key_ =
137       StringAttr::get(getContext(), getTfgHandleDataAttrKey());
138   tfg_full_type_key_ = StringAttr::get(getContext(), getTfgFullTypeAttrKey());
139 
140   control_ty_ = ControlType::get(getContext());
141 }
142 
143 // Provides a hook for op interface.
getRegisteredInterfaceForOp(TypeID interface,OperationName opName)144 void *TFGraphDialect::getRegisteredInterfaceForOp(TypeID interface,
145                                                   OperationName opName) {
146   if (interface == TypeID::get<OpAsmOpInterface>()) {
147     return fallbackOpAsmInterface_;
148   }
149 
150   // Intrinsic operations explicitly implement intefaces.
151   if (opName.hasTrait<OpTrait::IntrinsicOperation>()) {
152     return nullptr;
153   }
154 
155   if (interface == TypeID::get<TensorFlowRegistryInterface>()) {
156     if (auto *instance =
157             getRegisteredInterface<TensorFlowRegistryInterfaceBase>()) {
158       // Important: cast to (Concept *) to shift the pointer off the vtable.
159       return static_cast<TensorFlowRegistryInterfaceBase::Concept *>(
160           const_cast<TensorFlowRegistryInterfaceBase *>(instance));
161     }
162   } else if (interface == TypeID::get<MemoryEffectOpInterface>()) {
163     auto *instance = getRegisteredInterface<StatefulMemoryEffectInterface>();
164     assert(instance && "expected the memory interface to be registered");
165     return static_cast<StatefulMemoryEffectInterface::Concept *>(
166         const_cast<StatefulMemoryEffectInterface *>(instance));
167   }
168 
169   return nullptr;
170 }
171 
~TFGraphDialect()172 TFGraphDialect::~TFGraphDialect() { delete fallbackOpAsmInterface_; }
173 
174 // The name of certain optional attributes.
175 static std::array<StringRef, 3> keyword_attrs{
176     "_mlir_device", "_mlir_assigned_device", "_mlir_name"};
177 
PrintKeywordAttributes(Operation * op,OpAsmPrinter & printer,ArrayRef<StringRef> elided_attrs={})178 static void PrintKeywordAttributes(Operation *op, OpAsmPrinter &printer,
179                                    ArrayRef<StringRef> elided_attrs = {}) {
180   // Handles the optional "device" and "name" attribute.
181   for (StringRef keyword : keyword_attrs) {
182     if (StringAttr value_attr = op->getAttrOfType<StringAttr>(keyword)) {
183       assert(!value_attr.getValue().empty());
184       printer << " " << keyword.drop_front(/*len(_mlir_)*/ 6) << "(\""
185               << value_attr.getValue() << "\")";
186     }
187   }
188 
189   // Print attributes (other than name and device).
190   SmallVector<StringRef> attrs_to_elide = llvm::to_vector(elided_attrs);
191   llvm::append_range(attrs_to_elide, keyword_attrs);
192   printer.printOptionalAttrDict(op->getAttrs(), attrs_to_elide);
193 }
194 
195 // Print an operation that belongs to this dialect, if unregistered.
196 // The general syntax is:
197 //   tfg.OpName(%input1, %input2, %input3) [%control_dep1, %control_dep2]
198 //           name("<node_name>") device("<device>") { attribute-dict } :
199 //           (input types) -> (result_types)
printCustomTfOp(Operation * op,OpAsmPrinter & printer) const200 void TFGraphDialect::printCustomTfOp(Operation *op,
201                                      OpAsmPrinter &printer) const {
202   ControlType control_ty = getControlType();
203 
204   // Check that all control dependencies are after the regular values,
205   // otherwise print the generic form. We don't expect this to happen but
206   // we're defensive in the printer since this may happen in "hard-to-debug"
207   // issues.
208   {
209     bool has_control_dep = false;
210     for (Value operand : op->getOperands()) {
211       if (operand.getType() == control_ty) {
212         has_control_dep = true;
213         continue;
214       }
215       if (has_control_dep) {
216         printer.printGenericOp(op);
217         return;
218       }
219     }
220     has_control_dep = false;
221     for (Value result : op->getResults()) {
222       if (result.getType() == control_ty) {
223         has_control_dep = true;
224         continue;
225       }
226       if (has_control_dep) {
227         printer.printGenericOp(op);
228         return;
229       }
230     }
231   }
232 
233   // Print the inputs (other than the control dependencies), if any.
234   TFOp tfg_op(op);
235   OperandRange data = tfg_op.getNonControlOperands();
236   if (!data.empty()) printer << '(' << data << ')';
237   // Print the control dependencies (if any).
238   OperandRange ctls = tfg_op.getControlOperands();
239   if (!ctls.empty()) printer << " [" << ctls << ']';
240 
241   // Print the keyword attributes and optional attribute dictionary.
242   PrintKeywordAttributes(op, printer);
243 
244   // Print the type, but omit control dependencies.
245   // If there is a single control return, just print the list of input types,
246   // otherwise print the complete type in a "function-style" way: (operands)
247   // -> (results).
248   ResultRange results = tfg_op.getNonControlResults();
249   if (results.empty()) {
250     if (!data.empty()) printer << " : " << data.getTypes();
251   } else {
252     printer << " : (" << data.getTypes() << ") -> (" << results.getTypes()
253             << ")";
254   }
255 }
256 
257 // Print a custom TFG op.
PrintCustomTfOp(Operation * op,OpAsmPrinter & printer)258 static void PrintCustomTfOp(Operation *op, OpAsmPrinter &printer) {
259   cast<TFGraphDialect>(op->getDialect())->printCustomTfOp(op, printer);
260 }
261 
262 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
getOperationPrinter(Operation * op) const263 TFGraphDialect::getOperationPrinter(Operation *op) const {
264   return [this](Operation *op, OpAsmPrinter &printer) {
265     this->printCustomTfOp(op, printer);
266   };
267 }
268 
269 // Try to parse optional keyword attributes and prefix them with `_mlir_`, of
270 // `device`, `assigned_device`, and `name`.
ParseKeywordAttributes(OpAsmParser & parser,OperationState & result)271 static ParseResult ParseKeywordAttributes(OpAsmParser &parser,
272                                           OperationState &result) {
273   for (const char *keyword : {"device", "assigned_device", "name"}) {
274     if (succeeded(parser.parseOptionalKeyword(keyword))) {
275       StringAttr value;
276       if (parser.parseLParen() ||
277           parser.parseAttribute<StringAttr>(
278               value, NoneType::get(parser.getContext())) ||
279           parser.parseRParen())
280         return failure();
281       result.addAttribute((Twine("_mlir_") + keyword).str(), value);
282     }
283   }
284   return parser.parseOptionalAttrDict(result.attributes);
285 }
286 
287 // Parse an operation that belongs to this dialect, if unregistered.
288 // The general syntax is:
289 //   tfg.OpName(%input1, %input2, %input3) [%control_dep1, %control_dep2]
290 //           name("<node_name>") device("<device>") { attribute-dict } :
291 //           (input types) -> (result_types)
ParseCustomTfOp(OpAsmParser & parser,OperationState & result)292 static ParseResult ParseCustomTfOp(OpAsmParser &parser,
293                                    OperationState &result) {
294   MLIRContext *context = parser.getBuilder().getContext();
295   // Parse optional argument list
296   SmallVector<OpAsmParser::UnresolvedOperand, 4> op_infos;
297   if (parser.parseOperandList(op_infos, AsmParser::Delimiter::OptionalParen))
298     return failure();
299   unsigned numNonControlOperands = op_infos.size();
300   // Optional control list, in between brackets.
301   if (parser.parseOperandList(op_infos, AsmParser::Delimiter::OptionalSquare))
302     return failure();
303 
304   // Parse the optional keyword attributes and optional attribute dictionary.
305   if (ParseKeywordAttributes(parser, result)) return failure();
306 
307   // Parse the functional type.
308   SmallVector<Type> arg_types;
309   arg_types.reserve(op_infos.size());
310   llvm::SMLoc loc = parser.getCurrentLocation();
311   Type control_type = ControlType::get(context);
312   if (failed(parser.parseOptionalColonTypeList(arg_types))) return failure();
313   if (arg_types.size() == 1 && arg_types.front().isa<FunctionType>()) {
314     auto funcType = arg_types.front().cast<FunctionType>();
315     if (funcType.getNumInputs() != numNonControlOperands)
316       return parser.emitError(loc)
317              << "got " << numNonControlOperands
318              << " non-control operands, but the type defines "
319              << funcType.getNumInputs() << " input types";
320     arg_types.clear();
321     arg_types.append(funcType.getInputs().begin(), funcType.getInputs().end());
322     result.types.append(funcType.getResults().begin(),
323                         funcType.getResults().end());
324   }
325 
326   // The control input are elided from the type list, add them here.
327   arg_types.resize(op_infos.size(), control_type);
328   if (!arg_types.empty())
329     if (parser.resolveOperands(op_infos, arg_types, loc, result.operands))
330       return failure();
331   if (result.name.getStringRef() != "tfg.return")
332     result.types.push_back(control_type);
333   return success();
334 }
335 
getParseOperationHook(StringRef opName) const336 Optional<Dialect::ParseOpHook> TFGraphDialect::getParseOperationHook(
337     StringRef opName) const {
338   return ParseOpHook(ParseCustomTfOp);
339 }
340 
VerifyGenericTFGOperation(Operation & op)341 static bool VerifyGenericTFGOperation(Operation &op) {
342   TFGraphDialect *dialect = dyn_cast<TFGraphDialect>(op.getDialect());
343   if (!dialect) return true;
344   ControlType control_ty = dialect->getControlType();
345 
346   // verifies that control operands (or results) are always after regular
347   // inputs (or results).
348   auto check_ctl_at_end = [&](TypeRange types, StringRef input_or_output) {
349     int has_control_dep = -1;
350     for (auto &indexed_operand : llvm::enumerate(types)) {
351       if (indexed_operand.value() == control_ty) {
352         has_control_dep = indexed_operand.index();
353         continue;
354       }
355       if (has_control_dep != -1) {
356         op.emitOpError() << "found non-control " << input_or_output
357                          << " in position #" << indexed_operand.index()
358                          << " after control " << input_or_output
359                          << " in position #" << has_control_dep;
360         return false;
361       }
362     }
363     return true;
364   };
365   if (!check_ctl_at_end(op.getOperandTypes(), "input")) return false;
366   if (!check_ctl_at_end(op.getResultTypes(), "result")) return false;
367 
368   // Certain attributes are supposed to be inserted with non-empty value.
369   for (StringRef keyword : keyword_attrs) {
370     if (StringAttr value_attr = op.getAttrOfType<StringAttr>(keyword)) {
371       if (value_attr.getValue().empty()) {
372         op.emitOpError() << keyword
373                          << " has empty value. Only insert this attribute when "
374                             "it has a value";
375       }
376     }
377   }
378 
379   return true;
380 }
381 
382 //===----------------------------------------------------------------------===//
383 // Graph Operation
384 //===----------------------------------------------------------------------===//
385 
verify()386 LogicalResult GraphOp::verify() {
387   GraphOp op = *this;
388   // Check all ops in the body.
389   if (!all_of(*op.getBody(), VerifyGenericTFGOperation)) return failure();
390 
391   return success();
392 }
393 //===----------------------------------------------------------------------===//
394 // Func Operation
395 //===----------------------------------------------------------------------===//
396 
isMarkedForCompilation()397 bool GraphFuncOp::isMarkedForCompilation() {
398   auto is_enabled = [this](StringRef attr_name) -> bool {
399     Attribute attr = (*this)->getAttr(attr_name);
400     if (!attr) return false;
401     if (auto bool_attr = attr.dyn_cast<BoolAttr>()) return bool_attr.getValue();
402     if (auto str_attr = attr.dyn_cast<StringAttr>())
403       return !str_attr.getValue().empty();
404     return false;
405   };
406   return is_enabled("_xla_compile_id") || is_enabled("_tpu_replicate") ||
407          is_enabled("_XlaMustCompile");
408 }
409 
410 // Hook for OpTrait::FunctionLike, called after verifying that the 'type'
411 // attribute is present and checks if it holds a function type. Ensures
412 // getType, getNumFuncArguments, and getNumFuncResults can be called safely
verifyType()413 LogicalResult GraphFuncOp::verifyType() {
414   auto type = getFunctionTypeAttr().getValue();
415   if (!type.isa<FunctionType>())
416     return emitOpError("requires '" + getTypeAttrName() +
417                        "' attribute of function type");
418   return success();
419 }
420 
421 // Hook for OpTrait::FunctionLike, called after verifying the function
422 // type and the presence of the (potentially empty) function body.
verifyBody()423 LogicalResult GraphFuncOp::verifyBody() {
424   FunctionType type = getFunctionType();
425   // Check that the body is terminated with a tfg.return.
426   if (getRegion().empty() || getBody()->empty())
427     return emitOpError() << "expects a non empty body";
428 
429   if (getBody()->getNumArguments() != type.getNumInputs())
430     return emitOpError() << "function type indicated " << type.getNumInputs()
431                          << " args but block has "
432                          << getBody()->getNumArguments();
433 
434   for (auto &arg_types : llvm::enumerate(
435            llvm::zip(type.getInputs(), getBody()->getArgumentTypes()))) {
436     Type signature_arg = std::get<0>(arg_types.value());
437     Type block_arg = std::get<1>(arg_types.value());
438     if (signature_arg != block_arg)
439       return emitOpError() << "type mismatch for arg #" << arg_types.index()
440                            << ", signature defines " << signature_arg
441                            << " block arg is " << block_arg;
442   }
443 
444   if (!isa<ReturnOp>(getBody()->back()))
445     return emitOpError()
446            << "expects body to be terminated with a tfg.return, but got: "
447            << getBody()->back().getName().getStringRef();
448 
449   ReturnOp return_op = cast<ReturnOp>(getBody()->getTerminator());
450 
451   if (type.getNumResults() > return_op->getNumOperands())
452     return emitOpError() << "expects " << type.getNumResults()
453                          << " returned values but tfg.return has "
454                          << return_op->getNumOperands() << " operands";
455   for (auto &indexed_type : llvm::enumerate(type.getResults())) {
456     Type expected_type = indexed_type.value();
457     int res_num = indexed_type.index();
458     Type actual_type = return_op->getOperand(res_num).getType();
459     if (expected_type == actual_type) continue;
460     return emitOpError() << "type mismatch for returned value #" << res_num
461                          << ", expected " << expected_type << " but got "
462                          << actual_type;
463   }
464   Type control_type = getDialect()->getControlType();
465   for (auto &indexed_type : llvm::enumerate(llvm::drop_begin(
466            return_op->getOperandTypes(), type.getNumResults()))) {
467     Type actual_type = indexed_type.value();
468     if (actual_type != control_type) {
469       return emitOpError() << "returned value #" << indexed_type.index()
470                            << " overflow the expected " << type.getNumResults()
471                            << " returned value for function " << getName()
472                            << ", expected a ControlType but got "
473                            << actual_type;
474     }
475   }
476 
477   // Check all ops in the body.
478   if (!all_of(*getBody(), VerifyGenericTFGOperation)) return failure();
479 
480   return success();
481 }
482 
canonicalize(GraphFuncOp func_op,PatternRewriter & rewriter)483 LogicalResult GraphFuncOp::canonicalize(GraphFuncOp func_op,
484                                         PatternRewriter &rewriter) {
485   // Prune function body: the body is a graph where feeds/fetches a materialized
486   // with function arguments and returned values. As such any operation not
487   // reachable from the "fetches" can be pruned. The return statement also has
488   // control input so that side-effecting operations without results (print for
489   // example) aren't pruned.
490   bool changed = true;
491   while (changed) {
492     changed = false;
493     for (Operation &op :
494          llvm::make_early_inc_range(llvm::reverse(*func_op.getBody()))) {
495       if (isa<ReturnOp>(op)) continue;
496       if (op.getUses().empty()) {
497         rewriter.eraseOp(&op);
498         changed = true;
499       }
500     }
501   }
502   return failure();
503 }
504 
verify()505 LogicalResult GraphFuncOp::verify() {
506   GraphFuncOp func_op = *this;
507   if (func_op.getNumArguments() % 2)
508     return func_op.emitOpError() << "expects an even number of arguments";
509   ArrayAttr args_attrs = func_op.getAllArgAttrs();
510   if (args_attrs && args_attrs.size() != func_op.getNumArguments())
511     return func_op.emitOpError()
512            << "expects argument attributes for each argument ("
513            << args_attrs.size() << " vs " << func_op.getNumArguments() << ")";
514   ArrayAttr res_attrs = func_op.getAllResultAttrs();
515   if (res_attrs && res_attrs.size() != func_op.getNumResults())
516     return func_op.emitOpError()
517            << "expects results attributes for each result (" << res_attrs.size()
518            << " vs " << func_op.getNumResults() << ")";
519   return success();
520 }
521 
parse(OpAsmParser & parser,OperationState & result)522 ParseResult GraphFuncOp::parse(OpAsmParser &parser, OperationState &result) {
523   SmallVector<OpAsmParser::UnresolvedOperand> entry_args;
524   SmallVector<Attribute> arg_attrs;
525   SmallVector<Attribute> result_attrs;
526   SmallVector<Type> arg_types;
527   SmallVector<Type> result_types;
528   auto &builder = parser.getBuilder();
529   MLIRContext *context = builder.getContext();
530 
531   // Parse visibility.
532   StringRef visibility;
533   if (!parser.parseOptionalKeyword(&visibility,
534                                    {"public", "private", "nested"})) {
535     StringAttr visibility_attr = parser.getBuilder().getStringAttr(visibility);
536     result.attributes.push_back(parser.getBuilder().getNamedAttr(
537         SymbolTable::getVisibilityAttrName(), visibility_attr));
538   }
539 
540   if (succeeded(parser.parseOptionalKeyword("generic")))
541     result.addAttribute("generic", builder.getUnitAttr());
542 
543   // Parse the name as a symbol.
544   StringAttr name_attr;
545   if (parser.parseSymbolName(name_attr, SymbolTable::getSymbolAttrName(),
546                              result.attributes))
547     return failure();
548 
549   // Parse the function signature.
550   // The difference with usual functions, is that for every single argument
551   // parsed, we create two block arguments: one for the expected value and one
552   // for the control dependency.
553   if (parser.parseLParen()) return failure();
554   Type control_ty = ControlType::get(builder.getContext());
555   std::list<std::string> control_operand_names;
556 
557   // Helper to parse a single argument and its attributes.
558   auto parse_argument = [&]() -> ParseResult {
559     // Parse argument name if present.
560     entry_args.emplace_back();
561     arg_types.emplace_back();
562     if (parser.parseOperand(entry_args.back(), /*allowResultNumber=*/false) ||
563         parser.parseColonType(arg_types.back()))
564       return failure();
565 
566     // Parse any argument attributes.
567     NamedAttrList attrs;
568     if (parser.parseOptionalAttrDict(attrs)) return failure();
569     arg_attrs.push_back(attrs.getDictionary(context));
570 
571     // Define the control input: it's not printed but is added as block
572     // argument. Note the name computed here (suffixed ".ctl") is coupled to the
573     // implementation of:
574     //   TFGraphOpAsmInterface::getAsmBlockArgumentNames()
575     // at the top of this file.
576     OpAsmParser::UnresolvedOperand control_operand = entry_args.back();
577     control_operand_names.push_back((control_operand.name + ".ctl").str());
578     control_operand.name = control_operand_names.back();
579     entry_args.push_back(control_operand);
580     arg_types.push_back(control_ty);
581     arg_attrs.push_back(DictionaryAttr::get(context));
582     return success();
583   };
584 
585   // Parse the function arguments and their attributes.
586   if (failed(parser.parseOptionalRParen())) {
587     do {
588       if (parse_argument()) return failure();
589     } while (succeeded(parser.parseOptionalComma()));
590     if (parser.parseRParen()) return failure();
591   }
592 
593   // Parse the result types and their attributes.
594   if (succeeded(parser.parseOptionalArrow())) {
595     if (failed(parser.parseLParen())) return failure();
596     if (failed(parser.parseOptionalRParen())) {
597       // Parse individual function results.
598       do {
599         result_types.emplace_back();
600         NamedAttrList result_attr;
601         if (parser.parseType(result_types.back()) ||
602             parser.parseOptionalAttrDict(result_attr)) {
603           return failure();
604         }
605         result_attrs.push_back(builder.getDictionaryAttr(result_attr));
606       } while (succeeded(parser.parseOptionalComma()));
607       if (parser.parseRParen()) return failure();
608     }
609   }
610 
611   auto type = builder.getFunctionType(arg_types, result_types);
612   result.addAttribute(GraphFuncOp::getTypeAttrName(), TypeAttr::get(type));
613 
614   // If function attributes are present, parse them.
615   NamedAttrList parsed_attributes;
616   if (parser.parseOptionalAttrDictWithKeyword(parsed_attributes))
617     return failure();
618   result.attributes.append(parsed_attributes);
619 
620   // Add the attributes to the function arguments.
621   assert(arg_attrs.size() == arg_types.size());
622   assert(result_attrs.size() == result_types.size());
623   result.attributes.append(
624       builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
625                            builder.getArrayAttr(arg_attrs)));
626   result.attributes.append(
627       builder.getNamedAttr(FunctionOpInterface::getResultDictAttrName(),
628                            builder.getArrayAttr(result_attrs)));
629 
630   // Parse the function body.
631   auto *body = result.addRegion();
632   llvm::SMLoc loc = parser.getCurrentLocation();
633   SmallVector<OpAsmParser::Argument> args;
634   if (entry_args.size()) {
635     for (auto argAndType : llvm::zip(entry_args, arg_types)) {
636       auto &arg = args.emplace_back();
637       arg.ssaName = std::get<0>(argAndType);
638       arg.type = std::get<1>(argAndType);
639     }
640   }
641 
642   if (failed(parser.parseRegion(*body, args, /*enableNameShadowing=*/false)))
643     return failure();
644 
645   // Function body was parsed, make sure it's not empty.
646   if (body->empty())
647     return parser.emitError(loc, "expected non-empty function body");
648 
649   return success();
650 }
651 
print(OpAsmPrinter & p)652 void GraphFuncOp::print(OpAsmPrinter &p) {
653   // Print the operation and the function name.
654   Operation *op = *this;
655   p << " ";
656   int argIndentSize = op->getName().getStringRef().size() + 3;
657   StringRef visibility_attr_name = SymbolTable::getVisibilityAttrName();
658   if (auto visibility = op->getAttrOfType<StringAttr>(visibility_attr_name)) {
659     p << visibility.getValue() << ' ';
660     argIndentSize += visibility.getValue().size() + 1;
661   }
662   if (generic()) p << "generic ";
663   auto funcName =
664       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
665           .getValue();
666   p.printSymbolName(funcName);
667   argIndentSize += funcName.size();
668   std::string indent(argIndentSize, ' ');
669   FunctionType fnType = getFunctionType();
670   ArrayRef<Type> arg_types = fnType.getInputs();
671   ArrayRef<Type> result_types = fnType.getResults();
672   assert((arg_types.size() % 2) == 0);
673   // Print operand list with attributes.
674   p << '(';
675   ArrayAttr args_attr = getAllArgAttrs();
676   for (unsigned i = 0, e = arg_types.size(); i < e; i += 2) {
677     // Args come by pair: input+control.
678     p.printOperand(getArgument(i));
679     p << ": ";
680     p.printType(arg_types[i]);
681     if (auto arg_attrs = args_attr[i].dyn_cast<DictionaryAttr>())
682       p.printOptionalAttrDict(arg_attrs.getValue());
683     if (i != e - 2) {
684       p << ", ";
685       p.printNewline();
686       p << indent;
687     }
688   }
689   p << ')';
690 
691   // Print result types, if any.
692   if (!result_types.empty()) {
693     p.printNewline();
694     p.getStream() << "     -> (";
695     indent = std::string(9, ' ');
696     ArrayAttr results_attr = getAllResultAttrs();
697     for (int i = 0, e = result_types.size(); i < e; ++i) {
698       p.printType(result_types[i]);
699       if (auto result_attrs = results_attr[i].dyn_cast<DictionaryAttr>())
700         p.printOptionalAttrDict(result_attrs.getValue());
701       if (i != e - 1) {
702         p << ", ";
703         p.printNewline();
704         p << indent;
705       }
706     }
707     p << ")";
708   }
709   // Print attributes.
710   if (!op->getAttrs().empty()) {
711     p.printNewline();
712     function_interface_impl::printFunctionAttributes(
713         p, *this, fnType.getNumInputs(), fnType.getNumResults(),
714         {"generic", SymbolTable::getVisibilityAttrName()});
715   }
716   // Print body.
717   p << ' ';
718   p.printRegion(body(), /*printEntryBlockArgs=*/false);
719 }
720 
getCalledFunction(Operation * op,SymbolTable & symbol_table)721 GraphFuncOp GraphFuncOp::getCalledFunction(Operation *op,
722                                            SymbolTable &symbol_table) {
723   // Check if a node does indirect function call via PartitionedCallOp.
724   // TODO(aminim): consider replacing with isa<...> when possible.
725   if (op->getName().getStringRef() == "tfg.PartitionCall" ||
726       op->getName().getStringRef() == "tfg.StatefulPartitionedCall") {
727     auto func_attr = op->getAttrOfType<FuncAttr>("f");
728     if (!func_attr) return {};
729     GraphFuncOp callee = symbol_table.lookup<GraphFuncOp>(
730         func_attr.getName().getLeafReference());
731     if (callee) return callee;
732   }
733   return symbol_table.lookup<GraphFuncOp>(op->getName().stripDialect());
734 }
735 
getDataValueOf(BlockArgument ctl)736 BlockArgument GraphFuncOp::getDataValueOf(BlockArgument ctl) {
737   return ctl.getOwner()->getArgument(ctl.getArgNumber() - 1);
738 }
739 
getControlTokenOf(BlockArgument data)740 BlockArgument GraphFuncOp::getControlTokenOf(BlockArgument data) {
741   return data.getOwner()->getArgument(data.getArgNumber() + 1);
742 }
743 
getDataValue(Region & region,unsigned idx)744 BlockArgument GraphFuncOp::getDataValue(Region &region, unsigned idx) {
745   return region.getArgument(idx * 2);
746 }
747 
748 // This is naming block arguments for GraphFuncOp, we rely on the arg attributes
749 // for computing the names.
getAsmBlockArgumentNames(Region & region,OpAsmSetValueNameFn set_name_fn)750 void GraphFuncOp::getAsmBlockArgumentNames(Region &region,
751                                            OpAsmSetValueNameFn set_name_fn) {
752   ArrayRef<BlockArgument> args = getBody()->getArguments();
753   ControlType control_ty = ControlType::get(getContext());
754   // Sanity checking: this is verified by the op but this may be called before
755   // the verifier or in some diagnostic/debug context, let's not crash.
756   // We expect the function block operands to come as pair: tensor+control.
757   if (args.size() % 2) return;
758   for (unsigned i = 0, e = args.size(); i < e; i += 2)
759     if (args[i].getType() == control_ty || args[i + 1].getType() != control_ty)
760       return;
761 
762   // Name the values based on the `tfg.name` arg attribute retrieved from the
763   // func_op.
764   ArrayAttr args_attr = getAllArgAttrs();
765   if (!args_attr || args_attr.size() != args.size()) return;
766   for (int arg_num = 0, e = args.size(); arg_num < e; arg_num += 2) {
767     DictionaryAttr arg_attrs = args_attr[arg_num].dyn_cast<DictionaryAttr>();
768     if (!arg_attrs) continue;
769     if (auto strAttr = arg_attrs.getAs<StringAttr>("tfg.name")) {
770       set_name_fn(args[arg_num], strAttr.getValue());
771       set_name_fn(args[arg_num + 1], (strAttr.getValue() + ".ctl").str());
772     }
773   }
774 }
775 
776 //===----------------------------------------------------------------------===//
777 // ReturnOp
778 //===----------------------------------------------------------------------===//
779 
verify()780 LogicalResult ReturnOp::verify() {
781   ReturnOp op = *this;
782   // If the control result attributes are present, there must be the same number
783   // of entries as control results.
784   if (op.control_ret_attrs().size() != TFOp(op).getControlOperands().size()) {
785     return op.emitOpError(
786         "expected as many control result attributes as there are control "
787         "operands");
788   }
789   return success();
790 }
791 
parse(OpAsmParser & parser,OperationState & result)792 ParseResult ReturnOp::parse(OpAsmParser &parser, OperationState &result) {
793   // ReturnOp has the same assembly format as generic TFG ops except that the
794   // control result attributes are embedded with the control operands:
795   // [%ctl {tfg.name = "foo"}, %ctl_0 {tfg.name = "bar"}]
796   SmallVector<OpAsmParser::UnresolvedOperand> operands;
797   if (parser.parseOperandList(operands, AsmParser::Delimiter::OptionalParen))
798     return failure();
799 
800   SmallVector<Attribute> control_ret_attrs;
801   if (succeeded(parser.parseOptionalLSquare())) {
802     OpAsmParser::UnresolvedOperand operand;
803     do {
804       NamedAttrList attrs;
805       OptionalParseResult parse_result = parser.parseOptionalOperand(operand);
806       if (!parse_result.hasValue()) break;
807       if (failed(parse_result.getValue())) return failure();
808       if (parser.parseOptionalAttrDict(attrs)) return failure();
809       control_ret_attrs.push_back(attrs.getDictionary(result.getContext()));
810       operands.push_back(std::move(operand));
811     } while (succeeded(parser.parseOptionalComma()));
812     if (parser.parseRSquare()) return failure();
813   }
814 
815   if (ParseKeywordAttributes(parser, result)) return failure();
816   result.addAttribute(ReturnOp::control_ret_attrsAttrName(result.name),
817                       ArrayAttr::get(result.getContext(), control_ret_attrs));
818 
819   SmallVector<Type> types;
820   if (parser.parseOptionalColonTypeList(types)) return failure();
821   types.resize(operands.size(), ControlType::get(result.getContext()));
822   if (parser.resolveOperands(operands, types, parser.getCurrentLocation(),
823                              result.operands))
824     return failure();
825   return success();
826 }
827 
print(OpAsmPrinter & printer)828 void ReturnOp::print(OpAsmPrinter &printer) {
829   TFOp tfg_op(*this);
830   OperandRange data = tfg_op.getNonControlOperands();
831   if (!data.empty()) printer << '(' << data << ')';
832 
833   OperandRange ctls = tfg_op.getControlOperands();
834   if (!ctls.empty()) {
835     printer << " [";
836     llvm::interleave(
837         llvm::zip(ctls, control_ret_attrs().getAsRange<DictionaryAttr>()),
838         printer,
839         [&](auto it) {
840           printer << std::get<0>(it);
841           if (!std::get<1>(it).empty()) printer << ' ' << std::get<1>(it);
842         },
843         ", ");
844     printer << ']';
845   }
846 
847   PrintKeywordAttributes(*this, printer, {"control_ret_attrs"});
848 
849   if (!data.empty()) printer << " : " << data.getTypes();
850 }
851 
build(OpBuilder & odsBuilder,OperationState & odsState,ValueRange operands,ValueRange control_operands)852 void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState,
853                      ValueRange operands, ValueRange control_operands) {
854   odsState.addOperands(operands);
855   odsState.addOperands(control_operands);
856   // Populate `control_ret_attrs` with empty dictionaries.
857   odsState.addAttribute(
858       ReturnOp::control_ret_attrsAttrName(odsState.name),
859       odsBuilder.getArrayAttr(SmallVector<Attribute>(
860           control_operands.size(), odsBuilder.getDictionaryAttr({}))));
861 }
862 
863 //===----------------------------------------------------------------------===//
864 // Concrete Ops
865 //===----------------------------------------------------------------------===//
866 
867 // The ODS definitions of TFG ops can be autogenerated TODO(jeffniu) as well as
868 // parts of their verifiers. These hand-written verifiers focus on verifying the
869 // ops' operand and result types with respect to their functions' types, the
870 // logic for which is slightly different between operations.
871 
872 // Verify that all control operands follow non-control operands, and return the
873 // subrange of non-control operands.
VerifyOperands(Operation * op)874 static FailureOr<TypeRange> VerifyOperands(Operation *op) {
875   ControlType control_ty =
876       cast<TFGraphDialect>(op->getDialect())->getControlType();
877   Operation::operand_type_iterator it =
878       llvm::find(op->getOperandTypes(), control_ty);
879   if (!std::all_of(it, op->operand_type_end(),
880                    [&](Type type) { return type == control_ty; })) {
881     return op->emitOpError(
882         "not all control tokens come after non-control operands");
883   }
884   return {Operation::operand_type_range(op->operand_type_begin(), it)};
885 }
886 
887 // Verify that the last result of an operation is the only control result, and
888 // return a subrange of the non-control results.
VerifyResults(Operation * op)889 static FailureOr<TypeRange> VerifyResults(Operation *op) {
890   ControlType control_ty =
891       cast<TFGraphDialect>(op->getDialect())->getControlType();
892   Operation::result_type_iterator it =
893       llvm::find(op->getResultTypes(), control_ty);
894   if (it == op->result_type_end())
895     return op->emitOpError("does not define a control result");
896   if (it != std::prev(op->result_type_end())) {
897     return op->emitOpError(
898         "must have a control token result as and only as its last result");
899   }
900   return {Operation::result_type_range(op->result_type_begin(), it)};
901 }
902 
903 // Verify that the signature of the function matches the operation's operands
904 // and results.
VerifySignature(GraphFuncOp func,Operation * op,TypeRange operands,TypeRange results,const Twine & func_name)905 static LogicalResult VerifySignature(GraphFuncOp func, Operation *op,
906                                      TypeRange operands, TypeRange results,
907                                      const Twine &func_name) {
908   auto attach_func = [&](InFlightDiagnostic diag) -> LogicalResult {
909     return diag.attachNote(func.getLoc()).appendOp(*func, OpPrintingFlags())
910            << "\nsee referenced function";
911   };
912 
913   ArrayRef<Type> arguments = func.getFunctionType().getInputs();
914   ArrayRef<Type> returns = func.getFunctionType().getResults();
915   if (operands.size() * 2 != arguments.size()) {
916     return attach_func(op->emitOpError(func_name)
917                        << " function has " << arguments.size() / 2
918                        << " arguments but was provided " << operands.size());
919   }
920   if (results.size() != returns.size()) {
921     return attach_func(op->emitOpError(func_name)
922                        << " function has " << returns.size()
923                        << " results but expected " << results.size());
924   }
925 
926   if (func.generic()) return success();
927 
928   for (auto &it : llvm::enumerate(operands)) {
929     Type arg_type = arguments[it.index() * 2];
930     Type op_type = it.value();
931     if (!tf_type::HasCompatibleElementTypes(arg_type, op_type)) {
932       return attach_func(
933           op->emitOpError(func_name)
934           << " function argument #" << it.index() << " type " << arg_type
935           << " is not compatible with corresponding operand type: " << op_type);
936     }
937   }
938   for (auto &it : llvm::enumerate(results)) {
939     Type ret_type = returns[it.index()];
940     Type res_type = it.value();
941     if (!tf_type::HasCompatibleElementTypes(ret_type, res_type)) {
942       return attach_func(
943           op->emitOpError(func_name)
944           << " function result #" << it.index() << " type " << ret_type
945           << " is not compatible with corresponding result type: " << res_type);
946     }
947   }
948   return success();
949 }
950 
951 // This function verifies that the types of `values`, which are either operands
952 // or results of `op`, match the types specified in `types`, which is expected
953 // to be an array of type attributes.
VerifyTypeArray(Operation * op,ValueRange values,ArrayAttr types,StringRef kind)954 static LogicalResult VerifyTypeArray(Operation *op, ValueRange values,
955                                      ArrayAttr types, StringRef kind) {
956   // Don't verify if the types are not present.
957   if (!types) return success();
958   if (values.size() != types.size()) {
959     return op->emitOpError("has ") << values.size() << " " << kind << "s but "
960                                    << types.size() << " " << kind << " types";
961   }
962   for (auto it :
963        llvm::zip(llvm::enumerate(values), types.getAsRange<TypeAttr>())) {
964     Type type = std::get<0>(it).value().getType();
965     Type dtype = std::get<1>(it).getValue();
966     if (!tf_type::HasCompatibleElementTypes(type,
967                                             UnrankedTensorType::get(dtype))) {
968       return op->emitOpError(kind)
969              << " #" << std::get<0>(it).index()
970              << " is incompatible with dtype " << dtype << ", got: " << type;
971     }
972   }
973   return success();
974 }
975 
976 namespace detail {
977 // Check if the op type has `T`.
978 template <typename OpT>
979 using has_T = decltype(std::declval<OpT>().T());
980 template <typename OpT>
981 using detect_has_T = llvm::is_detected<has_T, OpT>;
982 
983 // Get the input and output type arrays. If the op has a single type array,
984 // use it for both input and output. Otherwise, return separate type arrays.
985 template <typename OpT, bool = detect_has_T<OpT>::value>
986 struct GetTypeArray {
getInputTypesmlir::tfg::detail::GetTypeArray987   static ArrayAttr getInputTypes(OpT op) { return op.TinAttr(); }
getOutputTypesmlir::tfg::detail::GetTypeArray988   static ArrayAttr getOutputTypes(OpT op) { return op.ToutAttr(); }
989 };
990 template <typename OpT>
991 struct GetTypeArray<OpT, true> {
getInputTypesmlir::tfg::detail::GetTypeArray992   static ArrayAttr getInputTypes(OpT op) { return op.TAttr(); }
getOutputTypesmlir::tfg::detail::GetTypeArray993   static ArrayAttr getOutputTypes(OpT op) { return op.TAttr(); }
994 };
995 }  // namespace detail
996 
997 // Verify a functional op's inputs and outputs against its data type arrays. For
998 // loop ops, this also checks that the number of inputs and outputs match. This
999 // is guaranteed to be valid on import but may be violated by a transformation.
1000 template <typename OpT>
VerifyTypeArrayAttributes(OpT op)1001 static LogicalResult VerifyTypeArrayAttributes(OpT op) {
1002   using GetTypeArray = typename detail::GetTypeArray<OpT>;
1003   ValueRange args =
1004       SplitDataAndControlValues(op.args(), ControlType::get(op.getContext()))
1005           .first;
1006   return success(
1007       succeeded(VerifyTypeArray(op, args, GetTypeArray::getInputTypes(op),
1008                                 "argument")) &&
1009       succeeded(VerifyTypeArray(op, op.outs(), GetTypeArray::getOutputTypes(op),
1010                                 "result")));
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // If-Like Ops
1015 
1016 template <typename IfLikeOp>
VerifyIfLikeOp(IfLikeOp op,SymbolTableCollection & symbol_table)1017 static LogicalResult VerifyIfLikeOp(IfLikeOp op,
1018                                     SymbolTableCollection &symbol_table) {
1019   if (failed(op.verifyInvariants())) return failure();
1020   FailureOr<TypeRange> ins = VerifyOperands(op);
1021   if (failed(ins)) return failure();
1022   FailureOr<TypeRange> outs = VerifyResults(op);
1023   if (failed(outs)) return failure();
1024 
1025   // The first operand is the condition and is not passed to the functions.
1026   TypeRange func_args = ins->drop_front();
1027 
1028   auto then_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
1029       op, op.then_branch().getName());
1030   if (then_func &&
1031       failed(VerifySignature(then_func, op, func_args, *outs, "then")))
1032     return failure();
1033 
1034   auto else_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
1035       op, op.else_branch().getName());
1036   if (else_func &&
1037       failed(VerifySignature(else_func, op, func_args, *outs, "else")))
1038     return failure();
1039 
1040   return VerifyTypeArrayAttributes(op);
1041 }
1042 
1043 //===----------------------------------------------------------------------===//
1044 // Case-Like Ops
1045 
1046 template <typename CaseLikeOp>
VerifyCaseLikeOp(CaseLikeOp op,SymbolTableCollection & symbol_table)1047 static LogicalResult VerifyCaseLikeOp(CaseLikeOp op,
1048                                       SymbolTableCollection &symbol_table) {
1049   if (failed(op.verifyInvariants())) return failure();
1050   FailureOr<TypeRange> ins = VerifyOperands(op);
1051   if (failed(ins)) return failure();
1052   FailureOr<TypeRange> outs = VerifyResults(op);
1053   if (failed(outs)) return failure();
1054 
1055   // The first operand is the branch index and is not passed to the functions.
1056   TypeRange func_args = ins->drop_front();
1057 
1058   for (auto &it : llvm::enumerate(op.branches())) {
1059     SymbolRefAttr func_name = it.value().template cast<FuncAttr>().getName();
1060     auto func =
1061         symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(op, func_name);
1062     if (func && failed(VerifySignature(func, op, func_args, *outs,
1063                                        "branch #" + Twine(it.index()))))
1064       return failure();
1065   }
1066 
1067   return VerifyTypeArrayAttributes(op);
1068 }
1069 
1070 //===----------------------------------------------------------------------===//
1071 // While-Like Ops
1072 
1073 template <typename WhileLikeOp>
VerifyWhileLikeOp(WhileLikeOp op,SymbolTableCollection & symbol_table)1074 static LogicalResult VerifyWhileLikeOp(WhileLikeOp op,
1075                                        SymbolTableCollection &symbol_table) {
1076   if (failed(op.verifyInvariants())) return failure();
1077   FailureOr<TypeRange> ins = VerifyOperands(op);
1078   if (failed(ins)) return failure();
1079   FailureOr<TypeRange> outs = VerifyResults(op);
1080   if (failed(outs)) return failure();
1081 
1082   SymbolRefAttr body_name = op.body().getName();
1083 
1084   auto cond_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
1085       op, op.cond().getName());
1086   auto i1_type = UnrankedTensorType::get(Builder(op.getContext()).getI1Type());
1087   if (cond_func &&
1088       failed(VerifySignature(cond_func, op, *ins, i1_type, "cond")))
1089     return failure();
1090 
1091   auto body_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
1092       op, op.body().getName());
1093   if (body_func && failed(VerifySignature(body_func, op, *ins, *outs, "body")))
1094     return failure();
1095 
1096   return VerifyTypeArrayAttributes(op);
1097 }
1098 
1099 //===----------------------------------------------------------------------===//
1100 // ForOp
1101 
verifySymbolUses(SymbolTableCollection & symbolTable)1102 LogicalResult ForOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1103   if (failed(verifyInvariants())) return failure();
1104   FailureOr<TypeRange> ins = VerifyOperands(*this);
1105   if (failed(ins)) return failure();
1106   FailureOr<TypeRange> outs = VerifyResults(*this);
1107   if (failed(outs)) return failure();
1108 
1109   auto body_func =
1110       symbolTable.lookupNearestSymbolFrom<GraphFuncOp>(*this, body().getName());
1111   // The first three arguments are the for-loop indices, but the current loop
1112   // index is passed in.
1113   TypeRange func_args = llvm::drop_begin(*ins, /*N=*/2);
1114   if (body_func &&
1115       failed(VerifySignature(body_func, *this, func_args, *outs, "body")))
1116     return failure();
1117 
1118   return VerifyTypeArrayAttributes(*this);
1119 }
1120 
1121 //===----------------------------------------------------------------------===//
1122 // Region Ops and Terminators
1123 //===----------------------------------------------------------------------===//
1124 
1125 // If a region op has preserved attributes, verify that they match the number of
1126 // results and block arguments.
VerifyPreservedAttrs(Operation * op,ArrayRef<Attribute> preserved_attrs)1127 static LogicalResult VerifyPreservedAttrs(Operation *op,
1128                                           ArrayRef<Attribute> preserved_attrs) {
1129   assert(op->getNumRegions() == preserved_attrs.size());
1130   for (auto it : llvm::zip(preserved_attrs, op->getRegions())) {
1131     // Preserved attributes for a particular region may not exist.
1132     auto attrs = std::get<0>(it).dyn_cast_or_null<RegionAttr>();
1133     if (!attrs) continue;
1134     Region &region = std::get<1>(it);
1135 
1136     const auto emit_region_error = [&](StringRef msg) {
1137       return op->emitOpError("region #")
1138              << region.getRegionNumber() << " " << msg;
1139     };
1140 
1141     unsigned num_args = GetLoopRegionDataArgs(region).size();
1142     if (num_args != attrs.getArgAttrs().size()) {
1143       return emit_region_error("has ")
1144              << num_args << " argument(s) but preserved attributes has "
1145              << attrs.getArgAttrs().size();
1146     }
1147 
1148     // All regions are terminated by either a YieldOp or a ConditionOp. In the
1149     // latter case, the function will only have one result.
1150     unsigned num_rets;
1151     Operation *terminator = region.front().getTerminator();
1152     if (isa<ConditionOp>(terminator)) {
1153       num_rets = 1;
1154     } else {
1155       num_rets = cast<RegionBranchTerminatorOpInterface>(terminator)
1156                      .getMutableSuccessorOperands(region.getRegionNumber())
1157                      .size();
1158     }
1159     if (num_rets != attrs.getResAttrs().size()) {
1160       return emit_region_error("has ")
1161              << num_rets << " result(s) but preserved attributes has "
1162              << attrs.getResAttrs().size();
1163     }
1164   }
1165   return success();
1166 }
1167 
1168 //===----------------------------------------------------------------------===//
1169 // YieldOp
1170 
getMutableSuccessorOperands(Optional<unsigned> index)1171 MutableOperandRange YieldOp::getMutableSuccessorOperands(
1172     Optional<unsigned> index) {
1173   // Get the subrange of non-control operands.
1174   return argsMutable();
1175 }
1176 
TerminatedByYield(Block & block)1177 static bool TerminatedByYield(Block &block) {
1178   return isa<YieldOp>(block.getTerminator());
1179 }
1180 
1181 //===----------------------------------------------------------------------===//
1182 // IfLikeRegionOp
1183 
1184 // Verify an if-like region op.
1185 template <typename IfLikeRegionOp>
VerifyIfLikeRegionOp(IfLikeRegionOp op)1186 static LogicalResult VerifyIfLikeRegionOp(IfLikeRegionOp op) {
1187   // Verify terminators.
1188   if (!TerminatedByYield(op.then_block()))
1189     return op.emitOpError("then region must be terminated by a 'tfg.yield'");
1190   if (!TerminatedByYield(op.else_block()))
1191     return op.emitOpError("else region must be terminated by a 'tfg.yield'");
1192   return VerifyPreservedAttrs(
1193       op, {op.then_region_attrsAttr(), op.else_region_attrsAttr()});
1194 }
1195 
1196 // Given an potentially null attribute that would represent a constant value,
1197 // try to narrow it to a statically known condition.
1198 // TODO(jeffniu): Incorporate the other cases of `tf.ToBool`.
GetStaticallyKnownBranch(Attribute cond_attr)1199 static Optional<bool> GetStaticallyKnownBranch(Attribute cond_attr) {
1200   // Only handle the case of a scalar tensor of i1.
1201   auto cond = cond_attr.dyn_cast_or_null<ElementsAttr>();
1202   if (cond && cond.getNumElements() == 1 &&
1203       cond.getElementType().isSignlessInteger(1))
1204     return cond.getSplatValue<bool>();
1205   return {};
1206 }
1207 
1208 // Get the successor of the regions of an if-like op.
1209 template <typename IfLikeRegionOp>
GetIfLikeRegionOpSuccessorRegions(IfLikeRegionOp op,Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1210 void GetIfLikeRegionOpSuccessorRegions(
1211     IfLikeRegionOp op, Optional<unsigned> index, ArrayRef<Attribute> operands,
1212     SmallVectorImpl<RegionSuccessor> &regions) {
1213   assert(index.has_value() ||
1214          !operands.empty() && "if-like op expected at least 1 operand");
1215   // Both regions branch back to the parent op.
1216   if (index.has_value()) {
1217     // Ignore the control token.
1218     regions.emplace_back(
1219         ResultRange(op->result_begin(), std::prev(op->result_end())));
1220   } else if (auto cond = GetStaticallyKnownBranch(operands[0])) {
1221     // Add only 1 possible successor if the condition is known.
1222     Region &region = *cond ? op.then_region() : op.else_region();
1223     regions.emplace_back(&region, GetLoopRegionDataArgs(region));
1224   } else {
1225     // Unknown successor.
1226     regions.emplace_back(&op.then_region(),
1227                          GetLoopRegionDataArgs(op.then_region()));
1228     regions.emplace_back(&op.else_region(),
1229                          GetLoopRegionDataArgs(op.else_region()));
1230   }
1231 }
1232 
1233 //===----------------------------------------------------------------------===//
1234 // CaseLikeRegionOp
1235 
1236 // Verify a case-like region op.
1237 template <typename CaseLikeRegionOp>
VerifyCaseLikeRegionOp(CaseLikeRegionOp op)1238 static LogicalResult VerifyCaseLikeRegionOp(CaseLikeRegionOp op) {
1239   for (auto &it : llvm::enumerate(op.branches())) {
1240     if (!TerminatedByYield(it.value().front())) {
1241       return op.emitOpError("branch region #")
1242              << it.index() << " is not terminated by a 'tfg.yield' op";
1243     }
1244   }
1245 
1246   if (op.branch_attrs() && op.branches().size() != op.branch_attrs()->size()) {
1247     return op.emitOpError("has ")
1248            << op.branches().size() << " regions but "
1249            << op.branch_attrs()->size() << " branch function attributes";
1250   }
1251   if (auto region_attrs = op.region_attrsAttr()) {
1252     if (region_attrs.size() != op.getNumRegions()) {
1253       return op.emitOpError("expected ")
1254              << op.getNumRegions() << " region attribute(s) but got "
1255              << region_attrs.size();
1256     }
1257     if (failed(VerifyPreservedAttrs(op, region_attrs.getValue())))
1258       return failure();
1259   }
1260   return success();
1261 }
1262 
1263 // Given a potentially null attribute that would represent a constant value,
1264 // try to narrow it to a statically known branch index.
GetStaticallyKnownCaseBranch(Attribute branch_attr)1265 static Optional<unsigned> GetStaticallyKnownCaseBranch(Attribute branch_attr) {
1266   auto branch = branch_attr.dyn_cast_or_null<ElementsAttr>();
1267   if (branch && branch.getNumElements() == 1 &&
1268       branch.getElementType().isSignlessInteger(32))
1269     return branch.getSplatValue<unsigned>();
1270   return {};
1271 }
1272 
1273 // Get the successor of the regions of a case-like op.
1274 template <typename CaseLikeRegionOp>
GetCaseLikeRegionOpSuccessorRegions(CaseLikeRegionOp op,Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1275 void GetCaseLikeRegionOpSuccessorRegions(
1276     CaseLikeRegionOp op, Optional<unsigned> index, ArrayRef<Attribute> operands,
1277     SmallVectorImpl<RegionSuccessor> &regions) {
1278   assert(index.has_value() ||
1279          !operands.empty() && "case-like op expected at least 1 operand");
1280   // All branch regions branch back to the parent op.
1281   if (index.has_value()) {
1282     // Ignore the control token.
1283     regions.emplace_back(
1284         ResultRange(op->result_begin(), std::prev(op->result_end())));
1285   } else if (auto branch_index = GetStaticallyKnownCaseBranch(operands[0])) {
1286     // Add only 1 possible successor if the condition is known.
1287     Region &region = op.branches()[*branch_index];
1288     regions.emplace_back(&region, GetLoopRegionDataArgs(region));
1289   } else {
1290     // Unknown successor. Add all of them.
1291     for (Region &branch : op.branches())
1292       regions.emplace_back(&branch, GetLoopRegionDataArgs(branch));
1293   }
1294 }
1295 
1296 //===----------------------------------------------------------------------===//
1297 // ConditionOp
1298 
getMutableSuccessorOperands(Optional<unsigned> index)1299 MutableOperandRange ConditionOp::getMutableSuccessorOperands(
1300     Optional<unsigned> index) {
1301   // Get the subrange of non-control operands that are forwarded to the
1302   // successor region.
1303   return argsMutable();
1304 }
1305 
1306 //===----------------------------------------------------------------------===//
1307 // WhileLikeRegionOp
1308 
1309 // Verify that the loop regions of a region-based loop op have N control tokens
1310 // immediately following N data values in their entry block arguments.
1311 // `RegionBranchOpInterface` will verify the number of arguments and their
1312 // types.
VerifyLoopRegionArgs(Operation * op,Region & region)1313 static LogicalResult VerifyLoopRegionArgs(Operation *op, Region &region) {
1314   const auto arg_error = [&](BlockArgument arg) {
1315     return op->emitOpError("region #")
1316            << region.getRegionNumber() << " argument #" << arg.getArgNumber()
1317            << " ";
1318   };
1319 
1320   // The interface trait verifies the number of data and control arguments. If
1321   // the first half of the arguments are not control tokens, then we know for
1322   // sure that the second half is only control tokens.
1323   for (BlockArgument data : GetLoopRegionDataArgs(region))
1324     if (data.getType().isa<ControlType>())
1325       return arg_error(data) << "should not be a control token";
1326   return success();
1327 }
1328 
1329 // Verify a while-like region op.
1330 template <typename WhileLikeRegionOp>
VerifyWhileLikeRegionOp(WhileLikeRegionOp op)1331 static LogicalResult VerifyWhileLikeRegionOp(WhileLikeRegionOp op) {
1332   // Verify terminators.
1333   if (!isa<ConditionOp>(op.cond_block().getTerminator())) {
1334     return op.emitOpError(
1335         "condition region must be terminated by a 'tfg.condition' op");
1336   }
1337   if (!TerminatedByYield(op.body_block()))
1338     op.emitOpError("body region must be terminated by a 'tfg.yield' op");
1339 
1340   if (failed(VerifyLoopRegionArgs(op, op.cond_region())) ||
1341       failed(VerifyLoopRegionArgs(op, op.body_region())))
1342     return failure();
1343   if (failed(VerifyPreservedAttrs(
1344           op, {op.cond_region_attrsAttr(), op.body_region_attrsAttr()})))
1345     return failure();
1346 
1347   return success();
1348 }
1349 
1350 template <typename WhileLikeRegionOp>
GetWhileLikeRegionOpSuccessorRegions(WhileLikeRegionOp op,Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1351 static void GetWhileLikeRegionOpSuccessorRegions(
1352     WhileLikeRegionOp op, Optional<unsigned> index,
1353     ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
1354   // The parent op and the body region always branch to the condion region.
1355   if (!index || *index == 1) {
1356     regions.emplace_back(&op.cond_region(),
1357                          GetLoopRegionDataArgs(op.cond_region()));
1358     return;
1359   }
1360   assert(*index == 0 && "invalid region index");
1361   // The condition regions branches to the loop body or back to the parent.
1362   // Try to narrow the condition value to a constant.
1363   auto condition = cast<ConditionOp>(op.cond_region().front().getTerminator());
1364   Attribute cond_attr;
1365   matchPattern(condition.cond(), m_Constant(&cond_attr));
1366   Optional<bool> cond = GetStaticallyKnownBranch(cond_attr);
1367   if (!cond || *cond) {
1368     regions.emplace_back(&op.body_region(),
1369                          GetLoopRegionDataArgs(op.body_region()));
1370   }
1371   if (!cond || !*cond) {
1372     // Drop the control token.
1373     regions.emplace_back(op.getResults().drop_back());
1374   }
1375 }
1376 
1377 //===----------------------------------------------------------------------===//
1378 // ForRegionOp
1379 
verify()1380 LogicalResult ForRegionOp::verify() {
1381   if (!TerminatedByYield(body_block())) {
1382     return emitOpError("body region must be terminated by a 'tfg.yield' op");
1383   }
1384 
1385   Block::BlockArgListType args = body_block().getArguments();
1386   if (args.empty()) {
1387     return emitOpError(
1388         "expected the body block to have at least have the loop index as an "
1389         "argument");
1390   }
1391   auto index = args.front().getType().dyn_cast<TensorType>();
1392   if (!index || !index.getElementType().isSignlessInteger(32)) {
1393     return emitOpError(
1394         "expected first body block argument to be an i32 tensor");
1395   }
1396 
1397   if (failed(VerifyLoopRegionArgs(*this, body_region()))) return failure();
1398   return VerifyPreservedAttrs(*this, {region_attrsAttr()});
1399 }
1400 
getSuccessorEntryOperands(Optional<unsigned> index)1401 OperandRange ForRegionOp::getSuccessorEntryOperands(Optional<unsigned> index) {
1402   return init();
1403 }
1404 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1405 void ForRegionOp::getSuccessorRegions(
1406     Optional<unsigned> index, ArrayRef<Attribute> operands,
1407     SmallVectorImpl<RegionSuccessor> &regions) {
1408   // Both the parent op and the body region branch to the body. Ignore the loop
1409   // index block argument, as it is not modified by the loop body itself.
1410   regions.emplace_back(&body_region(),
1411                        GetLoopRegionDataArgs(body_region()).drop_front());
1412   if (!index) return;
1413   // The body might branch back to the parent. Drop the control token.
1414   regions.emplace_back((*this)->getResults().drop_back());
1415 }
1416 
getDataValueOf(BlockArgument ctl)1417 BlockArgument ForRegionOp::getDataValueOf(BlockArgument ctl) {
1418   return GetLoopRegionDataOf(ctl);
1419 }
getControlTokenOf(BlockArgument data)1420 BlockArgument ForRegionOp::getControlTokenOf(BlockArgument data) {
1421   return GetLoopRegionControlOf(data);
1422 }
getDataValue(Region & region,unsigned idx)1423 BlockArgument ForRegionOp::getDataValue(Region &region, unsigned idx) {
1424   return GetLoopRegionDataArgs(region)[idx];
1425 }
getControlToken(Region & region,unsigned idx)1426 BlockArgument ForRegionOp::getControlToken(Region &region, unsigned idx) {
1427   return GetLoopRegionControlTokens(region)[idx];
1428 }
1429 
1430 //===----------------------------------------------------------------------===//
1431 // Function Table
1432 //===----------------------------------------------------------------------===//
1433 
FunctionTable(ModuleOp module)1434 FunctionTable::FunctionTable(ModuleOp module) {
1435   // Collect function names (to be used for disambiguating legacy call
1436   // behavior).
1437   for (auto &op : module.getOps()) {
1438     if (auto func = dyn_cast<GraphFuncOp>(op)) functions.insert(func.getName());
1439   }
1440 }
1441 
MayBeCall(Operation * op) const1442 bool FunctionTable::MayBeCall(Operation *op) const {
1443   if (IsLegacyCall(op)) return true;
1444   // The operation might be a call if it references a symbol.
1445   bool references_symbol = false;
1446   op->getAttrDictionary().walkSubAttrs(
1447       [&](Attribute attr) { references_symbol |= attr.isa<SymbolRefAttr>(); });
1448   return references_symbol;
1449 }
1450 
IsLegacyCall(Operation * op) const1451 bool FunctionTable::IsLegacyCall(Operation *op) const {
1452   // If the operation name refers to a function in the module, then it is
1453   // guaranteed to be a legacy call. Otherwise, it is not.
1454   return functions.count(op->getName().stripDialect());
1455 }
1456 
1457 }  // namespace tfg
1458 }  // namespace mlir
1459 
1460 //===----------------------------------------------------------------------===//
1461 // ODS Definitions
1462 //===----------------------------------------------------------------------===//
1463 
1464 #define GET_OP_CLASSES
1465 #include "tensorflow/core/ir/ops.cc.inc"
1466 #define GET_ATTRDEF_CLASSES
1467 #include "tensorflow/core/ir/attributes.cc.inc"
1468