xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import sys
4from functools import wraps, partial
5
6import torch
7import torch.distributed as dist
8from torch.distributed import rpc
9from torch.testing._internal.common_distributed import (
10    MultiProcessTestCase,
11    TEST_SKIPS,
12    tp_transports,
13)
14
15TEST_GPU_NUM = 4
16
17class ShardedTensorTestBase(MultiProcessTestCase):
18    @property
19    def world_size(self):
20        return TEST_GPU_NUM
21
22    def init_pg(self, backend="nccl"):
23        if backend not in ["nccl", "gloo", "mpi"]:
24            raise RuntimeError(f"Backend {backend} not supported!")
25
26        dist.init_process_group(
27            backend=backend,
28            world_size=self.world_size,
29            rank=self.rank,
30            init_method=f"file://{self.file_name}",
31        )
32
33        # set device for nccl pg for collectives
34        if backend == "nccl":
35            torch.cuda.set_device(self.rank)
36
37
38    def init_rpc(self):
39        rpc_backend_options = rpc.TensorPipeRpcBackendOptions(_transports=tp_transports())
40        rpc_backend_options.init_method = f"file://{self.file_name}"
41        for rank in range(self.world_size):
42            rpc_backend_options.set_device_map(
43                f"worker{rank}", {rank: self.rank, self.rank: rank}
44            )
45
46        rpc.init_rpc(
47            name="worker%d" % self.rank,
48            rank=self.rank,
49            world_size=self.world_size,
50            rpc_backend_options=rpc_backend_options,
51        )
52
53    def init_comms(self, init_rpc=True, backend="nccl"):
54        if init_rpc:
55            self.init_rpc()
56        self.init_pg(backend=backend)
57
58    def destroy_comms(self, destroy_rpc=True):
59        # Wait for all ranks to reach here before starting shutdown.
60        dist.barrier()
61
62        if destroy_rpc:
63            rpc.shutdown()
64        dist.destroy_process_group()
65
66    def setUp(self) -> None:
67        super().setUp()
68        self._spawn_processes()
69
70    def assert_sharded_tensor_equal(self, st1, st2):
71        st1_local_shards = st1.local_shards()
72        st2_local_shards = st2.local_shards()
73        self.assertEqual(len(st1_local_shards), len(st2_local_shards))
74        for i, st1_local_shard in enumerate(st1_local_shards):
75            self.assertEqual(st1_local_shard.tensor, st2_local_shards[i].tensor)
76            self.assertEqual(st1_local_shard.metadata, st2_local_shards[i].metadata)
77
78        self.assertEqual(st1.metadata(), st2.metadata())
79        self.assertEqual(st1.sharding_spec(), st2.sharding_spec())
80        self.assertEqual(len(st1.remote_shards()), len(st2.remote_shards()))
81
82# wrapper to initialize comms (processgroup + rpc)
83def with_comms(func=None, init_rpc=True, backend="nccl"):
84    if func is None:
85        return partial(
86            with_comms,
87            init_rpc=init_rpc,
88            backend=backend,
89        )
90
91    @wraps(func)
92    def wrapper(self, *args, **kwargs):
93        if backend == "nccl" and torch.cuda.device_count() < self.world_size:
94            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
95        self.init_comms(init_rpc=init_rpc, backend=backend)
96        func(self, *args, **kwargs)
97        self.destroy_comms(destroy_rpc=init_rpc)
98    return wrapper
99