1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2020 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <math.h>
7*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h>
8*4bdc9457SAndroid Build Coastguard Worker #include <stdint.h>
9*4bdc9457SAndroid Build Coastguard Worker #include <stdlib.h>
10*4bdc9457SAndroid Build Coastguard Worker
11*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
12*4bdc9457SAndroid Build Coastguard Worker
13*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/allocator.h>
15*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/log.h>
16*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h>
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/params.h>
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
19*4bdc9457SAndroid Build Coastguard Worker
20*4bdc9457SAndroid Build Coastguard Worker
21*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_ENABLE_SPARSE
22*4bdc9457SAndroid Build Coastguard Worker #error "XNN_ENABLE_SPARSE not defined"
23*4bdc9457SAndroid Build Coastguard Worker #endif
24*4bdc9457SAndroid Build Coastguard Worker
xnn_create_subgraph(uint32_t external_value_ids,uint32_t flags,xnn_subgraph_t * subgraph_out)25*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_subgraph(
26*4bdc9457SAndroid Build Coastguard Worker uint32_t external_value_ids,
27*4bdc9457SAndroid Build Coastguard Worker uint32_t flags,
28*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t* subgraph_out)
29*4bdc9457SAndroid Build Coastguard Worker {
30*4bdc9457SAndroid Build Coastguard Worker struct xnn_subgraph* subgraph = NULL;
31*4bdc9457SAndroid Build Coastguard Worker enum xnn_status status = xnn_status_uninitialized;
32*4bdc9457SAndroid Build Coastguard Worker
33*4bdc9457SAndroid Build Coastguard Worker if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
34*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to create subgraph: XNNPACK is not initialized");
35*4bdc9457SAndroid Build Coastguard Worker goto error;
36*4bdc9457SAndroid Build Coastguard Worker }
37*4bdc9457SAndroid Build Coastguard Worker
38*4bdc9457SAndroid Build Coastguard Worker status = xnn_status_out_of_memory;
39*4bdc9457SAndroid Build Coastguard Worker
40*4bdc9457SAndroid Build Coastguard Worker subgraph = xnn_allocate_zero_memory(sizeof(struct xnn_subgraph));
41*4bdc9457SAndroid Build Coastguard Worker if (subgraph == NULL) {
42*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to allocate %zu bytes for subgraph descriptor", sizeof(struct xnn_subgraph));
43*4bdc9457SAndroid Build Coastguard Worker goto error;
44*4bdc9457SAndroid Build Coastguard Worker }
45*4bdc9457SAndroid Build Coastguard Worker
46*4bdc9457SAndroid Build Coastguard Worker subgraph->external_value_ids = external_value_ids;
47*4bdc9457SAndroid Build Coastguard Worker
48*4bdc9457SAndroid Build Coastguard Worker subgraph->values = xnn_allocate_zero_memory(external_value_ids * sizeof(struct xnn_value));
49*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values == NULL) {
50*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to allocate %zu bytes for subgraph values",
51*4bdc9457SAndroid Build Coastguard Worker (size_t) external_value_ids * sizeof(struct xnn_value));
52*4bdc9457SAndroid Build Coastguard Worker goto error;
53*4bdc9457SAndroid Build Coastguard Worker }
54*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < external_value_ids; i++) {
55*4bdc9457SAndroid Build Coastguard Worker subgraph->values[i].id = i;
56*4bdc9457SAndroid Build Coastguard Worker }
57*4bdc9457SAndroid Build Coastguard Worker subgraph->num_values = external_value_ids;
58*4bdc9457SAndroid Build Coastguard Worker subgraph->num_reserved_values = external_value_ids;
59*4bdc9457SAndroid Build Coastguard Worker
60*4bdc9457SAndroid Build Coastguard Worker *subgraph_out = subgraph;
61*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
62*4bdc9457SAndroid Build Coastguard Worker
63*4bdc9457SAndroid Build Coastguard Worker error:
64*4bdc9457SAndroid Build Coastguard Worker xnn_delete_subgraph(subgraph);
65*4bdc9457SAndroid Build Coastguard Worker return status;
66*4bdc9457SAndroid Build Coastguard Worker }
67*4bdc9457SAndroid Build Coastguard Worker
68*4bdc9457SAndroid Build Coastguard Worker
xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph)69*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph)
70*4bdc9457SAndroid Build Coastguard Worker {
71*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* values = subgraph->values;
72*4bdc9457SAndroid Build Coastguard Worker const size_t size = subgraph->num_values;
73*4bdc9457SAndroid Build Coastguard Worker const size_t capacity = subgraph->num_reserved_values;
74*4bdc9457SAndroid Build Coastguard Worker if (capacity < size + 1) {
75*4bdc9457SAndroid Build Coastguard Worker const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
76*4bdc9457SAndroid Build Coastguard Worker assert(new_capacity >= size + 1);
77*4bdc9457SAndroid Build Coastguard Worker values = xnn_reallocate_memory(values, new_capacity * sizeof(struct xnn_value));
78*4bdc9457SAndroid Build Coastguard Worker if (values == NULL) {
79*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to allocate %zu bytes for subgraph values",
80*4bdc9457SAndroid Build Coastguard Worker capacity * sizeof(struct xnn_value));
81*4bdc9457SAndroid Build Coastguard Worker return values;
82*4bdc9457SAndroid Build Coastguard Worker }
83*4bdc9457SAndroid Build Coastguard Worker
84*4bdc9457SAndroid Build Coastguard Worker memset(values + size, 0, (new_capacity - size) * sizeof(struct xnn_value));
85*4bdc9457SAndroid Build Coastguard Worker subgraph->num_reserved_values = new_capacity;
86*4bdc9457SAndroid Build Coastguard Worker subgraph->values = values;
87*4bdc9457SAndroid Build Coastguard Worker }
88*4bdc9457SAndroid Build Coastguard Worker subgraph->num_values = size + 1;
89*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* new_value = values + size;
90*4bdc9457SAndroid Build Coastguard Worker new_value->id = size;
91*4bdc9457SAndroid Build Coastguard Worker return new_value;
92*4bdc9457SAndroid Build Coastguard Worker }
93*4bdc9457SAndroid Build Coastguard Worker
xnn_node_clear(struct xnn_node * node)94*4bdc9457SAndroid Build Coastguard Worker void xnn_node_clear(struct xnn_node* node) {
95*4bdc9457SAndroid Build Coastguard Worker assert(node != NULL);
96*4bdc9457SAndroid Build Coastguard Worker memset(node, 0, sizeof(struct xnn_node));
97*4bdc9457SAndroid Build Coastguard Worker }
98*4bdc9457SAndroid Build Coastguard Worker
xnn_value_clear(struct xnn_value * value)99*4bdc9457SAndroid Build Coastguard Worker void xnn_value_clear(struct xnn_value* value) {
100*4bdc9457SAndroid Build Coastguard Worker assert(value != NULL);
101*4bdc9457SAndroid Build Coastguard Worker memset(value, 0, sizeof(struct xnn_value));
102*4bdc9457SAndroid Build Coastguard Worker }
103*4bdc9457SAndroid Build Coastguard Worker
xnn_value_copy(struct xnn_value * dst_value,const struct xnn_value * src_value)104*4bdc9457SAndroid Build Coastguard Worker void xnn_value_copy(
105*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* dst_value,
106*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* src_value)
107*4bdc9457SAndroid Build Coastguard Worker {
108*4bdc9457SAndroid Build Coastguard Worker // Note: Value ID stays unchanged
109*4bdc9457SAndroid Build Coastguard Worker
110*4bdc9457SAndroid Build Coastguard Worker dst_value->type = src_value->type;
111*4bdc9457SAndroid Build Coastguard Worker dst_value->datatype = src_value->datatype;
112*4bdc9457SAndroid Build Coastguard Worker dst_value->quantization = src_value->quantization;
113*4bdc9457SAndroid Build Coastguard Worker dst_value->shape = src_value->shape;
114*4bdc9457SAndroid Build Coastguard Worker dst_value->flags = src_value->flags;
115*4bdc9457SAndroid Build Coastguard Worker dst_value->data = src_value->data;
116*4bdc9457SAndroid Build Coastguard Worker dst_value->producer = src_value->producer;
117*4bdc9457SAndroid Build Coastguard Worker dst_value->first_consumer = src_value->first_consumer;
118*4bdc9457SAndroid Build Coastguard Worker }
119*4bdc9457SAndroid Build Coastguard Worker
xnn_subgraph_new_node(xnn_subgraph_t subgraph)120*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* xnn_subgraph_new_node(xnn_subgraph_t subgraph)
121*4bdc9457SAndroid Build Coastguard Worker {
122*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* nodes = subgraph->nodes;
123*4bdc9457SAndroid Build Coastguard Worker const size_t size = subgraph->num_nodes;
124*4bdc9457SAndroid Build Coastguard Worker const size_t capacity = subgraph->num_reserved_nodes;
125*4bdc9457SAndroid Build Coastguard Worker
126*4bdc9457SAndroid Build Coastguard Worker if (capacity < size + 1) {
127*4bdc9457SAndroid Build Coastguard Worker const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
128*4bdc9457SAndroid Build Coastguard Worker assert(new_capacity >= size + 1);
129*4bdc9457SAndroid Build Coastguard Worker nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node));
130*4bdc9457SAndroid Build Coastguard Worker if (nodes == NULL) {
131*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to allocate %zu bytes for subgraph nodes",
132*4bdc9457SAndroid Build Coastguard Worker capacity * sizeof(struct xnn_node));
133*4bdc9457SAndroid Build Coastguard Worker return nodes;
134*4bdc9457SAndroid Build Coastguard Worker }
135*4bdc9457SAndroid Build Coastguard Worker
136*4bdc9457SAndroid Build Coastguard Worker memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node));
137*4bdc9457SAndroid Build Coastguard Worker subgraph->num_reserved_nodes = new_capacity;
138*4bdc9457SAndroid Build Coastguard Worker subgraph->nodes = nodes;
139*4bdc9457SAndroid Build Coastguard Worker }
140*4bdc9457SAndroid Build Coastguard Worker subgraph->num_nodes = size + 1;
141*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* new_node = nodes + size;
142*4bdc9457SAndroid Build Coastguard Worker new_node->id = size;
143*4bdc9457SAndroid Build Coastguard Worker return new_node;
144*4bdc9457SAndroid Build Coastguard Worker }
145*4bdc9457SAndroid Build Coastguard Worker
xnn_subgraph_add_nodes(xnn_subgraph_t subgraph,size_t num_nodes)146*4bdc9457SAndroid Build Coastguard Worker void xnn_subgraph_add_nodes(xnn_subgraph_t subgraph, size_t num_nodes)
147*4bdc9457SAndroid Build Coastguard Worker {
148*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* nodes = subgraph->nodes;
149*4bdc9457SAndroid Build Coastguard Worker const size_t size = subgraph->num_nodes;
150*4bdc9457SAndroid Build Coastguard Worker const size_t capacity = subgraph->num_reserved_nodes;
151*4bdc9457SAndroid Build Coastguard Worker
152*4bdc9457SAndroid Build Coastguard Worker if (capacity < size + num_nodes) {
153*4bdc9457SAndroid Build Coastguard Worker const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + max(num_nodes, 64));
154*4bdc9457SAndroid Build Coastguard Worker assert(new_capacity >= size + num_nodes);
155*4bdc9457SAndroid Build Coastguard Worker nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node));
156*4bdc9457SAndroid Build Coastguard Worker if (nodes == NULL) {
157*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to allocate %zu bytes for subgraph nodes",
158*4bdc9457SAndroid Build Coastguard Worker capacity * sizeof(struct xnn_node));
159*4bdc9457SAndroid Build Coastguard Worker return;
160*4bdc9457SAndroid Build Coastguard Worker }
161*4bdc9457SAndroid Build Coastguard Worker
162*4bdc9457SAndroid Build Coastguard Worker memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node));
163*4bdc9457SAndroid Build Coastguard Worker subgraph->num_reserved_nodes = new_capacity;
164*4bdc9457SAndroid Build Coastguard Worker subgraph->nodes = nodes;
165*4bdc9457SAndroid Build Coastguard Worker }
166*4bdc9457SAndroid Build Coastguard Worker subgraph->num_nodes = size + num_nodes;
167*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* new_nodes = nodes + size;
168*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_nodes; i++) {
169*4bdc9457SAndroid Build Coastguard Worker new_nodes[i].id = size + i;
170*4bdc9457SAndroid Build Coastguard Worker }
171*4bdc9457SAndroid Build Coastguard Worker }
172*4bdc9457SAndroid Build Coastguard Worker
xnn_subgraph_analyze_consumers_and_producers(xnn_subgraph_t subgraph)173*4bdc9457SAndroid Build Coastguard Worker void xnn_subgraph_analyze_consumers_and_producers(xnn_subgraph_t subgraph)
174*4bdc9457SAndroid Build Coastguard Worker {
175*4bdc9457SAndroid Build Coastguard Worker // Initialize producer/consumer fields to safe defaults.
176*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < subgraph->num_values; i++) {
177*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* value = &subgraph->values[i];
178*4bdc9457SAndroid Build Coastguard Worker value->producer = XNN_INVALID_NODE_ID;
179*4bdc9457SAndroid Build Coastguard Worker value->first_consumer = XNN_INVALID_NODE_ID;
180*4bdc9457SAndroid Build Coastguard Worker value->num_consumers = 0;
181*4bdc9457SAndroid Build Coastguard Worker }
182*4bdc9457SAndroid Build Coastguard Worker
183*4bdc9457SAndroid Build Coastguard Worker // Analyse Nodes' inputs and output and update Values' producer/consumer fields
184*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
185*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* node = &subgraph->nodes[n];
186*4bdc9457SAndroid Build Coastguard Worker
187*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_inputs; i++) {
188*4bdc9457SAndroid Build Coastguard Worker const uint32_t input_id = node->inputs[i];
189*4bdc9457SAndroid Build Coastguard Worker assert(input_id < subgraph->num_values);
190*4bdc9457SAndroid Build Coastguard Worker
191*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values[input_id].num_consumers++ == 0) {
192*4bdc9457SAndroid Build Coastguard Worker assert(subgraph->values[input_id].first_consumer == XNN_INVALID_NODE_ID);
193*4bdc9457SAndroid Build Coastguard Worker subgraph->values[input_id].first_consumer = n;
194*4bdc9457SAndroid Build Coastguard Worker }
195*4bdc9457SAndroid Build Coastguard Worker }
196*4bdc9457SAndroid Build Coastguard Worker
197*4bdc9457SAndroid Build Coastguard Worker for (uint32_t o = 0; o < node->num_outputs; o++) {
198*4bdc9457SAndroid Build Coastguard Worker const uint32_t output_id = node->outputs[o];
199*4bdc9457SAndroid Build Coastguard Worker assert(output_id < subgraph->num_values);
200*4bdc9457SAndroid Build Coastguard Worker
201*4bdc9457SAndroid Build Coastguard Worker assert(subgraph->values[output_id].producer == XNN_INVALID_NODE_ID);
202*4bdc9457SAndroid Build Coastguard Worker subgraph->values[output_id].producer = n;
203*4bdc9457SAndroid Build Coastguard Worker }
204*4bdc9457SAndroid Build Coastguard Worker }
205*4bdc9457SAndroid Build Coastguard Worker
206*4bdc9457SAndroid Build Coastguard Worker // Count extra consumer for Values which are external outputs.
207*4bdc9457SAndroid Build Coastguard Worker // Remove unreferenced values.
208*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < subgraph->num_values; i++) {
209*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* value = &subgraph->values[i];
210*4bdc9457SAndroid Build Coastguard Worker if (xnn_value_is_external_output(value)) {
211*4bdc9457SAndroid Build Coastguard Worker value->num_consumers += 1;
212*4bdc9457SAndroid Build Coastguard Worker }
213*4bdc9457SAndroid Build Coastguard Worker }
214*4bdc9457SAndroid Build Coastguard Worker }
215*4bdc9457SAndroid Build Coastguard Worker
216*4bdc9457SAndroid Build Coastguard Worker #define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW 1
217*4bdc9457SAndroid Build Coastguard Worker #define XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW 2
218*4bdc9457SAndroid Build Coastguard Worker #define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC 4
219*4bdc9457SAndroid Build Coastguard Worker #define XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER 8
220*4bdc9457SAndroid Build Coastguard Worker
xnn_check_nchw_compatibility(xnn_subgraph_t subgraph,struct xnn_node * node)221*4bdc9457SAndroid Build Coastguard Worker uint32_t xnn_check_nchw_compatibility(xnn_subgraph_t subgraph, struct xnn_node* node) {
222*4bdc9457SAndroid Build Coastguard Worker if (node->compute_type != xnn_compute_type_fp32) {
223*4bdc9457SAndroid Build Coastguard Worker if (node->type != xnn_node_type_invalid) {
224*4bdc9457SAndroid Build Coastguard Worker xnn_log_info(
225*4bdc9457SAndroid Build Coastguard Worker "Node %s compute type %d is incompatible with sparse inference",
226*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type), node->compute_type);
227*4bdc9457SAndroid Build Coastguard Worker }
228*4bdc9457SAndroid Build Coastguard Worker return 0;
229*4bdc9457SAndroid Build Coastguard Worker }
230*4bdc9457SAndroid Build Coastguard Worker
231*4bdc9457SAndroid Build Coastguard Worker switch (node->type) {
232*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_convolution_2d:
233*4bdc9457SAndroid Build Coastguard Worker // Supported cases:
234*4bdc9457SAndroid Build Coastguard Worker // - 1x1 convolution (no stride, no dilation, no padding, no groups)
235*4bdc9457SAndroid Build Coastguard Worker // - 3x3 stride-2 convolution (no dilation, padding 1 on each side, no groups, 3 input channels)
236*4bdc9457SAndroid Build Coastguard Worker if (node->params.convolution_2d.groups != 1) {
237*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s groups (%" PRIu32 ") "
238*4bdc9457SAndroid Build Coastguard Worker "is incompatible with sparse inference",
239*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
240*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.groups);
241*4bdc9457SAndroid Build Coastguard Worker return 0;
242*4bdc9457SAndroid Build Coastguard Worker }
243*4bdc9457SAndroid Build Coastguard Worker if ((node->params.convolution_2d.dilation_height | node->params.convolution_2d.dilation_width) != 1) {
244*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s dilation (height=%" PRIu32 ", width=%" PRIu32 ") "
245*4bdc9457SAndroid Build Coastguard Worker "is incompatible with sparse inference",
246*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
247*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.dilation_height,
248*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.dilation_width);
249*4bdc9457SAndroid Build Coastguard Worker return 0;
250*4bdc9457SAndroid Build Coastguard Worker }
251*4bdc9457SAndroid Build Coastguard Worker if ((node->params.convolution_2d.kernel_height | node->params.convolution_2d.kernel_width) == 1) {
252*4bdc9457SAndroid Build Coastguard Worker if ((node->params.convolution_2d.input_padding_top | node->params.convolution_2d.input_padding_right |
253*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.input_padding_bottom | node->params.convolution_2d.input_padding_left) != 0) {
254*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s (1x1 kernel) padding (top=%" PRIu32 ", right=%" PRIu32", bottom=%" PRIu32 ", left=%" PRIu32") "
255*4bdc9457SAndroid Build Coastguard Worker "is incompatible with sparse inference",
256*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
257*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.input_padding_top,
258*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.input_padding_right,
259*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.input_padding_bottom,
260*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.input_padding_left);
261*4bdc9457SAndroid Build Coastguard Worker return 0;
262*4bdc9457SAndroid Build Coastguard Worker }
263*4bdc9457SAndroid Build Coastguard Worker if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 1) {
264*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s (1x1 kernel) subsampling (height=%" PRIu32 ", width=%" PRIu32 ") "
265*4bdc9457SAndroid Build Coastguard Worker "is incompatible with sparse inference",
266*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
267*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.subsampling_height,
268*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.subsampling_width);
269*4bdc9457SAndroid Build Coastguard Worker return 0;
270*4bdc9457SAndroid Build Coastguard Worker }
271*4bdc9457SAndroid Build Coastguard Worker return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
272*4bdc9457SAndroid Build Coastguard Worker } else if (node->params.convolution_2d.kernel_height == 3 && node->params.convolution_2d.kernel_width == 3) {
273*4bdc9457SAndroid Build Coastguard Worker if (node->params.convolution_2d.input_padding_top != 1 || node->params.convolution_2d.input_padding_right != 1 ||
274*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.input_padding_bottom != 1 || node->params.convolution_2d.input_padding_left != 1) {
275*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s (3x3 kernel) padding (top=%" PRIu32 ", right=%" PRIu32 ", bottom=%" PRIu32 ", left=%" PRIu32 ") "
276*4bdc9457SAndroid Build Coastguard Worker "is incompatible with sparse inference",
277*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
278*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.input_padding_top,
279*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.input_padding_right,
280*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.input_padding_bottom,
281*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.input_padding_left);
282*4bdc9457SAndroid Build Coastguard Worker return 0;
283*4bdc9457SAndroid Build Coastguard Worker }
284*4bdc9457SAndroid Build Coastguard Worker if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 2) {
285*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s (3x3 kernel) subsampling (height=%" PRIu32 ", width=%" PRIu32 ") "
286*4bdc9457SAndroid Build Coastguard Worker "is incompatible with sparse inference",
287*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
288*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.subsampling_height,
289*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.subsampling_width);
290*4bdc9457SAndroid Build Coastguard Worker return 0;
291*4bdc9457SAndroid Build Coastguard Worker }
292*4bdc9457SAndroid Build Coastguard Worker if (node->params.convolution_2d.group_input_channels != 3) {
293*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s (3x3 kernel) input channels (%zu) "
294*4bdc9457SAndroid Build Coastguard Worker "is incompatible with sparse inference",
295*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
296*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.group_input_channels);
297*4bdc9457SAndroid Build Coastguard Worker return 0;
298*4bdc9457SAndroid Build Coastguard Worker }
299*4bdc9457SAndroid Build Coastguard Worker return XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW;
300*4bdc9457SAndroid Build Coastguard Worker }
301*4bdc9457SAndroid Build Coastguard Worker return 0;
302*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_depthwise_convolution_2d:
303*4bdc9457SAndroid Build Coastguard Worker // Supported cases:
304*4bdc9457SAndroid Build Coastguard Worker // - 3x3 stride-1 convolution (no dilation, padding 1 on each side)
305*4bdc9457SAndroid Build Coastguard Worker // - 3x3 stride-2 convolution (no dilation, padding 1 on each side)
306*4bdc9457SAndroid Build Coastguard Worker // - 5x5 stride-1 convolution (no dilation, padding 2 on each side)
307*4bdc9457SAndroid Build Coastguard Worker // - 5x5 stride-2 convolution (no dilation, padding 2 on each side)
308*4bdc9457SAndroid Build Coastguard Worker if ((node->params.depthwise_convolution_2d.dilation_height | node->params.depthwise_convolution_2d.dilation_width) != 1) {
309*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s dilation (height=%" PRIu32 ", width=%" PRIu32 ") "
310*4bdc9457SAndroid Build Coastguard Worker "is incompatible with sparse inference",
311*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
312*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.dilation_height,
313*4bdc9457SAndroid Build Coastguard Worker node->params.convolution_2d.dilation_width);
314*4bdc9457SAndroid Build Coastguard Worker return 0;
315*4bdc9457SAndroid Build Coastguard Worker }
316*4bdc9457SAndroid Build Coastguard Worker if (node->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
317*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s flags (%" PRIu32 ") has padding incompatible with sparse inference",
318*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
319*4bdc9457SAndroid Build Coastguard Worker node->flags);
320*4bdc9457SAndroid Build Coastguard Worker return 0;
321*4bdc9457SAndroid Build Coastguard Worker }
322*4bdc9457SAndroid Build Coastguard Worker if (node->params.depthwise_convolution_2d.depth_multiplier != 1) {
323*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s depth_multiplier (%" PRIu32 ") is incompatible with sparse inference",
324*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
325*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.depth_multiplier);
326*4bdc9457SAndroid Build Coastguard Worker return 0;
327*4bdc9457SAndroid Build Coastguard Worker }
328*4bdc9457SAndroid Build Coastguard Worker if (node->params.depthwise_convolution_2d.subsampling_height != node->params.depthwise_convolution_2d.subsampling_width) {
329*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s subsampling (height=%" PRIu32 ", width=%" PRIu32 ") "
330*4bdc9457SAndroid Build Coastguard Worker "is incompatible with sparse inference",
331*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
332*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.subsampling_height,
333*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.subsampling_width);
334*4bdc9457SAndroid Build Coastguard Worker return 0;
335*4bdc9457SAndroid Build Coastguard Worker }
336*4bdc9457SAndroid Build Coastguard Worker switch (node->params.depthwise_convolution_2d.subsampling_height) {
337*4bdc9457SAndroid Build Coastguard Worker case 1:
338*4bdc9457SAndroid Build Coastguard Worker case 2:
339*4bdc9457SAndroid Build Coastguard Worker break;
340*4bdc9457SAndroid Build Coastguard Worker default:
341*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s subsampling_height (%" PRIu32 ") "
342*4bdc9457SAndroid Build Coastguard Worker "is incompatible with sparse inference",
343*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
344*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.subsampling_height);
345*4bdc9457SAndroid Build Coastguard Worker return 0;
346*4bdc9457SAndroid Build Coastguard Worker }
347*4bdc9457SAndroid Build Coastguard Worker if (node->params.depthwise_convolution_2d.kernel_height != node->params.depthwise_convolution_2d.kernel_width) {
348*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s kernel (height=%" PRIu32 ", width=%" PRIu32 ") "
349*4bdc9457SAndroid Build Coastguard Worker "is incompatible with sparse inference",
350*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
351*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.kernel_height,
352*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.kernel_width);
353*4bdc9457SAndroid Build Coastguard Worker return 0;
354*4bdc9457SAndroid Build Coastguard Worker }
355*4bdc9457SAndroid Build Coastguard Worker switch (node->params.depthwise_convolution_2d.kernel_height) {
356*4bdc9457SAndroid Build Coastguard Worker case 3:
357*4bdc9457SAndroid Build Coastguard Worker if (node->params.depthwise_convolution_2d.input_padding_top == 1 &&
358*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_right == 1 &&
359*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_bottom == 1 &&
360*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_left == 1) {
361*4bdc9457SAndroid Build Coastguard Worker return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
362*4bdc9457SAndroid Build Coastguard Worker } else {
363*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s (3x3 kernel) padding "
364*4bdc9457SAndroid Build Coastguard Worker "(top=%" PRIu32 ", right=%" PRIu32 ", bottom=%" PRIu32 ", left=%" PRIu32 ") "
365*4bdc9457SAndroid Build Coastguard Worker "is incompatible with sparse inference",
366*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
367*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_top,
368*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_right,
369*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_bottom,
370*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_left);
371*4bdc9457SAndroid Build Coastguard Worker return 0;
372*4bdc9457SAndroid Build Coastguard Worker }
373*4bdc9457SAndroid Build Coastguard Worker case 5:
374*4bdc9457SAndroid Build Coastguard Worker if (node->params.depthwise_convolution_2d.input_padding_top == 2 &&
375*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_right == 2 &&
376*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_bottom == 2 &&
377*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_left == 2) {
378*4bdc9457SAndroid Build Coastguard Worker return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
379*4bdc9457SAndroid Build Coastguard Worker } else {
380*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s (5x5 kernel) padding "
381*4bdc9457SAndroid Build Coastguard Worker "(top=%" PRIu32 ", right=%" PRIu32 ", bottom=%" PRIu32 ", left=%" PRIu32 ") "
382*4bdc9457SAndroid Build Coastguard Worker "is incompatible with sparse inference",
383*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type),
384*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_top,
385*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_right,
386*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_bottom,
387*4bdc9457SAndroid Build Coastguard Worker node->params.depthwise_convolution_2d.input_padding_left);
388*4bdc9457SAndroid Build Coastguard Worker return 0;
389*4bdc9457SAndroid Build Coastguard Worker }
390*4bdc9457SAndroid Build Coastguard Worker default:
391*4bdc9457SAndroid Build Coastguard Worker return 0;
392*4bdc9457SAndroid Build Coastguard Worker }
393*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_depth_to_space:
394*4bdc9457SAndroid Build Coastguard Worker return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
395*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_global_average_pooling_2d:
396*4bdc9457SAndroid Build Coastguard Worker return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
397*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_add2:
398*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_multiply2:
399*4bdc9457SAndroid Build Coastguard Worker assert(node->num_inputs == 2);
400*4bdc9457SAndroid Build Coastguard Worker assert(node->num_outputs == 1);
401*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values[node->inputs[0]].shape.num_dims != 4 ||
402*4bdc9457SAndroid Build Coastguard Worker subgraph->values[node->inputs[1]].shape.num_dims != 4)
403*4bdc9457SAndroid Build Coastguard Worker {
404*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s inputs shape is incompatible with sparse inference",
405*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type));
406*4bdc9457SAndroid Build Coastguard Worker return 0;
407*4bdc9457SAndroid Build Coastguard Worker }
408*4bdc9457SAndroid Build Coastguard Worker
409*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values[node->inputs[0]].data != NULL) {
410*4bdc9457SAndroid Build Coastguard Worker // Check that the first input is representable as either a scalar, or a vector
411*4bdc9457SAndroid Build Coastguard Worker size_t num_nonunit_dims = 0;
412*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
413*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
414*4bdc9457SAndroid Build Coastguard Worker num_nonunit_dims += 1;
415*4bdc9457SAndroid Build Coastguard Worker }
416*4bdc9457SAndroid Build Coastguard Worker }
417*4bdc9457SAndroid Build Coastguard Worker if (num_nonunit_dims > 1) {
418*4bdc9457SAndroid Build Coastguard Worker return 0;
419*4bdc9457SAndroid Build Coastguard Worker }
420*4bdc9457SAndroid Build Coastguard Worker }
421*4bdc9457SAndroid Build Coastguard Worker
422*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values[node->inputs[1]].data != NULL) {
423*4bdc9457SAndroid Build Coastguard Worker // Check that the second input is representable as either a scalar, or a vector
424*4bdc9457SAndroid Build Coastguard Worker size_t num_nonunit_dims = 0;
425*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
426*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
427*4bdc9457SAndroid Build Coastguard Worker num_nonunit_dims += 1;
428*4bdc9457SAndroid Build Coastguard Worker }
429*4bdc9457SAndroid Build Coastguard Worker }
430*4bdc9457SAndroid Build Coastguard Worker if (num_nonunit_dims > 1) {
431*4bdc9457SAndroid Build Coastguard Worker return 0;
432*4bdc9457SAndroid Build Coastguard Worker }
433*4bdc9457SAndroid Build Coastguard Worker }
434*4bdc9457SAndroid Build Coastguard Worker
435*4bdc9457SAndroid Build Coastguard Worker return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
436*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_static_resize_bilinear_2d:
437*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values[node->inputs[0]].shape.dim[1] > 1 &&
438*4bdc9457SAndroid Build Coastguard Worker subgraph->values[node->inputs[0]].shape.dim[2] > 1) {
439*4bdc9457SAndroid Build Coastguard Worker return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
440*4bdc9457SAndroid Build Coastguard Worker } else {
441*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s inputs shape is incompatible with sparse inference",
442*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type));
443*4bdc9457SAndroid Build Coastguard Worker return 0;
444*4bdc9457SAndroid Build Coastguard Worker }
445*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_abs:
446*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_bankers_rounding:
447*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_ceiling:
448*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_clamp:
449*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_elu:
450*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_floor:
451*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_hardswish:
452*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_leaky_relu:
453*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_negate:
454*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_sigmoid:
455*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_square:
456*4bdc9457SAndroid Build Coastguard Worker assert(node->num_inputs == 1);
457*4bdc9457SAndroid Build Coastguard Worker assert(node->num_outputs == 1);
458*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values[node->inputs[0]].shape.num_dims == 4) {
459*4bdc9457SAndroid Build Coastguard Worker return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
460*4bdc9457SAndroid Build Coastguard Worker } else {
461*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node %s inputs shape is incompatible with sparse inference",
462*4bdc9457SAndroid Build Coastguard Worker xnn_node_type_to_string(node->type));
463*4bdc9457SAndroid Build Coastguard Worker return 0;
464*4bdc9457SAndroid Build Coastguard Worker }
465*4bdc9457SAndroid Build Coastguard Worker default:
466*4bdc9457SAndroid Build Coastguard Worker return false;
467*4bdc9457SAndroid Build Coastguard Worker }
468*4bdc9457SAndroid Build Coastguard Worker }
469*4bdc9457SAndroid Build Coastguard Worker
xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)470*4bdc9457SAndroid Build Coastguard Worker void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)
471*4bdc9457SAndroid Build Coastguard Worker {
472*4bdc9457SAndroid Build Coastguard Worker // Convert parts of the subgraph to NCHW for sparse inference
473*4bdc9457SAndroid Build Coastguard Worker // Step 1: detect NCHW-compatible Nodes
474*4bdc9457SAndroid Build Coastguard Worker // Step 2: detect NCHW-compatible clusters (run connected components graph algorithm)
475*4bdc9457SAndroid Build Coastguard Worker // Step 3: check that all NCHW-compatible Values are consumed only by NCHW-compatible Nodes
476*4bdc9457SAndroid Build Coastguard Worker // Step 4: switch Values' layout to NCHW
477*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
478*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* node = &subgraph->nodes[n];
479*4bdc9457SAndroid Build Coastguard Worker node->layout_flags = xnn_check_nchw_compatibility(subgraph, node);
480*4bdc9457SAndroid Build Coastguard Worker xnn_log_debug("Node #%" PRIu32 ": %s (NCHW: %s, NHWC->NCHW: %s, NCHW->NHWC: %s)",
481*4bdc9457SAndroid Build Coastguard Worker n, xnn_node_type_to_string(node->type),
482*4bdc9457SAndroid Build Coastguard Worker node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW ? "yes" : "no",
483*4bdc9457SAndroid Build Coastguard Worker node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW ? "yes" : "no",
484*4bdc9457SAndroid Build Coastguard Worker node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC ? "yes" : "no");
485*4bdc9457SAndroid Build Coastguard Worker }
486*4bdc9457SAndroid Build Coastguard Worker
487*4bdc9457SAndroid Build Coastguard Worker // Run Shiloach-Vishkin connected components algorithm i.e. find all
488*4bdc9457SAndroid Build Coastguard Worker // XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC nodes and set them as cluster leaders
489*4bdc9457SAndroid Build Coastguard Worker // to all the producer nodes
490*4bdc9457SAndroid Build Coastguard Worker bool update = false;
491*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
492*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* node = &subgraph->nodes[n];
493*4bdc9457SAndroid Build Coastguard Worker node->cluster_leader = n;
494*4bdc9457SAndroid Build Coastguard Worker if (node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC) {
495*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_inputs; i++) {
496*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* value = &subgraph->values[node->inputs[i]];
497*4bdc9457SAndroid Build Coastguard Worker if (value->data != NULL) {
498*4bdc9457SAndroid Build Coastguard Worker // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
499*4bdc9457SAndroid Build Coastguard Worker // during the initial NCHW compatibility check for the Node.
500*4bdc9457SAndroid Build Coastguard Worker continue;
501*4bdc9457SAndroid Build Coastguard Worker }
502*4bdc9457SAndroid Build Coastguard Worker if (xnn_value_is_external(value)) {
503*4bdc9457SAndroid Build Coastguard Worker // External value, invalid cluster
504*4bdc9457SAndroid Build Coastguard Worker node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
505*4bdc9457SAndroid Build Coastguard Worker continue;
506*4bdc9457SAndroid Build Coastguard Worker }
507*4bdc9457SAndroid Build Coastguard Worker const uint32_t producer_id = value->producer;
508*4bdc9457SAndroid Build Coastguard Worker assert(producer_id != XNN_INVALID_NODE_ID);
509*4bdc9457SAndroid Build Coastguard Worker assert(producer_id < n);
510*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* producer_node = &subgraph->nodes[producer_id];
511*4bdc9457SAndroid Build Coastguard Worker if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
512*4bdc9457SAndroid Build Coastguard Worker (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
513*4bdc9457SAndroid Build Coastguard Worker {
514*4bdc9457SAndroid Build Coastguard Worker producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
515*4bdc9457SAndroid Build Coastguard Worker if (producer_node->cluster_leader != node->cluster_leader) {
516*4bdc9457SAndroid Build Coastguard Worker producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
517*4bdc9457SAndroid Build Coastguard Worker update = true;
518*4bdc9457SAndroid Build Coastguard Worker }
519*4bdc9457SAndroid Build Coastguard Worker } else {
520*4bdc9457SAndroid Build Coastguard Worker node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
521*4bdc9457SAndroid Build Coastguard Worker }
522*4bdc9457SAndroid Build Coastguard Worker }
523*4bdc9457SAndroid Build Coastguard Worker }
524*4bdc9457SAndroid Build Coastguard Worker }
525*4bdc9457SAndroid Build Coastguard Worker // No NCHW2NHWC compatible nodes have been found thus the graph rewriting
526*4bdc9457SAndroid Build Coastguard Worker // practically cannot happen.
527*4bdc9457SAndroid Build Coastguard Worker if (!update) {
528*4bdc9457SAndroid Build Coastguard Worker return;
529*4bdc9457SAndroid Build Coastguard Worker }
530*4bdc9457SAndroid Build Coastguard Worker // Propagate the cluster leader to other nodes in the graph untill all the
531*4bdc9457SAndroid Build Coastguard Worker // nodes in the cluster is not updated
532*4bdc9457SAndroid Build Coastguard Worker while (update) {
533*4bdc9457SAndroid Build Coastguard Worker update = false;
534*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
535*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* node = &subgraph->nodes[n];
536*4bdc9457SAndroid Build Coastguard Worker if (node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) {
537*4bdc9457SAndroid Build Coastguard Worker continue;
538*4bdc9457SAndroid Build Coastguard Worker }
539*4bdc9457SAndroid Build Coastguard Worker
540*4bdc9457SAndroid Build Coastguard Worker if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC)) == 0) {
541*4bdc9457SAndroid Build Coastguard Worker continue;
542*4bdc9457SAndroid Build Coastguard Worker }
543*4bdc9457SAndroid Build Coastguard Worker
544*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_inputs; i++) {
545*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* value = &subgraph->values[node->inputs[i]];
546*4bdc9457SAndroid Build Coastguard Worker if (value->data != NULL) {
547*4bdc9457SAndroid Build Coastguard Worker // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
548*4bdc9457SAndroid Build Coastguard Worker // during the initial NCHW compatibility check for the Node.
549*4bdc9457SAndroid Build Coastguard Worker continue;
550*4bdc9457SAndroid Build Coastguard Worker }
551*4bdc9457SAndroid Build Coastguard Worker if (xnn_value_is_external(value)) {
552*4bdc9457SAndroid Build Coastguard Worker // External value, invalid cluster
553*4bdc9457SAndroid Build Coastguard Worker node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
554*4bdc9457SAndroid Build Coastguard Worker continue;
555*4bdc9457SAndroid Build Coastguard Worker }
556*4bdc9457SAndroid Build Coastguard Worker const uint32_t producer_id = value->producer;
557*4bdc9457SAndroid Build Coastguard Worker assert(producer_id != XNN_INVALID_NODE_ID);
558*4bdc9457SAndroid Build Coastguard Worker assert(producer_id < n);
559*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* producer_node = &subgraph->nodes[producer_id];
560*4bdc9457SAndroid Build Coastguard Worker if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
561*4bdc9457SAndroid Build Coastguard Worker (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
562*4bdc9457SAndroid Build Coastguard Worker {
563*4bdc9457SAndroid Build Coastguard Worker producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
564*4bdc9457SAndroid Build Coastguard Worker if (producer_node->cluster_leader != node->cluster_leader) {
565*4bdc9457SAndroid Build Coastguard Worker producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
566*4bdc9457SAndroid Build Coastguard Worker update = true;
567*4bdc9457SAndroid Build Coastguard Worker }
568*4bdc9457SAndroid Build Coastguard Worker } else {
569*4bdc9457SAndroid Build Coastguard Worker node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
570*4bdc9457SAndroid Build Coastguard Worker }
571*4bdc9457SAndroid Build Coastguard Worker }
572*4bdc9457SAndroid Build Coastguard Worker }
573*4bdc9457SAndroid Build Coastguard Worker }
574*4bdc9457SAndroid Build Coastguard Worker // Propagate XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER flags up to the cluster leaders
575*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
576*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* node = &subgraph->nodes[n];
577*4bdc9457SAndroid Build Coastguard Worker subgraph->nodes[node->cluster_leader].layout_flags |= node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
578*4bdc9457SAndroid Build Coastguard Worker }
579*4bdc9457SAndroid Build Coastguard Worker // Check that all Values consumed by NCHW-compatible cluster don't have NCHW-incompatible consumers
580*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
581*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* node = &subgraph->nodes[n];
582*4bdc9457SAndroid Build Coastguard Worker if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
583*4bdc9457SAndroid Build Coastguard Worker continue;
584*4bdc9457SAndroid Build Coastguard Worker }
585*4bdc9457SAndroid Build Coastguard Worker
586*4bdc9457SAndroid Build Coastguard Worker if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
587*4bdc9457SAndroid Build Coastguard Worker continue;
588*4bdc9457SAndroid Build Coastguard Worker }
589*4bdc9457SAndroid Build Coastguard Worker
590*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_inputs; i++) {
591*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* value = &subgraph->values[node->inputs[i]];
592*4bdc9457SAndroid Build Coastguard Worker if (value->data != NULL) {
593*4bdc9457SAndroid Build Coastguard Worker // Static data, skip this input value because it doesn't have a producer Node.
594*4bdc9457SAndroid Build Coastguard Worker continue;
595*4bdc9457SAndroid Build Coastguard Worker }
596*4bdc9457SAndroid Build Coastguard Worker assert(!xnn_value_is_external(value));
597*4bdc9457SAndroid Build Coastguard Worker value->num_nchw_compatible_consumers += 1;
598*4bdc9457SAndroid Build Coastguard Worker }
599*4bdc9457SAndroid Build Coastguard Worker }
600*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
601*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* node = &subgraph->nodes[n];
602*4bdc9457SAndroid Build Coastguard Worker if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
603*4bdc9457SAndroid Build Coastguard Worker continue;
604*4bdc9457SAndroid Build Coastguard Worker }
605*4bdc9457SAndroid Build Coastguard Worker
606*4bdc9457SAndroid Build Coastguard Worker if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
607*4bdc9457SAndroid Build Coastguard Worker continue;
608*4bdc9457SAndroid Build Coastguard Worker }
609*4bdc9457SAndroid Build Coastguard Worker
610*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_inputs; i++) {
611*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* value = &subgraph->values[node->inputs[i]];
612*4bdc9457SAndroid Build Coastguard Worker if (value->data != NULL) {
613*4bdc9457SAndroid Build Coastguard Worker // Static data, skip this input value because it doesn't have a producer Node.
614*4bdc9457SAndroid Build Coastguard Worker continue;
615*4bdc9457SAndroid Build Coastguard Worker }
616*4bdc9457SAndroid Build Coastguard Worker assert(!xnn_value_is_external(value));
617*4bdc9457SAndroid Build Coastguard Worker assert(value->num_nchw_compatible_consumers > 0);
618*4bdc9457SAndroid Build Coastguard Worker if (value->num_nchw_compatible_consumers != value->num_consumers) {
619*4bdc9457SAndroid Build Coastguard Worker subgraph->nodes[node->cluster_leader].layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
620*4bdc9457SAndroid Build Coastguard Worker }
621*4bdc9457SAndroid Build Coastguard Worker }
622*4bdc9457SAndroid Build Coastguard Worker }
623*4bdc9457SAndroid Build Coastguard Worker // Evaluate if it is profitable to run the model as sparse:
624*4bdc9457SAndroid Build Coastguard Worker // - Compute the number of parameters and zeroes in 1x1 Convolution weights
625*4bdc9457SAndroid Build Coastguard Worker // - Disable sparse rewriting for clusters without 1x1 Convolutions (num_params == 0)
626*4bdc9457SAndroid Build Coastguard Worker // or with less than 2/3rd of zeroes in 1x1 Convolution filters
627*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
628*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* node = &subgraph->nodes[n];
629*4bdc9457SAndroid Build Coastguard Worker if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
630*4bdc9457SAndroid Build Coastguard Worker continue;
631*4bdc9457SAndroid Build Coastguard Worker }
632*4bdc9457SAndroid Build Coastguard Worker
633*4bdc9457SAndroid Build Coastguard Worker if (node->type == xnn_node_type_convolution_2d &&
634*4bdc9457SAndroid Build Coastguard Worker max(node->params.convolution_2d.kernel_height, node->params.convolution_2d.kernel_width) == 1)
635*4bdc9457SAndroid Build Coastguard Worker {
636*4bdc9457SAndroid Build Coastguard Worker assert(node->num_inputs >= 2);
637*4bdc9457SAndroid Build Coastguard Worker
638*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* filter = &subgraph->values[node->inputs[1]];
639*4bdc9457SAndroid Build Coastguard Worker assert(filter->data != NULL);
640*4bdc9457SAndroid Build Coastguard Worker assert(filter->shape.num_dims == 4);
641*4bdc9457SAndroid Build Coastguard Worker
642*4bdc9457SAndroid Build Coastguard Worker const size_t num_params = filter->shape.dim[0] * filter->shape.dim[3];
643*4bdc9457SAndroid Build Coastguard Worker subgraph->nodes[node->cluster_leader].num_params += num_params;
644*4bdc9457SAndroid Build Coastguard Worker
645*4bdc9457SAndroid Build Coastguard Worker const float* data = (const float*) filter->data;
646*4bdc9457SAndroid Build Coastguard Worker size_t num_zeroes = 0;
647*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_params; i++) {
648*4bdc9457SAndroid Build Coastguard Worker num_zeroes += (size_t) (data[i] == 0.0f);
649*4bdc9457SAndroid Build Coastguard Worker }
650*4bdc9457SAndroid Build Coastguard Worker xnn_log_debug("1x1 Convolution 2D Node #%" PRIu32 ": %zu / %zu sparsity", n, num_zeroes, num_params);
651*4bdc9457SAndroid Build Coastguard Worker subgraph->nodes[node->cluster_leader].num_zeroes += num_zeroes;
652*4bdc9457SAndroid Build Coastguard Worker }
653*4bdc9457SAndroid Build Coastguard Worker }
654*4bdc9457SAndroid Build Coastguard Worker bool use_nchw_layout = false;
655*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
656*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* node = &subgraph->nodes[n];
657*4bdc9457SAndroid Build Coastguard Worker if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
658*4bdc9457SAndroid Build Coastguard Worker continue;
659*4bdc9457SAndroid Build Coastguard Worker }
660*4bdc9457SAndroid Build Coastguard Worker
661*4bdc9457SAndroid Build Coastguard Worker if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
662*4bdc9457SAndroid Build Coastguard Worker continue;
663*4bdc9457SAndroid Build Coastguard Worker }
664*4bdc9457SAndroid Build Coastguard Worker
665*4bdc9457SAndroid Build Coastguard Worker if (subgraph->nodes[node->cluster_leader].num_zeroes * 3 <= subgraph->nodes[node->cluster_leader].num_params * 2) {
666*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Node #%" PRIu32 ": sparse inference disabled: 1x1 Convolutions contain %zu / %zu zero weights",
667*4bdc9457SAndroid Build Coastguard Worker n, subgraph->nodes[node->cluster_leader].num_zeroes, subgraph->nodes[node->cluster_leader].num_params);
668*4bdc9457SAndroid Build Coastguard Worker continue;
669*4bdc9457SAndroid Build Coastguard Worker }
670*4bdc9457SAndroid Build Coastguard Worker
671*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_inputs; i++) {
672*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* value = &subgraph->values[node->inputs[i]];
673*4bdc9457SAndroid Build Coastguard Worker if (value->data != NULL) {
674*4bdc9457SAndroid Build Coastguard Worker // Static data, skip this input value because it doesn't have a producer Node.
675*4bdc9457SAndroid Build Coastguard Worker continue;
676*4bdc9457SAndroid Build Coastguard Worker }
677*4bdc9457SAndroid Build Coastguard Worker assert(!xnn_value_is_external(value));
678*4bdc9457SAndroid Build Coastguard Worker assert(value->num_nchw_compatible_consumers > 0);
679*4bdc9457SAndroid Build Coastguard Worker assert(value->num_nchw_compatible_consumers == value->num_consumers);
680*4bdc9457SAndroid Build Coastguard Worker if (value->layout != xnn_layout_type_nchw) {
681*4bdc9457SAndroid Build Coastguard Worker value->layout = xnn_layout_type_nchw;
682*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("set Value #%"PRIu32" layout to NCHW", node->inputs[i]);
683*4bdc9457SAndroid Build Coastguard Worker use_nchw_layout = true;
684*4bdc9457SAndroid Build Coastguard Worker }
685*4bdc9457SAndroid Build Coastguard Worker }
686*4bdc9457SAndroid Build Coastguard Worker }
687*4bdc9457SAndroid Build Coastguard Worker if (use_nchw_layout) {
688*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("XNNPACK has switched to sparse inference mode!");
689*4bdc9457SAndroid Build Coastguard Worker }
690*4bdc9457SAndroid Build Coastguard Worker }
691*4bdc9457SAndroid Build Coastguard Worker
xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)692*4bdc9457SAndroid Build Coastguard Worker bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
693*4bdc9457SAndroid Build Coastguard Worker {
694*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("Analyzing subgraph for FP16 compatibility");
695*4bdc9457SAndroid Build Coastguard Worker
696*4bdc9457SAndroid Build Coastguard Worker // Convert tensors and operators in the subgraph to FP16
697*4bdc9457SAndroid Build Coastguard Worker // 1. Check that all operators in the subgraph are supported in FP16.
698*4bdc9457SAndroid Build Coastguard Worker // 2. Indicate values that must be converted to FP16.
699*4bdc9457SAndroid Build Coastguard Worker // 3. Replace FP32 Values with FP16 Values as Nodes' inputs/outputs.
700*4bdc9457SAndroid Build Coastguard Worker // 4. Insert FP32->FP16 Convert Nodes for external FP32 inputs and FP16->FP32 Convert Nodes for external outputs.
701*4bdc9457SAndroid Build Coastguard Worker
702*4bdc9457SAndroid Build Coastguard Worker // Check that all operators in the subgraph are supported in FP16, bail out on any unsupported one.
703*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
704*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* node = &subgraph->nodes[n];
705*4bdc9457SAndroid Build Coastguard Worker if (node->type == xnn_node_type_invalid) {
706*4bdc9457SAndroid Build Coastguard Worker // Node was fused away, skip.
707*4bdc9457SAndroid Build Coastguard Worker continue;
708*4bdc9457SAndroid Build Coastguard Worker }
709*4bdc9457SAndroid Build Coastguard Worker
710*4bdc9457SAndroid Build Coastguard Worker if (node->compute_type != xnn_compute_type_fp32) {
711*4bdc9457SAndroid Build Coastguard Worker xnn_log_warning("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not FP32", n, xnn_node_type_to_string(node->type));
712*4bdc9457SAndroid Build Coastguard Worker return false;
713*4bdc9457SAndroid Build Coastguard Worker }
714*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_inputs; i++) {
715*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values[node->inputs[i]].layout == xnn_layout_type_nchw) {
716*4bdc9457SAndroid Build Coastguard Worker xnn_log_warning(
717*4bdc9457SAndroid Build Coastguard Worker "FP16 rewrite aborted: input #%" PRIu32 " (Value #%" PRIu32 ") of node #%" PRIu32 " (%s) has NCHW layout",
718*4bdc9457SAndroid Build Coastguard Worker i, node->inputs[i], n, xnn_node_type_to_string(node->type));
719*4bdc9457SAndroid Build Coastguard Worker return false;
720*4bdc9457SAndroid Build Coastguard Worker }
721*4bdc9457SAndroid Build Coastguard Worker }
722*4bdc9457SAndroid Build Coastguard Worker for (uint32_t o = 0; o < node->num_outputs; o++) {
723*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values[node->outputs[o]].layout == xnn_layout_type_nchw) {
724*4bdc9457SAndroid Build Coastguard Worker xnn_log_warning(
725*4bdc9457SAndroid Build Coastguard Worker "FP16 rewrite aborted: output #%" PRIu32 " (Value #%" PRIu32 ") of node #%" PRIu32 " (%s) has NCHW layout",
726*4bdc9457SAndroid Build Coastguard Worker o, node->outputs[o], n, xnn_node_type_to_string(node->type));
727*4bdc9457SAndroid Build Coastguard Worker return false;
728*4bdc9457SAndroid Build Coastguard Worker }
729*4bdc9457SAndroid Build Coastguard Worker }
730*4bdc9457SAndroid Build Coastguard Worker switch (node->type) {
731*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_abs:
732*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_add2:
733*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_divide:
734*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_maximum2:
735*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_minimum2:
736*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_multiply2:
737*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_concatenate2:
738*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_concatenate3:
739*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_concatenate4:
740*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_squared_difference:
741*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_subtract:
742*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_inputs; i++) {
743*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values[node->inputs[i]].data != NULL) {
744*4bdc9457SAndroid Build Coastguard Worker xnn_log_warning("FP16 rewrite aborted: node #%" PRIu32 " (%s) has static input %" PRIu32,
745*4bdc9457SAndroid Build Coastguard Worker n, xnn_node_type_to_string(node->type), i);
746*4bdc9457SAndroid Build Coastguard Worker return false;
747*4bdc9457SAndroid Build Coastguard Worker }
748*4bdc9457SAndroid Build Coastguard Worker }
749*4bdc9457SAndroid Build Coastguard Worker break;
750*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_average_pooling_2d:
751*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_bankers_rounding:
752*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_ceiling:
753*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_clamp:
754*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_convolution_2d:
755*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_deconvolution_2d:
756*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_depthwise_convolution_2d:
757*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_depth_to_space:
758*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_elu:
759*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_even_split2:
760*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_even_split3:
761*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_even_split4:
762*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_floor:
763*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_fully_connected:
764*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_global_average_pooling_2d:
765*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_hardswish:
766*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_leaky_relu:
767*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_max_pooling_2d:
768*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_negate:
769*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_prelu:
770*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_sigmoid:
771*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_softmax:
772*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_static_constant_pad:
773*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_static_reshape:
774*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_static_resize_bilinear_2d:
775*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_static_transpose:
776*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_square:
777*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_square_root:
778*4bdc9457SAndroid Build Coastguard Worker break;
779*4bdc9457SAndroid Build Coastguard Worker default:
780*4bdc9457SAndroid Build Coastguard Worker xnn_log_warning("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not supported for FP16 inference",
781*4bdc9457SAndroid Build Coastguard Worker n, xnn_node_type_to_string(node->type));
782*4bdc9457SAndroid Build Coastguard Worker return false;
783*4bdc9457SAndroid Build Coastguard Worker }
784*4bdc9457SAndroid Build Coastguard Worker }
785*4bdc9457SAndroid Build Coastguard Worker
786*4bdc9457SAndroid Build Coastguard Worker // Annotate Values to be converted to FP16 as FP16-compatible.
787*4bdc9457SAndroid Build Coastguard Worker // Note that static weights in [Depthwise] Convolution, Fully Connected, and PReLU Nodes remain FP32,
788*4bdc9457SAndroid Build Coastguard Worker // they will be converted to FP16 during weight repacking when the operator is created.
789*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
790*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* node = &subgraph->nodes[n];
791*4bdc9457SAndroid Build Coastguard Worker switch (node->type) {
792*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_convolution_2d:
793*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_deconvolution_2d:
794*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_depthwise_convolution_2d:
795*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_fully_connected:
796*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_prelu:
797*4bdc9457SAndroid Build Coastguard Worker subgraph->values[node->inputs[0]].fp16_compatible = true;
798*4bdc9457SAndroid Build Coastguard Worker subgraph->values[node->outputs[0]].fp16_compatible = true;
799*4bdc9457SAndroid Build Coastguard Worker break;
800*4bdc9457SAndroid Build Coastguard Worker default:
801*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_inputs; i++) {
802*4bdc9457SAndroid Build Coastguard Worker subgraph->values[node->inputs[i]].fp16_compatible = true;
803*4bdc9457SAndroid Build Coastguard Worker }
804*4bdc9457SAndroid Build Coastguard Worker for (uint32_t o = 0; o < node->num_outputs; o++) {
805*4bdc9457SAndroid Build Coastguard Worker subgraph->values[node->outputs[o]].fp16_compatible = true;
806*4bdc9457SAndroid Build Coastguard Worker }
807*4bdc9457SAndroid Build Coastguard Worker break;
808*4bdc9457SAndroid Build Coastguard Worker }
809*4bdc9457SAndroid Build Coastguard Worker }
810*4bdc9457SAndroid Build Coastguard Worker
811*4bdc9457SAndroid Build Coastguard Worker // Replace FP32 Values in Nodes' inputs/outputs with FP16 Values.
812*4bdc9457SAndroid Build Coastguard Worker // FP32 Values that are not external inputs or outputs are converted to FP16 in-place,
813*4bdc9457SAndroid Build Coastguard Worker // for external inputs and outputs we create same-shaped FP16 Values and use those instead.
814*4bdc9457SAndroid Build Coastguard Worker const uint32_t num_original_values = subgraph->num_values;
815*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_analyze_consumers_and_producers(subgraph);
816*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < num_original_values; n++) {
817*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* value = &subgraph->values[n];
818*4bdc9457SAndroid Build Coastguard Worker value->fp16_id = XNN_INVALID_VALUE_ID;
819*4bdc9457SAndroid Build Coastguard Worker value->fp32_id = XNN_INVALID_VALUE_ID;
820*4bdc9457SAndroid Build Coastguard Worker if (value->fp16_compatible) {
821*4bdc9457SAndroid Build Coastguard Worker assert(value->data == NULL);
822*4bdc9457SAndroid Build Coastguard Worker assert(value->datatype == xnn_datatype_fp32);
823*4bdc9457SAndroid Build Coastguard Worker if (xnn_value_is_external(value)) {
824*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* fp16_value = xnn_subgraph_new_internal_value(subgraph);
825*4bdc9457SAndroid Build Coastguard Worker
826*4bdc9457SAndroid Build Coastguard Worker // Recompute value due to potential reallocation in xnn_subgraph_new_internal_value
827*4bdc9457SAndroid Build Coastguard Worker value = &subgraph->values[n];
828*4bdc9457SAndroid Build Coastguard Worker xnn_value_copy(fp16_value, value);
829*4bdc9457SAndroid Build Coastguard Worker fp16_value->datatype = xnn_datatype_fp16;
830*4bdc9457SAndroid Build Coastguard Worker
831*4bdc9457SAndroid Build Coastguard Worker fp16_value->producer = value->producer;
832*4bdc9457SAndroid Build Coastguard Worker fp16_value->num_consumers = value->num_consumers;
833*4bdc9457SAndroid Build Coastguard Worker fp16_value->first_consumer = value->first_consumer;
834*4bdc9457SAndroid Build Coastguard Worker value->producer = XNN_INVALID_NODE_ID;
835*4bdc9457SAndroid Build Coastguard Worker value->num_consumers = 0;
836*4bdc9457SAndroid Build Coastguard Worker value->first_consumer = XNN_INVALID_NODE_ID;
837*4bdc9457SAndroid Build Coastguard Worker
838*4bdc9457SAndroid Build Coastguard Worker // Clear external input/output flags
839*4bdc9457SAndroid Build Coastguard Worker fp16_value->flags = 0;
840*4bdc9457SAndroid Build Coastguard Worker xnn_log_debug("FP16 rewrite: created FP16 tensor #%" PRIu32 " for FP32 tensor #%" PRIu32, fp16_value->id, n);
841*4bdc9457SAndroid Build Coastguard Worker
842*4bdc9457SAndroid Build Coastguard Worker value->fp16_id = fp16_value->id;
843*4bdc9457SAndroid Build Coastguard Worker fp16_value->fp32_id = n;
844*4bdc9457SAndroid Build Coastguard Worker } else {
845*4bdc9457SAndroid Build Coastguard Worker xnn_log_debug("FP16 rewrite: converted FP32 tensor #%" PRIu32 " to FP16", n);
846*4bdc9457SAndroid Build Coastguard Worker value->datatype = xnn_datatype_fp16;
847*4bdc9457SAndroid Build Coastguard Worker }
848*4bdc9457SAndroid Build Coastguard Worker }
849*4bdc9457SAndroid Build Coastguard Worker }
850*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
851*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* node = &subgraph->nodes[n];
852*4bdc9457SAndroid Build Coastguard Worker if (node->type == xnn_node_type_invalid) {
853*4bdc9457SAndroid Build Coastguard Worker // Node was fused away, skip.
854*4bdc9457SAndroid Build Coastguard Worker continue;
855*4bdc9457SAndroid Build Coastguard Worker }
856*4bdc9457SAndroid Build Coastguard Worker
857*4bdc9457SAndroid Build Coastguard Worker assert(node->compute_type == xnn_compute_type_fp32);
858*4bdc9457SAndroid Build Coastguard Worker node->compute_type = xnn_compute_type_fp16;
859*4bdc9457SAndroid Build Coastguard Worker if (node->type == xnn_node_type_static_constant_pad) {
860*4bdc9457SAndroid Build Coastguard Worker node->params.static_pad.padding_value =
861*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(uint32_as_float(node->params.static_pad.padding_value));
862*4bdc9457SAndroid Build Coastguard Worker }
863*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_inputs; i++) {
864*4bdc9457SAndroid Build Coastguard Worker const uint32_t fp16_id = subgraph->values[node->inputs[i]].fp16_id;
865*4bdc9457SAndroid Build Coastguard Worker if (fp16_id != XNN_INVALID_VALUE_ID) {
866*4bdc9457SAndroid Build Coastguard Worker assert(subgraph->values[fp16_id].fp32_id == node->inputs[i]);
867*4bdc9457SAndroid Build Coastguard Worker node->inputs[i] = fp16_id;
868*4bdc9457SAndroid Build Coastguard Worker }
869*4bdc9457SAndroid Build Coastguard Worker }
870*4bdc9457SAndroid Build Coastguard Worker for (uint32_t o = 0; o < node->num_outputs; o++) {
871*4bdc9457SAndroid Build Coastguard Worker const uint32_t fp16_id = subgraph->values[node->outputs[o]].fp16_id;
872*4bdc9457SAndroid Build Coastguard Worker if (fp16_id != XNN_INVALID_VALUE_ID) {
873*4bdc9457SAndroid Build Coastguard Worker assert(subgraph->values[fp16_id].fp32_id == node->outputs[o]);
874*4bdc9457SAndroid Build Coastguard Worker node->outputs[o] = fp16_id;
875*4bdc9457SAndroid Build Coastguard Worker }
876*4bdc9457SAndroid Build Coastguard Worker }
877*4bdc9457SAndroid Build Coastguard Worker }
878*4bdc9457SAndroid Build Coastguard Worker
879*4bdc9457SAndroid Build Coastguard Worker // Count the number of external inputs and outputs which require Convert nodes
880*4bdc9457SAndroid Build Coastguard Worker uint32_t num_external_inputs = 0;
881*4bdc9457SAndroid Build Coastguard Worker uint32_t num_external_outputs = 0;
882*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
883*4bdc9457SAndroid Build Coastguard Worker const struct xnn_node* node = &subgraph->nodes[n];
884*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_inputs; i++) {
885*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* value = &subgraph->values[node->inputs[i]];
886*4bdc9457SAndroid Build Coastguard Worker if (value->fp32_id != XNN_INVALID_VALUE_ID && value->first_consumer == n) {
887*4bdc9457SAndroid Build Coastguard Worker assert(value->data == NULL);
888*4bdc9457SAndroid Build Coastguard Worker assert(value->datatype == xnn_datatype_fp16);
889*4bdc9457SAndroid Build Coastguard Worker assert(subgraph->values[value->fp32_id].datatype == xnn_datatype_fp32);
890*4bdc9457SAndroid Build Coastguard Worker // This value isn't always an external input, it could be an external output of the current subgraph (due to
891*4bdc9457SAndroid Build Coastguard Worker // partition), and be simultaneously consumed by the current node.
892*4bdc9457SAndroid Build Coastguard Worker if (xnn_value_is_external_input(&subgraph->values[value->fp32_id])) {
893*4bdc9457SAndroid Build Coastguard Worker num_external_inputs += 1;
894*4bdc9457SAndroid Build Coastguard Worker }
895*4bdc9457SAndroid Build Coastguard Worker }
896*4bdc9457SAndroid Build Coastguard Worker }
897*4bdc9457SAndroid Build Coastguard Worker for (uint32_t o = 0; o < node->num_outputs; o++) {
898*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* value = &subgraph->values[node->outputs[o]];
899*4bdc9457SAndroid Build Coastguard Worker if (value->fp32_id != XNN_INVALID_VALUE_ID) {
900*4bdc9457SAndroid Build Coastguard Worker assert(value->datatype == xnn_datatype_fp16);
901*4bdc9457SAndroid Build Coastguard Worker assert(subgraph->values[value->fp32_id].datatype == xnn_datatype_fp32);
902*4bdc9457SAndroid Build Coastguard Worker assert(xnn_value_is_external_output(&subgraph->values[value->fp32_id]));
903*4bdc9457SAndroid Build Coastguard Worker num_external_outputs += 1;
904*4bdc9457SAndroid Build Coastguard Worker }
905*4bdc9457SAndroid Build Coastguard Worker }
906*4bdc9457SAndroid Build Coastguard Worker }
907*4bdc9457SAndroid Build Coastguard Worker xnn_log_debug("Discovered %"PRIu32" external inputs and %"PRIu32" external outputs",
908*4bdc9457SAndroid Build Coastguard Worker num_external_inputs, num_external_outputs);
909*4bdc9457SAndroid Build Coastguard Worker
910*4bdc9457SAndroid Build Coastguard Worker const uint32_t num_original_nodes = subgraph->num_nodes;
911*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_add_nodes(subgraph, num_external_inputs + num_external_outputs);
912*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* output_node = subgraph->nodes + subgraph->num_nodes - 1;
913*4bdc9457SAndroid Build Coastguard Worker for (uint32_t n = num_original_nodes; n != 0; n--) {
914*4bdc9457SAndroid Build Coastguard Worker const struct xnn_node* node = &subgraph->nodes[n - 1];
915*4bdc9457SAndroid Build Coastguard Worker // Insert Convert nodes for outputs
916*4bdc9457SAndroid Build Coastguard Worker for (uint32_t o = 0; o < node->num_outputs; o++) {
917*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* value = &subgraph->values[node->outputs[o]];
918*4bdc9457SAndroid Build Coastguard Worker if (value->fp32_id != XNN_INVALID_VALUE_ID) {
919*4bdc9457SAndroid Build Coastguard Worker xnn_log_debug("Inserted FP16->FP32 Convert Node from tensor #%"PRIu32" to tensor #%"PRIu32,
920*4bdc9457SAndroid Build Coastguard Worker value->id, value->fp32_id);
921*4bdc9457SAndroid Build Coastguard Worker const uint32_t output_node_id = output_node->id;
922*4bdc9457SAndroid Build Coastguard Worker assert(output_node >= subgraph->nodes);
923*4bdc9457SAndroid Build Coastguard Worker xnn_node_clear(output_node);
924*4bdc9457SAndroid Build Coastguard Worker output_node->id = output_node_id;
925*4bdc9457SAndroid Build Coastguard Worker xnn_init_convert_node(output_node, xnn_compute_type_fp16_to_fp32, value->id, value->fp32_id, 0 /* flags */);
926*4bdc9457SAndroid Build Coastguard Worker output_node -= 1;
927*4bdc9457SAndroid Build Coastguard Worker }
928*4bdc9457SAndroid Build Coastguard Worker }
929*4bdc9457SAndroid Build Coastguard Worker // Move the Node to the new location
930*4bdc9457SAndroid Build Coastguard Worker if (output_node != node) {
931*4bdc9457SAndroid Build Coastguard Worker const uint32_t output_node_id = output_node->id;
932*4bdc9457SAndroid Build Coastguard Worker assert(output_node >= subgraph->nodes);
933*4bdc9457SAndroid Build Coastguard Worker memcpy(output_node, node, sizeof(struct xnn_node));
934*4bdc9457SAndroid Build Coastguard Worker output_node->id = output_node_id;
935*4bdc9457SAndroid Build Coastguard Worker output_node -= 1;
936*4bdc9457SAndroid Build Coastguard Worker }
937*4bdc9457SAndroid Build Coastguard Worker // Insert Convert nodes for inputs
938*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_inputs; i++) {
939*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value* value = &subgraph->values[node->inputs[i]];
940*4bdc9457SAndroid Build Coastguard Worker if (value->fp32_id != XNN_INVALID_VALUE_ID && value->first_consumer == n - 1) {
941*4bdc9457SAndroid Build Coastguard Worker // Only insert convert nodes if the value actually is an external input. This value could be an external output,
942*4bdc9457SAndroid Build Coastguard Worker // if that's the case, we have already inserted a convert node in loop above for outputs.
943*4bdc9457SAndroid Build Coastguard Worker if (xnn_value_is_external_input(&subgraph->values[value->fp32_id])) {
944*4bdc9457SAndroid Build Coastguard Worker xnn_log_debug("Inserted FP32->FP16 Convert Node from tensor #%"PRIu32" to tensor #%"PRIu32,
945*4bdc9457SAndroid Build Coastguard Worker value->fp32_id, value->id);
946*4bdc9457SAndroid Build Coastguard Worker const uint32_t output_node_id = output_node->id;
947*4bdc9457SAndroid Build Coastguard Worker assert(output_node >= subgraph->nodes);
948*4bdc9457SAndroid Build Coastguard Worker xnn_node_clear(output_node);
949*4bdc9457SAndroid Build Coastguard Worker output_node->id = output_node_id;
950*4bdc9457SAndroid Build Coastguard Worker xnn_init_convert_node(output_node, xnn_compute_type_fp32_to_fp16, value->fp32_id, value->id, 0 /* flags */);
951*4bdc9457SAndroid Build Coastguard Worker output_node -= 1;
952*4bdc9457SAndroid Build Coastguard Worker }
953*4bdc9457SAndroid Build Coastguard Worker }
954*4bdc9457SAndroid Build Coastguard Worker }
955*4bdc9457SAndroid Build Coastguard Worker }
956*4bdc9457SAndroid Build Coastguard Worker
957*4bdc9457SAndroid Build Coastguard Worker return true;
958*4bdc9457SAndroid Build Coastguard Worker }
959*4bdc9457SAndroid Build Coastguard Worker
xnn_subgraph_fusion(xnn_subgraph_t subgraph)960*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_subgraph_fusion(
961*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph)
962*4bdc9457SAndroid Build Coastguard Worker {
963*4bdc9457SAndroid Build Coastguard Worker // Fuse Nodes where possible
964*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < subgraph->num_values; i++) {
965*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* value = &subgraph->values[i];
966*4bdc9457SAndroid Build Coastguard Worker if (value->num_consumers == 1) {
967*4bdc9457SAndroid Build Coastguard Worker const uint32_t producer_id = value->producer;
968*4bdc9457SAndroid Build Coastguard Worker if (producer_id == XNN_INVALID_NODE_ID) {
969*4bdc9457SAndroid Build Coastguard Worker continue;
970*4bdc9457SAndroid Build Coastguard Worker }
971*4bdc9457SAndroid Build Coastguard Worker assert(producer_id < subgraph->num_nodes);
972*4bdc9457SAndroid Build Coastguard Worker
973*4bdc9457SAndroid Build Coastguard Worker const uint32_t consumer_id = value->first_consumer;
974*4bdc9457SAndroid Build Coastguard Worker if (consumer_id == XNN_INVALID_NODE_ID) {
975*4bdc9457SAndroid Build Coastguard Worker continue;
976*4bdc9457SAndroid Build Coastguard Worker }
977*4bdc9457SAndroid Build Coastguard Worker assert(consumer_id < subgraph->num_nodes);
978*4bdc9457SAndroid Build Coastguard Worker
979*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* producer = &subgraph->nodes[producer_id];
980*4bdc9457SAndroid Build Coastguard Worker assert(producer->type != xnn_node_type_invalid);
981*4bdc9457SAndroid Build Coastguard Worker struct xnn_node* consumer = &subgraph->nodes[consumer_id];
982*4bdc9457SAndroid Build Coastguard Worker assert(consumer->type != xnn_node_type_invalid);
983*4bdc9457SAndroid Build Coastguard Worker
984*4bdc9457SAndroid Build Coastguard Worker // Try to fuse Clamp Node upstream into producer Node
985*4bdc9457SAndroid Build Coastguard Worker if (consumer->type == xnn_node_type_clamp) {
986*4bdc9457SAndroid Build Coastguard Worker switch (producer->type) {
987*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_add2:
988*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_average_pooling_2d:
989*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_clamp:
990*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_convolution_2d:
991*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_divide:
992*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_deconvolution_2d:
993*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_depthwise_convolution_2d:
994*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_fully_connected:
995*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_multiply2:
996*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_max_pooling_2d:
997*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_subtract:
998*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("fuse Clamp Node #%"PRIu32" into upstream Node #%"PRIu32, consumer_id, producer_id);
999*4bdc9457SAndroid Build Coastguard Worker assert(producer->num_outputs == 1);
1000*4bdc9457SAndroid Build Coastguard Worker assert(consumer->num_inputs == 1);
1001*4bdc9457SAndroid Build Coastguard Worker assert(consumer->num_outputs == 1);
1002*4bdc9457SAndroid Build Coastguard Worker
1003*4bdc9457SAndroid Build Coastguard Worker const uint32_t fused_output_id = consumer->outputs[0];
1004*4bdc9457SAndroid Build Coastguard Worker assert(fused_output_id < subgraph->num_values);
1005*4bdc9457SAndroid Build Coastguard Worker subgraph->values[fused_output_id].producer = producer_id;
1006*4bdc9457SAndroid Build Coastguard Worker producer->outputs[0] = fused_output_id;
1007*4bdc9457SAndroid Build Coastguard Worker
1008*4bdc9457SAndroid Build Coastguard Worker producer->activation.output_min =
1009*4bdc9457SAndroid Build Coastguard Worker math_max_f32(producer->activation.output_min, consumer->activation.output_min);
1010*4bdc9457SAndroid Build Coastguard Worker producer->activation.output_max =
1011*4bdc9457SAndroid Build Coastguard Worker math_min_f32(producer->activation.output_max, consumer->activation.output_max);
1012*4bdc9457SAndroid Build Coastguard Worker
1013*4bdc9457SAndroid Build Coastguard Worker xnn_node_clear(consumer);
1014*4bdc9457SAndroid Build Coastguard Worker xnn_value_clear(value);
1015*4bdc9457SAndroid Build Coastguard Worker break;
1016*4bdc9457SAndroid Build Coastguard Worker default:
1017*4bdc9457SAndroid Build Coastguard Worker break;
1018*4bdc9457SAndroid Build Coastguard Worker }
1019*4bdc9457SAndroid Build Coastguard Worker }
1020*4bdc9457SAndroid Build Coastguard Worker // Try to fuse Constant Pad node downstream into [Depthwise] Convolution 2D Node
1021*4bdc9457SAndroid Build Coastguard Worker if (producer->type == xnn_node_type_static_constant_pad) {
1022*4bdc9457SAndroid Build Coastguard Worker assert(producer->num_inputs == 1);
1023*4bdc9457SAndroid Build Coastguard Worker assert(producer->num_outputs == 1);
1024*4bdc9457SAndroid Build Coastguard Worker const bool is_spatial_2d_padding = value->shape.num_dims == 4 &&
1025*4bdc9457SAndroid Build Coastguard Worker (producer->params.static_pad.pre_paddings[0] | producer->params.static_pad.post_paddings[0] |
1026*4bdc9457SAndroid Build Coastguard Worker producer->params.static_pad.pre_paddings[3] | producer->params.static_pad.post_paddings[3]) == 0;
1027*4bdc9457SAndroid Build Coastguard Worker const enum xnn_datatype padding_datatype = subgraph->values[producer->outputs[0]].datatype;
1028*4bdc9457SAndroid Build Coastguard Worker const uint32_t padding_value = producer->params.static_pad.padding_value;
1029*4bdc9457SAndroid Build Coastguard Worker const bool is_zero_padding =
1030*4bdc9457SAndroid Build Coastguard Worker (padding_datatype == xnn_datatype_fp32 && padding_value == 0) ||
1031*4bdc9457SAndroid Build Coastguard Worker ((padding_datatype == xnn_datatype_qint8 || padding_datatype == xnn_datatype_quint8) &&
1032*4bdc9457SAndroid Build Coastguard Worker padding_value == (uint32_t) (uint8_t) subgraph->values[producer->outputs[0]].quantization.zero_point);
1033*4bdc9457SAndroid Build Coastguard Worker switch (consumer->type) {
1034*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_convolution_2d:
1035*4bdc9457SAndroid Build Coastguard Worker if (is_spatial_2d_padding && is_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
1036*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Convolution 2D Node #%"PRIu32,
1037*4bdc9457SAndroid Build Coastguard Worker consumer_id, producer_id);
1038*4bdc9457SAndroid Build Coastguard Worker assert(consumer->num_inputs >= 1);
1039*4bdc9457SAndroid Build Coastguard Worker assert(consumer->inputs[0] == producer->outputs[0]);
1040*4bdc9457SAndroid Build Coastguard Worker
1041*4bdc9457SAndroid Build Coastguard Worker consumer->params.convolution_2d.input_padding_top += producer->params.static_pad.pre_paddings[1];
1042*4bdc9457SAndroid Build Coastguard Worker consumer->params.convolution_2d.input_padding_right += producer->params.static_pad.post_paddings[2];
1043*4bdc9457SAndroid Build Coastguard Worker consumer->params.convolution_2d.input_padding_bottom += producer->params.static_pad.post_paddings[1];
1044*4bdc9457SAndroid Build Coastguard Worker consumer->params.convolution_2d.input_padding_left += producer->params.static_pad.pre_paddings[2];
1045*4bdc9457SAndroid Build Coastguard Worker
1046*4bdc9457SAndroid Build Coastguard Worker consumer->inputs[0] = producer->inputs[0];
1047*4bdc9457SAndroid Build Coastguard Worker
1048*4bdc9457SAndroid Build Coastguard Worker const uint32_t fused_input_id = producer->inputs[0];
1049*4bdc9457SAndroid Build Coastguard Worker assert(fused_input_id < subgraph->num_values);
1050*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values[fused_input_id].first_consumer == producer_id) {
1051*4bdc9457SAndroid Build Coastguard Worker subgraph->values[fused_input_id].first_consumer = consumer_id;
1052*4bdc9457SAndroid Build Coastguard Worker }
1053*4bdc9457SAndroid Build Coastguard Worker
1054*4bdc9457SAndroid Build Coastguard Worker xnn_node_clear(producer);
1055*4bdc9457SAndroid Build Coastguard Worker xnn_value_clear(value);
1056*4bdc9457SAndroid Build Coastguard Worker }
1057*4bdc9457SAndroid Build Coastguard Worker break;
1058*4bdc9457SAndroid Build Coastguard Worker case xnn_node_type_depthwise_convolution_2d:
1059*4bdc9457SAndroid Build Coastguard Worker if (is_spatial_2d_padding && is_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
1060*4bdc9457SAndroid Build Coastguard Worker xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Depthwise Convolution 2D Node #%"PRIu32,
1061*4bdc9457SAndroid Build Coastguard Worker consumer_id, producer_id);
1062*4bdc9457SAndroid Build Coastguard Worker assert(consumer->num_inputs >= 1);
1063*4bdc9457SAndroid Build Coastguard Worker assert(consumer->inputs[0] == producer->outputs[0]);
1064*4bdc9457SAndroid Build Coastguard Worker
1065*4bdc9457SAndroid Build Coastguard Worker consumer->params.depthwise_convolution_2d.input_padding_top +=
1066*4bdc9457SAndroid Build Coastguard Worker producer->params.static_pad.pre_paddings[1];
1067*4bdc9457SAndroid Build Coastguard Worker consumer->params.depthwise_convolution_2d.input_padding_right +=
1068*4bdc9457SAndroid Build Coastguard Worker producer->params.static_pad.post_paddings[2];
1069*4bdc9457SAndroid Build Coastguard Worker consumer->params.depthwise_convolution_2d.input_padding_bottom +=
1070*4bdc9457SAndroid Build Coastguard Worker producer->params.static_pad.post_paddings[1];
1071*4bdc9457SAndroid Build Coastguard Worker consumer->params.depthwise_convolution_2d.input_padding_left +=
1072*4bdc9457SAndroid Build Coastguard Worker producer->params.static_pad.pre_paddings[2];
1073*4bdc9457SAndroid Build Coastguard Worker
1074*4bdc9457SAndroid Build Coastguard Worker consumer->inputs[0] = producer->inputs[0];
1075*4bdc9457SAndroid Build Coastguard Worker
1076*4bdc9457SAndroid Build Coastguard Worker const uint32_t fused_input_id = producer->inputs[0];
1077*4bdc9457SAndroid Build Coastguard Worker assert(fused_input_id < subgraph->num_values);
1078*4bdc9457SAndroid Build Coastguard Worker if (subgraph->values[fused_input_id].first_consumer == producer_id) {
1079*4bdc9457SAndroid Build Coastguard Worker subgraph->values[fused_input_id].first_consumer = consumer_id;
1080*4bdc9457SAndroid Build Coastguard Worker }
1081*4bdc9457SAndroid Build Coastguard Worker
1082*4bdc9457SAndroid Build Coastguard Worker xnn_node_clear(producer);
1083*4bdc9457SAndroid Build Coastguard Worker xnn_value_clear(value);
1084*4bdc9457SAndroid Build Coastguard Worker }
1085*4bdc9457SAndroid Build Coastguard Worker break;
1086*4bdc9457SAndroid Build Coastguard Worker default:
1087*4bdc9457SAndroid Build Coastguard Worker break;
1088*4bdc9457SAndroid Build Coastguard Worker }
1089*4bdc9457SAndroid Build Coastguard Worker }
1090*4bdc9457SAndroid Build Coastguard Worker }
1091*4bdc9457SAndroid Build Coastguard Worker }
1092*4bdc9457SAndroid Build Coastguard Worker
1093*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
1094*4bdc9457SAndroid Build Coastguard Worker }
1095*4bdc9457SAndroid Build Coastguard Worker
xnn_subgraph_optimize(xnn_subgraph_t subgraph,uint32_t flags)1096*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_subgraph_optimize(
1097*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph,
1098*4bdc9457SAndroid Build Coastguard Worker uint32_t flags)
1099*4bdc9457SAndroid Build Coastguard Worker {
1100*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_analyze_consumers_and_producers(subgraph);
1101*4bdc9457SAndroid Build Coastguard Worker
1102*4bdc9457SAndroid Build Coastguard Worker // Remove unreferenced values.
1103*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < subgraph->num_values; i++) {
1104*4bdc9457SAndroid Build Coastguard Worker struct xnn_value* value = &subgraph->values[i];
1105*4bdc9457SAndroid Build Coastguard Worker if (value->type == xnn_value_type_invalid) {
1106*4bdc9457SAndroid Build Coastguard Worker continue;
1107*4bdc9457SAndroid Build Coastguard Worker }
1108*4bdc9457SAndroid Build Coastguard Worker
1109*4bdc9457SAndroid Build Coastguard Worker if (!xnn_value_is_external_input(value) && value->num_consumers == 0) {
1110*4bdc9457SAndroid Build Coastguard Worker xnn_value_clear(value);
1111*4bdc9457SAndroid Build Coastguard Worker }
1112*4bdc9457SAndroid Build Coastguard Worker }
1113*4bdc9457SAndroid Build Coastguard Worker
1114*4bdc9457SAndroid Build Coastguard Worker
1115*4bdc9457SAndroid Build Coastguard Worker if (!(flags & XNN_FLAG_NO_OPERATOR_FUSION)) {
1116*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_fusion(subgraph);
1117*4bdc9457SAndroid Build Coastguard Worker }
1118*4bdc9457SAndroid Build Coastguard Worker
1119*4bdc9457SAndroid Build Coastguard Worker #if XNN_ENABLE_SPARSE
1120*4bdc9457SAndroid Build Coastguard Worker if ((flags & XNN_FLAG_HINT_SPARSE_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_CHW_OPT)) {
1121*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_rewrite_for_nchw(subgraph);
1122*4bdc9457SAndroid Build Coastguard Worker }
1123*4bdc9457SAndroid Build Coastguard Worker #endif
1124*4bdc9457SAndroid Build Coastguard Worker
1125*4bdc9457SAndroid Build Coastguard Worker if ((flags & XNN_FLAG_FORCE_FP16_INFERENCE) && !(xnn_params.init_flags & XNN_INIT_FLAG_F16)) {
1126*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to force FP16 inference: hardware supports neither native nor emulated FP16 operators");
1127*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_hardware;
1128*4bdc9457SAndroid Build Coastguard Worker }
1129*4bdc9457SAndroid Build Coastguard Worker #ifndef XNN_NO_F16_OPERATORS
1130*4bdc9457SAndroid Build Coastguard Worker const bool try_native_fp16 =
1131*4bdc9457SAndroid Build Coastguard Worker (flags & XNN_FLAG_HINT_FP16_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_F16_NATIVE);
1132*4bdc9457SAndroid Build Coastguard Worker const bool force_fp16 = (flags & XNN_FLAG_FORCE_FP16_INFERENCE);
1133*4bdc9457SAndroid Build Coastguard Worker if (try_native_fp16 || force_fp16) {
1134*4bdc9457SAndroid Build Coastguard Worker const bool fp16_rewrite_succeeded = xnn_subgraph_rewrite_for_fp16(subgraph);
1135*4bdc9457SAndroid Build Coastguard Worker if (force_fp16 && !fp16_rewrite_succeeded) {
1136*4bdc9457SAndroid Build Coastguard Worker xnn_log_error("failed to force FP16 inference: subgraph is incompatible with FP16 operators");
1137*4bdc9457SAndroid Build Coastguard Worker return xnn_status_unsupported_parameter;
1138*4bdc9457SAndroid Build Coastguard Worker }
1139*4bdc9457SAndroid Build Coastguard Worker }
1140*4bdc9457SAndroid Build Coastguard Worker #endif // XNN_NO_F16_OPERATORS
1141*4bdc9457SAndroid Build Coastguard Worker
1142*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
1143*4bdc9457SAndroid Build Coastguard Worker }
1144*4bdc9457SAndroid Build Coastguard Worker
xnn_delete_subgraph(xnn_subgraph_t subgraph)1145*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_delete_subgraph(
1146*4bdc9457SAndroid Build Coastguard Worker xnn_subgraph_t subgraph)
1147*4bdc9457SAndroid Build Coastguard Worker {
1148*4bdc9457SAndroid Build Coastguard Worker if (subgraph != NULL) {
1149*4bdc9457SAndroid Build Coastguard Worker memset(subgraph->nodes, 0, sizeof(struct xnn_node) * subgraph->num_nodes);
1150*4bdc9457SAndroid Build Coastguard Worker xnn_release_memory(subgraph->nodes);
1151*4bdc9457SAndroid Build Coastguard Worker
1152*4bdc9457SAndroid Build Coastguard Worker memset(subgraph->values, 0, sizeof(struct xnn_value) * subgraph->num_values);
1153*4bdc9457SAndroid Build Coastguard Worker xnn_release_memory(subgraph->values);
1154*4bdc9457SAndroid Build Coastguard Worker
1155*4bdc9457SAndroid Build Coastguard Worker memset(subgraph, 0, sizeof(struct xnn_subgraph));
1156*4bdc9457SAndroid Build Coastguard Worker xnn_release_memory(subgraph);
1157*4bdc9457SAndroid Build Coastguard Worker }
1158*4bdc9457SAndroid Build Coastguard Worker return xnn_status_success;
1159*4bdc9457SAndroid Build Coastguard Worker }
1160