1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <utility>
19
20 #include "absl/strings/str_cat.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Block.h" // from @llvm-project
32 #include "mlir/IR/Builders.h" // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/OperationSupport.h" // from @llvm-project
36 #include "mlir/IR/Value.h" // from @llvm-project
37 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
38 #include "mlir/Pass/Pass.h" // from @llvm-project
39 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
40 #include "mlir/Support/LLVM.h" // from @llvm-project
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h"
43 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
44 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
45 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
46 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.h"
47 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
48 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
49
50 namespace mlir {
51 namespace TFL {
52 namespace tac {
53 namespace {
54
55 // Given the function interface name and the InferenceDeviceType, return the
56 // new function name.
GetFunctionImplName(std::string interface_name,const InferenceDeviceType & device_inference_type)57 std::string GetFunctionImplName(
58 std::string interface_name,
59 const InferenceDeviceType& device_inference_type) {
60 return absl::StrCat(interface_name, "_", device_inference_type.hardware, "_",
61 GetInferenceString(device_inference_type.inference_type));
62 }
63
64 // For every device, we will do the following:
65 // If the inference type is quantized, we will try the float alternative.
66 // If it's float, we will just keep it as it is.
GetAllAlternativeInferenceDeviceType(InferenceType inference_type,ArrayRef<std::string> devices)67 std::vector<InferenceDeviceType> GetAllAlternativeInferenceDeviceType(
68 InferenceType inference_type, ArrayRef<std::string> devices) {
69 std::vector<InferenceDeviceType> all_device_inference_types;
70 for (const auto& device : devices) {
71 if (inference_type == QUANTIZED_INT8) {
72 all_device_inference_types.push_back({device, QUANTIZED_INT8});
73 } else if (inference_type == QUANTIZED_UINT8) {
74 all_device_inference_types.push_back({device, QUANTIZED_UINT8});
75 }
76
77 // We will alway enable float.
78 all_device_inference_types.push_back({device, FLOAT});
79 }
80
81 return all_device_inference_types;
82 }
83
84 // This pass will try to get alternative subgraph:
85 // Say a subgraph is annotated with CPU (it probably means the ops it contains
86 // cannot be run on other deviecs):
87 //
88 // We will try:
89 // 1) If we can do some mathmatically equaivalent transformation so this
90 // subgraph can be run on other devices.
91 // 2) We will other apply device-specifics optimizations as well, that includes
92 // maybe tensor layout transformation, device specific fusion, etc.
93 class AlternativeSubgraphPass
94 : public mlir::PassWrapper<AlternativeSubgraphPass,
95 mlir::OperationPass<ModuleOp>> {
96 public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AlternativeSubgraphPass)97 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AlternativeSubgraphPass)
98
99 llvm::StringRef getArgument() const final {
100 return "tfl-get-alternative-subgraph";
101 }
getDescription() const102 llvm::StringRef getDescription() const final {
103 return "Get alternative subgraph representation (if appliable) for all the "
104 "given devices, will by default include the cpu implementation.";
105 }
106 AlternativeSubgraphPass() = default;
AlternativeSubgraphPass(const AlternativeSubgraphPass &)107 AlternativeSubgraphPass(const AlternativeSubgraphPass&) {}
AlternativeSubgraphPass(llvm::ArrayRef<std::string> device_specs)108 explicit AlternativeSubgraphPass(llvm::ArrayRef<std::string> device_specs) {
109 device_specs_flag_ = device_specs;
110 }
111
112 private:
113 void runOnOperation() override;
114
115 // Given a func and targeted devices, we will try to clonse the func &
116 // transform/optimize for those devices.
117 // This will only happen if the whole subgraph can be supported by the target
118 // or can be supported after some transformations.
119 void GetAlternativeGraphForFunc(ArrayRef<std::string> devices,
120 func::FuncOp func, ModuleOp module,
121 OpBuilder* builder);
122
123 // If all ops in the func op is able to be represented in the hardware, we
124 // will return true, else will be false.
125 // This is basically all or nothing.
126 bool IsAllSupportedbySpec(func::FuncOp func,
127 const InferenceDeviceType& inference_type);
128
129 // Given a func and a targeted device, we will try to clonse the func &
130 // transform/optimize for that device.
131 // It's simply clone the FuncOp and hardware specific transformations.
132 func::FuncOp GetAlternativeViewForSpec(
133 func::FuncOp func,
134 const InferenceDeviceType& current_device_inference_type,
135 const InferenceDeviceType& target_device_inference_type, ModuleOp module,
136 OpBuilder* builder);
137
138 // Apply any device-specific optimizations.
139 void Optimize(func::FuncOp func, const std::string& hardware);
140
141 ListOption<std::string> device_specs_flag_{
142 *this, "device-specs",
143 llvm::cl::desc(
144 "comma separated list of device specs, like CPU, GPU, DPS."),
145 llvm::cl::ZeroOrMore};
146 };
147
GetAlternativeGraphForFunc(ArrayRef<std::string> devices,func::FuncOp func,ModuleOp module,OpBuilder * builder)148 void AlternativeSubgraphPass::GetAlternativeGraphForFunc(
149 ArrayRef<std::string> devices, func::FuncOp func, ModuleOp module,
150 OpBuilder* builder) {
151 auto current_device = GetTargetAnnotation(func);
152 if (current_device->empty()) {
153 func.emitError(
154 "cannot find target annotation or unknown device specified for current "
155 "function");
156 return;
157 }
158
159 auto current_inference_type = GetInferenceTypeAnnotation(func);
160 if (!current_inference_type.has_value() ||
161 current_inference_type == UNKNOWN) {
162 func.emitError(
163 "cannot find inference type annotation or unknown inference type "
164 "specified for current "
165 "function");
166 return;
167 }
168
169 const InferenceDeviceType current_device_type(
170 {current_device.getValue(), current_inference_type.getValue()});
171
172 const std::vector<InferenceDeviceType>& all_inference_device_type =
173 GetAllAlternativeInferenceDeviceType(current_inference_type.getValue(),
174 devices);
175
176 for (const auto& device_inference_type : all_inference_device_type) {
177 if (device_inference_type != current_device_type) {
178 func::FuncOp cloned_func = GetAlternativeViewForSpec(
179 func, current_device_type, device_inference_type, module, builder);
180 // If we found unsupported ops, we will just go ahead and remove this
181 // function.
182 // TODO(b/160284136): currently we check if the ops are supported then
183 // see if we need to erase the func op.
184 // Ideally it would be nice if we can utilize dynamic illegal op to do
185 // the job.
186 if (!IsAllSupportedbySpec(cloned_func, device_inference_type)) {
187 cloned_func.erase();
188 }
189 }
190 }
191
192 // Perform the device-specific optimization last.
193 // We need to run the optimization for the current device last because we
194 // need to avoid any changes made the current graph polluting other
195 // alternative graph views.
196 Optimize(func, current_device.getValue());
197 }
198
IsAllSupportedbySpec(func::FuncOp func,const InferenceDeviceType & device_inference_type)199 bool AlternativeSubgraphPass::IsAllSupportedbySpec(
200 func::FuncOp func, const InferenceDeviceType& device_inference_type) {
201 bool found_unsupported = false;
202 func.walk([&](Operation* op) {
203 if (IsNonConstOp(op) && !IsTerminatorOp(op) &&
204 NotTFLQuantDequantizeOp(op) &&
205 !llvm::isa<func::ReturnOp, func::FuncOp, CallOpInterface>(op) &&
206 !IsSupported(op, device_inference_type.hardware)) {
207 found_unsupported = true;
208 }
209 });
210 return !found_unsupported;
211 }
212
Optimize(func::FuncOp func,const std::string & hardware)213 void AlternativeSubgraphPass::Optimize(func::FuncOp func,
214 const std::string& hardware) {
215 auto* ctx = &getContext();
216 RewritePatternSet patterns = GetHardwareRewritePatterns(ctx, hardware);
217 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
218 }
219
220 // Get the alternative view of the func for the given device_inference_type.
221 // It's possible the transformed func can still contain unsupported ops for the
222 // given device_inference_type.
GetAlternativeViewForSpec(func::FuncOp func,const InferenceDeviceType & current_device_inference_type,const InferenceDeviceType & target_device_inference_type,ModuleOp module,OpBuilder * builder)223 func::FuncOp AlternativeSubgraphPass::GetAlternativeViewForSpec(
224 func::FuncOp func, const InferenceDeviceType& current_device_inference_type,
225 const InferenceDeviceType& target_device_inference_type, ModuleOp module,
226 OpBuilder* builder) {
227 func::FuncOp cloned_func = func.clone();
228 cloned_func.setPrivate();
229 auto interface_name = GetInterFaceName(func);
230 if (!interface_name.has_value()) {
231 func.emitError("the func op does not have interface_name");
232 return nullptr;
233 }
234
235 cloned_func->setAttr(
236 kDevice, builder->getStringAttr(target_device_inference_type.hardware));
237 cloned_func->setAttr(kInferenceType,
238 builder->getStringAttr(GetInferenceString(
239 target_device_inference_type.inference_type)));
240 std::string new_function_name = GetFunctionImplName(
241 interface_name.getValue(), target_device_inference_type);
242 cloned_func.setName(new_function_name);
243
244 // If it's quantized -> float, we need to wrap all the ops around with dequant
245 // and quant.
246 if ((current_device_inference_type.inference_type == QUANTIZED_UINT8 ||
247 current_device_inference_type.inference_type == QUANTIZED_INT8) &&
248 target_device_inference_type.inference_type == FLOAT) {
249 OpBuilder cloned_func_builder(cloned_func);
250 ConvertQuantizedOpToFloat(cloned_func, &cloned_func_builder);
251 OptimizeQuantizedOpToFloat(cloned_func, &getContext());
252 }
253
254 Optimize(cloned_func, target_device_inference_type.hardware);
255
256 // Set device for each op.
257 cloned_func.walk([&](Operation* op) {
258 if (IsNonConstOp(op) && !IsTerminatorOp(op) &&
259 !llvm::isa<func::ReturnOp, func::FuncOp, CallableOpInterface>(op)) {
260 op->setAttr(kDevice, builder->getStringAttr(
261 target_device_inference_type.hardware));
262 op->setAttr(kInferenceType,
263 builder->getStringAttr(GetInferenceString(
264 target_device_inference_type.inference_type)));
265 }
266 });
267
268 module.push_back(cloned_func);
269 return cloned_func;
270 }
271
runOnOperation()272 void AlternativeSubgraphPass::runOnOperation() {
273 auto module = getOperation();
274
275 // Process devices specs.
276 if (device_specs_flag_.empty()) {
277 module.emitError("no device specs specified");
278 signalPassFailure();
279 }
280
281 std::vector<std::string> device_specs;
282 if (!ProcessTargetDevices(device_specs_flag_, &device_specs)) {
283 module.emitError("unknown devices specified");
284 signalPassFailure();
285 }
286
287 SmallVector<func::FuncOp, 25> funcs_to_be_processed;
288 // We only process if func has device annotations.
289 for (auto func : module.getOps<func::FuncOp>()) {
290 auto device_attr = func->getAttrOfType<StringAttr>(kDevice);
291 if (device_attr != nullptr) funcs_to_be_processed.push_back(func);
292 }
293
294 OpBuilder builder(module);
295 // Go head to process those funcs.
296 // We don't process in the previous loop is we're adding new funcs,
297 // this is to avoid unnecessary processing.
298 for (auto func : funcs_to_be_processed) {
299 GetAlternativeGraphForFunc(device_specs, func, module, &builder);
300 }
301 }
302
303 } // namespace
304
CreateAlternativeSubgraphPass(llvm::ArrayRef<std::string> device_specs)305 std::unique_ptr<OperationPass<ModuleOp>> CreateAlternativeSubgraphPass(
306 llvm::ArrayRef<std::string> device_specs) {
307 return std::make_unique<AlternativeSubgraphPass>(device_specs);
308 }
309
310 static PassRegistration<AlternativeSubgraphPass> pass;
311
312 } // namespace tac
313 } // namespace TFL
314 } // namespace mlir
315