xref: /aosp_15_r20/external/libopus/dnn/torch/weight-exchange/wexchange/tf/tf.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31
32import tensorflow as tf
33import numpy as np
34
35from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer
36
37def dump_tf_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128):
38
39
40    assert gru.activation == tf.keras.activations.tanh
41    assert gru.recurrent_activation == tf.keras.activations.sigmoid
42    assert gru.reset_after == True
43
44    w_ih = gru.weights[0].numpy().transpose().copy()
45    w_hh = gru.weights[1].numpy().transpose().copy()
46    b_ih = gru.weights[2].numpy()[0].copy()
47    b_hh = gru.weights[2].numpy()[1].copy()
48
49    if isinstance(where, CWriter):
50        return print_gru_layer(where, name, w_ih, w_hh, b_ih, b_hh, format='tf', input_sparse=input_sparse, recurrent_sparse=recurrent_sparse, quantize=quantize, scale=scale, recurrent_scale=recurrent_scale)
51    else:
52        os.makedirs(where, exist_ok=True)
53
54        # zrn => rzn
55        N = w_ih.shape[0] // 3
56        for x in [w_ih, w_hh, b_ih, b_hh]:
57            tmp = x[0:N].copy()
58            x[0:N] = x[N:2*N]
59            x[N:2*N] = tmp
60
61        np.save(os.path.join(where, 'weight_ih_rzn.npy'), w_ih)
62        np.save(os.path.join(where, 'weight_hh_rzn.npy'), w_hh)
63        np.save(os.path.join(where, 'bias_ih_rzn.npy'), b_ih)
64        np.save(os.path.join(where, 'bias_hh_rzn.npy'), b_hh)
65
66
67def load_tf_gru_weights(path, gru):
68
69    assert gru.activation == tf.keras.activations.tanh
70    assert gru.recurrent_activation == tf.keras.activations.sigmoid
71    assert gru.reset_after == True
72
73    w_ih = np.load(os.path.join(path, 'weight_ih_rzn.npy'))
74    w_hh = np.load(os.path.join(path, 'weight_hh_rzn.npy'))
75    b_ih = np.load(os.path.join(path, 'bias_ih_rzn.npy'))
76    b_hh = np.load(os.path.join(path, 'bias_hh_rzn.npy'))
77
78    # rzn => zrn
79    N = w_ih.shape[0] // 3
80    for x in [w_ih, w_hh, b_ih, b_hh]:
81        tmp = x[0:N].copy()
82        x[0:N] = x[N:2*N]
83        x[N:2*N] = tmp
84
85    gru.weights[0].assign(tf.convert_to_tensor(w_ih.transpose()))
86    gru.weights[1].assign(tf.convert_to_tensor(w_hh.transpose()))
87    gru.weights[2].assign(tf.convert_to_tensor(np.vstack((b_ih, b_hh))))
88
89
90def dump_tf_dense_weights(where, dense, name='dense', scale=1/128, sparse=False, diagonal=False, quantize=False):
91
92    w = dense.weights[0].numpy()
93    if dense.bias is None:
94        b = np.zeros(dense.units, dtype=w.dtype)
95    else:
96        b = dense.bias.numpy()
97
98
99
100    if isinstance(where, CWriter):
101        return print_dense_layer(where, name, w, b, scale=scale, format='tf', sparse=sparse, diagonal=diagonal, quantize=quantize)
102
103    else:
104        os.makedirs(where, exist_ok=True)
105
106        np.save(os.path.join(where, 'weight.npy'), w.transpose())
107        np.save(os.path.join(where, 'bias.npy'), b)
108
109
110def load_tf_dense_weights(path, dense):
111
112    w = np.load(os.path.join(path, 'weight.npy')).transpose()
113    b = np.load(os.path.join(path, 'bias.npy'))
114
115    dense.weights[0].assign(tf.convert_to_tensor(w))
116    if dense.bias is not None:
117        dense.weights[1].assign(tf.convert_to_tensor(b))
118
119
120def dump_tf_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False):
121
122    assert conv.data_format == 'channels_last'
123
124    w = conv.weights[0].numpy().copy()
125    if conv.bias is None:
126        b = np.zeros(conv.filters, dtype=w.dtype)
127    else:
128        b = conv.bias.numpy()
129
130    if isinstance(where, CWriter):
131        return print_conv1d_layer(where, name, w, b, scale=scale, format='tf', quantize=quantize)
132    else:
133        os.makedirs(where, exist_ok=True)
134
135        w = np.transpose(w, (2, 1, 0))
136        np.save(os.path.join(where, 'weight_oik.npy'), w)
137        np.save(os.path.join(where, 'bias.npy'), b)
138
139
140def load_tf_conv1d_weights(path, conv):
141
142    w = np.load(os.path.join(path, 'weight_oik.npy'))
143    b = np.load(os.path.join(path, 'bias.npy'))
144
145    w = np.transpose(w, (2, 1, 0))
146
147    conv.weights[0].assign(tf.convert_to_tensor(w))
148    if conv.bias is not None:
149        conv.weights[1].assign(tf.convert_to_tensor(b))
150
151
152def dump_tf_embedding_weights(path, emb):
153    os.makedirs(path, exist_ok=True)
154
155    w = emb.weights[0].numpy()
156    np.save(os.path.join(path, 'weight.npy'), w)
157
158
159
160def load_tf_embedding_weights(path, emb):
161
162    w = np.load(os.path.join(path, 'weight.npy'))
163    emb.weights[0].assign(tf.convert_to_tensor(w))
164
165
166def dump_tf_weights(path, module):
167    if isinstance(module, tf.keras.layers.Dense):
168        dump_tf_dense_weights(path, module)
169    elif isinstance(module, tf.keras.layers.GRU):
170        dump_tf_gru_weights(path, module)
171    elif isinstance(module, tf.keras.layers.Conv1D):
172        dump_tf_conv1d_weights(path, module)
173    elif isinstance(module, tf.keras.layers.Embedding):
174        dump_tf_embedding_weights(path, module)
175    else:
176        raise ValueError(f'dump_tf_weights: layer of type {type(module)} not supported')
177
178def load_tf_weights(path, module):
179    if isinstance(module, tf.keras.layers.Dense):
180        load_tf_dense_weights(path, module)
181    elif isinstance(module, tf.keras.layers.GRU):
182        load_tf_gru_weights(path, module)
183    elif isinstance(module, tf.keras.layers.Conv1D):
184        load_tf_conv1d_weights(path, module)
185    elif isinstance(module, tf.keras.layers.Embedding):
186        load_tf_embedding_weights(path, module)
187    else:
188        raise ValueError(f'dump_tf_weights: layer of type {type(module)} not supported')