xref: /aosp_15_r20/external/pytorch/test/distributed/_tensor/test_optimizers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3from copy import deepcopy
4
5import torch
6import torch.nn as nn
7from torch.distributed._tensor import (
8    DeviceMesh,
9    distribute_module,
10    distribute_tensor,
11    DTensor,
12    Replicate,
13    Shard,
14)
15from torch.testing._internal.common_utils import run_tests
16from torch.testing._internal.distributed._tensor.common_dtensor import (
17    DTensorTestBase,
18    MLPModule,
19    with_comms,
20)
21
22
23# shard function to do full sharding on all parameters of a module
24def shard_fn(name, module, device_mesh):
25    if isinstance(module, nn.Linear):
26        for name, param in module.named_parameters():
27            dist_param = torch.nn.Parameter(
28                distribute_tensor(param, device_mesh, [Shard(0)])
29            )
30            # make sure partial sum get cleared after backward()
31            dist_param.register_hook(
32                lambda grad: grad.redistribute(placements=[Shard(0)])
33            )
34            module.register_parameter(name, dist_param)
35
36
37# prepare input
38def input_fn(mod, inputs, device_mesh):
39    # split the input tensor to be sharded input
40    dist_inp = distribute_tensor(inputs[0], device_mesh, [Shard(0)])
41    return dist_inp
42
43
44# prepare output to be local torch.Tensor
45def output_fn(mod, outputs, device_mesh):
46    assert isinstance(outputs, DTensor)
47    return outputs.redistribute(placements=[Replicate()] * device_mesh.ndim).to_local()
48
49
50class TestDTensorOptimizer(DTensorTestBase):
51    def _assert_optimizer(
52        self,
53        mesh,
54        model,
55        optim,
56        dist_model,
57        dist_optim,
58        inputs,
59        *,
60        rtol: float = 1.3e-6,
61        atol: float = 1e-5,
62    ):
63        for iter_idx in range(2):
64            # run forward/backward/optim for original model
65            optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
66            out = model(inputs)
67            loss = out.sum()
68            loss.backward()
69            optim.step()
70
71            # run forward/backward/optim for distributed model
72            dist_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
73            dist_out = dist_model(inputs)
74            dist_loss = dist_out.sum()
75            dist_loss.backward()
76            dist_optim.step()
77
78            # check that the optimizer update parameters with same numerics
79            for p1, p2 in zip(model.parameters(), dist_model.parameters()):
80                p2 = p2.full_tensor()
81                # Default 'rtol' and 'atol' for attr:`~torch.float32` are ``1.3e-6`` and ``1e-5``
82                self.assertEqual(p1, p2, atol=atol, rtol=rtol)
83
84    def test_optimizer_foreach_supported_types_include_DTensor(self):
85        from torch.optim.optimizer import _foreach_supported_types
86
87        self.assertTrue(DTensor in _foreach_supported_types)
88
89    @with_comms
90    def test_adam_1d_sharding(self):
91        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
92
93        # lr as a Tensor is not supported for capturable=False and foreach=True
94        adam_float_lr_configs = [
95            {"lr": 0.1, "foreach": False},
96            {"lr": 0.1, "weight_decay": 0.05, "foreach": False},
97            {"lr": 0.1, "weight_decay": 0.05},
98            {"lr": 0.1, "weight_decay": 0.05, "amsgrad": True},
99            {
100                "lr": 0.1,
101                "weight_decay": 0.05,
102                "maximize": True,
103                "amsgrad": True,
104            },
105        ]
106        fused_adam_float_lr_configs = [
107            {"lr": 0.1, "fused": True},
108            {"lr": 0.1, "weight_decay": 0.05, "amsgrad": True, "fused": True},
109            {
110                "lr": 0.1,
111                "weight_decay": 0.05,
112                "maximize": True,
113                "amsgrad": True,
114                "fused": True,
115            },
116        ]
117        # lr could be a Tensor or a float when fused=True for adam optimizer
118        fused_adam_tensor_lr_configs = [
119            {**config, "lr": torch.tensor(0.1)}
120            for config in fused_adam_float_lr_configs
121        ]
122        fused_adam_tensor_lr_configs.extend(
123            [
124                {**config, "lr": torch.tensor([0.1])}
125                for config in fused_adam_float_lr_configs
126            ]
127        )
128        adam_configs = [
129            *adam_float_lr_configs,
130            *fused_adam_float_lr_configs,
131            *fused_adam_tensor_lr_configs,
132        ]
133
134        for config in adam_configs:
135            mod = MLPModule(self.device_type)
136            opt = torch.optim.Adam(mod.parameters(), **config)
137
138            dist_mod = distribute_module(
139                deepcopy(mod), mesh, shard_fn, input_fn, output_fn
140            )
141            dist_opt = torch.optim.Adam(dist_mod.parameters(), **config)
142
143            # use ones to make sure the single machine model have the same input
144            # on different ranks
145            inp = torch.ones(8, 10, device=self.device_type)
146            self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
147
148    @with_comms
149    def test_adamw_1d_sharding(self):
150        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
151
152        # lr as a Tensor is not supported for capturable=False and foreach=True
153        adamw_float_lr_configs = [
154            {"lr": 0.1, "foreach": False},
155            {"lr": 0.1, "weight_decay": 0.05, "foreach": False},
156            {"lr": 0.1, "weight_decay": 0.05},
157            {
158                "lr": 0.1,
159                "betas": (0.6, 0.66),
160                "eps": 1e-6,
161                "weight_decay": 0.05,
162                "amsgrad": True,
163            },
164            {
165                "lr": 0.1,
166                "betas": (0.6, 0.66),
167                "eps": 1e-6,
168                "weight_decay": 0.05,
169                "maximize": True,
170                "amsgrad": True,
171            },
172        ]
173        fused_adamw_float_lr_configs = [
174            {"lr": 0.1, "weight_decay": 0.05, "fused": True},
175            {
176                "lr": 0.1,
177                "betas": (0.6, 0.66),
178                "eps": 1e-6,
179                "weight_decay": 0.05,
180                "amsgrad": True,
181                "fused": True,
182            },
183            {
184                "lr": 0.1,
185                "betas": (0.6, 0.66),
186                "eps": 1e-6,
187                "weight_decay": 0.05,
188                "maximize": True,
189                "amsgrad": True,
190                "fused": True,
191            },
192        ]
193        # lr could be a Tensor or a float when fused=True for adamW optimizer
194        fused_adamw_tensor_lr_configs = [
195            {**config, "lr": torch.tensor(0.1)}
196            for config in fused_adamw_float_lr_configs
197        ]
198        fused_adamw_tensor_lr_configs.extend(
199            [
200                {**config, "lr": torch.tensor([0.1])}
201                for config in fused_adamw_float_lr_configs
202            ]
203        )
204        adamw_configs = [
205            *adamw_float_lr_configs,
206            *fused_adamw_float_lr_configs,
207            *fused_adamw_tensor_lr_configs,
208        ]
209
210        for config in adamw_configs:
211            mod = MLPModule(self.device_type)
212            opt = torch.optim.AdamW(mod.parameters(), **config)
213
214            dist_mod = distribute_module(
215                deepcopy(mod), mesh, shard_fn, input_fn, output_fn
216            )
217            dist_opt = torch.optim.AdamW(dist_mod.parameters(), **config)
218
219            # use ones to make sure the single machine model have the same input
220            # on different ranks
221            inp = torch.ones(8, 10, device=self.device_type)
222            self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
223
224    @with_comms
225    def test_sgd_1d_sharding(self):
226        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
227
228        sgd_configs = [
229            {"lr": 0.1, "foreach": False},
230            {"lr": 0.1, "momentum": 0.05, "foreach": False},
231            {"lr": 0.1, "momentum": 0.05},
232            {"lr": 0.1, "momentum": 0.06, "dampening": 0.07},
233            {
234                "lr": 0.1,
235                "momentum": 0.08,
236                "weight_decay": 0.05,
237                "nesterov": True,
238                "maximize": True,
239                "foreach": False,
240            },
241            {
242                "lr": 0.1,
243                "momentum": 0.08,
244                "weight_decay": 0.05,
245                "nesterov": True,
246                "maximize": True,
247            },
248        ]
249
250        for config in sgd_configs:
251            mod = MLPModule(self.device_type)
252            opt = torch.optim.SGD(mod.parameters(), **config)
253
254            dist_mod = distribute_module(
255                deepcopy(mod), mesh, shard_fn, input_fn, output_fn
256            )
257            dist_opt = torch.optim.SGD(dist_mod.parameters(), **config)
258
259            # use ones to make sure the single machine model have the same input
260            # on different ranks
261            inp = torch.ones(8, 10, device=self.device_type)
262            self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
263
264    @with_comms
265    def test_adagrad_1d_sharding(self):
266        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
267
268        adagrad_configs = [
269            {"lr": 0.1, "foreach": False},
270            {"lr": 0.1, "lr_decay": 0.05, "foreach": False},
271            {"lr": 0.1, "lr_decay": 0.02, "weight_decay": 0.05, "foreach": False},
272            {
273                "lr": 0.1,
274                "lr_decay": 0.02,
275                "weight_decay": 0.05,
276                "initial_accumulator_value": 0.03,
277                "foreach": False,
278            },
279            {
280                "lr": 0.1,
281                "lr_decay": 0.02,
282                "weight_decay": 0.05,
283                "initial_accumulator_value": 0.03,
284                "eps": 1e-6,
285                "foreach": False,
286            },
287            {
288                "lr": 0.1,
289                "lr_decay": 0.02,
290                "weight_decay": 0.05,
291                "initial_accumulator_value": 0.03,
292                "eps": 1e-6,
293                "maximize": True,
294                "foreach": False,
295            },
296            {
297                "lr": 0.1,
298                "lr_decay": 0.02,
299                "weight_decay": 0.05,
300                "initial_accumulator_value": 0.03,
301                "eps": 1e-6,
302                "maximize": True,
303            },
304        ]
305
306        for config in adagrad_configs:
307            mod = MLPModule(self.device_type)
308            opt = torch.optim.Adagrad(mod.parameters(), **config)
309
310            dist_mod = distribute_module(
311                deepcopy(mod), mesh, shard_fn, input_fn, output_fn
312            )
313            dist_opt = torch.optim.Adagrad(dist_mod.parameters(), **config)
314
315            # use ones to make sure the single machine model have the same input
316            # on different ranks
317            inp = torch.ones(8, 10, device=self.device_type)
318            self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
319
320    @with_comms
321    def test_RMSprop_1d_sharding(self):
322        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
323
324        RMSprop_configs = [
325            {"lr": 0.1, "foreach": False},
326            {"lr": 0.1, "alpha": 0.85, "foreach": False},
327            {"lr": 0.1, "alpha": 0.88, "eps": 1e-6, "foreach": False},
328            {
329                "lr": 0.1,
330                "alpha": 0.88,
331                "eps": 1e-6,
332                "weight_decay": 0.05,
333                "foreach": False,
334            },
335            {
336                "lr": 0.1,
337                "alpha": 0.88,
338                "eps": 1e-6,
339                "weight_decay": 0.05,
340                "momentum": 0.9,
341                "foreach": False,
342            },
343            {
344                "lr": 0.1,
345                "alpha": 0.88,
346                "eps": 1e-6,
347                "weight_decay": 0.05,
348                "momentum": 0.9,
349                "centered": True,
350                "foreach": False,
351            },
352            {
353                "lr": 0.1,
354                "alpha": 0.88,
355                "eps": 1e-6,
356                "weight_decay": 0.05,
357                "momentum": 0.9,
358                "centered": True,
359                "maximize": True,
360                "foreach": False,
361            },
362            {
363                "lr": 0.1,
364                "alpha": 0.88,
365                "eps": 1e-6,
366                "weight_decay": 0.05,
367                "momentum": 0.9,
368                "centered": True,
369                "maximize": True,
370            },
371        ]
372
373        for config in RMSprop_configs:
374            mod = MLPModule(self.device_type)
375            opt = torch.optim.RMSprop(mod.parameters(), **config)
376
377            dist_mod = distribute_module(
378                deepcopy(mod), mesh, shard_fn, input_fn, output_fn
379            )
380            dist_opt = torch.optim.RMSprop(dist_mod.parameters(), **config)
381
382            # use ones to make sure the single machine model have the same input
383            # on different ranks
384            inp = torch.ones(8, 10, device=self.device_type)
385            self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
386
387    @with_comms
388    def test_adadelta_1d_sharding(self):
389        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
390
391        adadelta_configs = [
392            {"lr": 0.1, "foreach": False},
393            {"lr": 0.1, "rho": 0.85, "foreach": False},
394            {"lr": 0.1, "rho": 0.88, "eps": 1e-5, "foreach": False},
395            {
396                "lr": 0.1,
397                "rho": 0.88,
398                "eps": 1e-6,
399                "weight_decay": 0.05,
400                "foreach": False,
401            },
402            {
403                "lr": 0.1,
404                "rho": 0.88,
405                "eps": 1e-6,
406                "weight_decay": 0.05,
407            },
408            {
409                "lr": 0.1,
410                "rho": 0.88,
411                "eps": 1e-6,
412                "weight_decay": 0.05,
413                "maximize": True,
414            },
415        ]
416
417        for config in adadelta_configs:
418            mod = MLPModule(self.device_type)
419            opt = torch.optim.Adadelta(mod.parameters(), **config)
420
421            dist_mod = distribute_module(
422                deepcopy(mod), mesh, shard_fn, input_fn, output_fn
423            )
424            dist_opt = torch.optim.Adadelta(dist_mod.parameters(), **config)
425
426            # use ones to make sure the single machine model have the same input
427            # on different ranks
428            inp = torch.ones(8, 10, device=self.device_type)
429            self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
430
431    @with_comms
432    def test_nadam_1d_sharding(self):
433        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
434
435        nadam_configs = [
436            {"lr": 0.1, "foreach": False},
437            {"lr": 0.1, "weight_decay": 0.05, "foreach": False},
438            {"lr": 0.1, "weight_decay": 0.05},
439            {
440                "lr": 0.1,
441                "betas": (0.6, 0.66),
442                "eps": 1e-6,
443                "weight_decay": 0.05,
444            },
445            {
446                "lr": 0.1,
447                "betas": (0.6, 0.66),
448                "eps": 1e-6,
449                "weight_decay": 0.05,
450                "decoupled_weight_decay": True,
451            },
452        ]
453
454        for config in nadam_configs:
455            mod = MLPModule(self.device_type)
456            opt = torch.optim.NAdam(mod.parameters(), **config)
457
458            dist_mod = distribute_module(
459                deepcopy(mod), mesh, shard_fn, input_fn, output_fn
460            )
461            dist_opt = torch.optim.NAdam(dist_mod.parameters(), **config)
462
463            # use ones to make sure the single machine model have the same input
464            # on different ranks
465            inp = torch.ones(8, 10, device=self.device_type)
466            self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
467
468    @with_comms
469    def test_radam_1d_sharding(self):
470        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
471
472        radam_configs = [
473            {"lr": 0.1, "foreach": False},
474            {"lr": 0.1, "weight_decay": 0.05, "foreach": False},
475            {
476                "lr": 0.1,
477                "weight_decay": 0.05,
478            },
479            {
480                "lr": 0.1,
481                "betas": (0.6, 0.66),
482                "eps": 1e-6,
483                "weight_decay": 0.05,
484            },
485            {
486                "lr": 0.1,
487                "betas": (0.6, 0.66),
488                "eps": 1e-6,
489                "weight_decay": 0.05,
490                "decoupled_weight_decay": True,
491            },
492        ]
493
494        for config in radam_configs:
495            mod = MLPModule(self.device_type)
496            opt = torch.optim.RAdam(mod.parameters(), **config)
497
498            dist_mod = distribute_module(
499                deepcopy(mod), mesh, shard_fn, input_fn, output_fn
500            )
501            dist_opt = torch.optim.RAdam(dist_mod.parameters(), **config)
502
503            # use ones to make sure the single machine model have the same input
504            # on different ranks
505            inp = torch.ones(8, 10, device=self.device_type)
506            self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
507
508    @with_comms
509    def test_adamax_1d_sharding(self):
510        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
511
512        adamax_configs = [
513            {"lr": 0.1, "foreach": False},
514            {"lr": 0.1, "betas": (0.6, 0.66), "foreach": False},
515            {"lr": 0.1, "betas": (0.6, 0.66), "eps": 1e-6, "foreach": False},
516            {
517                "lr": 0.1,
518                "betas": (0.6, 0.66),
519                "eps": 1e-6,
520                "weight_decay": 0.05,
521                "foreach": False,
522            },
523            {
524                "lr": 0.1,
525                "betas": (0.6, 0.66),
526                "eps": 1e-6,
527                "weight_decay": 0.05,
528            },
529            {
530                "lr": 0.1,
531                "betas": (0.6, 0.66),
532                "eps": 1e-6,
533                "weight_decay": 0.05,
534                "maximize": True,
535            },
536        ]
537
538        for config in adamax_configs:
539            mod = MLPModule(self.device_type)
540            opt = torch.optim.Adamax(mod.parameters(), **config)
541
542            dist_mod = distribute_module(
543                deepcopy(mod), mesh, shard_fn, input_fn, output_fn
544            )
545            dist_opt = torch.optim.Adamax(dist_mod.parameters(), **config)
546
547            # use ones to make sure the single machine model have the same input
548            # on different ranks
549            inp = torch.ones(8, 10, device=self.device_type)
550            self._assert_optimizer(mesh, mod, opt, dist_mod, dist_opt, inp)
551
552    @with_comms
553    def test_asgd_1d_sharding(self):
554        mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
555
556        asgd_configs = [
557            {"lr": 0.1, "foreach": False},
558            {"lr": 0.1, "lambd": 0.001, "foreach": False},
559            {"lr": 0.1, "lambd": 0.001, "alpha": 0.85, "foreach": False},
560            {"lr": 0.1, "lambd": 0.001, "alpha": 0.85, "t0": 1e5, "foreach": False},
561            {
562                "lr": 0.1,
563                "lambd": 0.001,
564                "alpha": 0.85,
565                "t0": 1e5,
566                "weight_decay": 0.05,
567                "foreach": False,
568            },
569            {
570                "lr": 0.1,
571                "lambd": 0.001,
572                "alpha": 0.85,
573                "t0": 1e5,
574                "weight_decay": 0.05,
575                "foreach": True,
576            },
577            {
578                "lr": 0.1,
579                "lambd": 0.001,
580                "alpha": 0.85,
581                "t0": 1e5,
582                "weight_decay": 0.05,
583                "foreach": True,
584                "maximize": True,
585            },
586        ]
587
588        for config in asgd_configs:
589            mod = MLPModule(self.device_type)
590            opt = torch.optim.ASGD(mod.parameters(), **config)
591
592            dist_mod = distribute_module(
593                deepcopy(mod), mesh, shard_fn, input_fn, output_fn
594            )
595            dist_opt = torch.optim.ASGD(dist_mod.parameters(), **config)
596
597            # use ones to make sure the single machine model have the same input
598            # on different ranks
599            inp = torch.ones(8, 10, device=self.device_type)
600
601            # TODO: We want to keep a unit test for ASGD optimizer for the time being, but we need to look into why
602            # when using ASGD we need higher atol and rtol when comparing model parameters.
603            # Default 'rtol' and 'atol' for attr:`~torch.float32` are ``1.3e-6`` and ``1e-5``
604            # Pointer here: https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L65
605            self._assert_optimizer(
606                mesh, mod, opt, dist_mod, dist_opt, inp, atol=1.3e-5, rtol=1e-4
607            )
608
609
610if __name__ == "__main__":
611    run_tests()
612