xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/backend/backend_device.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/backend/backend_device.h>
2 
3 #include <c10/core/Device.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/StringUtil.h>
6 #include <torch/csrc/lazy/backend/backend_interface.h>
7 #include <torch/csrc/lazy/core/tensor.h>
8 #include <optional>
9 
10 namespace torch {
11 namespace lazy {
12 
BackendDevice()13 BackendDevice::BackendDevice()
14     : type_(getBackend()->GetDefaultDeviceType()),
15       ordinal_(getBackend()->GetDefaultDeviceOrdinal()) {}
16 
BackendDevice(std::shared_ptr<BackendDeviceType> && type,int64_t ordinal)17 BackendDevice::BackendDevice(
18     std::shared_ptr<BackendDeviceType>&& type,
19     int64_t ordinal)
20     : type_(std::move(type)), ordinal_(ordinal) {}
21 
type() const22 int8_t BackendDevice::type() const {
23   TORCH_INTERNAL_ASSERT(type_);
24   return type_->type;
25 }
26 
toString() const27 std::string BackendDevice::toString() const {
28   TORCH_INTERNAL_ASSERT(type_);
29   return c10::str(type_->toString(), ordinal_);
30 }
31 
compare(const BackendDevice & rhs) const32 int BackendDevice::compare(const BackendDevice& rhs) const {
33   if (type() != rhs.type()) {
34     return type() < rhs.type() ? -1 : +1;
35   }
36   return ordinal_ < rhs.ordinal_ ? -1 : (ordinal_ > rhs.ordinal_ ? +1 : 0);
37 }
38 
operator <<(std::ostream & os,const BackendDevice & device)39 std::ostream& operator<<(std::ostream& os, const BackendDevice& device) {
40   os << device.toString();
41   return os;
42 }
43 
atenDeviceToBackendDevice(const c10::Device & device)44 BackendDevice atenDeviceToBackendDevice(const c10::Device& device) {
45   TORCH_CHECK(device.type() == at::kLazy, device);
46   int64_t ordinal = device.has_index()
47       ? device.index()
48       : getBackend()->GetDefaultDeviceOrdinal();
49   return BackendDevice(getBackend()->GetDefaultDeviceType(), ordinal);
50 }
51 
52 // TODO(whc) refactor this: we need to support non 1 on 1 mapping for torch/XLA.
backendDeviceToAtenDevice(const BackendDevice & device)53 c10::Device backendDeviceToAtenDevice(const BackendDevice& device) {
54   return c10::Device(at::kLazy, device.ordinal());
55 }
56 
GetBackendDevice(at::ITensorListRef tensors)57 std::optional<BackendDevice> GetBackendDevice(at::ITensorListRef tensors) {
58   for (auto& tensor : tensors) {
59     if (auto lt = TryGetLtcTensor(tensor)) {
60       return lt->GetDevice();
61     }
62   }
63   return std::nullopt;
64 }
65 
GetBackendDevice(at::TensorList tensors)66 std::optional<BackendDevice> GetBackendDevice(at::TensorList tensors) {
67   return GetBackendDevice(at::ITensorListRef(tensors));
68 }
69 
GetBackendDevice(const at::Tensor & tensor)70 std::optional<BackendDevice> GetBackendDevice(const at::Tensor& tensor) {
71   if (auto lt = TryGetLtcTensor(tensor)) {
72     return lt->GetDevice();
73   }
74   return std::nullopt;
75 }
76 
GetBackendDevice(const std::optional<c10::Device> & device)77 std::optional<BackendDevice> GetBackendDevice(
78     const std::optional<c10::Device>& device) {
79   if (device) {
80     return std::make_optional(atenDeviceToBackendDevice(*device));
81   }
82   return std::nullopt;
83 }
84 
GetBackendDevice()85 std::optional<BackendDevice> GetBackendDevice() {
86   return std::nullopt;
87 }
88 
89 } // namespace lazy
90 } // namespace torch
91