xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/UpSample.h>
6 #include <ATen/native/cpu/utils.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/_empty_affine_quantized.h>
13 #include <ATen/ops/_upsample_nearest_exact2d_native.h>
14 #include <ATen/ops/upsample_nearest2d_native.h>
15 #endif
16 
17 #include <c10/util/irange.h>
18 
19 #include <cstring>
20 
21 
22 namespace at {
23 namespace native {
24 
25 // Define a typedef to dispatch to nearest_idx or nearest_exact_idx
26 typedef int64_t (*nn_compute_source_index_fn_t)(const float, int64_t, int64_t);
27 
28 // at::native functions for the native_functions.yaml
29 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest2d_out_frame(scalar_t * odata,scalar_t * idata,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width,int64_t nbatch,int64_t channels,std::optional<double> scales_h,std::optional<double> scales_w)30 static void upsample_nearest2d_out_frame(
31     scalar_t* odata,
32     scalar_t* idata,
33     int64_t input_height,
34     int64_t input_width,
35     int64_t output_height,
36     int64_t output_width,
37     int64_t nbatch,
38     int64_t channels,
39     std::optional<double> scales_h,
40     std::optional<double> scales_w) {
41   float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
42   float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
43 
44   channels = channels * nbatch;
45   if (channels == 0 || output_height == 0 || output_width == 0) {
46     return;
47   }
48   auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(idata);
49   auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(odata);
50 
51   // special case: just copy
52   if (input_height == output_height && input_width == output_width) {
53     std::memcpy(o_p, i_p, channels * input_height * input_width * sizeof(typename scalar_t::underlying));
54     return;
55   }
56 
57   std::unique_ptr<int64_t []> input_offset_arr(new int64_t[output_width]);
58   int64_t* input_offset = input_offset_arr.get();
59 
60   for (const auto w2 : c10::irange(output_width)) {
61     const int64_t w1 = nn_compute_source_index_fn(width_scale, w2, input_width);
62     input_offset[w2] = w1;
63   }
64 
65   int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, output_width);
66   at::parallel_for(0, channels * output_height, grain_size, [&](int64_t begin, int64_t end) {
67     int64_t nc{0}, h2{0};
68     data_index_init(begin, nc, channels, h2, output_height);
69 
70     for (const auto i : c10::irange(begin, end)) {
71       const int64_t h1 = nn_compute_source_index_fn(height_scale, h2, input_height);
72       const auto* pos1 = &i_p[nc * input_height * input_width + h1 * input_width];
73       auto* pos2 = &o_p[i * output_width];
74 
75       for (const auto w2 : c10::irange(output_width)) {
76         const int64_t w1 = input_offset[w2];
77         pos2[w2] = pos1[w1];
78       }
79 
80       data_index_step(nc, channels, h2, output_height);
81     }
82   });
83 }
84 
85 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest2d_out_frame_nhwc(scalar_t * odata,scalar_t * idata,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width,int64_t nbatch,int64_t channels,std::optional<double> scales_h,std::optional<double> scales_w)86 static void upsample_nearest2d_out_frame_nhwc(
87     scalar_t* odata,
88     scalar_t* idata,
89     int64_t input_height,
90     int64_t input_width,
91     int64_t output_height,
92     int64_t output_width,
93     int64_t nbatch,
94     int64_t channels,
95     std::optional<double> scales_h,
96     std::optional<double> scales_w) {
97   float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
98   float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
99 
100   at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
101     int64_t b{0}, h2{0}, w2{0};
102     data_index_init(begin, b, nbatch, h2, output_height, w2, output_width);
103 
104     for (const auto i : c10::irange(begin, end)) {
105       auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(idata + b * input_height * input_width * channels);
106       auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(odata + i * channels);
107 
108       const int64_t h1 = nn_compute_source_index_fn(height_scale, h2, input_height);
109       const int64_t w1 = nn_compute_source_index_fn(width_scale, w2, input_width);
110 
111       const auto* pos1 = &i_p[(h1 * input_width + w1)*channels];
112       auto* pos2 = &o_p[0];
113       std::memcpy(pos2, pos1, channels * sizeof(typename scalar_t::underlying));
114 
115       data_index_step(b, nbatch, h2, output_height, w2, output_width);
116     }
117   });
118 }
119 
120 template <nn_compute_source_index_fn_t nn_compute_source_index_fn>
_upsample_nearest2d_quantized_cpu(const Tensor & input,IntArrayRef output_size,std::optional<double> scales_h,std::optional<double> scales_w)121 Tensor _upsample_nearest2d_quantized_cpu(
122     const Tensor& input,
123     IntArrayRef output_size,
124     std::optional<double> scales_h,
125     std::optional<double> scales_w) {
126   TORCH_CHECK(
127       output_size.size() == 2,
128       "It is expected output_size equals to 2, but got size ",
129       output_size.size());
130 
131   TORCH_CHECK(
132       input.dim() == 4,
133       "Non-empty 4D data tensor expected but got a tensor with sizes ",
134       input.sizes());
135 
136   int64_t output_height = output_size[0];
137   int64_t output_width = output_size[1];
138 
139   int64_t nbatch = input.size(0);
140   int64_t channels = input.size(1);
141   int64_t input_height = input.size(2);
142   int64_t input_width = input.size(3);
143     AT_ASSERT(input_width > 0 && output_width > 0);
144   if (input.is_contiguous(c10::MemoryFormat::ChannelsLast)) {
145     Tensor output = at::_empty_affine_quantized(
146         {nbatch, channels, output_height, output_width},
147         input.options().memory_format(input.suggest_memory_format()),
148         input.q_scale(),
149         input.q_zero_point(),
150         std::nullopt);
151 
152     // special case: just copy
153     if (input_height == output_height && input_width == output_width) {
154       output.copy_(input);
155       return output;
156     }
157 
158     AT_DISPATCH_QINT_TYPES(input.scalar_type(), "upsample_nearest2d", [&] {
159       auto* idata = static_cast<scalar_t*>(input.data_ptr());
160       auto* odata = static_cast<scalar_t*>(output.data_ptr());
161       upsample_nearest2d_out_frame_nhwc<scalar_t, nn_compute_source_index_fn>(
162           odata,
163           idata,
164           input_height,
165           input_width,
166           output_height,
167           output_width,
168           nbatch,
169           channels,
170           scales_h,
171           scales_w);
172     });
173     return output;
174   } else {
175     Tensor output = at::_empty_affine_quantized(
176         {nbatch, channels, output_height, output_width},
177         input.options(),
178         input.q_scale(),
179         input.q_zero_point());
180 
181     auto input_contig = input.contiguous();
182 
183     AT_DISPATCH_QINT_TYPES(input_contig.scalar_type(), "upsample_nearest2d", [&] {
184       auto* idata = static_cast<scalar_t*>(input_contig.data_ptr());
185       auto* odata = static_cast<scalar_t*>(output.data_ptr());
186       upsample_nearest2d_out_frame<scalar_t, nn_compute_source_index_fn>(
187           odata,
188           idata,
189           input_height,
190           input_width,
191           output_height,
192           output_width,
193           nbatch,
194           channels,
195           scales_h,
196           scales_w);
197     });
198     return output;
199   }
200 }
201 
202 using at::native::upsample::compute_output_size;
203 using at::native::upsample::get_scale_value;
204 
upsample_nearest2d_quantized_cpu(const Tensor & input,IntArrayRef osize,std::optional<double> scale_h,std::optional<double> scale_w)205 Tensor upsample_nearest2d_quantized_cpu(
206     const Tensor& input,
207     IntArrayRef osize,
208     std::optional<double> scale_h,
209     std::optional<double> scale_w) {
210   return _upsample_nearest2d_quantized_cpu<nearest_neighbor_compute_source_index>(input, osize, scale_h, scale_w);
211 }
212 
_upsample_nearest_exact2d_quantized_cpu(const Tensor & input,IntArrayRef osize,std::optional<double> scale_h,std::optional<double> scale_w)213 Tensor _upsample_nearest_exact2d_quantized_cpu(
214     const Tensor& input,
215     IntArrayRef osize,
216     std::optional<double> scale_h,
217     std::optional<double> scale_w) {
218   return _upsample_nearest2d_quantized_cpu<nearest_neighbor_exact_compute_source_index>(input, osize, scale_h, scale_w);
219 }
220 
221 } // namespace native
222 } // namespace at
223