1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <string.h>
10
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/subgraph.h>
16 #include <xnnpack/subgraph-validation.h>
17
18
create_divide_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)19 static enum xnn_status create_divide_operator(
20 const struct xnn_node* node,
21 const struct xnn_value* values,
22 size_t num_values,
23 struct xnn_operator_data* opdata,
24 const struct xnn_caches* caches)
25 {
26 assert(node->compute_type == xnn_compute_type_fp32);
27
28 assert(node->num_inputs == 2);
29 const uint32_t input1_id = node->inputs[0];
30 assert(input1_id != XNN_INVALID_VALUE_ID);
31 assert(input1_id < num_values);
32 const uint32_t input2_id = node->inputs[1];
33 assert(input2_id != XNN_INVALID_VALUE_ID);
34 assert(input2_id < num_values);
35
36 assert(node->num_outputs == 1);
37 const uint32_t output_id = node->outputs[0];
38 assert(output_id != XNN_INVALID_VALUE_ID);
39 assert(output_id < num_values);
40
41 enum xnn_status status;
42 switch (node->compute_type) {
43 #ifndef XNN_NO_F16_OPERATORS
44 case xnn_compute_type_fp16:
45 status = xnn_create_divide_nd_f16(
46 node->activation.output_min,
47 node->activation.output_max,
48 node->flags,
49 &opdata->operator_objects[0]);
50 break;
51 #endif // !defined(XNN_NO_F16_OPERATORS)
52 case xnn_compute_type_fp32:
53 status = xnn_create_divide_nd_f32(
54 node->activation.output_min,
55 node->activation.output_max,
56 node->flags,
57 &opdata->operator_objects[0]);
58 break;
59 default:
60 XNN_UNREACHABLE;
61 }
62 if (status == xnn_status_success) {
63 opdata->shape1.num_dims = values[input1_id].shape.num_dims;
64 opdata->shape2.num_dims = values[input2_id].shape.num_dims;
65 if (values[output_id].layout == xnn_layout_type_nchw) {
66 assert(values[input1_id].layout == xnn_layout_type_nchw);
67 assert(values[input2_id].layout == xnn_layout_type_nchw);
68 opdata->shape1.dim[0] = values[input1_id].shape.dim[0];
69 opdata->shape1.dim[1] = values[input1_id].shape.dim[values[input1_id].shape.num_dims - 1];
70 if (values[input1_id].shape.num_dims > 2) {
71 memcpy(&opdata->shape1.dim[2], &values[input1_id].shape.dim[1], (values[input1_id].shape.num_dims - 2) * sizeof(size_t));
72 }
73 opdata->shape2.dim[0] = values[input2_id].shape.dim[0];
74 opdata->shape2.dim[1] = values[input2_id].shape.dim[values[input2_id].shape.num_dims - 1];
75 if (values[input1_id].shape.num_dims > 2) {
76 memcpy(&opdata->shape2.dim[2], &values[input2_id].shape.dim[1], (values[input2_id].shape.num_dims - 2) * sizeof(size_t));
77 }
78 } else {
79 assert(values[output_id].layout == xnn_layout_type_nhwc);
80 assert(values[input1_id].layout == xnn_layout_type_nhwc);
81 assert(values[input2_id].layout == xnn_layout_type_nhwc);
82 memcpy(opdata->shape1.dim, values[input1_id].shape.dim, values[input1_id].shape.num_dims * sizeof(size_t));
83 memcpy(opdata->shape2.dim, values[input2_id].shape.dim, values[input2_id].shape.num_dims * sizeof(size_t));
84 }
85 opdata->inputs[0] = input1_id;
86 opdata->inputs[1] = input2_id;
87 opdata->outputs[0] = output_id;
88 }
89 return status;
90 }
91
setup_divide_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)92 static enum xnn_status setup_divide_operator(
93 const struct xnn_operator_data* opdata,
94 const struct xnn_blob* blobs,
95 size_t num_blobs,
96 pthreadpool_t threadpool)
97 {
98 const uint32_t input1_id = opdata->inputs[0];
99 assert(input1_id != XNN_INVALID_VALUE_ID);
100 assert(input1_id < num_blobs);
101
102 const uint32_t input2_id = opdata->inputs[1];
103 assert(input2_id != XNN_INVALID_VALUE_ID);
104 assert(input2_id < num_blobs);
105
106 const uint32_t output_id = opdata->outputs[0];
107 assert(output_id != XNN_INVALID_VALUE_ID);
108 assert(output_id < num_blobs);
109
110 const struct xnn_blob* input1_blob = blobs + input1_id;
111 const void* input1_data = input1_blob->data;
112 assert(input1_data != NULL);
113
114 const struct xnn_blob* input2_blob = blobs + input2_id;
115 const void* input2_data = input2_blob->data;
116 assert(input2_data != NULL);
117
118 const struct xnn_blob* output_blob = blobs + output_id;
119 void* output_data = output_blob->data;
120 assert(output_data != NULL);
121
122 switch (opdata->operator_objects[0]->type) {
123 #ifndef XNN_NO_F16_OPERATORS
124 case xnn_operator_type_divide_nd_f16:
125 return xnn_setup_divide_nd_f16(
126 opdata->operator_objects[0],
127 opdata->shape1.num_dims,
128 opdata->shape1.dim,
129 opdata->shape2.num_dims,
130 opdata->shape2.dim,
131 input1_data, input2_data, output_data,
132 threadpool);
133 #endif // !defined(XNN_NO_F16_OPERATORS)
134 case xnn_operator_type_divide_nd_f32:
135 return xnn_setup_divide_nd_f32(
136 opdata->operator_objects[0],
137 opdata->shape1.num_dims,
138 opdata->shape1.dim,
139 opdata->shape2.num_dims,
140 opdata->shape2.dim,
141 input1_data, input2_data, output_data,
142 threadpool);
143 default:
144 XNN_UNREACHABLE;
145 }
146 }
147
xnn_define_divide(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)148 enum xnn_status xnn_define_divide(
149 xnn_subgraph_t subgraph,
150 float output_min,
151 float output_max,
152 uint32_t input1_id,
153 uint32_t input2_id,
154 uint32_t output_id,
155 uint32_t flags)
156 {
157 enum xnn_status status;
158 if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_divide)) != xnn_status_success) {
159 return status;
160 }
161
162 status = xnn_subgraph_check_output_min_max(xnn_node_type_divide, output_min, output_max);
163 if (status != xnn_status_success) {
164 return status;
165 }
166
167 if ((status = xnn_subgraph_check_nth_input_node_id(xnn_node_type_divide, input1_id, subgraph->num_values, 1)) !=
168 xnn_status_success) {
169 return status;
170 }
171
172 const struct xnn_value* input1_value = &subgraph->values[input1_id];
173 status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_divide, input1_id, input1_value, 1);
174 if (status != xnn_status_success) {
175 return status;
176 }
177
178 switch (input1_value->datatype) {
179 case xnn_datatype_fp32:
180 break;
181 default:
182 xnn_log_error(
183 "failed to define %s operator with the first input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
184 xnn_node_type_to_string(xnn_node_type_divide), input1_id,
185 xnn_datatype_to_string(input1_value->datatype), input1_value->datatype);
186 return xnn_status_invalid_parameter;
187 }
188
189 if ((status = xnn_subgraph_check_nth_input_node_id(
190 xnn_node_type_divide, input2_id, subgraph->num_values, 2)) != xnn_status_success) {
191 return status;
192 }
193
194 const struct xnn_value* input2_value = &subgraph->values[input2_id];
195 status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_divide, input2_id, input2_value, 2);
196 if (status != xnn_status_success) {
197 return status;
198 }
199
200 switch (input2_value->datatype) {
201 case xnn_datatype_fp32:
202 break;
203 default:
204 xnn_log_error(
205 "failed to define %s operator with the second input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
206 xnn_node_type_to_string(xnn_node_type_divide), input2_id,
207 xnn_datatype_to_string(input2_value->datatype), input2_value->datatype);
208 return xnn_status_invalid_parameter;
209 }
210
211 status = xnn_subgraph_check_output_node_id(xnn_node_type_divide, output_id, subgraph->num_values);
212 if (status != xnn_status_success) {
213 return status;
214 }
215
216 const struct xnn_value* output_value = &subgraph->values[output_id];
217 status = xnn_subgraph_check_output_type_dense(xnn_node_type_divide, output_id, output_value);
218 if (status != xnn_status_success) {
219 return status;
220 }
221
222 switch (output_value->datatype) {
223 case xnn_datatype_fp32:
224 break;
225 default:
226 xnn_log_error(
227 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
228 xnn_node_type_to_string(xnn_node_type_divide), output_id,
229 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
230 return xnn_status_invalid_parameter;
231 }
232
233 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
234 if (node == NULL) {
235 return xnn_status_out_of_memory;
236 }
237
238 node->type = xnn_node_type_divide;
239 node->compute_type = xnn_compute_type_fp32;
240 node->activation.output_min = output_min;
241 node->activation.output_max = output_max;
242 node->num_inputs = 2;
243 node->inputs[0] = input1_id;
244 node->inputs[1] = input2_id;
245 node->num_outputs = 1;
246 node->outputs[0] = output_id;
247 node->flags = flags;
248
249 node->create = create_divide_operator;
250 node->setup = setup_divide_operator;
251
252 return xnn_status_success;
253 }
254