1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8import torch.nn.functional as F 9from torch import nn 10 11 12class ASRJoiner(nn.Module): 13 """ 14 ASR joiner implementation following the code in https://fburl.com/code/ierfau7c 15 Have a local implementation has the benefit that we don't need pull in the 16 heavy dependencies and wait for a few minutes to run tests. 17 """ 18 19 def __init__(self, B=1, H=10, T=1, U=1, D=768) -> None: 20 """ 21 B: source batch size 22 H: number of hypotheses for beam search 23 T: source sequence length 24 U: target sequence length 25 D: encoding (some sort of embedding?) dimension 26 """ 27 super().__init__() 28 self.B, self.H, self.T, self.U, self.D = B, H, T, U, D 29 # The module looks like: 30 # SequentialContainer( 31 # (module_list): ModuleList( 32 # (0): ReLULayer(inplace=False) 33 # (1): LinearLayer(input_dim=768, output_dim=4096, bias=True, context_dim=0, pruning_aware_training=False, parameter_noise=0.1, qat_qconfig=None, freeze_rex_pattern=None) 34 # ) 35 # ) 36 self.module = nn.Sequential( 37 nn.ReLU(), 38 nn.Linear(D, 4096), 39 ) 40 41 def forward(self, src_encodings, src_lengths, tgt_encodings, tgt_lengths): 42 """ 43 One simplification we make here is we assume src_encodings and tgt_encodings 44 are not None. In the originally implementation, either can be None. 45 """ 46 H = tgt_encodings.shape[0] // src_encodings.shape[0] 47 B = src_encodings.shape[0] 48 new_order = ( 49 (torch.arange(B).view(-1, 1).repeat(1, H).view(-1)) 50 .long() 51 .to(device=src_encodings.device) 52 ) 53 # src_encodings: (B, T, D) -> (B*H, T, D) 54 src_encodings = torch.index_select( 55 src_encodings, dim=0, index=new_order 56 ).contiguous() 57 58 # src_lengths: (B,) -> (B*H,) 59 src_lengths = torch.index_select( 60 src_lengths, dim=0, index=new_order.to(device=src_lengths.device) 61 ) 62 63 # src_encodings: (B*H, T, D) -> (B*H, T, 1, D) 64 src_encodings = src_encodings.unsqueeze(dim=2).contiguous() 65 66 # tgt_encodings: (B*H, U, D) -> (B*H, 1, U, D) 67 tgt_encodings = tgt_encodings.unsqueeze(dim=1).contiguous() 68 69 # joint_encodings: (B*H, T, U, D) 70 joint_encodings = src_encodings + tgt_encodings 71 72 output = F.log_softmax(self.module(joint_encodings), dim=-1) 73 74 return output, src_lengths, tgt_lengths 75 76 def get_random_inputs(self): 77 return ( 78 torch.rand(self.B, self.T, self.D), 79 torch.randint(0, 10, (self.B,)), 80 torch.rand(self.B * self.H, self.U, self.D), 81 torch.randint(0, 10, (self.B,)), 82 ) 83