1# Owner(s): ["module: inductor"] 2 3import sys 4import unittest 5import weakref 6from contextlib import ExitStack 7from copy import deepcopy 8from typing import NamedTuple 9 10import torch 11import torch._inductor 12import torch._inductor.cudagraph_trees 13import torch.optim.lr_scheduler 14from torch._inductor import config 15from torch._inductor.test_case import TestCase 16from torch.optim import ( 17 Adadelta, 18 Adagrad, 19 Adam, 20 Adamax, 21 AdamW, 22 ASGD, 23 NAdam, 24 RAdam, 25 RMSprop, 26 Rprop, 27 SGD, 28 SparseAdam, 29) 30from torch.optim.lr_scheduler import ( 31 ChainedScheduler, 32 ConstantLR, 33 CosineAnnealingLR, 34 CosineAnnealingWarmRestarts, 35 CyclicLR, 36 ExponentialLR, 37 LambdaLR, 38 LinearLR, 39 MultiplicativeLR, 40 MultiStepLR, 41 OneCycleLR, 42 PolynomialLR, 43 ReduceLROnPlateau, 44 StepLR, 45) 46from torch.testing._internal.common_device_type import ( 47 instantiate_device_type_tests, 48 skipCUDAIf, 49 skipXPUIf, 50) 51from torch.testing._internal.common_optimizers import ( 52 _get_optim_inputs_including_global_cliquey_kwargs, 53 optim_db, 54 optims, 55) 56from torch.testing._internal.common_utils import parametrize 57from torch.testing._internal.inductor_utils import ( 58 GPU_TYPE, 59 HAS_CPU, 60 HAS_GPU, 61 has_triton, 62) 63from torch.testing._internal.triton_utils import requires_cuda, requires_gpu 64 65 66# Note: we use atypical values to amplify error 67LR_SCHEDULER_TO_KWARGS = { 68 LambdaLR: {"lr_lambda": lambda x: 10}, 69 MultiplicativeLR: {"lr_lambda": lambda x: 10}, 70 StepLR: {"step_size": 1, "gamma": 100}, 71 MultiStepLR: {"milestones": [1, 2], "gamma": 100}, 72 ExponentialLR: {"gamma": 100}, 73 CosineAnnealingLR: {"T_max": 7}, 74 # These schedulers have memory leaks in eager 75 # https://github.com/pytorch/pytorch/issues/126131 76 # SequentialLR: {"schedulers": None, "milestones": [1, 2]}, 77 # ChainedScheduler: {"schedulers": None}, 78 CyclicLR: {"base_lr": 0.001, "max_lr": 0.02, "cycle_momentum": False}, 79 CosineAnnealingWarmRestarts: {"T_0": 1}, 80 OneCycleLR: { 81 "max_lr": 0.02, 82 "cycle_momentum": False, 83 "steps_per_epoch": 1, 84 "epochs": 10, 85 }, 86 ConstantLR: {"factor": 0.001}, 87 LinearLR: {}, 88 ReduceLROnPlateau: {"factor": 0.99, "patience": 1}, 89 PolynomialLR: {}, 90} 91 92 93def create_scheduler(scheduler, optim): 94 kwargs = LR_SCHEDULER_TO_KWARGS[scheduler] 95 if "schedulers" in kwargs: 96 kwargs["schedulers"] = [ 97 create_scheduler(torch.optim.lr_scheduler.ConstantLR, optim) 98 for _ in range(2) 99 ] + [create_scheduler(torch.optim.lr_scheduler.LambdaLR, optim)] 100 101 if scheduler == ChainedScheduler: 102 return scheduler(**kwargs) 103 else: 104 return scheduler(optim, **kwargs) 105 106 107class KernelCounts(NamedTuple): 108 multitensor: int 109 singletensor: int 110 111 112# With different settings for certain 113# tests you can get different kernel counts 114# This maps the test name to the 115# expected kernel count 116KERNEL_COUNT_OVERRIDES = { 117 "test_rmsprop_foreach_weight_decay_cpu": 12, 118 "test_nadam_foreach_weight_decay_momentum_decay_cpu": 20, 119 "test_adamw_amsgrad_capturable_foreach_cuda": 3, 120 "test_adamw_amsgrad_capturable_foreach_xpu": 3, 121 "test_adamw_amsgrad_capturable_cuda": 6, 122 "test_adamw_amsgrad_capturable_xpu": 6, 123 "test_adamw_tensor_lr_amsgrad_capturable_foreach_cuda": 3, 124 "test_adamw_tensor_lr_amsgrad_capturable_foreach_xpu": 3, 125 "test_adamw_tensor_lr_amsgrad_capturable_cuda": 6, 126 "test_adamw_tensor_lr_amsgrad_capturable_xpu": 6, 127 "test_adam_tensor_lr_amsgrad_capturable_cuda": 6, 128 "test_adam_tensor_lr_amsgrad_capturable_xpu": 6, 129 "test_adam_amsgrad_capturable_cuda": 6, 130 "test_adam_amsgrad_capturable_xpu": 6, 131 "test_adadelta_tensor_lr_capturable_cuda": 6, 132 "test_adadelta_tensor_lr_capturable_xpu": 6, 133 "test_rmsprop_tensor_lr_capturable_cuda": 6, 134 "test_rmsprop_tensor_lr_capturable_xpu": 6, 135 "test_adadelta_tensor_lr_capturable_foreach_cuda": 4, 136 "test_adadelta_tensor_lr_capturable_foreach_xpu": 4, 137 "test_adadelta_foreach_weight_decay_maximize_cpu": 12, 138 "test_adadelta_foreach_rho_weight_decay_cpu": 12, 139 "test_adadelta_foreach_weight_decay_cpu": 12, 140 "test_sgd_foreach_momentum_weight_decay_cpu": 16, 141 "test_sgd_foreach_momentum_nesterov_weight_decay_cpu": 16, 142 "test_sgd_momentum_dampening_foreach_cuda": 5, 143 "test_sgd_momentum_dampening_foreach_xpu": 5, 144 "test_sgd_momentum_foreach_cuda": 5, 145 "test_sgd_momentum_foreach_xpu": 5, 146 "test_sgd_weight_decay_maximize_cuda": 4, 147 "test_sgd_weight_decay_maximize_xpu": 4, 148 "test_sgd_weight_decay_maximize_cpu": 4, 149 "test_sgd_weight_decay_cpu": 4, 150 "test_sgd_weight_decay_cuda": 4, 151 "test_sgd_weight_decay_xpu": 4, 152 "test_sgd_momentum_weight_decay_foreach_cuda": 2, 153 "test_sgd_momentum_weight_decay_foreach_xpu": 2, 154 "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2, 155 "test_sgd_momentum_nesterov_weight_decay_foreach_xpu": 2, 156 "test_sgd_cuda": 4, 157 "test_sgd_cpu": 4, 158 "test_sgd_xpu": 4, 159 "test_rmsprop_tensor_lr_capturable_foreach_cuda": 4, 160 "test_rmsprop_tensor_lr_capturable_foreach_xpu": 4, 161 "test_adagrad_initial_accumulator_value_weight_decay_foreach_xpu": 2, 162 "test_adagrad_lr_decay_weight_decay_foreach_xpu": 2, 163 "test_adagrad_weight_decay_foreach_xpu": 2, 164 "test_adagrad_weight_decay_maximize_foreach_xpu": 2, 165 "test_adagrad_tensor_lr_cpu": 6, 166 "test_adagrad_tensor_lr_cuda": 6, 167 "test_adagrad_tensor_lr_xpu": 6, 168 "test_adamax_tensor_lr_weight_decay_capturable_cuda": 6, 169 "test_adamax_tensor_lr_weight_decay_capturable_xpu": 6, 170 "test_asgd_tensor_lr_weight_decay_maximize_capturable_cuda": 5, 171 "test_asgd_tensor_lr_weight_decay_maximize_capturable_xpu": 8, 172 "test_asgd_tensor_lr_weight_decay_maximize_capturable_foreach_cuda": 4, 173 "test_asgd_tensor_lr_weight_decay_maximize_capturable_foreach_xpu": 4, 174 "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": 6, 175 "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": 9, 176 "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_foreach_cuda": 3, 177 "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_foreach_xpu": 3, 178 "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_cuda": 6, 179 "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_xpu": 6, 180 "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_foreach_cuda": 3, 181 "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_foreach_xpu": 3, 182 "test_sgd_tensor_lr_cpu": 2, 183 "test_sgd_tensor_lr_cuda": 2, 184 "test_sgd_tensor_lr_xpu": 2, 185 "test_sgd_tensor_lr_foreach_cuda": 2, 186 "test_sgd_tensor_lr_foreach_xpu": 2, 187} 188 189# also tracks currently supported optimizers 190KERNEL_COUNTS = { 191 Adam: KernelCounts(multitensor=2, singletensor=8), 192 AdamW: KernelCounts(multitensor=2, singletensor=8), 193 NAdam: KernelCounts(multitensor=2, singletensor=8), 194 Rprop: KernelCounts(multitensor=2, singletensor=8), 195 RMSprop: KernelCounts(multitensor=2, singletensor=8), 196 Adadelta: KernelCounts(multitensor=2, singletensor=8), 197 Adagrad: KernelCounts(multitensor=2, singletensor=8), 198 SGD: KernelCounts(multitensor=1, singletensor=8), 199 ASGD: KernelCounts(multitensor=2, singletensor=8), 200 RAdam: KernelCounts(multitensor=2, singletensor=8), 201 Adamax: KernelCounts(multitensor=2, singletensor=8), 202} 203 204 205def build_opt_kwarg_db(): 206 compiled_opt_db = [] 207 for optim_info in optim_db: 208 if optim_info.optim_cls not in KERNEL_COUNTS: 209 continue 210 211 for device in ["cpu", GPU_TYPE]: 212 for optim_inputs in _get_optim_inputs_including_global_cliquey_kwargs( 213 device, None, optim_info, skip=("differentiable", "fused") 214 ): 215 kwargs = dict(optim_inputs.kwargs) 216 name = f"test_{optim_info.optim_cls.__name__.lower()}" 217 218 has_tensor_lr = False 219 for key, val in kwargs.items(): 220 if (not key == "lr" and not key == "betas") and ( 221 not isinstance(val, bool) or (isinstance(val, bool) and val) 222 ): 223 name += "_" + key 224 225 if key == "lr" and isinstance(kwargs["lr"], torch.Tensor): 226 has_tensor_lr = True 227 name += "_tensor_lr" 228 229 if key == "betas" and isinstance(kwargs["betas"][0], torch.Tensor): 230 name += "_tensor_betas" 231 232 name += f"_{device}" 233 234 kwargs["device"] = device 235 if name in KERNEL_COUNT_OVERRIDES: 236 kwargs["kernel_count"] = KERNEL_COUNT_OVERRIDES[name] 237 else: 238 kwargs["kernel_count"] = ( 239 KERNEL_COUNTS[optim_info.optim_cls].multitensor 240 if kwargs.get("foreach", False) and device == GPU_TYPE 241 else KERNEL_COUNTS[optim_info.optim_cls].singletensor 242 ) 243 244 if kwargs["kernel_count"] is None or kwargs.get("fused", False): 245 continue 246 247 if has_tensor_lr: 248 for scheduler_cls in LR_SCHEDULER_TO_KWARGS.keys(): 249 name_w_scheduler = name + f"_{scheduler_cls.__name__.lower()}" 250 compiled_opt_db.append( 251 ( 252 optim_info.optim_cls, 253 name_w_scheduler, 254 kwargs, 255 scheduler_cls, 256 ) 257 ) 258 else: 259 compiled_opt_db.append((optim_info.optim_cls, name, kwargs, None)) 260 261 return compiled_opt_db 262 263 264COMPILED_OPT_KWARG_DB = build_opt_kwarg_db() 265 266aten = torch.ops.aten 267 268 269try: 270 try: 271 from .test_torchinductor import check_model, check_model_gpu 272 except ImportError: 273 from test_torchinductor import check_model, check_model_gpu 274except (unittest.SkipTest, ImportError) as e: 275 sys.stderr.write(f"{type(e)}: {e}\n") 276 if __name__ == "__main__": 277 sys.exit(0) 278 raise 279 280 281def call_scheduler(scheduler): 282 if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 283 scheduler.step(1.0) # we won't reduce the metric over two iters anyway 284 else: 285 scheduler.step() 286 287 288def compile_opt(opt_compiled, closure=None, fullgraph=True): 289 # run the patcher so that step has the expected structure 290 torch._dynamo.eval_frame.TorchPatcher.patch() 291 292 # unwrap step TWICE to avoid a deliberate graph break due to 293 # a limitation of functionalization/no_grad detection 294 # see the [Note on graph break] in optimizer.py 295 # This ignores the outer _use_grad_if_differentiable wrapper 296 # and instead manually disables grad before calling step, which is fine 297 # for now as dynamo does not support differentiable optimizers anyway 298 step_fn = opt_compiled.step.__wrapped__.__wrapped__ 299 300 # This ensures we don't receive spam of warnings from LR Scheduler 301 opt_compiled._opt_called = True 302 303 if closure is not None: 304 305 def fn(): 306 step_fn(opt_compiled, closure) 307 308 else: 309 310 def fn(): 311 step_fn(opt_compiled) 312 313 return torch.compile(fn, backend="inductor", fullgraph=fullgraph) 314 315 316def check_optim( 317 self, 318 optim_cls, 319 params_eager, 320 params_compiled, 321 state_eager, 322 state_compiled, 323 atol=None, 324 rtol=None, 325): 326 params_eager = list(params_eager) 327 params_compiled = list(params_compiled) 328 # Note on tolerances: 329 # test_correctness_Adadelta_cuda_float32 330 # Mismatched elements: 10 / 100 (10.0%) 331 # Greatest absolute difference: 4.838220775127411e-05 at index (7, 4) (up to 1e-05 allowed) 332 # Greatest relative difference: 0.007270356640219688 at index (7, 2) (up to 1e-05 allowed) 333 # This is due to floating point ordering error + usage of sqrt 334 rtol = None 335 atol = None 336 if optim_cls is Adadelta: 337 rtol = 5.5e-4 338 atol = 5e-5 339 340 self.assertEqual(list(params_eager), list(params_compiled), atol=atol, rtol=rtol) 341 342 for p_eager, p_compiled in zip(params_eager, params_compiled): 343 self.assertEqual( 344 state_eager[p_eager], 345 state_compiled[p_compiled], 346 atol=atol, 347 rtol=rtol, 348 ) 349 350 351def make_test( 352 optim_cls, 353 closure=None, 354 scheduler_cls=None, 355 kernel_count=2, 356 device="cuda", 357 **kwargs, 358): 359 def test_fn(self): 360 stack = ExitStack() 361 try: 362 # https://github.com/pytorch/pytorch/issues/118715 for capturable Adagrad support 363 # https://github.com/pytorch/pytorch/issues/118018 for capturable SGD support 364 run_cudagraphs = device == "cuda" and optim_cls not in (Adagrad, SGD) 365 if run_cudagraphs: 366 stack.enter_context(config.patch({"triton.cudagraphs": True})) 367 368 kwargs_compiled = deepcopy(kwargs) 369 if isinstance(kwargs.get("lr", None), torch.Tensor): 370 kwargs["lr"] = kwargs["lr"].to(device) 371 kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device) 372 373 if "betas" in kwargs and isinstance(kwargs["betas"][0], torch.Tensor): 374 kwargs["betas"] = ( 375 kwargs["betas"][0].to(device), 376 kwargs["betas"][1].to(device), 377 ) 378 kwargs_compiled["betas"] = ( 379 kwargs_compiled["betas"][0].to(device), 380 kwargs_compiled["betas"][1].to(device), 381 ) 382 383 torch._dynamo.reset() 384 torch._inductor.metrics.reset() 385 input = torch.ones([10, 10], device=device) 386 model_eager = torch.nn.Sequential( 387 *[torch.nn.Linear(10, 10, device=device) for _ in range(2)] 388 ) 389 model_eager(input).sum().backward() 390 391 input = torch.ones([10, 10], device=device) 392 model_compiled = deepcopy(model_eager) 393 model_compiled(input).sum().backward() 394 395 opt_eager = optim_cls(model_eager.parameters(), **kwargs) 396 opt_compiled = optim_cls(model_compiled.parameters(), **kwargs_compiled) 397 compiled_step = compile_opt(opt_compiled, closure=closure) 398 399 if scheduler_cls: 400 scheduler_compiled = create_scheduler(scheduler_cls, opt_compiled) 401 scheduler_eager = create_scheduler(scheduler_cls, opt_eager) 402 # some schedulers only change after at least an epoch has passed 403 scheduler_compiled.last_epoch = 1 404 scheduler_eager.last_epoch = 1 405 406 with torch.set_grad_enabled(False): 407 for i in range(2): 408 compiled_step() 409 opt_eager.step() 410 if scheduler_cls: 411 call_scheduler(scheduler_eager) 412 call_scheduler(scheduler_compiled) 413 414 check_optim( 415 self, 416 optim_cls, 417 model_eager.parameters(), 418 model_compiled.parameters(), 419 opt_eager.state, 420 opt_compiled.state, 421 ) 422 423 if run_cudagraphs: 424 self.check_cudagraphs_ran() 425 426 if self.check_kernel_count: 427 # currently, we compile the step and the rest of the computation 428 # separately because the step is a single element tensor 429 # hence, the usual kernel count is 2 430 self.assertEqual( 431 torch._inductor.metrics.generated_kernel_count, kernel_count 432 ) 433 finally: 434 stack.close() 435 436 if device == GPU_TYPE: 437 test_fn = requires_gpu(test_fn) 438 439 return test_fn 440 441 442def make_recompile_test(optim_cls, closure=None, kernel_count=2, **kwargs): 443 @requires_gpu 444 def test_fn(self): 445 torch._dynamo.reset() 446 torch._inductor.metrics.reset() 447 input = torch.ones([10, 10], device=GPU_TYPE) 448 model = torch.nn.Sequential( 449 *[torch.nn.Linear(10, 10, device=GPU_TYPE) for _ in range(2)] 450 ) 451 model(input).sum().backward() 452 453 opt_compiled = optim_cls(model.parameters(), **kwargs) 454 compiled_step = compile_opt(opt_compiled) 455 456 # check no recompile here 457 with torch.set_grad_enabled(False): 458 for _ in range(4): 459 compiled_step() 460 461 # perturb state to force recompile 462 # Adagrad doesn't reinitialize state on each step 463 # SGD has an empty state 464 if optim_cls in (Adagrad, SGD): 465 opt_compiled.param_groups[0]["lr"] = 0.02 466 elif optim_cls is Adam: # ensure we are guarding on the data_ptr of states 467 state_tensor = opt_compiled.state[ 468 opt_compiled.param_groups[0]["params"][0] 469 ]["exp_avg"] 470 opt_compiled.state[opt_compiled.param_groups[0]["params"][0]][ 471 "exp_avg" 472 ] = torch.zeros_like(state_tensor) 473 else: 474 opt_compiled.state.clear() 475 476 compiled_step() 477 478 if self.check_kernel_count: 479 # currently, we compile the step and the rest of the computation 480 # separately because the step is a single element tensor 481 # hence, the usual kernel count is 2 482 # multiply by 2 to account for the recompile 483 multiplier = 2 484 485 self.assertEqual( 486 torch._inductor.metrics.generated_kernel_count, 487 multiplier * kernel_count, 488 ) 489 490 return test_fn 491 492 493class CompiledOptimizerParityTests(TestCase): 494 @skipCUDAIf(not has_triton(), "torch.compile with cuda requires triton") 495 @skipXPUIf(not has_triton(), "torch.compile with xpu requires triton") 496 @optims(optim_db, dtypes=[torch.float32]) 497 @parametrize("use_closure", [True, False]) 498 def test_correctness(self, device, dtype, optim_info, use_closure): 499 optim_cls = optim_info.optim_cls 500 all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 501 device, dtype, optim_info, skip=("differentiable",) 502 ) 503 504 if optim_info.step_requires_closure and not use_closure: 505 return 506 507 for optim_input in all_optim_inputs: 508 kwargs = optim_input.kwargs 509 510 use_scheduler = isinstance(kwargs.get("lr", None), torch.Tensor) 511 scheduler_classes = ( 512 list(LR_SCHEDULER_TO_KWARGS.keys()) if use_scheduler else [None] 513 ) 514 515 for scheduler_cls in scheduler_classes: 516 torch._dynamo.reset() 517 torch._inductor.metrics.reset() 518 input = torch.ones([10, 10], device=device) 519 model_eager = torch.nn.Sequential( 520 *[torch.nn.Linear(10, 10, device=device) for _ in range(2)] 521 ) 522 model_eager(input).sum().backward() 523 model_compiled = deepcopy(model_eager) 524 model_compiled(input).sum().backward() 525 526 if optim_cls is SparseAdam: 527 for param in model_eager.parameters(): 528 param.grad = param.grad.to_sparse() 529 for param in model_compiled.parameters(): 530 param.grad = param.grad.to_sparse() 531 532 opt_compiled = optim_cls( 533 model_compiled.parameters(), **deepcopy(kwargs) 534 ) 535 opt_eager = optim_cls(model_eager.parameters(), **deepcopy(kwargs)) 536 if scheduler_cls: 537 scheduler_compiled = create_scheduler(scheduler_cls, opt_compiled) 538 scheduler_eager = create_scheduler(scheduler_cls, opt_eager) 539 # some schedulers only change after at least an epoch has passed 540 scheduler_compiled.last_epoch = 1 541 scheduler_eager.last_epoch = 1 542 543 num_steps = 2 544 if use_closure: 545 546 @torch.compile() 547 def fn(): 548 def closure(): 549 loss = model_compiled(input).sum() 550 loss.backward() 551 if optim_info.only_supports_sparse_grads: 552 for param in model_compiled.parameters(): 553 param.grad = param.grad.to_sparse() 554 return loss 555 556 opt_compiled.step(closure) 557 if scheduler_cls: 558 call_scheduler(scheduler_compiled) 559 560 def closure_eager(): 561 loss = model_eager(input).sum() 562 loss.backward() 563 if optim_info.only_supports_sparse_grads: 564 for param in model_eager.parameters(): 565 param.grad = param.grad.to_sparse() 566 567 return loss 568 569 for _ in range(num_steps): 570 opt_eager.step(closure_eager) 571 if scheduler_cls: 572 call_scheduler(scheduler_eager) 573 else: 574 575 @torch.compile() 576 def fn(): 577 opt_compiled.step() 578 if scheduler_cls: 579 call_scheduler(scheduler_compiled) 580 581 for _ in range(num_steps): 582 opt_eager.step() 583 if scheduler_cls: 584 call_scheduler(scheduler_eager) 585 586 for _ in range(num_steps): 587 fn() 588 589 check_optim( 590 self, 591 optim_cls, 592 model_eager.parameters(), 593 model_compiled.parameters(), 594 opt_eager.state, 595 opt_compiled.state, 596 ) 597 598 599class CompiledOptimizerTests(TestCase): 600 check_model_gpu = check_model_gpu 601 check_model_cpu = check_model 602 check_kernel_count = True 603 604 def setUp(self): 605 super().setUp() 606 torch._dynamo.reset() 607 torch._inductor.metrics.reset() 608 609 def tearDown(self): 610 super().tearDown() 611 torch._dynamo.reset() 612 torch._inductor.metrics.reset() 613 614 def check_cudagraphs_ran(self): 615 # We run the zeroth device currently 616 manager = torch._inductor.cudagraph_trees.get_container(0).tree_manager 617 self.assertIsNotNone(manager) 618 self.assertEqual(manager.new_graph_id().id, 1) 619 620 test_adam_recompile = make_recompile_test(Adam, lr=0.01) 621 test_adamw_recompile = make_recompile_test(AdamW, lr=0.01) 622 test_adamax_recompile = make_recompile_test(Adamax, lr=0.01) 623 test_nadam_recompile = make_recompile_test(NAdam, lr=0.01) 624 test_rprop_recompile = make_recompile_test(Rprop, lr=0.01, kernel_count=2) 625 test_rmsprop_recompile = make_recompile_test(RMSprop, lr=0.01) 626 test_adadelta_recompile = make_recompile_test(Adadelta, lr=0.01) 627 test_adagrad_recompile = make_recompile_test(Adagrad, lr=0.01) 628 test_asgd_recompile_default = make_recompile_test(ASGD, lr=0.01) 629 test_asgd_recompile_single = make_recompile_test( 630 ASGD, kernel_count=8, lr=0.01, foreach=False 631 ) 632 test_asgd_recompile_foreach = make_recompile_test(ASGD, lr=0.01, foreach=True) 633 test_sgd_recompile_single = make_recompile_test( 634 SGD, kernel_count=4, lr=0.01, foreach=False 635 ) 636 test_sgd_recompile_foreach = make_recompile_test( 637 SGD, kernel_count=1, lr=0.01, foreach=True 638 ) 639 640 @requires_gpu 641 def test_static_address_finalizer(self): 642 import gc 643 644 gc.disable() 645 p_ref = None 646 647 def fn(): 648 nonlocal p_ref 649 mod = torch.nn.Linear(10, 10, device=GPU_TYPE, bias=False) 650 for p in mod.parameters(): 651 p.grad = torch.rand_like(p) 652 653 opt = torch.optim.Adam(mod.parameters(), lr=0.1) 654 655 def fn(): 656 opt.step() 657 658 with torch.set_grad_enabled(False): 659 step_fn_compiled = torch.compile(fn) 660 step_fn_compiled() 661 p_ref = weakref.ref(p) 662 self.assertTrue(p_ref() is not None) 663 664 fn() 665 666 self.assertTrue(p_ref() is None) 667 gc.enable() 668 669 def test_guard_on_none_grads(self): 670 def training_loop(): 671 input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).reshape(3, 2) 672 673 model = torch.nn.Sequential( 674 torch.nn.Linear(2, 3), 675 torch.nn.Sigmoid(), 676 torch.nn.Linear(3, 1), 677 torch.nn.Sigmoid(), 678 ) 679 680 params = list(model.parameters()) 681 optimizer = torch.optim.Adam(params) 682 step_list = [] 683 684 for i in range(6): 685 optimizer.zero_grad() 686 # Test that step behaves as expected (a no-op) when grads are set to None 687 if i != 3: 688 output = model(input) 689 loss = output.sum() 690 loss.backward() 691 692 optimizer.step() 693 step_list.append(optimizer.state[params[0]]["step"]) 694 695 return step_list 696 697 compiled_training_loop = torch._dynamo.optimize("eager")(training_loop) 698 actual_steps = compiled_training_loop() 699 expected_steps = training_loop() 700 self.assertEqual(actual_steps, expected_steps) 701 702 # Basic shampoo test to verify we support compiling the various ops without error 703 @requires_gpu 704 def test_basic_shampoo(self): 705 param_buf = torch.rand((1024, 128)) 706 param_buf_c = param_buf.clone().detach() 707 708 params_c = [param_buf_c[0:512, :].t(), param_buf_c[512:, :].t()] 709 params = [param_buf[0:512, :].t(), param_buf[512:, :].t()] 710 711 for p, p_c in zip(params, params_c): 712 p.grad = torch.rand_like(p) 713 p_c.grad = p.grad.clone().detach() 714 715 # note this skips the root inverse because this has a lot of internal dependencies 716 # we also don't compile it regardless 717 @torch.no_grad() 718 def shampoo_functional_basic(params): 719 step = 1 720 weight_decay = 0.1 721 grads = [p.grad for p in params] 722 beta1 = 0.9 723 beta2 = 1.0 724 epsilon = 1e-10 725 preconditioners = [torch.zeros_like(p) for p in params] 726 lr = 0.01 727 728 # pt2 region 1 729 # weight decay 730 torch._foreach_add_(grads, params, alpha=weight_decay) 731 732 # update preconditioners 733 torch._foreach_addcmul_(preconditioners, grads, grads, value=1.0) 734 735 torch._foreach_mul_(grads, beta1) 736 torch._foreach_add_( 737 grads, 738 grads, 739 alpha=1 - beta1, 740 ) 741 bias_correction1 = 1.0 - beta1**step 742 grad_list = torch._foreach_div(grads, bias_correction1) 743 744 # pt2 region 2 745 # precondition (with shampoo branch), with no grafting 746 bias_correction2 = 1.0 - beta2**step 747 bias_corrected_preconditioner_list = torch._foreach_div( 748 preconditioners, bias_correction2 749 ) 750 torch._foreach_sqrt_(bias_corrected_preconditioner_list) 751 torch._foreach_add_(bias_corrected_preconditioner_list, epsilon) 752 search_directions = torch._foreach_div( 753 grad_list, bias_corrected_preconditioner_list 754 ) 755 756 torch._foreach_add_( 757 search_directions, 758 params, 759 alpha=weight_decay, 760 ) 761 762 torch._foreach_mul_(search_directions, -lr) 763 # pt2 region 3 update params 764 torch._foreach_add_(params, search_directions) 765 766 return params, preconditioners, grads 767 768 compiled_fn = torch.compile(shampoo_functional_basic) 769 770 self.assertEqual(compiled_fn(params_c), shampoo_functional_basic(params)) 771 772 @requires_gpu 773 def test_closure_graph_break(self): 774 param = torch.rand( 775 2, 3, dtype=torch.float32, device=GPU_TYPE, requires_grad=True 776 ) 777 param_c = param.clone().detach().requires_grad_(True) 778 779 def closure(): 780 param.grad = torch.ones_like(param) * 2 781 return param.grad 782 783 def closure_c(): 784 param_c.grad = torch.ones_like(param_c) * 2 785 return param_c.grad 786 787 optimizer = torch.optim.AdamW([param]) 788 optimizer_c = torch.optim.AdamW([param_c]) 789 790 def loop(opt, c): 791 opt.step(c) 792 793 compiled_loop = torch._dynamo.optimize("eager")(loop) 794 795 compiled_loop(optimizer, closure) 796 loop(optimizer_c, closure_c) 797 798 self.assertEqual(param, param_c) 799 800 def test_get_value_on_static_address(self): 801 from torch._dynamo.decorators import mark_static_address 802 from torch.optim.optimizer import _get_value 803 804 compiled = torch.compile(_get_value) 805 806 x = torch.ones(2, 2) 807 mark_static_address(x) 808 809 ret_val = compiled(x) 810 811 self.assertEqual(ret_val, x) 812 813 # compile a large foreach op and verify 814 # that the time taken is within an expected range 815 @requires_gpu 816 def test_compile_time_smoketest(self): 817 import time 818 819 xs = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(100)] 820 ys = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(100)] 821 822 @torch.compile 823 def fn(xs, ys): 824 return torch._foreach_add(xs, ys) 825 826 start = time.perf_counter() 827 fn(xs, ys) 828 end = time.perf_counter() 829 830 self.assertLess(end - start, 90) 831 832 @requires_cuda 833 def test_S429861(self): 834 # Just verify we can compile this function without error 835 try: 836 from . import s429861_repro 837 except ImportError: 838 import s429861_repro 839 840 forward = s429861_repro.forward 841 842 import torch._dynamo 843 import torch._inductor 844 from torch._dynamo.debug_utils import aot_graph_input_parser 845 from torch._inductor.utils import fresh_inductor_cache 846 847 with fresh_inductor_cache(): 848 kwargs = aot_graph_input_parser(forward) 849 torch.compile(forward)(**kwargs) 850 851 852for optim_cls, name, kwargs, scheduler_cls in COMPILED_OPT_KWARG_DB: 853 setattr( 854 CompiledOptimizerTests, 855 name, 856 make_test(optim_cls, scheduler_cls=scheduler_cls, **kwargs), 857 ) 858 859instantiate_device_type_tests( 860 CompiledOptimizerParityTests, globals(), allow_xpu=True, except_for="cpu" 861) 862 863if __name__ == "__main__": 864 from torch._inductor.test_case import run_tests 865 866 if HAS_CPU or HAS_GPU: 867 run_tests(needs="filelock") 868