1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/DenseSet.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
28 #include "mlir/Support/LLVM.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
30 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
31 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
32 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.h"
33 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_pass.h"
35 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
36 
37 namespace mlir {
38 namespace TFL {
39 namespace tac {
40 namespace {
41 
42 class TargetAnnotationPass : public TacFunctionPass<TargetAnnotationPass> {
43  public:
getArgument() const44   llvm::StringRef getArgument() const final { return "tfl-target-annotation"; }
getDescription() const45   llvm::StringRef getDescription() const final {
46     return "Add user specified target annotations to the TFL operations given "
47            "operation capabilities, will default to CPU.";
48   }
49   // using TacFunctionPass::TacFunctionPass;
TargetAnnotationPass()50   TargetAnnotationPass() : TacFunctionPass(nullptr) {}
TargetAnnotationPass(const TargetAnnotationPass & copy)51   TargetAnnotationPass(const TargetAnnotationPass& copy)
52       : TacFunctionPass(copy.module_) {}
TargetAnnotationPass(llvm::ArrayRef<std::string> device_specs)53   explicit TargetAnnotationPass(llvm::ArrayRef<std::string> device_specs)
54       : TacFunctionPass(nullptr) {
55     device_specs_flag_ = device_specs;
56   }
57 
TargetAnnotationPass(const TacModule * module)58   explicit TargetAnnotationPass(const TacModule* module)
59       : TacFunctionPass(module) {}
60 
61  private:
62   void runOnFunction() override;
63   void SetTargetAnnotation(Operation* op,
64                            llvm::ArrayRef<std::string> device_specs,
65                            OpBuilder* builder);
66 
67   ListOption<std::string> device_specs_flag_{
68       *this, "device-specs",
69       llvm::cl::desc(
70           "comma separated list of device specs, like CPU, GPU, Hexagon."),
71       llvm::cl::ZeroOrMore};
72 };
73 
SetAnnotation(Operation * op,std::string attribute,std::string annotation,OpBuilder * builder)74 void SetAnnotation(Operation* op, std::string attribute, std::string annotation,
75                    OpBuilder* builder) {
76   // TODO(karimnosseir): Maybe set device capabilities to allow us to have
77   // more flexbility when raise the subgraphs.
78   auto default_target = builder->getStringAttr(annotation);
79   op->setAttr(attribute, default_target);
80 }
81 
SetTargetAnnotation(Operation * op,llvm::ArrayRef<std::string> device_specs,OpBuilder * builder)82 void TargetAnnotationPass::SetTargetAnnotation(
83     Operation* op, llvm::ArrayRef<std::string> device_specs,
84     OpBuilder* builder) {
85   const InferenceType inference_type = GetInferenceType(op);
86   const std::string inference_type_str = GetInferenceString(inference_type);
87   SetAnnotation(op, kInferenceType, inference_type_str, builder);
88   bool device_is_set = false;
89   // TODO(b/177376459): Remove the usage of device_specs.
90   // TODO(b/177376459): Update if needed to make testing easy.
91   if (!module_) {
92     for (const auto& device : device_specs) {
93       auto* hardware = this->GetTargetHardware(device);
94       if (hardware == nullptr) continue;
95       if (hardware->IsOpSupported(op)) {
96         SetAnnotation(op, kDevice, device, builder);
97         device_is_set = true;
98         break;
99       }
100     }
101   } else {
102     for (const auto* hardware : module_->GetAvailableHardwares()) {
103       if (hardware == nullptr) continue;
104       if (hardware->IsOpSupported(op)) {
105         SetAnnotation(op, kDevice, GetHardwareName(hardware), builder);
106         device_is_set = true;
107         break;
108       }
109     }
110   }
111   // default to CPU
112   if (!device_is_set) {
113     if (IsNonConstOp(op) && !IsTerminatorOp(op) &&
114         !llvm::isa<func::ReturnOp, func::FuncOp, CallableOpInterface>(op)) {
115       SetAnnotation(op, kDevice, "CPU", builder);
116       device_is_set = true;
117     }
118   }
119   if (!device_is_set) {
120     op->emitError("cannot set target device for this ops");
121   }
122 }
123 
runOnFunction()124 void TargetAnnotationPass::runOnFunction() {
125   auto func = getFunction();
126   OpBuilder builder(func);
127 
128   func.walk([&](Operation* op) {
129     // We only care about TFL dialect.
130     if (IsNonConstOp(op) && NotTFLQuantDequantizeOp(op) &&
131         !IsTerminatorOp(op) &&
132         !llvm::isa<func::ReturnOp, func::FuncOp, CallOpInterface>(op)) {
133       SetTargetAnnotation(op, device_specs_flag_, &builder);
134     }
135   });
136 }
137 
138 }  // namespace
139 
CreateTargetAnnotationPass(llvm::ArrayRef<std::string> device_specs)140 std::unique_ptr<OperationPass<func::FuncOp>> CreateTargetAnnotationPass(
141     llvm::ArrayRef<std::string> device_specs) {
142   return std::make_unique<TargetAnnotationPass>(device_specs);
143 }
144 
CreateTargetAnnotationPass(const TacModule * module)145 std::unique_ptr<OperationPass<func::FuncOp>> CreateTargetAnnotationPass(
146     const TacModule* module) {
147   return std::make_unique<TargetAnnotationPass>(module);
148 }
149 
150 static PassRegistration<TargetAnnotationPass> pass;
151 
152 }  // namespace tac
153 }  // namespace TFL
154 }  // namespace mlir
155