1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import torch 10from executorch.exir.pass_base import ExportPass, map_args 11 12 13class ScalarToTensorPass(ExportPass): 14 # pyre-ignore 15 def call_operator(self, op, args, kwargs, meta): 16 # pyre-ignore 17 def try_coerce(value, arg): 18 # Note: we want to create tensor constants instead of 19 # FakeTensor or ProxyTensor. If python_dispatcher is enabled, 20 # the fake_tensor_mode of inputs will be used so that we won't 21 # get a constant tensor with torch.tensor() call but instead 22 # a fake tensor is created. 23 with torch.utils._python_dispatch._disable_current_modes(): 24 return ( 25 torch.tensor(value) 26 if isinstance(value, (float, int, bool)) 27 and isinstance(arg.type, torch.TensorType) 28 else value 29 ) 30 31 args, kwargs = map_args(op, try_coerce, args, kwargs) 32 return super().call_operator(op, args, kwargs, meta) 33