xref: /aosp_15_r20/external/pytorch/benchmarks/functional_autograd_benchmark/torchaudio_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Taken from https://github.com/pytorch/audio/blob/master/torchaudio/models/wav2letter.py
2# So that we don't need torchaudio to be installed
3
4import math
5from collections import OrderedDict
6from typing import Optional, Tuple
7
8import torch
9import torch.nn.functional as F
10from torch import nn, Tensor
11
12
13__all__ = ["Wav2Letter"]
14
15
16class Wav2Letter(nn.Module):
17    r"""Wav2Letter model architecture from the `"Wav2Letter: an End-to-End ConvNet-based Speech Recognition System"
18     <https://arxiv.org/abs/1609.03193>`_ paper.
19     :math:`\text{padding} = \frac{\text{ceil}(\text{kernel} - \text{stride})}{2}`
20    Args:
21        num_classes (int, optional): Number of classes to be classified. (Default: ``40``)
22        input_type (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum``
23         or ``mfcc`` (Default: ``waveform``).
24        num_features (int, optional): Number of input features that the network will receive (Default: ``1``).
25    """
26
27    def __init__(
28        self, num_classes: int = 40, input_type: str = "waveform", num_features: int = 1
29    ) -> None:
30        super().__init__()
31
32        acoustic_num_features = 250 if input_type == "waveform" else num_features
33        acoustic_model = nn.Sequential(
34            nn.Conv1d(
35                in_channels=acoustic_num_features,
36                out_channels=250,
37                kernel_size=48,
38                stride=2,
39                padding=23,
40            ),
41            nn.ReLU(inplace=True),
42            nn.Conv1d(
43                in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3
44            ),
45            nn.ReLU(inplace=True),
46            nn.Conv1d(
47                in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3
48            ),
49            nn.ReLU(inplace=True),
50            nn.Conv1d(
51                in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3
52            ),
53            nn.ReLU(inplace=True),
54            nn.Conv1d(
55                in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3
56            ),
57            nn.ReLU(inplace=True),
58            nn.Conv1d(
59                in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3
60            ),
61            nn.ReLU(inplace=True),
62            nn.Conv1d(
63                in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3
64            ),
65            nn.ReLU(inplace=True),
66            nn.Conv1d(
67                in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3
68            ),
69            nn.ReLU(inplace=True),
70            nn.Conv1d(
71                in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16
72            ),
73            nn.ReLU(inplace=True),
74            nn.Conv1d(
75                in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0
76            ),
77            nn.ReLU(inplace=True),
78            nn.Conv1d(
79                in_channels=2000,
80                out_channels=num_classes,
81                kernel_size=1,
82                stride=1,
83                padding=0,
84            ),
85            nn.ReLU(inplace=True),
86        )
87
88        if input_type == "waveform":
89            waveform_model = nn.Sequential(
90                nn.Conv1d(
91                    in_channels=num_features,
92                    out_channels=250,
93                    kernel_size=250,
94                    stride=160,
95                    padding=45,
96                ),
97                nn.ReLU(inplace=True),
98            )
99            self.acoustic_model = nn.Sequential(waveform_model, acoustic_model)
100
101        if input_type in ["power_spectrum", "mfcc"]:
102            self.acoustic_model = acoustic_model
103
104    def forward(self, x: Tensor) -> Tensor:
105        r"""
106        Args:
107            x (Tensor): Tensor of dimension (batch_size, num_features, input_length).
108        Returns:
109            Tensor: Predictor tensor of dimension (batch_size, number_of_classes, input_length).
110        """
111
112        x = self.acoustic_model(x)
113        x = nn.functional.log_softmax(x, dim=1)
114        return x
115
116
117# Taken from  https://github.com/SeanNaren/deepspeech.pytorch with modifications
118class SequenceWise(nn.Module):
119    def __init__(self, module):
120        """
121        Collapses input of dim T*N*H to (T*N)*H, and applies to a module.
122        Allows handling of variable sequence lengths and minibatch sizes.
123        :param module: Module to apply input to.
124        """
125        super().__init__()
126        self.module = module
127
128    def forward(self, x):
129        t, n = x.size(0), x.size(1)
130        x = x.view(t * n, -1)
131        x = self.module(x)
132        x = x.view(t, n, -1)
133        return x
134
135    def __repr__(self):
136        tmpstr = self.__class__.__name__ + " (\n"
137        tmpstr += self.module.__repr__()
138        tmpstr += ")"
139        return tmpstr
140
141
142class MaskConv(nn.Module):
143    def __init__(self, seq_module):
144        """
145        Adds padding to the output of the module based on the given lengths. This is to ensure that the
146        results of the model do not change when batch sizes change during inference.
147        Input needs to be in the shape of (BxCxDxT)
148        :param seq_module: The sequential module containing the conv stack.
149        """
150        super().__init__()
151        self.seq_module = seq_module
152
153    def forward(self, x, lengths):
154        """
155        :param x: The input of size BxCxDxT
156        :param lengths: The actual length of each sequence in the batch
157        :return: Masked output from the module
158        """
159        for module in self.seq_module:
160            x = module(x)
161            mask = torch.BoolTensor(x.size()).fill_(0)
162            if x.is_cuda:
163                mask = mask.cuda()
164            for i, length in enumerate(lengths):
165                length = length.item()
166                if (mask[i].size(2) - length) > 0:
167                    mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1)
168            x = x.masked_fill(mask, 0)
169        return x, lengths
170
171
172class InferenceBatchSoftmax(nn.Module):
173    def forward(self, input_):
174        if not self.training:
175            return F.softmax(input_, dim=-1)
176        else:
177            return input_
178
179
180class BatchRNN(nn.Module):
181    def __init__(
182        self,
183        input_size,
184        hidden_size,
185        rnn_type=nn.LSTM,
186        bidirectional=False,
187        batch_norm=True,
188    ):
189        super().__init__()
190        self.input_size = input_size
191        self.hidden_size = hidden_size
192        self.bidirectional = bidirectional
193        self.batch_norm = (
194            SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None
195        )
196        self.rnn = rnn_type(
197            input_size=input_size,
198            hidden_size=hidden_size,
199            bidirectional=bidirectional,
200            bias=True,
201        )
202        self.num_directions = 2 if bidirectional else 1
203
204    def flatten_parameters(self):
205        self.rnn.flatten_parameters()
206
207    def forward(self, x, output_lengths):
208        if self.batch_norm is not None:
209            x = self.batch_norm(x)
210        x = nn.utils.rnn.pack_padded_sequence(x, output_lengths, enforce_sorted=False)
211        x, h = self.rnn(x)
212        x, _ = nn.utils.rnn.pad_packed_sequence(x)
213        if self.bidirectional:
214            x = (
215                x.view(x.size(0), x.size(1), 2, -1)
216                .sum(2)
217                .view(x.size(0), x.size(1), -1)
218            )  # (TxNxH*2) -> (TxNxH) by sum
219        return x
220
221
222class Lookahead(nn.Module):
223    # Wang et al., 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks
224    # input shape - sequence, batch, feature - TxNxH
225    # output shape - same as input
226    def __init__(self, n_features, context):
227        super().__init__()
228        assert context > 0
229        self.context = context
230        self.n_features = n_features
231        self.pad = (0, self.context - 1)
232        self.conv = nn.Conv1d(
233            self.n_features,
234            self.n_features,
235            kernel_size=self.context,
236            stride=1,
237            groups=self.n_features,
238            padding=0,
239            bias=None,
240        )
241
242    def forward(self, x):
243        x = x.transpose(0, 1).transpose(1, 2)
244        x = F.pad(x, pad=self.pad, value=0)
245        x = self.conv(x)
246        x = x.transpose(1, 2).transpose(0, 1).contiguous()
247        return x
248
249    def __repr__(self):
250        return (
251            self.__class__.__name__
252            + "("
253            + "n_features="
254            + str(self.n_features)
255            + ", context="
256            + str(self.context)
257            + ")"
258        )
259
260
261class DeepSpeech(nn.Module):
262    def __init__(
263        self,
264        rnn_type,
265        labels,
266        rnn_hidden_size,
267        nb_layers,
268        audio_conf,
269        bidirectional,
270        context=20,
271    ):
272        super().__init__()
273
274        self.hidden_size = rnn_hidden_size
275        self.hidden_layers = nb_layers
276        self.rnn_type = rnn_type
277        self.audio_conf = audio_conf
278        self.labels = labels
279        self.bidirectional = bidirectional
280
281        sample_rate = self.audio_conf["sample_rate"]
282        window_size = self.audio_conf["window_size"]
283        num_classes = len(self.labels)
284
285        self.conv = MaskConv(
286            nn.Sequential(
287                nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)),
288                nn.BatchNorm2d(32),
289                nn.Hardtanh(0, 20, inplace=True),
290                nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)),
291                nn.BatchNorm2d(32),
292                nn.Hardtanh(0, 20, inplace=True),
293            )
294        )
295        # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1
296        rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1)
297        rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1)
298        rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1)
299        rnn_input_size *= 32
300
301        rnns = []
302        rnn = BatchRNN(
303            input_size=rnn_input_size,
304            hidden_size=rnn_hidden_size,
305            rnn_type=rnn_type,
306            bidirectional=bidirectional,
307            batch_norm=False,
308        )
309        rnns.append(("0", rnn))
310        for x in range(nb_layers - 1):
311            rnn = BatchRNN(
312                input_size=rnn_hidden_size,
313                hidden_size=rnn_hidden_size,
314                rnn_type=rnn_type,
315                bidirectional=bidirectional,
316            )
317            rnns.append(("%d" % (x + 1), rnn))
318        self.rnns = nn.Sequential(OrderedDict(rnns))
319        self.lookahead = (
320            nn.Sequential(
321                # consider adding batch norm?
322                Lookahead(rnn_hidden_size, context=context),
323                nn.Hardtanh(0, 20, inplace=True),
324            )
325            if not bidirectional
326            else None
327        )
328
329        fully_connected = nn.Sequential(
330            nn.BatchNorm1d(rnn_hidden_size),
331            nn.Linear(rnn_hidden_size, num_classes, bias=False),
332        )
333        self.fc = nn.Sequential(
334            SequenceWise(fully_connected),
335        )
336        self.inference_softmax = InferenceBatchSoftmax()
337
338    def forward(self, x, lengths):
339        lengths = lengths.cpu().int()
340        output_lengths = self.get_seq_lens(lengths)
341        x, _ = self.conv(x, output_lengths)
342
343        sizes = x.size()
344        x = x.view(
345            sizes[0], sizes[1] * sizes[2], sizes[3]
346        )  # Collapse feature dimension
347        x = x.transpose(1, 2).transpose(0, 1).contiguous()  # TxNxH
348
349        for rnn in self.rnns:
350            x = rnn(x, output_lengths)
351
352        if not self.bidirectional:  # no need for lookahead layer in bidirectional
353            x = self.lookahead(x)
354
355        x = self.fc(x)
356        x = x.transpose(0, 1)
357        # identity in training mode, softmax in eval mode
358        x = self.inference_softmax(x)
359        return x, output_lengths
360
361    def get_seq_lens(self, input_length):
362        """
363        Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable
364        containing the size sequences that will be output by the network.
365        :param input_length: 1D Tensor
366        :return: 1D Tensor scaled by model
367        """
368        seq_len = input_length
369        for m in self.conv.modules():
370            if type(m) == nn.modules.conv.Conv2d:
371                seq_len = (
372                    seq_len
373                    + 2 * m.padding[1]
374                    - m.dilation[1] * (m.kernel_size[1] - 1)
375                    - 1
376                )
377                seq_len = seq_len.true_divide(m.stride[1]) + 1
378        return seq_len.int()
379
380
381# Taken from https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L108-L152
382class PositionalEncoding(nn.Module):
383    r"""Inject some information about the relative or absolute position of the tokens
384        in the sequence. The positional encodings have the same dimension as
385        the embeddings, so that the two can be summed. Here, we use sine and cosine
386        functions of different frequencies.
387    .. math::
388        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
389        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
390        \text{where pos is the word position and i is the embed idx)
391    Args:
392        d_model: the embed dim (required).
393        dropout: the dropout value (default=0.1).
394        max_len: the max. length of the incoming sequence (default=5000).
395    Examples:
396        >>> pos_encoder = PositionalEncoding(d_model)
397    """
398
399    def __init__(self, d_model, dropout=0.1, max_len=5000):
400        super().__init__()
401        self.dropout = nn.Dropout(p=dropout)
402
403        pe = torch.zeros(max_len, d_model)
404        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
405        div_term = torch.exp(
406            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
407        )
408        pe[:, 0::2] = torch.sin(position * div_term)
409        pe[:, 1::2] = torch.cos(position * div_term)
410        pe = pe.unsqueeze(0).transpose(0, 1)
411        self.register_buffer("pe", pe)
412
413    def forward(self, x):
414        r"""Inputs of forward function
415        Args:
416            x: the sequence fed to the positional encoder model (required).
417        Shape:
418            x: [sequence length, batch size, embed dim]
419            output: [sequence length, batch size, embed dim]
420        Examples:
421            >>> output = pos_encoder(x)
422        """
423
424        x = x + self.pe[: x.size(0), :]
425        return self.dropout(x)
426
427
428class TransformerModel(nn.Module):
429    """Container module with an encoder, a recurrent or transformer module, and a decoder."""
430
431    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
432        super().__init__()
433        try:
434            from torch.nn import TransformerEncoder, TransformerEncoderLayer
435        except Exception as e:
436            raise ImportError(
437                "TransformerEncoder module does not exist in PyTorch 1.1 or lower."
438            ) from e
439        self.model_type = "Transformer"
440        self.src_mask = None
441        self.pos_encoder = PositionalEncoding(ninp, dropout)
442        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
443        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
444        self.encoder = nn.Embedding(ntoken, ninp)
445        self.ninp = ninp
446        self.decoder = nn.Linear(ninp, ntoken)
447
448        self.init_weights()
449
450    def init_weights(self):
451        initrange = 0.1
452        nn.init.uniform_(self.encoder.weight, -initrange, initrange)
453        # Not sure how this works in the original code
454        # nn.init.zeros_(self.decoder)
455        nn.init.uniform_(self.decoder.weight, -initrange, initrange)
456
457    def forward(self, src, has_mask=True):
458        if has_mask:
459            device = src.device
460            # This will be created once during warmup
461            if self.src_mask is None or self.src_mask.size(0) != len(src):
462                mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(
463                    device
464                )
465                self.src_mask = mask
466        else:
467            self.src_mask = None
468
469        src = self.encoder(src) * math.sqrt(self.ninp)
470        src = self.pos_encoder(src)
471        output = self.transformer_encoder(src, self.src_mask)
472        output = self.decoder(output)
473        return F.log_softmax(output, dim=-1)
474
475
476# From https://github.com/pytorch/text/blob/master/torchtext/modules
477class MultiheadAttentionContainer(torch.nn.Module):
478    def __init__(self, nhead, in_proj_container, attention_layer, out_proj):
479        r"""A multi-head attention container
480        Args:
481            nhead: the number of heads in the multiheadattention model
482            in_proj_container: A container of multi-head in-projection linear layers (a.k.a nn.Linear).
483            attention_layer: The attention layer.
484            out_proj: The multi-head out-projection layer (a.k.a nn.Linear).
485        Examples::
486            >>> import torch
487            >>> embed_dim, num_heads, bsz = 10, 5, 64
488            >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim),
489                                                    torch.nn.Linear(embed_dim, embed_dim),
490                                                    torch.nn.Linear(embed_dim, embed_dim))
491            >>> MHA = MultiheadAttentionContainer(num_heads,
492                                                  in_proj_container,
493                                                  ScaledDotProduct(),
494                                                  torch.nn.Linear(embed_dim, embed_dim))
495            >>> query = torch.rand((21, bsz, embed_dim))
496            >>> key = value = torch.rand((16, bsz, embed_dim))
497            >>> attn_output, attn_weights = MHA(query, key, value)
498            >>> print(attn_output.shape)
499            >>> torch.Size([21, 64, 10])
500        """
501        super().__init__()
502        self.nhead = nhead
503        self.in_proj_container = in_proj_container
504        self.attention_layer = attention_layer
505        self.out_proj = out_proj
506
507    def forward(
508        self,
509        query: torch.Tensor,
510        key: torch.Tensor,
511        value: torch.Tensor,
512        attn_mask: Optional[torch.Tensor] = None,
513        bias_k: Optional[torch.Tensor] = None,
514        bias_v: Optional[torch.Tensor] = None,
515    ) -> Tuple[torch.Tensor, torch.Tensor]:
516        r"""
517        Args:
518            query, key, value (Tensor): map a query and a set of key-value pairs to an output.
519                See "Attention Is All You Need" for more details.
520            attn_mask, bias_k and bias_v (Tensor, optional): keyword arguments passed to the attention layer.
521                See the definitions in the attention.
522        Shape:
523            - Inputs:
524            - query: :math:`(L, N, E)`
525            - key: :math:`(S, N, E)`
526            - value: :math:`(S, N, E)`
527            - attn_mask, bias_k and bias_v: same with the shape of the corresponding args in attention layer.
528            - Outputs:
529            - attn_output: :math:`(L, N, E)`
530            - attn_output_weights: :math:`(N * H, L, S)`
531            where where L is the target length, S is the sequence length, H is the number of attention heads,
532                N is the batch size, and E is the embedding dimension.
533        """
534        tgt_len, src_len, bsz, embed_dim = (
535            query.size(-3),
536            key.size(-3),
537            query.size(-2),
538            query.size(-1),
539        )
540        q, k, v = self.in_proj_container(query, key, value)
541        assert (
542            q.size(-1) % self.nhead == 0
543        ), "query's embed_dim must be divisible by the number of heads"
544        head_dim = q.size(-1) // self.nhead
545        q = q.reshape(tgt_len, bsz * self.nhead, head_dim)
546
547        assert (
548            k.size(-1) % self.nhead == 0
549        ), "key's embed_dim must be divisible by the number of heads"
550        head_dim = k.size(-1) // self.nhead
551        k = k.reshape(src_len, bsz * self.nhead, head_dim)
552
553        assert (
554            v.size(-1) % self.nhead == 0
555        ), "value's embed_dim must be divisible by the number of heads"
556        head_dim = v.size(-1) // self.nhead
557        v = v.reshape(src_len, bsz * self.nhead, head_dim)
558
559        attn_output, attn_output_weights = self.attention_layer(
560            q, k, v, attn_mask=attn_mask, bias_k=bias_k, bias_v=bias_v
561        )
562        attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
563        attn_output = self.out_proj(attn_output)
564        return attn_output, attn_output_weights
565
566
567class ScaledDotProduct(torch.nn.Module):
568    def __init__(self, dropout=0.0):
569        r"""Processes a projected query and key-value pair to apply
570        scaled dot product attention.
571        Args:
572            dropout (float): probability of dropping an attention weight.
573        Examples::
574            >>> SDP = torchtext.models.ScaledDotProduct(0.1)
575            >>> q = torch.randn(256, 21, 3)
576            >>> k = v = torch.randn(256, 21, 3)
577            >>> attn_output, attn_weights = SDP(q, k, v)
578            >>> print(attn_output.shape, attn_weights.shape)
579            torch.Size([256, 21, 3]) torch.Size([256, 21, 21])
580        """
581        super().__init__()
582        self.dropout = dropout
583
584    def forward(
585        self,
586        query: torch.Tensor,
587        key: torch.Tensor,
588        value: torch.Tensor,
589        attn_mask: Optional[torch.Tensor] = None,
590        bias_k: Optional[torch.Tensor] = None,
591        bias_v: Optional[torch.Tensor] = None,
592    ) -> Tuple[torch.Tensor, torch.Tensor]:
593        r"""Uses a scaled dot product with the projected key-value pair to update
594        the projected query.
595        Args:
596            query (Tensor): Projected query
597            key (Tensor): Projected key
598            value (Tensor): Projected value
599            attn_mask (BoolTensor, optional): 3D mask that prevents attention to certain positions.
600            bias_k and bias_v: (Tensor, optional): one more key and value sequence to be added at
601                sequence dim (dim=-3). Those are used for incremental decoding. Users should provide
602                non-None to both arguments in order to activate them.
603        Shape:
604            - query: :math:`(L, N * H, E / H)`
605            - key: :math:`(S, N * H, E / H)`
606            - value: :math:`(S, N * H, E / H)`
607            - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not allowed to attend
608                while ``False`` values will be unchanged.
609            - bias_k and bias_v:bias: :math:`(1, N * H, E / H)`
610            - Output: :math:`(L, N * H, E / H)`, :math:`(N * H, L, S)`
611            where L is the target length, S is the source length, H is the number
612            of attention heads, N is the batch size, and E is the embedding dimension.
613        """
614        if bias_k is not None and bias_v is not None:
615            assert (
616                key.size(-1) == bias_k.size(-1)
617                and key.size(-2) == bias_k.size(-2)
618                and bias_k.size(-3) == 1
619            ), "Shape of bias_k is not supported"
620            assert (
621                value.size(-1) == bias_v.size(-1)
622                and value.size(-2) == bias_v.size(-2)
623                and bias_v.size(-3) == 1
624            ), "Shape of bias_v is not supported"
625            key = torch.cat([key, bias_k])
626            value = torch.cat([value, bias_v])
627            if attn_mask is not None:
628                _attn_mask = attn_mask
629                attn_mask = torch.nn.functional.pad(_attn_mask, [0, 1])
630
631        tgt_len, head_dim = query.size(-3), query.size(-1)
632        assert (
633            query.size(-1) == key.size(-1) == value.size(-1)
634        ), "The feature dim of query, key, value must be equal."
635        assert key.size() == value.size(), "Shape of key, value must match"
636        src_len = key.size(-3)
637        batch_heads = max(query.size(-2), key.size(-2))
638
639        # Scale query
640        query, key, value = (
641            query.transpose(-2, -3),
642            key.transpose(-2, -3),
643            value.transpose(-2, -3),
644        )
645        query = query * (float(head_dim) ** -0.5)
646        if attn_mask is not None:
647            if attn_mask.dim() != 3:
648                raise RuntimeError("attn_mask must be a 3D tensor.")
649            if (
650                (attn_mask.size(-1) != src_len)
651                or (attn_mask.size(-2) != tgt_len)
652                or (attn_mask.size(-3) != 1 and attn_mask.size(-3) != batch_heads)
653            ):
654                raise RuntimeError("The size of the attn_mask is not correct.")
655            if attn_mask.dtype != torch.bool:
656                raise RuntimeError("Only bool tensor is supported for attn_mask")
657
658        # Dot product of q, k
659        attn_output_weights = torch.matmul(query, key.mT)
660        if attn_mask is not None:
661            attn_output_weights.masked_fill_(
662                attn_mask,
663                -1e8,
664            )
665        attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1)
666        attn_output_weights = torch.nn.functional.dropout(
667            attn_output_weights, p=self.dropout, training=self.training
668        )
669        attn_output = torch.matmul(attn_output_weights, value)
670        return attn_output.transpose(-2, -3), attn_output_weights
671
672
673class InProjContainer(torch.nn.Module):
674    def __init__(self, query_proj, key_proj, value_proj):
675        r"""A in-proj container to process inputs.
676        Args:
677            query_proj: a proj layer for query.
678            key_proj: a proj layer for key.
679            value_proj: a proj layer for value.
680        """
681
682        super().__init__()
683        self.query_proj = query_proj
684        self.key_proj = key_proj
685        self.value_proj = value_proj
686
687    def forward(
688        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
689    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
690        r"""Projects the input sequences using in-proj layers.
691        Args:
692            query, key, value (Tensors): sequence to be projected
693        Shape:
694            - query, key, value: :math:`(S, N, E)`
695            - Output: :math:`(S, N, E)`
696            where S is the sequence length, N is the batch size, and E is the embedding dimension.
697        """
698        return self.query_proj(query), self.key_proj(key), self.value_proj(value)
699