1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14
15 #include <pytorch_qnnpack.h>
16 #include <qnnpack/log.h>
17 #include <qnnpack/operator.h>
18 #include <qnnpack/params.h>
19 #include <qnnpack/requantization.h>
20
pytorch_qnnp_create_add_nc_q8(size_t channels,uint8_t a_zero_point,float a_scale,uint8_t b_zero_point,float b_scale,uint8_t sum_zero_point,float sum_scale,uint8_t sum_min,uint8_t sum_max,uint32_t flags,pytorch_qnnp_operator_t * add_out)21 enum pytorch_qnnp_status pytorch_qnnp_create_add_nc_q8(
22 size_t channels,
23 uint8_t a_zero_point,
24 float a_scale,
25 uint8_t b_zero_point,
26 float b_scale,
27 uint8_t sum_zero_point,
28 float sum_scale,
29 uint8_t sum_min,
30 uint8_t sum_max,
31 uint32_t flags,
32 pytorch_qnnp_operator_t* add_out) {
33 pytorch_qnnp_operator_t add_op = NULL;
34 enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
35
36 if (!pytorch_qnnp_params.initialized) {
37 pytorch_qnnp_log_error(
38 "pytorch_qnnp_create_add_nc_q8 failed because QNNPACK is not properly initialized");
39 goto error;
40 }
41
42 status = pytorch_qnnp_status_invalid_parameter;
43
44 if (channels == 0) {
45 pytorch_qnnp_log_error(
46 "failed to create add operator with %zu channels: number of channels must be non-zero",
47 channels);
48 goto error;
49 }
50
51 if (a_scale <= 0.0f || !isnormal(a_scale)) {
52 pytorch_qnnp_log_error(
53 "failed to create add operator with %.7g A scale: scale must be finite and positive",
54 a_scale);
55 goto error;
56 }
57
58 if (b_scale <= 0.0f || !isnormal(b_scale)) {
59 pytorch_qnnp_log_error(
60 "failed to create add operator with %.7g B scale: scale must be finite and positive",
61 b_scale);
62 goto error;
63 }
64
65 if (sum_scale <= 0.0f || !isnormal(sum_scale)) {
66 pytorch_qnnp_log_error(
67 "failed to create add operator with %.7g output scale: scale must be finite and positive",
68 sum_scale);
69 goto error;
70 }
71
72 if (sum_min >= sum_max) {
73 pytorch_qnnp_log_error(
74 "failed to create add operator with [%" PRIu8 ", %" PRIu8
75 "] output range: range min must be below range max",
76 sum_min,
77 sum_max);
78 goto error;
79 }
80
81 status = pytorch_qnnp_status_unsupported_parameter;
82
83 const float a_output_scale = a_scale / sum_scale;
84 if (a_output_scale < 0x1.0p-14f || a_output_scale >= 0x1.0p+8f) {
85 pytorch_qnnp_log_error(
86 "failed to create add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
87 a_output_scale);
88 goto error;
89 }
90
91 const float b_output_scale = b_scale / sum_scale;
92 if (b_output_scale < 0x1.0p-14f || b_output_scale >= 0x1.0p+8f) {
93 pytorch_qnnp_log_error(
94 "failed to create add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range",
95 b_output_scale);
96 goto error;
97 }
98
99 status = pytorch_qnnp_status_out_of_memory;
100
101 add_op = calloc(1, sizeof(struct pytorch_qnnp_operator));
102 if (add_op == NULL) {
103 pytorch_qnnp_log_error(
104 "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
105 sizeof(struct pytorch_qnnp_operator));
106 goto error;
107 }
108
109 add_op->channels = channels;
110 add_op->add_quantization_params =
111 pytorch_qnnp_compute_add_quantization_params(
112 a_zero_point,
113 b_zero_point,
114 sum_zero_point,
115 a_scale / sum_scale,
116 b_scale / sum_scale,
117 sum_min,
118 sum_max);
119
120 add_op->ukernel_type = pytorch_qnnp_ukernel_type_add;
121 add_op->format = pytorch_qnnp_format_quint8;
122
123 *add_out = add_op;
124 return pytorch_qnnp_status_success;
125
126 error:
127 pytorch_qnnp_delete_operator(add_op);
128 return status;
129 }
130
pytorch_qnnp_setup_add_nc_q8(pytorch_qnnp_operator_t add_op,size_t batch_size,const uint8_t * a,size_t a_stride,const uint8_t * b,size_t b_stride,uint8_t * sum,size_t sum_stride)131 enum pytorch_qnnp_status pytorch_qnnp_setup_add_nc_q8(
132 pytorch_qnnp_operator_t add_op,
133 size_t batch_size,
134 const uint8_t* a,
135 size_t a_stride,
136 const uint8_t* b,
137 size_t b_stride,
138 uint8_t* sum,
139 size_t sum_stride) {
140 if (!pytorch_qnnp_params.initialized) {
141 pytorch_qnnp_log_error(
142 "pytorch_qnnp_setup_add_nc_q8 failed because QNNPACK is not properly initialized");
143 return pytorch_qnnp_status_uninitialized;
144 }
145
146 if (batch_size == 0) {
147 add_op->batch_size = 0;
148 return pytorch_qnnp_status_success;
149 }
150
151 add_op->batch_size = batch_size;
152 add_op->input = a;
153 add_op->input_pixel_stride = a_stride;
154 add_op->input2 = b;
155 add_op->input2_pixel_stride = b_stride;
156 add_op->output = sum;
157 add_op->output_pixel_stride = sum_stride;
158
159 return pytorch_qnnp_status_success;
160 }
161