1import os 2import sys 3import argparse 4 5import torch 6from torch import nn 7 8 9sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange')) 10import wexchange.torch 11 12from models import model_dict 13 14unquantized = [ 15 'bfcc_with_corr_upsampler.fc', 16 'cont_net.0', 17 'fwc6.cont_fc.0', 18 'fwc6.fc.0', 19 'fwc6.fc.1.gate', 20 'fwc7.cont_fc.0', 21 'fwc7.fc.0', 22 'fwc7.fc.1.gate' 23] 24 25description=f""" 26This is an unsafe dumping script for FWGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer 27and will fail to export any other weights. 28 29Furthermore, the quanitze option relies on the following explicit list of layers to be excluded: 30{unquantized}. 31 32Modify this script manually if adjustments are needed. 33""" 34 35parser = argparse.ArgumentParser(description=description) 36parser.add_argument('model', choices=['fwgan400', 'fwgan500'], help='model name') 37parser.add_argument('weightfile', type=str, help='weight file path') 38parser.add_argument('export_folder', type=str) 39parser.add_argument('--export-filename', type=str, default='fwgan_data', help='filename for source and header file (.c and .h will be added), defaults to fwgan_data') 40parser.add_argument('--struct-name', type=str, default='FWGAN', help='name for C struct, defaults to FWGAN') 41parser.add_argument('--quantize', action='store_true', help='apply quantization') 42 43if __name__ == "__main__": 44 args = parser.parse_args() 45 46 model = model_dict[args.model]() 47 48 print(f"loading weights from {args.weightfile}...") 49 saved_gen= torch.load(args.weightfile, map_location='cpu') 50 model.load_state_dict(saved_gen) 51 def _remove_weight_norm(m): 52 try: 53 torch.nn.utils.remove_weight_norm(m) 54 except ValueError: # this module didn't have weight norm 55 return 56 model.apply(_remove_weight_norm) 57 58 59 print("dumping model...") 60 quantize_model=args.quantize 61 62 output_folder = args.export_folder 63 os.makedirs(output_folder, exist_ok=True) 64 65 writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name) 66 67 for name, module in model.named_modules(): 68 69 if quantize_model: 70 quantize=name not in unquantized 71 scale = None if quantize else 1/128 72 else: 73 quantize=False 74 scale=1/128 75 76 if isinstance(module, nn.Linear): 77 print(f"dumping linear layer {name}...") 78 wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale) 79 80 if isinstance(module, nn.Conv1d): 81 print(f"dumping conv1d layer {name}...") 82 wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale) 83 84 if isinstance(module, nn.GRU): 85 print(f"dumping GRU layer {name}...") 86 wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale) 87 88 writer.close() 89