1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import os 31import sys 32 33import torch 34import numpy as np 35 36sys.path.append(sys.path.append(os.path.join(os.path.dirname(__file__), '../osce'))) 37try: 38 import utils.layers as osce_layers 39 from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d 40 from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d 41 from utils.layers.td_shaper import TDShaper 42 has_osce=True 43except: 44 has_osce=False 45 46from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer, print_tconv1d_layer, print_conv2d_layer 47 48def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', scale=1/128, quantize=False): 49 50 51 w_kernel = adaconv.conv_kernel.weight.detach().cpu().numpy().copy() 52 b_kernel = adaconv.conv_kernel.bias.detach().cpu().numpy().copy() 53 w_gain = adaconv.filter_gain.weight.detach().cpu().numpy().copy() 54 b_gain = adaconv.filter_gain.bias.detach().cpu().numpy().copy() 55 56 if isinstance(where, CWriter): 57 # pad kernel for quantization 58 left_padding = adaconv.padding[0] 59 kernel_size = adaconv.kernel_size 60 in_channels = adaconv.in_channels 61 out_channels = adaconv.out_channels 62 feature_dim = adaconv.feature_dim 63 64 if quantize and kernel_size % 8: 65 kernel_padding = 8 - (kernel_size % 8) 66 w_kernel = np.concatenate( 67 (np.zeros((out_channels, in_channels, kernel_padding, feature_dim)), w_kernel.reshape(out_channels, in_channels, kernel_size, feature_dim)), 68 dtype=w_kernel.dtype, 69 axis=2).reshape(-1, feature_dim) 70 b_kernel = np.concatenate( 71 (np.zeros((out_channels, in_channels, kernel_padding)), b_kernel.reshape(out_channels, in_channels, kernel_size)), 72 dtype=b_kernel.dtype, 73 axis=2).reshape(-1) 74 left_padding += kernel_padding 75 kernel_size += kernel_padding 76 77 # write relevant scalar parameters to header file 78 where.header.write(f""" 79#define {name.upper()}_FILTER_GAIN_A {adaconv.filter_gain_a:f}f 80#define {name.upper()}_FILTER_GAIN_B {adaconv.filter_gain_b:f}f 81#define {name.upper()}_SHAPE_GAIN {adaconv.shape_gain:f}f 82#define {name.upper()}_KERNEL_SIZE {kernel_size} 83#define {name.upper()}_FRAME_SIZE {adaconv.frame_size} 84#define {name.upper()}_LEFT_PADDING {left_padding} 85#define {name.upper()}_OVERLAP_SIZE {adaconv.overlap_size} 86#define {name.upper()}_IN_CHANNELS {adaconv.in_channels} 87#define {name.upper()}_OUT_CHANNELS {adaconv.out_channels} 88#define {name.upper()}_NORM_P {adaconv.norm_p} 89#define {name.upper()}_FEATURE_DIM {adaconv.feature_dim} 90""" 91 ) 92 93 print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=scale, format='torch', sparse=False, diagonal=False, quantize=quantize) 94 print_dense_layer(where, name + "_gain", w_gain, b_gain, format='torch', sparse=False, diagonal=False, quantize=False) 95 96 97 else: 98 np.save(where, 'weight_kernel.npy', w_kernel) 99 np.save(where, 'bias_kernel.npy', b_kernel) 100 np.save(where, 'weight_gain.npy', w_gain) 101 np.save(where, 'bias_gain.npy', b_gain) 102 103 104def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', scale=1/128, quantize=False): 105 106 107 w_kernel = adaconv.conv_kernel.weight.detach().cpu().numpy().copy() 108 b_kernel = adaconv.conv_kernel.bias.detach().cpu().numpy().copy() 109 w_gain = adaconv.filter_gain.weight.detach().cpu().numpy().copy() 110 b_gain = adaconv.filter_gain.bias.detach().cpu().numpy().copy() 111 w_global_gain = adaconv.global_filter_gain.weight.detach().cpu().numpy().copy() 112 b_global_gain = adaconv.global_filter_gain.bias.detach().cpu().numpy().copy() 113 114 115 if isinstance(where, CWriter): 116 # pad kernel for quantization 117 left_padding = adaconv.padding[0] 118 kernel_size = adaconv.kernel_size 119 120 if quantize and w_kernel.shape[0] % 8: 121 kernel_padding = 8 - (w_kernel.shape[0] % 8) 122 w_kernel = np.concatenate((np.zeros((kernel_padding, w_kernel.shape[1])), w_kernel), dtype=w_kernel.dtype) 123 b_kernel = np.concatenate((np.zeros((kernel_padding)), b_kernel), dtype=b_kernel.dtype) 124 left_padding += kernel_padding 125 kernel_size += kernel_padding 126 # write relevant scalar parameters to header file 127 where.header.write(f""" 128#define {name.upper()}_FILTER_GAIN_A {adaconv.filter_gain_a:f}f 129#define {name.upper()}_FILTER_GAIN_B {adaconv.filter_gain_b:f}f 130#define {name.upper()}_LOG_GAIN_LIMIT {adaconv.log_gain_limit:f}f 131#define {name.upper()}_KERNEL_SIZE {kernel_size} 132#define {name.upper()}_LEFT_PADDING {left_padding} 133#define {name.upper()}_FRAME_SIZE {adaconv.frame_size} 134#define {name.upper()}_OVERLAP_SIZE {adaconv.overlap_size} 135#define {name.upper()}_IN_CHANNELS {adaconv.in_channels} 136#define {name.upper()}_OUT_CHANNELS {adaconv.out_channels} 137#define {name.upper()}_NORM_P {adaconv.norm_p} 138#define {name.upper()}_FEATURE_DIM {adaconv.feature_dim} 139#define {name.upper()}_MAX_LAG {adaconv.max_lag} 140""" 141 ) 142 143 print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=scale, format='torch', sparse=False, diagonal=False, quantize=quantize) 144 print_dense_layer(where, name + "_gain", w_gain, b_gain, format='torch', sparse=False, diagonal=False, quantize=False) 145 print_dense_layer(where, name + "_global_gain", w_global_gain, b_global_gain, format='torch', sparse=False, diagonal=False, quantize=False) 146 147 148 else: 149 np.save(where, 'weight_kernel.npy', w_kernel) 150 np.save(where, 'bias_kernel.npy', b_kernel) 151 np.save(where, 'weight_gain.npy', w_gain) 152 np.save(where, 'bias_gain.npy', b_gain) 153 np.save(where, 'weight_global_gain.npy', w_global_gain) 154 np.save(where, 'bias_global_gain.npy', b_global_gain) 155 156def dump_torch_tdshaper(where, shaper, name='tdshaper', quantize=False, scale=1/128): 157 158 if isinstance(where, CWriter): 159 where.header.write(f""" 160#define {name.upper()}_FEATURE_DIM {shaper.feature_dim} 161#define {name.upper()}_FRAME_SIZE {shaper.frame_size} 162#define {name.upper()}_AVG_POOL_K {shaper.avg_pool_k} 163#define {name.upper()}_INNOVATE {1 if shaper.innovate else 0} 164#define {name.upper()}_POOL_AFTER {1 if shaper.pool_after else 0} 165""" 166 ) 167 168 dump_torch_conv1d_weights(where, shaper.feature_alpha1_f, name + "_alpha1_f", quantize=quantize, scale=scale) 169 dump_torch_conv1d_weights(where, shaper.feature_alpha1_t, name + "_alpha1_t") 170 dump_torch_conv1d_weights(where, shaper.feature_alpha2, name + "_alpha2") 171 172 if shaper.innovate: 173 dump_torch_conv1d_weights(where, shaper.feature_alpha1b, name + "_alpha1b") 174 dump_torch_conv1d_weights(where, shaper.feature_alpha1c, name + "_alpha1c") 175 dump_torch_conv1d_weights(where, shaper.feature_alpha2b, name + "_alpha2b") 176 dump_torch_conv1d_weights(where, shaper.feature_alpha2c, name + "_alpha2c") 177 178 179 180def dump_torch_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128): 181 182 assert gru.num_layers == 1 183 assert gru.bidirectional == False 184 185 w_ih = gru.weight_ih_l0.detach().cpu().numpy().copy() 186 w_hh = gru.weight_hh_l0.detach().cpu().numpy().copy() 187 if hasattr(gru, 'bias_ih_l0'): 188 b_ih = gru.bias_ih_l0.detach().cpu().numpy().copy() 189 else: 190 b_ih = None 191 if hasattr(gru, 'bias_hh_l0'): 192 b_hh = gru.bias_hh_l0.detach().cpu().numpy().copy() 193 else: 194 b_hh = None 195 196 if isinstance(where, CWriter): 197 return print_gru_layer(where, name, w_ih, w_hh, b_ih, b_hh, format='torch', input_sparse=input_sparse, recurrent_sparse=recurrent_sparse, quantize=quantize, scale=scale, recurrent_scale=recurrent_scale) 198 else: 199 os.makedirs(where, exist_ok=True) 200 201 np.save(os.path.join(where, 'weight_ih_rzn.npy'), w_ih) 202 np.save(os.path.join(where, 'weight_hh_rzn.npy'), w_hh) 203 np.save(os.path.join(where, 'bias_ih_rzn.npy'), b_ih) 204 np.save(os.path.join(where, 'bias_hh_rzn.npy'), b_hh) 205 206 207def dump_torch_grucell_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128): 208 209 w_ih = gru.weight_ih.detach().cpu().numpy().copy() 210 w_hh = gru.weight_hh.detach().cpu().numpy().copy() 211 if hasattr(gru, 'bias_ih') and gru.bias_ih is not None: 212 b_ih = gru.bias_ih.detach().cpu().numpy().copy() 213 else: 214 b_ih = None 215 if hasattr(gru, 'bias_hh') and gru.bias_hh is not None: 216 b_hh = gru.bias_hh.detach().cpu().numpy().copy() 217 else: 218 b_hh = None 219 220 if isinstance(where, CWriter): 221 return print_gru_layer(where, name, w_ih, w_hh, b_ih, b_hh, format='torch', input_sparse=input_sparse, recurrent_sparse=recurrent_sparse, quantize=quantize, scale=scale, recurrent_scale=recurrent_scale) 222 else: 223 os.makedirs(where, exist_ok=True) 224 225 np.save(os.path.join(where, 'weight_ih_rzn.npy'), w_ih) 226 np.save(os.path.join(where, 'weight_hh_rzn.npy'), w_hh) 227 np.save(os.path.join(where, 'bias_ih_rzn.npy'), b_ih) 228 np.save(os.path.join(where, 'bias_hh_rzn.npy'), b_hh) 229 230 231 232def load_torch_gru_weights(where, gru): 233 234 assert gru.num_layers == 1 235 assert gru.bidirectional == False 236 237 w_ih = np.load(os.path.join(where, 'weight_ih_rzn.npy')) 238 w_hh = np.load(os.path.join(where, 'weight_hh_rzn.npy')) 239 b_ih = np.load(os.path.join(where, 'bias_ih_rzn.npy')) 240 b_hh = np.load(os.path.join(where, 'bias_hh_rzn.npy')) 241 242 with torch.no_grad(): 243 gru.weight_ih_l0.set_(torch.from_numpy(w_ih)) 244 gru.weight_hh_l0.set_(torch.from_numpy(w_hh)) 245 gru.bias_ih_l0.set_(torch.from_numpy(b_ih)) 246 gru.bias_hh_l0.set_(torch.from_numpy(b_hh)) 247 248 249def dump_torch_dense_weights(where, dense, name='dense', scale=1/128, sparse=False, diagonal=False, quantize=False): 250 251 w = dense.weight.detach().cpu().numpy().copy() 252 if dense.bias is None: 253 b = np.zeros(dense.out_features, dtype=w.dtype) 254 else: 255 b = dense.bias.detach().cpu().numpy().copy() 256 257 if isinstance(where, CWriter): 258 return print_dense_layer(where, name, w, b, scale=scale, format='torch', sparse=sparse, diagonal=diagonal, quantize=quantize) 259 260 else: 261 os.makedirs(where, exist_ok=True) 262 263 np.save(os.path.join(where, 'weight.npy'), w) 264 np.save(os.path.join(where, 'bias.npy'), b) 265 266 267def load_torch_dense_weights(where, dense): 268 269 w = np.load(os.path.join(where, 'weight.npy')) 270 b = np.load(os.path.join(where, 'bias.npy')) 271 272 with torch.no_grad(): 273 dense.weight.set_(torch.from_numpy(w)) 274 if dense.bias is not None: 275 dense.bias.set_(torch.from_numpy(b)) 276 277 278def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False): 279 280 w = conv.weight.detach().cpu().numpy().copy() 281 if conv.bias is None: 282 b = np.zeros(conv.out_channels, dtype=w.dtype) 283 else: 284 b = conv.bias.detach().cpu().numpy().copy() 285 286 if isinstance(where, CWriter): 287 288 return print_conv1d_layer(where, name, w, b, scale=scale, format='torch', quantize=quantize, sparse=sparse) 289 else: 290 os.makedirs(where, exist_ok=True) 291 292 np.save(os.path.join(where, 'weight_oik.npy'), w) 293 294 np.save(os.path.join(where, 'bias.npy'), b) 295 296 297def load_torch_conv1d_weights(where, conv): 298 299 with torch.no_grad(): 300 w = np.load(os.path.join(where, 'weight_oik.npy')) 301 conv.weight.set_(torch.from_numpy(w)) 302 if type(conv.bias) != type(None): 303 b = np.load(os.path.join(where, 'bias.npy')) 304 if conv.bias is not None: 305 conv.bias.set_(torch.from_numpy(b)) 306 307 308def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False): 309 310 w = conv.weight.detach().cpu().numpy().copy() 311 if conv.bias is None: 312 b = np.zeros(conv.out_channels, dtype=w.dtype) 313 else: 314 b = conv.bias.detach().cpu().numpy().copy() 315 316 if isinstance(where, CWriter): 317 318 return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize, sparse=sparse) 319 else: 320 os.makedirs(where, exist_ok=True) 321 322 np.save(os.path.join(where, 'weight_oik.npy'), w) 323 324 np.save(os.path.join(where, 'bias.npy'), b) 325 326 327def load_torch_tconv1d_weights(where, conv): 328 329 with torch.no_grad(): 330 w = np.load(os.path.join(where, 'weight_oik.npy')) 331 conv.weight.set_(torch.from_numpy(w)) 332 if type(conv.bias) != type(None): 333 b = np.load(os.path.join(where, 'bias.npy')) 334 if conv.bias is not None: 335 conv.bias.set_(torch.from_numpy(b)) 336 337 338def dump_torch_conv2d_weights(where, conv, name='conv', scale=1/128, quantize=False): 339 w = conv.weight.detach().cpu().permute(0, 1, 3, 2).numpy().copy() 340 if conv.bias is None: 341 b = np.zeros(conv.out_channels, dtype=w.dtype) 342 else: 343 b = conv.bias.detach().cpu().numpy().copy() 344 345 if isinstance(where, CWriter): 346 return print_conv2d_layer(where, name, w, b, scale=scale, quantize=quantize) 347 348 else: 349 os.makedirs(where, exist_ok=True) 350 351 np.save(os.path.join(where, 'weight_oiwh.npy'), w) 352 353 np.save(os.path.join(where, 'bias.npy'), b) 354 355def load_torch_conv2d_weights(where, conv): 356 with torch.no_grad(): 357 w = np.load(os.path.join(where, 'weight_oiwh.npy')) 358 conv.weight.set_(torch.from_numpy(w).permute(0, 1, 3, 2)) 359 if type(conv.bias) != type(None): 360 b = np.load(os.path.join(where, 'bias.npy')) 361 if conv.bias is not None: 362 conv.bias.set_(torch.from_numpy(b)) 363 364 365def dump_torch_embedding_weights(where, embed, name='embed', scale=1/128, sparse=False, diagonal=False, quantize=False): 366 367 w = embed.weight.detach().cpu().numpy().copy().transpose() 368 b = np.zeros(w.shape[0], dtype=w.dtype) 369 370 if isinstance(where, CWriter): 371 return print_dense_layer(where, name, w, b, scale=scale, format='torch', sparse=sparse, diagonal=diagonal, quantize=quantize) 372 373 else: 374 os.makedirs(where, exist_ok=True) 375 376 np.save(os.path.join(where, 'weight.npy'), w) 377 np.save(os.path.join(where, 'bias.npy'), b) 378 379 380def load_torch_embedding_weights(where, emb): 381 382 w = np.load(os.path.join(where, 'weight.npy')) 383 384 with torch.no_grad(): 385 emb.weight.set_(torch.from_numpy(w)) 386 387def dump_torch_weights(where, module, name=None, verbose=False, **kwargs): 388 """ generic function for dumping weights of some torch.nn.Module """ 389 if verbose and name is not None: 390 print(f"printing layer {name} of type {type(module)}...") 391 if isinstance(module, torch.nn.Linear): 392 return dump_torch_dense_weights(where, module, name, **kwargs) 393 elif isinstance(module, torch.nn.GRU): 394 return dump_torch_gru_weights(where, module, name, **kwargs) 395 elif isinstance(module, torch.nn.GRUCell): 396 return dump_torch_grucell_weights(where, module, name, **kwargs) 397 elif isinstance(module, torch.nn.Conv1d): 398 return dump_torch_conv1d_weights(where, module, name, **kwargs) 399 elif isinstance(module, torch.nn.Conv2d): 400 return dump_torch_conv2d_weights(where, module, name, **kwargs) 401 elif isinstance(module, torch.nn.Embedding): 402 return dump_torch_embedding_weights(where, module, name, **kwargs) 403 elif isinstance(module, torch.nn.ConvTranspose1d): 404 return dump_torch_tconv1d_weights(where, module, name, **kwargs) 405 else: 406 if has_osce: 407 if isinstance(module, LimitedAdaptiveConv1d): 408 dump_torch_adaptive_conv1d_weights(where, module, name, **kwargs) 409 elif isinstance(module, LimitedAdaptiveComb1d): 410 dump_torch_adaptive_comb1d_weights(where, module, name, **kwargs) 411 elif isinstance(module, TDShaper): 412 dump_torch_tdshaper(where, module, name, **kwargs) 413 else: 414 raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported') 415 else: 416 raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported') 417 418def load_torch_weights(where, module): 419 """ generic function for loading weights of some torch.nn.Module """ 420 if isinstance(module, torch.nn.Linear): 421 load_torch_dense_weights(where, module) 422 elif isinstance(module, torch.nn.GRU): 423 load_torch_gru_weights(where, module) 424 elif isinstance(module, torch.nn.Conv1d): 425 load_torch_conv1d_weights(where, module) 426 elif isinstance(module, torch.nn.Conv2d): 427 load_torch_conv2d_weights(where, module) 428 elif isinstance(module, torch.nn.Embedding): 429 load_torch_embedding_weights(where, module) 430 elif isinstance(module, torch.nn.ConvTranspose1d): 431 return load_torch_tconv1d_weights(where, module) 432 else: 433 raise ValueError(f'load_torch_weights: layer of type {type(module)} not supported') 434