xref: /aosp_15_r20/external/pytorch/c10/core/DeviceType.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // This is directly synchronized with caffe2/proto/caffe2.proto, but
4 // doesn't require me to figure out how to get Protobuf headers into
5 // ATen/core (which would require a lot more build system hacking.)
6 // If you modify me, keep me synchronized with that file.
7 
8 #include <c10/macros/Export.h>
9 
10 #include <cstddef>
11 #include <cstdint>
12 #include <functional>
13 #include <ostream>
14 #include <string>
15 
16 namespace c10 {
17 
18 // These contains all device types that also have a BackendComponent
19 // and therefore participate in per-backend functionality dispatch keys.
20 // This is most backends except PrivateUse2 and PrivateUse3
21 #define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \
22   _(CPU, extra)                                   \
23   _(CUDA, extra)                                  \
24   _(HIP, extra)                                   \
25   _(XLA, extra)                                   \
26   _(MPS, extra)                                   \
27   _(IPU, extra)                                   \
28   _(XPU, extra)                                   \
29   _(HPU, extra)                                   \
30   _(VE, extra)                                    \
31   _(Lazy, extra)                                  \
32   _(Meta, extra)                                  \
33   _(MTIA, extra)                                  \
34   _(PrivateUse1, extra)
35 
36 enum class DeviceType : int8_t {
37   CPU = 0,
38   CUDA = 1, // CUDA.
39   MKLDNN = 2, // Reserved for explicit MKLDNN
40   OPENGL = 3, // OpenGL
41   OPENCL = 4, // OpenCL
42   IDEEP = 5, // IDEEP.
43   HIP = 6, // AMD HIP
44   FPGA = 7, // FPGA
45   MAIA = 8, // ONNX Runtime / Microsoft
46   XLA = 9, // XLA / TPU
47   Vulkan = 10, // Vulkan
48   Metal = 11, // Metal
49   XPU = 12, // XPU
50   MPS = 13, // MPS
51   Meta = 14, // Meta (tensors with no data)
52   HPU = 15, // HPU / HABANA
53   VE = 16, // SX-Aurora / NEC
54   Lazy = 17, // Lazy Tensors
55   IPU = 18, // Graphcore IPU
56   MTIA = 19, // Meta training and inference devices
57   PrivateUse1 = 20, // PrivateUse1 device
58   // NB: If you add more devices:
59   //  - Change the implementations of DeviceTypeName and isValidDeviceType
60   //    in DeviceType.cpp
61   //  - Change the number below
62   COMPILE_TIME_MAX_DEVICE_TYPES = 21,
63 };
64 
65 constexpr DeviceType kCPU = DeviceType::CPU;
66 constexpr DeviceType kCUDA = DeviceType::CUDA;
67 constexpr DeviceType kHIP = DeviceType::HIP;
68 constexpr DeviceType kFPGA = DeviceType::FPGA;
69 constexpr DeviceType kMAIA = DeviceType::MAIA;
70 constexpr DeviceType kXLA = DeviceType::XLA;
71 constexpr DeviceType kMPS = DeviceType::MPS;
72 constexpr DeviceType kMeta = DeviceType::Meta;
73 constexpr DeviceType kVulkan = DeviceType::Vulkan;
74 constexpr DeviceType kMetal = DeviceType::Metal;
75 constexpr DeviceType kXPU = DeviceType::XPU;
76 constexpr DeviceType kHPU = DeviceType::HPU;
77 constexpr DeviceType kVE = DeviceType::VE;
78 constexpr DeviceType kLazy = DeviceType::Lazy;
79 constexpr DeviceType kIPU = DeviceType::IPU;
80 constexpr DeviceType kMTIA = DeviceType::MTIA;
81 constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1;
82 
83 // define explicit int constant
84 constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
85     static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
86 
87 static_assert(
88     COMPILE_TIME_MAX_DEVICE_TYPES <= 21,
89     "Hey!  You seem to be adding a lot of new DeviceTypes.  The intent was "
90     "for this constant to reflect the actual number of DeviceTypes we support "
91     "in PyTorch; it's important that this number is not too large as we "
92     "use this to allocate stack arrays in some places in our code.  If you "
93     "are indeed just adding the 20th device type, feel free to change "
94     "the check to 32; but if you are adding some sort of extensible device "
95     "types registration, please be aware that you are affecting code that "
96     "this number is small.  Try auditing uses of this constant.");
97 
98 C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false);
99 
100 C10_API bool isValidDeviceType(DeviceType d);
101 
102 C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type);
103 
104 C10_API void register_privateuse1_backend(const std::string& backend_name);
105 C10_API std::string get_privateuse1_backend(bool lower_case = true);
106 
107 C10_API bool is_privateuse1_backend_registered();
108 
109 } // namespace c10
110 
111 namespace std {
112 template <>
113 struct hash<c10::DeviceType> {
114   std::size_t operator()(c10::DeviceType k) const {
115     return std::hash<int>()(static_cast<int>(k));
116   }
117 };
118 } // namespace std
119 
120 namespace torch {
121 // NOLINTNEXTLINE(misc-unused-using-decls)
122 using c10::DeviceType;
123 } // namespace torch
124