xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/GridSampler.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <algorithm>
4 #include <cmath>
5 #include <cstdint>
6 #include <utility>
7 
8 #include <ATen/native/GridSamplerUtils.h>
9 
10 namespace at::native {
11 
12 using detail::GridSamplerInterpolation;
13 using detail::GridSamplerPadding;
14 
15 // Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
16 // where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
17 // if align_corners: -1 and +1 get sent to the centers of the corner pixels
18 //     -1 --> 0
19 //     +1 --> (size - 1)
20 //     scale_factor = (size - 1) / 2
21 // if not align_corners: -1 and +1 get sent to the image edges
22 //     -1 --> -0.5
23 //     +1 --> (size - 1) + 0.5 == size - 0.5
24 //     scale_factor = size / 2
25 template <typename scalar_t>
grid_sampler_unnormalize(scalar_t coord,int64_t size,bool align_corners)26 static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size,
27                                                 bool align_corners) {
28   if (align_corners) {
29     // unnormalize coord from [-1, 1] to [0, size - 1]
30     return ((coord + 1) / 2) * (size - 1);
31   } else {
32     // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
33     return ((coord + 1) * size - 1) / 2;
34   }
35 }
36 
37 // grid_sampler_unnormalize_set_grad works the same as grid_sampler_unnormalize
38 // except that it also returns the `d output / d input` via pointer argument
39 // `grad_in`.
40 // This is useful in the backward pass of grid_sampler.
41 template <typename scalar_t>
grid_sampler_unnormalize_set_grad(scalar_t coord,int64_t size,bool align_corners,scalar_t * grad_in)42 static inline scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int64_t size,
43                                                          bool align_corners, scalar_t *grad_in) {
44   if (align_corners) {
45     // unnormalize coord from [-1, 1] to [0, size - 1]
46     *grad_in = static_cast<scalar_t>(size - 1) / 2;
47     return ((coord + 1) / 2) * (size - 1);
48   } else {
49     // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
50     *grad_in = static_cast<scalar_t>(size) / 2;
51     return ((coord + 1) * size - 1) / 2;
52   }
53 }
54 
55 // Clips coordinates to between 0 and clip_limit - 1
56 template<typename scalar_t>
clip_coordinates(scalar_t in,int64_t clip_limit)57 static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
58   return std::min(static_cast<scalar_t>(clip_limit - 1), std::max(in, static_cast<scalar_t>(0)));
59 }
60 
61 // clip_coordinates_set_grad works similarly to clip_coordinates except that
62 // it also returns the `d output / d input` via pointer argument `grad_in`.
63 // This is useful in the backward pass of grid_sampler.
64 template<typename scalar_t>
clip_coordinates_set_grad(scalar_t in,int64_t clip_limit,scalar_t * grad_in)65 static inline scalar_t clip_coordinates_set_grad(scalar_t in, int64_t clip_limit,
66                                                  scalar_t *grad_in) {
67   // Note that it is important for the gradient calculation that borders
68   // are considered out of bounds.
69   if (in <= static_cast<scalar_t>(0)) {
70     *grad_in = static_cast<scalar_t>(0);
71     return static_cast<scalar_t>(0);
72   } else {
73     scalar_t max = static_cast<scalar_t>(clip_limit - 1);
74     if (in >= max) {
75       *grad_in = static_cast<scalar_t>(0);
76       return max;
77     } else {
78       *grad_in = static_cast<scalar_t>(1);
79       return in;
80     }
81   }
82 }
83 
84 // Reflects coordinates until they fall between low and high (inclusive).
85 // The bounds are passed as twice their value so that half-integer values
86 // can be represented as ints.
87 template<typename scalar_t>
reflect_coordinates(scalar_t in,int64_t twice_low,int64_t twice_high)88 static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low,
89                                            int64_t twice_high) {
90   if (twice_low == twice_high) {
91     return static_cast<scalar_t>(0);
92   }
93   scalar_t min = static_cast<scalar_t>(twice_low) / 2;
94   scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
95   in = std::fabs(in - min);
96   // `fmod` returns same sign as `in`, which is positive after the `fabs` above.
97   scalar_t extra = std::fmod(in, span);
98   int flips = static_cast<int>(std::floor(in / span));
99   if (flips % 2 == 0) {
100     return extra + min;
101   } else {
102     return span - extra + min;
103   }
104 }
105 
106 // reflect_coordinates_set_grad works similarly to reflect_coordinates except
107 // that it also returns the `d output / d input` via pointer argument
108 // `grad_in`.
109 // This is useful in the backward pass of grid_sampler.
110 template<typename scalar_t>
reflect_coordinates_set_grad(scalar_t in,int64_t twice_low,int64_t twice_high,scalar_t * grad_in)111 static inline scalar_t reflect_coordinates_set_grad(scalar_t in, int64_t twice_low,
112                                                     int64_t twice_high, scalar_t *grad_in) {
113   if (twice_low == twice_high) {
114     *grad_in = static_cast<scalar_t>(0);
115     return static_cast<scalar_t>(0);
116   }
117   int grad_in_mult_;
118   scalar_t min = static_cast<scalar_t>(twice_low) / 2;
119   scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
120   in = in - min;
121   if (in < static_cast<scalar_t>(0)) {
122     grad_in_mult_ = -1;
123     in = -in;
124   } else {
125     grad_in_mult_ = 1;
126   }
127   // `fmod` returns same sign as `in`, which is positive after the `if` above.
128   scalar_t extra = std::fmod(in, span);
129   int flips = static_cast<int>(std::floor(in / span));
130   if (flips % 2 == 0) {
131     *grad_in = static_cast<scalar_t>(grad_in_mult_);
132     return extra + min;
133   } else {
134     *grad_in = static_cast<scalar_t>(-grad_in_mult_);
135     return span - extra + min;
136   }
137 }
138 
139 // Mapping the out-of-boundary points back into boundary
140 // This would only affect padding_mode=border or reflection
141 template<typename scalar_t>
compute_coordinates(scalar_t coord,int64_t size,GridSamplerPadding padding_mode,bool align_corners)142 static inline scalar_t compute_coordinates(scalar_t coord, int64_t size,
143                                            GridSamplerPadding padding_mode,
144                                            bool align_corners) {
145   if (padding_mode == GridSamplerPadding::Border) {
146     // clip coordinates to image borders
147     coord = clip_coordinates(coord, size);
148   } else if (padding_mode == GridSamplerPadding::Reflection) {
149     // reflect coordinates by image borders
150     if (align_corners) {
151       coord = reflect_coordinates(coord, 0, 2*(size - 1));
152     } else {
153       coord = reflect_coordinates(coord, -1, 2*size - 1);
154     }
155     // clip coordinates to image borders
156     coord = clip_coordinates(coord, size);
157   }
158   return coord;
159 }
160 
161 // Computes the pixel source index value for a grid coordinate
162 template <typename scalar_t>
grid_sampler_compute_source_index(scalar_t coord,int64_t size,GridSamplerPadding padding_mode,bool align_corners)163 static inline scalar_t grid_sampler_compute_source_index(
164     scalar_t coord,
165     int64_t size,
166     GridSamplerPadding padding_mode,
167     bool align_corners) {
168   coord = grid_sampler_unnormalize(coord, size, align_corners);
169   coord = compute_coordinates(coord, size, padding_mode, align_corners);
170   return coord;
171 }
172 
173 // grid_sampler_compute_source_index_set_grad works similarly to
174 // grid_sampler_compute_source_index except that it also returns the
175 // `d output / d input` via pointer argument `grad_in`.
176 // This is useful in the backward pass of grid_sampler.
177 template <typename scalar_t>
grid_sampler_compute_source_index_set_grad(scalar_t coord,int64_t size,GridSamplerPadding padding_mode,bool align_corners,scalar_t * grad_in)178 static inline scalar_t grid_sampler_compute_source_index_set_grad(
179     scalar_t coord,
180     int64_t size,
181     GridSamplerPadding padding_mode,
182     bool align_corners,
183     scalar_t *grad_in) {
184   scalar_t grad_clip, grad_refl;
185   coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in);
186   if (padding_mode == GridSamplerPadding::Border) {
187     // clip coordinates to image borders
188     coord = clip_coordinates_set_grad(coord, size, &grad_clip);
189     *grad_in = (*grad_in) * grad_clip;
190   } else if (padding_mode == GridSamplerPadding::Reflection) {
191     // reflect coordinates by image borders
192     if (align_corners) {
193       coord = reflect_coordinates_set_grad(coord, 0, 2*(size - 1), &grad_refl);
194     } else {
195       coord = reflect_coordinates_set_grad(coord, -1, 2*size - 1, &grad_refl);
196     }
197     // clip coordinates to image borders
198     coord = clip_coordinates_set_grad(coord, size, &grad_clip);
199     *grad_in = (*grad_in) * grad_refl * grad_clip;
200   }
201   return coord;
202 }
203 
within_bounds_2d(int64_t h,int64_t w,int64_t H,int64_t W)204 static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) {
205   return h >= 0 && h < H && w >= 0 && w < W;
206 }
207 
within_bounds_3d(int64_t d,int64_t h,int64_t w,int64_t D,int64_t H,int64_t W)208 static inline bool within_bounds_3d(int64_t d, int64_t h, int64_t w, int64_t D, int64_t H, int64_t W) {
209   return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
210 }
211 
212 template<typename scalar_t>
get_value_bounded(const scalar_t * data,scalar_t x,scalar_t y,int64_t W,int64_t H,int64_t sW,int64_t sH,GridSamplerPadding padding_mode,bool align_corners)213 static inline scalar_t get_value_bounded(
214     const scalar_t* data,
215     scalar_t x,
216     scalar_t y,
217     int64_t W,
218     int64_t H,
219     int64_t sW,
220     int64_t sH,
221     GridSamplerPadding padding_mode,
222     bool align_corners) {
223 
224   x = compute_coordinates(x, W, padding_mode, align_corners);
225   y = compute_coordinates(y, H, padding_mode, align_corners);
226 
227   int64_t ix = static_cast<int64_t>(x);
228   int64_t iy = static_cast<int64_t>(y);
229 
230   if (within_bounds_2d(iy, ix, H, W)) {
231     return data[iy * sH + ix * sW];
232   }
233   return static_cast<scalar_t>(0);
234 }
235 
236 template<typename scalar_t>
safe_add_2d(scalar_t * data,int64_t h,int64_t w,int64_t sH,int64_t sW,int64_t H,int64_t W,scalar_t delta)237 static inline void safe_add_2d(scalar_t *data, int64_t h, int64_t w,
238                                int64_t sH, int64_t sW, int64_t H, int64_t W,
239                                scalar_t delta) {
240   if (within_bounds_2d(h, w, H, W)) {
241     data[h * sH + w * sW] += delta;
242   }
243 }
244 
245 template<typename scalar_t>
safe_add_3d(scalar_t * data,int64_t d,int64_t h,int64_t w,int64_t sD,int64_t sH,int64_t sW,int64_t D,int64_t H,int64_t W,scalar_t delta)246 static inline void safe_add_3d(scalar_t *data, int64_t d, int64_t h, int64_t w,
247                                int64_t sD, int64_t sH, int64_t sW,
248                                int64_t D, int64_t H, int64_t W,
249                                scalar_t delta) {
250   if (within_bounds_3d(d, h, w, D, H, W)) {
251     data[d * sD + h * sH + w * sW] += delta;
252   }
253 }
254 
255 template<typename scalar_t>
add_value_bounded(scalar_t * data,scalar_t x,scalar_t y,int64_t W,int64_t H,int64_t sW,int64_t sH,scalar_t delta,GridSamplerPadding padding_mode,bool align_corners)256 static inline void add_value_bounded(
257     scalar_t* data,
258     scalar_t x,
259     scalar_t y,
260     int64_t W,
261     int64_t H,
262     int64_t sW,
263     int64_t sH,
264     scalar_t delta,
265     GridSamplerPadding padding_mode,
266     bool align_corners) {
267 
268   x = compute_coordinates(x, W, padding_mode, align_corners);
269   y = compute_coordinates(y, H, padding_mode, align_corners);
270 
271   int64_t ix = static_cast<int64_t>(x);
272   int64_t iy = static_cast<int64_t>(y);
273 
274   safe_add_2d(data, iy, ix, sH, sW, H, W, delta);
275 }
276 
277 // Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
278 template<typename scalar_t>
get_cubic_coefficients_grad(scalar_t coeffs[4],scalar_t t)279 static inline void get_cubic_coefficients_grad(
280     scalar_t coeffs[4],
281     scalar_t t) {
282 
283   // Must be the same as forward calculation in
284   // aten/src/ATen/native/UpSample.h:get_cubic_upsample_coefficients
285   scalar_t A = -0.75;
286 
287   scalar_t x;
288   x = -1 - t; // 1 < x = |-1 - tx| < 2
289   coeffs[0] = (-3 * A * x - 10 * A ) * x - 8 * A;
290   x = -t;     // x = |0 - tx| <= 1
291   coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
292   x = 1 - t;  // x = |1 - tx| <= 1
293   coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
294   x = 2 - t;  // 1 < x = |2 - tx| < 2
295   coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
296 }
297 
298 }  // namespace at::native
299