xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/impl/Packing.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/vulkan/api/api.h>
2 
3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
4 
5 namespace at {
6 namespace native {
7 namespace vulkan {
8 namespace packing {
9 
10 api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst);
11 api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src);
12 
13 void record_nchw_to_image_op(
14     api::Context* const context,
15     api::ShaderInfo& compute_shader,
16     api::VulkanBuffer& src_buffer,
17     vTensor& v_dst,
18     api::PipelineBarrier pipeline_barrier,
19     VkFence fence_handle);
20 
21 bool record_image_to_nchw_op(
22     api::Context* const context,
23     api::ShaderInfo& compute_shader,
24     vTensor& v_src,
25     api::VulkanBuffer& dst_buffer,
26     api::PipelineBarrier pipeline_barrier,
27     VkFence fence_handle);
28 
29 void record_nchw_to_buffer_op(
30     api::Context* const context,
31     api::VulkanBuffer& src_buffer,
32     vTensor& v_dst,
33     api::PipelineBarrier pipeline_barrier,
34     VkFence fence_handle);
35 
36 bool record_buffer_to_nchw_op(
37     api::Context* const context,
38     vTensor& v_src,
39     api::VulkanBuffer& dst_buffer,
40     api::PipelineBarrier pipeline_barrier,
41     VkFence fence_handle);
42 
43 vTensor convert_image_channels_packed_to_height_packed(const vTensor& v_input);
44 
45 vTensor convert_image_channels_packed_to_width_packed(const vTensor& v_input);
46 
47 } // namespace packing
48 } // namespace vulkan
49 } // namespace native
50 } // namespace at
51