1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This is a pass to reduce operands without changing the outcome.
17
18 #include <vector>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/Casting.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
28 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
29 #include "mlir/IR/Attributes.h" // from @llvm-project
30 #include "mlir/IR/Block.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/OpDefinition.h" // from @llvm-project
36 #include "mlir/IR/Operation.h" // from @llvm-project
37 #include "mlir/IR/OperationSupport.h" // from @llvm-project
38 #include "mlir/IR/PatternMatch.h" // from @llvm-project
39 #include "mlir/IR/Region.h" // from @llvm-project
40 #include "mlir/IR/TypeRange.h" // from @llvm-project
41 #include "mlir/IR/Value.h" // from @llvm-project
42 #include "mlir/IR/Visitors.h" // from @llvm-project
43 #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
44 #include "mlir/Pass/Pass.h" // from @llvm-project
45 #include "mlir/Support/LLVM.h" // from @llvm-project
46 #include "mlir/Support/LogicalResult.h" // from @llvm-project
47 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
50 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
51
52 namespace mlir {
53 namespace TFL {
54 namespace {
55 #define GEN_PASS_CLASSES
56 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
57
58 struct ReduceWhileOperandsPass
59 : public ReduceWhileOperandsPassBase<ReduceWhileOperandsPass> {
60 public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_IDmlir::TFL::__anon315c05b20111::ReduceWhileOperandsPass61 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReduceWhileOperandsPass)
62
63 void getDependentDialects(DialectRegistry ®istry) const override {
64 registry.insert<TFL::TensorFlowLiteDialect, TF::TensorFlowDialect>();
65 }
66 void runOnOperation() override;
67 };
68
FindImplicityProducers(const std::vector<uint64_t> & explicitly_consumed_ids,std::vector<bool> & is_consumed_id,const std::vector<std::vector<uint64_t>> & dependency_graph)69 LogicalResult FindImplicityProducers(
70 const std::vector<uint64_t> &explicitly_consumed_ids,
71 std::vector<bool> &is_consumed_id,
72 const std::vector<std::vector<uint64_t>> &dependency_graph) {
73 std::vector<uint64_t> queue;
74 queue.reserve(is_consumed_id.size());
75 for (auto id : explicitly_consumed_ids) {
76 is_consumed_id[id] = true;
77 queue.push_back(id);
78 }
79 while (!queue.empty()) {
80 auto i = queue.back();
81 queue.pop_back();
82
83 // If there is a consumer which cannot be found in dependency graph, return
84 // false.
85 if (i >= dependency_graph.size()) {
86 return failure();
87 }
88
89 for (auto j : dependency_graph.at(i)) {
90 if (is_consumed_id[j]) continue;
91 queue.push_back(j);
92 is_consumed_id[j] = true;
93 }
94 }
95
96 return success();
97 }
98
FindProducers(Value start_node,std::vector<uint64_t> & neighbors)99 void FindProducers(Value start_node, std::vector<uint64_t> &neighbors) {
100 llvm::DenseSet<Value> visited;
101 std::vector<Value> queue;
102 queue.push_back(start_node);
103 visited.insert(start_node);
104 while (!queue.empty()) {
105 auto node = queue.back();
106 queue.pop_back();
107 if (auto arg = node.dyn_cast_or_null<BlockArgument>()) {
108 neighbors.push_back(arg.getArgNumber());
109 continue;
110 }
111 if (!node.getDefiningOp()) continue;
112 for (Value operand : node.getDefiningOp()->getOperands()) {
113 if (visited.contains(operand)) continue;
114 queue.push_back(operand);
115 visited.insert(operand);
116 }
117 }
118 }
119
FindConsumedOp(Operation * start_op,llvm::DenseSet<Operation * > & consumed_ops)120 void FindConsumedOp(Operation *start_op,
121 llvm::DenseSet<Operation *> &consumed_ops) {
122 if (consumed_ops.contains(start_op)) return;
123 std::vector<Operation *> queue;
124 queue.push_back(start_op);
125 consumed_ops.insert(start_op);
126 while (!queue.empty()) {
127 auto op = queue.back();
128 queue.pop_back();
129 for (Value operand : op->getOperands()) {
130 if (!operand.getDefiningOp()) continue;
131 auto def_op = operand.getDefiningOp();
132 if (consumed_ops.contains(def_op)) continue;
133 queue.push_back(def_op);
134 consumed_ops.insert(def_op);
135 }
136 }
137 }
138
IsConstant(Operation * op)139 inline bool IsConstant(Operation *op) { return matchPattern(op, m_Constant()); }
140
AllOperationSafe(Block & block)141 bool AllOperationSafe(Block &block) {
142 auto walk_result = block.walk([&](Operation *op) {
143 // op has SideEffect.
144 if (!isa_and_nonnull<TFL::WhileOp>(op) &&
145 !op->hasTrait<OpTrait::IsTerminator>() &&
146 !MemoryEffectOpInterface::hasNoEffect(op)) {
147 return WalkResult::interrupt();
148 }
149 // op has implict arguments not listed in operands.
150 // Fact: if every op's operands are defined in the same block as op,
151 // then no operation has implicit arugments (constant doesn't count).
152 for (auto operand : op->getOperands()) {
153 if (operand.dyn_cast_or_null<BlockArgument>()) continue;
154 auto operand_op = operand.getDefiningOp();
155 if (IsConstant(operand_op)) continue;
156 if (operand_op->getBlock() != op->getBlock()) {
157 return WalkResult::interrupt();
158 }
159 }
160 return WalkResult::advance();
161 });
162 return !walk_result.wasInterrupted();
163 }
164
165 // It reduces the following pattern:
166 //
167 // S = (0, 0, 0)
168 // while S[2] < 3:
169 // a0 = S[0] * 2
170 // a1 = a0 + S[1]
171 // a2 = S[2] + 1
172 // S = (a0, a1, a2)
173 // return S[0]
174 //
175 // the 2nd operand (i = 1) as well as its related op (a1 = a0 + S[1])
176 // can be removed since only S[0] is returned.
177 // It cannot be removed by loop-invariant-code-motion pass since every value
178 // is used and changed in the while loop.
179
180 // Moreover, we require
181 // 1. no implicit argument: For every operation in whileOp, all dependent values
182 // (except for constant) are explicitly passed in.
183 // 2. no side effect: Every operation inside whileOp can be safely
184 // remove when it is useEmpty().
185 // 3. no call func inside while.
ReduceWhileOperands(TFL::WhileOp while_op)186 bool ReduceWhileOperands(TFL::WhileOp while_op) {
187 std::vector<uint64_t> explicitly_consumed_ids;
188 Block &cond = while_op.cond().front();
189 Block &body = while_op.body().front();
190
191 auto n = while_op.getNumOperands();
192 if (!AllOperationSafe(cond) || !AllOperationSafe(body)) return false;
193
194 // Find all Consumed indices.
195 // i is consumed element if result(i) is used outside whileOp or
196 // arugment(i) is used in whileOp.cond().
197 for (auto i = 0; i < n; ++i) {
198 if (!while_op.getResult(i).use_empty() ||
199 !cond.getArgument(i).use_empty()) {
200 explicitly_consumed_ids.push_back(i);
201 }
202 }
203 // Empty consumed_element_ids implies none of results is used.
204 if (explicitly_consumed_ids.empty()) {
205 while_op.erase();
206 return true;
207 }
208 // If every element is consumed, one can't reduce any operand.
209 if (explicitly_consumed_ids.size() == n) {
210 return false;
211 }
212
213 // Build the dependency graph.
214 // If result(i) is depend on argument(j) in While.body(), then we put
215 // directed edge (i->j) into the graph.
216 std::vector<std::vector<uint64_t>> dependency_graph;
217 dependency_graph.reserve(n);
218
219 Operation &yield_op = body.back();
220 auto results = yield_op.getOperands();
221 for (auto i = 0; i < n; ++i) {
222 std::vector<uint64_t> neighbors;
223 neighbors.reserve(n);
224 FindProducers(results[i], neighbors);
225 dependency_graph.push_back(neighbors);
226 }
227
228 std::vector<bool> is_consumed_id(n, false);
229 if (failed(FindImplicityProducers(explicitly_consumed_ids, is_consumed_id,
230 dependency_graph))) {
231 return false;
232 }
233
234 // Find all consumed operations in while body.
235 llvm::DenseSet<Operation *> consumed_ops;
236 // We'll pass in the erase_indices to erase several operands simultaneously.
237 llvm::BitVector erase_indices(n);
238 consumed_ops.insert(&yield_op);
239 for (auto i = 0; i < n; ++i) {
240 if (!is_consumed_id[i]) {
241 erase_indices.set(i);
242 } else if (results[i].getDefiningOp()) {
243 FindConsumedOp(results[i].getDefiningOp(), consumed_ops);
244 }
245 }
246 // Remove elements and operations in while_body that are not indispensable.
247 yield_op.eraseOperands(erase_indices);
248 // Remove ops from bottom to top.
249 for (Operation &op :
250 llvm::make_early_inc_range(reverse(body.getOperations())))
251 // Constant will not be removed in case it is implicitly used.
252 if (!consumed_ops.contains(&op) && !IsConstant(&op)) {
253 op.erase();
254 }
255 body.eraseArguments(erase_indices);
256 cond.eraseArguments(erase_indices);
257
258 llvm::SmallVector<Value> new_operands;
259 llvm::SmallVector<Type> new_result_types;
260 new_operands.reserve(n - erase_indices.size());
261 new_result_types.reserve(n - erase_indices.size());
262 // After reducing, the number of results is decreased. The i-th result of old
263 // WhileOp becomes the j-th (j<=i) result of new WhileOp. This information is
264 // stored in id_map (id_map[i] = j).
265 std::vector<uint64_t> id_map(n, 0);
266 uint64_t j = 0;
267 for (auto i = 0; i < n; ++i) {
268 if (is_consumed_id[i]) {
269 id_map[i] = j++;
270 new_operands.push_back(while_op.getOperand(i));
271 new_result_types.push_back(while_op.getResultTypes()[i]);
272 }
273 }
274
275 auto new_while_op = OpBuilder(while_op).create<WhileOp>(
276 while_op.getLoc(), new_result_types, new_operands, while_op->getAttrs());
277 new_while_op.cond().takeBody(while_op.cond());
278 new_while_op.body().takeBody(while_op.body());
279
280 for (auto i = 0; i < n; ++i) {
281 if (!while_op.getResult(i).use_empty()) {
282 auto j = id_map[i];
283 while_op.getResult(i).replaceAllUsesWith(new_while_op.getResult(j));
284 }
285 }
286 while_op.erase();
287 return erase_indices.any();
288 }
289
runOnOperation()290 void ReduceWhileOperandsPass::runOnOperation() {
291 auto fn = getOperation();
292 fn.walk([&](TFL::WhileOp while_op) { ReduceWhileOperands(while_op); });
293 }
294
295 static PassRegistration<ReduceWhileOperandsPass> pass;
296 } // namespace
297
CreateReduceWhileOperandsPass()298 std::unique_ptr<OperationPass<func::FuncOp>> CreateReduceWhileOperandsPass() {
299 return std::make_unique<ReduceWhileOperandsPass>();
300 }
301
302 } // namespace TFL
303 } // namespace mlir
304