1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from typing import Union 10 11import torch 12from executorch.exir.pass_base import ExportPass, map_args, NodeMetadata, ProxyValue 13from torch import SymBool, SymFloat, SymInt 14from torch.utils._pytree import PyTree 15 16 17class SymToTensorPass(ExportPass): 18 """ 19 The dispatcher implicitly converts SymInt/SymFloats to tensors, but 20 sometimes this doesn't comply with the operator's schema which ExecuTorch 21 heavily relies on. So this pass inserts a 22 torch.ops.aten.scalar_tensor.default operator before these SymInts are used 23 so that it matches the schema of the operator. 24 """ 25 26 # pyre-ignore 27 def call_operator(self, op, args, kwargs, meta: NodeMetadata): 28 # pyre-ignore 29 def is_sym(value, arg) -> bool: 30 if isinstance(value, ProxyValue) and not value.is_tensor(): 31 if isinstance(arg.type, torch.TensorType) and type(value.data) in { 32 SymInt, 33 SymFloat, 34 SymBool, 35 }: 36 return True 37 return False 38 39 def corresponding_dtype( 40 symbol: Union[SymInt, SymFloat, SymBool] 41 ) -> torch.dtype: 42 if isinstance(symbol, SymInt): 43 return torch.int32 44 elif isinstance(symbol, SymFloat): 45 return torch.float32 46 elif isinstance(symbol, SymBool): 47 return torch.bool 48 else: 49 raise AssertionError(f"Unsupported data type: {type(symbol)}") 50 51 def try_coerce(value: PyTree, arg: torch.Argument) -> PyTree: 52 if is_sym(value, arg): 53 return self.call_operator( 54 torch.ops.aten.scalar_tensor.default, 55 (value,), 56 {"dtype": corresponding_dtype(value.data)}, 57 meta, 58 ) 59 else: 60 return value 61 62 args, kwargs = map_args(op, try_coerce, args, kwargs) 63 64 return super().call_operator(op, args, kwargs, meta) 65