1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <assert.h>
7*4bdc9457SAndroid Build Coastguard Worker #include <math.h>
8*4bdc9457SAndroid Build Coastguard Worker #include <stdbool.h>
9*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h>
10*4bdc9457SAndroid Build Coastguard Worker #include <stdint.h>
11*4bdc9457SAndroid Build Coastguard Worker #include <stdlib.h>
12*4bdc9457SAndroid Build Coastguard Worker #include <string.h>
13*4bdc9457SAndroid Build Coastguard Worker
14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/allocator.h>
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/common.h>
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/compute.h>
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/indirection.h>
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/log.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/pack.h>
23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h>
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/params.h>
25*4bdc9457SAndroid Build Coastguard Worker
26*4bdc9457SAndroid Build Coastguard Worker
xnn_create_convolution2d_nchw_f32(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_channel_stride,size_t output_channel_stride,const float * kernel,const float * bias,float output_min,float output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * convolution_op_out)27*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_convolution2d_nchw_f32(
28*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_top,
29*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_right,
30*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_bottom,
31*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_left,
32*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height,
33*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width,
34*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_height,
35*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_width,
36*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height,
37*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width,
38*4bdc9457SAndroid Build Coastguard Worker uint32_t groups,
39*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
40*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
41*4bdc9457SAndroid Build Coastguard Worker size_t input_channel_stride,
42*4bdc9457SAndroid Build Coastguard Worker size_t output_channel_stride,
43*4bdc9457SAndroid Build Coastguard Worker const float* kernel,
44*4bdc9457SAndroid Build Coastguard Worker const float* bias,
45*4bdc9457SAndroid Build Coastguard Worker float output_min,
46*4bdc9457SAndroid Build Coastguard Worker float output_max,
47*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
48*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
49*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* convolution_op_out)
50*4bdc9457SAndroid Build Coastguard Worker {
51*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op = NULL;
52*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status = xnn_status_uninitialized;
53*4bdc9457SAndroid Build Coastguard Worker
54*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
55*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
56*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32));
57*4bdc9457SAndroid Build Coastguard Worker goto error;
58*4bdc9457SAndroid Build Coastguard Worker }
59*4bdc9457SAndroid Build Coastguard Worker
60*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_invalid_parameter;
61*4bdc9457SAndroid Build Coastguard Worker
62*4bdc9457SAndroid Build Coastguard Worker if (kernel_width == 0 || kernel_height == 0) {
63*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
64*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 "x%" PRIu32 " kernel: kernel dimensions must be non-zero",
65*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32), kernel_width, kernel_height);
66*4bdc9457SAndroid Build Coastguard Worker goto error;
67*4bdc9457SAndroid Build Coastguard Worker }
68*4bdc9457SAndroid Build Coastguard Worker
69*4bdc9457SAndroid Build Coastguard Worker if (subsampling_width == 0 || subsampling_height == 0) {
70*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
71*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 "x%" PRIu32 " subsampling: subsampling dimensions must be non-zero",
72*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32), subsampling_width, subsampling_height);
73*4bdc9457SAndroid Build Coastguard Worker goto error;
74*4bdc9457SAndroid Build Coastguard Worker }
75*4bdc9457SAndroid Build Coastguard Worker
76*4bdc9457SAndroid Build Coastguard Worker if (dilation_width == 0 || dilation_height == 0) {
77*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
78*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 "x%" PRIu32 " dilation: dilation dimensions must be non-zero",
79*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32), dilation_width, dilation_height);
80*4bdc9457SAndroid Build Coastguard Worker goto error;
81*4bdc9457SAndroid Build Coastguard Worker }
82*4bdc9457SAndroid Build Coastguard Worker
83*4bdc9457SAndroid Build Coastguard Worker if (groups == 0) {
84*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
85*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 " groups: number of groups must be non-zero",
86*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32), groups);
87*4bdc9457SAndroid Build Coastguard Worker goto error;
88*4bdc9457SAndroid Build Coastguard Worker }
89*4bdc9457SAndroid Build Coastguard Worker
90*4bdc9457SAndroid Build Coastguard Worker if (group_input_channels == 0) {
91*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
92*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %zu input channels per group: number of channels must be non-zero",
93*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32), group_input_channels);
94*4bdc9457SAndroid Build Coastguard Worker goto error;
95*4bdc9457SAndroid Build Coastguard Worker }
96*4bdc9457SAndroid Build Coastguard Worker
97*4bdc9457SAndroid Build Coastguard Worker if (group_output_channels == 0) {
98*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
99*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %zu output channels per group: number of channels must be non-zero",
100*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32), group_output_channels);
101*4bdc9457SAndroid Build Coastguard Worker goto error;
102*4bdc9457SAndroid Build Coastguard Worker }
103*4bdc9457SAndroid Build Coastguard Worker
104*4bdc9457SAndroid Build Coastguard Worker const size_t input_channels = groups * group_input_channels;
105*4bdc9457SAndroid Build Coastguard Worker if (input_channel_stride < input_channels) {
106*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
107*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with input channel stride of %zu: "
108*4bdc9457SAndroid Build Coastguard Worker "stride must be at least as large as the number of input channels (%" PRIu32 "x%zu)",
109*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32),
110*4bdc9457SAndroid Build Coastguard Worker input_channel_stride, groups, group_input_channels);
111*4bdc9457SAndroid Build Coastguard Worker goto error;
112*4bdc9457SAndroid Build Coastguard Worker }
113*4bdc9457SAndroid Build Coastguard Worker
114*4bdc9457SAndroid Build Coastguard Worker const size_t output_channels = groups * group_output_channels;
115*4bdc9457SAndroid Build Coastguard Worker if (output_channel_stride < output_channels) {
116*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
117*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with output channel stride of %zu: "
118*4bdc9457SAndroid Build Coastguard Worker "stride must be at least as large as the number of output channels (%" PRIu32 "x%zu)",
119*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32),
120*4bdc9457SAndroid Build Coastguard Worker output_channel_stride, groups, group_output_channels);
121*4bdc9457SAndroid Build Coastguard Worker goto error;
122*4bdc9457SAndroid Build Coastguard Worker }
123*4bdc9457SAndroid Build Coastguard Worker
124*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_min)) {
125*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
126*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
127*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32));
128*4bdc9457SAndroid Build Coastguard Worker goto error;
129*4bdc9457SAndroid Build Coastguard Worker }
130*4bdc9457SAndroid Build Coastguard Worker
131*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_max)) {
132*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
133*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
134*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32));
135*4bdc9457SAndroid Build Coastguard Worker goto error;
136*4bdc9457SAndroid Build Coastguard Worker }
137*4bdc9457SAndroid Build Coastguard Worker
138*4bdc9457SAndroid Build Coastguard Worker if (output_min >= output_max) {
139*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
140*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
141*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32), output_min, output_max);
142*4bdc9457SAndroid Build Coastguard Worker goto error;
143*4bdc9457SAndroid Build Coastguard Worker }
144*4bdc9457SAndroid Build Coastguard Worker
145*4bdc9457SAndroid Build Coastguard Worker if ((flags & XNN_FLAG_DEPTHWISE_CONVOLUTION) != 0 && group_input_channels != 1) {
146*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
147*4bdc9457SAndroid Build Coastguard Worker "failed to create depthwise %s operator with %zu input channels per group: "
148*4bdc9457SAndroid Build Coastguard Worker "depthwise convolution must have exactly 1 input channel per group",
149*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32), group_input_channels);
150*4bdc9457SAndroid Build Coastguard Worker goto error;
151*4bdc9457SAndroid Build Coastguard Worker }
152*4bdc9457SAndroid Build Coastguard Worker
153*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_unsupported_parameter;
154*4bdc9457SAndroid Build Coastguard Worker
155*4bdc9457SAndroid Build Coastguard Worker enum xnn_ukernel_type ukernel_type;
156*4bdc9457SAndroid Build Coastguard Worker struct dwconv2d_chw_parameters* dwconv2d_parameters = NULL;
157*4bdc9457SAndroid Build Coastguard Worker // Supported cases:
158*4bdc9457SAndroid Build Coastguard Worker // + 1x1 convolution (no groups)
159*4bdc9457SAndroid Build Coastguard Worker // + 3x3 stride-2 with 3 input channels and NHWC input layout
160*4bdc9457SAndroid Build Coastguard Worker // + 3x3 stride-2 depthwise convolution with horizontal padding 1 & no vertical padding
161*4bdc9457SAndroid Build Coastguard Worker // + 3x3 stride-1 depthwise convolution with horizontal padding 1 & no vertical padding
162*4bdc9457SAndroid Build Coastguard Worker // + 5x5 stride-2 depthwise convolution with horizontal padding 2 & no vertical padding
163*4bdc9457SAndroid Build Coastguard Worker // + 5x5 stride-1 depthwise convolution with horizontal padding 2 & no vertical padding
164*4bdc9457SAndroid Build Coastguard Worker const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
165*4bdc9457SAndroid Build Coastguard Worker const bool is_1x1 = kernel_width == 1 && kernel_height == 1 && subsampling_height == 1 && subsampling_width == 1;
166*4bdc9457SAndroid Build Coastguard Worker const bool is_3x3 = kernel_width == 3 && kernel_height == 3 && dilation_height == 1 && dilation_width == 1;
167*4bdc9457SAndroid Build Coastguard Worker const bool is_5x5 = kernel_width == 5 && kernel_height == 5 && dilation_height == 1 && dilation_width == 1;
168*4bdc9457SAndroid Build Coastguard Worker const bool nhwc_input = (flags & XNN_FLAG_INPUT_NHWC) != 0;
169*4bdc9457SAndroid Build Coastguard Worker if (is_1x1 && !any_padding && !nhwc_input && groups == 1) {
170*4bdc9457SAndroid Build Coastguard Worker ukernel_type = xnn_ukernel_type_spmm;
171*4bdc9457SAndroid Build Coastguard Worker } else if (is_3x3 && subsampling_height == 2 && subsampling_width == 2 &&
172*4bdc9457SAndroid Build Coastguard Worker input_padding_top == 1 && input_padding_left == 1 && input_padding_bottom == 1 && input_padding_right == 1 &&
173*4bdc9457SAndroid Build Coastguard Worker nhwc_input && groups == 1)
174*4bdc9457SAndroid Build Coastguard Worker {
175*4bdc9457SAndroid Build Coastguard Worker ukernel_type = xnn_ukernel_type_conv2d_hwc2chw;
176*4bdc9457SAndroid Build Coastguard Worker } else if (is_3x3 && subsampling_height == 1 && subsampling_width == 1 &&
177*4bdc9457SAndroid Build Coastguard Worker input_padding_top == 1 && input_padding_left == 1 && input_padding_bottom == 1 && input_padding_right == 1 &&
178*4bdc9457SAndroid Build Coastguard Worker !nhwc_input && group_input_channels == 1 && group_output_channels == 1)
179*4bdc9457SAndroid Build Coastguard Worker {
180*4bdc9457SAndroid Build Coastguard Worker ukernel_type = xnn_ukernel_type_dwconv;
181*4bdc9457SAndroid Build Coastguard Worker dwconv2d_parameters = &xnn_params.f32.dwconv2d_chw_3x3;
182*4bdc9457SAndroid Build Coastguard Worker } else if (is_3x3 && subsampling_height == 2 && subsampling_width == 2 &&
183*4bdc9457SAndroid Build Coastguard Worker (input_padding_top == 0 || input_padding_top == 1) && input_padding_left == 1 && input_padding_bottom == 1 && input_padding_right == 1 &&
184*4bdc9457SAndroid Build Coastguard Worker !nhwc_input && group_input_channels == 1 && group_output_channels == 1)
185*4bdc9457SAndroid Build Coastguard Worker {
186*4bdc9457SAndroid Build Coastguard Worker ukernel_type = xnn_ukernel_type_dwconv;
187*4bdc9457SAndroid Build Coastguard Worker dwconv2d_parameters = &xnn_params.f32.dwconv2d_chw_3x3s2;
188*4bdc9457SAndroid Build Coastguard Worker } else if (is_5x5 && subsampling_height == 1 && subsampling_width == 1 &&
189*4bdc9457SAndroid Build Coastguard Worker input_padding_top == 2 && input_padding_left == 2 && input_padding_bottom == 2 && input_padding_right == 2 &&
190*4bdc9457SAndroid Build Coastguard Worker !nhwc_input && group_input_channels == 1 && group_output_channels == 1)
191*4bdc9457SAndroid Build Coastguard Worker {
192*4bdc9457SAndroid Build Coastguard Worker ukernel_type = xnn_ukernel_type_dwconv;
193*4bdc9457SAndroid Build Coastguard Worker dwconv2d_parameters = &xnn_params.f32.dwconv2d_chw_5x5;
194*4bdc9457SAndroid Build Coastguard Worker } else if (is_5x5 && subsampling_height == 2 && subsampling_width == 2 &&
195*4bdc9457SAndroid Build Coastguard Worker (input_padding_top == 1 || input_padding_top == 2) && input_padding_left == 2 && input_padding_bottom == 2 && input_padding_right == 2 &&
196*4bdc9457SAndroid Build Coastguard Worker !nhwc_input && group_input_channels == 1 && group_output_channels == 1)
197*4bdc9457SAndroid Build Coastguard Worker {
198*4bdc9457SAndroid Build Coastguard Worker ukernel_type = xnn_ukernel_type_dwconv;
199*4bdc9457SAndroid Build Coastguard Worker dwconv2d_parameters = &xnn_params.f32.dwconv2d_chw_5x5s2;
200*4bdc9457SAndroid Build Coastguard Worker } else {
201*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
202*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 "x%" PRIu32 " kernel, %"PRIu32 "x%" PRIu32 " subsampling, %"PRIu32 "x%" PRIu32 " dilation"
203*4bdc9457SAndroid Build Coastguard Worker ", %" PRIu32 "+%" PRIu32 "x%" PRIu32 "+%" PRIu32" padding, %" PRIu32 "x%zu input channels, and %" PRIu32 "x%zu output channels: "
204*4bdc9457SAndroid Build Coastguard Worker "only selected convolution parameters are supported",
205*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32),
206*4bdc9457SAndroid Build Coastguard Worker kernel_width, kernel_height, subsampling_width, subsampling_height, dilation_width, dilation_height,
207*4bdc9457SAndroid Build Coastguard Worker input_padding_top, input_padding_left, input_padding_bottom, input_padding_right,
208*4bdc9457SAndroid Build Coastguard Worker groups, group_input_channels, groups, group_output_channels);
209*4bdc9457SAndroid Build Coastguard Worker goto error;
210*4bdc9457SAndroid Build Coastguard Worker }
211*4bdc9457SAndroid Build Coastguard Worker
212*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_out_of_memory;
213*4bdc9457SAndroid Build Coastguard Worker
214*4bdc9457SAndroid Build Coastguard Worker convolution_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
215*4bdc9457SAndroid Build Coastguard Worker if (convolution_op == NULL) {
216*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
217*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator descriptor",
218*4bdc9457SAndroid Build Coastguard Worker sizeof(struct xnn_operator), xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32));
219*4bdc9457SAndroid Build Coastguard Worker goto error;
220*4bdc9457SAndroid Build Coastguard Worker }
221*4bdc9457SAndroid Build Coastguard Worker
222*4bdc9457SAndroid Build Coastguard Worker if (caches != NULL && ukernel_type != xnn_ukernel_type_spmm) {
223*4bdc9457SAndroid Build Coastguard Worker convolution_op->weights_cache = caches->weights_cache;
224*4bdc9457SAndroid Build Coastguard Worker }
225*4bdc9457SAndroid Build Coastguard Worker
226*4bdc9457SAndroid Build Coastguard Worker switch (ukernel_type) {
227*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_spmm:
228*4bdc9457SAndroid Build Coastguard Worker {
229*4bdc9457SAndroid Build Coastguard Worker assert(kernel_height == 1);
230*4bdc9457SAndroid Build Coastguard Worker assert(kernel_width == 1);
231*4bdc9457SAndroid Build Coastguard Worker assert(groups == 1);
232*4bdc9457SAndroid Build Coastguard Worker
233*4bdc9457SAndroid Build Coastguard Worker size_t num_nonzeroes = 0;
234*4bdc9457SAndroid Build Coastguard Worker size_t num_nonzero_blocks2 = 0;
235*4bdc9457SAndroid Build Coastguard Worker size_t num_nonzero_blocks4 = 0;
236*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < round_down_po2(group_output_channels, 4); oc += 4) {
237*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels; ic++) {
238*4bdc9457SAndroid Build Coastguard Worker const size_t row0_nonzero = (size_t) (kernel[oc * group_input_channels + ic] != 0.0f);
239*4bdc9457SAndroid Build Coastguard Worker const size_t row1_nonzero = (size_t) (kernel[(oc + 1) * group_input_channels + ic] != 0.0f);
240*4bdc9457SAndroid Build Coastguard Worker const size_t row2_nonzero = (size_t) (kernel[(oc + 2) * group_input_channels + ic] != 0.0f);
241*4bdc9457SAndroid Build Coastguard Worker const size_t row3_nonzero = (size_t) (kernel[(oc + 3) * group_input_channels + ic] != 0.0f);
242*4bdc9457SAndroid Build Coastguard Worker num_nonzeroes += row0_nonzero + row1_nonzero + row2_nonzero + row3_nonzero;
243*4bdc9457SAndroid Build Coastguard Worker num_nonzero_blocks2 += (row0_nonzero | row1_nonzero) + (row2_nonzero | row3_nonzero);
244*4bdc9457SAndroid Build Coastguard Worker num_nonzero_blocks4 += (row0_nonzero | row1_nonzero | row2_nonzero | row3_nonzero);
245*4bdc9457SAndroid Build Coastguard Worker }
246*4bdc9457SAndroid Build Coastguard Worker }
247*4bdc9457SAndroid Build Coastguard Worker const size_t num_block4_nonzeroes = num_nonzeroes;
248*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = round_down_po2(group_output_channels, 4); oc < round_down_po2(group_output_channels, 2); oc += 2) {
249*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels; ic++) {
250*4bdc9457SAndroid Build Coastguard Worker const size_t row0_nonzero = (size_t) (kernel[oc * group_input_channels + ic] != 0.0f);
251*4bdc9457SAndroid Build Coastguard Worker const size_t row1_nonzero = (size_t) (kernel[(oc + 1) * group_input_channels + ic] != 0.0f);
252*4bdc9457SAndroid Build Coastguard Worker num_nonzeroes += row0_nonzero + row1_nonzero;
253*4bdc9457SAndroid Build Coastguard Worker num_nonzero_blocks2 += (row0_nonzero | row1_nonzero);
254*4bdc9457SAndroid Build Coastguard Worker }
255*4bdc9457SAndroid Build Coastguard Worker }
256*4bdc9457SAndroid Build Coastguard Worker const size_t num_block2_nonzeroes = num_nonzeroes;
257*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = round_down_po2(group_output_channels, 2); oc < group_output_channels; oc++) {
258*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels; ic++) {
259*4bdc9457SAndroid Build Coastguard Worker num_nonzeroes += (size_t) (kernel[oc * group_input_channels + ic] != 0.0f);
260*4bdc9457SAndroid Build Coastguard Worker }
261*4bdc9457SAndroid Build Coastguard Worker }
262*4bdc9457SAndroid Build Coastguard Worker size_t output_channels_block_size = 1;
263*4bdc9457SAndroid Build Coastguard Worker size_t num_output_channel_blocks = group_output_channels;
264*4bdc9457SAndroid Build Coastguard Worker size_t num_nonzero_values = num_nonzeroes;
265*4bdc9457SAndroid Build Coastguard Worker size_t num_nonzero_blocks = num_nonzeroes;
266*4bdc9457SAndroid Build Coastguard Worker const struct spmm_parameters* spmm_parameters = &xnn_params.f32.spmm;
267*4bdc9457SAndroid Build Coastguard Worker if (num_block4_nonzeroes * 5 >= num_nonzero_blocks4 * 18 && xnn_params.f32.spmm4.ukernel != NULL) {
268*4bdc9457SAndroid Build Coastguard Worker // 4-channel blocks have 90%+ non-zeroes
269*4bdc9457SAndroid Build Coastguard Worker
270*4bdc9457SAndroid Build Coastguard Worker output_channels_block_size = 4;
271*4bdc9457SAndroid Build Coastguard Worker num_output_channel_blocks = num_output_channel_blocks / 4 + num_output_channel_blocks % 4;
272*4bdc9457SAndroid Build Coastguard Worker spmm_parameters = &xnn_params.f32.spmm4;
273*4bdc9457SAndroid Build Coastguard Worker // Non-zeroes which don't fit into whole 4-channel blocks, processed one-by-one
274*4bdc9457SAndroid Build Coastguard Worker const size_t num_remaining_nonzeroes = num_nonzeroes - num_block4_nonzeroes;
275*4bdc9457SAndroid Build Coastguard Worker num_nonzero_values = num_nonzero_blocks4 * 4 + num_remaining_nonzeroes;
276*4bdc9457SAndroid Build Coastguard Worker num_nonzero_blocks = num_nonzero_blocks4 + num_remaining_nonzeroes;
277*4bdc9457SAndroid Build Coastguard Worker } else if (num_block2_nonzeroes * 5 >= num_nonzero_blocks2 * 9 && xnn_params.f32.spmm2.ukernel != NULL) {
278*4bdc9457SAndroid Build Coastguard Worker // 2-channel blocks have 90%+ non-zeroes
279*4bdc9457SAndroid Build Coastguard Worker
280*4bdc9457SAndroid Build Coastguard Worker output_channels_block_size = 2;
281*4bdc9457SAndroid Build Coastguard Worker num_output_channel_blocks = num_output_channel_blocks / 2 + num_output_channel_blocks % 2;
282*4bdc9457SAndroid Build Coastguard Worker spmm_parameters = &xnn_params.f32.spmm2;
283*4bdc9457SAndroid Build Coastguard Worker // Non-zeroes which don't fit into whole 2-channel blocks, processed one-by-one
284*4bdc9457SAndroid Build Coastguard Worker const size_t num_remaining_nonzeroes = num_nonzeroes - num_block2_nonzeroes;
285*4bdc9457SAndroid Build Coastguard Worker num_nonzero_values = num_nonzero_blocks2 * 2 + num_remaining_nonzeroes;
286*4bdc9457SAndroid Build Coastguard Worker num_nonzero_blocks = num_nonzero_blocks2 + num_remaining_nonzeroes;
287*4bdc9457SAndroid Build Coastguard Worker }
288*4bdc9457SAndroid Build Coastguard Worker
289*4bdc9457SAndroid Build Coastguard Worker // Sparse representation of weights consists of four components:
290*4bdc9457SAndroid Build Coastguard Worker // 1. An array of float values storing non-zero kernel elements, and all (group_output_channels) bias elements.
291*4bdc9457SAndroid Build Coastguard Worker // All elements within non-zero block are assumed to be non-zero.
292*4bdc9457SAndroid Build Coastguard Worker // 2. An array of int32_t values storing increment for input pointer after each processed tile. This array is
293*4bdc9457SAndroid Build Coastguard Worker // derived from scaled difference in array 2 using parameters to setup function.
294*4bdc9457SAndroid Build Coastguard Worker // 3. An array of uint32_t values storing the number of non-zero kernel elements per each output channel.
295*4bdc9457SAndroid Build Coastguard Worker // 4. An array of int32_t values storing scaled [by sizeof(input element)] difference between input channels
296*4bdc9457SAndroid Build Coastguard Worker // corresponding to successive non-zero blocks.
297*4bdc9457SAndroid Build Coastguard Worker const size_t packed_weights_size = num_output_channel_blocks * sizeof(uint32_t) +
298*4bdc9457SAndroid Build Coastguard Worker (num_nonzero_blocks * 2) * sizeof(int32_t) + (num_nonzero_values + group_output_channels) * sizeof(float);
299*4bdc9457SAndroid Build Coastguard Worker
300*4bdc9457SAndroid Build Coastguard Worker convolution_op->packed_weights.pointer = xnn_allocate_simd_memory(packed_weights_size);
301*4bdc9457SAndroid Build Coastguard Worker if (convolution_op->packed_weights.pointer == NULL) {
302*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
303*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator packed weights",
304*4bdc9457SAndroid Build Coastguard Worker packed_weights_size, xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32));
305*4bdc9457SAndroid Build Coastguard Worker goto error;
306*4bdc9457SAndroid Build Coastguard Worker }
307*4bdc9457SAndroid Build Coastguard Worker convolution_op->num_nonzero_values = num_nonzero_values;
308*4bdc9457SAndroid Build Coastguard Worker convolution_op->num_nonzero_blocks = num_nonzero_blocks;
309*4bdc9457SAndroid Build Coastguard Worker convolution_op->num_output_channel_blocks = num_output_channel_blocks;
310*4bdc9457SAndroid Build Coastguard Worker
311*4bdc9457SAndroid Build Coastguard Worker float* nonzero_values = convolution_op->packed_weights.pointer;
312*4bdc9457SAndroid Build Coastguard Worker int32_t* input_increments = (int32_t*) (nonzero_values + num_nonzero_values + group_output_channels);
313*4bdc9457SAndroid Build Coastguard Worker uint32_t* output_channel_nonzeros = (uint32_t*) (input_increments + num_nonzero_blocks);
314*4bdc9457SAndroid Build Coastguard Worker int32_t* input_channel_diffs = (int32_t*) (output_channel_nonzeros + num_output_channel_blocks);
315*4bdc9457SAndroid Build Coastguard Worker memset(output_channel_nonzeros, 0, num_output_channel_blocks * sizeof(uint32_t));
316*4bdc9457SAndroid Build Coastguard Worker
317*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_unsupported_parameter;
318*4bdc9457SAndroid Build Coastguard Worker
319*4bdc9457SAndroid Build Coastguard Worker size_t first_ic = 0, last_ic = 0;
320*4bdc9457SAndroid Build Coastguard Worker bool first_nonzero = true;
321*4bdc9457SAndroid Build Coastguard Worker for (size_t ocb = 0; ocb < round_down_po2(group_output_channels, output_channels_block_size); ocb += output_channels_block_size) {
322*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(bias != NULL) {
323*4bdc9457SAndroid Build Coastguard Worker for (size_t oco = 0; oco < output_channels_block_size; oco++) {
324*4bdc9457SAndroid Build Coastguard Worker *nonzero_values++ = bias[ocb + oco];
325*4bdc9457SAndroid Build Coastguard Worker }
326*4bdc9457SAndroid Build Coastguard Worker } else {
327*4bdc9457SAndroid Build Coastguard Worker for (size_t oco = 0; oco < output_channels_block_size; oco++) {
328*4bdc9457SAndroid Build Coastguard Worker *nonzero_values++ = 0.0f;
329*4bdc9457SAndroid Build Coastguard Worker }
330*4bdc9457SAndroid Build Coastguard Worker }
331*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels; ic++) {
332*4bdc9457SAndroid Build Coastguard Worker bool is_nonzero_block = false;
333*4bdc9457SAndroid Build Coastguard Worker for (size_t oco = 0; oco < output_channels_block_size; oco++) {
334*4bdc9457SAndroid Build Coastguard Worker is_nonzero_block |= (kernel[(ocb + oco) * group_input_channels + ic] != 0.0f);
335*4bdc9457SAndroid Build Coastguard Worker }
336*4bdc9457SAndroid Build Coastguard Worker if (is_nonzero_block) {
337*4bdc9457SAndroid Build Coastguard Worker for (size_t oco = 0; oco < output_channels_block_size; oco++) {
338*4bdc9457SAndroid Build Coastguard Worker *nonzero_values++ = kernel[(ocb + oco) * group_input_channels + ic];
339*4bdc9457SAndroid Build Coastguard Worker }
340*4bdc9457SAndroid Build Coastguard Worker if (first_nonzero) {
341*4bdc9457SAndroid Build Coastguard Worker first_ic = ic;
342*4bdc9457SAndroid Build Coastguard Worker } else {
343*4bdc9457SAndroid Build Coastguard Worker const int64_t diff = (int64_t) ((uint64_t) ic - (uint64_t) last_ic) * (int64_t) sizeof(float);
344*4bdc9457SAndroid Build Coastguard Worker if (diff != (int64_t) (int32_t) diff) {
345*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to convert kernel to sparse representation: "
346*4bdc9457SAndroid Build Coastguard Worker "scaled difference in input channels exceeds int32_t range");
347*4bdc9457SAndroid Build Coastguard Worker goto error;
348*4bdc9457SAndroid Build Coastguard Worker }
349*4bdc9457SAndroid Build Coastguard Worker *input_channel_diffs++ = (int32_t) diff;
350*4bdc9457SAndroid Build Coastguard Worker }
351*4bdc9457SAndroid Build Coastguard Worker first_nonzero = false;
352*4bdc9457SAndroid Build Coastguard Worker last_ic = ic;
353*4bdc9457SAndroid Build Coastguard Worker *output_channel_nonzeros += 1;
354*4bdc9457SAndroid Build Coastguard Worker }
355*4bdc9457SAndroid Build Coastguard Worker }
356*4bdc9457SAndroid Build Coastguard Worker output_channel_nonzeros += 1;
357*4bdc9457SAndroid Build Coastguard Worker }
358*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = round_down_po2(group_output_channels, output_channels_block_size); oc < group_output_channels; oc++) {
359*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(bias != NULL) {
360*4bdc9457SAndroid Build Coastguard Worker *nonzero_values++ = bias[oc];
361*4bdc9457SAndroid Build Coastguard Worker } else {
362*4bdc9457SAndroid Build Coastguard Worker *nonzero_values++ = 0.0f;
363*4bdc9457SAndroid Build Coastguard Worker }
364*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels; ic++) {
365*4bdc9457SAndroid Build Coastguard Worker const float weight = kernel[oc * group_input_channels + ic];
366*4bdc9457SAndroid Build Coastguard Worker if (weight != 0.0f) {
367*4bdc9457SAndroid Build Coastguard Worker *nonzero_values++ = weight;
368*4bdc9457SAndroid Build Coastguard Worker if (first_nonzero) {
369*4bdc9457SAndroid Build Coastguard Worker first_ic = ic;
370*4bdc9457SAndroid Build Coastguard Worker } else {
371*4bdc9457SAndroid Build Coastguard Worker const int64_t diff = (int64_t) ((uint64_t) ic - (uint64_t) last_ic) * (int64_t) sizeof(float);
372*4bdc9457SAndroid Build Coastguard Worker if (diff != (int64_t) (int32_t) diff) {
373*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to convert kernel to sparse representation: "
374*4bdc9457SAndroid Build Coastguard Worker "scaled difference in input channels exceeds int32_t range");
375*4bdc9457SAndroid Build Coastguard Worker goto error;
376*4bdc9457SAndroid Build Coastguard Worker }
377*4bdc9457SAndroid Build Coastguard Worker *input_channel_diffs++ = (int32_t) diff;
378*4bdc9457SAndroid Build Coastguard Worker }
379*4bdc9457SAndroid Build Coastguard Worker first_nonzero = false;
380*4bdc9457SAndroid Build Coastguard Worker last_ic = ic;
381*4bdc9457SAndroid Build Coastguard Worker *output_channel_nonzeros += 1;
382*4bdc9457SAndroid Build Coastguard Worker }
383*4bdc9457SAndroid Build Coastguard Worker }
384*4bdc9457SAndroid Build Coastguard Worker output_channel_nonzeros += 1;
385*4bdc9457SAndroid Build Coastguard Worker }
386*4bdc9457SAndroid Build Coastguard Worker // If there are any non-zero elements, we have to return to the initial input channel.
387*4bdc9457SAndroid Build Coastguard Worker if (!first_nonzero) {
388*4bdc9457SAndroid Build Coastguard Worker const int64_t diff = (int64_t) ((uint64_t) first_ic - (uint64_t) last_ic) * (int64_t) sizeof(float);
389*4bdc9457SAndroid Build Coastguard Worker if (diff != (int64_t) (int32_t) diff) {
390*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to convert kernel to sparse representation: "
391*4bdc9457SAndroid Build Coastguard Worker "scaled difference in input channels exceeds int32_t range");
392*4bdc9457SAndroid Build Coastguard Worker goto error;
393*4bdc9457SAndroid Build Coastguard Worker }
394*4bdc9457SAndroid Build Coastguard Worker *input_channel_diffs++ = (int32_t) diff;
395*4bdc9457SAndroid Build Coastguard Worker }
396*4bdc9457SAndroid Build Coastguard Worker convolution_op->first_input_channel = first_ic;
397*4bdc9457SAndroid Build Coastguard Worker
398*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.spmm = (struct xnn_ukernel_spmm) {
399*4bdc9457SAndroid Build Coastguard Worker .function = spmm_parameters->ukernel,
400*4bdc9457SAndroid Build Coastguard Worker .mr = spmm_parameters->mr,
401*4bdc9457SAndroid Build Coastguard Worker };
402*4bdc9457SAndroid Build Coastguard Worker
403*4bdc9457SAndroid Build Coastguard Worker break;
404*4bdc9457SAndroid Build Coastguard Worker }
405*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_conv2d_hwc2chw:
406*4bdc9457SAndroid Build Coastguard Worker {
407*4bdc9457SAndroid Build Coastguard Worker assert(groups == 1);
408*4bdc9457SAndroid Build Coastguard Worker
409*4bdc9457SAndroid Build Coastguard Worker const size_t packed_group_output_channels =
410*4bdc9457SAndroid Build Coastguard Worker round_up(group_output_channels, xnn_params.f32.conv_hwc2chw_3x3c3s2.output_channel_tile);
411*4bdc9457SAndroid Build Coastguard Worker const size_t packed_weights_size = groups * packed_group_output_channels *
412*4bdc9457SAndroid Build Coastguard Worker (group_input_channels * kernel_height * kernel_width + 1 /* bias */) * sizeof(float);
413*4bdc9457SAndroid Build Coastguard Worker size_t aligned_total_weights_size = round_up_po2(packed_weights_size, XNN_ALLOCATION_ALIGNMENT);
414*4bdc9457SAndroid Build Coastguard Worker void* weights_ptr = xnn_get_pointer_to_write_weights(
415*4bdc9457SAndroid Build Coastguard Worker convolution_op, aligned_total_weights_size, 0);
416*4bdc9457SAndroid Build Coastguard Worker if (weights_ptr == NULL) {
417*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to reserve or allocate %zu bytes for %s operator conv2d_hwc2chw packed weights",
418*4bdc9457SAndroid Build Coastguard Worker aligned_total_weights_size,
419*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32));
420*4bdc9457SAndroid Build Coastguard Worker goto error;
421*4bdc9457SAndroid Build Coastguard Worker }
422*4bdc9457SAndroid Build Coastguard Worker
423*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_dconv_oki_w(
424*4bdc9457SAndroid Build Coastguard Worker group_output_channels,
425*4bdc9457SAndroid Build Coastguard Worker group_input_channels,
426*4bdc9457SAndroid Build Coastguard Worker xnn_params.f32.conv_hwc2chw_3x3c3s2.output_channel_tile,
427*4bdc9457SAndroid Build Coastguard Worker kernel_height, kernel_width,
428*4bdc9457SAndroid Build Coastguard Worker kernel, bias, weights_ptr, NULL);
429*4bdc9457SAndroid Build Coastguard Worker
430*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache(convolution_op)) {
431*4bdc9457SAndroid Build Coastguard Worker convolution_op->packed_weights.offset = xnn_get_or_insert_weights_cache(
432*4bdc9457SAndroid Build Coastguard Worker convolution_op->weights_cache, weights_ptr, aligned_total_weights_size);
433*4bdc9457SAndroid Build Coastguard Worker }
434*4bdc9457SAndroid Build Coastguard Worker
435*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.conv2d = (struct xnn_ukernel_conv2d) {
436*4bdc9457SAndroid Build Coastguard Worker .hwc2chw_function = xnn_params.f32.conv_hwc2chw_3x3c3s2.ukernel_with_symm_padding,
437*4bdc9457SAndroid Build Coastguard Worker .output_height_tile = xnn_params.f32.conv_hwc2chw_3x3c3s2.output_height_tile,
438*4bdc9457SAndroid Build Coastguard Worker .output_channel_tile = xnn_params.f32.conv_hwc2chw_3x3c3s2.output_channel_tile,
439*4bdc9457SAndroid Build Coastguard Worker };
440*4bdc9457SAndroid Build Coastguard Worker
441*4bdc9457SAndroid Build Coastguard Worker break;
442*4bdc9457SAndroid Build Coastguard Worker }
443*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_dwconv:
444*4bdc9457SAndroid Build Coastguard Worker {
445*4bdc9457SAndroid Build Coastguard Worker assert(dwconv2d_parameters != NULL);
446*4bdc9457SAndroid Build Coastguard Worker assert(group_input_channels == 1);
447*4bdc9457SAndroid Build Coastguard Worker assert(group_output_channels == 1);
448*4bdc9457SAndroid Build Coastguard Worker
449*4bdc9457SAndroid Build Coastguard Worker const size_t packed_weights_size = groups * (kernel_height * kernel_width + 1 /* bias */) * sizeof(float);
450*4bdc9457SAndroid Build Coastguard Worker size_t aligned_total_weights_size = round_up_po2(packed_weights_size, XNN_ALLOCATION_ALIGNMENT);
451*4bdc9457SAndroid Build Coastguard Worker void* weights_ptr = xnn_get_pointer_to_write_weights(
452*4bdc9457SAndroid Build Coastguard Worker convolution_op, aligned_total_weights_size, 0);
453*4bdc9457SAndroid Build Coastguard Worker if (weights_ptr == NULL) {
454*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to reserve or allocate %zu bytes for %s operator dwconv packed weights",
455*4bdc9457SAndroid Build Coastguard Worker aligned_total_weights_size,
456*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32));
457*4bdc9457SAndroid Build Coastguard Worker goto error;
458*4bdc9457SAndroid Build Coastguard Worker }
459*4bdc9457SAndroid Build Coastguard Worker
460*4bdc9457SAndroid Build Coastguard Worker if (flags & XNN_FLAG_DEPTHWISE_CONVOLUTION) {
461*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_chw_dwconv_hwg_w(
462*4bdc9457SAndroid Build Coastguard Worker kernel_height * kernel_width, groups,
463*4bdc9457SAndroid Build Coastguard Worker kernel, bias, weights_ptr, NULL);
464*4bdc9457SAndroid Build Coastguard Worker } else {
465*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_chw_dwconv_ghw_w(
466*4bdc9457SAndroid Build Coastguard Worker kernel_height * kernel_width, groups,
467*4bdc9457SAndroid Build Coastguard Worker kernel, bias, weights_ptr, NULL);
468*4bdc9457SAndroid Build Coastguard Worker }
469*4bdc9457SAndroid Build Coastguard Worker
470*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache(convolution_op)) {
471*4bdc9457SAndroid Build Coastguard Worker convolution_op->packed_weights.offset = xnn_get_or_insert_weights_cache(
472*4bdc9457SAndroid Build Coastguard Worker convolution_op->weights_cache, weights_ptr, aligned_total_weights_size);
473*4bdc9457SAndroid Build Coastguard Worker }
474*4bdc9457SAndroid Build Coastguard Worker
475*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.dwconv2d = (struct xnn_ukernel_dwconv2d) {
476*4bdc9457SAndroid Build Coastguard Worker .chw_function = dwconv2d_parameters->ukernel,
477*4bdc9457SAndroid Build Coastguard Worker .output_width_tile = dwconv2d_parameters->output_width_tile,
478*4bdc9457SAndroid Build Coastguard Worker };
479*4bdc9457SAndroid Build Coastguard Worker
480*4bdc9457SAndroid Build Coastguard Worker break;
481*4bdc9457SAndroid Build Coastguard Worker }
482*4bdc9457SAndroid Build Coastguard Worker default:
483*4bdc9457SAndroid Build Coastguard Worker XNN_UNREACHABLE;
484*4bdc9457SAndroid Build Coastguard Worker }
485*4bdc9457SAndroid Build Coastguard Worker
486*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_top = input_padding_top;
487*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_right = input_padding_right;
488*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_bottom = input_padding_bottom;
489*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_left = input_padding_left;
490*4bdc9457SAndroid Build Coastguard Worker
491*4bdc9457SAndroid Build Coastguard Worker convolution_op->kernel_height = kernel_height;
492*4bdc9457SAndroid Build Coastguard Worker convolution_op->kernel_width = kernel_width;
493*4bdc9457SAndroid Build Coastguard Worker convolution_op->stride_height = subsampling_height;
494*4bdc9457SAndroid Build Coastguard Worker convolution_op->stride_width = subsampling_width;
495*4bdc9457SAndroid Build Coastguard Worker convolution_op->dilation_height = dilation_height;
496*4bdc9457SAndroid Build Coastguard Worker convolution_op->dilation_width = dilation_width;
497*4bdc9457SAndroid Build Coastguard Worker convolution_op->groups = groups;
498*4bdc9457SAndroid Build Coastguard Worker convolution_op->group_input_channels = group_input_channels;
499*4bdc9457SAndroid Build Coastguard Worker convolution_op->group_output_channels = group_output_channels;
500*4bdc9457SAndroid Build Coastguard Worker convolution_op->input_pixel_stride = input_channel_stride;
501*4bdc9457SAndroid Build Coastguard Worker convolution_op->output_pixel_stride = output_channel_stride;
502*4bdc9457SAndroid Build Coastguard Worker
503*4bdc9457SAndroid Build Coastguard Worker if (ukernel_type == xnn_ukernel_type_dwconv) {
504*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_chw_params(&convolution_op->params.f32_chw, 0, output_min, output_max);
505*4bdc9457SAndroid Build Coastguard Worker } else {
506*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_minmax_params(&convolution_op->params.f32_minmax, output_min, output_max);
507*4bdc9457SAndroid Build Coastguard Worker }
508*4bdc9457SAndroid Build Coastguard Worker
509*4bdc9457SAndroid Build Coastguard Worker convolution_op->type = xnn_operator_type_convolution_nchw_f32;
510*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.type = ukernel_type;
511*4bdc9457SAndroid Build Coastguard Worker convolution_op->flags = flags;
512*4bdc9457SAndroid Build Coastguard Worker
513*4bdc9457SAndroid Build Coastguard Worker convolution_op->state = xnn_run_state_invalid;
514*4bdc9457SAndroid Build Coastguard Worker
515*4bdc9457SAndroid Build Coastguard Worker *convolution_op_out = convolution_op;
516*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
517*4bdc9457SAndroid Build Coastguard Worker
518*4bdc9457SAndroid Build Coastguard Worker error:
519*4bdc9457SAndroid Build Coastguard Worker xnn_delete_operator(convolution_op);
520*4bdc9457SAndroid Build Coastguard Worker return status;
521*4bdc9457SAndroid Build Coastguard Worker }
522*4bdc9457SAndroid Build Coastguard Worker
setup_convolution2d_nchw(xnn_operator_t convolution_op,size_t batch_size,size_t input_height,size_t input_width,const void * input,void * output,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,uint32_t log2_output_element_size,const void * params,void * chw_params,size_t num_threads)523*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status setup_convolution2d_nchw(
524*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op,
525*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
526*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
527*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
528*4bdc9457SAndroid Build Coastguard Worker const void* input,
529*4bdc9457SAndroid Build Coastguard Worker void* output,
530*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_input_element_size,
531*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_filter_element_size,
532*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_element_size,
533*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_output_element_size,
534*4bdc9457SAndroid Build Coastguard Worker const void* params,
535*4bdc9457SAndroid Build Coastguard Worker void* chw_params,
536*4bdc9457SAndroid Build Coastguard Worker size_t num_threads)
537*4bdc9457SAndroid Build Coastguard Worker {
538*4bdc9457SAndroid Build Coastguard Worker convolution_op->state = xnn_run_state_invalid;
539*4bdc9457SAndroid Build Coastguard Worker
540*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
541*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
542*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32));
543*4bdc9457SAndroid Build Coastguard Worker return xnn_status_uninitialized;
544*4bdc9457SAndroid Build Coastguard Worker }
545*4bdc9457SAndroid Build Coastguard Worker
546*4bdc9457SAndroid Build Coastguard Worker if (input_width == 0 || input_height == 0) {
547*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
548*4bdc9457SAndroid Build Coastguard Worker "failed to setup %s operator with %zux%zu input: input dimensions must be non-zero",
549*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32), input_width, input_height);
550*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
551*4bdc9457SAndroid Build Coastguard Worker }
552*4bdc9457SAndroid Build Coastguard Worker
553*4bdc9457SAndroid Build Coastguard Worker if (batch_size == 0) {
554*4bdc9457SAndroid Build Coastguard Worker convolution_op->state = xnn_run_state_skip;
555*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
556*4bdc9457SAndroid Build Coastguard Worker }
557*4bdc9457SAndroid Build Coastguard Worker
558*4bdc9457SAndroid Build Coastguard Worker if (convolution_op->weights_cache != NULL && !xnn_weights_cache_is_finalized(convolution_op->weights_cache)) {
559*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup %s operator: weights cache is not finalized",
560*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32));
561*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_state;
562*4bdc9457SAndroid Build Coastguard Worker }
563*4bdc9457SAndroid Build Coastguard Worker
564*4bdc9457SAndroid Build Coastguard Worker convolution_op->batch_size = batch_size;
565*4bdc9457SAndroid Build Coastguard Worker convolution_op->input_height = input_height;
566*4bdc9457SAndroid Build Coastguard Worker convolution_op->input_width = input_width;
567*4bdc9457SAndroid Build Coastguard Worker convolution_op->input = input;
568*4bdc9457SAndroid Build Coastguard Worker convolution_op->output = output;
569*4bdc9457SAndroid Build Coastguard Worker
570*4bdc9457SAndroid Build Coastguard Worker const size_t output_height = xnn_compute_convolution_output_dimension(
571*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_top + input_height + convolution_op->padding_bottom,
572*4bdc9457SAndroid Build Coastguard Worker convolution_op->kernel_height,
573*4bdc9457SAndroid Build Coastguard Worker convolution_op->dilation_height,
574*4bdc9457SAndroid Build Coastguard Worker convolution_op->stride_height);
575*4bdc9457SAndroid Build Coastguard Worker const size_t output_width = xnn_compute_convolution_output_dimension(
576*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_left + input_width + convolution_op->padding_right,
577*4bdc9457SAndroid Build Coastguard Worker convolution_op->kernel_width,
578*4bdc9457SAndroid Build Coastguard Worker convolution_op->dilation_width,
579*4bdc9457SAndroid Build Coastguard Worker convolution_op->stride_width);
580*4bdc9457SAndroid Build Coastguard Worker
581*4bdc9457SAndroid Build Coastguard Worker const size_t input_batch_stride = (input_height * input_width * convolution_op->input_pixel_stride) << log2_input_element_size;
582*4bdc9457SAndroid Build Coastguard Worker const size_t output_batch_stride = (output_height * output_width * convolution_op->output_pixel_stride) << log2_output_element_size;
583*4bdc9457SAndroid Build Coastguard Worker switch (convolution_op->ukernel.type) {
584*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_spmm:
585*4bdc9457SAndroid Build Coastguard Worker {
586*4bdc9457SAndroid Build Coastguard Worker const size_t num_nonzero_values = convolution_op->num_nonzero_values;
587*4bdc9457SAndroid Build Coastguard Worker const size_t num_nonzero_blocks = convolution_op->num_nonzero_blocks;
588*4bdc9457SAndroid Build Coastguard Worker const size_t num_output_channel_blocks = convolution_op->num_output_channel_blocks;
589*4bdc9457SAndroid Build Coastguard Worker
590*4bdc9457SAndroid Build Coastguard Worker convolution_op->num_nonzero_values = num_nonzero_values;
591*4bdc9457SAndroid Build Coastguard Worker convolution_op->num_nonzero_blocks = num_nonzero_blocks;
592*4bdc9457SAndroid Build Coastguard Worker convolution_op->num_output_channel_blocks = num_output_channel_blocks;
593*4bdc9457SAndroid Build Coastguard Worker
594*4bdc9457SAndroid Build Coastguard Worker float* nonzero_values = packed_weights(convolution_op);
595*4bdc9457SAndroid Build Coastguard Worker int32_t* input_increments = (int32_t*) (nonzero_values + num_nonzero_values + convolution_op->group_output_channels);
596*4bdc9457SAndroid Build Coastguard Worker uint32_t* output_channel_nonzeros = (uint32_t*) (input_increments + num_nonzero_blocks);
597*4bdc9457SAndroid Build Coastguard Worker int32_t* input_channel_diffs = (int32_t*) (output_channel_nonzeros + num_output_channel_blocks);
598*4bdc9457SAndroid Build Coastguard Worker
599*4bdc9457SAndroid Build Coastguard Worker const size_t input_size = input_height * input_width;
600*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_nonzero_blocks; i++) {
601*4bdc9457SAndroid Build Coastguard Worker const int32_t diff = input_channel_diffs[i];
602*4bdc9457SAndroid Build Coastguard Worker const int64_t increment = (int64_t) diff * input_size;
603*4bdc9457SAndroid Build Coastguard Worker if ((int64_t) (int32_t) increment != increment) {
604*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
605*4bdc9457SAndroid Build Coastguard Worker "failed to setup %s operator with sparse kernel representation: input increment exceeds int32_t range",
606*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32));
607*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
608*4bdc9457SAndroid Build Coastguard Worker }
609*4bdc9457SAndroid Build Coastguard Worker input_increments[i] = (int32_t) increment;
610*4bdc9457SAndroid Build Coastguard Worker }
611*4bdc9457SAndroid Build Coastguard Worker
612*4bdc9457SAndroid Build Coastguard Worker convolution_op->context.spmm = (struct spmm_context) {
613*4bdc9457SAndroid Build Coastguard Worker .n = convolution_op->group_output_channels,
614*4bdc9457SAndroid Build Coastguard Worker .scaled_m = input_size * sizeof(float),
615*4bdc9457SAndroid Build Coastguard Worker .input = (const void*) ((uintptr_t) input + (convolution_op->first_input_channel * input_size * sizeof(float))),
616*4bdc9457SAndroid Build Coastguard Worker .nonzero_weights = nonzero_values,
617*4bdc9457SAndroid Build Coastguard Worker .input_increments = input_increments,
618*4bdc9457SAndroid Build Coastguard Worker .output_channel_nonzeros = output_channel_nonzeros,
619*4bdc9457SAndroid Build Coastguard Worker .output = output,
620*4bdc9457SAndroid Build Coastguard Worker .batched_input_stride = input_batch_stride,
621*4bdc9457SAndroid Build Coastguard Worker .batched_output_stride = output_batch_stride,
622*4bdc9457SAndroid Build Coastguard Worker .ukernel = convolution_op->ukernel.spmm.function,
623*4bdc9457SAndroid Build Coastguard Worker };
624*4bdc9457SAndroid Build Coastguard Worker memcpy(&convolution_op->context.spmm.params, params, sizeof(convolution_op->context.spmm.params));
625*4bdc9457SAndroid Build Coastguard Worker
626*4bdc9457SAndroid Build Coastguard Worker const size_t mr = convolution_op->ukernel.spmm.mr;
627*4bdc9457SAndroid Build Coastguard Worker #if XNN_TEST_MODE
628*4bdc9457SAndroid Build Coastguard Worker const size_t mc = mr;
629*4bdc9457SAndroid Build Coastguard Worker #else
630*4bdc9457SAndroid Build Coastguard Worker size_t mc = input_size;
631*4bdc9457SAndroid Build Coastguard Worker if (num_threads > 1) {
632*4bdc9457SAndroid Build Coastguard Worker const size_t target_tiles_per_thread = 5;
633*4bdc9457SAndroid Build Coastguard Worker const size_t max_mc = divide_round_up(input_size, num_threads * target_tiles_per_thread);
634*4bdc9457SAndroid Build Coastguard Worker if (max_mc < mc) {
635*4bdc9457SAndroid Build Coastguard Worker mc = min(mc, divide_round_up(mc, max_mc * mr) * mr);
636*4bdc9457SAndroid Build Coastguard Worker }
637*4bdc9457SAndroid Build Coastguard Worker }
638*4bdc9457SAndroid Build Coastguard Worker #endif
639*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_2d_tile_1d;
640*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_2d_tile_1d = (pthreadpool_task_2d_tile_1d_t) xnn_compute_spmm;
641*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[0] = batch_size;
642*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[1] = input_size * sizeof(float);
643*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.tile[0] = mc * sizeof(float);
644*4bdc9457SAndroid Build Coastguard Worker convolution_op->state = xnn_run_state_ready;
645*4bdc9457SAndroid Build Coastguard Worker
646*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
647*4bdc9457SAndroid Build Coastguard Worker }
648*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_conv2d_hwc2chw:
649*4bdc9457SAndroid Build Coastguard Worker {
650*4bdc9457SAndroid Build Coastguard Worker const size_t zero_size = (input_width * convolution_op->group_input_channels << log2_input_element_size) + XNN_EXTRA_BYTES;
651*4bdc9457SAndroid Build Coastguard Worker void* zero_buffer = xnn_reallocate_memory(convolution_op->zero_buffer, zero_size);
652*4bdc9457SAndroid Build Coastguard Worker if (zero_buffer == NULL) {
653*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
654*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator zero padding",
655*4bdc9457SAndroid Build Coastguard Worker sizeof(struct xnn_operator), xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32));
656*4bdc9457SAndroid Build Coastguard Worker return xnn_status_out_of_memory;
657*4bdc9457SAndroid Build Coastguard Worker }
658*4bdc9457SAndroid Build Coastguard Worker memset(zero_buffer, 0, zero_size);
659*4bdc9457SAndroid Build Coastguard Worker convolution_op->zero_buffer = zero_buffer;
660*4bdc9457SAndroid Build Coastguard Worker
661*4bdc9457SAndroid Build Coastguard Worker convolution_op->context.conv2d = (struct conv2d_context) {
662*4bdc9457SAndroid Build Coastguard Worker .input_height = input_height,
663*4bdc9457SAndroid Build Coastguard Worker .input_width = input_width,
664*4bdc9457SAndroid Build Coastguard Worker .input = input,
665*4bdc9457SAndroid Build Coastguard Worker .input_batch_stride = input_batch_stride,
666*4bdc9457SAndroid Build Coastguard Worker .zero = zero_buffer,
667*4bdc9457SAndroid Build Coastguard Worker .packed_weights = packed_weights(convolution_op),
668*4bdc9457SAndroid Build Coastguard Worker .output = output,
669*4bdc9457SAndroid Build Coastguard Worker .output_batch_stride = output_batch_stride,
670*4bdc9457SAndroid Build Coastguard Worker .input_padding_top = convolution_op->padding_top,
671*4bdc9457SAndroid Build Coastguard Worker .output_channels = convolution_op->group_output_channels,
672*4bdc9457SAndroid Build Coastguard Worker .output_height_stride = output_width << log2_output_element_size,
673*4bdc9457SAndroid Build Coastguard Worker .output_channel_stride = output_height * output_width << log2_output_element_size,
674*4bdc9457SAndroid Build Coastguard Worker .hwc2chw_ukernel = convolution_op->ukernel.conv2d.hwc2chw_function,
675*4bdc9457SAndroid Build Coastguard Worker };
676*4bdc9457SAndroid Build Coastguard Worker memcpy(&convolution_op->context.conv2d.params, params, sizeof(convolution_op->context.conv2d.params));
677*4bdc9457SAndroid Build Coastguard Worker
678*4bdc9457SAndroid Build Coastguard Worker const size_t output_height_tile = convolution_op->ukernel.conv2d.output_height_tile;
679*4bdc9457SAndroid Build Coastguard Worker #if XNN_TEST_MODE
680*4bdc9457SAndroid Build Coastguard Worker size_t output_height_slice = output_height_tile;
681*4bdc9457SAndroid Build Coastguard Worker #else
682*4bdc9457SAndroid Build Coastguard Worker size_t output_height_slice = output_height;
683*4bdc9457SAndroid Build Coastguard Worker if (num_threads > 1) {
684*4bdc9457SAndroid Build Coastguard Worker const size_t target_tiles_per_thread = 5;
685*4bdc9457SAndroid Build Coastguard Worker const size_t max_output_height_slice = divide_round_up(output_height, num_threads * target_tiles_per_thread);
686*4bdc9457SAndroid Build Coastguard Worker if (max_output_height_slice < output_height_slice) {
687*4bdc9457SAndroid Build Coastguard Worker output_height_slice = min(output_height_slice,
688*4bdc9457SAndroid Build Coastguard Worker divide_round_up(output_height_slice, max_output_height_slice * output_height_tile) * output_height_tile);
689*4bdc9457SAndroid Build Coastguard Worker }
690*4bdc9457SAndroid Build Coastguard Worker }
691*4bdc9457SAndroid Build Coastguard Worker #endif
692*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_2d_tile_1d;
693*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_2d_tile_1d = (pthreadpool_task_2d_tile_1d_t) xnn_compute_conv2d_hwc2chw;
694*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[0] = batch_size;
695*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[1] = output_height;
696*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.tile[0] = output_height_slice;
697*4bdc9457SAndroid Build Coastguard Worker convolution_op->state = xnn_run_state_ready;
698*4bdc9457SAndroid Build Coastguard Worker
699*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
700*4bdc9457SAndroid Build Coastguard Worker }
701*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_dwconv:
702*4bdc9457SAndroid Build Coastguard Worker {
703*4bdc9457SAndroid Build Coastguard Worker const size_t zero_size = (input_width << log2_input_element_size) + 2 * XNN_EXTRA_BYTES;
704*4bdc9457SAndroid Build Coastguard Worker void* zero_buffer = xnn_reallocate_memory(convolution_op->zero_buffer, zero_size);
705*4bdc9457SAndroid Build Coastguard Worker if (zero_buffer == NULL) {
706*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
707*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator zero padding",
708*4bdc9457SAndroid Build Coastguard Worker sizeof(struct xnn_operator), xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32));
709*4bdc9457SAndroid Build Coastguard Worker return xnn_status_out_of_memory;
710*4bdc9457SAndroid Build Coastguard Worker }
711*4bdc9457SAndroid Build Coastguard Worker memset(zero_buffer, 0, zero_size);
712*4bdc9457SAndroid Build Coastguard Worker convolution_op->zero_buffer = zero_buffer;
713*4bdc9457SAndroid Build Coastguard Worker
714*4bdc9457SAndroid Build Coastguard Worker xnn_update_f32_chw_params((union xnn_f32_chw_params*) chw_params, (uint32_t) input_width);
715*4bdc9457SAndroid Build Coastguard Worker convolution_op->context.dwconv2d = (struct dwconv2d_context) {
716*4bdc9457SAndroid Build Coastguard Worker .input_height = input_height,
717*4bdc9457SAndroid Build Coastguard Worker .input_width = input_width << log2_input_element_size,
718*4bdc9457SAndroid Build Coastguard Worker .input = input,
719*4bdc9457SAndroid Build Coastguard Worker .zero = zero_buffer,
720*4bdc9457SAndroid Build Coastguard Worker .input_padding_top = convolution_op->padding_top,
721*4bdc9457SAndroid Build Coastguard Worker .input_channel_stride = input_height * input_width << log2_input_element_size,
722*4bdc9457SAndroid Build Coastguard Worker .input_batch_stride = input_batch_stride,
723*4bdc9457SAndroid Build Coastguard Worker .packed_weights = packed_weights(convolution_op),
724*4bdc9457SAndroid Build Coastguard Worker .weights_channel_stride = bias_element_size +
725*4bdc9457SAndroid Build Coastguard Worker (convolution_op->kernel_height * convolution_op->kernel_width << log2_filter_element_size),
726*4bdc9457SAndroid Build Coastguard Worker .output = output,
727*4bdc9457SAndroid Build Coastguard Worker .output_channel_stride = output_height * output_width << log2_output_element_size,
728*4bdc9457SAndroid Build Coastguard Worker .output_batch_stride = output_batch_stride,
729*4bdc9457SAndroid Build Coastguard Worker .chw_ukernel = convolution_op->ukernel.dwconv2d.chw_function,
730*4bdc9457SAndroid Build Coastguard Worker };
731*4bdc9457SAndroid Build Coastguard Worker memcpy(&convolution_op->context.dwconv2d.params, chw_params, sizeof(convolution_op->context.dwconv2d.params));
732*4bdc9457SAndroid Build Coastguard Worker
733*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_2d;
734*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_dwconv2d_chw;
735*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[0] = batch_size;
736*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[1] = convolution_op->groups;
737*4bdc9457SAndroid Build Coastguard Worker convolution_op->state = xnn_run_state_ready;
738*4bdc9457SAndroid Build Coastguard Worker
739*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
740*4bdc9457SAndroid Build Coastguard Worker }
741*4bdc9457SAndroid Build Coastguard Worker default:
742*4bdc9457SAndroid Build Coastguard Worker XNN_UNREACHABLE;
743*4bdc9457SAndroid Build Coastguard Worker }
744*4bdc9457SAndroid Build Coastguard Worker }
745*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_convolution2d_nchw_f32(xnn_operator_t convolution_op,size_t batch_size,size_t input_height,size_t input_width,const float * input,float * output,pthreadpool_t threadpool)746*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_convolution2d_nchw_f32(
747*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op,
748*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
749*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
750*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
751*4bdc9457SAndroid Build Coastguard Worker const float* input,
752*4bdc9457SAndroid Build Coastguard Worker float* output,
753*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
754*4bdc9457SAndroid Build Coastguard Worker {
755*4bdc9457SAndroid Build Coastguard Worker if (convolution_op->type != xnn_operator_type_convolution_nchw_f32) {
756*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
757*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nchw_f32),
758*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(convolution_op->type));
759*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
760*4bdc9457SAndroid Build Coastguard Worker }
761*4bdc9457SAndroid Build Coastguard Worker
762*4bdc9457SAndroid Build Coastguard Worker return setup_convolution2d_nchw(
763*4bdc9457SAndroid Build Coastguard Worker convolution_op,
764*4bdc9457SAndroid Build Coastguard Worker batch_size, input_height, input_width,
765*4bdc9457SAndroid Build Coastguard Worker input, output,
766*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
767*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
768*4bdc9457SAndroid Build Coastguard Worker sizeof(float) /* sizeof(bias element) */,
769*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
770*4bdc9457SAndroid Build Coastguard Worker &convolution_op->params.f32_minmax,
771*4bdc9457SAndroid Build Coastguard Worker &convolution_op->params.f32_chw,
772*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
773*4bdc9457SAndroid Build Coastguard Worker }
774