xref: /aosp_15_r20/external/pytorch/benchmarks/functional_autograd_benchmark/audio_text_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torchaudio_models as models
2from utils import check_for_functorch, extract_weights, GetterReturnType, load_weights
3
4import torch
5from torch import nn, Tensor
6
7
8has_functorch = check_for_functorch()
9
10
11def get_wav2letter(device: torch.device) -> GetterReturnType:
12    N = 10
13    input_frames = 700
14    vocab_size = 28
15    model = models.Wav2Letter(num_classes=vocab_size)
16    criterion = torch.nn.NLLLoss()
17    model.to(device)
18    params, names = extract_weights(model)
19
20    inputs = torch.rand([N, 1, input_frames], device=device)
21    labels = torch.rand(N, 3, device=device).mul(vocab_size).long()
22
23    def forward(*new_params: Tensor) -> Tensor:
24        load_weights(model, names, new_params)
25        out = model(inputs)
26
27        loss = criterion(out, labels)
28        return loss
29
30    return forward, params
31
32
33def get_deepspeech(device: torch.device) -> GetterReturnType:
34    sample_rate = 16000
35    window_size = 0.02
36    window = "hamming"
37    audio_conf = dict(
38        sample_rate=sample_rate, window_size=window_size, window=window, noise_dir=None
39    )
40
41    N = 10
42    num_classes = 10
43    spectrogram_size = 161
44    # Commented are the original sizes in the code
45    seq_length = 500  # 1343
46    target_length = 10  # 50
47    labels = torch.rand(num_classes, device=device)
48    inputs = torch.rand(N, 1, spectrogram_size, seq_length, device=device)
49    # Sequence length for each input
50    inputs_sizes = (
51        torch.rand(N, device=device).mul(seq_length * 0.1).add(seq_length * 0.8)
52    )
53    targets = torch.rand(N, target_length, device=device)
54    targets_sizes = torch.full((N,), target_length, dtype=torch.int, device=device)
55
56    model = models.DeepSpeech(
57        rnn_type=nn.LSTM,
58        labels=labels,
59        rnn_hidden_size=1024,
60        nb_layers=5,
61        audio_conf=audio_conf,
62        bidirectional=True,
63    )
64
65    if has_functorch:
66        from functorch.experimental import replace_all_batch_norm_modules_
67
68        replace_all_batch_norm_modules_(model)
69
70    model = model.to(device)
71    criterion = nn.CTCLoss()
72    params, names = extract_weights(model)
73
74    def forward(*new_params: Tensor) -> Tensor:
75        load_weights(model, names, new_params)
76        out, out_sizes = model(inputs, inputs_sizes)
77        out = out.transpose(0, 1)  # For ctc loss
78
79        loss = criterion(out, targets, out_sizes, targets_sizes)
80        return loss
81
82    return forward, params
83
84
85def get_transformer(device: torch.device) -> GetterReturnType:
86    # For most SOTA research, you would like to have embed to 720, nhead to 12, bsz to 64, tgt_len/src_len to 128.
87    N = 64
88    seq_length = 128
89    ntoken = 50
90    model = models.TransformerModel(
91        ntoken=ntoken, ninp=720, nhead=12, nhid=2048, nlayers=2
92    )
93    model.to(device)
94
95    if has_functorch:
96        # disable dropout for consistency checking
97        model.eval()
98
99    criterion = nn.NLLLoss()
100    params, names = extract_weights(model)
101
102    data = torch.rand(N, seq_length + 1, device=device).mul(ntoken).long()
103    inputs = data.narrow(1, 0, seq_length)
104    targets = data.narrow(1, 1, seq_length)
105
106    def forward(*new_params: Tensor) -> Tensor:
107        load_weights(model, names, new_params)
108        out = model(inputs)
109
110        loss = criterion(
111            out.reshape(N * seq_length, ntoken), targets.reshape(N * seq_length)
112        )
113        return loss
114
115    return forward, params
116
117
118def get_multiheadattn(device: torch.device) -> GetterReturnType:
119    # From https://github.com/pytorch/text/blob/master/test/data/test_modules.py#L10
120    embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64
121    # Build torchtext MultiheadAttention module
122    in_proj = models.InProjContainer(
123        torch.nn.Linear(embed_dim, embed_dim, bias=False),
124        torch.nn.Linear(embed_dim, embed_dim, bias=False),
125        torch.nn.Linear(embed_dim, embed_dim, bias=False),
126    )
127
128    model = models.MultiheadAttentionContainer(
129        nhead,
130        in_proj,
131        models.ScaledDotProduct(),
132        torch.nn.Linear(embed_dim, embed_dim, bias=False),
133    )
134    model.to(device)
135    params, names = extract_weights(model)
136
137    query = torch.rand((tgt_len, bsz, embed_dim), device=device)
138    key = value = torch.rand((src_len, bsz, embed_dim), device=device)
139    attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len), device=device).to(torch.bool)
140    bias_k = bias_v = torch.rand((1, 1, embed_dim), device=device)
141
142    attn_mask = torch.stack([attn_mask_2D] * (bsz * nhead))
143    bias_k = bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)
144    bias_v = bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)
145
146    def forward(*new_params: Tensor) -> Tensor:
147        load_weights(model, names, new_params)
148        mha_output, attn_weights = model(
149            query, key, value, attn_mask=attn_mask, bias_k=bias_k, bias_v=bias_v
150        )
151
152        # Don't test any specific loss, just backprop ones for both outputs
153        loss = mha_output.sum() + attn_weights.sum()
154
155        return loss
156
157    return forward, params
158