1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3"""Base optimizer.""" 4import functools 5import warnings 6from collections import defaultdict, OrderedDict 7from copy import deepcopy 8from itertools import chain 9from typing import ( 10 Any, 11 Callable, 12 cast, 13 DefaultDict, 14 Dict, 15 Hashable, 16 Iterable, 17 List, 18 Optional, 19 overload, 20 Set, 21 Tuple, 22 TypeVar, 23 Union, 24) 25from typing_extensions import ParamSpec, Self, TypeAlias 26 27import torch 28import torch.utils.hooks as hooks 29from torch._utils import is_compiling 30from torch.utils._foreach_utils import ( 31 _get_foreach_kernels_supported_devices, 32 _get_fused_kernels_supported_devices, 33 _group_tensors_by_device_and_dtype, 34 Indices, 35 TensorListList, 36) 37from torch.utils.hooks import RemovableHandle 38 39 40Args: TypeAlias = Tuple[Any, ...] 41Kwargs: TypeAlias = Dict[str, Any] 42StateDict: TypeAlias = Dict[str, Any] 43DeviceDict = Dict[Optional[torch.device], torch.Tensor] 44 45 46GlobalOptimizerPreHook: TypeAlias = Callable[ 47 ["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]] 48] 49GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None] 50 51__all__ = [ 52 "Optimizer", 53 "register_optimizer_step_pre_hook", 54 "register_optimizer_step_post_hook", 55] 56_global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict() 57_global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict() 58_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter] 59 60 61class _RequiredParameter: 62 """Singleton class representing a required parameter for an Optimizer.""" 63 64 def __repr__(self) -> str: 65 return "<required parameter>" 66 67 68required = _RequiredParameter() 69 70 71def _use_grad_for_differentiable(func): 72 def _use_grad(self, *args, **kwargs): 73 import torch._dynamo 74 75 prev_grad = torch.is_grad_enabled() 76 try: 77 # Note on graph break below: 78 # we need to graph break to ensure that aot respects the no_grad annotation. 79 # This is important for perf because without this, functionalization will generate an epilogue 80 # which updates the mutated parameters of the optimizer which is *not* visible to inductor, as a result, 81 # inductor will allocate for every parameter in the model, which is horrible. 82 # With this, aot correctly sees that this is an inference graph, and functionalization will generate 83 # an epilogue which is appended to the graph, which *is* visible to inductor, as a result, inductor sees that 84 # step is in place and is able to avoid the extra allocation. 85 # In the future, we will either 1) continue to graph break on backward, so this graph break does not matter 86 # or 2) have a fully fused forward and backward graph, which will have no_grad by default, and we can remove this 87 # graph break to allow the fully fused fwd-bwd-optimizer graph to be compiled. 88 # see https://github.com/pytorch/pytorch/issues/104053 89 torch.set_grad_enabled(self.defaults["differentiable"]) 90 torch._dynamo.graph_break() 91 ret = func(self, *args, **kwargs) 92 finally: 93 torch._dynamo.graph_break() 94 torch.set_grad_enabled(prev_grad) 95 return ret 96 97 functools.update_wrapper(_use_grad, func) 98 return _use_grad 99 100 101def _get_value(x): 102 # item is significantly faster than a cpu tensor in eager mode 103 if not torch.jit.is_scripting() and is_compiling(): 104 return x 105 else: 106 return x.item() if isinstance(x, torch.Tensor) else x 107 108 109def _stack_if_compiling(x): 110 if not torch.jit.is_scripting() and is_compiling(): 111 return torch.stack(x) 112 else: 113 return x 114 115 116def _disable_dynamo_if_unsupported(single_tensor_fn=None): 117 # workaround for torchscript BC 118 # it requires all called functions to be in the 119 # global environment at the site at which the 120 # maybe_fallback closure is created 121 if single_tensor_fn: 122 globals()[single_tensor_fn.__name__] = single_tensor_fn 123 124 def wrapper(func): 125 import inspect 126 127 disabled_func = torch._disable_dynamo(func) 128 ps = inspect.signature(func).parameters 129 has_state_steps = True 130 try: 131 state_steps_ind = list(ps.keys()).index("state_steps") 132 except ValueError: 133 has_state_steps = False 134 135 # Today, there are cases where we stack state steps 136 # and pass them as the value arg of foreach ops. 137 # Having state steps on cuda as the value arg is not supported in eager, 138 # but this only occurs in the rare case that the user explicitly deletes 139 # the capturable flag. If capturable=True, this is not a problem. 140 @functools.wraps(func) 141 def maybe_fallback(*args, **kwargs): 142 if is_compiling() and ( 143 not kwargs.get("capturable", False) 144 and has_state_steps 145 and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda) 146 or ( 147 "state_steps" in kwargs 148 and kwargs["state_steps"] 149 and kwargs["state_steps"][0].is_cuda 150 ) 151 ): 152 return disabled_func(*args, **kwargs) 153 else: 154 return func(*args, **kwargs) 155 156 return maybe_fallback 157 158 return wrapper 159 160 161# For any optimizer with a faster implementation, we attempt to default to the 162# fastest + stablest whenever possible. For foreach, the requirements are to have 163# native params all on CUDA. For fused, there's currently the additional requirement 164# that the tensors' dtypes must be floating point. Neither alternative supports 165# torch.jit.script nor differentiable, so we fall back to the single tensor 166# implementation in those cases. 167def _default_to_fused_or_foreach( 168 params: List[torch.Tensor], differentiable: bool, use_fused: bool = False 169) -> Tuple[bool, bool]: 170 if torch.jit.is_scripting() or differentiable: 171 return False, False 172 173 fused_supported_devices = _get_fused_kernels_supported_devices() 174 foreach_supported_devices = _get_foreach_kernels_supported_devices() 175 fused = use_fused and all( 176 p is None 177 or ( 178 type(p) in _foreach_supported_types 179 and p.device.type in fused_supported_devices 180 and torch.is_floating_point(p) 181 ) 182 for p in params 183 ) 184 foreach = not fused and all( 185 p is None 186 or ( 187 type(p) in _foreach_supported_types 188 and p.device.type in foreach_supported_devices 189 ) 190 for p in params 191 ) 192 return fused, foreach 193 194 195def _device_dtype_check_for_fused( 196 p: torch.Tensor, cuda_unsupported: bool = False 197) -> None: 198 fused_supported_devices = _get_fused_kernels_supported_devices() 199 if cuda_unsupported: 200 fused_supported_devices.remove("cuda") 201 if not (p.device.type in fused_supported_devices and torch.is_floating_point(p)): 202 raise RuntimeError( 203 "`fused=True` requires all the params to be floating point Tensors of " 204 f"supported devices: {fused_supported_devices} but {p.dtype} and {p.device.type}" 205 ) 206 207 208def _view_as_real(params, *state_and_grads): 209 for i, p in enumerate(params): 210 if torch.is_complex(p): 211 params[i] = torch.view_as_real(params[i]) 212 for s in state_and_grads: 213 s[i] = torch.view_as_real(s[i]) 214 215 216def _get_scalar_dtype(is_fused=None): 217 if is_fused: 218 return torch.float32 219 return ( 220 torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32 221 ) 222 223 224def _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]: 225 r"""Return the device type list that supports capturable optimizer.""" 226 capturable_supported_devices = ["cuda", "xpu", "hpu"] 227 if not torch.jit.is_scripting(): 228 capturable_supported_devices.append(torch._C._get_privateuse1_backend_name()) 229 if supports_xla: 230 capturable_supported_devices.append("xla") 231 return capturable_supported_devices 232 233 234# Common doc strings among optimizers 235_foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer 236 is used. If unspecified by the user (so foreach is None), we will try to use 237 foreach over the for-loop implementation on CUDA, since it is usually 238 significantly more performant. Note that the foreach implementation uses 239 ~ sizeof(params) more peak memory than the for-loop version due to the intermediates 240 being a tensorlist vs just one tensor. If memory is prohibitive, batch fewer 241 parameters through the optimizer at a time or switch this flag to False (default: None)""" 242 243_fused_doc = r"""fused (bool, optional): whether the fused implementation is used. 244 Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` 245 are supported. (default: None) 246 247 .. note:: The foreach and fused implementations are typically faster than the for-loop, 248 single-tensor implementation, with fused being theoretically fastest with both 249 vertical and horizontal fusion. As such, if the user has not specified either 250 flag (i.e., when foreach = fused = None), we will attempt defaulting to the foreach 251 implementation when the tensors are all on CUDA. Why not fused? Since the fused 252 implementation is relatively new, we want to give it sufficient bake-in time. 253 To specify fused, pass True for fused. To force running the for-loop 254 implementation, pass False for either foreach or fused. """ 255 256_capturable_doc = r"""capturable (bool, optional): whether this instance is safe to 257 capture in a CUDA graph. Passing True can impair ungraphed performance, 258 so if you don't intend to graph capture this instance, leave it False 259 (default: False)""" 260 261_differentiable_doc = r"""differentiable (bool, optional): whether autograd should 262 occur through the optimizer step in training. Otherwise, the step() 263 function runs in a torch.no_grad() context. Setting to True can impair 264 performance, so leave it False if you don't intend to run autograd 265 through this instance (default: False)""" 266 267_maximize_doc = r"""maximize (bool, optional): maximize the objective with respect to the 268 params, instead of minimizing (default: False)""" 269 270 271def register_optimizer_step_pre_hook(hook: GlobalOptimizerPreHook) -> RemovableHandle: 272 r"""Register a pre hook common to all optimizers. 273 274 The hook should have the following signature:: 275 276 hook(optimizer, args, kwargs) -> None or modified args and kwargs 277 278 Args: 279 hook (Callable): A user defined hook which is registered on all optimizers. 280 281 Returns: 282 :class:`torch.utils.hooks.RemovableHandle`: 283 a handle that can be used to remove the added hook by calling 284 ``handle.remove()`` 285 """ 286 handle = hooks.RemovableHandle(_global_optimizer_pre_hooks) 287 _global_optimizer_pre_hooks[handle.id] = hook 288 return handle 289 290 291def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> RemovableHandle: 292 r"""Register a post hook common to all optimizers. 293 294 The hook should have the following signature:: 295 296 hook(optimizer, args, kwargs) -> None 297 298 Args: 299 hook (Callable): A user defined hook which is registered on all optimizers. 300 301 Returns: 302 :class:`torch.utils.hooks.RemovableHandle`: 303 a handle that can be used to remove the added hook by calling 304 ``handle.remove()`` 305 """ 306 handle = hooks.RemovableHandle(_global_optimizer_post_hooks) 307 _global_optimizer_post_hooks[handle.id] = hook 308 return handle 309 310 311ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] 312 313_P = ParamSpec("_P") 314R = TypeVar("R") 315T = TypeVar("T") 316 317 318class Optimizer: 319 r"""Base class for all optimizers. 320 321 .. warning:: 322 Parameters need to be specified as collections that have a deterministic 323 ordering that is consistent between runs. Examples of objects that don't 324 satisfy those properties are sets and iterators over values of dictionaries. 325 326 Args: 327 params (iterable): an iterable of :class:`torch.Tensor` s or 328 :class:`dict` s. Specifies what Tensors should be optimized. 329 defaults: (dict): a dict containing default values of optimization 330 options (used when a parameter group doesn't specify them). 331 """ 332 333 OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]] # type: ignore[misc] 334 OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc] 335 336 _optimizer_step_pre_hooks: Dict[int, OptimizerPreHook] 337 _optimizer_step_post_hooks: Dict[int, OptimizerPostHook] 338 _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' 339 _optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' 340 _optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' 341 _optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' 342 343 def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None: # noqa: D107 344 torch._C._log_api_usage_once("python.optimizer") 345 self.defaults = defaults 346 self._optimizer_step_pre_hooks = OrderedDict() 347 self._optimizer_step_post_hooks = OrderedDict() 348 self._optimizer_state_dict_pre_hooks = OrderedDict() 349 self._optimizer_state_dict_post_hooks = OrderedDict() 350 self._optimizer_load_state_dict_pre_hooks = OrderedDict() 351 self._optimizer_load_state_dict_post_hooks = OrderedDict() 352 353 self._patch_step_function() 354 355 if isinstance(params, torch.Tensor): 356 raise TypeError( 357 "params argument given to the optimizer should be " 358 "an iterable of Tensors or dicts, but got " + torch.typename(params) 359 ) 360 361 self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict) 362 self.param_groups: List[Dict[str, Any]] = [] 363 364 param_groups = list(params) 365 if len(param_groups) == 0: 366 raise ValueError("optimizer got an empty parameter list") 367 if not isinstance(param_groups[0], dict): 368 param_groups = [{"params": param_groups}] 369 370 for param_group in param_groups: 371 self.add_param_group(cast(dict, param_group)) 372 373 # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python, 374 # which I don't think exists 375 # https://github.com/pytorch/pytorch/issues/72948 376 self._warned_capturable_if_run_uncaptured = True 377 378 def __getstate__(self) -> Dict[str, Any]: # noqa: D105 379 return { 380 "defaults": self.defaults, 381 "state": self.state, 382 "param_groups": self.param_groups, 383 } 384 385 def __setstate__(self, state: Dict[str, Any]) -> None: # noqa: D105 386 self.__dict__.update(state) 387 if "_optimizer_step_pre_hooks" not in self.__dict__: 388 self._optimizer_step_pre_hooks = OrderedDict() 389 if "_optimizer_step_post_hooks" not in self.__dict__: 390 self._optimizer_step_post_hooks = OrderedDict() 391 if "_optimizer_state_dict_pre_hooks" not in self.__dict__: 392 self._optimizer_state_dict_pre_hooks = OrderedDict() 393 if "_optimizer_state_dict_post_hooks" not in self.__dict__: 394 self._optimizer_state_dict_post_hooks = OrderedDict() 395 if "_optimizer_load_state_dict_pre_hooks" not in self.__dict__: 396 self._optimizer_load_state_dict_pre_hooks = OrderedDict() 397 if "_optimizer_load_state_dict_post_hooks" not in self.__dict__: 398 self._optimizer_load_state_dict_post_hooks = OrderedDict() 399 self._patch_step_function() # To support multiprocessing pickle/unpickle 400 self.defaults.setdefault("differentiable", False) 401 402 def __repr__(self) -> str: # noqa: D105 403 format_string = self.__class__.__name__ + " (" 404 for i, group in enumerate(self.param_groups): 405 format_string += "\n" 406 format_string += f"Parameter Group {i}\n" 407 for key in sorted(group.keys()): 408 if key != "params": 409 format_string += f" {key}: {group[key]}\n" 410 format_string += ")" 411 return format_string 412 413 # Currently needed by Adam and AdamW 414 def _cuda_graph_capture_health_check(self) -> None: 415 # Note [torch.compile x capturable] 416 # If we are compiling, we try to take the capturable path automatically by 417 # setting the flag to True during tracing. Due to this, we skip all the checks 418 # normally required for determining whether we can use CUDA graphs and 419 # shunt the responsibility to torch.inductor. This saves time during tracing 420 # since the checks are slow without sacrificing UX since inductor will warn 421 # later if CUDA graphs cannot be enabled, e.g., 422 # https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390. 423 # Thus, when compiling, inductor will determine if cudagraphs 424 # can be enabled based on whether there is input mutation or CPU tensors. 425 if ( 426 not is_compiling() 427 and torch.backends.cuda.is_built() 428 and torch.cuda.is_available() 429 ): 430 capturing = torch.cuda.is_current_stream_capturing() 431 432 if capturing and not all( 433 group["capturable"] for group in self.param_groups 434 ): 435 raise RuntimeError( 436 "Attempting CUDA graph capture of step() for an instance of " 437 + self.__class__.__name__ 438 + " but param_groups' capturable is False." 439 ) 440 441 if ( 442 (not getattr(self, "_warned_capturable_if_run_uncaptured", False)) 443 and all(group["capturable"] for group in self.param_groups) 444 and (not capturing) 445 ): 446 warnings.warn( 447 "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, " 448 "but step() is running without CUDA graph capture. If you never intend to graph-capture this " 449 "instance, capturable=True can impair performance, and you should set capturable=False." 450 ) 451 self._warned_capturable_if_run_uncaptured = True 452 453 def _optimizer_step_code(self) -> None: 454 """Entry point for `torch.profile.profiler`. 455 456 When python tracing is enabled the profiler will hook into this 457 function at the CPython level to inspect the optimizer's parameters and 458 param groups. It is called it after `step()` since many optimizers 459 lazily initialize state. 460 461 This is a workaround due to lack of a proper step hook on the optimizer, 462 and will be removed if it exists. 463 """ 464 465 @staticmethod 466 def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]: # noqa: D102 467 @functools.wraps(func) 468 def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R: 469 self, *_ = args 470 self = cast(Optimizer, self) 471 profile_name = f"Optimizer.step#{self.__class__.__name__}.step" 472 with torch.autograd.profiler.record_function(profile_name): 473 # call optimizer step pre hooks 474 for pre_hook in chain( 475 _global_optimizer_pre_hooks.values(), 476 self._optimizer_step_pre_hooks.values(), 477 ): 478 result = pre_hook(self, args, kwargs) 479 if result is not None: 480 if isinstance(result, tuple) and len(result) == 2: 481 args, kwargs = result # type: ignore[assignment] 482 else: 483 raise RuntimeError( 484 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}." 485 ) 486 487 out = func(*args, **kwargs) 488 self._optimizer_step_code() 489 490 # call optimizer step post hooks 491 for post_hook in chain( 492 self._optimizer_step_post_hooks.values(), 493 _global_optimizer_post_hooks.values(), 494 ): 495 post_hook(self, args, kwargs) 496 497 return out 498 499 return wrapper 500 501 @staticmethod 502 def _group_tensors_by_device_and_dtype( 503 tensorlistlist: TensorListList, 504 with_indices: bool = False, 505 ) -> Union[ 506 Dict[Tuple[None, None], Tuple[TensorListList, Indices]], 507 Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]], 508 ]: 509 """Group a list of lists of tensors by device and dtype. 510 511 Skips this step if we are compiling since this will occur during inductor lowering. 512 """ 513 if is_compiling(): 514 return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))} 515 else: 516 return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type] 517 518 def _patch_step_function(self) -> None: 519 self._zero_grad_profile_name = ( 520 f"Optimizer.zero_grad#{self.__class__.__name__}.zero_grad" 521 ) 522 hooked = getattr(self.__class__.step, "hooked", None) 523 if not hooked: 524 self.__class__.step = self.profile_hook_step(self.__class__.step) # type: ignore[assignment] 525 self.__class__.step.hooked = True # type: ignore[attr-defined] 526 527 def register_step_pre_hook(self, hook: OptimizerPreHook) -> RemovableHandle: 528 r"""Register an optimizer step pre hook which will be called before optimizer step. 529 530 It should have the following signature:: 531 532 hook(optimizer, args, kwargs) -> None or modified args and kwargs 533 534 The ``optimizer`` argument is the optimizer instance being used. If 535 args and kwargs are modified by the pre-hook, then the transformed 536 values are returned as a tuple containing the new_args and new_kwargs. 537 538 Args: 539 hook (Callable): The user defined hook to be registered. 540 541 Returns: 542 :class:`torch.utils.hooks.RemovableHandle`: 543 a handle that can be used to remove the added hook by calling 544 ``handle.remove()`` 545 """ 546 handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks) 547 self._optimizer_step_pre_hooks[handle.id] = hook 548 return handle 549 550 def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle: 551 r"""Register an optimizer step post hook which will be called after optimizer step. 552 553 It should have the following signature:: 554 555 hook(optimizer, args, kwargs) -> None 556 557 The ``optimizer`` argument is the optimizer instance being used. 558 559 Args: 560 hook (Callable): The user defined hook to be registered. 561 562 Returns: 563 :class:`torch.utils.hooks.RemovableHandle`: 564 a handle that can be used to remove the added hook by calling 565 ``handle.remove()`` 566 """ 567 handle = hooks.RemovableHandle(self._optimizer_step_post_hooks) 568 self._optimizer_step_post_hooks[handle.id] = hook 569 return handle 570 571 def register_state_dict_pre_hook( 572 self, hook: Callable[["Optimizer"], None], prepend: bool = False 573 ) -> RemovableHandle: # noqa: D101 574 r"""Register a state dict pre-hook which will be called before :meth:`~torch.optim.Optimizer.state_dict` is called. 575 576 It should have the following signature:: 577 578 hook(optimizer) -> None 579 580 The ``optimizer`` argument is the optimizer instance being used. 581 The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``. 582 The registered hook can be used to perform pre-processing before the ``state_dict`` 583 call is made. 584 585 Args: 586 hook (Callable): The user defined hook to be registered. 587 prepend (bool): If True, the provided pre ``hook`` will be fired before 588 all the already registered pre-hooks on ``state_dict``. Otherwise, 589 the provided ``hook`` will be fired after all the already registered 590 pre-hooks. (default: False) 591 592 Returns: 593 :class:`torch.utils.hooks.RemoveableHandle`: 594 a handle that can be used to remove the added hook by calling 595 ``handle.remove()`` 596 """ 597 handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks) 598 self._optimizer_state_dict_pre_hooks[handle.id] = hook 599 if prepend: 600 self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False) 601 return handle 602 603 def register_state_dict_post_hook( 604 self, 605 hook: Callable[["Optimizer", StateDict], Optional[StateDict]], 606 prepend: bool = False, 607 ) -> RemovableHandle: 608 r"""Register a state dict post-hook which will be called after :meth:`~torch.optim.Optimizer.state_dict` is called. 609 610 It should have the following signature:: 611 612 hook(optimizer, state_dict) -> state_dict or None 613 614 The hook will be called with arguments ``self`` and ``state_dict`` after generating 615 a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally 616 return a new one. The registered hook can be used to perform post-processing 617 on the ``state_dict`` before it is returned. 618 619 Args: 620 hook (Callable): The user defined hook to be registered. 621 prepend (bool): If True, the provided post ``hook`` will be fired before 622 all the already registered post-hooks on ``state_dict``. Otherwise, 623 the provided ``hook`` will be fired after all the already registered 624 post-hooks. (default: False) 625 626 Returns: 627 :class:`torch.utils.hooks.RemoveableHandle`: 628 a handle that can be used to remove the added hook by calling 629 ``handle.remove()`` 630 """ 631 handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks) 632 self._optimizer_state_dict_post_hooks[handle.id] = hook 633 if prepend: 634 self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False) 635 return handle 636 637 @torch._disable_dynamo 638 def state_dict(self) -> StateDict: 639 r"""Return the state of the optimizer as a :class:`dict`. 640 641 It contains two entries: 642 643 * ``state``: a Dict holding current optimization state. Its content 644 differs between optimizer classes, but some common characteristics 645 hold. For example, state is saved per parameter, and the parameter 646 itself is NOT saved. ``state`` is a Dictionary mapping parameter ids 647 to a Dict with state corresponding to each parameter. 648 * ``param_groups``: a List containing all parameter groups where each 649 parameter group is a Dict. Each parameter group contains metadata 650 specific to the optimizer, such as learning rate and weight decay, 651 as well as a List of parameter IDs of the parameters in the group. 652 653 NOTE: The parameter IDs may look like indices but they are just IDs 654 associating state with param_group. When loading from a state_dict, 655 the optimizer will zip the param_group ``params`` (int IDs) and the 656 optimizer ``param_groups`` (actual ``nn.Parameter`` s) in order to 657 match state WITHOUT additional verification. 658 659 A returned state dict might look something like: 660 661 .. code-block:: text 662 663 { 664 'state': { 665 0: {'momentum_buffer': tensor(...), ...}, 666 1: {'momentum_buffer': tensor(...), ...}, 667 2: {'momentum_buffer': tensor(...), ...}, 668 3: {'momentum_buffer': tensor(...), ...} 669 }, 670 'param_groups': [ 671 { 672 'lr': 0.01, 673 'weight_decay': 0, 674 ... 675 'params': [0] 676 }, 677 { 678 'lr': 0.001, 679 'weight_decay': 0.5, 680 ... 681 'params': [1, 2, 3] 682 } 683 ] 684 } 685 686 """ 687 for pre_hook in self._optimizer_state_dict_pre_hooks.values(): 688 pre_hook(self) 689 690 # Save order indices instead of Tensors 691 param_mappings: Dict[int, int] = {} 692 start_index = 0 693 694 def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: 695 nonlocal start_index 696 packed = {k: v for k, v in group.items() if k != "params"} 697 param_mappings.update( 698 { 699 id(p): i 700 for i, p in enumerate(group["params"], start_index) 701 if id(p) not in param_mappings 702 } 703 ) 704 packed["params"] = [param_mappings[id(p)] for p in group["params"]] 705 start_index += len(packed["params"]) 706 return packed 707 708 param_groups = [pack_group(g) for g in self.param_groups] 709 # Remap state to use order indices as keys 710 packed_state = { 711 (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v 712 for k, v in self.state.items() 713 } 714 715 state_dict = { 716 "state": packed_state, 717 "param_groups": param_groups, 718 } 719 720 for post_hook in self._optimizer_state_dict_post_hooks.values(): 721 hook_result = post_hook(self, state_dict) 722 if hook_result is not None: 723 state_dict = hook_result 724 return state_dict 725 726 @staticmethod 727 def _process_value_according_to_param_policy( 728 param: torch.Tensor, 729 value: torch.Tensor, 730 param_id: int, 731 param_groups: List[Dict[Any, Any]], 732 key: Hashable = None, 733 ) -> torch.Tensor: 734 # Floating-point types are a bit special here. They are the only ones 735 # that are assumed to always match the type of params. 736 # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 737 # UNLESS fused or capturable, see note [special device hosting for step] 738 fused = False 739 capturable = False 740 assert param_groups is not None 741 for pg in param_groups: 742 if param_id in pg["params"]: 743 fused = pg["fused"] if "fused" in pg else False 744 capturable = pg["capturable"] if "capturable" in pg else False 745 break 746 if key == "step": 747 if capturable or fused: 748 return value.to(dtype=torch.float32, device=param.device) 749 else: 750 return value 751 else: 752 if param.is_floating_point(): 753 return value.to(dtype=param.dtype, device=param.device) 754 else: 755 return value.to(device=param.device) 756 757 def register_load_state_dict_pre_hook( 758 self, 759 hook: Callable[["Optimizer", StateDict], Optional[StateDict]], 760 prepend: bool = False, 761 ) -> RemovableHandle: # noqa: D205 D400 762 r"""Register a load_state_dict pre-hook which will be called before 763 :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the 764 following signature:: 765 766 hook(optimizer, state_dict) -> state_dict or None 767 768 The ``optimizer`` argument is the optimizer instance being used and the 769 ``state_dict`` argument is a shallow copy of the ``state_dict`` the user 770 passed in to ``load_state_dict``. The hook may modify the state_dict inplace 771 or optionally return a new one. If a state_dict is returned, it will be used 772 to be loaded into the optimizer. 773 774 The hook will be called with argument ``self`` and ``state_dict`` before 775 calling ``load_state_dict`` on ``self``. The registered hook can be used to 776 perform pre-processing before the ``load_state_dict`` call is made. 777 778 Args: 779 hook (Callable): The user defined hook to be registered. 780 prepend (bool): If True, the provided pre ``hook`` will be fired before 781 all the already registered pre-hooks on ``load_state_dict``. Otherwise, 782 the provided ``hook`` will be fired after all the already registered 783 pre-hooks. (default: False) 784 785 Returns: 786 :class:`torch.utils.hooks.RemoveableHandle`: 787 a handle that can be used to remove the added hook by calling 788 ``handle.remove()`` 789 """ 790 handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks) 791 self._optimizer_load_state_dict_pre_hooks[handle.id] = hook 792 if prepend: 793 self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False) 794 return handle 795 796 def register_load_state_dict_post_hook( 797 self, hook: Callable[["Optimizer"], None], prepend: bool = False 798 ) -> RemovableHandle: # noqa: D205 D400 799 r"""Register a load_state_dict post-hook which will be called after 800 :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the 801 following signature:: 802 803 hook(optimizer) -> None 804 805 The ``optimizer`` argument is the optimizer instance being used. 806 807 The hook will be called with argument ``self`` after calling 808 ``load_state_dict`` on ``self``. The registered hook can be used to 809 perform post-processing after ``load_state_dict`` has loaded the 810 ``state_dict``. 811 812 Args: 813 hook (Callable): The user defined hook to be registered. 814 prepend (bool): If True, the provided post ``hook`` will be fired before 815 all the already registered post-hooks on ``load_state_dict``. Otherwise, 816 the provided ``hook`` will be fired after all the already registered 817 post-hooks. (default: False) 818 819 Returns: 820 :class:`torch.utils.hooks.RemoveableHandle`: 821 a handle that can be used to remove the added hook by calling 822 ``handle.remove()`` 823 """ 824 handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks) 825 self._optimizer_load_state_dict_post_hooks[handle.id] = hook 826 if prepend: 827 self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] 828 return handle 829 830 @torch._disable_dynamo 831 def load_state_dict(self, state_dict: StateDict) -> None: 832 r"""Load the optimizer state. 833 834 Args: 835 state_dict (dict): optimizer state. Should be an object returned 836 from a call to :meth:`state_dict`. 837 """ 838 # shallow copy, to be consistent with module API 839 state_dict = state_dict.copy() 840 841 for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): 842 hook_result = pre_hook(self, state_dict) 843 if hook_result is not None: 844 state_dict = hook_result 845 846 # Validate the state_dict 847 groups = self.param_groups 848 849 # Deepcopy as we write into saved_groups later to update state 850 saved_groups = deepcopy(state_dict["param_groups"]) 851 852 if len(groups) != len(saved_groups): 853 raise ValueError( 854 "loaded state dict has a different number of " "parameter groups" 855 ) 856 param_lens = (len(g["params"]) for g in groups) 857 saved_lens = (len(g["params"]) for g in saved_groups) 858 if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): 859 raise ValueError( 860 "loaded state dict contains a parameter group " 861 "that doesn't match the size of optimizer's group" 862 ) 863 864 # Update the state 865 id_map = dict( 866 zip( 867 chain.from_iterable(g["params"] for g in saved_groups), 868 chain.from_iterable(g["params"] for g in groups), 869 ) 870 ) 871 872 def _cast(param, value, param_id=None, param_groups=None, key=None): 873 r"""Make a deep copy of value, casting all tensors to device of param.""" 874 if isinstance(value, torch.Tensor): 875 return Optimizer._process_value_according_to_param_policy( 876 param, value, param_id, param_groups, key 877 ) 878 elif isinstance(value, dict): 879 return { 880 k: _cast( 881 param, v, param_id=param_id, param_groups=param_groups, key=k 882 ) 883 for k, v in value.items() 884 } 885 elif isinstance(value, Iterable): 886 return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] 887 else: 888 return value 889 890 # Copy state assigned to params (and cast tensors to appropriate types). 891 # State that is not assigned to params is copied as is (needed for 892 # backward compatibility). 893 state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) 894 for k, v in state_dict["state"].items(): 895 if k in id_map: 896 param = id_map[k] 897 state[param] = _cast( 898 param, v, param_id=k, param_groups=state_dict["param_groups"] 899 ) 900 else: 901 state[k] = v 902 903 # Update parameter groups, setting their 'params' value 904 def update_group( 905 group: Dict[str, Any], new_group: Dict[str, Any] 906 ) -> Dict[str, Any]: 907 new_group["params"] = group["params"] 908 return new_group 909 910 param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] 911 self.__setstate__({"state": state, "param_groups": param_groups}) 912 913 for post_hook in self._optimizer_load_state_dict_post_hooks.values(): 914 post_hook(self) 915 916 @torch._disable_dynamo 917 def zero_grad(self, set_to_none: bool = True) -> None: 918 r"""Reset the gradients of all optimized :class:`torch.Tensor` s. 919 920 Args: 921 set_to_none (bool): instead of setting to zero, set the grads to None. 922 This will in general have lower memory footprint, and can modestly improve performance. 923 However, it changes certain behaviors. For example: 924 1. When the user tries to access a gradient and perform manual ops on it, 925 a None attribute or a Tensor full of 0s will behave differently. 926 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s 927 are guaranteed to be None for params that did not receive a gradient. 928 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None 929 (in one case it does the step with a gradient of 0 and in the other it skips 930 the step altogether). 931 """ 932 foreach = self.defaults.get("foreach", False) or self.defaults.get( 933 "fused", False 934 ) 935 936 if not hasattr(self, "_zero_grad_profile_name"): 937 self._patch_step_function() 938 939 per_device_and_dtype_grads: Optional[ 940 DefaultDict[torch.device, DefaultDict[torch.dtype, List[torch.Tensor]]] 941 ] 942 if foreach: 943 per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) 944 else: 945 per_device_and_dtype_grads = None 946 947 with torch.autograd.profiler.record_function(self._zero_grad_profile_name): 948 for group in self.param_groups: 949 for p in group["params"]: 950 if p.grad is not None: 951 if set_to_none: 952 p.grad = None 953 else: 954 if p.grad.grad_fn is not None: 955 p.grad.detach_() 956 else: 957 p.grad.requires_grad_(False) 958 if not foreach or p.grad.is_sparse: 959 p.grad.zero_() 960 else: 961 assert per_device_and_dtype_grads is not None 962 per_device_and_dtype_grads[p.grad.device][ 963 p.grad.dtype 964 ].append(p.grad) 965 if foreach: 966 assert per_device_and_dtype_grads is not None 967 for per_dtype_grads in per_device_and_dtype_grads.values(): 968 for grads in per_dtype_grads.values(): 969 torch._foreach_zero_(grads) 970 971 @overload 972 def step(self, closure: None = ...) -> None: 973 ... 974 975 @overload 976 def step(self, closure: Callable[[], float]) -> float: 977 ... 978 979 def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: 980 r"""Perform a single optimization step to update parameter. 981 982 Args: 983 closure (Callable): A closure that reevaluates the model and 984 returns the loss. Optional for most optimizers. 985 986 .. note:: 987 Unless otherwise specified, this function should not modify the 988 ``.grad`` field of the parameters. 989 """ 990 raise NotImplementedError 991 992 @torch._disable_dynamo 993 def add_param_group(self, param_group: Dict[str, Any]) -> None: 994 r"""Add a param group to the :class:`Optimizer` s `param_groups`. 995 996 This can be useful when fine tuning a pre-trained network as frozen layers can be made 997 trainable and added to the :class:`Optimizer` as training progresses. 998 999 Args: 1000 param_group (dict): Specifies what Tensors should be optimized along with group 1001 specific optimization options. 1002 """ 1003 if not isinstance(param_group, dict): 1004 raise TypeError(f"param_group must be a dict, but got {type(param_group)}") 1005 1006 params = param_group["params"] 1007 if isinstance(params, torch.Tensor): 1008 param_group["params"] = [params] 1009 elif isinstance(params, set): 1010 raise TypeError( 1011 "optimizer parameters need to be organized in ordered collections, but " 1012 "the ordering of tensors in sets will change between runs. Please use a list instead." 1013 ) 1014 else: 1015 param_group["params"] = list(params) 1016 1017 for param in param_group["params"]: 1018 if not isinstance(param, torch.Tensor): 1019 raise TypeError( 1020 "optimizer can only optimize Tensors, " 1021 "but one of the params is " + torch.typename(param) 1022 ) 1023 if not self.defaults.get("differentiable", None) and not ( 1024 param.is_leaf or param.retains_grad 1025 ): 1026 raise ValueError("can't optimize a non-leaf Tensor") 1027 1028 for name, default in self.defaults.items(): 1029 if default is required and name not in param_group: 1030 raise ValueError( 1031 f"parameter group didn't specify a value of required optimization parameter {name}" 1032 ) 1033 else: 1034 param_group.setdefault(name, default) 1035 1036 params = param_group["params"] 1037 if len(params) != len(set(params)): 1038 warnings.warn( 1039 "optimizer contains a parameter group with duplicate parameters; " 1040 "in future, this will cause an error; " 1041 "see github.com/pytorch/pytorch/issues/40967 for more information", 1042 stacklevel=3, 1043 ) 1044 1045 param_set: Set[torch.Tensor] = set() 1046 for group in self.param_groups: 1047 param_set.update(set(group["params"])) 1048 1049 if not param_set.isdisjoint(set(param_group["params"])): 1050 raise ValueError("some parameters appear in more than one parameter group") 1051 1052 self.param_groups.append(param_group) 1053