1# mypy: allow-untyped-defs 2import inspect 3import math 4import operator 5from collections.abc import Iterable 6from typing import Any, Dict, final, List, Tuple, Type, TYPE_CHECKING 7 8import torch 9from torch._ops import HigherOrderOperator, OpOverload 10from torch._subclasses.fake_tensor import FakeTensor 11from torch.export.graph_signature import ( 12 CustomObjArgument, 13 InputKind, 14 SymIntArgument, 15 TensorArgument, 16 TokenArgument, 17) 18from torch.fx import GraphModule 19 20if TYPE_CHECKING: 21 from torch.export.exported_program import ExportedProgram 22 23class SpecViolationError(Exception): 24 pass 25 26 27def is_functional(op: OpOverload) -> bool: 28 return not op._schema.is_mutable 29 30 31def _check_has_fake_tensor(node: torch.fx.Node) -> None: 32 # TODO(angelayi): remove this in favor of _check_val 33 return _check_val(node) 34 35 36def _check_val(node: torch.fx.Node) -> None: 37 from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt 38 39 def _check_correct_val(val): 40 if val is None: 41 return True 42 elif isinstance(val, (int, bool, str, float)): 43 return True 44 elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)): 45 return True 46 elif isinstance(val, (FakeTensor, torch.Tensor)): # TODO(zhxchen17) Remove Tensor. 47 return True 48 elif isinstance(val, (SymInt, SymFloat, SymBool)): 49 return True 50 elif isinstance(val, CustomObjArgument): 51 return True 52 elif isinstance(val, Iterable): 53 return all(_check_correct_val(x) for x in val) 54 return False 55 56 def _no_returns(op): 57 if not isinstance(op, OpOverload): 58 return False 59 return len(op._schema.returns) == 0 60 61 if "val" not in node.meta: 62 if node.op == "call_function" and _no_returns(node.target): 63 return 64 raise SpecViolationError(f"Node.meta {node.name} is missing val field.") 65 66 val = node.meta["val"] 67 if not _check_correct_val(val): 68 raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}") 69 70 71def _check_torch_fn(node: torch.fx.Node) -> None: 72 torch_fn = node.meta.get("torch_fn") 73 if torch_fn is None: 74 raise SpecViolationError(f"Unable to find torch_fn metadata for node {node.name}") 75 if ( 76 not isinstance(torch_fn, tuple) and 77 isinstance(torch_fn[0], str) and 78 isinstance(torch_fn[1], str) 79 ): 80 raise SpecViolationError(f"Node.meta {node.name} has invalid torch_fn field {torch_fn}") 81 82class _VerifierMeta(type): 83 _registry: Dict[str, Type['Verifier']] = {} 84 85 def __new__(metacls, name, bases, attrs): 86 if bases: 87 if "check" in attrs or "_check_graph_module" in attrs: 88 raise SyntaxError("Overriding method check is not allowed.") 89 assert "dialect" in attrs and attrs["dialect"] != "ATEN" 90 else: 91 assert "check" in attrs 92 assert "_check_graph_module" in attrs 93 assert attrs["dialect"] == "ATEN" 94 95 assert isinstance(attrs["dialect"], str) 96 ret = type.__new__(metacls, name, bases, attrs) 97 metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment] 98 return ret 99 100def getattr_recursive(obj: Any, target: str) -> Any: 101 target_atoms = target.split('.') 102 attr_itr = obj 103 for i, atom in enumerate(target_atoms): 104 if not hasattr(attr_itr, atom): 105 raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") 106 attr_itr = getattr(attr_itr, atom) 107 return attr_itr 108 109 110class Verifier(metaclass=_VerifierMeta): 111 dialect = "ATEN" 112 113 def allowed_builtin_ops(self) -> List: 114 return [ 115 operator.getitem, 116 operator.add, 117 operator.mul, 118 operator.sub, 119 operator.truediv, 120 operator.ge, 121 operator.le, 122 operator.gt, 123 operator.lt, 124 operator.eq, 125 operator.ne, 126 operator.floordiv, 127 operator.mod, 128 operator.and_, 129 operator.or_, 130 operator.not_, 131 operator.pow, 132 operator.neg, 133 operator.abs, 134 math.ceil, 135 math.floor, 136 math.trunc, 137 ] 138 139 def allowed_op_types(self) -> Tuple[Type[Any], ...]: 140 return (OpOverload, HigherOrderOperator) 141 142 def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: 143 return (torch.fx.GraphModule,) 144 145 def check_valid_op(self, op): 146 pass 147 148 def check_additional(self, gm: GraphModule) -> None: 149 """ 150 Additional checks that are specific to some dialects. 151 """ 152 153 @final 154 def check(self, ep: "ExportedProgram") -> None: 155 self._check_graph_module(ep.graph_module) 156 _verify_exported_program_module_call_graph(ep) 157 _verify_exported_program_signature(ep) 158 159 @final 160 def _check_graph_module(self, gm: torch.fx.GraphModule) -> None: 161 def _allowed_getattr_types() -> Tuple[Type[Any], ...]: 162 ret = self.allowed_getattr_types() 163 assert not any(t is object for t in ret) 164 return ret 165 166 def _check_valid_op(op) -> None: 167 def _allowed_builtin_ops() -> List: 168 ret = self.allowed_builtin_ops() 169 assert all(inspect.isbuiltin(op) for op in ret) 170 return ret 171 172 def _allowed_op_types() -> Tuple[Type[Any], ...]: 173 ret = self.allowed_op_types() 174 assert not any(t is object for t in ret) 175 return ret 176 177 # TODO Remove this allowlist. 178 _allowed_torch_functions = ( 179 torch.autograd.grad_mode.set_grad_enabled, 180 torch.sym_int, 181 torch.sym_float, 182 torch.sym_ite, 183 torch.sym_max, 184 torch.sym_min, 185 torch.sym_not, 186 torch.sym_sqrt, 187 # TODO (tmanlaibaatar) 188 # Predispatch export is able to contain autograd ops. 189 # These will be modeled as HOO later 190 torch._C._set_grad_enabled, 191 ) 192 193 if not isinstance(op, _allowed_op_types()): 194 if op not in _allowed_builtin_ops() and op not in _allowed_torch_functions: 195 raise SpecViolationError( 196 f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n" 197 f"Valid builtin ops: {_allowed_builtin_ops()}" 198 f"Valid torch functions: {_allowed_torch_functions}" 199 ) 200 201 if isinstance(op, OpOverload): 202 # All ops functional 203 # TODO (tmanlaibaatar) more proper way is needed here 204 if self.dialect != "TRAINING" and not is_functional(op): 205 raise SpecViolationError( 206 f"operator '{op}' is not functional" 207 ) 208 self.check_valid_op(op) 209 210 for mod in gm.modules(): 211 if not isinstance(mod, torch.fx.GraphModule): 212 continue 213 214 mod.graph.lint() 215 for node in mod.graph.nodes: 216 # TODO(T140410192): should have fake tensor for all dialects 217 if node.op in {"call_module", "call_method"}: 218 raise SpecViolationError( 219 f"call_module is not valid: got a class '{node.target}' ", 220 ) 221 222 elif node.op == "call_function": 223 _check_val(node) 224 225 _check_valid_op(node.target) 226 227 elif node.op == "get_attr": 228 if not isinstance(node.target, str): 229 raise SpecViolationError( 230 f"Expected get_attr target to be string, but got {type(node.target)}" 231 ) 232 233 attr = getattr_recursive(mod, node.target) 234 if isinstance(attr, torch.nn.Module): 235 def _is_type(name, ty): 236 return isinstance(getattr(attr, name, None), ty) 237 if type(attr).__name__ == "LoweredBackendModule": 238 if _is_type("backend_id", str) \ 239 and _is_type("processed_bytes", bytes) \ 240 and _is_type("compile_specs", list) \ 241 and hasattr(attr, "original_module"): 242 continue 243 else: 244 backend_id = getattr(attr, "backend_id", None) 245 processed_bytes = getattr(attr, "processed_bytes", None) 246 compile_specs = getattr(attr, "compile_specs", None) 247 raise SpecViolationError( 248 f"Invalid get_attr type {type(attr)}. \n" 249 f"LoweredBackendModule fields: " 250 f"backend_id(str) : {type(backend_id)}, " 251 f"processed_bytes(bytes) : {type(processed_bytes)}, " 252 f"compile_specs(list) : {type(compile_specs)}" 253 ) 254 255 if not isinstance(attr, _allowed_getattr_types()): 256 raise SpecViolationError( 257 f"Invalid get_attr type {type(attr)}. \n" 258 f"Valid get_attr types: {_allowed_getattr_types()}" 259 ) 260 261 262 elif node.op == "placeholder": 263 _check_val(node) 264 # TODO(zhxchen17) 265 # elif node.op == "output": 266 # _check_flattened_outputs() 267 268 self.check_additional(gm) 269 270 271class TrainingIRVerifier(Verifier): 272 dialect = "TRAINING" 273 274 275def _verify_exported_program_module_call_graph(exported_program) -> None: 276 module_call_graph = exported_program.module_call_graph 277 nodes = { 278 node.name for node in exported_program.graph.nodes 279 } 280 for entry in module_call_graph: 281 if entry.signature is not None: 282 for arg in entry.signature.inputs: 283 if arg.name and arg.name not in nodes: 284 raise SpecViolationError( 285 f"Input {arg.name} does not exist in the graph." 286 ) 287 for arg in entry.signature.outputs: 288 if arg.name and arg.name not in nodes: 289 raise SpecViolationError( 290 f"Output {arg.name} does not exist in the graph." 291 ) 292 293 294def _verify_exported_program_signature(exported_program) -> None: 295 # Check ExportedProgram signature matches 296 gs = exported_program.graph_signature 297 298 # Check every node in the signature exists in the graph 299 input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"] 300 301 if len(input_node_names) != len(gs.input_specs): 302 raise SpecViolationError( 303 f"Number of graph inputs ({len(input_node_names)}) " 304 f"does not match number of inputs in the graph signature ({len(gs.input_specs)})" 305 ) 306 307 for input_spec, node in zip(gs.input_specs, input_node_names): 308 if isinstance(input_spec.arg, (TensorArgument, SymIntArgument)): 309 if input_spec.arg.name != node: 310 raise SpecViolationError( 311 f"Input spec name {input_spec.arg.name} does not match node name {node}" 312 ) 313 314 if input_spec.kind == InputKind.USER_INPUT: 315 continue 316 317 elif input_spec.kind == InputKind.PARAMETER: 318 if not isinstance(input_spec.arg, TensorArgument): 319 raise SpecViolationError( 320 f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." 321 ) 322 if input_spec.target is None: 323 raise SpecViolationError( 324 f"InputSpec for {input_spec.name} has no target." 325 ) 326 327 param = input_spec.target 328 if param not in exported_program.state_dict: 329 raise SpecViolationError( 330 f"Parameter {param} is not in the state dict." 331 ) 332 333 if not isinstance(exported_program.state_dict[param], torch.nn.Parameter): 334 raise SpecViolationError( 335 f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter." 336 ) 337 338 elif input_spec.kind == InputKind.BUFFER: 339 if not isinstance(input_spec.arg, TensorArgument): 340 raise SpecViolationError( 341 f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." 342 ) 343 if input_spec.target is None: 344 raise SpecViolationError( 345 f"InputSpec for {input_spec.name} has no target." 346 ) 347 348 buffer = input_spec.target 349 if input_spec.persistent is None: 350 raise SpecViolationError( 351 f"Buffer {buffer} is missing a persistence flag" 352 ) 353 354 if input_spec.persistent is True and buffer not in exported_program.state_dict: 355 raise SpecViolationError( 356 f"Buffer {buffer} is not in the state dict." 357 ) 358 359 if input_spec.persistent is False and buffer in exported_program.state_dict: 360 raise SpecViolationError( 361 f"Non-persistent buffer {buffer} is in the state dict, it should not be." 362 ) 363 elif input_spec.kind == InputKind.CONSTANT_TENSOR: 364 if not isinstance(input_spec.arg, TensorArgument): 365 raise SpecViolationError( 366 f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." 367 ) 368 if input_spec.target is None: 369 raise SpecViolationError( 370 f"InputSpec for {input_spec.name} has no target." 371 ) 372 373 tensor_const = input_spec.target 374 if tensor_const not in exported_program.constants: 375 raise SpecViolationError( 376 f"Constant tensor {tensor_const} is not in the constants dictionary." 377 ) 378 elif input_spec.kind == InputKind.CUSTOM_OBJ: 379 if not isinstance(input_spec.arg, CustomObjArgument): 380 raise SpecViolationError( 381 f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead." 382 ) 383 if input_spec.target is None: 384 raise SpecViolationError( 385 f"InputSpec for {input_spec.name} has no target." 386 ) 387 388 custom_obj = input_spec.target 389 if custom_obj not in exported_program.constants: 390 raise SpecViolationError( 391 f"Custom object {custom_obj} is not in the constants dictionary." 392 ) 393 elif input_spec.kind == InputKind.TOKEN: 394 if not isinstance(input_spec.arg, TokenArgument): 395 raise SpecViolationError( 396 f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." 397 ) 398 else: 399 raise SpecViolationError( 400 f"Unknown InputKind {input_spec.kind}." 401 ) 402 403 # Check outputs 404 output_node = list(exported_program.graph.nodes)[-1] 405 assert output_node.op == "output" 406 output_nodes = [ 407 arg.name if isinstance(arg, torch.fx.Node) else arg 408 for arg in output_node.args[0] 409 ] 410 411 if len(output_nodes) != len(gs.output_specs): 412 raise SpecViolationError( 413 f"Number of output nodes {len(output_nodes)} is different " 414 "Than the number of outputs specified by the graph signature: \n" 415 f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n" 416 f"Number of user outputs: {len(gs.user_outputs)}. \n" 417 ) 418 419 num_tokens = len(gs.output_tokens) 420 end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate) + num_tokens 421 mutate_nodes: List[str] = output_nodes[num_tokens:end] 422 user_output_nodes = output_nodes[end:end + len(gs.user_outputs)] 423 424 for mutation_node in mutate_nodes: 425 if mutation_node in gs.buffers_to_mutate: 426 if gs.buffers_to_mutate[mutation_node] not in gs.buffers: 427 raise SpecViolationError( 428 f"Buffer output {mutation_node} does not point to a buffer that exists. \n" 429 f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" 430 f"Buffer nodes available: {gs.buffers} \n" 431 ) 432 elif mutation_node in gs.user_inputs_to_mutate: 433 if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs: 434 raise SpecViolationError( 435 f"User input output {mutation_node} does not point to a user input that exists. \n" 436 f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n" 437 f"User input nodes available: {gs.user_inputs} \n") 438 else: 439 raise SpecViolationError( 440 f"Mutation node {mutation_node} is neither a buffer nor a user input. " 441 f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}" 442 ) 443 444 for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs): 445 if user_output_node != user_output_name: 446 raise SpecViolationError( 447 f"User output {user_output_node} is not in the correct " 448 "order or is not found in the " 449 f"exported program's user_output list: {gs.user_outputs}. " 450 ) 451 452 453def load_verifier(dialect: str) -> Type[Verifier]: 454 if dialect == "ATEN" or dialect == "": 455 return _VerifierMeta._registry.get(dialect, Verifier) 456 return _VerifierMeta._registry[dialect] 457