1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_
17 #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_
18
19 #include <algorithm>
20 #include <functional>
21 #include <string>
22 #include <vector>
23
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/None.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "mlir/IR/Operation.h" // from @llvm-project
29
30 namespace mlir {
31 namespace TFL {
32 namespace tac {
33
34 // Device attribute string on the TFL dialect.
35 constexpr char kDevice[] = "tac.device";
36
37 // Inference type.
38 constexpr char kInferenceType[] = "tac.inference_type";
39
40 // TODO(renjieliu): Add more inference types.
41 enum InferenceType {
42 UNKNOWN = 0,
43 FLOAT = 1,
44 QUANTIZED_INT8 = 2,
45 QUANTIZED_UINT8 = 3,
46 HYBRID = 4
47 };
48
GetInferenceTypeEnum(llvm::StringRef inference_type_str)49 inline InferenceType GetInferenceTypeEnum(llvm::StringRef inference_type_str) {
50 if (inference_type_str == "FLOAT") {
51 return FLOAT;
52 } else if (inference_type_str == "QUANTIZED_INT8") {
53 return QUANTIZED_INT8;
54 } else if (inference_type_str == "QUANTIZED_UINT8") {
55 return QUANTIZED_UINT8;
56 } else if (inference_type_str == "HYBRID") {
57 return HYBRID;
58 } else {
59 return UNKNOWN;
60 }
61 }
62
GetInferenceString(InferenceType inference_type)63 inline std::string GetInferenceString(InferenceType inference_type) {
64 if (inference_type == FLOAT) {
65 return "FLOAT";
66 } else if (inference_type == QUANTIZED_INT8) {
67 return "QUANTIZED_INT8";
68 } else if (inference_type == QUANTIZED_UINT8) {
69 return "QUANTIZED_UINT8";
70 } else if (inference_type == HYBRID) {
71 return "HYBRID";
72 } else {
73 return "UNKNOWN";
74 }
75 }
76
77 // Returns canonical representation for hardware name (All uppercase).
78 // TODO(b/177376459): Remove this in favor of the string defined by hardwares
79 // MyHardware::kId.
GetCanonicalHardwareName(const std::string & hardware_name)80 inline std::string GetCanonicalHardwareName(const std::string& hardware_name) {
81 std::string name = hardware_name;
82 std::transform(
83 name.begin(), name.end(), name.begin(),
84 [](unsigned char c) -> unsigned char { return std::toupper(c); });
85 return name;
86 }
87
88 // Get the target annotation form the op.
GetTargetAnnotation(Operation * op)89 inline llvm::Optional<std::string> GetTargetAnnotation(Operation* op) {
90 auto device = op->getAttrOfType<StringAttr>(kDevice);
91 if (device == nullptr || device.getValue().empty()) return llvm::None;
92
93 return GetCanonicalHardwareName(device.getValue().str());
94 }
95
96 // Get inference type attribute from the operation if available.
GetInferenceTypeAnnotation(Operation * op)97 inline llvm::Optional<InferenceType> GetInferenceTypeAnnotation(Operation* op) {
98 auto inference_type = op->getAttrOfType<StringAttr>(kInferenceType);
99 if (inference_type == nullptr) return llvm::None;
100
101 llvm::StringRef device_name_str = inference_type.getValue();
102 return GetInferenceTypeEnum(device_name_str);
103 }
104
105 // InferenceDeviceType is a combination of the hardware with inference type.
106 struct InferenceDeviceType {
107 std::string hardware;
108 InferenceType inference_type;
109
110 bool operator==(const InferenceDeviceType& other) const {
111 return (hardware == other.hardware) &&
112 (inference_type == other.inference_type);
113 }
114
115 bool operator!=(const InferenceDeviceType& other) const {
116 return !(*this == other);
117 }
118
119 struct inference_device_type_hash {
operatorInferenceDeviceType::inference_device_type_hash120 size_t operator()(const InferenceDeviceType& p) const {
121 auto hash1 = std::hash<std::string>{}(p.hardware);
122 auto hash2 = std::hash<InferenceType>{}(p.inference_type);
123 return hash1 ^ hash2;
124 }
125 };
126 };
127
128 // Get InferenceDeviceType attribute from the operation if available.
GetInferenceDeviceTypeForOp(Operation * op)129 inline llvm::Optional<InferenceDeviceType> GetInferenceDeviceTypeForOp(
130 Operation* op) {
131 auto hardware = GetTargetAnnotation(op);
132 if (!hardware.has_value()) return llvm::None;
133
134 auto inference_type = GetInferenceTypeAnnotation(op);
135 if (!inference_type.has_value()) return llvm::None;
136
137 InferenceDeviceType inference_device_type;
138 inference_device_type.hardware = hardware.getValue();
139 inference_device_type.inference_type = inference_type.getValue();
140 return inference_device_type;
141 }
142
143 } // namespace tac
144 } // namespace TFL
145 } // namespace mlir
146
147 #endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_
148