1# mypy: allow-untyped-defs 2import operator 3from typing import Any, Callable, Dict, Tuple, Optional 4 5import torch 6import torch.fx 7import torch.fx as fx 8from torch.fx import Transformer, Proxy 9from torch.fx.node import Argument, Target, Node, map_aggregate 10from torch.fx.operator_schemas import ( 11 normalize_module, 12 normalize_function, 13 create_type_hint, 14) 15 16from .schema_type_annotation import AnnotateTypesWithSchema 17 18 19class NormalizeArgs(Transformer): 20 """ 21 Normalize arguments to Python targets. This means that 22 `args/kwargs` will be matched up to the module/functional's 23 signature and rewritten to exclusively kwargs in positional order 24 if `normalize_to_only_use_kwargs` is true. Also populates default 25 values. Does not support positional-only parameters or varargs 26 parameters (*args, **kwargs). 27 28 If the nodes have 'type' metadata, it will use it to disambiguate 29 overloads. Otherwise, it will throw an error. 30 31 Example usage: 32 m = torchvision.models.resnet18() 33 traced = torch.fx.symbolic_trace(m) 34 traced = NormalizeArgs(traced).transform() 35 """ 36 37 def __init__( 38 self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True 39 ): 40 super().__init__(module) 41 self.node_map: Dict[Proxy, Node] = {} 42 self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs 43 44 def run_node(self, n: Node) -> Any: 45 args, kwargs = self.fetch_args_kwargs_from_env(n) 46 47 def get_type(arg): 48 if isinstance(arg, fx.Node): 49 return n.meta["type"] if "type" in n.meta else None 50 return type(arg) 51 52 arg_types = map_aggregate(n.args, get_type) 53 assert isinstance(arg_types, tuple) 54 arg_types = tuple([create_type_hint(i) for i in arg_types]) 55 kwarg_types = {k: get_type(v) for k, v in kwargs.items()} 56 if n.op == "call_function": 57 out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types) 58 else: 59 out = super().run_node(n) 60 if n.op != "output": 61 self.node_map[out] = n 62 out.node.meta = n.meta 63 out.node.type = n.type 64 return out 65 66 def call_function( 67 self, 68 target: Target, 69 args: Tuple[Argument, ...], 70 kwargs: Dict[str, Any], 71 arg_types: Optional[Tuple[Any, ...]] = None, 72 kwarg_types: Optional[Dict[str, Any]] = None, 73 ): 74 assert callable(target) 75 new_args_and_kwargs = normalize_function( 76 target, 77 args, # type: ignore[arg-type] 78 kwargs, 79 arg_types, # type: ignore[arg-type] 80 kwarg_types, 81 self.normalize_to_only_use_kwargs, 82 ) 83 if new_args_and_kwargs: 84 new_args, new_kwargs = new_args_and_kwargs 85 return self.tracer.create_proxy( 86 "call_function", target, new_args, new_kwargs 87 ) 88 else: 89 return super().call_function(target, args, kwargs) 90 91 def call_module( 92 self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] 93 ): 94 assert isinstance(target, str) 95 new_args_and_kwargs = normalize_module( 96 self.module, 97 target, 98 args, # type: ignore[arg-type] 99 kwargs, 100 self.normalize_to_only_use_kwargs, 101 ) 102 if new_args_and_kwargs: 103 new_args, new_kwargs = new_args_and_kwargs 104 return super().call_module(target, new_args, new_kwargs) 105 else: 106 return super().call_module(target, args, kwargs) 107 108 109class NormalizeOperators(AnnotateTypesWithSchema): 110 """ 111 Normalize callsites that are different ways of "spelling" the same 112 invocation into a single, canonical call. Currently supports: 113 114 1. Normalize operators (e.g. operator.add) to the `torch` ops they 115 ultimately invoke (e.g. torch.add) when it is possible to statically 116 reason that 117 118 Example usage: 119 120 m = torchvision.models.resnet18() 121 122 traced = torch.fx.symbolic_trace(m) 123 124 traced = NormalizeOperators(traced).transform() 125 """ 126 127 binary_magic_method_remap: Dict[ 128 Callable[[Any, Any], Any], Callable[[Any, Any], Any] 129 ] = { 130 torch.add: operator.add, 131 torch.mul: operator.mul, 132 torch.sub: operator.sub, 133 torch.div: operator.truediv, 134 torch.floor_divide: operator.floordiv, 135 torch.remainder: operator.mod, 136 torch.eq: operator.eq, 137 torch.ne: operator.ne, 138 torch.lt: operator.lt, 139 torch.le: operator.le, 140 torch.gt: operator.gt, 141 torch.ge: operator.ge, 142 } 143 144 def call_function( 145 self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] 146 ): 147 # Normalize operators according to the magic methods implemented on tensors here: 148 # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 149 150 assert callable(target) 151 152 if target in self.binary_magic_method_remap: 153 if len(args) != 2: 154 return super().call_function(target, args, kwargs) 155 lhs, rhs = args 156 157 return super().call_function( 158 target=self.binary_magic_method_remap[target], 159 args=(lhs, rhs), 160 kwargs={}, 161 ) 162 163 return super().call_function(target, args, kwargs) 164