1 #include <c10/util/irange.h>
2 #include <torch/nn/init.h>
3 #include <torch/nn/modules/transformer.h>
4 #include <torch/nn/modules/transformercoder.h>
5 #include <torch/nn/modules/transformerlayer.h>
6
7 #include <limits>
8
9 namespace F = torch::nn::functional;
10
11 namespace torch {
12 namespace nn {
13
14 // ========================TransformerEncoderLayerImpl=========================
TransformerEncoderLayerImpl(TransformerEncoderLayerOptions options_)15 TransformerEncoderLayerImpl::TransformerEncoderLayerImpl(
16 TransformerEncoderLayerOptions options_)
17 : options(std::move(options_)) {
18 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
19 reset();
20 }
21
reset()22 void TransformerEncoderLayerImpl::reset() {
23 // NOTE: reset() is for initializing the model only, calling reset() after the
24 // model is created will throw exceptionss. Call reset_parameter() if the
25 // created model needs a reset
26
27 self_attn = this->register_module(
28 "self_attn",
29 MultiheadAttention(
30 MultiheadAttentionOptions(options.d_model(), options.nhead())
31 .dropout(options.dropout())));
32
33 linear1 = this->register_module(
34 "linear1", Linear(options.d_model(), options.dim_feedforward()));
35 dropout = this->register_module("dropout", Dropout(options.dropout()));
36 linear2 = this->register_module(
37 "linear2", Linear(options.dim_feedforward(), options.d_model()));
38
39 norm1 = this->register_module(
40 "norm1", LayerNorm(LayerNormOptions({options.d_model()})));
41 norm2 = this->register_module(
42 "norm2", LayerNorm(LayerNormOptions({options.d_model()})));
43
44 dropout1 = this->register_module("dropout1", Dropout(options.dropout()));
45 dropout2 = this->register_module("dropout2", Dropout(options.dropout()));
46 }
47
reset_parameters()48 void TransformerEncoderLayerImpl::reset_parameters() {
49 // TODO xinyu: standardrize reset_parameters virtual funcs
50 self_attn->_reset_parameters();
51
52 linear1->reset_parameters();
53 // dropout->reset_parameters();
54 linear2->reset_parameters();
55
56 norm1->reset_parameters();
57 norm2->reset_parameters();
58
59 // dropout1->reset_parameters();
60 // dropout2->reset_parameters();
61 }
62
forward(const Tensor & src,const Tensor & src_mask,const Tensor & src_key_padding_mask)63 Tensor TransformerEncoderLayerImpl::forward(
64 const Tensor& src,
65 const Tensor& src_mask,
66 const Tensor& src_key_padding_mask) {
67 // multihead attention
68 Tensor src2 = std::get<0>(self_attn(
69 src, src, src, src_key_padding_mask, /*need_weights=*/true, src_mask));
70 // add & norm
71 Tensor ret = norm1(src + dropout1(src2));
72
73 // feedforward
74 if (std::holds_alternative<enumtype::kGELU>(options.activation())) {
75 src2 = linear2(dropout(F::gelu(linear1(ret))));
76 } else if (std::holds_alternative<enumtype::kReLU>(options.activation())) {
77 src2 = linear2(dropout(F::relu(linear1(ret))));
78 } else if (std::holds_alternative<std::function<Tensor(const Tensor&)>>(
79 options.activation())) {
80 auto callable_activation =
81 *std::get_if<std::function<Tensor(const Tensor&)>>(
82 &options.activation());
83 src2 = linear2(dropout(callable_activation(linear1(ret))));
84 } else {
85 TORCH_CHECK(false, "activation should be kGELU, kReLU, or a callable");
86 }
87
88 // add & norm
89 return norm2(ret + dropout2(src2));
90 }
91
92 // ========================TransformerDecoderLayerImpl=========================
TransformerDecoderLayerImpl(TransformerDecoderLayerOptions options_)93 TransformerDecoderLayerImpl::TransformerDecoderLayerImpl(
94 TransformerDecoderLayerOptions options_)
95 : options(std::move(options_)) {
96 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
97 reset();
98 }
99
reset()100 void TransformerDecoderLayerImpl::reset() {
101 // NOTE: reset() is for initializing the model only, calling reset() after the
102 // model is created will cause throwing exceptions. Call reset_parameter() if
103 // the created model needs a reset.
104
105 // initialize self attention
106 self_attn = this->register_module(
107 "self_attn",
108 MultiheadAttention(
109 MultiheadAttentionOptions(options.d_model(), options.nhead())
110 .dropout(options.dropout())));
111
112 // initialize multihed attention
113 multihead_attn = this->register_module(
114 "multihead_attn",
115 MultiheadAttention(
116 MultiheadAttentionOptions(options.d_model(), options.nhead())
117 .dropout(options.dropout())));
118
119 // Initialize Feed forward first linear layer
120 linear1 = this->register_module(
121 "linear1", Linear(options.d_model(), options.dim_feedforward()));
122 // initialize Feed forward dropout layer
123 dropout = this->register_module("dropout", Dropout(options.dropout()));
124 // initialize Feed forward second linear layer
125 linear2 = this->register_module(
126 "linear2", Linear(options.dim_feedforward(), options.d_model()));
127
128 // initialize Normalization, post self attention
129 norm1 = this->register_module(
130 "norm1", LayerNorm(LayerNormOptions({options.d_model()})));
131 // initialize post multi-headed attention Normalization
132 norm2 = this->register_module(
133 "norm2", LayerNorm(LayerNormOptions({options.d_model()})));
134 // initialize normalization, post feed forward
135 norm3 = this->register_module(
136 "norm3", LayerNorm(LayerNormOptions({options.d_model()})));
137
138 // initialize Dropout, post self attention
139 dropout1 = this->register_module("dropout1", Dropout(options.dropout()));
140 // initialize post multi-headed attention dropout layer
141 dropout2 = this->register_module("dropout2", Dropout(options.dropout()));
142 // initialize dropout, post feed forward
143 dropout3 = this->register_module("dropout3", Dropout(options.dropout()));
144 }
145
reset_parameters()146 void TransformerDecoderLayerImpl::reset_parameters() {
147 // TODO xinyu: standardrize reset_parameters virtual funcs
148 self_attn->_reset_parameters();
149 multihead_attn->_reset_parameters();
150
151 linear1->reset_parameters();
152 // dropout->reset_paramteres();
153 linear2->reset_parameters();
154
155 norm1->reset_parameters();
156 norm2->reset_parameters();
157 norm3->reset_parameters();
158 // dropout1->reset_parameters();
159 // dropout2->reset_parameters();
160 // dropout3->reset_paramteres();
161 }
162
163 /// Pass the inputs (and mask) through the decoder layer.
forward(Tensor tgt,const Tensor & memory,const Tensor & tgt_mask,const Tensor & memory_mask,const Tensor & tgt_key_padding_mask,const Tensor & memory_key_padding_mask)164 Tensor TransformerDecoderLayerImpl::forward(
165 Tensor tgt,
166 const Tensor& memory,
167 const Tensor& tgt_mask,
168 const Tensor& memory_mask,
169 const Tensor& tgt_key_padding_mask,
170 const Tensor& memory_key_padding_mask) {
171 Tensor tgt2 = std::get<0>(self_attn(
172 tgt, // query
173 tgt, // key
174 tgt, // value
175 tgt_key_padding_mask, // key_padding_mask
176 false, // need_weights
177 tgt_mask) // attn_mask
178 );
179 tgt = tgt + dropout1(tgt2);
180 tgt = norm1(tgt);
181
182 tgt2 = std::get<0>(multihead_attn(
183 tgt, // query
184 memory, // key
185 memory, // value
186 memory_key_padding_mask, // key_padding_mask
187 false, // need_weights
188 memory_mask) // attn_mask
189 );
190 tgt = tgt + dropout2(tgt2);
191 tgt = norm2(tgt);
192
193 tgt2 = linear2(dropout(activation(linear1(tgt))));
194 tgt = tgt + dropout3(tgt2);
195 tgt = norm3(tgt);
196
197 return tgt;
198 }
199
activation(const Tensor & input)200 Tensor TransformerDecoderLayerImpl::activation(const Tensor& input) {
201 if (std::holds_alternative<enumtype::kGELU>(options.activation())) {
202 return F::gelu(input);
203 } else if (std::holds_alternative<enumtype::kReLU>(options.activation())) {
204 return F::relu(input);
205 } else if (std::holds_alternative<std::function<Tensor(const Tensor&)>>(
206 options.activation())) {
207 auto callable_activation =
208 *std::get_if<std::function<Tensor(const Tensor&)>>(
209 &options.activation());
210 return callable_activation(input);
211 } else {
212 TORCH_CHECK(false, "activation should be kGELU, kReLU, or a callable");
213 }
214 }
215
216 // ========================TransformerEncoderImpl=========================
TransformerEncoderImpl(TransformerEncoderOptions options_)217 TransformerEncoderImpl::TransformerEncoderImpl(
218 TransformerEncoderOptions options_)
219 : options(std::move(options_)) {
220 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
221 reset();
222 }
223
reset()224 void TransformerEncoderImpl::reset() {
225 layers = this->register_module("layers", ModuleList());
226 for (const auto i : c10::irange(options.num_layers())) {
227 (void)i; // Suppress unused variable warning
228 layers->push_back(options.encoder_layer()->clone());
229 }
230
231 if (!options.norm().is_empty()) {
232 norm = options.norm().clone();
233 this->register_module("norm", norm.ptr());
234 }
235 }
236
reset_parameters()237 void TransformerEncoderImpl::reset_parameters() {
238 TORCH_CHECK(
239 layers->size() == static_cast<size_t>(options.num_layers()),
240 "TransformerEncoder should have",
241 options.num_layers(),
242 " encoder layers, but got ",
243 layers->size());
244
245 size_t num_layers = layers->size();
246 for (const auto i : c10::irange(num_layers)) {
247 layers->at<TransformerEncoderLayerImpl>(i).reset_parameters();
248 }
249 // a. No way to know whether module in AnyModule has api to reset_parameters,
250 // so replace instead b. Allow user to add/delete normalization module when
251 // reset parameters
252 if (!norm.is_empty()) {
253 this->unregister_module("norm");
254 norm = AnyModule();
255 }
256 if (!options.norm().is_empty()) {
257 norm = options.norm().clone();
258 this->register_module("norm", norm.ptr());
259 }
260 }
261
forward(const Tensor & src,const Tensor & src_mask,const Tensor & src_key_padding_mask)262 Tensor TransformerEncoderImpl::forward(
263 const Tensor& src,
264 const Tensor& src_mask,
265 const Tensor& src_key_padding_mask) {
266 size_t num_layers = layers->size();
267 Tensor output;
268 if (num_layers > 0) {
269 output = layers->at<TransformerEncoderLayerImpl>(0).forward(
270 src, src_mask, src_key_padding_mask);
271 }
272 for (const auto i : c10::irange(1, num_layers)) {
273 output = layers->at<TransformerEncoderLayerImpl>(i).forward(
274 output, src_mask, src_key_padding_mask);
275 }
276
277 if (!norm.is_empty()) {
278 output = norm.forward<Tensor>(num_layers == 0 ? src : output);
279 }
280 return output;
281 }
282
283 // ========================TransformerDecoderImpl=========================
TransformerDecoderImpl(TransformerDecoderOptions options_)284 TransformerDecoderImpl::TransformerDecoderImpl(
285 TransformerDecoderOptions options_)
286 : options(std::move(options_)) {
287 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
288 reset();
289 }
290
reset()291 void TransformerDecoderImpl::reset() {
292 layers = this->register_module("layers", ModuleList());
293 for (const auto i : c10::irange(options.num_layers())) {
294 (void)i; // Suppress unused variable warning
295 layers->push_back(options.decoder_layer()->clone());
296 }
297
298 if (!options.norm().is_empty()) {
299 norm = options.norm().clone();
300 this->register_module("norm", norm.ptr());
301 }
302 }
303
reset_parameters()304 void TransformerDecoderImpl::reset_parameters() {
305 TORCH_CHECK(
306 layers->size() == static_cast<size_t>(options.num_layers()),
307 "TransformerDecoder should have",
308 options.num_layers(),
309 " decoder layers, but got ",
310 layers->size());
311
312 size_t num_layers = layers->size();
313 for (const auto i : c10::irange(num_layers)) {
314 layers->at<TransformerDecoderLayerImpl>(i).reset_parameters();
315 }
316 // a. No way to know whether module in AnyModule has api to reset_parameters,
317 // so replace instead b. Allow user to add/delete normalization module when
318 // reset parameters
319 if (!norm.is_empty()) {
320 this->unregister_module("norm");
321 norm = AnyModule();
322 }
323 if (!options.norm().is_empty()) {
324 norm = options.norm().clone();
325 this->register_module("norm", norm.ptr());
326 }
327 }
328
forward(const Tensor & tgt,const Tensor & memory,const Tensor & tgt_mask,const Tensor & memory_mask,const Tensor & tgt_key_padding_mask,const Tensor & memory_key_padding_mask)329 Tensor TransformerDecoderImpl::forward(
330 const Tensor& tgt,
331 const Tensor& memory,
332 const Tensor& tgt_mask,
333 const Tensor& memory_mask,
334 const Tensor& tgt_key_padding_mask,
335 const Tensor& memory_key_padding_mask) {
336 size_t num_layers = layers->size();
337 Tensor output;
338 if (num_layers > 0) {
339 output = layers->at<TransformerDecoderLayerImpl>(0).forward(
340 tgt,
341 memory,
342 tgt_mask,
343 memory_mask,
344 tgt_key_padding_mask,
345 memory_key_padding_mask);
346 }
347 for (const auto i : c10::irange(1, num_layers)) {
348 output = layers->at<TransformerDecoderLayerImpl>(i).forward(
349 output,
350 memory,
351 tgt_mask,
352 memory_mask,
353 tgt_key_padding_mask,
354 memory_key_padding_mask);
355 }
356
357 if (!norm.is_empty()) {
358 output = norm.forward<Tensor>(num_layers == 0 ? tgt : output);
359 }
360
361 return output;
362 }
363
364 // =======================================TransformerImpl================================
TransformerImpl(TransformerOptions options_)365 TransformerImpl::TransformerImpl(TransformerOptions options_)
366 : options(std::move(options_)) {
367 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
368 reset();
369 }
370
reset()371 void TransformerImpl::reset() {
372 // set up encoder
373 if (options.custom_encoder().is_empty()) {
374 LayerNorm norm(LayerNormOptions({options.d_model()}));
375 TransformerEncoder trans_encoder(
376 TransformerEncoderOptions(
377 TransformerEncoderLayerOptions(options.d_model(), options.nhead())
378 .dim_feedforward(options.dim_feedforward())
379 .dropout(options.dropout())
380 .activation(options.activation()),
381 options.num_encoder_layers())
382 .norm(AnyModule(norm)));
383
384 this->encoder = AnyModule(trans_encoder);
385 } else {
386 this->encoder = options.custom_encoder().clone();
387 }
388 this->register_module("encoder", this->encoder.ptr());
389
390 // set up decoder
391 if (options.custom_decoder().is_empty()) {
392 LayerNorm norm(LayerNormOptions({options.d_model()}));
393 TransformerDecoder trans_decoder(
394 TransformerDecoderOptions(
395 TransformerDecoderLayerOptions(options.d_model(), options.nhead())
396 .dim_feedforward(options.dim_feedforward())
397 .dropout(options.dropout())
398 .activation(options.activation()),
399 options.num_decoder_layers())
400 .norm(AnyModule(norm)));
401
402 this->decoder = AnyModule(trans_decoder);
403 } else {
404 this->decoder = options.custom_decoder().clone();
405 }
406 this->register_module("decoder", this->decoder.ptr());
407
408 reset_parameters();
409 }
410
reset_parameters()411 void TransformerImpl::reset_parameters() {
412 auto parameters = this->parameters();
413 for (auto& param : parameters) {
414 if (param.dim() > 1) {
415 torch::nn::init::xavier_uniform_(param);
416 }
417 }
418 }
419
forward(const Tensor & src,const Tensor & tgt,const Tensor & src_mask,const Tensor & tgt_mask,const Tensor & memory_mask,const Tensor & src_key_padding_mask,const Tensor & tgt_key_padding_mask,const Tensor & memory_key_padding_mask)420 Tensor TransformerImpl::forward(
421 const Tensor& src,
422 const Tensor& tgt,
423 const Tensor& src_mask,
424 const Tensor& tgt_mask,
425 const Tensor& memory_mask,
426 const Tensor& src_key_padding_mask,
427 const Tensor& tgt_key_padding_mask,
428 const Tensor& memory_key_padding_mask) {
429 TORCH_CHECK(
430 src.dim() == 3 && tgt.dim() == 3,
431 "src and tgt should have 3 dimensions, but got ",
432 src.dim(),
433 " and ",
434 tgt.dim());
435
436 TORCH_CHECK(
437 src.size(1) == tgt.size(1),
438 "src and tgt should have equal batch size (at dim 1), but got ",
439 src.size(1),
440 " and ",
441 tgt.size(1));
442
443 TORCH_CHECK(
444 src.size(2) == options.d_model() && tgt.size(2) == options.d_model(),
445 "src and tgt should have same feature size as d_model (at dim 2), but got ",
446 src.size(2),
447 " and ",
448 tgt.size(2),
449 " while d_model is ",
450 options.d_model());
451
452 Tensor memory =
453 this->encoder.forward<Tensor>(src, src_mask, src_key_padding_mask);
454 Tensor output = this->decoder.forward<Tensor>(
455 tgt,
456 memory,
457 tgt_mask,
458 memory_mask,
459 tgt_key_padding_mask,
460 memory_key_padding_mask);
461
462 return output;
463 }
464
generate_square_subsequent_mask(int64_t sz)465 Tensor TransformerImpl::generate_square_subsequent_mask(int64_t sz) {
466 // Treat 0 dim valid here
467 TORCH_CHECK(
468 sz >= 0,
469 "Input size must be non-negative to generate a valid square subsequent mask, but got ",
470 sz);
471
472 // check IEEE754 support here since -inf is not guaranteed to be valid on non
473 // IEEE754 platform
474 if (std::numeric_limits<float>::is_iec559) {
475 return torch::triu(
476 torch::full({sz, sz}, -std::numeric_limits<float>::infinity()), 1);
477 }
478 // if IEEE754 is not supported, we use the smallest float number in current
479 // platform
480 else {
481 TORCH_WARN_ONCE(
482 "IEEE754 is not supported on this platform, generate_square_subsequent_mask will fill "
483 "the mask with smallest float number on this platform instead of -inf");
484 return torch::triu(
485 torch::full({sz, sz}, std::numeric_limits<float>::lowest()), 1);
486 }
487 }
488
489 } // namespace nn
490 } // namespace torch
491