1 // Adapted from interp.cpp from Caffe util by Pauline Luc
2 // Originally developed by George Papandreou
3 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
4
5 #include <ATen/core/Tensor.h>
6 #include <ATen/TensorMeta.h>
7 #include <ATen/native/UpSample.h>
8 #include <c10/util/irange.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_upsample_bilinear2d_aa.h>
15 #include <ATen/ops/_upsample_bilinear2d_aa_backward.h>
16 #include <ATen/ops/_upsample_bilinear2d_aa_backward_native.h>
17 #include <ATen/ops/_upsample_bilinear2d_aa_native.h>
18 #include <ATen/ops/upsample_bilinear2d.h>
19 #include <ATen/ops/upsample_bilinear2d_backward.h>
20 #include <ATen/ops/upsample_bilinear2d_backward_native.h>
21 #include <ATen/ops/upsample_bilinear2d_native.h>
22 #endif
23
24 namespace at::meta {
25
TORCH_META_FUNC(upsample_bilinear2d)26 TORCH_META_FUNC(upsample_bilinear2d) (
27 const Tensor& input, IntArrayRef output_size, bool align_corners, std::optional<double> scales_h, std::optional<double> scales_w
28 ) {
29 auto full_output_size = native::upsample_2d_common_check(input.sizes(), output_size);
30
31 // Allow for empty batch size but not other dimensions
32 TORCH_CHECK(
33 input.numel() != 0 || c10::multiply_integers(input.sizes().begin() + 1, input.sizes().end()),
34 "Non-empty 4D data tensor expected but got a tensor with sizes ",
35 input.sizes());
36
37 set_output_raw_strided(0, full_output_size, {}, input.options().memory_format(input.suggest_memory_format()));
38 }
39
TORCH_META_FUNC(upsample_bilinear2d_backward)40 TORCH_META_FUNC(upsample_bilinear2d_backward) (
41 const Tensor& grad_output,
42 IntArrayRef output_size,
43 IntArrayRef input_size,
44 bool align_corners,
45 std::optional<double> scales_h,
46 std::optional<double> scales_w
47 ) {
48 auto full_output_size = native::upsample_2d_common_check(input_size, output_size);
49
50 TORCH_CHECK(
51 grad_output.dim() == 4,
52 "Expected grad_output to be a tensor of dimension 4 but got: dimension ", grad_output.dim());
53
54 for (const auto i : c10::irange(4)) {
55 TORCH_CHECK(
56 grad_output.size(i) == full_output_size[i],
57 "Expected grad_output to have the same shape as output;",
58 " output.size(", i, ") = ", full_output_size[i],
59 " but got grad_output.size(", i, ") = ", grad_output.size(i));
60 }
61
62 set_output_raw_strided(0, input_size, {}, grad_output.options().memory_format(grad_output.suggest_memory_format()));
63 }
64
TORCH_META_FUNC(_upsample_bilinear2d_aa)65 TORCH_META_FUNC(_upsample_bilinear2d_aa) (
66 const Tensor& input, IntArrayRef output_size, bool align_corners, std::optional<double> scales_h, std::optional<double> scales_w
67 ) {
68 auto full_output_size = native::upsample_2d_common_check(input.sizes(), output_size);
69
70 // Allow for empty batch size but not other dimensions
71 TORCH_CHECK(
72 input.numel() != 0 || c10::multiply_integers(input.sizes().begin() + 1, input.sizes().end()),
73 "Non-empty 4D data tensor expected but got a tensor with sizes ",
74 input.sizes());
75
76 set_output_raw_strided(0, full_output_size, {}, input.options().memory_format(input.suggest_memory_format()));
77 }
78
TORCH_META_FUNC(_upsample_bilinear2d_aa_backward)79 TORCH_META_FUNC(_upsample_bilinear2d_aa_backward) (
80 const Tensor& grad_output,
81 IntArrayRef output_size,
82 IntArrayRef input_size,
83 bool align_corners,
84 std::optional<double> scales_h,
85 std::optional<double> scales_w
86 ) {
87 auto full_output_size = native::upsample_2d_common_check(input_size, output_size);
88
89 TORCH_CHECK(
90 grad_output.dim() == 4,
91 "Expected grad_output to be a tensor of dimension 4 but got: dimension ", grad_output.dim());
92
93 for (const auto i : c10::irange(4)) {
94 TORCH_CHECK(
95 grad_output.size(i) == full_output_size[i],
96 "Expected grad_output to have the same shape as output;",
97 " output.size(", i, ") = ", full_output_size[i],
98 " but got grad_output.size(", i, ") = ", grad_output.size(i));
99 }
100
101 set_output_raw_strided(0, input_size, {}, grad_output.options().memory_format(grad_output.suggest_memory_format()));
102 }
103
104 } // namespace at::meta
105
106 namespace at::native {
107
TORCH_IMPL_FUNC(upsample_bilinear2d_out_cpu)108 TORCH_IMPL_FUNC(upsample_bilinear2d_out_cpu) (
109 const Tensor& input,
110 IntArrayRef output_size,
111 bool align_corners,
112 std::optional<double> scales_h,
113 std::optional<double> scales_w,
114 const Tensor& output
115 ) {
116 upsample_bilinear2d_kernel(kCPU, output, input, align_corners, scales_h, scales_w);
117 }
118
TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_cpu)119 TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_cpu) (
120 const Tensor& grad_output,
121 IntArrayRef output_size,
122 IntArrayRef input_size,
123 bool align_corners,
124 std::optional<double> scales_h,
125 std::optional<double> scales_w,
126 const Tensor& grad_input
127 ) {
128 grad_input.zero_();
129 upsample_bilinear2d_backward_kernel(kCPU, grad_input, grad_output, align_corners, scales_h, scales_w);
130 }
131
132
TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_out_cpu)133 TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_out_cpu) (
134 const Tensor& input,
135 IntArrayRef output_size,
136 bool align_corners,
137 std::optional<double> scales_h,
138 std::optional<double> scales_w,
139 const Tensor& output
140 ) {
141 _upsample_bilinear2d_aa_kernel(kCPU, output, input, align_corners, scales_h, scales_w);
142 }
143
TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_backward_out_cpu)144 TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_backward_out_cpu) (
145 const Tensor& grad_output,
146 IntArrayRef output_size,
147 IntArrayRef input_size,
148 bool align_corners,
149 std::optional<double> scales_h,
150 std::optional<double> scales_w,
151 const Tensor& grad_input
152 ) {
153 grad_input.zero_();
154 _upsample_bilinear2d_aa_backward_kernel(kCPU, grad_input, grad_output, align_corners, scales_h, scales_w);
155 }
156
157 using at::native::upsample::compute_output_size;
158 using at::native::upsample::get_scale_value;
159
upsample_bilinear2d(const Tensor & input,at::OptionalIntArrayRef output_size,bool align_corners,std::optional<ArrayRef<double>> scale_factors)160 Tensor upsample_bilinear2d(
161 const Tensor& input,
162 at::OptionalIntArrayRef output_size,
163 bool align_corners,
164 std::optional<ArrayRef<double>> scale_factors) {
165 auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
166 auto scale_h = get_scale_value(scale_factors, 0);
167 auto scale_w = get_scale_value(scale_factors, 1);
168 return at::upsample_bilinear2d(input, osize, align_corners, scale_h, scale_w);
169 }
170
_upsample_bilinear2d_aa(const Tensor & input,at::OptionalIntArrayRef output_size,bool align_corners,std::optional<ArrayRef<double>> scale_factors)171 Tensor _upsample_bilinear2d_aa(
172 const Tensor& input,
173 at::OptionalIntArrayRef output_size,
174 bool align_corners,
175 std::optional<ArrayRef<double>> scale_factors) {
176 auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
177 auto scale_h = get_scale_value(scale_factors, 0);
178 auto scale_w = get_scale_value(scale_factors, 1);
179 return at::_upsample_bilinear2d_aa(input, osize, align_corners, scale_h, scale_w);
180 }
181
182 DEFINE_DISPATCH(upsample_bilinear2d_kernel);
183 DEFINE_DISPATCH(upsample_bilinear2d_backward_kernel);
184 DEFINE_DISPATCH(_upsample_bilinear2d_aa_kernel);
185 DEFINE_DISPATCH(_upsample_bilinear2d_aa_backward_kernel);
186
187 } // namespace at::native
188