xref: /aosp_15_r20/external/XNNPACK/src/subgraph/static-reshape.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11 
12 #include <xnnpack.h>
13 #include <xnnpack/log.h>
14 #include <xnnpack/operator.h>
15 #include <xnnpack/params.h>
16 #include <xnnpack/subgraph.h>
17 #include <xnnpack/subgraph-validation.h>
18 
19 
create_copy_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_copy_operator(
21   const struct xnn_node* node,
22   const struct xnn_value* values,
23   size_t num_values,
24   struct xnn_operator_data* opdata,
25   const struct xnn_caches* caches)
26 {
27   assert(node->num_inputs == 1);
28   const uint32_t input_id = node->inputs[0];
29   assert(input_id != XNN_INVALID_VALUE_ID);
30   assert(input_id < num_values);
31 
32   assert(node->num_outputs == 1);
33   const uint32_t output_id = node->outputs[0];
34   assert(output_id != XNN_INVALID_VALUE_ID);
35   assert(output_id < num_values);
36 
37   enum xnn_status status;
38   switch (node->compute_type) {
39 #ifndef XNN_NO_F16_OPERATORS
40     case xnn_compute_type_fp16:
41       status = xnn_create_copy_nc_x16(
42         1 /* channels */, 1 /* input stride */, 1 /* output stride */,
43         node->flags,
44         &opdata->operator_objects[0]);
45       break;
46 #endif  // !defined(XNN_NO_F16_OPERATORS)
47     case xnn_compute_type_fp32:
48       status = xnn_create_copy_nc_x32(
49         1 /* channels */, 1 /* input stride */, 1 /* output stride */,
50         node->flags,
51         &opdata->operator_objects[0]);
52       break;
53 #ifndef XNN_NO_QS8_OPERATORS
54     case xnn_compute_type_qs8:
55 #endif  // !defined(XNN_NO_QS8_OPERATORS)
56 #ifndef XNN_NO_QU8_OPERATORS
57     case xnn_compute_type_qu8:
58 #endif  // !defined(XNN_NO_QU8_OPERATORS)
59 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
60       status = xnn_create_copy_nc_x8(
61         1 /* channels */, 1 /* input stride */, 1 /* output stride */,
62         node->flags,
63         &opdata->operator_objects[0]);
64       break;
65 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
66     default:
67       XNN_UNREACHABLE;
68   }
69   if (status == xnn_status_success) {
70     opdata->batch_size = xnn_shape_multiply_all_dims(&values[input_id].shape);
71     opdata->inputs[0] = input_id;
72     opdata->outputs[0] = output_id;
73   }
74   return status;
75 }
76 
setup_copy_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)77 static enum xnn_status setup_copy_operator(
78   const struct xnn_operator_data* opdata,
79   const struct xnn_blob* blobs,
80   size_t num_blobs,
81   pthreadpool_t threadpool)
82 {
83   const uint32_t input_id = opdata->inputs[0];
84   assert(input_id != XNN_INVALID_VALUE_ID);
85   assert(input_id < num_blobs);
86 
87   const uint32_t output_id = opdata->outputs[0];
88   assert(output_id != XNN_INVALID_VALUE_ID);
89   assert(output_id < num_blobs);
90 
91   const struct xnn_blob* input_blob = blobs + input_id;
92   const void* input_data = input_blob->data;
93   assert(input_data != NULL);
94 
95   const struct xnn_blob* output_blob = blobs + output_id;
96   void* output_data = output_blob->data;
97   assert(output_data != NULL);
98 
99   switch (opdata->operator_objects[0]->type) {
100 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
101     case xnn_operator_type_copy_nc_x8:
102       return xnn_setup_copy_nc_x8(
103         opdata->operator_objects[0],
104         opdata->batch_size,
105         input_data,
106         output_data,
107         threadpool);
108       break;
109 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
110 #ifndef XNN_NO_F16_OPERATORS
111     case xnn_operator_type_copy_nc_x16:
112       return xnn_setup_copy_nc_x16(
113         opdata->operator_objects[0],
114         opdata->batch_size,
115         input_data,
116         output_data,
117         threadpool);
118       break;
119 #endif  // !defined(XNN_NO_F16_OPERATORS)
120     case xnn_operator_type_copy_nc_x32:
121       return xnn_setup_copy_nc_x32(
122         opdata->operator_objects[0],
123         opdata->batch_size,
124         input_data,
125         output_data,
126         threadpool);
127       break;
128     default:
129       XNN_UNREACHABLE;
130   }
131 }
132 
xnn_define_static_reshape(xnn_subgraph_t subgraph,size_t num_dims,const size_t * new_shape,uint32_t input_id,uint32_t output_id,uint32_t flags)133 enum xnn_status xnn_define_static_reshape(
134   xnn_subgraph_t subgraph,
135   size_t num_dims,
136   const size_t* new_shape,
137   uint32_t input_id,
138   uint32_t output_id,
139   uint32_t flags)
140 {
141   enum xnn_status status;
142   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_static_reshape)) != xnn_status_success) {
143     return status;
144   }
145 
146   status = xnn_subgraph_check_input_node_id(xnn_node_type_static_reshape, input_id, subgraph->num_values);
147   if (status != xnn_status_success) {
148     return status;
149   }
150 
151   const struct xnn_value* input_value = &subgraph->values[input_id];
152   status = xnn_subgraph_check_input_type_dense(xnn_node_type_static_reshape, input_id, input_value);
153   if (status != xnn_status_success) {
154     return status;
155   }
156 
157   switch (input_value->datatype) {
158     case xnn_datatype_fp32:
159 #ifndef XNN_NO_QS8_OPERATORS
160     case xnn_datatype_qint8:
161 #endif  // !defined(XNN_NO_QS8_OPERATORS)
162 #ifndef XNN_NO_QU8_OPERATORS
163     case xnn_datatype_quint8:
164 #endif  // !defined(XNN_NO_QU8_OPERATORS)
165       break;
166     default:
167       xnn_log_error(
168         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
169         xnn_node_type_to_string(xnn_node_type_static_reshape), input_id,
170         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
171       return xnn_status_invalid_parameter;
172   }
173 
174   status = xnn_subgraph_check_output_node_id(xnn_node_type_static_reshape, output_id, subgraph->num_values);
175   if (status != xnn_status_success) {
176     return status;
177   }
178 
179   const struct xnn_value* output_value = &subgraph->values[output_id];
180   status = xnn_subgraph_check_output_type_dense(xnn_node_type_static_reshape, output_id, output_value);
181   if (status != xnn_status_success) {
182     return status;
183   }
184 
185   const size_t num_input_elements = xnn_shape_multiply_all_dims(&input_value->shape);
186   const size_t num_output_elements = xnn_shape_multiply_all_dims(&output_value->shape);
187 
188   if (num_input_elements != num_output_elements) {
189     xnn_log_error(
190         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
191         ": number of input elements, %zu, does not match number of output elements %zu",
192         xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id, num_input_elements,
193         num_output_elements);
194       return xnn_status_invalid_parameter;
195   }
196 
197   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
198   switch (output_value->datatype) {
199     case xnn_datatype_fp32:
200       compute_type = xnn_compute_type_fp32;
201       break;
202 #ifndef XNN_NO_QS8_OPERATORS
203     case xnn_datatype_qint8:
204       compute_type = xnn_compute_type_qs8;
205       break;
206 #endif  // !defined(XNN_NO_QS8_OPERATORS)
207 #ifndef XNN_NO_QU8_OPERATORS
208     case xnn_datatype_quint8:
209       compute_type = xnn_compute_type_qu8;
210       break;
211 #endif  // !defined(XNN_NO_QU8_OPERATORS)
212     default:
213       xnn_log_error(
214         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
215         xnn_node_type_to_string(xnn_node_type_static_reshape), output_id,
216         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
217       return xnn_status_invalid_parameter;
218   }
219 
220   status = xnn_subgraph_check_datatype_matches(xnn_node_type_static_reshape, input_id, input_value, output_id, output_value);
221   if (status != xnn_status_success) {
222     return status;
223   }
224 
225 #if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
226   if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
227     if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
228       xnn_log_error(
229         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
230         ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
231         xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
232         input_value->quantization.zero_point, output_value->quantization.zero_point);
233       return xnn_status_invalid_parameter;
234     }
235     if (input_value->quantization.scale != output_value->quantization.scale) {
236       xnn_log_error(
237         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
238         ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
239         xnn_node_type_to_string(xnn_node_type_static_reshape), input_id, output_id,
240         input_value->quantization.scale, output_value->quantization.scale);
241       return xnn_status_invalid_parameter;
242     }
243   }
244 #endif  // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
245 
246   if (num_dims > XNN_MAX_TENSOR_DIMS) {
247     xnn_log_error(
248       "failed to define %s operator with %zu-dimensional output shape: at most %zu dimensions are supported",
249       xnn_node_type_to_string(xnn_node_type_static_reshape), num_dims, (size_t) XNN_MAX_TENSOR_DIMS);
250     return xnn_status_unsupported_parameter;
251   }
252 
253   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
254   if (node == NULL) {
255     return xnn_status_out_of_memory;
256   }
257 
258   node->params.static_reshape.new_shape.num_dims = num_dims;
259   memcpy(&node->params.static_reshape.new_shape.dim, new_shape, num_dims * sizeof(size_t));
260 
261   node->type = xnn_node_type_static_reshape;
262   node->compute_type = compute_type;
263   node->num_inputs = 1;
264   node->inputs[0] = input_id;
265   node->num_outputs = 1;
266   node->outputs[0] = output_id;
267   node->flags = flags;
268 
269   node->create = create_copy_operator;
270   node->setup = setup_copy_operator;
271 
272   return xnn_status_success;
273 }
274