1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport torch 8*523fa7a6SAndroid Build Coastguard Workerimport torch.nn.functional as F 9*523fa7a6SAndroid Build Coastguard Workerfrom torch import nn 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerclass ASRJoiner(nn.Module): 13*523fa7a6SAndroid Build Coastguard Worker """ 14*523fa7a6SAndroid Build Coastguard Worker ASR joiner implementation following the code in https://fburl.com/code/ierfau7c 15*523fa7a6SAndroid Build Coastguard Worker Have a local implementation has the benefit that we don't need pull in the 16*523fa7a6SAndroid Build Coastguard Worker heavy dependencies and wait for a few minutes to run tests. 17*523fa7a6SAndroid Build Coastguard Worker """ 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Worker def __init__(self, B=1, H=10, T=1, U=1, D=768) -> None: 20*523fa7a6SAndroid Build Coastguard Worker """ 21*523fa7a6SAndroid Build Coastguard Worker B: source batch size 22*523fa7a6SAndroid Build Coastguard Worker H: number of hypotheses for beam search 23*523fa7a6SAndroid Build Coastguard Worker T: source sequence length 24*523fa7a6SAndroid Build Coastguard Worker U: target sequence length 25*523fa7a6SAndroid Build Coastguard Worker D: encoding (some sort of embedding?) dimension 26*523fa7a6SAndroid Build Coastguard Worker """ 27*523fa7a6SAndroid Build Coastguard Worker super().__init__() 28*523fa7a6SAndroid Build Coastguard Worker self.B, self.H, self.T, self.U, self.D = B, H, T, U, D 29*523fa7a6SAndroid Build Coastguard Worker # The module looks like: 30*523fa7a6SAndroid Build Coastguard Worker # SequentialContainer( 31*523fa7a6SAndroid Build Coastguard Worker # (module_list): ModuleList( 32*523fa7a6SAndroid Build Coastguard Worker # (0): ReLULayer(inplace=False) 33*523fa7a6SAndroid Build Coastguard Worker # (1): LinearLayer(input_dim=768, output_dim=4096, bias=True, context_dim=0, pruning_aware_training=False, parameter_noise=0.1, qat_qconfig=None, freeze_rex_pattern=None) 34*523fa7a6SAndroid Build Coastguard Worker # ) 35*523fa7a6SAndroid Build Coastguard Worker # ) 36*523fa7a6SAndroid Build Coastguard Worker self.module = nn.Sequential( 37*523fa7a6SAndroid Build Coastguard Worker nn.ReLU(), 38*523fa7a6SAndroid Build Coastguard Worker nn.Linear(D, 4096), 39*523fa7a6SAndroid Build Coastguard Worker ) 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Worker def forward(self, src_encodings, src_lengths, tgt_encodings, tgt_lengths): 42*523fa7a6SAndroid Build Coastguard Worker """ 43*523fa7a6SAndroid Build Coastguard Worker One simplification we make here is we assume src_encodings and tgt_encodings 44*523fa7a6SAndroid Build Coastguard Worker are not None. In the originally implementation, either can be None. 45*523fa7a6SAndroid Build Coastguard Worker """ 46*523fa7a6SAndroid Build Coastguard Worker H = tgt_encodings.shape[0] // src_encodings.shape[0] 47*523fa7a6SAndroid Build Coastguard Worker B = src_encodings.shape[0] 48*523fa7a6SAndroid Build Coastguard Worker new_order = ( 49*523fa7a6SAndroid Build Coastguard Worker (torch.arange(B).view(-1, 1).repeat(1, H).view(-1)) 50*523fa7a6SAndroid Build Coastguard Worker .long() 51*523fa7a6SAndroid Build Coastguard Worker .to(device=src_encodings.device) 52*523fa7a6SAndroid Build Coastguard Worker ) 53*523fa7a6SAndroid Build Coastguard Worker # src_encodings: (B, T, D) -> (B*H, T, D) 54*523fa7a6SAndroid Build Coastguard Worker src_encodings = torch.index_select( 55*523fa7a6SAndroid Build Coastguard Worker src_encodings, dim=0, index=new_order 56*523fa7a6SAndroid Build Coastguard Worker ).contiguous() 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Worker # src_lengths: (B,) -> (B*H,) 59*523fa7a6SAndroid Build Coastguard Worker src_lengths = torch.index_select( 60*523fa7a6SAndroid Build Coastguard Worker src_lengths, dim=0, index=new_order.to(device=src_lengths.device) 61*523fa7a6SAndroid Build Coastguard Worker ) 62*523fa7a6SAndroid Build Coastguard Worker 63*523fa7a6SAndroid Build Coastguard Worker # src_encodings: (B*H, T, D) -> (B*H, T, 1, D) 64*523fa7a6SAndroid Build Coastguard Worker src_encodings = src_encodings.unsqueeze(dim=2).contiguous() 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard Worker # tgt_encodings: (B*H, U, D) -> (B*H, 1, U, D) 67*523fa7a6SAndroid Build Coastguard Worker tgt_encodings = tgt_encodings.unsqueeze(dim=1).contiguous() 68*523fa7a6SAndroid Build Coastguard Worker 69*523fa7a6SAndroid Build Coastguard Worker # joint_encodings: (B*H, T, U, D) 70*523fa7a6SAndroid Build Coastguard Worker joint_encodings = src_encodings + tgt_encodings 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Worker output = F.log_softmax(self.module(joint_encodings), dim=-1) 73*523fa7a6SAndroid Build Coastguard Worker 74*523fa7a6SAndroid Build Coastguard Worker return output, src_lengths, tgt_lengths 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self): 77*523fa7a6SAndroid Build Coastguard Worker return ( 78*523fa7a6SAndroid Build Coastguard Worker torch.rand(self.B, self.T, self.D), 79*523fa7a6SAndroid Build Coastguard Worker torch.randint(0, 10, (self.B,)), 80*523fa7a6SAndroid Build Coastguard Worker torch.rand(self.B * self.H, self.U, self.D), 81*523fa7a6SAndroid Build Coastguard Worker torch.randint(0, 10, (self.B,)), 82*523fa7a6SAndroid Build Coastguard Worker ) 83