xref: /aosp_15_r20/external/libopus/dnn/torch/fargan/dump_fargan_weights.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import os
2import sys
3import argparse
4
5import torch
6from torch import nn
7
8
9sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange'))
10import wexchange.torch
11
12import fargan
13#from models import model_dict
14
15unquantized = [ 'cond_net.pembed', 'cond_net.fdense1', 'sig_net.cond_gain_dense', 'sig_net.gain_dense_out' ]
16
17unquantized2 = [
18    'cond_net.pembed',
19    'cond_net.fdense1',
20    'cond_net.fconv1',
21    'cond_net.fconv2',
22    'cont_net.0',
23    'sig_net.cond_gain_dense',
24    'sig_net.fwc0.conv',
25    'sig_net.fwc0.glu.gate',
26    'sig_net.dense1_glu.gate',
27    'sig_net.gru1_glu.gate',
28    'sig_net.gru2_glu.gate',
29    'sig_net.gru3_glu.gate',
30    'sig_net.skip_glu.gate',
31    'sig_net.skip_dense',
32    'sig_net.sig_dense_out',
33    'sig_net.gain_dense_out'
34]
35
36description=f"""
37This is an unsafe dumping script for FARGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer
38and will fail to export any other weights.
39
40Furthermore, the quanitze option relies on the following explicit list of layers to be excluded:
41{unquantized}.
42
43Modify this script manually if adjustments are needed.
44"""
45
46parser = argparse.ArgumentParser(description=description)
47parser.add_argument('weightfile', type=str, help='weight file path')
48parser.add_argument('export_folder', type=str)
49parser.add_argument('--export-filename', type=str, default='fargan_data', help='filename for source and header file (.c and .h will be added), defaults to fargan_data')
50parser.add_argument('--struct-name', type=str, default='FARGAN', help='name for C struct, defaults to FARGAN')
51parser.add_argument('--quantize', action='store_true', help='apply quantization')
52
53if __name__ == "__main__":
54    args = parser.parse_args()
55
56    print(f"loading weights from {args.weightfile}...")
57    saved_gen= torch.load(args.weightfile, map_location='cpu')
58    saved_gen['model_args']    = ()
59    saved_gen['model_kwargs']  = {'cond_size': 256, 'gamma': 0.9}
60
61    model = fargan.FARGAN(*saved_gen['model_args'], **saved_gen['model_kwargs'])
62    model.load_state_dict(saved_gen['state_dict'], strict=False)
63    def _remove_weight_norm(m):
64        try:
65            torch.nn.utils.remove_weight_norm(m)
66        except ValueError:  # this module didn't have weight norm
67            return
68    model.apply(_remove_weight_norm)
69
70
71    print("dumping model...")
72    quantize_model=args.quantize
73
74    output_folder = args.export_folder
75    os.makedirs(output_folder, exist_ok=True)
76
77    writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name)
78
79    for name, module in model.named_modules():
80
81        if quantize_model:
82            quantize=name not in unquantized
83            scale = None if quantize else 1/128
84        else:
85            quantize=False
86            scale=1/128
87
88        if isinstance(module, nn.Linear):
89            print(f"dumping linear layer {name}...")
90            wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
91
92        elif isinstance(module, nn.Conv1d):
93            print(f"dumping conv1d layer {name}...")
94            wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
95
96        elif isinstance(module, nn.GRU):
97            print(f"dumping GRU layer {name}...")
98            wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
99
100        elif isinstance(module, nn.GRUCell):
101            print(f"dumping GRUCell layer {name}...")
102            wexchange.torch.dump_torch_grucell_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
103
104        elif isinstance(module, nn.Embedding):
105            print(f"dumping Embedding layer {name}...")
106            wexchange.torch.dump_torch_embedding_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
107            #wexchange.torch.dump_torch_embedding_weights(writer, module)
108
109        else:
110            print(f"Ignoring layer {name}...")
111
112    writer.close()
113