1# Owner(s): ["oncall: distributed"] 2 3import sys 4import unittest 5 6import torch 7import torch.distributed as dist 8import torch.distributed._functional_collectives as funcol 9import torch.nn as nn 10from torch.distributed._tensor import DeviceMesh, init_device_mesh, Shard 11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 12from torch.distributed.tensor.parallel import ( 13 ColwiseParallel, 14 parallelize_module, 15 RowwiseParallel, 16) 17from torch.fx.experimental.proxy_tensor import make_fx 18from torch.testing import FileCheck 19from torch.testing._internal.common_utils import run_tests, TestCase 20from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule 21from torch.testing._internal.distributed.fake_pg import FakeStore 22 23 24if not dist.is_available(): 25 print("Distributed not available, skipping tests", file=sys.stderr) 26 sys.exit(0) 27 28HAS_CUDA = torch.cuda.is_available() 29 30 31class TestFakePG(TestCase): 32 def tearDown(self): 33 super().tearDown() 34 dist.destroy_process_group() 35 36 def test_all_reduce(self): 37 store = FakeStore() 38 dist.init_process_group(backend="fake", rank=1, world_size=2, store=store) 39 40 output = torch.ones(3, 3) * dist.get_rank() 41 dist.all_reduce(output) 42 self.assertEqual(tuple(output.shape), (3, 3)) 43 44 def test_allgather(self): 45 store = FakeStore() 46 dist.init_process_group(backend="fake", rank=1, world_size=2, store=store) 47 48 input_tensor = torch.ones(3, 3) * dist.get_rank() 49 output_tensors = [torch.empty_like(input_tensor) for _ in range(2)] 50 dist.all_gather(output_tensors, input_tensor) 51 for _, out_tensor in enumerate(output_tensors): 52 self.assertEqual(tuple(out_tensor.shape), (3, 3)) 53 54 def test_reduce_scatter(self): 55 store = FakeStore() 56 dist.init_process_group(backend="fake", rank=1, world_size=2, store=store) 57 58 to_reduce_scatter = [torch.ones(3, 3) * rank for rank in range(2)] 59 output_tensor = torch.empty(3, 3) 60 61 dist.reduce_scatter(output_tensor, to_reduce_scatter) 62 self.assertEqual(tuple(output_tensor.shape), (3, 3)) 63 64 @unittest.skipIf(not HAS_CUDA, "No CUDA") 65 def test_construct_fsdp(self): 66 store = FakeStore() 67 dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) 68 FSDP(nn.Linear(2, 3, device="cuda")) 69 70 @unittest.skipIf(not HAS_CUDA, "No CUDA") 71 def test_fsdp_fake_e2e(self): 72 store = dist.HashStore() 73 dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) 74 my_module = nn.Sequential( 75 nn.Linear(2, 3, device="cuda"), 76 nn.ReLU(), 77 nn.Linear(3, 2, device="cuda"), 78 ) 79 sharded_module = FSDP(my_module, use_orig_params=True) 80 optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) 81 input = torch.randn(2, 2) 82 x = sharded_module(input) 83 loss = x.sum() 84 loss.backward() 85 optim.step() 86 87 @unittest.skipIf(not HAS_CUDA, "No CUDA") 88 def test_fake_pg_tracing(self): 89 store = dist.HashStore() 90 dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) 91 92 default_pg = dist.distributed_c10d._get_default_group() 93 94 def allgather_fn(tensor): 95 return funcol.all_gather_tensor(tensor, 0, default_pg) 96 97 gm = make_fx(allgather_fn)(torch.randn(2, 2, device="cuda")) 98 FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph)) 99 100 def test_broadcast(self): 101 store = FakeStore() 102 dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) 103 104 # src == rank 105 output = torch.ones(3, 3) 106 dist.broadcast(output, src=0) 107 self.assertEqual(tuple(output.shape), (3, 3)) 108 109 # src != rank 110 output = torch.ones(3, 3) 111 dist.broadcast(output, src=1) 112 self.assertEqual(tuple(output.shape), (3, 3)) 113 114 def test_scatter(self): 115 store = FakeStore() 116 dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) 117 118 # src == rank 119 output = torch.ones(3, 3) 120 to_scatter = [torch.ones(3, 3) * rank for rank in range(2)] 121 dist.scatter(output, to_scatter) 122 self.assertEqual(tuple(output.shape), (3, 3)) 123 124 # src != rank 125 output = torch.ones(3, 3) 126 dist.scatter(output, None, src=1) 127 self.assertEqual(tuple(output.shape), (3, 3)) 128 129 def test_alltoall(self): 130 store = FakeStore() 131 dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) 132 133 output_list = [torch.ones(3, 3) for _ in range(2)] 134 input_list = [torch.ones(3, 3) for _ in range(2)] 135 dist.all_to_all(output_list, input_list) 136 self.assertEqual(len(output_list), 2) 137 for output in output_list: 138 self.assertEqual(tuple(output.shape), (3, 3)) 139 140 def test_alltoall_base(self): 141 store = FakeStore() 142 dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) 143 144 out_tensor = torch.ones(3, 3) 145 in_tensor = torch.ones(3, 3) 146 output_split = [1, 1] 147 input_split = [1, 1] 148 dist.all_to_all_single(out_tensor, in_tensor, output_split, input_split) 149 self.assertEqual(tuple(out_tensor.shape), (3, 3)) 150 151 def test_send(self): 152 store = FakeStore() 153 dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) 154 155 tensor = torch.ones(3, 3) 156 dist.send(tensor, 1) 157 self.assertEqual(tuple(tensor.shape), (3, 3)) 158 159 def test_recv(self): 160 store = FakeStore() 161 dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) 162 163 output = torch.ones(3, 3) 164 dist.recv(output, 1) 165 self.assertEqual(tuple(output.shape), (3, 3)) 166 167 @unittest.skipIf(not HAS_CUDA, "No CUDA or TP+FSDP") 168 def test_fsdp_tp_fake_e2e(self): 169 world_size = 4 170 tp_size = 2 171 172 store = dist.HashStore() 173 dist.init_process_group( 174 backend="fake", rank=0, world_size=world_size, store=store 175 ) 176 177 device_mesh = DeviceMesh("cuda", torch.arange(0, world_size).view(-1, tp_size)) 178 device_mesh = init_device_mesh( 179 "cuda", (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"] 180 ) 181 182 sequence_parallelize_plan = { 183 "net1": ColwiseParallel(input_layouts=Shard(0)), 184 "net2": RowwiseParallel(output_layouts=Shard(0)), 185 } 186 pairwise_parallelize_plan = { 187 "net1": ColwiseParallel(), 188 "net2": RowwiseParallel(), 189 } 190 for parallel_plan in [sequence_parallelize_plan, pairwise_parallelize_plan]: 191 my_module = parallelize_module( 192 MLPModule(device="cuda"), 193 device_mesh["tp"], 194 parallel_plan, 195 ) 196 197 sharded_module = FSDP( 198 my_module, use_orig_params=True, device_mesh=device_mesh["dp"] 199 ) 200 optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) 201 202 for i in range(10): 203 dp_rank = dist.get_rank() 204 torch.manual_seed(i + dp_rank) 205 input = torch.randn(20, 10).cuda(dist.get_rank()) 206 x = sharded_module(input) 207 loss = x.sum() 208 loss.backward() 209 optim.step() 210 211 212if __name__ == "__main__": 213 run_tests() 214