1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from typing import cast 10 11from executorch.backends.arm.tosa_mapping import extract_tensor_meta 12from executorch.exir.dialects._ops import ops as exir_ops 13from executorch.exir.pass_base import ExportPass 14 15 16class ConvertExpandCopyToRepeatPass(ExportPass): 17 """ 18 Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions. 19 """ 20 21 expand_copy = exir_ops.edge.aten.expand_copy.default 22 repeat = exir_ops.edge.aten.repeat.default 23 24 def call_operator(self, op, args, kwargs, meta): 25 if op != self.expand_copy: 26 return super().call_operator(op, args, kwargs, meta) 27 28 _, shape, _ = extract_tensor_meta(meta.data) 29 multiples = cast(list[int], args[1]) 30 expanded_rank = len(multiples) 31 32 # Expanded shape is 'shape' front-padded with ones. 33 padding = expanded_rank - len(shape) 34 extended_shape = [ 35 shape[i] if i >= 0 else 1 for i in range(-padding, len(shape)) 36 ] 37 38 # To convert expand arg to repeat arg, non-repeated dims should have 39 # multiples[dim] = 1. 40 multiples = [ 41 multiples[i] if extended_shape[i] == 1 else 1 for i in range(expanded_rank) 42 ] 43 return super().call_operator( 44 op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta 45 ) 46