1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict 8 9import torch 10from executorch.backends.xnnpack.operators.node_visitor import ( 11 NodeVisitor, 12 register_node_visitor, 13) 14from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 15 XNNGraph, 16 XNNStaticReshape, 17 XNode, 18) 19from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node 20 21 22@register_node_visitor 23class SqueezeVisitor(NodeVisitor): 24 target = "aten.squeeze_copy.dim" 25 26 def __init__(self, *args) -> None: 27 super().__init__(*args) 28 29 def define_node( 30 self, 31 node: torch.fx.Node, 32 xnn_graph: XNNGraph, 33 vals_to_ids: Dict[torch.fx.Node, int], 34 debug_handle: int, 35 ) -> None: 36 37 check_or_raise( 38 cast(int, node.args[1]) == -1, 39 "XNNPACK currently only supports squeezing in last dimension", 40 ) 41 42 self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) 43 input_node = get_input_node(node, 0) 44 45 # input 46 input_id = vals_to_ids[input_node] 47 48 # output 49 output_id = vals_to_ids[node] 50 51 check_or_raise( 52 "val" in input_node.meta, 53 "Missing val in tensor metadata for input when serializing XNNStaticReshape node", 54 ) 55 dynamic_shape = node.meta["val"].shape 56 new_shape = [] 57 58 num_dynamic_dims = 0 59 for dim in dynamic_shape: 60 if isinstance(dim, torch.SymInt): 61 num_dynamic_dims += 1 62 new_shape.append(0) 63 else: 64 new_shape.append(dim) 65 66 check_or_raise( 67 num_dynamic_dims <= 1, 68 "XNNPACK reshape only supports 1 dynamic dimension. This may occur when ", 69 ) 70 71 ser_node = XNode( 72 xnode_union=XNNStaticReshape( 73 num_dims=len(new_shape), 74 new_shape=new_shape, 75 input_id=input_id, 76 output_id=output_id, 77 flags=0, 78 ), 79 debug_handle=debug_handle, 80 ) 81 xnn_graph.xnodes.append(ser_node) 82 83 84@register_node_visitor 85class UnsqueezeVisitor(NodeVisitor): 86 target = "aten.unsqueeze_copy.default" 87 88 def __init__(self, *args) -> None: 89 super().__init__(*args) 90 91 def define_node( 92 self, 93 node: torch.fx.Node, 94 xnn_graph: XNNGraph, 95 vals_to_ids: Dict[torch.fx.Node, int], 96 debug_handle: int, 97 ) -> None: 98 99 check_or_raise( 100 cast(int, node.args[1]) == -1, 101 "XNNPACK currently only supports unsqueezing in last dimension", 102 ) 103 104 self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) 105 input_node = get_input_node(node, 0) 106 107 # input 108 input_id = vals_to_ids[input_node] 109 110 # output 111 output_id = vals_to_ids[node] 112 113 check_or_raise( 114 "val" in input_node.meta, 115 "Missing val in tensor metadata for input when serializing XNNStaticReshape node", 116 ) 117 dynamic_shape = node.meta["val"].shape 118 new_shape = [] 119 120 num_dynamic_dims = 0 121 for dim in dynamic_shape: 122 if isinstance(dim, torch.SymInt): 123 num_dynamic_dims += 1 124 new_shape.append(0) 125 else: 126 new_shape.append(dim) 127 128 check_or_raise( 129 num_dynamic_dims <= 1, 130 "XNNPACK reshape only supports 1 dynamic dimension. This may occur when ", 131 ) 132 133 ser_node = XNode( 134 xnode_union=XNNStaticReshape( 135 num_dims=len(new_shape), 136 new_shape=new_shape, 137 input_id=input_id, 138 output_id=output_id, 139 flags=0, 140 ), 141 debug_handle=debug_handle, 142 ) 143 xnn_graph.xnodes.append(ser_node) 144