1# mypy: ignore-errors 2 3import functools 4import warnings 5from typing import Callable, Union 6 7import torch 8import torch.utils._pytree as pytree 9from torch._ops import OpOverload 10from torch._subclasses.fake_tensor import ( 11 FakeTensorMode, 12 tree_flatten_only, 13 UnsupportedFakeTensorException, 14) 15from torch.utils._python_dispatch import TorchDispatchMode 16 17 18aten = torch._ops.ops.aten 19 20 21def outputs_alias_inputs(outputs, inputs): 22 input_storages = { 23 inp._typed_storage()._cdata 24 for inp in tree_flatten_only(torch.Tensor, inputs) 25 if torch._C._has_storage(inp) 26 } 27 return any( 28 torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages 29 for out in tree_flatten_only(torch.Tensor, outputs) 30 ) 31 32 33def outputs_are_inputs(outputs, inputs): 34 input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)} 35 return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs)) 36 37 38def output_alias_each_other(outputs): 39 storages = set() 40 for out in tree_flatten_only(torch.Tensor, outputs): 41 if not torch._C._has_storage(out): 42 continue 43 stor = out._typed_storage()._cdata 44 if stor in storages: 45 return True 46 storages.add(stor) 47 return False 48 49 50def is_sdpa_error(func, idx, e): 51 if ( 52 ( 53 func is aten._scaled_dot_product_flash_attention.default 54 or func is aten._flash_attention_forward.default 55 ) 56 and idx in (6, 7) 57 and "Devices" in repr(e) 58 ): 59 return True 60 if ( 61 ( 62 func is aten._scaled_dot_product_efficient_attention.default 63 or func is aten._efficient_attention_forward.default 64 ) 65 and idx in (2, 3) 66 and "Devices" in repr(e) 67 ): 68 return True 69 if ( 70 func is aten._scaled_dot_product_cudnn_attention.default 71 and idx in (6, 7) 72 and "Devices" in repr(e) 73 ): 74 return True 75 return False 76 77 78class CrossRefFakeMode(TorchDispatchMode): 79 def __init__( 80 self, 81 ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None, 82 *, 83 check_strides=True, 84 check_aliasing=True, 85 ): 86 super().__init__() 87 self.ignore_op_fn = ( 88 ignore_op_fn if ignore_op_fn is not None else lambda fn: False 89 ) 90 self.check_strides = check_strides 91 self.check_aliasing = check_aliasing 92 93 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 94 kwargs = kwargs or {} 95 96 fake_r = None 97 98 # empty_like excluded for now due to sparse complex 99 # aten._to_dense.default this one is getting called with csc 100 if ( 101 func 102 not in ( 103 aten.lift_fresh.default, 104 aten.lift_fresh_copy.default, 105 aten.set_.source_Storage_storage_offset, 106 ) 107 and not self.ignore_op_fn(func) 108 and torch.Tag.dynamic_output_shape not in func.tags 109 and torch.Tag.inplace_view not in func.tags 110 and torch.Tag.data_dependent_output not in func.tags 111 ): 112 # Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow 113 from torch.fx.experimental.symbolic_shapes import ShapeEnv 114 115 try: 116 # TODO: enable_python_dispatcher() here 117 with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode: 118 fake_args, fake_kwargs = pytree.tree_map_only( 119 torch.Tensor, 120 functools.partial(fake_mode.from_tensor, static_shapes=True), 121 (args, kwargs), 122 ) 123 with warnings.catch_warnings(): 124 fake_r = func(*fake_args, **fake_kwargs) 125 except UnsupportedFakeTensorException: 126 pass 127 128 context = ( 129 f"When comparing the output of {func} on FakeTensor and concrete Tensors, " 130 f"found" 131 ) 132 r = func(*args, **kwargs) 133 if fake_r is not None: 134 r_flat = pytree.tree_leaves(r) 135 f_flat = pytree.tree_leaves(fake_r) 136 assert len(f_flat) == len( 137 r_flat 138 ), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}" 139 140 if self.check_aliasing: 141 r_aliasing = outputs_alias_inputs(r, (args, kwargs)) 142 f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs)) 143 assert ( 144 r_aliasing == f_aliasing 145 ), f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}" 146 147 r_identity_eq = outputs_are_inputs(r, (args, kwargs)) 148 f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs)) 149 assert ( 150 r_identity_eq == f_identity_eq 151 ), f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}" 152 153 r_output_alias_each_other = output_alias_each_other(r) 154 f_output_alias_each_other = output_alias_each_other(fake_r) 155 assert r_output_alias_each_other == f_output_alias_each_other, ( 156 f"{context} mismatch in outputs_alias_each_other check " 157 f"{f_output_alias_each_other} != {r_output_alias_each_other}" 158 ) 159 160 for idx, (r_out, fake_out) in enumerate( 161 zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r)) 162 ): 163 r_is_ten = isinstance(r_out, torch.Tensor) 164 assert r_is_ten == isinstance( 165 fake_out, torch.Tensor 166 ), f"{context} mismatched number of tensor outputs" 167 if r_is_ten: 168 assert r_out.requires_grad == fake_out.requires_grad, ( 169 f"{context} mismatched requires_grad-ness of outputs. " 170 f"This usually means that you have added autograd support " 171 f"for your operator at a dispatch key other than Autograd, " 172 f"which will lead to problems" 173 ) 174 if torch._C._has_storage(r_out): 175 r_offset = r_out.storage_offset() 176 f_offset = fake_out.storage_offset() 177 assert ( 178 r_offset == f_offset 179 ), f"{context} mismatched storage offset" 180 181 try: 182 torch._prims.utils.compare_tensor_meta( 183 r_out, 184 fake_out, 185 check_strides=self.check_strides, 186 allow_rhs_unbacked=True, 187 ) 188 except Exception as e: 189 if is_sdpa_error(func, idx, e): 190 continue 191 error_message = ( 192 f"{context} mismatched tensor metadata: {e}" 193 if len(r_flat) == 1 194 else f"{context} mismatched tensor metadata for output[{idx}]: {e}" 195 ) 196 raise RuntimeError(error_message) from e 197 return r 198