xref: /aosp_15_r20/external/libopus/dnn/torch/weight-exchange/wexchange/torch/torch.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31import sys
32
33import torch
34import numpy as np
35
36sys.path.append(sys.path.append(os.path.join(os.path.dirname(__file__), '../osce')))
37try:
38    import utils.layers as osce_layers
39    from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
40    from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
41    from utils.layers.td_shaper import TDShaper
42    has_osce=True
43except:
44    has_osce=False
45
46from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer, print_tconv1d_layer, print_conv2d_layer
47
48def dump_torch_adaptive_conv1d_weights(where, adaconv, name='adaconv', scale=1/128, quantize=False):
49
50
51    w_kernel = adaconv.conv_kernel.weight.detach().cpu().numpy().copy()
52    b_kernel = adaconv.conv_kernel.bias.detach().cpu().numpy().copy()
53    w_gain = adaconv.filter_gain.weight.detach().cpu().numpy().copy()
54    b_gain = adaconv.filter_gain.bias.detach().cpu().numpy().copy()
55
56    if isinstance(where, CWriter):
57        # pad kernel for quantization
58        left_padding = adaconv.padding[0]
59        kernel_size = adaconv.kernel_size
60        in_channels = adaconv.in_channels
61        out_channels = adaconv.out_channels
62        feature_dim = adaconv.feature_dim
63
64        if quantize and kernel_size % 8:
65            kernel_padding = 8 - (kernel_size % 8)
66            w_kernel = np.concatenate(
67                (np.zeros((out_channels, in_channels, kernel_padding, feature_dim)), w_kernel.reshape(out_channels, in_channels, kernel_size, feature_dim)),
68                dtype=w_kernel.dtype,
69                axis=2).reshape(-1, feature_dim)
70            b_kernel = np.concatenate(
71                (np.zeros((out_channels, in_channels, kernel_padding)), b_kernel.reshape(out_channels, in_channels, kernel_size)),
72                dtype=b_kernel.dtype,
73                axis=2).reshape(-1)
74            left_padding += kernel_padding
75            kernel_size += kernel_padding
76
77        # write relevant scalar parameters to header file
78        where.header.write(f"""
79#define {name.upper()}_FILTER_GAIN_A {adaconv.filter_gain_a:f}f
80#define {name.upper()}_FILTER_GAIN_B {adaconv.filter_gain_b:f}f
81#define {name.upper()}_SHAPE_GAIN {adaconv.shape_gain:f}f
82#define {name.upper()}_KERNEL_SIZE {kernel_size}
83#define {name.upper()}_FRAME_SIZE {adaconv.frame_size}
84#define {name.upper()}_LEFT_PADDING {left_padding}
85#define {name.upper()}_OVERLAP_SIZE {adaconv.overlap_size}
86#define {name.upper()}_IN_CHANNELS {adaconv.in_channels}
87#define {name.upper()}_OUT_CHANNELS {adaconv.out_channels}
88#define {name.upper()}_NORM_P {adaconv.norm_p}
89#define {name.upper()}_FEATURE_DIM {adaconv.feature_dim}
90"""
91        )
92
93        print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=scale, format='torch', sparse=False, diagonal=False, quantize=quantize)
94        print_dense_layer(where, name + "_gain", w_gain, b_gain, format='torch', sparse=False, diagonal=False, quantize=False)
95
96
97    else:
98        np.save(where, 'weight_kernel.npy', w_kernel)
99        np.save(where, 'bias_kernel.npy', b_kernel)
100        np.save(where, 'weight_gain.npy', w_gain)
101        np.save(where, 'bias_gain.npy', b_gain)
102
103
104def dump_torch_adaptive_comb1d_weights(where, adaconv, name='adaconv', scale=1/128, quantize=False):
105
106
107    w_kernel = adaconv.conv_kernel.weight.detach().cpu().numpy().copy()
108    b_kernel = adaconv.conv_kernel.bias.detach().cpu().numpy().copy()
109    w_gain = adaconv.filter_gain.weight.detach().cpu().numpy().copy()
110    b_gain = adaconv.filter_gain.bias.detach().cpu().numpy().copy()
111    w_global_gain = adaconv.global_filter_gain.weight.detach().cpu().numpy().copy()
112    b_global_gain = adaconv.global_filter_gain.bias.detach().cpu().numpy().copy()
113
114
115    if isinstance(where, CWriter):
116        # pad kernel for quantization
117        left_padding = adaconv.padding[0]
118        kernel_size = adaconv.kernel_size
119
120        if quantize and w_kernel.shape[0] % 8:
121            kernel_padding = 8 - (w_kernel.shape[0] % 8)
122            w_kernel = np.concatenate((np.zeros((kernel_padding, w_kernel.shape[1])), w_kernel), dtype=w_kernel.dtype)
123            b_kernel = np.concatenate((np.zeros((kernel_padding)), b_kernel), dtype=b_kernel.dtype)
124            left_padding += kernel_padding
125            kernel_size += kernel_padding
126        # write relevant scalar parameters to header file
127        where.header.write(f"""
128#define {name.upper()}_FILTER_GAIN_A {adaconv.filter_gain_a:f}f
129#define {name.upper()}_FILTER_GAIN_B {adaconv.filter_gain_b:f}f
130#define {name.upper()}_LOG_GAIN_LIMIT {adaconv.log_gain_limit:f}f
131#define {name.upper()}_KERNEL_SIZE {kernel_size}
132#define {name.upper()}_LEFT_PADDING {left_padding}
133#define {name.upper()}_FRAME_SIZE {adaconv.frame_size}
134#define {name.upper()}_OVERLAP_SIZE {adaconv.overlap_size}
135#define {name.upper()}_IN_CHANNELS {adaconv.in_channels}
136#define {name.upper()}_OUT_CHANNELS {adaconv.out_channels}
137#define {name.upper()}_NORM_P {adaconv.norm_p}
138#define {name.upper()}_FEATURE_DIM {adaconv.feature_dim}
139#define {name.upper()}_MAX_LAG {adaconv.max_lag}
140"""
141        )
142
143        print_dense_layer(where, name + "_kernel", w_kernel, b_kernel, scale=scale, format='torch', sparse=False, diagonal=False, quantize=quantize)
144        print_dense_layer(where, name + "_gain", w_gain, b_gain, format='torch', sparse=False, diagonal=False, quantize=False)
145        print_dense_layer(where, name + "_global_gain", w_global_gain, b_global_gain, format='torch', sparse=False, diagonal=False, quantize=False)
146
147
148    else:
149        np.save(where, 'weight_kernel.npy', w_kernel)
150        np.save(where, 'bias_kernel.npy', b_kernel)
151        np.save(where, 'weight_gain.npy', w_gain)
152        np.save(where, 'bias_gain.npy', b_gain)
153        np.save(where, 'weight_global_gain.npy', w_global_gain)
154        np.save(where, 'bias_global_gain.npy', b_global_gain)
155
156def dump_torch_tdshaper(where, shaper, name='tdshaper', quantize=False, scale=1/128):
157
158    if isinstance(where, CWriter):
159        where.header.write(f"""
160#define {name.upper()}_FEATURE_DIM {shaper.feature_dim}
161#define {name.upper()}_FRAME_SIZE {shaper.frame_size}
162#define {name.upper()}_AVG_POOL_K {shaper.avg_pool_k}
163#define {name.upper()}_INNOVATE {1 if shaper.innovate else 0}
164#define {name.upper()}_POOL_AFTER {1 if shaper.pool_after else 0}
165"""
166        )
167
168    dump_torch_conv1d_weights(where, shaper.feature_alpha1_f, name + "_alpha1_f", quantize=quantize, scale=scale)
169    dump_torch_conv1d_weights(where, shaper.feature_alpha1_t, name + "_alpha1_t")
170    dump_torch_conv1d_weights(where, shaper.feature_alpha2, name + "_alpha2")
171
172    if shaper.innovate:
173        dump_torch_conv1d_weights(where, shaper.feature_alpha1b, name + "_alpha1b")
174        dump_torch_conv1d_weights(where, shaper.feature_alpha1c, name + "_alpha1c")
175        dump_torch_conv1d_weights(where, shaper.feature_alpha2b, name + "_alpha2b")
176        dump_torch_conv1d_weights(where, shaper.feature_alpha2c, name + "_alpha2c")
177
178
179
180def dump_torch_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128):
181
182    assert gru.num_layers == 1
183    assert gru.bidirectional == False
184
185    w_ih = gru.weight_ih_l0.detach().cpu().numpy().copy()
186    w_hh = gru.weight_hh_l0.detach().cpu().numpy().copy()
187    if hasattr(gru, 'bias_ih_l0'):
188        b_ih = gru.bias_ih_l0.detach().cpu().numpy().copy()
189    else:
190        b_ih = None
191    if hasattr(gru, 'bias_hh_l0'):
192        b_hh = gru.bias_hh_l0.detach().cpu().numpy().copy()
193    else:
194        b_hh = None
195
196    if isinstance(where, CWriter):
197        return print_gru_layer(where, name, w_ih, w_hh, b_ih, b_hh, format='torch', input_sparse=input_sparse, recurrent_sparse=recurrent_sparse, quantize=quantize, scale=scale, recurrent_scale=recurrent_scale)
198    else:
199        os.makedirs(where, exist_ok=True)
200
201        np.save(os.path.join(where, 'weight_ih_rzn.npy'), w_ih)
202        np.save(os.path.join(where, 'weight_hh_rzn.npy'), w_hh)
203        np.save(os.path.join(where, 'bias_ih_rzn.npy'), b_ih)
204        np.save(os.path.join(where, 'bias_hh_rzn.npy'), b_hh)
205
206
207def dump_torch_grucell_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128):
208
209    w_ih = gru.weight_ih.detach().cpu().numpy().copy()
210    w_hh = gru.weight_hh.detach().cpu().numpy().copy()
211    if hasattr(gru, 'bias_ih') and gru.bias_ih is not None:
212        b_ih = gru.bias_ih.detach().cpu().numpy().copy()
213    else:
214        b_ih = None
215    if hasattr(gru, 'bias_hh') and gru.bias_hh is not None:
216        b_hh = gru.bias_hh.detach().cpu().numpy().copy()
217    else:
218        b_hh = None
219
220    if isinstance(where, CWriter):
221        return print_gru_layer(where, name, w_ih, w_hh, b_ih, b_hh, format='torch', input_sparse=input_sparse, recurrent_sparse=recurrent_sparse, quantize=quantize, scale=scale, recurrent_scale=recurrent_scale)
222    else:
223        os.makedirs(where, exist_ok=True)
224
225        np.save(os.path.join(where, 'weight_ih_rzn.npy'), w_ih)
226        np.save(os.path.join(where, 'weight_hh_rzn.npy'), w_hh)
227        np.save(os.path.join(where, 'bias_ih_rzn.npy'), b_ih)
228        np.save(os.path.join(where, 'bias_hh_rzn.npy'), b_hh)
229
230
231
232def load_torch_gru_weights(where, gru):
233
234    assert gru.num_layers == 1
235    assert gru.bidirectional == False
236
237    w_ih = np.load(os.path.join(where, 'weight_ih_rzn.npy'))
238    w_hh = np.load(os.path.join(where, 'weight_hh_rzn.npy'))
239    b_ih = np.load(os.path.join(where, 'bias_ih_rzn.npy'))
240    b_hh = np.load(os.path.join(where, 'bias_hh_rzn.npy'))
241
242    with torch.no_grad():
243        gru.weight_ih_l0.set_(torch.from_numpy(w_ih))
244        gru.weight_hh_l0.set_(torch.from_numpy(w_hh))
245        gru.bias_ih_l0.set_(torch.from_numpy(b_ih))
246        gru.bias_hh_l0.set_(torch.from_numpy(b_hh))
247
248
249def dump_torch_dense_weights(where, dense, name='dense', scale=1/128, sparse=False, diagonal=False, quantize=False):
250
251    w = dense.weight.detach().cpu().numpy().copy()
252    if dense.bias is None:
253        b = np.zeros(dense.out_features, dtype=w.dtype)
254    else:
255        b = dense.bias.detach().cpu().numpy().copy()
256
257    if isinstance(where, CWriter):
258        return print_dense_layer(where, name, w, b, scale=scale, format='torch', sparse=sparse, diagonal=diagonal, quantize=quantize)
259
260    else:
261        os.makedirs(where, exist_ok=True)
262
263        np.save(os.path.join(where, 'weight.npy'), w)
264        np.save(os.path.join(where, 'bias.npy'), b)
265
266
267def load_torch_dense_weights(where, dense):
268
269    w = np.load(os.path.join(where, 'weight.npy'))
270    b = np.load(os.path.join(where, 'bias.npy'))
271
272    with torch.no_grad():
273        dense.weight.set_(torch.from_numpy(w))
274        if dense.bias is not None:
275            dense.bias.set_(torch.from_numpy(b))
276
277
278def dump_torch_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False):
279
280    w = conv.weight.detach().cpu().numpy().copy()
281    if conv.bias is None:
282        b = np.zeros(conv.out_channels, dtype=w.dtype)
283    else:
284        b = conv.bias.detach().cpu().numpy().copy()
285
286    if isinstance(where, CWriter):
287
288        return print_conv1d_layer(where, name, w, b, scale=scale, format='torch', quantize=quantize, sparse=sparse)
289    else:
290        os.makedirs(where, exist_ok=True)
291
292        np.save(os.path.join(where, 'weight_oik.npy'), w)
293
294        np.save(os.path.join(where, 'bias.npy'), b)
295
296
297def load_torch_conv1d_weights(where, conv):
298
299    with torch.no_grad():
300        w = np.load(os.path.join(where, 'weight_oik.npy'))
301        conv.weight.set_(torch.from_numpy(w))
302        if type(conv.bias) != type(None):
303            b = np.load(os.path.join(where, 'bias.npy'))
304            if conv.bias is not None:
305                conv.bias.set_(torch.from_numpy(b))
306
307
308def dump_torch_tconv1d_weights(where, conv, name='conv', scale=1/128, quantize=False, sparse=False):
309
310    w = conv.weight.detach().cpu().numpy().copy()
311    if conv.bias is None:
312        b = np.zeros(conv.out_channels, dtype=w.dtype)
313    else:
314        b = conv.bias.detach().cpu().numpy().copy()
315
316    if isinstance(where, CWriter):
317
318        return print_tconv1d_layer(where, name, w, b, conv.stride[0], scale=scale, quantize=quantize, sparse=sparse)
319    else:
320        os.makedirs(where, exist_ok=True)
321
322        np.save(os.path.join(where, 'weight_oik.npy'), w)
323
324        np.save(os.path.join(where, 'bias.npy'), b)
325
326
327def load_torch_tconv1d_weights(where, conv):
328
329    with torch.no_grad():
330        w = np.load(os.path.join(where, 'weight_oik.npy'))
331        conv.weight.set_(torch.from_numpy(w))
332        if type(conv.bias) != type(None):
333            b = np.load(os.path.join(where, 'bias.npy'))
334            if conv.bias is not None:
335                conv.bias.set_(torch.from_numpy(b))
336
337
338def dump_torch_conv2d_weights(where, conv, name='conv', scale=1/128, quantize=False):
339    w = conv.weight.detach().cpu().permute(0, 1, 3, 2).numpy().copy()
340    if conv.bias is None:
341        b = np.zeros(conv.out_channels, dtype=w.dtype)
342    else:
343        b = conv.bias.detach().cpu().numpy().copy()
344
345    if isinstance(where, CWriter):
346        return print_conv2d_layer(where, name, w, b, scale=scale, quantize=quantize)
347
348    else:
349        os.makedirs(where, exist_ok=True)
350
351        np.save(os.path.join(where, 'weight_oiwh.npy'), w)
352
353        np.save(os.path.join(where, 'bias.npy'), b)
354
355def load_torch_conv2d_weights(where, conv):
356    with torch.no_grad():
357        w = np.load(os.path.join(where, 'weight_oiwh.npy'))
358        conv.weight.set_(torch.from_numpy(w).permute(0, 1, 3, 2))
359        if type(conv.bias) != type(None):
360            b = np.load(os.path.join(where, 'bias.npy'))
361            if conv.bias is not None:
362                conv.bias.set_(torch.from_numpy(b))
363
364
365def dump_torch_embedding_weights(where, embed, name='embed', scale=1/128, sparse=False, diagonal=False, quantize=False):
366
367    w = embed.weight.detach().cpu().numpy().copy().transpose()
368    b = np.zeros(w.shape[0], dtype=w.dtype)
369
370    if isinstance(where, CWriter):
371        return print_dense_layer(where, name, w, b, scale=scale, format='torch', sparse=sparse, diagonal=diagonal, quantize=quantize)
372
373    else:
374        os.makedirs(where, exist_ok=True)
375
376        np.save(os.path.join(where, 'weight.npy'), w)
377        np.save(os.path.join(where, 'bias.npy'), b)
378
379
380def load_torch_embedding_weights(where, emb):
381
382    w = np.load(os.path.join(where, 'weight.npy'))
383
384    with torch.no_grad():
385        emb.weight.set_(torch.from_numpy(w))
386
387def dump_torch_weights(where, module, name=None, verbose=False, **kwargs):
388    """ generic function for dumping weights of some torch.nn.Module """
389    if verbose and name is not None:
390        print(f"printing layer {name} of type {type(module)}...")
391    if isinstance(module, torch.nn.Linear):
392        return dump_torch_dense_weights(where, module, name, **kwargs)
393    elif isinstance(module, torch.nn.GRU):
394        return dump_torch_gru_weights(where, module, name, **kwargs)
395    elif isinstance(module, torch.nn.GRUCell):
396        return dump_torch_grucell_weights(where, module, name, **kwargs)
397    elif isinstance(module, torch.nn.Conv1d):
398        return dump_torch_conv1d_weights(where, module, name, **kwargs)
399    elif isinstance(module, torch.nn.Conv2d):
400        return dump_torch_conv2d_weights(where, module, name, **kwargs)
401    elif isinstance(module, torch.nn.Embedding):
402        return dump_torch_embedding_weights(where, module, name, **kwargs)
403    elif isinstance(module, torch.nn.ConvTranspose1d):
404        return dump_torch_tconv1d_weights(where, module, name, **kwargs)
405    else:
406        if has_osce:
407            if isinstance(module, LimitedAdaptiveConv1d):
408                dump_torch_adaptive_conv1d_weights(where, module, name, **kwargs)
409            elif isinstance(module, LimitedAdaptiveComb1d):
410                dump_torch_adaptive_comb1d_weights(where, module, name, **kwargs)
411            elif isinstance(module, TDShaper):
412                dump_torch_tdshaper(where, module, name, **kwargs)
413            else:
414                raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported')
415        else:
416            raise ValueError(f'dump_torch_weights: layer of type {type(module)} not supported')
417
418def load_torch_weights(where, module):
419    """ generic function for loading weights of some torch.nn.Module """
420    if isinstance(module, torch.nn.Linear):
421        load_torch_dense_weights(where, module)
422    elif isinstance(module, torch.nn.GRU):
423        load_torch_gru_weights(where, module)
424    elif isinstance(module, torch.nn.Conv1d):
425        load_torch_conv1d_weights(where, module)
426    elif isinstance(module, torch.nn.Conv2d):
427        load_torch_conv2d_weights(where, module)
428    elif isinstance(module, torch.nn.Embedding):
429        load_torch_embedding_weights(where, module)
430    elif isinstance(module, torch.nn.ConvTranspose1d):
431        return load_torch_tconv1d_weights(where, module)
432    else:
433        raise ValueError(f'load_torch_weights: layer of type {type(module)} not supported')
434