1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_ 17 #define TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_ 18 19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 20 21 #define EIGEN_USE_GPU 22 23 #include <algorithm> 24 #include <array> 25 #include <limits> 26 #include <utility> 27 28 #if GOOGLE_CUDA 29 #include "third_party/gpus/cuda/include/cuda.h" 30 #endif 31 #include "tensorflow/core/framework/register_types.h" 32 #include "tensorflow/core/kernels/conv_2d.h" 33 #include "tensorflow/core/lib/math/math_util.h" 34 #include "tensorflow/core/util/gpu_kernel_helper.h" 35 #include "tensorflow/core/util/tensor_format.h" 36 37 namespace tensorflow { 38 39 typedef Eigen::GpuDevice GPUDevice; 40 41 namespace functor { 42 43 template <typename T, bool conjugate> 44 struct maybe_conj { runmaybe_conj45 __device__ static __inline__ T run(T x) { 46 if (conjugate) { 47 return Eigen::numext::conj(x); 48 } else { 49 return x; 50 } 51 } 52 }; 53 54 // Partial specializations for Gpu types used to store complex numbers. 55 template <bool conjugate> 56 struct maybe_conj<float2, conjugate> { 57 __device__ static __inline__ float2 run(float2 c) { 58 if (conjugate) { 59 float2 c_conj; 60 c_conj.x = c.x; 61 c_conj.y = -c.y; 62 return c_conj; 63 } else { 64 return c; 65 } 66 } 67 }; 68 69 template <bool conjugate> 70 struct maybe_conj<double2, conjugate> { 71 __device__ static __inline__ double2 run(double2 c) { 72 if (conjugate) { 73 double2 c_conj; 74 c_conj.x = c.x; 75 c_conj.y = -c.y; 76 return c_conj; 77 } else { 78 return c; 79 } 80 } 81 }; 82 83 // TODO(mjanusz): Move this to a shared util file. 84 // A simple array that contains data that can be passed between CPU and GPU. 85 template <typename T, int IndexCount, T DefaultValue> 86 struct Array { 87 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const { 88 return data[index]; 89 } 90 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) { 91 return data[index]; 92 } 93 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array() { 94 for (int i = 0; i < IndexCount; i++) { 95 data[i] = DefaultValue; 96 } 97 } 98 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0) { 99 data[0] = a0; 100 for (int i = 1; i < IndexCount; i++) { 101 data[i] = DefaultValue; 102 } 103 } 104 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1) { 105 data[0] = a0; 106 data[1] = a1; 107 for (int i = 2; i < IndexCount; i++) { 108 data[i] = DefaultValue; 109 } 110 } 111 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1, T a2) { 112 data[0] = a0; 113 data[1] = a1; 114 data[2] = a2; 115 for (int i = 3; i < IndexCount; i++) { 116 data[i] = DefaultValue; 117 } 118 } 119 EIGEN_STRONG_INLINE Array(const std::array<T, IndexCount>& array) { 120 for (int i = 0; i < IndexCount; i++) { 121 data[i] = array[i]; 122 } 123 } 124 T data[IndexCount]; 125 }; 126 127 // A dimension type with compile-time known size. 128 template <int IndexCount> 129 struct Dimension : Array<int, IndexCount, 1> { 130 typedef Array<int, IndexCount, 1> Base; 131 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension() : Base() {} 132 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0) : Base(a0) {} 133 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1) 134 : Base(a0, a1) {} 135 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1, int a2) 136 : Base(a0, a1, a2) {} 137 EIGEN_STRONG_INLINE Dimension(const std::array<int, IndexCount>& array) 138 : Base(array) {} 139 }; 140 141 // An index type with compile-time known size. 142 template <int IndexCount> 143 struct Index : Array<int, IndexCount, 0> { 144 typedef Array<int, IndexCount, 0> Base; 145 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index() : Base() {} 146 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0) : Base(a0) {} 147 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1) : Base(a0, a1) {} 148 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1, int a2) 149 : Base(a0, a1, a2) {} 150 }; 151 152 // A helper function that converts a tensor index into a flat array index. 153 template <int IndexCount> 154 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int TensorIndexToFlat( 155 const Index<IndexCount>& index, const Dimension<IndexCount>& dims) { 156 int flat_index = index[0]; 157 for (int i = 1; i < IndexCount; i++) { 158 flat_index = flat_index * dims[i] + index[i]; 159 } 160 return flat_index; 161 } 162 163 // A helper function that converts a flat array index into a tensor index. 164 template <int IndexCount> 165 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex( 166 int index, const Dimension<IndexCount>& dims) { 167 Index<IndexCount> tensor_index; 168 for (int i = IndexCount - 1; i >= 0; i--) { 169 int new_index = index / dims[i]; 170 tensor_index[i] = index - dims[i] * new_index; 171 index = new_index; 172 } 173 return tensor_index; 174 } 175 176 // A simple CUDA custom kernel to shuffle dimensions of a 3D tensor according to 177 // the given shuffle permutation in template parameters. Shuffle permutation 178 // <sp0, sp1, sp2> shuffles dimensions such that input dimension 0 goes to sp0, 179 // 1 goes to sp1 and 2 goes to sp2. For example, shuffle permutation <2, 0, 1> 180 // will populate output so that input[x][y][z] is equal to (*output)[y][z][x]. 181 // 182 // Requires that nthreads is equal to the total number of elements in the input 183 // tensor. 184 template <typename T, int sp0, int sp1, int sp2, bool conjugate = false> 185 __global__ void ShuffleInTensor3Simple(int nthreads, 186 const T* __restrict__ input, 187 Dimension<3> input_dims, 188 T* __restrict__ output) { 189 Dimension<3> output_dims; 190 output_dims[sp0] = input_dims[0]; 191 output_dims[sp1] = input_dims[1]; 192 output_dims[sp2] = input_dims[2]; 193 194 // Iterate over output as opposed to iterating over input for better 195 // performance. Iterating over output will generate sequential writes and 196 // random reads that performs better compared to sequential reads and random 197 // writes. 198 GPU_1D_KERNEL_LOOP(output_index, nthreads) { 199 Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims); 200 201 Index<3> input_tensor_index; 202 input_tensor_index[0] = output_tensor_index[sp0]; 203 input_tensor_index[1] = output_tensor_index[sp1]; 204 input_tensor_index[2] = output_tensor_index[sp2]; 205 206 int input_index = TensorIndexToFlat(input_tensor_index, input_dims); 207 208 output[output_index] = 209 maybe_conj<T, conjugate>::run(ldg(input + input_index)); 210 } 211 } 212 213 static constexpr int kUnroll = 4; 214 215 template <typename T, int sp0, int sp1, int sp2, bool conjugate = false> 216 __global__ void ShuffleInTensor3SimpleVector(int nthreads, 217 const T* __restrict__ input, 218 Dimension<3> input_dims, 219 T* __restrict__ output) { 220 Dimension<3> output_dims; 221 output_dims[sp0] = input_dims[0]; 222 output_dims[sp1] = input_dims[1]; 223 output_dims[sp2] = input_dims[2]; 224 225 const int stride = blockDim.x * gridDim.x * kUnroll; 226 const int tid = blockIdx.x * blockDim.x + threadIdx.x; 227 T buf[kUnroll]; 228 229 int output_index; 230 for (output_index = tid * kUnroll; output_index + kUnroll - 1 < nthreads; 231 output_index += stride) { 232 #pragma unroll 233 for (int i = 0; i < kUnroll; i++) { 234 int output_index_i = output_index + i; 235 Index<3> output_tensor_index = 236 FlatToTensorIndex(output_index_i, output_dims); 237 Index<3> input_tensor_index; 238 input_tensor_index[0] = output_tensor_index[sp0]; 239 input_tensor_index[1] = output_tensor_index[sp1]; 240 input_tensor_index[2] = output_tensor_index[sp2]; 241 242 int input_index_i = TensorIndexToFlat(input_tensor_index, input_dims); 243 buf[i] = maybe_conj<T, conjugate>::run(ldg(input + input_index_i)); 244 } 245 float2* out = reinterpret_cast<float2*>(output + output_index); 246 *out = *reinterpret_cast<float2*>(buf); 247 } 248 249 for (; output_index < nthreads; ++output_index) { 250 Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims); 251 252 Index<3> input_tensor_index; 253 input_tensor_index[0] = output_tensor_index[sp0]; 254 input_tensor_index[1] = output_tensor_index[sp1]; 255 input_tensor_index[2] = output_tensor_index[sp2]; 256 257 int input_index = TensorIndexToFlat(input_tensor_index, input_dims); 258 259 output[output_index] = 260 maybe_conj<T, conjugate>::run(ldg(input + input_index)); 261 } 262 } 263 264 // Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor, 265 // where dimensions are zero-based: output[i][j][k] = input[i][k][j]. 266 // 267 // Each thread block operates on a single tile, a rectangle of dimensions 268 // TileSizeI x TileSizeJ. 269 // 270 // In general, for best performance, you should probably set TileSizeI, 271 // TileSizeJ equal to the number of threads in a warp (32 in nvidia GPUs). 272 // With a TileSizeI, TileSizeJ of 32, NumThreads of 128 or 256 seems to get 273 // the best performance on K40 GPUs. 274 template <typename T, int NumThreads, int TileSizeI, int TileSizeJ, 275 bool conjugate = false> 276 __global__ void SwapDimension1And2InTensor3UsingTiles( 277 const T* __restrict__ input, Dimension<3> input_dims, 278 T* __restrict__ output) { 279 eigen_assert(blockDim.x == NumThreads); 280 eigen_assert(blockDim.y == 1); 281 eigen_assert(blockDim.z == 1); 282 eigen_assert(gridDim.y == 1); 283 eigen_assert(gridDim.z == 1); 284 285 constexpr int ReadRowPerPass = NumThreads / TileSizeJ; 286 constexpr int WriteRowPerPass = NumThreads / TileSizeI; 287 // One extra line in the inner dimension to avoid share memory bank conflict. 288 // This is to mimic the following, but no constructor of T can be invoked. 289 // __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1]; 290 #if GOOGLE_CUDA 291 __shared__ __align__( 292 alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)]; 293 typedef T(*SharedMemoryTile)[TileSizeJ + 1]; 294 SharedMemoryTile shared_memory_tile = 295 reinterpret_cast<SharedMemoryTile>(shared_mem_raw); 296 #elif TENSORFLOW_USE_ROCM 297 __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1]; 298 #endif 299 300 int x = threadIdx.x; 301 302 Dimension<3> output_dims = { 303 input_dims[0], 304 input_dims[2], 305 input_dims[1], 306 }; 307 308 Dimension<3> input_dims_in_tiles = { 309 input_dims[0], 310 (input_dims[1] + TileSizeI - 1) / TileSizeI, 311 (input_dims[2] + TileSizeJ - 1) / TileSizeJ, 312 }; 313 314 Index<3> input_tile_index = 315 FlatToTensorIndex(blockIdx.x, input_dims_in_tiles); 316 317 Index<3> input_tile_origin = { 318 input_tile_index[0], 319 input_tile_index[1] * TileSizeI, 320 input_tile_index[2] * TileSizeJ, 321 }; 322 323 int input_origin_flat_index = 324 TensorIndexToFlat(input_tile_origin, input_dims); 325 326 bool full_tile = true; 327 int tile_width = TileSizeJ; 328 329 // Only the last row or column may not have the full size. 330 if (input_tile_index[2] == input_dims_in_tiles[2] - 1) { 331 tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileSizeJ; 332 full_tile &= false; 333 } 334 335 int tile_height = TileSizeI; 336 337 if (input_tile_index[1] == input_dims_in_tiles[1] - 1) { 338 tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSizeI; 339 full_tile &= false; 340 } 341 342 // Calculate effective thread number. This ensures that we use the largest 343 // number of threads available to form a regular thread block with no 344 // trailing incomplete lines. 345 constexpr int in_effective_thread_num = NumThreads / TileSizeJ * TileSizeJ; 346 347 if (x < in_effective_thread_num) { 348 // Orient the logical thread block with respect to the input array. 349 // ie. align the contiguous dimension of thread blocks with the contiguous 350 // dimension of the input array. 351 int ti = x / TileSizeJ; 352 int tj = x % TileSizeJ; 353 int input_index = input_origin_flat_index + ti * input_dims[2] + tj; 354 int input_increment = ReadRowPerPass * input_dims[2]; 355 356 if (full_tile) { 357 #pragma unroll 358 for (int i_loc = ti; i_loc < (TileSizeI); i_loc += ReadRowPerPass) { 359 shared_memory_tile[i_loc][tj] = 360 maybe_conj<T, conjugate>::run(input[input_index]); 361 input_index += input_increment; 362 } 363 } else { 364 if (tj < tile_width) { 365 for (int i_loc = ti; i_loc < (tile_height); i_loc += ReadRowPerPass) { 366 shared_memory_tile[i_loc][tj] = 367 maybe_conj<T, conjugate>::run(input[input_index]); 368 input_index += input_increment; 369 } 370 } 371 } 372 } 373 374 __syncthreads(); 375 376 Index<3> output_tile_index = { 377 input_tile_index[0], 378 input_tile_index[2], 379 input_tile_index[1], 380 }; 381 382 Index<3> output_tile_origin = { 383 output_tile_index[0], 384 output_tile_index[1] * TileSizeJ, 385 output_tile_index[2] * TileSizeI, 386 }; 387 388 int output_origin_flat_index = 389 TensorIndexToFlat(output_tile_origin, output_dims); 390 391 constexpr int out_effective_thread_num = NumThreads / TileSizeI * TileSizeI; 392 393 if (x < out_effective_thread_num) { 394 // Re-orient the logical thread block with respect to the output array. 395 // ie. align the contiguous dimension of thread blocks with contiguous 396 // dimension of the output array. 397 int ti = x / TileSizeI; 398 int tj = x % TileSizeI; 399 int output_index = output_origin_flat_index + ti * output_dims[2] + tj; 400 int output_increment = WriteRowPerPass * output_dims[2]; 401 402 if (full_tile) { 403 #pragma unroll 404 for (int i_loc = ti; i_loc < (TileSizeJ); i_loc += WriteRowPerPass) { 405 output[output_index] = shared_memory_tile[tj][i_loc]; 406 output_index += output_increment; 407 } 408 } else { 409 if (tj < tile_height) { 410 for (int i_loc = ti; i_loc < (tile_width); i_loc += WriteRowPerPass) { 411 output[output_index] = shared_memory_tile[tj][i_loc]; 412 output_index += output_increment; 413 } 414 } 415 } 416 } 417 } 418 419 // A Gpu custom kernel that convert input to output, given proper padding on 420 // the left and the top. 421 template <typename T, int NDIMS> 422 __global__ void PadInputCustomKernelNHWC( 423 int nthreads, const T* __restrict__ input, Dimension<NDIMS> input_dims, 424 T* __restrict__ output, Dimension<NDIMS> output_dims, 425 Dimension<NDIMS - 2> padding_left, T padding_value) { 426 GPU_1D_KERNEL_LOOP(index, nthreads) { 427 int output_index = index; 428 Index<NDIMS> output_tensor_index = 429 FlatToTensorIndex(output_index, output_dims); 430 431 Index<NDIMS> input_tensor_index; 432 input_tensor_index[0] = output_tensor_index[0]; // batch 433 bool ok = true; 434 for (int i = 1; i < NDIMS - 1; i++) { 435 input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 1]; 436 ok &= 437 (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]); 438 } 439 input_tensor_index[NDIMS - 1] = output_tensor_index[NDIMS - 1]; // channels 440 441 if (ok) { 442 const int input_index = TensorIndexToFlat(input_tensor_index, input_dims); 443 output[output_index] = input[input_index]; 444 } else { 445 output[output_index] = padding_value; 446 } 447 } 448 } 449 450 template <typename T, int NDIMS> 451 __global__ void PadInputCustomKernelNCHW( 452 int nthreads, const T* __restrict__ input, Dimension<NDIMS> input_dims, 453 T* __restrict__ output, Dimension<NDIMS> output_dims, 454 Dimension<NDIMS - 2> padding_left, T padding_value) { 455 GPU_1D_KERNEL_LOOP(index, nthreads) { 456 int output_index = index; 457 Index<NDIMS> output_tensor_index = 458 FlatToTensorIndex(output_index, output_dims); 459 460 Index<NDIMS> input_tensor_index; 461 input_tensor_index[0] = output_tensor_index[0]; // batch 462 input_tensor_index[1] = output_tensor_index[1]; // channels 463 bool ok = true; 464 for (int i = 2; i < NDIMS; i++) { 465 input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 2]; 466 ok &= 467 (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]); 468 } 469 470 if (ok) { 471 const int input_index = TensorIndexToFlat(input_tensor_index, input_dims); 472 output[output_index] = input[input_index]; 473 } else { 474 output[output_index] = padding_value; 475 } 476 } 477 } 478 479 // A GPU helper function that converts TensorFlow filter format to Cudnn filter 480 // format. 481 template <typename T, int NDIMS> 482 struct TransformFilter<GPUDevice, T, int, NDIMS> { 483 typedef GPUDevice Device; 484 void operator()(const Device& d, FilterTensorFormat dst_filter_format, 485 typename TTypes<T, NDIMS, int>::ConstTensor in, 486 typename TTypes<T, NDIMS, int>::Tensor out) { 487 Dimension<3> combined_dims; 488 combined_dims[0] = in.dimension(0); // spatial dimensions 489 for (int i = 1; i < NDIMS - 2; i++) { 490 combined_dims[0] *= in.dimension(i); 491 } 492 combined_dims[1] = in.dimension(NDIMS - 2); // input filters 493 combined_dims[2] = in.dimension(NDIMS - 1); // output filters 494 GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d); 495 496 if (dst_filter_format == FORMAT_OIHW) { 497 TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>, 498 config.block_count, config.thread_per_block, 499 0, d.stream(), config.virtual_thread_count, 500 in.data(), combined_dims, out.data())); 501 502 } else if (dst_filter_format == FORMAT_OHWI) { 503 TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 1, 2, 0>, 504 config.block_count, config.thread_per_block, 505 0, d.stream(), config.virtual_thread_count, 506 in.data(), combined_dims, out.data())); 507 508 } else { 509 LOG(ERROR) << "Unsupported filter format: " 510 << ToString(dst_filter_format); 511 } 512 } 513 }; 514 515 // Converts Cudnn filter format OIHW or OHWI back to TensorFlow filter format 516 // HWIO. 517 template <typename T, int NDIMS> 518 struct ReverseTransformFilter<GPUDevice, T, NDIMS> { 519 typedef GPUDevice Device; 520 void operator()(const Device& d, FilterTensorFormat src_filter_format, 521 typename TTypes<T, NDIMS>::ConstTensor in, 522 typename TTypes<T, NDIMS>::Tensor out) { 523 Dimension<3> combined_dims; 524 525 if (src_filter_format == FORMAT_OIHW) { 526 combined_dims[0] = in.dimension(0); // output filters 527 combined_dims[1] = in.dimension(1); // input filters 528 combined_dims[2] = in.dimension(2); // spatial dimensions 529 for (int i = 3; i < NDIMS; ++i) { 530 combined_dims[2] *= in.dimension(i); 531 } 532 533 GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d); 534 TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>, 535 config.block_count, config.thread_per_block, 536 0, d.stream(), config.virtual_thread_count, 537 in.data(), combined_dims, out.data())); 538 539 } else if (src_filter_format == FORMAT_OHWI) { 540 combined_dims[0] = in.dimension(0); // output filters 541 combined_dims[1] = in.dimension(1); // spatial dimensions 542 for (int i = 2; i < NDIMS - 1; i++) { 543 combined_dims[1] *= in.dimension(i); 544 } 545 combined_dims[2] = in.dimension(NDIMS - 1); // input filters 546 547 GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d); 548 TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 0, 1>, 549 config.block_count, config.thread_per_block, 550 0, d.stream(), config.virtual_thread_count, 551 in.data(), combined_dims, out.data())); 552 553 } else { 554 // TODO(ezhulenev): Set error status in OpKernelContext instead. 555 LOG(FATAL) << "Unsupported filter format: " 556 << ToString(src_filter_format); 557 } 558 } 559 }; 560 561 // A GPU helper function that converts input tensor to a larger output tensor, 562 // given proper padding values. The padded value is zero. 563 template <typename T, int NDIMS> 564 struct PadInput<GPUDevice, T, int, NDIMS> { 565 typedef GPUDevice Device; 566 void operator()(const Device& d, 567 typename TTypes<T, NDIMS, int>::ConstTensor in, 568 const std::array<int, NDIMS - 2>& padding_left, 569 const std::array<int, NDIMS - 2>& padding_right, 570 typename TTypes<T, NDIMS, int>::Tensor out, 571 TensorFormat format, const T& padding_value) { 572 GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d); 573 Dimension<NDIMS> input_dims; 574 for (int i = 0; i < NDIMS; ++i) { 575 input_dims[i] = in.dimension(i); 576 } 577 Dimension<NDIMS> output_dims; 578 for (int i = 0; i < NDIMS; ++i) { 579 output_dims[i] = out.dimension(i); 580 } 581 582 const Dimension<NDIMS - 2> padding_left_dim(padding_left); 583 584 if (format == FORMAT_NHWC) { 585 TF_CHECK_OK(GpuLaunchKernel( 586 PadInputCustomKernelNHWC<T, NDIMS>, config.block_count, 587 config.thread_per_block, 0, d.stream(), config.virtual_thread_count, 588 in.data(), input_dims, out.data(), output_dims, padding_left_dim, 589 padding_value)); 590 } else if (format == FORMAT_NCHW) { 591 TF_CHECK_OK(GpuLaunchKernel( 592 PadInputCustomKernelNCHW<T, NDIMS>, config.block_count, 593 config.thread_per_block, 0, d.stream(), config.virtual_thread_count, 594 in.data(), input_dims, out.data(), output_dims, padding_left_dim, 595 padding_value)); 596 } else { 597 LOG(FATAL) << "Invalid data format: " << format; 598 } 599 } 600 }; 601 602 // We want std::equal_to and std::greater, but they're not constexpr until 603 // C++14. 604 struct EqualTo { 605 constexpr bool operator()(int a, int b) const { return a == b; } 606 }; 607 608 struct GreaterThan { 609 constexpr bool operator()(int a, int b) const { return a > b; } 610 }; 611 612 // For each data type, the tile size possibility frontier denotes the tile size 613 // combinations that consume the most computational resources constrained by 614 // - number of threads per SM limit, 615 // - limit on size of the short dimension (<=15) due to the definition of 616 // narrow matrix, 617 // - shared memory limit and 618 // - some experimentally determined, type-specific constraint on the product of 619 // two side lengths to increase grid-level parallelism. 620 // 621 // A tile size combination lies on the frontier if and only if one or more 622 // constraint mentioned above is hit. Tile size combinations lying outside this 623 // frontier are either not possible, or are slower than the alternatives. 624 // 625 // It is instrumental to consider, for each data type, two subsets of the 626 // corresponding frontier: 627 // - long side frontier: the union of the biggest tile size combination for 628 // each legal long side len. 629 // - non long side frontier: the frontier set minus the long side frontier. 630 // 631 // TileSizePossibilityFrontierCheck defines the frontier using only the long 632 // side frontier tile size combinations (since one can easily extrapolate 633 // the entire frontier from this subset). It serves as a utility function 634 // to help us determine where a tile size combination of interest lies with 635 // resepect to the frontier. 636 template <typename Op> 637 constexpr bool TileSizePossibilityFrontierCheck(int TileLongSide, 638 int TileShortSide, 639 int size_of_t, Op op) { 640 // clang-format off 641 642 return (size_of_t == 16 && ((TileLongSide == 32 && op(TileShortSide, 4)) || 643 (TileLongSide == 64 && op(TileShortSide, 4)) || 644 (TileLongSide == 128 && op(TileShortSide, 4)) || 645 (TileLongSide == 256 && op(TileShortSide, 2)))) || 646 (size_of_t == 8 && ((TileLongSide == 32 && op(TileShortSide, 15)) || 647 (TileLongSide == 64 && op(TileShortSide, 15)) || 648 (TileLongSide == 128 && op(TileShortSide, 8)) || 649 (TileLongSide == 256 && op(TileShortSide, 4)) || 650 (TileLongSide == 512 && op(TileShortSide, 2)))) || 651 (size_of_t == 4 && ((TileLongSide == 32 && op(TileShortSide, 15)) || 652 (TileLongSide == 64 && op(TileShortSide, 15)) || 653 (TileLongSide == 128 && op(TileShortSide, 15)) || 654 (TileLongSide == 256 && op(TileShortSide, 8)) || 655 (TileLongSide == 512 && op(TileShortSide, 4)) || 656 (TileLongSide == 1024 && op(TileShortSide, 2)))) || 657 (size_of_t == 2 && ((TileLongSide == 32 && op(TileShortSide, 15)) || 658 (TileLongSide == 64 && op(TileShortSide, 15)) || 659 (TileLongSide == 128 && op(TileShortSide, 15)) || 660 (TileLongSide == 256 && op(TileShortSide, 8)) || 661 (TileLongSide == 512 && op(TileShortSide, 4)) || 662 (TileLongSide == 1024 && op(TileShortSide, 2)))) || 663 (size_of_t == 1 && ((TileLongSide == 32 && op(TileShortSide, 15)) || 664 (TileLongSide == 64 && op(TileShortSide, 15)) || 665 (TileLongSide == 128 && op(TileShortSide, 15)) || 666 (TileLongSide == 256 && op(TileShortSide, 8)) || 667 (TileLongSide == 512 && op(TileShortSide, 4)) || 668 (TileLongSide == 1024 && op(TileShortSide, 2)))); 669 670 // clang-format on 671 } 672 673 constexpr bool TileSizeOnLongSideFrontier(int TileLongSide, int TileShortSide, 674 int size_of_t) { 675 return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide, 676 size_of_t, EqualTo()); 677 } 678 constexpr bool TileSizeOutsideFrontier(int TileLongSide, int TileShortSide, 679 int size_of_t) { 680 return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide, 681 size_of_t, GreaterThan()); 682 } 683 constexpr bool TileSizeOnNonLongSideFrontier(int TileLongSide, 684 int TileShortSide, int size_of_t) { 685 // For a tile size combination (longside, shortside), lying on the frontier 686 // implies that (longside, shortside) is on or within the frontier but 687 // (longside*2, shortside) or (longside, shortside+1) is not. With the above 688 // criterion, we simply need to use !TileSizeOnLongSideFrontier to ensure that 689 // it is not on the long side frontier. 690 return !TileSizeOutsideFrontier(TileLongSide, TileShortSide, size_of_t) && 691 (TileSizeOutsideFrontier(TileLongSide * 2, TileShortSide, size_of_t) || 692 TileSizeOutsideFrontier(TileLongSide, TileShortSide + 1, 693 size_of_t)) && 694 !TileSizeOnLongSideFrontier(TileLongSide, TileShortSide, size_of_t); 695 } 696 697 // Helper function to launch a batch narrow matirx transpose kernel. 698 template <typename T, int TileLongSide, int TileShortSide, bool conjugate> 699 void LaunchBatchNarrowMatrixTransposeKernel( 700 const GPUDevice& d, int tile_size_i, int tile_size_j, int total_tiles_count, 701 const T* input, const Dimension<3>& input_dims, T* output) { 702 constexpr int NumThreads = TileLongSide; 703 if (tile_size_i <= TileLongSide && tile_size_j <= TileShortSide) { 704 TF_CHECK_OK(GpuLaunchKernel( 705 SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileLongSide, 706 TileShortSide, conjugate>, 707 total_tiles_count, NumThreads, 0, d.stream(), input, input_dims, 708 output)); 709 } else { 710 TF_CHECK_OK(GpuLaunchKernel( 711 SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileShortSide, 712 TileLongSide, conjugate>, 713 total_tiles_count, NumThreads, 0, d.stream(), input, input_dims, 714 output)); 715 } 716 } 717 718 // Recursive template function to search, in a trial-and-error manner, for the 719 // minimum tile size configuration satisfying the requested tile side lengths. 720 // An important invariant of this search procedure is that for an unsatisfied 721 // request, we always try doubling the long side len first, and only after 722 // the request is satisfied for the long side len do we begin incrementing 723 // the short side len. 724 // 725 // We have three specializations of this search function depending on where the 726 // current tile size combination lies with respect to the frontier. 727 // - It lies within the frontier. If request is not satisfied, for the next tile 728 // size combination, we first try doubling the long side len and if that does 729 // not work, we then increment the short side len. 730 // - It lies on the non long side frontier. If the request is not satisfied, we 731 // can only increment the short side len. 732 // - It lies on the long side frontier. We launch the kernel without checking if 733 // the request is satisfied or not. 734 template <typename T, int TileLongSide, int TileShortSide, bool conjugate, 735 typename dummy = void> 736 struct BatchNarrowMatrixTransposeDispatcher { 737 static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j, 738 int total_tiles_count, const T* input, 739 const Dimension<3>& input_dims, T* output) { 740 static_assert( 741 (TileLongSide & (TileLongSide - 1)) == 0, 742 "The length of the longer side of the tile is always a power of 2."); 743 bool request_satisfied = 744 std::max(tile_size_i, tile_size_j) <= TileLongSide && 745 std::min(tile_size_i, tile_size_j) <= TileShortSide; 746 747 if (request_satisfied) { 748 LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide, 749 conjugate>( 750 d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, 751 output); 752 return; 753 } 754 755 // If the execution reaches here, then the kernel was not launched; we then 756 // determine whether it is the long side or the short side that falls short 757 // of the request and increase that parameter accordingly. 758 const bool long_side_request_not_satisfied = 759 std::max(tile_size_i, tile_size_j) > TileLongSide; 760 761 if (long_side_request_not_satisfied) { 762 BatchNarrowMatrixTransposeDispatcher<T, TileLongSide * 2, TileShortSide, 763 conjugate>::DoIt(d, tile_size_i, 764 tile_size_j, 765 total_tiles_count, 766 input, input_dims, 767 output); 768 } else { 769 BatchNarrowMatrixTransposeDispatcher<T, TileLongSide, TileShortSide + 1, 770 conjugate>::DoIt(d, tile_size_i, 771 tile_size_j, 772 total_tiles_count, 773 input, input_dims, 774 output); 775 } 776 } 777 }; 778 779 template <typename T, int TileLongSide, int TileShortSide, bool conjugate> 780 struct BatchNarrowMatrixTransposeDispatcher< 781 T, TileLongSide, TileShortSide, conjugate, 782 typename std::enable_if<TileSizeOnNonLongSideFrontier( 783 TileLongSide, TileShortSide, sizeof(T)), 784 void>::type> { 785 static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j, 786 int total_tiles_count, const T* input, 787 const Dimension<3>& input_dims, T* output) { 788 static_assert( 789 (TileLongSide & (TileLongSide - 1)) == 0, 790 "The length of the longer side of the tile is always a power of 2."); 791 bool request_satisfied = 792 std::max(tile_size_i, tile_size_j) <= TileLongSide && 793 std::min(tile_size_i, tile_size_j) <= TileShortSide; 794 795 if (request_satisfied) { 796 LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide, 797 conjugate>( 798 d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, 799 output); 800 return; 801 } 802 803 // If the execution reaches here, then the kernel was not launched; since 804 // we are on the non long side frontier, we increment the short dimension 805 // and try again. 806 BatchNarrowMatrixTransposeDispatcher<T, TileLongSide, TileShortSide + 1, 807 conjugate>::DoIt(d, tile_size_i, 808 tile_size_j, 809 total_tiles_count, 810 input, input_dims, 811 output); 812 } 813 }; 814 815 template <typename T, int TileLongSide, int TileShortSide, bool conjugate> 816 struct BatchNarrowMatrixTransposeDispatcher< 817 T, TileLongSide, TileShortSide, conjugate, 818 typename std::enable_if<TileSizeOnLongSideFrontier( 819 TileLongSide, TileShortSide, sizeof(T)), 820 void>::type> { 821 static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j, 822 int total_tiles_count, const T* input, 823 const Dimension<3>& input_dims, T* output) { 824 static_assert( 825 (TileLongSide & (TileLongSide - 1)) == 0, 826 "The length of the longer side of the tile is always a power of 2."); 827 828 LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide, 829 conjugate>( 830 d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, 831 output); 832 } 833 }; 834 835 // This function tries to recover, in a brute force way, the frontier defined in 836 // TileSizePossibilityFrontierCheck as a vector of tile size combinations lying 837 // on the long side frontier. This vector is sufficient to determine the entire 838 // frontier. 839 // 840 // Note that if one changes the frontier definition in 841 // TileSizePossibilityFrontierCheck and forgets to set the largest short 842 // side len of the largest legal long side len to 2, this function will fail 843 // and crash the program. 844 template <int SizeOfT> 845 const std::vector<std::pair<int, int>>& GetTileSizesFrontier() { 846 static_assert( 847 SizeOfT <= 16, 848 "Currently, only data types of sizes 16 bytes or less are supported."); 849 static_assert((SizeOfT & (SizeOfT - 1)) == 0, 850 "Data types must have sizes that are powers of 2."); 851 852 // Expensive work to populate sizes, lazily run in a thread-safe 853 // manner the first time GetTileSizesFrontier<N> is called. 854 static auto* frontier = [] { 855 auto* frontier = new std::vector<std::pair<int, int>>(); 856 const int kMaxLongSideLen = 1024; 857 const int kMaxShortSideLen = 15; 858 for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) { 859 for (int short_side = 2; short_side <= kMaxShortSideLen; 860 short_side += 1) { 861 if (TileSizeOnLongSideFrontier(long_side, short_side, SizeOfT)) { 862 // The current combination lies on the frontier, thus we 863 // add it to the frontier definition. 864 frontier->push_back(std::make_pair(long_side, short_side)); 865 866 // The long side length is the largest one allowed iff its 867 // corresponding short side length is 2. 868 if (short_side == 2) return frontier; 869 870 // We have exhausted all the possibilities in the frontier 871 // with the given long side length. 872 break; 873 } 874 } 875 } 876 LOG(FATAL) 877 << "The corresponding short side length of the largest long side " 878 "length has to be 2."; 879 }(); 880 return *frontier; 881 } 882 883 // Helper structs to help determine which data type to use given the size of 884 // the matrix data type. A transpose of elements of size N will use a kernel 885 // which operates on an array of TransposeElemType<N>::type. 886 template <int ElemBytes> 887 struct TransposeElemType; 888 template <> 889 struct TransposeElemType<1> { 890 using type = uint8; 891 }; 892 template <> 893 struct TransposeElemType<2> { 894 using type = uint16; 895 }; 896 template <> 897 struct TransposeElemType<4> { 898 using type = uint32; 899 }; 900 template <> 901 struct TransposeElemType<8> { 902 using type = float2; 903 }; 904 template <> 905 struct TransposeElemType<16> { 906 using type = double2; 907 }; 908 909 // A helper function to make RunSwapDimension1And2InTensor3 concise. This 910 // helper function looks at the data type and input matrix sizes and decides 911 // the thread numbers and tile sizes to use. 912 template <typename T, bool conjugate = false> 913 void SwapDimension1And2InTensor3WithNarrowMatrices( 914 const GPUDevice& d, const T* input, const Dimension<3>& input_dims, 915 T* output, const int kMinDimensionToUseTiles) { 916 // Get available tile sizes here for the data type requested: 917 const auto& tile_spec = GetTileSizesFrontier<sizeof(T)>(); 918 919 int tile_long_side_len = 0; 920 int tile_short_side_len = 0; 921 float lowest_cost = std::numeric_limits<float>::max(); 922 int data_long_side = std::max(input_dims[1], input_dims[2]); 923 924 for (auto tile_size_pair : tile_spec) { 925 int proposed_tile_long_side_len = tile_size_pair.first; 926 927 // Number of threads that will not be doing anything useful when reading 928 // the matrix because the thread block size is bigger than the data block 929 // size. 930 int num_wasted_threads = 931 data_long_side - MathUtil::FloorOfRatio<int>( 932 data_long_side, proposed_tile_long_side_len) * 933 proposed_tile_long_side_len; 934 935 int num_full_tiles = MathUtil::FloorOfRatio<int>( 936 data_long_side, proposed_tile_long_side_len); 937 938 float cost = 0; 939 940 // However, if we can execute two or more full tiles, then we gladly 941 // accept any number of wasted threads and ignore its cost. 942 if (num_full_tiles <= 1) cost = num_wasted_threads; 943 944 // Using less than or equal to here because given the same cost, we 945 // would like to launch as many threads as possible. 946 if (cost <= lowest_cost) { 947 tile_long_side_len = proposed_tile_long_side_len; 948 tile_short_side_len = tile_size_pair.second; 949 lowest_cost = cost; 950 } 951 } 952 953 // Request tile sizes such that the longer side of threadblock aligns with 954 // the longer side of input data block to maximize read throughput. 955 // The ideal tile shape is one where the length of the shorter side of the 956 // tile is equal to the length of the shorter side of the input matrix. 957 int requested_tile_size_i = input_dims[1] >= kMinDimensionToUseTiles 958 ? tile_long_side_len 959 : input_dims[1]; 960 int requested_tile_size_j = input_dims[1] >= kMinDimensionToUseTiles 961 ? input_dims[2] 962 : tile_long_side_len; 963 964 // Truncate the shorter size requested according to the manual limit set in 965 // tile_spec to make sure that we do not launch configurations violating 966 // hardware limits. 967 requested_tile_size_i = 968 requested_tile_size_i == tile_long_side_len 969 ? tile_long_side_len 970 : std::min(requested_tile_size_i, tile_short_side_len); 971 requested_tile_size_j = 972 requested_tile_size_j == tile_long_side_len 973 ? tile_long_side_len 974 : std::min(requested_tile_size_j, tile_short_side_len); 975 976 Dimension<3> input_dims_in_tiles = { 977 input_dims[0], 978 MathUtil::CeilOfRatio<int>(input_dims[1], requested_tile_size_i), 979 MathUtil::CeilOfRatio<int>(input_dims[2], requested_tile_size_j), 980 }; 981 982 int total_tiles_count = 983 input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2]; 984 985 using ElemType = typename TransposeElemType<sizeof(T)>::type; 986 static_assert(alignof(T) >= alignof(ElemType), "Unexpected data alignment."); 987 BatchNarrowMatrixTransposeDispatcher<ElemType, 32, 2, conjugate>::DoIt( 988 d, requested_tile_size_i, requested_tile_size_j, total_tiles_count, 989 reinterpret_cast<const ElemType*>(input), input_dims, 990 reinterpret_cast<ElemType*>(output)); 991 } 992 993 // Launch the GPU kernel that would swap dimension-1 and dimension-2 in a 994 // 3D tensor. It looks at the shape of the incoming data, and decides the best 995 // strategy to launch. 996 template <typename T, bool conjugate = false> 997 void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input, 998 const Dimension<3>& input_dims, T* output) { 999 // If both dimensions are not trivial, use tiles for the actual swapping. 1000 // If one dimension is trivial, use SmallDim kernel for swapping. 1001 // Otherwise, the trivial swapping relying on the ldg cache is more efficient. 1002 static const int kMinDimensionToUseTiles = 16; 1003 static const int kMinDimensionToUseRectTiles = 96; 1004 1005 bool large_matrix = input_dims[1] >= kMinDimensionToUseTiles && 1006 input_dims[2] >= kMinDimensionToUseTiles; 1007 bool narrow_matrix = input_dims[1] >= kMinDimensionToUseRectTiles || 1008 input_dims[2] >= kMinDimensionToUseRectTiles; 1009 if (large_matrix) { 1010 // We get best performance when kTileSize is the number of threads in a warp 1011 // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256 1012 // threads. 1013 constexpr int kTileSize = 32; 1014 constexpr int kNumThreads = 256; 1015 1016 Dimension<3> input_dims_in_tiles = { 1017 input_dims[0], 1018 MathUtil::CeilOfRatio<int>(input_dims[1], kTileSize), 1019 MathUtil::CeilOfRatio<int>(input_dims[2], kTileSize), 1020 }; 1021 1022 int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] * 1023 input_dims_in_tiles[2]; 1024 TF_CHECK_OK(GpuLaunchKernel( 1025 SwapDimension1And2InTensor3UsingTiles<T, kNumThreads, kTileSize, 1026 kTileSize, conjugate>, 1027 total_tiles_count, kNumThreads, 0, d.stream(), input, input_dims, 1028 output)); 1029 1030 } else if (narrow_matrix) { 1031 SwapDimension1And2InTensor3WithNarrowMatrices<T, conjugate>( 1032 d, input, input_dims, output, kMinDimensionToUseTiles); 1033 } else { 1034 int total_element_count = input_dims[0] * input_dims[1] * input_dims[2]; 1035 GpuLaunchConfig config = GetGpuLaunchConfig(total_element_count, d); 1036 TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 0, 2, 1, conjugate>, 1037 config.block_count, config.thread_per_block, 0, 1038 d.stream(), config.virtual_thread_count, input, 1039 input_dims, output)); 1040 } 1041 } 1042 1043 // A GPU helper functor that does general dimension 1 and 2 switch for 3D 1044 // tensor. 1045 template <typename T, bool conjugate> 1046 struct SwapDimension1And2InTensor3<GPUDevice, T, conjugate> { 1047 typedef GPUDevice Device; 1048 void operator()(const Device& d, const T* in, 1049 const gtl::ArraySlice<int64_t>& combined_dims, T* out) { 1050 Dimension<3> input_dims = {static_cast<int>(combined_dims[0]), 1051 static_cast<int>(combined_dims[1]), 1052 static_cast<int>(combined_dims[2])}; 1053 RunSwapDimension1And2InTensor3<T, conjugate>(d, in, input_dims, out); 1054 } 1055 }; 1056 1057 // A GPU helper functor that does general dimension 0 and 2 switch for 3D 1058 // tensor. 1059 template <typename T, bool conjugate> 1060 struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> { 1061 typedef GPUDevice Device; 1062 void operator()(const Device& d, const T* in, 1063 const gtl::ArraySlice<int64_t>& combined_dims, T* out) { 1064 Dimension<3> input_dims = {static_cast<int>(combined_dims[0]), 1065 static_cast<int>(combined_dims[1]), 1066 static_cast<int>(combined_dims[2])}; 1067 size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2]; 1068 GpuLaunchConfig config = GetGpuLaunchConfig(total_size, d); 1069 1070 auto out_ptr = reinterpret_cast<uintptr_t>(out); 1071 bool aligned = out_ptr % 16 == 0; 1072 1073 bool use_vector = false; 1074 bool use_custom_config = false; 1075 if ((input_dims[0] <= 128 && input_dims[2] <= 128) || 1076 input_dims[0] * input_dims[1] <= 128 || 1077 input_dims[1] * input_dims[2] <= 8) { 1078 use_vector = true; 1079 use_custom_config = true; 1080 } else if (input_dims[1] * input_dims[2] <= 16384) { 1081 use_vector = true; 1082 } 1083 1084 if (sizeof(T) == 2 && aligned && use_vector) { 1085 int block_count; 1086 if (use_custom_config) { 1087 block_count = (total_size + config.thread_per_block - 1) / 1088 config.thread_per_block; 1089 } else { 1090 block_count = config.block_count; 1091 } 1092 1093 TF_CHECK_OK( 1094 GpuLaunchKernel(ShuffleInTensor3SimpleVector<T, 2, 1, 0, conjugate>, 1095 block_count, config.thread_per_block / kUnroll, 0, 1096 d.stream(), total_size, in, input_dims, out)); 1097 } else { 1098 TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>, 1099 config.block_count, config.thread_per_block, 1100 0, d.stream(), config.virtual_thread_count, 1101 in, input_dims, out)); 1102 } 1103 } 1104 }; 1105 1106 // A GPU helper functor that converts NHWC TensorFlow data format to 1107 // NCHW format that is accepted by Cudnn. 1108 template <typename T, int NDIMS> 1109 struct NHWCToNCHW<GPUDevice, T, NDIMS> { 1110 typedef GPUDevice Device; 1111 void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in, 1112 typename TTypes<T, NDIMS>::Tensor out) { 1113 Dimension<3> combined_dims; 1114 combined_dims[0] = in.dimension(0); // N (batch) 1115 combined_dims[1] = in.dimension(1); // spatial dimensions (HW) 1116 for (int i = 2; i < NDIMS - 1; ++i) { 1117 combined_dims[1] *= in.dimension(i); 1118 } 1119 combined_dims[2] = in.dimension(NDIMS - 1); // C (channels) 1120 RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data()); 1121 } 1122 }; 1123 1124 // A GPU helper functor that converts NCHW Cudnn data format to NHWC TensorFlow 1125 // Format. 1126 template <typename T, int NDIMS> 1127 struct NCHWToNHWC<GPUDevice, T, NDIMS> { 1128 typedef GPUDevice Device; 1129 void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in, 1130 typename TTypes<T, NDIMS>::Tensor out) { 1131 Dimension<3> combined_dims; 1132 combined_dims[0] = in.dimension(0); // N (batch) 1133 combined_dims[1] = in.dimension(1); // C (channel) 1134 combined_dims[2] = in.dimension(2); // spatial dimensions (HW) 1135 for (int i = 3; i < NDIMS; ++i) { 1136 combined_dims[2] *= in.dimension(i); 1137 } 1138 RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data()); 1139 } 1140 }; 1141 1142 } // namespace functor 1143 } // namespace tensorflow 1144 1145 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 1146 1147 #endif // TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_ 1148