xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Copy.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_VULKAN_API
4 
5 #include <ATen/native/vulkan/ops/Common.h>
6 
7 namespace at {
8 namespace native {
9 namespace vulkan {
10 namespace ops {
11 
12 void transfer_cpu_to_vulkan(const Tensor&, vTensor&);
13 
14 void transfer_vulkan_to_cpu(vTensor&, Tensor&);
15 
16 void pack_cpu_to_vulkan(const Tensor& src, vTensor& dst);
17 
18 void pack_vulkan_to_cpu(vTensor& src, Tensor& dst);
19 
20 Tensor& copy_(Tensor& dst, const Tensor& src);
21 
22 vTensor to_vulkan(
23     at::Tensor& src,
24     const api::StorageType storage_type = api::StorageType::TEXTURE_3D);
25 
26 at::Tensor from_vulkan(vTensor& v_src);
27 
28 //
29 // Utility functions for memcpy
30 //
31 
32 template <typename T>
memcpy_to_mapping_impl(const Tensor & src,api::MemoryMap & dst_mapping)33 void memcpy_to_mapping_impl(const Tensor& src, api::MemoryMap& dst_mapping) {
34   T* data_ptr = dst_mapping.template data<T>();
35   memcpy(
36       data_ptr,
37       src.const_data_ptr<T>(),
38       std::min(src.nbytes(), dst_mapping.nbytes()));
39 }
40 
41 template <typename T>
memcpy_from_mapping_impl(api::MemoryMap & src_mapping,Tensor & dst)42 void memcpy_from_mapping_impl(api::MemoryMap& src_mapping, Tensor& dst) {
43   T* data_ptr = src_mapping.template data<T>();
44   memcpy(
45       dst.mutable_data_ptr<T>(),
46       data_ptr,
47       std::min(src_mapping.nbytes(), dst.nbytes()));
48 }
49 
memcpy_from_mapping_bool(api::MemoryMap & src_mapping,Tensor & dst)50 inline void memcpy_from_mapping_bool(api::MemoryMap& src_mapping, Tensor& dst) {
51   uint8_t* src_ptr = src_mapping.template data<uint8_t>();
52   bool* dst_ptr = dst.mutable_data_ptr<bool>();
53   for (int i = 0; (unsigned)i < std::min(src_mapping.nbytes(), dst.nbytes());
54        ++i) {
55     dst_ptr[i] = static_cast<bool>(src_ptr[i]);
56   }
57 }
58 
memcpy_to_mapping_uint8(const Tensor & src,api::MemoryMap & dst_mapping)59 inline void memcpy_to_mapping_uint8(
60     const Tensor& src,
61     api::MemoryMap& dst_mapping) {
62   bool* src_ptr = src.mutable_data_ptr<bool>();
63   uint8_t* dst_ptr = dst_mapping.template data<uint8_t>();
64   for (int i = 0; (unsigned)i < std::min(dst_mapping.nbytes(), src.nbytes());
65        ++i) {
66     dst_ptr[i] = static_cast<uint8_t>(src_ptr[i]);
67   }
68 }
69 
70 void memcpy_to_mapping(const Tensor& src, api::MemoryMap& dst_mapping);
71 
72 void memcpy_from_mapping(api::MemoryMap& src_mapping, Tensor& dst);
73 
74 } // namespace ops
75 } // namespace vulkan
76 } // namespace native
77 } // namespace at
78 
79 #endif /* USE_VULKAN_API */
80