xref: /aosp_15_r20/external/executorch/exir/tests/asr_joiner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport torch
8*523fa7a6SAndroid Build Coastguard Workerimport torch.nn.functional as F
9*523fa7a6SAndroid Build Coastguard Workerfrom torch import nn
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerclass ASRJoiner(nn.Module):
13*523fa7a6SAndroid Build Coastguard Worker    """
14*523fa7a6SAndroid Build Coastguard Worker    ASR joiner implementation following the code in https://fburl.com/code/ierfau7c
15*523fa7a6SAndroid Build Coastguard Worker    Have a local implementation has the benefit that we don't need pull in the
16*523fa7a6SAndroid Build Coastguard Worker    heavy dependencies and wait for a few minutes to run tests.
17*523fa7a6SAndroid Build Coastguard Worker    """
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, B=1, H=10, T=1, U=1, D=768) -> None:
20*523fa7a6SAndroid Build Coastguard Worker        """
21*523fa7a6SAndroid Build Coastguard Worker        B: source batch size
22*523fa7a6SAndroid Build Coastguard Worker        H: number of hypotheses for beam search
23*523fa7a6SAndroid Build Coastguard Worker        T: source sequence length
24*523fa7a6SAndroid Build Coastguard Worker        U: target sequence length
25*523fa7a6SAndroid Build Coastguard Worker        D: encoding (some sort of embedding?) dimension
26*523fa7a6SAndroid Build Coastguard Worker        """
27*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
28*523fa7a6SAndroid Build Coastguard Worker        self.B, self.H, self.T, self.U, self.D = B, H, T, U, D
29*523fa7a6SAndroid Build Coastguard Worker        # The module looks like:
30*523fa7a6SAndroid Build Coastguard Worker        # SequentialContainer(
31*523fa7a6SAndroid Build Coastguard Worker        #   (module_list): ModuleList(
32*523fa7a6SAndroid Build Coastguard Worker        #     (0): ReLULayer(inplace=False)
33*523fa7a6SAndroid Build Coastguard Worker        #     (1): LinearLayer(input_dim=768, output_dim=4096, bias=True, context_dim=0, pruning_aware_training=False, parameter_noise=0.1, qat_qconfig=None, freeze_rex_pattern=None)
34*523fa7a6SAndroid Build Coastguard Worker        #   )
35*523fa7a6SAndroid Build Coastguard Worker        # )
36*523fa7a6SAndroid Build Coastguard Worker        self.module = nn.Sequential(
37*523fa7a6SAndroid Build Coastguard Worker            nn.ReLU(),
38*523fa7a6SAndroid Build Coastguard Worker            nn.Linear(D, 4096),
39*523fa7a6SAndroid Build Coastguard Worker        )
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Worker    def forward(self, src_encodings, src_lengths, tgt_encodings, tgt_lengths):
42*523fa7a6SAndroid Build Coastguard Worker        """
43*523fa7a6SAndroid Build Coastguard Worker        One simplification we make here is we assume src_encodings and tgt_encodings
44*523fa7a6SAndroid Build Coastguard Worker        are not None. In the originally implementation, either can be None.
45*523fa7a6SAndroid Build Coastguard Worker        """
46*523fa7a6SAndroid Build Coastguard Worker        H = tgt_encodings.shape[0] // src_encodings.shape[0]
47*523fa7a6SAndroid Build Coastguard Worker        B = src_encodings.shape[0]
48*523fa7a6SAndroid Build Coastguard Worker        new_order = (
49*523fa7a6SAndroid Build Coastguard Worker            (torch.arange(B).view(-1, 1).repeat(1, H).view(-1))
50*523fa7a6SAndroid Build Coastguard Worker            .long()
51*523fa7a6SAndroid Build Coastguard Worker            .to(device=src_encodings.device)
52*523fa7a6SAndroid Build Coastguard Worker        )
53*523fa7a6SAndroid Build Coastguard Worker        # src_encodings: (B, T, D) -> (B*H, T, D)
54*523fa7a6SAndroid Build Coastguard Worker        src_encodings = torch.index_select(
55*523fa7a6SAndroid Build Coastguard Worker            src_encodings, dim=0, index=new_order
56*523fa7a6SAndroid Build Coastguard Worker        ).contiguous()
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard Worker        # src_lengths: (B,) -> (B*H,)
59*523fa7a6SAndroid Build Coastguard Worker        src_lengths = torch.index_select(
60*523fa7a6SAndroid Build Coastguard Worker            src_lengths, dim=0, index=new_order.to(device=src_lengths.device)
61*523fa7a6SAndroid Build Coastguard Worker        )
62*523fa7a6SAndroid Build Coastguard Worker
63*523fa7a6SAndroid Build Coastguard Worker        # src_encodings: (B*H, T, D) -> (B*H, T, 1, D)
64*523fa7a6SAndroid Build Coastguard Worker        src_encodings = src_encodings.unsqueeze(dim=2).contiguous()
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Worker        # tgt_encodings: (B*H, U, D) -> (B*H, 1, U, D)
67*523fa7a6SAndroid Build Coastguard Worker        tgt_encodings = tgt_encodings.unsqueeze(dim=1).contiguous()
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Worker        # joint_encodings: (B*H, T, U, D)
70*523fa7a6SAndroid Build Coastguard Worker        joint_encodings = src_encodings + tgt_encodings
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker        output = F.log_softmax(self.module(joint_encodings), dim=-1)
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Worker        return output, src_lengths, tgt_lengths
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self):
77*523fa7a6SAndroid Build Coastguard Worker        return (
78*523fa7a6SAndroid Build Coastguard Worker            torch.rand(self.B, self.T, self.D),
79*523fa7a6SAndroid Build Coastguard Worker            torch.randint(0, 10, (self.B,)),
80*523fa7a6SAndroid Build Coastguard Worker            torch.rand(self.B * self.H, self.U, self.D),
81*523fa7a6SAndroid Build Coastguard Worker            torch.randint(0, 10, (self.B,)),
82*523fa7a6SAndroid Build Coastguard Worker        )
83