xref: /aosp_15_r20/external/pytorch/c10/core/DeviceType.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/DeviceType.h>
2 #include <c10/util/Exception.h>
3 #include <array>
4 #include <atomic>
5 #include <mutex>
6 
7 namespace c10 {
8 
DeviceTypeName(DeviceType d,bool lower_case)9 std::string DeviceTypeName(DeviceType d, bool lower_case) {
10   switch (d) {
11     // I considered instead using ctype::tolower to lower-case the strings
12     // on the fly, but this seemed a bit much.
13     case DeviceType::CPU:
14       return lower_case ? "cpu" : "CPU";
15     case DeviceType::CUDA:
16       return lower_case ? "cuda" : "CUDA";
17     case DeviceType::OPENGL:
18       return lower_case ? "opengl" : "OPENGL";
19     case DeviceType::OPENCL:
20       return lower_case ? "opencl" : "OPENCL";
21     case DeviceType::MKLDNN:
22       return lower_case ? "mkldnn" : "MKLDNN";
23     case DeviceType::IDEEP:
24       return lower_case ? "ideep" : "IDEEP";
25     case DeviceType::HIP:
26       return lower_case ? "hip" : "HIP";
27     case DeviceType::VE:
28       return lower_case ? "ve" : "VE";
29     case DeviceType::FPGA:
30       return lower_case ? "fpga" : "FPGA";
31     case DeviceType::MAIA:
32       return lower_case ? "maia" : "MAIA";
33     case DeviceType::XLA:
34       return lower_case ? "xla" : "XLA";
35     case DeviceType::Lazy:
36       return lower_case ? "lazy" : "LAZY";
37     case DeviceType::MPS:
38       return lower_case ? "mps" : "MPS";
39     case DeviceType::Vulkan:
40       return lower_case ? "vulkan" : "VULKAN";
41     case DeviceType::Metal:
42       return lower_case ? "metal" : "METAL";
43     case DeviceType::XPU:
44       return lower_case ? "xpu" : "XPU";
45     case DeviceType::Meta:
46       return lower_case ? "meta" : "META";
47     case DeviceType::HPU:
48       return lower_case ? "hpu" : "HPU";
49     case DeviceType::IPU:
50       return lower_case ? "ipu" : "IPU";
51     case DeviceType::MTIA:
52       return lower_case ? "mtia" : "MTIA";
53     case DeviceType::PrivateUse1:
54       return get_privateuse1_backend(/*lower_case=*/lower_case);
55     default:
56       TORCH_CHECK(
57           false,
58           "Unknown device: ",
59           static_cast<int16_t>(d),
60           ". If you have recently updated the caffe2.proto file to add a new "
61           "device type, did you forget to update the DeviceTypeName() "
62           "function to reflect such recent changes?");
63       // The below code won't run but is needed to suppress some compiler
64       // warnings.
65       return "";
66   }
67 }
68 
69 // NB: Per the C++ standard (e.g.,
70 // https://stackoverflow.com/questions/18195312/what-happens-if-you-static-cast-invalid-value-to-enum-class)
71 // as long as you cast from the same underlying type, it is always valid to cast
72 // into an enum class (even if the value would be invalid by the enum.)  Thus,
73 // the caller is allowed to cast a possibly invalid int16_t to DeviceType and
74 // then pass it to this function.  (I considered making this function take an
75 // int16_t directly, but that just seemed weird.)
isValidDeviceType(DeviceType d)76 bool isValidDeviceType(DeviceType d) {
77   switch (d) {
78     case DeviceType::CPU:
79     case DeviceType::CUDA:
80     case DeviceType::OPENGL:
81     case DeviceType::OPENCL:
82     case DeviceType::MKLDNN:
83     case DeviceType::IDEEP:
84     case DeviceType::HIP:
85     case DeviceType::VE:
86     case DeviceType::FPGA:
87     case DeviceType::MAIA:
88     case DeviceType::XLA:
89     case DeviceType::Lazy:
90     case DeviceType::MPS:
91     case DeviceType::Vulkan:
92     case DeviceType::Metal:
93     case DeviceType::XPU:
94     case DeviceType::Meta:
95     case DeviceType::HPU:
96     case DeviceType::IPU:
97     case DeviceType::MTIA:
98     case DeviceType::PrivateUse1:
99       return true;
100     default:
101       return false;
102   }
103 }
104 
operator <<(std::ostream & stream,DeviceType type)105 std::ostream& operator<<(std::ostream& stream, DeviceType type) {
106   stream << DeviceTypeName(type, /* lower case */ true);
107   return stream;
108 }
109 
110 // We use both a mutex and an atomic here because:
111 // (1) Mutex is needed during writing:
112 //     We need to first check the value and potentially error,
113 //     before setting the value (without any one else racing in the middle).
114 //     It's also totally fine for this to be slow, since it happens exactly once
115 //     at import time.
116 // (2) Atomic is needed during reading:
117 //     Whenever a user prints a privateuse1 device name, they need to read this
118 //     variable. Although unlikely, we'll data race if someone else is trying to
119 //     set this variable at the same time that another thread is print the
120 //     device name. We could re-use the same mutex, but reading the atomic will
121 //     be much faster.
122 static std::atomic<bool> privateuse1_backend_name_set;
123 static std::string privateuse1_backend_name;
124 static std::mutex privateuse1_lock;
125 
get_privateuse1_backend(bool lower_case)126 std::string get_privateuse1_backend(bool lower_case) {
127   // Applying the same atomic read memory ordering logic as in Note [Memory
128   // ordering on Python interpreter tag].
129   auto name_registered =
130       privateuse1_backend_name_set.load(std::memory_order_acquire);
131   // Guaranteed that if the flag is set, then privateuse1_backend_name has been
132   // set, and will never be written to.
133   auto backend_name =
134       name_registered ? privateuse1_backend_name : "privateuseone";
135   auto op_case = lower_case ? ::tolower : ::toupper;
136   std::transform(
137       backend_name.begin(), backend_name.end(), backend_name.begin(), op_case);
138   return backend_name;
139 }
140 
register_privateuse1_backend(const std::string & backend_name)141 void register_privateuse1_backend(const std::string& backend_name) {
142   std::lock_guard<std::mutex> guard(privateuse1_lock);
143   TORCH_CHECK(
144       !privateuse1_backend_name_set.load() ||
145           privateuse1_backend_name == backend_name,
146       "torch.register_privateuse1_backend() has already been set! Current backend: ",
147       privateuse1_backend_name);
148 
149   static const std::array<std::string, 6> types = {
150       "cpu", "cuda", "hip", "mps", "xpu", "mtia"};
151   TORCH_CHECK(
152       std::find(types.begin(), types.end(), backend_name) == types.end(),
153       "Cannot register privateuse1 backend with in-tree device name: ",
154       backend_name);
155 
156   privateuse1_backend_name = backend_name;
157   // Invariant: once this flag is set, privateuse1_backend_name is NEVER written
158   // to.
159   privateuse1_backend_name_set.store(true, std::memory_order_relaxed);
160 }
161 
is_privateuse1_backend_registered()162 bool is_privateuse1_backend_registered() {
163   return privateuse1_backend_name_set.load(std::memory_order_acquire);
164 }
165 
166 } // namespace c10
167