1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3from model_registry import MLPModule, ModelWithParamAlias 4 5import torch 6from torch.distributed.pipelining import pipe_split, pipeline 7from torch.testing._internal.common_utils import ( 8 instantiate_parametrized_tests, 9 parametrize, 10 run_tests, 11 TestCase, 12) 13 14 15d_hid = 512 16microbatch_size = 16 17 18torch.manual_seed(0) 19 20 21# Basic example 22class ExampleCode(torch.nn.Module): 23 def __init__(self) -> None: 24 super().__init__() 25 self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 26 self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 27 self.lin1 = torch.nn.Linear(d_hid, d_hid) 28 self.lin2 = torch.nn.Linear(d_hid, d_hid) 29 30 def forward(self, x, y): 31 x = torch.mm(x, self.mm_param1) # mutli-use param 32 skip_connection = x 33 x = x + y 34 x = torch.relu(x) 35 pipe_split() 36 x = torch.mm(x, self.mm_param1) # mutli-use param 37 x = self.lin1(x) 38 pipe_split() 39 x = torch.relu(x) 40 x = x + skip_connection 41 x = torch.mm(x, self.mm_param2) 42 pipe_split() 43 x = self.lin2(x) 44 x = torch.relu(x) 45 return x 46 47 48class MultiMLP(torch.nn.Module): 49 def __init__(self) -> None: 50 super().__init__() 51 self.mlp0 = MLPModule(d_hid) 52 self.mlp1 = MLPModule(d_hid) 53 self.mlp2 = MLPModule(d_hid) 54 self.mlp3 = MLPModule(d_hid) 55 56 def forward(self, x, y): 57 x = self.mlp0(x) 58 pipe_split() 59 x = self.mlp1(x) 60 pipe_split() 61 x = self.mlp2(x) 62 pipe_split() 63 x = self.mlp3(x) 64 return x - y 65 66 67EXPECTED_N_STAGES = { 68 ExampleCode: 4, 69 MultiMLP: 4, 70 ModelWithParamAlias: 2, 71} 72 73# Currently, we don't enforce full set equality on the FQNs between the original 74# and pipelined models, because in the multi-use param case, PP will deduplicate 75# the FQNs from the state_dict. 76# TODO 77CHECK_FQN_SET_EQUALITY = False 78 79 80class PipeTests(TestCase): 81 @parametrize("ModelClass", [ExampleCode, MultiMLP, ModelWithParamAlias]) 82 def test_model_split(self, ModelClass): 83 mod = ModelClass() 84 x = torch.randn(microbatch_size, d_hid) 85 y = torch.randn(microbatch_size, d_hid) 86 87 pipe = pipeline( 88 mod, 89 mb_args=(x, y), 90 ) 91 92 assert ( 93 pipe.num_stages == EXPECTED_N_STAGES[ModelClass] 94 ), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}" 95 96 ref_out = mod(x, y) 97 out = pipe(x, y)[0] 98 torch.testing.assert_close(out, ref_out) 99 print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}") 100 101 # Check qualname 102 # state_dict.keys include both parameters and persistent buffers 103 old_names = set(mod.state_dict().keys()) 104 new_names = set() 105 for idx in range(pipe.num_stages): 106 stage_mod = pipe.get_stage_module(idx) 107 stage_fqns = set(stage_mod.state_dict().keys()) 108 assert stage_fqns.issubset(old_names) 109 new_names.update(stage_fqns) 110 111 if CHECK_FQN_SET_EQUALITY: 112 assert ( 113 old_names == new_names 114 ), f""" 115 old names {old_names} 116 new names {new_names} 117 """ 118 print("Qualname check passed") 119 120 121instantiate_parametrized_tests(PipeTests) 122 123if __name__ == "__main__": 124 run_tests() 125