xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/xsmm_conv2d.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Make this file empty (or nearly empty) so that it can be compiled even when
17 // libxsmm is not available.
18 
19 #ifndef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
20 void dummy_xsmm_conv2d_ensure_file_is_not_empty();
21 #else
22 
23 #define USE_EIGEN_TENSOR
24 #define EIGEN_USE_THREADS
25 
26 #include "tensorflow/core/kernels/xsmm_conv2d.h"
27 
28 #include <stdlib.h>
29 #include <cstring>
30 #if defined(_OPENMP) && defined(LIBXSMM_USE_OPENMP)
31 #include <omp.h>
32 #endif
33 
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/lib/core/blocking_counter.h"
36 #include "tensorflow/core/lib/core/threadpool.h"
37 
38 #include "include/libxsmm_cpuid.h"
39 #include "include/libxsmm_malloc.h"
40 #include "src/libxsmm_main.h"  // TODO(bsteiner): API to avoid incl. header from src/
41 
42 #define CHECK_LIBXSMM(CONDITION_OK, MESSAGE) \
43   if (!(CONDITION_OK)) VLOG(0) << (MESSAGE)
44 #define CHECK_LIBXSMM_DNN(STATUS, MESSAGE)                \
45   CHECK_LIBXSMM(LIBXSMM_DNN_SUCCESS == (STATUS), MESSAGE) \
46       << " failed: " << libxsmm_dnn_get_error(STATUS);
47 
48 namespace tensorflow {
49 
50 // Xsmm*Conv2D are wrappers for libxsmm direct convolutions.
51 
52 // Returns true if convolution can be computed efficiently by XsmmConv2D,
53 // returns false otherwise.
CanUseXsmmConv2D(const libxsmm_dnn_conv_desc & desc,TensorFormat data_format)54 bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc,
55                       TensorFormat data_format) {
56   int VECTOR_SIZE;
57   int arch = libxsmm_cpuid_x86();
58 
59   if (arch == LIBXSMM_X86_AVX512_CORE) {
60     VECTOR_SIZE = 16;
61   } else if (arch == LIBXSMM_X86_AVX2) {
62     VECTOR_SIZE = 8;
63   } else {
64     VLOG(1) << "Cannot use XSMM convolutions: unsupported architecture!";
65     return false;
66   }
67 
68   if (data_format != FORMAT_NHWC) {
69     VLOG(1) << "Cannot use XSMM convolutions: unsupported format!";
70     return false;
71   }
72   if (desc.K % VECTOR_SIZE != 0) {
73     VLOG(1) << "Cannot use XSMM convolutions: output features count not"
74                " divisible by vector size!";
75     return false;
76   }
77   VLOG(2) << "Can use XSMM convolutions.";
78   return true;
79 }
80 
81 typedef Eigen::ThreadPoolDevice CPUDevice;
82 
83 namespace functor {
84 
copy_RSCK_to_custom(const float * rsck,float * kcrs,int R,int S,int C,int K,int blocksifm,int blocksofm,int ifmblock,int ofmblock,int start,int end)85 LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
86                                         int S, int C, int K, int blocksifm,
87                                         int blocksofm, int ifmblock,
88                                         int ofmblock, int start, int end) {
89   LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C, K);
90   LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm, R, S, ifmblock, ofmblock);
91   int r, s, k, c, v1, v2;
92 
93   for (k = start; k < end; k++) {
94     for (c = 0; c < blocksifm; c++) {
95       for (r = 0; r < R; r++) {
96         for (s = 0; s < S; s++) {
97           for (v1 = c * ifmblock; v1 < std::min(C, (c + 1) * ifmblock); v1++) {
98             for (v2 = k * ofmblock; v2 < std::min(K, (k + 1) * ofmblock); v2++)
99               LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
100                                  v2 - k * ofmblock, blocksifm, R, S, ifmblock,
101                                  ofmblock) =
102                   LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K);
103             for (v2 = K; v2 < (k + 1) * ofmblock; v2++)
104               LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
105                                  v2 - k * ofmblock, blocksifm, R, S, ifmblock,
106                                  ofmblock) = 0.0f;
107           }
108           for (v1 = C; v1 < (c + 1) * ifmblock; v1++) {
109             for (v2 = k * ofmblock; v2 < (k + 1) * ofmblock; v2++)
110               LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
111                                  v2 - k * ofmblock, blocksifm, R, S, ifmblock,
112                                  ofmblock) = 0.0f;
113           }
114         }
115       }
116     }
117   }
118 }
119 
120 struct libxsmm_dnn_registry_key {
121   const libxsmm_dnn_conv_desc descriptor;
libxsmm_dnn_registry_keytensorflow::functor::libxsmm_dnn_registry_key122   libxsmm_dnn_registry_key(const libxsmm_dnn_conv_desc& desc_)
123       : descriptor(desc_) {}
operator ==tensorflow::functor::libxsmm_dnn_registry_key124   bool operator==(const libxsmm_dnn_registry_key& regkey) const {
125     return 0 == memcmp(&descriptor, &regkey.descriptor, sizeof(descriptor));
126   }
127 };
128 
129 struct HashFunction {
operator ()tensorflow::functor::HashFunction130   std::size_t operator()(const libxsmm_dnn_registry_key& regkey) const {
131     return libxsmm_hash(&regkey.descriptor, sizeof(regkey.descriptor),
132                         25071975);
133   }
134 };
135 
136 struct libxsmm_dnn_registry_value {
137   libxsmm_dnn_tensor_datalayout* layout_input;
138   libxsmm_dnn_tensor_datalayout* layout_filter;
139   libxsmm_dnn_tensor_datalayout* layout_output;
140   libxsmm_dnn_layer* handle;
141 };
142 
143 typedef libxsmm_tf_allocator<libxsmm_scratch_allocator>
144     libxsmm_tf_scratch_allocator;
145 
146 static class libxsmm_dnn_registry_type {
147  private:
148   typedef std::unordered_map<libxsmm_dnn_registry_key,
149                              libxsmm_dnn_registry_value, HashFunction>
150       container_type;
151 
152  public:
libxsmm_dnn_registry_type()153   libxsmm_dnn_registry_type() {
154     libxsmm_init(); /* must be first */
155 #if !defined(LIBXSMM_LOCAL_ALLOC)
156     {
157       libxsmm_malloc_function malloc_fn;
158       libxsmm_free_function free_fn;
159       malloc_fn.function = libxsmm_tf_scratch_allocator::malloc;
160       free_fn.function = libxsmm_tf_scratch_allocator::free;
161       libxsmm_set_scratch_allocator(0 /*context*/, malloc_fn, free_fn);
162     }
163 #endif
164     LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_LOCK_RWLOCK, &attr);
165     LIBXSMM_LOCK_INIT(LIBXSMM_LOCK_RWLOCK, &lock, &attr);
166   }
~libxsmm_dnn_registry_type()167   ~libxsmm_dnn_registry_type() {
168     LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
169     const container_type::const_iterator end = container.end();
170     for (container_type::const_iterator i = container.begin(); i != end; ++i) {
171       CHECK_LIBXSMM_DNN(
172           libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_input),
173           "destroy input layout");
174       CHECK_LIBXSMM_DNN(
175           libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_output),
176           "destroy output layout");
177       CHECK_LIBXSMM_DNN(
178           libxsmm_dnn_destroy_tensor_datalayout(i->second.layout_filter),
179           "destroy filter layout");
180       CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_conv_layer(i->second.handle),
181                         "destroy handle");
182     }
183     LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
184     LIBXSMM_LOCK_DESTROY(LIBXSMM_LOCK_RWLOCK, &lock);
185     LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_LOCK_RWLOCK, &attr);
186     libxsmm_finalize();
187   }
find(const libxsmm_dnn_registry_key & regkey)188   libxsmm_dnn_registry_value find(const libxsmm_dnn_registry_key& regkey) {
189     container_type::iterator i;
190     LIBXSMM_LOCK_ACQREAD(LIBXSMM_LOCK_RWLOCK, &lock);
191     i = container.find(regkey);
192     LIBXSMM_LOCK_RELREAD(LIBXSMM_LOCK_RWLOCK, &lock);
193     if (i == container.end()) {
194       libxsmm_dnn_err_t status;
195       libxsmm_dnn_registry_value regentry;
196 
197       LIBXSMM_LOCK_ACQUIRE(LIBXSMM_LOCK_RWLOCK, &lock);
198       i = container.find(regkey);
199       if (i == container.end()) {  // re-check after lock acquisition
200         regentry.handle =
201             libxsmm_dnn_create_conv_layer(regkey.descriptor, &status);
202         if (LIBXSMM_DNN_WARN_FALLBACK != status) {
203           CHECK_LIBXSMM_DNN(status, "create handle");
204         } else {  // warning
205           VLOG(1) << libxsmm_dnn_get_error(status);
206         }
207         regentry.layout_input = libxsmm_dnn_create_tensor_datalayout(
208             regentry.handle, LIBXSMM_DNN_INPUT, &status);
209         CHECK_LIBXSMM_DNN(status, "create input layout");
210 
211         regentry.layout_output = libxsmm_dnn_create_tensor_datalayout(
212             regentry.handle, LIBXSMM_DNN_OUTPUT, &status);
213         CHECK_LIBXSMM_DNN(status, "create output layout");
214 
215         regentry.layout_filter = libxsmm_dnn_create_tensor_datalayout(
216             regentry.handle, LIBXSMM_DNN_FILTER, &status);
217         CHECK_LIBXSMM_DNN(status, "create filter layout");
218 
219         i = container.insert(std::make_pair(regkey, regentry)).first;
220       }
221       LIBXSMM_LOCK_RELEASE(LIBXSMM_LOCK_RWLOCK, &lock);
222     }
223     return i->second;
224   }
225 
226  private:
227   container_type container;
228   LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_LOCK_RWLOCK) attr;
229   LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK_RWLOCK) lock;
230 } libxsmm_dnn_registry;
231 
232 // #define LIBXSMM_DETAILED_TIMING
233 
234 template <typename InputPtr, typename FilterPtr, typename OutputPtr>
CallLibxsmmConvGeneric(OpKernelContext * ctx,const libxsmm_dnn_conv_desc & desc,libxsmm_dnn_compute_kind kind,InputPtr input,FilterPtr filter,OutputPtr output)235 static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
236                                    const libxsmm_dnn_conv_desc& desc,
237                                    libxsmm_dnn_compute_kind kind,
238                                    InputPtr input, FilterPtr filter,
239                                    OutputPtr output) {
240   // TODO(penporn): Fix calls to deprecated LIBXSMM API or delete this kernel.
241   // Fall back to non-libxsmm code for now.
242   return false;
243   /*
244   #if defined(LIBXSMM_DETAILED_TIMING)
245     libxsmm_timer_tickint l_tick1;
246     libxsmm_timer_tickint l_tick2;
247     libxsmm_timer_tickint l_tick3;
248     libxsmm_timer_tickint l_tick4;
249     libxsmm_timer_tickint l_tick5;
250     libxsmm_timer_tickint l_tick6;
251     libxsmm_timer_tickint l_tick7;
252     libxsmm_timer_tickint l_tick8;
253     libxsmm_timer_tickint l_tick9;
254     libxsmm_timer_tickint l_tick10;
255     l_tick1 = libxsmm_timer_tick();
256   #endif
257   #if defined(LIBXSMM_LOCAL_ALLOC)
258     // setup scoped allocator, which adopts the allocator of the current context
259     const libxsmm_tf_scratch_allocator tf_allocator(*ctx);
260   #endif
261     const libxsmm_dnn_registry_key regkey(desc);
262     const libxsmm_dnn_registry_value regentry =
263   libxsmm_dnn_registry.find(regkey); libxsmm_dnn_tensor *libxsmm_input,
264   *libxsmm_output, *libxsmm_filter; libxsmm_dnn_err_t status;
265 
266     status = libxsmm_dnn_get_codegen_success(regentry.handle, kind);
267     if (status == LIBXSMM_DNN_WARN_FALLBACK) {
268       return false;  // Use non-libxsmm code
269     }
270     CHECK_LIBXSMM_DNN(status, "code generation");
271 
272   #if defined(LIBXSMM_DETAILED_TIMING)
273     l_tick2 = libxsmm_timer_tick();
274   #endif
275 
276     const int ifmblock = regentry.handle->ifmblock;
277     const int ofmblock = regentry.handle->ofmblock;
278 
279     const int blocksifm =
280         (desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1);
281     const int blocksofm =
282         (desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1);
283 
284     const size_t filter_size =
285         blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock;
286     float* const native_filter = (float*)libxsmm_aligned_scratch(
287         filter_size * sizeof(float), 2097152 /*alignment*//*);
288 
289   const DeviceBase::CpuWorkerThreads* const worker_threads =
290       ctx->device()->tensorflow_cpu_worker_threads();
291   const int num_threads = worker_threads->num_threads;
292 
293 #if 1
294   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
295       kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
296     if (blocksofm > num_threads) {
297       const int work = blocksofm;
298       BlockingCounter count(num_threads);
299       for (int i = 0; i < num_threads; ++i) {
300         worker_threads->workers->Schedule([=, &count]() {
301           const int start = work / num_threads * i;
302           const int end = (start + work / num_threads) > work
303                               ? work
304                               : start + work / num_threads;
305           copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
306                               desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
307                               start, end);
308           count.DecrementCount();
309         });
310       }
311       count.Wait();
312     } else {
313       const int work = blocksofm;
314       const int num_tasks = work;
315 
316       BlockingCounter count(num_tasks);
317       for (int i = 0; i < num_tasks; ++i) {
318         worker_threads->workers->Schedule([=, &count]() {
319           const int start = i;
320           const int end = i + 1;
321           copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
322                               desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
323                               start, end);
324           count.DecrementCount();
325         });
326       }
327       count.Wait();
328     }
329   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
330     // weight update buffer must be in the right format
331     // (LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR)
332     libxsmm_filter =
333         libxsmm_dnn_link_tensor(regentry.layout_filter, filter, &status);
334     CHECK_LIBXSMM_DNN(status, "link filter with layout");
335   }
336 #else
337   memset(native_filter, 0, filter_size * sizeof(float));
338 #endif
339 
340 #if defined(LIBXSMM_DETAILED_TIMING)
341   l_tick3 = libxsmm_timer_tick();
342 #endif
343 
344   // LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
345   libxsmm_input =
346       libxsmm_dnn_link_tensor(regentry.layout_input, input, &status);
347   CHECK_LIBXSMM_DNN(status, "link input buffer with layout");
348 
349   // LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR
350   libxsmm_output =
351       libxsmm_dnn_link_tensor(regentry.layout_output, output, &status);
352   CHECK_LIBXSMM_DNN(status, "link output buffer with layout");
353 
354   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
355       kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
356     // LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR
357     libxsmm_filter =
358         libxsmm_dnn_link_tensor(regentry.layout_filter, native_filter, &status);
359     CHECK_LIBXSMM_DNN(status, "link filter with layout");
360     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
361                                               LIBXSMM_DNN_REGULAR_FILTER),
362                       "bind filter to handle");
363   }
364   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
365     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
366                                               LIBXSMM_DNN_REGULAR_INPUT),
367                       "bind input forward");
368     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
369                                               LIBXSMM_DNN_REGULAR_FILTER),
370                       "bind filter forward");
371     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
372                                               LIBXSMM_DNN_REGULAR_OUTPUT),
373                       "bind output forward");
374   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
375     CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_input), "zeroing input");
376     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
377                                               LIBXSMM_DNN_GRADIENT_INPUT),
378                       "bind input backward");
379     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
380                                               LIBXSMM_DNN_REGULAR_FILTER),
381                       "bind filter backward");
382     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
383                                               LIBXSMM_DNN_GRADIENT_OUTPUT),
384                       "bind output backward");
385   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
386     CHECK_LIBXSMM_DNN(libxsmm_dnn_zero_tensor(libxsmm_filter),
387                       "zeroing filter");
388     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_input,
389                                               LIBXSMM_DNN_REGULAR_INPUT),
390                       "bind input weight update");
391     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_filter,
392                                               LIBXSMM_DNN_GRADIENT_FILTER),
393                       "bind filter weight update");
394     CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_tensor(regentry.handle, libxsmm_output,
395                                               LIBXSMM_DNN_GRADIENT_OUTPUT),
396                       "bind output weight update");
397   } else {
398     assert(0 /*should not happen*//*);
399   }
400 
401 #if defined(LIBXSMM_DETAILED_TIMING)
402   l_tick4 = libxsmm_timer_tick();
403 #endif
404 
405   const size_t scratch_size = libxsmm_dnn_get_scratch_size(
406       regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status);
407   CHECK_LIBXSMM_DNN(status, "get scratch size");
408   void* const scratch =
409       libxsmm_aligned_scratch(scratch_size, 2097152 /*alignment*//*);
410   CHECK_LIBXSMM(0 != scratch, "scratch memory allocation");
411   CHECK_LIBXSMM_DNN(libxsmm_dnn_bind_scratch(
412                         regentry.handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
413                     "binding scratch");
414 
415 #if defined(LIBXSMM_DETAILED_TIMING)
416   l_tick5 = libxsmm_timer_tick();
417 #endif
418 
419   if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
420     libxsmm_dnn_transpose_filter(regentry.handle, LIBXSMM_DNN_FILTER);
421   }
422 
423 #if defined(LIBXSMM_DETAILED_TIMING)
424   l_tick6 = libxsmm_timer_tick();
425 #endif
426 
427 #if !defined(_OPENMP) || !defined(LIBXSMM_USE_OPENMP)
428   BlockingCounter counter(num_threads);
429 
430   for (int i = 0; i < num_threads; ++i) {
431     worker_threads->workers->Schedule([=, &counter]() {
432       CHECK_LIBXSMM_DNN(libxsmm_dnn_execute_st(regentry.handle, kind, 0, i),
433                         "worker");
434       counter.DecrementCount();
435     });
436   }
437   counter.Wait();
438 #else
439 #pragma omp parallel
440   {
441     CHECK_LIBXSMM_DNN(
442         libxsmm_dnn_execute_st(regentry.handle, kind, 0, omp_get_thread_num()),
443         "worker");
444   }
445 #endif
446 
447 #if defined(LIBXSMM_DETAILED_TIMING)
448   l_tick7 = libxsmm_timer_tick();
449 #endif
450 
451   if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
452     libxsmm_dnn_reduce_wu_filters(regentry.handle, LIBXSMM_DNN_GRADIENT_FILTER);
453   }
454 
455 #if defined(LIBXSMM_DETAILED_TIMING)
456   l_tick8 = libxsmm_timer_tick();
457 #endif
458 
459   /* clean up */ /*
460   CHECK_LIBXSMM_DNN(libxsmm_dnn_release_scratch(regentry.handle,
461                                                 LIBXSMM_DNN_COMPUTE_KIND_ALL),
462                     "release scratch");
463   if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
464     CHECK_LIBXSMM_DNN(
465         libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
466         "release input");
467     CHECK_LIBXSMM_DNN(
468         libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_OUTPUT),
469         "release output");
470     CHECK_LIBXSMM_DNN(
471         libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
472         "release filter");
473   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
474     CHECK_LIBXSMM_DNN(
475         libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_GRADIENT_INPUT),
476         "release input");
477     CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
478                                                  LIBXSMM_DNN_GRADIENT_OUTPUT),
479                       "release output");
480     CHECK_LIBXSMM_DNN(
481         libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_FILTER),
482         "release filter");
483   } else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
484     CHECK_LIBXSMM_DNN(
485         libxsmm_dnn_release_tensor(regentry.handle, LIBXSMM_DNN_REGULAR_INPUT),
486         "release input");
487     CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
488                                                  LIBXSMM_DNN_GRADIENT_OUTPUT),
489                       "release output");
490     CHECK_LIBXSMM_DNN(libxsmm_dnn_release_tensor(regentry.handle,
491                                                  LIBXSMM_DNN_GRADIENT_FILTER),
492                       "release filter");
493   } else {
494     /* shouldn't happen */ /*
495   }
496   CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_input), "destroy input");
497   CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_output),
498                     "destroy output");
499   CHECK_LIBXSMM_DNN(libxsmm_dnn_destroy_tensor(libxsmm_filter),
500                     "destroy filter");
501 
502 #if defined(LIBXSMM_DETAILED_TIMING)
503   l_tick9 = libxsmm_timer_tick();
504 #endif
505 
506   libxsmm_free(native_filter);
507   libxsmm_free(scratch);
508 
509 #if defined(LIBXSMM_DETAILED_TIMING)
510   l_tick10 = libxsmm_timer_tick();
511   printf(
512       "time for convolution (%i, %i, %i, %i, %i): %f, %f, %f, %f, %f, %f, %f, "
513       "%f, %f, %f\n",
514       desc.N, desc.C, desc.K, desc.R, desc.S,
515       libxsmm_timer_duration(l_tick1, l_tick2),
516       libxsmm_timer_duration(l_tick2, l_tick3),
517       libxsmm_timer_duration(l_tick3, l_tick4),
518       libxsmm_timer_duration(l_tick4, l_tick5),
519       libxsmm_timer_duration(l_tick5, l_tick6),
520       libxsmm_timer_duration(l_tick6, l_tick7),
521       libxsmm_timer_duration(l_tick7, l_tick8),
522       libxsmm_timer_duration(l_tick8, l_tick9),
523       libxsmm_timer_duration(l_tick9, l_tick10),
524       libxsmm_timer_duration(l_tick1, l_tick10));
525 #endif
526 
527   return true;  // Succeeded
528   */
529 }
530 
531 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
532 template <typename T>
533 struct XsmmFwdConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmFwdConv2D534   bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
535                   const T* input, const T* filter, T* output) {
536     return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD,
537                                   input, filter, output);
538   }
539 };
540 #endif
541 
542 #ifdef TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS
543 template <typename T>
544 struct XsmmBkwInputConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmBkwInputConv2D545   bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
546                   T* input, const T* filter, const T* output) {
547     return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD,
548                                   input, filter, output);
549   }
550 };
551 
552 template <typename T>
553 struct XsmmBkwFilterConv2D<CPUDevice, T> {
operator ()tensorflow::functor::XsmmBkwFilterConv2D554   bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
555                   const T* input, T* filter, const T* output) {
556     return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD,
557                                   input, filter, output);
558   }
559 };
560 #endif
561 
562 }  // namespace functor
563 
564 template struct functor::XsmmFwdConv2D<CPUDevice, float>;
565 template struct functor::XsmmBkwInputConv2D<CPUDevice, float>;
566 template struct functor::XsmmBkwFilterConv2D<CPUDevice, float>;
567 
568 }  // namespace tensorflow
569 
570 #endif  // TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
571