1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_REDUCED_PRECISION_SUPPORT_H
16 #define TENSORFLOW_LITE_TOOLS_OPTIMIZE_REDUCED_PRECISION_SUPPORT_H
17
18 #include <string>
19
20 #include "tensorflow/lite/kernels/internal/compatibility.h"
21
22 namespace tflite {
23 namespace optimize {
24
25 static constexpr char kTfLiteReducedPrecisionKey[] =
26 "reduced_precision_support";
27
28 static constexpr char kTfLiteFloat16String[] = "fp16";
29 static constexpr char kTfLiteBfloat16String[] = "bf16";
30 static constexpr char kTfLiteFloat32String[] = "fp32";
31 static constexpr char kTfLiteAccumulationString[] = "acc";
32
33 enum class ReducedPrecisionSupport : std::uint8_t {
34 None = 0,
35 Float16Inference = 0x1,
36 Bfloat16Inference = 0x2,
37 Float16Accumulation = 0x4,
38 Float32Accumulation = 0x8,
39 };
40
41 inline ReducedPrecisionSupport operator|(ReducedPrecisionSupport a,
42 ReducedPrecisionSupport b) {
43 return static_cast<ReducedPrecisionSupport>(static_cast<std::uint32_t>(a) |
44 static_cast<std::uint32_t>(b));
45 }
46
47 inline ReducedPrecisionSupport& operator|=(ReducedPrecisionSupport& a,
48 ReducedPrecisionSupport b) {
49 return a = static_cast<ReducedPrecisionSupport>(
50 static_cast<std::uint32_t>(a) | static_cast<std::uint32_t>(b));
51 }
52
53 inline ReducedPrecisionSupport operator&(ReducedPrecisionSupport a,
54 ReducedPrecisionSupport b) {
55 return static_cast<ReducedPrecisionSupport>(static_cast<std::uint32_t>(a) &
56 static_cast<std::uint32_t>(b));
57 }
58
59 inline ReducedPrecisionSupport& operator&=(ReducedPrecisionSupport& a,
60 ReducedPrecisionSupport b) {
61 return a = static_cast<ReducedPrecisionSupport>(
62 static_cast<std::uint32_t>(a) & static_cast<std::uint32_t>(b));
63 }
64
SupportsFP16Inference(const ReducedPrecisionSupport & mask)65 inline bool SupportsFP16Inference(const ReducedPrecisionSupport& mask) {
66 return static_cast<bool>(mask & ReducedPrecisionSupport::Float16Inference);
67 }
68
SupportsBfloat16Inference(const ReducedPrecisionSupport & mask)69 inline bool SupportsBfloat16Inference(const ReducedPrecisionSupport& mask) {
70 return static_cast<bool>(mask & ReducedPrecisionSupport::Bfloat16Inference);
71 }
72
SupportsFP16Accumulation(const ReducedPrecisionSupport & mask)73 inline bool SupportsFP16Accumulation(const ReducedPrecisionSupport& mask) {
74 return static_cast<bool>(mask & ReducedPrecisionSupport::Float16Accumulation);
75 }
76
SupportsFP32Accumulation(const ReducedPrecisionSupport & mask)77 inline bool SupportsFP32Accumulation(const ReducedPrecisionSupport& mask) {
78 return static_cast<bool>(mask & ReducedPrecisionSupport::Float32Accumulation);
79 }
80
SupportsReducedPrecisionInference(const ReducedPrecisionSupport & mask)81 inline bool SupportsReducedPrecisionInference(
82 const ReducedPrecisionSupport& mask) {
83 return SupportsFP16Inference(mask) || SupportsBfloat16Inference(mask);
84 }
85
SupportsEitherFP16OrFP32Accumulation(const ReducedPrecisionSupport & mask)86 inline bool SupportsEitherFP16OrFP32Accumulation(
87 const ReducedPrecisionSupport& mask) {
88 return SupportsFP16Accumulation(mask) != SupportsFP32Accumulation(mask);
89 }
90
91 // Return the key-value pair for reduced precision support metadata.
92 // Example: mask = Float16Inference | Bfloat16Inference | Float32Accumulation;
93 // Returned value would be <"reduced_precision_support", "fp16bf16accfp32">.
MetadataForReducedPrecisionSupport(const ReducedPrecisionSupport & mask)94 inline std::pair<std::string, std::string> MetadataForReducedPrecisionSupport(
95 const ReducedPrecisionSupport& mask) {
96 TFLITE_DCHECK(SupportsReducedPrecisionInference(mask));
97 TFLITE_DCHECK(SupportsEitherFP16OrFP32Accumulation(mask));
98 std::string value = "";
99 if (SupportsFP16Inference(mask)) {
100 value += kTfLiteFloat16String;
101 }
102 if (SupportsBfloat16Inference(mask)) {
103 value += kTfLiteBfloat16String;
104 }
105 value += kTfLiteAccumulationString;
106 if (SupportsFP16Accumulation(mask)) {
107 value += kTfLiteFloat16String;
108 } else if (SupportsFP32Accumulation(mask)) {
109 value += kTfLiteFloat32String;
110 }
111 return std::make_pair(std::string(kTfLiteReducedPrecisionKey), value);
112 }
113
ReadInferenceType(const std::string & metadata,size_t * idx,ReducedPrecisionSupport * mask)114 inline bool ReadInferenceType(const std::string& metadata, size_t* idx,
115 ReducedPrecisionSupport* mask) {
116 if (metadata.substr(*idx, 4) == kTfLiteFloat16String) {
117 *idx += 4;
118 *mask = *mask | ReducedPrecisionSupport::Float16Inference;
119 return true;
120 } else if (metadata.substr(*idx, 4) == kTfLiteBfloat16String) {
121 *idx += 4;
122 *mask = *mask | ReducedPrecisionSupport::Bfloat16Inference;
123 return true;
124 }
125 return false;
126 }
127
ReadAccumulationType(const std::string & metadata,size_t * idx,ReducedPrecisionSupport * mask)128 inline bool ReadAccumulationType(const std::string& metadata, size_t* idx,
129 ReducedPrecisionSupport* mask) {
130 if (metadata.substr(*idx, 4) == kTfLiteFloat16String) {
131 *idx += 4;
132 *mask = *mask | ReducedPrecisionSupport::Float16Accumulation;
133 return true;
134 } else if (metadata.substr(*idx, 4) == kTfLiteFloat32String) {
135 *idx += 4;
136 *mask = *mask | ReducedPrecisionSupport::Float32Accumulation;
137 return true;
138 }
139 return false;
140 }
141
142 // If the string is valid, set the given mask to indicate the state in
143 // string and return true. If the string is invalid, return false.
144 // A valid string is:
145 // >= 1 valid inference types + accumulation token + 1 valid accumulation type.
146 // Valid examples would be: "fp16accfp16", "bf16accfp32"
SetMaskFromReducedPrecisionMetadata(const std::string & metadata,ReducedPrecisionSupport * mask)147 inline bool SetMaskFromReducedPrecisionMetadata(const std::string& metadata,
148 ReducedPrecisionSupport* mask) {
149 bool check = true;
150 size_t idx = 0;
151 ReducedPrecisionSupport rsp = ReducedPrecisionSupport::None;
152 do {
153 check = ReadInferenceType(metadata, &idx, &rsp);
154 } while (check);
155 // Ensure we read at least 1 inference type.
156 if (idx == 0) {
157 return false;
158 }
159 // Next read the accumulation token.
160 if (metadata.substr(idx, 3) != kTfLiteAccumulationString) {
161 return false;
162 }
163 idx += std::string(kTfLiteAccumulationString).size();
164 // Next read a valid accumulation type.
165 if (!ReadAccumulationType(metadata, &idx, &rsp)) {
166 return false;
167 }
168 // This should be the end of string.
169 if (idx != metadata.length()) {
170 return false;
171 }
172 // The string is a valid mask description. Set the value and return.
173 *mask = rsp;
174 return true;
175 }
176
177 } // namespace optimize
178 } // namespace tflite
179
180 #endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_REDUCED_PRECISION_SUPPORT_H
181