xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/GridSampler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/GridSampler.h>
3 #include <ATen/native/GridSamplerUtils.h>
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/native/UpSample.h>
9 #include <ATen/native/cpu/GridSamplerKernel.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/irange.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/_empty_affine_quantized.h>
18 #include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_native.h>
19 #include <ATen/ops/_grid_sampler_2d_cpu_fallback_native.h>
20 #include <ATen/ops/cudnn_grid_sampler.h>
21 #include <ATen/ops/empty.h>
22 #include <ATen/ops/empty_like.h>
23 #include <ATen/ops/grid_sampler_2d.h>
24 #include <ATen/ops/grid_sampler_2d_backward_native.h>
25 #include <ATen/ops/grid_sampler_2d_native.h>
26 #include <ATen/ops/grid_sampler_3d.h>
27 #include <ATen/ops/grid_sampler_3d_backward_native.h>
28 #include <ATen/ops/grid_sampler_3d_native.h>
29 #include <ATen/ops/grid_sampler_native.h>
30 #include <ATen/ops/zeros_like.h>
31 #endif
32 
33 namespace at::native {
34 
35 using at::native::detail::GridSamplerInterpolation;
36 using at::native::detail::GridSamplerPadding;
37 
38 namespace {
39 
40   template<typename scalar_t>
grid_sampler_3d_cpu_impl(const Tensor & input,const Tensor & grid,GridSamplerInterpolation interpolation_mode,GridSamplerPadding padding_mode,bool align_corners)41   Tensor grid_sampler_3d_cpu_impl(const Tensor& input, const Tensor& grid,
42                                   GridSamplerInterpolation interpolation_mode,
43                                   GridSamplerPadding padding_mode,
44                                   bool align_corners) {
45     // See NOTE [ grid_sampler Native Functions ].
46     // Add checks here in case this is called instead of grid_sampler.
47     check_grid_sampler_common(input, grid);
48     check_grid_sampler_3d(
49       input, grid, static_cast<int64_t>(interpolation_mode));
50 
51     int64_t N = input.size(0);
52     int64_t C = input.size(1);
53     int64_t inp_D = input.size(2);
54     int64_t inp_H = input.size(3);
55     int64_t inp_W = input.size(4);
56     int64_t out_D = grid.size(1);
57     int64_t out_H = grid.size(2);
58     int64_t out_W = grid.size(3);
59     auto output = at::empty({N, C, out_D, out_H, out_W}, input.options());
60     if (output.numel() == 0) {
61         return output;
62     }
63     int64_t inp_sN = input.stride(0);
64     int64_t inp_sC = input.stride(1);
65     int64_t inp_sD = input.stride(2);
66     int64_t inp_sH = input.stride(3);
67     int64_t inp_sW = input.stride(4);
68     int64_t grid_sN = grid.stride(0);
69     int64_t grid_sD = grid.stride(1);
70     int64_t grid_sH = grid.stride(2);
71     int64_t grid_sW = grid.stride(3);
72     int64_t grid_sCoor = grid.stride(4);
73     int64_t out_sN = output.stride(0);
74     int64_t out_sC = output.stride(1);
75     int64_t out_sD = output.stride(2);
76     int64_t out_sH = output.stride(3);
77     int64_t out_sW = output.stride(4);
78     const scalar_t *inp_ptr = input.const_data_ptr<scalar_t>();
79     scalar_t *out_ptr = output.data_ptr<scalar_t>();
80     const scalar_t *grid_ptr = grid.const_data_ptr<scalar_t>();
81     // loop over each output pixel
82     at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
83       for (const auto n : c10::irange(start, end)) {
84         const scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
85         const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
86         for (const auto d : c10::irange(out_D)) {
87           for (const auto h : c10::irange(out_H)) {
88             for (const auto w : c10::irange(out_W)) {
89               // get the corresponding input x, y, z co-ordinates from grid
90               const scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW;
91               scalar_t ix = *grid_ptr_NDHW;
92               scalar_t iy = grid_ptr_NDHW[grid_sCoor];
93               scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor];
94 
95               ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
96               iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
97               iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners);
98 
99               if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
100                 // get corner pixel values from (x, y, z)
101                 // for 4d, we used north-east-south-west
102                 // for 5d, we add top-bottom
103                 int64_t ix_tnw = static_cast<int64_t>(std::floor(ix));
104                 int64_t iy_tnw = static_cast<int64_t>(std::floor(iy));
105                 int64_t iz_tnw = static_cast<int64_t>(std::floor(iz));
106 
107                 int64_t ix_tne = ix_tnw + 1;
108                 int64_t iy_tne = iy_tnw;
109                 int64_t iz_tne = iz_tnw;
110 
111                 int64_t ix_tsw = ix_tnw;
112                 int64_t iy_tsw = iy_tnw + 1;
113                 int64_t iz_tsw = iz_tnw;
114 
115                 int64_t ix_tse = ix_tnw + 1;
116                 int64_t iy_tse = iy_tnw + 1;
117                 int64_t iz_tse = iz_tnw;
118 
119                 int64_t ix_bnw = ix_tnw;
120                 int64_t iy_bnw = iy_tnw;
121                 int64_t iz_bnw = iz_tnw + 1;
122 
123                 int64_t ix_bne = ix_tnw + 1;
124                 int64_t iy_bne = iy_tnw;
125                 int64_t iz_bne = iz_tnw + 1;
126 
127                 int64_t ix_bsw = ix_tnw;
128                 int64_t iy_bsw = iy_tnw + 1;
129                 int64_t iz_bsw = iz_tnw + 1;
130 
131                 int64_t ix_bse = ix_tnw + 1;
132                 int64_t iy_bse = iy_tnw + 1;
133                 int64_t iz_bse = iz_tnw + 1;
134 
135                 // get surfaces to each neighbor:
136                 scalar_t tnw = (ix_bse - ix)    * (iy_bse - iy)    * (iz_bse - iz);
137                 scalar_t tne = (ix    - ix_bsw) * (iy_bsw - iy)    * (iz_bsw - iz);
138                 scalar_t tsw = (ix_bne - ix)    * (iy    - iy_bne) * (iz_bne - iz);
139                 scalar_t tse = (ix    - ix_bnw) * (iy    - iy_bnw) * (iz_bnw - iz);
140                 scalar_t bnw = (ix_tse - ix)    * (iy_tse - iy)    * (iz - iz_tse);
141                 scalar_t bne = (ix    - ix_tsw) * (iy_tsw - iy)    * (iz - iz_tsw);
142                 scalar_t bsw = (ix_tne - ix)    * (iy    - iy_tne) * (iz - iz_tne);
143                 scalar_t bse = (ix    - ix_tnw) * (iy    - iy_tnw) * (iz - iz_tnw);
144 
145                 // calculate bilinear weighted pixel value and set output pixel
146                 scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
147                 const scalar_t *inp_ptr_NC = inp_ptr_N;
148                 for (int64_t c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) {
149                   //   (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
150                   // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
151                   // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
152                   // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
153                   *out_ptr_NCDHW = static_cast<scalar_t>(0);
154                   if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
155                     *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
156                   }
157                   if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
158                     *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
159                   }
160                   if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
161                     *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
162                   }
163                   if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
164                     *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
165                   }
166                   if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
167                     *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
168                   }
169                   if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
170                     *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
171                   }
172                   if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
173                     *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
174                   }
175                   if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
176                     *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
177                   }
178                 }
179               } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
180                 int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
181                 int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
182                 int64_t iz_nearest = static_cast<int64_t>(std::nearbyint(iz));
183 
184                 // assign nearest neighbour pixel value to output pixel
185                 scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
186                 const scalar_t *inp_ptr_NC = inp_ptr_N;
187                 for (int64_t c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) {
188                   if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {
189                     *out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW];
190                   } else {
191                     *out_ptr_NCDHW = static_cast<scalar_t>(0);
192                   }
193                 }
194               }
195             }
196           }
197         }
198       }
199     });
200     return output;
201   }
202 
203   template<typename scalar_t>
204   std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cpu_impl(const Tensor & grad_output,const Tensor & input,const Tensor & grid,GridSamplerInterpolation interpolation_mode,GridSamplerPadding padding_mode,bool align_corners,std::array<bool,2> output_mask)205   grid_sampler_3d_backward_cpu_impl(const Tensor& grad_output,
206                                     const Tensor& input, const Tensor& grid,
207                                     GridSamplerInterpolation interpolation_mode,
208                                     GridSamplerPadding padding_mode,
209                                     bool align_corners, std::array<bool,2> output_mask) {
210     // See NOTE [ grid_sampler Native Functions ].
211     // Add checks here in case this is called instead of grid_sampler.
212     check_grid_sampler_common(input, grid);
213     check_grid_sampler_3d(
214       input, grid, static_cast<int64_t>(interpolation_mode));
215 
216     auto input_requires_grad = output_mask[0];
217     Tensor grad_input = ([&]() {
218       if (input_requires_grad) {
219         return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
220       } else {
221         return Tensor();
222       }
223     })();
224     auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
225     if (grid.numel() == 0 || input.numel() == 0) {
226       grad_grid.zero_();
227       return std::make_tuple(grad_input, grad_grid);
228     }
229     // If interpolation mode is Nearest, then grad_grid is not filled in the
230     // loop below.
231     if (interpolation_mode == GridSamplerInterpolation::Nearest) {
232       grad_grid.zero_();
233     }
234     int64_t N = input.size(0);
235     int64_t C = input.size(1);
236     int64_t inp_D = input.size(2);
237     int64_t inp_H = input.size(3);
238     int64_t inp_W = input.size(4);
239     int64_t out_D = grid.size(1);
240     int64_t out_H = grid.size(2);
241     int64_t out_W = grid.size(3);
242     int64_t inp_sN = input.stride(0);
243     int64_t inp_sC = input.stride(1);
244     int64_t inp_sD = input.stride(2);
245     int64_t inp_sH = input.stride(3);
246     int64_t inp_sW = input.stride(4);
247     int64_t grid_sN = grid.stride(0);
248     int64_t grid_sD = grid.stride(1);
249     int64_t grid_sH = grid.stride(2);
250     int64_t grid_sW = grid.stride(3);
251     int64_t grid_sCoor = grid.stride(4);
252     int64_t gOut_sN = grad_output.stride(0);
253     int64_t gOut_sC = grad_output.stride(1);
254     int64_t gOut_sD = grad_output.stride(2);
255     int64_t gOut_sH = grad_output.stride(3);
256     int64_t gOut_sW = grad_output.stride(4);
257     int64_t gInp_sN = 0;
258     int64_t gInp_sC = 0;
259     int64_t gInp_sD = 0;
260     int64_t gInp_sH = 0;
261     int64_t gInp_sW = 0;
262     if (input_requires_grad) {
263       gInp_sN = grad_input.stride(0);
264       gInp_sC = grad_input.stride(1);
265       gInp_sD = grad_input.stride(2);
266       gInp_sH = grad_input.stride(3);
267       gInp_sW = grad_input.stride(4);
268     }
269     int64_t gGrid_sN = grad_grid.stride(0);
270     int64_t gGrid_sW = grad_grid.stride(3);
271     const scalar_t *inp_ptr = input.const_data_ptr<scalar_t>();
272     const scalar_t *grid_ptr = grid.const_data_ptr<scalar_t>();
273     const scalar_t *gOut_ptr = grad_output.const_data_ptr<scalar_t>();
274     scalar_t *gInp_ptr = nullptr;
275     if (input_requires_grad) {
276       gInp_ptr = grad_input.mutable_data_ptr<scalar_t>();
277     }
278     scalar_t *gGrid_ptr = grad_grid.data_ptr<scalar_t>();
279     // loop over each output pixel
280     at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
281       for (const auto n : c10::irange(start, end)) {
282         const scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
283         const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
284         scalar_t *gGrid_ptr_NDHW = gGrid_ptr + n * gGrid_sN;
285         for (const auto d : c10::irange(out_D)) {
286           for (const auto h : c10::irange(out_H)) {
287             for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NDHW += gGrid_sW /* grad_grid is contiguous */ ) {
288               // get the corresponding input x, y, z co-ordinates from grid
289               const scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW;
290               scalar_t ix = *grid_ptr_NDHW;
291               scalar_t iy = grid_ptr_NDHW[grid_sCoor];
292               scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor];
293 
294               // multipliers for gradients on ix, iy, and iz
295               scalar_t gix_mult, giy_mult, giz_mult;
296               ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
297               iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
298               iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
299 
300               if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
301                 // get corner pixel values from (x, y, z)
302                 // for 4d, we used north-east-south-west
303                 // for 5d, we add top-bottom
304                 int64_t ix_tnw = static_cast<int64_t>(std::floor(ix));
305                 int64_t iy_tnw = static_cast<int64_t>(std::floor(iy));
306                 int64_t iz_tnw = static_cast<int64_t>(std::floor(iz));
307 
308                 int64_t ix_tne = ix_tnw + 1;
309                 int64_t iy_tne = iy_tnw;
310                 int64_t iz_tne = iz_tnw;
311 
312                 int64_t ix_tsw = ix_tnw;
313                 int64_t iy_tsw = iy_tnw + 1;
314                 int64_t iz_tsw = iz_tnw;
315 
316                 int64_t ix_tse = ix_tnw + 1;
317                 int64_t iy_tse = iy_tnw + 1;
318                 int64_t iz_tse = iz_tnw;
319 
320                 int64_t ix_bnw = ix_tnw;
321                 int64_t iy_bnw = iy_tnw;
322                 int64_t iz_bnw = iz_tnw + 1;
323 
324                 int64_t ix_bne = ix_tnw + 1;
325                 int64_t iy_bne = iy_tnw;
326                 int64_t iz_bne = iz_tnw + 1;
327 
328                 int64_t ix_bsw = ix_tnw;
329                 int64_t iy_bsw = iy_tnw + 1;
330                 int64_t iz_bsw = iz_tnw + 1;
331 
332                 int64_t ix_bse = ix_tnw + 1;
333                 int64_t iy_bse = iy_tnw + 1;
334                 int64_t iz_bse = iz_tnw + 1;
335 
336                 // get surfaces to each neighbor:
337                 scalar_t tnw = (ix_bse - ix)    * (iy_bse - iy)    * (iz_bse - iz);
338                 scalar_t tne = (ix    - ix_bsw) * (iy_bsw - iy)    * (iz_bsw - iz);
339                 scalar_t tsw = (ix_bne - ix)    * (iy    - iy_bne) * (iz_bne - iz);
340                 scalar_t tse = (ix    - ix_bnw) * (iy    - iy_bnw) * (iz_bnw - iz);
341                 scalar_t bnw = (ix_tse - ix)    * (iy_tse - iy)    * (iz - iz_tse);
342                 scalar_t bne = (ix    - ix_tsw) * (iy_tsw - iy)    * (iz - iz_tsw);
343                 scalar_t bsw = (ix_tne - ix)    * (iy    - iy_tne) * (iz - iz_tne);
344                 scalar_t bse = (ix    - ix_tnw) * (iy    - iy_tnw) * (iz - iz_tnw);
345 
346                 scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0), giz = static_cast<scalar_t>(0);
347                 const scalar_t *gOut_ptr_NCDHW = gOut_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
348                 const scalar_t *inp_ptr_NC = inp_ptr_N;
349                 scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
350                 // calculate bilinear weighted pixel value and set output pixel
351                 for (int64_t c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
352                   scalar_t gOut = *gOut_ptr_NCDHW;
353 
354                   // calculate and set grad_input
355                   if (input_requires_grad) {
356                     safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
357                     safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
358                     safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
359                     safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
360                     safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
361                     safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
362                     safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
363                     safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);
364                   }
365                   // calculate grad_grid
366                   if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
367                     scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
368                     gix -= tnw_val * (iy_bse - iy)    * (iz_bse - iz)    * gOut;
369                     giy -= tnw_val * (ix_bse - ix)    * (iz_bse - iz)    * gOut;
370                     giz -= tnw_val * (ix_bse - ix)    * (iy_bse - iy)    * gOut;
371                   }
372                   if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
373                     scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
374                     gix += tne_val * (iy_bsw - iy)    * (iz_bsw - iz)    * gOut;
375                     giy -= tne_val * (ix    - ix_bsw) * (iz_bsw - iz)    * gOut;
376                     giz -= tne_val * (ix    - ix_bsw) * (iy_bsw - iy)    * gOut;
377                   }
378                   if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
379                     scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
380                     gix -= tsw_val * (iy - iy_bne)    * (iz_bne - iz)    * gOut;
381                     giy += tsw_val * (ix_bne - ix)    * (iz_bne - iz)    * gOut;
382                     giz -= tsw_val * (ix_bne - ix)    * (iy    - iy_bne) * gOut;
383                   }
384                   if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
385                     scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
386                     gix += tse_val * (iy - iy_bnw)    * (iz_bnw - iz)    * gOut;
387                     giy += tse_val * (ix    - ix_bnw) * (iz_bnw - iz)    * gOut;
388                     giz -= tse_val * (ix    - ix_bnw) * (iy    - iy_bnw) * gOut;
389                   }
390                   if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
391                     scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
392                     gix -= bnw_val * (iy_tse - iy)    * (iz - iz_tse)    * gOut;
393                     giy -= bnw_val * (ix_tse - ix)    * (iz - iz_tse)    * gOut;
394                     giz += bnw_val * (ix_tse - ix)    * (iy_tse - iy)    * gOut;
395                   }
396                   if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
397                     scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
398                     gix += bne_val * (iy_tsw - iy)    * (iz - iz_tsw)    * gOut;
399                     giy -= bne_val * (ix    - ix_tsw) * (iz - iz_tsw)    * gOut;
400                     giz += bne_val * (ix    - ix_tsw) * (iy_tsw - iy)    * gOut;
401                   }
402                   if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
403                     scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
404                     gix -= bsw_val * (iy - iy_tne)    * (iz - iz_tne)    * gOut;
405                     giy += bsw_val * (ix_tne - ix)    * (iz - iz_tne)    * gOut;
406                     giz += bsw_val * (ix_tne - ix)    * (iy    - iy_tne) * gOut;
407                   }
408                   if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
409                     scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
410                     gix += bse_val * (iy - iy_tnw)    * (iz - iz_tnw)    * gOut;
411                     giy += bse_val * (ix    - ix_tnw) * (iz - iz_tnw)    * gOut;
412                     giz += bse_val * (ix    - ix_tnw) * (iy    - iy_tnw) * gOut;
413                   }
414                 }
415 
416                 // assuming grad_grid is contiguous
417                 gGrid_ptr_NDHW[0] = gix_mult * gix;
418                 gGrid_ptr_NDHW[1] = giy_mult * giy;
419                 gGrid_ptr_NDHW[2] = giz_mult * giz;
420               } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
421                 int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
422                 int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
423                 int64_t iz_nearest = static_cast<int64_t>(std::nearbyint(iz));
424 
425                 // assign nearest neighbour pixel value to output pixel
426                 const scalar_t *gOut_ptr_NCDHW = gOut_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
427                 if (input_requires_grad) {
428                   scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
429                   for (int64_t c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC) {
430                     // calculate and set grad_input
431                     safe_add_3d(gInp_ptr_NC, iz_nearest, iy_nearest, ix_nearest,
432                                 gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, *gOut_ptr_NCDHW);
433                   }
434                 }
435               }
436             }
437           }
438         }
439       }
440     });
441     return std::make_tuple(grad_input, grad_grid);
442   }
443 
444 }  // namespace
445 
_grid_sampler_2d_cpu_quantized(const Tensor & input,const Tensor & grid,int64_t interpolation_mode_,int64_t padding_mode_,bool align_corners)446 static Tensor _grid_sampler_2d_cpu_quantized(
447     const Tensor& input,
448     const Tensor& grid,
449     int64_t interpolation_mode_,
450     int64_t padding_mode_,
451     bool align_corners) {
452   // See NOTE [ grid_sampler Native Functions ].
453   // Add checks here in case this is called instead of grid_sampler.
454   check_grid_sampler_common(input, grid);
455   check_grid_sampler_2d(input, grid);
456 
457   auto interpolation_mode =
458       static_cast<GridSamplerInterpolation>(interpolation_mode_);
459   /* Bilinear interpolation is supported using the fact that we can perform
460    * linear interpolations on quantized values without rescaling. */
461   TORCH_CHECK(
462       interpolation_mode == GridSamplerInterpolation::Bilinear,
463       "_grid_sampler_2d_cpu_quantized(): only bilinear interpolation supported")
464   auto padding_mode = static_cast<GridSamplerPadding>(padding_mode_);
465 
466   int64_t N = input.size(0);
467   int64_t C = input.size(1);
468   int64_t inp_H = input.size(2);
469   int64_t inp_W = input.size(3);
470   int64_t out_H = grid.size(1);
471   int64_t out_W = grid.size(2);
472   uint8_t zero_point = input.q_zero_point();
473   auto output = at::_empty_affine_quantized(
474       {N, C, out_H, out_W},
475       at::device(c10::kCPU).dtype(c10::kQUInt8),
476       input.q_scale(),
477       zero_point);
478   int64_t inp_sN = input.stride(0);
479   int64_t inp_sC = input.stride(1);
480   int64_t inp_sH = input.stride(2);
481   int64_t inp_sW = input.stride(3);
482   int64_t grid_sN = grid.stride(0);
483   int64_t grid_sH = grid.stride(1);
484   int64_t grid_sW = grid.stride(2);
485   int64_t grid_sCoor = grid.stride(3);
486   int64_t out_sN = output.stride(0);
487   int64_t out_sC = output.stride(1);
488   int64_t out_sH = output.stride(2);
489   int64_t out_sW = output.stride(3);
490   uint8_t* inp_ptr = (uint8_t*)input.data_ptr<quint8>();
491   uint8_t* out_ptr = (uint8_t*)output.data_ptr<quint8>();
492   float* grid_ptr = grid.data_ptr<float>();
493   at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
494     for (const auto n : c10::irange(start, end)) {
495       float* grid_ptr_N = grid_ptr + n * grid_sN;
496       uint8_t* inp_ptr_N = inp_ptr + n * inp_sN;
497       for (const auto h : c10::irange(out_H)) {
498         for (const auto w : c10::irange(out_W)) {
499           // get the corresponding input x, y, z co-ordinates from grid
500           float* grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
501           float x = *grid_ptr_NHW;
502           float y = grid_ptr_NHW[grid_sCoor];
503 
504           float ix = grid_sampler_compute_source_index(
505               x, inp_W, padding_mode, align_corners);
506           float iy = grid_sampler_compute_source_index(
507               y, inp_H, padding_mode, align_corners);
508 
509           // get corner pixel values from (x, y)
510           // for 4d, we use north-east-south-west
511           int64_t ix_nw = static_cast<int64_t>(std::floor(ix));
512           int64_t iy_nw = static_cast<int64_t>(std::floor(iy));
513 
514           int64_t ix_ne = ix_nw + 1;
515           int64_t iy_ne = iy_nw;
516 
517           int64_t ix_sw = ix_nw;
518           int64_t iy_sw = iy_nw + 1;
519 
520           int64_t ix_se = ix_nw + 1;
521           int64_t iy_se = iy_nw + 1;
522 
523           // get surfaces to each neighbor:
524           float nw = (ix_se - ix) * (iy_se - iy);
525           float ne = (ix - ix_sw) * (iy_sw - iy);
526           float sw = (ix_ne - ix) * (iy - iy_ne);
527           float se = (ix - ix_nw) * (iy - iy_nw);
528 
529           // calculate bilinear weighted pixel value and set output pixel
530           uint8_t* inp_ptr_NC = inp_ptr_N;
531           uint8_t* out_ptr_NCHW =
532               out_ptr + n * out_sN + h * out_sH + w * out_sW;
533           for (int64_t c = 0; c < C;
534                ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
535             float res = 0;
536             res += within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)
537                 ? inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw
538                 : zero_point * nw;
539             res += within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)
540                 ? inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne
541                 : zero_point * ne;
542             res += within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)
543                 ? inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw
544                 : zero_point * sw;
545             res += within_bounds_2d(iy_se, ix_se, inp_H, inp_W)
546                 ? inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se
547                 : zero_point * se;
548             *out_ptr_NCHW = std::nearbyint(res);
549           }
550         }
551       }
552     }
553   });
554   return output;
555 }
556 
_grid_sampler_2d_cpu_fallback(const Tensor & input,const Tensor & grid,int64_t interpolation_mode_,int64_t padding_mode_,bool align_corners)557 Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid,
558                                      int64_t interpolation_mode_,
559                                      int64_t padding_mode_,
560                                      bool align_corners) {
561   // See NOTE [ grid_sampler Native Functions ].
562   // Add checks here in case this is called instead of grid_sampler.
563   check_grid_sampler_common(input, grid);
564   check_grid_sampler_2d(input, grid);
565 
566   auto interpolation_mode = static_cast<GridSamplerInterpolation>(interpolation_mode_);
567   auto padding_mode = static_cast<GridSamplerPadding>(padding_mode_);
568   using scalar_t = float;
569 
570   int64_t N = input.size(0);
571   int64_t C = input.size(1);
572   int64_t inp_H = input.size(2);
573   int64_t inp_W = input.size(3);
574   int64_t out_H = grid.size(1);
575   int64_t out_W = grid.size(2);
576   auto output = at::empty({N, C, out_H, out_W}, input.options());
577   if (output.numel() == 0) {
578       return output;
579   }
580   int64_t inp_sN = input.stride(0);
581   int64_t inp_sC = input.stride(1);
582   int64_t inp_sH = input.stride(2);
583   int64_t inp_sW = input.stride(3);
584   int64_t grid_sN = grid.stride(0);
585   int64_t grid_sH = grid.stride(1);
586   int64_t grid_sW = grid.stride(2);
587   int64_t grid_sCoor = grid.stride(3);
588   int64_t out_sN = output.stride(0);
589   int64_t out_sC = output.stride(1);
590   int64_t out_sH = output.stride(2);
591   int64_t out_sW = output.stride(3);
592   const scalar_t *inp_ptr = input.const_data_ptr<scalar_t>();
593   scalar_t *out_ptr = output.data_ptr<scalar_t>();
594   const scalar_t *grid_ptr = grid.const_data_ptr<scalar_t>();
595   // loop over each output pixel
596   at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
597     for (const auto n : c10::irange(start, end)) {
598       const scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
599       const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
600       for (const auto h : c10::irange(out_H)) {
601         for (const auto w : c10::irange(out_W)) {
602           // get the corresponding input x, y, z co-ordinates from grid
603           const scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
604           scalar_t x = *grid_ptr_NHW;
605           scalar_t y = grid_ptr_NHW[grid_sCoor];
606 
607           scalar_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners);
608           scalar_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners);
609 
610           if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
611             // get corner pixel values from (x, y)
612             // for 4d, we use north-east-south-west
613             int64_t ix_nw = static_cast<int64_t>(std::floor(ix));
614             int64_t iy_nw = static_cast<int64_t>(std::floor(iy));
615 
616             int64_t ix_ne = ix_nw + 1;
617             int64_t iy_ne = iy_nw;
618 
619             int64_t ix_sw = ix_nw;
620             int64_t iy_sw = iy_nw + 1;
621 
622             int64_t ix_se = ix_nw + 1;
623             int64_t iy_se = iy_nw + 1;
624 
625 
626             // get surfaces to each neighbor:
627             scalar_t nw = (ix_se - ix)    * (iy_se - iy);
628             scalar_t ne = (ix    - ix_sw) * (iy_sw - iy);
629             scalar_t sw = (ix_ne - ix)    * (iy    - iy_ne);
630             scalar_t se = (ix    - ix_nw) * (iy    - iy_nw);
631 
632             // calculate bilinear weighted pixel value and set output pixel
633             const scalar_t *inp_ptr_NC = inp_ptr_N;
634             scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
635             for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
636               auto res = static_cast<scalar_t>(0);
637               if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
638                 res += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw;
639               }
640               if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
641                 res += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne;
642               }
643               if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
644                 res += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw;
645               }
646               if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
647                 res += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se;
648               }
649               *out_ptr_NCHW = res;
650             }
651           } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
652             int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
653             int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
654 
655             // assign nearest neighbour pixel value to output pixel
656             scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
657             const scalar_t *inp_ptr_NC = inp_ptr_N;
658             for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
659               if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) {
660                 *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW];
661               } else {
662                 *out_ptr_NCHW = static_cast<scalar_t>(0);
663               }
664             }
665           } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
666             // grid_sampler_compute_source_index will "clip the value" of idx depends on the padding,
667             // which would cause calculation to be wrong,
668             // for example x = -0.1 -> ix = 0 for zero padding, but in bicubic ix = floor(x) = -1
669             // There would be more problem in reflection padding, since the -1 and +1 direction is not fixed in boundary condition
670             ix = grid_sampler_unnormalize(x, inp_W, align_corners);
671             iy = grid_sampler_unnormalize(y, inp_H, align_corners);
672 
673             scalar_t ix_nw = std::floor(ix);
674             scalar_t iy_nw = std::floor(iy);
675 
676             const scalar_t tx = ix - ix_nw;
677             const scalar_t ty = iy - iy_nw;
678 
679             const scalar_t *inp_ptr_NC = inp_ptr_N;
680             scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
681             for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
682               // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
683               scalar_t coefficients[4];
684 
685               // Interpolate 4 values in the x direction
686               for (const auto i : c10::irange(4)) {
687                 coefficients[i] = cubic_interp1d<scalar_t>(
688                   get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
689                   get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
690                   get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
691                   get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
692                   tx);
693               }
694 
695               // Interpolate in the y direction
696               *out_ptr_NCHW = cubic_interp1d<scalar_t>(
697                 coefficients[0],
698                 coefficients[1],
699                 coefficients[2],
700                 coefficients[3],
701                 ty);
702             }
703           }
704         }
705       }
706     }
707   });
708   return output;
709 }
710 
711 std::tuple<Tensor, Tensor>
_grid_sampler_2d_cpu_fallback_backward(const Tensor & grad_output,const Tensor & input,const Tensor & grid,int64_t interpolation_mode_,int64_t padding_mode_,bool align_corners)712 _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output,
713                                        const Tensor& input, const Tensor& grid,
714                                        int64_t interpolation_mode_,
715                                        int64_t padding_mode_,
716                                        bool align_corners) {
717   // See NOTE [ grid_sampler Native Functions ].
718   // Add checks here in case this is called instead of grid_sampler.
719   check_grid_sampler_common(input, grid);
720   check_grid_sampler_2d(input, grid);
721 
722   const auto interpolation_mode = static_cast<GridSamplerInterpolation>(interpolation_mode_);
723   const auto padding_mode = static_cast<GridSamplerPadding>(padding_mode_);
724   using scalar_t = float;
725 
726   auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
727   auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
728   if (grid.numel() == 0 || input.numel() == 0) {
729     grad_grid.zero_();
730     return std::make_tuple(grad_input, grad_grid);
731   }
732   // If interpolation mode is Nearest, then grad_grid is not filled in the
733   // loop below.
734   if (interpolation_mode == GridSamplerInterpolation::Nearest) {
735     grad_grid.zero_();
736   }
737   int64_t N = input.size(0);
738   int64_t C = input.size(1);
739   int64_t inp_H = input.size(2);
740   int64_t inp_W = input.size(3);
741   int64_t out_H = grid.size(1);
742   int64_t out_W = grid.size(2);
743   int64_t inp_sN = input.stride(0);
744   int64_t inp_sC = input.stride(1);
745   int64_t inp_sH = input.stride(2);
746   int64_t inp_sW = input.stride(3);
747   int64_t grid_sN = grid.stride(0);
748   int64_t grid_sH = grid.stride(1);
749   int64_t grid_sW = grid.stride(2);
750   int64_t grid_sCoor = grid.stride(3);
751   int64_t gOut_sN = grad_output.stride(0);
752   int64_t gOut_sC = grad_output.stride(1);
753   int64_t gOut_sH = grad_output.stride(2);
754   int64_t gOut_sW = grad_output.stride(3);
755   int64_t gInp_sN = grad_input.stride(0);
756   int64_t gInp_sC = grad_input.stride(1);
757   int64_t gInp_sH = grad_input.stride(2);
758   int64_t gInp_sW = grad_input.stride(3);
759   int64_t gGrid_sN = grad_grid.stride(0);
760   int64_t gGrid_sW = grad_grid.stride(2);
761   const scalar_t *inp_ptr = input.const_data_ptr<scalar_t>();
762   const scalar_t *grid_ptr = grid.const_data_ptr<scalar_t>();
763   const scalar_t *gOut_ptr = grad_output.const_data_ptr<scalar_t>();
764   scalar_t *gInp_ptr = grad_input.mutable_data_ptr<scalar_t>();
765   scalar_t *gGrid_ptr = grad_grid.data_ptr<scalar_t>();
766   // loop over each output pixel
767   at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
768     for (const auto n : c10::irange(start, end)) {
769       const scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
770       const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
771       scalar_t *gGrid_ptr_NHW = gGrid_ptr + n * gGrid_sN;
772       for (const auto h : c10::irange(out_H)) {
773         for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NHW += gGrid_sW /* grad_grid is contiguous */ ) {
774           // get the corresponding input x, y co-ordinates from grid
775           const scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
776           scalar_t x = *grid_ptr_NHW;
777           scalar_t y = grid_ptr_NHW[grid_sCoor];
778 
779           // multipliers for gradients on ix, iy
780           // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
781           scalar_t gix_mult, giy_mult;
782           scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult);
783           scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult);
784 
785           if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
786             // get corner pixel values from (x, y)
787             // for 4d, we use north-east-south-west
788             int64_t ix_nw = static_cast<int64_t>(std::floor(ix));
789             int64_t iy_nw = static_cast<int64_t>(std::floor(iy));
790 
791             int64_t ix_ne = ix_nw + 1;
792             int64_t iy_ne = iy_nw;
793 
794             int64_t ix_sw = ix_nw;
795             int64_t iy_sw = iy_nw + 1;
796 
797             int64_t ix_se = ix_nw + 1;
798             int64_t iy_se = iy_nw + 1;
799 
800             // get surfaces to each neighbor:
801             scalar_t nw = (ix_se - ix)    * (iy_se - iy);
802             scalar_t ne = (ix    - ix_sw) * (iy_sw - iy);
803             scalar_t sw = (ix_ne - ix)    * (iy    - iy_ne);
804             scalar_t se = (ix    - ix_nw) * (iy    - iy_nw);
805 
806             scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0);
807             const scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;
808             scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
809             const scalar_t *inp_ptr_NC = inp_ptr_N;
810             // calculate bilinear weighted pixel value and set output pixel
811             for (int64_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
812               scalar_t gOut = *gOut_ptr_NCHW;
813 
814               // calculate and set grad_input
815               safe_add_2d(gInp_ptr_NC, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut);
816               safe_add_2d(gInp_ptr_NC, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut);
817               safe_add_2d(gInp_ptr_NC, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut);
818               safe_add_2d(gInp_ptr_NC, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut);
819 
820               // calculate grad_grid
821               if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
822                 scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW];
823                 gix -= nw_val * (iy_se - iy) * gOut;
824                 giy -= nw_val * (ix_se - ix) * gOut;
825               }
826               if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
827                 scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW];
828                 gix += ne_val * (iy_sw - iy) * gOut;
829                 giy -= ne_val * (ix - ix_sw) * gOut;
830               }
831               if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
832                 scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW];
833                 gix -= sw_val * (iy - iy_ne) * gOut;
834                 giy += sw_val * (ix_ne - ix) * gOut;
835               }
836               if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
837                 scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW];
838                 gix += se_val * (iy - iy_nw) * gOut;
839                 giy += se_val * (ix - ix_nw) * gOut;
840               }
841             }
842 
843             // assuming grad_grid is contiguous
844             gGrid_ptr_NHW[0] = gix_mult * gix;
845             gGrid_ptr_NHW[1] = giy_mult * giy;
846           } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
847             int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
848             int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
849 
850             // assign nearest neighbour pixel value to output pixel
851             const scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;
852             scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
853             for (int64_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC) {
854               // calculate and set grad_input
855               safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest, gInp_sH, gInp_sW,
856                           inp_H, inp_W, *gOut_ptr_NCHW);
857             }
858           } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
859 
860             ix = grid_sampler_unnormalize_set_grad(x, inp_W, align_corners, &gix_mult);
861             iy = grid_sampler_unnormalize_set_grad(y, inp_H, align_corners, &giy_mult);
862 
863             scalar_t ix_nw = std::floor(ix);
864             scalar_t iy_nw = std::floor(iy);
865 
866             const scalar_t tx = ix - ix_nw;
867             const scalar_t ty = iy - iy_nw;
868 
869             // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
870             scalar_t x_coeffs[4];
871             // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
872             scalar_t y_coeffs[4];
873             // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
874             scalar_t x_coeffs_grad[4];
875             // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
876             scalar_t y_coeffs_grad[4];
877 
878             get_cubic_upsample_coefficients<scalar_t>(x_coeffs, tx);
879             get_cubic_upsample_coefficients<scalar_t>(y_coeffs, ty);
880             get_cubic_coefficients_grad<scalar_t>(x_coeffs_grad, tx);
881             get_cubic_coefficients_grad<scalar_t>(y_coeffs_grad, ty);
882 
883             scalar_t gix = static_cast<scalar_t>(0);
884             scalar_t giy = static_cast<scalar_t>(0);
885 
886             const scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;
887             scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
888             const scalar_t *inp_ptr_NC = inp_ptr_N;
889 
890             for (int64_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC+= inp_sC) {
891               scalar_t gOut = *gOut_ptr_NCHW;
892 
893               for (const auto i : c10::irange(4)) {
894                 for (const auto j : c10::irange(4)) {
895 
896                   // set input gradient
897                   add_value_bounded<scalar_t>(gInp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j,
898                     inp_W, inp_H, gInp_sW, gInp_sH, gOut * x_coeffs[i] * y_coeffs[j], padding_mode, align_corners);
899 
900                   // set grid gradient
901                   scalar_t val = get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j,
902                     inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners);
903 
904                   gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut;
905                   giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut;
906                 }
907               }
908             }
909             gGrid_ptr_NHW[0] = gix_mult * gix;
910             gGrid_ptr_NHW[1] = giy_mult * giy;
911           }
912         }
913       }
914     }
915   });
916   return std::make_tuple(grad_input, grad_grid);
917 }
918 
grid_sampler_2d_cpu(const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)919 Tensor grid_sampler_2d_cpu(const Tensor& input, const Tensor& grid,
920                            int64_t interpolation_mode, int64_t padding_mode,
921                            bool align_corners) {
922   // See NOTE [ grid_sampler Native Functions ].
923   // Add checks here in case this is called instead of grid_sampler.
924   check_grid_sampler_common(input, grid);
925   check_grid_sampler_2d(input, grid);
926 
927   if (input.scalar_type() == kQUInt8) {
928     return native::_grid_sampler_2d_cpu_quantized(
929         input, grid, interpolation_mode, padding_mode, align_corners);
930   }
931   // AVX gather instructions use signed 32-bit offsets to gather float values.
932   // Check for possible overflow and fallback to scalar implementation
933   if (input.scalar_type() != kDouble) {
934     TORCH_CHECK(input.scalar_type() == kFloat,
935                 "grid_sampler_2d_cpu not implemented for ", input.scalar_type());
936     auto sizes = input.sizes();
937     auto strides = input.strides();
938     const auto grid_sW = grid.strides()[2];
939     // NOTE: Gather offsets are only used for the input H, W dimensions
940     //       or only for strided access to the grid tensor
941     auto max_gather_offset = std::max(
942       (sizes[2] - 1) * strides[2] + (sizes[3] - 1) * strides[3],
943       grid_sW * (vec::Vectorized<float>::size() - 1));
944 
945     if (max_gather_offset > std::numeric_limits<int32_t>::max()) {
946       return native::_grid_sampler_2d_cpu_fallback(
947         input, grid, interpolation_mode, padding_mode, align_corners);
948     }
949   }
950 
951   auto in_size = input.sizes();
952   auto grid_size = grid.sizes();
953   auto output = at::empty(
954       {in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
955   grid_sampler_2d_cpu_kernel(
956       kCPU, output, input, grid, interpolation_mode, padding_mode, align_corners);
957   return output;
958 }
959 
960 DEFINE_DISPATCH(grid_sampler_2d_cpu_kernel);
961 
962 
grid_sampler_3d_cpu(const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)963 Tensor grid_sampler_3d_cpu(const Tensor& input, const Tensor& grid,
964                            int64_t interpolation_mode, int64_t padding_mode,
965                            bool align_corners) {
966   // See NOTE [ grid_sampler Native Functions ].
967   // Add checks here in case this is called instead of grid_sampler.
968   check_grid_sampler_common(input, grid);
969   check_grid_sampler_3d(input, grid, interpolation_mode);
970 
971   return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler3d_cpu", [&] {
972     return grid_sampler_3d_cpu_impl<scalar_t>(
973       input, grid, static_cast<GridSamplerInterpolation>(interpolation_mode),
974       static_cast<GridSamplerPadding>(padding_mode), align_corners);
975   });
976 }
977 
978 std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners,std::array<bool,2> output_mask)979 grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid,
980                              int64_t interpolation_mode, int64_t padding_mode, bool align_corners,
981                              std::array<bool,2> output_mask) {
982   // See NOTE [ grid_sampler Native Functions ].
983   // Add checks here in case this is called instead of grid_sampler.
984   check_grid_sampler_common(input, grid);
985   check_grid_sampler_2d(input, grid);
986 
987   // AVX gather instructions use signed 32-bit offsets to gather float values.
988   // Check for possible overflow and fallback to scalar implementation
989   if (input.scalar_type() != kDouble) {
990     TORCH_CHECK(input.scalar_type() == kFloat,
991                 "grid_sampler_2d_backward_cpu not implemented for ", input.scalar_type());
992     auto isizes = input.sizes();
993     auto istrides = input.strides();
994     auto gsizes = grad_output.sizes();
995     auto gstrides = grad_output.strides();
996     const auto grid_sW = grid.strides()[2];
997     // NOTE: Gather offsets are only used for the height and width dimensions
998     auto max_gather_offset = std::max(
999       std::max(
1000         (isizes[2] - 1) * istrides[2] + (isizes[3] - 1) * istrides[3],
1001         (gsizes[2] - 1) * gstrides[2] + (gsizes[3] - 1) * gstrides[3]),
1002       grid_sW * (vec::Vectorized<float>::size() - 1));
1003 
1004     if (max_gather_offset > std::numeric_limits<int32_t>::max()) {
1005       return native::_grid_sampler_2d_cpu_fallback_backward(
1006         grad_output, input, grid, interpolation_mode, padding_mode, align_corners);
1007     }
1008   }
1009 
1010   auto input_requires_grad = output_mask[0];
1011   Tensor grad_input = ([&]() {
1012     if (input_requires_grad) {
1013       return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1014     } else {
1015       return Tensor();
1016     }
1017   })();
1018   auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1019   grid_sampler_2d_backward_cpu_kernel(
1020       kCPU, grad_input, grad_grid, grad_output, input, grid,
1021       interpolation_mode, padding_mode, align_corners, output_mask);
1022   return std::make_tuple(std::move(grad_input), std::move(grad_grid));
1023 }
1024 
1025 DEFINE_DISPATCH(grid_sampler_2d_backward_cpu_kernel);
1026 
1027 std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners,std::array<bool,2> output_mask)1028 grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid,
1029                              int64_t interpolation_mode, int64_t padding_mode, bool align_corners,
1030                              std::array<bool,2> output_mask) {
1031   // See NOTE [ grid_sampler Native Functions ].
1032   // Add checks here in case this is called instead of grid_sampler.
1033   check_grid_sampler_common(input, grid);
1034   check_grid_sampler_3d(input, grid, interpolation_mode);
1035 
1036   return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_3d_backward_cpu", [&] {
1037     return grid_sampler_3d_backward_cpu_impl<scalar_t>(
1038       grad_output, input, grid,
1039       static_cast<GridSamplerInterpolation>(interpolation_mode),
1040       static_cast<GridSamplerPadding>(padding_mode),
1041       align_corners, output_mask);
1042   });
1043 }
1044 
1045 // See NOTE [ grid_sampler Native Functions ].
grid_sampler(const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)1046 Tensor grid_sampler(
1047   const Tensor& input,
1048   const Tensor& grid,
1049   int64_t interpolation_mode,
1050   int64_t padding_mode,
1051   bool align_corners
1052 ) {
1053   if (cond_cudnn_grid_sampler(input, grid) &&
1054       static_cast<GridSamplerInterpolation>(interpolation_mode) ==
1055         GridSamplerInterpolation::Bilinear &&
1056       static_cast<GridSamplerPadding>(padding_mode) ==
1057         GridSamplerPadding::Zeros &&
1058       align_corners) {
1059     return cudnn_grid_sampler(input, grid);
1060   }
1061 
1062   if (input.dim() == 4) {
1063     return at::grid_sampler_2d(
1064       input, grid, interpolation_mode, padding_mode, align_corners);
1065   } else {
1066     return at::grid_sampler_3d(
1067       input, grid, interpolation_mode, padding_mode, align_corners);
1068   }
1069 }
1070 
1071 }  // namespace at::native
1072