1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/GridSampler.h>
3 #include <ATen/native/GridSamplerUtils.h>
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/native/UpSample.h>
9 #include <ATen/native/cpu/GridSamplerKernel.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/irange.h>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/_empty_affine_quantized.h>
18 #include <ATen/ops/_grid_sampler_2d_cpu_fallback_backward_native.h>
19 #include <ATen/ops/_grid_sampler_2d_cpu_fallback_native.h>
20 #include <ATen/ops/cudnn_grid_sampler.h>
21 #include <ATen/ops/empty.h>
22 #include <ATen/ops/empty_like.h>
23 #include <ATen/ops/grid_sampler_2d.h>
24 #include <ATen/ops/grid_sampler_2d_backward_native.h>
25 #include <ATen/ops/grid_sampler_2d_native.h>
26 #include <ATen/ops/grid_sampler_3d.h>
27 #include <ATen/ops/grid_sampler_3d_backward_native.h>
28 #include <ATen/ops/grid_sampler_3d_native.h>
29 #include <ATen/ops/grid_sampler_native.h>
30 #include <ATen/ops/zeros_like.h>
31 #endif
32
33 namespace at::native {
34
35 using at::native::detail::GridSamplerInterpolation;
36 using at::native::detail::GridSamplerPadding;
37
38 namespace {
39
40 template<typename scalar_t>
grid_sampler_3d_cpu_impl(const Tensor & input,const Tensor & grid,GridSamplerInterpolation interpolation_mode,GridSamplerPadding padding_mode,bool align_corners)41 Tensor grid_sampler_3d_cpu_impl(const Tensor& input, const Tensor& grid,
42 GridSamplerInterpolation interpolation_mode,
43 GridSamplerPadding padding_mode,
44 bool align_corners) {
45 // See NOTE [ grid_sampler Native Functions ].
46 // Add checks here in case this is called instead of grid_sampler.
47 check_grid_sampler_common(input, grid);
48 check_grid_sampler_3d(
49 input, grid, static_cast<int64_t>(interpolation_mode));
50
51 int64_t N = input.size(0);
52 int64_t C = input.size(1);
53 int64_t inp_D = input.size(2);
54 int64_t inp_H = input.size(3);
55 int64_t inp_W = input.size(4);
56 int64_t out_D = grid.size(1);
57 int64_t out_H = grid.size(2);
58 int64_t out_W = grid.size(3);
59 auto output = at::empty({N, C, out_D, out_H, out_W}, input.options());
60 if (output.numel() == 0) {
61 return output;
62 }
63 int64_t inp_sN = input.stride(0);
64 int64_t inp_sC = input.stride(1);
65 int64_t inp_sD = input.stride(2);
66 int64_t inp_sH = input.stride(3);
67 int64_t inp_sW = input.stride(4);
68 int64_t grid_sN = grid.stride(0);
69 int64_t grid_sD = grid.stride(1);
70 int64_t grid_sH = grid.stride(2);
71 int64_t grid_sW = grid.stride(3);
72 int64_t grid_sCoor = grid.stride(4);
73 int64_t out_sN = output.stride(0);
74 int64_t out_sC = output.stride(1);
75 int64_t out_sD = output.stride(2);
76 int64_t out_sH = output.stride(3);
77 int64_t out_sW = output.stride(4);
78 const scalar_t *inp_ptr = input.const_data_ptr<scalar_t>();
79 scalar_t *out_ptr = output.data_ptr<scalar_t>();
80 const scalar_t *grid_ptr = grid.const_data_ptr<scalar_t>();
81 // loop over each output pixel
82 at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
83 for (const auto n : c10::irange(start, end)) {
84 const scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
85 const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
86 for (const auto d : c10::irange(out_D)) {
87 for (const auto h : c10::irange(out_H)) {
88 for (const auto w : c10::irange(out_W)) {
89 // get the corresponding input x, y, z co-ordinates from grid
90 const scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW;
91 scalar_t ix = *grid_ptr_NDHW;
92 scalar_t iy = grid_ptr_NDHW[grid_sCoor];
93 scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor];
94
95 ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
96 iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
97 iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners);
98
99 if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
100 // get corner pixel values from (x, y, z)
101 // for 4d, we used north-east-south-west
102 // for 5d, we add top-bottom
103 int64_t ix_tnw = static_cast<int64_t>(std::floor(ix));
104 int64_t iy_tnw = static_cast<int64_t>(std::floor(iy));
105 int64_t iz_tnw = static_cast<int64_t>(std::floor(iz));
106
107 int64_t ix_tne = ix_tnw + 1;
108 int64_t iy_tne = iy_tnw;
109 int64_t iz_tne = iz_tnw;
110
111 int64_t ix_tsw = ix_tnw;
112 int64_t iy_tsw = iy_tnw + 1;
113 int64_t iz_tsw = iz_tnw;
114
115 int64_t ix_tse = ix_tnw + 1;
116 int64_t iy_tse = iy_tnw + 1;
117 int64_t iz_tse = iz_tnw;
118
119 int64_t ix_bnw = ix_tnw;
120 int64_t iy_bnw = iy_tnw;
121 int64_t iz_bnw = iz_tnw + 1;
122
123 int64_t ix_bne = ix_tnw + 1;
124 int64_t iy_bne = iy_tnw;
125 int64_t iz_bne = iz_tnw + 1;
126
127 int64_t ix_bsw = ix_tnw;
128 int64_t iy_bsw = iy_tnw + 1;
129 int64_t iz_bsw = iz_tnw + 1;
130
131 int64_t ix_bse = ix_tnw + 1;
132 int64_t iy_bse = iy_tnw + 1;
133 int64_t iz_bse = iz_tnw + 1;
134
135 // get surfaces to each neighbor:
136 scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
137 scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
138 scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
139 scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
140 scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
141 scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
142 scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
143 scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
144
145 // calculate bilinear weighted pixel value and set output pixel
146 scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
147 const scalar_t *inp_ptr_NC = inp_ptr_N;
148 for (int64_t c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) {
149 // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
150 // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
151 // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
152 // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
153 *out_ptr_NCDHW = static_cast<scalar_t>(0);
154 if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
155 *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
156 }
157 if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
158 *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
159 }
160 if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
161 *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
162 }
163 if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
164 *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
165 }
166 if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
167 *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
168 }
169 if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
170 *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
171 }
172 if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
173 *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
174 }
175 if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
176 *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
177 }
178 }
179 } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
180 int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
181 int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
182 int64_t iz_nearest = static_cast<int64_t>(std::nearbyint(iz));
183
184 // assign nearest neighbour pixel value to output pixel
185 scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
186 const scalar_t *inp_ptr_NC = inp_ptr_N;
187 for (int64_t c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) {
188 if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {
189 *out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW];
190 } else {
191 *out_ptr_NCDHW = static_cast<scalar_t>(0);
192 }
193 }
194 }
195 }
196 }
197 }
198 }
199 });
200 return output;
201 }
202
203 template<typename scalar_t>
204 std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cpu_impl(const Tensor & grad_output,const Tensor & input,const Tensor & grid,GridSamplerInterpolation interpolation_mode,GridSamplerPadding padding_mode,bool align_corners,std::array<bool,2> output_mask)205 grid_sampler_3d_backward_cpu_impl(const Tensor& grad_output,
206 const Tensor& input, const Tensor& grid,
207 GridSamplerInterpolation interpolation_mode,
208 GridSamplerPadding padding_mode,
209 bool align_corners, std::array<bool,2> output_mask) {
210 // See NOTE [ grid_sampler Native Functions ].
211 // Add checks here in case this is called instead of grid_sampler.
212 check_grid_sampler_common(input, grid);
213 check_grid_sampler_3d(
214 input, grid, static_cast<int64_t>(interpolation_mode));
215
216 auto input_requires_grad = output_mask[0];
217 Tensor grad_input = ([&]() {
218 if (input_requires_grad) {
219 return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
220 } else {
221 return Tensor();
222 }
223 })();
224 auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
225 if (grid.numel() == 0 || input.numel() == 0) {
226 grad_grid.zero_();
227 return std::make_tuple(grad_input, grad_grid);
228 }
229 // If interpolation mode is Nearest, then grad_grid is not filled in the
230 // loop below.
231 if (interpolation_mode == GridSamplerInterpolation::Nearest) {
232 grad_grid.zero_();
233 }
234 int64_t N = input.size(0);
235 int64_t C = input.size(1);
236 int64_t inp_D = input.size(2);
237 int64_t inp_H = input.size(3);
238 int64_t inp_W = input.size(4);
239 int64_t out_D = grid.size(1);
240 int64_t out_H = grid.size(2);
241 int64_t out_W = grid.size(3);
242 int64_t inp_sN = input.stride(0);
243 int64_t inp_sC = input.stride(1);
244 int64_t inp_sD = input.stride(2);
245 int64_t inp_sH = input.stride(3);
246 int64_t inp_sW = input.stride(4);
247 int64_t grid_sN = grid.stride(0);
248 int64_t grid_sD = grid.stride(1);
249 int64_t grid_sH = grid.stride(2);
250 int64_t grid_sW = grid.stride(3);
251 int64_t grid_sCoor = grid.stride(4);
252 int64_t gOut_sN = grad_output.stride(0);
253 int64_t gOut_sC = grad_output.stride(1);
254 int64_t gOut_sD = grad_output.stride(2);
255 int64_t gOut_sH = grad_output.stride(3);
256 int64_t gOut_sW = grad_output.stride(4);
257 int64_t gInp_sN = 0;
258 int64_t gInp_sC = 0;
259 int64_t gInp_sD = 0;
260 int64_t gInp_sH = 0;
261 int64_t gInp_sW = 0;
262 if (input_requires_grad) {
263 gInp_sN = grad_input.stride(0);
264 gInp_sC = grad_input.stride(1);
265 gInp_sD = grad_input.stride(2);
266 gInp_sH = grad_input.stride(3);
267 gInp_sW = grad_input.stride(4);
268 }
269 int64_t gGrid_sN = grad_grid.stride(0);
270 int64_t gGrid_sW = grad_grid.stride(3);
271 const scalar_t *inp_ptr = input.const_data_ptr<scalar_t>();
272 const scalar_t *grid_ptr = grid.const_data_ptr<scalar_t>();
273 const scalar_t *gOut_ptr = grad_output.const_data_ptr<scalar_t>();
274 scalar_t *gInp_ptr = nullptr;
275 if (input_requires_grad) {
276 gInp_ptr = grad_input.mutable_data_ptr<scalar_t>();
277 }
278 scalar_t *gGrid_ptr = grad_grid.data_ptr<scalar_t>();
279 // loop over each output pixel
280 at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
281 for (const auto n : c10::irange(start, end)) {
282 const scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
283 const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
284 scalar_t *gGrid_ptr_NDHW = gGrid_ptr + n * gGrid_sN;
285 for (const auto d : c10::irange(out_D)) {
286 for (const auto h : c10::irange(out_H)) {
287 for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NDHW += gGrid_sW /* grad_grid is contiguous */ ) {
288 // get the corresponding input x, y, z co-ordinates from grid
289 const scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW;
290 scalar_t ix = *grid_ptr_NDHW;
291 scalar_t iy = grid_ptr_NDHW[grid_sCoor];
292 scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor];
293
294 // multipliers for gradients on ix, iy, and iz
295 scalar_t gix_mult, giy_mult, giz_mult;
296 ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
297 iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
298 iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
299
300 if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
301 // get corner pixel values from (x, y, z)
302 // for 4d, we used north-east-south-west
303 // for 5d, we add top-bottom
304 int64_t ix_tnw = static_cast<int64_t>(std::floor(ix));
305 int64_t iy_tnw = static_cast<int64_t>(std::floor(iy));
306 int64_t iz_tnw = static_cast<int64_t>(std::floor(iz));
307
308 int64_t ix_tne = ix_tnw + 1;
309 int64_t iy_tne = iy_tnw;
310 int64_t iz_tne = iz_tnw;
311
312 int64_t ix_tsw = ix_tnw;
313 int64_t iy_tsw = iy_tnw + 1;
314 int64_t iz_tsw = iz_tnw;
315
316 int64_t ix_tse = ix_tnw + 1;
317 int64_t iy_tse = iy_tnw + 1;
318 int64_t iz_tse = iz_tnw;
319
320 int64_t ix_bnw = ix_tnw;
321 int64_t iy_bnw = iy_tnw;
322 int64_t iz_bnw = iz_tnw + 1;
323
324 int64_t ix_bne = ix_tnw + 1;
325 int64_t iy_bne = iy_tnw;
326 int64_t iz_bne = iz_tnw + 1;
327
328 int64_t ix_bsw = ix_tnw;
329 int64_t iy_bsw = iy_tnw + 1;
330 int64_t iz_bsw = iz_tnw + 1;
331
332 int64_t ix_bse = ix_tnw + 1;
333 int64_t iy_bse = iy_tnw + 1;
334 int64_t iz_bse = iz_tnw + 1;
335
336 // get surfaces to each neighbor:
337 scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
338 scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
339 scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
340 scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
341 scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
342 scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
343 scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
344 scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
345
346 scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0), giz = static_cast<scalar_t>(0);
347 const scalar_t *gOut_ptr_NCDHW = gOut_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
348 const scalar_t *inp_ptr_NC = inp_ptr_N;
349 scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
350 // calculate bilinear weighted pixel value and set output pixel
351 for (int64_t c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
352 scalar_t gOut = *gOut_ptr_NCDHW;
353
354 // calculate and set grad_input
355 if (input_requires_grad) {
356 safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
357 safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
358 safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
359 safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
360 safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
361 safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
362 safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
363 safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);
364 }
365 // calculate grad_grid
366 if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
367 scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
368 gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut;
369 giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut;
370 giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut;
371 }
372 if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
373 scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
374 gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut;
375 giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut;
376 giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut;
377 }
378 if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
379 scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
380 gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut;
381 giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut;
382 giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut;
383 }
384 if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
385 scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
386 gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut;
387 giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut;
388 giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut;
389 }
390 if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
391 scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
392 gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut;
393 giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut;
394 giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut;
395 }
396 if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
397 scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
398 gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut;
399 giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut;
400 giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut;
401 }
402 if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
403 scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
404 gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut;
405 giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut;
406 giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut;
407 }
408 if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
409 scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
410 gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut;
411 giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut;
412 giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut;
413 }
414 }
415
416 // assuming grad_grid is contiguous
417 gGrid_ptr_NDHW[0] = gix_mult * gix;
418 gGrid_ptr_NDHW[1] = giy_mult * giy;
419 gGrid_ptr_NDHW[2] = giz_mult * giz;
420 } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
421 int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
422 int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
423 int64_t iz_nearest = static_cast<int64_t>(std::nearbyint(iz));
424
425 // assign nearest neighbour pixel value to output pixel
426 const scalar_t *gOut_ptr_NCDHW = gOut_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
427 if (input_requires_grad) {
428 scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
429 for (int64_t c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC) {
430 // calculate and set grad_input
431 safe_add_3d(gInp_ptr_NC, iz_nearest, iy_nearest, ix_nearest,
432 gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, *gOut_ptr_NCDHW);
433 }
434 }
435 }
436 }
437 }
438 }
439 }
440 });
441 return std::make_tuple(grad_input, grad_grid);
442 }
443
444 } // namespace
445
_grid_sampler_2d_cpu_quantized(const Tensor & input,const Tensor & grid,int64_t interpolation_mode_,int64_t padding_mode_,bool align_corners)446 static Tensor _grid_sampler_2d_cpu_quantized(
447 const Tensor& input,
448 const Tensor& grid,
449 int64_t interpolation_mode_,
450 int64_t padding_mode_,
451 bool align_corners) {
452 // See NOTE [ grid_sampler Native Functions ].
453 // Add checks here in case this is called instead of grid_sampler.
454 check_grid_sampler_common(input, grid);
455 check_grid_sampler_2d(input, grid);
456
457 auto interpolation_mode =
458 static_cast<GridSamplerInterpolation>(interpolation_mode_);
459 /* Bilinear interpolation is supported using the fact that we can perform
460 * linear interpolations on quantized values without rescaling. */
461 TORCH_CHECK(
462 interpolation_mode == GridSamplerInterpolation::Bilinear,
463 "_grid_sampler_2d_cpu_quantized(): only bilinear interpolation supported")
464 auto padding_mode = static_cast<GridSamplerPadding>(padding_mode_);
465
466 int64_t N = input.size(0);
467 int64_t C = input.size(1);
468 int64_t inp_H = input.size(2);
469 int64_t inp_W = input.size(3);
470 int64_t out_H = grid.size(1);
471 int64_t out_W = grid.size(2);
472 uint8_t zero_point = input.q_zero_point();
473 auto output = at::_empty_affine_quantized(
474 {N, C, out_H, out_W},
475 at::device(c10::kCPU).dtype(c10::kQUInt8),
476 input.q_scale(),
477 zero_point);
478 int64_t inp_sN = input.stride(0);
479 int64_t inp_sC = input.stride(1);
480 int64_t inp_sH = input.stride(2);
481 int64_t inp_sW = input.stride(3);
482 int64_t grid_sN = grid.stride(0);
483 int64_t grid_sH = grid.stride(1);
484 int64_t grid_sW = grid.stride(2);
485 int64_t grid_sCoor = grid.stride(3);
486 int64_t out_sN = output.stride(0);
487 int64_t out_sC = output.stride(1);
488 int64_t out_sH = output.stride(2);
489 int64_t out_sW = output.stride(3);
490 uint8_t* inp_ptr = (uint8_t*)input.data_ptr<quint8>();
491 uint8_t* out_ptr = (uint8_t*)output.data_ptr<quint8>();
492 float* grid_ptr = grid.data_ptr<float>();
493 at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
494 for (const auto n : c10::irange(start, end)) {
495 float* grid_ptr_N = grid_ptr + n * grid_sN;
496 uint8_t* inp_ptr_N = inp_ptr + n * inp_sN;
497 for (const auto h : c10::irange(out_H)) {
498 for (const auto w : c10::irange(out_W)) {
499 // get the corresponding input x, y, z co-ordinates from grid
500 float* grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
501 float x = *grid_ptr_NHW;
502 float y = grid_ptr_NHW[grid_sCoor];
503
504 float ix = grid_sampler_compute_source_index(
505 x, inp_W, padding_mode, align_corners);
506 float iy = grid_sampler_compute_source_index(
507 y, inp_H, padding_mode, align_corners);
508
509 // get corner pixel values from (x, y)
510 // for 4d, we use north-east-south-west
511 int64_t ix_nw = static_cast<int64_t>(std::floor(ix));
512 int64_t iy_nw = static_cast<int64_t>(std::floor(iy));
513
514 int64_t ix_ne = ix_nw + 1;
515 int64_t iy_ne = iy_nw;
516
517 int64_t ix_sw = ix_nw;
518 int64_t iy_sw = iy_nw + 1;
519
520 int64_t ix_se = ix_nw + 1;
521 int64_t iy_se = iy_nw + 1;
522
523 // get surfaces to each neighbor:
524 float nw = (ix_se - ix) * (iy_se - iy);
525 float ne = (ix - ix_sw) * (iy_sw - iy);
526 float sw = (ix_ne - ix) * (iy - iy_ne);
527 float se = (ix - ix_nw) * (iy - iy_nw);
528
529 // calculate bilinear weighted pixel value and set output pixel
530 uint8_t* inp_ptr_NC = inp_ptr_N;
531 uint8_t* out_ptr_NCHW =
532 out_ptr + n * out_sN + h * out_sH + w * out_sW;
533 for (int64_t c = 0; c < C;
534 ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
535 float res = 0;
536 res += within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)
537 ? inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw
538 : zero_point * nw;
539 res += within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)
540 ? inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne
541 : zero_point * ne;
542 res += within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)
543 ? inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw
544 : zero_point * sw;
545 res += within_bounds_2d(iy_se, ix_se, inp_H, inp_W)
546 ? inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se
547 : zero_point * se;
548 *out_ptr_NCHW = std::nearbyint(res);
549 }
550 }
551 }
552 }
553 });
554 return output;
555 }
556
_grid_sampler_2d_cpu_fallback(const Tensor & input,const Tensor & grid,int64_t interpolation_mode_,int64_t padding_mode_,bool align_corners)557 Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid,
558 int64_t interpolation_mode_,
559 int64_t padding_mode_,
560 bool align_corners) {
561 // See NOTE [ grid_sampler Native Functions ].
562 // Add checks here in case this is called instead of grid_sampler.
563 check_grid_sampler_common(input, grid);
564 check_grid_sampler_2d(input, grid);
565
566 auto interpolation_mode = static_cast<GridSamplerInterpolation>(interpolation_mode_);
567 auto padding_mode = static_cast<GridSamplerPadding>(padding_mode_);
568 using scalar_t = float;
569
570 int64_t N = input.size(0);
571 int64_t C = input.size(1);
572 int64_t inp_H = input.size(2);
573 int64_t inp_W = input.size(3);
574 int64_t out_H = grid.size(1);
575 int64_t out_W = grid.size(2);
576 auto output = at::empty({N, C, out_H, out_W}, input.options());
577 if (output.numel() == 0) {
578 return output;
579 }
580 int64_t inp_sN = input.stride(0);
581 int64_t inp_sC = input.stride(1);
582 int64_t inp_sH = input.stride(2);
583 int64_t inp_sW = input.stride(3);
584 int64_t grid_sN = grid.stride(0);
585 int64_t grid_sH = grid.stride(1);
586 int64_t grid_sW = grid.stride(2);
587 int64_t grid_sCoor = grid.stride(3);
588 int64_t out_sN = output.stride(0);
589 int64_t out_sC = output.stride(1);
590 int64_t out_sH = output.stride(2);
591 int64_t out_sW = output.stride(3);
592 const scalar_t *inp_ptr = input.const_data_ptr<scalar_t>();
593 scalar_t *out_ptr = output.data_ptr<scalar_t>();
594 const scalar_t *grid_ptr = grid.const_data_ptr<scalar_t>();
595 // loop over each output pixel
596 at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
597 for (const auto n : c10::irange(start, end)) {
598 const scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
599 const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
600 for (const auto h : c10::irange(out_H)) {
601 for (const auto w : c10::irange(out_W)) {
602 // get the corresponding input x, y, z co-ordinates from grid
603 const scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
604 scalar_t x = *grid_ptr_NHW;
605 scalar_t y = grid_ptr_NHW[grid_sCoor];
606
607 scalar_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners);
608 scalar_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners);
609
610 if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
611 // get corner pixel values from (x, y)
612 // for 4d, we use north-east-south-west
613 int64_t ix_nw = static_cast<int64_t>(std::floor(ix));
614 int64_t iy_nw = static_cast<int64_t>(std::floor(iy));
615
616 int64_t ix_ne = ix_nw + 1;
617 int64_t iy_ne = iy_nw;
618
619 int64_t ix_sw = ix_nw;
620 int64_t iy_sw = iy_nw + 1;
621
622 int64_t ix_se = ix_nw + 1;
623 int64_t iy_se = iy_nw + 1;
624
625
626 // get surfaces to each neighbor:
627 scalar_t nw = (ix_se - ix) * (iy_se - iy);
628 scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
629 scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
630 scalar_t se = (ix - ix_nw) * (iy - iy_nw);
631
632 // calculate bilinear weighted pixel value and set output pixel
633 const scalar_t *inp_ptr_NC = inp_ptr_N;
634 scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
635 for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
636 auto res = static_cast<scalar_t>(0);
637 if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
638 res += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw;
639 }
640 if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
641 res += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne;
642 }
643 if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
644 res += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw;
645 }
646 if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
647 res += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se;
648 }
649 *out_ptr_NCHW = res;
650 }
651 } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
652 int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
653 int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
654
655 // assign nearest neighbour pixel value to output pixel
656 scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
657 const scalar_t *inp_ptr_NC = inp_ptr_N;
658 for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
659 if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) {
660 *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW];
661 } else {
662 *out_ptr_NCHW = static_cast<scalar_t>(0);
663 }
664 }
665 } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
666 // grid_sampler_compute_source_index will "clip the value" of idx depends on the padding,
667 // which would cause calculation to be wrong,
668 // for example x = -0.1 -> ix = 0 for zero padding, but in bicubic ix = floor(x) = -1
669 // There would be more problem in reflection padding, since the -1 and +1 direction is not fixed in boundary condition
670 ix = grid_sampler_unnormalize(x, inp_W, align_corners);
671 iy = grid_sampler_unnormalize(y, inp_H, align_corners);
672
673 scalar_t ix_nw = std::floor(ix);
674 scalar_t iy_nw = std::floor(iy);
675
676 const scalar_t tx = ix - ix_nw;
677 const scalar_t ty = iy - iy_nw;
678
679 const scalar_t *inp_ptr_NC = inp_ptr_N;
680 scalar_t *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW;
681 for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) {
682 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
683 scalar_t coefficients[4];
684
685 // Interpolate 4 values in the x direction
686 for (const auto i : c10::irange(4)) {
687 coefficients[i] = cubic_interp1d<scalar_t>(
688 get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
689 get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
690 get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
691 get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners),
692 tx);
693 }
694
695 // Interpolate in the y direction
696 *out_ptr_NCHW = cubic_interp1d<scalar_t>(
697 coefficients[0],
698 coefficients[1],
699 coefficients[2],
700 coefficients[3],
701 ty);
702 }
703 }
704 }
705 }
706 }
707 });
708 return output;
709 }
710
711 std::tuple<Tensor, Tensor>
_grid_sampler_2d_cpu_fallback_backward(const Tensor & grad_output,const Tensor & input,const Tensor & grid,int64_t interpolation_mode_,int64_t padding_mode_,bool align_corners)712 _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output,
713 const Tensor& input, const Tensor& grid,
714 int64_t interpolation_mode_,
715 int64_t padding_mode_,
716 bool align_corners) {
717 // See NOTE [ grid_sampler Native Functions ].
718 // Add checks here in case this is called instead of grid_sampler.
719 check_grid_sampler_common(input, grid);
720 check_grid_sampler_2d(input, grid);
721
722 const auto interpolation_mode = static_cast<GridSamplerInterpolation>(interpolation_mode_);
723 const auto padding_mode = static_cast<GridSamplerPadding>(padding_mode_);
724 using scalar_t = float;
725
726 auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
727 auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
728 if (grid.numel() == 0 || input.numel() == 0) {
729 grad_grid.zero_();
730 return std::make_tuple(grad_input, grad_grid);
731 }
732 // If interpolation mode is Nearest, then grad_grid is not filled in the
733 // loop below.
734 if (interpolation_mode == GridSamplerInterpolation::Nearest) {
735 grad_grid.zero_();
736 }
737 int64_t N = input.size(0);
738 int64_t C = input.size(1);
739 int64_t inp_H = input.size(2);
740 int64_t inp_W = input.size(3);
741 int64_t out_H = grid.size(1);
742 int64_t out_W = grid.size(2);
743 int64_t inp_sN = input.stride(0);
744 int64_t inp_sC = input.stride(1);
745 int64_t inp_sH = input.stride(2);
746 int64_t inp_sW = input.stride(3);
747 int64_t grid_sN = grid.stride(0);
748 int64_t grid_sH = grid.stride(1);
749 int64_t grid_sW = grid.stride(2);
750 int64_t grid_sCoor = grid.stride(3);
751 int64_t gOut_sN = grad_output.stride(0);
752 int64_t gOut_sC = grad_output.stride(1);
753 int64_t gOut_sH = grad_output.stride(2);
754 int64_t gOut_sW = grad_output.stride(3);
755 int64_t gInp_sN = grad_input.stride(0);
756 int64_t gInp_sC = grad_input.stride(1);
757 int64_t gInp_sH = grad_input.stride(2);
758 int64_t gInp_sW = grad_input.stride(3);
759 int64_t gGrid_sN = grad_grid.stride(0);
760 int64_t gGrid_sW = grad_grid.stride(2);
761 const scalar_t *inp_ptr = input.const_data_ptr<scalar_t>();
762 const scalar_t *grid_ptr = grid.const_data_ptr<scalar_t>();
763 const scalar_t *gOut_ptr = grad_output.const_data_ptr<scalar_t>();
764 scalar_t *gInp_ptr = grad_input.mutable_data_ptr<scalar_t>();
765 scalar_t *gGrid_ptr = grad_grid.data_ptr<scalar_t>();
766 // loop over each output pixel
767 at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
768 for (const auto n : c10::irange(start, end)) {
769 const scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
770 const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
771 scalar_t *gGrid_ptr_NHW = gGrid_ptr + n * gGrid_sN;
772 for (const auto h : c10::irange(out_H)) {
773 for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NHW += gGrid_sW /* grad_grid is contiguous */ ) {
774 // get the corresponding input x, y co-ordinates from grid
775 const scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
776 scalar_t x = *grid_ptr_NHW;
777 scalar_t y = grid_ptr_NHW[grid_sCoor];
778
779 // multipliers for gradients on ix, iy
780 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
781 scalar_t gix_mult, giy_mult;
782 scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult);
783 scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult);
784
785 if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
786 // get corner pixel values from (x, y)
787 // for 4d, we use north-east-south-west
788 int64_t ix_nw = static_cast<int64_t>(std::floor(ix));
789 int64_t iy_nw = static_cast<int64_t>(std::floor(iy));
790
791 int64_t ix_ne = ix_nw + 1;
792 int64_t iy_ne = iy_nw;
793
794 int64_t ix_sw = ix_nw;
795 int64_t iy_sw = iy_nw + 1;
796
797 int64_t ix_se = ix_nw + 1;
798 int64_t iy_se = iy_nw + 1;
799
800 // get surfaces to each neighbor:
801 scalar_t nw = (ix_se - ix) * (iy_se - iy);
802 scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
803 scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
804 scalar_t se = (ix - ix_nw) * (iy - iy_nw);
805
806 scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0);
807 const scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;
808 scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
809 const scalar_t *inp_ptr_NC = inp_ptr_N;
810 // calculate bilinear weighted pixel value and set output pixel
811 for (int64_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
812 scalar_t gOut = *gOut_ptr_NCHW;
813
814 // calculate and set grad_input
815 safe_add_2d(gInp_ptr_NC, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw * gOut);
816 safe_add_2d(gInp_ptr_NC, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne * gOut);
817 safe_add_2d(gInp_ptr_NC, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw * gOut);
818 safe_add_2d(gInp_ptr_NC, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se * gOut);
819
820 // calculate grad_grid
821 if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
822 scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW];
823 gix -= nw_val * (iy_se - iy) * gOut;
824 giy -= nw_val * (ix_se - ix) * gOut;
825 }
826 if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
827 scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW];
828 gix += ne_val * (iy_sw - iy) * gOut;
829 giy -= ne_val * (ix - ix_sw) * gOut;
830 }
831 if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
832 scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW];
833 gix -= sw_val * (iy - iy_ne) * gOut;
834 giy += sw_val * (ix_ne - ix) * gOut;
835 }
836 if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
837 scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW];
838 gix += se_val * (iy - iy_nw) * gOut;
839 giy += se_val * (ix - ix_nw) * gOut;
840 }
841 }
842
843 // assuming grad_grid is contiguous
844 gGrid_ptr_NHW[0] = gix_mult * gix;
845 gGrid_ptr_NHW[1] = giy_mult * giy;
846 } else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
847 int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
848 int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
849
850 // assign nearest neighbour pixel value to output pixel
851 const scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;
852 scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
853 for (int64_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC) {
854 // calculate and set grad_input
855 safe_add_2d(gInp_ptr_NC, iy_nearest, ix_nearest, gInp_sH, gInp_sW,
856 inp_H, inp_W, *gOut_ptr_NCHW);
857 }
858 } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
859
860 ix = grid_sampler_unnormalize_set_grad(x, inp_W, align_corners, &gix_mult);
861 iy = grid_sampler_unnormalize_set_grad(y, inp_H, align_corners, &giy_mult);
862
863 scalar_t ix_nw = std::floor(ix);
864 scalar_t iy_nw = std::floor(iy);
865
866 const scalar_t tx = ix - ix_nw;
867 const scalar_t ty = iy - iy_nw;
868
869 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
870 scalar_t x_coeffs[4];
871 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
872 scalar_t y_coeffs[4];
873 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
874 scalar_t x_coeffs_grad[4];
875 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
876 scalar_t y_coeffs_grad[4];
877
878 get_cubic_upsample_coefficients<scalar_t>(x_coeffs, tx);
879 get_cubic_upsample_coefficients<scalar_t>(y_coeffs, ty);
880 get_cubic_coefficients_grad<scalar_t>(x_coeffs_grad, tx);
881 get_cubic_coefficients_grad<scalar_t>(y_coeffs_grad, ty);
882
883 scalar_t gix = static_cast<scalar_t>(0);
884 scalar_t giy = static_cast<scalar_t>(0);
885
886 const scalar_t *gOut_ptr_NCHW = gOut_ptr + n * gOut_sN + h * gOut_sH + w * gOut_sW;
887 scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN;
888 const scalar_t *inp_ptr_NC = inp_ptr_N;
889
890 for (int64_t c = 0; c < C; ++c, gOut_ptr_NCHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC+= inp_sC) {
891 scalar_t gOut = *gOut_ptr_NCHW;
892
893 for (const auto i : c10::irange(4)) {
894 for (const auto j : c10::irange(4)) {
895
896 // set input gradient
897 add_value_bounded<scalar_t>(gInp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j,
898 inp_W, inp_H, gInp_sW, gInp_sH, gOut * x_coeffs[i] * y_coeffs[j], padding_mode, align_corners);
899
900 // set grid gradient
901 scalar_t val = get_value_bounded<scalar_t>(inp_ptr_NC, ix_nw - 1 + i, iy_nw - 1 + j,
902 inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners);
903
904 gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut;
905 giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut;
906 }
907 }
908 }
909 gGrid_ptr_NHW[0] = gix_mult * gix;
910 gGrid_ptr_NHW[1] = giy_mult * giy;
911 }
912 }
913 }
914 }
915 });
916 return std::make_tuple(grad_input, grad_grid);
917 }
918
grid_sampler_2d_cpu(const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)919 Tensor grid_sampler_2d_cpu(const Tensor& input, const Tensor& grid,
920 int64_t interpolation_mode, int64_t padding_mode,
921 bool align_corners) {
922 // See NOTE [ grid_sampler Native Functions ].
923 // Add checks here in case this is called instead of grid_sampler.
924 check_grid_sampler_common(input, grid);
925 check_grid_sampler_2d(input, grid);
926
927 if (input.scalar_type() == kQUInt8) {
928 return native::_grid_sampler_2d_cpu_quantized(
929 input, grid, interpolation_mode, padding_mode, align_corners);
930 }
931 // AVX gather instructions use signed 32-bit offsets to gather float values.
932 // Check for possible overflow and fallback to scalar implementation
933 if (input.scalar_type() != kDouble) {
934 TORCH_CHECK(input.scalar_type() == kFloat,
935 "grid_sampler_2d_cpu not implemented for ", input.scalar_type());
936 auto sizes = input.sizes();
937 auto strides = input.strides();
938 const auto grid_sW = grid.strides()[2];
939 // NOTE: Gather offsets are only used for the input H, W dimensions
940 // or only for strided access to the grid tensor
941 auto max_gather_offset = std::max(
942 (sizes[2] - 1) * strides[2] + (sizes[3] - 1) * strides[3],
943 grid_sW * (vec::Vectorized<float>::size() - 1));
944
945 if (max_gather_offset > std::numeric_limits<int32_t>::max()) {
946 return native::_grid_sampler_2d_cpu_fallback(
947 input, grid, interpolation_mode, padding_mode, align_corners);
948 }
949 }
950
951 auto in_size = input.sizes();
952 auto grid_size = grid.sizes();
953 auto output = at::empty(
954 {in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
955 grid_sampler_2d_cpu_kernel(
956 kCPU, output, input, grid, interpolation_mode, padding_mode, align_corners);
957 return output;
958 }
959
960 DEFINE_DISPATCH(grid_sampler_2d_cpu_kernel);
961
962
grid_sampler_3d_cpu(const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)963 Tensor grid_sampler_3d_cpu(const Tensor& input, const Tensor& grid,
964 int64_t interpolation_mode, int64_t padding_mode,
965 bool align_corners) {
966 // See NOTE [ grid_sampler Native Functions ].
967 // Add checks here in case this is called instead of grid_sampler.
968 check_grid_sampler_common(input, grid);
969 check_grid_sampler_3d(input, grid, interpolation_mode);
970
971 return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler3d_cpu", [&] {
972 return grid_sampler_3d_cpu_impl<scalar_t>(
973 input, grid, static_cast<GridSamplerInterpolation>(interpolation_mode),
974 static_cast<GridSamplerPadding>(padding_mode), align_corners);
975 });
976 }
977
978 std::tuple<Tensor, Tensor>
grid_sampler_2d_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners,std::array<bool,2> output_mask)979 grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid,
980 int64_t interpolation_mode, int64_t padding_mode, bool align_corners,
981 std::array<bool,2> output_mask) {
982 // See NOTE [ grid_sampler Native Functions ].
983 // Add checks here in case this is called instead of grid_sampler.
984 check_grid_sampler_common(input, grid);
985 check_grid_sampler_2d(input, grid);
986
987 // AVX gather instructions use signed 32-bit offsets to gather float values.
988 // Check for possible overflow and fallback to scalar implementation
989 if (input.scalar_type() != kDouble) {
990 TORCH_CHECK(input.scalar_type() == kFloat,
991 "grid_sampler_2d_backward_cpu not implemented for ", input.scalar_type());
992 auto isizes = input.sizes();
993 auto istrides = input.strides();
994 auto gsizes = grad_output.sizes();
995 auto gstrides = grad_output.strides();
996 const auto grid_sW = grid.strides()[2];
997 // NOTE: Gather offsets are only used for the height and width dimensions
998 auto max_gather_offset = std::max(
999 std::max(
1000 (isizes[2] - 1) * istrides[2] + (isizes[3] - 1) * istrides[3],
1001 (gsizes[2] - 1) * gstrides[2] + (gsizes[3] - 1) * gstrides[3]),
1002 grid_sW * (vec::Vectorized<float>::size() - 1));
1003
1004 if (max_gather_offset > std::numeric_limits<int32_t>::max()) {
1005 return native::_grid_sampler_2d_cpu_fallback_backward(
1006 grad_output, input, grid, interpolation_mode, padding_mode, align_corners);
1007 }
1008 }
1009
1010 auto input_requires_grad = output_mask[0];
1011 Tensor grad_input = ([&]() {
1012 if (input_requires_grad) {
1013 return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1014 } else {
1015 return Tensor();
1016 }
1017 })();
1018 auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1019 grid_sampler_2d_backward_cpu_kernel(
1020 kCPU, grad_input, grad_grid, grad_output, input, grid,
1021 interpolation_mode, padding_mode, align_corners, output_mask);
1022 return std::make_tuple(std::move(grad_input), std::move(grad_grid));
1023 }
1024
1025 DEFINE_DISPATCH(grid_sampler_2d_backward_cpu_kernel);
1026
1027 std::tuple<Tensor, Tensor>
grid_sampler_3d_backward_cpu(const Tensor & grad_output,const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners,std::array<bool,2> output_mask)1028 grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid,
1029 int64_t interpolation_mode, int64_t padding_mode, bool align_corners,
1030 std::array<bool,2> output_mask) {
1031 // See NOTE [ grid_sampler Native Functions ].
1032 // Add checks here in case this is called instead of grid_sampler.
1033 check_grid_sampler_common(input, grid);
1034 check_grid_sampler_3d(input, grid, interpolation_mode);
1035
1036 return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_3d_backward_cpu", [&] {
1037 return grid_sampler_3d_backward_cpu_impl<scalar_t>(
1038 grad_output, input, grid,
1039 static_cast<GridSamplerInterpolation>(interpolation_mode),
1040 static_cast<GridSamplerPadding>(padding_mode),
1041 align_corners, output_mask);
1042 });
1043 }
1044
1045 // See NOTE [ grid_sampler Native Functions ].
grid_sampler(const Tensor & input,const Tensor & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)1046 Tensor grid_sampler(
1047 const Tensor& input,
1048 const Tensor& grid,
1049 int64_t interpolation_mode,
1050 int64_t padding_mode,
1051 bool align_corners
1052 ) {
1053 if (cond_cudnn_grid_sampler(input, grid) &&
1054 static_cast<GridSamplerInterpolation>(interpolation_mode) ==
1055 GridSamplerInterpolation::Bilinear &&
1056 static_cast<GridSamplerPadding>(padding_mode) ==
1057 GridSamplerPadding::Zeros &&
1058 align_corners) {
1059 return cudnn_grid_sampler(input, grid);
1060 }
1061
1062 if (input.dim() == 4) {
1063 return at::grid_sampler_2d(
1064 input, grid, interpolation_mode, padding_mode, align_corners);
1065 } else {
1066 return at::grid_sampler_3d(
1067 input, grid, interpolation_mode, padding_mode, align_corners);
1068 }
1069 }
1070
1071 } // namespace at::native
1072