1 #pragma once
2
3 #ifdef USE_VULKAN_API
4
5 #include <ATen/native/vulkan/ops/Common.h>
6
7 namespace at {
8 namespace native {
9 namespace vulkan {
10 namespace ops {
11
12 void transfer_cpu_to_vulkan(const Tensor&, vTensor&);
13
14 void transfer_vulkan_to_cpu(vTensor&, Tensor&);
15
16 void pack_cpu_to_vulkan(const Tensor& src, vTensor& dst);
17
18 void pack_vulkan_to_cpu(vTensor& src, Tensor& dst);
19
20 Tensor& copy_(Tensor& dst, const Tensor& src);
21
22 vTensor to_vulkan(
23 at::Tensor& src,
24 const api::StorageType storage_type = api::StorageType::TEXTURE_3D);
25
26 at::Tensor from_vulkan(vTensor& v_src);
27
28 //
29 // Utility functions for memcpy
30 //
31
32 template <typename T>
memcpy_to_mapping_impl(const Tensor & src,api::MemoryMap & dst_mapping)33 void memcpy_to_mapping_impl(const Tensor& src, api::MemoryMap& dst_mapping) {
34 T* data_ptr = dst_mapping.template data<T>();
35 memcpy(
36 data_ptr,
37 src.const_data_ptr<T>(),
38 std::min(src.nbytes(), dst_mapping.nbytes()));
39 }
40
41 template <typename T>
memcpy_from_mapping_impl(api::MemoryMap & src_mapping,Tensor & dst)42 void memcpy_from_mapping_impl(api::MemoryMap& src_mapping, Tensor& dst) {
43 T* data_ptr = src_mapping.template data<T>();
44 memcpy(
45 dst.mutable_data_ptr<T>(),
46 data_ptr,
47 std::min(src_mapping.nbytes(), dst.nbytes()));
48 }
49
memcpy_from_mapping_bool(api::MemoryMap & src_mapping,Tensor & dst)50 inline void memcpy_from_mapping_bool(api::MemoryMap& src_mapping, Tensor& dst) {
51 uint8_t* src_ptr = src_mapping.template data<uint8_t>();
52 bool* dst_ptr = dst.mutable_data_ptr<bool>();
53 for (int i = 0; (unsigned)i < std::min(src_mapping.nbytes(), dst.nbytes());
54 ++i) {
55 dst_ptr[i] = static_cast<bool>(src_ptr[i]);
56 }
57 }
58
memcpy_to_mapping_uint8(const Tensor & src,api::MemoryMap & dst_mapping)59 inline void memcpy_to_mapping_uint8(
60 const Tensor& src,
61 api::MemoryMap& dst_mapping) {
62 bool* src_ptr = src.mutable_data_ptr<bool>();
63 uint8_t* dst_ptr = dst_mapping.template data<uint8_t>();
64 for (int i = 0; (unsigned)i < std::min(dst_mapping.nbytes(), src.nbytes());
65 ++i) {
66 dst_ptr[i] = static_cast<uint8_t>(src_ptr[i]);
67 }
68 }
69
70 void memcpy_to_mapping(const Tensor& src, api::MemoryMap& dst_mapping);
71
72 void memcpy_from_mapping(api::MemoryMap& src_mapping, Tensor& dst);
73
74 } // namespace ops
75 } // namespace vulkan
76 } // namespace native
77 } // namespace at
78
79 #endif /* USE_VULKAN_API */
80