xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/consolidate_attrs/pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/transforms/consolidate_attrs/pass.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Pass/PassManager.h"  // from @llvm-project
28 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
30 #include "tensorflow/core/ir/dialect.h"
31 #include "tensorflow/core/ir/ops.h"
32 #include "tensorflow/core/ir/tf_op_wrapper.h"
33 #include "tensorflow/core/ir/types/dialect.h"
34 #include "tensorflow/core/ir/utility.h"
35 #include "tensorflow/core/transforms/pass_detail.h"
36 
37 namespace mlir {
38 namespace tfg {
39 
40 static const char *kRegenerateOutputShapes = "tfg.regenerate_output_shapes";
41 
42 // Returns true if an attribute is an array of shapes;
IsArrayOfShapes(ArrayAttr array)43 static bool IsArrayOfShapes(ArrayAttr array) {
44   return llvm::all_of(array,
45                       [](Attribute attr) { return attr.isa<ShapeAttr>(); });
46 }
47 
48 // Given a tensor type and shape information, try to refine the type.
GetReifiedType(Type orig,ShapeAttr shape)49 static Type GetReifiedType(Type orig, ShapeAttr shape) {
50   Type element_type = orig.cast<ShapedType>().getElementType();
51   TensorType inferred;
52   if (shape.hasRank()) {
53     // Replace dimensions less than -1 with ?
54     SmallVector<int64_t> dims = llvm::to_vector(shape.getShape());
55     for (int64_t &dim : dims)
56       if (dim < -1) dim = -1;
57     inferred = RankedTensorType::get(dims, element_type);
58   } else {
59     inferred = UnrankedTensorType::get(element_type);
60   }
61   Type reified_type = tf_type::GetCastCompatibleType(inferred, orig);
62   // If the types are not compatible, return the original type.
63   return reified_type ? reified_type : orig;
64 }
65 
66 namespace {
67 // CRTP base class for consolidate attribute passes. This base class defines
68 // cached identifiers for the attributes.
69 template <typename PassT>
70 class AttributesPassBase : public PassWrapper<PassT, OperationPass<>> {
71  public:
initialize(MLIRContext * context)72   LogicalResult initialize(MLIRContext *context) override {
73     input_shapes_id_ = StringAttr::get(context, "tf._input_shapes");
74     regenerate_input_shapes_id_ =
75         StringAttr::get(context, "tfg.regenerate_input_shapes");
76     output_shapes_id_ = StringAttr::get(context, "tf._output_shapes");
77     regenerate_output_shapes_id_ =
78         StringAttr::get(context, "tfg.regenerate_output_shapes");
79     handle_data_id_ = StringAttr::get(context, "tfg.handle_data");
80     dtype_id_ = StringAttr::get(context, "tfg.dtype");
81     is_ref_id_ = StringAttr::get(context, "tfg.is_ref");
82     control_type_ = ControlType::get(context);
83     return success();
84   }
85 
86  protected:
87   // Identifier for `tf._input_shapes`.
88   StringAttr input_shapes_id_;
89   // Identifier for `tf._regenerate_input_shapes`.
90   StringAttr regenerate_input_shapes_id_;
91   // Identifier for `tf._output_shapes`.
92   StringAttr output_shapes_id_;
93   // Identifier for `tf._regenerate_output_shapes`.
94   StringAttr regenerate_output_shapes_id_;
95   // Identifier for `tfg.handle_data`.
96   StringAttr handle_data_id_;
97   // Identifier for `tfg.dtype`.
98   StringAttr dtype_id_;
99   // Identifier for `tfg.is_ref`.
100   StringAttr is_ref_id_;
101   // Cacched control type.
102   ControlType control_type_;
103 };
104 
105 class ConsolidateAttributesPassImpl
106     : public AttributesPassBase<ConsolidateAttributesPassImpl> {
107  public:
108   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConsolidateAttributesPassImpl)
109 
110   void runOnOperation() override;
111 
112  private:
113   // Reify `tf._input_shapes`, `tf._output_shapes` and `tfg.handle_data` into
114   // the types of the function arguments. Drop the attributes `tfg.dtype` and
115   // `tfg.is_ref`. Return the new argument attributes.
116   ArrayAttr reifyAndDropFunctionArgumentAttributes(GraphFuncOp func);
117   // Reify `tf._output_shapes` and `tfg.handle_data` into the types of the
118   // function results. Drop the attribute `tfg.dtype`. Return the new result
119   // attributes.
120   ArrayAttr reifyAndDropFunctionResultAttributes(GraphFuncOp func);
121 
122   // Refine a type with `tf._output_shapes`.
123   Type refineTypeWithOutputShapes(Type type, NamedAttrList &attrs);
124   // Refine a type with `tfg.handle_data`.
125   Type refineTypeWithHandleData(Type type, Attribute handle_data);
126 };
127 }  // namespace
128 
refineTypeWithOutputShapes(Type type,NamedAttrList & attrs)129 Type ConsolidateAttributesPassImpl::refineTypeWithOutputShapes(
130     Type type, NamedAttrList &attrs) {
131   // Get the output shapes attribute. If the attribute is not an array of
132   // exactly one shape, ignore it.
133   if (auto output_shapes =
134           attrs.get(output_shapes_id_).dyn_cast_or_null<ArrayAttr>()) {
135     if (output_shapes.size() == 1 && IsArrayOfShapes(output_shapes)) {
136       attrs.erase(output_shapes_id_);
137       attrs.set(regenerate_output_shapes_id_, UnitAttr::get(&getContext()));
138       return GetReifiedType(type, output_shapes[0].cast<ShapeAttr>());
139     }
140   }
141   return type;
142 }
143 
refineTypeWithHandleData(Type type,Attribute handle_data)144 Type ConsolidateAttributesPassImpl::refineTypeWithHandleData(
145     Type type, Attribute handle_data) {
146   if (!handle_data) return type;
147   SmallVector<TensorType> subtypes;
148   // Because `tfg.handle_data` is a TFG internal attribute, it will be
149   // well-formed.
150   for (Type type : handle_data.cast<ArrayAttr>().getAsValueRange<TypeAttr>())
151     subtypes.push_back(type.cast<TensorType>());
152   auto resource =
153       UnrankedTensorType::get(ResourceType::get(subtypes, &getContext()));
154   Type reified = tf_type::GetCastCompatibleType(resource, type);
155   return reified ? reified : type;
156 }
157 
reifyAndDropFunctionArgumentAttributes(GraphFuncOp func)158 ArrayAttr ConsolidateAttributesPassImpl::reifyAndDropFunctionArgumentAttributes(
159     GraphFuncOp func) {
160   // Get the input shapes attribute. If it is a UnitAttr, then it is empty and
161   // we will ignore it. If it isn't an array of shapes or has an inconsistent
162   // number of shapes, ignore it.
163   ArrayAttr input_shapes =
164       func->getAttr(input_shapes_id_).dyn_cast_or_null<ArrayAttr>();
165   unsigned num_args = func.getNumArguments() / 2;
166   if (input_shapes) {
167     if (input_shapes.size() != num_args || !IsArrayOfShapes(input_shapes)) {
168       input_shapes = {};
169     } else {
170       func->removeAttr(input_shapes_id_);
171       func->setAttr(regenerate_input_shapes_id_, UnitAttr::get(&getContext()));
172     }
173   }
174 
175   SmallVector<Attribute> arg_attrs;
176   auto empty_dict = DictionaryAttr::get(&getContext());
177   for (auto i : llvm::seq<unsigned>(0, num_args)) {
178     BlockArgument arg = GraphFuncOp::getDataValue(func.body(), i);
179     NamedAttrList attrs(func.getArgAttrs(arg.getArgNumber()));
180     Type arg_type = arg.getType();
181     arg_type = refineTypeWithOutputShapes(arg_type, attrs);
182     arg_type = refineTypeWithHandleData(arg_type, attrs.erase(handle_data_id_));
183     if (input_shapes)
184       arg_type = GetReifiedType(arg_type, input_shapes[i].cast<ShapeAttr>());
185     arg.setType(arg_type);
186     attrs.erase(dtype_id_);
187     attrs.erase(is_ref_id_);
188     arg_attrs.append({attrs.getDictionary(&getContext()), empty_dict});
189   }
190   return ArrayAttr::get(&getContext(), arg_attrs);
191 }
192 
reifyAndDropFunctionResultAttributes(GraphFuncOp func)193 ArrayAttr ConsolidateAttributesPassImpl::reifyAndDropFunctionResultAttributes(
194     GraphFuncOp func) {
195   ArrayAttr res_attrs = func.getAllResultAttrs();
196   if (!res_attrs) return ArrayAttr::get(&getContext(), {});
197 
198   SmallVector<Attribute> ret_attrs;
199   // The result types are propagated to the data operands to `return`.
200   auto ret_op = cast<ReturnOp>(func.body().front().getTerminator());
201   for (auto &it : llvm::enumerate(res_attrs.getAsRange<DictionaryAttr>())) {
202     NamedAttrList attrs(it.value());
203     Value ret = ret_op.getOperand(it.index());
204     Type ret_type = ret.getType();
205     ret_type = refineTypeWithOutputShapes(ret_type, attrs);
206     ret_type = refineTypeWithHandleData(ret_type, attrs.erase(handle_data_id_));
207     ret.setType(ret_type);
208     attrs.erase(dtype_id_);
209     ret_attrs.push_back(attrs.getDictionary(&getContext()));
210   }
211   return ArrayAttr::get(&getContext(), ret_attrs);
212 }
213 
214 namespace {
215 // This pattern reifies an op's result shape info into the result types and
216 // drops the output shapes attributes.
217 class ReifyOperationOutputShapes : public RewritePattern {
218  public:
ReifyOperationOutputShapes(MLIRContext * context,PatternBenefit benefit,StringRef attr_name)219   ReifyOperationOutputShapes(MLIRContext *context, PatternBenefit benefit,
220                              StringRef attr_name)
221       : RewritePattern(Pattern::MatchAnyOpTypeTag(), benefit, context),
222         output_shapes_id_(StringAttr::get(context, attr_name)) {}
223 
224   // Returns true if this instance of the pattern should match the op.
225   virtual bool shouldMatch(Operation *op) const = 0;
226 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const227   LogicalResult matchAndRewrite(Operation *op,
228                                 PatternRewriter &rewriter) const override {
229     if (!shouldMatch(op)) return failure();
230 
231     ResultRange results = TFOp(op).getNonControlResults();
232 
233     // Get the output shapes attribute. Ignore it if it is not an array
234     // attribute, if it has an inconsistent number of shapes, or if it is not
235     // an array of shapes.
236     ArrayAttr output_shapes =
237         op->getAttr(output_shapes_id_).dyn_cast_or_null<ArrayAttr>();
238     if (!output_shapes || results.size() != output_shapes.size() ||
239         !IsArrayOfShapes(output_shapes))
240       return failure();
241 
242     rewriter.updateRootInPlace(op, [&] {
243       op->removeAttr(output_shapes_id_);
244       assert(output_shapes.size() == results.size());
245       for (auto it :
246            llvm::zip(results, output_shapes.getAsRange<ShapeAttr>())) {
247         Value result = std::get<0>(it);
248         result.setType(GetReifiedType(result.getType(), std::get<1>(it)));
249       }
250       rewriteImpl(op, rewriter);
251     });
252     return success();
253   }
254 
rewriteImpl(Operation * op,PatternRewriter & rewriter) const255   virtual void rewriteImpl(Operation *op, PatternRewriter &rewriter) const {}
256 
257  private:
258   // Identifier for `_output_shapes`.
259   StringAttr output_shapes_id_;
260 };
261 
262 // This pattern matches and TFG op and reifies `_output_shapes`. The pattern
263 // leaves behind an attribute `_regenerate_output_shapes` that is used by the
264 // converse pattern to detect whether the attribute should be materialized.
265 class ReifyTFGOpOutputShapes : public ReifyOperationOutputShapes {
266  public:
ReifyTFGOpOutputShapes(MLIRContext * context)267   explicit ReifyTFGOpOutputShapes(MLIRContext *context)
268       : ReifyOperationOutputShapes(context, /*benefit=*/1, "_output_shapes"),
269         dialect_(context->getOrLoadDialect<TFGraphDialect>()),
270         regenerate_output_shapes_id_(
271             StringAttr::get(context, kRegenerateOutputShapes)) {}
272 
shouldMatch(Operation * op) const273   bool shouldMatch(Operation *op) const override {
274     return op->getDialect() == dialect_ && op->getNumResults();
275   }
276 
rewriteImpl(Operation * op,PatternRewriter & rewriter) const277   void rewriteImpl(Operation *op, PatternRewriter &rewriter) const override {
278     op->setAttr(regenerate_output_shapes_id_, rewriter.getUnitAttr());
279   }
280 
281  private:
282   // Cached TFG dialect instance.
283   TFGraphDialect *dialect_;
284   // Identifier to `_regenerate_output_shapes`.
285   StringAttr regenerate_output_shapes_id_;
286 };
287 
288 // This pattern matches `If`, `Case`, and `While` and reifies their
289 // `output_shapes` attribute.
290 struct ReifyCFOpOutputShapes : public ReifyOperationOutputShapes {
291   // Set a higher benefit to ensure that "output_shapes" is reified before
292   // "_output_shapes".
ReifyCFOpOutputShapesmlir::tfg::__anon17f873fd0311::ReifyCFOpOutputShapes293   explicit ReifyCFOpOutputShapes(MLIRContext *context)
294       : ReifyOperationOutputShapes(context, /*benefit=*/2, "output_shapes") {}
295 
shouldMatchmlir::tfg::__anon17f873fd0311::ReifyCFOpOutputShapes296   bool shouldMatch(Operation *op) const override {
297     return isa<IfOp, StatelessIfOp, StatefulIfOp, CaseOp, StatelessCaseOp,
298                StatefulCaseOp, WhileOp, StatelessWhileOp, StatefulWhileOp>(op);
299   }
300 };
301 
302 // This pattern removes a list of attributes from the given op types.
303 template <typename... OpTs>
304 class DropAttributes : public RewritePattern {
305  public:
306   // Create the pattern. Specify which attributes to remove.
DropAttributes(MLIRContext * context,ArrayRef<StringRef> attr_names)307   DropAttributes(MLIRContext *context, ArrayRef<StringRef> attr_names)
308       : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context) {
309     for (StringRef attr_name : attr_names)
310       attr_ids_.push_back(StringAttr::get(context, attr_name));
311   }
312 
313   // Remove the specified attributes from the op. Fail if none of the attributes
314   // were present.
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const315   LogicalResult matchAndRewrite(Operation *op,
316                                 PatternRewriter &rewriter) const override {
317     if (!isa<OpTs...>(op)) return failure();
318     rewriter.startRootUpdate(op);
319     if (!llvm::count_if(attr_ids_, [&](StringAttr attr_id) {
320           return op->removeAttr(attr_id);
321         })) {
322       rewriter.cancelRootUpdate(op);
323       return failure();
324     }
325     rewriter.finalizeRootUpdate(op);
326     return success();
327   }
328 
329  private:
330   // The identifiers of the attributes to remove.
331   SmallVector<StringAttr> attr_ids_;
332 };
333 }  // namespace
334 
335 template <typename... OpTs>
RemoveAttributes(MLIRContext * context,ArrayRef<StringRef> attr_names)336 static std::unique_ptr<RewritePattern> RemoveAttributes(
337     MLIRContext *context, ArrayRef<StringRef> attr_names) {
338   return std::make_unique<DropAttributes<OpTs...>>(context, attr_names);
339 }
340 
runOnOperation()341 void ConsolidateAttributesPassImpl::runOnOperation() {
342   // Skip this pass on generic functions. Generic functions contain only opaque
343   // tensor types, into which shape and data type info cannot be reified.
344   auto func = dyn_cast<GraphFuncOp>(getOperation());
345   if (func && func.generic()) return;
346 
347   // Reify operation attributes.
348   RewritePatternSet patterns(&getContext());
349   patterns.insert<ReifyTFGOpOutputShapes, ReifyCFOpOutputShapes>(&getContext());
350   patterns.add(RemoveAttributes<IfOp, StatelessIfOp, StatefulIfOp>(
351       &getContext(), {"Tcond", "Tin", "Tout"}));
352   patterns.add(RemoveAttributes<CaseOp, StatelessCaseOp, StatefulCaseOp>(
353       &getContext(), {"Tin", "Tout"}));
354   patterns.add(
355       RemoveAttributes<WhileOp, StatelessWhileOp, StatefulWhileOp, ForOp>(
356           &getContext(), {"T"}));
357   if (failed(
358           applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
359     getOperation()->emitError(getArgument() + " pass failed");
360     signalPassFailure();
361     return;
362   }
363 
364   // If the pass was run on a function, reify its attributes and then rebuild
365   // the signature. Because the attributes may have conflicting type info, the
366   // order in which we visit the attributes is the priority.
367   if (!func) return;
368   ArrayAttr arg_attrs = reifyAndDropFunctionArgumentAttributes(func);
369   ArrayAttr res_attrs = reifyAndDropFunctionResultAttributes(func);
370   Block &body = func.body().front();
371   auto type = FunctionType::get(
372       &getContext(), body.getArgumentTypes(),
373       TFOp(body.getTerminator()).getNonControlOperands().getTypes());
374   NamedAttrList attrs(func->getAttrDictionary());
375   attrs.set(func.function_typeAttrName(), TypeAttr::get(type));
376   attrs.set(func.arg_attrsAttrName(), arg_attrs);
377   attrs.set(func.res_attrsAttrName(), res_attrs);
378   func->setAttrs(attrs.getDictionary(&getContext()));
379 }
380 
381 namespace {
382 class PrepareAttributesForExportPassImpl
383     : public AttributesPassBase<PrepareAttributesForExportPassImpl> {
384  public:
385   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
386       PrepareAttributesForExportPassImpl)
387 
388   void runOnOperation() override;
389 
390  private:
391   // Materialize required `tfg.` attributes for export. Also, adds
392   // `tf._input_shapes` to the function attributes. And `tf._output_shapes` and
393   // `tf._handle_data` to the argument and result attributes.
394   void prepareFunctionAttributes(GraphFuncOp func);
395 
396   // Prepare attributes for a single type.
397   DictionaryAttr prepareAttributesFor(Type type, DictionaryAttr attr_dict);
398 };
399 }  // namespace
400 
prepareFunctionAttributes(GraphFuncOp func)401 void PrepareAttributesForExportPassImpl::prepareFunctionAttributes(
402     GraphFuncOp func) {
403   NamedAttrList attrs(func->getAttrDictionary());
404   SmallVector<Attribute> input_shapes, arg_attrs, res_attrs;
405 
406   ArrayAttr func_arg_attrs = func.getAllArgAttrs();
407   if (!func_arg_attrs) func_arg_attrs = ArrayAttr::get(&getContext(), {});
408   for (auto it : llvm::zip(func.getArgumentTypes(),
409                            func_arg_attrs.getAsRange<DictionaryAttr>())) {
410     Type type = std::get<0>(it);
411     DictionaryAttr attrs = std::get<1>(it);
412     if (type == control_type_) {
413       arg_attrs.push_back(attrs);
414       continue;
415     }
416     arg_attrs.push_back(prepareAttributesFor(type, attrs));
417     if (auto ranked = type.dyn_cast<RankedTensorType>()) {
418       input_shapes.push_back(ShapeAttr::get(&getContext(), ranked.getShape()));
419     } else {
420       input_shapes.push_back(ShapeAttr::get(&getContext(), llvm::None));
421     }
422   }
423 
424   ArrayAttr func_res_attrs = func.getAllResultAttrs();
425   if (!func_res_attrs) func_res_attrs = ArrayAttr::get(&getContext(), {});
426   for (auto it : llvm::zip(func.getResultTypes(),
427                            func_res_attrs.getAsRange<DictionaryAttr>()))
428     res_attrs.push_back(prepareAttributesFor(std::get<0>(it), std::get<1>(it)));
429 
430   // Add input shapes only if its regeneration is required.
431   if (attrs.erase(regenerate_input_shapes_id_))
432     attrs.set(input_shapes_id_, ArrayAttr::get(&getContext(), input_shapes));
433   attrs.set(func.arg_attrsAttrName(), ArrayAttr::get(&getContext(), arg_attrs));
434   attrs.set(func.res_attrsAttrName(), ArrayAttr::get(&getContext(), res_attrs));
435   func->setAttrs(attrs.getDictionary(&getContext()));
436 }
437 
prepareAttributesFor(Type type,DictionaryAttr attr_dict)438 DictionaryAttr PrepareAttributesForExportPassImpl::prepareAttributesFor(
439     Type type, DictionaryAttr attr_dict) {
440   NamedAttrList attrs(attr_dict);
441   // Add shape data if requested.
442   if (attrs.erase(regenerate_output_shapes_id_)) {
443     auto shape = ShapeAttr::get(&getContext(),
444                                 type.isa<RankedTensorType>()
445                                     ? type.cast<RankedTensorType>().getShape()
446                                     : Optional<ArrayRef<int64_t>>());
447     attrs.set(output_shapes_id_, ArrayAttr::get(&getContext(), {shape}));
448   }
449   auto element_type = type.cast<TensorType>().getElementType();
450   if (auto resource = element_type.dyn_cast<ResourceType>()) {
451     SmallVector<Attribute> handle_data;
452     for (TensorType subtype : resource.getSubtypes())
453       handle_data.push_back(TypeAttr::get(subtype));
454     // Only bother adding handle data if there are subtypes.
455     if (!handle_data.empty())
456       attrs.set(handle_data_id_, ArrayAttr::get(&getContext(), handle_data));
457   }
458   if (element_type.isa<tf_type::TensorFlowRefType>())
459     attrs.set(is_ref_id_, UnitAttr::get(&getContext()));
460   return attrs.getDictionary(&getContext());
461 }
462 
463 // Get the element types of the values as an array attributes.
GetElementTypesAttr(PatternRewriter & rewriter,ValueRange values)464 static ArrayAttr GetElementTypesAttr(PatternRewriter &rewriter,
465                                      ValueRange values) {
466   SmallVector<Attribute> types;
467   for (Value value : values) {
468     types.push_back(
469         TypeAttr::get(value.getType().cast<TensorType>().getElementType()));
470   }
471   return rewriter.getArrayAttr(types);
472 }
473 
474 namespace {
475 // Base class for patterns that materialize control-flow op attributes. This
476 // patterns contains a cached control type.
477 template <typename OpT>
478 class MaterializeAttrsPattern : public OpRewritePattern<OpT> {
479  public:
480   // Create the pattern with a cached control type instance.
MaterializeAttrsPattern(ControlType control_type)481   explicit MaterializeAttrsPattern(ControlType control_type)
482       : OpRewritePattern<OpT>(control_type.getContext()),
483         control_type_(control_type) {}
484 
485   // Get an array of the element types of the data arguments of the op. The
486   // arguments exclude "op-specific" operands such as if condition, case branch
487   // index, and for loop indices.
getArgumentElementTypesAttr(PatternRewriter & rewriter,OpT op) const488   ArrayAttr getArgumentElementTypesAttr(PatternRewriter &rewriter,
489                                         OpT op) const {
490     return GetElementTypesAttr(
491         rewriter, SplitDataAndControlValues(op.args(), control_type_).first);
492   }
493 
494  private:
495   // The cached control type.
496   ControlType control_type_;
497 };
498 
499 template <typename IfLikeOp>
500 struct MaterializeIfAttrs : public MaterializeAttrsPattern<IfLikeOp> {
501   using MaterializeAttrsPattern<IfLikeOp>::MaterializeAttrsPattern;
502 
503   // Materialize `Tcond`, `Tin`, and `Tout`.
matchAndRewritemlir::tfg::__anon17f873fd0711::MaterializeIfAttrs504   LogicalResult matchAndRewrite(IfLikeOp op,
505                                 PatternRewriter &rewriter) const override {
506     if (op.Tcond() && op.Tin() && op.Tout()) return failure();
507     NamedAttrList attrs(op->getAttrDictionary());
508     attrs.set(
509         op.TcondAttrName(),
510         TypeAttr::get(
511             op.cond().getType().template cast<TensorType>().getElementType()));
512     attrs.set(op.TinAttrName(),
513               this->getArgumentElementTypesAttr(rewriter, op));
514     attrs.set(op.ToutAttrName(), GetElementTypesAttr(rewriter, op.outs()));
515     rewriter.updateRootInPlace(
516         op, [&] { op->setAttrs(attrs.getDictionary(op->getContext())); });
517     return success();
518   }
519 };
520 
521 template <typename CaseLikeOp>
522 struct MaterializeCaseAttrs : public MaterializeAttrsPattern<CaseLikeOp> {
523   using MaterializeAttrsPattern<CaseLikeOp>::MaterializeAttrsPattern;
524 
525   // Materialize `Tin` and `Tout`.
matchAndRewritemlir::tfg::__anon17f873fd0711::MaterializeCaseAttrs526   LogicalResult matchAndRewrite(CaseLikeOp op,
527                                 PatternRewriter &rewriter) const override {
528     if (op.Tin() && op.Tout()) return failure();
529     NamedAttrList attrs(op->getAttrDictionary());
530     attrs.set(op.TinAttrName(),
531               this->getArgumentElementTypesAttr(rewriter, op));
532     attrs.set(op.ToutAttrName(), GetElementTypesAttr(rewriter, op.outs()));
533     rewriter.updateRootInPlace(
534         op, [&] { op->setAttrs(attrs.getDictionary(op->getContext())); });
535     return success();
536   }
537 };
538 
539 template <typename WhileOrForLikeOp>
540 struct MaterializeTAttr : public MaterializeAttrsPattern<WhileOrForLikeOp> {
541   using MaterializeAttrsPattern<WhileOrForLikeOp>::MaterializeAttrsPattern;
542 
543   // Materialize `T`.
matchAndRewritemlir::tfg::__anon17f873fd0711::MaterializeTAttr544   LogicalResult matchAndRewrite(WhileOrForLikeOp op,
545                                 PatternRewriter &rewriter) const override {
546     if (op.T()) return failure();
547     rewriter.updateRootInPlace(
548         op, [&] { op.TAttr(this->getArgumentElementTypesAttr(rewriter, op)); });
549     return success();
550   }
551 };
552 
553 // Base class for a pattern that
554 class MaterializeOutputShapesBase : public RewritePattern {
555  public:
MaterializeOutputShapesBase(MLIRContext * context,StringRef attr_name)556   explicit MaterializeOutputShapesBase(MLIRContext *context,
557                                        StringRef attr_name)
558       : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
559         attr_id_(StringAttr::get(context, attr_name)) {}
560 
561   virtual bool shouldMatch(Operation *op) const = 0;
562 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const563   LogicalResult matchAndRewrite(Operation *op,
564                                 PatternRewriter &rewriter) const override {
565     // Exclude internal TFG ops.
566     if (isa<ReturnOp>(op)) return failure();
567     if (!shouldMatch(op) || op->hasAttr(attr_id_)) return failure();
568     ResultRange results = TFOp(op).getNonControlResults();
569 
570     SmallVector<Attribute> shapes;
571     for (Value result : results) {
572       if (auto ranked = result.getType().dyn_cast<RankedTensorType>()) {
573         shapes.push_back(ShapeAttr::get(op->getContext(), ranked.getShape()));
574       } else {
575         shapes.push_back(ShapeAttr::get(op->getContext(), llvm::None));
576       }
577     }
578     rewriter.updateRootInPlace(op, [&] {
579       op->setAttr(attr_id_, rewriter.getArrayAttr(shapes));
580       rewriteImpl(op, rewriter);
581     });
582     return success();
583   }
584 
rewriteImpl(Operation * op,PatternRewriter & rewriter) const585   virtual void rewriteImpl(Operation *op, PatternRewriter &rewriter) const {}
586 
587  private:
588   // Cached identifier for the output shapes attribute.
589   StringAttr attr_id_;
590 };
591 
592 // Materialize `_output_shapes` for any TFG op.
593 class MaterializeTFGOpOutputShapes : public MaterializeOutputShapesBase {
594  public:
MaterializeTFGOpOutputShapes(MLIRContext * context)595   explicit MaterializeTFGOpOutputShapes(MLIRContext *context)
596       : MaterializeOutputShapesBase(context, "_output_shapes"),
597         dialect_(context->getOrLoadDialect<TFGraphDialect>()),
598         regenerate_output_shapes_id_(
599             StringAttr::get(context, kRegenerateOutputShapes)) {}
600 
shouldMatch(Operation * op) const601   bool shouldMatch(Operation *op) const override {
602     return op->getDialect() == dialect_ &&
603            op->getAttrOfType<UnitAttr>(regenerate_output_shapes_id_);
604   }
605 
rewriteImpl(Operation * op,PatternRewriter & rewriter) const606   void rewriteImpl(Operation *op, PatternRewriter &rewriter) const override {
607     op->removeAttr(regenerate_output_shapes_id_);
608   }
609 
610  private:
611   // Cached TFG dialect instance.
612   TFGraphDialect *dialect_;
613   // Identifier to `_regenerate_output_shapes`.
614   StringAttr regenerate_output_shapes_id_;
615 };
616 
617 // Materialize `output_shapes` for `If`, `Case`, and `While` ops.
618 struct MaterializeCFOpOutputShapes : public MaterializeOutputShapesBase {
MaterializeCFOpOutputShapesmlir::tfg::__anon17f873fd0711::MaterializeCFOpOutputShapes619   explicit MaterializeCFOpOutputShapes(MLIRContext *context)
620       : MaterializeOutputShapesBase(context, "output_shapes") {}
621 
shouldMatchmlir::tfg::__anon17f873fd0711::MaterializeCFOpOutputShapes622   bool shouldMatch(Operation *op) const override {
623     return isa<IfOp, StatelessIfOp, StatefulIfOp, CaseOp, StatelessCaseOp,
624                StatefulCaseOp, WhileOp, StatelessWhileOp, StatefulWhileOp>(op);
625   }
626 };
627 }  // namespace
628 
629 template <template <typename OpT> class PatternT, typename... OpTs,
630           typename... Args>
InsertPatterns(RewritePatternSet & patterns,Args &&...args)631 static void InsertPatterns(RewritePatternSet &patterns, Args &&...args) {
632   patterns.insert<PatternT<OpTs>...>(std::forward<Args>(args)...);
633 }
634 
runOnOperation()635 void PrepareAttributesForExportPassImpl::runOnOperation() {
636   // Skip this pass on generic functions. Generic functions contain only opaque
637   // tensor types, into which shape and data type info cannot be reified.
638   auto func = dyn_cast<GraphFuncOp>(getOperation());
639   if (func && func.generic()) return;
640 
641   RewritePatternSet patterns(&getContext());
642   ControlType control_type = ControlType::get(&getContext());
643   InsertPatterns<MaterializeIfAttrs, IfOp, StatelessIfOp, StatefulIfOp>(
644       patterns, control_type);
645   InsertPatterns<MaterializeCaseAttrs, CaseOp, StatelessCaseOp, StatefulCaseOp>(
646       patterns, control_type);
647   InsertPatterns<MaterializeTAttr, WhileOp, StatelessWhileOp, StatefulWhileOp,
648                  ForOp>(patterns, control_type);
649   patterns.insert<MaterializeTFGOpOutputShapes, MaterializeCFOpOutputShapes>(
650       &getContext());
651   if (failed(
652           applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
653     getOperation()->emitError(getArgument() + " pass failed");
654     signalPassFailure();
655     return;
656   }
657 
658   // If the pass was run on a function, materialize function, argument, and
659   // result attributes with type info.
660   if (func) prepareFunctionAttributes(func);
661 }
662 
663 namespace {
664 struct ConsolidateAttributesPass
665     : public ConsolidateAttributesBase<ConsolidateAttributesPass> {
runOnOperationmlir::tfg::__anon17f873fd0c11::ConsolidateAttributesPass666   void runOnOperation() override {
667     // Run the sub-pass on both `tfg.graph` and `tfg.func`.
668     PassManager mgr(&getContext());
669     mgr.addNestedPass<GraphOp>(
670         std::make_unique<ConsolidateAttributesPassImpl>());
671     mgr.addNestedPass<GraphFuncOp>(
672         std::make_unique<ConsolidateAttributesPassImpl>());
673     if (failed(runPipeline(mgr, getOperation()))) signalPassFailure();
674   }
675 };
676 
677 struct PrepareAttributesForExportPass
678     : public PrepareAttributesForExportBase<PrepareAttributesForExportPass> {
runOnOperationmlir::tfg::__anon17f873fd0c11::PrepareAttributesForExportPass679   void runOnOperation() override {
680     // Run the sub-pass on both `tfg.graph` and `tfg.func`.
681     PassManager mgr(&getContext());
682     mgr.addNestedPass<GraphOp>(
683         std::make_unique<PrepareAttributesForExportPassImpl>());
684     mgr.addNestedPass<GraphFuncOp>(
685         std::make_unique<PrepareAttributesForExportPassImpl>());
686     if (failed(runPipeline(mgr, getOperation()))) signalPassFailure();
687   }
688 };
689 }  // namespace
690 
CreateConsolidateAttributesPass()691 std::unique_ptr<Pass> CreateConsolidateAttributesPass() {
692   return std::make_unique<ConsolidateAttributesPass>();
693 }
694 
CreatePrepareAttributesForExportPass()695 std::unique_ptr<Pass> CreatePrepareAttributesForExportPass() {
696   return std::make_unique<PrepareAttributesForExportPass>();
697 }
698 
699 }  // namespace tfg
700 }  // namespace mlir
701