1""" module for handling extra model parameters for tf.keras models """ 2 3import tensorflow as tf 4 5 6def set_parameter(model, parameter_name, parameter_value, dtype='float32'): 7 """ stores parameter_value as non-trainable weight with name parameter_name:0 """ 8 9 weights = [weight for weight in model.weights if weight.name == (parameter_name + ":0")] 10 11 if len(weights) == 0: 12 model.add_weight(parameter_name, trainable=False, initializer=tf.keras.initializers.Constant(parameter_value), dtype=dtype) 13 elif len(weights) == 1: 14 weights[0].assign(parameter_value) 15 else: 16 raise ValueError(f"more than one weight starting with {parameter_name}:0 in model") 17 18 19def get_parameter(model, parameter_name, default=None): 20 """ returns parameter value if parameter is present in model and otherwise default """ 21 22 weights = [weight for weight in model.weights if weight.name == (parameter_name + ":0")] 23 24 if len(weights) == 0: 25 return default 26 elif len(weights) > 1: 27 raise ValueError(f"more than one weight starting with {parameter_name}:0 in model") 28 else: 29 return weights[0].numpy().item() 30