1# mypy: allow-untyped-defs 2import torch 3import torch.fx 4import inspect 5from typing import Any, Dict, Optional, Tuple 6from torch.fx.node import Argument, Target 7from torch._jit_internal import boolean_dispatched 8from torch.fx.operator_schemas import _torchscript_type_to_python_type 9 10from torch.fx import Transformer 11 12class AnnotateTypesWithSchema(Transformer): 13 """ 14 Use Python function signatures to annotate types for `Nodes` within an FX graph. 15 This pulls out Python function signatures for: 16 17 1. Standard `torch.nn` Module calls 18 2. `torch.nn.functional` calls 19 3. Attribute fetches via `get_attr` 20 21 Example usage: 22 23 m = torchvision.models.resnet18() 24 25 traced = torch.fx.symbolic_trace(m) 26 27 traced = AnnotateTypesWithSchema(traced).transform() 28 29 """ 30 def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True, 31 annotate_modules : bool = True, annotate_get_attrs : bool = True): 32 super().__init__(module) 33 self.annotate_functionals = annotate_functionals 34 self.annotate_modules = annotate_modules 35 self.annotate_get_attrs = annotate_get_attrs 36 37 def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): 38 python_ret_type = None 39 if self.annotate_functionals and target.__module__ == 'torch.nn.functional': 40 target_for_analysis = target 41 if target in boolean_dispatched: 42 # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have 43 # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` 44 # branches of the dispatch have exactly the same signature. If they do, use the `true` 45 # branch signature for analysis. Otherwise, leave this un-normalized 46 assert not isinstance(target, str) 47 dispatched = boolean_dispatched[target] 48 if_true, if_false = dispatched['if_true'], dispatched['if_false'] 49 # TODO: can we emit the union of these? What are the implications on TorchScript 50 # compilation? 51 if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation: 52 return super().call_function(target, args, kwargs) 53 target_for_analysis = if_true 54 55 python_ret_type = self._extract_python_return_type(target_for_analysis) 56 57 return_proxy = super().call_function(target, args, kwargs) 58 return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type 59 return return_proxy 60 61 def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): 62 python_ret_type = None 63 assert isinstance(target, str) 64 submod = self.fetch_attr(target) 65 if self.annotate_modules and hasattr(submod.__class__, '__name__'): 66 classname = submod.__class__.__name__ 67 if getattr(torch.nn, classname, None) == submod.__class__: 68 python_ret_type = self._extract_python_return_type(submod.forward) 69 return_proxy = super().call_module(target, args, kwargs) 70 return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type 71 return return_proxy 72 73 def get_attr(self, target : torch.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): 74 attr_proxy = super().get_attr(target, args, kwargs) 75 76 if self.annotate_get_attrs: 77 module_itr = self.module 78 assert isinstance(target, str) 79 atoms = target.split('.') 80 for i, atom in enumerate(atoms): 81 if not hasattr(module_itr, atom): 82 raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!') 83 module_itr = getattr(module_itr, atom) 84 85 maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr) 86 if maybe_inferred_ts_type.success(): 87 python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type()) 88 attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type 89 90 return attr_proxy 91 92 def _extract_python_return_type(self, target : Target) -> Optional[Any]: 93 """ 94 Given a Python call target, try to extract the Python return annotation 95 if it is available, otherwise return None 96 97 Args: 98 99 target (Callable): Python callable to get return annotation for 100 101 Returns: 102 103 Optional[Any]: Return annotation from the `target`, or None if it was 104 not available. 105 """ 106 assert callable(target) 107 try: 108 sig = inspect.signature(target) 109 except (ValueError, TypeError): 110 return None 111 112 return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None 113