xref: /aosp_15_r20/external/XNNPACK/src/subgraph/softmax.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/subgraph.h>
16 #include <xnnpack/subgraph-validation.h>
17 
18 
create_softmax_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)19 static enum xnn_status create_softmax_operator(
20   const struct xnn_node* node,
21   const struct xnn_value* values,
22   size_t num_values,
23   struct xnn_operator_data* opdata,
24   const struct xnn_caches* caches)
25 {
26   assert(node->num_inputs == 1);
27   const uint32_t input_id = node->inputs[0];
28   assert(input_id != XNN_INVALID_VALUE_ID);
29   assert(input_id < num_values);
30 
31   assert(node->num_outputs == 1);
32   const uint32_t output_id = node->outputs[0];
33   assert(output_id != XNN_INVALID_VALUE_ID);
34   assert(output_id < num_values);
35 
36   const size_t num_input_dims = values[input_id].shape.num_dims;
37   assert(num_input_dims > 0);
38   const size_t channel_dim = values[input_id].shape.dim[num_input_dims - 1];
39 
40   enum xnn_status status;
41   switch (node->compute_type) {
42     case xnn_datatype_fp32:
43       status = xnn_create_softmax_nc_f32(
44         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
45         node->flags,
46         &opdata->operator_objects[0]);
47       break;
48 #ifndef XNN_NO_F16_OPERATORS
49     case xnn_datatype_fp16:
50       status = xnn_create_softmax_nc_f16(
51         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
52         node->flags,
53         &opdata->operator_objects[0]);
54       break;
55 #endif  // !defined(XNN_NO_F16_OPERATORS)
56     default:
57       XNN_UNREACHABLE;
58   }
59   if (status == xnn_status_success) {
60     opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
61     opdata->inputs[0] = input_id;
62     opdata->outputs[0] = output_id;
63   }
64   return status;
65 }
66 
setup_softmax_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)67 static enum xnn_status setup_softmax_operator(
68   const struct xnn_operator_data* opdata,
69   const struct xnn_blob* blobs,
70   size_t num_blobs,
71   pthreadpool_t threadpool)
72 {
73   const uint32_t input_id = opdata->inputs[0];
74   assert(input_id != XNN_INVALID_VALUE_ID);
75   assert(input_id < num_blobs);
76 
77   const uint32_t output_id = opdata->outputs[0];
78   assert(output_id != XNN_INVALID_VALUE_ID);
79   assert(output_id < num_blobs);
80 
81   const struct xnn_blob* input_blob = blobs + input_id;
82   const void* input_data = input_blob->data;
83   assert(input_data != NULL);
84 
85   const struct xnn_blob* output_blob = blobs + output_id;
86   void* output_data = output_blob->data;
87   assert(output_data != NULL);
88 
89   switch (opdata->operator_objects[0]->type) {
90     case xnn_operator_type_softmax_nc_f32:
91       return xnn_setup_softmax_nc_f32(
92         opdata->operator_objects[0],
93         opdata->batch_size,
94         input_data,
95         output_data,
96         threadpool);
97 #ifndef XNN_NO_F16_OPERATORS
98     case xnn_operator_type_softmax_nc_f16:
99       return xnn_setup_softmax_nc_f16(
100         opdata->operator_objects[0],
101         opdata->batch_size,
102         input_data,
103         output_data,
104         threadpool);
105 #endif  // !defined(XNN_NO_F16_OPERATORS)
106     default:
107       XNN_UNREACHABLE;
108   }
109 }
110 
xnn_define_softmax(xnn_subgraph_t subgraph,uint32_t input_id,uint32_t output_id,uint32_t flags)111 enum xnn_status xnn_define_softmax(
112   xnn_subgraph_t subgraph,
113   uint32_t input_id,
114   uint32_t output_id,
115   uint32_t flags)
116 {
117   enum xnn_status status;
118   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_softmax)) != xnn_status_success) {
119     return status;
120   }
121 
122   if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_softmax, input_id, subgraph->num_values)) !=
123       xnn_status_success) {
124     return status;
125   }
126 
127   const struct xnn_value* input_value = &subgraph->values[input_id];
128   status = xnn_subgraph_check_input_type_dense(xnn_node_type_softmax, input_id, input_value);
129   if (status != xnn_status_success) {
130     return status;
131   }
132 
133   if (input_value->shape.num_dims < 1) {
134     xnn_log_error(
135       "failed to define %s operator with input ID #%" PRIu32 ": number of dimensions must be at least 1",
136       xnn_node_type_to_string(xnn_node_type_softmax), input_id);
137     return xnn_status_invalid_parameter;
138   }
139 
140   switch (input_value->datatype) {
141     case xnn_datatype_fp32:
142       break;
143     default:
144       xnn_log_error(
145         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
146         xnn_node_type_to_string(xnn_node_type_softmax), input_id,
147         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
148       return xnn_status_invalid_parameter;
149   }
150 
151   status = xnn_subgraph_check_output_node_id(xnn_node_type_softmax, output_id, subgraph->num_values);
152   if (status != xnn_status_success) {
153     return status;
154   }
155 
156   const struct xnn_value* output_value = &subgraph->values[output_id];
157   status = xnn_subgraph_check_output_type_dense(xnn_node_type_softmax, output_id, output_value);
158   if (status != xnn_status_success) {
159     return status;
160   }
161 
162   switch (output_value->datatype) {
163     case xnn_datatype_fp32:
164       break;
165     default:
166       xnn_log_error(
167         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
168         xnn_node_type_to_string(xnn_node_type_softmax), output_id,
169         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
170       return xnn_status_invalid_parameter;
171   }
172 
173   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
174   if (node == NULL) {
175     return xnn_status_out_of_memory;
176   }
177 
178   node->type = xnn_node_type_softmax;
179   node->compute_type = xnn_compute_type_fp32;
180   node->num_inputs = 1;
181   node->inputs[0] = input_id;
182   node->num_outputs = 1;
183   node->outputs[0] = output_id;
184   node->flags = flags;
185 
186   node->create = create_softmax_operator;
187   node->setup = setup_softmax_operator;
188 
189   return xnn_status_success;
190 }
191