1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include <xnnpack.h>
13 #include <xnnpack/log.h>
14 #include <xnnpack/operator.h>
15 #include <xnnpack/params.h>
16 #include <xnnpack/subgraph.h>
17 #include <xnnpack/subgraph-validation.h>
18
19
create_copy_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_copy_operator(
21 const struct xnn_node* node,
22 const struct xnn_value* values,
23 size_t num_values,
24 struct xnn_operator_data* opdata,
25 const struct xnn_caches* caches)
26 {
27 assert(node->num_inputs == 1);
28 const uint32_t input_id = node->inputs[0];
29 assert(input_id != XNN_INVALID_VALUE_ID);
30 assert(input_id < num_values);
31
32 assert(node->num_outputs == 1);
33 const uint32_t output_id = node->outputs[0];
34 assert(output_id != XNN_INVALID_VALUE_ID);
35 assert(output_id < num_values);
36
37 enum xnn_status status;
38 switch (node->compute_type) {
39 #ifndef XNN_NO_F16_OPERATORS
40 case xnn_compute_type_fp16:
41 status = xnn_create_copy_nc_x16(
42 1 /* channels */, 1 /* input stride */, 1 /* output stride */,
43 node->flags,
44 &opdata->operator_objects[0]);
45 break;
46 #endif // !defined(XNN_NO_F16_OPERATORS)
47 case xnn_compute_type_fp32:
48 status = xnn_create_copy_nc_x32(
49 1 /* channels */, 1 /* input stride */, 1 /* output stride */,
50 node->flags,
51 &opdata->operator_objects[0]);
52 break;
53 #ifndef XNN_NO_QS8_OPERATORS
54 case xnn_compute_type_qs8:
55 #endif // !defined(XNN_NO_QS8_OPERATORS)
56 #ifndef XNN_NO_QU8_OPERATORS
57 case xnn_compute_type_qu8:
58 #endif // !defined(XNN_NO_QU8_OPERATORS)
59 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
60 status = xnn_create_copy_nc_x8(
61 1 /* channels */, 1 /* input stride */, 1 /* output stride */,
62 node->flags,
63 &opdata->operator_objects[0]);
64 break;
65 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
66 default:
67 XNN_UNREACHABLE;
68 }
69 if (status == xnn_status_success) {
70 opdata->batch_size = xnn_shape_multiply_all_dims(&values[input_id].shape);
71 opdata->inputs[0] = input_id;
72 opdata->outputs[0] = output_id;
73 }
74 return status;
75 }
76
setup_copy_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)77 static enum xnn_status setup_copy_operator(
78 const struct xnn_operator_data* opdata,
79 const struct xnn_blob* blobs,
80 size_t num_blobs,
81 pthreadpool_t threadpool)
82 {
83 const uint32_t input_id = opdata->inputs[0];
84 assert(input_id != XNN_INVALID_VALUE_ID);
85 assert(input_id < num_blobs);
86
87 const uint32_t output_id = opdata->outputs[0];
88 assert(output_id != XNN_INVALID_VALUE_ID);
89 assert(output_id < num_blobs);
90
91 const struct xnn_blob* input_blob = blobs + input_id;
92 const void* input_data = input_blob->data;
93 assert(input_data != NULL);
94
95 const struct xnn_blob* output_blob = blobs + output_id;
96 void* output_data = output_blob->data;
97 assert(output_data != NULL);
98
99 switch (opdata->operator_objects[0]->type) {
100 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
101 case xnn_operator_type_copy_nc_x8:
102 return xnn_setup_copy_nc_x8(
103 opdata->operator_objects[0],
104 opdata->batch_size,
105 input_data,
106 output_data,
107 threadpool);
108 break;
109 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
110 #ifndef XNN_NO_F16_OPERATORS
111 case xnn_operator_type_copy_nc_x16:
112 return xnn_setup_copy_nc_x16(
113 opdata->operator_objects[0],
114 opdata->batch_size,
115 input_data,
116 output_data,
117 threadpool);
118 break;
119 #endif // !defined(XNN_NO_F16_OPERATORS)
120 case xnn_operator_type_copy_nc_x32:
121 return xnn_setup_copy_nc_x32(
122 opdata->operator_objects[0],
123 opdata->batch_size,
124 input_data,
125 output_data,
126 threadpool);
127 break;
128 default:
129 XNN_UNREACHABLE;
130 }
131 }
132
xnn_define_static_reshape(xnn_subgraph_t subgraph,size_t num_dims,const size_t * new_shape,uint32_t input_id,uint32_t output_id,uint32_t flags)133 enum xnn_status xnn_define_static_reshape(
134 xnn_subgraph_t subgraph,
135 size_t num_dims,
136 const size_t* new_shape,
137 uint32_t input_id,
138 uint32_t output_id,
139 uint32_t flags)
140 {
141 enum xnn_status status;
142 if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_static_reshape)) != xnn_status_success) {
143 return status;
144 }
145
146 status = xnn_subgraph_check_input_node_id(xnn_node_type_static_reshape, input_id, subgraph->num_values);
147 if (status != xnn_status_success) {
148 return status;
149 }
150
151 const struct xnn_value* input_value = &subgraph->values[input_id];
152 status = xnn_subgraph_check_input_type_dense(xnn_node_type_static_reshape, input_id, input_value);
153 if (status != xnn_status_success) {
154 return status;
155 }
156
157 switch (input_value->datatype) {
158 case xnn_datatype_fp32:
159 #ifndef XNN_NO_QS8_OPERATORS
160 case xnn_datatype_qint8:
161 #endif // !defined(XNN_NO_QS8_OPERATORS)
162 #ifndef XNN_NO_QU8_OPERATORS
163 case xnn_datatype_quint8:
164 #endif // !defined(XNN_NO_QU8_OPERATORS)
165 break;
166 default:
167 xnn_log_error(
168 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
169 xnn_node_type_to_string(xnn_node_type_static_reshape), input_id,
170 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
171 return xnn_status_invalid_parameter;
172 }
173
174 status = xnn_subgraph_check_output_node_id(xnn_node_type_static_reshape, output_id, subgraph->num_values);
175 if (status != xnn_status_success) {
176 return status;
177 }
178
179 const struct xnn_value* output_value = &subgraph->values[output_id];
180 status = xnn_subgraph_check_output_type_dense(xnn_node_type_static_reshape, output_id, output_value);
181 if (status != xnn_status_success) {
182 return status;
183 }
184
185 const size_t num_input_elements = xnn_shape_multiply_all_dims(&input_value->shape);
186 const size_t num_output_elements = xnn_shape_multiply_all_dims(&output_value->shape);
187
188 if (num_input_elements != num_output_elements) {
189 xnn_log_error(
190 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
191 ": number of input elements, %zu, does not match number of output elements %zu",
192 xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id, num_input_elements,
193 num_output_elements);
194 return xnn_status_invalid_parameter;
195 }
196
197 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
198 switch (output_value->datatype) {
199 case xnn_datatype_fp32:
200 compute_type = xnn_compute_type_fp32;
201 break;
202 #ifndef XNN_NO_QS8_OPERATORS
203 case xnn_datatype_qint8:
204 compute_type = xnn_compute_type_qs8;
205 break;
206 #endif // !defined(XNN_NO_QS8_OPERATORS)
207 #ifndef XNN_NO_QU8_OPERATORS
208 case xnn_datatype_quint8:
209 compute_type = xnn_compute_type_qu8;
210 break;
211 #endif // !defined(XNN_NO_QU8_OPERATORS)
212 default:
213 xnn_log_error(
214 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
215 xnn_node_type_to_string(xnn_node_type_static_reshape), output_id,
216 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
217 return xnn_status_invalid_parameter;
218 }
219
220 status = xnn_subgraph_check_datatype_matches(xnn_node_type_static_reshape, input_id, input_value, output_id, output_value);
221 if (status != xnn_status_success) {
222 return status;
223 }
224
225 #if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
226 if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
227 if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
228 xnn_log_error(
229 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
230 ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
231 xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
232 input_value->quantization.zero_point, output_value->quantization.zero_point);
233 return xnn_status_invalid_parameter;
234 }
235 if (input_value->quantization.scale != output_value->quantization.scale) {
236 xnn_log_error(
237 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
238 ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
239 xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
240 input_value->quantization.scale, output_value->quantization.scale);
241 return xnn_status_invalid_parameter;
242 }
243 }
244 #endif // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
245
246 if (num_dims > XNN_MAX_TENSOR_DIMS) {
247 xnn_log_error(
248 "failed to define %s operator with %zu-dimensional output shape: at most %zu dimensions are supported",
249 xnn_node_type_to_string(xnn_node_type_static_reshape), num_dims, (size_t) XNN_MAX_TENSOR_DIMS);
250 return xnn_status_unsupported_parameter;
251 }
252
253 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
254 if (node == NULL) {
255 return xnn_status_out_of_memory;
256 }
257
258 node->params.static_reshape.new_shape.num_dims = num_dims;
259 memcpy(&node->params.static_reshape.new_shape.dim, new_shape, num_dims * sizeof(size_t));
260
261 node->type = xnn_node_type_static_reshape;
262 node->compute_type = compute_type;
263 node->num_inputs = 1;
264 node->inputs[0] = input_id;
265 node->num_outputs = 1;
266 node->outputs[0] = output_id;
267 node->flags = flags;
268
269 node->create = create_copy_operator;
270 node->setup = setup_copy_operator;
271
272 return xnn_status_success;
273 }
274