1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/DenseSet.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/MLIRContext.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
28 #include "mlir/Support/LLVM.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
30 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
31 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
32 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.h"
33 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_pass.h"
35 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
36
37 namespace mlir {
38 namespace TFL {
39 namespace tac {
40 namespace {
41
42 class TargetAnnotationPass : public TacFunctionPass<TargetAnnotationPass> {
43 public:
getArgument() const44 llvm::StringRef getArgument() const final { return "tfl-target-annotation"; }
getDescription() const45 llvm::StringRef getDescription() const final {
46 return "Add user specified target annotations to the TFL operations given "
47 "operation capabilities, will default to CPU.";
48 }
49 // using TacFunctionPass::TacFunctionPass;
TargetAnnotationPass()50 TargetAnnotationPass() : TacFunctionPass(nullptr) {}
TargetAnnotationPass(const TargetAnnotationPass & copy)51 TargetAnnotationPass(const TargetAnnotationPass& copy)
52 : TacFunctionPass(copy.module_) {}
TargetAnnotationPass(llvm::ArrayRef<std::string> device_specs)53 explicit TargetAnnotationPass(llvm::ArrayRef<std::string> device_specs)
54 : TacFunctionPass(nullptr) {
55 device_specs_flag_ = device_specs;
56 }
57
TargetAnnotationPass(const TacModule * module)58 explicit TargetAnnotationPass(const TacModule* module)
59 : TacFunctionPass(module) {}
60
61 private:
62 void runOnFunction() override;
63 void SetTargetAnnotation(Operation* op,
64 llvm::ArrayRef<std::string> device_specs,
65 OpBuilder* builder);
66
67 ListOption<std::string> device_specs_flag_{
68 *this, "device-specs",
69 llvm::cl::desc(
70 "comma separated list of device specs, like CPU, GPU, Hexagon."),
71 llvm::cl::ZeroOrMore};
72 };
73
SetAnnotation(Operation * op,std::string attribute,std::string annotation,OpBuilder * builder)74 void SetAnnotation(Operation* op, std::string attribute, std::string annotation,
75 OpBuilder* builder) {
76 // TODO(karimnosseir): Maybe set device capabilities to allow us to have
77 // more flexbility when raise the subgraphs.
78 auto default_target = builder->getStringAttr(annotation);
79 op->setAttr(attribute, default_target);
80 }
81
SetTargetAnnotation(Operation * op,llvm::ArrayRef<std::string> device_specs,OpBuilder * builder)82 void TargetAnnotationPass::SetTargetAnnotation(
83 Operation* op, llvm::ArrayRef<std::string> device_specs,
84 OpBuilder* builder) {
85 const InferenceType inference_type = GetInferenceType(op);
86 const std::string inference_type_str = GetInferenceString(inference_type);
87 SetAnnotation(op, kInferenceType, inference_type_str, builder);
88 bool device_is_set = false;
89 // TODO(b/177376459): Remove the usage of device_specs.
90 // TODO(b/177376459): Update if needed to make testing easy.
91 if (!module_) {
92 for (const auto& device : device_specs) {
93 auto* hardware = this->GetTargetHardware(device);
94 if (hardware == nullptr) continue;
95 if (hardware->IsOpSupported(op)) {
96 SetAnnotation(op, kDevice, device, builder);
97 device_is_set = true;
98 break;
99 }
100 }
101 } else {
102 for (const auto* hardware : module_->GetAvailableHardwares()) {
103 if (hardware == nullptr) continue;
104 if (hardware->IsOpSupported(op)) {
105 SetAnnotation(op, kDevice, GetHardwareName(hardware), builder);
106 device_is_set = true;
107 break;
108 }
109 }
110 }
111 // default to CPU
112 if (!device_is_set) {
113 if (IsNonConstOp(op) && !IsTerminatorOp(op) &&
114 !llvm::isa<func::ReturnOp, func::FuncOp, CallableOpInterface>(op)) {
115 SetAnnotation(op, kDevice, "CPU", builder);
116 device_is_set = true;
117 }
118 }
119 if (!device_is_set) {
120 op->emitError("cannot set target device for this ops");
121 }
122 }
123
runOnFunction()124 void TargetAnnotationPass::runOnFunction() {
125 auto func = getFunction();
126 OpBuilder builder(func);
127
128 func.walk([&](Operation* op) {
129 // We only care about TFL dialect.
130 if (IsNonConstOp(op) && NotTFLQuantDequantizeOp(op) &&
131 !IsTerminatorOp(op) &&
132 !llvm::isa<func::ReturnOp, func::FuncOp, CallOpInterface>(op)) {
133 SetTargetAnnotation(op, device_specs_flag_, &builder);
134 }
135 });
136 }
137
138 } // namespace
139
CreateTargetAnnotationPass(llvm::ArrayRef<std::string> device_specs)140 std::unique_ptr<OperationPass<func::FuncOp>> CreateTargetAnnotationPass(
141 llvm::ArrayRef<std::string> device_specs) {
142 return std::make_unique<TargetAnnotationPass>(device_specs);
143 }
144
CreateTargetAnnotationPass(const TacModule * module)145 std::unique_ptr<OperationPass<func::FuncOp>> CreateTargetAnnotationPass(
146 const TacModule* module) {
147 return std::make_unique<TargetAnnotationPass>(module);
148 }
149
150 static PassRegistration<TargetAnnotationPass> pass;
151
152 } // namespace tac
153 } // namespace TFL
154 } // namespace mlir
155