1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2020 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <assert.h>
7*4bdc9457SAndroid Build Coastguard Worker #include <math.h>
8*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h>
9*4bdc9457SAndroid Build Coastguard Worker #include <stdint.h>
10*4bdc9457SAndroid Build Coastguard Worker #include <stdlib.h>
11*4bdc9457SAndroid Build Coastguard Worker
12*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
13*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/allocator.h>
14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/log.h>
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/params.h>
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
17*4bdc9457SAndroid Build Coastguard Worker
18*4bdc9457SAndroid Build Coastguard Worker
xnn_define_tensor_value(xnn_subgraph_t subgraph,enum xnn_datatype datatype,size_t num_dims,const size_t * dims,const void * data,uint32_t external_id,uint32_t flags,uint32_t * id_out)19*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_define_tensor_value(
20*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph,
21*4bdc9457SAndroid Build Coastguard Worker enum xnn_datatype datatype,
22*4bdc9457SAndroid Build Coastguard Worker size_t num_dims,
23*4bdc9457SAndroid Build Coastguard Worker const size_t* dims,
24*4bdc9457SAndroid Build Coastguard Worker const void* data,
25*4bdc9457SAndroid Build Coastguard Worker uint32_t external_id,
26*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
27*4bdc9457SAndroid Build Coastguard Worker uint32_t* id_out)
28*4bdc9457SAndroid Build Coastguard Worker {
29*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
30*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to create Dense Tensor value: XNNPACK is not initialized");
31*4bdc9457SAndroid Build Coastguard Worker return xnn_status_uninitialized;
32*4bdc9457SAndroid Build Coastguard Worker }
33*4bdc9457SAndroid Build Coastguard Worker
34*4bdc9457SAndroid Build Coastguard Worker if (external_id != XNN_INVALID_VALUE_ID && external_id >= subgraph->external_value_ids) {
35*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
36*4bdc9457SAndroid Build Coastguard Worker "failed to create Dense Tensor value: "
37*4bdc9457SAndroid Build Coastguard Worker "external ID %" PRIu32 " exceeds the number of reserved external IDs in subgraph (%" PRIu32 ")",
38*4bdc9457SAndroid Build Coastguard Worker external_id, subgraph->external_value_ids);
39*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
40*4bdc9457SAndroid Build Coastguard Worker }
41*4bdc9457SAndroid Build Coastguard Worker
42*4bdc9457SAndroid Build Coastguard Worker if (num_dims > XNN_MAX_TENSOR_DIMS) {
43*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to create Dense Tensor value: num of dimensions exceeds XNNPACK limit (%d)",
44*4bdc9457SAndroid Build Coastguard Worker XNN_MAX_TENSOR_DIMS);
45*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
46*4bdc9457SAndroid Build Coastguard Worker }
47*4bdc9457SAndroid Build Coastguard Worker
48*4bdc9457SAndroid Build Coastguard Worker switch (datatype) {
49*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_fp32:
50*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_fp16:
51*4bdc9457SAndroid Build Coastguard Worker break;
52*4bdc9457SAndroid Build Coastguard Worker default:
53*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to create Dense Tensor value: unsupported datatype %s (%d)",
54*4bdc9457SAndroid Build Coastguard Worker xnn_datatype_to_string(datatype), datatype);
55*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
56*4bdc9457SAndroid Build Coastguard Worker }
57*4bdc9457SAndroid Build Coastguard Worker
58*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* value = subgraph->values + external_id;
59*4bdc9457SAndroid Build Coastguard Worker if (external_id == XNN_INVALID_VALUE_ID) {
60*4bdc9457SAndroid Build Coastguard Worker value = xnn_subgraph_new_internal_value(subgraph);
61*4bdc9457SAndroid Build Coastguard Worker if (value == NULL) {
62*4bdc9457SAndroid Build Coastguard Worker return xnn_status_out_of_memory;
63*4bdc9457SAndroid Build Coastguard Worker }
64*4bdc9457SAndroid Build Coastguard Worker }
65*4bdc9457SAndroid Build Coastguard Worker value->type = xnn_value_type_dense_tensor;
66*4bdc9457SAndroid Build Coastguard Worker value->datatype = datatype;
67*4bdc9457SAndroid Build Coastguard Worker value->shape.num_dims = num_dims;
68*4bdc9457SAndroid Build Coastguard Worker memcpy(value->shape.dim, dims, num_dims * sizeof(size_t));
69*4bdc9457SAndroid Build Coastguard Worker value->flags = flags;
70*4bdc9457SAndroid Build Coastguard Worker value->data = data;
71*4bdc9457SAndroid Build Coastguard Worker
72*4bdc9457SAndroid Build Coastguard Worker *id_out = value->id;
73*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
74*4bdc9457SAndroid Build Coastguard Worker }
75*4bdc9457SAndroid Build Coastguard Worker
xnn_define_quantized_tensor_value(xnn_subgraph_t subgraph,enum xnn_datatype datatype,int32_t zero_point,float scale,size_t num_dims,const size_t * dims,const void * data,uint32_t external_id,uint32_t flags,uint32_t * id_out)76*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_define_quantized_tensor_value(
77*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph,
78*4bdc9457SAndroid Build Coastguard Worker enum xnn_datatype datatype,
79*4bdc9457SAndroid Build Coastguard Worker int32_t zero_point,
80*4bdc9457SAndroid Build Coastguard Worker float scale,
81*4bdc9457SAndroid Build Coastguard Worker size_t num_dims,
82*4bdc9457SAndroid Build Coastguard Worker const size_t* dims,
83*4bdc9457SAndroid Build Coastguard Worker const void* data,
84*4bdc9457SAndroid Build Coastguard Worker uint32_t external_id,
85*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
86*4bdc9457SAndroid Build Coastguard Worker uint32_t* id_out)
87*4bdc9457SAndroid Build Coastguard Worker {
88*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
89*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to create Quantized Dense Tensor value: XNNPACK is not initialized");
90*4bdc9457SAndroid Build Coastguard Worker return xnn_status_uninitialized;
91*4bdc9457SAndroid Build Coastguard Worker }
92*4bdc9457SAndroid Build Coastguard Worker
93*4bdc9457SAndroid Build Coastguard Worker if (external_id != XNN_INVALID_VALUE_ID && external_id >= subgraph->external_value_ids) {
94*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
95*4bdc9457SAndroid Build Coastguard Worker "failed to create Quantized Dense Tensor value: "
96*4bdc9457SAndroid Build Coastguard Worker "external ID %" PRIu32 " exceeds the number of reserved external IDs in subgraph (%" PRIu32 ")",
97*4bdc9457SAndroid Build Coastguard Worker external_id, subgraph->external_value_ids);
98*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
99*4bdc9457SAndroid Build Coastguard Worker }
100*4bdc9457SAndroid Build Coastguard Worker
101*4bdc9457SAndroid Build Coastguard Worker if (num_dims > XNN_MAX_TENSOR_DIMS) {
102*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
103*4bdc9457SAndroid Build Coastguard Worker "failed to create Quantized Dense Tensor value: num of dimensions exceeds XNNPACK limit (%d)",
104*4bdc9457SAndroid Build Coastguard Worker XNN_MAX_TENSOR_DIMS);
105*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
106*4bdc9457SAndroid Build Coastguard Worker }
107*4bdc9457SAndroid Build Coastguard Worker
108*4bdc9457SAndroid Build Coastguard Worker switch (datatype) {
109*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_qint8:
110*4bdc9457SAndroid Build Coastguard Worker if ((int32_t) (int8_t) zero_point != zero_point) {
111*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
112*4bdc9457SAndroid Build Coastguard Worker "failed to create Quantized Dense Tensor value: invalid zero point %" PRId32" outside the [-128, 127] range",
113*4bdc9457SAndroid Build Coastguard Worker zero_point);
114*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
115*4bdc9457SAndroid Build Coastguard Worker }
116*4bdc9457SAndroid Build Coastguard Worker break;
117*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_quint8:
118*4bdc9457SAndroid Build Coastguard Worker if ((int32_t) (uint8_t) zero_point != zero_point) {
119*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
120*4bdc9457SAndroid Build Coastguard Worker "failed to create Quantized Dense Tensor value: invalid zero point %" PRId32" outside the [0, 255] range",
121*4bdc9457SAndroid Build Coastguard Worker zero_point);
122*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
123*4bdc9457SAndroid Build Coastguard Worker }
124*4bdc9457SAndroid Build Coastguard Worker break;
125*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_qint32:
126*4bdc9457SAndroid Build Coastguard Worker if (zero_point != 0) {
127*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
128*4bdc9457SAndroid Build Coastguard Worker "failed to create Quantized Dense Tensor value: invalid non-zero zero point %" PRId32,
129*4bdc9457SAndroid Build Coastguard Worker zero_point);
130*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
131*4bdc9457SAndroid Build Coastguard Worker }
132*4bdc9457SAndroid Build Coastguard Worker break;
133*4bdc9457SAndroid Build Coastguard Worker default:
134*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to create Quantized Dense Tensor value: unsupported datatype %s (%d)",
135*4bdc9457SAndroid Build Coastguard Worker xnn_datatype_to_string(datatype), datatype);
136*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
137*4bdc9457SAndroid Build Coastguard Worker }
138*4bdc9457SAndroid Build Coastguard Worker
139*4bdc9457SAndroid Build Coastguard Worker if (scale <= 0.0f || !isnormal(scale)) {
140*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
141*4bdc9457SAndroid Build Coastguard Worker "failed to create Quantized Dense Tensor value with %.7g scale: scale must be finite, normalized, and positive",
142*4bdc9457SAndroid Build Coastguard Worker scale);
143*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
144*4bdc9457SAndroid Build Coastguard Worker }
145*4bdc9457SAndroid Build Coastguard Worker
146*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* value = subgraph->values + external_id;
147*4bdc9457SAndroid Build Coastguard Worker if (external_id == XNN_INVALID_VALUE_ID) {
148*4bdc9457SAndroid Build Coastguard Worker value = xnn_subgraph_new_internal_value(subgraph);
149*4bdc9457SAndroid Build Coastguard Worker if (value == NULL) {
150*4bdc9457SAndroid Build Coastguard Worker return xnn_status_out_of_memory;
151*4bdc9457SAndroid Build Coastguard Worker }
152*4bdc9457SAndroid Build Coastguard Worker }
153*4bdc9457SAndroid Build Coastguard Worker value->type = xnn_value_type_dense_tensor;
154*4bdc9457SAndroid Build Coastguard Worker value->datatype = datatype;
155*4bdc9457SAndroid Build Coastguard Worker value->quantization.zero_point = zero_point;
156*4bdc9457SAndroid Build Coastguard Worker value->quantization.scale = scale;
157*4bdc9457SAndroid Build Coastguard Worker value->shape.num_dims = num_dims;
158*4bdc9457SAndroid Build Coastguard Worker memcpy(value->shape.dim, dims, num_dims * sizeof(size_t));
159*4bdc9457SAndroid Build Coastguard Worker value->flags = flags;
160*4bdc9457SAndroid Build Coastguard Worker value->data = data;
161*4bdc9457SAndroid Build Coastguard Worker
162*4bdc9457SAndroid Build Coastguard Worker *id_out = value->id;
163*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
164*4bdc9457SAndroid Build Coastguard Worker }
165*4bdc9457SAndroid Build Coastguard Worker
xnn_define_channelwise_quantized_tensor_value(xnn_subgraph_t subgraph,enum xnn_datatype datatype,const float * scale,size_t num_dims,size_t channel_dim,const size_t * dims,const void * data,uint32_t external_id,uint32_t flags,uint32_t * id_out)166*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_define_channelwise_quantized_tensor_value(
167*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph,
168*4bdc9457SAndroid Build Coastguard Worker enum xnn_datatype datatype,
169*4bdc9457SAndroid Build Coastguard Worker const float* scale,
170*4bdc9457SAndroid Build Coastguard Worker size_t num_dims,
171*4bdc9457SAndroid Build Coastguard Worker size_t channel_dim,
172*4bdc9457SAndroid Build Coastguard Worker const size_t* dims,
173*4bdc9457SAndroid Build Coastguard Worker const void* data,
174*4bdc9457SAndroid Build Coastguard Worker uint32_t external_id,
175*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
176*4bdc9457SAndroid Build Coastguard Worker uint32_t* id_out)
177*4bdc9457SAndroid Build Coastguard Worker {
178*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
179*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to create Channelwise Quantized Dense Tensor value: XNNPACK is not initialized");
180*4bdc9457SAndroid Build Coastguard Worker return xnn_status_uninitialized;
181*4bdc9457SAndroid Build Coastguard Worker }
182*4bdc9457SAndroid Build Coastguard Worker
183*4bdc9457SAndroid Build Coastguard Worker if (external_id != XNN_INVALID_VALUE_ID && external_id >= subgraph->external_value_ids) {
184*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
185*4bdc9457SAndroid Build Coastguard Worker "failed to create Channelwise Quantized Dense Tensor value: "
186*4bdc9457SAndroid Build Coastguard Worker "external ID %" PRIu32 " exceeds the number of reserved external IDs in subgraph (%" PRIu32 ")",
187*4bdc9457SAndroid Build Coastguard Worker external_id, subgraph->external_value_ids);
188*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
189*4bdc9457SAndroid Build Coastguard Worker }
190*4bdc9457SAndroid Build Coastguard Worker
191*4bdc9457SAndroid Build Coastguard Worker if (num_dims == 0) {
192*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
193*4bdc9457SAndroid Build Coastguard Worker "failed to create Channelwise Quantized Dense Tensor value: no channel dimension exists");
194*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
195*4bdc9457SAndroid Build Coastguard Worker }
196*4bdc9457SAndroid Build Coastguard Worker
197*4bdc9457SAndroid Build Coastguard Worker if (num_dims > XNN_MAX_TENSOR_DIMS) {
198*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
199*4bdc9457SAndroid Build Coastguard Worker "failed to create Channelwise Quantized Dense Tensor value: num of dimensions exceeds XNNPACK limit (%d)",
200*4bdc9457SAndroid Build Coastguard Worker XNN_MAX_TENSOR_DIMS);
201*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
202*4bdc9457SAndroid Build Coastguard Worker }
203*4bdc9457SAndroid Build Coastguard Worker
204*4bdc9457SAndroid Build Coastguard Worker if (channel_dim >= num_dims) {
205*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
206*4bdc9457SAndroid Build Coastguard Worker "failed to create Channelwise Quantized Dense Tensor value: "
207*4bdc9457SAndroid Build Coastguard Worker "channel dimension index %zu is out of range for %zu-dimensional tensor",
208*4bdc9457SAndroid Build Coastguard Worker channel_dim, num_dims);
209*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
210*4bdc9457SAndroid Build Coastguard Worker }
211*4bdc9457SAndroid Build Coastguard Worker
212*4bdc9457SAndroid Build Coastguard Worker switch (datatype) {
213*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_qcint8:
214*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_qcint32:
215*4bdc9457SAndroid Build Coastguard Worker break;
216*4bdc9457SAndroid Build Coastguard Worker default:
217*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to create Channelwise Quantized Dense Tensor value: unsupported datatype %s (%d)",
218*4bdc9457SAndroid Build Coastguard Worker xnn_datatype_to_string(datatype), datatype);
219*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
220*4bdc9457SAndroid Build Coastguard Worker }
221*4bdc9457SAndroid Build Coastguard Worker
222*4bdc9457SAndroid Build Coastguard Worker const size_t channels = dims[0];
223*4bdc9457SAndroid Build Coastguard Worker for (size_t channel = 0; channel < channels; channel++) {
224*4bdc9457SAndroid Build Coastguard Worker if (scale[channel] <= 0.0f || !isnormal(scale[channel])) {
225*4bdc9457SAndroid Build Coastguard Worker xnn_log_error(
226*4bdc9457SAndroid Build Coastguard Worker "failed to create Channelwise Quantized Dense Tensor value with %.7g scale in channel #%zu: "
227*4bdc9457SAndroid Build Coastguard Worker "scale must be finite, normalized, and positive",
228*4bdc9457SAndroid Build Coastguard Worker scale[channel], channel);
229*4bdc9457SAndroid Build Coastguard Worker return xnn_status_invalid_parameter;
230*4bdc9457SAndroid Build Coastguard Worker }
231*4bdc9457SAndroid Build Coastguard Worker }
232*4bdc9457SAndroid Build Coastguard Worker
233*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* value = subgraph->values + external_id;
234*4bdc9457SAndroid Build Coastguard Worker if (external_id == XNN_INVALID_VALUE_ID) {
235*4bdc9457SAndroid Build Coastguard Worker value = xnn_subgraph_new_internal_value(subgraph);
236*4bdc9457SAndroid Build Coastguard Worker if (value == NULL) {
237*4bdc9457SAndroid Build Coastguard Worker return xnn_status_out_of_memory;
238*4bdc9457SAndroid Build Coastguard Worker }
239*4bdc9457SAndroid Build Coastguard Worker }
240*4bdc9457SAndroid Build Coastguard Worker value->type = xnn_value_type_dense_tensor;
241*4bdc9457SAndroid Build Coastguard Worker value->datatype = datatype;
242*4bdc9457SAndroid Build Coastguard Worker value->quantization.zero_point = 0;
243*4bdc9457SAndroid Build Coastguard Worker value->quantization.channelwise_scale = scale;
244*4bdc9457SAndroid Build Coastguard Worker value->quantization.channel_dimension = channel_dim;
245*4bdc9457SAndroid Build Coastguard Worker value->shape.num_dims = num_dims;
246*4bdc9457SAndroid Build Coastguard Worker memcpy(value->shape.dim, dims, num_dims * sizeof(size_t));
247*4bdc9457SAndroid Build Coastguard Worker value->flags = flags;
248*4bdc9457SAndroid Build Coastguard Worker value->data = data;
249*4bdc9457SAndroid Build Coastguard Worker
250*4bdc9457SAndroid Build Coastguard Worker *id_out = value->id;
251*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
252*4bdc9457SAndroid Build Coastguard Worker }
253*4bdc9457SAndroid Build Coastguard Worker
xnn_shape_multiply_all_dims(const struct xnn_shape shape[restrict XNN_MIN_ELEMENTS (1)])254*4bdc9457SAndroid Build Coastguard Worker size_t xnn_shape_multiply_all_dims(
255*4bdc9457SAndroid Build Coastguard Worker const struct xnn_shape shape[restrict XNN_MIN_ELEMENTS(1)])
256*4bdc9457SAndroid Build Coastguard Worker {
257*4bdc9457SAndroid Build Coastguard Worker size_t batch_size = 1;
258*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < shape->num_dims; i++) {
259*4bdc9457SAndroid Build Coastguard Worker batch_size *= shape->dim[i];
260*4bdc9457SAndroid Build Coastguard Worker }
261*4bdc9457SAndroid Build Coastguard Worker return batch_size;
262*4bdc9457SAndroid Build Coastguard Worker }
263*4bdc9457SAndroid Build Coastguard Worker
xnn_shape_multiply_batch_dims(const struct xnn_shape shape[restrict XNN_MIN_ELEMENTS (1)],size_t num_nonbatch_dims)264*4bdc9457SAndroid Build Coastguard Worker size_t xnn_shape_multiply_batch_dims(
265*4bdc9457SAndroid Build Coastguard Worker const struct xnn_shape shape[restrict XNN_MIN_ELEMENTS(1)],
266*4bdc9457SAndroid Build Coastguard Worker size_t num_nonbatch_dims)
267*4bdc9457SAndroid Build Coastguard Worker {
268*4bdc9457SAndroid Build Coastguard Worker size_t batch_size = 1;
269*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i + num_nonbatch_dims < shape->num_dims; i++) {
270*4bdc9457SAndroid Build Coastguard Worker batch_size *= shape->dim[i];
271*4bdc9457SAndroid Build Coastguard Worker }
272*4bdc9457SAndroid Build Coastguard Worker return batch_size;
273*4bdc9457SAndroid Build Coastguard Worker }
274*4bdc9457SAndroid Build Coastguard Worker
xnn_shape_multiply_non_channel_dims(const struct xnn_shape shape[restrict XNN_MIN_ELEMENTS (1)])275*4bdc9457SAndroid Build Coastguard Worker size_t xnn_shape_multiply_non_channel_dims(
276*4bdc9457SAndroid Build Coastguard Worker const struct xnn_shape shape[restrict XNN_MIN_ELEMENTS(1)])
277*4bdc9457SAndroid Build Coastguard Worker {
278*4bdc9457SAndroid Build Coastguard Worker size_t batch_size = 1;
279*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i + 1 < shape->num_dims; i++) {
280*4bdc9457SAndroid Build Coastguard Worker batch_size *= shape->dim[i];
281*4bdc9457SAndroid Build Coastguard Worker }
282*4bdc9457SAndroid Build Coastguard Worker return batch_size;
283*4bdc9457SAndroid Build Coastguard Worker }
284*4bdc9457SAndroid Build Coastguard Worker
xnn_tensor_get_size(xnn_subgraph_t subgraph,uint32_t value_id)285*4bdc9457SAndroid Build Coastguard Worker size_t xnn_tensor_get_size(
286*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph,
287*4bdc9457SAndroid Build Coastguard Worker uint32_t value_id)
288*4bdc9457SAndroid Build Coastguard Worker {
289*4bdc9457SAndroid Build Coastguard Worker assert(value_id < subgraph->num_values);
290*4bdc9457SAndroid Build Coastguard Worker
291*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* value = subgraph->values + value_id;
292*4bdc9457SAndroid Build Coastguard Worker assert(value->type == xnn_value_type_dense_tensor);
293*4bdc9457SAndroid Build Coastguard Worker assert(value->datatype != xnn_datatype_invalid);
294*4bdc9457SAndroid Build Coastguard Worker
295*4bdc9457SAndroid Build Coastguard Worker size_t size = 0;
296*4bdc9457SAndroid Build Coastguard Worker switch (value->datatype) {
297*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_fp16:
298*4bdc9457SAndroid Build Coastguard Worker size = 2;
299*4bdc9457SAndroid Build Coastguard Worker break;
300*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_fp32:
301*4bdc9457SAndroid Build Coastguard Worker size = 4;
302*4bdc9457SAndroid Build Coastguard Worker break;
303*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_qint8:
304*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_quint8:
305*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_qcint8:
306*4bdc9457SAndroid Build Coastguard Worker size = 1;
307*4bdc9457SAndroid Build Coastguard Worker break;
308*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_qint32:
309*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_qcint32:
310*4bdc9457SAndroid Build Coastguard Worker size = 4;
311*4bdc9457SAndroid Build Coastguard Worker break;
312*4bdc9457SAndroid Build Coastguard Worker case xnn_datatype_invalid:
313*4bdc9457SAndroid Build Coastguard Worker XNN_UNREACHABLE;
314*4bdc9457SAndroid Build Coastguard Worker }
315*4bdc9457SAndroid Build Coastguard Worker
316*4bdc9457SAndroid Build Coastguard Worker return size * xnn_shape_multiply_all_dims(&value->shape);
317*4bdc9457SAndroid Build Coastguard Worker }
318