xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/UpSampleNearest2d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorMeta.h>
4 #include <ATen/native/UpSample.h>
5 #include <c10/util/accumulate.h>
6 #include <c10/util/irange.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/_upsample_nearest_exact2d.h>
13 #include <ATen/ops/_upsample_nearest_exact2d_backward.h>
14 #include <ATen/ops/_upsample_nearest_exact2d_backward_native.h>
15 #include <ATen/ops/_upsample_nearest_exact2d_native.h>
16 #include <ATen/ops/upsample_nearest2d.h>
17 #include <ATen/ops/upsample_nearest2d_backward.h>
18 #include <ATen/ops/upsample_nearest2d_backward_native.h>
19 #include <ATen/ops/upsample_nearest2d_native.h>
20 #endif
21 
22 namespace at::meta {
23 
TORCH_META_FUNC(upsample_nearest2d)24 TORCH_META_FUNC(upsample_nearest2d) (
25     const Tensor& input, IntArrayRef output_size, std::optional<double> scales_h, std::optional<double> scales_w
26 ) {
27   auto full_output_size = native::upsample_2d_common_check(input.sizes(), output_size);
28 
29   // Allow for empty batch size but not other dimensions
30   TORCH_CHECK(
31       input.numel() != 0 || c10::multiply_integers(input.sizes().begin() + 1, input.sizes().end()),
32       "Non-empty 4D data tensor expected but got a tensor with sizes ",
33       input.sizes());
34 
35   set_output_raw_strided(0, full_output_size, {}, input.options().memory_format(input.suggest_memory_format()));
36 }
37 
TORCH_META_FUNC(_upsample_nearest_exact2d)38 TORCH_META_FUNC(_upsample_nearest_exact2d) (
39   const Tensor& input, IntArrayRef output_size, std::optional<double> scales_h, std::optional<double> scales_w
40 ) {
41   auto full_output_size = native::upsample_2d_common_check(input.sizes(), output_size);
42 
43   // Allow for empty batch size but not other dimensions
44   TORCH_CHECK(
45       input.numel() != 0 || c10::multiply_integers(input.sizes().begin() + 1, input.sizes().end()),
46       "Non-empty 4D data tensor expected but got a tensor with sizes ",
47       input.sizes());
48 
49   set_output_raw_strided(0, full_output_size, {}, input.options().memory_format(input.suggest_memory_format()));
50 }
51 
TORCH_META_FUNC(upsample_nearest2d_backward)52 TORCH_META_FUNC(upsample_nearest2d_backward) (
53     const Tensor& grad_output,
54     IntArrayRef output_size,
55     IntArrayRef input_size,
56     std::optional<double> scales_h,
57     std::optional<double> scales_w
58 ) {
59   auto full_output_size = native::upsample_2d_common_check(input_size, output_size);
60 
61   TORCH_CHECK(
62       grad_output.dim() == 4,
63       "Expected grad_output to be a tensor of dimension 4 but got: dimension ", grad_output.dim());
64 
65   for (const auto i : c10::irange(4)) {
66     TORCH_CHECK(
67         grad_output.size(i) == full_output_size[i],
68         "Expected grad_output to have the same shape as output;",
69         " output.size(", i, ") = ", full_output_size[i],
70         " but got grad_output.size(", i, ") = ", grad_output.size(i));
71   }
72 
73   set_output_raw_strided(0, input_size, {}, grad_output.options().memory_format(grad_output.suggest_memory_format()));
74 }
75 
TORCH_META_FUNC(_upsample_nearest_exact2d_backward)76 TORCH_META_FUNC(_upsample_nearest_exact2d_backward) (
77   const Tensor& grad_output,
78   IntArrayRef output_size,
79   IntArrayRef input_size,
80   std::optional<double> scales_h,
81   std::optional<double> scales_w
82 ) {
83   auto full_output_size = native::upsample_2d_common_check(input_size, output_size);
84 
85   TORCH_CHECK(
86       grad_output.dim() == 4,
87       "Expected grad_output to be a tensor of dimension 4 but got: dimension ", grad_output.dim());
88 
89   for (const auto i : c10::irange(4)) {
90     TORCH_CHECK(
91         grad_output.size(i) == full_output_size[i],
92         "Expected grad_output to have the same shape as output;",
93         " output.size(", i, ") = ", full_output_size[i],
94         " but got grad_output.size(", i, ") = ", grad_output.size(i));
95   }
96 
97   set_output_raw_strided(0, input_size, {}, grad_output.options().memory_format(grad_output.suggest_memory_format()));
98 }
99 
100 } // namespace at::meta
101 
102 namespace at::native {
103 
TORCH_IMPL_FUNC(upsample_nearest2d_out_cpu)104 TORCH_IMPL_FUNC(upsample_nearest2d_out_cpu) (
105     const Tensor& input,
106     IntArrayRef output_size,
107     std::optional<double> scales_h,
108     std::optional<double> scales_w,
109     const Tensor& output
110 ) {
111   upsample_nearest2d_kernel(kCPU, output, input, scales_h, scales_w);
112 }
113 
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_cpu)114 TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_cpu) (
115     const Tensor& input,
116     IntArrayRef output_size,
117     std::optional<double> scales_h,
118     std::optional<double> scales_w,
119     const Tensor& output
120 ) {
121   _upsample_nearest_exact2d_kernel(kCPU, output, input, scales_h, scales_w);
122 }
123 
TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_cpu)124 TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_cpu) (
125     const Tensor& grad_output,
126     IntArrayRef output_size,
127     IntArrayRef input_size,
128     std::optional<double> scales_h,
129     std::optional<double> scales_w,
130     const Tensor& grad_input) {
131   grad_input.zero_();
132   upsample_nearest2d_backward_kernel(kCPU, grad_input, grad_output, scales_h, scales_w);
133 }
134 
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_cpu)135 TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_cpu) (
136     const Tensor& grad_output,
137     IntArrayRef output_size,
138     IntArrayRef input_size,
139     std::optional<double> scales_h,
140     std::optional<double> scales_w,
141     const Tensor& grad_input) {
142   grad_input.zero_();
143   _upsample_nearest_exact2d_backward_kernel(kCPU, grad_input, grad_output, scales_h, scales_w);
144 }
145 
146 using at::native::upsample::compute_output_size;
147 using at::native::upsample::get_scale_value;
148 
upsample_nearest2d(const Tensor & input,at::OptionalIntArrayRef output_size,std::optional<ArrayRef<double>> scale_factors)149 Tensor upsample_nearest2d(
150     const Tensor& input,
151     at::OptionalIntArrayRef output_size,
152     std::optional<ArrayRef<double>> scale_factors) {
153   auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
154   auto scale_h = get_scale_value(scale_factors, 0);
155   auto scale_w = get_scale_value(scale_factors, 1);
156   return at::upsample_nearest2d(input, osize, scale_h, scale_w);
157 }
158 
_upsample_nearest_exact2d(const Tensor & input,at::OptionalIntArrayRef output_size,std::optional<ArrayRef<double>> scale_factors)159 Tensor _upsample_nearest_exact2d(
160     const Tensor& input,
161     at::OptionalIntArrayRef output_size,
162     std::optional<ArrayRef<double>> scale_factors) {
163   auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
164   auto scale_h = get_scale_value(scale_factors, 0);
165   auto scale_w = get_scale_value(scale_factors, 1);
166   return at::_upsample_nearest_exact2d(input, osize, scale_h, scale_w);
167 }
168 
169 DEFINE_DISPATCH(upsample_nearest2d_kernel);
170 DEFINE_DISPATCH(_upsample_nearest_exact2d_kernel);
171 DEFINE_DISPATCH(upsample_nearest2d_backward_kernel);
172 DEFINE_DISPATCH(_upsample_nearest_exact2d_backward_kernel);
173 
174 } // namespace at::native
175