# Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] import torch from torch.distributed.pipelining import pipe_split, pipeline from torch.testing._internal.common_utils import run_tests, TestCase # Building block for model class Block(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d( in_channels=16, out_channels=16, kernel_size=3, padding=1 ) self.lin0 = torch.nn.Linear(256, 256) self.relu = torch.nn.ReLU() self.lin1 = torch.nn.Linear(256, 256) def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor: x = self.conv(x) x = self.lin0(x) pipe_split() x.add_(constant) x = self.lin1(x) return self.relu(x) # Full model class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.block0 = Block() self.block1 = Block() def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor: x = self.block0(x, constant=constant) pipe_split() x = self.block1(x, constant=constant) return x class UnflattenTests(TestCase): def test_unflatten(self): x = torch.randn(1, 16, 256, 256) constant = torch.ones(1, 16, 256, 256) mod = M() pipe = pipeline( mod, (x,), {"constant": constant}, ) assert pipe.num_stages == 4 orig_state_dict = mod.state_dict() # Check qualnames for stage_idx in range(pipe.num_stages): stage_mod = pipe.get_stage_module(stage_idx) for param_name, param in stage_mod.named_parameters(): assert ( param_name in orig_state_dict ), f"{param_name} not in original state dict" print("Param qualname test passed") # Check equivalence ref = mod(x, constant) out = pipe(x, constant)[0] torch.testing.assert_close(out, ref) print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") if __name__ == "__main__": run_tests()