1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
11 #include <executorch/backends/vulkan/runtime/graph/Logging.h>
12
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Transpose.h>
14
15 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
16
17 #include <algorithm>
18
19 namespace vkcompute {
20
resize_transpose_view_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)21 void resize_transpose_view_node(
22 ComputeGraph* graph,
23 const std::vector<ArgGroup>& args,
24 const std::vector<ValueRef>& extra_args) {
25 (void)args;
26 vTensorPtr out = graph->get_tensor(extra_args[0]);
27 vTensorPtr in = graph->get_tensor(extra_args[1]);
28
29 const int64_t dim0 = graph->extract_scalar<int64_t>(extra_args[2]);
30 const int64_t dim1 = graph->extract_scalar<int64_t>(extra_args[3]);
31
32 std::vector<int64_t> new_sizes = in->sizes();
33 // Transpose the resized input sizes
34 std::iter_swap(new_sizes.begin() + dim0, new_sizes.begin() + dim1);
35 out->virtual_resize(new_sizes);
36 }
37
check_transpose_view_args(ComputeGraph & graph,ValueRef in_ref,const int64_t dim0,const int64_t dim1,ValueRef out_ref)38 void check_transpose_view_args(
39 ComputeGraph& graph,
40 ValueRef in_ref,
41 const int64_t dim0,
42 const int64_t dim1,
43 ValueRef out_ref) {
44 VK_CHECK_COND(
45 graph.val_is_view_of(out_ref, in_ref),
46 "output tensor must be a view of the input tensor");
47
48 const int64_t in_ndim = graph.dim_of(in_ref);
49 VK_CHECK_COND(
50 dim0 >= 0 && dim0 < in_ndim, "dim0 is not in the range of [0, in_ndim)");
51 VK_CHECK_COND(
52 dim1 >= 0 && dim1 < in_ndim, "dim1 is not in the range of [0, in_ndim)");
53 }
54
add_transpose_view_node(ComputeGraph & graph,ValueRef input_ref,ValueRef dim0_ref,ValueRef dim1_ref,ValueRef out_ref)55 void add_transpose_view_node(
56 ComputeGraph& graph,
57 ValueRef input_ref,
58 ValueRef dim0_ref,
59 ValueRef dim1_ref,
60 ValueRef out_ref) {
61 const int64_t dim0 = graph.extract_scalar<int64_t>(dim0_ref);
62 const int64_t dim1 = graph.extract_scalar<int64_t>(dim1_ref);
63
64 check_transpose_view_args(graph, input_ref, dim0, dim1, out_ref);
65 const vTensorPtr in = graph.get_tensor(input_ref);
66 graph.get_tensor(out_ref)->virtual_clone(*in);
67 graph.get_tensor(out_ref)->virtual_transpose(dim0, dim1);
68
69 graph.execute_nodes().emplace_back(new ExecuteNode(
70 resize_transpose_view_node, {out_ref, input_ref, dim0_ref, dim1_ref}));
71 }
72
transpose(ComputeGraph & graph,const std::vector<ValueRef> & args)73 void transpose(ComputeGraph& graph, const std::vector<ValueRef>& args) {
74 const ValueRef out = args[3];
75 return add_transpose_view_node(
76 graph,
77 args[0], // input
78 args[1], // dim0
79 args[2], // dim1
80 out);
81 }
82
83 REGISTER_OPERATORS {
84 VK_REGISTER_OP(aten.transpose.int, transpose);
85 }
86
87 } // namespace vkcompute
88