# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # Example script for exporting simple models to flatbuffer import logging import torch from executorch.backends.cadence.aot.ops_registrations import * # noqa from typing import List, Optional, Tuple from executorch.backends.cadence.aot.export_example import export_model from torchaudio.prototype.models import ConvEmformer FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) if __name__ == "__main__": class _TimeReduction(torch.nn.Module): def __init__(self, stride: int) -> None: super().__init__() self.stride = stride def forward( self, input: torch.Tensor, lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: B, T, D = input.shape num_frames = T - (T % self.stride) input = input[:, :num_frames, :] lengths = lengths.div(self.stride, rounding_mode="trunc") T_max = num_frames // self.stride output = input.reshape(B, T_max, D * self.stride) output = output.contiguous() return output, lengths class ConvEmformerEncoder(torch.nn.Module): def __init__( self, *, input_dim: int, output_dim: int, segment_length: int, kernel_size: int, right_context_length: int, time_reduction_stride: int, transformer_input_dim: int, transformer_num_heads: int, transformer_ffn_dim: int, transformer_num_layers: int, transformer_left_context_length: int, transformer_dropout: float = 0.0, transformer_activation: str = "relu", transformer_max_memory_size: int = 0, transformer_weight_init_scale_strategy: str = "depthwise", transformer_tanh_on_mem: bool = False, ) -> None: super().__init__() self.time_reduction = _TimeReduction(time_reduction_stride) self.input_linear = torch.nn.Linear( input_dim * time_reduction_stride, transformer_input_dim, bias=False, ) self.transformer = ConvEmformer( transformer_input_dim, transformer_num_heads, transformer_ffn_dim, transformer_num_layers, segment_length // time_reduction_stride, kernel_size=kernel_size, dropout=transformer_dropout, ffn_activation=transformer_activation, left_context_length=transformer_left_context_length, right_context_length=right_context_length // time_reduction_stride, max_memory_size=transformer_max_memory_size, weight_init_scale_strategy=transformer_weight_init_scale_strategy, tanh_on_mem=transformer_tanh_on_mem, conv_activation="silu", ) self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim) self.layer_norm = torch.nn.LayerNorm(output_dim) def forward( self, input: torch.Tensor, lengths: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: time_reduction_out, time_reduction_lengths = self.time_reduction( input, lengths ) input_linear_out = self.input_linear(time_reduction_out) transformer_out, transformer_lengths = self.transformer( input_linear_out, time_reduction_lengths ) output_linear_out = self.output_linear(transformer_out) layer_norm_out = self.layer_norm(output_linear_out) return layer_norm_out, transformer_lengths @torch.jit.export def infer( self, input: torch.Tensor, lengths: torch.Tensor, states: Optional[List[List[torch.Tensor]]], ) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]: time_reduction_out, time_reduction_lengths = self.time_reduction( input, lengths ) input_linear_out = self.input_linear(time_reduction_out) ( transformer_out, transformer_lengths, transformer_states, ) = self.transformer.infer(input_linear_out, time_reduction_lengths, states) output_linear_out = self.output_linear(transformer_out) layer_norm_out = self.layer_norm(output_linear_out) return layer_norm_out, transformer_lengths, transformer_states # Instantiate model time_reduction_stride = 4 encoder = ConvEmformerEncoder( input_dim=80, output_dim=256, segment_length=4 * time_reduction_stride, kernel_size=7, right_context_length=1 * time_reduction_stride, time_reduction_stride=time_reduction_stride, transformer_input_dim=128, transformer_num_heads=4, transformer_ffn_dim=512, transformer_num_layers=1, transformer_left_context_length=10, transformer_tanh_on_mem=True, ) # Batch size batch_size = 1 max_input_length = 100 input_dim = 80 right_context_length = 4 # Dummy inputs transcriber_input = torch.rand( batch_size, max_input_length + right_context_length, input_dim ) transcriber_lengths = torch.randint(1, max_input_length + 1, (batch_size,)) example_inputs = ( transcriber_input, transcriber_lengths, ) export_model(encoder, example_inputs)