1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_
18
19 #include <algorithm>
20 #include <limits>
21 #include <numeric>
22 #include <vector>
23
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/platform/bfloat16.h"
30
31 namespace xla {
32
33 // Represents the range used for quantization
34 struct QuantizedRange {
35 QuantizedRange() = default;
QuantizedRangeQuantizedRange36 QuantizedRange(float min_in, float max_in) : min(min_in), max(max_in) {}
37
38 bool operator==(const QuantizedRange& rhs) const {
39 return this->min == rhs.min && this->max == rhs.max;
40 }
41
42 bool operator!=(const QuantizedRange& rhs) const { return !(*this == rhs); }
43
44 tensorflow::bfloat16 min = tensorflow::bfloat16(0.0f);
45 tensorflow::bfloat16 max = tensorflow::bfloat16(0.0f);
46 };
47
48 template <typename T>
PackToUint32(absl::Span<const T> input)49 inline std::vector<uint32_t> PackToUint32(absl::Span<const T> input) {
50 const int64_t kElementsPerPack = sizeof(uint32_t) / sizeof(T);
51 const int64_t input_size = input.size();
52 const int64_t output_size = CeilOfRatio(input_size, kElementsPerPack);
53
54 std::vector<uint32_t> output_vec;
55 constexpr int64_t kShiftBits = sizeof(T) / sizeof(uint8_t) * CHAR_BIT;
56
57 for (int64_t i = 0; i < output_size; i++) {
58 uint32_t result = 0;
59 for (int64_t p = 0; p < kElementsPerPack; p++) {
60 int64_t index = i * kElementsPerPack + p;
61 if (index < input_size) {
62 int64_t total_shift_bits = kShiftBits * (kElementsPerPack - p - 1);
63 result |= (input[index] << total_shift_bits);
64 }
65 }
66 output_vec.push_back(result);
67 }
68
69 return output_vec;
70 }
71
72 // Dequantize the quantized input of packed uint32_t to bfloat16.
73 // Only uint8_t or uint16_t is supported for the original unpacked input.
74 // Returns a tensor of shape [d0,..., dn * unpack_size] if
75 // input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T).
76 // If transpose_output is true, will return a tensor of shape
77 // [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when
78 // input's rank higher than 1. The input needs to be transposed to use
79 // transpose_output feature.
80 template <typename T>
81 inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range,
82 absl::string_view mode_string = "MIN_COMBINED",
83 bool transpose_output = false) {
84 XlaBuilder* const builder = input.builder();
85 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
86 float half_range =
87 !std::is_signed<T>::value
88 ? 0.0f
89 : (static_cast<float>(std::numeric_limits<T>::max()) -
90 std::numeric_limits<T>::min() + 1) /
91 2.0f;
92 const int64_t unpack_size = sizeof(uint32_t) / sizeof(T);
93 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(input));
94
95 auto element_type = shape.element_type();
96 if (element_type != U32) {
97 return InvalidArgument(
98 "Only U32 is supported for input type of xla::Dequantize Op.");
99 }
100
101 // Broadcast the input to [unpack_size, d0, ..., dn] if input size is
102 // [d0, ..., dn].
103 auto broadcast_input = Broadcast(input, {unpack_size});
104
105 XlaOp iota_r1 = Iota(builder, U32, unpack_size);
106 // Highest significant bytes needs to shift more bytes than lower
107 // significant bytes.
108 XlaOp shift_bytes =
109 xla::ConstantR0<uint32_t>(builder, unpack_size - 1) - iota_r1;
110
111 const int bytes_of_type = sizeof(T) / sizeof(uint8_t);
112 std::vector<uint32_t> shift_vec(unpack_size, CHAR_BIT * bytes_of_type);
113 XlaOp shift_bits =
114 shift_bytes * xla::ConstantR1<uint32_t>(builder, shift_vec);
115
116 // Make bit_mask for different data type T.
117 uint32_t bit_mask = 0x00000000;
118 for (int i = 0; i < bytes_of_type; i++) {
119 bit_mask <<= CHAR_BIT;
120 bit_mask |= 0x000000ff;
121 }
122
123 std::vector<int64_t> shift_transpose_dimensions(shape.dimensions_size());
124 std::iota(shift_transpose_dimensions.begin(),
125 shift_transpose_dimensions.end(), 0);
126 shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1,
127 shape.dimensions_size());
128
129 // Shift the input by sizeof(T) bytes and apply bit_mask to unpack.
130 XlaOp shifted_input = ShiftRightLogical(
131 broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()),
132 shift_transpose_dimensions));
133 XlaOp unpack_input =
134 And(shifted_input, xla::ConstantR0<uint32_t>(builder, bit_mask));
135
136 XlaOp result;
137
138 if (mode_string == "MIN_COMBINED") {
139 const tensorflow::bfloat16 scale_factor =
140 (range.max - range.min) /
141 (static_cast<tensorflow::bfloat16>(std::numeric_limits<T>::max() -
142 std::numeric_limits<T>::min()));
143 // result = bfloat16(input + half_range) * scale_factor + range.min
144 XlaOp unpack_input_bf16 = ConvertElementType(unpack_input, BF16);
145 XlaOp half_range_bf16 = xla::ConstantR0<tensorflow::bfloat16>(
146 builder, static_cast<bfloat16>(half_range));
147 XlaOp sum = unpack_input_bf16 + half_range_bf16;
148
149 result =
150 sum * xla::ConstantR0<tensorflow::bfloat16>(builder, scale_factor) +
151 xla::ConstantR0<tensorflow::bfloat16>(builder, range.min);
152 } else {
153 // TODO(wangtao): support other modes.
154 return InvalidArgument(
155 "Only MIN_COMBINED mode is supported in xla::Dequantize Op.");
156 }
157
158 std::vector<int64_t> transpose_dimensions(shape.dimensions_size());
159 std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1);
160 std::reverse(transpose_dimensions.begin(), transpose_dimensions.end());
161 transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0);
162
163 // Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0].
164 XlaOp transposed_result = Transpose(result, transpose_dimensions);
165
166 // Reshape to be [dn * unpack_size, dn-1, ..., d1, d0].
167 XlaOp reshaped_result = Collapse(transposed_result, {0, 1});
168
169 // Return the transpose result if transpose_output is true.
170 if (transpose_output) {
171 return reshaped_result;
172 }
173
174 // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size].
175 std::vector<int64_t> result_dimensions(shape.dimensions_size());
176 std::iota(result_dimensions.begin(), result_dimensions.end(), 0);
177 std::reverse(result_dimensions.begin(), result_dimensions.end());
178
179 return Transpose(reshaped_result, result_dimensions);
180 });
181 }
182
183 } // namespace xla
184
185 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QUANTIZE_H_
186