1# Owner(s): ["oncall: distributed"] 2 3import torch 4import torch.distributed as dist 5import torch.distributed._functional_collectives as funcol 6import torch.nn as nn 7from torch.distributed._tensor import DeviceMesh, DTensor 8from torch.distributed._tensor.placement_types import Shard 9from torch.distributed.tensor.debug import CommDebugMode 10from torch.testing._internal.common_distributed import requires_nccl 11from torch.testing._internal.common_utils import run_tests, TestCase 12from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule 13from torch.testing._internal.distributed.fake_pg import FakeStore 14 15 16c10d_functional = torch.ops.c10d_functional 17c10d_ops = torch.ops.c10d 18 19 20class TestCommMode(TestCase): 21 def tearDown(self): 22 super().tearDown() 23 dist.destroy_process_group() 24 25 def setUp(self): 26 super().setUp() 27 self.world_size = 2 28 store = FakeStore() 29 dist.init_process_group( 30 backend="fake", rank=1, world_size=self.world_size, store=store 31 ) 32 self.device_type = "cuda" if torch.cuda.is_available() else "cpu" 33 self.world_pg = dist.distributed_c10d._get_default_group() 34 35 def checksAssert(self, comm_mode, key, expected_value, expected_total_value): 36 comm_counts = comm_mode.get_comm_counts() 37 self.assertEqual(comm_mode.get_total_counts(), expected_total_value) 38 self.assertEqual(comm_counts[key], expected_value) 39 40 return 41 42 def test_comm_mode(self): 43 world_pg = self.world_pg 44 45 class WrapperModel(nn.Module): 46 def __init__(self, device): 47 super().__init__() 48 self.model = MLPModule(device=device) 49 50 def forward(self, x): 51 x = funcol.all_gather_tensor(x, 0, world_pg) 52 x = funcol.reduce_scatter_tensor(x, "sum", 0, world_pg) 53 out = self.model(x) 54 return funcol.all_reduce(out, "sum", world_pg) 55 56 model = WrapperModel(self.device_type) 57 58 comm_mode = CommDebugMode() 59 with comm_mode: 60 model(torch.randn(20, 10, device=self.device_type)) 61 62 comm_counts = comm_mode.get_comm_counts() 63 self.assertEqual(comm_mode.get_total_counts(), 3) 64 self.assertEqual(comm_counts[c10d_functional.all_reduce], 1) 65 self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1) 66 self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 1) 67 68 def test_comm_mode_coalesced(self): 69 world_pg = self.world_pg 70 71 class WrapperModelCoalesced(nn.Module): 72 def __init__(self, device): 73 super().__init__() 74 self.model = MLPModule(device=device) 75 76 def forward(self, x): 77 x = funcol.all_gather_tensor(x, 0, world_pg) 78 x = funcol.reduce_scatter_tensor(x, "sum", 0, world_pg) 79 out = self.model(x) 80 return funcol.all_reduce_coalesced([out], "sum", world_pg) 81 82 model = WrapperModelCoalesced(self.device_type) 83 84 comm_mode = CommDebugMode() 85 with comm_mode: 86 model(torch.randn(20, 10, device=self.device_type)) 87 88 comm_counts = comm_mode.get_comm_counts() 89 self.assertEqual(comm_mode.get_total_counts(), 3) 90 self.assertEqual(comm_counts[c10d_functional.all_reduce_coalesced], 1) 91 self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1) 92 self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 1) 93 94 def test_comm_mode_with_dtensor(self): 95 world_pg = self.world_pg 96 mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 97 98 def f(x, y): 99 return torch.mm(x, y) 100 101 comm_mode = CommDebugMode() 102 x = torch.randn(4, 8, requires_grad=True) 103 y = torch.randn(4, 32, requires_grad=True) 104 x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) 105 y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False) 106 107 with comm_mode: 108 f(x_dtensor, y_dtensor) 109 110 comm_counts = comm_mode.get_comm_counts() 111 self.assertEqual(comm_mode.get_total_counts(), 1) 112 self.assertEqual(comm_counts[c10d_functional.all_reduce], 0) 113 self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1) 114 self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0) 115 116 @requires_nccl() 117 def test_comm_mode_with_c10d(self): 118 if not torch.cuda.is_available(): 119 return 120 121 world_pg = self.world_pg 122 123 inp = torch.rand(2, 8, 16).cuda() 124 all_gather_out = inp.new_empty(self.world_size * 2, 8, 16) 125 126 comm_mode = CommDebugMode() 127 128 # tests c10d all_reduce tracing 129 with comm_mode: 130 dist.all_reduce(inp) 131 132 self.checksAssert(comm_mode, c10d_ops.allreduce_, 1, 1) 133 134 # tests c10d all_gather_into_tensor tracing 135 with comm_mode: 136 dist.all_gather_into_tensor(all_gather_out, inp) 137 138 self.checksAssert(comm_mode, c10d_ops._allgather_base_, 1, 1) 139 140 # tests c10d reduce_scatter tracing 141 with comm_mode: 142 dist.reduce_scatter_tensor(inp, all_gather_out) 143 144 self.checksAssert(comm_mode, c10d_ops._reduce_scatter_base_, 1, 1) 145 146 # tests c10d broadcast tracing 147 with comm_mode: 148 dist.broadcast(inp, 0) 149 150 self.checksAssert(comm_mode, c10d_ops.broadcast_, 1, 1) 151 152 # tests c10d gather tracing 153 with comm_mode: 154 dist.gather(inp, None, 0) 155 156 self.checksAssert(comm_mode, c10d_ops.gather_, 1, 1) 157 158 # tests c10d reduce tracing 159 with comm_mode: 160 dist.reduce(inp, 0) 161 162 self.checksAssert(comm_mode, c10d_ops.reduce_, 1, 1) 163 164 # tests c10d scatter tracing 165 with comm_mode: 166 dist.scatter(inp, None, 0) 167 168 self.checksAssert(comm_mode, c10d_ops.scatter_, 1, 1) 169 170 # tests c10d all_gather tracing 171 output_list = [] 172 173 with comm_mode: 174 dist.all_gather(output_list, inp, None) 175 176 self.checksAssert(comm_mode, c10d_ops.allgather_, 1, 1) 177 178 # tests c10d allgather_coalesced_ tracing 179 output_list = [] 180 181 with comm_mode: 182 dist.all_gather_coalesced(output_list, [inp], None) 183 184 self.checksAssert(comm_mode, c10d_ops.allgather_coalesced_, 1, 1) 185 186 # tests c10d allgather_into_tensor_coalesced_ tracing 187 with comm_mode, dist._coalescing_manager(): 188 dist.all_gather_into_tensor(all_gather_out, inp) 189 190 self.checksAssert(comm_mode, c10d_ops.allgather_into_tensor_coalesced_, 1, 1) 191 192 # tests c10d allreduce_coalesced 193 with comm_mode: 194 dist.all_reduce_coalesced(inp) 195 196 self.checksAssert(comm_mode, c10d_ops.allreduce_coalesced_, 1, 1) 197 198 # tests c10d reduce_scatter_ 199 with comm_mode: 200 dist.reduce_scatter(all_gather_out, [inp]) 201 202 self.checksAssert(comm_mode, c10d_ops.reduce_scatter_, 1, 1) 203 204 # tests c10d reduce_scatter_tensor_coalesced 205 with comm_mode as A, dist._coalescing_manager() as B: 206 dist.reduce_scatter_tensor(all_gather_out, inp) 207 208 self.checksAssert(comm_mode, c10d_ops.reduce_scatter_tensor_coalesced_, 1, 1) 209 210 # tests c10d alltoall_ 211 with comm_mode: 212 dist.all_to_all([inp], [inp]) 213 214 self.checksAssert(comm_mode, c10d_ops.alltoall_, 1, 1) 215 216 # tests c10d alltoall_base_ 217 with comm_mode: 218 dist.all_to_all_single(inp, inp) 219 220 self.checksAssert(comm_mode, c10d_ops.alltoall_base_, 1, 1) 221 222 223if __name__ == "__main__": 224 run_tests() 225