1# mypy: allow-untyped-defs 2import dataclasses 3from enum import auto, Enum 4from typing import Collection, Dict, List, Mapping, Optional, Set, TYPE_CHECKING, Union 5 6from torch._library.fake_class_registry import FakeScriptObject 7 8 9if TYPE_CHECKING: 10 import torch 11 from torch._functorch._aot_autograd.schemas import GraphSignature 12 13__all__ = [ 14 "ConstantArgument", 15 "CustomObjArgument", 16 "ExportBackwardSignature", 17 "ExportGraphSignature", 18 "InputKind", 19 "InputSpec", 20 "OutputKind", 21 "OutputSpec", 22 "SymIntArgument", 23 "TensorArgument", 24] 25 26 27@dataclasses.dataclass 28class TensorArgument: 29 name: str 30 31 32@dataclasses.dataclass 33class TokenArgument: 34 name: str 35 36 37@dataclasses.dataclass 38class SymIntArgument: 39 name: str 40 41 42@dataclasses.dataclass 43class CustomObjArgument: 44 name: str 45 class_fqn: str 46 fake_val: Optional[FakeScriptObject] = None 47 48 49@dataclasses.dataclass 50class ConstantArgument: 51 name: str 52 value: Union[int, float, bool, str, None] 53 54 55ArgumentSpec = Union[ 56 TensorArgument, 57 SymIntArgument, 58 ConstantArgument, 59 CustomObjArgument, 60 TokenArgument, 61] 62 63 64class InputKind(Enum): 65 USER_INPUT = auto() 66 PARAMETER = auto() 67 BUFFER = auto() 68 CONSTANT_TENSOR = auto() 69 CUSTOM_OBJ = auto() 70 TOKEN = auto() 71 72 73@dataclasses.dataclass 74class InputSpec: 75 kind: InputKind 76 arg: ArgumentSpec 77 target: Optional[str] 78 persistent: Optional[bool] = None 79 80 def __post_init__(self): 81 if self.kind == InputKind.BUFFER: 82 assert ( 83 self.persistent is not None 84 ), "Failed to specify persistent flag on BUFFER." 85 assert isinstance( 86 self.arg, 87 ( 88 TensorArgument, 89 SymIntArgument, 90 ConstantArgument, 91 CustomObjArgument, 92 TokenArgument, 93 ), 94 ), f"got {type(self.arg)}" 95 96 97class OutputKind(Enum): 98 USER_OUTPUT = auto() 99 LOSS_OUTPUT = auto() 100 BUFFER_MUTATION = auto() 101 GRADIENT_TO_PARAMETER = auto() 102 GRADIENT_TO_USER_INPUT = auto() 103 USER_INPUT_MUTATION = auto() 104 TOKEN = auto() 105 106 107@dataclasses.dataclass 108class OutputSpec: 109 kind: OutputKind 110 arg: ArgumentSpec 111 target: Optional[str] 112 113 def __post_init__(self): 114 assert isinstance( 115 self.arg, 116 ( 117 TensorArgument, 118 SymIntArgument, 119 ConstantArgument, 120 TokenArgument, 121 CustomObjArgument, 122 ), 123 ), self.arg 124 125 126@dataclasses.dataclass 127class ExportBackwardSignature: 128 gradients_to_parameters: Dict[str, str] 129 gradients_to_user_inputs: Dict[str, str] 130 loss_output: str 131 132 133@dataclasses.dataclass 134class ExportGraphSignature: 135 """ 136 :class:`ExportGraphSignature` models the input/output signature of Export Graph, 137 which is a fx.Graph with stronger invariants gurantees. 138 139 Export Graph is functional and does not access "states" like parameters 140 or buffers within the graph via ``getattr`` nodes. Instead, :func:`export` 141 gurantees that parameters, buffers, and constant tensors are lifted out of 142 the graph as inputs. Similarly, any mutations to buffers are not included 143 in the graph either, instead the updated values of mutated buffers are 144 modeled as additional outputs of Export Graph. 145 146 The ordering of all inputs and outputs are:: 147 148 Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] 149 Outputs = [*mutated_inputs, *flattened_user_outputs] 150 151 e.g. If following module is exported:: 152 153 class CustomModule(nn.Module): 154 def __init__(self) -> None: 155 super(CustomModule, self).__init__() 156 157 # Define a parameter 158 self.my_parameter = nn.Parameter(torch.tensor(2.0)) 159 160 # Define two buffers 161 self.register_buffer('my_buffer1', torch.tensor(3.0)) 162 self.register_buffer('my_buffer2', torch.tensor(4.0)) 163 164 def forward(self, x1, x2): 165 # Use the parameter, buffers, and both inputs in the forward method 166 output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 167 168 # Mutate one of the buffers (e.g., increment it by 1) 169 self.my_buffer2.add_(1.0) # In-place addition 170 171 return output 172 173 Resulting Graph would be:: 174 175 graph(): 176 %arg0_1 := placeholder[target=arg0_1] 177 %arg1_1 := placeholder[target=arg1_1] 178 %arg2_1 := placeholder[target=arg2_1] 179 %arg3_1 := placeholder[target=arg3_1] 180 %arg4_1 := placeholder[target=arg4_1] 181 %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) 182 %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) 183 %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) 184 %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) 185 %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) 186 return (add_tensor_2, add_tensor_1) 187 188 Resulting ExportGraphSignature would be:: 189 190 ExportGraphSignature( 191 input_specs=[ 192 InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), 193 InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), 194 InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), 195 InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), 196 InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) 197 ], 198 output_specs=[ 199 OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), 200 OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) 201 ] 202 ) 203 """ 204 205 input_specs: List[InputSpec] 206 output_specs: List[OutputSpec] 207 208 # A list of parameters uniquely identified by mangled fully qualified name 209 @property 210 def parameters(self) -> Collection[str]: 211 return tuple( 212 s.target 213 for s in self.input_specs 214 if s.kind == InputKind.PARAMETER 215 if isinstance(s.target, str) 216 ) 217 218 # A list of buffers uniquely identified by mangled fully qualified name 219 @property 220 def buffers(self) -> Collection[str]: 221 return tuple( 222 s.target 223 for s in self.input_specs 224 if s.kind == InputKind.BUFFER 225 if isinstance(s.target, str) 226 ) 227 228 @property 229 def non_persistent_buffers(self) -> Collection[str]: 230 return tuple( 231 s.target 232 for s in self.input_specs 233 if s.kind == InputKind.BUFFER 234 if s.persistent is False 235 if isinstance(s.target, str) 236 ) 237 238 # A list of lifted constant tensors 239 @property 240 def lifted_tensor_constants(self) -> Collection[str]: 241 return tuple( 242 s.target 243 for s in self.input_specs 244 if s.kind == InputKind.CONSTANT_TENSOR 245 if isinstance(s.target, str) 246 ) 247 248 @property 249 def lifted_custom_objs(self) -> Collection[str]: 250 return tuple( 251 s.target 252 for s in self.input_specs 253 if s.kind == InputKind.CUSTOM_OBJ 254 if isinstance(s.target, str) 255 ) 256 257 # Graph node names of pytree-flattened inputs of original program 258 @property 259 def user_inputs(self) -> Collection[Union[int, float, bool, None, str]]: 260 user_inputs: List[Union[int, float, bool, None, str]] = [] 261 for s in self.input_specs: 262 if s.kind != InputKind.USER_INPUT: 263 continue 264 265 if isinstance(s.arg, (TensorArgument, SymIntArgument, CustomObjArgument)): 266 user_inputs.append(s.arg.name) 267 elif isinstance(s.arg, ConstantArgument): 268 user_inputs.append(s.arg.value) 269 else: 270 raise RuntimeError(f"{s.arg} is not a valid user inputs") 271 return tuple(user_inputs) 272 273 # Graph node names of pytree-flattened outputs of original program 274 @property 275 def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]: 276 user_outputs: List[Union[int, float, bool, None, str]] = [] 277 for s in self.output_specs: 278 if s.kind != OutputKind.USER_OUTPUT: 279 continue 280 281 if isinstance(s.arg, (TensorArgument, SymIntArgument)): 282 user_outputs.append(s.arg.name) 283 elif isinstance(s.arg, ConstantArgument): 284 user_outputs.append(s.arg.value) 285 elif isinstance(s.arg, CustomObjArgument): 286 user_outputs.append(s.arg.name) 287 else: 288 raise RuntimeError(f"{s.arg} is not a valid user output") 289 return tuple(user_outputs) 290 291 # A dictionary mapping graph input node names to parameters. If a graph input 292 # name is found in this dictionary, it is guranteed to be a lifted parameter. 293 @property 294 def inputs_to_parameters(self) -> Mapping[str, str]: 295 return _immutable_dict( 296 (s.arg.name, s.target) 297 for s in self.input_specs 298 if s.kind == InputKind.PARAMETER 299 and isinstance(s.arg, TensorArgument) 300 and isinstance(s.target, str) 301 ) 302 303 # A dictionary mapping graph input node names to buffers. If a graph input 304 # name is found in this dictionary, it is guranteed to be a lifted buffer. 305 @property 306 def inputs_to_buffers(self) -> Mapping[str, str]: 307 return _immutable_dict( 308 (s.arg.name, s.target) # type: ignore[union-attr, misc] 309 for s in self.input_specs 310 if s.kind == InputKind.BUFFER 311 and isinstance(s.arg, TensorArgument) 312 and isinstance(s.target, str) 313 ) 314 315 # A dictionary mapping graph output node names to buffers that are mutated in the 316 # original program. Buffers that are not mutated will not be found in this dictionary. 317 @property 318 def buffers_to_mutate(self) -> Mapping[str, str]: 319 return _immutable_dict( 320 (s.arg.name, s.target) 321 for s in self.output_specs 322 if s.kind == OutputKind.BUFFER_MUTATION 323 and isinstance(s.arg, TensorArgument) 324 and isinstance(s.target, str) 325 ) 326 327 @property 328 def user_inputs_to_mutate(self) -> Mapping[str, str]: 329 return _immutable_dict( 330 (s.arg.name, s.target) 331 for s in self.output_specs 332 if s.kind == OutputKind.USER_INPUT_MUTATION 333 and isinstance(s.arg, TensorArgument) 334 and isinstance(s.target, str) 335 ) 336 337 # A dictionary mapping graph input node names to lifted tensor constants. 338 @property 339 def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]: 340 return _immutable_dict( 341 (s.arg.name, s.target) 342 for s in self.input_specs 343 if s.kind == InputKind.CONSTANT_TENSOR 344 and isinstance(s.arg, TensorArgument) 345 and isinstance(s.target, str) 346 ) 347 348 @property 349 def inputs_to_lifted_custom_objs(self) -> Mapping[str, str]: 350 return _immutable_dict( 351 (s.arg.name, s.target) 352 for s in self.input_specs 353 if s.kind == InputKind.CUSTOM_OBJ 354 and isinstance(s.arg, CustomObjArgument) 355 and isinstance(s.target, str) 356 ) 357 358 @property 359 def backward_signature(self) -> Optional[ExportBackwardSignature]: 360 loss_output = None 361 gradients_to_parameters: Dict[str, str] = {} 362 gradients_to_user_inputs: Dict[str, str] = {} 363 for spec in self.output_specs: 364 if spec.kind == OutputKind.LOSS_OUTPUT: 365 assert loss_output is None 366 assert isinstance(spec.arg, TensorArgument) 367 loss_output = spec.arg.name 368 elif spec.kind == OutputKind.GRADIENT_TO_PARAMETER: 369 assert isinstance(spec.target, str) 370 assert isinstance(spec.arg, TensorArgument) 371 gradients_to_parameters[spec.arg.name] = spec.target 372 elif spec.kind == OutputKind.GRADIENT_TO_USER_INPUT: 373 assert isinstance(spec.target, str) 374 assert isinstance(spec.arg, TensorArgument) 375 gradients_to_user_inputs[spec.arg.name] = spec.target 376 377 if loss_output is None: 378 return None 379 380 return ExportBackwardSignature( 381 loss_output=loss_output, 382 gradients_to_parameters=gradients_to_parameters, 383 gradients_to_user_inputs=gradients_to_user_inputs, 384 ) 385 386 # Map from assertion dependency token index to assertion dep token output 387 # name in output. The shape of output after aot_autograd will be like: 388 # (updated_inputs, user_outputs, dep_token). 389 @property 390 def assertion_dep_token(self) -> Optional[Mapping[int, str]]: 391 return None 392 393 @property 394 def input_tokens(self) -> Collection[str]: 395 input_tokens = [] 396 for s in self.input_specs: 397 if s.kind == InputKind.TOKEN: 398 assert isinstance(s.arg, TokenArgument) 399 input_tokens.append(s.arg.name) 400 return tuple(input_tokens) 401 402 @property 403 def output_tokens(self) -> Collection[str]: 404 output_tokens = [] 405 for s in self.output_specs: 406 if s.kind == OutputKind.TOKEN: 407 assert isinstance(s.arg, TokenArgument) 408 output_tokens.append(s.arg.name) 409 return tuple(output_tokens) 410 411 def __post_init__(self) -> None: 412 assertion_dep_token = self.assertion_dep_token 413 if assertion_dep_token is None: 414 return 415 assert len(assertion_dep_token) == 1 416 assertion_dep_token_index = next(iter(assertion_dep_token.keys())) 417 assert ( 418 len(self.user_outputs) + len(self.buffers_to_mutate) 419 == assertion_dep_token_index 420 ) 421 422 def replace_all_uses(self, old: str, new: str): 423 """ 424 Replace all uses of the old name with new name in the signature. 425 """ 426 assert isinstance(old, str) 427 assert isinstance(new, str) 428 arg_types = (TensorArgument, SymIntArgument, CustomObjArgument, TokenArgument) 429 for o in self.output_specs: 430 if isinstance(o.arg, arg_types): 431 if o.arg.name == old: 432 o.arg.name = new 433 for i in self.input_specs: 434 if isinstance(i.arg, arg_types): 435 if i.arg.name == old: 436 i.arg.name = new 437 438 def get_replace_hook(self): 439 def _(old, new, user): 440 if user.op in ("output", "input"): 441 self.replace_all_uses(old.name, new) 442 443 return _ 444 445 446def _immutable_dict(items): 447 """ 448 Creates a mapping where items cannot be added, deleted, or updated. 449 NOTE: The immutability is shallow (like tuple is an immutable collection). 450 """ 451 from types import MappingProxyType 452 453 return MappingProxyType(dict(items)) 454 455 456def _make_argument_spec(node, token_names) -> ArgumentSpec: 457 from torch import ScriptObject, SymInt 458 from torch._library.fake_class_registry import FakeScriptObject 459 from torch._subclasses.fake_tensor import FakeTensor 460 461 if isinstance(node, (int, bool, float, type(None), str)): 462 # For const outputs we just directly return this 463 return ConstantArgument(name="", value=node) 464 465 assert ( 466 "val" in node.meta 467 ), f"{node} is not a constant or a node with a 'val' metadata field" 468 val = node.meta["val"] 469 if node.name in token_names: 470 return TokenArgument(name=node.name) 471 elif isinstance(val, FakeTensor): 472 return TensorArgument(name=node.name) 473 elif isinstance(val, SymInt): 474 return SymIntArgument(name=node.name) 475 elif isinstance(val, ScriptObject): 476 return CustomObjArgument(name=node.name, class_fqn=val._type().qualified_name()) # type: ignore[attr-defined] 477 elif isinstance(val, FakeScriptObject): 478 return CustomObjArgument( 479 name=node.name, class_fqn=val.script_class_name, fake_val=val 480 ) 481 elif isinstance(val, (int, bool, str, float, type(None))): 482 return ConstantArgument(name=node.name, value=val) 483 else: 484 raise AssertionError( 485 f"Encountered an unsupported object of type {type(val)} " 486 f"while writing the metadata for exported program" 487 ) 488 489 490def _convert_to_export_graph_signature( 491 graph_signature: "GraphSignature", 492 gm: "torch.fx.GraphModule", 493 non_persistent_buffers: Set[str], 494) -> "ExportGraphSignature": 495 from torch.utils import _pytree as pytree 496 497 is_joint = graph_signature.backward_signature is not None 498 499 # unpack objects 500 user_inputs = set(graph_signature.user_inputs) 501 inputs_to_parameters = graph_signature.inputs_to_parameters 502 inputs_to_buffers = graph_signature.inputs_to_buffers 503 user_outputs = set(graph_signature.user_outputs) 504 buffer_mutations = graph_signature.buffers_to_mutate 505 user_input_mutations = graph_signature.user_inputs_to_mutate 506 grad_params = graph_signature.backward_signature.gradients_to_parameter if is_joint else {} # type: ignore[union-attr] 507 grad_user_inputs = graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {} # type: ignore[union-attr] 508 loss_output = graph_signature.backward_signature.loss_output if is_joint else None # type: ignore[union-attr] 509 input_tokens = graph_signature.input_tokens 510 output_tokens = graph_signature.output_tokens 511 512 inputs = [ 513 _make_argument_spec(node, input_tokens) 514 for node in gm.graph.nodes 515 if node.op == "placeholder" 516 ] 517 outputs = [ 518 _make_argument_spec(node, output_tokens) 519 for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args) 520 ] 521 522 def to_input_spec(inp: ArgumentSpec) -> InputSpec: 523 if isinstance(inp, TokenArgument): 524 return InputSpec(kind=InputKind.TOKEN, arg=inp, target=None) 525 526 if not isinstance(inp, TensorArgument): 527 return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) 528 name = inp.name 529 if name in user_inputs: 530 return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) 531 elif name in inputs_to_parameters: 532 return InputSpec( 533 kind=InputKind.PARAMETER, 534 arg=inp, 535 target=inputs_to_parameters[name], # type: ignore[index] 536 ) 537 elif name in inputs_to_buffers: 538 return InputSpec( 539 kind=InputKind.BUFFER, 540 arg=inp, 541 target=inputs_to_buffers[name], # type: ignore[index] 542 persistent=(inputs_to_buffers[name] not in non_persistent_buffers), # type: ignore[index] 543 ) 544 else: 545 raise AssertionError(f"Unknown tensor input kind: {name}") 546 547 def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec: 548 if isinstance(o, TokenArgument): 549 return OutputSpec(kind=OutputKind.TOKEN, arg=o, target=None) 550 551 if not isinstance(o, TensorArgument): 552 return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) 553 name = o.name 554 if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): 555 if name in buffer_mutations: 556 return OutputSpec( 557 kind=OutputKind.BUFFER_MUTATION, 558 arg=o, 559 target=buffer_mutations[name], # type: ignore[index] 560 ) 561 elif name in user_input_mutations: 562 return OutputSpec( 563 kind=OutputKind.USER_INPUT_MUTATION, 564 arg=o, 565 target=user_input_mutations[name], # type: ignore[index] 566 ) 567 else: 568 raise AssertionError(f"Unknown tensor mutation kind: {name}") 569 else: 570 if name in user_outputs: 571 return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) 572 573 elif name in grad_params: 574 return OutputSpec( 575 kind=OutputKind.GRADIENT_TO_PARAMETER, 576 arg=o, 577 target=grad_params[name], 578 ) 579 elif name in grad_user_inputs: 580 return OutputSpec( 581 kind=OutputKind.GRADIENT_TO_USER_INPUT, 582 arg=o, 583 target=grad_user_inputs[name], 584 ) 585 elif name == loss_output: 586 return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None) 587 588 else: 589 raise AssertionError(f"Unknown tensor output kind: {name}") 590 591 input_specs = [to_input_spec(inp) for inp in inputs] 592 output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)] 593 return ExportGraphSignature(input_specs=input_specs, output_specs=output_specs) 594