xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/region_to_functional/impl.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/transforms/region_to_functional/impl.h"
17 
18 #include <string>
19 #include <tuple>
20 #include <utility>
21 #include <vector>
22 
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringSet.h"
27 #include "llvm/ADT/iterator.h"
28 #include "llvm/Support/ScopedPrinter.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/FunctionInterfaces.h"  // from @llvm-project
35 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
36 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
37 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
38 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
39 #include "mlir/IR/TypeRange.h"  // from @llvm-project
40 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
41 #include "mlir/IR/Value.h"  // from @llvm-project
42 #include "mlir/IR/Verifier.h"  // from @llvm-project
43 #include "mlir/Support/LLVM.h"  // from @llvm-project
44 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
45 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
46 #include "tensorflow/core/ir/dialect.h"
47 #include "tensorflow/core/ir/ops.h"
48 #include "tensorflow/core/ir/types/dialect.h"
49 #include "tensorflow/core/ir/utility.h"
50 #include "tensorflow/core/transforms/utils/utils.h"
51 
52 namespace mlir {
53 namespace tfg {
54 
55 //===----------------------------------------------------------------------===//
56 // Pattern Definitions
57 //===----------------------------------------------------------------------===//
58 
59 namespace {
60 // Cached attribute name identifiers shared by all patterns.
61 struct CachedIdentifiers {
CachedIdentifiersmlir::tfg::__anond19cc84e0111::CachedIdentifiers62   explicit CachedIdentifiers(TFGraphDialect *dialect)
63       : tfg_name(dialect->getTfgNameAttrIdentifier()),
64         tfg_regenerate_output_shapes(StringAttr::get(
65             dialect->getContext(), "tfg.regenerate_output_shapes")) {}
66 
67   // Cached identifier for "tfg.name".
68   StringAttr tfg_name;
69   // Cached identifier for "tfg.regenerate_output_shapes".
70   StringAttr tfg_regenerate_output_shapes;
71 };
72 
73 // A helper for uniqueing argument, result, and control result names, which must
74 // be unique for a function.
75 class NameUniquer {
76  public:
NameUniquer(MLIRContext * ctx)77   explicit NameUniquer(MLIRContext *ctx) : ctx_(ctx) {}
78 
79   // Unique a name. If the name is unused, returns the name. Otherwise,
80   // allocates a new name.
GetUniqued(StringAttr name)81   StringAttr GetUniqued(StringAttr name) {
82     auto it = unique_names_.insert(name);
83     if (it.second) return name;
84     unsigned suffix = 0;
85     StringAttr next_name;
86     do {
87       next_name =
88           StringAttr::get(ctx_, name.getValue() + "_" + Twine(suffix++));
89       it = unique_names_.insert(next_name);
90     } while (!it.second);
91     return next_name;
92   }
93 
94  private:
95   // The MLIR context.
96   MLIRContext *ctx_;
97   // This set contains the occupied names.
98   DenseSet<StringAttr> unique_names_;
99 };
100 
101 // Base class for patterns used to convert region control-flow ops to functional
102 // control-flow ops. This class contains common utility functions and cached
103 // attribute identifiers.
104 class BasePattern {
105  public:
BasePattern(TFGraphDialect & dialect,SymbolTable & table,bool force_control_capture,CachedIdentifiers ids)106   BasePattern(TFGraphDialect &dialect, SymbolTable &table,
107               bool force_control_capture, CachedIdentifiers ids)
108       : ctx_(dialect.getContext()),
109         dialect_(dialect),
110         table_(table),
111         force_control_capture_(force_control_capture),
112         ids_(ids) {}
113 
114  protected:
115   // Collect all values used in the region that are defined above the region.
116   // If a control token is encountered, collect its associated data value. If it
117   // doesn't have one, add it to `ctls`.
118   void CollectValuesDefinedAbove(Region &region, SetVector<Value> &datas,
119                                  SetVector<Value> &ctls) const;
120   // Collect data values used in any of the given regions that are defined above
121   // the regions. These are the values that will be converted to explicit
122   // capture. If a control token with no associated data value is encountered
123   // and `force_control_capture_` is not set, then this function returns
124   // failure. Otherwise, it inserts chain constants and rewrites uses of the
125   // token to use the control outputs of the constants.
126   FailureOr<std::vector<Value>> CollectValuesDefinedAboveAll(
127       RegionRange regions, PatternRewriter &rewriter) const;
128   // Rewrite the regions to be isolated from above by replacing uses of the
129   // given data values with block arguments. Use the same set of values for each
130   // region so that their block arguments are the same.
131   void IsolateRegions(RegionRange regions, MutableArrayRef<Value> datas) const;
132   // Create a chain `Const` operation. The op's data result is unused; its
133   // only purpose is to convert a control edge into a data edge.
134   Operation *MakeChainConstant(Operation *parent, Value ctl, unsigned idx,
135                                PatternRewriter &rewriter) const;
136 
137   // Infer or propagate function attributes. Use a name uniquer to unique names
138   // across function arguments, results, and control results.
139   NamedAttrList BuildAttributes(RegionAttr preserved, ValueRange arguments,
140                                 ValueRange results,
141                                 NameUniquer *name_uniquer) const;
142 
143   // Try to find a name for a data or control value. For op results, check the
144   // op for a name. Otherwise, check the enclosing function's arg attributes.
145   StringAttr TryFindName(Value value, Optional<ValueRange> args) const;
146 
147   // Get the `control_ret_attrs` attributes for control returns. Use a name
148   // uniquer to unique names across function arguments, results, and control
149   // results,
150   ArrayAttr GetControlRetAttrs(ValueRange ctls, ValueRange args,
151                                NameUniquer *name_uniquer) const;
152 
153   // Create a function with the given name and attributes. Use the types of the
154   // block arguments and the given results types. Take the body of the region.
155   GraphFuncOp CreateFunc(Location loc, const Twine &sym_name, Region &region,
156                          TypeRange res_types, NamedAttrList attrs) const;
157 
158   // Convert a (yield-terminated) region to a function and return a reference.
159   FuncAttr Outline(Operation *op, PatternRewriter &rewriter, ValueRange args,
160                    Region &region, RegionAttr preserved, DictionaryAttr attrs,
161                    const Twine &func_name) const;
162 
163   // A region function to outline.
164   struct RegionFunction {
165     // The function body.
166     Region &region;
167     // Potentially null preserved function attributes.
168     RegionAttr preserved_attrs;
169     // The function call attributes.
170     DictionaryAttr call_attrs;
171     // The function name to use.
172     std::string func_name;
173   };
174   // Outline a list of (yield-terminated) region functions, but if any function
175   // could not be re-used, then new functions are created for all of them.
176   template <typename FuncAttrT>
177   void ReuseAllOrOutline(Operation *op, PatternRewriter &rewriter,
178                          ValueRange args, ArrayRef<RegionFunction> regions,
179                          SmallVectorImpl<FuncAttrT> &functions) const;
180 
181   // Try to find a "reusable" function that has the same body as the provided
182   // region. A function is "reusable" if its body has the same topology as the
183   // provided region, corresponding operands have the same attributes, except
184   // for node name, and value types are compatible.
185   FuncAttr FindReusableFunc(Region &region, RegionAttr preserved,
186                             DictionaryAttr attrs) const;
187 
188   // If a function exists and has nested regions, return false. Otherwise,
189   // return true.
190   bool FuncHasNestedRegions(RegionAttr preserved) const;
191 
192  protected:
193   // Reference to the context.
194   MLIRContext *ctx_;
195   // Dialect reference for getting cached values;
196   TFGraphDialect &dialect_;
197   // Symbol table to use to look up existing functions.
198   SymbolTable &table_;
199   // Whether control tokens with no data values should be forcefully captured by
200   // inserting a chain `Const` op.
201   bool force_control_capture_;
202   // Cached attribute identifiers.
203   CachedIdentifiers ids_;
204 };
205 
206 //===----------------------------------------------------------------------===//
207 // ConvertToExplicitCapture
208 
209 template <typename OpT>
210 struct ConvertToExplicitCapture : public BasePattern {
211   using BasePattern::BasePattern;
212 
213   virtual ~ConvertToExplicitCapture() = default;
214 
215   // Convert the regions of the operation to explicit capture. Returns the
216   // newly captured values and an updated op.
217   FailureOr<std::pair<OpT, std::vector<Value>>> Run(OpT op,
218                                                     PatternRewriter &rewriter);
219 
220   // Rebuild the regions of the operation with added values.
221   virtual OpT RebuildWith(OpT op, ValueRange added,
222                           PatternRewriter &rewriter) const = 0;
223 };
224 
225 template <typename IfLikeRegionOp>
226 struct ConvertIfLikeRegionOpToExplicitCapture
227     : public ConvertToExplicitCapture<IfLikeRegionOp> {
228   using ConvertToExplicitCapture<IfLikeRegionOp>::ConvertToExplicitCapture;
229 
RebuildWithmlir::tfg::__anond19cc84e0111::ConvertIfLikeRegionOpToExplicitCapture230   IfLikeRegionOp RebuildWith(IfLikeRegionOp op, ValueRange added,
231                              PatternRewriter &rewriter) const override {
232     return rewriter.create<IfLikeRegionOp>(
233         op.getLoc(), op.getResultTypes(), op.cond(), op.ctls(),
234         op.then_attrsAttr(), op.else_attrsAttr(), op.then_region_attrsAttr(),
235         op.else_region_attrsAttr());
236   }
237 };
238 
239 template <typename CaseLikeRegionOp>
240 struct ConvertCaseLikeRegionOpToExplicitCapture
241     : public ConvertToExplicitCapture<CaseLikeRegionOp> {
242   using ConvertToExplicitCapture<CaseLikeRegionOp>::ConvertToExplicitCapture;
243 
RebuildWithmlir::tfg::__anond19cc84e0111::ConvertCaseLikeRegionOpToExplicitCapture244   CaseLikeRegionOp RebuildWith(CaseLikeRegionOp op, ValueRange added,
245                                PatternRewriter &rewriter) const override {
246     return rewriter.create<CaseLikeRegionOp>(
247         op.getLoc(), op.getResultTypes(), op.branch_index(), op.ctls(),
248         op.branch_attrsAttr(), op.region_attrsAttr(), op.branches().size());
249   }
250 };
251 
252 // Get the block arguments that correspond to the passthrough iteration
253 // arguments created from converting implicit captures. Append them to the
254 // previous region results `prev`.
GetForwardedValues(ValueRange added,Block::BlockArgListType block_args,ValueRange prev)255 static SmallVector<Value> GetForwardedValues(ValueRange added,
256                                              Block::BlockArgListType block_args,
257                                              ValueRange prev) {
258   SmallVector<Value> args(prev.begin(), prev.end());
259   llvm::append_range(args, block_args.slice(prev.size(), added.size()));
260   return args;
261 }
262 
263 template <typename WhileLikeRegionOp>
264 struct ConvertWhileLikeRegionOpToExplicitCapture
265     : public ConvertToExplicitCapture<WhileLikeRegionOp> {
266   using ConvertToExplicitCapture<WhileLikeRegionOp>::ConvertToExplicitCapture;
267 
RebuildWithmlir::tfg::__anond19cc84e0111::ConvertWhileLikeRegionOpToExplicitCapture268   WhileLikeRegionOp RebuildWith(WhileLikeRegionOp op, ValueRange added,
269                                 PatternRewriter &rewriter) const override {
270     ConditionOp cond_op = op.cond_condition();
271     rewriter.setInsertionPoint(cond_op);
272     rewriter.replaceOpWithNewOp<ConditionOp>(
273         cond_op, cond_op.cond(),
274         GetForwardedValues(added, op.cond_region().getArguments(),
275                            cond_op.args()),
276         cond_op.ctls());
277 
278     YieldOp yield_op = op.body_yield();
279     rewriter.setInsertionPoint(yield_op);
280     rewriter.replaceOpWithNewOp<YieldOp>(
281         yield_op,
282         GetForwardedValues(added, op.body_region().getArguments(),
283                            yield_op.args()),
284         yield_op.ctls());
285 
286     SmallVector<Value> operands = llvm::to_vector(op.init());
287     llvm::append_range(operands, added);
288     SmallVector<Type> results = llvm::to_vector(op.outs().getTypes());
289     llvm::append_range(results, added.getTypes());
290     util::LoopRegionResultAdded(op.body_region(), added.size());
291 
292     rewriter.setInsertionPoint(op);
293     return rewriter.create<WhileLikeRegionOp>(
294         op.getLoc(), results, op.ctl().getType(), operands, op.ctls(),
295         op.parallel_iterationsAttr(), op.cond_attrsAttr(), op.body_attrsAttr(),
296         op.cond_region_attrsAttr(), op.body_region_attrsAttr());
297   }
298 };
299 
300 struct ConvertForRegionOpToExplicitCapture
301     : public ConvertToExplicitCapture<ForRegionOp> {
302   using ConvertToExplicitCapture<ForRegionOp>::ConvertToExplicitCapture;
303 
RebuildWithmlir::tfg::__anond19cc84e0111::ConvertForRegionOpToExplicitCapture304   ForRegionOp RebuildWith(ForRegionOp op, ValueRange added,
305                           PatternRewriter &rewriter) const override {
306     YieldOp yield_op = op.body_yield();
307     rewriter.setInsertionPoint(yield_op);
308     // Get the iteration arguments excluding the for loop index argument.
309     auto iter_args = GetLoopRegionDataArgs(op.body_region()).slice(1);
310     rewriter.replaceOpWithNewOp<YieldOp>(
311         yield_op, GetForwardedValues(added, iter_args, yield_op.args()),
312         yield_op.ctls());
313 
314     SmallVector<Value> operands = llvm::to_vector(op.init());
315     llvm::append_range(operands, added);
316     SmallVector<Type> results = llvm::to_vector(op.outs().getTypes());
317     llvm::append_range(results, added.getTypes());
318     util::LoopRegionResultAdded(op.body_region(), added.size());
319 
320     rewriter.setInsertionPoint(op);
321     return rewriter.create<ForRegionOp>(
322         op.getLoc(), results, op.ctl().getType(), op.start(), op.limit(),
323         op.delta(), operands, op.ctls(), op.body_attrsAttr(),
324         op.region_attrsAttr());
325   }
326 };
327 
328 //===----------------------------------------------------------------------===//
329 // ConvertRegionToFunctional
330 
331 template <typename SourceOp, typename DestOp>
332 struct ConvertRegionToFunctionalPattern : public OpRewritePattern<SourceOp>,
333                                           public BasePattern {
ConvertRegionToFunctionalPatternmlir::tfg::__anond19cc84e0111::ConvertRegionToFunctionalPattern334   explicit ConvertRegionToFunctionalPattern(MLIRContext *context,
335                                             TFGraphDialect &dialect,
336                                             SymbolTable &table,
337                                             bool force_control_capture,
338                                             CachedIdentifiers ids)
339       : OpRewritePattern<SourceOp>(context, /*benefit=*/1,
340                                    {DestOp::getOperationName()}),
341         BasePattern(dialect, table, force_control_capture, ids) {}
342 };
343 
344 // Base class for patterns to convert an if-like TFG region op to
345 // functional form.
346 template <typename IfLikeRegionOp, typename IfLikeOp>
347 struct ConvertIfLikeOp
348     : public ConvertRegionToFunctionalPattern<IfLikeRegionOp, IfLikeOp> {
349   using ConvertRegionToFunctionalPattern<
350       IfLikeRegionOp, IfLikeOp>::ConvertRegionToFunctionalPattern;
351 
352   LogicalResult matchAndRewrite(IfLikeRegionOp op,
353                                 PatternRewriter &rewriter) const override;
354 };
355 
356 using ConvertIfOp = ConvertIfLikeOp<IfRegionOp, IfOp>;
357 using ConvertStatelessIfOp =
358     ConvertIfLikeOp<StatelessIfRegionOp, StatelessIfOp>;
359 using ConvertStatefulIfOp = ConvertIfLikeOp<StatefulIfRegionOp, StatefulIfOp>;
360 
361 // Base class for patterns to convert an case-like TFG region op to
362 // functional form.
363 template <typename CaseLikeRegionOp, typename CaseLikeOp>
364 struct ConvertCaseLikeOp
365     : public ConvertRegionToFunctionalPattern<CaseLikeRegionOp, CaseLikeOp> {
366   using ConvertRegionToFunctionalPattern<
367       CaseLikeRegionOp, CaseLikeOp>::ConvertRegionToFunctionalPattern;
368 
369   LogicalResult matchAndRewrite(CaseLikeRegionOp op,
370                                 PatternRewriter &rewriter) const override;
371 };
372 
373 using ConvertCaseOp = ConvertCaseLikeOp<CaseRegionOp, CaseOp>;
374 using ConvertStatelessCaseOp =
375     ConvertCaseLikeOp<StatelessCaseRegionOp, StatelessCaseOp>;
376 using ConvertStatefulCaseOp =
377     ConvertCaseLikeOp<StatefulCaseRegionOp, StatefulCaseOp>;
378 
379 // Base class for patterns to convert a while-like TFG region op to functional
380 // form.
381 template <typename WhileLikeRegionOp, typename WhileLikeOp>
382 struct ConvertWhileLikeOp
383     : public ConvertRegionToFunctionalPattern<WhileLikeRegionOp, WhileLikeOp> {
384   using ConvertRegionToFunctionalPattern<
385       WhileLikeRegionOp, WhileLikeOp>::ConvertRegionToFunctionalPattern;
386 
387   LogicalResult matchAndRewrite(WhileLikeRegionOp op,
388                                 PatternRewriter &rewriter) const override;
389 };
390 
391 using ConvertWhileOp = ConvertWhileLikeOp<WhileRegionOp, WhileOp>;
392 using ConvertStatelessWhileOp =
393     ConvertWhileLikeOp<StatelessWhileRegionOp, StatelessWhileOp>;
394 using ConvertStatefulWhileOp =
395     ConvertWhileLikeOp<StatefulWhileRegionOp, StatefulWhileOp>;
396 
397 // Convert a region-based for-loop to a functional for-loop.
398 struct ConvertForOp
399     : public ConvertRegionToFunctionalPattern<ForRegionOp, ForOp> {
400   using ConvertRegionToFunctionalPattern<
401       ForRegionOp, ForOp>::ConvertRegionToFunctionalPattern;
402 
403   LogicalResult matchAndRewrite(ForRegionOp op,
404                                 PatternRewriter &rewriter) const override;
405 };
406 
407 }  // namespace
408 
409 //===----------------------------------------------------------------------===//
410 // Utility Functions
411 //===----------------------------------------------------------------------===//
412 
CollectValuesDefinedAbove(Region & region,SetVector<Value> & datas,SetVector<Value> & ctls) const413 void BasePattern::CollectValuesDefinedAbove(Region &region,
414                                             SetVector<Value> &datas,
415                                             SetVector<Value> &ctls) const {
416   ControlType control_ty = dialect_.getControlType();
417   visitUsedValuesDefinedAbove(region, [&](OpOperand *operand) {
418     Value value = operand->get();
419     if (value.getType() != control_ty) {
420       datas.insert(value);
421     } else if (Optional<Value> data = LookupDataValue(value)) {
422       datas.insert(*data);
423     } else {
424       ctls.insert(value);
425     }
426   });
427 }
428 
MakeChainConstant(Operation * parent,Value ctl,unsigned idx,PatternRewriter & rewriter) const429 Operation *BasePattern::MakeChainConstant(Operation *parent, Value ctl,
430                                           unsigned idx,
431                                           PatternRewriter &rewriter) const {
432   OperationName name("tfg.Const", ctl.getContext());
433   OperationState state(ctl.getLoc(), name);
434   IntegerType i32 = rewriter.getI32Type();
435   ShapedType tensor_type = RankedTensorType::get({}, i32);
436   state.addOperands(ctl);
437   state.addAttribute("value", DenseElementsAttr::get(tensor_type, 0));
438   state.addAttribute("dtype", TypeAttr::get(i32));
439   state.addTypes({tensor_type, ctl.getType()});
440 
441   // Inherit `tfg.tpu_replicate`, `assigned_device`, and `device`.
442   for (StringAttr attr_name : {StringAttr::get(ctx_, "_tpu_replicate"),
443                                dialect_.getAssignedDeviceAttrIdentifier(),
444                                dialect_.getDeviceAttrIdentifier()}) {
445     if (Attribute attr = parent->getAttr(attr_name))
446       state.addAttribute(attr_name, attr);
447   }
448 
449   // Inherit a name based on the parent name.
450   StringAttr name_id = dialect_.getNameAttrIdentifier();
451   if (StringAttr name = parent->getAttrOfType<StringAttr>(name_id)) {
452     auto const_name = rewriter.getStringAttr(
453         name.getValue() + "_mlir_const_capture_" + Twine(idx));
454     state.addAttribute(name_id, const_name);
455   }
456 
457   return rewriter.create(state);
458 }
459 
CollectValuesDefinedAboveAll(RegionRange regions,PatternRewriter & rewriter) const460 FailureOr<std::vector<Value>> BasePattern::CollectValuesDefinedAboveAll(
461     RegionRange regions, PatternRewriter &rewriter) const {
462   SetVector<Value> data_set, ctl_only;
463   for (Region &region : llvm::make_pointee_range(regions))
464     CollectValuesDefinedAbove(region, data_set, ctl_only);
465   std::vector<Value> datas = data_set.takeVector();
466 
467   // If in any of the regions we found a use of a control token defined above
468   // the regions with no associated data value, then it cannot be converted to
469   // explicit capture unless we insert chain constants. If this option was not
470   // set, return failure because the region op cannot be converted.
471   if (!force_control_capture_ && !ctl_only.empty()) return failure();
472 
473   Operation *parent = regions.front()->getParentOp();
474   for (auto &ctl : llvm::enumerate(ctl_only.takeVector())) {
475     Operation *const_op =
476         MakeChainConstant(parent, ctl.value(), ctl.index(), rewriter);
477     for (Region *region : regions)
478       replaceAllUsesInRegionWith(ctl.value(), const_op->getResult(1), *region);
479     datas.push_back(const_op->getResult(0));
480   }
481 
482   return datas;
483 }
484 
IsolateRegions(RegionRange regions,MutableArrayRef<Value> datas) const485 void BasePattern::IsolateRegions(RegionRange regions,
486                                  MutableArrayRef<Value> datas) const {
487   ValueControlRetRange ctls(datas);
488   Value data, ctl;
489   for (Region &region : llvm::make_pointee_range(regions)) {
490     for (auto it : llvm::zip(datas, ctls)) {
491       std::tie(data, ctl) = it;
492       util::LoopRegionArgumentUpdate result =
493           util::LoopRegionAddArgument(region, data.getType());
494       replaceAllUsesInRegionWith(data, result.data, region);
495       replaceAllUsesInRegionWith(ctl, result.ctl, region);
496     }
497   }
498 }
499 
BuildAttributes(RegionAttr preserved,ValueRange arguments,ValueRange results,NameUniquer * name_uniquer) const500 NamedAttrList BasePattern::BuildAttributes(RegionAttr preserved,
501                                            ValueRange arguments,
502                                            ValueRange results,
503                                            NameUniquer *name_uniquer) const {
504   NamedAttrList attrs(preserved ? preserved.getAttrs() : DictionaryAttr());
505   // The original function name is preserved in the region attributes, but don't
506   // re-use it when creating a new function.
507   attrs.erase(SymbolTable::getSymbolAttrName());
508 
509   SmallVector<Attribute> arg_attrs, res_attrs;
510   ArrayAttr preserved_arg_attrs =
511       preserved ? preserved.getArgAttrs() : ArrayAttr();
512   ArrayAttr preserved_res_attrs =
513       preserved ? preserved.getResAttrs() : ArrayAttr();
514 
515   // For each argument and result, lookup a name and regenerate output shapes.
516   const auto build_attrs = [&](ArrayAttr attr, auto &it,
517                                Optional<ValueRange> args) {
518     NamedAttrList attrs(attr ? attr[it.index()].template cast<DictionaryAttr>()
519                              : DictionaryAttr());
520     // If no name was preserved, try to find one.
521     if (!attrs.get(ids_.tfg_name)) {
522       if (StringAttr name = TryFindName(it.value(), args))
523         attrs.set(ids_.tfg_name, name_uniquer->GetUniqued(name));
524     }
525     attrs.set(ids_.tfg_regenerate_output_shapes, UnitAttr::get(ctx_));
526     return attrs.getDictionary(ctx_);
527   };
528 
529   for (auto &it : llvm::enumerate(arguments)) {
530     arg_attrs.append({build_attrs(preserved_arg_attrs, it, {}),
531                       DictionaryAttr::get(ctx_, {})});
532   }
533   for (auto &it : llvm::enumerate(results))
534     res_attrs.push_back(build_attrs(preserved_res_attrs, it, arguments));
535 
536   attrs.append(FunctionOpInterface::getArgDictAttrName(),
537                ArrayAttr::get(ctx_, arg_attrs));
538   attrs.append(FunctionOpInterface::getResultDictAttrName(),
539                ArrayAttr::get(ctx_, res_attrs));
540   return attrs;
541 }
542 
TryFindName(Value value,Optional<ValueRange> args) const543 StringAttr BasePattern::TryFindName(Value value,
544                                     Optional<ValueRange> args) const {
545   // If this is an op result, return the op's name.
546   if (auto result = value.dyn_cast<OpResult>()) {
547     Operation *op = result.getOwner();
548     if (auto name =
549             op->getAttrOfType<StringAttr>(dialect_.getNameAttrIdentifier())) {
550       return StringAttr::get(ctx_, name.getValue() + "_tfg_result_" +
551                                        Twine(result.getResultNumber()));
552     }
553     return {};
554   }
555 
556   auto arg = value.cast<BlockArgument>();
557   Operation *parent = arg.getOwner()->getParentOp();
558   auto iface = dyn_cast<ControlArgumentInterface>(parent);
559   if (!iface) return {};
560   // If we were given a control token, lookup a name using the data value.
561   if (arg.getType() == dialect_.getControlType())
562     arg = iface.getDataValueOf(arg);
563   // If the parent is a function, try to find a `tfg.name`.
564   if (auto func = dyn_cast<GraphFuncOp>(*iface))
565     return func.getArgAttrOfType<StringAttr>(arg.getArgNumber(), ids_.tfg_name);
566   // Otherwise, "see through" to the corresponding operand.
567   if (args) {
568     assert(arg.getArgNumber() < args->size());
569     return TryFindName((*args)[arg.getArgNumber()], {});
570   }
571   if (auto for_op = dyn_cast<ForRegionOp>(parent)) {
572     unsigned arg_idx = arg.getArgNumber();
573     if (arg_idx == 0) return TryFindName(for_op.start(), {});
574     return TryFindName(for_op.init()[arg_idx - 1], {});
575   }
576   auto branch = cast<RegionBranchOpInterface>(parent);
577   ValueRange inputs = branch.getSuccessorEntryOperands(
578       arg.getParentRegion()->getRegionNumber());
579   return TryFindName(inputs[arg.getArgNumber()], {});
580 }
581 
GetControlRetAttrs(ValueRange ctls,ValueRange args,NameUniquer * name_uniquer) const582 ArrayAttr BasePattern::GetControlRetAttrs(ValueRange ctls, ValueRange args,
583                                           NameUniquer *name_uniquer) const {
584   SmallVector<Attribute> ctl_ret_attrs;
585   for (Value ctl : ctls) {
586     NamedAttrList ctl_attrs;
587     if (StringAttr name = TryFindName(ctl, args)) {
588       ctl_attrs.set(dialect_.getTfgNameAttrIdentifier(),
589                     name_uniquer->GetUniqued(name));
590     }
591     ctl_ret_attrs.push_back(ctl_attrs.getDictionary(ctx_));
592   }
593   return ArrayAttr::get(ctx_, ctl_ret_attrs);
594 }
595 
CreateFunc(Location loc,const Twine & sym_name,Region & region,TypeRange res_types,NamedAttrList attrs) const596 GraphFuncOp BasePattern::CreateFunc(Location loc, const Twine &sym_name,
597                                     Region &region, TypeRange res_types,
598                                     NamedAttrList attrs) const {
599   SmallVector<Type> arg_types;
600   for (BlockArgument operand : GetLoopRegionDataArgs(region))
601     arg_types.append({operand.getType(), dialect_.getControlType()});
602   auto func_type = FunctionType::get(ctx_, arg_types, res_types);
603   auto func = OpBuilder(ctx_).create<GraphFuncOp>(loc, sym_name, func_type,
604                                                   /*generic=*/false);
605 
606   attrs.append(func->getAttrs());
607   func->setAttrs(attrs.getDictionary(ctx_));
608 
609   SmallVector<BlockArgument> args =
610       llvm::to_vector(GetLoopRegionDataArgs(region));
611   SmallVector<BlockArgument> ctls =
612       llvm::to_vector(GetLoopRegionControlTokens(region));
613   // TODO(jeffniu): Change GraphFuncOp to use the same argument order as region
614   // loop ops.
615   for (auto it : llvm::zip(args, ctls)) {
616     BlockArgument arg, ctl;
617     std::tie(arg, ctl) = it;
618     arg.replaceAllUsesWith(region.addArgument(arg.getType(), arg.getLoc()));
619     ctl.replaceAllUsesWith(region.addArgument(ctl.getType(), ctl.getLoc()));
620   }
621   llvm::BitVector indices(region.getNumArguments());
622   indices.set(0, args.size() * 2);
623   region.front().eraseArguments(indices);
624 
625   func.body().takeBody(region);
626   return func;
627 }
628 
629 // Check the region attributes for a preserved function name
630 // TODO(jeffniu): RegionAttr should have an optional parameter for the function
631 // name, since it is treated differently from the other attributes.
GetFunctionName(RegionAttr preserved)632 static StringAttr GetFunctionName(RegionAttr preserved) {
633   if (!preserved) return {};
634   return preserved.getAttrs().getAs<StringAttr>(
635       SymbolTable::getSymbolAttrName());
636 }
637 
Outline(Operation * op,PatternRewriter & rewriter,ValueRange args,Region & region,RegionAttr preserved,DictionaryAttr attrs,const Twine & func_name) const638 FuncAttr BasePattern::Outline(Operation *op, PatternRewriter &rewriter,
639                               ValueRange args, Region &region,
640                               RegionAttr preserved, DictionaryAttr attrs,
641                               const Twine &func_name) const {
642   // Create a name scope for the function.
643   NameUniquer name_uniquer(ctx_);
644 
645   NamedAttrList func_attrs = BuildAttributes(
646       preserved, args, cast<YieldOp>(region.front().getTerminator()).args(),
647       &name_uniquer);
648 
649   auto yield = cast<YieldOp>(region.front().getTerminator());
650   rewriter.setInsertionPoint(yield);
651   auto ret_op = rewriter.replaceOpWithNewOp<ReturnOp>(
652       yield, yield.getOperands(),
653       GetControlRetAttrs(yield.ctls(), args, &name_uniquer));
654 
655   // Derive a function name. Use a default name. If a previous name exists,
656   // use it. If the op also has a name, derive a name based on that.
657   std::string new_func_name = func_name.str();
658   if (StringAttr existing_name = GetFunctionName(preserved)) {
659     new_func_name = existing_name.getValue().str();
660     if (auto op_name =
661             op->getAttrOfType<StringAttr>(dialect_.getNameAttrIdentifier())) {
662       llvm::raw_string_ostream os(new_func_name);
663       os << "_tfg_region_specialized_";
664       for (char c : llvm::map_range(
665                op_name.getValue(), [](char c) { return isalnum(c) ? c : '_'; }))
666         os << c;
667       os << '_' << llvm::to_string(region.getRegionNumber());
668       os.flush();
669     }
670   }
671 
672   // Create the function.
673   GraphFuncOp func = CreateFunc(op->getLoc(), new_func_name, region,
674                                 TFOp(ret_op).getNonControlOperands().getTypes(),
675                                 std::move(func_attrs));
676   return FuncAttr::get(ctx_, table_.insert(func),
677                        attrs ? attrs : DictionaryAttr::get(ctx_, {}));
678 }
679 
680 template <typename FuncAttrT>
ReuseAllOrOutline(Operation * op,PatternRewriter & rewriter,ValueRange args,ArrayRef<RegionFunction> regions,SmallVectorImpl<FuncAttrT> & functions) const681 void BasePattern::ReuseAllOrOutline(
682     Operation *op, PatternRewriter &rewriter, ValueRange args,
683     ArrayRef<RegionFunction> regions,
684     SmallVectorImpl<FuncAttrT> &functions) const {
685   // Try to find reusable functions for all regions.
686   const auto get_reusable_func = [this,
687                                   &functions](const RegionFunction &func) {
688     FuncAttr ref =
689         FindReusableFunc(func.region, func.preserved_attrs, func.call_attrs);
690     functions.push_back(ref);
691     return ref;
692   };
693   if (llvm::all_of(regions, get_reusable_func)) return;
694 
695   // At least one region needs to be outlined.
696   functions.clear();
697   for (const RegionFunction &func : regions) {
698     functions.push_back(Outline(op, rewriter, args, func.region,
699                                 func.preserved_attrs, func.call_attrs,
700                                 func.func_name));
701   }
702 }
703 
704 // Returns true if the region has any nested regions.
HasNestedRegions(Region & region)705 static bool HasNestedRegions(Region &region) {
706   return llvm::any_of(region.getOps(),
707                       [](Operation &op) { return op.getNumRegions(); });
708 }
709 
710 // Check if the region is "equivalent" to the body of the given function, and so
711 // the function can be re-used when outlining the region. This compares
712 // (topologically) the arguments, results, and ops, ignoring the op names and
713 // checking for compatible types.
RegionEqualTo(Region & region,GraphFuncOp func)714 static bool RegionEqualTo(Region &region, GraphFuncOp func) {
715   assert(!HasNestedRegions(region));
716   assert(!HasNestedRegions(func.body()));
717 
718   // Outlining is performed "bottom-up". I.e. regions with no nested regions are
719   // outlined first, which means that we will not have to worry about comparing
720   // `While` to `WhileRegion`. Also, it means that we can directly compare the
721   // operations.
722   DenseMap<Value, Value> value_map;
723   auto map_value = [&](Value lhs, Value rhs) {
724     if (!tf_type::HasCompatibleElementTypes(lhs.getType(), rhs.getType()))
725       return false;
726     return value_map.insert({lhs, rhs}).first->second == rhs;
727   };
728 
729   // Compare the non-control block arguments.
730   if (region.getNumArguments() != func.getNumArguments()) return false;
731   for (auto &it : llvm::enumerate(GetLoopRegionDataArgs(region))) {
732     Value rhs = GraphFuncOp::getDataValue(func.body(), it.index());
733     if (!map_value(it.value(), rhs)) return false;
734   }
735 
736   // Compare the bodies except the terminators. We can't use
737   // OperationEquivalence due to relaxed type equality.
738   auto map_value_range = [](ValueRange lhs_range, ValueRange rhs_range,
739                             auto map_value) {
740     if (lhs_range.size() != rhs_range.size()) return false;
741     for (auto it : llvm::zip(lhs_range, rhs_range))
742       if (!map_value(std::get<0>(it), std::get<1>(it))) return false;
743     return true;
744   };
745 
746   StringAttr name_id =
747       cast<TFGraphDialect>(func->getDialect())->getNameAttrIdentifier();
748 
749   auto compare_ops = [&](Operation &lhs, Operation &rhs) {
750     if (lhs.getName() != rhs.getName()) return false;
751 
752     DictionaryAttr lhs_attrs = lhs.getAttrDictionary();
753     DictionaryAttr rhs_attrs = rhs.getAttrDictionary();
754     if (lhs_attrs.size() != rhs_attrs.size()) return false;
755     for (auto it : llvm::zip(lhs_attrs, rhs_attrs)) {
756       NamedAttribute lhs_attr = std::get<0>(it);
757       NamedAttribute rhs_attr = std::get<1>(it);
758       if (lhs_attr.getName() != rhs_attr.getName()) return false;
759       if (lhs_attr.getName() == name_id) continue;
760       if (lhs_attr.getValue() != rhs_attr.getValue()) return false;
761     }
762     if (!map_value_range(lhs.getOperands(), rhs.getOperands(), map_value))
763       return false;
764     if (!map_value_range(lhs.getResults(), rhs.getResults(), map_value))
765       return false;
766     assert(!lhs.getNumRegions() && !rhs.getNumRegions());
767     return true;
768   };
769   if (!llvm::all_of_zip(region.front().without_terminator(),
770                         func.body().front().without_terminator(), compare_ops))
771     return false;
772 
773   // Compare just the operands of the terminators.
774   auto return_op = cast<ReturnOp>(func.body().front().getTerminator());
775   Operation *terminator = region.front().getTerminator();
776   if (auto yield = dyn_cast<YieldOp>(terminator)) {
777     return map_value_range(yield->getOperands(), return_op->getOperands(),
778                            map_value);
779   } else {
780     auto cond = cast<ConditionOp>(terminator);
781     return map_value(cond.cond(), return_op->getOperand(0)) &&
782            map_value_range(
783                cond.ctls(),
784                return_op->getOperands().slice(1, cond.ctls().size()),
785                map_value);
786   }
787 }
788 
FindReusableFunc(Region & region,RegionAttr preserved,DictionaryAttr attrs) const789 FuncAttr BasePattern::FindReusableFunc(Region &region, RegionAttr preserved,
790                                        DictionaryAttr attrs) const {
791   StringAttr name = GetFunctionName(preserved);
792   if (!name) return {};
793   auto func = table_.lookup<GraphFuncOp>(name);
794   if (!func) return {};
795   if (!RegionEqualTo(region, func)) return {};
796   return FuncAttr::get(region.getContext(), name.getValue(),
797                        attrs ? attrs : DictionaryAttr::get(ctx_, {}));
798 }
799 
FuncHasNestedRegions(RegionAttr preserved) const800 bool BasePattern::FuncHasNestedRegions(RegionAttr preserved) const {
801   StringAttr name = GetFunctionName(preserved);
802   if (!name) return false;
803   auto func = table_.lookup<GraphFuncOp>(name);
804   return func && HasNestedRegions(func.body());
805 }
806 
807 //===----------------------------------------------------------------------===//
808 // ConvertToExplicitCapture
809 //===----------------------------------------------------------------------===//
810 
811 template <typename OpT>
812 FailureOr<std::pair<OpT, std::vector<Value>>>
Run(OpT op,PatternRewriter & rewriter)813 ConvertToExplicitCapture<OpT>::Run(OpT op, PatternRewriter &rewriter) {
814   FailureOr<std::vector<Value>> operands =
815       this->CollectValuesDefinedAboveAll(op->getRegions(), rewriter);
816   if (failed(operands)) return failure();
817   this->IsolateRegions(op->getRegions(), *operands);
818   OpT new_op = RebuildWith(op, *operands, rewriter);
819   util::ForwardNonIntrinsicAttributes(op, new_op);
820   for (auto it : llvm::zip(op->getRegions(), new_op->getRegions()))
821     std::get<1>(it).takeBody(std::get<0>(it));
822   rewriter.replaceOp(op, new_op->getResults().slice(0, op->getNumResults()));
823   return std::make_pair(new_op, std::move(*operands));
824 }
825 
826 //===----------------------------------------------------------------------===//
827 // ConvertIfLikeOp
828 //===----------------------------------------------------------------------===//
829 
830 template <typename IfLikeRegionOp, typename IfLikeOp>
matchAndRewrite(IfLikeRegionOp op,PatternRewriter & rewriter) const831 LogicalResult ConvertIfLikeOp<IfLikeRegionOp, IfLikeOp>::matchAndRewrite(
832     IfLikeRegionOp op, PatternRewriter &rewriter) const {
833   if (HasNestedRegions(op.then_region()) || HasNestedRegions(op.else_region()))
834     return failure();
835   if (this->FuncHasNestedRegions(op.then_region_attrsAttr()) ||
836       this->FuncHasNestedRegions(op.else_region_attrsAttr()))
837     return failure();
838 
839   // Convert the op to explicit capture.
840   ConvertIfLikeRegionOpToExplicitCapture<IfLikeRegionOp> converter(
841       this->dialect_, this->table_, this->force_control_capture_, this->ids_);
842   auto result = converter.Run(op, rewriter);
843   if (failed(result)) return failure();
844   std::vector<Value> args;
845   std::tie(op, args) = std::move(*result);
846 
847   // Outline the regions.
848   SmallVector<FuncAttr, 2> branches;
849   this->ReuseAllOrOutline(op, rewriter, args,
850                           {{op.then_region(), op.then_region_attrsAttr(),
851                             op.then_attrsAttr(), "if_then_function"},
852                            {op.else_region(), op.else_region_attrsAttr(),
853                             op.else_attrsAttr(), "if_else_function"}},
854                           branches);
855 
856   // Build the functional if-like op.
857   SmallVector<Value> operands = llvm::to_vector(args);
858   llvm::append_range(operands, op.ctls());
859 
860   rewriter.setInsertionPoint(op);
861   auto func_op =
862       rewriter.create<IfLikeOp>(op.getLoc(), op.getResultTypes(), op.cond(),
863                                 operands, branches[0], branches[1]);
864   util::ForwardNonIntrinsicAttributes(op, func_op);
865   rewriter.replaceOp(op, func_op.getResults());
866   return success();
867 }
868 
869 //===----------------------------------------------------------------------===//
870 // ConvertCaseLikeOp
871 //===----------------------------------------------------------------------===//
872 
873 template <typename CaseLikeRegionOp, typename CaseLikeOp>
matchAndRewrite(CaseLikeRegionOp op,PatternRewriter & rewriter) const874 LogicalResult ConvertCaseLikeOp<CaseLikeRegionOp, CaseLikeOp>::matchAndRewrite(
875     CaseLikeRegionOp op, PatternRewriter &rewriter) const {
876   if (llvm::any_of(op.branches(), HasNestedRegions)) return failure();
877   if (ArrayAttr preserved = op.region_attrsAttr()) {
878     if (llvm::any_of(preserved.getAsRange<RegionAttr>(), [&](auto preserved) {
879           return this->FuncHasNestedRegions(preserved);
880         }))
881       return failure();
882   }
883 
884   // Convert the op to explicit capture.
885   ConvertCaseLikeRegionOpToExplicitCapture<CaseLikeRegionOp> converter(
886       this->dialect_, this->table_, this->force_control_capture_, this->ids_);
887   auto result = converter.Run(op, rewriter);
888   if (failed(result)) return failure();
889   std::vector<Value> args;
890   std::tie(op, args) = std::move(*result);
891 
892   // Outline the regions.
893   ArrayAttr branch_func_attrs = op.branch_attrsAttr();
894   SmallVector<BasePattern::RegionFunction> branch_regions;
895   for (auto &it : llvm::enumerate(op.branches())) {
896     unsigned idx = it.index();
897     // Get the preserved attributes, if there are any.
898     RegionAttr preserved =
899         op.region_attrs()
900             ? op.region_attrsAttr()[idx].template cast<RegionAttr>()
901             : nullptr;
902     DictionaryAttr attrs =
903         branch_func_attrs
904             ? branch_func_attrs[idx].template cast<DictionaryAttr>()
905             : nullptr;
906     branch_regions.push_back(BasePattern::RegionFunction{
907         it.value(), preserved, attrs, ("case_function_" + Twine(idx)).str()});
908   }
909   SmallVector<Attribute> branches;
910   this->ReuseAllOrOutline(op, rewriter, args, branch_regions, branches);
911 
912   // Build the functional case-like op.
913   SmallVector<Value> operands = llvm::to_vector(args);
914   llvm::append_range(operands, op.ctls());
915 
916   rewriter.setInsertionPoint(op);
917   auto func_op = rewriter.create<CaseLikeOp>(op.getLoc(), op.getResultTypes(),
918                                              op.branch_index(), operands,
919                                              rewriter.getArrayAttr(branches));
920   util::ForwardNonIntrinsicAttributes(op, func_op);
921   rewriter.replaceOp(op, func_op.getResults());
922   return success();
923 }
924 
925 //===----------------------------------------------------------------------===//
926 // ConvertWhileLikeOp
927 //===----------------------------------------------------------------------===//
928 
929 template <typename WhileLikeRegionOp, typename WhileLikeOp>
930 LogicalResult
matchAndRewrite(WhileLikeRegionOp op,PatternRewriter & rewriter) const931 ConvertWhileLikeOp<WhileLikeRegionOp, WhileLikeOp>::matchAndRewrite(
932     WhileLikeRegionOp op, PatternRewriter &rewriter) const {
933   if (HasNestedRegions(op.cond_region()) || HasNestedRegions(op.body_region()))
934     return failure();
935   if (this->FuncHasNestedRegions(op.cond_region_attrsAttr()) ||
936       this->FuncHasNestedRegions(op.body_region_attrsAttr()))
937     return failure();
938 
939   // Convert the op to explicit capture.
940   ConvertWhileLikeRegionOpToExplicitCapture<WhileLikeRegionOp> converter(
941       this->dialect_, this->table_, this->force_control_capture_, this->ids_);
942   auto result = converter.Run(op, rewriter);
943   if (failed(result)) return failure();
944   op = result->first;
945 
946   // Try to find re-usable functions for both the condition and body regions.
947   FuncAttr body_ref = this->FindReusableFunc(
948       op.body_region(), op.body_region_attrsAttr(), op.body_attrsAttr());
949   FuncAttr cond_ref = this->FindReusableFunc(
950       op.cond_region(), op.cond_region_attrsAttr(), op.cond_attrsAttr());
951 
952   // If a function for either region could not be re-used, outline them out.
953   if (!body_ref || !cond_ref) {
954     // Handle the condition region. Unlike other regions, the terminator is
955     // special and the function only has one result.
956     ConditionOp cond_op = op.cond_condition();
957     // Create a name scope for the condition function.
958     NameUniquer name_uniquer(this->ctx_);
959     // Create the function.
960     NamedAttrList cond_attrs = this->BuildAttributes(
961         op.cond_region_attrsAttr(), op.init(), cond_op.cond(), &name_uniquer);
962     GraphFuncOp cond_func =
963         this->CreateFunc(op.getLoc(), "while_cond_function", op.cond_region(),
964                          cond_op.cond().getType(), std::move(cond_attrs));
965     // Replace the condition terminator.
966     rewriter.setInsertionPoint(cond_op);
967     SmallVector<Value> cond_rets = {cond_op.cond()};
968     llvm::append_range(cond_rets, cond_op.ctls());
969     rewriter.replaceOpWithNewOp<ReturnOp>(
970         cond_op, cond_rets,
971         this->GetControlRetAttrs(cond_op.ctls(), op.init(), &name_uniquer));
972     // Insert the function and grab a reference.
973     cond_ref =
974         FuncAttr::get(op.getContext(), this->table_.insert(cond_func),
975                       op.cond_attrs().value_or(rewriter.getDictionaryAttr({})));
976 
977     // Outline the body.
978     body_ref = this->Outline(op, rewriter, op.init(), op.body_region(),
979                              op.body_region_attrsAttr(), op.body_attrsAttr(),
980                              "while_body_function");
981   }
982 
983   // Create the functional op.
984   SmallVector<Value> operands = llvm::to_vector(op.init());
985   llvm::append_range(operands, op.ctls());
986 
987   rewriter.setInsertionPoint(op);
988   auto func_op = rewriter.create<WhileLikeOp>(op.getLoc(), op.getResultTypes(),
989                                               operands, cond_ref, body_ref,
990                                               op.parallel_iterationsAttr());
991   util::ForwardNonIntrinsicAttributes(op, func_op);
992   rewriter.replaceOp(op, func_op.getResults());
993   return success();
994 }
995 
996 //===----------------------------------------------------------------------===//
997 // ConvertForOp
998 //===----------------------------------------------------------------------===//
999 
matchAndRewrite(ForRegionOp op,PatternRewriter & rewriter) const1000 LogicalResult ConvertForOp::matchAndRewrite(ForRegionOp op,
1001                                             PatternRewriter &rewriter) const {
1002   if (HasNestedRegions(op.body_region())) return failure();
1003   if (this->FuncHasNestedRegions(op.region_attrsAttr())) return failure();
1004 
1005   // Convert the op to explicit capture.
1006   ConvertForRegionOpToExplicitCapture converter(dialect_, table_,
1007                                                 force_control_capture_, ids_);
1008   auto result = converter.Run(op, rewriter);
1009   if (failed(result)) return failure();
1010   op = result->first;
1011 
1012   // Outline to body.
1013   SmallVector<Value> func_args(/*Size=*/1, op.start());
1014   llvm::append_range(func_args, op.init());
1015   SmallVector<FuncAttr, 1> body_ref;
1016   ReuseAllOrOutline(op, rewriter, func_args,
1017                     {{op.body_region(), op.region_attrsAttr(),
1018                       op.body_attrsAttr(), "for_body_function"}},
1019                     body_ref);
1020 
1021   // Create the functional op.
1022   SmallVector<Value> operands = llvm::to_vector(op.init());
1023   llvm::append_range(operands, op.ctls());
1024 
1025   rewriter.setInsertionPoint(op);
1026   auto func_op = rewriter.create<tfg::ForOp>(op.getLoc(), op.getResultTypes(),
1027                                              op.start(), op.limit(), op.delta(),
1028                                              operands, body_ref[0]);
1029   util::ForwardNonIntrinsicAttributes(op, func_op);
1030   rewriter.replaceOp(op, func_op.getResults());
1031   return success();
1032 }
1033 
1034 //===----------------------------------------------------------------------===//
1035 // Pattern Population
1036 //===----------------------------------------------------------------------===//
1037 
PopulateRegionToFunctionalPatterns(RewritePatternSet & patterns,SymbolTable & table,bool force_control_capture)1038 void PopulateRegionToFunctionalPatterns(RewritePatternSet &patterns,
1039                                         SymbolTable &table,
1040                                         bool force_control_capture) {
1041   auto *dialect = patterns.getContext()->getOrLoadDialect<TFGraphDialect>();
1042   patterns.insert<ConvertIfOp, ConvertStatelessIfOp, ConvertStatefulIfOp,
1043                   ConvertCaseOp, ConvertStatelessCaseOp, ConvertStatefulCaseOp,
1044                   ConvertWhileOp, ConvertStatelessWhileOp,
1045                   ConvertStatefulWhileOp, ConvertForOp>(
1046       patterns.getContext(), *dialect, table, force_control_capture,
1047       CachedIdentifiers(dialect));
1048 }
1049 
1050 }  // namespace tfg
1051 }  // namespace mlir
1052