1import lossgen 2import os 3import argparse 4import torch 5import numpy as np 6 7 8parser = argparse.ArgumentParser() 9 10parser.add_argument('model', type=str, help='CELPNet model') 11parser.add_argument('percentage', type=float, help='percentage loss') 12parser.add_argument('output', type=str, help='path to output file (ascii)') 13 14parser.add_argument('--length', type=int, help="length of sequence to generate", default=500) 15 16args = parser.parse_args() 17 18 19 20checkpoint = torch.load(args.model, map_location='cpu') 21model = lossgen.LossGen(*checkpoint['model_args'], **checkpoint['model_kwargs']) 22model.load_state_dict(checkpoint['state_dict'], strict=False) 23 24states=None 25last = torch.zeros((1,1,1)) 26perc = torch.tensor((args.percentage,))[None,None,:] 27seq = torch.zeros((0,1,1)) 28 29one = torch.ones((1,1,1)) 30zero = torch.zeros((1,1,1)) 31 32if __name__ == '__main__': 33 for i in range(args.length): 34 prob, states = model(last, perc, states=states) 35 prob = torch.sigmoid(prob) 36 states[0] = states[0].detach() 37 states[1] = states[1].detach() 38 loss = one if np.random.rand() < prob else zero 39 last = loss 40 seq = torch.cat([seq, loss]) 41 42np.savetxt(args.output, seq[:,:,0].numpy().astype('int'), fmt='%d') 43