xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Adapted from interp.cpp from Caffe util by Pauline Luc
2 // Originally developed by George Papandreou
3 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
4 #include <ATen/core/Tensor.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/ceil_div.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/Utils.h>
10 #include <ATen/cuda/CUDAContext.h>
11 #include <ATen/native/cuda/UpSample.cuh>
12 #include <ATen/native/cuda/KernelUtils.cuh>
13 #include <ATen/cuda/detail/KernelUtils.h>
14 #include <ATen/native/cuda/LaunchUtils.h>
15 
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/_upsample_bicubic2d_aa_backward_native.h>
21 #include <ATen/ops/_upsample_bicubic2d_aa_native.h>
22 #include <ATen/ops/_upsample_bilinear2d_aa_backward_native.h>
23 #include <ATen/ops/_upsample_bilinear2d_aa_native.h>
24 #include <ATen/ops/empty.h>
25 #include <ATen/ops/upsample_bilinear2d_backward_native.h>
26 #include <ATen/ops/upsample_bilinear2d_native.h>
27 #include <ATen/ops/zeros.h>
28 #endif
29 
30 namespace at::native {
31 namespace {
32 
33 template <typename scalar_t, typename accscalar_t>
34 C10_LAUNCH_BOUNDS_1(1024)
upsample_bilinear2d_out_frame(const int n,const accscalar_t rheight,const accscalar_t rwidth,const bool align_corners,const PackedTensorAccessor<const scalar_t,4> idata,PackedTensorAccessor<scalar_t,4> odata)35 __global__ void upsample_bilinear2d_out_frame(
36     const int n,
37     const accscalar_t rheight,
38     const accscalar_t rwidth,
39     const bool align_corners,
40     const PackedTensorAccessor<const scalar_t, 4> idata,
41     PackedTensorAccessor<scalar_t, 4> odata) {
42   int index = threadIdx.x + blockIdx.x * blockDim.x;
43 
44   const int batchsize = idata.size(0);
45   const int channels = idata.size(1);
46   const int height1 = idata.size(2);
47   const int width1 = idata.size(3);
48   const int width2 = odata.size(3);
49 
50   if (index < n) {
51     const int w2 = index % width2; // 0:width2-1
52     const int h2 = index / width2; // 0:height2-1
53 
54     const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
55         rheight, h2, align_corners, /*cubic=*/false);
56     const int h1 = h1r;
57     const int h1p = (h1 < height1 - 1) ? 1 : 0;
58     const accscalar_t h1lambda = h1r - h1;
59     const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
60     //
61     const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
62         rwidth, w2, align_corners, /*cubic=*/false);
63     const int w1 = w1r;
64     const int w1p = (w1 < width1 - 1) ? 1 : 0;
65     const accscalar_t w1lambda = w1r - w1;
66     const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
67     //
68     for (int n = 0; n < batchsize; n++) {
69       for (int c = 0; c < channels; ++c) {
70         const accscalar_t val = h0lambda *
71                 (w0lambda * idata[n][c][h1][w1] +
72                  w1lambda * idata[n][c][h1][w1 + w1p]) +
73             h1lambda *
74                 (w0lambda * idata[n][c][h1 + h1p][w1] +
75                  w1lambda * idata[n][c][h1 + h1p][w1 + w1p]);
76         odata[n][c][h2][w2] = static_cast<scalar_t>(val);
77       }
78     }
79   }
80 }
81 
82 template <typename scalar_t, typename accscalar_t>
83 C10_LAUNCH_BOUNDS_1(1024)
upsample_bilinear2d_nhwc_out_frame(const accscalar_t rheight,const accscalar_t rwidth,const bool align_corners,const int channels,const int height1,const int width1,const int height2,const int width2,const scalar_t * idata,scalar_t * odata,const int out_numel)84 __global__ void upsample_bilinear2d_nhwc_out_frame(
85     const accscalar_t rheight,
86     const accscalar_t rwidth,
87     const bool align_corners,
88     const int channels,
89     const int height1,
90     const int width1,
91     const int height2,
92     const int width2,
93     const scalar_t* idata,
94     scalar_t* odata,
95     const int out_numel) {
96 
97   const int index = blockIdx.x * blockDim.x + threadIdx.x;
98 
99   if (index < out_numel) {
100     const int c = index % channels;
101     const int w2 = (index / channels) % width2;
102     const int h2 = (index / channels / width2) % height2;
103     const int n = index / channels / width2 / height2;
104 
105     const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
106         rheight, h2, align_corners, /*cubic=*/false);
107     const int h1 = h1r;
108     const int h1p = (h1 < height1 - 1) ? 1 : 0;
109     const accscalar_t h1lambda = h1r - h1;
110     const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
111 
112     const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
113         rwidth, w2, align_corners, /*cubic=*/false);
114     const int w1 = w1r;
115     const int w1p = (w1 < width1 - 1) ? 1 : 0;
116     const accscalar_t w1lambda = w1r - w1;
117     const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
118 
119     const accscalar_t val = h0lambda * (
120         w0lambda * idata[idx_cl(n, h1, w1, c, height1, width1, channels)] +
121         w1lambda * idata[idx_cl(n, h1, w1 + w1p, c, height1, width1, channels)]
122       ) + h1lambda * (
123         w0lambda * idata[idx_cl(n, h1 + h1p, w1, c, height1, width1, channels)] +
124         w1lambda * idata[idx_cl(n, h1 + h1p, w1 + w1p, c, height1, width1, channels)]
125       );
126     odata[idx_cl(n, h2, w2, c, height2, width2, channels)] = static_cast<scalar_t>(val);
127   }
128 }
129 
130 // Backward (adjoint) operation 1 <- 2 (accumulates)
131 template <typename scalar_t, typename accscalar_t>
132 C10_LAUNCH_BOUNDS_1(1024)
upsample_bilinear2d_backward_out_frame(const size_t nc,const int height1,const int width1,const int height2,const int width2,const accscalar_t rheight,const accscalar_t rwidth,const bool align_corners,scalar_t * __restrict__ idata,const scalar_t * __restrict__ odata)133 __global__ void upsample_bilinear2d_backward_out_frame(
134     const size_t nc,
135     const int height1,
136     const int width1,
137     const int height2,
138     const int width2,
139     const accscalar_t rheight,
140     const accscalar_t rwidth,
141     const bool align_corners,
142     scalar_t* __restrict__ idata,
143     const scalar_t* __restrict__ odata) {
144   const size_t o_numel = nc * width2 * height2;
145   const size_t i_numel = nc * width1 * height1;
146   for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
147        index += blockDim.x * gridDim.x) {
148     size_t index_temp = index;
149     const int w2 = index_temp % width2; // 0:width2-1
150     index_temp /= width2;
151     const int h2 = index_temp % height2; // 0:height2-1
152     const size_t nc = index_temp / height2;
153     //
154     const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
155         rheight, h2, align_corners, /*cubic=*/false);
156     const int h1 = h1r;
157     const int h1p = (h1 < height1 - 1) ? 1 : 0;
158     const accscalar_t h1lambda = h1r - h1;
159     const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
160     //
161     const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
162         rwidth, w2, align_corners, /*cubic=*/false);
163     const int w1 = w1r;
164     const int w1p = (w1 < width1 - 1) ? 1 : 0;
165     const accscalar_t w1lambda = w1r - w1;
166     const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
167     //
168     const scalar_t d2val = odata[index];
169     fastAtomicAdd(
170         idata,
171         idx(nc, height1, width1, h1, w1),
172         i_numel,
173         static_cast<scalar_t>(h0lambda * w0lambda * d2val),
174         true);
175     fastAtomicAdd(
176         idata,
177         idx(nc, height1, width1, h1, w1 + w1p),
178         i_numel,
179         static_cast<scalar_t>(h0lambda * w1lambda * d2val),
180         true);
181     fastAtomicAdd(
182         idata,
183         idx(nc, height1, width1, h1 + h1p, w1),
184         i_numel,
185         static_cast<scalar_t>(h1lambda * w0lambda * d2val),
186         true);
187     fastAtomicAdd(
188         idata,
189         idx(nc, height1, width1, h1 + h1p, w1 + w1p),
190         i_numel,
191         static_cast<scalar_t>(h1lambda * w1lambda * d2val),
192         true);
193   }
194 }
195 
196 template <typename scalar_t, typename accscalar_t>
197 C10_LAUNCH_BOUNDS_1(1024)
upsample_bilinear2d_backward_nhwc_out_frame(const int height1,const int width1,const int height2,const int width2,const accscalar_t rheight,const accscalar_t rwidth,const bool align_corners,scalar_t * __restrict__ idata,const scalar_t * __restrict__ odata,const int channels,const size_t o_numel,const size_t i_numel)198 __global__ void upsample_bilinear2d_backward_nhwc_out_frame(
199     const int height1,
200     const int width1,
201     const int height2,
202     const int width2,
203     const accscalar_t rheight,
204     const accscalar_t rwidth,
205     const bool align_corners,
206     scalar_t* __restrict__ idata,
207     const scalar_t* __restrict__ odata,
208     const int channels,
209     const size_t o_numel,
210     const size_t i_numel) {
211 
212   const int index = blockIdx.x * blockDim.x + threadIdx.x;
213 
214   if (index < o_numel) {
215     const int c = index % channels;
216     const int w2 = (index / channels) % width2;
217     const int h2 = (index / channels / width2) % height2;
218     const int n = index / channels / width2 / height2;
219 
220     const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
221         rheight, h2, align_corners, /*cubic=*/false);
222     const int h1 = h1r;
223     const int h1p = (h1 < height1 - 1) ? 1 : 0;
224     const accscalar_t h1lambda = h1r - h1;
225     const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
226 
227     const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
228         rwidth, w2, align_corners, /*cubic=*/false);
229     const int w1 = w1r;
230     const int w1p = (w1 < width1 - 1) ? 1 : 0;
231     const accscalar_t w1lambda = w1r - w1;
232     const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
233 
234     const scalar_t d2val = odata[index];
235     fastAtomicAdd(
236         idata,
237         idx_cl(n, h1, w1, c, height1, width1, channels),
238         i_numel,
239         static_cast<scalar_t>(h0lambda * w0lambda * d2val),
240         true);
241     fastAtomicAdd(
242         idata,
243         idx_cl(n, h1, w1 + w1p, c, height1, width1, channels),
244         i_numel,
245         static_cast<scalar_t>(h0lambda * w1lambda * d2val),
246         true);
247     fastAtomicAdd(
248         idata,
249         idx_cl(n, h1 + h1p, w1, c, height1, width1, channels),
250         i_numel,
251         static_cast<scalar_t>(h1lambda * w0lambda * d2val),
252         true);
253     fastAtomicAdd(
254         idata,
255         idx_cl(n, h1 + h1p, w1 + w1p, c, height1, width1, channels),
256         i_numel,
257         static_cast<scalar_t>(h1lambda * w1lambda * d2val),
258         true);
259   }
260 }
261 
upsample_bilinear2d_out_cuda_template(const Tensor & output,const Tensor & input,IntArrayRef output_size,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)262 static void upsample_bilinear2d_out_cuda_template(
263     const Tensor& output,
264     const Tensor& input,
265     IntArrayRef output_size,
266     bool align_corners,
267     std::optional<double> scales_h,
268     std::optional<double> scales_w) {
269   TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2};
270   checkAllSameGPU(__func__, {input_arg, output_arg});
271 
272   int output_height = output_size[0];
273   int output_width = output_size[1];
274 
275   int channels = input.size(1);
276   int input_height = input.size(2);
277   int input_width = input.size(3);
278 
279   const auto memory_format = input.suggest_memory_format();
280 
281   if (input.sizes() == output.sizes()) {
282     output.copy_(input);
283     return;
284   }
285 
286   AT_DISPATCH_FLOATING_TYPES_AND2(
287       at::ScalarType::Half, at::ScalarType::BFloat16,
288       input.scalar_type(), "upsample_bilinear2d_out_frame", [&] {
289     // heuristic: only use channels_last path when it's faster than the contiguous path
290     if (memory_format == at::MemoryFormat::ChannelsLast && channels >= 16 && \
291           output.is_contiguous(memory_format)) {
292       using accscalar_t = at::acc_type<scalar_t, true>;
293 
294       TORCH_CHECK(input.numel() < std::numeric_limits<int>::max(),
295         "upsample_bilinear2d_nhwc only supports input tensors with less than INT_MAX elements, but got ", input.sizes());
296       TORCH_CHECK(output.numel() < std::numeric_limits<int>::max(),
297         "upsample_bilinear2d_nhwc only supports output tensors with less than INT_MAX elements, but got ", output.sizes());
298 
299       const int channels = input.size(1);
300       const int height1 = input.size(2);
301       const int width1 = input.size(3);
302       const int height2 = output.size(2);
303       const int width2 = output.size(3);
304 
305       // const int num_kernels = output_height * output_width;
306       const int num_kernels = output.numel();
307       const int num_threads = std::min(
308           at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
309 
310       at::Tensor input_cl = input.contiguous(at::MemoryFormat::ChannelsLast);
311 
312       const scalar_t* idata = input_cl.const_data_ptr<scalar_t>();
313       scalar_t* odata = output.mutable_data_ptr<scalar_t>();
314 
315       const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
316           input_height, output_height, align_corners, scales_h);
317       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
318           input_width, output_width, align_corners, scales_w);
319 
320       upsample_bilinear2d_nhwc_out_frame<scalar_t, accscalar_t>
321         <<<ceil_div(num_kernels, num_threads), num_threads, 0, at::cuda::getCurrentCUDAStream()>>>(
322           rheight, rwidth, align_corners,
323           channels,
324           height1,
325           width1,
326           height2,
327           width2,
328           idata, odata,
329           output.numel());
330       C10_CUDA_KERNEL_LAUNCH_CHECK();
331     } else {
332       // non-channels_last case, not necessarily contiguous
333       const int num_kernels = output_height * output_width;
334       const int num_threads = std::min(
335           at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
336       cudaStream_t stream = at::cuda::getCurrentCUDAStream();
337 
338       using accscalar_t = at::acc_type<scalar_t, true>;
339 
340       auto idata = input.packed_accessor64<const scalar_t, 4>();
341       auto odata = output.packed_accessor64<scalar_t, 4>();
342 
343       const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
344           input_height, output_height, align_corners, scales_h);
345       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
346           input_width, output_width, align_corners, scales_w);
347 
348       upsample_bilinear2d_out_frame<scalar_t, accscalar_t>
349           <<<ceil_div(num_kernels, num_threads),
350              num_threads,
351              0,
352              stream>>>(
353               num_kernels, rheight, rwidth, align_corners, idata, odata);
354       C10_CUDA_KERNEL_LAUNCH_CHECK();
355     }
356   });
357 }
358 
upsample_bilinear2d_backward_out_cuda_template(const Tensor & grad_input,const Tensor & grad_output_,IntArrayRef output_size,IntArrayRef input_size,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)359 static void upsample_bilinear2d_backward_out_cuda_template(
360     const Tensor& grad_input,
361     const Tensor& grad_output_,
362     IntArrayRef output_size,
363     IntArrayRef input_size,
364     bool align_corners,
365     std::optional<double> scales_h,
366     std::optional<double> scales_w) {
367   TensorArg grad_input_arg{grad_input, "grad_input", 1},
368       grad_output_arg{grad_output_, "grad_output_", 2};
369   checkAllSameGPU(__func__, {grad_output_arg, grad_input_arg});
370 
371   int output_height = output_size[0];
372   int output_width = output_size[1];
373 
374   int nbatch = input_size[0];
375   int channels = input_size[1];
376   int input_height = input_size[2];
377   int input_width = input_size[3];
378 
379   if (grad_input.numel() == 0) {
380     return;
381   }
382 
383   const auto memory_format = grad_output_.suggest_memory_format();
384 
385   // initialization to zero is required here. As we launch one thread per output
386   // element, and atomicAdd to input gradient. Given a sparse sampling case, our
387   // threads are not covering the whole input tensor.
388   grad_input.zero_();
389 
390   const size_t num_kernels = nbatch * channels * output_height * output_width;
391   const int num_threads = std::min(
392       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
393   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
394 
395   if (grad_output_.sizes() == grad_input.sizes()) {
396     grad_input.copy_(grad_output_);
397     return;
398   }
399 
400   AT_DISPATCH_FLOATING_TYPES_AND2(
401       at::ScalarType::Half, at::ScalarType::BFloat16,
402       grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
403     if (memory_format == at::MemoryFormat::ChannelsLast && channels >= 4 && \
404           grad_input.is_contiguous(memory_format)) {
405       using accscalar_t = at::acc_type<scalar_t, true>;
406 
407       Tensor grad_output = grad_output_.contiguous(at::MemoryFormat::ChannelsLast);
408 
409       auto idata = grad_input.mutable_data_ptr<scalar_t>();
410       auto odata = grad_output.const_data_ptr<scalar_t>();
411 
412       const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
413           input_height, output_height, align_corners, scales_h);
414       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
415           input_width, output_width, align_corners, scales_w);
416 
417       upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
418           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
419               input_height,
420               input_width,
421               output_height,
422               output_width,
423               rheight,
424               rwidth,
425               align_corners,
426               idata,
427               odata,
428               channels,
429               grad_output.numel(),
430               grad_input.numel());
431       C10_CUDA_KERNEL_LAUNCH_CHECK();
432     } else {
433       using accscalar_t = at::acc_type<scalar_t, true>;
434 
435       // This is needed for non-contiguous tensors.
436       Tensor grad_input_c = grad_input.is_contiguous() ? grad_input : at::zeros(grad_input.sizes(), grad_input.options());
437       Tensor grad_output = grad_output_.contiguous();
438 
439       auto idata = grad_input_c.mutable_data_ptr<scalar_t>();
440       auto odata = grad_output.const_data_ptr<scalar_t>();
441 
442       const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
443           input_height, output_height, align_corners, scales_h);
444       const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
445           input_width, output_width, align_corners, scales_w);
446 
447       upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
448           <<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
449              num_threads,
450              0,
451              stream>>>(
452               nbatch * channels,
453               input_height,
454               input_width,
455               output_height,
456               output_width,
457               rheight,
458               rwidth,
459               align_corners,
460               idata,
461               odata);
462       C10_CUDA_KERNEL_LAUNCH_CHECK();
463 
464       if (!grad_input.is_contiguous()) {
465           grad_input.copy_(grad_input_c);
466       }
467     }
468   });
469 }
470 
471 // Code for upsampling with antialias
472 template <typename scalar_t, typename accscalar_t, typename InterpFilter>
473 C10_LAUNCH_BOUNDS_1(256) // 256 performs better then 1024
upsample_gen2d_aa_out_frame(const accscalar_t height_scale,const accscalar_t width_scale,const PackedTensorAccessor64<const scalar_t,4> idata,PackedTensorAccessor64<scalar_t,4> odata,const InterpFilter & interp_filter)474 __global__ void upsample_gen2d_aa_out_frame(
475     const accscalar_t height_scale,
476     const accscalar_t width_scale,
477     const PackedTensorAccessor64<const scalar_t, 4> idata,
478     PackedTensorAccessor64<scalar_t, 4> odata,
479     const InterpFilter & interp_filter) {
480 
481   const int batchsize = idata.size(0);
482   const int channels = idata.size(1);
483   const int input_height = idata.size(2);
484   const int input_width = idata.size(3);
485   const int output_height = odata.size(2);
486   const int output_width = odata.size(3);
487 
488   const int output_x = threadIdx.x + blockIdx.x * blockDim.x;
489   const int output_y = threadIdx.y + blockIdx.y * blockDim.y;
490 
491   if (output_x >= output_width || output_y >= output_height) {
492     return;
493   }
494 
495   const accscalar_t half = 0.5;
496   const accscalar_t support_h = static_cast<accscalar_t>(
497       (height_scale >= 1.0) ? (interp_filter.size * half) * height_scale : interp_filter.size * half);
498   const accscalar_t support_w = static_cast<accscalar_t>(
499       (width_scale >= 1.0) ? (interp_filter.size * half) * width_scale : interp_filter.size * half);
500 
501   const int interp_height = (int)ceilf(support_h) * 2 + 1;
502   const int interp_width = (int)ceilf(support_w) * 2 + 1;
503 
504   // Setup weights and a buffer using shared memory
505   extern __shared__ int smem[];
506   scalar_t* wx = reinterpret_cast<scalar_t*>(smem) + interp_width * threadIdx.x;
507   scalar_t* wy = reinterpret_cast<scalar_t*>(smem) + interp_width * blockDim.x + interp_height * threadIdx.y;
508   const int offset = interp_width * blockDim.x + interp_height * blockDim.y;
509   scalar_t *buffer2 = reinterpret_cast<scalar_t*>(smem) + offset + \
510       interp_height * (threadIdx.x + threadIdx.y * blockDim.x);
511 
512   // Compute weights and kernel spans
513   int xmin, xsize, ymin, ysize;
514   accscalar_t xcenter, ycenter;
515   upsample_antialias::_compute_weights_span(
516       output_x, input_width, width_scale, support_w, xmin, xsize, xcenter);
517   upsample_antialias::_compute_weights_span(
518       output_y, input_height, height_scale, support_h, ymin, ysize, ycenter);
519 
520   if (threadIdx.y == 0)
521   {
522     // All threadIdx.y have the same wx weights
523     upsample_antialias::_compute_weights<scalar_t, accscalar_t>(
524         wx,
525         width_scale,
526         interp_width,
527         interp_filter,
528         xmin - xcenter,
529         xsize);
530   }
531 
532   if (threadIdx.x == 0)
533   {
534     // All threadIdx.x have the same wy weights
535     upsample_antialias::_compute_weights<scalar_t, accscalar_t>(
536         wy,
537         height_scale,
538         interp_height,
539         interp_filter,
540         ymin - ycenter,
541         ysize);
542   }
543 
544   __syncthreads();
545 
546   const scalar_t * buffer1;
547 
548   // Parallelized across batch/channels
549   for (int i = blockIdx.z; i < batchsize * channels; i += gridDim.z) {
550     int n = i / channels;
551     int c = i % channels;
552     // interpolate on y-axis for ymin to ymin + ysize
553     for (int y = 0; y < ysize; y++) {
554       buffer1 = &(idata[n][c][ymin + y][xmin]);
555       buffer2[y] = static_cast<scalar_t>(
556           upsample_antialias::interpolate_aa_single_dim<scalar_t, accscalar_t>(
557               buffer1, wx, xsize));
558     }
559     odata[n][c][output_y][output_x] = static_cast<scalar_t>(
560         upsample_antialias::interpolate_aa_single_dim<scalar_t, accscalar_t>(
561             buffer2, wy, ysize));
562   }
563 }
564 
565 // Code for upsampling with antialias
566 template <typename scalar_t, typename accscalar_t, typename InterpFilter>
567 C10_LAUNCH_BOUNDS_1(256) // 256 performs better then 1024
upsample_gen2d_aa_backward_out_frame(const accscalar_t height_scale,const accscalar_t width_scale,PackedTensorAccessor64<scalar_t,4> idata,const PackedTensorAccessor64<const scalar_t,4> odata,const InterpFilter & interp_filter)568 __global__ void upsample_gen2d_aa_backward_out_frame(
569     const accscalar_t height_scale,
570     const accscalar_t width_scale,
571     PackedTensorAccessor64<scalar_t, 4> idata,
572     const PackedTensorAccessor64<const scalar_t, 4> odata,
573     const InterpFilter & interp_filter) {
574 
575   const int batchsize = idata.size(0);
576   const int channels = idata.size(1);
577   const int input_height = idata.size(2);
578   const int input_width = idata.size(3);
579   const int output_height = odata.size(2);
580   const int output_width = odata.size(3);
581 
582   const int output_x = threadIdx.x + blockIdx.x * blockDim.x;
583   const int output_y = threadIdx.y + blockIdx.y * blockDim.y;
584 
585   if (output_x >= output_width || output_y >= output_height) {
586     return;
587   }
588 
589   // special case: output just copy
590   if (input_height == output_height && input_width == output_width) {
591     for (int i = blockIdx.z; i < batchsize * channels; i += gridDim.z) {
592       int n = i / channels;
593       int c = i % channels;
594       const scalar_t val = odata[n][c][output_y][output_x];
595       idata[n][c][output_y][output_x] = val;
596     }
597     return;
598   }
599 
600   const accscalar_t support_h = static_cast<accscalar_t>(
601       (height_scale >= 1.0) ? (interp_filter.size * 0.5) * height_scale
602                             : interp_filter.size * 0.5);
603   const accscalar_t support_w = static_cast<accscalar_t>(
604       (width_scale >= 1.0) ? (interp_filter.size * 0.5) * width_scale
605                            : interp_filter.size * 0.5);
606 
607   const int interp_height = (int)ceilf(support_h) * 2 + 1;
608   const int interp_width = (int)ceilf(support_w) * 2 + 1;
609 
610   // Setup weights using shared memory
611   extern __shared__ int smem[];
612   scalar_t* wx = reinterpret_cast<scalar_t*>(smem) + interp_width * threadIdx.x;
613   scalar_t* wy = reinterpret_cast<scalar_t*>(smem) + interp_width * blockDim.x + interp_height * threadIdx.y;
614 
615   // Compute weights and kernel spans
616   int xmin, xsize, ymin, ysize;
617   accscalar_t xcenter, ycenter;
618   upsample_antialias::_compute_weights_span(
619       output_x, input_width, width_scale, support_w, xmin, xsize, xcenter);
620   upsample_antialias::_compute_weights_span(
621       output_y, input_height, height_scale, support_h, ymin, ysize, ycenter);
622 
623   if (threadIdx.y == 0)
624   {
625     // All threadIdx.y have the same wx weights
626     upsample_antialias::_compute_weights<scalar_t, accscalar_t>(
627         wx,
628         width_scale,
629         interp_width,
630         interp_filter,
631         xmin - xcenter,
632         xsize);
633   }
634 
635   if (threadIdx.x == 0)
636   {
637     // All threadIdx.x have the same wy weights
638     upsample_antialias::_compute_weights<scalar_t, accscalar_t>(
639         wy,
640         height_scale,
641         interp_height,
642         interp_filter,
643         ymin - ycenter,
644         ysize);
645   }
646 
647   __syncthreads();
648 
649   // Parallelized across batch/channels
650   for (int i = blockIdx.z; i < batchsize * channels; i += gridDim.z) {
651     int n = i / channels;
652     int c = i % channels;
653     scalar_t out_value = odata[n][c][output_y][output_x];
654     for (int y = 0; y < ysize; y++) {
655       for (int x = 0; x < xsize; x++) {
656         upsample_increment_value_bounded<scalar_t, accscalar_t>(
657             idata,
658             n,
659             c,
660             input_height,
661             input_width,
662             ymin + y,
663             xmin + x,
664             wx[x] * wy[y] * out_value);
665       }
666     }
667   }
668 }
669 
670 // In the code below interp_filter_t distinguishes between bilinear and bicubic interpolations
671 // InterpFilter as BilinearFilterFunctor <--> bilinear
672 // InterpFilter as BicubicFilterFunctor <--> bicubic
673 template<typename InterpFilter>
upsample_gen2d_aa_out_cuda_template(const Tensor & output,const Tensor & input_,IntArrayRef output_size,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)674 static void upsample_gen2d_aa_out_cuda_template(
675     const Tensor& output,
676     const Tensor& input_,
677     IntArrayRef output_size,
678     bool align_corners,
679     std::optional<double> scales_h,
680     std::optional<double> scales_w) {
681   TensorArg input_arg{input_, "input_", 1}, output_arg{output, "output", 2};
682   checkAllSameGPU("upsample_gen2d_aa_out_cuda", {input_arg, output_arg});
683 
684   // TODO: remove this when the cuda kernel is updated to support the channels_last memory format.
685   // This is a temporary hack to prevent a silence correctness issue when calling this kernel
686   // with tensors in channels_last format.
687   auto output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
688   auto input = input_.contiguous();
689 
690   int output_height = output_size[0];
691   int output_width = output_size[1];
692 
693   int input_height = input.size(2);
694   int input_width = input.size(3);
695 
696   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
697   size_t sharedMemPerBlock = at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
698   int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
699   int maxThreadsPerBlock = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 256);
700   int* maxGridSize = at::cuda::getCurrentDeviceProperties()->maxGridSize;
701   int block_x = std::min<int>(maxThreadsDim[0], at::cuda::warp_size());
702   int grid_x = std::min<int>(maxGridSize[0], ceil_div(output_width, block_x));
703 
704   AT_DISPATCH_FLOATING_TYPES_AND2(
705       at::ScalarType::Half, at::ScalarType::BFloat16,
706       input.scalar_type(), "upsample_bilinear2d_out_frame", [&] {
707         using accscalar_t = at::acc_type<scalar_t, true>;
708 
709         auto idata = input.packed_accessor64<const scalar_t, 4>();
710         auto odata = output_c.packed_accessor64<scalar_t, 4>();
711 
712         const accscalar_t height_scale = area_pixel_compute_scale<accscalar_t>(
713             input_height, output_height, align_corners, scales_h);
714         const accscalar_t width_scale = area_pixel_compute_scale<accscalar_t>(
715             input_width, output_width, align_corners, scales_w);
716 
717         // We are using shared memory to store weights wx, wy and a buffer of size wy unique per thread
718         // Let's compute block_y size depending on given height_scale and width_scale
719         // We have the following relationship:
720         // shmem_size / sizeofdtype =
721         //  interp_width * block_x +   <-- wx allocation
722         //  interp_height * block_y * (block_x + 1)   <-- wy and buffer allocations
723 
724         auto interp_filter = InterpFilter();
725         const int interp_height = 1 + 2 * (int)ceilf(
726             (height_scale >= 1.0) ? interp_filter.size * 0.5 * height_scale : interp_filter.size * 0.5);
727         const int interp_width = 1 + 2 * (int)ceilf(
728             (width_scale >= 1.0) ? interp_filter.size * 0.5 * width_scale : interp_filter.size * 0.5);
729 
730         int numer = sharedMemPerBlock * 1.0 / sizeof(scalar_t) - interp_width * block_x;
731         int denom = interp_height * (block_x + 1);
732         int block_y = lastPow2((unsigned int) (numer / denom));
733         block_y = std::min<int>(maxThreadsPerBlock / block_x, block_y);
734         const dim3 block(block_x, block_y);
735 
736         int grid_y = std::min<int>(maxGridSize[1], ceil_div(output_height, block_y));
737         int grid_z = std::min<int>(maxGridSize[2], input.size(0) * input.size(1));
738         const dim3 grid(grid_x, grid_y, grid_z);
739 
740         // Compute actual size of required shared memory and verify if we can allocate it
741         // - wx and wy size:
742         size_t weights_per_block = interp_width * block_x + interp_height * block_y;
743         // - buffer size:
744         weights_per_block += interp_height * block_y * block_x;
745         size_t shmem_size = weights_per_block * sizeof(scalar_t);
746         TORCH_CHECK(
747             shmem_size <= sharedMemPerBlock,
748             "Provided interpolation parameters can not be handled with current algorithm implementation. ",
749             "Please reduce the scale factor. Too much shared memory required: ",
750             shmem_size, " vs ", sharedMemPerBlock);
751 
752         upsample_gen2d_aa_out_frame<scalar_t, accscalar_t>
753             <<<grid,
754                block,
755                shmem_size,
756                stream>>>(height_scale, width_scale, idata, odata, interp_filter);
757         C10_CUDA_KERNEL_LAUNCH_CHECK();
758       });
759 
760   if (!output.is_contiguous()) {
761       output.copy_(output_c);
762   }
763 }
764 
765 // In the code below interp_filter_t distinguishes between bilinear and bicubic interpolations
766 // InterpFilter as BilinearFilterFunctor <--> bilinear
767 // InterpFilter as BicubicFilterFunctor <--> bicubic
768 template<typename InterpFilter>
upsample_gen2d_aa_backward_out_cuda_template(const Tensor & grad_input,const Tensor & grad_output_,IntArrayRef output_size,IntArrayRef input_size,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)769 static void upsample_gen2d_aa_backward_out_cuda_template(
770     const Tensor& grad_input,
771     const Tensor& grad_output_,
772     IntArrayRef output_size,
773     IntArrayRef input_size,
774     bool align_corners,
775     std::optional<double> scales_h,
776     std::optional<double> scales_w) {
777 
778   // Inspired from UpSampleBicubic2d.cu::upsample_bicubic2d_backward_out_cuda_template
779   TensorArg grad_input_arg{grad_input, "grad_input", 1},
780       grad_output_arg{grad_output_, "grad_output_", 2};
781   checkAllSameGPU(
782       "upsample_gen2d_backward_out_cuda", {grad_output_arg, grad_input_arg});
783 
784   int output_height = output_size[0];
785   int output_width = output_size[1];
786 
787   int input_height = input_size[2];
788   int input_width = input_size[3];
789 
790   Tensor grad_output = grad_output_.contiguous();
791 
792   grad_input.zero_();
793 
794   const int num_threads = std::min(at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 256);
795   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
796 
797   int* maxThreadsDim = at::cuda::getCurrentDeviceProperties()->maxThreadsDim;
798   int block_x = std::min<int>(maxThreadsDim[0], at::cuda::warp_size());
799   int block_y = std::min<int>(maxThreadsDim[1], num_threads / block_x);
800   const dim3 block(block_x, block_y);
801 
802   int* maxGridSize = at::cuda::getCurrentDeviceProperties()->maxGridSize;
803   int grid_x = std::min<int>(maxGridSize[0], ceil_div(output_width, block_x));
804   int grid_y = std::min<int>(maxGridSize[1], ceil_div(output_height, block_y));
805   int grid_z = std::min<int>(maxGridSize[2], input_size[0] * input_size[1]);
806   const dim3 grid(grid_x, grid_y, grid_z);
807 
808   AT_DISPATCH_FLOATING_TYPES_AND2(
809       at::ScalarType::Half, at::ScalarType::BFloat16,
810       grad_output.scalar_type(), "upsample_gen2d_backward_out_frame", [&] {
811         using accscalar_t = at::acc_type<scalar_t, true>;
812 
813         auto idata = grad_input.packed_accessor64<scalar_t, 4>();
814         auto odata = grad_output.packed_accessor64<const scalar_t, 4>();
815 
816         const accscalar_t height_scale = area_pixel_compute_scale<accscalar_t>(
817             input_height, output_height, align_corners, scales_h);
818         const accscalar_t width_scale = area_pixel_compute_scale<accscalar_t>(
819             input_width, output_width, align_corners, scales_w);
820 
821         auto interp_filter = InterpFilter();
822         const int interp_height = 1 + 2 * (int)ceilf(
823             (height_scale >= 1.0) ? interp_filter.size * 0.5 * height_scale : interp_filter.size * 0.5);
824         const int interp_width = 1 + 2 * (int)ceilf(
825             (width_scale >= 1.0) ? interp_filter.size * 0.5 * width_scale : interp_filter.size * 0.5);
826 
827         size_t weights_per_block = interp_width * block_x + interp_height * block_y;
828         size_t shmem_size = weights_per_block * sizeof(scalar_t);
829         size_t sharedMemPerBlock = at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
830         TORCH_CHECK(
831             shmem_size <= sharedMemPerBlock,
832             "Provided interpolation parameters can not be handled with current algorithm implementation. ",
833             "Please reduce the scale factor. Too much shared memory required: ",
834             shmem_size, " vs ", sharedMemPerBlock);
835 
836         upsample_gen2d_aa_backward_out_frame<scalar_t, accscalar_t>
837             <<<grid,
838                block,
839                shmem_size,
840                stream>>>(height_scale, width_scale, idata, odata, interp_filter);
841         C10_CUDA_KERNEL_LAUNCH_CHECK();
842       });
843 }
844 
845 } // namespace
846 
TORCH_IMPL_FUNC(upsample_bilinear2d_out_cuda)847 TORCH_IMPL_FUNC(upsample_bilinear2d_out_cuda) (
848     const Tensor& input,
849     IntArrayRef output_size,
850     bool align_corners,
851     std::optional<double> scales_h,
852     std::optional<double> scales_w,
853     const Tensor& output) {
854   upsample_bilinear2d_out_cuda_template(output, input, output_size, align_corners, scales_h, scales_w);
855 }
856 
TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_cuda)857 TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_cuda) (
858     const Tensor& grad_output,
859     IntArrayRef output_size,
860     IntArrayRef input_size,
861     bool align_corners,
862     std::optional<double> scales_h,
863     std::optional<double> scales_w,
864     const Tensor& grad_input) {
865   // See Note [Writing Nondeterministic Operations]
866   // Nondeterministic because of atomicAdd usage
867   globalContext().alertNotDeterministic("upsample_bilinear2d_backward_out_cuda");
868   upsample_bilinear2d_backward_out_cuda_template(
869       grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w);
870 }
871 
TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_out_cuda)872 TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_out_cuda) (
873     const Tensor& input,
874     IntArrayRef output_size,
875     bool align_corners,
876     std::optional<double> scales_h,
877     std::optional<double> scales_w,
878     const Tensor& output) {
879 
880   upsample_gen2d_aa_out_cuda_template<upsample_antialias::BilinearFilterFunctor>(
881       output, input, output_size, align_corners, scales_h, scales_w);
882 }
883 
TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_backward_out_cuda)884 TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_backward_out_cuda) (
885     const Tensor& grad_output,
886     IntArrayRef output_size,
887     IntArrayRef input_size,
888     bool align_corners,
889     std::optional<double> scales_h,
890     std::optional<double> scales_w,
891     const Tensor& grad_input) {
892   // See Note [Writing Nondeterministic Operations]
893   // Nondeterministic because of atomicAdd usage
894   globalContext().alertNotDeterministic("upsample_bilinear2d_aa_backward_out_cuda");
895   upsample_gen2d_aa_backward_out_cuda_template<upsample_antialias::BilinearFilterFunctor>(
896       grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w);
897 }
898 
899 // We define bicubic anti-alias function implementations in this file instead of
900 // UpSampleBicubic2d.cu as we are using a single generic implementation
TORCH_IMPL_FUNC(_upsample_bicubic2d_aa_out_cuda)901 TORCH_IMPL_FUNC(_upsample_bicubic2d_aa_out_cuda) (
902     const Tensor& input,
903     IntArrayRef output_size,
904     bool align_corners,
905     std::optional<double> scales_h,
906     std::optional<double> scales_w,
907     const Tensor& output) {
908   upsample_gen2d_aa_out_cuda_template<upsample_antialias::BicubicFilterFunctor>(
909       output, input, output_size, align_corners, scales_h, scales_w);
910 }
911 
TORCH_IMPL_FUNC(_upsample_bicubic2d_aa_backward_out_cuda)912 TORCH_IMPL_FUNC(_upsample_bicubic2d_aa_backward_out_cuda) (
913     const Tensor& grad_output,
914     IntArrayRef output_size,
915     IntArrayRef input_size,
916     bool align_corners,
917     std::optional<double> scales_h,
918     std::optional<double> scales_w,
919     const Tensor& grad_input) {
920   // See Note [Writing Nondeterministic Operations]
921   // Nondeterministic because of atomicAdd usage
922   globalContext().alertNotDeterministic("upsample_bicubic2d_aa_backward_out_cuda");
923   upsample_gen2d_aa_backward_out_cuda_template<upsample_antialias::BicubicFilterFunctor>(
924       grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w);
925 }
926 
927 } // namespace at::native
928