xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <vector>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/native/UpSample.h>
7 #include <ATen/Parallel.h>
8 #include <ATen/TensorIterator.h>
9 #include <c10/util/irange.h>
10 #include <ATen/cpu/vec/vec.h>
11 
12 namespace at::native {
13 namespace {
14 
15 using scale_t = std::vector<std::optional<double>>;
16 
17 template <typename acc_t, typename scalar_t,
18           typename scalar_nonconst_t = std::remove_const_t<scalar_t>,
19           typename std::enable_if_t<!is_reduced_floating_point_v<scalar_nonconst_t> || !std::is_same_v<acc_t, float>, int> = 0>
nearest_channels_last_acc(acc_t * gin,scalar_t * gout,int64_t size)20 void inline nearest_channels_last_acc(acc_t* gin, scalar_t* gout, int64_t size) {
21   static_assert(std::is_same_v<acc_t, scalar_nonconst_t>,
22               "acc data type of Upsample backward should be same as scalar_t for float or double on CPU.");
23   using Vec = Vectorized<acc_t>;
24   int64_t d = 0;
25   for (; d < size - (size % Vec::size()); d += Vec::size()) {
26     Vec gin_vec = Vec::loadu(gin + d) + Vec::loadu(gout + d);
27     gin_vec.store(gin + d);
28   }
29   for (; d < size; d++) {
30     gin[d] += gout[d];
31   }
32 }
33 
34 template <typename acc_t, typename scalar_t,
35           typename scalar_nonconst_t = std::remove_const_t<scalar_t>,
36           typename std::enable_if_t<is_reduced_floating_point_v<scalar_nonconst_t> && std::is_same_v<acc_t, float>, int> = 0>
nearest_channels_last_acc(acc_t * gin,scalar_t * gout,int64_t size)37 void inline nearest_channels_last_acc(acc_t* gin, scalar_t* gout, int64_t size) {
38   using bVec = Vectorized<scalar_nonconst_t>;
39   using fVec = Vectorized<float>;
40   int64_t d = 0;
41   for (; d < size - (size % bVec::size()); d += bVec::size()) {
42     bVec gout_bvec = bVec::loadu(gout + d);
43     auto [gout_fvec0, gout_fvec1] = convert_to_float<scalar_nonconst_t>(gout_bvec);
44     fVec gin_fvec0 = fVec::loadu(gin + d) + gout_fvec0;
45     fVec gin_fvec1 = fVec::loadu(gin + d + fVec::size()) + gout_fvec1;
46     gin_fvec0.store(gin + d);
47     gin_fvec1.store(gin + d + fVec::size());
48   }
49   for (; d < size; d++) {
50     gin[d] += gout[d];
51   }
52 }
53 
54 template <typename acc_t, typename scalar_t,
55           typename scalar_nonconst_t = std::remove_const_t<scalar_t>,
56           typename std::enable_if_t<!is_reduced_floating_point_v<scalar_nonconst_t> || !std::is_same_v<acc_t, float>, int> = 0>
linear_channels_last_acc(acc_t * gin,const scalar_t * gout,acc_t w,int64_t size)57 void inline linear_channels_last_acc(acc_t* gin, const scalar_t* gout, acc_t w, int64_t size) {
58   static_assert(std::is_same_v<acc_t, scalar_nonconst_t>,
59               "acc data type of Upsample backward should be same as scalar_t for float or double on CPU.");
60   using Vec = Vectorized<acc_t>;
61   int64_t d = 0;
62   for (; d < size - (size % Vec::size()); d += Vec::size()) {
63     Vec gin_vec = Vec::loadu(gin + d) + Vec(w) * Vec::loadu(gout + d);
64     gin_vec.store(gin + d);
65   }
66   for (; d < size; d++) {
67     gin[d] += w * gout[d];
68   }
69 }
70 
71 template <typename acc_t, typename scalar_t,
72           typename scalar_nonconst_t = std::remove_const_t<scalar_t>,
73           typename std::enable_if_t<is_reduced_floating_point_v<scalar_nonconst_t> && std::is_same_v<acc_t, float>, int> = 0>
linear_channels_last_acc(acc_t * gin,const scalar_t * gout,acc_t w,int64_t size)74 void inline linear_channels_last_acc(acc_t* gin, const scalar_t* gout, acc_t w, int64_t size) {
75   using bVec = Vectorized<scalar_nonconst_t>;
76   using fVec = Vectorized<float>;
77   int64_t d = 0;
78   for (; d < size - (size % bVec::size()); d += bVec::size()) {
79     bVec gout_bvec = bVec::loadu(gout + d);
80     auto [gout_fvec0, gout_fvec1] = convert_to_float<scalar_nonconst_t>(gout_bvec);
81     fVec gin_fvec0 = fVec::loadu(gin + d) + fVec(w) * gout_fvec0;
82     fVec gin_fvec1 = fVec::loadu(gin + d + fVec::size()) + fVec(w) * gout_fvec1;
83     gin_fvec0.store(gin + d);
84     gin_fvec1.store(gin + d + fVec::size());
85   }
86   for (; d < size; d++) {
87     gin[d] += w * gout[d];
88   }
89 }
90 
91 template <typename scalar_t, typename scale_type, nearest_idx_fn_t nearest_idx_fn>
cpu_upsample_nearest_backward(const Tensor & grad_input_,const Tensor & grad_output_,const scale_type & scales)92 void cpu_upsample_nearest_backward(
93     const Tensor& grad_input_,
94     const Tensor& grad_output_,
95     const scale_type& scales) {
96   TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
97               " for `grad_input` but got dtype ", grad_input_.dtype());
98 
99   auto grad_output = grad_output_.contiguous();
100   auto grad_input = grad_input_.contiguous();
101 
102   auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
103   auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
104   auto input_sizes = grad_input.sizes().vec();
105   auto output_sizes = grad_output.sizes().vec();
106   auto ndim = input_sizes.size();
107 
108   // treat nbatch and channels as one dimension
109   int64_t channels = input_sizes[0] * input_sizes[1];
110   int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
111   int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
112   int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1;
113   int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1;
114   int64_t input_width = input_sizes[ndim - 1];
115   int64_t output_width = output_sizes[ndim - 1];
116 
117   int64_t output_slice_size = output_depth * output_height * output_width;
118   int64_t input_slice_size = input_depth * input_height * input_width;
119 
120   using opmath_t = at::opmath_type<scalar_t>;
121   auto loop1d = [&](int64_t begin, int64_t end) {
122     opmath_t* acc_data_ptr = nullptr;
123     std::unique_ptr<opmath_t[]> buffer_data;
124     if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
125       buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
126       acc_data_ptr = buffer_data.get();
127       memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
128     } else {
129       acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
130     }
131 
132     for (const auto c : c10::irange(begin, end)) {
133       int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
134       for (const auto ow : c10::irange(output_width)) {
135         int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[0]);
136         int64_t output_offset = c * output_slice_size + ow;
137         acc_data_ptr[input_offset + iw] += grad_output_data[output_offset];
138       }
139       if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
140         auto gin = grad_input_data + c * input_slice_size;
141         apply_grad_input(acc_data_ptr, gin, input_slice_size);
142       }
143     }
144   };
145 
146   auto loop2d = [&](int64_t begin, int64_t end) {
147     opmath_t* acc_data_ptr = nullptr;
148     std::unique_ptr<opmath_t[]> buffer_data;
149     if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
150         buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
151         acc_data_ptr = buffer_data.get();
152         memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
153     } else {
154       acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
155     }
156 
157     for (const auto c : c10::irange(begin, end)) {
158       int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
159       for (const auto oh : c10::irange(output_height)) {
160         int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[0]);
161         for (const auto ow : c10::irange(output_width)) {
162           int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[1]);
163           int64_t output_offset = c * output_slice_size + oh * output_width + ow;
164           acc_data_ptr[input_offset + ih * input_width + iw] += grad_output_data[output_offset];
165         }
166       }
167       if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
168         auto gin = grad_input_data + c * input_slice_size;
169         apply_grad_input(acc_data_ptr, gin, input_slice_size);
170       }
171     }
172   };
173 
174   auto loop3d = [&](int64_t begin, int64_t end) {
175     opmath_t* acc_data_ptr = nullptr;
176     std::unique_ptr<opmath_t[]> buffer_data;
177     if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
178         buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
179         acc_data_ptr = buffer_data.get();
180         memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
181     } else {
182       acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
183     }
184 
185     for (const auto c : c10::irange(begin, end)) {
186       int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
187       for (const auto od : c10::irange(output_depth)) {
188         int64_t id = nearest_idx_fn(od, input_depth, output_depth, scales[0]);
189         for (const auto oh : c10::irange(output_height)) {
190           int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[1]);
191           for (const auto ow : c10::irange(output_width)) {
192             int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[2]);
193             int64_t output_offset = c * output_slice_size +
194                 od *  output_height * output_width + oh * output_width + ow;
195             acc_data_ptr[input_offset + id * input_height * input_width + ih * input_width + iw] +=
196               grad_output_data[output_offset];
197           }
198         }
199       }
200       if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
201         auto gin = grad_input_data + c * input_slice_size;
202         apply_grad_input(acc_data_ptr, gin, input_slice_size);
203       }
204     }
205   };
206 
207   if (ndim == 3) {
208     // upsample nearest 1d
209     at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size, loop1d);
210   } else if (ndim == 4) {
211     // upsample nearest 2d
212     at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size , loop2d);
213   } else {
214     // upsample nearest 3d
215     TORCH_INTERNAL_ASSERT(ndim == 5);
216     at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size, loop3d);
217   }
218 
219   if (!grad_input_.is_contiguous()) {
220     grad_input_.copy_(grad_input);
221   }
222 }
223 
224 template <typename scalar_t, typename scale_type, nearest_idx_fn_t nearest_idx_fn>
cpu_upsample_nearest_backward_channels_last(const Tensor & grad_input_,const Tensor & grad_output_,const scale_type & scales)225 void cpu_upsample_nearest_backward_channels_last(
226     const Tensor& grad_input_,
227     const Tensor& grad_output_,
228     const scale_type& scales) {
229   TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
230               " for `grad_input` but got dtype ", grad_input_.dtype());
231 
232   auto ndim = grad_output_.ndimension();
233   TORCH_CHECK(ndim >=4 && ndim <= 5, "Upsample with NHWC format supports tensors with 4 or 5 dims.")
234 
235   auto channels_last_memory_format = ndim == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
236   auto grad_output = grad_output_.contiguous(channels_last_memory_format);
237   auto grad_input = grad_input_.contiguous(channels_last_memory_format);
238 
239   auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
240   auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
241 
242   auto input_sizes = grad_input.sizes().vec();
243   auto output_sizes = grad_output.sizes().vec();
244 
245   int64_t num_batches =  input_sizes[0];
246   int64_t channels =  input_sizes[1];
247   int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
248   int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
249   int64_t input_height = input_sizes[ndim - 2];
250   int64_t output_height = output_sizes[ndim - 2];
251   int64_t input_width = input_sizes[ndim - 1];
252   int64_t output_width = output_sizes[ndim - 1];
253   int64_t input_slice_size = input_depth * input_height * input_width * channels;
254 
255   using opmath_t = at::opmath_type<scalar_t>;
256   auto loop2d = [&](int64_t begin, int64_t end) {
257     opmath_t* acc_data_ptr = nullptr;
258     std::unique_ptr<opmath_t[]> buffer_data;
259     if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
260         buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
261         acc_data_ptr = buffer_data.get();
262         memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
263     } else {
264       acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
265     }
266 
267     for (const auto n : c10::irange(begin, end)) {
268       int64_t input_offset = buffer_data.get() == nullptr ? n * input_slice_size : 0;
269       for (const auto oh : c10::irange(output_height)) {
270         int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[0]);
271         for (const auto ow : c10::irange(output_width)) {
272           int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[1]);
273           const scalar_t* grad_output_ptr = grad_output_data +
274               (n * output_height * output_width + oh * output_width + ow) * channels;
275           opmath_t* buffer_ptr = acc_data_ptr + input_offset + (ih * input_width + iw) * channels;
276           nearest_channels_last_acc(buffer_ptr, grad_output_ptr, channels);
277         }
278       }
279       if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
280         auto gin = grad_input_data + n * input_slice_size;
281         apply_grad_input(acc_data_ptr, gin, input_slice_size);
282       }
283     }
284 
285   };
286 
287   auto loop3d = [&](int64_t begin, int64_t end) {
288     opmath_t* acc_data_ptr = nullptr;
289     std::unique_ptr<opmath_t[]> buffer_data;
290     if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
291         buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
292         acc_data_ptr = buffer_data.get();
293         memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
294     } else {
295       acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
296     }
297 
298     for (const auto n : c10::irange(begin, end)) {
299       int64_t input_offset = buffer_data.get() == nullptr ? n * input_slice_size : 0;
300       for (int64_t od = 0; od < output_depth; od++) {
301         int64_t id = nearest_idx_fn(od, input_depth, output_depth, scales[0]);
302         for (int64_t oh = 0; oh < output_height; oh++) {
303           int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[1]);
304           for (int64_t ow = 0; ow < output_width; ow++) {
305             int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[2]);
306             const scalar_t* grad_output_ptr = grad_output_data +
307                 (n * output_depth * output_height * output_width +
308                 od * output_height * output_width + oh * output_width + ow) * channels;
309 
310             opmath_t* buffer_ptr = acc_data_ptr + input_offset + (id * input_height * input_width + ih * input_width + iw) * channels;
311             nearest_channels_last_acc(buffer_ptr, grad_output_ptr, channels);
312           }
313         }
314       }
315       if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
316         auto gin = grad_input_data + n * input_slice_size;
317         apply_grad_input(acc_data_ptr, gin, input_slice_size);
318       }
319     }
320 
321   };
322 
323   if (ndim == 4) {
324     // upsample nearest 2d
325     at::parallel_for(0, num_batches, 0, loop2d);
326   } else {
327     // upsample nearest 3d
328     TORCH_INTERNAL_ASSERT(ndim == 5);
329     at::parallel_for(0, num_batches, 0, loop3d);
330   }
331 
332   if (!grad_input_.is_contiguous(channels_last_memory_format)) {
333     grad_input_.copy_(grad_input);
334   }
335 }
336 
upsample_nearest1d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,std::optional<double> scales_w)337 void upsample_nearest1d_backward_kernel_impl(
338     const Tensor& grad_input,
339     const Tensor& grad_output,
340     std::optional<double> scales_w) {
341   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_nearest1d_backward", [&] {
342     cpu_upsample_nearest_backward<scalar_t, scale_t, nearest_idx>(grad_input, grad_output, {scales_w});
343   });
344 }
345 
_upsample_nearest_exact1d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,std::optional<double> scales_w)346 void _upsample_nearest_exact1d_backward_kernel_impl(
347     const Tensor& grad_input,
348     const Tensor& grad_output,
349     std::optional<double> scales_w) {
350   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "_upsample_nearest_exact1d_backward", [&] {
351     cpu_upsample_nearest_backward<scalar_t, scale_t, nearest_exact_idx>(grad_input, grad_output, {scales_w});
352   });
353 }
354 
upsample_nearest2d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,std::optional<double> scales_h,std::optional<double> scales_w)355 void upsample_nearest2d_backward_kernel_impl(
356     const Tensor& grad_input,
357     const Tensor& grad_output,
358     std::optional<double> scales_h,
359     std::optional<double> scales_w) {
360   if (grad_output.is_contiguous(at::MemoryFormat::ChannelsLast)) {
361     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_nearest2d_backward_cl", [&] {
362       cpu_upsample_nearest_backward_channels_last<scalar_t, scale_t, nearest_idx>(grad_input, grad_output, {scales_h, scales_w});
363     });
364   } else {
365     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_nearest2d_backward", [&] {
366       cpu_upsample_nearest_backward<scalar_t, scale_t, nearest_idx>(grad_input, grad_output, {scales_h, scales_w});
367     });
368   }
369 }
370 
_upsample_nearest_exact2d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,std::optional<double> scales_h,std::optional<double> scales_w)371 void _upsample_nearest_exact2d_backward_kernel_impl(
372     const Tensor& grad_input,
373     const Tensor& grad_output,
374     std::optional<double> scales_h,
375     std::optional<double> scales_w) {
376   if (grad_output.is_contiguous(at::MemoryFormat::ChannelsLast)) {
377     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "_upsample_nearest_exact2d_backward_cl", [&] {
378       cpu_upsample_nearest_backward_channels_last<scalar_t, scale_t, nearest_exact_idx>(grad_input, grad_output, {scales_h, scales_w});
379     });
380   } else {
381     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "_upsample_nearest_exact2d_backward", [&] {
382       cpu_upsample_nearest_backward<scalar_t, scale_t, nearest_exact_idx>(grad_input, grad_output, {scales_h, scales_w});
383     });
384   }
385 }
386 
upsample_nearest3d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)387 void upsample_nearest3d_backward_kernel_impl(
388     const Tensor& grad_input,
389     const Tensor& grad_output,
390     std::optional<double> scales_d,
391     std::optional<double> scales_h,
392     std::optional<double> scales_w) {
393   if (grad_output.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
394     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "_upsample_nearest3d_backward_cl", [&] {
395       cpu_upsample_nearest_backward_channels_last<scalar_t, scale_t, nearest_idx>(grad_input, grad_output, {scales_d, scales_h, scales_w});
396     });
397   } else {
398     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_nearest3d_backward", [&] {
399       cpu_upsample_nearest_backward<scalar_t, scale_t, nearest_idx>(grad_input, grad_output, {scales_d, scales_h, scales_w});
400     });
401   }
402 }
403 
_upsample_nearest_exact3d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)404 void _upsample_nearest_exact3d_backward_kernel_impl(
405     const Tensor& grad_input,
406     const Tensor& grad_output,
407     std::optional<double> scales_d,
408     std::optional<double> scales_h,
409     std::optional<double> scales_w) {
410   if (grad_output.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
411     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "_upsample_nearest_exact3d_backward_cl", [&] {
412       cpu_upsample_nearest_backward_channels_last<scalar_t, scale_t, nearest_exact_idx>(grad_input, grad_output, {scales_d, scales_h, scales_w});
413     });
414   } else {
415     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "_upsample_nearest_exact3d_backward", [&] {
416       cpu_upsample_nearest_backward<scalar_t, scale_t, nearest_exact_idx>(grad_input, grad_output, {scales_d, scales_h, scales_w});
417     });
418   }
419 }
420 
421 template <typename scalar_t, typename scale_type>
cpu_upsample_linear_backward(const Tensor & grad_input_,const Tensor & grad_output_,bool align_corners,const scale_type & scales)422 void cpu_upsample_linear_backward(
423     const Tensor& grad_input_,
424     const Tensor& grad_output_,
425     bool align_corners,
426     const scale_type& scales) {
427   TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
428               " for `grad_input` but got dtype ", grad_input_.dtype());
429 
430   auto grad_output = grad_output_.contiguous();
431   auto grad_input = grad_input_.contiguous();
432 
433   auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
434   auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
435   auto input_sizes = grad_input.sizes().vec();
436   auto output_sizes = grad_output.sizes().vec();
437   auto ndim = input_sizes.size();
438 
439   // treat nbatch and channels as one dimension
440   int64_t channels = input_sizes[0] * input_sizes[1];
441   int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
442   int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
443   int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1;
444   int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1;
445   int64_t input_width = input_sizes[ndim - 1];
446   int64_t output_width = output_sizes[ndim - 1];
447 
448   int64_t input_slice_size = input_depth * input_height * input_width;
449   int64_t output_slice_size = output_depth * output_height * output_width;
450   using opmath_t = at::opmath_type<scalar_t>;
451   auto loop1d = [&](int64_t begin, int64_t end) {
452     opmath_t* acc_data_ptr = nullptr;
453     std::unique_ptr<opmath_t[]> buffer_data;
454     if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
455         buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
456         acc_data_ptr = buffer_data.get();
457         memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
458     } else {
459       acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
460     }
461 
462     const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
463         input_width, output_width, align_corners, scales[0]);
464 
465     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
466     int64_t iw0, iw1;
467     opmath_t w0lambda, w1lambda;
468     for (const auto c : c10::irange(begin, end)) {
469       int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
470       for (const auto ow : c10::irange(output_width)) {
471         compute_source_index_and_lambda(
472             iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
473         opmath_t grad_output_value = grad_output_data[c * output_slice_size + ow];
474         acc_data_ptr[input_offset + iw0] += w0lambda * grad_output_value; /* i0 */
475         acc_data_ptr[input_offset + iw1] += w1lambda * grad_output_value; /* i1*/
476       }
477       if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
478         auto gin = grad_input_data + c * input_slice_size;
479         apply_grad_input(acc_data_ptr, gin, input_slice_size);
480       }
481     }
482   };
483 
484   auto loop2d = [&](int64_t begin, int64_t end) {
485     opmath_t* acc_data_ptr = nullptr;
486     std::unique_ptr<opmath_t[]> buffer_data;
487     if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
488         buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
489         acc_data_ptr = buffer_data.get();
490         memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
491     } else {
492       acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
493     }
494 
495     const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
496         input_height, output_height, align_corners, scales[0]);
497     const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
498         input_width, output_width, align_corners, scales[1]);
499 
500     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
501     int64_t ih0, ih1, iw0, iw1;
502     opmath_t h0lambda, h1lambda, w0lambda, w1lambda;
503     for (const auto c : c10::irange(begin, end)) {
504       int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
505       for (const auto oh : c10::irange(output_height)) {
506         compute_source_index_and_lambda(
507             ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
508         for (const auto ow : c10::irange(output_width)) {
509           compute_source_index_and_lambda(
510               iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
511           opmath_t grad_output_value = grad_output_data[c * output_slice_size + oh * output_width + ow];
512           acc_data_ptr[input_offset + ih0 * input_width + iw0] += h0lambda * w0lambda * grad_output_value; /* i00 */
513           acc_data_ptr[input_offset + ih0 * input_width + iw1] += h0lambda * w1lambda * grad_output_value; /* i01 */
514           acc_data_ptr[input_offset + ih1 * input_width + iw0] += h1lambda * w0lambda * grad_output_value; /* i10 */
515           acc_data_ptr[input_offset + ih1 * input_width + iw1] += h1lambda * w1lambda * grad_output_value; /* i11 */
516         }
517       }
518       if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
519         auto gin = grad_input_data + c * input_slice_size;
520         apply_grad_input(acc_data_ptr, gin, input_slice_size);
521       }
522     }
523   };
524 
525   auto loop3d = [&](int64_t begin, int64_t end) {
526     opmath_t* acc_data_ptr = nullptr;
527     std::unique_ptr<opmath_t[]> buffer_data;
528     if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
529         buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
530         acc_data_ptr = buffer_data.get();
531         memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
532     } else {
533       acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
534     }
535 
536     const opmath_t depth_scale = area_pixel_compute_scale<opmath_t>(
537         input_depth, output_depth, align_corners, scales[0]);
538     const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
539         input_height, output_height, align_corners, scales[1]);
540     const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
541         input_width, output_width, align_corners, scales[2]);
542 
543     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
544     int64_t id0, id1, ih0, ih1, iw0, iw1;
545     opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda;
546     for (const auto c : c10::irange(begin, end)) {
547       int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
548       for (const auto od : c10::irange(output_depth)) {
549         compute_source_index_and_lambda(
550             id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners);
551         for (const auto oh : c10::irange(output_height)) {
552           compute_source_index_and_lambda(
553               ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
554           for (const auto ow : c10::irange(output_width)) {
555             compute_source_index_and_lambda(
556                 iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
557             opmath_t grad_output_value = grad_output_data[c * output_slice_size +
558                 od *  output_height * output_width + oh * output_width + ow];
559             acc_data_ptr[input_offset + id0 * input_height * input_width + ih0 * input_width + iw0] += d0lambda * h0lambda * w0lambda * grad_output_value; /* i000 */
560             acc_data_ptr[input_offset + id0 * input_height * input_width + ih0 * input_width + iw1] += d0lambda * h0lambda * w1lambda * grad_output_value; /* i001 */
561             acc_data_ptr[input_offset + id0 * input_height * input_width + ih1 * input_width + iw0] += d0lambda * h1lambda * w0lambda * grad_output_value; /* i010 */
562             acc_data_ptr[input_offset + id0 * input_height * input_width + ih1 * input_width + iw1] += d0lambda * h1lambda * w1lambda * grad_output_value; /* i011 */
563             acc_data_ptr[input_offset + id1 * input_height * input_width + ih0 * input_width + iw0] += d1lambda * h0lambda * w0lambda * grad_output_value; /* i100 */
564             acc_data_ptr[input_offset + id1 * input_height * input_width + ih0 * input_width + iw1] += d1lambda * h0lambda * w1lambda * grad_output_value; /* i101 */
565             acc_data_ptr[input_offset + id1 * input_height * input_width + ih1 * input_width + iw0] += d1lambda * h1lambda * w0lambda * grad_output_value; /* i110 */
566             acc_data_ptr[input_offset + id1 * input_height * input_width + ih1 * input_width + iw1] += d1lambda * h1lambda * w1lambda * grad_output_value; /* i111 */
567           }
568         }
569       }
570       if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
571         auto gin = grad_input_data + c * input_slice_size;
572         apply_grad_input(acc_data_ptr, gin, input_slice_size);
573       }
574     }
575   };
576 
577   if (ndim == 3) {
578     // upsample linear 1d
579     at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 2, loop1d);
580   } else if (ndim == 4) {
581     // upsample bilinear 2d
582     at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d);
583   } else {
584     // upsample trilinear 3d
585     TORCH_INTERNAL_ASSERT(ndim == 5);
586     at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 8, loop3d);
587   }
588 
589   if (!grad_input_.is_contiguous()) {
590     grad_input_.copy_(grad_input);
591   }
592 }
593 
594 template <typename scalar_t, typename scale_type>
cpu_upsample_linear_backward_channels_last(const Tensor & grad_input_,const Tensor & grad_output_,bool align_corners,const scale_type & scales)595 void cpu_upsample_linear_backward_channels_last(
596     const Tensor& grad_input_,
597     const Tensor& grad_output_,
598     bool align_corners,
599     const scale_type& scales) {
600   TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
601               " for `grad_input` but got dtype ", grad_input_.dtype());
602 
603   auto ndim = grad_output_.ndimension();
604   TORCH_CHECK(ndim >=4 && ndim <= 5, "Upsample with NHWC format supports tensors with 4 or 5 dims.")
605 
606   auto channels_last_memory_format = ndim == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
607   auto grad_output = grad_output_.contiguous(channels_last_memory_format);
608   auto grad_input = grad_input_.contiguous(channels_last_memory_format);
609 
610   auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
611   auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
612 
613   auto input_sizes = grad_input.sizes().vec();
614   auto output_sizes = grad_output.sizes().vec();
615 
616   int64_t num_batches =  input_sizes[0];
617   int64_t channels =  input_sizes[1];
618   int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
619   int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
620   int64_t input_height = input_sizes[ndim - 2];
621   int64_t output_height = output_sizes[ndim - 2];
622   int64_t input_width = input_sizes[ndim - 1];
623   int64_t output_width = output_sizes[ndim - 1];
624   int64_t input_slice_size = input_depth * input_height * input_width * channels;
625   using opmath_t = at::opmath_type<scalar_t>;
626 
627   auto loop2d = [&](int64_t begin, int64_t end) {
628     opmath_t* acc_data_ptr = nullptr;
629     std::unique_ptr<opmath_t[]> buffer_data;
630     if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
631         buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
632         acc_data_ptr = buffer_data.get();
633         memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
634     } else {
635       acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
636     }
637 
638     const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
639         input_height, output_height, align_corners, scales[0]);
640     const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
641         input_width, output_width, align_corners, scales[1]);
642 
643     auto input_indexr = [=](int64_t n, int64_t h, int64_t w, int64_t offset){
644       return acc_data_ptr + offset + (h * input_width + w) * channels;
645     };
646 
647     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
648     int64_t ih0, ih1, iw0, iw1;
649     opmath_t h0lambda, h1lambda, w0lambda, w1lambda;
650     for (const auto n : c10::irange(begin, end)) {
651       int64_t input_offset = buffer_data.get() == nullptr ? n * input_slice_size : 0;
652       for (const auto oh : c10::irange(output_height)) {
653         compute_source_index_and_lambda(
654             ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
655         for (const auto ow : c10::irange(output_width)) {
656           compute_source_index_and_lambda(
657               iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
658           const scalar_t* grad_output_ptr = grad_output_data +
659               (n * output_height * output_width + oh * output_width + ow) * channels;
660           linear_channels_last_acc(input_indexr(n, ih0, iw0, input_offset), grad_output_ptr, h0lambda * w0lambda, channels); /* i00 */
661           linear_channels_last_acc(input_indexr(n, ih0, iw1, input_offset), grad_output_ptr, h0lambda * w1lambda, channels); /* i01 */
662           linear_channels_last_acc(input_indexr(n, ih1, iw0, input_offset), grad_output_ptr, h1lambda * w0lambda, channels); /* i10 */
663           linear_channels_last_acc(input_indexr(n, ih1, iw1, input_offset), grad_output_ptr, h1lambda * w1lambda, channels); /* i11 */
664         }
665       }
666       if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
667         auto gin = grad_input_data + n * input_slice_size;
668         apply_grad_input(acc_data_ptr, gin, input_slice_size);
669       }
670 
671     }
672   };
673 
674   auto loop3d = [&](int64_t begin, int64_t end) {
675     opmath_t* acc_data_ptr = nullptr;
676     std::unique_ptr<opmath_t[]> buffer_data;
677     if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
678         buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
679         acc_data_ptr = buffer_data.get();
680         memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
681     } else {
682       acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
683     }
684 
685     const opmath_t depth_scale = area_pixel_compute_scale<opmath_t>(
686         input_depth, output_depth, align_corners, scales[0]);
687     const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
688         input_height, output_height, align_corners, scales[1]);
689     const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
690         input_width, output_width, align_corners, scales[2]);
691 
692     auto input_indexr = [=](int64_t n, int64_t d, int64_t h, int64_t w, int64_t offset) {
693       return acc_data_ptr + offset + (d * input_height * input_width + h * input_width + w) * channels;
694     };
695 
696     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
697     int64_t id0, id1, ih0, ih1, iw0, iw1;
698     opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda;
699     for (const auto n : c10::irange(begin, end)) {
700       int64_t input_offset = buffer_data.get() == nullptr ? n * input_slice_size : 0;
701       for (const auto od : c10::irange(output_depth)) {
702         compute_source_index_and_lambda(
703             id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners);
704         for (const auto oh : c10::irange(output_height)) {
705           compute_source_index_and_lambda(
706               ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
707           for (const auto ow : c10::irange(output_width)) {
708             compute_source_index_and_lambda(
709                 iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
710             const scalar_t* grad_output_ptr = grad_output_data + (n * output_depth * output_height * output_width +
711                 od *  output_height * output_width + oh * output_width + ow) * channels;
712             linear_channels_last_acc(input_indexr(n, id0, ih0, iw0, input_offset), grad_output_ptr, d0lambda * h0lambda * w0lambda, channels); /* i000 */
713             linear_channels_last_acc(input_indexr(n, id0, ih0, iw1, input_offset), grad_output_ptr, d0lambda * h0lambda * w1lambda, channels); /* i001 */
714             linear_channels_last_acc(input_indexr(n, id0, ih1, iw0, input_offset), grad_output_ptr, d0lambda * h1lambda * w0lambda, channels); /* i010 */
715             linear_channels_last_acc(input_indexr(n, id0, ih1, iw1, input_offset), grad_output_ptr, d0lambda * h1lambda * w1lambda, channels); /* i011 */
716             linear_channels_last_acc(input_indexr(n, id1, ih0, iw0, input_offset), grad_output_ptr, d1lambda * h0lambda * w0lambda, channels); /* i100 */
717             linear_channels_last_acc(input_indexr(n, id1, ih0, iw1, input_offset), grad_output_ptr, d1lambda * h0lambda * w1lambda, channels); /* i101 */
718             linear_channels_last_acc(input_indexr(n, id1, ih1, iw0, input_offset), grad_output_ptr, d1lambda * h1lambda * w0lambda, channels); /* i110 */
719             linear_channels_last_acc(input_indexr(n, id1, ih1, iw1, input_offset), grad_output_ptr, d1lambda * h1lambda * w1lambda, channels); /* i111 */
720           }
721         }
722       }
723       if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
724         auto gin = grad_input_data + n * input_slice_size;
725         apply_grad_input(acc_data_ptr, gin, input_slice_size);
726       }
727     }
728   };
729 
730   if (ndim == 4) {
731     // upsample bilinear 2d
732     at::parallel_for(0, num_batches, 0, loop2d);
733   } else {
734     // upsample trilinear 3d
735     TORCH_INTERNAL_ASSERT(ndim == 5);
736     at::parallel_for(0, num_batches, 0, loop3d);
737   }
738 
739   if (!grad_input_.is_contiguous(channels_last_memory_format)) {
740     grad_input_.copy_(grad_input);
741   }
742 }
743 
upsample_linear1d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,bool align_corners,std::optional<double> scales_w)744 void upsample_linear1d_backward_kernel_impl(
745     const Tensor& grad_input,
746     const Tensor& grad_output,
747     bool align_corners,
748     std::optional<double> scales_w) {
749   AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_linear1d_backward", [&] {
750     cpu_upsample_linear_backward<scalar_t, scale_t>(grad_input, grad_output, align_corners, {scales_w});
751   });
752 }
753 
upsample_bilinear2d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)754 void upsample_bilinear2d_backward_kernel_impl(
755     const Tensor& grad_input,
756     const Tensor& grad_output,
757     bool align_corners,
758     std::optional<double> scales_h,
759     std::optional<double> scales_w) {
760   if (grad_output.is_contiguous(at::MemoryFormat::ChannelsLast)) {
761     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_bilinear2d_backward_channels_last", [&] {
762       cpu_upsample_linear_backward_channels_last<scalar_t, scale_t>(grad_input, grad_output, align_corners, {scales_h, scales_w});
763     });
764   } else {
765     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_bilinear2d_backward", [&] {
766       cpu_upsample_linear_backward<scalar_t, scale_t>(grad_input, grad_output, align_corners, {scales_h, scales_w});
767     });
768   }
769 }
770 
upsample_trilinear3d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,bool align_corners,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)771 void upsample_trilinear3d_backward_kernel_impl(
772     const Tensor& grad_input,
773     const Tensor& grad_output,
774     bool align_corners,
775     std::optional<double> scales_d,
776     std::optional<double> scales_h,
777     std::optional<double> scales_w) {
778   if (grad_output.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
779     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_trilinear3d_backward_channels_last", [&] {
780       cpu_upsample_linear_backward_channels_last<scalar_t, scale_t>(grad_input, grad_output, align_corners, {scales_d, scales_h, scales_w});
781     });
782   } else {
783     AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_trilinear3d_backward", [&] {
784       cpu_upsample_linear_backward<scalar_t, scale_t>(grad_input, grad_output, align_corners, {scales_d, scales_h, scales_w});
785     });
786   }
787 }
788 
789 } // anonymous namespace
790 
791 REGISTER_DISPATCH(upsample_nearest1d_backward_kernel, &upsample_nearest1d_backward_kernel_impl);
792 REGISTER_DISPATCH(_upsample_nearest_exact1d_backward_kernel, &_upsample_nearest_exact1d_backward_kernel_impl);
793 REGISTER_DISPATCH(upsample_nearest2d_backward_kernel, &upsample_nearest2d_backward_kernel_impl);
794 REGISTER_DISPATCH(_upsample_nearest_exact2d_backward_kernel, &_upsample_nearest_exact2d_backward_kernel_impl);
795 REGISTER_DISPATCH(upsample_nearest3d_backward_kernel, &upsample_nearest3d_backward_kernel_impl);
796 REGISTER_DISPATCH(_upsample_nearest_exact3d_backward_kernel, &_upsample_nearest_exact3d_backward_kernel_impl);
797 
798 REGISTER_DISPATCH(upsample_linear1d_backward_kernel, &upsample_linear1d_backward_kernel_impl);
799 REGISTER_DISPATCH(upsample_bilinear2d_backward_kernel, &upsample_bilinear2d_backward_kernel_impl);
800 REGISTER_DISPATCH(upsample_trilinear3d_backward_kernel, &upsample_trilinear3d_backward_kernel_impl);
801 
802 } // namespace at::native
803