xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/transformer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/NestedTensorImpl.h>
5 
6 #include <torch/library.h>
7 
8 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_addmm_activation.h>
15 #include <ATen/ops/_native_multi_head_attention.h>
16 #include <ATen/ops/_transformer_encoder_layer_fwd_native.h>
17 #include <ATen/ops/addmm.h>
18 #include <ATen/ops/layer_norm.h>
19 #endif
20 
21 namespace at {
22 
23 namespace native {
24 
25 namespace {
linear_for_ffn(const Tensor & bias,const Tensor & mat1,const Tensor & mat2,std::optional<bool> use_gelu)26 Tensor linear_for_ffn(
27     const Tensor& bias,
28     const Tensor& mat1,
29     const Tensor& mat2,
30     std::optional<bool> use_gelu) {
31   if (mat1.is_nested()) {
32     return NestedTensor_times_Tensor_plus_Tensor_addmm(
33         bias, mat1, mat2.t(), 1, 1, use_gelu);
34   }
35 
36   auto mat1_ = mat1.view({mat1.sizes()[0] * mat1.sizes()[1], mat1.sizes()[2]});
37   Tensor result;
38   if (use_gelu.has_value()) {
39     result = at::_addmm_activation(bias, mat1_, mat2.t(), 1, 1, *use_gelu);
40   } else {
41     result = at::addmm(bias, mat1_, mat2.t());
42   }
43   return result.view({mat1.sizes()[0], mat1.sizes()[1], -1});
44 }
45 
ffn(const Tensor & input,const Tensor & w1,const Tensor & b1,const Tensor & w2,const Tensor & b2,bool use_gelu,bool add_norm)46 Tensor ffn(
47     const Tensor& input,
48     const Tensor& w1,
49     const Tensor& b1,
50     const Tensor& w2,
51     const Tensor& b2,
52     bool use_gelu,
53     bool add_norm) {
54   TORCH_CHECK(add_norm == false, "TODO add_norm to be supported in FFN");
55   TORCH_CHECK(input.dim() == 3, "batched input size should be 3");
56   TORCH_CHECK(w1.dim() == 2, "2d weights expected");
57   TORCH_CHECK(w2.dim() == 2, "2d weights expected");
58   Tensor res = linear_for_ffn(b1, input, w1, use_gelu);
59   res = linear_for_ffn(b2, res, w2, std::nullopt);
60   return res;
61 }
62 
norm(const Tensor & input,const int64_t embed_dim,const double eps,const Tensor & weight,const Tensor & bias,const bool use_nested_tensor)63 Tensor norm(
64     const Tensor& input,
65     const int64_t embed_dim,
66     const double eps,
67     const Tensor& weight,
68     const Tensor& bias,
69     const bool use_nested_tensor) {
70   return at::layer_norm(input, {embed_dim}, weight, bias, eps, true);
71 }
72 
73 } // namespace
74 
transformer_encoder_layer_forward(const Tensor & src,const int64_t embed_dim,const int64_t num_heads,const Tensor & qkv_weight,const Tensor & qkv_bias,const Tensor & proj_weight,const Tensor & proj_bias,const bool use_gelu,const bool norm_first,const double layer_norm_eps,const Tensor & layer_norm_weight_1,const Tensor & layer_norm_bias_1,const Tensor & layer_norm_weight_2,const Tensor & layer_norm_bias_2,const Tensor & ffn_weight_1,const Tensor & ffn_bias_1,const Tensor & ffn_weight_2,const Tensor & ffn_bias_2,const std::optional<Tensor> & mask,const std::optional<int64_t> mask_type)75 Tensor transformer_encoder_layer_forward(
76     const Tensor& src,
77     const int64_t embed_dim,
78     const int64_t num_heads,
79     const Tensor& qkv_weight,
80     const Tensor& qkv_bias,
81     const Tensor& proj_weight,
82     const Tensor& proj_bias,
83     const bool use_gelu,
84     const bool norm_first,
85     const double layer_norm_eps,
86     const Tensor& layer_norm_weight_1,
87     const Tensor& layer_norm_bias_1,
88     const Tensor& layer_norm_weight_2,
89     const Tensor& layer_norm_bias_2,
90     const Tensor& ffn_weight_1,
91     const Tensor& ffn_bias_1,
92     const Tensor& ffn_weight_2,
93     const Tensor& ffn_bias_2,
94     const std::optional<Tensor>& mask,
95     const std::optional<int64_t> mask_type) {
96   {
97     const Tensor& check_for_empty = src.is_nested() ? get_nested_tensor_impl(src)->get_buffer() : src;
98     if (check_for_empty.numel() == 0) {
99       return src.is_nested()
100         ? at::detail::make_tensor<NestedTensorImpl>(check_for_empty, get_nested_tensor_impl(src)->get_nested_sizes())
101         : src.clone();
102     }
103   }
104   const bool use_nested_tensor = src.is_nested();
105   Tensor x = src;
106   if (norm_first) {
107     x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_1, layer_norm_bias_1, use_nested_tensor);
108   }
109   x = std::get<0>(at::_native_multi_head_attention(
110       x,
111       x,
112       x,
113       embed_dim,
114       num_heads,
115       qkv_weight,
116       qkv_bias,
117       proj_weight,
118       proj_bias,
119       mask,
120       false /* need_weights */,
121       true /* average_attn_weights */,
122       mask_type));
123 
124   x.add_(src);
125   if (!norm_first) {
126     x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_1, layer_norm_bias_1, use_nested_tensor);
127   }
128 
129 
130   auto pre_ffn_res = x;
131 
132   if (norm_first) {
133     x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_2, layer_norm_bias_2, use_nested_tensor);
134   }
135   x = ffn(
136       x,
137       ffn_weight_1,
138       ffn_bias_1,
139       ffn_weight_2,
140       ffn_bias_2,
141       use_gelu,
142       /* add_norm* */ false);
143   x.add_(pre_ffn_res);
144   if (!norm_first) {
145     x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_2, layer_norm_bias_2, use_nested_tensor);
146   }
147   return x;
148 }
149 
150 } // namespace native
151 } // namespace at
152