xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/Transpose.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/Logging.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Transpose.h>
14 
15 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
16 
17 #include <algorithm>
18 
19 namespace vkcompute {
20 
resize_transpose_view_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)21 void resize_transpose_view_node(
22     ComputeGraph* graph,
23     const std::vector<ArgGroup>& args,
24     const std::vector<ValueRef>& extra_args) {
25   (void)args;
26   vTensorPtr out = graph->get_tensor(extra_args[0]);
27   vTensorPtr in = graph->get_tensor(extra_args[1]);
28 
29   const int64_t dim0 = graph->extract_scalar<int64_t>(extra_args[2]);
30   const int64_t dim1 = graph->extract_scalar<int64_t>(extra_args[3]);
31 
32   std::vector<int64_t> new_sizes = in->sizes();
33   // Transpose the resized input sizes
34   std::iter_swap(new_sizes.begin() + dim0, new_sizes.begin() + dim1);
35   out->virtual_resize(new_sizes);
36 }
37 
check_transpose_view_args(ComputeGraph & graph,ValueRef in_ref,const int64_t dim0,const int64_t dim1,ValueRef out_ref)38 void check_transpose_view_args(
39     ComputeGraph& graph,
40     ValueRef in_ref,
41     const int64_t dim0,
42     const int64_t dim1,
43     ValueRef out_ref) {
44   VK_CHECK_COND(
45       graph.val_is_view_of(out_ref, in_ref),
46       "output tensor must be a view of the input tensor");
47 
48   const int64_t in_ndim = graph.dim_of(in_ref);
49   VK_CHECK_COND(
50       dim0 >= 0 && dim0 < in_ndim, "dim0 is not in the range of [0, in_ndim)");
51   VK_CHECK_COND(
52       dim1 >= 0 && dim1 < in_ndim, "dim1 is not in the range of [0, in_ndim)");
53 }
54 
add_transpose_view_node(ComputeGraph & graph,ValueRef input_ref,ValueRef dim0_ref,ValueRef dim1_ref,ValueRef out_ref)55 void add_transpose_view_node(
56     ComputeGraph& graph,
57     ValueRef input_ref,
58     ValueRef dim0_ref,
59     ValueRef dim1_ref,
60     ValueRef out_ref) {
61   const int64_t dim0 = graph.extract_scalar<int64_t>(dim0_ref);
62   const int64_t dim1 = graph.extract_scalar<int64_t>(dim1_ref);
63 
64   check_transpose_view_args(graph, input_ref, dim0, dim1, out_ref);
65   const vTensorPtr in = graph.get_tensor(input_ref);
66   graph.get_tensor(out_ref)->virtual_clone(*in);
67   graph.get_tensor(out_ref)->virtual_transpose(dim0, dim1);
68 
69   graph.execute_nodes().emplace_back(new ExecuteNode(
70       resize_transpose_view_node, {out_ref, input_ref, dim0_ref, dim1_ref}));
71 }
72 
transpose(ComputeGraph & graph,const std::vector<ValueRef> & args)73 void transpose(ComputeGraph& graph, const std::vector<ValueRef>& args) {
74   const ValueRef out = args[3];
75   return add_transpose_view_node(
76       graph,
77       args[0], // input
78       args[1], // dim0
79       args[2], // dim1
80       out);
81 }
82 
83 REGISTER_OPERATORS {
84   VK_REGISTER_OP(aten.transpose.int, transpose);
85 }
86 
87 } // namespace vkcompute
88