1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/TensorIterator.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/native/UpSample.h>
9 #include <ATen/native/cpu/utils.h>
10 #include <c10/util/irange.h>
11 #include <ATen/native/cpu/UpSampleKernelAVXAntialias.h>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #else
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/empty_native.h>
18 #include <ATen/ops/ones.h>
19 #endif
20
21 namespace at::native {
22 namespace {
23
24 using scale_t = std::vector<std::optional<double>>;
25
26 // TODO: this file could benefit from a global renaming of its functions /
27 // classes and terms, as well as from adding more comments. In particular:
28 // - It's not obvious that despite their names (and the file name), all these
29 // kernels don't just do upsampling: they do general interpolation, i.e. they
30 // also all support downscaling.
31 // - the term "horizontal" or "within dims" or "contiguous dim" refers to the
32 // last dimension.
33 // It's not specific to 2D images and applies to 3D (and 1D??) inputs as well.
34 // Similarly "vertical" or "across dims" refers to all dims that aren't the
35 // last one. In other kernels these are also referred to as "zero-stride" and
36 // "non-zero-stride" - we should unify all this.
37 // - the terms "zero-stride" and "non-zero strides" refer to the weights and
38 // indices, not to the contiguity of input or output
39 // - It's not always clear which kernel is vectorized and which one isn't.
40 // - The functions like _use_vectorized_kernel_cond() should be renamed and
41 // their description updated, because they're not the only "fork" in the
42 // code-path where a choice is made between a vectorized kernel vs a
43 // non-vectorized one. See e.g. upsample_bilinear2d_kernel_impl() where we
44 // already make a similar check, before the one in
45 // _use_vectorized_kernel_cond().
46 // - It's not always clear which code is part of a "separable interpolation"
47 // code-path.
48 // - Some names need to be more specific. For example
49 // "cpu_upsample_generic_aa()" looks like a super generic name, but the function
50 // is instead fairly specific - we need to make that clearer.
51 // - Some functions have a "aa" suffix but it doesn't mean that they only
52 // support antialias. Some of them also support antialias=False now.
53 // - Various comments are outdated. Case in point: the one just below about the
54 // `Interpolate` struct being used for cpu_upsample_linear:
55 // cpu_upsample_linear doesn't exist anymore, and these structs are used for
56 // various modes, *not* just linear.
57 // - It'd be useful to document how interpolation works in general, and in particular state explicitly:
58 // - that the weights and indices across a given dimension are the same for
59 // all pixels (hence the benefit of pre-computing them)
60 // - that it can be "separated", i.e. we can do the horizontal pass and the
61 // vertical pass independently (and that some kernels are written this way,
62 // while some aren't.)
63 // - we can probably remove the template over index_t, because it's always
64 // hard-coded as int64_t
65
66
67 // Helper structs and methods for cpu_upsample_linear
68 //
69 // Interpolation methods that used below are separable, and as such we can compute the interpolation
70 // independently per dimension in a recursive way. Please, refer to #10482 for more context.
71 //
72 // Interpolation structure to compute output value in n-dimensional case.
73 // - recursively compute interpolated output for each dimension
74 // - we rely a lot on compiler's code optimization such that implemented operations
75 // can be automatically factorized and vectorized using SSE and AVX2
76 template <int n, typename scalar_t, typename opmath_t, typename index_t, int interp_size>
77 struct Interpolate {
evalat::native::__anonb4ad1c8a0111::Interpolate78 static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
79 index_t ids = *(index_t*)&data[0][i * strides[0]];
80 opmath_t wts = *(scalar_t*)&data[1][i * strides[1]];
81 opmath_t t = Interpolate<n - 1, scalar_t, opmath_t, index_t, interp_size>::eval(src + ids, &data[2 * interp_size], &strides[2 * interp_size], i);
82 opmath_t output = t * wts;
83 for (const auto j : c10::irange(1, interp_size)) {
84 ids = *(index_t*)&data[2 * j + 0][i * strides[2 * j + 0]];
85 wts = *(scalar_t*)&data[2 * j + 1][i * strides[2 * j + 1]];
86 t = Interpolate<n - 1, scalar_t, opmath_t, index_t, interp_size>::eval(src + ids, &data[2 * interp_size], &strides[2 * interp_size], i);
87 output += t * wts;
88 }
89 return output;
90 }
91 };
92
93 template <typename scalar_t, typename opmath_t, typename index_t, int interp_size>
94 struct Interpolate<1, scalar_t, opmath_t, index_t, interp_size> {
evalat::native::__anonb4ad1c8a0111::Interpolate95 static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
96 index_t ids = *(index_t*)&data[0][i * strides[0]];
97 opmath_t wts = *(scalar_t*)&data[1][i * strides[1]];
98 opmath_t t = *(scalar_t *)&src[ids];
99 opmath_t output = t * wts;
100 for (const auto j : c10::irange(1, interp_size)) {
101 ids = *(index_t*)&data[2 * j + 0][i * strides[2 * j + 0]];
102 wts = *(scalar_t*)&data[2 * j + 1][i * strides[2 * j + 1]];
103 t = *(scalar_t *)&src[ids];
104 output += t * wts;
105 }
106 return output;
107 }
108 };
109
110 template <int n, typename scalar_t, typename opmath_t, typename index_t>
111 struct Interpolate<n, scalar_t, opmath_t, index_t, 1> {
evalat::native::__anonb4ad1c8a0111::Interpolate112 static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
113 index_t ids = *(index_t*)&data[0][i * strides[0]];
114 return Interpolate<n - 1, scalar_t, opmath_t, index_t, 1>::eval(src + ids, &data[2], &strides[2], i);
115 }
116 };
117
118 template <typename scalar_t, typename opmath_t, typename index_t>
119 struct Interpolate<1, scalar_t, opmath_t, index_t, 1> {
evalat::native::__anonb4ad1c8a0111::Interpolate120 static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
121 index_t ids = *(index_t*)&data[0][i * strides[0]];
122 return *(scalar_t *)&src[ids];
123 }
124 };
125
126 // There is an unexpected 2x slowdown for upsample_trilinear3d channels_first
127 // for both 1 and 6 threads. We have to specialize this case as below:
128 // Once the issue is fixed we can keep generic implementation and remove:
129 // struct Interpolate<n, scalar_t, index_t, 2> and
130 // struct Interpolate<1, scalar_t, index_t, 2>
131 template <int n, typename scalar_t, typename opmath_t, typename index_t>
132 struct Interpolate<n, scalar_t, opmath_t, index_t, 2> {
evalat::native::__anonb4ad1c8a0111::Interpolate133 static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
134 index_t i0 = *(index_t*)&data[0][i * strides[0]];
135 index_t i1 = *(index_t*)&data[2][i * strides[2]];
136 opmath_t w0 = *(scalar_t *)&data[1][i * strides[1]];
137 opmath_t w1 = *(scalar_t *)&data[3][i * strides[3]];
138
139 opmath_t t0 = Interpolate<n - 1, scalar_t, opmath_t, index_t, 2>::eval(src + i0, &data[4], &strides[4], i);
140 opmath_t t1 = Interpolate<n - 1, scalar_t, opmath_t, index_t, 2>::eval(src + i1, &data[4], &strides[4], i);
141
142 return t0 * w0 + t1 * w1;
143 }
144 };
145
146 template <typename scalar_t, typename opmath_t, typename index_t>
147 struct Interpolate<1, scalar_t, opmath_t, index_t, 2> {
evalat::native::__anonb4ad1c8a0111::Interpolate148 static inline opmath_t eval(char* src, char** data, const int64_t* strides, int64_t i) {
149 index_t i0 = *(index_t*)&data[0][i * strides[0]];
150 index_t i1 = *(index_t*)&data[2][i * strides[2]];
151 opmath_t w0 = *(scalar_t *)&data[1][i * strides[1]];
152 opmath_t w1 = *(scalar_t *)&data[3][i * strides[3]];
153 opmath_t t0 = *(scalar_t *)&src[i0];
154 opmath_t t1 = *(scalar_t *)&src[i1];
155 return t0 * w0 + t1 * w1;
156 }
157 };
158
159 template <int n, typename scalar_t, typename index_t, int interp_size>
interpolate(char * src,char ** data,const int64_t * strides,int64_t i)160 static inline scalar_t interpolate(char* src, char** data, const int64_t* strides, int64_t i) {
161 using opmath_t = at::opmath_type<scalar_t>;
162 return Interpolate<n, scalar_t, opmath_t, index_t, interp_size>::eval(src, data, strides, i);
163 }
164
165 template <typename scalar_t, typename index_t>
interpolate_aa_single_dim_zero_strides(char * src,char ** data,const index_t ids_stride)166 static inline scalar_t interpolate_aa_single_dim_zero_strides(
167 char* src,
168 char** data,
169 const index_t ids_stride) {
170 const index_t ids_min = *(index_t*)&data[0][0];
171 const index_t ids_size = *(index_t*)&data[1][0];
172
173 char* src_min = src + ids_min;
174
175 scalar_t t = *(scalar_t*)&src_min[0];
176 index_t wts_idx = *(index_t*)&data[4][0];
177 scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx];
178 scalar_t wts = wts_ptr[0];
179
180 scalar_t output = t * wts;
181 for (const auto j : c10::irange(1, ids_size)) {
182 wts = wts_ptr[j];
183 t = *(scalar_t*)&src_min[j * ids_stride];
184 output += t * wts;
185 }
186 return output;
187 }
188
189 template <typename scalar_t, typename index_t>
interpolate_aa_single_dim(char * src,char ** data,const int64_t * strides,int64_t i,const index_t ids_stride)190 static inline scalar_t interpolate_aa_single_dim(
191 char* src,
192 char** data,
193 const int64_t* strides,
194 int64_t i,
195 const index_t ids_stride) {
196 index_t ids_min = *(index_t*)&data[0][i * strides[0]];
197 index_t ids_size = *(index_t*)&data[1][i * strides[1]];
198
199 char* src_min = src + ids_min;
200
201 scalar_t t = *(scalar_t*)&src_min[0];
202 index_t wts_idx = *(index_t*)&data[4][i * strides[4]];
203 scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx];
204 scalar_t wts = wts_ptr[0];
205
206 scalar_t output = t * wts;
207 for (const auto j : c10::irange(1, ids_size)) {
208 wts = wts_ptr[j];
209 t = *(scalar_t*)&src_min[j * ids_stride];
210 output += t * wts;
211 }
212 return output;
213 }
214
215 template<int m>
is_zero_stride(const int64_t * strides)216 static inline bool is_zero_stride(const int64_t* strides) {
217 bool output = strides[0] == 0;
218 for (const auto i : c10::irange(1, m)) {
219 output &= (strides[i] == 0);
220 }
221 return output;
222 }
223
224 template <typename scalar_t, typename index_t, int interp_size>
is_contiguous_stride(const int64_t * strides)225 static inline bool is_contiguous_stride(const int64_t* strides) {
226 bool output = (strides[0] == sizeof(index_t)) && (strides[1] == sizeof(scalar_t));
227 for (int i=2; i<2 * interp_size; i+=2) {
228 output &= (strides[i] == sizeof(index_t)) && (strides[i + 1] == sizeof(scalar_t));
229 }
230 return output;
231 }
232
233 // Helper class to recursively check if all input strides corresponding to interpolated dimensions
234 // are equal zero except on a single dimension.
235 //
236 // Inputs: array of strides of size N, non_zero_stride_dim which can be -1, 0, 1, 2, ...
237 // if non_zero_stride_dim, we check that all strides are equal zero, otherwise
238 // 4 strides corresponding to the strides for index_0, weight_0, index_1 and weight_1 for non_zero_stride_dim
239 // dimension should be non zero.
240 //
241 // Unit check of the recursion is to verify whether 4 strides for one interpolated dimension are either zero,
242 // see method is_zero_stride, or (sizeof(index_t), sizeof(scalar_t), sizeof(index_t), sizeof(scalar_t)), see
243 // method is_contiguous_stride.
244 //
245 // In practice, we have the following cases:
246 // - for ND, float32, channel first, strides are
247 // dimN-1, dim1, dim0
248 // i0, w0, i1, w1, ..., i0, w0, i1, w1, i0, w0, i1, w1
249 // strides=(0, 0, 0, 0, ..., 0, 0, 0, 0, 4, 4, 4, 4)
250 //
251 // if size dim0 is 1 then its strides are 0 and dim1 strides are equal 4
252 //
253 // - for ND, float32, channel last, strides are
254 // dimN-1, dimN-2, dim0
255 // i0, w0, i1, w1, i0, w0, i1, w1, ... i0, w0, i1, w1
256 // strides=(0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0)
257 //
258 // Using these methods we can hint the compiler to factorize constant indices and weights
259 // in cpu_upsample_linear method
260 template <int N, int non_zero_stride_dim, typename scalar_t, typename index_t, int interp_size>
261 struct CheckAlmostAllZeroStrides {
evalat::native::__anonb4ad1c8a0111::CheckAlmostAllZeroStrides262 static inline bool eval(const int64_t* strides) {
263 // N is dim index: N -> dim0, N-1 -> dim1, ...
264 // non_zero_stride_dim should be out_dims - dim
265 bool output = false;
266 if constexpr (N == non_zero_stride_dim) {
267 output = is_contiguous_stride<scalar_t, index_t, interp_size>(strides);
268 } else {
269 output = is_zero_stride<2 * interp_size>(strides);
270 }
271 return output &&
272 CheckAlmostAllZeroStrides<N - 1, non_zero_stride_dim, scalar_t, index_t, interp_size>::eval(
273 &strides[2 * interp_size]);
274 }
275 };
276
277 template <int non_zero_stride_dim, typename scalar_t, typename index_t, int interp_size>
278 struct CheckAlmostAllZeroStrides<0, non_zero_stride_dim, scalar_t, index_t, interp_size> {
evalat::native::__anonb4ad1c8a0111::CheckAlmostAllZeroStrides279 static inline bool eval(const int64_t* /*strides*/) {
280 return true;
281 }
282 };
283
284 template <int n, int s, typename scalar_t, typename index_t, int interp_size>
check_almost_all_zero_stride(const int64_t * strides)285 static inline bool check_almost_all_zero_stride(const int64_t* strides) {
286 return CheckAlmostAllZeroStrides<n, s, scalar_t, index_t, interp_size>::eval(strides);
287 }
288
289 // Helper method to compute interpolation for nearest, linear, cubic modes
290 template <typename scalar_t, typename index_t, int out_ndims, int interp_size>
basic_loop(char ** data,const int64_t * strides,int64_t n)291 static inline void basic_loop(char** data, const int64_t* strides, int64_t n) {
292 char* dst = data[0];
293 char* src = data[1];
294 for (const auto i : c10::irange(n)) {
295 *(scalar_t*)&dst[i * strides[0]] = interpolate<out_ndims, scalar_t, index_t, interp_size>(
296 src + i * strides[1], &data[2], &strides[2], i);
297 }
298 }
299
300 template <typename scalar_t>
basic_loop_aa_vertical(char ** data,const int64_t * strides,int64_t n,unsigned int weights_precision)301 static inline void basic_loop_aa_vertical(
302 char** data,
303 const int64_t* strides,
304 int64_t n,
305 unsigned int weights_precision) {
306 char* dst = data[0];
307 char* src = data[1];
308 // index stride is constant for the given dimension
309 const int64_t ids_stride = *(int64_t*)&data[2 + 2][0];
310
311 for (const auto i : c10::irange(n)) {
312 *(scalar_t*)&dst[i * strides[0]] =
313 interpolate_aa_single_dim_zero_strides<scalar_t, int64_t>(
314 src + i * strides[1], &data[2], ids_stride);
315 }
316 }
317
318 template <>
basic_loop_aa_vertical(char ** data,const int64_t * strides,int64_t n,unsigned int weights_precision)319 inline void basic_loop_aa_vertical<uint8_t>(
320 char** data,
321 const int64_t* strides,
322 int64_t n,
323 unsigned int weights_precision) {
324 // See Note [ Weights computation for uint8_t and multiplication trick ]
325 char* dst = data[0];
326 char* src = data[1];
327
328 // index stride is constant for the given dimension
329 const int64_t ids_stride = *(int64_t*)&data[2 + 2][0];
330 const int64_t ids_size = *(int64_t*)&data[2 + 1][0];
331 const int64_t ids_min = *(int64_t*)&data[2 + 0][0];
332
333 int64_t i = 0;
334
335 for (; i<n; i++) {
336
337 char* src_min = src + i * strides[1] + ids_min;
338
339 uint8_t t = *(uint8_t*)&src_min[0];
340 int64_t wts_idx = *(int64_t*)&data[2 + 4][0];
341 int16_t* wts_ptr = (int16_t*)&data[2 + 3][wts_idx];
342 int16_t wts = wts_ptr[0];
343
344 // Intermediate computations are using integer type
345 int output = 1 << (weights_precision - 1); // accounts for the +0.5 part
346 output += t * wts;
347 for (const auto j : c10::irange(1, ids_size)) {
348 wts = wts_ptr[j];
349 t = *(uint8_t*)&src_min[j * ids_stride];
350 output += t * wts;
351 }
352 *(uint8_t*)&dst[i * strides[0]] = (uint8_t)std::clamp(output >> weights_precision, 0, 255);
353 }
354 }
355
356 template <typename scalar_t>
basic_loop_aa_horizontal(char ** data,const int64_t * strides,int64_t n,unsigned int weights_precision)357 static inline void basic_loop_aa_horizontal(
358 char** data,
359 const int64_t* strides,
360 int64_t n,
361 unsigned int weights_precision) {
362 char* dst = data[0];
363 char* src = data[1];
364 // index stride is constant for the given dimension
365 const int64_t ids_stride = *(int64_t*)&data[2 + 2][0];
366
367 if (strides[1] == 0) {
368 for (const auto i : c10::irange(n)) {
369 *(scalar_t*)&dst[i * strides[0]] =
370 interpolate_aa_single_dim<scalar_t, int64_t>(
371 src, &data[2], &strides[2], i, ids_stride);
372 }
373 } else {
374 for (const auto i : c10::irange(n)) {
375 *(scalar_t*)&dst[i * strides[0]] =
376 interpolate_aa_single_dim<scalar_t, int64_t>(
377 src + i * strides[1], &data[2], &strides[2], i, ids_stride);
378 }
379 }
380 }
381
382 template <>
basic_loop_aa_horizontal(char ** data,const int64_t * strides,int64_t n,unsigned int weights_precision)383 inline void basic_loop_aa_horizontal<uint8_t>(
384 char** data,
385 const int64_t* strides,
386 int64_t n,
387 unsigned int weights_precision) {
388 // See Note [ Weights computation for uint8_t and multiplication trick ]
389 char* dst = data[0];
390 char* src = data[1];
391 // index stride is constant for the given dimension
392 const int64_t ids_stride = *(int64_t*)&data[2 + 2][0];
393
394 int64_t i = 0;
395
396 // Here we are implementing data interpolation within the same line (vs between the lines)
397 // output[x, y] = input[xmin[x], y] * W[x] + input[xmin[x] + 1, y] * W[x + 1] + ... + input[xmin[x] + xsize, y] * W[x + xsize]
398
399 for (; i<n; i++) {
400
401 int64_t ids_min = *(int64_t*)&data[2 + 0][i * strides[2 + 0]];
402 int64_t ids_size = *(int64_t*)&data[2 + 1][i * strides[2 + 1]];
403
404 char* src_min = src + i * strides[1] + ids_min;
405
406 uint8_t t = *(uint8_t*)&src_min[0];
407 int64_t wts_idx = *(int64_t*)&data[2 + 4][i * strides[2 + 4]];
408 int16_t* wts_ptr = (int16_t*)&data[2 + 3][wts_idx];
409 int16_t wts = wts_ptr[0];
410
411 // Intermediate computations are using integer type
412 int output = 1 << (weights_precision - 1); // accounts for the +0.5 part
413 output += t * wts;
414 for (const auto j : c10::irange(1, ids_size)) {
415 wts = wts_ptr[j];
416 t = *(uint8_t*)&src_min[j * ids_stride];
417 output += t * wts;
418 }
419 *(uint8_t*)&dst[i * strides[0]] = (uint8_t)std::clamp(output >> weights_precision, 0, 255);
420 }
421 }
422
423 // Generic upsampling computation method using TensorIterator for Nd case.
424 // Supports: nearest, linear, cubic modes with interp_size template argument: 1, 2, 4
425 //
426 // Single loop function for 1d, 2d and 3d cases and modes
427 // For N dimensions, output value up to Di dimension can be computed as
428 //
429 // output_i[a] = interpolate(output_{i+1}[a], w_{i+1}[a], output_{i+1}[a+1], w_{i+1}[a+1], ...)
430 // with
431 // output_DN[a] = interpolate(input_DN[a], w_DN[a], input_DN[a+1], w_DN[a+1], ...)
432 // and i - dimension index and a - linear index for spatial coordinates
433 //
434 // The recursive call is implemented with InterpLinear struct using template for
435 // the loop unrolling on compile time.
436 template <typename scalar_t, int out_ndims, int interp_size>
cpu_upsample_generic(at::TensorIterator & iter)437 void cpu_upsample_generic(at::TensorIterator& iter)
438 {
439 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
440 // special-cases to let the compiler apply compile-time input-specific optimizations
441 if ((strides[0] == sizeof(scalar_t) && (strides[1] == 0) &&
442 // NOLINTNEXTLINE(bugprone-branch-clone)
443 check_almost_all_zero_stride<out_ndims, 1, scalar_t, int64_t, interp_size>(&strides[2]))) {
444 // contiguous channels-first case
445 basic_loop<scalar_t, int64_t, out_ndims, interp_size>(data, strides, n);
446 } else if ((strides[0] == sizeof(scalar_t) && (strides[1] == sizeof(scalar_t)) &&
447 check_almost_all_zero_stride<out_ndims, -1, scalar_t, int64_t, interp_size>(&strides[2]))) {
448 // contiguous channels-last case
449 basic_loop<scalar_t, int64_t, out_ndims, interp_size>(data, strides, n);
450 } else {
451 // fallback
452 basic_loop<scalar_t, int64_t, out_ndims, interp_size>(data, strides, n);
453 }
454 };
455 iter.for_each(loop);
456 }
457
458 template <typename scalar_t, typename scale_type, nearest_idx_fn_t nearest_idx_fn>
cpu_upsample_nearest_channels_last(const Tensor & output_,const Tensor & input_,const scale_type & scales)459 void cpu_upsample_nearest_channels_last(
460 const Tensor& output_,
461 const Tensor& input_,
462 const scale_type& scales) {
463 TORCH_CHECK(input_.dtype() == output_.dtype(), "expected dtype ", input_.dtype(),
464 " for `output` but got dtype ", output_.dtype());
465
466 auto input_sizes = input_.sizes().vec();
467 auto output_sizes = output_.sizes().vec();
468 auto ndim = input_sizes.size();
469 TORCH_CHECK(ndim >=4 && ndim <= 5, "Upsample with NHWC format supports tensors with 4 or 5 dims.")
470
471 auto channels_last_memory_format = ndim == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
472 auto input = input_.contiguous(channels_last_memory_format);
473 auto output = output_.contiguous(channels_last_memory_format);
474
475 auto input_data = input.const_data_ptr<scalar_t>();
476 auto output_data = output.data_ptr<scalar_t>();
477
478 int64_t num_batches = input_sizes[0];
479 int64_t channels = input_sizes[1];
480 int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
481 int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
482 int64_t input_height = input_sizes[ndim - 2];
483 int64_t output_height = output_sizes[ndim - 2];
484 int64_t input_width = input_sizes[ndim - 1];
485 int64_t output_width = output_sizes[ndim - 1];
486 int64_t numel = output.numel();
487
488 TORCH_CHECK(channels > 0, "expected input and output channels greater than 0 but got ", channels);
489
490 using Vec = vec::Vectorized<scalar_t>;
491 auto copy = [](scalar_t* out, const scalar_t* in, int64_t size) {
492 int64_t d = 0;
493 for (; d < size - (size % Vec::size()); d += Vec::size()) {
494 Vec out_vec = Vec::loadu(in + d);
495 out_vec.store(out + d);
496 }
497 for (; d < size; d++) {
498 out[d] = in[d];
499 }
500 };
501
502 auto loop2d = [&](int64_t begin, int64_t end) {
503 int64_t n = 0;
504 int64_t oh = 0;
505 int64_t ow = 0;
506 data_index_init(begin, n, num_batches, oh, output_height, ow, output_width);
507
508 for (const auto i : c10::irange(begin, end)) {
509 int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[0]);
510 int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[1]);
511 scalar_t* output_ptr = output_data + i * channels;
512 const scalar_t* input_ptr = input_data + n * input_height * input_width * channels +
513 ih * input_width * channels + iw * channels;
514 copy(output_ptr, input_ptr, channels);
515 data_index_step(n, num_batches, oh, output_height, ow, output_width);
516 }
517 };
518
519 auto loop3d = [&](int64_t begin, int64_t end) {
520 int64_t n = 0;
521 int64_t od = 0;
522 int64_t oh = 0;
523 int64_t ow = 0;
524 data_index_init(begin, n, num_batches, od, output_depth, oh, output_height, ow, output_width);
525
526 for (const auto i : c10::irange(begin, end)) {
527 int64_t id = nearest_idx_fn(od, input_depth, output_depth, scales[0]);
528 int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[1]);
529 int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[2]);
530 scalar_t* output_ptr = output_data + i * channels;
531 const scalar_t* input_ptr = input_data + n * input_depth * input_height * input_width * channels +
532 id * input_height * input_width * channels +
533 ih * input_width * channels + iw * channels;
534 copy(output_ptr, input_ptr, channels);
535 data_index_step(n, num_batches, od, output_depth, oh, output_height, ow, output_width);
536 }
537 };
538
539 if (ndim == 4) {
540 // upsample nearest 2d
541 at::parallel_for(0, numel / channels, at::internal::GRAIN_SIZE / channels, loop2d);
542 } else {
543 // upsample nearest 3d
544 TORCH_INTERNAL_ASSERT(ndim == 5);
545 at::parallel_for(0, numel / channels, at::internal::GRAIN_SIZE / channels, loop3d);
546 }
547
548 if (!output_.is_contiguous(channels_last_memory_format)) {
549 output_.copy_(output);
550 }
551 }
552
553 template <typename scalar_t, typename accscalar_t>
interpolate(const scalar_t * t,accscalar_t w)554 inline VecType<scalar_t> interpolate(const scalar_t* t, accscalar_t w) {
555 return VecType<scalar_t>::loadu(t) * VecType<scalar_t>(w);
556 }
557
558 template <typename scalar_t, typename accscalar_t, typename... Args>
interpolate(const scalar_t * t,accscalar_t w,Args...args)559 inline VecType<scalar_t> interpolate(const scalar_t* t, accscalar_t w, Args... args) {
560 return VecType<scalar_t>::loadu(t) * VecType<scalar_t>(w) + interpolate(args...);
561 }
562
563 template <typename scalar_t, typename scale_type>
cpu_upsample_linear_channels_last(const Tensor & output_,const Tensor & input_,bool align_corners,const scale_type & scales)564 void cpu_upsample_linear_channels_last(
565 const Tensor& output_,
566 const Tensor& input_,
567 bool align_corners,
568 const scale_type& scales) {
569 TORCH_CHECK(input_.dtype() == output_.dtype(), "expected dtype ", input_.dtype(),
570 " for `output` but got dtype ", output_.dtype());
571
572 auto input_sizes = input_.sizes().vec();
573 auto output_sizes = output_.sizes().vec();
574 auto ndim = input_sizes.size();
575 TORCH_CHECK(ndim >=4 && ndim <= 5, "Upsample with NHWC format supports tensors with 4 or 5 dims.")
576
577 auto channels_last_memory_format = ndim == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
578 auto input = input_.contiguous(channels_last_memory_format);
579 auto output = output_.contiguous(channels_last_memory_format);
580
581 auto input_data = input.const_data_ptr<scalar_t>();
582 auto output_data = output.data_ptr<scalar_t>();
583
584 int64_t num_batches = input_sizes[0];
585 int64_t channels = input_sizes[1];
586 int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
587 int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
588 int64_t input_height = input_sizes[ndim - 2];
589 int64_t output_height = output_sizes[ndim - 2];
590 int64_t input_width = input_sizes[ndim - 1];
591 int64_t output_width = output_sizes[ndim - 1];
592
593 TORCH_CHECK(channels > 0, "expected input and output channels greater than 0 but got ", channels);
594 int64_t output_slice_size = output_depth * output_height * output_width * channels;
595
596 using opmath_t = at::opmath_type<scalar_t>;
597 using Vec = vec::Vectorized<scalar_t>;
598 auto loop2d = [&](int64_t begin, int64_t end) {
599 const auto height_scale = area_pixel_compute_scale<opmath_t>(
600 input_height, output_height, align_corners, scales[0]);
601 const auto width_scale = area_pixel_compute_scale<opmath_t>(
602 input_width, output_width, align_corners, scales[1]);
603
604 auto input_indexr = [=](int64_t n, int64_t h, int64_t w) {
605 return input_data + n * input_height * input_width * channels +
606 h * input_width * channels + w * channels;
607 };
608
609 int64_t ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
610 opmath_t h0lambda, h1lambda, w0lambda, w1lambda;
611 for (const auto n : c10::irange(begin, end)) {
612 for (const auto oh : c10::irange(output_height)) {
613 compute_source_index_and_lambda(
614 ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
615 for (const auto ow : c10::irange(output_width)) {
616 compute_source_index_and_lambda(
617 iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
618
619 scalar_t* out = output_data + n * output_slice_size +
620 oh * output_width * channels + ow * channels;
621 const scalar_t* i00 = input_indexr(n, ih0, iw0);
622 const scalar_t* i01 = input_indexr(n, ih0, iw1);
623 const scalar_t* i10 = input_indexr(n, ih1, iw0);
624 const scalar_t* i11 = input_indexr(n, ih1, iw1);
625 opmath_t w00 = h0lambda * w0lambda;
626 opmath_t w01 = h0lambda * w1lambda;
627 opmath_t w10 = h1lambda * w0lambda;
628 opmath_t w11 = h1lambda * w1lambda;
629
630 int64_t size = channels;
631 int64_t d = 0;
632 for (; d < size - (size % Vec::size()); d += Vec::size()) {
633 auto out_vec = interpolate(i00 + d, w00, i01 + d, w01, i10 + d, w10, i11 + d, w11);
634 out_vec.store(out + d);
635 }
636 for (; d < size; d++) {
637 out[d] = i00[d] * w00 + i01[d] * w01 + i10[d] * w10 + i11[d] * w11;
638 }
639 }
640 }
641 }
642 };
643
644 auto loop3d = [&](int64_t begin, int64_t end) {
645 const auto depth_scale = area_pixel_compute_scale<opmath_t>(
646 input_depth, output_depth, align_corners, scales[0]);
647 const auto height_scale = area_pixel_compute_scale<opmath_t>(
648 input_height, output_height, align_corners, scales[1]);
649 const auto width_scale = area_pixel_compute_scale<opmath_t>(
650 input_width, output_width, align_corners, scales[2]);
651
652 auto input_indexr = [=](int64_t n, int64_t d, int64_t h, int64_t w) {
653 return input_data + n * input_depth * input_height * input_width * channels +
654 d * input_height * input_width * channels +
655 h * input_width * channels + w * channels;
656 };
657
658 int64_t id0 = 0, id1 = 0, ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
659 opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda;
660 for (const auto n : c10::irange(begin, end)) {
661 for (const auto od : c10::irange(output_depth)) {
662 compute_source_index_and_lambda(
663 id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners);
664 for (const auto oh : c10::irange(output_height)) {
665 compute_source_index_and_lambda(
666 ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
667 for (const auto ow : c10::irange(output_width)) {
668 compute_source_index_and_lambda(
669 iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
670
671 scalar_t* out = output_data + n * output_slice_size +
672 od * output_height * output_width * channels +
673 oh * output_width * channels + ow * channels;
674 const scalar_t* i000 = input_indexr(n, id0, ih0, iw0);
675 const scalar_t* i001 = input_indexr(n, id0, ih0, iw1);
676 const scalar_t* i010 = input_indexr(n, id0, ih1, iw0);
677 const scalar_t* i011 = input_indexr(n, id0, ih1, iw1);
678 const scalar_t* i100 = input_indexr(n, id1, ih0, iw0);
679 const scalar_t* i101 = input_indexr(n, id1, ih0, iw1);
680 const scalar_t* i110 = input_indexr(n, id1, ih1, iw0);
681 const scalar_t* i111 = input_indexr(n, id1, ih1, iw1);
682 opmath_t w000 = d0lambda * h0lambda * w0lambda;
683 opmath_t w001 = d0lambda * h0lambda * w1lambda;
684 opmath_t w010 = d0lambda * h1lambda * w0lambda;
685 opmath_t w011 = d0lambda * h1lambda * w1lambda;
686 opmath_t w100 = d1lambda * h0lambda * w0lambda;
687 opmath_t w101 = d1lambda * h0lambda * w1lambda;
688 opmath_t w110 = d1lambda * h1lambda * w0lambda;
689 opmath_t w111 = d1lambda * h1lambda * w1lambda;
690
691 int64_t size = channels;
692 int64_t d = 0;
693 for (; d < size - (size % Vec::size()); d += Vec::size()) {
694 auto out_vec = interpolate(
695 i000 + d, w000, i001 + d, w001, i010 + d, w010, i011 + d, w011,
696 i100 + d, w100, i101 + d, w101, i110 + d, w110, i111 + d, w111);
697 out_vec.store(out + d);
698 }
699 for (; d < size; d++) {
700 out[d] =
701 i000[d] * w000 + i001[d] * w001 + i010[d] * w010 + i011[d] * w011 +
702 i100[d] * w100 + i101[d] * w101 + i110[d] * w110 + i111[d] * w111;
703 }
704 }
705 }
706 }
707 }
708 };
709
710 if (ndim == 4) {
711 // upsample nearest 2d
712 at::parallel_for(0, num_batches, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d);
713 } else {
714 // upsample nearest 3d
715 TORCH_INTERNAL_ASSERT(ndim == 5);
716 at::parallel_for(0, num_batches, at::internal::GRAIN_SIZE / output_slice_size / 8, loop3d);
717 }
718
719 if (!output_.is_contiguous(channels_last_memory_format)) {
720 output_.copy_(output);
721 }
722 }
723
724 // Helper structs to use with upsample_generic_Nd_kernel_impl
725 struct HelperInterpBase {
726
init_indices_weightsat::native::__anonb4ad1c8a0111::HelperInterpBase727 static inline void init_indices_weights(
728 at::ScalarType output_type,
729 std::vector<Tensor> & output, int64_t output_size, int64_t ndims,
730 int64_t reshape_dim, int interp_size
731 ) {
732
733 auto new_shape = std::vector<int64_t>(ndims, 1);
734 new_shape[reshape_dim] = output_size;
735
736 for (const auto j C10_UNUSED : c10::irange(interp_size)) {
737 output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
738 output.emplace_back(empty(new_shape, CPU(output_type)));
739 }
740 }
741
742 // This is a helper function for _compute_index_ranges_weights method that computes
743 // source two int64 scalars index min and size and a list weights (of size max_interp_size)
744 // for interpolation with antialiasing=true mode. It returns the maximal weights value
745 template <typename scalar_t, typename aa_filter_fn_t>
_compute_indices_min_size_weights_aaat::native::__anonb4ad1c8a0111::HelperInterpBase746 static inline scalar_t _compute_indices_min_size_weights_aa(
747 const int64_t i, const int64_t input_size, const scalar_t scale, const scalar_t support,
748 scalar_t* wt_ptr, const int64_t max_interp_size, aa_filter_fn_t filter_fn,
749 int64_t& xmin, int64_t& xsize
750 ) {
751
752 scalar_t center = scale * (i + 0.5);
753 scalar_t total_w = 0.0;
754 scalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0;
755 xmin = std::max(
756 static_cast<int64_t>(center - support + 0.5), static_cast<int64_t>(0));
757 xsize = std::min(
758 static_cast<int64_t>(center + support + 0.5), input_size) - xmin;
759 // There are rare cases when due to precision xsize can be larger than max_interp_size by one.
760 // We have to clip the value
761 xsize = std::clamp(xsize, static_cast<int64_t>(0), max_interp_size);
762
763 int64_t j = 0;
764 for (; j < xsize; j++) {
765 scalar_t w = filter_fn((j + xmin - center + 0.5) * invscale);
766 wt_ptr[j] = w;
767 total_w += w;
768 }
769
770 scalar_t wt_max = 0.0;
771 if (total_w != 0.0) {
772 for (j = 0; j < xsize; j++) {
773 wt_ptr[j] /= total_w;
774 wt_max = std::max(wt_max, wt_ptr[j]);
775 }
776 }
777
778 for (; j < max_interp_size; j++) {
779 wt_ptr[j] = static_cast<scalar_t>(0.0);
780 }
781 return wt_max;
782 }
783
784 // This is a helper function for _compute_index_ranges_weights method that computes
785 // source two int64 scalars index min and size and a list weights (of size max_interp_size)
786 // for interpolation with antialiasing=false mode. It returns the maximal weights value.
787 // This function is templated with scalar_t for type of scale and weights but is only used for
788 // bilinear/bicubic modes on uint8 input and antialiasing=false (in this case scalar_t is double).
789 // For float input types we are using upsample_generic_Nd_kernel_impl and compute_indices_weights methods
790 template <typename scalar_t, typename aa_filter_fn_t>
_compute_indices_min_size_weightsat::native::__anonb4ad1c8a0111::HelperInterpBase791 static inline scalar_t _compute_indices_min_size_weights(
792 const int64_t i, const int64_t input_size, const scalar_t scale,
793 scalar_t* wt_ptr, const int64_t max_interp_size, aa_filter_fn_t filter_fn,
794 bool align_corners, int64_t& index_min, int64_t& index_size
795 ) {
796 // Notes. We do not use opmath_t in this method as f16 and other smaller float types are not routed here.
797 // Typical usage of this method is with scalar_t = double when computing indices and weights for uint8 input
798 // The code below partly adapts indices and lambda computation from compute_indices_weights method and
799 // index_min/index_size from _compute_indices_min_size_weights_aa
800
801 bool cubic = max_interp_size > 2;
802 const auto real_input_index = area_pixel_compute_source_index<scalar_t>(
803 scale, i, align_corners, /*cubic=*/cubic);
804
805 scalar_t lambda;
806 int64_t input_index = 0;
807 guard_index_and_lambda(real_input_index, input_size, input_index, lambda);
808
809 const auto support = static_cast<int64_t>(max_interp_size * 0.5);
810 const auto unbound_index_min = input_index - support + 1;
811 const auto unbound_index_max = input_index + support + 1;
812 index_min = std::max(unbound_index_min, static_cast<int64_t>(0));
813 index_size = std::min(unbound_index_max, input_size) - index_min;
814 // There are rare cases when due to precision xsize can be larger than max_interp_size by one.
815 // We have to clip the value
816 index_size = std::clamp(index_size, static_cast<int64_t>(0), max_interp_size);
817
818 // Below the weights are computed using filter_fn and accumulating values for indices being out of bounds
819 // For example, for bicubic mode for output index i = 0, we have input_index = -1,
820 // then we have unbound_index_min = -2 and unbound_index_max = 1 => unbounded input indices are [-2, -1, 0, 1] and
821 // valid input indices will be [0, 1]
822 // For unbounded input indices we compute four non-zero weights values [w0, w1, w2, w3] and as only two weights can
823 // be used with valid input indcies, we accumulate values in the following way: [w0 + w1 + w2, w3, 0.0, 0.0]
824 // This is equivalent to the float path which would compute indices as [0, 0, 0, 1] and weights as [w0, w1, w2, s3].
825 // A similar accumulation should done for unbounded indices larger than input size.
826 auto w_index = 0;
827 scalar_t wt_max = 0.0;
828 for (const auto j : c10::irange(max_interp_size)) {
829 // initialize weights value as we will accumulate below
830 wt_ptr[j] = 0.0;
831
832 scalar_t w = filter_fn(static_cast<scalar_t>(j + 1 - support) - lambda);
833 if (unbound_index_min + j <= 0) {
834 w_index = 0;
835 } else if (unbound_index_min + j >= input_size - 1) {
836 w_index = index_size - 1;
837 }
838 wt_ptr[w_index] += w;
839 wt_max = std::max(wt_max, wt_ptr[w_index]);
840 w_index++;
841 }
842
843 return wt_max;
844 }
845
846 // Note [ Support for antialias=False as a subcase of antialias=True ]
847 // This function was originally written with the hard assumption that
848 // antialias=True and it was later extended to support antialias=False.
849 // The only difference between aa and no-aa is in how the
850 // weights and indices are computed (and their number). In aa their number is
851 // variable but with no-aa, they're fixed to interp_size. The same "filters"
852 // can be used otherwise. HOWEVER, support for antialias=False here may not be
853 // optimally optimized: the code assumes an arbitrary number of weights and
854 // indices, but this can be optimized further when aa=False since we know
855 // their actual dimensions.
856 template <typename scalar_t, typename aa_filter_fn_t, int weight_index_stride=sizeof(scalar_t)>
_compute_index_ranges_weightsat::native::__anonb4ad1c8a0111::HelperInterpBase857 static inline std::tuple<std::vector<Tensor>, int, scalar_t> _compute_index_ranges_weights(
858 int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
859 int64_t reshape_dim, scalar_t scale,
860 int interp_size, aa_filter_fn_t aa_filter_fn, bool antialias, bool align_corners
861 ) {
862
863 std::vector<Tensor> output;
864
865 scalar_t support;
866 int max_interp_size = 0;
867 if (antialias) {
868 support = (scale >= 1.0) ? (interp_size * 0.5) * scale : interp_size * 0.5;
869 max_interp_size = (int) std::ceil(support) * 2 + 1;
870 } else {
871 support = interp_size * 0.5;
872 max_interp_size = interp_size;
873 }
874
875 auto new_shape = std::vector<int64_t>(ndims, 1);
876 new_shape[reshape_dim] = output_size;
877
878 // Bounds approach as in PIL: xmin/xmax
879 output.emplace_back(
880 empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
881 output.emplace_back(
882 empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
883 output.emplace_back(
884 empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
885
886 {
887 // Weights
888 new_shape[reshape_dim] = output_size * max_interp_size;
889 auto wts = empty(new_shape, CPU(c10::CppTypeToScalarType<scalar_t>()));
890 auto strides = wts.strides().vec();
891 strides[reshape_dim] = 0;
892 new_shape[reshape_dim] = output_size;
893 wts = wts.as_strided(new_shape, strides);
894 output.emplace_back(wts);
895 // Weights indices
896 output.emplace_back(
897 empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
898 }
899
900 int64_t* idx_ptr_xmin = output[0].data_ptr<int64_t>();
901 int64_t* idx_ptr_size = output[1].data_ptr<int64_t>();
902 int64_t* idx_ptr_stride = output[2].data_ptr<int64_t>();
903 scalar_t* wt_ptr = output[3].data_ptr<scalar_t>();
904 int64_t* wt_idx_ptr = output[4].data_ptr<int64_t>();
905
906 scalar_t wt_max = 0.0;
907 for (const auto i : c10::irange(output_size)) {
908 int64_t xmin = 0, xsize = 0;
909 scalar_t wt_max_i;
910 if (antialias) {
911 wt_max_i = HelperInterpBase::_compute_indices_min_size_weights_aa(
912 i,
913 input_size,
914 scale,
915 support,
916 wt_ptr + i * max_interp_size,
917 max_interp_size,
918 aa_filter_fn,
919 xmin,
920 xsize);
921 } else {
922 wt_max_i = HelperInterpBase::_compute_indices_min_size_weights(
923 i,
924 input_size,
925 scale,
926 wt_ptr + i * max_interp_size,
927 max_interp_size,
928 aa_filter_fn,
929 align_corners,
930 xmin,
931 xsize);
932 }
933 wt_max = std::max(wt_max, wt_max_i);
934
935 idx_ptr_xmin[i] = xmin * stride;
936 idx_ptr_size[i] = xsize;
937 idx_ptr_stride[i] = stride;
938 wt_idx_ptr[i] = i * max_interp_size * weight_index_stride;
939 }
940 return {output, max_interp_size, wt_max};
941 }
942
943 /*
944 NOTE [ Weights computation for uint8_t and multiplication trick ]
945 When the input/output dtype is uint8_t, we still compute the interpolation
946 weights as double, but then convert them to int16 via some conversion logic
947 detailed below. This allows us to compute all interpolation operation (sum of
948 multiplications) as ints instead of floats. The result is converted back into
949 uint8 in basic_loop_aa_horizontal<uint8_t> (and vertical)
950
951 In essence the idea is to avoid a multiplication between a float (the
952 weight) and an int (the pixel value) and instead run a multiplication between
953 2 ints:
954
955 ```py
956 COEF_PREC = 16
957
958 def mul(a:float, b:int) -> Tuple[float, int]:
959 # return a * b, round(a * b)
960 actual = a * b
961
962 assert a > 0 # I'm lazy
963 int_a = floor(0.5 + a * (1 << COEF_PREC))
964 with_trick = ((int_a * b) + (1 << (COEF_PREC - 1))) >> COEF_PREC
965
966 return actual, with_trick # round(actual) == with_trick!!
967 ```
968
969 Here's how it works:
970 N == COEFF_PREC
971 1 << N == 2**N
972 floor(0.5 + x) == round(x)
973
974 So the operation is something like
975
976 int_a = round(a * 2**N) -- let's just say it's `a * 2**N` for simplicity
977
978 res = ((int_a * b) + (1 << (N - 1))) >> N
979 = ((a * 2**N * b + 2**(N - 1)) / 2**N
980 = a * b + 0.5
981 = round(a * b)
982 = what we wanted
983 */
984 template <typename aa_filter_fn_t>
_compute_index_ranges_int16_weightsat::native::__anonb4ad1c8a0111::HelperInterpBase985 static inline std::tuple<std::vector<Tensor>, int, unsigned int> _compute_index_ranges_int16_weights(
986 int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
987 int64_t reshape_dim, bool align_corners, const std::optional<double>& opt_scale,
988 int interp_size, aa_filter_fn_t aa_filter_fn, bool antialias, bool align_i32=false
989 ) {
990
991 double scale = area_pixel_compute_scale<double>(
992 input_size, output_size, align_corners, opt_scale);
993
994 auto [indices_weights, aligned_interp_size, wt_max] = HelperInterpBase::_compute_index_ranges_weights<double, aa_filter_fn_t, sizeof(int16_t)>(
995 input_size, output_size, stride, ndims, reshape_dim, scale, interp_size, aa_filter_fn, antialias, align_corners);
996 interp_size = aligned_interp_size;
997
998 // Rescale float weights to int16 and compute weights precision
999 auto weights_f64 = indices_weights[3];
1000 double * data_f64 = weights_f64. template data_ptr<double>();
1001
1002 unsigned int weights_precision = 0;
1003 for (weights_precision = 0; weights_precision < 22; ++weights_precision) {
1004 int next_value = (int) (0.5 + wt_max * (1 << (weights_precision + 1)));
1005 if (next_value >= (1 << 15))
1006 break;
1007 }
1008
1009 // Rescale float values to int16
1010 int16_t * data_i16 = (int16_t *) data_f64;
1011
1012 if (align_i32) {
1013 // We should respect int32 alignment as we will load int16 data as int32
1014 // See ImagingResampleHorizontalConvolution8u4x, mmk0 = _mm256_set1_epi32(*(int32_t*)&k[x]);
1015 // compute aligned_interp_size = nearest pair value to interp_size
1016 while (aligned_interp_size % sizeof(int32_t) != 0) {
1017 aligned_interp_size += 1;
1018 }
1019 // assert that we wont go out of bounds
1020 TORCH_INTERNAL_ASSERT(aligned_interp_size * sizeof(int16_t) < interp_size * sizeof(double));
1021 }
1022
1023 for (const auto j : c10::irange(output_size)) {
1024 for (const auto k : c10::irange(interp_size)) {
1025 double v = data_f64[j * interp_size + k] * (1 << weights_precision);
1026 data_i16[j * aligned_interp_size + k] = (v < 0) ? (int) (-0.5 + v) : (int) (0.5 + v);
1027 }
1028 }
1029
1030 return {indices_weights, aligned_interp_size, weights_precision};
1031 }
1032 };
1033
1034 struct HelperInterpNearest : public HelperInterpBase {
1035 // This structure implements outdated and buggy method to compute indices
1036 // for nearest neighbours interpolation
1037 // We keep this structure for BC and consider as deprecated.
1038 // See HelperInterpNearestExact as replacement
1039
1040 static const int interp_size = 1;
1041
init_indices_weightsat::native::__anonb4ad1c8a0111::HelperInterpNearest1042 static inline void init_indices_weights(
1043 at::ScalarType output_type,
1044 std::vector<Tensor> & output, int64_t output_size, int64_t ndims,
1045 int64_t reshape_dim, int interp_size
1046 ) {
1047 auto new_shape = std::vector<int64_t>(ndims, 1);
1048 new_shape[reshape_dim] = output_size;
1049
1050 for (const auto j C10_UNUSED : c10::irange(interp_size)) {
1051 output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType<int64_t>())));
1052 // Defines weights for consistency, but not used
1053 output.emplace_back(at::ones(new_shape, CPU(output_type)));
1054 }
1055 }
1056
1057 // Compute nearest mode indices and weights for each interpolated dimension
1058 // indices_weights = {
1059 // {indices_0, 1.0, }, // dim -n
1060 // {indices_0, 1.0, }, // dim -(n-1)
1061 // ...
1062 // {indices_0, 1.0, }, // dim -1
1063 // }
1064 // Indices and weights are reshaped as (1, 1, ..., N, ..., 1, 1) to
1065 // fit input/output tensors.
1066 // Indices are already containing the strides to optimize the computations
compute_indices_weightsat::native::__anonb4ad1c8a0111::HelperInterpNearest1067 static inline std::vector<Tensor> compute_indices_weights(
1068 at::ScalarType scalar_type,
1069 int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
1070 int64_t reshape_dim, bool align_corners, const std::optional<double>& opt_scale
1071 ) {
1072
1073 TORCH_INTERNAL_ASSERT(!align_corners);
1074 std::vector<Tensor> output;
1075 HelperInterpNearest::init_indices_weights(
1076 scalar_type, output, output_size, ndims, reshape_dim, HelperInterpNearest::interp_size);
1077
1078 AT_DISPATCH_FLOATING_TYPES_AND2(
1079 kBFloat16, kHalf, scalar_type, "compute_indices_weights_nearest", [&] {
1080 using opmath_t = at::opmath_type<scalar_t>;
1081 opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);
1082
1083 auto input_index_ptr = output[0].data_ptr<int64_t>();
1084 int64_t input_index;
1085
1086 // Indices are computed as following:
1087 // scale = 1.0 * isize / osize
1088 // index_f32 = (output_index) * scale
1089 // input_index = floor(index_f32)
1090 // Same as OpenCV INTER_NEAREST
1091 for (const auto i : c10::irange(output_size)) {
1092 const auto real_input_index =
1093 area_pixel_compute_source_index<opmath_t>(
1094 scale, i, /*align_corners=*/true, /*cubic=*/false);
1095 input_index = static_cast<int64_t>(floorf(real_input_index));
1096 input_index_ptr[i] = static_cast<int64_t>(std::min(input_index, input_size - 1)) * stride;
1097 }
1098 }
1099 );
1100 return output;
1101 }
1102
1103 };
1104
1105 struct HelperInterpNearestExact : public HelperInterpNearest {
1106
1107 // Compute nearest mode indices and weights for each interpolated dimension
1108 // indices_weights = {
1109 // {indices_0, 1.0, }, // dim -n
1110 // {indices_0, 1.0, }, // dim -(n-1)
1111 // ...
1112 // {indices_0, 1.0, }, // dim -1
1113 // }
1114 // Indices and weights are reshaped as (1, 1, ..., N, ..., 1, 1) to
1115 // fit input/output tensors.
1116 // Indices are already containing the strides to optimize the computations
compute_indices_weightsat::native::__anonb4ad1c8a0111::HelperInterpNearestExact1117 static inline std::vector<Tensor> compute_indices_weights(
1118 at::ScalarType scalar_type,
1119 int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
1120 int64_t reshape_dim, bool align_corners, const std::optional<double>& opt_scale
1121 ) {
1122
1123 TORCH_INTERNAL_ASSERT(!align_corners);
1124 std::vector<Tensor> output;
1125 HelperInterpNearest::init_indices_weights(
1126 scalar_type, output, output_size, ndims, reshape_dim, HelperInterpNearest::interp_size);
1127
1128 AT_DISPATCH_FLOATING_TYPES_AND2(
1129 kBFloat16, kHalf, scalar_type, "compute_indices_weights_nearest", [&] {
1130 using opmath_t = at::opmath_type<scalar_t>;
1131 opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);
1132
1133 auto input_index_ptr = output[0].data_ptr<int64_t>();
1134 int64_t input_index;
1135
1136 // Indices should be computed as following:
1137 // scale = 1.0 * isize / osize
1138 // index_f32 = (output_index + 0.5) * scale - 0.5
1139 // input_index = round(index_f32)
1140 // Same as Pillow and Scikit-Image/Scipy ndi.zoom
1141 for (const auto i : c10::irange(output_size)) {
1142 const auto real_input_index =
1143 area_pixel_compute_source_index<opmath_t>(
1144 scale, i, /*align_corners=*/align_corners, /*cubic=*/false);
1145 input_index = static_cast<int64_t>(floorf(real_input_index + 0.5));
1146 input_index_ptr[i] = static_cast<int64_t>(std::min(input_index, input_size - 1)) * stride;
1147 }
1148 }
1149 );
1150 return output;
1151 }
1152 };
1153
1154 struct HelperInterpLinear : public HelperInterpBase {
1155
1156 static const int interp_size = 2;
1157
1158 // Compute indices and weights for each interpolated dimension
1159 // indices_weights = {
1160 // {indices_0, weights_0, indices_1, weights_1}, // dim -n
1161 // {indices_0, weights_0, indices_1, weights_1}, // dim -(n-1)
1162 // ...
1163 // {indices_0, weights_0, indices_1, weights_1}, // dim -1
1164 // }
1165 // Indices and weights are reshaped as (1, 1, ..., N, ..., 1, 1) to
1166 // fit input/output tensors.
1167 // Indices are already containing the strides to optimize the computations
compute_indices_weightsat::native::__anonb4ad1c8a0111::HelperInterpLinear1168 static inline std::vector<Tensor> compute_indices_weights(
1169 at::ScalarType scalar_type,
1170 int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims, int64_t reshape_dim,
1171 bool align_corners, const std::optional<double>& opt_scale
1172 ) {
1173 std::vector<Tensor> output;
1174 HelperInterpLinear::init_indices_weights(
1175 scalar_type, output, output_size, ndims, reshape_dim, HelperInterpLinear::interp_size);
1176 AT_DISPATCH_FLOATING_TYPES_AND2(
1177 kBFloat16, kHalf, scalar_type, "compute_indices_weights_linear", [&] {
1178 using opmath_t = at::opmath_type<scalar_t>;
1179 opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);
1180
1181 auto input_index0_ptr = output[0].data_ptr<int64_t>();
1182 auto lambda0_ptr = output[1].data_ptr<scalar_t>();
1183 auto input_index1_ptr = output[2].data_ptr<int64_t>();
1184 auto lambda1_ptr = output[3].data_ptr<scalar_t>();
1185
1186 for (const auto i : c10::irange(output_size)) {
1187
1188 compute_source_index_and_lambda<scalar_t, opmath_t>(
1189 input_index0_ptr[i], input_index1_ptr[i],
1190 lambda0_ptr[i], lambda1_ptr[i],
1191 scale, i, input_size, output_size, align_corners
1192 );
1193 // put stride into indices
1194 // index values correspond to input indices (0, 1, 2, 3, ...)
1195 // when multiplied by input stride, maximum possible value
1196 // input_size[dim-1] * input_size[dim-2] * ... for the given dimension.
1197 input_index0_ptr[i] *= stride;
1198 input_index1_ptr[i] *= stride;
1199 }
1200 }
1201 );
1202 return output;
1203 }
1204
1205 // taken from
1206 // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
1207 // src/libImaging/Resample.c#L20-L29
1208 template<typename scalar_t>
aa_filterat::native::__anonb4ad1c8a0111::HelperInterpLinear1209 static inline scalar_t aa_filter(scalar_t x) {
1210 x = std::abs(x);
1211 if (x < 1.0) {
1212 return 1.0 - x;
1213 }
1214 return 0.0;
1215 }
1216
compute_index_ranges_weightsat::native::__anonb4ad1c8a0111::HelperInterpLinear1217 static inline std::vector<Tensor> compute_index_ranges_weights(
1218 at::ScalarType scalar_type,
1219 int64_t input_size,
1220 int64_t output_size,
1221 int64_t stride,
1222 int64_t ndims,
1223 int64_t reshape_dim,
1224 bool align_corners,
1225 const std::optional<double>& opt_scale,
1226 bool antialias
1227 ) {
1228
1229 std::vector<Tensor> indices_weights;
1230 AT_DISPATCH_FLOATING_TYPES(
1231 scalar_type, "compute_index_ranges_weights", [&] {
1232
1233 scalar_t scale = area_pixel_compute_scale<scalar_t>(
1234 input_size, output_size, align_corners, opt_scale);
1235
1236 auto interp_size = HelperInterpLinear::interp_size;
1237
1238 indices_weights = std::get<0>(HelperInterpLinear::_compute_index_ranges_weights<scalar_t>(
1239 input_size,
1240 output_size,
1241 stride,
1242 ndims,
1243 reshape_dim,
1244 scale,
1245 interp_size,
1246 &HelperInterpLinear::aa_filter<scalar_t>,
1247 /*antialias=*/antialias,
1248 /*align_corners=*/align_corners));
1249 }
1250 );
1251 return indices_weights;
1252 }
1253
compute_index_ranges_int16_weightsat::native::__anonb4ad1c8a0111::HelperInterpLinear1254 static inline std::tuple<std::vector<Tensor>, int, unsigned int> compute_index_ranges_int16_weights(
1255 int64_t input_size,
1256 int64_t output_size,
1257 int64_t stride,
1258 int64_t ndims,
1259 int64_t reshape_dim,
1260 bool align_corners,
1261 const std::optional<double>& opt_scale,
1262 bool antialias,
1263 bool align_i32=false
1264 ) {
1265
1266 auto interp_size = HelperInterpLinear::interp_size;
1267 auto fn = HelperInterpLinear::aa_filter<double>;
1268 return HelperInterpLinear::_compute_index_ranges_int16_weights(
1269 input_size, output_size, stride, ndims, reshape_dim,
1270 align_corners, opt_scale, interp_size, fn, antialias, align_i32);
1271 }
1272 };
1273
1274 struct HelperInterpCubic : public HelperInterpBase {
1275
1276 static const int interp_size = 4;
1277
1278 // Compute indices and weights for each interpolated dimension
1279 // indices_weights = {
1280 // {indices_0, weights_0, indices_1, weights_1, ..., indices_3, weights_3}, // dim -n
1281 // {indices_0, weights_0, indices_1, weights_1, ..., indices_3, weights_3}, // dim -(n-1)
1282 // ...
1283 // {indices_0, weights_0, indices_1, weights_1, ..., indices_3, weights_3}, // dim -1
1284 // }
1285 // Indices and weights are reshaped as (1, 1, ..., N, ..., 1, 1) to
1286 // fit input/output tensors.
1287 // Indices are already containing the strides to optimize the computations
compute_indices_weightsat::native::__anonb4ad1c8a0111::HelperInterpCubic1288 static inline std::vector<Tensor> compute_indices_weights(
1289 at::ScalarType scalar_type,
1290 int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims, int64_t reshape_dim,
1291 bool align_corners, const std::optional<double>& opt_scale
1292 ) {
1293 std::vector<Tensor> output;
1294 HelperInterpCubic::init_indices_weights(
1295 scalar_type, output, output_size, ndims, reshape_dim, HelperInterpCubic::interp_size);
1296
1297 AT_DISPATCH_FLOATING_TYPES_AND2(
1298 kBFloat16, kHalf, scalar_type, "compute_indices_weights_cubic", [&] {
1299 using opmath_t = at::opmath_type<scalar_t>;
1300 opmath_t scale = area_pixel_compute_scale<opmath_t>(input_size, output_size, align_corners, opt_scale);
1301
1302 int64_t input_index;
1303 int64_t zero = static_cast<int64_t>(0);
1304 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
1305 opmath_t coeffs[4];
1306
1307 int64_t * idx_ptr;
1308 scalar_t * wt_ptr;
1309 for (const auto i : c10::irange(output_size)) {
1310 const auto real_input_index =
1311 area_pixel_compute_source_index<opmath_t>(
1312 scale, i, align_corners, /*cubic=*/true);
1313 opmath_t lambda;
1314 guard_index_and_lambda(real_input_index, input_size, input_index, lambda);
1315 get_cubic_upsample_coefficients<opmath_t>(coeffs, lambda);
1316
1317 for (const auto j : c10::irange(interp_size)) {
1318 idx_ptr = output[2 * j + 0].data_ptr<int64_t>();
1319 idx_ptr[i] = static_cast<int64_t>(std::max(std::min(input_index + j - 1, input_size - 1), zero)) * stride;
1320 wt_ptr = output[2 * j + 1].data_ptr<scalar_t>();
1321 wt_ptr[i] = coeffs[j];
1322 }
1323 }
1324 }
1325 );
1326 return output;
1327 }
1328
1329 // taken from
1330 // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
1331 // src/libImaging/Resample.c#L46-L62
1332 template<typename scalar_t, bool use_keys_cubic=true>
aa_filterat::native::__anonb4ad1c8a0111::HelperInterpCubic1333 static inline scalar_t aa_filter(scalar_t x) {
1334 // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
1335 // a = -0.5 was proposed by R. Keys in "Cubic convolution interpolation for digital image processing"
1336 // We are using -0.5 for bicubic, antialiasing=true (compatibility with PIL)
1337 // and using -0.75 for bicubic, antialiasing=false (compatibility with Opencv)
1338 constexpr scalar_t a = use_keys_cubic ? -0.5 : -0.75;
1339
1340 x = std::abs(x);
1341 if (x < 1.0) {
1342 return cubic_convolution1(x, a);
1343 }
1344 if (x < 2.0) {
1345 return cubic_convolution2(x, a);
1346 }
1347 return 0.0;
1348 }
1349
compute_index_ranges_weightsat::native::__anonb4ad1c8a0111::HelperInterpCubic1350 static inline std::vector<Tensor> compute_index_ranges_weights(
1351 at::ScalarType scalar_type,
1352 int64_t input_size,
1353 int64_t output_size,
1354 int64_t stride,
1355 int64_t ndims,
1356 int64_t reshape_dim,
1357 bool align_corners,
1358 const std::optional<double>& opt_scale,
1359 bool antialias
1360 ) {
1361
1362 std::vector<Tensor> indices_weights;
1363 AT_DISPATCH_FLOATING_TYPES(
1364 scalar_type, "compute_index_ranges_weights", [&] {
1365
1366 scalar_t scale = area_pixel_compute_scale<scalar_t>(
1367 input_size, output_size, align_corners, opt_scale);
1368
1369 auto interp_size = HelperInterpCubic::interp_size;
1370
1371 indices_weights = std::get<0>(HelperInterpCubic::_compute_index_ranges_weights<scalar_t>(
1372 input_size,
1373 output_size,
1374 stride,
1375 ndims,
1376 reshape_dim,
1377 scale,
1378 interp_size,
1379 &HelperInterpCubic::aa_filter<scalar_t>,
1380 /*antialias=*/antialias,
1381 /*align_corners=*/align_corners));
1382 }
1383 );
1384 return indices_weights;
1385 }
1386
compute_index_ranges_int16_weightsat::native::__anonb4ad1c8a0111::HelperInterpCubic1387 static inline std::tuple<std::vector<Tensor>, int, unsigned int> compute_index_ranges_int16_weights(
1388 int64_t input_size,
1389 int64_t output_size,
1390 int64_t stride,
1391 int64_t ndims,
1392 int64_t reshape_dim,
1393 bool align_corners,
1394 const std::optional<double>& opt_scale,
1395 bool antialias,
1396 bool align_i32=false
1397 ) {
1398
1399 auto interp_size = HelperInterpCubic::interp_size;
1400 // We have to use the -0.75 constant when aa is False so that this uint8
1401 // path is as close as possible to float results.
1402 auto fn = antialias ? HelperInterpCubic::aa_filter<double, true> : HelperInterpCubic::aa_filter<double, false>;
1403 return HelperInterpCubic::_compute_index_ranges_int16_weights(
1404 input_size, output_size, stride, ndims, reshape_dim,
1405 align_corners, opt_scale, interp_size, fn, antialias, align_i32);
1406 }
1407
1408 };
1409
1410 // Generic upsampling interpolation kernel for N-d case.
1411 // Input is assumed to be like NCHW, NCL, NCKHW - interpolated spatial dimension
1412 // are those from the end up to batch size N and number of channels C.
1413 //
1414 // Internally, it uses TensorIterator to optimize the computations.
1415 // - out_ndims is the number of interpolated dims: 1, 2, 3
1416 // - scale_type is template type for scales, typically std::optional<double>
1417 // - template<typename> class F is one of the above structs to compute indices and weights
1418 template <int out_ndims, typename scale_type, class F>
upsample_generic_Nd_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,const scale_type & scales)1419 void upsample_generic_Nd_kernel_impl(
1420 const Tensor& output,
1421 const Tensor& input,
1422 bool align_corners,
1423 const scale_type& scales) {
1424
1425
1426 // input can be NCHW, NCL or NCKHW
1427 auto shape = input.sizes().vec();
1428 auto strides = input.strides().vec();
1429 auto oshape = output.sizes();
1430
1431 TORCH_INTERNAL_ASSERT(
1432 shape.size() == oshape.size() && shape.size() == 2 + out_ndims
1433 );
1434 TORCH_INTERNAL_ASSERT(strides.size() == 2 + out_ndims);
1435
1436 for (const auto i : c10::irange(out_ndims)) {
1437 shape[i + 2] = oshape[i + 2];
1438 strides[i + 2] = 0;
1439 }
1440 auto restrided_input = input.as_strided(shape, strides);
1441
1442
1443 constexpr int interp_size = F::interp_size;
1444 auto input_scalar_type = input.scalar_type();
1445 if ((interp_size == 1 && input_scalar_type == at::ScalarType::Byte)) {
1446 // nearest also supports uint8 tensor, but we have to use float
1447 // with compute_indices_weights
1448 input_scalar_type = at::ScalarType::Float;
1449 }
1450
1451 std::vector<std::vector<Tensor>> indices_weights;
1452 indices_weights.reserve(out_ndims);
1453 for (const auto i : c10::irange(out_ndims)) {
1454 indices_weights.emplace_back(
1455 F::compute_indices_weights(
1456 input_scalar_type, input.size(i + 2), oshape[i + 2],
1457 input.stride(i + 2) * input.element_size(),
1458 input.dim(), i + 2, align_corners, scales[i]
1459 )
1460 );
1461 }
1462
1463 TensorIteratorConfig config;
1464 config.check_all_same_dtype(false)
1465 .declare_static_dtype_and_device(input.scalar_type(), input.device())
1466 .add_output(output)
1467 .add_const_input(restrided_input);
1468
1469 for (auto & idx_weight: indices_weights) {
1470 for (auto& tensor : idx_weight) {
1471 config.add_const_input(tensor);
1472 }
1473 }
1474
1475 auto iter = config.build();
1476
1477 if (interp_size > 1) {
1478 // Nearest also supports uint8 tensor, so need to handle it separately
1479 AT_DISPATCH_FLOATING_TYPES_AND2(
1480 kBFloat16, kHalf, iter.dtype(), "upsample_generic_Nd", [&] {
1481 // MSVC can not catch constexpr int interp_size here
1482 constexpr int mode = F::interp_size;
1483 cpu_upsample_generic<scalar_t, out_ndims, mode>(iter);
1484 });
1485 } else {
1486 AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf,
1487 iter.dtype(), "upsample_generic_Nd", [&] {
1488 constexpr int mode = F::interp_size;
1489 cpu_upsample_generic<scalar_t, out_ndims, mode>(iter);
1490 });
1491 }
1492 }
1493
1494 template <typename scalar_t, bool is_horizontal>
cpu_upsample_generic_aa(at::TensorIterator & iter,unsigned int weights_precision)1495 void cpu_upsample_generic_aa(at::TensorIterator& iter, unsigned int weights_precision) {
1496
1497 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
1498 if constexpr (is_horizontal) {
1499
1500 // Strides are : X 0 | 8 8 8 0 8 (Channels first)
1501 // Strides are : X X | 0 0 0 0 0 (Channels last)
1502 basic_loop_aa_horizontal<scalar_t>(data, strides, n, weights_precision);
1503 } else {
1504 // Strides are : X Y | 0 0 0 0 0 (Channels first)
1505 // Strides are : X X | 0 0 0 0 0 (Channels last)
1506 // upsampling data between contiguous dimensions (aka vertical resampling)
1507 basic_loop_aa_vertical<scalar_t>(data, strides, n, weights_precision);
1508 }
1509 };
1510
1511 iter.for_each(loop);
1512 }
1513
1514 template <int out_ndims, typename scale_type, class F, bool is_horizontal>
_separable_upsample_generic_Nd_kernel_impl_single_dim(const Tensor & output,const Tensor & input,int interp_dim,bool align_corners,const scale_type & scales,bool antialias)1515 void _separable_upsample_generic_Nd_kernel_impl_single_dim(
1516 const Tensor& output,
1517 const Tensor& input,
1518 int interp_dim,
1519 bool align_corners,
1520 const scale_type& scales,
1521 bool antialias) {
1522
1523 // input can be NCHW, NCL or NCKHW
1524 auto shape = input.sizes().vec();
1525 auto strides = input.strides().vec();
1526 auto oshape = output.sizes();
1527
1528 TORCH_INTERNAL_ASSERT(
1529 shape.size() == oshape.size() && shape.size() == 2 + out_ndims);
1530 TORCH_INTERNAL_ASSERT(strides.size() == 2 + out_ndims);
1531
1532 for (const auto i : c10::irange(out_ndims)) {
1533 shape[i + 2] = oshape[i + 2];
1534 }
1535 strides[interp_dim] = 0;
1536 auto restrided_input = input.as_strided(shape, strides);
1537
1538 auto input_scalar_type = input.scalar_type();
1539
1540 std::vector<Tensor> indices_weights;
1541 unsigned int weights_precision = 0;
1542
1543 if (input_scalar_type == at::kByte) {
1544 // This is a special branch to provide uint8 dtype support for bilinear and bicubic modes only
1545 TORCH_INTERNAL_ASSERT(F::interp_size == 2 || F::interp_size == 4);
1546 int unused = 0;
1547 std::tie(indices_weights, unused, weights_precision) =
1548 F::compute_index_ranges_int16_weights(
1549 input.size(interp_dim), oshape[interp_dim],
1550 input.stride(interp_dim) * input.element_size(),
1551 input.dim(), interp_dim, align_corners, scales[interp_dim - 2],
1552 antialias);
1553 TORCH_INTERNAL_ASSERT(weights_precision > 0);
1554 } else {
1555 indices_weights =
1556 F::compute_index_ranges_weights(
1557 input_scalar_type, input.size(interp_dim), oshape[interp_dim],
1558 input.stride(interp_dim) * input.element_size(),
1559 input.dim(), interp_dim, align_corners, scales[interp_dim - 2],
1560 antialias);
1561 }
1562
1563 TensorIteratorConfig config;
1564 config.check_all_same_dtype(false)
1565 .declare_static_dtype_and_device(input.scalar_type(), input.device())
1566 .add_output(output)
1567 .add_const_input(restrided_input);
1568
1569 for (auto& tensor : indices_weights) {
1570 config.add_const_input(tensor);
1571 }
1572
1573 auto iter = config.build();
1574
1575 AT_DISPATCH_FLOATING_TYPES_AND(
1576 at::ScalarType::Byte, iter.dtype(), "upsample_generic_Nd_aa", [&] {
1577 cpu_upsample_generic_aa<scalar_t, is_horizontal>(iter, weights_precision);
1578 });
1579 }
1580
1581 // Generic separable upsampling interpolation kernel for N-d case with anti-aliasing.
1582 // It also supports antialias=False iff
1583 // (dtype == uint8 and mode in ("bilinear", "bicubic")): this is used as
1584 // fallback in these settings when AVX isn't supported.
1585 template <int out_ndims, typename scale_type, class F>
separable_upsample_generic_Nd_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,const scale_type & scales,bool antialias)1586 void separable_upsample_generic_Nd_kernel_impl(
1587 const Tensor& output,
1588 const Tensor& input,
1589 bool align_corners,
1590 const scale_type& scales,
1591 bool antialias) {
1592
1593 auto output_shape = output.sizes();
1594 auto input_shape = input.sizes();
1595 auto temp_oshape = input_shape.vec();
1596
1597 if (output_shape == input_shape) {
1598 output.copy_(input);
1599 return;
1600 }
1601
1602 at::Tensor temp_output, temp_input = input;
1603
1604 int interp_dim = 0;
1605 // Precompute the number of single dim resize method invocations
1606 // to avoid copying temporary buffer to output
1607 int num_single_dim_ops = 0;
1608 for (const auto i : c10::irange(out_ndims)) {
1609 interp_dim = 2 + out_ndims - 1 - i;
1610 if (output_shape[interp_dim] != input_shape[interp_dim]) {
1611 num_single_dim_ops += 1;
1612 }
1613 }
1614
1615 // upsampling data within the contiguous dimension (aka horizontal resampling)
1616 interp_dim = 2 + out_ndims - 1;
1617 if (output_shape[interp_dim] != input_shape[interp_dim]) {
1618
1619 num_single_dim_ops -= 1;
1620 if (num_single_dim_ops > 0) {
1621 temp_oshape[interp_dim] = output_shape[interp_dim];
1622 temp_output = at::empty(temp_oshape, input.options());
1623 } else {
1624 temp_output = output;
1625 }
1626
1627 _separable_upsample_generic_Nd_kernel_impl_single_dim<
1628 out_ndims,
1629 scale_t,
1630 F,
1631 true>(
1632 temp_output, temp_input, interp_dim, align_corners, scales, antialias);
1633 temp_input = temp_output;
1634 }
1635
1636 // upsampling data between contiguous dimensions (aka vertical resampling)
1637 for (const auto i : c10::irange(1, out_ndims)) {
1638 interp_dim = 2 + out_ndims - 1 - i;
1639 if (output_shape[interp_dim] != input_shape[interp_dim]) {
1640
1641 num_single_dim_ops -= 1;
1642 if (num_single_dim_ops > 0) {
1643 temp_oshape[interp_dim] = output_shape[interp_dim];
1644 temp_output = at::empty(temp_oshape, input.options());
1645 } else {
1646 temp_output = output;
1647 }
1648
1649 _separable_upsample_generic_Nd_kernel_impl_single_dim<
1650 out_ndims,
1651 scale_t,
1652 F,
1653 false>(
1654 temp_output, temp_input, interp_dim, align_corners, scales, antialias);
1655 temp_input = temp_output;
1656 }
1657 }
1658 }
1659
upsample_nearest1d_kernel_impl(const Tensor & output,const Tensor & input,std::optional<double> scales_w)1660 void upsample_nearest1d_kernel_impl(
1661 const Tensor& output,
1662 const Tensor& input,
1663 std::optional<double> scales_w) {
1664 upsample_generic_Nd_kernel_impl<1, scale_t, HelperInterpNearest>(
1665 output, input, false, {scales_w});
1666 }
1667
_upsample_nearest_exact1d_kernel_impl(const Tensor & output,const Tensor & input,std::optional<double> scales_w)1668 void _upsample_nearest_exact1d_kernel_impl(
1669 const Tensor& output,
1670 const Tensor& input,
1671 std::optional<double> scales_w) {
1672 upsample_generic_Nd_kernel_impl<1, scale_t, HelperInterpNearestExact>(
1673 output, input, false, {scales_w});
1674 }
1675
_use_vectorized_kernel_cond_2d(const Tensor & output,const Tensor & input)1676 int _use_vectorized_kernel_cond_2d(
1677 const Tensor& output,
1678 const Tensor& input) {
1679 // This condition is used to know whether we should dispatch to a vectorized
1680 // kernel, or to the more general upsample_generic_Nd_kernel_impl(). For now,
1681 // the vectorized kernels are only optimized for channels_last and when C >= 4
1682 // (shape = NCHW). For a very wide range of use-cases (typically image or mask
1683 // resizing where we have C < 4), using upsample_generic_Nd_kernel_impl() is
1684 // actually faster. On top of that, benchmarks showed that this also depends on
1685 // the *output* size (output_H + output_W), for both upsampling and
1686 // downsampling. The current 128 threshold was determined through benchmarks.
1687 return ((input.is_contiguous(at::MemoryFormat::ChannelsLast)) && (input.size(1) > 3)) || ((output.size(-2) + output.size(-1)) <= 128);
1688 }
1689
_use_vectorized_kernel_cond_3d(const Tensor & output,const Tensor & input)1690 int _use_vectorized_kernel_cond_3d(
1691 // Similar to _use_vectorized_kernel_cond_2d() but for 3d resampling (e.g. videos)
1692 // Note that unlike the 2d case, this is not subject to small output size
1693 // overhead - hence the absence of the 128 threshold in the condition.
1694 const Tensor& output,
1695 const Tensor& input) {
1696 return ((input.is_contiguous(at::MemoryFormat::ChannelsLast3d)) && (input.size(1) > 3));
1697 }
1698
1699
upsample_nearest2d_kernel_impl(const Tensor & output,const Tensor & input,std::optional<double> scales_h,std::optional<double> scales_w)1700 void upsample_nearest2d_kernel_impl(
1701 const Tensor& output,
1702 const Tensor& input,
1703 std::optional<double> scales_h,
1704 std::optional<double> scales_w) {
1705 if (_use_vectorized_kernel_cond_2d(output, input)) {
1706 AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf,
1707 input.scalar_type(), "upsample_nearest2d_channels_last", [&] {
1708 cpu_upsample_nearest_channels_last<scalar_t, scale_t, nearest_idx>(output, input, {scales_h, scales_w});
1709 });
1710 } else {
1711 upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpNearest>(
1712 output, input, false, {scales_h, scales_w});
1713 }
1714 }
1715
_upsample_nearest_exact2d_kernel_impl(const Tensor & output,const Tensor & input,std::optional<double> scales_h,std::optional<double> scales_w)1716 void _upsample_nearest_exact2d_kernel_impl(
1717 const Tensor& output,
1718 const Tensor& input,
1719 std::optional<double> scales_h,
1720 std::optional<double> scales_w) {
1721 if (_use_vectorized_kernel_cond_2d(output, input)) {
1722 AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf, input.scalar_type(), "upsample_nearest2d_channels_last", [&] {
1723 cpu_upsample_nearest_channels_last<scalar_t, scale_t, nearest_exact_idx>(output, input, {scales_h, scales_w});
1724 });
1725 } else {
1726 upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpNearestExact>(
1727 output, input, false, {scales_h, scales_w});
1728 }
1729 }
1730
upsample_nearest3d_kernel_impl(const Tensor & output,const Tensor & input,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)1731 void upsample_nearest3d_kernel_impl(
1732 const Tensor& output,
1733 const Tensor& input,
1734 std::optional<double> scales_d,
1735 std::optional<double> scales_h,
1736 std::optional<double> scales_w) {
1737 if (_use_vectorized_kernel_cond_3d(output, input)) {
1738 AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf,
1739 input.scalar_type(), "upsample_nearest3d_channels_last", [&] {
1740 cpu_upsample_nearest_channels_last<scalar_t, scale_t, nearest_idx>(output, input, {scales_d, scales_h, scales_w});
1741 });
1742 } else {
1743 upsample_generic_Nd_kernel_impl<3, scale_t, HelperInterpNearest>(
1744 output, input, false, {scales_d, scales_h, scales_w});
1745 }
1746 }
1747
_upsample_nearest_exact3d_kernel_impl(const Tensor & output,const Tensor & input,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)1748 void _upsample_nearest_exact3d_kernel_impl(
1749 const Tensor& output,
1750 const Tensor& input,
1751 std::optional<double> scales_d,
1752 std::optional<double> scales_h,
1753 std::optional<double> scales_w) {
1754 if (_use_vectorized_kernel_cond_3d(output, input)) {
1755 AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf, input.scalar_type(), "upsample_nearest3d_channels_last", [&] {
1756 cpu_upsample_nearest_channels_last<scalar_t, scale_t, nearest_exact_idx>(output, input, {scales_d, scales_h, scales_w});
1757 });
1758 } else {
1759 upsample_generic_Nd_kernel_impl<3, scale_t, HelperInterpNearestExact>(
1760 output, input, false, {scales_d, scales_h, scales_w});
1761 }
1762 }
1763
upsample_linear1d_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_w)1764 void upsample_linear1d_kernel_impl(
1765 const Tensor& output,
1766 const Tensor& input,
1767 bool align_corners,
1768 std::optional<double> scales_w) {
1769 upsample_generic_Nd_kernel_impl<1, scale_t, HelperInterpLinear>(
1770 output, input, align_corners, {scales_w});
1771 }
1772
1773
upsample_bilinear2d_kernel_impl_float(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)1774 void upsample_bilinear2d_kernel_impl_float(
1775 const Tensor& output,
1776 const Tensor& input,
1777 bool align_corners,
1778 std::optional<double> scales_h,
1779 std::optional<double> scales_w) {
1780
1781 // See note above about _use_vectorized_kernel_cond_2d(output, input). The extra cond is present
1782 // because benchmarks showed that with only 1 thread, images (C == 3) were
1783 // slightly faster with the vectorized kernel than with the generic one.
1784 // That's not the case for masks though (C == 1), which strongly benefit from
1785 // using the generic kernel.
1786 if ((_use_vectorized_kernel_cond_2d(output, input)) || (at::get_num_threads() == 1 && input.size(1) == 3)) {
1787 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "upsample_bilinear2d_channels_last", [&] {
1788 cpu_upsample_linear_channels_last<scalar_t, scale_t>(output, input, align_corners, {scales_h, scales_w});
1789 });
1790 } else {
1791 upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>(
1792 output, input, align_corners, {scales_h, scales_w});
1793 }
1794 }
1795
upsample_bilinear2d_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)1796 void upsample_bilinear2d_kernel_impl(
1797 const Tensor& output,
1798 const Tensor& input,
1799 bool align_corners,
1800 std::optional<double> scales_h,
1801 std::optional<double> scales_w) {
1802
1803 if (input.dtype() == at::kByte){
1804 #ifdef CPU_CAPABILITY_AVX2
1805 if (input.size(1) <= 4) {
1806 upsample_avx_bilinear_bicubic_uint8<scale_t, HelperInterpLinear>(input,
1807 output, align_corners, {scales_h, scales_w},
1808 /*antialias=*/false);
1809 } else {
1810 separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>(
1811 output, input, align_corners, {scales_h, scales_w},
1812 /*antialias=*/false);
1813 }
1814 #else // CPU_CAPABILITY_AVX2
1815 separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>(
1816 output, input, align_corners, {scales_h, scales_w},
1817 /*antialias=*/false);
1818 #endif // CPU_CAPABILITY_AVX2
1819 } else {
1820 upsample_bilinear2d_kernel_impl_float(output, input, align_corners, scales_h, scales_w);
1821 }
1822 }
1823
1824
upsample_bilinear2d_aa_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)1825 void upsample_bilinear2d_aa_kernel_impl(
1826 const Tensor& output,
1827 const Tensor& input,
1828 bool align_corners,
1829 std::optional<double> scales_h,
1830 std::optional<double> scales_w) {
1831 #ifdef CPU_CAPABILITY_AVX2
1832 if (input.dtype() == at::kByte && input.size(1) <= 4) {
1833 upsample_avx_bilinear_bicubic_uint8<scale_t, HelperInterpLinear>(
1834 input, output, align_corners, {scales_h, scales_w},
1835 /*antialias=*/true);
1836 } else {
1837 separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>(
1838 output, input, align_corners, {scales_h, scales_w},
1839 /*antialias=*/true);
1840 }
1841 #else // CPU_CAPABILITY_AVX2
1842 separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>(
1843 output, input, align_corners, {scales_h, scales_w},
1844 /*antialias=*/true);
1845 #endif // CPU_CAPABILITY_AVX2
1846 }
1847
upsample_trilinear3d_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)1848 void upsample_trilinear3d_kernel_impl(
1849 const Tensor& output,
1850 const Tensor& input,
1851 bool align_corners,
1852 std::optional<double> scales_d,
1853 std::optional<double> scales_h,
1854 std::optional<double> scales_w) {
1855 if ((_use_vectorized_kernel_cond_3d(output, input))) {
1856 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "upsample_trilinear3d_channels_last", [&] {
1857 cpu_upsample_linear_channels_last<scalar_t, scale_t>(output, input, align_corners, {scales_d, scales_h, scales_w});
1858 });
1859 } else {
1860 upsample_generic_Nd_kernel_impl<3, scale_t, HelperInterpLinear>(
1861 output, input, align_corners, {scales_d, scales_h, scales_w});
1862 }
1863 }
1864
upsample_bicubic2d_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)1865 void upsample_bicubic2d_kernel_impl(
1866 const Tensor& output,
1867 const Tensor& input,
1868 bool align_corners,
1869 std::optional<double> scales_h,
1870 std::optional<double> scales_w) {
1871
1872 if (input.dtype() == at::kByte){
1873 #ifdef CPU_CAPABILITY_AVX2
1874 if (input.size(1) <= 4) {
1875 upsample_avx_bilinear_bicubic_uint8<scale_t, HelperInterpCubic>(input,
1876 output, align_corners, {scales_h, scales_w},
1877 /*antialias=*/false);
1878 } else {
1879 separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>(
1880 output, input, align_corners, {scales_h, scales_w},
1881 /*antialias=*/false);
1882 }
1883 #else // CPU_CAPABILITY_AVX2
1884 separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>(
1885 output, input, align_corners, {scales_h, scales_w},
1886 /*antialias=*/false);
1887 #endif // CPU_CAPABILITY_AVX2
1888 }
1889 else {
1890 upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>(
1891 output, input, align_corners, {scales_h, scales_w});
1892 }
1893 }
1894
upsample_bicubic2d_aa_kernel_impl(const Tensor & output,const Tensor & input,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)1895 void upsample_bicubic2d_aa_kernel_impl(
1896 const Tensor& output,
1897 const Tensor& input,
1898 bool align_corners,
1899 std::optional<double> scales_h,
1900 std::optional<double> scales_w) {
1901
1902 #ifdef CPU_CAPABILITY_AVX2
1903 if (input.dtype() == at::kByte && input.size(1) <= 4) {
1904 upsample_avx_bilinear_bicubic_uint8<scale_t, HelperInterpCubic>(
1905 input, output, align_corners, {scales_h, scales_w},
1906 /*antialias=*/true);
1907 } else {
1908 separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>(
1909 output, input, align_corners, {scales_h, scales_w},
1910 /*antialias=*/true);
1911 }
1912 #else // CPU_CAPABILITY_AVX2
1913 separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpCubic>(
1914 output, input, align_corners, {scales_h, scales_w},
1915 /*antialias=*/true);
1916 #endif // CPU_CAPABILITY_AVX2
1917 }
1918
1919 template <
1920 typename scalar_t,
1921 typename scale_type,
1922 class F>
cpu_upsample_genNd_backward_aa(const Tensor & grad_input_,const Tensor & grad_output_,bool align_corners,const scale_type & scales)1923 void cpu_upsample_genNd_backward_aa(
1924 const Tensor& grad_input_,
1925 const Tensor& grad_output_,
1926 bool align_corners,
1927 const scale_type& scales) {
1928 TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
1929 " for `grad_input` but got dtype ", grad_input_.dtype());
1930
1931 auto grad_output = grad_output_.contiguous();
1932 auto grad_input = grad_input_.contiguous();
1933
1934 auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
1935 auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
1936 auto input_sizes = grad_input.sizes().vec();
1937 auto output_sizes = grad_output.sizes().vec();
1938 auto ndim = input_sizes.size();
1939
1940 // treat nbatch and channels as one dimension
1941 int64_t channels = input_sizes[0] * input_sizes[1];
1942 int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
1943 int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1;
1944 int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1;
1945 int64_t input_width = input_sizes[ndim - 1];
1946 int64_t output_width = output_sizes[ndim - 1];
1947
1948 int64_t output_slice_size = output_depth * output_height * output_width;
1949 int interp_size = F::interp_size;
1950
1951 auto loop2d = [&](int64_t begin, int64_t end) {
1952 const scalar_t height_scale = area_pixel_compute_scale<scalar_t>(
1953 input_height, output_height, align_corners, scales[0]);
1954 const scalar_t width_scale = area_pixel_compute_scale<scalar_t>(
1955 input_width, output_width, align_corners, scales[1]);
1956
1957 auto input_indexr = [=](int64_t c, int64_t h, int64_t w) {
1958 return grad_input_data + c * input_height * input_width +
1959 h * input_width + w;
1960 };
1961
1962 const scalar_t support_h = (height_scale >= 1.0)
1963 ? (interp_size * 0.5) * height_scale
1964 : interp_size * 0.5;
1965 const scalar_t support_w = (width_scale >= 1.0)
1966 ? (interp_size * 0.5) * width_scale
1967 : interp_size * 0.5;
1968
1969 const int interp_height = (int)ceilf(support_h) * 2 + 1;
1970 const int interp_width = (int)ceilf(support_w) * 2 + 1;
1971
1972 std::vector<scalar_t> wx(interp_width, 0.0);
1973 std::vector<scalar_t> wy(interp_height, 0.0);
1974
1975 int64_t xmin = 0, ymin = 0;
1976 int64_t xsize = 0, ysize = 0;
1977
1978 typedef scalar_t (*aa_filter_fn_t)(scalar_t);
1979 aa_filter_fn_t filter_fn = &F::aa_filter;
1980
1981 for (const auto oh : c10::irange(output_height)) {
1982 F::_compute_indices_min_size_weights_aa(
1983 oh,
1984 input_height,
1985 height_scale,
1986 support_h,
1987 wy.data(),
1988 interp_height,
1989 filter_fn,
1990 ymin,
1991 ysize);
1992
1993 for (const auto ow : c10::irange(output_width)) {
1994 F::_compute_indices_min_size_weights_aa(
1995 ow,
1996 input_width,
1997 width_scale,
1998 support_w,
1999 wx.data(),
2000 interp_width,
2001 filter_fn,
2002 xmin,
2003 xsize);
2004
2005 for (const auto c : c10::irange(begin, end)) {
2006 scalar_t grad_output_value =
2007 grad_output_data[c * output_slice_size + oh * output_width + ow];
2008
2009 for (const auto y : c10::irange(ysize)) {
2010 for (const auto x : c10::irange(xsize)) {
2011 *input_indexr(c, ymin + y, xmin + x) +=
2012 wx[x] * wy[y] * grad_output_value;
2013 }
2014 }
2015 }
2016 }
2017 }
2018 };
2019
2020 if (ndim == 4) {
2021 // upsample bilinear 2d
2022 at::parallel_for(
2023 0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d);
2024 } else {
2025 TORCH_CHECK(false, "Unsupported tensor ndim");
2026 }
2027
2028 if (!grad_input_.is_contiguous()) {
2029 grad_input_.copy_(grad_input);
2030 }
2031 }
2032
upsample_bilinear2d_aa_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)2033 void upsample_bilinear2d_aa_backward_kernel_impl(
2034 const Tensor& grad_input,
2035 const Tensor& grad_output,
2036 bool align_corners,
2037 std::optional<double> scales_h,
2038 std::optional<double> scales_w) {
2039 AT_DISPATCH_FLOATING_TYPES(
2040 grad_output.scalar_type(), "upsample_bilinear2d_aa_backward_cpu", [&] {
2041 cpu_upsample_genNd_backward_aa<scalar_t, scale_t, HelperInterpLinear>(
2042 grad_input, grad_output, align_corners, {scales_h, scales_w});
2043 });
2044 }
2045
upsample_bicubic2d_aa_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)2046 void upsample_bicubic2d_aa_backward_kernel_impl(
2047 const Tensor& grad_input,
2048 const Tensor& grad_output,
2049 bool align_corners,
2050 std::optional<double> scales_h,
2051 std::optional<double> scales_w) {
2052 AT_DISPATCH_FLOATING_TYPES(
2053 grad_output.scalar_type(), "upsample_bicubic2d_aa_backward_cpu", [&] {
2054 cpu_upsample_genNd_backward_aa<scalar_t, scale_t, HelperInterpCubic>(
2055 grad_input, grad_output, align_corners, {scales_h, scales_w});
2056 });
2057 }
2058
2059 } // anonymous namespace
2060
2061 REGISTER_DISPATCH(upsample_nearest1d_kernel, &upsample_nearest1d_kernel_impl);
2062 REGISTER_DISPATCH(_upsample_nearest_exact1d_kernel, &_upsample_nearest_exact1d_kernel_impl);
2063 REGISTER_DISPATCH(upsample_nearest2d_kernel, &upsample_nearest2d_kernel_impl);
2064 REGISTER_DISPATCH(_upsample_nearest_exact2d_kernel, &_upsample_nearest_exact2d_kernel_impl);
2065 REGISTER_DISPATCH(upsample_nearest3d_kernel, &upsample_nearest3d_kernel_impl);
2066 REGISTER_DISPATCH(_upsample_nearest_exact3d_kernel, &_upsample_nearest_exact3d_kernel_impl);
2067
2068 REGISTER_DISPATCH(upsample_linear1d_kernel, &upsample_linear1d_kernel_impl);
2069 REGISTER_DISPATCH(upsample_bilinear2d_kernel, &upsample_bilinear2d_kernel_impl);
2070 REGISTER_DISPATCH(_upsample_bilinear2d_aa_kernel, &upsample_bilinear2d_aa_kernel_impl);
2071 REGISTER_DISPATCH(_upsample_bilinear2d_aa_backward_kernel, &upsample_bilinear2d_aa_backward_kernel_impl);
2072 REGISTER_DISPATCH(upsample_trilinear3d_kernel, &upsample_trilinear3d_kernel_impl);
2073
2074 REGISTER_DISPATCH(upsample_bicubic2d_kernel, &upsample_bicubic2d_kernel_impl);
2075 REGISTER_DISPATCH(_upsample_bicubic2d_aa_kernel, &upsample_bicubic2d_aa_kernel_impl);
2076 REGISTER_DISPATCH(_upsample_bicubic2d_aa_backward_kernel, &upsample_bicubic2d_aa_backward_kernel_impl);
2077 } // namespace at::native
2078