xref: /aosp_15_r20/external/XNNPACK/src/subgraph/divide.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <string.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/subgraph.h>
16 #include <xnnpack/subgraph-validation.h>
17 
18 
create_divide_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)19 static enum xnn_status create_divide_operator(
20   const struct xnn_node* node,
21   const struct xnn_value* values,
22   size_t num_values,
23   struct xnn_operator_data* opdata,
24   const struct xnn_caches* caches)
25 {
26   assert(node->compute_type == xnn_compute_type_fp32);
27 
28   assert(node->num_inputs == 2);
29   const uint32_t input1_id = node->inputs[0];
30   assert(input1_id != XNN_INVALID_VALUE_ID);
31   assert(input1_id < num_values);
32   const uint32_t input2_id = node->inputs[1];
33   assert(input2_id != XNN_INVALID_VALUE_ID);
34   assert(input2_id < num_values);
35 
36   assert(node->num_outputs == 1);
37   const uint32_t output_id = node->outputs[0];
38   assert(output_id != XNN_INVALID_VALUE_ID);
39   assert(output_id < num_values);
40 
41   enum xnn_status status;
42   switch (node->compute_type) {
43 #ifndef XNN_NO_F16_OPERATORS
44     case xnn_compute_type_fp16:
45       status = xnn_create_divide_nd_f16(
46         node->activation.output_min,
47         node->activation.output_max,
48         node->flags,
49         &opdata->operator_objects[0]);
50       break;
51 #endif  // !defined(XNN_NO_F16_OPERATORS)
52     case xnn_compute_type_fp32:
53       status = xnn_create_divide_nd_f32(
54         node->activation.output_min,
55         node->activation.output_max,
56         node->flags,
57         &opdata->operator_objects[0]);
58       break;
59     default:
60       XNN_UNREACHABLE;
61   }
62   if (status == xnn_status_success) {
63     opdata->shape1.num_dims = values[input1_id].shape.num_dims;
64     opdata->shape2.num_dims = values[input2_id].shape.num_dims;
65     if (values[output_id].layout == xnn_layout_type_nchw) {
66       assert(values[input1_id].layout == xnn_layout_type_nchw);
67       assert(values[input2_id].layout == xnn_layout_type_nchw);
68       opdata->shape1.dim[0] = values[input1_id].shape.dim[0];
69       opdata->shape1.dim[1] = values[input1_id].shape.dim[values[input1_id].shape.num_dims - 1];
70       if (values[input1_id].shape.num_dims > 2) {
71         memcpy(&opdata->shape1.dim[2], &values[input1_id].shape.dim[1], (values[input1_id].shape.num_dims - 2) * sizeof(size_t));
72       }
73       opdata->shape2.dim[0] = values[input2_id].shape.dim[0];
74       opdata->shape2.dim[1] = values[input2_id].shape.dim[values[input2_id].shape.num_dims - 1];
75       if (values[input1_id].shape.num_dims > 2) {
76         memcpy(&opdata->shape2.dim[2], &values[input2_id].shape.dim[1], (values[input2_id].shape.num_dims - 2) * sizeof(size_t));
77       }
78     } else {
79       assert(values[output_id].layout == xnn_layout_type_nhwc);
80       assert(values[input1_id].layout == xnn_layout_type_nhwc);
81       assert(values[input2_id].layout == xnn_layout_type_nhwc);
82       memcpy(opdata->shape1.dim, values[input1_id].shape.dim, values[input1_id].shape.num_dims * sizeof(size_t));
83       memcpy(opdata->shape2.dim, values[input2_id].shape.dim, values[input2_id].shape.num_dims * sizeof(size_t));
84     }
85     opdata->inputs[0] = input1_id;
86     opdata->inputs[1] = input2_id;
87     opdata->outputs[0] = output_id;
88   }
89   return status;
90 }
91 
setup_divide_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)92 static enum xnn_status setup_divide_operator(
93   const struct xnn_operator_data* opdata,
94   const struct xnn_blob* blobs,
95   size_t num_blobs,
96   pthreadpool_t threadpool)
97 {
98   const uint32_t input1_id = opdata->inputs[0];
99   assert(input1_id != XNN_INVALID_VALUE_ID);
100   assert(input1_id < num_blobs);
101 
102   const uint32_t input2_id = opdata->inputs[1];
103   assert(input2_id != XNN_INVALID_VALUE_ID);
104   assert(input2_id < num_blobs);
105 
106   const uint32_t output_id = opdata->outputs[0];
107   assert(output_id != XNN_INVALID_VALUE_ID);
108   assert(output_id < num_blobs);
109 
110   const struct xnn_blob* input1_blob = blobs + input1_id;
111   const void* input1_data = input1_blob->data;
112   assert(input1_data != NULL);
113 
114   const struct xnn_blob* input2_blob = blobs + input2_id;
115   const void* input2_data = input2_blob->data;
116   assert(input2_data != NULL);
117 
118   const struct xnn_blob* output_blob = blobs + output_id;
119   void* output_data = output_blob->data;
120   assert(output_data != NULL);
121 
122   switch (opdata->operator_objects[0]->type) {
123 #ifndef XNN_NO_F16_OPERATORS
124     case xnn_operator_type_divide_nd_f16:
125       return xnn_setup_divide_nd_f16(
126         opdata->operator_objects[0],
127         opdata->shape1.num_dims,
128         opdata->shape1.dim,
129         opdata->shape2.num_dims,
130         opdata->shape2.dim,
131         input1_data, input2_data, output_data,
132         threadpool);
133 #endif  // !defined(XNN_NO_F16_OPERATORS)
134     case xnn_operator_type_divide_nd_f32:
135       return xnn_setup_divide_nd_f32(
136         opdata->operator_objects[0],
137         opdata->shape1.num_dims,
138         opdata->shape1.dim,
139         opdata->shape2.num_dims,
140         opdata->shape2.dim,
141         input1_data, input2_data, output_data,
142         threadpool);
143     default:
144       XNN_UNREACHABLE;
145   }
146 }
147 
xnn_define_divide(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)148 enum xnn_status xnn_define_divide(
149   xnn_subgraph_t subgraph,
150   float output_min,
151   float output_max,
152   uint32_t input1_id,
153   uint32_t input2_id,
154   uint32_t output_id,
155   uint32_t flags)
156 {
157   enum xnn_status status;
158   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_divide)) != xnn_status_success) {
159     return status;
160   }
161 
162   status = xnn_subgraph_check_output_min_max(xnn_node_type_divide, output_min, output_max);
163   if (status != xnn_status_success) {
164     return status;
165   }
166 
167   if ((status = xnn_subgraph_check_nth_input_node_id(xnn_node_type_divide, input1_id, subgraph->num_values, 1)) !=
168       xnn_status_success) {
169     return status;
170   }
171 
172   const struct xnn_value* input1_value = &subgraph->values[input1_id];
173   status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_divide, input1_id, input1_value, 1);
174   if (status != xnn_status_success) {
175     return status;
176   }
177 
178   switch (input1_value->datatype) {
179     case xnn_datatype_fp32:
180       break;
181     default:
182       xnn_log_error(
183         "failed to define %s operator with the first input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
184         xnn_node_type_to_string(xnn_node_type_divide), input1_id,
185         xnn_datatype_to_string(input1_value->datatype), input1_value->datatype);
186       return xnn_status_invalid_parameter;
187   }
188 
189   if ((status = xnn_subgraph_check_nth_input_node_id(
190         xnn_node_type_divide, input2_id, subgraph->num_values, 2)) != xnn_status_success) {
191     return status;
192   }
193 
194   const struct xnn_value* input2_value = &subgraph->values[input2_id];
195   status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_divide, input2_id, input2_value, 2);
196   if (status != xnn_status_success) {
197     return status;
198   }
199 
200   switch (input2_value->datatype) {
201     case xnn_datatype_fp32:
202       break;
203     default:
204       xnn_log_error(
205         "failed to define %s operator with the second input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
206         xnn_node_type_to_string(xnn_node_type_divide), input2_id,
207         xnn_datatype_to_string(input2_value->datatype), input2_value->datatype);
208       return xnn_status_invalid_parameter;
209   }
210 
211   status = xnn_subgraph_check_output_node_id(xnn_node_type_divide, output_id, subgraph->num_values);
212   if (status != xnn_status_success) {
213     return status;
214   }
215 
216   const struct xnn_value* output_value = &subgraph->values[output_id];
217   status = xnn_subgraph_check_output_type_dense(xnn_node_type_divide, output_id, output_value);
218   if (status != xnn_status_success) {
219     return status;
220   }
221 
222   switch (output_value->datatype) {
223     case xnn_datatype_fp32:
224       break;
225     default:
226       xnn_log_error(
227         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
228         xnn_node_type_to_string(xnn_node_type_divide), output_id,
229         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
230       return xnn_status_invalid_parameter;
231   }
232 
233   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
234   if (node == NULL) {
235     return xnn_status_out_of_memory;
236   }
237 
238   node->type = xnn_node_type_divide;
239   node->compute_type = xnn_compute_type_fp32;
240   node->activation.output_min = output_min;
241   node->activation.output_max = output_max;
242   node->num_inputs = 2;
243   node->inputs[0] = input1_id;
244   node->inputs[1] = input2_id;
245   node->num_outputs = 1;
246   node->outputs[0] = output_id;
247   node->flags = flags;
248 
249   node->create = create_divide_operator;
250   node->setup = setup_divide_operator;
251 
252   return xnn_status_success;
253 }
254