xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/UpSample.h>
6 #include <ATen/native/quantized/AffineQuantizer.h>
7 #include <ATen/native/quantized/cpu/QuantizedOps.h>
8 #include <ATen/native/cpu/utils.h>
9 #include <c10/util/irange.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/_empty_affine_quantized.h>
16 #include <ATen/ops/upsample_bilinear2d_native.h>
17 #endif
18 
19 #include <cstring>
20 
21 namespace at {
22 namespace native {
23 namespace {
24 
25 // pre calculate interpolation params on width
26 struct UpsampleBilinearParamW {
27   int64_t w1, w1p;
28   float w0lambda, w1lambda;
29 
UpsampleBilinearParamWat::native::__anon15100f890111::UpsampleBilinearParamW30   UpsampleBilinearParamW(int64_t w1, int64_t w1p, float w0lambda, float w1lambda)
31     : w1(w1)
32     , w1p(w1p)
33     , w0lambda(w0lambda)
34     , w1lambda(w1lambda) {}
35 };
36 
37 // at::native functions for the native_functions.yaml
38 template <typename scalar_t>
upsample_bilinear2d_out_frame(Tensor & output,const Tensor & input,int64_t input_height,int64_t input_width,int64_t output_height,int64_t output_width,int64_t nbatch,int64_t channels,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)39 static void upsample_bilinear2d_out_frame(
40     Tensor& output,
41     const Tensor& input,
42     int64_t input_height,
43     int64_t input_width,
44     int64_t output_height,
45     int64_t output_width,
46     int64_t nbatch,
47     int64_t channels,
48     bool align_corners,
49     std::optional<double> scales_h,
50     std::optional<double> scales_w) {
51   auto* idata = static_cast<const scalar_t*>(input.const_data_ptr());
52   auto* odata = static_cast<scalar_t*>(output.data_ptr());
53 
54   channels = channels * nbatch;
55   if (channels == 0 || output_height == 0 || output_width == 0) {
56     return;
57   }
58   auto* i_p = reinterpret_cast<const typename scalar_t::underlying*>(idata);
59   auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(odata);
60 
61   // special case: just copy
62   if (input_height == output_height && input_width == output_width) {
63     std::memcpy(
64         o_p,
65         i_p,
66         channels * input_height * input_width *
67             sizeof(typename scalar_t::underlying));
68     return;
69   }
70 
71   const auto rheight = area_pixel_compute_scale<float>(
72       input_height, output_height, align_corners, scales_h);
73 
74   const auto rwidth = area_pixel_compute_scale<float>(
75       input_width, output_width, align_corners, scales_w);
76 
77   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
78   float output_scale = output.q_scale() / input.q_scale();
79 
80   const int64_t input_q_zero_point = input.q_zero_point();
81   const int64_t output_q_zero_point = output.q_zero_point();
82 
83   std::vector<UpsampleBilinearParamW> params_w;
84   params_w.reserve(output_width);
85   for (const auto w2 : c10::irange(output_width)) {
86     const auto w1r = area_pixel_compute_source_index<float>(
87         rwidth, w2, align_corners, /*cubic=*/false);
88 
89     const int64_t w1 = w1r;
90     const int64_t w1p = (w1 < input_width - 1) ? 1 : 0;
91 
92     const float w1lambda = w1r - w1;
93     const float w0lambda = static_cast<float>(1.) - w1lambda;
94 
95     params_w.emplace_back(w1, w1p, w0lambda, w1lambda);
96   }
97 
98   // compared to 'nearest', each requires 4 points and takes additional * and +
99   // set the scale to be 16.
100   int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, output_width) / 16;
101   at::parallel_for(0, channels * output_height, grain_size, [&](int64_t begin, int64_t end) {
102     int64_t nc{0}, h2{0};
103     data_index_init(begin, nc, channels, h2, output_height);
104 
105     for (const auto i : c10::irange(begin, end)) {
106       const auto h1r = area_pixel_compute_source_index<float>(
107           rheight, h2, align_corners, /*cubic=*/false);
108 
109       const int64_t h1 = h1r;
110       const int64_t h1p = (h1 < input_height - 1) ? 1 : 0;
111 
112       const float h1lambda = h1r - h1;
113       const float h0lambda = static_cast<float>(1.) - h1lambda;
114 
115       const auto* i_ptr = &i_p[nc * input_height * input_width];
116       auto* pos2 = &o_p[i * output_width];
117 
118       for (const auto w2 : c10::irange(output_width)) {
119         const auto& param_w = params_w[w2];
120         const int64_t w1 = param_w.w1;
121         const int64_t w1p = param_w.w1p;
122         const float w0lambda = param_w.w0lambda;
123         const float w1lambda = param_w.w1lambda;
124 
125         const auto* pos1 = i_ptr + h1 * input_width + w1;
126 
127         float result = h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p]) +
128             h1lambda *
129                 (w0lambda * pos1[h1p * input_width] +
130                  w1lambda * pos1[h1p * input_width + w1p]) - input_q_zero_point;
131         // requantization
132         pos2[w2] = at::native::quantize_val<scalar_t>(
133                       output_scale, output_q_zero_point, result)
134                       .val_;
135       }
136 
137       data_index_step(nc, channels, h2, output_height);
138     }
139   });
140 
141 }
142 
143 } // namespace
144 
upsample_bilinear2d_quantized_cpu(const Tensor & input,IntArrayRef output_size,bool align_corners,std::optional<double> scales_h,std::optional<double> scales_w)145 Tensor upsample_bilinear2d_quantized_cpu(
146     const Tensor& input,
147     IntArrayRef output_size,
148     bool align_corners,
149     std::optional<double> scales_h,
150     std::optional<double> scales_w) {
151   TORCH_CHECK(
152       output_size.size() == 2,
153       "It is expected output_size equals to 2, but got size ",
154       output_size.size());
155 
156   TORCH_CHECK(
157       input.dim() == 4,
158       "Non-empty 4D data tensor expected but got a tensor with sizes ",
159       input.sizes());
160 
161   int64_t output_height = output_size[0];
162   int64_t output_width = output_size[1];
163 
164   int64_t nbatch = input.size(0);
165   int64_t channels = input.size(1);
166   int64_t input_height = input.size(2);
167   int64_t input_width = input.size(3);
168   AT_ASSERT(input_width > 0 && output_width > 0);
169 
170   if (input.is_contiguous(c10::MemoryFormat::ChannelsLast)) {
171     Tensor output = at::_empty_affine_quantized(
172         {nbatch, channels, output_height, output_width},
173         input.options().memory_format(input.suggest_memory_format()),
174         input.q_scale(),
175         input.q_zero_point(),
176         std::nullopt);
177 
178     qupsample_bilinear2d_nhwc_stub(
179         input.device().type(),
180         output,
181         input,
182         input_height,
183         input_width,
184         output_height,
185         output_width,
186         nbatch,
187         channels,
188         align_corners,
189         scales_h,
190         scales_w);
191     return output;
192   } else {
193     Tensor output = at::_empty_affine_quantized(
194         {nbatch, channels, output_height, output_width},
195         input.options(),
196         input.q_scale(),
197         input.q_zero_point());
198 
199     auto input_contig = input.contiguous();
200     AT_DISPATCH_QINT_TYPES(
201         input_contig.scalar_type(), "upsample_bilinear2d", [&] {
202           upsample_bilinear2d_out_frame<scalar_t>(
203               output,
204               input_contig,
205               input_height,
206               input_width,
207               output_height,
208               output_width,
209               nbatch,
210               channels,
211               align_corners,
212               scales_h,
213               scales_w);
214         });
215     return output;
216   }
217 }
218 
219 DEFINE_DISPATCH(qupsample_bilinear2d_nhwc_stub);
220 } // namespace native
221 } // namespace at
222