1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
11 #include <executorch/backends/vulkan/runtime/api/api.h>
12 #include <executorch/backends/vulkan/runtime/graph/Logging.h>
13
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17
18 namespace vkcompute {
19
check_args(const api::vTensor & t_in,int64_t dim,int64_t index,const api::vTensor & t_out)20 void check_args(
21 const api::vTensor& t_in,
22 int64_t dim,
23 int64_t index,
24 const api::vTensor& t_out) {
25 VK_CHECK_COND(check_packed_dim_is(t_in, WHCN::kChannelsDim));
26 VK_CHECK_COND(check_packed_dim_is(t_out, WHCN::kChannelsDim));
27
28 const int64_t in_dim = t_in.dim();
29 VK_CHECK_COND(
30 in_dim == 3 || in_dim == 4,
31 "Vulkan select only support 3d or 4d tensors!");
32
33 const int64_t in_size = t_in.size(dim);
34
35 if (index < -in_size || index >= in_size) {
36 VK_CHECK_COND(
37 false,
38 "select(): index ",
39 index,
40 " t_outof range for tensor of size ",
41 in_size,
42 " at dimension ",
43 dim);
44 }
45 }
46
add_select_int_node(ComputeGraph & graph,const ValueRef in,const ValueRef dim_ref,const ValueRef index_ref,const ValueRef out)47 void add_select_int_node(
48 ComputeGraph& graph,
49 const ValueRef in,
50 const ValueRef dim_ref,
51 const ValueRef index_ref,
52 const ValueRef out) {
53 vTensorPtr t_in = graph.get_tensor(in);
54 vTensorPtr t_out = graph.get_tensor(out);
55 int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
56 int64_t index = graph.extract_scalar<int64_t>(index_ref);
57
58 check_args(*t_in, dim, index, *t_out);
59
60 const int64_t in_size = t_in->size(dim);
61
62 if (index < 0) {
63 index += in_size;
64 }
65
66 std::string kernel_name;
67
68 // for 3d tensors, these values are not used by the shader.
69 int32_t num_texel_per_batch = 1;
70 int32_t num_batches = 1;
71
72 int64_t in_dim = t_in->dim();
73 if (in_dim == 3) {
74 if (dim == 0) {
75 kernel_name = "select_channel_3d";
76 } else if (dim == 1) {
77 kernel_name = "select_height_3d";
78 } else if (dim == 2) {
79 kernel_name = "select_width_3d";
80 } else {
81 VK_CHECK_COND(
82 false, "Unexpected dim value=", dim, "for the input 3d tensor");
83 }
84 } else { // self.dim() == 4
85 num_texel_per_batch =
86 static_cast<int32_t>(std::ceil(static_cast<float>(t_in->size(1)) / 4));
87 num_batches = t_in->size(0);
88 if (dim == 0) {
89 kernel_name = "select_batch_4d";
90 } else if (dim == 1) {
91 kernel_name = "select_channel_4d";
92 } else if (dim == 2) {
93 kernel_name = "select_height_4d";
94 } else if (dim == 3) {
95 kernel_name = "select_width_4d";
96 } else {
97 VK_CHECK_COND(
98 false, "Unexpected dim value=", dim, "for the input 4d tensor");
99 }
100 }
101
102 kernel_name.reserve(kShaderNameReserve);
103 add_dtype_suffix(kernel_name, *t_out);
104
105 // TODO: add resizing to support dynamic shapes.
106 graph.execute_nodes().emplace_back(new DispatchNode(
107 graph,
108 VK_KERNEL_FROM_STR(kernel_name),
109 graph.create_global_wg_size(out),
110 graph.create_local_wg_size(out),
111 // Inputs and Outputs
112 {{out, vkapi::MemoryAccessType::WRITE},
113 {in, vkapi::MemoryAccessType::READ}},
114 // Parameter buffers
115 {t_out->logical_limits_ubo(),
116 t_out->sizes_ubo(),
117 // TODO: num_batches and num_texel_per_batch are provided by
118 // t_out->sizes. Can change the following to reduce params
119 // created.
120 graph.create_params_buffer(
121 utils::make_ivec4({index, num_batches, num_texel_per_batch, 0}))},
122 // Specialization Constants
123 {}));
124 }
125
select_int(ComputeGraph & graph,const std::vector<ValueRef> & args)126 void select_int(ComputeGraph& graph, const std::vector<ValueRef>& args) {
127 return add_select_int_node(graph, args[0], args[1], args[2], args[3]);
128 }
129
130 REGISTER_OPERATORS {
131 VK_REGISTER_OP(aten.select.int, select_int);
132 VK_REGISTER_OP(aten.select_copy.int, select_int);
133 }
134
135 } // namespace vkcompute
136