1# mypy: ignore-errors 2 3import functools 4import itertools 5import sys 6import unittest 7from copy import deepcopy 8from enum import Enum 9from typing import Any, Dict, List, Tuple, Union 10 11import torch 12from torch import Tensor 13from torch.nn import Parameter 14from torch.optim import ( 15 Adadelta, 16 Adafactor, 17 Adagrad, 18 Adam, 19 Adamax, 20 AdamW, 21 ASGD, 22 LBFGS, 23 NAdam, 24 Optimizer, 25 RAdam, 26 RMSprop, 27 Rprop, 28 SGD, 29 SparseAdam, 30) 31from torch.optim.lr_scheduler import ( 32 ConstantLR, 33 ExponentialLR, 34 LinearLR, 35 PolynomialLR, 36 ReduceLROnPlateau, 37 StepLR, 38) 39from torch.testing._internal.common_device_type import tol, toleranceOverride 40from torch.testing._internal.common_methods_invocations import DecorateInfo 41from torch.testing._internal.common_utils import ( 42 _TestParametrizer, 43 skipIfMps, 44 skipIfTorchDynamo, 45 skipIfXpu, 46 TEST_WITH_TORCHDYNAMO, 47) 48from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices 49 50 51class OptimizerInput: 52 """Contains args / kwargs to be passed to an optimizer constructor.""" 53 54 __slots__ = ["params", "kwargs", "desc"] 55 56 def __init__( 57 self, 58 params: Union[List[Parameter], List[Tensor], Dict[Any, Any]], 59 kwargs: Dict[str, Any], 60 desc: str = "", 61 ): 62 # params can be a list of Tensors OR param_groups OR None 63 self.params = params 64 self.kwargs = kwargs 65 self.desc = desc 66 67 def __repr__(self): 68 return f"params={self.params}, kwargs={self.kwargs}, desc={self.desc}" 69 70 71class OptimizerErrorEnum(Enum): 72 """Enumerates when an error is raised when testing optimizers.""" 73 74 CONSTRUCTION_ERROR = 0 75 STEP_ERROR = 1 76 77 78class ErrorOptimizerInput: 79 """ 80 An OptimizerInput that will cause the optimizer to throw an error when constructed. 81 Includes the type and string of the resulting error. 82 """ 83 84 __slots__ = ["optimizer_error_input", "error_on", "error_type", "error_regex"] 85 86 def __init__( 87 self, 88 optimizer_error_input, 89 *, 90 error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR, 91 error_type=RuntimeError, 92 error_regex="", 93 ): 94 self.optimizer_error_input = optimizer_error_input 95 self.error_on = error_on 96 self.error_type = error_type 97 self.error_regex = error_regex 98 99 100class OptimizerInfo: 101 """Optimizer information to be used in testing.""" 102 103 def __init__( 104 self, 105 optim_cls: Optimizer, # Class object for the Optimizer under test 106 *, 107 # Function to generate optimizer inputs EXCLUDING params. We delegate params responsibility 108 # to the test using the OptimizerInfo. OptimizerInput.params is likely None. 109 # Can optionally take in device to filter out certain unsupported configs 110 optim_inputs_func, 111 # Tuple of lambdas to generate LRScheduler instances to run with the optimizer for the 112 # LRScheduler tests like test_forloop_goes_right_direction with_lrsched. 113 # We DO NOT expect to thoroughly test LRSchedulers through the optimizers, so not every 114 # LRScheduler configuration will be included. See test_lrscheduler.py for that instead. 115 # A few optimizers like SGD and Adam will test more LRSchedulers. 116 scheduler_inputs=( 117 [ 118 lambda opt: StepLR(opt, gamma=0.9, step_size=10), 119 lambda opt: ReduceLROnPlateau(opt), 120 ], 121 ), 122 # A subset of the global-cliquey flags (fused, foreach, differentiable) the optimizer 123 # supports. See NOTE: [optimizer kwarg categories] for what global-cliquey means. 124 supported_impls: Tuple[str, ...] = ("foreach", "differentiable"), 125 # A subset of all flags, signifying which ones were only supported after the 126 # original optimizer had already been released. aka impls where we need to check BC. 127 not_og_supported_flags: Tuple[str, ...] = ( 128 "foreach", 129 "differentiable", 130 "maximize", 131 "capturable", 132 ), 133 # the optim supports passing in sparse gradients as well as dense grads 134 supports_sparse: bool = False, 135 # the optimizer constructor supports passing in capturable as a kwarg 136 has_capturable_arg: bool = False, 137 # the optim only supports one config: sparse grads w/ dense params, see SparseAdam 138 only_supports_sparse_grads: bool = False, 139 # Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests, 140 # with especially tuned hyperparameters. These only apply if the optimizer supports 141 # sparse parameters or grads. 142 metadata_for_sparse=({}, []), 143 # the optim supports complex parameters 144 supports_complex: bool = True, 145 # whether the optimizer.step() function requires a closure to be passed 146 step_requires_closure: bool = False, 147 # whether the optimizer supports per-param options with parameter groups 148 supports_param_groups: bool = True, 149 # whether the optimizer supports parameters on multiple devices 150 supports_multiple_devices: bool = True, 151 skips=(), # Indicates which tests to skip 152 decorators=None, # Additional decorators to apply to generated tests 153 optim_error_inputs_func=None, # Function to generate optim inputs that error 154 supports_fused_on: Tuple[str, ...] = (), 155 ): 156 self.optim_cls = optim_cls 157 self.optim_inputs_func = optim_inputs_func 158 self.scheduler_inputs = scheduler_inputs 159 self.supported_impls = supported_impls 160 self.not_og_supported_flags = not_og_supported_flags 161 self.supports_sparse = supports_sparse 162 self.has_capturable_arg = has_capturable_arg 163 self.metadata_for_sparse = metadata_for_sparse 164 self.only_supports_sparse_grads = only_supports_sparse_grads 165 self.supports_complex = supports_complex 166 self.step_requires_closure = step_requires_closure 167 self.supports_param_groups = supports_param_groups 168 self.supports_multiple_devices = supports_multiple_devices 169 self.decorators = ( 170 *(decorators if decorators else []), 171 *(skips if skips else []), 172 ) 173 self.optim_error_inputs_func = optim_error_inputs_func 174 self.supports_fused_on = supports_fused_on 175 176 def get_decorators(self, test_class, test_name, device, dtype, param_kwargs): 177 result = [] 178 for decorator in self.decorators: 179 if isinstance(decorator, DecorateInfo): 180 if decorator.is_active( 181 test_class, test_name, device, dtype, param_kwargs 182 ): 183 result.extend(decorator.decorators) 184 else: 185 result.append(decorator) 186 return result 187 188 @property 189 def name(self): 190 return self.optim_cls.__name__ 191 192 193class optims(_TestParametrizer): 194 """Decorator for specifying a list of optimizers over which to run a test.""" 195 196 def __init__(self, optim_info_iterable, dtypes=None): 197 self.optim_info_list = list(optim_info_iterable) 198 199 # optimizers aren't limited to be one dtype as parameters can have different dtypes 200 # We default to torch.float32, but dtypes should be specified through passed in 201 # parameters. 202 self.dtypes = dtypes if dtypes is not None else [torch.float32] 203 204 def _parametrize_test(self, test, generic_cls, device_cls): 205 if device_cls is None: 206 raise RuntimeError( 207 "The @optims decorator is only intended to be used in a device-specific " 208 "context; use it with instantiate_device_type_tests() instead of " 209 "instantiate_parametrized_tests()" 210 ) 211 212 for optim_info, dtype in itertools.product(self.optim_info_list, self.dtypes): 213 # Construct the test name; device / dtype parts are handled outside. 214 # See [Note: device and dtype suffix placement] 215 test_name = optim_info.name 216 217 # Construct parameter kwargs to pass to the test. 218 param_kwargs = {"optim_info": optim_info, "dtype": dtype} 219 220 try: 221 222 @functools.wraps(test) 223 def test_wrapper(*args, **kwargs): 224 return test(*args, **kwargs) 225 226 decorator_fn = functools.partial( 227 optim_info.get_decorators, 228 generic_cls.__name__, 229 test.__name__, 230 device_cls.device_type, 231 dtype, 232 ) 233 234 yield (test_wrapper, test_name, param_kwargs, decorator_fn) 235 except Exception as ex: 236 # Provides an error message for debugging before rethrowing the exception 237 print( 238 f"Failed to instantiate {test_name} for module {optim_info.name}!" 239 ) 240 raise ex 241 242 243# Helper function for generating error inputs for all optimizers, used below. 244def get_error_inputs_for_all_optims(device, dtype): 245 if _get_device_type(device) == "cpu": 246 sample_param = Parameter(torch.randn(1, device=device, dtype=dtype)) 247 return [ 248 ErrorOptimizerInput( 249 OptimizerInput( 250 params=sample_param, 251 kwargs={}, 252 desc="invalid param type", 253 ), 254 error_type=TypeError, 255 error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts", 256 ), 257 ErrorOptimizerInput( 258 OptimizerInput( 259 params=[sample_param, sample_param], 260 kwargs={}, 261 desc="a param group cannot have duplicate parameters", 262 ), 263 error_type=UserWarning, 264 error_regex=".*a parameter group with duplicate parameters.*", 265 ), 266 ErrorOptimizerInput( 267 OptimizerInput( 268 params=[{"params": sample_param}, {"params": sample_param}], 269 kwargs={}, 270 desc="duplicate parameters should not occur across param groups either", 271 ), 272 error_type=ValueError, 273 error_regex="some parameters appear in more than one parameter group", 274 ), 275 ErrorOptimizerInput( 276 OptimizerInput( 277 params=None, 278 kwargs=dict(lr=torch.tensor([0.001, 0.001])), 279 desc="Tensor lr must be 1-element", 280 ), 281 error_type=ValueError, 282 error_regex="Tensor lr must be 1-element", 283 ), 284 ] 285 else: 286 return [] 287 288 289# ------------------------------------------------------------------------------------------ 290# NOTE: [optimizer kwarg categories] 291# We categorize optimizer kwargs as 3 types: 292# 1. optimizer-specific flags are like amsgrad or rho or beta, flags that are specific to 293# algorithms and thus only show up for certain optimizers. There are many of these, so I 294# do not bother gathering them all and listing them here. The converse to these would be 295# global flags that every optimizer ideally _should_ support. We break global flags into 296# 2 further categories and list them all below. 297# 2. global-friendly = ["lr", "weight_decay", "maximize", "capturable"] 298# global-friendly flags are global flags who play nicely with all other global flags, 299# i.e., are mutually exclusive in function. This means that any pair of the following 300# flags can be toggled at once (e.g., maximize and weight_decay). Furthermore, any of the 301# following flags theoretically can be enabled with ANY other global flag, including the 302# cliquey ones (e.g, capturable and foreach). 303# 3. global-cliquey = ["foreach", "fused", "differentiable"] 304# global-cliquey flags are global flags that do NOT coexist with other cliquey flags, 305# usually because they contradict each other in function. For example, one should not flip 306# both foreach AND fused to True, because they are two differing performance optimizations 307# in which you can only opt into one. 308# 309# The following optim_inputs_func_* sampling functions only return constructor combinations of 310# optimizer-specific and global-friendly flags. This is because we are confident they would mesh 311# well with additional kwargs. On the flip side of the same coin, we reserve setting the 312# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs. 313 314 315def optim_inputs_func_adadelta(device, dtype=None): 316 cuda_supported_configs = [ 317 OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), 318 OptimizerInput( 319 params=None, 320 kwargs={"weight_decay": 0.1, "capturable": True}, 321 desc="capturable with weight decay", 322 ), 323 OptimizerInput( 324 params=None, 325 kwargs={"lr": torch.tensor(0.001), "capturable": True}, 326 desc="Tensor lr with capturable", 327 ), 328 ] 329 330 return [ 331 OptimizerInput(params=None, kwargs={}, desc="default"), 332 OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"), 333 OptimizerInput( 334 params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" 335 ), 336 OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), 337 OptimizerInput( 338 params=None, 339 kwargs={"weight_decay": 0.1, "maximize": True}, 340 desc="maximize, weight_decay", 341 ), 342 OptimizerInput( 343 params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho" 344 ), 345 ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) 346 347 348def optim_error_inputs_func_adadelta(device, dtype): 349 error_inputs = get_error_inputs_for_all_optims(device, dtype) 350 if _get_device_type(device) == "cpu": 351 error_inputs += [ 352 ErrorOptimizerInput( 353 OptimizerInput( 354 params=None, 355 kwargs=dict(lr=1e-2, rho=1.1), 356 desc="rho should be between 0 and 1", 357 ), 358 error_type=ValueError, 359 error_regex="Invalid rho value: 1.1", 360 ), 361 ] 362 return error_inputs 363 364 365def optim_inputs_func_adafactor(device, dtype=None): 366 return [ 367 OptimizerInput(params=None, kwargs={}, desc="default"), 368 OptimizerInput( 369 params=None, 370 kwargs={"weight_decay": 0.1, "lr": 0.01}, 371 desc="nonzero weight_decay", 372 ), 373 OptimizerInput( 374 params=None, 375 kwargs={"weight_decay": 0.1, "maximize": True}, 376 desc="maximize", 377 ), 378 OptimizerInput( 379 params=None, 380 kwargs={"beta2_decay": -1.0}, 381 desc="non-default beta2_decay", 382 ), 383 OptimizerInput( 384 params=None, 385 kwargs={"d": 1.5}, 386 desc="non-default clipping threshold d", 387 ), 388 ] 389 390 391def optim_error_inputs_func_adafactor(device, dtype): 392 error_inputs = get_error_inputs_for_all_optims(device, dtype) 393 if _get_device_type(device) == "cpu": 394 complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64) 395 complex_param.grad = torch.rand_like(complex_param) 396 error_inputs += [ 397 ErrorOptimizerInput( 398 OptimizerInput( 399 params=None, 400 kwargs=dict(eps=(-1e-30, 1e-3)), 401 desc="epsilon1 should be >= 0", 402 ), 403 error_type=ValueError, 404 error_regex="epsilon1 should be >= 0", 405 ), 406 ErrorOptimizerInput( 407 OptimizerInput( 408 params=None, 409 kwargs=dict(d=0.0), 410 desc="invalid d", 411 ), 412 error_type=ValueError, 413 error_regex="Clipping threshold d should be >= 1", 414 ), 415 ErrorOptimizerInput( 416 OptimizerInput( 417 params=None, 418 kwargs=dict(beta2_decay=0.8), 419 desc="invalid beta2_decay", 420 ), 421 error_type=ValueError, 422 error_regex="beta2_decay should be <= 0", 423 ), 424 ErrorOptimizerInput( 425 OptimizerInput( 426 params=[complex_param], 427 kwargs=dict(), 428 desc="does not support complex parameters", 429 ), 430 error_type=RuntimeError, 431 error_regex="Adafactor does not support complex parameters", 432 error_on=OptimizerErrorEnum.STEP_ERROR, 433 ), 434 ] 435 return error_inputs 436 437 438def optim_inputs_func_adagrad(device, dtype=None): 439 return [ 440 OptimizerInput(params=None, kwargs={}, desc="default"), 441 OptimizerInput( 442 params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" 443 ), 444 OptimizerInput( 445 params=None, 446 kwargs={"weight_decay": 0.1, "maximize": True}, 447 desc="maximize", 448 ), 449 OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"), 450 OptimizerInput( 451 params=None, 452 kwargs={"initial_accumulator_value": 0.1, "weight_decay": 0.1}, 453 desc="initial_accumulator_value", 454 ), 455 OptimizerInput( 456 params=None, 457 kwargs={"lr": 0.1, "lr_decay": 0.5, "weight_decay": 0.1}, 458 desc="lr_decay", 459 ), # TODO: Move out to testing in param_group? 460 OptimizerInput( 461 params=None, 462 kwargs={"lr": torch.tensor(0.001)}, 463 desc="Tensor lr", 464 ), 465 ] 466 467 468def optim_error_inputs_func_adagrad(device, dtype): 469 error_inputs = get_error_inputs_for_all_optims(device, dtype) 470 if _get_device_type(device) == "cpu": 471 error_inputs += [ 472 ErrorOptimizerInput( 473 OptimizerInput( 474 params=None, 475 kwargs=dict(lr=1e-2, lr_decay=-0.5), 476 desc="lr_decay must be bigger than 0", 477 ), 478 error_type=ValueError, 479 error_regex="Invalid lr_decay value: -0.5", 480 ), 481 ] 482 return error_inputs 483 484 485# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work 486# with all implementation code paths... 487def optim_inputs_func_adam(device, dtype=None): 488 cuda_supported_configs = [ 489 OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), 490 OptimizerInput( 491 params=None, 492 kwargs={"weight_decay": 0.1, "amsgrad": True, "capturable": True}, 493 desc="capturable, amsgrad", 494 ), 495 OptimizerInput( 496 params=None, 497 kwargs={"lr": torch.tensor(0.001), "amsgrad": True, "capturable": True}, 498 desc="Tensor lr with capturable and amsgrad", 499 ), 500 ] 501 mps_supported_configs = [ 502 OptimizerInput( 503 params=None, kwargs={"lr": torch.tensor(0.01)}, desc="Tensor lr" 504 ), 505 ] 506 507 total = ( 508 [ 509 OptimizerInput(params=None, kwargs={}, desc="default"), 510 OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"), 511 OptimizerInput( 512 params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" 513 ), 514 OptimizerInput( 515 params=None, 516 kwargs={"weight_decay": 0.1, "maximize": True}, 517 desc="maximize", 518 ), 519 OptimizerInput( 520 params=None, 521 kwargs={"weight_decay": 0.1, "amsgrad": True}, 522 desc="amsgrad", 523 ), 524 ] 525 + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) 526 + (mps_supported_configs if _get_device_type(device) == "mps" else []) 527 ) 528 if dtype in (torch.float16,): 529 for input in total: 530 """ 531 Too small eps will make denom to be zero for low precision dtype 532 denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) 533 For example, 534 >>> a 535 tensor([0.], dtype=torch.float16) 536 >>> a + 1e-8 537 tensor([0.], dtype=torch.float16) 538 """ 539 input.kwargs["eps"] = 0.1 540 return total 541 542 543def optim_error_inputs_func_adam(device, dtype): 544 error_inputs = get_error_inputs_for_all_optims(device, dtype) 545 if _get_device_type(device) == "cpu": 546 error_inputs += [ 547 ErrorOptimizerInput( 548 OptimizerInput( 549 params=None, 550 kwargs=dict(lr=1e-2, betas=(1.0, 0.0)), 551 desc="beta1 should be between 0 and 1", 552 ), 553 error_type=ValueError, 554 error_regex="Invalid beta parameter at index 0: 1.0", 555 ), 556 ErrorOptimizerInput( 557 OptimizerInput( 558 params=None, 559 kwargs=dict(lr=1e-2, weight_decay=-1), 560 desc="weight_decay should > 0", 561 ), 562 error_type=ValueError, 563 error_regex="Invalid weight_decay value: -1", 564 ), 565 ErrorOptimizerInput( 566 OptimizerInput( 567 params=None, 568 kwargs=dict(lr=torch.tensor(0.001), foreach=True), 569 desc="lr as Tensor doesn't work with foreach & not capturable", 570 ), 571 error_type=ValueError, 572 error_regex="lr as a Tensor is not supported for capturable=False and foreach=True", 573 ), 574 ] 575 if _get_device_type(device) == "cuda": 576 sample_tensor = torch.empty((), device=device, dtype=dtype) 577 error_inputs += [ 578 ErrorOptimizerInput( 579 OptimizerInput( 580 params=[sample_tensor], 581 kwargs={"foreach": True, "fused": True}, 582 desc="`fused` and `foreach` cannot be `True` together", 583 ), 584 error_type=RuntimeError, 585 error_regex="`fused` and `foreach` cannot be `True` together", 586 ), 587 ErrorOptimizerInput( 588 OptimizerInput( 589 params=[sample_tensor], 590 kwargs={"fused": True, "differentiable": True}, 591 desc="`fused` does not support `differentiable`", 592 ), 593 error_type=RuntimeError, 594 error_regex="`fused` does not support `differentiable`", 595 ), 596 ] 597 return error_inputs 598 599 600def optim_inputs_func_adamax(device, dtype=None): 601 cuda_supported_configs = [ 602 OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), 603 OptimizerInput( 604 params=None, 605 kwargs={"weight_decay": 0.9, "maximize": True, "capturable": True}, 606 desc="capturable, maximize, weight_decay", 607 ), 608 OptimizerInput( 609 params=None, 610 kwargs={"weight_decay": 0, "maximize": True, "capturable": True}, 611 desc="capturable, maximize", 612 ), 613 OptimizerInput( 614 params=None, 615 kwargs={"weight_decay": 0.9, "maximize": False, "capturable": True}, 616 desc="capturable, weight_decay", 617 ), 618 OptimizerInput( 619 params=None, 620 kwargs={ 621 "lr": torch.tensor(0.001), 622 "weight_decay": 0.9, 623 "maximize": False, 624 "capturable": True, 625 }, 626 desc="capturable, weight_decay, tensor LR", 627 ), 628 ] 629 630 return [ 631 OptimizerInput(params=None, kwargs={}, desc="default"), 632 OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"), 633 OptimizerInput( 634 params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" 635 ), 636 OptimizerInput( 637 params=None, 638 kwargs={"maximize": True}, 639 desc="maximize", 640 ), 641 OptimizerInput( 642 params=None, 643 kwargs={"weight_decay": 0.1, "maximize": True}, 644 desc="maximize, weight_decay", 645 ), 646 ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) 647 648 649def optim_error_inputs_func_adamax(device, dtype): 650 error_inputs = get_error_inputs_for_all_optims(device, dtype) 651 if _get_device_type(device) == "cpu": 652 error_inputs += [ 653 ErrorOptimizerInput( 654 OptimizerInput( 655 params=None, 656 kwargs=dict(lr=1e-2, betas=(0.0, 1.0)), 657 desc="beta2 should be between 0 and 1", 658 ), 659 error_type=ValueError, 660 error_regex="Invalid beta parameter at index 1: 1.0", 661 ), 662 ] 663 return error_inputs 664 665 666def optim_inputs_func_adamw(device, dtype=None): 667 return optim_inputs_func_adam(device, dtype) 668 669 670def optim_error_inputs_func_adamw(device, dtype): 671 return optim_error_inputs_func_adam(device, dtype) 672 673 674def optim_inputs_func_asgd(device, dtype=None): 675 cuda_supported_configs = [ 676 OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), 677 OptimizerInput( 678 params=None, 679 kwargs={"maximize": True, "capturable": True}, 680 desc="maximize, capturable", 681 ), 682 OptimizerInput( 683 params=None, 684 kwargs={"weight_decay": 0.1, "capturable": True}, 685 desc="weight_decay, capturable", 686 ), 687 OptimizerInput( 688 params=None, 689 kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True}, 690 desc="maximize, weight_decay, capturable", 691 ), 692 OptimizerInput( 693 params=None, 694 kwargs={ 695 "lr": torch.tensor(0.001), 696 "weight_decay": 0.1, 697 "maximize": True, 698 "capturable": True, 699 }, 700 desc="maximize, weight_decay, capturable, tensor LR", 701 ), 702 ] 703 return [ 704 OptimizerInput(params=None, kwargs={}, desc="default"), 705 OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"), 706 OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"), 707 OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"), 708 OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), 709 OptimizerInput( 710 params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" 711 ), 712 OptimizerInput( 713 params=None, 714 kwargs={"weight_decay": 0.1, "maximize": True}, 715 desc="maximize, nonzero weight_decay", 716 ), 717 ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) 718 719 720def optim_error_inputs_func_asgd(device, dtype): 721 error_inputs = get_error_inputs_for_all_optims(device, dtype) 722 if _get_device_type(device) == "cpu": 723 error_inputs += [ 724 ErrorOptimizerInput( 725 OptimizerInput( 726 params=None, 727 kwargs=dict(lr=1e-2, weight_decay=-0.5), 728 desc="weight_decay should > 0", 729 ), 730 error_type=ValueError, 731 error_regex="Invalid weight_decay value: -0.5", 732 ), 733 ] 734 return error_inputs 735 736 737def optim_inputs_func_lbfgs(device, dtype=None): 738 return [ 739 OptimizerInput(params=None, kwargs={}, desc="default"), 740 OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"), 741 OptimizerInput( 742 params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr" 743 ), 744 OptimizerInput( 745 params=None, kwargs={"tolerance_grad": 1e-6}, desc="tolerance_grad" 746 ), 747 OptimizerInput( 748 params=None, 749 kwargs={"line_search_fn": "strong_wolfe"}, 750 desc="strong_wolfe", 751 ), 752 ] 753 754 755def optim_error_inputs_func_lbfgs(device, dtype): 756 error_inputs = get_error_inputs_for_all_optims(device, dtype) 757 return error_inputs 758 759 760def optim_inputs_func_nadam(device, dtype=None): 761 cuda_supported_configs = [ 762 OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), 763 OptimizerInput( 764 params=None, 765 kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3, "capturable": True}, 766 desc="weight_decay, capturable", 767 ), 768 OptimizerInput( 769 params=None, 770 kwargs={ 771 "weight_decay": 0.9, 772 "momentum_decay": 6e-3, 773 "decoupled_weight_decay": True, 774 "capturable": True, 775 }, 776 desc="decoupled_weight_decay, capturable", 777 ), 778 OptimizerInput( 779 params=None, 780 kwargs={ 781 "lr": torch.tensor(0.001), 782 "weight_decay": 0.9, 783 "momentum_decay": 6e-3, 784 "decoupled_weight_decay": True, 785 "capturable": True, 786 }, 787 desc="decoupled_weight_decay, capturable", 788 ), 789 ] 790 return [ 791 OptimizerInput(params=None, kwargs={}, desc="default"), 792 OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"), 793 OptimizerInput( 794 params=None, 795 kwargs={"momentum_decay": 6e-3}, 796 desc="non-zero momentum_decay", 797 ), 798 OptimizerInput( 799 params=None, 800 kwargs={ 801 "weight_decay": 0.1, 802 }, 803 desc="weight_decay", 804 ), 805 OptimizerInput( 806 params=None, 807 kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3}, 808 desc="weight_decay, momentum_decay", 809 ), 810 OptimizerInput( 811 params=None, 812 kwargs={ 813 "weight_decay": 0.1, 814 "momentum_decay": 6e-3, 815 "decoupled_weight_decay": True, 816 }, 817 desc="decoupled_weight_decay", 818 ), 819 OptimizerInput( 820 params=None, 821 kwargs={"weight_decay": 0.1, "maximize": True}, 822 desc="maximize", 823 ), 824 ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) 825 826 827def optim_error_inputs_func_nadam(device, dtype): 828 error_inputs = get_error_inputs_for_all_optims(device, dtype) 829 if _get_device_type(device) == "cpu": 830 error_inputs += [ 831 ErrorOptimizerInput( 832 OptimizerInput( 833 params=None, 834 kwargs=dict(lr=1e-2, betas=(1.0, 0.0)), 835 desc="beta1 should be between 0 and 1", 836 ), 837 error_type=ValueError, 838 error_regex="Invalid beta parameter at index 0: 1.0", 839 ), 840 ErrorOptimizerInput( 841 OptimizerInput( 842 params=None, 843 kwargs=dict(lr=1e-2, momentum_decay=-0.2), 844 desc="momentum_decay should > 0", 845 ), 846 error_type=ValueError, 847 error_regex="Invalid momentum_decay value: -0.2", 848 ), 849 ] 850 return error_inputs 851 852 853# Weird story bro, NAdam and RAdam do not have maximize. 854def optim_inputs_func_radam(device=None, dtype=None): 855 cuda_supported_configs = [ 856 OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), 857 OptimizerInput( 858 params=None, 859 kwargs={ 860 "capturable": True, 861 "weight_decay": 0.1, 862 }, 863 desc="capturable, weight_decay", 864 ), 865 OptimizerInput( 866 params=None, 867 kwargs={ 868 "capturable": True, 869 "weight_decay": 0.1, 870 "decoupled_weight_decay": True, 871 }, 872 desc="capturable, weight_decay, decoupled_weight_decay", 873 ), 874 OptimizerInput( 875 params=None, 876 kwargs={ 877 "lr": torch.tensor(0.001), 878 "capturable": True, 879 "weight_decay": 0.1, 880 "decoupled_weight_decay": True, 881 }, 882 desc="capturable, weight_decay, decoupled_weight_decay, tensor LR", 883 ), 884 ] 885 return [ 886 OptimizerInput(params=None, kwargs={}, desc="default"), 887 OptimizerInput(params=None, kwargs={"lr": 2e-3}, desc="non-default lr"), 888 OptimizerInput(params=None, kwargs={"eps": 1e-6}, desc="non-default eps"), 889 OptimizerInput( 890 params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" 891 ), 892 OptimizerInput( 893 params=None, 894 kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True}, 895 desc="decoupled_weight_decay", 896 ), 897 OptimizerInput( 898 params=None, 899 kwargs={"weight_decay": 0.1, "maximize": True}, 900 desc="maximize", 901 ), 902 ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) 903 904 905def optim_error_inputs_func_radam(device, dtype): 906 error_inputs = get_error_inputs_for_all_optims(device, dtype) 907 if _get_device_type(device) == "cpu": 908 error_inputs += [ 909 ErrorOptimizerInput( 910 OptimizerInput( 911 params=None, 912 kwargs=dict(lr=1e-2, betas=(1.0, 0.0)), 913 desc="beta1 should be between 0 and 1", 914 ), 915 error_type=ValueError, 916 error_regex="Invalid beta parameter at index 0: 1.0", 917 ), 918 ErrorOptimizerInput( 919 OptimizerInput( 920 params=None, 921 kwargs=dict(lr=1e-2, weight_decay=-1), 922 desc="weight_decay should > 0", 923 ), 924 error_type=ValueError, 925 error_regex="Invalid weight_decay value: -1", 926 ), 927 ] 928 return error_inputs 929 930 931def optim_inputs_func_rmsprop(device, dtype=None): 932 cuda_supported_configs = [ 933 OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), 934 OptimizerInput( 935 params=None, 936 kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True}, 937 desc="capturable, maximize", 938 ), 939 OptimizerInput( 940 params=None, 941 kwargs={"lr": torch.tensor(0.001), "capturable": True}, 942 desc="Tensor lr with capturable", 943 ), 944 ] 945 946 return [ 947 OptimizerInput(params=None, kwargs={}, desc="default"), 948 OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"), 949 OptimizerInput( 950 params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" 951 ), 952 OptimizerInput( 953 params=None, 954 kwargs={ 955 "maximize": True, 956 }, 957 desc="maximize", 958 ), 959 OptimizerInput( 960 params=None, 961 kwargs={"weight_decay": 0.1, "centered": True}, 962 desc="centered", 963 ), 964 OptimizerInput( 965 params=None, 966 kwargs={ 967 "maximize": True, 968 "weight_decay": 0.1, 969 }, 970 desc="maximize, weight_decay", 971 ), 972 OptimizerInput( 973 params=None, 974 kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1}, 975 desc="momentum", 976 ), 977 OptimizerInput( 978 params=None, 979 kwargs={ 980 "weight_decay": 0.1, 981 "centered": True, 982 "momentum": 0.1, 983 "maximize": True, 984 }, 985 desc="maximize, centered, weight_decay, w/ momentum", 986 ), 987 ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) 988 989 990def optim_error_inputs_func_rmsprop(device, dtype): 991 error_inputs = get_error_inputs_for_all_optims(device, dtype) 992 if _get_device_type(device) == "cpu": 993 error_inputs += [ 994 ErrorOptimizerInput( 995 OptimizerInput( 996 params=None, 997 kwargs=dict(lr=1e-2, momentum=-1.0), 998 desc="momentum should be between 0 and 1", 999 ), 1000 error_type=ValueError, 1001 error_regex="Invalid momentum value: -1.0", 1002 ), 1003 ] 1004 return error_inputs 1005 1006 1007def optim_inputs_func_rprop(device, dtype=None): 1008 cuda_supported_configs = [ 1009 OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), 1010 OptimizerInput( 1011 params=None, 1012 kwargs={"lr": torch.tensor(0.001), "capturable": True}, 1013 desc="Tensor lr with capturable", 1014 ), 1015 ] 1016 1017 return [ 1018 OptimizerInput(params=None, kwargs={}, desc="default"), 1019 OptimizerInput(params=None, kwargs={"lr": 2e-4}, desc="non-default lr"), 1020 OptimizerInput( 1021 params=None, kwargs={"etas": (0.5, 1.5)}, desc="non-default etas" 1022 ), 1023 OptimizerInput( 1024 params=None, 1025 kwargs={"step_sizes": (2e-6, 100)}, 1026 desc="non-default step_sizes", 1027 ), 1028 OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), 1029 ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) 1030 1031 1032def optim_error_inputs_func_rprop(device, dtype): 1033 error_inputs = get_error_inputs_for_all_optims(device, dtype) 1034 if _get_device_type(device) == "cpu": 1035 error_inputs += [ 1036 ErrorOptimizerInput( 1037 OptimizerInput( 1038 params=None, 1039 kwargs=dict(lr=1e-2, etas=(1.0, 0.5)), 1040 desc="0 < eta1 < 1 < eta2", 1041 ), 1042 error_type=ValueError, 1043 error_regex="Invalid eta values: 1.0, 0.5", 1044 ), 1045 ] 1046 return error_inputs 1047 1048 1049def optim_inputs_func_sgd(device, dtype=None): 1050 return [ 1051 OptimizerInput(params=None, kwargs={}, desc="default"), 1052 OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"), 1053 OptimizerInput( 1054 params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr" 1055 ), 1056 OptimizerInput( 1057 params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay" 1058 ), 1059 OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"), 1060 OptimizerInput( 1061 params=None, 1062 kwargs={"weight_decay": 0.1, "maximize": True}, 1063 desc="maximize", 1064 ), 1065 OptimizerInput( 1066 params=None, 1067 kwargs={"momentum": 0.9, "dampening": 0.5}, 1068 desc="dampening", 1069 ), 1070 OptimizerInput( 1071 params=None, 1072 kwargs={"momentum": 0.9, "weight_decay": 0.1}, 1073 desc="weight_decay w/ momentum", 1074 ), 1075 OptimizerInput( 1076 params=None, 1077 kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1}, 1078 desc="nesterov", 1079 ), 1080 ] 1081 1082 1083def optim_error_inputs_func_sgd(device, dtype): 1084 error_inputs = get_error_inputs_for_all_optims(device, dtype) 1085 if _get_device_type(device) == "cpu": 1086 error_inputs += [ 1087 ErrorOptimizerInput( 1088 OptimizerInput( 1089 params=None, 1090 kwargs=dict(lr=1e-2, momentum=-0.5), 1091 desc="momentum should be between 0 and 1", 1092 ), 1093 error_type=ValueError, 1094 error_regex="Invalid momentum value: -0.5", 1095 ), 1096 ] 1097 return error_inputs 1098 1099 1100def optim_inputs_func_sparseadam(device, dtype=None): 1101 return [ 1102 OptimizerInput(params=None, kwargs={}, desc="default"), 1103 OptimizerInput( 1104 params=None, kwargs={"lr": 0.01}, desc="non-default lr" 1105 ), # TODO: Move out to testing in param_group? 1106 OptimizerInput( 1107 params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr" 1108 ), 1109 OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), 1110 ] 1111 1112 1113def optim_error_inputs_func_sparseadam(device, dtype): 1114 error_inputs = get_error_inputs_for_all_optims(device, dtype) 1115 1116 if _get_device_type(device) == "cpu": 1117 error_inputs += [ 1118 ErrorOptimizerInput( 1119 OptimizerInput( 1120 params=None, 1121 kwargs=dict(lr=1e-2, betas=(1.0, 0.0)), 1122 desc="beta1 should be between 0 and 1", 1123 ), 1124 error_type=ValueError, 1125 error_regex="Invalid beta parameter at index 0: 1.0", 1126 ), 1127 ErrorOptimizerInput( 1128 OptimizerInput( 1129 params=[ 1130 torch.zeros( 1131 3, layout=torch.sparse_coo, device=device, dtype=dtype 1132 ) 1133 ], 1134 kwargs={}, 1135 desc="dense params required", 1136 ), 1137 error_type=ValueError, 1138 error_regex="SparseAdam requires dense parameter tensors", 1139 ), 1140 ErrorOptimizerInput( 1141 OptimizerInput( 1142 params=[ 1143 { 1144 "params": [ 1145 torch.zeros( 1146 3, 1147 layout=torch.sparse_coo, 1148 device=device, 1149 dtype=dtype, 1150 ) 1151 ] 1152 } 1153 ], 1154 kwargs={}, 1155 desc="dense params required in param_groups", 1156 ), 1157 error_type=ValueError, 1158 error_regex="SparseAdam requires dense parameter tensors", 1159 ), 1160 ErrorOptimizerInput( 1161 OptimizerInput( 1162 params=[torch.rand(2, 3, device=device, dtype=torch.complex64)], 1163 kwargs={}, 1164 desc="complex not supported", 1165 ), 1166 error_type=ValueError, 1167 error_regex="SparseAdam does not support complex parameters", 1168 ), 1169 ] 1170 return error_inputs 1171 1172 1173def _get_device_type(device: Union[str, torch.device]) -> str: 1174 # Returns the device type as a string, e.g., "cpu" or "cuda" 1175 if isinstance(device, torch.device): 1176 device = str(device.type) 1177 assert isinstance(device, str) 1178 return device.split(":")[0] 1179 1180 1181def _get_optim_inputs_including_global_cliquey_kwargs( 1182 device, dtype, optim_info, skip=() 1183) -> List[OptimizerInput]: 1184 """ 1185 Return a list of all configs for a given optimizer as a list of OptimizerInputs, 1186 including configs that have supported global cliquey kwargs (foreach, fused, 1187 differentiable) based on optim_info.supported_impls. 1188 1189 The configs (optim_inputs) returned by optim_info.optim_inputs_func(...) 1190 intentionally do NOT include global cliquey kwargs to give flexibility to tests. 1191 For example, testing correctness between toggling foreach on and off is now 1192 trivial. That said, we sometimes want to test for all possible configs on an 1193 optimizer including all supported flags, so this helper returns all optim inputs. 1194 """ 1195 assert all( 1196 x in ["foreach", "fused", "differentiable"] for x in skip 1197 ), "skip must be a subset of ['foreach', 'fused', 'differentiable']" 1198 1199 optim_inputs = optim_info.optim_inputs_func(device) 1200 1201 supported_impls = tuple( 1202 x 1203 for x in optim_info.supported_impls 1204 if x not in skip 1205 and (_get_device_type(device) in optim_info.supports_fused_on or x != "fused") 1206 and ( 1207 _get_device_type(device) in _get_foreach_kernels_supported_devices() 1208 or x != "foreach" 1209 ) 1210 ) 1211 1212 all_optim_inputs = [] 1213 for optim_input in optim_inputs: 1214 # Add the base config where all the flags are False 1215 base_kwargs = deepcopy(optim_input.kwargs) 1216 if len(supported_impls) != 0: 1217 for flag in supported_impls: 1218 base_kwargs[flag] = False 1219 all_optim_inputs.append( 1220 OptimizerInput(params=None, kwargs=base_kwargs, desc=optim_input.desc) 1221 ) 1222 else: 1223 all_optim_inputs.append(optim_input) 1224 # Add a config for when each of the global cliquey kwargs is True 1225 # Note that in [optimizer kwarg categories], these kwargs are mutually 1226 # exclusive, so we do not need to product them together. 1227 for flag in supported_impls: 1228 new_kwargs = deepcopy(base_kwargs) 1229 new_kwargs[flag] = True 1230 all_optim_inputs.append( 1231 OptimizerInput( 1232 params=None, kwargs=new_kwargs, desc=f"{optim_input.desc} & {flag}" 1233 ) 1234 ) 1235 return all_optim_inputs 1236 1237 1238# Database of OptimizerInfo entries in alphabetical order. 1239optim_db: List[OptimizerInfo] = [ 1240 OptimizerInfo( 1241 Adadelta, 1242 optim_inputs_func=optim_inputs_func_adadelta, 1243 optim_error_inputs_func=optim_error_inputs_func_adadelta, 1244 supported_impls=("foreach", "differentiable"), 1245 has_capturable_arg=True, 1246 skips=( 1247 DecorateInfo( 1248 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), 1249 "TestOptimRenewed", 1250 "test_tensor_lr", 1251 active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7), 1252 ), 1253 DecorateInfo( 1254 skipIfTorchDynamo("See #116028"), 1255 "TestOptimRenewed", 1256 "test_set_default_dtype_works_with_foreach", 1257 ), 1258 DecorateInfo( 1259 skipIfTorchDynamo( 1260 "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" 1261 ), 1262 "TestOptimRenewed", 1263 "test_complex_2d", 1264 ), 1265 # Note on tolerances: 1266 # test_correctness_Adadelta_cuda_float32 1267 # Mismatched elements: 10 / 100 (10.0%) 1268 # Greatest absolute difference: 4.838220775127411e-05 at index (7, 4) (up to 1e-05 allowed) 1269 # Greatest relative difference: 0.007270356640219688 at index (7, 2) (up to 1e-05 allowed) 1270 # This is due to floating point ordering error + usage of sqrt 1271 DecorateInfo( 1272 toleranceOverride( 1273 { 1274 torch.float32: tol( 1275 rtol=5.5e-4, 1276 atol=5e-5, 1277 ) 1278 } 1279 ), 1280 "CompiledOptimizerParityTests", 1281 "test_correctness", 1282 ), 1283 DecorateInfo( 1284 skipIfTorchDynamo( 1285 "This test uses mocks, which dynamo does not support" 1286 ), 1287 "TestOptimRenewed", 1288 "test_defaults_changed_to_foreach", 1289 ), 1290 ), 1291 ), 1292 OptimizerInfo( 1293 Adafactor, 1294 optim_inputs_func=optim_inputs_func_adafactor, 1295 optim_error_inputs_func=optim_error_inputs_func_adafactor, 1296 supported_impls=("foreach",), 1297 not_og_supported_flags=("foreach",), 1298 supports_complex=False, 1299 skips=( 1300 DecorateInfo( 1301 unittest.skip("See #133268 regarding dtype being None"), 1302 "CompiledOptimizerParityTests", 1303 "test_correctness", 1304 device_type="cuda", 1305 active_if=lambda kwargs: kwargs.get("use_closure", False), 1306 ), 1307 DecorateInfo( 1308 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1309 "TestOptimRenewed", 1310 "test_can_load_older_state_dict", 1311 device_type="cuda", 1312 ), 1313 DecorateInfo( 1314 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1315 "TestOptimRenewed", 1316 "test_deepcopy_copies_all_public_attrs", 1317 device_type="cuda", 1318 ), 1319 DecorateInfo( 1320 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1321 "TestOptimRenewed", 1322 "test_foreach_large_tensor", 1323 ), 1324 DecorateInfo( 1325 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1326 "TestOptimRenewed", 1327 "test_foreach_matches_forloop", 1328 ), 1329 DecorateInfo( 1330 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1331 "TestOptimRenewed", 1332 "test_load_nontensor_step", 1333 device_type="cuda", 1334 ), 1335 DecorateInfo( 1336 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1337 "TestOptimRenewed", 1338 "test_mixed_device_dtype", 1339 ), 1340 DecorateInfo( 1341 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1342 "TestOptimRenewed", 1343 "test_param_groups_lr", 1344 device_type="cuda", 1345 ), 1346 DecorateInfo( 1347 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1348 "TestOptimRenewed", 1349 "test_param_groups_weight_decay", 1350 device_type="cuda", 1351 ), 1352 DecorateInfo( 1353 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1354 "TestOptimRenewed", 1355 "test_peak_memory_foreach", 1356 ), 1357 DecorateInfo( 1358 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1359 "TestOptimRenewed", 1360 "test_save_load_equality_with_weights_only", 1361 device_type="cuda", 1362 ), 1363 DecorateInfo( 1364 skipIfTorchDynamo("See #116028 regarding copy not supported"), 1365 "TestOptimRenewed", 1366 "test_set_default_dtype_works_with_foreach", 1367 ), 1368 DecorateInfo( 1369 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1370 "TestOptimRenewed", 1371 "test_state_dict_deterministic", 1372 device_type="cuda", 1373 ), 1374 DecorateInfo( 1375 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1376 "TestOptimRenewed", 1377 "test_step_is_noop_for_zero_grads", 1378 device_type="cuda", 1379 ), 1380 DecorateInfo( 1381 unittest.skip("See #133268 regarding dtype being None"), 1382 "CompiledOptimizerParityTests", 1383 "test_correctness", 1384 device_type="xpu", 1385 ), 1386 DecorateInfo( 1387 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1388 "TestOptimRenewed", 1389 "test_can_load_older_state_dict", 1390 device_type="xpu", 1391 ), 1392 DecorateInfo( 1393 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1394 "TestOptimRenewed", 1395 "test_deepcopy_copies_all_public_attrs", 1396 device_type="xpu", 1397 ), 1398 DecorateInfo( 1399 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1400 "TestOptimRenewed", 1401 "test_load_nontensor_step", 1402 device_type="xpu", 1403 ), 1404 DecorateInfo( 1405 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1406 "TestOptimRenewed", 1407 "test_param_groups_lr", 1408 device_type="xpu", 1409 ), 1410 DecorateInfo( 1411 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1412 "TestOptimRenewed", 1413 "test_param_groups_weight_decay", 1414 device_type="xpu", 1415 ), 1416 DecorateInfo( 1417 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1418 "TestOptimRenewed", 1419 "test_save_load_equality_with_weights_only", 1420 device_type="xpu", 1421 ), 1422 DecorateInfo( 1423 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1424 "TestOptimRenewed", 1425 "test_state_dict_deterministic", 1426 device_type="xpu", 1427 ), 1428 DecorateInfo( 1429 skipIfTorchDynamo("See #133268 regarding dtype being None"), 1430 "TestOptimRenewed", 1431 "test_step_is_noop_for_zero_grads", 1432 device_type="xpu", 1433 ), 1434 ), 1435 ), 1436 OptimizerInfo( 1437 Adagrad, 1438 optim_inputs_func=optim_inputs_func_adagrad, 1439 optim_error_inputs_func=optim_error_inputs_func_adagrad, 1440 supported_impls=("foreach", "differentiable", "fused"), 1441 not_og_supported_flags=( 1442 "foreach", 1443 "differentiable", 1444 "fused", 1445 "maximize", 1446 "capturable", 1447 ), 1448 supports_fused_on=("cpu",), 1449 supports_sparse=True, 1450 metadata_for_sparse=( 1451 {"lr": 0.1, "weight_decay": 0, "lr_decay": 0}, 1452 [ 1453 lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500), 1454 lambda opt: ReduceLROnPlateau(opt, threshold=1e-4), 1455 ], 1456 ), 1457 decorators=( 1458 DecorateInfo( 1459 # Note on tolerances: 1460 # difference comes from the fact that the non fused kernel have 1461 # more dtype cast operations. We have another test test_fused_cpu_matches_cuda 1462 # to make sure there is no discrepancies between cuda fused kernel 1463 # and cpu fused kernel 1464 toleranceOverride( 1465 { 1466 torch.bfloat16: tol(atol=5e-3, rtol=5e-3), 1467 torch.float16: tol(atol=5e-3, rtol=5e-3), 1468 } 1469 ), 1470 "TestOptimRenewed", 1471 "test_fused_matches_forloop", 1472 ), 1473 ), 1474 skips=( 1475 DecorateInfo( 1476 skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 1477 "TestOptimRenewed", 1478 "test_forloop_goes_right_direction", 1479 active_if=lambda kwargs: not kwargs["contiguous"], 1480 ), 1481 DecorateInfo( 1482 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), 1483 "TestOptimRenewed", 1484 "test_tensor_lr", 1485 active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7), 1486 ), 1487 DecorateInfo( 1488 skipIfTorchDynamo("See #116028"), 1489 "TestOptimRenewed", 1490 "test_set_default_dtype_works_with_foreach", 1491 ), 1492 DecorateInfo( 1493 skipIfTorchDynamo( 1494 "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" 1495 ), 1496 "TestOptimRenewed", 1497 "test_complex_2d", 1498 ), 1499 DecorateInfo( 1500 skipIfTorchDynamo( 1501 "This test uses mocks, which dynamo does not support" 1502 ), 1503 "TestOptimRenewed", 1504 "test_defaults_changed_to_foreach", 1505 ), 1506 ), 1507 ), 1508 OptimizerInfo( 1509 Adam, 1510 optim_inputs_func=optim_inputs_func_adam, 1511 scheduler_inputs=( 1512 [lambda opt: ExponentialLR(opt, gamma=0.9)], 1513 [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)], 1514 [ 1515 lambda opt: ConstantLR(opt, factor=0.4, total_iters=4), 1516 lambda opt: ExponentialLR(opt, gamma=0.9), 1517 ], 1518 [ 1519 lambda opt: ExponentialLR(opt, gamma=0.9), 1520 lambda opt: ReduceLROnPlateau(opt), 1521 ], 1522 [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)], 1523 [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)], 1524 [ 1525 lambda opt: StepLR(opt, gamma=0.9, step_size=10), 1526 lambda opt: ReduceLROnPlateau(opt), 1527 ], 1528 ), 1529 optim_error_inputs_func=optim_error_inputs_func_adam, 1530 supported_impls=("foreach", "differentiable", "fused"), 1531 has_capturable_arg=True, 1532 not_og_supported_flags=( 1533 "foreach", 1534 "differentiable", 1535 "fused", 1536 "maximize", 1537 "capturable", 1538 ), 1539 supports_fused_on=("cpu", "cuda", "mps"), 1540 decorators=( 1541 # Expected floating point error between fused and compiled forloop 1542 DecorateInfo( 1543 toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}), 1544 "TestOptimRenewed", 1545 "test_fused_matches_forloop", 1546 active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO 1547 and kwargs["dtype"] == torch.float64, 1548 ), 1549 DecorateInfo( 1550 # Note on tolerances: 1551 # difference comes from the fact that the non fused kernel have 1552 # more dtype cast operations. We have another test test_fused_cpu_matches_cuda 1553 # to make sure there is no discrepancies between cuda fused kernel 1554 # and cpu fused kernel 1555 toleranceOverride( 1556 { 1557 torch.bfloat16: tol(atol=5e-3, rtol=5e-3), 1558 torch.float16: tol(atol=5e-3, rtol=5e-3), 1559 } 1560 ), 1561 "TestOptimRenewed", 1562 "test_fused_matches_forloop", 1563 ), 1564 DecorateInfo( 1565 # Note on tolerances: 1566 # Tracking through #127000 1567 toleranceOverride( 1568 { 1569 torch.float32: tol(atol=3e-5, rtol=1.3e-06), 1570 } 1571 ), 1572 "TestCudaOptims", 1573 "test_grad_scaling_autocast_fused_optimizers", 1574 ), 1575 ), 1576 skips=( 1577 DecorateInfo( 1578 skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 1579 "TestOptimRenewed", 1580 "test_forloop_goes_right_direction", 1581 active_if=lambda kwargs: not kwargs["contiguous"], 1582 ), 1583 DecorateInfo( 1584 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), 1585 "TestOptimRenewed", 1586 "test_tensor_lr", 1587 active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7), 1588 ), 1589 DecorateInfo( 1590 skipIfTorchDynamo( 1591 "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" 1592 ), 1593 "TestOptimRenewed", 1594 "test_set_default_dtype_works_with_foreach", 1595 ), 1596 DecorateInfo( 1597 skipIfTorchDynamo( 1598 "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" 1599 ), 1600 "TestOptimRenewed", 1601 "test_complex_2d", 1602 ), 1603 DecorateInfo( 1604 skipIfTorchDynamo( 1605 "This test uses mocks, which dynamo does not support" 1606 ), 1607 "TestOptimRenewed", 1608 "test_defaults_changed_to_foreach", 1609 ), 1610 ), 1611 ), 1612 OptimizerInfo( 1613 Adamax, 1614 optim_inputs_func=optim_inputs_func_adamax, 1615 optim_error_inputs_func=optim_error_inputs_func_adamax, 1616 supported_impls=("foreach", "differentiable"), 1617 has_capturable_arg=True, 1618 skips=( 1619 DecorateInfo( 1620 skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 1621 "TestOptimRenewed", 1622 "test_forloop_goes_right_direction", 1623 active_if=lambda kwargs: not kwargs["contiguous"], 1624 ), 1625 DecorateInfo( 1626 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), 1627 "TestOptimRenewed", 1628 "test_tensor_lr", 1629 active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7), 1630 ), 1631 DecorateInfo( 1632 skipIfTorchDynamo("See #116028"), 1633 "TestOptimRenewed", 1634 "test_set_default_dtype_works_with_foreach", 1635 ), 1636 DecorateInfo( 1637 skipIfTorchDynamo( 1638 "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" 1639 ), 1640 "TestOptimRenewed", 1641 "test_complex_2d", 1642 ), 1643 DecorateInfo( 1644 unittest.skip("Uses too much memory, even for H100, surprisingly."), 1645 "TestOptimRenewed", 1646 "test_foreach_large_tensor", 1647 ), 1648 DecorateInfo( 1649 skipIfTorchDynamo( 1650 "This test uses mocks, which dynamo does not support" 1651 ), 1652 "TestOptimRenewed", 1653 "test_defaults_changed_to_foreach", 1654 ), 1655 ), 1656 ), 1657 OptimizerInfo( 1658 AdamW, 1659 optim_inputs_func=optim_inputs_func_adamw, 1660 optim_error_inputs_func=optim_error_inputs_func_adamw, 1661 supported_impls=("foreach", "differentiable", "fused"), 1662 not_og_supported_flags=( 1663 "foreach", 1664 "differentiable", 1665 "fused", 1666 "maximize", 1667 "capturable", 1668 ), 1669 supports_fused_on=("cpu", "cuda", "mps"), 1670 has_capturable_arg=True, 1671 decorators=( 1672 # Expected error between compiled forloop and fused optimizers 1673 DecorateInfo( 1674 toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}), 1675 "TestOptimRenewed", 1676 "test_fused_matches_forloop", 1677 active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO 1678 and kwargs["dtype"] == torch.float64, 1679 ), 1680 DecorateInfo( 1681 toleranceOverride( 1682 # Note on tolerances: 1683 # difference comes from the fact that the non fused kernel have 1684 # more dtype cast operations. We have another test test_fused_cpu_matches_cuda 1685 # to make sure there is no discrepancies between cuda fused kernel 1686 # and cpu fused kernel 1687 { 1688 torch.bfloat16: tol(atol=5e-3, rtol=5e-3), 1689 torch.float16: tol(atol=5e-3, rtol=5e-3), 1690 } 1691 ), 1692 "TestOptimRenewed", 1693 "test_fused_matches_forloop", 1694 ), 1695 # Note on tolerances: 1696 # Tracking through #127000 1697 DecorateInfo( 1698 toleranceOverride( 1699 { 1700 torch.float32: tol( 1701 atol=3e-5, 1702 rtol=1.3e-06, 1703 ) 1704 } 1705 ), 1706 "TestCudaOptims", 1707 "test_grad_scaling_autocast_fused_optimizers", 1708 ), 1709 ), 1710 skips=( 1711 DecorateInfo( 1712 skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 1713 "TestOptimRenewed", 1714 "test_forloop_goes_right_direction", 1715 active_if=lambda kwargs: not kwargs["contiguous"], 1716 ), 1717 DecorateInfo( 1718 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), 1719 "TestOptimRenewed", 1720 "test_tensor_lr", 1721 active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7), 1722 ), 1723 DecorateInfo( 1724 skipIfTorchDynamo( 1725 "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" 1726 ), 1727 "TestOptimRenewed", 1728 "test_set_default_dtype_works_with_foreach", 1729 ), 1730 DecorateInfo( 1731 skipIfTorchDynamo( 1732 "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" 1733 ), 1734 "TestOptimRenewed", 1735 "test_complex_2d", 1736 ), 1737 DecorateInfo( 1738 skipIfTorchDynamo( 1739 "This test uses mocks, which dynamo does not support" 1740 ), 1741 "TestOptimRenewed", 1742 "test_defaults_changed_to_foreach", 1743 ), 1744 ), 1745 ), 1746 OptimizerInfo( 1747 ASGD, 1748 optim_inputs_func=optim_inputs_func_asgd, 1749 optim_error_inputs_func=optim_error_inputs_func_asgd, 1750 supported_impls=("foreach", "differentiable"), 1751 has_capturable_arg=True, 1752 skips=( 1753 DecorateInfo( 1754 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), 1755 "TestOptimRenewed", 1756 "test_tensor_lr", 1757 active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7), 1758 ), 1759 DecorateInfo( 1760 skipIfTorchDynamo( 1761 "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" 1762 ), 1763 "TestOptimRenewed", 1764 "test_set_default_dtype_works_with_foreach", 1765 ), 1766 DecorateInfo( 1767 skipIfTorchDynamo( 1768 "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" 1769 ), 1770 "TestOptimRenewed", 1771 "test_complex_2d", 1772 ), 1773 DecorateInfo( 1774 toleranceOverride( 1775 { 1776 torch.float32: tol(atol=1.5e-5, rtol=1e-5), 1777 } 1778 ), 1779 "TestOptimRenewed", 1780 "test_step_is_noop_for_zero_grads", 1781 ), 1782 DecorateInfo( 1783 skipIfTorchDynamo( 1784 "This test uses mocks, which dynamo does not support" 1785 ), 1786 "TestOptimRenewed", 1787 "test_defaults_changed_to_foreach", 1788 ), 1789 DecorateInfo( 1790 unittest.skip( 1791 "ASGD internally changes the weights even with zero grad" 1792 ), 1793 "TestOptimRenewed", 1794 "test_step_is_noop_for_zero_grads", 1795 ), 1796 ), 1797 ), 1798 OptimizerInfo( 1799 LBFGS, 1800 optim_inputs_func=optim_inputs_func_lbfgs, 1801 optim_error_inputs_func=optim_error_inputs_func_lbfgs, 1802 supported_impls=(), 1803 step_requires_closure=True, 1804 supports_param_groups=False, 1805 supports_multiple_devices=False, 1806 skips=( 1807 # Fails on MacOS 13.2.1 in CI https://github.com/pytorch/pytorch/issues/117094 1808 DecorateInfo( 1809 skipIfMps, "TestOptimRenewed", "test_can_load_older_state_dict" 1810 ), 1811 DecorateInfo( 1812 toleranceOverride( 1813 { 1814 torch.complex64: tol( 1815 rtol=4.5e-5, 1816 atol=5e-5, 1817 ) 1818 } 1819 ), 1820 "TestOptimRenewed", 1821 "test_complex_2d", 1822 ), 1823 DecorateInfo( 1824 unittest.skip("Does not support param groups"), 1825 "TestOptimRenewed", 1826 "test_param_groups_lr", 1827 ), 1828 DecorateInfo( 1829 unittest.skip("Does not support param groups"), 1830 "TestOptimRenewed", 1831 "test_param_groups_weight_decay", 1832 ), 1833 DecorateInfo( 1834 unittest.skip("LBFGS doesn't support multidevice"), 1835 "TestOptimRenewed", 1836 "test_forloop_goes_right_direction_multigpu", 1837 ), 1838 DecorateInfo( 1839 unittest.skip("Does not support param groups"), 1840 "TestOptimRenewed", 1841 "test_param_group_with_lrscheduler_goes_right_direction", 1842 ), 1843 DecorateInfo( 1844 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), 1845 "TestOptimRenewed", 1846 "test_tensor_lr", 1847 active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7), 1848 ), 1849 # https://github.com/pytorch/pytorch/issues/131398 1850 DecorateInfo( 1851 unittest.expectedFailure, 1852 "CompiledOptimizerParityTests", 1853 "test_correctness", 1854 active_if=lambda kwargs: sys.platform == "darwin" 1855 and kwargs["use_closure"], 1856 ), 1857 ), 1858 ), 1859 OptimizerInfo( 1860 NAdam, 1861 optim_inputs_func=optim_inputs_func_nadam, 1862 optim_error_inputs_func=optim_error_inputs_func_nadam, 1863 supported_impls=("foreach", "differentiable"), 1864 has_capturable_arg=True, 1865 skips=( 1866 DecorateInfo( 1867 skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 1868 "TestOptimRenewed", 1869 "test_forloop_goes_right_direction", 1870 active_if=lambda kwargs: not kwargs["contiguous"], 1871 ), 1872 DecorateInfo( 1873 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), 1874 "TestOptimRenewed", 1875 "test_tensor_lr", 1876 active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7), 1877 ), 1878 DecorateInfo( 1879 skipIfTorchDynamo( 1880 "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" 1881 ), 1882 "TestOptimRenewed", 1883 "test_set_default_dtype_works_with_foreach", 1884 ), 1885 DecorateInfo( 1886 skipIfTorchDynamo( 1887 "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" 1888 ), 1889 "TestOptimRenewed", 1890 "test_complex_2d", 1891 ), 1892 DecorateInfo( 1893 skipIfTorchDynamo( 1894 "Errors, https://github.com/pytorch/pytorch/issues/117150" 1895 ), 1896 "TestOptimRenewed", 1897 "test_load_nontensor_step", 1898 ), 1899 DecorateInfo( 1900 skipIfTorchDynamo( 1901 "This test uses mocks, which dynamo does not support" 1902 ), 1903 "TestOptimRenewed", 1904 "test_defaults_changed_to_foreach", 1905 ), 1906 ), 1907 ), 1908 OptimizerInfo( 1909 RAdam, 1910 optim_inputs_func=optim_inputs_func_radam, 1911 optim_error_inputs_func=optim_error_inputs_func_radam, 1912 supported_impls=("foreach", "differentiable"), 1913 has_capturable_arg=True, 1914 skips=( 1915 DecorateInfo( 1916 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), 1917 "TestOptimRenewed", 1918 "test_tensor_lr", 1919 active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7), 1920 ), 1921 DecorateInfo( 1922 skipIfTorchDynamo( 1923 "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" 1924 ), 1925 "TestOptimRenewed", 1926 "test_set_default_dtype_works_with_foreach", 1927 ), 1928 DecorateInfo( 1929 skipIfTorchDynamo( 1930 "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" 1931 ), 1932 "TestOptimRenewed", 1933 "test_complex_2d", 1934 ), 1935 DecorateInfo( 1936 toleranceOverride( 1937 { 1938 # previously atol=1e-7, rtol=1e-7 1939 torch.float64: tol(atol=1.5e-7, rtol=1.1e-7) 1940 } 1941 ), 1942 "TestOptimRenewed", 1943 "test_foreach_matches_forloop", 1944 ), 1945 DecorateInfo( 1946 skipIfTorchDynamo( 1947 "This test uses mocks, which dynamo does not support" 1948 ), 1949 "TestOptimRenewed", 1950 "test_defaults_changed_to_foreach", 1951 ), 1952 ), 1953 ), 1954 OptimizerInfo( 1955 RMSprop, 1956 optim_inputs_func=optim_inputs_func_rmsprop, 1957 optim_error_inputs_func=optim_error_inputs_func_rmsprop, 1958 supported_impls=("foreach", "differentiable"), 1959 has_capturable_arg=True, 1960 skips=( 1961 DecorateInfo( 1962 skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 1963 "TestOptimRenewed", 1964 "test_forloop_goes_right_direction", 1965 active_if=lambda kwargs: not kwargs["contiguous"], 1966 ), 1967 DecorateInfo( 1968 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), 1969 "TestOptimRenewed", 1970 "test_tensor_lr", 1971 active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7), 1972 ), 1973 DecorateInfo( 1974 skipIfTorchDynamo("See #116028"), 1975 "TestOptimRenewed", 1976 "test_set_default_dtype_works_with_foreach", 1977 ), 1978 DecorateInfo( 1979 skipIfTorchDynamo( 1980 "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" 1981 ), 1982 "TestOptimRenewed", 1983 "test_complex_2d", 1984 ), 1985 DecorateInfo( 1986 toleranceOverride( 1987 { # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202 1988 torch.float32: tol(atol=5e-04, rtol=0.01), 1989 } 1990 ), 1991 "TestOptimRenewed", 1992 "test_mixed_device_dtype", 1993 active_if=TEST_WITH_TORCHDYNAMO, 1994 ), 1995 DecorateInfo( 1996 skipIfTorchDynamo( 1997 "This test uses mocks, which dynamo does not support" 1998 ), 1999 "TestOptimRenewed", 2000 "test_defaults_changed_to_foreach", 2001 ), 2002 ), 2003 ), 2004 OptimizerInfo( 2005 Rprop, 2006 optim_inputs_func=optim_inputs_func_rprop, 2007 optim_error_inputs_func=optim_error_inputs_func_rprop, 2008 supported_impls=("foreach", "differentiable"), 2009 has_capturable_arg=True, 2010 skips=( 2011 DecorateInfo( 2012 skipIfMps, # Rprop doesn't update for non-contiguous, see #118117 2013 "TestOptimRenewed", 2014 "test_forloop_goes_right_direction", 2015 active_if=lambda kwargs: not kwargs["contiguous"], 2016 ), 2017 DecorateInfo( 2018 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), 2019 "TestOptimRenewed", 2020 "test_tensor_lr", 2021 active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7), 2022 ), 2023 DecorateInfo( 2024 skipIfTorchDynamo("See #116028"), 2025 "TestOptimRenewed", 2026 "test_set_default_dtype_works_with_foreach", 2027 ), 2028 DecorateInfo( 2029 skipIfTorchDynamo( 2030 "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" 2031 ), 2032 "TestOptimRenewed", 2033 "test_complex_2d", 2034 ), 2035 DecorateInfo( 2036 skipIfTorchDynamo( 2037 "This test uses mocks, which dynamo does not support" 2038 ), 2039 "TestOptimRenewed", 2040 "test_defaults_changed_to_foreach", 2041 ), 2042 ), 2043 ), 2044 OptimizerInfo( 2045 SGD, 2046 optim_inputs_func=optim_inputs_func_sgd, 2047 scheduler_inputs=( 2048 [lambda opt: StepLR(opt, gamma=0.9, step_size=10)], 2049 [ 2050 lambda opt: LinearLR( 2051 opt, start_factor=0.4, end_factor=0.8, total_iters=4 2052 ) 2053 ], 2054 [ 2055 lambda opt: StepLR(opt, gamma=0.9, step_size=10), 2056 lambda opt: LinearLR( 2057 opt, start_factor=0.4, end_factor=0.6, total_iters=4 2058 ), 2059 ], 2060 [ 2061 lambda opt: StepLR(opt, gamma=0.99, step_size=10), 2062 lambda opt: ExponentialLR(opt, gamma=0.99), 2063 lambda opt: ReduceLROnPlateau(opt), 2064 ], 2065 [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)], 2066 [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)], 2067 [ 2068 lambda opt: StepLR(opt, gamma=0.9, step_size=10), 2069 lambda opt: ReduceLROnPlateau(opt), 2070 ], 2071 ), 2072 optim_error_inputs_func=optim_error_inputs_func_sgd, 2073 supported_impls=("foreach", "differentiable", "fused"), 2074 not_og_supported_flags=( 2075 "foreach", 2076 "differentiable", 2077 "fused", 2078 "maximize", 2079 "capturable", 2080 ), 2081 supports_sparse=True, 2082 metadata_for_sparse=( 2083 { 2084 "lr": 4.8e-3, 2085 "maximize": False, 2086 "momentum": 0, 2087 "nesterov": False, 2088 "weight_decay": 0, 2089 }, 2090 [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)], 2091 ), 2092 supports_fused_on=( 2093 "cpu", 2094 "cuda", 2095 "mps", 2096 ), 2097 skips=( 2098 DecorateInfo( 2099 skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), 2100 "TestOptimRenewed", 2101 "test_tensor_lr", 2102 active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7), 2103 ), 2104 DecorateInfo( 2105 skipIfTorchDynamo( 2106 "Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028" 2107 ), 2108 "TestOptimRenewed", 2109 "test_set_default_dtype_works_with_foreach", 2110 ), 2111 DecorateInfo( 2112 skipIfTorchDynamo( 2113 "Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184" 2114 ), 2115 "TestOptimRenewed", 2116 "test_complex_2d", 2117 ), 2118 DecorateInfo( 2119 toleranceOverride( 2120 { # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202 2121 torch.float32: tol(atol=5e-04, rtol=0.007), 2122 } 2123 ), 2124 "TestOptimRenewed", 2125 "test_mixed_device_dtype", 2126 active_if=TEST_WITH_TORCHDYNAMO, 2127 ), 2128 DecorateInfo( 2129 skipIfTorchDynamo( 2130 "This test uses mocks, which dynamo does not support" 2131 ), 2132 "TestOptimRenewed", 2133 "test_defaults_changed_to_foreach", 2134 ), 2135 ), 2136 ), 2137 OptimizerInfo( 2138 SparseAdam, 2139 optim_inputs_func=optim_inputs_func_sparseadam, 2140 optim_error_inputs_func=optim_error_inputs_func_sparseadam, 2141 supported_impls=(), 2142 only_supports_sparse_grads=True, 2143 metadata_for_sparse=({"lr": 4e-2}, []), 2144 supports_complex=False, # Missing complex support, see #118153 2145 skips=( 2146 DecorateInfo( 2147 skipIfMps, # SparseAdam does not support MPS 2148 "TestOptimRenewed", 2149 ), 2150 DecorateInfo( 2151 skipIfXpu(msg="SparseAdam is not yet supported on the XPU stack"), 2152 ), 2153 DecorateInfo( 2154 skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), 2155 "TestOptimRenewed", 2156 "test_param_groups_lr", 2157 ), 2158 DecorateInfo( 2159 skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), 2160 "TestOptimRenewed", 2161 "test_tensor_lr", 2162 ), 2163 DecorateInfo( 2164 unittest.skip( 2165 "SparseAdam does not support dense gradients, see #116507" 2166 ), 2167 "TestOptimRenewed", 2168 "test_can_load_older_state_dict", 2169 ), 2170 DecorateInfo( 2171 skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), 2172 "TestOptimRenewed", 2173 "test_load_nontensor_step", 2174 ), 2175 DecorateInfo( 2176 skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), 2177 "TestOptimRenewed", 2178 "test_forloop_goes_right_direction", 2179 ), 2180 DecorateInfo( 2181 skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), 2182 "TestOptimRenewed", 2183 "test_forloop_goes_right_direction_multigpu", 2184 ), 2185 DecorateInfo( 2186 skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), 2187 "TestOptimRenewed", 2188 "test_param_group_with_lrscheduler_goes_right_direction", 2189 ), 2190 DecorateInfo( 2191 skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), 2192 "TestOptimRenewed", 2193 "test_state_dict_with_cuda_params", 2194 ), 2195 DecorateInfo( 2196 skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"), 2197 "TestOptimRenewed", 2198 "test_deepcopy_copies_all_public_attrs", 2199 ), 2200 ), 2201 ), 2202] 2203 2204 2205class TensorTracker: 2206 """ 2207 A utility to track tensor clones in a list, with the expectation of popping them later (in 2208 order) to make fair comparisons between two multi-step computation. The intended use case is 2209 usually when comparing two supposed equal computations, such as an optimizer step that each 2210 individually consists of multiple steps, where numerical deviation could multiply. 2211 2212 The goal is to be able to compare and align numbers at every milestone so as to minimize 2213 numerical discrepancies, and so when the test fails, it is likely a real problem. 2214 """ 2215 2216 def __init__(self, assert_eq_kwargs=None): 2217 if assert_eq_kwargs is None: 2218 assert_eq_kwargs = {} 2219 self.assert_eq_kwargs = assert_eq_kwargs 2220 self.tensors = [] 2221 2222 def add(self, tensor): 2223 """ 2224 Add a clone().detach()'d version of the tensor 2225 """ 2226 self.tensors.append(tensor.clone().detach()) 2227 2228 # pops from beginning, like a queue and not a stack! 2229 def pop_check_set(self, tensor_to_set, testcase): 2230 """ 2231 Pop the first element in the tensor tracker, assert equality between the popped tensor and 2232 the input tensor, and then set the input tensor to have the same values as the popped tensor 2233 (with copy_). 2234 """ 2235 testcase.assertGreater(len(self.tensors), 0, "no tensors to pop") 2236 ref = self.tensors.pop(0) 2237 2238 testcase.assertTrue(isinstance(ref, Tensor), f"{type(ref)=}") 2239 testcase.assertEqual(tensor_to_set, ref, **self.assert_eq_kwargs) 2240 2241 with torch.no_grad(): 2242 tensor_to_set.copy_(ref) 2243 2244 def all_popped(self): 2245 return len(self.tensors) == 0 2246