1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerfrom __future__ import annotations 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workertry: # noqa: C901 12*523fa7a6SAndroid Build Coastguard Worker from torch._higher_order_ops.executorch_call_delegate import ( 13*523fa7a6SAndroid Build Coastguard Worker executorch_call_delegate as executorch_call_delegate, 14*523fa7a6SAndroid Build Coastguard Worker get_lowered_module_name as get_lowered_module_name, 15*523fa7a6SAndroid Build Coastguard Worker is_lowered_module as is_lowered_module, 16*523fa7a6SAndroid Build Coastguard Worker ) 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Workerexcept ImportError: 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Worker # TODO: Delete this code once pytorch pin advances 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Worker from typing import Any, cast 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker import torch 25*523fa7a6SAndroid Build Coastguard Worker import torch.utils._pytree as pytree 26*523fa7a6SAndroid Build Coastguard Worker from torch._ops import HigherOrderOperator 27*523fa7a6SAndroid Build Coastguard Worker from torch._subclasses.fake_tensor import FakeTensorMode 28*523fa7a6SAndroid Build Coastguard Worker from torch.fx.experimental.proxy_tensor import ( 29*523fa7a6SAndroid Build Coastguard Worker disable_proxy_modes_tracing, 30*523fa7a6SAndroid Build Coastguard Worker get_proxy_slot, 31*523fa7a6SAndroid Build Coastguard Worker ProxyTorchDispatchMode, 32*523fa7a6SAndroid Build Coastguard Worker track_tensor_tree, 33*523fa7a6SAndroid Build Coastguard Worker ) 34*523fa7a6SAndroid Build Coastguard Worker from torch.utils._pytree import tree_flatten 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard Worker executorch_call_delegate = HigherOrderOperator("executorch_call_delegate") 37*523fa7a6SAndroid Build Coastguard Worker executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher) 38*523fa7a6SAndroid Build Coastguard Worker executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot) 39*523fa7a6SAndroid Build Coastguard Worker executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView) 40*523fa7a6SAndroid Build Coastguard Worker executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU) 41*523fa7a6SAndroid Build Coastguard Worker 42*523fa7a6SAndroid Build Coastguard Worker LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule" 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 45*523fa7a6SAndroid Build Coastguard Worker def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args): 46*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 47*523fa7a6SAndroid Build Coastguard Worker def _unwrap_proxy(e): 48*523fa7a6SAndroid Build Coastguard Worker if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): 49*523fa7a6SAndroid Build Coastguard Worker return e 50*523fa7a6SAndroid Build Coastguard Worker return get_proxy_slot( 51*523fa7a6SAndroid Build Coastguard Worker cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy 52*523fa7a6SAndroid Build Coastguard Worker ) 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker if not is_lowered_module(lowered_module): 55*523fa7a6SAndroid Build Coastguard Worker raise ValueError( 56*523fa7a6SAndroid Build Coastguard Worker "executorch_call_delegate()'s first argument must be a LoweredBackendModule" 57*523fa7a6SAndroid Build Coastguard Worker ) 58*523fa7a6SAndroid Build Coastguard Worker 59*523fa7a6SAndroid Build Coastguard Worker with disable_proxy_modes_tracing(): 60*523fa7a6SAndroid Build Coastguard Worker out = call_delegate_cpu(lowered_module, *args) 61*523fa7a6SAndroid Build Coastguard Worker 62*523fa7a6SAndroid Build Coastguard Worker get_lowered_module_name(proxy_mode.tracer.root, lowered_module) 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker node_args = (lowered_module, *args) 65*523fa7a6SAndroid Build Coastguard Worker proxy_args = pytree.tree_map(_unwrap_proxy, node_args) 66*523fa7a6SAndroid Build Coastguard Worker out_proxy = proxy_mode.tracer.create_proxy( 67*523fa7a6SAndroid Build Coastguard Worker "call_function", 68*523fa7a6SAndroid Build Coastguard Worker func_overload, 69*523fa7a6SAndroid Build Coastguard Worker proxy_args, 70*523fa7a6SAndroid Build Coastguard Worker {}, 71*523fa7a6SAndroid Build Coastguard Worker name="executorch_call_delegate", 72*523fa7a6SAndroid Build Coastguard Worker ) 73*523fa7a6SAndroid Build Coastguard Worker return track_tensor_tree( 74*523fa7a6SAndroid Build Coastguard Worker out, out_proxy, constant=None, tracer=proxy_mode.tracer 75*523fa7a6SAndroid Build Coastguard Worker ) 76*523fa7a6SAndroid Build Coastguard Worker 77*523fa7a6SAndroid Build Coastguard Worker @executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) 78*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 79*523fa7a6SAndroid Build Coastguard Worker def call_delegate_cpu(lowered_module, *args): 80*523fa7a6SAndroid Build Coastguard Worker # FX creates this immutable_dict/list concept. Get rid of this. 81*523fa7a6SAndroid Build Coastguard Worker map_types = { 82*523fa7a6SAndroid Build Coastguard Worker torch.fx.immutable_collections.immutable_dict: dict, 83*523fa7a6SAndroid Build Coastguard Worker torch.fx.immutable_collections.immutable_list: list, 84*523fa7a6SAndroid Build Coastguard Worker } 85*523fa7a6SAndroid Build Coastguard Worker new_args = pytree.tree_map_only( 86*523fa7a6SAndroid Build Coastguard Worker tuple(map_types.keys()), 87*523fa7a6SAndroid Build Coastguard Worker lambda a: map_types[type(a)](a), 88*523fa7a6SAndroid Build Coastguard Worker args, 89*523fa7a6SAndroid Build Coastguard Worker lambda a: isinstance(a, tuple(map_types.keys())), 90*523fa7a6SAndroid Build Coastguard Worker ) 91*523fa7a6SAndroid Build Coastguard Worker return lowered_module.original_module.module()(*new_args) 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Worker @executorch_call_delegate.py_impl(torch._C.DispatchKey.Autograd) 94*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 95*523fa7a6SAndroid Build Coastguard Worker def call_delegate_autograd(lowered_module, *args): 96*523fa7a6SAndroid Build Coastguard Worker # TODO: support autograd 97*523fa7a6SAndroid Build Coastguard Worker flat_operands, _ = tree_flatten([lowered_module, *args]) 98*523fa7a6SAndroid Build Coastguard Worker requires_grad = any( 99*523fa7a6SAndroid Build Coastguard Worker f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) 100*523fa7a6SAndroid Build Coastguard Worker ) 101*523fa7a6SAndroid Build Coastguard Worker 102*523fa7a6SAndroid Build Coastguard Worker with torch._C._ExcludeDispatchKeyGuard( 103*523fa7a6SAndroid Build Coastguard Worker torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU) 104*523fa7a6SAndroid Build Coastguard Worker ): 105*523fa7a6SAndroid Build Coastguard Worker res = executorch_call_delegate(lowered_module, *args) 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker if requires_grad: 108*523fa7a6SAndroid Build Coastguard Worker # Create aliases of the output that has requires_grad=True. We need 109*523fa7a6SAndroid Build Coastguard Worker # at least one of the inputs to err_fn to require grad so that the 110*523fa7a6SAndroid Build Coastguard Worker # output will have a grad_fn. 111*523fa7a6SAndroid Build Coastguard Worker 112*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 113*523fa7a6SAndroid Build Coastguard Worker def fake_requires_grad(var): 114*523fa7a6SAndroid Build Coastguard Worker if var is not None: 115*523fa7a6SAndroid Build Coastguard Worker var = var.detach() 116*523fa7a6SAndroid Build Coastguard Worker if torch.is_floating_point(var) or torch.is_complex(var): 117*523fa7a6SAndroid Build Coastguard Worker var.requires_grad = True 118*523fa7a6SAndroid Build Coastguard Worker return var 119*523fa7a6SAndroid Build Coastguard Worker 120*523fa7a6SAndroid Build Coastguard Worker return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res) 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Worker return res 123*523fa7a6SAndroid Build Coastguard Worker 124*523fa7a6SAndroid Build Coastguard Worker @executorch_call_delegate.py_impl(ProxyTorchDispatchMode) 125*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 126*523fa7a6SAndroid Build Coastguard Worker def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args): 127*523fa7a6SAndroid Build Coastguard Worker res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args) 128*523fa7a6SAndroid Build Coastguard Worker return res 129*523fa7a6SAndroid Build Coastguard Worker 130*523fa7a6SAndroid Build Coastguard Worker @executorch_call_delegate.py_impl(FakeTensorMode) 131*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 132*523fa7a6SAndroid Build Coastguard Worker def call_delegate_fake_tensor_mode(mode, lowered_module, *args): 133*523fa7a6SAndroid Build Coastguard Worker with mode: 134*523fa7a6SAndroid Build Coastguard Worker return call_delegate_cpu(lowered_module, *args) 135*523fa7a6SAndroid Build Coastguard Worker 136*523fa7a6SAndroid Build Coastguard Worker @executorch_call_delegate.py_functionalize_impl 137*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore 138*523fa7a6SAndroid Build Coastguard Worker def call_delegate_functionalize(ctx, lowered_module, *args): 139*523fa7a6SAndroid Build Coastguard Worker unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args) 140*523fa7a6SAndroid Build Coastguard Worker with ctx.redispatch_to_next(): 141*523fa7a6SAndroid Build Coastguard Worker res = executorch_call_delegate(lowered_module, *unwrapped_args) 142*523fa7a6SAndroid Build Coastguard Worker return ctx.wrap_tensors(res) 143*523fa7a6SAndroid Build Coastguard Worker 144*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre 145*523fa7a6SAndroid Build Coastguard Worker def is_lowered_module(obj: Any) -> bool: 146*523fa7a6SAndroid Build Coastguard Worker """ 147*523fa7a6SAndroid Build Coastguard Worker This function is added to avoid using isinstance(obj, LoweredBackendModule) as it will import LoweredBackendModule, which may cause a circular import. 148*523fa7a6SAndroid Build Coastguard Worker """ 149*523fa7a6SAndroid Build Coastguard Worker return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE 150*523fa7a6SAndroid Build Coastguard Worker 151*523fa7a6SAndroid Build Coastguard Worker def get_lowered_module_name( 152*523fa7a6SAndroid Build Coastguard Worker root: torch.nn.Module, 153*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type. 154*523fa7a6SAndroid Build Coastguard Worker lowered_module: LOWERED_BACKEND_MODULE_TYPE, # noqa 155*523fa7a6SAndroid Build Coastguard Worker ) -> str: 156*523fa7a6SAndroid Build Coastguard Worker """ 157*523fa7a6SAndroid Build Coastguard Worker Adds the given lowered_module into the given root module and returns the 158*523fa7a6SAndroid Build Coastguard Worker name of the module added. 159*523fa7a6SAndroid Build Coastguard Worker """ 160*523fa7a6SAndroid Build Coastguard Worker # Find a qualifying name for the lowered submodule 161*523fa7a6SAndroid Build Coastguard Worker qualname = None 162*523fa7a6SAndroid Build Coastguard Worker i = 0 163*523fa7a6SAndroid Build Coastguard Worker while True: 164*523fa7a6SAndroid Build Coastguard Worker qualname = f"lowered_module_{i}" 165*523fa7a6SAndroid Build Coastguard Worker if not hasattr(root, qualname): 166*523fa7a6SAndroid Build Coastguard Worker break 167*523fa7a6SAndroid Build Coastguard Worker i += 1 168*523fa7a6SAndroid Build Coastguard Worker assert qualname is not None 169*523fa7a6SAndroid Build Coastguard Worker 170*523fa7a6SAndroid Build Coastguard Worker root.add_module(qualname, lowered_module) 171*523fa7a6SAndroid Build Coastguard Worker return qualname 172