xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Permute.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/vulkan/ops/Common.h>
2 #include <torch/library.h>
3 
4 namespace at {
5 namespace native {
6 namespace vulkan {
7 namespace ops {
8 namespace {
9 
10 using namespace api::utils;
11 
permute_4d(const Tensor & input_arg,const uvec4 & in_size,const uvec4 & out_size,const uvec4 & out_dims,vTensor & v_output)12 Tensor permute_4d(
13     const Tensor& input_arg,
14     const uvec4& in_size,
15     const uvec4& out_size,
16     const uvec4& out_dims,
17     vTensor& v_output) {
18   api::Context* const context = api::context();
19 
20   const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
21   const vTensor& v_self = convert(input);
22 
23   uint32_t out_channels = out_size.data[1u];
24   uint32_t in_channels = in_size.data[1u];
25 
26   uint32_t out_c_aligned = api::utils::align_up(out_channels, 4u);
27   uint32_t in_c_aligned = api::utils::align_up(in_channels, 4u);
28 
29   const struct Block final {
30     ivec3 out_extents;
31     int32_t fill0;
32     ivec3 in_extents;
33     int32_t fill1;
34     uvec4 out_tensor_size;
35     uvec4 in_tensor_size;
36     uvec4 out_ndims;
37     uvec2 ch_info;
38   } block{
39       api::utils::make_ivec3(v_output.extents()),
40       0,
41       api::utils::make_ivec3(v_self.extents()),
42       0,
43       out_size,
44       in_size,
45       out_dims,
46       {out_c_aligned, in_c_aligned},
47   };
48 
49   api::UniformParamsBuffer params(context, block);
50   api::PipelineBarrier pipeline_barrier{};
51 
52   context->submit_compute_job(
53       // shader descriptor
54       VK_KERNEL(permute_4d),
55       // pipeline barrier
56       pipeline_barrier,
57       // global work group size
58       v_output.extents(),
59       // local work group size
60       adaptive_work_group_size(v_output.extents()),
61       // fence handle
62       VK_NULL_HANDLE,
63       // shader arguments
64       v_output.image(
65           pipeline_barrier,
66           api::PipelineStage::COMPUTE,
67           api::MemoryAccessType::READ | api::MemoryAccessType::WRITE),
68       v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE),
69       // params buffer
70       params.buffer());
71 
72   return convert(v_output);
73 }
74 
permute(const Tensor & self,IntArrayRef dims)75 Tensor permute(const Tensor& self, IntArrayRef dims) {
76   auto nDims = safe_downcast<uint32_t>(self.dim());
77   TORCH_CHECK(
78       dims.size() == (size_t)nDims, "number of dims don't match in permute");
79 
80   uvec4 in_size{1u, 1u, 1u, 1u}, out_size{1u, 1u, 1u, 1u};
81   uvec4 out_dims{0u, 1u, 2u, 3u};
82 
83   auto oldSizes = self.sizes();
84   DimVector newSizes(nDims);
85   bool sameDims = true;
86   std::vector<bool> seen(nDims);
87   for (const auto i : c10::irange(nDims)) {
88     auto dim = safe_downcast<uint32_t>(maybe_wrap_dim(dims[i], nDims));
89     TORCH_CHECK(!seen[dim], "repeated dim in permute");
90     seen[dim] = true;
91     newSizes[i] = oldSizes[dim];
92     if (dim != i) {
93       sameDims = false;
94     }
95     // generalize into 4D tensor
96     in_size.data[(4u - nDims) + i] = self.sizes()[i];
97     out_size.data[(4u - nDims) + i] = self.sizes()[dim];
98     out_dims.data[(4u - nDims) + i] = dim + (4u - nDims);
99   }
100 
101   if (sameDims) {
102     return self;
103   }
104 
105   IntArrayRef output_sizes(newSizes);
106   vTensor v_output{
107       api::context(),
108       output_sizes.vec(),
109       convert_dtype(self.scalar_type()),
110   };
111 
112   return permute_4d(self, in_size, out_size, out_dims, v_output);
113 }
114 
115 #ifdef USE_VULKAN_API
116 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)117 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
118   m.impl(TORCH_SELECTIVE_NAME("aten::permute"), TORCH_FN(permute));
119 }
120 
121 #endif /* USE_VULKAN_API */
122 
123 } // namespace
124 } // namespace ops
125 } // namespace vulkan
126 } // namespace native
127 } // namespace at
128