1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14
15 #include <pytorch_qnnpack.h>
16 #include <qnnpack/log.h>
17 #include <qnnpack/operator.h>
18
pytorch_qnnp_create_softargmax_nc_q8(size_t channels,float input_scale,uint8_t output_zero_point,float output_scale,uint32_t flags,pytorch_qnnp_operator_t * softargmax_out)19 enum pytorch_qnnp_status pytorch_qnnp_create_softargmax_nc_q8(
20 size_t channels,
21 float input_scale,
22 uint8_t output_zero_point,
23 float output_scale,
24 uint32_t flags,
25 pytorch_qnnp_operator_t* softargmax_out) {
26 pytorch_qnnp_operator_t softargmax_op = NULL;
27 enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
28
29 if (!pytorch_qnnp_params.initialized) {
30 pytorch_qnnp_log_error(
31 "pytorch_qnnp_create_softargmax_nc_q8 failed because QNNPACK is not properly initialized");
32 goto error;
33 }
34
35 status = pytorch_qnnp_status_invalid_parameter;
36
37 if (channels == 0) {
38 pytorch_qnnp_log_error(
39 "failed to create Soft ArgMax operator with %zu channels: number of channels must be non-zero",
40 channels);
41 goto error;
42 }
43
44 if (input_scale <= 0.0f || !isnormal(input_scale)) {
45 pytorch_qnnp_log_error(
46 "failed to create Soft ArgMax operator with %.7g input scale: scale must be finite and positive",
47 input_scale);
48 goto error;
49 }
50
51 if (output_scale <= 0.0f || !isnormal(output_scale)) {
52 pytorch_qnnp_log_error(
53 "failed to create Soft ArgMax operator with %.7g output scale: scale must be finite and positive",
54 output_scale);
55 goto error;
56 }
57
58 status = pytorch_qnnp_status_unsupported_parameter;
59
60 if (output_scale != 0x1.0p-8f) {
61 pytorch_qnnp_log_error(
62 "failed to create Soft ArgMax operator with %.7g output scale: only output scale of 1/256 is supported",
63 output_scale);
64 goto error;
65 }
66
67 if (output_zero_point != 0) {
68 pytorch_qnnp_log_error(
69 "failed to create Soft ArgMax operator with %" PRIu8
70 " output zero point: only output zero point of 0 is supported",
71 output_zero_point);
72 goto error;
73 }
74
75 status = pytorch_qnnp_status_out_of_memory;
76
77 softargmax_op = calloc(1, sizeof(struct pytorch_qnnp_operator));
78 if (softargmax_op == NULL) {
79 pytorch_qnnp_log_error(
80 "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
81 sizeof(struct pytorch_qnnp_operator));
82 goto error;
83 }
84
85 softargmax_op->lookup_table = malloc(256 * sizeof(uint32_t));
86 if (softargmax_op->lookup_table == NULL) {
87 pytorch_qnnp_log_error(
88 "failed to allocate 256 bytes for Soft ArgMax lookup table");
89 goto error;
90 }
91
92 uint32_t* lookup_table = softargmax_op->lookup_table;
93 const double qscale =
94 fmin(((double)UINT32_MAX) / (double)channels, 8388607.0);
95 for (int32_t i = 0; i < 256; i++) {
96 const double scaled_exp_xi =
97 qscale * exp((double)(i - 255) * (double)input_scale);
98 lookup_table[(uint32_t)i] = (uint32_t)lrint(scaled_exp_xi);
99 }
100
101 softargmax_op->channels = channels;
102
103 softargmax_op->ukernel_type = pytorch_qnnp_ukernel_type_softargmax;
104 softargmax_op->format = pytorch_qnnp_format_quint8;
105
106 *softargmax_out = softargmax_op;
107 return pytorch_qnnp_status_success;
108
109 error:
110 pytorch_qnnp_delete_operator(softargmax_op);
111 return status;
112 }
113
pytorch_qnnp_setup_softargmax_nc_q8(pytorch_qnnp_operator_t softargmax,size_t batch_size,const uint8_t * input,size_t input_stride,uint8_t * output,size_t output_stride)114 enum pytorch_qnnp_status pytorch_qnnp_setup_softargmax_nc_q8(
115 pytorch_qnnp_operator_t softargmax,
116 size_t batch_size,
117 const uint8_t* input,
118 size_t input_stride,
119 uint8_t* output,
120 size_t output_stride) {
121 if (!pytorch_qnnp_params.initialized) {
122 pytorch_qnnp_log_error(
123 "pytorch_qnnp_setup_softargmax_nc_q8 failed because QNNPACK is not properly initialized");
124 return pytorch_qnnp_status_uninitialized;
125 }
126
127 if (batch_size == 0) {
128 softargmax->batch_size = 0;
129 return pytorch_qnnp_status_success;
130 }
131
132 softargmax->batch_size = batch_size;
133 softargmax->input = input;
134 softargmax->input_pixel_stride = input_stride;
135 softargmax->output = output;
136 softargmax->output_pixel_stride = output_stride;
137
138 return pytorch_qnnp_status_success;
139 }
140