1 #pragma once 2 3 #include <c10/core/DeviceType.h> 4 #include <c10/macros/Export.h> 5 #include <c10/util/Exception.h> 6 7 #include <cstddef> 8 #include <cstdint> 9 #include <functional> 10 #include <iosfwd> 11 #include <string> 12 13 namespace c10 { 14 15 /// An index representing a specific device; e.g., the 1 in GPU 1. 16 /// A DeviceIndex is not independently meaningful without knowing 17 /// the DeviceType it is associated; try to use Device rather than 18 /// DeviceIndex directly. 19 using DeviceIndex = int8_t; 20 21 /// Represents a compute device on which a tensor is located. A device is 22 /// uniquely identified by a type, which specifies the type of machine it is 23 /// (e.g. CPU or CUDA GPU), and a device index or ordinal, which identifies the 24 /// specific compute device when there is more than one of a certain type. The 25 /// device index is optional, and in its defaulted state represents (abstractly) 26 /// "the current device". Further, there are two constraints on the value of the 27 /// device index, if one is explicitly stored: 28 /// 1. A negative index represents the current device, a non-negative index 29 /// represents a specific, concrete device, 30 /// 2. When the device type is CPU, the device index must be zero. 31 struct C10_API Device final { 32 using Type = DeviceType; 33 34 /// Constructs a new `Device` from a `DeviceType` and an optional device 35 /// index. 36 /* implicit */ Device(DeviceType type, DeviceIndex index = -1) type_final37 : type_(type), index_(index) { 38 validate(); 39 } 40 41 /// Constructs a `Device` from a string description, for convenience. 42 /// The string supplied must follow the following schema: 43 /// `(cpu|cuda)[:<device-index>]` 44 /// where `cpu` or `cuda` specifies the device type, and 45 /// `:<device-index>` optionally specifies a device index. 46 /* implicit */ Device(const std::string& device_string); 47 48 /// Returns true if the type and index of this `Device` matches that of 49 /// `other`. 50 bool operator==(const Device& other) const noexcept { 51 return this->type_ == other.type_ && this->index_ == other.index_; 52 } 53 54 /// Returns true if the type or index of this `Device` differs from that of 55 /// `other`. 56 bool operator!=(const Device& other) const noexcept { 57 return !(*this == other); 58 } 59 60 /// Sets the device index. set_indexfinal61 void set_index(DeviceIndex index) { 62 index_ = index; 63 } 64 65 /// Returns the type of device this is. typefinal66 DeviceType type() const noexcept { 67 return type_; 68 } 69 70 /// Returns the optional index. indexfinal71 DeviceIndex index() const noexcept { 72 return index_; 73 } 74 75 /// Returns true if the device has a non-default index. has_indexfinal76 bool has_index() const noexcept { 77 return index_ != -1; 78 } 79 80 /// Return true if the device is of CUDA type. is_cudafinal81 bool is_cuda() const noexcept { 82 return type_ == DeviceType::CUDA; 83 } 84 85 /// Return true if the device is of PrivateUse1 type. is_privateuseonefinal86 bool is_privateuseone() const noexcept { 87 return type_ == DeviceType::PrivateUse1; 88 } 89 90 /// Return true if the device is of MPS type. is_mpsfinal91 bool is_mps() const noexcept { 92 return type_ == DeviceType::MPS; 93 } 94 95 /// Return true if the device is of HIP type. is_hipfinal96 bool is_hip() const noexcept { 97 return type_ == DeviceType::HIP; 98 } 99 100 /// Return true if the device is of VE type. is_vefinal101 bool is_ve() const noexcept { 102 return type_ == DeviceType::VE; 103 } 104 105 /// Return true if the device is of XPU type. is_xpufinal106 bool is_xpu() const noexcept { 107 return type_ == DeviceType::XPU; 108 } 109 110 /// Return true if the device is of IPU type. is_ipufinal111 bool is_ipu() const noexcept { 112 return type_ == DeviceType::IPU; 113 } 114 115 /// Return true if the device is of XLA type. is_xlafinal116 bool is_xla() const noexcept { 117 return type_ == DeviceType::XLA; 118 } 119 120 /// Return true if the device is of MTIA type. is_mtiafinal121 bool is_mtia() const noexcept { 122 return type_ == DeviceType::MTIA; 123 } 124 125 /// Return true if the device is of HPU type. is_hpufinal126 bool is_hpu() const noexcept { 127 return type_ == DeviceType::HPU; 128 } 129 130 /// Return true if the device is of Lazy type. is_lazyfinal131 bool is_lazy() const noexcept { 132 return type_ == DeviceType::Lazy; 133 } 134 135 /// Return true if the device is of Vulkan type. is_vulkanfinal136 bool is_vulkan() const noexcept { 137 return type_ == DeviceType::Vulkan; 138 } 139 140 /// Return true if the device is of Metal type. is_metalfinal141 bool is_metal() const noexcept { 142 return type_ == DeviceType::Metal; 143 } 144 145 /// Return true if the device is of MAIA type. is_maiafinal146 bool is_maia() const noexcept { 147 return type_ == DeviceType::MAIA; 148 } 149 150 /// Return true if the device is of META type. is_metafinal151 bool is_meta() const noexcept { 152 return type_ == DeviceType::Meta; 153 } 154 155 /// Return true if the device is of CPU type. is_cpufinal156 bool is_cpu() const noexcept { 157 return type_ == DeviceType::CPU; 158 } 159 160 /// Return true if the device supports arbitrary strides. supports_as_stridedfinal161 bool supports_as_strided() const noexcept { 162 return type_ != DeviceType::IPU && type_ != DeviceType::XLA && 163 type_ != DeviceType::Lazy && type_ != DeviceType::MTIA; 164 } 165 166 /// Same string as returned from operator<<. 167 std::string str() const; 168 169 private: 170 DeviceType type_; 171 DeviceIndex index_ = -1; validatefinal172 void validate() { 173 // Removing these checks in release builds noticeably improves 174 // performance in micro-benchmarks. 175 // This is safe to do, because backends that use the DeviceIndex 176 // have a later check when we actually try to switch to that device. 177 TORCH_INTERNAL_ASSERT_DEBUG_ONLY( 178 index_ >= -1, 179 "Device index must be -1 or non-negative, got ", 180 static_cast<int>(index_)); 181 TORCH_INTERNAL_ASSERT_DEBUG_ONLY( 182 !is_cpu() || index_ <= 0, 183 "CPU device index must be -1 or zero, got ", 184 static_cast<int>(index_)); 185 } 186 }; 187 188 C10_API std::ostream& operator<<(std::ostream& stream, const Device& device); 189 190 } // namespace c10 191 192 namespace std { 193 template <> 194 struct hash<c10::Device> { 195 size_t operator()(c10::Device d) const noexcept { 196 // Are you here because this static assert failed? Make sure you ensure 197 // that the bitmasking code below is updated accordingly! 198 static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit"); 199 static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit"); 200 // Note [Hazard when concatenating signed integers] 201 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 202 // We must first convert to a same-sized unsigned type, before promoting to 203 // the result type, to prevent sign extension when any of the values is -1. 204 // If sign extension occurs, you'll clobber all of the values in the MSB 205 // half of the resulting integer. 206 // 207 // Technically, by C/C++ integer promotion rules, we only need one of the 208 // uint32_t casts to the result type, but we put in both for explicitness's 209 // sake. 210 uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type())) 211 << 16 | 212 static_cast<uint32_t>(static_cast<uint8_t>(d.index())); 213 return std::hash<uint32_t>{}(bits); 214 } 215 }; 216 } // namespace std 217