xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Convert.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_VULKAN_API
4 
5 #include <ATen/native/vulkan/VulkanOpaqueTensorImpl.h>
6 #include <ATen/native/vulkan/api/Tensor.h>
7 #include <ATen/native/vulkan/api/Types.h>
8 #include <c10/util/accumulate.h>
9 
10 namespace at {
11 namespace native {
12 namespace vulkan {
13 namespace ops {
14 
15 /**
16  * Determines an appropriate GPU Memory Layout qualifier based on the the
17  * StorageType requested and the c10::MemoryFormat specified.
18  */
get_gpu_memory_layout(const api::StorageType storage_type,const c10::MemoryFormat memory_format)19 inline api::GPUMemoryLayout get_gpu_memory_layout(
20     const api::StorageType storage_type,
21     const c10::MemoryFormat memory_format) {
22   if (storage_type == api::StorageType::BUFFER) {
23     switch (memory_format) {
24       case c10::MemoryFormat::Contiguous:
25         return api::GPUMemoryLayout::TENSOR_WIDTH_PACKED;
26       case c10::MemoryFormat::ChannelsLast:
27         return api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED;
28       default:
29         VK_THROW("Invalid memory format used to create vTensor!");
30     }
31   }
32   // For texture storage, always return a memory layout that packs the channels
33   // dimension. for now. With the way texture storage currently works, for 2-dim
34   // tensors, a channel dimension is added, as well as 3 channels of zero
35   // padding resulting in a final shape of {4, H, W}. For 1-dim tensors, it is
36   // unsqueezed to size {1, 1, L} and 3 channels of zero padding are added to
37   // produce a final size of {4, 1, L}. This is to ensure that physical texture
38   // positions correspond directly to logical tensor coordinates (so
39   // texelFetch(ivec3(x, y, 0), 0) will correspond to tensor[y, x].
40   //
41   // TODO(ssjia): have 2D and 1D tensors use TENSOR_WIDTH_PACKED by default.
42   return api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED;
43 }
44 
45 /*
46  * Converts a `c10::ScalarType` to an equivalent
47  * `::at::native::vulkan::api::ScalarType`.
48  */
convert_dtype(const c10::ScalarType dtype)49 static inline api::ScalarType convert_dtype(const c10::ScalarType dtype) {
50 #define DEFINE_CASE(ctype, vkformat, name) \
51   case c10::ScalarType::name:              \
52     return ::at::native::vulkan::api::ScalarType::name;
53 
54   switch (dtype) {
55     VK_FORALL_SCALAR_TYPES(DEFINE_CASE)
56     default:
57       TORCH_CHECK(false, "Not a supported Vulkan ScalarType!");
58   }
59 #undef DEFINE_CASE
60 }
61 
62 /*
63  * Converts an `::at::native::vulkan::api::ScalarType` to an equivalent
64  * `c10::ScalarType`.
65  */
convert_dtype(const api::ScalarType dtype)66 static inline c10::ScalarType convert_dtype(const api::ScalarType dtype) {
67 #define DEFINE_CASE(ctype, vkformat, name)          \
68   case ::at::native::vulkan::api::ScalarType::name: \
69     return c10::ScalarType::name;
70 
71   switch (dtype) {
72     VK_FORALL_SCALAR_TYPES(DEFINE_CASE)
73     default:
74       TORCH_CHECK(false, "Not a supported c10::ScalarType!");
75   }
76 #undef DEFINE_CASE
77 }
78 
79 using vTensorImpl = VulkanOpaqueTensorImpl<vTensor>;
80 
convert(const vTensor & tensor)81 inline Tensor convert(const vTensor& tensor) {
82   return at::detail::make_tensor<vTensorImpl>(
83       DispatchKeySet(DispatchKey::Vulkan),
84       c10::scalarTypeToTypeMeta(convert_dtype(tensor.dtype())),
85       at::Device(at::kVulkan),
86       tensor,
87       tensor.sizes(),
88       tensor.strides());
89 }
90 
convert_quantized(const vTensor & tensor)91 inline Tensor convert_quantized(const vTensor& tensor) {
92   TORCH_CHECK(tensor.is_quantized(), "Not a Quantized Tensor");
93   return at::detail::make_tensor<vTensorImpl>(
94       DispatchKeySet(DispatchKey::Vulkan),
95       c10::scalarTypeToTypeMeta(convert_dtype(tensor.dtype())),
96       at::Device(at::kVulkan),
97       tensor,
98       tensor.sizes(),
99       tensor.strides());
100 }
101 
convert(const Tensor & tensor)102 inline vTensor& convert(const Tensor& tensor) {
103   TORCH_INTERNAL_ASSERT(tensor.is_vulkan(), "Vulkan tensor expected!");
104 
105   vTensorImpl* const impl =
106       static_cast<vTensorImpl*>(tensor.unsafeGetTensorImpl());
107 
108   return impl->unsafe_opaque_handle();
109 }
110 
111 } // namespace ops
112 } // namespace vulkan
113 } // namespace native
114 } // namespace at
115 
116 #endif /* USE_VULKAN_API */
117