xref: /aosp_15_r20/external/pytorch/test/inductor/test_compiled_optimizers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3import sys
4import unittest
5import weakref
6from contextlib import ExitStack
7from copy import deepcopy
8from typing import NamedTuple
9
10import torch
11import torch._inductor
12import torch._inductor.cudagraph_trees
13import torch.optim.lr_scheduler
14from torch._inductor import config
15from torch._inductor.test_case import TestCase
16from torch.optim import (
17    Adadelta,
18    Adagrad,
19    Adam,
20    Adamax,
21    AdamW,
22    ASGD,
23    NAdam,
24    RAdam,
25    RMSprop,
26    Rprop,
27    SGD,
28    SparseAdam,
29)
30from torch.optim.lr_scheduler import (
31    ChainedScheduler,
32    ConstantLR,
33    CosineAnnealingLR,
34    CosineAnnealingWarmRestarts,
35    CyclicLR,
36    ExponentialLR,
37    LambdaLR,
38    LinearLR,
39    MultiplicativeLR,
40    MultiStepLR,
41    OneCycleLR,
42    PolynomialLR,
43    ReduceLROnPlateau,
44    StepLR,
45)
46from torch.testing._internal.common_device_type import (
47    instantiate_device_type_tests,
48    skipCUDAIf,
49    skipXPUIf,
50)
51from torch.testing._internal.common_optimizers import (
52    _get_optim_inputs_including_global_cliquey_kwargs,
53    optim_db,
54    optims,
55)
56from torch.testing._internal.common_utils import parametrize
57from torch.testing._internal.inductor_utils import (
58    GPU_TYPE,
59    HAS_CPU,
60    HAS_GPU,
61    has_triton,
62)
63from torch.testing._internal.triton_utils import requires_cuda, requires_gpu
64
65
66# Note: we use atypical values to amplify error
67LR_SCHEDULER_TO_KWARGS = {
68    LambdaLR: {"lr_lambda": lambda x: 10},
69    MultiplicativeLR: {"lr_lambda": lambda x: 10},
70    StepLR: {"step_size": 1, "gamma": 100},
71    MultiStepLR: {"milestones": [1, 2], "gamma": 100},
72    ExponentialLR: {"gamma": 100},
73    CosineAnnealingLR: {"T_max": 7},
74    # These schedulers have memory leaks in eager
75    # https://github.com/pytorch/pytorch/issues/126131
76    # SequentialLR: {"schedulers": None, "milestones": [1, 2]},
77    # ChainedScheduler: {"schedulers": None},
78    CyclicLR: {"base_lr": 0.001, "max_lr": 0.02, "cycle_momentum": False},
79    CosineAnnealingWarmRestarts: {"T_0": 1},
80    OneCycleLR: {
81        "max_lr": 0.02,
82        "cycle_momentum": False,
83        "steps_per_epoch": 1,
84        "epochs": 10,
85    },
86    ConstantLR: {"factor": 0.001},
87    LinearLR: {},
88    ReduceLROnPlateau: {"factor": 0.99, "patience": 1},
89    PolynomialLR: {},
90}
91
92
93def create_scheduler(scheduler, optim):
94    kwargs = LR_SCHEDULER_TO_KWARGS[scheduler]
95    if "schedulers" in kwargs:
96        kwargs["schedulers"] = [
97            create_scheduler(torch.optim.lr_scheduler.ConstantLR, optim)
98            for _ in range(2)
99        ] + [create_scheduler(torch.optim.lr_scheduler.LambdaLR, optim)]
100
101    if scheduler == ChainedScheduler:
102        return scheduler(**kwargs)
103    else:
104        return scheduler(optim, **kwargs)
105
106
107class KernelCounts(NamedTuple):
108    multitensor: int
109    singletensor: int
110
111
112# With different settings for certain
113# tests you can get different kernel counts
114# This maps the test name to the
115# expected kernel count
116KERNEL_COUNT_OVERRIDES = {
117    "test_rmsprop_foreach_weight_decay_cpu": 12,
118    "test_nadam_foreach_weight_decay_momentum_decay_cpu": 20,
119    "test_adamw_amsgrad_capturable_foreach_cuda": 3,
120    "test_adamw_amsgrad_capturable_foreach_xpu": 3,
121    "test_adamw_amsgrad_capturable_cuda": 6,
122    "test_adamw_amsgrad_capturable_xpu": 6,
123    "test_adamw_tensor_lr_amsgrad_capturable_foreach_cuda": 3,
124    "test_adamw_tensor_lr_amsgrad_capturable_foreach_xpu": 3,
125    "test_adamw_tensor_lr_amsgrad_capturable_cuda": 6,
126    "test_adamw_tensor_lr_amsgrad_capturable_xpu": 6,
127    "test_adam_tensor_lr_amsgrad_capturable_cuda": 6,
128    "test_adam_tensor_lr_amsgrad_capturable_xpu": 6,
129    "test_adam_amsgrad_capturable_cuda": 6,
130    "test_adam_amsgrad_capturable_xpu": 6,
131    "test_adadelta_tensor_lr_capturable_cuda": 6,
132    "test_adadelta_tensor_lr_capturable_xpu": 6,
133    "test_rmsprop_tensor_lr_capturable_cuda": 6,
134    "test_rmsprop_tensor_lr_capturable_xpu": 6,
135    "test_adadelta_tensor_lr_capturable_foreach_cuda": 4,
136    "test_adadelta_tensor_lr_capturable_foreach_xpu": 4,
137    "test_adadelta_foreach_weight_decay_maximize_cpu": 12,
138    "test_adadelta_foreach_rho_weight_decay_cpu": 12,
139    "test_adadelta_foreach_weight_decay_cpu": 12,
140    "test_sgd_foreach_momentum_weight_decay_cpu": 16,
141    "test_sgd_foreach_momentum_nesterov_weight_decay_cpu": 16,
142    "test_sgd_momentum_dampening_foreach_cuda": 5,
143    "test_sgd_momentum_dampening_foreach_xpu": 5,
144    "test_sgd_momentum_foreach_cuda": 5,
145    "test_sgd_momentum_foreach_xpu": 5,
146    "test_sgd_weight_decay_maximize_cuda": 4,
147    "test_sgd_weight_decay_maximize_xpu": 4,
148    "test_sgd_weight_decay_maximize_cpu": 4,
149    "test_sgd_weight_decay_cpu": 4,
150    "test_sgd_weight_decay_cuda": 4,
151    "test_sgd_weight_decay_xpu": 4,
152    "test_sgd_momentum_weight_decay_foreach_cuda": 2,
153    "test_sgd_momentum_weight_decay_foreach_xpu": 2,
154    "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2,
155    "test_sgd_momentum_nesterov_weight_decay_foreach_xpu": 2,
156    "test_sgd_cuda": 4,
157    "test_sgd_cpu": 4,
158    "test_sgd_xpu": 4,
159    "test_rmsprop_tensor_lr_capturable_foreach_cuda": 4,
160    "test_rmsprop_tensor_lr_capturable_foreach_xpu": 4,
161    "test_adagrad_initial_accumulator_value_weight_decay_foreach_xpu": 2,
162    "test_adagrad_lr_decay_weight_decay_foreach_xpu": 2,
163    "test_adagrad_weight_decay_foreach_xpu": 2,
164    "test_adagrad_weight_decay_maximize_foreach_xpu": 2,
165    "test_adagrad_tensor_lr_cpu": 6,
166    "test_adagrad_tensor_lr_cuda": 6,
167    "test_adagrad_tensor_lr_xpu": 6,
168    "test_adamax_tensor_lr_weight_decay_capturable_cuda": 6,
169    "test_adamax_tensor_lr_weight_decay_capturable_xpu": 6,
170    "test_asgd_tensor_lr_weight_decay_maximize_capturable_cuda": 5,
171    "test_asgd_tensor_lr_weight_decay_maximize_capturable_xpu": 8,
172    "test_asgd_tensor_lr_weight_decay_maximize_capturable_foreach_cuda": 4,
173    "test_asgd_tensor_lr_weight_decay_maximize_capturable_foreach_xpu": 4,
174    "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": 6,
175    "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": 9,
176    "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_foreach_cuda": 3,
177    "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_foreach_xpu": 3,
178    "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_cuda": 6,
179    "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_xpu": 6,
180    "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_foreach_cuda": 3,
181    "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_foreach_xpu": 3,
182    "test_sgd_tensor_lr_cpu": 2,
183    "test_sgd_tensor_lr_cuda": 2,
184    "test_sgd_tensor_lr_xpu": 2,
185    "test_sgd_tensor_lr_foreach_cuda": 2,
186    "test_sgd_tensor_lr_foreach_xpu": 2,
187}
188
189# also tracks currently supported optimizers
190KERNEL_COUNTS = {
191    Adam: KernelCounts(multitensor=2, singletensor=8),
192    AdamW: KernelCounts(multitensor=2, singletensor=8),
193    NAdam: KernelCounts(multitensor=2, singletensor=8),
194    Rprop: KernelCounts(multitensor=2, singletensor=8),
195    RMSprop: KernelCounts(multitensor=2, singletensor=8),
196    Adadelta: KernelCounts(multitensor=2, singletensor=8),
197    Adagrad: KernelCounts(multitensor=2, singletensor=8),
198    SGD: KernelCounts(multitensor=1, singletensor=8),
199    ASGD: KernelCounts(multitensor=2, singletensor=8),
200    RAdam: KernelCounts(multitensor=2, singletensor=8),
201    Adamax: KernelCounts(multitensor=2, singletensor=8),
202}
203
204
205def build_opt_kwarg_db():
206    compiled_opt_db = []
207    for optim_info in optim_db:
208        if optim_info.optim_cls not in KERNEL_COUNTS:
209            continue
210
211        for device in ["cpu", GPU_TYPE]:
212            for optim_inputs in _get_optim_inputs_including_global_cliquey_kwargs(
213                device, None, optim_info, skip=("differentiable", "fused")
214            ):
215                kwargs = dict(optim_inputs.kwargs)
216                name = f"test_{optim_info.optim_cls.__name__.lower()}"
217
218                has_tensor_lr = False
219                for key, val in kwargs.items():
220                    if (not key == "lr" and not key == "betas") and (
221                        not isinstance(val, bool) or (isinstance(val, bool) and val)
222                    ):
223                        name += "_" + key
224
225                    if key == "lr" and isinstance(kwargs["lr"], torch.Tensor):
226                        has_tensor_lr = True
227                        name += "_tensor_lr"
228
229                    if key == "betas" and isinstance(kwargs["betas"][0], torch.Tensor):
230                        name += "_tensor_betas"
231
232                name += f"_{device}"
233
234                kwargs["device"] = device
235                if name in KERNEL_COUNT_OVERRIDES:
236                    kwargs["kernel_count"] = KERNEL_COUNT_OVERRIDES[name]
237                else:
238                    kwargs["kernel_count"] = (
239                        KERNEL_COUNTS[optim_info.optim_cls].multitensor
240                        if kwargs.get("foreach", False) and device == GPU_TYPE
241                        else KERNEL_COUNTS[optim_info.optim_cls].singletensor
242                    )
243
244                if kwargs["kernel_count"] is None or kwargs.get("fused", False):
245                    continue
246
247                if has_tensor_lr:
248                    for scheduler_cls in LR_SCHEDULER_TO_KWARGS.keys():
249                        name_w_scheduler = name + f"_{scheduler_cls.__name__.lower()}"
250                        compiled_opt_db.append(
251                            (
252                                optim_info.optim_cls,
253                                name_w_scheduler,
254                                kwargs,
255                                scheduler_cls,
256                            )
257                        )
258                else:
259                    compiled_opt_db.append((optim_info.optim_cls, name, kwargs, None))
260
261    return compiled_opt_db
262
263
264COMPILED_OPT_KWARG_DB = build_opt_kwarg_db()
265
266aten = torch.ops.aten
267
268
269try:
270    try:
271        from .test_torchinductor import check_model, check_model_gpu
272    except ImportError:
273        from test_torchinductor import check_model, check_model_gpu
274except (unittest.SkipTest, ImportError) as e:
275    sys.stderr.write(f"{type(e)}: {e}\n")
276    if __name__ == "__main__":
277        sys.exit(0)
278    raise
279
280
281def call_scheduler(scheduler):
282    if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
283        scheduler.step(1.0)  # we won't reduce the metric over two iters anyway
284    else:
285        scheduler.step()
286
287
288def compile_opt(opt_compiled, closure=None, fullgraph=True):
289    # run the patcher so that step has the expected structure
290    torch._dynamo.eval_frame.TorchPatcher.patch()
291
292    # unwrap step TWICE to avoid a deliberate graph break due to
293    # a limitation of functionalization/no_grad detection
294    # see the [Note on graph break] in optimizer.py
295    # This ignores the outer _use_grad_if_differentiable wrapper
296    # and instead manually disables grad before calling step, which is fine
297    # for now as dynamo does not support differentiable optimizers anyway
298    step_fn = opt_compiled.step.__wrapped__.__wrapped__
299
300    # This ensures we don't receive spam of warnings from LR Scheduler
301    opt_compiled._opt_called = True
302
303    if closure is not None:
304
305        def fn():
306            step_fn(opt_compiled, closure)
307
308    else:
309
310        def fn():
311            step_fn(opt_compiled)
312
313    return torch.compile(fn, backend="inductor", fullgraph=fullgraph)
314
315
316def check_optim(
317    self,
318    optim_cls,
319    params_eager,
320    params_compiled,
321    state_eager,
322    state_compiled,
323    atol=None,
324    rtol=None,
325):
326    params_eager = list(params_eager)
327    params_compiled = list(params_compiled)
328    # Note on tolerances:
329    # test_correctness_Adadelta_cuda_float32
330    # Mismatched elements: 10 / 100 (10.0%)
331    # Greatest absolute difference: 4.838220775127411e-05 at index (7, 4) (up to 1e-05 allowed)
332    # Greatest relative difference: 0.007270356640219688 at index (7, 2) (up to 1e-05 allowed)
333    # This is due to floating point ordering error + usage of sqrt
334    rtol = None
335    atol = None
336    if optim_cls is Adadelta:
337        rtol = 5.5e-4
338        atol = 5e-5
339
340    self.assertEqual(list(params_eager), list(params_compiled), atol=atol, rtol=rtol)
341
342    for p_eager, p_compiled in zip(params_eager, params_compiled):
343        self.assertEqual(
344            state_eager[p_eager],
345            state_compiled[p_compiled],
346            atol=atol,
347            rtol=rtol,
348        )
349
350
351def make_test(
352    optim_cls,
353    closure=None,
354    scheduler_cls=None,
355    kernel_count=2,
356    device="cuda",
357    **kwargs,
358):
359    def test_fn(self):
360        stack = ExitStack()
361        try:
362            # https://github.com/pytorch/pytorch/issues/118715 for capturable Adagrad support
363            # https://github.com/pytorch/pytorch/issues/118018 for capturable SGD support
364            run_cudagraphs = device == "cuda" and optim_cls not in (Adagrad, SGD)
365            if run_cudagraphs:
366                stack.enter_context(config.patch({"triton.cudagraphs": True}))
367
368            kwargs_compiled = deepcopy(kwargs)
369            if isinstance(kwargs.get("lr", None), torch.Tensor):
370                kwargs["lr"] = kwargs["lr"].to(device)
371                kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device)
372
373            if "betas" in kwargs and isinstance(kwargs["betas"][0], torch.Tensor):
374                kwargs["betas"] = (
375                    kwargs["betas"][0].to(device),
376                    kwargs["betas"][1].to(device),
377                )
378                kwargs_compiled["betas"] = (
379                    kwargs_compiled["betas"][0].to(device),
380                    kwargs_compiled["betas"][1].to(device),
381                )
382
383            torch._dynamo.reset()
384            torch._inductor.metrics.reset()
385            input = torch.ones([10, 10], device=device)
386            model_eager = torch.nn.Sequential(
387                *[torch.nn.Linear(10, 10, device=device) for _ in range(2)]
388            )
389            model_eager(input).sum().backward()
390
391            input = torch.ones([10, 10], device=device)
392            model_compiled = deepcopy(model_eager)
393            model_compiled(input).sum().backward()
394
395            opt_eager = optim_cls(model_eager.parameters(), **kwargs)
396            opt_compiled = optim_cls(model_compiled.parameters(), **kwargs_compiled)
397            compiled_step = compile_opt(opt_compiled, closure=closure)
398
399            if scheduler_cls:
400                scheduler_compiled = create_scheduler(scheduler_cls, opt_compiled)
401                scheduler_eager = create_scheduler(scheduler_cls, opt_eager)
402                # some schedulers only change after at least an epoch has passed
403                scheduler_compiled.last_epoch = 1
404                scheduler_eager.last_epoch = 1
405
406            with torch.set_grad_enabled(False):
407                for i in range(2):
408                    compiled_step()
409                    opt_eager.step()
410                    if scheduler_cls:
411                        call_scheduler(scheduler_eager)
412                        call_scheduler(scheduler_compiled)
413
414            check_optim(
415                self,
416                optim_cls,
417                model_eager.parameters(),
418                model_compiled.parameters(),
419                opt_eager.state,
420                opt_compiled.state,
421            )
422
423            if run_cudagraphs:
424                self.check_cudagraphs_ran()
425
426            if self.check_kernel_count:
427                # currently, we compile the step and the rest of the computation
428                # separately because the step is a single element tensor
429                # hence, the usual kernel count is 2
430                self.assertEqual(
431                    torch._inductor.metrics.generated_kernel_count, kernel_count
432                )
433        finally:
434            stack.close()
435
436    if device == GPU_TYPE:
437        test_fn = requires_gpu(test_fn)
438
439    return test_fn
440
441
442def make_recompile_test(optim_cls, closure=None, kernel_count=2, **kwargs):
443    @requires_gpu
444    def test_fn(self):
445        torch._dynamo.reset()
446        torch._inductor.metrics.reset()
447        input = torch.ones([10, 10], device=GPU_TYPE)
448        model = torch.nn.Sequential(
449            *[torch.nn.Linear(10, 10, device=GPU_TYPE) for _ in range(2)]
450        )
451        model(input).sum().backward()
452
453        opt_compiled = optim_cls(model.parameters(), **kwargs)
454        compiled_step = compile_opt(opt_compiled)
455
456        # check no recompile here
457        with torch.set_grad_enabled(False):
458            for _ in range(4):
459                compiled_step()
460
461            # perturb state to force recompile
462            # Adagrad doesn't reinitialize state on each step
463            # SGD has an empty state
464            if optim_cls in (Adagrad, SGD):
465                opt_compiled.param_groups[0]["lr"] = 0.02
466            elif optim_cls is Adam:  # ensure we are guarding on the data_ptr of states
467                state_tensor = opt_compiled.state[
468                    opt_compiled.param_groups[0]["params"][0]
469                ]["exp_avg"]
470                opt_compiled.state[opt_compiled.param_groups[0]["params"][0]][
471                    "exp_avg"
472                ] = torch.zeros_like(state_tensor)
473            else:
474                opt_compiled.state.clear()
475
476            compiled_step()
477
478        if self.check_kernel_count:
479            # currently, we compile the step and the rest of the computation
480            # separately because the step is a single element tensor
481            # hence, the usual kernel count is 2
482            # multiply by 2 to account for the recompile
483            multiplier = 2
484
485            self.assertEqual(
486                torch._inductor.metrics.generated_kernel_count,
487                multiplier * kernel_count,
488            )
489
490    return test_fn
491
492
493class CompiledOptimizerParityTests(TestCase):
494    @skipCUDAIf(not has_triton(), "torch.compile with cuda requires triton")
495    @skipXPUIf(not has_triton(), "torch.compile with xpu requires triton")
496    @optims(optim_db, dtypes=[torch.float32])
497    @parametrize("use_closure", [True, False])
498    def test_correctness(self, device, dtype, optim_info, use_closure):
499        optim_cls = optim_info.optim_cls
500        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
501            device, dtype, optim_info, skip=("differentiable",)
502        )
503
504        if optim_info.step_requires_closure and not use_closure:
505            return
506
507        for optim_input in all_optim_inputs:
508            kwargs = optim_input.kwargs
509
510            use_scheduler = isinstance(kwargs.get("lr", None), torch.Tensor)
511            scheduler_classes = (
512                list(LR_SCHEDULER_TO_KWARGS.keys()) if use_scheduler else [None]
513            )
514
515            for scheduler_cls in scheduler_classes:
516                torch._dynamo.reset()
517                torch._inductor.metrics.reset()
518                input = torch.ones([10, 10], device=device)
519                model_eager = torch.nn.Sequential(
520                    *[torch.nn.Linear(10, 10, device=device) for _ in range(2)]
521                )
522                model_eager(input).sum().backward()
523                model_compiled = deepcopy(model_eager)
524                model_compiled(input).sum().backward()
525
526                if optim_cls is SparseAdam:
527                    for param in model_eager.parameters():
528                        param.grad = param.grad.to_sparse()
529                    for param in model_compiled.parameters():
530                        param.grad = param.grad.to_sparse()
531
532                opt_compiled = optim_cls(
533                    model_compiled.parameters(), **deepcopy(kwargs)
534                )
535                opt_eager = optim_cls(model_eager.parameters(), **deepcopy(kwargs))
536                if scheduler_cls:
537                    scheduler_compiled = create_scheduler(scheduler_cls, opt_compiled)
538                    scheduler_eager = create_scheduler(scheduler_cls, opt_eager)
539                    # some schedulers only change after at least an epoch has passed
540                    scheduler_compiled.last_epoch = 1
541                    scheduler_eager.last_epoch = 1
542
543                num_steps = 2
544                if use_closure:
545
546                    @torch.compile()
547                    def fn():
548                        def closure():
549                            loss = model_compiled(input).sum()
550                            loss.backward()
551                            if optim_info.only_supports_sparse_grads:
552                                for param in model_compiled.parameters():
553                                    param.grad = param.grad.to_sparse()
554                            return loss
555
556                        opt_compiled.step(closure)
557                        if scheduler_cls:
558                            call_scheduler(scheduler_compiled)
559
560                    def closure_eager():
561                        loss = model_eager(input).sum()
562                        loss.backward()
563                        if optim_info.only_supports_sparse_grads:
564                            for param in model_eager.parameters():
565                                param.grad = param.grad.to_sparse()
566
567                        return loss
568
569                    for _ in range(num_steps):
570                        opt_eager.step(closure_eager)
571                        if scheduler_cls:
572                            call_scheduler(scheduler_eager)
573                else:
574
575                    @torch.compile()
576                    def fn():
577                        opt_compiled.step()
578                        if scheduler_cls:
579                            call_scheduler(scheduler_compiled)
580
581                    for _ in range(num_steps):
582                        opt_eager.step()
583                        if scheduler_cls:
584                            call_scheduler(scheduler_eager)
585
586                for _ in range(num_steps):
587                    fn()
588
589                check_optim(
590                    self,
591                    optim_cls,
592                    model_eager.parameters(),
593                    model_compiled.parameters(),
594                    opt_eager.state,
595                    opt_compiled.state,
596                )
597
598
599class CompiledOptimizerTests(TestCase):
600    check_model_gpu = check_model_gpu
601    check_model_cpu = check_model
602    check_kernel_count = True
603
604    def setUp(self):
605        super().setUp()
606        torch._dynamo.reset()
607        torch._inductor.metrics.reset()
608
609    def tearDown(self):
610        super().tearDown()
611        torch._dynamo.reset()
612        torch._inductor.metrics.reset()
613
614    def check_cudagraphs_ran(self):
615        # We run the zeroth device currently
616        manager = torch._inductor.cudagraph_trees.get_container(0).tree_manager
617        self.assertIsNotNone(manager)
618        self.assertEqual(manager.new_graph_id().id, 1)
619
620    test_adam_recompile = make_recompile_test(Adam, lr=0.01)
621    test_adamw_recompile = make_recompile_test(AdamW, lr=0.01)
622    test_adamax_recompile = make_recompile_test(Adamax, lr=0.01)
623    test_nadam_recompile = make_recompile_test(NAdam, lr=0.01)
624    test_rprop_recompile = make_recompile_test(Rprop, lr=0.01, kernel_count=2)
625    test_rmsprop_recompile = make_recompile_test(RMSprop, lr=0.01)
626    test_adadelta_recompile = make_recompile_test(Adadelta, lr=0.01)
627    test_adagrad_recompile = make_recompile_test(Adagrad, lr=0.01)
628    test_asgd_recompile_default = make_recompile_test(ASGD, lr=0.01)
629    test_asgd_recompile_single = make_recompile_test(
630        ASGD, kernel_count=8, lr=0.01, foreach=False
631    )
632    test_asgd_recompile_foreach = make_recompile_test(ASGD, lr=0.01, foreach=True)
633    test_sgd_recompile_single = make_recompile_test(
634        SGD, kernel_count=4, lr=0.01, foreach=False
635    )
636    test_sgd_recompile_foreach = make_recompile_test(
637        SGD, kernel_count=1, lr=0.01, foreach=True
638    )
639
640    @requires_gpu
641    def test_static_address_finalizer(self):
642        import gc
643
644        gc.disable()
645        p_ref = None
646
647        def fn():
648            nonlocal p_ref
649            mod = torch.nn.Linear(10, 10, device=GPU_TYPE, bias=False)
650            for p in mod.parameters():
651                p.grad = torch.rand_like(p)
652
653            opt = torch.optim.Adam(mod.parameters(), lr=0.1)
654
655            def fn():
656                opt.step()
657
658            with torch.set_grad_enabled(False):
659                step_fn_compiled = torch.compile(fn)
660                step_fn_compiled()
661            p_ref = weakref.ref(p)
662            self.assertTrue(p_ref() is not None)
663
664        fn()
665
666        self.assertTrue(p_ref() is None)
667        gc.enable()
668
669    def test_guard_on_none_grads(self):
670        def training_loop():
671            input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).reshape(3, 2)
672
673            model = torch.nn.Sequential(
674                torch.nn.Linear(2, 3),
675                torch.nn.Sigmoid(),
676                torch.nn.Linear(3, 1),
677                torch.nn.Sigmoid(),
678            )
679
680            params = list(model.parameters())
681            optimizer = torch.optim.Adam(params)
682            step_list = []
683
684            for i in range(6):
685                optimizer.zero_grad()
686                # Test that step behaves as expected (a no-op) when grads are set to None
687                if i != 3:
688                    output = model(input)
689                    loss = output.sum()
690                    loss.backward()
691
692                optimizer.step()
693                step_list.append(optimizer.state[params[0]]["step"])
694
695            return step_list
696
697        compiled_training_loop = torch._dynamo.optimize("eager")(training_loop)
698        actual_steps = compiled_training_loop()
699        expected_steps = training_loop()
700        self.assertEqual(actual_steps, expected_steps)
701
702    # Basic shampoo test to verify we support compiling the various ops without error
703    @requires_gpu
704    def test_basic_shampoo(self):
705        param_buf = torch.rand((1024, 128))
706        param_buf_c = param_buf.clone().detach()
707
708        params_c = [param_buf_c[0:512, :].t(), param_buf_c[512:, :].t()]
709        params = [param_buf[0:512, :].t(), param_buf[512:, :].t()]
710
711        for p, p_c in zip(params, params_c):
712            p.grad = torch.rand_like(p)
713            p_c.grad = p.grad.clone().detach()
714
715        # note this skips the root inverse because this has a lot of internal dependencies
716        # we also don't compile it regardless
717        @torch.no_grad()
718        def shampoo_functional_basic(params):
719            step = 1
720            weight_decay = 0.1
721            grads = [p.grad for p in params]
722            beta1 = 0.9
723            beta2 = 1.0
724            epsilon = 1e-10
725            preconditioners = [torch.zeros_like(p) for p in params]
726            lr = 0.01
727
728            # pt2 region 1
729            # weight decay
730            torch._foreach_add_(grads, params, alpha=weight_decay)
731
732            # update preconditioners
733            torch._foreach_addcmul_(preconditioners, grads, grads, value=1.0)
734
735            torch._foreach_mul_(grads, beta1)
736            torch._foreach_add_(
737                grads,
738                grads,
739                alpha=1 - beta1,
740            )
741            bias_correction1 = 1.0 - beta1**step
742            grad_list = torch._foreach_div(grads, bias_correction1)
743
744            # pt2 region 2
745            # precondition (with shampoo branch), with no grafting
746            bias_correction2 = 1.0 - beta2**step
747            bias_corrected_preconditioner_list = torch._foreach_div(
748                preconditioners, bias_correction2
749            )
750            torch._foreach_sqrt_(bias_corrected_preconditioner_list)
751            torch._foreach_add_(bias_corrected_preconditioner_list, epsilon)
752            search_directions = torch._foreach_div(
753                grad_list, bias_corrected_preconditioner_list
754            )
755
756            torch._foreach_add_(
757                search_directions,
758                params,
759                alpha=weight_decay,
760            )
761
762            torch._foreach_mul_(search_directions, -lr)
763            # pt2 region 3 update params
764            torch._foreach_add_(params, search_directions)
765
766            return params, preconditioners, grads
767
768        compiled_fn = torch.compile(shampoo_functional_basic)
769
770        self.assertEqual(compiled_fn(params_c), shampoo_functional_basic(params))
771
772    @requires_gpu
773    def test_closure_graph_break(self):
774        param = torch.rand(
775            2, 3, dtype=torch.float32, device=GPU_TYPE, requires_grad=True
776        )
777        param_c = param.clone().detach().requires_grad_(True)
778
779        def closure():
780            param.grad = torch.ones_like(param) * 2
781            return param.grad
782
783        def closure_c():
784            param_c.grad = torch.ones_like(param_c) * 2
785            return param_c.grad
786
787        optimizer = torch.optim.AdamW([param])
788        optimizer_c = torch.optim.AdamW([param_c])
789
790        def loop(opt, c):
791            opt.step(c)
792
793        compiled_loop = torch._dynamo.optimize("eager")(loop)
794
795        compiled_loop(optimizer, closure)
796        loop(optimizer_c, closure_c)
797
798        self.assertEqual(param, param_c)
799
800    def test_get_value_on_static_address(self):
801        from torch._dynamo.decorators import mark_static_address
802        from torch.optim.optimizer import _get_value
803
804        compiled = torch.compile(_get_value)
805
806        x = torch.ones(2, 2)
807        mark_static_address(x)
808
809        ret_val = compiled(x)
810
811        self.assertEqual(ret_val, x)
812
813    # compile a large foreach op and verify
814    # that the time taken is within an expected range
815    @requires_gpu
816    def test_compile_time_smoketest(self):
817        import time
818
819        xs = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(100)]
820        ys = [torch.ones(2, 2, device=GPU_TYPE) for _ in range(100)]
821
822        @torch.compile
823        def fn(xs, ys):
824            return torch._foreach_add(xs, ys)
825
826        start = time.perf_counter()
827        fn(xs, ys)
828        end = time.perf_counter()
829
830        self.assertLess(end - start, 90)
831
832    @requires_cuda
833    def test_S429861(self):
834        # Just verify we can compile this function without error
835        try:
836            from . import s429861_repro
837        except ImportError:
838            import s429861_repro
839
840        forward = s429861_repro.forward
841
842        import torch._dynamo
843        import torch._inductor
844        from torch._dynamo.debug_utils import aot_graph_input_parser
845        from torch._inductor.utils import fresh_inductor_cache
846
847        with fresh_inductor_cache():
848            kwargs = aot_graph_input_parser(forward)
849            torch.compile(forward)(**kwargs)
850
851
852for optim_cls, name, kwargs, scheduler_cls in COMPILED_OPT_KWARG_DB:
853    setattr(
854        CompiledOptimizerTests,
855        name,
856        make_test(optim_cls, scheduler_cls=scheduler_cls, **kwargs),
857    )
858
859instantiate_device_type_tests(
860    CompiledOptimizerParityTests, globals(), allow_xpu=True, except_for="cpu"
861)
862
863if __name__ == "__main__":
864    from torch._inductor.test_case import run_tests
865
866    if HAS_CPU or HAS_GPU:
867        run_tests(needs="filelock")
868