xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/hardsigmoid.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14 
15 #include <pytorch_qnnpack.h>
16 #include <qnnpack/log.h>
17 #include <qnnpack/operator.h>
18 
pytorch_qnnp_create_hardsigmoid_nc_q8(size_t channels,uint8_t input_zero_point,float input_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,pytorch_qnnp_operator_t * hardsigmoid_out)19 enum pytorch_qnnp_status pytorch_qnnp_create_hardsigmoid_nc_q8(
20     size_t channels,
21     uint8_t input_zero_point,
22     float input_scale,
23     uint8_t output_zero_point,
24     float output_scale,
25     uint8_t output_min,
26     uint8_t output_max,
27     uint32_t flags,
28     pytorch_qnnp_operator_t* hardsigmoid_out) {
29   pytorch_qnnp_operator_t hardsigmoid_op = NULL;
30   enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
31 
32   if (!pytorch_qnnp_params.initialized) {
33     pytorch_qnnp_log_error(
34         "pytorch_qnnp_create_hardsigmoid_nc_q8 failed because QNNPACK is not properly initialized");
35     goto error;
36   }
37 
38   status = pytorch_qnnp_status_invalid_parameter;
39 
40   if (channels == 0) {
41     pytorch_qnnp_log_error(
42         "failed to create Hardsigmoid operator with %zu channels: number of channels must be non-zero",
43         channels);
44     goto error;
45   }
46 
47   if (input_scale <= 0.0f || !isnormal(input_scale)) {
48     pytorch_qnnp_log_error(
49         "failed to create Hardsigmoid operator with %.7g input scale: scale must be finite and positive",
50         input_scale);
51     goto error;
52   }
53 
54   if (output_scale <= 0.0f || !isnormal(output_scale)) {
55     pytorch_qnnp_log_error(
56         "failed to create Hardsigmoid operator with %.7g output scale: scale must be finite and positive",
57         output_scale);
58     goto error;
59   }
60 
61   if (output_min >= output_max) {
62     pytorch_qnnp_log_error(
63         "failed to create Hardsigmoid operator with [%" PRIu8 ", %" PRIu8
64         "] output range: range min must be below range max",
65         output_min,
66         output_max);
67     goto error;
68   }
69 
70   status = pytorch_qnnp_status_unsupported_parameter;
71 
72   if (output_scale != 0x1.0p-8f) {
73     pytorch_qnnp_log_error(
74         "failed to create Hardsigmoid operator with %.7g output scale: only output scale of 1/256 is supported",
75         output_scale);
76     goto error;
77   }
78 
79   if (output_zero_point != 0) {
80     pytorch_qnnp_log_error(
81         "failed to create Hardsigmoid operator with %" PRIu8
82         " output zero point: only output zero point of 0 is supported",
83         output_zero_point);
84     goto error;
85   }
86 
87   status = pytorch_qnnp_status_out_of_memory;
88 
89   hardsigmoid_op = calloc(1, sizeof(struct pytorch_qnnp_operator));
90   if (hardsigmoid_op == NULL) {
91     pytorch_qnnp_log_error(
92         "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
93         sizeof(struct pytorch_qnnp_operator));
94     goto error;
95   }
96 
97   hardsigmoid_op->lookup_table = malloc(256 * sizeof(uint8_t));
98   if (hardsigmoid_op->lookup_table == NULL) {
99     pytorch_qnnp_log_error(
100         "failed to allocate 256 bytes for Hardsigmoid lookup table");
101     goto error;
102   }
103 
104   uint8_t* lookup_table = hardsigmoid_op->lookup_table;
105   const float scaled_min = (float)(int32_t)output_min;
106   const float scaled_max = (float)(int32_t)output_max;
107   const float inv_output_scale = 1.0f / output_scale;
108   for (int32_t i = 0; i < 256; i++) {
109     float x =
110         input_scale * (float)(i - (int32_t)(uint32_t)input_zero_point);
111     // hardsigmoid, no min/max functions in C
112     float x2 = x + 3.0f;
113     x2 = x2 > 0.0f ? x2 : 0.0f;
114     x2 = x2 < 6.0f ? x2 : 6.0f;
115     x2 = x2 / 6.0f;
116     float scaled_hardsigmoid_x = inv_output_scale * x2 + output_zero_point;
117     if (scaled_hardsigmoid_x < scaled_min) {
118       scaled_hardsigmoid_x = scaled_min;
119     }
120     if (scaled_hardsigmoid_x > scaled_max) {
121       scaled_hardsigmoid_x = scaled_max;
122     }
123     lookup_table[(uint32_t)i] = (uint8_t)lrintf(scaled_hardsigmoid_x);
124   }
125 
126   hardsigmoid_op->channels = channels;
127 
128   hardsigmoid_op->ukernel_type = pytorch_qnnp_ukernel_type_lut;
129   hardsigmoid_op->format = pytorch_qnnp_format_quint8;
130 
131   *hardsigmoid_out = hardsigmoid_op;
132   return pytorch_qnnp_status_success;
133 
134 error:
135   pytorch_qnnp_delete_operator(hardsigmoid_op);
136   return status;
137 }
138 
pytorch_qnnp_setup_hardsigmoid_nc_q8(pytorch_qnnp_operator_t hardsigmoid,size_t batch_size,const uint8_t * input,size_t input_stride,uint8_t * output,size_t output_stride)139 enum pytorch_qnnp_status pytorch_qnnp_setup_hardsigmoid_nc_q8(
140     pytorch_qnnp_operator_t hardsigmoid,
141     size_t batch_size,
142     const uint8_t* input,
143     size_t input_stride,
144     uint8_t* output,
145     size_t output_stride) {
146   if (!pytorch_qnnp_params.initialized) {
147     pytorch_qnnp_log_error(
148         "pytorch_qnnp_setup_hardsigmoid_nc_q8 failed because QNNPACK is not properly initialized");
149     return pytorch_qnnp_status_uninitialized;
150   }
151 
152   if (batch_size == 0) {
153     hardsigmoid->batch_size = 0;
154     return pytorch_qnnp_status_success;
155   }
156 
157   hardsigmoid->batch_size = batch_size;
158   hardsigmoid->input = input;
159   hardsigmoid->input_pixel_stride = input_stride;
160   hardsigmoid->output = output;
161   hardsigmoid->output_pixel_stride = output_stride;
162 
163   return pytorch_qnnp_status_success;
164 }
165