xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/meta_support.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/kernels/meta_support.h"
19 
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/kernels/quantization_utils.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25 
26 #if defined(GEMMLOWP_NEON_32) && !defined(TENSORFLOW_DISABLE_META) && \
27     !defined(__APPLE__)
28 #define TENSORFLOW_USE_META (1)
29 #endif
30 
31 namespace tensorflow {
32 namespace meta {
33 
34 namespace {
35 
36 int g_num_threads = 0;
37 bool g_enabled = true;
38 bool g_use_local_context = false;
39 
40 #ifdef TENSORFLOW_USE_META
41 
42 const int kAlignment = 32;
43 const int kScratchSize = 2048 * 1024 + kAlignment;
44 
45 class Scratch : public ResourceBase {
46  public:
Scratch()47   Scratch() : scratch_(new uint8_t[kScratchSize]) {
48     // Make sure scratch is aligned to 32 bytes. Scratch object owns the
49     // scratch buffer.
50     scratch_32_aligned_ =
51         scratch_.get() + kAlignment -
52         (reinterpret_cast<uintptr_t>(scratch_.get()) % kAlignment);
53   }
54 
buffer()55   uint8_t* buffer() { return scratch_32_aligned_; }
56 
DebugString() const57   string DebugString() const override { return "MetaGemmScratchResource"; }
58 
59  private:
60   std::unique_ptr<uint8_t> scratch_;
61   uint8_t* scratch_32_aligned_;
62 };
63 
GetScratch(OpKernelContext * context)64 uint8_t* GetScratch(OpKernelContext* context) {
65   Scratch* scratch = nullptr;
66   std::function<Status(Scratch**)> creator = [](Scratch** resource) {
67     *resource = new Scratch();
68     return Status::OK();
69   };
70   Status s = context->resource_manager()->LookupOrCreate(
71       "MetaGemm", "ScratchBuffer", &scratch, creator);
72   if (!s.ok()) {
73     context->CtxFailureWithWarning(s);
74     return nullptr;
75   }
76   return scratch->buffer();
77 }
78 
GetWorkersPool()79 gemmlowp::WorkersPool* GetWorkersPool() {
80   static gemmlowp::WorkersPool* pool = new gemmlowp::WorkersPool();
81   return pool;
82 }
83 
GetMutex()84 mutex& GetMutex() {
85   static mutex mu(LINKER_INITIALIZED);
86   return mu;
87 }
88 
GetWorkersCount(OpKernelContext * tf_context)89 int GetWorkersCount(OpKernelContext* tf_context) {
90   if (g_num_threads == 0) {
91     return tf_context->device()->tensorflow_cpu_worker_threads()->num_threads;
92   }
93   return g_num_threads;
94 }
95 
96 typedef gemmlowp::meta::SimpleContext<gemmlowp::WorkersPool> LocalContext;
97 
98 template <typename Context, typename Params>
MultiThreadGemm(Context * context,const Params & params)99 void MultiThreadGemm(Context* context, const Params& params) {
100   if (params.m <= 4) {
101     gemmlowp::meta::MultiThreadGemm<
102         Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params, 1,
103         8, 8>(context, params);
104   } else {
105     if (params.m >= params.n) {
106       gemmlowp::meta::MultiThreadGemm<
107           Context, gemmlowp::meta::GemmExecutorPackRHSCacheFriendly<>, Params,
108           2, 4, 8>(context, params);
109     } else {
110       gemmlowp::meta::MultiThreadGemm<
111           Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params,
112           2, 4, 8>(context, params);
113     }
114   }
115 }
116 
117 template <typename LeftStream, typename RightStream>
QuantizedGemmImpl(OpKernelContext * tf_context,const quint8 * a_data,const quint8 * b_data,qint32 * c_data,int m,int n,int k,int offset_a,int offset_b,int lda,int ldb,int ldc)118 void QuantizedGemmImpl(OpKernelContext* tf_context, const quint8* a_data,
119                        const quint8* b_data, qint32* c_data, int m, int n,
120                        int k, int offset_a, int offset_b, int lda, int ldb,
121                        int ldc) {
122   typedef gemmlowp::meta::GemmParams<
123       uint8_t, int32_t, LeftStream, RightStream,
124       gemmlowp::meta::QuantizedStaticPreprocessedAsInt32,
125       gemmlowp::meta::RowMajor>
126       Params;
127   Params params;
128 
129   params.m = m;
130   params.n = n;
131   params.k = k;
132 
133   params.lhs = reinterpret_cast<const uint8_t*>(&(a_data->value));
134   params.rhs = reinterpret_cast<const uint8_t*>(&(b_data->value));
135   params.result = reinterpret_cast<int32_t*>(&(c_data->value));
136   params.scratch = CHECK_NOTNULL(GetScratch(tf_context));
137 
138   params.left_stream.count = k;
139   params.left_stream.stride = lda;
140   params.left_stream.multiplicative_sum_offset = offset_b;
141   params.left_stream.additive_sum_offset = k * offset_a * offset_b;
142 
143   params.right_stream.count = k;
144   params.right_stream.stride = ldb;
145   params.right_stream.multiplicative_sum_offset = offset_a;
146   params.right_stream.additive_sum_offset = 0;
147 
148   params.fused_kernel.kernel.count = k;
149   params.fused_kernel.output_stream.stride = ldc * sizeof(int32_t);
150 
151   if (g_use_local_context) {
152     LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
153     MultiThreadGemm<LocalContext, Params>(&local_context, params);
154   } else {
155     auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
156     TensorflowGemmContext context(workers.num_threads, workers.workers);
157     MultiThreadGemm<TensorflowGemmContext, Params>(&context, params);
158   }
159 }
160 
161 template <typename Params, int kernel_size>
MultiThreadTransform1D(OpKernelContext * tf_context,const Params & params)162 void MultiThreadTransform1D(OpKernelContext* tf_context, const Params& params) {
163   if (g_use_local_context) {
164     LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
165     gemmlowp::meta::MultiThreadTransform1D<LocalContext, Params, kernel_size>(
166         &local_context, params);
167   } else {
168     auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
169     TensorflowGemmContext context(workers.num_threads, workers.workers);
170     gemmlowp::meta::MultiThreadTransform1D<TensorflowGemmContext, Params,
171                                            kernel_size>(&context, params);
172   }
173 }
174 
175 template <typename QuantizedType>
CalculateRangeScale(float min,float max)176 double CalculateRangeScale(float min, float max) {
177   const int bits = sizeof(QuantizedType) * 8;
178   return static_cast<double>(max - min) /
179          ((static_cast<int64_t>(1) << bits) - 1);
180 }
181 
182 template <typename QuantizedType>
CalculateOneOverRangeScale(float min,float max)183 double CalculateOneOverRangeScale(float min, float max) {
184   if (min == max) {
185     return 0.0;
186   }
187   const int bits = sizeof(QuantizedType) * 8;
188   return static_cast<double>((static_cast<int64_t>(1) << bits) - 1) /
189          (max - min);
190 }
191 
192 #endif  // TENSORFLOW_USE_META
193 
194 }  // namespace
195 
SetNumThreads(int num_threads)196 void SetNumThreads(int num_threads) { g_num_threads = num_threads; }
197 
GetNumThreads()198 int GetNumThreads() { return g_num_threads; }
199 
SetUseLocalContext(bool use_local_context)200 void SetUseLocalContext(bool use_local_context) {
201   g_use_local_context = use_local_context;
202 }
203 
GetUseLocalContext()204 bool GetUseLocalContext() { return g_use_local_context; }
205 
IsSupported()206 bool IsSupported() {
207 #if defined(TENSORFLOW_USE_META)
208   return true;
209 #else
210   return false;
211 #endif
212 }
213 
IsEnabled()214 bool IsEnabled() { return g_enabled; }
215 
SetEnabled(bool enabled)216 void SetEnabled(bool enabled) { g_enabled = enabled; }
217 
IsSupportedAndEnabled()218 bool IsSupportedAndEnabled() { return IsSupported() && IsEnabled(); }
219 
QuantizedGemm(OpKernelContext * tf_context,bool transpose_a,bool transpose_b,const quint8 * a_data,const quint8 * b_data,qint32 * c_data,int m,int n,int k,int offset_a,int offset_b,int lda,int ldb,int ldc)220 void QuantizedGemm(OpKernelContext* tf_context, bool transpose_a,
221                    bool transpose_b, const quint8* a_data, const quint8* b_data,
222                    qint32* c_data, int m, int n, int k, int offset_a,
223                    int offset_b, int lda, int ldb, int ldc) {
224 #ifdef TENSORFLOW_USE_META
225   mutex_lock library_lock(GetMutex());
226   if (transpose_a) {
227     if (transpose_b) {
228       QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
229                         gemmlowp::meta::RowMajorWithSum>(
230           tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
231           ldb, ldc);
232     } else {
233       QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
234                         gemmlowp::meta::ColumnMajorWithSum>(
235           tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
236           ldb, ldc);
237     }
238   } else {
239     if (transpose_b) {
240       QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
241                         gemmlowp::meta::RowMajorWithSum>(
242           tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
243           ldb, ldc);
244     } else {
245       QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
246                         gemmlowp::meta::ColumnMajorWithSum>(
247           tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
248           ldb, ldc);
249     }
250   }
251 #else
252   LOG(FATAL) << "QuantizedGemm: Meta fastpath not supported.";
253 #endif
254 }
255 
Requantize(OpKernelContext * tf_context,const qint32 * input,int count,float input_min,float input_max,float output_min,float output_max,quint8 * output)256 void Requantize(OpKernelContext* tf_context, const qint32* input, int count,
257                 float input_min, float input_max, float output_min,
258                 float output_max, quint8* output) {
259 #ifdef TENSORFLOW_USE_META
260   mutex_lock library_lock(GetMutex());
261   typedef gemmlowp::meta::Transform1DParams<int32_t, uint8_t,
262                                             gemmlowp::meta::Requantize>
263       Params;
264 
265   Params params;
266   params.input = reinterpret_cast<const int32_t*>(input);
267   params.output = reinterpret_cast<uint8_t*>(output);
268   params.kernel.count = count;
269   params.kernel.input_range_min = input_min;
270   params.kernel.output_range_min = output_min;
271   params.kernel.input_range_scale =
272       CalculateRangeScale<int32_t>(input_min, input_max);
273   params.kernel.one_over_output_range_scale =
274       CalculateOneOverRangeScale<uint8_t>(output_min, output_max);
275   params.kernel.input_range_offset =
276       static_cast<float>(std::numeric_limits<int32_t>::lowest());
277   params.kernel.output_range_offset =
278       static_cast<float>(std::numeric_limits<uint8_t>::lowest());
279 
280 #if defined(GEMMLOWP_NEON_32)
281   // After adding the output_range_offset the value is cast from float to uint.
282   // The float to int/uint cast in 32bit arm uses round toward 0. To keep the
283   // rounding consistent with Eigen, which uses round toward closest, we can
284   // add 0.5f and exploit the fact that we only operate on non negative values.
285   // TODO(maciekc): fix the actual kernel in gemmlowp/meta
286   params.kernel.output_range_offset += 0.5f;
287 #endif
288 
289   MultiThreadTransform1D<Params, 16>(tf_context, params);
290 #else
291   LOG(FATAL) << "Requantize: Meta fastpath not supported.";
292 #endif
293 }
294 
Dequantize(OpKernelContext * tf_context,const quint8 * input,int count,float range_min,float range_max,float * output)295 void Dequantize(OpKernelContext* tf_context, const quint8* input, int count,
296                 float range_min, float range_max, float* output) {
297 #ifdef TENSORFLOW_USE_META
298   mutex_lock library_lock(GetMutex());
299   typedef gemmlowp::meta::Transform1DParams<uint8_t, float,
300                                             gemmlowp::meta::Dequantize>
301       Params;
302 
303   Params params;
304   params.input = reinterpret_cast<const uint8_t*>(input);
305   params.output = reinterpret_cast<float*>(output);
306   params.kernel.count = count;
307   params.kernel.range_min = range_min;
308   params.kernel.range_scale =
309       CalculateRangeScale<uint8_t>(range_min, range_max);
310   params.kernel.range_offset =
311       static_cast<float>(std::numeric_limits<uint8_t>::lowest());
312 
313   MultiThreadTransform1D<Params, 16>(tf_context, params);
314 #else
315   LOG(FATAL) << "Dequantize: Meta fastpath not supported.";
316 #endif
317 }
318 
Quantize(OpKernelContext * tf_context,const float * input,int count,float range_min,float range_max,quint8 * output)319 void Quantize(OpKernelContext* tf_context, const float* input, int count,
320               float range_min, float range_max, quint8* output) {
321 #ifdef TENSORFLOW_USE_META
322   mutex_lock library_lock(GetMutex());
323   typedef gemmlowp::meta::Transform1DParams<float, uint8_t,
324                                             gemmlowp::meta::Quantize>
325       Params;
326 
327   Params params;
328   params.input = reinterpret_cast<const float*>(input);
329   params.output = reinterpret_cast<uint8_t*>(output);
330   params.kernel.count = count;
331   params.kernel.range_min = range_min;
332   params.kernel.range_scale =
333       CalculateOneOverRangeScale<uint8_t>(range_min, range_max);
334   params.kernel.range_offset =
335       static_cast<float>(std::numeric_limits<uint8_t>::lowest());
336 
337 #if defined(GEMMLOWP_NEON_32)
338   // The float to int/uint cast on 32bit arm uses round toward 0. To keep the
339   // rounding consistent with Eigen, which uses round toward closest, we can
340   // add 0.5f and exploit the fact that we only operate on non negative values.
341   // TODO(maciekc): fix the actual kernel in gemmlowp/meta
342   params.kernel.range_offset += 0.5f;
343 #endif
344 
345   MultiThreadTransform1D<Params, 16>(tf_context, params);
346 #else
347   LOG(FATAL) << "Quantize: Meta fastpath not supported.";
348 #endif
349 }
350 
QuantizedBiasAdd(OpKernelContext * tf_context,const quint8 * input,int input_count,const quint8 * bias,int bias_count,float input_min,float input_max,float bias_min,float bias_max,float output_min,float output_max,qint32 * output)351 void QuantizedBiasAdd(OpKernelContext* tf_context, const quint8* input,
352                       int input_count, const quint8* bias, int bias_count,
353                       float input_min, float input_max, float bias_min,
354                       float bias_max, float output_min, float output_max,
355                       qint32* output) {
356 #ifdef TENSORFLOW_USE_META
357   mutex_lock library_lock(GetMutex());
358   typedef gemmlowp::meta::Transform1DParams<uint8_t, int32_t,
359                                             gemmlowp::meta::BiasAdd<uint8_t>>
360       Params;
361 
362   Params params;
363   params.input = reinterpret_cast<const uint8_t*>(input);
364   params.output = reinterpret_cast<int32_t*>(output);
365   params.kernel.bias = reinterpret_cast<const uint8_t*>(bias);
366   params.kernel.count = bias_count;
367   params.kernel.rows = input_count / bias_count;
368   params.kernel.input_range_min = input_min;
369   params.kernel.bias_range_min = bias_min;
370   params.kernel.input_range_scale =
371       CalculateRangeScale<uint8_t>(input_min, input_max);
372   params.kernel.bias_range_scale =
373       CalculateRangeScale<uint8_t>(bias_min, bias_max);
374   params.kernel.input_range_offset = 0;
375   params.kernel.bias_range_offset = 0;
376   params.kernel.output_range_min = output_min;
377   params.kernel.one_over_output_range_scale =
378       CalculateOneOverRangeScale<int32_t>(output_min, output_max);
379   params.kernel.output_range_offset =
380       static_cast<float>(std::numeric_limits<int32_t>::lowest());
381 
382   // TODO(maciekc): add multithreading to bias add.
383   // Right now this kernel does not support multi threaded execution.
384   gemmlowp::meta::Transform1D<Params, 16>(params);
385 #else
386   LOG(FATAL) << "QuantizedBiasAdd: Meta fastpath not supported.";
387 #endif
388 }
389 
Clamp(OpKernelContext * tf_context,const quint8 * input,int count,quint8 clamp_min,quint8 clamp_max,quint8 * output)390 void Clamp(OpKernelContext* tf_context, const quint8* input, int count,
391            quint8 clamp_min, quint8 clamp_max, quint8* output) {
392 #ifdef TENSORFLOW_USE_META
393   mutex_lock library_lock(GetMutex());
394   typedef gemmlowp::meta::Transform1DParams<uint8_t, uint8_t,
395                                             gemmlowp::meta::MinMax<uint8_t>>
396       Params;
397 
398   Params params;
399   params.input = reinterpret_cast<const uint8_t*>(input);
400   params.output = reinterpret_cast<uint8_t*>(output);
401   params.kernel.count = count;
402   params.kernel.min = clamp_min;
403   params.kernel.max = clamp_max;
404 
405   MultiThreadTransform1D<Params, 16>(tf_context, params);
406 #else
407   LOG(FATAL) << "Clamp: Meta fastpath not supported.";
408 #endif
409 }
410 
411 }  // namespace meta
412 }  // namespace tensorflow
413