xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_
17 #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_
18 
19 #include <algorithm>
20 #include <functional>
21 #include <string>
22 #include <vector>
23 
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/None.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "mlir/IR/Operation.h"  // from @llvm-project
29 
30 namespace mlir {
31 namespace TFL {
32 namespace tac {
33 
34 // Device attribute string on the TFL dialect.
35 constexpr char kDevice[] = "tac.device";
36 
37 // Inference type.
38 constexpr char kInferenceType[] = "tac.inference_type";
39 
40 // TODO(renjieliu): Add more inference types.
41 enum InferenceType {
42   UNKNOWN = 0,
43   FLOAT = 1,
44   QUANTIZED_INT8 = 2,
45   QUANTIZED_UINT8 = 3,
46   HYBRID = 4
47 };
48 
GetInferenceTypeEnum(llvm::StringRef inference_type_str)49 inline InferenceType GetInferenceTypeEnum(llvm::StringRef inference_type_str) {
50   if (inference_type_str == "FLOAT") {
51     return FLOAT;
52   } else if (inference_type_str == "QUANTIZED_INT8") {
53     return QUANTIZED_INT8;
54   } else if (inference_type_str == "QUANTIZED_UINT8") {
55     return QUANTIZED_UINT8;
56   } else if (inference_type_str == "HYBRID") {
57     return HYBRID;
58   } else {
59     return UNKNOWN;
60   }
61 }
62 
GetInferenceString(InferenceType inference_type)63 inline std::string GetInferenceString(InferenceType inference_type) {
64   if (inference_type == FLOAT) {
65     return "FLOAT";
66   } else if (inference_type == QUANTIZED_INT8) {
67     return "QUANTIZED_INT8";
68   } else if (inference_type == QUANTIZED_UINT8) {
69     return "QUANTIZED_UINT8";
70   } else if (inference_type == HYBRID) {
71     return "HYBRID";
72   } else {
73     return "UNKNOWN";
74   }
75 }
76 
77 // Returns canonical representation for hardware name (All uppercase).
78 // TODO(b/177376459): Remove this in favor of the string defined by hardwares
79 // MyHardware::kId.
GetCanonicalHardwareName(const std::string & hardware_name)80 inline std::string GetCanonicalHardwareName(const std::string& hardware_name) {
81   std::string name = hardware_name;
82   std::transform(
83       name.begin(), name.end(), name.begin(),
84       [](unsigned char c) -> unsigned char { return std::toupper(c); });
85   return name;
86 }
87 
88 // Get the target annotation form the op.
GetTargetAnnotation(Operation * op)89 inline llvm::Optional<std::string> GetTargetAnnotation(Operation* op) {
90   auto device = op->getAttrOfType<StringAttr>(kDevice);
91   if (device == nullptr || device.getValue().empty()) return llvm::None;
92 
93   return GetCanonicalHardwareName(device.getValue().str());
94 }
95 
96 // Get inference type attribute from the operation if available.
GetInferenceTypeAnnotation(Operation * op)97 inline llvm::Optional<InferenceType> GetInferenceTypeAnnotation(Operation* op) {
98   auto inference_type = op->getAttrOfType<StringAttr>(kInferenceType);
99   if (inference_type == nullptr) return llvm::None;
100 
101   llvm::StringRef device_name_str = inference_type.getValue();
102   return GetInferenceTypeEnum(device_name_str);
103 }
104 
105 // InferenceDeviceType is a combination of the hardware with inference type.
106 struct InferenceDeviceType {
107   std::string hardware;
108   InferenceType inference_type;
109 
110   bool operator==(const InferenceDeviceType& other) const {
111     return (hardware == other.hardware) &&
112            (inference_type == other.inference_type);
113   }
114 
115   bool operator!=(const InferenceDeviceType& other) const {
116     return !(*this == other);
117   }
118 
119   struct inference_device_type_hash {
operatorInferenceDeviceType::inference_device_type_hash120     size_t operator()(const InferenceDeviceType& p) const {
121       auto hash1 = std::hash<std::string>{}(p.hardware);
122       auto hash2 = std::hash<InferenceType>{}(p.inference_type);
123       return hash1 ^ hash2;
124     }
125   };
126 };
127 
128 // Get InferenceDeviceType attribute from the operation if available.
GetInferenceDeviceTypeForOp(Operation * op)129 inline llvm::Optional<InferenceDeviceType> GetInferenceDeviceTypeForOp(
130     Operation* op) {
131   auto hardware = GetTargetAnnotation(op);
132   if (!hardware.has_value()) return llvm::None;
133 
134   auto inference_type = GetInferenceTypeAnnotation(op);
135   if (!inference_type.has_value()) return llvm::None;
136 
137   InferenceDeviceType inference_device_type;
138   inference_device_type.hardware = hardware.getValue();
139   inference_device_type.inference_type = inference_type.getValue();
140   return inference_device_type;
141 }
142 
143 }  // namespace tac
144 }  // namespace TFL
145 }  // namespace mlir
146 
147 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_
148