xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/backend/backend_device.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <memory>
4 #include <ostream>
5 #include <string>
6 
7 #include <ATen/Tensor.h>
8 #include <c10/macros/Export.h>
9 #include <c10/util/Deprecated.h>
10 #include <optional>
11 
12 namespace c10 {
13 struct Device;
14 }
15 
16 namespace torch {
17 namespace lazy {
18 
19 // Backend should extend it and define their own supported hardware types.
20 struct TORCH_API BackendDeviceType {
21   int8_t type{(int8_t)at::kCPU};
22   // Note: previous default value was '0', which actually maps to at::kCPU, at
23   // least now it is explicit, we may want to make default/undefined semantics
24   // more clear though
BackendDeviceTypeBackendDeviceType25   BackendDeviceType() : type((int8_t)at::kCPU) {}
BackendDeviceTypeBackendDeviceType26   BackendDeviceType(int8_t type) : type(type) {}
27 
28   virtual ~BackendDeviceType() = default;
toStringBackendDeviceType29   virtual std::string toString() const {
30     return "Unknown";
31   }
32 };
33 
34 class TORCH_API BackendDevice {
35  public:
36   // The default constructor will set both the device type and ordinal
37   // to backend specific defaults.
38   BackendDevice();
39   BackendDevice(std::shared_ptr<BackendDeviceType>&& type, int64_t ordinal);
40 
41   int8_t type() const;
ordinal()42   int64_t ordinal() const {
43     return ordinal_;
44   }
45 
46   bool operator==(const BackendDevice& other) const {
47     return compare(other) == 0;
48   }
49   bool operator!=(const BackendDevice& other) const {
50     return compare(other) != 0;
51   }
52   bool operator<(const BackendDevice& rhs) const {
53     return compare(rhs) < 0;
54   }
55 
56   std::string toString() const;
57 
58  private:
59   int compare(const BackendDevice& rhs) const;
60 
61   // Use shared_ptr instead of unique_ptr so that BackendDevice can be copied.
62   std::shared_ptr<BackendDeviceType> type_;
63   int64_t ordinal_;
64 };
65 
66 TORCH_API std::ostream& operator<<(
67     std::ostream& os,
68     const BackendDevice& device);
69 
70 // Helpers for converting a c10::Device to BackendDevice and vice versa.
71 TORCH_API BackendDevice atenDeviceToBackendDevice(const c10::Device& device);
72 TORCH_API c10::Device backendDeviceToAtenDevice(const BackendDevice& device);
73 
74 // Tries to extract the backend device out of the lazy tensor. Returns nullopt
75 // if the input is not a lazy tensor.
76 TORCH_API std::optional<BackendDevice> GetBackendDevice(
77     const at::ITensorListRef tensors);
78 TORCH_API std::optional<BackendDevice> GetBackendDevice(
79     const at::TensorList tensors);
80 TORCH_API std::optional<BackendDevice> GetBackendDevice(
81     const at::Tensor& tensor);
82 TORCH_API std::optional<BackendDevice> GetBackendDevice(
83     const std::optional<c10::Device>& device);
84 
85 // For variadic template.
86 TORCH_API std::optional<BackendDevice> GetBackendDevice();
87 
88 template <typename T, typename... Args>
GetBackendDevice(const T & tensor,const Args &...forward_tensors)89 std::optional<BackendDevice> GetBackendDevice(
90     const T& tensor,
91     const Args&... forward_tensors) {
92   auto optional_device = GetBackendDevice(tensor);
93   if (optional_device) {
94     return optional_device;
95   }
96   return GetBackendDevice(forward_tensors...);
97 }
98 
99 } // namespace lazy
100 } // namespace torch
101