xref: /aosp_15_r20/external/executorch/exir/passes/remove_mixed_type_operators.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import torch
10from executorch.exir.pass_base import ExportPass, map_args, NodeMetadata, ProxyValue
11from torch import SymBool, SymFloat, SymInt
12from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
13from torch.utils._pytree import PyTree
14
15
16class RemoveMixedTypeOperators(ExportPass):
17    # pyre-ignore
18    def call_operator(self, op, args, kwargs, meta: NodeMetadata):  # noqa: C901
19        if len(args) <= 1:
20            # Unary Operators are not mixed type
21            return super().call_operator(op, args, kwargs, meta)
22
23        promotion_type_allow_list = {
24            torch.ops.aten.add.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
25            torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
26            torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
27            torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
28        }
29
30        if op in promotion_type_allow_list:
31            promotion_kind = promotion_type_allow_list[op]
32        else:
33            # Not in allow list, do nothing
34            return super().call_operator(op, args, kwargs, meta)
35
36        # Using tensors for type information only
37        arg_tensor = []
38        for arg in args:
39            if isinstance(arg, ProxyValue) and arg.is_tensor():
40                arg_tensor.append(arg.to_tensor())
41            elif isinstance(arg, ProxyValue) and isinstance(
42                arg.data,
43                (
44                    SymFloat,
45                    SymInt,
46                    SymBool,
47                ),
48            ):
49                arg_tensor.append(torch.tensor(arg.data))
50            # Note: this case can happen after scarlar_to_tensor pass
51            # where we convert a scalar to a tensor.
52            elif isinstance(arg, torch.Tensor):
53                arg_tensor.append(arg)
54            else:
55                arg_tensor.append(arg.data)
56        arg_tensor = tuple(arg_tensor)
57
58        # Computes type for computation
59        promote_dtype: torch.dtype = elementwise_dtypes(
60            *arg_tensor,
61            type_promotion_kind=promotion_kind,
62        )[1]
63
64        def try_coerce(value: PyTree, arg: torch.Argument) -> PyTree:
65            if not isinstance(arg.type, torch.TensorType):
66                return value
67
68            if isinstance(value, ProxyValue):
69                if not value.is_tensor():
70                    return value
71                if value.to_tensor().dtype == promote_dtype:
72                    return value
73
74            if isinstance(value, torch.Tensor) and value.dtype == promote_dtype:
75                return value
76
77            return self.call_operator(
78                torch.ops.aten._to_copy.default,
79                (value,),
80                {"dtype": promote_dtype},
81                meta,
82            )
83
84        args, kwargs = map_args(op, try_coerce, args, kwargs)
85
86        return super().call_operator(op, args, kwargs, meta)
87