xref: /aosp_15_r20/external/pytorch/test/distributed/pipelining/test_transformer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3import torch
4from torch.distributed.pipelining import pipeline, SplitPoint
5from torch.testing._internal.common_utils import run_tests, TestCase
6
7
8d_hid = 16
9n_layers = 8
10microbatch_size = 4
11
12
13class MLPModule(torch.nn.Module):
14    def __init__(self, d_hid):
15        super().__init__()
16        self.net1 = torch.nn.Linear(d_hid, d_hid)
17        self.relu = torch.nn.ReLU()
18        self.net2 = torch.nn.Linear(d_hid, d_hid)
19
20    def forward(self, x):
21        x = self.net1(x)
22        x = self.relu(x)
23        x = self.net2(x)
24        return x
25
26
27class TransformerLike(torch.nn.Module):
28    def __init__(self) -> None:
29        super().__init__()
30        self.layers = torch.nn.Sequential(*[MLPModule(d_hid) for _ in range(n_layers)])
31
32    def forward(self, x: torch.Tensor) -> torch.Tensor:
33        return self.layers(x)
34
35
36class TransformerTests(TestCase):
37    def test_ir(self):
38        transformer = TransformerLike()
39        x = torch.randn(microbatch_size, d_hid)
40
41        # Split into 2 stages
42        num_stages = 2
43        split_spec = {f"layers.{n_layers // num_stages}": SplitPoint.BEGINNING}
44
45        pipe = pipeline(
46            transformer,
47            (x,),
48            split_spec=split_spec,
49        )
50        assert pipe.num_stages == num_stages, f"{pipe.num_stages=}, expect {num_stages}"
51
52        def get_layers(module):
53            layers = [name for name, _ in module.layers.named_children()]
54            return layers
55
56        # Collect all layers in pipe
57        layers = []
58        for stage_idx in range(pipe.num_stages):
59            stage_mod = pipe.get_stage_module(stage_idx)
60            layers += get_layers(stage_mod)
61
62        # Check layer completeness
63        orig_layers = get_layers(transformer)
64        assert sorted(layers) == sorted(orig_layers), f"{layers} != {orig_layers}"
65        print("Layers matched!")
66
67        # Check equivalence
68        ref = transformer(x)
69        out = pipe(x)[0]
70        torch.testing.assert_close(out, ref)
71        print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
72
73
74if __name__ == "__main__":
75    run_tests()
76