xref: /aosp_15_r20/external/XNNPACK/src/subgraph/prelu.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/subgraph.h>
16 #include <xnnpack/subgraph-validation.h>
17 
18 
create_prelu_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)19 static enum xnn_status create_prelu_operator(
20   const struct xnn_node* node,
21   const struct xnn_value* values,
22   size_t num_values,
23   struct xnn_operator_data* opdata,
24   const struct xnn_caches* caches)
25 {
26   assert(node->num_inputs == 2);
27   const uint32_t input_id = node->inputs[0];
28   assert(input_id != XNN_INVALID_VALUE_ID);
29   assert(input_id < num_values);
30   const uint32_t slope_id = node->inputs[1];
31   assert(slope_id != XNN_INVALID_VALUE_ID);
32   assert(slope_id < num_values);
33 
34   assert(node->num_outputs == 1);
35   const uint32_t output_id = node->outputs[0];
36   assert(output_id != XNN_INVALID_VALUE_ID);
37   assert(output_id < num_values);
38 
39   const size_t num_input_dims = values[input_id].shape.num_dims;
40   const size_t channel_dim = num_input_dims == 0 ? 1 : values[input_id].shape.dim[num_input_dims - 1];
41 
42   enum xnn_status status;
43   switch (node->compute_type) {
44 #ifndef XNN_NO_F16_OPERATORS
45     case xnn_compute_type_fp16:
46       status = xnn_create_prelu_nc_f16(
47         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
48         values[slope_id].data /* negative slope */,
49         node->flags | XNN_FLAG_FP32_STATIC_WEIGHTS,
50         caches,
51         &opdata->operator_objects[0]);
52       break;
53 #endif  // XNN_NO_F16_OPERATORS
54     case xnn_compute_type_fp32:
55       status = xnn_create_prelu_nc_f32(
56         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
57         values[slope_id].data /* negative slope */,
58         node->flags,
59         caches,
60         &opdata->operator_objects[0]);
61       break;
62     default:
63       XNN_UNREACHABLE;
64   }
65   if (status == xnn_status_success) {
66     opdata->batch_size = xnn_shape_multiply_non_channel_dims(&values[input_id].shape);
67     opdata->inputs[0] = input_id;
68     opdata->outputs[0] = output_id;
69   }
70   return status;
71 }
72 
setup_prelu_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)73 static enum xnn_status setup_prelu_operator(
74   const struct xnn_operator_data* opdata,
75   const struct xnn_blob* blobs,
76   size_t num_blobs,
77   pthreadpool_t threadpool)
78 {
79   const uint32_t input_id = opdata->inputs[0];
80   assert(input_id != XNN_INVALID_VALUE_ID);
81   assert(input_id < num_blobs);
82 
83   const uint32_t output_id = opdata->outputs[0];
84   assert(output_id != XNN_INVALID_VALUE_ID);
85   assert(output_id < num_blobs);
86 
87   const struct xnn_blob* input_blob = blobs + input_id;
88   const void* input_data = input_blob->data;
89   assert(input_data != NULL);
90 
91   const struct xnn_blob* output_blob = blobs + output_id;
92   void* output_data = output_blob->data;
93   assert(output_data != NULL);
94 
95   switch (opdata->operator_objects[0]->type) {
96 #ifndef XNN_NO_F16_OPERATORS
97     case xnn_operator_type_prelu_nc_f16:
98       return xnn_setup_prelu_nc_f16(
99         opdata->operator_objects[0],
100         opdata->batch_size,
101         input_data,
102         output_data,
103         threadpool);
104 #endif  // XNN_NO_F16_OPERATORS
105     case xnn_operator_type_prelu_nc_f32:
106       return xnn_setup_prelu_nc_f32(
107         opdata->operator_objects[0],
108         opdata->batch_size,
109         input_data,
110         output_data,
111         threadpool);
112     default:
113       XNN_UNREACHABLE;
114   }
115 }
116 
xnn_define_prelu(xnn_subgraph_t subgraph,uint32_t input_id,uint32_t slope_id,uint32_t output_id,uint32_t flags)117 enum xnn_status xnn_define_prelu(
118   xnn_subgraph_t subgraph,
119   uint32_t input_id,
120   uint32_t slope_id,
121   uint32_t output_id,
122   uint32_t flags)
123 {
124   enum xnn_status status;
125   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_prelu)) != xnn_status_success) {
126     return status;
127   }
128 
129   if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_prelu, input_id, subgraph->num_values)) !=
130       xnn_status_success) {
131     return status;
132   }
133 
134   const struct xnn_value* input_value = &subgraph->values[input_id];
135   status = xnn_subgraph_check_input_type_dense(xnn_node_type_prelu, input_id, input_value);
136   if (status != xnn_status_success) {
137     return status;
138   }
139 
140   switch (input_value->datatype) {
141     case xnn_datatype_fp32:
142       break;
143     default:
144       xnn_log_error(
145         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
146         xnn_node_type_to_string(xnn_node_type_prelu), input_id,
147         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
148       return xnn_status_invalid_parameter;
149   }
150 
151   if (slope_id >= subgraph->num_values) {
152     xnn_log_error(
153       "failed to define %s operator with slope ID #%" PRIu32 ": invalid Value ID",
154       xnn_node_type_to_string(xnn_node_type_prelu), slope_id);
155     return xnn_status_invalid_parameter;
156   }
157 
158   const struct xnn_value* slope_value = &subgraph->values[slope_id];
159   if (slope_value->type != xnn_value_type_dense_tensor) {
160     xnn_log_error(
161       "failed to define %s operator with slope ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
162       xnn_node_type_to_string(xnn_node_type_prelu), slope_id, slope_value->type);
163     return xnn_status_invalid_parameter;
164   }
165 
166   switch (slope_value->datatype) {
167     case xnn_datatype_fp32:
168       break;
169     default:
170       xnn_log_error(
171         "failed to define %s operator with slope ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
172         xnn_node_type_to_string(xnn_node_type_prelu), slope_id,
173         xnn_datatype_to_string(slope_value->datatype), slope_value->datatype);
174       return xnn_status_invalid_parameter;
175   }
176 
177   status = xnn_subgraph_check_output_node_id(xnn_node_type_prelu, output_id, subgraph->num_values);
178   if (status != xnn_status_success) {
179     return status;
180   }
181 
182   const struct xnn_value* output_value = &subgraph->values[output_id];
183   status = xnn_subgraph_check_output_type_dense(xnn_node_type_prelu, output_id, output_value);
184   if (status != xnn_status_success) {
185     return status;
186   }
187 
188   switch (output_value->datatype) {
189     case xnn_datatype_fp32:
190       break;
191     default:
192       xnn_log_error(
193         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
194         xnn_node_type_to_string(xnn_node_type_prelu), output_id,
195         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
196       return xnn_status_invalid_parameter;
197   }
198 
199   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
200   if (node == NULL) {
201     return xnn_status_out_of_memory;
202   }
203 
204   node->type = xnn_node_type_prelu;
205   node->compute_type = xnn_compute_type_fp32;
206   node->num_inputs = 2;
207   node->inputs[0] = input_id;
208   node->inputs[1] = slope_id;
209   node->num_outputs = 1;
210   node->outputs[0] = output_id;
211   node->flags = flags;
212 
213   node->create = create_prelu_operator;
214   node->setup = setup_prelu_operator;
215 
216   return xnn_status_success;
217 }
218