xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/split_merged_operands.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Block.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
28 #include "mlir/IR/Matchers.h"  // from @llvm-project
29 #include "mlir/IR/Operation.h"  // from @llvm-project
30 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
31 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
32 #include "mlir/IR/Types.h"  // from @llvm-project
33 #include "mlir/IR/Value.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
36 #include "mlir/Support/LLVM.h"  // from @llvm-project
37 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
39 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
40 #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
41 
42 // Background info:
43 // Currently the model taken to MLIRConverter is frozen (all the variables have
44 // been converted to constants, all the assign ops are gone, etc.). However,
45 // TFLite has these variable tensors semantics. So the variable mapping from TF
46 // to TFLite is actually broken here, we sort of hard-code the variable tensors
47 // based on the actual ops using them, such as unidirectional_sequence_lstm.
48 //
49 // MLIRConverter also benefits from lots of typical compiler optimization like
50 // merging same input values if they're identical. These optimizations are
51 // desirable but not for those TFLite ops which have variable tensors as inputs.
52 // Yes, they have identical input values, but those identical values are
53 // "stateful", their values can change during invocations.
54 //
55 // A typical example is unidirectional_sequence_lstm have two variable tensor
56 // inputs: activation state & cell state. They may have same initial values
57 // (typical zero-initialized), but their values will be changed. So we cannot
58 // just merge those values.
59 //
60 // This pass is more like short-term workaround since we don't have a good
61 // variable representation right now.
62 //
63 // This pass will duplicate input values for those variable tensor inputs.
64 
65 namespace mlir {
66 namespace TFL {
67 namespace {
68 #define GEN_PASS_CLASSES
69 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
70 
71 struct SplitMergedOperandsPass
72     : public SplitMergedOperandsPassBase<SplitMergedOperandsPass> {
73   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SplitMergedOperandsPass)
74 
75   void runOnOperation() override;
76 };
77 
DuplicateValueIfNeeded(Operation * op,llvm::DenseSet<Value> * values,OpBuilder * builder)78 LogicalResult DuplicateValueIfNeeded(Operation* op,
79                                      llvm::DenseSet<Value>* values,
80                                      OpBuilder* builder) {
81   std::vector<int> stateful_operands_index;
82   if (!IsStatefulOp(op, &stateful_operands_index)) return success();
83 
84   for (int index : stateful_operands_index) {
85     Value operand = op->getOperand(index);
86     auto inserted_value = values->insert(operand).second;
87     if (inserted_value) continue;
88     // We can only clone the constant op at this point.
89     // Since all ops have been legalized to tflite ops, so we only care about
90     // ConstOp or QConstOp or mlir constant op/
91     Operation* input_op = operand.getDefiningOp();
92     if (input_op == nullptr) return failure();
93 
94     Attribute attr;
95     if (!matchPattern(input_op, m_Constant(&attr))) {
96       op->emitError()
97           << "We cannot duplicate the value since it's not constant.\n";
98       return failure();
99     }
100     builder->setInsertionPoint(op);
101     Operation* duplicated_input_op = builder->clone(*input_op);
102 
103     // Rewire the inputs.
104     op->setOperand(index, duplicated_input_op->getResult(0));
105   }
106   return success();
107 }
108 
runOnOperation()109 void SplitMergedOperandsPass::runOnOperation() {
110   llvm::DenseSet<Value> stateful_values;
111   auto func = getOperation();
112   OpBuilder builder(func);
113   for (auto& bb : func.getBody()) {
114     for (auto& op : bb) {
115       if (failed(DuplicateValueIfNeeded(&op, &stateful_values, &builder))) {
116         func.emitError() << "Failed to duplicate values for the stateful op\n";
117         return signalPassFailure();
118       }
119     }
120   }
121 }
122 
123 }  // namespace
124 
125 /// Creates an instance of the TensorFlow Lite dialect SplitMergedOperands
126 /// pass.
CreateSplitMergedOperandsPass()127 std::unique_ptr<OperationPass<func::FuncOp>> CreateSplitMergedOperandsPass() {
128   return std::make_unique<SplitMergedOperandsPass>();
129 }
130 
131 }  // namespace TFL
132 }  // namespace mlir
133