xref: /aosp_15_r20/external/executorch/exir/wrap.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker"""
10*523fa7a6SAndroid Build Coastguard WorkerHelper functions for constructing a "leaf function" in FX graph. A "leaf
11*523fa7a6SAndroid Build Coastguard Workerfunction" will be preserved as a call node in the the graph instead of
12*523fa7a6SAndroid Build Coastguard Workerbeing traced through.
13*523fa7a6SAndroid Build Coastguard Worker"""
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Workerimport torch
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tracer import PythonTensor, unwrap_functional
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[21]: Could not find module `torch._C._functorch`.
19*523fa7a6SAndroid Build Coastguard Workerfrom torch._C._functorch import (  # @manual=//caffe2/functorch:functorch"
20*523fa7a6SAndroid Build Coastguard Worker    is_functionaltensor,
21*523fa7a6SAndroid Build Coastguard Worker)
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Workerfrom torch._functorch.eager_transforms import _assert_wrapped_functional  # pyre-ignore
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Workerdef update_with_proxy(t: torch.Tensor, proxy: torch.fx.Proxy) -> torch.Tensor:
27*523fa7a6SAndroid Build Coastguard Worker    unwrapped = unwrap_functional(t)
28*523fa7a6SAndroid Build Coastguard Worker    assert isinstance(unwrapped, PythonTensor)
29*523fa7a6SAndroid Build Coastguard Worker    unwrapped.update_proxy(proxy)
30*523fa7a6SAndroid Build Coastguard Worker    if is_functionaltensor(t):  # type: ignore
31*523fa7a6SAndroid Build Coastguard Worker        _assert_wrapped_functional(unwrapped, t)
32*523fa7a6SAndroid Build Coastguard Worker    return t
33