1# mypy: ignore-errors 2 3import torch 4import torch.fx 5import traceback 6 7from torch._dispatch.python import enable_python_dispatcher 8from torch.fx.node import Node, map_aggregate 9from typing import Any, Tuple, NamedTuple, Optional, Dict 10from torch.fx._compatibility import compatibility 11from torch._guards import detect_fake_mode 12from torch._subclasses.meta_utils import is_sparse_any 13 14__all__ = ['TensorMetadata', 'ShapeProp'] 15 16@compatibility(is_backward_compatible=True) 17class TensorMetadata(NamedTuple): 18 # TensorMetadata is a structure containing pertinent information 19 # about a tensor within a PyTorch program. 20 21 # General Tensor metadata 22 shape : torch.Size 23 dtype : torch.dtype 24 requires_grad : bool 25 stride : Tuple[int, ...] 26 memory_format : Optional[torch.memory_format] 27 28 # Quantization metadata 29 is_quantized : bool 30 qparams: Dict[str, Any] 31 32def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata: 33 """ 34 Extract a TensorMetadata NamedTuple describing `result`. 35 """ 36 shape = result.shape 37 dtype = result.dtype 38 requires_grad = result.requires_grad 39 stride = result.stride() if not is_sparse_any(result) else None 40 41 memory_format = None 42 43 if include_contiguity and not is_sparse_any(result): 44 memory_formats = { 45 torch.contiguous_format, 46 torch.channels_last, 47 torch.channels_last_3d, 48 } 49 for query_format in memory_formats: 50 if result.is_contiguous(memory_format=query_format): 51 memory_format = query_format 52 break 53 54 is_quantized = result.is_quantized 55 qparams: Dict[str, Any] = {} 56 if is_quantized: 57 qscheme = result.qscheme() 58 qparams["qscheme"] = qscheme 59 if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: 60 qparams["scale"] = result.q_scale() # type: ignore[assignment] 61 qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] 62 elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: 63 # In this branch, scale and zero_point are expected to be tensors, 64 # we store the values as immutable_list in TensorMetadata for 65 # easier serialization downstream 66 qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] 67 qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] 68 qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] 69 70 return TensorMetadata( 71 shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) 72 73@compatibility(is_backward_compatible=True) 74class ShapeProp(torch.fx.Interpreter): 75 """ 76 Execute an FX graph Node-by-Node and 77 record the shape and type of the result 78 into the corresponding node. 79 80 Example: 81 In this example, we record the shape 82 and data type of a module given 83 an example input ``torch.randn(50, D_in)``. 84 We print the name, shape and dtype of each node. 85 86 class TwoLayerNet(torch.nn.Module): 87 def __init__(self, D_in, H, D_out): 88 super().__init__() 89 self.linear1 = torch.nn.Linear(D_in, H) 90 self.linear2 = torch.nn.Linear(H, D_out) 91 def forward(self, x): 92 h_relu = self.linear1(x).clamp(min=0) 93 y_pred = self.linear2(h_relu) 94 return y_pred 95 N, D_in, H, D_out = 64, 1000, 100, 10 96 x = torch.randn(N, D_in) 97 y = torch.randn(N, D_out) 98 model = TwoLayerNet(D_in, H, D_out) 99 gm = torch.fx.symbolic_trace(model) 100 sample_input = torch.randn(50, D_in) 101 ShapeProp(gm).propagate(sample_input) 102 103 for node in gm.graph.nodes: 104 print(node.name, node.meta['tensor_meta'].dtype, 105 node.meta['tensor_meta'].shape) 106 107 The output of this code is: 108 109 x torch.float32 torch.Size([50, 1000]) 110 linear1 torch.float32 torch.Size([50, 100]) 111 clamp_1 torch.float32 torch.Size([50, 100]) 112 linear2 torch.float32 torch.Size([50, 10]) 113 output torch.float32 torch.Size([50, 10]) 114 115 Args: 116 module (GraphModule): The module to be executed 117 fake_mode (FakeTensorMode): A fake mode for copying the gm 118 119 """ 120 def __init__(self, gm, fake_mode=None): 121 super().__init__(gm) 122 if fake_mode is None: 123 fake_mode = detect_fake_mode() 124 if fake_mode is not None: 125 from torch._dynamo.utils import deepcopy_to_fake_tensor 126 # Note: 127 # We need fake execution cause the inputs are fake, however, we cannot fakify the module 128 # - because we need to write to the tensor_meta of the real module. So we fakify to 129 # produce a result (L131 below), to extract tensor meta, and then keep going. 130 # 131 # If we were to fakify, we would write to the wrong node, and then downstream fusion 132 # would be missing the tensor_meta. 133 # 134 # See torch/_inductor/overrides.py for where this is called upstream of fusion. 135 self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode) 136 self.fake_mode = fake_mode 137 else: 138 self.fake_module = None 139 self.fake_mode = None 140 141 self.real_module = self.module 142 143 def run_node(self, n : Node) -> Any: 144 try: 145 if self.fake_module is not None: 146 # Hacky swap. Alternatively, we could do this with overriding 147 # call_module and get_attr. 148 self.module = self.fake_module 149 try: 150 if self.fake_mode is not None: 151 with self.fake_mode, enable_python_dispatcher(): 152 result = super().run_node(n) 153 else: 154 result = super().run_node(n) 155 finally: 156 self.module = self.real_module 157 except Exception as e: 158 traceback.print_exc() 159 raise RuntimeError( 160 f"ShapeProp error for: node={n.format_node()} with " 161 f"meta={n.meta}" 162 ) from e 163 164 found_tensor = False 165 166 def extract_tensor_meta(obj): 167 if isinstance(obj, torch.Tensor): 168 nonlocal found_tensor 169 found_tensor = True 170 return _extract_tensor_metadata(obj) 171 else: 172 return obj 173 174 meta = map_aggregate(result, extract_tensor_meta) 175 if found_tensor: 176 n.meta['tensor_meta'] = meta 177 178 n.meta['type'] = type(result) 179 return result 180 181 def propagate(self, *args): 182 """ 183 Run `module` via interpretation and return the result and 184 record the shape and type of each node. 185 186 Args: 187 *args (Tensor): the sample input. 188 189 Returns: 190 Any: The value returned from executing the Module 191 """ 192 if self.fake_mode is not None: 193 fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args] 194 else: 195 fake_args = args 196 return super().run(*fake_args) 197