xref: /aosp_15_r20/external/pytorch/torch/fx/passes/shape_prop.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import torch
4import torch.fx
5import traceback
6
7from torch._dispatch.python import enable_python_dispatcher
8from torch.fx.node import Node, map_aggregate
9from typing import Any, Tuple, NamedTuple, Optional, Dict
10from torch.fx._compatibility import compatibility
11from torch._guards import detect_fake_mode
12from torch._subclasses.meta_utils import is_sparse_any
13
14__all__ = ['TensorMetadata', 'ShapeProp']
15
16@compatibility(is_backward_compatible=True)
17class TensorMetadata(NamedTuple):
18    # TensorMetadata is a structure containing pertinent information
19    # about a tensor within a PyTorch program.
20
21    # General Tensor metadata
22    shape : torch.Size
23    dtype : torch.dtype
24    requires_grad : bool
25    stride : Tuple[int, ...]
26    memory_format : Optional[torch.memory_format]
27
28    # Quantization metadata
29    is_quantized : bool
30    qparams: Dict[str, Any]
31
32def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata:
33    """
34    Extract a TensorMetadata NamedTuple describing `result`.
35    """
36    shape = result.shape
37    dtype = result.dtype
38    requires_grad = result.requires_grad
39    stride = result.stride() if not is_sparse_any(result) else None
40
41    memory_format = None
42
43    if include_contiguity and not is_sparse_any(result):
44        memory_formats = {
45            torch.contiguous_format,
46            torch.channels_last,
47            torch.channels_last_3d,
48        }
49        for query_format in memory_formats:
50            if result.is_contiguous(memory_format=query_format):
51                memory_format = query_format
52                break
53
54    is_quantized = result.is_quantized
55    qparams: Dict[str, Any] = {}
56    if is_quantized:
57        qscheme = result.qscheme()
58        qparams["qscheme"] = qscheme
59        if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
60            qparams["scale"] = result.q_scale()  # type: ignore[assignment]
61            qparams["zero_point"] = result.q_zero_point()  # type: ignore[assignment]
62        elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}:
63            # In this branch, scale and zero_point are expected to be tensors,
64            # we store the values as immutable_list in TensorMetadata for
65            # easier serialization downstream
66            qparams["scale"] = result.q_per_channel_scales().tolist()  # type: ignore[assignment]
67            qparams["zero_point"] = result.q_per_channel_zero_points().tolist()  # type: ignore[assignment]
68            qparams["axis"] = result.q_per_channel_axis()  # type: ignore[assignment]
69
70    return TensorMetadata(
71        shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams)
72
73@compatibility(is_backward_compatible=True)
74class ShapeProp(torch.fx.Interpreter):
75    """
76    Execute an FX graph Node-by-Node and
77    record the shape and type of the result
78    into the corresponding node.
79
80    Example:
81         In this example, we record the shape
82         and data type of a module given
83         an example input ``torch.randn(50, D_in)``.
84         We print the name, shape and dtype of each node.
85
86        class TwoLayerNet(torch.nn.Module):
87            def __init__(self, D_in, H, D_out):
88                super().__init__()
89                self.linear1 = torch.nn.Linear(D_in, H)
90                self.linear2 = torch.nn.Linear(H, D_out)
91            def forward(self, x):
92                h_relu = self.linear1(x).clamp(min=0)
93                y_pred = self.linear2(h_relu)
94                return y_pred
95        N, D_in, H, D_out = 64, 1000, 100, 10
96        x = torch.randn(N, D_in)
97        y = torch.randn(N, D_out)
98        model = TwoLayerNet(D_in, H, D_out)
99        gm = torch.fx.symbolic_trace(model)
100        sample_input = torch.randn(50, D_in)
101        ShapeProp(gm).propagate(sample_input)
102
103        for node in gm.graph.nodes:
104            print(node.name, node.meta['tensor_meta'].dtype,
105                node.meta['tensor_meta'].shape)
106
107        The output of this code is:
108
109        x torch.float32 torch.Size([50, 1000])
110        linear1 torch.float32 torch.Size([50, 100])
111        clamp_1 torch.float32 torch.Size([50, 100])
112        linear2 torch.float32 torch.Size([50, 10])
113        output torch.float32 torch.Size([50, 10])
114
115    Args:
116         module (GraphModule): The module to be executed
117         fake_mode (FakeTensorMode): A fake mode for copying the gm
118
119    """
120    def __init__(self, gm, fake_mode=None):
121        super().__init__(gm)
122        if fake_mode is None:
123            fake_mode = detect_fake_mode()
124        if fake_mode is not None:
125            from torch._dynamo.utils import deepcopy_to_fake_tensor
126            # Note:
127            # We need fake execution cause the inputs are fake, however, we cannot fakify the module
128            # - because we need to write to the tensor_meta of the real module. So we fakify to
129            # produce a result (L131 below), to extract tensor meta, and then keep going.
130            #
131            # If we were to fakify, we would write to the wrong node, and then downstream fusion
132            # would be missing the tensor_meta.
133            #
134            # See torch/_inductor/overrides.py for where this is called upstream of fusion.
135            self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode)
136            self.fake_mode = fake_mode
137        else:
138            self.fake_module = None
139            self.fake_mode = None
140
141        self.real_module = self.module
142
143    def run_node(self, n : Node) -> Any:
144        try:
145            if self.fake_module is not None:
146                # Hacky swap. Alternatively, we could do this with overriding
147                # call_module and get_attr.
148                self.module = self.fake_module
149            try:
150                if self.fake_mode is not None:
151                    with self.fake_mode, enable_python_dispatcher():
152                        result = super().run_node(n)
153                else:
154                    result = super().run_node(n)
155            finally:
156                self.module = self.real_module
157        except Exception as e:
158            traceback.print_exc()
159            raise RuntimeError(
160                f"ShapeProp error for: node={n.format_node()} with "
161                f"meta={n.meta}"
162            ) from e
163
164        found_tensor = False
165
166        def extract_tensor_meta(obj):
167            if isinstance(obj, torch.Tensor):
168                nonlocal found_tensor
169                found_tensor = True
170                return _extract_tensor_metadata(obj)
171            else:
172                return obj
173
174        meta = map_aggregate(result, extract_tensor_meta)
175        if found_tensor:
176            n.meta['tensor_meta'] = meta
177
178        n.meta['type'] = type(result)
179        return result
180
181    def propagate(self, *args):
182        """
183        Run `module` via interpretation and return the result and
184        record the shape and type of each node.
185
186        Args:
187            *args (Tensor): the sample input.
188
189        Returns:
190            Any: The value returned from executing the Module
191        """
192        if self.fake_mode is not None:
193            fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
194        else:
195            fake_args = args
196        return super().run(*fake_args)
197