1 #pragma once 2 3 // This is directly synchronized with caffe2/proto/caffe2.proto, but 4 // doesn't require me to figure out how to get Protobuf headers into 5 // ATen/core (which would require a lot more build system hacking.) 6 // If you modify me, keep me synchronized with that file. 7 8 #include <c10/macros/Export.h> 9 10 #include <cstddef> 11 #include <cstdint> 12 #include <functional> 13 #include <ostream> 14 #include <string> 15 16 namespace c10 { 17 18 // These contains all device types that also have a BackendComponent 19 // and therefore participate in per-backend functionality dispatch keys. 20 // This is most backends except PrivateUse2 and PrivateUse3 21 #define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \ 22 _(CPU, extra) \ 23 _(CUDA, extra) \ 24 _(HIP, extra) \ 25 _(XLA, extra) \ 26 _(MPS, extra) \ 27 _(IPU, extra) \ 28 _(XPU, extra) \ 29 _(HPU, extra) \ 30 _(VE, extra) \ 31 _(Lazy, extra) \ 32 _(Meta, extra) \ 33 _(MTIA, extra) \ 34 _(PrivateUse1, extra) 35 36 enum class DeviceType : int8_t { 37 CPU = 0, 38 CUDA = 1, // CUDA. 39 MKLDNN = 2, // Reserved for explicit MKLDNN 40 OPENGL = 3, // OpenGL 41 OPENCL = 4, // OpenCL 42 IDEEP = 5, // IDEEP. 43 HIP = 6, // AMD HIP 44 FPGA = 7, // FPGA 45 MAIA = 8, // ONNX Runtime / Microsoft 46 XLA = 9, // XLA / TPU 47 Vulkan = 10, // Vulkan 48 Metal = 11, // Metal 49 XPU = 12, // XPU 50 MPS = 13, // MPS 51 Meta = 14, // Meta (tensors with no data) 52 HPU = 15, // HPU / HABANA 53 VE = 16, // SX-Aurora / NEC 54 Lazy = 17, // Lazy Tensors 55 IPU = 18, // Graphcore IPU 56 MTIA = 19, // Meta training and inference devices 57 PrivateUse1 = 20, // PrivateUse1 device 58 // NB: If you add more devices: 59 // - Change the implementations of DeviceTypeName and isValidDeviceType 60 // in DeviceType.cpp 61 // - Change the number below 62 COMPILE_TIME_MAX_DEVICE_TYPES = 21, 63 }; 64 65 constexpr DeviceType kCPU = DeviceType::CPU; 66 constexpr DeviceType kCUDA = DeviceType::CUDA; 67 constexpr DeviceType kHIP = DeviceType::HIP; 68 constexpr DeviceType kFPGA = DeviceType::FPGA; 69 constexpr DeviceType kMAIA = DeviceType::MAIA; 70 constexpr DeviceType kXLA = DeviceType::XLA; 71 constexpr DeviceType kMPS = DeviceType::MPS; 72 constexpr DeviceType kMeta = DeviceType::Meta; 73 constexpr DeviceType kVulkan = DeviceType::Vulkan; 74 constexpr DeviceType kMetal = DeviceType::Metal; 75 constexpr DeviceType kXPU = DeviceType::XPU; 76 constexpr DeviceType kHPU = DeviceType::HPU; 77 constexpr DeviceType kVE = DeviceType::VE; 78 constexpr DeviceType kLazy = DeviceType::Lazy; 79 constexpr DeviceType kIPU = DeviceType::IPU; 80 constexpr DeviceType kMTIA = DeviceType::MTIA; 81 constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1; 82 83 // define explicit int constant 84 constexpr int COMPILE_TIME_MAX_DEVICE_TYPES = 85 static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES); 86 87 static_assert( 88 COMPILE_TIME_MAX_DEVICE_TYPES <= 21, 89 "Hey! You seem to be adding a lot of new DeviceTypes. The intent was " 90 "for this constant to reflect the actual number of DeviceTypes we support " 91 "in PyTorch; it's important that this number is not too large as we " 92 "use this to allocate stack arrays in some places in our code. If you " 93 "are indeed just adding the 20th device type, feel free to change " 94 "the check to 32; but if you are adding some sort of extensible device " 95 "types registration, please be aware that you are affecting code that " 96 "this number is small. Try auditing uses of this constant."); 97 98 C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false); 99 100 C10_API bool isValidDeviceType(DeviceType d); 101 102 C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type); 103 104 C10_API void register_privateuse1_backend(const std::string& backend_name); 105 C10_API std::string get_privateuse1_backend(bool lower_case = true); 106 107 C10_API bool is_privateuse1_backend_registered(); 108 109 } // namespace c10 110 111 namespace std { 112 template <> 113 struct hash<c10::DeviceType> { 114 std::size_t operator()(c10::DeviceType k) const { 115 return std::hash<int>()(static_cast<int>(k)); 116 } 117 }; 118 } // namespace std 119 120 namespace torch { 121 // NOLINTNEXTLINE(misc-unused-using-decls) 122 using c10::DeviceType; 123 } // namespace torch 124