1 #include <c10/core/DeviceType.h>
2 #include <c10/util/Exception.h>
3 #include <array>
4 #include <atomic>
5 #include <mutex>
6
7 namespace c10 {
8
DeviceTypeName(DeviceType d,bool lower_case)9 std::string DeviceTypeName(DeviceType d, bool lower_case) {
10 switch (d) {
11 // I considered instead using ctype::tolower to lower-case the strings
12 // on the fly, but this seemed a bit much.
13 case DeviceType::CPU:
14 return lower_case ? "cpu" : "CPU";
15 case DeviceType::CUDA:
16 return lower_case ? "cuda" : "CUDA";
17 case DeviceType::OPENGL:
18 return lower_case ? "opengl" : "OPENGL";
19 case DeviceType::OPENCL:
20 return lower_case ? "opencl" : "OPENCL";
21 case DeviceType::MKLDNN:
22 return lower_case ? "mkldnn" : "MKLDNN";
23 case DeviceType::IDEEP:
24 return lower_case ? "ideep" : "IDEEP";
25 case DeviceType::HIP:
26 return lower_case ? "hip" : "HIP";
27 case DeviceType::VE:
28 return lower_case ? "ve" : "VE";
29 case DeviceType::FPGA:
30 return lower_case ? "fpga" : "FPGA";
31 case DeviceType::MAIA:
32 return lower_case ? "maia" : "MAIA";
33 case DeviceType::XLA:
34 return lower_case ? "xla" : "XLA";
35 case DeviceType::Lazy:
36 return lower_case ? "lazy" : "LAZY";
37 case DeviceType::MPS:
38 return lower_case ? "mps" : "MPS";
39 case DeviceType::Vulkan:
40 return lower_case ? "vulkan" : "VULKAN";
41 case DeviceType::Metal:
42 return lower_case ? "metal" : "METAL";
43 case DeviceType::XPU:
44 return lower_case ? "xpu" : "XPU";
45 case DeviceType::Meta:
46 return lower_case ? "meta" : "META";
47 case DeviceType::HPU:
48 return lower_case ? "hpu" : "HPU";
49 case DeviceType::IPU:
50 return lower_case ? "ipu" : "IPU";
51 case DeviceType::MTIA:
52 return lower_case ? "mtia" : "MTIA";
53 case DeviceType::PrivateUse1:
54 return get_privateuse1_backend(/*lower_case=*/lower_case);
55 default:
56 TORCH_CHECK(
57 false,
58 "Unknown device: ",
59 static_cast<int16_t>(d),
60 ". If you have recently updated the caffe2.proto file to add a new "
61 "device type, did you forget to update the DeviceTypeName() "
62 "function to reflect such recent changes?");
63 // The below code won't run but is needed to suppress some compiler
64 // warnings.
65 return "";
66 }
67 }
68
69 // NB: Per the C++ standard (e.g.,
70 // https://stackoverflow.com/questions/18195312/what-happens-if-you-static-cast-invalid-value-to-enum-class)
71 // as long as you cast from the same underlying type, it is always valid to cast
72 // into an enum class (even if the value would be invalid by the enum.) Thus,
73 // the caller is allowed to cast a possibly invalid int16_t to DeviceType and
74 // then pass it to this function. (I considered making this function take an
75 // int16_t directly, but that just seemed weird.)
isValidDeviceType(DeviceType d)76 bool isValidDeviceType(DeviceType d) {
77 switch (d) {
78 case DeviceType::CPU:
79 case DeviceType::CUDA:
80 case DeviceType::OPENGL:
81 case DeviceType::OPENCL:
82 case DeviceType::MKLDNN:
83 case DeviceType::IDEEP:
84 case DeviceType::HIP:
85 case DeviceType::VE:
86 case DeviceType::FPGA:
87 case DeviceType::MAIA:
88 case DeviceType::XLA:
89 case DeviceType::Lazy:
90 case DeviceType::MPS:
91 case DeviceType::Vulkan:
92 case DeviceType::Metal:
93 case DeviceType::XPU:
94 case DeviceType::Meta:
95 case DeviceType::HPU:
96 case DeviceType::IPU:
97 case DeviceType::MTIA:
98 case DeviceType::PrivateUse1:
99 return true;
100 default:
101 return false;
102 }
103 }
104
operator <<(std::ostream & stream,DeviceType type)105 std::ostream& operator<<(std::ostream& stream, DeviceType type) {
106 stream << DeviceTypeName(type, /* lower case */ true);
107 return stream;
108 }
109
110 // We use both a mutex and an atomic here because:
111 // (1) Mutex is needed during writing:
112 // We need to first check the value and potentially error,
113 // before setting the value (without any one else racing in the middle).
114 // It's also totally fine for this to be slow, since it happens exactly once
115 // at import time.
116 // (2) Atomic is needed during reading:
117 // Whenever a user prints a privateuse1 device name, they need to read this
118 // variable. Although unlikely, we'll data race if someone else is trying to
119 // set this variable at the same time that another thread is print the
120 // device name. We could re-use the same mutex, but reading the atomic will
121 // be much faster.
122 static std::atomic<bool> privateuse1_backend_name_set;
123 static std::string privateuse1_backend_name;
124 static std::mutex privateuse1_lock;
125
get_privateuse1_backend(bool lower_case)126 std::string get_privateuse1_backend(bool lower_case) {
127 // Applying the same atomic read memory ordering logic as in Note [Memory
128 // ordering on Python interpreter tag].
129 auto name_registered =
130 privateuse1_backend_name_set.load(std::memory_order_acquire);
131 // Guaranteed that if the flag is set, then privateuse1_backend_name has been
132 // set, and will never be written to.
133 auto backend_name =
134 name_registered ? privateuse1_backend_name : "privateuseone";
135 auto op_case = lower_case ? ::tolower : ::toupper;
136 std::transform(
137 backend_name.begin(), backend_name.end(), backend_name.begin(), op_case);
138 return backend_name;
139 }
140
register_privateuse1_backend(const std::string & backend_name)141 void register_privateuse1_backend(const std::string& backend_name) {
142 std::lock_guard<std::mutex> guard(privateuse1_lock);
143 TORCH_CHECK(
144 !privateuse1_backend_name_set.load() ||
145 privateuse1_backend_name == backend_name,
146 "torch.register_privateuse1_backend() has already been set! Current backend: ",
147 privateuse1_backend_name);
148
149 static const std::array<std::string, 6> types = {
150 "cpu", "cuda", "hip", "mps", "xpu", "mtia"};
151 TORCH_CHECK(
152 std::find(types.begin(), types.end(), backend_name) == types.end(),
153 "Cannot register privateuse1 backend with in-tree device name: ",
154 backend_name);
155
156 privateuse1_backend_name = backend_name;
157 // Invariant: once this flag is set, privateuse1_backend_name is NEVER written
158 // to.
159 privateuse1_backend_name_set.store(true, std::memory_order_relaxed);
160 }
161
is_privateuse1_backend_registered()162 bool is_privateuse1_backend_registered() {
163 return privateuse1_backend_name_set.load(std::memory_order_acquire);
164 }
165
166 } // namespace c10
167