1# Owner(s): ["oncall: distributed"] 2from collections import defaultdict 3from typing import Dict 4 5import torch 6from torch.distributed._tensor.experimental._tp_transform import ( 7 tensor_parallel_transformation, 8) 9from torch.distributed.tensor.parallel.style import ( 10 ColwiseParallel, 11 ParallelStyle, 12 RowwiseParallel, 13) 14from torch.testing._internal.common_utils import run_tests 15from torch.testing._internal.distributed._tensor.common_dtensor import ( 16 DTensorTestBase, 17 with_comms, 18) 19 20 21class MLPListModule(torch.nn.Module): 22 """ 23 A dummy model with list of MLPs. 24 """ 25 26 def __init__(self, num_mlps=3, bias=True): 27 super().__init__() 28 self.mlps = torch.nn.ModuleList() 29 for _ in range(num_mlps): 30 self.mlps.append( 31 torch.nn.Sequential( 32 torch.nn.Linear(6, 18), 33 torch.nn.ReLU(), 34 torch.nn.Linear(18, 6, bias=bias), 35 ) 36 ) 37 38 def forward(self, x: torch.Tensor) -> torch.Tensor: 39 x = torch.chunk(x, 2, dim=1)[0] 40 for mlp in self.mlps: 41 x = mlp(x) 42 return x + torch.ones_like(x) 43 44 45class DummyModel(torch.nn.Module): 46 def __init__(self) -> None: 47 super().__init__() 48 self.fc = torch.nn.Linear(3, 5) 49 self.bn = torch.nn.BatchNorm1d(5) 50 51 def forward(self, x): 52 return self.bn(self.fc(x)) 53 54 55class TensorParallelTest(DTensorTestBase): 56 def setUp(self) -> None: 57 super().setUp() 58 59 def assert_has_c10d_ops( 60 self, gm: torch.fx.GraphModule, expected_ops_count: Dict[str, int] 61 ) -> None: 62 actual_ops_count: Dict[str, int] = defaultdict(int) 63 for node in gm.graph.nodes: 64 if node.op == "call_function": 65 if "c10d_functional" in str(node.target): 66 actual_ops_count[str(node.target)] += 1 67 self.assertDictEqual(expected_ops_count, actual_ops_count) 68 69 @with_comms 70 def test_tp_transform_with_uncovered_op(self): 71 model = DummyModel().to(device=self.device_type) 72 inputs = (torch.randn(7, 3, requires_grad=False).to(device=self.device_type),) 73 with torch.no_grad(): 74 res = model(*inputs) 75 exported_program = torch.export.export( 76 model, 77 inputs, 78 ).run_decompositions() 79 tp_exported_program = tensor_parallel_transformation( 80 exported_program, 81 self.rank, 82 self.world_size, 83 self.device_type, 84 {"fc": ColwiseParallel}, 85 ) 86 tp_model = tp_exported_program.module() 87 with torch.no_grad(): 88 tp_res = tp_model(*inputs) 89 self.assertEqual(res, tp_res) 90 # Expect all_gather to be inserted to distributed sharded fc resutls 91 self.assert_has_c10d_ops( 92 tp_exported_program.graph_module, 93 { 94 "_c10d_functional.all_gather_into_tensor.default": 1, 95 "_c10d_functional.wait_tensor.default": 1, 96 }, 97 ) 98 99 @with_comms 100 def test_tp_transform_e2e(self): 101 torch.manual_seed(0) 102 model = MLPListModule(2).to(device=self.device_type) 103 inputs = (torch.randn((10, 12)).to(device=self.device_type),) 104 parallel_strategies: Dict[str, ParallelStyle] = { 105 "mlps.0.0": ColwiseParallel, 106 "mlps.0.2": RowwiseParallel, 107 "mlps.1.0": ColwiseParallel, 108 "mlps.1.2": RowwiseParallel, 109 } 110 111 with torch.inference_mode(): 112 res = model(*inputs) 113 exported_program = torch.export.export( 114 model, 115 inputs, 116 ).run_decompositions() 117 tp_exported_program = tensor_parallel_transformation( 118 exported_program, 119 self.rank, 120 self.world_size, 121 self.device_type, 122 parallel_strategies, 123 ) 124 tp_model = tp_exported_program.module() 125 with torch.inference_mode(): 126 tp_res = tp_model(*inputs) 127 self.assertEqual(res, tp_res) 128 # Expect all_reduce to be inserted at the end of each MLP 129 self.assert_has_c10d_ops( 130 tp_exported_program.graph_module, 131 { 132 "_c10d_functional.all_reduce.default": 2, 133 "_c10d_functional.wait_tensor.default": 2, 134 }, 135 ) 136 137 @with_comms 138 def test_tp_transform_no_bias(self): 139 torch.manual_seed(0) 140 model = MLPListModule(1, bias=False).to(device=self.device_type) 141 inputs = (torch.randn((10, 12)).to(device=self.device_type),) 142 parallel_strategies: Dict[str, ParallelStyle] = { 143 "mlps.0.0": ColwiseParallel, 144 "mlps.0.2": RowwiseParallel, 145 } 146 147 with torch.inference_mode(): 148 res = model(*inputs) 149 exported_program = torch.export.export( 150 model, 151 inputs, 152 ).run_decompositions() 153 tp_exported_program = tensor_parallel_transformation( 154 exported_program, 155 self.rank, 156 self.world_size, 157 self.device_type, 158 parallel_strategies, 159 ) 160 tp_model = tp_exported_program.module() 161 with torch.inference_mode(): 162 tp_res = tp_model(*inputs) 163 self.assertEqual(res, tp_res) 164 self.assert_has_c10d_ops( 165 tp_exported_program.graph_module, 166 { 167 "_c10d_functional.all_reduce.default": 1, 168 "_c10d_functional.wait_tensor.default": 1, 169 }, 170 ) 171 172 173if __name__ == "__main__": 174 run_tests() 175