xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/Select.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10 
11 #include <executorch/backends/vulkan/runtime/api/api.h>
12 #include <executorch/backends/vulkan/runtime/graph/Logging.h>
13 
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17 
18 namespace vkcompute {
19 
check_args(const api::vTensor & t_in,int64_t dim,int64_t index,const api::vTensor & t_out)20 void check_args(
21     const api::vTensor& t_in,
22     int64_t dim,
23     int64_t index,
24     const api::vTensor& t_out) {
25   VK_CHECK_COND(check_packed_dim_is(t_in, WHCN::kChannelsDim));
26   VK_CHECK_COND(check_packed_dim_is(t_out, WHCN::kChannelsDim));
27 
28   const int64_t in_dim = t_in.dim();
29   VK_CHECK_COND(
30       in_dim == 3 || in_dim == 4,
31       "Vulkan select only support 3d or 4d tensors!");
32 
33   const int64_t in_size = t_in.size(dim);
34 
35   if (index < -in_size || index >= in_size) {
36     VK_CHECK_COND(
37         false,
38         "select(): index ",
39         index,
40         " t_outof range for tensor of size ",
41         in_size,
42         " at dimension ",
43         dim);
44   }
45 }
46 
add_select_int_node(ComputeGraph & graph,const ValueRef in,const ValueRef dim_ref,const ValueRef index_ref,const ValueRef out)47 void add_select_int_node(
48     ComputeGraph& graph,
49     const ValueRef in,
50     const ValueRef dim_ref,
51     const ValueRef index_ref,
52     const ValueRef out) {
53   vTensorPtr t_in = graph.get_tensor(in);
54   vTensorPtr t_out = graph.get_tensor(out);
55   int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
56   int64_t index = graph.extract_scalar<int64_t>(index_ref);
57 
58   check_args(*t_in, dim, index, *t_out);
59 
60   const int64_t in_size = t_in->size(dim);
61 
62   if (index < 0) {
63     index += in_size;
64   }
65 
66   std::string kernel_name;
67 
68   // for 3d tensors, these values are not used by the shader.
69   int32_t num_texel_per_batch = 1;
70   int32_t num_batches = 1;
71 
72   int64_t in_dim = t_in->dim();
73   if (in_dim == 3) {
74     if (dim == 0) {
75       kernel_name = "select_channel_3d";
76     } else if (dim == 1) {
77       kernel_name = "select_height_3d";
78     } else if (dim == 2) {
79       kernel_name = "select_width_3d";
80     } else {
81       VK_CHECK_COND(
82           false, "Unexpected dim value=", dim, "for the input 3d tensor");
83     }
84   } else { // self.dim() == 4
85     num_texel_per_batch =
86         static_cast<int32_t>(std::ceil(static_cast<float>(t_in->size(1)) / 4));
87     num_batches = t_in->size(0);
88     if (dim == 0) {
89       kernel_name = "select_batch_4d";
90     } else if (dim == 1) {
91       kernel_name = "select_channel_4d";
92     } else if (dim == 2) {
93       kernel_name = "select_height_4d";
94     } else if (dim == 3) {
95       kernel_name = "select_width_4d";
96     } else {
97       VK_CHECK_COND(
98           false, "Unexpected dim value=", dim, "for the input 4d tensor");
99     }
100   }
101 
102   kernel_name.reserve(kShaderNameReserve);
103   add_dtype_suffix(kernel_name, *t_out);
104 
105   // TODO: add resizing to support dynamic shapes.
106   graph.execute_nodes().emplace_back(new DispatchNode(
107       graph,
108       VK_KERNEL_FROM_STR(kernel_name),
109       graph.create_global_wg_size(out),
110       graph.create_local_wg_size(out),
111       // Inputs and Outputs
112       {{out, vkapi::MemoryAccessType::WRITE},
113        {in, vkapi::MemoryAccessType::READ}},
114       // Parameter buffers
115       {t_out->logical_limits_ubo(),
116        t_out->sizes_ubo(),
117        // TODO: num_batches and num_texel_per_batch are provided by
118        // t_out->sizes. Can change the following to reduce params
119        // created.
120        graph.create_params_buffer(
121            utils::make_ivec4({index, num_batches, num_texel_per_batch, 0}))},
122       // Specialization Constants
123       {}));
124 }
125 
select_int(ComputeGraph & graph,const std::vector<ValueRef> & args)126 void select_int(ComputeGraph& graph, const std::vector<ValueRef>& args) {
127   return add_select_int_node(graph, args[0], args[1], args[2], args[3]);
128 }
129 
130 REGISTER_OPERATORS {
131   VK_REGISTER_OP(aten.select.int, select_int);
132   VK_REGISTER_OP(aten.select_copy.int, select_int);
133 }
134 
135 } // namespace vkcompute
136