xref: /aosp_15_r20/external/executorch/examples/cadence/models/rnnt_joiner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# Example script for exporting simple models to flatbuffer
8
9import logging
10
11import torch
12
13from executorch.backends.cadence.aot.ops_registrations import *  # noqa
14
15from typing import Tuple
16
17from executorch.backends.cadence.aot.export_example import export_model
18
19
20FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
21logging.basicConfig(level=logging.INFO, format=FORMAT)
22
23
24if __name__ == "__main__":
25
26    class Joiner(torch.nn.Module):
27        def __init__(
28            self, input_dim: int, output_dim: int, activation: str = "relu"
29        ) -> None:
30            super().__init__()
31            self.linear = torch.nn.Linear(input_dim, output_dim, bias=True)
32            if activation == "relu":
33                # pyre-fixme[4]: Attribute must be annotated.
34                self.activation = torch.nn.ReLU()
35            elif activation == "tanh":
36                self.activation = torch.nn.Tanh()
37            else:
38                raise ValueError(f"Unsupported activation {activation}")
39
40        def forward(
41            self,
42            source_encodings: torch.Tensor,
43            target_encodings: torch.Tensor,
44        ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
45            joint_encodings = (
46                source_encodings.unsqueeze(2).contiguous()
47                + target_encodings.unsqueeze(1).contiguous()
48            )
49            activation_out = self.activation(joint_encodings)
50            output = self.linear(activation_out)
51            return output
52
53    # Joiner
54    model = Joiner(256, 128)
55
56    # Get dummy joiner inputs
57    source_encodings = torch.randn(1, 25, 256)
58    target_encodings = torch.randn(1, 10, 256)
59
60    example_inputs = (
61        source_encodings,
62        target_encodings,
63    )
64
65    export_model(model, example_inputs)
66