xref: /aosp_15_r20/external/pytorch/test/distributed/pipelining/test_microbatch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3from model_registry import ModelWithKwargs
4
5import torch
6from torch.distributed.pipelining import pipeline
7from torch.distributed.pipelining.microbatch import (
8    merge_chunks,
9    split_args_kwargs_into_chunks,
10    TensorChunkSpec,
11)
12from torch.testing._internal.common_utils import run_tests, TestCase
13
14
15d_hid = 512
16torch.manual_seed(0)
17
18
19class MicrobatchTests(TestCase):
20    def test_split_and_merge(self):
21        x0 = torch.randn(128, d_hid)
22        x1 = torch.randn(256, d_hid)
23        x2 = torch.randn(512, d_hid)
24
25        args = (x0, x1, x2)
26        kwargs = {"x0": x0, "x1": x1, "x2": x2}
27
28        # Default chunking: dim 0
29        arg_chunks, kwarg_chunks = split_args_kwargs_into_chunks(args, kwargs, 2)
30        assert len(arg_chunks) == 2
31        assert len(kwarg_chunks) == 2
32        assert arg_chunks[0][0].shape == torch.Size([64, d_hid])
33        assert arg_chunks[1][0].shape == torch.Size([64, d_hid])
34        assert arg_chunks[0][1].shape == torch.Size([128, d_hid])
35        assert arg_chunks[0][2].shape == torch.Size([256, d_hid])
36        assert kwarg_chunks[0]["x0"].shape == torch.Size([64, d_hid])
37        assert kwarg_chunks[0]["x1"].shape == torch.Size([128, d_hid])
38        assert kwarg_chunks[1]["x2"].shape == torch.Size([256, d_hid])
39
40        # Merge chunks back together
41        merged_args = merge_chunks(
42            arg_chunks,
43            (TensorChunkSpec(0), TensorChunkSpec(0), TensorChunkSpec(0)),
44        )
45        torch.testing.assert_close(merged_args, args)
46
47        merged_kwargs = merge_chunks(
48            kwarg_chunks,
49            {
50                "x0": TensorChunkSpec(0),
51                "x1": TensorChunkSpec(0),
52                "x2": TensorChunkSpec(0),
53            },
54        )
55        torch.testing.assert_close(merged_kwargs, kwargs)
56        print("Microbatch test passed")
57
58    def test_chunk_spec(self):
59        mod = ModelWithKwargs()
60        batch_size = ModelWithKwargs.DEFAULT_BATCH_SIZE
61
62        x = torch.randn(batch_size, d_hid)
63        y = torch.randn(batch_size, d_hid)
64
65        num_chunks = 4
66
67        args_chunk_spec = TensorChunkSpec.from_tuple((0,))
68        kwargs_chunk_spec = TensorChunkSpec.from_dict({"y": 0})
69
70        args_split, kwargs_split = split_args_kwargs_into_chunks(
71            (x,),
72            {"y": y},
73            num_chunks,
74            args_chunk_spec,
75            kwargs_chunk_spec,
76        )
77
78        pipe = pipeline(
79            mod,
80            mb_args=args_split[0],
81            mb_kwargs=kwargs_split[0],
82        )
83
84        ref = mod(x, y)
85        out = pipe(x, y)[0]
86        torch.testing.assert_close(out, ref)
87        print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
88
89
90if __name__ == "__main__":
91    run_tests()
92