xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_unsqueeze.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5#
6#  Follows this specification: https://pytorch.org/docs/stable/generated/torch.unsqueeze.html
7
8# pyre-unsafe
9
10import serializer.tosa_serializer as ts
11import torch.fx
12from executorch.backends.arm.operators.node_visitor import (
13    NodeVisitor,
14    register_node_visitor,
15)
16from executorch.backends.arm.tosa_mapping import TosaArg
17from executorch.backends.arm.tosa_utils import tosa_shape
18from serializer.tosa_serializer import TosaOp
19
20
21@register_node_visitor
22class UnsqueezeVisitor(NodeVisitor):
23    target = "aten.unsqueeze_copy.default"
24
25    def __init__(self, *args):
26        super().__init__(*args)
27
28    def define_node(
29        self,
30        node: torch.fx.Node,
31        tosa_graph: ts.TosaSerializer,
32        inputs: list[TosaArg],
33        output: TosaArg,
34        is_quant_node: bool,
35    ) -> None:
36
37        dim = inputs[1].number
38        shape = inputs[0].shape
39        rank = len(shape)
40
41        assert -rank - 1 <= dim < rank + 1
42        if dim < 0:
43            dim = dim + rank + 1
44
45        new_shape = list(shape)
46        new_shape.insert(dim, 1)
47        new_shape = tosa_shape(new_shape, output.dim_order)
48
49        attr = ts.TosaSerializerAttribute()
50        attr.ReshapeAttribute(new_shape)
51        tosa_graph.addOperator(
52            TosaOp.Op().RESHAPE, [inputs[0].name], [output.name], attr
53        )
54