1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cpu/DepthwiseConvKernel.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Parallel.h>
5 #include <c10/util/irange.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #else
10 #include <ATen/ops/empty.h>
11 #include <ATen/ops/zeros.h>
12 #endif
13
14 #ifdef __ARM_NEON__
15 #include <arm_neon.h>
16 #elif defined(__riscv_v_intrinsic) && __riscv_v_intrinsic>=12000
17 #include <riscv_vector.h>
18 #endif
19
20 namespace at::native {
21 namespace {
22
23 struct Arguments final {
24 // Input layer dimensions
25 int64_t batch;
26 int64_t in_rows;
27 int64_t in_cols;
28 int64_t stride;
29 int64_t pad_rows;
30 int64_t pad_cols;
31
32 // Output layer dimensions
33 int64_t out_rows;
34 int64_t out_cols;
35 int64_t out_channels;
36 };
37
calculate_conv_output_size(const IntArrayRef input_size,const IntArrayRef weight_size,const IntArrayRef stride,const IntArrayRef padding)38 inline std::vector<int64_t> calculate_conv_output_size(
39 const IntArrayRef input_size,
40 const IntArrayRef weight_size,
41 const IntArrayRef stride,
42 const IntArrayRef padding) {
43 const auto calc_output_dimension = [](
44 const int64_t input, const int64_t kernel, const int64_t stride, const int64_t padding) {
45 return 1 + (input - kernel + 2 * padding) / stride;
46 };
47
48 return std::vector<int64_t> {
49 input_size[0],
50 weight_size[0],
51 calc_output_dimension(input_size[2], weight_size[2], stride[0], padding[0]),
52 calc_output_dimension(input_size[3], weight_size[3], stride[1], padding[1]),
53 };
54 }
55
56 #ifdef __ARM_NEON__
57
winograd_f2k3_input_transform_inplace__neon(float32x4_t * const d0,float32x4_t * const d1,float32x4_t * const d2,float32x4_t * const d3)58 inline void winograd_f2k3_input_transform_inplace__neon(
59 float32x4_t* const d0,
60 float32x4_t* const d1,
61 float32x4_t* const d2,
62 float32x4_t* const d3) {
63 const float32x4_t wd0 = *d0 - *d2;
64 const float32x4_t wd1 = *d1 + *d2;
65 const float32x4_t wd2 = -*d1 + *d2;
66 const float32x4_t wd3 = *d1 - *d3;
67 *d0 = wd0;
68 *d1 = wd1;
69 *d2 = wd2;
70 *d3 = wd3;
71 }
72
winograd_f2k3_output_transform_inplace__neon(float32x4_t * const m0,float32x4_t * const m1,const float32x4_t * const m2,const float32x4_t * const m3)73 inline void winograd_f2k3_output_transform_inplace__neon(
74 float32x4_t* const m0,
75 float32x4_t* const m1,
76 const float32x4_t* const m2,
77 const float32x4_t* const m3) {
78 *m0 = *m0 + *m1 + *m2;
79 *m1 = *m1 - *m2 - *m3;
80 }
81
82 inline float32x4_t
vmuladdq_f32(const float32x4_t c,const float32x4_t a,const float32x4_t b)83 vmuladdq_f32(const float32x4_t c, const float32x4_t a, const float32x4_t b) {
84 #if defined(__aarch64__)
85 return vfmaq_f32(c, a, b);
86 #else
87 return vmlaq_f32(c, a, b);
88 #endif
89 }
90
91 inline float32x4_t
vmulsubq_f32(const float32x4_t c,const float32x4_t a,const float32x4_t b)92 vmulsubq_f32(const float32x4_t c, const float32x4_t a, const float32x4_t b) {
93 #if defined(__aarch64__)
94 return vfmsq_f32(c, a, b);
95 #else
96 return vmlsq_f32(c, a, b);
97 #endif
98 }
99
winograd_f2k3_kernel_transform__neon(const float32x4_t g0,const float32x4_t g1,const float32x4_t g2,float32x4_t * const transform0,float32x4_t * const transform1,float32x4_t * const transform2,float32x4_t * const transform3)100 inline void winograd_f2k3_kernel_transform__neon(
101 const float32x4_t g0,
102 const float32x4_t g1,
103 const float32x4_t g2,
104 float32x4_t* const transform0,
105 float32x4_t* const transform1,
106 float32x4_t* const transform2,
107 float32x4_t* const transform3) {
108 const float32x4_t const_half = vdupq_n_f32(0.5f);
109 float32x4_t half_g0_plus_g2 = const_half * (g0 + g2);
110 *transform0 = g0;
111 *transform1 = vmuladdq_f32(half_g0_plus_g2, const_half, g1);
112 *transform2 = vmulsubq_f32(half_g0_plus_g2, const_half, g1);
113 *transform3 = g2;
114 }
115
v4f_transpose4x4__neon(const float32x4x4_t m)116 inline float32x4x4_t v4f_transpose4x4__neon(const float32x4x4_t m) {
117 float32x4x4_t ret;
118 vst4q_f32((float*)(&ret), m);
119 return ret;
120 }
121
convolution_depthwise3x3_winograd_impl(const Arguments & args,const float * const input,const float * const kernel,const float * const bias,float * const output)122 void convolution_depthwise3x3_winograd_impl(
123 const Arguments& args,
124 const float* const input,
125 const float* const kernel,
126 const float* const bias,
127 float* const output) {
128 const float32x4_t vbias = vsetq_lane_f32(*bias, vdupq_n_f32(0.0), 1);
129 float32x4x4_t kernel_tile;
130
131 {
132 const float32x4_t g0 = vld1q_f32(kernel);
133 const float32x4_t g1 = vld1q_f32(kernel + 3);
134 // g2[3] is junk
135 const float32x4_t g2 =
136 vextq_f32(vld1q_f32(kernel + 5), vld1q_f32(kernel + 5), 1);
137 float32x4x4_t w;
138 winograd_f2k3_kernel_transform__neon(
139 g0, g1, g2, &w.val[0], &w.val[1], &w.val[2], &w.val[3]);
140 w = v4f_transpose4x4__neon(w);
141
142 winograd_f2k3_kernel_transform__neon(
143 w.val[0],
144 w.val[1],
145 w.val[2],
146 &kernel_tile.val[0],
147 &kernel_tile.val[1],
148 &kernel_tile.val[2],
149 &kernel_tile.val[3]);
150 }
151
152 #define TILE \
153 winograd_f2k3_input_transform_inplace__neon( \
154 &input_tile.val[0], \
155 &input_tile.val[1], \
156 &input_tile.val[2], \
157 &input_tile.val[3]); \
158 input_tile = v4f_transpose4x4__neon(input_tile); \
159 winograd_f2k3_input_transform_inplace__neon( \
160 &input_tile.val[0], \
161 &input_tile.val[1], \
162 &input_tile.val[2], \
163 &input_tile.val[3]); \
164 \
165 for (const auto row : c10::irange(4)) { \
166 input_tile.val[row] = \
167 vmulq_f32(input_tile.val[row], kernel_tile.val[row]); \
168 } \
169 \
170 input_tile.val[1] = input_tile.val[1] + vbias; \
171 winograd_f2k3_output_transform_inplace__neon( \
172 &input_tile.val[0], \
173 &input_tile.val[1], \
174 &input_tile.val[2], \
175 &input_tile.val[3]); \
176 input_tile = v4f_transpose4x4__neon(input_tile); \
177 winograd_f2k3_output_transform_inplace__neon( \
178 &input_tile.val[0], \
179 &input_tile.val[1], \
180 &input_tile.val[2], \
181 &input_tile.val[3])
182
183 // Non-padded regime.
184
185 // Iterate over non-padded output tiles.
186 // TODO: avoid spilling W by breaking out the non-padded vs padded case.
187 for (int64_t oth = 0; oth < (args.out_rows + 1) / 2; ++oth) {
188 for (int64_t otw = 0; otw < (args.out_cols + 1) / 2; ++otw) {
189 // load input tile for [oth, otw];
190 int64_t ih = oth * 2 - args.pad_rows;
191 int64_t iw = otw * 2 - args.pad_cols;
192 // fast-path, all accesses in-bounds
193 if (C10_LIKELY(
194 ih >= 0 && iw >= 0 && ih + 3 < args.in_rows &&
195 iw + 3 < args.in_cols && 2 * oth + 1 < args.out_rows &&
196 2 * otw + 1 < args.out_cols
197 )) {
198 float32x4x4_t input_tile;
199 for (const auto row : c10::irange(4)) {
200 input_tile.val[row] =
201 vld1q_f32(input + (ih + row) * args.in_cols + iw);
202 }
203
204 TILE;
205
206 for (const auto row : c10::irange(2)) {
207 vst1_f32(
208 output + (oth * 2 + row) * args.out_cols + otw * 2,
209 vget_low_f32(input_tile.val[row]));
210 }
211 } else {
212 float block[4][4];
213 for (const auto row : c10::irange(4)) {
214 for (const auto col : c10::irange(4)) {
215 if (ih + row >= 0 && iw + col >= 0 && ih + row < args.in_rows &&
216 iw + col < args.in_cols) {
217 block[row][col] = input[(ih + row) * args.in_cols + iw + col];
218 } else {
219 block[row][col] = 0.0;
220 }
221 }
222 }
223
224 float32x4x4_t input_tile;
225 for (const auto row : c10::irange(4)) {
226 input_tile.val[row] = vld1q_f32(&block[row][0]);
227 }
228
229 TILE;
230
231 float oblock[2][2];
232 for (const auto row : c10::irange(2)) {
233 vst1_f32(&oblock[row][0], vget_low_f32(input_tile.val[row]));
234 }
235 for (const auto row : c10::irange(2)) {
236 for (const auto col : c10::irange(2)) {
237 if (2 * oth + row < args.out_rows &&
238 2 * otw + col < args.out_cols) {
239 output[(2 * oth + row) * args.out_cols + 2 * otw + col] =
240 oblock[row][col];
241 }
242 }
243 }
244 }
245 }
246 }
247 }
248
249 #elif defined(__riscv_v_intrinsic) && __riscv_v_intrinsic>=12000
250
winograd_f2k3_input_transform_inplace__rvv(vfloat32m1x4_t * input_tile_val)251 inline void winograd_f2k3_input_transform_inplace__rvv(
252 vfloat32m1x4_t* input_tile_val) {
253 const vfloat32m1_t d0 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 0);
254 const vfloat32m1_t d1 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 1);
255 const vfloat32m1_t d2 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 2);
256 const vfloat32m1_t d3 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 3);
257
258 const vfloat32m1_t wd0 = __riscv_vfsub_vv_f32m1(d0, d2, 4);
259 const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4);
260 const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4);
261 const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4);
262
263 *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wd0);
264 *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wd1);
265 *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 2, wd2);
266 *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 3, wd3);
267 }
268
winograd_f2k3_output_transform_inplace__rvv(vfloat32m1x4_t * input_tile_val)269 inline void winograd_f2k3_output_transform_inplace__rvv(
270 vfloat32m1x4_t* input_tile_val) {
271 const vfloat32m1_t m0 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 0);
272 const vfloat32m1_t m1 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 1);
273 const vfloat32m1_t m2 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 2);
274 const vfloat32m1_t m3 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 3);
275
276 const vfloat32m1_t m0_plus_m1 = __riscv_vfadd_vv_f32m1(m0, m1, 4);
277 const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4);
278 const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4);
279 const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4);
280
281 *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wm0);
282 *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wm1);
283 }
284
285 inline vfloat32m1_t
vmuladdq_f32(const vfloat32m1_t c,const vfloat32m1_t a,const vfloat32m1_t b)286 vmuladdq_f32(const vfloat32m1_t c, const vfloat32m1_t a, const vfloat32m1_t b) {
287 return __riscv_vfmacc_vv_f32m1(c, a, b, 4);
288 }
289
290 inline vfloat32m1_t
vmulsubq_f32(const vfloat32m1_t c,const vfloat32m1_t a,const vfloat32m1_t b)291 vmulsubq_f32(const vfloat32m1_t c, const vfloat32m1_t a, const vfloat32m1_t b) {
292 return __riscv_vfnmsac_vv_f32m1(c, a, b, 4);
293 }
294
winograd_f2k3_kernel_transform__rvv(const vfloat32m1_t g0,const vfloat32m1_t g1,const vfloat32m1_t g2,vfloat32m1x4_t * const transform)295 inline void winograd_f2k3_kernel_transform__rvv(
296 const vfloat32m1_t g0,
297 const vfloat32m1_t g1,
298 const vfloat32m1_t g2,
299 vfloat32m1x4_t* const transform) {
300 const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4);
301 const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4);
302 vfloat32m1_t half_g0_plus_g2 = __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4);
303
304 *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 0, g0);
305 *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
306 *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
307 *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 3, g2);
308 }
309
v4f_transpose4x4__rvv(const vfloat32m1x4_t m)310 inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) {
311 vfloat32m1x4_t ret;
312 __riscv_vsseg4e32_v_f32m1x4((float*)(&ret), m, 4);
313 return ret;
314 }
315
convolution_depthwise3x3_winograd_impl(const Arguments & args,const float * const input,const float * const kernel,const float * const bias,float * const output)316 void convolution_depthwise3x3_winograd_impl(
317 const Arguments& args,
318 const float* const input,
319 const float* const kernel,
320 const float* const bias,
321 float* const output) {
322
323 vbool32_t mask = __riscv_vreinterpret_v_u32m1_b32(__riscv_vmv_v_x_u32m1((uint32_t)(1 << 1),2));
324 const vfloat32m1_t vbias = __riscv_vfmerge_vfm_f32m1(__riscv_vfmv_v_f_f32m1(0.0, 4), *bias, mask, 4);
325 vfloat32m1x4_t kernel_tile;
326
327 {
328 const vfloat32m1_t g0 = __riscv_vle32_v_f32m1(kernel, 4);
329 const vfloat32m1_t g1 = __riscv_vle32_v_f32m1(kernel + 3, 4);
330 // g2[3] is junk
331 vfloat32m1_t a_slidedown = __riscv_vslidedown_vx_f32m1(__riscv_vle32_v_f32m1(kernel + 5, 4), 1, 4);
332 const vfloat32m1_t g2 =
333 __riscv_vslideup_vx_f32m1(a_slidedown, __riscv_vle32_v_f32m1(kernel + 5, 4), 3, 4);
334 vfloat32m1x4_t w;
335
336 winograd_f2k3_kernel_transform__rvv(
337 g0, g1, g2, &w);
338
339 w = v4f_transpose4x4__rvv(w);
340
341 winograd_f2k3_kernel_transform__rvv(
342 __riscv_vget_v_f32m1x4_f32m1(w, 0),
343 __riscv_vget_v_f32m1x4_f32m1(w, 1),
344 __riscv_vget_v_f32m1x4_f32m1(w, 2),
345 &kernel_tile);
346
347 }
348
349 #define TILE \
350 winograd_f2k3_input_transform_inplace__rvv( \
351 &input_tile); \
352 input_tile = v4f_transpose4x4__rvv(input_tile); \
353 winograd_f2k3_input_transform_inplace__rvv( \
354 &input_tile); \
355 \
356 for (const auto row : c10::irange(4)) { \
357 vfloat32m1_t input_mul_kernel = \
358 __riscv_vfmul_vv_f32m1( \
359 __riscv_vle32_v_f32m1((float*)&input_tile + row * 4, 4), \
360 __riscv_vle32_v_f32m1((float*)&kernel_tile + row * 4, 4), \
361 4); \
362 __riscv_vse32_v_f32m1( \
363 (float*)&input_tile + row * 4, \
364 input_mul_kernel, \
365 4); \
366 } \
367 \
368 vfloat32m1_t val = __riscv_vget_v_f32m1x4_f32m1(input_tile, 1); \
369 vfloat32m1_t val_add_vbias = __riscv_vfadd_vv_f32m1(val, vbias, 4); \
370 input_tile = __riscv_vset_v_f32m1_f32m1x4(input_tile, 1, val_add_vbias); \
371 winograd_f2k3_output_transform_inplace__rvv( \
372 &input_tile); \
373 input_tile = v4f_transpose4x4__rvv(input_tile); \
374 winograd_f2k3_output_transform_inplace__rvv( \
375 &input_tile)
376
377 // Non-padded regime.
378
379 // Iterate over non-padded output tiles.
380 // TODO: avoid spilling W by breaking out the non-padded vs padded case.
381 for (int64_t oth = 0; oth < (args.out_rows + 1) / 2; ++oth) {
382 for (int64_t otw = 0; otw < (args.out_cols + 1) / 2; ++otw) {
383 // load input tile for [oth, otw];
384 int64_t ih = oth * 2 - args.pad_rows;
385 int64_t iw = otw * 2 - args.pad_cols;
386 // fast-path, all accesses in-bounds
387 if (C10_LIKELY(
388 ih >= 0 && iw >= 0 && ih + 3 < args.in_rows &&
389 iw + 3 < args.in_cols && 2 * oth + 1 < args.out_rows &&
390 2 * otw + 1 < args.out_cols
391 )) {
392 vfloat32m1x4_t input_tile;
393 for (const auto row : c10::irange(4)) {
394 __riscv_vse32_v_f32m1(
395 (float*)&input_tile + row * 4,
396 __riscv_vle32_v_f32m1(input + (ih + row) * args.in_cols + iw, 4),
397 4);
398 }
399
400 TILE;
401
402 for (const auto row : c10::irange(2)) {
403 __riscv_vse32_v_f32m1(
404 output + (oth * 2 + row) * args.out_cols + otw * 2,
405 __riscv_vle32_v_f32m1((float*)&input_tile + row * 4, 2),
406 2);
407 }
408 } else {
409 float block[4][4];
410 for (const auto row : c10::irange(4)) {
411 for (const auto col : c10::irange(4)) {
412 if (ih + row >= 0 && iw + col >= 0 && ih + row < args.in_rows &&
413 iw + col < args.in_cols) {
414 block[row][col] = input[(ih + row) * args.in_cols + iw + col];
415 } else {
416 block[row][col] = 0.0;
417 }
418 }
419 }
420
421 vfloat32m1x4_t input_tile;
422 for (const auto row : c10::irange(4)) {
423 __riscv_vse32_v_f32m1(
424 (float*)&input_tile + row * 4,
425 __riscv_vle32_v_f32m1(&block[row][0], 4),
426 4);
427 }
428
429 TILE;
430
431 float oblock[2][2];
432 for (const auto row : c10::irange(2)) {
433 __riscv_vse32_v_f32m1(
434 &oblock[row][0],
435 __riscv_vle32_v_f32m1((float*)&input_tile + row * 4, 2),
436 2);
437 }
438 for (const auto row : c10::irange(2)) {
439 for (const auto col : c10::irange(2)) {
440 if (2 * oth + row < args.out_rows &&
441 2 * otw + col < args.out_cols) {
442 output[(2 * oth + row) * args.out_cols + 2 * otw + col] =
443 oblock[row][col];
444 }
445 }
446 }
447 }
448 }
449 }
450 }
451
452 #else
453
convolution_depthwise3x3_winograd_impl(const Arguments &,const float * const,const float * const,const float * const,float * const)454 void convolution_depthwise3x3_winograd_impl(
455 const Arguments&,
456 const float* const,
457 const float* const,
458 const float* const,
459 float* const) {
460 }
461
462 #endif /* __ARM_NEON__ */
463
_convolution_depthwise3x3_winograd(const Tensor & input,const Tensor & kernel,const Tensor & bias_potentially_undefined,const IntArrayRef stride,const IntArrayRef padding,const int64_t groups)464 Tensor _convolution_depthwise3x3_winograd(
465 const Tensor & input,
466 const Tensor & kernel,
467 const Tensor & bias_potentially_undefined,
468 const IntArrayRef stride,
469 const IntArrayRef padding,
470 const int64_t groups)
471 {
472 const IntArrayRef input_sizes = input.sizes();
473 const IntArrayRef kernel_sizes = kernel.sizes();
474
475 Tensor output = at::empty(
476 calculate_conv_output_size(input_sizes, kernel_sizes, stride, padding),
477 input.options());
478
479 const IntArrayRef output_sizes = output.sizes();
480
481 const Arguments args {
482 input_sizes[0], // Input N
483 input_sizes[2], // Input H
484 input_sizes[3], // Input W
485 stride[0], // Stride
486 padding[0], // Padding Rows
487 padding[1], // Padding Columns
488 output_sizes[2], // Output H
489 output_sizes[3], // Output W
490 output_sizes[1], // Output C
491 };
492
493 const int64_t input_hxw = args.in_rows * args.in_cols;
494 const int64_t output_hxw = args.out_rows * args.out_cols;
495
496 const Tensor bias = bias_potentially_undefined.defined() ?
497 bias_potentially_undefined :
498 at::zeros({kernel_sizes[0]}, input.options());
499
500 auto input_data = input.const_data_ptr<float>();
501 auto kernel_data = kernel.const_data_ptr<float>();
502 auto bias_data = bias.const_data_ptr<float>();
503 auto output_data = output.data_ptr<float>();
504
505 at::parallel_for(0, args.batch * args.out_channels, 0, [&](int64_t start, int64_t end) {
506 for (const auto k : c10::irange(start, end)) {
507 const int64_t g = k % args.out_channels;
508 const int64_t i = k / (args.out_channels / groups);
509 convolution_depthwise3x3_winograd_impl(
510 args,
511 input_data + i * input_hxw,
512 kernel_data + g * 3 * 3,
513 bias_data + g,
514 output_data + k * output_hxw);
515 }
516 });
517
518 return output;
519 }
520
521 } // namespace
522
523 ALSO_REGISTER_AVX512_DISPATCH(convolution_depthwise3x3_winograd_stub, &_convolution_depthwise3x3_winograd);
524
525 } // namespace at::native
526