1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/kernels/meta_support.h"
19
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/kernels/quantization_utils.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25
26 #if defined(GEMMLOWP_NEON_32) && !defined(TENSORFLOW_DISABLE_META) && \
27 !defined(__APPLE__)
28 #define TENSORFLOW_USE_META (1)
29 #endif
30
31 namespace tensorflow {
32 namespace meta {
33
34 namespace {
35
36 int g_num_threads = 0;
37 bool g_enabled = true;
38 bool g_use_local_context = false;
39
40 #ifdef TENSORFLOW_USE_META
41
42 const int kAlignment = 32;
43 const int kScratchSize = 2048 * 1024 + kAlignment;
44
45 class Scratch : public ResourceBase {
46 public:
Scratch()47 Scratch() : scratch_(new uint8_t[kScratchSize]) {
48 // Make sure scratch is aligned to 32 bytes. Scratch object owns the
49 // scratch buffer.
50 scratch_32_aligned_ =
51 scratch_.get() + kAlignment -
52 (reinterpret_cast<uintptr_t>(scratch_.get()) % kAlignment);
53 }
54
buffer()55 uint8_t* buffer() { return scratch_32_aligned_; }
56
DebugString() const57 string DebugString() const override { return "MetaGemmScratchResource"; }
58
59 private:
60 std::unique_ptr<uint8_t> scratch_;
61 uint8_t* scratch_32_aligned_;
62 };
63
GetScratch(OpKernelContext * context)64 uint8_t* GetScratch(OpKernelContext* context) {
65 Scratch* scratch = nullptr;
66 std::function<Status(Scratch**)> creator = [](Scratch** resource) {
67 *resource = new Scratch();
68 return Status::OK();
69 };
70 Status s = context->resource_manager()->LookupOrCreate(
71 "MetaGemm", "ScratchBuffer", &scratch, creator);
72 if (!s.ok()) {
73 context->CtxFailureWithWarning(s);
74 return nullptr;
75 }
76 return scratch->buffer();
77 }
78
GetWorkersPool()79 gemmlowp::WorkersPool* GetWorkersPool() {
80 static gemmlowp::WorkersPool* pool = new gemmlowp::WorkersPool();
81 return pool;
82 }
83
GetMutex()84 mutex& GetMutex() {
85 static mutex mu(LINKER_INITIALIZED);
86 return mu;
87 }
88
GetWorkersCount(OpKernelContext * tf_context)89 int GetWorkersCount(OpKernelContext* tf_context) {
90 if (g_num_threads == 0) {
91 return tf_context->device()->tensorflow_cpu_worker_threads()->num_threads;
92 }
93 return g_num_threads;
94 }
95
96 typedef gemmlowp::meta::SimpleContext<gemmlowp::WorkersPool> LocalContext;
97
98 template <typename Context, typename Params>
MultiThreadGemm(Context * context,const Params & params)99 void MultiThreadGemm(Context* context, const Params& params) {
100 if (params.m <= 4) {
101 gemmlowp::meta::MultiThreadGemm<
102 Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params, 1,
103 8, 8>(context, params);
104 } else {
105 if (params.m >= params.n) {
106 gemmlowp::meta::MultiThreadGemm<
107 Context, gemmlowp::meta::GemmExecutorPackRHSCacheFriendly<>, Params,
108 2, 4, 8>(context, params);
109 } else {
110 gemmlowp::meta::MultiThreadGemm<
111 Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params,
112 2, 4, 8>(context, params);
113 }
114 }
115 }
116
117 template <typename LeftStream, typename RightStream>
QuantizedGemmImpl(OpKernelContext * tf_context,const quint8 * a_data,const quint8 * b_data,qint32 * c_data,int m,int n,int k,int offset_a,int offset_b,int lda,int ldb,int ldc)118 void QuantizedGemmImpl(OpKernelContext* tf_context, const quint8* a_data,
119 const quint8* b_data, qint32* c_data, int m, int n,
120 int k, int offset_a, int offset_b, int lda, int ldb,
121 int ldc) {
122 typedef gemmlowp::meta::GemmParams<
123 uint8_t, int32_t, LeftStream, RightStream,
124 gemmlowp::meta::QuantizedStaticPreprocessedAsInt32,
125 gemmlowp::meta::RowMajor>
126 Params;
127 Params params;
128
129 params.m = m;
130 params.n = n;
131 params.k = k;
132
133 params.lhs = reinterpret_cast<const uint8_t*>(&(a_data->value));
134 params.rhs = reinterpret_cast<const uint8_t*>(&(b_data->value));
135 params.result = reinterpret_cast<int32_t*>(&(c_data->value));
136 params.scratch = CHECK_NOTNULL(GetScratch(tf_context));
137
138 params.left_stream.count = k;
139 params.left_stream.stride = lda;
140 params.left_stream.multiplicative_sum_offset = offset_b;
141 params.left_stream.additive_sum_offset = k * offset_a * offset_b;
142
143 params.right_stream.count = k;
144 params.right_stream.stride = ldb;
145 params.right_stream.multiplicative_sum_offset = offset_a;
146 params.right_stream.additive_sum_offset = 0;
147
148 params.fused_kernel.kernel.count = k;
149 params.fused_kernel.output_stream.stride = ldc * sizeof(int32_t);
150
151 if (g_use_local_context) {
152 LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
153 MultiThreadGemm<LocalContext, Params>(&local_context, params);
154 } else {
155 auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
156 TensorflowGemmContext context(workers.num_threads, workers.workers);
157 MultiThreadGemm<TensorflowGemmContext, Params>(&context, params);
158 }
159 }
160
161 template <typename Params, int kernel_size>
MultiThreadTransform1D(OpKernelContext * tf_context,const Params & params)162 void MultiThreadTransform1D(OpKernelContext* tf_context, const Params& params) {
163 if (g_use_local_context) {
164 LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
165 gemmlowp::meta::MultiThreadTransform1D<LocalContext, Params, kernel_size>(
166 &local_context, params);
167 } else {
168 auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
169 TensorflowGemmContext context(workers.num_threads, workers.workers);
170 gemmlowp::meta::MultiThreadTransform1D<TensorflowGemmContext, Params,
171 kernel_size>(&context, params);
172 }
173 }
174
175 template <typename QuantizedType>
CalculateRangeScale(float min,float max)176 double CalculateRangeScale(float min, float max) {
177 const int bits = sizeof(QuantizedType) * 8;
178 return static_cast<double>(max - min) /
179 ((static_cast<int64_t>(1) << bits) - 1);
180 }
181
182 template <typename QuantizedType>
CalculateOneOverRangeScale(float min,float max)183 double CalculateOneOverRangeScale(float min, float max) {
184 if (min == max) {
185 return 0.0;
186 }
187 const int bits = sizeof(QuantizedType) * 8;
188 return static_cast<double>((static_cast<int64_t>(1) << bits) - 1) /
189 (max - min);
190 }
191
192 #endif // TENSORFLOW_USE_META
193
194 } // namespace
195
SetNumThreads(int num_threads)196 void SetNumThreads(int num_threads) { g_num_threads = num_threads; }
197
GetNumThreads()198 int GetNumThreads() { return g_num_threads; }
199
SetUseLocalContext(bool use_local_context)200 void SetUseLocalContext(bool use_local_context) {
201 g_use_local_context = use_local_context;
202 }
203
GetUseLocalContext()204 bool GetUseLocalContext() { return g_use_local_context; }
205
IsSupported()206 bool IsSupported() {
207 #if defined(TENSORFLOW_USE_META)
208 return true;
209 #else
210 return false;
211 #endif
212 }
213
IsEnabled()214 bool IsEnabled() { return g_enabled; }
215
SetEnabled(bool enabled)216 void SetEnabled(bool enabled) { g_enabled = enabled; }
217
IsSupportedAndEnabled()218 bool IsSupportedAndEnabled() { return IsSupported() && IsEnabled(); }
219
QuantizedGemm(OpKernelContext * tf_context,bool transpose_a,bool transpose_b,const quint8 * a_data,const quint8 * b_data,qint32 * c_data,int m,int n,int k,int offset_a,int offset_b,int lda,int ldb,int ldc)220 void QuantizedGemm(OpKernelContext* tf_context, bool transpose_a,
221 bool transpose_b, const quint8* a_data, const quint8* b_data,
222 qint32* c_data, int m, int n, int k, int offset_a,
223 int offset_b, int lda, int ldb, int ldc) {
224 #ifdef TENSORFLOW_USE_META
225 mutex_lock library_lock(GetMutex());
226 if (transpose_a) {
227 if (transpose_b) {
228 QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
229 gemmlowp::meta::RowMajorWithSum>(
230 tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
231 ldb, ldc);
232 } else {
233 QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
234 gemmlowp::meta::ColumnMajorWithSum>(
235 tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
236 ldb, ldc);
237 }
238 } else {
239 if (transpose_b) {
240 QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
241 gemmlowp::meta::RowMajorWithSum>(
242 tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
243 ldb, ldc);
244 } else {
245 QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
246 gemmlowp::meta::ColumnMajorWithSum>(
247 tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
248 ldb, ldc);
249 }
250 }
251 #else
252 LOG(FATAL) << "QuantizedGemm: Meta fastpath not supported.";
253 #endif
254 }
255
Requantize(OpKernelContext * tf_context,const qint32 * input,int count,float input_min,float input_max,float output_min,float output_max,quint8 * output)256 void Requantize(OpKernelContext* tf_context, const qint32* input, int count,
257 float input_min, float input_max, float output_min,
258 float output_max, quint8* output) {
259 #ifdef TENSORFLOW_USE_META
260 mutex_lock library_lock(GetMutex());
261 typedef gemmlowp::meta::Transform1DParams<int32_t, uint8_t,
262 gemmlowp::meta::Requantize>
263 Params;
264
265 Params params;
266 params.input = reinterpret_cast<const int32_t*>(input);
267 params.output = reinterpret_cast<uint8_t*>(output);
268 params.kernel.count = count;
269 params.kernel.input_range_min = input_min;
270 params.kernel.output_range_min = output_min;
271 params.kernel.input_range_scale =
272 CalculateRangeScale<int32_t>(input_min, input_max);
273 params.kernel.one_over_output_range_scale =
274 CalculateOneOverRangeScale<uint8_t>(output_min, output_max);
275 params.kernel.input_range_offset =
276 static_cast<float>(std::numeric_limits<int32_t>::lowest());
277 params.kernel.output_range_offset =
278 static_cast<float>(std::numeric_limits<uint8_t>::lowest());
279
280 #if defined(GEMMLOWP_NEON_32)
281 // After adding the output_range_offset the value is cast from float to uint.
282 // The float to int/uint cast in 32bit arm uses round toward 0. To keep the
283 // rounding consistent with Eigen, which uses round toward closest, we can
284 // add 0.5f and exploit the fact that we only operate on non negative values.
285 // TODO(maciekc): fix the actual kernel in gemmlowp/meta
286 params.kernel.output_range_offset += 0.5f;
287 #endif
288
289 MultiThreadTransform1D<Params, 16>(tf_context, params);
290 #else
291 LOG(FATAL) << "Requantize: Meta fastpath not supported.";
292 #endif
293 }
294
Dequantize(OpKernelContext * tf_context,const quint8 * input,int count,float range_min,float range_max,float * output)295 void Dequantize(OpKernelContext* tf_context, const quint8* input, int count,
296 float range_min, float range_max, float* output) {
297 #ifdef TENSORFLOW_USE_META
298 mutex_lock library_lock(GetMutex());
299 typedef gemmlowp::meta::Transform1DParams<uint8_t, float,
300 gemmlowp::meta::Dequantize>
301 Params;
302
303 Params params;
304 params.input = reinterpret_cast<const uint8_t*>(input);
305 params.output = reinterpret_cast<float*>(output);
306 params.kernel.count = count;
307 params.kernel.range_min = range_min;
308 params.kernel.range_scale =
309 CalculateRangeScale<uint8_t>(range_min, range_max);
310 params.kernel.range_offset =
311 static_cast<float>(std::numeric_limits<uint8_t>::lowest());
312
313 MultiThreadTransform1D<Params, 16>(tf_context, params);
314 #else
315 LOG(FATAL) << "Dequantize: Meta fastpath not supported.";
316 #endif
317 }
318
Quantize(OpKernelContext * tf_context,const float * input,int count,float range_min,float range_max,quint8 * output)319 void Quantize(OpKernelContext* tf_context, const float* input, int count,
320 float range_min, float range_max, quint8* output) {
321 #ifdef TENSORFLOW_USE_META
322 mutex_lock library_lock(GetMutex());
323 typedef gemmlowp::meta::Transform1DParams<float, uint8_t,
324 gemmlowp::meta::Quantize>
325 Params;
326
327 Params params;
328 params.input = reinterpret_cast<const float*>(input);
329 params.output = reinterpret_cast<uint8_t*>(output);
330 params.kernel.count = count;
331 params.kernel.range_min = range_min;
332 params.kernel.range_scale =
333 CalculateOneOverRangeScale<uint8_t>(range_min, range_max);
334 params.kernel.range_offset =
335 static_cast<float>(std::numeric_limits<uint8_t>::lowest());
336
337 #if defined(GEMMLOWP_NEON_32)
338 // The float to int/uint cast on 32bit arm uses round toward 0. To keep the
339 // rounding consistent with Eigen, which uses round toward closest, we can
340 // add 0.5f and exploit the fact that we only operate on non negative values.
341 // TODO(maciekc): fix the actual kernel in gemmlowp/meta
342 params.kernel.range_offset += 0.5f;
343 #endif
344
345 MultiThreadTransform1D<Params, 16>(tf_context, params);
346 #else
347 LOG(FATAL) << "Quantize: Meta fastpath not supported.";
348 #endif
349 }
350
QuantizedBiasAdd(OpKernelContext * tf_context,const quint8 * input,int input_count,const quint8 * bias,int bias_count,float input_min,float input_max,float bias_min,float bias_max,float output_min,float output_max,qint32 * output)351 void QuantizedBiasAdd(OpKernelContext* tf_context, const quint8* input,
352 int input_count, const quint8* bias, int bias_count,
353 float input_min, float input_max, float bias_min,
354 float bias_max, float output_min, float output_max,
355 qint32* output) {
356 #ifdef TENSORFLOW_USE_META
357 mutex_lock library_lock(GetMutex());
358 typedef gemmlowp::meta::Transform1DParams<uint8_t, int32_t,
359 gemmlowp::meta::BiasAdd<uint8_t>>
360 Params;
361
362 Params params;
363 params.input = reinterpret_cast<const uint8_t*>(input);
364 params.output = reinterpret_cast<int32_t*>(output);
365 params.kernel.bias = reinterpret_cast<const uint8_t*>(bias);
366 params.kernel.count = bias_count;
367 params.kernel.rows = input_count / bias_count;
368 params.kernel.input_range_min = input_min;
369 params.kernel.bias_range_min = bias_min;
370 params.kernel.input_range_scale =
371 CalculateRangeScale<uint8_t>(input_min, input_max);
372 params.kernel.bias_range_scale =
373 CalculateRangeScale<uint8_t>(bias_min, bias_max);
374 params.kernel.input_range_offset = 0;
375 params.kernel.bias_range_offset = 0;
376 params.kernel.output_range_min = output_min;
377 params.kernel.one_over_output_range_scale =
378 CalculateOneOverRangeScale<int32_t>(output_min, output_max);
379 params.kernel.output_range_offset =
380 static_cast<float>(std::numeric_limits<int32_t>::lowest());
381
382 // TODO(maciekc): add multithreading to bias add.
383 // Right now this kernel does not support multi threaded execution.
384 gemmlowp::meta::Transform1D<Params, 16>(params);
385 #else
386 LOG(FATAL) << "QuantizedBiasAdd: Meta fastpath not supported.";
387 #endif
388 }
389
Clamp(OpKernelContext * tf_context,const quint8 * input,int count,quint8 clamp_min,quint8 clamp_max,quint8 * output)390 void Clamp(OpKernelContext* tf_context, const quint8* input, int count,
391 quint8 clamp_min, quint8 clamp_max, quint8* output) {
392 #ifdef TENSORFLOW_USE_META
393 mutex_lock library_lock(GetMutex());
394 typedef gemmlowp::meta::Transform1DParams<uint8_t, uint8_t,
395 gemmlowp::meta::MinMax<uint8_t>>
396 Params;
397
398 Params params;
399 params.input = reinterpret_cast<const uint8_t*>(input);
400 params.output = reinterpret_cast<uint8_t*>(output);
401 params.kernel.count = count;
402 params.kernel.min = clamp_min;
403 params.kernel.max = clamp_max;
404
405 MultiThreadTransform1D<Params, 16>(tf_context, params);
406 #else
407 LOG(FATAL) << "Clamp: Meta fastpath not supported.";
408 #endif
409 }
410
411 } // namespace meta
412 } // namespace tensorflow
413