1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/transforms/functional_to_region/impl.h"
17
18 #include <algorithm>
19 #include <tuple>
20
21 #include "llvm/ADT/DenseSet.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/Sequence.h"
24 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/Diagnostics.h" // from @llvm-project
28 #include "mlir/IR/PatternMatch.h" // from @llvm-project
29 #include "mlir/IR/SymbolTable.h" // from @llvm-project
30 #include "mlir/IR/Value.h" // from @llvm-project
31 #include "mlir/Support/LogicalResult.h" // from @llvm-project
32 #include "tensorflow/core/ir/dialect.h"
33 #include "tensorflow/core/ir/ops.h"
34 #include "tensorflow/core/ir/types/dialect.h"
35 #include "tensorflow/core/ir/utility.h"
36 #include "tensorflow/core/transforms/utils/utils.h"
37
38 namespace mlir {
39 namespace tfg {
40
41 //===----------------------------------------------------------------------===//
42 // Pattern Definitions
43 //===----------------------------------------------------------------------===//
44
45 namespace {
46 // Base class for patterns that convert functional ops to region-based ops. This
47 // class contains common utility functions and class members.
48 class BasePattern {
49 public:
BasePattern(SymbolTable & table,TFGraphDialect & dialect)50 BasePattern(SymbolTable &table, TFGraphDialect &dialect)
51 : table_(table), dialect_(dialect) {}
52
53 protected:
54 // Lookup, using the symbol table, a graph function.
LookupFunc(FuncAttr func_ref) const55 GraphFuncOp LookupFunc(FuncAttr func_ref) const {
56 return table_.lookup<GraphFuncOp>(func_ref.getName().getLeafReference());
57 }
58
59 // Split a range of non-control and control operands.
SplitControl(ValueRange values) const60 std::pair<ValueRange, ValueRange> SplitControl(ValueRange values) const {
61 return SplitDataAndControlValues(values, dialect_.getControlType());
62 }
63
64 // Convert the terminator of a region from `return` to `yield`.
65 YieldOp ReplaceReturnWithYield(Block &block, TypeRange types,
66 PatternRewriter &rewriter) const;
67
68 // Copy a region from a function body to a loop body, reordering the arguments
69 // from function order (pairs of data and control values) to loop order (all
70 // data values followed by all control values).
71 void CloneAndReorderArgs(TypeRange types, Region &from, Region &to,
72 PatternRewriter &rewriter) const;
73
74 // Clone ops from one region to another with a given value mapping. Rename
75 // clone ops with unique names.
76 void CloneAndRename(Region &from, Region &to, BlockAndValueMapping &bv) const;
77
78 protected:
79 // Symbol table for looking up branch/loop functions.
80 SymbolTable &table_;
81 // Dialect reference for getting cached values.
82 TFGraphDialect &dialect_;
83 };
84
85 // Base class for converting a functional control-flow `SourceOp` to a
86 // region-based `DestOp`.
87 template <typename SourceOp, typename DestOp>
88 class ConvertFunctionalToRegionPattern : public OpRewritePattern<SourceOp>,
89 public BasePattern {
90 public:
ConvertFunctionalToRegionPattern(MLIRContext * context,SymbolTable & table,TFGraphDialect & dialect)91 explicit ConvertFunctionalToRegionPattern(MLIRContext *context,
92 SymbolTable &table,
93 TFGraphDialect &dialect)
94 : OpRewritePattern<SourceOp>(context, /*benefit=*/1,
95 {DestOp::getOperationName()}),
96 BasePattern(table, dialect) {}
97 };
98
99 // Base class for patterns to convert an if-like TFG op to region form.
100 template <typename IfLikeOp, typename IfLikeRegionOp>
101 struct ConvertIfLikeOp
102 : public ConvertFunctionalToRegionPattern<IfLikeOp, IfLikeRegionOp> {
103 using ConvertFunctionalToRegionPattern<
104 IfLikeOp, IfLikeRegionOp>::ConvertFunctionalToRegionPattern;
105
106 LogicalResult matchAndRewrite(IfLikeOp op,
107 PatternRewriter &rewriter) const override;
108 };
109
110 using ConvertIfOp = ConvertIfLikeOp<IfOp, IfRegionOp>;
111 using ConvertStatelessIfOp =
112 ConvertIfLikeOp<StatelessIfOp, StatelessIfRegionOp>;
113 using ConvertStatefulIfOp = ConvertIfLikeOp<StatefulIfOp, StatefulIfRegionOp>;
114
115 // Base class for patterns to convert a case-like TFG op to region form.
116 template <typename CaseLikeOp, typename CaseLikeRegionOp>
117 struct ConvertCaseLikeOp
118 : public ConvertFunctionalToRegionPattern<CaseLikeOp, CaseLikeRegionOp> {
119 using ConvertFunctionalToRegionPattern<
120 CaseLikeOp, CaseLikeRegionOp>::ConvertFunctionalToRegionPattern;
121
122 LogicalResult matchAndRewrite(CaseLikeOp op,
123 PatternRewriter &rewriter) const override;
124 };
125
126 using ConvertCaseOp = ConvertCaseLikeOp<CaseOp, CaseRegionOp>;
127 using ConvertStatelessCaseOp =
128 ConvertCaseLikeOp<StatelessCaseOp, StatelessCaseRegionOp>;
129 using ConvertStatefulCaseOp =
130 ConvertCaseLikeOp<StatefulCaseOp, StatefulCaseRegionOp>;
131
132 // Base class for patterns to convert a while-like TFG op to region form.
133 template <typename WhileLikeOp, typename WhileLikeRegionOp>
134 struct ConvertWhileLikeOp
135 : public ConvertFunctionalToRegionPattern<WhileLikeOp, WhileLikeRegionOp> {
136 using ConvertFunctionalToRegionPattern<
137 WhileLikeOp, WhileLikeRegionOp>::ConvertFunctionalToRegionPattern;
138
139 LogicalResult matchAndRewrite(WhileLikeOp op,
140 PatternRewriter &rewriter) const override;
141 };
142
143 using ConvertWhileOp = ConvertWhileLikeOp<WhileOp, WhileRegionOp>;
144 using ConvertStatelessWhileOp =
145 ConvertWhileLikeOp<StatelessWhileOp, StatelessWhileRegionOp>;
146 using ConvertStatefulWhileOp =
147 ConvertWhileLikeOp<StatefulWhileOp, StatefulWhileRegionOp>;
148
149 // Convert a functional for-loop to a region-based for-loop.
150 struct ConvertForOp
151 : public ConvertFunctionalToRegionPattern<ForOp, ForRegionOp> {
152 using ConvertFunctionalToRegionPattern<
153 ForOp, ForRegionOp>::ConvertFunctionalToRegionPattern;
154
155 LogicalResult matchAndRewrite(tfg::ForOp op,
156 PatternRewriter &rewriter) const override;
157 };
158
159 } // namespace
160
161 //===----------------------------------------------------------------------===//
162 // Utility Functions
163 //===----------------------------------------------------------------------===//
164
165 // We cannot inline or modify a function if it does not exist, if it is generic,
166 // if it has a computed gradient, or if it is marked for compilation (e.g. by
167 // XLA).
CannotInline(GraphFuncOp func)168 static bool CannotInline(GraphFuncOp func) {
169 return !func || func.generic() || func.gradient() ||
170 func.isMarkedForCompilation();
171 }
172
173 // Determine which optional attributes of a non-generic function to preserve.
174 // Preserved attributes:
175 // - `description`
176 // - `is_stateful`
177 // - `resource_arg_unique_ids_keys`
178 // - `resource_arg_unique_ids_values`
179 //
180 // The attributes of a non-generic function to preserve:
181 // - Intrinsic `tfg.*` attributes are preserved.
182 // - Non-intrinsic `tf.*` attributes are preserved.
183 //
184 // The result attributes of a non-generic function to preserve:
185 // - Intrinsic `tfg.*` attributes are preserved.
PreserveFunctionAttributes(GraphFuncOp func)186 static DictionaryAttr PreserveFunctionAttributes(GraphFuncOp func) {
187 NamedAttrList preserved_attrs;
188 const auto preserve = [&](StringAttr name) {
189 if (Attribute attr = func->getAttr(name))
190 preserved_attrs.append(name, attr);
191 };
192 preserve(func.descriptionAttrName());
193 preserve(func.is_statefulAttrName());
194 preserve(func.resource_arg_unique_ids_keysAttrName());
195 preserve(func.resource_arg_unique_ids_valuesAttrName());
196 // Propagate tf.* attributes.
197 // TODO(jeffniu): `tf` dialect is not loaded.
198 for (const NamedAttribute &attr : func->getAttrs())
199 if (attr.getName().getValue().startswith("tf."))
200 preserved_attrs.append(attr);
201
202 // Certain pipelines (Brella) will split a graph into subgraphs before merging
203 // them back together. If the subgraphs pass through conversion to and from
204 // region form, the previously unique branch/loop body function names become
205 // not unique, which prevents the graphs from being correctly merged back
206 // together. Also, if an op is referenced in two different subgraphs, if
207 // Grappler changes the function name, the reference will only be valid in the
208 // first subgraph, leading to a function-not-found error. Preserve the
209 // original function name.
210 preserve(func.sym_nameAttrName());
211
212 return preserved_attrs.getDictionary(func.getContext());
213 }
214
215 // Given the function, argument, and result attributes to be preserved,
216 // determine if they are empty and can be dropped.
ArePreservedAttrsEmpty(DictionaryAttr func_attrs,ArrayAttr arg_attrs,ArrayAttr res_attrs)217 static bool ArePreservedAttrsEmpty(DictionaryAttr func_attrs,
218 ArrayAttr arg_attrs, ArrayAttr res_attrs) {
219 const auto is_empty = [](DictionaryAttr dict) { return dict.empty(); };
220 return func_attrs.empty() &&
221 llvm::all_of(arg_attrs.getAsRange<DictionaryAttr>(), is_empty) &&
222 llvm::all_of(res_attrs.getAsRange<DictionaryAttr>(), is_empty);
223 }
224
225 // Determine if the region attributes are empty.
AreRegionAttrsEmpty(RegionAttr attrs)226 static bool AreRegionAttrsEmpty(RegionAttr attrs) {
227 return ArePreservedAttrsEmpty(attrs.getAttrs(), attrs.getArgAttrs(),
228 attrs.getResAttrs());
229 }
230
231 // Preserve certain attributes of a function so that they can be used later if
232 // the region op is converted back to functional form. When `If` and `Case` are
233 // converted, all arguments attributes are dropped because the arguments are
234 // converted to implicit captures. For `While` and `For`, no arguments are
235 // removed.
236 //
237 // If `drop_args` is set, then all argument attributes are dropped, regardless
238 // of the number of arguments in the function.
239 //
240 // If `allow_empty` is set, then this function will always return a non-null
241 // attribute, even if the region attributes are empty.
PreserveAttributes(GraphFuncOp func,bool drop_args=false,bool allow_empty=false)242 static RegionAttr PreserveAttributes(GraphFuncOp func, bool drop_args = false,
243 bool allow_empty = false) {
244 DictionaryAttr func_attrs = PreserveFunctionAttributes(func);
245 // Since all argument and result attributes are preserved, just propagate the
246 // array attributes. Remove the control argument attributes from the argument
247 // attributes.
248 const auto every_other = [](ArrayAttr attrs) {
249 SmallVector<Attribute> others;
250 for (unsigned i = 0; i < attrs.size(); i += 2) others.push_back(attrs[i]);
251 return ArrayAttr::get(attrs.getContext(), others);
252 };
253
254 ArrayAttr arg_attrs = drop_args || !func.arg_attrs()
255 ? ArrayAttr::get(func.getContext(), {})
256 : every_other(*func.arg_attrs());
257 ArrayAttr res_attrs = func.res_attrs()
258 ? *func.res_attrs()
259 : ArrayAttr::get(func.getContext(), {});
260
261 if (!allow_empty && ArePreservedAttrsEmpty(func_attrs, arg_attrs, res_attrs))
262 return nullptr;
263 return RegionAttr::get(func_attrs, arg_attrs, res_attrs);
264 }
265
ReplaceReturnWithYield(Block & block,TypeRange types,PatternRewriter & rewriter) const266 YieldOp BasePattern::ReplaceReturnWithYield(Block &block, TypeRange types,
267 PatternRewriter &rewriter) const {
268 auto op = cast<ReturnOp>(block.getTerminator());
269 rewriter.setInsertionPoint(op);
270 ValueRange args, ctls;
271 std::tie(args, ctls) = SplitControl(op.getOperands());
272 return rewriter.replaceOpWithNewOp<YieldOp>(op, args, ctls);
273 }
274
CloneAndReorderArgs(TypeRange types,Region & from,Region & to,PatternRewriter & rewriter) const275 void BasePattern::CloneAndReorderArgs(TypeRange types, Region &from, Region &to,
276 PatternRewriter &rewriter) const {
277 ControlType control_ty = dialect_.getControlType();
278 BlockAndValueMapping bv;
279 CloneAndRename(from, to, bv);
280 SmallVector<Location> arg_locs(types.size(), from.getLoc());
281 for (auto &it :
282 llvm::enumerate(llvm::to_vector(to.addArguments(types, arg_locs)))) {
283 BlockArgument arg = to.getArgument(it.index() * 2);
284 BlockArgument ctl = to.getArgument(arg.getArgNumber() + 1);
285 arg.replaceAllUsesWith(it.value());
286 ctl.replaceAllUsesWith(to.addArgument(control_ty, arg.getLoc()));
287 }
288 llvm::BitVector erase_indices(to.getNumArguments());
289 erase_indices.set(0, types.size() * 2);
290 to.front().eraseArguments(erase_indices);
291 }
292
CloneAndRename(Region & from,Region & to,BlockAndValueMapping & bv) const293 void BasePattern::CloneAndRename(Region &from, Region &to,
294 BlockAndValueMapping &bv) const {
295 from.cloneInto(&to, bv);
296 StringAttr name_id = dialect_.getNameAttrIdentifier();
297 auto op_name = to.getParentOp()->getAttrOfType<StringAttr>(name_id);
298 if (!op_name) return;
299 for (Operation &op : to.getOps()) {
300 if (auto name = op.getAttrOfType<StringAttr>(name_id)) {
301 auto new_name =
302 StringAttr::get(op.getContext(), name.getValue() + "_tfg_inlined_" +
303 op_name.getValue() + "_" +
304 Twine(to.getRegionNumber()));
305 op.setAttr(name_id, new_name);
306 }
307 }
308 }
309
310 //===----------------------------------------------------------------------===//
311 // ConvertIfLikeOp
312 //===----------------------------------------------------------------------===//
313
314 template <typename IfLikeOp, typename IfLikeRegionOp>
matchAndRewrite(IfLikeOp op,PatternRewriter & rewriter) const315 LogicalResult ConvertIfLikeOp<IfLikeOp, IfLikeRegionOp>::matchAndRewrite(
316 IfLikeOp op, PatternRewriter &rewriter) const {
317 GraphFuncOp then_func = this->LookupFunc(op.then_branch());
318 GraphFuncOp else_func = this->LookupFunc(op.else_branch());
319 if (CannotInline(then_func) || CannotInline(else_func)) return failure();
320
321 // Create the region-based op, passing in the required attributes.
322 ValueRange args, ctls;
323 std::tie(args, ctls) = this->SplitControl(op.args());
324 auto region_op = rewriter.create<IfLikeRegionOp>(
325 op.getLoc(), op.getResultTypes(), op.cond(), ctls,
326 op.then_branch().getAttrs(), op.else_branch().getAttrs(),
327 PreserveAttributes(then_func, /*drop_args=*/true),
328 PreserveAttributes(else_func, /*drop_args=*/true));
329 util::ForwardNonIntrinsicAttributes(op, region_op);
330
331 // Move the regions over and replace the block arguments.
332 ControlType control_ty = this->dialect_.getControlType();
333 BlockAndValueMapping then_bv, else_bv;
334 auto func_args =
335 llvm::zip(then_func.getArguments(), else_func.getArguments()).begin();
336 rewriter.setInsertionPoint(region_op);
337 Value then_arg, else_arg, then_ctl, else_ctl;
338 for (Value arg : args) {
339 std::tie(then_arg, else_arg) = *func_args;
340 ++func_args;
341 std::tie(then_ctl, else_ctl) = *func_args;
342 ++func_args;
343 Value ctl = LookupControlDependency(arg);
344 then_bv.map(then_arg, arg);
345 else_bv.map(else_arg, arg);
346 then_bv.map(then_ctl, ctl);
347 else_bv.map(else_ctl, ctl);
348 }
349 this->CloneAndRename(then_func.body(), region_op.then_region(), then_bv);
350 this->CloneAndRename(else_func.body(), region_op.else_region(), else_bv);
351
352 // Replace the terminators `return` with `yield`.
353 TypeRange ret_types = region_op.outs().getTypes();
354 this->ReplaceReturnWithYield(region_op.then_block(), ret_types, rewriter);
355 this->ReplaceReturnWithYield(region_op.else_block(), ret_types, rewriter);
356 rewriter.replaceOp(op, region_op.getResults());
357 return success();
358 }
359
360 //===----------------------------------------------------------------------===//
361 // ConvertCaseLikeOp
362 //===----------------------------------------------------------------------===//
363
364 template <typename CaseLikeOp, typename CaseLikeRegionOp>
matchAndRewrite(CaseLikeOp op,PatternRewriter & rewriter) const365 LogicalResult ConvertCaseLikeOp<CaseLikeOp, CaseLikeRegionOp>::matchAndRewrite(
366 CaseLikeOp op, PatternRewriter &rewriter) const {
367 // Lookup all the branch functions and save their attributes.
368 SmallVector<GraphFuncOp> branch_funcs;
369 SmallVector<Attribute> branch_attrs;
370 branch_funcs.reserve(op.branches().size());
371 for (auto attr : op.branches().template getAsRange<FuncAttr>()) {
372 GraphFuncOp branch_func = this->LookupFunc(attr);
373 if (CannotInline(branch_func)) return failure();
374 branch_funcs.push_back(branch_func);
375 branch_attrs.push_back(attr.getAttrs());
376 }
377
378 SmallVector<Attribute> preserved_attrs;
379 for (GraphFuncOp func : branch_funcs) {
380 preserved_attrs.push_back(
381 PreserveAttributes(func, /*drop_args=*/true, /*allow_empty=*/true));
382 }
383 ArrayAttr region_attrs = nullptr;
384 if (!llvm::all_of(preserved_attrs, [](Attribute attr) {
385 return AreRegionAttrsEmpty(attr.cast<RegionAttr>());
386 }))
387 region_attrs = rewriter.getArrayAttr(preserved_attrs);
388
389 // Create the region-based op, passing in the required attributes.
390 ValueRange args, ctls;
391 std::tie(args, ctls) = this->SplitControl(op.args());
392 auto region_op = rewriter.create<CaseLikeRegionOp>(
393 op.getLoc(), op.getResultTypes(), op.branch_index(), ctls,
394 rewriter.getArrayAttr(branch_attrs), region_attrs, op.branches().size());
395 util::ForwardNonIntrinsicAttributes(op, region_op);
396
397 // Move the regions over and replace the block arguments.
398 ControlType control_ty = this->dialect_.getControlType();
399 SmallVector<BlockAndValueMapping> bvs(branch_funcs.size(), {});
400 rewriter.setInsertionPoint(region_op);
401 for (auto &arg : llvm::enumerate(args)) {
402 for (auto it : llvm::zip(branch_funcs, bvs)) {
403 BlockArgument branch_arg =
404 GraphFuncOp::getDataValue(std::get<0>(it).body(), arg.index());
405 BlockAndValueMapping &bv = std::get<1>(it);
406 bv.map(branch_arg, arg.value());
407 bv.map(GraphFuncOp::getControlTokenOf(branch_arg),
408 LookupControlDependency(arg.value()));
409 }
410 }
411 for (auto it : llvm::zip(branch_funcs, region_op.branches(), bvs)) {
412 this->CloneAndRename(std::get<0>(it).body(), std::get<1>(it),
413 std::get<2>(it));
414 }
415
416 // Replace the terminators `return` with `yield`.
417 TypeRange ret_types = region_op.outs().getTypes();
418 for (Region &branch : region_op.branches())
419 this->ReplaceReturnWithYield(branch.front(), ret_types, rewriter);
420 rewriter.replaceOp(op, region_op.getResults());
421 return success();
422 }
423
424 //===----------------------------------------------------------------------===//
425 // ConvertWhileLikeOp
426 //===----------------------------------------------------------------------===//
427
428 template <typename WhileLikeOp, typename WhileLikeRegionOp>
429 LogicalResult
matchAndRewrite(WhileLikeOp op,PatternRewriter & rewriter) const430 ConvertWhileLikeOp<WhileLikeOp, WhileLikeRegionOp>::matchAndRewrite(
431 WhileLikeOp op, PatternRewriter &rewriter) const {
432 GraphFuncOp cond_func = this->LookupFunc(op.cond());
433 GraphFuncOp body_func = this->LookupFunc(op.body());
434 if (CannotInline(cond_func) || CannotInline(body_func)) return failure();
435
436 // Note that `tfg.While` may not have the same input and output types. We will
437 // need to insert casts.
438 // TODO(jeffniu): Change this to call the infer return types builder.
439 ValueRange init, ctls;
440 std::tie(init, ctls) = this->SplitControl(op.args());
441 auto region_op = rewriter.create<WhileLikeRegionOp>(
442 op.getLoc(), op.getResultTypes(), init, ctls,
443 op.parallel_iterationsAttr(), op.cond().getAttrs(), op.body().getAttrs(),
444 PreserveAttributes(cond_func), PreserveAttributes(body_func));
445 util::ForwardNonIntrinsicAttributes(op, region_op);
446
447 // Just copy the function bodies into the regions. `RegionBranchOpInterface`
448 // requires that we re-order the block arguments such that the control tokens
449 // all come after the data arguments.
450 this->CloneAndReorderArgs(init.getTypes(), cond_func.body(),
451 region_op.cond_region(), rewriter);
452 this->CloneAndReorderArgs(init.getTypes(), body_func.body(),
453 region_op.body_region(), rewriter);
454 this->ReplaceReturnWithYield(region_op.body_block(), init.getTypes(),
455 rewriter);
456
457 // Replace `return(tensor<*xi1>)` with `condition`.
458 auto ret_op = cast<ReturnOp>(region_op.cond_block().getTerminator());
459 ValueRange ret_args, ret_ctls;
460 std::tie(ret_args, ret_ctls) = this->SplitControl(ret_op.getOperands());
461 rewriter.setInsertionPoint(ret_op);
462 rewriter.replaceOpWithNewOp<ConditionOp>(
463 ret_op, ret_args.front(), GetLoopRegionDataArgs(region_op.cond_region()),
464 ret_ctls);
465 rewriter.replaceOp(op, region_op->getResults());
466 return success();
467 }
468
469 //===----------------------------------------------------------------------===//
470 // ConvertForOp
471 //===----------------------------------------------------------------------===//
472
matchAndRewrite(tfg::ForOp op,PatternRewriter & rewriter) const473 LogicalResult ConvertForOp::matchAndRewrite(tfg::ForOp op,
474 PatternRewriter &rewriter) const {
475 GraphFuncOp body_func = LookupFunc(op.body());
476 if (CannotInline(body_func)) return failure();
477
478 // Note that `For` may not have the same input and output typse, although
479 // `ForRegion` does. We will need to insert casts.
480 ValueRange init, ctls;
481 std::tie(init, ctls) = SplitControl(op.args());
482 auto region_op = rewriter.create<ForRegionOp>(
483 op.getLoc(), op.getResultTypes(), op.start(), op.limit(), op.delta(),
484 init, ctls, op.body().getAttrs(), PreserveAttributes(body_func));
485 util::ForwardNonIntrinsicAttributes(op, region_op);
486
487 // Copy the function body into the region. One index type must be added.
488 OperandRange args = op.getOperands().drop_front(2).drop_back(ctls.size());
489 CloneAndReorderArgs(args.getTypes(), body_func.body(),
490 region_op.body_region(), rewriter);
491 ReplaceReturnWithYield(region_op.body_block(), init.getTypes(), rewriter);
492 rewriter.replaceOp(op, region_op->getResults());
493 return success();
494 }
495
496 //===----------------------------------------------------------------------===//
497 // Populate Patterns
498 //===----------------------------------------------------------------------===//
499
PopulateFunctionalToRegionPatterns(RewritePatternSet & patterns,SymbolTable & table)500 void PopulateFunctionalToRegionPatterns(RewritePatternSet &patterns,
501 SymbolTable &table) {
502 patterns.insert<ConvertIfOp, ConvertStatelessIfOp, ConvertStatefulIfOp,
503 ConvertWhileOp, ConvertStatelessWhileOp,
504 ConvertStatefulWhileOp, ConvertCaseOp, ConvertStatelessCaseOp,
505 ConvertStatefulCaseOp, ConvertForOp>(
506 patterns.getContext(), table,
507 *patterns.getContext()->getOrLoadDialect<TFGraphDialect>());
508 }
509
510 } // namespace tfg
511 } // namespace mlir
512