xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfr/passes/decompose.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cstdint>
18 #include <iterator>
19 #include <numeric>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/memory/memory.h"
24 #include "absl/strings/string_view.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/None.h"
27 #include "llvm/ADT/Optional.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/Support/Casting.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
35 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
36 #include "mlir/Dialect/SCF/IR/SCF.h"  // from @llvm-project
37 #include "mlir/IR/Attributes.h"  // from @llvm-project
38 #include "mlir/IR/Builders.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
40 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
41 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
42 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
43 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
44 #include "mlir/IR/Value.h"  // from @llvm-project
45 #include "mlir/IR/Visitors.h"  // from @llvm-project
46 #include "mlir/Pass/Pass.h"  // from @llvm-project
47 #include "mlir/Support/LLVM.h"  // from @llvm-project
48 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
49 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
50 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
51 #include "mlir/Transforms/InliningUtils.h"  // from @llvm-project
52 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
53 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
54 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
55 #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
56 #include "tensorflow/compiler/mlir/tfr/passes/passes.h"
57 #include "tensorflow/compiler/mlir/tfr/utils/utils.h"
58 #include "tensorflow/core/lib/monitoring/counter.h"
59 
60 namespace tensorflow {
61 namespace {
62 
63 auto* tf_core_op_expansion_op_counter =
64     monitoring::Counter<1>::New("/tensorflow/core/op_expansion/op_counter",
65                                 "The number of composite op expanded.", "name");
66 }
67 
IncreaseOpExpansionExecuteCounterByOne(const std::string & op_name)68 void IncreaseOpExpansionExecuteCounterByOne(const std::string& op_name) {
69   tf_core_op_expansion_op_counter->GetCell(op_name)->IncrementBy(1);
70 }
71 
72 }  // namespace tensorflow
73 
74 //===----------------------------------------------------------------------===//
75 // The pass to decompose unregistered TF ops with the TFR compose function.
76 //
77 namespace mlir {
78 namespace TFR {
79 
80 namespace {
81 
82 // Quantize the float value based on given scale and zero point attributes.
Quantize(float value,Attribute scale_attr,Attribute zp_attr,OpBuilder builder)83 Attribute Quantize(float value, Attribute scale_attr, Attribute zp_attr,
84                    OpBuilder builder) {
85   double scale = scale_attr.cast<FloatAttr>().getValueAsDouble();
86   int64_t zp = zp_attr.cast<IntegerAttr>().getInt();
87 
88   int quantized = static_cast<int>(std::round(value / scale) + zp);
89   quantized =
90       std::min(quantized, static_cast<int>(std::numeric_limits<int8_t>::max()));
91   quantized =
92       std::max(quantized, static_cast<int>(std::numeric_limits<int8_t>::min()));
93   return builder.getI32IntegerAttr(quantized);
94 }
95 
96 // Decompose the TF ops with the registered composition library.
97 class DecomposeTFOpsPass
98     : public PassWrapper<DecomposeTFOpsPass, OperationPass<func::FuncOp>> {
99  public:
100   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DecomposeTFOpsPass)
101 
DecomposeTFOpsPass(llvm::Optional<ModuleOp> external_tfr_module)102   explicit DecomposeTFOpsPass(llvm::Optional<ModuleOp> external_tfr_module)
103       : external_tfr_module_(external_tfr_module) {}
104 
getArgument() const105   StringRef getArgument() const final { return "tfr-decompose"; }
106 
getDescription() const107   StringRef getDescription() const final {
108     return "Decompose TF ops with the registered composition library.";
109   }
110 
111   void runOnOperation() override;
112 
113  private:
114   // Apply canonicalization, mainly constant folding, on the function.
115   void ApplyCanonicalization();
116 
117   // Rewrite unregistered TF ops to TFR func call ops. Return failure if all the
118   // ops are registered or the compose function doesn't exist.
119   LogicalResult RewriteUnregisteredTFOps();
120 
121   // Inline the TFR func call ops.
122   LogicalResult InlineTFRFuncCalls();
123 
124   // Optional external symbol table to look up the TFR function.
125   llvm::Optional<ModuleOp> external_tfr_module_;
126 };
127 
128 #include "tensorflow/compiler/mlir/tfr/passes/generated_decompose.inc"
129 
ApplyCanonicalization()130 void DecomposeTFOpsPass::ApplyCanonicalization() {
131   func::FuncOp func = getOperation();
132   RewritePatternSet patterns(&getContext());
133 
134   populateWithGenerated(patterns);
135   populateCanonicalizationPatterns(func, patterns);
136 
137   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
138 }
139 
RewriteUnregisteredTFOps()140 LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() {
141   func::FuncOp func = getOperation();
142   SymbolTable table(external_tfr_module_.has_value()
143                         ? *external_tfr_module_
144                         : func->getParentOfType<ModuleOp>());
145   OpBuilder builder(func);
146   bool changed = false;
147   func.walk([&table, &builder, &changed](Operation* op) {
148     // Only the un-registered ops requires decomposition. The remaining ones
149     // either will be constant folded or lowered by the rules defined in the
150     // bridge.
151     if (op->isRegistered()) {
152       return WalkResult::advance();
153     }
154 
155     // Find out the compose function
156     auto compose_func_name = GetComposeFuncName(op->getName().getStringRef());
157     auto compose_func = table.lookup<TFRFuncOp>(compose_func_name);
158     if (!compose_func || compose_func.isExternal()) {
159       // There are no decomposition methods defined for this op, skip.
160       return WalkResult::advance();
161     }
162 
163     // Make sure all the attributes are valid. An attribute is valid when it is
164     // in the signature or it is allowed explicitly.
165     auto compose_func_signature =
166         table.lookup<TFRFuncOp>(compose_func_name + "_");
167     if (!compose_func_signature) compose_func_signature = compose_func;
168     auto defined_attrs = compose_func_signature.getDefinedAttributeNames();
169     if (failed(ValidateAttrs(op, defined_attrs))) {
170       return WalkResult::interrupt();
171     }
172 
173     tensorflow::IncreaseOpExpansionExecuteCounterByOne(
174         op->getName().getStringRef().str());
175 
176     auto compose_func_type = compose_func.getFunctionType();
177     builder.setInsertionPoint(op);
178     TFRTensorType unconstrainted_tensor_type = builder.getType<TFRTensorType>();
179 
180     // Create the new operands. This is mapping the operands from the target
181     // TF ops to the TFR function arguments. If the TFR function argument is
182     // a tensor_list, a "tfr.build_list" op is used to concat the available
183     // TF op operands. If the TFR function argument isn't a tensor/tensor_list,
184     // a constant is created by using the attribute stored in the TF op or the
185     // default value in the argument attribute.
186     llvm::SmallVector<Value, 4> new_operands;
187     for (auto arg : llvm::enumerate(compose_func_type.getInputs())) {
188       if (auto tensor_type = arg.value().dyn_cast<TFRTensorType>()) {
189         auto casted = builder.create<CastOp>(op->getLoc(), tensor_type,
190                                              op->getOperand(arg.index()));
191         new_operands.push_back(casted);
192       } else if (auto list_type = arg.value().dyn_cast<TFRTensorListType>()) {
193         llvm::SmallVector<Value, 4> variadic_operands;
194         for (int i = arg.index(); i < op->getNumOperands(); i++) {
195           auto casted = builder.create<CastOp>(
196               op->getLoc(), unconstrainted_tensor_type, op->getOperand(i));
197           variadic_operands.push_back(casted);
198         }
199         auto build_list_op = builder.create<BuildListOp>(
200             op->getLoc(), list_type, variadic_operands);
201         new_operands.push_back(build_list_op.out());
202       } else {
203         auto attr_name = compose_func.getArgAttrOfType<StringAttr>(
204             arg.index(), kAttrArgumentNameAttr);
205         auto attribute = op->getAttr(attr_name.getValue());
206         if (!attribute) {
207           attribute =
208               compose_func.getArgAttr(arg.index(), kAttrArgumentDefaultAttr);
209         }
210         if (!attribute && attr_name.getValue() == "out_type") {
211           auto type = op->getResult(0).getType();
212           if (type.isa<TensorType>()) {
213             type = type.cast<TensorType>().getElementType();
214           }
215           attribute = TypeAttr::get(type);
216         }
217         Value attr_cst;
218         // Wrap these special attributes as a special TFR constant, so the SSA
219         // value has a valid type to be used as TFR function argument. These
220         // attributes are not expected to be manipulated by the lowering passes.
221         if (attribute.isa<TypeAttr>() || attribute.isa<ArrayAttr>() ||
222             attribute.isa<StringAttr>() || attribute.isa<FlatSymbolRefAttr>()) {
223           TFRAttrType output_type = TFRAttrType::get(builder.getContext());
224           attr_cst =
225               builder.create<ConstOp>(op->getLoc(), output_type, attribute);
226         } else {
227           attr_cst =
228               builder.create<mlir::arith::ConstantOp>(op->getLoc(), attribute);
229         }
230         new_operands.push_back(attr_cst);
231       }
232     }
233 
234     // Create the TFR call op
235     auto new_op = builder.create<CallOp>(
236         op->getLoc(), compose_func_type.getResults(),
237         SymbolRefAttr::get(builder.getContext(), compose_func.getName()),
238         new_operands);
239 
240     // Replace the use of the old op. This is mapping the results from the
241     // target TF ops to the TFR function returns. If the TFR function return is
242     // a tensor_list, "tfr.get_element" op is used to extract the required TF
243     // op result.
244     llvm::SmallVector<Value, 4> new_results;
245     for (auto res : llvm::enumerate(compose_func_type.getResults())) {
246       if (res.value().dyn_cast<TFRTensorType>()) {
247         new_results.push_back(new_op.getResult(res.index()));
248       } else if (auto list_type = res.value().dyn_cast<TFRTensorListType>()) {
249         for (int i = res.index(), j = 0; i < op->getNumResults(); i++, j++) {
250           auto index = builder.create<mlir::arith::ConstantOp>(
251               op->getLoc(), builder.getIndexAttr(j));
252           auto element_op = builder.create<GetElementOp>(
253               op->getLoc(), unconstrainted_tensor_type,
254               new_op.getResult(res.index()), index.getResult());
255           new_results.push_back(element_op.out());
256         }
257       }
258     }
259     for (auto res : llvm::zip(op->getResults(), new_results)) {
260       auto casted = builder.create<CastOp>(
261           op->getLoc(), std::get<0>(res).getType(), std::get<1>(res));
262       std::get<0>(res).replaceAllUsesWith(casted.out());
263     }
264 
265     // Copy all the unregisted attributes to the new op.
266     if (failed(CopyAllowedUnregisteredAttrs(op, new_op, defined_attrs))) {
267       return WalkResult::interrupt();
268     }
269 
270     op->erase();
271     changed |= true;
272     return WalkResult::advance();
273   });
274 
275   // If `changed` is false, it is considered as a failure, so the recursive
276   // rewrite will stop.
277   return success(changed);
278 }
279 
InlineTFRFuncCalls()280 LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() {
281   // The Inliner will automatically use the registered dialect inliner.
282   InlinerInterface inliner(&getContext());
283   func::FuncOp func = getOperation();
284   SymbolTable table(external_tfr_module_.has_value()
285                         ? *external_tfr_module_
286                         : func->getParentOfType<ModuleOp>());
287 
288   // The inliner only inlines the TFR call op.
289   bool changed = false;
290   auto walk_result = func.walk([&](CallOp call_op) {
291     auto callee = table.lookup<TFRFuncOp>(call_op.callee());
292     if (!callee || callee.isExternal()) return WalkResult::advance();
293 
294     // Record the boundary of the inlined operations. The inlined operation will
295     // be inserted between these two operations.
296     Operation* inlined_point = call_op.getOperation();
297     Operation* after_inlined_point =
298         &*std::next(Block::iterator(call_op.getOperation()));
299 
300     // Use the inliner to replace all the uses of the call_op by its
301     // composition.
302     if (failed(inlineCall(inliner,
303                           cast<CallOpInterface>(call_op.getOperation()),
304                           cast<CallableOpInterface>(callee.getOperation()),
305                           callee.getCallableRegion(),
306                           /**shouldCloneInLinedRegion=*/true))) {
307       // This failure is usually because the decompose function is not defined.
308       // This call will be raised to TF ops.
309       return WalkResult::interrupt();
310     }
311 
312     // Propagate all the attributes to the inlined operations, which are defined
313     // by the two boundary operations.
314     PropagateAttrsToOperations(call_op, Block::iterator(inlined_point),
315                                Block::iterator(after_inlined_point));
316 
317     // Remove the call_op to finish the op expansion.
318     call_op.erase();
319     changed |= true;
320     return WalkResult::advance();
321   });
322 
323   if (walk_result.wasInterrupted()) {
324     signalPassFailure();
325     return failure();
326   }
327 
328   // If `changed` is false, it is considered as a failure, so the recursive
329   // rewrite will stop.
330   return success(changed);
331 }
332 
runOnOperation()333 void DecomposeTFOpsPass::runOnOperation() {
334   // Set a maximum iteration threshold in case there are infinite loops in the
335   // call stack.
336   int max_iterators = 10;
337   do {
338     // canonicalization
339     ApplyCanonicalization();
340 
341     // rewrite unregistered tf ops. Failed either because no ops can be
342     // decomposed or the compose function isn't defined.
343     auto rewrite_status = RewriteUnregisteredTFOps();
344     // inline the tfr call op until there are no tfr.call op can be inlined.
345     auto inline_status = InlineTFRFuncCalls();
346 
347     if (failed(rewrite_status) && failed(inline_status)) {
348       break;
349     }
350   } while (max_iterators-- >= 0);
351 }
352 
353 }  // namespace
354 
355 // Creates an instance of the pass to decompose the TF ops.
CreateDecomposeTFOpsPass(llvm::Optional<ModuleOp> tfr_module)356 std::unique_ptr<OperationPass<func::FuncOp>> CreateDecomposeTFOpsPass(
357     llvm::Optional<ModuleOp> tfr_module) {
358   return std::make_unique<DecomposeTFOpsPass>(tfr_module);
359 }
360 
__anon3ffcefa40502null361 static PassRegistration<DecomposeTFOpsPass> pass([] {
362   return CreateDecomposeTFOpsPass();
363 });
364 
365 }  // namespace TFR
366 }  // namespace mlir
367