xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15 
16 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17 
18 namespace vkcompute {
19 
calc_out_mean_sizes(api::vTensor & self,int64_t normalized_shape_dim)20 std::vector<int64_t> calc_out_mean_sizes(
21     api::vTensor& self,
22     int64_t normalized_shape_dim) {
23   std::vector<int64_t> output_size = self.sizes();
24   int64_t self_dim = self.sizes().size();
25   for (int64_t i = 0; i < normalized_shape_dim; ++i) {
26     output_size.at(self_dim - i - 1) = 1;
27   }
28   return output_size;
29 }
30 
resize_native_layer_norm_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)31 void resize_native_layer_norm_node(
32     ComputeGraph* graph,
33     const std::vector<ArgGroup>& args,
34     const std::vector<ValueRef>& extra_args) {
35   vTensorPtr out = graph->get_tensor(args[0].refs[0]);
36   vTensorPtr mean = graph->get_tensor(args[0].refs[1]);
37   vTensorPtr rstd = graph->get_tensor(args[0].refs[2]);
38   vTensorPtr in = graph->get_tensor(args[1].refs[0]);
39   std::vector<int64_t> in_sizes = in->sizes();
40 
41   const auto normalized_shape_dim = graph->get_int_list(extra_args[0])->size();
42 
43   std::vector<int64_t> mean_size =
44       calc_out_mean_sizes(*in, normalized_shape_dim);
45 
46   out->virtual_resize(in_sizes);
47   mean->virtual_resize(mean_size);
48   rstd->virtual_resize(mean_size);
49 }
50 
check_args(const api::vTensor & in,const api::vTensor & out)51 void check_args(const api::vTensor& in, const api::vTensor& out) {
52   VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
53   VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
54 }
55 
add_native_layer_norm_node(ComputeGraph & graph,const ValueRef in,const ValueRef normalized_shape,const ValueRef weight_data,const ValueRef bias_data,const ValueRef eps,const ValueRef out)56 void add_native_layer_norm_node(
57     ComputeGraph& graph,
58     const ValueRef in,
59     const ValueRef normalized_shape,
60     const ValueRef weight_data,
61     const ValueRef bias_data,
62     const ValueRef eps,
63     const ValueRef out) {
64   const auto normalized_shape_dim =
65       graph.get_int_list(normalized_shape)->size();
66   if (normalized_shape_dim > 1) {
67     VK_THROW("native_layer_norm only supports normalized_shape with dim == 1");
68   }
69 
70   if (graph.val_is_none(weight_data)) {
71     VK_THROW("native_layer_norm requires weight to be non-None");
72   }
73 
74   if (graph.val_is_none(bias_data)) {
75     VK_THROW("native_layer_norm requires bias to be non-None");
76   }
77 
78   ValueRef arg_weight = prepack_standard_like(graph, weight_data, in);
79   ValueRef arg_bias = prepack_standard_like(graph, bias_data, in);
80 
81   const auto out_val = graph.get_value_list(out);
82   vTensorPtr t_out = graph.get_tensor(out_val->at(0));
83   vTensorPtr t_mean = graph.get_tensor(out_val->at(1));
84   vTensorPtr t_input = graph.get_tensor(in);
85   float epsilon = graph.extract_scalar<float>(eps);
86 
87   check_args(*t_input, *t_out);
88 
89   std::vector<int64_t> in_sizes = t_input->sizes();
90 
91   utils::uvec3 global_size = t_mean->logical_limits();
92   utils::uvec3 local_size = adaptive_work_group_size(global_size);
93 
94   std::string kernel_name("native_layer_norm");
95   kernel_name.reserve(kShaderNameReserve);
96 
97   add_dtype_suffix(kernel_name, *t_out);
98 
99   graph.execute_nodes().emplace_back(new DispatchNode(
100       graph,
101       VK_KERNEL_FROM_STR(kernel_name),
102       global_size,
103       local_size,
104       // Inputs and Outputs
105       {{{out_val->at(0), out_val->at(1), out_val->at(2)},
106         vkapi::MemoryAccessType::WRITE},
107        {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
108       // Shader params buffers
109       {
110           t_out->logical_limits_ubo(),
111           t_out->sizes_ubo(),
112           graph.create_params_buffer(epsilon),
113       },
114       // Specialization Constants
115       {
116           t_input->hashed_layout(),
117           t_out->hashed_layout(),
118       },
119       // Resizing Logic
120       resize_native_layer_norm_node,
121       {normalized_shape}));
122 }
123 
native_layer_norm(ComputeGraph & graph,const std::vector<ValueRef> & args)124 void native_layer_norm(ComputeGraph& graph, const std::vector<ValueRef>& args) {
125   return add_native_layer_norm_node(
126       graph, args[0], args[1], args[2], args[3], args[4], args[5]);
127 }
128 
129 REGISTER_OPERATORS {
130   VK_REGISTER_OP(aten.native_layer_norm.default, native_layer_norm);
131 }
132 
133 } // namespace vkcompute
134