xref: /aosp_15_r20/external/XNNPACK/src/subgraph/static-transpose.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <string.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/subgraph.h>
16 #include <xnnpack/subgraph-validation.h>
17 
create_transpose_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)18 static enum xnn_status create_transpose_operator(
19   const struct xnn_node* node,
20   const struct xnn_value* values,
21   size_t num_values,
22   struct xnn_operator_data* opdata,
23   const struct xnn_caches* caches)
24 {
25   assert(node->num_inputs == 1);
26   const uint32_t input_id = node->inputs[0];
27   assert(input_id != XNN_INVALID_VALUE_ID);
28   assert(input_id < num_values);
29 
30   assert(node->num_outputs == 1);
31   const uint32_t output_id = node->outputs[0];
32   assert(output_id != XNN_INVALID_VALUE_ID);
33   assert(output_id < num_values);
34 
35   enum xnn_status status;
36   switch (node->compute_type) {
37     case xnn_compute_type_fp32:
38       status = xnn_create_transpose_nd_x32(node->flags, &opdata->operator_objects[0]);
39       break;
40 #ifndef XNN_NO_F16_OPERATORS
41     case xnn_compute_type_fp16:
42       status = xnn_create_transpose_nd_x16(node->flags, &opdata->operator_objects[0]);
43       break;
44 #endif
45 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
46     case xnn_compute_type_qs8:
47     case xnn_compute_type_qu8:
48       status = xnn_create_transpose_nd_x8(node->flags, &opdata->operator_objects[0]);
49       break;
50 #endif
51     default:
52       XNN_UNREACHABLE;
53   }
54 
55   if (status == xnn_status_success) {
56     opdata->inputs[0] = input_id;
57     opdata->outputs[0] = output_id;
58     opdata->shape1.num_dims = node->params.transpose.num_dims;
59     opdata->shape2.num_dims = node->params.transpose.num_dims;
60     memcpy(opdata->shape1.dim, values[input_id].shape.dim, opdata->shape1.num_dims * sizeof(size_t));
61     memcpy(opdata->shape2.dim, node->params.transpose.perm, opdata->shape2.num_dims * sizeof(size_t));
62   }
63 
64   return status;
65 }
66 
setup_transpose_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)67 static enum xnn_status setup_transpose_operator(
68   const struct xnn_operator_data* opdata,
69   const struct xnn_blob* blobs,
70   size_t num_blobs,
71   pthreadpool_t threadpool)
72 {
73   const uint32_t input_id = opdata->inputs[0];
74   assert(input_id != XNN_INVALID_VALUE_ID);
75   assert(input_id < num_blobs);
76 
77   const uint32_t output_id = opdata->outputs[0];
78   assert(output_id != XNN_INVALID_VALUE_ID);
79   assert(output_id < num_blobs);
80 
81   const struct xnn_blob* input_blob = blobs + input_id;
82   const void* input_data = input_blob->data;
83   assert(input_data != NULL);
84 
85   const struct xnn_blob* output_blob = blobs + output_id;
86   void* output_data = output_blob->data;
87   assert(output_data != NULL);
88 
89   enum xnn_status status;
90    switch (opdata->operator_objects[0]->type) {
91 #ifndef XNN_NO_F16_OPERATORS
92     case xnn_operator_type_transpose_nd_x16: {
93       status = xnn_setup_transpose_nd_x16(
94         opdata->operator_objects[0],
95         input_data,
96         output_data,
97         opdata->shape1.num_dims,
98         opdata->shape1.dim,
99         opdata->shape2.dim,
100         threadpool);
101       break;
102     }
103 #endif  // !defined(XNN_NO_F16_OPERATORS)
104     case xnn_operator_type_transpose_nd_x32: {
105       status = xnn_setup_transpose_nd_x32(
106         opdata->operator_objects[0],
107         input_data,
108         output_data,
109         opdata->shape1.num_dims,
110         opdata->shape1.dim,
111         opdata->shape2.dim,
112         threadpool);
113       break;
114     }
115 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
116     case xnn_operator_type_transpose_nd_x8: {
117       status = xnn_setup_transpose_nd_x8(
118         opdata->operator_objects[0],
119         input_data,
120         output_data,
121         opdata->shape1.num_dims,
122         opdata->shape1.dim,
123         opdata->shape2.dim,
124         threadpool);
125       break;
126     }
127 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
128     default:
129       XNN_UNREACHABLE;
130   }
131 
132   return status;
133 }
134 
xnn_define_static_transpose(xnn_subgraph_t subgraph,size_t num_dims,const size_t * perm,uint32_t input_id,uint32_t output_id,uint32_t flags)135 enum xnn_status xnn_define_static_transpose(
136   xnn_subgraph_t subgraph,
137   size_t num_dims,
138   const size_t* perm,
139   uint32_t input_id,
140   uint32_t output_id,
141   uint32_t flags)
142 {
143   enum xnn_status status;
144   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_static_transpose)) != xnn_status_success) {
145     return status;
146   }
147 
148   if (num_dims == 0) {
149     xnn_log_error(
150       "failed to create %s operator with %zu num_dims: num_dims must be non-zero",
151       xnn_node_type_to_string(xnn_node_type_static_transpose), num_dims);
152     return xnn_status_invalid_parameter;
153   }
154 
155   if (num_dims > XNN_MAX_TENSOR_DIMS) {
156     xnn_log_error(
157       "failed to create %s operator with %zu num_dims: num_dims must be <= %d",
158       xnn_node_type_to_string(xnn_node_type_static_transpose), num_dims, XNN_MAX_TENSOR_DIMS);
159     return xnn_status_invalid_parameter;
160   }
161 
162   for (size_t i = 0; i < num_dims; ++i) {
163     if (perm[i] >= num_dims) {
164       xnn_log_error(
165           "failed to create %s operator with %zu perm and %zu num_dims: 0 <= perm < num_dims",
166           xnn_node_type_to_string(xnn_node_type_static_transpose), perm[i], num_dims);
167       return xnn_status_invalid_parameter;
168     }
169   }
170 
171   for (size_t i = 0; i < num_dims - 1; ++i) {
172     for (size_t j = i + 1; j < num_dims; ++j) {
173       if (perm[i] == perm[j]) {
174         xnn_log_error(
175             "failed to create %s operator with duplicate entries in perm",
176             xnn_node_type_to_string(xnn_node_type_static_transpose));
177         return xnn_status_invalid_parameter;
178       }
179     }
180   }
181 
182   if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_static_transpose, input_id, subgraph->num_values)) !=
183       xnn_status_success) {
184     return status;
185   }
186 
187   const struct xnn_value* input_value = &subgraph->values[input_id];
188   status = xnn_subgraph_check_input_type_dense(xnn_node_type_static_transpose, input_id, input_value);
189   if (status != xnn_status_success) {
190     return status;
191   }
192 
193   status = xnn_subgraph_check_output_node_id(xnn_node_type_static_transpose, output_id, subgraph->num_values);
194   if (status != xnn_status_success) {
195     return status;
196   }
197 
198   const struct xnn_value* output_value = &subgraph->values[output_id];
199   status = xnn_subgraph_check_output_type_dense(xnn_node_type_static_transpose, output_id, output_value);
200   if (status != xnn_status_success) {
201     return status;
202   }
203 
204   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
205   switch (output_value->datatype) {
206     case xnn_datatype_fp32:
207       compute_type = xnn_compute_type_fp32;
208       break;
209 #ifndef XNN_NO_QS8_OPERATORS
210     case xnn_datatype_qint8:
211       compute_type = xnn_compute_type_qs8;
212       break;
213 #endif  // !defined(XNN_NO_QS8_OPERATORS)
214 #ifndef XNN_NO_QU8_OPERATORS
215     case xnn_datatype_quint8:
216       compute_type = xnn_compute_type_qu8;
217       break;
218 #endif  // !defined(XNN_NO_QU8_OPERATORS)
219     default:
220       xnn_log_error(
221         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
222         xnn_node_type_to_string(xnn_node_type_static_transpose), output_id,
223         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
224       return xnn_status_invalid_parameter;
225   }
226 
227   switch (input_value->datatype) {
228     case xnn_datatype_fp32:
229 #ifndef XNN_NO_QS8_OPERATORS
230     case xnn_datatype_qint8:
231 #endif  // !defined(XNN_NO_QS8_OPERATORS)
232 #ifndef XNN_NO_QU8_OPERATORS
233     case xnn_datatype_quint8:
234 #endif  // !defined(XNN_NO_QU8_OPERATORS)
235       break;
236     default:
237       xnn_log_error(
238         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
239         xnn_node_type_to_string(xnn_node_type_static_transpose), input_id,
240         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
241       return xnn_status_invalid_parameter;
242   }
243 
244   status = xnn_subgraph_check_datatype_matches(
245     xnn_node_type_static_transpose, input_id, input_value, output_id, output_value);
246   if (status != xnn_status_success) {
247     return status;
248   }
249 
250   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
251   if (node == NULL) {
252     return xnn_status_out_of_memory;
253   }
254 
255   node->compute_type = compute_type;
256   node->inputs[0] = input_id;
257   node->flags = flags;
258   node->num_inputs = 1;
259   node->num_outputs = 1;
260   node->outputs[0] = output_id;
261   node->type = xnn_node_type_static_transpose;
262 
263   node->params.transpose.num_dims = num_dims;
264   node->create = create_transpose_operator;
265   node->setup = setup_transpose_operator;
266 
267   memcpy(node->params.transpose.perm, perm, num_dims * sizeof(size_t));
268 
269   return xnn_status_success;
270 }
271