xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
15 
16 namespace vkcompute {
17 
resize_repeat_interleave_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)18 void resize_repeat_interleave_node(
19     ComputeGraph* graph,
20     const std::vector<ArgGroup>& args,
21     const std::vector<ValueRef>& extra_args) {
22   (void)extra_args;
23   vTensorPtr out = graph->get_tensor(args[0].refs[0]);
24   vTensorPtr in = graph->get_tensor(args[1].refs[0]);
25 
26   const int64_t nrepeats = graph->extract_scalar<int64_t>(extra_args[0]);
27   int64_t repeat_dim = graph->extract_scalar<int64_t>(extra_args[1]);
28 
29   std::vector<int64_t> new_sizes = in->sizes();
30   repeat_dim = normalize(repeat_dim, new_sizes.size());
31   new_sizes.at(repeat_dim) *= nrepeats;
32 
33   out->virtual_resize(new_sizes);
34 }
35 
add_repeat_interleave_node(ComputeGraph & graph,const ValueRef in,const ValueRef num_repeats,const ValueRef dim,const ValueRef out)36 void add_repeat_interleave_node(
37     ComputeGraph& graph,
38     const ValueRef in,
39     const ValueRef num_repeats,
40     const ValueRef dim,
41     const ValueRef out) {
42   const int32_t nrepeats = graph.extract_scalar<int32_t>(num_repeats);
43   const int32_t repeat_dim =
44       graph.extract_whcn_dim<int32_t>(dim, graph.dim_of(in));
45 
46   VK_CHECK_COND(repeat_dim != graph.packed_dim_of(out));
47   VK_CHECK_COND(repeat_dim != graph.packed_dim_of(in));
48 
49   std::string kernel_name = "repeat_interleave";
50   add_dtype_suffix(kernel_name, graph.dtype_of(out));
51 
52   const utils::uvec3 global_wg_size = graph.logical_limits_of(in);
53   const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
54 
55   graph.execute_nodes().emplace_back(new DispatchNode(
56       graph,
57       // Shader
58       VK_KERNEL_FROM_STR(kernel_name),
59       // Workgroup sizes
60       global_wg_size,
61       local_wg_size,
62       // Inputs and Outputs
63       {{out, vkapi::MemoryAccessType::WRITE},
64        {in, vkapi::MemoryAccessType::READ}},
65       // Parameter buffers
66       {graph.logical_limits_ubo(in)},
67       // Specialization Constants
68       {graph.hashed_layout_of(out),
69        graph.hashed_layout_of(in),
70        nrepeats,
71        repeat_dim},
72       // Resizing Logic
73       resize_repeat_interleave_node,
74       {num_repeats, dim}));
75 }
76 
repeat_interleave(ComputeGraph & graph,const std::vector<ValueRef> & args)77 void repeat_interleave(ComputeGraph& graph, const std::vector<ValueRef>& args) {
78   int args_i = 0;
79   const ValueRef in = args[args_i++];
80   const ValueRef num_repeats = args[args_i++];
81   const ValueRef dim = args[args_i++];
82   const ValueRef output_size = args[args_i++];
83   const ValueRef out = args[args_i++];
84 
85   // Output size is not used in the kernel
86   (void)output_size;
87 
88   add_repeat_interleave_node(graph, in, num_repeats, dim, out);
89 }
90 
91 REGISTER_OPERATORS {
92   VK_REGISTER_OP(aten.repeat_interleave.self_int, repeat_interleave);
93 }
94 
95 } // namespace vkcompute
96