1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <utility>
19 
20 #include "absl/strings/str_cat.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Block.h"  // from @llvm-project
32 #include "mlir/IR/Builders.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
36 #include "mlir/IR/Value.h"  // from @llvm-project
37 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
40 #include "mlir/Support/LLVM.h"  // from @llvm-project
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h"
43 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
44 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
45 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
46 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.h"
47 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
48 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
49 
50 namespace mlir {
51 namespace TFL {
52 namespace tac {
53 namespace {
54 
55 // Given the function interface name and the InferenceDeviceType, return the
56 // new function name.
GetFunctionImplName(std::string interface_name,const InferenceDeviceType & device_inference_type)57 std::string GetFunctionImplName(
58     std::string interface_name,
59     const InferenceDeviceType& device_inference_type) {
60   return absl::StrCat(interface_name, "_", device_inference_type.hardware, "_",
61                       GetInferenceString(device_inference_type.inference_type));
62 }
63 
64 // For every device, we will do the following:
65 // If the inference type is quantized, we will try the float alternative.
66 // If it's float, we will just keep it as it is.
GetAllAlternativeInferenceDeviceType(InferenceType inference_type,ArrayRef<std::string> devices)67 std::vector<InferenceDeviceType> GetAllAlternativeInferenceDeviceType(
68     InferenceType inference_type, ArrayRef<std::string> devices) {
69   std::vector<InferenceDeviceType> all_device_inference_types;
70   for (const auto& device : devices) {
71     if (inference_type == QUANTIZED_INT8) {
72       all_device_inference_types.push_back({device, QUANTIZED_INT8});
73     } else if (inference_type == QUANTIZED_UINT8) {
74       all_device_inference_types.push_back({device, QUANTIZED_UINT8});
75     }
76 
77     // We will alway enable float.
78     all_device_inference_types.push_back({device, FLOAT});
79   }
80 
81   return all_device_inference_types;
82 }
83 
84 // This pass will try to get alternative subgraph:
85 // Say a subgraph is annotated with CPU (it probably means the ops it contains
86 // cannot be run on other deviecs):
87 //
88 // We will try:
89 // 1) If we can do some mathmatically equaivalent transformation so this
90 //   subgraph can be run on other devices.
91 // 2) We will other apply device-specifics optimizations as well, that includes
92 //   maybe tensor layout transformation, device specific fusion, etc.
93 class AlternativeSubgraphPass
94     : public mlir::PassWrapper<AlternativeSubgraphPass,
95                                mlir::OperationPass<ModuleOp>> {
96  public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AlternativeSubgraphPass)97   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AlternativeSubgraphPass)
98 
99   llvm::StringRef getArgument() const final {
100     return "tfl-get-alternative-subgraph";
101   }
getDescription() const102   llvm::StringRef getDescription() const final {
103     return "Get alternative subgraph representation (if appliable) for all the "
104            "given devices, will by default include the cpu implementation.";
105   }
106   AlternativeSubgraphPass() = default;
AlternativeSubgraphPass(const AlternativeSubgraphPass &)107   AlternativeSubgraphPass(const AlternativeSubgraphPass&) {}
AlternativeSubgraphPass(llvm::ArrayRef<std::string> device_specs)108   explicit AlternativeSubgraphPass(llvm::ArrayRef<std::string> device_specs) {
109     device_specs_flag_ = device_specs;
110   }
111 
112  private:
113   void runOnOperation() override;
114 
115   // Given a func and targeted devices, we will try to clonse the func &
116   // transform/optimize for those devices.
117   // This will only happen if the whole subgraph can be supported by the target
118   // or can be supported after some transformations.
119   void GetAlternativeGraphForFunc(ArrayRef<std::string> devices,
120                                   func::FuncOp func, ModuleOp module,
121                                   OpBuilder* builder);
122 
123   // If all ops in the func op is able to be represented in the hardware, we
124   // will return true, else will be false.
125   // This is basically all or nothing.
126   bool IsAllSupportedbySpec(func::FuncOp func,
127                             const InferenceDeviceType& inference_type);
128 
129   // Given a func and a targeted device, we will try to clonse the func &
130   // transform/optimize for that device.
131   // It's simply clone the FuncOp and hardware specific transformations.
132   func::FuncOp GetAlternativeViewForSpec(
133       func::FuncOp func,
134       const InferenceDeviceType& current_device_inference_type,
135       const InferenceDeviceType& target_device_inference_type, ModuleOp module,
136       OpBuilder* builder);
137 
138   // Apply any device-specific optimizations.
139   void Optimize(func::FuncOp func, const std::string& hardware);
140 
141   ListOption<std::string> device_specs_flag_{
142       *this, "device-specs",
143       llvm::cl::desc(
144           "comma separated list of device specs, like CPU, GPU, DPS."),
145       llvm::cl::ZeroOrMore};
146 };
147 
GetAlternativeGraphForFunc(ArrayRef<std::string> devices,func::FuncOp func,ModuleOp module,OpBuilder * builder)148 void AlternativeSubgraphPass::GetAlternativeGraphForFunc(
149     ArrayRef<std::string> devices, func::FuncOp func, ModuleOp module,
150     OpBuilder* builder) {
151   auto current_device = GetTargetAnnotation(func);
152   if (current_device->empty()) {
153     func.emitError(
154         "cannot find target annotation or unknown device specified for current "
155         "function");
156     return;
157   }
158 
159   auto current_inference_type = GetInferenceTypeAnnotation(func);
160   if (!current_inference_type.has_value() ||
161       current_inference_type == UNKNOWN) {
162     func.emitError(
163         "cannot find inference type annotation or unknown inference type "
164         "specified for current "
165         "function");
166     return;
167   }
168 
169   const InferenceDeviceType current_device_type(
170       {current_device.getValue(), current_inference_type.getValue()});
171 
172   const std::vector<InferenceDeviceType>& all_inference_device_type =
173       GetAllAlternativeInferenceDeviceType(current_inference_type.getValue(),
174                                            devices);
175 
176   for (const auto& device_inference_type : all_inference_device_type) {
177     if (device_inference_type != current_device_type) {
178       func::FuncOp cloned_func = GetAlternativeViewForSpec(
179           func, current_device_type, device_inference_type, module, builder);
180       // If we found unsupported ops, we will just go ahead and remove this
181       // function.
182       // TODO(b/160284136): currently we check if the ops are supported then
183       // see if we need to erase the func op.
184       // Ideally it would be nice if we can utilize dynamic illegal op to do
185       // the job.
186       if (!IsAllSupportedbySpec(cloned_func, device_inference_type)) {
187         cloned_func.erase();
188       }
189     }
190   }
191 
192   // Perform the device-specific optimization last.
193   // We need to run the optimization for the current device last because we
194   // need to avoid any changes made the current graph polluting other
195   // alternative graph views.
196   Optimize(func, current_device.getValue());
197 }
198 
IsAllSupportedbySpec(func::FuncOp func,const InferenceDeviceType & device_inference_type)199 bool AlternativeSubgraphPass::IsAllSupportedbySpec(
200     func::FuncOp func, const InferenceDeviceType& device_inference_type) {
201   bool found_unsupported = false;
202   func.walk([&](Operation* op) {
203     if (IsNonConstOp(op) && !IsTerminatorOp(op) &&
204         NotTFLQuantDequantizeOp(op) &&
205         !llvm::isa<func::ReturnOp, func::FuncOp, CallOpInterface>(op) &&
206         !IsSupported(op, device_inference_type.hardware)) {
207       found_unsupported = true;
208     }
209   });
210   return !found_unsupported;
211 }
212 
Optimize(func::FuncOp func,const std::string & hardware)213 void AlternativeSubgraphPass::Optimize(func::FuncOp func,
214                                        const std::string& hardware) {
215   auto* ctx = &getContext();
216   RewritePatternSet patterns = GetHardwareRewritePatterns(ctx, hardware);
217   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
218 }
219 
220 // Get the alternative view of the func for the given device_inference_type.
221 // It's possible the transformed func can still contain unsupported ops for the
222 // given device_inference_type.
GetAlternativeViewForSpec(func::FuncOp func,const InferenceDeviceType & current_device_inference_type,const InferenceDeviceType & target_device_inference_type,ModuleOp module,OpBuilder * builder)223 func::FuncOp AlternativeSubgraphPass::GetAlternativeViewForSpec(
224     func::FuncOp func, const InferenceDeviceType& current_device_inference_type,
225     const InferenceDeviceType& target_device_inference_type, ModuleOp module,
226     OpBuilder* builder) {
227   func::FuncOp cloned_func = func.clone();
228   cloned_func.setPrivate();
229   auto interface_name = GetInterFaceName(func);
230   if (!interface_name.has_value()) {
231     func.emitError("the func op does not have interface_name");
232     return nullptr;
233   }
234 
235   cloned_func->setAttr(
236       kDevice, builder->getStringAttr(target_device_inference_type.hardware));
237   cloned_func->setAttr(kInferenceType,
238                        builder->getStringAttr(GetInferenceString(
239                            target_device_inference_type.inference_type)));
240   std::string new_function_name = GetFunctionImplName(
241       interface_name.getValue(), target_device_inference_type);
242   cloned_func.setName(new_function_name);
243 
244   // If it's quantized -> float, we need to wrap all the ops around with dequant
245   // and quant.
246   if ((current_device_inference_type.inference_type == QUANTIZED_UINT8 ||
247        current_device_inference_type.inference_type == QUANTIZED_INT8) &&
248       target_device_inference_type.inference_type == FLOAT) {
249     OpBuilder cloned_func_builder(cloned_func);
250     ConvertQuantizedOpToFloat(cloned_func, &cloned_func_builder);
251     OptimizeQuantizedOpToFloat(cloned_func, &getContext());
252   }
253 
254   Optimize(cloned_func, target_device_inference_type.hardware);
255 
256   // Set device for each op.
257   cloned_func.walk([&](Operation* op) {
258     if (IsNonConstOp(op) && !IsTerminatorOp(op) &&
259         !llvm::isa<func::ReturnOp, func::FuncOp, CallableOpInterface>(op)) {
260       op->setAttr(kDevice, builder->getStringAttr(
261                                target_device_inference_type.hardware));
262       op->setAttr(kInferenceType,
263                   builder->getStringAttr(GetInferenceString(
264                       target_device_inference_type.inference_type)));
265     }
266   });
267 
268   module.push_back(cloned_func);
269   return cloned_func;
270 }
271 
runOnOperation()272 void AlternativeSubgraphPass::runOnOperation() {
273   auto module = getOperation();
274 
275   // Process devices specs.
276   if (device_specs_flag_.empty()) {
277     module.emitError("no device specs specified");
278     signalPassFailure();
279   }
280 
281   std::vector<std::string> device_specs;
282   if (!ProcessTargetDevices(device_specs_flag_, &device_specs)) {
283     module.emitError("unknown devices specified");
284     signalPassFailure();
285   }
286 
287   SmallVector<func::FuncOp, 25> funcs_to_be_processed;
288   // We only process if func has device annotations.
289   for (auto func : module.getOps<func::FuncOp>()) {
290     auto device_attr = func->getAttrOfType<StringAttr>(kDevice);
291     if (device_attr != nullptr) funcs_to_be_processed.push_back(func);
292   }
293 
294   OpBuilder builder(module);
295   // Go head to process those funcs.
296   // We don't process in the previous loop is we're adding new funcs,
297   // this is to avoid unnecessary processing.
298   for (auto func : funcs_to_be_processed) {
299     GetAlternativeGraphForFunc(device_specs, func, module, &builder);
300   }
301 }
302 
303 }  // namespace
304 
CreateAlternativeSubgraphPass(llvm::ArrayRef<std::string> device_specs)305 std::unique_ptr<OperationPass<ModuleOp>> CreateAlternativeSubgraphPass(
306     llvm::ArrayRef<std::string> device_specs) {
307   return std::make_unique<AlternativeSubgraphPass>(device_specs);
308 }
309 
310 static PassRegistration<AlternativeSubgraphPass> pass;
311 
312 }  // namespace tac
313 }  // namespace TFL
314 }  // namespace mlir
315