1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-classes-have-attributes 16"""Built-in loss functions.""" 17 18import abc 19import functools 20 21from tensorflow.python.autograph.core import ag_ctx 22from tensorflow.python.autograph.impl import api as autograph 23from tensorflow.python.distribute import distribution_strategy_context 24from tensorflow.python.eager import context 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import smart_cond 28from tensorflow.python.framework import tensor_spec 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.keras import backend 31from tensorflow.python.keras.utils import losses_utils 32from tensorflow.python.keras.utils import tf_utils 33from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 34from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import nn 39from tensorflow.python.ops.losses import losses_impl 40from tensorflow.python.ops.ragged import ragged_map_ops 41from tensorflow.python.ops.ragged import ragged_tensor 42from tensorflow.python.ops.ragged import ragged_util 43from tensorflow.python.util import dispatch 44from tensorflow.python.util.tf_export import keras_export 45from tensorflow.tools.docs import doc_controls 46 47 48@keras_export('keras.losses.Loss') 49class Loss: 50 """Loss base class. 51 52 To be implemented by subclasses: 53 * `call()`: Contains the logic for loss calculation using `y_true`, `y_pred`. 54 55 Example subclass implementation: 56 57 ```python 58 class MeanSquaredError(Loss): 59 60 def call(self, y_true, y_pred): 61 y_pred = tf.convert_to_tensor_v2(y_pred) 62 y_true = tf.cast(y_true, y_pred.dtype) 63 return tf.reduce_mean(math_ops.square(y_pred - y_true), axis=-1) 64 ``` 65 66 When used with `tf.distribute.Strategy`, outside of built-in training loops 67 such as `tf.keras` `compile` and `fit`, please use 'SUM' or 'NONE' reduction 68 types, and reduce losses explicitly in your training loop. Using 'AUTO' or 69 'SUM_OVER_BATCH_SIZE' will raise an error. 70 71 Please see this custom training [tutorial]( 72 https://www.tensorflow.org/tutorials/distribute/custom_training) for more 73 details on this. 74 75 You can implement 'SUM_OVER_BATCH_SIZE' using global batch size like: 76 77 ```python 78 with strategy.scope(): 79 loss_obj = tf.keras.losses.CategoricalCrossentropy( 80 reduction=tf.keras.losses.Reduction.NONE) 81 .... 82 loss = (tf.reduce_sum(loss_obj(labels, predictions)) * 83 (1. / global_batch_size)) 84 ``` 85 """ 86 87 def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name=None): 88 """Initializes `Loss` class. 89 90 Args: 91 reduction: Type of `tf.keras.losses.Reduction` to apply to 92 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 93 option will be determined by the usage context. For almost all cases 94 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 95 `tf.distribute.Strategy`, outside of built-in training loops such as 96 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 97 will raise an error. Please see this custom training [tutorial]( 98 https://www.tensorflow.org/tutorials/distribute/custom_training) for 99 more details. 100 name: Optional name for the instance. 101 """ 102 losses_utils.ReductionV2.validate(reduction) 103 self.reduction = reduction 104 self.name = name 105 # SUM_OVER_BATCH is only allowed in losses managed by `fit` or 106 # CannedEstimators. 107 self._allow_sum_over_batch_size = False 108 self._set_name_scope() 109 110 def _set_name_scope(self): 111 """Creates a valid `name_scope` name.""" 112 if self.name is None: 113 self._name_scope = self.__class__.__name__ 114 elif self.name == '<lambda>': 115 self._name_scope = 'lambda' 116 else: 117 # E.g. '_my_loss' => 'my_loss' 118 self._name_scope = self.name.strip('_') 119 120 def __call__(self, y_true, y_pred, sample_weight=None): 121 """Invokes the `Loss` instance. 122 123 Args: 124 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`, except 125 sparse loss functions such as sparse categorical crossentropy where 126 shape = `[batch_size, d0, .. dN-1]` 127 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]` 128 sample_weight: Optional `sample_weight` acts as a coefficient for the 129 loss. If a scalar is provided, then the loss is simply scaled by the 130 given value. If `sample_weight` is a tensor of size `[batch_size]`, then 131 the total loss for each sample of the batch is rescaled by the 132 corresponding element in the `sample_weight` vector. If the shape of 133 `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted to 134 this shape), then each loss element of `y_pred` is scaled 135 by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss 136 functions reduce by 1 dimension, usually axis=-1.) 137 138 Returns: 139 Weighted loss float `Tensor`. If `reduction` is `NONE`, this has 140 shape `[batch_size, d0, .. dN-1]`; otherwise, it is scalar. (Note `dN-1` 141 because all loss functions reduce by 1 dimension, usually axis=-1.) 142 143 Raises: 144 ValueError: If the shape of `sample_weight` is invalid. 145 """ 146 # If we are wrapping a lambda function strip '<>' from the name as it is not 147 # accepted in scope name. 148 graph_ctx = tf_utils.graph_context_for_symbolic_tensors( 149 y_true, y_pred, sample_weight) 150 with backend.name_scope(self._name_scope), graph_ctx: 151 if context.executing_eagerly(): 152 call_fn = self.call 153 else: 154 call_fn = autograph.tf_convert(self.call, ag_ctx.control_status_ctx()) 155 losses = call_fn(y_true, y_pred) 156 return losses_utils.compute_weighted_loss( 157 losses, sample_weight, reduction=self._get_reduction()) 158 159 @classmethod 160 def from_config(cls, config): 161 """Instantiates a `Loss` from its config (output of `get_config()`). 162 163 Args: 164 config: Output of `get_config()`. 165 166 Returns: 167 A `Loss` instance. 168 """ 169 return cls(**config) 170 171 def get_config(self): 172 """Returns the config dictionary for a `Loss` instance.""" 173 return {'reduction': self.reduction, 'name': self.name} 174 175 @abc.abstractmethod 176 @doc_controls.for_subclass_implementers 177 def call(self, y_true, y_pred): 178 """Invokes the `Loss` instance. 179 180 Args: 181 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`, except 182 sparse loss functions such as sparse categorical crossentropy where 183 shape = `[batch_size, d0, .. dN-1]` 184 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]` 185 186 Returns: 187 Loss values with the shape `[batch_size, d0, .. dN-1]`. 188 """ 189 raise NotImplementedError('Must be implemented in subclasses.') 190 191 def _get_reduction(self): 192 """Handles `AUTO` reduction cases and returns the reduction value.""" 193 if (not self._allow_sum_over_batch_size and 194 distribution_strategy_context.has_strategy() and 195 (self.reduction == losses_utils.ReductionV2.AUTO or 196 self.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE)): 197 raise ValueError( 198 'Please use `tf.keras.losses.Reduction.SUM` or ' 199 '`tf.keras.losses.Reduction.NONE` for loss reduction when losses are ' 200 'used with `tf.distribute.Strategy` outside of the built-in training ' 201 'loops. You can implement ' 202 '`tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` using global batch ' 203 'size like:\n```\nwith strategy.scope():\n' 204 ' loss_obj = tf.keras.losses.CategoricalCrossentropy(' 205 'reduction=tf.keras.losses.Reduction.NONE)\n....\n' 206 ' loss = tf.reduce_sum(loss_obj(labels, predictions)) * ' 207 '(1. / global_batch_size)\n```\nPlease see ' 208 'https://www.tensorflow.org/tutorials/distribute/custom_training' 209 ' for more details.') 210 211 if self.reduction == losses_utils.ReductionV2.AUTO: 212 return losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 213 return self.reduction 214 215 216class LossFunctionWrapper(Loss): 217 """Wraps a loss function in the `Loss` class.""" 218 219 def __init__(self, 220 fn, 221 reduction=losses_utils.ReductionV2.AUTO, 222 name=None, 223 **kwargs): 224 """Initializes `LossFunctionWrapper` class. 225 226 Args: 227 fn: The loss function to wrap, with signature `fn(y_true, y_pred, 228 **kwargs)`. 229 reduction: Type of `tf.keras.losses.Reduction` to apply to 230 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 231 option will be determined by the usage context. For almost all cases 232 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 233 `tf.distribute.Strategy`, outside of built-in training loops such as 234 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 235 will raise an error. Please see this custom training [tutorial]( 236 https://www.tensorflow.org/tutorials/distribute/custom_training) for 237 more details. 238 name: Optional name for the instance. 239 **kwargs: The keyword arguments that are passed on to `fn`. 240 """ 241 super().__init__(reduction=reduction, name=name) 242 self.fn = fn 243 self._fn_kwargs = kwargs 244 245 def call(self, y_true, y_pred): 246 """Invokes the `LossFunctionWrapper` instance. 247 248 Args: 249 y_true: Ground truth values. 250 y_pred: The predicted values. 251 252 Returns: 253 Loss values per sample. 254 """ 255 if tensor_util.is_tf_type(y_pred) and tensor_util.is_tf_type(y_true): 256 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(y_pred, y_true) 257 258 ag_fn = autograph.tf_convert(self.fn, ag_ctx.control_status_ctx()) 259 return ag_fn(y_true, y_pred, **self._fn_kwargs) 260 261 def get_config(self): 262 config = {} 263 for k, v in self._fn_kwargs.items(): 264 config[k] = backend.eval(v) if tf_utils.is_tensor_or_variable(v) else v 265 base_config = super().get_config() 266 return dict(list(base_config.items()) + list(config.items())) 267 268 269@keras_export('keras.losses.MeanSquaredError') 270class MeanSquaredError(LossFunctionWrapper): 271 """Computes the mean of squares of errors between labels and predictions. 272 273 `loss = square(y_true - y_pred)` 274 275 Standalone usage: 276 277 >>> y_true = [[0., 1.], [0., 0.]] 278 >>> y_pred = [[1., 1.], [1., 0.]] 279 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 280 >>> mse = tf.keras.losses.MeanSquaredError() 281 >>> mse(y_true, y_pred).numpy() 282 0.5 283 284 >>> # Calling with 'sample_weight'. 285 >>> mse(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy() 286 0.25 287 288 >>> # Using 'sum' reduction type. 289 >>> mse = tf.keras.losses.MeanSquaredError( 290 ... reduction=tf.keras.losses.Reduction.SUM) 291 >>> mse(y_true, y_pred).numpy() 292 1.0 293 294 >>> # Using 'none' reduction type. 295 >>> mse = tf.keras.losses.MeanSquaredError( 296 ... reduction=tf.keras.losses.Reduction.NONE) 297 >>> mse(y_true, y_pred).numpy() 298 array([0.5, 0.5], dtype=float32) 299 300 Usage with the `compile()` API: 301 302 ```python 303 model.compile(optimizer='sgd', loss=tf.keras.losses.MeanSquaredError()) 304 ``` 305 """ 306 307 def __init__(self, 308 reduction=losses_utils.ReductionV2.AUTO, 309 name='mean_squared_error'): 310 """Initializes `MeanSquaredError` instance. 311 312 Args: 313 reduction: Type of `tf.keras.losses.Reduction` to apply to 314 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 315 option will be determined by the usage context. For almost all cases 316 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 317 `tf.distribute.Strategy`, outside of built-in training loops such as 318 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 319 will raise an error. Please see this custom training [tutorial]( 320 https://www.tensorflow.org/tutorials/distribute/custom_training) for 321 more details. 322 name: Optional name for the instance. Defaults to 'mean_squared_error'. 323 """ 324 super().__init__(mean_squared_error, name=name, reduction=reduction) 325 326 327@keras_export('keras.losses.MeanAbsoluteError') 328class MeanAbsoluteError(LossFunctionWrapper): 329 """Computes the mean of absolute difference between labels and predictions. 330 331 `loss = abs(y_true - y_pred)` 332 333 Standalone usage: 334 335 >>> y_true = [[0., 1.], [0., 0.]] 336 >>> y_pred = [[1., 1.], [1., 0.]] 337 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 338 >>> mae = tf.keras.losses.MeanAbsoluteError() 339 >>> mae(y_true, y_pred).numpy() 340 0.5 341 342 >>> # Calling with 'sample_weight'. 343 >>> mae(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy() 344 0.25 345 346 >>> # Using 'sum' reduction type. 347 >>> mae = tf.keras.losses.MeanAbsoluteError( 348 ... reduction=tf.keras.losses.Reduction.SUM) 349 >>> mae(y_true, y_pred).numpy() 350 1.0 351 352 >>> # Using 'none' reduction type. 353 >>> mae = tf.keras.losses.MeanAbsoluteError( 354 ... reduction=tf.keras.losses.Reduction.NONE) 355 >>> mae(y_true, y_pred).numpy() 356 array([0.5, 0.5], dtype=float32) 357 358 Usage with the `compile()` API: 359 360 ```python 361 model.compile(optimizer='sgd', loss=tf.keras.losses.MeanAbsoluteError()) 362 ``` 363 """ 364 365 def __init__(self, 366 reduction=losses_utils.ReductionV2.AUTO, 367 name='mean_absolute_error'): 368 """Initializes `MeanAbsoluteError` instance. 369 370 Args: 371 reduction: Type of `tf.keras.losses.Reduction` to apply to 372 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 373 option will be determined by the usage context. For almost all cases 374 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 375 `tf.distribute.Strategy`, outside of built-in training loops such as 376 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 377 will raise an error. Please see this custom training [tutorial]( 378 https://www.tensorflow.org/tutorials/distribute/custom_training) for 379 more details. 380 name: Optional name for the instance. Defaults to 'mean_absolute_error'. 381 """ 382 super().__init__(mean_absolute_error, name=name, reduction=reduction) 383 384 385@keras_export('keras.losses.MeanAbsolutePercentageError') 386class MeanAbsolutePercentageError(LossFunctionWrapper): 387 """Computes the mean absolute percentage error between `y_true` and `y_pred`. 388 389 `loss = 100 * abs(y_true - y_pred) / y_true` 390 391 Standalone usage: 392 393 >>> y_true = [[2., 1.], [2., 3.]] 394 >>> y_pred = [[1., 1.], [1., 0.]] 395 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 396 >>> mape = tf.keras.losses.MeanAbsolutePercentageError() 397 >>> mape(y_true, y_pred).numpy() 398 50. 399 400 >>> # Calling with 'sample_weight'. 401 >>> mape(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy() 402 20. 403 404 >>> # Using 'sum' reduction type. 405 >>> mape = tf.keras.losses.MeanAbsolutePercentageError( 406 ... reduction=tf.keras.losses.Reduction.SUM) 407 >>> mape(y_true, y_pred).numpy() 408 100. 409 410 >>> # Using 'none' reduction type. 411 >>> mape = tf.keras.losses.MeanAbsolutePercentageError( 412 ... reduction=tf.keras.losses.Reduction.NONE) 413 >>> mape(y_true, y_pred).numpy() 414 array([25., 75.], dtype=float32) 415 416 Usage with the `compile()` API: 417 418 ```python 419 model.compile(optimizer='sgd', 420 loss=tf.keras.losses.MeanAbsolutePercentageError()) 421 ``` 422 """ 423 424 def __init__(self, 425 reduction=losses_utils.ReductionV2.AUTO, 426 name='mean_absolute_percentage_error'): 427 """Initializes `MeanAbsolutePercentageError` instance. 428 429 Args: 430 reduction: Type of `tf.keras.losses.Reduction` to apply to 431 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 432 option will be determined by the usage context. For almost all cases 433 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 434 `tf.distribute.Strategy`, outside of built-in training loops such as 435 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 436 will raise an error. Please see this custom training [tutorial]( 437 https://www.tensorflow.org/tutorials/distribute/custom_training) for 438 more details. 439 name: Optional name for the instance. Defaults to 440 'mean_absolute_percentage_error'. 441 """ 442 super().__init__( 443 mean_absolute_percentage_error, name=name, reduction=reduction) 444 445 446@keras_export('keras.losses.MeanSquaredLogarithmicError') 447class MeanSquaredLogarithmicError(LossFunctionWrapper): 448 """Computes the mean squared logarithmic error between `y_true` and `y_pred`. 449 450 `loss = square(log(y_true + 1.) - log(y_pred + 1.))` 451 452 Standalone usage: 453 454 >>> y_true = [[0., 1.], [0., 0.]] 455 >>> y_pred = [[1., 1.], [1., 0.]] 456 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 457 >>> msle = tf.keras.losses.MeanSquaredLogarithmicError() 458 >>> msle(y_true, y_pred).numpy() 459 0.240 460 461 >>> # Calling with 'sample_weight'. 462 >>> msle(y_true, y_pred, sample_weight=[0.7, 0.3]).numpy() 463 0.120 464 465 >>> # Using 'sum' reduction type. 466 >>> msle = tf.keras.losses.MeanSquaredLogarithmicError( 467 ... reduction=tf.keras.losses.Reduction.SUM) 468 >>> msle(y_true, y_pred).numpy() 469 0.480 470 471 >>> # Using 'none' reduction type. 472 >>> msle = tf.keras.losses.MeanSquaredLogarithmicError( 473 ... reduction=tf.keras.losses.Reduction.NONE) 474 >>> msle(y_true, y_pred).numpy() 475 array([0.240, 0.240], dtype=float32) 476 477 Usage with the `compile()` API: 478 479 ```python 480 model.compile(optimizer='sgd', 481 loss=tf.keras.losses.MeanSquaredLogarithmicError()) 482 ``` 483 """ 484 485 def __init__(self, 486 reduction=losses_utils.ReductionV2.AUTO, 487 name='mean_squared_logarithmic_error'): 488 """Initializes `MeanSquaredLogarithmicError` instance. 489 490 Args: 491 reduction: Type of `tf.keras.losses.Reduction` to apply to 492 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 493 option will be determined by the usage context. For almost all cases 494 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 495 `tf.distribute.Strategy`, outside of built-in training loops such as 496 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 497 will raise an error. Please see this custom training [tutorial]( 498 https://www.tensorflow.org/tutorials/distribute/custom_training) for 499 more details. 500 name: Optional name for the instance. Defaults to 501 'mean_squared_logarithmic_error'. 502 """ 503 super().__init__( 504 mean_squared_logarithmic_error, name=name, reduction=reduction) 505 506 507@keras_export('keras.losses.BinaryCrossentropy') 508class BinaryCrossentropy(LossFunctionWrapper): 509 """Computes the cross-entropy loss between true labels and predicted labels. 510 511 Use this cross-entropy loss for binary (0 or 1) classification applications. 512 The loss function requires the following inputs: 513 514 - `y_true` (true label): This is either 0 or 1. 515 - `y_pred` (predicted value): This is the model's prediction, i.e, a single 516 floating-point value which either represents a 517 [logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf] 518 when `from_logits=True`) or a probability (i.e, value in [0., 1.] when 519 `from_logits=False`). 520 521 **Recommended Usage:** (set `from_logits=True`) 522 523 With `tf.keras` API: 524 525 ```python 526 model.compile( 527 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), 528 .... 529 ) 530 ``` 531 532 As a standalone function: 533 534 >>> # Example 1: (batch_size = 1, number of samples = 4) 535 >>> y_true = [0, 1, 0, 0] 536 >>> y_pred = [-18.6, 0.51, 2.94, -12.8] 537 >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True) 538 >>> bce(y_true, y_pred).numpy() 539 0.865 540 541 >>> # Example 2: (batch_size = 2, number of samples = 4) 542 >>> y_true = [[0, 1], [0, 0]] 543 >>> y_pred = [[-18.6, 0.51], [2.94, -12.8]] 544 >>> # Using default 'auto'/'sum_over_batch_size' reduction type. 545 >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True) 546 >>> bce(y_true, y_pred).numpy() 547 0.865 548 >>> # Using 'sample_weight' attribute 549 >>> bce(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy() 550 0.243 551 >>> # Using 'sum' reduction` type. 552 >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True, 553 ... reduction=tf.keras.losses.Reduction.SUM) 554 >>> bce(y_true, y_pred).numpy() 555 1.730 556 >>> # Using 'none' reduction type. 557 >>> bce = tf.keras.losses.BinaryCrossentropy(from_logits=True, 558 ... reduction=tf.keras.losses.Reduction.NONE) 559 >>> bce(y_true, y_pred).numpy() 560 array([0.235, 1.496], dtype=float32) 561 562 **Default Usage:** (set `from_logits=False`) 563 564 >>> # Make the following updates to the above "Recommended Usage" section 565 >>> # 1. Set `from_logits=False` 566 >>> tf.keras.losses.BinaryCrossentropy() # OR ...('from_logits=False') 567 >>> # 2. Update `y_pred` to use probabilities instead of logits 568 >>> y_pred = [0.6, 0.3, 0.2, 0.8] # OR [[0.6, 0.3], [0.2, 0.8]] 569 """ 570 571 def __init__(self, 572 from_logits=False, 573 label_smoothing=0, 574 axis=-1, 575 reduction=losses_utils.ReductionV2.AUTO, 576 name='binary_crossentropy'): 577 """Initializes `BinaryCrossentropy` instance. 578 579 Args: 580 from_logits: Whether to interpret `y_pred` as a tensor of 581 [logit](https://en.wikipedia.org/wiki/Logit) values. By default, we 582 assume that `y_pred` contains probabilities (i.e., values in [0, 1]). 583 label_smoothing: Float in [0, 1]. When 0, no smoothing occurs. When > 0, 584 we compute the loss between the predicted labels and a smoothed version 585 of the true labels, where the smoothing squeezes the labels towards 0.5. 586 Larger values of `label_smoothing` correspond to heavier smoothing. 587 axis: The axis along which to compute crossentropy (the features axis). 588 Defaults to -1. 589 reduction: Type of `tf.keras.losses.Reduction` to apply to 590 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 591 option will be determined by the usage context. For almost all cases 592 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 593 `tf.distribute.Strategy`, outside of built-in training loops such as 594 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 595 will raise an error. Please see this custom training [tutorial]( 596 https://www.tensorflow.org/tutorials/distribute/custom_training) for 597 more details. 598 name: Name for the op. Defaults to 'binary_crossentropy'. 599 """ 600 super().__init__( 601 binary_crossentropy, 602 name=name, 603 reduction=reduction, 604 from_logits=from_logits, 605 label_smoothing=label_smoothing, 606 axis=axis) 607 self.from_logits = from_logits 608 609 610@keras_export('keras.losses.CategoricalCrossentropy') 611class CategoricalCrossentropy(LossFunctionWrapper): 612 """Computes the crossentropy loss between the labels and predictions. 613 614 Use this crossentropy loss function when there are two or more label classes. 615 We expect labels to be provided in a `one_hot` representation. If you want to 616 provide labels as integers, please use `SparseCategoricalCrossentropy` loss. 617 There should be `# classes` floating point values per feature. 618 619 In the snippet below, there is `# classes` floating pointing values per 620 example. The shape of both `y_pred` and `y_true` are 621 `[batch_size, num_classes]`. 622 623 Standalone usage: 624 625 >>> y_true = [[0, 1, 0], [0, 0, 1]] 626 >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] 627 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 628 >>> cce = tf.keras.losses.CategoricalCrossentropy() 629 >>> cce(y_true, y_pred).numpy() 630 1.177 631 632 >>> # Calling with 'sample_weight'. 633 >>> cce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy() 634 0.814 635 636 >>> # Using 'sum' reduction type. 637 >>> cce = tf.keras.losses.CategoricalCrossentropy( 638 ... reduction=tf.keras.losses.Reduction.SUM) 639 >>> cce(y_true, y_pred).numpy() 640 2.354 641 642 >>> # Using 'none' reduction type. 643 >>> cce = tf.keras.losses.CategoricalCrossentropy( 644 ... reduction=tf.keras.losses.Reduction.NONE) 645 >>> cce(y_true, y_pred).numpy() 646 array([0.0513, 2.303], dtype=float32) 647 648 Usage with the `compile()` API: 649 650 ```python 651 model.compile(optimizer='sgd', loss=tf.keras.losses.CategoricalCrossentropy()) 652 ``` 653 """ 654 655 def __init__(self, 656 from_logits=False, 657 label_smoothing=0, 658 axis=-1, 659 reduction=losses_utils.ReductionV2.AUTO, 660 name='categorical_crossentropy'): 661 """Initializes `CategoricalCrossentropy` instance. 662 663 Args: 664 from_logits: Whether `y_pred` is expected to be a logits tensor. By 665 default, we assume that `y_pred` encodes a probability distribution. 666 label_smoothing: Float in [0, 1]. When > 0, label values are smoothed, 667 meaning the confidence on label values are relaxed. For example, if 668 `0.1`, use `0.1 / num_classes` for non-target labels and 669 `0.9 + 0.1 / num_classes` for target labels. 670 axis: The axis along which to compute crossentropy (the features axis). 671 Defaults to -1. 672 reduction: Type of `tf.keras.losses.Reduction` to apply to 673 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 674 option will be determined by the usage context. For almost all cases 675 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 676 `tf.distribute.Strategy`, outside of built-in training loops such as 677 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 678 will raise an error. Please see this custom training [tutorial]( 679 https://www.tensorflow.org/tutorials/distribute/custom_training) for 680 more details. 681 name: Optional name for the instance. 682 Defaults to 'categorical_crossentropy'. 683 """ 684 super().__init__( 685 categorical_crossentropy, 686 name=name, 687 reduction=reduction, 688 from_logits=from_logits, 689 label_smoothing=label_smoothing, 690 axis=axis) 691 692 693@keras_export('keras.losses.SparseCategoricalCrossentropy') 694class SparseCategoricalCrossentropy(LossFunctionWrapper): 695 """Computes the crossentropy loss between the labels and predictions. 696 697 Use this crossentropy loss function when there are two or more label classes. 698 We expect labels to be provided as integers. If you want to provide labels 699 using `one-hot` representation, please use `CategoricalCrossentropy` loss. 700 There should be `# classes` floating point values per feature for `y_pred` 701 and a single floating point value per feature for `y_true`. 702 703 In the snippet below, there is a single floating point value per example for 704 `y_true` and `# classes` floating pointing values per example for `y_pred`. 705 The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is 706 `[batch_size, num_classes]`. 707 708 Standalone usage: 709 710 >>> y_true = [1, 2] 711 >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] 712 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 713 >>> scce = tf.keras.losses.SparseCategoricalCrossentropy() 714 >>> scce(y_true, y_pred).numpy() 715 1.177 716 717 >>> # Calling with 'sample_weight'. 718 >>> scce(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy() 719 0.814 720 721 >>> # Using 'sum' reduction type. 722 >>> scce = tf.keras.losses.SparseCategoricalCrossentropy( 723 ... reduction=tf.keras.losses.Reduction.SUM) 724 >>> scce(y_true, y_pred).numpy() 725 2.354 726 727 >>> # Using 'none' reduction type. 728 >>> scce = tf.keras.losses.SparseCategoricalCrossentropy( 729 ... reduction=tf.keras.losses.Reduction.NONE) 730 >>> scce(y_true, y_pred).numpy() 731 array([0.0513, 2.303], dtype=float32) 732 733 Usage with the `compile()` API: 734 735 ```python 736 model.compile(optimizer='sgd', 737 loss=tf.keras.losses.SparseCategoricalCrossentropy()) 738 ``` 739 """ 740 741 def __init__(self, 742 from_logits=False, 743 reduction=losses_utils.ReductionV2.AUTO, 744 name='sparse_categorical_crossentropy'): 745 """Initializes `SparseCategoricalCrossentropy` instance. 746 747 Args: 748 from_logits: Whether `y_pred` is expected to be a logits tensor. By 749 default, we assume that `y_pred` encodes a probability distribution. 750 reduction: Type of `tf.keras.losses.Reduction` to apply to 751 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 752 option will be determined by the usage context. For almost all cases 753 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 754 `tf.distribute.Strategy`, outside of built-in training loops such as 755 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 756 will raise an error. Please see this custom training [tutorial]( 757 https://www.tensorflow.org/tutorials/distribute/custom_training) for 758 more details. 759 name: Optional name for the instance. Defaults to 760 'sparse_categorical_crossentropy'. 761 """ 762 super().__init__( 763 sparse_categorical_crossentropy, 764 name=name, 765 reduction=reduction, 766 from_logits=from_logits) 767 768 769@keras_export('keras.losses.Hinge') 770class Hinge(LossFunctionWrapper): 771 """Computes the hinge loss between `y_true` and `y_pred`. 772 773 `loss = maximum(1 - y_true * y_pred, 0)` 774 775 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 776 provided we will convert them to -1 or 1. 777 778 Standalone usage: 779 780 >>> y_true = [[0., 1.], [0., 0.]] 781 >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] 782 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 783 >>> h = tf.keras.losses.Hinge() 784 >>> h(y_true, y_pred).numpy() 785 1.3 786 787 >>> # Calling with 'sample_weight'. 788 >>> h(y_true, y_pred, sample_weight=[1, 0]).numpy() 789 0.55 790 791 >>> # Using 'sum' reduction type. 792 >>> h = tf.keras.losses.Hinge( 793 ... reduction=tf.keras.losses.Reduction.SUM) 794 >>> h(y_true, y_pred).numpy() 795 2.6 796 797 >>> # Using 'none' reduction type. 798 >>> h = tf.keras.losses.Hinge( 799 ... reduction=tf.keras.losses.Reduction.NONE) 800 >>> h(y_true, y_pred).numpy() 801 array([1.1, 1.5], dtype=float32) 802 803 Usage with the `compile()` API: 804 805 ```python 806 model.compile(optimizer='sgd', loss=tf.keras.losses.Hinge()) 807 ``` 808 """ 809 810 def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='hinge'): 811 """Initializes `Hinge` instance. 812 813 Args: 814 reduction: Type of `tf.keras.losses.Reduction` to apply to 815 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 816 option will be determined by the usage context. For almost all cases 817 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 818 `tf.distribute.Strategy`, outside of built-in training loops such as 819 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 820 will raise an error. Please see this custom training [tutorial]( 821 https://www.tensorflow.org/tutorials/distribute/custom_training) for 822 more details. 823 name: Optional name for the instance. Defaults to 'hinge'. 824 """ 825 super().__init__(hinge, name=name, reduction=reduction) 826 827 828@keras_export('keras.losses.SquaredHinge') 829class SquaredHinge(LossFunctionWrapper): 830 """Computes the squared hinge loss between `y_true` and `y_pred`. 831 832 `loss = square(maximum(1 - y_true * y_pred, 0))` 833 834 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 835 provided we will convert them to -1 or 1. 836 837 Standalone usage: 838 839 >>> y_true = [[0., 1.], [0., 0.]] 840 >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] 841 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 842 >>> h = tf.keras.losses.SquaredHinge() 843 >>> h(y_true, y_pred).numpy() 844 1.86 845 846 >>> # Calling with 'sample_weight'. 847 >>> h(y_true, y_pred, sample_weight=[1, 0]).numpy() 848 0.73 849 850 >>> # Using 'sum' reduction type. 851 >>> h = tf.keras.losses.SquaredHinge( 852 ... reduction=tf.keras.losses.Reduction.SUM) 853 >>> h(y_true, y_pred).numpy() 854 3.72 855 856 >>> # Using 'none' reduction type. 857 >>> h = tf.keras.losses.SquaredHinge( 858 ... reduction=tf.keras.losses.Reduction.NONE) 859 >>> h(y_true, y_pred).numpy() 860 array([1.46, 2.26], dtype=float32) 861 862 Usage with the `compile()` API: 863 864 ```python 865 model.compile(optimizer='sgd', loss=tf.keras.losses.SquaredHinge()) 866 ``` 867 """ 868 869 def __init__(self, 870 reduction=losses_utils.ReductionV2.AUTO, 871 name='squared_hinge'): 872 """Initializes `SquaredHinge` instance. 873 874 Args: 875 reduction: Type of `tf.keras.losses.Reduction` to apply to 876 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 877 option will be determined by the usage context. For almost all cases 878 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 879 `tf.distribute.Strategy`, outside of built-in training loops such as 880 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 881 will raise an error. Please see this custom training [tutorial]( 882 https://www.tensorflow.org/tutorials/distribute/custom_training) for 883 more details. 884 name: Optional name for the instance. Defaults to 'squared_hinge'. 885 """ 886 super().__init__(squared_hinge, name=name, reduction=reduction) 887 888 889@keras_export('keras.losses.CategoricalHinge') 890class CategoricalHinge(LossFunctionWrapper): 891 """Computes the categorical hinge loss between `y_true` and `y_pred`. 892 893 `loss = maximum(neg - pos + 1, 0)` 894 where `neg=maximum((1-y_true)*y_pred) and pos=sum(y_true*y_pred)` 895 896 Standalone usage: 897 898 >>> y_true = [[0, 1], [0, 0]] 899 >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] 900 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 901 >>> h = tf.keras.losses.CategoricalHinge() 902 >>> h(y_true, y_pred).numpy() 903 1.4 904 905 >>> # Calling with 'sample_weight'. 906 >>> h(y_true, y_pred, sample_weight=[1, 0]).numpy() 907 0.6 908 909 >>> # Using 'sum' reduction type. 910 >>> h = tf.keras.losses.CategoricalHinge( 911 ... reduction=tf.keras.losses.Reduction.SUM) 912 >>> h(y_true, y_pred).numpy() 913 2.8 914 915 >>> # Using 'none' reduction type. 916 >>> h = tf.keras.losses.CategoricalHinge( 917 ... reduction=tf.keras.losses.Reduction.NONE) 918 >>> h(y_true, y_pred).numpy() 919 array([1.2, 1.6], dtype=float32) 920 921 Usage with the `compile()` API: 922 923 ```python 924 model.compile(optimizer='sgd', loss=tf.keras.losses.CategoricalHinge()) 925 ``` 926 """ 927 928 def __init__(self, 929 reduction=losses_utils.ReductionV2.AUTO, 930 name='categorical_hinge'): 931 """Initializes `CategoricalHinge` instance. 932 933 Args: 934 reduction: Type of `tf.keras.losses.Reduction` to apply to 935 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 936 option will be determined by the usage context. For almost all cases 937 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 938 `tf.distribute.Strategy`, outside of built-in training loops such as 939 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 940 will raise an error. Please see this custom training [tutorial]( 941 https://www.tensorflow.org/tutorials/distribute/custom_training) for 942 more details. 943 name: Optional name for the instance. Defaults to 'categorical_hinge'. 944 """ 945 super().__init__(categorical_hinge, name=name, reduction=reduction) 946 947 948@keras_export('keras.losses.Poisson') 949class Poisson(LossFunctionWrapper): 950 """Computes the Poisson loss between `y_true` and `y_pred`. 951 952 `loss = y_pred - y_true * log(y_pred)` 953 954 Standalone usage: 955 956 >>> y_true = [[0., 1.], [0., 0.]] 957 >>> y_pred = [[1., 1.], [0., 0.]] 958 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 959 >>> p = tf.keras.losses.Poisson() 960 >>> p(y_true, y_pred).numpy() 961 0.5 962 963 >>> # Calling with 'sample_weight'. 964 >>> p(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy() 965 0.4 966 967 >>> # Using 'sum' reduction type. 968 >>> p = tf.keras.losses.Poisson( 969 ... reduction=tf.keras.losses.Reduction.SUM) 970 >>> p(y_true, y_pred).numpy() 971 0.999 972 973 >>> # Using 'none' reduction type. 974 >>> p = tf.keras.losses.Poisson( 975 ... reduction=tf.keras.losses.Reduction.NONE) 976 >>> p(y_true, y_pred).numpy() 977 array([0.999, 0.], dtype=float32) 978 979 Usage with the `compile()` API: 980 981 ```python 982 model.compile(optimizer='sgd', loss=tf.keras.losses.Poisson()) 983 ``` 984 """ 985 986 def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='poisson'): 987 """Initializes `Poisson` instance. 988 989 Args: 990 reduction: Type of `tf.keras.losses.Reduction` to apply to 991 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 992 option will be determined by the usage context. For almost all cases 993 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 994 `tf.distribute.Strategy`, outside of built-in training loops such as 995 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 996 will raise an error. Please see this custom training [tutorial]( 997 https://www.tensorflow.org/tutorials/distribute/custom_training) for 998 more details. 999 name: Optional name for the instance. Defaults to 'poisson'. 1000 """ 1001 super().__init__(poisson, name=name, reduction=reduction) 1002 1003 1004@keras_export('keras.losses.LogCosh') 1005class LogCosh(LossFunctionWrapper): 1006 """Computes the logarithm of the hyperbolic cosine of the prediction error. 1007 1008 `logcosh = log((exp(x) + exp(-x))/2)`, 1009 where x is the error `y_pred - y_true`. 1010 1011 Standalone usage: 1012 1013 >>> y_true = [[0., 1.], [0., 0.]] 1014 >>> y_pred = [[1., 1.], [0., 0.]] 1015 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 1016 >>> l = tf.keras.losses.LogCosh() 1017 >>> l(y_true, y_pred).numpy() 1018 0.108 1019 1020 >>> # Calling with 'sample_weight'. 1021 >>> l(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy() 1022 0.087 1023 1024 >>> # Using 'sum' reduction type. 1025 >>> l = tf.keras.losses.LogCosh( 1026 ... reduction=tf.keras.losses.Reduction.SUM) 1027 >>> l(y_true, y_pred).numpy() 1028 0.217 1029 1030 >>> # Using 'none' reduction type. 1031 >>> l = tf.keras.losses.LogCosh( 1032 ... reduction=tf.keras.losses.Reduction.NONE) 1033 >>> l(y_true, y_pred).numpy() 1034 array([0.217, 0.], dtype=float32) 1035 1036 Usage with the `compile()` API: 1037 1038 ```python 1039 model.compile(optimizer='sgd', loss=tf.keras.losses.LogCosh()) 1040 ``` 1041 """ 1042 1043 def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='log_cosh'): 1044 """Initializes `LogCosh` instance. 1045 1046 Args: 1047 reduction: Type of `tf.keras.losses.Reduction` to apply to 1048 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 1049 option will be determined by the usage context. For almost all cases 1050 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 1051 `tf.distribute.Strategy`, outside of built-in training loops such as 1052 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 1053 will raise an error. Please see this custom training [tutorial]( 1054 https://www.tensorflow.org/tutorials/distribute/custom_training) for 1055 more details. 1056 name: Optional name for the instance. Defaults to 'log_cosh'. 1057 """ 1058 super().__init__(log_cosh, name=name, reduction=reduction) 1059 1060 1061@keras_export('keras.losses.KLDivergence') 1062class KLDivergence(LossFunctionWrapper): 1063 """Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`. 1064 1065 `loss = y_true * log(y_true / y_pred)` 1066 1067 See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence 1068 1069 Standalone usage: 1070 1071 >>> y_true = [[0, 1], [0, 0]] 1072 >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] 1073 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 1074 >>> kl = tf.keras.losses.KLDivergence() 1075 >>> kl(y_true, y_pred).numpy() 1076 0.458 1077 1078 >>> # Calling with 'sample_weight'. 1079 >>> kl(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy() 1080 0.366 1081 1082 >>> # Using 'sum' reduction type. 1083 >>> kl = tf.keras.losses.KLDivergence( 1084 ... reduction=tf.keras.losses.Reduction.SUM) 1085 >>> kl(y_true, y_pred).numpy() 1086 0.916 1087 1088 >>> # Using 'none' reduction type. 1089 >>> kl = tf.keras.losses.KLDivergence( 1090 ... reduction=tf.keras.losses.Reduction.NONE) 1091 >>> kl(y_true, y_pred).numpy() 1092 array([0.916, -3.08e-06], dtype=float32) 1093 1094 Usage with the `compile()` API: 1095 1096 ```python 1097 model.compile(optimizer='sgd', loss=tf.keras.losses.KLDivergence()) 1098 ``` 1099 """ 1100 1101 def __init__(self, 1102 reduction=losses_utils.ReductionV2.AUTO, 1103 name='kl_divergence'): 1104 """Initializes `KLDivergence` instance. 1105 1106 Args: 1107 reduction: Type of `tf.keras.losses.Reduction` to apply to 1108 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 1109 option will be determined by the usage context. For almost all cases 1110 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 1111 `tf.distribute.Strategy`, outside of built-in training loops such as 1112 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 1113 will raise an error. Please see this custom training [tutorial]( 1114 https://www.tensorflow.org/tutorials/distribute/custom_training) for 1115 more details. 1116 name: Optional name for the instance. Defaults to 'kl_divergence'. 1117 """ 1118 super().__init__(kl_divergence, name=name, reduction=reduction) 1119 1120 1121@keras_export('keras.losses.Huber') 1122class Huber(LossFunctionWrapper): 1123 """Computes the Huber loss between `y_true` and `y_pred`. 1124 1125 For each value x in `error = y_true - y_pred`: 1126 1127 ``` 1128 loss = 0.5 * x^2 if |x| <= d 1129 loss = 0.5 * d^2 + d * (|x| - d) if |x| > d 1130 ``` 1131 where d is `delta`. See: https://en.wikipedia.org/wiki/Huber_loss 1132 1133 Standalone usage: 1134 1135 >>> y_true = [[0, 1], [0, 0]] 1136 >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] 1137 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 1138 >>> h = tf.keras.losses.Huber() 1139 >>> h(y_true, y_pred).numpy() 1140 0.155 1141 1142 >>> # Calling with 'sample_weight'. 1143 >>> h(y_true, y_pred, sample_weight=[1, 0]).numpy() 1144 0.09 1145 1146 >>> # Using 'sum' reduction type. 1147 >>> h = tf.keras.losses.Huber( 1148 ... reduction=tf.keras.losses.Reduction.SUM) 1149 >>> h(y_true, y_pred).numpy() 1150 0.31 1151 1152 >>> # Using 'none' reduction type. 1153 >>> h = tf.keras.losses.Huber( 1154 ... reduction=tf.keras.losses.Reduction.NONE) 1155 >>> h(y_true, y_pred).numpy() 1156 array([0.18, 0.13], dtype=float32) 1157 1158 Usage with the `compile()` API: 1159 1160 ```python 1161 model.compile(optimizer='sgd', loss=tf.keras.losses.Huber()) 1162 ``` 1163 """ 1164 1165 def __init__(self, 1166 delta=1.0, 1167 reduction=losses_utils.ReductionV2.AUTO, 1168 name='huber_loss'): 1169 """Initializes `Huber` instance. 1170 1171 Args: 1172 delta: A float, the point where the Huber loss function changes from a 1173 quadratic to linear. 1174 reduction: Type of `tf.keras.losses.Reduction` to apply to 1175 loss. Default value is `AUTO`. `AUTO` indicates that the reduction 1176 option will be determined by the usage context. For almost all cases 1177 this defaults to `SUM_OVER_BATCH_SIZE`. When used with 1178 `tf.distribute.Strategy`, outside of built-in training loops such as 1179 `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE` 1180 will raise an error. Please see this custom training [tutorial]( 1181 https://www.tensorflow.org/tutorials/distribute/custom_training) for 1182 more details. 1183 name: Optional name for the instance. Defaults to 'huber_loss'. 1184 """ 1185 super().__init__(huber, name=name, reduction=reduction, delta=delta) 1186 1187 1188@keras_export('keras.metrics.mean_squared_error', 'keras.metrics.mse', 1189 'keras.metrics.MSE', 'keras.losses.mean_squared_error', 1190 'keras.losses.mse', 'keras.losses.MSE') 1191@dispatch.add_dispatch_support 1192def mean_squared_error(y_true, y_pred): 1193 """Computes the mean squared error between labels and predictions. 1194 1195 After computing the squared distance between the inputs, the mean value over 1196 the last dimension is returned. 1197 1198 `loss = mean(square(y_true - y_pred), axis=-1)` 1199 1200 Standalone usage: 1201 1202 >>> y_true = np.random.randint(0, 2, size=(2, 3)) 1203 >>> y_pred = np.random.random(size=(2, 3)) 1204 >>> loss = tf.keras.losses.mean_squared_error(y_true, y_pred) 1205 >>> assert loss.shape == (2,) 1206 >>> assert np.array_equal( 1207 ... loss.numpy(), np.mean(np.square(y_true - y_pred), axis=-1)) 1208 1209 Args: 1210 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1211 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1212 1213 Returns: 1214 Mean squared error values. shape = `[batch_size, d0, .. dN-1]`. 1215 """ 1216 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1217 y_true = math_ops.cast(y_true, y_pred.dtype) 1218 return backend.mean(math_ops.squared_difference(y_pred, y_true), axis=-1) 1219 1220 1221def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred, y_pred_extra_dim=False): 1222 """Apply a loss function on a per batch basis. 1223 1224 Args: 1225 loss_fn: The loss function 1226 y_true: truth values (RaggedTensor) 1227 y_pred: predicted values (RaggedTensor) 1228 y_pred_extra_dim: whether y_pred has an additional dimension compared to 1229 y_true 1230 1231 Returns: 1232 Loss-function result. A dense tensor if the output has a single dimension 1233 (per-batch loss value); a ragged tensor otherwise. 1234 """ 1235 1236 def rt_is_equiv_dense(rt): 1237 """Returns true if this RaggedTensor has the same row_lenghts across 1238 1239 all ragged dimensions and thus can be converted to a dense tensor 1240 without loss of information. 1241 1242 Args: 1243 rt: RaggedTensor. 1244 """ 1245 return math_ops.reduce_all([ 1246 math_ops.equal( 1247 math_ops.reduce_variance(math_ops.cast(row_lens, backend.floatx())), 1248 constant_op.constant([0.])) for row_lens in rt.nested_row_lengths() 1249 ]) 1250 1251 def _convert_to_dense(inputs): 1252 return tuple( 1253 rt.to_tensor() if isinstance(rt, ragged_tensor.RaggedTensor) else rt 1254 for rt in inputs) 1255 1256 def _call_loss(inputs, ragged_output): 1257 """ Adapt the result to ragged or dense tensor according to the expected 1258 1259 output type. This is done so that all the return values of the map 1260 operation have the same type. 1261 """ 1262 r = loss_fn(*inputs) 1263 if ragged_output and not isinstance(r, ragged_tensor.RaggedTensor): 1264 r = ragged_tensor.RaggedTensor.from_tensor(r) 1265 elif not ragged_output and isinstance(r, ragged_tensor.RaggedTensor): 1266 r = r.to_tensor() 1267 return r 1268 1269 def _wrapper(inputs, ragged_output): 1270 _, y_pred = inputs 1271 if isinstance(y_pred, ragged_tensor.RaggedTensor): 1272 return control_flow_ops.cond( 1273 rt_is_equiv_dense(y_pred), 1274 lambda: _call_loss(_convert_to_dense(inputs), ragged_output), 1275 lambda: _call_loss(inputs, ragged_output)) 1276 1277 return loss_fn(*inputs) 1278 1279 if not isinstance(y_true, ragged_tensor.RaggedTensor): 1280 return loss_fn(y_true, y_pred.to_tensor()) 1281 1282 lshape = y_pred.shape.as_list()[1:-1] 1283 if len(lshape) > 0: 1284 spec = ragged_tensor.RaggedTensorSpec(shape=lshape, dtype=y_pred.dtype) 1285 else: 1286 spec = tensor_spec.TensorSpec(shape=[], dtype=y_pred.dtype) 1287 1288 nested_splits_list = [rt.nested_row_splits for rt in (y_true, y_pred)] 1289 if y_pred_extra_dim: 1290 # The last dimension of a categorical prediction may be ragged or not. 1291 rdims = [len(slist) for slist in nested_splits_list] 1292 if rdims[0] == rdims[1] - 1: 1293 nested_splits_list[1] = nested_splits_list[1][:-1] 1294 1295 map_fn = functools.partial(_wrapper, ragged_output=len(lshape) > 1) 1296 1297 assertion_list = ragged_util.assert_splits_match(nested_splits_list) 1298 with ops.control_dependencies(assertion_list): 1299 return ragged_map_ops.map_fn(map_fn, elems=(y_true, y_pred), dtype=spec) 1300 1301 1302@dispatch.dispatch_for_types(mean_squared_error, ragged_tensor.RaggedTensor) 1303def _ragged_tensor_mse(y_true, y_pred): 1304 """Implements support for handling RaggedTensors. 1305 1306 Args: 1307 y_true: RaggedTensor truth values. shape = `[batch_size, d0, .. dN]`. 1308 y_pred: RaggedTensor predicted values. shape = `[batch_size, d0, .. dN]`. 1309 1310 Returns: 1311 Mean squared error values. shape = `[batch_size, d0, .. dN-1]`. 1312 When the number of dimensions of the batch feature vector [d0, .. dN] is 1313 greater than one the return value is a RaggedTensor. Otherwise a Dense 1314 tensor with dimensions [batch_size] is returned. 1315 """ 1316 return _ragged_tensor_apply_loss(mean_squared_error, y_true, y_pred) 1317 1318 1319@keras_export('keras.metrics.mean_absolute_error', 'keras.metrics.mae', 1320 'keras.metrics.MAE', 'keras.losses.mean_absolute_error', 1321 'keras.losses.mae', 'keras.losses.MAE') 1322@dispatch.add_dispatch_support 1323def mean_absolute_error(y_true, y_pred): 1324 """Computes the mean absolute error between labels and predictions. 1325 1326 `loss = mean(abs(y_true - y_pred), axis=-1)` 1327 1328 Standalone usage: 1329 1330 >>> y_true = np.random.randint(0, 2, size=(2, 3)) 1331 >>> y_pred = np.random.random(size=(2, 3)) 1332 >>> loss = tf.keras.losses.mean_absolute_error(y_true, y_pred) 1333 >>> assert loss.shape == (2,) 1334 >>> assert np.array_equal( 1335 ... loss.numpy(), np.mean(np.abs(y_true - y_pred), axis=-1)) 1336 1337 Args: 1338 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1339 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1340 1341 Returns: 1342 Mean absolute error values. shape = `[batch_size, d0, .. dN-1]`. 1343 """ 1344 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1345 y_true = math_ops.cast(y_true, y_pred.dtype) 1346 return backend.mean(math_ops.abs(y_pred - y_true), axis=-1) 1347 1348 1349@dispatch.dispatch_for_types(mean_absolute_error, ragged_tensor.RaggedTensor) 1350def _ragged_tensor_mae(y_true, y_pred): 1351 """RaggedTensor adapter for mean_absolute_error.""" 1352 return _ragged_tensor_apply_loss(mean_absolute_error, y_true, y_pred) 1353 1354 1355@keras_export('keras.metrics.mean_absolute_percentage_error', 1356 'keras.metrics.mape', 'keras.metrics.MAPE', 1357 'keras.losses.mean_absolute_percentage_error', 1358 'keras.losses.mape', 'keras.losses.MAPE') 1359@dispatch.add_dispatch_support 1360def mean_absolute_percentage_error(y_true, y_pred): 1361 """Computes the mean absolute percentage error between `y_true` and `y_pred`. 1362 1363 `loss = 100 * mean(abs((y_true - y_pred) / y_true), axis=-1)` 1364 1365 Standalone usage: 1366 1367 >>> y_true = np.random.random(size=(2, 3)) 1368 >>> y_true = np.maximum(y_true, 1e-7) # Prevent division by zero 1369 >>> y_pred = np.random.random(size=(2, 3)) 1370 >>> loss = tf.keras.losses.mean_absolute_percentage_error(y_true, y_pred) 1371 >>> assert loss.shape == (2,) 1372 >>> assert np.array_equal( 1373 ... loss.numpy(), 1374 ... 100. * np.mean(np.abs((y_true - y_pred) / y_true), axis=-1)) 1375 1376 Args: 1377 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1378 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1379 1380 Returns: 1381 Mean absolute percentage error values. shape = `[batch_size, d0, .. dN-1]`. 1382 """ 1383 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1384 y_true = math_ops.cast(y_true, y_pred.dtype) 1385 diff = math_ops.abs( 1386 (y_true - y_pred) / backend.maximum(math_ops.abs(y_true), 1387 backend.epsilon())) 1388 return 100. * backend.mean(diff, axis=-1) 1389 1390 1391@dispatch.dispatch_for_types(mean_absolute_percentage_error, 1392 ragged_tensor.RaggedTensor) 1393def _ragged_tensor_mape(y_true, y_pred): 1394 """Support RaggedTensors.""" 1395 return _ragged_tensor_apply_loss(mean_absolute_percentage_error, y_true, 1396 y_pred) 1397 1398 1399@keras_export('keras.metrics.mean_squared_logarithmic_error', 1400 'keras.metrics.msle', 'keras.metrics.MSLE', 1401 'keras.losses.mean_squared_logarithmic_error', 1402 'keras.losses.msle', 'keras.losses.MSLE') 1403@dispatch.add_dispatch_support 1404def mean_squared_logarithmic_error(y_true, y_pred): 1405 """Computes the mean squared logarithmic error between `y_true` and `y_pred`. 1406 1407 `loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1)` 1408 1409 Standalone usage: 1410 1411 >>> y_true = np.random.randint(0, 2, size=(2, 3)) 1412 >>> y_pred = np.random.random(size=(2, 3)) 1413 >>> loss = tf.keras.losses.mean_squared_logarithmic_error(y_true, y_pred) 1414 >>> assert loss.shape == (2,) 1415 >>> y_true = np.maximum(y_true, 1e-7) 1416 >>> y_pred = np.maximum(y_pred, 1e-7) 1417 >>> assert np.allclose( 1418 ... loss.numpy(), 1419 ... np.mean( 1420 ... np.square(np.log(y_true + 1.) - np.log(y_pred + 1.)), axis=-1)) 1421 1422 Args: 1423 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1424 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1425 1426 Returns: 1427 Mean squared logarithmic error values. shape = `[batch_size, d0, .. dN-1]`. 1428 """ 1429 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1430 y_true = math_ops.cast(y_true, y_pred.dtype) 1431 first_log = math_ops.log(backend.maximum(y_pred, backend.epsilon()) + 1.) 1432 second_log = math_ops.log(backend.maximum(y_true, backend.epsilon()) + 1.) 1433 return backend.mean( 1434 math_ops.squared_difference(first_log, second_log), axis=-1) 1435 1436 1437@dispatch.dispatch_for_types(mean_squared_logarithmic_error, 1438 ragged_tensor.RaggedTensor) 1439def _ragged_tensor_msle(y_true, y_pred): 1440 """Implements support for handling RaggedTensors.""" 1441 return _ragged_tensor_apply_loss(mean_squared_logarithmic_error, y_true, 1442 y_pred) 1443 1444 1445def _maybe_convert_labels(y_true): 1446 """Converts binary labels into -1/1.""" 1447 are_zeros = math_ops.equal(y_true, 0) 1448 are_ones = math_ops.equal(y_true, 1) 1449 is_binary = math_ops.reduce_all(math_ops.logical_or(are_zeros, are_ones)) 1450 1451 def _convert_binary_labels(): 1452 # Convert the binary labels to -1 or 1. 1453 return 2. * y_true - 1. 1454 1455 updated_y_true = smart_cond.smart_cond(is_binary, _convert_binary_labels, 1456 lambda: y_true) 1457 return updated_y_true 1458 1459 1460@keras_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge') 1461@dispatch.add_dispatch_support 1462def squared_hinge(y_true, y_pred): 1463 """Computes the squared hinge loss between `y_true` and `y_pred`. 1464 1465 `loss = mean(square(maximum(1 - y_true * y_pred, 0)), axis=-1)` 1466 1467 Standalone usage: 1468 1469 >>> y_true = np.random.choice([-1, 1], size=(2, 3)) 1470 >>> y_pred = np.random.random(size=(2, 3)) 1471 >>> loss = tf.keras.losses.squared_hinge(y_true, y_pred) 1472 >>> assert loss.shape == (2,) 1473 >>> assert np.array_equal( 1474 ... loss.numpy(), 1475 ... np.mean(np.square(np.maximum(1. - y_true * y_pred, 0.)), axis=-1)) 1476 1477 Args: 1478 y_true: The ground truth values. `y_true` values are expected to be -1 or 1. 1479 If binary (0 or 1) labels are provided we will convert them to -1 or 1. 1480 shape = `[batch_size, d0, .. dN]`. 1481 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1482 1483 Returns: 1484 Squared hinge loss values. shape = `[batch_size, d0, .. dN-1]`. 1485 """ 1486 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1487 y_true = math_ops.cast(y_true, y_pred.dtype) 1488 y_true = _maybe_convert_labels(y_true) 1489 return backend.mean( 1490 math_ops.square(math_ops.maximum(1. - y_true * y_pred, 0.)), axis=-1) 1491 1492 1493@keras_export('keras.metrics.hinge', 'keras.losses.hinge') 1494@dispatch.add_dispatch_support 1495def hinge(y_true, y_pred): 1496 """Computes the hinge loss between `y_true` and `y_pred`. 1497 1498 `loss = mean(maximum(1 - y_true * y_pred, 0), axis=-1)` 1499 1500 Standalone usage: 1501 1502 >>> y_true = np.random.choice([-1, 1], size=(2, 3)) 1503 >>> y_pred = np.random.random(size=(2, 3)) 1504 >>> loss = tf.keras.losses.hinge(y_true, y_pred) 1505 >>> assert loss.shape == (2,) 1506 >>> assert np.array_equal( 1507 ... loss.numpy(), 1508 ... np.mean(np.maximum(1. - y_true * y_pred, 0.), axis=-1)) 1509 1510 Args: 1511 y_true: The ground truth values. `y_true` values are expected to be -1 or 1. 1512 If binary (0 or 1) labels are provided they will be converted to -1 or 1. 1513 shape = `[batch_size, d0, .. dN]`. 1514 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1515 1516 Returns: 1517 Hinge loss values. shape = `[batch_size, d0, .. dN-1]`. 1518 """ 1519 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1520 y_true = math_ops.cast(y_true, y_pred.dtype) 1521 y_true = _maybe_convert_labels(y_true) 1522 return backend.mean(math_ops.maximum(1. - y_true * y_pred, 0.), axis=-1) 1523 1524 1525@keras_export('keras.losses.categorical_hinge') 1526@dispatch.add_dispatch_support 1527def categorical_hinge(y_true, y_pred): 1528 """Computes the categorical hinge loss between `y_true` and `y_pred`. 1529 1530 `loss = maximum(neg - pos + 1, 0)` 1531 where `neg=maximum((1-y_true)*y_pred) and pos=sum(y_true*y_pred)` 1532 1533 Standalone usage: 1534 1535 >>> y_true = np.random.randint(0, 3, size=(2,)) 1536 >>> y_true = tf.keras.utils.to_categorical(y_true, num_classes=3) 1537 >>> y_pred = np.random.random(size=(2, 3)) 1538 >>> loss = tf.keras.losses.categorical_hinge(y_true, y_pred) 1539 >>> assert loss.shape == (2,) 1540 >>> pos = np.sum(y_true * y_pred, axis=-1) 1541 >>> neg = np.amax((1. - y_true) * y_pred, axis=-1) 1542 >>> assert np.array_equal(loss.numpy(), np.maximum(0., neg - pos + 1.)) 1543 1544 Args: 1545 y_true: The ground truth values. `y_true` values are expected to be 1546 either `{-1, +1}` or `{0, 1}` (i.e. a one-hot-encoded tensor). 1547 y_pred: The predicted values. 1548 1549 Returns: 1550 Categorical hinge loss values. 1551 """ 1552 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1553 y_true = math_ops.cast(y_true, y_pred.dtype) 1554 pos = math_ops.reduce_sum(y_true * y_pred, axis=-1) 1555 neg = math_ops.reduce_max((1. - y_true) * y_pred, axis=-1) 1556 zero = math_ops.cast(0., y_pred.dtype) 1557 return math_ops.maximum(neg - pos + 1., zero) 1558 1559 1560@keras_export('keras.losses.huber', v1=[]) 1561@dispatch.add_dispatch_support 1562def huber(y_true, y_pred, delta=1.0): 1563 """Computes Huber loss value. 1564 1565 For each value x in `error = y_true - y_pred`: 1566 1567 ``` 1568 loss = 0.5 * x^2 if |x| <= d 1569 loss = d * |x| - 0.5 * d^2 if |x| > d 1570 ``` 1571 where d is `delta`. See: https://en.wikipedia.org/wiki/Huber_loss 1572 1573 Args: 1574 y_true: tensor of true targets. 1575 y_pred: tensor of predicted targets. 1576 delta: A float, the point where the Huber loss function changes from a 1577 quadratic to linear. 1578 1579 Returns: 1580 Tensor with one scalar loss entry per sample. 1581 """ 1582 y_pred = math_ops.cast(y_pred, dtype=backend.floatx()) 1583 y_true = math_ops.cast(y_true, dtype=backend.floatx()) 1584 delta = math_ops.cast(delta, dtype=backend.floatx()) 1585 error = math_ops.subtract(y_pred, y_true) 1586 abs_error = math_ops.abs(error) 1587 half = ops.convert_to_tensor_v2_with_dispatch(0.5, dtype=abs_error.dtype) 1588 return backend.mean( 1589 array_ops.where_v2(abs_error <= delta, half * math_ops.square(error), 1590 delta * abs_error - half * math_ops.square(delta)), 1591 axis=-1) 1592 1593 1594@keras_export('keras.losses.log_cosh', 'keras.losses.logcosh', 1595 'keras.metrics.log_cosh', 'keras.metrics.logcosh') 1596@dispatch.add_dispatch_support 1597def log_cosh(y_true, y_pred): 1598 """Logarithm of the hyperbolic cosine of the prediction error. 1599 1600 `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and 1601 to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly 1602 like the mean squared error, but will not be so strongly affected by the 1603 occasional wildly incorrect prediction. 1604 1605 Standalone usage: 1606 1607 >>> y_true = np.random.random(size=(2, 3)) 1608 >>> y_pred = np.random.random(size=(2, 3)) 1609 >>> loss = tf.keras.losses.logcosh(y_true, y_pred) 1610 >>> assert loss.shape == (2,) 1611 >>> x = y_pred - y_true 1612 >>> assert np.allclose( 1613 ... loss.numpy(), 1614 ... np.mean(x + np.log(np.exp(-2. * x) + 1.) - math_ops.log(2.), axis=-1), 1615 ... atol=1e-5) 1616 1617 Args: 1618 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1619 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1620 1621 Returns: 1622 Logcosh error values. shape = `[batch_size, d0, .. dN-1]`. 1623 """ 1624 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1625 y_true = math_ops.cast(y_true, y_pred.dtype) 1626 1627 def _logcosh(x): 1628 return x + math_ops.softplus(-2. * x) - math_ops.cast( 1629 math_ops.log(2.), x.dtype) 1630 1631 return backend.mean(_logcosh(y_pred - y_true), axis=-1) 1632 1633 1634@keras_export('keras.metrics.categorical_crossentropy', 1635 'keras.losses.categorical_crossentropy') 1636@dispatch.add_dispatch_support 1637def categorical_crossentropy(y_true, 1638 y_pred, 1639 from_logits=False, 1640 label_smoothing=0, 1641 axis=-1): 1642 """Computes the categorical crossentropy loss. 1643 1644 Standalone usage: 1645 1646 >>> y_true = [[0, 1, 0], [0, 0, 1]] 1647 >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] 1648 >>> loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred) 1649 >>> assert loss.shape == (2,) 1650 >>> loss.numpy() 1651 array([0.0513, 2.303], dtype=float32) 1652 1653 Args: 1654 y_true: Tensor of one-hot true targets. 1655 y_pred: Tensor of predicted targets. 1656 from_logits: Whether `y_pred` is expected to be a logits tensor. By default, 1657 we assume that `y_pred` encodes a probability distribution. 1658 label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For 1659 example, if `0.1`, use `0.1 / num_classes` for non-target labels 1660 and `0.9 + 0.1 / num_classes` for target labels. 1661 axis: Defaults to -1. The dimension along which the entropy is 1662 computed. 1663 1664 Returns: 1665 Categorical crossentropy loss value. 1666 """ 1667 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1668 y_true = math_ops.cast(y_true, y_pred.dtype) 1669 label_smoothing = ops.convert_to_tensor_v2_with_dispatch( 1670 label_smoothing, dtype=backend.floatx()) 1671 1672 def _smooth_labels(): 1673 num_classes = math_ops.cast(array_ops.shape(y_true)[-1], y_pred.dtype) 1674 return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes) 1675 1676 y_true = smart_cond.smart_cond(label_smoothing, _smooth_labels, 1677 lambda: y_true) 1678 1679 return backend.categorical_crossentropy( 1680 y_true, y_pred, from_logits=from_logits, axis=axis) 1681 1682 1683@dispatch.dispatch_for_types(categorical_crossentropy, 1684 ragged_tensor.RaggedTensor) 1685def _ragged_tensor_categorical_crossentropy(y_true, 1686 y_pred, 1687 from_logits=False, 1688 label_smoothing=0, 1689 axis=-1): 1690 """Implements support for handling RaggedTensors. 1691 1692 Args: 1693 y_true: Tensor of one-hot true targets. 1694 y_pred: Tensor of predicted targets. 1695 from_logits: Whether `y_pred` is expected to be a logits tensor. By default, 1696 we assume that `y_pred` encodes a probability distribution. 1697 label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For 1698 example, if `0.1`, use `0.1 / num_classes` for non-target labels 1699 and `0.9 + 0.1 / num_classes` for target labels. 1700 axis: The axis along which to compute crossentropy (the features axis). 1701 Defaults to -1. 1702 1703 Returns: 1704 Categorical crossentropy loss value. 1705 1706 Expected shape: (batch, sequence_len, n_classes) with sequence_len 1707 being variable per batch. 1708 Return shape: (batch, sequence_len). 1709 1710 When used by CategoricalCrossentropy() with the default reduction 1711 (SUM_OVER_BATCH_SIZE), the reduction averages the loss over the 1712 number of elements independent of the batch. E.g. if the RaggedTensor 1713 has 2 batches with [2, 1] values respectivly the resulting loss is 1714 the sum of the individual loss values divided by 3. 1715 """ 1716 fn = functools.partial( 1717 categorical_crossentropy, 1718 from_logits=from_logits, 1719 label_smoothing=label_smoothing, 1720 axis=axis) 1721 return _ragged_tensor_apply_loss(fn, y_true, y_pred) 1722 1723 1724@keras_export('keras.metrics.sparse_categorical_crossentropy', 1725 'keras.losses.sparse_categorical_crossentropy') 1726@dispatch.add_dispatch_support 1727def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1): 1728 """Computes the sparse categorical crossentropy loss. 1729 1730 Standalone usage: 1731 1732 >>> y_true = [1, 2] 1733 >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] 1734 >>> loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred) 1735 >>> assert loss.shape == (2,) 1736 >>> loss.numpy() 1737 array([0.0513, 2.303], dtype=float32) 1738 1739 Args: 1740 y_true: Ground truth values. 1741 y_pred: The predicted values. 1742 from_logits: Whether `y_pred` is expected to be a logits tensor. By default, 1743 we assume that `y_pred` encodes a probability distribution. 1744 axis: Defaults to -1. The dimension along which the entropy is 1745 computed. 1746 1747 Returns: 1748 Sparse categorical crossentropy loss value. 1749 """ 1750 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1751 y_true = math_ops.cast(y_true, y_pred.dtype) 1752 return backend.sparse_categorical_crossentropy( 1753 y_true, y_pred, from_logits=from_logits, axis=axis) 1754 1755 1756@dispatch.dispatch_for_types(sparse_categorical_crossentropy, 1757 ragged_tensor.RaggedTensor) 1758def _ragged_tensor_sparse_categorical_crossentropy(y_true, 1759 y_pred, 1760 from_logits=False, 1761 axis=-1): 1762 """ Implements support for handling RaggedTensors. 1763 1764 Expected y_pred shape: (batch, sequence_len, n_classes) with sequence_len 1765 being variable per batch. 1766 Return shape: (batch, sequence_len). 1767 1768 When used by SparseCategoricalCrossentropy() with the default reduction 1769 (SUM_OVER_BATCH_SIZE), the reduction averages the loss over the 1770 number of elements independent of the batch. E.g. if the RaggedTensor 1771 has 2 batches with [2, 1] values respectively, the resulting loss is 1772 the sum of the individual loss values divided by 3. 1773 """ 1774 fn = functools.partial( 1775 sparse_categorical_crossentropy, from_logits=from_logits, axis=axis) 1776 return _ragged_tensor_apply_loss(fn, y_true, y_pred, y_pred_extra_dim=True) 1777 1778 1779@keras_export('keras.metrics.binary_crossentropy', 1780 'keras.losses.binary_crossentropy') 1781@dispatch.add_dispatch_support 1782def binary_crossentropy(y_true, 1783 y_pred, 1784 from_logits=False, 1785 label_smoothing=0, 1786 axis=-1): 1787 """Computes the binary crossentropy loss. 1788 1789 Standalone usage: 1790 1791 >>> y_true = [[0, 1], [0, 0]] 1792 >>> y_pred = [[0.6, 0.4], [0.4, 0.6]] 1793 >>> loss = tf.keras.losses.binary_crossentropy(y_true, y_pred) 1794 >>> assert loss.shape == (2,) 1795 >>> loss.numpy() 1796 array([0.916 , 0.714], dtype=float32) 1797 1798 Args: 1799 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1800 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1801 from_logits: Whether `y_pred` is expected to be a logits tensor. By default, 1802 we assume that `y_pred` encodes a probability distribution. 1803 label_smoothing: Float in [0, 1]. If > `0` then smooth the labels by 1804 squeezing them towards 0.5 That is, using `1. - 0.5 * label_smoothing` 1805 for the target class and `0.5 * label_smoothing` for the non-target class. 1806 axis: The axis along which the mean is computed. Defaults to -1. 1807 1808 Returns: 1809 Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`. 1810 """ 1811 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1812 y_true = math_ops.cast(y_true, y_pred.dtype) 1813 label_smoothing = ops.convert_to_tensor_v2_with_dispatch( 1814 label_smoothing, dtype=backend.floatx()) 1815 1816 def _smooth_labels(): 1817 return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing 1818 1819 y_true = smart_cond.smart_cond(label_smoothing, _smooth_labels, 1820 lambda: y_true) 1821 1822 return backend.mean( 1823 backend.binary_crossentropy(y_true, y_pred, from_logits=from_logits), 1824 axis=axis) 1825 1826 1827@dispatch.dispatch_for_types(binary_crossentropy, ragged_tensor.RaggedTensor) 1828def _ragged_tensor_binary_crossentropy(y_true, 1829 y_pred, 1830 from_logits=False, 1831 label_smoothing=0, 1832 axis=-1): 1833 """Implements support for handling RaggedTensors. 1834 1835 Args: 1836 y_true: Tensor of one-hot true targets. 1837 y_pred: Tensor of predicted targets. 1838 from_logits: Whether `y_pred` is expected to be a logits tensor. By default, 1839 we assume that `y_pred` encodes a probability distribution. 1840 label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For 1841 example, if `0.1`, use `0.1 / num_classes` for non-target labels 1842 and `0.9 + 0.1 / num_classes` for target labels. 1843 axis: Axis along which to compute crossentropy. 1844 1845 Returns: 1846 Binary crossentropy loss value. 1847 1848 Expected shape: (batch, sequence_len) with sequence_len being variable 1849 per batch. 1850 Return shape: (batch,); returns the per batch mean of the loss values. 1851 1852 When used by BinaryCrossentropy() with the default reduction 1853 (SUM_OVER_BATCH_SIZE), the reduction averages the per batch losses over 1854 the number of batches. 1855 """ 1856 fn = functools.partial( 1857 binary_crossentropy, 1858 from_logits=from_logits, 1859 label_smoothing=label_smoothing, 1860 axis=axis) 1861 return _ragged_tensor_apply_loss(fn, y_true, y_pred) 1862 1863 1864@keras_export('keras.metrics.kl_divergence', 1865 'keras.metrics.kullback_leibler_divergence', 'keras.metrics.kld', 1866 'keras.metrics.KLD', 'keras.losses.kl_divergence', 1867 'keras.losses.kullback_leibler_divergence', 'keras.losses.kld', 1868 'keras.losses.KLD') 1869@dispatch.add_dispatch_support 1870def kl_divergence(y_true, y_pred): 1871 """Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`. 1872 1873 `loss = y_true * log(y_true / y_pred)` 1874 1875 See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence 1876 1877 Standalone usage: 1878 1879 >>> y_true = np.random.randint(0, 2, size=(2, 3)).astype(np.float64) 1880 >>> y_pred = np.random.random(size=(2, 3)) 1881 >>> loss = tf.keras.losses.kullback_leibler_divergence(y_true, y_pred) 1882 >>> assert loss.shape == (2,) 1883 >>> y_true = tf.keras.backend.clip(y_true, 1e-7, 1) 1884 >>> y_pred = tf.keras.backend.clip(y_pred, 1e-7, 1) 1885 >>> assert np.array_equal( 1886 ... loss.numpy(), np.sum(y_true * np.log(y_true / y_pred), axis=-1)) 1887 1888 Args: 1889 y_true: Tensor of true targets. 1890 y_pred: Tensor of predicted targets. 1891 1892 Returns: 1893 A `Tensor` with loss. 1894 1895 Raises: 1896 TypeError: If `y_true` cannot be cast to the `y_pred.dtype`. 1897 """ 1898 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1899 y_true = math_ops.cast(y_true, y_pred.dtype) 1900 y_true = backend.clip(y_true, backend.epsilon(), 1) 1901 y_pred = backend.clip(y_pred, backend.epsilon(), 1) 1902 return math_ops.reduce_sum(y_true * math_ops.log(y_true / y_pred), axis=-1) 1903 1904 1905@keras_export('keras.metrics.poisson', 'keras.losses.poisson') 1906@dispatch.add_dispatch_support 1907def poisson(y_true, y_pred): 1908 """Computes the Poisson loss between y_true and y_pred. 1909 1910 The Poisson loss is the mean of the elements of the `Tensor` 1911 `y_pred - y_true * log(y_pred)`. 1912 1913 Standalone usage: 1914 1915 >>> y_true = np.random.randint(0, 2, size=(2, 3)) 1916 >>> y_pred = np.random.random(size=(2, 3)) 1917 >>> loss = tf.keras.losses.poisson(y_true, y_pred) 1918 >>> assert loss.shape == (2,) 1919 >>> y_pred = y_pred + 1e-7 1920 >>> assert np.allclose( 1921 ... loss.numpy(), np.mean(y_pred - y_true * np.log(y_pred), axis=-1), 1922 ... atol=1e-5) 1923 1924 Args: 1925 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 1926 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 1927 1928 Returns: 1929 Poisson loss value. shape = `[batch_size, d0, .. dN-1]`. 1930 1931 Raises: 1932 InvalidArgumentError: If `y_true` and `y_pred` have incompatible shapes. 1933 """ 1934 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 1935 y_true = math_ops.cast(y_true, y_pred.dtype) 1936 return backend.mean( 1937 y_pred - y_true * math_ops.log(y_pred + backend.epsilon()), axis=-1) 1938 1939 1940@keras_export( 1941 'keras.losses.cosine_similarity', 1942 v1=[ 1943 'keras.metrics.cosine_proximity', 1944 'keras.metrics.cosine', 1945 'keras.losses.cosine_proximity', 1946 'keras.losses.cosine', 1947 'keras.losses.cosine_similarity', 1948 ]) 1949@dispatch.add_dispatch_support 1950def cosine_similarity(y_true, y_pred, axis=-1): 1951 """Computes the cosine similarity between labels and predictions. 1952 1953 Note that it is a number between -1 and 1. When it is a negative number 1954 between -1 and 0, 0 indicates orthogonality and values closer to -1 1955 indicate greater similarity. The values closer to 1 indicate greater 1956 dissimilarity. This makes it usable as a loss function in a setting 1957 where you try to maximize the proximity between predictions and 1958 targets. If either `y_true` or `y_pred` is a zero vector, cosine 1959 similarity will be 0 regardless of the proximity between predictions 1960 and targets. 1961 1962 `loss = -sum(l2_norm(y_true) * l2_norm(y_pred))` 1963 1964 Standalone usage: 1965 1966 >>> y_true = [[0., 1.], [1., 1.], [1., 1.]] 1967 >>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]] 1968 >>> loss = tf.keras.losses.cosine_similarity(y_true, y_pred, axis=1) 1969 >>> loss.numpy() 1970 array([-0., -0.999, 0.999], dtype=float32) 1971 1972 Args: 1973 y_true: Tensor of true targets. 1974 y_pred: Tensor of predicted targets. 1975 axis: Axis along which to determine similarity. 1976 1977 Returns: 1978 Cosine similarity tensor. 1979 """ 1980 y_true = nn.l2_normalize(y_true, axis=axis) 1981 y_pred = nn.l2_normalize(y_pred, axis=axis) 1982 return -math_ops.reduce_sum(y_true * y_pred, axis=axis) 1983 1984 1985@keras_export('keras.losses.CosineSimilarity') 1986class CosineSimilarity(LossFunctionWrapper): 1987 """Computes the cosine similarity between labels and predictions. 1988 1989 Note that it is a number between -1 and 1. When it is a negative number 1990 between -1 and 0, 0 indicates orthogonality and values closer to -1 1991 indicate greater similarity. The values closer to 1 indicate greater 1992 dissimilarity. This makes it usable as a loss function in a setting 1993 where you try to maximize the proximity between predictions and targets. 1994 If either `y_true` or `y_pred` is a zero vector, cosine similarity will be 0 1995 regardless of the proximity between predictions and targets. 1996 1997 `loss = -sum(l2_norm(y_true) * l2_norm(y_pred))` 1998 1999 Standalone usage: 2000 2001 >>> y_true = [[0., 1.], [1., 1.]] 2002 >>> y_pred = [[1., 0.], [1., 1.]] 2003 >>> # Using 'auto'/'sum_over_batch_size' reduction type. 2004 >>> cosine_loss = tf.keras.losses.CosineSimilarity(axis=1) 2005 >>> # l2_norm(y_true) = [[0., 1.], [1./1.414, 1./1.414]] 2006 >>> # l2_norm(y_pred) = [[1., 0.], [1./1.414, 1./1.414]] 2007 >>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]] 2008 >>> # loss = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1)) 2009 >>> # = -((0. + 0.) + (0.5 + 0.5)) / 2 2010 >>> cosine_loss(y_true, y_pred).numpy() 2011 -0.5 2012 2013 >>> # Calling with 'sample_weight'. 2014 >>> cosine_loss(y_true, y_pred, sample_weight=[0.8, 0.2]).numpy() 2015 -0.0999 2016 2017 >>> # Using 'sum' reduction type. 2018 >>> cosine_loss = tf.keras.losses.CosineSimilarity(axis=1, 2019 ... reduction=tf.keras.losses.Reduction.SUM) 2020 >>> cosine_loss(y_true, y_pred).numpy() 2021 -0.999 2022 2023 >>> # Using 'none' reduction type. 2024 >>> cosine_loss = tf.keras.losses.CosineSimilarity(axis=1, 2025 ... reduction=tf.keras.losses.Reduction.NONE) 2026 >>> cosine_loss(y_true, y_pred).numpy() 2027 array([-0., -0.999], dtype=float32) 2028 2029 Usage with the `compile()` API: 2030 2031 ```python 2032 model.compile(optimizer='sgd', loss=tf.keras.losses.CosineSimilarity(axis=1)) 2033 ``` 2034 2035 Args: 2036 axis: The axis along which the cosine similarity is computed 2037 (the features axis). Defaults to -1. 2038 reduction: Type of `tf.keras.losses.Reduction` to apply to loss. 2039 Default value is `AUTO`. `AUTO` indicates that the reduction option will 2040 be determined by the usage context. For almost all cases this defaults to 2041 `SUM_OVER_BATCH_SIZE`. When used with `tf.distribute.Strategy`, outside of 2042 built-in training loops such as `tf.keras` `compile` and `fit`, using 2043 `AUTO` or `SUM_OVER_BATCH_SIZE` will raise an error. Please see this 2044 custom training [tutorial] 2045 (https://www.tensorflow.org/tutorials/distribute/custom_training) for more 2046 details. 2047 name: Optional name for the instance. 2048 """ 2049 2050 def __init__(self, 2051 axis=-1, 2052 reduction=losses_utils.ReductionV2.AUTO, 2053 name='cosine_similarity'): 2054 super().__init__( 2055 cosine_similarity, reduction=reduction, name=name, axis=axis) 2056 2057 2058# Aliases. 2059 2060bce = BCE = binary_crossentropy 2061mse = MSE = mean_squared_error 2062mae = MAE = mean_absolute_error 2063mape = MAPE = mean_absolute_percentage_error 2064msle = MSLE = mean_squared_logarithmic_error 2065kld = KLD = kullback_leibler_divergence = kl_divergence 2066logcosh = log_cosh 2067huber_loss = huber 2068 2069 2070def is_categorical_crossentropy(loss): 2071 result = ((isinstance(loss, CategoricalCrossentropy) or 2072 (isinstance(loss, LossFunctionWrapper) and 2073 loss.fn == categorical_crossentropy) or 2074 (hasattr(loss, '__name__') and 2075 loss.__name__ == 'categorical_crossentropy') or 2076 (loss == 'categorical_crossentropy'))) 2077 return result 2078 2079 2080@keras_export('keras.losses.serialize') 2081def serialize(loss): 2082 """Serializes loss function or `Loss` instance. 2083 2084 Args: 2085 loss: A Keras `Loss` instance or a loss function. 2086 2087 Returns: 2088 Loss configuration dictionary. 2089 """ 2090 return serialize_keras_object(loss) 2091 2092 2093@keras_export('keras.losses.deserialize') 2094def deserialize(name, custom_objects=None): 2095 """Deserializes a serialized loss class/function instance. 2096 2097 Args: 2098 name: Loss configuration. 2099 custom_objects: Optional dictionary mapping names (strings) to custom 2100 objects (classes and functions) to be considered during deserialization. 2101 2102 Returns: 2103 A Keras `Loss` instance or a loss function. 2104 """ 2105 return deserialize_keras_object( 2106 name, 2107 module_objects=globals(), 2108 custom_objects=custom_objects, 2109 printable_module_name='loss function') 2110 2111 2112@keras_export('keras.losses.get') 2113def get(identifier): 2114 """Retrieves a Keras loss as a `function`/`Loss` class instance. 2115 2116 The `identifier` may be the string name of a loss function or `Loss` class. 2117 2118 >>> loss = tf.keras.losses.get("categorical_crossentropy") 2119 >>> type(loss) 2120 <class 'function'> 2121 >>> loss = tf.keras.losses.get("CategoricalCrossentropy") 2122 >>> type(loss) 2123 <class '...keras.losses.CategoricalCrossentropy'> 2124 2125 You can also specify `config` of the loss to this function by passing dict 2126 containing `class_name` and `config` as an identifier. Also note that the 2127 `class_name` must map to a `Loss` class 2128 2129 >>> identifier = {"class_name": "CategoricalCrossentropy", 2130 ... "config": {"from_logits": True}} 2131 >>> loss = tf.keras.losses.get(identifier) 2132 >>> type(loss) 2133 <class '...keras.losses.CategoricalCrossentropy'> 2134 2135 Args: 2136 identifier: A loss identifier. One of None or string name of a loss 2137 function/class or loss configuration dictionary or a loss function or a 2138 loss class instance. 2139 2140 Returns: 2141 A Keras loss as a `function`/ `Loss` class instance. 2142 2143 Raises: 2144 ValueError: If `identifier` cannot be interpreted. 2145 """ 2146 if identifier is None: 2147 return None 2148 if isinstance(identifier, str): 2149 identifier = str(identifier) 2150 return deserialize(identifier) 2151 if isinstance(identifier, dict): 2152 return deserialize(identifier) 2153 if callable(identifier): 2154 return identifier 2155 raise ValueError( 2156 f'Could not interpret loss function identifier: {identifier}') 2157 2158 2159LABEL_DTYPES_FOR_LOSSES = { 2160 losses_impl.sparse_softmax_cross_entropy: 'int32', 2161 sparse_categorical_crossentropy: 'int32' 2162} 2163