xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/softargmax.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14 
15 #include <pytorch_qnnpack.h>
16 #include <qnnpack/log.h>
17 #include <qnnpack/operator.h>
18 
pytorch_qnnp_create_softargmax_nc_q8(size_t channels,float input_scale,uint8_t output_zero_point,float output_scale,uint32_t flags,pytorch_qnnp_operator_t * softargmax_out)19 enum pytorch_qnnp_status pytorch_qnnp_create_softargmax_nc_q8(
20     size_t channels,
21     float input_scale,
22     uint8_t output_zero_point,
23     float output_scale,
24     uint32_t flags,
25     pytorch_qnnp_operator_t* softargmax_out) {
26   pytorch_qnnp_operator_t softargmax_op = NULL;
27   enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
28 
29   if (!pytorch_qnnp_params.initialized) {
30     pytorch_qnnp_log_error(
31         "pytorch_qnnp_create_softargmax_nc_q8 failed because QNNPACK is not properly initialized");
32     goto error;
33   }
34 
35   status = pytorch_qnnp_status_invalid_parameter;
36 
37   if (channels == 0) {
38     pytorch_qnnp_log_error(
39         "failed to create Soft ArgMax operator with %zu channels: number of channels must be non-zero",
40         channels);
41     goto error;
42   }
43 
44   if (input_scale <= 0.0f || !isnormal(input_scale)) {
45     pytorch_qnnp_log_error(
46         "failed to create Soft ArgMax operator with %.7g input scale: scale must be finite and positive",
47         input_scale);
48     goto error;
49   }
50 
51   if (output_scale <= 0.0f || !isnormal(output_scale)) {
52     pytorch_qnnp_log_error(
53         "failed to create Soft ArgMax operator with %.7g output scale: scale must be finite and positive",
54         output_scale);
55     goto error;
56   }
57 
58   status = pytorch_qnnp_status_unsupported_parameter;
59 
60   if (output_scale != 0x1.0p-8f) {
61     pytorch_qnnp_log_error(
62         "failed to create Soft ArgMax operator with %.7g output scale: only output scale of 1/256 is supported",
63         output_scale);
64     goto error;
65   }
66 
67   if (output_zero_point != 0) {
68     pytorch_qnnp_log_error(
69         "failed to create Soft ArgMax operator with %" PRIu8
70         " output zero point: only output zero point of 0 is supported",
71         output_zero_point);
72     goto error;
73   }
74 
75   status = pytorch_qnnp_status_out_of_memory;
76 
77   softargmax_op = calloc(1, sizeof(struct pytorch_qnnp_operator));
78   if (softargmax_op == NULL) {
79     pytorch_qnnp_log_error(
80         "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
81         sizeof(struct pytorch_qnnp_operator));
82     goto error;
83   }
84 
85   softargmax_op->lookup_table = malloc(256 * sizeof(uint32_t));
86   if (softargmax_op->lookup_table == NULL) {
87     pytorch_qnnp_log_error(
88         "failed to allocate 256 bytes for Soft ArgMax lookup table");
89     goto error;
90   }
91 
92   uint32_t* lookup_table = softargmax_op->lookup_table;
93   const double qscale =
94       fmin(((double)UINT32_MAX) / (double)channels, 8388607.0);
95   for (int32_t i = 0; i < 256; i++) {
96     const double scaled_exp_xi =
97         qscale * exp((double)(i - 255) * (double)input_scale);
98     lookup_table[(uint32_t)i] = (uint32_t)lrint(scaled_exp_xi);
99   }
100 
101   softargmax_op->channels = channels;
102 
103   softargmax_op->ukernel_type = pytorch_qnnp_ukernel_type_softargmax;
104   softargmax_op->format = pytorch_qnnp_format_quint8;
105 
106   *softargmax_out = softargmax_op;
107   return pytorch_qnnp_status_success;
108 
109 error:
110   pytorch_qnnp_delete_operator(softargmax_op);
111   return status;
112 }
113 
pytorch_qnnp_setup_softargmax_nc_q8(pytorch_qnnp_operator_t softargmax,size_t batch_size,const uint8_t * input,size_t input_stride,uint8_t * output,size_t output_stride)114 enum pytorch_qnnp_status pytorch_qnnp_setup_softargmax_nc_q8(
115     pytorch_qnnp_operator_t softargmax,
116     size_t batch_size,
117     const uint8_t* input,
118     size_t input_stride,
119     uint8_t* output,
120     size_t output_stride) {
121   if (!pytorch_qnnp_params.initialized) {
122     pytorch_qnnp_log_error(
123         "pytorch_qnnp_setup_softargmax_nc_q8 failed because QNNPACK is not properly initialized");
124     return pytorch_qnnp_status_uninitialized;
125   }
126 
127   if (batch_size == 0) {
128     softargmax->batch_size = 0;
129     return pytorch_qnnp_status_success;
130   }
131 
132   softargmax->batch_size = batch_size;
133   softargmax->input = input;
134   softargmax->input_pixel_stride = input_stride;
135   softargmax->output = output;
136   softargmax->output_pixel_stride = output_stride;
137 
138   return pytorch_qnnp_status_success;
139 }
140