1# Owner(s): ["oncall: distributed"] 2import torch 3import torch.distributed._functional_collectives as funcol 4import torch.distributed.tensor._random as random 5from torch.distributed._tensor import init_device_mesh, Replicate 6from torch.distributed.tensor.parallel.api import parallelize_module 7from torch.distributed.tensor.parallel.style import ColwiseParallel 8from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 9from torch.testing._internal.common_utils import run_tests 10from torch.testing._internal.distributed._tensor.common_dtensor import ( 11 DTensorTestBase, 12 MLPModule, 13 with_comms, 14) 15 16 17class TensorParallelRandomStateTests(DTensorTestBase): 18 def get_tensor_slice(self, idx, n, large_tensor): 19 shape = large_tensor.shape 20 assert shape[0] % n == 0 21 local_shape = [shape[0] // n, shape[1]] 22 23 slice_idx = [ 24 slice(idx * local_shape[0], (idx + 1) * local_shape[0]), 25 slice(local_shape[1]), 26 ] 27 return large_tensor[slice_idx] 28 29 def check_gathered_tensors(self, self_rank, size, gathered_tensors, assertFunc): 30 for other_rank in range(size): 31 if self_rank != other_rank: 32 assertFunc( 33 self.get_tensor_slice(self_rank, size, gathered_tensors), 34 self.get_tensor_slice(other_rank, size, gathered_tensors), 35 ) 36 37 @with_comms 38 @skip_if_lt_x_gpu(4) 39 def test_model_init(self): 40 dp_size = 2 41 tp_size = self.world_size // dp_size 42 mesh_2d = init_device_mesh( 43 self.device_type, (dp_size, tp_size), mesh_dim_names=("dp", "tp") 44 ) 45 dp_mesh = mesh_2d["dp"] 46 tp_mesh = mesh_2d["tp"] 47 dp_rank = dp_mesh.get_coordinate()[0] 48 tp_rank = tp_mesh.get_coordinate()[0] 49 self.assertEqual(dp_rank, self.rank // tp_size) 50 self.assertEqual(tp_rank, self.rank % tp_size) 51 52 for enable_distribute_flag in [False, True]: 53 # a local model on meta device 54 model = MLPModule(device="meta") 55 # the col-wise parallel style shards the weight over tensor dim 0 56 model_tp = parallelize_module( 57 model, 58 tp_mesh, 59 { 60 "net1": ColwiseParallel(output_layouts=Replicate()), 61 "net2": ColwiseParallel(output_layouts=Replicate()), 62 }, 63 ) 64 # in most cases, the random number generator states is set by data loader 65 # in the following way: 66 # - within a tensor parallel group, the RNG is set with the same seed 67 # - across data parallel groups, the RNG is set with different seeds 68 torch.cuda.manual_seed(dp_rank) 69 70 # disable/enable parallel RNG feature 71 random._rng_tracker.distribute_region_enabled = enable_distribute_flag 72 self.assertTrue(model_tp.net1.weight.is_meta) 73 # initialize the model's local shard 74 model_tp.to_empty(device=self.device_type) 75 model_tp.reset_parameters() 76 # examine that the weights are initialized adhere to DP/TP 77 for dtensor in [model_tp.net1.weight, model_tp.net2.weight]: 78 # check within the TP group 79 # the 1d mesh represents the TP group 80 _1d_mesh = dtensor.device_mesh 81 assert _1d_mesh.ndim == 1 82 self.assertEqual(_1d_mesh, tp_mesh) 83 84 tensor_local = dtensor.to_local() 85 86 # all-gather local shards 87 tensor_gather = funcol.all_gather_tensor( 88 tensor_local, 89 gather_dim=0, 90 group=_1d_mesh, 91 ) 92 self.assertEqual(_1d_mesh.get_coordinate()[0], tp_rank) 93 94 # compare local shards within the TP group 95 def tp_weights_assert(tensor1, tensor2): 96 if enable_distribute_flag: 97 # each rank within a TP group shall initialize local weights differently 98 self.assertNotEqual(tensor1, tensor2) 99 else: 100 # without the parallel RNG, weight initialization violates the TP setup: 101 # each rank within a TP group has the same initial weights 102 self.assertEqual(tensor1, tensor2) 103 104 self.check_gathered_tensors( 105 tp_rank, tp_size, tensor_gather, tp_weights_assert 106 ) 107 108 # check across TP groups 109 # all-gather local shards 110 tensor_gather = funcol.all_gather_tensor( 111 tensor_local, 112 gather_dim=0, 113 group=dp_mesh, 114 ) 115 116 # compare local shards across TP groups 117 def dp_weights_assert(tensor1, tensor2): 118 if enable_distribute_flag: 119 # local weights shall be initialized the same across TP groups 120 self.assertEqual(tensor1, tensor2) 121 else: 122 # without the parallel RNG, weight initialization violates the TP setup: 123 # local weights are initialized differently across TP groups due to different 124 # random seeds set in data loading. 125 self.assertNotEqual(tensor1, tensor2) 126 127 self.check_gathered_tensors( 128 dp_rank, dp_size, tensor_gather, dp_weights_assert 129 ) 130 131 132if __name__ == "__main__": 133 run_tests() 134