1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker
9*4bdc9457SAndroid Build Coastguard Worker #include <assert.h>
10*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h>
11*4bdc9457SAndroid Build Coastguard Worker #include <stdint.h>
12*4bdc9457SAndroid Build Coastguard Worker #include <string.h>
13*4bdc9457SAndroid Build Coastguard Worker
14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/allocator.h>
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/log.h>
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/common.h>
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/params.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/compute.h>
22*4bdc9457SAndroid Build Coastguard Worker
23*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_transposec_2d(const struct transpose_context * context,size_t i,size_t j,size_t tile_i,size_t tile_j)24*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_transposec_2d(
25*4bdc9457SAndroid Build Coastguard Worker const struct transpose_context* context,
26*4bdc9457SAndroid Build Coastguard Worker size_t i,
27*4bdc9457SAndroid Build Coastguard Worker size_t j,
28*4bdc9457SAndroid Build Coastguard Worker size_t tile_i,
29*4bdc9457SAndroid Build Coastguard Worker size_t tile_j)
30*4bdc9457SAndroid Build Coastguard Worker {
31*4bdc9457SAndroid Build Coastguard Worker const size_t log2_element_size = context->log2_element_size;
32*4bdc9457SAndroid Build Coastguard Worker
33*4bdc9457SAndroid Build Coastguard Worker context->const_size_ukernel(
34*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->x + (i << log2_element_size) + j * context->input_stride[1]),
35*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->y + (j << log2_element_size) + i * context->output_stride[0]),
36*4bdc9457SAndroid Build Coastguard Worker context->input_stride[1],
37*4bdc9457SAndroid Build Coastguard Worker context->output_stride[0],
38*4bdc9457SAndroid Build Coastguard Worker tile_i,
39*4bdc9457SAndroid Build Coastguard Worker tile_j);
40*4bdc9457SAndroid Build Coastguard Worker }
41*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_transposec_3d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t tile_j,size_t tile_k)42*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_transposec_3d(
43*4bdc9457SAndroid Build Coastguard Worker const struct transpose_context* context,
44*4bdc9457SAndroid Build Coastguard Worker size_t i,
45*4bdc9457SAndroid Build Coastguard Worker size_t j,
46*4bdc9457SAndroid Build Coastguard Worker size_t k,
47*4bdc9457SAndroid Build Coastguard Worker size_t tile_j,
48*4bdc9457SAndroid Build Coastguard Worker size_t tile_k)
49*4bdc9457SAndroid Build Coastguard Worker {
50*4bdc9457SAndroid Build Coastguard Worker const size_t log2_element_size = context->log2_element_size;
51*4bdc9457SAndroid Build Coastguard Worker const size_t ld_input = context->input_stride[2];
52*4bdc9457SAndroid Build Coastguard Worker const size_t ld_output = context->output_stride[1];
53*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*) ((uintptr_t) context->x +
54*4bdc9457SAndroid Build Coastguard Worker (i * context->input_stride[0] + j * context->input_stride[1]) + k * ld_input);
55*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
56*4bdc9457SAndroid Build Coastguard Worker (k << log2_element_size));
57*4bdc9457SAndroid Build Coastguard Worker
58*4bdc9457SAndroid Build Coastguard Worker context->const_size_ukernel(
59*4bdc9457SAndroid Build Coastguard Worker x,
60*4bdc9457SAndroid Build Coastguard Worker y,
61*4bdc9457SAndroid Build Coastguard Worker ld_input,
62*4bdc9457SAndroid Build Coastguard Worker ld_output,
63*4bdc9457SAndroid Build Coastguard Worker tile_j,
64*4bdc9457SAndroid Build Coastguard Worker tile_k);
65*4bdc9457SAndroid Build Coastguard Worker }
66*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_transposec_4d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t tile_k,size_t tile_l)67*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_transposec_4d(
68*4bdc9457SAndroid Build Coastguard Worker const struct transpose_context* context,
69*4bdc9457SAndroid Build Coastguard Worker size_t i,
70*4bdc9457SAndroid Build Coastguard Worker size_t j,
71*4bdc9457SAndroid Build Coastguard Worker size_t k,
72*4bdc9457SAndroid Build Coastguard Worker size_t l,
73*4bdc9457SAndroid Build Coastguard Worker size_t tile_k,
74*4bdc9457SAndroid Build Coastguard Worker size_t tile_l)
75*4bdc9457SAndroid Build Coastguard Worker {
76*4bdc9457SAndroid Build Coastguard Worker const size_t log2_element_size = context->log2_element_size;
77*4bdc9457SAndroid Build Coastguard Worker const size_t ld_input = context->input_stride[3];
78*4bdc9457SAndroid Build Coastguard Worker const size_t ld_output = context->output_stride[2];
79*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*) ((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
80*4bdc9457SAndroid Build Coastguard Worker k * context->input_stride[2] + l * ld_input);
81*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
82*4bdc9457SAndroid Build Coastguard Worker k * context->output_stride[2] + (l << log2_element_size));
83*4bdc9457SAndroid Build Coastguard Worker
84*4bdc9457SAndroid Build Coastguard Worker context->const_size_ukernel(
85*4bdc9457SAndroid Build Coastguard Worker x,
86*4bdc9457SAndroid Build Coastguard Worker y,
87*4bdc9457SAndroid Build Coastguard Worker ld_input,
88*4bdc9457SAndroid Build Coastguard Worker ld_output,
89*4bdc9457SAndroid Build Coastguard Worker tile_k,
90*4bdc9457SAndroid Build Coastguard Worker tile_l);
91*4bdc9457SAndroid Build Coastguard Worker }
92*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_transposec_5d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t m,size_t tile_l,size_t tile_m)93*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_transposec_5d(
94*4bdc9457SAndroid Build Coastguard Worker const struct transpose_context* context,
95*4bdc9457SAndroid Build Coastguard Worker size_t i,
96*4bdc9457SAndroid Build Coastguard Worker size_t j,
97*4bdc9457SAndroid Build Coastguard Worker size_t k,
98*4bdc9457SAndroid Build Coastguard Worker size_t l,
99*4bdc9457SAndroid Build Coastguard Worker size_t m,
100*4bdc9457SAndroid Build Coastguard Worker size_t tile_l,
101*4bdc9457SAndroid Build Coastguard Worker size_t tile_m)
102*4bdc9457SAndroid Build Coastguard Worker {
103*4bdc9457SAndroid Build Coastguard Worker const size_t log2_element_size = context->log2_element_size;
104*4bdc9457SAndroid Build Coastguard Worker const size_t ld_input = context->input_stride[4];
105*4bdc9457SAndroid Build Coastguard Worker const size_t ld_output = context->output_stride[3];
106*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
107*4bdc9457SAndroid Build Coastguard Worker k * context->input_stride[2] + l * context->input_stride[3] + m * ld_input);
108*4bdc9457SAndroid Build Coastguard Worker void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
109*4bdc9457SAndroid Build Coastguard Worker k * context->output_stride[2] + l * context->output_stride[3] + (m << log2_element_size));
110*4bdc9457SAndroid Build Coastguard Worker
111*4bdc9457SAndroid Build Coastguard Worker context->const_size_ukernel(
112*4bdc9457SAndroid Build Coastguard Worker x,
113*4bdc9457SAndroid Build Coastguard Worker y,
114*4bdc9457SAndroid Build Coastguard Worker ld_input,
115*4bdc9457SAndroid Build Coastguard Worker ld_output,
116*4bdc9457SAndroid Build Coastguard Worker tile_l,
117*4bdc9457SAndroid Build Coastguard Worker tile_m);
118*4bdc9457SAndroid Build Coastguard Worker }
119*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_transposec_6d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t m,size_t n,size_t tile_m,size_t tile_n)120*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_transposec_6d(
121*4bdc9457SAndroid Build Coastguard Worker const struct transpose_context* context,
122*4bdc9457SAndroid Build Coastguard Worker size_t i,
123*4bdc9457SAndroid Build Coastguard Worker size_t j,
124*4bdc9457SAndroid Build Coastguard Worker size_t k,
125*4bdc9457SAndroid Build Coastguard Worker size_t l,
126*4bdc9457SAndroid Build Coastguard Worker size_t m,
127*4bdc9457SAndroid Build Coastguard Worker size_t n,
128*4bdc9457SAndroid Build Coastguard Worker size_t tile_m,
129*4bdc9457SAndroid Build Coastguard Worker size_t tile_n)
130*4bdc9457SAndroid Build Coastguard Worker {
131*4bdc9457SAndroid Build Coastguard Worker const size_t log2_element_size = context->log2_element_size;
132*4bdc9457SAndroid Build Coastguard Worker const size_t ld_input = context->input_stride[5];
133*4bdc9457SAndroid Build Coastguard Worker const size_t ld_output = context->output_stride[4];
134*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
135*4bdc9457SAndroid Build Coastguard Worker k * context->input_stride[2] + l * context->input_stride[3] +
136*4bdc9457SAndroid Build Coastguard Worker m * context->input_stride[4] + n * ld_input);
137*4bdc9457SAndroid Build Coastguard Worker void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
138*4bdc9457SAndroid Build Coastguard Worker k * context->output_stride[2] + l * context->output_stride[3] + m * context->output_stride[4] +
139*4bdc9457SAndroid Build Coastguard Worker (n << log2_element_size));
140*4bdc9457SAndroid Build Coastguard Worker
141*4bdc9457SAndroid Build Coastguard Worker context->const_size_ukernel(
142*4bdc9457SAndroid Build Coastguard Worker x,
143*4bdc9457SAndroid Build Coastguard Worker y,
144*4bdc9457SAndroid Build Coastguard Worker ld_input,
145*4bdc9457SAndroid Build Coastguard Worker ld_output,
146*4bdc9457SAndroid Build Coastguard Worker tile_m,
147*4bdc9457SAndroid Build Coastguard Worker tile_n);
148*4bdc9457SAndroid Build Coastguard Worker }
149*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_transposev_2d(const struct transpose_context * context,size_t i,size_t j,size_t tile_i,size_t tile_j)150*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_transposev_2d(
151*4bdc9457SAndroid Build Coastguard Worker const struct transpose_context* context,
152*4bdc9457SAndroid Build Coastguard Worker size_t i,
153*4bdc9457SAndroid Build Coastguard Worker size_t j,
154*4bdc9457SAndroid Build Coastguard Worker size_t tile_i,
155*4bdc9457SAndroid Build Coastguard Worker size_t tile_j)
156*4bdc9457SAndroid Build Coastguard Worker {
157*4bdc9457SAndroid Build Coastguard Worker const size_t element_size = context->element_size;
158*4bdc9457SAndroid Build Coastguard Worker const size_t ld_input = context->input_stride[1];
159*4bdc9457SAndroid Build Coastguard Worker const size_t ld_output = context->output_stride[0];
160*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*) ((uintptr_t) context->x +
161*4bdc9457SAndroid Build Coastguard Worker i * context->input_stride[0] + j * ld_input);
162*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y + context->output_stride[1] * j + i * context->output_stride[0]);
163*4bdc9457SAndroid Build Coastguard Worker
164*4bdc9457SAndroid Build Coastguard Worker context->variable_size_ukernel(
165*4bdc9457SAndroid Build Coastguard Worker x,
166*4bdc9457SAndroid Build Coastguard Worker y,
167*4bdc9457SAndroid Build Coastguard Worker ld_input,
168*4bdc9457SAndroid Build Coastguard Worker ld_output,
169*4bdc9457SAndroid Build Coastguard Worker context->input_stride[0],
170*4bdc9457SAndroid Build Coastguard Worker context->output_stride[1],
171*4bdc9457SAndroid Build Coastguard Worker element_size,
172*4bdc9457SAndroid Build Coastguard Worker tile_i,
173*4bdc9457SAndroid Build Coastguard Worker tile_j);
174*4bdc9457SAndroid Build Coastguard Worker }
175*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_transposev_3d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t tile_j,size_t tile_k)176*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_transposev_3d(
177*4bdc9457SAndroid Build Coastguard Worker const struct transpose_context* context,
178*4bdc9457SAndroid Build Coastguard Worker size_t i,
179*4bdc9457SAndroid Build Coastguard Worker size_t j,
180*4bdc9457SAndroid Build Coastguard Worker size_t k,
181*4bdc9457SAndroid Build Coastguard Worker size_t tile_j,
182*4bdc9457SAndroid Build Coastguard Worker size_t tile_k)
183*4bdc9457SAndroid Build Coastguard Worker {
184*4bdc9457SAndroid Build Coastguard Worker const size_t element_size = context->element_size;
185*4bdc9457SAndroid Build Coastguard Worker const size_t ld_input = context->input_stride[2];
186*4bdc9457SAndroid Build Coastguard Worker const size_t ld_output = context->output_stride[1];
187*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
188*4bdc9457SAndroid Build Coastguard Worker k * ld_input);
189*4bdc9457SAndroid Build Coastguard Worker void* y = (void*)((uintptr_t)context->y + i * context->output_stride[0] + j * context->output_stride[1] +
190*4bdc9457SAndroid Build Coastguard Worker k * context->output_stride[2]);
191*4bdc9457SAndroid Build Coastguard Worker
192*4bdc9457SAndroid Build Coastguard Worker context->variable_size_ukernel(
193*4bdc9457SAndroid Build Coastguard Worker x,
194*4bdc9457SAndroid Build Coastguard Worker y,
195*4bdc9457SAndroid Build Coastguard Worker ld_input,
196*4bdc9457SAndroid Build Coastguard Worker ld_output,
197*4bdc9457SAndroid Build Coastguard Worker context->input_stride[1],
198*4bdc9457SAndroid Build Coastguard Worker context->output_stride[2],
199*4bdc9457SAndroid Build Coastguard Worker element_size,
200*4bdc9457SAndroid Build Coastguard Worker tile_j,
201*4bdc9457SAndroid Build Coastguard Worker tile_k);
202*4bdc9457SAndroid Build Coastguard Worker }
203*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_transposev_4d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t tile_k,size_t tile_l)204*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_transposev_4d(
205*4bdc9457SAndroid Build Coastguard Worker const struct transpose_context* context,
206*4bdc9457SAndroid Build Coastguard Worker size_t i,
207*4bdc9457SAndroid Build Coastguard Worker size_t j,
208*4bdc9457SAndroid Build Coastguard Worker size_t k,
209*4bdc9457SAndroid Build Coastguard Worker size_t l,
210*4bdc9457SAndroid Build Coastguard Worker size_t tile_k,
211*4bdc9457SAndroid Build Coastguard Worker size_t tile_l)
212*4bdc9457SAndroid Build Coastguard Worker {
213*4bdc9457SAndroid Build Coastguard Worker const size_t element_size = context->element_size;
214*4bdc9457SAndroid Build Coastguard Worker const size_t ld_input = context->input_stride[3];
215*4bdc9457SAndroid Build Coastguard Worker const size_t ld_output = context->output_stride[2];
216*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
217*4bdc9457SAndroid Build Coastguard Worker k * context->input_stride[2] + l * ld_input);
218*4bdc9457SAndroid Build Coastguard Worker void* y = (void*)((uintptr_t)context->y + context->output_stride[3] * l + i * context->output_stride[0] +
219*4bdc9457SAndroid Build Coastguard Worker j * context->output_stride[1] + k * context->output_stride[2]);
220*4bdc9457SAndroid Build Coastguard Worker
221*4bdc9457SAndroid Build Coastguard Worker context->variable_size_ukernel(
222*4bdc9457SAndroid Build Coastguard Worker x,
223*4bdc9457SAndroid Build Coastguard Worker y,
224*4bdc9457SAndroid Build Coastguard Worker ld_input,
225*4bdc9457SAndroid Build Coastguard Worker ld_output,
226*4bdc9457SAndroid Build Coastguard Worker context->input_stride[2],
227*4bdc9457SAndroid Build Coastguard Worker context->output_stride[3],
228*4bdc9457SAndroid Build Coastguard Worker element_size,
229*4bdc9457SAndroid Build Coastguard Worker tile_k,
230*4bdc9457SAndroid Build Coastguard Worker tile_l);
231*4bdc9457SAndroid Build Coastguard Worker }
232*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_transposev_5d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t m,size_t tile_l,size_t tile_m)233*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_transposev_5d(
234*4bdc9457SAndroid Build Coastguard Worker const struct transpose_context* context,
235*4bdc9457SAndroid Build Coastguard Worker size_t i,
236*4bdc9457SAndroid Build Coastguard Worker size_t j,
237*4bdc9457SAndroid Build Coastguard Worker size_t k,
238*4bdc9457SAndroid Build Coastguard Worker size_t l,
239*4bdc9457SAndroid Build Coastguard Worker size_t m,
240*4bdc9457SAndroid Build Coastguard Worker size_t tile_l,
241*4bdc9457SAndroid Build Coastguard Worker size_t tile_m)
242*4bdc9457SAndroid Build Coastguard Worker {
243*4bdc9457SAndroid Build Coastguard Worker const size_t element_size = context->element_size;
244*4bdc9457SAndroid Build Coastguard Worker const size_t ld_input = context->input_stride[4];
245*4bdc9457SAndroid Build Coastguard Worker const size_t ld_output = context->output_stride[3];
246*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
247*4bdc9457SAndroid Build Coastguard Worker k * context->input_stride[2] + l * context->input_stride[3] + m * ld_input);
248*4bdc9457SAndroid Build Coastguard Worker void* y = (void*)((uintptr_t)context->y + context->output_stride[4] * m + i * context->output_stride[0] +
249*4bdc9457SAndroid Build Coastguard Worker j * context->output_stride[1] + k * context->output_stride[2] + l * context->output_stride[3]);
250*4bdc9457SAndroid Build Coastguard Worker
251*4bdc9457SAndroid Build Coastguard Worker context->variable_size_ukernel(
252*4bdc9457SAndroid Build Coastguard Worker x,
253*4bdc9457SAndroid Build Coastguard Worker y,
254*4bdc9457SAndroid Build Coastguard Worker ld_input,
255*4bdc9457SAndroid Build Coastguard Worker ld_output,
256*4bdc9457SAndroid Build Coastguard Worker context->input_stride[3],
257*4bdc9457SAndroid Build Coastguard Worker context->output_stride[4],
258*4bdc9457SAndroid Build Coastguard Worker element_size,
259*4bdc9457SAndroid Build Coastguard Worker tile_l,
260*4bdc9457SAndroid Build Coastguard Worker tile_m);
261*4bdc9457SAndroid Build Coastguard Worker }
262*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_transposev_6d(const struct transpose_context * context,size_t i,size_t j,size_t k,size_t l,size_t m,size_t n,size_t tile_m,size_t tile_n)263*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_transposev_6d(
264*4bdc9457SAndroid Build Coastguard Worker const struct transpose_context* context,
265*4bdc9457SAndroid Build Coastguard Worker size_t i,
266*4bdc9457SAndroid Build Coastguard Worker size_t j,
267*4bdc9457SAndroid Build Coastguard Worker size_t k,
268*4bdc9457SAndroid Build Coastguard Worker size_t l,
269*4bdc9457SAndroid Build Coastguard Worker size_t m,
270*4bdc9457SAndroid Build Coastguard Worker size_t n,
271*4bdc9457SAndroid Build Coastguard Worker size_t tile_m,
272*4bdc9457SAndroid Build Coastguard Worker size_t tile_n)
273*4bdc9457SAndroid Build Coastguard Worker {
274*4bdc9457SAndroid Build Coastguard Worker const size_t element_size = context->element_size;
275*4bdc9457SAndroid Build Coastguard Worker const size_t ld_input = context->input_stride[5];
276*4bdc9457SAndroid Build Coastguard Worker const size_t ld_output = context->output_stride[4];
277*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*)((uintptr_t)context->x + i * context->input_stride[0] + j * context->input_stride[1] +
278*4bdc9457SAndroid Build Coastguard Worker k * context->input_stride[2] + l * context->input_stride[3] +
279*4bdc9457SAndroid Build Coastguard Worker m * context->input_stride[4] + n * ld_input);
280*4bdc9457SAndroid Build Coastguard Worker void* y = (void*)((uintptr_t)context->y + context->output_stride[5] * n + i * context->output_stride[0] +
281*4bdc9457SAndroid Build Coastguard Worker j * context->output_stride[1] + k * context->output_stride[2] + l * context->output_stride[3] +
282*4bdc9457SAndroid Build Coastguard Worker m * context->output_stride[4]);
283*4bdc9457SAndroid Build Coastguard Worker
284*4bdc9457SAndroid Build Coastguard Worker context->variable_size_ukernel(
285*4bdc9457SAndroid Build Coastguard Worker x,
286*4bdc9457SAndroid Build Coastguard Worker y,
287*4bdc9457SAndroid Build Coastguard Worker ld_input,
288*4bdc9457SAndroid Build Coastguard Worker ld_output,
289*4bdc9457SAndroid Build Coastguard Worker context->input_stride[4],
290*4bdc9457SAndroid Build Coastguard Worker context->output_stride[5],
291*4bdc9457SAndroid Build Coastguard Worker element_size,
292*4bdc9457SAndroid Build Coastguard Worker tile_m,
293*4bdc9457SAndroid Build Coastguard Worker tile_n);
294*4bdc9457SAndroid Build Coastguard Worker }
295*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_grouped_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)296*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_grouped_gemm(
297*4bdc9457SAndroid Build Coastguard Worker const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
298*4bdc9457SAndroid Build Coastguard Worker size_t group_index,
299*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_start,
300*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_start,
301*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_size,
302*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_size)
303*4bdc9457SAndroid Build Coastguard Worker {
304*4bdc9457SAndroid Build Coastguard Worker const size_t k_scaled = context->k_scaled;
305*4bdc9457SAndroid Build Coastguard Worker const size_t a_stride = context->a_stride;
306*4bdc9457SAndroid Build Coastguard Worker const size_t cm_stride = context->cm_stride;
307*4bdc9457SAndroid Build Coastguard Worker
308*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[XNN_UARCH_DEFAULT](
309*4bdc9457SAndroid Build Coastguard Worker mr_block_size,
310*4bdc9457SAndroid Build Coastguard Worker nr_block_size,
311*4bdc9457SAndroid Build Coastguard Worker k_scaled,
312*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
313*4bdc9457SAndroid Build Coastguard Worker a_stride,
314*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
315*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
316*4bdc9457SAndroid Build Coastguard Worker cm_stride,
317*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
318*4bdc9457SAndroid Build Coastguard Worker &context->params);
319*4bdc9457SAndroid Build Coastguard Worker }
320*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)321*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_gemm(
322*4bdc9457SAndroid Build Coastguard Worker const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
323*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_start,
324*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_start,
325*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_size,
326*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_size)
327*4bdc9457SAndroid Build Coastguard Worker {
328*4bdc9457SAndroid Build Coastguard Worker const size_t a_stride = context->a_stride;
329*4bdc9457SAndroid Build Coastguard Worker const size_t cm_stride = context->cm_stride;
330*4bdc9457SAndroid Build Coastguard Worker
331*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[XNN_UARCH_DEFAULT](
332*4bdc9457SAndroid Build Coastguard Worker mr_block_size,
333*4bdc9457SAndroid Build Coastguard Worker nr_block_size,
334*4bdc9457SAndroid Build Coastguard Worker context->k_scaled,
335*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
336*4bdc9457SAndroid Build Coastguard Worker a_stride,
337*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
338*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
339*4bdc9457SAndroid Build Coastguard Worker cm_stride,
340*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
341*4bdc9457SAndroid Build Coastguard Worker context->fused_params);
342*4bdc9457SAndroid Build Coastguard Worker }
343*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_spmm(const struct spmm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t mr_block_start,size_t mr_block_size)344*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_spmm(
345*4bdc9457SAndroid Build Coastguard Worker const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)],
346*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
347*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_start,
348*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_size)
349*4bdc9457SAndroid Build Coastguard Worker {
350*4bdc9457SAndroid Build Coastguard Worker context->ukernel(
351*4bdc9457SAndroid Build Coastguard Worker mr_block_size,
352*4bdc9457SAndroid Build Coastguard Worker context->n,
353*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->input + batch_index * context->batched_input_stride + mr_block_start),
354*4bdc9457SAndroid Build Coastguard Worker context->nonzero_weights,
355*4bdc9457SAndroid Build Coastguard Worker context->input_increments,
356*4bdc9457SAndroid Build Coastguard Worker context->output_channel_nonzeros,
357*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->output + batch_index * context->batched_output_stride + mr_block_start),
358*4bdc9457SAndroid Build Coastguard Worker context->scaled_m,
359*4bdc9457SAndroid Build Coastguard Worker &context->params);
360*4bdc9457SAndroid Build Coastguard Worker }
361*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_grouped_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)362*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_grouped_batch_igemm(
363*4bdc9457SAndroid Build Coastguard Worker const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
364*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
365*4bdc9457SAndroid Build Coastguard Worker size_t group_index,
366*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_start,
367*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_start,
368*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_size,
369*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_size)
370*4bdc9457SAndroid Build Coastguard Worker {
371*4bdc9457SAndroid Build Coastguard Worker const size_t ks = context->ks;
372*4bdc9457SAndroid Build Coastguard Worker const size_t cm_stride = context->cm_stride;
373*4bdc9457SAndroid Build Coastguard Worker
374*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[XNN_UARCH_DEFAULT](
375*4bdc9457SAndroid Build Coastguard Worker mr_block_size,
376*4bdc9457SAndroid Build Coastguard Worker nr_block_size,
377*4bdc9457SAndroid Build Coastguard Worker context->kc,
378*4bdc9457SAndroid Build Coastguard Worker context->ks_scaled,
379*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
380*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
381*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
382*4bdc9457SAndroid Build Coastguard Worker cm_stride,
383*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
384*4bdc9457SAndroid Build Coastguard Worker context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
385*4bdc9457SAndroid Build Coastguard Worker context->zero,
386*4bdc9457SAndroid Build Coastguard Worker &context->params);
387*4bdc9457SAndroid Build Coastguard Worker }
388*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_grouped_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)389*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_grouped_igemm(
390*4bdc9457SAndroid Build Coastguard Worker const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
391*4bdc9457SAndroid Build Coastguard Worker size_t group_index,
392*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_start,
393*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_start,
394*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_size,
395*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_size)
396*4bdc9457SAndroid Build Coastguard Worker {
397*4bdc9457SAndroid Build Coastguard Worker const size_t ks = context->ks;
398*4bdc9457SAndroid Build Coastguard Worker const size_t cm_stride = context->cm_stride;
399*4bdc9457SAndroid Build Coastguard Worker
400*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[XNN_UARCH_DEFAULT](
401*4bdc9457SAndroid Build Coastguard Worker mr_block_size,
402*4bdc9457SAndroid Build Coastguard Worker nr_block_size,
403*4bdc9457SAndroid Build Coastguard Worker context->kc,
404*4bdc9457SAndroid Build Coastguard Worker context->ks_scaled,
405*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
406*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
407*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
408*4bdc9457SAndroid Build Coastguard Worker cm_stride,
409*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
410*4bdc9457SAndroid Build Coastguard Worker context->a_offset + group_index * context->ga_stride,
411*4bdc9457SAndroid Build Coastguard Worker context->zero,
412*4bdc9457SAndroid Build Coastguard Worker &context->params);
413*4bdc9457SAndroid Build Coastguard Worker }
414*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)415*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_batch_igemm(
416*4bdc9457SAndroid Build Coastguard Worker const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
417*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
418*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_start,
419*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_start,
420*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_size,
421*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_size)
422*4bdc9457SAndroid Build Coastguard Worker {
423*4bdc9457SAndroid Build Coastguard Worker const size_t ks = context->ks;
424*4bdc9457SAndroid Build Coastguard Worker const size_t cm_stride = context->cm_stride;
425*4bdc9457SAndroid Build Coastguard Worker
426*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[XNN_UARCH_DEFAULT](
427*4bdc9457SAndroid Build Coastguard Worker mr_block_size,
428*4bdc9457SAndroid Build Coastguard Worker nr_block_size,
429*4bdc9457SAndroid Build Coastguard Worker context->kc,
430*4bdc9457SAndroid Build Coastguard Worker context->ks_scaled,
431*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
432*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
433*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
434*4bdc9457SAndroid Build Coastguard Worker cm_stride,
435*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
436*4bdc9457SAndroid Build Coastguard Worker context->a_offset + batch_index * context->ba_stride,
437*4bdc9457SAndroid Build Coastguard Worker context->zero,
438*4bdc9457SAndroid Build Coastguard Worker &context->params);
439*4bdc9457SAndroid Build Coastguard Worker }
440*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)441*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_igemm(
442*4bdc9457SAndroid Build Coastguard Worker const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
443*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_start,
444*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_start,
445*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_size,
446*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_size)
447*4bdc9457SAndroid Build Coastguard Worker {
448*4bdc9457SAndroid Build Coastguard Worker const size_t ks = context->ks;
449*4bdc9457SAndroid Build Coastguard Worker const size_t cm_stride = context->cm_stride;
450*4bdc9457SAndroid Build Coastguard Worker
451*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[XNN_UARCH_DEFAULT](
452*4bdc9457SAndroid Build Coastguard Worker mr_block_size,
453*4bdc9457SAndroid Build Coastguard Worker nr_block_size,
454*4bdc9457SAndroid Build Coastguard Worker context->kc,
455*4bdc9457SAndroid Build Coastguard Worker context->ks_scaled,
456*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
457*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
458*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
459*4bdc9457SAndroid Build Coastguard Worker cm_stride,
460*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
461*4bdc9457SAndroid Build Coastguard Worker context->a_offset,
462*4bdc9457SAndroid Build Coastguard Worker context->zero,
463*4bdc9457SAndroid Build Coastguard Worker &context->params);
464*4bdc9457SAndroid Build Coastguard Worker }
465*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_grouped_subgemm2d(const struct subgemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)466*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_grouped_subgemm2d(
467*4bdc9457SAndroid Build Coastguard Worker const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
468*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
469*4bdc9457SAndroid Build Coastguard Worker size_t group_index,
470*4bdc9457SAndroid Build Coastguard Worker size_t subkernel_index,
471*4bdc9457SAndroid Build Coastguard Worker size_t slice_y,
472*4bdc9457SAndroid Build Coastguard Worker size_t slice_x_start,
473*4bdc9457SAndroid Build Coastguard Worker size_t nc_block_start,
474*4bdc9457SAndroid Build Coastguard Worker size_t slice_x_max,
475*4bdc9457SAndroid Build Coastguard Worker size_t nc_block_size)
476*4bdc9457SAndroid Build Coastguard Worker {
477*4bdc9457SAndroid Build Coastguard Worker const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
478*4bdc9457SAndroid Build Coastguard Worker
479*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
480*4bdc9457SAndroid Build Coastguard Worker return;
481*4bdc9457SAndroid Build Coastguard Worker }
482*4bdc9457SAndroid Build Coastguard Worker
483*4bdc9457SAndroid Build Coastguard Worker const size_t slice_width = subconvolution_params->slice_width;
484*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(slice_x_start >= slice_width) {
485*4bdc9457SAndroid Build Coastguard Worker return;
486*4bdc9457SAndroid Build Coastguard Worker }
487*4bdc9457SAndroid Build Coastguard Worker const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
488*4bdc9457SAndroid Build Coastguard Worker
489*4bdc9457SAndroid Build Coastguard Worker const size_t ax_stride = context->ax_stride;
490*4bdc9457SAndroid Build Coastguard Worker const size_t cx_stride = context->cx_stride;
491*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[XNN_UARCH_DEFAULT](
492*4bdc9457SAndroid Build Coastguard Worker slice_x_size,
493*4bdc9457SAndroid Build Coastguard Worker nc_block_size,
494*4bdc9457SAndroid Build Coastguard Worker context->kc,
495*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->a + group_index * context->ga_stride + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
496*4bdc9457SAndroid Build Coastguard Worker ax_stride,
497*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
498*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
499*4bdc9457SAndroid Build Coastguard Worker cx_stride,
500*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
501*4bdc9457SAndroid Build Coastguard Worker &context->params);
502*4bdc9457SAndroid Build Coastguard Worker }
503*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_subgemm2d(const struct subgemm_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)504*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_subgemm2d(
505*4bdc9457SAndroid Build Coastguard Worker const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)],
506*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
507*4bdc9457SAndroid Build Coastguard Worker size_t subkernel_index,
508*4bdc9457SAndroid Build Coastguard Worker size_t slice_y,
509*4bdc9457SAndroid Build Coastguard Worker size_t slice_x_start,
510*4bdc9457SAndroid Build Coastguard Worker size_t nc_block_start,
511*4bdc9457SAndroid Build Coastguard Worker size_t slice_x_max,
512*4bdc9457SAndroid Build Coastguard Worker size_t nc_block_size)
513*4bdc9457SAndroid Build Coastguard Worker {
514*4bdc9457SAndroid Build Coastguard Worker const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
515*4bdc9457SAndroid Build Coastguard Worker
516*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
517*4bdc9457SAndroid Build Coastguard Worker return;
518*4bdc9457SAndroid Build Coastguard Worker }
519*4bdc9457SAndroid Build Coastguard Worker
520*4bdc9457SAndroid Build Coastguard Worker const size_t slice_width = subconvolution_params->slice_width;
521*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(slice_x_start >= slice_width) {
522*4bdc9457SAndroid Build Coastguard Worker return;
523*4bdc9457SAndroid Build Coastguard Worker }
524*4bdc9457SAndroid Build Coastguard Worker const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
525*4bdc9457SAndroid Build Coastguard Worker
526*4bdc9457SAndroid Build Coastguard Worker const size_t ax_stride = context->ax_stride;
527*4bdc9457SAndroid Build Coastguard Worker const size_t cx_stride = context->cx_stride;
528*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[XNN_UARCH_DEFAULT](
529*4bdc9457SAndroid Build Coastguard Worker slice_x_size,
530*4bdc9457SAndroid Build Coastguard Worker nc_block_size,
531*4bdc9457SAndroid Build Coastguard Worker context->kc,
532*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->a + slice_y * context->ay_stride + slice_x_start * ax_stride + batch_index * context->ba_stride),
533*4bdc9457SAndroid Build Coastguard Worker ax_stride,
534*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
535*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
536*4bdc9457SAndroid Build Coastguard Worker cx_stride,
537*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
538*4bdc9457SAndroid Build Coastguard Worker &context->params);
539*4bdc9457SAndroid Build Coastguard Worker }
540*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_grouped_subconv2d(const struct subconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t group_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)541*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_grouped_subconv2d(
542*4bdc9457SAndroid Build Coastguard Worker const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
543*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
544*4bdc9457SAndroid Build Coastguard Worker size_t group_index,
545*4bdc9457SAndroid Build Coastguard Worker size_t subkernel_index,
546*4bdc9457SAndroid Build Coastguard Worker size_t slice_y,
547*4bdc9457SAndroid Build Coastguard Worker size_t slice_x_start,
548*4bdc9457SAndroid Build Coastguard Worker size_t nc_block_start,
549*4bdc9457SAndroid Build Coastguard Worker size_t slice_x_max,
550*4bdc9457SAndroid Build Coastguard Worker size_t nc_block_size)
551*4bdc9457SAndroid Build Coastguard Worker {
552*4bdc9457SAndroid Build Coastguard Worker const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
553*4bdc9457SAndroid Build Coastguard Worker
554*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
555*4bdc9457SAndroid Build Coastguard Worker return;
556*4bdc9457SAndroid Build Coastguard Worker }
557*4bdc9457SAndroid Build Coastguard Worker
558*4bdc9457SAndroid Build Coastguard Worker const size_t slice_width = subconvolution_params->slice_width;
559*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(slice_x_start >= slice_width) {
560*4bdc9457SAndroid Build Coastguard Worker return;
561*4bdc9457SAndroid Build Coastguard Worker }
562*4bdc9457SAndroid Build Coastguard Worker const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
563*4bdc9457SAndroid Build Coastguard Worker
564*4bdc9457SAndroid Build Coastguard Worker const size_t cx_stride = context->cx_stride;
565*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[XNN_UARCH_DEFAULT](
566*4bdc9457SAndroid Build Coastguard Worker slice_x_size,
567*4bdc9457SAndroid Build Coastguard Worker nc_block_size,
568*4bdc9457SAndroid Build Coastguard Worker context->kc,
569*4bdc9457SAndroid Build Coastguard Worker subconvolution_params->scaled_kernel_size,
570*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
571*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride + group_index * context->gw_stride),
572*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) subconvolution_params->output + group_index * context->gc_stride + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
573*4bdc9457SAndroid Build Coastguard Worker cx_stride,
574*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
575*4bdc9457SAndroid Build Coastguard Worker context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
576*4bdc9457SAndroid Build Coastguard Worker context->zero,
577*4bdc9457SAndroid Build Coastguard Worker &context->params);
578*4bdc9457SAndroid Build Coastguard Worker }
579*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_subconv2d(const struct subconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t subkernel_index,size_t slice_y,size_t slice_x_start,size_t nc_block_start,size_t slice_x_max,size_t nc_block_size)580*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_subconv2d(
581*4bdc9457SAndroid Build Coastguard Worker const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)],
582*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
583*4bdc9457SAndroid Build Coastguard Worker size_t subkernel_index,
584*4bdc9457SAndroid Build Coastguard Worker size_t slice_y,
585*4bdc9457SAndroid Build Coastguard Worker size_t slice_x_start,
586*4bdc9457SAndroid Build Coastguard Worker size_t nc_block_start,
587*4bdc9457SAndroid Build Coastguard Worker size_t slice_x_max,
588*4bdc9457SAndroid Build Coastguard Worker size_t nc_block_size)
589*4bdc9457SAndroid Build Coastguard Worker {
590*4bdc9457SAndroid Build Coastguard Worker const struct subconvolution_params* subconvolution_params = &context->subconvolution_params[subkernel_index];
591*4bdc9457SAndroid Build Coastguard Worker
592*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(slice_y >= subconvolution_params->slice_height) {
593*4bdc9457SAndroid Build Coastguard Worker return;
594*4bdc9457SAndroid Build Coastguard Worker }
595*4bdc9457SAndroid Build Coastguard Worker
596*4bdc9457SAndroid Build Coastguard Worker const size_t slice_width = subconvolution_params->slice_width;
597*4bdc9457SAndroid Build Coastguard Worker if XNN_UNLIKELY(slice_x_start >= slice_width) {
598*4bdc9457SAndroid Build Coastguard Worker return;
599*4bdc9457SAndroid Build Coastguard Worker }
600*4bdc9457SAndroid Build Coastguard Worker const size_t slice_x_size = min(slice_x_max, slice_width - slice_x_start);
601*4bdc9457SAndroid Build Coastguard Worker
602*4bdc9457SAndroid Build Coastguard Worker const size_t cx_stride = context->cx_stride;
603*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[XNN_UARCH_DEFAULT](
604*4bdc9457SAndroid Build Coastguard Worker slice_x_size,
605*4bdc9457SAndroid Build Coastguard Worker nc_block_size,
606*4bdc9457SAndroid Build Coastguard Worker context->kc,
607*4bdc9457SAndroid Build Coastguard Worker subconvolution_params->scaled_kernel_size,
608*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) subconvolution_params->indirection_buffer + slice_y * subconvolution_params->indirection_y_stride + slice_x_start * subconvolution_params->indirection_x_stride),
609*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) subconvolution_params->weights + nc_block_start * subconvolution_params->w_stride),
610*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) subconvolution_params->output + slice_y * context->cy_stride + slice_x_start * cx_stride + batch_index * context->bc_stride + (nc_block_start << context->log2_csize)),
611*4bdc9457SAndroid Build Coastguard Worker cx_stride,
612*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
613*4bdc9457SAndroid Build Coastguard Worker context->a_offset + batch_index * context->ba_stride,
614*4bdc9457SAndroid Build Coastguard Worker context->zero,
615*4bdc9457SAndroid Build Coastguard Worker &context->params);
616*4bdc9457SAndroid Build Coastguard Worker }
617*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_conv2d_hwc2chw(const struct conv2d_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y_start,size_t output_y_slice)618*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_conv2d_hwc2chw(
619*4bdc9457SAndroid Build Coastguard Worker const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
620*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
621*4bdc9457SAndroid Build Coastguard Worker size_t output_y_start,
622*4bdc9457SAndroid Build Coastguard Worker size_t output_y_slice)
623*4bdc9457SAndroid Build Coastguard Worker {
624*4bdc9457SAndroid Build Coastguard Worker context->hwc2chw_ukernel(
625*4bdc9457SAndroid Build Coastguard Worker context->input_height,
626*4bdc9457SAndroid Build Coastguard Worker context->input_width,
627*4bdc9457SAndroid Build Coastguard Worker output_y_start,
628*4bdc9457SAndroid Build Coastguard Worker output_y_start + output_y_slice,
629*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride),
630*4bdc9457SAndroid Build Coastguard Worker context->zero,
631*4bdc9457SAndroid Build Coastguard Worker context->packed_weights,
632*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride),
633*4bdc9457SAndroid Build Coastguard Worker context->input_padding_top,
634*4bdc9457SAndroid Build Coastguard Worker context->output_channels,
635*4bdc9457SAndroid Build Coastguard Worker context->output_height_stride,
636*4bdc9457SAndroid Build Coastguard Worker context->output_channel_stride,
637*4bdc9457SAndroid Build Coastguard Worker &context->params);
638*4bdc9457SAndroid Build Coastguard Worker }
639*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_dwconv_unipass(const struct dwconv_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)640*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_dwconv_unipass(
641*4bdc9457SAndroid Build Coastguard Worker const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)],
642*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
643*4bdc9457SAndroid Build Coastguard Worker size_t output_y)
644*4bdc9457SAndroid Build Coastguard Worker {
645*4bdc9457SAndroid Build Coastguard Worker const void** indirect_input =
646*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
647*4bdc9457SAndroid Build Coastguard Worker const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
648*4bdc9457SAndroid Build Coastguard Worker void* output = (void*) ((uintptr_t) context->output +
649*4bdc9457SAndroid Build Coastguard Worker batch_index * context->output_batch_stride + output_y * context->output_height_stride);
650*4bdc9457SAndroid Build Coastguard Worker
651*4bdc9457SAndroid Build Coastguard Worker context->unipass_ukernel(
652*4bdc9457SAndroid Build Coastguard Worker context->groups, context->output_width,
653*4bdc9457SAndroid Build Coastguard Worker indirect_input, context->packed_weights, output,
654*4bdc9457SAndroid Build Coastguard Worker context->indirect_input_width_stride, context->output_increment,
655*4bdc9457SAndroid Build Coastguard Worker input_offset, context->zero,
656*4bdc9457SAndroid Build Coastguard Worker &context->params);
657*4bdc9457SAndroid Build Coastguard Worker }
658*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_dwconv2d_chw(const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channel)659*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_dwconv2d_chw(
660*4bdc9457SAndroid Build Coastguard Worker const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)],
661*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
662*4bdc9457SAndroid Build Coastguard Worker size_t channel)
663*4bdc9457SAndroid Build Coastguard Worker {
664*4bdc9457SAndroid Build Coastguard Worker context->chw_ukernel(
665*4bdc9457SAndroid Build Coastguard Worker context->input_height,
666*4bdc9457SAndroid Build Coastguard Worker context->input_width,
667*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->input + channel * context->input_channel_stride + batch_index * context->input_batch_stride),
668*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_weights + channel * context->weights_channel_stride),
669*4bdc9457SAndroid Build Coastguard Worker context->zero,
670*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->output + channel * context->output_channel_stride + batch_index * context->output_batch_stride),
671*4bdc9457SAndroid Build Coastguard Worker context->input_padding_top,
672*4bdc9457SAndroid Build Coastguard Worker &context->params);
673*4bdc9457SAndroid Build Coastguard Worker }
674*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_argmax_pooling_unipass(const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)675*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_argmax_pooling_unipass(
676*4bdc9457SAndroid Build Coastguard Worker const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
677*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
678*4bdc9457SAndroid Build Coastguard Worker size_t output_y)
679*4bdc9457SAndroid Build Coastguard Worker {
680*4bdc9457SAndroid Build Coastguard Worker const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
681*4bdc9457SAndroid Build Coastguard Worker output_y * context->indirect_input_height_stride);
682*4bdc9457SAndroid Build Coastguard Worker const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
683*4bdc9457SAndroid Build Coastguard Worker void* output = (void*) ((uintptr_t) context->output +
684*4bdc9457SAndroid Build Coastguard Worker batch_index * context->output_batch_stride + output_y * context->output_height_stride);
685*4bdc9457SAndroid Build Coastguard Worker uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
686*4bdc9457SAndroid Build Coastguard Worker batch_index * context->index_batch_stride + output_y * context->index_height_stride);
687*4bdc9457SAndroid Build Coastguard Worker
688*4bdc9457SAndroid Build Coastguard Worker context->unipass_ukernel(
689*4bdc9457SAndroid Build Coastguard Worker context->output_width, context->pooling_size, context->channels,
690*4bdc9457SAndroid Build Coastguard Worker indirect_input, input_offset, output, index,
691*4bdc9457SAndroid Build Coastguard Worker context->input_increment, context->output_increment);
692*4bdc9457SAndroid Build Coastguard Worker }
693*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_argmax_pooling_multipass(const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)694*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_argmax_pooling_multipass(
695*4bdc9457SAndroid Build Coastguard Worker const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
696*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
697*4bdc9457SAndroid Build Coastguard Worker size_t output_y)
698*4bdc9457SAndroid Build Coastguard Worker {
699*4bdc9457SAndroid Build Coastguard Worker const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
700*4bdc9457SAndroid Build Coastguard Worker output_y * context->indirect_input_height_stride);
701*4bdc9457SAndroid Build Coastguard Worker const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
702*4bdc9457SAndroid Build Coastguard Worker void* output = (void*) ((uintptr_t) context->output +
703*4bdc9457SAndroid Build Coastguard Worker batch_index * context->output_batch_stride + output_y * context->output_height_stride);
704*4bdc9457SAndroid Build Coastguard Worker uint32_t* index = (uint32_t*) ((uintptr_t) context->index +
705*4bdc9457SAndroid Build Coastguard Worker batch_index * context->index_batch_stride + output_y * context->index_height_stride);
706*4bdc9457SAndroid Build Coastguard Worker
707*4bdc9457SAndroid Build Coastguard Worker void* multipass_accumulation_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(float) + XNN_EXTRA_BYTES);
708*4bdc9457SAndroid Build Coastguard Worker void* multipass_index_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(uint32_t) + XNN_EXTRA_BYTES);
709*4bdc9457SAndroid Build Coastguard Worker
710*4bdc9457SAndroid Build Coastguard Worker context->multipass_ukernel(
711*4bdc9457SAndroid Build Coastguard Worker context->output_width, context->pooling_size, context->channels,
712*4bdc9457SAndroid Build Coastguard Worker indirect_input, input_offset, multipass_accumulation_buffer, multipass_index_buffer, output, index,
713*4bdc9457SAndroid Build Coastguard Worker context->input_increment, context->output_increment);
714*4bdc9457SAndroid Build Coastguard Worker }
715*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_max_pooling(const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)716*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_max_pooling(
717*4bdc9457SAndroid Build Coastguard Worker const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
718*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
719*4bdc9457SAndroid Build Coastguard Worker size_t output_y)
720*4bdc9457SAndroid Build Coastguard Worker {
721*4bdc9457SAndroid Build Coastguard Worker const void** indirect_input = (const void**) ((uintptr_t) context->indirect_input +
722*4bdc9457SAndroid Build Coastguard Worker output_y * context->indirect_input_height_stride);
723*4bdc9457SAndroid Build Coastguard Worker const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
724*4bdc9457SAndroid Build Coastguard Worker void* output = (void*) ((uintptr_t) context->output +
725*4bdc9457SAndroid Build Coastguard Worker batch_index * context->output_batch_stride + output_y * context->output_height_stride);
726*4bdc9457SAndroid Build Coastguard Worker
727*4bdc9457SAndroid Build Coastguard Worker context->ukernel(
728*4bdc9457SAndroid Build Coastguard Worker context->output_width, context->pooling_size, context->channels,
729*4bdc9457SAndroid Build Coastguard Worker indirect_input, input_offset, output,
730*4bdc9457SAndroid Build Coastguard Worker context->input_increment, context->output_increment,
731*4bdc9457SAndroid Build Coastguard Worker &context->params);
732*4bdc9457SAndroid Build Coastguard Worker }
733*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_unpooling(const struct unpooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t input_y,size_t input_x)734*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_unpooling(
735*4bdc9457SAndroid Build Coastguard Worker const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)],
736*4bdc9457SAndroid Build Coastguard Worker size_t input_y,
737*4bdc9457SAndroid Build Coastguard Worker size_t input_x)
738*4bdc9457SAndroid Build Coastguard Worker {
739*4bdc9457SAndroid Build Coastguard Worker const void* input = (const void*) ((uintptr_t) context->input +
740*4bdc9457SAndroid Build Coastguard Worker input_y * context->input_height_stride + input_x * context->input_width_stride);
741*4bdc9457SAndroid Build Coastguard Worker const uint32_t* index = (const uint32_t*) ((uintptr_t) context->index +
742*4bdc9457SAndroid Build Coastguard Worker input_y * context->index_height_stride + input_x * context->index_width_stride);
743*4bdc9457SAndroid Build Coastguard Worker void** indirect_output =
744*4bdc9457SAndroid Build Coastguard Worker (void**) ((uintptr_t) context->indirect_output +
745*4bdc9457SAndroid Build Coastguard Worker input_y * context->indirect_output_height_stride + input_x * context->indirect_output_width_stride);
746*4bdc9457SAndroid Build Coastguard Worker
747*4bdc9457SAndroid Build Coastguard Worker context->ukernel(
748*4bdc9457SAndroid Build Coastguard Worker context->pooling_size,
749*4bdc9457SAndroid Build Coastguard Worker context->channels,
750*4bdc9457SAndroid Build Coastguard Worker context->fill_value,
751*4bdc9457SAndroid Build Coastguard Worker input, index, indirect_output);
752*4bdc9457SAndroid Build Coastguard Worker }
753*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_average_pooling_unipass(const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)754*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_average_pooling_unipass(
755*4bdc9457SAndroid Build Coastguard Worker const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
756*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
757*4bdc9457SAndroid Build Coastguard Worker size_t output_y)
758*4bdc9457SAndroid Build Coastguard Worker {
759*4bdc9457SAndroid Build Coastguard Worker const void** indirect_input =
760*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
761*4bdc9457SAndroid Build Coastguard Worker const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
762*4bdc9457SAndroid Build Coastguard Worker void* output = (void*) ((uintptr_t) context->output +
763*4bdc9457SAndroid Build Coastguard Worker batch_index * context->output_batch_stride + output_y * context->output_height_stride);
764*4bdc9457SAndroid Build Coastguard Worker
765*4bdc9457SAndroid Build Coastguard Worker context->unipass_ukernel(
766*4bdc9457SAndroid Build Coastguard Worker context->output_width, context->pooling_size, context->channels,
767*4bdc9457SAndroid Build Coastguard Worker indirect_input, input_offset, context->zero, output,
768*4bdc9457SAndroid Build Coastguard Worker context->input_increment, context->output_increment,
769*4bdc9457SAndroid Build Coastguard Worker &context->params);
770*4bdc9457SAndroid Build Coastguard Worker }
771*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_average_pooling_multipass(const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)772*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_average_pooling_multipass(
773*4bdc9457SAndroid Build Coastguard Worker const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
774*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
775*4bdc9457SAndroid Build Coastguard Worker size_t output_y)
776*4bdc9457SAndroid Build Coastguard Worker {
777*4bdc9457SAndroid Build Coastguard Worker const void** indirect_input =
778*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
779*4bdc9457SAndroid Build Coastguard Worker const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
780*4bdc9457SAndroid Build Coastguard Worker void* output = (void*) ((uintptr_t) context->output +
781*4bdc9457SAndroid Build Coastguard Worker batch_index * context->output_batch_stride + output_y * context->output_height_stride);
782*4bdc9457SAndroid Build Coastguard Worker
783*4bdc9457SAndroid Build Coastguard Worker void* multipass_buffer =
784*4bdc9457SAndroid Build Coastguard Worker XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
785*4bdc9457SAndroid Build Coastguard Worker
786*4bdc9457SAndroid Build Coastguard Worker context->multipass_ukernel(
787*4bdc9457SAndroid Build Coastguard Worker context->output_width, context->pooling_size, context->channels,
788*4bdc9457SAndroid Build Coastguard Worker indirect_input, input_offset, context->zero, multipass_buffer, output,
789*4bdc9457SAndroid Build Coastguard Worker context->input_increment, context->output_increment,
790*4bdc9457SAndroid Build Coastguard Worker &context->params);
791*4bdc9457SAndroid Build Coastguard Worker }
792*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_pixelwise_average_pooling_unipass(const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)793*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_pixelwise_average_pooling_unipass(
794*4bdc9457SAndroid Build Coastguard Worker const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
795*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
796*4bdc9457SAndroid Build Coastguard Worker size_t output_y)
797*4bdc9457SAndroid Build Coastguard Worker {
798*4bdc9457SAndroid Build Coastguard Worker const void** indirect_input =
799*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
800*4bdc9457SAndroid Build Coastguard Worker const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
801*4bdc9457SAndroid Build Coastguard Worker const void* pixelwise_buffer =
802*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
803*4bdc9457SAndroid Build Coastguard Worker void* output = (void*) ((uintptr_t) context->output +
804*4bdc9457SAndroid Build Coastguard Worker batch_index * context->output_batch_stride + output_y * context->output_height_stride);
805*4bdc9457SAndroid Build Coastguard Worker
806*4bdc9457SAndroid Build Coastguard Worker context->unipass_ukernel(
807*4bdc9457SAndroid Build Coastguard Worker context->output_width, context->pooling_size, context->channels,
808*4bdc9457SAndroid Build Coastguard Worker indirect_input, input_offset, context->zero, pixelwise_buffer, output,
809*4bdc9457SAndroid Build Coastguard Worker context->input_increment, context->output_increment,
810*4bdc9457SAndroid Build Coastguard Worker &context->params);
811*4bdc9457SAndroid Build Coastguard Worker }
812*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_pixelwise_average_pooling_multipass(const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t output_y)813*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_pixelwise_average_pooling_multipass(
814*4bdc9457SAndroid Build Coastguard Worker const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)],
815*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
816*4bdc9457SAndroid Build Coastguard Worker size_t output_y)
817*4bdc9457SAndroid Build Coastguard Worker {
818*4bdc9457SAndroid Build Coastguard Worker const void** indirect_input =
819*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) context->indirect_input + output_y * context->indirect_input_height_stride);
820*4bdc9457SAndroid Build Coastguard Worker const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride;
821*4bdc9457SAndroid Build Coastguard Worker const void* pixelwise_buffer =
822*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->pixelwise_buffer + output_y * context->pixelwise_buffer_height_stride);
823*4bdc9457SAndroid Build Coastguard Worker void* output = (void*) ((uintptr_t) context->output +
824*4bdc9457SAndroid Build Coastguard Worker batch_index * context->output_batch_stride + output_y * context->output_height_stride);
825*4bdc9457SAndroid Build Coastguard Worker
826*4bdc9457SAndroid Build Coastguard Worker void* multipass_buffer = XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
827*4bdc9457SAndroid Build Coastguard Worker
828*4bdc9457SAndroid Build Coastguard Worker context->multipass_ukernel(
829*4bdc9457SAndroid Build Coastguard Worker context->output_width, context->pooling_size, context->channels,
830*4bdc9457SAndroid Build Coastguard Worker indirect_input, input_offset, context->zero, pixelwise_buffer, multipass_buffer, output,
831*4bdc9457SAndroid Build Coastguard Worker context->input_increment, context->output_increment,
832*4bdc9457SAndroid Build Coastguard Worker &context->params);
833*4bdc9457SAndroid Build Coastguard Worker }
834*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_global_average_pooling_nwc_unipass(const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)835*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_global_average_pooling_nwc_unipass(
836*4bdc9457SAndroid Build Coastguard Worker const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
837*4bdc9457SAndroid Build Coastguard Worker size_t batch_index)
838*4bdc9457SAndroid Build Coastguard Worker {
839*4bdc9457SAndroid Build Coastguard Worker const void* input =
840*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
841*4bdc9457SAndroid Build Coastguard Worker void* output =
842*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
843*4bdc9457SAndroid Build Coastguard Worker
844*4bdc9457SAndroid Build Coastguard Worker context->unipass_ukernel(
845*4bdc9457SAndroid Build Coastguard Worker context->input_elements,
846*4bdc9457SAndroid Build Coastguard Worker context->channels,
847*4bdc9457SAndroid Build Coastguard Worker input,
848*4bdc9457SAndroid Build Coastguard Worker context->input_pixel_stride,
849*4bdc9457SAndroid Build Coastguard Worker context->zero,
850*4bdc9457SAndroid Build Coastguard Worker output,
851*4bdc9457SAndroid Build Coastguard Worker &context->params);
852*4bdc9457SAndroid Build Coastguard Worker }
853*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_global_average_pooling_nwc_multipass(const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)854*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_global_average_pooling_nwc_multipass(
855*4bdc9457SAndroid Build Coastguard Worker const struct global_average_pooling_nwc_context context[restrict XNN_MIN_ELEMENTS(1)],
856*4bdc9457SAndroid Build Coastguard Worker size_t batch_index)
857*4bdc9457SAndroid Build Coastguard Worker {
858*4bdc9457SAndroid Build Coastguard Worker const void* input =
859*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->input + batch_index * context->input_batch_stride);
860*4bdc9457SAndroid Build Coastguard Worker void* output =
861*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride);
862*4bdc9457SAndroid Build Coastguard Worker
863*4bdc9457SAndroid Build Coastguard Worker void* multipass_buffer =
864*4bdc9457SAndroid Build Coastguard Worker XNN_SIMD_ALLOCA(context->channels * sizeof(int32_t) + XNN_EXTRA_BYTES * sizeof(int32_t) / sizeof(uint8_t));
865*4bdc9457SAndroid Build Coastguard Worker
866*4bdc9457SAndroid Build Coastguard Worker context->multipass_ukernel(
867*4bdc9457SAndroid Build Coastguard Worker context->input_elements,
868*4bdc9457SAndroid Build Coastguard Worker context->channels,
869*4bdc9457SAndroid Build Coastguard Worker input,
870*4bdc9457SAndroid Build Coastguard Worker context->input_pixel_stride,
871*4bdc9457SAndroid Build Coastguard Worker context->zero,
872*4bdc9457SAndroid Build Coastguard Worker multipass_buffer,
873*4bdc9457SAndroid Build Coastguard Worker output,
874*4bdc9457SAndroid Build Coastguard Worker &context->params);
875*4bdc9457SAndroid Build Coastguard Worker }
876*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_global_average_pooling_ncw(const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channels_start,size_t channels_slice)877*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_global_average_pooling_ncw(
878*4bdc9457SAndroid Build Coastguard Worker const struct global_average_pooling_ncw_context context[restrict XNN_MIN_ELEMENTS(1)],
879*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
880*4bdc9457SAndroid Build Coastguard Worker size_t channels_start,
881*4bdc9457SAndroid Build Coastguard Worker size_t channels_slice)
882*4bdc9457SAndroid Build Coastguard Worker {
883*4bdc9457SAndroid Build Coastguard Worker const void* input = (const void*) ((uintptr_t) context->input +
884*4bdc9457SAndroid Build Coastguard Worker channels_start * context->input_channel_stride + batch_index * context->input_batch_stride);
885*4bdc9457SAndroid Build Coastguard Worker void* output = (void*) ((uintptr_t) context->output +
886*4bdc9457SAndroid Build Coastguard Worker channels_start * context->output_channel_stride + batch_index * context->output_batch_stride);
887*4bdc9457SAndroid Build Coastguard Worker
888*4bdc9457SAndroid Build Coastguard Worker context->ukernel(
889*4bdc9457SAndroid Build Coastguard Worker context->input_elements,
890*4bdc9457SAndroid Build Coastguard Worker channels_slice,
891*4bdc9457SAndroid Build Coastguard Worker input,
892*4bdc9457SAndroid Build Coastguard Worker output,
893*4bdc9457SAndroid Build Coastguard Worker &context->params);
894*4bdc9457SAndroid Build Coastguard Worker }
895*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_resize_bilinear(const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t pixel_start,size_t pixel_range)896*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_resize_bilinear(
897*4bdc9457SAndroid Build Coastguard Worker const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)],
898*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
899*4bdc9457SAndroid Build Coastguard Worker size_t pixel_start,
900*4bdc9457SAndroid Build Coastguard Worker size_t pixel_range)
901*4bdc9457SAndroid Build Coastguard Worker {
902*4bdc9457SAndroid Build Coastguard Worker void* output =
903*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->output + pixel_start * context->output_pixel_stride + batch_index * context->output_batch_stride);
904*4bdc9457SAndroid Build Coastguard Worker
905*4bdc9457SAndroid Build Coastguard Worker context->ukernel(
906*4bdc9457SAndroid Build Coastguard Worker pixel_range,
907*4bdc9457SAndroid Build Coastguard Worker context->scaled_channels,
908*4bdc9457SAndroid Build Coastguard Worker context->indirect_input + pixel_start * 4,
909*4bdc9457SAndroid Build Coastguard Worker context->input_offset + batch_index * context->input_batch_stride,
910*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_weights + (pixel_start << context->log2_wsize)),
911*4bdc9457SAndroid Build Coastguard Worker output,
912*4bdc9457SAndroid Build Coastguard Worker context->output_pixel_stride - context->scaled_channels);
913*4bdc9457SAndroid Build Coastguard Worker }
914*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_resize_bilinear_chw(const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t channel_start,size_t channel_range)915*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_resize_bilinear_chw(
916*4bdc9457SAndroid Build Coastguard Worker const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)],
917*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
918*4bdc9457SAndroid Build Coastguard Worker size_t channel_start,
919*4bdc9457SAndroid Build Coastguard Worker size_t channel_range)
920*4bdc9457SAndroid Build Coastguard Worker {
921*4bdc9457SAndroid Build Coastguard Worker void* output =
922*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->output + channel_start * context->output_channel_stride + batch_index * context->output_batch_stride);
923*4bdc9457SAndroid Build Coastguard Worker const size_t input_offset = context->input_offset + batch_index * context->input_batch_stride + channel_start * context->input_channel_stride;
924*4bdc9457SAndroid Build Coastguard Worker
925*4bdc9457SAndroid Build Coastguard Worker context->ukernel(
926*4bdc9457SAndroid Build Coastguard Worker context->output_pixels,
927*4bdc9457SAndroid Build Coastguard Worker channel_range,
928*4bdc9457SAndroid Build Coastguard Worker context->indirect_input,
929*4bdc9457SAndroid Build Coastguard Worker input_offset,
930*4bdc9457SAndroid Build Coastguard Worker context->packed_weights,
931*4bdc9457SAndroid Build Coastguard Worker output,
932*4bdc9457SAndroid Build Coastguard Worker context->input_channel_stride);
933*4bdc9457SAndroid Build Coastguard Worker }
934*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_prelu(const struct prelu_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_start,size_t batch_range)935*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_prelu(
936*4bdc9457SAndroid Build Coastguard Worker const struct prelu_context context[restrict XNN_MIN_ELEMENTS(1)],
937*4bdc9457SAndroid Build Coastguard Worker size_t batch_start,
938*4bdc9457SAndroid Build Coastguard Worker size_t batch_range)
939*4bdc9457SAndroid Build Coastguard Worker {
940*4bdc9457SAndroid Build Coastguard Worker const size_t x_stride = context->x_stride;
941*4bdc9457SAndroid Build Coastguard Worker const size_t y_stride = context->y_stride;
942*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
943*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
944*4bdc9457SAndroid Build Coastguard Worker
945*4bdc9457SAndroid Build Coastguard Worker context->ukernel(batch_range, context->n, x, x_stride, context->w, y, y_stride);
946*4bdc9457SAndroid Build Coastguard Worker }
947*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_pad_5d(const struct pad_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k,size_t l,size_t m)948*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_pad_5d(
949*4bdc9457SAndroid Build Coastguard Worker const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)],
950*4bdc9457SAndroid Build Coastguard Worker size_t i, size_t j, size_t k, size_t l, size_t m)
951*4bdc9457SAndroid Build Coastguard Worker {
952*4bdc9457SAndroid Build Coastguard Worker const void* input = (const void*) ((uintptr_t) context->input +
953*4bdc9457SAndroid Build Coastguard Worker i * context->input_stride[4] + j * context->input_stride[3] + k * context->input_stride[2] + l * context->input_stride[1] + m * context->input_stride[0]);
954*4bdc9457SAndroid Build Coastguard Worker void* output = (void*) ((uintptr_t) context->output +
955*4bdc9457SAndroid Build Coastguard Worker i * context->output_stride[4] + j * context->output_stride[3] + k * context->output_stride[2] + l * context->output_stride[1] + m * context->output_stride[0]);
956*4bdc9457SAndroid Build Coastguard Worker
957*4bdc9457SAndroid Build Coastguard Worker const size_t i_padding = context->pre_paddings[5];
958*4bdc9457SAndroid Build Coastguard Worker const size_t j_padding = context->pre_paddings[4];
959*4bdc9457SAndroid Build Coastguard Worker const size_t k_padding = context->pre_paddings[3];
960*4bdc9457SAndroid Build Coastguard Worker const size_t l_padding = context->pre_paddings[2];
961*4bdc9457SAndroid Build Coastguard Worker const size_t m_padding = context->pre_paddings[1];
962*4bdc9457SAndroid Build Coastguard Worker
963*4bdc9457SAndroid Build Coastguard Worker const size_t i_size = context->input_size[5];
964*4bdc9457SAndroid Build Coastguard Worker const size_t j_size = context->input_size[4];
965*4bdc9457SAndroid Build Coastguard Worker const size_t k_size = context->input_size[3];
966*4bdc9457SAndroid Build Coastguard Worker const size_t l_size = context->input_size[2];
967*4bdc9457SAndroid Build Coastguard Worker const size_t m_size = context->input_size[1];
968*4bdc9457SAndroid Build Coastguard Worker
969*4bdc9457SAndroid Build Coastguard Worker if XNN_LIKELY(i - i_padding < i_size && j - j_padding < j_size && k - k_padding < k_size &&
970*4bdc9457SAndroid Build Coastguard Worker l - l_padding < l_size && m - m_padding < m_size)
971*4bdc9457SAndroid Build Coastguard Worker {
972*4bdc9457SAndroid Build Coastguard Worker context->pad_ukernel(
973*4bdc9457SAndroid Build Coastguard Worker 1 /* rows */,
974*4bdc9457SAndroid Build Coastguard Worker context->input_size[0], context->pre_paddings[0], context->post_paddings[0],
975*4bdc9457SAndroid Build Coastguard Worker input, 0 /* input stride */, output, 0 /* output stride */,
976*4bdc9457SAndroid Build Coastguard Worker context->padding_value);
977*4bdc9457SAndroid Build Coastguard Worker } else {
978*4bdc9457SAndroid Build Coastguard Worker context->fill_ukernel(1 /* rows */, context->output_size[0], output, 0 /* output stride */, context->padding_value);
979*4bdc9457SAndroid Build Coastguard Worker }
980*4bdc9457SAndroid Build Coastguard Worker }
981*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_elementwise_binary_1d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i)982*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_elementwise_binary_1d(
983*4bdc9457SAndroid Build Coastguard Worker const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
984*4bdc9457SAndroid Build Coastguard Worker size_t i)
985*4bdc9457SAndroid Build Coastguard Worker {
986*4bdc9457SAndroid Build Coastguard Worker const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[4]);
987*4bdc9457SAndroid Build Coastguard Worker const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[4]);
988*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[4]);
989*4bdc9457SAndroid Build Coastguard Worker context->ukernel(context->elements, a, b, y, &context->params);
990*4bdc9457SAndroid Build Coastguard Worker }
991*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_elementwise_binary_2d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j)992*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_elementwise_binary_2d(
993*4bdc9457SAndroid Build Coastguard Worker const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
994*4bdc9457SAndroid Build Coastguard Worker size_t i, size_t j)
995*4bdc9457SAndroid Build Coastguard Worker {
996*4bdc9457SAndroid Build Coastguard Worker const void* a = (const void*) ((uintptr_t) context->a + i * context->a_stride[3] + j * context->a_stride[4]);
997*4bdc9457SAndroid Build Coastguard Worker const void* b = (const void*) ((uintptr_t) context->b + i * context->b_stride[3] + j * context->b_stride[4]);
998*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y + i * context->y_stride[3] + j * context->y_stride[4]);
999*4bdc9457SAndroid Build Coastguard Worker context->ukernel(context->elements, a, b, y, &context->params);
1000*4bdc9457SAndroid Build Coastguard Worker }
1001*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_elementwise_binary_3d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k)1002*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_elementwise_binary_3d(
1003*4bdc9457SAndroid Build Coastguard Worker const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1004*4bdc9457SAndroid Build Coastguard Worker size_t i, size_t j, size_t k)
1005*4bdc9457SAndroid Build Coastguard Worker {
1006*4bdc9457SAndroid Build Coastguard Worker const void* a = (const void*) ((uintptr_t) context->a +
1007*4bdc9457SAndroid Build Coastguard Worker i * context->a_stride[2] + j * context->a_stride[3] + k * context->a_stride[4]);
1008*4bdc9457SAndroid Build Coastguard Worker const void* b = (const void*) ((uintptr_t) context->b +
1009*4bdc9457SAndroid Build Coastguard Worker i * context->b_stride[2] + j * context->b_stride[3] + k * context->b_stride[4]);
1010*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y +
1011*4bdc9457SAndroid Build Coastguard Worker i * context->y_stride[2] + j * context->y_stride[3] + k * context->y_stride[4]);
1012*4bdc9457SAndroid Build Coastguard Worker context->ukernel(context->elements, a, b, y, &context->params);
1013*4bdc9457SAndroid Build Coastguard Worker }
1014*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_elementwise_binary_4d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k,size_t l)1015*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_elementwise_binary_4d(
1016*4bdc9457SAndroid Build Coastguard Worker const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1017*4bdc9457SAndroid Build Coastguard Worker size_t i, size_t j, size_t k, size_t l)
1018*4bdc9457SAndroid Build Coastguard Worker {
1019*4bdc9457SAndroid Build Coastguard Worker const void* a = (const void*) ((uintptr_t) context->a +
1020*4bdc9457SAndroid Build Coastguard Worker i * context->a_stride[1] + j * context->a_stride[2] + k * context->a_stride[3] + l * context->a_stride[4]);
1021*4bdc9457SAndroid Build Coastguard Worker const void* b = (const void*) ((uintptr_t) context->b +
1022*4bdc9457SAndroid Build Coastguard Worker i * context->b_stride[1] + j * context->b_stride[2] + k * context->b_stride[3] + l * context->b_stride[4]);
1023*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y +
1024*4bdc9457SAndroid Build Coastguard Worker i * context->y_stride[1] + j * context->y_stride[2] + k * context->y_stride[3] + l * context->y_stride[4]);
1025*4bdc9457SAndroid Build Coastguard Worker context->ukernel(context->elements, a, b, y, &context->params);
1026*4bdc9457SAndroid Build Coastguard Worker }
1027*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_elementwise_binary_5d(const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS (1)],size_t i,size_t j,size_t k,size_t l,size_t m)1028*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_elementwise_binary_5d(
1029*4bdc9457SAndroid Build Coastguard Worker const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)],
1030*4bdc9457SAndroid Build Coastguard Worker size_t i, size_t j, size_t k, size_t l, size_t m)
1031*4bdc9457SAndroid Build Coastguard Worker {
1032*4bdc9457SAndroid Build Coastguard Worker const void* a = (const void*) ((uintptr_t) context->a +
1033*4bdc9457SAndroid Build Coastguard Worker i * context->a_stride[0] + j * context->a_stride[1] + k * context->a_stride[2] + l * context->a_stride[3] + m * context->a_stride[4]);
1034*4bdc9457SAndroid Build Coastguard Worker const void* b = (const void*) ((uintptr_t) context->b +
1035*4bdc9457SAndroid Build Coastguard Worker i * context->b_stride[0] + j * context->b_stride[1] + k * context->b_stride[2] + l * context->b_stride[3] + m * context->b_stride[4]);
1036*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y +
1037*4bdc9457SAndroid Build Coastguard Worker i * context->y_stride[0] + j * context->y_stride[1] + k * context->y_stride[2] + l * context->y_stride[3] + m * context->y_stride[4]);
1038*4bdc9457SAndroid Build Coastguard Worker context->ukernel(context->elements, a, b, y, &context->params);
1039*4bdc9457SAndroid Build Coastguard Worker }
1040*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_channel_shuffle_fixed(const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS (1)],size_t index)1041*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_channel_shuffle_fixed(
1042*4bdc9457SAndroid Build Coastguard Worker const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
1043*4bdc9457SAndroid Build Coastguard Worker size_t index)
1044*4bdc9457SAndroid Build Coastguard Worker {
1045*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
1046*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
1047*4bdc9457SAndroid Build Coastguard Worker
1048*4bdc9457SAndroid Build Coastguard Worker context->fixed_ukernel(context->n, x, y);
1049*4bdc9457SAndroid Build Coastguard Worker }
1050*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_channel_shuffle_variable(const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS (1)],size_t index)1051*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_channel_shuffle_variable(
1052*4bdc9457SAndroid Build Coastguard Worker const struct channel_shuffle_context context[restrict XNN_MIN_ELEMENTS(1)],
1053*4bdc9457SAndroid Build Coastguard Worker size_t index)
1054*4bdc9457SAndroid Build Coastguard Worker {
1055*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*) ((uintptr_t) context->x + index * context->x_stride);
1056*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y + index * context->y_stride);
1057*4bdc9457SAndroid Build Coastguard Worker
1058*4bdc9457SAndroid Build Coastguard Worker context->variable_ukernel(context->n, context->m, x, y);
1059*4bdc9457SAndroid Build Coastguard Worker }
1060*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_lut_strided(const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)1061*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_lut_strided(
1062*4bdc9457SAndroid Build Coastguard Worker const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
1063*4bdc9457SAndroid Build Coastguard Worker size_t batch_index)
1064*4bdc9457SAndroid Build Coastguard Worker {
1065*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
1066*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
1067*4bdc9457SAndroid Build Coastguard Worker
1068*4bdc9457SAndroid Build Coastguard Worker context->ukernel(context->n, x, y, context->t);
1069*4bdc9457SAndroid Build Coastguard Worker }
1070*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_lut_contiguous(const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS (1)],size_t offset,size_t size)1071*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_lut_contiguous(
1072*4bdc9457SAndroid Build Coastguard Worker const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
1073*4bdc9457SAndroid Build Coastguard Worker size_t offset,
1074*4bdc9457SAndroid Build Coastguard Worker size_t size)
1075*4bdc9457SAndroid Build Coastguard Worker {
1076*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*) ((uintptr_t) context->x + offset);
1077*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y + offset);
1078*4bdc9457SAndroid Build Coastguard Worker
1079*4bdc9457SAndroid Build Coastguard Worker context->ukernel(size, x, y, context->t);
1080*4bdc9457SAndroid Build Coastguard Worker }
1081*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_univector_strided(const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index,size_t batch_range)1082*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_univector_strided(
1083*4bdc9457SAndroid Build Coastguard Worker const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)],
1084*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
1085*4bdc9457SAndroid Build Coastguard Worker size_t batch_range)
1086*4bdc9457SAndroid Build Coastguard Worker {
1087*4bdc9457SAndroid Build Coastguard Worker const size_t x_stride = context->x_stride;
1088*4bdc9457SAndroid Build Coastguard Worker const size_t y_stride = context->y_stride;
1089*4bdc9457SAndroid Build Coastguard Worker
1090*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_index);
1091*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y + y_stride * batch_index);
1092*4bdc9457SAndroid Build Coastguard Worker do {
1093*4bdc9457SAndroid Build Coastguard Worker context->ukernel(context->n, x, y, &context->params);
1094*4bdc9457SAndroid Build Coastguard Worker x = (const void*) ((uintptr_t) x + x_stride);
1095*4bdc9457SAndroid Build Coastguard Worker y = (void*) ((uintptr_t) y + y_stride);
1096*4bdc9457SAndroid Build Coastguard Worker } while (--batch_range != 0);
1097*4bdc9457SAndroid Build Coastguard Worker }
1098*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_univector_contiguous(const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS (1)],size_t offset,size_t size)1099*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_univector_contiguous(
1100*4bdc9457SAndroid Build Coastguard Worker const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)],
1101*4bdc9457SAndroid Build Coastguard Worker size_t offset,
1102*4bdc9457SAndroid Build Coastguard Worker size_t size)
1103*4bdc9457SAndroid Build Coastguard Worker {
1104*4bdc9457SAndroid Build Coastguard Worker const uint32_t log2_xsize = context->log2_xsize;
1105*4bdc9457SAndroid Build Coastguard Worker const uint32_t log2_ysize = context->log2_ysize;
1106*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*) ((uintptr_t) context->x + offset);
1107*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y + ((offset >> log2_xsize) << log2_ysize));
1108*4bdc9457SAndroid Build Coastguard Worker context->ukernel(size, x, y, &context->params);
1109*4bdc9457SAndroid Build Coastguard Worker }
1110*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_u8_softmax(const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)1111*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_u8_softmax(
1112*4bdc9457SAndroid Build Coastguard Worker const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
1113*4bdc9457SAndroid Build Coastguard Worker size_t batch_index)
1114*4bdc9457SAndroid Build Coastguard Worker {
1115*4bdc9457SAndroid Build Coastguard Worker const uint8_t* x = (const uint8_t*) ((uintptr_t) context->x + context->x_stride * batch_index);
1116*4bdc9457SAndroid Build Coastguard Worker uint8_t* y = (uint8_t*) ((uintptr_t) context->y + context->y_stride * batch_index);
1117*4bdc9457SAndroid Build Coastguard Worker const size_t n = context->n;
1118*4bdc9457SAndroid Build Coastguard Worker
1119*4bdc9457SAndroid Build Coastguard Worker uint8_t x_max = 0;
1120*4bdc9457SAndroid Build Coastguard Worker context->rmax_ukernel(n, x, &x_max);
1121*4bdc9457SAndroid Build Coastguard Worker const size_t adjustment = x_max ^ 255;
1122*4bdc9457SAndroid Build Coastguard Worker const uint32_t* t = (const uint32_t*) context->t + adjustment;
1123*4bdc9457SAndroid Build Coastguard Worker context->lut_norm_ukernel(n, x, t, y);
1124*4bdc9457SAndroid Build Coastguard Worker }
1125*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_floating_point_softmax(const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_index)1126*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_floating_point_softmax(
1127*4bdc9457SAndroid Build Coastguard Worker const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS(1)],
1128*4bdc9457SAndroid Build Coastguard Worker size_t batch_index)
1129*4bdc9457SAndroid Build Coastguard Worker {
1130*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*) ((uintptr_t) context->x + context->x_stride * batch_index);
1131*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y + context->y_stride * batch_index);
1132*4bdc9457SAndroid Build Coastguard Worker const size_t n = context->n;
1133*4bdc9457SAndroid Build Coastguard Worker
1134*4bdc9457SAndroid Build Coastguard Worker // First pass: reduce-max
1135*4bdc9457SAndroid Build Coastguard Worker union {
1136*4bdc9457SAndroid Build Coastguard Worker float as_float;
1137*4bdc9457SAndroid Build Coastguard Worker uint16_t as_half;
1138*4bdc9457SAndroid Build Coastguard Worker } x_max;
1139*4bdc9457SAndroid Build Coastguard Worker context->rmax_ukernel(n, x, &x_max);
1140*4bdc9457SAndroid Build Coastguard Worker
1141*4bdc9457SAndroid Build Coastguard Worker // Second pass: reduce-add & store exp(x-x_max)
1142*4bdc9457SAndroid Build Coastguard Worker union {
1143*4bdc9457SAndroid Build Coastguard Worker float as_float;
1144*4bdc9457SAndroid Build Coastguard Worker uint16_t as_half;
1145*4bdc9457SAndroid Build Coastguard Worker } y_sum;
1146*4bdc9457SAndroid Build Coastguard Worker context->raddstoreexpminusmax_ukernel(n, x, &x_max, y, &y_sum, &context->expminus_params);
1147*4bdc9457SAndroid Build Coastguard Worker
1148*4bdc9457SAndroid Build Coastguard Worker // Third pass: scale y
1149*4bdc9457SAndroid Build Coastguard Worker union {
1150*4bdc9457SAndroid Build Coastguard Worker float as_float;
1151*4bdc9457SAndroid Build Coastguard Worker uint16_t as_half;
1152*4bdc9457SAndroid Build Coastguard Worker } y_scale;
1153*4bdc9457SAndroid Build Coastguard Worker context->compute_reciprocal(&y_sum, &y_scale);
1154*4bdc9457SAndroid Build Coastguard Worker context->vmulc_ukernel(n, y, &y_scale, y, &context->minmax_params);
1155*4bdc9457SAndroid Build Coastguard Worker }
1156*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_vmulcaddc(const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS (1)],size_t batch_start,size_t batch_size)1157*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_vmulcaddc(
1158*4bdc9457SAndroid Build Coastguard Worker const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)],
1159*4bdc9457SAndroid Build Coastguard Worker size_t batch_start,
1160*4bdc9457SAndroid Build Coastguard Worker size_t batch_size)
1161*4bdc9457SAndroid Build Coastguard Worker {
1162*4bdc9457SAndroid Build Coastguard Worker const size_t x_stride = context->x_stride;
1163*4bdc9457SAndroid Build Coastguard Worker const size_t y_stride = context->y_stride;
1164*4bdc9457SAndroid Build Coastguard Worker
1165*4bdc9457SAndroid Build Coastguard Worker const void* x = (const void*) ((uintptr_t) context->x + x_stride * batch_start);
1166*4bdc9457SAndroid Build Coastguard Worker void* y = (void*) ((uintptr_t) context->y + y_stride * batch_start);
1167*4bdc9457SAndroid Build Coastguard Worker
1168*4bdc9457SAndroid Build Coastguard Worker context->ukernel(
1169*4bdc9457SAndroid Build Coastguard Worker batch_size,
1170*4bdc9457SAndroid Build Coastguard Worker context->n,
1171*4bdc9457SAndroid Build Coastguard Worker x, x_stride,
1172*4bdc9457SAndroid Build Coastguard Worker context->w,
1173*4bdc9457SAndroid Build Coastguard Worker y, y_stride,
1174*4bdc9457SAndroid Build Coastguard Worker &context->params);
1175*4bdc9457SAndroid Build Coastguard Worker }
1176*4bdc9457SAndroid Build Coastguard Worker
1177*4bdc9457SAndroid Build Coastguard Worker #if XNN_MAX_UARCH_TYPES > 1
xnn_compute_hmp_grouped_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1178*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_hmp_grouped_gemm(
1179*4bdc9457SAndroid Build Coastguard Worker const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1180*4bdc9457SAndroid Build Coastguard Worker uint32_t uarch_index,
1181*4bdc9457SAndroid Build Coastguard Worker size_t group_index,
1182*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_start,
1183*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_start,
1184*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_size,
1185*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_size)
1186*4bdc9457SAndroid Build Coastguard Worker {
1187*4bdc9457SAndroid Build Coastguard Worker const size_t k_scaled = context->k_scaled;
1188*4bdc9457SAndroid Build Coastguard Worker const size_t a_stride = context->a_stride;
1189*4bdc9457SAndroid Build Coastguard Worker const size_t cm_stride = context->cm_stride;
1190*4bdc9457SAndroid Build Coastguard Worker
1191*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[uarch_index](
1192*4bdc9457SAndroid Build Coastguard Worker mr_block_size,
1193*4bdc9457SAndroid Build Coastguard Worker nr_block_size,
1194*4bdc9457SAndroid Build Coastguard Worker k_scaled,
1195*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->a + mr_block_start * a_stride + group_index * k_scaled),
1196*4bdc9457SAndroid Build Coastguard Worker a_stride,
1197*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->wg_stride),
1198*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize) + group_index * context->cg_stride),
1199*4bdc9457SAndroid Build Coastguard Worker cm_stride,
1200*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
1201*4bdc9457SAndroid Build Coastguard Worker &context->params);
1202*4bdc9457SAndroid Build Coastguard Worker }
1203*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_hmp_gemm(const struct gemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1204*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_hmp_gemm(
1205*4bdc9457SAndroid Build Coastguard Worker const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1206*4bdc9457SAndroid Build Coastguard Worker uint32_t uarch_index,
1207*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_start,
1208*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_start,
1209*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_size,
1210*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_size)
1211*4bdc9457SAndroid Build Coastguard Worker {
1212*4bdc9457SAndroid Build Coastguard Worker const size_t a_stride = context->a_stride;
1213*4bdc9457SAndroid Build Coastguard Worker const size_t cm_stride = context->cm_stride;
1214*4bdc9457SAndroid Build Coastguard Worker
1215*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[uarch_index](
1216*4bdc9457SAndroid Build Coastguard Worker mr_block_size,
1217*4bdc9457SAndroid Build Coastguard Worker nr_block_size,
1218*4bdc9457SAndroid Build Coastguard Worker context->k_scaled,
1219*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
1220*4bdc9457SAndroid Build Coastguard Worker a_stride,
1221*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1222*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1223*4bdc9457SAndroid Build Coastguard Worker cm_stride,
1224*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
1225*4bdc9457SAndroid Build Coastguard Worker context->fused_params);
1226*4bdc9457SAndroid Build Coastguard Worker }
1227*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_hmp_grouped_batch_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t batch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1228*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_hmp_grouped_batch_igemm(
1229*4bdc9457SAndroid Build Coastguard Worker const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1230*4bdc9457SAndroid Build Coastguard Worker uint32_t uarch_index,
1231*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
1232*4bdc9457SAndroid Build Coastguard Worker size_t group_index,
1233*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_start,
1234*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_start,
1235*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_size,
1236*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_size)
1237*4bdc9457SAndroid Build Coastguard Worker {
1238*4bdc9457SAndroid Build Coastguard Worker const size_t ks = context->ks;
1239*4bdc9457SAndroid Build Coastguard Worker const size_t cm_stride = context->cm_stride;
1240*4bdc9457SAndroid Build Coastguard Worker
1241*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[uarch_index](
1242*4bdc9457SAndroid Build Coastguard Worker mr_block_size,
1243*4bdc9457SAndroid Build Coastguard Worker nr_block_size,
1244*4bdc9457SAndroid Build Coastguard Worker context->kc,
1245*4bdc9457SAndroid Build Coastguard Worker context->ks_scaled,
1246*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1247*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
1248*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->c + group_index * context->gc_stride + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1249*4bdc9457SAndroid Build Coastguard Worker cm_stride,
1250*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
1251*4bdc9457SAndroid Build Coastguard Worker context->a_offset + group_index * context->ga_stride + batch_index * context->ba_stride,
1252*4bdc9457SAndroid Build Coastguard Worker context->zero,
1253*4bdc9457SAndroid Build Coastguard Worker &context->params);
1254*4bdc9457SAndroid Build Coastguard Worker }
1255*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_hmp_grouped_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t group_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1256*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_hmp_grouped_igemm(
1257*4bdc9457SAndroid Build Coastguard Worker const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1258*4bdc9457SAndroid Build Coastguard Worker uint32_t uarch_index,
1259*4bdc9457SAndroid Build Coastguard Worker size_t group_index,
1260*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_start,
1261*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_start,
1262*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_size,
1263*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_size)
1264*4bdc9457SAndroid Build Coastguard Worker {
1265*4bdc9457SAndroid Build Coastguard Worker const size_t ks = context->ks;
1266*4bdc9457SAndroid Build Coastguard Worker const size_t cm_stride = context->cm_stride;
1267*4bdc9457SAndroid Build Coastguard Worker
1268*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[uarch_index](
1269*4bdc9457SAndroid Build Coastguard Worker mr_block_size,
1270*4bdc9457SAndroid Build Coastguard Worker nr_block_size,
1271*4bdc9457SAndroid Build Coastguard Worker context->kc,
1272*4bdc9457SAndroid Build Coastguard Worker context->ks_scaled,
1273*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1274*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride + group_index * context->gw_stride),
1275*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->c + group_index * context->gc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1276*4bdc9457SAndroid Build Coastguard Worker cm_stride,
1277*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
1278*4bdc9457SAndroid Build Coastguard Worker context->a_offset + group_index * context->ga_stride,
1279*4bdc9457SAndroid Build Coastguard Worker context->zero,
1280*4bdc9457SAndroid Build Coastguard Worker &context->params);
1281*4bdc9457SAndroid Build Coastguard Worker }
1282*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_batch_hmp_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t batch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1283*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_batch_hmp_igemm(
1284*4bdc9457SAndroid Build Coastguard Worker const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1285*4bdc9457SAndroid Build Coastguard Worker uint32_t uarch_index,
1286*4bdc9457SAndroid Build Coastguard Worker size_t batch_index,
1287*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_start,
1288*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_start,
1289*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_size,
1290*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_size)
1291*4bdc9457SAndroid Build Coastguard Worker {
1292*4bdc9457SAndroid Build Coastguard Worker const size_t ks = context->ks;
1293*4bdc9457SAndroid Build Coastguard Worker const size_t cm_stride = context->cm_stride;
1294*4bdc9457SAndroid Build Coastguard Worker
1295*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[uarch_index](
1296*4bdc9457SAndroid Build Coastguard Worker mr_block_size,
1297*4bdc9457SAndroid Build Coastguard Worker nr_block_size,
1298*4bdc9457SAndroid Build Coastguard Worker context->kc,
1299*4bdc9457SAndroid Build Coastguard Worker context->ks_scaled,
1300*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1301*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1302*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->c + batch_index * context->bc_stride + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1303*4bdc9457SAndroid Build Coastguard Worker cm_stride,
1304*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
1305*4bdc9457SAndroid Build Coastguard Worker context->a_offset + batch_index * context->ba_stride,
1306*4bdc9457SAndroid Build Coastguard Worker context->zero,
1307*4bdc9457SAndroid Build Coastguard Worker &context->params);
1308*4bdc9457SAndroid Build Coastguard Worker }
1309*4bdc9457SAndroid Build Coastguard Worker
xnn_compute_hmp_igemm(const struct igemm_context context[restrict XNN_MIN_ELEMENTS (1)],uint32_t uarch_index,size_t mr_block_start,size_t nr_block_start,size_t mr_block_size,size_t nr_block_size)1310*4bdc9457SAndroid Build Coastguard Worker void xnn_compute_hmp_igemm(
1311*4bdc9457SAndroid Build Coastguard Worker const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
1312*4bdc9457SAndroid Build Coastguard Worker uint32_t uarch_index,
1313*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_start,
1314*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_start,
1315*4bdc9457SAndroid Build Coastguard Worker size_t mr_block_size,
1316*4bdc9457SAndroid Build Coastguard Worker size_t nr_block_size)
1317*4bdc9457SAndroid Build Coastguard Worker {
1318*4bdc9457SAndroid Build Coastguard Worker const size_t ks = context->ks;
1319*4bdc9457SAndroid Build Coastguard Worker const size_t cm_stride = context->cm_stride;
1320*4bdc9457SAndroid Build Coastguard Worker
1321*4bdc9457SAndroid Build Coastguard Worker context->ukernel.function[uarch_index](
1322*4bdc9457SAndroid Build Coastguard Worker mr_block_size,
1323*4bdc9457SAndroid Build Coastguard Worker nr_block_size,
1324*4bdc9457SAndroid Build Coastguard Worker context->kc,
1325*4bdc9457SAndroid Build Coastguard Worker context->ks_scaled,
1326*4bdc9457SAndroid Build Coastguard Worker (const void**) ((uintptr_t) context->indirect_a + mr_block_start * ks * sizeof(void*)),
1327*4bdc9457SAndroid Build Coastguard Worker (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
1328*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
1329*4bdc9457SAndroid Build Coastguard Worker cm_stride,
1330*4bdc9457SAndroid Build Coastguard Worker context->cn_stride,
1331*4bdc9457SAndroid Build Coastguard Worker context->a_offset,
1332*4bdc9457SAndroid Build Coastguard Worker context->zero,
1333*4bdc9457SAndroid Build Coastguard Worker &context->params);
1334*4bdc9457SAndroid Build Coastguard Worker }
1335*4bdc9457SAndroid Build Coastguard Worker #endif // XNN_MAX_UARCH_TYPES > 1
1336*4bdc9457SAndroid Build Coastguard Worker
xnn_run_operator(xnn_operator_t op,pthreadpool_t threadpool)1337*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_run_operator(xnn_operator_t op, pthreadpool_t threadpool)
1338*4bdc9457SAndroid Build Coastguard Worker {
1339*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
1340*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to run operator: XNNPACK is not initialized");
1341*4bdc9457SAndroid Build Coastguard Worker return xnn_status_uninitialized;
1342*4bdc9457SAndroid Build Coastguard Worker }
1343*4bdc9457SAndroid Build Coastguard Worker switch (op->state) {
1344*4bdc9457SAndroid Build Coastguard Worker case xnn_run_state_invalid:
1345*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to run operator: operator was not successfully setup");
1346*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_state;
1347*4bdc9457SAndroid Build Coastguard Worker case xnn_run_state_ready:
1348*4bdc9457SAndroid Build Coastguard Worker break;
1349*4bdc9457SAndroid Build Coastguard Worker case xnn_run_state_skip:
1350*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
1351*4bdc9457SAndroid Build Coastguard Worker }
1352*4bdc9457SAndroid Build Coastguard Worker
1353*4bdc9457SAndroid Build Coastguard Worker uint32_t flags = PTHREADPOOL_FLAG_DISABLE_DENORMALS;
1354*4bdc9457SAndroid Build Coastguard Worker if (op->flags & XNN_FLAG_YIELD_WORKERS) {
1355*4bdc9457SAndroid Build Coastguard Worker flags |= PTHREADPOOL_FLAG_YIELD_WORKERS;
1356*4bdc9457SAndroid Build Coastguard Worker }
1357*4bdc9457SAndroid Build Coastguard Worker switch (op->compute.type) {
1358*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_invalid:
1359*4bdc9457SAndroid Build Coastguard Worker break;
1360*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_1d:
1361*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1362*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_1d(
1363*4bdc9457SAndroid Build Coastguard Worker threadpool,
1364*4bdc9457SAndroid Build Coastguard Worker op->compute.task_1d,
1365*4bdc9457SAndroid Build Coastguard Worker &op->context,
1366*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0],
1367*4bdc9457SAndroid Build Coastguard Worker flags);
1368*4bdc9457SAndroid Build Coastguard Worker break;
1369*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_1d_tile_1d:
1370*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1371*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[0] != 0);
1372*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_1d_tile_1d(
1373*4bdc9457SAndroid Build Coastguard Worker threadpool,
1374*4bdc9457SAndroid Build Coastguard Worker op->compute.task_1d_tile_1d,
1375*4bdc9457SAndroid Build Coastguard Worker &op->context,
1376*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0],
1377*4bdc9457SAndroid Build Coastguard Worker op->compute.tile[0],
1378*4bdc9457SAndroid Build Coastguard Worker flags);
1379*4bdc9457SAndroid Build Coastguard Worker break;
1380*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_2d:
1381*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1382*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[1] != 0);
1383*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_2d(
1384*4bdc9457SAndroid Build Coastguard Worker threadpool,
1385*4bdc9457SAndroid Build Coastguard Worker op->compute.task_2d,
1386*4bdc9457SAndroid Build Coastguard Worker &op->context,
1387*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0], op->compute.range[1],
1388*4bdc9457SAndroid Build Coastguard Worker flags);
1389*4bdc9457SAndroid Build Coastguard Worker break;
1390*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_2d_tile_1d:
1391*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1392*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[1] != 0);
1393*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[0] != 0);
1394*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_2d_tile_1d(
1395*4bdc9457SAndroid Build Coastguard Worker threadpool,
1396*4bdc9457SAndroid Build Coastguard Worker op->compute.task_2d_tile_1d,
1397*4bdc9457SAndroid Build Coastguard Worker &op->context,
1398*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0], op->compute.range[1],
1399*4bdc9457SAndroid Build Coastguard Worker op->compute.tile[0],
1400*4bdc9457SAndroid Build Coastguard Worker flags);
1401*4bdc9457SAndroid Build Coastguard Worker break;
1402*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_2d_tile_2d:
1403*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1404*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[1] != 0);
1405*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[0] != 0);
1406*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[1] != 0);
1407*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_2d_tile_2d(
1408*4bdc9457SAndroid Build Coastguard Worker threadpool,
1409*4bdc9457SAndroid Build Coastguard Worker op->compute.task_2d_tile_2d,
1410*4bdc9457SAndroid Build Coastguard Worker &op->context,
1411*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0], op->compute.range[1],
1412*4bdc9457SAndroid Build Coastguard Worker op->compute.tile[0], op->compute.tile[1],
1413*4bdc9457SAndroid Build Coastguard Worker flags);
1414*4bdc9457SAndroid Build Coastguard Worker break;
1415*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_3d:
1416*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1417*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[1] != 0);
1418*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[2] != 0);
1419*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_3d(
1420*4bdc9457SAndroid Build Coastguard Worker threadpool,
1421*4bdc9457SAndroid Build Coastguard Worker op->compute.task_3d,
1422*4bdc9457SAndroid Build Coastguard Worker &op->context,
1423*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0], op->compute.range[1], op->compute.range[2],
1424*4bdc9457SAndroid Build Coastguard Worker flags);
1425*4bdc9457SAndroid Build Coastguard Worker break;
1426*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_3d_tile_2d:
1427*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1428*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[1] != 0);
1429*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[2] != 0);
1430*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[0] != 0);
1431*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[1] != 0);
1432*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_3d_tile_2d(
1433*4bdc9457SAndroid Build Coastguard Worker threadpool,
1434*4bdc9457SAndroid Build Coastguard Worker op->compute.task_3d_tile_2d,
1435*4bdc9457SAndroid Build Coastguard Worker &op->context,
1436*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0], op->compute.range[1], op->compute.range[2],
1437*4bdc9457SAndroid Build Coastguard Worker op->compute.tile[0], op->compute.tile[1],
1438*4bdc9457SAndroid Build Coastguard Worker flags);
1439*4bdc9457SAndroid Build Coastguard Worker break;
1440*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_4d:
1441*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1442*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[1] != 0);
1443*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[2] != 0);
1444*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[3] != 0);
1445*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_4d(
1446*4bdc9457SAndroid Build Coastguard Worker threadpool,
1447*4bdc9457SAndroid Build Coastguard Worker op->compute.task_4d,
1448*4bdc9457SAndroid Build Coastguard Worker &op->context,
1449*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1450*4bdc9457SAndroid Build Coastguard Worker flags);
1451*4bdc9457SAndroid Build Coastguard Worker break;
1452*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_4d_tile_2d:
1453*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1454*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[1] != 0);
1455*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[2] != 0);
1456*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[3] != 0);
1457*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[0] != 0);
1458*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[1] != 0);
1459*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_4d_tile_2d(
1460*4bdc9457SAndroid Build Coastguard Worker threadpool,
1461*4bdc9457SAndroid Build Coastguard Worker op->compute.task_4d_tile_2d,
1462*4bdc9457SAndroid Build Coastguard Worker &op->context,
1463*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1464*4bdc9457SAndroid Build Coastguard Worker op->compute.tile[0], op->compute.tile[1],
1465*4bdc9457SAndroid Build Coastguard Worker flags);
1466*4bdc9457SAndroid Build Coastguard Worker break;
1467*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_5d:
1468*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1469*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[1] != 0);
1470*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[2] != 0);
1471*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[3] != 0);
1472*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[4] != 0);
1473*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_5d(
1474*4bdc9457SAndroid Build Coastguard Worker threadpool,
1475*4bdc9457SAndroid Build Coastguard Worker op->compute.task_5d,
1476*4bdc9457SAndroid Build Coastguard Worker &op->context,
1477*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1478*4bdc9457SAndroid Build Coastguard Worker flags);
1479*4bdc9457SAndroid Build Coastguard Worker break;
1480*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_5d_tile_2d:
1481*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1482*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[1] != 0);
1483*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[2] != 0);
1484*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[3] != 0);
1485*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[4] != 0);
1486*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[0] != 0);
1487*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[1] != 0);
1488*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_5d_tile_2d(
1489*4bdc9457SAndroid Build Coastguard Worker threadpool,
1490*4bdc9457SAndroid Build Coastguard Worker op->compute.task_5d_tile_2d,
1491*4bdc9457SAndroid Build Coastguard Worker &op->context,
1492*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4],
1493*4bdc9457SAndroid Build Coastguard Worker op->compute.tile[0], op->compute.tile[1],
1494*4bdc9457SAndroid Build Coastguard Worker flags);
1495*4bdc9457SAndroid Build Coastguard Worker break;
1496*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_6d_tile_2d:
1497*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1498*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[1] != 0);
1499*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[2] != 0);
1500*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[3] != 0);
1501*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[4] != 0);
1502*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[5] != 0);
1503*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[0] != 0);
1504*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[1] != 0);
1505*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_6d_tile_2d(
1506*4bdc9457SAndroid Build Coastguard Worker threadpool,
1507*4bdc9457SAndroid Build Coastguard Worker op->compute.task_6d_tile_2d,
1508*4bdc9457SAndroid Build Coastguard Worker &op->context,
1509*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3], op->compute.range[4], op->compute.range[5],
1510*4bdc9457SAndroid Build Coastguard Worker op->compute.tile[0], op->compute.tile[1],
1511*4bdc9457SAndroid Build Coastguard Worker flags);
1512*4bdc9457SAndroid Build Coastguard Worker break;
1513*4bdc9457SAndroid Build Coastguard Worker #if XNN_MAX_UARCH_TYPES > 1
1514*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_2d_tile_2d_with_uarch:
1515*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1516*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[1] != 0);
1517*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[0] != 0);
1518*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[1] != 0);
1519*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_2d_tile_2d_with_uarch(
1520*4bdc9457SAndroid Build Coastguard Worker threadpool,
1521*4bdc9457SAndroid Build Coastguard Worker op->compute.task_2d_tile_2d_with_id,
1522*4bdc9457SAndroid Build Coastguard Worker &op->context,
1523*4bdc9457SAndroid Build Coastguard Worker 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1524*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0], op->compute.range[1],
1525*4bdc9457SAndroid Build Coastguard Worker op->compute.tile[0], op->compute.tile[1],
1526*4bdc9457SAndroid Build Coastguard Worker flags);
1527*4bdc9457SAndroid Build Coastguard Worker break;
1528*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_3d_tile_2d_with_uarch:
1529*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1530*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[1] != 0);
1531*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[2] != 0);
1532*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[0] != 0);
1533*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[1] != 0);
1534*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_3d_tile_2d_with_uarch(
1535*4bdc9457SAndroid Build Coastguard Worker threadpool,
1536*4bdc9457SAndroid Build Coastguard Worker op->compute.task_3d_tile_2d_with_id,
1537*4bdc9457SAndroid Build Coastguard Worker &op->context,
1538*4bdc9457SAndroid Build Coastguard Worker 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1539*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0], op->compute.range[1], op->compute.range[2],
1540*4bdc9457SAndroid Build Coastguard Worker op->compute.tile[0], op->compute.tile[1],
1541*4bdc9457SAndroid Build Coastguard Worker flags);
1542*4bdc9457SAndroid Build Coastguard Worker break;
1543*4bdc9457SAndroid Build Coastguard Worker case xnn_parallelization_type_4d_tile_2d_with_uarch:
1544*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[0] != 0);
1545*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[1] != 0);
1546*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[2] != 0);
1547*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.range[3] != 0);
1548*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[0] != 0);
1549*4bdc9457SAndroid Build Coastguard Worker assert(op->compute.tile[1] != 0);
1550*4bdc9457SAndroid Build Coastguard Worker pthreadpool_parallelize_4d_tile_2d_with_uarch(
1551*4bdc9457SAndroid Build Coastguard Worker threadpool,
1552*4bdc9457SAndroid Build Coastguard Worker op->compute.task_4d_tile_2d_with_id,
1553*4bdc9457SAndroid Build Coastguard Worker &op->context,
1554*4bdc9457SAndroid Build Coastguard Worker 0 /* default uarch index */, XNN_MAX_UARCH_TYPES - 1,
1555*4bdc9457SAndroid Build Coastguard Worker op->compute.range[0], op->compute.range[1], op->compute.range[2], op->compute.range[3],
1556*4bdc9457SAndroid Build Coastguard Worker op->compute.tile[0], op->compute.tile[1],
1557*4bdc9457SAndroid Build Coastguard Worker flags);
1558*4bdc9457SAndroid Build Coastguard Worker break;
1559*4bdc9457SAndroid Build Coastguard Worker #endif // XNN_MAX_UARCH_TYPES > 1
1560*4bdc9457SAndroid Build Coastguard Worker default:
1561*4bdc9457SAndroid Build Coastguard Worker XNN_UNREACHABLE;
1562*4bdc9457SAndroid Build Coastguard Worker }
1563*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
1564*4bdc9457SAndroid Build Coastguard Worker }
1565