xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/DepthwiseConvKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cpu/DepthwiseConvKernel.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Parallel.h>
5 #include <c10/util/irange.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #else
10 #include <ATen/ops/empty.h>
11 #include <ATen/ops/zeros.h>
12 #endif
13 
14 #ifdef __ARM_NEON__
15 #include <arm_neon.h>
16 #elif defined(__riscv_v_intrinsic) && __riscv_v_intrinsic>=12000
17 #include <riscv_vector.h>
18 #endif
19 
20 namespace at::native {
21 namespace {
22 
23 struct Arguments final {
24   // Input layer dimensions
25   int64_t batch;
26   int64_t in_rows;
27   int64_t in_cols;
28   int64_t stride;
29   int64_t pad_rows;
30   int64_t pad_cols;
31 
32   // Output layer dimensions
33   int64_t out_rows;
34   int64_t out_cols;
35   int64_t out_channels;
36 };
37 
calculate_conv_output_size(const IntArrayRef input_size,const IntArrayRef weight_size,const IntArrayRef stride,const IntArrayRef padding)38 inline std::vector<int64_t> calculate_conv_output_size(
39     const IntArrayRef input_size,
40     const IntArrayRef weight_size,
41     const IntArrayRef stride,
42     const IntArrayRef padding) {
43   const auto calc_output_dimension = [](
44     const int64_t input, const int64_t kernel, const int64_t stride, const int64_t padding) {
45     return 1 + (input - kernel + 2 * padding) / stride;
46   };
47 
48   return std::vector<int64_t> {
49     input_size[0],
50     weight_size[0],
51     calc_output_dimension(input_size[2], weight_size[2], stride[0], padding[0]),
52     calc_output_dimension(input_size[3], weight_size[3], stride[1], padding[1]),
53   };
54 }
55 
56 #ifdef __ARM_NEON__
57 
winograd_f2k3_input_transform_inplace__neon(float32x4_t * const d0,float32x4_t * const d1,float32x4_t * const d2,float32x4_t * const d3)58 inline void winograd_f2k3_input_transform_inplace__neon(
59     float32x4_t* const d0,
60     float32x4_t* const d1,
61     float32x4_t* const d2,
62     float32x4_t* const d3) {
63   const float32x4_t wd0 = *d0 - *d2;
64   const float32x4_t wd1 = *d1 + *d2;
65   const float32x4_t wd2 = -*d1 + *d2;
66   const float32x4_t wd3 = *d1 - *d3;
67   *d0 = wd0;
68   *d1 = wd1;
69   *d2 = wd2;
70   *d3 = wd3;
71 }
72 
winograd_f2k3_output_transform_inplace__neon(float32x4_t * const m0,float32x4_t * const m1,const float32x4_t * const m2,const float32x4_t * const m3)73 inline void winograd_f2k3_output_transform_inplace__neon(
74     float32x4_t* const m0,
75     float32x4_t* const m1,
76     const float32x4_t* const m2,
77     const float32x4_t* const m3) {
78   *m0 = *m0 + *m1 + *m2;
79   *m1 = *m1 - *m2 - *m3;
80 }
81 
82 inline float32x4_t
vmuladdq_f32(const float32x4_t c,const float32x4_t a,const float32x4_t b)83 vmuladdq_f32(const float32x4_t c, const float32x4_t a, const float32x4_t b) {
84 #if defined(__aarch64__)
85   return vfmaq_f32(c, a, b);
86 #else
87   return vmlaq_f32(c, a, b);
88 #endif
89 }
90 
91 inline float32x4_t
vmulsubq_f32(const float32x4_t c,const float32x4_t a,const float32x4_t b)92 vmulsubq_f32(const float32x4_t c, const float32x4_t a, const float32x4_t b) {
93 #if defined(__aarch64__)
94   return vfmsq_f32(c, a, b);
95 #else
96   return vmlsq_f32(c, a, b);
97 #endif
98 }
99 
winograd_f2k3_kernel_transform__neon(const float32x4_t g0,const float32x4_t g1,const float32x4_t g2,float32x4_t * const transform0,float32x4_t * const transform1,float32x4_t * const transform2,float32x4_t * const transform3)100 inline void winograd_f2k3_kernel_transform__neon(
101     const float32x4_t g0,
102     const float32x4_t g1,
103     const float32x4_t g2,
104     float32x4_t* const transform0,
105     float32x4_t* const transform1,
106     float32x4_t* const transform2,
107     float32x4_t* const transform3) {
108   const float32x4_t const_half = vdupq_n_f32(0.5f);
109   float32x4_t half_g0_plus_g2 = const_half * (g0 + g2);
110   *transform0 = g0;
111   *transform1 = vmuladdq_f32(half_g0_plus_g2, const_half, g1);
112   *transform2 = vmulsubq_f32(half_g0_plus_g2, const_half, g1);
113   *transform3 = g2;
114 }
115 
v4f_transpose4x4__neon(const float32x4x4_t m)116 inline float32x4x4_t v4f_transpose4x4__neon(const float32x4x4_t m) {
117   float32x4x4_t ret;
118   vst4q_f32((float*)(&ret), m);
119   return ret;
120 }
121 
convolution_depthwise3x3_winograd_impl(const Arguments & args,const float * const input,const float * const kernel,const float * const bias,float * const output)122 void convolution_depthwise3x3_winograd_impl(
123     const Arguments& args,
124     const float* const input,
125     const float* const kernel,
126     const float* const bias,
127     float* const output) {
128   const float32x4_t vbias = vsetq_lane_f32(*bias, vdupq_n_f32(0.0), 1);
129   float32x4x4_t kernel_tile;
130 
131   {
132     const float32x4_t g0 = vld1q_f32(kernel);
133     const float32x4_t g1 = vld1q_f32(kernel + 3);
134     // g2[3] is junk
135     const float32x4_t g2 =
136         vextq_f32(vld1q_f32(kernel + 5), vld1q_f32(kernel + 5), 1);
137     float32x4x4_t w;
138     winograd_f2k3_kernel_transform__neon(
139         g0, g1, g2, &w.val[0], &w.val[1], &w.val[2], &w.val[3]);
140     w = v4f_transpose4x4__neon(w);
141 
142     winograd_f2k3_kernel_transform__neon(
143         w.val[0],
144         w.val[1],
145         w.val[2],
146         &kernel_tile.val[0],
147         &kernel_tile.val[1],
148         &kernel_tile.val[2],
149         &kernel_tile.val[3]);
150   }
151 
152 #define TILE                                                  \
153   winograd_f2k3_input_transform_inplace__neon(                \
154       &input_tile.val[0],                                     \
155       &input_tile.val[1],                                     \
156       &input_tile.val[2],                                     \
157       &input_tile.val[3]);                                    \
158   input_tile = v4f_transpose4x4__neon(input_tile);            \
159   winograd_f2k3_input_transform_inplace__neon(                \
160       &input_tile.val[0],                                     \
161       &input_tile.val[1],                                     \
162       &input_tile.val[2],                                     \
163       &input_tile.val[3]);                                    \
164                                                               \
165   for (const auto row : c10::irange(4)) {                         \
166     input_tile.val[row] =                                     \
167         vmulq_f32(input_tile.val[row], kernel_tile.val[row]); \
168   }                                                           \
169                                                               \
170   input_tile.val[1] = input_tile.val[1] + vbias;              \
171   winograd_f2k3_output_transform_inplace__neon(               \
172       &input_tile.val[0],                                     \
173       &input_tile.val[1],                                     \
174       &input_tile.val[2],                                     \
175       &input_tile.val[3]);                                    \
176   input_tile = v4f_transpose4x4__neon(input_tile);            \
177   winograd_f2k3_output_transform_inplace__neon(               \
178       &input_tile.val[0],                                     \
179       &input_tile.val[1],                                     \
180       &input_tile.val[2],                                     \
181       &input_tile.val[3])
182 
183   // Non-padded regime.
184 
185   // Iterate over non-padded output tiles.
186   // TODO: avoid spilling W by breaking out the non-padded vs padded case.
187   for (int64_t oth = 0; oth < (args.out_rows + 1) / 2; ++oth) {
188     for (int64_t otw = 0; otw < (args.out_cols + 1) / 2; ++otw) {
189       // load input tile for [oth, otw];
190       int64_t ih = oth * 2 - args.pad_rows;
191       int64_t iw = otw * 2 - args.pad_cols;
192       // fast-path, all accesses in-bounds
193       if (C10_LIKELY(
194               ih >= 0 && iw >= 0 && ih + 3 < args.in_rows &&
195                   iw + 3 < args.in_cols && 2 * oth + 1 < args.out_rows &&
196                   2 * otw + 1 < args.out_cols
197               )) {
198         float32x4x4_t input_tile;
199         for (const auto row : c10::irange(4)) {
200           input_tile.val[row] =
201               vld1q_f32(input + (ih + row) * args.in_cols + iw);
202         }
203 
204         TILE;
205 
206         for (const auto row : c10::irange(2)) {
207           vst1_f32(
208               output + (oth * 2 + row) * args.out_cols + otw * 2,
209               vget_low_f32(input_tile.val[row]));
210         }
211       } else {
212         float block[4][4];
213         for (const auto row : c10::irange(4)) {
214           for (const auto col : c10::irange(4)) {
215             if (ih + row >= 0 && iw + col >= 0 && ih + row < args.in_rows &&
216                 iw + col < args.in_cols) {
217               block[row][col] = input[(ih + row) * args.in_cols + iw + col];
218             } else {
219               block[row][col] = 0.0;
220             }
221           }
222         }
223 
224         float32x4x4_t input_tile;
225         for (const auto row : c10::irange(4)) {
226           input_tile.val[row] = vld1q_f32(&block[row][0]);
227         }
228 
229         TILE;
230 
231         float oblock[2][2];
232         for (const auto row : c10::irange(2)) {
233           vst1_f32(&oblock[row][0], vget_low_f32(input_tile.val[row]));
234         }
235         for (const auto row : c10::irange(2)) {
236           for (const auto col : c10::irange(2)) {
237             if (2 * oth + row < args.out_rows &&
238                 2 * otw + col < args.out_cols) {
239               output[(2 * oth + row) * args.out_cols + 2 * otw + col] =
240                   oblock[row][col];
241             }
242           }
243         }
244       }
245     }
246   }
247 }
248 
249 #elif defined(__riscv_v_intrinsic) && __riscv_v_intrinsic>=12000
250 
winograd_f2k3_input_transform_inplace__rvv(vfloat32m1x4_t * input_tile_val)251 inline void winograd_f2k3_input_transform_inplace__rvv(
252     vfloat32m1x4_t* input_tile_val) {
253   const vfloat32m1_t d0 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 0);
254   const vfloat32m1_t d1 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 1);
255   const vfloat32m1_t d2 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 2);
256   const vfloat32m1_t d3 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 3);
257 
258   const vfloat32m1_t wd0 = __riscv_vfsub_vv_f32m1(d0, d2, 4);
259   const vfloat32m1_t wd1 = __riscv_vfadd_vv_f32m1(d1, d2, 4);
260   const vfloat32m1_t wd2 = __riscv_vfsub_vv_f32m1(d2, d1, 4);
261   const vfloat32m1_t wd3 = __riscv_vfsub_vv_f32m1(d1, d3, 4);
262 
263   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wd0);
264   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wd1);
265   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 2, wd2);
266   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 3, wd3);
267 }
268 
winograd_f2k3_output_transform_inplace__rvv(vfloat32m1x4_t * input_tile_val)269 inline void winograd_f2k3_output_transform_inplace__rvv(
270     vfloat32m1x4_t* input_tile_val) {
271   const vfloat32m1_t m0 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 0);
272   const vfloat32m1_t m1 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 1);
273   const vfloat32m1_t m2 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 2);
274   const vfloat32m1_t m3 = __riscv_vget_v_f32m1x4_f32m1(*input_tile_val, 3);
275 
276   const vfloat32m1_t m0_plus_m1 = __riscv_vfadd_vv_f32m1(m0, m1, 4);
277   const vfloat32m1_t wm0 = __riscv_vfadd_vv_f32m1(m0_plus_m1, m2, 4);
278   const vfloat32m1_t m1_sub_m2 = __riscv_vfsub_vv_f32m1(m1, m2, 4);
279   const vfloat32m1_t wm1 = __riscv_vfsub_vv_f32m1(m1_sub_m2, m3, 4);
280 
281   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 0, wm0);
282   *input_tile_val = __riscv_vset_v_f32m1_f32m1x4(*input_tile_val, 1, wm1);
283 }
284 
285 inline vfloat32m1_t
vmuladdq_f32(const vfloat32m1_t c,const vfloat32m1_t a,const vfloat32m1_t b)286 vmuladdq_f32(const vfloat32m1_t c, const vfloat32m1_t a, const vfloat32m1_t b) {
287   return __riscv_vfmacc_vv_f32m1(c, a, b, 4);
288 }
289 
290 inline vfloat32m1_t
vmulsubq_f32(const vfloat32m1_t c,const vfloat32m1_t a,const vfloat32m1_t b)291 vmulsubq_f32(const vfloat32m1_t c, const vfloat32m1_t a, const vfloat32m1_t b) {
292   return __riscv_vfnmsac_vv_f32m1(c, a, b, 4);
293 }
294 
winograd_f2k3_kernel_transform__rvv(const vfloat32m1_t g0,const vfloat32m1_t g1,const vfloat32m1_t g2,vfloat32m1x4_t * const transform)295 inline void winograd_f2k3_kernel_transform__rvv(
296     const vfloat32m1_t g0,
297     const vfloat32m1_t g1,
298     const vfloat32m1_t g2,
299     vfloat32m1x4_t* const transform) {
300   const vfloat32m1_t const_half = __riscv_vfmv_v_f_f32m1(0.5f, 4);
301   const vfloat32m1_t g0_plus_g2 = __riscv_vfadd_vv_f32m1(g0, g2, 4);
302   vfloat32m1_t half_g0_plus_g2 =  __riscv_vfmul_vv_f32m1(const_half, g0_plus_g2, 4);
303 
304   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 0, g0);
305   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 1, vmuladdq_f32(half_g0_plus_g2, const_half, g1));
306   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 2, vmulsubq_f32(half_g0_plus_g2, const_half, g1));
307   *transform = __riscv_vset_v_f32m1_f32m1x4(*transform, 3, g2);
308 }
309 
v4f_transpose4x4__rvv(const vfloat32m1x4_t m)310 inline vfloat32m1x4_t v4f_transpose4x4__rvv(const vfloat32m1x4_t m) {
311   vfloat32m1x4_t ret;
312   __riscv_vsseg4e32_v_f32m1x4((float*)(&ret), m, 4);
313   return ret;
314 }
315 
convolution_depthwise3x3_winograd_impl(const Arguments & args,const float * const input,const float * const kernel,const float * const bias,float * const output)316 void convolution_depthwise3x3_winograd_impl(
317     const Arguments& args,
318     const float* const input,
319     const float* const kernel,
320     const float* const bias,
321     float* const output) {
322 
323   vbool32_t mask = __riscv_vreinterpret_v_u32m1_b32(__riscv_vmv_v_x_u32m1((uint32_t)(1 << 1),2));
324   const vfloat32m1_t vbias = __riscv_vfmerge_vfm_f32m1(__riscv_vfmv_v_f_f32m1(0.0, 4), *bias, mask, 4);
325   vfloat32m1x4_t kernel_tile;
326 
327   {
328     const vfloat32m1_t g0 = __riscv_vle32_v_f32m1(kernel, 4);
329     const vfloat32m1_t g1 = __riscv_vle32_v_f32m1(kernel + 3, 4);
330     // g2[3] is junk
331     vfloat32m1_t a_slidedown = __riscv_vslidedown_vx_f32m1(__riscv_vle32_v_f32m1(kernel + 5, 4), 1, 4);
332     const vfloat32m1_t g2 =
333           __riscv_vslideup_vx_f32m1(a_slidedown, __riscv_vle32_v_f32m1(kernel + 5, 4), 3, 4);
334     vfloat32m1x4_t w;
335 
336     winograd_f2k3_kernel_transform__rvv(
337         g0, g1, g2, &w);
338 
339     w = v4f_transpose4x4__rvv(w);
340 
341     winograd_f2k3_kernel_transform__rvv(
342         __riscv_vget_v_f32m1x4_f32m1(w, 0),
343         __riscv_vget_v_f32m1x4_f32m1(w, 1),
344         __riscv_vget_v_f32m1x4_f32m1(w, 2),
345         &kernel_tile);
346 
347   }
348 
349 #define TILE                                                                   \
350   winograd_f2k3_input_transform_inplace__rvv(                                  \
351       &input_tile);                                                            \
352   input_tile = v4f_transpose4x4__rvv(input_tile);                              \
353   winograd_f2k3_input_transform_inplace__rvv(                                  \
354       &input_tile);                                                            \
355                                                                                \
356   for (const auto row : c10::irange(4)) {                                      \
357     vfloat32m1_t input_mul_kernel =                                            \
358          __riscv_vfmul_vv_f32m1(                                               \
359            __riscv_vle32_v_f32m1((float*)&input_tile + row * 4, 4),            \
360            __riscv_vle32_v_f32m1((float*)&kernel_tile + row * 4, 4),           \
361            4);                                                                 \
362     __riscv_vse32_v_f32m1(                                                     \
363       (float*)&input_tile + row * 4,                                           \
364       input_mul_kernel,                                                        \
365       4);                                                                      \
366   }                                                                            \
367                                                                                \
368   vfloat32m1_t val = __riscv_vget_v_f32m1x4_f32m1(input_tile, 1);              \
369   vfloat32m1_t val_add_vbias =  __riscv_vfadd_vv_f32m1(val, vbias, 4);         \
370   input_tile = __riscv_vset_v_f32m1_f32m1x4(input_tile, 1, val_add_vbias);     \
371   winograd_f2k3_output_transform_inplace__rvv(                                 \
372       &input_tile);                                                            \
373   input_tile = v4f_transpose4x4__rvv(input_tile);                              \
374   winograd_f2k3_output_transform_inplace__rvv(                                 \
375       &input_tile)
376 
377   // Non-padded regime.
378 
379   // Iterate over non-padded output tiles.
380   // TODO: avoid spilling W by breaking out the non-padded vs padded case.
381   for (int64_t oth = 0; oth < (args.out_rows + 1) / 2; ++oth) {
382     for (int64_t otw = 0; otw < (args.out_cols + 1) / 2; ++otw) {
383       // load input tile for [oth, otw];
384       int64_t ih = oth * 2 - args.pad_rows;
385       int64_t iw = otw * 2 - args.pad_cols;
386       // fast-path, all accesses in-bounds
387       if (C10_LIKELY(
388               ih >= 0 && iw >= 0 && ih + 3 < args.in_rows &&
389                   iw + 3 < args.in_cols && 2 * oth + 1 < args.out_rows &&
390                   2 * otw + 1 < args.out_cols
391               )) {
392         vfloat32m1x4_t input_tile;
393         for (const auto row : c10::irange(4)) {
394           __riscv_vse32_v_f32m1(
395             (float*)&input_tile + row * 4,
396             __riscv_vle32_v_f32m1(input + (ih + row) * args.in_cols + iw, 4),
397             4);
398         }
399 
400         TILE;
401 
402         for (const auto row : c10::irange(2)) {
403           __riscv_vse32_v_f32m1(
404               output + (oth * 2 + row) * args.out_cols + otw * 2,
405               __riscv_vle32_v_f32m1((float*)&input_tile + row * 4, 2),
406               2);
407         }
408       } else {
409         float block[4][4];
410         for (const auto row : c10::irange(4)) {
411           for (const auto col : c10::irange(4)) {
412             if (ih + row >= 0 && iw + col >= 0 && ih + row < args.in_rows &&
413                 iw + col < args.in_cols) {
414               block[row][col] = input[(ih + row) * args.in_cols + iw + col];
415             } else {
416               block[row][col] = 0.0;
417             }
418           }
419         }
420 
421         vfloat32m1x4_t input_tile;
422         for (const auto row : c10::irange(4)) {
423           __riscv_vse32_v_f32m1(
424             (float*)&input_tile + row * 4,
425             __riscv_vle32_v_f32m1(&block[row][0], 4),
426             4);
427         }
428 
429         TILE;
430 
431         float oblock[2][2];
432         for (const auto row : c10::irange(2)) {
433           __riscv_vse32_v_f32m1(
434             &oblock[row][0],
435             __riscv_vle32_v_f32m1((float*)&input_tile + row * 4, 2),
436             2);
437         }
438         for (const auto row : c10::irange(2)) {
439           for (const auto col : c10::irange(2)) {
440             if (2 * oth + row < args.out_rows &&
441                 2 * otw + col < args.out_cols) {
442               output[(2 * oth + row) * args.out_cols + 2 * otw + col] =
443                   oblock[row][col];
444             }
445           }
446         }
447       }
448     }
449   }
450 }
451 
452 #else
453 
convolution_depthwise3x3_winograd_impl(const Arguments &,const float * const,const float * const,const float * const,float * const)454 void convolution_depthwise3x3_winograd_impl(
455     const Arguments&,
456     const float* const,
457     const float* const,
458     const float* const,
459     float* const) {
460 }
461 
462 #endif /* __ARM_NEON__ */
463 
_convolution_depthwise3x3_winograd(const Tensor & input,const Tensor & kernel,const Tensor & bias_potentially_undefined,const IntArrayRef stride,const IntArrayRef padding,const int64_t groups)464 Tensor _convolution_depthwise3x3_winograd(
465     const Tensor & input,
466     const Tensor & kernel,
467     const Tensor & bias_potentially_undefined,
468     const IntArrayRef stride,
469     const IntArrayRef padding,
470     const int64_t groups)
471 {
472   const IntArrayRef input_sizes = input.sizes();
473   const IntArrayRef kernel_sizes = kernel.sizes();
474 
475   Tensor output = at::empty(
476     calculate_conv_output_size(input_sizes, kernel_sizes, stride, padding),
477     input.options());
478 
479   const IntArrayRef output_sizes = output.sizes();
480 
481   const Arguments args {
482       input_sizes[0],     // Input N
483       input_sizes[2],     // Input H
484       input_sizes[3],     // Input W
485       stride[0],          // Stride
486       padding[0],         // Padding Rows
487       padding[1],         // Padding Columns
488       output_sizes[2],    // Output H
489       output_sizes[3],    // Output W
490       output_sizes[1],    // Output C
491   };
492 
493   const int64_t input_hxw = args.in_rows * args.in_cols;
494   const int64_t output_hxw = args.out_rows * args.out_cols;
495 
496   const Tensor bias = bias_potentially_undefined.defined() ?
497                       bias_potentially_undefined :
498                       at::zeros({kernel_sizes[0]}, input.options());
499 
500   auto input_data = input.const_data_ptr<float>();
501   auto kernel_data = kernel.const_data_ptr<float>();
502   auto bias_data = bias.const_data_ptr<float>();
503   auto output_data = output.data_ptr<float>();
504 
505   at::parallel_for(0, args.batch * args.out_channels, 0, [&](int64_t start, int64_t end) {
506     for (const auto k : c10::irange(start, end)) {
507       const int64_t g = k % args.out_channels;
508       const int64_t i = k / (args.out_channels / groups);
509       convolution_depthwise3x3_winograd_impl(
510           args,
511           input_data + i * input_hxw,
512           kernel_data + g * 3 * 3,
513           bias_data + g,
514           output_data + k * output_hxw);
515     }
516   });
517 
518   return output;
519 }
520 
521 }  // namespace
522 
523 ALSO_REGISTER_AVX512_DISPATCH(convolution_depthwise3x3_winograd_stub, &_convolution_depthwise3x3_winograd);
524 
525 }  // namespace at::native
526