1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker
9*4bdc9457SAndroid Build Coastguard Worker #include <assert.h>
10*4bdc9457SAndroid Build Coastguard Worker #include <math.h>
11*4bdc9457SAndroid Build Coastguard Worker #include <stdbool.h>
12*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h>
13*4bdc9457SAndroid Build Coastguard Worker #include <stdint.h>
14*4bdc9457SAndroid Build Coastguard Worker #include <stdlib.h>
15*4bdc9457SAndroid Build Coastguard Worker #include <string.h>
16*4bdc9457SAndroid Build Coastguard Worker
17*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
18*4bdc9457SAndroid Build Coastguard Worker
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/allocator.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/cache.h>
22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/common.h>
23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/compute.h>
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/indirection.h>
25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/log.h>
26*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h>
27*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
28*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/pack.h>
29*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/params.h>
30*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/post-operation.h>
31*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h>
32*4bdc9457SAndroid Build Coastguard Worker
33*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_ENABLE_GEMM_M_SPECIALIZATION
34*4bdc9457SAndroid Build Coastguard Worker #error "XNN_ENABLE_GEMM_M_SPECIALIZATION is not defined"
35*4bdc9457SAndroid Build Coastguard Worker #endif
36*4bdc9457SAndroid Build Coastguard Worker
compute_output_dimension_with_tf_same_padding(size_t input_dimension,size_t subsampling_dimension)37*4bdc9457SAndroid Build Coastguard Worker static inline size_t compute_output_dimension_with_tf_same_padding(
38*4bdc9457SAndroid Build Coastguard Worker size_t input_dimension,
39*4bdc9457SAndroid Build Coastguard Worker size_t subsampling_dimension)
40*4bdc9457SAndroid Build Coastguard Worker {
41*4bdc9457SAndroid Build Coastguard Worker return divide_round_up(input_dimension, subsampling_dimension);
42*4bdc9457SAndroid Build Coastguard Worker }
43*4bdc9457SAndroid Build Coastguard Worker
find_dwconv_ukernel(size_t kernel_size,const struct dwconv_parameters * ukernel,size_t num_ukernels)44*4bdc9457SAndroid Build Coastguard Worker static inline const struct dwconv_parameters* find_dwconv_ukernel(
45*4bdc9457SAndroid Build Coastguard Worker size_t kernel_size,
46*4bdc9457SAndroid Build Coastguard Worker const struct dwconv_parameters* ukernel,
47*4bdc9457SAndroid Build Coastguard Worker size_t num_ukernels)
48*4bdc9457SAndroid Build Coastguard Worker {
49*4bdc9457SAndroid Build Coastguard Worker while (num_ukernels-- != 0) {
50*4bdc9457SAndroid Build Coastguard Worker if (ukernel->primary_tile == kernel_size) {
51*4bdc9457SAndroid Build Coastguard Worker return ukernel;
52*4bdc9457SAndroid Build Coastguard Worker }
53*4bdc9457SAndroid Build Coastguard Worker ukernel++;
54*4bdc9457SAndroid Build Coastguard Worker }
55*4bdc9457SAndroid Build Coastguard Worker return NULL;
56*4bdc9457SAndroid Build Coastguard Worker }
57*4bdc9457SAndroid Build Coastguard Worker
58*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT
cached_code_at_offset(xnn_operator_t op,size_t offset)59*4bdc9457SAndroid Build Coastguard Worker static inline uintptr_t cached_code_at_offset(xnn_operator_t op, size_t offset)
60*4bdc9457SAndroid Build Coastguard Worker {
61*4bdc9457SAndroid Build Coastguard Worker return (uintptr_t)op->code_cache->cache.code.start + offset;
62*4bdc9457SAndroid Build Coastguard Worker }
63*4bdc9457SAndroid Build Coastguard Worker
get_generated_gemm(struct xnn_hmp_gemm_codegen generators,struct jit_gemm_params * jit_gemm_params,size_t mr,size_t group_output_channels,size_t nr,size_t group_input_channels,size_t log2_input_element_size,struct xnn_code_cache * code_cache)64*4bdc9457SAndroid Build Coastguard Worker static size_t get_generated_gemm(
65*4bdc9457SAndroid Build Coastguard Worker struct xnn_hmp_gemm_codegen generators,
66*4bdc9457SAndroid Build Coastguard Worker struct jit_gemm_params *jit_gemm_params,
67*4bdc9457SAndroid Build Coastguard Worker size_t mr,
68*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
69*4bdc9457SAndroid Build Coastguard Worker size_t nr,
70*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
71*4bdc9457SAndroid Build Coastguard Worker size_t log2_input_element_size,
72*4bdc9457SAndroid Build Coastguard Worker struct xnn_code_cache* code_cache)
73*4bdc9457SAndroid Build Coastguard Worker {
74*4bdc9457SAndroid Build Coastguard Worker size_t offset = XNN_CACHE_NOT_FOUND;
75*4bdc9457SAndroid Build Coastguard Worker xnn_jit_gemm_code_generator_function generator = generators.function[XNN_UARCH_DEFAULT];
76*4bdc9457SAndroid Build Coastguard Worker if (generator == NULL) {
77*4bdc9457SAndroid Build Coastguard Worker goto error;
78*4bdc9457SAndroid Build Coastguard Worker }
79*4bdc9457SAndroid Build Coastguard Worker
80*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status = xnn_status_success;
81*4bdc9457SAndroid Build Coastguard Worker
82*4bdc9457SAndroid Build Coastguard Worker status = xnn_reserve_code_memory(&code_cache->cache.code, XNN_DEFAULT_MICROKERNEL_SIZE);
83*4bdc9457SAndroid Build Coastguard Worker if (xnn_status_success != status) {
84*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to ensure sufficient space in the code buffer for a microkernel");
85*4bdc9457SAndroid Build Coastguard Worker goto error;
86*4bdc9457SAndroid Build Coastguard Worker }
87*4bdc9457SAndroid Build Coastguard Worker
88*4bdc9457SAndroid Build Coastguard Worker const size_t old_size = code_cache->cache.code.size;
89*4bdc9457SAndroid Build Coastguard Worker void* old_code = (uint8_t*) code_cache->cache.code.start + old_size;
90*4bdc9457SAndroid Build Coastguard Worker status = generator(&code_cache->cache.code, mr, group_output_channels % nr,
91*4bdc9457SAndroid Build Coastguard Worker group_input_channels << log2_input_element_size,
92*4bdc9457SAndroid Build Coastguard Worker jit_gemm_params);
93*4bdc9457SAndroid Build Coastguard Worker
94*4bdc9457SAndroid Build Coastguard Worker if (xnn_status_success != status) {
95*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to generate GEMM microkernel");
96*4bdc9457SAndroid Build Coastguard Worker goto error;
97*4bdc9457SAndroid Build Coastguard Worker }
98*4bdc9457SAndroid Build Coastguard Worker
99*4bdc9457SAndroid Build Coastguard Worker const size_t new_size = code_cache->cache.code.size;
100*4bdc9457SAndroid Build Coastguard Worker return xnn_get_or_insert_code_cache(code_cache, old_code, new_size - old_size);
101*4bdc9457SAndroid Build Coastguard Worker
102*4bdc9457SAndroid Build Coastguard Worker error:
103*4bdc9457SAndroid Build Coastguard Worker return offset;
104*4bdc9457SAndroid Build Coastguard Worker }
105*4bdc9457SAndroid Build Coastguard Worker
generate_gemms_up_to_max_mr(size_t max_mr,struct gemm_codegens generators,struct jit_gemm_params * jit_gemm_params,size_t group_output_channels,size_t nr,size_t group_input_channels,size_t log2_input_element_size,xnn_operator_t convolution_op)106*4bdc9457SAndroid Build Coastguard Worker static void generate_gemms_up_to_max_mr(
107*4bdc9457SAndroid Build Coastguard Worker size_t max_mr,
108*4bdc9457SAndroid Build Coastguard Worker struct gemm_codegens generators,
109*4bdc9457SAndroid Build Coastguard Worker struct jit_gemm_params *jit_gemm_params,
110*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
111*4bdc9457SAndroid Build Coastguard Worker size_t nr,
112*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
113*4bdc9457SAndroid Build Coastguard Worker size_t log2_input_element_size,
114*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op)
115*4bdc9457SAndroid Build Coastguard Worker {
116*4bdc9457SAndroid Build Coastguard Worker assert(XNN_MAX_MR >= max_mr);
117*4bdc9457SAndroid Build Coastguard Worker if (convolution_op->code_cache == NULL) {
118*4bdc9457SAndroid Build Coastguard Worker return;
119*4bdc9457SAndroid Build Coastguard Worker }
120*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.gemm.gemm_cases[0].generated_code_offset[XNN_UARCH_DEFAULT] =
121*4bdc9457SAndroid Build Coastguard Worker get_generated_gemm(generators.gemm1, jit_gemm_params, 1, group_output_channels, nr, group_input_channels,
122*4bdc9457SAndroid Build Coastguard Worker log2_input_element_size, convolution_op->code_cache);
123*4bdc9457SAndroid Build Coastguard Worker for (size_t mr = 2; mr <= max_mr; mr++) {
124*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.gemm.gemm_cases[mr - 1].generated_code_offset[XNN_UARCH_DEFAULT] =
125*4bdc9457SAndroid Build Coastguard Worker get_generated_gemm(generators.gemm, jit_gemm_params, mr, group_output_channels, nr, group_input_channels,
126*4bdc9457SAndroid Build Coastguard Worker log2_input_element_size, convolution_op->code_cache);
127*4bdc9457SAndroid Build Coastguard Worker }
128*4bdc9457SAndroid Build Coastguard Worker }
129*4bdc9457SAndroid Build Coastguard Worker
get_generated_igemm(struct xnn_hmp_igemm_codegen generators,struct jit_gemm_params * jit_gemm_params,size_t group_output_channels,size_t nr,size_t group_input_channels,size_t log2_input_element_size,size_t kernel_size,size_t mr,struct xnn_code_cache * code_cache)130*4bdc9457SAndroid Build Coastguard Worker static size_t get_generated_igemm(
131*4bdc9457SAndroid Build Coastguard Worker struct xnn_hmp_igemm_codegen generators,
132*4bdc9457SAndroid Build Coastguard Worker struct jit_gemm_params *jit_gemm_params,
133*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
134*4bdc9457SAndroid Build Coastguard Worker size_t nr,
135*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
136*4bdc9457SAndroid Build Coastguard Worker size_t log2_input_element_size,
137*4bdc9457SAndroid Build Coastguard Worker size_t kernel_size,
138*4bdc9457SAndroid Build Coastguard Worker size_t mr,
139*4bdc9457SAndroid Build Coastguard Worker struct xnn_code_cache* code_cache)
140*4bdc9457SAndroid Build Coastguard Worker {
141*4bdc9457SAndroid Build Coastguard Worker size_t offset = XNN_CACHE_NOT_FOUND;
142*4bdc9457SAndroid Build Coastguard Worker xnn_jit_igemm_code_generator_function generator = generators.function[XNN_UARCH_DEFAULT];
143*4bdc9457SAndroid Build Coastguard Worker if (generator == NULL) {
144*4bdc9457SAndroid Build Coastguard Worker goto error;
145*4bdc9457SAndroid Build Coastguard Worker }
146*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status = xnn_status_success;
147*4bdc9457SAndroid Build Coastguard Worker
148*4bdc9457SAndroid Build Coastguard Worker status = xnn_reserve_code_memory(&code_cache->cache.code, XNN_DEFAULT_MICROKERNEL_SIZE);
149*4bdc9457SAndroid Build Coastguard Worker if (xnn_status_success != status) {
150*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to ensure sufficient space in code buffer for microkernel");
151*4bdc9457SAndroid Build Coastguard Worker goto error;
152*4bdc9457SAndroid Build Coastguard Worker }
153*4bdc9457SAndroid Build Coastguard Worker
154*4bdc9457SAndroid Build Coastguard Worker const size_t old_size = code_cache->cache.code.size;
155*4bdc9457SAndroid Build Coastguard Worker void* old_code = (uint8_t*) code_cache->cache.code.start + old_size;
156*4bdc9457SAndroid Build Coastguard Worker status = generator(&code_cache->cache.code, mr, group_output_channels % nr,
157*4bdc9457SAndroid Build Coastguard Worker group_input_channels << log2_input_element_size,
158*4bdc9457SAndroid Build Coastguard Worker kernel_size * mr * sizeof(void*), jit_gemm_params);
159*4bdc9457SAndroid Build Coastguard Worker if (status != xnn_status_success) {
160*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to generate IGEMM microkernel");
161*4bdc9457SAndroid Build Coastguard Worker goto error;
162*4bdc9457SAndroid Build Coastguard Worker }
163*4bdc9457SAndroid Build Coastguard Worker
164*4bdc9457SAndroid Build Coastguard Worker const size_t new_size = code_cache->cache.code.size;
165*4bdc9457SAndroid Build Coastguard Worker return xnn_get_or_insert_code_cache(code_cache, old_code, new_size - old_size);
166*4bdc9457SAndroid Build Coastguard Worker
167*4bdc9457SAndroid Build Coastguard Worker error:
168*4bdc9457SAndroid Build Coastguard Worker return offset;
169*4bdc9457SAndroid Build Coastguard Worker }
170*4bdc9457SAndroid Build Coastguard Worker
generate_igemms_up_to_max_mr(size_t max_mr,struct gemm_codegens generators,struct jit_gemm_params * jit_gemm_params,size_t group_output_channels,size_t nr,size_t group_input_channels,size_t log2_input_element_size,size_t kernel_size,xnn_operator_t convolution_op)171*4bdc9457SAndroid Build Coastguard Worker static void generate_igemms_up_to_max_mr(
172*4bdc9457SAndroid Build Coastguard Worker size_t max_mr,
173*4bdc9457SAndroid Build Coastguard Worker struct gemm_codegens generators,
174*4bdc9457SAndroid Build Coastguard Worker struct jit_gemm_params *jit_gemm_params,
175*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
176*4bdc9457SAndroid Build Coastguard Worker size_t nr,
177*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
178*4bdc9457SAndroid Build Coastguard Worker size_t log2_input_element_size,
179*4bdc9457SAndroid Build Coastguard Worker size_t kernel_size,
180*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op)
181*4bdc9457SAndroid Build Coastguard Worker {
182*4bdc9457SAndroid Build Coastguard Worker assert(XNN_MAX_MR >= max_mr);
183*4bdc9457SAndroid Build Coastguard Worker if (convolution_op->code_cache == NULL) {
184*4bdc9457SAndroid Build Coastguard Worker return;
185*4bdc9457SAndroid Build Coastguard Worker }
186*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.igemm.igemm_cases[0].generated_code_offset[XNN_UARCH_DEFAULT] =
187*4bdc9457SAndroid Build Coastguard Worker get_generated_igemm(generators.igemm1, jit_gemm_params, group_output_channels, nr, group_input_channels,
188*4bdc9457SAndroid Build Coastguard Worker log2_input_element_size, kernel_size, 1, convolution_op->code_cache);
189*4bdc9457SAndroid Build Coastguard Worker for (size_t mr = 2; mr <= max_mr; mr++) {
190*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.igemm.igemm_cases[mr - 1].generated_code_offset[XNN_UARCH_DEFAULT] =
191*4bdc9457SAndroid Build Coastguard Worker get_generated_igemm(generators.igemm, jit_gemm_params, group_output_channels, nr, group_input_channels,
192*4bdc9457SAndroid Build Coastguard Worker log2_input_element_size, kernel_size, mr, convolution_op->code_cache);
193*4bdc9457SAndroid Build Coastguard Worker }
194*4bdc9457SAndroid Build Coastguard Worker }
195*4bdc9457SAndroid Build Coastguard Worker #endif // XNN_PLATFORM_JIT
196*4bdc9457SAndroid Build Coastguard Worker
create_convolution2d_nhwc(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_channel_stride,size_t output_channel_stride,const void * kernel,const void * bias,uint32_t flags,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,xnn_pack_vmulcaddc_w_function pack_vmulcaddc_w,xnn_pack_dwconv_hwg_w_function pack_dwconv_hwg_w,xnn_pack_dwconv_ghw_w_function pack_dwconv_ghw_w,xnn_pack_gemm_goi_w_function pack_gemm_goi_w,xnn_pack_conv_kgo_w_function pack_conv_kgo_w,xnn_pack_conv_goki_w_function pack_conv_goki_w,const void * packing_params,int input_padding_byte,int packed_weights_padding_byte,size_t extra_weights_bytes,xnn_init_qc8_scale_params_fn init_scale_params,const float * scale_params,const void * gemm_params,size_t gemm_params_size,const void * dwconv_params,size_t dwconv_params_size,const void * vmulcaddc_params,size_t vmulcaddc_params_size,const struct gemm_parameters * gemm_parameters,const struct dwconv_parameters * dwconv_ukernel,const struct vmulcaddc_parameters * vmulcaddc_parameters,struct jit_gemm_params * jit_gemm_params,bool linear_activation,bool relu_activation,uint32_t datatype_init_flags,enum xnn_operator_type operator_type,size_t num_post_operations,void * post_operation_params,xnn_caches_t caches,xnn_operator_t * convolution_op_out)197*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status create_convolution2d_nhwc(
198*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_top,
199*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_right,
200*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_bottom,
201*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_left,
202*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height,
203*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width,
204*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_height,
205*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_width,
206*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height,
207*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width,
208*4bdc9457SAndroid Build Coastguard Worker uint32_t groups,
209*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
210*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
211*4bdc9457SAndroid Build Coastguard Worker size_t input_channel_stride,
212*4bdc9457SAndroid Build Coastguard Worker size_t output_channel_stride,
213*4bdc9457SAndroid Build Coastguard Worker const void* kernel,
214*4bdc9457SAndroid Build Coastguard Worker const void* bias,
215*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
216*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_input_element_size,
217*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_filter_element_size,
218*4bdc9457SAndroid Build Coastguard Worker uint32_t bias_element_size,
219*4bdc9457SAndroid Build Coastguard Worker xnn_pack_vmulcaddc_w_function pack_vmulcaddc_w,
220*4bdc9457SAndroid Build Coastguard Worker xnn_pack_dwconv_hwg_w_function pack_dwconv_hwg_w,
221*4bdc9457SAndroid Build Coastguard Worker xnn_pack_dwconv_ghw_w_function pack_dwconv_ghw_w,
222*4bdc9457SAndroid Build Coastguard Worker xnn_pack_gemm_goi_w_function pack_gemm_goi_w,
223*4bdc9457SAndroid Build Coastguard Worker xnn_pack_conv_kgo_w_function pack_conv_kgo_w,
224*4bdc9457SAndroid Build Coastguard Worker xnn_pack_conv_goki_w_function pack_conv_goki_w,
225*4bdc9457SAndroid Build Coastguard Worker const void* packing_params,
226*4bdc9457SAndroid Build Coastguard Worker int input_padding_byte,
227*4bdc9457SAndroid Build Coastguard Worker int packed_weights_padding_byte,
228*4bdc9457SAndroid Build Coastguard Worker size_t extra_weights_bytes,
229*4bdc9457SAndroid Build Coastguard Worker xnn_init_qc8_scale_params_fn init_scale_params,
230*4bdc9457SAndroid Build Coastguard Worker const float* scale_params,
231*4bdc9457SAndroid Build Coastguard Worker const void* gemm_params,
232*4bdc9457SAndroid Build Coastguard Worker size_t gemm_params_size,
233*4bdc9457SAndroid Build Coastguard Worker const void* dwconv_params,
234*4bdc9457SAndroid Build Coastguard Worker size_t dwconv_params_size,
235*4bdc9457SAndroid Build Coastguard Worker const void* vmulcaddc_params,
236*4bdc9457SAndroid Build Coastguard Worker size_t vmulcaddc_params_size,
237*4bdc9457SAndroid Build Coastguard Worker const struct gemm_parameters* gemm_parameters,
238*4bdc9457SAndroid Build Coastguard Worker const struct dwconv_parameters* dwconv_ukernel,
239*4bdc9457SAndroid Build Coastguard Worker const struct vmulcaddc_parameters* vmulcaddc_parameters,
240*4bdc9457SAndroid Build Coastguard Worker struct jit_gemm_params* jit_gemm_params,
241*4bdc9457SAndroid Build Coastguard Worker bool linear_activation,
242*4bdc9457SAndroid Build Coastguard Worker bool relu_activation,
243*4bdc9457SAndroid Build Coastguard Worker uint32_t datatype_init_flags,
244*4bdc9457SAndroid Build Coastguard Worker enum xnn_operator_type operator_type,
245*4bdc9457SAndroid Build Coastguard Worker size_t num_post_operations,
246*4bdc9457SAndroid Build Coastguard Worker void* post_operation_params,
247*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
248*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* convolution_op_out)
249*4bdc9457SAndroid Build Coastguard Worker {
250*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op = NULL;
251*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status = xnn_status_uninitialized;
252*4bdc9457SAndroid Build Coastguard Worker
253*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
254*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
255*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator: XNNPACK is not initialized",
256*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type));
257*4bdc9457SAndroid Build Coastguard Worker goto error;
258*4bdc9457SAndroid Build Coastguard Worker }
259*4bdc9457SAndroid Build Coastguard Worker
260*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_unsupported_hardware;
261*4bdc9457SAndroid Build Coastguard Worker
262*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
263*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
264*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator: operations on data type are not supported",
265*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type));
266*4bdc9457SAndroid Build Coastguard Worker goto error;
267*4bdc9457SAndroid Build Coastguard Worker }
268*4bdc9457SAndroid Build Coastguard Worker
269*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_invalid_parameter;
270*4bdc9457SAndroid Build Coastguard Worker
271*4bdc9457SAndroid Build Coastguard Worker if (kernel_width == 0 || kernel_height == 0) {
272*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
273*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 "x%" PRIu32 " kernel: kernel dimensions must be non-zero",
274*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), kernel_width, kernel_height);
275*4bdc9457SAndroid Build Coastguard Worker goto error;
276*4bdc9457SAndroid Build Coastguard Worker }
277*4bdc9457SAndroid Build Coastguard Worker
278*4bdc9457SAndroid Build Coastguard Worker if (subsampling_width == 0 || subsampling_height == 0) {
279*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
280*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 "x%" PRIu32 " subsampling: subsampling dimensions must be non-zero",
281*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), subsampling_width, subsampling_height);
282*4bdc9457SAndroid Build Coastguard Worker goto error;
283*4bdc9457SAndroid Build Coastguard Worker }
284*4bdc9457SAndroid Build Coastguard Worker
285*4bdc9457SAndroid Build Coastguard Worker if (dilation_width == 0 || dilation_height == 0) {
286*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
287*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 "x%" PRIu32 " dilation: dilation dimensions must be non-zero",
288*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), dilation_width, dilation_height);
289*4bdc9457SAndroid Build Coastguard Worker goto error;
290*4bdc9457SAndroid Build Coastguard Worker }
291*4bdc9457SAndroid Build Coastguard Worker
292*4bdc9457SAndroid Build Coastguard Worker if (groups == 0) {
293*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
294*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 " groups: number of groups must be non-zero",
295*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), groups);
296*4bdc9457SAndroid Build Coastguard Worker goto error;
297*4bdc9457SAndroid Build Coastguard Worker }
298*4bdc9457SAndroid Build Coastguard Worker
299*4bdc9457SAndroid Build Coastguard Worker if (group_input_channels == 0) {
300*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
301*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %zu input channels per group: number of channels must be non-zero",
302*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), group_input_channels);
303*4bdc9457SAndroid Build Coastguard Worker goto error;
304*4bdc9457SAndroid Build Coastguard Worker }
305*4bdc9457SAndroid Build Coastguard Worker
306*4bdc9457SAndroid Build Coastguard Worker if (group_output_channels == 0) {
307*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
308*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %zu output channels per group: number of channels must be non-zero",
309*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), group_output_channels);
310*4bdc9457SAndroid Build Coastguard Worker goto error;
311*4bdc9457SAndroid Build Coastguard Worker }
312*4bdc9457SAndroid Build Coastguard Worker
313*4bdc9457SAndroid Build Coastguard Worker const size_t input_channels = groups * group_input_channels;
314*4bdc9457SAndroid Build Coastguard Worker if (input_channel_stride < input_channels) {
315*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
316*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with input channel stride of %zu: "
317*4bdc9457SAndroid Build Coastguard Worker "stride must be at least as large as the number of input channels (%" PRIu32 "x%zu)",
318*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type),
319*4bdc9457SAndroid Build Coastguard Worker input_channel_stride, groups, group_input_channels);
320*4bdc9457SAndroid Build Coastguard Worker goto error;
321*4bdc9457SAndroid Build Coastguard Worker }
322*4bdc9457SAndroid Build Coastguard Worker
323*4bdc9457SAndroid Build Coastguard Worker const size_t output_channels = groups * group_output_channels;
324*4bdc9457SAndroid Build Coastguard Worker if (output_channel_stride < output_channels) {
325*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
326*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with output channel stride of %zu: "
327*4bdc9457SAndroid Build Coastguard Worker "stride must be at least as large as the number of output channels (%" PRIu32 "x%zu)",
328*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type),
329*4bdc9457SAndroid Build Coastguard Worker output_channel_stride, groups, group_output_channels);
330*4bdc9457SAndroid Build Coastguard Worker goto error;
331*4bdc9457SAndroid Build Coastguard Worker }
332*4bdc9457SAndroid Build Coastguard Worker
333*4bdc9457SAndroid Build Coastguard Worker if ((flags & XNN_FLAG_DEPTHWISE_CONVOLUTION) != 0 && group_input_channels != 1) {
334*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
335*4bdc9457SAndroid Build Coastguard Worker "failed to create depthwise %s operator with %zu input channels per group: "
336*4bdc9457SAndroid Build Coastguard Worker "depthwise convolution must have exactly 1 input channel per group",
337*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), group_input_channels);
338*4bdc9457SAndroid Build Coastguard Worker goto error;
339*4bdc9457SAndroid Build Coastguard Worker }
340*4bdc9457SAndroid Build Coastguard Worker
341*4bdc9457SAndroid Build Coastguard Worker const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
342*4bdc9457SAndroid Build Coastguard Worker if ((flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) != 0) {
343*4bdc9457SAndroid Build Coastguard Worker if (any_padding) {
344*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
345*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu32 "+%" PRIu32 "x%" PRIu32 "+%" PRIu32" padding: "
346*4bdc9457SAndroid Build Coastguard Worker "TensorFlow SAME padding can't be combined with explicit padding specification",
347*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type),
348*4bdc9457SAndroid Build Coastguard Worker input_padding_top, input_padding_left, input_padding_bottom, input_padding_right);
349*4bdc9457SAndroid Build Coastguard Worker goto error;
350*4bdc9457SAndroid Build Coastguard Worker }
351*4bdc9457SAndroid Build Coastguard Worker }
352*4bdc9457SAndroid Build Coastguard Worker
353*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_out_of_memory;
354*4bdc9457SAndroid Build Coastguard Worker
355*4bdc9457SAndroid Build Coastguard Worker convolution_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
356*4bdc9457SAndroid Build Coastguard Worker if (convolution_op == NULL) {
357*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
358*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator descriptor",
359*4bdc9457SAndroid Build Coastguard Worker sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
360*4bdc9457SAndroid Build Coastguard Worker goto error;
361*4bdc9457SAndroid Build Coastguard Worker }
362*4bdc9457SAndroid Build Coastguard Worker
363*4bdc9457SAndroid Build Coastguard Worker if (caches != NULL) {
364*4bdc9457SAndroid Build Coastguard Worker convolution_op->weights_cache = caches->weights_cache;
365*4bdc9457SAndroid Build Coastguard Worker convolution_op->code_cache = caches->code_cache;
366*4bdc9457SAndroid Build Coastguard Worker }
367*4bdc9457SAndroid Build Coastguard Worker
368*4bdc9457SAndroid Build Coastguard Worker const size_t kernel_size = kernel_height * kernel_width;
369*4bdc9457SAndroid Build Coastguard Worker
370*4bdc9457SAndroid Build Coastguard Worker enum xnn_ukernel_type ukernel_type = xnn_ukernel_type_default;
371*4bdc9457SAndroid Build Coastguard Worker const bool unit_subsampling = (subsampling_width | subsampling_height) == 1;
372*4bdc9457SAndroid Build Coastguard Worker if (group_input_channels == 1 && group_output_channels == 1 && kernel_size == 1 && unit_subsampling && !any_padding && vmulcaddc_parameters != NULL) {
373*4bdc9457SAndroid Build Coastguard Worker ukernel_type = xnn_ukernel_type_vmulcaddc;
374*4bdc9457SAndroid Build Coastguard Worker } else if (group_input_channels == 1 && group_output_channels == 1 && dwconv_ukernel != NULL)
375*4bdc9457SAndroid Build Coastguard Worker {
376*4bdc9457SAndroid Build Coastguard Worker ukernel_type = xnn_ukernel_type_dwconv;
377*4bdc9457SAndroid Build Coastguard Worker } else if (kernel_size == 1 && unit_subsampling && !any_padding) {
378*4bdc9457SAndroid Build Coastguard Worker ukernel_type = xnn_ukernel_type_gemm;
379*4bdc9457SAndroid Build Coastguard Worker } else {
380*4bdc9457SAndroid Build Coastguard Worker ukernel_type = xnn_ukernel_type_igemm;
381*4bdc9457SAndroid Build Coastguard Worker }
382*4bdc9457SAndroid Build Coastguard Worker assert(ukernel_type != xnn_ukernel_type_default);
383*4bdc9457SAndroid Build Coastguard Worker
384*4bdc9457SAndroid Build Coastguard Worker if (num_post_operations != 0 && ukernel_type != xnn_ukernel_type_gemm) {
385*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
386*4bdc9457SAndroid Build Coastguard Worker "convolution with post operations not support for these parameters: "
387*4bdc9457SAndroid Build Coastguard Worker "kernel_size: %zu unit_subsampling: %d padding: %d",
388*4bdc9457SAndroid Build Coastguard Worker kernel_size, unit_subsampling, any_padding);
389*4bdc9457SAndroid Build Coastguard Worker goto error;
390*4bdc9457SAndroid Build Coastguard Worker }
391*4bdc9457SAndroid Build Coastguard Worker
392*4bdc9457SAndroid Build Coastguard Worker size_t zero_size = 0;
393*4bdc9457SAndroid Build Coastguard Worker switch (ukernel_type) {
394*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_vmulcaddc:
395*4bdc9457SAndroid Build Coastguard Worker {
396*4bdc9457SAndroid Build Coastguard Worker assert(vmulcaddc_parameters != NULL);
397*4bdc9457SAndroid Build Coastguard Worker assert(vmulcaddc_params != NULL);
398*4bdc9457SAndroid Build Coastguard Worker
399*4bdc9457SAndroid Build Coastguard Worker const size_t c_stride = round_up_po2(groups, vmulcaddc_parameters->channel_tile);
400*4bdc9457SAndroid Build Coastguard Worker const size_t packed_weights_size = ((UINT32_C(1) << log2_filter_element_size) + bias_element_size) * c_stride;
401*4bdc9457SAndroid Build Coastguard Worker size_t aligned_total_weights_size = round_up_po2(packed_weights_size, XNN_ALLOCATION_ALIGNMENT);
402*4bdc9457SAndroid Build Coastguard Worker void* weights_ptr = xnn_get_pointer_to_write_weights(
403*4bdc9457SAndroid Build Coastguard Worker convolution_op, aligned_total_weights_size, packed_weights_padding_byte);
404*4bdc9457SAndroid Build Coastguard Worker if (weights_ptr == NULL) {
405*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to reserve or allocated %zu bytes for %s operator vmulcaddc packed weights",
406*4bdc9457SAndroid Build Coastguard Worker aligned_total_weights_size, xnn_operator_type_to_string(operator_type));
407*4bdc9457SAndroid Build Coastguard Worker goto error;
408*4bdc9457SAndroid Build Coastguard Worker }
409*4bdc9457SAndroid Build Coastguard Worker
410*4bdc9457SAndroid Build Coastguard Worker pack_vmulcaddc_w(
411*4bdc9457SAndroid Build Coastguard Worker groups, vmulcaddc_parameters->channel_tile,
412*4bdc9457SAndroid Build Coastguard Worker kernel, bias, weights_ptr, packing_params);
413*4bdc9457SAndroid Build Coastguard Worker
414*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache(convolution_op)) {
415*4bdc9457SAndroid Build Coastguard Worker convolution_op->packed_weights.offset = xnn_get_or_insert_weights_cache(
416*4bdc9457SAndroid Build Coastguard Worker convolution_op->weights_cache, weights_ptr, aligned_total_weights_size);
417*4bdc9457SAndroid Build Coastguard Worker }
418*4bdc9457SAndroid Build Coastguard Worker
419*4bdc9457SAndroid Build Coastguard Worker memcpy(&convolution_op->params, vmulcaddc_params, vmulcaddc_params_size);
420*4bdc9457SAndroid Build Coastguard Worker
421*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.vmulcaddc = (struct xnn_ukernel_vmulcaddc) {
422*4bdc9457SAndroid Build Coastguard Worker .function = vmulcaddc_parameters->ukernel,
423*4bdc9457SAndroid Build Coastguard Worker .mr = vmulcaddc_parameters->row_tile,
424*4bdc9457SAndroid Build Coastguard Worker };
425*4bdc9457SAndroid Build Coastguard Worker break;
426*4bdc9457SAndroid Build Coastguard Worker }
427*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_dwconv:
428*4bdc9457SAndroid Build Coastguard Worker {
429*4bdc9457SAndroid Build Coastguard Worker assert(dwconv_ukernel != NULL);
430*4bdc9457SAndroid Build Coastguard Worker assert(dwconv_ukernel->primary_tile == kernel_size);
431*4bdc9457SAndroid Build Coastguard Worker
432*4bdc9457SAndroid Build Coastguard Worker const size_t c_stride = round_up_po2(groups, dwconv_ukernel->channel_tile);
433*4bdc9457SAndroid Build Coastguard Worker const size_t packed_weights_size = ((kernel_size << log2_filter_element_size) + bias_element_size + extra_weights_bytes) * c_stride;
434*4bdc9457SAndroid Build Coastguard Worker size_t aligned_total_weights_size = round_up_po2(packed_weights_size, XNN_ALLOCATION_ALIGNMENT);
435*4bdc9457SAndroid Build Coastguard Worker void* weights_ptr = xnn_get_pointer_to_write_weights(
436*4bdc9457SAndroid Build Coastguard Worker convolution_op, aligned_total_weights_size, packed_weights_padding_byte);
437*4bdc9457SAndroid Build Coastguard Worker if (weights_ptr == NULL) {
438*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to reserve or allocated %zu bytes for %s operator dwconv packed weights",
439*4bdc9457SAndroid Build Coastguard Worker aligned_total_weights_size, xnn_operator_type_to_string(operator_type));
440*4bdc9457SAndroid Build Coastguard Worker goto error;
441*4bdc9457SAndroid Build Coastguard Worker }
442*4bdc9457SAndroid Build Coastguard Worker memcpy(&convolution_op->params, dwconv_params, dwconv_params_size);
443*4bdc9457SAndroid Build Coastguard Worker
444*4bdc9457SAndroid Build Coastguard Worker if (flags & XNN_FLAG_DEPTHWISE_CONVOLUTION) {
445*4bdc9457SAndroid Build Coastguard Worker pack_dwconv_hwg_w(
446*4bdc9457SAndroid Build Coastguard Worker dwconv_ukernel->primary_tile,
447*4bdc9457SAndroid Build Coastguard Worker kernel_height, kernel_width,
448*4bdc9457SAndroid Build Coastguard Worker groups, dwconv_ukernel->channel_tile,
449*4bdc9457SAndroid Build Coastguard Worker kernel, bias, weights_ptr,
450*4bdc9457SAndroid Build Coastguard Worker dwconv_ukernel->channel_tile * extra_weights_bytes,
451*4bdc9457SAndroid Build Coastguard Worker packing_params);
452*4bdc9457SAndroid Build Coastguard Worker } else {
453*4bdc9457SAndroid Build Coastguard Worker pack_dwconv_ghw_w(
454*4bdc9457SAndroid Build Coastguard Worker dwconv_ukernel->primary_tile,
455*4bdc9457SAndroid Build Coastguard Worker kernel_height, kernel_width,
456*4bdc9457SAndroid Build Coastguard Worker groups, dwconv_ukernel->channel_tile,
457*4bdc9457SAndroid Build Coastguard Worker kernel, bias, weights_ptr,
458*4bdc9457SAndroid Build Coastguard Worker dwconv_ukernel->channel_tile * extra_weights_bytes,
459*4bdc9457SAndroid Build Coastguard Worker packing_params);
460*4bdc9457SAndroid Build Coastguard Worker }
461*4bdc9457SAndroid Build Coastguard Worker
462*4bdc9457SAndroid Build Coastguard Worker if (scale_params != NULL) {
463*4bdc9457SAndroid Build Coastguard Worker assert(init_scale_params != NULL);
464*4bdc9457SAndroid Build Coastguard Worker
465*4bdc9457SAndroid Build Coastguard Worker init_scale_params(
466*4bdc9457SAndroid Build Coastguard Worker groups, dwconv_ukernel->channel_tile,
467*4bdc9457SAndroid Build Coastguard Worker dwconv_ukernel->channel_tile * ((kernel_size << log2_filter_element_size) + bias_element_size + extra_weights_bytes),
468*4bdc9457SAndroid Build Coastguard Worker scale_params,
469*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) weights_ptr + dwconv_ukernel->channel_tile * ((kernel_size << log2_filter_element_size) + bias_element_size)));
470*4bdc9457SAndroid Build Coastguard Worker }
471*4bdc9457SAndroid Build Coastguard Worker
472*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache(convolution_op)) {
473*4bdc9457SAndroid Build Coastguard Worker convolution_op->packed_weights.offset = xnn_get_or_insert_weights_cache(
474*4bdc9457SAndroid Build Coastguard Worker convolution_op->weights_cache, weights_ptr, aligned_total_weights_size);
475*4bdc9457SAndroid Build Coastguard Worker }
476*4bdc9457SAndroid Build Coastguard Worker
477*4bdc9457SAndroid Build Coastguard Worker const union dwconv_fused_ukernels* ukernels = &dwconv_ukernel->minmax;
478*4bdc9457SAndroid Build Coastguard Worker if (linear_activation && dwconv_ukernel->linear.unipass != NULL) {
479*4bdc9457SAndroid Build Coastguard Worker ukernels = &dwconv_ukernel->linear;
480*4bdc9457SAndroid Build Coastguard Worker }
481*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.dwconv = (struct xnn_ukernel_dwconv) {
482*4bdc9457SAndroid Build Coastguard Worker .unipass_function = ukernels->unipass,
483*4bdc9457SAndroid Build Coastguard Worker .primary_tile = dwconv_ukernel->primary_tile,
484*4bdc9457SAndroid Build Coastguard Worker .incremental_tile = dwconv_ukernel->incremental_tile,
485*4bdc9457SAndroid Build Coastguard Worker };
486*4bdc9457SAndroid Build Coastguard Worker
487*4bdc9457SAndroid Build Coastguard Worker zero_size = XNN_EXTRA_BYTES + (c_stride << log2_input_element_size);
488*4bdc9457SAndroid Build Coastguard Worker break;
489*4bdc9457SAndroid Build Coastguard Worker }
490*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_gemm:
491*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_igemm:
492*4bdc9457SAndroid Build Coastguard Worker {
493*4bdc9457SAndroid Build Coastguard Worker const uint32_t nr = gemm_parameters->nr;
494*4bdc9457SAndroid Build Coastguard Worker const uint32_t kr = UINT32_C(1) << gemm_parameters->log2_kr;
495*4bdc9457SAndroid Build Coastguard Worker const uint32_t sr = UINT32_C(1) << gemm_parameters->log2_sr;
496*4bdc9457SAndroid Build Coastguard Worker const size_t n_stride = round_up(group_output_channels, nr);
497*4bdc9457SAndroid Build Coastguard Worker const size_t k_stride = round_up_po2(group_input_channels, kr * sr);
498*4bdc9457SAndroid Build Coastguard Worker
499*4bdc9457SAndroid Build Coastguard Worker const size_t packed_group_weights_size = ((kernel_size * k_stride << log2_filter_element_size) + bias_element_size + extra_weights_bytes) * n_stride;
500*4bdc9457SAndroid Build Coastguard Worker const size_t aligned_total_weights_size = round_up_po2(packed_group_weights_size * groups, XNN_ALLOCATION_ALIGNMENT);
501*4bdc9457SAndroid Build Coastguard Worker void* weights_ptr = xnn_get_pointer_to_write_weights(
502*4bdc9457SAndroid Build Coastguard Worker convolution_op, aligned_total_weights_size, packed_weights_padding_byte);
503*4bdc9457SAndroid Build Coastguard Worker if (weights_ptr == NULL) {
504*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to reserve or allocated %zu bytes for %s operator gemm packed weights",
505*4bdc9457SAndroid Build Coastguard Worker aligned_total_weights_size, xnn_operator_type_to_string(operator_type));
506*4bdc9457SAndroid Build Coastguard Worker goto error;
507*4bdc9457SAndroid Build Coastguard Worker }
508*4bdc9457SAndroid Build Coastguard Worker memcpy(&convolution_op->params, gemm_params, gemm_params_size);
509*4bdc9457SAndroid Build Coastguard Worker convolution_op->num_post_operation_params = num_post_operations;
510*4bdc9457SAndroid Build Coastguard Worker convolution_op->post_operation_params = post_operation_params;
511*4bdc9457SAndroid Build Coastguard Worker
512*4bdc9457SAndroid Build Coastguard Worker const struct gemm_fused_ukernels* gemm_ukernels = &gemm_parameters->minmax;
513*4bdc9457SAndroid Build Coastguard Worker const uint32_t mr = gemm_parameters->mr;
514*4bdc9457SAndroid Build Coastguard Worker if (linear_activation && gemm_parameters->linear.gemm[mr - 1].function[XNN_UARCH_DEFAULT] != NULL) {
515*4bdc9457SAndroid Build Coastguard Worker gemm_ukernels = &gemm_parameters->linear;
516*4bdc9457SAndroid Build Coastguard Worker } else if (relu_activation && gemm_parameters->relu.gemm[mr - 1].function[XNN_UARCH_DEFAULT] != NULL) {
517*4bdc9457SAndroid Build Coastguard Worker gemm_ukernels = &gemm_parameters->relu;
518*4bdc9457SAndroid Build Coastguard Worker }
519*4bdc9457SAndroid Build Coastguard Worker switch (ukernel_type) {
520*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_gemm:
521*4bdc9457SAndroid Build Coastguard Worker pack_gemm_goi_w(
522*4bdc9457SAndroid Build Coastguard Worker groups, group_output_channels, group_input_channels,
523*4bdc9457SAndroid Build Coastguard Worker nr, kr, sr,
524*4bdc9457SAndroid Build Coastguard Worker kernel, bias, weights_ptr, gemm_parameters->nr * extra_weights_bytes, packing_params);
525*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.gemm = (struct xnn_ukernel_gemm) {
526*4bdc9457SAndroid Build Coastguard Worker .mr = mr,
527*4bdc9457SAndroid Build Coastguard Worker .nr = nr,
528*4bdc9457SAndroid Build Coastguard Worker .kr = kr,
529*4bdc9457SAndroid Build Coastguard Worker .sr = sr,
530*4bdc9457SAndroid Build Coastguard Worker };
531*4bdc9457SAndroid Build Coastguard Worker
532*4bdc9457SAndroid Build Coastguard Worker assert(XNN_MAX_MR >= mr);
533*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < mr; i++) {
534*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.gemm.gemm_cases[i] = gemm_ukernels->gemm[i];
535*4bdc9457SAndroid Build Coastguard Worker }
536*4bdc9457SAndroid Build Coastguard Worker
537*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT
538*4bdc9457SAndroid Build Coastguard Worker generate_gemms_up_to_max_mr(
539*4bdc9457SAndroid Build Coastguard Worker mr, gemm_parameters->generator, jit_gemm_params, group_output_channels, nr,
540*4bdc9457SAndroid Build Coastguard Worker group_input_channels, log2_input_element_size, convolution_op);
541*4bdc9457SAndroid Build Coastguard Worker #endif // XNN_PLATFORM_JIT
542*4bdc9457SAndroid Build Coastguard Worker
543*4bdc9457SAndroid Build Coastguard Worker break;
544*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_igemm:
545*4bdc9457SAndroid Build Coastguard Worker if (flags & XNN_FLAG_DEPTHWISE_CONVOLUTION) {
546*4bdc9457SAndroid Build Coastguard Worker pack_conv_kgo_w(
547*4bdc9457SAndroid Build Coastguard Worker groups, group_output_channels, kernel_size,
548*4bdc9457SAndroid Build Coastguard Worker nr, kr, sr,
549*4bdc9457SAndroid Build Coastguard Worker kernel, bias, weights_ptr, gemm_parameters->nr * extra_weights_bytes, packing_params);
550*4bdc9457SAndroid Build Coastguard Worker } else {
551*4bdc9457SAndroid Build Coastguard Worker pack_conv_goki_w(
552*4bdc9457SAndroid Build Coastguard Worker groups, group_output_channels, kernel_size, group_input_channels,
553*4bdc9457SAndroid Build Coastguard Worker nr, kr, sr,
554*4bdc9457SAndroid Build Coastguard Worker kernel, bias, weights_ptr, gemm_parameters->nr * extra_weights_bytes, packing_params);
555*4bdc9457SAndroid Build Coastguard Worker }
556*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.igemm = (struct xnn_ukernel_igemm) {
557*4bdc9457SAndroid Build Coastguard Worker .mr = mr,
558*4bdc9457SAndroid Build Coastguard Worker .nr = nr,
559*4bdc9457SAndroid Build Coastguard Worker .kr = kr,
560*4bdc9457SAndroid Build Coastguard Worker .sr = sr,
561*4bdc9457SAndroid Build Coastguard Worker };
562*4bdc9457SAndroid Build Coastguard Worker
563*4bdc9457SAndroid Build Coastguard Worker assert(XNN_MAX_MR >= mr);
564*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < mr; i++) {
565*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.igemm.igemm_cases[i] = gemm_ukernels->igemm[i];
566*4bdc9457SAndroid Build Coastguard Worker }
567*4bdc9457SAndroid Build Coastguard Worker
568*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT
569*4bdc9457SAndroid Build Coastguard Worker generate_igemms_up_to_max_mr(
570*4bdc9457SAndroid Build Coastguard Worker mr, gemm_parameters->generator, jit_gemm_params, group_output_channels, nr,
571*4bdc9457SAndroid Build Coastguard Worker group_input_channels, log2_input_element_size, kernel_size, convolution_op);
572*4bdc9457SAndroid Build Coastguard Worker #endif // XNN_PLATFORM_JIT
573*4bdc9457SAndroid Build Coastguard Worker
574*4bdc9457SAndroid Build Coastguard Worker break;
575*4bdc9457SAndroid Build Coastguard Worker default:
576*4bdc9457SAndroid Build Coastguard Worker XNN_UNREACHABLE;
577*4bdc9457SAndroid Build Coastguard Worker }
578*4bdc9457SAndroid Build Coastguard Worker
579*4bdc9457SAndroid Build Coastguard Worker if (scale_params != NULL) {
580*4bdc9457SAndroid Build Coastguard Worker assert(init_scale_params != NULL);
581*4bdc9457SAndroid Build Coastguard Worker
582*4bdc9457SAndroid Build Coastguard Worker void* group_weights = (void*)
583*4bdc9457SAndroid Build Coastguard Worker ((uintptr_t) weights_ptr + gemm_parameters->nr * ((kernel_size * k_stride << log2_filter_element_size) + bias_element_size));
584*4bdc9457SAndroid Build Coastguard Worker const size_t weights_stride = (kernel_size * k_stride << log2_filter_element_size) + bias_element_size + extra_weights_bytes;
585*4bdc9457SAndroid Build Coastguard Worker for (uint32_t group = 0; group < groups; group++) {
586*4bdc9457SAndroid Build Coastguard Worker init_scale_params(
587*4bdc9457SAndroid Build Coastguard Worker group_output_channels, gemm_parameters->nr,
588*4bdc9457SAndroid Build Coastguard Worker gemm_parameters->nr * weights_stride,
589*4bdc9457SAndroid Build Coastguard Worker scale_params, group_weights);
590*4bdc9457SAndroid Build Coastguard Worker scale_params += group_output_channels;
591*4bdc9457SAndroid Build Coastguard Worker group_weights = (void*) ((uintptr_t) group_weights + n_stride * weights_stride);
592*4bdc9457SAndroid Build Coastguard Worker }
593*4bdc9457SAndroid Build Coastguard Worker }
594*4bdc9457SAndroid Build Coastguard Worker
595*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache(convolution_op)) {
596*4bdc9457SAndroid Build Coastguard Worker convolution_op->packed_weights.offset = xnn_get_or_insert_weights_cache(
597*4bdc9457SAndroid Build Coastguard Worker convolution_op->weights_cache, weights_ptr, aligned_total_weights_size);
598*4bdc9457SAndroid Build Coastguard Worker }
599*4bdc9457SAndroid Build Coastguard Worker
600*4bdc9457SAndroid Build Coastguard Worker zero_size = XNN_EXTRA_BYTES + (k_stride << log2_input_element_size);
601*4bdc9457SAndroid Build Coastguard Worker break;
602*4bdc9457SAndroid Build Coastguard Worker }
603*4bdc9457SAndroid Build Coastguard Worker default:
604*4bdc9457SAndroid Build Coastguard Worker XNN_UNREACHABLE;
605*4bdc9457SAndroid Build Coastguard Worker }
606*4bdc9457SAndroid Build Coastguard Worker
607*4bdc9457SAndroid Build Coastguard Worker const bool tf_same_padding = (flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) != 0 && kernel_size != 1;
608*4bdc9457SAndroid Build Coastguard Worker if (any_padding || tf_same_padding) {
609*4bdc9457SAndroid Build Coastguard Worker convolution_op->zero_buffer = xnn_allocate_simd_memory(zero_size);
610*4bdc9457SAndroid Build Coastguard Worker if (convolution_op->zero_buffer == NULL) {
611*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
612*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator zero padding",
613*4bdc9457SAndroid Build Coastguard Worker zero_size, xnn_operator_type_to_string(operator_type));
614*4bdc9457SAndroid Build Coastguard Worker goto error;
615*4bdc9457SAndroid Build Coastguard Worker }
616*4bdc9457SAndroid Build Coastguard Worker memset(convolution_op->zero_buffer, input_padding_byte, zero_size);
617*4bdc9457SAndroid Build Coastguard Worker }
618*4bdc9457SAndroid Build Coastguard Worker
619*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_top = input_padding_top;
620*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_right = input_padding_right;
621*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_bottom = input_padding_bottom;
622*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_left = input_padding_left;
623*4bdc9457SAndroid Build Coastguard Worker
624*4bdc9457SAndroid Build Coastguard Worker convolution_op->kernel_height = kernel_height;
625*4bdc9457SAndroid Build Coastguard Worker convolution_op->kernel_width = kernel_width;
626*4bdc9457SAndroid Build Coastguard Worker convolution_op->stride_height = subsampling_height;
627*4bdc9457SAndroid Build Coastguard Worker convolution_op->stride_width = subsampling_width;
628*4bdc9457SAndroid Build Coastguard Worker convolution_op->dilation_height = dilation_height;
629*4bdc9457SAndroid Build Coastguard Worker convolution_op->dilation_width = dilation_width;
630*4bdc9457SAndroid Build Coastguard Worker convolution_op->groups = groups;
631*4bdc9457SAndroid Build Coastguard Worker convolution_op->group_input_channels = group_input_channels;
632*4bdc9457SAndroid Build Coastguard Worker convolution_op->group_output_channels = group_output_channels;
633*4bdc9457SAndroid Build Coastguard Worker convolution_op->input_pixel_stride = input_channel_stride;
634*4bdc9457SAndroid Build Coastguard Worker convolution_op->output_pixel_stride = output_channel_stride;
635*4bdc9457SAndroid Build Coastguard Worker
636*4bdc9457SAndroid Build Coastguard Worker convolution_op->type = operator_type;
637*4bdc9457SAndroid Build Coastguard Worker convolution_op->ukernel.type = ukernel_type;
638*4bdc9457SAndroid Build Coastguard Worker convolution_op->flags = flags & ~XNN_FLAG_TENSORFLOW_SAME_PADDING;
639*4bdc9457SAndroid Build Coastguard Worker if (tf_same_padding) {
640*4bdc9457SAndroid Build Coastguard Worker convolution_op->flags |= XNN_FLAG_TENSORFLOW_SAME_PADDING;
641*4bdc9457SAndroid Build Coastguard Worker }
642*4bdc9457SAndroid Build Coastguard Worker
643*4bdc9457SAndroid Build Coastguard Worker convolution_op->state = xnn_run_state_invalid;
644*4bdc9457SAndroid Build Coastguard Worker
645*4bdc9457SAndroid Build Coastguard Worker *convolution_op_out = convolution_op;
646*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
647*4bdc9457SAndroid Build Coastguard Worker
648*4bdc9457SAndroid Build Coastguard Worker error:
649*4bdc9457SAndroid Build Coastguard Worker xnn_delete_operator(convolution_op);
650*4bdc9457SAndroid Build Coastguard Worker return status;
651*4bdc9457SAndroid Build Coastguard Worker }
652*4bdc9457SAndroid Build Coastguard Worker
xnn_create_convolution2d_nhwc_qu8(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_channel_stride,size_t output_channel_stride,uint8_t input_zero_point,float input_scale,uint8_t kernel_zero_point,float kernel_scale,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * convolution_op_out)653*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_convolution2d_nhwc_qu8(
654*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_top,
655*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_right,
656*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_bottom,
657*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_left,
658*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height,
659*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width,
660*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_height,
661*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_width,
662*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height,
663*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width,
664*4bdc9457SAndroid Build Coastguard Worker uint32_t groups,
665*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
666*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
667*4bdc9457SAndroid Build Coastguard Worker size_t input_channel_stride,
668*4bdc9457SAndroid Build Coastguard Worker size_t output_channel_stride,
669*4bdc9457SAndroid Build Coastguard Worker uint8_t input_zero_point,
670*4bdc9457SAndroid Build Coastguard Worker float input_scale,
671*4bdc9457SAndroid Build Coastguard Worker uint8_t kernel_zero_point,
672*4bdc9457SAndroid Build Coastguard Worker float kernel_scale,
673*4bdc9457SAndroid Build Coastguard Worker const uint8_t* kernel,
674*4bdc9457SAndroid Build Coastguard Worker const int32_t* bias,
675*4bdc9457SAndroid Build Coastguard Worker uint8_t output_zero_point,
676*4bdc9457SAndroid Build Coastguard Worker float output_scale,
677*4bdc9457SAndroid Build Coastguard Worker uint8_t output_min,
678*4bdc9457SAndroid Build Coastguard Worker uint8_t output_max,
679*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
680*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
681*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* convolution_op_out)
682*4bdc9457SAndroid Build Coastguard Worker {
683*4bdc9457SAndroid Build Coastguard Worker if (input_scale <= 0.0f || !isnormal(input_scale)) {
684*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
685*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
686*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qu8), input_scale);
687*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
688*4bdc9457SAndroid Build Coastguard Worker }
689*4bdc9457SAndroid Build Coastguard Worker
690*4bdc9457SAndroid Build Coastguard Worker if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
691*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
692*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
693*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qu8), kernel_scale);
694*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
695*4bdc9457SAndroid Build Coastguard Worker }
696*4bdc9457SAndroid Build Coastguard Worker
697*4bdc9457SAndroid Build Coastguard Worker if (output_scale <= 0.0f || !isnormal(output_scale)) {
698*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
699*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
700*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qu8), output_scale);
701*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
702*4bdc9457SAndroid Build Coastguard Worker }
703*4bdc9457SAndroid Build Coastguard Worker
704*4bdc9457SAndroid Build Coastguard Worker if (output_min >= output_max) {
705*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
706*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
707*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qu8), output_min, output_max);
708*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
709*4bdc9457SAndroid Build Coastguard Worker }
710*4bdc9457SAndroid Build Coastguard Worker
711*4bdc9457SAndroid Build Coastguard Worker const float requantization_scale = input_scale * kernel_scale / output_scale;
712*4bdc9457SAndroid Build Coastguard Worker if (requantization_scale >= 256.0f) {
713*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
714*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
715*4bdc9457SAndroid Build Coastguard Worker "requantization scale %.7g is greater or equal to 256.0",
716*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qu8),
717*4bdc9457SAndroid Build Coastguard Worker input_scale, kernel_scale, output_scale, requantization_scale);
718*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
719*4bdc9457SAndroid Build Coastguard Worker }
720*4bdc9457SAndroid Build Coastguard Worker
721*4bdc9457SAndroid Build Coastguard Worker const struct xnn_qu8_packing_params packing_params = {
722*4bdc9457SAndroid Build Coastguard Worker .input_zero_point = input_zero_point,
723*4bdc9457SAndroid Build Coastguard Worker .kernel_zero_point = kernel_zero_point,
724*4bdc9457SAndroid Build Coastguard Worker };
725*4bdc9457SAndroid Build Coastguard Worker
726*4bdc9457SAndroid Build Coastguard Worker
727*4bdc9457SAndroid Build Coastguard Worker union xnn_qu8_conv_minmax_params gemm_params;
728*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.qu8.gemm.init.qu8 != NULL) {
729*4bdc9457SAndroid Build Coastguard Worker xnn_params.qu8.gemm.init.qu8(&gemm_params,
730*4bdc9457SAndroid Build Coastguard Worker kernel_zero_point, requantization_scale, output_zero_point, output_min, output_max);
731*4bdc9457SAndroid Build Coastguard Worker }
732*4bdc9457SAndroid Build Coastguard Worker
733*4bdc9457SAndroid Build Coastguard Worker union xnn_qu8_conv_minmax_params dwconv_params;
734*4bdc9457SAndroid Build Coastguard Worker const struct dwconv_parameters* dwconv_ukernel =
735*4bdc9457SAndroid Build Coastguard Worker find_dwconv_ukernel(kernel_height * kernel_width, xnn_params.qu8.dwconv, XNN_MAX_QU8_DWCONV_UKERNELS);
736*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(dwconv_ukernel != NULL) {
737*4bdc9457SAndroid Build Coastguard Worker dwconv_ukernel->init.qu8(&dwconv_params,
738*4bdc9457SAndroid Build Coastguard Worker kernel_zero_point, requantization_scale, output_zero_point, output_min, output_max);
739*4bdc9457SAndroid Build Coastguard Worker }
740*4bdc9457SAndroid Build Coastguard Worker
741*4bdc9457SAndroid Build Coastguard Worker return create_convolution2d_nhwc(
742*4bdc9457SAndroid Build Coastguard Worker input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
743*4bdc9457SAndroid Build Coastguard Worker kernel_height, kernel_width,
744*4bdc9457SAndroid Build Coastguard Worker subsampling_height, subsampling_width,
745*4bdc9457SAndroid Build Coastguard Worker dilation_height, dilation_width,
746*4bdc9457SAndroid Build Coastguard Worker groups, group_input_channels, group_output_channels,
747*4bdc9457SAndroid Build Coastguard Worker input_channel_stride, output_channel_stride,
748*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
749*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
750*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
751*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) /* sizeof(bias element) */,
752*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_vmulcaddc_w_function) NULL,
753*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_dwconv_hwg_w_function) xnn_pack_qu8_dwconv_hwg_w,
754*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_dwconv_ghw_w_function) xnn_pack_qu8_dwconv_ghw_w,
755*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_gemm_goi_w_function) xnn_pack_qu8_gemm_goi_w,
756*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_conv_kgo_w_function) xnn_pack_qu8_conv_kgo_w,
757*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_conv_goki_w_function) xnn_pack_qu8_conv_goki_w,
758*4bdc9457SAndroid Build Coastguard Worker &packing_params, input_zero_point /* input padding byte */, kernel_zero_point /* packed weights padding byte */,
759*4bdc9457SAndroid Build Coastguard Worker 0 /* extra weights bytes */, NULL /* init scale params fn */, NULL /* scale params */,
760*4bdc9457SAndroid Build Coastguard Worker &gemm_params, sizeof(gemm_params),
761*4bdc9457SAndroid Build Coastguard Worker &dwconv_params, sizeof(dwconv_params),
762*4bdc9457SAndroid Build Coastguard Worker NULL /* vmulcaddc params */, 0,
763*4bdc9457SAndroid Build Coastguard Worker &xnn_params.qu8.gemm, dwconv_ukernel, NULL /* vmulcaddc parameters */,
764*4bdc9457SAndroid Build Coastguard Worker NULL /* jit_gemm_params */,
765*4bdc9457SAndroid Build Coastguard Worker false /* linear activation */, false /* relu activation */, XNN_INIT_FLAG_QU8,
766*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_convolution_nhwc_qu8,
767*4bdc9457SAndroid Build Coastguard Worker 0, NULL,
768*4bdc9457SAndroid Build Coastguard Worker caches,
769*4bdc9457SAndroid Build Coastguard Worker convolution_op_out);
770*4bdc9457SAndroid Build Coastguard Worker }
771*4bdc9457SAndroid Build Coastguard Worker
xnn_create_convolution2d_nhwc_qs8(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_channel_stride,size_t output_channel_stride,int8_t input_zero_point,float input_scale,float kernel_scale,const int8_t * kernel,const int32_t * bias,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * convolution_op_out)772*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_convolution2d_nhwc_qs8(
773*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_top,
774*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_right,
775*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_bottom,
776*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_left,
777*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height,
778*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width,
779*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_height,
780*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_width,
781*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height,
782*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width,
783*4bdc9457SAndroid Build Coastguard Worker uint32_t groups,
784*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
785*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
786*4bdc9457SAndroid Build Coastguard Worker size_t input_channel_stride,
787*4bdc9457SAndroid Build Coastguard Worker size_t output_channel_stride,
788*4bdc9457SAndroid Build Coastguard Worker int8_t input_zero_point,
789*4bdc9457SAndroid Build Coastguard Worker float input_scale,
790*4bdc9457SAndroid Build Coastguard Worker float kernel_scale,
791*4bdc9457SAndroid Build Coastguard Worker const int8_t* kernel,
792*4bdc9457SAndroid Build Coastguard Worker const int32_t* bias,
793*4bdc9457SAndroid Build Coastguard Worker int8_t output_zero_point,
794*4bdc9457SAndroid Build Coastguard Worker float output_scale,
795*4bdc9457SAndroid Build Coastguard Worker int8_t output_min,
796*4bdc9457SAndroid Build Coastguard Worker int8_t output_max,
797*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
798*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
799*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* convolution_op_out)
800*4bdc9457SAndroid Build Coastguard Worker {
801*4bdc9457SAndroid Build Coastguard Worker if (input_scale <= 0.0f || !isnormal(input_scale)) {
802*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
803*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
804*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qs8), input_scale);
805*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
806*4bdc9457SAndroid Build Coastguard Worker }
807*4bdc9457SAndroid Build Coastguard Worker
808*4bdc9457SAndroid Build Coastguard Worker if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
809*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
810*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
811*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qs8), kernel_scale);
812*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
813*4bdc9457SAndroid Build Coastguard Worker }
814*4bdc9457SAndroid Build Coastguard Worker
815*4bdc9457SAndroid Build Coastguard Worker if (output_scale <= 0.0f || !isnormal(output_scale)) {
816*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
817*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
818*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qs8), output_scale);
819*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
820*4bdc9457SAndroid Build Coastguard Worker }
821*4bdc9457SAndroid Build Coastguard Worker
822*4bdc9457SAndroid Build Coastguard Worker if (output_min >= output_max) {
823*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
824*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
825*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qs8), output_min, output_max);
826*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
827*4bdc9457SAndroid Build Coastguard Worker }
828*4bdc9457SAndroid Build Coastguard Worker
829*4bdc9457SAndroid Build Coastguard Worker const float requantization_scale = input_scale * kernel_scale / output_scale;
830*4bdc9457SAndroid Build Coastguard Worker if (requantization_scale >= 256.0f) {
831*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
832*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
833*4bdc9457SAndroid Build Coastguard Worker "requantization scale %.7g is greater or equal to 256.0",
834*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qs8),
835*4bdc9457SAndroid Build Coastguard Worker input_scale, kernel_scale, output_scale, requantization_scale);
836*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
837*4bdc9457SAndroid Build Coastguard Worker }
838*4bdc9457SAndroid Build Coastguard Worker
839*4bdc9457SAndroid Build Coastguard Worker const struct xnn_qs8_packing_params packing_params = { .input_zero_point = input_zero_point, };
840*4bdc9457SAndroid Build Coastguard Worker
841*4bdc9457SAndroid Build Coastguard Worker union xnn_qs8_conv_minmax_params gemm_params;
842*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.qs8.gemm.init.qs8 != NULL) {
843*4bdc9457SAndroid Build Coastguard Worker xnn_params.qs8.gemm.init.qs8(&gemm_params,
844*4bdc9457SAndroid Build Coastguard Worker requantization_scale, output_zero_point, output_min, output_max);
845*4bdc9457SAndroid Build Coastguard Worker }
846*4bdc9457SAndroid Build Coastguard Worker
847*4bdc9457SAndroid Build Coastguard Worker union xnn_qs8_conv_minmax_params dwconv_params;
848*4bdc9457SAndroid Build Coastguard Worker const struct dwconv_parameters* dwconv_ukernel =
849*4bdc9457SAndroid Build Coastguard Worker find_dwconv_ukernel(kernel_height * kernel_width, xnn_params.qs8.dwconv, XNN_MAX_QS8_DWCONV_UKERNELS);
850*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(dwconv_ukernel != NULL) {
851*4bdc9457SAndroid Build Coastguard Worker dwconv_ukernel->init.qs8(&dwconv_params,
852*4bdc9457SAndroid Build Coastguard Worker requantization_scale, output_zero_point, output_min, output_max);
853*4bdc9457SAndroid Build Coastguard Worker }
854*4bdc9457SAndroid Build Coastguard Worker
855*4bdc9457SAndroid Build Coastguard Worker return create_convolution2d_nhwc(
856*4bdc9457SAndroid Build Coastguard Worker input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
857*4bdc9457SAndroid Build Coastguard Worker kernel_height, kernel_width,
858*4bdc9457SAndroid Build Coastguard Worker subsampling_height, subsampling_width,
859*4bdc9457SAndroid Build Coastguard Worker dilation_height, dilation_width,
860*4bdc9457SAndroid Build Coastguard Worker groups, group_input_channels, group_output_channels,
861*4bdc9457SAndroid Build Coastguard Worker input_channel_stride, output_channel_stride,
862*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
863*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(input element)) = log2(sizeof(int8_t)) */,
864*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
865*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) /* sizeof(bias element) */,
866*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_vmulcaddc_w_function) NULL,
867*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_dwconv_hwg_w_function) xnn_pack_qs8_dwconv_hwg_w,
868*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_dwconv_ghw_w_function) xnn_pack_qs8_dwconv_ghw_w,
869*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_gemm_goi_w_function) xnn_pack_qs8_gemm_goi_w,
870*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_conv_kgo_w_function) xnn_pack_qs8_conv_kgo_w,
871*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_conv_goki_w_function) xnn_pack_qs8_conv_goki_w,
872*4bdc9457SAndroid Build Coastguard Worker &packing_params, input_zero_point /* input padding byte */, 0 /* packed weights padding byte */,
873*4bdc9457SAndroid Build Coastguard Worker 0 /* extra weights bytes */, NULL /* init scale params fn */, NULL /* scale params */,
874*4bdc9457SAndroid Build Coastguard Worker &gemm_params, sizeof(gemm_params),
875*4bdc9457SAndroid Build Coastguard Worker &dwconv_params, sizeof(dwconv_params),
876*4bdc9457SAndroid Build Coastguard Worker NULL /* vmulcaddc params */, 0,
877*4bdc9457SAndroid Build Coastguard Worker &xnn_params.qs8.gemm, dwconv_ukernel, NULL /* vmulcaddc parameters */,
878*4bdc9457SAndroid Build Coastguard Worker NULL /* jit_gemm_params */,
879*4bdc9457SAndroid Build Coastguard Worker false /* linear activation */, false /* relu activation */, XNN_INIT_FLAG_QS8,
880*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_convolution_nhwc_qs8,
881*4bdc9457SAndroid Build Coastguard Worker 0, NULL,
882*4bdc9457SAndroid Build Coastguard Worker caches,
883*4bdc9457SAndroid Build Coastguard Worker convolution_op_out);
884*4bdc9457SAndroid Build Coastguard Worker }
885*4bdc9457SAndroid Build Coastguard Worker
xnn_create_convolution2d_nhwc_qc8(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_channel_stride,size_t output_channel_stride,int8_t input_zero_point,float input_scale,const float * kernel_scale,const int8_t * kernel,const int32_t * bias,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * convolution_op_out)886*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_convolution2d_nhwc_qc8(
887*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_top,
888*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_right,
889*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_bottom,
890*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_left,
891*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height,
892*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width,
893*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_height,
894*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_width,
895*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height,
896*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width,
897*4bdc9457SAndroid Build Coastguard Worker uint32_t groups,
898*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
899*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
900*4bdc9457SAndroid Build Coastguard Worker size_t input_channel_stride,
901*4bdc9457SAndroid Build Coastguard Worker size_t output_channel_stride,
902*4bdc9457SAndroid Build Coastguard Worker int8_t input_zero_point,
903*4bdc9457SAndroid Build Coastguard Worker float input_scale,
904*4bdc9457SAndroid Build Coastguard Worker const float* kernel_scale,
905*4bdc9457SAndroid Build Coastguard Worker const int8_t* kernel,
906*4bdc9457SAndroid Build Coastguard Worker const int32_t* bias,
907*4bdc9457SAndroid Build Coastguard Worker int8_t output_zero_point,
908*4bdc9457SAndroid Build Coastguard Worker float output_scale,
909*4bdc9457SAndroid Build Coastguard Worker int8_t output_min,
910*4bdc9457SAndroid Build Coastguard Worker int8_t output_max,
911*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
912*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
913*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* convolution_op_out)
914*4bdc9457SAndroid Build Coastguard Worker {
915*4bdc9457SAndroid Build Coastguard Worker if (input_scale <= 0.0f || !isnormal(input_scale)) {
916*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
917*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
918*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qc8), input_scale);
919*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
920*4bdc9457SAndroid Build Coastguard Worker }
921*4bdc9457SAndroid Build Coastguard Worker
922*4bdc9457SAndroid Build Coastguard Worker for (size_t output_channel = 0; output_channel < groups * group_output_channels; output_channel++) {
923*4bdc9457SAndroid Build Coastguard Worker if (kernel_scale[output_channel] <= 0.0f || !isnormal(kernel_scale[output_channel])) {
924*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
925*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g kernel scale in output channel #%zu: "
926*4bdc9457SAndroid Build Coastguard Worker "scale must be finite, normalized, and positive",
927*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qc8), kernel_scale[output_channel],
928*4bdc9457SAndroid Build Coastguard Worker output_channel);
929*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
930*4bdc9457SAndroid Build Coastguard Worker }
931*4bdc9457SAndroid Build Coastguard Worker }
932*4bdc9457SAndroid Build Coastguard Worker
933*4bdc9457SAndroid Build Coastguard Worker if (output_scale <= 0.0f || !isnormal(output_scale)) {
934*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
935*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
936*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qc8), output_scale);
937*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
938*4bdc9457SAndroid Build Coastguard Worker }
939*4bdc9457SAndroid Build Coastguard Worker
940*4bdc9457SAndroid Build Coastguard Worker if (output_min >= output_max) {
941*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
942*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
943*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qc8), output_min, output_max);
944*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
945*4bdc9457SAndroid Build Coastguard Worker }
946*4bdc9457SAndroid Build Coastguard Worker
947*4bdc9457SAndroid Build Coastguard Worker float* requantization_scale = XNN_SIMD_ALLOCA(groups * group_output_channels * sizeof(float));
948*4bdc9457SAndroid Build Coastguard Worker for (size_t output_channel = 0; output_channel < groups * group_output_channels; output_channel++) {
949*4bdc9457SAndroid Build Coastguard Worker requantization_scale[output_channel] = input_scale * kernel_scale[output_channel] / output_scale;
950*4bdc9457SAndroid Build Coastguard Worker if (requantization_scale[output_channel] >= 256.0f) {
951*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
952*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale in output channel #%zu: "
953*4bdc9457SAndroid Build Coastguard Worker "requantization scale %.7g is greater or equal to 256.0",
954*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_qc8),
955*4bdc9457SAndroid Build Coastguard Worker input_scale, kernel_scale[output_channel], output_scale,
956*4bdc9457SAndroid Build Coastguard Worker output_channel, requantization_scale[output_channel]);
957*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
958*4bdc9457SAndroid Build Coastguard Worker }
959*4bdc9457SAndroid Build Coastguard Worker }
960*4bdc9457SAndroid Build Coastguard Worker
961*4bdc9457SAndroid Build Coastguard Worker const struct xnn_qs8_packing_params packing_params = { .input_zero_point = input_zero_point, };
962*4bdc9457SAndroid Build Coastguard Worker
963*4bdc9457SAndroid Build Coastguard Worker union xnn_qc8_conv_minmax_params gemm_params;
964*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.qc8.gemm.init.qc8 != NULL) {
965*4bdc9457SAndroid Build Coastguard Worker xnn_params.qc8.gemm.init.qc8(&gemm_params,
966*4bdc9457SAndroid Build Coastguard Worker output_zero_point, output_min, output_max);
967*4bdc9457SAndroid Build Coastguard Worker }
968*4bdc9457SAndroid Build Coastguard Worker
969*4bdc9457SAndroid Build Coastguard Worker union xnn_qc8_conv_minmax_params dwconv_params;
970*4bdc9457SAndroid Build Coastguard Worker const struct dwconv_parameters* dwconv_ukernel =
971*4bdc9457SAndroid Build Coastguard Worker find_dwconv_ukernel(kernel_height * kernel_width, xnn_params.qc8.dwconv, XNN_MAX_QC8_DWCONV_UKERNELS);
972*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(dwconv_ukernel != NULL) {
973*4bdc9457SAndroid Build Coastguard Worker dwconv_ukernel->init.qc8(&dwconv_params,
974*4bdc9457SAndroid Build Coastguard Worker output_zero_point, output_min, output_max);
975*4bdc9457SAndroid Build Coastguard Worker }
976*4bdc9457SAndroid Build Coastguard Worker
977*4bdc9457SAndroid Build Coastguard Worker return create_convolution2d_nhwc(
978*4bdc9457SAndroid Build Coastguard Worker input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
979*4bdc9457SAndroid Build Coastguard Worker kernel_height, kernel_width,
980*4bdc9457SAndroid Build Coastguard Worker subsampling_height, subsampling_width,
981*4bdc9457SAndroid Build Coastguard Worker dilation_height, dilation_width,
982*4bdc9457SAndroid Build Coastguard Worker groups, group_input_channels, group_output_channels,
983*4bdc9457SAndroid Build Coastguard Worker input_channel_stride, output_channel_stride,
984*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
985*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(input element)) = log2(sizeof(int8_t)) */,
986*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
987*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) /* sizeof(bias element) */,
988*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_vmulcaddc_w_function) NULL,
989*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_dwconv_hwg_w_function) xnn_pack_qs8_dwconv_hwg_w,
990*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_dwconv_ghw_w_function) xnn_pack_qs8_dwconv_ghw_w,
991*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_gemm_goi_w_function) xnn_pack_qs8_gemm_goi_w,
992*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_conv_kgo_w_function) xnn_pack_qs8_conv_kgo_w,
993*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_conv_goki_w_function) xnn_pack_qs8_conv_goki_w,
994*4bdc9457SAndroid Build Coastguard Worker &packing_params, input_zero_point /* input padding byte */, 0 /* packed weights padding byte */,
995*4bdc9457SAndroid Build Coastguard Worker sizeof(float) /* extra weights bytes */, xnn_init_qc8_scale_fp32_params, requantization_scale,
996*4bdc9457SAndroid Build Coastguard Worker &gemm_params, sizeof(gemm_params),
997*4bdc9457SAndroid Build Coastguard Worker &dwconv_params, sizeof(dwconv_params),
998*4bdc9457SAndroid Build Coastguard Worker NULL /* vmulcaddc params */, 0,
999*4bdc9457SAndroid Build Coastguard Worker &xnn_params.qc8.gemm, dwconv_ukernel, NULL /* vmulcaddc parameters */,
1000*4bdc9457SAndroid Build Coastguard Worker NULL /* jit_gemm_params */,
1001*4bdc9457SAndroid Build Coastguard Worker false /* linear activation */, false /* relu activation */, XNN_INIT_FLAG_QC8,
1002*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_convolution_nhwc_qc8,
1003*4bdc9457SAndroid Build Coastguard Worker 0, NULL,
1004*4bdc9457SAndroid Build Coastguard Worker caches,
1005*4bdc9457SAndroid Build Coastguard Worker convolution_op_out);
1006*4bdc9457SAndroid Build Coastguard Worker }
1007*4bdc9457SAndroid Build Coastguard Worker
xnn_create_convolution2d_nhwc_f16(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_channel_stride,size_t output_channel_stride,const void * kernel,const void * bias,float output_min,float output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * convolution_op_out)1008*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_convolution2d_nhwc_f16(
1009*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_top,
1010*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_right,
1011*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_bottom,
1012*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_left,
1013*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height,
1014*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width,
1015*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_height,
1016*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_width,
1017*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height,
1018*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width,
1019*4bdc9457SAndroid Build Coastguard Worker uint32_t groups,
1020*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
1021*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
1022*4bdc9457SAndroid Build Coastguard Worker size_t input_channel_stride,
1023*4bdc9457SAndroid Build Coastguard Worker size_t output_channel_stride,
1024*4bdc9457SAndroid Build Coastguard Worker const void* kernel,
1025*4bdc9457SAndroid Build Coastguard Worker const void* bias,
1026*4bdc9457SAndroid Build Coastguard Worker float output_min,
1027*4bdc9457SAndroid Build Coastguard Worker float output_max,
1028*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
1029*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
1030*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* convolution_op_out)
1031*4bdc9457SAndroid Build Coastguard Worker {
1032*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_min)) {
1033*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
1034*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
1035*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_f16));
1036*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1037*4bdc9457SAndroid Build Coastguard Worker }
1038*4bdc9457SAndroid Build Coastguard Worker
1039*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_max)) {
1040*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
1041*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
1042*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_f16));
1043*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1044*4bdc9457SAndroid Build Coastguard Worker }
1045*4bdc9457SAndroid Build Coastguard Worker
1046*4bdc9457SAndroid Build Coastguard Worker const uint16_t fp16_output_min = fp16_ieee_from_fp32_value(output_min);
1047*4bdc9457SAndroid Build Coastguard Worker const uint16_t fp16_output_max = fp16_ieee_from_fp32_value(output_max);
1048*4bdc9457SAndroid Build Coastguard Worker const float rounded_output_min = fp16_ieee_to_fp32_value(fp16_output_min);
1049*4bdc9457SAndroid Build Coastguard Worker const float rounded_output_max = fp16_ieee_to_fp32_value(fp16_output_max);
1050*4bdc9457SAndroid Build Coastguard Worker if (rounded_output_min >= rounded_output_max) {
1051*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
1052*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
1053*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_f16), rounded_output_min, rounded_output_max);
1054*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1055*4bdc9457SAndroid Build Coastguard Worker }
1056*4bdc9457SAndroid Build Coastguard Worker
1057*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_minmax_params gemm_params;
1058*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.f16.gemm.init.f16 != NULL) {
1059*4bdc9457SAndroid Build Coastguard Worker xnn_params.f16.gemm.init.f16(&gemm_params,
1060*4bdc9457SAndroid Build Coastguard Worker fp16_output_min, fp16_output_max);
1061*4bdc9457SAndroid Build Coastguard Worker }
1062*4bdc9457SAndroid Build Coastguard Worker
1063*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_minmax_params dwconv_params;
1064*4bdc9457SAndroid Build Coastguard Worker const struct dwconv_parameters* dwconv_ukernel =
1065*4bdc9457SAndroid Build Coastguard Worker find_dwconv_ukernel(kernel_height * kernel_width, xnn_params.f16.dwconv, XNN_MAX_F16_DWCONV_UKERNELS);
1066*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(dwconv_ukernel != NULL) {
1067*4bdc9457SAndroid Build Coastguard Worker dwconv_ukernel->init.f16(&dwconv_params, fp16_output_min, fp16_output_max);
1068*4bdc9457SAndroid Build Coastguard Worker }
1069*4bdc9457SAndroid Build Coastguard Worker
1070*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_minmax_params vmulcaddc_params;
1071*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.f16.vmulcaddc.init.f16 != NULL) {
1072*4bdc9457SAndroid Build Coastguard Worker xnn_params.f16.vmulcaddc.init.f16(&vmulcaddc_params, fp16_output_min, fp16_output_max);
1073*4bdc9457SAndroid Build Coastguard Worker }
1074*4bdc9457SAndroid Build Coastguard Worker
1075*4bdc9457SAndroid Build Coastguard Worker xnn_pack_vmulcaddc_w_function pack_vmulcaddc_w = (xnn_pack_vmulcaddc_w_function) xnn_pack_f16_vmulcaddc_w;
1076*4bdc9457SAndroid Build Coastguard Worker xnn_pack_dwconv_hwg_w_function pack_dwconv_hwg_w = (xnn_pack_dwconv_hwg_w_function) xnn_pack_f16_dwconv_hwg_w;
1077*4bdc9457SAndroid Build Coastguard Worker xnn_pack_dwconv_ghw_w_function pack_dwconv_ghw_w = (xnn_pack_dwconv_ghw_w_function) xnn_pack_f16_dwconv_ghw_w;
1078*4bdc9457SAndroid Build Coastguard Worker xnn_pack_gemm_goi_w_function pack_gemm_goi_w = (xnn_pack_gemm_goi_w_function) xnn_pack_f16_gemm_goi_w;
1079*4bdc9457SAndroid Build Coastguard Worker xnn_pack_conv_kgo_w_function pack_conv_kgo_w = (xnn_pack_conv_kgo_w_function) xnn_pack_f16_conv_kgo_w;
1080*4bdc9457SAndroid Build Coastguard Worker xnn_pack_conv_goki_w_function pack_conv_goki_w = (xnn_pack_conv_goki_w_function) xnn_pack_f16_conv_goki_w;
1081*4bdc9457SAndroid Build Coastguard Worker if (flags & XNN_FLAG_FP32_STATIC_WEIGHTS) {
1082*4bdc9457SAndroid Build Coastguard Worker pack_vmulcaddc_w = (xnn_pack_vmulcaddc_w_function) xnn_pack_f32_to_f16_vmulcaddc_w;
1083*4bdc9457SAndroid Build Coastguard Worker pack_dwconv_hwg_w = (xnn_pack_dwconv_hwg_w_function) xnn_pack_f32_to_f16_dwconv_hwg_w;
1084*4bdc9457SAndroid Build Coastguard Worker pack_dwconv_ghw_w = (xnn_pack_dwconv_ghw_w_function) xnn_pack_f32_to_f16_dwconv_ghw_w;
1085*4bdc9457SAndroid Build Coastguard Worker pack_gemm_goi_w = (xnn_pack_gemm_goi_w_function) xnn_pack_f32_to_f16_gemm_goi_w;
1086*4bdc9457SAndroid Build Coastguard Worker pack_conv_kgo_w = (xnn_pack_conv_kgo_w_function) xnn_pack_f32_to_f16_conv_kgo_w;
1087*4bdc9457SAndroid Build Coastguard Worker pack_conv_goki_w = (xnn_pack_conv_goki_w_function) xnn_pack_f32_to_f16_conv_goki_w;
1088*4bdc9457SAndroid Build Coastguard Worker }
1089*4bdc9457SAndroid Build Coastguard Worker
1090*4bdc9457SAndroid Build Coastguard Worker return create_convolution2d_nhwc(
1091*4bdc9457SAndroid Build Coastguard Worker input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
1092*4bdc9457SAndroid Build Coastguard Worker kernel_height, kernel_width,
1093*4bdc9457SAndroid Build Coastguard Worker subsampling_height, subsampling_width,
1094*4bdc9457SAndroid Build Coastguard Worker dilation_height, dilation_width,
1095*4bdc9457SAndroid Build Coastguard Worker groups, group_input_channels, group_output_channels,
1096*4bdc9457SAndroid Build Coastguard Worker input_channel_stride, output_channel_stride,
1097*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
1098*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(input element)) = log2(sizeof(uint16_t)) */,
1099*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(filter element)) = log2(sizeof(uint16_t)) */,
1100*4bdc9457SAndroid Build Coastguard Worker sizeof(uint16_t) /* sizeof(bias element) */,
1101*4bdc9457SAndroid Build Coastguard Worker pack_vmulcaddc_w,
1102*4bdc9457SAndroid Build Coastguard Worker pack_dwconv_hwg_w,
1103*4bdc9457SAndroid Build Coastguard Worker pack_dwconv_ghw_w,
1104*4bdc9457SAndroid Build Coastguard Worker pack_gemm_goi_w,
1105*4bdc9457SAndroid Build Coastguard Worker pack_conv_kgo_w,
1106*4bdc9457SAndroid Build Coastguard Worker pack_conv_goki_w,
1107*4bdc9457SAndroid Build Coastguard Worker NULL /* packing params */, 0 /* input padding byte */, 0 /* packed weights padding byte */,
1108*4bdc9457SAndroid Build Coastguard Worker 0 /* extra weights bytes */, NULL /* init scale params fn */, NULL /* scale params */,
1109*4bdc9457SAndroid Build Coastguard Worker &gemm_params, sizeof(gemm_params),
1110*4bdc9457SAndroid Build Coastguard Worker &dwconv_params, sizeof(dwconv_params),
1111*4bdc9457SAndroid Build Coastguard Worker &vmulcaddc_params, sizeof(vmulcaddc_params),
1112*4bdc9457SAndroid Build Coastguard Worker &xnn_params.f16.gemm, dwconv_ukernel, &xnn_params.f16.vmulcaddc,
1113*4bdc9457SAndroid Build Coastguard Worker NULL /* jit_gemm_params */,
1114*4bdc9457SAndroid Build Coastguard Worker false /* linear activation */, false /* relu activation */, XNN_INIT_FLAG_F16,
1115*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_convolution_nhwc_f16,
1116*4bdc9457SAndroid Build Coastguard Worker 0, NULL,
1117*4bdc9457SAndroid Build Coastguard Worker caches,
1118*4bdc9457SAndroid Build Coastguard Worker convolution_op_out);
1119*4bdc9457SAndroid Build Coastguard Worker }
1120*4bdc9457SAndroid Build Coastguard Worker
xnn_create_convolution2d_nhwc_f32(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_channel_stride,size_t output_channel_stride,const float * kernel,const float * bias,float output_min,float output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * convolution_op_out)1121*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_convolution2d_nhwc_f32(
1122*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_top,
1123*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_right,
1124*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_bottom,
1125*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_left,
1126*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height,
1127*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width,
1128*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_height,
1129*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_width,
1130*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height,
1131*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width,
1132*4bdc9457SAndroid Build Coastguard Worker uint32_t groups,
1133*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
1134*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
1135*4bdc9457SAndroid Build Coastguard Worker size_t input_channel_stride,
1136*4bdc9457SAndroid Build Coastguard Worker size_t output_channel_stride,
1137*4bdc9457SAndroid Build Coastguard Worker const float* kernel,
1138*4bdc9457SAndroid Build Coastguard Worker const float* bias,
1139*4bdc9457SAndroid Build Coastguard Worker float output_min,
1140*4bdc9457SAndroid Build Coastguard Worker float output_max,
1141*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
1142*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
1143*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* convolution_op_out)
1144*4bdc9457SAndroid Build Coastguard Worker {
1145*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_min)) {
1146*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
1147*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
1148*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_f32));
1149*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1150*4bdc9457SAndroid Build Coastguard Worker }
1151*4bdc9457SAndroid Build Coastguard Worker
1152*4bdc9457SAndroid Build Coastguard Worker if (isnan(output_max)) {
1153*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
1154*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
1155*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_f32));
1156*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1157*4bdc9457SAndroid Build Coastguard Worker }
1158*4bdc9457SAndroid Build Coastguard Worker
1159*4bdc9457SAndroid Build Coastguard Worker if (output_min >= output_max) {
1160*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
1161*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
1162*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_f32), output_min, output_max);
1163*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1164*4bdc9457SAndroid Build Coastguard Worker }
1165*4bdc9457SAndroid Build Coastguard Worker
1166*4bdc9457SAndroid Build Coastguard Worker const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
1167*4bdc9457SAndroid Build Coastguard Worker const bool relu_activation = (output_max == INFINITY) && (output_min == 0.0f);
1168*4bdc9457SAndroid Build Coastguard Worker
1169*4bdc9457SAndroid Build Coastguard Worker const struct gemm_parameters* gemm_parameters = &xnn_params.f32.gemm;
1170*4bdc9457SAndroid Build Coastguard Worker if (gemm_parameters->nr > group_output_channels) {
1171*4bdc9457SAndroid Build Coastguard Worker // Default micro-kernel is suboptimal. Try to find a better micro-kernel.
1172*4bdc9457SAndroid Build Coastguard Worker
1173*4bdc9457SAndroid Build Coastguard Worker if (xnn_params.f32.gemm2.minmax.igemm[gemm_parameters->mr].function[XNN_UARCH_DEFAULT] != NULL) {
1174*4bdc9457SAndroid Build Coastguard Worker gemm_parameters = &xnn_params.f32.gemm2;
1175*4bdc9457SAndroid Build Coastguard Worker }
1176*4bdc9457SAndroid Build Coastguard Worker }
1177*4bdc9457SAndroid Build Coastguard Worker
1178*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_minmax_params gemm_params;
1179*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(gemm_parameters->init.f32 != NULL) {
1180*4bdc9457SAndroid Build Coastguard Worker gemm_parameters->init.f32(&gemm_params, output_min, output_max);
1181*4bdc9457SAndroid Build Coastguard Worker }
1182*4bdc9457SAndroid Build Coastguard Worker
1183*4bdc9457SAndroid Build Coastguard Worker struct jit_gemm_params jit_gemm_params = {
1184*4bdc9457SAndroid Build Coastguard Worker .f32_minmax = {
1185*4bdc9457SAndroid Build Coastguard Worker .min = output_min,
1186*4bdc9457SAndroid Build Coastguard Worker .max = output_max
1187*4bdc9457SAndroid Build Coastguard Worker }
1188*4bdc9457SAndroid Build Coastguard Worker };
1189*4bdc9457SAndroid Build Coastguard Worker
1190*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_minmax_params dwconv_params;
1191*4bdc9457SAndroid Build Coastguard Worker const struct dwconv_parameters* dwconv_ukernel =
1192*4bdc9457SAndroid Build Coastguard Worker find_dwconv_ukernel(kernel_height * kernel_width, xnn_params.f32.dwconv, XNN_MAX_F32_DWCONV_UKERNELS);
1193*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(dwconv_ukernel != NULL) {
1194*4bdc9457SAndroid Build Coastguard Worker dwconv_ukernel->init.f32(&dwconv_params, output_min, output_max);
1195*4bdc9457SAndroid Build Coastguard Worker }
1196*4bdc9457SAndroid Build Coastguard Worker
1197*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_minmax_params vmulcaddc_params;
1198*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.f32.vmulcaddc.init.f32 != NULL) {
1199*4bdc9457SAndroid Build Coastguard Worker xnn_params.f32.vmulcaddc.init.f32(&vmulcaddc_params, output_min, output_max);
1200*4bdc9457SAndroid Build Coastguard Worker }
1201*4bdc9457SAndroid Build Coastguard Worker
1202*4bdc9457SAndroid Build Coastguard Worker return create_convolution2d_nhwc(
1203*4bdc9457SAndroid Build Coastguard Worker input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
1204*4bdc9457SAndroid Build Coastguard Worker kernel_height, kernel_width,
1205*4bdc9457SAndroid Build Coastguard Worker subsampling_height, subsampling_width,
1206*4bdc9457SAndroid Build Coastguard Worker dilation_height, dilation_width,
1207*4bdc9457SAndroid Build Coastguard Worker groups, group_input_channels, group_output_channels,
1208*4bdc9457SAndroid Build Coastguard Worker input_channel_stride, output_channel_stride,
1209*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
1210*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
1211*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
1212*4bdc9457SAndroid Build Coastguard Worker sizeof(float) /* sizeof(bias element) */,
1213*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_vmulcaddc_w_function) xnn_pack_f32_vmulcaddc_w,
1214*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_dwconv_hwg_w_function) xnn_pack_f32_dwconv_hwg_w,
1215*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_dwconv_ghw_w_function) xnn_pack_f32_dwconv_ghw_w,
1216*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_gemm_goi_w_function) xnn_pack_f32_gemm_goi_w,
1217*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_conv_kgo_w_function) xnn_pack_f32_conv_kgo_w,
1218*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_conv_goki_w_function) xnn_pack_f32_conv_goki_w,
1219*4bdc9457SAndroid Build Coastguard Worker NULL /* packing params */, 0 /* input padding byte */, 0 /* packed weights padding byte */,
1220*4bdc9457SAndroid Build Coastguard Worker 0 /* extra weights bytes */, NULL /* init scale params fn */, NULL /* scale params */,
1221*4bdc9457SAndroid Build Coastguard Worker &gemm_params, sizeof(gemm_params),
1222*4bdc9457SAndroid Build Coastguard Worker &dwconv_params, sizeof(dwconv_params),
1223*4bdc9457SAndroid Build Coastguard Worker &vmulcaddc_params, sizeof(vmulcaddc_params),
1224*4bdc9457SAndroid Build Coastguard Worker gemm_parameters, dwconv_ukernel, &xnn_params.f32.vmulcaddc,
1225*4bdc9457SAndroid Build Coastguard Worker &jit_gemm_params,
1226*4bdc9457SAndroid Build Coastguard Worker linear_activation, relu_activation, XNN_INIT_FLAG_F32,
1227*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_convolution_nhwc_f32,
1228*4bdc9457SAndroid Build Coastguard Worker 0, NULL,
1229*4bdc9457SAndroid Build Coastguard Worker caches,
1230*4bdc9457SAndroid Build Coastguard Worker convolution_op_out);
1231*4bdc9457SAndroid Build Coastguard Worker }
1232*4bdc9457SAndroid Build Coastguard Worker
xnn_create_fused_convolution2d_nhwc_f32(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_channel_stride,size_t output_channel_stride,const float * kernel,const float * bias,size_t num_post_operations,struct xnn_post_operation * post_operations,uint32_t flags,xnn_caches_t caches,xnn_operator_t * convolution_op_out)1233*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_fused_convolution2d_nhwc_f32(
1234*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_top,
1235*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_right,
1236*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_bottom,
1237*4bdc9457SAndroid Build Coastguard Worker uint32_t input_padding_left,
1238*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height,
1239*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width,
1240*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_height,
1241*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_width,
1242*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height,
1243*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width,
1244*4bdc9457SAndroid Build Coastguard Worker uint32_t groups,
1245*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels,
1246*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels,
1247*4bdc9457SAndroid Build Coastguard Worker size_t input_channel_stride,
1248*4bdc9457SAndroid Build Coastguard Worker size_t output_channel_stride,
1249*4bdc9457SAndroid Build Coastguard Worker const float* kernel,
1250*4bdc9457SAndroid Build Coastguard Worker const float* bias,
1251*4bdc9457SAndroid Build Coastguard Worker size_t num_post_operations,
1252*4bdc9457SAndroid Build Coastguard Worker struct xnn_post_operation* post_operations,
1253*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
1254*4bdc9457SAndroid Build Coastguard Worker xnn_caches_t caches,
1255*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* convolution_op_out)
1256*4bdc9457SAndroid Build Coastguard Worker {
1257*4bdc9457SAndroid Build Coastguard Worker #if !XNN_ENABLE_JIT
1258*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
1259*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator: convolution with post operations available only if JIT is enabled",
1260*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_f32));
1261*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1262*4bdc9457SAndroid Build Coastguard Worker #endif
1263*4bdc9457SAndroid Build Coastguard Worker
1264*4bdc9457SAndroid Build Coastguard Worker // Convolution is specified with linear activation, any clamping should be specified as a post operator.
1265*4bdc9457SAndroid Build Coastguard Worker const float output_max = INFINITY;
1266*4bdc9457SAndroid Build Coastguard Worker const float output_min = -INFINITY;
1267*4bdc9457SAndroid Build Coastguard Worker
1268*4bdc9457SAndroid Build Coastguard Worker struct jit_gemm_params jit_gemm_params = {
1269*4bdc9457SAndroid Build Coastguard Worker .f32_minmax = {
1270*4bdc9457SAndroid Build Coastguard Worker .min = output_min,
1271*4bdc9457SAndroid Build Coastguard Worker .max = output_max
1272*4bdc9457SAndroid Build Coastguard Worker },
1273*4bdc9457SAndroid Build Coastguard Worker .num_post_operations = num_post_operations,
1274*4bdc9457SAndroid Build Coastguard Worker .post_operations = post_operations,
1275*4bdc9457SAndroid Build Coastguard Worker };
1276*4bdc9457SAndroid Build Coastguard Worker
1277*4bdc9457SAndroid Build Coastguard Worker char* post_operation_params = allocate_and_initialize_post_operation_params(num_post_operations, post_operations);
1278*4bdc9457SAndroid Build Coastguard Worker
1279*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_minmax_params gemm_params;
1280*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.f32.gemm.init.f32 != NULL) {
1281*4bdc9457SAndroid Build Coastguard Worker xnn_params.f32.gemm.init.f32(&gemm_params, output_min, output_max);
1282*4bdc9457SAndroid Build Coastguard Worker }
1283*4bdc9457SAndroid Build Coastguard Worker
1284*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_minmax_params dwconv_params;
1285*4bdc9457SAndroid Build Coastguard Worker const struct dwconv_parameters* dwconv_ukernel =
1286*4bdc9457SAndroid Build Coastguard Worker find_dwconv_ukernel(kernel_height * kernel_width, xnn_params.f32.dwconv, XNN_MAX_F32_DWCONV_UKERNELS);
1287*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(dwconv_ukernel != NULL) {
1288*4bdc9457SAndroid Build Coastguard Worker dwconv_ukernel->init.f32(&dwconv_params, output_min, output_max);
1289*4bdc9457SAndroid Build Coastguard Worker }
1290*4bdc9457SAndroid Build Coastguard Worker
1291*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_minmax_params vmulcaddc_params;
1292*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(xnn_params.f32.vmulcaddc.init.f32 != NULL) {
1293*4bdc9457SAndroid Build Coastguard Worker xnn_params.f32.vmulcaddc.init.f32(&vmulcaddc_params, output_min, output_max);
1294*4bdc9457SAndroid Build Coastguard Worker }
1295*4bdc9457SAndroid Build Coastguard Worker
1296*4bdc9457SAndroid Build Coastguard Worker return create_convolution2d_nhwc(
1297*4bdc9457SAndroid Build Coastguard Worker input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
1298*4bdc9457SAndroid Build Coastguard Worker kernel_height, kernel_width,
1299*4bdc9457SAndroid Build Coastguard Worker subsampling_height, subsampling_width,
1300*4bdc9457SAndroid Build Coastguard Worker dilation_height, dilation_width,
1301*4bdc9457SAndroid Build Coastguard Worker groups, group_input_channels, group_output_channels,
1302*4bdc9457SAndroid Build Coastguard Worker input_channel_stride, output_channel_stride,
1303*4bdc9457SAndroid Build Coastguard Worker kernel, bias, flags,
1304*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
1305*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
1306*4bdc9457SAndroid Build Coastguard Worker sizeof(float) /* sizeof(bias element) */,
1307*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_vmulcaddc_w_function) xnn_pack_f32_vmulcaddc_w,
1308*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_dwconv_hwg_w_function) xnn_pack_f32_dwconv_hwg_w,
1309*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_dwconv_ghw_w_function) xnn_pack_f32_dwconv_ghw_w,
1310*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_gemm_goi_w_function) xnn_pack_f32_gemm_goi_w,
1311*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_conv_kgo_w_function) xnn_pack_f32_conv_kgo_w,
1312*4bdc9457SAndroid Build Coastguard Worker (xnn_pack_conv_goki_w_function) xnn_pack_f32_conv_goki_w,
1313*4bdc9457SAndroid Build Coastguard Worker NULL /* packing params */, 0 /* input padding byte */, 0 /* packed weights padding byte */,
1314*4bdc9457SAndroid Build Coastguard Worker 0 /* extra weights bytes */, NULL /* init scale params fn */, NULL /* scale params */,
1315*4bdc9457SAndroid Build Coastguard Worker (void*) &gemm_params, sizeof(gemm_params),
1316*4bdc9457SAndroid Build Coastguard Worker &dwconv_params, sizeof(dwconv_params),
1317*4bdc9457SAndroid Build Coastguard Worker &vmulcaddc_params, sizeof(vmulcaddc_params),
1318*4bdc9457SAndroid Build Coastguard Worker &xnn_params.f32.gemm, dwconv_ukernel, &xnn_params.f32.vmulcaddc,
1319*4bdc9457SAndroid Build Coastguard Worker &jit_gemm_params,
1320*4bdc9457SAndroid Build Coastguard Worker true /* linear_activation */, false /* relu_activation */, XNN_INIT_FLAG_F32,
1321*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_convolution_nhwc_f32,
1322*4bdc9457SAndroid Build Coastguard Worker num_post_operations, post_operation_params,
1323*4bdc9457SAndroid Build Coastguard Worker caches,
1324*4bdc9457SAndroid Build Coastguard Worker convolution_op_out);
1325*4bdc9457SAndroid Build Coastguard Worker }
1326*4bdc9457SAndroid Build Coastguard Worker
setup_convolution2d_nhwc(xnn_operator_t convolution_op,enum xnn_operator_type expected_operator_type,size_t batch_size,size_t input_height,size_t input_width,const void * input,void * output,uint32_t datatype_init_flags,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t extra_weights_elements_size,uint32_t log2_output_element_size,size_t num_threads)1327*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status setup_convolution2d_nhwc(
1328*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op,
1329*4bdc9457SAndroid Build Coastguard Worker enum xnn_operator_type expected_operator_type,
1330*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
1331*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
1332*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
1333*4bdc9457SAndroid Build Coastguard Worker const void* input,
1334*4bdc9457SAndroid Build Coastguard Worker void* output,
1335*4bdc9457SAndroid Build Coastguard Worker uint32_t datatype_init_flags,
1336*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_input_element_size,
1337*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_filter_element_size,
1338*4bdc9457SAndroid Build Coastguard Worker uint32_t extra_weights_elements_size,
1339*4bdc9457SAndroid Build Coastguard Worker uint32_t log2_output_element_size,
1340*4bdc9457SAndroid Build Coastguard Worker size_t num_threads)
1341*4bdc9457SAndroid Build Coastguard Worker {
1342*4bdc9457SAndroid Build Coastguard Worker if (convolution_op->type != expected_operator_type) {
1343*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1344*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(expected_operator_type),
1345*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(convolution_op->type));
1346*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1347*4bdc9457SAndroid Build Coastguard Worker }
1348*4bdc9457SAndroid Build Coastguard Worker convolution_op->state = xnn_run_state_invalid;
1349*4bdc9457SAndroid Build Coastguard Worker
1350*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
1351*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
1352*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(convolution_op->type));
1353*4bdc9457SAndroid Build Coastguard Worker return xnn_status_uninitialized;
1354*4bdc9457SAndroid Build Coastguard Worker }
1355*4bdc9457SAndroid Build Coastguard Worker
1356*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
1357*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
1358*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator: operations on data type are not supported",
1359*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(convolution_op->type));
1360*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_hardware;
1361*4bdc9457SAndroid Build Coastguard Worker }
1362*4bdc9457SAndroid Build Coastguard Worker
1363*4bdc9457SAndroid Build Coastguard Worker if (input_width == 0 || input_height == 0) {
1364*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
1365*4bdc9457SAndroid Build Coastguard Worker "failed to setup %s operator with %zux%zu input: input dimensions must be non-zero",
1366*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(convolution_op->type), input_width, input_height);
1367*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
1368*4bdc9457SAndroid Build Coastguard Worker }
1369*4bdc9457SAndroid Build Coastguard Worker
1370*4bdc9457SAndroid Build Coastguard Worker if (batch_size == 0) {
1371*4bdc9457SAndroid Build Coastguard Worker convolution_op->state = xnn_run_state_skip;
1372*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
1373*4bdc9457SAndroid Build Coastguard Worker }
1374*4bdc9457SAndroid Build Coastguard Worker
1375*4bdc9457SAndroid Build Coastguard Worker if (convolution_op->weights_cache != NULL && !xnn_weights_cache_is_finalized(convolution_op->weights_cache)) {
1376*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup %s operator: weights cache is not finalized",
1377*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(convolution_op->type));
1378*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_state;
1379*4bdc9457SAndroid Build Coastguard Worker }
1380*4bdc9457SAndroid Build Coastguard Worker
1381*4bdc9457SAndroid Build Coastguard Worker convolution_op->batch_size = batch_size;
1382*4bdc9457SAndroid Build Coastguard Worker convolution_op->input_height = input_height;
1383*4bdc9457SAndroid Build Coastguard Worker convolution_op->input_width = input_width;
1384*4bdc9457SAndroid Build Coastguard Worker convolution_op->input = input;
1385*4bdc9457SAndroid Build Coastguard Worker
1386*4bdc9457SAndroid Build Coastguard Worker if (convolution_op->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
1387*4bdc9457SAndroid Build Coastguard Worker convolution_op->output_height = compute_output_dimension_with_tf_same_padding(
1388*4bdc9457SAndroid Build Coastguard Worker input_height, convolution_op->stride_height);
1389*4bdc9457SAndroid Build Coastguard Worker convolution_op->output_width = compute_output_dimension_with_tf_same_padding(
1390*4bdc9457SAndroid Build Coastguard Worker input_width, convolution_op->stride_width);
1391*4bdc9457SAndroid Build Coastguard Worker
1392*4bdc9457SAndroid Build Coastguard Worker const uint32_t effective_kernel_height = (convolution_op->kernel_height - 1) * convolution_op->dilation_height + 1;
1393*4bdc9457SAndroid Build Coastguard Worker const uint32_t effective_kernel_width = (convolution_op->kernel_width - 1) * convolution_op->dilation_width + 1;
1394*4bdc9457SAndroid Build Coastguard Worker const size_t total_padding_height =
1395*4bdc9457SAndroid Build Coastguard Worker (convolution_op->output_height - 1) * convolution_op->stride_height + effective_kernel_height - input_height;
1396*4bdc9457SAndroid Build Coastguard Worker const size_t total_padding_width =
1397*4bdc9457SAndroid Build Coastguard Worker (convolution_op->output_width - 1) * convolution_op->stride_width + effective_kernel_width - input_width;
1398*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_top = total_padding_height / 2;
1399*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_left = total_padding_width / 2;
1400*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_bottom = total_padding_height - convolution_op->padding_top;
1401*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_right = total_padding_width - convolution_op->padding_left;
1402*4bdc9457SAndroid Build Coastguard Worker } else {
1403*4bdc9457SAndroid Build Coastguard Worker convolution_op->output_height = xnn_compute_convolution_output_dimension(
1404*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_top + input_height + convolution_op->padding_bottom,
1405*4bdc9457SAndroid Build Coastguard Worker convolution_op->kernel_height,
1406*4bdc9457SAndroid Build Coastguard Worker convolution_op->dilation_height,
1407*4bdc9457SAndroid Build Coastguard Worker convolution_op->stride_height);
1408*4bdc9457SAndroid Build Coastguard Worker convolution_op->output_width = xnn_compute_convolution_output_dimension(
1409*4bdc9457SAndroid Build Coastguard Worker convolution_op->padding_left + input_width + convolution_op->padding_right,
1410*4bdc9457SAndroid Build Coastguard Worker convolution_op->kernel_width,
1411*4bdc9457SAndroid Build Coastguard Worker convolution_op->dilation_width,
1412*4bdc9457SAndroid Build Coastguard Worker convolution_op->stride_width);
1413*4bdc9457SAndroid Build Coastguard Worker }
1414*4bdc9457SAndroid Build Coastguard Worker convolution_op->output = output;
1415*4bdc9457SAndroid Build Coastguard Worker
1416*4bdc9457SAndroid Build Coastguard Worker switch (convolution_op->ukernel.type) {
1417*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_gemm:
1418*4bdc9457SAndroid Build Coastguard Worker {
1419*4bdc9457SAndroid Build Coastguard Worker // Convolution maps directly to GEMM and doesn't use indirection buffer.
1420*4bdc9457SAndroid Build Coastguard Worker
1421*4bdc9457SAndroid Build Coastguard Worker const size_t output_height = convolution_op->output_height;
1422*4bdc9457SAndroid Build Coastguard Worker const size_t output_width = convolution_op->output_width;
1423*4bdc9457SAndroid Build Coastguard Worker const size_t output_size = output_height * output_width;
1424*4bdc9457SAndroid Build Coastguard Worker const size_t batch_output_size = batch_size * output_size;
1425*4bdc9457SAndroid Build Coastguard Worker
1426*4bdc9457SAndroid Build Coastguard Worker const size_t groups = convolution_op->groups;
1427*4bdc9457SAndroid Build Coastguard Worker const size_t group_input_channels = convolution_op->group_input_channels;
1428*4bdc9457SAndroid Build Coastguard Worker const size_t w_stride = extra_weights_elements_size +
1429*4bdc9457SAndroid Build Coastguard Worker (round_up_po2(group_input_channels, convolution_op->ukernel.gemm.kr * convolution_op->ukernel.gemm.sr) << log2_filter_element_size);
1430*4bdc9457SAndroid Build Coastguard Worker const size_t group_output_channels = convolution_op->group_output_channels;
1431*4bdc9457SAndroid Build Coastguard Worker
1432*4bdc9457SAndroid Build Coastguard Worker uint32_t mr = convolution_op->ukernel.gemm.mr;
1433*4bdc9457SAndroid Build Coastguard Worker const uint32_t nr = convolution_op->ukernel.gemm.nr;
1434*4bdc9457SAndroid Build Coastguard Worker struct xnn_hmp_gemm_ukernel *gemm_cases = convolution_op->ukernel.gemm.gemm_cases;
1435*4bdc9457SAndroid Build Coastguard Worker
1436*4bdc9457SAndroid Build Coastguard Worker #if XNN_ENABLE_GEMM_M_SPECIALIZATION
1437*4bdc9457SAndroid Build Coastguard Worker mr = xnn_get_heuristic_mr_gemm(batch_output_size, mr, nr, gemm_cases);
1438*4bdc9457SAndroid Build Coastguard Worker #else
1439*4bdc9457SAndroid Build Coastguard Worker if (batch_output_size == 1 && gemm_cases[0].function[XNN_UARCH_DEFAULT] != NULL) {
1440*4bdc9457SAndroid Build Coastguard Worker mr = 1;
1441*4bdc9457SAndroid Build Coastguard Worker }
1442*4bdc9457SAndroid Build Coastguard Worker #endif
1443*4bdc9457SAndroid Build Coastguard Worker
1444*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT
1445*4bdc9457SAndroid Build Coastguard Worker if (convolution_op->code_cache != NULL) {
1446*4bdc9457SAndroid Build Coastguard Worker const size_t jit_code_offset = gemm_cases[mr - 1].generated_code_offset[XNN_UARCH_DEFAULT];
1447*4bdc9457SAndroid Build Coastguard Worker if (jit_code_offset != XNN_CACHE_NOT_FOUND) {
1448*4bdc9457SAndroid Build Coastguard Worker gemm_cases[mr - 1].function[XNN_UARCH_DEFAULT] =
1449*4bdc9457SAndroid Build Coastguard Worker (xnn_gemm_ukernel_function) cached_code_at_offset(convolution_op, jit_code_offset);
1450*4bdc9457SAndroid Build Coastguard Worker }
1451*4bdc9457SAndroid Build Coastguard Worker }
1452*4bdc9457SAndroid Build Coastguard Worker #endif // XNN_PLATFORM_JIT
1453*4bdc9457SAndroid Build Coastguard Worker struct xnn_hmp_gemm_ukernel gemm_ukernel = gemm_cases[mr - 1];
1454*4bdc9457SAndroid Build Coastguard Worker
1455*4bdc9457SAndroid Build Coastguard Worker convolution_op->context.gemm = (struct gemm_context) {
1456*4bdc9457SAndroid Build Coastguard Worker .k_scaled = group_input_channels << log2_input_element_size,
1457*4bdc9457SAndroid Build Coastguard Worker .a = input,
1458*4bdc9457SAndroid Build Coastguard Worker .a_stride = convolution_op->input_pixel_stride << log2_input_element_size,
1459*4bdc9457SAndroid Build Coastguard Worker .packed_w = packed_weights(convolution_op),
1460*4bdc9457SAndroid Build Coastguard Worker .w_stride = w_stride,
1461*4bdc9457SAndroid Build Coastguard Worker .wg_stride = w_stride * round_up(group_output_channels, nr),
1462*4bdc9457SAndroid Build Coastguard Worker .c = output,
1463*4bdc9457SAndroid Build Coastguard Worker .cm_stride = convolution_op->output_pixel_stride << log2_output_element_size,
1464*4bdc9457SAndroid Build Coastguard Worker .cn_stride = nr << log2_output_element_size,
1465*4bdc9457SAndroid Build Coastguard Worker .cg_stride = group_output_channels << log2_output_element_size,
1466*4bdc9457SAndroid Build Coastguard Worker .log2_csize = log2_output_element_size,
1467*4bdc9457SAndroid Build Coastguard Worker .ukernel = gemm_ukernel,
1468*4bdc9457SAndroid Build Coastguard Worker };
1469*4bdc9457SAndroid Build Coastguard Worker memcpy(&convolution_op->context.gemm.params, &convolution_op->params, sizeof(convolution_op->context.gemm.params));
1470*4bdc9457SAndroid Build Coastguard Worker if (convolution_op->num_post_operation_params == 0) {
1471*4bdc9457SAndroid Build Coastguard Worker convolution_op->context.gemm.fused_params = &convolution_op->context.gemm.params;
1472*4bdc9457SAndroid Build Coastguard Worker } else {
1473*4bdc9457SAndroid Build Coastguard Worker convolution_op->context.gemm.fused_params = convolution_op->post_operation_params;
1474*4bdc9457SAndroid Build Coastguard Worker }
1475*4bdc9457SAndroid Build Coastguard Worker
1476*4bdc9457SAndroid Build Coastguard Worker #if XNN_TEST_MODE
1477*4bdc9457SAndroid Build Coastguard Worker const size_t nc = nr;
1478*4bdc9457SAndroid Build Coastguard Worker #else
1479*4bdc9457SAndroid Build Coastguard Worker size_t nc = group_output_channels;
1480*4bdc9457SAndroid Build Coastguard Worker if (num_threads > 1) {
1481*4bdc9457SAndroid Build Coastguard Worker const size_t num_other_tiles = groups * divide_round_up(batch_output_size, mr);
1482*4bdc9457SAndroid Build Coastguard Worker const size_t target_tiles_per_thread = 5;
1483*4bdc9457SAndroid Build Coastguard Worker const size_t max_nc = divide_round_up(group_output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
1484*4bdc9457SAndroid Build Coastguard Worker if (max_nc < nc) {
1485*4bdc9457SAndroid Build Coastguard Worker nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
1486*4bdc9457SAndroid Build Coastguard Worker }
1487*4bdc9457SAndroid Build Coastguard Worker }
1488*4bdc9457SAndroid Build Coastguard Worker #endif
1489*4bdc9457SAndroid Build Coastguard Worker if (groups == 1) {
1490*4bdc9457SAndroid Build Coastguard Worker #if XNN_MAX_UARCH_TYPES > 1
1491*4bdc9457SAndroid Build Coastguard Worker if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
1492*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d_with_uarch;
1493*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm;
1494*4bdc9457SAndroid Build Coastguard Worker } else {
1495*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
1496*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
1497*4bdc9457SAndroid Build Coastguard Worker }
1498*4bdc9457SAndroid Build Coastguard Worker #else
1499*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
1500*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
1501*4bdc9457SAndroid Build Coastguard Worker #endif
1502*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[0] = batch_output_size;
1503*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[1] = group_output_channels;
1504*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.tile[0] = mr;
1505*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.tile[1] = nc;
1506*4bdc9457SAndroid Build Coastguard Worker } else {
1507*4bdc9457SAndroid Build Coastguard Worker #if XNN_MAX_UARCH_TYPES > 1
1508*4bdc9457SAndroid Build Coastguard Worker if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
1509*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d_with_uarch;
1510*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t) xnn_compute_hmp_grouped_gemm;
1511*4bdc9457SAndroid Build Coastguard Worker } else {
1512*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
1513*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm;
1514*4bdc9457SAndroid Build Coastguard Worker }
1515*4bdc9457SAndroid Build Coastguard Worker #else
1516*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
1517*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm;
1518*4bdc9457SAndroid Build Coastguard Worker #endif
1519*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[0] = groups;
1520*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[1] = batch_output_size;
1521*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[2] = group_output_channels;
1522*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.tile[0] = mr;
1523*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.tile[1] = nc;
1524*4bdc9457SAndroid Build Coastguard Worker }
1525*4bdc9457SAndroid Build Coastguard Worker convolution_op->state = xnn_run_state_ready;
1526*4bdc9457SAndroid Build Coastguard Worker
1527*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
1528*4bdc9457SAndroid Build Coastguard Worker }
1529*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_igemm:
1530*4bdc9457SAndroid Build Coastguard Worker {
1531*4bdc9457SAndroid Build Coastguard Worker const size_t groups = convolution_op->groups;
1532*4bdc9457SAndroid Build Coastguard Worker const size_t kernel_height = convolution_op->kernel_height;
1533*4bdc9457SAndroid Build Coastguard Worker const size_t kernel_width = convolution_op->kernel_width;
1534*4bdc9457SAndroid Build Coastguard Worker const size_t kernel_size = kernel_height * kernel_width;
1535*4bdc9457SAndroid Build Coastguard Worker const size_t output_height = convolution_op->output_height;
1536*4bdc9457SAndroid Build Coastguard Worker const size_t output_width = convolution_op->output_width;
1537*4bdc9457SAndroid Build Coastguard Worker const size_t output_size = output_height * output_width;
1538*4bdc9457SAndroid Build Coastguard Worker
1539*4bdc9457SAndroid Build Coastguard Worker uint32_t mr = convolution_op->ukernel.igemm.mr;
1540*4bdc9457SAndroid Build Coastguard Worker const uint32_t nr = convolution_op->ukernel.igemm.nr;
1541*4bdc9457SAndroid Build Coastguard Worker struct xnn_hmp_igemm_ukernel* igemm_cases = convolution_op->ukernel.igemm.igemm_cases;
1542*4bdc9457SAndroid Build Coastguard Worker
1543*4bdc9457SAndroid Build Coastguard Worker #if XNN_ENABLE_GEMM_M_SPECIALIZATION
1544*4bdc9457SAndroid Build Coastguard Worker mr = xnn_get_heuristic_mr_igemm(output_size, mr, nr, igemm_cases);
1545*4bdc9457SAndroid Build Coastguard Worker #else
1546*4bdc9457SAndroid Build Coastguard Worker if (output_size == 1 && igemm_cases[0].function[XNN_UARCH_DEFAULT] != NULL) {
1547*4bdc9457SAndroid Build Coastguard Worker mr = 1;
1548*4bdc9457SAndroid Build Coastguard Worker }
1549*4bdc9457SAndroid Build Coastguard Worker #endif
1550*4bdc9457SAndroid Build Coastguard Worker
1551*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT
1552*4bdc9457SAndroid Build Coastguard Worker if (convolution_op->code_cache != NULL) {
1553*4bdc9457SAndroid Build Coastguard Worker const size_t jit_code_offset = igemm_cases[mr - 1].generated_code_offset[XNN_UARCH_DEFAULT];
1554*4bdc9457SAndroid Build Coastguard Worker if (jit_code_offset != XNN_CACHE_NOT_FOUND) {
1555*4bdc9457SAndroid Build Coastguard Worker igemm_cases[mr - 1].function[XNN_UARCH_DEFAULT] =
1556*4bdc9457SAndroid Build Coastguard Worker (xnn_igemm_ukernel_function) cached_code_at_offset(convolution_op, jit_code_offset);
1557*4bdc9457SAndroid Build Coastguard Worker }
1558*4bdc9457SAndroid Build Coastguard Worker }
1559*4bdc9457SAndroid Build Coastguard Worker #endif // XNN_PLATFORM_JIT
1560*4bdc9457SAndroid Build Coastguard Worker struct xnn_hmp_igemm_ukernel igemm_ukernel = igemm_cases[mr - 1];
1561*4bdc9457SAndroid Build Coastguard Worker
1562*4bdc9457SAndroid Build Coastguard Worker const size_t tiled_output_size = round_up(output_size, mr);
1563*4bdc9457SAndroid Build Coastguard Worker const size_t indirection_buffer_size = sizeof(void*) * kernel_size * tiled_output_size;
1564*4bdc9457SAndroid Build Coastguard Worker
1565*4bdc9457SAndroid Build Coastguard Worker if (input_height != convolution_op->last_input_height ||
1566*4bdc9457SAndroid Build Coastguard Worker input_width != convolution_op->last_input_width)
1567*4bdc9457SAndroid Build Coastguard Worker {
1568*4bdc9457SAndroid Build Coastguard Worker const void** indirection_buffer = (const void**) xnn_reallocate_memory((void*) convolution_op->indirection_buffer, indirection_buffer_size);
1569*4bdc9457SAndroid Build Coastguard Worker if (indirection_buffer == NULL) {
1570*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
1571*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator indirection buffer",
1572*4bdc9457SAndroid Build Coastguard Worker indirection_buffer_size, xnn_operator_type_to_string(convolution_op->type));
1573*4bdc9457SAndroid Build Coastguard Worker return xnn_status_out_of_memory;
1574*4bdc9457SAndroid Build Coastguard Worker }
1575*4bdc9457SAndroid Build Coastguard Worker convolution_op->indirection_buffer = indirection_buffer;
1576*4bdc9457SAndroid Build Coastguard Worker convolution_op->last_input = input;
1577*4bdc9457SAndroid Build Coastguard Worker convolution_op->last_input_height = input_height;
1578*4bdc9457SAndroid Build Coastguard Worker convolution_op->last_input_width = input_width;
1579*4bdc9457SAndroid Build Coastguard Worker
1580*4bdc9457SAndroid Build Coastguard Worker xnn_indirection_init_conv2d(convolution_op, mr, log2_input_element_size);
1581*4bdc9457SAndroid Build Coastguard Worker }
1582*4bdc9457SAndroid Build Coastguard Worker
1583*4bdc9457SAndroid Build Coastguard Worker const size_t group_input_channels = convolution_op->group_input_channels;
1584*4bdc9457SAndroid Build Coastguard Worker const size_t w_stride = extra_weights_elements_size +
1585*4bdc9457SAndroid Build Coastguard Worker (round_up_po2(group_input_channels, convolution_op->ukernel.igemm.kr * convolution_op->ukernel.igemm.sr) * kernel_size << log2_filter_element_size);
1586*4bdc9457SAndroid Build Coastguard Worker const size_t group_output_channels = convolution_op->group_output_channels;
1587*4bdc9457SAndroid Build Coastguard Worker convolution_op->context.igemm = (struct igemm_context) {
1588*4bdc9457SAndroid Build Coastguard Worker .ks = kernel_size,
1589*4bdc9457SAndroid Build Coastguard Worker .ks_scaled = kernel_size * mr * sizeof(void*),
1590*4bdc9457SAndroid Build Coastguard Worker .kc = group_input_channels << log2_input_element_size,
1591*4bdc9457SAndroid Build Coastguard Worker .w_stride = w_stride,
1592*4bdc9457SAndroid Build Coastguard Worker .indirect_a = convolution_op->indirection_buffer,
1593*4bdc9457SAndroid Build Coastguard Worker .a_offset = (size_t) ((uintptr_t) input - (uintptr_t) convolution_op->last_input),
1594*4bdc9457SAndroid Build Coastguard Worker .zero = convolution_op->zero_buffer,
1595*4bdc9457SAndroid Build Coastguard Worker .packed_w = packed_weights(convolution_op),
1596*4bdc9457SAndroid Build Coastguard Worker .c = convolution_op->output,
1597*4bdc9457SAndroid Build Coastguard Worker .cm_stride = convolution_op->output_pixel_stride << log2_output_element_size,
1598*4bdc9457SAndroid Build Coastguard Worker .cn_stride = nr << log2_output_element_size,
1599*4bdc9457SAndroid Build Coastguard Worker .ga_stride = group_input_channels << log2_input_element_size,
1600*4bdc9457SAndroid Build Coastguard Worker .gw_stride = w_stride * round_up(group_output_channels, nr),
1601*4bdc9457SAndroid Build Coastguard Worker .gc_stride = group_output_channels << log2_output_element_size,
1602*4bdc9457SAndroid Build Coastguard Worker .ba_stride = input_height * input_width * convolution_op->input_pixel_stride << log2_input_element_size,
1603*4bdc9457SAndroid Build Coastguard Worker .bc_stride = output_size * convolution_op->output_pixel_stride << log2_output_element_size,
1604*4bdc9457SAndroid Build Coastguard Worker .log2_csize = log2_output_element_size,
1605*4bdc9457SAndroid Build Coastguard Worker .ukernel = igemm_ukernel,
1606*4bdc9457SAndroid Build Coastguard Worker };
1607*4bdc9457SAndroid Build Coastguard Worker memcpy(&convolution_op->context.igemm.params, &convolution_op->params, sizeof(convolution_op->context.igemm.params));
1608*4bdc9457SAndroid Build Coastguard Worker
1609*4bdc9457SAndroid Build Coastguard Worker #if XNN_TEST_MODE
1610*4bdc9457SAndroid Build Coastguard Worker const size_t nc = nr;
1611*4bdc9457SAndroid Build Coastguard Worker #else
1612*4bdc9457SAndroid Build Coastguard Worker size_t nc = group_output_channels;
1613*4bdc9457SAndroid Build Coastguard Worker if (num_threads > 1) {
1614*4bdc9457SAndroid Build Coastguard Worker const size_t num_other_tiles = groups * batch_size * divide_round_up(output_size, mr);
1615*4bdc9457SAndroid Build Coastguard Worker const size_t target_tiles_per_thread = 5;
1616*4bdc9457SAndroid Build Coastguard Worker const size_t max_nc = divide_round_up(group_output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
1617*4bdc9457SAndroid Build Coastguard Worker if (max_nc < nc) {
1618*4bdc9457SAndroid Build Coastguard Worker nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
1619*4bdc9457SAndroid Build Coastguard Worker }
1620*4bdc9457SAndroid Build Coastguard Worker }
1621*4bdc9457SAndroid Build Coastguard Worker #endif
1622*4bdc9457SAndroid Build Coastguard Worker if (groups == 1) {
1623*4bdc9457SAndroid Build Coastguard Worker #if XNN_MAX_UARCH_TYPES > 1
1624*4bdc9457SAndroid Build Coastguard Worker if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) {
1625*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
1626*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d_with_uarch;
1627*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t) xnn_compute_batch_hmp_igemm;
1628*4bdc9457SAndroid Build Coastguard Worker } else {
1629*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d_with_uarch;
1630*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_igemm;
1631*4bdc9457SAndroid Build Coastguard Worker }
1632*4bdc9457SAndroid Build Coastguard Worker } else {
1633*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
1634*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
1635*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_batch_igemm;
1636*4bdc9457SAndroid Build Coastguard Worker } else {
1637*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
1638*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_igemm;
1639*4bdc9457SAndroid Build Coastguard Worker }
1640*4bdc9457SAndroid Build Coastguard Worker }
1641*4bdc9457SAndroid Build Coastguard Worker #else
1642*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
1643*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
1644*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_batch_igemm;
1645*4bdc9457SAndroid Build Coastguard Worker } else {
1646*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
1647*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_igemm;
1648*4bdc9457SAndroid Build Coastguard Worker }
1649*4bdc9457SAndroid Build Coastguard Worker #endif
1650*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
1651*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[0] = batch_size;
1652*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[1] = output_size;
1653*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[2] = group_output_channels;
1654*4bdc9457SAndroid Build Coastguard Worker } else {
1655*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[0] = output_size;
1656*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[1] = group_output_channels;
1657*4bdc9457SAndroid Build Coastguard Worker }
1658*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.tile[0] = mr;
1659*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.tile[1] = nc;
1660*4bdc9457SAndroid Build Coastguard Worker } else {
1661*4bdc9457SAndroid Build Coastguard Worker #if XNN_MAX_UARCH_TYPES > 1
1662*4bdc9457SAndroid Build Coastguard Worker if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) {
1663*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
1664*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_4d_tile_2d_with_uarch;
1665*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_4d_tile_2d_with_id = (pthreadpool_task_4d_tile_2d_with_id_t) xnn_compute_hmp_grouped_batch_igemm;
1666*4bdc9457SAndroid Build Coastguard Worker } else {
1667*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d_with_uarch;
1668*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t) xnn_compute_hmp_grouped_igemm;
1669*4bdc9457SAndroid Build Coastguard Worker }
1670*4bdc9457SAndroid Build Coastguard Worker } else {
1671*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
1672*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
1673*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_grouped_batch_igemm;
1674*4bdc9457SAndroid Build Coastguard Worker } else {
1675*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
1676*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_igemm;
1677*4bdc9457SAndroid Build Coastguard Worker }
1678*4bdc9457SAndroid Build Coastguard Worker }
1679*4bdc9457SAndroid Build Coastguard Worker #else
1680*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
1681*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
1682*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_grouped_batch_igemm;
1683*4bdc9457SAndroid Build Coastguard Worker } else {
1684*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
1685*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_igemm;
1686*4bdc9457SAndroid Build Coastguard Worker }
1687*4bdc9457SAndroid Build Coastguard Worker #endif
1688*4bdc9457SAndroid Build Coastguard Worker if (batch_size > 1) {
1689*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[0] = batch_size;
1690*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[1] = groups;
1691*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[2] = output_size;
1692*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[3] = group_output_channels;
1693*4bdc9457SAndroid Build Coastguard Worker } else {
1694*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[0] = groups;
1695*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[1] = output_size;
1696*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[2] = group_output_channels;
1697*4bdc9457SAndroid Build Coastguard Worker }
1698*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.tile[0] = mr;
1699*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.tile[1] = nc;
1700*4bdc9457SAndroid Build Coastguard Worker }
1701*4bdc9457SAndroid Build Coastguard Worker convolution_op->state = xnn_run_state_ready;
1702*4bdc9457SAndroid Build Coastguard Worker
1703*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
1704*4bdc9457SAndroid Build Coastguard Worker }
1705*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_dwconv:
1706*4bdc9457SAndroid Build Coastguard Worker {
1707*4bdc9457SAndroid Build Coastguard Worker const size_t kernel_height = convolution_op->kernel_height;
1708*4bdc9457SAndroid Build Coastguard Worker const size_t kernel_width = convolution_op->kernel_width;
1709*4bdc9457SAndroid Build Coastguard Worker const size_t kernel_size = kernel_height * kernel_width;
1710*4bdc9457SAndroid Build Coastguard Worker const size_t output_height = convolution_op->output_height;
1711*4bdc9457SAndroid Build Coastguard Worker const size_t output_width = convolution_op->output_width;
1712*4bdc9457SAndroid Build Coastguard Worker const size_t step_width = convolution_op->dilation_width == 1 ? convolution_op->stride_width : kernel_width;
1713*4bdc9457SAndroid Build Coastguard Worker const size_t step_height = kernel_size + (output_width - 1) * step_width * kernel_height;
1714*4bdc9457SAndroid Build Coastguard Worker const size_t primary_tile = convolution_op->ukernel.dwconv.primary_tile;
1715*4bdc9457SAndroid Build Coastguard Worker if (input_height != convolution_op->last_input_height || input_width != convolution_op->last_input_width) {
1716*4bdc9457SAndroid Build Coastguard Worker // Micro-kernel will read (primary_tile - kernel_size) elements after the end of indirection buffer.
1717*4bdc9457SAndroid Build Coastguard Worker const size_t indirection_buffer_size =
1718*4bdc9457SAndroid Build Coastguard Worker sizeof(void*) * (primary_tile - kernel_size + output_height * step_height);
1719*4bdc9457SAndroid Build Coastguard Worker
1720*4bdc9457SAndroid Build Coastguard Worker const void** indirection_buffer =
1721*4bdc9457SAndroid Build Coastguard Worker (const void**) xnn_reallocate_memory(convolution_op->indirection_buffer, indirection_buffer_size);
1722*4bdc9457SAndroid Build Coastguard Worker if (indirection_buffer == NULL) {
1723*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to allocate %zu bytes for %s operator indirection buffer",
1724*4bdc9457SAndroid Build Coastguard Worker indirection_buffer_size, xnn_operator_type_to_string(convolution_op->type));
1725*4bdc9457SAndroid Build Coastguard Worker return xnn_status_out_of_memory;
1726*4bdc9457SAndroid Build Coastguard Worker }
1727*4bdc9457SAndroid Build Coastguard Worker convolution_op->indirection_buffer = indirection_buffer;
1728*4bdc9457SAndroid Build Coastguard Worker
1729*4bdc9457SAndroid Build Coastguard Worker xnn_indirection_init_dwconv2d(convolution_op, step_height, step_width, primary_tile, log2_input_element_size);
1730*4bdc9457SAndroid Build Coastguard Worker
1731*4bdc9457SAndroid Build Coastguard Worker convolution_op->last_input = input;
1732*4bdc9457SAndroid Build Coastguard Worker convolution_op->last_input_height = input_height;
1733*4bdc9457SAndroid Build Coastguard Worker convolution_op->last_input_width = input_width;
1734*4bdc9457SAndroid Build Coastguard Worker }
1735*4bdc9457SAndroid Build Coastguard Worker
1736*4bdc9457SAndroid Build Coastguard Worker const size_t groups = convolution_op->groups;
1737*4bdc9457SAndroid Build Coastguard Worker convolution_op->context.dwconv = (struct dwconv_context) {
1738*4bdc9457SAndroid Build Coastguard Worker .indirect_input = convolution_op->indirection_buffer,
1739*4bdc9457SAndroid Build Coastguard Worker .indirect_input_width_stride = kernel_height * step_width * sizeof(void*),
1740*4bdc9457SAndroid Build Coastguard Worker .indirect_input_height_stride = step_height * sizeof(void*),
1741*4bdc9457SAndroid Build Coastguard Worker .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) convolution_op->last_input),
1742*4bdc9457SAndroid Build Coastguard Worker .input_batch_stride = (input_height * input_width * convolution_op->input_pixel_stride) << log2_input_element_size,
1743*4bdc9457SAndroid Build Coastguard Worker .packed_weights = packed_weights(convolution_op),
1744*4bdc9457SAndroid Build Coastguard Worker .output = convolution_op->output,
1745*4bdc9457SAndroid Build Coastguard Worker .output_batch_stride = (output_height * output_width * convolution_op->output_pixel_stride) << log2_output_element_size,
1746*4bdc9457SAndroid Build Coastguard Worker .output_height_stride = (output_width * convolution_op->output_pixel_stride) << log2_output_element_size,
1747*4bdc9457SAndroid Build Coastguard Worker .output_width = output_width,
1748*4bdc9457SAndroid Build Coastguard Worker .groups = groups,
1749*4bdc9457SAndroid Build Coastguard Worker .zero = convolution_op->zero_buffer,
1750*4bdc9457SAndroid Build Coastguard Worker .output_increment = (convolution_op->output_pixel_stride - groups) << log2_output_element_size,
1751*4bdc9457SAndroid Build Coastguard Worker .unipass_ukernel = convolution_op->ukernel.dwconv.unipass_function,
1752*4bdc9457SAndroid Build Coastguard Worker };
1753*4bdc9457SAndroid Build Coastguard Worker memcpy(&convolution_op->context.dwconv.params, &convolution_op->params, sizeof(convolution_op->context.dwconv.params));
1754*4bdc9457SAndroid Build Coastguard Worker
1755*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_2d;
1756*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_dwconv_unipass;
1757*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[0] = batch_size;
1758*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[1] = output_height;
1759*4bdc9457SAndroid Build Coastguard Worker convolution_op->state = xnn_run_state_ready;
1760*4bdc9457SAndroid Build Coastguard Worker
1761*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
1762*4bdc9457SAndroid Build Coastguard Worker }
1763*4bdc9457SAndroid Build Coastguard Worker case xnn_ukernel_type_vmulcaddc:
1764*4bdc9457SAndroid Build Coastguard Worker {
1765*4bdc9457SAndroid Build Coastguard Worker const size_t batch_output_size = batch_size * convolution_op->output_height * convolution_op->output_width;
1766*4bdc9457SAndroid Build Coastguard Worker
1767*4bdc9457SAndroid Build Coastguard Worker convolution_op->context.vmulcaddc = (struct vmulcaddc_context) {
1768*4bdc9457SAndroid Build Coastguard Worker .n = convolution_op->groups << log2_input_element_size,
1769*4bdc9457SAndroid Build Coastguard Worker .x = input,
1770*4bdc9457SAndroid Build Coastguard Worker .x_stride = convolution_op->input_pixel_stride << log2_input_element_size,
1771*4bdc9457SAndroid Build Coastguard Worker .w = packed_weights(convolution_op),
1772*4bdc9457SAndroid Build Coastguard Worker .y = output,
1773*4bdc9457SAndroid Build Coastguard Worker .y_stride = convolution_op->output_pixel_stride << log2_output_element_size,
1774*4bdc9457SAndroid Build Coastguard Worker .ukernel = convolution_op->ukernel.vmulcaddc.function,
1775*4bdc9457SAndroid Build Coastguard Worker };
1776*4bdc9457SAndroid Build Coastguard Worker memcpy(&convolution_op->context.vmulcaddc.params, &convolution_op->params, sizeof(convolution_op->context.vmulcaddc.params));
1777*4bdc9457SAndroid Build Coastguard Worker
1778*4bdc9457SAndroid Build Coastguard Worker #if XNN_TEST_MODE
1779*4bdc9457SAndroid Build Coastguard Worker const size_t mc = convolution_op->ukernel.vmulcaddc.mr;
1780*4bdc9457SAndroid Build Coastguard Worker #else
1781*4bdc9457SAndroid Build Coastguard Worker size_t mc = batch_output_size;
1782*4bdc9457SAndroid Build Coastguard Worker if (num_threads > 1) {
1783*4bdc9457SAndroid Build Coastguard Worker const size_t target_tiles_per_thread = 5;
1784*4bdc9457SAndroid Build Coastguard Worker const size_t max_mc = divide_round_up(batch_output_size, num_threads * target_tiles_per_thread);
1785*4bdc9457SAndroid Build Coastguard Worker if (max_mc < mc) {
1786*4bdc9457SAndroid Build Coastguard Worker const uint32_t mr = convolution_op->ukernel.vmulcaddc.mr;
1787*4bdc9457SAndroid Build Coastguard Worker mc = min(mc, divide_round_up(mc, max_mc * mr) * mr);
1788*4bdc9457SAndroid Build Coastguard Worker }
1789*4bdc9457SAndroid Build Coastguard Worker }
1790*4bdc9457SAndroid Build Coastguard Worker #endif
1791*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.type = xnn_parallelization_type_1d_tile_1d;
1792*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_vmulcaddc;
1793*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.range[0] = batch_output_size;
1794*4bdc9457SAndroid Build Coastguard Worker convolution_op->compute.tile[0] = mc;
1795*4bdc9457SAndroid Build Coastguard Worker convolution_op->state = xnn_run_state_ready;
1796*4bdc9457SAndroid Build Coastguard Worker
1797*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
1798*4bdc9457SAndroid Build Coastguard Worker }
1799*4bdc9457SAndroid Build Coastguard Worker default:
1800*4bdc9457SAndroid Build Coastguard Worker XNN_UNREACHABLE;
1801*4bdc9457SAndroid Build Coastguard Worker }
1802*4bdc9457SAndroid Build Coastguard Worker }
1803*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_convolution2d_nhwc_qu8(xnn_operator_t convolution_op,size_t batch_size,size_t input_height,size_t input_width,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)1804*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_convolution2d_nhwc_qu8(
1805*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op,
1806*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
1807*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
1808*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
1809*4bdc9457SAndroid Build Coastguard Worker const uint8_t* input,
1810*4bdc9457SAndroid Build Coastguard Worker uint8_t* output,
1811*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
1812*4bdc9457SAndroid Build Coastguard Worker {
1813*4bdc9457SAndroid Build Coastguard Worker return setup_convolution2d_nhwc(
1814*4bdc9457SAndroid Build Coastguard Worker convolution_op, xnn_operator_type_convolution_nhwc_qu8,
1815*4bdc9457SAndroid Build Coastguard Worker batch_size, input_height, input_width,
1816*4bdc9457SAndroid Build Coastguard Worker input, output,
1817*4bdc9457SAndroid Build Coastguard Worker XNN_INIT_FLAG_QU8,
1818*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
1819*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
1820*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) /* sizeof(extra weights elements) */,
1821*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
1822*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
1823*4bdc9457SAndroid Build Coastguard Worker }
1824*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_convolution2d_nhwc_qs8(xnn_operator_t convolution_op,size_t batch_size,size_t input_height,size_t input_width,const int8_t * input,int8_t * output,pthreadpool_t threadpool)1825*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_convolution2d_nhwc_qs8(
1826*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op,
1827*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
1828*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
1829*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
1830*4bdc9457SAndroid Build Coastguard Worker const int8_t* input,
1831*4bdc9457SAndroid Build Coastguard Worker int8_t* output,
1832*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
1833*4bdc9457SAndroid Build Coastguard Worker {
1834*4bdc9457SAndroid Build Coastguard Worker return setup_convolution2d_nhwc(
1835*4bdc9457SAndroid Build Coastguard Worker convolution_op, xnn_operator_type_convolution_nhwc_qs8,
1836*4bdc9457SAndroid Build Coastguard Worker batch_size, input_height, input_width,
1837*4bdc9457SAndroid Build Coastguard Worker input, output,
1838*4bdc9457SAndroid Build Coastguard Worker XNN_INIT_FLAG_QS8,
1839*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(input element)) = log2(sizeof(int8_t)) */,
1840*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
1841*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) /* sizeof(extra weights elements) */,
1842*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(output element)) = log2(sizeof(int8_t)) */,
1843*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
1844*4bdc9457SAndroid Build Coastguard Worker }
1845*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_convolution2d_nhwc_qc8(xnn_operator_t convolution_op,size_t batch_size,size_t input_height,size_t input_width,const int8_t * input,int8_t * output,pthreadpool_t threadpool)1846*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_convolution2d_nhwc_qc8(
1847*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op,
1848*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
1849*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
1850*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
1851*4bdc9457SAndroid Build Coastguard Worker const int8_t* input,
1852*4bdc9457SAndroid Build Coastguard Worker int8_t* output,
1853*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
1854*4bdc9457SAndroid Build Coastguard Worker {
1855*4bdc9457SAndroid Build Coastguard Worker return setup_convolution2d_nhwc(
1856*4bdc9457SAndroid Build Coastguard Worker convolution_op, xnn_operator_type_convolution_nhwc_qc8,
1857*4bdc9457SAndroid Build Coastguard Worker batch_size, input_height, input_width,
1858*4bdc9457SAndroid Build Coastguard Worker input, output,
1859*4bdc9457SAndroid Build Coastguard Worker XNN_INIT_FLAG_QC8,
1860*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(input element)) = log2(sizeof(int8_t)) */,
1861*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
1862*4bdc9457SAndroid Build Coastguard Worker sizeof(int32_t) + sizeof(float) /* sizeof(extra weights elements) */,
1863*4bdc9457SAndroid Build Coastguard Worker 0 /* log2(sizeof(output element)) = log2(sizeof(int8_t)) */,
1864*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
1865*4bdc9457SAndroid Build Coastguard Worker }
1866*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_convolution2d_nhwc_f16(xnn_operator_t convolution_op,size_t batch_size,size_t input_height,size_t input_width,const void * input,void * output,pthreadpool_t threadpool)1867*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_convolution2d_nhwc_f16(
1868*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op,
1869*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
1870*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
1871*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
1872*4bdc9457SAndroid Build Coastguard Worker const void* input,
1873*4bdc9457SAndroid Build Coastguard Worker void* output,
1874*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
1875*4bdc9457SAndroid Build Coastguard Worker {
1876*4bdc9457SAndroid Build Coastguard Worker return setup_convolution2d_nhwc(
1877*4bdc9457SAndroid Build Coastguard Worker convolution_op, xnn_operator_type_convolution_nhwc_f16,
1878*4bdc9457SAndroid Build Coastguard Worker batch_size, input_height, input_width,
1879*4bdc9457SAndroid Build Coastguard Worker input, output,
1880*4bdc9457SAndroid Build Coastguard Worker XNN_INIT_FLAG_F16,
1881*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(input element)) = log2(sizeof(uint16_t)) */,
1882*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(filter element)) = log2(sizeof(uint16_t)) */,
1883*4bdc9457SAndroid Build Coastguard Worker sizeof(uint16_t) /* sizeof(extra weights elements) */,
1884*4bdc9457SAndroid Build Coastguard Worker 1 /* log2(sizeof(output element)) = log2(sizeof(uint16_t)) */,
1885*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
1886*4bdc9457SAndroid Build Coastguard Worker }
1887*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_convolution2d_nhwc_f32(xnn_operator_t convolution_op,size_t batch_size,size_t input_height,size_t input_width,const float * input,float * output,pthreadpool_t threadpool)1888*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_convolution2d_nhwc_f32(
1889*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op,
1890*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
1891*4bdc9457SAndroid Build Coastguard Worker size_t input_height,
1892*4bdc9457SAndroid Build Coastguard Worker size_t input_width,
1893*4bdc9457SAndroid Build Coastguard Worker const float* input,
1894*4bdc9457SAndroid Build Coastguard Worker float* output,
1895*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
1896*4bdc9457SAndroid Build Coastguard Worker {
1897*4bdc9457SAndroid Build Coastguard Worker return setup_convolution2d_nhwc(
1898*4bdc9457SAndroid Build Coastguard Worker convolution_op, xnn_operator_type_convolution_nhwc_f32,
1899*4bdc9457SAndroid Build Coastguard Worker batch_size, input_height, input_width,
1900*4bdc9457SAndroid Build Coastguard Worker input, output,
1901*4bdc9457SAndroid Build Coastguard Worker XNN_INIT_FLAG_F32,
1902*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
1903*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
1904*4bdc9457SAndroid Build Coastguard Worker sizeof(float) /* sizeof(extra weights elements) */,
1905*4bdc9457SAndroid Build Coastguard Worker 2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
1906*4bdc9457SAndroid Build Coastguard Worker pthreadpool_get_threads_count(threadpool));
1907*4bdc9457SAndroid Build Coastguard Worker }
1908