xref: /aosp_15_r20/external/executorch/exir/passes/scalar_to_tensor_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import torch
10from executorch.exir.pass_base import ExportPass, map_args
11
12
13class ScalarToTensorPass(ExportPass):
14    # pyre-ignore
15    def call_operator(self, op, args, kwargs, meta):
16        # pyre-ignore
17        def try_coerce(value, arg):
18            # Note: we want to create tensor constants instead of
19            # FakeTensor or ProxyTensor. If python_dispatcher is enabled,
20            # the fake_tensor_mode of inputs will be used so that we won't
21            # get a constant tensor with torch.tensor() call but instead
22            # a fake tensor is created.
23            with torch.utils._python_dispatch._disable_current_modes():
24                return (
25                    torch.tensor(value)
26                    if isinstance(value, (float, int, bool))
27                    and isinstance(arg.type, torch.TensorType)
28                    else value
29                )
30
31        args, kwargs = map_args(op, try_coerce, args, kwargs)
32        return super().call_operator(op, args, kwargs, meta)
33