1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/ir/ops.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <list>
21 #include <memory>
22 #include <string>
23 #include <utility>
24
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Support/SMLoc.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
35 #include "mlir/IR/Diagnostics.h" // from @llvm-project
36 #include "mlir/IR/Dialect.h" // from @llvm-project
37 #include "mlir/IR/DialectImplementation.h" // from @llvm-project
38 #include "mlir/IR/FunctionImplementation.h" // from @llvm-project
39 #include "mlir/IR/FunctionInterfaces.h" // from @llvm-project
40 #include "mlir/IR/MLIRContext.h" // from @llvm-project
41 #include "mlir/IR/OpImplementation.h" // from @llvm-project
42 #include "mlir/IR/OperationSupport.h" // from @llvm-project
43 #include "mlir/IR/SymbolTable.h" // from @llvm-project
44 #include "mlir/IR/TypeRange.h" // from @llvm-project
45 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
46 #include "mlir/IR/Value.h" // from @llvm-project
47 #include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
48 #include "mlir/Support/LLVM.h" // from @llvm-project
49 #include "mlir/Support/LogicalResult.h" // from @llvm-project
50 #include "tensorflow/core/ir/dialect.h"
51 #include "tensorflow/core/ir/interfaces.h"
52 #include "tensorflow/core/ir/types/dialect.h"
53 #include "tensorflow/core/ir/utility.h"
54
55 // Generated definitions.
56 #include "tensorflow/core/ir/dialect.cc.inc"
57
58 namespace mlir {
59 namespace tfg {
60
61 //===----------------------------------------------------------------------===//
62 // TFGraph dialect.
63 //===----------------------------------------------------------------------===//
64
65 // Name operation results with the operation name, except control outputs which
66 // are named "ctl". MLIR will automatically use a numerical suffix to unique.
GenericGetAsmResultNames(Operation * op,OpAsmSetValueNameFn set_name_fn)67 static void GenericGetAsmResultNames(Operation *op,
68 OpAsmSetValueNameFn set_name_fn) {
69 // We only name the results when there are results to name, an op like `print`
70 // which does not have results will just use the `ctl` name for the control
71 // output.
72 if (op->getNumResults() > 1 && !op->getResult(0).getType().isa<ControlType>())
73 set_name_fn(op->getResult(0), op->getName().stripDialect());
74 for (Value result : op->getResults()) {
75 if (result.getType().isa<ControlType>()) {
76 set_name_fn(op->getResult(op->getNumResults() - 1), "ctl");
77 break;
78 }
79 }
80 }
81
82 // TFGraph support for interacting with the AsmPrinter.
83 // Gives prettier names to SSA values.
84 struct TFGraphOpAsmInterface
85 : public OpAsmOpInterface::FallbackModel<TFGraphOpAsmInterface> {
classofmlir::tfg::TFGraphOpAsmInterface86 static bool classof(Operation *op) { return true; }
87
getAsmResultNamesmlir::tfg::TFGraphOpAsmInterface88 void getAsmResultNames(Operation *op, OpAsmSetValueNameFn set_name_fn) const {
89 GenericGetAsmResultNames(op, set_name_fn);
90 }
getAsmBlockArgumentNamesmlir::tfg::TFGraphOpAsmInterface91 void getAsmBlockArgumentNames(Operation *op, Region ®ion,
92 OpAsmSetValueNameFn setNameFn) const {}
getAsmBlockNamesmlir::tfg::TFGraphOpAsmInterface93 void getAsmBlockNames(Operation *op,
94 mlir::OpAsmSetBlockNameFn setNameFn) const {}
95 };
96
97 // Dialect construction: there is one instance per context and it registers its
98 // operations, types, and interfaces here.
initialize()99 void TFGraphDialect::initialize() {
100 getContext()->getOrLoadDialect<tf_type::TFTypeDialect>();
101 addOperations<
102 #define GET_OP_LIST
103 #include "tensorflow/core/ir/ops.cc.inc"
104 >();
105 addAttributes<
106 #define GET_ATTRDEF_LIST
107 #include "tensorflow/core/ir/attributes.cc.inc"
108 >();
109
110 // Support unknown operations because not all TensorFlow operations are
111 // registered.
112 allowUnknownOperations();
113
114 // Create the fallback OpAsmOpInterface instance.
115 fallbackOpAsmInterface_ = new TFGraphOpAsmInterface;
116
117 // Register the memory effects interface adaptor.
118 addInterfaces<StatefulMemoryEffectInterface>();
119
120 // Initialized the cached operation names.
121 #define GET_OP_NAME_DEFS
122 #include "tensorflow/core/ir/tf_op_names.inc"
123
124 // Caching some often used context-owned informations for fast-access.
125 name_key_ = StringAttr::get(getContext(), getNameAttrKey());
126 device_key_ = StringAttr::get(getContext(), getDeviceAttrKey());
127 assigned_device_key_ =
128 StringAttr::get(getContext(), getAssignedDeviceAttrKey());
129 fulltype_key_ = StringAttr::get(getContext(), getFullTypeAttrKey());
130 lifted_graph_func_name_ =
131 StringAttr::get(getContext(), getLiftedGraphFuncNameKey());
132 tfg_name_key_ = StringAttr::get(getContext(), getTfgNameAttrKey());
133 tfg_description_key_ =
134 StringAttr::get(getContext(), getTfgDescriptionAttrKey());
135 tfg_is_ref_key_ = StringAttr::get(getContext(), getTfgIsRefAttrKey());
136 tfg_handle_data_key_ =
137 StringAttr::get(getContext(), getTfgHandleDataAttrKey());
138 tfg_full_type_key_ = StringAttr::get(getContext(), getTfgFullTypeAttrKey());
139
140 control_ty_ = ControlType::get(getContext());
141 }
142
143 // Provides a hook for op interface.
getRegisteredInterfaceForOp(TypeID interface,OperationName opName)144 void *TFGraphDialect::getRegisteredInterfaceForOp(TypeID interface,
145 OperationName opName) {
146 if (interface == TypeID::get<OpAsmOpInterface>()) {
147 return fallbackOpAsmInterface_;
148 }
149
150 // Intrinsic operations explicitly implement intefaces.
151 if (opName.hasTrait<OpTrait::IntrinsicOperation>()) {
152 return nullptr;
153 }
154
155 if (interface == TypeID::get<TensorFlowRegistryInterface>()) {
156 if (auto *instance =
157 getRegisteredInterface<TensorFlowRegistryInterfaceBase>()) {
158 // Important: cast to (Concept *) to shift the pointer off the vtable.
159 return static_cast<TensorFlowRegistryInterfaceBase::Concept *>(
160 const_cast<TensorFlowRegistryInterfaceBase *>(instance));
161 }
162 } else if (interface == TypeID::get<MemoryEffectOpInterface>()) {
163 auto *instance = getRegisteredInterface<StatefulMemoryEffectInterface>();
164 assert(instance && "expected the memory interface to be registered");
165 return static_cast<StatefulMemoryEffectInterface::Concept *>(
166 const_cast<StatefulMemoryEffectInterface *>(instance));
167 }
168
169 return nullptr;
170 }
171
~TFGraphDialect()172 TFGraphDialect::~TFGraphDialect() { delete fallbackOpAsmInterface_; }
173
174 // The name of certain optional attributes.
175 static std::array<StringRef, 3> keyword_attrs{
176 "_mlir_device", "_mlir_assigned_device", "_mlir_name"};
177
PrintKeywordAttributes(Operation * op,OpAsmPrinter & printer,ArrayRef<StringRef> elided_attrs={})178 static void PrintKeywordAttributes(Operation *op, OpAsmPrinter &printer,
179 ArrayRef<StringRef> elided_attrs = {}) {
180 // Handles the optional "device" and "name" attribute.
181 for (StringRef keyword : keyword_attrs) {
182 if (StringAttr value_attr = op->getAttrOfType<StringAttr>(keyword)) {
183 assert(!value_attr.getValue().empty());
184 printer << " " << keyword.drop_front(/*len(_mlir_)*/ 6) << "(\""
185 << value_attr.getValue() << "\")";
186 }
187 }
188
189 // Print attributes (other than name and device).
190 SmallVector<StringRef> attrs_to_elide = llvm::to_vector(elided_attrs);
191 llvm::append_range(attrs_to_elide, keyword_attrs);
192 printer.printOptionalAttrDict(op->getAttrs(), attrs_to_elide);
193 }
194
195 // Print an operation that belongs to this dialect, if unregistered.
196 // The general syntax is:
197 // tfg.OpName(%input1, %input2, %input3) [%control_dep1, %control_dep2]
198 // name("<node_name>") device("<device>") { attribute-dict } :
199 // (input types) -> (result_types)
printCustomTfOp(Operation * op,OpAsmPrinter & printer) const200 void TFGraphDialect::printCustomTfOp(Operation *op,
201 OpAsmPrinter &printer) const {
202 ControlType control_ty = getControlType();
203
204 // Check that all control dependencies are after the regular values,
205 // otherwise print the generic form. We don't expect this to happen but
206 // we're defensive in the printer since this may happen in "hard-to-debug"
207 // issues.
208 {
209 bool has_control_dep = false;
210 for (Value operand : op->getOperands()) {
211 if (operand.getType() == control_ty) {
212 has_control_dep = true;
213 continue;
214 }
215 if (has_control_dep) {
216 printer.printGenericOp(op);
217 return;
218 }
219 }
220 has_control_dep = false;
221 for (Value result : op->getResults()) {
222 if (result.getType() == control_ty) {
223 has_control_dep = true;
224 continue;
225 }
226 if (has_control_dep) {
227 printer.printGenericOp(op);
228 return;
229 }
230 }
231 }
232
233 // Print the inputs (other than the control dependencies), if any.
234 TFOp tfg_op(op);
235 OperandRange data = tfg_op.getNonControlOperands();
236 if (!data.empty()) printer << '(' << data << ')';
237 // Print the control dependencies (if any).
238 OperandRange ctls = tfg_op.getControlOperands();
239 if (!ctls.empty()) printer << " [" << ctls << ']';
240
241 // Print the keyword attributes and optional attribute dictionary.
242 PrintKeywordAttributes(op, printer);
243
244 // Print the type, but omit control dependencies.
245 // If there is a single control return, just print the list of input types,
246 // otherwise print the complete type in a "function-style" way: (operands)
247 // -> (results).
248 ResultRange results = tfg_op.getNonControlResults();
249 if (results.empty()) {
250 if (!data.empty()) printer << " : " << data.getTypes();
251 } else {
252 printer << " : (" << data.getTypes() << ") -> (" << results.getTypes()
253 << ")";
254 }
255 }
256
257 // Print a custom TFG op.
PrintCustomTfOp(Operation * op,OpAsmPrinter & printer)258 static void PrintCustomTfOp(Operation *op, OpAsmPrinter &printer) {
259 cast<TFGraphDialect>(op->getDialect())->printCustomTfOp(op, printer);
260 }
261
262 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
getOperationPrinter(Operation * op) const263 TFGraphDialect::getOperationPrinter(Operation *op) const {
264 return [this](Operation *op, OpAsmPrinter &printer) {
265 this->printCustomTfOp(op, printer);
266 };
267 }
268
269 // Try to parse optional keyword attributes and prefix them with `_mlir_`, of
270 // `device`, `assigned_device`, and `name`.
ParseKeywordAttributes(OpAsmParser & parser,OperationState & result)271 static ParseResult ParseKeywordAttributes(OpAsmParser &parser,
272 OperationState &result) {
273 for (const char *keyword : {"device", "assigned_device", "name"}) {
274 if (succeeded(parser.parseOptionalKeyword(keyword))) {
275 StringAttr value;
276 if (parser.parseLParen() ||
277 parser.parseAttribute<StringAttr>(
278 value, NoneType::get(parser.getContext())) ||
279 parser.parseRParen())
280 return failure();
281 result.addAttribute((Twine("_mlir_") + keyword).str(), value);
282 }
283 }
284 return parser.parseOptionalAttrDict(result.attributes);
285 }
286
287 // Parse an operation that belongs to this dialect, if unregistered.
288 // The general syntax is:
289 // tfg.OpName(%input1, %input2, %input3) [%control_dep1, %control_dep2]
290 // name("<node_name>") device("<device>") { attribute-dict } :
291 // (input types) -> (result_types)
ParseCustomTfOp(OpAsmParser & parser,OperationState & result)292 static ParseResult ParseCustomTfOp(OpAsmParser &parser,
293 OperationState &result) {
294 MLIRContext *context = parser.getBuilder().getContext();
295 // Parse optional argument list
296 SmallVector<OpAsmParser::UnresolvedOperand, 4> op_infos;
297 if (parser.parseOperandList(op_infos, AsmParser::Delimiter::OptionalParen))
298 return failure();
299 unsigned numNonControlOperands = op_infos.size();
300 // Optional control list, in between brackets.
301 if (parser.parseOperandList(op_infos, AsmParser::Delimiter::OptionalSquare))
302 return failure();
303
304 // Parse the optional keyword attributes and optional attribute dictionary.
305 if (ParseKeywordAttributes(parser, result)) return failure();
306
307 // Parse the functional type.
308 SmallVector<Type> arg_types;
309 arg_types.reserve(op_infos.size());
310 llvm::SMLoc loc = parser.getCurrentLocation();
311 Type control_type = ControlType::get(context);
312 if (failed(parser.parseOptionalColonTypeList(arg_types))) return failure();
313 if (arg_types.size() == 1 && arg_types.front().isa<FunctionType>()) {
314 auto funcType = arg_types.front().cast<FunctionType>();
315 if (funcType.getNumInputs() != numNonControlOperands)
316 return parser.emitError(loc)
317 << "got " << numNonControlOperands
318 << " non-control operands, but the type defines "
319 << funcType.getNumInputs() << " input types";
320 arg_types.clear();
321 arg_types.append(funcType.getInputs().begin(), funcType.getInputs().end());
322 result.types.append(funcType.getResults().begin(),
323 funcType.getResults().end());
324 }
325
326 // The control input are elided from the type list, add them here.
327 arg_types.resize(op_infos.size(), control_type);
328 if (!arg_types.empty())
329 if (parser.resolveOperands(op_infos, arg_types, loc, result.operands))
330 return failure();
331 if (result.name.getStringRef() != "tfg.return")
332 result.types.push_back(control_type);
333 return success();
334 }
335
getParseOperationHook(StringRef opName) const336 Optional<Dialect::ParseOpHook> TFGraphDialect::getParseOperationHook(
337 StringRef opName) const {
338 return ParseOpHook(ParseCustomTfOp);
339 }
340
VerifyGenericTFGOperation(Operation & op)341 static bool VerifyGenericTFGOperation(Operation &op) {
342 TFGraphDialect *dialect = dyn_cast<TFGraphDialect>(op.getDialect());
343 if (!dialect) return true;
344 ControlType control_ty = dialect->getControlType();
345
346 // verifies that control operands (or results) are always after regular
347 // inputs (or results).
348 auto check_ctl_at_end = [&](TypeRange types, StringRef input_or_output) {
349 int has_control_dep = -1;
350 for (auto &indexed_operand : llvm::enumerate(types)) {
351 if (indexed_operand.value() == control_ty) {
352 has_control_dep = indexed_operand.index();
353 continue;
354 }
355 if (has_control_dep != -1) {
356 op.emitOpError() << "found non-control " << input_or_output
357 << " in position #" << indexed_operand.index()
358 << " after control " << input_or_output
359 << " in position #" << has_control_dep;
360 return false;
361 }
362 }
363 return true;
364 };
365 if (!check_ctl_at_end(op.getOperandTypes(), "input")) return false;
366 if (!check_ctl_at_end(op.getResultTypes(), "result")) return false;
367
368 // Certain attributes are supposed to be inserted with non-empty value.
369 for (StringRef keyword : keyword_attrs) {
370 if (StringAttr value_attr = op.getAttrOfType<StringAttr>(keyword)) {
371 if (value_attr.getValue().empty()) {
372 op.emitOpError() << keyword
373 << " has empty value. Only insert this attribute when "
374 "it has a value";
375 }
376 }
377 }
378
379 return true;
380 }
381
382 //===----------------------------------------------------------------------===//
383 // Graph Operation
384 //===----------------------------------------------------------------------===//
385
verify()386 LogicalResult GraphOp::verify() {
387 GraphOp op = *this;
388 // Check all ops in the body.
389 if (!all_of(*op.getBody(), VerifyGenericTFGOperation)) return failure();
390
391 return success();
392 }
393 //===----------------------------------------------------------------------===//
394 // Func Operation
395 //===----------------------------------------------------------------------===//
396
isMarkedForCompilation()397 bool GraphFuncOp::isMarkedForCompilation() {
398 auto is_enabled = [this](StringRef attr_name) -> bool {
399 Attribute attr = (*this)->getAttr(attr_name);
400 if (!attr) return false;
401 if (auto bool_attr = attr.dyn_cast<BoolAttr>()) return bool_attr.getValue();
402 if (auto str_attr = attr.dyn_cast<StringAttr>())
403 return !str_attr.getValue().empty();
404 return false;
405 };
406 return is_enabled("_xla_compile_id") || is_enabled("_tpu_replicate") ||
407 is_enabled("_XlaMustCompile");
408 }
409
410 // Hook for OpTrait::FunctionLike, called after verifying that the 'type'
411 // attribute is present and checks if it holds a function type. Ensures
412 // getType, getNumFuncArguments, and getNumFuncResults can be called safely
verifyType()413 LogicalResult GraphFuncOp::verifyType() {
414 auto type = getFunctionTypeAttr().getValue();
415 if (!type.isa<FunctionType>())
416 return emitOpError("requires '" + getTypeAttrName() +
417 "' attribute of function type");
418 return success();
419 }
420
421 // Hook for OpTrait::FunctionLike, called after verifying the function
422 // type and the presence of the (potentially empty) function body.
verifyBody()423 LogicalResult GraphFuncOp::verifyBody() {
424 FunctionType type = getFunctionType();
425 // Check that the body is terminated with a tfg.return.
426 if (getRegion().empty() || getBody()->empty())
427 return emitOpError() << "expects a non empty body";
428
429 if (getBody()->getNumArguments() != type.getNumInputs())
430 return emitOpError() << "function type indicated " << type.getNumInputs()
431 << " args but block has "
432 << getBody()->getNumArguments();
433
434 for (auto &arg_types : llvm::enumerate(
435 llvm::zip(type.getInputs(), getBody()->getArgumentTypes()))) {
436 Type signature_arg = std::get<0>(arg_types.value());
437 Type block_arg = std::get<1>(arg_types.value());
438 if (signature_arg != block_arg)
439 return emitOpError() << "type mismatch for arg #" << arg_types.index()
440 << ", signature defines " << signature_arg
441 << " block arg is " << block_arg;
442 }
443
444 if (!isa<ReturnOp>(getBody()->back()))
445 return emitOpError()
446 << "expects body to be terminated with a tfg.return, but got: "
447 << getBody()->back().getName().getStringRef();
448
449 ReturnOp return_op = cast<ReturnOp>(getBody()->getTerminator());
450
451 if (type.getNumResults() > return_op->getNumOperands())
452 return emitOpError() << "expects " << type.getNumResults()
453 << " returned values but tfg.return has "
454 << return_op->getNumOperands() << " operands";
455 for (auto &indexed_type : llvm::enumerate(type.getResults())) {
456 Type expected_type = indexed_type.value();
457 int res_num = indexed_type.index();
458 Type actual_type = return_op->getOperand(res_num).getType();
459 if (expected_type == actual_type) continue;
460 return emitOpError() << "type mismatch for returned value #" << res_num
461 << ", expected " << expected_type << " but got "
462 << actual_type;
463 }
464 Type control_type = getDialect()->getControlType();
465 for (auto &indexed_type : llvm::enumerate(llvm::drop_begin(
466 return_op->getOperandTypes(), type.getNumResults()))) {
467 Type actual_type = indexed_type.value();
468 if (actual_type != control_type) {
469 return emitOpError() << "returned value #" << indexed_type.index()
470 << " overflow the expected " << type.getNumResults()
471 << " returned value for function " << getName()
472 << ", expected a ControlType but got "
473 << actual_type;
474 }
475 }
476
477 // Check all ops in the body.
478 if (!all_of(*getBody(), VerifyGenericTFGOperation)) return failure();
479
480 return success();
481 }
482
canonicalize(GraphFuncOp func_op,PatternRewriter & rewriter)483 LogicalResult GraphFuncOp::canonicalize(GraphFuncOp func_op,
484 PatternRewriter &rewriter) {
485 // Prune function body: the body is a graph where feeds/fetches a materialized
486 // with function arguments and returned values. As such any operation not
487 // reachable from the "fetches" can be pruned. The return statement also has
488 // control input so that side-effecting operations without results (print for
489 // example) aren't pruned.
490 bool changed = true;
491 while (changed) {
492 changed = false;
493 for (Operation &op :
494 llvm::make_early_inc_range(llvm::reverse(*func_op.getBody()))) {
495 if (isa<ReturnOp>(op)) continue;
496 if (op.getUses().empty()) {
497 rewriter.eraseOp(&op);
498 changed = true;
499 }
500 }
501 }
502 return failure();
503 }
504
verify()505 LogicalResult GraphFuncOp::verify() {
506 GraphFuncOp func_op = *this;
507 if (func_op.getNumArguments() % 2)
508 return func_op.emitOpError() << "expects an even number of arguments";
509 ArrayAttr args_attrs = func_op.getAllArgAttrs();
510 if (args_attrs && args_attrs.size() != func_op.getNumArguments())
511 return func_op.emitOpError()
512 << "expects argument attributes for each argument ("
513 << args_attrs.size() << " vs " << func_op.getNumArguments() << ")";
514 ArrayAttr res_attrs = func_op.getAllResultAttrs();
515 if (res_attrs && res_attrs.size() != func_op.getNumResults())
516 return func_op.emitOpError()
517 << "expects results attributes for each result (" << res_attrs.size()
518 << " vs " << func_op.getNumResults() << ")";
519 return success();
520 }
521
parse(OpAsmParser & parser,OperationState & result)522 ParseResult GraphFuncOp::parse(OpAsmParser &parser, OperationState &result) {
523 SmallVector<OpAsmParser::UnresolvedOperand> entry_args;
524 SmallVector<Attribute> arg_attrs;
525 SmallVector<Attribute> result_attrs;
526 SmallVector<Type> arg_types;
527 SmallVector<Type> result_types;
528 auto &builder = parser.getBuilder();
529 MLIRContext *context = builder.getContext();
530
531 // Parse visibility.
532 StringRef visibility;
533 if (!parser.parseOptionalKeyword(&visibility,
534 {"public", "private", "nested"})) {
535 StringAttr visibility_attr = parser.getBuilder().getStringAttr(visibility);
536 result.attributes.push_back(parser.getBuilder().getNamedAttr(
537 SymbolTable::getVisibilityAttrName(), visibility_attr));
538 }
539
540 if (succeeded(parser.parseOptionalKeyword("generic")))
541 result.addAttribute("generic", builder.getUnitAttr());
542
543 // Parse the name as a symbol.
544 StringAttr name_attr;
545 if (parser.parseSymbolName(name_attr, SymbolTable::getSymbolAttrName(),
546 result.attributes))
547 return failure();
548
549 // Parse the function signature.
550 // The difference with usual functions, is that for every single argument
551 // parsed, we create two block arguments: one for the expected value and one
552 // for the control dependency.
553 if (parser.parseLParen()) return failure();
554 Type control_ty = ControlType::get(builder.getContext());
555 std::list<std::string> control_operand_names;
556
557 // Helper to parse a single argument and its attributes.
558 auto parse_argument = [&]() -> ParseResult {
559 // Parse argument name if present.
560 entry_args.emplace_back();
561 arg_types.emplace_back();
562 if (parser.parseOperand(entry_args.back(), /*allowResultNumber=*/false) ||
563 parser.parseColonType(arg_types.back()))
564 return failure();
565
566 // Parse any argument attributes.
567 NamedAttrList attrs;
568 if (parser.parseOptionalAttrDict(attrs)) return failure();
569 arg_attrs.push_back(attrs.getDictionary(context));
570
571 // Define the control input: it's not printed but is added as block
572 // argument. Note the name computed here (suffixed ".ctl") is coupled to the
573 // implementation of:
574 // TFGraphOpAsmInterface::getAsmBlockArgumentNames()
575 // at the top of this file.
576 OpAsmParser::UnresolvedOperand control_operand = entry_args.back();
577 control_operand_names.push_back((control_operand.name + ".ctl").str());
578 control_operand.name = control_operand_names.back();
579 entry_args.push_back(control_operand);
580 arg_types.push_back(control_ty);
581 arg_attrs.push_back(DictionaryAttr::get(context));
582 return success();
583 };
584
585 // Parse the function arguments and their attributes.
586 if (failed(parser.parseOptionalRParen())) {
587 do {
588 if (parse_argument()) return failure();
589 } while (succeeded(parser.parseOptionalComma()));
590 if (parser.parseRParen()) return failure();
591 }
592
593 // Parse the result types and their attributes.
594 if (succeeded(parser.parseOptionalArrow())) {
595 if (failed(parser.parseLParen())) return failure();
596 if (failed(parser.parseOptionalRParen())) {
597 // Parse individual function results.
598 do {
599 result_types.emplace_back();
600 NamedAttrList result_attr;
601 if (parser.parseType(result_types.back()) ||
602 parser.parseOptionalAttrDict(result_attr)) {
603 return failure();
604 }
605 result_attrs.push_back(builder.getDictionaryAttr(result_attr));
606 } while (succeeded(parser.parseOptionalComma()));
607 if (parser.parseRParen()) return failure();
608 }
609 }
610
611 auto type = builder.getFunctionType(arg_types, result_types);
612 result.addAttribute(GraphFuncOp::getTypeAttrName(), TypeAttr::get(type));
613
614 // If function attributes are present, parse them.
615 NamedAttrList parsed_attributes;
616 if (parser.parseOptionalAttrDictWithKeyword(parsed_attributes))
617 return failure();
618 result.attributes.append(parsed_attributes);
619
620 // Add the attributes to the function arguments.
621 assert(arg_attrs.size() == arg_types.size());
622 assert(result_attrs.size() == result_types.size());
623 result.attributes.append(
624 builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
625 builder.getArrayAttr(arg_attrs)));
626 result.attributes.append(
627 builder.getNamedAttr(FunctionOpInterface::getResultDictAttrName(),
628 builder.getArrayAttr(result_attrs)));
629
630 // Parse the function body.
631 auto *body = result.addRegion();
632 llvm::SMLoc loc = parser.getCurrentLocation();
633 SmallVector<OpAsmParser::Argument> args;
634 if (entry_args.size()) {
635 for (auto argAndType : llvm::zip(entry_args, arg_types)) {
636 auto &arg = args.emplace_back();
637 arg.ssaName = std::get<0>(argAndType);
638 arg.type = std::get<1>(argAndType);
639 }
640 }
641
642 if (failed(parser.parseRegion(*body, args, /*enableNameShadowing=*/false)))
643 return failure();
644
645 // Function body was parsed, make sure it's not empty.
646 if (body->empty())
647 return parser.emitError(loc, "expected non-empty function body");
648
649 return success();
650 }
651
print(OpAsmPrinter & p)652 void GraphFuncOp::print(OpAsmPrinter &p) {
653 // Print the operation and the function name.
654 Operation *op = *this;
655 p << " ";
656 int argIndentSize = op->getName().getStringRef().size() + 3;
657 StringRef visibility_attr_name = SymbolTable::getVisibilityAttrName();
658 if (auto visibility = op->getAttrOfType<StringAttr>(visibility_attr_name)) {
659 p << visibility.getValue() << ' ';
660 argIndentSize += visibility.getValue().size() + 1;
661 }
662 if (generic()) p << "generic ";
663 auto funcName =
664 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
665 .getValue();
666 p.printSymbolName(funcName);
667 argIndentSize += funcName.size();
668 std::string indent(argIndentSize, ' ');
669 FunctionType fnType = getFunctionType();
670 ArrayRef<Type> arg_types = fnType.getInputs();
671 ArrayRef<Type> result_types = fnType.getResults();
672 assert((arg_types.size() % 2) == 0);
673 // Print operand list with attributes.
674 p << '(';
675 ArrayAttr args_attr = getAllArgAttrs();
676 for (unsigned i = 0, e = arg_types.size(); i < e; i += 2) {
677 // Args come by pair: input+control.
678 p.printOperand(getArgument(i));
679 p << ": ";
680 p.printType(arg_types[i]);
681 if (auto arg_attrs = args_attr[i].dyn_cast<DictionaryAttr>())
682 p.printOptionalAttrDict(arg_attrs.getValue());
683 if (i != e - 2) {
684 p << ", ";
685 p.printNewline();
686 p << indent;
687 }
688 }
689 p << ')';
690
691 // Print result types, if any.
692 if (!result_types.empty()) {
693 p.printNewline();
694 p.getStream() << " -> (";
695 indent = std::string(9, ' ');
696 ArrayAttr results_attr = getAllResultAttrs();
697 for (int i = 0, e = result_types.size(); i < e; ++i) {
698 p.printType(result_types[i]);
699 if (auto result_attrs = results_attr[i].dyn_cast<DictionaryAttr>())
700 p.printOptionalAttrDict(result_attrs.getValue());
701 if (i != e - 1) {
702 p << ", ";
703 p.printNewline();
704 p << indent;
705 }
706 }
707 p << ")";
708 }
709 // Print attributes.
710 if (!op->getAttrs().empty()) {
711 p.printNewline();
712 function_interface_impl::printFunctionAttributes(
713 p, *this, fnType.getNumInputs(), fnType.getNumResults(),
714 {"generic", SymbolTable::getVisibilityAttrName()});
715 }
716 // Print body.
717 p << ' ';
718 p.printRegion(body(), /*printEntryBlockArgs=*/false);
719 }
720
getCalledFunction(Operation * op,SymbolTable & symbol_table)721 GraphFuncOp GraphFuncOp::getCalledFunction(Operation *op,
722 SymbolTable &symbol_table) {
723 // Check if a node does indirect function call via PartitionedCallOp.
724 // TODO(aminim): consider replacing with isa<...> when possible.
725 if (op->getName().getStringRef() == "tfg.PartitionCall" ||
726 op->getName().getStringRef() == "tfg.StatefulPartitionedCall") {
727 auto func_attr = op->getAttrOfType<FuncAttr>("f");
728 if (!func_attr) return {};
729 GraphFuncOp callee = symbol_table.lookup<GraphFuncOp>(
730 func_attr.getName().getLeafReference());
731 if (callee) return callee;
732 }
733 return symbol_table.lookup<GraphFuncOp>(op->getName().stripDialect());
734 }
735
getDataValueOf(BlockArgument ctl)736 BlockArgument GraphFuncOp::getDataValueOf(BlockArgument ctl) {
737 return ctl.getOwner()->getArgument(ctl.getArgNumber() - 1);
738 }
739
getControlTokenOf(BlockArgument data)740 BlockArgument GraphFuncOp::getControlTokenOf(BlockArgument data) {
741 return data.getOwner()->getArgument(data.getArgNumber() + 1);
742 }
743
getDataValue(Region & region,unsigned idx)744 BlockArgument GraphFuncOp::getDataValue(Region ®ion, unsigned idx) {
745 return region.getArgument(idx * 2);
746 }
747
748 // This is naming block arguments for GraphFuncOp, we rely on the arg attributes
749 // for computing the names.
getAsmBlockArgumentNames(Region & region,OpAsmSetValueNameFn set_name_fn)750 void GraphFuncOp::getAsmBlockArgumentNames(Region ®ion,
751 OpAsmSetValueNameFn set_name_fn) {
752 ArrayRef<BlockArgument> args = getBody()->getArguments();
753 ControlType control_ty = ControlType::get(getContext());
754 // Sanity checking: this is verified by the op but this may be called before
755 // the verifier or in some diagnostic/debug context, let's not crash.
756 // We expect the function block operands to come as pair: tensor+control.
757 if (args.size() % 2) return;
758 for (unsigned i = 0, e = args.size(); i < e; i += 2)
759 if (args[i].getType() == control_ty || args[i + 1].getType() != control_ty)
760 return;
761
762 // Name the values based on the `tfg.name` arg attribute retrieved from the
763 // func_op.
764 ArrayAttr args_attr = getAllArgAttrs();
765 if (!args_attr || args_attr.size() != args.size()) return;
766 for (int arg_num = 0, e = args.size(); arg_num < e; arg_num += 2) {
767 DictionaryAttr arg_attrs = args_attr[arg_num].dyn_cast<DictionaryAttr>();
768 if (!arg_attrs) continue;
769 if (auto strAttr = arg_attrs.getAs<StringAttr>("tfg.name")) {
770 set_name_fn(args[arg_num], strAttr.getValue());
771 set_name_fn(args[arg_num + 1], (strAttr.getValue() + ".ctl").str());
772 }
773 }
774 }
775
776 //===----------------------------------------------------------------------===//
777 // ReturnOp
778 //===----------------------------------------------------------------------===//
779
verify()780 LogicalResult ReturnOp::verify() {
781 ReturnOp op = *this;
782 // If the control result attributes are present, there must be the same number
783 // of entries as control results.
784 if (op.control_ret_attrs().size() != TFOp(op).getControlOperands().size()) {
785 return op.emitOpError(
786 "expected as many control result attributes as there are control "
787 "operands");
788 }
789 return success();
790 }
791
parse(OpAsmParser & parser,OperationState & result)792 ParseResult ReturnOp::parse(OpAsmParser &parser, OperationState &result) {
793 // ReturnOp has the same assembly format as generic TFG ops except that the
794 // control result attributes are embedded with the control operands:
795 // [%ctl {tfg.name = "foo"}, %ctl_0 {tfg.name = "bar"}]
796 SmallVector<OpAsmParser::UnresolvedOperand> operands;
797 if (parser.parseOperandList(operands, AsmParser::Delimiter::OptionalParen))
798 return failure();
799
800 SmallVector<Attribute> control_ret_attrs;
801 if (succeeded(parser.parseOptionalLSquare())) {
802 OpAsmParser::UnresolvedOperand operand;
803 do {
804 NamedAttrList attrs;
805 OptionalParseResult parse_result = parser.parseOptionalOperand(operand);
806 if (!parse_result.hasValue()) break;
807 if (failed(parse_result.getValue())) return failure();
808 if (parser.parseOptionalAttrDict(attrs)) return failure();
809 control_ret_attrs.push_back(attrs.getDictionary(result.getContext()));
810 operands.push_back(std::move(operand));
811 } while (succeeded(parser.parseOptionalComma()));
812 if (parser.parseRSquare()) return failure();
813 }
814
815 if (ParseKeywordAttributes(parser, result)) return failure();
816 result.addAttribute(ReturnOp::control_ret_attrsAttrName(result.name),
817 ArrayAttr::get(result.getContext(), control_ret_attrs));
818
819 SmallVector<Type> types;
820 if (parser.parseOptionalColonTypeList(types)) return failure();
821 types.resize(operands.size(), ControlType::get(result.getContext()));
822 if (parser.resolveOperands(operands, types, parser.getCurrentLocation(),
823 result.operands))
824 return failure();
825 return success();
826 }
827
print(OpAsmPrinter & printer)828 void ReturnOp::print(OpAsmPrinter &printer) {
829 TFOp tfg_op(*this);
830 OperandRange data = tfg_op.getNonControlOperands();
831 if (!data.empty()) printer << '(' << data << ')';
832
833 OperandRange ctls = tfg_op.getControlOperands();
834 if (!ctls.empty()) {
835 printer << " [";
836 llvm::interleave(
837 llvm::zip(ctls, control_ret_attrs().getAsRange<DictionaryAttr>()),
838 printer,
839 [&](auto it) {
840 printer << std::get<0>(it);
841 if (!std::get<1>(it).empty()) printer << ' ' << std::get<1>(it);
842 },
843 ", ");
844 printer << ']';
845 }
846
847 PrintKeywordAttributes(*this, printer, {"control_ret_attrs"});
848
849 if (!data.empty()) printer << " : " << data.getTypes();
850 }
851
build(OpBuilder & odsBuilder,OperationState & odsState,ValueRange operands,ValueRange control_operands)852 void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState,
853 ValueRange operands, ValueRange control_operands) {
854 odsState.addOperands(operands);
855 odsState.addOperands(control_operands);
856 // Populate `control_ret_attrs` with empty dictionaries.
857 odsState.addAttribute(
858 ReturnOp::control_ret_attrsAttrName(odsState.name),
859 odsBuilder.getArrayAttr(SmallVector<Attribute>(
860 control_operands.size(), odsBuilder.getDictionaryAttr({}))));
861 }
862
863 //===----------------------------------------------------------------------===//
864 // Concrete Ops
865 //===----------------------------------------------------------------------===//
866
867 // The ODS definitions of TFG ops can be autogenerated TODO(jeffniu) as well as
868 // parts of their verifiers. These hand-written verifiers focus on verifying the
869 // ops' operand and result types with respect to their functions' types, the
870 // logic for which is slightly different between operations.
871
872 // Verify that all control operands follow non-control operands, and return the
873 // subrange of non-control operands.
VerifyOperands(Operation * op)874 static FailureOr<TypeRange> VerifyOperands(Operation *op) {
875 ControlType control_ty =
876 cast<TFGraphDialect>(op->getDialect())->getControlType();
877 Operation::operand_type_iterator it =
878 llvm::find(op->getOperandTypes(), control_ty);
879 if (!std::all_of(it, op->operand_type_end(),
880 [&](Type type) { return type == control_ty; })) {
881 return op->emitOpError(
882 "not all control tokens come after non-control operands");
883 }
884 return {Operation::operand_type_range(op->operand_type_begin(), it)};
885 }
886
887 // Verify that the last result of an operation is the only control result, and
888 // return a subrange of the non-control results.
VerifyResults(Operation * op)889 static FailureOr<TypeRange> VerifyResults(Operation *op) {
890 ControlType control_ty =
891 cast<TFGraphDialect>(op->getDialect())->getControlType();
892 Operation::result_type_iterator it =
893 llvm::find(op->getResultTypes(), control_ty);
894 if (it == op->result_type_end())
895 return op->emitOpError("does not define a control result");
896 if (it != std::prev(op->result_type_end())) {
897 return op->emitOpError(
898 "must have a control token result as and only as its last result");
899 }
900 return {Operation::result_type_range(op->result_type_begin(), it)};
901 }
902
903 // Verify that the signature of the function matches the operation's operands
904 // and results.
VerifySignature(GraphFuncOp func,Operation * op,TypeRange operands,TypeRange results,const Twine & func_name)905 static LogicalResult VerifySignature(GraphFuncOp func, Operation *op,
906 TypeRange operands, TypeRange results,
907 const Twine &func_name) {
908 auto attach_func = [&](InFlightDiagnostic diag) -> LogicalResult {
909 return diag.attachNote(func.getLoc()).appendOp(*func, OpPrintingFlags())
910 << "\nsee referenced function";
911 };
912
913 ArrayRef<Type> arguments = func.getFunctionType().getInputs();
914 ArrayRef<Type> returns = func.getFunctionType().getResults();
915 if (operands.size() * 2 != arguments.size()) {
916 return attach_func(op->emitOpError(func_name)
917 << " function has " << arguments.size() / 2
918 << " arguments but was provided " << operands.size());
919 }
920 if (results.size() != returns.size()) {
921 return attach_func(op->emitOpError(func_name)
922 << " function has " << returns.size()
923 << " results but expected " << results.size());
924 }
925
926 if (func.generic()) return success();
927
928 for (auto &it : llvm::enumerate(operands)) {
929 Type arg_type = arguments[it.index() * 2];
930 Type op_type = it.value();
931 if (!tf_type::HasCompatibleElementTypes(arg_type, op_type)) {
932 return attach_func(
933 op->emitOpError(func_name)
934 << " function argument #" << it.index() << " type " << arg_type
935 << " is not compatible with corresponding operand type: " << op_type);
936 }
937 }
938 for (auto &it : llvm::enumerate(results)) {
939 Type ret_type = returns[it.index()];
940 Type res_type = it.value();
941 if (!tf_type::HasCompatibleElementTypes(ret_type, res_type)) {
942 return attach_func(
943 op->emitOpError(func_name)
944 << " function result #" << it.index() << " type " << ret_type
945 << " is not compatible with corresponding result type: " << res_type);
946 }
947 }
948 return success();
949 }
950
951 // This function verifies that the types of `values`, which are either operands
952 // or results of `op`, match the types specified in `types`, which is expected
953 // to be an array of type attributes.
VerifyTypeArray(Operation * op,ValueRange values,ArrayAttr types,StringRef kind)954 static LogicalResult VerifyTypeArray(Operation *op, ValueRange values,
955 ArrayAttr types, StringRef kind) {
956 // Don't verify if the types are not present.
957 if (!types) return success();
958 if (values.size() != types.size()) {
959 return op->emitOpError("has ") << values.size() << " " << kind << "s but "
960 << types.size() << " " << kind << " types";
961 }
962 for (auto it :
963 llvm::zip(llvm::enumerate(values), types.getAsRange<TypeAttr>())) {
964 Type type = std::get<0>(it).value().getType();
965 Type dtype = std::get<1>(it).getValue();
966 if (!tf_type::HasCompatibleElementTypes(type,
967 UnrankedTensorType::get(dtype))) {
968 return op->emitOpError(kind)
969 << " #" << std::get<0>(it).index()
970 << " is incompatible with dtype " << dtype << ", got: " << type;
971 }
972 }
973 return success();
974 }
975
976 namespace detail {
977 // Check if the op type has `T`.
978 template <typename OpT>
979 using has_T = decltype(std::declval<OpT>().T());
980 template <typename OpT>
981 using detect_has_T = llvm::is_detected<has_T, OpT>;
982
983 // Get the input and output type arrays. If the op has a single type array,
984 // use it for both input and output. Otherwise, return separate type arrays.
985 template <typename OpT, bool = detect_has_T<OpT>::value>
986 struct GetTypeArray {
getInputTypesmlir::tfg::detail::GetTypeArray987 static ArrayAttr getInputTypes(OpT op) { return op.TinAttr(); }
getOutputTypesmlir::tfg::detail::GetTypeArray988 static ArrayAttr getOutputTypes(OpT op) { return op.ToutAttr(); }
989 };
990 template <typename OpT>
991 struct GetTypeArray<OpT, true> {
getInputTypesmlir::tfg::detail::GetTypeArray992 static ArrayAttr getInputTypes(OpT op) { return op.TAttr(); }
getOutputTypesmlir::tfg::detail::GetTypeArray993 static ArrayAttr getOutputTypes(OpT op) { return op.TAttr(); }
994 };
995 } // namespace detail
996
997 // Verify a functional op's inputs and outputs against its data type arrays. For
998 // loop ops, this also checks that the number of inputs and outputs match. This
999 // is guaranteed to be valid on import but may be violated by a transformation.
1000 template <typename OpT>
VerifyTypeArrayAttributes(OpT op)1001 static LogicalResult VerifyTypeArrayAttributes(OpT op) {
1002 using GetTypeArray = typename detail::GetTypeArray<OpT>;
1003 ValueRange args =
1004 SplitDataAndControlValues(op.args(), ControlType::get(op.getContext()))
1005 .first;
1006 return success(
1007 succeeded(VerifyTypeArray(op, args, GetTypeArray::getInputTypes(op),
1008 "argument")) &&
1009 succeeded(VerifyTypeArray(op, op.outs(), GetTypeArray::getOutputTypes(op),
1010 "result")));
1011 }
1012
1013 //===----------------------------------------------------------------------===//
1014 // If-Like Ops
1015
1016 template <typename IfLikeOp>
VerifyIfLikeOp(IfLikeOp op,SymbolTableCollection & symbol_table)1017 static LogicalResult VerifyIfLikeOp(IfLikeOp op,
1018 SymbolTableCollection &symbol_table) {
1019 if (failed(op.verifyInvariants())) return failure();
1020 FailureOr<TypeRange> ins = VerifyOperands(op);
1021 if (failed(ins)) return failure();
1022 FailureOr<TypeRange> outs = VerifyResults(op);
1023 if (failed(outs)) return failure();
1024
1025 // The first operand is the condition and is not passed to the functions.
1026 TypeRange func_args = ins->drop_front();
1027
1028 auto then_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
1029 op, op.then_branch().getName());
1030 if (then_func &&
1031 failed(VerifySignature(then_func, op, func_args, *outs, "then")))
1032 return failure();
1033
1034 auto else_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
1035 op, op.else_branch().getName());
1036 if (else_func &&
1037 failed(VerifySignature(else_func, op, func_args, *outs, "else")))
1038 return failure();
1039
1040 return VerifyTypeArrayAttributes(op);
1041 }
1042
1043 //===----------------------------------------------------------------------===//
1044 // Case-Like Ops
1045
1046 template <typename CaseLikeOp>
VerifyCaseLikeOp(CaseLikeOp op,SymbolTableCollection & symbol_table)1047 static LogicalResult VerifyCaseLikeOp(CaseLikeOp op,
1048 SymbolTableCollection &symbol_table) {
1049 if (failed(op.verifyInvariants())) return failure();
1050 FailureOr<TypeRange> ins = VerifyOperands(op);
1051 if (failed(ins)) return failure();
1052 FailureOr<TypeRange> outs = VerifyResults(op);
1053 if (failed(outs)) return failure();
1054
1055 // The first operand is the branch index and is not passed to the functions.
1056 TypeRange func_args = ins->drop_front();
1057
1058 for (auto &it : llvm::enumerate(op.branches())) {
1059 SymbolRefAttr func_name = it.value().template cast<FuncAttr>().getName();
1060 auto func =
1061 symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(op, func_name);
1062 if (func && failed(VerifySignature(func, op, func_args, *outs,
1063 "branch #" + Twine(it.index()))))
1064 return failure();
1065 }
1066
1067 return VerifyTypeArrayAttributes(op);
1068 }
1069
1070 //===----------------------------------------------------------------------===//
1071 // While-Like Ops
1072
1073 template <typename WhileLikeOp>
VerifyWhileLikeOp(WhileLikeOp op,SymbolTableCollection & symbol_table)1074 static LogicalResult VerifyWhileLikeOp(WhileLikeOp op,
1075 SymbolTableCollection &symbol_table) {
1076 if (failed(op.verifyInvariants())) return failure();
1077 FailureOr<TypeRange> ins = VerifyOperands(op);
1078 if (failed(ins)) return failure();
1079 FailureOr<TypeRange> outs = VerifyResults(op);
1080 if (failed(outs)) return failure();
1081
1082 SymbolRefAttr body_name = op.body().getName();
1083
1084 auto cond_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
1085 op, op.cond().getName());
1086 auto i1_type = UnrankedTensorType::get(Builder(op.getContext()).getI1Type());
1087 if (cond_func &&
1088 failed(VerifySignature(cond_func, op, *ins, i1_type, "cond")))
1089 return failure();
1090
1091 auto body_func = symbol_table.lookupNearestSymbolFrom<GraphFuncOp>(
1092 op, op.body().getName());
1093 if (body_func && failed(VerifySignature(body_func, op, *ins, *outs, "body")))
1094 return failure();
1095
1096 return VerifyTypeArrayAttributes(op);
1097 }
1098
1099 //===----------------------------------------------------------------------===//
1100 // ForOp
1101
verifySymbolUses(SymbolTableCollection & symbolTable)1102 LogicalResult ForOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1103 if (failed(verifyInvariants())) return failure();
1104 FailureOr<TypeRange> ins = VerifyOperands(*this);
1105 if (failed(ins)) return failure();
1106 FailureOr<TypeRange> outs = VerifyResults(*this);
1107 if (failed(outs)) return failure();
1108
1109 auto body_func =
1110 symbolTable.lookupNearestSymbolFrom<GraphFuncOp>(*this, body().getName());
1111 // The first three arguments are the for-loop indices, but the current loop
1112 // index is passed in.
1113 TypeRange func_args = llvm::drop_begin(*ins, /*N=*/2);
1114 if (body_func &&
1115 failed(VerifySignature(body_func, *this, func_args, *outs, "body")))
1116 return failure();
1117
1118 return VerifyTypeArrayAttributes(*this);
1119 }
1120
1121 //===----------------------------------------------------------------------===//
1122 // Region Ops and Terminators
1123 //===----------------------------------------------------------------------===//
1124
1125 // If a region op has preserved attributes, verify that they match the number of
1126 // results and block arguments.
VerifyPreservedAttrs(Operation * op,ArrayRef<Attribute> preserved_attrs)1127 static LogicalResult VerifyPreservedAttrs(Operation *op,
1128 ArrayRef<Attribute> preserved_attrs) {
1129 assert(op->getNumRegions() == preserved_attrs.size());
1130 for (auto it : llvm::zip(preserved_attrs, op->getRegions())) {
1131 // Preserved attributes for a particular region may not exist.
1132 auto attrs = std::get<0>(it).dyn_cast_or_null<RegionAttr>();
1133 if (!attrs) continue;
1134 Region ®ion = std::get<1>(it);
1135
1136 const auto emit_region_error = [&](StringRef msg) {
1137 return op->emitOpError("region #")
1138 << region.getRegionNumber() << " " << msg;
1139 };
1140
1141 unsigned num_args = GetLoopRegionDataArgs(region).size();
1142 if (num_args != attrs.getArgAttrs().size()) {
1143 return emit_region_error("has ")
1144 << num_args << " argument(s) but preserved attributes has "
1145 << attrs.getArgAttrs().size();
1146 }
1147
1148 // All regions are terminated by either a YieldOp or a ConditionOp. In the
1149 // latter case, the function will only have one result.
1150 unsigned num_rets;
1151 Operation *terminator = region.front().getTerminator();
1152 if (isa<ConditionOp>(terminator)) {
1153 num_rets = 1;
1154 } else {
1155 num_rets = cast<RegionBranchTerminatorOpInterface>(terminator)
1156 .getMutableSuccessorOperands(region.getRegionNumber())
1157 .size();
1158 }
1159 if (num_rets != attrs.getResAttrs().size()) {
1160 return emit_region_error("has ")
1161 << num_rets << " result(s) but preserved attributes has "
1162 << attrs.getResAttrs().size();
1163 }
1164 }
1165 return success();
1166 }
1167
1168 //===----------------------------------------------------------------------===//
1169 // YieldOp
1170
getMutableSuccessorOperands(Optional<unsigned> index)1171 MutableOperandRange YieldOp::getMutableSuccessorOperands(
1172 Optional<unsigned> index) {
1173 // Get the subrange of non-control operands.
1174 return argsMutable();
1175 }
1176
TerminatedByYield(Block & block)1177 static bool TerminatedByYield(Block &block) {
1178 return isa<YieldOp>(block.getTerminator());
1179 }
1180
1181 //===----------------------------------------------------------------------===//
1182 // IfLikeRegionOp
1183
1184 // Verify an if-like region op.
1185 template <typename IfLikeRegionOp>
VerifyIfLikeRegionOp(IfLikeRegionOp op)1186 static LogicalResult VerifyIfLikeRegionOp(IfLikeRegionOp op) {
1187 // Verify terminators.
1188 if (!TerminatedByYield(op.then_block()))
1189 return op.emitOpError("then region must be terminated by a 'tfg.yield'");
1190 if (!TerminatedByYield(op.else_block()))
1191 return op.emitOpError("else region must be terminated by a 'tfg.yield'");
1192 return VerifyPreservedAttrs(
1193 op, {op.then_region_attrsAttr(), op.else_region_attrsAttr()});
1194 }
1195
1196 // Given an potentially null attribute that would represent a constant value,
1197 // try to narrow it to a statically known condition.
1198 // TODO(jeffniu): Incorporate the other cases of `tf.ToBool`.
GetStaticallyKnownBranch(Attribute cond_attr)1199 static Optional<bool> GetStaticallyKnownBranch(Attribute cond_attr) {
1200 // Only handle the case of a scalar tensor of i1.
1201 auto cond = cond_attr.dyn_cast_or_null<ElementsAttr>();
1202 if (cond && cond.getNumElements() == 1 &&
1203 cond.getElementType().isSignlessInteger(1))
1204 return cond.getSplatValue<bool>();
1205 return {};
1206 }
1207
1208 // Get the successor of the regions of an if-like op.
1209 template <typename IfLikeRegionOp>
GetIfLikeRegionOpSuccessorRegions(IfLikeRegionOp op,Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1210 void GetIfLikeRegionOpSuccessorRegions(
1211 IfLikeRegionOp op, Optional<unsigned> index, ArrayRef<Attribute> operands,
1212 SmallVectorImpl<RegionSuccessor> ®ions) {
1213 assert(index.has_value() ||
1214 !operands.empty() && "if-like op expected at least 1 operand");
1215 // Both regions branch back to the parent op.
1216 if (index.has_value()) {
1217 // Ignore the control token.
1218 regions.emplace_back(
1219 ResultRange(op->result_begin(), std::prev(op->result_end())));
1220 } else if (auto cond = GetStaticallyKnownBranch(operands[0])) {
1221 // Add only 1 possible successor if the condition is known.
1222 Region ®ion = *cond ? op.then_region() : op.else_region();
1223 regions.emplace_back(®ion, GetLoopRegionDataArgs(region));
1224 } else {
1225 // Unknown successor.
1226 regions.emplace_back(&op.then_region(),
1227 GetLoopRegionDataArgs(op.then_region()));
1228 regions.emplace_back(&op.else_region(),
1229 GetLoopRegionDataArgs(op.else_region()));
1230 }
1231 }
1232
1233 //===----------------------------------------------------------------------===//
1234 // CaseLikeRegionOp
1235
1236 // Verify a case-like region op.
1237 template <typename CaseLikeRegionOp>
VerifyCaseLikeRegionOp(CaseLikeRegionOp op)1238 static LogicalResult VerifyCaseLikeRegionOp(CaseLikeRegionOp op) {
1239 for (auto &it : llvm::enumerate(op.branches())) {
1240 if (!TerminatedByYield(it.value().front())) {
1241 return op.emitOpError("branch region #")
1242 << it.index() << " is not terminated by a 'tfg.yield' op";
1243 }
1244 }
1245
1246 if (op.branch_attrs() && op.branches().size() != op.branch_attrs()->size()) {
1247 return op.emitOpError("has ")
1248 << op.branches().size() << " regions but "
1249 << op.branch_attrs()->size() << " branch function attributes";
1250 }
1251 if (auto region_attrs = op.region_attrsAttr()) {
1252 if (region_attrs.size() != op.getNumRegions()) {
1253 return op.emitOpError("expected ")
1254 << op.getNumRegions() << " region attribute(s) but got "
1255 << region_attrs.size();
1256 }
1257 if (failed(VerifyPreservedAttrs(op, region_attrs.getValue())))
1258 return failure();
1259 }
1260 return success();
1261 }
1262
1263 // Given a potentially null attribute that would represent a constant value,
1264 // try to narrow it to a statically known branch index.
GetStaticallyKnownCaseBranch(Attribute branch_attr)1265 static Optional<unsigned> GetStaticallyKnownCaseBranch(Attribute branch_attr) {
1266 auto branch = branch_attr.dyn_cast_or_null<ElementsAttr>();
1267 if (branch && branch.getNumElements() == 1 &&
1268 branch.getElementType().isSignlessInteger(32))
1269 return branch.getSplatValue<unsigned>();
1270 return {};
1271 }
1272
1273 // Get the successor of the regions of a case-like op.
1274 template <typename CaseLikeRegionOp>
GetCaseLikeRegionOpSuccessorRegions(CaseLikeRegionOp op,Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1275 void GetCaseLikeRegionOpSuccessorRegions(
1276 CaseLikeRegionOp op, Optional<unsigned> index, ArrayRef<Attribute> operands,
1277 SmallVectorImpl<RegionSuccessor> ®ions) {
1278 assert(index.has_value() ||
1279 !operands.empty() && "case-like op expected at least 1 operand");
1280 // All branch regions branch back to the parent op.
1281 if (index.has_value()) {
1282 // Ignore the control token.
1283 regions.emplace_back(
1284 ResultRange(op->result_begin(), std::prev(op->result_end())));
1285 } else if (auto branch_index = GetStaticallyKnownCaseBranch(operands[0])) {
1286 // Add only 1 possible successor if the condition is known.
1287 Region ®ion = op.branches()[*branch_index];
1288 regions.emplace_back(®ion, GetLoopRegionDataArgs(region));
1289 } else {
1290 // Unknown successor. Add all of them.
1291 for (Region &branch : op.branches())
1292 regions.emplace_back(&branch, GetLoopRegionDataArgs(branch));
1293 }
1294 }
1295
1296 //===----------------------------------------------------------------------===//
1297 // ConditionOp
1298
getMutableSuccessorOperands(Optional<unsigned> index)1299 MutableOperandRange ConditionOp::getMutableSuccessorOperands(
1300 Optional<unsigned> index) {
1301 // Get the subrange of non-control operands that are forwarded to the
1302 // successor region.
1303 return argsMutable();
1304 }
1305
1306 //===----------------------------------------------------------------------===//
1307 // WhileLikeRegionOp
1308
1309 // Verify that the loop regions of a region-based loop op have N control tokens
1310 // immediately following N data values in their entry block arguments.
1311 // `RegionBranchOpInterface` will verify the number of arguments and their
1312 // types.
VerifyLoopRegionArgs(Operation * op,Region & region)1313 static LogicalResult VerifyLoopRegionArgs(Operation *op, Region ®ion) {
1314 const auto arg_error = [&](BlockArgument arg) {
1315 return op->emitOpError("region #")
1316 << region.getRegionNumber() << " argument #" << arg.getArgNumber()
1317 << " ";
1318 };
1319
1320 // The interface trait verifies the number of data and control arguments. If
1321 // the first half of the arguments are not control tokens, then we know for
1322 // sure that the second half is only control tokens.
1323 for (BlockArgument data : GetLoopRegionDataArgs(region))
1324 if (data.getType().isa<ControlType>())
1325 return arg_error(data) << "should not be a control token";
1326 return success();
1327 }
1328
1329 // Verify a while-like region op.
1330 template <typename WhileLikeRegionOp>
VerifyWhileLikeRegionOp(WhileLikeRegionOp op)1331 static LogicalResult VerifyWhileLikeRegionOp(WhileLikeRegionOp op) {
1332 // Verify terminators.
1333 if (!isa<ConditionOp>(op.cond_block().getTerminator())) {
1334 return op.emitOpError(
1335 "condition region must be terminated by a 'tfg.condition' op");
1336 }
1337 if (!TerminatedByYield(op.body_block()))
1338 op.emitOpError("body region must be terminated by a 'tfg.yield' op");
1339
1340 if (failed(VerifyLoopRegionArgs(op, op.cond_region())) ||
1341 failed(VerifyLoopRegionArgs(op, op.body_region())))
1342 return failure();
1343 if (failed(VerifyPreservedAttrs(
1344 op, {op.cond_region_attrsAttr(), op.body_region_attrsAttr()})))
1345 return failure();
1346
1347 return success();
1348 }
1349
1350 template <typename WhileLikeRegionOp>
GetWhileLikeRegionOpSuccessorRegions(WhileLikeRegionOp op,Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1351 static void GetWhileLikeRegionOpSuccessorRegions(
1352 WhileLikeRegionOp op, Optional<unsigned> index,
1353 ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) {
1354 // The parent op and the body region always branch to the condion region.
1355 if (!index || *index == 1) {
1356 regions.emplace_back(&op.cond_region(),
1357 GetLoopRegionDataArgs(op.cond_region()));
1358 return;
1359 }
1360 assert(*index == 0 && "invalid region index");
1361 // The condition regions branches to the loop body or back to the parent.
1362 // Try to narrow the condition value to a constant.
1363 auto condition = cast<ConditionOp>(op.cond_region().front().getTerminator());
1364 Attribute cond_attr;
1365 matchPattern(condition.cond(), m_Constant(&cond_attr));
1366 Optional<bool> cond = GetStaticallyKnownBranch(cond_attr);
1367 if (!cond || *cond) {
1368 regions.emplace_back(&op.body_region(),
1369 GetLoopRegionDataArgs(op.body_region()));
1370 }
1371 if (!cond || !*cond) {
1372 // Drop the control token.
1373 regions.emplace_back(op.getResults().drop_back());
1374 }
1375 }
1376
1377 //===----------------------------------------------------------------------===//
1378 // ForRegionOp
1379
verify()1380 LogicalResult ForRegionOp::verify() {
1381 if (!TerminatedByYield(body_block())) {
1382 return emitOpError("body region must be terminated by a 'tfg.yield' op");
1383 }
1384
1385 Block::BlockArgListType args = body_block().getArguments();
1386 if (args.empty()) {
1387 return emitOpError(
1388 "expected the body block to have at least have the loop index as an "
1389 "argument");
1390 }
1391 auto index = args.front().getType().dyn_cast<TensorType>();
1392 if (!index || !index.getElementType().isSignlessInteger(32)) {
1393 return emitOpError(
1394 "expected first body block argument to be an i32 tensor");
1395 }
1396
1397 if (failed(VerifyLoopRegionArgs(*this, body_region()))) return failure();
1398 return VerifyPreservedAttrs(*this, {region_attrsAttr()});
1399 }
1400
getSuccessorEntryOperands(Optional<unsigned> index)1401 OperandRange ForRegionOp::getSuccessorEntryOperands(Optional<unsigned> index) {
1402 return init();
1403 }
1404
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1405 void ForRegionOp::getSuccessorRegions(
1406 Optional<unsigned> index, ArrayRef<Attribute> operands,
1407 SmallVectorImpl<RegionSuccessor> ®ions) {
1408 // Both the parent op and the body region branch to the body. Ignore the loop
1409 // index block argument, as it is not modified by the loop body itself.
1410 regions.emplace_back(&body_region(),
1411 GetLoopRegionDataArgs(body_region()).drop_front());
1412 if (!index) return;
1413 // The body might branch back to the parent. Drop the control token.
1414 regions.emplace_back((*this)->getResults().drop_back());
1415 }
1416
getDataValueOf(BlockArgument ctl)1417 BlockArgument ForRegionOp::getDataValueOf(BlockArgument ctl) {
1418 return GetLoopRegionDataOf(ctl);
1419 }
getControlTokenOf(BlockArgument data)1420 BlockArgument ForRegionOp::getControlTokenOf(BlockArgument data) {
1421 return GetLoopRegionControlOf(data);
1422 }
getDataValue(Region & region,unsigned idx)1423 BlockArgument ForRegionOp::getDataValue(Region ®ion, unsigned idx) {
1424 return GetLoopRegionDataArgs(region)[idx];
1425 }
getControlToken(Region & region,unsigned idx)1426 BlockArgument ForRegionOp::getControlToken(Region ®ion, unsigned idx) {
1427 return GetLoopRegionControlTokens(region)[idx];
1428 }
1429
1430 //===----------------------------------------------------------------------===//
1431 // Function Table
1432 //===----------------------------------------------------------------------===//
1433
FunctionTable(ModuleOp module)1434 FunctionTable::FunctionTable(ModuleOp module) {
1435 // Collect function names (to be used for disambiguating legacy call
1436 // behavior).
1437 for (auto &op : module.getOps()) {
1438 if (auto func = dyn_cast<GraphFuncOp>(op)) functions.insert(func.getName());
1439 }
1440 }
1441
MayBeCall(Operation * op) const1442 bool FunctionTable::MayBeCall(Operation *op) const {
1443 if (IsLegacyCall(op)) return true;
1444 // The operation might be a call if it references a symbol.
1445 bool references_symbol = false;
1446 op->getAttrDictionary().walkSubAttrs(
1447 [&](Attribute attr) { references_symbol |= attr.isa<SymbolRefAttr>(); });
1448 return references_symbol;
1449 }
1450
IsLegacyCall(Operation * op) const1451 bool FunctionTable::IsLegacyCall(Operation *op) const {
1452 // If the operation name refers to a function in the module, then it is
1453 // guaranteed to be a legacy call. Otherwise, it is not.
1454 return functions.count(op->getName().stripDialect());
1455 }
1456
1457 } // namespace tfg
1458 } // namespace mlir
1459
1460 //===----------------------------------------------------------------------===//
1461 // ODS Definitions
1462 //===----------------------------------------------------------------------===//
1463
1464 #define GET_OP_CLASSES
1465 #include "tensorflow/core/ir/ops.cc.inc"
1466 #define GET_ATTRDEF_CLASSES
1467 #include "tensorflow/core/ir/attributes.cc.inc"
1468