1 #pragma once
2 #include <ATen/core/TensorAccessor.h>
3 #include <ATen/cuda/Atomic.cuh>
4
5 #include <c10/util/ArrayRef.h>
6 #include <c10/util/SmallVector.h>
7 #include <c10/util/OptionalArrayRef.h>
8
9 #include <math.h>
10 #include <optional>
11
12 namespace at {
13 namespace native {
14
15 namespace upsample {
16 // TODO: Remove duplicate declaration.
17 TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
18 c10::IntArrayRef input_size, // Full input tensor size.
19 at::OptionalIntArrayRef output_size,
20 std::optional<c10::ArrayRef<double>> scale_factors);
21 } // namespace upsample
22
23 namespace upsample_cuda {
24
25 // TODO: Remove duplication with Upsample.h (CPU).
get_scale_value(std::optional<c10::ArrayRef<double>> scales,int idx)26 inline std::optional<double> get_scale_value(std::optional<c10::ArrayRef<double>> scales, int idx) {
27 if (!scales) {
28 return std::nullopt;
29 }
30 return scales->at(idx);
31 }
32
33 } // namespace upsample_cuda
34
35
36 /* TODO: move this to a common place */
37 template <typename scalar_t>
min(scalar_t a,scalar_t b)38 __device__ inline scalar_t min(scalar_t a, scalar_t b) {
39 return a < b ? a : b;
40 }
41
42 template <typename scalar_t>
max(scalar_t a,scalar_t b)43 __device__ inline scalar_t max(scalar_t a, scalar_t b) {
44 return a > b ? a : b;
45 }
46
47 // NOTE [ Nearest neighbor upsampling kernel implementation ]
48 //
49 // The nearest neighbor upsampling kernel implementation is symmetrical as
50 // expected. We launch kernels with threads mapping to destination tensors where
51 // kernels write data to, each thread reads data from the source tensor, this
52 // means:
53 // 1. In the forward kernel,
54 // src_xxx refers to properties of input tensors;
55 // dst_xxx refers to properties of output tensors;
56 // scale_factor is the ratio of src_size to dst_size;
57 // 2. In the backward kernel,
58 // src_xxx refers to properties of grad_output tensors;
59 // dst_xxx refers to properties of grad_input tensors;
60 // scale_factor is the ratio of src_size to dst_size;
61 //
62 // Because of this, we need to take the reciprocal of the scale defined by
63 // upsample layer during forward path. The motivation is to avoid slow
64 // division in the kernel code, so we can use faster multiplication instead.
65 // This is not necessary during backward path, since the scale_factor is already
66 // the reciprocal of corresponding scale_factor used in the forward path due to
67 // the swap of source and destination tensor.
68 //
69 // Similarly, since the mapping from grad_input to grad_output during backward
70 // is the reverse of the mapping of output to input, we need to have opposite
71 // mapping functions to compute the source index.
72
73 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
74 template <typename accscalar_t>
compute_scales_value(const std::optional<double> scale,int64_t src_size,int64_t dst_size)75 __host__ __forceinline__ accscalar_t compute_scales_value(
76 const std::optional<double> scale,
77 int64_t src_size,
78 int64_t dst_size) {
79 // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
80 return (scale.has_value() && scale.value() > 0.) ? (accscalar_t)(1.0 / scale.value())
81 : (accscalar_t)src_size / dst_size;
82 }
83
84 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
85 template <typename accscalar_t>
compute_scales_value_backwards(const std::optional<double> scale,int64_t src_size,int64_t dst_size)86 __host__ __forceinline__ accscalar_t compute_scales_value_backwards(
87 const std::optional<double> scale,
88 int64_t src_size,
89 int64_t dst_size) {
90 // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
91 return (scale.has_value() && scale.value() > 0.) ? (accscalar_t)scale.value()
92 : (accscalar_t)src_size / dst_size;
93 }
94
95 template <typename accscalar_t>
area_pixel_compute_scale(int input_size,int output_size,bool align_corners,const std::optional<double> scale)96 __host__ __forceinline__ accscalar_t area_pixel_compute_scale(
97 int input_size,
98 int output_size,
99 bool align_corners,
100 const std::optional<double> scale) {
101 if(align_corners) {
102 if(output_size > 1) {
103 return (accscalar_t)(input_size - 1) / (output_size - 1);
104 }
105 else {
106 return static_cast<accscalar_t>(0);
107 }
108 }
109 else{
110 return compute_scales_value<accscalar_t>(scale, input_size, output_size);
111 }
112 }
113
114 template <typename accscalar_t>
area_pixel_compute_source_index(accscalar_t scale,int dst_index,bool align_corners,bool cubic)115 __device__ __forceinline__ accscalar_t area_pixel_compute_source_index(
116 accscalar_t scale,
117 int dst_index,
118 bool align_corners,
119 bool cubic) {
120 if (align_corners) {
121 return scale * dst_index;
122 } else {
123 accscalar_t src_idx = scale * (dst_index + static_cast<accscalar_t>(0.5)) -
124 static_cast<accscalar_t>(0.5);
125 // See Note[Follow Opencv resize logic]
126 return (!cubic && src_idx < static_cast<accscalar_t>(0))
127 ? static_cast<accscalar_t>(0)
128 : src_idx;
129 }
130 }
131
132 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
nearest_neighbor_compute_source_index(const float scale,int dst_index,int input_size)133 __device__ __forceinline__ int nearest_neighbor_compute_source_index(
134 const float scale,
135 int dst_index,
136 int input_size) {
137 // index_f32 = (output_index) * scale
138 // input_index = round(index_f32)
139 // Same as a buggy OpenCV INTER_NEAREST
140 // We keep this method for BC and consider as deprecated.
141 // See nearest_neighbor_exact_compute_source_index as replacement
142 const int src_index =
143 min(static_cast<int>(floorf((dst_index) * scale)), input_size - 1);
144 return src_index;
145 }
146
nearest_neighbor_exact_compute_source_index(const float scale,int dst_index,int input_size)147 __device__ __forceinline__ int nearest_neighbor_exact_compute_source_index(
148 const float scale,
149 int dst_index,
150 int input_size) {
151 // index_f32 = (output_index + 0.5) * scale - 0.5
152 // input_index = round(index_f32)
153 // Same as Pillow and Scikit-Image/Scipy ndi.zoom
154 const int src_index =
155 min(static_cast<int>(floorf((dst_index + static_cast<float>(0.5)) * scale)), input_size - 1);
156 return src_index;
157 }
158
159 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
nearest_neighbor_bw_compute_source_index(const float scale,int dst_index,int output_size)160 __device__ __forceinline__ int nearest_neighbor_bw_compute_source_index(
161 const float scale,
162 int dst_index,
163 int output_size) {
164 // Equivalent to buggy OpenCV INTER_NEAREST
165 // We keep this method for BC and consider as deprecated.
166 // See nearest_neighbor_exact_bw_compute_source_index as replacement
167 const int src_index =
168 min(static_cast<int>(ceilf(dst_index * scale)), output_size);
169 return src_index;
170 }
171
172 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
nearest_neighbor_exact_bw_compute_source_index(const float scale,int dst_index,int output_size)173 __device__ __forceinline__ int nearest_neighbor_exact_bw_compute_source_index(
174 const float scale,
175 int dst_index,
176 int output_size) {
177 // Equivalent to Pillow and Scikit-Image/Scipy ndi.zoom
178 const int src_index =
179 min(static_cast<int>(ceilf(dst_index * scale - static_cast<float>(0.5))), output_size);
180 return src_index;
181 }
182
183 /* Used by UpSampleBicubic2d.cu */
184 template <typename scalar_t>
upsample_get_value_bounded(const PackedTensorAccessor64<const scalar_t,4> & data,int batch,int channel,int height,int width,int y,int x)185 __device__ __forceinline__ scalar_t upsample_get_value_bounded(
186 const PackedTensorAccessor64<const scalar_t, 4>& data,
187 int batch,
188 int channel,
189 int height,
190 int width,
191 int y,
192 int x) {
193 int access_y = max(min(y, height - 1), 0);
194 int access_x = max(min(x, width - 1), 0);
195 return data[batch][channel][access_y][access_x];
196 }
197
198 /* Used by UpSampleBicubic2d.cu */
199 template <typename scalar_t, typename accscalar_t>
upsample_increment_value_bounded(PackedTensorAccessor64<scalar_t,4> & data,int batch,int channel,int height,int width,int y,int x,accscalar_t value)200 __device__ __forceinline__ void upsample_increment_value_bounded(
201 PackedTensorAccessor64<scalar_t, 4>& data,
202 int batch,
203 int channel,
204 int height,
205 int width,
206 int y,
207 int x,
208 accscalar_t value) {
209 int access_y = max(min(y, height - 1), 0);
210 int access_x = max(min(x, width - 1), 0);
211 /* TODO: result here is truncated to scalar_t,
212 check: https://github.com/pytorch/pytorch/pull/19630#discussion_r281426912
213 */
214 gpuAtomicAddNoReturn(
215 &data[batch][channel][access_y][access_x], static_cast<scalar_t>(value));
216 }
217
218 // Based on
219 // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
220 template <typename accscalar_t>
cubic_convolution1(accscalar_t x,accscalar_t A)221 __device__ __forceinline__ accscalar_t cubic_convolution1(
222 accscalar_t x,
223 accscalar_t A) {
224 return ((A + 2) * x - (A + 3)) * x * x + 1;
225 }
226
227 template <typename accscalar_t>
cubic_convolution2(accscalar_t x,accscalar_t A)228 __device__ __forceinline__ accscalar_t cubic_convolution2(
229 accscalar_t x,
230 accscalar_t A) {
231 return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
232 }
233
234 template <typename accscalar_t>
get_cubic_upsampling_coefficients(accscalar_t coeffs[4],accscalar_t t)235 __device__ __forceinline__ void get_cubic_upsampling_coefficients(
236 accscalar_t coeffs[4],
237 accscalar_t t) {
238 accscalar_t A = -0.75;
239
240 accscalar_t x1 = t;
241 coeffs[0] = cubic_convolution2<accscalar_t>(x1 + 1.0, A);
242 coeffs[1] = cubic_convolution1<accscalar_t>(x1, A);
243
244 // opposite coefficients
245 accscalar_t x2 = 1.0 - t;
246 coeffs[2] = cubic_convolution1<accscalar_t>(x2, A);
247 coeffs[3] = cubic_convolution2<accscalar_t>(x2 + 1.0, A);
248 }
249
250 template <typename scalar_t, typename accscalar_t>
cubic_interp1d(scalar_t x0,scalar_t x1,scalar_t x2,scalar_t x3,accscalar_t t)251 __device__ __forceinline__ accscalar_t cubic_interp1d(
252 scalar_t x0,
253 scalar_t x1,
254 scalar_t x2,
255 scalar_t x3,
256 accscalar_t t) {
257 accscalar_t coeffs[4];
258 get_cubic_upsampling_coefficients<accscalar_t>(coeffs, t);
259
260 return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
261 }
262
263 namespace upsample_antialias {
264
265 // taken from
266 // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
267 // src/libImaging/Resample.c#L20-L29
268 struct BilinearFilterFunctor {
269
270 template <typename accscalar_t>
operator ()at::native::upsample_antialias::BilinearFilterFunctor271 __device__ accscalar_t operator()(accscalar_t x) const {
272 if (x < 0) {
273 x = -x;
274 }
275 if (x < 1) {
276 return 1 - x;
277 }
278 return 0;
279 }
280
281 static const int size = 2;
282 };
283
284 // taken from
285 // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
286 // src/libImaging/Resample.c#L46-L62
287 struct BicubicFilterFunctor {
288
289 template <typename accscalar_t>
operator ()at::native::upsample_antialias::BicubicFilterFunctor290 __device__ accscalar_t operator()(accscalar_t x) const {
291 // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
292 const accscalar_t a = -0.5;
293 if (x < 0) {
294 x = -x;
295 }
296 if (x < 1) {
297 return ((a + 2) * x - (a + 3)) * x * x + 1;
298 }
299 if (x < 2) {
300 return (((x - 5) * x + 8) * x - 4) * a;
301 }
302 return 0;
303 }
304
305 static const int size = 4;
306 };
307
308 template <typename accscalar_t>
_compute_weights_span(const int i,const int input_size,const accscalar_t scale,const accscalar_t support,int & xmin,int & xsize,accscalar_t & center)309 __device__ __forceinline__ void _compute_weights_span(
310 const int i,
311 const int input_size,
312 const accscalar_t scale,
313 const accscalar_t support,
314 int& xmin,
315 int& xsize,
316 accscalar_t& center) {
317 center = scale * (i + static_cast<accscalar_t>(0.5));
318 xmin = max(static_cast<int>(center - support + static_cast<accscalar_t>(0.5)), static_cast<int>(0));
319 xsize = min(static_cast<int>(center + support + static_cast<accscalar_t>(0.5)), input_size) - xmin;
320 }
321
322 template <typename scalar_t, typename accscalar_t, typename interp_filter_t>
_compute_weights(scalar_t * wt_ptr,const accscalar_t scale,int interp_size,const interp_filter_t & interp_filter,accscalar_t xmin_m_center,int xsize)323 __device__ __forceinline__ void _compute_weights(
324 scalar_t* wt_ptr,
325 const accscalar_t scale,
326 int interp_size,
327 const interp_filter_t& interp_filter,
328 accscalar_t xmin_m_center,
329 int xsize) {
330
331 accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0;
332 accscalar_t total_w = 0.0;
333 int j = 0;
334 for (j = 0; j < xsize; j++) {
335 accscalar_t w = interp_filter((j + xmin_m_center + static_cast<accscalar_t>(0.5)) * invscale);
336 wt_ptr[j] = static_cast<scalar_t>(w);
337 total_w += w;
338 }
339 for (j = 0; j < xsize; j++) {
340 if (total_w != 0.0) {
341 wt_ptr[j] /= total_w;
342 }
343 }
344 for (; j < interp_size; j++) {
345 wt_ptr[j] = static_cast<scalar_t>(0.0);
346 }
347 }
348
349 template <typename scalar_t, typename accscalar_t>
interpolate_aa_single_dim(const scalar_t * src,const scalar_t * weights,int size)350 __device__ __forceinline__ accscalar_t interpolate_aa_single_dim(
351 const scalar_t* src,
352 const scalar_t* weights,
353 int size) {
354 scalar_t t = static_cast<accscalar_t>(*src);
355 scalar_t wts = static_cast<accscalar_t>(weights[0]);
356 accscalar_t output = t * wts;
357
358 int j = 1;
359 for (; j < size; j++) {
360 wts = static_cast<accscalar_t>(weights[j]);
361 t = static_cast<accscalar_t>(*(src + j));
362 output += t * wts;
363 }
364 return output;
365 }
366
367 }
368
369 } // namespace native
370 } // namespace at
371