xref: /aosp_15_r20/external/pytorch/torch/distributed/remote_device.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Optional, Union
3
4import torch
5
6
7class _remote_device:
8    """
9    Represents a device on a remote worker.
10
11    Args:
12        remote_device (str or torch.device): Represents a device on a remote worker.
13            The string format should be one of the following:
14
15                1. "<workername>/<device>", where the device field can be parsed as torch.device type.
16                   E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
17                   In addition, the device field can be optional and the default value is "cpu".
18                2. "rank:<rank>/<device>", where <rank> is the rank of the
19                   process and device can be parsed as torch.device type.
20                   E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0"
21                3. <workername> and <rank> are optional and formats like "cpu"
22                    and "cuda:1", just represent local devices.
23    """
24
25    def __init__(self, remote_device: Union[str, torch.device]):
26        PARSE_ERROR = (
27            f"Could not parse remote_device: {remote_device}. The valid format is "
28            "'<workername>/<device>' or 'rank:<rank>/<device>' or '<device>'"
29        )
30        self._worker_name = None
31        self._rank = None
32        self._device: Optional[Union[str, int, torch.device]] = None
33
34        if isinstance(remote_device, torch.device):
35            self._device = remote_device
36        elif isinstance(remote_device, str):
37            fields = remote_device.split("/")
38            if len(fields) == 2:
39                self._worker_name, self._device = fields
40            elif len(fields) == 1:
41                # Check if this is a valid device.
42                if _remote_device._is_valid_local_device(fields[0]):
43                    self._device = fields[0]
44                else:
45                    self._worker_name = fields[0]
46                    self._device = "cpu"
47            else:
48                raise ValueError(PARSE_ERROR)
49        else:
50            raise TypeError(f"Invalid type for remote_device: {type(remote_device)}")
51
52        # Do some basic sanity check (no empty string)
53        if self._worker_name is not None and not self._worker_name:
54            raise ValueError(PARSE_ERROR)
55
56        # Validate the device.
57        self._device = torch.device(self._device)
58
59        # Check for rank based format.
60        if self._worker_name is not None:
61            fields = self._worker_name.split(":")
62            if len(fields) == 2:
63                # rank:<rank>/device format, extract rank
64                if fields[0] == "rank" and fields[1].isdigit():
65                    self._rank = int(fields[1])  # type: ignore[assignment]
66                    self._worker_name = None
67                else:
68                    raise ValueError(PARSE_ERROR)
69            elif len(fields) > 2:
70                raise ValueError(PARSE_ERROR)
71
72    @staticmethod
73    def _is_valid_local_device(device):
74        # Check for torch.device
75        try:
76            torch.device(device)
77            return True
78        except Exception:
79            return False
80
81    def worker_name(self) -> Optional[str]:
82        """Return the name of remote worker representing the remote device and ``None`` if no worker name is available."""
83        return self._worker_name
84
85    def rank(self) -> Optional[int]:
86        """
87        Returns the rank of remote worker representing the remote device.
88        Returns ``None`` if no rank is available.
89        """
90        return self._rank
91
92    def device(self) -> torch.device:
93        """Return the local device on the remote worker."""
94        return self._device  # type: ignore[return-value]
95
96    def __repr__(self):
97        if self._device is not None:
98            if self._worker_name is not None:
99                return f"{self._worker_name}/{self._device}"
100            elif self._rank is not None:
101                return f"rank:{self._rank}/{self._device}"
102            else:
103                return str(self._device)
104        else:
105            if self._worker_name is not None:
106                return f"{self._worker_name}"
107            elif self._rank is not None:
108                return f"{self._rank}"
109            else:
110                raise RuntimeError("Invalid state!")
111
112    def __eq__(self, other):
113        return isinstance(other, _remote_device) and (
114            self._worker_name == other._worker_name
115            and self._device == other._device
116            and self._rank == other._rank
117        )
118
119    def __hash__(self):
120        return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank)
121