xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This is a pass to reduce operands without changing the outcome.
17 
18 #include <vector>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/Casting.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
28 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/Block.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
36 #include "mlir/IR/Operation.h"  // from @llvm-project
37 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
38 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
39 #include "mlir/IR/Region.h"  // from @llvm-project
40 #include "mlir/IR/TypeRange.h"  // from @llvm-project
41 #include "mlir/IR/Value.h"  // from @llvm-project
42 #include "mlir/IR/Visitors.h"  // from @llvm-project
43 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
44 #include "mlir/Pass/Pass.h"  // from @llvm-project
45 #include "mlir/Support/LLVM.h"  // from @llvm-project
46 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
47 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
49 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
50 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
51 
52 namespace mlir {
53 namespace TFL {
54 namespace {
55 #define GEN_PASS_CLASSES
56 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
57 
58 struct ReduceWhileOperandsPass
59     : public ReduceWhileOperandsPassBase<ReduceWhileOperandsPass> {
60  public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_IDmlir::TFL::__anon315c05b20111::ReduceWhileOperandsPass61   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReduceWhileOperandsPass)
62 
63   void getDependentDialects(DialectRegistry &registry) const override {
64     registry.insert<TFL::TensorFlowLiteDialect, TF::TensorFlowDialect>();
65   }
66   void runOnOperation() override;
67 };
68 
FindImplicityProducers(const std::vector<uint64_t> & explicitly_consumed_ids,std::vector<bool> & is_consumed_id,const std::vector<std::vector<uint64_t>> & dependency_graph)69 LogicalResult FindImplicityProducers(
70     const std::vector<uint64_t> &explicitly_consumed_ids,
71     std::vector<bool> &is_consumed_id,
72     const std::vector<std::vector<uint64_t>> &dependency_graph) {
73   std::vector<uint64_t> queue;
74   queue.reserve(is_consumed_id.size());
75   for (auto id : explicitly_consumed_ids) {
76     is_consumed_id[id] = true;
77     queue.push_back(id);
78   }
79   while (!queue.empty()) {
80     auto i = queue.back();
81     queue.pop_back();
82 
83     // If there is a consumer which cannot be found in dependency graph, return
84     // false.
85     if (i >= dependency_graph.size()) {
86       return failure();
87     }
88 
89     for (auto j : dependency_graph.at(i)) {
90       if (is_consumed_id[j]) continue;
91       queue.push_back(j);
92       is_consumed_id[j] = true;
93     }
94   }
95 
96   return success();
97 }
98 
FindProducers(Value start_node,std::vector<uint64_t> & neighbors)99 void FindProducers(Value start_node, std::vector<uint64_t> &neighbors) {
100   llvm::DenseSet<Value> visited;
101   std::vector<Value> queue;
102   queue.push_back(start_node);
103   visited.insert(start_node);
104   while (!queue.empty()) {
105     auto node = queue.back();
106     queue.pop_back();
107     if (auto arg = node.dyn_cast_or_null<BlockArgument>()) {
108       neighbors.push_back(arg.getArgNumber());
109       continue;
110     }
111     if (!node.getDefiningOp()) continue;
112     for (Value operand : node.getDefiningOp()->getOperands()) {
113       if (visited.contains(operand)) continue;
114       queue.push_back(operand);
115       visited.insert(operand);
116     }
117   }
118 }
119 
FindConsumedOp(Operation * start_op,llvm::DenseSet<Operation * > & consumed_ops)120 void FindConsumedOp(Operation *start_op,
121                     llvm::DenseSet<Operation *> &consumed_ops) {
122   if (consumed_ops.contains(start_op)) return;
123   std::vector<Operation *> queue;
124   queue.push_back(start_op);
125   consumed_ops.insert(start_op);
126   while (!queue.empty()) {
127     auto op = queue.back();
128     queue.pop_back();
129     for (Value operand : op->getOperands()) {
130       if (!operand.getDefiningOp()) continue;
131       auto def_op = operand.getDefiningOp();
132       if (consumed_ops.contains(def_op)) continue;
133       queue.push_back(def_op);
134       consumed_ops.insert(def_op);
135     }
136   }
137 }
138 
IsConstant(Operation * op)139 inline bool IsConstant(Operation *op) { return matchPattern(op, m_Constant()); }
140 
AllOperationSafe(Block & block)141 bool AllOperationSafe(Block &block) {
142   auto walk_result = block.walk([&](Operation *op) {
143     // op has SideEffect.
144     if (!isa_and_nonnull<TFL::WhileOp>(op) &&
145         !op->hasTrait<OpTrait::IsTerminator>() &&
146         !MemoryEffectOpInterface::hasNoEffect(op)) {
147       return WalkResult::interrupt();
148     }
149     // op has implict arguments not listed in operands.
150     // Fact: if every op's operands are defined in the same block as op,
151     //       then no operation has implicit arugments (constant doesn't count).
152     for (auto operand : op->getOperands()) {
153       if (operand.dyn_cast_or_null<BlockArgument>()) continue;
154       auto operand_op = operand.getDefiningOp();
155       if (IsConstant(operand_op)) continue;
156       if (operand_op->getBlock() != op->getBlock()) {
157         return WalkResult::interrupt();
158       }
159     }
160     return WalkResult::advance();
161   });
162   return !walk_result.wasInterrupted();
163 }
164 
165 // It reduces the following pattern:
166 //
167 // S = (0, 0, 0)
168 // while S[2] < 3:
169 //  a0 = S[0] * 2
170 //  a1 = a0 + S[1]
171 //  a2 = S[2] + 1
172 //  S = (a0, a1, a2)
173 // return S[0]
174 //
175 // the 2nd operand (i = 1) as well as its related op (a1 = a0 + S[1])
176 // can be removed since only S[0] is returned.
177 // It cannot be removed by loop-invariant-code-motion pass since every value
178 // is used and changed in the while loop.
179 
180 // Moreover, we require
181 // 1. no implicit argument: For every operation in whileOp, all dependent values
182 //    (except for constant) are explicitly passed in.
183 // 2. no side effect: Every operation inside whileOp can be safely
184 //    remove when it is useEmpty().
185 // 3. no call func inside while.
ReduceWhileOperands(TFL::WhileOp while_op)186 bool ReduceWhileOperands(TFL::WhileOp while_op) {
187   std::vector<uint64_t> explicitly_consumed_ids;
188   Block &cond = while_op.cond().front();
189   Block &body = while_op.body().front();
190 
191   auto n = while_op.getNumOperands();
192   if (!AllOperationSafe(cond) || !AllOperationSafe(body)) return false;
193 
194   // Find all Consumed indices.
195   // i is consumed element if result(i) is used outside whileOp or
196   // arugment(i) is used in whileOp.cond().
197   for (auto i = 0; i < n; ++i) {
198     if (!while_op.getResult(i).use_empty() ||
199         !cond.getArgument(i).use_empty()) {
200       explicitly_consumed_ids.push_back(i);
201     }
202   }
203   // Empty consumed_element_ids implies none of results is used.
204   if (explicitly_consumed_ids.empty()) {
205     while_op.erase();
206     return true;
207   }
208   // If every element is consumed, one can't reduce any operand.
209   if (explicitly_consumed_ids.size() == n) {
210     return false;
211   }
212 
213   // Build the dependency graph.
214   // If result(i) is depend on argument(j) in While.body(), then we put
215   // directed edge (i->j) into the graph.
216   std::vector<std::vector<uint64_t>> dependency_graph;
217   dependency_graph.reserve(n);
218 
219   Operation &yield_op = body.back();
220   auto results = yield_op.getOperands();
221   for (auto i = 0; i < n; ++i) {
222     std::vector<uint64_t> neighbors;
223     neighbors.reserve(n);
224     FindProducers(results[i], neighbors);
225     dependency_graph.push_back(neighbors);
226   }
227 
228   std::vector<bool> is_consumed_id(n, false);
229   if (failed(FindImplicityProducers(explicitly_consumed_ids, is_consumed_id,
230                                     dependency_graph))) {
231     return false;
232   }
233 
234   // Find all consumed operations in while body.
235   llvm::DenseSet<Operation *> consumed_ops;
236   // We'll pass in the erase_indices to erase several operands simultaneously.
237   llvm::BitVector erase_indices(n);
238   consumed_ops.insert(&yield_op);
239   for (auto i = 0; i < n; ++i) {
240     if (!is_consumed_id[i]) {
241       erase_indices.set(i);
242     } else if (results[i].getDefiningOp()) {
243       FindConsumedOp(results[i].getDefiningOp(), consumed_ops);
244     }
245   }
246   // Remove elements and operations in while_body that are not indispensable.
247   yield_op.eraseOperands(erase_indices);
248   // Remove ops from bottom to top.
249   for (Operation &op :
250        llvm::make_early_inc_range(reverse(body.getOperations())))
251     // Constant will not be removed in case it is implicitly used.
252     if (!consumed_ops.contains(&op) && !IsConstant(&op)) {
253       op.erase();
254     }
255   body.eraseArguments(erase_indices);
256   cond.eraseArguments(erase_indices);
257 
258   llvm::SmallVector<Value> new_operands;
259   llvm::SmallVector<Type> new_result_types;
260   new_operands.reserve(n - erase_indices.size());
261   new_result_types.reserve(n - erase_indices.size());
262   // After reducing, the number of results is decreased. The i-th result of old
263   // WhileOp becomes the j-th (j<=i) result of new WhileOp. This information is
264   // stored in id_map (id_map[i] = j).
265   std::vector<uint64_t> id_map(n, 0);
266   uint64_t j = 0;
267   for (auto i = 0; i < n; ++i) {
268     if (is_consumed_id[i]) {
269       id_map[i] = j++;
270       new_operands.push_back(while_op.getOperand(i));
271       new_result_types.push_back(while_op.getResultTypes()[i]);
272     }
273   }
274 
275   auto new_while_op = OpBuilder(while_op).create<WhileOp>(
276       while_op.getLoc(), new_result_types, new_operands, while_op->getAttrs());
277   new_while_op.cond().takeBody(while_op.cond());
278   new_while_op.body().takeBody(while_op.body());
279 
280   for (auto i = 0; i < n; ++i) {
281     if (!while_op.getResult(i).use_empty()) {
282       auto j = id_map[i];
283       while_op.getResult(i).replaceAllUsesWith(new_while_op.getResult(j));
284     }
285   }
286   while_op.erase();
287   return erase_indices.any();
288 }
289 
runOnOperation()290 void ReduceWhileOperandsPass::runOnOperation() {
291   auto fn = getOperation();
292   fn.walk([&](TFL::WhileOp while_op) { ReduceWhileOperands(while_op); });
293 }
294 
295 static PassRegistration<ReduceWhileOperandsPass> pass;
296 }  // namespace
297 
CreateReduceWhileOperandsPass()298 std::unique_ptr<OperationPass<func::FuncOp>> CreateReduceWhileOperandsPass() {
299   return std::make_unique<ReduceWhileOperandsPass>();
300 }
301 
302 }  // namespace TFL
303 }  // namespace mlir
304