1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <limits>
17 #include <memory>
18 #include <queue>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/DenseSet.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
31 #include "mlir/IR/Attributes.h"  // from @llvm-project
32 #include "mlir/IR/Block.h"  // from @llvm-project
33 #include "mlir/IR/Builders.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
36 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
37 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
41 #include "mlir/Support/LLVM.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/cost.h"
43 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h"
44 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
45 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
46 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.h"
47 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
48 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
49 
50 // This pass is used to "determine" the best combinations of the whole "graph".
51 //
52 // Assume we have the graph looks like below:
53 //    subgraph1 (CPU/GPU)    subgraph2 (CPU)
54 //      \                     /
55 //      subgraph3 (CPU/GPU)     subgraph4 (CPU/GPU)
56 //         |                  /
57 //      subgraph5 (CPU/GPU)
58 //         |
59 //      subgraph6 (CPU)
60 //
61 //  We want to evaluate the possible options and minize the overall costs to
62 // produce a graph like below:
63 //
64 //    subgraph1 (GPU)   subgraph2(CPU)
65 //       \              /
66 //     subgraph3 (GPU)      subgraph4(GPU)
67 //         |             /
68 //      subgraph5 (GPU)
69 //         |
70 //      subgraph6 (CPU)
71 //
72 // The overall workflow of the pick subgraphs pass:
73 //  1) Build subgraphs
74 //    1.1) Collect output subgraphs.
75 //    1.2) Build `Subgraph` and their "alternative view" from FuncOp.
76 //  2) Pick subgraphs
77 //    2.1) Populate the "dp table" for (subgraph, hardware).
78 //    2.2) Make decisions based on the populated dp table.
79 //    2.3) Rewire the whole graph based on the desicions.
80 //
81 namespace mlir {
82 namespace TFL {
83 namespace tac {
84 namespace {
85 
86 // GrapView is used to hold the aggregated cost for the given hardware
87 // view.
88 struct GraphView {
89   float total_cost;
90   std::unordered_map<Operation*, InferenceDeviceType> input_subgraph_plans;
91 };
92 
93 // Subgraph is to hold the "conceptual" subgraph.
94 // A subgraph may associate with 1...n FuncOp, and each FuncOp may correspond
95 // with different hardwares.
96 struct Subgraph {
97   // The call can be thought as an "API".
98   func::CallOp call;
99 
100   // available_choces can be viewed as "real implementation" assosicated with
101   // the hardware.
102   std::unordered_map<InferenceDeviceType, func::FuncOp,
103                      InferenceDeviceType::inference_device_type_hash>
104       available_choices;
105 
106   // This will include self (the subgraph itself).
107   // subgraphn
108   //    |
109   // current_subgraph   <- aggregated cost
110   std::unordered_map<InferenceDeviceType, GraphView,
111                      InferenceDeviceType::inference_device_type_hash>
112       aggregated_cost_with_decisions;
113 };
114 
115 // If the output is produced by a callop, will return the callop, otherwise,
116 // will return nullptr.
GetProducerCallOpOrNull(Value output)117 inline func::CallOp GetProducerCallOpOrNull(Value output) {
118   Operation* output_op = output.getDefiningOp();
119   if (output_op != nullptr && llvm::isa<func::CallOp>(output_op)) {
120     return llvm::cast<func::CallOp>(output_op);
121   }
122   return nullptr;
123 }
124 
125 class PickSubgraphsPass
126     : public mlir::PassWrapper<PickSubgraphsPass,
127                                mlir::OperationPass<ModuleOp>> {
128  public:
129   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PickSubgraphsPass)
130 
131  private:
getArgument() const132   llvm::StringRef getArgument() const final { return "tfl-pick-subgraphs"; }
getDescription() const133   llvm::StringRef getDescription() const final {
134     return "Pick the best subgraphs to minimize the overall total costs.";
135   }
136   void runOnOperation() override;
137 
138   std::unordered_map<std::string, std::vector<func::FuncOp>>
139   CollectSubgraphFuncs(ModuleOp module);
140 
141   void BuildSubgraphs(
142       func::FuncOp main_fn,
143       const std::unordered_map<std::string, std::vector<func::FuncOp>>&
144           func_impls,
145       llvm::SetVector<Operation*>* unprocessed_subgraphs,
146       SmallVector<func::CallOp, 4>* output_subgraphs);
147 
148   void ProcessSubgraph(func::CallOp current_graph,
149                        llvm::SetVector<Operation*>* unprocessed_subgraphs);
150 
151   bool PickSubgraphs(
152       llvm::SetVector<Operation*>* all_subgraphs,
153       ArrayRef<func::CallOp> output_subgraphs,
154       const std::unordered_map<std::string, std::vector<func::FuncOp>>&
155           collected_impl_funcs,
156       OpBuilder* builder);
157 
158   // Make the decisions based on the subgraphs.
159   // It may be the case we cannot decide the best scenarios for the user,
160   // in this case, we just return false.
161   bool MakeDecisions(ArrayRef<func::CallOp> output_subgraphs);
162 
163   // Rewire the subgraphs based on the decisions made.
164   // If we cannot make a decisions, we just don't do anything.
165   // TODO(renjieliu): we may change the vector to a map of hardware with
166   // corresponding ipml.
167   void RewireSubgraphs(
168       const std::unordered_map<std::string, std::vector<func::FuncOp>>&
169           collected_impl_funcs,
170       OpBuilder* builder);
171 
172   float GetCostOrFail(func::FuncOp func);
173 
174   llvm::DenseMap<Operation*, Subgraph> subgraphs_;
175 
176   llvm::DenseMap<Operation*, InferenceDeviceType> decisions_;
177 };
178 
GetCostOrFail(func::FuncOp func)179 float PickSubgraphsPass::GetCostOrFail(func::FuncOp func) {
180   float self_cost;
181   if (!GetCostOnOp(func, &self_cost)) {
182     func.emitError("we cannot find cost for this func");
183     signalPassFailure();
184   }
185   return self_cost;
186 }
187 
188 // Here we choose to do a greedy dynamic programming based algorithm for
189 // simplicity.
190 //
191 // See the following graph:
192 //
193 //    input_subgraph_1      ....      input_subgraph_n
194 //              \                          /
195 //               \                        /
196 //                   current_subgraph
197 //                      /     |      \
198 //
199 // Assume all the input subgraphs of the current subgraph are independent.
200 // If we already got optimal results for all the input subgraphs.
201 // Then the current_subgraph's aggregated optimal costs with regards to target
202 // perspective is simply:
203 //     for target in current_subgraph.supported_targets:
204 //       total_cost = 0
205 //       for input_subgraph in current_subgraph.input_subgraphs:
206 //         input_cost = kInfinity
207 //         for input_target in input_subgraphs.upported_targets:
208 //           # cost = aggregated cost for input_subgraph with transfer cost.
209 //           input_cost = min(input_cost, cost)
210 //         total_cost += input_cost
211 //       total_cost += current_subgraph.get_computation_cost(target)
212 //
213 // Note: for input subgraphs are not independent case, the dp case it a little
214 // bit complicated to handle. A potential thought is resolve only where
215 // conflict "happened".
216 //
217 // The above mentioned thought should probably be revisited for better thought
218 // or expanded more for more careful design.
219 // TODO(renjieliu): We may revisit this later.
ProcessSubgraph(func::CallOp current_graph_call,llvm::SetVector<Operation * > * unprocessed_subgraphs)220 void PickSubgraphsPass::ProcessSubgraph(
221     func::CallOp current_graph_call,
222     llvm::SetVector<Operation*>* unprocessed_subgraphs) {
223   Subgraph& current_subgraph = subgraphs_.find(current_graph_call)->second;
224 
225   std::vector<Subgraph*> input_subgraphs;
226   for (auto input : current_graph_call.getOperands()) {
227     func::CallOp input_call = GetProducerCallOpOrNull(input);
228     // If the input subgraph is not processed yet, we just go ahead and process
229     // that one first.
230     if (input_call == nullptr) continue;
231 
232     if (unprocessed_subgraphs->count(input_call) > 0) {
233       unprocessed_subgraphs->remove(input_call);
234       ProcessSubgraph(input_call, unprocessed_subgraphs);
235     }
236     Subgraph& input_subgraph = subgraphs_.find(input_call)->second;
237     input_subgraphs.push_back(&input_subgraph);
238   }
239 
240   // Find the best plan for the current subgraph.
241   for (const auto& kv : current_subgraph.available_choices) {
242     const auto& current_inference_device_type = kv.first;
243     func::FuncOp impl_target = kv.second;
244     float self_compute_cost = GetCostOrFail(impl_target);
245 
246     GraphView current_graph_view;
247     auto& input_subgraph_plans = current_graph_view.input_subgraph_plans;
248 
249     float inputs_total_costs = 0.0;
250     for (Subgraph* input_subgraph : input_subgraphs) {
251       float input_total_cost = std::numeric_limits<float>::max();
252       for (const auto& input_kv : input_subgraph->available_choices) {
253         const auto& input_inference_device_type = input_kv.first;
254         func::FuncOp input_impl_target = input_kv.second;
255         float input_compute_cost = GetCostOrFail(input_impl_target);
256 
257         float transfer_cost =
258             GetTransferCost(input_inference_device_type.hardware,
259                             current_inference_device_type.hardware,
260                             input_subgraph->call, current_graph_call);
261         float quant_dequant_cost =
262             GetQuantDequantCost(input_inference_device_type.inference_type,
263                                 current_inference_device_type.inference_type,
264                                 input_subgraph->call, current_graph_call);
265         float summed_cost =
266             transfer_cost + quant_dequant_cost + input_compute_cost;
267 
268         if (summed_cost < input_total_cost) {
269           // Looks this hardware is better for this input_subgraph, let's change
270           // it.
271           input_total_cost = summed_cost;
272           input_subgraph_plans[input_subgraph->call] =
273               input_inference_device_type;
274         }
275       }  // for every hardware of input_subgraph
276       inputs_total_costs += input_total_cost;
277     }  // for every input_subgraph
278     current_graph_view.total_cost = inputs_total_costs + self_compute_cost;
279     current_subgraph
280         .aggregated_cost_with_decisions[current_inference_device_type] =
281         current_graph_view;
282   }  // for every subgraph
283 }
284 
BuildSubgraphs(func::FuncOp fn,const std::unordered_map<std::string,std::vector<func::FuncOp>> & func_impls,llvm::SetVector<Operation * > * unprocessed_subgraphs,SmallVector<func::CallOp,4> * output_subgraphs)285 void PickSubgraphsPass::BuildSubgraphs(
286     func::FuncOp fn,
287     const std::unordered_map<std::string, std::vector<func::FuncOp>>&
288         func_impls,
289     llvm::SetVector<Operation*>* unprocessed_subgraphs,
290     SmallVector<func::CallOp, 4>* output_subgraphs) {
291   llvm::DenseSet<Operation*> returned_call_op_set;
292   // Collect all returns first from the main function.
293   // all the outputs of the main function are actually the final outputs.
294   // main_func:
295   //  %output1 = call @subgraph_1...
296   //   ...
297   //  %output2 = call @subgraph_m...
298   //   ...
299   //  %outputn = call @subgraph_k...
300   //  return %output1, output2, ..., outputn.
301   fn.walk([&](func::ReturnOp return_op) {
302     for (auto output : return_op.getOperands()) {
303       func::CallOp output_call = GetProducerCallOpOrNull(output);
304       if (output_call != nullptr) {
305         returned_call_op_set.insert(output_call);
306       }
307     }
308   });
309 
310   // Each call op actually is the entry of the subgraph.
311   fn.walk([&](func::CallOp call_op) {
312     auto interface_name = GetInterFaceName(call_op);
313     // we only need to care about the call ops those have interface_name.
314     if (!interface_name.has_value()) return;
315 
316     unprocessed_subgraphs->insert(call_op);
317 
318     // Build the subgraph.
319     Subgraph subgraph;
320     subgraph.call = call_op;
321     auto impl_iter = func_impls.find(interface_name.getValue());
322     if (impl_iter == func_impls.end()) {
323       call_op.emitError(
324           "we cannot find corresponding implementation for this call op");
325       signalPassFailure();
326     }
327 
328     for (auto impl : impl_iter->second) {
329       auto inference_device_type = GetInferenceDeviceTypeForOp(impl);
330       if (!inference_device_type.has_value()) {
331         impl.emitError("we cannot find inference device type for this func");
332         signalPassFailure();
333       }
334       subgraph.available_choices.emplace(inference_device_type.getValue(),
335                                          impl);
336     }
337 
338     // Insert in the subgraphs.
339     subgraphs_.try_emplace(call_op, subgraph);
340 
341     // If it's an output subgraph, we will add to the output_subgraphs.
342     if (returned_call_op_set.find(call_op) != returned_call_op_set.end()) {
343       output_subgraphs->push_back(call_op);
344     }
345   });
346 }
347 
348 // Collect all the subgraphs (and their alternatives) in the module.
349 std::unordered_map<std::string, std::vector<func::FuncOp>>
CollectSubgraphFuncs(ModuleOp module)350 PickSubgraphsPass::CollectSubgraphFuncs(ModuleOp module) {
351   std::unordered_map<std::string, std::vector<func::FuncOp>> func_impls;
352   for (auto func : module.getOps<func::FuncOp>()) {
353     auto interface_name = GetInterFaceName(func);
354     if (interface_name.has_value()) {
355       auto impls_iter = func_impls.find(interface_name.getValue());
356       if (impls_iter == func_impls.end())
357         impls_iter =
358             func_impls
359                 .emplace(interface_name.getValue(), std::vector<func::FuncOp>())
360                 .first;
361       impls_iter->second.push_back(func);
362     }
363   }
364   return func_impls;
365 }
366 
367 // Given the final outputs, evaluate on the overall costs and pick the best
368 // plan, if we cannot make a decision, nothing would change, just fallback
369 // to the original plan.
MakeDecisions(ArrayRef<func::CallOp> output_subgraphs)370 bool PickSubgraphsPass::MakeDecisions(ArrayRef<func::CallOp> output_subgraphs) {
371   // BFS to make decisions.
372   std::queue<const GraphView*> processing_queue;
373   for (func::CallOp output : output_subgraphs) {
374     const GraphView* preferred_graph_view;
375     float minimum_cost = std::numeric_limits<float>::max();
376 
377     const Subgraph& subgraph = subgraphs_.find(output)->second;
378     for (const auto& kv : subgraph.aggregated_cost_with_decisions) {
379       if (minimum_cost > kv.second.total_cost) {
380         minimum_cost = kv.second.total_cost;
381         preferred_graph_view = &kv.second;
382         decisions_[output] = kv.first;
383       }
384     }
385 
386     processing_queue.push(preferred_graph_view);
387   }
388 
389   // If we see conflict, we will just abort.
390   while (!processing_queue.empty()) {
391     const GraphView* current = processing_queue.front();
392     processing_queue.pop();
393     for (const auto& input_with_plans : current->input_subgraph_plans) {
394       func::CallOp input = llvm::cast<func::CallOp>(input_with_plans.first);
395       const InferenceDeviceType& input_decision = input_with_plans.second;
396       auto made_input_decision_it = decisions_.find(input);
397       if (made_input_decision_it == decisions_.end()) {
398         // Input is not processed.
399         // Let's process it, also push it to the queue.
400         decisions_[input] = input_decision;
401         const Subgraph& input_subgraph = subgraphs_.find(input)->second;
402         const GraphView& input_subgraph_view =
403             input_subgraph.aggregated_cost_with_decisions.find(input_decision)
404                 ->second;
405         processing_queue.push(&input_subgraph_view);
406       } else if (made_input_decision_it->second != input_decision) {
407         // We see confliction, we need to abort.
408         return false;
409       }
410     }
411   }
412   return true;
413 }
414 
415 // This rewire subgraph is essentially "hook" the call op with the "best" choice
416 // (subgraph).
RewireSubgraphs(const std::unordered_map<std::string,std::vector<func::FuncOp>> & collected_impl_funcs,OpBuilder * builder)417 void PickSubgraphsPass::RewireSubgraphs(
418     const std::unordered_map<std::string, std::vector<func::FuncOp>>&
419         collected_impl_funcs,
420     OpBuilder* builder) {
421   for (auto& kv : decisions_) {
422     func::CallOp call = llvm::cast<func::CallOp>(kv.first);
423 
424     const InferenceDeviceType& preferred_inference_device_type = kv.second;
425 
426     // We need to rewire the call.
427     std::string interface_name = GetInterFaceName(call).getValue();
428     for (auto impl : collected_impl_funcs.find(interface_name)->second) {
429       const auto& impl_inference_device_type =
430           GetInferenceDeviceTypeForOp(impl);
431       if (impl_inference_device_type.getValue() ==
432           preferred_inference_device_type) {
433         if (call.getCallee() != impl.getName()) {
434           // We need to rebuild the call op. :(
435           builder->setInsertionPoint(call);
436           auto new_call = builder->create<func::CallOp>(call.getLoc(), impl,
437                                                         call.getOperands());
438 
439           // Set interface_name & target to the call_op as well.
440           new_call->setAttr(kInterfaceNameAttr,
441                             builder->getStringAttr(interface_name));
442           new_call->setAttr(
443               kDevice,
444               builder->getStringAttr(preferred_inference_device_type.hardware));
445           new_call->setAttr(
446               kInferenceType,
447               builder->getStringAttr(GetInferenceString(
448                   preferred_inference_device_type.inference_type)));
449 
450           call.replaceAllUsesWith(new_call.getResults());
451           call.erase();
452         }
453       }
454     }
455   }
456 }
457 
PickSubgraphs(llvm::SetVector<Operation * > * all_subgraphs,ArrayRef<func::CallOp> output_subgraphs,const std::unordered_map<std::string,std::vector<func::FuncOp>> & collected_impl_funcs,OpBuilder * builder)458 bool PickSubgraphsPass::PickSubgraphs(
459     llvm::SetVector<Operation*>* all_subgraphs,
460     ArrayRef<func::CallOp> output_subgraphs,
461     const std::unordered_map<std::string, std::vector<func::FuncOp>>&
462         collected_impl_funcs,
463     OpBuilder* builder) {
464   // Process those collected unprocessed subgraphs.
465   //
466   // Algorithm complexity for this:
467   // This Complexity should be O(edge * specs ^ 2).
468   // We should expect the specs to be a small number.
469   // In future, the spesc can be Hardwares x inference_types
470   // The Hardware can be {CPU, GPU, DSP, EDGE_TPU}
471   // The inference_types can be {float, Q_INT8, float16}.
472   // But still, we should expect the specs to be a small number.
473   //
474   // The process is essentially evaluating the accumulated cost for the dp table
475   // for all the subgraphs (and their alternatives).
476   while (!all_subgraphs->empty()) {
477     func::CallOp current_subgraph =
478         llvm::cast<func::CallOp>(all_subgraphs->front());
479     all_subgraphs->remove(current_subgraph);
480     ProcessSubgraph(current_subgraph, all_subgraphs);
481   }
482 
483   // Make decisions given the "outputs" and the populated dp table.
484   // This is hoping to achieve a global minimum.
485   if (!MakeDecisions(output_subgraphs)) {
486     return false;
487   }
488 
489   // Once the design has been made.
490   // Start from the outputs and go back and checkout the plan.
491   RewireSubgraphs(collected_impl_funcs, builder);
492 
493   return true;
494 }
495 
runOnOperation()496 void PickSubgraphsPass::runOnOperation() {
497   auto module = getOperation();
498   // Collect & build the subgraphs.
499   // Also collect the output subgraphs.
500   // Output subgraphs are essentially those subgraphs pointed by the return
501   // op.
502   const std::unordered_map<std::string, std::vector<func::FuncOp>> func_impls =
503       CollectSubgraphFuncs(module);
504   llvm::SetVector<Operation*> unprocessed_subgraphs;
505   SmallVector<func::CallOp, 4> output_subgraphs;
506 
507   for (auto fn : module.getOps<func::FuncOp>()) {
508     BuildSubgraphs(fn, func_impls, &unprocessed_subgraphs, &output_subgraphs);
509   }
510   OpBuilder builder(module);
511   if (!PickSubgraphs(&unprocessed_subgraphs, output_subgraphs, func_impls,
512                      &builder)) {
513     module.emitWarning(
514         "we cannot find the best scenarios for your case, so we just use "
515         "your original model plans");
516   }
517 }
518 
519 }  // namespace
520 
CreatePickSubgraphsPass()521 std::unique_ptr<OperationPass<ModuleOp>> CreatePickSubgraphsPass() {
522   return std::make_unique<PickSubgraphsPass>();
523 }
524 
525 static PassRegistration<PickSubgraphsPass> pass;
526 
527 }  // namespace tac
528 }  // namespace TFL
529 }  // namespace mlir
530