1*1295d682SXin Li#!/usr/bin/python 2*1295d682SXin Li 3*1295d682SXin Lifrom __future__ import print_function 4*1295d682SXin Li 5*1295d682SXin Lifrom keras.models import Sequential 6*1295d682SXin Lifrom keras.layers import Dense 7*1295d682SXin Lifrom keras.layers import LSTM 8*1295d682SXin Lifrom keras.layers import GRU 9*1295d682SXin Lifrom keras.models import load_model 10*1295d682SXin Lifrom keras import backend as K 11*1295d682SXin Liimport sys 12*1295d682SXin Liimport re 13*1295d682SXin Liimport numpy as np 14*1295d682SXin Li 15*1295d682SXin Lidef printVector(f, ft, vector, name): 16*1295d682SXin Li v = np.reshape(vector, (-1)); 17*1295d682SXin Li #print('static const float ', name, '[', len(v), '] = \n', file=f) 18*1295d682SXin Li f.write('static const rnn_weight {}[{}] = {{\n '.format(name, len(v))) 19*1295d682SXin Li for i in range(0, len(v)): 20*1295d682SXin Li f.write('{}'.format(min(127, int(round(256*v[i]))))) 21*1295d682SXin Li ft.write('{}'.format(min(127, int(round(256*v[i]))))) 22*1295d682SXin Li if (i!=len(v)-1): 23*1295d682SXin Li f.write(',') 24*1295d682SXin Li else: 25*1295d682SXin Li break; 26*1295d682SXin Li ft.write(" ") 27*1295d682SXin Li if (i%8==7): 28*1295d682SXin Li f.write("\n ") 29*1295d682SXin Li else: 30*1295d682SXin Li f.write(" ") 31*1295d682SXin Li #print(v, file=f) 32*1295d682SXin Li f.write('\n};\n\n') 33*1295d682SXin Li ft.write("\n") 34*1295d682SXin Li return; 35*1295d682SXin Li 36*1295d682SXin Lidef printLayer(f, ft, layer): 37*1295d682SXin Li weights = layer.get_weights() 38*1295d682SXin Li activation = re.search('function (.*) at', str(layer.activation)).group(1).upper() 39*1295d682SXin Li if len(weights) > 2: 40*1295d682SXin Li ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1]/3)) 41*1295d682SXin Li else: 42*1295d682SXin Li ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1])) 43*1295d682SXin Li if activation == 'SIGMOID': 44*1295d682SXin Li ft.write('1\n') 45*1295d682SXin Li elif activation == 'RELU': 46*1295d682SXin Li ft.write('2\n') 47*1295d682SXin Li else: 48*1295d682SXin Li ft.write('0\n') 49*1295d682SXin Li printVector(f, ft, weights[0], layer.name + '_weights') 50*1295d682SXin Li if len(weights) > 2: 51*1295d682SXin Li printVector(f, ft, weights[1], layer.name + '_recurrent_weights') 52*1295d682SXin Li printVector(f, ft, weights[-1], layer.name + '_bias') 53*1295d682SXin Li name = layer.name 54*1295d682SXin Li if len(weights) > 2: 55*1295d682SXin Li f.write('static const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' 56*1295d682SXin Li .format(name, name, name, name, weights[0].shape[0], weights[0].shape[1]/3, activation)) 57*1295d682SXin Li else: 58*1295d682SXin Li f.write('static const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' 59*1295d682SXin Li .format(name, name, name, weights[0].shape[0], weights[0].shape[1], activation)) 60*1295d682SXin Li 61*1295d682SXin Lidef structLayer(f, layer): 62*1295d682SXin Li weights = layer.get_weights() 63*1295d682SXin Li name = layer.name 64*1295d682SXin Li if len(weights) > 2: 65*1295d682SXin Li f.write(' {},\n'.format(weights[0].shape[1]/3)) 66*1295d682SXin Li else: 67*1295d682SXin Li f.write(' {},\n'.format(weights[0].shape[1])) 68*1295d682SXin Li f.write(' &{},\n'.format(name)) 69*1295d682SXin Li 70*1295d682SXin Li 71*1295d682SXin Lidef foo(c, name): 72*1295d682SXin Li return None 73*1295d682SXin Li 74*1295d682SXin Lidef mean_squared_sqrt_error(y_true, y_pred): 75*1295d682SXin Li return K.mean(K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1) 76*1295d682SXin Li 77*1295d682SXin Li 78*1295d682SXin Limodel = load_model(sys.argv[1], custom_objects={'msse': mean_squared_sqrt_error, 'mean_squared_sqrt_error': mean_squared_sqrt_error, 'my_crossentropy': mean_squared_sqrt_error, 'mycost': mean_squared_sqrt_error, 'WeightClip': foo}) 79*1295d682SXin Li 80*1295d682SXin Liweights = model.get_weights() 81*1295d682SXin Li 82*1295d682SXin Lif = open(sys.argv[2], 'w') 83*1295d682SXin Lift = open(sys.argv[3], 'w') 84*1295d682SXin Li 85*1295d682SXin Lif.write('/*This file is automatically generated from a Keras model*/\n\n') 86*1295d682SXin Lif.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n#include "rnn_data.h"\n\n') 87*1295d682SXin Lift.write('rnnoise-nu model file version 1\n') 88*1295d682SXin Li 89*1295d682SXin Lilayer_list = [] 90*1295d682SXin Lifor i, layer in enumerate(model.layers): 91*1295d682SXin Li if len(layer.get_weights()) > 0: 92*1295d682SXin Li printLayer(f, ft, layer) 93*1295d682SXin Li if len(layer.get_weights()) > 2: 94*1295d682SXin Li layer_list.append(layer.name) 95*1295d682SXin Li 96*1295d682SXin Lif.write('const struct RNNModel rnnoise_model_{} = {{\n'.format(sys.argv[4])) 97*1295d682SXin Lifor i, layer in enumerate(model.layers): 98*1295d682SXin Li if len(layer.get_weights()) > 0: 99*1295d682SXin Li structLayer(f, layer) 100*1295d682SXin Lif.write('};\n') 101*1295d682SXin Li 102*1295d682SXin Li#hf.write('struct RNNState {\n') 103*1295d682SXin Li#for i, name in enumerate(layer_list): 104*1295d682SXin Li# hf.write(' float {}_state[{}_SIZE];\n'.format(name, name.upper())) 105*1295d682SXin Li#hf.write('};\n') 106*1295d682SXin Li 107*1295d682SXin Lif.close() 108