1 #pragma once
2
3 #ifdef USE_VULKAN_API
4
5 #include <ATen/native/vulkan/VulkanOpaqueTensorImpl.h>
6 #include <ATen/native/vulkan/api/Tensor.h>
7 #include <ATen/native/vulkan/api/Types.h>
8 #include <c10/util/accumulate.h>
9
10 namespace at {
11 namespace native {
12 namespace vulkan {
13 namespace ops {
14
15 /**
16 * Determines an appropriate GPU Memory Layout qualifier based on the the
17 * StorageType requested and the c10::MemoryFormat specified.
18 */
get_gpu_memory_layout(const api::StorageType storage_type,const c10::MemoryFormat memory_format)19 inline api::GPUMemoryLayout get_gpu_memory_layout(
20 const api::StorageType storage_type,
21 const c10::MemoryFormat memory_format) {
22 if (storage_type == api::StorageType::BUFFER) {
23 switch (memory_format) {
24 case c10::MemoryFormat::Contiguous:
25 return api::GPUMemoryLayout::TENSOR_WIDTH_PACKED;
26 case c10::MemoryFormat::ChannelsLast:
27 return api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED;
28 default:
29 VK_THROW("Invalid memory format used to create vTensor!");
30 }
31 }
32 // For texture storage, always return a memory layout that packs the channels
33 // dimension. for now. With the way texture storage currently works, for 2-dim
34 // tensors, a channel dimension is added, as well as 3 channels of zero
35 // padding resulting in a final shape of {4, H, W}. For 1-dim tensors, it is
36 // unsqueezed to size {1, 1, L} and 3 channels of zero padding are added to
37 // produce a final size of {4, 1, L}. This is to ensure that physical texture
38 // positions correspond directly to logical tensor coordinates (so
39 // texelFetch(ivec3(x, y, 0), 0) will correspond to tensor[y, x].
40 //
41 // TODO(ssjia): have 2D and 1D tensors use TENSOR_WIDTH_PACKED by default.
42 return api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED;
43 }
44
45 /*
46 * Converts a `c10::ScalarType` to an equivalent
47 * `::at::native::vulkan::api::ScalarType`.
48 */
convert_dtype(const c10::ScalarType dtype)49 static inline api::ScalarType convert_dtype(const c10::ScalarType dtype) {
50 #define DEFINE_CASE(ctype, vkformat, name) \
51 case c10::ScalarType::name: \
52 return ::at::native::vulkan::api::ScalarType::name;
53
54 switch (dtype) {
55 VK_FORALL_SCALAR_TYPES(DEFINE_CASE)
56 default:
57 TORCH_CHECK(false, "Not a supported Vulkan ScalarType!");
58 }
59 #undef DEFINE_CASE
60 }
61
62 /*
63 * Converts an `::at::native::vulkan::api::ScalarType` to an equivalent
64 * `c10::ScalarType`.
65 */
convert_dtype(const api::ScalarType dtype)66 static inline c10::ScalarType convert_dtype(const api::ScalarType dtype) {
67 #define DEFINE_CASE(ctype, vkformat, name) \
68 case ::at::native::vulkan::api::ScalarType::name: \
69 return c10::ScalarType::name;
70
71 switch (dtype) {
72 VK_FORALL_SCALAR_TYPES(DEFINE_CASE)
73 default:
74 TORCH_CHECK(false, "Not a supported c10::ScalarType!");
75 }
76 #undef DEFINE_CASE
77 }
78
79 using vTensorImpl = VulkanOpaqueTensorImpl<vTensor>;
80
convert(const vTensor & tensor)81 inline Tensor convert(const vTensor& tensor) {
82 return at::detail::make_tensor<vTensorImpl>(
83 DispatchKeySet(DispatchKey::Vulkan),
84 c10::scalarTypeToTypeMeta(convert_dtype(tensor.dtype())),
85 at::Device(at::kVulkan),
86 tensor,
87 tensor.sizes(),
88 tensor.strides());
89 }
90
convert_quantized(const vTensor & tensor)91 inline Tensor convert_quantized(const vTensor& tensor) {
92 TORCH_CHECK(tensor.is_quantized(), "Not a Quantized Tensor");
93 return at::detail::make_tensor<vTensorImpl>(
94 DispatchKeySet(DispatchKey::Vulkan),
95 c10::scalarTypeToTypeMeta(convert_dtype(tensor.dtype())),
96 at::Device(at::kVulkan),
97 tensor,
98 tensor.sizes(),
99 tensor.strides());
100 }
101
convert(const Tensor & tensor)102 inline vTensor& convert(const Tensor& tensor) {
103 TORCH_INTERNAL_ASSERT(tensor.is_vulkan(), "Vulkan tensor expected!");
104
105 vTensorImpl* const impl =
106 static_cast<vTensorImpl*>(tensor.unsafeGetTensorImpl());
107
108 return impl->unsafe_opaque_handle();
109 }
110
111 } // namespace ops
112 } // namespace vulkan
113 } // namespace native
114 } // namespace at
115
116 #endif /* USE_VULKAN_API */
117