xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/conv_2d_gpu.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
17 #define TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 
21 #define EIGEN_USE_GPU
22 
23 #include <algorithm>
24 #include <array>
25 #include <limits>
26 #include <utility>
27 
28 #if GOOGLE_CUDA
29 #include "third_party/gpus/cuda/include/cuda.h"
30 #endif
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/kernels/conv_2d.h"
33 #include "tensorflow/core/lib/math/math_util.h"
34 #include "tensorflow/core/util/gpu_kernel_helper.h"
35 #include "tensorflow/core/util/tensor_format.h"
36 
37 namespace tensorflow {
38 
39 typedef Eigen::GpuDevice GPUDevice;
40 
41 namespace functor {
42 
43 template <typename T, bool conjugate>
44 struct maybe_conj {
runmaybe_conj45   __device__ static __inline__ T run(T x) {
46     if (conjugate) {
47       return Eigen::numext::conj(x);
48     } else {
49       return x;
50     }
51   }
52 };
53 
54 // Partial specializations for Gpu types used to store complex numbers.
55 template <bool conjugate>
56 struct maybe_conj<float2, conjugate> {
57   __device__ static __inline__ float2 run(float2 c) {
58     if (conjugate) {
59       float2 c_conj;
60       c_conj.x = c.x;
61       c_conj.y = -c.y;
62       return c_conj;
63     } else {
64       return c;
65     }
66   }
67 };
68 
69 template <bool conjugate>
70 struct maybe_conj<double2, conjugate> {
71   __device__ static __inline__ double2 run(double2 c) {
72     if (conjugate) {
73       double2 c_conj;
74       c_conj.x = c.x;
75       c_conj.y = -c.y;
76       return c_conj;
77     } else {
78       return c;
79     }
80   }
81 };
82 
83 // TODO(mjanusz): Move this to a shared util file.
84 // A simple array that contains data that can be passed between CPU and GPU.
85 template <typename T, int IndexCount, T DefaultValue>
86 struct Array {
87   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const {
88     return data[index];
89   }
90   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) {
91     return data[index];
92   }
93   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array() {
94     for (int i = 0; i < IndexCount; i++) {
95       data[i] = DefaultValue;
96     }
97   }
98   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0) {
99     data[0] = a0;
100     for (int i = 1; i < IndexCount; i++) {
101       data[i] = DefaultValue;
102     }
103   }
104   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1) {
105     data[0] = a0;
106     data[1] = a1;
107     for (int i = 2; i < IndexCount; i++) {
108       data[i] = DefaultValue;
109     }
110   }
111   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1, T a2) {
112     data[0] = a0;
113     data[1] = a1;
114     data[2] = a2;
115     for (int i = 3; i < IndexCount; i++) {
116       data[i] = DefaultValue;
117     }
118   }
119   EIGEN_STRONG_INLINE Array(const std::array<T, IndexCount>& array) {
120     for (int i = 0; i < IndexCount; i++) {
121       data[i] = array[i];
122     }
123   }
124   T data[IndexCount];
125 };
126 
127 // A dimension type with compile-time known size.
128 template <int IndexCount>
129 struct Dimension : Array<int, IndexCount, 1> {
130   typedef Array<int, IndexCount, 1> Base;
131   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension() : Base() {}
132   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0) : Base(a0) {}
133   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1)
134       : Base(a0, a1) {}
135   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1, int a2)
136       : Base(a0, a1, a2) {}
137   EIGEN_STRONG_INLINE Dimension(const std::array<int, IndexCount>& array)
138       : Base(array) {}
139 };
140 
141 // An index type with compile-time known size.
142 template <int IndexCount>
143 struct Index : Array<int, IndexCount, 0> {
144   typedef Array<int, IndexCount, 0> Base;
145   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index() : Base() {}
146   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0) : Base(a0) {}
147   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1) : Base(a0, a1) {}
148   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1, int a2)
149       : Base(a0, a1, a2) {}
150 };
151 
152 // A helper function that converts a tensor index into a flat array index.
153 template <int IndexCount>
154 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int TensorIndexToFlat(
155     const Index<IndexCount>& index, const Dimension<IndexCount>& dims) {
156   int flat_index = index[0];
157   for (int i = 1; i < IndexCount; i++) {
158     flat_index = flat_index * dims[i] + index[i];
159   }
160   return flat_index;
161 }
162 
163 // A helper function that converts a flat array index into a tensor index.
164 template <int IndexCount>
165 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex(
166     int index, const Dimension<IndexCount>& dims) {
167   Index<IndexCount> tensor_index;
168   for (int i = IndexCount - 1; i >= 0; i--) {
169     int new_index = index / dims[i];
170     tensor_index[i] = index - dims[i] * new_index;
171     index = new_index;
172   }
173   return tensor_index;
174 }
175 
176 // A simple CUDA custom kernel to shuffle dimensions of a 3D tensor according to
177 // the given shuffle permutation in template parameters. Shuffle permutation
178 // <sp0, sp1, sp2> shuffles dimensions such that input dimension 0 goes to sp0,
179 // 1 goes to sp1 and 2 goes to sp2. For example, shuffle permutation <2, 0, 1>
180 // will populate output so that input[x][y][z] is equal to (*output)[y][z][x].
181 //
182 // Requires that nthreads is equal to the total number of elements in the input
183 // tensor.
184 template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
185 __global__ void ShuffleInTensor3Simple(int nthreads,
186                                        const T* __restrict__ input,
187                                        Dimension<3> input_dims,
188                                        T* __restrict__ output) {
189   Dimension<3> output_dims;
190   output_dims[sp0] = input_dims[0];
191   output_dims[sp1] = input_dims[1];
192   output_dims[sp2] = input_dims[2];
193 
194   // Iterate over output as opposed to iterating over input for better
195   // performance. Iterating over output will generate sequential writes and
196   // random reads that performs better compared to sequential reads and random
197   // writes.
198   GPU_1D_KERNEL_LOOP(output_index, nthreads) {
199     Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
200 
201     Index<3> input_tensor_index;
202     input_tensor_index[0] = output_tensor_index[sp0];
203     input_tensor_index[1] = output_tensor_index[sp1];
204     input_tensor_index[2] = output_tensor_index[sp2];
205 
206     int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
207 
208     output[output_index] =
209         maybe_conj<T, conjugate>::run(ldg(input + input_index));
210   }
211 }
212 
213 static constexpr int kUnroll = 4;
214 
215 template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
216 __global__ void ShuffleInTensor3SimpleVector(int nthreads,
217                                              const T* __restrict__ input,
218                                              Dimension<3> input_dims,
219                                              T* __restrict__ output) {
220   Dimension<3> output_dims;
221   output_dims[sp0] = input_dims[0];
222   output_dims[sp1] = input_dims[1];
223   output_dims[sp2] = input_dims[2];
224 
225   const int stride = blockDim.x * gridDim.x * kUnroll;
226   const int tid = blockIdx.x * blockDim.x + threadIdx.x;
227   T buf[kUnroll];
228 
229   int output_index;
230   for (output_index = tid * kUnroll; output_index + kUnroll - 1 < nthreads;
231        output_index += stride) {
232 #pragma unroll
233     for (int i = 0; i < kUnroll; i++) {
234       int output_index_i = output_index + i;
235       Index<3> output_tensor_index =
236           FlatToTensorIndex(output_index_i, output_dims);
237       Index<3> input_tensor_index;
238       input_tensor_index[0] = output_tensor_index[sp0];
239       input_tensor_index[1] = output_tensor_index[sp1];
240       input_tensor_index[2] = output_tensor_index[sp2];
241 
242       int input_index_i = TensorIndexToFlat(input_tensor_index, input_dims);
243       buf[i] = maybe_conj<T, conjugate>::run(ldg(input + input_index_i));
244     }
245     float2* out = reinterpret_cast<float2*>(output + output_index);
246     *out = *reinterpret_cast<float2*>(buf);
247   }
248 
249   for (; output_index < nthreads; ++output_index) {
250     Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
251 
252     Index<3> input_tensor_index;
253     input_tensor_index[0] = output_tensor_index[sp0];
254     input_tensor_index[1] = output_tensor_index[sp1];
255     input_tensor_index[2] = output_tensor_index[sp2];
256 
257     int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
258 
259     output[output_index] =
260         maybe_conj<T, conjugate>::run(ldg(input + input_index));
261   }
262 }
263 
264 // Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor,
265 // where dimensions are zero-based: output[i][j][k] = input[i][k][j].
266 //
267 // Each thread block operates on a single tile, a rectangle of dimensions
268 // TileSizeI x TileSizeJ.
269 //
270 // In general, for best performance, you should probably set TileSizeI,
271 // TileSizeJ equal to the number of threads in a warp (32 in nvidia GPUs).
272 // With a TileSizeI, TileSizeJ of 32, NumThreads of 128 or 256 seems to get
273 // the best performance on K40 GPUs.
274 template <typename T, int NumThreads, int TileSizeI, int TileSizeJ,
275           bool conjugate = false>
276 __global__ void SwapDimension1And2InTensor3UsingTiles(
277     const T* __restrict__ input, Dimension<3> input_dims,
278     T* __restrict__ output) {
279   eigen_assert(blockDim.x == NumThreads);
280   eigen_assert(blockDim.y == 1);
281   eigen_assert(blockDim.z == 1);
282   eigen_assert(gridDim.y == 1);
283   eigen_assert(gridDim.z == 1);
284 
285   constexpr int ReadRowPerPass = NumThreads / TileSizeJ;
286   constexpr int WriteRowPerPass = NumThreads / TileSizeI;
287   // One extra line in the inner dimension to avoid share memory bank conflict.
288   // This is to mimic the following, but no constructor of T can be invoked.
289   //     __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
290 #if GOOGLE_CUDA
291   __shared__ __align__(
292       alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)];
293   typedef T(*SharedMemoryTile)[TileSizeJ + 1];
294   SharedMemoryTile shared_memory_tile =
295       reinterpret_cast<SharedMemoryTile>(shared_mem_raw);
296 #elif TENSORFLOW_USE_ROCM
297   __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
298 #endif
299 
300   int x = threadIdx.x;
301 
302   Dimension<3> output_dims = {
303       input_dims[0],
304       input_dims[2],
305       input_dims[1],
306   };
307 
308   Dimension<3> input_dims_in_tiles = {
309       input_dims[0],
310       (input_dims[1] + TileSizeI - 1) / TileSizeI,
311       (input_dims[2] + TileSizeJ - 1) / TileSizeJ,
312   };
313 
314   Index<3> input_tile_index =
315       FlatToTensorIndex(blockIdx.x, input_dims_in_tiles);
316 
317   Index<3> input_tile_origin = {
318       input_tile_index[0],
319       input_tile_index[1] * TileSizeI,
320       input_tile_index[2] * TileSizeJ,
321   };
322 
323   int input_origin_flat_index =
324       TensorIndexToFlat(input_tile_origin, input_dims);
325 
326   bool full_tile = true;
327   int tile_width = TileSizeJ;
328 
329   // Only the last row or column may not have the full size.
330   if (input_tile_index[2] == input_dims_in_tiles[2] - 1) {
331     tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileSizeJ;
332     full_tile &= false;
333   }
334 
335   int tile_height = TileSizeI;
336 
337   if (input_tile_index[1] == input_dims_in_tiles[1] - 1) {
338     tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSizeI;
339     full_tile &= false;
340   }
341 
342   // Calculate effective thread number. This ensures that we use the largest
343   // number of threads available to form a regular thread block with no
344   // trailing incomplete lines.
345   constexpr int in_effective_thread_num = NumThreads / TileSizeJ * TileSizeJ;
346 
347   if (x < in_effective_thread_num) {
348     // Orient the logical thread block with respect to the input array.
349     // ie. align the contiguous dimension of thread blocks with the contiguous
350     // dimension of the input array.
351     int ti = x / TileSizeJ;
352     int tj = x % TileSizeJ;
353     int input_index = input_origin_flat_index + ti * input_dims[2] + tj;
354     int input_increment = ReadRowPerPass * input_dims[2];
355 
356     if (full_tile) {
357 #pragma unroll
358       for (int i_loc = ti; i_loc < (TileSizeI); i_loc += ReadRowPerPass) {
359         shared_memory_tile[i_loc][tj] =
360             maybe_conj<T, conjugate>::run(input[input_index]);
361         input_index += input_increment;
362       }
363     } else {
364       if (tj < tile_width) {
365         for (int i_loc = ti; i_loc < (tile_height); i_loc += ReadRowPerPass) {
366           shared_memory_tile[i_loc][tj] =
367               maybe_conj<T, conjugate>::run(input[input_index]);
368           input_index += input_increment;
369         }
370       }
371     }
372   }
373 
374   __syncthreads();
375 
376   Index<3> output_tile_index = {
377       input_tile_index[0],
378       input_tile_index[2],
379       input_tile_index[1],
380   };
381 
382   Index<3> output_tile_origin = {
383       output_tile_index[0],
384       output_tile_index[1] * TileSizeJ,
385       output_tile_index[2] * TileSizeI,
386   };
387 
388   int output_origin_flat_index =
389       TensorIndexToFlat(output_tile_origin, output_dims);
390 
391   constexpr int out_effective_thread_num = NumThreads / TileSizeI * TileSizeI;
392 
393   if (x < out_effective_thread_num) {
394     // Re-orient the logical thread block with respect to the output array.
395     // ie. align the contiguous dimension of thread blocks with contiguous
396     // dimension of the output array.
397     int ti = x / TileSizeI;
398     int tj = x % TileSizeI;
399     int output_index = output_origin_flat_index + ti * output_dims[2] + tj;
400     int output_increment = WriteRowPerPass * output_dims[2];
401 
402     if (full_tile) {
403 #pragma unroll
404       for (int i_loc = ti; i_loc < (TileSizeJ); i_loc += WriteRowPerPass) {
405         output[output_index] = shared_memory_tile[tj][i_loc];
406         output_index += output_increment;
407       }
408     } else {
409       if (tj < tile_height) {
410         for (int i_loc = ti; i_loc < (tile_width); i_loc += WriteRowPerPass) {
411           output[output_index] = shared_memory_tile[tj][i_loc];
412           output_index += output_increment;
413         }
414       }
415     }
416   }
417 }
418 
419 // A Gpu custom kernel that convert input to output, given proper padding on
420 // the left and the top.
421 template <typename T, int NDIMS>
422 __global__ void PadInputCustomKernelNHWC(
423     int nthreads, const T* __restrict__ input, Dimension<NDIMS> input_dims,
424     T* __restrict__ output, Dimension<NDIMS> output_dims,
425     Dimension<NDIMS - 2> padding_left, T padding_value) {
426   GPU_1D_KERNEL_LOOP(index, nthreads) {
427     int output_index = index;
428     Index<NDIMS> output_tensor_index =
429         FlatToTensorIndex(output_index, output_dims);
430 
431     Index<NDIMS> input_tensor_index;
432     input_tensor_index[0] = output_tensor_index[0];  // batch
433     bool ok = true;
434     for (int i = 1; i < NDIMS - 1; i++) {
435       input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 1];
436       ok &=
437           (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]);
438     }
439     input_tensor_index[NDIMS - 1] = output_tensor_index[NDIMS - 1];  // channels
440 
441     if (ok) {
442       const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
443       output[output_index] = input[input_index];
444     } else {
445       output[output_index] = padding_value;
446     }
447   }
448 }
449 
450 template <typename T, int NDIMS>
451 __global__ void PadInputCustomKernelNCHW(
452     int nthreads, const T* __restrict__ input, Dimension<NDIMS> input_dims,
453     T* __restrict__ output, Dimension<NDIMS> output_dims,
454     Dimension<NDIMS - 2> padding_left, T padding_value) {
455   GPU_1D_KERNEL_LOOP(index, nthreads) {
456     int output_index = index;
457     Index<NDIMS> output_tensor_index =
458         FlatToTensorIndex(output_index, output_dims);
459 
460     Index<NDIMS> input_tensor_index;
461     input_tensor_index[0] = output_tensor_index[0];  // batch
462     input_tensor_index[1] = output_tensor_index[1];  // channels
463     bool ok = true;
464     for (int i = 2; i < NDIMS; i++) {
465       input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 2];
466       ok &=
467           (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]);
468     }
469 
470     if (ok) {
471       const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
472       output[output_index] = input[input_index];
473     } else {
474       output[output_index] = padding_value;
475     }
476   }
477 }
478 
479 // A GPU helper function that converts TensorFlow filter format to Cudnn filter
480 // format.
481 template <typename T, int NDIMS>
482 struct TransformFilter<GPUDevice, T, int, NDIMS> {
483   typedef GPUDevice Device;
484   void operator()(const Device& d, FilterTensorFormat dst_filter_format,
485                   typename TTypes<T, NDIMS, int>::ConstTensor in,
486                   typename TTypes<T, NDIMS, int>::Tensor out) {
487     Dimension<3> combined_dims;
488     combined_dims[0] = in.dimension(0);  // spatial dimensions
489     for (int i = 1; i < NDIMS - 2; i++) {
490       combined_dims[0] *= in.dimension(i);
491     }
492     combined_dims[1] = in.dimension(NDIMS - 2);  // input filters
493     combined_dims[2] = in.dimension(NDIMS - 1);  // output filters
494     GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
495 
496     if (dst_filter_format == FORMAT_OIHW) {
497       TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
498                                   config.block_count, config.thread_per_block,
499                                   0, d.stream(), config.virtual_thread_count,
500                                   in.data(), combined_dims, out.data()));
501 
502     } else if (dst_filter_format == FORMAT_OHWI) {
503       TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 1, 2, 0>,
504                                   config.block_count, config.thread_per_block,
505                                   0, d.stream(), config.virtual_thread_count,
506                                   in.data(), combined_dims, out.data()));
507 
508     } else {
509       LOG(ERROR) << "Unsupported filter format: "
510                  << ToString(dst_filter_format);
511     }
512   }
513 };
514 
515 // Converts Cudnn filter format OIHW or OHWI back to TensorFlow filter format
516 // HWIO.
517 template <typename T, int NDIMS>
518 struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
519   typedef GPUDevice Device;
520   void operator()(const Device& d, FilterTensorFormat src_filter_format,
521                   typename TTypes<T, NDIMS>::ConstTensor in,
522                   typename TTypes<T, NDIMS>::Tensor out) {
523     Dimension<3> combined_dims;
524 
525     if (src_filter_format == FORMAT_OIHW) {
526       combined_dims[0] = in.dimension(0);  // output filters
527       combined_dims[1] = in.dimension(1);  // input filters
528       combined_dims[2] = in.dimension(2);  // spatial dimensions
529       for (int i = 3; i < NDIMS; ++i) {
530         combined_dims[2] *= in.dimension(i);
531       }
532 
533       GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
534       TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
535                                   config.block_count, config.thread_per_block,
536                                   0, d.stream(), config.virtual_thread_count,
537                                   in.data(), combined_dims, out.data()));
538 
539     } else if (src_filter_format == FORMAT_OHWI) {
540       combined_dims[0] = in.dimension(0);  // output filters
541       combined_dims[1] = in.dimension(1);  // spatial dimensions
542       for (int i = 2; i < NDIMS - 1; i++) {
543         combined_dims[1] *= in.dimension(i);
544       }
545       combined_dims[2] = in.dimension(NDIMS - 1);  // input filters
546 
547       GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
548       TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 0, 1>,
549                                   config.block_count, config.thread_per_block,
550                                   0, d.stream(), config.virtual_thread_count,
551                                   in.data(), combined_dims, out.data()));
552 
553     } else {
554       // TODO(ezhulenev): Set error status in OpKernelContext instead.
555       LOG(FATAL) << "Unsupported filter format: "
556                  << ToString(src_filter_format);
557     }
558   }
559 };
560 
561 // A GPU helper function that converts input tensor to a larger output tensor,
562 // given proper padding values. The padded value is zero.
563 template <typename T, int NDIMS>
564 struct PadInput<GPUDevice, T, int, NDIMS> {
565   typedef GPUDevice Device;
566   void operator()(const Device& d,
567                   typename TTypes<T, NDIMS, int>::ConstTensor in,
568                   const std::array<int, NDIMS - 2>& padding_left,
569                   const std::array<int, NDIMS - 2>& padding_right,
570                   typename TTypes<T, NDIMS, int>::Tensor out,
571                   TensorFormat format, const T& padding_value) {
572     GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
573     Dimension<NDIMS> input_dims;
574     for (int i = 0; i < NDIMS; ++i) {
575       input_dims[i] = in.dimension(i);
576     }
577     Dimension<NDIMS> output_dims;
578     for (int i = 0; i < NDIMS; ++i) {
579       output_dims[i] = out.dimension(i);
580     }
581 
582     const Dimension<NDIMS - 2> padding_left_dim(padding_left);
583 
584     if (format == FORMAT_NHWC) {
585       TF_CHECK_OK(GpuLaunchKernel(
586           PadInputCustomKernelNHWC<T, NDIMS>, config.block_count,
587           config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
588           in.data(), input_dims, out.data(), output_dims, padding_left_dim,
589           padding_value));
590     } else if (format == FORMAT_NCHW) {
591       TF_CHECK_OK(GpuLaunchKernel(
592           PadInputCustomKernelNCHW<T, NDIMS>, config.block_count,
593           config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
594           in.data(), input_dims, out.data(), output_dims, padding_left_dim,
595           padding_value));
596     } else {
597       LOG(FATAL) << "Invalid data format: " << format;
598     }
599   }
600 };
601 
602 // We want std::equal_to and std::greater, but they're not constexpr until
603 // C++14.
604 struct EqualTo {
605   constexpr bool operator()(int a, int b) const { return a == b; }
606 };
607 
608 struct GreaterThan {
609   constexpr bool operator()(int a, int b) const { return a > b; }
610 };
611 
612 // For each data type, the tile size possibility frontier denotes the tile size
613 // combinations that consume the most computational resources constrained by
614 // - number of threads per SM limit,
615 // - limit on size of the short dimension (<=15) due to the definition of
616 //   narrow matrix,
617 // - shared memory limit and
618 // - some experimentally determined, type-specific constraint on the product of
619 //   two side lengths to increase grid-level parallelism.
620 //
621 // A tile size combination lies on the frontier if and only if one or more
622 // constraint mentioned above is hit. Tile size combinations lying outside this
623 // frontier are either not possible, or are slower than the alternatives.
624 //
625 // It is instrumental to consider, for each data type, two subsets of the
626 // corresponding frontier:
627 // - long side frontier: the union of the biggest tile size combination for
628 //   each legal long side len.
629 // - non long side frontier: the frontier set minus the long side frontier.
630 //
631 // TileSizePossibilityFrontierCheck defines the frontier using only the long
632 // side frontier tile size combinations (since one can easily extrapolate
633 // the entire frontier from this subset). It serves as a utility function
634 // to help us determine where a tile size combination of interest lies with
635 // resepect to the frontier.
636 template <typename Op>
637 constexpr bool TileSizePossibilityFrontierCheck(int TileLongSide,
638                                                 int TileShortSide,
639                                                 int size_of_t, Op op) {
640   // clang-format off
641 
642   return (size_of_t == 16 && ((TileLongSide == 32   && op(TileShortSide, 4))  ||
643                              (TileLongSide == 64   && op(TileShortSide, 4))  ||
644                              (TileLongSide == 128  && op(TileShortSide, 4))  ||
645                              (TileLongSide == 256  && op(TileShortSide, 2)))) ||
646           (size_of_t == 8 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
647                              (TileLongSide == 64   && op(TileShortSide, 15)) ||
648                              (TileLongSide == 128  && op(TileShortSide, 8))  ||
649                              (TileLongSide == 256  && op(TileShortSide, 4))  ||
650                              (TileLongSide == 512  && op(TileShortSide, 2)))) ||
651           (size_of_t == 4 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
652                              (TileLongSide == 64   && op(TileShortSide, 15)) ||
653                              (TileLongSide == 128  && op(TileShortSide, 15)) ||
654                              (TileLongSide == 256  && op(TileShortSide, 8))  ||
655                              (TileLongSide == 512  && op(TileShortSide, 4))  ||
656                              (TileLongSide == 1024 && op(TileShortSide, 2)))) ||
657           (size_of_t == 2 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
658                              (TileLongSide == 64   && op(TileShortSide, 15)) ||
659                              (TileLongSide == 128  && op(TileShortSide, 15)) ||
660                              (TileLongSide == 256  && op(TileShortSide, 8))  ||
661                              (TileLongSide == 512  && op(TileShortSide, 4))  ||
662                              (TileLongSide == 1024 && op(TileShortSide, 2)))) ||
663           (size_of_t == 1 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
664                              (TileLongSide == 64   && op(TileShortSide, 15)) ||
665                              (TileLongSide == 128  && op(TileShortSide, 15)) ||
666                              (TileLongSide == 256  && op(TileShortSide, 8))  ||
667                              (TileLongSide == 512  && op(TileShortSide, 4))  ||
668                              (TileLongSide == 1024 && op(TileShortSide, 2))));
669 
670   // clang-format on
671 }
672 
673 constexpr bool TileSizeOnLongSideFrontier(int TileLongSide, int TileShortSide,
674                                           int size_of_t) {
675   return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide,
676                                           size_of_t, EqualTo());
677 }
678 constexpr bool TileSizeOutsideFrontier(int TileLongSide, int TileShortSide,
679                                        int size_of_t) {
680   return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide,
681                                           size_of_t, GreaterThan());
682 }
683 constexpr bool TileSizeOnNonLongSideFrontier(int TileLongSide,
684                                              int TileShortSide, int size_of_t) {
685   // For a tile size combination (longside, shortside), lying on the frontier
686   // implies that (longside, shortside) is on or within the frontier but
687   // (longside*2, shortside) or (longside, shortside+1) is not. With the above
688   // criterion, we simply need to use !TileSizeOnLongSideFrontier to ensure that
689   // it is not on the long side frontier.
690   return !TileSizeOutsideFrontier(TileLongSide, TileShortSide, size_of_t) &&
691          (TileSizeOutsideFrontier(TileLongSide * 2, TileShortSide, size_of_t) ||
692           TileSizeOutsideFrontier(TileLongSide, TileShortSide + 1,
693                                   size_of_t)) &&
694          !TileSizeOnLongSideFrontier(TileLongSide, TileShortSide, size_of_t);
695 }
696 
697 // Helper function to launch a batch narrow matirx transpose kernel.
698 template <typename T, int TileLongSide, int TileShortSide, bool conjugate>
699 void LaunchBatchNarrowMatrixTransposeKernel(
700     const GPUDevice& d, int tile_size_i, int tile_size_j, int total_tiles_count,
701     const T* input, const Dimension<3>& input_dims, T* output) {
702   constexpr int NumThreads = TileLongSide;
703   if (tile_size_i <= TileLongSide && tile_size_j <= TileShortSide) {
704     TF_CHECK_OK(GpuLaunchKernel(
705         SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileLongSide,
706                                               TileShortSide, conjugate>,
707         total_tiles_count, NumThreads, 0, d.stream(), input, input_dims,
708         output));
709   } else {
710     TF_CHECK_OK(GpuLaunchKernel(
711         SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileShortSide,
712                                               TileLongSide, conjugate>,
713         total_tiles_count, NumThreads, 0, d.stream(), input, input_dims,
714         output));
715   }
716 }
717 
718 // Recursive template function to search, in a trial-and-error manner, for the
719 // minimum tile size configuration satisfying the requested tile side lengths.
720 // An important invariant of this search procedure is that for an unsatisfied
721 // request, we always try doubling the long side len first, and only after
722 // the request is satisfied for the long side len do we begin incrementing
723 // the short side len.
724 //
725 // We have three specializations of this search function depending on where the
726 // current tile size combination lies with respect to the frontier.
727 // - It lies within the frontier. If request is not satisfied, for the next tile
728 // size combination, we first try doubling the long side len and if that does
729 // not work, we then increment the short side len.
730 // - It lies on the non long side frontier. If the request is not satisfied, we
731 // can only increment the short side len.
732 // - It lies on the long side frontier. We launch the kernel without checking if
733 // the request is satisfied or not.
734 template <typename T, int TileLongSide, int TileShortSide, bool conjugate,
735           typename dummy = void>
736 struct BatchNarrowMatrixTransposeDispatcher {
737   static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
738                    int total_tiles_count, const T* input,
739                    const Dimension<3>& input_dims, T* output) {
740     static_assert(
741         (TileLongSide & (TileLongSide - 1)) == 0,
742         "The length of the longer side of the tile is always a power of 2.");
743     bool request_satisfied =
744         std::max(tile_size_i, tile_size_j) <= TileLongSide &&
745         std::min(tile_size_i, tile_size_j) <= TileShortSide;
746 
747     if (request_satisfied) {
748       LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide,
749                                              conjugate>(
750           d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
751           output);
752       return;
753     }
754 
755     // If the execution reaches here, then the kernel was not launched; we then
756     // determine whether it is the long side or the short side that falls short
757     // of the request and increase that parameter accordingly.
758     const bool long_side_request_not_satisfied =
759         std::max(tile_size_i, tile_size_j) > TileLongSide;
760 
761     if (long_side_request_not_satisfied) {
762       BatchNarrowMatrixTransposeDispatcher<T, TileLongSide * 2, TileShortSide,
763                                            conjugate>::DoIt(d, tile_size_i,
764                                                             tile_size_j,
765                                                             total_tiles_count,
766                                                             input, input_dims,
767                                                             output);
768     } else {
769       BatchNarrowMatrixTransposeDispatcher<T, TileLongSide, TileShortSide + 1,
770                                            conjugate>::DoIt(d, tile_size_i,
771                                                             tile_size_j,
772                                                             total_tiles_count,
773                                                             input, input_dims,
774                                                             output);
775     }
776   }
777 };
778 
779 template <typename T, int TileLongSide, int TileShortSide, bool conjugate>
780 struct BatchNarrowMatrixTransposeDispatcher<
781     T, TileLongSide, TileShortSide, conjugate,
782     typename std::enable_if<TileSizeOnNonLongSideFrontier(
783                                 TileLongSide, TileShortSide, sizeof(T)),
784                             void>::type> {
785   static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
786                    int total_tiles_count, const T* input,
787                    const Dimension<3>& input_dims, T* output) {
788     static_assert(
789         (TileLongSide & (TileLongSide - 1)) == 0,
790         "The length of the longer side of the tile is always a power of 2.");
791     bool request_satisfied =
792         std::max(tile_size_i, tile_size_j) <= TileLongSide &&
793         std::min(tile_size_i, tile_size_j) <= TileShortSide;
794 
795     if (request_satisfied) {
796       LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide,
797                                              conjugate>(
798           d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
799           output);
800       return;
801     }
802 
803     // If the execution reaches here, then the kernel was not launched; since
804     // we are on the non long side frontier, we increment the short dimension
805     // and try again.
806     BatchNarrowMatrixTransposeDispatcher<T, TileLongSide, TileShortSide + 1,
807                                          conjugate>::DoIt(d, tile_size_i,
808                                                           tile_size_j,
809                                                           total_tiles_count,
810                                                           input, input_dims,
811                                                           output);
812   }
813 };
814 
815 template <typename T, int TileLongSide, int TileShortSide, bool conjugate>
816 struct BatchNarrowMatrixTransposeDispatcher<
817     T, TileLongSide, TileShortSide, conjugate,
818     typename std::enable_if<TileSizeOnLongSideFrontier(
819                                 TileLongSide, TileShortSide, sizeof(T)),
820                             void>::type> {
821   static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
822                    int total_tiles_count, const T* input,
823                    const Dimension<3>& input_dims, T* output) {
824     static_assert(
825         (TileLongSide & (TileLongSide - 1)) == 0,
826         "The length of the longer side of the tile is always a power of 2.");
827 
828     LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide,
829                                            conjugate>(
830         d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
831         output);
832   }
833 };
834 
835 // This function tries to recover, in a brute force way, the frontier defined in
836 // TileSizePossibilityFrontierCheck as a vector of tile size combinations lying
837 // on the long side frontier. This vector is sufficient to determine the entire
838 // frontier.
839 //
840 // Note that if one changes the frontier definition in
841 // TileSizePossibilityFrontierCheck and forgets to set the largest short
842 // side len of the largest legal long side len to 2, this function will fail
843 // and crash the program.
844 template <int SizeOfT>
845 const std::vector<std::pair<int, int>>& GetTileSizesFrontier() {
846   static_assert(
847       SizeOfT <= 16,
848       "Currently, only data types of sizes 16 bytes or less are supported.");
849   static_assert((SizeOfT & (SizeOfT - 1)) == 0,
850                 "Data types must have sizes that are powers of 2.");
851 
852   // Expensive work to populate sizes, lazily run in a thread-safe
853   // manner the first time GetTileSizesFrontier<N> is called.
854   static auto* frontier = [] {
855     auto* frontier = new std::vector<std::pair<int, int>>();
856     const int kMaxLongSideLen = 1024;
857     const int kMaxShortSideLen = 15;
858     for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) {
859       for (int short_side = 2; short_side <= kMaxShortSideLen;
860            short_side += 1) {
861         if (TileSizeOnLongSideFrontier(long_side, short_side, SizeOfT)) {
862           // The current combination lies on the frontier, thus we
863           // add it to the frontier definition.
864           frontier->push_back(std::make_pair(long_side, short_side));
865 
866           // The long side length is the largest one allowed iff its
867           // corresponding short side length is 2.
868           if (short_side == 2) return frontier;
869 
870           // We have exhausted all the possibilities in the frontier
871           // with the given long side length.
872           break;
873         }
874       }
875     }
876     LOG(FATAL)
877         << "The corresponding short side length of the largest long side "
878            "length has to be 2.";
879   }();
880   return *frontier;
881 }
882 
883 // Helper structs to help determine which data type to use given the size of
884 // the matrix data type. A transpose of elements of size N will use a kernel
885 // which operates on an array of TransposeElemType<N>::type.
886 template <int ElemBytes>
887 struct TransposeElemType;
888 template <>
889 struct TransposeElemType<1> {
890   using type = uint8;
891 };
892 template <>
893 struct TransposeElemType<2> {
894   using type = uint16;
895 };
896 template <>
897 struct TransposeElemType<4> {
898   using type = uint32;
899 };
900 template <>
901 struct TransposeElemType<8> {
902   using type = float2;
903 };
904 template <>
905 struct TransposeElemType<16> {
906   using type = double2;
907 };
908 
909 // A helper function to make RunSwapDimension1And2InTensor3 concise. This
910 // helper function looks at the data type and input matrix sizes and decides
911 // the thread numbers and tile sizes to use.
912 template <typename T, bool conjugate = false>
913 void SwapDimension1And2InTensor3WithNarrowMatrices(
914     const GPUDevice& d, const T* input, const Dimension<3>& input_dims,
915     T* output, const int kMinDimensionToUseTiles) {
916   // Get available tile sizes here for the data type requested:
917   const auto& tile_spec = GetTileSizesFrontier<sizeof(T)>();
918 
919   int tile_long_side_len = 0;
920   int tile_short_side_len = 0;
921   float lowest_cost = std::numeric_limits<float>::max();
922   int data_long_side = std::max(input_dims[1], input_dims[2]);
923 
924   for (auto tile_size_pair : tile_spec) {
925     int proposed_tile_long_side_len = tile_size_pair.first;
926 
927     // Number of threads that will not be doing anything useful when reading
928     // the matrix because the thread block size is bigger than the data block
929     // size.
930     int num_wasted_threads =
931         data_long_side - MathUtil::FloorOfRatio<int>(
932                              data_long_side, proposed_tile_long_side_len) *
933                              proposed_tile_long_side_len;
934 
935     int num_full_tiles = MathUtil::FloorOfRatio<int>(
936         data_long_side, proposed_tile_long_side_len);
937 
938     float cost = 0;
939 
940     // However, if we can execute two or more full tiles, then we gladly
941     // accept any number of wasted threads and ignore its cost.
942     if (num_full_tiles <= 1) cost = num_wasted_threads;
943 
944     // Using less than or equal to here because given the same cost, we
945     // would like to launch as many threads as possible.
946     if (cost <= lowest_cost) {
947       tile_long_side_len = proposed_tile_long_side_len;
948       tile_short_side_len = tile_size_pair.second;
949       lowest_cost = cost;
950     }
951   }
952 
953   // Request tile sizes such that the longer side of threadblock aligns with
954   // the longer side of input data block to maximize read throughput.
955   // The ideal tile shape is one where the length of the shorter side of the
956   // tile is equal to the length of the shorter side of the input matrix.
957   int requested_tile_size_i = input_dims[1] >= kMinDimensionToUseTiles
958                                   ? tile_long_side_len
959                                   : input_dims[1];
960   int requested_tile_size_j = input_dims[1] >= kMinDimensionToUseTiles
961                                   ? input_dims[2]
962                                   : tile_long_side_len;
963 
964   // Truncate the shorter size requested according to the manual limit set in
965   // tile_spec to make sure that we do not launch configurations violating
966   // hardware limits.
967   requested_tile_size_i =
968       requested_tile_size_i == tile_long_side_len
969           ? tile_long_side_len
970           : std::min(requested_tile_size_i, tile_short_side_len);
971   requested_tile_size_j =
972       requested_tile_size_j == tile_long_side_len
973           ? tile_long_side_len
974           : std::min(requested_tile_size_j, tile_short_side_len);
975 
976   Dimension<3> input_dims_in_tiles = {
977       input_dims[0],
978       MathUtil::CeilOfRatio<int>(input_dims[1], requested_tile_size_i),
979       MathUtil::CeilOfRatio<int>(input_dims[2], requested_tile_size_j),
980   };
981 
982   int total_tiles_count =
983       input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2];
984 
985   using ElemType = typename TransposeElemType<sizeof(T)>::type;
986   static_assert(alignof(T) >= alignof(ElemType), "Unexpected data alignment.");
987   BatchNarrowMatrixTransposeDispatcher<ElemType, 32, 2, conjugate>::DoIt(
988       d, requested_tile_size_i, requested_tile_size_j, total_tiles_count,
989       reinterpret_cast<const ElemType*>(input), input_dims,
990       reinterpret_cast<ElemType*>(output));
991 }
992 
993 // Launch the GPU kernel that would swap dimension-1 and dimension-2 in a
994 // 3D tensor. It looks at the shape of the incoming data, and decides the best
995 // strategy to launch.
996 template <typename T, bool conjugate = false>
997 void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
998                                     const Dimension<3>& input_dims, T* output) {
999   // If both dimensions are not trivial, use tiles for the actual swapping.
1000   // If one dimension is trivial, use SmallDim kernel for swapping.
1001   // Otherwise, the trivial swapping relying on the ldg cache is more efficient.
1002   static const int kMinDimensionToUseTiles = 16;
1003   static const int kMinDimensionToUseRectTiles = 96;
1004 
1005   bool large_matrix = input_dims[1] >= kMinDimensionToUseTiles &&
1006                       input_dims[2] >= kMinDimensionToUseTiles;
1007   bool narrow_matrix = input_dims[1] >= kMinDimensionToUseRectTiles ||
1008                        input_dims[2] >= kMinDimensionToUseRectTiles;
1009   if (large_matrix) {
1010     // We get best performance when kTileSize is the number of threads in a warp
1011     // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
1012     // threads.
1013     constexpr int kTileSize = 32;
1014     constexpr int kNumThreads = 256;
1015 
1016     Dimension<3> input_dims_in_tiles = {
1017         input_dims[0],
1018         MathUtil::CeilOfRatio<int>(input_dims[1], kTileSize),
1019         MathUtil::CeilOfRatio<int>(input_dims[2], kTileSize),
1020     };
1021 
1022     int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
1023                             input_dims_in_tiles[2];
1024     TF_CHECK_OK(GpuLaunchKernel(
1025         SwapDimension1And2InTensor3UsingTiles<T, kNumThreads, kTileSize,
1026                                               kTileSize, conjugate>,
1027         total_tiles_count, kNumThreads, 0, d.stream(), input, input_dims,
1028         output));
1029 
1030   } else if (narrow_matrix) {
1031     SwapDimension1And2InTensor3WithNarrowMatrices<T, conjugate>(
1032         d, input, input_dims, output, kMinDimensionToUseTiles);
1033   } else {
1034     int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
1035     GpuLaunchConfig config = GetGpuLaunchConfig(total_element_count, d);
1036     TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 0, 2, 1, conjugate>,
1037                                 config.block_count, config.thread_per_block, 0,
1038                                 d.stream(), config.virtual_thread_count, input,
1039                                 input_dims, output));
1040   }
1041 }
1042 
1043 // A GPU helper functor that does general dimension 1 and 2 switch for 3D
1044 // tensor.
1045 template <typename T, bool conjugate>
1046 struct SwapDimension1And2InTensor3<GPUDevice, T, conjugate> {
1047   typedef GPUDevice Device;
1048   void operator()(const Device& d, const T* in,
1049                   const gtl::ArraySlice<int64_t>& combined_dims, T* out) {
1050     Dimension<3> input_dims = {static_cast<int>(combined_dims[0]),
1051                                static_cast<int>(combined_dims[1]),
1052                                static_cast<int>(combined_dims[2])};
1053     RunSwapDimension1And2InTensor3<T, conjugate>(d, in, input_dims, out);
1054   }
1055 };
1056 
1057 // A GPU helper functor that does general dimension 0 and 2 switch for 3D
1058 // tensor.
1059 template <typename T, bool conjugate>
1060 struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
1061   typedef GPUDevice Device;
1062   void operator()(const Device& d, const T* in,
1063                   const gtl::ArraySlice<int64_t>& combined_dims, T* out) {
1064     Dimension<3> input_dims = {static_cast<int>(combined_dims[0]),
1065                                static_cast<int>(combined_dims[1]),
1066                                static_cast<int>(combined_dims[2])};
1067     size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
1068     GpuLaunchConfig config = GetGpuLaunchConfig(total_size, d);
1069 
1070     auto out_ptr = reinterpret_cast<uintptr_t>(out);
1071     bool aligned = out_ptr % 16 == 0;
1072 
1073     bool use_vector = false;
1074     bool use_custom_config = false;
1075     if ((input_dims[0] <= 128 && input_dims[2] <= 128) ||
1076         input_dims[0] * input_dims[1] <= 128 ||
1077         input_dims[1] * input_dims[2] <= 8) {
1078       use_vector = true;
1079       use_custom_config = true;
1080     } else if (input_dims[1] * input_dims[2] <= 16384) {
1081       use_vector = true;
1082     }
1083 
1084     if (sizeof(T) == 2 && aligned && use_vector) {
1085       int block_count;
1086       if (use_custom_config) {
1087         block_count = (total_size + config.thread_per_block - 1) /
1088                       config.thread_per_block;
1089       } else {
1090         block_count = config.block_count;
1091       }
1092 
1093       TF_CHECK_OK(
1094           GpuLaunchKernel(ShuffleInTensor3SimpleVector<T, 2, 1, 0, conjugate>,
1095                           block_count, config.thread_per_block / kUnroll, 0,
1096                           d.stream(), total_size, in, input_dims, out));
1097     } else {
1098       TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>,
1099                                   config.block_count, config.thread_per_block,
1100                                   0, d.stream(), config.virtual_thread_count,
1101                                   in, input_dims, out));
1102     }
1103   }
1104 };
1105 
1106 // A GPU helper functor that converts NHWC TensorFlow data format to
1107 // NCHW format that is accepted by Cudnn.
1108 template <typename T, int NDIMS>
1109 struct NHWCToNCHW<GPUDevice, T, NDIMS> {
1110   typedef GPUDevice Device;
1111   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
1112                   typename TTypes<T, NDIMS>::Tensor out) {
1113     Dimension<3> combined_dims;
1114     combined_dims[0] = in.dimension(0);  // N (batch)
1115     combined_dims[1] = in.dimension(1);  // spatial dimensions (HW)
1116     for (int i = 2; i < NDIMS - 1; ++i) {
1117       combined_dims[1] *= in.dimension(i);
1118     }
1119     combined_dims[2] = in.dimension(NDIMS - 1);  // C (channels)
1120     RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
1121   }
1122 };
1123 
1124 // A GPU helper functor that converts NCHW Cudnn data format to NHWC TensorFlow
1125 // Format.
1126 template <typename T, int NDIMS>
1127 struct NCHWToNHWC<GPUDevice, T, NDIMS> {
1128   typedef GPUDevice Device;
1129   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
1130                   typename TTypes<T, NDIMS>::Tensor out) {
1131     Dimension<3> combined_dims;
1132     combined_dims[0] = in.dimension(0);  // N (batch)
1133     combined_dims[1] = in.dimension(1);  // C (channel)
1134     combined_dims[2] = in.dimension(2);  // spatial dimensions (HW)
1135     for (int i = 3; i < NDIMS; ++i) {
1136       combined_dims[2] *= in.dimension(i);
1137     }
1138     RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
1139   }
1140 };
1141 
1142 }  // namespace functor
1143 }  // namespace tensorflow
1144 
1145 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1146 
1147 #endif  // TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
1148