xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Transpose.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/vulkan/ops/Common.h>
2 #include <ATen/native/vulkan/ops/Utils.h>
3 #include <torch/library.h>
4 
5 namespace at {
6 namespace native {
7 namespace vulkan {
8 namespace ops {
9 namespace {
10 
11 using namespace api::utils;
12 
transpose_4d(const Tensor & input_arg,const uvec4 & in_size,const uvec4 & out_size,const uvec4 & out_dims,vTensor & v_output)13 Tensor transpose_4d(
14     const Tensor& input_arg,
15     const uvec4& in_size,
16     const uvec4& out_size,
17     const uvec4& out_dims,
18     vTensor& v_output) {
19   api::Context* const context = api::context();
20 
21   const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
22   const vTensor& v_self = convert(input);
23 
24   uint32_t out_channels = out_size.data[1u];
25   uint32_t in_channels = in_size.data[1u];
26 
27   uint32_t out_c_aligned = api::utils::align_up(out_channels, 4u);
28   uint32_t in_c_aligned = api::utils::align_up(in_channels, 4u);
29 
30   const struct Block final {
31     ivec3 out_extents;
32     int32_t fill0;
33     ivec3 in_extents;
34     int32_t fill1;
35     uvec4 out_tensor_size;
36     uvec4 in_tensor_size;
37     uvec4 out_ndims;
38     uvec2 ch_info;
39   } block{
40       api::utils::make_ivec3(v_output.extents()),
41       0,
42       api::utils::make_ivec3(v_self.extents()),
43       0,
44       out_size,
45       in_size,
46       out_dims,
47       {out_c_aligned, in_c_aligned},
48   };
49 
50   api::UniformParamsBuffer params(context, block);
51   api::PipelineBarrier pipeline_barrier{};
52 
53   context->submit_compute_job(
54       // shader descriptor
55       VK_KERNEL(permute_4d),
56       // pipeline barrier
57       pipeline_barrier,
58       // global work group size
59       v_output.extents(),
60       // local work group size
61       adaptive_work_group_size(v_output.extents()),
62       // fence handle
63       VK_NULL_HANDLE,
64       // shader arguments
65       v_output.image(
66           pipeline_barrier,
67           api::PipelineStage::COMPUTE,
68           api::MemoryAccessType::READ | api::MemoryAccessType::WRITE),
69       v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE),
70       // params buffer
71       params.buffer());
72 
73   return convert(v_output);
74 }
75 
transpose(const Tensor & self,int64_t index0,int64_t index1)76 Tensor transpose(const Tensor& self, int64_t index0, int64_t index1) {
77   TORCH_CHECK(
78       self.dim() <= 4,
79       "Vulkan transpose only supports tensors <= 4 dimensions");
80 
81   auto nDims = safe_downcast<uint32_t>(self.dim());
82   uvec4 in_size{1u, 1u, 1u, 1u}, out_size{1u, 1u, 1u, 1u};
83   uvec4 out_dims{0u, 1u, 2u, 3u};
84 
85   auto oldSizes = self.sizes();
86   DimVector newSizes(nDims);
87   auto new_index0 = safe_downcast<uint32_t>(maybe_wrap_dim(index0, nDims));
88   auto new_index1 = safe_downcast<uint32_t>(maybe_wrap_dim(index1, nDims));
89   if (new_index0 == new_index1) {
90     return self.detach();
91   }
92 
93   // generalize input and output into 4D tensor, e.g. input is 3d of shape [2,
94   // 3, 4] by padding at the batch dim, input becomes 4d with in_size = [1, 2,
95   // 3, 4]
96   for (const auto i : c10::irange(nDims)) {
97     in_size.data[(4u - nDims) + i] = self.sizes()[i];
98     out_size.data[(4u - nDims) + i] = self.sizes()[i];
99     newSizes[i] = oldSizes[i];
100   }
101 
102   // get the size of the output by swapping the size of input at index0 and
103   // index1 continue with the example above, if index0 = 0, index1 = 2, then
104   // output is of size out_size = [1, 4, 3, 2].
105   // Note: indices are shifted by (4u - nDims) since input is generalized into
106   // 4d.
107   out_size.data[(4u - nDims) + new_index0] =
108       in_size.data[(4u - nDims) + new_index1];
109   out_size.data[(4u - nDims) + new_index1] =
110       in_size.data[(4u - nDims) + new_index0];
111 
112   // get the desired ordering of dimensions, again we shift by (4u - nDims).
113   // Using the example above, out_dims = [0, 3, 2, 1]
114   auto temp_dim = out_dims.data[(4u - nDims) + new_index0];
115   out_dims.data[(4u - nDims) + new_index0] =
116       out_dims.data[(4u - nDims) + new_index1];
117   out_dims.data[(4u - nDims) + new_index1] = temp_dim;
118 
119   // get the size of the output by swapping sizes of the input. Continue with
120   // the example, newSizes = [1, 4, 3, 2]
121   newSizes[new_index0] = oldSizes[new_index1];
122   newSizes[new_index1] = oldSizes[new_index0];
123 
124   IntArrayRef output_size(newSizes);
125   vTensor v_output{
126       api::context(),
127       output_size.vec(),
128       convert_dtype(self.scalar_type()),
129   };
130 
131   return transpose_4d(self, in_size, out_size, out_dims, v_output);
132 }
133 
t(const Tensor & self)134 Tensor t(const Tensor& self) {
135   TORCH_CHECK(self.dim() <= 2, "t() only supports tensors <= 2 dimensions");
136   return transpose(self.detach(), 0, self.dim() < 2 ? 0 : 1);
137 }
138 
139 #ifdef USE_VULKAN_API
140 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)141 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
142   m.impl(TORCH_SELECTIVE_NAME("aten::t"), TORCH_FN(t));
143   m.impl(TORCH_SELECTIVE_NAME("aten::transpose.int"), TORCH_FN(transpose));
144 }
145 
146 #endif /* USE_VULKAN_API */
147 
148 } // namespace
149 } // namespace ops
150 } // namespace vulkan
151 } // namespace native
152 } // namespace at
153