xref: /aosp_15_r20/external/pytorch/test/distributed/tensor/parallel/test_tp_random_state.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2import torch
3import torch.distributed._functional_collectives as funcol
4import torch.distributed.tensor._random as random
5from torch.distributed._tensor import init_device_mesh, Replicate
6from torch.distributed.tensor.parallel.api import parallelize_module
7from torch.distributed.tensor.parallel.style import ColwiseParallel
8from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
9from torch.testing._internal.common_utils import run_tests
10from torch.testing._internal.distributed._tensor.common_dtensor import (
11    DTensorTestBase,
12    MLPModule,
13    with_comms,
14)
15
16
17class TensorParallelRandomStateTests(DTensorTestBase):
18    def get_tensor_slice(self, idx, n, large_tensor):
19        shape = large_tensor.shape
20        assert shape[0] % n == 0
21        local_shape = [shape[0] // n, shape[1]]
22
23        slice_idx = [
24            slice(idx * local_shape[0], (idx + 1) * local_shape[0]),
25            slice(local_shape[1]),
26        ]
27        return large_tensor[slice_idx]
28
29    def check_gathered_tensors(self, self_rank, size, gathered_tensors, assertFunc):
30        for other_rank in range(size):
31            if self_rank != other_rank:
32                assertFunc(
33                    self.get_tensor_slice(self_rank, size, gathered_tensors),
34                    self.get_tensor_slice(other_rank, size, gathered_tensors),
35                )
36
37    @with_comms
38    @skip_if_lt_x_gpu(4)
39    def test_model_init(self):
40        dp_size = 2
41        tp_size = self.world_size // dp_size
42        mesh_2d = init_device_mesh(
43            self.device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
44        )
45        dp_mesh = mesh_2d["dp"]
46        tp_mesh = mesh_2d["tp"]
47        dp_rank = dp_mesh.get_coordinate()[0]
48        tp_rank = tp_mesh.get_coordinate()[0]
49        self.assertEqual(dp_rank, self.rank // tp_size)
50        self.assertEqual(tp_rank, self.rank % tp_size)
51
52        for enable_distribute_flag in [False, True]:
53            # a local model on meta device
54            model = MLPModule(device="meta")
55            # the col-wise parallel style shards the weight over tensor dim 0
56            model_tp = parallelize_module(
57                model,
58                tp_mesh,
59                {
60                    "net1": ColwiseParallel(output_layouts=Replicate()),
61                    "net2": ColwiseParallel(output_layouts=Replicate()),
62                },
63            )
64            # in most cases, the random number generator states is set by data loader
65            # in the following way:
66            #   - within a tensor parallel group, the RNG is set with the same seed
67            #   - across data parallel groups, the RNG is set with different seeds
68            torch.cuda.manual_seed(dp_rank)
69
70            # disable/enable parallel RNG feature
71            random._rng_tracker.distribute_region_enabled = enable_distribute_flag
72            self.assertTrue(model_tp.net1.weight.is_meta)
73            # initialize the model's local shard
74            model_tp.to_empty(device=self.device_type)
75            model_tp.reset_parameters()
76            # examine that the weights are initialized adhere to DP/TP
77            for dtensor in [model_tp.net1.weight, model_tp.net2.weight]:
78                # check within the TP group
79                # the 1d mesh represents the TP group
80                _1d_mesh = dtensor.device_mesh
81                assert _1d_mesh.ndim == 1
82                self.assertEqual(_1d_mesh, tp_mesh)
83
84                tensor_local = dtensor.to_local()
85
86                # all-gather local shards
87                tensor_gather = funcol.all_gather_tensor(
88                    tensor_local,
89                    gather_dim=0,
90                    group=_1d_mesh,
91                )
92                self.assertEqual(_1d_mesh.get_coordinate()[0], tp_rank)
93
94                # compare local shards within the TP group
95                def tp_weights_assert(tensor1, tensor2):
96                    if enable_distribute_flag:
97                        # each rank within a TP group shall initialize local weights differently
98                        self.assertNotEqual(tensor1, tensor2)
99                    else:
100                        # without the parallel RNG, weight initialization violates the TP setup:
101                        # each rank within a TP group has the same initial weights
102                        self.assertEqual(tensor1, tensor2)
103
104                self.check_gathered_tensors(
105                    tp_rank, tp_size, tensor_gather, tp_weights_assert
106                )
107
108                # check across TP groups
109                # all-gather local shards
110                tensor_gather = funcol.all_gather_tensor(
111                    tensor_local,
112                    gather_dim=0,
113                    group=dp_mesh,
114                )
115
116                # compare local shards across TP groups
117                def dp_weights_assert(tensor1, tensor2):
118                    if enable_distribute_flag:
119                        # local weights shall be initialized the same across TP groups
120                        self.assertEqual(tensor1, tensor2)
121                    else:
122                        # without the parallel RNG, weight initialization violates the TP setup:
123                        # local weights are initialized differently across TP groups due to different
124                        # random seeds set in data loading.
125                        self.assertNotEqual(tensor1, tensor2)
126
127                self.check_gathered_tensors(
128                    dp_rank, dp_size, tensor_gather, dp_weights_assert
129                )
130
131
132if __name__ == "__main__":
133    run_tests()
134