xref: /aosp_15_r20/external/executorch/exir/tests/asr_joiner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8import torch.nn.functional as F
9from torch import nn
10
11
12class ASRJoiner(nn.Module):
13    """
14    ASR joiner implementation following the code in https://fburl.com/code/ierfau7c
15    Have a local implementation has the benefit that we don't need pull in the
16    heavy dependencies and wait for a few minutes to run tests.
17    """
18
19    def __init__(self, B=1, H=10, T=1, U=1, D=768) -> None:
20        """
21        B: source batch size
22        H: number of hypotheses for beam search
23        T: source sequence length
24        U: target sequence length
25        D: encoding (some sort of embedding?) dimension
26        """
27        super().__init__()
28        self.B, self.H, self.T, self.U, self.D = B, H, T, U, D
29        # The module looks like:
30        # SequentialContainer(
31        #   (module_list): ModuleList(
32        #     (0): ReLULayer(inplace=False)
33        #     (1): LinearLayer(input_dim=768, output_dim=4096, bias=True, context_dim=0, pruning_aware_training=False, parameter_noise=0.1, qat_qconfig=None, freeze_rex_pattern=None)
34        #   )
35        # )
36        self.module = nn.Sequential(
37            nn.ReLU(),
38            nn.Linear(D, 4096),
39        )
40
41    def forward(self, src_encodings, src_lengths, tgt_encodings, tgt_lengths):
42        """
43        One simplification we make here is we assume src_encodings and tgt_encodings
44        are not None. In the originally implementation, either can be None.
45        """
46        H = tgt_encodings.shape[0] // src_encodings.shape[0]
47        B = src_encodings.shape[0]
48        new_order = (
49            (torch.arange(B).view(-1, 1).repeat(1, H).view(-1))
50            .long()
51            .to(device=src_encodings.device)
52        )
53        # src_encodings: (B, T, D) -> (B*H, T, D)
54        src_encodings = torch.index_select(
55            src_encodings, dim=0, index=new_order
56        ).contiguous()
57
58        # src_lengths: (B,) -> (B*H,)
59        src_lengths = torch.index_select(
60            src_lengths, dim=0, index=new_order.to(device=src_lengths.device)
61        )
62
63        # src_encodings: (B*H, T, D) -> (B*H, T, 1, D)
64        src_encodings = src_encodings.unsqueeze(dim=2).contiguous()
65
66        # tgt_encodings: (B*H, U, D) -> (B*H, 1, U, D)
67        tgt_encodings = tgt_encodings.unsqueeze(dim=1).contiguous()
68
69        # joint_encodings: (B*H, T, U, D)
70        joint_encodings = src_encodings + tgt_encodings
71
72        output = F.log_softmax(self.module(joint_encodings), dim=-1)
73
74        return output, src_lengths, tgt_lengths
75
76    def get_random_inputs(self):
77        return (
78            torch.rand(self.B, self.T, self.D),
79            torch.randint(0, 10, (self.B,)),
80            torch.rand(self.B * self.H, self.U, self.D),
81            torch.randint(0, 10, (self.B,)),
82        )
83