1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <vector>
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/native/UpSample.h>
7 #include <ATen/Parallel.h>
8 #include <ATen/TensorIterator.h>
9 #include <c10/util/irange.h>
10 #include <ATen/cpu/vec/vec.h>
11
12 namespace at::native {
13 namespace {
14
15 using scale_t = std::vector<std::optional<double>>;
16
17 template <typename acc_t, typename scalar_t,
18 typename scalar_nonconst_t = std::remove_const_t<scalar_t>,
19 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_nonconst_t> || !std::is_same_v<acc_t, float>, int> = 0>
nearest_channels_last_acc(acc_t * gin,scalar_t * gout,int64_t size)20 void inline nearest_channels_last_acc(acc_t* gin, scalar_t* gout, int64_t size) {
21 static_assert(std::is_same_v<acc_t, scalar_nonconst_t>,
22 "acc data type of Upsample backward should be same as scalar_t for float or double on CPU.");
23 using Vec = Vectorized<acc_t>;
24 int64_t d = 0;
25 for (; d < size - (size % Vec::size()); d += Vec::size()) {
26 Vec gin_vec = Vec::loadu(gin + d) + Vec::loadu(gout + d);
27 gin_vec.store(gin + d);
28 }
29 for (; d < size; d++) {
30 gin[d] += gout[d];
31 }
32 }
33
34 template <typename acc_t, typename scalar_t,
35 typename scalar_nonconst_t = std::remove_const_t<scalar_t>,
36 typename std::enable_if_t<is_reduced_floating_point_v<scalar_nonconst_t> && std::is_same_v<acc_t, float>, int> = 0>
nearest_channels_last_acc(acc_t * gin,scalar_t * gout,int64_t size)37 void inline nearest_channels_last_acc(acc_t* gin, scalar_t* gout, int64_t size) {
38 using bVec = Vectorized<scalar_nonconst_t>;
39 using fVec = Vectorized<float>;
40 int64_t d = 0;
41 for (; d < size - (size % bVec::size()); d += bVec::size()) {
42 bVec gout_bvec = bVec::loadu(gout + d);
43 auto [gout_fvec0, gout_fvec1] = convert_to_float<scalar_nonconst_t>(gout_bvec);
44 fVec gin_fvec0 = fVec::loadu(gin + d) + gout_fvec0;
45 fVec gin_fvec1 = fVec::loadu(gin + d + fVec::size()) + gout_fvec1;
46 gin_fvec0.store(gin + d);
47 gin_fvec1.store(gin + d + fVec::size());
48 }
49 for (; d < size; d++) {
50 gin[d] += gout[d];
51 }
52 }
53
54 template <typename acc_t, typename scalar_t,
55 typename scalar_nonconst_t = std::remove_const_t<scalar_t>,
56 typename std::enable_if_t<!is_reduced_floating_point_v<scalar_nonconst_t> || !std::is_same_v<acc_t, float>, int> = 0>
linear_channels_last_acc(acc_t * gin,const scalar_t * gout,acc_t w,int64_t size)57 void inline linear_channels_last_acc(acc_t* gin, const scalar_t* gout, acc_t w, int64_t size) {
58 static_assert(std::is_same_v<acc_t, scalar_nonconst_t>,
59 "acc data type of Upsample backward should be same as scalar_t for float or double on CPU.");
60 using Vec = Vectorized<acc_t>;
61 int64_t d = 0;
62 for (; d < size - (size % Vec::size()); d += Vec::size()) {
63 Vec gin_vec = Vec::loadu(gin + d) + Vec(w) * Vec::loadu(gout + d);
64 gin_vec.store(gin + d);
65 }
66 for (; d < size; d++) {
67 gin[d] += w * gout[d];
68 }
69 }
70
71 template <typename acc_t, typename scalar_t,
72 typename scalar_nonconst_t = std::remove_const_t<scalar_t>,
73 typename std::enable_if_t<is_reduced_floating_point_v<scalar_nonconst_t> && std::is_same_v<acc_t, float>, int> = 0>
linear_channels_last_acc(acc_t * gin,const scalar_t * gout,acc_t w,int64_t size)74 void inline linear_channels_last_acc(acc_t* gin, const scalar_t* gout, acc_t w, int64_t size) {
75 using bVec = Vectorized<scalar_nonconst_t>;
76 using fVec = Vectorized<float>;
77 int64_t d = 0;
78 for (; d < size - (size % bVec::size()); d += bVec::size()) {
79 bVec gout_bvec = bVec::loadu(gout + d);
80 auto [gout_fvec0, gout_fvec1] = convert_to_float<scalar_nonconst_t>(gout_bvec);
81 fVec gin_fvec0 = fVec::loadu(gin + d) + fVec(w) * gout_fvec0;
82 fVec gin_fvec1 = fVec::loadu(gin + d + fVec::size()) + fVec(w) * gout_fvec1;
83 gin_fvec0.store(gin + d);
84 gin_fvec1.store(gin + d + fVec::size());
85 }
86 for (; d < size; d++) {
87 gin[d] += w * gout[d];
88 }
89 }
90
91 template <typename scalar_t, typename scale_type, nearest_idx_fn_t nearest_idx_fn>
cpu_upsample_nearest_backward(const Tensor & grad_input_,const Tensor & grad_output_,const scale_type & scales)92 void cpu_upsample_nearest_backward(
93 const Tensor& grad_input_,
94 const Tensor& grad_output_,
95 const scale_type& scales) {
96 TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
97 " for `grad_input` but got dtype ", grad_input_.dtype());
98
99 auto grad_output = grad_output_.contiguous();
100 auto grad_input = grad_input_.contiguous();
101
102 auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
103 auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
104 auto input_sizes = grad_input.sizes().vec();
105 auto output_sizes = grad_output.sizes().vec();
106 auto ndim = input_sizes.size();
107
108 // treat nbatch and channels as one dimension
109 int64_t channels = input_sizes[0] * input_sizes[1];
110 int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
111 int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
112 int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1;
113 int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1;
114 int64_t input_width = input_sizes[ndim - 1];
115 int64_t output_width = output_sizes[ndim - 1];
116
117 int64_t output_slice_size = output_depth * output_height * output_width;
118 int64_t input_slice_size = input_depth * input_height * input_width;
119
120 using opmath_t = at::opmath_type<scalar_t>;
121 auto loop1d = [&](int64_t begin, int64_t end) {
122 opmath_t* acc_data_ptr = nullptr;
123 std::unique_ptr<opmath_t[]> buffer_data;
124 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
125 buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
126 acc_data_ptr = buffer_data.get();
127 memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
128 } else {
129 acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
130 }
131
132 for (const auto c : c10::irange(begin, end)) {
133 int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
134 for (const auto ow : c10::irange(output_width)) {
135 int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[0]);
136 int64_t output_offset = c * output_slice_size + ow;
137 acc_data_ptr[input_offset + iw] += grad_output_data[output_offset];
138 }
139 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
140 auto gin = grad_input_data + c * input_slice_size;
141 apply_grad_input(acc_data_ptr, gin, input_slice_size);
142 }
143 }
144 };
145
146 auto loop2d = [&](int64_t begin, int64_t end) {
147 opmath_t* acc_data_ptr = nullptr;
148 std::unique_ptr<opmath_t[]> buffer_data;
149 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
150 buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
151 acc_data_ptr = buffer_data.get();
152 memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
153 } else {
154 acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
155 }
156
157 for (const auto c : c10::irange(begin, end)) {
158 int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
159 for (const auto oh : c10::irange(output_height)) {
160 int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[0]);
161 for (const auto ow : c10::irange(output_width)) {
162 int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[1]);
163 int64_t output_offset = c * output_slice_size + oh * output_width + ow;
164 acc_data_ptr[input_offset + ih * input_width + iw] += grad_output_data[output_offset];
165 }
166 }
167 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
168 auto gin = grad_input_data + c * input_slice_size;
169 apply_grad_input(acc_data_ptr, gin, input_slice_size);
170 }
171 }
172 };
173
174 auto loop3d = [&](int64_t begin, int64_t end) {
175 opmath_t* acc_data_ptr = nullptr;
176 std::unique_ptr<opmath_t[]> buffer_data;
177 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
178 buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
179 acc_data_ptr = buffer_data.get();
180 memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
181 } else {
182 acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
183 }
184
185 for (const auto c : c10::irange(begin, end)) {
186 int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
187 for (const auto od : c10::irange(output_depth)) {
188 int64_t id = nearest_idx_fn(od, input_depth, output_depth, scales[0]);
189 for (const auto oh : c10::irange(output_height)) {
190 int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[1]);
191 for (const auto ow : c10::irange(output_width)) {
192 int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[2]);
193 int64_t output_offset = c * output_slice_size +
194 od * output_height * output_width + oh * output_width + ow;
195 acc_data_ptr[input_offset + id * input_height * input_width + ih * input_width + iw] +=
196 grad_output_data[output_offset];
197 }
198 }
199 }
200 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
201 auto gin = grad_input_data + c * input_slice_size;
202 apply_grad_input(acc_data_ptr, gin, input_slice_size);
203 }
204 }
205 };
206
207 if (ndim == 3) {
208 // upsample nearest 1d
209 at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size, loop1d);
210 } else if (ndim == 4) {
211 // upsample nearest 2d
212 at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size , loop2d);
213 } else {
214 // upsample nearest 3d
215 TORCH_INTERNAL_ASSERT(ndim == 5);
216 at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size, loop3d);
217 }
218
219 if (!grad_input_.is_contiguous()) {
220 grad_input_.copy_(grad_input);
221 }
222 }
223
224 template <typename scalar_t, typename scale_type, nearest_idx_fn_t nearest_idx_fn>
cpu_upsample_nearest_backward_channels_last(const Tensor & grad_input_,const Tensor & grad_output_,const scale_type & scales)225 void cpu_upsample_nearest_backward_channels_last(
226 const Tensor& grad_input_,
227 const Tensor& grad_output_,
228 const scale_type& scales) {
229 TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
230 " for `grad_input` but got dtype ", grad_input_.dtype());
231
232 auto ndim = grad_output_.ndimension();
233 TORCH_CHECK(ndim >=4 && ndim <= 5, "Upsample with NHWC format supports tensors with 4 or 5 dims.")
234
235 auto channels_last_memory_format = ndim == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
236 auto grad_output = grad_output_.contiguous(channels_last_memory_format);
237 auto grad_input = grad_input_.contiguous(channels_last_memory_format);
238
239 auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
240 auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
241
242 auto input_sizes = grad_input.sizes().vec();
243 auto output_sizes = grad_output.sizes().vec();
244
245 int64_t num_batches = input_sizes[0];
246 int64_t channels = input_sizes[1];
247 int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
248 int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
249 int64_t input_height = input_sizes[ndim - 2];
250 int64_t output_height = output_sizes[ndim - 2];
251 int64_t input_width = input_sizes[ndim - 1];
252 int64_t output_width = output_sizes[ndim - 1];
253 int64_t input_slice_size = input_depth * input_height * input_width * channels;
254
255 using opmath_t = at::opmath_type<scalar_t>;
256 auto loop2d = [&](int64_t begin, int64_t end) {
257 opmath_t* acc_data_ptr = nullptr;
258 std::unique_ptr<opmath_t[]> buffer_data;
259 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
260 buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
261 acc_data_ptr = buffer_data.get();
262 memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
263 } else {
264 acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
265 }
266
267 for (const auto n : c10::irange(begin, end)) {
268 int64_t input_offset = buffer_data.get() == nullptr ? n * input_slice_size : 0;
269 for (const auto oh : c10::irange(output_height)) {
270 int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[0]);
271 for (const auto ow : c10::irange(output_width)) {
272 int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[1]);
273 const scalar_t* grad_output_ptr = grad_output_data +
274 (n * output_height * output_width + oh * output_width + ow) * channels;
275 opmath_t* buffer_ptr = acc_data_ptr + input_offset + (ih * input_width + iw) * channels;
276 nearest_channels_last_acc(buffer_ptr, grad_output_ptr, channels);
277 }
278 }
279 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
280 auto gin = grad_input_data + n * input_slice_size;
281 apply_grad_input(acc_data_ptr, gin, input_slice_size);
282 }
283 }
284
285 };
286
287 auto loop3d = [&](int64_t begin, int64_t end) {
288 opmath_t* acc_data_ptr = nullptr;
289 std::unique_ptr<opmath_t[]> buffer_data;
290 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
291 buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
292 acc_data_ptr = buffer_data.get();
293 memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
294 } else {
295 acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
296 }
297
298 for (const auto n : c10::irange(begin, end)) {
299 int64_t input_offset = buffer_data.get() == nullptr ? n * input_slice_size : 0;
300 for (int64_t od = 0; od < output_depth; od++) {
301 int64_t id = nearest_idx_fn(od, input_depth, output_depth, scales[0]);
302 for (int64_t oh = 0; oh < output_height; oh++) {
303 int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[1]);
304 for (int64_t ow = 0; ow < output_width; ow++) {
305 int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[2]);
306 const scalar_t* grad_output_ptr = grad_output_data +
307 (n * output_depth * output_height * output_width +
308 od * output_height * output_width + oh * output_width + ow) * channels;
309
310 opmath_t* buffer_ptr = acc_data_ptr + input_offset + (id * input_height * input_width + ih * input_width + iw) * channels;
311 nearest_channels_last_acc(buffer_ptr, grad_output_ptr, channels);
312 }
313 }
314 }
315 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
316 auto gin = grad_input_data + n * input_slice_size;
317 apply_grad_input(acc_data_ptr, gin, input_slice_size);
318 }
319 }
320
321 };
322
323 if (ndim == 4) {
324 // upsample nearest 2d
325 at::parallel_for(0, num_batches, 0, loop2d);
326 } else {
327 // upsample nearest 3d
328 TORCH_INTERNAL_ASSERT(ndim == 5);
329 at::parallel_for(0, num_batches, 0, loop3d);
330 }
331
332 if (!grad_input_.is_contiguous(channels_last_memory_format)) {
333 grad_input_.copy_(grad_input);
334 }
335 }
336
upsample_nearest1d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,std::optional<double> scales_w)337 void upsample_nearest1d_backward_kernel_impl(
338 const Tensor& grad_input,
339 const Tensor& grad_output,
340 std::optional<double> scales_w) {
341 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_nearest1d_backward", [&] {
342 cpu_upsample_nearest_backward<scalar_t, scale_t, nearest_idx>(grad_input, grad_output, {scales_w});
343 });
344 }
345
_upsample_nearest_exact1d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,std::optional<double> scales_w)346 void _upsample_nearest_exact1d_backward_kernel_impl(
347 const Tensor& grad_input,
348 const Tensor& grad_output,
349 std::optional<double> scales_w) {
350 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "_upsample_nearest_exact1d_backward", [&] {
351 cpu_upsample_nearest_backward<scalar_t, scale_t, nearest_exact_idx>(grad_input, grad_output, {scales_w});
352 });
353 }
354
upsample_nearest2d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,std::optional<double> scales_h,std::optional<double> scales_w)355 void upsample_nearest2d_backward_kernel_impl(
356 const Tensor& grad_input,
357 const Tensor& grad_output,
358 std::optional<double> scales_h,
359 std::optional<double> scales_w) {
360 if (grad_output.is_contiguous(at::MemoryFormat::ChannelsLast)) {
361 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_nearest2d_backward_cl", [&] {
362 cpu_upsample_nearest_backward_channels_last<scalar_t, scale_t, nearest_idx>(grad_input, grad_output, {scales_h, scales_w});
363 });
364 } else {
365 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_nearest2d_backward", [&] {
366 cpu_upsample_nearest_backward<scalar_t, scale_t, nearest_idx>(grad_input, grad_output, {scales_h, scales_w});
367 });
368 }
369 }
370
_upsample_nearest_exact2d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,std::optional<double> scales_h,std::optional<double> scales_w)371 void _upsample_nearest_exact2d_backward_kernel_impl(
372 const Tensor& grad_input,
373 const Tensor& grad_output,
374 std::optional<double> scales_h,
375 std::optional<double> scales_w) {
376 if (grad_output.is_contiguous(at::MemoryFormat::ChannelsLast)) {
377 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "_upsample_nearest_exact2d_backward_cl", [&] {
378 cpu_upsample_nearest_backward_channels_last<scalar_t, scale_t, nearest_exact_idx>(grad_input, grad_output, {scales_h, scales_w});
379 });
380 } else {
381 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "_upsample_nearest_exact2d_backward", [&] {
382 cpu_upsample_nearest_backward<scalar_t, scale_t, nearest_exact_idx>(grad_input, grad_output, {scales_h, scales_w});
383 });
384 }
385 }
386
upsample_nearest3d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)387 void upsample_nearest3d_backward_kernel_impl(
388 const Tensor& grad_input,
389 const Tensor& grad_output,
390 std::optional<double> scales_d,
391 std::optional<double> scales_h,
392 std::optional<double> scales_w) {
393 if (grad_output.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
394 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "_upsample_nearest3d_backward_cl", [&] {
395 cpu_upsample_nearest_backward_channels_last<scalar_t, scale_t, nearest_idx>(grad_input, grad_output, {scales_d, scales_h, scales_w});
396 });
397 } else {
398 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_nearest3d_backward", [&] {
399 cpu_upsample_nearest_backward<scalar_t, scale_t, nearest_idx>(grad_input, grad_output, {scales_d, scales_h, scales_w});
400 });
401 }
402 }
403
_upsample_nearest_exact3d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)404 void _upsample_nearest_exact3d_backward_kernel_impl(
405 const Tensor& grad_input,
406 const Tensor& grad_output,
407 std::optional<double> scales_d,
408 std::optional<double> scales_h,
409 std::optional<double> scales_w) {
410 if (grad_output.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
411 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "_upsample_nearest_exact3d_backward_cl", [&] {
412 cpu_upsample_nearest_backward_channels_last<scalar_t, scale_t, nearest_exact_idx>(grad_input, grad_output, {scales_d, scales_h, scales_w});
413 });
414 } else {
415 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "_upsample_nearest_exact3d_backward", [&] {
416 cpu_upsample_nearest_backward<scalar_t, scale_t, nearest_exact_idx>(grad_input, grad_output, {scales_d, scales_h, scales_w});
417 });
418 }
419 }
420
421 template <typename scalar_t, typename scale_type>
cpu_upsample_linear_backward(const Tensor & grad_input_,const Tensor & grad_output_,bool align_corners,const scale_type & scales)422 void cpu_upsample_linear_backward(
423 const Tensor& grad_input_,
424 const Tensor& grad_output_,
425 bool align_corners,
426 const scale_type& scales) {
427 TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
428 " for `grad_input` but got dtype ", grad_input_.dtype());
429
430 auto grad_output = grad_output_.contiguous();
431 auto grad_input = grad_input_.contiguous();
432
433 auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
434 auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
435 auto input_sizes = grad_input.sizes().vec();
436 auto output_sizes = grad_output.sizes().vec();
437 auto ndim = input_sizes.size();
438
439 // treat nbatch and channels as one dimension
440 int64_t channels = input_sizes[0] * input_sizes[1];
441 int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
442 int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
443 int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1;
444 int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1;
445 int64_t input_width = input_sizes[ndim - 1];
446 int64_t output_width = output_sizes[ndim - 1];
447
448 int64_t input_slice_size = input_depth * input_height * input_width;
449 int64_t output_slice_size = output_depth * output_height * output_width;
450 using opmath_t = at::opmath_type<scalar_t>;
451 auto loop1d = [&](int64_t begin, int64_t end) {
452 opmath_t* acc_data_ptr = nullptr;
453 std::unique_ptr<opmath_t[]> buffer_data;
454 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
455 buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
456 acc_data_ptr = buffer_data.get();
457 memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
458 } else {
459 acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
460 }
461
462 const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
463 input_width, output_width, align_corners, scales[0]);
464
465 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
466 int64_t iw0, iw1;
467 opmath_t w0lambda, w1lambda;
468 for (const auto c : c10::irange(begin, end)) {
469 int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
470 for (const auto ow : c10::irange(output_width)) {
471 compute_source_index_and_lambda(
472 iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
473 opmath_t grad_output_value = grad_output_data[c * output_slice_size + ow];
474 acc_data_ptr[input_offset + iw0] += w0lambda * grad_output_value; /* i0 */
475 acc_data_ptr[input_offset + iw1] += w1lambda * grad_output_value; /* i1*/
476 }
477 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
478 auto gin = grad_input_data + c * input_slice_size;
479 apply_grad_input(acc_data_ptr, gin, input_slice_size);
480 }
481 }
482 };
483
484 auto loop2d = [&](int64_t begin, int64_t end) {
485 opmath_t* acc_data_ptr = nullptr;
486 std::unique_ptr<opmath_t[]> buffer_data;
487 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
488 buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
489 acc_data_ptr = buffer_data.get();
490 memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
491 } else {
492 acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
493 }
494
495 const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
496 input_height, output_height, align_corners, scales[0]);
497 const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
498 input_width, output_width, align_corners, scales[1]);
499
500 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
501 int64_t ih0, ih1, iw0, iw1;
502 opmath_t h0lambda, h1lambda, w0lambda, w1lambda;
503 for (const auto c : c10::irange(begin, end)) {
504 int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
505 for (const auto oh : c10::irange(output_height)) {
506 compute_source_index_and_lambda(
507 ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
508 for (const auto ow : c10::irange(output_width)) {
509 compute_source_index_and_lambda(
510 iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
511 opmath_t grad_output_value = grad_output_data[c * output_slice_size + oh * output_width + ow];
512 acc_data_ptr[input_offset + ih0 * input_width + iw0] += h0lambda * w0lambda * grad_output_value; /* i00 */
513 acc_data_ptr[input_offset + ih0 * input_width + iw1] += h0lambda * w1lambda * grad_output_value; /* i01 */
514 acc_data_ptr[input_offset + ih1 * input_width + iw0] += h1lambda * w0lambda * grad_output_value; /* i10 */
515 acc_data_ptr[input_offset + ih1 * input_width + iw1] += h1lambda * w1lambda * grad_output_value; /* i11 */
516 }
517 }
518 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
519 auto gin = grad_input_data + c * input_slice_size;
520 apply_grad_input(acc_data_ptr, gin, input_slice_size);
521 }
522 }
523 };
524
525 auto loop3d = [&](int64_t begin, int64_t end) {
526 opmath_t* acc_data_ptr = nullptr;
527 std::unique_ptr<opmath_t[]> buffer_data;
528 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
529 buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
530 acc_data_ptr = buffer_data.get();
531 memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
532 } else {
533 acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
534 }
535
536 const opmath_t depth_scale = area_pixel_compute_scale<opmath_t>(
537 input_depth, output_depth, align_corners, scales[0]);
538 const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
539 input_height, output_height, align_corners, scales[1]);
540 const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
541 input_width, output_width, align_corners, scales[2]);
542
543 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
544 int64_t id0, id1, ih0, ih1, iw0, iw1;
545 opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda;
546 for (const auto c : c10::irange(begin, end)) {
547 int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
548 for (const auto od : c10::irange(output_depth)) {
549 compute_source_index_and_lambda(
550 id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners);
551 for (const auto oh : c10::irange(output_height)) {
552 compute_source_index_and_lambda(
553 ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
554 for (const auto ow : c10::irange(output_width)) {
555 compute_source_index_and_lambda(
556 iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
557 opmath_t grad_output_value = grad_output_data[c * output_slice_size +
558 od * output_height * output_width + oh * output_width + ow];
559 acc_data_ptr[input_offset + id0 * input_height * input_width + ih0 * input_width + iw0] += d0lambda * h0lambda * w0lambda * grad_output_value; /* i000 */
560 acc_data_ptr[input_offset + id0 * input_height * input_width + ih0 * input_width + iw1] += d0lambda * h0lambda * w1lambda * grad_output_value; /* i001 */
561 acc_data_ptr[input_offset + id0 * input_height * input_width + ih1 * input_width + iw0] += d0lambda * h1lambda * w0lambda * grad_output_value; /* i010 */
562 acc_data_ptr[input_offset + id0 * input_height * input_width + ih1 * input_width + iw1] += d0lambda * h1lambda * w1lambda * grad_output_value; /* i011 */
563 acc_data_ptr[input_offset + id1 * input_height * input_width + ih0 * input_width + iw0] += d1lambda * h0lambda * w0lambda * grad_output_value; /* i100 */
564 acc_data_ptr[input_offset + id1 * input_height * input_width + ih0 * input_width + iw1] += d1lambda * h0lambda * w1lambda * grad_output_value; /* i101 */
565 acc_data_ptr[input_offset + id1 * input_height * input_width + ih1 * input_width + iw0] += d1lambda * h1lambda * w0lambda * grad_output_value; /* i110 */
566 acc_data_ptr[input_offset + id1 * input_height * input_width + ih1 * input_width + iw1] += d1lambda * h1lambda * w1lambda * grad_output_value; /* i111 */
567 }
568 }
569 }
570 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
571 auto gin = grad_input_data + c * input_slice_size;
572 apply_grad_input(acc_data_ptr, gin, input_slice_size);
573 }
574 }
575 };
576
577 if (ndim == 3) {
578 // upsample linear 1d
579 at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 2, loop1d);
580 } else if (ndim == 4) {
581 // upsample bilinear 2d
582 at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d);
583 } else {
584 // upsample trilinear 3d
585 TORCH_INTERNAL_ASSERT(ndim == 5);
586 at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 8, loop3d);
587 }
588
589 if (!grad_input_.is_contiguous()) {
590 grad_input_.copy_(grad_input);
591 }
592 }
593
594 template <typename scalar_t, typename scale_type>
cpu_upsample_linear_backward_channels_last(const Tensor & grad_input_,const Tensor & grad_output_,bool align_corners,const scale_type & scales)595 void cpu_upsample_linear_backward_channels_last(
596 const Tensor& grad_input_,
597 const Tensor& grad_output_,
598 bool align_corners,
599 const scale_type& scales) {
600 TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
601 " for `grad_input` but got dtype ", grad_input_.dtype());
602
603 auto ndim = grad_output_.ndimension();
604 TORCH_CHECK(ndim >=4 && ndim <= 5, "Upsample with NHWC format supports tensors with 4 or 5 dims.")
605
606 auto channels_last_memory_format = ndim == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
607 auto grad_output = grad_output_.contiguous(channels_last_memory_format);
608 auto grad_input = grad_input_.contiguous(channels_last_memory_format);
609
610 auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
611 auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
612
613 auto input_sizes = grad_input.sizes().vec();
614 auto output_sizes = grad_output.sizes().vec();
615
616 int64_t num_batches = input_sizes[0];
617 int64_t channels = input_sizes[1];
618 int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
619 int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
620 int64_t input_height = input_sizes[ndim - 2];
621 int64_t output_height = output_sizes[ndim - 2];
622 int64_t input_width = input_sizes[ndim - 1];
623 int64_t output_width = output_sizes[ndim - 1];
624 int64_t input_slice_size = input_depth * input_height * input_width * channels;
625 using opmath_t = at::opmath_type<scalar_t>;
626
627 auto loop2d = [&](int64_t begin, int64_t end) {
628 opmath_t* acc_data_ptr = nullptr;
629 std::unique_ptr<opmath_t[]> buffer_data;
630 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
631 buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
632 acc_data_ptr = buffer_data.get();
633 memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
634 } else {
635 acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
636 }
637
638 const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
639 input_height, output_height, align_corners, scales[0]);
640 const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
641 input_width, output_width, align_corners, scales[1]);
642
643 auto input_indexr = [=](int64_t n, int64_t h, int64_t w, int64_t offset){
644 return acc_data_ptr + offset + (h * input_width + w) * channels;
645 };
646
647 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
648 int64_t ih0, ih1, iw0, iw1;
649 opmath_t h0lambda, h1lambda, w0lambda, w1lambda;
650 for (const auto n : c10::irange(begin, end)) {
651 int64_t input_offset = buffer_data.get() == nullptr ? n * input_slice_size : 0;
652 for (const auto oh : c10::irange(output_height)) {
653 compute_source_index_and_lambda(
654 ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
655 for (const auto ow : c10::irange(output_width)) {
656 compute_source_index_and_lambda(
657 iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
658 const scalar_t* grad_output_ptr = grad_output_data +
659 (n * output_height * output_width + oh * output_width + ow) * channels;
660 linear_channels_last_acc(input_indexr(n, ih0, iw0, input_offset), grad_output_ptr, h0lambda * w0lambda, channels); /* i00 */
661 linear_channels_last_acc(input_indexr(n, ih0, iw1, input_offset), grad_output_ptr, h0lambda * w1lambda, channels); /* i01 */
662 linear_channels_last_acc(input_indexr(n, ih1, iw0, input_offset), grad_output_ptr, h1lambda * w0lambda, channels); /* i10 */
663 linear_channels_last_acc(input_indexr(n, ih1, iw1, input_offset), grad_output_ptr, h1lambda * w1lambda, channels); /* i11 */
664 }
665 }
666 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
667 auto gin = grad_input_data + n * input_slice_size;
668 apply_grad_input(acc_data_ptr, gin, input_slice_size);
669 }
670
671 }
672 };
673
674 auto loop3d = [&](int64_t begin, int64_t end) {
675 opmath_t* acc_data_ptr = nullptr;
676 std::unique_ptr<opmath_t[]> buffer_data;
677 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
678 buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
679 acc_data_ptr = buffer_data.get();
680 memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
681 } else {
682 acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
683 }
684
685 const opmath_t depth_scale = area_pixel_compute_scale<opmath_t>(
686 input_depth, output_depth, align_corners, scales[0]);
687 const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
688 input_height, output_height, align_corners, scales[1]);
689 const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
690 input_width, output_width, align_corners, scales[2]);
691
692 auto input_indexr = [=](int64_t n, int64_t d, int64_t h, int64_t w, int64_t offset) {
693 return acc_data_ptr + offset + (d * input_height * input_width + h * input_width + w) * channels;
694 };
695
696 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
697 int64_t id0, id1, ih0, ih1, iw0, iw1;
698 opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda;
699 for (const auto n : c10::irange(begin, end)) {
700 int64_t input_offset = buffer_data.get() == nullptr ? n * input_slice_size : 0;
701 for (const auto od : c10::irange(output_depth)) {
702 compute_source_index_and_lambda(
703 id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners);
704 for (const auto oh : c10::irange(output_height)) {
705 compute_source_index_and_lambda(
706 ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
707 for (const auto ow : c10::irange(output_width)) {
708 compute_source_index_and_lambda(
709 iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
710 const scalar_t* grad_output_ptr = grad_output_data + (n * output_depth * output_height * output_width +
711 od * output_height * output_width + oh * output_width + ow) * channels;
712 linear_channels_last_acc(input_indexr(n, id0, ih0, iw0, input_offset), grad_output_ptr, d0lambda * h0lambda * w0lambda, channels); /* i000 */
713 linear_channels_last_acc(input_indexr(n, id0, ih0, iw1, input_offset), grad_output_ptr, d0lambda * h0lambda * w1lambda, channels); /* i001 */
714 linear_channels_last_acc(input_indexr(n, id0, ih1, iw0, input_offset), grad_output_ptr, d0lambda * h1lambda * w0lambda, channels); /* i010 */
715 linear_channels_last_acc(input_indexr(n, id0, ih1, iw1, input_offset), grad_output_ptr, d0lambda * h1lambda * w1lambda, channels); /* i011 */
716 linear_channels_last_acc(input_indexr(n, id1, ih0, iw0, input_offset), grad_output_ptr, d1lambda * h0lambda * w0lambda, channels); /* i100 */
717 linear_channels_last_acc(input_indexr(n, id1, ih0, iw1, input_offset), grad_output_ptr, d1lambda * h0lambda * w1lambda, channels); /* i101 */
718 linear_channels_last_acc(input_indexr(n, id1, ih1, iw0, input_offset), grad_output_ptr, d1lambda * h1lambda * w0lambda, channels); /* i110 */
719 linear_channels_last_acc(input_indexr(n, id1, ih1, iw1, input_offset), grad_output_ptr, d1lambda * h1lambda * w1lambda, channels); /* i111 */
720 }
721 }
722 }
723 if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
724 auto gin = grad_input_data + n * input_slice_size;
725 apply_grad_input(acc_data_ptr, gin, input_slice_size);
726 }
727 }
728 };
729
730 if (ndim == 4) {
731 // upsample bilinear 2d
732 at::parallel_for(0, num_batches, 0, loop2d);
733 } else {
734 // upsample trilinear 3d
735 TORCH_INTERNAL_ASSERT(ndim == 5);
736 at::parallel_for(0, num_batches, 0, loop3d);
737 }
738
739 if (!grad_input_.is_contiguous(channels_last_memory_format)) {
740 grad_input_.copy_(grad_input);
741 }
742 }
743
upsample_linear1d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,bool align_corners,std::optional<double> scales_w)744 void upsample_linear1d_backward_kernel_impl(
745 const Tensor& grad_input,
746 const Tensor& grad_output,
747 bool align_corners,
748 std::optional<double> scales_w) {
749 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_linear1d_backward", [&] {
750 cpu_upsample_linear_backward<scalar_t, scale_t>(grad_input, grad_output, align_corners, {scales_w});
751 });
752 }
753
upsample_bilinear2d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)754 void upsample_bilinear2d_backward_kernel_impl(
755 const Tensor& grad_input,
756 const Tensor& grad_output,
757 bool align_corners,
758 std::optional<double> scales_h,
759 std::optional<double> scales_w) {
760 if (grad_output.is_contiguous(at::MemoryFormat::ChannelsLast)) {
761 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_bilinear2d_backward_channels_last", [&] {
762 cpu_upsample_linear_backward_channels_last<scalar_t, scale_t>(grad_input, grad_output, align_corners, {scales_h, scales_w});
763 });
764 } else {
765 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_bilinear2d_backward", [&] {
766 cpu_upsample_linear_backward<scalar_t, scale_t>(grad_input, grad_output, align_corners, {scales_h, scales_w});
767 });
768 }
769 }
770
upsample_trilinear3d_backward_kernel_impl(const Tensor & grad_input,const Tensor & grad_output,bool align_corners,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)771 void upsample_trilinear3d_backward_kernel_impl(
772 const Tensor& grad_input,
773 const Tensor& grad_output,
774 bool align_corners,
775 std::optional<double> scales_d,
776 std::optional<double> scales_h,
777 std::optional<double> scales_w) {
778 if (grad_output.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
779 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_trilinear3d_backward_channels_last", [&] {
780 cpu_upsample_linear_backward_channels_last<scalar_t, scale_t>(grad_input, grad_output, align_corners, {scales_d, scales_h, scales_w});
781 });
782 } else {
783 AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "upsample_trilinear3d_backward", [&] {
784 cpu_upsample_linear_backward<scalar_t, scale_t>(grad_input, grad_output, align_corners, {scales_d, scales_h, scales_w});
785 });
786 }
787 }
788
789 } // anonymous namespace
790
791 REGISTER_DISPATCH(upsample_nearest1d_backward_kernel, &upsample_nearest1d_backward_kernel_impl);
792 REGISTER_DISPATCH(_upsample_nearest_exact1d_backward_kernel, &_upsample_nearest_exact1d_backward_kernel_impl);
793 REGISTER_DISPATCH(upsample_nearest2d_backward_kernel, &upsample_nearest2d_backward_kernel_impl);
794 REGISTER_DISPATCH(_upsample_nearest_exact2d_backward_kernel, &_upsample_nearest_exact2d_backward_kernel_impl);
795 REGISTER_DISPATCH(upsample_nearest3d_backward_kernel, &upsample_nearest3d_backward_kernel_impl);
796 REGISTER_DISPATCH(_upsample_nearest_exact3d_backward_kernel, &_upsample_nearest_exact3d_backward_kernel_impl);
797
798 REGISTER_DISPATCH(upsample_linear1d_backward_kernel, &upsample_linear1d_backward_kernel_impl);
799 REGISTER_DISPATCH(upsample_bilinear2d_backward_kernel, &upsample_bilinear2d_backward_kernel_impl);
800 REGISTER_DISPATCH(upsample_trilinear3d_backward_kernel, &upsample_trilinear3d_backward_kernel_impl);
801
802 } // namespace at::native
803