1 #pragma once 2 3 #include <c10/core/Backend.h> 4 #include <c10/core/ScalarType.h> 5 #include <c10/core/Layout.h> 6 #include <c10/core/TensorOptions.h> 7 #include <c10/core/Storage.h> 8 #include <ATen/core/DeprecatedTypePropertiesRegistry.h> 9 #include <ATen/core/Generator.h> 10 11 12 namespace at { 13 14 class Tensor; 15 16 // This class specifies a Backend and a ScalarType. Currently, it primarily 17 // serves as a replacement return value for Tensor::type(). Previously, 18 // Tensor::type() returned Type&, but we are changing Type to not be 19 // dtype-specific. 20 class TORCH_API DeprecatedTypeProperties { 21 public: DeprecatedTypeProperties(Backend backend,ScalarType scalar_type)22 DeprecatedTypeProperties(Backend backend, ScalarType scalar_type) 23 : backend_(backend), scalar_type_(scalar_type) {} 24 backend()25 Backend backend() const { 26 return backend_; 27 } 28 layout()29 Layout layout() const { 30 return layout_from_backend(backend_); 31 } 32 is_sparse()33 bool is_sparse() const { 34 return layout_from_backend(backend()) == kSparse; 35 } 36 is_sparse_csr()37 bool is_sparse_csr() const { 38 return layout_from_backend(backend()) == kSparseCsr; 39 } 40 device_type()41 c10::DeviceType device_type() const { 42 return backendToDeviceType(backend_); 43 } 44 is_cuda()45 bool is_cuda() const { 46 return backendToDeviceType(backend_) == kCUDA; 47 } 48 scalarType()49 ScalarType scalarType() const { 50 return scalar_type_; 51 } 52 typeMeta()53 caffe2::TypeMeta typeMeta() const { 54 return scalarTypeToTypeMeta(scalar_type_); 55 } 56 57 bool operator==(const DeprecatedTypeProperties& other) const { 58 return backend_ == other.backend() && scalar_type_ == other.scalarType(); 59 } 60 61 bool operator!=(const DeprecatedTypeProperties& other) const { 62 return !(*this == other); 63 } 64 toString()65 std::string toString() const { 66 std::string base_str; 67 if (backend_ == Backend::Undefined || scalar_type_ == ScalarType::Undefined) { 68 base_str = "UndefinedType"; 69 } else { 70 base_str = std::string(at::toString(backend_)) + at::toString(scalar_type_) + "Type"; 71 } 72 return base_str; 73 } 74 toBackend(Backend b)75 DeprecatedTypeProperties & toBackend(Backend b) const { 76 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( 77 b, scalar_type_); 78 } 79 toScalarType(ScalarType s)80 DeprecatedTypeProperties & toScalarType(ScalarType s) const { 81 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( 82 backend_, s); 83 } 84 cpu()85 DeprecatedTypeProperties & cpu() const { 86 return toBackend(Backend::CPU); 87 } 88 cuda()89 DeprecatedTypeProperties & cuda() const { 90 return toBackend(Backend::CUDA); 91 } 92 hip()93 DeprecatedTypeProperties & hip() const { 94 return toBackend(Backend::HIP); 95 } 96 privateUser1()97 DeprecatedTypeProperties & privateUser1() const { 98 return toBackend(Backend::PrivateUse1); 99 } 100 101 /// Constructs the `TensorOptions` from a type and a `device_index`. 102 TensorOptions options(int16_t device_index = -1) const { 103 return TensorOptions().dtype(typeMeta()) 104 .device(device_type(), static_cast<c10::DeviceIndex>(device_index)) 105 .layout(layout()); 106 } 107 108 /// Constructs the `TensorOptions` from a type and a Device. Asserts that 109 /// the device type matches the device type of the type. options(std::optional<Device> device_opt)110 TensorOptions options(std::optional<Device> device_opt) const { 111 if (!device_opt.has_value()) { 112 return options(-1); 113 } else { 114 Device device = device_opt.value(); 115 AT_ASSERT(device.type() == device_type()); 116 return options(device.index()); 117 } 118 } 119 TensorOptions()120 operator TensorOptions() const { 121 return options(); 122 } 123 id()124 int64_t id() const { 125 return static_cast<int64_t>(backend()) * 126 static_cast<int64_t>(ScalarType::NumOptions) + 127 static_cast<int64_t>(scalarType()); 128 } 129 130 Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const; 131 Storage unsafeStorageFromTH(void * th_pointer, bool retain) const; 132 Tensor copy(const Tensor & src, bool non_blocking=false, std::optional<Device> to_device={}) const; 133 134 private: 135 Backend backend_; 136 ScalarType scalar_type_; 137 }; 138 139 } // namespace at 140