1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3from functools import partial 4 5import torch 6import torch.distributed._functional_collectives as funcol 7from torch.distributed._tensor import ( 8 distribute_tensor, 9 DTensor, 10 init_device_mesh, 11 Replicate, 12 Shard, 13) 14from torch.distributed._tensor.experimental import local_map 15from torch.distributed.tensor.debug import CommDebugMode 16from torch.testing._internal.common_utils import run_tests 17from torch.testing._internal.distributed._tensor.common_dtensor import ( 18 DTensorTestBase, 19 with_comms, 20) 21 22 23funcol_py = torch.ops.c10d_functional 24 25 26row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh 27col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh 28replicate = [Replicate()] # replicate placements on 1-d mesh 29 30 31def equal_allgather_forward(device_mesh, X, Y): 32 eq = torch.tensor([torch.equal(X, Y)], device=X.device) 33 eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh) 34 return torch.all(eq_gather).item() 35 36 37def mm_all_gather_forward(device_mesh, A, B): 38 local_mm_result = torch.mm(A, B) 39 return funcol.all_gather_tensor(local_mm_result, 0, device_mesh).wait() 40 41 42def mm_forward(A, B): # no device mesh needed since we don't do collective 43 return torch.mm(A, B) 44 45 46def mm_allreduce_forward(device_mesh, A, B): 47 partial_sum_tensor = torch.mm(A, B) 48 return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() 49 50 51@partial( 52 local_map, 53 out_placements=replicate, 54 in_placements=(None, col_wise, row_wise), 55) 56def mm_allreduce_forward_decorated(device_mesh, A, B): 57 partial_sum_tensor = torch.mm(A, B) 58 return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() 59 60 61def mul_forward(X, scalar): # no device mesh needed since we don't do collective 62 return torch.mul(X, scalar) 63 64 65class TestLocalMap(DTensorTestBase): 66 @property 67 def world_size(self): 68 return 2 69 70 # simple correctness check 71 @with_comms 72 def test_local_map_correctness(self): 73 device_mesh = init_device_mesh( 74 device_type=self.device_type, mesh_shape=(self.world_size,) 75 ) 76 comm_mode = CommDebugMode() 77 78 # Y = X @ W 79 X = torch.randn(16, 8, device=self.device_type, requires_grad=False) 80 W = torch.randn(8, 12, device=self.device_type, requires_grad=False) 81 Y = torch.mm(X, W) 82 83 X_dt = distribute_tensor( 84 X, device_mesh, col_wise 85 ) # col-wisely sharded X tensor 86 W_dt = distribute_tensor( 87 W, device_mesh, row_wise 88 ) # row-wisely sharded W tensor 89 90 # Test 1: use the function returned from calling local_map 91 # get the function wrapped with DTensor/Tensor convertion 92 # mm_allreduce_forward is a function that applies to Tensors with manual collective 93 # local_mm_allreduce_forward is the function that does the same but applies to 94 # DTensors' `_local_tensor`. 95 local_mm_allreduce_forward = local_map( 96 mm_allreduce_forward, 97 out_placements=replicate, 98 in_placements=(None, col_wise, row_wise), 99 device_mesh=device_mesh, 100 ) 101 with comm_mode: 102 Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt) 103 104 # output redistribution to Replicate 105 self.assertEqual(comm_mode.get_total_counts(), 1) 106 # check output placements 107 for placement in Y_dt.placements: 108 self.assertTrue(placement.is_replicate()) 109 # check output value 110 self.assertEqual(Y_dt.to_local(), Y) 111 112 # Test 2: use the local_map decorator 113 with comm_mode: 114 Y_dt = mm_allreduce_forward_decorated(device_mesh, X_dt, W_dt) 115 116 # output redistribution to Replicate 117 self.assertEqual(comm_mode.get_total_counts(), 1) 118 # check output placements 119 for placement in Y_dt.placements: 120 self.assertTrue(placement.is_replicate()) 121 # check output value 122 self.assertEqual(Y_dt.to_local(), Y) 123 124 # check for `out_placements` 125 @with_comms 126 def test_local_map_out_placements(self): 127 # Test 1: wrap out into DTensor w/ `out_placements` 128 device_mesh = init_device_mesh( 129 device_type=self.device_type, mesh_shape=(self.world_size,) 130 ) 131 comm_mode = CommDebugMode() 132 133 # X.equal(Y) 134 X = torch.randn(8, 8, device=self.device_type, requires_grad=False) 135 Y = torch.randn(8, 8, device=self.device_type, requires_grad=False) 136 X_dt = distribute_tensor(X, device_mesh, row_wise) 137 Y_dt = distribute_tensor(Y, device_mesh, row_wise) 138 local_equal_allgather_forward = local_map( 139 equal_allgather_forward, 140 out_placements=None, 141 ) 142 with comm_mode: 143 equal_dt = local_equal_allgather_forward(device_mesh, X_dt, Y_dt) # a bool 144 145 self.assertEqual(comm_mode.get_total_counts(), 1) 146 self.assertTrue(not equal_dt) 147 self.assertTrue(not (X.equal(Y))) 148 149 # Test 2: directly return out if no argument is DTensor 150 # matmul in DDP 151 X = torch.randn( 152 4 // self.world_size, 4, device=self.device_type, requires_grad=False 153 ) 154 W = torch.randn(4, 4, device=self.device_type, requires_grad=False) 155 local_mm_all_gather_forward = local_map( 156 mm_all_gather_forward, 157 out_placements=row_wise, 158 in_placements=(None, row_wise, replicate), 159 ) 160 with comm_mode: 161 Y = local_mm_all_gather_forward(device_mesh, X, W) 162 163 self.assertEqual(comm_mode.get_total_counts(), 1) 164 self.assertEqual( 165 comm_mode.get_comm_counts()[funcol_py.all_gather_into_tensor], 1 166 ) 167 X_replicate = funcol.all_gather_tensor(X, 0, device_mesh).wait() 168 Y_replicate = torch.mm(X_replicate, W) 169 self.assertEqual(Y, Y_replicate) # Y is a torch.Tensor 170 171 # check for `in_placements` handling 172 @with_comms 173 def test_local_map_in_placements(self): 174 device_mesh = init_device_mesh( 175 device_type=self.device_type, mesh_shape=(self.world_size,) 176 ) 177 comm_mode = CommDebugMode() 178 179 # Y = X @ W 180 X = torch.randn(16, 8, device=self.device_type, requires_grad=False) 181 W = torch.randn(8, 12, device=self.device_type, requires_grad=False) 182 Y = torch.mm(X, W) 183 184 X_dt = distribute_tensor( 185 X, device_mesh, row_wise 186 ) # row-wisely sharded X tensor 187 W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor 188 189 # Test 1: explicitly pass `in_placements` 190 local_mm_forward = local_map( 191 mm_forward, 192 out_placements=row_wise, 193 in_placements=(row_wise, replicate), 194 device_mesh=device_mesh, 195 ) 196 with comm_mode: 197 Y_dt = local_mm_forward(X_dt, W_dt) 198 199 # no communication should occur in this case 200 self.assertEqual(comm_mode.get_total_counts(), 0) 201 for placement in Y_dt.placements: 202 self.assertTrue(placement.is_shard(dim=0)) 203 self.assertEqual(Y_dt.full_tensor(), Y) 204 205 # Test 2: `in_placements=None` 206 local_mm_forward = local_map( 207 mm_forward, 208 out_placements=row_wise, 209 device_mesh=device_mesh, 210 ) 211 with comm_mode: 212 Y_dt = local_mm_forward(X_dt, W_dt) 213 214 self.assertEqual(comm_mode.get_total_counts(), 0) 215 for placement in Y_dt.placements: 216 self.assertTrue(placement.is_shard(dim=0)) 217 self.assertEqual(Y_dt.full_tensor(), Y) 218 219 # Test 3: `None` placements for non-Tensor input argument 220 # Y = X * 2.0 221 local_mul_forward = local_map( 222 mul_forward, 223 in_placements=(row_wise, None), 224 out_placements=row_wise, 225 device_mesh=device_mesh, 226 ) 227 Y = torch.mul(X, 2.0) 228 with comm_mode: 229 Y_dt = local_mul_forward(X_dt, 2.0) 230 231 self.assertEqual(comm_mode.get_total_counts(), 0) 232 for placement in Y_dt.placements: 233 self.assertTrue(placement.is_shard(dim=0)) 234 self.assertEqual(Y_dt.full_tensor(), Y) 235 236 # Test 4: `None` placements for Tensor input argument 237 local_mm_forward = local_map( 238 mm_forward, 239 out_placements=None, 240 in_placements=(None, None), 241 device_mesh=device_mesh, 242 ) 243 with comm_mode: 244 Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local()) 245 246 self.assertEqual(comm_mode.get_total_counts(), 0) 247 self.assertEqual( 248 DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(), 249 torch.mm(X, W), 250 ) 251 252 # Test 5: Some placements for Tensor input argument 253 local_mm_forward = local_map( 254 mm_forward, 255 out_placements=None, 256 in_placements=(replicate, row_wise), 257 device_mesh=device_mesh, 258 ) 259 with comm_mode: 260 Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local()) 261 262 self.assertEqual(comm_mode.get_total_counts(), 0) 263 self.assertEqual( 264 DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(), 265 torch.mm(X, W), 266 ) 267 268 # Test 6: expect error - `None` placements for DTensor input argument 269 local_mm_forward = local_map( 270 mm_forward, 271 out_placements=row_wise, 272 in_placements=(row_wise, None), 273 device_mesh=device_mesh, 274 ) 275 with self.assertRaisesRegex(AssertionError, "expects placements"): 276 Y_dt = local_mm_forward(X_dt, W_dt) 277 278 # check for `redistribute_inputs` handling 279 @with_comms 280 def test_local_map_redistribute(self): 281 device_mesh = init_device_mesh( 282 device_type=self.device_type, mesh_shape=(self.world_size,) 283 ) 284 comm_mode = CommDebugMode() 285 286 # Y = X @ W 287 X = torch.randn(16, 8, device=self.device_type, requires_grad=False) 288 W = torch.randn(8, 12, device=self.device_type, requires_grad=False) 289 Y = torch.mm(X, W) 290 291 X_dt = distribute_tensor( 292 X, device_mesh, row_wise 293 ) # row-wisely sharded X tensor which will be redistributed 294 W_dt = distribute_tensor( 295 W, device_mesh, col_wise 296 ) # col-wisely sharded W tensor which will be redistributed 297 298 # Test 1: allow input redistribution 299 local_mm_allreduce_forward = local_map( 300 mm_allreduce_forward, 301 out_placements=replicate, 302 in_placements=(None, col_wise, row_wise), 303 device_mesh=device_mesh, 304 redistribute_inputs=True, 305 ) 306 with comm_mode: 307 Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt) 308 309 # 2 for input redistribution and 1 for output 310 self.assertEqual(comm_mode.get_total_counts(), 3) 311 for placement in Y_dt.placements: 312 self.assertTrue(placement.is_replicate()) 313 self.assertEqual(Y_dt.to_local(), Y) 314 315 # Test 2: no input redistribution is allowed 316 local_mm_allreduce_forward = local_map( 317 mm_allreduce_forward, 318 out_placements=replicate, 319 in_placements=(None, col_wise, row_wise), 320 device_mesh=device_mesh, 321 redistribute_inputs=False, 322 ) 323 with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"): 324 Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt) 325 326 327if __name__ == "__main__": 328 run_tests() 329