1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/NestedTensorImpl.h>
5
6 #include <torch/library.h>
7
8 #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_addmm_activation.h>
15 #include <ATen/ops/_native_multi_head_attention.h>
16 #include <ATen/ops/_transformer_encoder_layer_fwd_native.h>
17 #include <ATen/ops/addmm.h>
18 #include <ATen/ops/layer_norm.h>
19 #endif
20
21 namespace at {
22
23 namespace native {
24
25 namespace {
linear_for_ffn(const Tensor & bias,const Tensor & mat1,const Tensor & mat2,std::optional<bool> use_gelu)26 Tensor linear_for_ffn(
27 const Tensor& bias,
28 const Tensor& mat1,
29 const Tensor& mat2,
30 std::optional<bool> use_gelu) {
31 if (mat1.is_nested()) {
32 return NestedTensor_times_Tensor_plus_Tensor_addmm(
33 bias, mat1, mat2.t(), 1, 1, use_gelu);
34 }
35
36 auto mat1_ = mat1.view({mat1.sizes()[0] * mat1.sizes()[1], mat1.sizes()[2]});
37 Tensor result;
38 if (use_gelu.has_value()) {
39 result = at::_addmm_activation(bias, mat1_, mat2.t(), 1, 1, *use_gelu);
40 } else {
41 result = at::addmm(bias, mat1_, mat2.t());
42 }
43 return result.view({mat1.sizes()[0], mat1.sizes()[1], -1});
44 }
45
ffn(const Tensor & input,const Tensor & w1,const Tensor & b1,const Tensor & w2,const Tensor & b2,bool use_gelu,bool add_norm)46 Tensor ffn(
47 const Tensor& input,
48 const Tensor& w1,
49 const Tensor& b1,
50 const Tensor& w2,
51 const Tensor& b2,
52 bool use_gelu,
53 bool add_norm) {
54 TORCH_CHECK(add_norm == false, "TODO add_norm to be supported in FFN");
55 TORCH_CHECK(input.dim() == 3, "batched input size should be 3");
56 TORCH_CHECK(w1.dim() == 2, "2d weights expected");
57 TORCH_CHECK(w2.dim() == 2, "2d weights expected");
58 Tensor res = linear_for_ffn(b1, input, w1, use_gelu);
59 res = linear_for_ffn(b2, res, w2, std::nullopt);
60 return res;
61 }
62
norm(const Tensor & input,const int64_t embed_dim,const double eps,const Tensor & weight,const Tensor & bias,const bool use_nested_tensor)63 Tensor norm(
64 const Tensor& input,
65 const int64_t embed_dim,
66 const double eps,
67 const Tensor& weight,
68 const Tensor& bias,
69 const bool use_nested_tensor) {
70 return at::layer_norm(input, {embed_dim}, weight, bias, eps, true);
71 }
72
73 } // namespace
74
transformer_encoder_layer_forward(const Tensor & src,const int64_t embed_dim,const int64_t num_heads,const Tensor & qkv_weight,const Tensor & qkv_bias,const Tensor & proj_weight,const Tensor & proj_bias,const bool use_gelu,const bool norm_first,const double layer_norm_eps,const Tensor & layer_norm_weight_1,const Tensor & layer_norm_bias_1,const Tensor & layer_norm_weight_2,const Tensor & layer_norm_bias_2,const Tensor & ffn_weight_1,const Tensor & ffn_bias_1,const Tensor & ffn_weight_2,const Tensor & ffn_bias_2,const std::optional<Tensor> & mask,const std::optional<int64_t> mask_type)75 Tensor transformer_encoder_layer_forward(
76 const Tensor& src,
77 const int64_t embed_dim,
78 const int64_t num_heads,
79 const Tensor& qkv_weight,
80 const Tensor& qkv_bias,
81 const Tensor& proj_weight,
82 const Tensor& proj_bias,
83 const bool use_gelu,
84 const bool norm_first,
85 const double layer_norm_eps,
86 const Tensor& layer_norm_weight_1,
87 const Tensor& layer_norm_bias_1,
88 const Tensor& layer_norm_weight_2,
89 const Tensor& layer_norm_bias_2,
90 const Tensor& ffn_weight_1,
91 const Tensor& ffn_bias_1,
92 const Tensor& ffn_weight_2,
93 const Tensor& ffn_bias_2,
94 const std::optional<Tensor>& mask,
95 const std::optional<int64_t> mask_type) {
96 {
97 const Tensor& check_for_empty = src.is_nested() ? get_nested_tensor_impl(src)->get_buffer() : src;
98 if (check_for_empty.numel() == 0) {
99 return src.is_nested()
100 ? at::detail::make_tensor<NestedTensorImpl>(check_for_empty, get_nested_tensor_impl(src)->get_nested_sizes())
101 : src.clone();
102 }
103 }
104 const bool use_nested_tensor = src.is_nested();
105 Tensor x = src;
106 if (norm_first) {
107 x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_1, layer_norm_bias_1, use_nested_tensor);
108 }
109 x = std::get<0>(at::_native_multi_head_attention(
110 x,
111 x,
112 x,
113 embed_dim,
114 num_heads,
115 qkv_weight,
116 qkv_bias,
117 proj_weight,
118 proj_bias,
119 mask,
120 false /* need_weights */,
121 true /* average_attn_weights */,
122 mask_type));
123
124 x.add_(src);
125 if (!norm_first) {
126 x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_1, layer_norm_bias_1, use_nested_tensor);
127 }
128
129
130 auto pre_ffn_res = x;
131
132 if (norm_first) {
133 x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_2, layer_norm_bias_2, use_nested_tensor);
134 }
135 x = ffn(
136 x,
137 ffn_weight_1,
138 ffn_bias_1,
139 ffn_weight_2,
140 ffn_bias_2,
141 use_gelu,
142 /* add_norm* */ false);
143 x.add_(pre_ffn_res);
144 if (!norm_first) {
145 x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_2, layer_norm_bias_2, use_nested_tensor);
146 }
147 return x;
148 }
149
150 } // namespace native
151 } // namespace at
152