1import torchaudio_models as models 2from utils import check_for_functorch, extract_weights, GetterReturnType, load_weights 3 4import torch 5from torch import nn, Tensor 6 7 8has_functorch = check_for_functorch() 9 10 11def get_wav2letter(device: torch.device) -> GetterReturnType: 12 N = 10 13 input_frames = 700 14 vocab_size = 28 15 model = models.Wav2Letter(num_classes=vocab_size) 16 criterion = torch.nn.NLLLoss() 17 model.to(device) 18 params, names = extract_weights(model) 19 20 inputs = torch.rand([N, 1, input_frames], device=device) 21 labels = torch.rand(N, 3, device=device).mul(vocab_size).long() 22 23 def forward(*new_params: Tensor) -> Tensor: 24 load_weights(model, names, new_params) 25 out = model(inputs) 26 27 loss = criterion(out, labels) 28 return loss 29 30 return forward, params 31 32 33def get_deepspeech(device: torch.device) -> GetterReturnType: 34 sample_rate = 16000 35 window_size = 0.02 36 window = "hamming" 37 audio_conf = dict( 38 sample_rate=sample_rate, window_size=window_size, window=window, noise_dir=None 39 ) 40 41 N = 10 42 num_classes = 10 43 spectrogram_size = 161 44 # Commented are the original sizes in the code 45 seq_length = 500 # 1343 46 target_length = 10 # 50 47 labels = torch.rand(num_classes, device=device) 48 inputs = torch.rand(N, 1, spectrogram_size, seq_length, device=device) 49 # Sequence length for each input 50 inputs_sizes = ( 51 torch.rand(N, device=device).mul(seq_length * 0.1).add(seq_length * 0.8) 52 ) 53 targets = torch.rand(N, target_length, device=device) 54 targets_sizes = torch.full((N,), target_length, dtype=torch.int, device=device) 55 56 model = models.DeepSpeech( 57 rnn_type=nn.LSTM, 58 labels=labels, 59 rnn_hidden_size=1024, 60 nb_layers=5, 61 audio_conf=audio_conf, 62 bidirectional=True, 63 ) 64 65 if has_functorch: 66 from functorch.experimental import replace_all_batch_norm_modules_ 67 68 replace_all_batch_norm_modules_(model) 69 70 model = model.to(device) 71 criterion = nn.CTCLoss() 72 params, names = extract_weights(model) 73 74 def forward(*new_params: Tensor) -> Tensor: 75 load_weights(model, names, new_params) 76 out, out_sizes = model(inputs, inputs_sizes) 77 out = out.transpose(0, 1) # For ctc loss 78 79 loss = criterion(out, targets, out_sizes, targets_sizes) 80 return loss 81 82 return forward, params 83 84 85def get_transformer(device: torch.device) -> GetterReturnType: 86 # For most SOTA research, you would like to have embed to 720, nhead to 12, bsz to 64, tgt_len/src_len to 128. 87 N = 64 88 seq_length = 128 89 ntoken = 50 90 model = models.TransformerModel( 91 ntoken=ntoken, ninp=720, nhead=12, nhid=2048, nlayers=2 92 ) 93 model.to(device) 94 95 if has_functorch: 96 # disable dropout for consistency checking 97 model.eval() 98 99 criterion = nn.NLLLoss() 100 params, names = extract_weights(model) 101 102 data = torch.rand(N, seq_length + 1, device=device).mul(ntoken).long() 103 inputs = data.narrow(1, 0, seq_length) 104 targets = data.narrow(1, 1, seq_length) 105 106 def forward(*new_params: Tensor) -> Tensor: 107 load_weights(model, names, new_params) 108 out = model(inputs) 109 110 loss = criterion( 111 out.reshape(N * seq_length, ntoken), targets.reshape(N * seq_length) 112 ) 113 return loss 114 115 return forward, params 116 117 118def get_multiheadattn(device: torch.device) -> GetterReturnType: 119 # From https://github.com/pytorch/text/blob/master/test/data/test_modules.py#L10 120 embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 121 # Build torchtext MultiheadAttention module 122 in_proj = models.InProjContainer( 123 torch.nn.Linear(embed_dim, embed_dim, bias=False), 124 torch.nn.Linear(embed_dim, embed_dim, bias=False), 125 torch.nn.Linear(embed_dim, embed_dim, bias=False), 126 ) 127 128 model = models.MultiheadAttentionContainer( 129 nhead, 130 in_proj, 131 models.ScaledDotProduct(), 132 torch.nn.Linear(embed_dim, embed_dim, bias=False), 133 ) 134 model.to(device) 135 params, names = extract_weights(model) 136 137 query = torch.rand((tgt_len, bsz, embed_dim), device=device) 138 key = value = torch.rand((src_len, bsz, embed_dim), device=device) 139 attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len), device=device).to(torch.bool) 140 bias_k = bias_v = torch.rand((1, 1, embed_dim), device=device) 141 142 attn_mask = torch.stack([attn_mask_2D] * (bsz * nhead)) 143 bias_k = bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1) 144 bias_v = bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1) 145 146 def forward(*new_params: Tensor) -> Tensor: 147 load_weights(model, names, new_params) 148 mha_output, attn_weights = model( 149 query, key, value, attn_mask=attn_mask, bias_k=bias_k, bias_v=bias_v 150 ) 151 152 # Don't test any specific loss, just backprop ones for both outputs 153 loss = mha_output.sum() + attn_weights.sum() 154 155 return loss 156 157 return forward, params 158