xref: /aosp_15_r20/external/pytorch/c10/core/Device.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/DeviceType.h>
4 #include <c10/macros/Export.h>
5 #include <c10/util/Exception.h>
6 
7 #include <cstddef>
8 #include <cstdint>
9 #include <functional>
10 #include <iosfwd>
11 #include <string>
12 
13 namespace c10 {
14 
15 /// An index representing a specific device; e.g., the 1 in GPU 1.
16 /// A DeviceIndex is not independently meaningful without knowing
17 /// the DeviceType it is associated; try to use Device rather than
18 /// DeviceIndex directly.
19 using DeviceIndex = int8_t;
20 
21 /// Represents a compute device on which a tensor is located. A device is
22 /// uniquely identified by a type, which specifies the type of machine it is
23 /// (e.g. CPU or CUDA GPU), and a device index or ordinal, which identifies the
24 /// specific compute device when there is more than one of a certain type. The
25 /// device index is optional, and in its defaulted state represents (abstractly)
26 /// "the current device". Further, there are two constraints on the value of the
27 /// device index, if one is explicitly stored:
28 /// 1. A negative index represents the current device, a non-negative index
29 /// represents a specific, concrete device,
30 /// 2. When the device type is CPU, the device index must be zero.
31 struct C10_API Device final {
32   using Type = DeviceType;
33 
34   /// Constructs a new `Device` from a `DeviceType` and an optional device
35   /// index.
36   /* implicit */ Device(DeviceType type, DeviceIndex index = -1)
type_final37       : type_(type), index_(index) {
38     validate();
39   }
40 
41   /// Constructs a `Device` from a string description, for convenience.
42   /// The string supplied must follow the following schema:
43   /// `(cpu|cuda)[:<device-index>]`
44   /// where `cpu` or `cuda` specifies the device type, and
45   /// `:<device-index>` optionally specifies a device index.
46   /* implicit */ Device(const std::string& device_string);
47 
48   /// Returns true if the type and index of this `Device` matches that of
49   /// `other`.
50   bool operator==(const Device& other) const noexcept {
51     return this->type_ == other.type_ && this->index_ == other.index_;
52   }
53 
54   /// Returns true if the type or index of this `Device` differs from that of
55   /// `other`.
56   bool operator!=(const Device& other) const noexcept {
57     return !(*this == other);
58   }
59 
60   /// Sets the device index.
set_indexfinal61   void set_index(DeviceIndex index) {
62     index_ = index;
63   }
64 
65   /// Returns the type of device this is.
typefinal66   DeviceType type() const noexcept {
67     return type_;
68   }
69 
70   /// Returns the optional index.
indexfinal71   DeviceIndex index() const noexcept {
72     return index_;
73   }
74 
75   /// Returns true if the device has a non-default index.
has_indexfinal76   bool has_index() const noexcept {
77     return index_ != -1;
78   }
79 
80   /// Return true if the device is of CUDA type.
is_cudafinal81   bool is_cuda() const noexcept {
82     return type_ == DeviceType::CUDA;
83   }
84 
85   /// Return true if the device is of PrivateUse1 type.
is_privateuseonefinal86   bool is_privateuseone() const noexcept {
87     return type_ == DeviceType::PrivateUse1;
88   }
89 
90   /// Return true if the device is of MPS type.
is_mpsfinal91   bool is_mps() const noexcept {
92     return type_ == DeviceType::MPS;
93   }
94 
95   /// Return true if the device is of HIP type.
is_hipfinal96   bool is_hip() const noexcept {
97     return type_ == DeviceType::HIP;
98   }
99 
100   /// Return true if the device is of VE type.
is_vefinal101   bool is_ve() const noexcept {
102     return type_ == DeviceType::VE;
103   }
104 
105   /// Return true if the device is of XPU type.
is_xpufinal106   bool is_xpu() const noexcept {
107     return type_ == DeviceType::XPU;
108   }
109 
110   /// Return true if the device is of IPU type.
is_ipufinal111   bool is_ipu() const noexcept {
112     return type_ == DeviceType::IPU;
113   }
114 
115   /// Return true if the device is of XLA type.
is_xlafinal116   bool is_xla() const noexcept {
117     return type_ == DeviceType::XLA;
118   }
119 
120   /// Return true if the device is of MTIA type.
is_mtiafinal121   bool is_mtia() const noexcept {
122     return type_ == DeviceType::MTIA;
123   }
124 
125   /// Return true if the device is of HPU type.
is_hpufinal126   bool is_hpu() const noexcept {
127     return type_ == DeviceType::HPU;
128   }
129 
130   /// Return true if the device is of Lazy type.
is_lazyfinal131   bool is_lazy() const noexcept {
132     return type_ == DeviceType::Lazy;
133   }
134 
135   /// Return true if the device is of Vulkan type.
is_vulkanfinal136   bool is_vulkan() const noexcept {
137     return type_ == DeviceType::Vulkan;
138   }
139 
140   /// Return true if the device is of Metal type.
is_metalfinal141   bool is_metal() const noexcept {
142     return type_ == DeviceType::Metal;
143   }
144 
145   /// Return true if the device is of MAIA type.
is_maiafinal146   bool is_maia() const noexcept {
147     return type_ == DeviceType::MAIA;
148   }
149 
150   /// Return true if the device is of META type.
is_metafinal151   bool is_meta() const noexcept {
152     return type_ == DeviceType::Meta;
153   }
154 
155   /// Return true if the device is of CPU type.
is_cpufinal156   bool is_cpu() const noexcept {
157     return type_ == DeviceType::CPU;
158   }
159 
160   /// Return true if the device supports arbitrary strides.
supports_as_stridedfinal161   bool supports_as_strided() const noexcept {
162     return type_ != DeviceType::IPU && type_ != DeviceType::XLA &&
163         type_ != DeviceType::Lazy && type_ != DeviceType::MTIA;
164   }
165 
166   /// Same string as returned from operator<<.
167   std::string str() const;
168 
169  private:
170   DeviceType type_;
171   DeviceIndex index_ = -1;
validatefinal172   void validate() {
173     // Removing these checks in release builds noticeably improves
174     // performance in micro-benchmarks.
175     // This is safe to do, because backends that use the DeviceIndex
176     // have a later check when we actually try to switch to that device.
177     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
178         index_ >= -1,
179         "Device index must be -1 or non-negative, got ",
180         static_cast<int>(index_));
181     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
182         !is_cpu() || index_ <= 0,
183         "CPU device index must be -1 or zero, got ",
184         static_cast<int>(index_));
185   }
186 };
187 
188 C10_API std::ostream& operator<<(std::ostream& stream, const Device& device);
189 
190 } // namespace c10
191 
192 namespace std {
193 template <>
194 struct hash<c10::Device> {
195   size_t operator()(c10::Device d) const noexcept {
196     // Are you here because this static assert failed?  Make sure you ensure
197     // that the bitmasking code below is updated accordingly!
198     static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit");
199     static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit");
200     // Note [Hazard when concatenating signed integers]
201     // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
202     // We must first convert to a same-sized unsigned type, before promoting to
203     // the result type, to prevent sign extension when any of the values is -1.
204     // If sign extension occurs, you'll clobber all of the values in the MSB
205     // half of the resulting integer.
206     //
207     // Technically, by C/C++ integer promotion rules, we only need one of the
208     // uint32_t casts to the result type, but we put in both for explicitness's
209     // sake.
210     uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
211             << 16 |
212         static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
213     return std::hash<uint32_t>{}(bits);
214   }
215 };
216 } // namespace std
217