xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/transformer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/nn/init.h>
3 #include <torch/nn/modules/transformer.h>
4 #include <torch/nn/modules/transformercoder.h>
5 #include <torch/nn/modules/transformerlayer.h>
6 
7 #include <limits>
8 
9 namespace F = torch::nn::functional;
10 
11 namespace torch {
12 namespace nn {
13 
14 // ========================TransformerEncoderLayerImpl=========================
TransformerEncoderLayerImpl(TransformerEncoderLayerOptions options_)15 TransformerEncoderLayerImpl::TransformerEncoderLayerImpl(
16     TransformerEncoderLayerOptions options_)
17     : options(std::move(options_)) {
18   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
19   reset();
20 }
21 
reset()22 void TransformerEncoderLayerImpl::reset() {
23   // NOTE: reset() is for initializing the model only, calling reset() after the
24   // model is created will throw exceptionss. Call reset_parameter() if the
25   // created model needs a reset
26 
27   self_attn = this->register_module(
28       "self_attn",
29       MultiheadAttention(
30           MultiheadAttentionOptions(options.d_model(), options.nhead())
31               .dropout(options.dropout())));
32 
33   linear1 = this->register_module(
34       "linear1", Linear(options.d_model(), options.dim_feedforward()));
35   dropout = this->register_module("dropout", Dropout(options.dropout()));
36   linear2 = this->register_module(
37       "linear2", Linear(options.dim_feedforward(), options.d_model()));
38 
39   norm1 = this->register_module(
40       "norm1", LayerNorm(LayerNormOptions({options.d_model()})));
41   norm2 = this->register_module(
42       "norm2", LayerNorm(LayerNormOptions({options.d_model()})));
43 
44   dropout1 = this->register_module("dropout1", Dropout(options.dropout()));
45   dropout2 = this->register_module("dropout2", Dropout(options.dropout()));
46 }
47 
reset_parameters()48 void TransformerEncoderLayerImpl::reset_parameters() {
49   // TODO xinyu: standardrize reset_parameters virtual funcs
50   self_attn->_reset_parameters();
51 
52   linear1->reset_parameters();
53   // dropout->reset_parameters();
54   linear2->reset_parameters();
55 
56   norm1->reset_parameters();
57   norm2->reset_parameters();
58 
59   // dropout1->reset_parameters();
60   // dropout2->reset_parameters();
61 }
62 
forward(const Tensor & src,const Tensor & src_mask,const Tensor & src_key_padding_mask)63 Tensor TransformerEncoderLayerImpl::forward(
64     const Tensor& src,
65     const Tensor& src_mask,
66     const Tensor& src_key_padding_mask) {
67   // multihead attention
68   Tensor src2 = std::get<0>(self_attn(
69       src, src, src, src_key_padding_mask, /*need_weights=*/true, src_mask));
70   // add & norm
71   Tensor ret = norm1(src + dropout1(src2));
72 
73   // feedforward
74   if (std::holds_alternative<enumtype::kGELU>(options.activation())) {
75     src2 = linear2(dropout(F::gelu(linear1(ret))));
76   } else if (std::holds_alternative<enumtype::kReLU>(options.activation())) {
77     src2 = linear2(dropout(F::relu(linear1(ret))));
78   } else if (std::holds_alternative<std::function<Tensor(const Tensor&)>>(
79                  options.activation())) {
80     auto callable_activation =
81         *std::get_if<std::function<Tensor(const Tensor&)>>(
82             &options.activation());
83     src2 = linear2(dropout(callable_activation(linear1(ret))));
84   } else {
85     TORCH_CHECK(false, "activation should be kGELU, kReLU, or a callable");
86   }
87 
88   // add & norm
89   return norm2(ret + dropout2(src2));
90 }
91 
92 // ========================TransformerDecoderLayerImpl=========================
TransformerDecoderLayerImpl(TransformerDecoderLayerOptions options_)93 TransformerDecoderLayerImpl::TransformerDecoderLayerImpl(
94     TransformerDecoderLayerOptions options_)
95     : options(std::move(options_)) {
96   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
97   reset();
98 }
99 
reset()100 void TransformerDecoderLayerImpl::reset() {
101   // NOTE: reset() is for initializing the model only, calling reset() after the
102   // model is created will cause throwing exceptions. Call reset_parameter() if
103   // the created model needs a reset.
104 
105   // initialize self attention
106   self_attn = this->register_module(
107       "self_attn",
108       MultiheadAttention(
109           MultiheadAttentionOptions(options.d_model(), options.nhead())
110               .dropout(options.dropout())));
111 
112   // initialize multihed attention
113   multihead_attn = this->register_module(
114       "multihead_attn",
115       MultiheadAttention(
116           MultiheadAttentionOptions(options.d_model(), options.nhead())
117               .dropout(options.dropout())));
118 
119   // Initialize Feed forward first linear layer
120   linear1 = this->register_module(
121       "linear1", Linear(options.d_model(), options.dim_feedforward()));
122   // initialize Feed forward dropout layer
123   dropout = this->register_module("dropout", Dropout(options.dropout()));
124   // initialize Feed forward second linear layer
125   linear2 = this->register_module(
126       "linear2", Linear(options.dim_feedforward(), options.d_model()));
127 
128   // initialize Normalization, post self attention
129   norm1 = this->register_module(
130       "norm1", LayerNorm(LayerNormOptions({options.d_model()})));
131   // initialize post multi-headed attention Normalization
132   norm2 = this->register_module(
133       "norm2", LayerNorm(LayerNormOptions({options.d_model()})));
134   // initialize normalization, post feed forward
135   norm3 = this->register_module(
136       "norm3", LayerNorm(LayerNormOptions({options.d_model()})));
137 
138   // initialize Dropout, post self attention
139   dropout1 = this->register_module("dropout1", Dropout(options.dropout()));
140   // initialize post multi-headed attention dropout layer
141   dropout2 = this->register_module("dropout2", Dropout(options.dropout()));
142   // initialize dropout, post feed forward
143   dropout3 = this->register_module("dropout3", Dropout(options.dropout()));
144 }
145 
reset_parameters()146 void TransformerDecoderLayerImpl::reset_parameters() {
147   // TODO xinyu: standardrize reset_parameters virtual funcs
148   self_attn->_reset_parameters();
149   multihead_attn->_reset_parameters();
150 
151   linear1->reset_parameters();
152   // dropout->reset_paramteres();
153   linear2->reset_parameters();
154 
155   norm1->reset_parameters();
156   norm2->reset_parameters();
157   norm3->reset_parameters();
158   // dropout1->reset_parameters();
159   // dropout2->reset_parameters();
160   // dropout3->reset_paramteres();
161 }
162 
163 /// Pass the inputs (and mask) through the decoder layer.
forward(Tensor tgt,const Tensor & memory,const Tensor & tgt_mask,const Tensor & memory_mask,const Tensor & tgt_key_padding_mask,const Tensor & memory_key_padding_mask)164 Tensor TransformerDecoderLayerImpl::forward(
165     Tensor tgt,
166     const Tensor& memory,
167     const Tensor& tgt_mask,
168     const Tensor& memory_mask,
169     const Tensor& tgt_key_padding_mask,
170     const Tensor& memory_key_padding_mask) {
171   Tensor tgt2 = std::get<0>(self_attn(
172       tgt, // query
173       tgt, // key
174       tgt, // value
175       tgt_key_padding_mask, // key_padding_mask
176       false, // need_weights
177       tgt_mask) // attn_mask
178   );
179   tgt = tgt + dropout1(tgt2);
180   tgt = norm1(tgt);
181 
182   tgt2 = std::get<0>(multihead_attn(
183       tgt, // query
184       memory, // key
185       memory, // value
186       memory_key_padding_mask, // key_padding_mask
187       false, // need_weights
188       memory_mask) // attn_mask
189   );
190   tgt = tgt + dropout2(tgt2);
191   tgt = norm2(tgt);
192 
193   tgt2 = linear2(dropout(activation(linear1(tgt))));
194   tgt = tgt + dropout3(tgt2);
195   tgt = norm3(tgt);
196 
197   return tgt;
198 }
199 
activation(const Tensor & input)200 Tensor TransformerDecoderLayerImpl::activation(const Tensor& input) {
201   if (std::holds_alternative<enumtype::kGELU>(options.activation())) {
202     return F::gelu(input);
203   } else if (std::holds_alternative<enumtype::kReLU>(options.activation())) {
204     return F::relu(input);
205   } else if (std::holds_alternative<std::function<Tensor(const Tensor&)>>(
206                  options.activation())) {
207     auto callable_activation =
208         *std::get_if<std::function<Tensor(const Tensor&)>>(
209             &options.activation());
210     return callable_activation(input);
211   } else {
212     TORCH_CHECK(false, "activation should be kGELU, kReLU, or a callable");
213   }
214 }
215 
216 // ========================TransformerEncoderImpl=========================
TransformerEncoderImpl(TransformerEncoderOptions options_)217 TransformerEncoderImpl::TransformerEncoderImpl(
218     TransformerEncoderOptions options_)
219     : options(std::move(options_)) {
220   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
221   reset();
222 }
223 
reset()224 void TransformerEncoderImpl::reset() {
225   layers = this->register_module("layers", ModuleList());
226   for (const auto i : c10::irange(options.num_layers())) {
227     (void)i; // Suppress unused variable warning
228     layers->push_back(options.encoder_layer()->clone());
229   }
230 
231   if (!options.norm().is_empty()) {
232     norm = options.norm().clone();
233     this->register_module("norm", norm.ptr());
234   }
235 }
236 
reset_parameters()237 void TransformerEncoderImpl::reset_parameters() {
238   TORCH_CHECK(
239       layers->size() == static_cast<size_t>(options.num_layers()),
240       "TransformerEncoder should have",
241       options.num_layers(),
242       " encoder layers, but got ",
243       layers->size());
244 
245   size_t num_layers = layers->size();
246   for (const auto i : c10::irange(num_layers)) {
247     layers->at<TransformerEncoderLayerImpl>(i).reset_parameters();
248   }
249   // a. No way to know whether module in AnyModule has api to reset_parameters,
250   // so replace instead b. Allow user to add/delete normalization module when
251   // reset parameters
252   if (!norm.is_empty()) {
253     this->unregister_module("norm");
254     norm = AnyModule();
255   }
256   if (!options.norm().is_empty()) {
257     norm = options.norm().clone();
258     this->register_module("norm", norm.ptr());
259   }
260 }
261 
forward(const Tensor & src,const Tensor & src_mask,const Tensor & src_key_padding_mask)262 Tensor TransformerEncoderImpl::forward(
263     const Tensor& src,
264     const Tensor& src_mask,
265     const Tensor& src_key_padding_mask) {
266   size_t num_layers = layers->size();
267   Tensor output;
268   if (num_layers > 0) {
269     output = layers->at<TransformerEncoderLayerImpl>(0).forward(
270         src, src_mask, src_key_padding_mask);
271   }
272   for (const auto i : c10::irange(1, num_layers)) {
273     output = layers->at<TransformerEncoderLayerImpl>(i).forward(
274         output, src_mask, src_key_padding_mask);
275   }
276 
277   if (!norm.is_empty()) {
278     output = norm.forward<Tensor>(num_layers == 0 ? src : output);
279   }
280   return output;
281 }
282 
283 // ========================TransformerDecoderImpl=========================
TransformerDecoderImpl(TransformerDecoderOptions options_)284 TransformerDecoderImpl::TransformerDecoderImpl(
285     TransformerDecoderOptions options_)
286     : options(std::move(options_)) {
287   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
288   reset();
289 }
290 
reset()291 void TransformerDecoderImpl::reset() {
292   layers = this->register_module("layers", ModuleList());
293   for (const auto i : c10::irange(options.num_layers())) {
294     (void)i; // Suppress unused variable warning
295     layers->push_back(options.decoder_layer()->clone());
296   }
297 
298   if (!options.norm().is_empty()) {
299     norm = options.norm().clone();
300     this->register_module("norm", norm.ptr());
301   }
302 }
303 
reset_parameters()304 void TransformerDecoderImpl::reset_parameters() {
305   TORCH_CHECK(
306       layers->size() == static_cast<size_t>(options.num_layers()),
307       "TransformerDecoder should have",
308       options.num_layers(),
309       " decoder layers, but got ",
310       layers->size());
311 
312   size_t num_layers = layers->size();
313   for (const auto i : c10::irange(num_layers)) {
314     layers->at<TransformerDecoderLayerImpl>(i).reset_parameters();
315   }
316   // a. No way to know whether module in AnyModule has api to reset_parameters,
317   // so replace instead b. Allow user to add/delete normalization module when
318   // reset parameters
319   if (!norm.is_empty()) {
320     this->unregister_module("norm");
321     norm = AnyModule();
322   }
323   if (!options.norm().is_empty()) {
324     norm = options.norm().clone();
325     this->register_module("norm", norm.ptr());
326   }
327 }
328 
forward(const Tensor & tgt,const Tensor & memory,const Tensor & tgt_mask,const Tensor & memory_mask,const Tensor & tgt_key_padding_mask,const Tensor & memory_key_padding_mask)329 Tensor TransformerDecoderImpl::forward(
330     const Tensor& tgt,
331     const Tensor& memory,
332     const Tensor& tgt_mask,
333     const Tensor& memory_mask,
334     const Tensor& tgt_key_padding_mask,
335     const Tensor& memory_key_padding_mask) {
336   size_t num_layers = layers->size();
337   Tensor output;
338   if (num_layers > 0) {
339     output = layers->at<TransformerDecoderLayerImpl>(0).forward(
340         tgt,
341         memory,
342         tgt_mask,
343         memory_mask,
344         tgt_key_padding_mask,
345         memory_key_padding_mask);
346   }
347   for (const auto i : c10::irange(1, num_layers)) {
348     output = layers->at<TransformerDecoderLayerImpl>(i).forward(
349         output,
350         memory,
351         tgt_mask,
352         memory_mask,
353         tgt_key_padding_mask,
354         memory_key_padding_mask);
355   }
356 
357   if (!norm.is_empty()) {
358     output = norm.forward<Tensor>(num_layers == 0 ? tgt : output);
359   }
360 
361   return output;
362 }
363 
364 // =======================================TransformerImpl================================
TransformerImpl(TransformerOptions options_)365 TransformerImpl::TransformerImpl(TransformerOptions options_)
366     : options(std::move(options_)) {
367   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
368   reset();
369 }
370 
reset()371 void TransformerImpl::reset() {
372   // set up encoder
373   if (options.custom_encoder().is_empty()) {
374     LayerNorm norm(LayerNormOptions({options.d_model()}));
375     TransformerEncoder trans_encoder(
376         TransformerEncoderOptions(
377             TransformerEncoderLayerOptions(options.d_model(), options.nhead())
378                 .dim_feedforward(options.dim_feedforward())
379                 .dropout(options.dropout())
380                 .activation(options.activation()),
381             options.num_encoder_layers())
382             .norm(AnyModule(norm)));
383 
384     this->encoder = AnyModule(trans_encoder);
385   } else {
386     this->encoder = options.custom_encoder().clone();
387   }
388   this->register_module("encoder", this->encoder.ptr());
389 
390   // set up decoder
391   if (options.custom_decoder().is_empty()) {
392     LayerNorm norm(LayerNormOptions({options.d_model()}));
393     TransformerDecoder trans_decoder(
394         TransformerDecoderOptions(
395             TransformerDecoderLayerOptions(options.d_model(), options.nhead())
396                 .dim_feedforward(options.dim_feedforward())
397                 .dropout(options.dropout())
398                 .activation(options.activation()),
399             options.num_decoder_layers())
400             .norm(AnyModule(norm)));
401 
402     this->decoder = AnyModule(trans_decoder);
403   } else {
404     this->decoder = options.custom_decoder().clone();
405   }
406   this->register_module("decoder", this->decoder.ptr());
407 
408   reset_parameters();
409 }
410 
reset_parameters()411 void TransformerImpl::reset_parameters() {
412   auto parameters = this->parameters();
413   for (auto& param : parameters) {
414     if (param.dim() > 1) {
415       torch::nn::init::xavier_uniform_(param);
416     }
417   }
418 }
419 
forward(const Tensor & src,const Tensor & tgt,const Tensor & src_mask,const Tensor & tgt_mask,const Tensor & memory_mask,const Tensor & src_key_padding_mask,const Tensor & tgt_key_padding_mask,const Tensor & memory_key_padding_mask)420 Tensor TransformerImpl::forward(
421     const Tensor& src,
422     const Tensor& tgt,
423     const Tensor& src_mask,
424     const Tensor& tgt_mask,
425     const Tensor& memory_mask,
426     const Tensor& src_key_padding_mask,
427     const Tensor& tgt_key_padding_mask,
428     const Tensor& memory_key_padding_mask) {
429   TORCH_CHECK(
430       src.dim() == 3 && tgt.dim() == 3,
431       "src and tgt should have 3 dimensions, but got ",
432       src.dim(),
433       " and ",
434       tgt.dim());
435 
436   TORCH_CHECK(
437       src.size(1) == tgt.size(1),
438       "src and tgt should have equal batch size (at dim 1), but got ",
439       src.size(1),
440       " and ",
441       tgt.size(1));
442 
443   TORCH_CHECK(
444       src.size(2) == options.d_model() && tgt.size(2) == options.d_model(),
445       "src and tgt should have same feature size as d_model (at dim 2), but got ",
446       src.size(2),
447       " and ",
448       tgt.size(2),
449       " while d_model is ",
450       options.d_model());
451 
452   Tensor memory =
453       this->encoder.forward<Tensor>(src, src_mask, src_key_padding_mask);
454   Tensor output = this->decoder.forward<Tensor>(
455       tgt,
456       memory,
457       tgt_mask,
458       memory_mask,
459       tgt_key_padding_mask,
460       memory_key_padding_mask);
461 
462   return output;
463 }
464 
generate_square_subsequent_mask(int64_t sz)465 Tensor TransformerImpl::generate_square_subsequent_mask(int64_t sz) {
466   // Treat 0 dim valid here
467   TORCH_CHECK(
468       sz >= 0,
469       "Input size must be non-negative to generate a valid square subsequent mask, but got ",
470       sz);
471 
472   // check IEEE754 support here since -inf is not guaranteed to be valid on non
473   // IEEE754 platform
474   if (std::numeric_limits<float>::is_iec559) {
475     return torch::triu(
476         torch::full({sz, sz}, -std::numeric_limits<float>::infinity()), 1);
477   }
478   // if IEEE754 is not supported, we use the smallest float number in current
479   // platform
480   else {
481     TORCH_WARN_ONCE(
482         "IEEE754 is not supported on this platform, generate_square_subsequent_mask will fill "
483         "the mask with smallest float number on this platform instead of -inf");
484     return torch::triu(
485         torch::full({sz, sz}, std::numeric_limits<float>::lowest()), 1);
486   }
487 }
488 
489 } // namespace nn
490 } // namespace torch
491