xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/UpSampleKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/TensorIterator.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/native/UpSample.h>
9 #include <ATen/native/cpu/utils.h>
10 #include <c10/util/irange.h>
11 #include <ATen/native/cpu/UpSampleKernelAVXAntialias.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #else
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/empty_native.h>
18 #include <ATen/ops/ones.h>
19 #endif
20 
21 namespace at::native {
22 namespace {
23 
24 using scale_t = std::vector<std::optional<double>>;
25 
26 // TODO: this file could benefit from a global renaming of its functions /
27 // classes and terms, as well as from adding more comments. In particular:
28 // - It's not obvious that despite their names (and the file name), all these
29 //   kernels don't just do upsampling: they do general interpolation, i.e. they
30 //   also all support downscaling.
31 // - the term "horizontal" or "within dims" or "contiguous dim" refers to the
32 //   last dimension.
33 //   It's not specific to 2D images and applies to 3D (and 1D??) inputs as well.
34 //   Similarly "vertical" or "across dims" refers to all dims that aren't the
35 //   last one. In other kernels these are also referred to as "zero-stride" and
36 //   "non-zero-stride" - we should unify all this.
37 // - the terms "zero-stride" and "non-zero strides" refer to the weights and
38 //   indices, not to the contiguity of input or output
39 // - It's not always clear which kernel is vectorized and which one isn't.
40 // - The functions like _use_vectorized_kernel_cond() should be renamed and
41 //   their description updated, because they're not the only "fork" in the
42 //   code-path where a choice is made between a vectorized kernel vs a
43 //   non-vectorized one. See e.g. upsample_bilinear2d_kernel_impl() where we
44 //   already make a similar check, before the one in
45 //   _use_vectorized_kernel_cond().
46 // - It's not always clear which code is part of a "separable interpolation"
47 //   code-path.
48 // - Some names need to be more specific. For example
49 //   "cpu_upsample_generic_aa()" looks like a super generic name, but the function
50 //   is instead fairly specific - we need to make that clearer.
51 // - Some functions have a "aa" suffix but it doesn't mean that they only
52 //   support antialias. Some of them also support antialias=False now.
53 // - Various comments are outdated. Case in point: the one just below about the
54 //   `Interpolate` struct being used for cpu_upsample_linear:
55 //   cpu_upsample_linear doesn't exist anymore, and these structs are used for
56 //   various modes, *not* just linear.
57 // - It'd be useful to document how interpolation works in general, and in particular state explicitly:
58 //   - that the weights and indices across a given dimension are the same for
59 //     all pixels (hence the benefit of pre-computing them)
60 //   - that it can be "separated", i.e. we can do the horizontal pass and the
61 //     vertical pass independently (and that some kernels are written this way,
62 //     while some aren't.)
63 // - we can probably remove the template over index_t, because it's always
64 //   hard-coded as int64_t
65 
66 
67 // Helper structs and methods for cpu_upsample_linear
68 //
69 // Interpolation methods that used below are separable, and as such we can compute the interpolation
70 // independently per dimension in a recursive way. Please, refer to #10482 for more context.
71 //
72 // Interpolation structure to compute output value in n-dimensional case.
73 // - recursively compute interpolated output for each dimension
74 // - we rely a lot on compiler's code optimization such that implemented operations
75 //   can be automatically factorized and vectorized using SSE and AVX2
76 template <int n, typename scalar_t, typename opmath_t, typename index_t, int interp_size>
77 struct Interpolate {
evalat::native::__anonb4ad1c8a0111::Interpolate78     static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
79       index_t ids = *(index_t*)&data[0][i * strides[0]];
80       opmath_t wts = *(scalar_t*)&data[1][i * strides[1]];
81       opmath_t t = Interpolate<n - 1, scalar_t, opmath_t, index_t, interp_size>::eval(src + ids, &data[2 * interp_size], &strides[2 * interp_size], i);
82       opmath_t output = t * wts;
83       for (const auto j : c10::irange(1, interp_size)) {
84         ids = *(index_t*)&data[2 * j + 0][i * strides[2 * j + 0]];
85         wts = *(scalar_t*)&data[2 * j + 1][i * strides[2 * j + 1]];
86         t = Interpolate<n - 1, scalar_t, opmath_t, index_t, interp_size>::eval(src + ids, &data[2 * interp_size], &strides[2 * interp_size], i);
87         output += t * wts;
88       }
89       return output;
90   }
91 };
92 
93 template <typename scalar_t, typename opmath_t, typename index_t, int interp_size>
94 struct Interpolate<1, scalar_t, opmath_t, index_t, interp_size> {
evalat::native::__anonb4ad1c8a0111::Interpolate95     static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
96       index_t ids = *(index_t*)&data[0][i * strides[0]];
97       opmath_t wts = *(scalar_t*)&data[1][i * strides[1]];
98       opmath_t t = *(scalar_t *)&src[ids];
99       opmath_t output = t * wts;
100       for (const auto j : c10::irange(1, interp_size)) {
101         ids = *(index_t*)&data[2 * j + 0][i * strides[2 * j + 0]];
102         wts = *(scalar_t*)&data[2 * j + 1][i * strides[2 * j + 1]];
103         t = *(scalar_t *)&src[ids];
104         output += t * wts;
105       }
106       return output;
107     }
108 };
109 
110 template <int n, typename scalar_t, typename opmath_t, typename index_t>
111 struct Interpolate<n, scalar_t, opmath_t, index_t, 1> {
evalat::native::__anonb4ad1c8a0111::Interpolate112     static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
113       index_t ids = *(index_t*)&data[0][i * strides[0]];
114       return Interpolate<n - 1, scalar_t, opmath_t, index_t, 1>::eval(src + ids, &data[2], &strides[2], i);
115   }
116 };
117 
118 template <typename scalar_t, typename opmath_t, typename index_t>
119 struct Interpolate<1, scalar_t, opmath_t, index_t, 1> {
evalat::native::__anonb4ad1c8a0111::Interpolate120     static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
121       index_t ids = *(index_t*)&data[0][i * strides[0]];
122       return *(scalar_t *)&src[ids];
123     }
124 };
125 
126 // There is an unexpected 2x slowdown for upsample_trilinear3d channels_first
127 // for both 1 and 6 threads. We have to specialize this case as below:
128 // Once the issue is fixed we can keep generic implementation and remove:
129 // struct Interpolate<n, scalar_t, index_t, 2> and
130 // struct Interpolate<1, scalar_t, index_t, 2>
131 template <int n, typename scalar_t, typename opmath_t, typename index_t>
132 struct Interpolate<n, scalar_t, opmath_t, index_t, 2> {
evalat::native::__anonb4ad1c8a0111::Interpolate133     static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
134         index_t i0 = *(index_t*)&data[0][i * strides[0]];
135         index_t i1 = *(index_t*)&data[2][i * strides[2]];
136         opmath_t w0 = *(scalar_t *)&data[1][i * strides[1]];
137         opmath_t w1 = *(scalar_t *)&data[3][i * strides[3]];
138 
139         opmath_t t0 = Interpolate<n - 1, scalar_t, opmath_t, index_t, 2>::eval(src + i0, &data[4], &strides[4], i);
140         opmath_t t1 = Interpolate<n - 1, scalar_t, opmath_t, index_t, 2>::eval(src + i1, &data[4], &strides[4], i);
141 
142         return t0 * w0 + t1 * w1;
143   }
144 };
145 
146 template <typename scalar_t, typename opmath_t, typename index_t>
147 struct Interpolate<1, scalar_t, opmath_t, index_t, 2> {
evalat::native::__anonb4ad1c8a0111::Interpolate148     static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
149         index_t i0 = *(index_t*)&data[0][i * strides[0]];
150         index_t i1 = *(index_t*)&data[2][i * strides[2]];
151         opmath_t w0 = *(scalar_t *)&data[1][i * strides[1]];
152         opmath_t w1 = *(scalar_t *)&data[3][i * strides[3]];
153         opmath_t t0 = *(scalar_t *)&src[i0];
154         opmath_t t1 = *(scalar_t *)&src[i1];
155         return t0 * w0 + t1 * w1;
156     }
157 };
158 
159 template <int n, typename scalar_t, typename index_t, int interp_size>
interpolate(char * src,char ** data,const int64_t * strides,int64_t i)160 static inline scalar_t interpolate(char* src, char** data, const int64_t* strides, int64_t i) {
161   using opmath_t = at::opmath_type<scalar_t>;
162   return Interpolate<n, scalar_t, opmath_t, index_t, interp_size>::eval(src, data, strides, i);
163 }
164 
165 template <typename scalar_t, typename index_t>
interpolate_aa_single_dim_zero_strides(char * src,char ** data,const index_t ids_stride)166 static inline scalar_t interpolate_aa_single_dim_zero_strides(
167     char* src,
168     char** data,
169     const index_t ids_stride) {
170   const index_t ids_min = *(index_t*)&data[0][0];
171   const index_t ids_size = *(index_t*)&data[1][0];
172 
173   char* src_min = src + ids_min;
174 
175   scalar_t t = *(scalar_t*)&src_min[0];
176   index_t wts_idx = *(index_t*)&data[4][0];
177   scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx];
178   scalar_t wts = wts_ptr[0];
179 
180   scalar_t output = t * wts;
181   for (const auto j : c10::irange(1, ids_size)) {
182     wts = wts_ptr[j];
183     t = *(scalar_t*)&src_min[j * ids_stride];
184     output += t * wts;
185   }
186   return output;
187 }
188 
189 template <typename scalar_t, typename index_t>
interpolate_aa_single_dim(char * src,char ** data,const int64_t * strides,int64_t i,const index_t ids_stride)190 static inline scalar_t interpolate_aa_single_dim(
191     char* src,
192     char** data,
193     const int64_t* strides,
194     int64_t i,
195     const index_t ids_stride) {
196   index_t ids_min = *(index_t*)&data[0][i * strides[0]];
197   index_t ids_size = *(index_t*)&data[1][i * strides[1]];
198 
199   char* src_min = src + ids_min;
200 
201   scalar_t t = *(scalar_t*)&src_min[0];
202   index_t wts_idx = *(index_t*)&data[4][i * strides[4]];
203   scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx];
204   scalar_t wts = wts_ptr[0];
205 
206   scalar_t output = t * wts;
207   for (const auto j : c10::irange(1, ids_size)) {
208     wts = wts_ptr[j];
209     t = *(scalar_t*)&src_min[j * ids_stride];
210     output += t * wts;
211   }
212   return output;
213 }
214 
215 template<int m>
is_zero_stride(const int64_t * strides)216 static inline bool is_zero_stride(const int64_t* strides) {
217   bool output = strides[0] == 0;
218   for (const auto i : c10::irange(1, m)) {
219     output &= (strides[i] == 0);
220   }
221   return output;
222 }
223 
224 template <typename scalar_t, typename index_t, int interp_size>
is_contiguous_stride(const int64_t * strides)225 static inline bool is_contiguous_stride(const int64_t* strides) {
226   bool output = (strides[0] == sizeof(index_t)) && (strides[1] == sizeof(scalar_t));
227   for (int i=2; i<2 * interp_size; i+=2) {
228     output &= (strides[i] == sizeof(index_t)) && (strides[i + 1] == sizeof(scalar_t));
229   }
230   return output;
231 }
232 
233 // Helper class to recursively check if all input strides corresponding to interpolated dimensions
234 // are equal zero except on a single dimension.
235 //
236 // Inputs: array of strides of size N, non_zero_stride_dim which can be -1, 0, 1, 2, ...
237 //   if non_zero_stride_dim, we check that all strides are equal zero, otherwise
238 //   4 strides corresponding to the strides for index_0, weight_0, index_1 and weight_1 for non_zero_stride_dim
239 //   dimension should be non zero.
240 //
241 // Unit check of the recursion is to verify whether 4 strides for one interpolated dimension are either zero,
242 // see method is_zero_stride, or (sizeof(index_t), sizeof(scalar_t), sizeof(index_t), sizeof(scalar_t)), see
243 // method is_contiguous_stride.
244 //
245 // In practice, we have the following cases:
246 // - for ND, float32, channel first, strides are
247 //         dimN-1,              dim1,           dim0
248 //         i0, w0, i1, w1, ..., i0, w0, i1, w1, i0, w0, i1, w1
249 // strides=(0,  0,  0,  0, ...,  0,  0,  0,  0,  4,  4,  4,  4)
250 //
251 // if size dim0 is 1 then its strides are 0 and dim1 strides are equal 4
252 //
253 // - for ND, float32, channel last, strides are
254 //         dimN-1,         dimN-2,             dim0
255 //         i0, w0, i1, w1, i0, w0, i1, w1, ... i0, w0, i1, w1
256 // strides=(0,  0,  0,  0,  0,  0,  0,  0, ..., 0,  0,  0,  0)
257 //
258 // Using these methods we can hint the compiler to factorize constant indices and weights
259 // in cpu_upsample_linear method
260 template <int N, int non_zero_stride_dim, typename scalar_t, typename index_t, int interp_size>
261 struct CheckAlmostAllZeroStrides {
evalat::native::__anonb4ad1c8a0111::CheckAlmostAllZeroStrides262   static inline bool eval(const int64_t* strides) {
263     // N is dim index: N -> dim0, N-1 -> dim1, ...
264     // non_zero_stride_dim should be out_dims - dim
265     bool output = false;
266     if constexpr (N == non_zero_stride_dim) {
267       output = is_contiguous_stride<scalar_t, index_t, interp_size>(strides);
268     } else {
269       output = is_zero_stride<2 * interp_size>(strides);
270     }
271     return output &&
272       CheckAlmostAllZeroStrides<N - 1, non_zero_stride_dim, scalar_t, index_t, interp_size>::eval(
273         &strides[2 * interp_size]);
274   }
275 };
276 
277 template <int non_zero_stride_dim, typename scalar_t, typename index_t, int interp_size>
278 struct CheckAlmostAllZeroStrides<0, non_zero_stride_dim, scalar_t, index_t, interp_size> {
evalat::native::__anonb4ad1c8a0111::CheckAlmostAllZeroStrides279   static inline bool eval(const int64_t* /*strides*/) {
280     return true;
281   }
282 };
283 
284 template <int n, int s, typename scalar_t, typename index_t, int interp_size>
check_almost_all_zero_stride(const int64_t * strides)285 static inline bool check_almost_all_zero_stride(const int64_t* strides) {
286   return CheckAlmostAllZeroStrides<n, s, scalar_t, index_t, interp_size>::eval(strides);
287 }
288 
289 // Helper method to compute interpolation for nearest, linear, cubic modes
290 template <typename scalar_t, typename index_t, int out_ndims, int interp_size>
basic_loop(char ** data,const int64_t * strides,int64_t n)291 static inline void basic_loop(char** data, const int64_t* strides, int64_t n) {
292   char* dst = data[0];
293   char* src = data[1];
294   for (const auto i : c10::irange(n)) {
295     *(scalar_t*)&dst[i * strides[0]] = interpolate<out_ndims, scalar_t, index_t, interp_size>(
296         src + i * strides[1], &data[2], &strides[2], i);
297   }
298 }
299 
300 template <typename scalar_t>
basic_loop_aa_vertical(char ** data,const int64_t * strides,int64_t n,unsigned int weights_precision)301 static inline void basic_loop_aa_vertical(
302     char** data,
303     const int64_t* strides,
304     int64_t n,
305     unsigned int weights_precision) {
306   char* dst = data[0];
307   char* src = data[1];
308   // index stride is constant for the given dimension
309   const int64_t ids_stride = *(int64_t*)&data[2 + 2][0];
310 
311   for (const auto i : c10::irange(n)) {
312     *(scalar_t*)&dst[i * strides[0]] =
313         interpolate_aa_single_dim_zero_strides<scalar_t, int64_t>(
314             src + i * strides[1], &data[2], ids_stride);
315   }
316 }
317 
318 template <>
basic_loop_aa_vertical(char ** data,const int64_t * strides,int64_t n,unsigned int weights_precision)319 inline void basic_loop_aa_vertical<uint8_t>(
320     char** data,
321     const int64_t* strides,
322     int64_t n,
323     unsigned int weights_precision) {
324   // See Note [ Weights computation for uint8_t and multiplication trick ]
325   char* dst = data[0];
326   char* src = data[1];
327 
328   // index stride is constant for the given dimension
329   const int64_t ids_stride = *(int64_t*)&data[2 + 2][0];
330   const int64_t ids_size = *(int64_t*)&data[2 + 1][0];
331   const int64_t ids_min = *(int64_t*)&data[2 + 0][0];
332 
333   int64_t i = 0;
334 
335   for (; i<n; i++) {
336 
337     char* src_min = src + i * strides[1] + ids_min;
338 
339     uint8_t t = *(uint8_t*)&src_min[0];
340     int64_t wts_idx = *(int64_t*)&data[2 + 4][0];
341     int16_t* wts_ptr = (int16_t*)&data[2 + 3][wts_idx];
342     int16_t wts = wts_ptr[0];
343 
344     // Intermediate computations are using integer type
345     int output = 1 << (weights_precision - 1);  // accounts for the +0.5 part
346     output += t * wts;
347     for (const auto j : c10::irange(1, ids_size)) {
348       wts = wts_ptr[j];
349       t = *(uint8_t*)&src_min[j * ids_stride];
350       output += t * wts;
351     }
352     *(uint8_t*)&dst[i * strides[0]] = (uint8_t)std::clamp(output >> weights_precision, 0, 255);
353   }
354 }
355 
356 template <typename scalar_t>
basic_loop_aa_horizontal(char ** data,const int64_t * strides,int64_t n,unsigned int weights_precision)357 static inline void basic_loop_aa_horizontal(
358     char** data,
359     const int64_t* strides,
360     int64_t n,
361     unsigned int weights_precision) {
362   char* dst = data[0];
363   char* src = data[1];
364   // index stride is constant for the given dimension
365   const int64_t ids_stride = *(int64_t*)&data[2 + 2][0];
366 
367   if (strides[1] == 0) {
368     for (const auto i : c10::irange(n)) {
369       *(scalar_t*)&dst[i * strides[0]] =
370           interpolate_aa_single_dim<scalar_t, int64_t>(
371               src, &data[2], &strides[2], i, ids_stride);
372     }
373   } else {
374     for (const auto i : c10::irange(n)) {
375       *(scalar_t*)&dst[i * strides[0]] =
376           interpolate_aa_single_dim<scalar_t, int64_t>(
377               src + i * strides[1], &data[2], &strides[2], i, ids_stride);
378     }
379   }
380 }
381 
382 template <>
basic_loop_aa_horizontal(char ** data,const int64_t * strides,int64_t n,unsigned int weights_precision)383 inline void basic_loop_aa_horizontal<uint8_t>(
384     char** data,
385     const int64_t* strides,
386     int64_t n,
387     unsigned int weights_precision) {
388   // See Note [ Weights computation for uint8_t and multiplication trick ]
389   char* dst = data[0];
390   char* src = data[1];
391   // index stride is constant for the given dimension
392   const int64_t ids_stride = *(int64_t*)&data[2 + 2][0];
393 
394   int64_t i = 0;
395 
396   // Here we are implementing data interpolation within the same line (vs between the lines)
397   // output[x, y] = input[xmin[x], y] * W[x] + input[xmin[x] + 1, y] * W[x + 1] + ... + input[xmin[x] + xsize, y] * W[x + xsize]
398 
399   for (; i<n; i++) {
400 
401     int64_t ids_min = *(int64_t*)&data[2 + 0][i * strides[2 + 0]];
402     int64_t ids_size = *(int64_t*)&data[2 + 1][i * strides[2 + 1]];
403 
404     char* src_min = src + i * strides[1] + ids_min;
405 
406     uint8_t t = *(uint8_t*)&src_min[0];
407     int64_t wts_idx = *(int64_t*)&data[2 + 4][i * strides[2 + 4]];
408     int16_t* wts_ptr = (int16_t*)&data[2 + 3][wts_idx];
409     int16_t wts = wts_ptr[0];
410 
411     // Intermediate computations are using integer type
412     int output = 1 << (weights_precision - 1);  // accounts for the +0.5 part
413     output += t * wts;
414     for (const auto j : c10::irange(1, ids_size)) {
415       wts = wts_ptr[j];
416       t = *(uint8_t*)&src_min[j * ids_stride];
417       output += t * wts;
418     }
419     *(uint8_t*)&dst[i * strides[0]] = (uint8_t)std::clamp(output >> weights_precision, 0, 255);
420   }
421 }
422 
423 // Generic upsampling computation method using TensorIterator for Nd case.
424 // Supports: nearest, linear, cubic modes with interp_size template argument: 1, 2, 4
425 //
426 // Single loop function for 1d, 2d and 3d cases and modes
427 // For N dimensions, output value up to Di dimension can be computed as
428 //
429 // output_i[a] = interpolate(output_{i+1}[a], w_{i+1}[a], output_{i+1}[a+1], w_{i+1}[a+1], ...)
430 // with
431 // output_DN[a] = interpolate(input_DN[a], w_DN[a], input_DN[a+1], w_DN[a+1], ...)
432 // and i - dimension index and a - linear index for spatial coordinates
433 //
434 // The recursive call is implemented with InterpLinear struct using template for
435 // the loop unrolling on compile time.
436 template <typename scalar_t, int out_ndims, int interp_size>
cpu_upsample_generic(at::TensorIterator & iter)437 void cpu_upsample_generic(at::TensorIterator& iter)
438 {
439   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
440     // special-cases to let the compiler apply compile-time input-specific optimizations
441     if ((strides[0] == sizeof(scalar_t) && (strides[1] == 0) &&
442         // NOLINTNEXTLINE(bugprone-branch-clone)
443         check_almost_all_zero_stride<out_ndims, 1, scalar_t, int64_t, interp_size>(&strides[2]))) {
444       // contiguous channels-first case
445       basic_loop<scalar_t, int64_t, out_ndims, interp_size>(data, strides, n);
446     } else if ((strides[0] == sizeof(scalar_t) && (strides[1] == sizeof(scalar_t)) &&
447                check_almost_all_zero_stride<out_ndims, -1, scalar_t, int64_t, interp_size>(&strides[2]))) {
448       // contiguous channels-last case
449       basic_loop<scalar_t, int64_t, out_ndims, interp_size>(data, strides, n);
450     } else {
451       // fallback
452       basic_loop<scalar_t, int64_t, out_ndims, interp_size>(data, strides, n);
453     }
454   };
455   iter.for_each(loop);
456 }
457 
458 template <typename scalar_t, typename scale_type, nearest_idx_fn_t nearest_idx_fn>
cpu_upsample_nearest_channels_last(const Tensor & output_,const Tensor & input_,const scale_type & scales)459 void cpu_upsample_nearest_channels_last(
460     const Tensor& output_,
461     const Tensor& input_,
462     const scale_type& scales) {
463   TORCH_CHECK(input_.dtype() == output_.dtype(), "expected dtype ", input_.dtype(),
464               " for `output` but got dtype ", output_.dtype());
465 
466   auto input_sizes = input_.sizes().vec();
467   auto output_sizes = output_.sizes().vec();
468   auto ndim = input_sizes.size();
469   TORCH_CHECK(ndim >=4 && ndim <= 5, "Upsample with NHWC format supports tensors with 4 or 5 dims.")
470 
471   auto channels_last_memory_format = ndim == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
472   auto input = input_.contiguous(channels_last_memory_format);
473   auto output = output_.contiguous(channels_last_memory_format);
474 
475   auto input_data = input.const_data_ptr<scalar_t>();
476   auto output_data = output.data_ptr<scalar_t>();
477 
478   int64_t num_batches =  input_sizes[0];
479   int64_t channels =  input_sizes[1];
480   int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
481   int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
482   int64_t input_height = input_sizes[ndim - 2];
483   int64_t output_height = output_sizes[ndim - 2];
484   int64_t input_width = input_sizes[ndim - 1];
485   int64_t output_width = output_sizes[ndim - 1];
486   int64_t numel = output.numel();
487 
488   TORCH_CHECK(channels > 0, "expected input and output channels greater than 0 but got ", channels);
489 
490   using Vec = vec::Vectorized<scalar_t>;
491   auto copy = [](scalar_t* out, const scalar_t* in, int64_t size) {
492     int64_t d = 0;
493     for (; d < size - (size % Vec::size()); d += Vec::size()) {
494       Vec out_vec = Vec::loadu(in + d);
495       out_vec.store(out + d);
496     }
497     for (; d < size; d++) {
498       out[d] = in[d];
499     }
500   };
501 
502   auto loop2d = [&](int64_t begin, int64_t end) {
503     int64_t n = 0;
504     int64_t oh = 0;
505     int64_t ow = 0;
506     data_index_init(begin, n, num_batches, oh, output_height, ow, output_width);
507 
508     for (const auto i : c10::irange(begin, end)) {
509       int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[0]);
510       int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[1]);
511       scalar_t* output_ptr = output_data + i * channels;
512       const scalar_t* input_ptr = input_data + n * input_height * input_width * channels +
513           ih * input_width * channels + iw * channels;
514       copy(output_ptr, input_ptr, channels);
515       data_index_step(n, num_batches, oh, output_height, ow, output_width);
516     }
517   };
518 
519   auto loop3d = [&](int64_t begin, int64_t end) {
520     int64_t n = 0;
521     int64_t od = 0;
522     int64_t oh = 0;
523     int64_t ow = 0;
524     data_index_init(begin, n, num_batches, od, output_depth, oh, output_height, ow, output_width);
525 
526     for (const auto i : c10::irange(begin, end)) {
527       int64_t id = nearest_idx_fn(od, input_depth, output_depth, scales[0]);
528       int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[1]);
529       int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[2]);
530       scalar_t* output_ptr = output_data + i * channels;
531       const scalar_t* input_ptr = input_data + n * input_depth * input_height * input_width * channels +
532           id * input_height * input_width * channels +
533           ih * input_width * channels + iw * channels;
534       copy(output_ptr, input_ptr, channels);
535       data_index_step(n, num_batches, od, output_depth, oh, output_height, ow, output_width);
536     }
537   };
538 
539   if (ndim == 4) {
540     // upsample nearest 2d
541     at::parallel_for(0, numel / channels, at::internal::GRAIN_SIZE / channels, loop2d);
542   } else {
543     // upsample nearest 3d
544     TORCH_INTERNAL_ASSERT(ndim == 5);
545     at::parallel_for(0, numel / channels, at::internal::GRAIN_SIZE / channels, loop3d);
546   }
547 
548   if (!output_.is_contiguous(channels_last_memory_format)) {
549     output_.copy_(output);
550   }
551 }
552 
553 template <typename scalar_t, typename accscalar_t>
interpolate(const scalar_t * t,accscalar_t w)554 inline VecType<scalar_t> interpolate(const scalar_t* t, accscalar_t w) {
555   return VecType<scalar_t>::loadu(t) * VecType<scalar_t>(w);
556 }
557 
558 template <typename scalar_t, typename accscalar_t, typename... Args>
interpolate(const scalar_t * t,accscalar_t w,Args...args)559 inline VecType<scalar_t> interpolate(const scalar_t* t, accscalar_t w, Args... args) {
560   return VecType<scalar_t>::loadu(t) * VecType<scalar_t>(w) + interpolate(args...);
561 }
562 
563 template <typename scalar_t, typename scale_type>
cpu_upsample_linear_channels_last(const Tensor & output_,const Tensor & input_,bool align_corners,const scale_type & scales)564 void cpu_upsample_linear_channels_last(
565     const Tensor& output_,
566     const Tensor& input_,
567     bool align_corners,
568     const scale_type& scales) {
569   TORCH_CHECK(input_.dtype() == output_.dtype(), "expected dtype ", input_.dtype(),
570               " for `output` but got dtype ", output_.dtype());
571 
572   auto input_sizes = input_.sizes().vec();
573   auto output_sizes = output_.sizes().vec();
574   auto ndim = input_sizes.size();
575   TORCH_CHECK(ndim >=4 && ndim <= 5, "Upsample with NHWC format supports tensors with 4 or 5 dims.")
576 
577   auto channels_last_memory_format = ndim == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
578   auto input = input_.contiguous(channels_last_memory_format);
579   auto output = output_.contiguous(channels_last_memory_format);
580 
581   auto input_data = input.const_data_ptr<scalar_t>();
582   auto output_data = output.data_ptr<scalar_t>();
583 
584   int64_t num_batches =  input_sizes[0];
585   int64_t channels =  input_sizes[1];
586   int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
587   int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
588   int64_t input_height = input_sizes[ndim - 2];
589   int64_t output_height = output_sizes[ndim - 2];
590   int64_t input_width = input_sizes[ndim - 1];
591   int64_t output_width = output_sizes[ndim - 1];
592 
593   TORCH_CHECK(channels > 0, "expected input and output channels greater than 0 but got ", channels);
594   int64_t output_slice_size = output_depth * output_height * output_width * channels;
595 
596   using opmath_t = at::opmath_type<scalar_t>;
597   using Vec = vec::Vectorized<scalar_t>;
598   auto loop2d = [&](int64_t begin, int64_t end) {
599     const auto height_scale = area_pixel_compute_scale<opmath_t>(
600         input_height, output_height, align_corners, scales[0]);
601     const auto width_scale = area_pixel_compute_scale<opmath_t>(
602         input_width, output_width, align_corners, scales[1]);
603 
604     auto input_indexr = [=](int64_t n, int64_t h, int64_t w) {
605       return input_data + n * input_height * input_width * channels +
606           h * input_width * channels + w * channels;
607     };
608 
609     int64_t ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
610     opmath_t h0lambda, h1lambda, w0lambda, w1lambda;
611     for (const auto n : c10::irange(begin, end)) {
612       for (const auto oh : c10::irange(output_height)) {
613         compute_source_index_and_lambda(
614             ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
615         for (const auto ow : c10::irange(output_width)) {
616           compute_source_index_and_lambda(
617               iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
618 
619           scalar_t* out = output_data + n * output_slice_size +
620               oh * output_width * channels + ow * channels;
621           const scalar_t* i00 = input_indexr(n, ih0, iw0);
622           const scalar_t* i01 = input_indexr(n, ih0, iw1);
623           const scalar_t* i10 = input_indexr(n, ih1, iw0);
624           const scalar_t* i11 = input_indexr(n, ih1, iw1);
625           opmath_t w00 = h0lambda * w0lambda;
626           opmath_t w01 = h0lambda * w1lambda;
627           opmath_t w10 = h1lambda * w0lambda;
628           opmath_t w11 = h1lambda * w1lambda;
629 
630           int64_t size = channels;
631           int64_t d = 0;
632           for (; d < size - (size % Vec::size()); d += Vec::size()) {
633             auto out_vec = interpolate(i00 + d, w00, i01 + d, w01, i10 + d, w10, i11 + d, w11);
634             out_vec.store(out + d);
635           }
636           for (; d < size; d++) {
637             out[d] = i00[d] * w00 + i01[d] * w01 + i10[d] * w10 + i11[d] * w11;
638           }
639         }
640       }
641     }
642   };
643 
644   auto loop3d = [&](int64_t begin, int64_t end) {
645     const auto depth_scale = area_pixel_compute_scale<opmath_t>(
646         input_depth, output_depth, align_corners, scales[0]);
647     const auto height_scale = area_pixel_compute_scale<opmath_t>(
648         input_height, output_height, align_corners, scales[1]);
649     const auto width_scale = area_pixel_compute_scale<opmath_t>(
650         input_width, output_width, align_corners, scales[2]);
651 
652     auto input_indexr = [=](int64_t n, int64_t d, int64_t h, int64_t w) {
653       return input_data + n * input_depth * input_height * input_width * channels +
654           d * input_height * input_width * channels +
655           h * input_width * channels + w * channels;
656     };
657 
658     int64_t id0 = 0, id1 = 0, ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
659     opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda;
660     for (const auto n : c10::irange(begin, end)) {
661       for (const auto od : c10::irange(output_depth)) {
662         compute_source_index_and_lambda(
663             id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners);
664         for (const auto oh : c10::irange(output_height)) {
665           compute_source_index_and_lambda(
666               ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
667           for (const auto ow : c10::irange(output_width)) {
668             compute_source_index_and_lambda(
669                 iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
670 
671             scalar_t* out = output_data + n * output_slice_size +
672                 od * output_height * output_width * channels +
673                 oh * output_width * channels + ow * channels;
674             const scalar_t* i000 = input_indexr(n, id0, ih0, iw0);
675             const scalar_t* i001 = input_indexr(n, id0, ih0, iw1);
676             const scalar_t* i010 = input_indexr(n, id0, ih1, iw0);
677             const scalar_t* i011 = input_indexr(n, id0, ih1, iw1);
678             const scalar_t* i100 = input_indexr(n, id1, ih0, iw0);
679             const scalar_t* i101 = input_indexr(n, id1, ih0, iw1);
680             const scalar_t* i110 = input_indexr(n, id1, ih1, iw0);
681             const scalar_t* i111 = input_indexr(n, id1, ih1, iw1);
682             opmath_t w000 = d0lambda * h0lambda * w0lambda;
683             opmath_t w001 = d0lambda * h0lambda * w1lambda;
684             opmath_t w010 = d0lambda * h1lambda * w0lambda;
685             opmath_t w011 = d0lambda * h1lambda * w1lambda;
686             opmath_t w100 = d1lambda * h0lambda * w0lambda;
687             opmath_t w101 = d1lambda * h0lambda * w1lambda;
688             opmath_t w110 = d1lambda * h1lambda * w0lambda;
689             opmath_t w111 = d1lambda * h1lambda * w1lambda;
690 
691             int64_t size = channels;
692             int64_t d = 0;
693             for (; d < size - (size % Vec::size()); d += Vec::size()) {
694               auto out_vec = interpolate(
695                   i000 + d, w000, i001 + d, w001, i010 + d, w010, i011 + d, w011,
696                   i100 + d, w100, i101 + d, w101, i110 + d, w110, i111 + d, w111);
697               out_vec.store(out + d);
698             }
699             for (; d < size; d++) {
700               out[d] =
701                   i000[d] * w000 + i001[d] * w001 + i010[d] * w010 + i011[d] * w011 +
702                   i100[d] * w100 + i101[d] * w101 + i110[d] * w110 + i111[d] * w111;
703             }
704           }
705         }
706       }
707     }
708   };
709 
710   if (ndim == 4) {
711     // upsample nearest 2d
712     at::parallel_for(0, num_batches, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d);
713   } else {
714     // upsample nearest 3d
715     TORCH_INTERNAL_ASSERT(ndim == 5);
716     at::parallel_for(0, num_batches, at::internal::GRAIN_SIZE / output_slice_size / 8, loop3d);
717   }
718 
719   if (!output_.is_contiguous(channels_last_memory_format)) {
720     output_.copy_(output);
721   }
722 }
723 
724 // Helper structs to use with upsample_generic_Nd_kernel_impl
725 struct HelperInterpBase {
726 
init_indices_weightsat::native::__anonb4ad1c8a0111::HelperInterpBase727   static inline void init_indices_weights(
728     at::ScalarType output_type,
729     std::vector<Tensor> & output, int64_t output_size, int64_t ndims,
730     int64_t reshape_dim, int interp_size
731   ) {
732 
733     auto new_shape = std::vector<int64_t>(ndims, 1);
734     new_shape[reshape_dim] = output_size;
735 
736     for (const auto j C10_UNUSED : c10::irange(interp_size)) {
737       output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
738       output.emplace_back(empty(new_shape, CPU(output_type)));
739     }
740   }
741 
742   // This is a helper function for _compute_index_ranges_weights method that computes
743   // source two int64 scalars index min and size and a list weights (of size max_interp_size)
744   // for interpolation with antialiasing=true mode. It returns the maximal weights value
745   template <typename scalar_t, typename aa_filter_fn_t>
_compute_indices_min_size_weights_aaat::native::__anonb4ad1c8a0111::HelperInterpBase746   static inline scalar_t _compute_indices_min_size_weights_aa(
747     const int64_t i, const int64_t input_size, const scalar_t scale, const scalar_t support,
748     scalar_t* wt_ptr, const int64_t max_interp_size, aa_filter_fn_t filter_fn,
749     int64_t& xmin, int64_t& xsize
750   ) {
751 
752     scalar_t center = scale * (i + 0.5);
753     scalar_t total_w = 0.0;
754     scalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0;
755     xmin = std::max(
756         static_cast<int64_t>(center - support + 0.5), static_cast<int64_t>(0));
757     xsize = std::min(
758         static_cast<int64_t>(center + support + 0.5), input_size) - xmin;
759     // There are rare cases when due to precision xsize can be larger than max_interp_size by one.
760     // We have to clip the value
761     xsize = std::clamp(xsize, static_cast<int64_t>(0), max_interp_size);
762 
763     int64_t j = 0;
764     for (; j < xsize; j++) {
765       scalar_t w = filter_fn((j + xmin - center + 0.5) * invscale);
766       wt_ptr[j] = w;
767       total_w += w;
768     }
769 
770     scalar_t wt_max = 0.0;
771     if (total_w != 0.0) {
772       for (j = 0; j < xsize; j++) {
773         wt_ptr[j] /= total_w;
774         wt_max = std::max(wt_max, wt_ptr[j]);
775       }
776     }
777 
778     for (; j < max_interp_size; j++) {
779       wt_ptr[j] = static_cast<scalar_t>(0.0);
780     }
781     return wt_max;
782   }
783 
784   // This is a helper function for _compute_index_ranges_weights method that computes
785   // source two int64 scalars index min and size and a list weights (of size max_interp_size)
786   // for interpolation with antialiasing=false mode. It returns the maximal weights value.
787   // This function is templated with scalar_t for type of scale and weights but is only used for
788   // bilinear/bicubic modes on uint8 input and antialiasing=false (in this case scalar_t is double).
789   // For float input types we are using upsample_generic_Nd_kernel_impl and compute_indices_weights methods
790   template <typename scalar_t, typename aa_filter_fn_t>
_compute_indices_min_size_weightsat::native::__anonb4ad1c8a0111::HelperInterpBase791   static inline scalar_t _compute_indices_min_size_weights(
792     const int64_t i, const int64_t input_size, const scalar_t scale,
793     scalar_t* wt_ptr, const int64_t max_interp_size, aa_filter_fn_t filter_fn,
794     bool align_corners, int64_t& index_min, int64_t& index_size
795   ) {
796     // Notes. We do not use opmath_t in this method as f16 and other smaller float types are not routed here.
797     // Typical usage of this method is with scalar_t = double when computing indices and weights for uint8 input
798     // The code below partly adapts indices and lambda computation from compute_indices_weights method and
799     // index_min/index_size from _compute_indices_min_size_weights_aa
800 
801     bool cubic = max_interp_size > 2;
802     const auto real_input_index = area_pixel_compute_source_index<scalar_t>(
803         scale, i, align_corners, /*cubic=*/cubic);
804 
805     scalar_t lambda;
806     int64_t input_index = 0;
807     guard_index_and_lambda(real_input_index, input_size, input_index, lambda);
808 
809     const auto support = static_cast<int64_t>(max_interp_size * 0.5);
810     const auto unbound_index_min = input_index - support + 1;
811     const auto unbound_index_max = input_index + support + 1;
812     index_min = std::max(unbound_index_min, static_cast<int64_t>(0));
813     index_size = std::min(unbound_index_max, input_size) - index_min;
814     // There are rare cases when due to precision xsize can be larger than max_interp_size by one.
815     // We have to clip the value
816     index_size = std::clamp(index_size, static_cast<int64_t>(0), max_interp_size);
817 
818     // Below the weights are computed using filter_fn and accumulating values for indices being out of bounds
819     // For example, for bicubic mode for output index i = 0, we have input_index = -1,
820     // then we have unbound_index_min = -2 and unbound_index_max = 1 => unbounded input indices are [-2, -1, 0, 1] and
821     // valid input indices will be [0, 1]
822     // For unbounded input indices we compute four non-zero weights values [w0, w1, w2, w3] and as only two weights can
823     // be used with valid input indcies, we accumulate values in the following way: [w0 + w1 + w2, w3, 0.0, 0.0]
824     // This is equivalent to the float path which would compute indices as [0, 0, 0, 1] and weights as [w0, w1, w2, s3].
825     // A similar accumulation should done for unbounded indices larger than input size.
826     auto w_index = 0;
827     scalar_t wt_max = 0.0;
828     for (const auto j : c10::irange(max_interp_size)) {
829       // initialize weights value as we will accumulate below
830       wt_ptr[j] = 0.0;
831 
832       scalar_t w = filter_fn(static_cast<scalar_t>(j + 1 - support) - lambda);
833       if (unbound_index_min + j <= 0) {
834         w_index = 0;
835       } else if (unbound_index_min + j >= input_size - 1) {
836         w_index = index_size - 1;
837       }
838       wt_ptr[w_index] += w;
839       wt_max = std::max(wt_max, wt_ptr[w_index]);
840       w_index++;
841     }
842 
843     return wt_max;
844   }
845 
846   // Note [ Support for antialias=False as a subcase of antialias=True ]
847   // This function was originally written with the hard assumption that
848   // antialias=True and it was later extended to support antialias=False.
849   // The only difference between aa and no-aa is in how the
850   // weights and indices are computed (and their number). In aa their number is
851   // variable but with no-aa, they're fixed to interp_size. The same "filters"
852   // can be used otherwise. HOWEVER, support for antialias=False here may not be
853   // optimally optimized: the code assumes an arbitrary number of weights and
854   // indices, but this can be optimized further when aa=False since we know
855   // their actual dimensions.
856   template <typename scalar_t, typename aa_filter_fn_t, int weight_index_stride=sizeof(scalar_t)>
_compute_index_ranges_weightsat::native::__anonb4ad1c8a0111::HelperInterpBase857   static inline std::tuple<std::vector<Tensor>, int, scalar_t> _compute_index_ranges_weights(
858     int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
859     int64_t reshape_dim, scalar_t scale,
860     int interp_size, aa_filter_fn_t aa_filter_fn, bool antialias, bool align_corners
861   ) {
862 
863     std::vector<Tensor> output;
864 
865     scalar_t support;
866     int max_interp_size = 0;
867     if (antialias) {
868         support = (scale >= 1.0) ? (interp_size * 0.5) * scale : interp_size * 0.5;
869         max_interp_size = (int) std::ceil(support) * 2 + 1;
870     } else {
871         support = interp_size * 0.5;
872         max_interp_size = interp_size;
873     }
874 
875     auto new_shape = std::vector<int64_t>(ndims, 1);
876     new_shape[reshape_dim] = output_size;
877 
878     // Bounds approach as in PIL: xmin/xmax
879     output.emplace_back(
880         empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
881     output.emplace_back(
882         empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
883     output.emplace_back(
884         empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
885 
886     {
887       // Weights
888       new_shape[reshape_dim] = output_size * max_interp_size;
889       auto wts = empty(new_shape, CPU(c10::CppTypeToScalarType<scalar_t>()));
890       auto strides = wts.strides().vec();
891       strides[reshape_dim] = 0;
892       new_shape[reshape_dim] = output_size;
893       wts = wts.as_strided(new_shape, strides);
894       output.emplace_back(wts);
895       // Weights indices
896       output.emplace_back(
897           empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
898     }
899 
900     int64_t* idx_ptr_xmin = output[0].data_ptr<int64_t>();
901     int64_t* idx_ptr_size = output[1].data_ptr<int64_t>();
902     int64_t* idx_ptr_stride = output[2].data_ptr<int64_t>();
903     scalar_t* wt_ptr = output[3].data_ptr<scalar_t>();
904     int64_t* wt_idx_ptr = output[4].data_ptr<int64_t>();
905 
906     scalar_t wt_max = 0.0;
907     for (const auto i : c10::irange(output_size)) {
908       int64_t xmin = 0, xsize = 0;
909       scalar_t wt_max_i;
910       if (antialias) {
911         wt_max_i = HelperInterpBase::_compute_indices_min_size_weights_aa(
912             i,
913             input_size,
914             scale,
915             support,
916             wt_ptr + i * max_interp_size,
917             max_interp_size,
918             aa_filter_fn,
919             xmin,
920             xsize);
921       } else {
922         wt_max_i = HelperInterpBase::_compute_indices_min_size_weights(
923             i,
924             input_size,
925             scale,
926             wt_ptr + i * max_interp_size,
927             max_interp_size,
928             aa_filter_fn,
929             align_corners,
930             xmin,
931             xsize);
932       }
933       wt_max = std::max(wt_max, wt_max_i);
934 
935       idx_ptr_xmin[i] = xmin * stride;
936       idx_ptr_size[i] = xsize;
937       idx_ptr_stride[i] = stride;
938       wt_idx_ptr[i] = i * max_interp_size * weight_index_stride;
939     }
940     return {output, max_interp_size, wt_max};
941   }
942 
943   /*
944   NOTE [ Weights computation for uint8_t and multiplication trick ]
945   When the input/output dtype is uint8_t, we still compute the interpolation
946   weights as double, but then convert them to int16 via some conversion logic
947   detailed below. This allows us to compute all interpolation operation (sum of
948   multiplications) as ints instead of floats. The result is converted back into
949   uint8 in basic_loop_aa_horizontal<uint8_t> (and vertical)
950 
951   In essence the idea is to avoid a multiplication between a float (the
952   weight) and an int (the pixel value) and instead run a multiplication between
953   2 ints:
954 
955   ```py
956   COEF_PREC = 16
957 
958   def mul(a:float, b:int) -> Tuple[float, int]:
959     # return a * b, round(a * b)
960     actual = a * b
961 
962     assert a > 0  # I'm lazy
963     int_a = floor(0.5 + a * (1 << COEF_PREC))
964     with_trick = ((int_a * b) + (1 << (COEF_PREC - 1))) >> COEF_PREC
965 
966     return actual, with_trick  # round(actual) == with_trick!!
967   ```
968 
969   Here's how it works:
970   N == COEFF_PREC
971   1 << N == 2**N
972   floor(0.5 + x) == round(x)
973 
974   So the operation is something like
975 
976   int_a = round(a * 2**N)  -- let's just say it's `a * 2**N` for simplicity
977 
978   res = ((int_a * b) + (1 << (N - 1))) >> N
979       = ((a * 2**N * b + 2**(N - 1)) / 2**N
980       = a * b + 0.5
981       = round(a * b)
982       = what we wanted
983   */
984   template <typename aa_filter_fn_t>
_compute_index_ranges_int16_weightsat::native::__anonb4ad1c8a0111::HelperInterpBase985   static inline std::tuple<std::vector<Tensor>, int, unsigned int> _compute_index_ranges_int16_weights(
986     int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
987     int64_t reshape_dim, bool align_corners, const std::optional<double>& opt_scale,
988     int interp_size, aa_filter_fn_t aa_filter_fn, bool antialias, bool align_i32=false
989   ) {
990 
991     double scale = area_pixel_compute_scale<double>(
992         input_size, output_size, align_corners, opt_scale);
993 
994     auto [indices_weights, aligned_interp_size, wt_max] = HelperInterpBase::_compute_index_ranges_weights<double, aa_filter_fn_t, sizeof(int16_t)>(
995         input_size, output_size, stride, ndims, reshape_dim, scale, interp_size, aa_filter_fn, antialias, align_corners);
996     interp_size = aligned_interp_size;
997 
998     // Rescale float weights to int16 and compute weights precision
999     auto weights_f64 = indices_weights[3];
1000     double * data_f64 = weights_f64. template data_ptr<double>();
1001 
1002     unsigned int weights_precision = 0;
1003     for (weights_precision = 0; weights_precision < 22; ++weights_precision) {
1004         int next_value = (int) (0.5 + wt_max * (1 << (weights_precision + 1)));
1005         if (next_value >= (1 << 15))
1006             break;
1007     }
1008 
1009     // Rescale float values to int16
1010     int16_t * data_i16 = (int16_t *) data_f64;
1011 
1012     if (align_i32) {
1013       // We should respect int32 alignment as we will load int16 data as int32
1014       // See ImagingResampleHorizontalConvolution8u4x, mmk0 = _mm256_set1_epi32(*(int32_t*)&k[x]);
1015       // compute aligned_interp_size = nearest pair value to interp_size
1016       while (aligned_interp_size % sizeof(int32_t) != 0) {
1017         aligned_interp_size += 1;
1018       }
1019       // assert that we wont go out of bounds
1020       TORCH_INTERNAL_ASSERT(aligned_interp_size * sizeof(int16_t) < interp_size * sizeof(double));
1021     }
1022 
1023     for (const auto j : c10::irange(output_size)) {
1024       for (const auto k : c10::irange(interp_size)) {
1025         double v = data_f64[j * interp_size + k] * (1 << weights_precision);
1026         data_i16[j * aligned_interp_size + k] = (v < 0) ? (int) (-0.5 + v) : (int) (0.5 + v);
1027       }
1028     }
1029 
1030     return {indices_weights, aligned_interp_size, weights_precision};
1031   }
1032 };
1033 
1034 struct HelperInterpNearest : public HelperInterpBase {
1035   // This structure implements outdated and buggy method to compute indices
1036   // for nearest neighbours interpolation
1037   // We keep this structure for BC and consider as deprecated.
1038   // See HelperInterpNearestExact as replacement
1039 
1040   static const int interp_size = 1;
1041 
init_indices_weightsat::native::__anonb4ad1c8a0111::HelperInterpNearest1042   static inline void init_indices_weights(
1043     at::ScalarType output_type,
1044     std::vector<Tensor> & output, int64_t output_size, int64_t ndims,
1045     int64_t reshape_dim, int interp_size
1046   ) {
1047     auto new_shape = std::vector<int64_t>(ndims, 1);
1048     new_shape[reshape_dim] = output_size;
1049 
1050     for (const auto j C10_UNUSED : c10::irange(interp_size)) {
1051       output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
1052       // Defines weights for consistency, but not used
1053       output.emplace_back(at::ones(new_shape, CPU(output_type)));
1054     }
1055   }
1056 
1057   // Compute nearest mode indices and weights for each interpolated dimension
1058   // indices_weights = {
1059   //      {indices_0, 1.0, },  // dim -n
1060   //      {indices_0, 1.0, },  // dim -(n-1)
1061   //      ...
1062   //      {indices_0, 1.0, },  // dim -1
1063   // }
1064   // Indices and weights are reshaped as (1, 1, ..., N, ..., 1, 1) to
1065   // fit input/output tensors.
1066   // Indices are already containing the strides to optimize the computations
compute_indices_weightsat::native::__anonb4ad1c8a0111::HelperInterpNearest1067   static inline std::vector<Tensor> compute_indices_weights(
1068     at::ScalarType scalar_type,
1069     int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
1070     int64_t reshape_dim, bool align_corners, const std::optional<double>& opt_scale
1071   ) {
1072 
1073     TORCH_INTERNAL_ASSERT(!align_corners);
1074     std::vector<Tensor> output;
1075     HelperInterpNearest::init_indices_weights(
1076       scalar_type, output, output_size, ndims, reshape_dim, HelperInterpNearest::interp_size);
1077 
1078     AT_DISPATCH_FLOATING_TYPES_AND2(
1079       kBFloat16, kHalf, scalar_type, "compute_indices_weights_nearest", [&] {
1080         using opmath_t = at::opmath_type<scalar_t>;
1081         opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);
1082 
1083         auto input_index_ptr = output[0].data_ptr<int64_t>();
1084         int64_t input_index;
1085 
1086         // Indices are computed as following:
1087         // scale = 1.0 * isize / osize
1088         // index_f32 = (output_index) * scale
1089         // input_index = floor(index_f32)
1090         // Same as OpenCV INTER_NEAREST
1091         for (const auto i : c10::irange(output_size)) {
1092           const auto real_input_index =
1093               area_pixel_compute_source_index<opmath_t>(
1094                   scale, i, /*align_corners=*/true, /*cubic=*/false);
1095           input_index = static_cast<int64_t>(floorf(real_input_index));
1096           input_index_ptr[i] = static_cast<int64_t>(std::min(input_index, input_size - 1)) * stride;
1097         }
1098       }
1099     );
1100     return output;
1101   }
1102 
1103 };
1104 
1105 struct HelperInterpNearestExact : public HelperInterpNearest {
1106 
1107   // Compute nearest mode indices and weights for each interpolated dimension
1108   // indices_weights = {
1109   //      {indices_0, 1.0, },  // dim -n
1110   //      {indices_0, 1.0, },  // dim -(n-1)
1111   //      ...
1112   //      {indices_0, 1.0, },  // dim -1
1113   // }
1114   // Indices and weights are reshaped as (1, 1, ..., N, ..., 1, 1) to
1115   // fit input/output tensors.
1116   // Indices are already containing the strides to optimize the computations
compute_indices_weightsat::native::__anonb4ad1c8a0111::HelperInterpNearestExact1117   static inline std::vector<Tensor> compute_indices_weights(
1118     at::ScalarType scalar_type,
1119     int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
1120     int64_t reshape_dim, bool align_corners, const std::optional<double>& opt_scale
1121   ) {
1122 
1123     TORCH_INTERNAL_ASSERT(!align_corners);
1124     std::vector<Tensor> output;
1125     HelperInterpNearest::init_indices_weights(
1126       scalar_type, output, output_size, ndims, reshape_dim, HelperInterpNearest::interp_size);
1127 
1128     AT_DISPATCH_FLOATING_TYPES_AND2(
1129       kBFloat16, kHalf, scalar_type, "compute_indices_weights_nearest", [&] {
1130         using opmath_t = at::opmath_type<scalar_t>;
1131         opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);
1132 
1133         auto input_index_ptr = output[0].data_ptr<int64_t>();
1134         int64_t input_index;
1135 
1136         // Indices should be computed as following:
1137         // scale = 1.0 * isize / osize
1138         // index_f32 = (output_index + 0.5) * scale - 0.5
1139         // input_index = round(index_f32)
1140         // Same as Pillow and Scikit-Image/Scipy ndi.zoom
1141         for (const auto i : c10::irange(output_size)) {
1142           const auto real_input_index =
1143               area_pixel_compute_source_index<opmath_t>(
1144                   scale, i, /*align_corners=*/align_corners, /*cubic=*/false);
1145           input_index = static_cast<int64_t>(floorf(real_input_index + 0.5));
1146           input_index_ptr[i] = static_cast<int64_t>(std::min(input_index, input_size - 1)) * stride;
1147         }
1148       }
1149     );
1150     return output;
1151   }
1152 };
1153 
1154 struct HelperInterpLinear : public HelperInterpBase {
1155 
1156   static const int interp_size = 2;
1157 
1158   // Compute indices and weights for each interpolated dimension
1159   // indices_weights = {
1160   //      {indices_0, weights_0, indices_1, weights_1},  // dim -n
1161   //      {indices_0, weights_0, indices_1, weights_1},  // dim -(n-1)
1162   //      ...
1163   //      {indices_0, weights_0, indices_1, weights_1},  // dim -1
1164   // }
1165   // Indices and weights are reshaped as (1, 1, ..., N, ..., 1, 1) to
1166   // fit input/output tensors.
1167   // Indices are already containing the strides to optimize the computations
compute_indices_weightsat::native::__anonb4ad1c8a0111::HelperInterpLinear1168   static inline std::vector<Tensor> compute_indices_weights(
1169     at::ScalarType scalar_type,
1170     int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims, int64_t reshape_dim,
1171     bool align_corners, const std::optional<double>& opt_scale
1172   ) {
1173     std::vector<Tensor> output;
1174     HelperInterpLinear::init_indices_weights(
1175       scalar_type, output, output_size, ndims, reshape_dim, HelperInterpLinear::interp_size);
1176     AT_DISPATCH_FLOATING_TYPES_AND2(
1177       kBFloat16, kHalf, scalar_type, "compute_indices_weights_linear", [&] {
1178         using opmath_t = at::opmath_type<scalar_t>;
1179         opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);
1180 
1181         auto input_index0_ptr = output[0].data_ptr<int64_t>();
1182         auto lambda0_ptr = output[1].data_ptr<scalar_t>();
1183         auto input_index1_ptr = output[2].data_ptr<int64_t>();
1184         auto lambda1_ptr = output[3].data_ptr<scalar_t>();
1185 
1186         for (const auto i : c10::irange(output_size)) {
1187 
1188           compute_source_index_and_lambda<scalar_t, opmath_t>(
1189             input_index0_ptr[i], input_index1_ptr[i],
1190             lambda0_ptr[i], lambda1_ptr[i],
1191             scale, i, input_size, output_size, align_corners
1192           );
1193           // put stride into indices
1194           // index values correspond to input indices (0, 1, 2, 3, ...)
1195           // when multiplied by input stride, maximum possible value
1196           // input_size[dim-1] * input_size[dim-2] * ... for the given dimension.
1197           input_index0_ptr[i] *= stride;
1198           input_index1_ptr[i] *= stride;
1199         }
1200       }
1201     );
1202     return output;
1203   }
1204 
1205   // taken from
1206   // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
1207   // src/libImaging/Resample.c#L20-L29
1208   template<typename scalar_t>
aa_filterat::native::__anonb4ad1c8a0111::HelperInterpLinear1209   static inline scalar_t aa_filter(scalar_t x) {
1210     x = std::abs(x);
1211     if (x < 1.0) {
1212       return 1.0 - x;
1213     }
1214     return 0.0;
1215   }
1216 
compute_index_ranges_weightsat::native::__anonb4ad1c8a0111::HelperInterpLinear1217   static inline std::vector<Tensor> compute_index_ranges_weights(
1218     at::ScalarType scalar_type,
1219     int64_t input_size,
1220     int64_t output_size,
1221     int64_t stride,
1222     int64_t ndims,
1223     int64_t reshape_dim,
1224     bool align_corners,
1225     const std::optional<double>& opt_scale,
1226     bool antialias
1227   ) {
1228 
1229     std::vector<Tensor> indices_weights;
1230     AT_DISPATCH_FLOATING_TYPES(
1231       scalar_type, "compute_index_ranges_weights", [&] {
1232 
1233         scalar_t scale = area_pixel_compute_scale<scalar_t>(
1234             input_size, output_size, align_corners, opt_scale);
1235 
1236         auto interp_size = HelperInterpLinear::interp_size;
1237 
1238         indices_weights = std::get<0>(HelperInterpLinear::_compute_index_ranges_weights<scalar_t>(
1239             input_size,
1240             output_size,
1241             stride,
1242             ndims,
1243             reshape_dim,
1244             scale,
1245             interp_size,
1246             &HelperInterpLinear::aa_filter<scalar_t>,
1247             /*antialias=*/antialias,
1248             /*align_corners=*/align_corners));
1249       }
1250     );
1251     return indices_weights;
1252   }
1253 
compute_index_ranges_int16_weightsat::native::__anonb4ad1c8a0111::HelperInterpLinear1254   static inline std::tuple<std::vector<Tensor>, int, unsigned int> compute_index_ranges_int16_weights(
1255     int64_t input_size,
1256     int64_t output_size,
1257     int64_t stride,
1258     int64_t ndims,
1259     int64_t reshape_dim,
1260     bool align_corners,
1261     const std::optional<double>& opt_scale,
1262     bool antialias,
1263     bool align_i32=false
1264   ) {
1265 
1266     auto interp_size = HelperInterpLinear::interp_size;
1267     auto fn = HelperInterpLinear::aa_filter<double>;
1268     return HelperInterpLinear::_compute_index_ranges_int16_weights(
1269         input_size, output_size, stride, ndims, reshape_dim,
1270         align_corners, opt_scale, interp_size, fn, antialias, align_i32);
1271   }
1272 };
1273 
1274 struct HelperInterpCubic : public HelperInterpBase {
1275 
1276   static const int interp_size = 4;
1277 
1278   // Compute indices and weights for each interpolated dimension
1279   // indices_weights = {
1280   //      {indices_0, weights_0, indices_1, weights_1, ..., indices_3, weights_3},  // dim -n
1281   //      {indices_0, weights_0, indices_1, weights_1, ..., indices_3, weights_3},  // dim -(n-1)
1282   //      ...
1283   //      {indices_0, weights_0, indices_1, weights_1, ..., indices_3, weights_3},  // dim -1
1284   // }
1285   // Indices and weights are reshaped as (1, 1, ..., N, ..., 1, 1) to
1286   // fit input/output tensors.
1287   // Indices are already containing the strides to optimize the computations
compute_indices_weightsat::native::__anonb4ad1c8a0111::HelperInterpCubic1288   static inline std::vector<Tensor> compute_indices_weights(
1289     at::ScalarType scalar_type,
1290     int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims, int64_t reshape_dim,
1291     bool align_corners, const std::optional<double>& opt_scale
1292   ) {
1293     std::vector<Tensor> output;
1294     HelperInterpCubic::init_indices_weights(
1295       scalar_type, output, output_size, ndims, reshape_dim, HelperInterpCubic::interp_size);
1296 
1297     AT_DISPATCH_FLOATING_TYPES_AND2(
1298       kBFloat16, kHalf, scalar_type, "compute_indices_weights_cubic", [&] {
1299         using opmath_t = at::opmath_type<scalar_t>;
1300         opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);
1301 
1302         int64_t input_index;
1303         int64_t zero = static_cast<int64_t>(0);
1304         // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1305         opmath_t coeffs[4];
1306 
1307         int64_t * idx_ptr;
1308         scalar_t * wt_ptr;
1309         for (const auto i : c10::irange(output_size)) {
1310           const auto real_input_index =
1311               area_pixel_compute_source_index<opmath_t>(
1312                   scale, i, align_corners, /*cubic=*/true);
1313           opmath_t lambda;
1314           guard_index_and_lambda(real_input_index, input_size, input_index, lambda);
1315           get_cubic_upsample_coefficients<opmath_t>(coeffs, lambda);
1316 
1317           for (const auto j : c10::irange(interp_size)) {
1318             idx_ptr = output[2 * j + 0].data_ptr<int64_t>();
1319             idx_ptr[i] = static_cast<int64_t>(std::max(std::min(input_index + j - 1, input_size - 1), zero)) * stride;
1320             wt_ptr = output[2 * j + 1].data_ptr<scalar_t>();
1321             wt_ptr[i] = coeffs[j];
1322           }
1323         }
1324       }
1325     );
1326     return output;
1327   }
1328 
1329   // taken from
1330   // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
1331   // src/libImaging/Resample.c#L46-L62
1332   template<typename scalar_t, bool use_keys_cubic=true>
aa_filterat::native::__anonb4ad1c8a0111::HelperInterpCubic1333   static inline scalar_t aa_filter(scalar_t x) {
1334     // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
1335     // a = -0.5 was proposed by R. Keys in "Cubic convolution interpolation for digital image processing"
1336     // We are using -0.5 for bicubic, antialiasing=true (compatibility with PIL)
1337     // and using -0.75 for bicubic, antialiasing=false (compatibility with Opencv)
1338     constexpr scalar_t a = use_keys_cubic ? -0.5 : -0.75;
1339 
1340     x = std::abs(x);
1341     if (x < 1.0) {
1342         return cubic_convolution1(x, a);
1343     }
1344     if (x < 2.0) {
1345         return cubic_convolution2(x, a);
1346     }
1347     return 0.0;
1348   }
1349 
compute_index_ranges_weightsat::native::__anonb4ad1c8a0111::HelperInterpCubic1350   static inline std::vector<Tensor> compute_index_ranges_weights(
1351     at::ScalarType scalar_type,
1352     int64_t input_size,
1353     int64_t output_size,
1354     int64_t stride,
1355     int64_t ndims,
1356     int64_t reshape_dim,
1357     bool align_corners,
1358     const std::optional<double>& opt_scale,
1359     bool antialias
1360   ) {
1361 
1362     std::vector<Tensor> indices_weights;
1363     AT_DISPATCH_FLOATING_TYPES(
1364       scalar_type, "compute_index_ranges_weights", [&] {
1365 
1366         scalar_t scale = area_pixel_compute_scale<scalar_t>(
1367             input_size, output_size, align_corners, opt_scale);
1368 
1369         auto interp_size = HelperInterpCubic::interp_size;
1370 
1371         indices_weights = std::get<0>(HelperInterpCubic::_compute_index_ranges_weights<scalar_t>(
1372             input_size,
1373             output_size,
1374             stride,
1375             ndims,
1376             reshape_dim,
1377             scale,
1378             interp_size,
1379             &HelperInterpCubic::aa_filter<scalar_t>,
1380             /*antialias=*/antialias,
1381             /*align_corners=*/align_corners));
1382       }
1383     );
1384     return indices_weights;
1385   }
1386 
compute_index_ranges_int16_weightsat::native::__anonb4ad1c8a0111::HelperInterpCubic1387   static inline std::tuple<std::vector<Tensor>, int, unsigned int> compute_index_ranges_int16_weights(
1388     int64_t input_size,
1389     int64_t output_size,
1390     int64_t stride,
1391     int64_t ndims,
1392     int64_t reshape_dim,
1393     bool align_corners,
1394     const std::optional<double>& opt_scale,
1395     bool antialias,
1396     bool align_i32=false
1397   ) {
1398 
1399     auto interp_size = HelperInterpCubic::interp_size;
1400     // We have to use the -0.75 constant when aa is False so that this uint8
1401     // path is as close as possible to float results.
1402     auto fn = antialias ? HelperInterpCubic::aa_filter<double, true> : HelperInterpCubic::aa_filter<double, false>;
1403     return HelperInterpCubic::_compute_index_ranges_int16_weights(
1404         input_size, output_size, stride, ndims, reshape_dim,
1405         align_corners, opt_scale, interp_size, fn, antialias, align_i32);
1406   }
1407 
1408 };
1409 
1410 // Generic upsampling interpolation kernel for N-d case.
1411 // Input is assumed to be like NCHW, NCL, NCKHW - interpolated spatial dimension
1412 // are those from the end up to batch size N and number of channels C.
1413 //
1414 // Internally, it uses TensorIterator to optimize the computations.
1415 // - out_ndims is the number of interpolated dims: 1, 2, 3
1416 // - scale_type is template type for scales, typically std::optional<double>
1417 // - template<typename> class F is one of the above structs to compute indices and weights
1418 template <int out_ndims, typename scale_type, class F>
upsample_generic_Nd_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,const scale_type & scales)1419 void upsample_generic_Nd_kernel_impl(
1420     const Tensor& output,
1421     const Tensor& input,
1422     bool align_corners,
1423     const scale_type& scales) {
1424 
1425 
1426   // input can be NCHW, NCL or NCKHW
1427   auto shape = input.sizes().vec();
1428   auto strides = input.strides().vec();
1429   auto oshape = output.sizes();
1430 
1431   TORCH_INTERNAL_ASSERT(
1432     shape.size() == oshape.size() && shape.size() == 2 + out_ndims
1433   );
1434   TORCH_INTERNAL_ASSERT(strides.size() == 2 + out_ndims);
1435 
1436   for (const auto i : c10::irange(out_ndims)) {
1437     shape[i + 2] = oshape[i + 2];
1438     strides[i + 2] = 0;
1439   }
1440   auto restrided_input = input.as_strided(shape, strides);
1441 
1442 
1443   constexpr int interp_size = F::interp_size;
1444   auto input_scalar_type = input.scalar_type();
1445   if ((interp_size == 1 && input_scalar_type == at::ScalarType::Byte)) {
1446     // nearest also supports uint8 tensor, but we have to use float
1447     // with compute_indices_weights
1448     input_scalar_type = at::ScalarType::Float;
1449   }
1450 
1451   std::vector<std::vector<Tensor>> indices_weights;
1452   indices_weights.reserve(out_ndims);
1453   for (const auto i : c10::irange(out_ndims)) {
1454     indices_weights.emplace_back(
1455       F::compute_indices_weights(
1456         input_scalar_type, input.size(i + 2), oshape[i + 2],
1457         input.stride(i + 2) * input.element_size(),
1458         input.dim(), i + 2, align_corners, scales[i]
1459       )
1460     );
1461   }
1462 
1463   TensorIteratorConfig config;
1464   config.check_all_same_dtype(false)
1465     .declare_static_dtype_and_device(input.scalar_type(), input.device())
1466     .add_output(output)
1467     .add_const_input(restrided_input);
1468 
1469   for (auto & idx_weight: indices_weights) {
1470     for (auto& tensor : idx_weight) {
1471       config.add_const_input(tensor);
1472     }
1473   }
1474 
1475   auto iter = config.build();
1476 
1477   if (interp_size > 1) {
1478     // Nearest also supports uint8 tensor, so need to handle it separately
1479     AT_DISPATCH_FLOATING_TYPES_AND2(
1480         kBFloat16, kHalf, iter.dtype(), "upsample_generic_Nd", [&] {
1481         // MSVC can not catch constexpr int interp_size here
1482         constexpr int mode = F::interp_size;
1483         cpu_upsample_generic<scalar_t, out_ndims, mode>(iter);
1484     });
1485   } else {
1486     AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf,
1487         iter.dtype(), "upsample_generic_Nd", [&] {
1488         constexpr int mode = F::interp_size;
1489         cpu_upsample_generic<scalar_t, out_ndims, mode>(iter);
1490     });
1491   }
1492 }
1493 
1494 template <typename scalar_t, bool is_horizontal>
cpu_upsample_generic_aa(at::TensorIterator & iter,unsigned int weights_precision)1495 void cpu_upsample_generic_aa(at::TensorIterator& iter, unsigned int weights_precision) {
1496 
1497   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
1498     if constexpr (is_horizontal) {
1499 
1500       // Strides are : X 0 | 8 8 8 0 8  (Channels first)
1501       // Strides are : X X | 0 0 0 0 0  (Channels last)
1502       basic_loop_aa_horizontal<scalar_t>(data, strides, n, weights_precision);
1503     } else {
1504       // Strides are : X Y | 0 0 0 0 0 (Channels first)
1505       // Strides are : X X | 0 0 0 0 0 (Channels last)
1506       // upsampling data between contiguous dimensions (aka vertical resampling)
1507       basic_loop_aa_vertical<scalar_t>(data, strides, n, weights_precision);
1508     }
1509   };
1510 
1511   iter.for_each(loop);
1512 }
1513 
1514 template <int out_ndims, typename scale_type, class F, bool is_horizontal>
_separable_upsample_generic_Nd_kernel_impl_single_dim(const Tensor & output,const Tensor & input,int interp_dim,bool align_corners,const scale_type & scales,bool antialias)1515 void _separable_upsample_generic_Nd_kernel_impl_single_dim(
1516     const Tensor& output,
1517     const Tensor& input,
1518     int interp_dim,
1519     bool align_corners,
1520     const scale_type& scales,
1521     bool antialias) {
1522 
1523   // input can be NCHW, NCL or NCKHW
1524   auto shape = input.sizes().vec();
1525   auto strides = input.strides().vec();
1526   auto oshape = output.sizes();
1527 
1528   TORCH_INTERNAL_ASSERT(
1529       shape.size() == oshape.size() && shape.size() == 2 + out_ndims);
1530   TORCH_INTERNAL_ASSERT(strides.size() == 2 + out_ndims);
1531 
1532   for (const auto i : c10::irange(out_ndims)) {
1533     shape[i + 2] = oshape[i + 2];
1534   }
1535   strides[interp_dim] = 0;
1536   auto restrided_input = input.as_strided(shape, strides);
1537 
1538   auto input_scalar_type = input.scalar_type();
1539 
1540   std::vector<Tensor> indices_weights;
1541   unsigned int weights_precision = 0;
1542 
1543   if (input_scalar_type == at::kByte) {
1544     // This is a special branch to provide uint8 dtype support for bilinear and bicubic modes only
1545     TORCH_INTERNAL_ASSERT(F::interp_size == 2 || F::interp_size == 4);
1546     int unused = 0;
1547     std::tie(indices_weights, unused, weights_precision) =
1548       F::compute_index_ranges_int16_weights(
1549         input.size(interp_dim), oshape[interp_dim],
1550         input.stride(interp_dim) * input.element_size(),
1551         input.dim(), interp_dim, align_corners, scales[interp_dim - 2],
1552         antialias);
1553     TORCH_INTERNAL_ASSERT(weights_precision > 0);
1554   } else {
1555     indices_weights =
1556       F::compute_index_ranges_weights(
1557         input_scalar_type, input.size(interp_dim), oshape[interp_dim],
1558         input.stride(interp_dim) * input.element_size(),
1559         input.dim(), interp_dim, align_corners, scales[interp_dim - 2],
1560         antialias);
1561   }
1562 
1563   TensorIteratorConfig config;
1564   config.check_all_same_dtype(false)
1565       .declare_static_dtype_and_device(input.scalar_type(), input.device())
1566       .add_output(output)
1567       .add_const_input(restrided_input);
1568 
1569   for (auto& tensor : indices_weights) {
1570     config.add_const_input(tensor);
1571   }
1572 
1573   auto iter = config.build();
1574 
1575   AT_DISPATCH_FLOATING_TYPES_AND(
1576       at::ScalarType::Byte, iter.dtype(), "upsample_generic_Nd_aa", [&] {
1577         cpu_upsample_generic_aa<scalar_t, is_horizontal>(iter, weights_precision);
1578       });
1579 }
1580 
1581 // Generic separable upsampling interpolation kernel for N-d case with anti-aliasing.
1582 // It also supports antialias=False iff
1583 // (dtype == uint8 and mode in ("bilinear", "bicubic")): this is used as
1584 // fallback in these settings when AVX isn't supported.
1585 template <int out_ndims, typename scale_type, class F>
separable_upsample_generic_Nd_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,const scale_type & scales,bool antialias)1586 void separable_upsample_generic_Nd_kernel_impl(
1587     const Tensor& output,
1588     const Tensor& input,
1589     bool align_corners,
1590     const scale_type& scales,
1591     bool antialias) {
1592 
1593   auto output_shape = output.sizes();
1594   auto input_shape = input.sizes();
1595   auto temp_oshape = input_shape.vec();
1596 
1597   if (output_shape == input_shape) {
1598     output.copy_(input);
1599     return;
1600   }
1601 
1602   at::Tensor temp_output, temp_input = input;
1603 
1604   int interp_dim = 0;
1605   // Precompute the number of single dim resize method invocations
1606   // to avoid copying temporary buffer to output
1607   int num_single_dim_ops = 0;
1608   for (const auto i : c10::irange(out_ndims)) {
1609     interp_dim = 2 + out_ndims - 1 - i;
1610     if (output_shape[interp_dim] != input_shape[interp_dim]) {
1611       num_single_dim_ops += 1;
1612     }
1613   }
1614 
1615   // upsampling data within the contiguous dimension (aka horizontal resampling)
1616   interp_dim = 2 + out_ndims - 1;
1617   if (output_shape[interp_dim] != input_shape[interp_dim]) {
1618 
1619     num_single_dim_ops -= 1;
1620     if (num_single_dim_ops > 0) {
1621       temp_oshape[interp_dim] = output_shape[interp_dim];
1622       temp_output = at::empty(temp_oshape, input.options());
1623     } else {
1624       temp_output = output;
1625     }
1626 
1627     _separable_upsample_generic_Nd_kernel_impl_single_dim<
1628         out_ndims,
1629         scale_t,
1630         F,
1631         true>(
1632         temp_output, temp_input, interp_dim, align_corners, scales, antialias);
1633     temp_input = temp_output;
1634   }
1635 
1636   // upsampling data between contiguous dimensions (aka vertical resampling)
1637   for (const auto i : c10::irange(1, out_ndims)) {
1638     interp_dim = 2 + out_ndims - 1 - i;
1639     if (output_shape[interp_dim] != input_shape[interp_dim]) {
1640 
1641       num_single_dim_ops -= 1;
1642       if (num_single_dim_ops > 0) {
1643         temp_oshape[interp_dim] = output_shape[interp_dim];
1644         temp_output = at::empty(temp_oshape, input.options());
1645       } else {
1646         temp_output = output;
1647       }
1648 
1649       _separable_upsample_generic_Nd_kernel_impl_single_dim<
1650           out_ndims,
1651           scale_t,
1652           F,
1653           false>(
1654           temp_output, temp_input, interp_dim, align_corners, scales, antialias);
1655       temp_input = temp_output;
1656     }
1657   }
1658 }
1659 
upsample_nearest1d_kernel_impl(const Tensor & output,const Tensor & input,std::optional<double> scales_w)1660 void upsample_nearest1d_kernel_impl(
1661     const Tensor& output,
1662     const Tensor& input,
1663     std::optional<double> scales_w) {
1664   upsample_generic_Nd_kernel_impl<1, scale_t, HelperInterpNearest>(
1665       output, input, false, {scales_w});
1666 }
1667 
_upsample_nearest_exact1d_kernel_impl(const Tensor & output,const Tensor & input,std::optional<double> scales_w)1668 void _upsample_nearest_exact1d_kernel_impl(
1669     const Tensor& output,
1670     const Tensor& input,
1671     std::optional<double> scales_w) {
1672   upsample_generic_Nd_kernel_impl<1, scale_t, HelperInterpNearestExact>(
1673     output, input, false, {scales_w});
1674 }
1675 
_use_vectorized_kernel_cond_2d(const Tensor & output,const Tensor & input)1676 int _use_vectorized_kernel_cond_2d(
1677     const Tensor& output,
1678     const Tensor& input) {
1679       // This condition is used to know whether we should dispatch to a vectorized
1680       // kernel, or to the more general upsample_generic_Nd_kernel_impl(). For now,
1681       // the vectorized kernels are only optimized for channels_last and when C >= 4
1682       // (shape = NCHW). For a very wide range of use-cases (typically image or mask
1683       // resizing where we have C < 4), using upsample_generic_Nd_kernel_impl() is
1684       // actually faster. On top of that, benchmarks showed that this also depends on
1685       // the *output* size (output_H + output_W), for both upsampling and
1686       // downsampling. The current 128 threshold was determined through benchmarks.
1687       return ((input.is_contiguous(at::MemoryFormat::ChannelsLast)) && (input.size(1) > 3)) || ((output.size(-2) + output.size(-1)) <= 128);
1688 }
1689 
_use_vectorized_kernel_cond_3d(const Tensor & output,const Tensor & input)1690 int _use_vectorized_kernel_cond_3d(
1691     // Similar to _use_vectorized_kernel_cond_2d() but for 3d resampling (e.g. videos)
1692     // Note that unlike the 2d case, this is not subject to small output size
1693     // overhead - hence the absence of the 128 threshold in the condition.
1694     const Tensor& output,
1695     const Tensor& input) {
1696       return ((input.is_contiguous(at::MemoryFormat::ChannelsLast3d)) && (input.size(1) > 3));
1697 }
1698 
1699 
upsample_nearest2d_kernel_impl(const Tensor & output,const Tensor & input,std::optional<double> scales_h,std::optional<double> scales_w)1700 void upsample_nearest2d_kernel_impl(
1701     const Tensor& output,
1702     const Tensor& input,
1703     std::optional<double> scales_h,
1704     std::optional<double> scales_w) {
1705   if (_use_vectorized_kernel_cond_2d(output, input)) {
1706     AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf,
1707         input.scalar_type(), "upsample_nearest2d_channels_last", [&] {
1708       cpu_upsample_nearest_channels_last<scalar_t, scale_t, nearest_idx>(output, input, {scales_h, scales_w});
1709     });
1710   } else {
1711     upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpNearest>(
1712       output, input, false, {scales_h, scales_w});
1713   }
1714 }
1715 
_upsample_nearest_exact2d_kernel_impl(const Tensor & output,const Tensor & input,std::optional<double> scales_h,std::optional<double> scales_w)1716 void _upsample_nearest_exact2d_kernel_impl(
1717     const Tensor& output,
1718     const Tensor& input,
1719     std::optional<double> scales_h,
1720     std::optional<double> scales_w) {
1721   if (_use_vectorized_kernel_cond_2d(output, input)) {
1722     AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf, input.scalar_type(), "upsample_nearest2d_channels_last", [&] {
1723       cpu_upsample_nearest_channels_last<scalar_t, scale_t, nearest_exact_idx>(output, input, {scales_h, scales_w});
1724     });
1725   } else {
1726     upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpNearestExact>(
1727       output, input, false, {scales_h, scales_w});
1728   }
1729 }
1730 
upsample_nearest3d_kernel_impl(const Tensor & output,const Tensor & input,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)1731 void upsample_nearest3d_kernel_impl(
1732     const Tensor& output,
1733     const Tensor& input,
1734     std::optional<double> scales_d,
1735     std::optional<double> scales_h,
1736     std::optional<double> scales_w) {
1737   if (_use_vectorized_kernel_cond_3d(output, input)) {
1738     AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf,
1739         input.scalar_type(), "upsample_nearest3d_channels_last", [&] {
1740       cpu_upsample_nearest_channels_last<scalar_t, scale_t, nearest_idx>(output, input, {scales_d, scales_h, scales_w});
1741     });
1742   } else {
1743     upsample_generic_Nd_kernel_impl<3, scale_t, HelperInterpNearest>(
1744       output, input, false, {scales_d, scales_h, scales_w});
1745   }
1746 }
1747 
_upsample_nearest_exact3d_kernel_impl(const Tensor & output,const Tensor & input,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)1748 void _upsample_nearest_exact3d_kernel_impl(
1749     const Tensor& output,
1750     const Tensor& input,
1751     std::optional<double> scales_d,
1752     std::optional<double> scales_h,
1753     std::optional<double> scales_w) {
1754   if (_use_vectorized_kernel_cond_3d(output, input)) {
1755     AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf, input.scalar_type(), "upsample_nearest3d_channels_last", [&] {
1756       cpu_upsample_nearest_channels_last<scalar_t, scale_t, nearest_exact_idx>(output, input, {scales_d, scales_h, scales_w});
1757     });
1758   } else {
1759     upsample_generic_Nd_kernel_impl<3, scale_t, HelperInterpNearestExact>(
1760       output, input, false, {scales_d, scales_h, scales_w});
1761   }
1762 }
1763 
upsample_linear1d_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_w)1764 void upsample_linear1d_kernel_impl(
1765     const Tensor& output,
1766     const Tensor& input,
1767     bool align_corners,
1768     std::optional<double> scales_w) {
1769   upsample_generic_Nd_kernel_impl<1, scale_t, HelperInterpLinear>(
1770     output, input, align_corners, {scales_w});
1771 }
1772 
1773 
upsample_bilinear2d_kernel_impl_float(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)1774 void upsample_bilinear2d_kernel_impl_float(
1775     const Tensor& output,
1776     const Tensor& input,
1777     bool align_corners,
1778     std::optional<double> scales_h,
1779     std::optional<double> scales_w) {
1780 
1781   // See note above about _use_vectorized_kernel_cond_2d(output, input). The extra cond is present
1782   // because benchmarks showed that with only 1 thread, images (C == 3) were
1783   // slightly faster with the vectorized kernel than with the generic one.
1784   // That's not the case for masks though (C == 1), which strongly benefit from
1785   // using the generic kernel.
1786   if ((_use_vectorized_kernel_cond_2d(output, input)) || (at::get_num_threads() == 1 && input.size(1) == 3)) {
1787     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "upsample_bilinear2d_channels_last", [&] {
1788       cpu_upsample_linear_channels_last<scalar_t, scale_t>(output, input, align_corners, {scales_h, scales_w});
1789     });
1790   } else {
1791     upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>(
1792       output, input, align_corners, {scales_h, scales_w});
1793   }
1794 }
1795 
upsample_bilinear2d_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)1796 void upsample_bilinear2d_kernel_impl(
1797     const Tensor& output,
1798     const Tensor& input,
1799     bool align_corners,
1800     std::optional<double> scales_h,
1801     std::optional<double> scales_w) {
1802 
1803   if (input.dtype() == at::kByte){
1804     #ifdef CPU_CAPABILITY_AVX2
1805       if (input.size(1) <= 4) {
1806         upsample_avx_bilinear_bicubic_uint8<scale_t, HelperInterpLinear>(input,
1807           output, align_corners, {scales_h, scales_w},
1808           /*antialias=*/false);
1809       } else {
1810         separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>(
1811           output, input, align_corners, {scales_h, scales_w},
1812           /*antialias=*/false);
1813       }
1814     #else  // CPU_CAPABILITY_AVX2
1815       separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>(
1816         output, input, align_corners, {scales_h, scales_w},
1817         /*antialias=*/false);
1818     #endif  // CPU_CAPABILITY_AVX2
1819   } else {
1820     upsample_bilinear2d_kernel_impl_float(output, input, align_corners, scales_h, scales_w);
1821   }
1822 }
1823 
1824 
upsample_bilinear2d_aa_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)1825 void upsample_bilinear2d_aa_kernel_impl(
1826     const Tensor& output,
1827     const Tensor& input,
1828     bool align_corners,
1829     std::optional<double> scales_h,
1830     std::optional<double> scales_w) {
1831 #ifdef CPU_CAPABILITY_AVX2
1832   if (input.dtype() == at::kByte && input.size(1) <= 4) {
1833     upsample_avx_bilinear_bicubic_uint8<scale_t, HelperInterpLinear>(
1834       input, output, align_corners, {scales_h, scales_w},
1835       /*antialias=*/true);
1836   } else {
1837     separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>(
1838         output, input, align_corners, {scales_h, scales_w},
1839         /*antialias=*/true);
1840   }
1841 #else // CPU_CAPABILITY_AVX2
1842   separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>(
1843       output, input, align_corners, {scales_h, scales_w},
1844       /*antialias=*/true);
1845 #endif // CPU_CAPABILITY_AVX2
1846 }
1847 
upsample_trilinear3d_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)1848 void upsample_trilinear3d_kernel_impl(
1849     const Tensor& output,
1850     const Tensor& input,
1851     bool align_corners,
1852     std::optional<double> scales_d,
1853     std::optional<double> scales_h,
1854     std::optional<double> scales_w) {
1855   if ((_use_vectorized_kernel_cond_3d(output, input))) {
1856     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "upsample_trilinear3d_channels_last", [&] {
1857       cpu_upsample_linear_channels_last<scalar_t, scale_t>(output, input, align_corners, {scales_d, scales_h, scales_w});
1858     });
1859   } else {
1860     upsample_generic_Nd_kernel_impl<3, scale_t, HelperInterpLinear>(
1861       output, input, align_corners, {scales_d, scales_h, scales_w});
1862   }
1863 }
1864 
upsample_bicubic2d_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)1865 void upsample_bicubic2d_kernel_impl(
1866     const Tensor& output,
1867     const Tensor& input,
1868     bool align_corners,
1869     std::optional<double> scales_h,
1870     std::optional<double> scales_w) {
1871 
1872   if (input.dtype() == at::kByte){
1873     #ifdef CPU_CAPABILITY_AVX2
1874       if (input.size(1) <= 4) {
1875         upsample_avx_bilinear_bicubic_uint8<scale_t, HelperInterpCubic>(input,
1876           output, align_corners, {scales_h, scales_w},
1877           /*antialias=*/false);
1878       } else {
1879         separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>(
1880           output, input, align_corners, {scales_h, scales_w},
1881           /*antialias=*/false);
1882       }
1883     #else  // CPU_CAPABILITY_AVX2
1884       separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>(
1885         output, input, align_corners, {scales_h, scales_w},
1886         /*antialias=*/false);
1887     #endif  // CPU_CAPABILITY_AVX2
1888   }
1889   else {
1890     upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>(
1891       output, input, align_corners, {scales_h, scales_w});
1892   }
1893 }
1894 
upsample_bicubic2d_aa_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)1895 void upsample_bicubic2d_aa_kernel_impl(
1896     const Tensor& output,
1897     const Tensor& input,
1898     bool align_corners,
1899     std::optional<double> scales_h,
1900     std::optional<double> scales_w) {
1901 
1902 #ifdef CPU_CAPABILITY_AVX2
1903   if (input.dtype() == at::kByte && input.size(1) <= 4) {
1904     upsample_avx_bilinear_bicubic_uint8<scale_t, HelperInterpCubic>(
1905       input, output, align_corners, {scales_h, scales_w},
1906       /*antialias=*/true);
1907   } else {
1908     separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>(
1909         output, input, align_corners, {scales_h, scales_w},
1910         /*antialias=*/true);
1911   }
1912 #else // CPU_CAPABILITY_AVX2
1913   separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>(
1914       output, input, align_corners, {scales_h, scales_w},
1915       /*antialias=*/true);
1916 #endif // CPU_CAPABILITY_AVX2
1917 }
1918 
1919 template <
1920     typename scalar_t,
1921     typename scale_type,
1922     class F>
cpu_upsample_genNd_backward_aa(const Tensor & grad_input_,const Tensor & grad_output_,bool align_corners,const scale_type & scales)1923 void cpu_upsample_genNd_backward_aa(
1924     const Tensor& grad_input_,
1925     const Tensor& grad_output_,
1926     bool align_corners,
1927     const scale_type& scales) {
1928   TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
1929               " for `grad_input` but got dtype ", grad_input_.dtype());
1930 
1931   auto grad_output = grad_output_.contiguous();
1932   auto grad_input = grad_input_.contiguous();
1933 
1934   auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
1935   auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
1936   auto input_sizes = grad_input.sizes().vec();
1937   auto output_sizes = grad_output.sizes().vec();
1938   auto ndim = input_sizes.size();
1939 
1940   // treat nbatch and channels as one dimension
1941   int64_t channels = input_sizes[0] * input_sizes[1];
1942   int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
1943   int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1;
1944   int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1;
1945   int64_t input_width = input_sizes[ndim - 1];
1946   int64_t output_width = output_sizes[ndim - 1];
1947 
1948   int64_t output_slice_size = output_depth * output_height * output_width;
1949   int interp_size = F::interp_size;
1950 
1951   auto loop2d = [&](int64_t begin, int64_t end) {
1952     const scalar_t height_scale = area_pixel_compute_scale<scalar_t>(
1953         input_height, output_height, align_corners, scales[0]);
1954     const scalar_t width_scale = area_pixel_compute_scale<scalar_t>(
1955         input_width, output_width, align_corners, scales[1]);
1956 
1957     auto input_indexr = [=](int64_t c, int64_t h, int64_t w) {
1958       return grad_input_data + c * input_height * input_width +
1959           h * input_width + w;
1960     };
1961 
1962     const scalar_t support_h = (height_scale >= 1.0)
1963         ? (interp_size * 0.5) * height_scale
1964         : interp_size * 0.5;
1965     const scalar_t support_w = (width_scale >= 1.0)
1966         ? (interp_size * 0.5) * width_scale
1967         : interp_size * 0.5;
1968 
1969     const int interp_height = (int)ceilf(support_h) * 2 + 1;
1970     const int interp_width = (int)ceilf(support_w) * 2 + 1;
1971 
1972     std::vector<scalar_t> wx(interp_width, 0.0);
1973     std::vector<scalar_t> wy(interp_height, 0.0);
1974 
1975     int64_t xmin = 0, ymin = 0;
1976     int64_t xsize = 0, ysize = 0;
1977 
1978     typedef scalar_t (*aa_filter_fn_t)(scalar_t);
1979     aa_filter_fn_t filter_fn = &F::aa_filter;
1980 
1981     for (const auto oh : c10::irange(output_height)) {
1982       F::_compute_indices_min_size_weights_aa(
1983           oh,
1984           input_height,
1985           height_scale,
1986           support_h,
1987           wy.data(),
1988           interp_height,
1989           filter_fn,
1990           ymin,
1991           ysize);
1992 
1993       for (const auto ow : c10::irange(output_width)) {
1994         F::_compute_indices_min_size_weights_aa(
1995             ow,
1996             input_width,
1997             width_scale,
1998             support_w,
1999             wx.data(),
2000             interp_width,
2001             filter_fn,
2002             xmin,
2003             xsize);
2004 
2005         for (const auto c : c10::irange(begin, end)) {
2006           scalar_t grad_output_value =
2007               grad_output_data[c * output_slice_size + oh * output_width + ow];
2008 
2009           for (const auto y : c10::irange(ysize)) {
2010             for (const auto x : c10::irange(xsize)) {
2011               *input_indexr(c, ymin + y, xmin + x) +=
2012                   wx[x] * wy[y] * grad_output_value;
2013             }
2014           }
2015         }
2016       }
2017     }
2018   };
2019 
2020   if (ndim == 4) {
2021     // upsample bilinear 2d
2022     at::parallel_for(
2023         0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d);
2024   } else {
2025     TORCH_CHECK(false, "Unsupported tensor ndim");
2026   }
2027 
2028   if (!grad_input_.is_contiguous()) {
2029     grad_input_.copy_(grad_input);
2030   }
2031 }
2032 
upsample_bilinear2d_aa_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)2033 void upsample_bilinear2d_aa_backward_kernel_impl(
2034     const Tensor& grad_input,
2035     const Tensor& grad_output,
2036     bool align_corners,
2037     std::optional<double> scales_h,
2038     std::optional<double> scales_w) {
2039   AT_DISPATCH_FLOATING_TYPES(
2040       grad_output.scalar_type(), "upsample_bilinear2d_aa_backward_cpu", [&] {
2041         cpu_upsample_genNd_backward_aa<scalar_t, scale_t, HelperInterpLinear>(
2042             grad_input, grad_output, align_corners, {scales_h, scales_w});
2043       });
2044 }
2045 
upsample_bicubic2d_aa_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)2046 void upsample_bicubic2d_aa_backward_kernel_impl(
2047     const Tensor& grad_input,
2048     const Tensor& grad_output,
2049     bool align_corners,
2050     std::optional<double> scales_h,
2051     std::optional<double> scales_w) {
2052   AT_DISPATCH_FLOATING_TYPES(
2053       grad_output.scalar_type(), "upsample_bicubic2d_aa_backward_cpu", [&] {
2054         cpu_upsample_genNd_backward_aa<scalar_t, scale_t, HelperInterpCubic>(
2055             grad_input, grad_output, align_corners, {scales_h, scales_w});
2056       });
2057 }
2058 
2059 } // anonymous namespace
2060 
2061 REGISTER_DISPATCH(upsample_nearest1d_kernel, &upsample_nearest1d_kernel_impl);
2062 REGISTER_DISPATCH(_upsample_nearest_exact1d_kernel, &_upsample_nearest_exact1d_kernel_impl);
2063 REGISTER_DISPATCH(upsample_nearest2d_kernel, &upsample_nearest2d_kernel_impl);
2064 REGISTER_DISPATCH(_upsample_nearest_exact2d_kernel, &_upsample_nearest_exact2d_kernel_impl);
2065 REGISTER_DISPATCH(upsample_nearest3d_kernel, &upsample_nearest3d_kernel_impl);
2066 REGISTER_DISPATCH(_upsample_nearest_exact3d_kernel, &_upsample_nearest_exact3d_kernel_impl);
2067 
2068 REGISTER_DISPATCH(upsample_linear1d_kernel, &upsample_linear1d_kernel_impl);
2069 REGISTER_DISPATCH(upsample_bilinear2d_kernel, &upsample_bilinear2d_kernel_impl);
2070 REGISTER_DISPATCH(_upsample_bilinear2d_aa_kernel, &upsample_bilinear2d_aa_kernel_impl);
2071 REGISTER_DISPATCH(_upsample_bilinear2d_aa_backward_kernel, &upsample_bilinear2d_aa_backward_kernel_impl);
2072 REGISTER_DISPATCH(upsample_trilinear3d_kernel, &upsample_trilinear3d_kernel_impl);
2073 
2074 REGISTER_DISPATCH(upsample_bicubic2d_kernel, &upsample_bicubic2d_kernel_impl);
2075 REGISTER_DISPATCH(_upsample_bicubic2d_aa_kernel, &upsample_bicubic2d_aa_kernel_impl);
2076 REGISTER_DISPATCH(_upsample_bicubic2d_aa_backward_kernel, &upsample_bicubic2d_aa_backward_kernel_impl);
2077 } // namespace at::native
2078