xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/normalize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import operator
3from typing import Any, Callable, Dict, Tuple, Optional
4
5import torch
6import torch.fx
7import torch.fx as fx
8from torch.fx import Transformer, Proxy
9from torch.fx.node import Argument, Target, Node, map_aggregate
10from torch.fx.operator_schemas import (
11    normalize_module,
12    normalize_function,
13    create_type_hint,
14)
15
16from .schema_type_annotation import AnnotateTypesWithSchema
17
18
19class NormalizeArgs(Transformer):
20    """
21    Normalize arguments to Python targets. This means that
22    `args/kwargs` will be matched up to the module/functional's
23    signature and rewritten to exclusively kwargs in positional order
24    if `normalize_to_only_use_kwargs` is true. Also populates default
25    values. Does not support positional-only parameters or varargs
26    parameters (*args, **kwargs).
27
28    If the nodes have 'type' metadata, it will use it to disambiguate
29    overloads. Otherwise, it will throw an error.
30
31    Example usage:
32        m = torchvision.models.resnet18()
33        traced = torch.fx.symbolic_trace(m)
34        traced = NormalizeArgs(traced).transform()
35    """
36
37    def __init__(
38        self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
39    ):
40        super().__init__(module)
41        self.node_map: Dict[Proxy, Node] = {}
42        self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
43
44    def run_node(self, n: Node) -> Any:
45        args, kwargs = self.fetch_args_kwargs_from_env(n)
46
47        def get_type(arg):
48            if isinstance(arg, fx.Node):
49                return n.meta["type"] if "type" in n.meta else None
50            return type(arg)
51
52        arg_types = map_aggregate(n.args, get_type)
53        assert isinstance(arg_types, tuple)
54        arg_types = tuple([create_type_hint(i) for i in arg_types])
55        kwarg_types = {k: get_type(v) for k, v in kwargs.items()}
56        if n.op == "call_function":
57            out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types)
58        else:
59            out = super().run_node(n)
60        if n.op != "output":
61            self.node_map[out] = n
62            out.node.meta = n.meta
63            out.node.type = n.type
64        return out
65
66    def call_function(
67        self,
68        target: Target,
69        args: Tuple[Argument, ...],
70        kwargs: Dict[str, Any],
71        arg_types: Optional[Tuple[Any, ...]] = None,
72        kwarg_types: Optional[Dict[str, Any]] = None,
73    ):
74        assert callable(target)
75        new_args_and_kwargs = normalize_function(
76            target,
77            args,  # type: ignore[arg-type]
78            kwargs,
79            arg_types,  # type: ignore[arg-type]
80            kwarg_types,
81            self.normalize_to_only_use_kwargs,
82        )
83        if new_args_and_kwargs:
84            new_args, new_kwargs = new_args_and_kwargs
85            return self.tracer.create_proxy(
86                "call_function", target, new_args, new_kwargs
87            )
88        else:
89            return super().call_function(target, args, kwargs)
90
91    def call_module(
92        self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
93    ):
94        assert isinstance(target, str)
95        new_args_and_kwargs = normalize_module(
96            self.module,
97            target,
98            args,  # type: ignore[arg-type]
99            kwargs,
100            self.normalize_to_only_use_kwargs,
101        )
102        if new_args_and_kwargs:
103            new_args, new_kwargs = new_args_and_kwargs
104            return super().call_module(target, new_args, new_kwargs)
105        else:
106            return super().call_module(target, args, kwargs)
107
108
109class NormalizeOperators(AnnotateTypesWithSchema):
110    """
111    Normalize callsites that are different ways of "spelling" the same
112    invocation into a single, canonical call. Currently supports:
113
114    1. Normalize operators (e.g. operator.add) to the `torch` ops they
115       ultimately invoke (e.g. torch.add) when it is possible to statically
116       reason that
117
118    Example usage:
119
120        m = torchvision.models.resnet18()
121
122        traced = torch.fx.symbolic_trace(m)
123
124        traced = NormalizeOperators(traced).transform()
125    """
126
127    binary_magic_method_remap: Dict[
128        Callable[[Any, Any], Any], Callable[[Any, Any], Any]
129    ] = {
130        torch.add: operator.add,
131        torch.mul: operator.mul,
132        torch.sub: operator.sub,
133        torch.div: operator.truediv,
134        torch.floor_divide: operator.floordiv,
135        torch.remainder: operator.mod,
136        torch.eq: operator.eq,
137        torch.ne: operator.ne,
138        torch.lt: operator.lt,
139        torch.le: operator.le,
140        torch.gt: operator.gt,
141        torch.ge: operator.ge,
142    }
143
144    def call_function(
145        self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
146    ):
147        # Normalize operators according to the magic methods implemented on tensors here:
148        # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950
149
150        assert callable(target)
151
152        if target in self.binary_magic_method_remap:
153            if len(args) != 2:
154                return super().call_function(target, args, kwargs)
155            lhs, rhs = args
156
157            return super().call_function(
158                target=self.binary_magic_method_remap[target],
159                args=(lhs, rhs),
160                kwargs={},
161            )
162
163        return super().call_function(target, args, kwargs)
164