1# mypy: ignore-errors 2 3from collections import namedtuple 4from copy import deepcopy 5from itertools import combinations 6 7import torch 8from torch.fx.operator_schemas import normalize_function 9from torch.utils import _pytree as pytree 10from torch.utils._python_dispatch import TorchDispatchMode 11from torch.utils._pytree import tree_map 12 13 14# Named Tuples used within SchemaCheckMode 15Mutation = namedtuple("Mutation", ["op_name", "arg_name"]) 16Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"]) 17 18# Simplified naming for C++ classes 19SchemaArgument = torch._C._SchemaArgument 20SchemaArgType = torch._C._SchemaArgType 21SchemaInfo = torch._C._SchemaInfo 22 23# This TorchDispatchMode Subclass is used to verify op schemas 24# This TorchDispatchMode Scubclass currently: 25# - Records the called ops 26# - Checks for mutations on all inputs 27# - Checks for aliasing on all inputs 28 29 30# move these 2 functions here to avoid numpy dependency in testing/_internal/common_utils.py 31 32 33def is_iterable_of_tensors(iterable): 34 # Tensor itself is iterable so we check this first 35 if isinstance(iterable, torch.Tensor): 36 return False 37 try: 38 if len(iterable) == 0: 39 return False 40 for t in iter(iterable): 41 if not isinstance(t, torch.Tensor): 42 return False 43 except TypeError as te: 44 return False 45 return True 46 47 48def clone_inputs(args): 49 inputs = [] 50 51 for arg in args: 52 if isinstance(arg, torch.Tensor): 53 inputs.append(arg.detach().clone()) 54 elif is_iterable_of_tensors(arg): 55 inputs.append([t.detach().clone() for t in arg]) 56 else: 57 inputs.append(arg) 58 59 return inputs 60 61 62class SchemaCheckMode(TorchDispatchMode): 63 def __init__(self) -> None: 64 # Information recorded for testing purposes. For example: 65 # - incorrect schemas 66 # - overly conservative schemas 67 self.ops = [] 68 self.mutated = [] 69 self.aliasing = [] 70 71 def reset_cache(self): 72 self.ops.clear() 73 self.mutated.clear() 74 self.aliasing.clear() 75 76 def display_ops(self): 77 print(*self.ops, sep=",") 78 79 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 80 def bitwise_equal(lhs, rhs): 81 if lhs.is_quantized: 82 # TODO: This is only OK if can't have NaN quantized; idk if 83 # this is actually true 84 return torch.equal(lhs, rhs) 85 else: 86 return torch.allclose(lhs, rhs, equal_nan=True) 87 88 def has_mutated(before, after, md): 89 are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor 90 if ( 91 are_tensors 92 and before.layout != torch.sparse_csr 93 and after.layout != torch.sparse_csr 94 ): 95 return not ( 96 before.size() == after.size() 97 and bitwise_equal(before, after) 98 and md[0] == after.stride() 99 and md[1] == after._typed_storage()._cdata 100 ) 101 return False 102 103 def has_aliased(lhs, rhs): 104 try: 105 return torch._C._overlaps(lhs, rhs) 106 except Exception as exception: 107 if str(exception).startswith("Cannot inspect value of type "): 108 return False 109 else: 110 raise exception 111 112 def standardize_name(name): 113 return name if name != "self" else "input" 114 115 def unwrap(e): 116 if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor: 117 try: 118 return e.elem 119 except AttributeError as t: 120 return e 121 return e 122 123 def parse_metadata(e): 124 if isinstance(e, torch.Tensor): 125 if not type(e) == torch.Tensor: 126 try: 127 current = e.elem 128 return ( 129 deepcopy(current.stride()), 130 current._typed_storage()._cdata, 131 ) 132 except AttributeError as t: 133 return None 134 # Sparse CSR tensors do not have strides or storage 135 elif e.layout != torch.sparse_csr: 136 return (deepcopy(e.stride()), e._typed_storage()._cdata) 137 return None 138 139 self.ops.append(func._schema.name) 140 141 # Clone and process arguments and outputs 142 pre_arguments = normalize_function( 143 func, args, kwargs, normalize_to_only_use_kwargs=True 144 ).kwargs 145 146 c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values()))) 147 cloned_arguments = { 148 name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args 149 } 150 cloned_metadata = { 151 name: [ 152 parse_metadata(a) for a in pytree.tree_leaves(pre_arguments.get(name)) 153 ] 154 for name in pre_arguments 155 } 156 157 out = func(*args, **kwargs) 158 arguments = { 159 name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments 160 } 161 tuple_out = out if isinstance(out, tuple) else (out,) 162 tuple_out = tree_map(unwrap, tuple_out) 163 164 schema_info = SchemaInfo(func._schema) 165 schema_info.add_argument_values(pre_arguments) 166 167 # Process arguments with outputs 168 for i in range(len(func._schema.arguments)): 169 arg = func._schema.arguments[i] 170 name = standardize_name(arg.name) 171 if arguments.get(name) is not None: 172 before = cloned_arguments.get(name) 173 md = cloned_metadata.get(name) 174 after = arguments.get(name) 175 for j in range(len(tuple_out)): 176 # aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe) 177 unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split") 178 if ( 179 has_aliased(tuple_out[j], after) 180 and func._schema.name not in unsafe_ops 181 ): 182 if not schema_info.may_contain_alias( 183 SchemaArgument(SchemaArgType.output, j), 184 SchemaArgument(SchemaArgType.input, i), 185 ): 186 raise RuntimeError( 187 f"Argument {name} is not defined to alias output but was aliasing" 188 ) 189 else: 190 self.aliasing.append( 191 Aliasing(func._schema.name, name, f"output_{j}") 192 ) 193 if after is tuple_out[j] and isinstance(after, torch.Tensor): 194 # Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs. 195 if not schema_info.is_mutable( 196 SchemaArgument(SchemaArgType.input, i) 197 ) and func not in [ 198 torch.ops.aten.lift.default, 199 torch.ops.aten.lift_fresh.default, 200 ]: 201 raise RuntimeError( 202 f"""\ 203Dispatcher operators below autograd are not allowed to directly return inputs. 204However, we found that `outputs[{str(j)}] is {name}""" 205 ) 206 if any( 207 has_mutated(a, b, c) 208 for a, b, c in zip( 209 pytree.tree_leaves(before), pytree.tree_leaves(after), md 210 ) 211 ): 212 if not schema_info.is_mutable( 213 SchemaArgument(SchemaArgType.input, i) 214 ): 215 raise RuntimeError( 216 f"Argument {name} is not defined as mutable but was mutated" 217 ) 218 else: 219 self.mutated.append(Mutation(func._schema.name, name)) 220 221 # Aliasing between outputs 222 for i, j in combinations(range(len(func._schema.returns)), 2): 223 if has_aliased(tuple_out[i], tuple_out[j]): 224 if not schema_info.may_contain_alias( 225 SchemaArgument(SchemaArgType.output, i), 226 SchemaArgument(SchemaArgType.output, j), 227 ): 228 raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly") 229 230 return out 231