1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <limits>
17 #include <memory>
18 #include <queue>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/DenseSet.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
31 #include "mlir/IR/Attributes.h" // from @llvm-project
32 #include "mlir/IR/Block.h" // from @llvm-project
33 #include "mlir/IR/Builders.h" // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
35 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
36 #include "mlir/IR/MLIRContext.h" // from @llvm-project
37 #include "mlir/IR/PatternMatch.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
41 #include "mlir/Support/LLVM.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/cost.h"
43 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h"
44 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
45 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
46 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.h"
47 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
48 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
49
50 // This pass is used to "determine" the best combinations of the whole "graph".
51 //
52 // Assume we have the graph looks like below:
53 // subgraph1 (CPU/GPU) subgraph2 (CPU)
54 // \ /
55 // subgraph3 (CPU/GPU) subgraph4 (CPU/GPU)
56 // | /
57 // subgraph5 (CPU/GPU)
58 // |
59 // subgraph6 (CPU)
60 //
61 // We want to evaluate the possible options and minize the overall costs to
62 // produce a graph like below:
63 //
64 // subgraph1 (GPU) subgraph2(CPU)
65 // \ /
66 // subgraph3 (GPU) subgraph4(GPU)
67 // | /
68 // subgraph5 (GPU)
69 // |
70 // subgraph6 (CPU)
71 //
72 // The overall workflow of the pick subgraphs pass:
73 // 1) Build subgraphs
74 // 1.1) Collect output subgraphs.
75 // 1.2) Build `Subgraph` and their "alternative view" from FuncOp.
76 // 2) Pick subgraphs
77 // 2.1) Populate the "dp table" for (subgraph, hardware).
78 // 2.2) Make decisions based on the populated dp table.
79 // 2.3) Rewire the whole graph based on the desicions.
80 //
81 namespace mlir {
82 namespace TFL {
83 namespace tac {
84 namespace {
85
86 // GrapView is used to hold the aggregated cost for the given hardware
87 // view.
88 struct GraphView {
89 float total_cost;
90 std::unordered_map<Operation*, InferenceDeviceType> input_subgraph_plans;
91 };
92
93 // Subgraph is to hold the "conceptual" subgraph.
94 // A subgraph may associate with 1...n FuncOp, and each FuncOp may correspond
95 // with different hardwares.
96 struct Subgraph {
97 // The call can be thought as an "API".
98 func::CallOp call;
99
100 // available_choces can be viewed as "real implementation" assosicated with
101 // the hardware.
102 std::unordered_map<InferenceDeviceType, func::FuncOp,
103 InferenceDeviceType::inference_device_type_hash>
104 available_choices;
105
106 // This will include self (the subgraph itself).
107 // subgraphn
108 // |
109 // current_subgraph <- aggregated cost
110 std::unordered_map<InferenceDeviceType, GraphView,
111 InferenceDeviceType::inference_device_type_hash>
112 aggregated_cost_with_decisions;
113 };
114
115 // If the output is produced by a callop, will return the callop, otherwise,
116 // will return nullptr.
GetProducerCallOpOrNull(Value output)117 inline func::CallOp GetProducerCallOpOrNull(Value output) {
118 Operation* output_op = output.getDefiningOp();
119 if (output_op != nullptr && llvm::isa<func::CallOp>(output_op)) {
120 return llvm::cast<func::CallOp>(output_op);
121 }
122 return nullptr;
123 }
124
125 class PickSubgraphsPass
126 : public mlir::PassWrapper<PickSubgraphsPass,
127 mlir::OperationPass<ModuleOp>> {
128 public:
129 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PickSubgraphsPass)
130
131 private:
getArgument() const132 llvm::StringRef getArgument() const final { return "tfl-pick-subgraphs"; }
getDescription() const133 llvm::StringRef getDescription() const final {
134 return "Pick the best subgraphs to minimize the overall total costs.";
135 }
136 void runOnOperation() override;
137
138 std::unordered_map<std::string, std::vector<func::FuncOp>>
139 CollectSubgraphFuncs(ModuleOp module);
140
141 void BuildSubgraphs(
142 func::FuncOp main_fn,
143 const std::unordered_map<std::string, std::vector<func::FuncOp>>&
144 func_impls,
145 llvm::SetVector<Operation*>* unprocessed_subgraphs,
146 SmallVector<func::CallOp, 4>* output_subgraphs);
147
148 void ProcessSubgraph(func::CallOp current_graph,
149 llvm::SetVector<Operation*>* unprocessed_subgraphs);
150
151 bool PickSubgraphs(
152 llvm::SetVector<Operation*>* all_subgraphs,
153 ArrayRef<func::CallOp> output_subgraphs,
154 const std::unordered_map<std::string, std::vector<func::FuncOp>>&
155 collected_impl_funcs,
156 OpBuilder* builder);
157
158 // Make the decisions based on the subgraphs.
159 // It may be the case we cannot decide the best scenarios for the user,
160 // in this case, we just return false.
161 bool MakeDecisions(ArrayRef<func::CallOp> output_subgraphs);
162
163 // Rewire the subgraphs based on the decisions made.
164 // If we cannot make a decisions, we just don't do anything.
165 // TODO(renjieliu): we may change the vector to a map of hardware with
166 // corresponding ipml.
167 void RewireSubgraphs(
168 const std::unordered_map<std::string, std::vector<func::FuncOp>>&
169 collected_impl_funcs,
170 OpBuilder* builder);
171
172 float GetCostOrFail(func::FuncOp func);
173
174 llvm::DenseMap<Operation*, Subgraph> subgraphs_;
175
176 llvm::DenseMap<Operation*, InferenceDeviceType> decisions_;
177 };
178
GetCostOrFail(func::FuncOp func)179 float PickSubgraphsPass::GetCostOrFail(func::FuncOp func) {
180 float self_cost;
181 if (!GetCostOnOp(func, &self_cost)) {
182 func.emitError("we cannot find cost for this func");
183 signalPassFailure();
184 }
185 return self_cost;
186 }
187
188 // Here we choose to do a greedy dynamic programming based algorithm for
189 // simplicity.
190 //
191 // See the following graph:
192 //
193 // input_subgraph_1 .... input_subgraph_n
194 // \ /
195 // \ /
196 // current_subgraph
197 // / | \
198 //
199 // Assume all the input subgraphs of the current subgraph are independent.
200 // If we already got optimal results for all the input subgraphs.
201 // Then the current_subgraph's aggregated optimal costs with regards to target
202 // perspective is simply:
203 // for target in current_subgraph.supported_targets:
204 // total_cost = 0
205 // for input_subgraph in current_subgraph.input_subgraphs:
206 // input_cost = kInfinity
207 // for input_target in input_subgraphs.upported_targets:
208 // # cost = aggregated cost for input_subgraph with transfer cost.
209 // input_cost = min(input_cost, cost)
210 // total_cost += input_cost
211 // total_cost += current_subgraph.get_computation_cost(target)
212 //
213 // Note: for input subgraphs are not independent case, the dp case it a little
214 // bit complicated to handle. A potential thought is resolve only where
215 // conflict "happened".
216 //
217 // The above mentioned thought should probably be revisited for better thought
218 // or expanded more for more careful design.
219 // TODO(renjieliu): We may revisit this later.
ProcessSubgraph(func::CallOp current_graph_call,llvm::SetVector<Operation * > * unprocessed_subgraphs)220 void PickSubgraphsPass::ProcessSubgraph(
221 func::CallOp current_graph_call,
222 llvm::SetVector<Operation*>* unprocessed_subgraphs) {
223 Subgraph& current_subgraph = subgraphs_.find(current_graph_call)->second;
224
225 std::vector<Subgraph*> input_subgraphs;
226 for (auto input : current_graph_call.getOperands()) {
227 func::CallOp input_call = GetProducerCallOpOrNull(input);
228 // If the input subgraph is not processed yet, we just go ahead and process
229 // that one first.
230 if (input_call == nullptr) continue;
231
232 if (unprocessed_subgraphs->count(input_call) > 0) {
233 unprocessed_subgraphs->remove(input_call);
234 ProcessSubgraph(input_call, unprocessed_subgraphs);
235 }
236 Subgraph& input_subgraph = subgraphs_.find(input_call)->second;
237 input_subgraphs.push_back(&input_subgraph);
238 }
239
240 // Find the best plan for the current subgraph.
241 for (const auto& kv : current_subgraph.available_choices) {
242 const auto& current_inference_device_type = kv.first;
243 func::FuncOp impl_target = kv.second;
244 float self_compute_cost = GetCostOrFail(impl_target);
245
246 GraphView current_graph_view;
247 auto& input_subgraph_plans = current_graph_view.input_subgraph_plans;
248
249 float inputs_total_costs = 0.0;
250 for (Subgraph* input_subgraph : input_subgraphs) {
251 float input_total_cost = std::numeric_limits<float>::max();
252 for (const auto& input_kv : input_subgraph->available_choices) {
253 const auto& input_inference_device_type = input_kv.first;
254 func::FuncOp input_impl_target = input_kv.second;
255 float input_compute_cost = GetCostOrFail(input_impl_target);
256
257 float transfer_cost =
258 GetTransferCost(input_inference_device_type.hardware,
259 current_inference_device_type.hardware,
260 input_subgraph->call, current_graph_call);
261 float quant_dequant_cost =
262 GetQuantDequantCost(input_inference_device_type.inference_type,
263 current_inference_device_type.inference_type,
264 input_subgraph->call, current_graph_call);
265 float summed_cost =
266 transfer_cost + quant_dequant_cost + input_compute_cost;
267
268 if (summed_cost < input_total_cost) {
269 // Looks this hardware is better for this input_subgraph, let's change
270 // it.
271 input_total_cost = summed_cost;
272 input_subgraph_plans[input_subgraph->call] =
273 input_inference_device_type;
274 }
275 } // for every hardware of input_subgraph
276 inputs_total_costs += input_total_cost;
277 } // for every input_subgraph
278 current_graph_view.total_cost = inputs_total_costs + self_compute_cost;
279 current_subgraph
280 .aggregated_cost_with_decisions[current_inference_device_type] =
281 current_graph_view;
282 } // for every subgraph
283 }
284
BuildSubgraphs(func::FuncOp fn,const std::unordered_map<std::string,std::vector<func::FuncOp>> & func_impls,llvm::SetVector<Operation * > * unprocessed_subgraphs,SmallVector<func::CallOp,4> * output_subgraphs)285 void PickSubgraphsPass::BuildSubgraphs(
286 func::FuncOp fn,
287 const std::unordered_map<std::string, std::vector<func::FuncOp>>&
288 func_impls,
289 llvm::SetVector<Operation*>* unprocessed_subgraphs,
290 SmallVector<func::CallOp, 4>* output_subgraphs) {
291 llvm::DenseSet<Operation*> returned_call_op_set;
292 // Collect all returns first from the main function.
293 // all the outputs of the main function are actually the final outputs.
294 // main_func:
295 // %output1 = call @subgraph_1...
296 // ...
297 // %output2 = call @subgraph_m...
298 // ...
299 // %outputn = call @subgraph_k...
300 // return %output1, output2, ..., outputn.
301 fn.walk([&](func::ReturnOp return_op) {
302 for (auto output : return_op.getOperands()) {
303 func::CallOp output_call = GetProducerCallOpOrNull(output);
304 if (output_call != nullptr) {
305 returned_call_op_set.insert(output_call);
306 }
307 }
308 });
309
310 // Each call op actually is the entry of the subgraph.
311 fn.walk([&](func::CallOp call_op) {
312 auto interface_name = GetInterFaceName(call_op);
313 // we only need to care about the call ops those have interface_name.
314 if (!interface_name.has_value()) return;
315
316 unprocessed_subgraphs->insert(call_op);
317
318 // Build the subgraph.
319 Subgraph subgraph;
320 subgraph.call = call_op;
321 auto impl_iter = func_impls.find(interface_name.getValue());
322 if (impl_iter == func_impls.end()) {
323 call_op.emitError(
324 "we cannot find corresponding implementation for this call op");
325 signalPassFailure();
326 }
327
328 for (auto impl : impl_iter->second) {
329 auto inference_device_type = GetInferenceDeviceTypeForOp(impl);
330 if (!inference_device_type.has_value()) {
331 impl.emitError("we cannot find inference device type for this func");
332 signalPassFailure();
333 }
334 subgraph.available_choices.emplace(inference_device_type.getValue(),
335 impl);
336 }
337
338 // Insert in the subgraphs.
339 subgraphs_.try_emplace(call_op, subgraph);
340
341 // If it's an output subgraph, we will add to the output_subgraphs.
342 if (returned_call_op_set.find(call_op) != returned_call_op_set.end()) {
343 output_subgraphs->push_back(call_op);
344 }
345 });
346 }
347
348 // Collect all the subgraphs (and their alternatives) in the module.
349 std::unordered_map<std::string, std::vector<func::FuncOp>>
CollectSubgraphFuncs(ModuleOp module)350 PickSubgraphsPass::CollectSubgraphFuncs(ModuleOp module) {
351 std::unordered_map<std::string, std::vector<func::FuncOp>> func_impls;
352 for (auto func : module.getOps<func::FuncOp>()) {
353 auto interface_name = GetInterFaceName(func);
354 if (interface_name.has_value()) {
355 auto impls_iter = func_impls.find(interface_name.getValue());
356 if (impls_iter == func_impls.end())
357 impls_iter =
358 func_impls
359 .emplace(interface_name.getValue(), std::vector<func::FuncOp>())
360 .first;
361 impls_iter->second.push_back(func);
362 }
363 }
364 return func_impls;
365 }
366
367 // Given the final outputs, evaluate on the overall costs and pick the best
368 // plan, if we cannot make a decision, nothing would change, just fallback
369 // to the original plan.
MakeDecisions(ArrayRef<func::CallOp> output_subgraphs)370 bool PickSubgraphsPass::MakeDecisions(ArrayRef<func::CallOp> output_subgraphs) {
371 // BFS to make decisions.
372 std::queue<const GraphView*> processing_queue;
373 for (func::CallOp output : output_subgraphs) {
374 const GraphView* preferred_graph_view;
375 float minimum_cost = std::numeric_limits<float>::max();
376
377 const Subgraph& subgraph = subgraphs_.find(output)->second;
378 for (const auto& kv : subgraph.aggregated_cost_with_decisions) {
379 if (minimum_cost > kv.second.total_cost) {
380 minimum_cost = kv.second.total_cost;
381 preferred_graph_view = &kv.second;
382 decisions_[output] = kv.first;
383 }
384 }
385
386 processing_queue.push(preferred_graph_view);
387 }
388
389 // If we see conflict, we will just abort.
390 while (!processing_queue.empty()) {
391 const GraphView* current = processing_queue.front();
392 processing_queue.pop();
393 for (const auto& input_with_plans : current->input_subgraph_plans) {
394 func::CallOp input = llvm::cast<func::CallOp>(input_with_plans.first);
395 const InferenceDeviceType& input_decision = input_with_plans.second;
396 auto made_input_decision_it = decisions_.find(input);
397 if (made_input_decision_it == decisions_.end()) {
398 // Input is not processed.
399 // Let's process it, also push it to the queue.
400 decisions_[input] = input_decision;
401 const Subgraph& input_subgraph = subgraphs_.find(input)->second;
402 const GraphView& input_subgraph_view =
403 input_subgraph.aggregated_cost_with_decisions.find(input_decision)
404 ->second;
405 processing_queue.push(&input_subgraph_view);
406 } else if (made_input_decision_it->second != input_decision) {
407 // We see confliction, we need to abort.
408 return false;
409 }
410 }
411 }
412 return true;
413 }
414
415 // This rewire subgraph is essentially "hook" the call op with the "best" choice
416 // (subgraph).
RewireSubgraphs(const std::unordered_map<std::string,std::vector<func::FuncOp>> & collected_impl_funcs,OpBuilder * builder)417 void PickSubgraphsPass::RewireSubgraphs(
418 const std::unordered_map<std::string, std::vector<func::FuncOp>>&
419 collected_impl_funcs,
420 OpBuilder* builder) {
421 for (auto& kv : decisions_) {
422 func::CallOp call = llvm::cast<func::CallOp>(kv.first);
423
424 const InferenceDeviceType& preferred_inference_device_type = kv.second;
425
426 // We need to rewire the call.
427 std::string interface_name = GetInterFaceName(call).getValue();
428 for (auto impl : collected_impl_funcs.find(interface_name)->second) {
429 const auto& impl_inference_device_type =
430 GetInferenceDeviceTypeForOp(impl);
431 if (impl_inference_device_type.getValue() ==
432 preferred_inference_device_type) {
433 if (call.getCallee() != impl.getName()) {
434 // We need to rebuild the call op. :(
435 builder->setInsertionPoint(call);
436 auto new_call = builder->create<func::CallOp>(call.getLoc(), impl,
437 call.getOperands());
438
439 // Set interface_name & target to the call_op as well.
440 new_call->setAttr(kInterfaceNameAttr,
441 builder->getStringAttr(interface_name));
442 new_call->setAttr(
443 kDevice,
444 builder->getStringAttr(preferred_inference_device_type.hardware));
445 new_call->setAttr(
446 kInferenceType,
447 builder->getStringAttr(GetInferenceString(
448 preferred_inference_device_type.inference_type)));
449
450 call.replaceAllUsesWith(new_call.getResults());
451 call.erase();
452 }
453 }
454 }
455 }
456 }
457
PickSubgraphs(llvm::SetVector<Operation * > * all_subgraphs,ArrayRef<func::CallOp> output_subgraphs,const std::unordered_map<std::string,std::vector<func::FuncOp>> & collected_impl_funcs,OpBuilder * builder)458 bool PickSubgraphsPass::PickSubgraphs(
459 llvm::SetVector<Operation*>* all_subgraphs,
460 ArrayRef<func::CallOp> output_subgraphs,
461 const std::unordered_map<std::string, std::vector<func::FuncOp>>&
462 collected_impl_funcs,
463 OpBuilder* builder) {
464 // Process those collected unprocessed subgraphs.
465 //
466 // Algorithm complexity for this:
467 // This Complexity should be O(edge * specs ^ 2).
468 // We should expect the specs to be a small number.
469 // In future, the spesc can be Hardwares x inference_types
470 // The Hardware can be {CPU, GPU, DSP, EDGE_TPU}
471 // The inference_types can be {float, Q_INT8, float16}.
472 // But still, we should expect the specs to be a small number.
473 //
474 // The process is essentially evaluating the accumulated cost for the dp table
475 // for all the subgraphs (and their alternatives).
476 while (!all_subgraphs->empty()) {
477 func::CallOp current_subgraph =
478 llvm::cast<func::CallOp>(all_subgraphs->front());
479 all_subgraphs->remove(current_subgraph);
480 ProcessSubgraph(current_subgraph, all_subgraphs);
481 }
482
483 // Make decisions given the "outputs" and the populated dp table.
484 // This is hoping to achieve a global minimum.
485 if (!MakeDecisions(output_subgraphs)) {
486 return false;
487 }
488
489 // Once the design has been made.
490 // Start from the outputs and go back and checkout the plan.
491 RewireSubgraphs(collected_impl_funcs, builder);
492
493 return true;
494 }
495
runOnOperation()496 void PickSubgraphsPass::runOnOperation() {
497 auto module = getOperation();
498 // Collect & build the subgraphs.
499 // Also collect the output subgraphs.
500 // Output subgraphs are essentially those subgraphs pointed by the return
501 // op.
502 const std::unordered_map<std::string, std::vector<func::FuncOp>> func_impls =
503 CollectSubgraphFuncs(module);
504 llvm::SetVector<Operation*> unprocessed_subgraphs;
505 SmallVector<func::CallOp, 4> output_subgraphs;
506
507 for (auto fn : module.getOps<func::FuncOp>()) {
508 BuildSubgraphs(fn, func_impls, &unprocessed_subgraphs, &output_subgraphs);
509 }
510 OpBuilder builder(module);
511 if (!PickSubgraphs(&unprocessed_subgraphs, output_subgraphs, func_impls,
512 &builder)) {
513 module.emitWarning(
514 "we cannot find the best scenarios for your case, so we just use "
515 "your original model plans");
516 }
517 }
518
519 } // namespace
520
CreatePickSubgraphsPass()521 std::unique_ptr<OperationPass<ModuleOp>> CreatePickSubgraphsPass() {
522 return std::make_unique<PickSubgraphsPass>();
523 }
524
525 static PassRegistration<PickSubgraphsPass> pass;
526
527 } // namespace tac
528 } // namespace TFL
529 } // namespace mlir
530