xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/quantization_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
17 #define TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
18 
19 #include <cmath>
20 #define EIGEN_USE_THREADS
21 
22 // This is a set of functions that standardizes how quantized values are
23 // interpreted as float numbers.
24 // All of the current implementations are for reference and have not been
25 // optimized. They should be implementable using fixed point representations
26 // to avoid a dependency on floating-point hardware.
27 
28 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
29 #define QUANTIZATION_UTILS_USE_NEON
30 #include <arm_neon.h>
31 #endif
32 
33 #include <array>
34 
35 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
36 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
37 #include "public/gemmlowp.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/lib/core/threadpool.h"
40 
41 namespace tensorflow {
42 
43 // We have to be able to detect and handle overflows in int32, so this function
44 // uses doubles and int64's to make sure we have enough room.
45 template <class T>
FloatToQuantizedUnclamped(float input,float range_min,float range_max)46 inline int64_t FloatToQuantizedUnclamped(float input, float range_min,
47                                          float range_max) {
48   const int64_t lowest_quantized =
49       static_cast<double>(Eigen::NumTraits<T>::lowest());
50   if (range_min == range_max) {
51     return lowest_quantized;
52   }
53   const int number_of_bits = sizeof(T) * 8;
54   const int64_t number_of_steps = static_cast<int64_t>(1) << number_of_bits;
55   const double range_adjust = (number_of_steps / (number_of_steps - 1.0));
56   const double range = ((range_max - range_min) * range_adjust);
57   const double range_scale = (number_of_steps / range);
58   int64_t quantized =
59       (round(input * range_scale) - round(range_min * range_scale));
60   quantized += lowest_quantized;
61   return quantized;
62 }
63 
64 template <>
65 inline int64_t FloatToQuantizedUnclamped<float>(float input, float range_min,
66                                                 float range_max) {
67   return -1;
68 }
69 
70 // This converts the float into the final quantized type, clamping/saturating
71 // any over or underflows.
72 template <class T>
FloatToQuantized(float input,float range_min,float range_max)73 T FloatToQuantized(float input, float range_min, float range_max) {
74   if (std::is_same<T, float>::value) {
75     // Specialization for float. This is used in reference implementation
76     // for float which is useful to compare performance between float
77     // and quantized type.
78     return input;
79   }
80   int64_t quantized = FloatToQuantizedUnclamped<T>(input, range_min, range_max);
81   const int64_t lowest_quantized =
82       static_cast<int64_t>(Eigen::NumTraits<T>::lowest());
83   const int64_t highest_quantized =
84       static_cast<int64_t>(Eigen::NumTraits<T>::highest());
85   quantized = std::max(quantized, lowest_quantized);
86   quantized = std::min(quantized, highest_quantized);
87   return static_cast<T>(static_cast<int32>(quantized));
88 }
89 
90 template <class T>
QuantizedToFloat(T input,float range_min,float range_max)91 float QuantizedToFloat(T input, float range_min, float range_max) {
92   if (std::is_same<T, float>::value) {
93     // Specialization for float. This is used in reference implementation
94     // for float which is useful to compare performance between float
95     // and quantized type.
96     return input;
97   }
98   if (range_min == range_max) {
99     return range_min;
100   }
101   const int number_of_bits = sizeof(T) * 8;
102   const int64_t number_of_steps = static_cast<int64_t>(1) << number_of_bits;
103   const double range_adjust = (number_of_steps / (number_of_steps - 1.0));
104   const double range = ((range_max - range_min) * range_adjust);
105   const double range_scale = (range / number_of_steps);
106   const int64_t lowest_quantized =
107       static_cast<int64_t>(Eigen::NumTraits<T>::lowest());
108   const double offset_input = static_cast<double>(input) - lowest_quantized;
109   // For compatibility with DEQUANTIZE_WITH_EIGEN, we should convert
110   // range_scale to a float, otherwise range_min_rounded might be slightly
111   // different.
112   const double range_min_rounded =
113       std::round(range_min / static_cast<float>(range_scale)) *
114       static_cast<float>(range_scale);
115   const double result = range_min_rounded + (offset_input * range_scale);
116   return static_cast<float>(result);
117 }
118 
119 template <class T>
FloatForOneQuantizedLevel(float range_min,float range_max)120 float FloatForOneQuantizedLevel(float range_min, float range_max) {
121   const int64_t highest = static_cast<int64_t>(Eigen::NumTraits<T>::highest());
122   const int64_t lowest = static_cast<int64_t>(Eigen::NumTraits<T>::lowest());
123   const float float_for_one_quantized_level =
124       (range_max - range_min) / (highest - lowest);
125   return float_for_one_quantized_level;
126 }
127 
128 template <class T1, class T2, class T3>
QuantizationRangeForMultiplication(float min_a,float max_a,float min_b,float max_b,float * min_c,float * max_c)129 void QuantizationRangeForMultiplication(float min_a, float max_a, float min_b,
130                                         float max_b, float* min_c,
131                                         float* max_c) {
132   const float a_float_for_one_quant_level =
133       FloatForOneQuantizedLevel<T1>(min_a, max_a);
134   const float b_float_for_one_quant_level =
135       FloatForOneQuantizedLevel<T2>(min_b, max_b);
136 
137   const int64_t c_highest =
138       static_cast<int64_t>(Eigen::NumTraits<T3>::highest());
139   const int64_t c_lowest = static_cast<int64_t>(Eigen::NumTraits<T3>::lowest());
140   const float c_float_for_one_quant_level =
141       a_float_for_one_quant_level * b_float_for_one_quant_level;
142 
143   *min_c = c_float_for_one_quant_level * c_lowest;
144   *max_c = c_float_for_one_quant_level * c_highest;
145 }
146 
147 // input_array is an eigen Tensor.  q2f is a QuantizedToFloatStruct.
148 // This evaluates to an eigen tensor expression, to be used like:
149 // auto tensor = DEQUANTIZE_WITH_EIGEN(input_tensor, q2f);
150 #define DEQUANTIZE_WITH_EIGEN(input_array, q2f)                         \
151   ((q2f.range_min_rounded - q2f.lowest_quantized() * q2f.range_scale) + \
152    input_array.template cast<float>() * q2f.range_scale)
153 
154 // input_array is an eigen Tensor.  f2q is a FloatToQuantizedStruct.
155 // OutputType is the type of output (e.g. quint8).
156 // This evaluates to an eigen tensor expression, to be used like:
157 // auto tensor = QUANTIZE_WITH_EIGEN(input_tensor, f2q, T);
158 #define QUANTIZE_WITH_EIGEN(input_array, f2q, OutputType) \
159   ((input_array * f2q.range_scale).round() -              \
160    (f2q.range_min_scaled - f2q.lowest_quantized()))       \
161       .cwiseMax(f2q.lower_bound_float())                  \
162       .cwiseMin(f2q.upper_bound_float())                  \
163       .template cast<int32>()                             \
164       .template cast<OutputType>()
165 
166 // For use with DEQUANTIZE_WITH_EIGEN.
167 template <typename T>
168 struct QuantizedToFloatStruct {
169   static constexpr int number_of_bits = sizeof(T) * 8;
170   static constexpr int64_t number_of_steps = static_cast<int64_t>(1)
171                                              << number_of_bits;
172 
lowest_quantizedQuantizedToFloatStruct173   static float lowest_quantized() {
174     return static_cast<float>(Eigen::NumTraits<T>::lowest());
175   }
176 
QuantizedToFloatStructQuantizedToFloatStruct177   QuantizedToFloatStruct(float range_min, float range_max)
178       : range_min(range_min),
179         range_scale((range_max - range_min) / (number_of_steps - 1.0)),
180         range_min_rounded(range_max == range_min
181                               ? range_min
182                               : std::round(range_min / range_scale) *
183                                     range_scale) {}
184 
185   const float range_min;
186   const float range_scale;
187   const float range_min_rounded;
188 };
189 
190 // For use with QUANTIZE_WITH_EIGEN.
191 template <typename T>
192 struct FloatToQuantizedStruct {
193   static constexpr int number_of_bits = sizeof(T) * 8;
194   static constexpr int64_t number_of_steps = static_cast<int64_t>(1)
195                                              << number_of_bits;
196   static constexpr double range_adjust =
197       (number_of_steps / (number_of_steps - 1.0));
198 
199   // Casting QInt32's lowest or highest to a float gives a float that can't be
200   // cast back to int32 or QInt32.  Instead, use bounds that can be converted
201   // back to int32 without going outside the range of an int32.
lower_bound_floatFloatToQuantizedStruct202   static float lower_bound_float() {
203     return Eigen::numext::maxi(
204         static_cast<float>(Eigen::NumTraits<T>::lowest()), -2.147483648e+09f);
205   }
upper_bound_floatFloatToQuantizedStruct206   static float upper_bound_float() {
207     return Eigen::numext::mini(
208         static_cast<float>(Eigen::NumTraits<T>::highest()), +2.147483520e+09f);
209   }
210 
lowest_quantizedFloatToQuantizedStruct211   static float lowest_quantized() {
212     return static_cast<float>(Eigen::NumTraits<T>::lowest());
213   }
214 
FloatToQuantizedStructFloatToQuantizedStruct215   FloatToQuantizedStruct(float range_min, float range_max)
216       : range_min(range_min),
217         range_scale(range_max == range_min
218                         ? 0.0
219                         : (number_of_steps - 1.0) / (range_max - range_min)),
220         range_min_scaled(std::round(range_min * range_scale)) {}
221 
222   const float range_min;
223   const float range_scale;
224   const float range_min_scaled;
225 };
226 
227 template <class T1, class T2>
RequantizeInNewRange(T1 input,float min_input,float max_input,float min_new,float max_new)228 inline T2 RequantizeInNewRange(T1 input, float min_input, float max_input,
229                                float min_new, float max_new) {
230   const float input_float = QuantizedToFloat<T1>(input, min_input, max_input);
231   return FloatToQuantized<T2>(input_float, min_new, max_new);
232 }
233 
234 template <class T1, class T2>
RequantizeManyInNewRange(const T1 * input,int64_t count,float min_input,float max_input,float min_output,float max_output,T2 * output)235 inline void RequantizeManyInNewRange(const T1* input, int64_t count,
236                                      float min_input, float max_input,
237                                      float min_output, float max_output,
238                                      T2* output) {
239   for (size_t index = 0; index < count; ++index) {
240     const float input_float =
241         QuantizedToFloat<T1>(input[index], min_input, max_input);
242     output[index] = FloatToQuantized<T2>(input_float, min_output, max_output);
243   }
244 }
245 
246 // Because converting 32-bit accumulated results down to eight bit is a common
247 // case, we have a specialized code path to handle it as efficiently as
248 // possible using only fixed-point math for the inner loop.
RequantizeManyInNewRangeReference(const qint32 * input,int64_t count,float min_input,float max_input,float min_output,float max_output,quint8 * output)249 inline void RequantizeManyInNewRangeReference(const qint32* input,
250                                               int64_t count, float min_input,
251                                               float max_input, float min_output,
252                                               float max_output,
253                                               quint8* output) {
254   // Initially we calculate all the constants we need once, before we go into
255   // the inner loop.  If this is updated, also update the Eigen version.
256   const int fp_shift = 16;
257   const float input_range = max_input - min_input;
258   const float output_range = max_output - min_output;
259   const float recip_output_range =
260       output_range == 0.0 ? 0.0 : (255.0 / output_range);
261   const float input_rezero = (min_input + max_input) / 2.0;
262   const int64_t range_scale_fp =
263       output_range == 0.0 ? 0.0
264                           : static_cast<int64_t>(255.0 * (1 << fp_shift) *
265                                                  input_range / output_range);
266   const int64_t input_offset_fp =
267       static_cast<int64_t>(input_rezero * recip_output_range * (1 << fp_shift));
268   const int64_t output_offset_fp =
269       output_range == 0.0
270           ? 0
271           : std::lround((1 << fp_shift) * (min_output * 255.0) / output_range);
272   const int64_t rounding_delta = 1 << (fp_shift - 1);
273 
274   // Inside this loop we just do minimal adds, multiplies, and shifts, in a way
275   // that could be easily adapted for a SIMD implementation. It should also be
276   // possible to perform all the calculations in 32-bit rather than 64, but
277   // that's not been implemented yet.
278   for (int64_t index = 0; index < count; ++index) {
279     const int64_t input_value = static_cast<int64_t>(input[index]);
280     const int64_t fp_value =
281         ((input_value * range_scale_fp) >> 32) + input_offset_fp;
282     const int64_t offset_intermediate = fp_value - output_offset_fp;
283     const int64_t round_intermediate = offset_intermediate + rounding_delta;
284     int64_t quantized_int64 = round_intermediate >> fp_shift;
285     quantized_int64 = std::max(quantized_int64, int64_t{0});
286     quantized_int64 = std::min(quantized_int64, int64_t{255});
287     output[index] = static_cast<quint8>(static_cast<int32>(quantized_int64));
288   }
289 }
290 
291 // Another common case is converting eight bit inputs up to thirty two bits, so
292 // we have specialized fixed-point code to accelerate that. There is also a NEON
293 // version for ARM devices below.
RequantizeManyInNewRange8To32BitReference(const quint8 * input,int64_t count,float min_input,float max_input,float min_output,float max_output,qint32 * output)294 inline void RequantizeManyInNewRange8To32BitReference(
295     const quint8* input, int64_t count, float min_input, float max_input,
296     float min_output, float max_output, qint32* output) {
297   const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input);
298   const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input);
299   const int64_t code_0_int64 =
300       FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output);
301   const int64_t code_1_int64 =
302       FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output);
303   const int32_t mult_int32 = code_1_int64 - code_0_int64;
304   const int64_t lowest_quantized =
305       static_cast<int64_t>(Eigen::NumTraits<qint32>::lowest());
306   const int64_t highest_quantized =
307       static_cast<int64_t>(Eigen::NumTraits<qint32>::highest());
308   for (int64_t i = 0; i < count; ++i) {
309     const int64_t input_value = static_cast<int64_t>(input[i]);
310     int64_t output_value = code_0_int64 + (input_value * mult_int32);
311     output_value = std::max(output_value, lowest_quantized);
312     output_value = std::min(output_value, highest_quantized);
313     output[i] = static_cast<int32>(output_value);
314   }
315 }
316 
317 #ifdef QUANTIZATION_UTILS_USE_NEON
318 // Speeds up the 32->8bit conversion using fixed-point arithmetic and NEON SIMD
319 // intrinsics for ARM platforms.
RequantizeManyInNewRangeNeon(const qint32 * input,int64 count,float min_input,float max_input,float min_output,float max_output,quint8 * output)320 inline void RequantizeManyInNewRangeNeon(const qint32* input, int64 count,
321                                          float min_input, float max_input,
322                                          float min_output, float max_output,
323                                          quint8* output) {
324   // Initially we calculate all the constants we need once, before we go into
325   // the inner loop.  If this is updated, also update the Eigen version.
326   const int fp_shift = 16;
327 
328   // Calculate range variables in advance.
329   // Input range.
330   const float input_range = max_input - min_input;
331   // Output range.
332   const float output_range = max_output - min_output;
333   // Ratio of output range.
334   const float recip_output_range =
335       output_range == 0.0 ? 0.0 : (255.0 / output_range);
336   // Average of input range as zero position of input.
337   const float input_rezero = (min_input + max_input) / 2.0;
338   // In-out range scale.
339   const int32 range_scale_fp =
340       output_range == 0.0 ? 0.0
341                           : static_cast<int32>(255.0 * (1 << (fp_shift - 16)) *
342                                                input_range / output_range);
343   // Input zero position offset to output.
344   const int32 input_offset_fp =
345       static_cast<int32>(input_rezero * recip_output_range * (1 << fp_shift));
346   // Output min offset.
347   const int32 output_offset_fp =
348       output_range == 0.0
349           ? 0
350           : static_cast<int32>((1 << fp_shift) * (min_output * 255.0) /
351                                output_range);
352   const int32 rounding_delta = 1 << (fp_shift - 1);
353 
354   // broadcast range to each lane
355   const int32x4_t range_scale_fp_32x4 = vmovq_n_s32(range_scale_fp);
356   const int32x4_t input_offset_fp_32x4 = vmovq_n_s32(input_offset_fp);
357   const int32x4_t output_offset_fp_32x4 = vmovq_n_s32(output_offset_fp);
358   const int32x4_t rounding_delta_32x4 = vmovq_n_s32(rounding_delta);
359 
360   int64 index = 0;
361   // Use SIMD to requantize.
362   for (; index < (count - 7); index += 8) {
363     const int32* input_ptr = &(input->value) + index;
364     const int32x4_t input_value_low_32x4 = vld1q_s32(input_ptr);
365     const int32x4_t input_value_high_32x4 = vld1q_s32(input_ptr + 4);
366     const int32x4_t fp_value_low_32x4 = vaddq_s32(
367         input_offset_fp_32x4,
368         vmulq_s32(vshrq_n_s32(input_value_low_32x4, 16), range_scale_fp_32x4));
369     const int32x4_t fp_value_high_32x4 = vaddq_s32(
370         input_offset_fp_32x4,
371         vmulq_s32(vshrq_n_s32(input_value_high_32x4, 16), range_scale_fp_32x4));
372     const int32x4_t offset_intermediate_low_32x4 =
373         vsubq_s32(fp_value_low_32x4, output_offset_fp_32x4);
374     const int32x4_t offset_intermediate_high_32x4 =
375         vsubq_s32(fp_value_high_32x4, output_offset_fp_32x4);
376     const int32x4_t round_intermediate_low_32x4 =
377         vaddq_s32(offset_intermediate_low_32x4, rounding_delta_32x4);
378     const int32x4_t round_intermediate_high_32x4 =
379         vaddq_s32(offset_intermediate_high_32x4, rounding_delta_32x4);
380     const int16x4_t quantized_low_16x4 =
381         vqmovn_s32(vshrq_n_s32(round_intermediate_low_32x4, fp_shift));
382     const int16x4_t quantized_high_16x4 =
383         vqmovn_s32(vshrq_n_s32(round_intermediate_high_32x4, fp_shift));
384     const uint8x8_t quantized_8x8 =
385         vqmovun_s16(vcombine_s16(quantized_low_16x4, quantized_high_16x4));
386     uint8* output_ptr = &(output->value) + index;
387     vst1_u8(output_ptr, quantized_8x8);
388   }
389 
390   // Requantize remaining elements in array without SIMD.
391   for (; index < count; ++index) {
392     const int32 input_value = static_cast<int32>(input[index]);
393     const int32 fp_value =
394         static_cast<int32>(
395             (static_cast<int32>(input_value >> 16) * (range_scale_fp))) +
396         input_offset_fp;
397     const int32 offset_intermediate = fp_value - output_offset_fp;
398     const int32 round_intermediate = offset_intermediate + rounding_delta;
399     int32 quantized_int32 = round_intermediate >> fp_shift;
400     quantized_int32 = std::max(quantized_int32, 0);
401     quantized_int32 = std::min(quantized_int32, 255);
402     output[index] = static_cast<quint8>(static_cast<int32>(quantized_int32));
403   }
404 }
405 
406 template <>
407 inline void RequantizeManyInNewRange<qint32, quint8>(
408     const qint32* input, int64 count, float min_input, float max_input,
409     float min_output, float max_output, quint8* output) {
410   const float input_range = max_input - min_input;
411   const float output_range = max_output - min_output;
412   if ((input_range / output_range) > 16384.0f) {
413     // Our NEON implementation uses 32-bit math and can't handle very
414     // large ranges, so fall back to the reference implementation. We don't
415     // expect these to be common in models, so this shouldn't be a performance
416     // problem in practice.
417     RequantizeManyInNewRangeReference(input, count, min_input, max_input,
418                                       min_output, max_output, output);
419   } else {
420     RequantizeManyInNewRangeNeon(input, count, min_input, max_input, min_output,
421                                  max_output, output);
422   }
423 }
424 
425 // NEON accelerated 16bit rounded division by 2^n.
426 template <int POW>
Divide16x8PowRound(const int16x8_t val)427 inline int16x8_t Divide16x8PowRound(const int16x8_t val) {
428   const int16x8_t val_sign = vshrq_n_s16(val, 15);
429   const int16x8_t val_xor = veorq_s16(val, val_sign);
430   const int16x8_t val_pos = vsubq_s16(val_xor, val_sign);
431   const int16x8_t shifted_val_pos = vrshrq_n_s16(val_pos, POW);
432   const int16x8_t shifted_val_pos_xor = veorq_s16(shifted_val_pos, val_sign);
433   const int16x8_t shifted_val = vsubq_s16(shifted_val_pos_xor, val_sign);
434   return shifted_val;
435 }
436 
437 // NEON accelerated 64bit rounded division by 2^n.
438 template <int POW>
Divide64x2PowRound(const int64x2_t val)439 inline int64x2_t Divide64x2PowRound(const int64x2_t val) {
440   const int64x2_t val_sign = vshrq_n_s64(val, 63);
441   const int64x2_t val_xor = veorq_s64(val, val_sign);
442   const int64x2_t val_pos = vsubq_s64(val_xor, val_sign);
443   const int64x2_t shifted_val_pos = vrshrq_n_s64(val_pos, POW);
444   const int64x2_t shifted_val_pos_xor = veorq_s64(shifted_val_pos, val_sign);
445   const int64x2_t shifted_val = vsubq_s64(shifted_val_pos_xor, val_sign);
446   return shifted_val;
447 }
448 
449 // NEON accelerated 16bit division by 2^n.
450 // CAVEAT: The input must be greater than min-int16 to avoid underflow.
451 template <int POW>
Divide16x8Pow(const int16x8_t val)452 inline int16x8_t Divide16x8Pow(const int16x8_t val) {
453   static constexpr int16 FIRST_BIT_VAL = 0x0000000000000001;
454   static const int16x8_t FIRST_BIT = vmovq_n_s16(FIRST_BIT_VAL);
455   const int16x8_t val_sign = vshrq_n_s16(val, 15);
456   const int16x8_t neg_offset = vandq_s16(val_sign, FIRST_BIT);
457   const int16x8_t val_with_offset = vsubq_s16(val, neg_offset);
458   const int16x8_t shifted_wo_offset =
459       vsraq_n_s16(neg_offset, val_with_offset, POW);
460   return shifted_wo_offset;
461 }
462 
463 // NEON accelerated 64bit division by 2^n.
464 // CAVEAT: The input must be greater than min-int64 to avoid underflow.
465 template <int POW>
Divide64x2Pow(const int64x2_t val)466 inline int64x2_t Divide64x2Pow(const int64x2_t val) {
467   static constexpr int64 FIRST_BIT_VAL = 0x0000000000000001;
468   static const int64x2_t FIRST_BIT = vmovq_n_s64(FIRST_BIT_VAL);
469   const int64x2_t val_sign = vshrq_n_s64(val, 63);
470   const int64x2_t neg_offset = vandq_s64(val_sign, FIRST_BIT);
471   const int64x2_t val_with_offset = vsubq_s64(val, neg_offset);
472   const int64x2_t shifted_wo_offset =
473       vsraq_n_s64(neg_offset, val_with_offset, POW);
474   return shifted_wo_offset;
475 }
476 
477 // 32bit x 2 NEON accelerated lerp computation.
478 template <int RESOLUTION>
ComputeLerp32x2(const int32x2_t top_left,const int32x2_t top_right,const int32x2_t bottom_left,const int32x2_t bottom_right,const int32x2_t x_lerp,const int32x2_t y_lerp)479 inline int32x2_t ComputeLerp32x2(const int32x2_t top_left,
480                                  const int32x2_t top_right,
481                                  const int32x2_t bottom_left,
482                                  const int32x2_t bottom_right,
483                                  const int32x2_t x_lerp,
484                                  const int32x2_t y_lerp) {
485   static_assert(RESOLUTION < 31, "RESOLUTION must be less than 31");
486   constexpr int32 RESOLUTION_MULT32 = (1 << RESOLUTION);
487   static const int32x2_t RESOLUTION_MULT32x2 = vmov_n_s32(RESOLUTION_MULT32);
488 
489   const int64x2_t top_left_x_res = vmull_s32(top_left, RESOLUTION_MULT32x2);
490   const int64x2_t bottom_left_x_res =
491       vmull_s32(bottom_left, RESOLUTION_MULT32x2);
492 
493   const int32x2_t top_right_sub_top_left = vsub_s32(top_right, top_left);
494   const int64x2_t top_x_res =
495       vmlal_s32(top_left_x_res, top_right_sub_top_left, x_lerp);
496   const int32x2_t bottom_right_sub_bottom_left =
497       vsub_s32(bottom_right, bottom_left);
498   const int64x2_t bottom_x_res =
499       vmlal_s32(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp);
500 
501   const int64x2_t bottom_sub_top_x_res = vsubq_s64(bottom_x_res, top_x_res);
502   const int64x2_t bottom_sub_top =
503       Divide64x2Pow<RESOLUTION>(bottom_sub_top_x_res);
504   const int32x2_t bottom_sub_top_32 = vqmovn_s64(bottom_sub_top);
505   const int64x2_t top_add_bottom_sub_top_mul_ylerp_x_res =
506       vmlal_s32(top_x_res, bottom_sub_top_32, y_lerp);
507   const int64x2_t retval =
508       Divide64x2PowRound<RESOLUTION>(top_add_bottom_sub_top_mul_ylerp_x_res);
509   const int32x2_t retval32 = vqmovn_s64(retval);
510   return retval32;
511 }
512 
513 // 8bit x 8 NEON accelerated lerp computation.
514 template <int RESOLUTION>
ComputeLerp8x8(const uint8x8_t top_left8x8,const uint8x8_t top_right8x8,const uint8x8_t bottom_left8x8,const uint8x8_t bottom_right8x8,const int16x8_t x_lerp,const int16x8_t y_lerp)515 inline uint8x8_t ComputeLerp8x8(const uint8x8_t top_left8x8,
516                                 const uint8x8_t top_right8x8,
517                                 const uint8x8_t bottom_left8x8,
518                                 const uint8x8_t bottom_right8x8,
519                                 const int16x8_t x_lerp,
520                                 const int16x8_t y_lerp) {
521   static_assert(RESOLUTION < 8, "RESOLUTION must be less than 8");
522   constexpr uint8 RESOLUTION_MULT_VAL = (1 << RESOLUTION);
523   static const uint8x8_t RESOLUTION_MULT = vdup_n_u8(RESOLUTION_MULT_VAL);
524 
525   const int16x8_t top_left_x_res =
526       vreinterpretq_s16_u16(vmull_u8(top_left8x8, RESOLUTION_MULT));
527   const int16x8_t bottom_left_x_res =
528       vreinterpretq_s16_u16(vmull_u8(bottom_left8x8, RESOLUTION_MULT));
529 
530   const int16x8_t top_right_sub_top_left =
531       vreinterpretq_s16_u16(vsubl_u8(top_right8x8, top_left8x8));
532   const int16x8_t top_x_res =
533       vmlaq_s16(top_left_x_res, top_right_sub_top_left, x_lerp);
534 
535   const int16x8_t bottom_right_sub_bottom_left =
536       vreinterpretq_s16_u16(vsubl_u8(bottom_right8x8, bottom_left8x8));
537   const int16x8_t bottom_x_res =
538       vmlaq_s16(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp);
539 
540   const int16x8_t bottom_sub_top_x_res = vsubq_s16(bottom_x_res, top_x_res);
541   const int16x8_t bottom_sub_top =
542       Divide16x8Pow<RESOLUTION>(bottom_sub_top_x_res);
543   const int16x8_t top_add_bottom_sub_top_mul_ylerp_x_res =
544       vmlaq_s16(top_x_res, bottom_sub_top, y_lerp);
545   const int16x8_t retval16 =
546       Divide16x8PowRound<RESOLUTION>(top_add_bottom_sub_top_mul_ylerp_x_res);
547   const uint8x8_t retval = vmovn_u16(vreinterpretq_u16_s16(retval16));
548   return retval;
549 }
550 
551 // Requantize 8 x 8 quints to 8 x 32 qints in parallel by neon
552 // Return std::array instead of pointer to leverage return value optimization
Requantize8x8To32Neon(const uint8 * input_ptr,const int64x2_t input_0_64x2,const int32x2_t input_mult_32x2)553 inline std::array<int32x4_t, 2> Requantize8x8To32Neon(
554     const uint8* input_ptr, const int64x2_t input_0_64x2,
555     const int32x2_t input_mult_32x2) {
556   const uint8x8_t input_value_8x8 = vld1_u8(input_ptr);
557   const int16x8_t input_value_16x8 =
558       vreinterpretq_s16_u16(vmovl_u8(input_value_8x8));
559   const int16x4_t input_value_low_16x4 = vget_low_s16(input_value_16x8);
560   const int16x4_t input_value_high_16x4 = vget_high_s16(input_value_16x8);
561   const int32x4_t input_value_low_32x4 = vmovl_s16(input_value_low_16x4);
562   const int32x4_t input_value_high_32x4 = vmovl_s16(input_value_high_16x4);
563   const int32x2_t input_value_low_low_32x2 = vget_low_s32(input_value_low_32x4);
564   const int32x2_t input_value_low_high_32x2 =
565       vget_high_s32(input_value_low_32x4);
566   const int32x2_t input_value_high_low_32x2 =
567       vget_low_s32(input_value_high_32x4);
568   const int32x2_t input_value_high_high_32x2 =
569       vget_high_s32(input_value_high_32x4);
570   const int64x2_t mult_result_low_low_64x2 =
571       vmlal_s32(input_0_64x2, input_value_low_low_32x2, input_mult_32x2);
572   const int64x2_t mult_result_low_high_64x2 =
573       vmlal_s32(input_0_64x2, input_value_low_high_32x2, input_mult_32x2);
574   const int64x2_t mult_result_high_low_64x2 =
575       vmlal_s32(input_0_64x2, input_value_high_low_32x2, input_mult_32x2);
576   const int64x2_t mult_result_high_high_64x2 =
577       vmlal_s32(input_0_64x2, input_value_high_high_32x2, input_mult_32x2);
578   const int32x2_t output_value_low_low_32x2 =
579       vqmovn_s64(mult_result_low_low_64x2);
580   const int32x2_t output_value_low_high_32x2 =
581       vqmovn_s64(mult_result_low_high_64x2);
582   const int32x2_t output_value_high_low_32x2 =
583       vqmovn_s64(mult_result_high_low_64x2);
584   const int32x2_t output_value_high_high_32x2 =
585       vqmovn_s64(mult_result_high_high_64x2);
586   const int32x4_t output_value_low_32x4 =
587       vcombine_s32(output_value_low_low_32x2, output_value_low_high_32x2);
588   const int32x4_t output_value_high_32x4 =
589       vcombine_s32(output_value_high_low_32x2, output_value_high_high_32x2);
590   return std::array<int32x4_t, 2>{
591       {output_value_low_32x4, output_value_high_32x4}};
592 }
593 
594 // Speeds up the 8->32bit conversion using fixed-point arithmetic and NEON SIMD
595 // intrinsics for ARM platforms.
596 template <>
597 inline void RequantizeManyInNewRange<quint8, qint32>(
598     const quint8* input, int64 count, float min_input, float max_input,
599     float min_output, float max_output, qint32* output) {
600   // Pre-calculate zero position and multiplier.
601   // Calculate 0 and 1 value in float.
602   const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input);
603   const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input);
604 
605   // Cast 0 and 1 value in int64.
606   const int64 code_0_int64 =
607       FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output);
608   const int64 code_1_int64 =
609       FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output);
610 
611   // Calculate multiplier.
612   const int32 mult_int32 = static_cast<int32>(code_1_int64 - code_0_int64);
613 
614   // Broadcast 0 position and multiplier to lanes
615   const int64x2_t code_0_64x2 = vmovq_n_s64(code_0_int64);
616   const int32x2_t mult_32x2 = vmov_n_s32(mult_int32);
617 
618   int64 i = 0;
619 
620   // Use SIMD to requantize array.
621   for (; i < (count - 7); i += 8) {
622     const uint8* input_ptr = &(input->value) + i;
623     int32* output_ptr = &(output->value) + i;
624     const std::array<int32x4_t, 2> output_value =
625         Requantize8x8To32Neon(input_ptr, code_0_64x2, mult_32x2);
626     vst1q_s32(output_ptr + 0, output_value[0]);
627     vst1q_s32(output_ptr + 4, output_value[1]);
628   }
629 
630   // Requantize remaining elements in array without SIMD.
631   const int64 lowest_quantized =
632       static_cast<int64_t>(Eigen::NumTraits<qint32>::lowest());
633   const int64 highest_quantized =
634       static_cast<int64_t>(Eigen::NumTraits<qint32>::highest());
635 
636   for (; i < count; ++i) {
637     const int64 input_value = static_cast<int64_t>(input[i]);
638     int64 output_value = code_0_int64 + (input_value * mult_int32);
639     output_value = std::max(output_value, lowest_quantized);
640     output_value = std::min(output_value, highest_quantized);
641     output[i] = static_cast<int32>(output_value);
642   }
643 }
644 
645 #else
646 
647 // If SIMD implementations aren't available, then use these default reference
648 // versions.
649 template <>
650 inline void RequantizeManyInNewRange<qint32, quint8>(
651     const qint32* input, int64_t count, float min_input, float max_input,
652     float min_output, float max_output, quint8* output) {
653   RequantizeManyInNewRangeReference(input, count, min_input, max_input,
654                                     min_output, max_output, output);
655 }
656 
657 template <>
658 inline void RequantizeManyInNewRange<quint8, qint32>(
659     const quint8* input, int64_t count, float min_input, float max_input,
660     float min_output, float max_output, qint32* output) {
661   RequantizeManyInNewRange8To32BitReference(input, count, min_input, max_input,
662                                             min_output, max_output, output);
663 }
664 
665 #endif
666 
667 template <int shift>
668 struct int64_right_shift_op {
669   EIGEN_DEVICE_FUNC
operatorint64_right_shift_op670   EIGEN_STRONG_INLINE const int64_t operator()(const int64_t a) const {
671     return a >> shift;
672   }
673 };
674 
675 // See RequantizeManyInNewRange() for a non-eigen reference implementation.
676 template <class T1, class T2>
RequantizeManyInNewRangeUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min_input,float max_input,float min_output,float max_output,Tensor * output)677 inline void RequantizeManyInNewRangeUsingEigen(
678     const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input,
679     float max_input, float min_output, float max_output, Tensor* output) {
680   auto input_array = input.flat<T1>();
681   QuantizedToFloatStruct<T1> q2f(min_input, max_input);
682   auto input_float = DEQUANTIZE_WITH_EIGEN(input_array, q2f);
683   FloatToQuantizedStruct<T2> f2q(min_output, max_output);
684   auto input_requantized = QUANTIZE_WITH_EIGEN(input_float, f2q, T2);
685 
686   output->flat<T2>().device(device) = input_requantized;
687 }
688 
689 // See RequantizeManyInNewRange() for a non-eigen reference implementation.
690 //
691 // Because converting 32-bit accumulated results down to eight bit is a common
692 // case, we have a specialized code path to handle it as efficiently as
693 // possible using only fixed-point math for the inner loop.
694 template <>
695 inline void RequantizeManyInNewRangeUsingEigen<qint32, quint8>(
696     const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input,
697     float max_input, float min_output, float max_output, Tensor* output) {
698   // Initially we calculate all the constants we need once, before we go into
699   // the inner loop.  If this is updated, also update the non-Eigen version.
700   const int fp_shift = 16;
701   const float input_range = max_input - min_input;
702   const float output_range = max_output - min_output;
703   const float recip_output_range =
704       output_range == 0.0 ? 0.0 : (255.0 / output_range);
705   const float input_rezero = (min_input + max_input) / 2.0;
706   const int64_t range_scale_fp =
707       output_range == 0.0 ? 0.0
708                           : static_cast<int64_t>(255.0 * (1 << fp_shift) *
709                                                  input_range / output_range);
710   const int64_t input_offset_fp =
711       static_cast<int64_t>(input_rezero * recip_output_range * (1 << fp_shift));
712   const int64_t output_offset_fp =
713       output_range == 0.0
714           ? 0
715           : std::lround((1 << fp_shift) * (min_output * 255.0) / output_range);
716   const int64_t rounding_delta = 1 << (fp_shift - 1);
717 
718   // Inside this eigen expression we just do minimal adds, multiplies, and
719   // shifts. It should be possible to perform all the calculations in 32-bit
720   // rather than 64, but that's not been implemented yet.
721   auto input_array = input.flat<qint32>();
722   auto fp_value = ((input_array.template cast<int64_t>() * range_scale_fp)
723                        .unaryExpr(int64_right_shift_op<32>())) +
724                   (input_offset_fp - output_offset_fp + rounding_delta);
725   auto intermediate = fp_value.unaryExpr(int64_right_shift_op<fp_shift>());
726   auto input_requantized = intermediate.cwiseMax(int64_t{0})
727                                .cwiseMin(int64_t{255})
728                                .template cast<int32>()
729                                .template cast<quint8>();
730   output->flat<quint8>().device(device) = input_requantized;
731 }
732 
733 // REQUIRES: 'result->NumElements() == input.NumElements()'
734 template <class T>
FloatTensorToQuantizedInPlaceUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min,float max,Tensor * result)735 void FloatTensorToQuantizedInPlaceUsingEigen(
736     const Eigen::ThreadPoolDevice& device, const Tensor& input, float min,
737     float max, Tensor* result) {
738   DCHECK_EQ(DataTypeToEnum<T>::v(), result->dtype());
739   auto flat_input = input.flat<float>();
740   auto flat_result = result->flat<T>();
741   DCHECK_EQ(flat_input.size(), flat_result.size());
742 
743   FloatToQuantizedStruct<T> f2q(min, max);
744   flat_result.device(device) = QUANTIZE_WITH_EIGEN(flat_input, f2q, T);
745 }
746 
747 template <class T>
FloatTensorToQuantizedInPlace(const Tensor & input,float min,float max,Tensor * result)748 void FloatTensorToQuantizedInPlace(const Tensor& input, float min, float max,
749                                    Tensor* result) {
750   DCHECK_EQ(DataTypeToEnum<T>::v(), result->dtype());
751   auto flat_input = input.flat<float>();
752   auto flat_result = result->flat<T>();
753   const int data_size = flat_input.size();
754   DCHECK(data_size == flat_result.size());
755   for (int i = 0; i < data_size; ++i) {
756     flat_result(i) = FloatToQuantized<T>(flat_input(i), min, max);
757   }
758 }
759 
760 template <class T>
FloatTensorToQuantized(const Tensor & input,float min,float max)761 Tensor FloatTensorToQuantized(const Tensor& input, float min, float max) {
762   Tensor result(DataTypeToEnum<T>::v(), input.shape());
763   FloatTensorToQuantizedInPlace<T>(input, min, max, &result);
764   return result;
765 }
766 
767 // REQUIRES: 'result->NumElements() == input.NumElements()'
768 template <class T>
QuantizedTensorToFloatInPlaceUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min,float max,Tensor * result)769 void QuantizedTensorToFloatInPlaceUsingEigen(
770     const Eigen::ThreadPoolDevice& device, const Tensor& input, float min,
771     float max, Tensor* result) {
772   DCHECK_EQ(DataTypeToEnum<T>::v(), input.dtype());
773   auto flat_input = input.flat<T>();
774   auto flat_result = result->flat<float>();
775   const int data_size = flat_input.size();
776   DCHECK(data_size == flat_result.size());
777 
778   QuantizedToFloatStruct<T> q2f(min, max);
779   flat_result.device(device) = DEQUANTIZE_WITH_EIGEN(flat_input, q2f);
780 }
781 
782 // REQUIRES: 'result->NumElements() == input.NumElements()'
783 template <class T>
QuantizedTensorToFloatInPlace(const Tensor & input,float min,float max,Tensor * result)784 void QuantizedTensorToFloatInPlace(const Tensor& input, float min, float max,
785                                    Tensor* result) {
786   DCHECK_EQ(DataTypeToEnum<T>::v(), input.dtype());
787   auto flat_input = input.flat<T>();
788   auto flat_result = result->flat<float>();
789   const int data_size = flat_input.size();
790   DCHECK(data_size == flat_result.size());
791   for (int i = 0; i < data_size; ++i) {
792     flat_result(i) = QuantizedToFloat<T>(flat_input(i), min, max);
793   }
794 }
795 
796 template <class T>
QuantizedTensorToFloat(const Tensor & input,float min,float max)797 Tensor QuantizedTensorToFloat(const Tensor& input, float min, float max) {
798   Tensor result(DT_FLOAT, input.shape());
799   QuantizedTensorToFloatInPlace<T>(input, min, max, &result);
800   return result;
801 }
802 
803 void GetOutputMinAndMaxForQuantizedAdd(float input_min, float input_max,
804                                        float smaller_input_min,
805                                        float smaller_input_max,
806                                        float* output_min, float* output_max);
807 
808 // Add <input> and <smaller_input>.  If <smaller_input> has fewer elements than
809 // <input>, then it is broadcast onto <input>.
810 template <typename T1, typename T2, typename T3>
QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float input_min,float input_max,const Tensor & smaller_input,float smaller_input_min,float smaller_input_max,Tensor * output,float * output_min,float * output_max)811 void QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice& device,
812                             const Tensor& input, float input_min,
813                             float input_max, const Tensor& smaller_input,
814                             float smaller_input_min, float smaller_input_max,
815                             Tensor* output, float* output_min,
816                             float* output_max) {
817   const auto& input_flat = input.flat<T1>();
818   const auto& smaller_input_flat = smaller_input.flat<T2>();
819   auto output_flat = output->flat<T3>();
820 
821   GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min,
822                                     smaller_input_max, output_min, output_max);
823   // To do addition properly, we need to compensate for a possibly unbalanced
824   // zero point in the total representation. The quantized value that
825   // represents the real number zero needs to be subtracted before addition to
826   // make sure that the identity of zero + zero = zero holds.
827   const T3 zero_in_total_space =
828       FloatToQuantized<T3>(0.0f, *output_min, *output_max);
829 
830   const int64_t input_element_count = input.NumElements();
831   const int64_t smaller_input_element_count = smaller_input.NumElements();
832 
833   QuantizedToFloatStruct<T1> input_q2f(input_min, input_max);
834   QuantizedToFloatStruct<T2> smaller_input_q2f(smaller_input_min,
835                                                smaller_input_max);
836   FloatToQuantizedStruct<T3> f2q(*output_min, *output_max);
837 
838   auto smaller_input_float =
839       DEQUANTIZE_WITH_EIGEN(smaller_input_flat, smaller_input_q2f);
840   auto smaller_input_in_total_space =
841       QUANTIZE_WITH_EIGEN(smaller_input_float, f2q, T3);
842 
843   auto input_float = DEQUANTIZE_WITH_EIGEN(input_flat, input_q2f);
844   auto input_in_total_space = QUANTIZE_WITH_EIGEN(input_float, f2q, T3);
845 
846   Eigen::array<Eigen::DenseIndex, 1> bcast;
847   bcast[0] = input_element_count / smaller_input_element_count;
848   output_flat.device(device) =
849       input_in_total_space +
850       (smaller_input_in_total_space.broadcast(bcast) + zero_in_total_space);
851 }
852 
853 // This is a reference implementation of the bias addition for quantized
854 // buffers, designed to provide a clear specification for the result we
855 // want. We'll want to specialize this for particular hardware, and
856 // probably even fuse it with matrix multiplications in a lot of cases. It's
857 // important to show the clamping behavior we want in particular.
858 template <typename T1, typename T2, typename T3>
QuantizedAdd(const Eigen::ThreadPoolDevice & device,const Tensor & input,float input_min,float input_max,const Tensor & smaller_input,float smaller_input_min,float smaller_input_max,Tensor * output,float * output_min,float * output_max)859 void QuantizedAdd(const Eigen::ThreadPoolDevice& device, const Tensor& input,
860                   float input_min, float input_max, const Tensor& smaller_input,
861                   float smaller_input_min, float smaller_input_max,
862                   Tensor* output, float* output_min, float* output_max) {
863   const auto& input_flat = input.flat<T1>();
864   const auto& smaller_input_flat = smaller_input.flat<T2>();
865   auto output_flat = output->flat<T3>();
866 
867   GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min,
868                                     smaller_input_max, output_min, output_max);
869   // To do addition properly, we need to compensate for a possibly unbalanced
870   // zero point in the total representation. The quantized value that
871   // represents the real number zero needs to be subtracted before addition to
872   // make sure that the identity of zero + zero = zero holds.
873   const T3 zero_in_total_space =
874       FloatToQuantized<T3>(0.0f, *output_min, *output_max);
875 
876   const int64_t input_element_count = input.NumElements();
877   const int64_t smaller_input_element_count = smaller_input.NumElements();
878 
879   float total_min = *output_min;
880   float total_max = *output_max;
881   const size_t how_many_iterations =
882       (input_element_count / smaller_input_element_count);
883   for (size_t iteration = 0; iteration < how_many_iterations; ++iteration) {
884     const size_t offset = iteration * smaller_input_element_count;
885     for (int c = 0; c < smaller_input_element_count; ++c) {
886       const int index = (offset + c);
887       // The two numbers we're going to add can each be in very different
888       // ranges (e.g. the quantized value '127' may represent very different
889       // real numbers in both) so we need to convert them to a common range
890       // before we sum them.
891       const T1 input_value = input_flat(index);
892       const T3 input_in_total_space = RequantizeInNewRange<T1, T3>(
893           input_value, input_min, input_max, total_min, total_max);
894       const T2 smaller_input_value = smaller_input_flat(c);
895       const T3 smaller_input_in_total_space =
896           RequantizeInNewRange<T2, T3>(smaller_input_value, smaller_input_min,
897                                        smaller_input_max, total_min, total_max);
898       const T3 total_pre = input_in_total_space + smaller_input_in_total_space;
899       // As noted above, we need to compensate for the offset of the actual
900       // zero point in the space we're operating in.
901       const T3 total = total_pre + zero_in_total_space;
902       output_flat(index) = total;
903     }
904   }
905 }
906 
907 // See gemmlowp/internal/multi_thread_gemm.h for the semantics of Execute.
908 class TensorflowGemmlowpWorkersPool {
909  public:
TensorflowGemmlowpWorkersPool(thread::ThreadPool * workers)910   TensorflowGemmlowpWorkersPool(thread::ThreadPool* workers)
911       : workers_(workers) {}
912 
~TensorflowGemmlowpWorkersPool()913   ~TensorflowGemmlowpWorkersPool() {
914     // This workaround ensures that all worker tasks have exited methods in the
915     // BlockingCounter. Without this, there is a race where the context is torn
916     // down while the counter is in use.
917     counter_to_decrement_when_ready_.Reset(0);
918   }
919 
Execute(const std::vector<gemmlowp::Task * > & tasks)920   void Execute(const std::vector<gemmlowp::Task*>& tasks) {
921     assert(!tasks.empty());
922     assert(workers_ != nullptr);
923     counter_to_decrement_when_ready_.Reset(tasks.size());
924     for (gemmlowp::Task* task : tasks) {
925       workers_->Schedule([this, task]() {
926         // TODO(cwhipkey): get a local_allocator from a thread local storage.
927         gemmlowp::Allocator local_allocator;
928         CHECK(task != nullptr);
929         task->local_allocator = &local_allocator;
930         task->Run();
931         counter_to_decrement_when_ready_.DecrementCount();
932       });
933     }
934     counter_to_decrement_when_ready_.Wait();
935     for (gemmlowp::Task* task : tasks) {
936       delete task;
937     }
938   }
939 
940  private:
941   thread::ThreadPool* const workers_;
942 
943   // The BlockingCounter used to wait for the workers.
944   gemmlowp::BlockingCounter counter_to_decrement_when_ready_;
945 
946   TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmlowpWorkersPool);
947 };
948 
949 class TensorflowGemmContext : public gemmlowp::MultiThreadGemmContextBase {
950  public:
TensorflowGemmContext(int num_threads,thread::ThreadPool * workers)951   TensorflowGemmContext(int num_threads, thread::ThreadPool* workers)
952       : workers_pool_(workers) {
953     set_max_num_threads(num_threads);
954   }
955 
workers_pool()956   TensorflowGemmlowpWorkersPool* workers_pool() { return &workers_pool_; }
957 
958  private:
959   TensorflowGemmlowpWorkersPool workers_pool_;
960 
961   TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmContext);
962 };
963 
964 }  // namespace tensorflow
965 
966 #endif  // TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
967