1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2021 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <assert.h>
7*4bdc9457SAndroid Build Coastguard Worker #include <math.h>
8*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h>
9*4bdc9457SAndroid Build Coastguard Worker #include <stdint.h>
10*4bdc9457SAndroid Build Coastguard Worker #include <stdlib.h>
11*4bdc9457SAndroid Build Coastguard Worker
12*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
13*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/allocator.h>
14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/log.h>
16*4bdc9457SAndroid Build Coastguard Worker
17*4bdc9457SAndroid Build Coastguard Worker
18*4bdc9457SAndroid Build Coastguard Worker typedef float (*xnn_lut_init_fn)(float, const void*);
19*4bdc9457SAndroid Build Coastguard Worker
create_lut_elementwise_nc(size_t channels,size_t input_stride,size_t output_stride,int32_t input_zero_point,float input_scale,int32_t input_min,long output_zero_point,float output_scale,long output_min,long output_max,uint32_t flags,xnn_lut_init_fn init_fn,const void * init_params,enum xnn_operator_type operator_type,xnn_operator_t * lut_elementwise_op_out)20*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status create_lut_elementwise_nc(
21*4bdc9457SAndroid Build Coastguard Worker size_t channels,
22*4bdc9457SAndroid Build Coastguard Worker size_t input_stride,
23*4bdc9457SAndroid Build Coastguard Worker size_t output_stride,
24*4bdc9457SAndroid Build Coastguard Worker int32_t input_zero_point,
25*4bdc9457SAndroid Build Coastguard Worker float input_scale,
26*4bdc9457SAndroid Build Coastguard Worker int32_t input_min,
27*4bdc9457SAndroid Build Coastguard Worker long output_zero_point,
28*4bdc9457SAndroid Build Coastguard Worker float output_scale,
29*4bdc9457SAndroid Build Coastguard Worker long output_min,
30*4bdc9457SAndroid Build Coastguard Worker long output_max,
31*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
32*4bdc9457SAndroid Build Coastguard Worker xnn_lut_init_fn init_fn,
33*4bdc9457SAndroid Build Coastguard Worker const void* init_params,
34*4bdc9457SAndroid Build Coastguard Worker enum xnn_operator_type operator_type,
35*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* lut_elementwise_op_out)
36*4bdc9457SAndroid Build Coastguard Worker {
37*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t lut_elementwise_op = NULL;
38*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status = xnn_status_uninitialized;
39*4bdc9457SAndroid Build Coastguard Worker
40*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
41*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to create %s operator: XNNPACK is not initialized",
42*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type));
43*4bdc9457SAndroid Build Coastguard Worker goto error;
44*4bdc9457SAndroid Build Coastguard Worker }
45*4bdc9457SAndroid Build Coastguard Worker
46*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_invalid_parameter;
47*4bdc9457SAndroid Build Coastguard Worker
48*4bdc9457SAndroid Build Coastguard Worker if (channels == 0) {
49*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
50*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %zu channels: number of channels must be non-zero",
51*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), channels);
52*4bdc9457SAndroid Build Coastguard Worker goto error;
53*4bdc9457SAndroid Build Coastguard Worker }
54*4bdc9457SAndroid Build Coastguard Worker
55*4bdc9457SAndroid Build Coastguard Worker if (input_stride < channels) {
56*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
57*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with input element stride of %zu: "
58*4bdc9457SAndroid Build Coastguard Worker "stride must be at least as large as the number of channels (%zu)",
59*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), input_stride, channels);
60*4bdc9457SAndroid Build Coastguard Worker goto error;
61*4bdc9457SAndroid Build Coastguard Worker }
62*4bdc9457SAndroid Build Coastguard Worker
63*4bdc9457SAndroid Build Coastguard Worker if (output_stride < channels) {
64*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
65*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with output element stride of %zu: "
66*4bdc9457SAndroid Build Coastguard Worker "stride must be at least as large as the number of channels (%zu)",
67*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), output_stride, channels);
68*4bdc9457SAndroid Build Coastguard Worker goto error;
69*4bdc9457SAndroid Build Coastguard Worker }
70*4bdc9457SAndroid Build Coastguard Worker
71*4bdc9457SAndroid Build Coastguard Worker if (input_scale <= 0.0f || !isnormal(input_scale)) {
72*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
73*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g input scale: scale must be finite, normalized, and positive",
74*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), input_scale);
75*4bdc9457SAndroid Build Coastguard Worker goto error;
76*4bdc9457SAndroid Build Coastguard Worker }
77*4bdc9457SAndroid Build Coastguard Worker
78*4bdc9457SAndroid Build Coastguard Worker if (output_scale <= 0.0f || !isnormal(output_scale)) {
79*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
80*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g output scale: scale must be finite, normalized, and positive",
81*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), output_scale);
82*4bdc9457SAndroid Build Coastguard Worker goto error;
83*4bdc9457SAndroid Build Coastguard Worker }
84*4bdc9457SAndroid Build Coastguard Worker
85*4bdc9457SAndroid Build Coastguard Worker if (output_min >= output_max) {
86*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
87*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with [%ld, %ld] output range: range min must be below range max",
88*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type), output_min, output_max);
89*4bdc9457SAndroid Build Coastguard Worker goto error;
90*4bdc9457SAndroid Build Coastguard Worker }
91*4bdc9457SAndroid Build Coastguard Worker
92*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_out_of_memory;
93*4bdc9457SAndroid Build Coastguard Worker
94*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
95*4bdc9457SAndroid Build Coastguard Worker if (lut_elementwise_op == NULL) {
96*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
97*4bdc9457SAndroid Build Coastguard Worker "failed to allocate %zu bytes for %s operator descriptor",
98*4bdc9457SAndroid Build Coastguard Worker sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
99*4bdc9457SAndroid Build Coastguard Worker goto error;
100*4bdc9457SAndroid Build Coastguard Worker }
101*4bdc9457SAndroid Build Coastguard Worker
102*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->lookup_table = xnn_allocate_simd_memory(256 * sizeof(uint8_t));
103*4bdc9457SAndroid Build Coastguard Worker if (lut_elementwise_op->lookup_table == NULL) {
104*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
105*4bdc9457SAndroid Build Coastguard Worker "failed to allocate 256 bytes for %s operator lookup table",
106*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(operator_type));
107*4bdc9457SAndroid Build Coastguard Worker goto error;
108*4bdc9457SAndroid Build Coastguard Worker }
109*4bdc9457SAndroid Build Coastguard Worker
110*4bdc9457SAndroid Build Coastguard Worker uint8_t* lookup_table = lut_elementwise_op->lookup_table;
111*4bdc9457SAndroid Build Coastguard Worker const float inv_output_scale = 1.0f / output_scale;
112*4bdc9457SAndroid Build Coastguard Worker for (int32_t i = input_min; i < input_min + 256; i++) {
113*4bdc9457SAndroid Build Coastguard Worker const float dequantized_input = (i - input_zero_point) * input_scale;
114*4bdc9457SAndroid Build Coastguard Worker const float dequantized_output = init_fn(dequantized_input, init_params);
115*4bdc9457SAndroid Build Coastguard Worker long quantized_output = lrintf(dequantized_output * inv_output_scale) + output_zero_point;
116*4bdc9457SAndroid Build Coastguard Worker quantized_output = XNN_UNPREDICTABLE(quantized_output < output_min) ? output_min : quantized_output;
117*4bdc9457SAndroid Build Coastguard Worker quantized_output = XNN_UNPREDICTABLE(quantized_output > output_max) ? output_max : quantized_output;
118*4bdc9457SAndroid Build Coastguard Worker lookup_table[(uint8_t) i] = (uint8_t) quantized_output;
119*4bdc9457SAndroid Build Coastguard Worker }
120*4bdc9457SAndroid Build Coastguard Worker
121*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->channels = channels;
122*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->input_pixel_stride = input_stride;
123*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->output_pixel_stride = output_stride;
124*4bdc9457SAndroid Build Coastguard Worker
125*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->type = operator_type;
126*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->flags = flags;
127*4bdc9457SAndroid Build Coastguard Worker
128*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->state = xnn_run_state_invalid;
129*4bdc9457SAndroid Build Coastguard Worker
130*4bdc9457SAndroid Build Coastguard Worker *lut_elementwise_op_out = lut_elementwise_op;
131*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
132*4bdc9457SAndroid Build Coastguard Worker
133*4bdc9457SAndroid Build Coastguard Worker error:
134*4bdc9457SAndroid Build Coastguard Worker xnn_delete_operator(lut_elementwise_op);
135*4bdc9457SAndroid Build Coastguard Worker return status;
136*4bdc9457SAndroid Build Coastguard Worker }
137*4bdc9457SAndroid Build Coastguard Worker
calculate_elu(float x,const float * alpha_ptr)138*4bdc9457SAndroid Build Coastguard Worker static float calculate_elu(float x, const float* alpha_ptr) {
139*4bdc9457SAndroid Build Coastguard Worker const float alpha = *alpha_ptr;
140*4bdc9457SAndroid Build Coastguard Worker return signbit(x) ? alpha * expm1f(x) : x;
141*4bdc9457SAndroid Build Coastguard Worker }
142*4bdc9457SAndroid Build Coastguard Worker
xnn_create_elu_nc_qs8(size_t channels,size_t input_stride,size_t output_stride,float alpha,int8_t input_zero_point,float input_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * elu_op_out)143*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_elu_nc_qs8(
144*4bdc9457SAndroid Build Coastguard Worker size_t channels,
145*4bdc9457SAndroid Build Coastguard Worker size_t input_stride,
146*4bdc9457SAndroid Build Coastguard Worker size_t output_stride,
147*4bdc9457SAndroid Build Coastguard Worker float alpha,
148*4bdc9457SAndroid Build Coastguard Worker int8_t input_zero_point,
149*4bdc9457SAndroid Build Coastguard Worker float input_scale,
150*4bdc9457SAndroid Build Coastguard Worker int8_t output_zero_point,
151*4bdc9457SAndroid Build Coastguard Worker float output_scale,
152*4bdc9457SAndroid Build Coastguard Worker int8_t output_min,
153*4bdc9457SAndroid Build Coastguard Worker int8_t output_max,
154*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
155*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* elu_op_out)
156*4bdc9457SAndroid Build Coastguard Worker {
157*4bdc9457SAndroid Build Coastguard Worker if (alpha <= 0.0f || !isnormal(alpha)) {
158*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
159*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g alpha parameter: alpha must be finite, normalized, and positive",
160*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_elu_nc_qs8), alpha);
161*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
162*4bdc9457SAndroid Build Coastguard Worker }
163*4bdc9457SAndroid Build Coastguard Worker
164*4bdc9457SAndroid Build Coastguard Worker return create_lut_elementwise_nc(
165*4bdc9457SAndroid Build Coastguard Worker channels, input_stride, output_stride,
166*4bdc9457SAndroid Build Coastguard Worker (int32_t) input_zero_point, input_scale, INT8_MIN,
167*4bdc9457SAndroid Build Coastguard Worker (long) output_zero_point, output_scale,
168*4bdc9457SAndroid Build Coastguard Worker (long) output_min, (long) output_max,
169*4bdc9457SAndroid Build Coastguard Worker flags,
170*4bdc9457SAndroid Build Coastguard Worker (xnn_lut_init_fn) &calculate_elu, &alpha,
171*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_elu_nc_qs8, elu_op_out);
172*4bdc9457SAndroid Build Coastguard Worker }
173*4bdc9457SAndroid Build Coastguard Worker
calculate_sigmoid(float x,const void * params)174*4bdc9457SAndroid Build Coastguard Worker static float calculate_sigmoid(float x, const void* params) {
175*4bdc9457SAndroid Build Coastguard Worker return signbit(x) ? 1.0f / (1.0f + expf(-x)) : 1.0f - 1.0f / (1.0f + expf(x));
176*4bdc9457SAndroid Build Coastguard Worker }
177*4bdc9457SAndroid Build Coastguard Worker
xnn_create_sigmoid_nc_qs8(size_t channels,size_t input_stride,size_t output_stride,int8_t input_zero_point,float input_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * sigmoid_op_out)178*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_sigmoid_nc_qs8(
179*4bdc9457SAndroid Build Coastguard Worker size_t channels,
180*4bdc9457SAndroid Build Coastguard Worker size_t input_stride,
181*4bdc9457SAndroid Build Coastguard Worker size_t output_stride,
182*4bdc9457SAndroid Build Coastguard Worker int8_t input_zero_point,
183*4bdc9457SAndroid Build Coastguard Worker float input_scale,
184*4bdc9457SAndroid Build Coastguard Worker int8_t output_zero_point,
185*4bdc9457SAndroid Build Coastguard Worker float output_scale,
186*4bdc9457SAndroid Build Coastguard Worker int8_t output_min,
187*4bdc9457SAndroid Build Coastguard Worker int8_t output_max,
188*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
189*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* sigmoid_op_out)
190*4bdc9457SAndroid Build Coastguard Worker {
191*4bdc9457SAndroid Build Coastguard Worker if (output_scale != 0x1.0p-8f) {
192*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
193*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g output scale: only output scale of 1/256 is supported",
194*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qs8), output_scale);
195*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
196*4bdc9457SAndroid Build Coastguard Worker }
197*4bdc9457SAndroid Build Coastguard Worker
198*4bdc9457SAndroid Build Coastguard Worker if (output_zero_point != -128) {
199*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
200*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu8 " output zero point: only output zero point of -128 is supported",
201*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qs8), output_zero_point);
202*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
203*4bdc9457SAndroid Build Coastguard Worker }
204*4bdc9457SAndroid Build Coastguard Worker
205*4bdc9457SAndroid Build Coastguard Worker return create_lut_elementwise_nc(
206*4bdc9457SAndroid Build Coastguard Worker channels, input_stride, output_stride,
207*4bdc9457SAndroid Build Coastguard Worker (int32_t) input_zero_point, input_scale, INT8_MIN,
208*4bdc9457SAndroid Build Coastguard Worker (long) output_zero_point, output_scale,
209*4bdc9457SAndroid Build Coastguard Worker (long) output_min, (long) output_max,
210*4bdc9457SAndroid Build Coastguard Worker flags,
211*4bdc9457SAndroid Build Coastguard Worker (xnn_lut_init_fn) &calculate_sigmoid, NULL,
212*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_sigmoid_nc_qs8, sigmoid_op_out);
213*4bdc9457SAndroid Build Coastguard Worker }
214*4bdc9457SAndroid Build Coastguard Worker
xnn_create_sigmoid_nc_qu8(size_t channels,size_t input_stride,size_t output_stride,uint8_t input_zero_point,float input_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * sigmoid_op_out)215*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_sigmoid_nc_qu8(
216*4bdc9457SAndroid Build Coastguard Worker size_t channels,
217*4bdc9457SAndroid Build Coastguard Worker size_t input_stride,
218*4bdc9457SAndroid Build Coastguard Worker size_t output_stride,
219*4bdc9457SAndroid Build Coastguard Worker uint8_t input_zero_point,
220*4bdc9457SAndroid Build Coastguard Worker float input_scale,
221*4bdc9457SAndroid Build Coastguard Worker uint8_t output_zero_point,
222*4bdc9457SAndroid Build Coastguard Worker float output_scale,
223*4bdc9457SAndroid Build Coastguard Worker uint8_t output_min,
224*4bdc9457SAndroid Build Coastguard Worker uint8_t output_max,
225*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
226*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* sigmoid_op_out)
227*4bdc9457SAndroid Build Coastguard Worker {
228*4bdc9457SAndroid Build Coastguard Worker if (output_scale != 0x1.0p-8f) {
229*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
230*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g output scale: only output scale of 1/256 is supported",
231*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qu8), output_scale);
232*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
233*4bdc9457SAndroid Build Coastguard Worker }
234*4bdc9457SAndroid Build Coastguard Worker
235*4bdc9457SAndroid Build Coastguard Worker if (output_zero_point != 0) {
236*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
237*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu8 " output zero point: only output zero point of 0 is supported",
238*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_sigmoid_nc_qu8), output_zero_point);
239*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
240*4bdc9457SAndroid Build Coastguard Worker }
241*4bdc9457SAndroid Build Coastguard Worker
242*4bdc9457SAndroid Build Coastguard Worker return create_lut_elementwise_nc(
243*4bdc9457SAndroid Build Coastguard Worker channels, input_stride, output_stride,
244*4bdc9457SAndroid Build Coastguard Worker (int32_t) (uint32_t) input_zero_point, input_scale, 0 /* input min */,
245*4bdc9457SAndroid Build Coastguard Worker (long) (unsigned long) output_zero_point, output_scale,
246*4bdc9457SAndroid Build Coastguard Worker (long) (unsigned long) output_min, (long) (unsigned long) output_max,
247*4bdc9457SAndroid Build Coastguard Worker flags,
248*4bdc9457SAndroid Build Coastguard Worker (xnn_lut_init_fn) &calculate_sigmoid, NULL,
249*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_sigmoid_nc_qu8, sigmoid_op_out);
250*4bdc9457SAndroid Build Coastguard Worker }
251*4bdc9457SAndroid Build Coastguard Worker
calculate_tanh(float x,const void * params)252*4bdc9457SAndroid Build Coastguard Worker static float calculate_tanh(float x, const void* params) {
253*4bdc9457SAndroid Build Coastguard Worker return tanhf(x);
254*4bdc9457SAndroid Build Coastguard Worker }
255*4bdc9457SAndroid Build Coastguard Worker
xnn_create_tanh_nc_qs8(size_t channels,size_t input_stride,size_t output_stride,int8_t input_zero_point,float input_scale,int8_t output_zero_point,float output_scale,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * tanh_op_out)256*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_tanh_nc_qs8(
257*4bdc9457SAndroid Build Coastguard Worker size_t channels,
258*4bdc9457SAndroid Build Coastguard Worker size_t input_stride,
259*4bdc9457SAndroid Build Coastguard Worker size_t output_stride,
260*4bdc9457SAndroid Build Coastguard Worker int8_t input_zero_point,
261*4bdc9457SAndroid Build Coastguard Worker float input_scale,
262*4bdc9457SAndroid Build Coastguard Worker int8_t output_zero_point,
263*4bdc9457SAndroid Build Coastguard Worker float output_scale,
264*4bdc9457SAndroid Build Coastguard Worker int8_t output_min,
265*4bdc9457SAndroid Build Coastguard Worker int8_t output_max,
266*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
267*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* tanh_op_out)
268*4bdc9457SAndroid Build Coastguard Worker {
269*4bdc9457SAndroid Build Coastguard Worker if (output_scale != 0x1.0p-7f) {
270*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
271*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g output scale: only output scale of 1/128 is supported",
272*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_tanh_nc_qs8), output_scale);
273*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
274*4bdc9457SAndroid Build Coastguard Worker }
275*4bdc9457SAndroid Build Coastguard Worker
276*4bdc9457SAndroid Build Coastguard Worker if (output_zero_point != 0) {
277*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
278*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu8 " output zero point: only output zero point of 0 is supported",
279*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_tanh_nc_qs8), output_zero_point);
280*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
281*4bdc9457SAndroid Build Coastguard Worker }
282*4bdc9457SAndroid Build Coastguard Worker
283*4bdc9457SAndroid Build Coastguard Worker return create_lut_elementwise_nc(
284*4bdc9457SAndroid Build Coastguard Worker channels, input_stride, output_stride,
285*4bdc9457SAndroid Build Coastguard Worker (int32_t) input_zero_point, input_scale, INT8_MIN,
286*4bdc9457SAndroid Build Coastguard Worker (long) output_zero_point, output_scale,
287*4bdc9457SAndroid Build Coastguard Worker (long) output_min, (long) output_max,
288*4bdc9457SAndroid Build Coastguard Worker flags,
289*4bdc9457SAndroid Build Coastguard Worker (xnn_lut_init_fn) &calculate_tanh, NULL,
290*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_tanh_nc_qs8, tanh_op_out);
291*4bdc9457SAndroid Build Coastguard Worker }
292*4bdc9457SAndroid Build Coastguard Worker
xnn_create_tanh_nc_qu8(size_t channels,size_t input_stride,size_t output_stride,uint8_t input_zero_point,float input_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * tanh_op_out)293*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_tanh_nc_qu8(
294*4bdc9457SAndroid Build Coastguard Worker size_t channels,
295*4bdc9457SAndroid Build Coastguard Worker size_t input_stride,
296*4bdc9457SAndroid Build Coastguard Worker size_t output_stride,
297*4bdc9457SAndroid Build Coastguard Worker uint8_t input_zero_point,
298*4bdc9457SAndroid Build Coastguard Worker float input_scale,
299*4bdc9457SAndroid Build Coastguard Worker uint8_t output_zero_point,
300*4bdc9457SAndroid Build Coastguard Worker float output_scale,
301*4bdc9457SAndroid Build Coastguard Worker uint8_t output_min,
302*4bdc9457SAndroid Build Coastguard Worker uint8_t output_max,
303*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
304*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t* tanh_op_out)
305*4bdc9457SAndroid Build Coastguard Worker {
306*4bdc9457SAndroid Build Coastguard Worker if (output_scale != 0x1.0p-7f) {
307*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
308*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %.7g output scale: only output scale of 1/128 is supported",
309*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_tanh_nc_qu8), output_scale);
310*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
311*4bdc9457SAndroid Build Coastguard Worker }
312*4bdc9457SAndroid Build Coastguard Worker
313*4bdc9457SAndroid Build Coastguard Worker if (output_zero_point != 128) {
314*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
315*4bdc9457SAndroid Build Coastguard Worker "failed to create %s operator with %" PRIu8 " output zero point: only output zero point of 128 is supported",
316*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(xnn_operator_type_tanh_nc_qu8), output_zero_point);
317*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
318*4bdc9457SAndroid Build Coastguard Worker }
319*4bdc9457SAndroid Build Coastguard Worker
320*4bdc9457SAndroid Build Coastguard Worker return create_lut_elementwise_nc(
321*4bdc9457SAndroid Build Coastguard Worker channels, input_stride, output_stride,
322*4bdc9457SAndroid Build Coastguard Worker (int32_t) (uint32_t) input_zero_point, input_scale, 0 /* input min */,
323*4bdc9457SAndroid Build Coastguard Worker (long) (unsigned long) output_zero_point, output_scale,
324*4bdc9457SAndroid Build Coastguard Worker (long) (unsigned long) output_min, (long) (unsigned long) output_max,
325*4bdc9457SAndroid Build Coastguard Worker flags,
326*4bdc9457SAndroid Build Coastguard Worker (xnn_lut_init_fn) &calculate_tanh, NULL,
327*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_tanh_nc_qu8, tanh_op_out);
328*4bdc9457SAndroid Build Coastguard Worker }
329*4bdc9457SAndroid Build Coastguard Worker
setup_lut_elementwise_nc(xnn_operator_t lut_elementwise_op,enum xnn_operator_type expected_operator_type,size_t batch_size,const void * input,void * output)330*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status setup_lut_elementwise_nc(
331*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t lut_elementwise_op,
332*4bdc9457SAndroid Build Coastguard Worker enum xnn_operator_type expected_operator_type,
333*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
334*4bdc9457SAndroid Build Coastguard Worker const void* input,
335*4bdc9457SAndroid Build Coastguard Worker void* output)
336*4bdc9457SAndroid Build Coastguard Worker {
337*4bdc9457SAndroid Build Coastguard Worker if (lut_elementwise_op->type != expected_operator_type) {
338*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
339*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(expected_operator_type),
340*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(lut_elementwise_op->type));
341*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
342*4bdc9457SAndroid Build Coastguard Worker }
343*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->state = xnn_run_state_invalid;
344*4bdc9457SAndroid Build Coastguard Worker
345*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
346*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
347*4bdc9457SAndroid Build Coastguard Worker "failed to setup %s operator: XNNPACK is not initialized",
348*4bdc9457SAndroid Build Coastguard Worker xnn_operator_type_to_string(expected_operator_type));
349*4bdc9457SAndroid Build Coastguard Worker return xnn_status_uninitialized;
350*4bdc9457SAndroid Build Coastguard Worker }
351*4bdc9457SAndroid Build Coastguard Worker
352*4bdc9457SAndroid Build Coastguard Worker if (batch_size == 0) {
353*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->state = xnn_run_state_skip;
354*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
355*4bdc9457SAndroid Build Coastguard Worker }
356*4bdc9457SAndroid Build Coastguard Worker
357*4bdc9457SAndroid Build Coastguard Worker const size_t channels = lut_elementwise_op->channels;
358*4bdc9457SAndroid Build Coastguard Worker const size_t input_stride = lut_elementwise_op->input_pixel_stride;
359*4bdc9457SAndroid Build Coastguard Worker const size_t output_stride = lut_elementwise_op->output_pixel_stride;
360*4bdc9457SAndroid Build Coastguard Worker if ((((input_stride ^ channels) | (output_stride ^ channels)) == 0) || batch_size == 1) {
361*4bdc9457SAndroid Build Coastguard Worker const size_t block_size = 1024;
362*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->context.lut_contiguous = (struct lut_contiguous_context) {
363*4bdc9457SAndroid Build Coastguard Worker .x = input,
364*4bdc9457SAndroid Build Coastguard Worker .x_stride = input_stride * sizeof(uint8_t),
365*4bdc9457SAndroid Build Coastguard Worker .t = lut_elementwise_op->lookup_table,
366*4bdc9457SAndroid Build Coastguard Worker .y = output,
367*4bdc9457SAndroid Build Coastguard Worker .y_stride = output_stride * sizeof(uint8_t),
368*4bdc9457SAndroid Build Coastguard Worker .ukernel = xnn_params.x8.lut,
369*4bdc9457SAndroid Build Coastguard Worker };
370*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->compute.type = xnn_parallelization_type_1d_tile_1d;
371*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->compute.task_1d_tile_1d = (pthreadpool_task_1d_tile_1d_t) xnn_compute_lut_contiguous;
372*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->compute.range[0] = batch_size * channels * sizeof(uint8_t);
373*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->compute.tile[0] = block_size;
374*4bdc9457SAndroid Build Coastguard Worker } else {
375*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->context.lut_strided = (struct lut_strided_context) {
376*4bdc9457SAndroid Build Coastguard Worker .n = channels,
377*4bdc9457SAndroid Build Coastguard Worker .x = input,
378*4bdc9457SAndroid Build Coastguard Worker .x_stride = input_stride * sizeof(uint8_t),
379*4bdc9457SAndroid Build Coastguard Worker .t = lut_elementwise_op->lookup_table,
380*4bdc9457SAndroid Build Coastguard Worker .y = output,
381*4bdc9457SAndroid Build Coastguard Worker .y_stride = output_stride * sizeof(uint8_t),
382*4bdc9457SAndroid Build Coastguard Worker .ukernel = xnn_params.x8.lut,
383*4bdc9457SAndroid Build Coastguard Worker };
384*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->compute.type = xnn_parallelization_type_1d;
385*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->compute.task_1d = (pthreadpool_task_1d_t) xnn_compute_lut_strided;
386*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->compute.range[0] = batch_size;
387*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->compute.tile[0] = 0;
388*4bdc9457SAndroid Build Coastguard Worker }
389*4bdc9457SAndroid Build Coastguard Worker lut_elementwise_op->state = xnn_run_state_ready;
390*4bdc9457SAndroid Build Coastguard Worker
391*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
392*4bdc9457SAndroid Build Coastguard Worker }
393*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_elu_nc_qs8(xnn_operator_t sigmoid_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)394*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_elu_nc_qs8(
395*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t sigmoid_op,
396*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
397*4bdc9457SAndroid Build Coastguard Worker const int8_t* input,
398*4bdc9457SAndroid Build Coastguard Worker int8_t* output,
399*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
400*4bdc9457SAndroid Build Coastguard Worker {
401*4bdc9457SAndroid Build Coastguard Worker return setup_lut_elementwise_nc(
402*4bdc9457SAndroid Build Coastguard Worker sigmoid_op, xnn_operator_type_elu_nc_qs8,
403*4bdc9457SAndroid Build Coastguard Worker batch_size, input, output);
404*4bdc9457SAndroid Build Coastguard Worker }
405*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_sigmoid_nc_qs8(xnn_operator_t sigmoid_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)406*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_sigmoid_nc_qs8(
407*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t sigmoid_op,
408*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
409*4bdc9457SAndroid Build Coastguard Worker const int8_t* input,
410*4bdc9457SAndroid Build Coastguard Worker int8_t* output,
411*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
412*4bdc9457SAndroid Build Coastguard Worker {
413*4bdc9457SAndroid Build Coastguard Worker return setup_lut_elementwise_nc(
414*4bdc9457SAndroid Build Coastguard Worker sigmoid_op, xnn_operator_type_sigmoid_nc_qs8,
415*4bdc9457SAndroid Build Coastguard Worker batch_size, input, output);
416*4bdc9457SAndroid Build Coastguard Worker }
417*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_sigmoid_nc_qu8(xnn_operator_t sigmoid_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)418*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_sigmoid_nc_qu8(
419*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t sigmoid_op,
420*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
421*4bdc9457SAndroid Build Coastguard Worker const uint8_t* input,
422*4bdc9457SAndroid Build Coastguard Worker uint8_t* output,
423*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
424*4bdc9457SAndroid Build Coastguard Worker {
425*4bdc9457SAndroid Build Coastguard Worker return setup_lut_elementwise_nc(
426*4bdc9457SAndroid Build Coastguard Worker sigmoid_op, xnn_operator_type_sigmoid_nc_qu8,
427*4bdc9457SAndroid Build Coastguard Worker batch_size, input, output);
428*4bdc9457SAndroid Build Coastguard Worker }
429*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_tanh_nc_qs8(xnn_operator_t tanh_op,size_t batch_size,const int8_t * input,int8_t * output,pthreadpool_t threadpool)430*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_tanh_nc_qs8(
431*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t tanh_op,
432*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
433*4bdc9457SAndroid Build Coastguard Worker const int8_t* input,
434*4bdc9457SAndroid Build Coastguard Worker int8_t* output,
435*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
436*4bdc9457SAndroid Build Coastguard Worker {
437*4bdc9457SAndroid Build Coastguard Worker return setup_lut_elementwise_nc(
438*4bdc9457SAndroid Build Coastguard Worker tanh_op, xnn_operator_type_tanh_nc_qs8,
439*4bdc9457SAndroid Build Coastguard Worker batch_size, input, output);
440*4bdc9457SAndroid Build Coastguard Worker }
441*4bdc9457SAndroid Build Coastguard Worker
xnn_setup_tanh_nc_qu8(xnn_operator_t tanh_op,size_t batch_size,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)442*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_tanh_nc_qu8(
443*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t tanh_op,
444*4bdc9457SAndroid Build Coastguard Worker size_t batch_size,
445*4bdc9457SAndroid Build Coastguard Worker const uint8_t* input,
446*4bdc9457SAndroid Build Coastguard Worker uint8_t* output,
447*4bdc9457SAndroid Build Coastguard Worker pthreadpool_t threadpool)
448*4bdc9457SAndroid Build Coastguard Worker {
449*4bdc9457SAndroid Build Coastguard Worker return setup_lut_elementwise_nc(
450*4bdc9457SAndroid Build Coastguard Worker tanh_op, xnn_operator_type_tanh_nc_qu8,
451*4bdc9457SAndroid Build Coastguard Worker batch_size, input, output);
452*4bdc9457SAndroid Build Coastguard Worker }
453