1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
15
16 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
17
18 namespace vkcompute {
19
calc_out_mean_sizes(api::vTensor & self,int64_t normalized_shape_dim)20 std::vector<int64_t> calc_out_mean_sizes(
21 api::vTensor& self,
22 int64_t normalized_shape_dim) {
23 std::vector<int64_t> output_size = self.sizes();
24 int64_t self_dim = self.sizes().size();
25 for (int64_t i = 0; i < normalized_shape_dim; ++i) {
26 output_size.at(self_dim - i - 1) = 1;
27 }
28 return output_size;
29 }
30
resize_native_layer_norm_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)31 void resize_native_layer_norm_node(
32 ComputeGraph* graph,
33 const std::vector<ArgGroup>& args,
34 const std::vector<ValueRef>& extra_args) {
35 vTensorPtr out = graph->get_tensor(args[0].refs[0]);
36 vTensorPtr mean = graph->get_tensor(args[0].refs[1]);
37 vTensorPtr rstd = graph->get_tensor(args[0].refs[2]);
38 vTensorPtr in = graph->get_tensor(args[1].refs[0]);
39 std::vector<int64_t> in_sizes = in->sizes();
40
41 const auto normalized_shape_dim = graph->get_int_list(extra_args[0])->size();
42
43 std::vector<int64_t> mean_size =
44 calc_out_mean_sizes(*in, normalized_shape_dim);
45
46 out->virtual_resize(in_sizes);
47 mean->virtual_resize(mean_size);
48 rstd->virtual_resize(mean_size);
49 }
50
check_args(const api::vTensor & in,const api::vTensor & out)51 void check_args(const api::vTensor& in, const api::vTensor& out) {
52 VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
53 VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
54 }
55
add_native_layer_norm_node(ComputeGraph & graph,const ValueRef in,const ValueRef normalized_shape,const ValueRef weight_data,const ValueRef bias_data,const ValueRef eps,const ValueRef out)56 void add_native_layer_norm_node(
57 ComputeGraph& graph,
58 const ValueRef in,
59 const ValueRef normalized_shape,
60 const ValueRef weight_data,
61 const ValueRef bias_data,
62 const ValueRef eps,
63 const ValueRef out) {
64 const auto normalized_shape_dim =
65 graph.get_int_list(normalized_shape)->size();
66 if (normalized_shape_dim > 1) {
67 VK_THROW("native_layer_norm only supports normalized_shape with dim == 1");
68 }
69
70 if (graph.val_is_none(weight_data)) {
71 VK_THROW("native_layer_norm requires weight to be non-None");
72 }
73
74 if (graph.val_is_none(bias_data)) {
75 VK_THROW("native_layer_norm requires bias to be non-None");
76 }
77
78 ValueRef arg_weight = prepack_standard_like(graph, weight_data, in);
79 ValueRef arg_bias = prepack_standard_like(graph, bias_data, in);
80
81 const auto out_val = graph.get_value_list(out);
82 vTensorPtr t_out = graph.get_tensor(out_val->at(0));
83 vTensorPtr t_mean = graph.get_tensor(out_val->at(1));
84 vTensorPtr t_input = graph.get_tensor(in);
85 float epsilon = graph.extract_scalar<float>(eps);
86
87 check_args(*t_input, *t_out);
88
89 std::vector<int64_t> in_sizes = t_input->sizes();
90
91 utils::uvec3 global_size = t_mean->logical_limits();
92 utils::uvec3 local_size = adaptive_work_group_size(global_size);
93
94 std::string kernel_name("native_layer_norm");
95 kernel_name.reserve(kShaderNameReserve);
96
97 add_dtype_suffix(kernel_name, *t_out);
98
99 graph.execute_nodes().emplace_back(new DispatchNode(
100 graph,
101 VK_KERNEL_FROM_STR(kernel_name),
102 global_size,
103 local_size,
104 // Inputs and Outputs
105 {{{out_val->at(0), out_val->at(1), out_val->at(2)},
106 vkapi::MemoryAccessType::WRITE},
107 {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
108 // Shader params buffers
109 {
110 t_out->logical_limits_ubo(),
111 t_out->sizes_ubo(),
112 graph.create_params_buffer(epsilon),
113 },
114 // Specialization Constants
115 {
116 t_input->hashed_layout(),
117 t_out->hashed_layout(),
118 },
119 // Resizing Logic
120 resize_native_layer_norm_node,
121 {normalized_shape}));
122 }
123
native_layer_norm(ComputeGraph & graph,const std::vector<ValueRef> & args)124 void native_layer_norm(ComputeGraph& graph, const std::vector<ValueRef>& args) {
125 return add_native_layer_norm_node(
126 graph, args[0], args[1], args[2], args[3], args[4], args[5]);
127 }
128
129 REGISTER_OPERATORS {
130 VK_REGISTER_OP(aten.native_layer_norm.default, native_layer_norm);
131 }
132
133 } // namespace vkcompute
134