xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/GridSamplerKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/GridSampler.h>
3 #include <ATen/native/cpu/GridSamplerKernel.h>
4 #include <ATen/core/TensorBase.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/TensorGeometry.h>
8 #include <ATen/TensorIterator.h>
9 #include <ATen/cpu/vec/vec.h>
10 #include <c10/util/irange.h>
11 
12 #include <algorithm>
13 #include <cstring>
14 
15 namespace at::native { namespace {
16 
17 /**  NOTE [ Grid Sample CPU Kernels ]
18  *
19  *   Implementation of vectorized grid sample CPU kernels is divided into three
20  *   parts. More detailed description exist after this paragraph, but on a high
21  *   level, they are
22  *   1. `ComputeLocation` struct
23  *      + Computes the interpolation location basing on padding mode.
24  *   2. `ApplyGridSample` struct
25  *      + Owns N (# spatial dims) `ComputeLocation` structs, and uses them to
26  *        compute the interpolation locations.
27  *      + Interpolates the values and writes to output.
28  *   3. `grid_sample_2d_grid_slice_iterator` function
29  *      + Iterates over a slice of the grid tensor based on the geometry by the
30  *        spatial ordering, i.e., the first iteration will process grid values
31  *           grid[n, 0, 0, :], grid[n, 0, 1, :], grid[n, 0, 2, :], ...
32  *        (Recall that, e.g., 2D grid has shape [N x H x W x 2], so grid[n, ...]
33  *         is a slice, and grid[n, h, w, :] contains the values for a single
34  *         output spatial location.)
35  *      + Applies a given operator at each iteration, so we can use the same
36  *        pattern for forward and backward.
37  *
38  *   Putting everything together, we have, e.g., the forward kernel implemented
39  *   as
40  *
41  *      // `ApplyGridSample` struct that processes grid values, extracts and
42  *      // interpolates input values, and write to output.
43  *      ApplyGridSample<scalar_t, 2, interp, padding> grid_sample(input_accessor);
44  *
45  *      // For each slice, we call `grid_sample_2d_grid_slice_iterator` with
46  *      //   1. the grid slice, and
47  *      //   2. a lambda that takes in
48  *      //      i.   location vectors (x and y for 2D) extracted from grid
49  *      //      ii.  `spatial_offset` as the spatial offset of these vectors
50  *      //           from the beginning of this slice.
51  *      //      iii. `len` as the number of valid locations in the vectors.
52  *      //           (There might not be enough near boundary.)
53  *      for (const auto n : c10::irange(input_accessor.size(0))) {
54  *        grid_sample_2d_grid_slice_iterator(
55  *          grid_accessor[n],
56  *          [&](const Vectorized<scalar_t>& grid_x,
57  *              const Vectorized<scalar_t>& grid_y,
58  *              int64_t spatial_offset, int64_t len) {
59  *            grid_sample.forward(out_accessor[n], input_accessor[n],
60  *                                spatial_offset, grid_x, grid_y, len);
61  *          });
62  *      }
63  *
64  *   Now we talk about details of each of these three parts:
65  *
66  *   1. `ComputeLocation` struct
67  *      Transforms grid values into interpolation locations of the input tensor
68  *      for a particular spatial dimension, based on the size of that dimension
69  *      in input tensor, and the padding mode.
70  *
71  *        template<typename scalar_t, GridSamplerPadding padding>
72  *        struct ComputeLocation {
73  *          using Vec = Vectorized<scalar_t>;
74  *
75  *          // ctor
76  *          ComputeLocation(int64_t size);
77  *
78  *          // Given grid values `in`, return the interpolation locations after
79  *          // un-normalization and padding mechanism (elementwise).
80  *          Vec apply(const Vec &in) const;
81  *
82  *          // Similar to `apply`, but also returns `d apply(in) / d in`
83  *          // (elementwise).
84  *          // this is often used in gradient computation.
85  *          std::pair<Vec, Vec> apply_get_grad(const Vec &in) const;
86  *        };
87  *
88  *   2. `ApplyGridSample` struct
89  *      Owns N `ComputeLocation` structs, where N is the number of spatial
90  *      dimensions. Given N input grid vectors (one for each spatial dimension)
91  *      and spatial offset, it gets the interpolation locations from
92  *      `ComputeLocation`s, applies interpolation procedure, and then writes to
93  *      the output (or grad_input & grad_grid in backward).
94  *
95  *        template<typename scalar_t, int spatial_dim,
96  *                 GridSamplerInterpolation interp,
97  *                 GridSamplerPadding padding>
98  *        struct ApplyGridSample {
99  *
100  *          // ctor
101  *          ApplyGridSample(const TensorAccessor<scalar_t, 4>& input);
102  *
103  *          // Applies grid sampling (forward) procedure:
104  *          //   1. computes interpolation locations from grid values `grid_x`
105  *          //      and `grid_y`,
106  *          //   2. interpolates output values using the locations and input
107  *          //      data in `inp_slice`, and
108  *          //   3. writes the first `len` values in the interpolated vector to
109  *          //      `out_slice` with spatial offset being `offset`.
110  *          //
111  *          // This assumes that `grid_x` and `grid_y` all contain valid grid
112  *          // values \in [-1, 1], even at indices greater than `len`.
113  *          //
114  *          // The `*_slice` argument names mean samples within a batch (i.e.,
115  *          // with the batch dimension sliced out).
116  *          void forward(TensorAccessor<scalar_t, 3>& out_slice,
117  *                       const TensorAccessor<scalar_t, 3>& inp_slice,
118  *                       int64_t offset, const Vec& grid_x, const Vec& grid_y,
119  *                       int64_t len) const;
120  *
121  *          // Applies grid sampling (backward) procedure. Arguments semantics
122  *          // and strategy are similar to those of `forward`, with the
123  *          // exception that `backward` has branches based on whether `input`
124  *          // requires gradient (passed in as a template parameter). The
125  *          // TensorAccessor for the input gradient is also given as a
126  *          // pointer instead of reference, so that it can be null if the
127  *          // gradient is not calculated.
128  *          template <bool input_requires_grad>
129  *          void backward(TensorAccessor<scalar_t, 3>* gInp_slice_ptr,
130  *                        TensorAccessor<scalar_t, 3>& gGrid_slice,
131  *                        const TensorAccessor<scalar_t, 3>& gOut_slice,
132  *                        const TensorAccessor<scalar_t, 3>& inp_slice,
133  *                        int64_t offset, const Vec& grid_x, const Vec& grid_y,
134  *                        int64_t len) const;
135  *        };
136  *
137  *   3. `grid_sample_2d_grid_slice_iterator` function
138  *      Among the tensors we work with, we know that the output tensors are
139  *      contiguous (i.e., `output` in forward, and `grad_input` & `grad_grid` in
140  *      backward), we need to randomly read `input` anyways, and `grad_output`
141  *      usually comes from autograd and is often contiguous. So we base our
142  *      iterating strategy on the geometry of grid.
143  *      `grid_sample_2d_grid_slice_iterator` function provides an abstraction to
144  *      efficiently iterates through a `grid` slice (without batch dimension).
145  *      See comments of that function on the specific cases and strategies used.
146  *
147  *        template<typename scalar_t, typename ApplyFn>
148  *        void grid_sample_2d_grid_slice_iterator(
149  *          const TensorAccessor<scalar_t, 3>& grid_slice,
150  *          const ApplyFn &apply_fn);
151  *
152  *      `apply_fn` is a function/lambda that takes in
153  *           i.   location vectors (x and y for 2D) extracted from grid
154  *           ii.  `spatial_offset` as the spatial offset of these vectors
155  *                from the beginning of this slice.
156  *           iii. `len` as the number of valid locations in the vectors.
157  *                (There might not be enough near boundary.)
158 
159  *       It should be callable as if it has declaration:
160  *          void apply_fn(const Vectorized<scalar_t>& grid_x,
161  *                        const Vectorized<scalar_t>& grid_y,
162  *                        int64_t spatial_offset, int64_t len);
163  *
164  *      `apply_fn` will be called multiple times, and together cover the entire
165  *      output spatial space.
166  *
167  *  Now you should be able to understand everything about the implementation of
168  *  2D forward kernel shown at the beginning of this note.
169  *
170  **/
171 
172 
173 using at::native::detail::GridSamplerInterpolation;
174 using at::native::detail::GridSamplerPadding;
175 using namespace at::vec;
176 
177 
178 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ComputeLocation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
179 // Struct to compute interpolation location from grid values, and to apply
180 // padding mechanism (e.g., reflection).
181 // See NOTE [ Grid Sample CPU Kernels ] for details.
182 
183 template<typename scalar_t, bool align_corners>
184 struct ComputeLocationBase;
185 
186 template<typename scalar_t>
187 struct ComputeLocationBase<scalar_t, /*align_corners=*/true> {
188   using Vec = Vectorized<scalar_t>;
189 
190   // values are clipped to between 0 and max_val
191   const scalar_t max_val;
192   // unnormalization scaling factor
193   const scalar_t scaling_factor;
194   // reflection parameters: reflected coordinates land in [low, low+span] inclusive
195   const scalar_t low; // only used when align_corners=False
196   const scalar_t twice_span;
197   // if the reflecting span is empty, all reflected coords are set to 0
198   const bool empty;
199 
ComputeLocationBaseat::native::__anon4fc6df3d0111::ComputeLocationBase200   ComputeLocationBase(int64_t size)
201     : max_val(static_cast<scalar_t>(size - 1))
202     , scaling_factor(static_cast<scalar_t>(size - 1) / 2)
203     , low(static_cast<scalar_t>(0))
204     , twice_span(static_cast<scalar_t>(size - 1) * 2)
205     , empty(size <= 1) {}
206 
unnormalizeat::native::__anon4fc6df3d0111::ComputeLocationBase207   inline Vec unnormalize(const Vec &in) const {
208     return (in + Vec(1)) * Vec(scaling_factor);
209   }
210 
clip_coordinatesat::native::__anon4fc6df3d0111::ComputeLocationBase211   inline Vec clip_coordinates(const Vec &in) const {
212     // Invert order of clamp_min operands in order to clamp Nans to zero
213     return clamp_max(Vec(max_val), clamp_min(Vec(0), in));
214   }
215 
216   // same as clip_coordinates but also returns the gradient multiplier
clip_coordinates_get_gradat::native::__anon4fc6df3d0111::ComputeLocationBase217   inline std::pair<Vec, Vec> clip_coordinates_get_grad(const Vec &in) const {
218     using int_t = int_same_size_t<scalar_t>;
219     auto bounded_lo = maximum(in, Vec(0));
220     // Integral type equality comparison is very very fast because it just looks
221     // at the bits. Casting is free too. So we use the following pattern instead
222     // of comparison + blendv.
223     // Note that it is important for the gradient calculation that borders
224     // are considered out of bounds.
225     auto in_bound_lo = cast<scalar_t>(cast<int_t>(bounded_lo) != cast<int_t>(Vec(0)));
226     auto res = minimum(bounded_lo, Vec(max_val));
227     auto in_bound_hi = cast<scalar_t>(cast<int_t>(res) != cast<int_t>(Vec(max_val)));
228     return std::make_pair(res, in_bound_lo & in_bound_hi);
229   }
230 
reflect_coordinatesat::native::__anon4fc6df3d0111::ComputeLocationBase231   inline Vec reflect_coordinates(const Vec &in) const {
232     if (empty) {
233       return Vec(0);
234     }
235     Vec twice_span_vec(twice_span);
236     auto abs_in = in.abs();
237     auto fdouble_flips = abs_in / twice_span_vec;
238     auto double_flips = fdouble_flips.trunc();
239     auto extra = abs_in - double_flips * twice_span_vec;
240     // Now we need to test if extra > max_val to find out if another flip is
241     // needed. The following comparison does that and returns the correct
242     // flipped value.
243     return minimum(extra, twice_span_vec - extra);
244   }
245 
246   // same as reflect_coordinates but also returns the gradient multiplier
reflect_coordinates_get_gradat::native::__anon4fc6df3d0111::ComputeLocationBase247   inline std::pair<Vec, Vec> reflect_coordinates_get_grad(const Vec &in) const {
248     if (empty) {
249       return std::make_pair(Vec(0), Vec(0));
250     }
251     Vec twice_span_vec(twice_span);
252     auto neg_in = in < Vec(0);
253     auto abs_in = in.abs();
254     auto fdouble_flips = abs_in / twice_span_vec;
255     auto double_flips = fdouble_flips.trunc();
256 
257     auto extra = abs_in - double_flips * twice_span_vec;
258     auto reflected_extra = twice_span_vec - extra;
259     auto one_more_flip = extra > reflected_extra;
260 
261     return std::make_pair(
262       Vec::blendv(extra, reflected_extra, one_more_flip),
263       Vec::blendv(Vec(1), Vec(-1), one_more_flip ^ neg_in)
264     );
265   }
266 };
267 
268 template<typename scalar_t>
269 struct ComputeLocationBase<scalar_t, /*align_corners=*/false> {
270   using Vec = Vectorized<scalar_t>;
271 
272   // values are clipped to between 0 and max_val
273   const scalar_t max_val;
274   // unnormalization scaling factor
275   const scalar_t scaling_factor;
276   // reflection parameters: reflected coordinates land in [low, low+span] inclusive
277   const scalar_t low;
278   const scalar_t twice_span;
279   // if the reflecting span is empty, all reflected coords are set to 0
280   const bool empty; // only used when align_corners=True
281 
ComputeLocationBaseat::native::__anon4fc6df3d0111::ComputeLocationBase282   ComputeLocationBase(int64_t size)
283     : max_val(static_cast<scalar_t>(size - 1))
284     , scaling_factor(static_cast<scalar_t>(size) / 2)
285     , low(static_cast<scalar_t>(-0.5))
286     , twice_span(static_cast<scalar_t>(size) * 2)
287     , empty(size <= 0) {}
288 
unnormalizeat::native::__anon4fc6df3d0111::ComputeLocationBase289   inline Vec unnormalize(const Vec &in) const {
290     return (in + Vec(1)) * Vec(scaling_factor) - Vec(0.5);
291   }
292 
clip_coordinatesat::native::__anon4fc6df3d0111::ComputeLocationBase293   inline Vec clip_coordinates(const Vec &in) const {
294     // Invert order of clamp_min operands in order to clamp Nans to zero
295     return clamp_max(Vec(max_val), clamp_min(Vec(0), in));
296   }
297 
298   // same as clip_coordinates but also returns the gradient multiplier
clip_coordinates_get_gradat::native::__anon4fc6df3d0111::ComputeLocationBase299   inline std::pair<Vec, Vec> clip_coordinates_get_grad(const Vec &in) const {
300     using int_t = int_same_size_t<scalar_t>;
301     auto bounded_lo = maximum(in, Vec(0));
302     // Integral type equality comparison is very very fast because it just looks
303     // at the bits. Casting is free too. So we use the following pattern instead
304     // of comparison + blendv.
305     // Note that it is important for the gradient calculation that borders
306     // are considered out of bounds.
307     auto in_bound_lo = cast<scalar_t>(cast<int_t>(bounded_lo) != cast<int_t>(Vec(0)));
308     auto res = minimum(bounded_lo, Vec(max_val));
309     auto in_bound_hi = cast<scalar_t>(cast<int_t>(res) != cast<int_t>(Vec(max_val)));
310     return std::make_pair(res, in_bound_lo & in_bound_hi);
311   }
312 
reflect_coordinatesat::native::__anon4fc6df3d0111::ComputeLocationBase313   inline Vec reflect_coordinates(const Vec &in) const {
314     Vec twice_span_vec(twice_span), low_vec(low);
315     // Since reflection is around low and low+span, subtract low before
316     // the reflection, and then add it back at the end.
317     auto abs_in = (in - low_vec).abs();
318     auto fdouble_flips = abs_in / twice_span_vec;
319     auto double_flips = fdouble_flips.trunc();
320     auto extra = abs_in - double_flips * twice_span_vec;
321     // Now we need to test if extra > max_val to find out if another flip is
322     // needed. The following comparison does that and returns the correct
323     // flipped value.
324     return minimum(extra, twice_span_vec - extra) + low_vec;
325   }
326 
327   // same as reflect_coordinates but also returns the gradient multiplier
reflect_coordinates_get_gradat::native::__anon4fc6df3d0111::ComputeLocationBase328   inline std::pair<Vec, Vec> reflect_coordinates_get_grad(const Vec &in) const {
329     Vec twice_span_vec(twice_span), low_vec(low);
330     Vec in_minus_low = in - low_vec;
331     auto neg_in = in_minus_low < Vec(0);
332     auto abs_in = in_minus_low.abs();
333     auto fdouble_flips = abs_in / twice_span_vec;
334     auto double_flips = fdouble_flips.trunc();
335 
336     auto extra = abs_in - double_flips * twice_span_vec;
337     auto reflected_extra = twice_span_vec - extra;
338     auto one_more_flip = extra > reflected_extra;
339 
340     return std::make_pair(
341       Vec::blendv(extra, reflected_extra, one_more_flip) + low_vec,
342       Vec::blendv(Vec(1), Vec(-1), one_more_flip ^ neg_in)
343     );
344   }
345 };
346 
347 template<typename scalar_t, GridSamplerPadding padding, bool align_corners>
348 struct ComputeLocation;
349 
350 template<typename scalar_t, bool align_corners>
351 struct ComputeLocation<scalar_t, GridSamplerPadding::Zeros, align_corners>
352   : ComputeLocationBase<scalar_t, align_corners> {
353   using Vec = Vectorized<scalar_t>;
354   using ComputeLocationBase<scalar_t, align_corners>::unnormalize;
355   using ComputeLocationBase<scalar_t, align_corners>::scaling_factor;
356 
357   using ComputeLocationBase<scalar_t, align_corners>::ComputeLocationBase;
358 
applyat::native::__anon4fc6df3d0111::ComputeLocation359   inline Vec apply(const Vec &in) const {
360     return unnormalize(in);
361   }
362 
compute_coordinatesat::native::__anon4fc6df3d0111::ComputeLocation363   inline Vec compute_coordinates(const Vec &in) const {
364     return in;
365   }
366 
apply_get_gradat::native::__anon4fc6df3d0111::ComputeLocation367   inline std::pair<Vec, Vec> apply_get_grad(const Vec &in) const {
368     return std::make_pair(unnormalize(in), Vec(scaling_factor));
369   }
370 };
371 
372 template<typename scalar_t, bool align_corners>
373 struct ComputeLocation<scalar_t, GridSamplerPadding::Border, align_corners>
374   : ComputeLocationBase<scalar_t, align_corners> {
375   using Vec = Vectorized<scalar_t>;
376   using ComputeLocationBase<scalar_t, align_corners>::unnormalize;
377   using ComputeLocationBase<scalar_t, align_corners>::clip_coordinates;
378   using ComputeLocationBase<scalar_t, align_corners>::clip_coordinates_get_grad;
379   using ComputeLocationBase<scalar_t, align_corners>::scaling_factor;
380 
381   using ComputeLocationBase<scalar_t, align_corners>::ComputeLocationBase;
382 
applyat::native::__anon4fc6df3d0111::ComputeLocation383   inline Vec apply(const Vec &in) const {
384     return clip_coordinates(unnormalize(in));
385   }
386 
compute_coordinatesat::native::__anon4fc6df3d0111::ComputeLocation387   inline Vec compute_coordinates(const Vec &in) const {
388     return clip_coordinates(in);
389   }
390 
apply_get_gradat::native::__anon4fc6df3d0111::ComputeLocation391   inline std::pair<Vec, Vec> apply_get_grad(const Vec &in) const {
392     auto [res, grad_clip] = clip_coordinates_get_grad(unnormalize(in));
393     return std::make_pair(res, grad_clip & Vec(scaling_factor));
394   }
395 };
396 
397 template<typename scalar_t, bool align_corners>
398 struct ComputeLocation<scalar_t, GridSamplerPadding::Reflection, align_corners>
399   : ComputeLocationBase<scalar_t, align_corners> {
400   using Vec = Vectorized<scalar_t>;
401   using ComputeLocationBase<scalar_t, align_corners>::unnormalize;
402   using ComputeLocationBase<scalar_t, align_corners>::clip_coordinates;
403   using ComputeLocationBase<scalar_t, align_corners>::clip_coordinates_get_grad;
404   using ComputeLocationBase<scalar_t, align_corners>::reflect_coordinates;
405   using ComputeLocationBase<scalar_t, align_corners>::reflect_coordinates_get_grad;
406   using ComputeLocationBase<scalar_t, align_corners>::scaling_factor;
407 
408   using ComputeLocationBase<scalar_t, align_corners>::ComputeLocationBase;
409 
applyat::native::__anon4fc6df3d0111::ComputeLocation410   inline Vec apply(const Vec &in) const {
411     auto res = reflect_coordinates(unnormalize(in));
412     res = clip_coordinates(res);
413     return res;
414   }
415 
compute_coordinatesat::native::__anon4fc6df3d0111::ComputeLocation416   inline Vec compute_coordinates(const Vec &in) const {
417     auto res = reflect_coordinates(in);
418     res = clip_coordinates(res);
419     return res;
420   }
421 
apply_get_gradat::native::__anon4fc6df3d0111::ComputeLocation422   inline std::pair<Vec, Vec> apply_get_grad(const Vec &in) const {
423     auto [res, grad_refl] = reflect_coordinates_get_grad(unnormalize(in));
424     Vec grad(scaling_factor);
425     grad = grad_refl * grad;
426     auto [res2, grad_clip] = clip_coordinates_get_grad(res);
427     grad = grad_clip & grad;
428     return std::make_pair(res2, grad);
429   }
430 };
431 
432 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ApplyGridSample ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
433 // Struct to apply grid sample (reading from input, interpolate, and write to
434 // output).
435 // See NOTE [ Grid Sample CPU Kernels ] for details.
436 
437 template<typename scalar_t>
438 static inline void
mask_scatter_add(const scalar_t * src,scalar_t * base_addr,const int_same_size_t<scalar_t> * offsets,const int_same_size_t<scalar_t> * mask,int64_t len)439 mask_scatter_add(const scalar_t *src, scalar_t* base_addr,
440                  const int_same_size_t<scalar_t> *offsets,
441                  const int_same_size_t<scalar_t> *mask, int64_t len) {
442   #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
443   # pragma unroll
444   #endif
445   for (const auto i : c10::irange(len)) {
446     if (mask[i] & 0x01) {
447       base_addr[offsets[i]] += src[i];
448     }
449   }
450 }
451 
452 template<typename scalar_t, int spatial_dim,
453          GridSamplerInterpolation interp,
454          GridSamplerPadding padding,
455          bool align_corners>
456 struct ApplyGridSample;
457 
458 template<typename scalar_t, GridSamplerPadding padding, bool align_corners>
459 struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bilinear,
460                        padding, align_corners> {
461   using Vec = Vectorized<scalar_t>;
462   using integer_t = int_same_size_t<scalar_t>;
463   using iVec = Vectorized<integer_t>;
464 
465   const int64_t inp_H;
466   const int64_t inp_W;
467   const int64_t inp_sH;
468   const int64_t inp_sW;
469   const int64_t C;
470   const int64_t inp_sC;
471   const ComputeLocation<scalar_t, padding, align_corners> compute_H;
472   const ComputeLocation<scalar_t, padding, align_corners> compute_W;
473   const bool must_in_bound = padding != GridSamplerPadding::Zeros;
474 
ApplyGridSampleat::native::__anon4fc6df3d0111::ApplyGridSample475   ApplyGridSample(const TensorAccessor<const scalar_t, 4>& input)
476     : inp_H(input.size(2))
477     , inp_W(input.size(3))
478     , inp_sH(input.stride(2))
479     , inp_sW(input.stride(3))
480     , C(input.size(1))
481     , inp_sC(input.stride(1))
482     , compute_H(input.size(2))
483     , compute_W(input.size(3)) {}
484 
485   inline std::tuple<
486     Vec, Vec, Vec, Vec,       // distances to 4 sides
487     Vec, Vec, Vec, Vec,       // interpolation weights wrt 4 corners
488     Vec, Vec, Vec, Vec,       // in_bound masks
489     iVec, iVec                // y_n and x_w
490   >
compute_interp_paramsat::native::__anon4fc6df3d0111::ApplyGridSample491   compute_interp_params(const Vec& x, const Vec& y) const {
492     // get NE, NW, SE, SW pixel values from (x, y)
493     // assuming we get exact integer representation and just use scalar_t
494     // if we don't, the weights will be garbage anyways.
495     auto x_w = x.floor();
496     auto y_n = y.floor();
497 
498     // get distances to each side
499     auto w = x - x_w;
500     auto e = Vec(1) - w;
501     auto n = y - y_n;
502     auto s = Vec(1) - n;
503 
504     // get interpolation weights for each neighbor
505     // e.g., for the nw corner, the weight is `dist_to_south * dist_to_east`.
506     auto nw = s * e;
507     auto ne = s * w;
508     auto sw = n * e;
509     auto se = n * w;
510 
511     auto i_x_w = convert_to_int_of_same_size(x_w);
512     auto i_y_n = convert_to_int_of_same_size(y_n);
513     auto i_x_e = i_x_w + iVec(1);
514     auto i_y_s = i_y_n + iVec(1);
515 
516     // Use int comparison because it is much faster than float comp with AVX2
517     // (latency 1 cyc vs. 4 cyc on skylake)
518     // Avoid using the le and ge because those are not implemented in AVX2 and
519     // are actually simulated using multiple instructions.
520     auto w_mask = must_in_bound ? iVec(-1)  // true = all ones
521                                 : (i_x_w > iVec(-1)) & (i_x_w < iVec(inp_W));
522     auto n_mask = must_in_bound ? iVec(-1)  // true = all ones
523                                 : (i_y_n > iVec(-1)) & (i_y_n < iVec(inp_H));
524     auto e_mask = must_in_bound ? (i_x_e < iVec(inp_W))
525                                 : (i_x_e > iVec(-1)) & (i_x_e < iVec(inp_W));
526     auto s_mask = must_in_bound ? (i_y_s < iVec(inp_H))
527                                 : (i_y_s > iVec(-1)) & (i_y_s < iVec(inp_H));
528     auto nw_mask = cast<scalar_t>(must_in_bound ? iVec(-1) : (w_mask & n_mask));
529     auto ne_mask = cast<scalar_t>(e_mask & n_mask);
530     auto sw_mask = cast<scalar_t>(w_mask & s_mask);
531     auto se_mask = cast<scalar_t>(e_mask & s_mask);
532 
533     return std::make_tuple(
534       n, s, w, e,
535       nw, ne, sw, se,
536       nw_mask, ne_mask, sw_mask, se_mask,
537       i_y_n, i_x_w);
538   }
539 
forwardat::native::__anon4fc6df3d0111::ApplyGridSample540   inline void forward(TensorAccessor<scalar_t, 3>& out_slice,
541                       const TensorAccessor<const scalar_t, 3>& inp_slice,
542                       int64_t offset, const Vec& grid_x, const Vec& grid_y,
543                       int64_t len) const {
544     auto x = compute_W.apply(grid_x);
545     auto y = compute_H.apply(grid_y);
546 
547     auto interp_params = compute_interp_params(x, y);
548 
549     auto nw = std::get<4>(interp_params);
550     auto ne = std::get<5>(interp_params);
551     auto sw = std::get<6>(interp_params);
552     auto se = std::get<7>(interp_params);
553 
554     auto nw_mask = std::get<8>(interp_params);
555     auto ne_mask = std::get<9>(interp_params);
556     auto sw_mask = std::get<10>(interp_params);
557     auto se_mask = std::get<11>(interp_params);
558 
559     auto i_y_n = std::get<12>(interp_params);
560     auto i_x_w = std::get<13>(interp_params);
561 
562     auto i_nw_offset = i_y_n * iVec(inp_sH) + i_x_w * iVec(inp_sW);
563     auto i_ne_offset = i_nw_offset + iVec(inp_sW);
564     auto i_sw_offset = i_nw_offset + iVec(inp_sH);
565     auto i_se_offset = i_sw_offset + iVec(inp_sW);
566 
567     #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
568     # pragma unroll
569     #endif
570     for (const auto c : c10::irange(C)) {
571       auto inp_slice_C_ptr = inp_slice[c].data();
572 
573       // mask_gather zeros out the mask, so we need to make copies
574       Vec nw_mask_copy = nw_mask;
575       Vec ne_mask_copy = ne_mask;
576       Vec sw_mask_copy = sw_mask;
577       Vec se_mask_copy = se_mask;
578       auto nw_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_nw_offset, nw_mask_copy);
579       auto ne_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_ne_offset, ne_mask_copy);
580       auto sw_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_sw_offset, sw_mask_copy);
581       auto se_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_se_offset, se_mask_copy);
582 
583       auto interpolated = (nw_val * nw) + (ne_val * ne) + (sw_val * sw) + (se_val * se);
584       interpolated.store(out_slice[c].data() + offset, len);
585     }
586   }
587 
588   template<bool input_requires_grad>
backwardat::native::__anon4fc6df3d0111::ApplyGridSample589   inline void backward(TensorAccessor<scalar_t, 3>* gInp_slice_ptr,
590                        TensorAccessor<scalar_t, 3>& gGrid_slice,
591                        const TensorAccessor<const scalar_t, 3>& gOut_slice,
592                        const TensorAccessor<const scalar_t, 3>& inp_slice,
593                        int64_t offset, const Vec& grid_x, const Vec& grid_y,
594                        int64_t len) const {
595     auto [x, gx_mult] = compute_W.apply_get_grad(grid_x);
596     auto [y, gy_mult] = compute_H.apply_get_grad(grid_y);
597 
598     auto [
599       n, s, w, e, nw, ne, sw, se, nw_mask, ne_mask, sw_mask, se_mask,
600       i_y_n, i_x_w] = compute_interp_params(x, y);
601 
602     auto i_nw_offset = i_y_n * iVec(inp_sH) + i_x_w * iVec(inp_sW);
603     auto i_ne_offset = i_nw_offset + iVec(inp_sW);
604     auto i_sw_offset = i_nw_offset + iVec(inp_sH);
605     auto i_se_offset = i_sw_offset + iVec(inp_sW);
606 
607     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
608     integer_t i_nw_mask_arr[iVec::size()];
609     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
610     integer_t i_ne_mask_arr[iVec::size()];
611     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
612     integer_t i_sw_mask_arr[iVec::size()];
613     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
614     integer_t i_se_mask_arr[iVec::size()];
615     nw_mask.store(i_nw_mask_arr);
616     ne_mask.store(i_ne_mask_arr);
617     sw_mask.store(i_sw_mask_arr);
618     se_mask.store(i_se_mask_arr);
619 
620     // i_gInp_*_offset_arr and gInp_corner_arr variables below are unnecessary
621     // when input_requires_grad is false (they are only used within the
622     // if-blocks), but required to make the code well-formed.
623 
624     // When reading input values, we used mask_gather. Unfortunately, there is
625     // no mask_scatter_add (the backward of mask_gather) in Intel intrinsics.
626     // So we store the necessary vectors to temporary arrays and use the helper
627     // mask_scatter_add defined above.
628 
629     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
630     integer_t i_gInp_nw_offset_arr[iVec::size()];
631     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
632     integer_t i_gInp_ne_offset_arr[iVec::size()];
633     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
634     integer_t i_gInp_sw_offset_arr[iVec::size()];
635     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
636     integer_t i_gInp_se_offset_arr[iVec::size()];
637     if (input_requires_grad) {
638       auto i_gInp_nw_offset = i_y_n * iVec(inp_W) + i_x_w;
639       auto i_gInp_ne_offset = i_gInp_nw_offset + iVec(1);
640       auto i_gInp_sw_offset = i_gInp_nw_offset + iVec(inp_W);
641       auto i_gInp_se_offset = i_gInp_sw_offset + iVec(1);
642 
643       i_gInp_nw_offset.store(i_gInp_nw_offset_arr);
644       i_gInp_ne_offset.store(i_gInp_ne_offset_arr);
645       i_gInp_sw_offset.store(i_gInp_sw_offset_arr);
646       i_gInp_se_offset.store(i_gInp_se_offset_arr);
647     }
648 
649     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
650     scalar_t gInp_corner_arr[Vec::size()];
651 
652     auto gx = Vec(0), gy = Vec(0);
653     #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
654     # pragma unroll
655     #endif
656     for (const auto c : c10::irange(C)) {
657       auto inp_slice_C_ptr = inp_slice[c].data();
658       auto gOut = Vec::loadu(gOut_slice[c].data() + offset, len);
659 
660       if (input_requires_grad) {
661         TORCH_INTERNAL_ASSERT(gInp_slice_ptr);
662         auto gInp_slice_C_ptr = (*gInp_slice_ptr)[c].data();
663 
664         (nw * gOut).store(gInp_corner_arr);
665         mask_scatter_add(gInp_corner_arr, gInp_slice_C_ptr, i_gInp_nw_offset_arr, i_nw_mask_arr, len);
666         (ne * gOut).store(gInp_corner_arr);
667         mask_scatter_add(gInp_corner_arr, gInp_slice_C_ptr, i_gInp_ne_offset_arr, i_ne_mask_arr, len);
668         (sw * gOut).store(gInp_corner_arr);
669         mask_scatter_add(gInp_corner_arr, gInp_slice_C_ptr, i_gInp_sw_offset_arr, i_sw_mask_arr, len);
670         (se * gOut).store(gInp_corner_arr);
671         mask_scatter_add(gInp_corner_arr, gInp_slice_C_ptr, i_gInp_se_offset_arr, i_se_mask_arr, len);
672       }
673 
674       // mask_gather zeros out the mask, so we need to make copies
675       Vec nw_mask_copy = nw_mask;
676       Vec ne_mask_copy = ne_mask;
677       Vec sw_mask_copy = sw_mask;
678       Vec se_mask_copy = se_mask;
679       auto nw_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_nw_offset, nw_mask_copy);
680       auto ne_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_ne_offset, ne_mask_copy);
681       auto sw_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_sw_offset, sw_mask_copy);
682       auto se_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_C_ptr, i_se_offset, se_mask_copy);
683 
684       gx = gx + ((ne_val - nw_val) * s + (se_val - sw_val) * n) * gOut;
685       gy = gy + ((sw_val - nw_val) * e + (se_val - ne_val) * w) * gOut;
686     }
687 
688     gx = gx * gx_mult;
689     gy = gy * gy_mult;
690 
691     constexpr int64_t step = Vec::size();
692     auto interleaved_gGrid = interleave2(gx, gy);
693     auto gGrid_ptr = gGrid_slice.data() + offset * 2;
694     std::get<0>(interleaved_gGrid).store(gGrid_ptr,
695                                          std::min(len * 2, step));
696     std::get<1>(interleaved_gGrid).store(gGrid_ptr + step,
697                                          std::max(static_cast<int64_t>(0), len * 2 - step));
698   }
699 };
700 
701 template<typename scalar_t, GridSamplerPadding padding, bool align_corners>
702 struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Nearest,
703                        padding, align_corners> {
704   using Vec = Vectorized<scalar_t>;
705   using integer_t = int_same_size_t<scalar_t>;
706   using iVec = Vectorized<integer_t>;
707 
708   const int64_t inp_H;
709   const int64_t inp_W;
710   const int64_t inp_sH;
711   const int64_t inp_sW;
712   const int64_t C;
713   const int64_t inp_sC;
714   const ComputeLocation<scalar_t, padding, align_corners> compute_H;
715   const ComputeLocation<scalar_t, padding, align_corners> compute_W;
716   const bool must_in_bound = padding != GridSamplerPadding::Zeros;
717 
ApplyGridSampleat::native::__anon4fc6df3d0111::ApplyGridSample718   ApplyGridSample(const TensorAccessor<const scalar_t, 4>& input)
719     : inp_H(input.size(2))
720     , inp_W(input.size(3))
721     , inp_sH(input.stride(2))
722     , inp_sW(input.stride(3))
723     , C(input.size(1))
724     , inp_sC(input.stride(1))
725     , compute_H(input.size(2))
726     , compute_W(input.size(3)) {}
727 
forwardat::native::__anon4fc6df3d0111::ApplyGridSample728   inline void forward(TensorAccessor<scalar_t, 3>& out_slice,
729                       const TensorAccessor<const scalar_t, 3>& inp_slice,
730                       int64_t offset, const Vec& grid_x, const Vec& grid_y,
731                       int64_t len) const {
732     auto x = compute_W.apply(grid_x);
733     auto y = compute_H.apply(grid_y);
734 
735     auto x_nearest = x.round();
736     auto y_nearest = y.round();
737 
738     auto i_x_nearest = convert_to_int_of_same_size(x_nearest);
739     auto i_y_nearest = convert_to_int_of_same_size(y_nearest);
740 
741     auto i_mask = must_in_bound ? iVec(-1)
742                                 : (i_x_nearest > iVec(-1)) & (i_x_nearest < iVec(inp_W)) &
743                                   (i_y_nearest > iVec(-1)) & (i_y_nearest < iVec(inp_H));
744     auto mask = cast<scalar_t>(i_mask);
745 
746     auto i_offset = i_y_nearest * iVec(inp_sH) + i_x_nearest * iVec(inp_sW);
747 
748     auto out_ptr = out_slice.data() + offset;
749     auto out_sC = out_slice.stride(0);
750     auto inp_slice_ptr = inp_slice.data();
751     #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
752     # pragma unroll
753     #endif
754     for (int64_t c = 0; c < C; ++c, out_ptr += out_sC, inp_slice_ptr += inp_sC) {
755       // mask_gather zeros out the mask, so we need to make a copy
756       auto mask_copy = mask;
757       auto inp_val = mask_gather<sizeof(scalar_t)>(Vec(0), inp_slice_ptr, i_offset, mask_copy);
758       inp_val.store(static_cast<void*>(out_ptr), len);
759     }
760   }
761 
762   template<bool input_requires_grad>
backwardat::native::__anon4fc6df3d0111::ApplyGridSample763   inline void backward(TensorAccessor<scalar_t, 3>* gInp_slice_ptr,
764                        TensorAccessor<scalar_t, 3>& gGrid_slice,
765                        const TensorAccessor<const scalar_t, 3>& gOut_slice,
766                        const TensorAccessor<const scalar_t, 3>& /*inp_slice*/,
767                        int64_t offset, const Vec& grid_x, const Vec& grid_y,
768                        int64_t len) const {
769     if (input_requires_grad) {
770       auto x = compute_W.apply(grid_x);
771       auto y = compute_H.apply(grid_y);
772 
773       auto x_nearest = x.round();
774       auto y_nearest = y.round();
775 
776       auto i_x_nearest = convert_to_int_of_same_size(x_nearest);
777       auto i_y_nearest = convert_to_int_of_same_size(y_nearest);
778 
779       auto i_mask = must_in_bound ? iVec(-1)
780                                   : (i_x_nearest > iVec(-1)) & (i_x_nearest < iVec(inp_W)) &
781                                     (i_y_nearest > iVec(-1)) & (i_y_nearest < iVec(inp_H));
782 
783       auto i_gInp_offset = i_y_nearest * iVec(inp_W) + i_x_nearest;  // gInp is contiguous
784 
785       // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
786       integer_t mask_arr[iVec::size()];
787       i_mask.store(mask_arr);
788       // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
789       integer_t gInp_offset_arr[iVec::size()];
790       i_gInp_offset.store(gInp_offset_arr);
791 
792       #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
793       # pragma unroll
794       #endif
795       for (const auto c : c10::irange(C)) {
796         mask_scatter_add(gOut_slice[c].data() + offset, (*gInp_slice_ptr)[c].data(),
797                         gInp_offset_arr, mask_arr, len);
798       }
799     }
800 
801     // grid has zero 0 gradient in Nearest mode
802     auto gGrid_ptr = gGrid_slice.data() + offset * 2;
803     std::memset(gGrid_ptr, 0, sizeof(scalar_t) * len * 2);
804   }
805 };
806 
807 // Use bicubic convolution algorithm. Based on
808 // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
809 template<typename scalar_t, GridSamplerPadding padding, bool align_corners>
810 struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bicubic,
811                        padding, align_corners> {
812   using Vec = Vectorized<scalar_t>;
813   using integer_t = int_same_size_t<scalar_t>;
814   using iVec = Vectorized<integer_t>;
815 
816   const int64_t inp_H;
817   const int64_t inp_W;
818   const int64_t inp_sH;
819   const int64_t inp_sW;
820   const int64_t C;
821   const int64_t inp_sC;
822   const ComputeLocation<scalar_t, padding, align_corners> compute_H;
823   const ComputeLocation<scalar_t, padding, align_corners> compute_W;
824   const bool must_in_bound = padding != GridSamplerPadding::Zeros;
825 
826   // constant used in cubic convolution
827   // could be -0.5 or -0.75, use the same value in UpSampleBicubic2d.h
828   const Vec A = Vec(-0.75);
829 
ApplyGridSampleat::native::__anon4fc6df3d0111::ApplyGridSample830   ApplyGridSample(const TensorAccessor<const scalar_t, 4>& input)
831     : inp_H(input.size(2))
832     , inp_W(input.size(3))
833     , inp_sH(input.stride(2))
834     , inp_sW(input.stride(3))
835     , C(input.size(1))
836     , inp_sC(input.stride(1))
837     , compute_H(input.size(2))
838     , compute_W(input.size(3)) {}
839 
840   // Calculate the cubic convolution coefficient
841   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
get_cubic_coefficientsat::native::__anon4fc6df3d0111::ApplyGridSample842   inline void get_cubic_coefficients(Vec (&coeffs)[4], const Vec& tx) const {
843     Vec x;
844     x = tx + Vec(1);  // 1 < x = |-1 - tx| < 2
845     coeffs[0] = ((A * x - Vec(5) * A) * x + Vec(8) * A) * x - Vec(4) * A;
846     x = tx;           // x = |0 - tx| <= 1
847     coeffs[1] = ((A + Vec(2)) * x - (A + Vec(3))) * x * x + Vec(1);
848     x = Vec(1) - tx;  // x = |1 - tx| <= 1
849     coeffs[2] = ((A + Vec(2)) * x - (A + Vec(3))) * x * x + Vec(1);
850     x = Vec(2) - tx;  // 1 < x = |2 - tx| < 2
851     coeffs[3] = ((A * x - Vec(5) * A) * x + Vec(8) * A) * x - Vec(4) * A;
852   }
853 
854   // Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
855   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
get_cubic_coefficients_gradat::native::__anon4fc6df3d0111::ApplyGridSample856   inline void get_cubic_coefficients_grad(Vec (&coeffs)[4], const Vec& tx) const {
857     Vec x;
858     x = Vec(-1) - tx; // 1 < x = |-1 - tx| < 2
859     coeffs[0] = (Vec(-3) * A * x - Vec(10) * A ) * x - Vec(8) * A;
860     x = Vec(0) - tx;  // x = |0 - tx| <= 1
861     coeffs[1] = (Vec(-3) * (A + Vec(2)) * x - Vec(2) * (A + Vec(3))) * x;
862     x = Vec(1) - tx;  // x = |1 - tx| <= 1
863     coeffs[2] = (Vec(3) * (A + Vec(2)) * x - Vec(2) * (A + Vec(3))) * x;
864     x = Vec(2) - tx;  // 1 < x = |2 - tx| < 2
865     coeffs[3] = (Vec(3) * A * x - Vec(10) * A) * x + Vec(8) * A;
866   }
867 
get_value_boundedat::native::__anon4fc6df3d0111::ApplyGridSample868   inline Vec get_value_bounded(const scalar_t* data, const Vec& x, const Vec& y) const {
869     auto ix = convert_to_int_of_same_size(compute_W.compute_coordinates(x));
870     auto iy = convert_to_int_of_same_size(compute_H.compute_coordinates(y));
871 
872     auto mask_x = must_in_bound ? iVec(-1) : (ix > iVec(-1)) & (ix < iVec(inp_W));
873     auto mask_y = must_in_bound ? iVec(-1) : (iy > iVec(-1)) & (iy < iVec(inp_H));
874     auto mask = cast<scalar_t>(mask_x & mask_y);
875 
876     auto offset = iy * iVec(inp_sH) + ix * iVec(inp_sW);
877 
878     auto val = mask_gather<sizeof(scalar_t)>(Vec(0), data, offset, mask);
879     return val;
880   }
881 
add_value_boundedat::native::__anon4fc6df3d0111::ApplyGridSample882   inline void add_value_bounded(scalar_t* data, int64_t len, const Vec& x, const Vec&y,
883                                const Vec& delta) const {
884 
885     auto ix = convert_to_int_of_same_size(compute_W.compute_coordinates(x));
886     auto iy = convert_to_int_of_same_size(compute_H.compute_coordinates(y));
887 
888     auto mask_x = must_in_bound ? iVec(-1) : (ix > iVec(-1)) & (ix < iVec(inp_W));
889     auto mask_y = must_in_bound ? iVec(-1) : (iy > iVec(-1)) & (iy < iVec(inp_H));
890     auto mask = cast<scalar_t>(mask_x & mask_y);
891 
892     auto i_gInp_offset = iy * iVec(inp_W) + ix;
893     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
894     integer_t i_gInp_offset_arr[iVec::size()];
895     i_gInp_offset.store(i_gInp_offset_arr);
896 
897     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
898     integer_t mask_arr[iVec::size()];
899     mask.store(mask_arr);
900 
901     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
902     scalar_t gInp_corner_arr[Vec::size()];
903     delta.store(gInp_corner_arr);
904 
905     mask_scatter_add(gInp_corner_arr, data, i_gInp_offset_arr, mask_arr, len);
906   }
907 
forwardat::native::__anon4fc6df3d0111::ApplyGridSample908   inline void forward(TensorAccessor<scalar_t, 3>& out_slice,
909                       const TensorAccessor<const scalar_t, 3>& inp_slice,
910                       int64_t offset, const Vec& grid_x, const Vec& grid_y,
911                       int64_t len) const {
912 
913     auto x = compute_W.unnormalize(grid_x);
914     auto y = compute_H.unnormalize(grid_y);
915 
916     auto ix = x.floor();
917     auto iy = y.floor();
918 
919     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
920     Vec coeff_x[4];
921     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
922     Vec coeff_y[4];
923     get_cubic_coefficients(coeff_x, x - ix);
924     get_cubic_coefficients(coeff_y, y - iy);
925 
926     #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
927     # pragma unroll
928     #endif
929     for (const auto c : c10::irange(C)) {
930       auto inp_slice_C_ptr = inp_slice[c].data();
931 
932       // Interpolate the 4 values in the x direction
933       // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
934       Vec interp_x[4];
935       for (const auto i : c10::irange(4)) {
936         interp_x[i] =
937           coeff_x[0] * get_value_bounded(inp_slice_C_ptr, ix - Vec(1), iy + Vec(-1 + i)) +
938           coeff_x[1] * get_value_bounded(inp_slice_C_ptr, ix + Vec(0), iy + Vec(-1 + i)) +
939           coeff_x[2] * get_value_bounded(inp_slice_C_ptr, ix + Vec(1), iy + Vec(-1 + i)) +
940           coeff_x[3] * get_value_bounded(inp_slice_C_ptr, ix + Vec(2), iy + Vec(-1 + i));
941       }
942 
943       // Interpolate the 4 values in the y direction
944       auto interpolated = coeff_y[0] * interp_x[0] + coeff_y[1] * interp_x[1] +
945                           coeff_y[2] * interp_x[2] + coeff_y[3] * interp_x[3];
946       interpolated.store(out_slice[c].data() + offset, len);
947     }
948   }
949 
950   template<bool input_requires_grad>
backwardat::native::__anon4fc6df3d0111::ApplyGridSample951   inline void backward(TensorAccessor<scalar_t, 3>* gInp_slice_ptr,
952                       TensorAccessor<scalar_t, 3>& gGrid_slice,
953                       const TensorAccessor<const scalar_t, 3>& gOut_slice,
954                       const TensorAccessor<const scalar_t, 3>& inp_slice,
955                       int64_t offset, const Vec& grid_x, const Vec& grid_y,
956                       int64_t len) const {
957     Vec x = compute_W.unnormalize(grid_x);
958     Vec y = compute_H.unnormalize(grid_y);
959     Vec gx_mult = Vec(compute_W.scaling_factor);
960     Vec gy_mult = Vec(compute_H.scaling_factor);
961 
962     auto ix = x.floor();
963     auto iy = y.floor();
964 
965     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
966     Vec coeff_x[4];
967     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
968     Vec coeff_y[4];
969     get_cubic_coefficients(coeff_x, x - ix);
970     get_cubic_coefficients(coeff_y, y - iy);
971 
972     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
973     Vec coeff_x_grad[4];
974     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
975     Vec coeff_y_grad[4];
976     get_cubic_coefficients_grad(coeff_x_grad, x - ix);
977     get_cubic_coefficients_grad(coeff_y_grad, y - iy);
978 
979     auto gx = Vec(0), gy = Vec(0);
980     #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
981     # pragma unroll
982     #endif
983     for (const auto c : c10::irange(C)) {
984       auto inp_slice_C_ptr = inp_slice[c].data();
985       auto gOut = Vec::loadu(gOut_slice[c].data() + offset, len);
986 
987       for (const auto i : c10::irange(4)) {
988         for (const auto j : c10::irange(4)) {
989           auto xx = ix + Vec(-1 + i);
990           auto yy = iy + Vec(-1 + j);
991 
992           if (input_requires_grad) {
993             auto gInp_slice_C_ptr = (*gInp_slice_ptr)[c].data();
994             add_value_bounded(gInp_slice_C_ptr, len, xx, yy, gOut * coeff_x[i] * coeff_y[j]);
995           }
996 
997           auto val = get_value_bounded(inp_slice_C_ptr, xx, yy);
998           gx = gx - val * gOut * coeff_x_grad[i] * coeff_y[j];
999           gy = gy - val * gOut * coeff_y_grad[j] * coeff_x[i];
1000         }
1001       }
1002     }
1003 
1004     gx = gx * gx_mult;
1005     gy = gy * gy_mult;
1006 
1007     constexpr int64_t step = Vec::size();
1008     auto interleaved_gGrid = interleave2(gx, gy);
1009     auto gGrid_ptr = gGrid_slice.data() + offset * 2;
1010     std::get<0>(interleaved_gGrid).store(gGrid_ptr,
1011                                          std::min(len * 2, step));
1012     std::get<1>(interleaved_gGrid).store(gGrid_ptr + step,
1013                                          std::max(static_cast<int64_t>(0), len * 2 - step));
1014   }
1015 };
1016 
1017 // ~~~~~~~~~~~~~~~~~~ grid_sample_2d_grid_slice_iterator ~~~~~~~~~~~~~~~~~~~~~~
1018 // Function to apply a vectorized function on a grid slice tensor (without batch
1019 // dimension).
1020 // See NOTE [ Grid Sample CPU Kernels ] for details.
1021 
1022 template<typename scalar_t, typename ApplyFn>
grid_sample_2d_grid_slice_iterator(const TensorAccessor<const scalar_t,3> & grid_slice,const ApplyFn & apply_fn)1023 static inline void grid_sample_2d_grid_slice_iterator(
1024     const TensorAccessor<const scalar_t, 3>& grid_slice, const ApplyFn &apply_fn) {
1025   int64_t out_H = grid_slice.size(0);
1026   int64_t out_W = grid_slice.size(1);
1027   int64_t grid_sH = grid_slice.stride(0);
1028   int64_t grid_sW = grid_slice.stride(1);
1029   int64_t grid_sCoor = grid_slice.stride(2);
1030   auto grid_ptr = grid_slice.data();
1031 
1032   using Vec = Vectorized<scalar_t>;
1033   using iVec = Vectorized<int_same_size_t<scalar_t>>;
1034   constexpr int64_t step = Vec::size();
1035 
1036   // Loop over each output pixel in grid.
1037   // We consider the following three cases (after slicing out the batch
1038   // dimension).
1039   // See detailed discussions under each if-case.
1040 
1041   if (at::geometry_is_contiguous({out_H, out_W, 2}, {grid_sH, grid_sW, grid_sCoor})) {
1042     // Case 1:
1043     // Grid is contiguous.
1044     // Strategy: Sequentially load two vectors at the same time, and get,
1045     //           e.g.,  {x0, y0, x1, y1}, {x2, y2, x3, y3}. Then we use
1046     //           at::vec::deinterleave2 to get x and y vectors.
1047     auto total_size = out_H * out_W;
1048     for (int64_t spatial_offset = 0; spatial_offset < total_size; spatial_offset += step) {
1049       auto grid_offset = spatial_offset * 2;
1050       auto len = std::min(step, total_size - spatial_offset);
1051       auto vec1 = Vec::loadu(grid_ptr + grid_offset,
1052                              std::min(step, len * 2));
1053       auto vec2 = Vec::loadu(grid_ptr + grid_offset + step,
1054                              std::max(static_cast<int64_t>(0), len * 2 - step));
1055       auto vec_xy_pair = deinterleave2(vec1, vec2);
1056 
1057       auto x = std::get<0>(vec_xy_pair);
1058       auto y = std::get<1>(vec_xy_pair);
1059 
1060       // make sure that x and y are valid grid sample locations
1061       if (len < step) {
1062         x = Vec::set(Vec(0), x, len);
1063         y = Vec::set(Vec(0), y, len);
1064       }
1065       apply_fn(x, y, spatial_offset, len);
1066     }
1067   } else if (grid_sW == 1 || out_W == 1) {
1068     // Case 2:
1069     // The W dimension is contiguous.
1070     // This can be common, e.g., grid is from a conv net output of shape
1071     // [N, 2, H, W].
1072     // Strategy: Divide into two contiguous slices each of shape [H, W], and
1073     //           each containing x and y vectors. So we sequentially load a
1074     //           vector from each of them to get x and y vector
1075 
1076     // Function to apply along a contiguous W dimension (or flattened H x W).
1077     auto line_fn = [&](const scalar_t *grid_ptr_x, const scalar_t *grid_ptr_y,
1078                        int64_t out_base_offset, int64_t total_size) {
1079       for (int64_t i = 0; i < total_size; i += step) {
1080         auto len = std::min(step, total_size - i);
1081         auto x = Vec::loadu(grid_ptr_x + i, len);
1082         auto y = Vec::loadu(grid_ptr_y + i, len);
1083         // make sure that x and y are valid grid sample locations
1084         if (len < step) {
1085           x = Vec::set(Vec(0), x, len);
1086           y = Vec::set(Vec(0), y, len);
1087         }
1088         apply_fn(x, y, out_base_offset + i, len);
1089       }
1090     };
1091 
1092     if (at::geometry_is_contiguous({out_H, out_W}, {grid_sH, grid_sW})) {
1093       // If [H, W] is contiguous, apply line_fn once.
1094       line_fn(grid_ptr, grid_ptr + grid_sCoor, 0, out_H * out_W);
1095     } else {
1096       // If only [W] is contiguous, apply line_fn once for each h slice.
1097       auto grid_ptr_NH = grid_ptr;
1098       for (const auto h : c10::irange(out_H)) {
1099         line_fn(grid_ptr_NH, grid_ptr_NH + grid_sCoor, h * out_W, out_W);
1100         grid_ptr_NH += grid_sH;
1101       }
1102     }
1103   } else {
1104     // Case 3:
1105     // General case.
1106     // Strategy: Do a for-loop over H, for each W slice, use
1107     //           at::vec::gather to load the x and y vectors.
1108     int64_t spatial_offset = 0;
1109     const int64_t i_offset_delta = grid_sW * step;
1110 
1111     #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
1112     # pragma unroll
1113     #endif
1114     for (const auto h : c10::irange(out_H)) {
1115       auto grid_ptr_x = grid_ptr + h * grid_sH;
1116       auto grid_ptr_y = grid_ptr_x + grid_sCoor;
1117       auto i_offsets = iVec::arange(0, grid_sW);
1118       #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE)
1119       # pragma unroll
1120       #endif
1121       for (int64_t w = 0; w < out_W; w += step) {
1122         auto len = std::min(step, out_W - w);
1123         if (len < step) {
1124           // prevents illegal memory access, sets the exceeding offsets to zero
1125           i_offsets = iVec::set(iVec(0), i_offsets, len);
1126         }
1127         apply_fn(vec::gather<sizeof(scalar_t)>(grid_ptr_x, i_offsets),
1128                  vec::gather<sizeof(scalar_t)>(grid_ptr_y, i_offsets),
1129                  spatial_offset, len);
1130 
1131         grid_ptr_x += i_offset_delta;
1132         grid_ptr_y += i_offset_delta;
1133         spatial_offset += len;
1134       }
1135     }
1136   }
1137 }
1138 
1139 // ~~~~~~~~~~~~~~~~~~~~~~~~~ Grid Sample Kernels ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1140 // Use the structs & functions defined above to calculate grid sample forward
1141 // and backward.
1142 // See NOTE [ Grid Sample CPU Kernels ] for details.
1143 
grid_sampler_2d_cpu_kernel_impl(const TensorBase & output,const TensorBase & input,const TensorBase & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners)1144 void grid_sampler_2d_cpu_kernel_impl(
1145     const TensorBase &output, const TensorBase &input, const TensorBase &grid,
1146     int64_t interpolation_mode, int64_t padding_mode, bool align_corners) {
1147   auto N = input.size(0);
1148   auto H = grid.size(1);
1149   auto W = grid.size(2);
1150   auto spatial_size = H * W;
1151   auto grain_size = spatial_size == 0 ? (N + 1)
1152                                       : at::divup(at::internal::GRAIN_SIZE, spatial_size * 4 /* 2d * 2 tensors*/);
1153   if (output.numel() == 0) {
1154          return;
1155   }
1156 
1157 #define HANDLE_CASE(interp, padding, align_corners)                            \
1158   case padding: {                                                              \
1159     ApplyGridSample<scalar_t, 2, interp, padding, align_corners>               \
1160     grid_sample(inp_acc);                                                      \
1161     parallel_for(0, N, grain_size, [&](int64_t begin, int64_t end) {           \
1162       for (const auto n : c10::irange(begin, end)) {                           \
1163         auto out_slice = out_acc[n];                                           \
1164         auto inp_slice = inp_acc[n];                                           \
1165         grid_sample_2d_grid_slice_iterator(                                    \
1166           grid_acc[n],                                                         \
1167           [&](const Vectorized<scalar_t>& grid_x, const Vectorized<scalar_t>& grid_y,  \
1168               int64_t spatial_offset, int64_t len) {                           \
1169             grid_sample.forward(out_slice, inp_slice, spatial_offset,          \
1170                                 grid_x, grid_y, len);                          \
1171           });                                                                  \
1172         }                                                                      \
1173       });                                                                      \
1174     return;                                                                    \
1175   }
1176 
1177 #define HANDLE_INTERP(interp, align_corners)                                   \
1178   case interp: {                                                               \
1179     switch (static_cast<GridSamplerPadding>(padding_mode)) {                   \
1180       HANDLE_CASE(interp, GridSamplerPadding::Zeros, align_corners);           \
1181       HANDLE_CASE(interp, GridSamplerPadding::Border, align_corners);          \
1182       HANDLE_CASE(interp, GridSamplerPadding::Reflection, align_corners);      \
1183     }                                                                          \
1184     return;                                                                    \
1185   }
1186 
1187   AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_cpu_kernel_impl", [&] {
1188     auto out_acc = output.accessor<scalar_t, 4>();
1189     auto inp_acc = input.accessor<const scalar_t, 4>();
1190     auto grid_acc = grid.accessor<const scalar_t, 4>();
1191     if (align_corners) {
1192       switch (static_cast<GridSamplerInterpolation>(interpolation_mode)) {
1193         HANDLE_INTERP(GridSamplerInterpolation::Bilinear, true);
1194         HANDLE_INTERP(GridSamplerInterpolation::Nearest, true);
1195         HANDLE_INTERP(GridSamplerInterpolation::Bicubic, true);
1196       }
1197     } else {
1198       switch (static_cast<GridSamplerInterpolation>(interpolation_mode)) {
1199         HANDLE_INTERP(GridSamplerInterpolation::Bilinear, false);
1200         HANDLE_INTERP(GridSamplerInterpolation::Nearest, false);
1201         HANDLE_INTERP(GridSamplerInterpolation::Bicubic, false);
1202       }
1203     }
1204   });
1205 #undef HANDLE_CASE
1206 #undef HANDLE_INTERP
1207 }
1208 
grid_sampler_2d_backward_cpu_kernel_impl(const TensorBase & grad_input,const TensorBase & grad_grid,const TensorBase & grad_output_,const TensorBase & input,const TensorBase & grid,int64_t interpolation_mode,int64_t padding_mode,bool align_corners,std::array<bool,2> output_mask)1209 void grid_sampler_2d_backward_cpu_kernel_impl(
1210     const TensorBase &grad_input,
1211     const TensorBase &grad_grid,
1212     const TensorBase &grad_output_,
1213     const TensorBase &input,
1214     const TensorBase &grid,
1215     int64_t interpolation_mode,
1216     int64_t padding_mode,
1217     bool align_corners,
1218     std::array<bool,2> output_mask) {
1219   if (grad_output_.numel() == 0) {
1220     grad_grid.zero_();
1221     return;
1222   }
1223   // grad_output should be contiguous most of time. Ensuring that it is
1224   // contiguous can greatly simplify this code.
1225   auto grad_output = grad_output_.contiguous();
1226 
1227   // If `input` gradient is not required, we skip computing it -- not needing to create
1228   // the tensor to hold the gradient can markedly increase performance. (`grid` gradient
1229   // is always computed.)
1230   auto input_requires_grad = output_mask[0];
1231 
1232   auto N = input.size(0);
1233   auto spatial_size = grid.size(1) * grid.size(2);
1234   auto grain_size = spatial_size == 0 ? (N + 1)
1235                                       : at::divup(at::internal::GRAIN_SIZE, spatial_size * 10 /* 2d * 5 tensors*/);
1236 
1237 #define GINP_SLICE_PTR_true auto gInp_slice = gInp_acc[n]; auto gInp_slice_ptr = &gInp_slice;
1238 #define GINP_SLICE_PTR_false TensorAccessor<scalar_t, 3>* gInp_slice_ptr = nullptr;
1239 #define GINP_SLICE_PTR(input_requires_grad) GINP_SLICE_PTR_##input_requires_grad
1240 
1241 #define HANDLE_CASE(interp, padding, align_corners, input_requires_grad)         \
1242   case padding: {                                                                \
1243     ApplyGridSample<scalar_t, 2, interp, padding, align_corners>                 \
1244     grid_sample(inp_acc);                                                        \
1245     parallel_for(0, N, grain_size, [&](int64_t begin, int64_t end) {             \
1246       for (const auto n : c10::irange(begin, end)) {                             \
1247         GINP_SLICE_PTR(input_requires_grad)                                      \
1248         auto gGrid_slice = gGrid_acc[n];                                         \
1249         auto gOut_slice = gOut_acc[n];                                           \
1250         auto inp_slice = inp_acc[n];                                             \
1251         grid_sample_2d_grid_slice_iterator(                                      \
1252           grid_acc[n],                                                           \
1253           [&](const Vectorized<scalar_t>& grid_x, const Vectorized<scalar_t>& grid_y,    \
1254               int64_t spatial_offset, int64_t len) {                             \
1255             grid_sample.backward<input_requires_grad>(gInp_slice_ptr, gGrid_slice,       \
1256                                                       gOut_slice, inp_slice,     \
1257                                                       spatial_offset, grid_x,    \
1258                                                       grid_y, len);              \
1259           });                                                                    \
1260       }                                                                          \
1261     });                                                                          \
1262     return;                                                                      \
1263   }
1264 
1265 #define HANDLE_INTERP(interp, align_corners, input_requires_grad)           \
1266   case interp: {                                                            \
1267     switch (static_cast<GridSamplerPadding>(padding_mode)) {                \
1268       HANDLE_CASE(interp, GridSamplerPadding::Zeros, align_corners, input_requires_grad);      \
1269       HANDLE_CASE(interp, GridSamplerPadding::Border, align_corners, input_requires_grad);     \
1270       HANDLE_CASE(interp, GridSamplerPadding::Reflection, align_corners, input_requires_grad); \
1271     }                                                                       \
1272     return;                                                                 \
1273   }
1274 
1275   AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_2d_backward_cpu_kernel_impl", [&] {
1276     auto gGrid_acc = grad_grid.accessor<scalar_t, 4>();
1277     auto inp_acc = input.accessor<const scalar_t, 4>();
1278     auto grid_acc = grid.accessor<const scalar_t, 4>();
1279     auto gOut_acc = grad_output.accessor<const scalar_t, 4>();
1280     if (input_requires_grad) {
1281       auto gInp_acc = grad_input.accessor<scalar_t, 4>();
1282       if (align_corners) {
1283         switch (static_cast<GridSamplerInterpolation>(interpolation_mode)) {
1284           HANDLE_INTERP(GridSamplerInterpolation::Bilinear, true, true);
1285           HANDLE_INTERP(GridSamplerInterpolation::Nearest, true, true);
1286           HANDLE_INTERP(GridSamplerInterpolation::Bicubic, true, true);
1287         }
1288       } else {
1289         switch (static_cast<GridSamplerInterpolation>(interpolation_mode)) {
1290           HANDLE_INTERP(GridSamplerInterpolation::Bilinear, false, true);
1291           HANDLE_INTERP(GridSamplerInterpolation::Nearest, false, true);
1292           HANDLE_INTERP(GridSamplerInterpolation::Bicubic, false, true);
1293         }
1294       }
1295     } else {
1296       if (align_corners) {
1297         switch (static_cast<GridSamplerInterpolation>(interpolation_mode)) {
1298           HANDLE_INTERP(GridSamplerInterpolation::Bilinear, true, false);
1299           HANDLE_INTERP(GridSamplerInterpolation::Nearest, true, false);
1300           HANDLE_INTERP(GridSamplerInterpolation::Bicubic, true, false);
1301         }
1302       } else {
1303         switch (static_cast<GridSamplerInterpolation>(interpolation_mode)) {
1304           HANDLE_INTERP(GridSamplerInterpolation::Bilinear, false, false);
1305           HANDLE_INTERP(GridSamplerInterpolation::Nearest, false, false);
1306           HANDLE_INTERP(GridSamplerInterpolation::Bicubic, false, false);
1307         }
1308       }
1309 
1310     }
1311   });
1312 #undef HANDLE_CASE
1313 #undef HANDLE_INTERP
1314 }
1315 
1316 }
1317 
1318 REGISTER_DISPATCH(grid_sampler_2d_cpu_kernel, &grid_sampler_2d_cpu_kernel_impl);
1319 REGISTER_DISPATCH(grid_sampler_2d_backward_cpu_kernel, &grid_sampler_2d_backward_cpu_kernel_impl);
1320 
1321 
1322 }  // namespace at::native
1323