1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Make this file empty (or nearly empty) so that it can be compiled even when
17 // libxsmm is not available.
18
19 #ifndef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
20 void dummy_xsmm_conv2d_ensure_file_is_not_empty();
21 #else
22
23 #define USE_EIGEN_TENSOR
24 #define EIGEN_USE_THREADS
25
26 #include "tensorflow/core/kernels/xsmm_conv2d.h"
27
28 #include <stdlib.h>
29 #include <cstring>
30 #if defined(_OPENMP) && defined(LIBXSMM_USE_OPENMP)
31 #include <omp.h>
32 #endif
33
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/lib/core/blocking_counter.h"
36 #include "tensorflow/core/lib/core/threadpool.h"
37
38 #include "include/libxsmm_cpuid.h"
39 #include "include/libxsmm_malloc.h"
40 #include "src/libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/
41
42 #define CHECK_LIBXSMM(CONDITION_OK, MESSAGE) \
43 if (!(CONDITION_OK)) VLOG(0) << (MESSAGE)
44 #define CHECK_LIBXSMM_DNN(STATUS, MESSAGE) \
45 CHECK_LIBXSMM(LIBXSMM_DNN_SUCCESS == (STATUS), MESSAGE) \
46 << " failed: " << libxsmm_dnn_get_error(STATUS);
47
48 namespace tensorflow {
49
50 // Xsmm*Conv2D are wrappers for libxsmm direct convolutions.
51
52 // Returns true if convolution can be computed efficiently by XsmmConv2D,
53 // returns false otherwise.
CanUseXsmmConv2D(const libxsmm_dnn_conv_desc & desc,TensorFormat data_format)54 bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc,
55 TensorFormat data_format) {
56 int VECTOR_SIZE;
57 int arch = libxsmm_cpuid_x86();
58
59 if (arch == LIBXSMM_X86_AVX512_CORE) {
60 VECTOR_SIZE = 16;
61 } else if (arch == LIBXSMM_X86_AVX2) {
62 VECTOR_SIZE = 8;
63 } else {
64 VLOG(1) << "Cannot use XSMM convolutions: unsupported architecture!";
65 return false;
66 }
67
68 if (data_format != FORMAT_NHWC) {
69 VLOG(1) << "Cannot use XSMM convolutions: unsupported format!";
70 return false;
71 }
72 if (desc.K % VECTOR_SIZE != 0) {
73 VLOG(1) << "Cannot use XSMM convolutions: output features count not"
74 " divisible by vector size!";
75 return false;
76 }
77 VLOG(2) << "Can use XSMM convolutions.";
78 return true;
79 }
80
81 typedef Eigen::ThreadPoolDevice CPUDevice;
82
83 namespace functor {
84
copy_RSCK_to_custom(const float * rsck,float * kcrs,int R,int S,int C,int K,int blocksifm,int blocksofm,int ifmblock,int ofmblock,int start,int end)85 LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
86 int S, int C, int K, int blocksifm,
87 int blocksofm, int ifmblock,
88 int ofmblock, int start, int end) {
89 LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C, K);
90 LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm, R, S, ifmblock, ofmblock);
91 int r, s, k, c, v1, v2;
92
93 for (k = start; k < end; k++) {
94 for (c = 0; c < blocksifm; c++) {
95 for (r = 0; r < R; r++) {
96 for (s = 0; s < S; s++) {
97 for (v1 = c * ifmblock; v1 < std::min(C, (c + 1) * ifmblock); v1++) {
98 for (v2 = k * ofmblock; v2 < std::min(K, (k + 1) * ofmblock); v2++)
99 LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
100 v2 - k * ofmblock, blocksifm, R, S, ifmblock,
101 ofmblock) =
102 LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K);
103 for (v2 = K; v2 < (k + 1) * ofmblock; v2++)
104 LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
105 v2 - k * ofmblock, blocksifm, R, S, ifmblock,
106 ofmblock) = 0.0f;
107 }
108 for (v1 = C; v1 < (c + 1) * ifmblock; v1++) {
109 for (v2 = k * ofmblock; v2 < (k + 1) * ofmblock; v2++)
110 LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
111 v2 - k * ofmblock, blocksifm, R, S, ifmblock,
112 ofmblock) = 0.0f;
113 }
114 }
115 }
116 }
117 }
118 }
119
120 struct libxsmm_dnn_registry_key {
121 const libxsmm_dnn_conv_desc descriptor;
libxsmm_dnn_registry_keytensorflow::functor::libxsmm_dnn_registry_key122 libxsmm_dnn_registry_key(const libxsmm_dnn_conv_desc& desc_)
123 : descriptor(desc_) {}
operator ==tensorflow::functor::libxsmm_dnn_registry_key124 bool operator==(const libxsmm_dnn_registry_key& regkey) const {
125 return 0 == memcmp(&descriptor, ®key.descriptor, sizeof(descriptor));
126 }
127 };
128
129 struct HashFunction {
operator ()tensorflow::functor::HashFunction130 std::size_t operator()(const libxsmm_dnn_registry_key& regkey) const {
131 return libxsmm_hash(®key.descriptor, sizeof(regkey.descriptor),
132 25071975);
133 }
134 };
135
136 struct libxsmm_dnn_registry_value {
137 libxsmm_dnn_tensor_datalayout* layout_input;
138 libxsmm_dnn_tensor_datalayout* layout_filter;
139 libxsmm_dnn_tensor_datalayout* layout_output;
140 libxsmm_dnn_layer* handle;
141 };
142
143 typedef libxsmm_tf_allocator<libxsmm_scratch_allocator>
144 libxsmm_tf_scratch_allocator;
145
146 static class libxsmm_dnn_registry_type {
147 private:
148 typedef std::unordered_map<libxsmm_dnn_registry_key,
149 libxsmm_dnn_registry_value, HashFunction>
150 container_type;
151
152 public:
libxsmm_dnn_registry_type()153 libxsmm_dnn_registry_type() {
154 libxsmm_init(); /* must be first */
155 #if !defined(LIBXSMM_LOCAL_ALLOC)
156 {
157 libxsmm_malloc_function malloc_fn;
158 libxsmm_free_function free_fn;
159 malloc_fn.function = libxsmm_tf_scratch_allocator::malloc;
160 free_fn.function = libxsmm_tf_scratch_allocator::free;
161 libxsmm_set_scratch_allocator(0 /*context*/, malloc_fn, free_fn);
162 }
163 #endif
164 LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_LOCK_RWLOCK, &attr);
165 LIBXSMM_LOCK_INIT(LIBXSMM_LOCK_RWLOCK, &lock, &attr);
166 }
~libxsmm_dnn_registry_type()167 ~libxsmm_dnn_registry_type() {
168 LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
169 const container_type::const_iterator end = container.end();
170 for (container_type::const_iterator i = container.begin(); i != end; ++i) {
171 CHECK_LIBXSMM_DNN(
172 libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_input),
173 "destroy input layout");
174 CHECK_LIBXSMM_DNN(
175 libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_output),
176 "destroy output layout");
177 CHECK_LIBXSMM_DNN(
178 libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_filter),
179 "destroy filter layout");
180 CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_conv_layer(i->second.handle),
181 "destroy handle");
182 }
183 LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
184 LIBXSMM_LOCK_DESTROY(LIBXSMM_LOCK_RWLOCK, &lock);
185 LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_LOCK_RWLOCK, &attr);
186 libxsmm_finalize();
187 }
find(const libxsmm_dnn_registry_key & regkey)188 libxsmm_dnn_registry_value find(const libxsmm_dnn_registry_key& regkey) {
189 container_type::iterator i;
190 LIBXSMM_LOCK_ACQREAD(LIBXSMM_LOCK_RWLOCK, &lock);
191 i = container.find(regkey);
192 LIBXSMM_LOCK_RELREAD(LIBXSMM_LOCK_RWLOCK, &lock);
193 if (i == container.end()) {
194 libxsmm_dnn_err_t status;
195 libxsmm_dnn_registry_value regentry;
196
197 LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
198 i = container.find(regkey);
199 if (i == container.end()) { // re-check after lock acquisition
200 regentry.handle =
201 libxsmm_dnn_create_conv_layer(regkey.descriptor, &status);
202 if (LIBXSMM_DNN_WARN_FALLBACK != status) {
203 CHECK_LIBXSMM_DNN(status, "create handle");
204 } else { // warning
205 VLOG(1) << libxsmm_dnn_get_error(status);
206 }
207 regentry.layout_input = libxsmm_dnn_create_tensor_datalayout(
208 regentry.handle, LIBXSMM_DNN_INPUT, &status);
209 CHECK_LIBXSMM_DNN(status, "create input layout");
210
211 regentry.layout_output = libxsmm_dnn_create_tensor_datalayout(
212 regentry.handle, LIBXSMM_DNN_OUTPUT, &status);
213 CHECK_LIBXSMM_DNN(status, "create output layout");
214
215 regentry.layout_filter = libxsmm_dnn_create_tensor_datalayout(
216 regentry.handle, LIBXSMM_DNN_FILTER, &status);
217 CHECK_LIBXSMM_DNN(status, "create filter layout");
218
219 i = container.insert(std::make_pair(regkey, regentry)).first;
220 }
221 LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
222 }
223 return i->second;
224 }
225
226 private:
227 container_type container;
228 LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_LOCK_RWLOCK) attr;
229 LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK_RWLOCK) lock;
230 } libxsmm_dnn_registry;
231
232 // #define LIBXSMM_DETAILED_TIMING
233
234 template <typename InputPtr, typename FilterPtr, typename OutputPtr>
CallLibxsmmConvGeneric(OpKernelContext * ctx,const libxsmm_dnn_conv_desc & desc,libxsmm_dnn_compute_kind kind,InputPtr input,FilterPtr filter,OutputPtr output)235 static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
236 const libxsmm_dnn_conv_desc& desc,
237 libxsmm_dnn_compute_kind kind,
238 InputPtr input, FilterPtr filter,
239 OutputPtr output) {
240 // TODO(penporn): Fix calls to deprecated LIBXSMM API or delete this kernel.
241 // Fall back to non-libxsmm code for now.
242 return false;
243 /*
244 #if defined(LIBXSMM_DETAILED_TIMING)
245 libxsmm_timer_tickint l_tick1;
246 libxsmm_timer_tickint l_tick2;
247 libxsmm_timer_tickint l_tick3;
248 libxsmm_timer_tickint l_tick4;
249 libxsmm_timer_tickint l_tick5;
250 libxsmm_timer_tickint l_tick6;
251 libxsmm_timer_tickint l_tick7;
252 libxsmm_timer_tickint l_tick8;
253 libxsmm_timer_tickint l_tick9;
254 libxsmm_timer_tickint l_tick10;
255 l_tick1 = libxsmm_timer_tick();
256 #endif
257 #if defined(LIBXSMM_LOCAL_ALLOC)
258 // setup scoped allocator, which adopts the allocator of the current context
259 const libxsmm_tf_scratch_allocator tf_allocator(*ctx);
260 #endif
261 const libxsmm_dnn_registry_key regkey(desc);
262 const libxsmm_dnn_registry_value regentry =
263 libxsmm_dnn_registry.find(regkey); libxsmm_dnn_tensor *libxsmm_input,
264 *libxsmm_output, *libxsmm_filter; libxsmm_dnn_err_t status;
265
266 status = libxsmm_dnn_get_codegen_success(regentry.handle, kind);
267 if (status == LIBXSMM_DNN_WARN_FALLBACK) {
268 return false; // Use non-libxsmm code
269 }
270 CHECK_LIBXSMM_DNN(status, "code generation");
271
272 #if defined(LIBXSMM_DETAILED_TIMING)
273 l_tick2 = libxsmm_timer_tick();
274 #endif
275
276 const int ifmblock = regentry.handle->ifmblock;
277 const int ofmblock = regentry.handle->ofmblock;
278
279 const int blocksifm =
280 (desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1);
281 const int blocksofm =
282 (desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1);
283
284 const size_t filter_size =
285 blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock;
286 float* const native_filter = (float*)libxsmm_aligned_scratch(
287 filter_size * sizeof(float), 2097152 /*alignment*//*);
288
289 const DeviceBase::CpuWorkerThreads* const worker_threads =
290 ctx->device()->tensorflow_cpu_worker_threads();
291 const int num_threads = worker_threads->num_threads;
292
293 #if 1
294 if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
295 kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
296 if (blocksofm > num_threads) {
297 const int work = blocksofm;
298 BlockingCounter count(num_threads);
299 for (int i = 0; i < num_threads; ++i) {
300 worker_threads->workers->Schedule([=, &count]() {
301 const int start = work / num_threads * i;
302 const int end = (start + work / num_threads) > work
303 ? work
304 : start + work / num_threads;
305 copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
306 desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
307 start, end);
308 count.DecrementCount();
309 });
310 }
311 count.Wait();
312 } else {
313 const int work = blocksofm;
314 const int num_tasks = work;
315
316 BlockingCounter count(num_tasks);
317 for (int i = 0; i < num_tasks; ++i) {
318 worker_threads->workers->Schedule([=, &count]() {
319 const int start = i;
320 const int end = i + 1;
321 copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
322 desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
323 start, end);
324 count.DecrementCount();
325 });
326 }
327 count.Wait();
328 }
329 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
330 // weight update buffer must be in the right format
331 // (LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR)
332 libxsmm_filter =
333 libxsmm_dnn_link_tensor(regentry.layout_filter, filter, &status);
334 CHECK_LIBXSMM_DNN(status, "link filter with layout");
335 }
336 #else
337 memset(native_filter, 0, filter_size * sizeof(float));
338 #endif
339
340 #if defined(LIBXSMM_DETAILED_TIMING)
341 l_tick3 = libxsmm_timer_tick();
342 #endif
343
344 // LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
345 libxsmm_input =
346 libxsmm_dnn_link_tensor(regentry.layout_input, input, &status);
347 CHECK_LIBXSMM_DNN(status, "link input buffer with layout");
348
349 // LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
350 libxsmm_output =
351 libxsmm_dnn_link_tensor(regentry.layout_output, output, &status);
352 CHECK_LIBXSMM_DNN(status, "link output buffer with layout");
353
354 if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
355 kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
356 // LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR
357 libxsmm_filter =
358 libxsmm_dnn_link_tensor(regentry.layout_filter, native_filter, &status);
359 CHECK_LIBXSMM_DNN(status, "link filter with layout");
360 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
361 LIBXSMM_DNN_REGULAR_FILTER),
362 "bind filter to handle");
363 }
364 if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
365 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
366 LIBXSMM_DNN_REGULAR_INPUT),
367 "bind input forward");
368 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
369 LIBXSMM_DNN_REGULAR_FILTER),
370 "bind filter forward");
371 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
372 LIBXSMM_DNN_REGULAR_OUTPUT),
373 "bind output forward");
374 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
375 CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_input), "zeroing input");
376 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
377 LIBXSMM_DNN_GRADIENT_INPUT),
378 "bind input backward");
379 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
380 LIBXSMM_DNN_REGULAR_FILTER),
381 "bind filter backward");
382 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
383 LIBXSMM_DNN_GRADIENT_OUTPUT),
384 "bind output backward");
385 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
386 CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_filter),
387 "zeroing filter");
388 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
389 LIBXSMM_DNN_REGULAR_INPUT),
390 "bind input weight update");
391 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
392 LIBXSMM_DNN_GRADIENT_FILTER),
393 "bind filter weight update");
394 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
395 LIBXSMM_DNN_GRADIENT_OUTPUT),
396 "bind output weight update");
397 } else {
398 assert(0 /*should not happen*//*);
399 }
400
401 #if defined(LIBXSMM_DETAILED_TIMING)
402 l_tick4 = libxsmm_timer_tick();
403 #endif
404
405 const size_t scratch_size = libxsmm_dnn_get_scratch_size(
406 regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status);
407 CHECK_LIBXSMM_DNN(status, "get scratch size");
408 void* const scratch =
409 libxsmm_aligned_scratch(scratch_size, 2097152 /*alignment*//*);
410 CHECK_LIBXSMM(0 != scratch, "scratch memory allocation");
411 CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_scratch(
412 regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
413 "binding scratch");
414
415 #if defined(LIBXSMM_DETAILED_TIMING)
416 l_tick5 = libxsmm_timer_tick();
417 #endif
418
419 if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
420 libxsmm_dnn_transpose_filter(regentry.handle, LIBXSMM_DNN_FILTER);
421 }
422
423 #if defined(LIBXSMM_DETAILED_TIMING)
424 l_tick6 = libxsmm_timer_tick();
425 #endif
426
427 #if !defined(_OPENMP) || !defined(LIBXSMM_USE_OPENMP)
428 BlockingCounter counter(num_threads);
429
430 for (int i = 0; i < num_threads; ++i) {
431 worker_threads->workers->Schedule([=, &counter]() {
432 CHECK_LIBXSMM_DNN(libxsmm_dnn_execute_st(regentry.handle, kind, 0, i),
433 "worker");
434 counter.DecrementCount();
435 });
436 }
437 counter.Wait();
438 #else
439 #pragma omp parallel
440 {
441 CHECK_LIBXSMM_DNN(
442 libxsmm_dnn_execute_st(regentry.handle, kind, 0, omp_get_thread_num()),
443 "worker");
444 }
445 #endif
446
447 #if defined(LIBXSMM_DETAILED_TIMING)
448 l_tick7 = libxsmm_timer_tick();
449 #endif
450
451 if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
452 libxsmm_dnn_reduce_wu_filters(regentry.handle, LIBXSMM_DNN_GRADIENT_FILTER);
453 }
454
455 #if defined(LIBXSMM_DETAILED_TIMING)
456 l_tick8 = libxsmm_timer_tick();
457 #endif
458
459 /* clean up */ /*
460 CHECK_LIBXSMM_DNN(libxsmm_dnn_release_scratch(regentry.handle,
461 LIBXSMM_DNN_COMPUTE_KIND_ALL),
462 "release scratch");
463 if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
464 CHECK_LIBXSMM_DNN(
465 libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
466 "release input");
467 CHECK_LIBXSMM_DNN(
468 libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_OUTPUT),
469 "release output");
470 CHECK_LIBXSMM_DNN(
471 libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
472 "release filter");
473 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
474 CHECK_LIBXSMM_DNN(
475 libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_INPUT),
476 "release input");
477 CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
478 LIBXSMM_DNN_GRADIENT_OUTPUT),
479 "release output");
480 CHECK_LIBXSMM_DNN(
481 libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
482 "release filter");
483 } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
484 CHECK_LIBXSMM_DNN(
485 libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
486 "release input");
487 CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
488 LIBXSMM_DNN_GRADIENT_OUTPUT),
489 "release output");
490 CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
491 LIBXSMM_DNN_GRADIENT_FILTER),
492 "release filter");
493 } else {
494 /* shouldn't happen */ /*
495 }
496 CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_input), "destroy input");
497 CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_output),
498 "destroy output");
499 CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_filter),
500 "destroy filter");
501
502 #if defined(LIBXSMM_DETAILED_TIMING)
503 l_tick9 = libxsmm_timer_tick();
504 #endif
505
506 libxsmm_free(native_filter);
507 libxsmm_free(scratch);
508
509 #if defined(LIBXSMM_DETAILED_TIMING)
510 l_tick10 = libxsmm_timer_tick();
511 printf(
512 "time for convolution (%i, %i, %i, %i, %i): %f, %f, %f, %f, %f, %f, %f, "
513 "%f, %f, %f\n",
514 desc.N, desc.C, desc.K, desc.R, desc.S,
515 libxsmm_timer_duration(l_tick1, l_tick2),
516 libxsmm_timer_duration(l_tick2, l_tick3),
517 libxsmm_timer_duration(l_tick3, l_tick4),
518 libxsmm_timer_duration(l_tick4, l_tick5),
519 libxsmm_timer_duration(l_tick5, l_tick6),
520 libxsmm_timer_duration(l_tick6, l_tick7),
521 libxsmm_timer_duration(l_tick7, l_tick8),
522 libxsmm_timer_duration(l_tick8, l_tick9),
523 libxsmm_timer_duration(l_tick9, l_tick10),
524 libxsmm_timer_duration(l_tick1, l_tick10));
525 #endif
526
527 return true; // Succeeded
528 */
529 }
530
531 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
532 template <typename T>
533 struct XsmmFwdConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmFwdConv2D534 bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
535 const T* input, const T* filter, T* output) {
536 return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD,
537 input, filter, output);
538 }
539 };
540 #endif
541
542 #ifdef TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS
543 template <typename T>
544 struct XsmmBkwInputConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmBkwInputConv2D545 bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
546 T* input, const T* filter, const T* output) {
547 return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD,
548 input, filter, output);
549 }
550 };
551
552 template <typename T>
553 struct XsmmBkwFilterConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmBkwFilterConv2D554 bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
555 const T* input, T* filter, const T* output) {
556 return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD,
557 input, filter, output);
558 }
559 };
560 #endif
561
562 } // namespace functor
563
564 template struct functor::XsmmFwdConv2D<CPUDevice, float>;
565 template struct functor::XsmmBkwInputConv2D<CPUDevice, float>;
566 template struct functor::XsmmBkwFilterConv2D<CPUDevice, float>;
567
568 } // namespace tensorflow
569
570 #endif // TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
571