xref: /aosp_15_r20/external/libopus/dnn/training_tf2/mdense.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Lifrom tensorflow.keras import backend as K
2*a58d3d2aSXin Lifrom tensorflow.keras.layers import Layer, InputSpec
3*a58d3d2aSXin Lifrom tensorflow.keras import activations
4*a58d3d2aSXin Lifrom tensorflow.keras import initializers, regularizers, constraints
5*a58d3d2aSXin Liimport numpy as np
6*a58d3d2aSXin Liimport math
7*a58d3d2aSXin Li
8*a58d3d2aSXin Liclass MDense(Layer):
9*a58d3d2aSXin Li
10*a58d3d2aSXin Li    def __init__(self, outputs,
11*a58d3d2aSXin Li                 channels=2,
12*a58d3d2aSXin Li                 activation=None,
13*a58d3d2aSXin Li                 use_bias=True,
14*a58d3d2aSXin Li                 kernel_initializer='glorot_uniform',
15*a58d3d2aSXin Li                 bias_initializer='zeros',
16*a58d3d2aSXin Li                 kernel_regularizer=None,
17*a58d3d2aSXin Li                 bias_regularizer=None,
18*a58d3d2aSXin Li                 activity_regularizer=None,
19*a58d3d2aSXin Li                 kernel_constraint=None,
20*a58d3d2aSXin Li                 bias_constraint=None,
21*a58d3d2aSXin Li                 **kwargs):
22*a58d3d2aSXin Li        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
23*a58d3d2aSXin Li            kwargs['input_shape'] = (kwargs.pop('input_dim'),)
24*a58d3d2aSXin Li        super(MDense, self).__init__(**kwargs)
25*a58d3d2aSXin Li        self.units = outputs
26*a58d3d2aSXin Li        self.channels = channels
27*a58d3d2aSXin Li        self.activation = activations.get(activation)
28*a58d3d2aSXin Li        self.use_bias = use_bias
29*a58d3d2aSXin Li        self.kernel_initializer = initializers.get(kernel_initializer)
30*a58d3d2aSXin Li        self.bias_initializer = initializers.get(bias_initializer)
31*a58d3d2aSXin Li        self.kernel_regularizer = regularizers.get(kernel_regularizer)
32*a58d3d2aSXin Li        self.bias_regularizer = regularizers.get(bias_regularizer)
33*a58d3d2aSXin Li        self.activity_regularizer = regularizers.get(activity_regularizer)
34*a58d3d2aSXin Li        self.kernel_constraint = constraints.get(kernel_constraint)
35*a58d3d2aSXin Li        self.bias_constraint = constraints.get(bias_constraint)
36*a58d3d2aSXin Li        self.input_spec = InputSpec(min_ndim=2)
37*a58d3d2aSXin Li        self.supports_masking = True
38*a58d3d2aSXin Li
39*a58d3d2aSXin Li    def build(self, input_shape):
40*a58d3d2aSXin Li        assert len(input_shape) >= 2
41*a58d3d2aSXin Li        input_dim = input_shape[-1]
42*a58d3d2aSXin Li
43*a58d3d2aSXin Li        self.kernel = self.add_weight(shape=(self.units, input_dim, self.channels),
44*a58d3d2aSXin Li                                      initializer=self.kernel_initializer,
45*a58d3d2aSXin Li                                      name='kernel',
46*a58d3d2aSXin Li                                      regularizer=self.kernel_regularizer,
47*a58d3d2aSXin Li                                      constraint=self.kernel_constraint)
48*a58d3d2aSXin Li        if self.use_bias:
49*a58d3d2aSXin Li            self.bias = self.add_weight(shape=(self.units, self.channels),
50*a58d3d2aSXin Li                                        initializer=self.bias_initializer,
51*a58d3d2aSXin Li                                        name='bias',
52*a58d3d2aSXin Li                                        regularizer=self.bias_regularizer,
53*a58d3d2aSXin Li                                        constraint=self.bias_constraint)
54*a58d3d2aSXin Li        else:
55*a58d3d2aSXin Li            self.bias = None
56*a58d3d2aSXin Li        self.factor = self.add_weight(shape=(self.units, self.channels),
57*a58d3d2aSXin Li                                    initializer='ones',
58*a58d3d2aSXin Li                                    name='factor',
59*a58d3d2aSXin Li                                    regularizer=self.bias_regularizer,
60*a58d3d2aSXin Li                                    constraint=self.bias_constraint)
61*a58d3d2aSXin Li        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
62*a58d3d2aSXin Li        self.built = True
63*a58d3d2aSXin Li
64*a58d3d2aSXin Li    def call(self, inputs):
65*a58d3d2aSXin Li        output = K.dot(inputs, self.kernel)
66*a58d3d2aSXin Li        if self.use_bias:
67*a58d3d2aSXin Li            output = output + self.bias
68*a58d3d2aSXin Li        output = K.tanh(output) * self.factor
69*a58d3d2aSXin Li        output = K.sum(output, axis=-1)
70*a58d3d2aSXin Li        if self.activation is not None:
71*a58d3d2aSXin Li            output = self.activation(output)
72*a58d3d2aSXin Li        return output
73*a58d3d2aSXin Li
74*a58d3d2aSXin Li    def compute_output_shape(self, input_shape):
75*a58d3d2aSXin Li        assert input_shape and len(input_shape) >= 2
76*a58d3d2aSXin Li        assert input_shape[-1]
77*a58d3d2aSXin Li        output_shape = list(input_shape)
78*a58d3d2aSXin Li        output_shape[-1] = self.units
79*a58d3d2aSXin Li        return tuple(output_shape)
80*a58d3d2aSXin Li
81*a58d3d2aSXin Li    def get_config(self):
82*a58d3d2aSXin Li        config = {
83*a58d3d2aSXin Li            'units': self.units,
84*a58d3d2aSXin Li            'activation': activations.serialize(self.activation),
85*a58d3d2aSXin Li            'use_bias': self.use_bias,
86*a58d3d2aSXin Li            'kernel_initializer': initializers.serialize(self.kernel_initializer),
87*a58d3d2aSXin Li            'bias_initializer': initializers.serialize(self.bias_initializer),
88*a58d3d2aSXin Li            'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
89*a58d3d2aSXin Li            'bias_regularizer': regularizers.serialize(self.bias_regularizer),
90*a58d3d2aSXin Li            'activity_regularizer': regularizers.serialize(self.activity_regularizer),
91*a58d3d2aSXin Li            'kernel_constraint': constraints.serialize(self.kernel_constraint),
92*a58d3d2aSXin Li            'bias_constraint': constraints.serialize(self.bias_constraint)
93*a58d3d2aSXin Li        }
94*a58d3d2aSXin Li        base_config = super(MDense, self).get_config()
95*a58d3d2aSXin Li        return dict(list(base_config.items()) + list(config.items()))
96