xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/convert_binary_op_with_scalar.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from typing import Dict, Tuple
7
8import torch
9from executorch.exir.pass_base import ExportPass
10from torch._export.pass_base import Argument
11from torch._export.pass_infra.node_metadata import NodeMetadata
12from torch._export.pass_infra.proxy_value import ProxyValue
13
14
15class ConvertBinaryOpsWithScalar(ExportPass):
16    """
17    Replace binary ops with scalar into binary ops with tensor.
18    Since torch.ops.aten.xxx.Scalar will not generate a placeholder node
19    for scalar after to_edge.
20    """
21
22    binary_ops_with_scalar = {
23        torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
24        torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
25        torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
26        torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
27    }
28
29    def __init__(self):
30        super(ConvertBinaryOpsWithScalar, self).__init__()
31
32    def call_operator(
33        self,
34        op,
35        args: Tuple[Argument, ...],
36        kwargs: Dict[str, Argument],
37        meta: NodeMetadata,
38    ) -> ProxyValue:
39        return super().call_operator(
40            self.binary_ops_with_scalar.get(op, op), args, kwargs, meta
41        )
42