1import os 2import sys 3import argparse 4 5import torch 6from torch import nn 7 8 9sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange')) 10import wexchange.torch 11 12import fargan 13#from models import model_dict 14 15unquantized = [ 'cond_net.pembed', 'cond_net.fdense1', 'sig_net.cond_gain_dense', 'sig_net.gain_dense_out' ] 16 17unquantized2 = [ 18 'cond_net.pembed', 19 'cond_net.fdense1', 20 'cond_net.fconv1', 21 'cond_net.fconv2', 22 'cont_net.0', 23 'sig_net.cond_gain_dense', 24 'sig_net.fwc0.conv', 25 'sig_net.fwc0.glu.gate', 26 'sig_net.dense1_glu.gate', 27 'sig_net.gru1_glu.gate', 28 'sig_net.gru2_glu.gate', 29 'sig_net.gru3_glu.gate', 30 'sig_net.skip_glu.gate', 31 'sig_net.skip_dense', 32 'sig_net.sig_dense_out', 33 'sig_net.gain_dense_out' 34] 35 36description=f""" 37This is an unsafe dumping script for FARGAN models. It assumes that all weights are included in Linear, Conv1d or GRU layer 38and will fail to export any other weights. 39 40Furthermore, the quanitze option relies on the following explicit list of layers to be excluded: 41{unquantized}. 42 43Modify this script manually if adjustments are needed. 44""" 45 46parser = argparse.ArgumentParser(description=description) 47parser.add_argument('weightfile', type=str, help='weight file path') 48parser.add_argument('export_folder', type=str) 49parser.add_argument('--export-filename', type=str, default='fargan_data', help='filename for source and header file (.c and .h will be added), defaults to fargan_data') 50parser.add_argument('--struct-name', type=str, default='FARGAN', help='name for C struct, defaults to FARGAN') 51parser.add_argument('--quantize', action='store_true', help='apply quantization') 52 53if __name__ == "__main__": 54 args = parser.parse_args() 55 56 print(f"loading weights from {args.weightfile}...") 57 saved_gen= torch.load(args.weightfile, map_location='cpu') 58 saved_gen['model_args'] = () 59 saved_gen['model_kwargs'] = {'cond_size': 256, 'gamma': 0.9} 60 61 model = fargan.FARGAN(*saved_gen['model_args'], **saved_gen['model_kwargs']) 62 model.load_state_dict(saved_gen['state_dict'], strict=False) 63 def _remove_weight_norm(m): 64 try: 65 torch.nn.utils.remove_weight_norm(m) 66 except ValueError: # this module didn't have weight norm 67 return 68 model.apply(_remove_weight_norm) 69 70 71 print("dumping model...") 72 quantize_model=args.quantize 73 74 output_folder = args.export_folder 75 os.makedirs(output_folder, exist_ok=True) 76 77 writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name) 78 79 for name, module in model.named_modules(): 80 81 if quantize_model: 82 quantize=name not in unquantized 83 scale = None if quantize else 1/128 84 else: 85 quantize=False 86 scale=1/128 87 88 if isinstance(module, nn.Linear): 89 print(f"dumping linear layer {name}...") 90 wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale) 91 92 elif isinstance(module, nn.Conv1d): 93 print(f"dumping conv1d layer {name}...") 94 wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale) 95 96 elif isinstance(module, nn.GRU): 97 print(f"dumping GRU layer {name}...") 98 wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale) 99 100 elif isinstance(module, nn.GRUCell): 101 print(f"dumping GRUCell layer {name}...") 102 wexchange.torch.dump_torch_grucell_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale) 103 104 elif isinstance(module, nn.Embedding): 105 print(f"dumping Embedding layer {name}...") 106 wexchange.torch.dump_torch_embedding_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale) 107 #wexchange.torch.dump_torch_embedding_weights(writer, module) 108 109 else: 110 print(f"Ignoring layer {name}...") 111 112 writer.close() 113