xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/vol2col.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cuda/CUDAContext.h>
4 #include <ATen/cuda/detail/KernelUtils.h>
5 #include <ATen/cuda/detail/IndexUtils.cuh>
6 #include <ATen/cuda/detail/TensorInfo.cuh>
7 
8 #include <c10/macros/Macros.h>
9 
10 namespace at {
11 namespace native {
12 
13 using namespace at::cuda::detail;
14 
15 // Kernel for fast unfold+copy on volumes
16 template <typename T>
17 C10_LAUNCH_BOUNDS_1(1024)
vol2col_kernel(const int64_t n,const T * data_vol,const int depth,const int height,const int width,const int ksize_t,const int ksize_h,const int ksize_w,const int pad_t,const int pad_h,const int pad_w,const int stride_t,const int stride_h,const int stride_w,const int dilation_t,const int dilation_h,const int dilation_w,const int depth_col,const int height_col,const int width_col,T * data_col)18 __global__ void vol2col_kernel(
19     const int64_t n,
20     const T* data_vol,
21     const int depth,
22     const int height,
23     const int width,
24     const int ksize_t,
25     const int ksize_h,
26     const int ksize_w,
27     const int pad_t,
28     const int pad_h,
29     const int pad_w,
30     const int stride_t,
31     const int stride_h,
32     const int stride_w,
33     const int dilation_t,
34     const int dilation_h,
35     const int dilation_w,
36     const int depth_col,
37     const int height_col,
38     const int width_col,
39     T* data_col) {
40   CUDA_KERNEL_LOOP_TYPE(index, n, int64_t) {
41     auto w_out = index % width_col;
42     index /= width_col;
43     auto h_out = index % height_col;
44     index /= height_col;
45     auto t_out = index % depth_col;
46     auto channel_in = index / depth_col;
47     auto channel_out = channel_in * ksize_t * ksize_h * ksize_w;
48     auto t_in = t_out * stride_t - pad_t;
49     auto h_in = h_out * stride_h - pad_h;
50     auto w_in = w_out * stride_w - pad_w;
51     data_col +=
52         ((channel_out * depth_col + t_out) * height_col + h_out) * width_col +
53         w_out;
54     data_vol += ((channel_in * depth + t_in) * height + h_in) * width + w_in;
55     for (int i = 0; i < ksize_t; ++i) {
56       for (int j = 0; j < ksize_h; ++j) {
57         for (int k = 0; k < ksize_w; ++k) {
58           auto t = t_in + i * dilation_t;
59           auto h = h_in + j * dilation_h;
60           auto w = w_in + k * dilation_w;
61           *data_col = (t >= 0 && h >= 0 && w >= 0 && t < depth && h < height &&
62                        w < width)
63               ? data_vol
64                     [i * dilation_t * height * width + j * dilation_h * width +
65                      k * dilation_w]
66               : static_cast<T>(0);
67           data_col += depth_col * height_col * width_col;
68         }
69       }
70     }
71   }
72 }
73 
74 template <typename T>
vol2col(cudaStream_t stream,const T * data_vol,const int channels,const int depth,const int height,const int width,const int depth_col,const int height_col,const int width_col,const int ksize_t,const int ksize_h,const int ksize_w,const int pad_t,const int pad_h,const int pad_w,const int stride_t,const int stride_h,const int stride_w,const int dilation_t,const int dilation_h,const int dilation_w,T * data_col)75 void vol2col(
76     cudaStream_t stream,
77     const T* data_vol,
78     const int channels,
79     const int depth,
80     const int height,
81     const int width,
82     const int depth_col,
83     const int height_col,
84     const int width_col,
85     const int ksize_t,
86     const int ksize_h,
87     const int ksize_w,
88     const int pad_t,
89     const int pad_h,
90     const int pad_w,
91     const int stride_t,
92     const int stride_h,
93     const int stride_w,
94     const int dilation_t,
95     const int dilation_h,
96     const int dilation_w,
97     T* data_col) {
98   // We are going to launch channels * depth_col * height_col * width_col
99   // kernels, each kernel responsible for copying a single-channel grid.
100   // We cast an operand to int64 so that the product will not overflow
101   const auto num_kernels = static_cast<int64_t>(channels) * depth_col * height_col * width_col;
102   // Launch
103   vol2col_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>(
104       num_kernels,
105       data_vol,
106       depth,
107       height,
108       width,
109       ksize_t,
110       ksize_h,
111       ksize_w,
112       pad_t,
113       pad_h,
114       pad_w,
115       stride_t,
116       stride_h,
117       stride_w,
118       dilation_t,
119       dilation_h,
120       dilation_w,
121       depth_col,
122       height_col,
123       width_col,
124       data_col);
125   C10_CUDA_KERNEL_LAUNCH_CHECK();
126 }
127 
128 template <typename T, typename accT>
vol2im_kernel(const int64_t n,const T * data_col,const unsigned depth,const unsigned height,const unsigned width,const unsigned channels,const unsigned kernel_t,const unsigned kernel_h,const unsigned kernel_w,const unsigned pad_t,const unsigned pad_h,const unsigned pad_w,const unsigned stride_t,const unsigned stride_h,const unsigned stride_w,const unsigned dilation_t,const unsigned dilation_h,const unsigned dilation_w,const unsigned depth_col,const unsigned height_col,const unsigned width_col,T * data_vol)129 __global__ void vol2im_kernel(
130     const int64_t n,
131     const T* data_col,
132     const unsigned depth,
133     const unsigned height,
134     const unsigned width,
135     const unsigned channels,
136     const unsigned kernel_t,
137     const unsigned kernel_h,
138     const unsigned kernel_w,
139     const unsigned pad_t,
140     const unsigned pad_h,
141     const unsigned pad_w,
142     const unsigned stride_t,
143     const unsigned stride_h,
144     const unsigned stride_w,
145     const unsigned dilation_t,
146     const unsigned dilation_h,
147     const unsigned dilation_w,
148     const unsigned depth_col,
149     const unsigned height_col,
150     const unsigned width_col,
151     T* data_vol) {
152   CUDA_KERNEL_LOOP(index, n) {
153     accT val = static_cast<accT>(0);
154     const auto w_im = index % width + pad_w;
155     const auto h_im = (index / width) % height + pad_h;
156     const auto t_im = (index / width / height) % depth + pad_t;
157     const auto c_im = index / (width * height * depth);
158     auto kernel_extent_w = (kernel_w - 1) * dilation_w + 1;
159     auto kernel_extent_h = (kernel_h - 1) * dilation_h + 1;
160     auto kernel_extent_t = (kernel_t - 1) * dilation_t + 1;
161     // compute the start and end of the output
162     const auto w_col_start =
163         (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1;
164     const auto w_col_end = std::min(w_im / stride_w + 1, width_col);
165     const auto h_col_start =
166         (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1;
167     const auto h_col_end = std::min(h_im / stride_h + 1, height_col);
168     const auto t_col_start =
169         (t_im < kernel_extent_t) ? 0 : (t_im - kernel_extent_t) / stride_t + 1;
170     const auto t_col_end = std::min(t_im / stride_t + 1, depth_col);
171     // TODO: use LCM of stride and dilation to avoid unnecessary loops
172     for (unsigned t_col = t_col_start; t_col < t_col_end; t_col += 1) {
173       for (unsigned h_col = h_col_start; h_col < h_col_end; h_col += 1) {
174         for (unsigned w_col = w_col_start; w_col < w_col_end; w_col += 1) {
175           uint64_t t_k = (t_im - t_col * stride_t);
176           uint64_t h_k = (h_im - h_col * stride_h);
177           uint64_t w_k = (w_im - w_col * stride_w);
178           if (t_k % dilation_t == 0 && h_k % dilation_h == 0 &&
179               w_k % dilation_w == 0) {
180             t_k /= dilation_t;
181             h_k /= dilation_h;
182             w_k /= dilation_w;
183             const int64_t idx_k =
184                 ((c_im * kernel_t + t_k) * kernel_h + h_k) * kernel_w + w_k;
185             const int64_t data_col_index =
186                 ((idx_k * depth_col + t_col) *
187                     height_col + h_col) *
188                   width_col + w_col;
189             val += data_col[data_col_index];
190           }
191         }
192       }
193     }
194     data_vol[index] = static_cast<T>(val);
195   }
196 }
197 
198 template <typename T, typename accT>
col2vol(cudaStream_t stream,const T * data_col,const int64_t channels,const int64_t depth,const int64_t height,const int64_t width,const int64_t output_depth,const int64_t output_height,const int64_t output_width,const int64_t patch_t,const int64_t patch_h,const int64_t patch_w,const int64_t pad_t,const int64_t pad_h,const int64_t pad_w,const int64_t stride_t,const int64_t stride_h,const int64_t stride_w,const int64_t dilation_t,const int64_t dilation_h,const int64_t dilation_w,T * data_vol)199 void col2vol(
200     cudaStream_t stream,
201     const T* data_col,
202     const int64_t channels,
203     const int64_t depth,
204     const int64_t height,
205     const int64_t width,
206     const int64_t output_depth,
207     const int64_t output_height,
208     const int64_t output_width,
209     const int64_t patch_t,
210     const int64_t patch_h,
211     const int64_t patch_w,
212     const int64_t pad_t,
213     const int64_t pad_h,
214     const int64_t pad_w,
215     const int64_t stride_t,
216     const int64_t stride_h,
217     const int64_t stride_w,
218     const int64_t dilation_t,
219     const int64_t dilation_h,
220     const int64_t dilation_w,
221     T* data_vol) {
222   const auto num_kernels = channels * depth * height * width;
223 
224   auto check_fits_in_unsigned =
225     [](int64_t val, const char * name) {
226       constexpr auto umax = std::numeric_limits<unsigned>::max();
227       TORCH_CHECK(val >= 0 && val <= umax,
228                   name, " must fit in a 32-bit unsigned value");
229     };
230   check_fits_in_unsigned(num_kernels, "input size");
231   check_fits_in_unsigned(
232       channels * patch_t * patch_h * patch_w, "channels x kernel size");
233 
234   // To avoid involving atomic operations, we will launch one kernel per
235   // bottom dimension, and then in the kernel add up the top dimensions.
236   vol2im_kernel<T, accT>
237       <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>(
238           num_kernels,
239           data_col,
240           depth,
241           height,
242           width,
243           channels,
244           patch_t,
245           patch_h,
246           patch_w,
247           pad_t,
248           pad_h,
249           pad_w,
250           stride_t,
251           stride_h,
252           stride_w,
253           dilation_t,
254           dilation_h,
255           dilation_w,
256           output_depth,
257           output_height,
258           output_width,
259           data_vol);
260   C10_CUDA_KERNEL_LAUNCH_CHECK();
261 }
262 
263 } // namespace native
264 } // namespace at
265