1# mypy: allow-untyped-defs 2import functools 3from contextlib import contextmanager 4from dataclasses import dataclass, field 5from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple 6 7import torch 8import torch.nn as nn 9 10 11@dataclass 12class TracingConfig: 13 """ 14 This represents a symbolic tracing configuration. 15 16 Args: 17 tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to 18 use for symbolic tracing. The default value is the native 19 :class:`torch.fx.Tracer` constructed with default arguments. 20 However, the user may want to pass a different value such as the 21 ``HFTracer`` for models in the HuggingFace Transformers_ library. 22 .. _Transformers: https://huggingface.co/docs/transformers/index 23 concrete_args (Optional[Dict[str, Any]]): Concrete arguments that 24 should not be treated as ``torch.fx.Proxy`` when tracing the 25 module ``forward()``. Passing ``concrete_args`` allows partially 26 specializing the forward, e.g. to remove control flow or data 27 structures. This ``concrete_args`` here is the same argument used 28 in :meth:`~torch.fx.Tracer.trace`. 29 """ 30 31 tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer) 32 concrete_args: Optional[Dict[str, Any]] = None 33 34 35class _ParamUsageInfo(NamedTuple): 36 """ 37 This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record 38 execution information. The ``dict`` maps modules to a list of these 39 ``_ParamUsageInfo`` instances, where each instance represents a group of 40 parameters used together. 41 42 Specifically, for each module key in the ``dict``, each instance of this 43 class represents either: 44 (1) the module and some sublist of its ``named_parameters()`` used 45 together in execution (see ``_patched_create_proxy()``), or 46 (2) a submodule and all of ``submodule.named_parameters()`` (see 47 ``_patched_call_module()``). 48 49 Type (1) corresponds to directly using parameters in ops without calling 50 ``forward()``, and type (2) corresponds to calling ``forward()``. The 51 mapped-to lists in the ``dict`` follow the execution order. 52 """ 53 54 module: nn.Module 55 named_params: List[Tuple[str, nn.Parameter]] 56 57 58class _ExecutionInfo: 59 """ 60 This represents the execution order information from the forward pass. 61 62 Attributes: 63 curr_module (nn.Module): Current module being traced. 64 module_forward_order (List[nn.Module]): The modules in (pre-)forward 65 order, i.e. the order in which their ``forward()`` methods are 66 called. Each call to a module's ``forward()`` corresponds to one 67 element in the list. 68 module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]): 69 Maps a module to a list of module execution infos. See 70 :class:`_ParamUsageInfo` for details. 71 param_forward_order (List[nn.Parameter]): The parameters in forward 72 execution order, where only a parameter's first participation is 73 included. 74 visited_params (Set[nn.Parameter]): The parameters visited so far 75 during the trace. This is only used during tracing for fast 76 membership check. Invariant: The parameters in 77 ``param_forward_order`` are exactly those in ``visited_params``. 78 """ 79 80 def __init__(self, root_module: nn.Module) -> None: 81 self.curr_module: nn.Module = root_module 82 self.module_forward_order: List[nn.Module] = [root_module] 83 self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = { 84 root_module: [] 85 } 86 self.param_forward_order: List[nn.Parameter] = [] 87 self.visited_params: Set[nn.Parameter] = set() 88 89 90class _ExecOrderTracer: 91 def __init__(self) -> None: 92 self.exec_info: Optional[_ExecutionInfo] = None 93 94 @contextmanager 95 def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module): 96 self.exec_info = _ExecutionInfo(root_module) 97 orig_call_module = tracer.call_module 98 orig_create_proxy = tracer.create_proxy 99 tracer.call_module = functools.partial( # type: ignore[method-assign] 100 self._patched_call_module, orig_call_module, self.exec_info 101 ) 102 fqn_to_param = dict(root_module.named_parameters()) 103 tracer.create_proxy = functools.partial( # type: ignore[method-assign] 104 self._patched_create_proxy, 105 orig_create_proxy, 106 self.exec_info, 107 fqn_to_param, 108 ) 109 try: 110 yield 111 finally: 112 tracer.call_module = orig_call_module # type: ignore[method-assign] 113 tracer.create_proxy = orig_create_proxy # type: ignore[method-assign] 114 115 def _patched_call_module( 116 self, 117 call_module: Callable, 118 exec_info: _ExecutionInfo, 119 # Below are the expected arguments to `call_module()` 120 module: nn.Module, 121 forward: Callable, 122 args: Tuple[Any, ...], 123 kwargs: Dict[str, Any], 124 ) -> Any: 125 """ 126 Overrides ``call_module`` to save execution information to 127 ``exec_info``. Note that ``call_module`` is called during symbolic 128 tracing for each non-root module. 129 130 Args: 131 call_module (Callable): Original ``call_module`` to override. 132 exec_info (_ExecutionInfo): Used to record execution information. 133 module (nn.Module): Module corresponding to this ``call_module``. 134 forward (Callable): ``forward()`` method of ``module`` to be called 135 for this ``call_module``. 136 args (Tuple[Any, ...]): Positional arguments for ``forward``. 137 kwargs (Dict[str, Any]): Keyword arguments for ``forward``. 138 139 Returns: 140 Same return value as ``call_module``. 141 """ 142 exec_info.module_forward_order.append(module) 143 named_params = list(module.named_parameters()) 144 curr_module = exec_info.curr_module 145 if named_params: 146 assert ( 147 curr_module in exec_info.module_to_param_usage_infos 148 ), "The current module should have already been processed by a patched `call_module`" 149 exec_info.module_to_param_usage_infos[exec_info.curr_module].append( 150 _ParamUsageInfo(module, named_params) 151 ) 152 prev_curr_module = curr_module 153 exec_info.curr_module = module 154 exec_info.module_to_param_usage_infos[module] = [] 155 output = call_module(module, forward, args, kwargs) 156 exec_info.curr_module = prev_curr_module 157 return output 158 159 def _patched_create_proxy( 160 self, 161 create_proxy: Callable, 162 exec_info: _ExecutionInfo, 163 fqn_to_param: Dict[str, nn.Parameter], 164 # Below are the expected arguments to `create_proxy()` 165 kind: str, 166 target: torch.fx.node.Target, 167 args: Tuple[Any, ...], 168 kwargs: Dict[str, Any], 169 name: Optional[str] = None, 170 type_expr: Optional[Any] = None, 171 proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None, 172 ) -> torch.fx.Proxy: 173 """ 174 Overrides ``create_proxy`` to save execution information to 175 ``exec_info``. Note that ``create_proxy`` is called during symbolic 176 tracing for each leaf function/method/module. 177 178 Args: 179 create_proxy (Callable): Original ``create_proxy`` to override. 180 exec_info (_ExecutionInfo): Used to record execution information. 181 fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the 182 root module's ``named_parameters()`` with FQN as key and 183 parameter as value. 184 kind (str): Kind of the target method ('call_function', 185 'call_method', 'get_attr', 'call_module', 'placeholder', or 186 'output'). See :class:`torch.fx.Graph` for details. This is 187 passed to ``create_proxy``. 188 target (torch.fx.node.Target): Contains the string name of the 189 function/method/module. This is passed to ``create_proxy``. 190 args (Tuple[Any, ...]): Positional arguments for the function/ 191 method/module. This is passed to ``create_proxy``. 192 kwargs (Dict[str, Any]): Keyword arguments for the function/method/ 193 module. This is passed to ``create_proxy`` 194 name (Optional[str]): An optional string name for the ``Node`` 195 created in ``create_proxy``. This is passed to 196 ``create_proxy``. 197 type_expr (Optional[Any]): An optional type annotation representing 198 the Python type that the output of the node has. This is passed 199 to ``create_proxy``. 200 proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]): 201 An alternative proxy constructor used in ``create_proxy``. This 202 is passed to ``create_proxy``. 203 204 Returns: 205 torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object. 206 """ 207 proxy = create_proxy( 208 kind, target, args, kwargs, name, type_expr, proxy_factory_fn 209 ) 210 curr_module = exec_info.curr_module 211 if kind in ("call_function", "call_method"): 212 if args is not None: 213 named_params: List[Tuple[str, nn.Parameter]] = [] 214 for arg in args: 215 if ( 216 isinstance(arg, torch.fx.Proxy) 217 and arg.node.target in fqn_to_param 218 ): 219 param = fqn_to_param[arg.node.target] # type: ignore[index] 220 named_params.append((arg.node.target, param)) # type: ignore[arg-type] 221 if param not in exec_info.visited_params: 222 exec_info.visited_params.add(param) 223 exec_info.param_forward_order.append(param) 224 if named_params: 225 exec_info.module_to_param_usage_infos[curr_module].append( 226 _ParamUsageInfo(curr_module, named_params) 227 ) 228 elif kind == "call_module": 229 named_params = list(curr_module.named_parameters()) 230 if named_params: 231 exec_info.module_to_param_usage_infos[curr_module].append( 232 _ParamUsageInfo(curr_module, named_params) 233 ) 234 for _, param in named_params: 235 if param not in exec_info.visited_params: 236 exec_info.visited_params.add(param) 237 exec_info.param_forward_order.append(param) 238 return proxy 239