1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9""" 10Helper functions for constructing a "leaf function" in FX graph. A "leaf 11function" will be preserved as a call node in the the graph instead of 12being traced through. 13""" 14 15import torch 16from executorch.exir.tracer import PythonTensor, unwrap_functional 17 18# pyre-fixme[21]: Could not find module `torch._C._functorch`. 19from torch._C._functorch import ( # @manual=//caffe2/functorch:functorch" 20 is_functionaltensor, 21) 22 23from torch._functorch.eager_transforms import _assert_wrapped_functional # pyre-ignore 24 25 26def update_with_proxy(t: torch.Tensor, proxy: torch.fx.Proxy) -> torch.Tensor: 27 unwrapped = unwrap_functional(t) 28 assert isinstance(unwrapped, PythonTensor) 29 unwrapped.update_proxy(proxy) 30 if is_functionaltensor(t): # type: ignore 31 _assert_wrapped_functional(unwrapped, t) 32 return t 33