1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <cstdint>
18 #include <iterator>
19 #include <numeric>
20 #include <string>
21 #include <utility>
22
23 #include "absl/memory/memory.h"
24 #include "absl/strings/string_view.h"
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/None.h"
27 #include "llvm/ADT/Optional.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/Support/Casting.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
35 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
36 #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
37 #include "mlir/IR/Attributes.h" // from @llvm-project
38 #include "mlir/IR/Builders.h" // from @llvm-project
39 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
40 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
41 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
42 #include "mlir/IR/MLIRContext.h" // from @llvm-project
43 #include "mlir/IR/SymbolTable.h" // from @llvm-project
44 #include "mlir/IR/Value.h" // from @llvm-project
45 #include "mlir/IR/Visitors.h" // from @llvm-project
46 #include "mlir/Pass/Pass.h" // from @llvm-project
47 #include "mlir/Support/LLVM.h" // from @llvm-project
48 #include "mlir/Support/LogicalResult.h" // from @llvm-project
49 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
50 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
51 #include "mlir/Transforms/InliningUtils.h" // from @llvm-project
52 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
53 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
54 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
55 #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h"
56 #include "tensorflow/compiler/mlir/tfr/passes/passes.h"
57 #include "tensorflow/compiler/mlir/tfr/utils/utils.h"
58 #include "tensorflow/core/lib/monitoring/counter.h"
59
60 namespace tensorflow {
61 namespace {
62
63 auto* tf_core_op_expansion_op_counter =
64 monitoring::Counter<1>::New("/tensorflow/core/op_expansion/op_counter",
65 "The number of composite op expanded.", "name");
66 }
67
IncreaseOpExpansionExecuteCounterByOne(const std::string & op_name)68 void IncreaseOpExpansionExecuteCounterByOne(const std::string& op_name) {
69 tf_core_op_expansion_op_counter->GetCell(op_name)->IncrementBy(1);
70 }
71
72 } // namespace tensorflow
73
74 //===----------------------------------------------------------------------===//
75 // The pass to decompose unregistered TF ops with the TFR compose function.
76 //
77 namespace mlir {
78 namespace TFR {
79
80 namespace {
81
82 // Quantize the float value based on given scale and zero point attributes.
Quantize(float value,Attribute scale_attr,Attribute zp_attr,OpBuilder builder)83 Attribute Quantize(float value, Attribute scale_attr, Attribute zp_attr,
84 OpBuilder builder) {
85 double scale = scale_attr.cast<FloatAttr>().getValueAsDouble();
86 int64_t zp = zp_attr.cast<IntegerAttr>().getInt();
87
88 int quantized = static_cast<int>(std::round(value / scale) + zp);
89 quantized =
90 std::min(quantized, static_cast<int>(std::numeric_limits<int8_t>::max()));
91 quantized =
92 std::max(quantized, static_cast<int>(std::numeric_limits<int8_t>::min()));
93 return builder.getI32IntegerAttr(quantized);
94 }
95
96 // Decompose the TF ops with the registered composition library.
97 class DecomposeTFOpsPass
98 : public PassWrapper<DecomposeTFOpsPass, OperationPass<func::FuncOp>> {
99 public:
100 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DecomposeTFOpsPass)
101
DecomposeTFOpsPass(llvm::Optional<ModuleOp> external_tfr_module)102 explicit DecomposeTFOpsPass(llvm::Optional<ModuleOp> external_tfr_module)
103 : external_tfr_module_(external_tfr_module) {}
104
getArgument() const105 StringRef getArgument() const final { return "tfr-decompose"; }
106
getDescription() const107 StringRef getDescription() const final {
108 return "Decompose TF ops with the registered composition library.";
109 }
110
111 void runOnOperation() override;
112
113 private:
114 // Apply canonicalization, mainly constant folding, on the function.
115 void ApplyCanonicalization();
116
117 // Rewrite unregistered TF ops to TFR func call ops. Return failure if all the
118 // ops are registered or the compose function doesn't exist.
119 LogicalResult RewriteUnregisteredTFOps();
120
121 // Inline the TFR func call ops.
122 LogicalResult InlineTFRFuncCalls();
123
124 // Optional external symbol table to look up the TFR function.
125 llvm::Optional<ModuleOp> external_tfr_module_;
126 };
127
128 #include "tensorflow/compiler/mlir/tfr/passes/generated_decompose.inc"
129
ApplyCanonicalization()130 void DecomposeTFOpsPass::ApplyCanonicalization() {
131 func::FuncOp func = getOperation();
132 RewritePatternSet patterns(&getContext());
133
134 populateWithGenerated(patterns);
135 populateCanonicalizationPatterns(func, patterns);
136
137 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
138 }
139
RewriteUnregisteredTFOps()140 LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() {
141 func::FuncOp func = getOperation();
142 SymbolTable table(external_tfr_module_.has_value()
143 ? *external_tfr_module_
144 : func->getParentOfType<ModuleOp>());
145 OpBuilder builder(func);
146 bool changed = false;
147 func.walk([&table, &builder, &changed](Operation* op) {
148 // Only the un-registered ops requires decomposition. The remaining ones
149 // either will be constant folded or lowered by the rules defined in the
150 // bridge.
151 if (op->isRegistered()) {
152 return WalkResult::advance();
153 }
154
155 // Find out the compose function
156 auto compose_func_name = GetComposeFuncName(op->getName().getStringRef());
157 auto compose_func = table.lookup<TFRFuncOp>(compose_func_name);
158 if (!compose_func || compose_func.isExternal()) {
159 // There are no decomposition methods defined for this op, skip.
160 return WalkResult::advance();
161 }
162
163 // Make sure all the attributes are valid. An attribute is valid when it is
164 // in the signature or it is allowed explicitly.
165 auto compose_func_signature =
166 table.lookup<TFRFuncOp>(compose_func_name + "_");
167 if (!compose_func_signature) compose_func_signature = compose_func;
168 auto defined_attrs = compose_func_signature.getDefinedAttributeNames();
169 if (failed(ValidateAttrs(op, defined_attrs))) {
170 return WalkResult::interrupt();
171 }
172
173 tensorflow::IncreaseOpExpansionExecuteCounterByOne(
174 op->getName().getStringRef().str());
175
176 auto compose_func_type = compose_func.getFunctionType();
177 builder.setInsertionPoint(op);
178 TFRTensorType unconstrainted_tensor_type = builder.getType<TFRTensorType>();
179
180 // Create the new operands. This is mapping the operands from the target
181 // TF ops to the TFR function arguments. If the TFR function argument is
182 // a tensor_list, a "tfr.build_list" op is used to concat the available
183 // TF op operands. If the TFR function argument isn't a tensor/tensor_list,
184 // a constant is created by using the attribute stored in the TF op or the
185 // default value in the argument attribute.
186 llvm::SmallVector<Value, 4> new_operands;
187 for (auto arg : llvm::enumerate(compose_func_type.getInputs())) {
188 if (auto tensor_type = arg.value().dyn_cast<TFRTensorType>()) {
189 auto casted = builder.create<CastOp>(op->getLoc(), tensor_type,
190 op->getOperand(arg.index()));
191 new_operands.push_back(casted);
192 } else if (auto list_type = arg.value().dyn_cast<TFRTensorListType>()) {
193 llvm::SmallVector<Value, 4> variadic_operands;
194 for (int i = arg.index(); i < op->getNumOperands(); i++) {
195 auto casted = builder.create<CastOp>(
196 op->getLoc(), unconstrainted_tensor_type, op->getOperand(i));
197 variadic_operands.push_back(casted);
198 }
199 auto build_list_op = builder.create<BuildListOp>(
200 op->getLoc(), list_type, variadic_operands);
201 new_operands.push_back(build_list_op.out());
202 } else {
203 auto attr_name = compose_func.getArgAttrOfType<StringAttr>(
204 arg.index(), kAttrArgumentNameAttr);
205 auto attribute = op->getAttr(attr_name.getValue());
206 if (!attribute) {
207 attribute =
208 compose_func.getArgAttr(arg.index(), kAttrArgumentDefaultAttr);
209 }
210 if (!attribute && attr_name.getValue() == "out_type") {
211 auto type = op->getResult(0).getType();
212 if (type.isa<TensorType>()) {
213 type = type.cast<TensorType>().getElementType();
214 }
215 attribute = TypeAttr::get(type);
216 }
217 Value attr_cst;
218 // Wrap these special attributes as a special TFR constant, so the SSA
219 // value has a valid type to be used as TFR function argument. These
220 // attributes are not expected to be manipulated by the lowering passes.
221 if (attribute.isa<TypeAttr>() || attribute.isa<ArrayAttr>() ||
222 attribute.isa<StringAttr>() || attribute.isa<FlatSymbolRefAttr>()) {
223 TFRAttrType output_type = TFRAttrType::get(builder.getContext());
224 attr_cst =
225 builder.create<ConstOp>(op->getLoc(), output_type, attribute);
226 } else {
227 attr_cst =
228 builder.create<mlir::arith::ConstantOp>(op->getLoc(), attribute);
229 }
230 new_operands.push_back(attr_cst);
231 }
232 }
233
234 // Create the TFR call op
235 auto new_op = builder.create<CallOp>(
236 op->getLoc(), compose_func_type.getResults(),
237 SymbolRefAttr::get(builder.getContext(), compose_func.getName()),
238 new_operands);
239
240 // Replace the use of the old op. This is mapping the results from the
241 // target TF ops to the TFR function returns. If the TFR function return is
242 // a tensor_list, "tfr.get_element" op is used to extract the required TF
243 // op result.
244 llvm::SmallVector<Value, 4> new_results;
245 for (auto res : llvm::enumerate(compose_func_type.getResults())) {
246 if (res.value().dyn_cast<TFRTensorType>()) {
247 new_results.push_back(new_op.getResult(res.index()));
248 } else if (auto list_type = res.value().dyn_cast<TFRTensorListType>()) {
249 for (int i = res.index(), j = 0; i < op->getNumResults(); i++, j++) {
250 auto index = builder.create<mlir::arith::ConstantOp>(
251 op->getLoc(), builder.getIndexAttr(j));
252 auto element_op = builder.create<GetElementOp>(
253 op->getLoc(), unconstrainted_tensor_type,
254 new_op.getResult(res.index()), index.getResult());
255 new_results.push_back(element_op.out());
256 }
257 }
258 }
259 for (auto res : llvm::zip(op->getResults(), new_results)) {
260 auto casted = builder.create<CastOp>(
261 op->getLoc(), std::get<0>(res).getType(), std::get<1>(res));
262 std::get<0>(res).replaceAllUsesWith(casted.out());
263 }
264
265 // Copy all the unregisted attributes to the new op.
266 if (failed(CopyAllowedUnregisteredAttrs(op, new_op, defined_attrs))) {
267 return WalkResult::interrupt();
268 }
269
270 op->erase();
271 changed |= true;
272 return WalkResult::advance();
273 });
274
275 // If `changed` is false, it is considered as a failure, so the recursive
276 // rewrite will stop.
277 return success(changed);
278 }
279
InlineTFRFuncCalls()280 LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() {
281 // The Inliner will automatically use the registered dialect inliner.
282 InlinerInterface inliner(&getContext());
283 func::FuncOp func = getOperation();
284 SymbolTable table(external_tfr_module_.has_value()
285 ? *external_tfr_module_
286 : func->getParentOfType<ModuleOp>());
287
288 // The inliner only inlines the TFR call op.
289 bool changed = false;
290 auto walk_result = func.walk([&](CallOp call_op) {
291 auto callee = table.lookup<TFRFuncOp>(call_op.callee());
292 if (!callee || callee.isExternal()) return WalkResult::advance();
293
294 // Record the boundary of the inlined operations. The inlined operation will
295 // be inserted between these two operations.
296 Operation* inlined_point = call_op.getOperation();
297 Operation* after_inlined_point =
298 &*std::next(Block::iterator(call_op.getOperation()));
299
300 // Use the inliner to replace all the uses of the call_op by its
301 // composition.
302 if (failed(inlineCall(inliner,
303 cast<CallOpInterface>(call_op.getOperation()),
304 cast<CallableOpInterface>(callee.getOperation()),
305 callee.getCallableRegion(),
306 /**shouldCloneInLinedRegion=*/true))) {
307 // This failure is usually because the decompose function is not defined.
308 // This call will be raised to TF ops.
309 return WalkResult::interrupt();
310 }
311
312 // Propagate all the attributes to the inlined operations, which are defined
313 // by the two boundary operations.
314 PropagateAttrsToOperations(call_op, Block::iterator(inlined_point),
315 Block::iterator(after_inlined_point));
316
317 // Remove the call_op to finish the op expansion.
318 call_op.erase();
319 changed |= true;
320 return WalkResult::advance();
321 });
322
323 if (walk_result.wasInterrupted()) {
324 signalPassFailure();
325 return failure();
326 }
327
328 // If `changed` is false, it is considered as a failure, so the recursive
329 // rewrite will stop.
330 return success(changed);
331 }
332
runOnOperation()333 void DecomposeTFOpsPass::runOnOperation() {
334 // Set a maximum iteration threshold in case there are infinite loops in the
335 // call stack.
336 int max_iterators = 10;
337 do {
338 // canonicalization
339 ApplyCanonicalization();
340
341 // rewrite unregistered tf ops. Failed either because no ops can be
342 // decomposed or the compose function isn't defined.
343 auto rewrite_status = RewriteUnregisteredTFOps();
344 // inline the tfr call op until there are no tfr.call op can be inlined.
345 auto inline_status = InlineTFRFuncCalls();
346
347 if (failed(rewrite_status) && failed(inline_status)) {
348 break;
349 }
350 } while (max_iterators-- >= 0);
351 }
352
353 } // namespace
354
355 // Creates an instance of the pass to decompose the TF ops.
CreateDecomposeTFOpsPass(llvm::Optional<ModuleOp> tfr_module)356 std::unique_ptr<OperationPass<func::FuncOp>> CreateDecomposeTFOpsPass(
357 llvm::Optional<ModuleOp> tfr_module) {
358 return std::make_unique<DecomposeTFOpsPass>(tfr_module);
359 }
360
__anon3ffcefa40502null361 static PassRegistration<DecomposeTFOpsPass> pass([] {
362 return CreateDecomposeTFOpsPass();
363 });
364
365 } // namespace TFR
366 } // namespace mlir
367