1# Copyright 2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5# 6# Follows this specification: https://pytorch.org/docs/stable/generated/torch.unsqueeze.html 7 8# pyre-unsafe 9 10import serializer.tosa_serializer as ts 11import torch.fx 12from executorch.backends.arm.operators.node_visitor import ( 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.arm.tosa_mapping import TosaArg 17from executorch.backends.arm.tosa_utils import tosa_shape 18from serializer.tosa_serializer import TosaOp 19 20 21@register_node_visitor 22class UnsqueezeVisitor(NodeVisitor): 23 target = "aten.unsqueeze_copy.default" 24 25 def __init__(self, *args): 26 super().__init__(*args) 27 28 def define_node( 29 self, 30 node: torch.fx.Node, 31 tosa_graph: ts.TosaSerializer, 32 inputs: list[TosaArg], 33 output: TosaArg, 34 is_quant_node: bool, 35 ) -> None: 36 37 dim = inputs[1].number 38 shape = inputs[0].shape 39 rank = len(shape) 40 41 assert -rank - 1 <= dim < rank + 1 42 if dim < 0: 43 dim = dim + rank + 1 44 45 new_shape = list(shape) 46 new_shape.insert(dim, 1) 47 new_shape = tosa_shape(new_shape, output.dim_order) 48 49 attr = ts.TosaSerializerAttribute() 50 attr.ReshapeAttribute(new_shape) 51 tosa_graph.addOperator( 52 TosaOp.Op().RESHAPE, [inputs[0].name], [output.name], attr 53 ) 54