xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 
18 #define EIGEN_USE_GPU
19 
20 #include "tensorflow/core/framework/tensor_types.h"
21 #include "tensorflow/core/kernels/depthtospace_op.h"
22 #include "tensorflow/core/platform/types.h"
23 #include "tensorflow/core/util/gpu_kernel_helper.h"
24 
25 namespace tensorflow {
26 namespace {
27 
28 using GPUDevice = Eigen::GpuDevice;
29 
30 // Depth2Space kernel for FORMAT_NHWC.
31 // See 'depthtospace_op.h' for a more detailed description.
32 template <typename dtype>
D2S_NHWC(const int32 nthreads,const dtype * __restrict__ input_ptr,const int block_size,const int batch_size,const int input_height,const int input_width,const int input_depth,const int output_height,const int output_width,const int output_depth,dtype * __restrict__ output_ptr)33 __global__ void D2S_NHWC(const int32 nthreads,
34                          const dtype* __restrict__ input_ptr,
35                          const int block_size, const int batch_size,
36                          const int input_height, const int input_width,
37                          const int input_depth, const int output_height,
38                          const int output_width, const int output_depth,
39                          dtype* __restrict__ output_ptr) {
40   GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
41     // out_idx = d + output_depth * (w + output_width * (h + output_height * b))
42     const int d = out_idx % output_depth;
43     const int out_idx2 = out_idx / output_depth;
44     const int w = out_idx2 % output_width;
45     const int out_idx3 = out_idx2 / output_width;
46     const int h = out_idx3 % output_height;
47     const int b = out_idx3 / output_height;
48 
49     const int in_h = h / block_size;
50     const int offset_h = h % block_size;
51     const int in_w = w / block_size;
52     const int offset_w = w % block_size;
53     const int offset_d = (offset_h * block_size + offset_w) * output_depth;
54     const int in_d = d + offset_d;
55     const int inp_idx =
56         in_d + input_depth * (in_w + input_width * (in_h + input_height * b));
57     *(output_ptr + out_idx) = ldg(input_ptr + inp_idx);
58   }
59 }
60 
61 // Depth2Space kernel for FORMAT_NCHW.
62 // See 'spacetodepth_op.h' for a more detailed description.
63 template <typename dtype>
D2S_NCHW(const int32 nthreads,const dtype * __restrict__ input_ptr,const int block_size,const int input_width,const int output_depth_by_input_height,dtype * __restrict__ output_ptr)64 __global__ void D2S_NCHW(const int32 nthreads,
65                          const dtype* __restrict__ input_ptr,
66                          const int block_size, const int input_width,
67                          const int output_depth_by_input_height,
68                          dtype* __restrict__ output_ptr) {
69   GPU_1D_KERNEL_LOOP(input_idx, nthreads) {
70     // We will be converting the image from ordering:
71     // n, bY, bX, oC, iY, iX    (== input_idx)   to
72     // n, oC, iY, bY, iX, bX
73 
74     // Start reading the input data straight away since we know the address.
75     // We calculate the output address in parallel while this is being fetched.
76 
77     const int n_bY_bX_oC_iY = input_idx / input_width;
78     const int iX = input_idx - n_bY_bX_oC_iY * input_width;
79 
80     const int n_bY_bX = n_bY_bX_oC_iY / output_depth_by_input_height;
81     const int oC_iY = n_bY_bX_oC_iY - n_bY_bX * output_depth_by_input_height;
82 
83     const int n_bY = n_bY_bX / block_size;
84     const int bX = n_bY_bX - n_bY * block_size;
85 
86     const int n = n_bY / block_size;
87     const int bY = n_bY - n * block_size;
88 
89     const int output_idx =
90         bX +
91         block_size *
92             (iX + input_width *
93                       (bY + block_size *
94                                 (oC_iY + n * output_depth_by_input_height)));
95 
96     *(output_ptr + output_idx) = ldg(input_ptr + input_idx);
97   }
98 }
99 
100 template <typename dtype, int block_size>
D2S_NCHW_LOOP(const int32 nthreads,const dtype * __restrict__ input,const int input_width,const int output_width,const int output_depth_by_input_area,const int input_depth_by_input_area,dtype * __restrict__ output)101 __global__ void D2S_NCHW_LOOP(const int32 nthreads,
102                               const dtype* __restrict__ input,
103                               const int input_width, const int output_width,
104                               const int output_depth_by_input_area,
105                               const int input_depth_by_input_area,
106                               dtype* __restrict__ output) {
107   GPU_1D_KERNEL_LOOP(thread_idx, nthreads) {
108     // We will be converting the image from ordering:
109     // n, bY, bX, oC, iY, iX   to
110     // n, oC, iY, bY, iX, bX
111 
112     // We assume thread_idx encodes n_oC_iY_iX, and use an unrolled loop over
113     // bY and bX coordinates within the block. This kernel is significantly
114     // more performant than the D2S_NCHW kernel.
115     //   A likely explanation of the improvement is that although both kernels
116     // get input coalescing, this one would write the output data more densely
117     // per warp, so would benefit assuming delayed cache writeback is used.
118 
119     const int n_oC_iY = thread_idx / input_width;
120     const int iX = thread_idx - n_oC_iY * input_width;
121 
122     const int n = thread_idx / output_depth_by_input_area;
123     const int oC_iY_iX = thread_idx - n * output_depth_by_input_area;
124 
125     // Recombine the components and apply to the input and output pointers.
126     auto input_ptr = input + n * input_depth_by_input_area + oC_iY_iX;
127     auto output_ptr = output + (n_oC_iY * output_width + iX) * block_size;
128 
129 #pragma unroll
130     // Copy a patch of data to the output batch image.
131     for (int bY = 0; bY < block_size; ++bY) {
132 #pragma unroll
133       for (int bX = 0; bX < block_size; ++bX) {
134         output_ptr[bY * output_width + bX] = ldg(
135             input_ptr + (bY * block_size + bX) * output_depth_by_input_area);
136       }
137     }
138   }
139 }
140 
141 }  // namespace
142 
143 // Specialization of DepthToSpaceOpFunctor for a GPUDevice.
144 namespace functor {
145 
146 template <typename T>
147 struct DepthToSpaceOpFunctor<GPUDevice, T, FORMAT_NHWC> {
operator ()tensorflow::functor::DepthToSpaceOpFunctor148   void operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
149                   int block_size, typename TTypes<T, 4>::Tensor output) {
150     const int batch_size = output.dimension(0);
151     const int input_height = input.dimension(1);
152     const int input_width = input.dimension(2);
153     const int input_depth = input.dimension(3);
154     const int output_height = output.dimension(1);
155     const int output_width = output.dimension(2);
156     const int output_depth = output.dimension(3);
157 
158     const int total_count =
159         batch_size * output_height * output_width * output_depth;
160     if (total_count == 0) {
161       return;
162     }
163     GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d);
164     TF_CHECK_OK(GpuLaunchKernel(
165         D2S_NHWC<T>, config.block_count, config.thread_per_block, 0, d.stream(),
166         config.virtual_thread_count, input.data(), block_size, batch_size,
167         input_height, input_width, input_depth, output_height, output_width,
168         output_depth, output.data()));
169   }
operator ()tensorflow::functor::DepthToSpaceOpFunctor170   void operator()(const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input,
171                   int block_size, typename TTypes<T, 5>::Tensor output) {
172     LOG(FATAL) << "5-D tensors should not be used with NHWC format";
173   }
174 };
175 
176 template <typename T>
177 struct DepthToSpaceOpFunctor<GPUDevice, T, FORMAT_NCHW> {
operator ()tensorflow::functor::DepthToSpaceOpFunctor178   void operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
179                   int block_size, typename TTypes<T, 4>::Tensor output) {
180     const int batch_size = input.dimension(0);
181     const int input_depth = input.dimension(1);
182     const int input_height = input.dimension(2);
183     const int input_width = input.dimension(3);
184     const int output_depth = output.dimension(1);
185     const int input_area = input_width * input_height;
186     const int input_depth_by_input_area = input_depth * input_area;
187 
188     // We improve performance by generating instantiations of the loop kernel
189     // for the most common block sizes.
190     if (block_size <= 4) {
191       const int output_width = output.dimension(3);
192       const int output_depth_by_input_area = output_depth * input_area;
193       const int total_count = batch_size * output_depth_by_input_area;
194       if (total_count == 0) {
195         return;
196       }
197       GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d);
198       switch (block_size) {
199         case 2:
200           TF_CHECK_OK(GpuLaunchKernel(
201               D2S_NCHW_LOOP<T, 2>, config.block_count, config.thread_per_block,
202               0, d.stream(), total_count, input.data(), input_width,
203               output_width, output_depth_by_input_area,
204               input_depth_by_input_area, output.data()));
205           return;
206         case 3:
207           TF_CHECK_OK(GpuLaunchKernel(
208               D2S_NCHW_LOOP<T, 3>, config.block_count, config.thread_per_block,
209               0, d.stream(), total_count, input.data(), input_width,
210               output_width, output_depth_by_input_area,
211               input_depth_by_input_area, output.data()));
212           return;
213         case 4:
214           TF_CHECK_OK(GpuLaunchKernel(
215               D2S_NCHW_LOOP<T, 4>, config.block_count, config.thread_per_block,
216               0, d.stream(), total_count, input.data(), input_width,
217               output_width, output_depth_by_input_area,
218               input_depth_by_input_area, output.data()));
219           return;
220       }
221     }
222 
223     // Other block sizes are processed by the generic kernel.
224     const int total_count = batch_size * input_depth_by_input_area;
225     if (total_count == 0) {
226       return;
227     }
228     auto config = GetGpuLaunchConfig(total_count, d);
229     TF_CHECK_OK(GpuLaunchKernel(
230         D2S_NCHW<T>, config.block_count, config.thread_per_block, 0, d.stream(),
231         config.virtual_thread_count, input.data(), block_size, input_width,
232         output_depth * input_height, output.data()));
233   }
operator ()tensorflow::functor::DepthToSpaceOpFunctor234   void operator()(const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input,
235                   int block_size, typename TTypes<T, 5>::Tensor output) {
236     LOG(FATAL) << "5-D tensors should not be used with NCHW format";
237   }
238 };
239 }  // end namespace functor
240 
241 // Instantiate the GPU implementations for float.
242 template struct functor::DepthToSpaceOpFunctor<GPUDevice, float, FORMAT_NCHW>;
243 template struct functor::DepthToSpaceOpFunctor<GPUDevice, float, FORMAT_NHWC>;
244 
245 // Instantiate the GPU implementations for Eigen::half.
246 template struct functor::DepthToSpaceOpFunctor<GPUDevice, Eigen::half,
247                                                FORMAT_NCHW>;
248 template struct functor::DepthToSpaceOpFunctor<GPUDevice, Eigen::half,
249                                                FORMAT_NHWC>;
250 
251 // NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32.
252 template struct functor::DepthToSpaceOpFunctor<GPUDevice, int32, FORMAT_NCHW>;
253 
254 }  // end namespace tensorflow
255 
256 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
257