1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker""" 10*523fa7a6SAndroid Build Coastguard WorkerHelper functions for constructing a "leaf function" in FX graph. A "leaf 11*523fa7a6SAndroid Build Coastguard Workerfunction" will be preserved as a call node in the the graph instead of 12*523fa7a6SAndroid Build Coastguard Workerbeing traced through. 13*523fa7a6SAndroid Build Coastguard Worker""" 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Workerimport torch 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tracer import PythonTensor, unwrap_functional 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[21]: Could not find module `torch._C._functorch`. 19*523fa7a6SAndroid Build Coastguard Workerfrom torch._C._functorch import ( # @manual=//caffe2/functorch:functorch" 20*523fa7a6SAndroid Build Coastguard Worker is_functionaltensor, 21*523fa7a6SAndroid Build Coastguard Worker) 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard Workerfrom torch._functorch.eager_transforms import _assert_wrapped_functional # pyre-ignore 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Workerdef update_with_proxy(t: torch.Tensor, proxy: torch.fx.Proxy) -> torch.Tensor: 27*523fa7a6SAndroid Build Coastguard Worker unwrapped = unwrap_functional(t) 28*523fa7a6SAndroid Build Coastguard Worker assert isinstance(unwrapped, PythonTensor) 29*523fa7a6SAndroid Build Coastguard Worker unwrapped.update_proxy(proxy) 30*523fa7a6SAndroid Build Coastguard Worker if is_functionaltensor(t): # type: ignore 31*523fa7a6SAndroid Build Coastguard Worker _assert_wrapped_functional(unwrapped, t) 32*523fa7a6SAndroid Build Coastguard Worker return t 33