xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/GridSampler.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/OpMathType.h>
3 #include <ATen/native/cuda/GridSampler.h>
4 #include <ATen/native/GridSamplerUtils.h>
5 #include <ATen/native/cuda/GridSampler.cuh>
6 #include <ATen/native/cuda/UpSample.cuh>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/detail/TensorInfo.cuh>
9 #include <ATen/cuda/detail/IndexUtils.cuh>
10 #include <ATen/cuda/detail/KernelUtils.h>
11 #include <ATen/core/TensorBase.h>
12 #include <ATen/Dispatch.h>
13 #include <c10/macros/Macros.h>
14 #include <cmath>
15 
16 namespace at::native {
17 
18 using namespace at::cuda::detail;
19 
20 using at::native::detail::GridSamplerInterpolation;
21 using at::native::detail::GridSamplerPadding;
22 
23 namespace {
24   template <typename scalar_t, typename index_t>
25   C10_LAUNCH_BOUNDS_1(256)
grid_sampler_2d_kernel(const index_t nthreads,TensorInfo<const scalar_t,index_t> input,TensorInfo<const scalar_t,index_t> grid,TensorInfo<scalar_t,index_t> output,const GridSamplerInterpolation interpolation_mode,const GridSamplerPadding padding_mode,bool align_corners)26   __global__ void grid_sampler_2d_kernel(
27       const index_t nthreads,
28       TensorInfo<const scalar_t, index_t> input,
29       TensorInfo<const scalar_t, index_t> grid,
30       TensorInfo<scalar_t, index_t> output,
31       const GridSamplerInterpolation interpolation_mode,
32       const GridSamplerPadding padding_mode,
33       bool align_corners) {
34 
35     using opmath_t = at::opmath_type<scalar_t>;
36     index_t C = input.sizes[1];
37     index_t inp_H = input.sizes[2];
38     index_t inp_W = input.sizes[3];
39     index_t out_H = grid.sizes[1];
40     index_t out_W = grid.sizes[2];
41     index_t inp_sN = input.strides[0];
42     index_t inp_sC = input.strides[1];
43     index_t inp_sH = input.strides[2];
44     index_t inp_sW = input.strides[3];
45     index_t grid_sN = grid.strides[0];
46     index_t grid_sH = grid.strides[1];
47     index_t grid_sW = grid.strides[2];
48     index_t grid_sCoor = grid.strides[3];
49     index_t out_sN = output.strides[0];
50     index_t out_sC = output.strides[1];
51     index_t out_sH = output.strides[2];
52     index_t out_sW = output.strides[3];
53 
54     CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
55       const index_t w = index % out_W;
56       const index_t h = (index / out_W) % out_H;
57       const index_t n = index / (out_H * out_W);
58       const index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
59 
60       // get the corresponding input x, y co-ordinates from grid
61       opmath_t x = grid.data[grid_offset];
62       opmath_t y = grid.data[grid_offset + grid_sCoor];
63 
64       opmath_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners);
65       opmath_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners);
66 
67       if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
68         // get NE, NW, SE, SW pixel values from (x, y)
69         index_t ix_nw = static_cast<index_t>(::floor(ix));
70         index_t iy_nw = static_cast<index_t>(::floor(iy));
71         index_t ix_ne = ix_nw + 1;
72         index_t iy_ne = iy_nw;
73         index_t ix_sw = ix_nw;
74         index_t iy_sw = iy_nw + 1;
75         index_t ix_se = ix_nw + 1;
76         index_t iy_se = iy_nw + 1;
77 
78         // get surfaces to each neighbor:
79         opmath_t nw = (ix_se - ix)    * (iy_se - iy);
80         opmath_t ne = (ix    - ix_sw) * (iy_sw - iy);
81         opmath_t sw = (ix_ne - ix)    * (iy    - iy_ne);
82         opmath_t se = (ix    - ix_nw) * (iy    - iy_nw);
83 
84         // calculate bilinear weighted pixel value and set output pixel
85         auto inp_ptr_NC = input.data + n * inp_sN;
86         auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW;
87         for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {
88           opmath_t out_acc = 0;
89           if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
90             out_acc += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw;
91           }
92           if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
93             out_acc += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne;
94           }
95           if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
96             out_acc += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw;
97           }
98           if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
99             out_acc += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se;
100           }
101           *out_ptr_NCHW = out_acc;
102         }
103       } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
104         index_t ix_nearest = static_cast<index_t>(std::nearbyint(ix));
105         index_t iy_nearest = static_cast<index_t>(std::nearbyint(iy));
106 
107         // assign nearest neighbour pixel value to output pixel
108         auto inp_ptr_NC = input.data + n * inp_sN;
109         auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW;
110         for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {
111           if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) {
112             *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW];
113           } else {
114             *out_ptr_NCHW = static_cast<scalar_t>(0);
115           }
116         }
117       } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
118 
119         ix = grid_sampler_unnormalize(x, inp_W, align_corners);
120         iy = grid_sampler_unnormalize(y, inp_H, align_corners);
121 
122         opmath_t ix_nw = std::floor(ix);
123         opmath_t iy_nw = std::floor(iy);
124 
125         const opmath_t tx = ix - ix_nw;
126         const opmath_t ty = iy - iy_nw;
127 
128         auto inp_ptr_NC = input.data + n * inp_sN;
129         auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW;
130         for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {
131           opmath_t coefficients[4];
132 
133           #pragma unroll 4
134           for (index_t i = 0; i < 4; ++i) {
135             coefficients[i] = cubic_interp1d(
136               get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
137               get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
138               get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
139               get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
140               tx);
141           }
142 
143           *out_ptr_NCHW = cubic_interp1d(
144             coefficients[0],
145             coefficients[1],
146             coefficients[2],
147             coefficients[3],
148             ty);
149         }
150       }
151     }
152   }
153 
154   template <typename scalar_t, typename index_t>
155   C10_LAUNCH_BOUNDS_1(512)
grid_sampler_3d_kernel(const index_t nthreads,TensorInfo<const scalar_t,index_t> input,TensorInfo<const scalar_t,index_t> grid,TensorInfo<scalar_t,index_t> output,const GridSamplerInterpolation interpolation_mode,const GridSamplerPadding padding_mode,bool align_corners)156   __global__ void grid_sampler_3d_kernel(
157       const index_t nthreads,
158       TensorInfo<const scalar_t, index_t> input,
159       TensorInfo<const scalar_t, index_t> grid,
160       TensorInfo<scalar_t, index_t> output,
161       const GridSamplerInterpolation interpolation_mode,
162       const GridSamplerPadding padding_mode,
163       bool align_corners) {
164 
165     using opmath_t = at::opmath_type<scalar_t>;
166     index_t C = input.sizes[1];
167     index_t inp_D = input.sizes[2];
168     index_t inp_H = input.sizes[3];
169     index_t inp_W = input.sizes[4];
170     index_t out_D = grid.sizes[1];
171     index_t out_H = grid.sizes[2];
172     index_t out_W = grid.sizes[3];
173     index_t inp_sN = input.strides[0];
174     index_t inp_sC = input.strides[1];
175     index_t inp_sD = input.strides[2];
176     index_t inp_sH = input.strides[3];
177     index_t inp_sW = input.strides[4];
178     index_t grid_sN = grid.strides[0];
179     index_t grid_sD = grid.strides[1];
180     index_t grid_sH = grid.strides[2];
181     index_t grid_sW = grid.strides[3];
182     index_t grid_sCoor = grid.strides[4];
183     index_t out_sN = output.strides[0];
184     index_t out_sC = output.strides[1];
185     index_t out_sD = output.strides[2];
186     index_t out_sH = output.strides[3];
187     index_t out_sW = output.strides[4];
188 
189     CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
190       const index_t w = index % out_W;
191       const index_t h = (index / out_W) % out_H;
192       const index_t d = (index / (out_H * out_W)) % out_D;
193       const index_t n = index / (out_D * out_H * out_W);
194       const index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
195 
196       // get the corresponding input x, y, z co-ordinates from grid
197       opmath_t x = grid.data[grid_offset];
198       opmath_t y = grid.data[grid_offset + grid_sCoor];
199       opmath_t z = grid.data[grid_offset + 2 * grid_sCoor];
200 
201       opmath_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners);
202       opmath_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners);
203       opmath_t iz = grid_sampler_compute_source_index(z, inp_D, padding_mode, align_corners);
204 
205       if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
206         // get corner pixel values from (x, y, z)
207         // for 4d, we used north-east-south-west
208         // for 5d, we add top-bottom
209         index_t ix_tnw = static_cast<index_t>(::floor(ix));
210         index_t iy_tnw = static_cast<index_t>(::floor(iy));
211         index_t iz_tnw = static_cast<index_t>(::floor(iz));
212 
213         index_t ix_tne = ix_tnw + 1;
214         index_t iy_tne = iy_tnw;
215         index_t iz_tne = iz_tnw;
216 
217         index_t ix_tsw = ix_tnw;
218         index_t iy_tsw = iy_tnw + 1;
219         index_t iz_tsw = iz_tnw;
220 
221         index_t ix_tse = ix_tnw + 1;
222         index_t iy_tse = iy_tnw + 1;
223         index_t iz_tse = iz_tnw;
224 
225         index_t ix_bnw = ix_tnw;
226         index_t iy_bnw = iy_tnw;
227         index_t iz_bnw = iz_tnw + 1;
228 
229         index_t ix_bne = ix_tnw + 1;
230         index_t iy_bne = iy_tnw;
231         index_t iz_bne = iz_tnw + 1;
232 
233         index_t ix_bsw = ix_tnw;
234         index_t iy_bsw = iy_tnw + 1;
235         index_t iz_bsw = iz_tnw + 1;
236 
237         index_t ix_bse = ix_tnw + 1;
238         index_t iy_bse = iy_tnw + 1;
239         index_t iz_bse = iz_tnw + 1;
240 
241         // get surfaces to each neighbor:
242         opmath_t tnw = (ix_bse - ix)    * (iy_bse - iy)    * (iz_bse - iz);
243         opmath_t tne = (ix    - ix_bsw) * (iy_bsw - iy)    * (iz_bsw - iz);
244         opmath_t tsw = (ix_bne - ix)    * (iy    - iy_bne) * (iz_bne - iz);
245         opmath_t tse = (ix    - ix_bnw) * (iy    - iy_bnw) * (iz_bnw - iz);
246         opmath_t bnw = (ix_tse - ix)    * (iy_tse - iy)    * (iz - iz_tse);
247         opmath_t bne = (ix    - ix_tsw) * (iy_tsw - iy)    * (iz - iz_tsw);
248         opmath_t bsw = (ix_tne - ix)    * (iy    - iy_tne) * (iz - iz_tne);
249         opmath_t bse = (ix    - ix_tnw) * (iy    - iy_tnw) * (iz - iz_tnw);
250 
251         auto inp_ptr_NC = input.data + n * inp_sN;
252         auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
253         for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
254           //   (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
255           // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
256           // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
257           // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
258           opmath_t out_acc = 0;
259           if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
260             out_acc += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
261           }
262           if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
263             out_acc += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
264           }
265           if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
266             out_acc += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
267           }
268           if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
269             out_acc += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
270           }
271           if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
272             out_acc += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
273           }
274           if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
275             out_acc += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
276           }
277           if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
278             out_acc += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
279           }
280           if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
281             out_acc += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
282           }
283           *out_ptr_NCDHW = out_acc;
284         }
285       } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
286         index_t ix_nearest = static_cast<index_t>(std::nearbyint(ix));
287         index_t iy_nearest = static_cast<index_t>(std::nearbyint(iy));
288         index_t iz_nearest = static_cast<index_t>(std::nearbyint(iz));
289 
290         // assign nearest neighbour pixel value to output pixel
291         auto inp_ptr_NC = input.data + n * inp_sN;
292         auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
293         for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
294           if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {
295             *out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW];
296           } else {
297             *out_ptr_NCDHW = static_cast<scalar_t>(0);
298           }
299         }
300       }
301     }
302   }
303 
304 // Note [Passing pointer and offset to fastAtomicAdd]
305 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
306 // For its internal bounds checking, fastAtomicAdd needs to know where the destination address
307 // lies relative to the entire tensor, so we pass the base grad_input.data and full offset information,
308 // including batch * channel offset (NC_offset).
309 
310   template <typename scalar_t, typename index_t>
311   C10_LAUNCH_BOUNDS_1(256)
grid_sampler_2d_backward_kernel(const index_t nthreads,TensorInfo<const scalar_t,index_t> grad_output,TensorInfo<const scalar_t,index_t> input,TensorInfo<const scalar_t,index_t> grid,TensorInfo<scalar_t,index_t> grad_input,TensorInfo<scalar_t,index_t> grad_grid,const GridSamplerInterpolation interpolation_mode,const GridSamplerPadding padding_mode,bool align_corners,const index_t grad_input_memory_span,const bool input_requires_grad)312   __global__ void grid_sampler_2d_backward_kernel(
313       const index_t nthreads,
314       TensorInfo<const scalar_t, index_t> grad_output,
315       TensorInfo<const scalar_t, index_t> input,
316       TensorInfo<const scalar_t, index_t> grid,
317       TensorInfo<scalar_t, index_t> grad_input,  // initialized to zeros (or unused if input_requires_grad is false)
318       TensorInfo<scalar_t, index_t> grad_grid,   // initialized to empty
319       const GridSamplerInterpolation interpolation_mode,
320       const GridSamplerPadding padding_mode,
321       bool align_corners,
322       const index_t grad_input_memory_span,
323       const bool input_requires_grad) {
324 
325     index_t C = input.sizes[1];
326     index_t inp_H = input.sizes[2];
327     index_t inp_W = input.sizes[3];
328     index_t out_H = grid.sizes[1];
329     index_t out_W = grid.sizes[2];
330     index_t inp_sN = input.strides[0];
331     index_t inp_sC = input.strides[1];
332     index_t inp_sH = input.strides[2];
333     index_t inp_sW = input.strides[3];
334     index_t grid_sN = grid.strides[0];
335     index_t grid_sH = grid.strides[1];
336     index_t grid_sW = grid.strides[2];
337     index_t grid_sCoor = grid.strides[3];
338     index_t gOut_sN = grad_output.strides[0];
339     index_t gOut_sC = grad_output.strides[1];
340     index_t gOut_sH = grad_output.strides[2];
341     index_t gOut_sW = grad_output.strides[3];
342     // gInp_* (and NC_offset below) are not really needed if input_requires_grad is false.
343     index_t gInp_sN;
344     index_t gInp_sC;
345     index_t gInp_sH;
346     index_t gInp_sW;
347     if (input_requires_grad) {
348       gInp_sN = grad_input.strides[0];
349       gInp_sC = grad_input.strides[1];
350       gInp_sH = grad_input.strides[2];
351       gInp_sW = grad_input.strides[3];
352     }
353     index_t gGrid_sW = grad_grid.strides[2];
354 
355     CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
356       const index_t w = index % out_W;
357       const index_t h = (index / out_W) % out_H;
358       const index_t n = index / (out_H * out_W);
359       const auto grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
360 
361       // get the corresponding input x, y co-ordinates from grid
362       scalar_t x = grid.data[grid_offset];
363       scalar_t y = grid.data[grid_offset + grid_sCoor];
364 
365       // multipliers for gradients on ix and iy
366       scalar_t gix_mult, giy_mult;
367       scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult);
368       scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult);
369 
370       if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
371         // get NE, NW, SE, SW pixel values from (x, y)
372         index_t ix_nw = static_cast<index_t>(std::floor(ix));
373         index_t iy_nw = static_cast<index_t>(std::floor(iy));
374         index_t ix_ne = ix_nw + 1;
375         index_t iy_ne = iy_nw;
376         index_t ix_sw = ix_nw;
377         index_t iy_sw = iy_nw + 1;
378         index_t ix_se = ix_nw + 1;
379         index_t iy_se = iy_nw + 1;
380 
381         // get surfaces to each neighbor:
382         scalar_t nw = (ix_se - ix)    * (iy_se - iy);
383         scalar_t ne = (ix    - ix_sw) * (iy_sw - iy);
384         scalar_t sw = (ix_ne - ix)    * (iy    - iy_ne);
385         scalar_t se = (ix    - ix_nw) * (iy    - iy_nw);
386 
387         scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0);
388         const scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
389         index_t NC_offset = n * gInp_sN;
390         const scalar_t *inp_ptr_NC = input.data + n * inp_sN;
391         for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) {
392           const scalar_t gOut = *gOut_ptr_NCHW;
393 
394           if (input_requires_grad) {
395             // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd].
396             safe_add_2d(grad_input.data, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut, NC_offset, grad_input_memory_span);
397             safe_add_2d(grad_input.data, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut, NC_offset, grad_input_memory_span);
398             safe_add_2d(grad_input.data, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut, NC_offset, grad_input_memory_span);
399             safe_add_2d(grad_input.data, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut, NC_offset, grad_input_memory_span);
400           }
401 
402           // calculate grad_grid
403           if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
404             scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW];
405             gix -= nw_val * (iy_se - iy) * gOut;
406             giy -= nw_val * (ix_se - ix) * gOut;
407           }
408           if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
409             scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW];
410             gix += ne_val * (iy_sw - iy) * gOut;
411             giy -= ne_val * (ix - ix_sw) * gOut;
412           }
413           if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
414             scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW];
415             gix -= sw_val * (iy - iy_ne) * gOut;
416             giy += sw_val * (ix_ne - ix) * gOut;
417           }
418           if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
419             scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW];
420             gix += se_val * (iy - iy_nw) * gOut;
421             giy += se_val * (ix - ix_nw) * gOut;
422           }
423         }
424 
425         // assuming grad_grid is contiguous
426         // thus we can
427         //   1. use index with gGrid_sW to directly compute gGrid_ptr_NHW
428         //   2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1]
429         scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
430         gGrid_ptr_NHW[0] = gix_mult * gix;
431         gGrid_ptr_NHW[1] = giy_mult * giy;
432       } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
433         if (input_requires_grad) {
434           index_t ix_nearest = static_cast<index_t>(std::nearbyint(ix));
435           index_t iy_nearest = static_cast<index_t>(std::nearbyint(iy));
436 
437           // assign nearest neighbour pixel value to output pixel
438           const scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
439           index_t NC_offset = n * gInp_sN;
440           for (index_t c = 0; c < C; ++c, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) {
441             // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd].
442             safe_add_2d(grad_input.data, iy_nearest, ix_nearest, gInp_sH, gInp_sW, inp_H, inp_W, *gOut_ptr_NCHW, NC_offset, grad_input_memory_span);
443           }
444         }
445 
446         // assuming grad_grid is contiguous
447         // thus we can
448         //   1. use index with gGrid_sW to directly compute gGrid_ptr_NHW
449         //   2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1]
450         scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
451         gGrid_ptr_NHW[0] = static_cast<scalar_t>(0);
452         gGrid_ptr_NHW[1] = static_cast<scalar_t>(0);
453       } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
454 
455         ix = grid_sampler_unnormalize_set_grad(x, inp_W, align_corners, &gix_mult);
456         iy = grid_sampler_unnormalize_set_grad(y, inp_H, align_corners, &giy_mult);
457 
458         scalar_t ix_nw = std::floor(ix);
459         scalar_t iy_nw = std::floor(iy);
460 
461         const scalar_t tx = ix - ix_nw;
462         const scalar_t ty = iy - iy_nw;
463 
464         scalar_t x_coeffs[4];
465         scalar_t y_coeffs[4];
466         scalar_t x_coeffs_grad[4];
467         scalar_t y_coeffs_grad[4];
468 
469         get_cubic_upsampling_coefficients<scalar_t>(x_coeffs, tx);
470         get_cubic_upsampling_coefficients<scalar_t>(y_coeffs, ty);
471         get_cubic_coefficients_grad<scalar_t>(x_coeffs_grad, tx);
472         get_cubic_coefficients_grad<scalar_t>(y_coeffs_grad, ty);
473 
474         scalar_t gix = static_cast<scalar_t>(0);
475         scalar_t giy = static_cast<scalar_t>(0);
476 
477         const scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
478         index_t NC_offset = n * gInp_sN;
479         const scalar_t *inp_ptr_NC = input.data + n * inp_sN;
480 
481         for (index_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, NC_offset += gInp_sC, inp_ptr_NC+= inp_sC) {
482           const scalar_t gOut = *gOut_ptr_NCHW;
483 
484           #pragma unroll 4
485           for (index_t i = 0; i < 4; ++i) {
486             #pragma unroll 4
487             for (index_t j = 0; j < 4; ++j) {
488 
489               if (input_requires_grad) {
490                 // set input gradient. See Note [Passing pointer and offset to fastAtomicAdd].
491                 add_value_bounded<scalar_t>(grad_input.data, ix_nw - 1 + i, iy_nw - 1 + j, inp_W, inp_H, gInp_sW, gInp_sH,
492                   gOut * x_coeffs[i] * y_coeffs[j],
493                   padding_mode,
494                   align_corners,
495                   NC_offset,
496                   grad_input_memory_span);
497               }
498 
499               // set grid gradient
500               scalar_t val = get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j,
501                 inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners);
502 
503               gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut;
504               giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut;
505             }
506           }
507         }
508 
509         scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
510         gGrid_ptr_NHW[0] = gix_mult * gix;
511         gGrid_ptr_NHW[1] = giy_mult * giy;
512       }
513     }
514   }
515 
516   template <typename scalar_t, typename index_t>
517   C10_LAUNCH_BOUNDS_1(256)
grid_sampler_3d_backward_kernel(const index_t nthreads,TensorInfo<const scalar_t,index_t> grad_output,TensorInfo<const scalar_t,index_t> input,TensorInfo<const scalar_t,index_t> grid,TensorInfo<scalar_t,index_t> grad_input,TensorInfo<scalar_t,index_t> grad_grid,const GridSamplerInterpolation interpolation_mode,const GridSamplerPadding padding_mode,bool align_corners,const index_t grad_input_memory_span,const bool input_requires_grad)518   __global__ void grid_sampler_3d_backward_kernel(
519       const index_t nthreads,
520       TensorInfo<const scalar_t, index_t> grad_output,
521       TensorInfo<const scalar_t, index_t> input,
522       TensorInfo<const scalar_t, index_t> grid,
523       TensorInfo<scalar_t, index_t> grad_input,  // initialized to zeros (or unused if input_requires_grad is false)
524       TensorInfo<scalar_t, index_t> grad_grid,   // initialized to empty
525       const GridSamplerInterpolation interpolation_mode,
526       const GridSamplerPadding padding_mode,
527       bool align_corners,
528       const index_t grad_input_memory_span,
529       const bool input_requires_grad) {
530 
531     index_t C = input.sizes[1];
532     index_t inp_D = input.sizes[2];
533     index_t inp_H = input.sizes[3];
534     index_t inp_W = input.sizes[4];
535     index_t out_D = grid.sizes[1];
536     index_t out_H = grid.sizes[2];
537     index_t out_W = grid.sizes[3];
538     index_t inp_sN = input.strides[0];
539     index_t inp_sC = input.strides[1];
540     index_t inp_sD = input.strides[2];
541     index_t inp_sH = input.strides[3];
542     index_t inp_sW = input.strides[4];
543     index_t grid_sN = grid.strides[0];
544     index_t grid_sD = grid.strides[1];
545     index_t grid_sH = grid.strides[2];
546     index_t grid_sW = grid.strides[3];
547     index_t grid_sCoor = grid.strides[4];
548     index_t gOut_sN = grad_output.strides[0];
549     index_t gOut_sC = grad_output.strides[1];
550     index_t gOut_sD = grad_output.strides[2];
551     index_t gOut_sH = grad_output.strides[3];
552     index_t gOut_sW = grad_output.strides[4];
553     // gInp_* (and NC_offset below) are not really needed if input_requires_grad is false.
554     int64_t gInp_sN = 0;
555     int64_t gInp_sC = 0;
556     int64_t gInp_sD = 0;
557     int64_t gInp_sH = 0;
558     int64_t gInp_sW = 0;
559     if (input_requires_grad) {
560       gInp_sN = grad_input.strides[0];
561       gInp_sC = grad_input.strides[1];
562       gInp_sD = grad_input.strides[2];
563       gInp_sH = grad_input.strides[3];
564       gInp_sW = grad_input.strides[4];
565     }
566     index_t gGrid_sW = grad_grid.strides[3];
567 
568     CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
569       const index_t w = index % out_W;
570       const index_t h = (index / out_W) % out_H;
571       const index_t d = (index / (out_H * out_W)) % out_D;
572       const index_t n = index / (out_D * out_H * out_W);
573       const auto grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
574 
575       // get the corresponding input x, y, z co-ordinates from grid
576       scalar_t ix = grid.data[grid_offset];
577       scalar_t iy = grid.data[grid_offset + grid_sCoor];
578       scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
579 
580       // multipliers for gradients on ix, iy, and iz
581       scalar_t gix_mult, giy_mult, giz_mult;
582       ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
583       iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
584       iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
585 
586       if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
587         // get corner pixel values from (x, y, z)
588         // for 4d, we used north-east-south-west
589         // for 5d, we add top-bottom
590         index_t ix_tnw = static_cast<index_t>(std::floor(ix));
591         index_t iy_tnw = static_cast<index_t>(std::floor(iy));
592         index_t iz_tnw = static_cast<index_t>(std::floor(iz));
593 
594         index_t ix_tne = ix_tnw + 1;
595         index_t iy_tne = iy_tnw;
596         index_t iz_tne = iz_tnw;
597 
598         index_t ix_tsw = ix_tnw;
599         index_t iy_tsw = iy_tnw + 1;
600         index_t iz_tsw = iz_tnw;
601 
602         index_t ix_tse = ix_tnw + 1;
603         index_t iy_tse = iy_tnw + 1;
604         index_t iz_tse = iz_tnw;
605 
606         index_t ix_bnw = ix_tnw;
607         index_t iy_bnw = iy_tnw;
608         index_t iz_bnw = iz_tnw + 1;
609 
610         index_t ix_bne = ix_tnw + 1;
611         index_t iy_bne = iy_tnw;
612         index_t iz_bne = iz_tnw + 1;
613 
614         index_t ix_bsw = ix_tnw;
615         index_t iy_bsw = iy_tnw + 1;
616         index_t iz_bsw = iz_tnw + 1;
617 
618         index_t ix_bse = ix_tnw + 1;
619         index_t iy_bse = iy_tnw + 1;
620         index_t iz_bse = iz_tnw + 1;
621 
622         // get surfaces to each neighbor:
623         scalar_t tnw = (ix_bse - ix)    * (iy_bse - iy)    * (iz_bse - iz);
624         scalar_t tne = (ix    - ix_bsw) * (iy_bsw - iy)    * (iz_bsw - iz);
625         scalar_t tsw = (ix_bne - ix)    * (iy    - iy_bne) * (iz_bne - iz);
626         scalar_t tse = (ix    - ix_bnw) * (iy    - iy_bnw) * (iz_bnw - iz);
627         scalar_t bnw = (ix_tse - ix)    * (iy_tse - iy)    * (iz - iz_tse);
628         scalar_t bne = (ix    - ix_tsw) * (iy_tsw - iy)    * (iz - iz_tsw);
629         scalar_t bsw = (ix_tne - ix)    * (iy    - iy_tne) * (iz - iz_tne);
630         scalar_t bse = (ix    - ix_tnw) * (iy    - iy_tnw) * (iz - iz_tnw);
631 
632         scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0), giz = static_cast<scalar_t>(0);
633         const scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
634         index_t NC_offset;
635         if (input_requires_grad) {
636           NC_offset = n * gInp_sN;
637         }
638         const scalar_t *inp_ptr_NC = input.data + n * inp_sN;
639         // calculate bilinear weighted pixel value and set output pixel
640         for (index_t c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, NC_offset += gInp_sC, inp_ptr_NC += inp_sC) {
641           scalar_t gOut = *gOut_ptr_NCDHW;
642 
643           // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd].
644           if (input_requires_grad) {
645             safe_add_3d(grad_input.data, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut,
646                         NC_offset, grad_input_memory_span);
647             safe_add_3d(grad_input.data, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut,
648                         NC_offset, grad_input_memory_span);
649             safe_add_3d(grad_input.data, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut,
650                         NC_offset, grad_input_memory_span);
651             safe_add_3d(grad_input.data, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut,
652                         NC_offset, grad_input_memory_span);
653             safe_add_3d(grad_input.data, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut,
654                         NC_offset, grad_input_memory_span);
655             safe_add_3d(grad_input.data, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut,
656                         NC_offset, grad_input_memory_span);
657             safe_add_3d(grad_input.data, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut,
658                         NC_offset, grad_input_memory_span);
659             safe_add_3d(grad_input.data, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut,
660                         NC_offset, grad_input_memory_span);
661           }
662           // calculate grad_grid
663           if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
664             scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
665             gix -= tnw_val * (iy_bse - iy)    * (iz_bse - iz)    * gOut;
666             giy -= tnw_val * (ix_bse - ix)    * (iz_bse - iz)    * gOut;
667             giz -= tnw_val * (ix_bse - ix)    * (iy_bse - iy)    * gOut;
668           }
669           if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
670             scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
671             gix += tne_val * (iy_bsw - iy)    * (iz_bsw - iz)    * gOut;
672             giy -= tne_val * (ix    - ix_bsw) * (iz_bsw - iz)    * gOut;
673             giz -= tne_val * (ix    - ix_bsw) * (iy_bsw - iy)    * gOut;
674           }
675           if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
676             scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
677             gix -= tsw_val * (iy - iy_bne)    * (iz_bne - iz)    * gOut;
678             giy += tsw_val * (ix_bne - ix)    * (iz_bne - iz)    * gOut;
679             giz -= tsw_val * (ix_bne - ix)    * (iy    - iy_bne) * gOut;
680           }
681           if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
682             scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
683             gix += tse_val * (iy - iy_bnw)    * (iz_bnw - iz)    * gOut;
684             giy += tse_val * (ix    - ix_bnw) * (iz_bnw - iz)    * gOut;
685             giz -= tse_val * (ix    - ix_bnw) * (iy    - iy_bnw) * gOut;
686           }
687           if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
688             scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
689             gix -= bnw_val * (iy_tse - iy)    * (iz - iz_tse)    * gOut;
690             giy -= bnw_val * (ix_tse - ix)    * (iz - iz_tse)    * gOut;
691             giz += bnw_val * (ix_tse - ix)    * (iy_tse - iy)    * gOut;
692           }
693           if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
694             scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
695             gix += bne_val * (iy_tsw - iy)    * (iz - iz_tsw)    * gOut;
696             giy -= bne_val * (ix    - ix_tsw) * (iz - iz_tsw)    * gOut;
697             giz += bne_val * (ix    - ix_tsw) * (iy_tsw - iy)    * gOut;
698           }
699           if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
700             scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
701             gix -= bsw_val * (iy - iy_tne)    * (iz - iz_tne)    * gOut;
702             giy += bsw_val * (ix_tne - ix)    * (iz - iz_tne)    * gOut;
703             giz += bsw_val * (ix_tne - ix)    * (iy    - iy_tne) * gOut;
704           }
705           if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
706             scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
707             gix += bse_val * (iy - iy_tnw)    * (iz - iz_tnw)    * gOut;
708             giy += bse_val * (ix    - ix_tnw) * (iz - iz_tnw)    * gOut;
709             giz += bse_val * (ix    - ix_tnw) * (iy    - iy_tnw) * gOut;
710           }
711         }
712 
713         // assuming grad_grid is contiguous
714         // thus we can
715         //   1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW
716         //   2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2]
717         scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
718         gGrid_ptr_NDHW[0] = gix_mult * gix;
719         gGrid_ptr_NDHW[1] = giy_mult * giy;
720         gGrid_ptr_NDHW[2] = giz_mult * giz;
721       } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
722         if (input_requires_grad) {
723           auto ix_nearest = static_cast<index_t>(std::nearbyint(ix));
724           auto iy_nearest = static_cast<index_t>(std::nearbyint(iy));
725           auto iz_nearest = static_cast<index_t>(std::nearbyint(iz));
726 
727           // assign nearest neighbour pixel value to output pixel
728           const scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
729           index_t NC_offset = n * gInp_sN;
730           for (index_t c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, NC_offset += gInp_sC) {
731             // calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd].
732             safe_add_3d(grad_input.data, iz_nearest, iy_nearest, ix_nearest,
733                         gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, *gOut_ptr_NCDHW,
734                         NC_offset, grad_input_memory_span);
735           }
736         }
737         // assuming grad_grid is contiguous
738         // thus we can
739         //   1. use index with gGrid_sW to directly compute gGrid_ptr_NDHW
740         //   2. directly assign to gGrid_ptr_NDHW[0], gGrid_ptr_NDHW[1], gGrid_ptr_NDHW[2]
741         scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
742         gGrid_ptr_NDHW[0] = static_cast<scalar_t>(0);
743         gGrid_ptr_NDHW[1] = static_cast<scalar_t>(0);
744         gGrid_ptr_NDHW[2] = static_cast<scalar_t>(0);
745       }
746     }
747   }
748 }  // namespace
749 
launch_grid_sampler_2d_forward_kernel(const TensorBase & output,const TensorBase & input,const TensorBase & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)750 void launch_grid_sampler_2d_forward_kernel(
751     const TensorBase &output, const TensorBase &input, const TensorBase &grid,
752     int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
753   // See NOTE [ grid_sampler Native Functions ].
754   // Add checks here in case this is called instead of grid_sampler.
755   check_grid_sampler_common(input, grid);
756   check_grid_sampler_2d(input, grid);
757 
758   auto N = input.size(0);
759   auto H = grid.size(1);
760   auto W = grid.size(2);
761   int64_t count = N * H * W;
762   if (count > 0) {
763     AT_DISPATCH_FLOATING_TYPES_AND2(
764       ScalarType::Half, ScalarType::BFloat16,
765       input.scalar_type(), "grid_sampler_2d_cuda", [&] {
766       if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
767           canUse32BitIndexMath(output)) {
768         grid_sampler_2d_kernel<scalar_t>
769           <<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
770             static_cast<int>(count),
771             getTensorInfo<const scalar_t, int>(input),
772             getTensorInfo<const scalar_t, int>(grid),
773             getTensorInfo<scalar_t, int>(output),
774             static_cast<GridSamplerInterpolation>(interpolation_mode),
775             static_cast<GridSamplerPadding>(padding_mode),
776             align_corners);
777         C10_CUDA_KERNEL_LAUNCH_CHECK();
778       } else {
779         grid_sampler_2d_kernel<scalar_t>
780           <<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
781             count,
782             getTensorInfo<const scalar_t, int64_t>(input),
783             getTensorInfo<const scalar_t, int64_t>(grid),
784             getTensorInfo<scalar_t, int64_t>(output),
785             static_cast<GridSamplerInterpolation>(interpolation_mode),
786             static_cast<GridSamplerPadding>(padding_mode),
787             align_corners);
788         C10_CUDA_KERNEL_LAUNCH_CHECK();
789       }
790     });
791   }
792 }
793 
launch_grid_sampler_3d_forward_kernel(const TensorBase & output,const TensorBase & input,const TensorBase & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)794 void launch_grid_sampler_3d_forward_kernel(
795     const TensorBase &output, const TensorBase &input, const TensorBase &grid,
796     int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
797   // See NOTE [ grid_sampler Native Functions ].
798   // Add checks here in case this is called instead of grid_sampler.
799   check_grid_sampler_common(input, grid);
800   check_grid_sampler_3d(input, grid, interpolation_mode);
801 
802   auto N = input.size(0);
803   auto D = grid.size(1);
804   auto H = grid.size(2);
805   auto W = grid.size(3);
806   int64_t count = N * D * H * W;
807   if (count > 0) {
808     AT_DISPATCH_FLOATING_TYPES_AND2(
809       ScalarType::Half, ScalarType::BFloat16,
810       input.scalar_type(), "grid_sampler_3d_cuda", [&] {
811       if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
812           canUse32BitIndexMath(output)) {
813         grid_sampler_3d_kernel<scalar_t>
814           <<<GET_BLOCKS(count, 512), 512, 0, at::cuda::getCurrentCUDAStream()>>>(
815             static_cast<int>(count),
816             getTensorInfo<const scalar_t, int>(input),
817             getTensorInfo<const scalar_t, int>(grid),
818             getTensorInfo<scalar_t, int>(output),
819             static_cast<GridSamplerInterpolation>(interpolation_mode),
820             static_cast<GridSamplerPadding>(padding_mode),
821             align_corners);
822         C10_CUDA_KERNEL_LAUNCH_CHECK();
823       } else {
824         grid_sampler_3d_kernel<scalar_t>
825           <<<GET_BLOCKS(count, 512), 512, 0, at::cuda::getCurrentCUDAStream()>>>(
826             count,
827             getTensorInfo<const scalar_t, int64_t>(input),
828             getTensorInfo<const scalar_t, int64_t>(grid),
829             getTensorInfo<scalar_t, int64_t>(output),
830             static_cast<GridSamplerInterpolation>(interpolation_mode),
831             static_cast<GridSamplerPadding>(padding_mode),
832             align_corners);
833         C10_CUDA_KERNEL_LAUNCH_CHECK();
834       }
835     });
836   }
837 }
838 
launch_grid_sampler_2d_backward_kernel(const TensorBase & grad_input,const TensorBase & grad_grid,const TensorBase & grad_output,const TensorBase & input,const TensorBase & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners,std::array<bool,2> output_mask)839 void launch_grid_sampler_2d_backward_kernel(
840     const TensorBase &grad_input, const TensorBase &grad_grid,
841     const TensorBase &grad_output, const TensorBase &input,
842     const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode,
843     bool align_corners, std::array<bool,2> output_mask) {
844   // See NOTE [ grid_sampler Native Functions ].
845   // Add checks here in case this is called instead of grid_sampler.
846   check_grid_sampler_common(input, grid);
847   check_grid_sampler_2d(input, grid);
848 
849   // See Note [Writing Nondeterministic Operations]
850   // Nondeterministic because of atomicAdd usage
851   globalContext().alertNotDeterministic("grid_sampler_2d_backward_cuda");
852   auto N = input.size(0);
853   auto H = grid.size(1);
854   auto W = grid.size(2);
855 
856   // If `input` gradient is not required, we skip computing it -- not needing to create
857   // the tensor to hold the gradient can markedly increase performance. (`grid` gradient
858   // is always computed.)
859   auto input_requires_grad = output_mask[0];
860 
861   int64_t count = N * H * W;
862   if (count > 0) {
863     AT_DISPATCH_FLOATING_TYPES_AND2(
864       ScalarType::Half, ScalarType::BFloat16,
865       input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] {
866       if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
867           canUse32BitIndexMath(grad_output)) {
868         grid_sampler_2d_backward_kernel<scalar_t>
869           <<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
870             static_cast<int>(count),
871             getTensorInfo<const scalar_t, int>(grad_output),
872             getTensorInfo<const scalar_t, int>(input),
873             getTensorInfo<const scalar_t, int>(grid),
874             input_requires_grad ? getTensorInfo<scalar_t, int>(grad_input) : TensorInfo<scalar_t, int>(),
875             getTensorInfo<scalar_t, int>(grad_grid),
876             static_cast<GridSamplerInterpolation>(interpolation_mode),
877             static_cast<GridSamplerPadding>(padding_mode),
878             align_corners,
879             /*grad_input_memory_span =*/input_requires_grad ? static_cast<int>(grad_input.numel()) : 0,
880             input_requires_grad);
881         C10_CUDA_KERNEL_LAUNCH_CHECK();
882       } else {
883         grid_sampler_2d_backward_kernel<scalar_t>
884           <<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
885             count,
886             getTensorInfo<const scalar_t, int64_t>(grad_output),
887             getTensorInfo<const scalar_t, int64_t>(input),
888             getTensorInfo<const scalar_t, int64_t>(grid),
889             input_requires_grad ? getTensorInfo<scalar_t, int64_t>(grad_input) : TensorInfo<scalar_t, int64_t>(),
890             getTensorInfo<scalar_t, int64_t>(grad_grid),
891             static_cast<GridSamplerInterpolation>(interpolation_mode),
892             static_cast<GridSamplerPadding>(padding_mode),
893             align_corners,
894             /*grad_input_memory_span =*/input_requires_grad ? grad_input.numel() : 0,
895             input_requires_grad);
896         C10_CUDA_KERNEL_LAUNCH_CHECK();
897       }
898     });
899   }
900 }
901 
launch_grid_sampler_3d_backward_kernel(const TensorBase & grad_input,const TensorBase & grad_grid,const TensorBase & grad_output,const TensorBase & input,const TensorBase & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners,std::array<bool,2> output_mask)902 void launch_grid_sampler_3d_backward_kernel(
903     const TensorBase &grad_input, const TensorBase &grad_grid,
904     const TensorBase& grad_output, const TensorBase& input,
905     const TensorBase& grid, int64_t interpolation_mode, int64_t padding_mode,
906     bool align_corners, std::array<bool,2> output_mask) {
907   // See NOTE [ grid_sampler Native Functions ].
908   // Add checks here in case this is called instead of grid_sampler.
909   check_grid_sampler_common(input, grid);
910   check_grid_sampler_3d(input, grid, interpolation_mode);
911 
912   // See Note [Writing Nondeterministic Operations]
913   // Nondeterministic because of atomicAdd usage
914   globalContext().alertNotDeterministic("grid_sampler_3d_backward_cuda");
915   auto N = input.size(0);
916   auto D = grid.size(1);
917   auto H = grid.size(2);
918   auto W = grid.size(3);
919   int64_t count = N * D * H * W;
920   auto input_requires_grad = output_mask[0];
921   if (count > 0) {
922     AT_DISPATCH_FLOATING_TYPES_AND2(
923       ScalarType::Half, ScalarType::BFloat16,
924       input.scalar_type(), "grid_sampler_3d_backward_cuda", [&] {
925       if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
926           canUse32BitIndexMath(grad_output)) {
927         grid_sampler_3d_backward_kernel<scalar_t>
928           <<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
929             static_cast<int>(count),
930             getTensorInfo<const scalar_t, int>(grad_output),
931             getTensorInfo<const scalar_t, int>(input),
932             getTensorInfo<const scalar_t, int>(grid),
933             input_requires_grad ? getTensorInfo<scalar_t, int>(grad_input) : TensorInfo<scalar_t, int>(),
934             getTensorInfo<scalar_t, int>(grad_grid),
935             static_cast<GridSamplerInterpolation>(interpolation_mode),
936             static_cast<GridSamplerPadding>(padding_mode),
937             align_corners,
938             /*grad_input_memory_span =*/input_requires_grad ? static_cast<int>(grad_input.numel()) : 0,
939             input_requires_grad);
940         C10_CUDA_KERNEL_LAUNCH_CHECK();
941       } else {
942         grid_sampler_3d_backward_kernel<scalar_t>
943           <<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
944             count,
945             getTensorInfo<const scalar_t, int64_t>(grad_output),
946             getTensorInfo<const scalar_t, int64_t>(input),
947             getTensorInfo<const scalar_t, int64_t>(grid),
948             input_requires_grad ? getTensorInfo<scalar_t, int64_t>(grad_input) : TensorInfo<scalar_t, int64_t>(),
949             getTensorInfo<scalar_t, int64_t>(grad_grid),
950             static_cast<GridSamplerInterpolation>(interpolation_mode),
951             static_cast<GridSamplerPadding>(padding_mode),
952             align_corners,
953             /*grad_input_memory_span =*/input_requires_grad ? grad_input.numel() : 0,
954             input_requires_grad);
955         C10_CUDA_KERNEL_LAUNCH_CHECK();
956       }
957     });
958   }
959 }
960 
961 }  // namespace at::native
962