xref: /aosp_15_r20/external/executorch/backends/xnnpack/operators/op_squeeze.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import cast, Dict
8
9import torch
10from executorch.backends.xnnpack.operators.node_visitor import (
11    NodeVisitor,
12    register_node_visitor,
13)
14from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
15    XNNGraph,
16    XNNStaticReshape,
17    XNode,
18)
19from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node
20
21
22@register_node_visitor
23class SqueezeVisitor(NodeVisitor):
24    target = "aten.squeeze_copy.dim"
25
26    def __init__(self, *args) -> None:
27        super().__init__(*args)
28
29    def define_node(
30        self,
31        node: torch.fx.Node,
32        xnn_graph: XNNGraph,
33        vals_to_ids: Dict[torch.fx.Node, int],
34        debug_handle: int,
35    ) -> None:
36
37        check_or_raise(
38            cast(int, node.args[1]) == -1,
39            "XNNPACK currently only supports squeezing in last dimension",
40        )
41
42        self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
43        input_node = get_input_node(node, 0)
44
45        # input
46        input_id = vals_to_ids[input_node]
47
48        # output
49        output_id = vals_to_ids[node]
50
51        check_or_raise(
52            "val" in input_node.meta,
53            "Missing val in tensor metadata for input when serializing XNNStaticReshape node",
54        )
55        dynamic_shape = node.meta["val"].shape
56        new_shape = []
57
58        num_dynamic_dims = 0
59        for dim in dynamic_shape:
60            if isinstance(dim, torch.SymInt):
61                num_dynamic_dims += 1
62                new_shape.append(0)
63            else:
64                new_shape.append(dim)
65
66        check_or_raise(
67            num_dynamic_dims <= 1,
68            "XNNPACK reshape only supports 1 dynamic dimension. This may occur when ",
69        )
70
71        ser_node = XNode(
72            xnode_union=XNNStaticReshape(
73                num_dims=len(new_shape),
74                new_shape=new_shape,
75                input_id=input_id,
76                output_id=output_id,
77                flags=0,
78            ),
79            debug_handle=debug_handle,
80        )
81        xnn_graph.xnodes.append(ser_node)
82
83
84@register_node_visitor
85class UnsqueezeVisitor(NodeVisitor):
86    target = "aten.unsqueeze_copy.default"
87
88    def __init__(self, *args) -> None:
89        super().__init__(*args)
90
91    def define_node(
92        self,
93        node: torch.fx.Node,
94        xnn_graph: XNNGraph,
95        vals_to_ids: Dict[torch.fx.Node, int],
96        debug_handle: int,
97    ) -> None:
98
99        check_or_raise(
100            cast(int, node.args[1]) == -1,
101            "XNNPACK currently only supports unsqueezing in last dimension",
102        )
103
104        self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids)
105        input_node = get_input_node(node, 0)
106
107        # input
108        input_id = vals_to_ids[input_node]
109
110        # output
111        output_id = vals_to_ids[node]
112
113        check_or_raise(
114            "val" in input_node.meta,
115            "Missing val in tensor metadata for input when serializing XNNStaticReshape node",
116        )
117        dynamic_shape = node.meta["val"].shape
118        new_shape = []
119
120        num_dynamic_dims = 0
121        for dim in dynamic_shape:
122            if isinstance(dim, torch.SymInt):
123                num_dynamic_dims += 1
124                new_shape.append(0)
125            else:
126                new_shape.append(dim)
127
128        check_or_raise(
129            num_dynamic_dims <= 1,
130            "XNNPACK reshape only supports 1 dynamic dimension. This may occur when ",
131        )
132
133        ser_node = XNode(
134            xnode_union=XNNStaticReshape(
135                num_dims=len(new_shape),
136                new_shape=new_shape,
137                input_id=input_id,
138                output_id=output_id,
139                flags=0,
140            ),
141            debug_handle=debug_handle,
142        )
143        xnn_graph.xnodes.append(ser_node)
144