xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/add.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14 
15 #include <pytorch_qnnpack.h>
16 #include <qnnpack/log.h>
17 #include <qnnpack/operator.h>
18 #include <qnnpack/params.h>
19 #include <qnnpack/requantization.h>
20 
pytorch_qnnp_create_add_nc_q8(size_t channels,uint8_t a_zero_point,float a_scale,uint8_t b_zero_point,float b_scale,uint8_t sum_zero_point,float sum_scale,uint8_t sum_min,uint8_t sum_max,uint32_t flags,pytorch_qnnp_operator_t * add_out)21 enum pytorch_qnnp_status pytorch_qnnp_create_add_nc_q8(
22     size_t channels,
23     uint8_t a_zero_point,
24     float a_scale,
25     uint8_t b_zero_point,
26     float b_scale,
27     uint8_t sum_zero_point,
28     float sum_scale,
29     uint8_t sum_min,
30     uint8_t sum_max,
31     uint32_t flags,
32     pytorch_qnnp_operator_t* add_out) {
33   pytorch_qnnp_operator_t add_op = NULL;
34   enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
35 
36   if (!pytorch_qnnp_params.initialized) {
37     pytorch_qnnp_log_error(
38         "pytorch_qnnp_create_add_nc_q8 failed because QNNPACK is not properly initialized");
39     goto error;
40   }
41 
42   status = pytorch_qnnp_status_invalid_parameter;
43 
44   if (channels == 0) {
45     pytorch_qnnp_log_error(
46         "failed to create add operator with %zu channels: number of channels must be non-zero",
47         channels);
48     goto error;
49   }
50 
51   if (a_scale <= 0.0f || !isnormal(a_scale)) {
52     pytorch_qnnp_log_error(
53         "failed to create add operator with %.7g A scale: scale must be finite and positive",
54         a_scale);
55     goto error;
56   }
57 
58   if (b_scale <= 0.0f || !isnormal(b_scale)) {
59     pytorch_qnnp_log_error(
60         "failed to create add operator with %.7g B scale: scale must be finite and positive",
61         b_scale);
62     goto error;
63   }
64 
65   if (sum_scale <= 0.0f || !isnormal(sum_scale)) {
66     pytorch_qnnp_log_error(
67         "failed to create add operator with %.7g output scale: scale must be finite and positive",
68         sum_scale);
69     goto error;
70   }
71 
72   if (sum_min >= sum_max) {
73     pytorch_qnnp_log_error(
74         "failed to create add operator with [%" PRIu8 ", %" PRIu8
75         "] output range: range min must be below range max",
76         sum_min,
77         sum_max);
78     goto error;
79   }
80 
81   status = pytorch_qnnp_status_unsupported_parameter;
82 
83   const float a_output_scale = a_scale / sum_scale;
84   if (a_output_scale < 0x1.0p-14f || a_output_scale >= 0x1.0p+8f) {
85     pytorch_qnnp_log_error(
86         "failed to create add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
87         a_output_scale);
88     goto error;
89   }
90 
91   const float b_output_scale = b_scale / sum_scale;
92   if (b_output_scale < 0x1.0p-14f || b_output_scale >= 0x1.0p+8f) {
93     pytorch_qnnp_log_error(
94         "failed to create add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
95         b_output_scale);
96     goto error;
97   }
98 
99   status = pytorch_qnnp_status_out_of_memory;
100 
101   add_op = calloc(1, sizeof(struct pytorch_qnnp_operator));
102   if (add_op == NULL) {
103     pytorch_qnnp_log_error(
104         "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
105         sizeof(struct pytorch_qnnp_operator));
106     goto error;
107   }
108 
109   add_op->channels = channels;
110   add_op->add_quantization_params =
111       pytorch_qnnp_compute_add_quantization_params(
112           a_zero_point,
113           b_zero_point,
114           sum_zero_point,
115           a_scale / sum_scale,
116           b_scale / sum_scale,
117           sum_min,
118           sum_max);
119 
120   add_op->ukernel_type = pytorch_qnnp_ukernel_type_add;
121   add_op->format = pytorch_qnnp_format_quint8;
122 
123   *add_out = add_op;
124   return pytorch_qnnp_status_success;
125 
126 error:
127   pytorch_qnnp_delete_operator(add_op);
128   return status;
129 }
130 
pytorch_qnnp_setup_add_nc_q8(pytorch_qnnp_operator_t add_op,size_t batch_size,const uint8_t * a,size_t a_stride,const uint8_t * b,size_t b_stride,uint8_t * sum,size_t sum_stride)131 enum pytorch_qnnp_status pytorch_qnnp_setup_add_nc_q8(
132     pytorch_qnnp_operator_t add_op,
133     size_t batch_size,
134     const uint8_t* a,
135     size_t a_stride,
136     const uint8_t* b,
137     size_t b_stride,
138     uint8_t* sum,
139     size_t sum_stride) {
140   if (!pytorch_qnnp_params.initialized) {
141     pytorch_qnnp_log_error(
142         "pytorch_qnnp_setup_add_nc_q8 failed because QNNPACK is not properly initialized");
143     return pytorch_qnnp_status_uninitialized;
144   }
145 
146   if (batch_size == 0) {
147     add_op->batch_size = 0;
148     return pytorch_qnnp_status_success;
149   }
150 
151   add_op->batch_size = batch_size;
152   add_op->input = a;
153   add_op->input_pixel_stride = a_stride;
154   add_op->input2 = b;
155   add_op->input2_pixel_stride = b_stride;
156   add_op->output = sum;
157   add_op->output_pixel_stride = sum_stride;
158 
159   return pytorch_qnnp_status_success;
160 }
161