1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import torch 10from executorch.exir.pass_base import ExportPass, map_args, NodeMetadata, ProxyValue 11from torch import SymBool, SymFloat, SymInt 12from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND 13from torch.utils._pytree import PyTree 14 15 16class RemoveMixedTypeOperators(ExportPass): 17 # pyre-ignore 18 def call_operator(self, op, args, kwargs, meta: NodeMetadata): # noqa: C901 19 if len(args) <= 1: 20 # Unary Operators are not mixed type 21 return super().call_operator(op, args, kwargs, meta) 22 23 promotion_type_allow_list = { 24 torch.ops.aten.add.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 25 torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 26 torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 27 torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 28 } 29 30 if op in promotion_type_allow_list: 31 promotion_kind = promotion_type_allow_list[op] 32 else: 33 # Not in allow list, do nothing 34 return super().call_operator(op, args, kwargs, meta) 35 36 # Using tensors for type information only 37 arg_tensor = [] 38 for arg in args: 39 if isinstance(arg, ProxyValue) and arg.is_tensor(): 40 arg_tensor.append(arg.to_tensor()) 41 elif isinstance(arg, ProxyValue) and isinstance( 42 arg.data, 43 ( 44 SymFloat, 45 SymInt, 46 SymBool, 47 ), 48 ): 49 arg_tensor.append(torch.tensor(arg.data)) 50 # Note: this case can happen after scarlar_to_tensor pass 51 # where we convert a scalar to a tensor. 52 elif isinstance(arg, torch.Tensor): 53 arg_tensor.append(arg) 54 else: 55 arg_tensor.append(arg.data) 56 arg_tensor = tuple(arg_tensor) 57 58 # Computes type for computation 59 promote_dtype: torch.dtype = elementwise_dtypes( 60 *arg_tensor, 61 type_promotion_kind=promotion_kind, 62 )[1] 63 64 def try_coerce(value: PyTree, arg: torch.Argument) -> PyTree: 65 if not isinstance(arg.type, torch.TensorType): 66 return value 67 68 if isinstance(value, ProxyValue): 69 if not value.is_tensor(): 70 return value 71 if value.to_tensor().dtype == promote_dtype: 72 return value 73 74 if isinstance(value, torch.Tensor) and value.dtype == promote_dtype: 75 return value 76 77 return self.call_operator( 78 torch.ops.aten._to_copy.default, 79 (value,), 80 {"dtype": promote_dtype}, 81 meta, 82 ) 83 84 args, kwargs = map_args(op, try_coerce, args, kwargs) 85 86 return super().call_operator(op, args, kwargs, meta) 87