xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15 
16 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17 
18 namespace vkcompute {
19 
check_binary_op_args(const api::vTensor & self,const api::vTensor & other,const api::vTensor & out)20 void check_binary_op_args(
21     const api::vTensor& self,
22     const api::vTensor& other,
23     const api::vTensor& out) {
24   VK_CHECK_COND(check_same_packed_dim(self, other, out));
25   std::vector<int64_t> broadcasted_sizes =
26       calculate_broadcasted_output_size(self, other);
27   VK_CHECK_COND(out.sizes() == broadcasted_sizes);
28 }
29 
resize_binary_op_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)30 void resize_binary_op_node(
31     ComputeGraph* graph,
32     const std::vector<ArgGroup>& args,
33     const std::vector<ValueRef>& extra_args) {
34   (void)extra_args;
35   vTensorPtr out = graph->get_tensor(args[0].refs[0]);
36 
37   // TODO(T183442143): Verify tensors are broadcastable.
38   vTensorPtr self = graph->get_tensor(args[1].refs[0]);
39   vTensorPtr other = graph->get_tensor(args[1].refs[1]);
40 
41   std::vector<int64_t> new_out_sizes =
42       calculate_broadcasted_output_size(*self, *other);
43 
44   out->virtual_resize(new_out_sizes);
45 }
46 
add_binary_op_node(ComputeGraph & graph,const ValueRef in1,const ValueRef in2,const ValueRef alpha,const ValueRef out,const std::string & op_name)47 void add_binary_op_node(
48     ComputeGraph& graph,
49     const ValueRef in1,
50     const ValueRef in2,
51     const ValueRef alpha,
52     const ValueRef out,
53     const std::string& op_name) {
54   ValueRef arg1 = prepack_standard_like(graph, in1, out, true);
55   ValueRef arg2 = prepack_standard_like(graph, in2, out, true);
56 
57   vTensorPtr t_in1 = graph.get_tensor(arg1);
58   vTensorPtr t_in2 = graph.get_tensor(arg2);
59   vTensorPtr t_out = graph.get_tensor(out);
60 
61   check_binary_op_args(*t_in1, *t_in2, *t_out);
62 
63   float alpha_val = 1.0f;
64   // String is checked since floor_div passes in an unused string argument in
65   // place of alpha
66   if (is_valid(alpha) && !graph.val_is_string(alpha)) {
67     alpha_val = graph.extract_scalar<float>(alpha);
68   }
69 
70   const utils::ivec2 broadcast_params = create_broadcast_params(*t_in1, *t_in2);
71 
72   std::string kernel_name("binary_");
73   kernel_name.reserve(kShaderNameReserve);
74   kernel_name += op_name;
75   add_dtype_suffix(kernel_name, *t_out);
76 
77   graph.execute_nodes().emplace_back(new DispatchNode(
78       graph,
79       VK_KERNEL_FROM_STR(kernel_name),
80       graph.create_global_wg_size(out),
81       graph.create_local_wg_size(out),
82       // Inputs and Outputs
83       {{out, vkapi::MemoryAccessType::WRITE},
84        {{arg1, arg2}, vkapi::MemoryAccessType::READ}},
85       // Shader params buffers
86       {t_out->sizes_ubo(),
87        t_in1->sizes_ubo(),
88        t_in2->sizes_ubo(),
89        graph.create_params_buffer(broadcast_params),
90        graph.create_params_buffer(alpha_val)},
91       // Specialization Constants
92       {t_out->hashed_layout(), t_in1->hashed_layout(), t_in2->hashed_layout()},
93       // Resizing Logic
94       resize_binary_op_node,
95       {}));
96 }
97 
98 #define DEFINE_BINARY_OP_WITH_ALPHA_FN(op_name)                          \
99   void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
100     return add_binary_op_node(                                           \
101         graph, args[0], args[1], args[2], args[3], #op_name);            \
102   }
103 
104 #define DEFINE_BINARY_OP_FN(op_name)                                     \
105   void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
106     return add_binary_op_node(                                           \
107         graph, args[0], args[1], kDummyValueRef, args[2], #op_name);     \
108   }
109 
110 DEFINE_BINARY_OP_WITH_ALPHA_FN(add);
111 DEFINE_BINARY_OP_WITH_ALPHA_FN(sub);
112 
113 // Floor div does not have an alpha, but a string argument (which is unused) is
114 // passed in at the same location as the alpha argument in other op.
115 DEFINE_BINARY_OP_WITH_ALPHA_FN(floor_divide);
116 
117 DEFINE_BINARY_OP_FN(mul);
118 DEFINE_BINARY_OP_FN(div);
119 DEFINE_BINARY_OP_FN(pow);
120 DEFINE_BINARY_OP_FN(minimum);
121 
122 REGISTER_OPERATORS {
123   VK_REGISTER_OP(aten.add.Tensor, add);
124   VK_REGISTER_OP(aten.sub.Tensor, sub);
125   VK_REGISTER_OP(aten.mul.Tensor, mul);
126   VK_REGISTER_OP(aten.div.Tensor, div);
127   VK_REGISTER_OP(aten.div.Tensor_mode, floor_divide);
128   VK_REGISTER_OP(aten.pow.Tensor_Tensor, pow);
129   VK_REGISTER_OP(aten.minimum.default, minimum);
130 }
131 
132 } // namespace vkcompute
133