1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from __future__ import annotations 10 11try: # noqa: C901 12 from torch._higher_order_ops.executorch_call_delegate import ( 13 executorch_call_delegate as executorch_call_delegate, 14 get_lowered_module_name as get_lowered_module_name, 15 is_lowered_module as is_lowered_module, 16 ) 17 18except ImportError: 19 20 # TODO: Delete this code once pytorch pin advances 21 22 from typing import Any, cast 23 24 import torch 25 import torch.utils._pytree as pytree 26 from torch._ops import HigherOrderOperator 27 from torch._subclasses.fake_tensor import FakeTensorMode 28 from torch.fx.experimental.proxy_tensor import ( 29 disable_proxy_modes_tracing, 30 get_proxy_slot, 31 ProxyTorchDispatchMode, 32 track_tensor_tree, 33 ) 34 from torch.utils._pytree import tree_flatten 35 36 executorch_call_delegate = HigherOrderOperator("executorch_call_delegate") 37 executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher) 38 executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot) 39 executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView) 40 executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU) 41 42 LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule" 43 44 # pyre-ignore 45 def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args): 46 # pyre-ignore 47 def _unwrap_proxy(e): 48 if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): 49 return e 50 return get_proxy_slot( 51 cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy 52 ) 53 54 if not is_lowered_module(lowered_module): 55 raise ValueError( 56 "executorch_call_delegate()'s first argument must be a LoweredBackendModule" 57 ) 58 59 with disable_proxy_modes_tracing(): 60 out = call_delegate_cpu(lowered_module, *args) 61 62 get_lowered_module_name(proxy_mode.tracer.root, lowered_module) 63 64 node_args = (lowered_module, *args) 65 proxy_args = pytree.tree_map(_unwrap_proxy, node_args) 66 out_proxy = proxy_mode.tracer.create_proxy( 67 "call_function", 68 func_overload, 69 proxy_args, 70 {}, 71 name="executorch_call_delegate", 72 ) 73 return track_tensor_tree( 74 out, out_proxy, constant=None, tracer=proxy_mode.tracer 75 ) 76 77 @executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) 78 # pyre-ignore 79 def call_delegate_cpu(lowered_module, *args): 80 # FX creates this immutable_dict/list concept. Get rid of this. 81 map_types = { 82 torch.fx.immutable_collections.immutable_dict: dict, 83 torch.fx.immutable_collections.immutable_list: list, 84 } 85 new_args = pytree.tree_map_only( 86 tuple(map_types.keys()), 87 lambda a: map_types[type(a)](a), 88 args, 89 lambda a: isinstance(a, tuple(map_types.keys())), 90 ) 91 return lowered_module.original_module.module()(*new_args) 92 93 @executorch_call_delegate.py_impl(torch._C.DispatchKey.Autograd) 94 # pyre-ignore 95 def call_delegate_autograd(lowered_module, *args): 96 # TODO: support autograd 97 flat_operands, _ = tree_flatten([lowered_module, *args]) 98 requires_grad = any( 99 f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) 100 ) 101 102 with torch._C._ExcludeDispatchKeyGuard( 103 torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU) 104 ): 105 res = executorch_call_delegate(lowered_module, *args) 106 107 if requires_grad: 108 # Create aliases of the output that has requires_grad=True. We need 109 # at least one of the inputs to err_fn to require grad so that the 110 # output will have a grad_fn. 111 112 # pyre-ignore 113 def fake_requires_grad(var): 114 if var is not None: 115 var = var.detach() 116 if torch.is_floating_point(var) or torch.is_complex(var): 117 var.requires_grad = True 118 return var 119 120 return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res) 121 122 return res 123 124 @executorch_call_delegate.py_impl(ProxyTorchDispatchMode) 125 # pyre-ignore 126 def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args): 127 res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args) 128 return res 129 130 @executorch_call_delegate.py_impl(FakeTensorMode) 131 # pyre-ignore 132 def call_delegate_fake_tensor_mode(mode, lowered_module, *args): 133 with mode: 134 return call_delegate_cpu(lowered_module, *args) 135 136 @executorch_call_delegate.py_functionalize_impl 137 # pyre-ignore 138 def call_delegate_functionalize(ctx, lowered_module, *args): 139 unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args) 140 with ctx.redispatch_to_next(): 141 res = executorch_call_delegate(lowered_module, *unwrapped_args) 142 return ctx.wrap_tensors(res) 143 144 # pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre 145 def is_lowered_module(obj: Any) -> bool: 146 """ 147 This function is added to avoid using isinstance(obj, LoweredBackendModule) as it will import LoweredBackendModule, which may cause a circular import. 148 """ 149 return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE 150 151 def get_lowered_module_name( 152 root: torch.nn.Module, 153 # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type. 154 lowered_module: LOWERED_BACKEND_MODULE_TYPE, # noqa 155 ) -> str: 156 """ 157 Adds the given lowered_module into the given root module and returns the 158 name of the module added. 159 """ 160 # Find a qualifying name for the lowered submodule 161 qualname = None 162 i = 0 163 while True: 164 qualname = f"lowered_module_{i}" 165 if not hasattr(root, qualname): 166 break 167 i += 1 168 assert qualname is not None 169 170 root.add_module(qualname, lowered_module) 171 return qualname 172