1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/transforms/consolidate_attrs/pass.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/PatternMatch.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Pass/PassManager.h" // from @llvm-project
28 #include "mlir/Support/LogicalResult.h" // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
30 #include "tensorflow/core/ir/dialect.h"
31 #include "tensorflow/core/ir/ops.h"
32 #include "tensorflow/core/ir/tf_op_wrapper.h"
33 #include "tensorflow/core/ir/types/dialect.h"
34 #include "tensorflow/core/ir/utility.h"
35 #include "tensorflow/core/transforms/pass_detail.h"
36
37 namespace mlir {
38 namespace tfg {
39
40 static const char *kRegenerateOutputShapes = "tfg.regenerate_output_shapes";
41
42 // Returns true if an attribute is an array of shapes;
IsArrayOfShapes(ArrayAttr array)43 static bool IsArrayOfShapes(ArrayAttr array) {
44 return llvm::all_of(array,
45 [](Attribute attr) { return attr.isa<ShapeAttr>(); });
46 }
47
48 // Given a tensor type and shape information, try to refine the type.
GetReifiedType(Type orig,ShapeAttr shape)49 static Type GetReifiedType(Type orig, ShapeAttr shape) {
50 Type element_type = orig.cast<ShapedType>().getElementType();
51 TensorType inferred;
52 if (shape.hasRank()) {
53 // Replace dimensions less than -1 with ?
54 SmallVector<int64_t> dims = llvm::to_vector(shape.getShape());
55 for (int64_t &dim : dims)
56 if (dim < -1) dim = -1;
57 inferred = RankedTensorType::get(dims, element_type);
58 } else {
59 inferred = UnrankedTensorType::get(element_type);
60 }
61 Type reified_type = tf_type::GetCastCompatibleType(inferred, orig);
62 // If the types are not compatible, return the original type.
63 return reified_type ? reified_type : orig;
64 }
65
66 namespace {
67 // CRTP base class for consolidate attribute passes. This base class defines
68 // cached identifiers for the attributes.
69 template <typename PassT>
70 class AttributesPassBase : public PassWrapper<PassT, OperationPass<>> {
71 public:
initialize(MLIRContext * context)72 LogicalResult initialize(MLIRContext *context) override {
73 input_shapes_id_ = StringAttr::get(context, "tf._input_shapes");
74 regenerate_input_shapes_id_ =
75 StringAttr::get(context, "tfg.regenerate_input_shapes");
76 output_shapes_id_ = StringAttr::get(context, "tf._output_shapes");
77 regenerate_output_shapes_id_ =
78 StringAttr::get(context, "tfg.regenerate_output_shapes");
79 handle_data_id_ = StringAttr::get(context, "tfg.handle_data");
80 dtype_id_ = StringAttr::get(context, "tfg.dtype");
81 is_ref_id_ = StringAttr::get(context, "tfg.is_ref");
82 control_type_ = ControlType::get(context);
83 return success();
84 }
85
86 protected:
87 // Identifier for `tf._input_shapes`.
88 StringAttr input_shapes_id_;
89 // Identifier for `tf._regenerate_input_shapes`.
90 StringAttr regenerate_input_shapes_id_;
91 // Identifier for `tf._output_shapes`.
92 StringAttr output_shapes_id_;
93 // Identifier for `tf._regenerate_output_shapes`.
94 StringAttr regenerate_output_shapes_id_;
95 // Identifier for `tfg.handle_data`.
96 StringAttr handle_data_id_;
97 // Identifier for `tfg.dtype`.
98 StringAttr dtype_id_;
99 // Identifier for `tfg.is_ref`.
100 StringAttr is_ref_id_;
101 // Cacched control type.
102 ControlType control_type_;
103 };
104
105 class ConsolidateAttributesPassImpl
106 : public AttributesPassBase<ConsolidateAttributesPassImpl> {
107 public:
108 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConsolidateAttributesPassImpl)
109
110 void runOnOperation() override;
111
112 private:
113 // Reify `tf._input_shapes`, `tf._output_shapes` and `tfg.handle_data` into
114 // the types of the function arguments. Drop the attributes `tfg.dtype` and
115 // `tfg.is_ref`. Return the new argument attributes.
116 ArrayAttr reifyAndDropFunctionArgumentAttributes(GraphFuncOp func);
117 // Reify `tf._output_shapes` and `tfg.handle_data` into the types of the
118 // function results. Drop the attribute `tfg.dtype`. Return the new result
119 // attributes.
120 ArrayAttr reifyAndDropFunctionResultAttributes(GraphFuncOp func);
121
122 // Refine a type with `tf._output_shapes`.
123 Type refineTypeWithOutputShapes(Type type, NamedAttrList &attrs);
124 // Refine a type with `tfg.handle_data`.
125 Type refineTypeWithHandleData(Type type, Attribute handle_data);
126 };
127 } // namespace
128
refineTypeWithOutputShapes(Type type,NamedAttrList & attrs)129 Type ConsolidateAttributesPassImpl::refineTypeWithOutputShapes(
130 Type type, NamedAttrList &attrs) {
131 // Get the output shapes attribute. If the attribute is not an array of
132 // exactly one shape, ignore it.
133 if (auto output_shapes =
134 attrs.get(output_shapes_id_).dyn_cast_or_null<ArrayAttr>()) {
135 if (output_shapes.size() == 1 && IsArrayOfShapes(output_shapes)) {
136 attrs.erase(output_shapes_id_);
137 attrs.set(regenerate_output_shapes_id_, UnitAttr::get(&getContext()));
138 return GetReifiedType(type, output_shapes[0].cast<ShapeAttr>());
139 }
140 }
141 return type;
142 }
143
refineTypeWithHandleData(Type type,Attribute handle_data)144 Type ConsolidateAttributesPassImpl::refineTypeWithHandleData(
145 Type type, Attribute handle_data) {
146 if (!handle_data) return type;
147 SmallVector<TensorType> subtypes;
148 // Because `tfg.handle_data` is a TFG internal attribute, it will be
149 // well-formed.
150 for (Type type : handle_data.cast<ArrayAttr>().getAsValueRange<TypeAttr>())
151 subtypes.push_back(type.cast<TensorType>());
152 auto resource =
153 UnrankedTensorType::get(ResourceType::get(subtypes, &getContext()));
154 Type reified = tf_type::GetCastCompatibleType(resource, type);
155 return reified ? reified : type;
156 }
157
reifyAndDropFunctionArgumentAttributes(GraphFuncOp func)158 ArrayAttr ConsolidateAttributesPassImpl::reifyAndDropFunctionArgumentAttributes(
159 GraphFuncOp func) {
160 // Get the input shapes attribute. If it is a UnitAttr, then it is empty and
161 // we will ignore it. If it isn't an array of shapes or has an inconsistent
162 // number of shapes, ignore it.
163 ArrayAttr input_shapes =
164 func->getAttr(input_shapes_id_).dyn_cast_or_null<ArrayAttr>();
165 unsigned num_args = func.getNumArguments() / 2;
166 if (input_shapes) {
167 if (input_shapes.size() != num_args || !IsArrayOfShapes(input_shapes)) {
168 input_shapes = {};
169 } else {
170 func->removeAttr(input_shapes_id_);
171 func->setAttr(regenerate_input_shapes_id_, UnitAttr::get(&getContext()));
172 }
173 }
174
175 SmallVector<Attribute> arg_attrs;
176 auto empty_dict = DictionaryAttr::get(&getContext());
177 for (auto i : llvm::seq<unsigned>(0, num_args)) {
178 BlockArgument arg = GraphFuncOp::getDataValue(func.body(), i);
179 NamedAttrList attrs(func.getArgAttrs(arg.getArgNumber()));
180 Type arg_type = arg.getType();
181 arg_type = refineTypeWithOutputShapes(arg_type, attrs);
182 arg_type = refineTypeWithHandleData(arg_type, attrs.erase(handle_data_id_));
183 if (input_shapes)
184 arg_type = GetReifiedType(arg_type, input_shapes[i].cast<ShapeAttr>());
185 arg.setType(arg_type);
186 attrs.erase(dtype_id_);
187 attrs.erase(is_ref_id_);
188 arg_attrs.append({attrs.getDictionary(&getContext()), empty_dict});
189 }
190 return ArrayAttr::get(&getContext(), arg_attrs);
191 }
192
reifyAndDropFunctionResultAttributes(GraphFuncOp func)193 ArrayAttr ConsolidateAttributesPassImpl::reifyAndDropFunctionResultAttributes(
194 GraphFuncOp func) {
195 ArrayAttr res_attrs = func.getAllResultAttrs();
196 if (!res_attrs) return ArrayAttr::get(&getContext(), {});
197
198 SmallVector<Attribute> ret_attrs;
199 // The result types are propagated to the data operands to `return`.
200 auto ret_op = cast<ReturnOp>(func.body().front().getTerminator());
201 for (auto &it : llvm::enumerate(res_attrs.getAsRange<DictionaryAttr>())) {
202 NamedAttrList attrs(it.value());
203 Value ret = ret_op.getOperand(it.index());
204 Type ret_type = ret.getType();
205 ret_type = refineTypeWithOutputShapes(ret_type, attrs);
206 ret_type = refineTypeWithHandleData(ret_type, attrs.erase(handle_data_id_));
207 ret.setType(ret_type);
208 attrs.erase(dtype_id_);
209 ret_attrs.push_back(attrs.getDictionary(&getContext()));
210 }
211 return ArrayAttr::get(&getContext(), ret_attrs);
212 }
213
214 namespace {
215 // This pattern reifies an op's result shape info into the result types and
216 // drops the output shapes attributes.
217 class ReifyOperationOutputShapes : public RewritePattern {
218 public:
ReifyOperationOutputShapes(MLIRContext * context,PatternBenefit benefit,StringRef attr_name)219 ReifyOperationOutputShapes(MLIRContext *context, PatternBenefit benefit,
220 StringRef attr_name)
221 : RewritePattern(Pattern::MatchAnyOpTypeTag(), benefit, context),
222 output_shapes_id_(StringAttr::get(context, attr_name)) {}
223
224 // Returns true if this instance of the pattern should match the op.
225 virtual bool shouldMatch(Operation *op) const = 0;
226
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const227 LogicalResult matchAndRewrite(Operation *op,
228 PatternRewriter &rewriter) const override {
229 if (!shouldMatch(op)) return failure();
230
231 ResultRange results = TFOp(op).getNonControlResults();
232
233 // Get the output shapes attribute. Ignore it if it is not an array
234 // attribute, if it has an inconsistent number of shapes, or if it is not
235 // an array of shapes.
236 ArrayAttr output_shapes =
237 op->getAttr(output_shapes_id_).dyn_cast_or_null<ArrayAttr>();
238 if (!output_shapes || results.size() != output_shapes.size() ||
239 !IsArrayOfShapes(output_shapes))
240 return failure();
241
242 rewriter.updateRootInPlace(op, [&] {
243 op->removeAttr(output_shapes_id_);
244 assert(output_shapes.size() == results.size());
245 for (auto it :
246 llvm::zip(results, output_shapes.getAsRange<ShapeAttr>())) {
247 Value result = std::get<0>(it);
248 result.setType(GetReifiedType(result.getType(), std::get<1>(it)));
249 }
250 rewriteImpl(op, rewriter);
251 });
252 return success();
253 }
254
rewriteImpl(Operation * op,PatternRewriter & rewriter) const255 virtual void rewriteImpl(Operation *op, PatternRewriter &rewriter) const {}
256
257 private:
258 // Identifier for `_output_shapes`.
259 StringAttr output_shapes_id_;
260 };
261
262 // This pattern matches and TFG op and reifies `_output_shapes`. The pattern
263 // leaves behind an attribute `_regenerate_output_shapes` that is used by the
264 // converse pattern to detect whether the attribute should be materialized.
265 class ReifyTFGOpOutputShapes : public ReifyOperationOutputShapes {
266 public:
ReifyTFGOpOutputShapes(MLIRContext * context)267 explicit ReifyTFGOpOutputShapes(MLIRContext *context)
268 : ReifyOperationOutputShapes(context, /*benefit=*/1, "_output_shapes"),
269 dialect_(context->getOrLoadDialect<TFGraphDialect>()),
270 regenerate_output_shapes_id_(
271 StringAttr::get(context, kRegenerateOutputShapes)) {}
272
shouldMatch(Operation * op) const273 bool shouldMatch(Operation *op) const override {
274 return op->getDialect() == dialect_ && op->getNumResults();
275 }
276
rewriteImpl(Operation * op,PatternRewriter & rewriter) const277 void rewriteImpl(Operation *op, PatternRewriter &rewriter) const override {
278 op->setAttr(regenerate_output_shapes_id_, rewriter.getUnitAttr());
279 }
280
281 private:
282 // Cached TFG dialect instance.
283 TFGraphDialect *dialect_;
284 // Identifier to `_regenerate_output_shapes`.
285 StringAttr regenerate_output_shapes_id_;
286 };
287
288 // This pattern matches `If`, `Case`, and `While` and reifies their
289 // `output_shapes` attribute.
290 struct ReifyCFOpOutputShapes : public ReifyOperationOutputShapes {
291 // Set a higher benefit to ensure that "output_shapes" is reified before
292 // "_output_shapes".
ReifyCFOpOutputShapesmlir::tfg::__anon17f873fd0311::ReifyCFOpOutputShapes293 explicit ReifyCFOpOutputShapes(MLIRContext *context)
294 : ReifyOperationOutputShapes(context, /*benefit=*/2, "output_shapes") {}
295
shouldMatchmlir::tfg::__anon17f873fd0311::ReifyCFOpOutputShapes296 bool shouldMatch(Operation *op) const override {
297 return isa<IfOp, StatelessIfOp, StatefulIfOp, CaseOp, StatelessCaseOp,
298 StatefulCaseOp, WhileOp, StatelessWhileOp, StatefulWhileOp>(op);
299 }
300 };
301
302 // This pattern removes a list of attributes from the given op types.
303 template <typename... OpTs>
304 class DropAttributes : public RewritePattern {
305 public:
306 // Create the pattern. Specify which attributes to remove.
DropAttributes(MLIRContext * context,ArrayRef<StringRef> attr_names)307 DropAttributes(MLIRContext *context, ArrayRef<StringRef> attr_names)
308 : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context) {
309 for (StringRef attr_name : attr_names)
310 attr_ids_.push_back(StringAttr::get(context, attr_name));
311 }
312
313 // Remove the specified attributes from the op. Fail if none of the attributes
314 // were present.
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const315 LogicalResult matchAndRewrite(Operation *op,
316 PatternRewriter &rewriter) const override {
317 if (!isa<OpTs...>(op)) return failure();
318 rewriter.startRootUpdate(op);
319 if (!llvm::count_if(attr_ids_, [&](StringAttr attr_id) {
320 return op->removeAttr(attr_id);
321 })) {
322 rewriter.cancelRootUpdate(op);
323 return failure();
324 }
325 rewriter.finalizeRootUpdate(op);
326 return success();
327 }
328
329 private:
330 // The identifiers of the attributes to remove.
331 SmallVector<StringAttr> attr_ids_;
332 };
333 } // namespace
334
335 template <typename... OpTs>
RemoveAttributes(MLIRContext * context,ArrayRef<StringRef> attr_names)336 static std::unique_ptr<RewritePattern> RemoveAttributes(
337 MLIRContext *context, ArrayRef<StringRef> attr_names) {
338 return std::make_unique<DropAttributes<OpTs...>>(context, attr_names);
339 }
340
runOnOperation()341 void ConsolidateAttributesPassImpl::runOnOperation() {
342 // Skip this pass on generic functions. Generic functions contain only opaque
343 // tensor types, into which shape and data type info cannot be reified.
344 auto func = dyn_cast<GraphFuncOp>(getOperation());
345 if (func && func.generic()) return;
346
347 // Reify operation attributes.
348 RewritePatternSet patterns(&getContext());
349 patterns.insert<ReifyTFGOpOutputShapes, ReifyCFOpOutputShapes>(&getContext());
350 patterns.add(RemoveAttributes<IfOp, StatelessIfOp, StatefulIfOp>(
351 &getContext(), {"Tcond", "Tin", "Tout"}));
352 patterns.add(RemoveAttributes<CaseOp, StatelessCaseOp, StatefulCaseOp>(
353 &getContext(), {"Tin", "Tout"}));
354 patterns.add(
355 RemoveAttributes<WhileOp, StatelessWhileOp, StatefulWhileOp, ForOp>(
356 &getContext(), {"T"}));
357 if (failed(
358 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
359 getOperation()->emitError(getArgument() + " pass failed");
360 signalPassFailure();
361 return;
362 }
363
364 // If the pass was run on a function, reify its attributes and then rebuild
365 // the signature. Because the attributes may have conflicting type info, the
366 // order in which we visit the attributes is the priority.
367 if (!func) return;
368 ArrayAttr arg_attrs = reifyAndDropFunctionArgumentAttributes(func);
369 ArrayAttr res_attrs = reifyAndDropFunctionResultAttributes(func);
370 Block &body = func.body().front();
371 auto type = FunctionType::get(
372 &getContext(), body.getArgumentTypes(),
373 TFOp(body.getTerminator()).getNonControlOperands().getTypes());
374 NamedAttrList attrs(func->getAttrDictionary());
375 attrs.set(func.function_typeAttrName(), TypeAttr::get(type));
376 attrs.set(func.arg_attrsAttrName(), arg_attrs);
377 attrs.set(func.res_attrsAttrName(), res_attrs);
378 func->setAttrs(attrs.getDictionary(&getContext()));
379 }
380
381 namespace {
382 class PrepareAttributesForExportPassImpl
383 : public AttributesPassBase<PrepareAttributesForExportPassImpl> {
384 public:
385 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
386 PrepareAttributesForExportPassImpl)
387
388 void runOnOperation() override;
389
390 private:
391 // Materialize required `tfg.` attributes for export. Also, adds
392 // `tf._input_shapes` to the function attributes. And `tf._output_shapes` and
393 // `tf._handle_data` to the argument and result attributes.
394 void prepareFunctionAttributes(GraphFuncOp func);
395
396 // Prepare attributes for a single type.
397 DictionaryAttr prepareAttributesFor(Type type, DictionaryAttr attr_dict);
398 };
399 } // namespace
400
prepareFunctionAttributes(GraphFuncOp func)401 void PrepareAttributesForExportPassImpl::prepareFunctionAttributes(
402 GraphFuncOp func) {
403 NamedAttrList attrs(func->getAttrDictionary());
404 SmallVector<Attribute> input_shapes, arg_attrs, res_attrs;
405
406 ArrayAttr func_arg_attrs = func.getAllArgAttrs();
407 if (!func_arg_attrs) func_arg_attrs = ArrayAttr::get(&getContext(), {});
408 for (auto it : llvm::zip(func.getArgumentTypes(),
409 func_arg_attrs.getAsRange<DictionaryAttr>())) {
410 Type type = std::get<0>(it);
411 DictionaryAttr attrs = std::get<1>(it);
412 if (type == control_type_) {
413 arg_attrs.push_back(attrs);
414 continue;
415 }
416 arg_attrs.push_back(prepareAttributesFor(type, attrs));
417 if (auto ranked = type.dyn_cast<RankedTensorType>()) {
418 input_shapes.push_back(ShapeAttr::get(&getContext(), ranked.getShape()));
419 } else {
420 input_shapes.push_back(ShapeAttr::get(&getContext(), llvm::None));
421 }
422 }
423
424 ArrayAttr func_res_attrs = func.getAllResultAttrs();
425 if (!func_res_attrs) func_res_attrs = ArrayAttr::get(&getContext(), {});
426 for (auto it : llvm::zip(func.getResultTypes(),
427 func_res_attrs.getAsRange<DictionaryAttr>()))
428 res_attrs.push_back(prepareAttributesFor(std::get<0>(it), std::get<1>(it)));
429
430 // Add input shapes only if its regeneration is required.
431 if (attrs.erase(regenerate_input_shapes_id_))
432 attrs.set(input_shapes_id_, ArrayAttr::get(&getContext(), input_shapes));
433 attrs.set(func.arg_attrsAttrName(), ArrayAttr::get(&getContext(), arg_attrs));
434 attrs.set(func.res_attrsAttrName(), ArrayAttr::get(&getContext(), res_attrs));
435 func->setAttrs(attrs.getDictionary(&getContext()));
436 }
437
prepareAttributesFor(Type type,DictionaryAttr attr_dict)438 DictionaryAttr PrepareAttributesForExportPassImpl::prepareAttributesFor(
439 Type type, DictionaryAttr attr_dict) {
440 NamedAttrList attrs(attr_dict);
441 // Add shape data if requested.
442 if (attrs.erase(regenerate_output_shapes_id_)) {
443 auto shape = ShapeAttr::get(&getContext(),
444 type.isa<RankedTensorType>()
445 ? type.cast<RankedTensorType>().getShape()
446 : Optional<ArrayRef<int64_t>>());
447 attrs.set(output_shapes_id_, ArrayAttr::get(&getContext(), {shape}));
448 }
449 auto element_type = type.cast<TensorType>().getElementType();
450 if (auto resource = element_type.dyn_cast<ResourceType>()) {
451 SmallVector<Attribute> handle_data;
452 for (TensorType subtype : resource.getSubtypes())
453 handle_data.push_back(TypeAttr::get(subtype));
454 // Only bother adding handle data if there are subtypes.
455 if (!handle_data.empty())
456 attrs.set(handle_data_id_, ArrayAttr::get(&getContext(), handle_data));
457 }
458 if (element_type.isa<tf_type::TensorFlowRefType>())
459 attrs.set(is_ref_id_, UnitAttr::get(&getContext()));
460 return attrs.getDictionary(&getContext());
461 }
462
463 // Get the element types of the values as an array attributes.
GetElementTypesAttr(PatternRewriter & rewriter,ValueRange values)464 static ArrayAttr GetElementTypesAttr(PatternRewriter &rewriter,
465 ValueRange values) {
466 SmallVector<Attribute> types;
467 for (Value value : values) {
468 types.push_back(
469 TypeAttr::get(value.getType().cast<TensorType>().getElementType()));
470 }
471 return rewriter.getArrayAttr(types);
472 }
473
474 namespace {
475 // Base class for patterns that materialize control-flow op attributes. This
476 // patterns contains a cached control type.
477 template <typename OpT>
478 class MaterializeAttrsPattern : public OpRewritePattern<OpT> {
479 public:
480 // Create the pattern with a cached control type instance.
MaterializeAttrsPattern(ControlType control_type)481 explicit MaterializeAttrsPattern(ControlType control_type)
482 : OpRewritePattern<OpT>(control_type.getContext()),
483 control_type_(control_type) {}
484
485 // Get an array of the element types of the data arguments of the op. The
486 // arguments exclude "op-specific" operands such as if condition, case branch
487 // index, and for loop indices.
getArgumentElementTypesAttr(PatternRewriter & rewriter,OpT op) const488 ArrayAttr getArgumentElementTypesAttr(PatternRewriter &rewriter,
489 OpT op) const {
490 return GetElementTypesAttr(
491 rewriter, SplitDataAndControlValues(op.args(), control_type_).first);
492 }
493
494 private:
495 // The cached control type.
496 ControlType control_type_;
497 };
498
499 template <typename IfLikeOp>
500 struct MaterializeIfAttrs : public MaterializeAttrsPattern<IfLikeOp> {
501 using MaterializeAttrsPattern<IfLikeOp>::MaterializeAttrsPattern;
502
503 // Materialize `Tcond`, `Tin`, and `Tout`.
matchAndRewritemlir::tfg::__anon17f873fd0711::MaterializeIfAttrs504 LogicalResult matchAndRewrite(IfLikeOp op,
505 PatternRewriter &rewriter) const override {
506 if (op.Tcond() && op.Tin() && op.Tout()) return failure();
507 NamedAttrList attrs(op->getAttrDictionary());
508 attrs.set(
509 op.TcondAttrName(),
510 TypeAttr::get(
511 op.cond().getType().template cast<TensorType>().getElementType()));
512 attrs.set(op.TinAttrName(),
513 this->getArgumentElementTypesAttr(rewriter, op));
514 attrs.set(op.ToutAttrName(), GetElementTypesAttr(rewriter, op.outs()));
515 rewriter.updateRootInPlace(
516 op, [&] { op->setAttrs(attrs.getDictionary(op->getContext())); });
517 return success();
518 }
519 };
520
521 template <typename CaseLikeOp>
522 struct MaterializeCaseAttrs : public MaterializeAttrsPattern<CaseLikeOp> {
523 using MaterializeAttrsPattern<CaseLikeOp>::MaterializeAttrsPattern;
524
525 // Materialize `Tin` and `Tout`.
matchAndRewritemlir::tfg::__anon17f873fd0711::MaterializeCaseAttrs526 LogicalResult matchAndRewrite(CaseLikeOp op,
527 PatternRewriter &rewriter) const override {
528 if (op.Tin() && op.Tout()) return failure();
529 NamedAttrList attrs(op->getAttrDictionary());
530 attrs.set(op.TinAttrName(),
531 this->getArgumentElementTypesAttr(rewriter, op));
532 attrs.set(op.ToutAttrName(), GetElementTypesAttr(rewriter, op.outs()));
533 rewriter.updateRootInPlace(
534 op, [&] { op->setAttrs(attrs.getDictionary(op->getContext())); });
535 return success();
536 }
537 };
538
539 template <typename WhileOrForLikeOp>
540 struct MaterializeTAttr : public MaterializeAttrsPattern<WhileOrForLikeOp> {
541 using MaterializeAttrsPattern<WhileOrForLikeOp>::MaterializeAttrsPattern;
542
543 // Materialize `T`.
matchAndRewritemlir::tfg::__anon17f873fd0711::MaterializeTAttr544 LogicalResult matchAndRewrite(WhileOrForLikeOp op,
545 PatternRewriter &rewriter) const override {
546 if (op.T()) return failure();
547 rewriter.updateRootInPlace(
548 op, [&] { op.TAttr(this->getArgumentElementTypesAttr(rewriter, op)); });
549 return success();
550 }
551 };
552
553 // Base class for a pattern that
554 class MaterializeOutputShapesBase : public RewritePattern {
555 public:
MaterializeOutputShapesBase(MLIRContext * context,StringRef attr_name)556 explicit MaterializeOutputShapesBase(MLIRContext *context,
557 StringRef attr_name)
558 : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
559 attr_id_(StringAttr::get(context, attr_name)) {}
560
561 virtual bool shouldMatch(Operation *op) const = 0;
562
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const563 LogicalResult matchAndRewrite(Operation *op,
564 PatternRewriter &rewriter) const override {
565 // Exclude internal TFG ops.
566 if (isa<ReturnOp>(op)) return failure();
567 if (!shouldMatch(op) || op->hasAttr(attr_id_)) return failure();
568 ResultRange results = TFOp(op).getNonControlResults();
569
570 SmallVector<Attribute> shapes;
571 for (Value result : results) {
572 if (auto ranked = result.getType().dyn_cast<RankedTensorType>()) {
573 shapes.push_back(ShapeAttr::get(op->getContext(), ranked.getShape()));
574 } else {
575 shapes.push_back(ShapeAttr::get(op->getContext(), llvm::None));
576 }
577 }
578 rewriter.updateRootInPlace(op, [&] {
579 op->setAttr(attr_id_, rewriter.getArrayAttr(shapes));
580 rewriteImpl(op, rewriter);
581 });
582 return success();
583 }
584
rewriteImpl(Operation * op,PatternRewriter & rewriter) const585 virtual void rewriteImpl(Operation *op, PatternRewriter &rewriter) const {}
586
587 private:
588 // Cached identifier for the output shapes attribute.
589 StringAttr attr_id_;
590 };
591
592 // Materialize `_output_shapes` for any TFG op.
593 class MaterializeTFGOpOutputShapes : public MaterializeOutputShapesBase {
594 public:
MaterializeTFGOpOutputShapes(MLIRContext * context)595 explicit MaterializeTFGOpOutputShapes(MLIRContext *context)
596 : MaterializeOutputShapesBase(context, "_output_shapes"),
597 dialect_(context->getOrLoadDialect<TFGraphDialect>()),
598 regenerate_output_shapes_id_(
599 StringAttr::get(context, kRegenerateOutputShapes)) {}
600
shouldMatch(Operation * op) const601 bool shouldMatch(Operation *op) const override {
602 return op->getDialect() == dialect_ &&
603 op->getAttrOfType<UnitAttr>(regenerate_output_shapes_id_);
604 }
605
rewriteImpl(Operation * op,PatternRewriter & rewriter) const606 void rewriteImpl(Operation *op, PatternRewriter &rewriter) const override {
607 op->removeAttr(regenerate_output_shapes_id_);
608 }
609
610 private:
611 // Cached TFG dialect instance.
612 TFGraphDialect *dialect_;
613 // Identifier to `_regenerate_output_shapes`.
614 StringAttr regenerate_output_shapes_id_;
615 };
616
617 // Materialize `output_shapes` for `If`, `Case`, and `While` ops.
618 struct MaterializeCFOpOutputShapes : public MaterializeOutputShapesBase {
MaterializeCFOpOutputShapesmlir::tfg::__anon17f873fd0711::MaterializeCFOpOutputShapes619 explicit MaterializeCFOpOutputShapes(MLIRContext *context)
620 : MaterializeOutputShapesBase(context, "output_shapes") {}
621
shouldMatchmlir::tfg::__anon17f873fd0711::MaterializeCFOpOutputShapes622 bool shouldMatch(Operation *op) const override {
623 return isa<IfOp, StatelessIfOp, StatefulIfOp, CaseOp, StatelessCaseOp,
624 StatefulCaseOp, WhileOp, StatelessWhileOp, StatefulWhileOp>(op);
625 }
626 };
627 } // namespace
628
629 template <template <typename OpT> class PatternT, typename... OpTs,
630 typename... Args>
InsertPatterns(RewritePatternSet & patterns,Args &&...args)631 static void InsertPatterns(RewritePatternSet &patterns, Args &&...args) {
632 patterns.insert<PatternT<OpTs>...>(std::forward<Args>(args)...);
633 }
634
runOnOperation()635 void PrepareAttributesForExportPassImpl::runOnOperation() {
636 // Skip this pass on generic functions. Generic functions contain only opaque
637 // tensor types, into which shape and data type info cannot be reified.
638 auto func = dyn_cast<GraphFuncOp>(getOperation());
639 if (func && func.generic()) return;
640
641 RewritePatternSet patterns(&getContext());
642 ControlType control_type = ControlType::get(&getContext());
643 InsertPatterns<MaterializeIfAttrs, IfOp, StatelessIfOp, StatefulIfOp>(
644 patterns, control_type);
645 InsertPatterns<MaterializeCaseAttrs, CaseOp, StatelessCaseOp, StatefulCaseOp>(
646 patterns, control_type);
647 InsertPatterns<MaterializeTAttr, WhileOp, StatelessWhileOp, StatefulWhileOp,
648 ForOp>(patterns, control_type);
649 patterns.insert<MaterializeTFGOpOutputShapes, MaterializeCFOpOutputShapes>(
650 &getContext());
651 if (failed(
652 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
653 getOperation()->emitError(getArgument() + " pass failed");
654 signalPassFailure();
655 return;
656 }
657
658 // If the pass was run on a function, materialize function, argument, and
659 // result attributes with type info.
660 if (func) prepareFunctionAttributes(func);
661 }
662
663 namespace {
664 struct ConsolidateAttributesPass
665 : public ConsolidateAttributesBase<ConsolidateAttributesPass> {
runOnOperationmlir::tfg::__anon17f873fd0c11::ConsolidateAttributesPass666 void runOnOperation() override {
667 // Run the sub-pass on both `tfg.graph` and `tfg.func`.
668 PassManager mgr(&getContext());
669 mgr.addNestedPass<GraphOp>(
670 std::make_unique<ConsolidateAttributesPassImpl>());
671 mgr.addNestedPass<GraphFuncOp>(
672 std::make_unique<ConsolidateAttributesPassImpl>());
673 if (failed(runPipeline(mgr, getOperation()))) signalPassFailure();
674 }
675 };
676
677 struct PrepareAttributesForExportPass
678 : public PrepareAttributesForExportBase<PrepareAttributesForExportPass> {
runOnOperationmlir::tfg::__anon17f873fd0c11::PrepareAttributesForExportPass679 void runOnOperation() override {
680 // Run the sub-pass on both `tfg.graph` and `tfg.func`.
681 PassManager mgr(&getContext());
682 mgr.addNestedPass<GraphOp>(
683 std::make_unique<PrepareAttributesForExportPassImpl>());
684 mgr.addNestedPass<GraphFuncOp>(
685 std::make_unique<PrepareAttributesForExportPassImpl>());
686 if (failed(runPipeline(mgr, getOperation()))) signalPassFailure();
687 }
688 };
689 } // namespace
690
CreateConsolidateAttributesPass()691 std::unique_ptr<Pass> CreateConsolidateAttributesPass() {
692 return std::make_unique<ConsolidateAttributesPass>();
693 }
694
CreatePrepareAttributesForExportPass()695 std::unique_ptr<Pass> CreatePrepareAttributesForExportPass() {
696 return std::make_unique<PrepareAttributesForExportPass>();
697 }
698
699 } // namespace tfg
700 } // namespace mlir
701