xref: /aosp_15_r20/external/executorch/exir/wrap.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9"""
10Helper functions for constructing a "leaf function" in FX graph. A "leaf
11function" will be preserved as a call node in the the graph instead of
12being traced through.
13"""
14
15import torch
16from executorch.exir.tracer import PythonTensor, unwrap_functional
17
18# pyre-fixme[21]: Could not find module `torch._C._functorch`.
19from torch._C._functorch import (  # @manual=//caffe2/functorch:functorch"
20    is_functionaltensor,
21)
22
23from torch._functorch.eager_transforms import _assert_wrapped_functional  # pyre-ignore
24
25
26def update_with_proxy(t: torch.Tensor, proxy: torch.fx.Proxy) -> torch.Tensor:
27    unwrapped = unwrap_functional(t)
28    assert isinstance(unwrapped, PythonTensor)
29    unwrapped.update_proxy(proxy)
30    if is_functionaltensor(t):  # type: ignore
31        _assert_wrapped_functional(unwrapped, t)
32    return t
33