1 #include <ATen/native/vulkan/ops/Common.h>
2 #include <torch/library.h>
3
4 namespace at {
5 namespace native {
6 namespace vulkan {
7 namespace ops {
8 namespace {
9
10 using namespace api::utils;
11
permute_4d(const Tensor & input_arg,const uvec4 & in_size,const uvec4 & out_size,const uvec4 & out_dims,vTensor & v_output)12 Tensor permute_4d(
13 const Tensor& input_arg,
14 const uvec4& in_size,
15 const uvec4& out_size,
16 const uvec4& out_dims,
17 vTensor& v_output) {
18 api::Context* const context = api::context();
19
20 const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
21 const vTensor& v_self = convert(input);
22
23 uint32_t out_channels = out_size.data[1u];
24 uint32_t in_channels = in_size.data[1u];
25
26 uint32_t out_c_aligned = api::utils::align_up(out_channels, 4u);
27 uint32_t in_c_aligned = api::utils::align_up(in_channels, 4u);
28
29 const struct Block final {
30 ivec3 out_extents;
31 int32_t fill0;
32 ivec3 in_extents;
33 int32_t fill1;
34 uvec4 out_tensor_size;
35 uvec4 in_tensor_size;
36 uvec4 out_ndims;
37 uvec2 ch_info;
38 } block{
39 api::utils::make_ivec3(v_output.extents()),
40 0,
41 api::utils::make_ivec3(v_self.extents()),
42 0,
43 out_size,
44 in_size,
45 out_dims,
46 {out_c_aligned, in_c_aligned},
47 };
48
49 api::UniformParamsBuffer params(context, block);
50 api::PipelineBarrier pipeline_barrier{};
51
52 context->submit_compute_job(
53 // shader descriptor
54 VK_KERNEL(permute_4d),
55 // pipeline barrier
56 pipeline_barrier,
57 // global work group size
58 v_output.extents(),
59 // local work group size
60 adaptive_work_group_size(v_output.extents()),
61 // fence handle
62 VK_NULL_HANDLE,
63 // shader arguments
64 v_output.image(
65 pipeline_barrier,
66 api::PipelineStage::COMPUTE,
67 api::MemoryAccessType::READ | api::MemoryAccessType::WRITE),
68 v_self.image(pipeline_barrier, api::PipelineStage::COMPUTE),
69 // params buffer
70 params.buffer());
71
72 return convert(v_output);
73 }
74
permute(const Tensor & self,IntArrayRef dims)75 Tensor permute(const Tensor& self, IntArrayRef dims) {
76 auto nDims = safe_downcast<uint32_t>(self.dim());
77 TORCH_CHECK(
78 dims.size() == (size_t)nDims, "number of dims don't match in permute");
79
80 uvec4 in_size{1u, 1u, 1u, 1u}, out_size{1u, 1u, 1u, 1u};
81 uvec4 out_dims{0u, 1u, 2u, 3u};
82
83 auto oldSizes = self.sizes();
84 DimVector newSizes(nDims);
85 bool sameDims = true;
86 std::vector<bool> seen(nDims);
87 for (const auto i : c10::irange(nDims)) {
88 auto dim = safe_downcast<uint32_t>(maybe_wrap_dim(dims[i], nDims));
89 TORCH_CHECK(!seen[dim], "repeated dim in permute");
90 seen[dim] = true;
91 newSizes[i] = oldSizes[dim];
92 if (dim != i) {
93 sameDims = false;
94 }
95 // generalize into 4D tensor
96 in_size.data[(4u - nDims) + i] = self.sizes()[i];
97 out_size.data[(4u - nDims) + i] = self.sizes()[dim];
98 out_dims.data[(4u - nDims) + i] = dim + (4u - nDims);
99 }
100
101 if (sameDims) {
102 return self;
103 }
104
105 IntArrayRef output_sizes(newSizes);
106 vTensor v_output{
107 api::context(),
108 output_sizes.vec(),
109 convert_dtype(self.scalar_type()),
110 };
111
112 return permute_4d(self, in_size, out_size, out_dims, v_output);
113 }
114
115 #ifdef USE_VULKAN_API
116
TORCH_LIBRARY_IMPL(aten,Vulkan,m)117 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
118 m.impl(TORCH_SELECTIVE_NAME("aten::permute"), TORCH_FN(permute));
119 }
120
121 #endif /* USE_VULKAN_API */
122
123 } // namespace
124 } // namespace ops
125 } // namespace vulkan
126 } // namespace native
127 } // namespace at
128