xref: /aosp_15_r20/external/XNNPACK/src/subgraph.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <math.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <stdlib.h>
10 
11 #include <fp16.h>
12 
13 #include <xnnpack.h>
14 #include <xnnpack/allocator.h>
15 #include <xnnpack/log.h>
16 #include <xnnpack/math.h>
17 #include <xnnpack/params.h>
18 #include <xnnpack/subgraph.h>
19 
20 
21 #ifndef XNN_ENABLE_SPARSE
22   #error "XNN_ENABLE_SPARSE not defined"
23 #endif
24 
xnn_create_subgraph(uint32_t external_value_ids,uint32_t flags,xnn_subgraph_t * subgraph_out)25 enum xnn_status xnn_create_subgraph(
26     uint32_t external_value_ids,
27     uint32_t flags,
28     xnn_subgraph_t* subgraph_out)
29 {
30   struct xnn_subgraph* subgraph = NULL;
31   enum xnn_status status = xnn_status_uninitialized;
32 
33   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
34     xnn_log_error("failed to create subgraph: XNNPACK is not initialized");
35     goto error;
36   }
37 
38   status = xnn_status_out_of_memory;
39 
40   subgraph = xnn_allocate_zero_memory(sizeof(struct xnn_subgraph));
41   if (subgraph == NULL) {
42     xnn_log_error("failed to allocate %zu bytes for subgraph descriptor", sizeof(struct xnn_subgraph));
43     goto error;
44   }
45 
46   subgraph->external_value_ids = external_value_ids;
47 
48   subgraph->values = xnn_allocate_zero_memory(external_value_ids * sizeof(struct xnn_value));
49   if (subgraph->values == NULL) {
50     xnn_log_error("failed to allocate %zu bytes for subgraph values",
51       (size_t) external_value_ids * sizeof(struct xnn_value));
52     goto error;
53   }
54   for (size_t i = 0; i < external_value_ids; i++) {
55     subgraph->values[i].id = i;
56   }
57   subgraph->num_values = external_value_ids;
58   subgraph->num_reserved_values = external_value_ids;
59 
60   *subgraph_out = subgraph;
61   return xnn_status_success;
62 
63 error:
64   xnn_delete_subgraph(subgraph);
65   return status;
66 }
67 
68 
xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph)69 struct xnn_value* xnn_subgraph_new_internal_value(xnn_subgraph_t subgraph)
70 {
71   struct xnn_value* values = subgraph->values;
72   const size_t size = subgraph->num_values;
73   const size_t capacity = subgraph->num_reserved_values;
74   if (capacity < size + 1) {
75     const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
76     assert(new_capacity >= size + 1);
77     values = xnn_reallocate_memory(values, new_capacity * sizeof(struct xnn_value));
78     if (values == NULL) {
79       xnn_log_error("failed to allocate %zu bytes for subgraph values",
80         capacity * sizeof(struct xnn_value));
81       return values;
82     }
83 
84     memset(values + size, 0, (new_capacity - size) * sizeof(struct xnn_value));
85     subgraph->num_reserved_values = new_capacity;
86     subgraph->values = values;
87   }
88   subgraph->num_values = size + 1;
89   struct xnn_value* new_value = values + size;
90   new_value->id = size;
91   return new_value;
92 }
93 
xnn_node_clear(struct xnn_node * node)94 void xnn_node_clear(struct xnn_node* node) {
95   assert(node != NULL);
96   memset(node, 0, sizeof(struct xnn_node));
97 }
98 
xnn_value_clear(struct xnn_value * value)99 void xnn_value_clear(struct xnn_value* value) {
100   assert(value != NULL);
101   memset(value, 0, sizeof(struct xnn_value));
102 }
103 
xnn_value_copy(struct xnn_value * dst_value,const struct xnn_value * src_value)104 void xnn_value_copy(
105   struct xnn_value* dst_value,
106   const struct xnn_value* src_value)
107 {
108   // Note: Value ID stays unchanged
109 
110   dst_value->type = src_value->type;
111   dst_value->datatype = src_value->datatype;
112   dst_value->quantization = src_value->quantization;
113   dst_value->shape = src_value->shape;
114   dst_value->flags = src_value->flags;
115   dst_value->data = src_value->data;
116   dst_value->producer = src_value->producer;
117   dst_value->first_consumer = src_value->first_consumer;
118 }
119 
xnn_subgraph_new_node(xnn_subgraph_t subgraph)120 struct xnn_node* xnn_subgraph_new_node(xnn_subgraph_t subgraph)
121 {
122   struct xnn_node* nodes = subgraph->nodes;
123   const size_t size = subgraph->num_nodes;
124   const size_t capacity = subgraph->num_reserved_nodes;
125 
126   if (capacity < size + 1) {
127     const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + 64);
128     assert(new_capacity >= size + 1);
129     nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node));
130     if (nodes == NULL) {
131       xnn_log_error("failed to allocate %zu bytes for subgraph nodes",
132         capacity * sizeof(struct xnn_node));
133       return nodes;
134     }
135 
136     memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node));
137     subgraph->num_reserved_nodes = new_capacity;
138     subgraph->nodes = nodes;
139   }
140   subgraph->num_nodes = size + 1;
141   struct xnn_node* new_node = nodes + size;
142   new_node->id = size;
143   return new_node;
144 }
145 
xnn_subgraph_add_nodes(xnn_subgraph_t subgraph,size_t num_nodes)146 void xnn_subgraph_add_nodes(xnn_subgraph_t subgraph, size_t num_nodes)
147 {
148   struct xnn_node* nodes = subgraph->nodes;
149   const size_t size = subgraph->num_nodes;
150   const size_t capacity = subgraph->num_reserved_nodes;
151 
152   if (capacity < size + num_nodes) {
153     const size_t new_capacity = max(min(capacity * 2, capacity + 512), capacity + max(num_nodes, 64));
154     assert(new_capacity >= size + num_nodes);
155     nodes = xnn_reallocate_memory(nodes, new_capacity * sizeof(struct xnn_node));
156     if (nodes == NULL) {
157       xnn_log_error("failed to allocate %zu bytes for subgraph nodes",
158         capacity * sizeof(struct xnn_node));
159       return;
160     }
161 
162     memset(nodes + size, 0, (new_capacity - size) * sizeof(struct xnn_node));
163     subgraph->num_reserved_nodes = new_capacity;
164     subgraph->nodes = nodes;
165   }
166   subgraph->num_nodes = size + num_nodes;
167   struct xnn_node* new_nodes = nodes + size;
168   for (size_t i = 0; i < num_nodes; i++) {
169     new_nodes[i].id = size + i;
170   }
171 }
172 
xnn_subgraph_analyze_consumers_and_producers(xnn_subgraph_t subgraph)173 void xnn_subgraph_analyze_consumers_and_producers(xnn_subgraph_t subgraph)
174 {
175   // Initialize producer/consumer fields to safe defaults.
176   for (uint32_t i = 0; i < subgraph->num_values; i++) {
177     struct xnn_value* value = &subgraph->values[i];
178     value->producer = XNN_INVALID_NODE_ID;
179     value->first_consumer = XNN_INVALID_NODE_ID;
180     value->num_consumers = 0;
181   }
182 
183   // Analyse Nodes' inputs and output and update Values' producer/consumer fields
184   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
185     struct xnn_node* node = &subgraph->nodes[n];
186 
187     for (uint32_t i = 0; i < node->num_inputs; i++) {
188       const uint32_t input_id = node->inputs[i];
189       assert(input_id < subgraph->num_values);
190 
191       if (subgraph->values[input_id].num_consumers++ == 0) {
192         assert(subgraph->values[input_id].first_consumer == XNN_INVALID_NODE_ID);
193         subgraph->values[input_id].first_consumer = n;
194       }
195     }
196 
197     for (uint32_t o = 0; o < node->num_outputs; o++) {
198       const uint32_t output_id = node->outputs[o];
199       assert(output_id < subgraph->num_values);
200 
201       assert(subgraph->values[output_id].producer == XNN_INVALID_NODE_ID);
202       subgraph->values[output_id].producer = n;
203     }
204   }
205 
206   // Count extra consumer for Values which are external outputs.
207   // Remove unreferenced values.
208   for (uint32_t i = 0; i < subgraph->num_values; i++) {
209     struct xnn_value* value = &subgraph->values[i];
210     if (xnn_value_is_external_output(value)) {
211       value->num_consumers += 1;
212     }
213   }
214 }
215 
216 #define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW      1
217 #define XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW 2
218 #define XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC 4
219 #define XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER 8
220 
xnn_check_nchw_compatibility(xnn_subgraph_t subgraph,struct xnn_node * node)221 uint32_t xnn_check_nchw_compatibility(xnn_subgraph_t subgraph, struct xnn_node* node) {
222   if (node->compute_type != xnn_compute_type_fp32) {
223     if (node->type != xnn_node_type_invalid) {
224       xnn_log_info(
225           "Node %s compute type %d is incompatible with sparse inference",
226           xnn_node_type_to_string(node->type), node->compute_type);
227     }
228     return 0;
229   }
230 
231   switch (node->type) {
232     case xnn_node_type_convolution_2d:
233       // Supported cases:
234       // - 1x1 convolution (no stride, no dilation, no padding, no groups)
235       // - 3x3 stride-2 convolution (no dilation, padding 1 on each side, no groups, 3 input channels)
236       if (node->params.convolution_2d.groups != 1) {
237         xnn_log_info("Node %s groups (%" PRIu32 ") "
238                      "is incompatible with sparse inference",
239                      xnn_node_type_to_string(node->type),
240                      node->params.convolution_2d.groups);
241         return 0;
242       }
243       if ((node->params.convolution_2d.dilation_height | node->params.convolution_2d.dilation_width) != 1) {
244         xnn_log_info("Node %s dilation (height=%" PRIu32 ", width=%" PRIu32 ") "
245                      "is incompatible with sparse inference",
246                      xnn_node_type_to_string(node->type),
247                      node->params.convolution_2d.dilation_height,
248                      node->params.convolution_2d.dilation_width);
249         return 0;
250       }
251       if ((node->params.convolution_2d.kernel_height | node->params.convolution_2d.kernel_width) == 1) {
252         if ((node->params.convolution_2d.input_padding_top | node->params.convolution_2d.input_padding_right |
253              node->params.convolution_2d.input_padding_bottom | node->params.convolution_2d.input_padding_left) != 0) {
254           xnn_log_info("Node %s (1x1 kernel) padding (top=%" PRIu32 ", right=%" PRIu32", bottom=%" PRIu32 ", left=%" PRIu32") "
255                        "is incompatible with sparse inference",
256                        xnn_node_type_to_string(node->type),
257                        node->params.convolution_2d.input_padding_top,
258                        node->params.convolution_2d.input_padding_right,
259                        node->params.convolution_2d.input_padding_bottom,
260                        node->params.convolution_2d.input_padding_left);
261           return 0;
262         }
263         if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 1) {
264           xnn_log_info("Node %s (1x1 kernel) subsampling (height=%" PRIu32 ", width=%" PRIu32 ") "
265                        "is incompatible with sparse inference",
266                        xnn_node_type_to_string(node->type),
267                        node->params.convolution_2d.subsampling_height,
268                        node->params.convolution_2d.subsampling_width);
269           return 0;
270         }
271         return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
272       } else if (node->params.convolution_2d.kernel_height == 3 && node->params.convolution_2d.kernel_width == 3) {
273         if (node->params.convolution_2d.input_padding_top != 1 || node->params.convolution_2d.input_padding_right != 1 ||
274             node->params.convolution_2d.input_padding_bottom != 1 || node->params.convolution_2d.input_padding_left != 1) {
275           xnn_log_info("Node %s (3x3 kernel) padding (top=%" PRIu32 ", right=%" PRIu32 ", bottom=%" PRIu32 ", left=%" PRIu32 ") "
276                        "is incompatible with sparse inference",
277                        xnn_node_type_to_string(node->type),
278                        node->params.convolution_2d.input_padding_top,
279                        node->params.convolution_2d.input_padding_right,
280                        node->params.convolution_2d.input_padding_bottom,
281                        node->params.convolution_2d.input_padding_left);
282           return 0;
283         }
284         if ((node->params.convolution_2d.subsampling_height | node->params.convolution_2d.subsampling_width) != 2) {
285           xnn_log_info("Node %s (3x3 kernel) subsampling (height=%" PRIu32 ", width=%" PRIu32 ") "
286                        "is incompatible with sparse inference",
287                        xnn_node_type_to_string(node->type),
288                        node->params.convolution_2d.subsampling_height,
289                        node->params.convolution_2d.subsampling_width);
290           return 0;
291         }
292         if (node->params.convolution_2d.group_input_channels != 3) {
293           xnn_log_info("Node %s (3x3 kernel) input channels (%zu) "
294                        "is incompatible with sparse inference",
295                        xnn_node_type_to_string(node->type),
296                        node->params.convolution_2d.group_input_channels);
297           return 0;
298         }
299         return XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW;
300       }
301       return 0;
302     case xnn_node_type_depthwise_convolution_2d:
303       // Supported cases:
304       // - 3x3 stride-1 convolution (no dilation, padding 1 on each side)
305       // - 3x3 stride-2 convolution (no dilation, padding 1 on each side)
306       // - 5x5 stride-1 convolution (no dilation, padding 2 on each side)
307       // - 5x5 stride-2 convolution (no dilation, padding 2 on each side)
308       if ((node->params.depthwise_convolution_2d.dilation_height | node->params.depthwise_convolution_2d.dilation_width) != 1) {
309         xnn_log_info("Node %s dilation (height=%" PRIu32 ", width=%" PRIu32 ") "
310                      "is incompatible with sparse inference",
311                      xnn_node_type_to_string(node->type),
312                      node->params.convolution_2d.dilation_height,
313                      node->params.convolution_2d.dilation_width);
314         return 0;
315       }
316       if (node->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
317         xnn_log_info("Node %s flags (%" PRIu32 ") has padding incompatible with sparse inference",
318                      xnn_node_type_to_string(node->type),
319                      node->flags);
320         return 0;
321       }
322       if (node->params.depthwise_convolution_2d.depth_multiplier != 1) {
323         xnn_log_info("Node %s depth_multiplier (%" PRIu32 ") is incompatible with sparse inference",
324                      xnn_node_type_to_string(node->type),
325                      node->params.depthwise_convolution_2d.depth_multiplier);
326         return 0;
327       }
328       if (node->params.depthwise_convolution_2d.subsampling_height != node->params.depthwise_convolution_2d.subsampling_width) {
329         xnn_log_info("Node %s subsampling (height=%" PRIu32 ", width=%" PRIu32 ") "
330                      "is incompatible with sparse inference",
331                      xnn_node_type_to_string(node->type),
332                      node->params.depthwise_convolution_2d.subsampling_height,
333                      node->params.depthwise_convolution_2d.subsampling_width);
334         return 0;
335       }
336       switch (node->params.depthwise_convolution_2d.subsampling_height) {
337         case 1:
338         case 2:
339           break;
340         default:
341           xnn_log_info("Node %s subsampling_height (%" PRIu32 ") "
342                        "is incompatible with sparse inference",
343                         xnn_node_type_to_string(node->type),
344                         node->params.depthwise_convolution_2d.subsampling_height);
345           return 0;
346       }
347       if (node->params.depthwise_convolution_2d.kernel_height != node->params.depthwise_convolution_2d.kernel_width) {
348          xnn_log_info("Node %s kernel (height=%" PRIu32 ", width=%" PRIu32 ") "
349                       "is incompatible with sparse inference",
350                       xnn_node_type_to_string(node->type),
351                       node->params.depthwise_convolution_2d.kernel_height,
352                       node->params.depthwise_convolution_2d.kernel_width);
353         return 0;
354       }
355       switch (node->params.depthwise_convolution_2d.kernel_height) {
356         case 3:
357           if (node->params.depthwise_convolution_2d.input_padding_top == 1 &&
358               node->params.depthwise_convolution_2d.input_padding_right == 1 &&
359               node->params.depthwise_convolution_2d.input_padding_bottom == 1 &&
360               node->params.depthwise_convolution_2d.input_padding_left == 1) {
361             return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
362           } else {
363             xnn_log_info("Node %s (3x3 kernel) padding "
364                          "(top=%" PRIu32 ", right=%" PRIu32 ", bottom=%" PRIu32 ", left=%" PRIu32 ") "
365                          "is incompatible with sparse inference",
366                          xnn_node_type_to_string(node->type),
367                          node->params.depthwise_convolution_2d.input_padding_top,
368                          node->params.depthwise_convolution_2d.input_padding_right,
369                          node->params.depthwise_convolution_2d.input_padding_bottom,
370                          node->params.depthwise_convolution_2d.input_padding_left);
371             return 0;
372           }
373         case 5:
374           if (node->params.depthwise_convolution_2d.input_padding_top == 2 &&
375               node->params.depthwise_convolution_2d.input_padding_right == 2 &&
376               node->params.depthwise_convolution_2d.input_padding_bottom == 2 &&
377               node->params.depthwise_convolution_2d.input_padding_left == 2) {
378             return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
379           } else {
380             xnn_log_info("Node %s (5x5 kernel) padding "
381                          "(top=%" PRIu32 ", right=%" PRIu32 ", bottom=%" PRIu32 ", left=%" PRIu32 ") "
382                          "is incompatible with sparse inference",
383                          xnn_node_type_to_string(node->type),
384                          node->params.depthwise_convolution_2d.input_padding_top,
385                          node->params.depthwise_convolution_2d.input_padding_right,
386                          node->params.depthwise_convolution_2d.input_padding_bottom,
387                          node->params.depthwise_convolution_2d.input_padding_left);
388             return 0;
389           }
390         default:
391           return 0;
392       }
393     case xnn_node_type_depth_to_space:
394       return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
395     case xnn_node_type_global_average_pooling_2d:
396       return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
397     case xnn_node_type_add2:
398     case xnn_node_type_multiply2:
399       assert(node->num_inputs == 2);
400       assert(node->num_outputs == 1);
401       if (subgraph->values[node->inputs[0]].shape.num_dims != 4 ||
402           subgraph->values[node->inputs[1]].shape.num_dims != 4)
403       {
404         xnn_log_info("Node %s inputs shape is incompatible with sparse inference",
405                      xnn_node_type_to_string(node->type));
406         return 0;
407       }
408 
409       if (subgraph->values[node->inputs[0]].data != NULL) {
410         // Check that the first input is representable as either a scalar, or a vector
411         size_t num_nonunit_dims = 0;
412         for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
413           if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
414             num_nonunit_dims += 1;
415           }
416         }
417         if (num_nonunit_dims > 1) {
418           return 0;
419         }
420       }
421 
422       if (subgraph->values[node->inputs[1]].data != NULL) {
423         // Check that the second input is representable as either a scalar, or a vector
424         size_t num_nonunit_dims = 0;
425         for (uint32_t i = 0; i < subgraph->values[node->inputs[0]].shape.num_dims; i++) {
426           if (subgraph->values[node->inputs[0]].shape.dim[i] != 1) {
427             num_nonunit_dims += 1;
428           }
429         }
430         if (num_nonunit_dims > 1) {
431           return 0;
432         }
433       }
434 
435       return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
436     case xnn_node_type_static_resize_bilinear_2d:
437       if (subgraph->values[node->inputs[0]].shape.dim[1] > 1 &&
438           subgraph->values[node->inputs[0]].shape.dim[2] > 1) {
439         return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
440       } else {
441         xnn_log_info("Node %s inputs shape is incompatible with sparse inference",
442                      xnn_node_type_to_string(node->type));
443         return 0;
444       }
445     case xnn_node_type_abs:
446     case xnn_node_type_bankers_rounding:
447     case xnn_node_type_ceiling:
448     case xnn_node_type_clamp:
449     case xnn_node_type_elu:
450     case xnn_node_type_floor:
451     case xnn_node_type_hardswish:
452     case xnn_node_type_leaky_relu:
453     case xnn_node_type_negate:
454     case xnn_node_type_sigmoid:
455     case xnn_node_type_square:
456       assert(node->num_inputs == 1);
457       assert(node->num_outputs == 1);
458       if (subgraph->values[node->inputs[0]].shape.num_dims == 4) {
459         return XNN_LAYOUT_FLAG_COMPATIBLE_NCHW;
460       } else {
461         xnn_log_info("Node %s inputs shape is incompatible with sparse inference",
462                      xnn_node_type_to_string(node->type));
463         return 0;
464       }
465     default:
466       return false;
467   }
468 }
469 
xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)470 void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph)
471 {
472   // Convert parts of the subgraph to NCHW for sparse inference
473   // Step 1: detect NCHW-compatible Nodes
474   // Step 2: detect NCHW-compatible clusters (run connected components graph algorithm)
475   // Step 3: check that all NCHW-compatible Values are consumed only by NCHW-compatible Nodes
476   // Step 4: switch Values' layout to NCHW
477   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
478     struct xnn_node* node = &subgraph->nodes[n];
479     node->layout_flags = xnn_check_nchw_compatibility(subgraph, node);
480     xnn_log_debug("Node #%" PRIu32 ": %s (NCHW: %s, NHWC->NCHW: %s, NCHW->NHWC: %s)",
481       n, xnn_node_type_to_string(node->type),
482       node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW ? "yes" : "no",
483       node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW ? "yes" : "no",
484       node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC ? "yes" : "no");
485   }
486 
487   // Run Shiloach-Vishkin connected components algorithm i.e. find all
488   // XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC nodes and set them as cluster leaders
489   // to all the producer nodes
490   bool update = false;
491   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
492     struct xnn_node* node = &subgraph->nodes[n];
493     node->cluster_leader = n;
494     if (node->layout_flags & XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC) {
495       for (uint32_t i = 0; i < node->num_inputs; i++) {
496         const struct xnn_value* value = &subgraph->values[node->inputs[i]];
497         if (value->data != NULL) {
498           // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
499           // during the initial NCHW compatibility check for the Node.
500           continue;
501         }
502         if (xnn_value_is_external(value)) {
503           // External value, invalid cluster
504           node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
505           continue;
506         }
507         const uint32_t producer_id = value->producer;
508         assert(producer_id != XNN_INVALID_NODE_ID);
509         assert(producer_id < n);
510         struct xnn_node* producer_node = &subgraph->nodes[producer_id];
511         if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
512             (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
513         {
514           producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
515           if (producer_node->cluster_leader != node->cluster_leader) {
516             producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
517             update = true;
518           }
519         } else {
520           node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
521         }
522       }
523     }
524   }
525   // No NCHW2NHWC compatible nodes have been found thus the graph rewriting
526   // practically cannot happen.
527   if (!update) {
528     return;
529   }
530   // Propagate the cluster leader to other nodes in the graph untill all the
531   // nodes in the cluster is not updated
532   while (update) {
533     update = false;
534     for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
535       struct xnn_node* node = &subgraph->nodes[n];
536       if (node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) {
537         continue;
538       }
539 
540       if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC)) == 0) {
541         continue;
542       }
543 
544       for (uint32_t i = 0; i < node->num_inputs; i++) {
545         const struct xnn_value* value = &subgraph->values[node->inputs[i]];
546         if (value->data != NULL) {
547           // Static data, skip this input value. Compatibility of this static input with NCHW layout was validated
548           // during the initial NCHW compatibility check for the Node.
549           continue;
550         }
551         if (xnn_value_is_external(value)) {
552           // External value, invalid cluster
553           node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
554           continue;
555         }
556         const uint32_t producer_id = value->producer;
557         assert(producer_id != XNN_INVALID_NODE_ID);
558         assert(producer_id < n);
559         struct xnn_node* producer_node = &subgraph->nodes[producer_id];
560         if ((producer_node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NHWC2NCHW | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) != 0 &&
561             (producer_node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) == 0)
562         {
563           producer_node->layout_flags &= ~XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC;
564           if (producer_node->cluster_leader != node->cluster_leader) {
565             producer_node->cluster_leader = node->cluster_leader = math_max_u32(producer_node->cluster_leader, node->cluster_leader);
566             update = true;
567           }
568         } else {
569           node->layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
570         }
571       }
572     }
573   }
574   // Propagate XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER flags up to the cluster leaders
575   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
576     struct xnn_node* node = &subgraph->nodes[n];
577     subgraph->nodes[node->cluster_leader].layout_flags |= node->layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
578   }
579   // Check that all Values consumed by NCHW-compatible cluster don't have NCHW-incompatible consumers
580   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
581     struct xnn_node* node = &subgraph->nodes[n];
582     if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
583       continue;
584     }
585 
586     if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
587       continue;
588     }
589 
590     for (uint32_t i = 0; i < node->num_inputs; i++) {
591       struct xnn_value* value = &subgraph->values[node->inputs[i]];
592       if (value->data != NULL) {
593         // Static data, skip this input value because it doesn't have a producer Node.
594         continue;
595       }
596       assert(!xnn_value_is_external(value));
597       value->num_nchw_compatible_consumers += 1;
598     }
599   }
600   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
601     struct xnn_node* node = &subgraph->nodes[n];
602     if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
603       continue;
604     }
605 
606     if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
607       continue;
608     }
609 
610     for (uint32_t i = 0; i < node->num_inputs; i++) {
611       const struct xnn_value* value = &subgraph->values[node->inputs[i]];
612       if (value->data != NULL) {
613         // Static data, skip this input value because it doesn't have a producer Node.
614         continue;
615       }
616       assert(!xnn_value_is_external(value));
617       assert(value->num_nchw_compatible_consumers > 0);
618       if (value->num_nchw_compatible_consumers != value->num_consumers) {
619         subgraph->nodes[node->cluster_leader].layout_flags |= XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER;
620       }
621     }
622   }
623   // Evaluate if it is profitable to run the model as sparse:
624   // - Compute the number of parameters and zeroes in 1x1 Convolution weights
625   // - Disable sparse rewriting for clusters without 1x1 Convolutions (num_params == 0)
626   //   or with less than 2/3rd of zeroes in 1x1 Convolution filters
627   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
628     struct xnn_node* node = &subgraph->nodes[n];
629     if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
630       continue;
631     }
632 
633     if (node->type == xnn_node_type_convolution_2d &&
634         max(node->params.convolution_2d.kernel_height, node->params.convolution_2d.kernel_width) == 1)
635     {
636       assert(node->num_inputs >= 2);
637 
638       const struct xnn_value* filter = &subgraph->values[node->inputs[1]];
639       assert(filter->data != NULL);
640       assert(filter->shape.num_dims == 4);
641 
642       const size_t num_params = filter->shape.dim[0] * filter->shape.dim[3];
643       subgraph->nodes[node->cluster_leader].num_params += num_params;
644 
645       const float* data = (const float*) filter->data;
646       size_t num_zeroes = 0;
647       for (size_t i = 0; i < num_params; i++) {
648         num_zeroes += (size_t) (data[i] == 0.0f);
649       }
650       xnn_log_debug("1x1 Convolution 2D Node #%" PRIu32 ": %zu / %zu sparsity", n, num_zeroes, num_params);
651       subgraph->nodes[node->cluster_leader].num_zeroes += num_zeroes;
652     }
653   }
654   bool use_nchw_layout = false;
655   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
656     struct xnn_node* node = &subgraph->nodes[n];
657     if ((subgraph->nodes[node->cluster_leader].layout_flags & XNN_LAYOUT_FLAG_INCOMPATIBLE_CLUSTER) != 0) {
658       continue;
659     }
660 
661     if ((node->layout_flags & (XNN_LAYOUT_FLAG_COMPATIBLE_NCHW2NHWC | XNN_LAYOUT_FLAG_COMPATIBLE_NCHW)) == 0) {
662       continue;
663     }
664 
665     if (subgraph->nodes[node->cluster_leader].num_zeroes * 3 <= subgraph->nodes[node->cluster_leader].num_params * 2) {
666       xnn_log_info("Node #%" PRIu32 ": sparse inference disabled: 1x1 Convolutions contain %zu / %zu zero weights",
667         n, subgraph->nodes[node->cluster_leader].num_zeroes, subgraph->nodes[node->cluster_leader].num_params);
668       continue;
669     }
670 
671     for (uint32_t i = 0; i < node->num_inputs; i++) {
672       struct xnn_value* value = &subgraph->values[node->inputs[i]];
673       if (value->data != NULL) {
674         // Static data, skip this input value because it doesn't have a producer Node.
675         continue;
676       }
677       assert(!xnn_value_is_external(value));
678       assert(value->num_nchw_compatible_consumers > 0);
679       assert(value->num_nchw_compatible_consumers == value->num_consumers);
680       if (value->layout != xnn_layout_type_nchw) {
681         value->layout = xnn_layout_type_nchw;
682         xnn_log_info("set Value #%"PRIu32" layout to NCHW", node->inputs[i]);
683         use_nchw_layout = true;
684       }
685     }
686   }
687   if (use_nchw_layout) {
688     xnn_log_info("XNNPACK has switched to sparse inference mode!");
689   }
690 }
691 
xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)692 bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
693 {
694   xnn_log_info("Analyzing subgraph for FP16 compatibility");
695 
696   // Convert tensors and operators in the subgraph to FP16
697   // 1. Check that all operators in the subgraph are supported in FP16.
698   // 2. Indicate values that must be converted to FP16.
699   // 3. Replace FP32 Values with FP16 Values as Nodes' inputs/outputs.
700   // 4. Insert FP32->FP16 Convert Nodes for external FP32 inputs and FP16->FP32 Convert Nodes for external outputs.
701 
702   // Check that all operators in the subgraph are supported in FP16, bail out on any unsupported one.
703   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
704     struct xnn_node* node = &subgraph->nodes[n];
705     if (node->type == xnn_node_type_invalid) {
706       // Node was fused away, skip.
707       continue;
708     }
709 
710     if (node->compute_type != xnn_compute_type_fp32) {
711       xnn_log_warning("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not FP32", n, xnn_node_type_to_string(node->type));
712       return false;
713     }
714     for (uint32_t i = 0; i < node->num_inputs; i++) {
715       if (subgraph->values[node->inputs[i]].layout == xnn_layout_type_nchw) {
716         xnn_log_warning(
717           "FP16 rewrite aborted: input #%" PRIu32 " (Value #%" PRIu32 ") of node #%" PRIu32 " (%s) has NCHW layout",
718           i, node->inputs[i], n, xnn_node_type_to_string(node->type));
719         return false;
720       }
721     }
722     for (uint32_t o = 0; o < node->num_outputs; o++) {
723       if (subgraph->values[node->outputs[o]].layout == xnn_layout_type_nchw) {
724         xnn_log_warning(
725           "FP16 rewrite aborted: output #%" PRIu32 " (Value #%" PRIu32 ") of node #%" PRIu32 " (%s) has NCHW layout",
726           o, node->outputs[o], n, xnn_node_type_to_string(node->type));
727         return false;
728       }
729     }
730     switch (node->type) {
731       case xnn_node_type_abs:
732       case xnn_node_type_add2:
733       case xnn_node_type_divide:
734       case xnn_node_type_maximum2:
735       case xnn_node_type_minimum2:
736       case xnn_node_type_multiply2:
737       case xnn_node_type_concatenate2:
738       case xnn_node_type_concatenate3:
739       case xnn_node_type_concatenate4:
740       case xnn_node_type_squared_difference:
741       case xnn_node_type_subtract:
742         for (uint32_t i = 0; i < node->num_inputs; i++) {
743           if (subgraph->values[node->inputs[i]].data != NULL) {
744             xnn_log_warning("FP16 rewrite aborted: node #%" PRIu32 " (%s) has static input %" PRIu32,
745               n, xnn_node_type_to_string(node->type), i);
746             return false;
747           }
748         }
749         break;
750       case xnn_node_type_average_pooling_2d:
751       case xnn_node_type_bankers_rounding:
752       case xnn_node_type_ceiling:
753       case xnn_node_type_clamp:
754       case xnn_node_type_convolution_2d:
755       case xnn_node_type_deconvolution_2d:
756       case xnn_node_type_depthwise_convolution_2d:
757       case xnn_node_type_depth_to_space:
758       case xnn_node_type_elu:
759       case xnn_node_type_even_split2:
760       case xnn_node_type_even_split3:
761       case xnn_node_type_even_split4:
762       case xnn_node_type_floor:
763       case xnn_node_type_fully_connected:
764       case xnn_node_type_global_average_pooling_2d:
765       case xnn_node_type_hardswish:
766       case xnn_node_type_leaky_relu:
767       case xnn_node_type_max_pooling_2d:
768       case xnn_node_type_negate:
769       case xnn_node_type_prelu:
770       case xnn_node_type_sigmoid:
771       case xnn_node_type_softmax:
772       case xnn_node_type_static_constant_pad:
773       case xnn_node_type_static_reshape:
774       case xnn_node_type_static_resize_bilinear_2d:
775       case xnn_node_type_static_transpose:
776       case xnn_node_type_square:
777       case xnn_node_type_square_root:
778         break;
779       default:
780         xnn_log_warning("FP16 rewrite aborted: node #%" PRIu32 " (%s) is not supported for FP16 inference",
781           n, xnn_node_type_to_string(node->type));
782         return false;
783     }
784   }
785 
786   // Annotate Values to be converted to FP16 as FP16-compatible.
787   // Note that static weights in [Depthwise] Convolution, Fully Connected, and PReLU Nodes remain FP32,
788   // they will be converted to FP16 during weight repacking when the operator is created.
789   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
790     struct xnn_node* node = &subgraph->nodes[n];
791     switch (node->type) {
792       case xnn_node_type_convolution_2d:
793       case xnn_node_type_deconvolution_2d:
794       case xnn_node_type_depthwise_convolution_2d:
795       case xnn_node_type_fully_connected:
796       case xnn_node_type_prelu:
797         subgraph->values[node->inputs[0]].fp16_compatible = true;
798         subgraph->values[node->outputs[0]].fp16_compatible = true;
799         break;
800       default:
801         for (uint32_t i = 0; i < node->num_inputs; i++) {
802           subgraph->values[node->inputs[i]].fp16_compatible = true;
803         }
804         for (uint32_t o = 0; o < node->num_outputs; o++) {
805           subgraph->values[node->outputs[o]].fp16_compatible = true;
806         }
807         break;
808     }
809   }
810 
811   // Replace FP32 Values in Nodes' inputs/outputs with FP16 Values.
812   // FP32 Values that are not external inputs or outputs are converted to FP16 in-place,
813   // for external inputs and outputs we create same-shaped FP16 Values and use those instead.
814   const uint32_t num_original_values = subgraph->num_values;
815   xnn_subgraph_analyze_consumers_and_producers(subgraph);
816   for (uint32_t n = 0; n < num_original_values; n++) {
817     struct xnn_value* value = &subgraph->values[n];
818     value->fp16_id = XNN_INVALID_VALUE_ID;
819     value->fp32_id = XNN_INVALID_VALUE_ID;
820     if (value->fp16_compatible) {
821       assert(value->data == NULL);
822       assert(value->datatype == xnn_datatype_fp32);
823       if (xnn_value_is_external(value)) {
824         struct xnn_value* fp16_value = xnn_subgraph_new_internal_value(subgraph);
825 
826         // Recompute value due to potential reallocation in xnn_subgraph_new_internal_value
827         value = &subgraph->values[n];
828         xnn_value_copy(fp16_value, value);
829         fp16_value->datatype = xnn_datatype_fp16;
830 
831         fp16_value->producer = value->producer;
832         fp16_value->num_consumers = value->num_consumers;
833         fp16_value->first_consumer = value->first_consumer;
834         value->producer = XNN_INVALID_NODE_ID;
835         value->num_consumers = 0;
836         value->first_consumer = XNN_INVALID_NODE_ID;
837 
838         // Clear external input/output flags
839         fp16_value->flags = 0;
840         xnn_log_debug("FP16 rewrite: created FP16 tensor #%" PRIu32 " for FP32 tensor #%" PRIu32, fp16_value->id, n);
841 
842         value->fp16_id = fp16_value->id;
843         fp16_value->fp32_id = n;
844       } else {
845         xnn_log_debug("FP16 rewrite: converted FP32 tensor #%" PRIu32 " to FP16", n);
846         value->datatype = xnn_datatype_fp16;
847       }
848     }
849   }
850   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
851     struct xnn_node* node = &subgraph->nodes[n];
852     if (node->type == xnn_node_type_invalid) {
853       // Node was fused away, skip.
854       continue;
855     }
856 
857     assert(node->compute_type == xnn_compute_type_fp32);
858     node->compute_type = xnn_compute_type_fp16;
859     if (node->type == xnn_node_type_static_constant_pad) {
860       node->params.static_pad.padding_value =
861         fp16_ieee_from_fp32_value(uint32_as_float(node->params.static_pad.padding_value));
862     }
863     for (uint32_t i = 0; i < node->num_inputs; i++) {
864       const uint32_t fp16_id = subgraph->values[node->inputs[i]].fp16_id;
865       if (fp16_id != XNN_INVALID_VALUE_ID) {
866         assert(subgraph->values[fp16_id].fp32_id == node->inputs[i]);
867         node->inputs[i] = fp16_id;
868       }
869     }
870     for (uint32_t o = 0; o < node->num_outputs; o++) {
871       const uint32_t fp16_id = subgraph->values[node->outputs[o]].fp16_id;
872       if (fp16_id != XNN_INVALID_VALUE_ID) {
873         assert(subgraph->values[fp16_id].fp32_id == node->outputs[o]);
874         node->outputs[o] = fp16_id;
875       }
876     }
877   }
878 
879   // Count the number of external inputs and outputs which require Convert nodes
880   uint32_t num_external_inputs = 0;
881   uint32_t num_external_outputs = 0;
882   for (uint32_t n = 0; n < subgraph->num_nodes; n++) {
883     const struct xnn_node* node = &subgraph->nodes[n];
884     for (uint32_t i = 0; i < node->num_inputs; i++) {
885       const struct xnn_value* value = &subgraph->values[node->inputs[i]];
886       if (value->fp32_id != XNN_INVALID_VALUE_ID && value->first_consumer == n) {
887         assert(value->data == NULL);
888         assert(value->datatype == xnn_datatype_fp16);
889         assert(subgraph->values[value->fp32_id].datatype == xnn_datatype_fp32);
890         // This value isn't always an external input, it could be an external output of the current subgraph (due to
891         // partition), and be simultaneously consumed by the current node.
892         if (xnn_value_is_external_input(&subgraph->values[value->fp32_id])) {
893           num_external_inputs += 1;
894         }
895       }
896     }
897     for (uint32_t o = 0; o < node->num_outputs; o++) {
898       const struct xnn_value* value = &subgraph->values[node->outputs[o]];
899       if (value->fp32_id != XNN_INVALID_VALUE_ID) {
900         assert(value->datatype == xnn_datatype_fp16);
901         assert(subgraph->values[value->fp32_id].datatype == xnn_datatype_fp32);
902         assert(xnn_value_is_external_output(&subgraph->values[value->fp32_id]));
903         num_external_outputs += 1;
904       }
905     }
906   }
907   xnn_log_debug("Discovered %"PRIu32" external inputs and %"PRIu32" external outputs",
908     num_external_inputs, num_external_outputs);
909 
910   const uint32_t num_original_nodes = subgraph->num_nodes;
911   xnn_subgraph_add_nodes(subgraph, num_external_inputs + num_external_outputs);
912   struct xnn_node* output_node = subgraph->nodes + subgraph->num_nodes - 1;
913   for (uint32_t n = num_original_nodes; n != 0; n--) {
914     const struct xnn_node* node = &subgraph->nodes[n - 1];
915     // Insert Convert nodes for outputs
916     for (uint32_t o = 0; o < node->num_outputs; o++) {
917       const struct xnn_value* value = &subgraph->values[node->outputs[o]];
918       if (value->fp32_id != XNN_INVALID_VALUE_ID) {
919         xnn_log_debug("Inserted FP16->FP32 Convert Node from tensor #%"PRIu32" to tensor #%"PRIu32,
920           value->id, value->fp32_id);
921         const uint32_t output_node_id = output_node->id;
922         assert(output_node >= subgraph->nodes);
923         xnn_node_clear(output_node);
924         output_node->id = output_node_id;
925         xnn_init_convert_node(output_node, xnn_compute_type_fp16_to_fp32, value->id, value->fp32_id, 0 /* flags */);
926         output_node -= 1;
927       }
928     }
929     // Move the Node to the new location
930     if (output_node != node) {
931       const uint32_t output_node_id = output_node->id;
932       assert(output_node >= subgraph->nodes);
933       memcpy(output_node, node, sizeof(struct xnn_node));
934       output_node->id = output_node_id;
935       output_node -= 1;
936     }
937     // Insert Convert nodes for inputs
938     for (uint32_t i = 0; i < node->num_inputs; i++) {
939       const struct xnn_value* value = &subgraph->values[node->inputs[i]];
940       if (value->fp32_id != XNN_INVALID_VALUE_ID && value->first_consumer == n - 1) {
941         // Only insert convert nodes if the value actually is an external input. This value could be an external output,
942         // if that's the case, we have already inserted a convert node in loop above for outputs.
943         if (xnn_value_is_external_input(&subgraph->values[value->fp32_id])) {
944           xnn_log_debug("Inserted FP32->FP16 Convert Node from tensor #%"PRIu32" to tensor #%"PRIu32,
945                         value->fp32_id, value->id);
946           const uint32_t output_node_id = output_node->id;
947           assert(output_node >= subgraph->nodes);
948           xnn_node_clear(output_node);
949           output_node->id = output_node_id;
950           xnn_init_convert_node(output_node, xnn_compute_type_fp32_to_fp16, value->fp32_id, value->id, 0 /* flags */);
951           output_node -= 1;
952         }
953       }
954     }
955   }
956 
957   return true;
958 }
959 
xnn_subgraph_fusion(xnn_subgraph_t subgraph)960 enum xnn_status xnn_subgraph_fusion(
961     xnn_subgraph_t subgraph)
962 {
963   // Fuse Nodes where possible
964   for (uint32_t i = 0; i < subgraph->num_values; i++) {
965     struct xnn_value* value = &subgraph->values[i];
966     if (value->num_consumers == 1) {
967       const uint32_t producer_id = value->producer;
968       if (producer_id == XNN_INVALID_NODE_ID) {
969         continue;
970       }
971       assert(producer_id < subgraph->num_nodes);
972 
973       const uint32_t consumer_id = value->first_consumer;
974       if (consumer_id == XNN_INVALID_NODE_ID) {
975         continue;
976       }
977       assert(consumer_id < subgraph->num_nodes);
978 
979       struct xnn_node* producer = &subgraph->nodes[producer_id];
980       assert(producer->type != xnn_node_type_invalid);
981       struct xnn_node* consumer = &subgraph->nodes[consumer_id];
982       assert(consumer->type != xnn_node_type_invalid);
983 
984       // Try to fuse Clamp Node upstream into producer Node
985       if (consumer->type == xnn_node_type_clamp) {
986         switch (producer->type) {
987           case xnn_node_type_add2:
988           case xnn_node_type_average_pooling_2d:
989           case xnn_node_type_clamp:
990           case xnn_node_type_convolution_2d:
991           case xnn_node_type_divide:
992           case xnn_node_type_deconvolution_2d:
993           case xnn_node_type_depthwise_convolution_2d:
994           case xnn_node_type_fully_connected:
995           case xnn_node_type_multiply2:
996           case xnn_node_type_max_pooling_2d:
997           case xnn_node_type_subtract:
998             xnn_log_info("fuse Clamp Node #%"PRIu32" into upstream Node #%"PRIu32, consumer_id, producer_id);
999             assert(producer->num_outputs == 1);
1000             assert(consumer->num_inputs == 1);
1001             assert(consumer->num_outputs == 1);
1002 
1003             const uint32_t fused_output_id = consumer->outputs[0];
1004             assert(fused_output_id < subgraph->num_values);
1005             subgraph->values[fused_output_id].producer = producer_id;
1006             producer->outputs[0] = fused_output_id;
1007 
1008             producer->activation.output_min =
1009               math_max_f32(producer->activation.output_min, consumer->activation.output_min);
1010             producer->activation.output_max =
1011               math_min_f32(producer->activation.output_max, consumer->activation.output_max);
1012 
1013             xnn_node_clear(consumer);
1014             xnn_value_clear(value);
1015             break;
1016           default:
1017             break;
1018         }
1019       }
1020       // Try to fuse Constant Pad node downstream into [Depthwise] Convolution 2D Node
1021       if (producer->type == xnn_node_type_static_constant_pad) {
1022         assert(producer->num_inputs == 1);
1023         assert(producer->num_outputs == 1);
1024         const bool is_spatial_2d_padding = value->shape.num_dims == 4 &&
1025           (producer->params.static_pad.pre_paddings[0] | producer->params.static_pad.post_paddings[0] |
1026            producer->params.static_pad.pre_paddings[3] | producer->params.static_pad.post_paddings[3]) == 0;
1027         const enum xnn_datatype padding_datatype = subgraph->values[producer->outputs[0]].datatype;
1028         const uint32_t padding_value = producer->params.static_pad.padding_value;
1029         const bool is_zero_padding =
1030           (padding_datatype == xnn_datatype_fp32 && padding_value == 0) ||
1031           ((padding_datatype == xnn_datatype_qint8 || padding_datatype == xnn_datatype_quint8) &&
1032           padding_value == (uint32_t) (uint8_t) subgraph->values[producer->outputs[0]].quantization.zero_point);
1033         switch (consumer->type) {
1034           case xnn_node_type_convolution_2d:
1035             if (is_spatial_2d_padding && is_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
1036               xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Convolution 2D Node #%"PRIu32,
1037                 consumer_id, producer_id);
1038               assert(consumer->num_inputs >= 1);
1039               assert(consumer->inputs[0] == producer->outputs[0]);
1040 
1041               consumer->params.convolution_2d.input_padding_top    += producer->params.static_pad.pre_paddings[1];
1042               consumer->params.convolution_2d.input_padding_right  += producer->params.static_pad.post_paddings[2];
1043               consumer->params.convolution_2d.input_padding_bottom += producer->params.static_pad.post_paddings[1];
1044               consumer->params.convolution_2d.input_padding_left   += producer->params.static_pad.pre_paddings[2];
1045 
1046               consumer->inputs[0] = producer->inputs[0];
1047 
1048               const uint32_t fused_input_id = producer->inputs[0];
1049               assert(fused_input_id < subgraph->num_values);
1050               if (subgraph->values[fused_input_id].first_consumer == producer_id) {
1051                 subgraph->values[fused_input_id].first_consumer = consumer_id;
1052               }
1053 
1054               xnn_node_clear(producer);
1055               xnn_value_clear(value);
1056             }
1057             break;
1058           case xnn_node_type_depthwise_convolution_2d:
1059             if (is_spatial_2d_padding && is_zero_padding && !(consumer->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING)) {
1060               xnn_log_info("fuse Constant Pad Node #%"PRIu32" into Depthwise Convolution 2D Node #%"PRIu32,
1061                 consumer_id, producer_id);
1062               assert(consumer->num_inputs >= 1);
1063               assert(consumer->inputs[0] == producer->outputs[0]);
1064 
1065               consumer->params.depthwise_convolution_2d.input_padding_top +=
1066                 producer->params.static_pad.pre_paddings[1];
1067               consumer->params.depthwise_convolution_2d.input_padding_right +=
1068                 producer->params.static_pad.post_paddings[2];
1069               consumer->params.depthwise_convolution_2d.input_padding_bottom +=
1070                 producer->params.static_pad.post_paddings[1];
1071               consumer->params.depthwise_convolution_2d.input_padding_left +=
1072                 producer->params.static_pad.pre_paddings[2];
1073 
1074               consumer->inputs[0] = producer->inputs[0];
1075 
1076               const uint32_t fused_input_id = producer->inputs[0];
1077               assert(fused_input_id < subgraph->num_values);
1078               if (subgraph->values[fused_input_id].first_consumer == producer_id) {
1079                 subgraph->values[fused_input_id].first_consumer = consumer_id;
1080               }
1081 
1082               xnn_node_clear(producer);
1083               xnn_value_clear(value);
1084             }
1085             break;
1086           default:
1087             break;
1088         }
1089       }
1090     }
1091   }
1092 
1093   return xnn_status_success;
1094 }
1095 
xnn_subgraph_optimize(xnn_subgraph_t subgraph,uint32_t flags)1096 enum xnn_status xnn_subgraph_optimize(
1097   xnn_subgraph_t subgraph,
1098   uint32_t flags)
1099 {
1100   xnn_subgraph_analyze_consumers_and_producers(subgraph);
1101 
1102   // Remove unreferenced values.
1103   for (uint32_t i = 0; i < subgraph->num_values; i++) {
1104     struct xnn_value* value = &subgraph->values[i];
1105     if (value->type == xnn_value_type_invalid) {
1106       continue;
1107     }
1108 
1109     if (!xnn_value_is_external_input(value) && value->num_consumers == 0) {
1110       xnn_value_clear(value);
1111     }
1112   }
1113 
1114 
1115   if (!(flags & XNN_FLAG_NO_OPERATOR_FUSION)) {
1116     xnn_subgraph_fusion(subgraph);
1117   }
1118 
1119   #if XNN_ENABLE_SPARSE
1120     if ((flags & XNN_FLAG_HINT_SPARSE_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_CHW_OPT)) {
1121       xnn_subgraph_rewrite_for_nchw(subgraph);
1122     }
1123   #endif
1124 
1125   if ((flags & XNN_FLAG_FORCE_FP16_INFERENCE) && !(xnn_params.init_flags & XNN_INIT_FLAG_F16)) {
1126     xnn_log_error("failed to force FP16 inference: hardware supports neither native nor emulated FP16 operators");
1127     return xnn_status_unsupported_hardware;
1128   }
1129   #ifndef XNN_NO_F16_OPERATORS
1130     const bool try_native_fp16 =
1131       (flags & XNN_FLAG_HINT_FP16_INFERENCE) && (xnn_params.init_flags & XNN_INIT_FLAG_F16_NATIVE);
1132     const bool force_fp16 = (flags & XNN_FLAG_FORCE_FP16_INFERENCE);
1133     if (try_native_fp16 || force_fp16) {
1134       const bool fp16_rewrite_succeeded = xnn_subgraph_rewrite_for_fp16(subgraph);
1135       if (force_fp16 && !fp16_rewrite_succeeded) {
1136         xnn_log_error("failed to force FP16 inference: subgraph is incompatible with FP16 operators");
1137         return xnn_status_unsupported_parameter;
1138       }
1139     }
1140   #endif  // XNN_NO_F16_OPERATORS
1141 
1142   return xnn_status_success;
1143 }
1144 
xnn_delete_subgraph(xnn_subgraph_t subgraph)1145 enum xnn_status xnn_delete_subgraph(
1146   xnn_subgraph_t subgraph)
1147 {
1148   if (subgraph != NULL) {
1149     memset(subgraph->nodes, 0, sizeof(struct xnn_node) * subgraph->num_nodes);
1150     xnn_release_memory(subgraph->nodes);
1151 
1152     memset(subgraph->values, 0, sizeof(struct xnn_value) * subgraph->num_values);
1153     xnn_release_memory(subgraph->values);
1154 
1155     memset(subgraph, 0, sizeof(struct xnn_subgraph));
1156     xnn_release_memory(subgraph);
1157   }
1158   return xnn_status_success;
1159 }
1160