xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/schema_type_annotation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import torch.fx
4import inspect
5from typing import Any, Dict, Optional, Tuple
6from torch.fx.node import Argument, Target
7from torch._jit_internal import boolean_dispatched
8from torch.fx.operator_schemas import _torchscript_type_to_python_type
9
10from torch.fx import Transformer
11
12class AnnotateTypesWithSchema(Transformer):
13    """
14    Use Python function signatures to annotate types for `Nodes` within an FX graph.
15    This pulls out Python function signatures for:
16
17        1. Standard `torch.nn` Module calls
18        2. `torch.nn.functional` calls
19        3. Attribute fetches via `get_attr`
20
21    Example usage:
22
23        m = torchvision.models.resnet18()
24
25        traced = torch.fx.symbolic_trace(m)
26
27        traced = AnnotateTypesWithSchema(traced).transform()
28
29    """
30    def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True,
31                 annotate_modules : bool = True, annotate_get_attrs : bool = True):
32        super().__init__(module)
33        self.annotate_functionals = annotate_functionals
34        self.annotate_modules = annotate_modules
35        self.annotate_get_attrs = annotate_get_attrs
36
37    def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
38        python_ret_type = None
39        if self.annotate_functionals and target.__module__ == 'torch.nn.functional':
40            target_for_analysis = target
41            if target in boolean_dispatched:
42                # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
43                # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
44                # branches of the dispatch have exactly the same signature. If they do, use the `true`
45                # branch signature for analysis. Otherwise, leave this un-normalized
46                assert not isinstance(target, str)
47                dispatched = boolean_dispatched[target]
48                if_true, if_false = dispatched['if_true'], dispatched['if_false']
49                # TODO: can we emit the union of these? What are the implications on TorchScript
50                # compilation?
51                if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation:
52                    return super().call_function(target, args, kwargs)
53                target_for_analysis = if_true
54
55            python_ret_type = self._extract_python_return_type(target_for_analysis)
56
57        return_proxy = super().call_function(target, args, kwargs)
58        return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
59        return return_proxy
60
61    def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
62        python_ret_type = None
63        assert isinstance(target, str)
64        submod = self.fetch_attr(target)
65        if self.annotate_modules and hasattr(submod.__class__, '__name__'):
66            classname = submod.__class__.__name__
67            if getattr(torch.nn, classname, None) == submod.__class__:
68                python_ret_type = self._extract_python_return_type(submod.forward)
69        return_proxy = super().call_module(target, args, kwargs)
70        return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type
71        return return_proxy
72
73    def get_attr(self, target : torch.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]):
74        attr_proxy = super().get_attr(target, args, kwargs)
75
76        if self.annotate_get_attrs:
77            module_itr = self.module
78            assert isinstance(target, str)
79            atoms = target.split('.')
80            for i, atom in enumerate(atoms):
81                if not hasattr(module_itr, atom):
82                    raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!')
83                module_itr = getattr(module_itr, atom)
84
85            maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr)
86            if maybe_inferred_ts_type.success():
87                python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type())
88                attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type
89
90        return attr_proxy
91
92    def _extract_python_return_type(self, target : Target) -> Optional[Any]:
93        """
94        Given a Python call target, try to extract the Python return annotation
95        if it is available, otherwise return None
96
97        Args:
98
99            target (Callable): Python callable to get return annotation for
100
101        Returns:
102
103            Optional[Any]: Return annotation from the `target`, or None if it was
104                not available.
105        """
106        assert callable(target)
107        try:
108            sig = inspect.signature(target)
109        except (ValueError, TypeError):
110            return None
111
112        return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None
113