1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Built-in regularizers.""" 16# pylint: disable=invalid-name 17 18import math 19 20from tensorflow.python.keras import backend 21from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 22from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 23from tensorflow.python.ops import math_ops 24from tensorflow.python.util.tf_export import keras_export 25 26 27def _check_penalty_number(x): 28 """check penalty number availability, raise ValueError if failed.""" 29 if not isinstance(x, (float, int)): 30 raise ValueError(('Value: {} is not a valid regularization penalty number, ' 31 'expected an int or float value').format(x)) 32 33 if math.isinf(x) or math.isnan(x): 34 raise ValueError( 35 ('Value: {} is not a valid regularization penalty number, ' 36 'a positive/negative infinity or NaN is not a property value' 37 ).format(x)) 38 39 40def _none_to_default(inputs, default): 41 return default if inputs is None else default 42 43 44@keras_export('keras.regularizers.Regularizer') 45class Regularizer(object): 46 """Regularizer base class. 47 48 Regularizers allow you to apply penalties on layer parameters or layer 49 activity during optimization. These penalties are summed into the loss 50 function that the network optimizes. 51 52 Regularization penalties are applied on a per-layer basis. The exact API will 53 depend on the layer, but many layers (e.g. `Dense`, `Conv1D`, `Conv2D` and 54 `Conv3D`) have a unified API. 55 56 These layers expose 3 keyword arguments: 57 58 - `kernel_regularizer`: Regularizer to apply a penalty on the layer's kernel 59 - `bias_regularizer`: Regularizer to apply a penalty on the layer's bias 60 - `activity_regularizer`: Regularizer to apply a penalty on the layer's output 61 62 All layers (including custom layers) expose `activity_regularizer` as a 63 settable property, whether or not it is in the constructor arguments. 64 65 The value returned by the `activity_regularizer` is divided by the input 66 batch size so that the relative weighting between the weight regularizers and 67 the activity regularizers does not change with the batch size. 68 69 You can access a layer's regularization penalties by calling `layer.losses` 70 after calling the layer on inputs. 71 72 ## Example 73 74 >>> layer = tf.keras.layers.Dense( 75 ... 5, input_dim=5, 76 ... kernel_initializer='ones', 77 ... kernel_regularizer=tf.keras.regularizers.L1(0.01), 78 ... activity_regularizer=tf.keras.regularizers.L2(0.01)) 79 >>> tensor = tf.ones(shape=(5, 5)) * 2.0 80 >>> out = layer(tensor) 81 82 >>> # The kernel regularization term is 0.25 83 >>> # The activity regularization term (after dividing by the batch size) is 5 84 >>> tf.math.reduce_sum(layer.losses) 85 <tf.Tensor: shape=(), dtype=float32, numpy=5.25> 86 87 ## Available penalties 88 89 ```python 90 tf.keras.regularizers.L1(0.3) # L1 Regularization Penalty 91 tf.keras.regularizers.L2(0.1) # L2 Regularization Penalty 92 tf.keras.regularizers.L1L2(l1=0.01, l2=0.01) # L1 + L2 penalties 93 ``` 94 95 ## Directly calling a regularizer 96 97 Compute a regularization loss on a tensor by directly calling a regularizer 98 as if it is a one-argument function. 99 100 E.g. 101 >>> regularizer = tf.keras.regularizers.L2(2.) 102 >>> tensor = tf.ones(shape=(5, 5)) 103 >>> regularizer(tensor) 104 <tf.Tensor: shape=(), dtype=float32, numpy=50.0> 105 106 107 ## Developing new regularizers 108 109 Any function that takes in a weight matrix and returns a scalar 110 tensor can be used as a regularizer, e.g.: 111 112 >>> @tf.keras.utils.register_keras_serializable(package='Custom', name='l1') 113 ... def l1_reg(weight_matrix): 114 ... return 0.01 * tf.math.reduce_sum(tf.math.abs(weight_matrix)) 115 ... 116 >>> layer = tf.keras.layers.Dense(5, input_dim=5, 117 ... kernel_initializer='ones', kernel_regularizer=l1_reg) 118 >>> tensor = tf.ones(shape=(5, 5)) 119 >>> out = layer(tensor) 120 >>> layer.losses 121 [<tf.Tensor: shape=(), dtype=float32, numpy=0.25>] 122 123 Alternatively, you can write your custom regularizers in an 124 object-oriented way by extending this regularizer base class, e.g.: 125 126 >>> @tf.keras.utils.register_keras_serializable(package='Custom', name='l2') 127 ... class L2Regularizer(tf.keras.regularizers.Regularizer): 128 ... def __init__(self, l2=0.): # pylint: disable=redefined-outer-name 129 ... self.l2 = l2 130 ... 131 ... def __call__(self, x): 132 ... return self.l2 * tf.math.reduce_sum(tf.math.square(x)) 133 ... 134 ... def get_config(self): 135 ... return {'l2': float(self.l2)} 136 ... 137 >>> layer = tf.keras.layers.Dense( 138 ... 5, input_dim=5, kernel_initializer='ones', 139 ... kernel_regularizer=L2Regularizer(l2=0.5)) 140 141 >>> tensor = tf.ones(shape=(5, 5)) 142 >>> out = layer(tensor) 143 >>> layer.losses 144 [<tf.Tensor: shape=(), dtype=float32, numpy=12.5>] 145 146 ### A note on serialization and deserialization: 147 148 Registering the regularizers as serializable is optional if you are just 149 training and executing models, exporting to and from SavedModels, or saving 150 and loading weight checkpoints. 151 152 Registration is required for Keras `model_to_estimator`, saving and 153 loading models to HDF5 formats, Keras model cloning, some visualization 154 utilities, and exporting models to and from JSON. If using this functionality, 155 you must make sure any python process running your model has also defined 156 and registered your custom regularizer. 157 158 `tf.keras.utils.register_keras_serializable` is only available in TF 2.1 and 159 beyond. In earlier versions of TensorFlow you must pass your custom 160 regularizer to the `custom_objects` argument of methods that expect custom 161 regularizers to be registered as serializable. 162 """ 163 164 def __call__(self, x): 165 """Compute a regularization penalty from an input tensor.""" 166 return 0. 167 168 @classmethod 169 def from_config(cls, config): 170 """Creates a regularizer from its config. 171 172 This method is the reverse of `get_config`, 173 capable of instantiating the same regularizer from the config 174 dictionary. 175 176 This method is used by Keras `model_to_estimator`, saving and 177 loading models to HDF5 formats, Keras model cloning, some visualization 178 utilities, and exporting models to and from JSON. 179 180 Args: 181 config: A Python dictionary, typically the output of get_config. 182 183 Returns: 184 A regularizer instance. 185 """ 186 return cls(**config) 187 188 def get_config(self): 189 """Returns the config of the regularizer. 190 191 An regularizer config is a Python dictionary (serializable) 192 containing all configuration parameters of the regularizer. 193 The same regularizer can be reinstantiated later 194 (without any saved state) from this configuration. 195 196 This method is optional if you are just training and executing models, 197 exporting to and from SavedModels, or using weight checkpoints. 198 199 This method is required for Keras `model_to_estimator`, saving and 200 loading models to HDF5 formats, Keras model cloning, some visualization 201 utilities, and exporting models to and from JSON. 202 203 Returns: 204 Python dictionary. 205 """ 206 raise NotImplementedError(str(self) + ' does not implement get_config()') 207 208 209@keras_export('keras.regularizers.L1L2') 210class L1L2(Regularizer): 211 """A regularizer that applies both L1 and L2 regularization penalties. 212 213 The L1 regularization penalty is computed as: 214 `loss = l1 * reduce_sum(abs(x))` 215 216 The L2 regularization penalty is computed as 217 `loss = l2 * reduce_sum(square(x))` 218 219 L1L2 may be passed to a layer as a string identifier: 220 221 >>> dense = tf.keras.layers.Dense(3, kernel_regularizer='l1_l2') 222 223 In this case, the default values used are `l1=0.01` and `l2=0.01`. 224 225 Attributes: 226 l1: Float; L1 regularization factor. 227 l2: Float; L2 regularization factor. 228 """ 229 230 def __init__(self, l1=0., l2=0.): # pylint: disable=redefined-outer-name 231 # The default value for l1 and l2 are different from the value in l1_l2 232 # for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2 233 # and no l1 penalty. 234 l1 = 0. if l1 is None else l1 235 l2 = 0. if l2 is None else l2 236 _check_penalty_number(l1) 237 _check_penalty_number(l2) 238 239 self.l1 = backend.cast_to_floatx(l1) 240 self.l2 = backend.cast_to_floatx(l2) 241 242 def __call__(self, x): 243 regularization = backend.constant(0., dtype=x.dtype) 244 if self.l1: 245 regularization += self.l1 * math_ops.reduce_sum(math_ops.abs(x)) 246 if self.l2: 247 regularization += self.l2 * math_ops.reduce_sum(math_ops.square(x)) 248 return regularization 249 250 def get_config(self): 251 return {'l1': float(self.l1), 'l2': float(self.l2)} 252 253 254@keras_export('keras.regularizers.L1', 'keras.regularizers.l1') 255class L1(Regularizer): 256 """A regularizer that applies a L1 regularization penalty. 257 258 The L1 regularization penalty is computed as: 259 `loss = l1 * reduce_sum(abs(x))` 260 261 L1 may be passed to a layer as a string identifier: 262 263 >>> dense = tf.keras.layers.Dense(3, kernel_regularizer='l1') 264 265 In this case, the default value used is `l1=0.01`. 266 267 Attributes: 268 l1: Float; L1 regularization factor. 269 """ 270 271 def __init__(self, l1=0.01, **kwargs): # pylint: disable=redefined-outer-name 272 l1 = kwargs.pop('l', l1) # Backwards compatibility 273 if kwargs: 274 raise TypeError('Argument(s) not recognized: %s' % (kwargs,)) 275 276 l1 = 0.01 if l1 is None else l1 277 _check_penalty_number(l1) 278 279 self.l1 = backend.cast_to_floatx(l1) 280 281 def __call__(self, x): 282 return self.l1 * math_ops.reduce_sum(math_ops.abs(x)) 283 284 def get_config(self): 285 return {'l1': float(self.l1)} 286 287 288@keras_export('keras.regularizers.L2', 'keras.regularizers.l2') 289class L2(Regularizer): 290 """A regularizer that applies a L2 regularization penalty. 291 292 The L2 regularization penalty is computed as: 293 `loss = l2 * reduce_sum(square(x))` 294 295 L2 may be passed to a layer as a string identifier: 296 297 >>> dense = tf.keras.layers.Dense(3, kernel_regularizer='l2') 298 299 In this case, the default value used is `l2=0.01`. 300 301 Attributes: 302 l2: Float; L2 regularization factor. 303 """ 304 305 def __init__(self, l2=0.01, **kwargs): # pylint: disable=redefined-outer-name 306 l2 = kwargs.pop('l', l2) # Backwards compatibility 307 if kwargs: 308 raise TypeError('Argument(s) not recognized: %s' % (kwargs,)) 309 310 l2 = 0.01 if l2 is None else l2 311 _check_penalty_number(l2) 312 313 self.l2 = backend.cast_to_floatx(l2) 314 315 def __call__(self, x): 316 return self.l2 * math_ops.reduce_sum(math_ops.square(x)) 317 318 def get_config(self): 319 return {'l2': float(self.l2)} 320 321 322@keras_export('keras.regularizers.l1_l2') 323def l1_l2(l1=0.01, l2=0.01): # pylint: disable=redefined-outer-name 324 r"""Create a regularizer that applies both L1 and L2 penalties. 325 326 The L1 regularization penalty is computed as: 327 `loss = l1 * reduce_sum(abs(x))` 328 329 The L2 regularization penalty is computed as: 330 `loss = l2 * reduce_sum(square(x))` 331 332 Args: 333 l1: Float; L1 regularization factor. 334 l2: Float; L2 regularization factor. 335 336 Returns: 337 An L1L2 Regularizer with the given regularization factors. 338 """ 339 return L1L2(l1=l1, l2=l2) 340 341 342# Deserialization aliases. 343l1 = L1 344l2 = L2 345 346 347@keras_export('keras.regularizers.serialize') 348def serialize(regularizer): 349 return serialize_keras_object(regularizer) 350 351 352@keras_export('keras.regularizers.deserialize') 353def deserialize(config, custom_objects=None): 354 if config == 'l1_l2': 355 # Special case necessary since the defaults used for "l1_l2" (string) 356 # differ from those of the L1L2 class. 357 return L1L2(l1=0.01, l2=0.01) 358 return deserialize_keras_object( 359 config, 360 module_objects=globals(), 361 custom_objects=custom_objects, 362 printable_module_name='regularizer') 363 364 365@keras_export('keras.regularizers.get') 366def get(identifier): 367 """Retrieve a regularizer instance from a config or identifier.""" 368 if identifier is None: 369 return None 370 if isinstance(identifier, dict): 371 return deserialize(identifier) 372 elif isinstance(identifier, str): 373 return deserialize(str(identifier)) 374 elif callable(identifier): 375 return identifier 376 else: 377 raise ValueError( 378 'Could not interpret regularizer identifier: {}'.format(identifier)) 379