xref: /aosp_15_r20/external/pytorch/test/distributed/pipelining/test_pipe.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3from model_registry import MLPModule, ModelWithParamAlias
4
5import torch
6from torch.distributed.pipelining import pipe_split, pipeline
7from torch.testing._internal.common_utils import (
8    instantiate_parametrized_tests,
9    parametrize,
10    run_tests,
11    TestCase,
12)
13
14
15d_hid = 512
16microbatch_size = 16
17
18torch.manual_seed(0)
19
20
21# Basic example
22class ExampleCode(torch.nn.Module):
23    def __init__(self) -> None:
24        super().__init__()
25        self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
26        self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
27        self.lin1 = torch.nn.Linear(d_hid, d_hid)
28        self.lin2 = torch.nn.Linear(d_hid, d_hid)
29
30    def forward(self, x, y):
31        x = torch.mm(x, self.mm_param1)  # mutli-use param
32        skip_connection = x
33        x = x + y
34        x = torch.relu(x)
35        pipe_split()
36        x = torch.mm(x, self.mm_param1)  # mutli-use param
37        x = self.lin1(x)
38        pipe_split()
39        x = torch.relu(x)
40        x = x + skip_connection
41        x = torch.mm(x, self.mm_param2)
42        pipe_split()
43        x = self.lin2(x)
44        x = torch.relu(x)
45        return x
46
47
48class MultiMLP(torch.nn.Module):
49    def __init__(self) -> None:
50        super().__init__()
51        self.mlp0 = MLPModule(d_hid)
52        self.mlp1 = MLPModule(d_hid)
53        self.mlp2 = MLPModule(d_hid)
54        self.mlp3 = MLPModule(d_hid)
55
56    def forward(self, x, y):
57        x = self.mlp0(x)
58        pipe_split()
59        x = self.mlp1(x)
60        pipe_split()
61        x = self.mlp2(x)
62        pipe_split()
63        x = self.mlp3(x)
64        return x - y
65
66
67EXPECTED_N_STAGES = {
68    ExampleCode: 4,
69    MultiMLP: 4,
70    ModelWithParamAlias: 2,
71}
72
73# Currently, we don't enforce full set equality on the FQNs between the original
74# and pipelined models, because in the multi-use param case, PP will deduplicate
75# the FQNs from the state_dict.
76# TODO
77CHECK_FQN_SET_EQUALITY = False
78
79
80class PipeTests(TestCase):
81    @parametrize("ModelClass", [ExampleCode, MultiMLP, ModelWithParamAlias])
82    def test_model_split(self, ModelClass):
83        mod = ModelClass()
84        x = torch.randn(microbatch_size, d_hid)
85        y = torch.randn(microbatch_size, d_hid)
86
87        pipe = pipeline(
88            mod,
89            mb_args=(x, y),
90        )
91
92        assert (
93            pipe.num_stages == EXPECTED_N_STAGES[ModelClass]
94        ), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}"
95
96        ref_out = mod(x, y)
97        out = pipe(x, y)[0]
98        torch.testing.assert_close(out, ref_out)
99        print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}")
100
101        # Check qualname
102        # state_dict.keys include both parameters and persistent buffers
103        old_names = set(mod.state_dict().keys())
104        new_names = set()
105        for idx in range(pipe.num_stages):
106            stage_mod = pipe.get_stage_module(idx)
107            stage_fqns = set(stage_mod.state_dict().keys())
108            assert stage_fqns.issubset(old_names)
109            new_names.update(stage_fqns)
110
111        if CHECK_FQN_SET_EQUALITY:
112            assert (
113                old_names == new_names
114            ), f"""
115            old names {old_names}
116            new names {new_names}
117            """
118        print("Qualname check passed")
119
120
121instantiate_parametrized_tests(PipeTests)
122
123if __name__ == "__main__":
124    run_tests()
125