1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.h>
12
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
15
16 namespace vkcompute {
17
resize_repeat_interleave_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)18 void resize_repeat_interleave_node(
19 ComputeGraph* graph,
20 const std::vector<ArgGroup>& args,
21 const std::vector<ValueRef>& extra_args) {
22 (void)extra_args;
23 vTensorPtr out = graph->get_tensor(args[0].refs[0]);
24 vTensorPtr in = graph->get_tensor(args[1].refs[0]);
25
26 const int64_t nrepeats = graph->extract_scalar<int64_t>(extra_args[0]);
27 int64_t repeat_dim = graph->extract_scalar<int64_t>(extra_args[1]);
28
29 std::vector<int64_t> new_sizes = in->sizes();
30 repeat_dim = normalize(repeat_dim, new_sizes.size());
31 new_sizes.at(repeat_dim) *= nrepeats;
32
33 out->virtual_resize(new_sizes);
34 }
35
add_repeat_interleave_node(ComputeGraph & graph,const ValueRef in,const ValueRef num_repeats,const ValueRef dim,const ValueRef out)36 void add_repeat_interleave_node(
37 ComputeGraph& graph,
38 const ValueRef in,
39 const ValueRef num_repeats,
40 const ValueRef dim,
41 const ValueRef out) {
42 const int32_t nrepeats = graph.extract_scalar<int32_t>(num_repeats);
43 const int32_t repeat_dim =
44 graph.extract_whcn_dim<int32_t>(dim, graph.dim_of(in));
45
46 VK_CHECK_COND(repeat_dim != graph.packed_dim_of(out));
47 VK_CHECK_COND(repeat_dim != graph.packed_dim_of(in));
48
49 std::string kernel_name = "repeat_interleave";
50 add_dtype_suffix(kernel_name, graph.dtype_of(out));
51
52 const utils::uvec3 global_wg_size = graph.logical_limits_of(in);
53 const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
54
55 graph.execute_nodes().emplace_back(new DispatchNode(
56 graph,
57 // Shader
58 VK_KERNEL_FROM_STR(kernel_name),
59 // Workgroup sizes
60 global_wg_size,
61 local_wg_size,
62 // Inputs and Outputs
63 {{out, vkapi::MemoryAccessType::WRITE},
64 {in, vkapi::MemoryAccessType::READ}},
65 // Parameter buffers
66 {graph.logical_limits_ubo(in)},
67 // Specialization Constants
68 {graph.hashed_layout_of(out),
69 graph.hashed_layout_of(in),
70 nrepeats,
71 repeat_dim},
72 // Resizing Logic
73 resize_repeat_interleave_node,
74 {num_repeats, dim}));
75 }
76
repeat_interleave(ComputeGraph & graph,const std::vector<ValueRef> & args)77 void repeat_interleave(ComputeGraph& graph, const std::vector<ValueRef>& args) {
78 int args_i = 0;
79 const ValueRef in = args[args_i++];
80 const ValueRef num_repeats = args[args_i++];
81 const ValueRef dim = args[args_i++];
82 const ValueRef output_size = args[args_i++];
83 const ValueRef out = args[args_i++];
84
85 // Output size is not used in the kernel
86 (void)output_size;
87
88 add_repeat_interleave_node(graph, in, num_repeats, dim, out);
89 }
90
91 REGISTER_OPERATORS {
92 VK_REGISTER_OP(aten.repeat_interleave.self_int, repeat_interleave);
93 }
94
95 } // namespace vkcompute
96