1# Taken from https://github.com/pytorch/audio/blob/master/torchaudio/models/wav2letter.py 2# So that we don't need torchaudio to be installed 3 4import math 5from collections import OrderedDict 6from typing import Optional, Tuple 7 8import torch 9import torch.nn.functional as F 10from torch import nn, Tensor 11 12 13__all__ = ["Wav2Letter"] 14 15 16class Wav2Letter(nn.Module): 17 r"""Wav2Letter model architecture from the `"Wav2Letter: an End-to-End ConvNet-based Speech Recognition System" 18 <https://arxiv.org/abs/1609.03193>`_ paper. 19 :math:`\text{padding} = \frac{\text{ceil}(\text{kernel} - \text{stride})}{2}` 20 Args: 21 num_classes (int, optional): Number of classes to be classified. (Default: ``40``) 22 input_type (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum`` 23 or ``mfcc`` (Default: ``waveform``). 24 num_features (int, optional): Number of input features that the network will receive (Default: ``1``). 25 """ 26 27 def __init__( 28 self, num_classes: int = 40, input_type: str = "waveform", num_features: int = 1 29 ) -> None: 30 super().__init__() 31 32 acoustic_num_features = 250 if input_type == "waveform" else num_features 33 acoustic_model = nn.Sequential( 34 nn.Conv1d( 35 in_channels=acoustic_num_features, 36 out_channels=250, 37 kernel_size=48, 38 stride=2, 39 padding=23, 40 ), 41 nn.ReLU(inplace=True), 42 nn.Conv1d( 43 in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3 44 ), 45 nn.ReLU(inplace=True), 46 nn.Conv1d( 47 in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3 48 ), 49 nn.ReLU(inplace=True), 50 nn.Conv1d( 51 in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3 52 ), 53 nn.ReLU(inplace=True), 54 nn.Conv1d( 55 in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3 56 ), 57 nn.ReLU(inplace=True), 58 nn.Conv1d( 59 in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3 60 ), 61 nn.ReLU(inplace=True), 62 nn.Conv1d( 63 in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3 64 ), 65 nn.ReLU(inplace=True), 66 nn.Conv1d( 67 in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3 68 ), 69 nn.ReLU(inplace=True), 70 nn.Conv1d( 71 in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16 72 ), 73 nn.ReLU(inplace=True), 74 nn.Conv1d( 75 in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0 76 ), 77 nn.ReLU(inplace=True), 78 nn.Conv1d( 79 in_channels=2000, 80 out_channels=num_classes, 81 kernel_size=1, 82 stride=1, 83 padding=0, 84 ), 85 nn.ReLU(inplace=True), 86 ) 87 88 if input_type == "waveform": 89 waveform_model = nn.Sequential( 90 nn.Conv1d( 91 in_channels=num_features, 92 out_channels=250, 93 kernel_size=250, 94 stride=160, 95 padding=45, 96 ), 97 nn.ReLU(inplace=True), 98 ) 99 self.acoustic_model = nn.Sequential(waveform_model, acoustic_model) 100 101 if input_type in ["power_spectrum", "mfcc"]: 102 self.acoustic_model = acoustic_model 103 104 def forward(self, x: Tensor) -> Tensor: 105 r""" 106 Args: 107 x (Tensor): Tensor of dimension (batch_size, num_features, input_length). 108 Returns: 109 Tensor: Predictor tensor of dimension (batch_size, number_of_classes, input_length). 110 """ 111 112 x = self.acoustic_model(x) 113 x = nn.functional.log_softmax(x, dim=1) 114 return x 115 116 117# Taken from https://github.com/SeanNaren/deepspeech.pytorch with modifications 118class SequenceWise(nn.Module): 119 def __init__(self, module): 120 """ 121 Collapses input of dim T*N*H to (T*N)*H, and applies to a module. 122 Allows handling of variable sequence lengths and minibatch sizes. 123 :param module: Module to apply input to. 124 """ 125 super().__init__() 126 self.module = module 127 128 def forward(self, x): 129 t, n = x.size(0), x.size(1) 130 x = x.view(t * n, -1) 131 x = self.module(x) 132 x = x.view(t, n, -1) 133 return x 134 135 def __repr__(self): 136 tmpstr = self.__class__.__name__ + " (\n" 137 tmpstr += self.module.__repr__() 138 tmpstr += ")" 139 return tmpstr 140 141 142class MaskConv(nn.Module): 143 def __init__(self, seq_module): 144 """ 145 Adds padding to the output of the module based on the given lengths. This is to ensure that the 146 results of the model do not change when batch sizes change during inference. 147 Input needs to be in the shape of (BxCxDxT) 148 :param seq_module: The sequential module containing the conv stack. 149 """ 150 super().__init__() 151 self.seq_module = seq_module 152 153 def forward(self, x, lengths): 154 """ 155 :param x: The input of size BxCxDxT 156 :param lengths: The actual length of each sequence in the batch 157 :return: Masked output from the module 158 """ 159 for module in self.seq_module: 160 x = module(x) 161 mask = torch.BoolTensor(x.size()).fill_(0) 162 if x.is_cuda: 163 mask = mask.cuda() 164 for i, length in enumerate(lengths): 165 length = length.item() 166 if (mask[i].size(2) - length) > 0: 167 mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1) 168 x = x.masked_fill(mask, 0) 169 return x, lengths 170 171 172class InferenceBatchSoftmax(nn.Module): 173 def forward(self, input_): 174 if not self.training: 175 return F.softmax(input_, dim=-1) 176 else: 177 return input_ 178 179 180class BatchRNN(nn.Module): 181 def __init__( 182 self, 183 input_size, 184 hidden_size, 185 rnn_type=nn.LSTM, 186 bidirectional=False, 187 batch_norm=True, 188 ): 189 super().__init__() 190 self.input_size = input_size 191 self.hidden_size = hidden_size 192 self.bidirectional = bidirectional 193 self.batch_norm = ( 194 SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None 195 ) 196 self.rnn = rnn_type( 197 input_size=input_size, 198 hidden_size=hidden_size, 199 bidirectional=bidirectional, 200 bias=True, 201 ) 202 self.num_directions = 2 if bidirectional else 1 203 204 def flatten_parameters(self): 205 self.rnn.flatten_parameters() 206 207 def forward(self, x, output_lengths): 208 if self.batch_norm is not None: 209 x = self.batch_norm(x) 210 x = nn.utils.rnn.pack_padded_sequence(x, output_lengths, enforce_sorted=False) 211 x, h = self.rnn(x) 212 x, _ = nn.utils.rnn.pad_packed_sequence(x) 213 if self.bidirectional: 214 x = ( 215 x.view(x.size(0), x.size(1), 2, -1) 216 .sum(2) 217 .view(x.size(0), x.size(1), -1) 218 ) # (TxNxH*2) -> (TxNxH) by sum 219 return x 220 221 222class Lookahead(nn.Module): 223 # Wang et al., 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks 224 # input shape - sequence, batch, feature - TxNxH 225 # output shape - same as input 226 def __init__(self, n_features, context): 227 super().__init__() 228 assert context > 0 229 self.context = context 230 self.n_features = n_features 231 self.pad = (0, self.context - 1) 232 self.conv = nn.Conv1d( 233 self.n_features, 234 self.n_features, 235 kernel_size=self.context, 236 stride=1, 237 groups=self.n_features, 238 padding=0, 239 bias=None, 240 ) 241 242 def forward(self, x): 243 x = x.transpose(0, 1).transpose(1, 2) 244 x = F.pad(x, pad=self.pad, value=0) 245 x = self.conv(x) 246 x = x.transpose(1, 2).transpose(0, 1).contiguous() 247 return x 248 249 def __repr__(self): 250 return ( 251 self.__class__.__name__ 252 + "(" 253 + "n_features=" 254 + str(self.n_features) 255 + ", context=" 256 + str(self.context) 257 + ")" 258 ) 259 260 261class DeepSpeech(nn.Module): 262 def __init__( 263 self, 264 rnn_type, 265 labels, 266 rnn_hidden_size, 267 nb_layers, 268 audio_conf, 269 bidirectional, 270 context=20, 271 ): 272 super().__init__() 273 274 self.hidden_size = rnn_hidden_size 275 self.hidden_layers = nb_layers 276 self.rnn_type = rnn_type 277 self.audio_conf = audio_conf 278 self.labels = labels 279 self.bidirectional = bidirectional 280 281 sample_rate = self.audio_conf["sample_rate"] 282 window_size = self.audio_conf["window_size"] 283 num_classes = len(self.labels) 284 285 self.conv = MaskConv( 286 nn.Sequential( 287 nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)), 288 nn.BatchNorm2d(32), 289 nn.Hardtanh(0, 20, inplace=True), 290 nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)), 291 nn.BatchNorm2d(32), 292 nn.Hardtanh(0, 20, inplace=True), 293 ) 294 ) 295 # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1 296 rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1) 297 rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1) 298 rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1) 299 rnn_input_size *= 32 300 301 rnns = [] 302 rnn = BatchRNN( 303 input_size=rnn_input_size, 304 hidden_size=rnn_hidden_size, 305 rnn_type=rnn_type, 306 bidirectional=bidirectional, 307 batch_norm=False, 308 ) 309 rnns.append(("0", rnn)) 310 for x in range(nb_layers - 1): 311 rnn = BatchRNN( 312 input_size=rnn_hidden_size, 313 hidden_size=rnn_hidden_size, 314 rnn_type=rnn_type, 315 bidirectional=bidirectional, 316 ) 317 rnns.append(("%d" % (x + 1), rnn)) 318 self.rnns = nn.Sequential(OrderedDict(rnns)) 319 self.lookahead = ( 320 nn.Sequential( 321 # consider adding batch norm? 322 Lookahead(rnn_hidden_size, context=context), 323 nn.Hardtanh(0, 20, inplace=True), 324 ) 325 if not bidirectional 326 else None 327 ) 328 329 fully_connected = nn.Sequential( 330 nn.BatchNorm1d(rnn_hidden_size), 331 nn.Linear(rnn_hidden_size, num_classes, bias=False), 332 ) 333 self.fc = nn.Sequential( 334 SequenceWise(fully_connected), 335 ) 336 self.inference_softmax = InferenceBatchSoftmax() 337 338 def forward(self, x, lengths): 339 lengths = lengths.cpu().int() 340 output_lengths = self.get_seq_lens(lengths) 341 x, _ = self.conv(x, output_lengths) 342 343 sizes = x.size() 344 x = x.view( 345 sizes[0], sizes[1] * sizes[2], sizes[3] 346 ) # Collapse feature dimension 347 x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH 348 349 for rnn in self.rnns: 350 x = rnn(x, output_lengths) 351 352 if not self.bidirectional: # no need for lookahead layer in bidirectional 353 x = self.lookahead(x) 354 355 x = self.fc(x) 356 x = x.transpose(0, 1) 357 # identity in training mode, softmax in eval mode 358 x = self.inference_softmax(x) 359 return x, output_lengths 360 361 def get_seq_lens(self, input_length): 362 """ 363 Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable 364 containing the size sequences that will be output by the network. 365 :param input_length: 1D Tensor 366 :return: 1D Tensor scaled by model 367 """ 368 seq_len = input_length 369 for m in self.conv.modules(): 370 if type(m) == nn.modules.conv.Conv2d: 371 seq_len = ( 372 seq_len 373 + 2 * m.padding[1] 374 - m.dilation[1] * (m.kernel_size[1] - 1) 375 - 1 376 ) 377 seq_len = seq_len.true_divide(m.stride[1]) + 1 378 return seq_len.int() 379 380 381# Taken from https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L108-L152 382class PositionalEncoding(nn.Module): 383 r"""Inject some information about the relative or absolute position of the tokens 384 in the sequence. The positional encodings have the same dimension as 385 the embeddings, so that the two can be summed. Here, we use sine and cosine 386 functions of different frequencies. 387 .. math:: 388 \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) 389 \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) 390 \text{where pos is the word position and i is the embed idx) 391 Args: 392 d_model: the embed dim (required). 393 dropout: the dropout value (default=0.1). 394 max_len: the max. length of the incoming sequence (default=5000). 395 Examples: 396 >>> pos_encoder = PositionalEncoding(d_model) 397 """ 398 399 def __init__(self, d_model, dropout=0.1, max_len=5000): 400 super().__init__() 401 self.dropout = nn.Dropout(p=dropout) 402 403 pe = torch.zeros(max_len, d_model) 404 position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 405 div_term = torch.exp( 406 torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 407 ) 408 pe[:, 0::2] = torch.sin(position * div_term) 409 pe[:, 1::2] = torch.cos(position * div_term) 410 pe = pe.unsqueeze(0).transpose(0, 1) 411 self.register_buffer("pe", pe) 412 413 def forward(self, x): 414 r"""Inputs of forward function 415 Args: 416 x: the sequence fed to the positional encoder model (required). 417 Shape: 418 x: [sequence length, batch size, embed dim] 419 output: [sequence length, batch size, embed dim] 420 Examples: 421 >>> output = pos_encoder(x) 422 """ 423 424 x = x + self.pe[: x.size(0), :] 425 return self.dropout(x) 426 427 428class TransformerModel(nn.Module): 429 """Container module with an encoder, a recurrent or transformer module, and a decoder.""" 430 431 def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): 432 super().__init__() 433 try: 434 from torch.nn import TransformerEncoder, TransformerEncoderLayer 435 except Exception as e: 436 raise ImportError( 437 "TransformerEncoder module does not exist in PyTorch 1.1 or lower." 438 ) from e 439 self.model_type = "Transformer" 440 self.src_mask = None 441 self.pos_encoder = PositionalEncoding(ninp, dropout) 442 encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) 443 self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 444 self.encoder = nn.Embedding(ntoken, ninp) 445 self.ninp = ninp 446 self.decoder = nn.Linear(ninp, ntoken) 447 448 self.init_weights() 449 450 def init_weights(self): 451 initrange = 0.1 452 nn.init.uniform_(self.encoder.weight, -initrange, initrange) 453 # Not sure how this works in the original code 454 # nn.init.zeros_(self.decoder) 455 nn.init.uniform_(self.decoder.weight, -initrange, initrange) 456 457 def forward(self, src, has_mask=True): 458 if has_mask: 459 device = src.device 460 # This will be created once during warmup 461 if self.src_mask is None or self.src_mask.size(0) != len(src): 462 mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to( 463 device 464 ) 465 self.src_mask = mask 466 else: 467 self.src_mask = None 468 469 src = self.encoder(src) * math.sqrt(self.ninp) 470 src = self.pos_encoder(src) 471 output = self.transformer_encoder(src, self.src_mask) 472 output = self.decoder(output) 473 return F.log_softmax(output, dim=-1) 474 475 476# From https://github.com/pytorch/text/blob/master/torchtext/modules 477class MultiheadAttentionContainer(torch.nn.Module): 478 def __init__(self, nhead, in_proj_container, attention_layer, out_proj): 479 r"""A multi-head attention container 480 Args: 481 nhead: the number of heads in the multiheadattention model 482 in_proj_container: A container of multi-head in-projection linear layers (a.k.a nn.Linear). 483 attention_layer: The attention layer. 484 out_proj: The multi-head out-projection layer (a.k.a nn.Linear). 485 Examples:: 486 >>> import torch 487 >>> embed_dim, num_heads, bsz = 10, 5, 64 488 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), 489 torch.nn.Linear(embed_dim, embed_dim), 490 torch.nn.Linear(embed_dim, embed_dim)) 491 >>> MHA = MultiheadAttentionContainer(num_heads, 492 in_proj_container, 493 ScaledDotProduct(), 494 torch.nn.Linear(embed_dim, embed_dim)) 495 >>> query = torch.rand((21, bsz, embed_dim)) 496 >>> key = value = torch.rand((16, bsz, embed_dim)) 497 >>> attn_output, attn_weights = MHA(query, key, value) 498 >>> print(attn_output.shape) 499 >>> torch.Size([21, 64, 10]) 500 """ 501 super().__init__() 502 self.nhead = nhead 503 self.in_proj_container = in_proj_container 504 self.attention_layer = attention_layer 505 self.out_proj = out_proj 506 507 def forward( 508 self, 509 query: torch.Tensor, 510 key: torch.Tensor, 511 value: torch.Tensor, 512 attn_mask: Optional[torch.Tensor] = None, 513 bias_k: Optional[torch.Tensor] = None, 514 bias_v: Optional[torch.Tensor] = None, 515 ) -> Tuple[torch.Tensor, torch.Tensor]: 516 r""" 517 Args: 518 query, key, value (Tensor): map a query and a set of key-value pairs to an output. 519 See "Attention Is All You Need" for more details. 520 attn_mask, bias_k and bias_v (Tensor, optional): keyword arguments passed to the attention layer. 521 See the definitions in the attention. 522 Shape: 523 - Inputs: 524 - query: :math:`(L, N, E)` 525 - key: :math:`(S, N, E)` 526 - value: :math:`(S, N, E)` 527 - attn_mask, bias_k and bias_v: same with the shape of the corresponding args in attention layer. 528 - Outputs: 529 - attn_output: :math:`(L, N, E)` 530 - attn_output_weights: :math:`(N * H, L, S)` 531 where where L is the target length, S is the sequence length, H is the number of attention heads, 532 N is the batch size, and E is the embedding dimension. 533 """ 534 tgt_len, src_len, bsz, embed_dim = ( 535 query.size(-3), 536 key.size(-3), 537 query.size(-2), 538 query.size(-1), 539 ) 540 q, k, v = self.in_proj_container(query, key, value) 541 assert ( 542 q.size(-1) % self.nhead == 0 543 ), "query's embed_dim must be divisible by the number of heads" 544 head_dim = q.size(-1) // self.nhead 545 q = q.reshape(tgt_len, bsz * self.nhead, head_dim) 546 547 assert ( 548 k.size(-1) % self.nhead == 0 549 ), "key's embed_dim must be divisible by the number of heads" 550 head_dim = k.size(-1) // self.nhead 551 k = k.reshape(src_len, bsz * self.nhead, head_dim) 552 553 assert ( 554 v.size(-1) % self.nhead == 0 555 ), "value's embed_dim must be divisible by the number of heads" 556 head_dim = v.size(-1) // self.nhead 557 v = v.reshape(src_len, bsz * self.nhead, head_dim) 558 559 attn_output, attn_output_weights = self.attention_layer( 560 q, k, v, attn_mask=attn_mask, bias_k=bias_k, bias_v=bias_v 561 ) 562 attn_output = attn_output.reshape(tgt_len, bsz, embed_dim) 563 attn_output = self.out_proj(attn_output) 564 return attn_output, attn_output_weights 565 566 567class ScaledDotProduct(torch.nn.Module): 568 def __init__(self, dropout=0.0): 569 r"""Processes a projected query and key-value pair to apply 570 scaled dot product attention. 571 Args: 572 dropout (float): probability of dropping an attention weight. 573 Examples:: 574 >>> SDP = torchtext.models.ScaledDotProduct(0.1) 575 >>> q = torch.randn(256, 21, 3) 576 >>> k = v = torch.randn(256, 21, 3) 577 >>> attn_output, attn_weights = SDP(q, k, v) 578 >>> print(attn_output.shape, attn_weights.shape) 579 torch.Size([256, 21, 3]) torch.Size([256, 21, 21]) 580 """ 581 super().__init__() 582 self.dropout = dropout 583 584 def forward( 585 self, 586 query: torch.Tensor, 587 key: torch.Tensor, 588 value: torch.Tensor, 589 attn_mask: Optional[torch.Tensor] = None, 590 bias_k: Optional[torch.Tensor] = None, 591 bias_v: Optional[torch.Tensor] = None, 592 ) -> Tuple[torch.Tensor, torch.Tensor]: 593 r"""Uses a scaled dot product with the projected key-value pair to update 594 the projected query. 595 Args: 596 query (Tensor): Projected query 597 key (Tensor): Projected key 598 value (Tensor): Projected value 599 attn_mask (BoolTensor, optional): 3D mask that prevents attention to certain positions. 600 bias_k and bias_v: (Tensor, optional): one more key and value sequence to be added at 601 sequence dim (dim=-3). Those are used for incremental decoding. Users should provide 602 non-None to both arguments in order to activate them. 603 Shape: 604 - query: :math:`(L, N * H, E / H)` 605 - key: :math:`(S, N * H, E / H)` 606 - value: :math:`(S, N * H, E / H)` 607 - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not allowed to attend 608 while ``False`` values will be unchanged. 609 - bias_k and bias_v:bias: :math:`(1, N * H, E / H)` 610 - Output: :math:`(L, N * H, E / H)`, :math:`(N * H, L, S)` 611 where L is the target length, S is the source length, H is the number 612 of attention heads, N is the batch size, and E is the embedding dimension. 613 """ 614 if bias_k is not None and bias_v is not None: 615 assert ( 616 key.size(-1) == bias_k.size(-1) 617 and key.size(-2) == bias_k.size(-2) 618 and bias_k.size(-3) == 1 619 ), "Shape of bias_k is not supported" 620 assert ( 621 value.size(-1) == bias_v.size(-1) 622 and value.size(-2) == bias_v.size(-2) 623 and bias_v.size(-3) == 1 624 ), "Shape of bias_v is not supported" 625 key = torch.cat([key, bias_k]) 626 value = torch.cat([value, bias_v]) 627 if attn_mask is not None: 628 _attn_mask = attn_mask 629 attn_mask = torch.nn.functional.pad(_attn_mask, [0, 1]) 630 631 tgt_len, head_dim = query.size(-3), query.size(-1) 632 assert ( 633 query.size(-1) == key.size(-1) == value.size(-1) 634 ), "The feature dim of query, key, value must be equal." 635 assert key.size() == value.size(), "Shape of key, value must match" 636 src_len = key.size(-3) 637 batch_heads = max(query.size(-2), key.size(-2)) 638 639 # Scale query 640 query, key, value = ( 641 query.transpose(-2, -3), 642 key.transpose(-2, -3), 643 value.transpose(-2, -3), 644 ) 645 query = query * (float(head_dim) ** -0.5) 646 if attn_mask is not None: 647 if attn_mask.dim() != 3: 648 raise RuntimeError("attn_mask must be a 3D tensor.") 649 if ( 650 (attn_mask.size(-1) != src_len) 651 or (attn_mask.size(-2) != tgt_len) 652 or (attn_mask.size(-3) != 1 and attn_mask.size(-3) != batch_heads) 653 ): 654 raise RuntimeError("The size of the attn_mask is not correct.") 655 if attn_mask.dtype != torch.bool: 656 raise RuntimeError("Only bool tensor is supported for attn_mask") 657 658 # Dot product of q, k 659 attn_output_weights = torch.matmul(query, key.mT) 660 if attn_mask is not None: 661 attn_output_weights.masked_fill_( 662 attn_mask, 663 -1e8, 664 ) 665 attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) 666 attn_output_weights = torch.nn.functional.dropout( 667 attn_output_weights, p=self.dropout, training=self.training 668 ) 669 attn_output = torch.matmul(attn_output_weights, value) 670 return attn_output.transpose(-2, -3), attn_output_weights 671 672 673class InProjContainer(torch.nn.Module): 674 def __init__(self, query_proj, key_proj, value_proj): 675 r"""A in-proj container to process inputs. 676 Args: 677 query_proj: a proj layer for query. 678 key_proj: a proj layer for key. 679 value_proj: a proj layer for value. 680 """ 681 682 super().__init__() 683 self.query_proj = query_proj 684 self.key_proj = key_proj 685 self.value_proj = value_proj 686 687 def forward( 688 self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 689 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 690 r"""Projects the input sequences using in-proj layers. 691 Args: 692 query, key, value (Tensors): sequence to be projected 693 Shape: 694 - query, key, value: :math:`(S, N, E)` 695 - Output: :math:`(S, N, E)` 696 where S is the sequence length, N is the batch size, and E is the embedding dimension. 697 """ 698 return self.query_proj(query), self.key_proj(key), self.value_proj(value) 699