xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_optimizers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import functools
4import itertools
5import sys
6import unittest
7from copy import deepcopy
8from enum import Enum
9from typing import Any, Dict, List, Tuple, Union
10
11import torch
12from torch import Tensor
13from torch.nn import Parameter
14from torch.optim import (
15    Adadelta,
16    Adafactor,
17    Adagrad,
18    Adam,
19    Adamax,
20    AdamW,
21    ASGD,
22    LBFGS,
23    NAdam,
24    Optimizer,
25    RAdam,
26    RMSprop,
27    Rprop,
28    SGD,
29    SparseAdam,
30)
31from torch.optim.lr_scheduler import (
32    ConstantLR,
33    ExponentialLR,
34    LinearLR,
35    PolynomialLR,
36    ReduceLROnPlateau,
37    StepLR,
38)
39from torch.testing._internal.common_device_type import tol, toleranceOverride
40from torch.testing._internal.common_methods_invocations import DecorateInfo
41from torch.testing._internal.common_utils import (
42    _TestParametrizer,
43    skipIfMps,
44    skipIfTorchDynamo,
45    skipIfXpu,
46    TEST_WITH_TORCHDYNAMO,
47)
48from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
49
50
51class OptimizerInput:
52    """Contains args / kwargs to be passed to an optimizer constructor."""
53
54    __slots__ = ["params", "kwargs", "desc"]
55
56    def __init__(
57        self,
58        params: Union[List[Parameter], List[Tensor], Dict[Any, Any]],
59        kwargs: Dict[str, Any],
60        desc: str = "",
61    ):
62        # params can be a list of Tensors OR param_groups OR None
63        self.params = params
64        self.kwargs = kwargs
65        self.desc = desc
66
67    def __repr__(self):
68        return f"params={self.params}, kwargs={self.kwargs}, desc={self.desc}"
69
70
71class OptimizerErrorEnum(Enum):
72    """Enumerates when an error is raised when testing optimizers."""
73
74    CONSTRUCTION_ERROR = 0
75    STEP_ERROR = 1
76
77
78class ErrorOptimizerInput:
79    """
80    An OptimizerInput that will cause the optimizer to throw an error when constructed.
81    Includes the type and string of the resulting error.
82    """
83
84    __slots__ = ["optimizer_error_input", "error_on", "error_type", "error_regex"]
85
86    def __init__(
87        self,
88        optimizer_error_input,
89        *,
90        error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
91        error_type=RuntimeError,
92        error_regex="",
93    ):
94        self.optimizer_error_input = optimizer_error_input
95        self.error_on = error_on
96        self.error_type = error_type
97        self.error_regex = error_regex
98
99
100class OptimizerInfo:
101    """Optimizer information to be used in testing."""
102
103    def __init__(
104        self,
105        optim_cls: Optimizer,  # Class object for the Optimizer under test
106        *,
107        # Function to generate optimizer inputs EXCLUDING params. We delegate params responsibility
108        # to the test using the OptimizerInfo. OptimizerInput.params is likely None.
109        # Can optionally take in device to filter out certain unsupported configs
110        optim_inputs_func,
111        # Tuple of lambdas to generate LRScheduler instances to run with the optimizer for the
112        # LRScheduler tests like test_forloop_goes_right_direction with_lrsched.
113        # We DO NOT expect to thoroughly test LRSchedulers through the optimizers, so not every
114        # LRScheduler configuration will be included. See test_lrscheduler.py for that instead.
115        # A few optimizers like SGD and Adam will test more LRSchedulers.
116        scheduler_inputs=(
117            [
118                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
119                lambda opt: ReduceLROnPlateau(opt),
120            ],
121        ),
122        # A subset of the global-cliquey flags (fused, foreach, differentiable) the optimizer
123        # supports. See NOTE: [optimizer kwarg categories] for what global-cliquey means.
124        supported_impls: Tuple[str, ...] = ("foreach", "differentiable"),
125        # A subset of all flags, signifying which ones were only supported after the
126        # original optimizer had already been released. aka impls where we need to check BC.
127        not_og_supported_flags: Tuple[str, ...] = (
128            "foreach",
129            "differentiable",
130            "maximize",
131            "capturable",
132        ),
133        # the optim supports passing in sparse gradients as well as dense grads
134        supports_sparse: bool = False,
135        # the optimizer constructor supports passing in capturable as a kwarg
136        has_capturable_arg: bool = False,
137        # the optim only supports one config: sparse grads w/ dense params, see SparseAdam
138        only_supports_sparse_grads: bool = False,
139        # Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests,
140        # with especially tuned hyperparameters. These only apply if the optimizer supports
141        # sparse parameters or grads.
142        metadata_for_sparse=({}, []),
143        # the optim supports complex parameters
144        supports_complex: bool = True,
145        # whether the optimizer.step() function requires a closure to be passed
146        step_requires_closure: bool = False,
147        # whether the optimizer supports per-param options with parameter groups
148        supports_param_groups: bool = True,
149        # whether the optimizer supports parameters on multiple devices
150        supports_multiple_devices: bool = True,
151        skips=(),  # Indicates which tests to skip
152        decorators=None,  # Additional decorators to apply to generated tests
153        optim_error_inputs_func=None,  # Function to generate optim inputs that error
154        supports_fused_on: Tuple[str, ...] = (),
155    ):
156        self.optim_cls = optim_cls
157        self.optim_inputs_func = optim_inputs_func
158        self.scheduler_inputs = scheduler_inputs
159        self.supported_impls = supported_impls
160        self.not_og_supported_flags = not_og_supported_flags
161        self.supports_sparse = supports_sparse
162        self.has_capturable_arg = has_capturable_arg
163        self.metadata_for_sparse = metadata_for_sparse
164        self.only_supports_sparse_grads = only_supports_sparse_grads
165        self.supports_complex = supports_complex
166        self.step_requires_closure = step_requires_closure
167        self.supports_param_groups = supports_param_groups
168        self.supports_multiple_devices = supports_multiple_devices
169        self.decorators = (
170            *(decorators if decorators else []),
171            *(skips if skips else []),
172        )
173        self.optim_error_inputs_func = optim_error_inputs_func
174        self.supports_fused_on = supports_fused_on
175
176    def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
177        result = []
178        for decorator in self.decorators:
179            if isinstance(decorator, DecorateInfo):
180                if decorator.is_active(
181                    test_class, test_name, device, dtype, param_kwargs
182                ):
183                    result.extend(decorator.decorators)
184            else:
185                result.append(decorator)
186        return result
187
188    @property
189    def name(self):
190        return self.optim_cls.__name__
191
192
193class optims(_TestParametrizer):
194    """Decorator for specifying a list of optimizers over which to run a test."""
195
196    def __init__(self, optim_info_iterable, dtypes=None):
197        self.optim_info_list = list(optim_info_iterable)
198
199        # optimizers aren't limited to be one dtype as parameters can have different dtypes
200        # We default to torch.float32, but dtypes should be specified through passed in
201        # parameters.
202        self.dtypes = dtypes if dtypes is not None else [torch.float32]
203
204    def _parametrize_test(self, test, generic_cls, device_cls):
205        if device_cls is None:
206            raise RuntimeError(
207                "The @optims decorator is only intended to be used in a device-specific "
208                "context; use it with instantiate_device_type_tests() instead of "
209                "instantiate_parametrized_tests()"
210            )
211
212        for optim_info, dtype in itertools.product(self.optim_info_list, self.dtypes):
213            # Construct the test name; device / dtype parts are handled outside.
214            # See [Note: device and dtype suffix placement]
215            test_name = optim_info.name
216
217            # Construct parameter kwargs to pass to the test.
218            param_kwargs = {"optim_info": optim_info, "dtype": dtype}
219
220            try:
221
222                @functools.wraps(test)
223                def test_wrapper(*args, **kwargs):
224                    return test(*args, **kwargs)
225
226                decorator_fn = functools.partial(
227                    optim_info.get_decorators,
228                    generic_cls.__name__,
229                    test.__name__,
230                    device_cls.device_type,
231                    dtype,
232                )
233
234                yield (test_wrapper, test_name, param_kwargs, decorator_fn)
235            except Exception as ex:
236                # Provides an error message for debugging before rethrowing the exception
237                print(
238                    f"Failed to instantiate {test_name} for module {optim_info.name}!"
239                )
240                raise ex
241
242
243# Helper function for generating error inputs for all optimizers, used below.
244def get_error_inputs_for_all_optims(device, dtype):
245    if _get_device_type(device) == "cpu":
246        sample_param = Parameter(torch.randn(1, device=device, dtype=dtype))
247        return [
248            ErrorOptimizerInput(
249                OptimizerInput(
250                    params=sample_param,
251                    kwargs={},
252                    desc="invalid param type",
253                ),
254                error_type=TypeError,
255                error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
256            ),
257            ErrorOptimizerInput(
258                OptimizerInput(
259                    params=[sample_param, sample_param],
260                    kwargs={},
261                    desc="a param group cannot have duplicate parameters",
262                ),
263                error_type=UserWarning,
264                error_regex=".*a parameter group with duplicate parameters.*",
265            ),
266            ErrorOptimizerInput(
267                OptimizerInput(
268                    params=[{"params": sample_param}, {"params": sample_param}],
269                    kwargs={},
270                    desc="duplicate parameters should not occur across param groups either",
271                ),
272                error_type=ValueError,
273                error_regex="some parameters appear in more than one parameter group",
274            ),
275            ErrorOptimizerInput(
276                OptimizerInput(
277                    params=None,
278                    kwargs=dict(lr=torch.tensor([0.001, 0.001])),
279                    desc="Tensor lr must be 1-element",
280                ),
281                error_type=ValueError,
282                error_regex="Tensor lr must be 1-element",
283            ),
284        ]
285    else:
286        return []
287
288
289# ------------------------------------------------------------------------------------------
290# NOTE: [optimizer kwarg categories]
291# We categorize optimizer kwargs as 3 types:
292#  1. optimizer-specific flags are like amsgrad or rho or beta, flags that are specific to
293#     algorithms and thus only show up for certain optimizers. There are many of these, so I
294#     do not bother gathering them all and listing them here. The converse to these would be
295#     global flags that every optimizer ideally _should_ support. We break global flags into
296#     2 further categories and list them all below.
297#  2. global-friendly = ["lr", "weight_decay", "maximize", "capturable"]
298#     global-friendly flags are global flags who play nicely with all other global flags,
299#     i.e., are mutually exclusive in function. This means that any pair of the following
300#     flags can be toggled at once (e.g., maximize and weight_decay). Furthermore, any of the
301#     following flags theoretically can be enabled with ANY other global flag, including the
302#     cliquey ones (e.g, capturable and foreach).
303#  3. global-cliquey = ["foreach", "fused", "differentiable"]
304#     global-cliquey flags are global flags that do NOT coexist with other cliquey flags,
305#     usually because they contradict each other in function. For example, one should not flip
306#     both foreach AND fused to True, because they are two differing performance optimizations
307#     in which you can only opt into one.
308#
309# The following optim_inputs_func_* sampling functions only return constructor combinations of
310# optimizer-specific and global-friendly flags. This is because we are confident they would mesh
311# well with additional kwargs. On the flip side of the same coin, we reserve setting the
312# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs.
313
314
315def optim_inputs_func_adadelta(device, dtype=None):
316    cuda_supported_configs = [
317        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
318        OptimizerInput(
319            params=None,
320            kwargs={"weight_decay": 0.1, "capturable": True},
321            desc="capturable with weight decay",
322        ),
323        OptimizerInput(
324            params=None,
325            kwargs={"lr": torch.tensor(0.001), "capturable": True},
326            desc="Tensor lr with capturable",
327        ),
328    ]
329
330    return [
331        OptimizerInput(params=None, kwargs={}, desc="default"),
332        OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
333        OptimizerInput(
334            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
335        ),
336        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
337        OptimizerInput(
338            params=None,
339            kwargs={"weight_decay": 0.1, "maximize": True},
340            desc="maximize, weight_decay",
341        ),
342        OptimizerInput(
343            params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
344        ),
345    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
346
347
348def optim_error_inputs_func_adadelta(device, dtype):
349    error_inputs = get_error_inputs_for_all_optims(device, dtype)
350    if _get_device_type(device) == "cpu":
351        error_inputs += [
352            ErrorOptimizerInput(
353                OptimizerInput(
354                    params=None,
355                    kwargs=dict(lr=1e-2, rho=1.1),
356                    desc="rho should be between 0 and 1",
357                ),
358                error_type=ValueError,
359                error_regex="Invalid rho value: 1.1",
360            ),
361        ]
362    return error_inputs
363
364
365def optim_inputs_func_adafactor(device, dtype=None):
366    return [
367        OptimizerInput(params=None, kwargs={}, desc="default"),
368        OptimizerInput(
369            params=None,
370            kwargs={"weight_decay": 0.1, "lr": 0.01},
371            desc="nonzero weight_decay",
372        ),
373        OptimizerInput(
374            params=None,
375            kwargs={"weight_decay": 0.1, "maximize": True},
376            desc="maximize",
377        ),
378        OptimizerInput(
379            params=None,
380            kwargs={"beta2_decay": -1.0},
381            desc="non-default beta2_decay",
382        ),
383        OptimizerInput(
384            params=None,
385            kwargs={"d": 1.5},
386            desc="non-default clipping threshold d",
387        ),
388    ]
389
390
391def optim_error_inputs_func_adafactor(device, dtype):
392    error_inputs = get_error_inputs_for_all_optims(device, dtype)
393    if _get_device_type(device) == "cpu":
394        complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64)
395        complex_param.grad = torch.rand_like(complex_param)
396        error_inputs += [
397            ErrorOptimizerInput(
398                OptimizerInput(
399                    params=None,
400                    kwargs=dict(eps=(-1e-30, 1e-3)),
401                    desc="epsilon1 should be >= 0",
402                ),
403                error_type=ValueError,
404                error_regex="epsilon1 should be >= 0",
405            ),
406            ErrorOptimizerInput(
407                OptimizerInput(
408                    params=None,
409                    kwargs=dict(d=0.0),
410                    desc="invalid d",
411                ),
412                error_type=ValueError,
413                error_regex="Clipping threshold d should be >= 1",
414            ),
415            ErrorOptimizerInput(
416                OptimizerInput(
417                    params=None,
418                    kwargs=dict(beta2_decay=0.8),
419                    desc="invalid beta2_decay",
420                ),
421                error_type=ValueError,
422                error_regex="beta2_decay should be <= 0",
423            ),
424            ErrorOptimizerInput(
425                OptimizerInput(
426                    params=[complex_param],
427                    kwargs=dict(),
428                    desc="does not support complex parameters",
429                ),
430                error_type=RuntimeError,
431                error_regex="Adafactor does not support complex parameters",
432                error_on=OptimizerErrorEnum.STEP_ERROR,
433            ),
434        ]
435    return error_inputs
436
437
438def optim_inputs_func_adagrad(device, dtype=None):
439    return [
440        OptimizerInput(params=None, kwargs={}, desc="default"),
441        OptimizerInput(
442            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
443        ),
444        OptimizerInput(
445            params=None,
446            kwargs={"weight_decay": 0.1, "maximize": True},
447            desc="maximize",
448        ),
449        OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
450        OptimizerInput(
451            params=None,
452            kwargs={"initial_accumulator_value": 0.1, "weight_decay": 0.1},
453            desc="initial_accumulator_value",
454        ),
455        OptimizerInput(
456            params=None,
457            kwargs={"lr": 0.1, "lr_decay": 0.5, "weight_decay": 0.1},
458            desc="lr_decay",
459        ),  # TODO: Move out to testing in param_group?
460        OptimizerInput(
461            params=None,
462            kwargs={"lr": torch.tensor(0.001)},
463            desc="Tensor lr",
464        ),
465    ]
466
467
468def optim_error_inputs_func_adagrad(device, dtype):
469    error_inputs = get_error_inputs_for_all_optims(device, dtype)
470    if _get_device_type(device) == "cpu":
471        error_inputs += [
472            ErrorOptimizerInput(
473                OptimizerInput(
474                    params=None,
475                    kwargs=dict(lr=1e-2, lr_decay=-0.5),
476                    desc="lr_decay must be bigger than 0",
477                ),
478                error_type=ValueError,
479                error_regex="Invalid lr_decay value: -0.5",
480            ),
481        ]
482    return error_inputs
483
484
485# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
486# with all implementation code paths...
487def optim_inputs_func_adam(device, dtype=None):
488    cuda_supported_configs = [
489        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
490        OptimizerInput(
491            params=None,
492            kwargs={"weight_decay": 0.1, "amsgrad": True, "capturable": True},
493            desc="capturable, amsgrad",
494        ),
495        OptimizerInput(
496            params=None,
497            kwargs={"lr": torch.tensor(0.001), "amsgrad": True, "capturable": True},
498            desc="Tensor lr with capturable and amsgrad",
499        ),
500    ]
501    mps_supported_configs = [
502        OptimizerInput(
503            params=None, kwargs={"lr": torch.tensor(0.01)}, desc="Tensor lr"
504        ),
505    ]
506
507    total = (
508        [
509            OptimizerInput(params=None, kwargs={}, desc="default"),
510            OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
511            OptimizerInput(
512                params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
513            ),
514            OptimizerInput(
515                params=None,
516                kwargs={"weight_decay": 0.1, "maximize": True},
517                desc="maximize",
518            ),
519            OptimizerInput(
520                params=None,
521                kwargs={"weight_decay": 0.1, "amsgrad": True},
522                desc="amsgrad",
523            ),
524        ]
525        + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
526        + (mps_supported_configs if _get_device_type(device) == "mps" else [])
527    )
528    if dtype in (torch.float16,):
529        for input in total:
530            """
531            Too small eps will make denom to be zero for low precision dtype
532            denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
533            For example,
534            >>> a
535            tensor([0.], dtype=torch.float16)
536            >>> a + 1e-8
537            tensor([0.], dtype=torch.float16)
538            """
539            input.kwargs["eps"] = 0.1
540    return total
541
542
543def optim_error_inputs_func_adam(device, dtype):
544    error_inputs = get_error_inputs_for_all_optims(device, dtype)
545    if _get_device_type(device) == "cpu":
546        error_inputs += [
547            ErrorOptimizerInput(
548                OptimizerInput(
549                    params=None,
550                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
551                    desc="beta1 should be between 0 and 1",
552                ),
553                error_type=ValueError,
554                error_regex="Invalid beta parameter at index 0: 1.0",
555            ),
556            ErrorOptimizerInput(
557                OptimizerInput(
558                    params=None,
559                    kwargs=dict(lr=1e-2, weight_decay=-1),
560                    desc="weight_decay should > 0",
561                ),
562                error_type=ValueError,
563                error_regex="Invalid weight_decay value: -1",
564            ),
565            ErrorOptimizerInput(
566                OptimizerInput(
567                    params=None,
568                    kwargs=dict(lr=torch.tensor(0.001), foreach=True),
569                    desc="lr as Tensor doesn't work with foreach & not capturable",
570                ),
571                error_type=ValueError,
572                error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
573            ),
574        ]
575    if _get_device_type(device) == "cuda":
576        sample_tensor = torch.empty((), device=device, dtype=dtype)
577        error_inputs += [
578            ErrorOptimizerInput(
579                OptimizerInput(
580                    params=[sample_tensor],
581                    kwargs={"foreach": True, "fused": True},
582                    desc="`fused` and `foreach` cannot be `True` together",
583                ),
584                error_type=RuntimeError,
585                error_regex="`fused` and `foreach` cannot be `True` together",
586            ),
587            ErrorOptimizerInput(
588                OptimizerInput(
589                    params=[sample_tensor],
590                    kwargs={"fused": True, "differentiable": True},
591                    desc="`fused` does not support `differentiable`",
592                ),
593                error_type=RuntimeError,
594                error_regex="`fused` does not support `differentiable`",
595            ),
596        ]
597    return error_inputs
598
599
600def optim_inputs_func_adamax(device, dtype=None):
601    cuda_supported_configs = [
602        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
603        OptimizerInput(
604            params=None,
605            kwargs={"weight_decay": 0.9, "maximize": True, "capturable": True},
606            desc="capturable, maximize, weight_decay",
607        ),
608        OptimizerInput(
609            params=None,
610            kwargs={"weight_decay": 0, "maximize": True, "capturable": True},
611            desc="capturable, maximize",
612        ),
613        OptimizerInput(
614            params=None,
615            kwargs={"weight_decay": 0.9, "maximize": False, "capturable": True},
616            desc="capturable, weight_decay",
617        ),
618        OptimizerInput(
619            params=None,
620            kwargs={
621                "lr": torch.tensor(0.001),
622                "weight_decay": 0.9,
623                "maximize": False,
624                "capturable": True,
625            },
626            desc="capturable, weight_decay, tensor LR",
627        ),
628    ]
629
630    return [
631        OptimizerInput(params=None, kwargs={}, desc="default"),
632        OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
633        OptimizerInput(
634            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
635        ),
636        OptimizerInput(
637            params=None,
638            kwargs={"maximize": True},
639            desc="maximize",
640        ),
641        OptimizerInput(
642            params=None,
643            kwargs={"weight_decay": 0.1, "maximize": True},
644            desc="maximize, weight_decay",
645        ),
646    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
647
648
649def optim_error_inputs_func_adamax(device, dtype):
650    error_inputs = get_error_inputs_for_all_optims(device, dtype)
651    if _get_device_type(device) == "cpu":
652        error_inputs += [
653            ErrorOptimizerInput(
654                OptimizerInput(
655                    params=None,
656                    kwargs=dict(lr=1e-2, betas=(0.0, 1.0)),
657                    desc="beta2 should be between 0 and 1",
658                ),
659                error_type=ValueError,
660                error_regex="Invalid beta parameter at index 1: 1.0",
661            ),
662        ]
663    return error_inputs
664
665
666def optim_inputs_func_adamw(device, dtype=None):
667    return optim_inputs_func_adam(device, dtype)
668
669
670def optim_error_inputs_func_adamw(device, dtype):
671    return optim_error_inputs_func_adam(device, dtype)
672
673
674def optim_inputs_func_asgd(device, dtype=None):
675    cuda_supported_configs = [
676        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
677        OptimizerInput(
678            params=None,
679            kwargs={"maximize": True, "capturable": True},
680            desc="maximize, capturable",
681        ),
682        OptimizerInput(
683            params=None,
684            kwargs={"weight_decay": 0.1, "capturable": True},
685            desc="weight_decay, capturable",
686        ),
687        OptimizerInput(
688            params=None,
689            kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
690            desc="maximize, weight_decay, capturable",
691        ),
692        OptimizerInput(
693            params=None,
694            kwargs={
695                "lr": torch.tensor(0.001),
696                "weight_decay": 0.1,
697                "maximize": True,
698                "capturable": True,
699            },
700            desc="maximize, weight_decay, capturable, tensor LR",
701        ),
702    ]
703    return [
704        OptimizerInput(params=None, kwargs={}, desc="default"),
705        OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"),
706        OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
707        OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
708        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
709        OptimizerInput(
710            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
711        ),
712        OptimizerInput(
713            params=None,
714            kwargs={"weight_decay": 0.1, "maximize": True},
715            desc="maximize, nonzero weight_decay",
716        ),
717    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
718
719
720def optim_error_inputs_func_asgd(device, dtype):
721    error_inputs = get_error_inputs_for_all_optims(device, dtype)
722    if _get_device_type(device) == "cpu":
723        error_inputs += [
724            ErrorOptimizerInput(
725                OptimizerInput(
726                    params=None,
727                    kwargs=dict(lr=1e-2, weight_decay=-0.5),
728                    desc="weight_decay should > 0",
729                ),
730                error_type=ValueError,
731                error_regex="Invalid weight_decay value: -0.5",
732            ),
733        ]
734    return error_inputs
735
736
737def optim_inputs_func_lbfgs(device, dtype=None):
738    return [
739        OptimizerInput(params=None, kwargs={}, desc="default"),
740        OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
741        OptimizerInput(
742            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
743        ),
744        OptimizerInput(
745            params=None, kwargs={"tolerance_grad": 1e-6}, desc="tolerance_grad"
746        ),
747        OptimizerInput(
748            params=None,
749            kwargs={"line_search_fn": "strong_wolfe"},
750            desc="strong_wolfe",
751        ),
752    ]
753
754
755def optim_error_inputs_func_lbfgs(device, dtype):
756    error_inputs = get_error_inputs_for_all_optims(device, dtype)
757    return error_inputs
758
759
760def optim_inputs_func_nadam(device, dtype=None):
761    cuda_supported_configs = [
762        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
763        OptimizerInput(
764            params=None,
765            kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3, "capturable": True},
766            desc="weight_decay, capturable",
767        ),
768        OptimizerInput(
769            params=None,
770            kwargs={
771                "weight_decay": 0.9,
772                "momentum_decay": 6e-3,
773                "decoupled_weight_decay": True,
774                "capturable": True,
775            },
776            desc="decoupled_weight_decay, capturable",
777        ),
778        OptimizerInput(
779            params=None,
780            kwargs={
781                "lr": torch.tensor(0.001),
782                "weight_decay": 0.9,
783                "momentum_decay": 6e-3,
784                "decoupled_weight_decay": True,
785                "capturable": True,
786            },
787            desc="decoupled_weight_decay, capturable",
788        ),
789    ]
790    return [
791        OptimizerInput(params=None, kwargs={}, desc="default"),
792        OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
793        OptimizerInput(
794            params=None,
795            kwargs={"momentum_decay": 6e-3},
796            desc="non-zero momentum_decay",
797        ),
798        OptimizerInput(
799            params=None,
800            kwargs={
801                "weight_decay": 0.1,
802            },
803            desc="weight_decay",
804        ),
805        OptimizerInput(
806            params=None,
807            kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
808            desc="weight_decay, momentum_decay",
809        ),
810        OptimizerInput(
811            params=None,
812            kwargs={
813                "weight_decay": 0.1,
814                "momentum_decay": 6e-3,
815                "decoupled_weight_decay": True,
816            },
817            desc="decoupled_weight_decay",
818        ),
819        OptimizerInput(
820            params=None,
821            kwargs={"weight_decay": 0.1, "maximize": True},
822            desc="maximize",
823        ),
824    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
825
826
827def optim_error_inputs_func_nadam(device, dtype):
828    error_inputs = get_error_inputs_for_all_optims(device, dtype)
829    if _get_device_type(device) == "cpu":
830        error_inputs += [
831            ErrorOptimizerInput(
832                OptimizerInput(
833                    params=None,
834                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
835                    desc="beta1 should be between 0 and 1",
836                ),
837                error_type=ValueError,
838                error_regex="Invalid beta parameter at index 0: 1.0",
839            ),
840            ErrorOptimizerInput(
841                OptimizerInput(
842                    params=None,
843                    kwargs=dict(lr=1e-2, momentum_decay=-0.2),
844                    desc="momentum_decay should > 0",
845                ),
846                error_type=ValueError,
847                error_regex="Invalid momentum_decay value: -0.2",
848            ),
849        ]
850    return error_inputs
851
852
853# Weird story bro, NAdam and RAdam do not have maximize.
854def optim_inputs_func_radam(device=None, dtype=None):
855    cuda_supported_configs = [
856        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
857        OptimizerInput(
858            params=None,
859            kwargs={
860                "capturable": True,
861                "weight_decay": 0.1,
862            },
863            desc="capturable, weight_decay",
864        ),
865        OptimizerInput(
866            params=None,
867            kwargs={
868                "capturable": True,
869                "weight_decay": 0.1,
870                "decoupled_weight_decay": True,
871            },
872            desc="capturable, weight_decay, decoupled_weight_decay",
873        ),
874        OptimizerInput(
875            params=None,
876            kwargs={
877                "lr": torch.tensor(0.001),
878                "capturable": True,
879                "weight_decay": 0.1,
880                "decoupled_weight_decay": True,
881            },
882            desc="capturable, weight_decay, decoupled_weight_decay, tensor LR",
883        ),
884    ]
885    return [
886        OptimizerInput(params=None, kwargs={}, desc="default"),
887        OptimizerInput(params=None, kwargs={"lr": 2e-3}, desc="non-default lr"),
888        OptimizerInput(params=None, kwargs={"eps": 1e-6}, desc="non-default eps"),
889        OptimizerInput(
890            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
891        ),
892        OptimizerInput(
893            params=None,
894            kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True},
895            desc="decoupled_weight_decay",
896        ),
897        OptimizerInput(
898            params=None,
899            kwargs={"weight_decay": 0.1, "maximize": True},
900            desc="maximize",
901        ),
902    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
903
904
905def optim_error_inputs_func_radam(device, dtype):
906    error_inputs = get_error_inputs_for_all_optims(device, dtype)
907    if _get_device_type(device) == "cpu":
908        error_inputs += [
909            ErrorOptimizerInput(
910                OptimizerInput(
911                    params=None,
912                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
913                    desc="beta1 should be between 0 and 1",
914                ),
915                error_type=ValueError,
916                error_regex="Invalid beta parameter at index 0: 1.0",
917            ),
918            ErrorOptimizerInput(
919                OptimizerInput(
920                    params=None,
921                    kwargs=dict(lr=1e-2, weight_decay=-1),
922                    desc="weight_decay should > 0",
923                ),
924                error_type=ValueError,
925                error_regex="Invalid weight_decay value: -1",
926            ),
927        ]
928    return error_inputs
929
930
931def optim_inputs_func_rmsprop(device, dtype=None):
932    cuda_supported_configs = [
933        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
934        OptimizerInput(
935            params=None,
936            kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
937            desc="capturable, maximize",
938        ),
939        OptimizerInput(
940            params=None,
941            kwargs={"lr": torch.tensor(0.001), "capturable": True},
942            desc="Tensor lr with capturable",
943        ),
944    ]
945
946    return [
947        OptimizerInput(params=None, kwargs={}, desc="default"),
948        OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
949        OptimizerInput(
950            params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
951        ),
952        OptimizerInput(
953            params=None,
954            kwargs={
955                "maximize": True,
956            },
957            desc="maximize",
958        ),
959        OptimizerInput(
960            params=None,
961            kwargs={"weight_decay": 0.1, "centered": True},
962            desc="centered",
963        ),
964        OptimizerInput(
965            params=None,
966            kwargs={
967                "maximize": True,
968                "weight_decay": 0.1,
969            },
970            desc="maximize, weight_decay",
971        ),
972        OptimizerInput(
973            params=None,
974            kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
975            desc="momentum",
976        ),
977        OptimizerInput(
978            params=None,
979            kwargs={
980                "weight_decay": 0.1,
981                "centered": True,
982                "momentum": 0.1,
983                "maximize": True,
984            },
985            desc="maximize, centered, weight_decay, w/ momentum",
986        ),
987    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
988
989
990def optim_error_inputs_func_rmsprop(device, dtype):
991    error_inputs = get_error_inputs_for_all_optims(device, dtype)
992    if _get_device_type(device) == "cpu":
993        error_inputs += [
994            ErrorOptimizerInput(
995                OptimizerInput(
996                    params=None,
997                    kwargs=dict(lr=1e-2, momentum=-1.0),
998                    desc="momentum should be between 0 and 1",
999                ),
1000                error_type=ValueError,
1001                error_regex="Invalid momentum value: -1.0",
1002            ),
1003        ]
1004    return error_inputs
1005
1006
1007def optim_inputs_func_rprop(device, dtype=None):
1008    cuda_supported_configs = [
1009        OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
1010        OptimizerInput(
1011            params=None,
1012            kwargs={"lr": torch.tensor(0.001), "capturable": True},
1013            desc="Tensor lr with capturable",
1014        ),
1015    ]
1016
1017    return [
1018        OptimizerInput(params=None, kwargs={}, desc="default"),
1019        OptimizerInput(params=None, kwargs={"lr": 2e-4}, desc="non-default lr"),
1020        OptimizerInput(
1021            params=None, kwargs={"etas": (0.5, 1.5)}, desc="non-default etas"
1022        ),
1023        OptimizerInput(
1024            params=None,
1025            kwargs={"step_sizes": (2e-6, 100)},
1026            desc="non-default step_sizes",
1027        ),
1028        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
1029    ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
1030
1031
1032def optim_error_inputs_func_rprop(device, dtype):
1033    error_inputs = get_error_inputs_for_all_optims(device, dtype)
1034    if _get_device_type(device) == "cpu":
1035        error_inputs += [
1036            ErrorOptimizerInput(
1037                OptimizerInput(
1038                    params=None,
1039                    kwargs=dict(lr=1e-2, etas=(1.0, 0.5)),
1040                    desc="0 < eta1 < 1 < eta2",
1041                ),
1042                error_type=ValueError,
1043                error_regex="Invalid eta values: 1.0, 0.5",
1044            ),
1045        ]
1046    return error_inputs
1047
1048
1049def optim_inputs_func_sgd(device, dtype=None):
1050    return [
1051        OptimizerInput(params=None, kwargs={}, desc="default"),
1052        OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
1053        OptimizerInput(
1054            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
1055        ),
1056        OptimizerInput(
1057            params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
1058        ),
1059        OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
1060        OptimizerInput(
1061            params=None,
1062            kwargs={"weight_decay": 0.1, "maximize": True},
1063            desc="maximize",
1064        ),
1065        OptimizerInput(
1066            params=None,
1067            kwargs={"momentum": 0.9, "dampening": 0.5},
1068            desc="dampening",
1069        ),
1070        OptimizerInput(
1071            params=None,
1072            kwargs={"momentum": 0.9, "weight_decay": 0.1},
1073            desc="weight_decay w/ momentum",
1074        ),
1075        OptimizerInput(
1076            params=None,
1077            kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
1078            desc="nesterov",
1079        ),
1080    ]
1081
1082
1083def optim_error_inputs_func_sgd(device, dtype):
1084    error_inputs = get_error_inputs_for_all_optims(device, dtype)
1085    if _get_device_type(device) == "cpu":
1086        error_inputs += [
1087            ErrorOptimizerInput(
1088                OptimizerInput(
1089                    params=None,
1090                    kwargs=dict(lr=1e-2, momentum=-0.5),
1091                    desc="momentum should be between 0 and 1",
1092                ),
1093                error_type=ValueError,
1094                error_regex="Invalid momentum value: -0.5",
1095            ),
1096        ]
1097    return error_inputs
1098
1099
1100def optim_inputs_func_sparseadam(device, dtype=None):
1101    return [
1102        OptimizerInput(params=None, kwargs={}, desc="default"),
1103        OptimizerInput(
1104            params=None, kwargs={"lr": 0.01}, desc="non-default lr"
1105        ),  # TODO: Move out to testing in param_group?
1106        OptimizerInput(
1107            params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
1108        ),
1109        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
1110    ]
1111
1112
1113def optim_error_inputs_func_sparseadam(device, dtype):
1114    error_inputs = get_error_inputs_for_all_optims(device, dtype)
1115
1116    if _get_device_type(device) == "cpu":
1117        error_inputs += [
1118            ErrorOptimizerInput(
1119                OptimizerInput(
1120                    params=None,
1121                    kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
1122                    desc="beta1 should be between 0 and 1",
1123                ),
1124                error_type=ValueError,
1125                error_regex="Invalid beta parameter at index 0: 1.0",
1126            ),
1127            ErrorOptimizerInput(
1128                OptimizerInput(
1129                    params=[
1130                        torch.zeros(
1131                            3, layout=torch.sparse_coo, device=device, dtype=dtype
1132                        )
1133                    ],
1134                    kwargs={},
1135                    desc="dense params required",
1136                ),
1137                error_type=ValueError,
1138                error_regex="SparseAdam requires dense parameter tensors",
1139            ),
1140            ErrorOptimizerInput(
1141                OptimizerInput(
1142                    params=[
1143                        {
1144                            "params": [
1145                                torch.zeros(
1146                                    3,
1147                                    layout=torch.sparse_coo,
1148                                    device=device,
1149                                    dtype=dtype,
1150                                )
1151                            ]
1152                        }
1153                    ],
1154                    kwargs={},
1155                    desc="dense params required in param_groups",
1156                ),
1157                error_type=ValueError,
1158                error_regex="SparseAdam requires dense parameter tensors",
1159            ),
1160            ErrorOptimizerInput(
1161                OptimizerInput(
1162                    params=[torch.rand(2, 3, device=device, dtype=torch.complex64)],
1163                    kwargs={},
1164                    desc="complex not supported",
1165                ),
1166                error_type=ValueError,
1167                error_regex="SparseAdam does not support complex parameters",
1168            ),
1169        ]
1170    return error_inputs
1171
1172
1173def _get_device_type(device: Union[str, torch.device]) -> str:
1174    # Returns the device type as a string, e.g., "cpu" or "cuda"
1175    if isinstance(device, torch.device):
1176        device = str(device.type)
1177    assert isinstance(device, str)
1178    return device.split(":")[0]
1179
1180
1181def _get_optim_inputs_including_global_cliquey_kwargs(
1182    device, dtype, optim_info, skip=()
1183) -> List[OptimizerInput]:
1184    """
1185    Return a list of all configs for a given optimizer as a list of OptimizerInputs,
1186    including configs that have supported global cliquey kwargs (foreach, fused,
1187    differentiable) based on optim_info.supported_impls.
1188
1189    The configs (optim_inputs) returned by optim_info.optim_inputs_func(...)
1190    intentionally do NOT include global cliquey kwargs to give flexibility to tests.
1191    For example, testing correctness between toggling foreach on and off is now
1192    trivial. That said, we sometimes want to test for all possible configs on an
1193    optimizer including all supported flags, so this helper returns all optim inputs.
1194    """
1195    assert all(
1196        x in ["foreach", "fused", "differentiable"] for x in skip
1197    ), "skip must be a subset of ['foreach', 'fused', 'differentiable']"
1198
1199    optim_inputs = optim_info.optim_inputs_func(device)
1200
1201    supported_impls = tuple(
1202        x
1203        for x in optim_info.supported_impls
1204        if x not in skip
1205        and (_get_device_type(device) in optim_info.supports_fused_on or x != "fused")
1206        and (
1207            _get_device_type(device) in _get_foreach_kernels_supported_devices()
1208            or x != "foreach"
1209        )
1210    )
1211
1212    all_optim_inputs = []
1213    for optim_input in optim_inputs:
1214        # Add the base config where all the flags are False
1215        base_kwargs = deepcopy(optim_input.kwargs)
1216        if len(supported_impls) != 0:
1217            for flag in supported_impls:
1218                base_kwargs[flag] = False
1219            all_optim_inputs.append(
1220                OptimizerInput(params=None, kwargs=base_kwargs, desc=optim_input.desc)
1221            )
1222        else:
1223            all_optim_inputs.append(optim_input)
1224        # Add a config for when each of the global cliquey kwargs is True
1225        # Note that in [optimizer kwarg categories], these kwargs are mutually
1226        # exclusive, so we do not need to product them together.
1227        for flag in supported_impls:
1228            new_kwargs = deepcopy(base_kwargs)
1229            new_kwargs[flag] = True
1230            all_optim_inputs.append(
1231                OptimizerInput(
1232                    params=None, kwargs=new_kwargs, desc=f"{optim_input.desc} & {flag}"
1233                )
1234            )
1235    return all_optim_inputs
1236
1237
1238# Database of OptimizerInfo entries in alphabetical order.
1239optim_db: List[OptimizerInfo] = [
1240    OptimizerInfo(
1241        Adadelta,
1242        optim_inputs_func=optim_inputs_func_adadelta,
1243        optim_error_inputs_func=optim_error_inputs_func_adadelta,
1244        supported_impls=("foreach", "differentiable"),
1245        has_capturable_arg=True,
1246        skips=(
1247            DecorateInfo(
1248                skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
1249                "TestOptimRenewed",
1250                "test_tensor_lr",
1251                active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
1252            ),
1253            DecorateInfo(
1254                skipIfTorchDynamo("See #116028"),
1255                "TestOptimRenewed",
1256                "test_set_default_dtype_works_with_foreach",
1257            ),
1258            DecorateInfo(
1259                skipIfTorchDynamo(
1260                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
1261                ),
1262                "TestOptimRenewed",
1263                "test_complex_2d",
1264            ),
1265            # Note on tolerances:
1266            # test_correctness_Adadelta_cuda_float32
1267            # Mismatched elements: 10 / 100 (10.0%)
1268            # Greatest absolute difference: 4.838220775127411e-05 at index (7, 4) (up to 1e-05 allowed)
1269            # Greatest relative difference: 0.007270356640219688 at index (7, 2) (up to 1e-05 allowed)
1270            # This is due to floating point ordering error + usage of sqrt
1271            DecorateInfo(
1272                toleranceOverride(
1273                    {
1274                        torch.float32: tol(
1275                            rtol=5.5e-4,
1276                            atol=5e-5,
1277                        )
1278                    }
1279                ),
1280                "CompiledOptimizerParityTests",
1281                "test_correctness",
1282            ),
1283            DecorateInfo(
1284                skipIfTorchDynamo(
1285                    "This test uses mocks, which dynamo does not support"
1286                ),
1287                "TestOptimRenewed",
1288                "test_defaults_changed_to_foreach",
1289            ),
1290        ),
1291    ),
1292    OptimizerInfo(
1293        Adafactor,
1294        optim_inputs_func=optim_inputs_func_adafactor,
1295        optim_error_inputs_func=optim_error_inputs_func_adafactor,
1296        supported_impls=("foreach",),
1297        not_og_supported_flags=("foreach",),
1298        supports_complex=False,
1299        skips=(
1300            DecorateInfo(
1301                unittest.skip("See #133268 regarding dtype being None"),
1302                "CompiledOptimizerParityTests",
1303                "test_correctness",
1304                device_type="cuda",
1305                active_if=lambda kwargs: kwargs.get("use_closure", False),
1306            ),
1307            DecorateInfo(
1308                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1309                "TestOptimRenewed",
1310                "test_can_load_older_state_dict",
1311                device_type="cuda",
1312            ),
1313            DecorateInfo(
1314                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1315                "TestOptimRenewed",
1316                "test_deepcopy_copies_all_public_attrs",
1317                device_type="cuda",
1318            ),
1319            DecorateInfo(
1320                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1321                "TestOptimRenewed",
1322                "test_foreach_large_tensor",
1323            ),
1324            DecorateInfo(
1325                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1326                "TestOptimRenewed",
1327                "test_foreach_matches_forloop",
1328            ),
1329            DecorateInfo(
1330                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1331                "TestOptimRenewed",
1332                "test_load_nontensor_step",
1333                device_type="cuda",
1334            ),
1335            DecorateInfo(
1336                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1337                "TestOptimRenewed",
1338                "test_mixed_device_dtype",
1339            ),
1340            DecorateInfo(
1341                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1342                "TestOptimRenewed",
1343                "test_param_groups_lr",
1344                device_type="cuda",
1345            ),
1346            DecorateInfo(
1347                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1348                "TestOptimRenewed",
1349                "test_param_groups_weight_decay",
1350                device_type="cuda",
1351            ),
1352            DecorateInfo(
1353                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1354                "TestOptimRenewed",
1355                "test_peak_memory_foreach",
1356            ),
1357            DecorateInfo(
1358                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1359                "TestOptimRenewed",
1360                "test_save_load_equality_with_weights_only",
1361                device_type="cuda",
1362            ),
1363            DecorateInfo(
1364                skipIfTorchDynamo("See #116028 regarding copy not supported"),
1365                "TestOptimRenewed",
1366                "test_set_default_dtype_works_with_foreach",
1367            ),
1368            DecorateInfo(
1369                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1370                "TestOptimRenewed",
1371                "test_state_dict_deterministic",
1372                device_type="cuda",
1373            ),
1374            DecorateInfo(
1375                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1376                "TestOptimRenewed",
1377                "test_step_is_noop_for_zero_grads",
1378                device_type="cuda",
1379            ),
1380            DecorateInfo(
1381                unittest.skip("See #133268 regarding dtype being None"),
1382                "CompiledOptimizerParityTests",
1383                "test_correctness",
1384                device_type="xpu",
1385            ),
1386            DecorateInfo(
1387                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1388                "TestOptimRenewed",
1389                "test_can_load_older_state_dict",
1390                device_type="xpu",
1391            ),
1392            DecorateInfo(
1393                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1394                "TestOptimRenewed",
1395                "test_deepcopy_copies_all_public_attrs",
1396                device_type="xpu",
1397            ),
1398            DecorateInfo(
1399                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1400                "TestOptimRenewed",
1401                "test_load_nontensor_step",
1402                device_type="xpu",
1403            ),
1404            DecorateInfo(
1405                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1406                "TestOptimRenewed",
1407                "test_param_groups_lr",
1408                device_type="xpu",
1409            ),
1410            DecorateInfo(
1411                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1412                "TestOptimRenewed",
1413                "test_param_groups_weight_decay",
1414                device_type="xpu",
1415            ),
1416            DecorateInfo(
1417                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1418                "TestOptimRenewed",
1419                "test_save_load_equality_with_weights_only",
1420                device_type="xpu",
1421            ),
1422            DecorateInfo(
1423                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1424                "TestOptimRenewed",
1425                "test_state_dict_deterministic",
1426                device_type="xpu",
1427            ),
1428            DecorateInfo(
1429                skipIfTorchDynamo("See #133268 regarding dtype being None"),
1430                "TestOptimRenewed",
1431                "test_step_is_noop_for_zero_grads",
1432                device_type="xpu",
1433            ),
1434        ),
1435    ),
1436    OptimizerInfo(
1437        Adagrad,
1438        optim_inputs_func=optim_inputs_func_adagrad,
1439        optim_error_inputs_func=optim_error_inputs_func_adagrad,
1440        supported_impls=("foreach", "differentiable", "fused"),
1441        not_og_supported_flags=(
1442            "foreach",
1443            "differentiable",
1444            "fused",
1445            "maximize",
1446            "capturable",
1447        ),
1448        supports_fused_on=("cpu",),
1449        supports_sparse=True,
1450        metadata_for_sparse=(
1451            {"lr": 0.1, "weight_decay": 0, "lr_decay": 0},
1452            [
1453                lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
1454                lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
1455            ],
1456        ),
1457        decorators=(
1458            DecorateInfo(
1459                #  Note on tolerances:
1460                #  difference comes from the fact that the non fused kernel have
1461                #  more dtype cast operations. We have another test test_fused_cpu_matches_cuda
1462                #  to make sure there is no discrepancies between cuda fused kernel
1463                #  and cpu fused kernel
1464                toleranceOverride(
1465                    {
1466                        torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
1467                        torch.float16: tol(atol=5e-3, rtol=5e-3),
1468                    }
1469                ),
1470                "TestOptimRenewed",
1471                "test_fused_matches_forloop",
1472            ),
1473        ),
1474        skips=(
1475            DecorateInfo(
1476                skipIfMps,  # addcdiv doesn't work for non-contiguous, see #118115
1477                "TestOptimRenewed",
1478                "test_forloop_goes_right_direction",
1479                active_if=lambda kwargs: not kwargs["contiguous"],
1480            ),
1481            DecorateInfo(
1482                skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
1483                "TestOptimRenewed",
1484                "test_tensor_lr",
1485                active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
1486            ),
1487            DecorateInfo(
1488                skipIfTorchDynamo("See #116028"),
1489                "TestOptimRenewed",
1490                "test_set_default_dtype_works_with_foreach",
1491            ),
1492            DecorateInfo(
1493                skipIfTorchDynamo(
1494                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
1495                ),
1496                "TestOptimRenewed",
1497                "test_complex_2d",
1498            ),
1499            DecorateInfo(
1500                skipIfTorchDynamo(
1501                    "This test uses mocks, which dynamo does not support"
1502                ),
1503                "TestOptimRenewed",
1504                "test_defaults_changed_to_foreach",
1505            ),
1506        ),
1507    ),
1508    OptimizerInfo(
1509        Adam,
1510        optim_inputs_func=optim_inputs_func_adam,
1511        scheduler_inputs=(
1512            [lambda opt: ExponentialLR(opt, gamma=0.9)],
1513            [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
1514            [
1515                lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
1516                lambda opt: ExponentialLR(opt, gamma=0.9),
1517            ],
1518            [
1519                lambda opt: ExponentialLR(opt, gamma=0.9),
1520                lambda opt: ReduceLROnPlateau(opt),
1521            ],
1522            [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
1523            [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
1524            [
1525                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
1526                lambda opt: ReduceLROnPlateau(opt),
1527            ],
1528        ),
1529        optim_error_inputs_func=optim_error_inputs_func_adam,
1530        supported_impls=("foreach", "differentiable", "fused"),
1531        has_capturable_arg=True,
1532        not_og_supported_flags=(
1533            "foreach",
1534            "differentiable",
1535            "fused",
1536            "maximize",
1537            "capturable",
1538        ),
1539        supports_fused_on=("cpu", "cuda", "mps"),
1540        decorators=(
1541            # Expected floating point error between fused and compiled forloop
1542            DecorateInfo(
1543                toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}),
1544                "TestOptimRenewed",
1545                "test_fused_matches_forloop",
1546                active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
1547                and kwargs["dtype"] == torch.float64,
1548            ),
1549            DecorateInfo(
1550                #  Note on tolerances:
1551                #  difference comes from the fact that the non fused kernel have
1552                #  more dtype cast operations. We have another test test_fused_cpu_matches_cuda
1553                #  to make sure there is no discrepancies between cuda fused kernel
1554                #  and cpu fused kernel
1555                toleranceOverride(
1556                    {
1557                        torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
1558                        torch.float16: tol(atol=5e-3, rtol=5e-3),
1559                    }
1560                ),
1561                "TestOptimRenewed",
1562                "test_fused_matches_forloop",
1563            ),
1564            DecorateInfo(
1565                # Note on tolerances:
1566                # Tracking through #127000
1567                toleranceOverride(
1568                    {
1569                        torch.float32: tol(atol=3e-5, rtol=1.3e-06),
1570                    }
1571                ),
1572                "TestCudaOptims",
1573                "test_grad_scaling_autocast_fused_optimizers",
1574            ),
1575        ),
1576        skips=(
1577            DecorateInfo(
1578                skipIfMps,  # addcdiv doesn't work for non-contiguous, see #118115
1579                "TestOptimRenewed",
1580                "test_forloop_goes_right_direction",
1581                active_if=lambda kwargs: not kwargs["contiguous"],
1582            ),
1583            DecorateInfo(
1584                skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
1585                "TestOptimRenewed",
1586                "test_tensor_lr",
1587                active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
1588            ),
1589            DecorateInfo(
1590                skipIfTorchDynamo(
1591                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
1592                ),
1593                "TestOptimRenewed",
1594                "test_set_default_dtype_works_with_foreach",
1595            ),
1596            DecorateInfo(
1597                skipIfTorchDynamo(
1598                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
1599                ),
1600                "TestOptimRenewed",
1601                "test_complex_2d",
1602            ),
1603            DecorateInfo(
1604                skipIfTorchDynamo(
1605                    "This test uses mocks, which dynamo does not support"
1606                ),
1607                "TestOptimRenewed",
1608                "test_defaults_changed_to_foreach",
1609            ),
1610        ),
1611    ),
1612    OptimizerInfo(
1613        Adamax,
1614        optim_inputs_func=optim_inputs_func_adamax,
1615        optim_error_inputs_func=optim_error_inputs_func_adamax,
1616        supported_impls=("foreach", "differentiable"),
1617        has_capturable_arg=True,
1618        skips=(
1619            DecorateInfo(
1620                skipIfMps,  # addcdiv doesn't work for non-contiguous, see #118115
1621                "TestOptimRenewed",
1622                "test_forloop_goes_right_direction",
1623                active_if=lambda kwargs: not kwargs["contiguous"],
1624            ),
1625            DecorateInfo(
1626                skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
1627                "TestOptimRenewed",
1628                "test_tensor_lr",
1629                active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
1630            ),
1631            DecorateInfo(
1632                skipIfTorchDynamo("See #116028"),
1633                "TestOptimRenewed",
1634                "test_set_default_dtype_works_with_foreach",
1635            ),
1636            DecorateInfo(
1637                skipIfTorchDynamo(
1638                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
1639                ),
1640                "TestOptimRenewed",
1641                "test_complex_2d",
1642            ),
1643            DecorateInfo(
1644                unittest.skip("Uses too much memory, even for H100, surprisingly."),
1645                "TestOptimRenewed",
1646                "test_foreach_large_tensor",
1647            ),
1648            DecorateInfo(
1649                skipIfTorchDynamo(
1650                    "This test uses mocks, which dynamo does not support"
1651                ),
1652                "TestOptimRenewed",
1653                "test_defaults_changed_to_foreach",
1654            ),
1655        ),
1656    ),
1657    OptimizerInfo(
1658        AdamW,
1659        optim_inputs_func=optim_inputs_func_adamw,
1660        optim_error_inputs_func=optim_error_inputs_func_adamw,
1661        supported_impls=("foreach", "differentiable", "fused"),
1662        not_og_supported_flags=(
1663            "foreach",
1664            "differentiable",
1665            "fused",
1666            "maximize",
1667            "capturable",
1668        ),
1669        supports_fused_on=("cpu", "cuda", "mps"),
1670        has_capturable_arg=True,
1671        decorators=(
1672            # Expected error between compiled forloop and fused optimizers
1673            DecorateInfo(
1674                toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}),
1675                "TestOptimRenewed",
1676                "test_fused_matches_forloop",
1677                active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
1678                and kwargs["dtype"] == torch.float64,
1679            ),
1680            DecorateInfo(
1681                toleranceOverride(
1682                    #  Note on tolerances:
1683                    #  difference comes from the fact that the non fused kernel have
1684                    #  more dtype cast operations. We have another test test_fused_cpu_matches_cuda
1685                    #  to make sure there is no discrepancies between cuda fused kernel
1686                    #  and cpu fused kernel
1687                    {
1688                        torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
1689                        torch.float16: tol(atol=5e-3, rtol=5e-3),
1690                    }
1691                ),
1692                "TestOptimRenewed",
1693                "test_fused_matches_forloop",
1694            ),
1695            # Note on tolerances:
1696            # Tracking through #127000
1697            DecorateInfo(
1698                toleranceOverride(
1699                    {
1700                        torch.float32: tol(
1701                            atol=3e-5,
1702                            rtol=1.3e-06,
1703                        )
1704                    }
1705                ),
1706                "TestCudaOptims",
1707                "test_grad_scaling_autocast_fused_optimizers",
1708            ),
1709        ),
1710        skips=(
1711            DecorateInfo(
1712                skipIfMps,  # addcdiv doesn't work for non-contiguous, see #118115
1713                "TestOptimRenewed",
1714                "test_forloop_goes_right_direction",
1715                active_if=lambda kwargs: not kwargs["contiguous"],
1716            ),
1717            DecorateInfo(
1718                skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
1719                "TestOptimRenewed",
1720                "test_tensor_lr",
1721                active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
1722            ),
1723            DecorateInfo(
1724                skipIfTorchDynamo(
1725                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
1726                ),
1727                "TestOptimRenewed",
1728                "test_set_default_dtype_works_with_foreach",
1729            ),
1730            DecorateInfo(
1731                skipIfTorchDynamo(
1732                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
1733                ),
1734                "TestOptimRenewed",
1735                "test_complex_2d",
1736            ),
1737            DecorateInfo(
1738                skipIfTorchDynamo(
1739                    "This test uses mocks, which dynamo does not support"
1740                ),
1741                "TestOptimRenewed",
1742                "test_defaults_changed_to_foreach",
1743            ),
1744        ),
1745    ),
1746    OptimizerInfo(
1747        ASGD,
1748        optim_inputs_func=optim_inputs_func_asgd,
1749        optim_error_inputs_func=optim_error_inputs_func_asgd,
1750        supported_impls=("foreach", "differentiable"),
1751        has_capturable_arg=True,
1752        skips=(
1753            DecorateInfo(
1754                skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
1755                "TestOptimRenewed",
1756                "test_tensor_lr",
1757                active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
1758            ),
1759            DecorateInfo(
1760                skipIfTorchDynamo(
1761                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
1762                ),
1763                "TestOptimRenewed",
1764                "test_set_default_dtype_works_with_foreach",
1765            ),
1766            DecorateInfo(
1767                skipIfTorchDynamo(
1768                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
1769                ),
1770                "TestOptimRenewed",
1771                "test_complex_2d",
1772            ),
1773            DecorateInfo(
1774                toleranceOverride(
1775                    {
1776                        torch.float32: tol(atol=1.5e-5, rtol=1e-5),
1777                    }
1778                ),
1779                "TestOptimRenewed",
1780                "test_step_is_noop_for_zero_grads",
1781            ),
1782            DecorateInfo(
1783                skipIfTorchDynamo(
1784                    "This test uses mocks, which dynamo does not support"
1785                ),
1786                "TestOptimRenewed",
1787                "test_defaults_changed_to_foreach",
1788            ),
1789            DecorateInfo(
1790                unittest.skip(
1791                    "ASGD internally changes the weights even with zero grad"
1792                ),
1793                "TestOptimRenewed",
1794                "test_step_is_noop_for_zero_grads",
1795            ),
1796        ),
1797    ),
1798    OptimizerInfo(
1799        LBFGS,
1800        optim_inputs_func=optim_inputs_func_lbfgs,
1801        optim_error_inputs_func=optim_error_inputs_func_lbfgs,
1802        supported_impls=(),
1803        step_requires_closure=True,
1804        supports_param_groups=False,
1805        supports_multiple_devices=False,
1806        skips=(
1807            # Fails on MacOS 13.2.1 in CI https://github.com/pytorch/pytorch/issues/117094
1808            DecorateInfo(
1809                skipIfMps, "TestOptimRenewed", "test_can_load_older_state_dict"
1810            ),
1811            DecorateInfo(
1812                toleranceOverride(
1813                    {
1814                        torch.complex64: tol(
1815                            rtol=4.5e-5,
1816                            atol=5e-5,
1817                        )
1818                    }
1819                ),
1820                "TestOptimRenewed",
1821                "test_complex_2d",
1822            ),
1823            DecorateInfo(
1824                unittest.skip("Does not support param groups"),
1825                "TestOptimRenewed",
1826                "test_param_groups_lr",
1827            ),
1828            DecorateInfo(
1829                unittest.skip("Does not support param groups"),
1830                "TestOptimRenewed",
1831                "test_param_groups_weight_decay",
1832            ),
1833            DecorateInfo(
1834                unittest.skip("LBFGS doesn't support multidevice"),
1835                "TestOptimRenewed",
1836                "test_forloop_goes_right_direction_multigpu",
1837            ),
1838            DecorateInfo(
1839                unittest.skip("Does not support param groups"),
1840                "TestOptimRenewed",
1841                "test_param_group_with_lrscheduler_goes_right_direction",
1842            ),
1843            DecorateInfo(
1844                skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
1845                "TestOptimRenewed",
1846                "test_tensor_lr",
1847                active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
1848            ),
1849            # https://github.com/pytorch/pytorch/issues/131398
1850            DecorateInfo(
1851                unittest.expectedFailure,
1852                "CompiledOptimizerParityTests",
1853                "test_correctness",
1854                active_if=lambda kwargs: sys.platform == "darwin"
1855                and kwargs["use_closure"],
1856            ),
1857        ),
1858    ),
1859    OptimizerInfo(
1860        NAdam,
1861        optim_inputs_func=optim_inputs_func_nadam,
1862        optim_error_inputs_func=optim_error_inputs_func_nadam,
1863        supported_impls=("foreach", "differentiable"),
1864        has_capturable_arg=True,
1865        skips=(
1866            DecorateInfo(
1867                skipIfMps,  # addcdiv doesn't work for non-contiguous, see #118115
1868                "TestOptimRenewed",
1869                "test_forloop_goes_right_direction",
1870                active_if=lambda kwargs: not kwargs["contiguous"],
1871            ),
1872            DecorateInfo(
1873                skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
1874                "TestOptimRenewed",
1875                "test_tensor_lr",
1876                active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
1877            ),
1878            DecorateInfo(
1879                skipIfTorchDynamo(
1880                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
1881                ),
1882                "TestOptimRenewed",
1883                "test_set_default_dtype_works_with_foreach",
1884            ),
1885            DecorateInfo(
1886                skipIfTorchDynamo(
1887                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
1888                ),
1889                "TestOptimRenewed",
1890                "test_complex_2d",
1891            ),
1892            DecorateInfo(
1893                skipIfTorchDynamo(
1894                    "Errors, https://github.com/pytorch/pytorch/issues/117150"
1895                ),
1896                "TestOptimRenewed",
1897                "test_load_nontensor_step",
1898            ),
1899            DecorateInfo(
1900                skipIfTorchDynamo(
1901                    "This test uses mocks, which dynamo does not support"
1902                ),
1903                "TestOptimRenewed",
1904                "test_defaults_changed_to_foreach",
1905            ),
1906        ),
1907    ),
1908    OptimizerInfo(
1909        RAdam,
1910        optim_inputs_func=optim_inputs_func_radam,
1911        optim_error_inputs_func=optim_error_inputs_func_radam,
1912        supported_impls=("foreach", "differentiable"),
1913        has_capturable_arg=True,
1914        skips=(
1915            DecorateInfo(
1916                skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
1917                "TestOptimRenewed",
1918                "test_tensor_lr",
1919                active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
1920            ),
1921            DecorateInfo(
1922                skipIfTorchDynamo(
1923                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
1924                ),
1925                "TestOptimRenewed",
1926                "test_set_default_dtype_works_with_foreach",
1927            ),
1928            DecorateInfo(
1929                skipIfTorchDynamo(
1930                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
1931                ),
1932                "TestOptimRenewed",
1933                "test_complex_2d",
1934            ),
1935            DecorateInfo(
1936                toleranceOverride(
1937                    {
1938                        # previously atol=1e-7, rtol=1e-7
1939                        torch.float64: tol(atol=1.5e-7, rtol=1.1e-7)
1940                    }
1941                ),
1942                "TestOptimRenewed",
1943                "test_foreach_matches_forloop",
1944            ),
1945            DecorateInfo(
1946                skipIfTorchDynamo(
1947                    "This test uses mocks, which dynamo does not support"
1948                ),
1949                "TestOptimRenewed",
1950                "test_defaults_changed_to_foreach",
1951            ),
1952        ),
1953    ),
1954    OptimizerInfo(
1955        RMSprop,
1956        optim_inputs_func=optim_inputs_func_rmsprop,
1957        optim_error_inputs_func=optim_error_inputs_func_rmsprop,
1958        supported_impls=("foreach", "differentiable"),
1959        has_capturable_arg=True,
1960        skips=(
1961            DecorateInfo(
1962                skipIfMps,  # addcdiv doesn't work for non-contiguous, see #118115
1963                "TestOptimRenewed",
1964                "test_forloop_goes_right_direction",
1965                active_if=lambda kwargs: not kwargs["contiguous"],
1966            ),
1967            DecorateInfo(
1968                skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
1969                "TestOptimRenewed",
1970                "test_tensor_lr",
1971                active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
1972            ),
1973            DecorateInfo(
1974                skipIfTorchDynamo("See #116028"),
1975                "TestOptimRenewed",
1976                "test_set_default_dtype_works_with_foreach",
1977            ),
1978            DecorateInfo(
1979                skipIfTorchDynamo(
1980                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
1981                ),
1982                "TestOptimRenewed",
1983                "test_complex_2d",
1984            ),
1985            DecorateInfo(
1986                toleranceOverride(
1987                    {  # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202
1988                        torch.float32: tol(atol=5e-04, rtol=0.01),
1989                    }
1990                ),
1991                "TestOptimRenewed",
1992                "test_mixed_device_dtype",
1993                active_if=TEST_WITH_TORCHDYNAMO,
1994            ),
1995            DecorateInfo(
1996                skipIfTorchDynamo(
1997                    "This test uses mocks, which dynamo does not support"
1998                ),
1999                "TestOptimRenewed",
2000                "test_defaults_changed_to_foreach",
2001            ),
2002        ),
2003    ),
2004    OptimizerInfo(
2005        Rprop,
2006        optim_inputs_func=optim_inputs_func_rprop,
2007        optim_error_inputs_func=optim_error_inputs_func_rprop,
2008        supported_impls=("foreach", "differentiable"),
2009        has_capturable_arg=True,
2010        skips=(
2011            DecorateInfo(
2012                skipIfMps,  # Rprop doesn't update for non-contiguous, see #118117
2013                "TestOptimRenewed",
2014                "test_forloop_goes_right_direction",
2015                active_if=lambda kwargs: not kwargs["contiguous"],
2016            ),
2017            DecorateInfo(
2018                skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
2019                "TestOptimRenewed",
2020                "test_tensor_lr",
2021                active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
2022            ),
2023            DecorateInfo(
2024                skipIfTorchDynamo("See #116028"),
2025                "TestOptimRenewed",
2026                "test_set_default_dtype_works_with_foreach",
2027            ),
2028            DecorateInfo(
2029                skipIfTorchDynamo(
2030                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
2031                ),
2032                "TestOptimRenewed",
2033                "test_complex_2d",
2034            ),
2035            DecorateInfo(
2036                skipIfTorchDynamo(
2037                    "This test uses mocks, which dynamo does not support"
2038                ),
2039                "TestOptimRenewed",
2040                "test_defaults_changed_to_foreach",
2041            ),
2042        ),
2043    ),
2044    OptimizerInfo(
2045        SGD,
2046        optim_inputs_func=optim_inputs_func_sgd,
2047        scheduler_inputs=(
2048            [lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
2049            [
2050                lambda opt: LinearLR(
2051                    opt, start_factor=0.4, end_factor=0.8, total_iters=4
2052                )
2053            ],
2054            [
2055                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
2056                lambda opt: LinearLR(
2057                    opt, start_factor=0.4, end_factor=0.6, total_iters=4
2058                ),
2059            ],
2060            [
2061                lambda opt: StepLR(opt, gamma=0.99, step_size=10),
2062                lambda opt: ExponentialLR(opt, gamma=0.99),
2063                lambda opt: ReduceLROnPlateau(opt),
2064            ],
2065            [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
2066            [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
2067            [
2068                lambda opt: StepLR(opt, gamma=0.9, step_size=10),
2069                lambda opt: ReduceLROnPlateau(opt),
2070            ],
2071        ),
2072        optim_error_inputs_func=optim_error_inputs_func_sgd,
2073        supported_impls=("foreach", "differentiable", "fused"),
2074        not_og_supported_flags=(
2075            "foreach",
2076            "differentiable",
2077            "fused",
2078            "maximize",
2079            "capturable",
2080        ),
2081        supports_sparse=True,
2082        metadata_for_sparse=(
2083            {
2084                "lr": 4.8e-3,
2085                "maximize": False,
2086                "momentum": 0,
2087                "nesterov": False,
2088                "weight_decay": 0,
2089            },
2090            [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
2091        ),
2092        supports_fused_on=(
2093            "cpu",
2094            "cuda",
2095            "mps",
2096        ),
2097        skips=(
2098            DecorateInfo(
2099                skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
2100                "TestOptimRenewed",
2101                "test_tensor_lr",
2102                active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
2103            ),
2104            DecorateInfo(
2105                skipIfTorchDynamo(
2106                    "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
2107                ),
2108                "TestOptimRenewed",
2109                "test_set_default_dtype_works_with_foreach",
2110            ),
2111            DecorateInfo(
2112                skipIfTorchDynamo(
2113                    "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
2114                ),
2115                "TestOptimRenewed",
2116                "test_complex_2d",
2117            ),
2118            DecorateInfo(
2119                toleranceOverride(
2120                    {  # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202
2121                        torch.float32: tol(atol=5e-04, rtol=0.007),
2122                    }
2123                ),
2124                "TestOptimRenewed",
2125                "test_mixed_device_dtype",
2126                active_if=TEST_WITH_TORCHDYNAMO,
2127            ),
2128            DecorateInfo(
2129                skipIfTorchDynamo(
2130                    "This test uses mocks, which dynamo does not support"
2131                ),
2132                "TestOptimRenewed",
2133                "test_defaults_changed_to_foreach",
2134            ),
2135        ),
2136    ),
2137    OptimizerInfo(
2138        SparseAdam,
2139        optim_inputs_func=optim_inputs_func_sparseadam,
2140        optim_error_inputs_func=optim_error_inputs_func_sparseadam,
2141        supported_impls=(),
2142        only_supports_sparse_grads=True,
2143        metadata_for_sparse=({"lr": 4e-2}, []),
2144        supports_complex=False,  # Missing complex support, see #118153
2145        skips=(
2146            DecorateInfo(
2147                skipIfMps,  # SparseAdam does not support MPS
2148                "TestOptimRenewed",
2149            ),
2150            DecorateInfo(
2151                skipIfXpu(msg="SparseAdam is not yet supported on the XPU stack"),
2152            ),
2153            DecorateInfo(
2154                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
2155                "TestOptimRenewed",
2156                "test_param_groups_lr",
2157            ),
2158            DecorateInfo(
2159                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
2160                "TestOptimRenewed",
2161                "test_tensor_lr",
2162            ),
2163            DecorateInfo(
2164                unittest.skip(
2165                    "SparseAdam does not support dense gradients, see #116507"
2166                ),
2167                "TestOptimRenewed",
2168                "test_can_load_older_state_dict",
2169            ),
2170            DecorateInfo(
2171                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
2172                "TestOptimRenewed",
2173                "test_load_nontensor_step",
2174            ),
2175            DecorateInfo(
2176                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
2177                "TestOptimRenewed",
2178                "test_forloop_goes_right_direction",
2179            ),
2180            DecorateInfo(
2181                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
2182                "TestOptimRenewed",
2183                "test_forloop_goes_right_direction_multigpu",
2184            ),
2185            DecorateInfo(
2186                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
2187                "TestOptimRenewed",
2188                "test_param_group_with_lrscheduler_goes_right_direction",
2189            ),
2190            DecorateInfo(
2191                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
2192                "TestOptimRenewed",
2193                "test_state_dict_with_cuda_params",
2194            ),
2195            DecorateInfo(
2196                skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
2197                "TestOptimRenewed",
2198                "test_deepcopy_copies_all_public_attrs",
2199            ),
2200        ),
2201    ),
2202]
2203
2204
2205class TensorTracker:
2206    """
2207    A utility to track tensor clones in a list, with the expectation of popping them later (in
2208    order) to make fair comparisons between two multi-step computation. The intended use case is
2209    usually when comparing two supposed equal computations, such as an optimizer step that each
2210    individually consists of multiple steps, where numerical deviation could multiply.
2211
2212    The goal is to be able to compare and align numbers at every milestone so as to minimize
2213    numerical discrepancies, and so when the test fails, it is likely a real problem.
2214    """
2215
2216    def __init__(self, assert_eq_kwargs=None):
2217        if assert_eq_kwargs is None:
2218            assert_eq_kwargs = {}
2219        self.assert_eq_kwargs = assert_eq_kwargs
2220        self.tensors = []
2221
2222    def add(self, tensor):
2223        """
2224        Add a clone().detach()'d version of the tensor
2225        """
2226        self.tensors.append(tensor.clone().detach())
2227
2228    # pops from beginning, like a queue and not a stack!
2229    def pop_check_set(self, tensor_to_set, testcase):
2230        """
2231        Pop the first element in the tensor tracker, assert equality between the popped tensor and
2232        the input tensor, and then set the input tensor to have the same values as the popped tensor
2233        (with copy_).
2234        """
2235        testcase.assertGreater(len(self.tensors), 0, "no tensors to pop")
2236        ref = self.tensors.pop(0)
2237
2238        testcase.assertTrue(isinstance(ref, Tensor), f"{type(ref)=}")
2239        testcase.assertEqual(tensor_to_set, ref, **self.assert_eq_kwargs)
2240
2241        with torch.no_grad():
2242            tensor_to_set.copy_(ref)
2243
2244    def all_popped(self):
2245        return len(self.tensors) == 0
2246