xref: /aosp_15_r20/external/pytorch/test/distributed/_shard/test_sharder.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2import copy
3import sys
4
5import torch
6import torch.nn as nn
7from torch.distributed._shard import shard_module
8from torch.distributed._shard.sharded_tensor import ShardedTensor
9from torch.distributed._shard.sharder import Sharder
10from torch.distributed._shard.sharding_plan import ShardingPlan
11from torch.distributed._shard.sharding_spec import ChunkShardingSpec
12from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
13from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
14from torch.testing._internal.distributed._shard.sharded_tensor import (
15    ShardedTensorTestBase,
16    TEST_GPU_NUM,
17    with_comms,
18)
19
20
21if TEST_WITH_DEV_DBG_ASAN:
22    print(
23        "Skip dev-asan as torch + multiprocessing spawn have known issues",
24        file=sys.stderr,
25    )
26    sys.exit(0)
27
28
29# a simple collection of embedding bag implementation
30class CustomEmbeddingBagCollection(nn.Module):
31    def __init__(self, num_bags, num_embeddings_per_bag, num_dims):
32        super().__init__()
33        self.num_bags = num_bags
34        self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
35
36        for i in range(num_bags):
37            self.embedding_bags[f"embedding_bag_{i}"] = nn.EmbeddingBag(
38                num_embeddings_per_bag, num_dims, mode="sum"
39            )
40
41    def forward(self, inputs):
42        outputs = []
43        for bag in self.embedding_bags.values():
44            outputs.append(bag(inputs))
45        return torch.cat(outputs)
46
47
48# a simple sharded version of EBC
49class CustomShardedEBC(nn.Module):
50    def __init__(self, ebc, split_idx, specs):
51        super().__init__()
52        self.split_idx = split_idx
53        row_spec, col_spec = specs
54
55        # create embedding bags base on the spec
56        self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
57
58        assert self.split_idx < ebc.num_bags
59        for i in range(ebc.num_bags):
60            bag_key = f"embedding_bag_{i}"
61            if i < self.split_idx:
62                shard_module(
63                    ebc,
64                    plan=ShardingPlan(
65                        plan={f"embedding_bags.{bag_key}.weight": row_spec}
66                    ),
67                )
68            else:
69                shard_module(
70                    ebc,
71                    plan=ShardingPlan(
72                        plan={f"embedding_bags.{bag_key}.weight": col_spec}
73                    ),
74                )
75
76            self.embedding_bags[bag_key] = ebc.embedding_bags[bag_key]
77
78
79class CustomSharder(Sharder):
80    def __init__(self, devices, split_sharding_idx):
81        self.devices = devices
82        self.split_sharding_idx = split_sharding_idx
83        self.rowwise_spec = ChunkShardingSpec(dim=0, placements=devices)
84        self.colwise_spec = ChunkShardingSpec(dim=1, placements=devices)
85
86    def shard(self, ebc: nn.Module) -> nn.Module:
87        if not isinstance(ebc, CustomEmbeddingBagCollection):
88            raise RuntimeError(
89                "The custom sharder only supports CustomEmbeddingBagCollection"
90            )
91
92        return CustomShardedEBC(
93            ebc, self.split_sharding_idx, (self.rowwise_spec, self.colwise_spec)
94        )
95
96
97class TestCustomSharder(ShardedTensorTestBase):
98    @with_comms(init_rpc=False)
99    @skip_if_lt_x_gpu(TEST_GPU_NUM)
100    @requires_nccl()
101    def test_custom_sharder(self):
102        class MyModule(nn.Module):
103            def __init__(self) -> None:
104                super().__init__()
105                self.ebc = CustomEmbeddingBagCollection(10, 10, 8)
106
107            def forward(self, inputs):
108                return self.ebc(inputs)
109
110        custom_sharder = CustomSharder(
111            devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)],
112            split_sharding_idx=TEST_GPU_NUM // 2,
113        )
114
115        sharding_plan = ShardingPlan(
116            plan={
117                "ebc": custom_sharder,
118            }
119        )
120
121        local_model = MyModule().cuda(self.rank)
122        sharded_model = copy.deepcopy(local_model)
123
124        # shard the module with the provided sharding plan
125        shard_module(sharded_model, sharding_plan)
126
127        # check to make sure the module already been sharded
128        emb_bags = sharded_model.ebc.embedding_bags
129        self.assertTrue(isinstance(emb_bags["embedding_bag_0"].weight, ShardedTensor))
130        self.assertTrue(isinstance(emb_bags["embedding_bag_9"].weight, ShardedTensor))
131        self.assertEqual(
132            emb_bags["embedding_bag_0"].weight.sharding_spec(),
133            custom_sharder.rowwise_spec,
134        )
135        self.assertEqual(
136            emb_bags["embedding_bag_9"].weight.sharding_spec(),
137            custom_sharder.colwise_spec,
138        )
139
140        # make sure we can run sharded computation and compare outputs
141        # with the local model version
142        input = torch.arange(8).reshape((2, 4)).cuda(self.rank)
143        local_output = local_model(input)
144        sharded_output = sharded_model(input)
145
146        self.assertEqual(local_output, sharded_output)
147
148    @with_comms(init_rpc=False)
149    @skip_if_lt_x_gpu(TEST_GPU_NUM)
150    @requires_nccl()
151    def test_custom_sharder_errors(self):
152        custom_sharder = CustomSharder(
153            devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)],
154            split_sharding_idx=TEST_GPU_NUM // 2,
155        )
156
157        sharding_plan = ShardingPlan(
158            plan={
159                "": custom_sharder,
160            }
161        )
162
163        sharded_model = CustomEmbeddingBagCollection(10, 10, 8).cuda(self.rank)
164
165        with self.assertRaisesRegex(
166            KeyError, "path must not be empty for custom sharder!"
167        ):
168            # shard the module with the provided sharding plan
169            shard_module(sharded_model, sharding_plan)
170
171        # test conflicted sharding plan
172        spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:0", "rank:1/cuda:1"])
173        sharding_plan = ShardingPlan(
174            plan={
175                "embedding_bags.embedding_bag_0.weight": spec,
176                "embedding_bags": custom_sharder,
177            }
178        )
179
180        with self.assertRaisesRegex(
181            RuntimeError, "should not conflict with the submodule tree"
182        ):
183            # shard the module with the provided sharding plan
184            shard_module(sharded_model, sharding_plan)
185
186
187if __name__ == "__main__":
188    run_tests()
189