1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/transforms/region_to_functional/impl.h"
17
18 #include <string>
19 #include <tuple>
20 #include <utility>
21 #include <vector>
22
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringSet.h"
27 #include "llvm/ADT/iterator.h"
28 #include "llvm/Support/ScopedPrinter.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
31 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34 #include "mlir/IR/FunctionInterfaces.h" // from @llvm-project
35 #include "mlir/IR/MLIRContext.h" // from @llvm-project
36 #include "mlir/IR/OpDefinition.h" // from @llvm-project
37 #include "mlir/IR/OperationSupport.h" // from @llvm-project
38 #include "mlir/IR/PatternMatch.h" // from @llvm-project
39 #include "mlir/IR/TypeRange.h" // from @llvm-project
40 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
41 #include "mlir/IR/Value.h" // from @llvm-project
42 #include "mlir/IR/Verifier.h" // from @llvm-project
43 #include "mlir/Support/LLVM.h" // from @llvm-project
44 #include "mlir/Support/LogicalResult.h" // from @llvm-project
45 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
46 #include "tensorflow/core/ir/dialect.h"
47 #include "tensorflow/core/ir/ops.h"
48 #include "tensorflow/core/ir/types/dialect.h"
49 #include "tensorflow/core/ir/utility.h"
50 #include "tensorflow/core/transforms/utils/utils.h"
51
52 namespace mlir {
53 namespace tfg {
54
55 //===----------------------------------------------------------------------===//
56 // Pattern Definitions
57 //===----------------------------------------------------------------------===//
58
59 namespace {
60 // Cached attribute name identifiers shared by all patterns.
61 struct CachedIdentifiers {
CachedIdentifiersmlir::tfg::__anond19cc84e0111::CachedIdentifiers62 explicit CachedIdentifiers(TFGraphDialect *dialect)
63 : tfg_name(dialect->getTfgNameAttrIdentifier()),
64 tfg_regenerate_output_shapes(StringAttr::get(
65 dialect->getContext(), "tfg.regenerate_output_shapes")) {}
66
67 // Cached identifier for "tfg.name".
68 StringAttr tfg_name;
69 // Cached identifier for "tfg.regenerate_output_shapes".
70 StringAttr tfg_regenerate_output_shapes;
71 };
72
73 // A helper for uniqueing argument, result, and control result names, which must
74 // be unique for a function.
75 class NameUniquer {
76 public:
NameUniquer(MLIRContext * ctx)77 explicit NameUniquer(MLIRContext *ctx) : ctx_(ctx) {}
78
79 // Unique a name. If the name is unused, returns the name. Otherwise,
80 // allocates a new name.
GetUniqued(StringAttr name)81 StringAttr GetUniqued(StringAttr name) {
82 auto it = unique_names_.insert(name);
83 if (it.second) return name;
84 unsigned suffix = 0;
85 StringAttr next_name;
86 do {
87 next_name =
88 StringAttr::get(ctx_, name.getValue() + "_" + Twine(suffix++));
89 it = unique_names_.insert(next_name);
90 } while (!it.second);
91 return next_name;
92 }
93
94 private:
95 // The MLIR context.
96 MLIRContext *ctx_;
97 // This set contains the occupied names.
98 DenseSet<StringAttr> unique_names_;
99 };
100
101 // Base class for patterns used to convert region control-flow ops to functional
102 // control-flow ops. This class contains common utility functions and cached
103 // attribute identifiers.
104 class BasePattern {
105 public:
BasePattern(TFGraphDialect & dialect,SymbolTable & table,bool force_control_capture,CachedIdentifiers ids)106 BasePattern(TFGraphDialect &dialect, SymbolTable &table,
107 bool force_control_capture, CachedIdentifiers ids)
108 : ctx_(dialect.getContext()),
109 dialect_(dialect),
110 table_(table),
111 force_control_capture_(force_control_capture),
112 ids_(ids) {}
113
114 protected:
115 // Collect all values used in the region that are defined above the region.
116 // If a control token is encountered, collect its associated data value. If it
117 // doesn't have one, add it to `ctls`.
118 void CollectValuesDefinedAbove(Region ®ion, SetVector<Value> &datas,
119 SetVector<Value> &ctls) const;
120 // Collect data values used in any of the given regions that are defined above
121 // the regions. These are the values that will be converted to explicit
122 // capture. If a control token with no associated data value is encountered
123 // and `force_control_capture_` is not set, then this function returns
124 // failure. Otherwise, it inserts chain constants and rewrites uses of the
125 // token to use the control outputs of the constants.
126 FailureOr<std::vector<Value>> CollectValuesDefinedAboveAll(
127 RegionRange regions, PatternRewriter &rewriter) const;
128 // Rewrite the regions to be isolated from above by replacing uses of the
129 // given data values with block arguments. Use the same set of values for each
130 // region so that their block arguments are the same.
131 void IsolateRegions(RegionRange regions, MutableArrayRef<Value> datas) const;
132 // Create a chain `Const` operation. The op's data result is unused; its
133 // only purpose is to convert a control edge into a data edge.
134 Operation *MakeChainConstant(Operation *parent, Value ctl, unsigned idx,
135 PatternRewriter &rewriter) const;
136
137 // Infer or propagate function attributes. Use a name uniquer to unique names
138 // across function arguments, results, and control results.
139 NamedAttrList BuildAttributes(RegionAttr preserved, ValueRange arguments,
140 ValueRange results,
141 NameUniquer *name_uniquer) const;
142
143 // Try to find a name for a data or control value. For op results, check the
144 // op for a name. Otherwise, check the enclosing function's arg attributes.
145 StringAttr TryFindName(Value value, Optional<ValueRange> args) const;
146
147 // Get the `control_ret_attrs` attributes for control returns. Use a name
148 // uniquer to unique names across function arguments, results, and control
149 // results,
150 ArrayAttr GetControlRetAttrs(ValueRange ctls, ValueRange args,
151 NameUniquer *name_uniquer) const;
152
153 // Create a function with the given name and attributes. Use the types of the
154 // block arguments and the given results types. Take the body of the region.
155 GraphFuncOp CreateFunc(Location loc, const Twine &sym_name, Region ®ion,
156 TypeRange res_types, NamedAttrList attrs) const;
157
158 // Convert a (yield-terminated) region to a function and return a reference.
159 FuncAttr Outline(Operation *op, PatternRewriter &rewriter, ValueRange args,
160 Region ®ion, RegionAttr preserved, DictionaryAttr attrs,
161 const Twine &func_name) const;
162
163 // A region function to outline.
164 struct RegionFunction {
165 // The function body.
166 Region ®ion;
167 // Potentially null preserved function attributes.
168 RegionAttr preserved_attrs;
169 // The function call attributes.
170 DictionaryAttr call_attrs;
171 // The function name to use.
172 std::string func_name;
173 };
174 // Outline a list of (yield-terminated) region functions, but if any function
175 // could not be re-used, then new functions are created for all of them.
176 template <typename FuncAttrT>
177 void ReuseAllOrOutline(Operation *op, PatternRewriter &rewriter,
178 ValueRange args, ArrayRef<RegionFunction> regions,
179 SmallVectorImpl<FuncAttrT> &functions) const;
180
181 // Try to find a "reusable" function that has the same body as the provided
182 // region. A function is "reusable" if its body has the same topology as the
183 // provided region, corresponding operands have the same attributes, except
184 // for node name, and value types are compatible.
185 FuncAttr FindReusableFunc(Region ®ion, RegionAttr preserved,
186 DictionaryAttr attrs) const;
187
188 // If a function exists and has nested regions, return false. Otherwise,
189 // return true.
190 bool FuncHasNestedRegions(RegionAttr preserved) const;
191
192 protected:
193 // Reference to the context.
194 MLIRContext *ctx_;
195 // Dialect reference for getting cached values;
196 TFGraphDialect &dialect_;
197 // Symbol table to use to look up existing functions.
198 SymbolTable &table_;
199 // Whether control tokens with no data values should be forcefully captured by
200 // inserting a chain `Const` op.
201 bool force_control_capture_;
202 // Cached attribute identifiers.
203 CachedIdentifiers ids_;
204 };
205
206 //===----------------------------------------------------------------------===//
207 // ConvertToExplicitCapture
208
209 template <typename OpT>
210 struct ConvertToExplicitCapture : public BasePattern {
211 using BasePattern::BasePattern;
212
213 virtual ~ConvertToExplicitCapture() = default;
214
215 // Convert the regions of the operation to explicit capture. Returns the
216 // newly captured values and an updated op.
217 FailureOr<std::pair<OpT, std::vector<Value>>> Run(OpT op,
218 PatternRewriter &rewriter);
219
220 // Rebuild the regions of the operation with added values.
221 virtual OpT RebuildWith(OpT op, ValueRange added,
222 PatternRewriter &rewriter) const = 0;
223 };
224
225 template <typename IfLikeRegionOp>
226 struct ConvertIfLikeRegionOpToExplicitCapture
227 : public ConvertToExplicitCapture<IfLikeRegionOp> {
228 using ConvertToExplicitCapture<IfLikeRegionOp>::ConvertToExplicitCapture;
229
RebuildWithmlir::tfg::__anond19cc84e0111::ConvertIfLikeRegionOpToExplicitCapture230 IfLikeRegionOp RebuildWith(IfLikeRegionOp op, ValueRange added,
231 PatternRewriter &rewriter) const override {
232 return rewriter.create<IfLikeRegionOp>(
233 op.getLoc(), op.getResultTypes(), op.cond(), op.ctls(),
234 op.then_attrsAttr(), op.else_attrsAttr(), op.then_region_attrsAttr(),
235 op.else_region_attrsAttr());
236 }
237 };
238
239 template <typename CaseLikeRegionOp>
240 struct ConvertCaseLikeRegionOpToExplicitCapture
241 : public ConvertToExplicitCapture<CaseLikeRegionOp> {
242 using ConvertToExplicitCapture<CaseLikeRegionOp>::ConvertToExplicitCapture;
243
RebuildWithmlir::tfg::__anond19cc84e0111::ConvertCaseLikeRegionOpToExplicitCapture244 CaseLikeRegionOp RebuildWith(CaseLikeRegionOp op, ValueRange added,
245 PatternRewriter &rewriter) const override {
246 return rewriter.create<CaseLikeRegionOp>(
247 op.getLoc(), op.getResultTypes(), op.branch_index(), op.ctls(),
248 op.branch_attrsAttr(), op.region_attrsAttr(), op.branches().size());
249 }
250 };
251
252 // Get the block arguments that correspond to the passthrough iteration
253 // arguments created from converting implicit captures. Append them to the
254 // previous region results `prev`.
GetForwardedValues(ValueRange added,Block::BlockArgListType block_args,ValueRange prev)255 static SmallVector<Value> GetForwardedValues(ValueRange added,
256 Block::BlockArgListType block_args,
257 ValueRange prev) {
258 SmallVector<Value> args(prev.begin(), prev.end());
259 llvm::append_range(args, block_args.slice(prev.size(), added.size()));
260 return args;
261 }
262
263 template <typename WhileLikeRegionOp>
264 struct ConvertWhileLikeRegionOpToExplicitCapture
265 : public ConvertToExplicitCapture<WhileLikeRegionOp> {
266 using ConvertToExplicitCapture<WhileLikeRegionOp>::ConvertToExplicitCapture;
267
RebuildWithmlir::tfg::__anond19cc84e0111::ConvertWhileLikeRegionOpToExplicitCapture268 WhileLikeRegionOp RebuildWith(WhileLikeRegionOp op, ValueRange added,
269 PatternRewriter &rewriter) const override {
270 ConditionOp cond_op = op.cond_condition();
271 rewriter.setInsertionPoint(cond_op);
272 rewriter.replaceOpWithNewOp<ConditionOp>(
273 cond_op, cond_op.cond(),
274 GetForwardedValues(added, op.cond_region().getArguments(),
275 cond_op.args()),
276 cond_op.ctls());
277
278 YieldOp yield_op = op.body_yield();
279 rewriter.setInsertionPoint(yield_op);
280 rewriter.replaceOpWithNewOp<YieldOp>(
281 yield_op,
282 GetForwardedValues(added, op.body_region().getArguments(),
283 yield_op.args()),
284 yield_op.ctls());
285
286 SmallVector<Value> operands = llvm::to_vector(op.init());
287 llvm::append_range(operands, added);
288 SmallVector<Type> results = llvm::to_vector(op.outs().getTypes());
289 llvm::append_range(results, added.getTypes());
290 util::LoopRegionResultAdded(op.body_region(), added.size());
291
292 rewriter.setInsertionPoint(op);
293 return rewriter.create<WhileLikeRegionOp>(
294 op.getLoc(), results, op.ctl().getType(), operands, op.ctls(),
295 op.parallel_iterationsAttr(), op.cond_attrsAttr(), op.body_attrsAttr(),
296 op.cond_region_attrsAttr(), op.body_region_attrsAttr());
297 }
298 };
299
300 struct ConvertForRegionOpToExplicitCapture
301 : public ConvertToExplicitCapture<ForRegionOp> {
302 using ConvertToExplicitCapture<ForRegionOp>::ConvertToExplicitCapture;
303
RebuildWithmlir::tfg::__anond19cc84e0111::ConvertForRegionOpToExplicitCapture304 ForRegionOp RebuildWith(ForRegionOp op, ValueRange added,
305 PatternRewriter &rewriter) const override {
306 YieldOp yield_op = op.body_yield();
307 rewriter.setInsertionPoint(yield_op);
308 // Get the iteration arguments excluding the for loop index argument.
309 auto iter_args = GetLoopRegionDataArgs(op.body_region()).slice(1);
310 rewriter.replaceOpWithNewOp<YieldOp>(
311 yield_op, GetForwardedValues(added, iter_args, yield_op.args()),
312 yield_op.ctls());
313
314 SmallVector<Value> operands = llvm::to_vector(op.init());
315 llvm::append_range(operands, added);
316 SmallVector<Type> results = llvm::to_vector(op.outs().getTypes());
317 llvm::append_range(results, added.getTypes());
318 util::LoopRegionResultAdded(op.body_region(), added.size());
319
320 rewriter.setInsertionPoint(op);
321 return rewriter.create<ForRegionOp>(
322 op.getLoc(), results, op.ctl().getType(), op.start(), op.limit(),
323 op.delta(), operands, op.ctls(), op.body_attrsAttr(),
324 op.region_attrsAttr());
325 }
326 };
327
328 //===----------------------------------------------------------------------===//
329 // ConvertRegionToFunctional
330
331 template <typename SourceOp, typename DestOp>
332 struct ConvertRegionToFunctionalPattern : public OpRewritePattern<SourceOp>,
333 public BasePattern {
ConvertRegionToFunctionalPatternmlir::tfg::__anond19cc84e0111::ConvertRegionToFunctionalPattern334 explicit ConvertRegionToFunctionalPattern(MLIRContext *context,
335 TFGraphDialect &dialect,
336 SymbolTable &table,
337 bool force_control_capture,
338 CachedIdentifiers ids)
339 : OpRewritePattern<SourceOp>(context, /*benefit=*/1,
340 {DestOp::getOperationName()}),
341 BasePattern(dialect, table, force_control_capture, ids) {}
342 };
343
344 // Base class for patterns to convert an if-like TFG region op to
345 // functional form.
346 template <typename IfLikeRegionOp, typename IfLikeOp>
347 struct ConvertIfLikeOp
348 : public ConvertRegionToFunctionalPattern<IfLikeRegionOp, IfLikeOp> {
349 using ConvertRegionToFunctionalPattern<
350 IfLikeRegionOp, IfLikeOp>::ConvertRegionToFunctionalPattern;
351
352 LogicalResult matchAndRewrite(IfLikeRegionOp op,
353 PatternRewriter &rewriter) const override;
354 };
355
356 using ConvertIfOp = ConvertIfLikeOp<IfRegionOp, IfOp>;
357 using ConvertStatelessIfOp =
358 ConvertIfLikeOp<StatelessIfRegionOp, StatelessIfOp>;
359 using ConvertStatefulIfOp = ConvertIfLikeOp<StatefulIfRegionOp, StatefulIfOp>;
360
361 // Base class for patterns to convert an case-like TFG region op to
362 // functional form.
363 template <typename CaseLikeRegionOp, typename CaseLikeOp>
364 struct ConvertCaseLikeOp
365 : public ConvertRegionToFunctionalPattern<CaseLikeRegionOp, CaseLikeOp> {
366 using ConvertRegionToFunctionalPattern<
367 CaseLikeRegionOp, CaseLikeOp>::ConvertRegionToFunctionalPattern;
368
369 LogicalResult matchAndRewrite(CaseLikeRegionOp op,
370 PatternRewriter &rewriter) const override;
371 };
372
373 using ConvertCaseOp = ConvertCaseLikeOp<CaseRegionOp, CaseOp>;
374 using ConvertStatelessCaseOp =
375 ConvertCaseLikeOp<StatelessCaseRegionOp, StatelessCaseOp>;
376 using ConvertStatefulCaseOp =
377 ConvertCaseLikeOp<StatefulCaseRegionOp, StatefulCaseOp>;
378
379 // Base class for patterns to convert a while-like TFG region op to functional
380 // form.
381 template <typename WhileLikeRegionOp, typename WhileLikeOp>
382 struct ConvertWhileLikeOp
383 : public ConvertRegionToFunctionalPattern<WhileLikeRegionOp, WhileLikeOp> {
384 using ConvertRegionToFunctionalPattern<
385 WhileLikeRegionOp, WhileLikeOp>::ConvertRegionToFunctionalPattern;
386
387 LogicalResult matchAndRewrite(WhileLikeRegionOp op,
388 PatternRewriter &rewriter) const override;
389 };
390
391 using ConvertWhileOp = ConvertWhileLikeOp<WhileRegionOp, WhileOp>;
392 using ConvertStatelessWhileOp =
393 ConvertWhileLikeOp<StatelessWhileRegionOp, StatelessWhileOp>;
394 using ConvertStatefulWhileOp =
395 ConvertWhileLikeOp<StatefulWhileRegionOp, StatefulWhileOp>;
396
397 // Convert a region-based for-loop to a functional for-loop.
398 struct ConvertForOp
399 : public ConvertRegionToFunctionalPattern<ForRegionOp, ForOp> {
400 using ConvertRegionToFunctionalPattern<
401 ForRegionOp, ForOp>::ConvertRegionToFunctionalPattern;
402
403 LogicalResult matchAndRewrite(ForRegionOp op,
404 PatternRewriter &rewriter) const override;
405 };
406
407 } // namespace
408
409 //===----------------------------------------------------------------------===//
410 // Utility Functions
411 //===----------------------------------------------------------------------===//
412
CollectValuesDefinedAbove(Region & region,SetVector<Value> & datas,SetVector<Value> & ctls) const413 void BasePattern::CollectValuesDefinedAbove(Region ®ion,
414 SetVector<Value> &datas,
415 SetVector<Value> &ctls) const {
416 ControlType control_ty = dialect_.getControlType();
417 visitUsedValuesDefinedAbove(region, [&](OpOperand *operand) {
418 Value value = operand->get();
419 if (value.getType() != control_ty) {
420 datas.insert(value);
421 } else if (Optional<Value> data = LookupDataValue(value)) {
422 datas.insert(*data);
423 } else {
424 ctls.insert(value);
425 }
426 });
427 }
428
MakeChainConstant(Operation * parent,Value ctl,unsigned idx,PatternRewriter & rewriter) const429 Operation *BasePattern::MakeChainConstant(Operation *parent, Value ctl,
430 unsigned idx,
431 PatternRewriter &rewriter) const {
432 OperationName name("tfg.Const", ctl.getContext());
433 OperationState state(ctl.getLoc(), name);
434 IntegerType i32 = rewriter.getI32Type();
435 ShapedType tensor_type = RankedTensorType::get({}, i32);
436 state.addOperands(ctl);
437 state.addAttribute("value", DenseElementsAttr::get(tensor_type, 0));
438 state.addAttribute("dtype", TypeAttr::get(i32));
439 state.addTypes({tensor_type, ctl.getType()});
440
441 // Inherit `tfg.tpu_replicate`, `assigned_device`, and `device`.
442 for (StringAttr attr_name : {StringAttr::get(ctx_, "_tpu_replicate"),
443 dialect_.getAssignedDeviceAttrIdentifier(),
444 dialect_.getDeviceAttrIdentifier()}) {
445 if (Attribute attr = parent->getAttr(attr_name))
446 state.addAttribute(attr_name, attr);
447 }
448
449 // Inherit a name based on the parent name.
450 StringAttr name_id = dialect_.getNameAttrIdentifier();
451 if (StringAttr name = parent->getAttrOfType<StringAttr>(name_id)) {
452 auto const_name = rewriter.getStringAttr(
453 name.getValue() + "_mlir_const_capture_" + Twine(idx));
454 state.addAttribute(name_id, const_name);
455 }
456
457 return rewriter.create(state);
458 }
459
CollectValuesDefinedAboveAll(RegionRange regions,PatternRewriter & rewriter) const460 FailureOr<std::vector<Value>> BasePattern::CollectValuesDefinedAboveAll(
461 RegionRange regions, PatternRewriter &rewriter) const {
462 SetVector<Value> data_set, ctl_only;
463 for (Region ®ion : llvm::make_pointee_range(regions))
464 CollectValuesDefinedAbove(region, data_set, ctl_only);
465 std::vector<Value> datas = data_set.takeVector();
466
467 // If in any of the regions we found a use of a control token defined above
468 // the regions with no associated data value, then it cannot be converted to
469 // explicit capture unless we insert chain constants. If this option was not
470 // set, return failure because the region op cannot be converted.
471 if (!force_control_capture_ && !ctl_only.empty()) return failure();
472
473 Operation *parent = regions.front()->getParentOp();
474 for (auto &ctl : llvm::enumerate(ctl_only.takeVector())) {
475 Operation *const_op =
476 MakeChainConstant(parent, ctl.value(), ctl.index(), rewriter);
477 for (Region *region : regions)
478 replaceAllUsesInRegionWith(ctl.value(), const_op->getResult(1), *region);
479 datas.push_back(const_op->getResult(0));
480 }
481
482 return datas;
483 }
484
IsolateRegions(RegionRange regions,MutableArrayRef<Value> datas) const485 void BasePattern::IsolateRegions(RegionRange regions,
486 MutableArrayRef<Value> datas) const {
487 ValueControlRetRange ctls(datas);
488 Value data, ctl;
489 for (Region ®ion : llvm::make_pointee_range(regions)) {
490 for (auto it : llvm::zip(datas, ctls)) {
491 std::tie(data, ctl) = it;
492 util::LoopRegionArgumentUpdate result =
493 util::LoopRegionAddArgument(region, data.getType());
494 replaceAllUsesInRegionWith(data, result.data, region);
495 replaceAllUsesInRegionWith(ctl, result.ctl, region);
496 }
497 }
498 }
499
BuildAttributes(RegionAttr preserved,ValueRange arguments,ValueRange results,NameUniquer * name_uniquer) const500 NamedAttrList BasePattern::BuildAttributes(RegionAttr preserved,
501 ValueRange arguments,
502 ValueRange results,
503 NameUniquer *name_uniquer) const {
504 NamedAttrList attrs(preserved ? preserved.getAttrs() : DictionaryAttr());
505 // The original function name is preserved in the region attributes, but don't
506 // re-use it when creating a new function.
507 attrs.erase(SymbolTable::getSymbolAttrName());
508
509 SmallVector<Attribute> arg_attrs, res_attrs;
510 ArrayAttr preserved_arg_attrs =
511 preserved ? preserved.getArgAttrs() : ArrayAttr();
512 ArrayAttr preserved_res_attrs =
513 preserved ? preserved.getResAttrs() : ArrayAttr();
514
515 // For each argument and result, lookup a name and regenerate output shapes.
516 const auto build_attrs = [&](ArrayAttr attr, auto &it,
517 Optional<ValueRange> args) {
518 NamedAttrList attrs(attr ? attr[it.index()].template cast<DictionaryAttr>()
519 : DictionaryAttr());
520 // If no name was preserved, try to find one.
521 if (!attrs.get(ids_.tfg_name)) {
522 if (StringAttr name = TryFindName(it.value(), args))
523 attrs.set(ids_.tfg_name, name_uniquer->GetUniqued(name));
524 }
525 attrs.set(ids_.tfg_regenerate_output_shapes, UnitAttr::get(ctx_));
526 return attrs.getDictionary(ctx_);
527 };
528
529 for (auto &it : llvm::enumerate(arguments)) {
530 arg_attrs.append({build_attrs(preserved_arg_attrs, it, {}),
531 DictionaryAttr::get(ctx_, {})});
532 }
533 for (auto &it : llvm::enumerate(results))
534 res_attrs.push_back(build_attrs(preserved_res_attrs, it, arguments));
535
536 attrs.append(FunctionOpInterface::getArgDictAttrName(),
537 ArrayAttr::get(ctx_, arg_attrs));
538 attrs.append(FunctionOpInterface::getResultDictAttrName(),
539 ArrayAttr::get(ctx_, res_attrs));
540 return attrs;
541 }
542
TryFindName(Value value,Optional<ValueRange> args) const543 StringAttr BasePattern::TryFindName(Value value,
544 Optional<ValueRange> args) const {
545 // If this is an op result, return the op's name.
546 if (auto result = value.dyn_cast<OpResult>()) {
547 Operation *op = result.getOwner();
548 if (auto name =
549 op->getAttrOfType<StringAttr>(dialect_.getNameAttrIdentifier())) {
550 return StringAttr::get(ctx_, name.getValue() + "_tfg_result_" +
551 Twine(result.getResultNumber()));
552 }
553 return {};
554 }
555
556 auto arg = value.cast<BlockArgument>();
557 Operation *parent = arg.getOwner()->getParentOp();
558 auto iface = dyn_cast<ControlArgumentInterface>(parent);
559 if (!iface) return {};
560 // If we were given a control token, lookup a name using the data value.
561 if (arg.getType() == dialect_.getControlType())
562 arg = iface.getDataValueOf(arg);
563 // If the parent is a function, try to find a `tfg.name`.
564 if (auto func = dyn_cast<GraphFuncOp>(*iface))
565 return func.getArgAttrOfType<StringAttr>(arg.getArgNumber(), ids_.tfg_name);
566 // Otherwise, "see through" to the corresponding operand.
567 if (args) {
568 assert(arg.getArgNumber() < args->size());
569 return TryFindName((*args)[arg.getArgNumber()], {});
570 }
571 if (auto for_op = dyn_cast<ForRegionOp>(parent)) {
572 unsigned arg_idx = arg.getArgNumber();
573 if (arg_idx == 0) return TryFindName(for_op.start(), {});
574 return TryFindName(for_op.init()[arg_idx - 1], {});
575 }
576 auto branch = cast<RegionBranchOpInterface>(parent);
577 ValueRange inputs = branch.getSuccessorEntryOperands(
578 arg.getParentRegion()->getRegionNumber());
579 return TryFindName(inputs[arg.getArgNumber()], {});
580 }
581
GetControlRetAttrs(ValueRange ctls,ValueRange args,NameUniquer * name_uniquer) const582 ArrayAttr BasePattern::GetControlRetAttrs(ValueRange ctls, ValueRange args,
583 NameUniquer *name_uniquer) const {
584 SmallVector<Attribute> ctl_ret_attrs;
585 for (Value ctl : ctls) {
586 NamedAttrList ctl_attrs;
587 if (StringAttr name = TryFindName(ctl, args)) {
588 ctl_attrs.set(dialect_.getTfgNameAttrIdentifier(),
589 name_uniquer->GetUniqued(name));
590 }
591 ctl_ret_attrs.push_back(ctl_attrs.getDictionary(ctx_));
592 }
593 return ArrayAttr::get(ctx_, ctl_ret_attrs);
594 }
595
CreateFunc(Location loc,const Twine & sym_name,Region & region,TypeRange res_types,NamedAttrList attrs) const596 GraphFuncOp BasePattern::CreateFunc(Location loc, const Twine &sym_name,
597 Region ®ion, TypeRange res_types,
598 NamedAttrList attrs) const {
599 SmallVector<Type> arg_types;
600 for (BlockArgument operand : GetLoopRegionDataArgs(region))
601 arg_types.append({operand.getType(), dialect_.getControlType()});
602 auto func_type = FunctionType::get(ctx_, arg_types, res_types);
603 auto func = OpBuilder(ctx_).create<GraphFuncOp>(loc, sym_name, func_type,
604 /*generic=*/false);
605
606 attrs.append(func->getAttrs());
607 func->setAttrs(attrs.getDictionary(ctx_));
608
609 SmallVector<BlockArgument> args =
610 llvm::to_vector(GetLoopRegionDataArgs(region));
611 SmallVector<BlockArgument> ctls =
612 llvm::to_vector(GetLoopRegionControlTokens(region));
613 // TODO(jeffniu): Change GraphFuncOp to use the same argument order as region
614 // loop ops.
615 for (auto it : llvm::zip(args, ctls)) {
616 BlockArgument arg, ctl;
617 std::tie(arg, ctl) = it;
618 arg.replaceAllUsesWith(region.addArgument(arg.getType(), arg.getLoc()));
619 ctl.replaceAllUsesWith(region.addArgument(ctl.getType(), ctl.getLoc()));
620 }
621 llvm::BitVector indices(region.getNumArguments());
622 indices.set(0, args.size() * 2);
623 region.front().eraseArguments(indices);
624
625 func.body().takeBody(region);
626 return func;
627 }
628
629 // Check the region attributes for a preserved function name
630 // TODO(jeffniu): RegionAttr should have an optional parameter for the function
631 // name, since it is treated differently from the other attributes.
GetFunctionName(RegionAttr preserved)632 static StringAttr GetFunctionName(RegionAttr preserved) {
633 if (!preserved) return {};
634 return preserved.getAttrs().getAs<StringAttr>(
635 SymbolTable::getSymbolAttrName());
636 }
637
Outline(Operation * op,PatternRewriter & rewriter,ValueRange args,Region & region,RegionAttr preserved,DictionaryAttr attrs,const Twine & func_name) const638 FuncAttr BasePattern::Outline(Operation *op, PatternRewriter &rewriter,
639 ValueRange args, Region ®ion,
640 RegionAttr preserved, DictionaryAttr attrs,
641 const Twine &func_name) const {
642 // Create a name scope for the function.
643 NameUniquer name_uniquer(ctx_);
644
645 NamedAttrList func_attrs = BuildAttributes(
646 preserved, args, cast<YieldOp>(region.front().getTerminator()).args(),
647 &name_uniquer);
648
649 auto yield = cast<YieldOp>(region.front().getTerminator());
650 rewriter.setInsertionPoint(yield);
651 auto ret_op = rewriter.replaceOpWithNewOp<ReturnOp>(
652 yield, yield.getOperands(),
653 GetControlRetAttrs(yield.ctls(), args, &name_uniquer));
654
655 // Derive a function name. Use a default name. If a previous name exists,
656 // use it. If the op also has a name, derive a name based on that.
657 std::string new_func_name = func_name.str();
658 if (StringAttr existing_name = GetFunctionName(preserved)) {
659 new_func_name = existing_name.getValue().str();
660 if (auto op_name =
661 op->getAttrOfType<StringAttr>(dialect_.getNameAttrIdentifier())) {
662 llvm::raw_string_ostream os(new_func_name);
663 os << "_tfg_region_specialized_";
664 for (char c : llvm::map_range(
665 op_name.getValue(), [](char c) { return isalnum(c) ? c : '_'; }))
666 os << c;
667 os << '_' << llvm::to_string(region.getRegionNumber());
668 os.flush();
669 }
670 }
671
672 // Create the function.
673 GraphFuncOp func = CreateFunc(op->getLoc(), new_func_name, region,
674 TFOp(ret_op).getNonControlOperands().getTypes(),
675 std::move(func_attrs));
676 return FuncAttr::get(ctx_, table_.insert(func),
677 attrs ? attrs : DictionaryAttr::get(ctx_, {}));
678 }
679
680 template <typename FuncAttrT>
ReuseAllOrOutline(Operation * op,PatternRewriter & rewriter,ValueRange args,ArrayRef<RegionFunction> regions,SmallVectorImpl<FuncAttrT> & functions) const681 void BasePattern::ReuseAllOrOutline(
682 Operation *op, PatternRewriter &rewriter, ValueRange args,
683 ArrayRef<RegionFunction> regions,
684 SmallVectorImpl<FuncAttrT> &functions) const {
685 // Try to find reusable functions for all regions.
686 const auto get_reusable_func = [this,
687 &functions](const RegionFunction &func) {
688 FuncAttr ref =
689 FindReusableFunc(func.region, func.preserved_attrs, func.call_attrs);
690 functions.push_back(ref);
691 return ref;
692 };
693 if (llvm::all_of(regions, get_reusable_func)) return;
694
695 // At least one region needs to be outlined.
696 functions.clear();
697 for (const RegionFunction &func : regions) {
698 functions.push_back(Outline(op, rewriter, args, func.region,
699 func.preserved_attrs, func.call_attrs,
700 func.func_name));
701 }
702 }
703
704 // Returns true if the region has any nested regions.
HasNestedRegions(Region & region)705 static bool HasNestedRegions(Region ®ion) {
706 return llvm::any_of(region.getOps(),
707 [](Operation &op) { return op.getNumRegions(); });
708 }
709
710 // Check if the region is "equivalent" to the body of the given function, and so
711 // the function can be re-used when outlining the region. This compares
712 // (topologically) the arguments, results, and ops, ignoring the op names and
713 // checking for compatible types.
RegionEqualTo(Region & region,GraphFuncOp func)714 static bool RegionEqualTo(Region ®ion, GraphFuncOp func) {
715 assert(!HasNestedRegions(region));
716 assert(!HasNestedRegions(func.body()));
717
718 // Outlining is performed "bottom-up". I.e. regions with no nested regions are
719 // outlined first, which means that we will not have to worry about comparing
720 // `While` to `WhileRegion`. Also, it means that we can directly compare the
721 // operations.
722 DenseMap<Value, Value> value_map;
723 auto map_value = [&](Value lhs, Value rhs) {
724 if (!tf_type::HasCompatibleElementTypes(lhs.getType(), rhs.getType()))
725 return false;
726 return value_map.insert({lhs, rhs}).first->second == rhs;
727 };
728
729 // Compare the non-control block arguments.
730 if (region.getNumArguments() != func.getNumArguments()) return false;
731 for (auto &it : llvm::enumerate(GetLoopRegionDataArgs(region))) {
732 Value rhs = GraphFuncOp::getDataValue(func.body(), it.index());
733 if (!map_value(it.value(), rhs)) return false;
734 }
735
736 // Compare the bodies except the terminators. We can't use
737 // OperationEquivalence due to relaxed type equality.
738 auto map_value_range = [](ValueRange lhs_range, ValueRange rhs_range,
739 auto map_value) {
740 if (lhs_range.size() != rhs_range.size()) return false;
741 for (auto it : llvm::zip(lhs_range, rhs_range))
742 if (!map_value(std::get<0>(it), std::get<1>(it))) return false;
743 return true;
744 };
745
746 StringAttr name_id =
747 cast<TFGraphDialect>(func->getDialect())->getNameAttrIdentifier();
748
749 auto compare_ops = [&](Operation &lhs, Operation &rhs) {
750 if (lhs.getName() != rhs.getName()) return false;
751
752 DictionaryAttr lhs_attrs = lhs.getAttrDictionary();
753 DictionaryAttr rhs_attrs = rhs.getAttrDictionary();
754 if (lhs_attrs.size() != rhs_attrs.size()) return false;
755 for (auto it : llvm::zip(lhs_attrs, rhs_attrs)) {
756 NamedAttribute lhs_attr = std::get<0>(it);
757 NamedAttribute rhs_attr = std::get<1>(it);
758 if (lhs_attr.getName() != rhs_attr.getName()) return false;
759 if (lhs_attr.getName() == name_id) continue;
760 if (lhs_attr.getValue() != rhs_attr.getValue()) return false;
761 }
762 if (!map_value_range(lhs.getOperands(), rhs.getOperands(), map_value))
763 return false;
764 if (!map_value_range(lhs.getResults(), rhs.getResults(), map_value))
765 return false;
766 assert(!lhs.getNumRegions() && !rhs.getNumRegions());
767 return true;
768 };
769 if (!llvm::all_of_zip(region.front().without_terminator(),
770 func.body().front().without_terminator(), compare_ops))
771 return false;
772
773 // Compare just the operands of the terminators.
774 auto return_op = cast<ReturnOp>(func.body().front().getTerminator());
775 Operation *terminator = region.front().getTerminator();
776 if (auto yield = dyn_cast<YieldOp>(terminator)) {
777 return map_value_range(yield->getOperands(), return_op->getOperands(),
778 map_value);
779 } else {
780 auto cond = cast<ConditionOp>(terminator);
781 return map_value(cond.cond(), return_op->getOperand(0)) &&
782 map_value_range(
783 cond.ctls(),
784 return_op->getOperands().slice(1, cond.ctls().size()),
785 map_value);
786 }
787 }
788
FindReusableFunc(Region & region,RegionAttr preserved,DictionaryAttr attrs) const789 FuncAttr BasePattern::FindReusableFunc(Region ®ion, RegionAttr preserved,
790 DictionaryAttr attrs) const {
791 StringAttr name = GetFunctionName(preserved);
792 if (!name) return {};
793 auto func = table_.lookup<GraphFuncOp>(name);
794 if (!func) return {};
795 if (!RegionEqualTo(region, func)) return {};
796 return FuncAttr::get(region.getContext(), name.getValue(),
797 attrs ? attrs : DictionaryAttr::get(ctx_, {}));
798 }
799
FuncHasNestedRegions(RegionAttr preserved) const800 bool BasePattern::FuncHasNestedRegions(RegionAttr preserved) const {
801 StringAttr name = GetFunctionName(preserved);
802 if (!name) return false;
803 auto func = table_.lookup<GraphFuncOp>(name);
804 return func && HasNestedRegions(func.body());
805 }
806
807 //===----------------------------------------------------------------------===//
808 // ConvertToExplicitCapture
809 //===----------------------------------------------------------------------===//
810
811 template <typename OpT>
812 FailureOr<std::pair<OpT, std::vector<Value>>>
Run(OpT op,PatternRewriter & rewriter)813 ConvertToExplicitCapture<OpT>::Run(OpT op, PatternRewriter &rewriter) {
814 FailureOr<std::vector<Value>> operands =
815 this->CollectValuesDefinedAboveAll(op->getRegions(), rewriter);
816 if (failed(operands)) return failure();
817 this->IsolateRegions(op->getRegions(), *operands);
818 OpT new_op = RebuildWith(op, *operands, rewriter);
819 util::ForwardNonIntrinsicAttributes(op, new_op);
820 for (auto it : llvm::zip(op->getRegions(), new_op->getRegions()))
821 std::get<1>(it).takeBody(std::get<0>(it));
822 rewriter.replaceOp(op, new_op->getResults().slice(0, op->getNumResults()));
823 return std::make_pair(new_op, std::move(*operands));
824 }
825
826 //===----------------------------------------------------------------------===//
827 // ConvertIfLikeOp
828 //===----------------------------------------------------------------------===//
829
830 template <typename IfLikeRegionOp, typename IfLikeOp>
matchAndRewrite(IfLikeRegionOp op,PatternRewriter & rewriter) const831 LogicalResult ConvertIfLikeOp<IfLikeRegionOp, IfLikeOp>::matchAndRewrite(
832 IfLikeRegionOp op, PatternRewriter &rewriter) const {
833 if (HasNestedRegions(op.then_region()) || HasNestedRegions(op.else_region()))
834 return failure();
835 if (this->FuncHasNestedRegions(op.then_region_attrsAttr()) ||
836 this->FuncHasNestedRegions(op.else_region_attrsAttr()))
837 return failure();
838
839 // Convert the op to explicit capture.
840 ConvertIfLikeRegionOpToExplicitCapture<IfLikeRegionOp> converter(
841 this->dialect_, this->table_, this->force_control_capture_, this->ids_);
842 auto result = converter.Run(op, rewriter);
843 if (failed(result)) return failure();
844 std::vector<Value> args;
845 std::tie(op, args) = std::move(*result);
846
847 // Outline the regions.
848 SmallVector<FuncAttr, 2> branches;
849 this->ReuseAllOrOutline(op, rewriter, args,
850 {{op.then_region(), op.then_region_attrsAttr(),
851 op.then_attrsAttr(), "if_then_function"},
852 {op.else_region(), op.else_region_attrsAttr(),
853 op.else_attrsAttr(), "if_else_function"}},
854 branches);
855
856 // Build the functional if-like op.
857 SmallVector<Value> operands = llvm::to_vector(args);
858 llvm::append_range(operands, op.ctls());
859
860 rewriter.setInsertionPoint(op);
861 auto func_op =
862 rewriter.create<IfLikeOp>(op.getLoc(), op.getResultTypes(), op.cond(),
863 operands, branches[0], branches[1]);
864 util::ForwardNonIntrinsicAttributes(op, func_op);
865 rewriter.replaceOp(op, func_op.getResults());
866 return success();
867 }
868
869 //===----------------------------------------------------------------------===//
870 // ConvertCaseLikeOp
871 //===----------------------------------------------------------------------===//
872
873 template <typename CaseLikeRegionOp, typename CaseLikeOp>
matchAndRewrite(CaseLikeRegionOp op,PatternRewriter & rewriter) const874 LogicalResult ConvertCaseLikeOp<CaseLikeRegionOp, CaseLikeOp>::matchAndRewrite(
875 CaseLikeRegionOp op, PatternRewriter &rewriter) const {
876 if (llvm::any_of(op.branches(), HasNestedRegions)) return failure();
877 if (ArrayAttr preserved = op.region_attrsAttr()) {
878 if (llvm::any_of(preserved.getAsRange<RegionAttr>(), [&](auto preserved) {
879 return this->FuncHasNestedRegions(preserved);
880 }))
881 return failure();
882 }
883
884 // Convert the op to explicit capture.
885 ConvertCaseLikeRegionOpToExplicitCapture<CaseLikeRegionOp> converter(
886 this->dialect_, this->table_, this->force_control_capture_, this->ids_);
887 auto result = converter.Run(op, rewriter);
888 if (failed(result)) return failure();
889 std::vector<Value> args;
890 std::tie(op, args) = std::move(*result);
891
892 // Outline the regions.
893 ArrayAttr branch_func_attrs = op.branch_attrsAttr();
894 SmallVector<BasePattern::RegionFunction> branch_regions;
895 for (auto &it : llvm::enumerate(op.branches())) {
896 unsigned idx = it.index();
897 // Get the preserved attributes, if there are any.
898 RegionAttr preserved =
899 op.region_attrs()
900 ? op.region_attrsAttr()[idx].template cast<RegionAttr>()
901 : nullptr;
902 DictionaryAttr attrs =
903 branch_func_attrs
904 ? branch_func_attrs[idx].template cast<DictionaryAttr>()
905 : nullptr;
906 branch_regions.push_back(BasePattern::RegionFunction{
907 it.value(), preserved, attrs, ("case_function_" + Twine(idx)).str()});
908 }
909 SmallVector<Attribute> branches;
910 this->ReuseAllOrOutline(op, rewriter, args, branch_regions, branches);
911
912 // Build the functional case-like op.
913 SmallVector<Value> operands = llvm::to_vector(args);
914 llvm::append_range(operands, op.ctls());
915
916 rewriter.setInsertionPoint(op);
917 auto func_op = rewriter.create<CaseLikeOp>(op.getLoc(), op.getResultTypes(),
918 op.branch_index(), operands,
919 rewriter.getArrayAttr(branches));
920 util::ForwardNonIntrinsicAttributes(op, func_op);
921 rewriter.replaceOp(op, func_op.getResults());
922 return success();
923 }
924
925 //===----------------------------------------------------------------------===//
926 // ConvertWhileLikeOp
927 //===----------------------------------------------------------------------===//
928
929 template <typename WhileLikeRegionOp, typename WhileLikeOp>
930 LogicalResult
matchAndRewrite(WhileLikeRegionOp op,PatternRewriter & rewriter) const931 ConvertWhileLikeOp<WhileLikeRegionOp, WhileLikeOp>::matchAndRewrite(
932 WhileLikeRegionOp op, PatternRewriter &rewriter) const {
933 if (HasNestedRegions(op.cond_region()) || HasNestedRegions(op.body_region()))
934 return failure();
935 if (this->FuncHasNestedRegions(op.cond_region_attrsAttr()) ||
936 this->FuncHasNestedRegions(op.body_region_attrsAttr()))
937 return failure();
938
939 // Convert the op to explicit capture.
940 ConvertWhileLikeRegionOpToExplicitCapture<WhileLikeRegionOp> converter(
941 this->dialect_, this->table_, this->force_control_capture_, this->ids_);
942 auto result = converter.Run(op, rewriter);
943 if (failed(result)) return failure();
944 op = result->first;
945
946 // Try to find re-usable functions for both the condition and body regions.
947 FuncAttr body_ref = this->FindReusableFunc(
948 op.body_region(), op.body_region_attrsAttr(), op.body_attrsAttr());
949 FuncAttr cond_ref = this->FindReusableFunc(
950 op.cond_region(), op.cond_region_attrsAttr(), op.cond_attrsAttr());
951
952 // If a function for either region could not be re-used, outline them out.
953 if (!body_ref || !cond_ref) {
954 // Handle the condition region. Unlike other regions, the terminator is
955 // special and the function only has one result.
956 ConditionOp cond_op = op.cond_condition();
957 // Create a name scope for the condition function.
958 NameUniquer name_uniquer(this->ctx_);
959 // Create the function.
960 NamedAttrList cond_attrs = this->BuildAttributes(
961 op.cond_region_attrsAttr(), op.init(), cond_op.cond(), &name_uniquer);
962 GraphFuncOp cond_func =
963 this->CreateFunc(op.getLoc(), "while_cond_function", op.cond_region(),
964 cond_op.cond().getType(), std::move(cond_attrs));
965 // Replace the condition terminator.
966 rewriter.setInsertionPoint(cond_op);
967 SmallVector<Value> cond_rets = {cond_op.cond()};
968 llvm::append_range(cond_rets, cond_op.ctls());
969 rewriter.replaceOpWithNewOp<ReturnOp>(
970 cond_op, cond_rets,
971 this->GetControlRetAttrs(cond_op.ctls(), op.init(), &name_uniquer));
972 // Insert the function and grab a reference.
973 cond_ref =
974 FuncAttr::get(op.getContext(), this->table_.insert(cond_func),
975 op.cond_attrs().value_or(rewriter.getDictionaryAttr({})));
976
977 // Outline the body.
978 body_ref = this->Outline(op, rewriter, op.init(), op.body_region(),
979 op.body_region_attrsAttr(), op.body_attrsAttr(),
980 "while_body_function");
981 }
982
983 // Create the functional op.
984 SmallVector<Value> operands = llvm::to_vector(op.init());
985 llvm::append_range(operands, op.ctls());
986
987 rewriter.setInsertionPoint(op);
988 auto func_op = rewriter.create<WhileLikeOp>(op.getLoc(), op.getResultTypes(),
989 operands, cond_ref, body_ref,
990 op.parallel_iterationsAttr());
991 util::ForwardNonIntrinsicAttributes(op, func_op);
992 rewriter.replaceOp(op, func_op.getResults());
993 return success();
994 }
995
996 //===----------------------------------------------------------------------===//
997 // ConvertForOp
998 //===----------------------------------------------------------------------===//
999
matchAndRewrite(ForRegionOp op,PatternRewriter & rewriter) const1000 LogicalResult ConvertForOp::matchAndRewrite(ForRegionOp op,
1001 PatternRewriter &rewriter) const {
1002 if (HasNestedRegions(op.body_region())) return failure();
1003 if (this->FuncHasNestedRegions(op.region_attrsAttr())) return failure();
1004
1005 // Convert the op to explicit capture.
1006 ConvertForRegionOpToExplicitCapture converter(dialect_, table_,
1007 force_control_capture_, ids_);
1008 auto result = converter.Run(op, rewriter);
1009 if (failed(result)) return failure();
1010 op = result->first;
1011
1012 // Outline to body.
1013 SmallVector<Value> func_args(/*Size=*/1, op.start());
1014 llvm::append_range(func_args, op.init());
1015 SmallVector<FuncAttr, 1> body_ref;
1016 ReuseAllOrOutline(op, rewriter, func_args,
1017 {{op.body_region(), op.region_attrsAttr(),
1018 op.body_attrsAttr(), "for_body_function"}},
1019 body_ref);
1020
1021 // Create the functional op.
1022 SmallVector<Value> operands = llvm::to_vector(op.init());
1023 llvm::append_range(operands, op.ctls());
1024
1025 rewriter.setInsertionPoint(op);
1026 auto func_op = rewriter.create<tfg::ForOp>(op.getLoc(), op.getResultTypes(),
1027 op.start(), op.limit(), op.delta(),
1028 operands, body_ref[0]);
1029 util::ForwardNonIntrinsicAttributes(op, func_op);
1030 rewriter.replaceOp(op, func_op.getResults());
1031 return success();
1032 }
1033
1034 //===----------------------------------------------------------------------===//
1035 // Pattern Population
1036 //===----------------------------------------------------------------------===//
1037
PopulateRegionToFunctionalPatterns(RewritePatternSet & patterns,SymbolTable & table,bool force_control_capture)1038 void PopulateRegionToFunctionalPatterns(RewritePatternSet &patterns,
1039 SymbolTable &table,
1040 bool force_control_capture) {
1041 auto *dialect = patterns.getContext()->getOrLoadDialect<TFGraphDialect>();
1042 patterns.insert<ConvertIfOp, ConvertStatelessIfOp, ConvertStatefulIfOp,
1043 ConvertCaseOp, ConvertStatelessCaseOp, ConvertStatefulCaseOp,
1044 ConvertWhileOp, ConvertStatelessWhileOp,
1045 ConvertStatefulWhileOp, ConvertForOp>(
1046 patterns.getContext(), *dialect, table, force_control_capture,
1047 CachedIdentifiers(dialect));
1048 }
1049
1050 } // namespace tfg
1051 } // namespace mlir
1052