xref: /aosp_15_r20/external/pytorch/torch/optim/optimizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3"""Base optimizer."""
4import functools
5import warnings
6from collections import defaultdict, OrderedDict
7from copy import deepcopy
8from itertools import chain
9from typing import (
10    Any,
11    Callable,
12    cast,
13    DefaultDict,
14    Dict,
15    Hashable,
16    Iterable,
17    List,
18    Optional,
19    overload,
20    Set,
21    Tuple,
22    TypeVar,
23    Union,
24)
25from typing_extensions import ParamSpec, Self, TypeAlias
26
27import torch
28import torch.utils.hooks as hooks
29from torch._utils import is_compiling
30from torch.utils._foreach_utils import (
31    _get_foreach_kernels_supported_devices,
32    _get_fused_kernels_supported_devices,
33    _group_tensors_by_device_and_dtype,
34    Indices,
35    TensorListList,
36)
37from torch.utils.hooks import RemovableHandle
38
39
40Args: TypeAlias = Tuple[Any, ...]
41Kwargs: TypeAlias = Dict[str, Any]
42StateDict: TypeAlias = Dict[str, Any]
43DeviceDict = Dict[Optional[torch.device], torch.Tensor]
44
45
46GlobalOptimizerPreHook: TypeAlias = Callable[
47    ["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]]
48]
49GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None]
50
51__all__ = [
52    "Optimizer",
53    "register_optimizer_step_pre_hook",
54    "register_optimizer_step_post_hook",
55]
56_global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict()
57_global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict()
58_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
59
60
61class _RequiredParameter:
62    """Singleton class representing a required parameter for an Optimizer."""
63
64    def __repr__(self) -> str:
65        return "<required parameter>"
66
67
68required = _RequiredParameter()
69
70
71def _use_grad_for_differentiable(func):
72    def _use_grad(self, *args, **kwargs):
73        import torch._dynamo
74
75        prev_grad = torch.is_grad_enabled()
76        try:
77            # Note on graph break below:
78            # we need to graph break to ensure that aot respects the no_grad annotation.
79            # This is important for perf because without this, functionalization will generate an epilogue
80            # which updates the mutated parameters of the optimizer which is *not* visible to inductor, as a result,
81            # inductor will allocate for every parameter in the model, which is horrible.
82            # With this, aot correctly sees that this is an inference graph, and functionalization will generate
83            # an epilogue which is appended to the graph, which *is* visible to inductor, as a result, inductor sees that
84            # step is in place and is able to avoid the extra allocation.
85            # In the future, we will either 1) continue to graph break on backward, so this graph break does not matter
86            # or 2) have a fully fused forward and backward graph, which will have no_grad by default, and we can remove this
87            # graph break to allow the fully fused fwd-bwd-optimizer graph to be compiled.
88            # see https://github.com/pytorch/pytorch/issues/104053
89            torch.set_grad_enabled(self.defaults["differentiable"])
90            torch._dynamo.graph_break()
91            ret = func(self, *args, **kwargs)
92        finally:
93            torch._dynamo.graph_break()
94            torch.set_grad_enabled(prev_grad)
95        return ret
96
97    functools.update_wrapper(_use_grad, func)
98    return _use_grad
99
100
101def _get_value(x):
102    # item is significantly faster than a cpu tensor in eager mode
103    if not torch.jit.is_scripting() and is_compiling():
104        return x
105    else:
106        return x.item() if isinstance(x, torch.Tensor) else x
107
108
109def _stack_if_compiling(x):
110    if not torch.jit.is_scripting() and is_compiling():
111        return torch.stack(x)
112    else:
113        return x
114
115
116def _disable_dynamo_if_unsupported(single_tensor_fn=None):
117    # workaround for torchscript BC
118    # it requires all called functions to be in the
119    # global environment at the site at which the
120    # maybe_fallback closure is created
121    if single_tensor_fn:
122        globals()[single_tensor_fn.__name__] = single_tensor_fn
123
124    def wrapper(func):
125        import inspect
126
127        disabled_func = torch._disable_dynamo(func)
128        ps = inspect.signature(func).parameters
129        has_state_steps = True
130        try:
131            state_steps_ind = list(ps.keys()).index("state_steps")
132        except ValueError:
133            has_state_steps = False
134
135        # Today, there are cases where we stack state steps
136        # and pass them as the value arg of foreach ops.
137        # Having state steps on cuda as the value arg is not supported in eager,
138        # but this only occurs in the rare case that the user explicitly deletes
139        # the capturable flag. If capturable=True, this is not a problem.
140        @functools.wraps(func)
141        def maybe_fallback(*args, **kwargs):
142            if is_compiling() and (
143                not kwargs.get("capturable", False)
144                and has_state_steps
145                and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda)
146                or (
147                    "state_steps" in kwargs
148                    and kwargs["state_steps"]
149                    and kwargs["state_steps"][0].is_cuda
150                )
151            ):
152                return disabled_func(*args, **kwargs)
153            else:
154                return func(*args, **kwargs)
155
156        return maybe_fallback
157
158    return wrapper
159
160
161# For any optimizer with a faster implementation, we attempt to default to the
162# fastest + stablest whenever possible. For foreach, the requirements are to have
163# native params all on CUDA. For fused, there's currently the additional requirement
164# that the tensors' dtypes must be floating point. Neither alternative supports
165# torch.jit.script nor differentiable, so we fall back to the single tensor
166# implementation in those cases.
167def _default_to_fused_or_foreach(
168    params: List[torch.Tensor], differentiable: bool, use_fused: bool = False
169) -> Tuple[bool, bool]:
170    if torch.jit.is_scripting() or differentiable:
171        return False, False
172
173    fused_supported_devices = _get_fused_kernels_supported_devices()
174    foreach_supported_devices = _get_foreach_kernels_supported_devices()
175    fused = use_fused and all(
176        p is None
177        or (
178            type(p) in _foreach_supported_types
179            and p.device.type in fused_supported_devices
180            and torch.is_floating_point(p)
181        )
182        for p in params
183    )
184    foreach = not fused and all(
185        p is None
186        or (
187            type(p) in _foreach_supported_types
188            and p.device.type in foreach_supported_devices
189        )
190        for p in params
191    )
192    return fused, foreach
193
194
195def _device_dtype_check_for_fused(
196    p: torch.Tensor, cuda_unsupported: bool = False
197) -> None:
198    fused_supported_devices = _get_fused_kernels_supported_devices()
199    if cuda_unsupported:
200        fused_supported_devices.remove("cuda")
201    if not (p.device.type in fused_supported_devices and torch.is_floating_point(p)):
202        raise RuntimeError(
203            "`fused=True` requires all the params to be floating point Tensors of "
204            f"supported devices: {fused_supported_devices} but {p.dtype} and {p.device.type}"
205        )
206
207
208def _view_as_real(params, *state_and_grads):
209    for i, p in enumerate(params):
210        if torch.is_complex(p):
211            params[i] = torch.view_as_real(params[i])
212            for s in state_and_grads:
213                s[i] = torch.view_as_real(s[i])
214
215
216def _get_scalar_dtype(is_fused=None):
217    if is_fused:
218        return torch.float32
219    return (
220        torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32
221    )
222
223
224def _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]:
225    r"""Return the device type list that supports capturable optimizer."""
226    capturable_supported_devices = ["cuda", "xpu", "hpu"]
227    if not torch.jit.is_scripting():
228        capturable_supported_devices.append(torch._C._get_privateuse1_backend_name())
229    if supports_xla:
230        capturable_supported_devices.append("xla")
231    return capturable_supported_devices
232
233
234# Common doc strings among optimizers
235_foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer
236            is used. If unspecified by the user (so foreach is None), we will try to use
237            foreach over the for-loop implementation on CUDA, since it is usually
238            significantly more performant. Note that the foreach implementation uses
239            ~ sizeof(params) more peak memory than the for-loop version due to the intermediates
240            being a tensorlist vs just one tensor. If memory is prohibitive, batch fewer
241            parameters through the optimizer at a time or switch this flag to False (default: None)"""
242
243_fused_doc = r"""fused (bool, optional): whether the fused implementation is used.
244            Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
245            are supported. (default: None)
246
247    .. note:: The foreach and fused implementations are typically faster than the for-loop,
248              single-tensor implementation, with fused being theoretically fastest with both
249              vertical and horizontal fusion. As such, if the user has not specified either
250              flag (i.e., when foreach = fused = None), we will attempt defaulting to the foreach
251              implementation when the tensors are all on CUDA. Why not fused? Since the fused
252              implementation is relatively new, we want to give it sufficient bake-in time.
253              To specify fused, pass True for fused. To force running the for-loop
254              implementation, pass False for either foreach or fused. """
255
256_capturable_doc = r"""capturable (bool, optional): whether this instance is safe to
257            capture in a CUDA graph. Passing True can impair ungraphed performance,
258            so if you don't intend to graph capture this instance, leave it False
259            (default: False)"""
260
261_differentiable_doc = r"""differentiable (bool, optional): whether autograd should
262            occur through the optimizer step in training. Otherwise, the step()
263            function runs in a torch.no_grad() context. Setting to True can impair
264            performance, so leave it False if you don't intend to run autograd
265            through this instance (default: False)"""
266
267_maximize_doc = r"""maximize (bool, optional): maximize the objective with respect to the
268            params, instead of minimizing (default: False)"""
269
270
271def register_optimizer_step_pre_hook(hook: GlobalOptimizerPreHook) -> RemovableHandle:
272    r"""Register a pre hook common to all optimizers.
273
274    The hook should have the following signature::
275
276        hook(optimizer, args, kwargs) -> None or modified args and kwargs
277
278    Args:
279        hook (Callable): A user defined hook which is registered on all optimizers.
280
281    Returns:
282        :class:`torch.utils.hooks.RemovableHandle`:
283            a handle that can be used to remove the added hook by calling
284            ``handle.remove()``
285    """
286    handle = hooks.RemovableHandle(_global_optimizer_pre_hooks)
287    _global_optimizer_pre_hooks[handle.id] = hook
288    return handle
289
290
291def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> RemovableHandle:
292    r"""Register a post hook common to all optimizers.
293
294    The hook should have the following signature::
295
296        hook(optimizer, args, kwargs) -> None
297
298    Args:
299        hook (Callable): A user defined hook which is registered on all optimizers.
300
301    Returns:
302        :class:`torch.utils.hooks.RemovableHandle`:
303            a handle that can be used to remove the added hook by calling
304            ``handle.remove()``
305    """
306    handle = hooks.RemovableHandle(_global_optimizer_post_hooks)
307    _global_optimizer_post_hooks[handle.id] = hook
308    return handle
309
310
311ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
312
313_P = ParamSpec("_P")
314R = TypeVar("R")
315T = TypeVar("T")
316
317
318class Optimizer:
319    r"""Base class for all optimizers.
320
321    .. warning::
322        Parameters need to be specified as collections that have a deterministic
323        ordering that is consistent between runs. Examples of objects that don't
324        satisfy those properties are sets and iterators over values of dictionaries.
325
326    Args:
327        params (iterable): an iterable of :class:`torch.Tensor` s or
328            :class:`dict` s. Specifies what Tensors should be optimized.
329        defaults: (dict): a dict containing default values of optimization
330            options (used when a parameter group doesn't specify them).
331    """
332
333    OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]]  # type: ignore[misc]
334    OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None]  # type: ignore[misc]
335
336    _optimizer_step_pre_hooks: Dict[int, OptimizerPreHook]
337    _optimizer_step_post_hooks: Dict[int, OptimizerPostHook]
338    _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
339    _optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
340    _optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
341    _optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
342
343    def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:  # noqa: D107
344        torch._C._log_api_usage_once("python.optimizer")
345        self.defaults = defaults
346        self._optimizer_step_pre_hooks = OrderedDict()
347        self._optimizer_step_post_hooks = OrderedDict()
348        self._optimizer_state_dict_pre_hooks = OrderedDict()
349        self._optimizer_state_dict_post_hooks = OrderedDict()
350        self._optimizer_load_state_dict_pre_hooks = OrderedDict()
351        self._optimizer_load_state_dict_post_hooks = OrderedDict()
352
353        self._patch_step_function()
354
355        if isinstance(params, torch.Tensor):
356            raise TypeError(
357                "params argument given to the optimizer should be "
358                "an iterable of Tensors or dicts, but got " + torch.typename(params)
359            )
360
361        self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
362        self.param_groups: List[Dict[str, Any]] = []
363
364        param_groups = list(params)
365        if len(param_groups) == 0:
366            raise ValueError("optimizer got an empty parameter list")
367        if not isinstance(param_groups[0], dict):
368            param_groups = [{"params": param_groups}]
369
370        for param_group in param_groups:
371            self.add_param_group(cast(dict, param_group))
372
373        # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python,
374        # which I don't think exists
375        # https://github.com/pytorch/pytorch/issues/72948
376        self._warned_capturable_if_run_uncaptured = True
377
378    def __getstate__(self) -> Dict[str, Any]:  # noqa: D105
379        return {
380            "defaults": self.defaults,
381            "state": self.state,
382            "param_groups": self.param_groups,
383        }
384
385    def __setstate__(self, state: Dict[str, Any]) -> None:  # noqa: D105
386        self.__dict__.update(state)
387        if "_optimizer_step_pre_hooks" not in self.__dict__:
388            self._optimizer_step_pre_hooks = OrderedDict()
389        if "_optimizer_step_post_hooks" not in self.__dict__:
390            self._optimizer_step_post_hooks = OrderedDict()
391        if "_optimizer_state_dict_pre_hooks" not in self.__dict__:
392            self._optimizer_state_dict_pre_hooks = OrderedDict()
393        if "_optimizer_state_dict_post_hooks" not in self.__dict__:
394            self._optimizer_state_dict_post_hooks = OrderedDict()
395        if "_optimizer_load_state_dict_pre_hooks" not in self.__dict__:
396            self._optimizer_load_state_dict_pre_hooks = OrderedDict()
397        if "_optimizer_load_state_dict_post_hooks" not in self.__dict__:
398            self._optimizer_load_state_dict_post_hooks = OrderedDict()
399        self._patch_step_function()  # To support multiprocessing pickle/unpickle
400        self.defaults.setdefault("differentiable", False)
401
402    def __repr__(self) -> str:  # noqa: D105
403        format_string = self.__class__.__name__ + " ("
404        for i, group in enumerate(self.param_groups):
405            format_string += "\n"
406            format_string += f"Parameter Group {i}\n"
407            for key in sorted(group.keys()):
408                if key != "params":
409                    format_string += f"    {key}: {group[key]}\n"
410        format_string += ")"
411        return format_string
412
413    # Currently needed by Adam and AdamW
414    def _cuda_graph_capture_health_check(self) -> None:
415        # Note [torch.compile x capturable]
416        # If we are compiling, we try to take the capturable path automatically by
417        # setting the flag to True during tracing. Due to this, we skip all the checks
418        # normally required for determining whether we can use CUDA graphs and
419        # shunt the responsibility to torch.inductor. This saves time during tracing
420        # since the checks are slow without sacrificing UX since inductor will warn
421        # later if CUDA graphs cannot be enabled, e.g.,
422        # https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390.
423        # Thus, when compiling, inductor will determine if cudagraphs
424        # can be enabled based on whether there is input mutation or CPU tensors.
425        if (
426            not is_compiling()
427            and torch.backends.cuda.is_built()
428            and torch.cuda.is_available()
429        ):
430            capturing = torch.cuda.is_current_stream_capturing()
431
432            if capturing and not all(
433                group["capturable"] for group in self.param_groups
434            ):
435                raise RuntimeError(
436                    "Attempting CUDA graph capture of step() for an instance of "
437                    + self.__class__.__name__
438                    + " but param_groups' capturable is False."
439                )
440
441            if (
442                (not getattr(self, "_warned_capturable_if_run_uncaptured", False))
443                and all(group["capturable"] for group in self.param_groups)
444                and (not capturing)
445            ):
446                warnings.warn(
447                    "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, "
448                    "but step() is running without CUDA graph capture. If you never intend to graph-capture this "
449                    "instance, capturable=True can impair performance, and you should set capturable=False."
450                )
451                self._warned_capturable_if_run_uncaptured = True
452
453    def _optimizer_step_code(self) -> None:
454        """Entry point for `torch.profile.profiler`.
455
456        When python tracing is enabled the profiler will hook into this
457        function at the CPython level to inspect the optimizer's parameters and
458        param groups. It is called it after `step()` since many optimizers
459        lazily initialize state.
460
461        This is a workaround due to lack of a proper step hook on the optimizer,
462        and will be removed if it exists.
463        """
464
465    @staticmethod
466    def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]:  # noqa: D102
467        @functools.wraps(func)
468        def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R:
469            self, *_ = args
470            self = cast(Optimizer, self)
471            profile_name = f"Optimizer.step#{self.__class__.__name__}.step"
472            with torch.autograd.profiler.record_function(profile_name):
473                # call optimizer step pre hooks
474                for pre_hook in chain(
475                    _global_optimizer_pre_hooks.values(),
476                    self._optimizer_step_pre_hooks.values(),
477                ):
478                    result = pre_hook(self, args, kwargs)
479                    if result is not None:
480                        if isinstance(result, tuple) and len(result) == 2:
481                            args, kwargs = result  # type: ignore[assignment]
482                        else:
483                            raise RuntimeError(
484                                f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
485                            )
486
487                out = func(*args, **kwargs)
488                self._optimizer_step_code()
489
490                # call optimizer step post hooks
491                for post_hook in chain(
492                    self._optimizer_step_post_hooks.values(),
493                    _global_optimizer_post_hooks.values(),
494                ):
495                    post_hook(self, args, kwargs)
496
497                return out
498
499        return wrapper
500
501    @staticmethod
502    def _group_tensors_by_device_and_dtype(
503        tensorlistlist: TensorListList,
504        with_indices: bool = False,
505    ) -> Union[
506        Dict[Tuple[None, None], Tuple[TensorListList, Indices]],
507        Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]],
508    ]:
509        """Group a list of lists of tensors by device and dtype.
510
511        Skips this step if we are compiling since this will occur during inductor lowering.
512        """
513        if is_compiling():
514            return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))}
515        else:
516            return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices)  # type: ignore[return-value, arg-type]
517
518    def _patch_step_function(self) -> None:
519        self._zero_grad_profile_name = (
520            f"Optimizer.zero_grad#{self.__class__.__name__}.zero_grad"
521        )
522        hooked = getattr(self.__class__.step, "hooked", None)
523        if not hooked:
524            self.__class__.step = self.profile_hook_step(self.__class__.step)  # type: ignore[assignment]
525            self.__class__.step.hooked = True  # type: ignore[attr-defined]
526
527    def register_step_pre_hook(self, hook: OptimizerPreHook) -> RemovableHandle:
528        r"""Register an optimizer step pre hook which will be called before optimizer step.
529
530        It should have the following signature::
531
532            hook(optimizer, args, kwargs) -> None or modified args and kwargs
533
534        The ``optimizer`` argument is the optimizer instance being used. If
535        args and kwargs are modified by the pre-hook, then the transformed
536        values are returned as a tuple containing the new_args and new_kwargs.
537
538        Args:
539            hook (Callable): The user defined hook to be registered.
540
541        Returns:
542            :class:`torch.utils.hooks.RemovableHandle`:
543                a handle that can be used to remove the added hook by calling
544                ``handle.remove()``
545        """
546        handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks)
547        self._optimizer_step_pre_hooks[handle.id] = hook
548        return handle
549
550    def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle:
551        r"""Register an optimizer step post hook which will be called after optimizer step.
552
553        It should have the following signature::
554
555            hook(optimizer, args, kwargs) -> None
556
557        The ``optimizer`` argument is the optimizer instance being used.
558
559        Args:
560            hook (Callable): The user defined hook to be registered.
561
562        Returns:
563            :class:`torch.utils.hooks.RemovableHandle`:
564                a handle that can be used to remove the added hook by calling
565                ``handle.remove()``
566        """
567        handle = hooks.RemovableHandle(self._optimizer_step_post_hooks)
568        self._optimizer_step_post_hooks[handle.id] = hook
569        return handle
570
571    def register_state_dict_pre_hook(
572        self, hook: Callable[["Optimizer"], None], prepend: bool = False
573    ) -> RemovableHandle:  # noqa: D101
574        r"""Register a state dict pre-hook which will be called before :meth:`~torch.optim.Optimizer.state_dict` is called.
575
576        It should have the following signature::
577
578            hook(optimizer) -> None
579
580        The ``optimizer`` argument is the optimizer instance being used.
581        The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``.
582        The registered hook can be used to perform pre-processing before the ``state_dict``
583        call is made.
584
585        Args:
586            hook (Callable): The user defined hook to be registered.
587            prepend (bool): If True, the provided pre ``hook`` will be fired before
588                all the already registered pre-hooks on ``state_dict``. Otherwise,
589                the provided ``hook`` will be fired after all the already registered
590                pre-hooks. (default: False)
591
592        Returns:
593            :class:`torch.utils.hooks.RemoveableHandle`:
594                a handle that can be used to remove the added hook by calling
595                ``handle.remove()``
596        """
597        handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks)
598        self._optimizer_state_dict_pre_hooks[handle.id] = hook
599        if prepend:
600            self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False)
601        return handle
602
603    def register_state_dict_post_hook(
604        self,
605        hook: Callable[["Optimizer", StateDict], Optional[StateDict]],
606        prepend: bool = False,
607    ) -> RemovableHandle:
608        r"""Register a state dict post-hook which will be called after :meth:`~torch.optim.Optimizer.state_dict` is called.
609
610        It should have the following signature::
611
612            hook(optimizer, state_dict) -> state_dict or None
613
614        The hook will be called with arguments ``self`` and ``state_dict`` after generating
615        a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally
616        return a new one. The registered hook can be used to perform post-processing
617        on the ``state_dict`` before it is returned.
618
619        Args:
620            hook (Callable): The user defined hook to be registered.
621            prepend (bool): If True, the provided post ``hook`` will be fired before
622                all the already registered post-hooks on ``state_dict``. Otherwise,
623                the provided ``hook`` will be fired after all the already registered
624                post-hooks. (default: False)
625
626        Returns:
627            :class:`torch.utils.hooks.RemoveableHandle`:
628                a handle that can be used to remove the added hook by calling
629                ``handle.remove()``
630        """
631        handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks)
632        self._optimizer_state_dict_post_hooks[handle.id] = hook
633        if prepend:
634            self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False)
635        return handle
636
637    @torch._disable_dynamo
638    def state_dict(self) -> StateDict:
639        r"""Return the state of the optimizer as a :class:`dict`.
640
641        It contains two entries:
642
643        * ``state``: a Dict holding current optimization state. Its content
644            differs between optimizer classes, but some common characteristics
645            hold. For example, state is saved per parameter, and the parameter
646            itself is NOT saved. ``state`` is a Dictionary mapping parameter ids
647            to a Dict with state corresponding to each parameter.
648        * ``param_groups``: a List containing all parameter groups where each
649            parameter group is a Dict. Each parameter group contains metadata
650            specific to the optimizer, such as learning rate and weight decay,
651            as well as a List of parameter IDs of the parameters in the group.
652
653        NOTE: The parameter IDs may look like indices but they are just IDs
654        associating state with param_group. When loading from a state_dict,
655        the optimizer will zip the param_group ``params`` (int IDs) and the
656        optimizer ``param_groups`` (actual ``nn.Parameter`` s) in order to
657        match state WITHOUT additional verification.
658
659        A returned state dict might look something like:
660
661        .. code-block:: text
662
663            {
664                'state': {
665                    0: {'momentum_buffer': tensor(...), ...},
666                    1: {'momentum_buffer': tensor(...), ...},
667                    2: {'momentum_buffer': tensor(...), ...},
668                    3: {'momentum_buffer': tensor(...), ...}
669                },
670                'param_groups': [
671                    {
672                        'lr': 0.01,
673                        'weight_decay': 0,
674                        ...
675                        'params': [0]
676                    },
677                    {
678                        'lr': 0.001,
679                        'weight_decay': 0.5,
680                        ...
681                        'params': [1, 2, 3]
682                    }
683                ]
684            }
685
686        """
687        for pre_hook in self._optimizer_state_dict_pre_hooks.values():
688            pre_hook(self)
689
690        # Save order indices instead of Tensors
691        param_mappings: Dict[int, int] = {}
692        start_index = 0
693
694        def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
695            nonlocal start_index
696            packed = {k: v for k, v in group.items() if k != "params"}
697            param_mappings.update(
698                {
699                    id(p): i
700                    for i, p in enumerate(group["params"], start_index)
701                    if id(p) not in param_mappings
702                }
703            )
704            packed["params"] = [param_mappings[id(p)] for p in group["params"]]
705            start_index += len(packed["params"])
706            return packed
707
708        param_groups = [pack_group(g) for g in self.param_groups]
709        # Remap state to use order indices as keys
710        packed_state = {
711            (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
712            for k, v in self.state.items()
713        }
714
715        state_dict = {
716            "state": packed_state,
717            "param_groups": param_groups,
718        }
719
720        for post_hook in self._optimizer_state_dict_post_hooks.values():
721            hook_result = post_hook(self, state_dict)
722            if hook_result is not None:
723                state_dict = hook_result
724        return state_dict
725
726    @staticmethod
727    def _process_value_according_to_param_policy(
728        param: torch.Tensor,
729        value: torch.Tensor,
730        param_id: int,
731        param_groups: List[Dict[Any, Any]],
732        key: Hashable = None,
733    ) -> torch.Tensor:
734        # Floating-point types are a bit special here. They are the only ones
735        # that are assumed to always match the type of params.
736        # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
737        # UNLESS fused or capturable, see note [special device hosting for step]
738        fused = False
739        capturable = False
740        assert param_groups is not None
741        for pg in param_groups:
742            if param_id in pg["params"]:
743                fused = pg["fused"] if "fused" in pg else False
744                capturable = pg["capturable"] if "capturable" in pg else False
745                break
746        if key == "step":
747            if capturable or fused:
748                return value.to(dtype=torch.float32, device=param.device)
749            else:
750                return value
751        else:
752            if param.is_floating_point():
753                return value.to(dtype=param.dtype, device=param.device)
754            else:
755                return value.to(device=param.device)
756
757    def register_load_state_dict_pre_hook(
758        self,
759        hook: Callable[["Optimizer", StateDict], Optional[StateDict]],
760        prepend: bool = False,
761    ) -> RemovableHandle:  # noqa: D205 D400
762        r"""Register a load_state_dict pre-hook which will be called before
763        :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
764        following signature::
765
766            hook(optimizer, state_dict) -> state_dict or None
767
768        The ``optimizer`` argument is the optimizer instance being used and the
769        ``state_dict`` argument is a shallow copy of the ``state_dict`` the user
770        passed in to ``load_state_dict``. The hook may modify the state_dict inplace
771        or optionally return a new one. If a state_dict is returned, it will be used
772        to be loaded into the optimizer.
773
774        The hook will be called with argument ``self`` and ``state_dict`` before
775        calling ``load_state_dict`` on ``self``. The registered hook can be used to
776        perform pre-processing before the ``load_state_dict`` call is made.
777
778        Args:
779            hook (Callable): The user defined hook to be registered.
780            prepend (bool): If True, the provided pre ``hook`` will be fired before
781                all the already registered pre-hooks on ``load_state_dict``. Otherwise,
782                the provided ``hook`` will be fired after all the already registered
783                pre-hooks. (default: False)
784
785        Returns:
786            :class:`torch.utils.hooks.RemoveableHandle`:
787                a handle that can be used to remove the added hook by calling
788                ``handle.remove()``
789        """
790        handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks)
791        self._optimizer_load_state_dict_pre_hooks[handle.id] = hook
792        if prepend:
793            self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False)
794        return handle
795
796    def register_load_state_dict_post_hook(
797        self, hook: Callable[["Optimizer"], None], prepend: bool = False
798    ) -> RemovableHandle:  # noqa: D205 D400
799        r"""Register a load_state_dict post-hook which will be called after
800        :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
801        following signature::
802
803            hook(optimizer) -> None
804
805        The ``optimizer`` argument is the optimizer instance being used.
806
807        The hook will be called with argument ``self`` after calling
808        ``load_state_dict`` on ``self``. The registered hook can be used to
809        perform post-processing after ``load_state_dict`` has loaded the
810        ``state_dict``.
811
812        Args:
813            hook (Callable): The user defined hook to be registered.
814            prepend (bool): If True, the provided post ``hook`` will be fired before
815                all the already registered post-hooks on ``load_state_dict``. Otherwise,
816                the provided ``hook`` will be fired after all the already registered
817                post-hooks. (default: False)
818
819        Returns:
820            :class:`torch.utils.hooks.RemoveableHandle`:
821                a handle that can be used to remove the added hook by calling
822                ``handle.remove()``
823        """
824        handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks)
825        self._optimizer_load_state_dict_post_hooks[handle.id] = hook
826        if prepend:
827            self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]
828        return handle
829
830    @torch._disable_dynamo
831    def load_state_dict(self, state_dict: StateDict) -> None:
832        r"""Load the optimizer state.
833
834        Args:
835            state_dict (dict): optimizer state. Should be an object returned
836                from a call to :meth:`state_dict`.
837        """
838        # shallow copy, to be consistent with module API
839        state_dict = state_dict.copy()
840
841        for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
842            hook_result = pre_hook(self, state_dict)
843            if hook_result is not None:
844                state_dict = hook_result
845
846        # Validate the state_dict
847        groups = self.param_groups
848
849        # Deepcopy as we write into saved_groups later to update state
850        saved_groups = deepcopy(state_dict["param_groups"])
851
852        if len(groups) != len(saved_groups):
853            raise ValueError(
854                "loaded state dict has a different number of " "parameter groups"
855            )
856        param_lens = (len(g["params"]) for g in groups)
857        saved_lens = (len(g["params"]) for g in saved_groups)
858        if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
859            raise ValueError(
860                "loaded state dict contains a parameter group "
861                "that doesn't match the size of optimizer's group"
862            )
863
864        # Update the state
865        id_map = dict(
866            zip(
867                chain.from_iterable(g["params"] for g in saved_groups),
868                chain.from_iterable(g["params"] for g in groups),
869            )
870        )
871
872        def _cast(param, value, param_id=None, param_groups=None, key=None):
873            r"""Make a deep copy of value, casting all tensors to device of param."""
874            if isinstance(value, torch.Tensor):
875                return Optimizer._process_value_according_to_param_policy(
876                    param, value, param_id, param_groups, key
877                )
878            elif isinstance(value, dict):
879                return {
880                    k: _cast(
881                        param, v, param_id=param_id, param_groups=param_groups, key=k
882                    )
883                    for k, v in value.items()
884                }
885            elif isinstance(value, Iterable):
886                return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value)  # type: ignore[call-arg]
887            else:
888                return value
889
890        # Copy state assigned to params (and cast tensors to appropriate types).
891        # State that is not assigned to params is copied as is (needed for
892        # backward compatibility).
893        state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
894        for k, v in state_dict["state"].items():
895            if k in id_map:
896                param = id_map[k]
897                state[param] = _cast(
898                    param, v, param_id=k, param_groups=state_dict["param_groups"]
899                )
900            else:
901                state[k] = v
902
903        # Update parameter groups, setting their 'params' value
904        def update_group(
905            group: Dict[str, Any], new_group: Dict[str, Any]
906        ) -> Dict[str, Any]:
907            new_group["params"] = group["params"]
908            return new_group
909
910        param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
911        self.__setstate__({"state": state, "param_groups": param_groups})
912
913        for post_hook in self._optimizer_load_state_dict_post_hooks.values():
914            post_hook(self)
915
916    @torch._disable_dynamo
917    def zero_grad(self, set_to_none: bool = True) -> None:
918        r"""Reset the gradients of all optimized :class:`torch.Tensor` s.
919
920        Args:
921            set_to_none (bool): instead of setting to zero, set the grads to None.
922                This will in general have lower memory footprint, and can modestly improve performance.
923                However, it changes certain behaviors. For example:
924                1. When the user tries to access a gradient and perform manual ops on it,
925                a None attribute or a Tensor full of 0s will behave differently.
926                2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
927                are guaranteed to be None for params that did not receive a gradient.
928                3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
929                (in one case it does the step with a gradient of 0 and in the other it skips
930                the step altogether).
931        """
932        foreach = self.defaults.get("foreach", False) or self.defaults.get(
933            "fused", False
934        )
935
936        if not hasattr(self, "_zero_grad_profile_name"):
937            self._patch_step_function()
938
939        per_device_and_dtype_grads: Optional[
940            DefaultDict[torch.device, DefaultDict[torch.dtype, List[torch.Tensor]]]
941        ]
942        if foreach:
943            per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
944        else:
945            per_device_and_dtype_grads = None
946
947        with torch.autograd.profiler.record_function(self._zero_grad_profile_name):
948            for group in self.param_groups:
949                for p in group["params"]:
950                    if p.grad is not None:
951                        if set_to_none:
952                            p.grad = None
953                        else:
954                            if p.grad.grad_fn is not None:
955                                p.grad.detach_()
956                            else:
957                                p.grad.requires_grad_(False)
958                            if not foreach or p.grad.is_sparse:
959                                p.grad.zero_()
960                            else:
961                                assert per_device_and_dtype_grads is not None
962                                per_device_and_dtype_grads[p.grad.device][
963                                    p.grad.dtype
964                                ].append(p.grad)
965            if foreach:
966                assert per_device_and_dtype_grads is not None
967                for per_dtype_grads in per_device_and_dtype_grads.values():
968                    for grads in per_dtype_grads.values():
969                        torch._foreach_zero_(grads)
970
971    @overload
972    def step(self, closure: None = ...) -> None:
973        ...
974
975    @overload
976    def step(self, closure: Callable[[], float]) -> float:
977        ...
978
979    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
980        r"""Perform a single optimization step to update parameter.
981
982        Args:
983            closure (Callable): A closure that reevaluates the model and
984                returns the loss. Optional for most optimizers.
985
986        .. note::
987            Unless otherwise specified, this function should not modify the
988            ``.grad`` field of the parameters.
989        """
990        raise NotImplementedError
991
992    @torch._disable_dynamo
993    def add_param_group(self, param_group: Dict[str, Any]) -> None:
994        r"""Add a param group to the :class:`Optimizer` s `param_groups`.
995
996        This can be useful when fine tuning a pre-trained network as frozen layers can be made
997        trainable and added to the :class:`Optimizer` as training progresses.
998
999        Args:
1000            param_group (dict): Specifies what Tensors should be optimized along with group
1001                specific optimization options.
1002        """
1003        if not isinstance(param_group, dict):
1004            raise TypeError(f"param_group must be a dict, but got {type(param_group)}")
1005
1006        params = param_group["params"]
1007        if isinstance(params, torch.Tensor):
1008            param_group["params"] = [params]
1009        elif isinstance(params, set):
1010            raise TypeError(
1011                "optimizer parameters need to be organized in ordered collections, but "
1012                "the ordering of tensors in sets will change between runs. Please use a list instead."
1013            )
1014        else:
1015            param_group["params"] = list(params)
1016
1017        for param in param_group["params"]:
1018            if not isinstance(param, torch.Tensor):
1019                raise TypeError(
1020                    "optimizer can only optimize Tensors, "
1021                    "but one of the params is " + torch.typename(param)
1022                )
1023            if not self.defaults.get("differentiable", None) and not (
1024                param.is_leaf or param.retains_grad
1025            ):
1026                raise ValueError("can't optimize a non-leaf Tensor")
1027
1028        for name, default in self.defaults.items():
1029            if default is required and name not in param_group:
1030                raise ValueError(
1031                    f"parameter group didn't specify a value of required optimization parameter {name}"
1032                )
1033            else:
1034                param_group.setdefault(name, default)
1035
1036        params = param_group["params"]
1037        if len(params) != len(set(params)):
1038            warnings.warn(
1039                "optimizer contains a parameter group with duplicate parameters; "
1040                "in future, this will cause an error; "
1041                "see github.com/pytorch/pytorch/issues/40967 for more information",
1042                stacklevel=3,
1043            )
1044
1045        param_set: Set[torch.Tensor] = set()
1046        for group in self.param_groups:
1047            param_set.update(set(group["params"]))
1048
1049        if not param_set.isdisjoint(set(param_group["params"])):
1050            raise ValueError("some parameters appear in more than one parameter group")
1051
1052        self.param_groups.append(param_group)
1053