xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/optimize/reduced_precision_support.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_REDUCED_PRECISION_SUPPORT_H
16 #define TENSORFLOW_LITE_TOOLS_OPTIMIZE_REDUCED_PRECISION_SUPPORT_H
17 
18 #include <string>
19 
20 #include "tensorflow/lite/kernels/internal/compatibility.h"
21 
22 namespace tflite {
23 namespace optimize {
24 
25 static constexpr char kTfLiteReducedPrecisionKey[] =
26     "reduced_precision_support";
27 
28 static constexpr char kTfLiteFloat16String[] = "fp16";
29 static constexpr char kTfLiteBfloat16String[] = "bf16";
30 static constexpr char kTfLiteFloat32String[] = "fp32";
31 static constexpr char kTfLiteAccumulationString[] = "acc";
32 
33 enum class ReducedPrecisionSupport : std::uint8_t {
34   None = 0,
35   Float16Inference = 0x1,
36   Bfloat16Inference = 0x2,
37   Float16Accumulation = 0x4,
38   Float32Accumulation = 0x8,
39 };
40 
41 inline ReducedPrecisionSupport operator|(ReducedPrecisionSupport a,
42                                          ReducedPrecisionSupport b) {
43   return static_cast<ReducedPrecisionSupport>(static_cast<std::uint32_t>(a) |
44                                               static_cast<std::uint32_t>(b));
45 }
46 
47 inline ReducedPrecisionSupport& operator|=(ReducedPrecisionSupport& a,
48                                            ReducedPrecisionSupport b) {
49   return a = static_cast<ReducedPrecisionSupport>(
50              static_cast<std::uint32_t>(a) | static_cast<std::uint32_t>(b));
51 }
52 
53 inline ReducedPrecisionSupport operator&(ReducedPrecisionSupport a,
54                                          ReducedPrecisionSupport b) {
55   return static_cast<ReducedPrecisionSupport>(static_cast<std::uint32_t>(a) &
56                                               static_cast<std::uint32_t>(b));
57 }
58 
59 inline ReducedPrecisionSupport& operator&=(ReducedPrecisionSupport& a,
60                                            ReducedPrecisionSupport b) {
61   return a = static_cast<ReducedPrecisionSupport>(
62              static_cast<std::uint32_t>(a) & static_cast<std::uint32_t>(b));
63 }
64 
SupportsFP16Inference(const ReducedPrecisionSupport & mask)65 inline bool SupportsFP16Inference(const ReducedPrecisionSupport& mask) {
66   return static_cast<bool>(mask & ReducedPrecisionSupport::Float16Inference);
67 }
68 
SupportsBfloat16Inference(const ReducedPrecisionSupport & mask)69 inline bool SupportsBfloat16Inference(const ReducedPrecisionSupport& mask) {
70   return static_cast<bool>(mask & ReducedPrecisionSupport::Bfloat16Inference);
71 }
72 
SupportsFP16Accumulation(const ReducedPrecisionSupport & mask)73 inline bool SupportsFP16Accumulation(const ReducedPrecisionSupport& mask) {
74   return static_cast<bool>(mask & ReducedPrecisionSupport::Float16Accumulation);
75 }
76 
SupportsFP32Accumulation(const ReducedPrecisionSupport & mask)77 inline bool SupportsFP32Accumulation(const ReducedPrecisionSupport& mask) {
78   return static_cast<bool>(mask & ReducedPrecisionSupport::Float32Accumulation);
79 }
80 
SupportsReducedPrecisionInference(const ReducedPrecisionSupport & mask)81 inline bool SupportsReducedPrecisionInference(
82     const ReducedPrecisionSupport& mask) {
83   return SupportsFP16Inference(mask) || SupportsBfloat16Inference(mask);
84 }
85 
SupportsEitherFP16OrFP32Accumulation(const ReducedPrecisionSupport & mask)86 inline bool SupportsEitherFP16OrFP32Accumulation(
87     const ReducedPrecisionSupport& mask) {
88   return SupportsFP16Accumulation(mask) != SupportsFP32Accumulation(mask);
89 }
90 
91 // Return the key-value pair for reduced precision support metadata.
92 // Example: mask = Float16Inference | Bfloat16Inference | Float32Accumulation;
93 // Returned value would be <"reduced_precision_support", "fp16bf16accfp32">.
MetadataForReducedPrecisionSupport(const ReducedPrecisionSupport & mask)94 inline std::pair<std::string, std::string> MetadataForReducedPrecisionSupport(
95     const ReducedPrecisionSupport& mask) {
96   TFLITE_DCHECK(SupportsReducedPrecisionInference(mask));
97   TFLITE_DCHECK(SupportsEitherFP16OrFP32Accumulation(mask));
98   std::string value = "";
99   if (SupportsFP16Inference(mask)) {
100     value += kTfLiteFloat16String;
101   }
102   if (SupportsBfloat16Inference(mask)) {
103     value += kTfLiteBfloat16String;
104   }
105   value += kTfLiteAccumulationString;
106   if (SupportsFP16Accumulation(mask)) {
107     value += kTfLiteFloat16String;
108   } else if (SupportsFP32Accumulation(mask)) {
109     value += kTfLiteFloat32String;
110   }
111   return std::make_pair(std::string(kTfLiteReducedPrecisionKey), value);
112 }
113 
ReadInferenceType(const std::string & metadata,size_t * idx,ReducedPrecisionSupport * mask)114 inline bool ReadInferenceType(const std::string& metadata, size_t* idx,
115                               ReducedPrecisionSupport* mask) {
116   if (metadata.substr(*idx, 4) == kTfLiteFloat16String) {
117     *idx += 4;
118     *mask = *mask | ReducedPrecisionSupport::Float16Inference;
119     return true;
120   } else if (metadata.substr(*idx, 4) == kTfLiteBfloat16String) {
121     *idx += 4;
122     *mask = *mask | ReducedPrecisionSupport::Bfloat16Inference;
123     return true;
124   }
125   return false;
126 }
127 
ReadAccumulationType(const std::string & metadata,size_t * idx,ReducedPrecisionSupport * mask)128 inline bool ReadAccumulationType(const std::string& metadata, size_t* idx,
129                                  ReducedPrecisionSupport* mask) {
130   if (metadata.substr(*idx, 4) == kTfLiteFloat16String) {
131     *idx += 4;
132     *mask = *mask | ReducedPrecisionSupport::Float16Accumulation;
133     return true;
134   } else if (metadata.substr(*idx, 4) == kTfLiteFloat32String) {
135     *idx += 4;
136     *mask = *mask | ReducedPrecisionSupport::Float32Accumulation;
137     return true;
138   }
139   return false;
140 }
141 
142 // If the string is valid, set the given mask to indicate the state in
143 // string and return true. If the string is invalid, return false.
144 // A valid string is:
145 // >= 1 valid inference types + accumulation token + 1 valid accumulation type.
146 // Valid examples would be: "fp16accfp16", "bf16accfp32"
SetMaskFromReducedPrecisionMetadata(const std::string & metadata,ReducedPrecisionSupport * mask)147 inline bool SetMaskFromReducedPrecisionMetadata(const std::string& metadata,
148                                                 ReducedPrecisionSupport* mask) {
149   bool check = true;
150   size_t idx = 0;
151   ReducedPrecisionSupport rsp = ReducedPrecisionSupport::None;
152   do {
153     check = ReadInferenceType(metadata, &idx, &rsp);
154   } while (check);
155   // Ensure we read at least 1 inference type.
156   if (idx == 0) {
157     return false;
158   }
159   // Next read the accumulation token.
160   if (metadata.substr(idx, 3) != kTfLiteAccumulationString) {
161     return false;
162   }
163   idx += std::string(kTfLiteAccumulationString).size();
164   // Next read a valid accumulation type.
165   if (!ReadAccumulationType(metadata, &idx, &rsp)) {
166     return false;
167   }
168   // This should be the end of string.
169   if (idx != metadata.length()) {
170     return false;
171   }
172   // The string is a valid mask description. Set the value and return.
173   *mask = rsp;
174   return true;
175 }
176 
177 }  // namespace optimize
178 }  // namespace tflite
179 
180 #endif  // TENSORFLOW_LITE_TOOLS_OPTIMIZE_REDUCED_PRECISION_SUPPORT_H
181