1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import os 31 32import tensorflow as tf 33import numpy as np 34 35from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer 36 37def dump_tf_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128): 38 39 40 assert gru.activation == tf.keras.activations.tanh 41 assert gru.recurrent_activation == tf.keras.activations.sigmoid 42 assert gru.reset_after == True 43 44 w_ih = gru.weights[0].numpy().transpose().copy() 45 w_hh = gru.weights[1].numpy().transpose().copy() 46 b_ih = gru.weights[2].numpy()[0].copy() 47 b_hh = gru.weights[2].numpy()[1].copy() 48 49 if isinstance(where, CWriter): 50 return print_gru_layer(where, name, w_ih, w_hh, b_ih, b_hh, format='tf', input_sparse=input_sparse, recurrent_sparse=recurrent_sparse, quantize=quantize, scale=scale, recurrent_scale=recurrent_scale) 51 else: 52 os.makedirs(where, exist_ok=True) 53 54 # zrn => rzn 55 N = w_ih.shape[0] // 3 56 for x in [w_ih, w_hh, b_ih, b_hh]: 57 tmp = x[0:N].copy() 58 x[0:N] = x[N:2*N] 59 x[N:2*N] = tmp 60 61 np.save(os.path.join(where, 'weight_ih_rzn.npy'), w_ih) 62 np.save(os.path.join(where, 'weight_hh_rzn.npy'), w_hh) 63 np.save(os.path.join(where, 'bias_ih_rzn.npy'), b_ih) 64 np.save(os.path.join(where, 'bias_hh_rzn.npy'), b_hh) 65 66 67def load_tf_gru_weights(path, gru): 68 69 assert gru.activation == tf.keras.activations.tanh 70 assert gru.recurrent_activation == tf.keras.activations.sigmoid 71 assert gru.reset_after == True 72 73 w_ih = np.load(os.path.join(path, 'weight_ih_rzn.npy')) 74 w_hh = np.load(os.path.join(path, 'weight_hh_rzn.npy')) 75 b_ih = np.load(os.path.join(path, 'bias_ih_rzn.npy')) 76 b_hh = np.load(os.path.join(path, 'bias_hh_rzn.npy')) 77 78 # rzn => zrn 79 N = w_ih.shape[0] // 3 80 for x in [w_ih, w_hh, b_ih, b_hh]: 81 tmp = x[0:N].copy() 82 x[0:N] = x[N:2*N] 83 x[N:2*N] = tmp 84 85 gru.weights[0].assign(tf.convert_to_tensor(w_ih.transpose())) 86 gru.weights[1].assign(tf.convert_to_tensor(w_hh.transpose())) 87 gru.weights[2].assign(tf.convert_to_tensor(np.vstack((b_ih, b_hh)))) 88 89 90def dump_tf_dense_weights(where, dense, name='dense', scale=1/128, sparse=False, diagonal=False, quantize=False): 91 92 w = dense.weights[0].numpy() 93 if dense.bias is None: 94 b = np.zeros(dense.units, dtype=w.dtype) 95 else: 96 b = dense.bias.numpy() 97 98 99 100 if isinstance(where, CWriter): 101 return print_dense_layer(where, name, w, b, scale=scale, format='tf', sparse=sparse, diagonal=diagonal, quantize=quantize) 102 103 else: 104 os.makedirs(where, exist_ok=True) 105 106 np.save(os.path.join(where, 'weight.npy'), w.transpose()) 107 np.save(os.path.join(where, 'bias.npy'), b) 108 109 110def load_tf_dense_weights(path, dense): 111 112 w = np.load(os.path.join(path, 'weight.npy')).transpose() 113 b = np.load(os.path.join(path, 'bias.npy')) 114 115 dense.weights[0].assign(tf.convert_to_tensor(w)) 116 if dense.bias is not None: 117 dense.weights[1].assign(tf.convert_to_tensor(b)) 118 119 120def dump_tf_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False): 121 122 assert conv.data_format == 'channels_last' 123 124 w = conv.weights[0].numpy().copy() 125 if conv.bias is None: 126 b = np.zeros(conv.filters, dtype=w.dtype) 127 else: 128 b = conv.bias.numpy() 129 130 if isinstance(where, CWriter): 131 return print_conv1d_layer(where, name, w, b, scale=scale, format='tf', quantize=quantize) 132 else: 133 os.makedirs(where, exist_ok=True) 134 135 w = np.transpose(w, (2, 1, 0)) 136 np.save(os.path.join(where, 'weight_oik.npy'), w) 137 np.save(os.path.join(where, 'bias.npy'), b) 138 139 140def load_tf_conv1d_weights(path, conv): 141 142 w = np.load(os.path.join(path, 'weight_oik.npy')) 143 b = np.load(os.path.join(path, 'bias.npy')) 144 145 w = np.transpose(w, (2, 1, 0)) 146 147 conv.weights[0].assign(tf.convert_to_tensor(w)) 148 if conv.bias is not None: 149 conv.weights[1].assign(tf.convert_to_tensor(b)) 150 151 152def dump_tf_embedding_weights(path, emb): 153 os.makedirs(path, exist_ok=True) 154 155 w = emb.weights[0].numpy() 156 np.save(os.path.join(path, 'weight.npy'), w) 157 158 159 160def load_tf_embedding_weights(path, emb): 161 162 w = np.load(os.path.join(path, 'weight.npy')) 163 emb.weights[0].assign(tf.convert_to_tensor(w)) 164 165 166def dump_tf_weights(path, module): 167 if isinstance(module, tf.keras.layers.Dense): 168 dump_tf_dense_weights(path, module) 169 elif isinstance(module, tf.keras.layers.GRU): 170 dump_tf_gru_weights(path, module) 171 elif isinstance(module, tf.keras.layers.Conv1D): 172 dump_tf_conv1d_weights(path, module) 173 elif isinstance(module, tf.keras.layers.Embedding): 174 dump_tf_embedding_weights(path, module) 175 else: 176 raise ValueError(f'dump_tf_weights: layer of type {type(module)} not supported') 177 178def load_tf_weights(path, module): 179 if isinstance(module, tf.keras.layers.Dense): 180 load_tf_dense_weights(path, module) 181 elif isinstance(module, tf.keras.layers.GRU): 182 load_tf_gru_weights(path, module) 183 elif isinstance(module, tf.keras.layers.Conv1D): 184 load_tf_conv1d_weights(path, module) 185 elif isinstance(module, tf.keras.layers.Embedding): 186 load_tf_embedding_weights(path, module) 187 else: 188 raise ValueError(f'dump_tf_weights: layer of type {type(module)} not supported')