1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3import torch 4from torch.distributed.pipelining import pipe_split, pipeline 5from torch.testing._internal.common_utils import run_tests, TestCase 6 7 8# Building block for model 9class Block(torch.nn.Module): 10 def __init__(self) -> None: 11 super().__init__() 12 self.conv = torch.nn.Conv2d( 13 in_channels=16, out_channels=16, kernel_size=3, padding=1 14 ) 15 self.lin0 = torch.nn.Linear(256, 256) 16 self.relu = torch.nn.ReLU() 17 self.lin1 = torch.nn.Linear(256, 256) 18 19 def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor: 20 x = self.conv(x) 21 x = self.lin0(x) 22 pipe_split() 23 x.add_(constant) 24 x = self.lin1(x) 25 return self.relu(x) 26 27 28# Full model 29class M(torch.nn.Module): 30 def __init__(self) -> None: 31 super().__init__() 32 self.block0 = Block() 33 self.block1 = Block() 34 35 def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor: 36 x = self.block0(x, constant=constant) 37 pipe_split() 38 x = self.block1(x, constant=constant) 39 return x 40 41 42class UnflattenTests(TestCase): 43 def test_unflatten(self): 44 x = torch.randn(1, 16, 256, 256) 45 constant = torch.ones(1, 16, 256, 256) 46 47 mod = M() 48 49 pipe = pipeline( 50 mod, 51 (x,), 52 {"constant": constant}, 53 ) 54 55 assert pipe.num_stages == 4 56 orig_state_dict = mod.state_dict() 57 58 # Check qualnames 59 for stage_idx in range(pipe.num_stages): 60 stage_mod = pipe.get_stage_module(stage_idx) 61 for param_name, param in stage_mod.named_parameters(): 62 assert ( 63 param_name in orig_state_dict 64 ), f"{param_name} not in original state dict" 65 print("Param qualname test passed") 66 67 # Check equivalence 68 ref = mod(x, constant) 69 out = pipe(x, constant)[0] 70 torch.testing.assert_close(out, ref) 71 print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") 72 73 74if __name__ == "__main__": 75 run_tests() 76