1 #pragma once 2 3 #include <memory> 4 #include <ostream> 5 #include <string> 6 7 #include <ATen/Tensor.h> 8 #include <c10/macros/Export.h> 9 #include <c10/util/Deprecated.h> 10 #include <optional> 11 12 namespace c10 { 13 struct Device; 14 } 15 16 namespace torch { 17 namespace lazy { 18 19 // Backend should extend it and define their own supported hardware types. 20 struct TORCH_API BackendDeviceType { 21 int8_t type{(int8_t)at::kCPU}; 22 // Note: previous default value was '0', which actually maps to at::kCPU, at 23 // least now it is explicit, we may want to make default/undefined semantics 24 // more clear though BackendDeviceTypeBackendDeviceType25 BackendDeviceType() : type((int8_t)at::kCPU) {} BackendDeviceTypeBackendDeviceType26 BackendDeviceType(int8_t type) : type(type) {} 27 28 virtual ~BackendDeviceType() = default; toStringBackendDeviceType29 virtual std::string toString() const { 30 return "Unknown"; 31 } 32 }; 33 34 class TORCH_API BackendDevice { 35 public: 36 // The default constructor will set both the device type and ordinal 37 // to backend specific defaults. 38 BackendDevice(); 39 BackendDevice(std::shared_ptr<BackendDeviceType>&& type, int64_t ordinal); 40 41 int8_t type() const; ordinal()42 int64_t ordinal() const { 43 return ordinal_; 44 } 45 46 bool operator==(const BackendDevice& other) const { 47 return compare(other) == 0; 48 } 49 bool operator!=(const BackendDevice& other) const { 50 return compare(other) != 0; 51 } 52 bool operator<(const BackendDevice& rhs) const { 53 return compare(rhs) < 0; 54 } 55 56 std::string toString() const; 57 58 private: 59 int compare(const BackendDevice& rhs) const; 60 61 // Use shared_ptr instead of unique_ptr so that BackendDevice can be copied. 62 std::shared_ptr<BackendDeviceType> type_; 63 int64_t ordinal_; 64 }; 65 66 TORCH_API std::ostream& operator<<( 67 std::ostream& os, 68 const BackendDevice& device); 69 70 // Helpers for converting a c10::Device to BackendDevice and vice versa. 71 TORCH_API BackendDevice atenDeviceToBackendDevice(const c10::Device& device); 72 TORCH_API c10::Device backendDeviceToAtenDevice(const BackendDevice& device); 73 74 // Tries to extract the backend device out of the lazy tensor. Returns nullopt 75 // if the input is not a lazy tensor. 76 TORCH_API std::optional<BackendDevice> GetBackendDevice( 77 const at::ITensorListRef tensors); 78 TORCH_API std::optional<BackendDevice> GetBackendDevice( 79 const at::TensorList tensors); 80 TORCH_API std::optional<BackendDevice> GetBackendDevice( 81 const at::Tensor& tensor); 82 TORCH_API std::optional<BackendDevice> GetBackendDevice( 83 const std::optional<c10::Device>& device); 84 85 // For variadic template. 86 TORCH_API std::optional<BackendDevice> GetBackendDevice(); 87 88 template <typename T, typename... Args> GetBackendDevice(const T & tensor,const Args &...forward_tensors)89std::optional<BackendDevice> GetBackendDevice( 90 const T& tensor, 91 const Args&... forward_tensors) { 92 auto optional_device = GetBackendDevice(tensor); 93 if (optional_device) { 94 return optional_device; 95 } 96 return GetBackendDevice(forward_tensors...); 97 } 98 99 } // namespace lazy 100 } // namespace torch 101