xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/DeprecatedTypeProperties.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Backend.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/core/Layout.h>
6 #include <c10/core/TensorOptions.h>
7 #include <c10/core/Storage.h>
8 #include <ATen/core/DeprecatedTypePropertiesRegistry.h>
9 #include <ATen/core/Generator.h>
10 
11 
12 namespace at {
13 
14 class Tensor;
15 
16 // This class specifies a Backend and a ScalarType. Currently, it primarily
17 // serves as a replacement return value for Tensor::type(). Previously,
18 // Tensor::type() returned Type&, but we are changing Type to not be
19 // dtype-specific.
20 class TORCH_API DeprecatedTypeProperties {
21  public:
DeprecatedTypeProperties(Backend backend,ScalarType scalar_type)22   DeprecatedTypeProperties(Backend backend, ScalarType scalar_type)
23     : backend_(backend), scalar_type_(scalar_type) {}
24 
backend()25   Backend backend() const {
26     return backend_;
27   }
28 
layout()29   Layout layout() const {
30     return layout_from_backend(backend_);
31   }
32 
is_sparse()33   bool is_sparse() const {
34     return layout_from_backend(backend()) == kSparse;
35   }
36 
is_sparse_csr()37   bool is_sparse_csr() const {
38     return layout_from_backend(backend()) == kSparseCsr;
39   }
40 
device_type()41   c10::DeviceType device_type() const {
42     return backendToDeviceType(backend_);
43   }
44 
is_cuda()45   bool is_cuda() const {
46     return backendToDeviceType(backend_) == kCUDA;
47   }
48 
scalarType()49   ScalarType scalarType() const {
50     return scalar_type_;
51   }
52 
typeMeta()53   caffe2::TypeMeta typeMeta() const {
54     return scalarTypeToTypeMeta(scalar_type_);
55   }
56 
57   bool operator==(const DeprecatedTypeProperties& other) const {
58     return backend_ == other.backend() && scalar_type_ == other.scalarType();
59   }
60 
61   bool operator!=(const DeprecatedTypeProperties& other) const {
62     return !(*this == other);
63   }
64 
toString()65   std::string toString() const {
66     std::string base_str;
67     if (backend_ == Backend::Undefined || scalar_type_ == ScalarType::Undefined) {
68       base_str = "UndefinedType";
69     } else {
70       base_str = std::string(at::toString(backend_)) + at::toString(scalar_type_) + "Type";
71     }
72     return base_str;
73   }
74 
toBackend(Backend b)75   DeprecatedTypeProperties & toBackend(Backend b) const {
76     return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
77         b, scalar_type_);
78   }
79 
toScalarType(ScalarType s)80   DeprecatedTypeProperties & toScalarType(ScalarType s) const {
81     return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
82         backend_, s);
83   }
84 
cpu()85   DeprecatedTypeProperties & cpu() const {
86     return toBackend(Backend::CPU);
87   }
88 
cuda()89   DeprecatedTypeProperties & cuda() const {
90     return toBackend(Backend::CUDA);
91   }
92 
hip()93   DeprecatedTypeProperties & hip() const {
94     return toBackend(Backend::HIP);
95   }
96 
privateUser1()97   DeprecatedTypeProperties & privateUser1() const {
98     return toBackend(Backend::PrivateUse1);
99   }
100 
101   /// Constructs the `TensorOptions` from a type and a `device_index`.
102   TensorOptions options(int16_t device_index = -1) const {
103     return TensorOptions().dtype(typeMeta())
104                           .device(device_type(), static_cast<c10::DeviceIndex>(device_index))
105                           .layout(layout());
106   }
107 
108   /// Constructs the `TensorOptions` from a type and a Device.  Asserts that
109   /// the device type matches the device type of the type.
options(std::optional<Device> device_opt)110   TensorOptions options(std::optional<Device> device_opt) const {
111     if (!device_opt.has_value()) {
112       return options(-1);
113     } else {
114       Device device = device_opt.value();
115       AT_ASSERT(device.type() == device_type());
116       return options(device.index());
117     }
118   }
119 
TensorOptions()120   operator TensorOptions() const {
121     return options();
122   }
123 
id()124   int64_t id() const {
125     return static_cast<int64_t>(backend()) *
126         static_cast<int64_t>(ScalarType::NumOptions) +
127         static_cast<int64_t>(scalarType());
128   }
129 
130   Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const;
131   Storage unsafeStorageFromTH(void * th_pointer, bool retain) const;
132   Tensor copy(const Tensor & src, bool non_blocking=false, std::optional<Device> to_device={}) const;
133 
134  private:
135   Backend backend_;
136   ScalarType scalar_type_;
137 };
138 
139 }  // namespace at
140