1*a58d3d2aSXin Lifrom tensorflow.keras import backend as K 2*a58d3d2aSXin Lifrom tensorflow.keras.layers import Layer, InputSpec 3*a58d3d2aSXin Lifrom tensorflow.keras import activations 4*a58d3d2aSXin Lifrom tensorflow.keras import initializers, regularizers, constraints 5*a58d3d2aSXin Liimport numpy as np 6*a58d3d2aSXin Liimport math 7*a58d3d2aSXin Li 8*a58d3d2aSXin Liclass MDense(Layer): 9*a58d3d2aSXin Li 10*a58d3d2aSXin Li def __init__(self, outputs, 11*a58d3d2aSXin Li channels=2, 12*a58d3d2aSXin Li activation=None, 13*a58d3d2aSXin Li use_bias=True, 14*a58d3d2aSXin Li kernel_initializer='glorot_uniform', 15*a58d3d2aSXin Li bias_initializer='zeros', 16*a58d3d2aSXin Li kernel_regularizer=None, 17*a58d3d2aSXin Li bias_regularizer=None, 18*a58d3d2aSXin Li activity_regularizer=None, 19*a58d3d2aSXin Li kernel_constraint=None, 20*a58d3d2aSXin Li bias_constraint=None, 21*a58d3d2aSXin Li **kwargs): 22*a58d3d2aSXin Li if 'input_shape' not in kwargs and 'input_dim' in kwargs: 23*a58d3d2aSXin Li kwargs['input_shape'] = (kwargs.pop('input_dim'),) 24*a58d3d2aSXin Li super(MDense, self).__init__(**kwargs) 25*a58d3d2aSXin Li self.units = outputs 26*a58d3d2aSXin Li self.channels = channels 27*a58d3d2aSXin Li self.activation = activations.get(activation) 28*a58d3d2aSXin Li self.use_bias = use_bias 29*a58d3d2aSXin Li self.kernel_initializer = initializers.get(kernel_initializer) 30*a58d3d2aSXin Li self.bias_initializer = initializers.get(bias_initializer) 31*a58d3d2aSXin Li self.kernel_regularizer = regularizers.get(kernel_regularizer) 32*a58d3d2aSXin Li self.bias_regularizer = regularizers.get(bias_regularizer) 33*a58d3d2aSXin Li self.activity_regularizer = regularizers.get(activity_regularizer) 34*a58d3d2aSXin Li self.kernel_constraint = constraints.get(kernel_constraint) 35*a58d3d2aSXin Li self.bias_constraint = constraints.get(bias_constraint) 36*a58d3d2aSXin Li self.input_spec = InputSpec(min_ndim=2) 37*a58d3d2aSXin Li self.supports_masking = True 38*a58d3d2aSXin Li 39*a58d3d2aSXin Li def build(self, input_shape): 40*a58d3d2aSXin Li assert len(input_shape) >= 2 41*a58d3d2aSXin Li input_dim = input_shape[-1] 42*a58d3d2aSXin Li 43*a58d3d2aSXin Li self.kernel = self.add_weight(shape=(self.units, input_dim, self.channels), 44*a58d3d2aSXin Li initializer=self.kernel_initializer, 45*a58d3d2aSXin Li name='kernel', 46*a58d3d2aSXin Li regularizer=self.kernel_regularizer, 47*a58d3d2aSXin Li constraint=self.kernel_constraint) 48*a58d3d2aSXin Li if self.use_bias: 49*a58d3d2aSXin Li self.bias = self.add_weight(shape=(self.units, self.channels), 50*a58d3d2aSXin Li initializer=self.bias_initializer, 51*a58d3d2aSXin Li name='bias', 52*a58d3d2aSXin Li regularizer=self.bias_regularizer, 53*a58d3d2aSXin Li constraint=self.bias_constraint) 54*a58d3d2aSXin Li else: 55*a58d3d2aSXin Li self.bias = None 56*a58d3d2aSXin Li self.factor = self.add_weight(shape=(self.units, self.channels), 57*a58d3d2aSXin Li initializer='ones', 58*a58d3d2aSXin Li name='factor', 59*a58d3d2aSXin Li regularizer=self.bias_regularizer, 60*a58d3d2aSXin Li constraint=self.bias_constraint) 61*a58d3d2aSXin Li self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim}) 62*a58d3d2aSXin Li self.built = True 63*a58d3d2aSXin Li 64*a58d3d2aSXin Li def call(self, inputs): 65*a58d3d2aSXin Li output = K.dot(inputs, self.kernel) 66*a58d3d2aSXin Li if self.use_bias: 67*a58d3d2aSXin Li output = output + self.bias 68*a58d3d2aSXin Li output = K.tanh(output) * self.factor 69*a58d3d2aSXin Li output = K.sum(output, axis=-1) 70*a58d3d2aSXin Li if self.activation is not None: 71*a58d3d2aSXin Li output = self.activation(output) 72*a58d3d2aSXin Li return output 73*a58d3d2aSXin Li 74*a58d3d2aSXin Li def compute_output_shape(self, input_shape): 75*a58d3d2aSXin Li assert input_shape and len(input_shape) >= 2 76*a58d3d2aSXin Li assert input_shape[-1] 77*a58d3d2aSXin Li output_shape = list(input_shape) 78*a58d3d2aSXin Li output_shape[-1] = self.units 79*a58d3d2aSXin Li return tuple(output_shape) 80*a58d3d2aSXin Li 81*a58d3d2aSXin Li def get_config(self): 82*a58d3d2aSXin Li config = { 83*a58d3d2aSXin Li 'units': self.units, 84*a58d3d2aSXin Li 'activation': activations.serialize(self.activation), 85*a58d3d2aSXin Li 'use_bias': self.use_bias, 86*a58d3d2aSXin Li 'kernel_initializer': initializers.serialize(self.kernel_initializer), 87*a58d3d2aSXin Li 'bias_initializer': initializers.serialize(self.bias_initializer), 88*a58d3d2aSXin Li 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 89*a58d3d2aSXin Li 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 90*a58d3d2aSXin Li 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 91*a58d3d2aSXin Li 'kernel_constraint': constraints.serialize(self.kernel_constraint), 92*a58d3d2aSXin Li 'bias_constraint': constraints.serialize(self.bias_constraint) 93*a58d3d2aSXin Li } 94*a58d3d2aSXin Li base_config = super(MDense, self).get_config() 95*a58d3d2aSXin Li return dict(list(base_config.items()) + list(config.items())) 96