xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_uneven.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import torch
6from torch import distributed as dist
7from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8from torch.nn import Linear
9from torch.optim import SGD
10from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
11from torch.testing._internal.common_fsdp import FSDPTest
12from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
13
14
15if not dist.is_available():
16    print("Distributed not available, skipping tests", file=sys.stderr)
17    sys.exit(0)
18
19if TEST_WITH_DEV_DBG_ASAN:
20    print(
21        "Skip dev-asan as torch + multiprocessing spawn have known issues",
22        file=sys.stderr,
23    )
24    sys.exit(0)
25
26
27class TestUnevenParamShard(FSDPTest):
28    def _get_ref_results(self, model, input, my_lr):
29        with torch.no_grad():
30            # Compute one iteration local output.
31            weight = model.weight.T.clone().to(self.rank)
32            v = torch.Tensor(input[self.rank]).to(self.rank)
33            ref_forward_output_my_rank = torch.matmul(v, weight)
34            # Compute one iteration global weight update.
35            v = torch.Tensor(input[: self.world_size]).to(self.rank)
36            grad = v.float().sum(0).repeat(weight.shape[0], 1).div(self.world_size)
37            ref_weight_out = weight - grad.T * my_lr
38
39        return ref_forward_output_my_rank, ref_weight_out
40
41    @skip_if_lt_x_gpu(2)
42    def test_one_iteration(self):
43        """Test FSDP with uneven divide of parameter shards."""
44        model = Linear(3, 3, bias=False)
45        input = torch.rand(8, 3)
46        my_lr = 0.1
47
48        ref_forward_output_my_rank, ref_weight_out = self._get_ref_results(
49            model, input, my_lr
50        )
51
52        model.to(self.rank)
53        model = FSDP(model)
54        optim = SGD(model.parameters(), lr=my_lr)
55        self.assertTrue(len(input) >= self.world_size)
56        in_data = torch.Tensor(input[self.rank]).to(self.rank)
57        out = model(in_data)
58        out.float().sum().backward()
59        optim.step()
60        optim.zero_grad()
61
62        with model.summon_full_params(model):
63            weight_out = model.module.weight.T.clone()
64            self.assertEqual(ref_forward_output_my_rank, out)
65            self.assertEqual(ref_weight_out, weight_out)
66
67
68if __name__ == "__main__":
69    run_tests()
70