xref: /aosp_15_r20/external/libopus/dnn/torch/fwgan/dump_model_weights.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import os
2import sys
3import argparse
4
5import torch
6from torch import nn
7
8
9sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange'))
10import wexchange.torch
11
12from models import model_dict
13
14unquantized = [
15    'bfcc_with_corr_upsampler.fc',
16    'cont_net.0',
17    'fwc6.cont_fc.0',
18    'fwc6.fc.0',
19    'fwc6.fc.1.gate',
20    'fwc7.cont_fc.0',
21    'fwc7.fc.0',
22    'fwc7.fc.1.gate'
23]
24
25description=f"""
26This is an unsafe dumping script for FWGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer
27and will fail to export any other weights.
28
29Furthermore, the quanitze option relies on the following explicit list of layers to be excluded:
30{unquantized}.
31
32Modify this script manually if adjustments are needed.
33"""
34
35parser = argparse.ArgumentParser(description=description)
36parser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name')
37parser.add_argument('weightfile', type=str, help='weight file path')
38parser.add_argument('export_folder', type=str)
39parser.add_argument('--export-filename', type=str, default='fwgan_data', help='filename for source and header file (.c and .h will be added), defaults to fwgan_data')
40parser.add_argument('--struct-name', type=str, default='FWGAN', help='name for C struct, defaults to FWGAN')
41parser.add_argument('--quantize', action='store_true', help='apply quantization')
42
43if __name__ == "__main__":
44    args = parser.parse_args()
45
46    model = model_dict[args.model]()
47
48    print(f"loading weights from {args.weightfile}...")
49    saved_gen= torch.load(args.weightfile, map_location='cpu')
50    model.load_state_dict(saved_gen)
51    def _remove_weight_norm(m):
52        try:
53            torch.nn.utils.remove_weight_norm(m)
54        except ValueError:  # this module didn't have weight norm
55            return
56    model.apply(_remove_weight_norm)
57
58
59    print("dumping model...")
60    quantize_model=args.quantize
61
62    output_folder = args.export_folder
63    os.makedirs(output_folder, exist_ok=True)
64
65    writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name)
66
67    for name, module in model.named_modules():
68
69        if quantize_model:
70            quantize=name not in unquantized
71            scale = None if quantize else 1/128
72        else:
73            quantize=False
74            scale=1/128
75
76        if isinstance(module, nn.Linear):
77            print(f"dumping linear layer {name}...")
78            wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
79
80        if isinstance(module, nn.Conv1d):
81            print(f"dumping conv1d layer {name}...")
82            wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale)
83
84        if isinstance(module, nn.GRU):
85            print(f"dumping GRU layer {name}...")
86            wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale)
87
88    writer.close()
89