1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorMeta.h>
4 #include <ATen/native/UpSample.h>
5 #include <c10/util/accumulate.h>
6 #include <c10/util/irange.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/_upsample_nearest_exact2d.h>
13 #include <ATen/ops/_upsample_nearest_exact2d_backward.h>
14 #include <ATen/ops/_upsample_nearest_exact2d_backward_native.h>
15 #include <ATen/ops/_upsample_nearest_exact2d_native.h>
16 #include <ATen/ops/upsample_nearest2d.h>
17 #include <ATen/ops/upsample_nearest2d_backward.h>
18 #include <ATen/ops/upsample_nearest2d_backward_native.h>
19 #include <ATen/ops/upsample_nearest2d_native.h>
20 #endif
21
22 namespace at::meta {
23
TORCH_META_FUNC(upsample_nearest2d)24 TORCH_META_FUNC(upsample_nearest2d) (
25 const Tensor& input, IntArrayRef output_size, std::optional<double> scales_h, std::optional<double> scales_w
26 ) {
27 auto full_output_size = native::upsample_2d_common_check(input.sizes(), output_size);
28
29 // Allow for empty batch size but not other dimensions
30 TORCH_CHECK(
31 input.numel() != 0 || c10::multiply_integers(input.sizes().begin() + 1, input.sizes().end()),
32 "Non-empty 4D data tensor expected but got a tensor with sizes ",
33 input.sizes());
34
35 set_output_raw_strided(0, full_output_size, {}, input.options().memory_format(input.suggest_memory_format()));
36 }
37
TORCH_META_FUNC(_upsample_nearest_exact2d)38 TORCH_META_FUNC(_upsample_nearest_exact2d) (
39 const Tensor& input, IntArrayRef output_size, std::optional<double> scales_h, std::optional<double> scales_w
40 ) {
41 auto full_output_size = native::upsample_2d_common_check(input.sizes(), output_size);
42
43 // Allow for empty batch size but not other dimensions
44 TORCH_CHECK(
45 input.numel() != 0 || c10::multiply_integers(input.sizes().begin() + 1, input.sizes().end()),
46 "Non-empty 4D data tensor expected but got a tensor with sizes ",
47 input.sizes());
48
49 set_output_raw_strided(0, full_output_size, {}, input.options().memory_format(input.suggest_memory_format()));
50 }
51
TORCH_META_FUNC(upsample_nearest2d_backward)52 TORCH_META_FUNC(upsample_nearest2d_backward) (
53 const Tensor& grad_output,
54 IntArrayRef output_size,
55 IntArrayRef input_size,
56 std::optional<double> scales_h,
57 std::optional<double> scales_w
58 ) {
59 auto full_output_size = native::upsample_2d_common_check(input_size, output_size);
60
61 TORCH_CHECK(
62 grad_output.dim() == 4,
63 "Expected grad_output to be a tensor of dimension 4 but got: dimension ", grad_output.dim());
64
65 for (const auto i : c10::irange(4)) {
66 TORCH_CHECK(
67 grad_output.size(i) == full_output_size[i],
68 "Expected grad_output to have the same shape as output;",
69 " output.size(", i, ") = ", full_output_size[i],
70 " but got grad_output.size(", i, ") = ", grad_output.size(i));
71 }
72
73 set_output_raw_strided(0, input_size, {}, grad_output.options().memory_format(grad_output.suggest_memory_format()));
74 }
75
TORCH_META_FUNC(_upsample_nearest_exact2d_backward)76 TORCH_META_FUNC(_upsample_nearest_exact2d_backward) (
77 const Tensor& grad_output,
78 IntArrayRef output_size,
79 IntArrayRef input_size,
80 std::optional<double> scales_h,
81 std::optional<double> scales_w
82 ) {
83 auto full_output_size = native::upsample_2d_common_check(input_size, output_size);
84
85 TORCH_CHECK(
86 grad_output.dim() == 4,
87 "Expected grad_output to be a tensor of dimension 4 but got: dimension ", grad_output.dim());
88
89 for (const auto i : c10::irange(4)) {
90 TORCH_CHECK(
91 grad_output.size(i) == full_output_size[i],
92 "Expected grad_output to have the same shape as output;",
93 " output.size(", i, ") = ", full_output_size[i],
94 " but got grad_output.size(", i, ") = ", grad_output.size(i));
95 }
96
97 set_output_raw_strided(0, input_size, {}, grad_output.options().memory_format(grad_output.suggest_memory_format()));
98 }
99
100 } // namespace at::meta
101
102 namespace at::native {
103
TORCH_IMPL_FUNC(upsample_nearest2d_out_cpu)104 TORCH_IMPL_FUNC(upsample_nearest2d_out_cpu) (
105 const Tensor& input,
106 IntArrayRef output_size,
107 std::optional<double> scales_h,
108 std::optional<double> scales_w,
109 const Tensor& output
110 ) {
111 upsample_nearest2d_kernel(kCPU, output, input, scales_h, scales_w);
112 }
113
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_cpu)114 TORCH_IMPL_FUNC(_upsample_nearest_exact2d_out_cpu) (
115 const Tensor& input,
116 IntArrayRef output_size,
117 std::optional<double> scales_h,
118 std::optional<double> scales_w,
119 const Tensor& output
120 ) {
121 _upsample_nearest_exact2d_kernel(kCPU, output, input, scales_h, scales_w);
122 }
123
TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_cpu)124 TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_cpu) (
125 const Tensor& grad_output,
126 IntArrayRef output_size,
127 IntArrayRef input_size,
128 std::optional<double> scales_h,
129 std::optional<double> scales_w,
130 const Tensor& grad_input) {
131 grad_input.zero_();
132 upsample_nearest2d_backward_kernel(kCPU, grad_input, grad_output, scales_h, scales_w);
133 }
134
TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_cpu)135 TORCH_IMPL_FUNC(_upsample_nearest_exact2d_backward_out_cpu) (
136 const Tensor& grad_output,
137 IntArrayRef output_size,
138 IntArrayRef input_size,
139 std::optional<double> scales_h,
140 std::optional<double> scales_w,
141 const Tensor& grad_input) {
142 grad_input.zero_();
143 _upsample_nearest_exact2d_backward_kernel(kCPU, grad_input, grad_output, scales_h, scales_w);
144 }
145
146 using at::native::upsample::compute_output_size;
147 using at::native::upsample::get_scale_value;
148
upsample_nearest2d(const Tensor & input,at::OptionalIntArrayRef output_size,std::optional<ArrayRef<double>> scale_factors)149 Tensor upsample_nearest2d(
150 const Tensor& input,
151 at::OptionalIntArrayRef output_size,
152 std::optional<ArrayRef<double>> scale_factors) {
153 auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
154 auto scale_h = get_scale_value(scale_factors, 0);
155 auto scale_w = get_scale_value(scale_factors, 1);
156 return at::upsample_nearest2d(input, osize, scale_h, scale_w);
157 }
158
_upsample_nearest_exact2d(const Tensor & input,at::OptionalIntArrayRef output_size,std::optional<ArrayRef<double>> scale_factors)159 Tensor _upsample_nearest_exact2d(
160 const Tensor& input,
161 at::OptionalIntArrayRef output_size,
162 std::optional<ArrayRef<double>> scale_factors) {
163 auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
164 auto scale_h = get_scale_value(scale_factors, 0);
165 auto scale_w = get_scale_value(scale_factors, 1);
166 return at::_upsample_nearest_exact2d(input, osize, scale_h, scale_w);
167 }
168
169 DEFINE_DISPATCH(upsample_nearest2d_kernel);
170 DEFINE_DISPATCH(_upsample_nearest_exact2d_kernel);
171 DEFINE_DISPATCH(upsample_nearest2d_backward_kernel);
172 DEFINE_DISPATCH(_upsample_nearest_exact2d_backward_kernel);
173
174 } // namespace at::native
175