xref: /aosp_15_r20/external/XNNPACK/src/operators/deconvolution-nhwc.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <assert.h>
10 #include <stdbool.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <string.h>
14 #include <math.h>
15 
16 #include <fp16.h>
17 
18 #include <xnnpack.h>
19 #include <xnnpack/allocator.h>
20 #include <xnnpack/indirection.h>
21 #include <xnnpack/log.h>
22 #include <xnnpack/math.h>
23 #include <xnnpack/operator.h>
24 #include <xnnpack/pack.h>
25 #include <xnnpack/params.h>
26 
27 #ifndef XNN_ENABLE_GEMM_M_SPECIALIZATION
28 #error "XNN_ENABLE_GEMM_M_SPECIALIZATION is not defined"
29 #endif
30 
create_deconvolution2d_nhwc(uint32_t output_padding_top,uint32_t output_padding_right,uint32_t output_padding_bottom,uint32_t output_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_pixel_stride,size_t output_pixel_stride,const void * kernel,const void * bias,uint32_t flags,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,xnn_pack_conv_goki_w_function pack_conv_goki_w,xnn_pack_deconv_goki_w_function pack_deconv_goki_w,const void * packing_params,int input_padding_byte,int packed_weights_padding_byte,const void * params,size_t params_size,const struct gemm_parameters * gemm_parameters,const struct gemm_fused_ukernels * gemm_ukernels,enum xnn_operator_type operator_type,xnn_caches_t caches,xnn_operator_t * deconvolution_op_out)31 static enum xnn_status create_deconvolution2d_nhwc(
32     uint32_t output_padding_top,
33     uint32_t output_padding_right,
34     uint32_t output_padding_bottom,
35     uint32_t output_padding_left,
36     uint32_t kernel_height,
37     uint32_t kernel_width,
38     uint32_t stride_height,
39     uint32_t stride_width,
40     uint32_t dilation_height,
41     uint32_t dilation_width,
42     uint32_t groups,
43     size_t group_input_channels,
44     size_t group_output_channels,
45     size_t input_pixel_stride,
46     size_t output_pixel_stride,
47     const void* kernel,
48     const void* bias,
49     uint32_t flags,
50     uint32_t log2_input_element_size,
51     uint32_t log2_filter_element_size,
52     uint32_t bias_element_size,
53     xnn_pack_conv_goki_w_function pack_conv_goki_w,
54     xnn_pack_deconv_goki_w_function pack_deconv_goki_w,
55     const void* packing_params,
56     int input_padding_byte,
57     int packed_weights_padding_byte,
58     const void* params,
59     size_t params_size,
60     const struct gemm_parameters* gemm_parameters,
61     const struct gemm_fused_ukernels* gemm_ukernels,
62     enum xnn_operator_type operator_type,
63     xnn_caches_t caches,
64     xnn_operator_t* deconvolution_op_out)
65 {
66   xnn_operator_t deconvolution_op = NULL;
67   enum xnn_status status = xnn_status_uninitialized;
68 
69   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
70     xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
71       xnn_operator_type_to_string(operator_type));
72     goto error;
73   }
74 
75   status = xnn_status_invalid_parameter;
76 
77   if (kernel_width == 0 || kernel_height == 0) {
78     xnn_log_error(
79       "failed to create %s operator with %" PRIu32 "x%" PRIu32 " kernel: kernel dimensions must be non-zero",
80       xnn_operator_type_to_string(operator_type), kernel_width, kernel_height);
81     goto error;
82   }
83 
84   if (stride_width == 0 || stride_height == 0) {
85     xnn_log_error(
86       "failed to create %s operator with %" PRIu32 "x%" PRIu32 " stride: stride dimensions must be non-zero",
87       xnn_operator_type_to_string(operator_type), stride_width, stride_height);
88     goto error;
89   }
90 
91   if (dilation_width == 0 || dilation_height == 0) {
92     xnn_log_error(
93       "failed to create %s operator with %" PRIu32 "x%" PRIu32 " dilation: dilation dimensions must be non-zero",
94       xnn_operator_type_to_string(operator_type), dilation_width, dilation_height);
95     goto error;
96   }
97 
98   if (groups == 0) {
99     xnn_log_error(
100       "failed to create %s operator with %" PRIu32 " groups: number of groups must be non-zero",
101       xnn_operator_type_to_string(operator_type), groups);
102     goto error;
103   }
104 
105   if (group_input_channels == 0) {
106     xnn_log_error(
107       "failed to create %s operator with %zu input channels per group: number of channels must be non-zero",
108       xnn_operator_type_to_string(operator_type), group_input_channels);
109     goto error;
110   }
111 
112   if (group_output_channels == 0) {
113     xnn_log_error(
114       "failed to create %s operator with %zu output channels per group: number of channels must be non-zero",
115       xnn_operator_type_to_string(operator_type), group_output_channels);
116     goto error;
117   }
118 
119   const size_t input_channels = groups * group_input_channels;
120   if (input_pixel_stride < input_channels) {
121     xnn_log_error(
122       "failed to create %s operator with input pixel stride of %zu: "
123       "stride must be at least as large as the number of output channels (%" PRIu32 "x%zu)",
124       xnn_operator_type_to_string(operator_type),
125       input_pixel_stride, groups, group_input_channels);
126     goto error;
127   }
128 
129   const size_t output_channels = groups * group_output_channels;
130   if (output_pixel_stride < output_channels) {
131     xnn_log_error(
132       "failed to create %s operator with output pixel stride of %zu: "
133       "stride must be at least as large as the number of output channels (%" PRIu32 "x%zu)",
134       xnn_operator_type_to_string(operator_type),
135       output_pixel_stride, groups, group_output_channels);
136     goto error;
137   }
138 
139   status = xnn_status_out_of_memory;
140 
141   deconvolution_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
142   if (deconvolution_op == NULL) {
143     xnn_log_error(
144       "failed to allocate %zu bytes for %s operator descriptor",
145       sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
146     goto error;
147   }
148 
149   if (caches != NULL) {
150     deconvolution_op->weights_cache = caches->weights_cache;
151   }
152 
153   const uint32_t mr = gemm_parameters->mr;
154   const uint32_t nr = gemm_parameters->nr;
155   const uint32_t kr = UINT32_C(1) << gemm_parameters->log2_kr;
156   const uint32_t sr = UINT32_C(1) << gemm_parameters->log2_sr;
157 
158   const uint32_t n_stride = round_up(group_output_channels, nr);
159   const uint32_t k_stride = round_up_po2(group_input_channels, kr * sr);
160   const uint32_t kernel_size = kernel_height * kernel_width;
161   enum xnn_ukernel_type ukernel_type = xnn_ukernel_type_igemm;
162   size_t packed_group_weights_size = (((kernel_size * k_stride) << log2_filter_element_size) + bias_element_size) * n_stride;
163   if (max(stride_height, stride_width) > 1 && max(dilation_height, dilation_width) == 1 && stride_width <= kernel_width && stride_height <= kernel_height) {
164     ukernel_type = xnn_ukernel_type_subconv2d;
165     const size_t subkernels = stride_height * stride_width;
166     packed_group_weights_size = n_stride *
167       (((kernel_size * k_stride) << log2_filter_element_size) + bias_element_size * subkernels);
168 
169     const size_t subconvolution_buffer_size = sizeof(struct subconvolution_params) * subkernels;
170     deconvolution_op->subconvolution_buffer = xnn_allocate_zero_memory(subconvolution_buffer_size);
171     if (deconvolution_op->subconvolution_buffer == NULL) {
172       xnn_log_error(
173         "failed to allocate %zu bytes for %s operator subconvolution buffer",
174         subconvolution_buffer_size, xnn_operator_type_to_string(operator_type));
175       goto error;
176     }
177 
178     struct subconvolution_params* subconvolution_params = deconvolution_op->subconvolution_buffer;
179     for (size_t offset_y = 0; offset_y < stride_height; offset_y++) {
180       for (size_t offset_x = 0; offset_x < stride_width; offset_x++) {
181         const size_t subkernel_height = divide_round_up(kernel_height - offset_y, stride_height);
182         const size_t subkernel_width = divide_round_up(kernel_width - offset_x, stride_width);
183         const size_t subkernel_size = subkernel_height * subkernel_width;
184 
185         subconvolution_params->indirection_x_stride = sizeof(void*) * subkernel_size;
186         subconvolution_params->w_stride = bias_element_size + ((k_stride * subkernel_size) << log2_filter_element_size);
187         subconvolution_params++;
188       }
189     }
190   }
191 
192   const size_t aligned_total_weights_size = round_up_po2(packed_group_weights_size * groups, XNN_ALLOCATION_ALIGNMENT);
193   void* weights_ptr = xnn_get_pointer_to_write_weights(
194       deconvolution_op, aligned_total_weights_size, packed_weights_padding_byte);
195   if (weights_ptr == NULL) {
196     xnn_log_error(
197       "failed to allocate %zu bytes for %s operator packed weights",
198       aligned_total_weights_size, xnn_operator_type_to_string(operator_type));
199     goto error;
200   }
201 
202   switch (ukernel_type) {
203     case xnn_ukernel_type_igemm:
204       pack_conv_goki_w(
205         groups, group_output_channels, kernel_size, group_input_channels,
206         nr, kr, sr,
207         kernel, bias, weights_ptr,
208         0 /* extra bytes */,
209         packing_params);
210       break;
211     case xnn_ukernel_type_subconv2d:
212       pack_deconv_goki_w(
213         groups, group_output_channels, kernel_height, kernel_width, group_input_channels,
214         stride_height, stride_width,
215         nr, kr, sr,
216         kernel, bias, weights_ptr, deconvolution_op->subconvolution_buffer,
217         packing_params);
218       // We assume that the first subconvolution param weights point to the start of the weights, this is used to check
219       // if the weights cache has moved.
220       assert(deconvolution_op->subconvolution_buffer->weights == weights_ptr);
221       break;
222     default:
223       XNN_UNREACHABLE;
224   }
225 
226   if (use_weights_cache(deconvolution_op)) {
227     deconvolution_op->packed_weights.offset = xnn_get_or_insert_weights_cache(
228         deconvolution_op->weights_cache, weights_ptr, aligned_total_weights_size);
229   }
230 
231   const size_t zero_size = (k_stride << log2_input_element_size) + XNN_EXTRA_BYTES;
232   deconvolution_op->zero_buffer = xnn_allocate_simd_memory(zero_size);
233   if (deconvolution_op->zero_buffer == NULL) {
234     xnn_log_error(
235       "failed to allocate %zu bytes for %s operator zero padding",
236       zero_size, xnn_operator_type_to_string(operator_type));
237     goto error;
238   }
239   memset(deconvolution_op->zero_buffer, input_padding_byte, zero_size);
240 
241   deconvolution_op->padding_top = output_padding_top;
242   deconvolution_op->padding_right = output_padding_right;
243   deconvolution_op->padding_bottom = output_padding_bottom;
244   deconvolution_op->padding_left = output_padding_left;
245 
246   deconvolution_op->kernel_height = kernel_height;
247   deconvolution_op->kernel_width = kernel_width;
248   deconvolution_op->stride_height = stride_height;
249   deconvolution_op->stride_width = stride_width;
250   deconvolution_op->dilation_height = dilation_height;
251   deconvolution_op->dilation_width = dilation_width;
252   deconvolution_op->groups = groups;
253   deconvolution_op->group_input_channels = group_input_channels;
254   deconvolution_op->group_output_channels = group_output_channels;
255   deconvolution_op->input_pixel_stride = input_pixel_stride;
256   deconvolution_op->output_pixel_stride = output_pixel_stride;
257 
258   memcpy(&deconvolution_op->params, params, params_size);
259   deconvolution_op->type = operator_type;
260   deconvolution_op->ukernel.type = ukernel_type;
261   deconvolution_op->ukernel.igemm = (struct xnn_ukernel_igemm) {
262     .mr = mr,
263     .nr = nr,
264     .kr = kr,
265     .sr = sr,
266   };
267 
268   assert(XNN_MAX_MR >= mr);
269   for (size_t i = 0; i < mr; i++) {
270     if (gemm_ukernels->gemm[i].function[XNN_UARCH_DEFAULT] != NULL) {
271       deconvolution_op->ukernel.igemm.gemm_cases[i] = gemm_ukernels->gemm[i];
272     }
273     if (gemm_ukernels->igemm[i].function[XNN_UARCH_DEFAULT] != NULL) {
274       deconvolution_op->ukernel.igemm.igemm_cases[i] = gemm_ukernels->igemm[i];
275     }
276   }
277 
278   deconvolution_op->state = xnn_run_state_invalid;
279 
280   *deconvolution_op_out = deconvolution_op;
281   return xnn_status_success;
282 
283 error:
284   xnn_delete_operator(deconvolution_op);
285   return status;
286 }
287 
xnn_create_deconvolution2d_nhwc_qs8(uint32_t output_padding_top,uint32_t output_padding_right,uint32_t output_padding_bottom,uint32_t output_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_pixel_stride,size_t output_pixel_stride,int8_t input_zero_point,float input_scale,float kernel_scale,const int8_t * kernel,const int32_t * bias,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * deconvolution_op_out)288 enum xnn_status xnn_create_deconvolution2d_nhwc_qs8(
289     uint32_t output_padding_top,
290     uint32_t output_padding_right,
291     uint32_t output_padding_bottom,
292     uint32_t output_padding_left,
293     uint32_t kernel_height,
294     uint32_t kernel_width,
295     uint32_t stride_height,
296     uint32_t stride_width,
297     uint32_t dilation_height,
298     uint32_t dilation_width,
299     uint32_t groups,
300     size_t group_input_channels,
301     size_t group_output_channels,
302     size_t input_pixel_stride,
303     size_t output_pixel_stride,
304     int8_t input_zero_point,
305     float input_scale,
306     float kernel_scale,
307     const int8_t* kernel,
308     const int32_t* bias,
309     int8_t output_zero_point,
310     float output_scale,
311     int8_t output_min,
312     int8_t output_max,
313     uint32_t flags,
314     xnn_caches_t caches,
315     xnn_operator_t* deconvolution_op_out)
316 {
317   if (input_scale <= 0.0f || !isnormal(input_scale)) {
318     xnn_log_error(
319       "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
320       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qs8), input_scale);
321     return xnn_status_invalid_parameter;
322   }
323 
324   if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
325     xnn_log_error(
326       "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
327       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qs8), kernel_scale);
328     return xnn_status_invalid_parameter;
329   }
330 
331   if (output_scale <= 0.0f || !isnormal(output_scale)) {
332     xnn_log_error(
333       "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
334       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qs8), output_scale);
335     return xnn_status_invalid_parameter;
336   }
337 
338   if (output_min >= output_max) {
339     xnn_log_error(
340       "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
341       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qs8), output_min, output_max);
342     return xnn_status_invalid_parameter;
343   }
344 
345   const float requantization_scale = input_scale * kernel_scale / output_scale;
346   if (requantization_scale >= 256.0f) {
347     xnn_log_error(
348       "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
349       "requantization scale %.7g is greater or equal to 256.0",
350       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qs8),
351       input_scale, kernel_scale, output_scale, requantization_scale);
352     return xnn_status_unsupported_parameter;
353   }
354 
355   union xnn_qs8_conv_minmax_params params;
356   if XNN_LIKELY(xnn_params.qs8.gemm.init.qs8 != NULL) {
357     xnn_params.qs8.gemm.init.qs8(&params,
358       requantization_scale, output_zero_point, output_min, output_max);
359   }
360   const struct xnn_qs8_packing_params packing_params = {
361     .input_zero_point = input_zero_point,
362   };
363   return create_deconvolution2d_nhwc(
364     output_padding_top, output_padding_right, output_padding_bottom, output_padding_left,
365     kernel_height, kernel_width,
366     stride_height, stride_width,
367     dilation_height, dilation_width,
368     groups, group_input_channels, group_output_channels,
369     input_pixel_stride, output_pixel_stride,
370     kernel, bias, flags,
371     0 /* log2(sizeof(input element)) = log2(sizeof(int8_t)) */,
372     0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
373     sizeof(int32_t) /* sizeof(bias element) */,
374     (xnn_pack_conv_goki_w_function) xnn_pack_qs8_conv_goki_w,
375     (xnn_pack_deconv_goki_w_function) xnn_pack_qs8_deconv_goki_w,
376     &packing_params, input_zero_point /* input padding byte */, 0 /* packed weights padding byte */,
377     &params, sizeof(params),
378     &xnn_params.qs8.gemm, &xnn_params.qs8.gemm.minmax,
379     xnn_operator_type_deconvolution_nhwc_qs8,
380     caches,
381     deconvolution_op_out);
382 }
383 
xnn_create_deconvolution2d_nhwc_qu8(uint32_t output_padding_top,uint32_t output_padding_right,uint32_t output_padding_bottom,uint32_t output_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_pixel_stride,size_t output_pixel_stride,uint8_t input_zero_point,float input_scale,uint8_t kernel_zero_point,float kernel_scale,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * deconvolution_op_out)384 enum xnn_status xnn_create_deconvolution2d_nhwc_qu8(
385     uint32_t output_padding_top,
386     uint32_t output_padding_right,
387     uint32_t output_padding_bottom,
388     uint32_t output_padding_left,
389     uint32_t kernel_height,
390     uint32_t kernel_width,
391     uint32_t stride_height,
392     uint32_t stride_width,
393     uint32_t dilation_height,
394     uint32_t dilation_width,
395     uint32_t groups,
396     size_t group_input_channels,
397     size_t group_output_channels,
398     size_t input_pixel_stride,
399     size_t output_pixel_stride,
400     uint8_t input_zero_point,
401     float input_scale,
402     uint8_t kernel_zero_point,
403     float kernel_scale,
404     const uint8_t* kernel,
405     const int32_t* bias,
406     uint8_t output_zero_point,
407     float output_scale,
408     uint8_t output_min,
409     uint8_t output_max,
410     uint32_t flags,
411     xnn_caches_t caches,
412     xnn_operator_t* deconvolution_op_out)
413 {
414   if (input_scale <= 0.0f || !isnormal(input_scale)) {
415     xnn_log_error(
416       "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
417       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qu8), input_scale);
418     return xnn_status_invalid_parameter;
419   }
420 
421   if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) {
422     xnn_log_error(
423       "failed to create %s operator with %.7g kernel scale: scale must be finite, normalized, and positive",
424       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qu8), kernel_scale);
425     return xnn_status_invalid_parameter;
426   }
427 
428   if (output_scale <= 0.0f || !isnormal(output_scale)) {
429     xnn_log_error(
430       "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
431       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qu8), output_scale);
432     return xnn_status_invalid_parameter;
433   }
434 
435   if (output_min >= output_max) {
436     xnn_log_error(
437       "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
438       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qu8), output_min, output_max);
439     return xnn_status_invalid_parameter;
440   }
441 
442   const float requantization_scale = input_scale * kernel_scale / output_scale;
443   if (requantization_scale >= 256.0f) {
444     xnn_log_error(
445       "failed to create %s operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: "
446       "requantization scale %.7g is greater or equal to 256.0",
447       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qu8),
448       input_scale, kernel_scale, output_scale, requantization_scale);
449     return xnn_status_unsupported_parameter;
450   }
451 
452   union xnn_qu8_conv_minmax_params params;
453   if XNN_LIKELY(xnn_params.qu8.gemm.init.qu8 != NULL) {
454     xnn_params.qu8.gemm.init.qu8(&params,
455       kernel_zero_point, requantization_scale, output_zero_point, output_min, output_max);
456   }
457   const struct xnn_qu8_packing_params packing_params = {
458     .input_zero_point = input_zero_point,
459     .kernel_zero_point = kernel_zero_point,
460   };
461   return create_deconvolution2d_nhwc(
462     output_padding_top, output_padding_right, output_padding_bottom, output_padding_left,
463     kernel_height, kernel_width,
464     stride_height, stride_width,
465     dilation_height, dilation_width,
466     groups, group_input_channels, group_output_channels,
467     input_pixel_stride, output_pixel_stride,
468     kernel, bias, flags,
469     0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
470     0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
471     sizeof(int32_t) /* sizeof(bias element) */,
472     (xnn_pack_conv_goki_w_function) xnn_pack_qu8_conv_goki_w,
473     (xnn_pack_deconv_goki_w_function) xnn_pack_qu8_deconv_goki_w,
474     &packing_params, input_zero_point /* input padding byte */, kernel_zero_point /* packed weights padding byte */,
475     &params, sizeof(params),
476     &xnn_params.qu8.gemm, &xnn_params.qu8.gemm.minmax,
477     xnn_operator_type_deconvolution_nhwc_qu8,
478     caches,
479     deconvolution_op_out);
480 }
481 
xnn_create_deconvolution2d_nhwc_f16(uint32_t output_padding_top,uint32_t output_padding_right,uint32_t output_padding_bottom,uint32_t output_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_pixel_stride,size_t output_pixel_stride,const void * kernel,const void * bias,float output_min,float output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * deconvolution_op_out)482 enum xnn_status xnn_create_deconvolution2d_nhwc_f16(
483     uint32_t output_padding_top,
484     uint32_t output_padding_right,
485     uint32_t output_padding_bottom,
486     uint32_t output_padding_left,
487     uint32_t kernel_height,
488     uint32_t kernel_width,
489     uint32_t stride_height,
490     uint32_t stride_width,
491     uint32_t dilation_height,
492     uint32_t dilation_width,
493     uint32_t groups,
494     size_t group_input_channels,
495     size_t group_output_channels,
496     size_t input_pixel_stride,
497     size_t output_pixel_stride,
498     const void* kernel,
499     const void* bias,
500     float output_min,
501     float output_max,
502     uint32_t flags,
503     xnn_caches_t caches,
504     xnn_operator_t* deconvolution_op_out)
505 {
506   if ((xnn_params.init_flags & XNN_INIT_FLAG_F16) != XNN_INIT_FLAG_F16) {
507     xnn_log_error("failed to create %s operator: operations on data type are not supported",
508       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f16));
509     return xnn_status_unsupported_hardware;
510   }
511 
512   if (isnan(output_min)) {
513     xnn_log_error(
514       "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
515       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f16));
516     return xnn_status_invalid_parameter;
517   }
518 
519   if (isnan(output_max)) {
520     xnn_log_error(
521       "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
522       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f16));
523     return xnn_status_invalid_parameter;
524   }
525 
526   const uint16_t output_min_as_half = fp16_ieee_from_fp32_value(output_min);
527   const uint16_t output_max_as_half = fp16_ieee_from_fp32_value(output_max);
528   output_min = fp16_ieee_to_fp32_value(output_min_as_half);
529   output_max = fp16_ieee_to_fp32_value(output_max_as_half);
530   if (output_min >= output_max) {
531     xnn_log_error(
532       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
533       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f16), output_min, output_max);
534     return xnn_status_invalid_parameter;
535   }
536 
537   const struct gemm_parameters* gemm_parameters = &xnn_params.f16.gemm;
538   const struct gemm_fused_ukernels* gemm_ukernels = &gemm_parameters->minmax;
539   const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
540   if (linear_activation && gemm_parameters->linear.gemm[gemm_parameters->mr - 1].function[XNN_UARCH_DEFAULT] != NULL) {
541     gemm_ukernels = &gemm_parameters->linear;
542   }
543 
544   union xnn_f16_minmax_params params;
545   if XNN_LIKELY(xnn_params.f16.gemm.init.f16 != NULL) {
546     gemm_parameters->init.f16(&params, output_min_as_half, output_max_as_half);
547   }
548 
549   xnn_pack_conv_goki_w_function pack_conv_goki_w = (xnn_pack_conv_goki_w_function) xnn_pack_f16_conv_goki_w;
550   xnn_pack_deconv_goki_w_function pack_deconv_goki_w = (xnn_pack_deconv_goki_w_function) xnn_pack_f16_deconv_goki_w;
551   if (flags & XNN_FLAG_FP32_STATIC_WEIGHTS) {
552     pack_conv_goki_w = (xnn_pack_conv_goki_w_function) xnn_pack_f32_to_f16_conv_goki_w;
553     pack_deconv_goki_w = (xnn_pack_deconv_goki_w_function) xnn_pack_f32_to_f16_deconv_goki_w;
554   }
555 
556   return create_deconvolution2d_nhwc(
557     output_padding_top, output_padding_right, output_padding_bottom, output_padding_left,
558     kernel_height, kernel_width,
559     stride_height, stride_width,
560     dilation_height, dilation_width,
561     groups, group_input_channels, group_output_channels,
562     input_pixel_stride, output_pixel_stride,
563     kernel, bias, flags,
564     1 /* log2(sizeof(input element)) = log2(sizeof(uint16_t)) */,
565     1 /* log2(sizeof(filter element)) = log2(sizeof(uint16_t)) */,
566     sizeof(uint16_t) /* sizeof(bias element) */,
567     pack_conv_goki_w,
568     pack_deconv_goki_w,
569     NULL /* packing params */, 0 /* input padding byte */, 0 /* packed weights padding byte */,
570     &params, sizeof(params),
571     gemm_parameters, gemm_ukernels,
572     xnn_operator_type_deconvolution_nhwc_f16,
573     caches,
574     deconvolution_op_out);
575 }
576 
xnn_create_deconvolution2d_nhwc_f32(uint32_t output_padding_top,uint32_t output_padding_right,uint32_t output_padding_bottom,uint32_t output_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,size_t input_pixel_stride,size_t output_pixel_stride,const float * kernel,const float * bias,float output_min,float output_max,uint32_t flags,xnn_caches_t caches,xnn_operator_t * deconvolution_op_out)577 enum xnn_status xnn_create_deconvolution2d_nhwc_f32(
578     uint32_t output_padding_top,
579     uint32_t output_padding_right,
580     uint32_t output_padding_bottom,
581     uint32_t output_padding_left,
582     uint32_t kernel_height,
583     uint32_t kernel_width,
584     uint32_t stride_height,
585     uint32_t stride_width,
586     uint32_t dilation_height,
587     uint32_t dilation_width,
588     uint32_t groups,
589     size_t group_input_channels,
590     size_t group_output_channels,
591     size_t input_pixel_stride,
592     size_t output_pixel_stride,
593     const float* kernel,
594     const float* bias,
595     float output_min,
596     float output_max,
597     uint32_t flags,
598     xnn_caches_t caches,
599     xnn_operator_t* deconvolution_op_out)
600 {
601   if (isnan(output_min)) {
602     xnn_log_error(
603       "failed to create %s operator with NaN output lower bound: lower bound must be non-NaN",
604       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f32));
605     return xnn_status_invalid_parameter;
606   }
607 
608   if (isnan(output_max)) {
609     xnn_log_error(
610       "failed to create %s operator with NaN output upper bound: upper bound must be non-NaN",
611       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f32));
612     return xnn_status_invalid_parameter;
613   }
614 
615   if (output_min >= output_max) {
616     xnn_log_error(
617       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
618       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f32), output_min, output_max);
619     return xnn_status_invalid_parameter;
620   }
621 
622   const struct gemm_parameters* gemm_parameters = &xnn_params.f32.gemm;
623   if (gemm_parameters->nr > group_output_channels) {
624     // Default micro-kernel is suboptimal. Try to find a better micro-kernel.
625     const struct gemm_parameters* gemm2_parameters = &xnn_params.f32.gemm2;
626     if (gemm2_parameters->minmax.igemm[gemm2_parameters->mr - 1].function[XNN_UARCH_DEFAULT] != NULL) {
627       gemm_parameters = gemm2_parameters;
628     }
629   }
630 
631   const struct gemm_fused_ukernels* gemm_ukernels = &gemm_parameters->minmax;
632   const bool linear_activation = (output_max == INFINITY) && (output_min == -output_max);
633   if (linear_activation && gemm_parameters->linear.gemm[gemm_parameters->mr - 1].function[XNN_UARCH_DEFAULT] != NULL) {
634     gemm_ukernels = &gemm_parameters->linear;
635   }
636 
637   union xnn_f32_minmax_params params;
638   if XNN_LIKELY(xnn_params.f32.gemm.init.f32 != NULL) {
639     gemm_parameters->init.f32(&params, output_min, output_max);
640   }
641   return create_deconvolution2d_nhwc(
642     output_padding_top, output_padding_right, output_padding_bottom, output_padding_left,
643     kernel_height, kernel_width,
644     stride_height, stride_width,
645     dilation_height, dilation_width,
646     groups, group_input_channels, group_output_channels,
647     input_pixel_stride, output_pixel_stride,
648     kernel, bias, flags,
649     2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
650     2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
651     sizeof(float) /* sizeof(bias element) */,
652     (xnn_pack_conv_goki_w_function) xnn_pack_f32_conv_goki_w,
653     (xnn_pack_deconv_goki_w_function) xnn_pack_f32_deconv_goki_w,
654     NULL /* packing params */, 0 /* input padding byte */, 0 /* packed weights padding byte */,
655     &params, sizeof(params),
656     gemm_parameters, gemm_ukernels,
657     xnn_operator_type_deconvolution_nhwc_f32,
658     caches,
659     deconvolution_op_out);
660 }
661 
setup_conv_path(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,const void * input,size_t output_height,size_t output_width,void * output,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,uint32_t log2_output_element_size,const void * params,size_t params_size,size_t num_threads)662 static enum xnn_status setup_conv_path(
663   xnn_operator_t deconvolution_op,
664   size_t batch_size,
665   size_t input_height,
666   size_t input_width,
667   const void* input,
668   size_t output_height,
669   size_t output_width,
670   void* output,
671   uint32_t log2_input_element_size,
672   uint32_t log2_filter_element_size,
673   uint32_t bias_element_size,
674   uint32_t log2_output_element_size,
675   const void* params,
676   size_t params_size,
677   size_t num_threads)
678 {
679   assert(deconvolution_op->ukernel.type == xnn_ukernel_type_igemm);
680 
681   const size_t kernel_height = deconvolution_op->kernel_height;
682   const size_t kernel_width = deconvolution_op->kernel_width;
683   const size_t kernel_size = kernel_height * kernel_width;
684 
685   const size_t groups = deconvolution_op->groups;
686   const size_t output_size = output_height * output_width;
687   size_t mr = deconvolution_op->ukernel.igemm.mr;
688 
689   struct xnn_hmp_igemm_ukernel igemm_ukernel = deconvolution_op->ukernel.igemm.igemm_cases[mr - 1];
690   if (output_size == 1 && deconvolution_op->ukernel.igemm.igemm_cases[0].function[XNN_UARCH_DEFAULT] != NULL) {
691     mr = 1;
692     igemm_ukernel = deconvolution_op->ukernel.igemm.igemm_cases[0];
693   }
694 
695   const size_t tiled_output_size = round_up(output_size, mr);
696   const size_t indirection_buffer_size = sizeof(void*) * kernel_size * tiled_output_size;
697 
698   if (input_height != deconvolution_op->last_input_height ||
699       input_width != deconvolution_op->last_input_width)
700   {
701     const void** indirection_buffer = (const void**) xnn_reallocate_memory(deconvolution_op->indirection_buffer, indirection_buffer_size);
702     if (indirection_buffer == NULL) {
703       xnn_log_error(
704         "failed to allocate %zu bytes for %s operator indirection buffer",
705         indirection_buffer_size, xnn_operator_type_to_string(deconvolution_op->type));
706       return xnn_status_out_of_memory;
707     }
708     deconvolution_op->indirection_buffer = indirection_buffer;
709     deconvolution_op->last_input = input;
710     deconvolution_op->last_input_height = input_height;
711     deconvolution_op->last_input_width = input_width;
712 
713     xnn_indirection_init_deconv2d(deconvolution_op, mr, log2_input_element_size);
714   }
715 
716   const size_t group_input_channels = deconvolution_op->group_input_channels;
717   const size_t group_output_channels = deconvolution_op->group_output_channels;
718   const uint32_t nr = deconvolution_op->ukernel.igemm.nr;
719 
720   const size_t w_stride = bias_element_size +
721     (round_up_po2(group_input_channels, deconvolution_op->ukernel.igemm.kr * deconvolution_op->ukernel.igemm.sr) * kernel_size << log2_filter_element_size);
722   deconvolution_op->context.igemm = (struct igemm_context) {
723       .ks = kernel_size,
724       .ks_scaled = kernel_size * mr * sizeof(void*),
725       .kc = group_input_channels << log2_input_element_size,
726       .w_stride = w_stride,
727       .indirect_a = deconvolution_op->indirection_buffer,
728       .a_offset = (size_t) ((uintptr_t) input - (uintptr_t) deconvolution_op->last_input),
729       .zero = deconvolution_op->zero_buffer,
730       .packed_w = packed_weights(deconvolution_op),
731       .c = deconvolution_op->output,
732       .cm_stride = deconvolution_op->output_pixel_stride << log2_output_element_size,
733       .cn_stride = nr << log2_output_element_size,
734       .ga_stride = group_input_channels << log2_input_element_size,
735       .gw_stride = w_stride * round_up(group_output_channels, nr),
736       .gc_stride = group_output_channels << log2_output_element_size,
737       .ba_stride = input_height * input_width * deconvolution_op->input_pixel_stride << log2_input_element_size,
738       .bc_stride = output_size * deconvolution_op->output_pixel_stride << log2_output_element_size,
739       .log2_csize = log2_output_element_size,
740       .ukernel = igemm_ukernel,
741   };
742   memcpy(&deconvolution_op->context.igemm.params, params, params_size);
743 
744   #if XNN_TEST_MODE
745     const size_t nc = nr;
746   #else
747     size_t nc = group_output_channels;
748     if (num_threads > 1) {
749       const size_t num_other_tiles = groups * batch_size * divide_round_up(output_size, mr);
750       const size_t target_tiles_per_thread = 5;
751       const size_t max_nc = divide_round_up(group_output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
752       if (max_nc < nc) {
753         nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
754       }
755     }
756   #endif
757   if (groups == 1) {
758     #if XNN_MAX_UARCH_TYPES > 1
759       if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) {
760         if (batch_size > 1) {
761           deconvolution_op->compute.type = xnn_parallelization_type_3d_tile_2d_with_uarch;
762           deconvolution_op->compute.task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t) xnn_compute_batch_hmp_igemm;
763         } else {
764           deconvolution_op->compute.type = xnn_parallelization_type_2d_tile_2d_with_uarch;
765           deconvolution_op->compute.task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_igemm;
766         }
767       } else {
768         if (batch_size > 1) {
769           deconvolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
770           deconvolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_batch_igemm;
771         } else {
772           deconvolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
773           deconvolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_igemm;
774         }
775       }
776     #else
777       if (batch_size > 1) {
778         deconvolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
779         deconvolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_batch_igemm;
780       } else {
781         deconvolution_op->compute.type = xnn_parallelization_type_2d_tile_2d;
782         deconvolution_op->compute.task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_igemm;
783       }
784     #endif
785     if (batch_size > 1) {
786       deconvolution_op->compute.range[0] = batch_size;
787       deconvolution_op->compute.range[1] = output_size;
788       deconvolution_op->compute.range[2] = group_output_channels;
789     } else {
790       deconvolution_op->compute.range[0] = output_size;
791       deconvolution_op->compute.range[1] = group_output_channels;
792     }
793     deconvolution_op->compute.tile[0] = mr;
794     deconvolution_op->compute.tile[1] = nc;
795   } else {
796     #if XNN_MAX_UARCH_TYPES > 1
797       if (xnn_is_hmp_igemm_ukernel(igemm_ukernel)) {
798         if (batch_size > 1) {
799           deconvolution_op->compute.type = xnn_parallelization_type_4d_tile_2d_with_uarch;
800           deconvolution_op->compute.task_4d_tile_2d_with_id = (pthreadpool_task_4d_tile_2d_with_id_t) xnn_compute_hmp_grouped_batch_igemm;
801         } else {
802           deconvolution_op->compute.type = xnn_parallelization_type_3d_tile_2d_with_uarch;
803           deconvolution_op->compute.task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t) xnn_compute_hmp_grouped_igemm;
804         }
805       } else {
806         if (batch_size > 1) {
807           deconvolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
808           deconvolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_grouped_batch_igemm;
809         } else {
810           deconvolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
811           deconvolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_igemm;
812         }
813       }
814     #else
815       if (batch_size > 1) {
816         deconvolution_op->compute.type = xnn_parallelization_type_4d_tile_2d;
817         deconvolution_op->compute.task_4d_tile_2d = (pthreadpool_task_4d_tile_2d_t) xnn_compute_grouped_batch_igemm;
818       } else {
819         deconvolution_op->compute.type = xnn_parallelization_type_3d_tile_2d;
820         deconvolution_op->compute.task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_igemm;
821       }
822     #endif
823     if (batch_size > 1) {
824       deconvolution_op->compute.range[0] = batch_size;
825       deconvolution_op->compute.range[1] = groups;
826       deconvolution_op->compute.range[2] = output_size;
827       deconvolution_op->compute.range[3] = group_output_channels;
828     } else {
829       deconvolution_op->compute.range[0] = groups;
830       deconvolution_op->compute.range[1] = output_size;
831       deconvolution_op->compute.range[2] = group_output_channels;
832     }
833     deconvolution_op->compute.tile[0] = mr;
834     deconvolution_op->compute.tile[1] = nc;
835   }
836   deconvolution_op->state = xnn_run_state_ready;
837   return xnn_status_success;
838 }
839 
setup_subconv2d_path(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,const void * input,size_t output_height,size_t output_width,void * output,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,uint32_t log2_output_element_size,const void * params,size_t params_size,size_t num_threads,bool use_gemm)840 static enum xnn_status setup_subconv2d_path(
841   xnn_operator_t deconvolution_op,
842   size_t batch_size,
843   size_t input_height,
844   size_t input_width,
845   const void* input,
846   size_t output_height,
847   size_t output_width,
848   void* output,
849   uint32_t log2_input_element_size,
850   uint32_t log2_filter_element_size,
851   uint32_t bias_element_size,
852   uint32_t log2_output_element_size,
853   const void* params,
854   size_t params_size,
855   size_t num_threads,
856   bool use_gemm)
857 {
858   assert(deconvolution_op->ukernel.type == xnn_ukernel_type_subconv2d);
859 
860   const size_t kernel_height = deconvolution_op->kernel_height;
861   const size_t kernel_width = deconvolution_op->kernel_width;
862   const size_t kernel_size = kernel_height * kernel_width;
863   const size_t stride_height = deconvolution_op->stride_height;
864   const size_t stride_width = deconvolution_op->stride_width;
865   const size_t output_height_positions = divide_round_up(output_height, stride_height);
866   const size_t output_width_positions = divide_round_up(output_width, stride_width);
867 
868   const size_t groups = deconvolution_op->groups;
869   const size_t output_size = output_height * output_width;
870   uint32_t mr = deconvolution_op->ukernel.igemm.mr;
871   const uint32_t nr = deconvolution_op->ukernel.igemm.nr;
872   #if XNN_ENABLE_GEMM_M_SPECIALIZATION
873     mr = xnn_get_heuristic_mr_igemm(
874       output_width_positions, mr, nr, deconvolution_op->ukernel.igemm.igemm_cases);
875   #endif
876 
877   const size_t input_pixel_stride = deconvolution_op->input_pixel_stride << log2_input_element_size;
878   const size_t output_pixel_stride = deconvolution_op->output_pixel_stride << log2_output_element_size;
879 
880   const bool any_size_change =
881     input_height != deconvolution_op->last_input_height ||
882     input_width != deconvolution_op->last_input_width ||
883     output_height != deconvolution_op->last_output_height ||
884     output_width != deconvolution_op->last_output_width;
885 
886   if (deconvolution_op->weights_cache != NULL) {
887     void* packed_weights_ptr = packed_weights(deconvolution_op);
888     struct subconvolution_params* subconvolution_params = deconvolution_op->subconvolution_buffer;
889     if (packed_weights_ptr != subconvolution_params->weights) {
890       // Weights cache moved, update all weights pointer.
891       const ptrdiff_t diff = (uintptr_t) packed_weights_ptr - (uintptr_t) subconvolution_params->weights;
892       for (size_t offset_y = 0; offset_y < stride_height; offset_y++) {
893         for (size_t offset_x = 0; offset_x < stride_width; offset_x++) {
894           subconvolution_params->weights = (void*) ((uintptr_t) subconvolution_params->weights + diff);
895           ++subconvolution_params;
896         }
897       }
898     }
899   }
900 
901   if (any_size_change || output != deconvolution_op->last_output) {
902     // Initialize subconvolution parameters which depend on output dimensions or MR.
903     struct subconvolution_params* subconvolution_params = deconvolution_op->subconvolution_buffer;
904     const size_t modulo_padding_top = deconvolution_op->padding_top % stride_height;
905     const size_t modulo_padding_left = deconvolution_op->padding_left % stride_width;
906     for (size_t offset_y = 0; offset_y < stride_height; offset_y++) {
907       for (size_t offset_x = 0; offset_x < stride_width; offset_x++) {
908         const size_t output_x_start = subtract_modulo(offset_x, modulo_padding_left, stride_width);
909         const size_t output_y_start = subtract_modulo(offset_y, modulo_padding_top, stride_height);
910         subconvolution_params->scaled_kernel_size = mr * subconvolution_params->indirection_x_stride;
911         subconvolution_params->slice_width = divide_round_up(output_width - output_x_start, stride_width);
912         subconvolution_params->slice_height = divide_round_up(output_height - output_y_start, stride_height);
913         subconvolution_params->output =
914           (void*) ((uintptr_t) output + ((output_y_start * output_width + output_x_start) * output_pixel_stride));
915         ++subconvolution_params;
916       }
917     }
918     deconvolution_op->last_output = output;
919   }
920 
921   if (any_size_change) {
922     if (!use_gemm) {
923       const size_t indirection_buffer_size = sizeof(void*) *
924         kernel_size * output_height * stride_width * round_up(divide_round_up(output_width, stride_width), mr);
925 
926       const void** indirection_buffer =
927         (const void**) xnn_reallocate_memory(deconvolution_op->indirection_buffer, indirection_buffer_size);
928       if (indirection_buffer == NULL) {
929         xnn_log_error(
930           "failed to allocate %zu bytes for %s operator indirection buffer",
931           indirection_buffer_size, xnn_operator_type_to_string(deconvolution_op->type));
932         return xnn_status_out_of_memory;
933       }
934       deconvolution_op->indirection_buffer = indirection_buffer;
935       deconvolution_op->last_input = input;
936 
937       xnn_indirection_init_subconv2d(deconvolution_op, mr, log2_input_element_size);
938     }
939     deconvolution_op->last_input_height = input_height;
940     deconvolution_op->last_input_width = input_width;
941     deconvolution_op->last_output_height = output_height;
942     deconvolution_op->last_output_width = output_width;
943   }
944 
945   const size_t group_input_channels = deconvolution_op->group_input_channels;
946   const size_t group_output_channels = deconvolution_op->group_output_channels;
947   const uint32_t kr = deconvolution_op->ukernel.igemm.kr;
948   const uint32_t sr = deconvolution_op->ukernel.igemm.sr;
949   const size_t w_stride = stride_height * stride_width * bias_element_size +
950     (round_up_po2(group_input_channels, kr * sr) * kernel_size << log2_filter_element_size);
951   if (use_gemm) {
952     deconvolution_op->context.subgemm = (struct subgemm_context) {
953         .subconvolution_params = deconvolution_op->subconvolution_buffer,
954         .kc = group_input_channels << log2_input_element_size,
955         .a = input,
956         .ax_stride = input_pixel_stride,
957         .ay_stride = input_width * input_pixel_stride,
958         .cx_stride = stride_width * output_pixel_stride,
959         .cy_stride = stride_height * output_width * output_pixel_stride,
960         .cn_stride = nr << log2_output_element_size,
961         .ga_stride = group_input_channels << log2_input_element_size,
962         .gw_stride = w_stride * round_up(group_output_channels, nr),
963         .gc_stride = group_output_channels << log2_output_element_size,
964         .ba_stride = input_height * input_width * input_pixel_stride,
965         .bc_stride = output_size * output_pixel_stride,
966         .log2_csize = log2_output_element_size,
967         .ukernel = deconvolution_op->ukernel.igemm.gemm_cases[mr - 1],
968     };
969     memcpy(&deconvolution_op->context.subgemm.params, params, params_size);
970   } else {
971     deconvolution_op->context.subconv = (struct subconv_context) {
972         .subconvolution_params = deconvolution_op->subconvolution_buffer,
973         .kc = group_input_channels << log2_input_element_size,
974         .a_offset = (size_t) ((uintptr_t) input - (uintptr_t) deconvolution_op->last_input),
975         .zero = deconvolution_op->zero_buffer,
976         .cx_stride = stride_width * output_pixel_stride,
977         .cy_stride = stride_height * output_width * output_pixel_stride,
978         .cn_stride = nr << log2_output_element_size,
979         .ga_stride = group_input_channels << log2_input_element_size,
980         .gw_stride = w_stride * round_up(group_output_channels, nr),
981         .gc_stride = group_output_channels << log2_output_element_size,
982         .ba_stride = input_height * input_width * input_pixel_stride,
983         .bc_stride = output_size * output_pixel_stride,
984         .log2_csize = log2_output_element_size,
985         .ukernel = deconvolution_op->ukernel.igemm.igemm_cases[mr - 1],
986     };
987     memcpy(&deconvolution_op->context.subconv.params, params, params_size);
988   }
989 
990   #if XNN_TEST_MODE
991     const size_t nc = nr;
992   #else
993     size_t nc = group_output_channels;
994     if (num_threads > 1) {
995       const size_t num_other_tiles = groups * stride_height * stride_width *
996         output_height_positions * divide_round_up(output_width_positions, mr);
997       const size_t target_tiles_per_thread = 5;
998       const size_t max_nc = divide_round_up(group_output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
999       if (max_nc < nc) {
1000         nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
1001       }
1002     }
1003   #endif
1004 
1005   if (groups == 1) {
1006     deconvolution_op->compute.type = xnn_parallelization_type_5d_tile_2d;
1007     deconvolution_op->compute.task_5d_tile_2d = use_gemm ?
1008       (pthreadpool_task_5d_tile_2d_t) xnn_compute_subgemm2d : (pthreadpool_task_5d_tile_2d_t) xnn_compute_subconv2d;
1009     deconvolution_op->compute.range[0] = batch_size;
1010     deconvolution_op->compute.range[1] = stride_height * stride_width;
1011     deconvolution_op->compute.range[2] = output_height_positions;
1012     deconvolution_op->compute.range[3] = output_width_positions;
1013     deconvolution_op->compute.range[4] = group_output_channels;
1014     deconvolution_op->compute.tile[0] = mr;
1015     deconvolution_op->compute.tile[1] = nc;
1016   } else {
1017     deconvolution_op->compute.type = xnn_parallelization_type_6d_tile_2d;
1018     deconvolution_op->compute.task_6d_tile_2d = use_gemm ?
1019       (pthreadpool_task_6d_tile_2d_t) xnn_compute_grouped_subgemm2d : (pthreadpool_task_6d_tile_2d_t) xnn_compute_grouped_subconv2d;
1020     deconvolution_op->compute.range[0] = batch_size;
1021     deconvolution_op->compute.range[1] = groups;
1022     deconvolution_op->compute.range[2] = stride_height * stride_width;
1023     deconvolution_op->compute.range[3] = output_height_positions;
1024     deconvolution_op->compute.range[4] = output_width_positions;
1025     deconvolution_op->compute.range[5] = group_output_channels;
1026     deconvolution_op->compute.tile[0] = mr;
1027     deconvolution_op->compute.tile[1] = nc;
1028   }
1029 
1030   deconvolution_op->state = xnn_run_state_ready;
1031   return xnn_status_success;
1032 }
1033 
setup_deconvolution2d_nhwc(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,uint32_t adjustment_height,uint32_t adjustment_width,const void * input,void * output,uint32_t log2_input_element_size,uint32_t log2_filter_element_size,uint32_t bias_element_size,uint32_t log2_output_element_size,const void * params,size_t params_size,size_t num_threads)1034 static enum xnn_status setup_deconvolution2d_nhwc(
1035   xnn_operator_t deconvolution_op,
1036   size_t batch_size,
1037   size_t input_height,
1038   size_t input_width,
1039   uint32_t adjustment_height,
1040   uint32_t adjustment_width,
1041   const void* input,
1042   void* output,
1043   uint32_t log2_input_element_size,
1044   uint32_t log2_filter_element_size,
1045   uint32_t bias_element_size,
1046   uint32_t log2_output_element_size,
1047   const void* params,
1048   size_t params_size,
1049   size_t num_threads)
1050 {
1051   deconvolution_op->state = xnn_run_state_invalid;
1052 
1053   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
1054     xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
1055       xnn_operator_type_to_string(deconvolution_op->type));
1056     return xnn_status_uninitialized;
1057   }
1058 
1059   if (input_width == 0 || input_height == 0) {
1060     xnn_log_error(
1061       "failed to setup %s operator with %zux%zu input: input dimensions must be non-zero",
1062       xnn_operator_type_to_string(deconvolution_op->type), input_width, input_height);
1063     return xnn_status_invalid_parameter;
1064   }
1065 
1066   if (adjustment_height >= deconvolution_op->stride_height) {
1067     xnn_log_error(
1068       "failed to setup %s operator with %" PRIu32 " height adjustment: "
1069       "height adjustment must be smaller than height stride (%" PRIu32 ")",
1070       xnn_operator_type_to_string(deconvolution_op->type), adjustment_height, deconvolution_op->stride_height);
1071     return xnn_status_invalid_parameter;
1072   }
1073 
1074   if (adjustment_width >= deconvolution_op->stride_width) {
1075     xnn_log_error(
1076       "failed to setup %s operator with %" PRIu32 " width adjustment: "
1077       "width adjustment must be smaller than width stride (%" PRIu32 ")",
1078       xnn_operator_type_to_string(deconvolution_op->type), adjustment_width, deconvolution_op->stride_width);
1079     return xnn_status_invalid_parameter;
1080   }
1081 
1082   if (batch_size == 0) {
1083     deconvolution_op->state = xnn_run_state_skip;
1084     return xnn_status_success;
1085   }
1086 
1087   if (deconvolution_op->weights_cache != NULL && !xnn_weights_cache_is_finalized(deconvolution_op->weights_cache)) {
1088     xnn_log_error("failed to setup %s operator: weights cache is not finalized",
1089                   xnn_operator_type_to_string(deconvolution_op->type));
1090     return xnn_status_invalid_state;
1091   }
1092 
1093   deconvolution_op->batch_size = batch_size;
1094   deconvolution_op->input_height = input_height;
1095   deconvolution_op->input_width = input_width;
1096   deconvolution_op->input = input;
1097   deconvolution_op->output = output;
1098 
1099   deconvolution_op->output_height = xnn_compute_deconvolution_output_dimension(
1100       input_height, deconvolution_op->padding_top + deconvolution_op->padding_bottom,
1101       adjustment_height, deconvolution_op->kernel_height, deconvolution_op->dilation_height, deconvolution_op->stride_height);
1102   deconvolution_op->output_width = deconvolution_op->output_width = xnn_compute_deconvolution_output_dimension(
1103       input_width, deconvolution_op->padding_left + deconvolution_op->padding_right,
1104       adjustment_width, deconvolution_op->kernel_width, deconvolution_op->dilation_width, deconvolution_op->stride_width);
1105 
1106   switch (deconvolution_op->ukernel.type) {
1107     case xnn_ukernel_type_igemm:
1108       return setup_conv_path(
1109         deconvolution_op,
1110         batch_size,
1111         input_height, input_width, input,
1112         deconvolution_op->output_height, deconvolution_op->output_width, output,
1113         log2_input_element_size, log2_filter_element_size, bias_element_size, log2_output_element_size,
1114         params, params_size, num_threads);
1115     case xnn_ukernel_type_subconv2d:
1116     {
1117       const size_t mr = deconvolution_op->ukernel.igemm.mr;
1118       const bool no_padding = (deconvolution_op->padding_top | deconvolution_op->padding_right | deconvolution_op->padding_bottom | deconvolution_op->padding_left) == 0;
1119       const bool no_adjustment = (adjustment_height | adjustment_width) == 0;
1120       const bool use_gemm = no_padding && no_adjustment &&
1121         deconvolution_op->kernel_height == deconvolution_op->stride_height &&
1122         deconvolution_op->kernel_width == deconvolution_op->stride_width &&
1123         deconvolution_op->ukernel.igemm.gemm_cases[mr - 1].function[XNN_UARCH_DEFAULT] != NULL;
1124       return setup_subconv2d_path(
1125         deconvolution_op,
1126         batch_size,
1127         input_height, input_width, input,
1128         deconvolution_op->output_height, deconvolution_op->output_width, output,
1129         log2_input_element_size, log2_filter_element_size, bias_element_size, log2_output_element_size,
1130         params, params_size, num_threads, use_gemm);
1131     }
1132     default:
1133       XNN_UNREACHABLE;
1134   }
1135 }
1136 
xnn_setup_deconvolution2d_nhwc_qs8(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,uint32_t adjustment_height,uint32_t adjustment_width,const int8_t * input,int8_t * output,pthreadpool_t threadpool)1137 enum xnn_status xnn_setup_deconvolution2d_nhwc_qs8(
1138     xnn_operator_t deconvolution_op,
1139     size_t batch_size,
1140     size_t input_height,
1141     size_t input_width,
1142     uint32_t adjustment_height,
1143     uint32_t adjustment_width,
1144     const int8_t* input,
1145     int8_t* output,
1146     pthreadpool_t threadpool)
1147 {
1148   if (deconvolution_op->type != xnn_operator_type_deconvolution_nhwc_qs8) {
1149     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1150       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qs8),
1151       xnn_operator_type_to_string(deconvolution_op->type));
1152     return xnn_status_invalid_parameter;
1153   }
1154 
1155   return setup_deconvolution2d_nhwc(
1156     deconvolution_op,
1157     batch_size, input_height, input_width,
1158     adjustment_height, adjustment_width,
1159     input, output,
1160     0 /* log2(sizeof(input element)) = log2(sizeof(int8_t)) */,
1161     0 /* log2(sizeof(filter element)) = log2(sizeof(int8_t)) */,
1162     sizeof(int32_t) /* sizeof(bias element) */,
1163     0 /* log2(sizeof(output element)) = log2(sizeof(int8_t)) */,
1164     &deconvolution_op->params.qs8_conv_minmax, sizeof(deconvolution_op->params.qs8_conv_minmax),
1165     pthreadpool_get_threads_count(threadpool));
1166 }
1167 
xnn_setup_deconvolution2d_nhwc_qu8(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,uint32_t adjustment_height,uint32_t adjustment_width,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)1168 enum xnn_status xnn_setup_deconvolution2d_nhwc_qu8(
1169     xnn_operator_t deconvolution_op,
1170     size_t batch_size,
1171     size_t input_height,
1172     size_t input_width,
1173     uint32_t adjustment_height,
1174     uint32_t adjustment_width,
1175     const uint8_t* input,
1176     uint8_t* output,
1177     pthreadpool_t threadpool)
1178 {
1179   if (deconvolution_op->type != xnn_operator_type_deconvolution_nhwc_qu8) {
1180     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1181       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_qu8),
1182       xnn_operator_type_to_string(deconvolution_op->type));
1183     return xnn_status_invalid_parameter;
1184   }
1185 
1186   return setup_deconvolution2d_nhwc(
1187     deconvolution_op,
1188     batch_size, input_height, input_width,
1189     adjustment_height, adjustment_width,
1190     input, output,
1191     0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
1192     0 /* log2(sizeof(filter element)) = log2(sizeof(uint8_t)) */,
1193     sizeof(int32_t) /* sizeof(bias element) */,
1194     0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
1195     &deconvolution_op->params.qu8_conv_minmax, sizeof(deconvolution_op->params.qu8_conv_minmax),
1196     pthreadpool_get_threads_count(threadpool));
1197 }
1198 
xnn_setup_deconvolution2d_nhwc_f16(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,uint32_t adjustment_height,uint32_t adjustment_width,const void * input,void * output,pthreadpool_t threadpool)1199 enum xnn_status xnn_setup_deconvolution2d_nhwc_f16(
1200     xnn_operator_t deconvolution_op,
1201     size_t batch_size,
1202     size_t input_height,
1203     size_t input_width,
1204     uint32_t adjustment_height,
1205     uint32_t adjustment_width,
1206     const void* input,
1207     void* output,
1208     pthreadpool_t threadpool)
1209 {
1210   if (deconvolution_op->type != xnn_operator_type_deconvolution_nhwc_f16) {
1211     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1212       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f16),
1213       xnn_operator_type_to_string(deconvolution_op->type));
1214     return xnn_status_invalid_parameter;
1215   }
1216 
1217   return setup_deconvolution2d_nhwc(
1218     deconvolution_op,
1219     batch_size, input_height, input_width,
1220     adjustment_height, adjustment_width,
1221     input, output,
1222     1 /* log2(sizeof(input element)) = log2(sizeof(uint16_t)) */,
1223     1 /* log2(sizeof(filter element)) = log2(sizeof(uint16_t)) */,
1224     sizeof(uint16_t) /* sizeof(bias element) */,
1225     1 /* log2(sizeof(output element)) = log2(sizeof(uint16_t)) */,
1226     &deconvolution_op->params.f16_minmax, sizeof(deconvolution_op->params.f16_minmax),
1227     pthreadpool_get_threads_count(threadpool));
1228 }
1229 
xnn_setup_deconvolution2d_nhwc_f32(xnn_operator_t deconvolution_op,size_t batch_size,size_t input_height,size_t input_width,uint32_t adjustment_height,uint32_t adjustment_width,const float * input,float * output,pthreadpool_t threadpool)1230 enum xnn_status xnn_setup_deconvolution2d_nhwc_f32(
1231     xnn_operator_t deconvolution_op,
1232     size_t batch_size,
1233     size_t input_height,
1234     size_t input_width,
1235     uint32_t adjustment_height,
1236     uint32_t adjustment_width,
1237     const float* input,
1238     float* output,
1239     pthreadpool_t threadpool)
1240 {
1241   if (deconvolution_op->type != xnn_operator_type_deconvolution_nhwc_f32) {
1242     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
1243       xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f32),
1244       xnn_operator_type_to_string(deconvolution_op->type));
1245     return xnn_status_invalid_parameter;
1246   }
1247 
1248   return setup_deconvolution2d_nhwc(
1249     deconvolution_op,
1250     batch_size, input_height, input_width,
1251     adjustment_height, adjustment_width,
1252     input, output,
1253     2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
1254     2 /* log2(sizeof(filter element)) = log2(sizeof(float)) */,
1255     sizeof(float) /* sizeof(bias element) */,
1256     2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
1257     &deconvolution_op->params.f32_minmax, sizeof(deconvolution_op->params.f32_minmax),
1258     pthreadpool_get_threads_count(threadpool));
1259 }
1260