1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/UpSample.h>
6 #include <ATen/native/cpu/utils.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/_empty_affine_quantized.h>
13 #include <ATen/ops/_upsample_nearest_exact2d_native.h>
14 #include <ATen/ops/upsample_nearest2d_native.h>
15 #endif
16
17 #include <c10/util/irange.h>
18
19 #include <cstring>
20
21
22 namespace at {
23 namespace native {
24
25 // Define a typedef to dispatch to nearest_idx or nearest_exact_idx
26 typedef int64_t (*nn_compute_source_index_fn_t)(const float, int64_t, int64_t);
27
28 // at::native functions for the native_functions.yaml
29 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest2d_out_frame(scalar_t * odata,scalar_t * idata,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width,int64_t nbatch,int64_t channels,std::optional<double> scales_h,std::optional<double> scales_w)30 static void upsample_nearest2d_out_frame(
31 scalar_t* odata,
32 scalar_t* idata,
33 int64_t input_height,
34 int64_t input_width,
35 int64_t output_height,
36 int64_t output_width,
37 int64_t nbatch,
38 int64_t channels,
39 std::optional<double> scales_h,
40 std::optional<double> scales_w) {
41 float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
42 float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
43
44 channels = channels * nbatch;
45 if (channels == 0 || output_height == 0 || output_width == 0) {
46 return;
47 }
48 auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(idata);
49 auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(odata);
50
51 // special case: just copy
52 if (input_height == output_height && input_width == output_width) {
53 std::memcpy(o_p, i_p, channels * input_height * input_width * sizeof(typename scalar_t::underlying));
54 return;
55 }
56
57 std::unique_ptr<int64_t []> input_offset_arr(new int64_t[output_width]);
58 int64_t* input_offset = input_offset_arr.get();
59
60 for (const auto w2 : c10::irange(output_width)) {
61 const int64_t w1 = nn_compute_source_index_fn(width_scale, w2, input_width);
62 input_offset[w2] = w1;
63 }
64
65 int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, output_width);
66 at::parallel_for(0, channels * output_height, grain_size, [&](int64_t begin, int64_t end) {
67 int64_t nc{0}, h2{0};
68 data_index_init(begin, nc, channels, h2, output_height);
69
70 for (const auto i : c10::irange(begin, end)) {
71 const int64_t h1 = nn_compute_source_index_fn(height_scale, h2, input_height);
72 const auto* pos1 = &i_p[nc * input_height * input_width + h1 * input_width];
73 auto* pos2 = &o_p[i * output_width];
74
75 for (const auto w2 : c10::irange(output_width)) {
76 const int64_t w1 = input_offset[w2];
77 pos2[w2] = pos1[w1];
78 }
79
80 data_index_step(nc, channels, h2, output_height);
81 }
82 });
83 }
84
85 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest2d_out_frame_nhwc(scalar_t * odata,scalar_t * idata,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width,int64_t nbatch,int64_t channels,std::optional<double> scales_h,std::optional<double> scales_w)86 static void upsample_nearest2d_out_frame_nhwc(
87 scalar_t* odata,
88 scalar_t* idata,
89 int64_t input_height,
90 int64_t input_width,
91 int64_t output_height,
92 int64_t output_width,
93 int64_t nbatch,
94 int64_t channels,
95 std::optional<double> scales_h,
96 std::optional<double> scales_w) {
97 float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
98 float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
99
100 at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
101 int64_t b{0}, h2{0}, w2{0};
102 data_index_init(begin, b, nbatch, h2, output_height, w2, output_width);
103
104 for (const auto i : c10::irange(begin, end)) {
105 auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(idata + b * input_height * input_width * channels);
106 auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(odata + i * channels);
107
108 const int64_t h1 = nn_compute_source_index_fn(height_scale, h2, input_height);
109 const int64_t w1 = nn_compute_source_index_fn(width_scale, w2, input_width);
110
111 const auto* pos1 = &i_p[(h1 * input_width + w1)*channels];
112 auto* pos2 = &o_p[0];
113 std::memcpy(pos2, pos1, channels * sizeof(typename scalar_t::underlying));
114
115 data_index_step(b, nbatch, h2, output_height, w2, output_width);
116 }
117 });
118 }
119
120 template <nn_compute_source_index_fn_t nn_compute_source_index_fn>
_upsample_nearest2d_quantized_cpu(const Tensor & input,IntArrayRef output_size,std::optional<double> scales_h,std::optional<double> scales_w)121 Tensor _upsample_nearest2d_quantized_cpu(
122 const Tensor& input,
123 IntArrayRef output_size,
124 std::optional<double> scales_h,
125 std::optional<double> scales_w) {
126 TORCH_CHECK(
127 output_size.size() == 2,
128 "It is expected output_size equals to 2, but got size ",
129 output_size.size());
130
131 TORCH_CHECK(
132 input.dim() == 4,
133 "Non-empty 4D data tensor expected but got a tensor with sizes ",
134 input.sizes());
135
136 int64_t output_height = output_size[0];
137 int64_t output_width = output_size[1];
138
139 int64_t nbatch = input.size(0);
140 int64_t channels = input.size(1);
141 int64_t input_height = input.size(2);
142 int64_t input_width = input.size(3);
143 AT_ASSERT(input_width > 0 && output_width > 0);
144 if (input.is_contiguous(c10::MemoryFormat::ChannelsLast)) {
145 Tensor output = at::_empty_affine_quantized(
146 {nbatch, channels, output_height, output_width},
147 input.options().memory_format(input.suggest_memory_format()),
148 input.q_scale(),
149 input.q_zero_point(),
150 std::nullopt);
151
152 // special case: just copy
153 if (input_height == output_height && input_width == output_width) {
154 output.copy_(input);
155 return output;
156 }
157
158 AT_DISPATCH_QINT_TYPES(input.scalar_type(), "upsample_nearest2d", [&] {
159 auto* idata = static_cast<scalar_t*>(input.data_ptr());
160 auto* odata = static_cast<scalar_t*>(output.data_ptr());
161 upsample_nearest2d_out_frame_nhwc<scalar_t, nn_compute_source_index_fn>(
162 odata,
163 idata,
164 input_height,
165 input_width,
166 output_height,
167 output_width,
168 nbatch,
169 channels,
170 scales_h,
171 scales_w);
172 });
173 return output;
174 } else {
175 Tensor output = at::_empty_affine_quantized(
176 {nbatch, channels, output_height, output_width},
177 input.options(),
178 input.q_scale(),
179 input.q_zero_point());
180
181 auto input_contig = input.contiguous();
182
183 AT_DISPATCH_QINT_TYPES(input_contig.scalar_type(), "upsample_nearest2d", [&] {
184 auto* idata = static_cast<scalar_t*>(input_contig.data_ptr());
185 auto* odata = static_cast<scalar_t*>(output.data_ptr());
186 upsample_nearest2d_out_frame<scalar_t, nn_compute_source_index_fn>(
187 odata,
188 idata,
189 input_height,
190 input_width,
191 output_height,
192 output_width,
193 nbatch,
194 channels,
195 scales_h,
196 scales_w);
197 });
198 return output;
199 }
200 }
201
202 using at::native::upsample::compute_output_size;
203 using at::native::upsample::get_scale_value;
204
upsample_nearest2d_quantized_cpu(const Tensor & input,IntArrayRef osize,std::optional<double> scale_h,std::optional<double> scale_w)205 Tensor upsample_nearest2d_quantized_cpu(
206 const Tensor& input,
207 IntArrayRef osize,
208 std::optional<double> scale_h,
209 std::optional<double> scale_w) {
210 return _upsample_nearest2d_quantized_cpu<nearest_neighbor_compute_source_index>(input, osize, scale_h, scale_w);
211 }
212
_upsample_nearest_exact2d_quantized_cpu(const Tensor & input,IntArrayRef osize,std::optional<double> scale_h,std::optional<double> scale_w)213 Tensor _upsample_nearest_exact2d_quantized_cpu(
214 const Tensor& input,
215 IntArrayRef osize,
216 std::optional<double> scale_h,
217 std::optional<double> scale_w) {
218 return _upsample_nearest2d_quantized_cpu<nearest_neighbor_exact_compute_source_index>(input, osize, scale_h, scale_w);
219 }
220
221 } // namespace native
222 } // namespace at
223