xref: /aosp_15_r20/external/executorch/examples/models/wav2letter/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8
9import torch
10from torchaudio import models
11
12from ..model_base import EagerModelBase
13
14
15class Wav2LetterModel(EagerModelBase):
16    def __init__(self):
17        self.batch_size = 10
18        self.input_frames = 700
19        self.vocab_size = 4096
20
21    def get_eager_model(self) -> torch.nn.Module:
22        logging.info("Loading wav2letter model")
23        wav2letter = models.Wav2Letter(num_classes=self.vocab_size)
24        logging.info("Loaded wav2letter model")
25        return wav2letter
26
27    def get_example_inputs(self):
28        input_shape = (self.batch_size, 1, self.input_frames)
29        return (torch.randn(input_shape),)
30