1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3import torch 4from torch.distributed.pipelining import pipeline, SplitPoint 5from torch.testing._internal.common_utils import run_tests, TestCase 6 7 8d_hid = 16 9n_layers = 8 10microbatch_size = 4 11 12 13class MLPModule(torch.nn.Module): 14 def __init__(self, d_hid): 15 super().__init__() 16 self.net1 = torch.nn.Linear(d_hid, d_hid) 17 self.relu = torch.nn.ReLU() 18 self.net2 = torch.nn.Linear(d_hid, d_hid) 19 20 def forward(self, x): 21 x = self.net1(x) 22 x = self.relu(x) 23 x = self.net2(x) 24 return x 25 26 27class TransformerLike(torch.nn.Module): 28 def __init__(self) -> None: 29 super().__init__() 30 self.layers = torch.nn.Sequential(*[MLPModule(d_hid) for _ in range(n_layers)]) 31 32 def forward(self, x: torch.Tensor) -> torch.Tensor: 33 return self.layers(x) 34 35 36class TransformerTests(TestCase): 37 def test_ir(self): 38 transformer = TransformerLike() 39 x = torch.randn(microbatch_size, d_hid) 40 41 # Split into 2 stages 42 num_stages = 2 43 split_spec = {f"layers.{n_layers // num_stages}": SplitPoint.BEGINNING} 44 45 pipe = pipeline( 46 transformer, 47 (x,), 48 split_spec=split_spec, 49 ) 50 assert pipe.num_stages == num_stages, f"{pipe.num_stages=}, expect {num_stages}" 51 52 def get_layers(module): 53 layers = [name for name, _ in module.layers.named_children()] 54 return layers 55 56 # Collect all layers in pipe 57 layers = [] 58 for stage_idx in range(pipe.num_stages): 59 stage_mod = pipe.get_stage_module(stage_idx) 60 layers += get_layers(stage_mod) 61 62 # Check layer completeness 63 orig_layers = get_layers(transformer) 64 assert sorted(layers) == sorted(orig_layers), f"{layers} != {orig_layers}" 65 print("Layers matched!") 66 67 # Check equivalence 68 ref = transformer(x) 69 out = pipe(x)[0] 70 torch.testing.assert_close(out, ref) 71 print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") 72 73 74if __name__ == "__main__": 75 run_tests() 76