1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4from typing import ( 5 Any, 6 Callable, 7 Mapping, 8 Protocol, 9 runtime_checkable, 10 Sequence, 11 TYPE_CHECKING, 12) 13 14import torch 15import torch.export as torch_export 16from torch.utils import _pytree as pytree 17 18 19if TYPE_CHECKING: 20 import inspect 21 22# TODO(bowbao): Add diagnostics for IO adapters. 23 24 25@runtime_checkable 26class InputAdaptStep(Protocol): 27 """A protocol that defines a step in the input adapting process. 28 29 The input adapting process is a sequence of steps that are applied to the 30 PyTorch model inputs to transform them into the inputs format expected by the 31 exported ONNX model. Each step takes the PyTorch model inputs as arguments and 32 returns the transformed inputs. 33 34 This serves as a base formalized construct for the transformation done to model 35 input signature by any individual component in the exporter. 36 """ 37 38 def apply( 39 self, 40 model_args: Sequence[Any], 41 model_kwargs: Mapping[str, Any], 42 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 43 ) -> tuple[Sequence[Any], Mapping[str, Any]]: ... 44 45 46class InputAdapter: 47 """A class that adapts the PyTorch model inputs to exported ONNX model inputs format.""" 48 49 def __init__(self, steps: list[InputAdaptStep] | None = None): 50 self._steps = steps or [] 51 52 def append_step(self, step: InputAdaptStep) -> None: 53 """Appends a step to the input adapt steps. 54 55 Args: 56 step: The step to append. 57 """ 58 self._steps.append(step) 59 60 def apply( 61 self, 62 *model_args, 63 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 64 **model_kwargs, 65 ) -> Sequence[int | float | bool | str | torch.Tensor | torch.dtype | None]: 66 """Converts the PyTorch model inputs to exported ONNX model inputs format. 67 68 Args: 69 model_args: The PyTorch model inputs. 70 model: The PyTorch model. 71 model_kwargs: The PyTorch model keyword inputs. 72 Returns: 73 A sequence of tensors converted from PyTorch model inputs. 74 """ 75 args: Sequence[Any] = model_args 76 kwargs: Mapping[str, Any] = model_kwargs 77 for step in self._steps: 78 args, kwargs = step.apply(args, kwargs, model=model) 79 assert not kwargs 80 return args 81 82 83@runtime_checkable 84class OutputAdaptStep(Protocol): 85 """A protocol that defines a step in the output adapting process. 86 87 The output adapting process is a sequence of steps that are applied to the 88 PyTorch model outputs to transform them into the outputs format produced by the 89 exported ONNX model. Each step takes the PyTorch model outputs as arguments and 90 returns the transformed outputs. 91 92 This serves as a base formalized construct for the transformation done to model 93 output signature by any individual component in the exporter. 94 """ 95 96 def apply( 97 self, 98 model_outputs: Any, 99 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 100 ) -> Any: ... 101 102 103class OutputAdapter: 104 """A class that adapts the PyTorch model outputs to exported ONNX model outputs format.""" 105 106 def __init__(self, steps: list[OutputAdaptStep] | None = None): 107 self._steps = steps or [] 108 109 def append_step(self, step: OutputAdaptStep) -> None: 110 """Appends a step to the output format steps. 111 112 Args: 113 step: The step to append. 114 """ 115 self._steps.append(step) 116 117 def apply( 118 self, 119 model_outputs: Any, 120 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 121 ) -> Sequence[torch.Tensor | int | float | bool | str]: 122 """Converts the PyTorch model outputs to exported ONNX model outputs format. 123 124 Args: 125 model_outputs: The PyTorch model outputs. 126 model: The PyTorch model. 127 128 Returns: 129 PyTorch model outputs in exported ONNX model outputs format. 130 """ 131 for step in self._steps: 132 model_outputs = step.apply(model_outputs, model=model) 133 return model_outputs 134 135 136# TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276 137 138 139def _replace_tuple_with_list(spec: pytree.TreeSpec) -> pytree.TreeSpec: 140 _type = list if spec.type == tuple else spec.type 141 return pytree.TreeSpec( 142 _type, spec.context, list(map(_replace_tuple_with_list, spec.children_specs)) 143 ) 144 145 146def _open_top_level_list_if_single_element(spec: pytree.TreeSpec) -> pytree.TreeSpec: 147 if spec.type == list and spec.num_children == 1: 148 return spec.children_specs[0] 149 return spec 150 151 152def _assert_identical_pytree_spec( 153 spec1: pytree.TreeSpec, spec2: pytree.TreeSpec, error_message: str 154) -> None: 155 """Assert the two `TreeSpec` objects are identical. 156 157 Args: 158 spec1: The first `TreeSpec` object. 159 spec2: The second `TreeSpec` object. 160 error_message: The error message to raise if the two `TreeSpec` objects are not 161 identical. 162 163 Raises: 164 ValueError: If the two `TreeSpec` objects are not identical. 165 """ 166 # TODO(bowbao): Turn this check into diagnostic. Consider warning instead of error. 167 pass_if_any_checks: Sequence[Callable[[], bool]] = [ 168 lambda: spec1 == spec2, 169 # FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'. 170 lambda: _replace_tuple_with_list(spec1) == _replace_tuple_with_list(spec2), 171 # FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list. 172 lambda: _open_top_level_list_if_single_element(spec1) == spec2, 173 lambda: spec1 == _open_top_level_list_if_single_element(spec2), 174 ] 175 176 if not any(check() for check in pass_if_any_checks): 177 raise ValueError(f"{error_message}\nExpect {spec1}.\nActual {spec2}.") 178 179 180class BindInputStep(InputAdaptStep): 181 """Bind the input arguments to the model signature.""" 182 183 def __init__(self, model_signature: inspect.Signature): 184 self._model_signature = model_signature 185 186 def apply( 187 self, 188 model_args: Sequence[Any], 189 model_kwargs: Mapping[str, Any], 190 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 191 ) -> tuple[Sequence[Any], Mapping[str, Any]]: 192 """Bind the input arguments to the model signature. 193 194 We hope the input kwargs will be mapped to bound.args after binding. 195 If not, we will raise an error. 196 197 Args: 198 model_args: The model args. 199 model_kwargs: The model kwargs. 200 model: The PyTorch model. 201 202 Returns: 203 A tuple of the model args and kwargs. args is always empty. 204 205 Raises: 206 ValueError: If there are keyword-only arguments left after binding args and 207 kwargs to model signature. 208 """ 209 bound = self._model_signature.bind(*model_args, **model_kwargs) 210 bound.apply_defaults() 211 212 # keyword-only arguments are not handled. 213 # bound.kwargs only contains keyword-only arguments after calling 214 # bind & apply_defaults, so we raise if it's not empty. 215 if bound.kwargs: 216 raise ValueError("Keyword-only arguments are not supported.") 217 return (), bound.arguments 218 219 220class MergeKwargsIntoArgsInputStep(InputAdaptStep): 221 """Merge the input kwargs into the input args.""" 222 223 def apply( 224 self, 225 model_args: Sequence[Any], 226 model_kwargs: Mapping[str, Any], 227 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 228 ) -> tuple[Sequence[Any], Mapping[str, Any]]: 229 """Merge the input kwargs into the input args. 230 231 Args: 232 model_args: The model args. 233 model_kwargs: The model kwargs. 234 model: The PyTorch model. 235 236 Returns: 237 A tuple of the model args and kwargs. kwargs is always empty. 238 """ 239 return tuple(model_args) + tuple(model_kwargs.values()), {} 240 241 242class LiftParametersAndBuffersIntoArgsInputStep(InputAdaptStep): 243 """Append parameters and buffers to model's positional argument list.""" 244 245 def __init__(self, inputs: tuple[torch.Tensor, ...]) -> None: 246 self.inputs = inputs 247 248 def apply( 249 self, 250 model_args: Sequence[Any], 251 model_kwargs: Mapping[str, Any], 252 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 253 ) -> tuple[Sequence[Any], Mapping[str, Any]]: 254 """Append model's parameters and buffers into its input. 255 256 Args: 257 model_args: The model args. 258 model_kwargs: The model kwargs. 259 model: The PyTorch model. 260 261 Returns: 262 A tuple of the model args + appended inputs and kwargs. 263 """ 264 return (*model_args, *self.inputs), model_kwargs 265 266 267class ConvertComplexToRealRepresentationInputStep(InputAdaptStep): 268 """Convert complex dtype tensors to real representation tensors. 269 270 ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors 271 to real representation tensors (i.e., float dtype tensors with an extra dimension 272 representing the real and imaginary parts of the complex number). 273 274 """ 275 276 def apply( 277 self, 278 model_args: Sequence[Any], 279 model_kwargs: Mapping[str, Any], 280 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 281 ) -> tuple[Sequence[Any], Mapping[str, Any]]: 282 """Convert complex tensors to float tensors. 283 284 Args: 285 model_args: The model args. 286 model_kwargs: The model kwargs. 287 model: The PyTorch model. 288 289 Returns: 290 A tuple of the model args and kwargs. 291 """ 292 return ( 293 tuple( 294 torch.view_as_real(arg.resolve_conj()) 295 if isinstance(arg, torch.Tensor) and arg.is_complex() 296 else arg 297 for arg in model_args 298 ), 299 model_kwargs, 300 ) 301 302 303class RemoveNoneInputStep(InputAdaptStep): 304 """Remove `None` from arguments. 305 306 This adapt step assumes ``model_kwargs`` is empty. It also assumes ``model_args`` 307 is flattened, i.e. it does not check `None` inside nested collections. 308 """ 309 310 def apply( 311 self, 312 model_args: Sequence[Any], 313 model_kwargs: Mapping[str, Any], 314 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 315 ) -> tuple[Sequence[Any], Mapping[str, Any]]: 316 """Remove `None` from arguments. 317 318 Args: 319 model_args: The model args. 320 model_kwargs: The model kwargs. 321 model: The PyTorch model. 322 323 Returns: 324 A tuple of the model args and kwargs. 325 326 Raises: 327 ValueError: If `model_kwargs` is not empty. 328 """ 329 assert not model_kwargs 330 return tuple(arg for arg in model_args if arg is not None), {} 331 332 333class RemoveNonTensorInputStep(InputAdaptStep): 334 """Remove the non-tensor input arguments. 335 336 Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534). 337 338 Specifically, it does put the input into graph with an empty node, but consumed by no ones. 339 The concrete value is embedded into the graph as a constant arg of a target node. Meta 340 suggests in this case that one should rewrite the model code to make it tensor if the 341 input value is supposed to change at runtime. We might need to further investigate 342 the feasibility of that suggestion. 343 344 For example, 345 346 def func(x, b=1.0): 347 y = x + b 348 z = y.relu() 349 return (y, z) 350 351 x = torch.randn(1, 1, 2, dtype=torch.float32) 352 gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real") 353 354 # class GraphModule(torch.nn.Module): 355 # def forward(self, x, b): 356 # arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec) 357 # # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b 358 # add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None 359 360 # # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu() 361 # relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor) 362 # return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec) 363 364 Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as 365 it's ignored in ONNX graph. Thus, we delete the useless input here. 366 367 """ 368 369 def apply( 370 self, 371 model_args: Sequence[Any], 372 model_kwargs: Mapping[str, Any], 373 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 374 ) -> tuple[Sequence[Any], Mapping[str, Any]]: 375 """Remove Constant from arguments. 376 377 Args: 378 model_args: The model args. 379 model_kwargs: The model kwargs. 380 model: The PyTorch model. 381 382 Returns: 383 A tuple of the model args and kwargs. 384 385 Raises: 386 ValueError: If `model_kwargs` is not empty. 387 """ 388 assert not model_kwargs 389 return ( 390 tuple( 391 arg 392 for arg in model_args 393 if not isinstance(arg, (int, float, bool, str)) 394 ), 395 {}, 396 ) 397 398 399class FlattenInputWithTreeSpecValidationInputStep(InputAdaptStep): 400 """Flatten nested collection types and return a flat list of elements. 401 402 ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor, 403 etc). 404 405 This class stores the `SpecTree` output produced when `adapt` was called the first 406 time. It then validates the `SpecTree` output produced from later `adapt` calls. 407 """ 408 409 _spec: pytree.TreeSpec | None = None 410 411 def apply( 412 self, 413 model_args: Sequence[Any], 414 model_kwargs: Mapping[str, Any], 415 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 416 ) -> tuple[Sequence[Any], Mapping[str, Any]]: 417 """Flatten the model args and kwargs and validate the `SpecTree` output. 418 419 Args: 420 model_args: The model args. 421 model_kwargs: The model kwargs. 422 model: The PyTorch model. 423 424 Returns: 425 A tuple of the flattened model args and kwargs. The kwargs is empty, because 426 they are flattened and merged into the args. 427 428 Raises: 429 ValueError: If the `SpecTree` output produced from the current `model_outputs` 430 is not identical to the `SpecTree` output produced from the first 431 `model_outputs` that was passed to this method. 432 """ 433 flattened_args, spec = pytree.tree_flatten((model_args, model_kwargs)) 434 if self._spec is None: 435 self._spec = spec 436 else: 437 _assert_identical_pytree_spec( 438 self._spec, 439 spec, 440 error_message="Model inputs incompatible with the format that was exported. ", 441 ) 442 return flattened_args, {} 443 444 445class FlattenOutputStep(OutputAdaptStep): 446 """Flatten nested collection types and return a flat list of elements. 447 448 ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor, 449 etc). 450 451 NOTE: Ideally we would want to use ``FlattenOutputWithTreeSpecValidationOutputStep``, such 452 that `SpecTree` can be validate for new model outputs. However, this is not possible 453 currently because we never have access to real PyTorch model outputs during export. 454 Only traced outputs may be available, but they are not an accurate reflection of the 455 original PyTorch model outputs format as they are typically in their own unique format, 456 depending on the tracing strategy. 457 """ 458 459 def apply( 460 self, 461 model_outputs: Any, 462 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 463 ) -> Sequence[Any]: 464 """Flatten the model outputs. 465 466 Args: 467 model_outputs: The model outputs to flatten. 468 model: The PyTorch model. 469 470 Returns: 471 A tuple of the flattened model outputs. 472 """ 473 return pytree.tree_leaves(model_outputs) 474 475 476class ConvertComplexToRealRepresentationOutputStep(OutputAdaptStep): 477 """Convert complex dtype tensors to real representation tensors. 478 479 ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors 480 to real representation tensors (i.e., float dtype tensors with an extra dimension 481 representing the real and imaginary parts of the complex number). 482 483 """ 484 485 def apply( 486 self, 487 model_outputs: Any, 488 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 489 ) -> Any: 490 """Convert float tensors to complex tensors. 491 492 Args: 493 model_output: The model output. 494 model: The PyTorch model. 495 496 Returns: 497 A tuple of the model output. 498 """ 499 return [ 500 torch.view_as_real(output.resolve_conj()) 501 if isinstance(output, torch.Tensor) and torch.is_complex(output) 502 else output 503 for output in model_outputs 504 ] 505 506 507class FlattenOutputWithTreeSpecValidationOutputStep(OutputAdaptStep): 508 """Same as ``FlattenOutputStep``, with additional `TreeSpec` validation. 509 510 This class stores the `SpecTree` output produced when `adapt` was called the first 511 time. It then validates the `SpecTree` output produced from later `adapt` calls. 512 """ 513 514 _spec: pytree.TreeSpec | None = None 515 516 def apply( 517 self, 518 model_outputs: Any, 519 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 520 ) -> Sequence[Any]: 521 """Flatten the model outputs and validate the `SpecTree` output. 522 523 Args: 524 model_outputs: The model outputs to flatten. 525 model: The PyTorch model. 526 527 Returns: 528 flattened_outputs: The flattened model outputs. 529 530 Raises: 531 ValueError: If the `SpecTree` output produced from the current `model_outputs` 532 is not identical to the `SpecTree` output produced from the first 533 `model_outputs` that was passed to this method. 534 """ 535 flattened_outputs, spec = pytree.tree_flatten(model_outputs) 536 if self._spec is None: 537 self._spec = spec 538 else: 539 _assert_identical_pytree_spec( 540 self._spec, 541 spec, 542 error_message="Model outputs incompatible with the format that was exported. ", 543 ) 544 return flattened_outputs 545 546 547class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep): 548 """Prepend model parameters, buffers and constants to the user input. 549 550 :func:`torch.export.export` lifts model parameters, buffers and constants as model input, thus, they 551 must be added to the user input before the model is executed. 552 553 Args: 554 model: The PyTorch model with embedded parameters and buffers. 555 """ 556 557 def apply( 558 self, 559 model_args: Sequence[Any], 560 model_kwargs: Mapping[str, Any], 561 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 562 ) -> tuple[Sequence[Any], Mapping[str, Any]]: 563 """Convert complex tensors to float tensors. 564 565 Args: 566 model_args: The model args. 567 model_kwargs: The model kwargs. 568 model: The PyTorch model. 569 570 Returns: 571 A tuple of the model args and kwargs. 572 """ 573 ordered_params = tuple( 574 model.state_dict[name] # type: ignore[union-attr,index] 575 for name in model.graph_signature.parameters # type: ignore[union-attr] 576 ) 577 non_persistent_buffers = set(model.graph_signature.non_persistent_buffers) # type: ignore[union-attr] 578 ordered_buffers = [] 579 for name in model.graph_signature.buffers: # type: ignore[union-attr] 580 if name in non_persistent_buffers: 581 ordered_buffers.append(model.constants[name]) # type: ignore[union-attr] 582 else: 583 ordered_buffers.append(model.state_dict[name]) # type: ignore[union-attr,index] 584 ordered_constant_tensors = tuple( 585 model.constants[fqn] # type: ignore[union-attr,index] 586 for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr] 587 ) 588 589 # NOTE: calling convention is first params, then buffers, then args as user supplied them. 590 # See: torch/_functorch/aot_autograd.py#L1034 591 updated_args = ( 592 *ordered_params, 593 *ordered_buffers, 594 *ordered_constant_tensors, 595 *model_args, 596 ) 597 if model_kwargs: 598 return MergeKwargsIntoArgsInputStep().apply( 599 updated_args, model_kwargs, model=model 600 ) 601 return updated_args, {} 602 603 604class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep): 605 """Prepend model's mutated buffers to the user output. 606 607 :func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they 608 must be added to the user output after the model is executed. 609 610 Args: 611 model: The PyTorch model with mutated buffers. 612 """ 613 614 def apply( 615 self, 616 model_outputs: Any, 617 model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, 618 ) -> Sequence[Any]: 619 """Flatten the model outputs and validate the `SpecTree` output. 620 621 Args: 622 model_outputs: The model outputs to flatten. 623 model: The PyTorch model. 624 625 Returns: 626 flattened_outputs: The flattened model outputs. 627 """ 628 629 assert isinstance( 630 model, torch_export.ExportedProgram 631 ), "'model' must be torch_export.ExportedProgram" 632 ordered_buffers = tuple( 633 model.state_dict[name] 634 if name in model.state_dict 635 else model.constants[name] 636 for name in model.graph_signature.buffers_to_mutate.values() 637 ) 638 639 # NOTE: calling convention is first mutated buffers, then outputs args as model returned them. 640 updated_outputs = (*ordered_buffers, *model_outputs) 641 return updated_outputs 642