xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/quantize.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_
18 
19 #include <algorithm>
20 #include <limits>
21 #include <numeric>
22 #include <vector>
23 
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/platform/bfloat16.h"
30 
31 namespace xla {
32 
33 // Represents the range used for quantization
34 struct QuantizedRange {
35   QuantizedRange() = default;
QuantizedRangeQuantizedRange36   QuantizedRange(float min_in, float max_in) : min(min_in), max(max_in) {}
37 
38   bool operator==(const QuantizedRange& rhs) const {
39     return this->min == rhs.min && this->max == rhs.max;
40   }
41 
42   bool operator!=(const QuantizedRange& rhs) const { return !(*this == rhs); }
43 
44   tensorflow::bfloat16 min = tensorflow::bfloat16(0.0f);
45   tensorflow::bfloat16 max = tensorflow::bfloat16(0.0f);
46 };
47 
48 template <typename T>
PackToUint32(absl::Span<const T> input)49 inline std::vector<uint32_t> PackToUint32(absl::Span<const T> input) {
50   const int64_t kElementsPerPack = sizeof(uint32_t) / sizeof(T);
51   const int64_t input_size = input.size();
52   const int64_t output_size = CeilOfRatio(input_size, kElementsPerPack);
53 
54   std::vector<uint32_t> output_vec;
55   constexpr int64_t kShiftBits = sizeof(T) / sizeof(uint8_t) * CHAR_BIT;
56 
57   for (int64_t i = 0; i < output_size; i++) {
58     uint32_t result = 0;
59     for (int64_t p = 0; p < kElementsPerPack; p++) {
60       int64_t index = i * kElementsPerPack + p;
61       if (index < input_size) {
62         int64_t total_shift_bits = kShiftBits * (kElementsPerPack - p - 1);
63         result |= (input[index] << total_shift_bits);
64       }
65     }
66     output_vec.push_back(result);
67   }
68 
69   return output_vec;
70 }
71 
72 // Dequantize the quantized input of packed uint32_t to bfloat16.
73 // Only uint8_t or uint16_t is supported for the original unpacked input.
74 // Returns a tensor of shape [d0,..., dn * unpack_size] if
75 // input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T).
76 // If transpose_output is true, will return a tensor of shape
77 // [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when
78 // input's rank higher than 1. The input needs to be transposed to use
79 // transpose_output feature.
80 template <typename T>
81 inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range,
82                         absl::string_view mode_string = "MIN_COMBINED",
83                         bool transpose_output = false) {
84   XlaBuilder* const builder = input.builder();
85   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
86     float half_range =
87         !std::is_signed<T>::value
88             ? 0.0f
89             : (static_cast<float>(std::numeric_limits<T>::max()) -
90                std::numeric_limits<T>::min() + 1) /
91                   2.0f;
92     const int64_t unpack_size = sizeof(uint32_t) / sizeof(T);
93     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(input));
94 
95     auto element_type = shape.element_type();
96     if (element_type != U32) {
97       return InvalidArgument(
98           "Only U32 is supported for input type of xla::Dequantize Op.");
99     }
100 
101     // Broadcast the input to [unpack_size, d0, ..., dn] if input size is
102     // [d0, ..., dn].
103     auto broadcast_input = Broadcast(input, {unpack_size});
104 
105     XlaOp iota_r1 = Iota(builder, U32, unpack_size);
106     // Highest significant bytes needs to shift more bytes than lower
107     // significant bytes.
108     XlaOp shift_bytes =
109         xla::ConstantR0<uint32_t>(builder, unpack_size - 1) - iota_r1;
110 
111     const int bytes_of_type = sizeof(T) / sizeof(uint8_t);
112     std::vector<uint32_t> shift_vec(unpack_size, CHAR_BIT * bytes_of_type);
113     XlaOp shift_bits =
114         shift_bytes * xla::ConstantR1<uint32_t>(builder, shift_vec);
115 
116     // Make bit_mask for different data type T.
117     uint32_t bit_mask = 0x00000000;
118     for (int i = 0; i < bytes_of_type; i++) {
119       bit_mask <<= CHAR_BIT;
120       bit_mask |= 0x000000ff;
121     }
122 
123     std::vector<int64_t> shift_transpose_dimensions(shape.dimensions_size());
124     std::iota(shift_transpose_dimensions.begin(),
125               shift_transpose_dimensions.end(), 0);
126     shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1,
127                                       shape.dimensions_size());
128 
129     // Shift the input by sizeof(T) bytes and apply bit_mask to unpack.
130     XlaOp shifted_input = ShiftRightLogical(
131         broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()),
132                                    shift_transpose_dimensions));
133     XlaOp unpack_input =
134         And(shifted_input, xla::ConstantR0<uint32_t>(builder, bit_mask));
135 
136     XlaOp result;
137 
138     if (mode_string == "MIN_COMBINED") {
139       const tensorflow::bfloat16 scale_factor =
140           (range.max - range.min) /
141           (static_cast<tensorflow::bfloat16>(std::numeric_limits<T>::max() -
142                                              std::numeric_limits<T>::min()));
143       // result = bfloat16(input + half_range) * scale_factor + range.min
144       XlaOp unpack_input_bf16 = ConvertElementType(unpack_input, BF16);
145       XlaOp half_range_bf16 = xla::ConstantR0<tensorflow::bfloat16>(
146           builder, static_cast<bfloat16>(half_range));
147       XlaOp sum = unpack_input_bf16 + half_range_bf16;
148 
149       result =
150           sum * xla::ConstantR0<tensorflow::bfloat16>(builder, scale_factor) +
151           xla::ConstantR0<tensorflow::bfloat16>(builder, range.min);
152     } else {
153       // TODO(wangtao): support other modes.
154       return InvalidArgument(
155           "Only MIN_COMBINED mode is supported in xla::Dequantize Op.");
156     }
157 
158     std::vector<int64_t> transpose_dimensions(shape.dimensions_size());
159     std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1);
160     std::reverse(transpose_dimensions.begin(), transpose_dimensions.end());
161     transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0);
162 
163     // Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0].
164     XlaOp transposed_result = Transpose(result, transpose_dimensions);
165 
166     // Reshape to be [dn * unpack_size, dn-1, ..., d1, d0].
167     XlaOp reshaped_result = Collapse(transposed_result, {0, 1});
168 
169     // Return the transpose result if transpose_output is true.
170     if (transpose_output) {
171       return reshaped_result;
172     }
173 
174     // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size].
175     std::vector<int64_t> result_dimensions(shape.dimensions_size());
176     std::iota(result_dimensions.begin(), result_dimensions.end(), 0);
177     std::reverse(result_dimensions.begin(), result_dimensions.end());
178 
179     return Transpose(reshaped_result, result_dimensions);
180   });
181 }
182 
183 }  // namespace xla
184 
185 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_
186