xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/UpSampleNearest3d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorMeta.h>
4 #include <ATen/native/UpSample.h>
5 #include <c10/util/irange.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/_upsample_nearest_exact3d.h>
12 #include <ATen/ops/_upsample_nearest_exact3d_backward.h>
13 #include <ATen/ops/_upsample_nearest_exact3d_backward_native.h>
14 #include <ATen/ops/_upsample_nearest_exact3d_native.h>
15 #include <ATen/ops/upsample_nearest3d.h>
16 #include <ATen/ops/upsample_nearest3d_backward.h>
17 #include <ATen/ops/upsample_nearest3d_backward_native.h>
18 #include <ATen/ops/upsample_nearest3d_native.h>
19 #endif
20 
21 namespace at::meta {
22 
TORCH_META_FUNC(upsample_nearest3d)23 TORCH_META_FUNC(upsample_nearest3d) (
24     const Tensor& input,
25     IntArrayRef output_size,
26     std::optional<double> scales_d,
27     std::optional<double> scales_h,
28     std::optional<double> scales_w
29 ) {
30   auto full_output_size = native::upsample_3d_common_check(input.sizes(), output_size);
31 
32   // Allow for empty batch size but not other dimensions
33   TORCH_CHECK(
34       input.numel() != 0 || c10::multiply_integers(input.sizes().begin() + 1, input.sizes().end()),
35       "Non-empty 5D data tensor expected but got a tensor with sizes ",
36       input.sizes());
37 
38   set_output_raw_strided(0, full_output_size, {}, input.options().memory_format(input.suggest_memory_format()));
39 }
40 
TORCH_META_FUNC(_upsample_nearest_exact3d)41 TORCH_META_FUNC(_upsample_nearest_exact3d) (
42   const Tensor& input,
43   IntArrayRef output_size,
44   std::optional<double> scales_d,
45   std::optional<double> scales_h,
46   std::optional<double> scales_w
47 ) {
48   auto full_output_size = native::upsample_3d_common_check(input.sizes(), output_size);
49 
50   // Allow for empty batch size but not other dimensions
51   TORCH_CHECK(
52       input.numel() != 0 || c10::multiply_integers(input.sizes().begin() + 1, input.sizes().end()),
53       "Non-empty 5D data tensor expected but got a tensor with sizes ",
54       input.sizes());
55 
56   set_output_raw_strided(0, full_output_size, {}, input.options().memory_format(input.suggest_memory_format()));
57 }
58 
TORCH_META_FUNC(upsample_nearest3d_backward)59 TORCH_META_FUNC(upsample_nearest3d_backward) (
60     const Tensor& grad_output,
61     IntArrayRef output_size,
62     IntArrayRef input_size,
63     std::optional<double> scales_d,
64     std::optional<double> scales_h,
65     std::optional<double> scales_w
66 ) {
67   auto full_output_size = native::upsample_3d_common_check(input_size, output_size);
68 
69   TORCH_CHECK(
70       grad_output.dim() == 5,
71       "Expected grad_output to be a tensor of dimension 5 but got: dimension ", grad_output.dim());
72 
73   for (const auto i : c10::irange(5)) {
74     TORCH_CHECK(
75         grad_output.size(i) == full_output_size[i],
76         "Expected grad_output to have the same shape as output;",
77         " output.size(", i, ") = ", full_output_size[i],
78         " but got grad_output.size(", i, ") = ", grad_output.size(i));
79   }
80 
81   set_output_raw_strided(0, input_size, {}, grad_output.options());
82 }
83 
TORCH_META_FUNC(_upsample_nearest_exact3d_backward)84 TORCH_META_FUNC(_upsample_nearest_exact3d_backward) (
85   const Tensor& grad_output,
86   IntArrayRef output_size,
87   IntArrayRef input_size,
88   std::optional<double> scales_d,
89   std::optional<double> scales_h,
90   std::optional<double> scales_w
91 ) {
92   auto full_output_size = native::upsample_3d_common_check(input_size, output_size);
93 
94   TORCH_CHECK(
95       grad_output.dim() == 5,
96       "Expected grad_output to be a tensor of dimension 5 but got: dimension ", grad_output.dim());
97 
98   for (const auto i : c10::irange(5)) {
99     TORCH_CHECK(
100         grad_output.size(i) == full_output_size[i],
101         "Expected grad_output to have the same shape as output;",
102         " output.size(", i, ") = ", full_output_size[i],
103         " but got grad_output.size(", i, ") = ", grad_output.size(i));
104   }
105 
106   set_output_raw_strided(0, input_size, {}, grad_output.options());
107 }
108 
109 } // namespace at::meta
110 
111 namespace at::native {
112 
TORCH_IMPL_FUNC(upsample_nearest3d_out_cpu)113 TORCH_IMPL_FUNC(upsample_nearest3d_out_cpu) (
114     const Tensor& input,
115     IntArrayRef output_size,
116     std::optional<double> scales_d,
117     std::optional<double> scales_h,
118     std::optional<double> scales_w,
119     const Tensor& output
120 ) {
121   upsample_nearest3d_kernel(kCPU, output, input, scales_d, scales_h, scales_w);
122 }
123 
TORCH_IMPL_FUNC(_upsample_nearest_exact3d_out_cpu)124 TORCH_IMPL_FUNC(_upsample_nearest_exact3d_out_cpu) (
125     const Tensor& input,
126     IntArrayRef output_size,
127     std::optional<double> scales_d,
128     std::optional<double> scales_h,
129     std::optional<double> scales_w,
130     const Tensor& output
131 ) {
132   _upsample_nearest_exact3d_kernel(kCPU, output, input, scales_d, scales_h, scales_w);
133 }
134 
TORCH_IMPL_FUNC(upsample_nearest3d_backward_out_cpu)135 TORCH_IMPL_FUNC(upsample_nearest3d_backward_out_cpu) (
136     const Tensor& grad_output,
137     IntArrayRef output_size,
138     IntArrayRef input_size,
139     std::optional<double> scales_d,
140     std::optional<double> scales_h,
141     std::optional<double> scales_w,
142     const Tensor& grad_input) {
143   grad_input.zero_();
144   upsample_nearest3d_backward_kernel(kCPU, grad_input, grad_output, scales_d, scales_h, scales_w);
145 }
146 
TORCH_IMPL_FUNC(_upsample_nearest_exact3d_backward_out_cpu)147 TORCH_IMPL_FUNC(_upsample_nearest_exact3d_backward_out_cpu) (
148     const Tensor& grad_output,
149     IntArrayRef output_size,
150     IntArrayRef input_size,
151     std::optional<double> scales_d,
152     std::optional<double> scales_h,
153     std::optional<double> scales_w,
154     const Tensor& grad_input) {
155   grad_input.zero_();
156   _upsample_nearest_exact3d_backward_kernel(kCPU, grad_input, grad_output, scales_d, scales_h, scales_w);
157 }
158 
159 // vec variants
160 
161 using at::native::upsample::compute_output_size;
162 using at::native::upsample::get_scale_value;
163 
upsample_nearest3d(const Tensor & input,at::OptionalIntArrayRef output_size,std::optional<ArrayRef<double>> scale_factors)164 Tensor upsample_nearest3d(
165     const Tensor& input,
166     at::OptionalIntArrayRef output_size,
167     std::optional<ArrayRef<double>> scale_factors) {
168   auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
169   auto scale_d = get_scale_value(scale_factors, 0);
170   auto scale_h = get_scale_value(scale_factors, 1);
171   auto scale_w = get_scale_value(scale_factors, 2);
172   return at::upsample_nearest3d(input, osize, scale_d, scale_h, scale_w);
173 }
174 
_upsample_nearest_exact3d(const Tensor & input,at::OptionalIntArrayRef output_size,std::optional<ArrayRef<double>> scale_factors)175 Tensor _upsample_nearest_exact3d(
176     const Tensor& input,
177     at::OptionalIntArrayRef output_size,
178     std::optional<ArrayRef<double>> scale_factors) {
179   auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
180   auto scale_d = get_scale_value(scale_factors, 0);
181   auto scale_h = get_scale_value(scale_factors, 1);
182   auto scale_w = get_scale_value(scale_factors, 2);
183   return at::_upsample_nearest_exact3d(input, osize, scale_d, scale_h, scale_w);
184 }
185 
186 DEFINE_DISPATCH(upsample_nearest3d_kernel);
187 DEFINE_DISPATCH(_upsample_nearest_exact3d_kernel);
188 DEFINE_DISPATCH(upsample_nearest3d_backward_kernel);
189 DEFINE_DISPATCH(_upsample_nearest_exact3d_backward_kernel);
190 
191 } // namespace at::native
192