xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Copy.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/native/vulkan/ops/Copy.h>
3 #include <ATen/native/vulkan/ops/Utils.h>
4 #include <ATen/vulkan/Context.h>
5 
6 namespace at {
7 namespace native {
8 namespace vulkan {
9 namespace ops {
10 
11 //
12 // Utility functions for memcpy
13 //
14 
memcpy_to_mapping(const Tensor & src,api::MemoryMap & dst_mapping)15 void memcpy_to_mapping(const Tensor& src, api::MemoryMap& dst_mapping) {
16   if (src.dtype() == at::kFloat) {
17     memcpy_to_mapping_impl<float>(src, dst_mapping);
18   } else if (src.dtype() == at::kHalf) {
19     memcpy_to_mapping_impl<c10::Half>(src, dst_mapping);
20   } else if (src.dtype() == c10::kQUInt8) {
21     memcpy_to_mapping_impl<c10::quint8>(src, dst_mapping);
22   } else if (src.dtype() == c10::kQInt8) {
23     memcpy_to_mapping_impl<c10::qint8>(src, dst_mapping);
24   } else if (src.dtype() == c10::kQInt32) {
25     memcpy_to_mapping_impl<c10::qint32>(src, dst_mapping);
26   } else if (src.dtype() == c10::kBool) {
27     memcpy_to_mapping_uint8(src, dst_mapping);
28   } else {
29     TORCH_CHECK(
30         false,
31         "Invalid Data Type: expected c10::kQInt32, c10::kQInt8, c10::kQUInt8,",
32         " c10::kBool, at::kHalf, or at::Float but got ",
33         src.dtype());
34   }
35 }
36 
memcpy_from_mapping(api::MemoryMap & src_mapping,Tensor & dst)37 void memcpy_from_mapping(api::MemoryMap& src_mapping, Tensor& dst) {
38   if (dst.dtype() == at::kFloat) {
39     memcpy_from_mapping_impl<float>(src_mapping, dst);
40   } else if (dst.dtype() == at::kHalf) {
41     memcpy_from_mapping_impl<c10::Half>(src_mapping, dst);
42   } else if (dst.dtype() == c10::kQUInt8) {
43     memcpy_from_mapping_impl<c10::quint8>(src_mapping, dst);
44   } else if (dst.dtype() == c10::kQInt8) {
45     memcpy_from_mapping_impl<c10::qint8>(src_mapping, dst);
46   } else if (dst.dtype() == c10::kQInt32) {
47     memcpy_from_mapping_impl<c10::qint32>(src_mapping, dst);
48   } else if (dst.dtype() == c10::kBool) {
49     memcpy_from_mapping_bool(src_mapping, dst);
50   } else {
51     TORCH_CHECK(
52         false,
53         "Invalid Data Type: expected c10::kQInt32, c10::kQInt8, c10::kQUInt8,",
54         " c10::kBool, at::kHalf or at::Float but got ",
55         dst.dtype());
56   }
57 }
58 
59 //
60 // CPU <-> GPU copy implementations (these functions use Transfer commands)
61 //
62 
transfer_cpu_to_vulkan(const Tensor & src,vTensor & v_dst)63 void transfer_cpu_to_vulkan(const Tensor& src, vTensor& v_dst) {
64   api::Context* const context = api::context();
65 
66   // Convert to dtype corresponding to the image format of the texture to
67   // ensure that byte alignment is consistent when copying. In some cases
68   // a 16 bit format will be used for at::kFloat.
69   Tensor src_nc4hw =
70       utils::nchw_to_nc4hw(src).to(convert_dtype(v_dst.texture_dtype()));
71 
72   api::StorageBuffer staging(context, v_dst.texture_dtype(), v_dst.gpu_numel());
73   // Copy data into the staging buffer
74   {
75     api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);
76     mapping.invalidate();
77 
78     memcpy_to_mapping(src_nc4hw, mapping);
79   }
80 
81   api::PipelineBarrier pipeline_barrier{};
82   utils::copy_buffer_to_vtensor(staging.buffer(), v_dst, pipeline_barrier);
83 }
84 
transfer_vulkan_to_cpu(vTensor & v_src,Tensor & dst)85 void transfer_vulkan_to_cpu(vTensor& v_src, Tensor& dst) {
86   api::Context* const context = api::context();
87 
88   // Temporary tensor to receive copied NC4HW data
89   at::Tensor dst_tmp = utils::create_staging_tensor(v_src);
90 
91   api::StorageBuffer staging(context, v_src.texture_dtype(), v_src.gpu_numel());
92 
93   api::VulkanFence fence = context->fences().get_fence();
94 
95   {
96     // Refer to comment in submit_compute_job. When syncing with the GPU, the
97     // context must not allow other threads to record dispatches into it between
98     // between calling vkQueueSubmit and flushing the context. Therefore,
99     // cmd_mutex_ must be manually managed by the calling thread.
100     std::unique_lock<std::mutex> context_lock(context->dispatch_lock());
101 
102     api::PipelineBarrier pipeline_barrier{};
103     utils::copy_vtensor_to_buffer(
104         v_src, staging.buffer(), pipeline_barrier, fence.get_submit_handle());
105 
106     fence.wait();
107 
108     context->flush();
109     // cmd_mutex_ will be released when exiting this scope.
110   }
111 
112   // Copy data from buffer back to CPU tensor.
113   {
114     api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::READ);
115     mapping.invalidate();
116 
117     memcpy_from_mapping(mapping, dst_tmp);
118   }
119 
120   context->fences().return_fence(fence);
121 
122   dst = utils::nc4hw_to_nchw(dst_tmp, v_src.sizes())
123             .to(convert_dtype(v_src.dtype()));
124 }
125 
transfer_vulkan_to_vulkan(vTensor & src,vTensor & dst)126 static void transfer_vulkan_to_vulkan(vTensor& src, vTensor& dst) {
127   api::Context* const context = api::context();
128 
129   api::PipelineBarrier pipeline_barrier{};
130 
131   context->submit_copy<api::VulkanImage, api::VulkanImage>(
132       // pipeline barrier
133       pipeline_barrier,
134       // images
135       src.image(pipeline_barrier, api::PipelineStage::TRANSFER),
136       dst.image(
137           pipeline_barrier,
138           api::PipelineStage::TRANSFER,
139           api::MemoryAccessType::WRITE),
140       // copy details
141       src.extents(),
142       {0u, 0u, 0u},
143       {0u, 0u, 0u},
144       // fence handle
145       VK_NULL_HANDLE);
146 }
147 
148 //
149 // CPU <-> GPU copy implementations (these functions use compute shaders)
150 //
151 
pack_cpu_to_vulkan(const Tensor & src,vTensor & dst)152 void pack_cpu_to_vulkan(const Tensor& src, vTensor& dst) {
153   api::Context* const context = api::context();
154 
155   // Ensure that src is contiguous in its memory format
156   Tensor src_contig = src.contiguous(src.suggest_memory_format());
157 
158   // Note that the float data type has been enforced for the storage buffer
159   // below. The reason for this is that the nchw_to_image and image_to_nchw
160   // shaders which perform the transfer to/from an image texture expect a buffer
161   // of floats as input. GLSL/Vulkan does not natively support 16 bit arithmetic
162   // types, so for now storage buffers created for compute shaders must define
163   // floats as their base data type.
164   api::StorageBuffer staging(context, api::kFloat, dst.gpu_numel());
165   {
166     api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);
167 
168     // If the dtype() of src is at::kHalf, then first convert it to 32 bit
169     // float. This is required since the nchw_to_image shader uses a float
170     // buffer as input (note that at::kFloat is used to create the StorageBuffer
171     // above).
172     if (src.dtype() == at::kHalf) {
173       memcpy_to_mapping(src_contig.to(at::kFloat), mapping);
174     } else {
175       memcpy_to_mapping(src_contig, mapping);
176     }
177   }
178   utils::pack_staging_to_vtensor(staging.buffer(), dst);
179 }
180 
pack_vulkan_to_cpu(vTensor & src,Tensor & dst)181 void pack_vulkan_to_cpu(vTensor& src, Tensor& dst) {
182   TORCH_CHECK(
183       !src.is_quantized(),
184       "Copy of vulkan quantized tensors to cpu is currently disabled!");
185   api::Context* const context = api::context();
186 
187   // Refer to the comment in pack_cpu_to_vulkan for why at::kFloat is specified
188   // for the storage buffer below.
189   api::StorageBuffer staging(context, api::kFloat, src.gpu_numel());
190 
191   api::VulkanFence fence = context->fences().get_fence();
192 
193   {
194     // Refer to comment in submit_compute_job. When syncing with the GPU, the
195     // context must not allow other threads to record dispatches into it between
196     // between calling vkQueueSubmit and flushing the context. Therefore,
197     // cmd_mutex_ must be manually managed by the calling thread.
198     std::unique_lock<std::mutex> context_lock(context->dispatch_lock());
199 
200     bool submitted_to_gpu = utils::pack_vtensor_to_staging(
201         src, staging.buffer(), fence.get_submit_handle());
202 
203     // Only wait on the fence if work was actually submitted to the GPU.
204     // Otherwise, it will hang indefinitely.
205     if (submitted_to_gpu) {
206       fence.wait();
207     }
208 
209     context->flush();
210     // cmd_mutex_ will be released when exiting this scope.
211   }
212 
213   // Copy data from buffer back to CPU tensor.
214   {
215     api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::READ);
216     mapping.invalidate();
217 
218     // If the dtype() of dst is at::kHalf, then copy the data into a float
219     // version of it first, similar to pack_cpu_to_vulkan().
220     if (dst.dtype() == at::kHalf) {
221       Tensor dst_float = dst.to(at::kFloat);
222       memcpy_from_mapping(mapping, dst_float);
223       dst = dst_float.to(at::kHalf);
224     } else {
225       memcpy_from_mapping(mapping, dst);
226     }
227   }
228 
229   context->fences().return_fence(fence);
230 }
231 
232 //
233 // Copy op implementations
234 //
235 
copy_(Tensor & dst,const Tensor & src)236 Tensor& copy_(Tensor& dst, const Tensor& src) {
237   // Check that sizes are equal
238   TORCH_CHECK(
239       dst.sizes() == src.sizes(), "Vulkan copy_: Tensor sizes are mismatched!");
240 
241   // X -> Vulkan
242   if (at::kVulkan == dst.device().type()) {
243     vTensor& v_self = convert(dst);
244 
245     // Vulkan -> Vulkan
246     if (at::kVulkan == src.device().type()) {
247       vTensor& v_src = convert(src);
248       transfer_vulkan_to_vulkan(v_src, v_self);
249     }
250     // CPU -> Vulkan
251     else {
252       pack_cpu_to_vulkan(src, v_self);
253     }
254   }
255   // Vulkan -> X
256   else if (at::kVulkan == src.device().type()) {
257     vTensor& v_src = convert(src);
258 
259     // Vulkan -> CPU
260     if (dst.device().is_cpu()) {
261       pack_vulkan_to_cpu(v_src, dst);
262     } else {
263       TORCH_CHECK(false, "Unsupported!");
264     }
265   } else {
266     TORCH_INTERNAL_ASSERT(
267         false,
268         "Invalid code path taken! Either the source or the destination tensor "
269         "was expected to be Vulkan a tensor!  Incorrect dispatch?");
270   }
271 
272   return dst;
273 }
274 
to_vulkan(at::Tensor & src,const api::StorageType storage_type)275 vTensor to_vulkan(at::Tensor& src, const api::StorageType storage_type) {
276   TORCH_CHECK(
277       src.device().type() == at::kCPU,
278       "Vulkan to_vulkan(): input tensor must be a CPU tensor!")
279 
280   vTensor v_ret{
281       api::context(),
282       src.sizes().vec(),
283       convert_dtype(src.scalar_type()),
284       storage_type,
285       get_gpu_memory_layout(storage_type, src.suggest_memory_format()),
286   };
287 
288   ops::pack_cpu_to_vulkan(src, v_ret);
289 
290   return v_ret;
291 }
292 
from_vulkan(vTensor & v_src)293 at::Tensor from_vulkan(vTensor& v_src) {
294   at::TensorOptions opt(at::kCPU);
295   opt = opt.dtype(convert_dtype(v_src.dtype()));
296 
297   c10::MemoryFormat v_src_memory_format;
298 
299   switch (v_src.gpu_memory_layout()) {
300     case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED:
301       v_src_memory_format = c10::MemoryFormat::Contiguous;
302       break;
303     case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED:
304       v_src_memory_format = c10::MemoryFormat::ChannelsLast;
305       break;
306     default:
307       TORCH_CHECK(false, "No corresponding memory format");
308   }
309 
310   at::Tensor ret = at::empty(v_src.sizes(), opt).to(v_src_memory_format);
311   ops::pack_vulkan_to_cpu(v_src, ret);
312   return ret;
313 }
314 
315 //
316 // VulkanImpl
317 //
318 
319 struct VulkanImpl final : public at::vulkan::VulkanImplInterface {
is_vulkan_availableat::native::vulkan::ops::VulkanImpl320   bool is_vulkan_available() const override {
321     return api::available();
322   }
323 
vulkan_copy_at::native::vulkan::ops::VulkanImpl324   Tensor& vulkan_copy_(Tensor& self, const Tensor& src) const override {
325     return vulkan::ops::copy_(self, src);
326   }
327 };
328 static at::vulkan::VulkanImplRegistrar g_vulkan_impl(new VulkanImpl());
329 
330 } // namespace ops
331 } // namespace vulkan
332 } // namespace native
333 } // namespace at
334