1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-classes-have-attributes 16# pylint: disable=g-doc-return-or-yield 17"""Built-in metrics.""" 18 19import abc 20import types 21import warnings 22 23import numpy as np 24 25from tensorflow.python.autograph.core import ag_ctx 26from tensorflow.python.autograph.impl import api as autograph 27from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx 28from tensorflow.python.eager import context 29from tensorflow.python.eager import def_function 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.keras import activations 35from tensorflow.python.keras import backend 36from tensorflow.python.keras.engine import base_layer 37from tensorflow.python.keras.engine import base_layer_utils 38from tensorflow.python.keras.engine import keras_tensor 39from tensorflow.python.keras.losses import binary_crossentropy 40from tensorflow.python.keras.losses import categorical_crossentropy 41from tensorflow.python.keras.losses import categorical_hinge 42from tensorflow.python.keras.losses import hinge 43from tensorflow.python.keras.losses import kullback_leibler_divergence 44from tensorflow.python.keras.losses import logcosh 45from tensorflow.python.keras.losses import mean_absolute_error 46from tensorflow.python.keras.losses import mean_absolute_percentage_error 47from tensorflow.python.keras.losses import mean_squared_error 48from tensorflow.python.keras.losses import mean_squared_logarithmic_error 49from tensorflow.python.keras.losses import poisson 50from tensorflow.python.keras.losses import sparse_categorical_crossentropy 51from tensorflow.python.keras.losses import squared_hinge 52from tensorflow.python.keras.saving.saved_model import metric_serialization 53from tensorflow.python.keras.utils import generic_utils 54from tensorflow.python.keras.utils import losses_utils 55from tensorflow.python.keras.utils import metrics_utils 56from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 57from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 58from tensorflow.python.keras.utils.generic_utils import to_list 59from tensorflow.python.keras.utils.tf_utils import is_tensor_or_variable 60from tensorflow.python.ops import array_ops 61from tensorflow.python.ops import check_ops 62from tensorflow.python.ops import confusion_matrix 63from tensorflow.python.ops import init_ops 64from tensorflow.python.ops import math_ops 65from tensorflow.python.ops import nn 66from tensorflow.python.ops import variables as variables_module 67from tensorflow.python.ops import weights_broadcast_ops 68from tensorflow.python.util import dispatch 69from tensorflow.python.util import nest 70from tensorflow.python.util.tf_export import keras_export 71from tensorflow.tools.docs import doc_controls 72 73 74@keras_export('keras.metrics.Metric') 75class Metric(base_layer.Layer, metaclass=abc.ABCMeta): 76 """Encapsulates metric logic and state. 77 78 Args: 79 name: (Optional) string name of the metric instance. 80 dtype: (Optional) data type of the metric result. 81 **kwargs: Additional layer keywords arguments. 82 83 Standalone usage: 84 85 ```python 86 m = SomeMetric(...) 87 for input in ...: 88 m.update_state(input) 89 print('Final result: ', m.result().numpy()) 90 ``` 91 92 Usage with `compile()` API: 93 94 ```python 95 model = tf.keras.Sequential() 96 model.add(tf.keras.layers.Dense(64, activation='relu')) 97 model.add(tf.keras.layers.Dense(64, activation='relu')) 98 model.add(tf.keras.layers.Dense(10, activation='softmax')) 99 100 model.compile(optimizer=tf.keras.optimizers.RMSprop(0.01), 101 loss=tf.keras.losses.CategoricalCrossentropy(), 102 metrics=[tf.keras.metrics.CategoricalAccuracy()]) 103 104 data = np.random.random((1000, 32)) 105 labels = np.random.random((1000, 10)) 106 107 dataset = tf.data.Dataset.from_tensor_slices((data, labels)) 108 dataset = dataset.batch(32) 109 110 model.fit(dataset, epochs=10) 111 ``` 112 113 To be implemented by subclasses: 114 * `__init__()`: All state variables should be created in this method by 115 calling `self.add_weight()` like: `self.var = self.add_weight(...)` 116 * `update_state()`: Has all updates to the state variables like: 117 self.var.assign_add(...). 118 * `result()`: Computes and returns a value for the metric 119 from the state variables. 120 121 Example subclass implementation: 122 123 ```python 124 class BinaryTruePositives(tf.keras.metrics.Metric): 125 126 def __init__(self, name='binary_true_positives', **kwargs): 127 super(BinaryTruePositives, self).__init__(name=name, **kwargs) 128 self.true_positives = self.add_weight(name='tp', initializer='zeros') 129 130 def update_state(self, y_true, y_pred, sample_weight=None): 131 y_true = tf.cast(y_true, tf.bool) 132 y_pred = tf.cast(y_pred, tf.bool) 133 134 values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True)) 135 values = tf.cast(values, self.dtype) 136 if sample_weight is not None: 137 sample_weight = tf.cast(sample_weight, self.dtype) 138 sample_weight = tf.broadcast_to(sample_weight, values.shape) 139 values = tf.multiply(values, sample_weight) 140 self.true_positives.assign_add(tf.reduce_sum(values)) 141 142 def result(self): 143 return self.true_positives 144 ``` 145 """ 146 147 def __init__(self, name=None, dtype=None, **kwargs): 148 super(Metric, self).__init__(name=name, dtype=dtype, **kwargs) 149 self.stateful = True # All metric layers are stateful. 150 self.built = True 151 if not base_layer_utils.v2_dtype_behavior_enabled(): 152 # We only do this when the V2 behavior is not enabled, as when it is 153 # enabled, the dtype already defaults to floatx. 154 self._dtype = (backend.floatx() if dtype is None 155 else dtypes.as_dtype(dtype).name) 156 157 def __new__(cls, *args, **kwargs): 158 obj = super(Metric, cls).__new__(cls) 159 160 # If `update_state` is not in eager/tf.function and it is not from a 161 # built-in metric, wrap it in `tf.function`. This is so that users writing 162 # custom metrics in v1 need not worry about control dependencies and 163 # return ops. 164 if (base_layer_utils.is_in_eager_or_tf_function() or 165 is_built_in(cls)): 166 obj_update_state = obj.update_state 167 168 def update_state_fn(*args, **kwargs): 169 control_status = ag_ctx.control_status_ctx() 170 ag_update_state = autograph.tf_convert(obj_update_state, control_status) 171 return ag_update_state(*args, **kwargs) 172 else: 173 if isinstance(obj.update_state, def_function.Function): 174 update_state_fn = obj.update_state 175 else: 176 update_state_fn = def_function.function(obj.update_state) 177 178 obj.update_state = types.MethodType( 179 metrics_utils.update_state_wrapper(update_state_fn), obj) 180 181 obj_result = obj.result 182 183 def result_fn(*args, **kwargs): 184 control_status = ag_ctx.control_status_ctx() 185 ag_result = autograph.tf_convert(obj_result, control_status) 186 return ag_result(*args, **kwargs) 187 188 obj.result = types.MethodType(metrics_utils.result_wrapper(result_fn), obj) 189 190 return obj 191 192 def __call__(self, *args, **kwargs): 193 """Accumulates statistics and then computes metric result value. 194 195 Args: 196 *args: 197 **kwargs: A mini-batch of inputs to the Metric, 198 passed on to `update_state()`. 199 200 Returns: 201 The metric value tensor. 202 """ 203 204 def replica_local_fn(*args, **kwargs): 205 """Updates the state of the metric in a replica-local context.""" 206 if any( 207 isinstance(arg, keras_tensor.KerasTensor) 208 for arg in nest.flatten((args, kwargs))): 209 update_op = None 210 else: 211 update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable 212 update_ops = [] 213 if update_op is not None: 214 update_ops.append(update_op) 215 with ops.control_dependencies(update_ops): 216 result_t = self.result() # pylint: disable=not-callable 217 218 # We are adding the metric object as metadata on the result tensor. 219 # This is required when we want to use a metric with `add_metric` API on 220 # a Model/Layer in graph mode. This metric instance will later be used 221 # to reset variable state after each epoch of training. 222 # Example: 223 # model = Model() 224 # mean = Mean() 225 # model.add_metric(mean(values), name='mean') 226 result_t._metric_obj = self # pylint: disable=protected-access 227 return result_t 228 229 from tensorflow.python.keras.distribute import distributed_training_utils # pylint:disable=g-import-not-at-top 230 return distributed_training_utils.call_replica_local_fn( 231 replica_local_fn, *args, **kwargs) 232 233 @property 234 def dtype(self): 235 return self._dtype 236 237 def get_config(self): 238 """Returns the serializable config of the metric.""" 239 return {'name': self.name, 'dtype': self.dtype} 240 241 def reset_state(self): 242 """Resets all of the metric state variables. 243 244 This function is called between epochs/steps, 245 when a metric is evaluated during training. 246 """ 247 if not generic_utils.is_default(self.reset_states): 248 warnings.warn('Metric %s implements a `reset_states()` method; rename it ' 249 'to `reset_state()` (without the final "s"). The name ' 250 '`reset_states()` has been deprecated to improve API ' 251 'consistency.' % (self.__class__.__name__,)) 252 return self.reset_states() 253 else: 254 backend.batch_set_value([(v, 0) for v in self.variables]) 255 256 @abc.abstractmethod 257 def update_state(self, *args, **kwargs): 258 """Accumulates statistics for the metric. 259 260 Note: This function is executed as a graph function in graph mode. 261 This means: 262 a) Operations on the same resource are executed in textual order. 263 This should make it easier to do things like add the updated 264 value of a variable to another, for example. 265 b) You don't need to worry about collecting the update ops to execute. 266 All update ops added to the graph by this function will be executed. 267 As a result, code should generally work the same way with graph or 268 eager execution. 269 270 Args: 271 *args: 272 **kwargs: A mini-batch of inputs to the Metric. 273 """ 274 raise NotImplementedError('Must be implemented in subclasses.') 275 276 @abc.abstractmethod 277 def result(self): 278 """Computes and returns the metric value tensor. 279 280 Result computation is an idempotent operation that simply calculates the 281 metric value using the state variables. 282 """ 283 raise NotImplementedError('Must be implemented in subclasses.') 284 285 ### For use by subclasses ### 286 @doc_controls.for_subclass_implementers 287 def add_weight( 288 self, 289 name, 290 shape=(), 291 aggregation=variables_module.VariableAggregation.SUM, 292 synchronization=variables_module.VariableSynchronization.ON_READ, 293 initializer=None, 294 dtype=None): 295 """Adds state variable. Only for use by subclasses.""" 296 if distribute_ctx.has_strategy(): 297 strategy = distribute_ctx.get_strategy() 298 else: 299 strategy = None 300 301 # TODO(b/120571621): Make `ON_READ` work with Keras metrics on TPU. 302 if backend.is_tpu_strategy(strategy): 303 synchronization = variables_module.VariableSynchronization.ON_WRITE 304 305 with ops.init_scope(): 306 return super(Metric, self).add_weight( 307 name=name, 308 shape=shape, 309 dtype=self._dtype if dtype is None else dtype, 310 trainable=False, 311 initializer=initializer, 312 collections=[], 313 synchronization=synchronization, 314 aggregation=aggregation) 315 316 ### End: For use by subclasses ### 317 318 @property 319 def trainable_weights(self): 320 # Overridden from Layer class to track submetric weights. 321 if self.trainable: 322 trainable_weights = self._trainable_weights 323 for m in self._metrics: 324 trainable_weights += m.trainable_weights 325 return self._dedup_weights(trainable_weights) 326 else: 327 return [] 328 329 @property 330 def non_trainable_weights(self): 331 # Overridden from Layer class to track submetric weights. 332 if self.trainable: 333 non_trainable_weights = self._non_trainable_weights 334 for m in self._metrics: 335 non_trainable_weights += m.non_trainable_weights 336 else: 337 non_trainable_weights = ( 338 self._non_trainable_weights + self._trainable_weights) 339 for m in self._metrics: 340 non_trainable_weights += m.weights 341 return self._dedup_weights(non_trainable_weights) 342 343 @property 344 def _trackable_saved_model_saver(self): 345 return metric_serialization.MetricSavedModelSaver(self) 346 347 @generic_utils.default 348 @doc_controls.do_not_generate_docs 349 def reset_states(self): 350 # Backwards compatibility alias of `reset_state`. New classes should 351 # only implement `reset_state`. 352 return self.reset_state() 353 354 355class Reduce(Metric): 356 """Encapsulates metrics that perform a reduce operation on the values. 357 358 Args: 359 reduction: a `tf.keras.metrics.Reduction` enum value. 360 name: string name of the metric instance. 361 dtype: (Optional) data type of the metric result. 362 """ 363 364 def __init__(self, reduction, name, dtype=None): 365 super(Reduce, self).__init__(name=name, dtype=dtype) 366 self.reduction = reduction 367 self.total = self.add_weight( 368 'total', initializer=init_ops.zeros_initializer) 369 if reduction in [metrics_utils.Reduction.SUM_OVER_BATCH_SIZE, 370 metrics_utils.Reduction.WEIGHTED_MEAN]: 371 self.count = self.add_weight( 372 'count', initializer=init_ops.zeros_initializer) 373 374 def update_state(self, values, sample_weight=None): 375 """Accumulates statistics for computing the metric. 376 377 Args: 378 values: Per-example value. 379 sample_weight: Optional weighting of each example. Defaults to 1. 380 381 Returns: 382 Update op. 383 """ 384 [values], sample_weight = \ 385 metrics_utils.ragged_assert_compatible_and_get_flat_values( 386 [values], sample_weight) 387 try: 388 values = math_ops.cast(values, self._dtype) 389 except (ValueError, TypeError): 390 msg = ('The output of a metric function can only be a single Tensor. ' 391 'Got: %s' % (values,)) 392 if isinstance(values, dict): 393 msg += ('. To return a dict of values, implement a custom Metric ' 394 'subclass.') 395 raise RuntimeError(msg) 396 if sample_weight is not None: 397 sample_weight = math_ops.cast(sample_weight, self._dtype) 398 # Update dimensions of weights to match with values if possible. 399 values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions( 400 values, sample_weight=sample_weight) 401 try: 402 # Broadcast weights if possible. 403 sample_weight = weights_broadcast_ops.broadcast_weights( 404 sample_weight, values) 405 except ValueError: 406 # Reduce values to same ndim as weight array 407 ndim = backend.ndim(values) 408 weight_ndim = backend.ndim(sample_weight) 409 if self.reduction == metrics_utils.Reduction.SUM: 410 values = math_ops.reduce_sum( 411 values, axis=list(range(weight_ndim, ndim))) 412 else: 413 values = math_ops.reduce_mean( 414 values, axis=list(range(weight_ndim, ndim))) 415 values = math_ops.multiply(values, sample_weight) 416 417 value_sum = math_ops.reduce_sum(values) 418 with ops.control_dependencies([value_sum]): 419 update_total_op = self.total.assign_add(value_sum) 420 421 # Exit early if the reduction doesn't have a denominator. 422 if self.reduction == metrics_utils.Reduction.SUM: 423 return update_total_op 424 425 # Update `count` for reductions that require a denominator. 426 if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE: 427 num_values = math_ops.cast(array_ops.size(values), self._dtype) 428 elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN: 429 if sample_weight is None: 430 num_values = math_ops.cast(array_ops.size(values), self._dtype) 431 else: 432 num_values = math_ops.reduce_sum(sample_weight) 433 else: 434 raise NotImplementedError( 435 'reduction [%s] not implemented' % self.reduction) 436 437 with ops.control_dependencies([update_total_op]): 438 return self.count.assign_add(num_values) 439 440 def result(self): 441 if self.reduction == metrics_utils.Reduction.SUM: 442 return array_ops.identity(self.total) 443 elif self.reduction in [ 444 metrics_utils.Reduction.WEIGHTED_MEAN, 445 metrics_utils.Reduction.SUM_OVER_BATCH_SIZE 446 ]: 447 return math_ops.div_no_nan(self.total, self.count) 448 else: 449 raise NotImplementedError( 450 'reduction [%s] not implemented' % self.reduction) 451 452 453@keras_export('keras.metrics.Sum') 454class Sum(Reduce): 455 """Computes the (weighted) sum of the given values. 456 457 For example, if values is [1, 3, 5, 7] then the sum is 16. 458 If the weights were specified as [1, 1, 0, 0] then the sum would be 4. 459 460 This metric creates one variable, `total`, that is used to compute the sum of 461 `values`. This is ultimately returned as `sum`. 462 463 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 464 to mask values. 465 466 Args: 467 name: (Optional) string name of the metric instance. 468 dtype: (Optional) data type of the metric result. 469 470 Standalone usage: 471 472 >>> m = tf.keras.metrics.Sum() 473 >>> m.update_state([1, 3, 5, 7]) 474 >>> m.result().numpy() 475 16.0 476 477 Usage with `compile()` API: 478 479 ```python 480 model.add_metric(tf.keras.metrics.Sum(name='sum_1')(outputs)) 481 model.compile(optimizer='sgd', loss='mse') 482 ``` 483 """ 484 485 def __init__(self, name='sum', dtype=None): 486 super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM, 487 name=name, dtype=dtype) 488 489 490@keras_export('keras.metrics.Mean') 491class Mean(Reduce): 492 """Computes the (weighted) mean of the given values. 493 494 For example, if values is [1, 3, 5, 7] then the mean is 4. 495 If the weights were specified as [1, 1, 0, 0] then the mean would be 2. 496 497 This metric creates two variables, `total` and `count` that are used to 498 compute the average of `values`. This average is ultimately returned as `mean` 499 which is an idempotent operation that simply divides `total` by `count`. 500 501 If `sample_weight` is `None`, weights default to 1. 502 Use `sample_weight` of 0 to mask values. 503 504 Args: 505 name: (Optional) string name of the metric instance. 506 dtype: (Optional) data type of the metric result. 507 508 Standalone usage: 509 510 >>> m = tf.keras.metrics.Mean() 511 >>> m.update_state([1, 3, 5, 7]) 512 >>> m.result().numpy() 513 4.0 514 >>> m.reset_state() 515 >>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0]) 516 >>> m.result().numpy() 517 2.0 518 519 Usage with `compile()` API: 520 521 ```python 522 model.add_metric(tf.keras.metrics.Mean(name='mean_1')(outputs)) 523 model.compile(optimizer='sgd', loss='mse') 524 ``` 525 """ 526 527 def __init__(self, name='mean', dtype=None): 528 super(Mean, self).__init__( 529 reduction=metrics_utils.Reduction.WEIGHTED_MEAN, name=name, dtype=dtype) 530 531 532@keras_export('keras.metrics.MeanRelativeError') 533class MeanRelativeError(Mean): 534 """Computes the mean relative error by normalizing with the given values. 535 536 This metric creates two local variables, `total` and `count` that are used to 537 compute the mean relative error. This is weighted by `sample_weight`, and 538 it is ultimately returned as `mean_relative_error`: 539 an idempotent operation that simply divides `total` by `count`. 540 541 If `sample_weight` is `None`, weights default to 1. 542 Use `sample_weight` of 0 to mask values. 543 544 Args: 545 normalizer: The normalizer values with same shape as predictions. 546 name: (Optional) string name of the metric instance. 547 dtype: (Optional) data type of the metric result. 548 549 Standalone usage: 550 551 >>> m = tf.keras.metrics.MeanRelativeError(normalizer=[1, 3, 2, 3]) 552 >>> m.update_state([1, 3, 2, 3], [2, 4, 6, 8]) 553 554 >>> # metric = mean(|y_pred - y_true| / normalizer) 555 >>> # = mean([1, 1, 4, 5] / [1, 3, 2, 3]) = mean([1, 1/3, 2, 5/3]) 556 >>> # = 5/4 = 1.25 557 >>> m.result().numpy() 558 1.25 559 560 Usage with `compile()` API: 561 562 ```python 563 model.compile( 564 optimizer='sgd', 565 loss='mse', 566 metrics=[tf.keras.metrics.MeanRelativeError(normalizer=[1, 3])]) 567 ``` 568 """ 569 570 def __init__(self, normalizer, name=None, dtype=None): 571 super(MeanRelativeError, self).__init__(name=name, dtype=dtype) 572 normalizer = math_ops.cast(normalizer, self._dtype) 573 self.normalizer = normalizer 574 575 def update_state(self, y_true, y_pred, sample_weight=None): 576 """Accumulates metric statistics. 577 578 Args: 579 y_true: The ground truth values. 580 y_pred: The predicted values. 581 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 582 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 583 be broadcastable to `y_true`. 584 585 Returns: 586 Update op. 587 """ 588 y_true = math_ops.cast(y_true, self._dtype) 589 y_pred = math_ops.cast(y_pred, self._dtype) 590 [y_pred, y_true], sample_weight = \ 591 metrics_utils.ragged_assert_compatible_and_get_flat_values( 592 [y_pred, y_true], sample_weight) 593 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 594 y_pred, y_true) 595 596 y_pred, self.normalizer = losses_utils.remove_squeezable_dimensions( 597 y_pred, self.normalizer) 598 y_pred.shape.assert_is_compatible_with(y_true.shape) 599 relative_errors = math_ops.div_no_nan( 600 math_ops.abs(y_true - y_pred), self.normalizer) 601 602 return super(MeanRelativeError, self).update_state( 603 relative_errors, sample_weight=sample_weight) 604 605 def get_config(self): 606 n = self.normalizer 607 config = {'normalizer': backend.eval(n) if is_tensor_or_variable(n) else n} 608 base_config = super(MeanRelativeError, self).get_config() 609 return dict(list(base_config.items()) + list(config.items())) 610 611 612@keras_export('keras.metrics.MeanMetricWrapper') 613class MeanMetricWrapper(Mean): 614 """Wraps a stateless metric function with the Mean metric. 615 616 You could use this class to quickly build a mean metric from a function. The 617 function needs to have the signature `fn(y_true, y_pred)` and return a 618 per-sample loss array. `MeanMetricWrapper.result()` will return 619 the average metric value across all samples seen so far. 620 621 For example: 622 623 ```python 624 def accuracy(y_true, y_pred): 625 return tf.cast(tf.math.equal(y_true, y_pred), tf.float32) 626 627 accuracy_metric = tf.keras.metrics.MeanMetricWrapper(fn=accuracy) 628 629 keras_model.compile(..., metrics=accuracy_metric) 630 ``` 631 632 Args: 633 fn: The metric function to wrap, with signature `fn(y_true, y_pred, 634 **kwargs)`. 635 name: (Optional) string name of the metric instance. 636 dtype: (Optional) data type of the metric result. 637 **kwargs: Keyword arguments to pass on to `fn`. 638 """ 639 640 def __init__(self, fn, name=None, dtype=None, **kwargs): 641 super(MeanMetricWrapper, self).__init__(name=name, dtype=dtype) 642 self._fn = fn 643 self._fn_kwargs = kwargs 644 645 def update_state(self, y_true, y_pred, sample_weight=None): 646 """Accumulates metric statistics. 647 648 `y_true` and `y_pred` should have the same shape. 649 650 Args: 651 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 652 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 653 sample_weight: Optional `sample_weight` acts as a 654 coefficient for the metric. If a scalar is provided, then the metric is 655 simply scaled by the given value. If `sample_weight` is a tensor of size 656 `[batch_size]`, then the metric for each sample of the batch is rescaled 657 by the corresponding element in the `sample_weight` vector. If the shape 658 of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted 659 to this shape), then each metric element of `y_pred` is scaled by the 660 corresponding value of `sample_weight`. (Note on `dN-1`: all metric 661 functions reduce by 1 dimension, usually the last axis (-1)). 662 663 Returns: 664 Update op. 665 """ 666 y_true = math_ops.cast(y_true, self._dtype) 667 y_pred = math_ops.cast(y_pred, self._dtype) 668 [y_true, y_pred], sample_weight = ( 669 metrics_utils.ragged_assert_compatible_and_get_flat_values( 670 [y_true, y_pred], sample_weight)) 671 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 672 y_pred, y_true) 673 674 ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx()) 675 matches = ag_fn(y_true, y_pred, **self._fn_kwargs) 676 return super(MeanMetricWrapper, self).update_state( 677 matches, sample_weight=sample_weight) 678 679 def get_config(self): 680 config = {} 681 682 if type(self) is MeanMetricWrapper: # pylint: disable=unidiomatic-typecheck 683 # Only include function argument when the object is a MeanMetricWrapper 684 # and not a subclass. 685 config['fn'] = self._fn 686 687 for k, v in self._fn_kwargs.items(): 688 config[k] = backend.eval(v) if is_tensor_or_variable(v) else v 689 base_config = super(MeanMetricWrapper, self).get_config() 690 return dict(list(base_config.items()) + list(config.items())) 691 692 @classmethod 693 def from_config(cls, config): 694 # Note that while MeanMetricWrapper itself isn't public, objects of this 695 # class may be created and added to the model by calling model.compile. 696 fn = config.pop('fn', None) 697 if cls is MeanMetricWrapper: 698 return cls(get(fn), **config) 699 return super(MeanMetricWrapper, cls).from_config(config) 700 701 702@keras_export('keras.metrics.Accuracy') 703class Accuracy(MeanMetricWrapper): 704 """Calculates how often predictions equal labels. 705 706 This metric creates two local variables, `total` and `count` that are used to 707 compute the frequency with which `y_pred` matches `y_true`. This frequency is 708 ultimately returned as `binary accuracy`: an idempotent operation that simply 709 divides `total` by `count`. 710 711 If `sample_weight` is `None`, weights default to 1. 712 Use `sample_weight` of 0 to mask values. 713 714 Args: 715 name: (Optional) string name of the metric instance. 716 dtype: (Optional) data type of the metric result. 717 718 Standalone usage: 719 720 >>> m = tf.keras.metrics.Accuracy() 721 >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]]) 722 >>> m.result().numpy() 723 0.75 724 725 >>> m.reset_state() 726 >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]], 727 ... sample_weight=[1, 1, 0, 0]) 728 >>> m.result().numpy() 729 0.5 730 731 Usage with `compile()` API: 732 733 ```python 734 model.compile(optimizer='sgd', 735 loss='mse', 736 metrics=[tf.keras.metrics.Accuracy()]) 737 ``` 738 """ 739 740 def __init__(self, name='accuracy', dtype=None): 741 super(Accuracy, self).__init__(accuracy, name, dtype=dtype) 742 743 744@keras_export('keras.metrics.BinaryAccuracy') 745class BinaryAccuracy(MeanMetricWrapper): 746 """Calculates how often predictions match binary labels. 747 748 This metric creates two local variables, `total` and `count` that are used to 749 compute the frequency with which `y_pred` matches `y_true`. This frequency is 750 ultimately returned as `binary accuracy`: an idempotent operation that simply 751 divides `total` by `count`. 752 753 If `sample_weight` is `None`, weights default to 1. 754 Use `sample_weight` of 0 to mask values. 755 756 Args: 757 name: (Optional) string name of the metric instance. 758 dtype: (Optional) data type of the metric result. 759 threshold: (Optional) Float representing the threshold for deciding 760 whether prediction values are 1 or 0. 761 762 Standalone usage: 763 764 >>> m = tf.keras.metrics.BinaryAccuracy() 765 >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]]) 766 >>> m.result().numpy() 767 0.75 768 769 >>> m.reset_state() 770 >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]], 771 ... sample_weight=[1, 0, 0, 1]) 772 >>> m.result().numpy() 773 0.5 774 775 Usage with `compile()` API: 776 777 ```python 778 model.compile(optimizer='sgd', 779 loss='mse', 780 metrics=[tf.keras.metrics.BinaryAccuracy()]) 781 ``` 782 """ 783 784 def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5): 785 super(BinaryAccuracy, self).__init__( 786 binary_accuracy, name, dtype=dtype, threshold=threshold) 787 788 789@keras_export('keras.metrics.CategoricalAccuracy') 790class CategoricalAccuracy(MeanMetricWrapper): 791 """Calculates how often predictions match one-hot labels. 792 793 You can provide logits of classes as `y_pred`, since argmax of 794 logits and probabilities are same. 795 796 This metric creates two local variables, `total` and `count` that are used to 797 compute the frequency with which `y_pred` matches `y_true`. This frequency is 798 ultimately returned as `categorical accuracy`: an idempotent operation that 799 simply divides `total` by `count`. 800 801 `y_pred` and `y_true` should be passed in as vectors of probabilities, rather 802 than as labels. If necessary, use `tf.one_hot` to expand `y_true` as a vector. 803 804 If `sample_weight` is `None`, weights default to 1. 805 Use `sample_weight` of 0 to mask values. 806 807 Args: 808 name: (Optional) string name of the metric instance. 809 dtype: (Optional) data type of the metric result. 810 811 Standalone usage: 812 813 >>> m = tf.keras.metrics.CategoricalAccuracy() 814 >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], 815 ... [0.05, 0.95, 0]]) 816 >>> m.result().numpy() 817 0.5 818 819 >>> m.reset_state() 820 >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], 821 ... [0.05, 0.95, 0]], 822 ... sample_weight=[0.7, 0.3]) 823 >>> m.result().numpy() 824 0.3 825 826 Usage with `compile()` API: 827 828 ```python 829 model.compile( 830 optimizer='sgd', 831 loss='mse', 832 metrics=[tf.keras.metrics.CategoricalAccuracy()]) 833 ``` 834 """ 835 836 def __init__(self, name='categorical_accuracy', dtype=None): 837 super(CategoricalAccuracy, self).__init__( 838 categorical_accuracy, name, dtype=dtype) 839 840 841@keras_export('keras.metrics.SparseCategoricalAccuracy') 842class SparseCategoricalAccuracy(MeanMetricWrapper): 843 """Calculates how often predictions match integer labels. 844 845 ```python 846 acc = np.dot(sample_weight, np.equal(y_true, np.argmax(y_pred, axis=1)) 847 ``` 848 849 You can provide logits of classes as `y_pred`, since argmax of 850 logits and probabilities are same. 851 852 This metric creates two local variables, `total` and `count` that are used to 853 compute the frequency with which `y_pred` matches `y_true`. This frequency is 854 ultimately returned as `sparse categorical accuracy`: an idempotent operation 855 that simply divides `total` by `count`. 856 857 If `sample_weight` is `None`, weights default to 1. 858 Use `sample_weight` of 0 to mask values. 859 860 Args: 861 name: (Optional) string name of the metric instance. 862 dtype: (Optional) data type of the metric result. 863 864 Standalone usage: 865 866 >>> m = tf.keras.metrics.SparseCategoricalAccuracy() 867 >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]]) 868 >>> m.result().numpy() 869 0.5 870 871 >>> m.reset_state() 872 >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]], 873 ... sample_weight=[0.7, 0.3]) 874 >>> m.result().numpy() 875 0.3 876 877 Usage with `compile()` API: 878 879 ```python 880 model.compile( 881 optimizer='sgd', 882 loss='mse', 883 metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) 884 ``` 885 """ 886 887 def __init__(self, name='sparse_categorical_accuracy', dtype=None): 888 super(SparseCategoricalAccuracy, self).__init__( 889 sparse_categorical_accuracy, name, dtype=dtype) 890 891 892@keras_export('keras.metrics.TopKCategoricalAccuracy') 893class TopKCategoricalAccuracy(MeanMetricWrapper): 894 """Computes how often targets are in the top `K` predictions. 895 896 Args: 897 k: (Optional) Number of top elements to look at for computing accuracy. 898 Defaults to 5. 899 name: (Optional) string name of the metric instance. 900 dtype: (Optional) data type of the metric result. 901 902 Standalone usage: 903 904 >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1) 905 >>> m.update_state([[0, 0, 1], [0, 1, 0]], 906 ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) 907 >>> m.result().numpy() 908 0.5 909 910 >>> m.reset_state() 911 >>> m.update_state([[0, 0, 1], [0, 1, 0]], 912 ... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]], 913 ... sample_weight=[0.7, 0.3]) 914 >>> m.result().numpy() 915 0.3 916 917 Usage with `compile()` API: 918 919 ```python 920 model.compile(optimizer='sgd', 921 loss='mse', 922 metrics=[tf.keras.metrics.TopKCategoricalAccuracy()]) 923 ``` 924 """ 925 926 def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None): 927 super(TopKCategoricalAccuracy, self).__init__( 928 top_k_categorical_accuracy, name, dtype=dtype, k=k) 929 930 931@keras_export('keras.metrics.SparseTopKCategoricalAccuracy') 932class SparseTopKCategoricalAccuracy(MeanMetricWrapper): 933 """Computes how often integer targets are in the top `K` predictions. 934 935 Args: 936 k: (Optional) Number of top elements to look at for computing accuracy. 937 Defaults to 5. 938 name: (Optional) string name of the metric instance. 939 dtype: (Optional) data type of the metric result. 940 941 Standalone usage: 942 943 >>> m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1) 944 >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]) 945 >>> m.result().numpy() 946 0.5 947 948 >>> m.reset_state() 949 >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]], 950 ... sample_weight=[0.7, 0.3]) 951 >>> m.result().numpy() 952 0.3 953 954 Usage with `compile()` API: 955 956 ```python 957 model.compile( 958 optimizer='sgd', 959 loss='mse', 960 metrics=[tf.keras.metrics.SparseTopKCategoricalAccuracy()]) 961 ``` 962 """ 963 964 def __init__(self, k=5, name='sparse_top_k_categorical_accuracy', dtype=None): 965 super(SparseTopKCategoricalAccuracy, self).__init__( 966 sparse_top_k_categorical_accuracy, name, dtype=dtype, k=k) 967 968 969class _ConfusionMatrixConditionCount(Metric): 970 """Calculates the number of the given confusion matrix condition. 971 972 Args: 973 confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions. 974 thresholds: (Optional) Defaults to 0.5. A float value or a python list/tuple 975 of float threshold values in [0, 1]. A threshold is compared with 976 prediction values to determine the truth value of predictions (i.e., above 977 the threshold is `true`, below is `false`). One metric value is generated 978 for each threshold value. 979 name: (Optional) string name of the metric instance. 980 dtype: (Optional) data type of the metric result. 981 """ 982 983 def __init__(self, 984 confusion_matrix_cond, 985 thresholds=None, 986 name=None, 987 dtype=None): 988 super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype) 989 self._confusion_matrix_cond = confusion_matrix_cond 990 self.init_thresholds = thresholds 991 self.thresholds = metrics_utils.parse_init_thresholds( 992 thresholds, default_threshold=0.5) 993 self._thresholds_distributed_evenly = ( 994 metrics_utils.is_evenly_distributed_thresholds(self.thresholds)) 995 self.accumulator = self.add_weight( 996 'accumulator', 997 shape=(len(self.thresholds),), 998 initializer=init_ops.zeros_initializer) 999 1000 def update_state(self, y_true, y_pred, sample_weight=None): 1001 """Accumulates the metric statistics. 1002 1003 Args: 1004 y_true: The ground truth values. 1005 y_pred: The predicted values. 1006 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 1007 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 1008 be broadcastable to `y_true`. 1009 1010 Returns: 1011 Update op. 1012 """ 1013 return metrics_utils.update_confusion_matrix_variables( 1014 {self._confusion_matrix_cond: self.accumulator}, 1015 y_true, 1016 y_pred, 1017 thresholds=self.thresholds, 1018 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 1019 sample_weight=sample_weight) 1020 1021 def result(self): 1022 if len(self.thresholds) == 1: 1023 result = self.accumulator[0] 1024 else: 1025 result = self.accumulator 1026 return ops.convert_to_tensor_v2_with_dispatch(result) 1027 1028 def reset_state(self): 1029 num_thresholds = len(to_list(self.thresholds)) 1030 backend.batch_set_value( 1031 [(v, np.zeros((num_thresholds,))) for v in self.variables]) 1032 1033 def get_config(self): 1034 config = {'thresholds': self.init_thresholds} 1035 base_config = super(_ConfusionMatrixConditionCount, self).get_config() 1036 return dict(list(base_config.items()) + list(config.items())) 1037 1038 1039@keras_export('keras.metrics.FalsePositives') 1040class FalsePositives(_ConfusionMatrixConditionCount): 1041 """Calculates the number of false positives. 1042 1043 If `sample_weight` is given, calculates the sum of the weights of 1044 false positives. This metric creates one local variable, `accumulator` 1045 that is used to keep track of the number of false positives. 1046 1047 If `sample_weight` is `None`, weights default to 1. 1048 Use `sample_weight` of 0 to mask values. 1049 1050 Args: 1051 thresholds: (Optional) Defaults to 0.5. A float value or a python 1052 list/tuple of float threshold values in [0, 1]. A threshold is compared 1053 with prediction values to determine the truth value of predictions 1054 (i.e., above the threshold is `true`, below is `false`). One metric 1055 value is generated for each threshold value. 1056 name: (Optional) string name of the metric instance. 1057 dtype: (Optional) data type of the metric result. 1058 1059 Standalone usage: 1060 1061 >>> m = tf.keras.metrics.FalsePositives() 1062 >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1]) 1063 >>> m.result().numpy() 1064 2.0 1065 1066 >>> m.reset_state() 1067 >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 1068 >>> m.result().numpy() 1069 1.0 1070 1071 Usage with `compile()` API: 1072 1073 ```python 1074 model.compile(optimizer='sgd', 1075 loss='mse', 1076 metrics=[tf.keras.metrics.FalsePositives()]) 1077 ``` 1078 """ 1079 1080 def __init__(self, thresholds=None, name=None, dtype=None): 1081 super(FalsePositives, self).__init__( 1082 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES, 1083 thresholds=thresholds, 1084 name=name, 1085 dtype=dtype) 1086 1087 1088@keras_export('keras.metrics.FalseNegatives') 1089class FalseNegatives(_ConfusionMatrixConditionCount): 1090 """Calculates the number of false negatives. 1091 1092 If `sample_weight` is given, calculates the sum of the weights of 1093 false negatives. This metric creates one local variable, `accumulator` 1094 that is used to keep track of the number of false negatives. 1095 1096 If `sample_weight` is `None`, weights default to 1. 1097 Use `sample_weight` of 0 to mask values. 1098 1099 Args: 1100 thresholds: (Optional) Defaults to 0.5. A float value or a python 1101 list/tuple of float threshold values in [0, 1]. A threshold is compared 1102 with prediction values to determine the truth value of predictions 1103 (i.e., above the threshold is `true`, below is `false`). One metric 1104 value is generated for each threshold value. 1105 name: (Optional) string name of the metric instance. 1106 dtype: (Optional) data type of the metric result. 1107 1108 Standalone usage: 1109 1110 >>> m = tf.keras.metrics.FalseNegatives() 1111 >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0]) 1112 >>> m.result().numpy() 1113 2.0 1114 1115 >>> m.reset_state() 1116 >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0]) 1117 >>> m.result().numpy() 1118 1.0 1119 1120 Usage with `compile()` API: 1121 1122 ```python 1123 model.compile(optimizer='sgd', 1124 loss='mse', 1125 metrics=[tf.keras.metrics.FalseNegatives()]) 1126 ``` 1127 """ 1128 1129 def __init__(self, thresholds=None, name=None, dtype=None): 1130 super(FalseNegatives, self).__init__( 1131 confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES, 1132 thresholds=thresholds, 1133 name=name, 1134 dtype=dtype) 1135 1136 1137@keras_export('keras.metrics.TrueNegatives') 1138class TrueNegatives(_ConfusionMatrixConditionCount): 1139 """Calculates the number of true negatives. 1140 1141 If `sample_weight` is given, calculates the sum of the weights of 1142 true negatives. This metric creates one local variable, `accumulator` 1143 that is used to keep track of the number of true negatives. 1144 1145 If `sample_weight` is `None`, weights default to 1. 1146 Use `sample_weight` of 0 to mask values. 1147 1148 Args: 1149 thresholds: (Optional) Defaults to 0.5. A float value or a python 1150 list/tuple of float threshold values in [0, 1]. A threshold is compared 1151 with prediction values to determine the truth value of predictions 1152 (i.e., above the threshold is `true`, below is `false`). One metric 1153 value is generated for each threshold value. 1154 name: (Optional) string name of the metric instance. 1155 dtype: (Optional) data type of the metric result. 1156 1157 Standalone usage: 1158 1159 >>> m = tf.keras.metrics.TrueNegatives() 1160 >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0]) 1161 >>> m.result().numpy() 1162 2.0 1163 1164 >>> m.reset_state() 1165 >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0]) 1166 >>> m.result().numpy() 1167 1.0 1168 1169 Usage with `compile()` API: 1170 1171 ```python 1172 model.compile(optimizer='sgd', 1173 loss='mse', 1174 metrics=[tf.keras.metrics.TrueNegatives()]) 1175 ``` 1176 """ 1177 1178 def __init__(self, thresholds=None, name=None, dtype=None): 1179 super(TrueNegatives, self).__init__( 1180 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES, 1181 thresholds=thresholds, 1182 name=name, 1183 dtype=dtype) 1184 1185 1186@keras_export('keras.metrics.TruePositives') 1187class TruePositives(_ConfusionMatrixConditionCount): 1188 """Calculates the number of true positives. 1189 1190 If `sample_weight` is given, calculates the sum of the weights of 1191 true positives. This metric creates one local variable, `true_positives` 1192 that is used to keep track of the number of true positives. 1193 1194 If `sample_weight` is `None`, weights default to 1. 1195 Use `sample_weight` of 0 to mask values. 1196 1197 Args: 1198 thresholds: (Optional) Defaults to 0.5. A float value or a python 1199 list/tuple of float threshold values in [0, 1]. A threshold is compared 1200 with prediction values to determine the truth value of predictions 1201 (i.e., above the threshold is `true`, below is `false`). One metric 1202 value is generated for each threshold value. 1203 name: (Optional) string name of the metric instance. 1204 dtype: (Optional) data type of the metric result. 1205 1206 Standalone usage: 1207 1208 >>> m = tf.keras.metrics.TruePositives() 1209 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 1210 >>> m.result().numpy() 1211 2.0 1212 1213 >>> m.reset_state() 1214 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 1215 >>> m.result().numpy() 1216 1.0 1217 1218 Usage with `compile()` API: 1219 1220 ```python 1221 model.compile(optimizer='sgd', 1222 loss='mse', 1223 metrics=[tf.keras.metrics.TruePositives()]) 1224 ``` 1225 """ 1226 1227 def __init__(self, thresholds=None, name=None, dtype=None): 1228 super(TruePositives, self).__init__( 1229 confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES, 1230 thresholds=thresholds, 1231 name=name, 1232 dtype=dtype) 1233 1234 1235@keras_export('keras.metrics.Precision') 1236class Precision(Metric): 1237 """Computes the precision of the predictions with respect to the labels. 1238 1239 The metric creates two local variables, `true_positives` and `false_positives` 1240 that are used to compute the precision. This value is ultimately returned as 1241 `precision`, an idempotent operation that simply divides `true_positives` 1242 by the sum of `true_positives` and `false_positives`. 1243 1244 If `sample_weight` is `None`, weights default to 1. 1245 Use `sample_weight` of 0 to mask values. 1246 1247 If `top_k` is set, we'll calculate precision as how often on average a class 1248 among the top-k classes with the highest predicted values of a batch entry is 1249 correct and can be found in the label for that entry. 1250 1251 If `class_id` is specified, we calculate precision by considering only the 1252 entries in the batch for which `class_id` is above the threshold and/or in the 1253 top-k highest predictions, and computing the fraction of them for which 1254 `class_id` is indeed a correct label. 1255 1256 Args: 1257 thresholds: (Optional) A float value or a python list/tuple of float 1258 threshold values in [0, 1]. A threshold is compared with prediction 1259 values to determine the truth value of predictions (i.e., above the 1260 threshold is `true`, below is `false`). One metric value is generated 1261 for each threshold value. If neither thresholds nor top_k are set, the 1262 default is to calculate precision with `thresholds=0.5`. 1263 top_k: (Optional) Unset by default. An int value specifying the top-k 1264 predictions to consider when calculating precision. 1265 class_id: (Optional) Integer class ID for which we want binary metrics. 1266 This must be in the half-open interval `[0, num_classes)`, where 1267 `num_classes` is the last dimension of predictions. 1268 name: (Optional) string name of the metric instance. 1269 dtype: (Optional) data type of the metric result. 1270 1271 Standalone usage: 1272 1273 >>> m = tf.keras.metrics.Precision() 1274 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 1275 >>> m.result().numpy() 1276 0.6666667 1277 1278 >>> m.reset_state() 1279 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 1280 >>> m.result().numpy() 1281 1.0 1282 1283 >>> # With top_k=2, it will calculate precision over y_true[:2] and y_pred[:2] 1284 >>> m = tf.keras.metrics.Precision(top_k=2) 1285 >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) 1286 >>> m.result().numpy() 1287 0.0 1288 1289 >>> # With top_k=4, it will calculate precision over y_true[:4] and y_pred[:4] 1290 >>> m = tf.keras.metrics.Precision(top_k=4) 1291 >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1]) 1292 >>> m.result().numpy() 1293 0.5 1294 1295 Usage with `compile()` API: 1296 1297 ```python 1298 model.compile(optimizer='sgd', 1299 loss='mse', 1300 metrics=[tf.keras.metrics.Precision()]) 1301 ``` 1302 """ 1303 1304 def __init__(self, 1305 thresholds=None, 1306 top_k=None, 1307 class_id=None, 1308 name=None, 1309 dtype=None): 1310 super(Precision, self).__init__(name=name, dtype=dtype) 1311 self.init_thresholds = thresholds 1312 self.top_k = top_k 1313 self.class_id = class_id 1314 1315 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF 1316 self.thresholds = metrics_utils.parse_init_thresholds( 1317 thresholds, default_threshold=default_threshold) 1318 self._thresholds_distributed_evenly = ( 1319 metrics_utils.is_evenly_distributed_thresholds(self.thresholds)) 1320 self.true_positives = self.add_weight( 1321 'true_positives', 1322 shape=(len(self.thresholds),), 1323 initializer=init_ops.zeros_initializer) 1324 self.false_positives = self.add_weight( 1325 'false_positives', 1326 shape=(len(self.thresholds),), 1327 initializer=init_ops.zeros_initializer) 1328 1329 def update_state(self, y_true, y_pred, sample_weight=None): 1330 """Accumulates true positive and false positive statistics. 1331 1332 Args: 1333 y_true: The ground truth values, with the same dimensions as `y_pred`. 1334 Will be cast to `bool`. 1335 y_pred: The predicted values. Each element must be in the range `[0, 1]`. 1336 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 1337 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 1338 be broadcastable to `y_true`. 1339 1340 Returns: 1341 Update op. 1342 """ 1343 return metrics_utils.update_confusion_matrix_variables( 1344 { 1345 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, 1346 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives 1347 }, 1348 y_true, 1349 y_pred, 1350 thresholds=self.thresholds, 1351 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 1352 top_k=self.top_k, 1353 class_id=self.class_id, 1354 sample_weight=sample_weight) 1355 1356 def result(self): 1357 result = math_ops.div_no_nan(self.true_positives, 1358 self.true_positives + self.false_positives) 1359 return result[0] if len(self.thresholds) == 1 else result 1360 1361 def reset_state(self): 1362 num_thresholds = len(to_list(self.thresholds)) 1363 backend.batch_set_value([(v, np.zeros((num_thresholds,))) 1364 for v in (self.true_positives, 1365 self.false_positives)]) 1366 1367 def get_config(self): 1368 config = { 1369 'thresholds': self.init_thresholds, 1370 'top_k': self.top_k, 1371 'class_id': self.class_id 1372 } 1373 base_config = super(Precision, self).get_config() 1374 return dict(list(base_config.items()) + list(config.items())) 1375 1376 1377@keras_export('keras.metrics.Recall') 1378class Recall(Metric): 1379 """Computes the recall of the predictions with respect to the labels. 1380 1381 This metric creates two local variables, `true_positives` and 1382 `false_negatives`, that are used to compute the recall. This value is 1383 ultimately returned as `recall`, an idempotent operation that simply divides 1384 `true_positives` by the sum of `true_positives` and `false_negatives`. 1385 1386 If `sample_weight` is `None`, weights default to 1. 1387 Use `sample_weight` of 0 to mask values. 1388 1389 If `top_k` is set, recall will be computed as how often on average a class 1390 among the labels of a batch entry is in the top-k predictions. 1391 1392 If `class_id` is specified, we calculate recall by considering only the 1393 entries in the batch for which `class_id` is in the label, and computing the 1394 fraction of them for which `class_id` is above the threshold and/or in the 1395 top-k predictions. 1396 1397 Args: 1398 thresholds: (Optional) A float value or a python list/tuple of float 1399 threshold values in [0, 1]. A threshold is compared with prediction 1400 values to determine the truth value of predictions (i.e., above the 1401 threshold is `true`, below is `false`). One metric value is generated 1402 for each threshold value. If neither thresholds nor top_k are set, the 1403 default is to calculate recall with `thresholds=0.5`. 1404 top_k: (Optional) Unset by default. An int value specifying the top-k 1405 predictions to consider when calculating recall. 1406 class_id: (Optional) Integer class ID for which we want binary metrics. 1407 This must be in the half-open interval `[0, num_classes)`, where 1408 `num_classes` is the last dimension of predictions. 1409 name: (Optional) string name of the metric instance. 1410 dtype: (Optional) data type of the metric result. 1411 1412 Standalone usage: 1413 1414 >>> m = tf.keras.metrics.Recall() 1415 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1]) 1416 >>> m.result().numpy() 1417 0.6666667 1418 1419 >>> m.reset_state() 1420 >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0]) 1421 >>> m.result().numpy() 1422 1.0 1423 1424 Usage with `compile()` API: 1425 1426 ```python 1427 model.compile(optimizer='sgd', 1428 loss='mse', 1429 metrics=[tf.keras.metrics.Recall()]) 1430 ``` 1431 """ 1432 1433 def __init__(self, 1434 thresholds=None, 1435 top_k=None, 1436 class_id=None, 1437 name=None, 1438 dtype=None): 1439 super(Recall, self).__init__(name=name, dtype=dtype) 1440 self.init_thresholds = thresholds 1441 self.top_k = top_k 1442 self.class_id = class_id 1443 1444 default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF 1445 self.thresholds = metrics_utils.parse_init_thresholds( 1446 thresholds, default_threshold=default_threshold) 1447 self._thresholds_distributed_evenly = ( 1448 metrics_utils.is_evenly_distributed_thresholds(self.thresholds)) 1449 self.true_positives = self.add_weight( 1450 'true_positives', 1451 shape=(len(self.thresholds),), 1452 initializer=init_ops.zeros_initializer) 1453 self.false_negatives = self.add_weight( 1454 'false_negatives', 1455 shape=(len(self.thresholds),), 1456 initializer=init_ops.zeros_initializer) 1457 1458 def update_state(self, y_true, y_pred, sample_weight=None): 1459 """Accumulates true positive and false negative statistics. 1460 1461 Args: 1462 y_true: The ground truth values, with the same dimensions as `y_pred`. 1463 Will be cast to `bool`. 1464 y_pred: The predicted values. Each element must be in the range `[0, 1]`. 1465 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 1466 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 1467 be broadcastable to `y_true`. 1468 1469 Returns: 1470 Update op. 1471 """ 1472 return metrics_utils.update_confusion_matrix_variables( 1473 { 1474 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, 1475 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives 1476 }, 1477 y_true, 1478 y_pred, 1479 thresholds=self.thresholds, 1480 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 1481 top_k=self.top_k, 1482 class_id=self.class_id, 1483 sample_weight=sample_weight) 1484 1485 def result(self): 1486 result = math_ops.div_no_nan(self.true_positives, 1487 self.true_positives + self.false_negatives) 1488 return result[0] if len(self.thresholds) == 1 else result 1489 1490 def reset_state(self): 1491 num_thresholds = len(to_list(self.thresholds)) 1492 backend.batch_set_value([(v, np.zeros((num_thresholds,))) 1493 for v in (self.true_positives, 1494 self.false_negatives)]) 1495 1496 def get_config(self): 1497 config = { 1498 'thresholds': self.init_thresholds, 1499 'top_k': self.top_k, 1500 'class_id': self.class_id 1501 } 1502 base_config = super(Recall, self).get_config() 1503 return dict(list(base_config.items()) + list(config.items())) 1504 1505 1506class SensitivitySpecificityBase(Metric, metaclass=abc.ABCMeta): 1507 """Abstract base class for computing sensitivity and specificity. 1508 1509 For additional information about specificity and sensitivity, see 1510 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). 1511 """ 1512 1513 def __init__(self, 1514 value, 1515 num_thresholds=200, 1516 class_id=None, 1517 name=None, 1518 dtype=None): 1519 super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype) 1520 if num_thresholds <= 0: 1521 raise ValueError('`num_thresholds` must be > 0.') 1522 self.value = value 1523 self.class_id = class_id 1524 self.true_positives = self.add_weight( 1525 'true_positives', 1526 shape=(num_thresholds,), 1527 initializer=init_ops.zeros_initializer) 1528 self.true_negatives = self.add_weight( 1529 'true_negatives', 1530 shape=(num_thresholds,), 1531 initializer=init_ops.zeros_initializer) 1532 self.false_positives = self.add_weight( 1533 'false_positives', 1534 shape=(num_thresholds,), 1535 initializer=init_ops.zeros_initializer) 1536 self.false_negatives = self.add_weight( 1537 'false_negatives', 1538 shape=(num_thresholds,), 1539 initializer=init_ops.zeros_initializer) 1540 1541 # Compute `num_thresholds` thresholds in [0, 1] 1542 if num_thresholds == 1: 1543 self.thresholds = [0.5] 1544 self._thresholds_distributed_evenly = False 1545 else: 1546 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 1547 for i in range(num_thresholds - 2)] 1548 self.thresholds = [0.0] + thresholds + [1.0] 1549 self._thresholds_distributed_evenly = True 1550 1551 def update_state(self, y_true, y_pred, sample_weight=None): 1552 """Accumulates confusion matrix statistics. 1553 1554 Args: 1555 y_true: The ground truth values. 1556 y_pred: The predicted values. 1557 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 1558 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 1559 be broadcastable to `y_true`. 1560 1561 Returns: 1562 Update op. 1563 """ 1564 return metrics_utils.update_confusion_matrix_variables( 1565 { 1566 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives, 1567 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives, 1568 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives, 1569 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives, 1570 }, 1571 y_true, 1572 y_pred, 1573 thresholds=self.thresholds, 1574 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 1575 class_id=self.class_id, 1576 sample_weight=sample_weight) 1577 1578 def reset_state(self): 1579 num_thresholds = len(self.thresholds) 1580 confusion_matrix_variables = (self.true_positives, self.true_negatives, 1581 self.false_positives, self.false_negatives) 1582 backend.batch_set_value([ 1583 (v, np.zeros((num_thresholds,))) for v in confusion_matrix_variables 1584 ]) 1585 1586 def get_config(self): 1587 config = {'class_id': self.class_id} 1588 base_config = super(SensitivitySpecificityBase, self).get_config() 1589 return dict(list(base_config.items()) + list(config.items())) 1590 1591 def _find_max_under_constraint(self, constrained, dependent, predicate): 1592 """Returns the maximum of dependent_statistic that satisfies the constraint. 1593 1594 Args: 1595 constrained: Over these values the constraint 1596 is specified. A rank-1 tensor. 1597 dependent: From these values the maximum that satiesfies the 1598 constraint is selected. Values in this tensor and in 1599 `constrained` are linked by having the same threshold at each 1600 position, hence this tensor must have the same shape. 1601 predicate: A binary boolean functor to be applied to arguments 1602 `constrained` and `self.value`, e.g. `tf.greater`. 1603 1604 Returns maximal dependent value, if no value satiesfies the constraint 0.0. 1605 """ 1606 feasible = array_ops.where_v2(predicate(constrained, self.value)) 1607 feasible_exists = math_ops.greater(array_ops.size(feasible), 0) 1608 max_dependent = math_ops.reduce_max(array_ops.gather(dependent, feasible)) 1609 1610 return array_ops.where_v2(feasible_exists, max_dependent, 0.0) 1611 1612 1613@keras_export('keras.metrics.SensitivityAtSpecificity') 1614class SensitivityAtSpecificity(SensitivitySpecificityBase): 1615 """Computes best sensitivity where specificity is >= specified value. 1616 1617 the sensitivity at a given specificity. 1618 1619 `Sensitivity` measures the proportion of actual positives that are correctly 1620 identified as such (tp / (tp + fn)). 1621 `Specificity` measures the proportion of actual negatives that are correctly 1622 identified as such (tn / (tn + fp)). 1623 1624 This metric creates four local variables, `true_positives`, `true_negatives`, 1625 `false_positives` and `false_negatives` that are used to compute the 1626 sensitivity at the given specificity. The threshold for the given specificity 1627 value is computed and used to evaluate the corresponding sensitivity. 1628 1629 If `sample_weight` is `None`, weights default to 1. 1630 Use `sample_weight` of 0 to mask values. 1631 1632 If `class_id` is specified, we calculate precision by considering only the 1633 entries in the batch for which `class_id` is above the threshold predictions, 1634 and computing the fraction of them for which `class_id` is indeed a correct 1635 label. 1636 1637 For additional information about specificity and sensitivity, see 1638 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). 1639 1640 Args: 1641 specificity: A scalar value in range `[0, 1]`. 1642 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 1643 use for matching the given specificity. 1644 class_id: (Optional) Integer class ID for which we want binary metrics. 1645 This must be in the half-open interval `[0, num_classes)`, where 1646 `num_classes` is the last dimension of predictions. 1647 name: (Optional) string name of the metric instance. 1648 dtype: (Optional) data type of the metric result. 1649 1650 Standalone usage: 1651 1652 >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5) 1653 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) 1654 >>> m.result().numpy() 1655 0.5 1656 1657 >>> m.reset_state() 1658 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], 1659 ... sample_weight=[1, 1, 2, 2, 1]) 1660 >>> m.result().numpy() 1661 0.333333 1662 1663 Usage with `compile()` API: 1664 1665 ```python 1666 model.compile( 1667 optimizer='sgd', 1668 loss='mse', 1669 metrics=[tf.keras.metrics.SensitivityAtSpecificity()]) 1670 ``` 1671 """ 1672 1673 def __init__(self, 1674 specificity, 1675 num_thresholds=200, 1676 class_id=None, 1677 name=None, 1678 dtype=None): 1679 if specificity < 0 or specificity > 1: 1680 raise ValueError('`specificity` must be in the range [0, 1].') 1681 self.specificity = specificity 1682 self.num_thresholds = num_thresholds 1683 super(SensitivityAtSpecificity, self).__init__( 1684 specificity, 1685 num_thresholds=num_thresholds, 1686 class_id=class_id, 1687 name=name, 1688 dtype=dtype) 1689 1690 def result(self): 1691 specificities = math_ops.div_no_nan( 1692 self.true_negatives, self.true_negatives + self.false_positives) 1693 sensitivities = math_ops.div_no_nan( 1694 self.true_positives, self.true_positives + self.false_negatives) 1695 return self._find_max_under_constraint( 1696 specificities, sensitivities, math_ops.greater_equal) 1697 1698 def get_config(self): 1699 config = { 1700 'num_thresholds': self.num_thresholds, 1701 'specificity': self.specificity 1702 } 1703 base_config = super(SensitivityAtSpecificity, self).get_config() 1704 return dict(list(base_config.items()) + list(config.items())) 1705 1706 1707@keras_export('keras.metrics.SpecificityAtSensitivity') 1708class SpecificityAtSensitivity(SensitivitySpecificityBase): 1709 """Computes best specificity where sensitivity is >= specified value. 1710 1711 `Sensitivity` measures the proportion of actual positives that are correctly 1712 identified as such (tp / (tp + fn)). 1713 `Specificity` measures the proportion of actual negatives that are correctly 1714 identified as such (tn / (tn + fp)). 1715 1716 This metric creates four local variables, `true_positives`, `true_negatives`, 1717 `false_positives` and `false_negatives` that are used to compute the 1718 specificity at the given sensitivity. The threshold for the given sensitivity 1719 value is computed and used to evaluate the corresponding specificity. 1720 1721 If `sample_weight` is `None`, weights default to 1. 1722 Use `sample_weight` of 0 to mask values. 1723 1724 If `class_id` is specified, we calculate precision by considering only the 1725 entries in the batch for which `class_id` is above the threshold predictions, 1726 and computing the fraction of them for which `class_id` is indeed a correct 1727 label. 1728 1729 For additional information about specificity and sensitivity, see 1730 [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). 1731 1732 Args: 1733 sensitivity: A scalar value in range `[0, 1]`. 1734 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 1735 use for matching the given sensitivity. 1736 class_id: (Optional) Integer class ID for which we want binary metrics. 1737 This must be in the half-open interval `[0, num_classes)`, where 1738 `num_classes` is the last dimension of predictions. 1739 name: (Optional) string name of the metric instance. 1740 dtype: (Optional) data type of the metric result. 1741 1742 Standalone usage: 1743 1744 >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5) 1745 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) 1746 >>> m.result().numpy() 1747 0.66666667 1748 1749 >>> m.reset_state() 1750 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], 1751 ... sample_weight=[1, 1, 2, 2, 2]) 1752 >>> m.result().numpy() 1753 0.5 1754 1755 Usage with `compile()` API: 1756 1757 ```python 1758 model.compile( 1759 optimizer='sgd', 1760 loss='mse', 1761 metrics=[tf.keras.metrics.SpecificityAtSensitivity()]) 1762 ``` 1763 """ 1764 1765 def __init__(self, 1766 sensitivity, 1767 num_thresholds=200, 1768 class_id=None, 1769 name=None, 1770 dtype=None): 1771 if sensitivity < 0 or sensitivity > 1: 1772 raise ValueError('`sensitivity` must be in the range [0, 1].') 1773 self.sensitivity = sensitivity 1774 self.num_thresholds = num_thresholds 1775 super(SpecificityAtSensitivity, self).__init__( 1776 sensitivity, 1777 num_thresholds=num_thresholds, 1778 class_id=class_id, 1779 name=name, 1780 dtype=dtype) 1781 1782 def result(self): 1783 sensitivities = math_ops.div_no_nan( 1784 self.true_positives, self.true_positives + self.false_negatives) 1785 specificities = math_ops.div_no_nan( 1786 self.true_negatives, self.true_negatives + self.false_positives) 1787 return self._find_max_under_constraint( 1788 sensitivities, specificities, math_ops.greater_equal) 1789 1790 def get_config(self): 1791 config = { 1792 'num_thresholds': self.num_thresholds, 1793 'sensitivity': self.sensitivity 1794 } 1795 base_config = super(SpecificityAtSensitivity, self).get_config() 1796 return dict(list(base_config.items()) + list(config.items())) 1797 1798 1799@keras_export('keras.metrics.PrecisionAtRecall') 1800class PrecisionAtRecall(SensitivitySpecificityBase): 1801 """Computes best precision where recall is >= specified value. 1802 1803 This metric creates four local variables, `true_positives`, `true_negatives`, 1804 `false_positives` and `false_negatives` that are used to compute the 1805 precision at the given recall. The threshold for the given recall 1806 value is computed and used to evaluate the corresponding precision. 1807 1808 If `sample_weight` is `None`, weights default to 1. 1809 Use `sample_weight` of 0 to mask values. 1810 1811 If `class_id` is specified, we calculate precision by considering only the 1812 entries in the batch for which `class_id` is above the threshold predictions, 1813 and computing the fraction of them for which `class_id` is indeed a correct 1814 label. 1815 1816 Args: 1817 recall: A scalar value in range `[0, 1]`. 1818 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 1819 use for matching the given recall. 1820 class_id: (Optional) Integer class ID for which we want binary metrics. 1821 This must be in the half-open interval `[0, num_classes)`, where 1822 `num_classes` is the last dimension of predictions. 1823 name: (Optional) string name of the metric instance. 1824 dtype: (Optional) data type of the metric result. 1825 1826 Standalone usage: 1827 1828 >>> m = tf.keras.metrics.PrecisionAtRecall(0.5) 1829 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8]) 1830 >>> m.result().numpy() 1831 0.5 1832 1833 >>> m.reset_state() 1834 >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8], 1835 ... sample_weight=[2, 2, 2, 1, 1]) 1836 >>> m.result().numpy() 1837 0.33333333 1838 1839 Usage with `compile()` API: 1840 1841 ```python 1842 model.compile( 1843 optimizer='sgd', 1844 loss='mse', 1845 metrics=[tf.keras.metrics.PrecisionAtRecall(recall=0.8)]) 1846 ``` 1847 """ 1848 1849 def __init__(self, 1850 recall, 1851 num_thresholds=200, 1852 class_id=None, 1853 name=None, 1854 dtype=None): 1855 if recall < 0 or recall > 1: 1856 raise ValueError('`recall` must be in the range [0, 1].') 1857 self.recall = recall 1858 self.num_thresholds = num_thresholds 1859 super(PrecisionAtRecall, self).__init__( 1860 value=recall, 1861 num_thresholds=num_thresholds, 1862 class_id=class_id, 1863 name=name, 1864 dtype=dtype) 1865 1866 def result(self): 1867 recalls = math_ops.div_no_nan( 1868 self.true_positives, self.true_positives + self.false_negatives) 1869 precisions = math_ops.div_no_nan( 1870 self.true_positives, self.true_positives + self.false_positives) 1871 return self._find_max_under_constraint( 1872 recalls, precisions, math_ops.greater_equal) 1873 1874 def get_config(self): 1875 config = {'num_thresholds': self.num_thresholds, 'recall': self.recall} 1876 base_config = super(PrecisionAtRecall, self).get_config() 1877 return dict(list(base_config.items()) + list(config.items())) 1878 1879 1880@keras_export('keras.metrics.RecallAtPrecision') 1881class RecallAtPrecision(SensitivitySpecificityBase): 1882 """Computes best recall where precision is >= specified value. 1883 1884 For a given score-label-distribution the required precision might not 1885 be achievable, in this case 0.0 is returned as recall. 1886 1887 This metric creates four local variables, `true_positives`, `true_negatives`, 1888 `false_positives` and `false_negatives` that are used to compute the 1889 recall at the given precision. The threshold for the given precision 1890 value is computed and used to evaluate the corresponding recall. 1891 1892 If `sample_weight` is `None`, weights default to 1. 1893 Use `sample_weight` of 0 to mask values. 1894 1895 If `class_id` is specified, we calculate precision by considering only the 1896 entries in the batch for which `class_id` is above the threshold predictions, 1897 and computing the fraction of them for which `class_id` is indeed a correct 1898 label. 1899 1900 Args: 1901 precision: A scalar value in range `[0, 1]`. 1902 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 1903 use for matching the given precision. 1904 class_id: (Optional) Integer class ID for which we want binary metrics. 1905 This must be in the half-open interval `[0, num_classes)`, where 1906 `num_classes` is the last dimension of predictions. 1907 name: (Optional) string name of the metric instance. 1908 dtype: (Optional) data type of the metric result. 1909 1910 Standalone usage: 1911 1912 >>> m = tf.keras.metrics.RecallAtPrecision(0.8) 1913 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) 1914 >>> m.result().numpy() 1915 0.5 1916 1917 >>> m.reset_state() 1918 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], 1919 ... sample_weight=[1, 0, 0, 1]) 1920 >>> m.result().numpy() 1921 1.0 1922 1923 Usage with `compile()` API: 1924 1925 ```python 1926 model.compile( 1927 optimizer='sgd', 1928 loss='mse', 1929 metrics=[tf.keras.metrics.RecallAtPrecision(precision=0.8)]) 1930 ``` 1931 """ 1932 1933 def __init__(self, 1934 precision, 1935 num_thresholds=200, 1936 class_id=None, 1937 name=None, 1938 dtype=None): 1939 if precision < 0 or precision > 1: 1940 raise ValueError('`precision` must be in the range [0, 1].') 1941 self.precision = precision 1942 self.num_thresholds = num_thresholds 1943 super(RecallAtPrecision, self).__init__( 1944 value=precision, 1945 num_thresholds=num_thresholds, 1946 class_id=class_id, 1947 name=name, 1948 dtype=dtype) 1949 1950 def result(self): 1951 precisions = math_ops.div_no_nan( 1952 self.true_positives, self.true_positives + self.false_positives) 1953 recalls = math_ops.div_no_nan( 1954 self.true_positives, self.true_positives + self.false_negatives) 1955 return self._find_max_under_constraint( 1956 precisions, recalls, math_ops.greater_equal) 1957 1958 def get_config(self): 1959 config = {'num_thresholds': self.num_thresholds, 1960 'precision': self.precision} 1961 base_config = super(RecallAtPrecision, self).get_config() 1962 return dict(list(base_config.items()) + list(config.items())) 1963 1964 1965@keras_export('keras.metrics.AUC') 1966class AUC(Metric): 1967 """Approximates the AUC (Area under the curve) of the ROC or PR curves. 1968 1969 The AUC (Area under the curve) of the ROC (Receiver operating 1970 characteristic; default) or PR (Precision Recall) curves are quality measures 1971 of binary classifiers. Unlike the accuracy, and like cross-entropy 1972 losses, ROC-AUC and PR-AUC evaluate all the operational points of a model. 1973 1974 This class approximates AUCs using a Riemann sum. During the metric 1975 accumulation phrase, predictions are accumulated within predefined buckets 1976 by value. The AUC is then computed by interpolating per-bucket averages. These 1977 buckets define the evaluated operational points. 1978 1979 This metric creates four local variables, `true_positives`, `true_negatives`, 1980 `false_positives` and `false_negatives` that are used to compute the AUC. 1981 To discretize the AUC curve, a linearly spaced set of thresholds is used to 1982 compute pairs of recall and precision values. The area under the ROC-curve is 1983 therefore computed using the height of the recall values by the false positive 1984 rate, while the area under the PR-curve is the computed using the height of 1985 the precision values by the recall. 1986 1987 This value is ultimately returned as `auc`, an idempotent operation that 1988 computes the area under a discretized curve of precision versus recall values 1989 (computed using the aforementioned variables). The `num_thresholds` variable 1990 controls the degree of discretization with larger numbers of thresholds more 1991 closely approximating the true AUC. The quality of the approximation may vary 1992 dramatically depending on `num_thresholds`. The `thresholds` parameter can be 1993 used to manually specify thresholds which split the predictions more evenly. 1994 1995 For a best approximation of the real AUC, `predictions` should be distributed 1996 approximately uniformly in the range [0, 1] (if `from_logits=False`). The 1997 quality of the AUC approximation may be poor if this is not the case. Setting 1998 `summation_method` to 'minoring' or 'majoring' can help quantify the error in 1999 the approximation by providing lower or upper bound estimate of the AUC. 2000 2001 If `sample_weight` is `None`, weights default to 1. 2002 Use `sample_weight` of 0 to mask values. 2003 2004 Args: 2005 num_thresholds: (Optional) Defaults to 200. The number of thresholds to 2006 use when discretizing the roc curve. Values must be > 1. 2007 curve: (Optional) Specifies the name of the curve to be computed, 'ROC' 2008 [default] or 'PR' for the Precision-Recall-curve. 2009 summation_method: (Optional) Specifies the [Riemann summation method]( 2010 https://en.wikipedia.org/wiki/Riemann_sum) used. 2011 'interpolation' (default) applies mid-point summation scheme for `ROC`. 2012 For PR-AUC, interpolates (true/false) positives but not the ratio that 2013 is precision (see Davis & Goadrich 2006 for details); 2014 'minoring' applies left summation 2015 for increasing intervals and right summation for decreasing intervals; 2016 'majoring' does the opposite. 2017 name: (Optional) string name of the metric instance. 2018 dtype: (Optional) data type of the metric result. 2019 thresholds: (Optional) A list of floating point values to use as the 2020 thresholds for discretizing the curve. If set, the `num_thresholds` 2021 parameter is ignored. Values should be in [0, 1]. Endpoint thresholds 2022 equal to {-epsilon, 1+epsilon} for a small positive epsilon value will 2023 be automatically included with these to correctly handle predictions 2024 equal to exactly 0 or 1. 2025 multi_label: boolean indicating whether multilabel data should be 2026 treated as such, wherein AUC is computed separately for each label and 2027 then averaged across labels, or (when False) if the data should be 2028 flattened into a single label before AUC computation. In the latter 2029 case, when multilabel data is passed to AUC, each label-prediction pair 2030 is treated as an individual data point. Should be set to False for 2031 multi-class data. 2032 num_labels: (Optional) The number of labels, used when `multi_label` is 2033 True. If `num_labels` is not specified, then state variables get created 2034 on the first call to `update_state`. 2035 label_weights: (Optional) list, array, or tensor of non-negative weights 2036 used to compute AUCs for multilabel data. When `multi_label` is True, 2037 the weights are applied to the individual label AUCs when they are 2038 averaged to produce the multi-label AUC. When it's False, they are used 2039 to weight the individual label predictions in computing the confusion 2040 matrix on the flattened data. Note that this is unlike class_weights in 2041 that class_weights weights the example depending on the value of its 2042 label, whereas label_weights depends only on the index of that label 2043 before flattening; therefore `label_weights` should not be used for 2044 multi-class data. 2045 from_logits: boolean indicating whether the predictions (`y_pred` in 2046 `update_state`) are probabilities or sigmoid logits. As a rule of thumb, 2047 when using a keras loss, the `from_logits` constructor argument of the 2048 loss should match the AUC `from_logits` constructor argument. 2049 2050 Standalone usage: 2051 2052 >>> m = tf.keras.metrics.AUC(num_thresholds=3) 2053 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9]) 2054 >>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7] 2055 >>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2] 2056 >>> # tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0] 2057 >>> # auc = ((((1+0.5)/2)*(1-0)) + (((0.5+0)/2)*(0-0))) = 0.75 2058 >>> m.result().numpy() 2059 0.75 2060 2061 >>> m.reset_state() 2062 >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9], 2063 ... sample_weight=[1, 0, 0, 1]) 2064 >>> m.result().numpy() 2065 1.0 2066 2067 Usage with `compile()` API: 2068 2069 ```python 2070 # Reports the AUC of a model outputing a probability. 2071 model.compile(optimizer='sgd', 2072 loss=tf.keras.losses.BinaryCrossentropy(), 2073 metrics=[tf.keras.metrics.AUC()]) 2074 2075 # Reports the AUC of a model outputing a logit. 2076 model.compile(optimizer='sgd', 2077 loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), 2078 metrics=[tf.keras.metrics.AUC(from_logits=True)]) 2079 ``` 2080 """ 2081 2082 def __init__(self, 2083 num_thresholds=200, 2084 curve='ROC', 2085 summation_method='interpolation', 2086 name=None, 2087 dtype=None, 2088 thresholds=None, 2089 multi_label=False, 2090 num_labels=None, 2091 label_weights=None, 2092 from_logits=False): 2093 # Validate configurations. 2094 if isinstance(curve, metrics_utils.AUCCurve) and curve not in list( 2095 metrics_utils.AUCCurve): 2096 raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format( 2097 curve, list(metrics_utils.AUCCurve))) 2098 if isinstance( 2099 summation_method, 2100 metrics_utils.AUCSummationMethod) and summation_method not in list( 2101 metrics_utils.AUCSummationMethod): 2102 raise ValueError( 2103 'Invalid summation method: "{}". Valid options are: "{}"'.format( 2104 summation_method, list(metrics_utils.AUCSummationMethod))) 2105 2106 # Update properties. 2107 if thresholds is not None: 2108 # If specified, use the supplied thresholds. 2109 self.num_thresholds = len(thresholds) + 2 2110 thresholds = sorted(thresholds) 2111 self._thresholds_distributed_evenly = ( 2112 metrics_utils.is_evenly_distributed_thresholds( 2113 np.array([0.0] + thresholds + [1.0]))) 2114 else: 2115 if num_thresholds <= 1: 2116 raise ValueError('`num_thresholds` must be > 1.') 2117 2118 # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in 2119 # (0, 1). 2120 self.num_thresholds = num_thresholds 2121 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 2122 for i in range(num_thresholds - 2)] 2123 self._thresholds_distributed_evenly = True 2124 2125 # Add an endpoint "threshold" below zero and above one for either 2126 # threshold method to account for floating point imprecisions. 2127 self._thresholds = np.array([0.0 - backend.epsilon()] + thresholds + 2128 [1.0 + backend.epsilon()]) 2129 2130 if isinstance(curve, metrics_utils.AUCCurve): 2131 self.curve = curve 2132 else: 2133 self.curve = metrics_utils.AUCCurve.from_str(curve) 2134 if isinstance(summation_method, metrics_utils.AUCSummationMethod): 2135 self.summation_method = summation_method 2136 else: 2137 self.summation_method = metrics_utils.AUCSummationMethod.from_str( 2138 summation_method) 2139 super(AUC, self).__init__(name=name, dtype=dtype) 2140 2141 # Handle multilabel arguments. 2142 self.multi_label = multi_label 2143 if label_weights is not None: 2144 label_weights = constant_op.constant(label_weights, dtype=self.dtype) 2145 checks = [ 2146 check_ops.assert_non_negative( 2147 label_weights, 2148 message='All values of `label_weights` must be non-negative.') 2149 ] 2150 with ops.control_dependencies(checks): 2151 self.label_weights = label_weights 2152 2153 else: 2154 self.label_weights = None 2155 2156 self._from_logits = from_logits 2157 2158 self._built = False 2159 if self.multi_label: 2160 if num_labels: 2161 shape = tensor_shape.TensorShape([None, num_labels]) 2162 self._build(shape) 2163 else: 2164 if num_labels: 2165 raise ValueError( 2166 '`num_labels` is needed only when `multi_label` is True.') 2167 self._build(None) 2168 2169 @property 2170 def thresholds(self): 2171 """The thresholds used for evaluating AUC.""" 2172 return list(self._thresholds) 2173 2174 def _build(self, shape): 2175 """Initialize TP, FP, TN, and FN tensors, given the shape of the data.""" 2176 if self.multi_label: 2177 if shape.ndims != 2: 2178 raise ValueError('`y_true` must have rank=2 when `multi_label` is ' 2179 'True. Found rank %s.' % shape.ndims) 2180 self._num_labels = shape[1] 2181 variable_shape = tensor_shape.TensorShape( 2182 [tensor_shape.Dimension(self.num_thresholds), self._num_labels]) 2183 2184 else: 2185 variable_shape = tensor_shape.TensorShape( 2186 [tensor_shape.Dimension(self.num_thresholds)]) 2187 self._build_input_shape = shape 2188 # Create metric variables 2189 self.true_positives = self.add_weight( 2190 'true_positives', 2191 shape=variable_shape, 2192 initializer=init_ops.zeros_initializer) 2193 self.true_negatives = self.add_weight( 2194 'true_negatives', 2195 shape=variable_shape, 2196 initializer=init_ops.zeros_initializer) 2197 self.false_positives = self.add_weight( 2198 'false_positives', 2199 shape=variable_shape, 2200 initializer=init_ops.zeros_initializer) 2201 self.false_negatives = self.add_weight( 2202 'false_negatives', 2203 shape=variable_shape, 2204 initializer=init_ops.zeros_initializer) 2205 2206 if self.multi_label: 2207 with ops.init_scope(): 2208 # This should only be necessary for handling v1 behavior. In v2, AUC 2209 # should be initialized outside of any tf.functions, and therefore in 2210 # eager mode. 2211 if not context.executing_eagerly(): 2212 backend._initialize_variables(backend._get_session()) # pylint: disable=protected-access 2213 2214 self._built = True 2215 2216 def update_state(self, y_true, y_pred, sample_weight=None): 2217 """Accumulates confusion matrix statistics. 2218 2219 Args: 2220 y_true: The ground truth values. 2221 y_pred: The predicted values. 2222 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 2223 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 2224 be broadcastable to `y_true`. 2225 2226 Returns: 2227 Update op. 2228 """ 2229 deps = [] 2230 if not self._built: 2231 self._build(tensor_shape.TensorShape(y_pred.shape)) 2232 2233 if self.multi_label or (self.label_weights is not None): 2234 # y_true should have shape (number of examples, number of labels). 2235 shapes = [ 2236 (y_true, ('N', 'L')) 2237 ] 2238 if self.multi_label: 2239 # TP, TN, FP, and FN should all have shape 2240 # (number of thresholds, number of labels). 2241 shapes.extend([(self.true_positives, ('T', 'L')), 2242 (self.true_negatives, ('T', 'L')), 2243 (self.false_positives, ('T', 'L')), 2244 (self.false_negatives, ('T', 'L'))]) 2245 if self.label_weights is not None: 2246 # label_weights should be of length equal to the number of labels. 2247 shapes.append((self.label_weights, ('L',))) 2248 deps = [ 2249 check_ops.assert_shapes( 2250 shapes, message='Number of labels is not consistent.') 2251 ] 2252 2253 # Only forward label_weights to update_confusion_matrix_variables when 2254 # multi_label is False. Otherwise the averaging of individual label AUCs is 2255 # handled in AUC.result 2256 label_weights = None if self.multi_label else self.label_weights 2257 2258 if self._from_logits: 2259 y_pred = activations.sigmoid(y_pred) 2260 2261 with ops.control_dependencies(deps): 2262 return metrics_utils.update_confusion_matrix_variables( 2263 { 2264 metrics_utils.ConfusionMatrix.TRUE_POSITIVES: 2265 self.true_positives, 2266 metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: 2267 self.true_negatives, 2268 metrics_utils.ConfusionMatrix.FALSE_POSITIVES: 2269 self.false_positives, 2270 metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: 2271 self.false_negatives, 2272 }, 2273 y_true, 2274 y_pred, 2275 self._thresholds, 2276 thresholds_distributed_evenly=self._thresholds_distributed_evenly, 2277 sample_weight=sample_weight, 2278 multi_label=self.multi_label, 2279 label_weights=label_weights) 2280 2281 def interpolate_pr_auc(self): 2282 """Interpolation formula inspired by section 4 of Davis & Goadrich 2006. 2283 2284 https://www.biostat.wisc.edu/~page/rocpr.pdf 2285 2286 Note here we derive & use a closed formula not present in the paper 2287 as follows: 2288 2289 Precision = TP / (TP + FP) = TP / P 2290 2291 Modeling all of TP (true positive), FP (false positive) and their sum 2292 P = TP + FP (predicted positive) as varying linearly within each interval 2293 [A, B] between successive thresholds, we get 2294 2295 Precision slope = dTP / dP 2296 = (TP_B - TP_A) / (P_B - P_A) 2297 = (TP - TP_A) / (P - P_A) 2298 Precision = (TP_A + slope * (P - P_A)) / P 2299 2300 The area within the interval is (slope / total_pos_weight) times 2301 2302 int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} 2303 int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} 2304 2305 where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in 2306 2307 int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) 2308 2309 Bringing back the factor (slope / total_pos_weight) we'd put aside, we get 2310 2311 slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight 2312 2313 where dTP == TP_B - TP_A. 2314 2315 Note that when P_A == 0 the above calculation simplifies into 2316 2317 int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) 2318 2319 which is really equivalent to imputing constant precision throughout the 2320 first bucket having >0 true positives. 2321 2322 Returns: 2323 pr_auc: an approximation of the area under the P-R curve. 2324 """ 2325 dtp = self.true_positives[:self.num_thresholds - 2326 1] - self.true_positives[1:] 2327 p = self.true_positives + self.false_positives 2328 dp = p[:self.num_thresholds - 1] - p[1:] 2329 prec_slope = math_ops.div_no_nan( 2330 dtp, math_ops.maximum(dp, 0), name='prec_slope') 2331 intercept = self.true_positives[1:] - math_ops.multiply(prec_slope, p[1:]) 2332 2333 safe_p_ratio = array_ops.where( 2334 math_ops.logical_and(p[:self.num_thresholds - 1] > 0, p[1:] > 0), 2335 math_ops.div_no_nan( 2336 p[:self.num_thresholds - 1], 2337 math_ops.maximum(p[1:], 0), 2338 name='recall_relative_ratio'), 2339 array_ops.ones_like(p[1:])) 2340 2341 pr_auc_increment = math_ops.div_no_nan( 2342 prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)), 2343 math_ops.maximum(self.true_positives[1:] + self.false_negatives[1:], 0), 2344 name='pr_auc_increment') 2345 2346 if self.multi_label: 2347 by_label_auc = math_ops.reduce_sum( 2348 pr_auc_increment, name=self.name + '_by_label', axis=0) 2349 if self.label_weights is None: 2350 # Evenly weighted average of the label AUCs. 2351 return math_ops.reduce_mean(by_label_auc, name=self.name) 2352 else: 2353 # Weighted average of the label AUCs. 2354 return math_ops.div_no_nan( 2355 math_ops.reduce_sum( 2356 math_ops.multiply(by_label_auc, self.label_weights)), 2357 math_ops.reduce_sum(self.label_weights), 2358 name=self.name) 2359 else: 2360 return math_ops.reduce_sum(pr_auc_increment, name='interpolate_pr_auc') 2361 2362 def result(self): 2363 if (self.curve == metrics_utils.AUCCurve.PR and 2364 self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION 2365 ): 2366 # This use case is different and is handled separately. 2367 return self.interpolate_pr_auc() 2368 2369 # Set `x` and `y` values for the curves based on `curve` config. 2370 recall = math_ops.div_no_nan(self.true_positives, 2371 self.true_positives + self.false_negatives) 2372 if self.curve == metrics_utils.AUCCurve.ROC: 2373 fp_rate = math_ops.div_no_nan(self.false_positives, 2374 self.false_positives + self.true_negatives) 2375 x = fp_rate 2376 y = recall 2377 else: # curve == 'PR'. 2378 precision = math_ops.div_no_nan( 2379 self.true_positives, self.true_positives + self.false_positives) 2380 x = recall 2381 y = precision 2382 2383 # Find the rectangle heights based on `summation_method`. 2384 if self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION: 2385 # Note: the case ('PR', 'interpolation') has been handled above. 2386 heights = (y[:self.num_thresholds - 1] + y[1:]) / 2. 2387 elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING: 2388 heights = math_ops.minimum(y[:self.num_thresholds - 1], y[1:]) 2389 else: # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING: 2390 heights = math_ops.maximum(y[:self.num_thresholds - 1], y[1:]) 2391 2392 # Sum up the areas of all the rectangles. 2393 if self.multi_label: 2394 riemann_terms = math_ops.multiply(x[:self.num_thresholds - 1] - x[1:], 2395 heights) 2396 by_label_auc = math_ops.reduce_sum( 2397 riemann_terms, name=self.name + '_by_label', axis=0) 2398 2399 if self.label_weights is None: 2400 # Unweighted average of the label AUCs. 2401 return math_ops.reduce_mean(by_label_auc, name=self.name) 2402 else: 2403 # Weighted average of the label AUCs. 2404 return math_ops.div_no_nan( 2405 math_ops.reduce_sum( 2406 math_ops.multiply(by_label_auc, self.label_weights)), 2407 math_ops.reduce_sum(self.label_weights), 2408 name=self.name) 2409 else: 2410 return math_ops.reduce_sum( 2411 math_ops.multiply(x[:self.num_thresholds - 1] - x[1:], heights), 2412 name=self.name) 2413 2414 def reset_state(self): 2415 if self._built: 2416 confusion_matrix_variables = (self.true_positives, self.true_negatives, 2417 self.false_positives, self.false_negatives) 2418 if self.multi_label: 2419 backend.batch_set_value( 2420 [(v, np.zeros((self.num_thresholds, self._num_labels))) 2421 for v in confusion_matrix_variables]) 2422 else: 2423 backend.batch_set_value([(v, np.zeros((self.num_thresholds,))) 2424 for v in confusion_matrix_variables]) 2425 2426 def get_config(self): 2427 if is_tensor_or_variable(self.label_weights): 2428 label_weights = backend.eval(self.label_weights) 2429 else: 2430 label_weights = self.label_weights 2431 config = { 2432 'num_thresholds': self.num_thresholds, 2433 'curve': self.curve.value, 2434 'summation_method': self.summation_method.value, 2435 # We remove the endpoint thresholds as an inverse of how the thresholds 2436 # were initialized. This ensures that a metric initialized from this 2437 # config has the same thresholds. 2438 'thresholds': self.thresholds[1:-1], 2439 'multi_label': self.multi_label, 2440 'label_weights': label_weights 2441 } 2442 base_config = super(AUC, self).get_config() 2443 return dict(list(base_config.items()) + list(config.items())) 2444 2445 2446@keras_export('keras.metrics.CosineSimilarity') 2447class CosineSimilarity(MeanMetricWrapper): 2448 """Computes the cosine similarity between the labels and predictions. 2449 2450 `cosine similarity = (a . b) / ||a|| ||b||` 2451 2452 See: [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity). 2453 2454 This metric keeps the average cosine similarity between `predictions` and 2455 `labels` over a stream of data. 2456 2457 Args: 2458 name: (Optional) string name of the metric instance. 2459 dtype: (Optional) data type of the metric result. 2460 axis: (Optional) Defaults to -1. The dimension along which the cosine 2461 similarity is computed. 2462 2463 Standalone usage: 2464 2465 >>> # l2_norm(y_true) = [[0., 1.], [1./1.414, 1./1.414]] 2466 >>> # l2_norm(y_pred) = [[1., 0.], [1./1.414, 1./1.414]] 2467 >>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]] 2468 >>> # result = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1)) 2469 >>> # = ((0. + 0.) + (0.5 + 0.5)) / 2 2470 >>> m = tf.keras.metrics.CosineSimilarity(axis=1) 2471 >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]]) 2472 >>> m.result().numpy() 2473 0.49999997 2474 2475 >>> m.reset_state() 2476 >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]], 2477 ... sample_weight=[0.3, 0.7]) 2478 >>> m.result().numpy() 2479 0.6999999 2480 2481 Usage with `compile()` API: 2482 2483 ```python 2484 model.compile( 2485 optimizer='sgd', 2486 loss='mse', 2487 metrics=[tf.keras.metrics.CosineSimilarity(axis=1)]) 2488 ``` 2489 """ 2490 2491 def __init__(self, name='cosine_similarity', dtype=None, axis=-1): 2492 super(CosineSimilarity, self).__init__( 2493 cosine_similarity, name, dtype=dtype, axis=axis) 2494 2495 2496@keras_export('keras.metrics.MeanAbsoluteError') 2497class MeanAbsoluteError(MeanMetricWrapper): 2498 """Computes the mean absolute error between the labels and predictions. 2499 2500 Args: 2501 name: (Optional) string name of the metric instance. 2502 dtype: (Optional) data type of the metric result. 2503 2504 Standalone usage: 2505 2506 >>> m = tf.keras.metrics.MeanAbsoluteError() 2507 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2508 >>> m.result().numpy() 2509 0.25 2510 2511 >>> m.reset_state() 2512 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2513 ... sample_weight=[1, 0]) 2514 >>> m.result().numpy() 2515 0.5 2516 2517 Usage with `compile()` API: 2518 2519 ```python 2520 model.compile( 2521 optimizer='sgd', 2522 loss='mse', 2523 metrics=[tf.keras.metrics.MeanAbsoluteError()]) 2524 ``` 2525 """ 2526 2527 def __init__(self, name='mean_absolute_error', dtype=None): 2528 super(MeanAbsoluteError, self).__init__( 2529 mean_absolute_error, name, dtype=dtype) 2530 2531 2532@keras_export('keras.metrics.MeanAbsolutePercentageError') 2533class MeanAbsolutePercentageError(MeanMetricWrapper): 2534 """Computes the mean absolute percentage error between `y_true` and `y_pred`. 2535 2536 Args: 2537 name: (Optional) string name of the metric instance. 2538 dtype: (Optional) data type of the metric result. 2539 2540 Standalone usage: 2541 2542 >>> m = tf.keras.metrics.MeanAbsolutePercentageError() 2543 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2544 >>> m.result().numpy() 2545 250000000.0 2546 2547 >>> m.reset_state() 2548 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2549 ... sample_weight=[1, 0]) 2550 >>> m.result().numpy() 2551 500000000.0 2552 2553 Usage with `compile()` API: 2554 2555 ```python 2556 model.compile( 2557 optimizer='sgd', 2558 loss='mse', 2559 metrics=[tf.keras.metrics.MeanAbsolutePercentageError()]) 2560 ``` 2561 """ 2562 2563 def __init__(self, name='mean_absolute_percentage_error', dtype=None): 2564 super(MeanAbsolutePercentageError, self).__init__( 2565 mean_absolute_percentage_error, name, dtype=dtype) 2566 2567 2568@keras_export('keras.metrics.MeanSquaredError') 2569class MeanSquaredError(MeanMetricWrapper): 2570 """Computes the mean squared error between `y_true` and `y_pred`. 2571 2572 Args: 2573 name: (Optional) string name of the metric instance. 2574 dtype: (Optional) data type of the metric result. 2575 2576 Standalone usage: 2577 2578 >>> m = tf.keras.metrics.MeanSquaredError() 2579 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2580 >>> m.result().numpy() 2581 0.25 2582 2583 >>> m.reset_state() 2584 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2585 ... sample_weight=[1, 0]) 2586 >>> m.result().numpy() 2587 0.5 2588 2589 Usage with `compile()` API: 2590 2591 ```python 2592 model.compile( 2593 optimizer='sgd', 2594 loss='mse', 2595 metrics=[tf.keras.metrics.MeanSquaredError()]) 2596 ``` 2597 """ 2598 2599 def __init__(self, name='mean_squared_error', dtype=None): 2600 super(MeanSquaredError, self).__init__( 2601 mean_squared_error, name, dtype=dtype) 2602 2603 2604@keras_export('keras.metrics.MeanSquaredLogarithmicError') 2605class MeanSquaredLogarithmicError(MeanMetricWrapper): 2606 """Computes the mean squared logarithmic error between `y_true` and `y_pred`. 2607 2608 Args: 2609 name: (Optional) string name of the metric instance. 2610 dtype: (Optional) data type of the metric result. 2611 2612 Standalone usage: 2613 2614 >>> m = tf.keras.metrics.MeanSquaredLogarithmicError() 2615 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2616 >>> m.result().numpy() 2617 0.12011322 2618 2619 >>> m.reset_state() 2620 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2621 ... sample_weight=[1, 0]) 2622 >>> m.result().numpy() 2623 0.24022643 2624 2625 Usage with `compile()` API: 2626 2627 ```python 2628 model.compile( 2629 optimizer='sgd', 2630 loss='mse', 2631 metrics=[tf.keras.metrics.MeanSquaredLogarithmicError()]) 2632 ``` 2633 """ 2634 2635 def __init__(self, name='mean_squared_logarithmic_error', dtype=None): 2636 super(MeanSquaredLogarithmicError, self).__init__( 2637 mean_squared_logarithmic_error, name, dtype=dtype) 2638 2639 2640@keras_export('keras.metrics.Hinge') 2641class Hinge(MeanMetricWrapper): 2642 """Computes the hinge metric between `y_true` and `y_pred`. 2643 2644 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 2645 provided we will convert them to -1 or 1. 2646 2647 Args: 2648 name: (Optional) string name of the metric instance. 2649 dtype: (Optional) data type of the metric result. 2650 2651 Standalone usage: 2652 2653 >>> m = tf.keras.metrics.Hinge() 2654 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 2655 >>> m.result().numpy() 2656 1.3 2657 2658 >>> m.reset_state() 2659 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 2660 ... sample_weight=[1, 0]) 2661 >>> m.result().numpy() 2662 1.1 2663 2664 Usage with `compile()` API: 2665 2666 ```python 2667 model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.Hinge()]) 2668 ``` 2669 """ 2670 2671 def __init__(self, name='hinge', dtype=None): 2672 super(Hinge, self).__init__(hinge, name, dtype=dtype) 2673 2674 2675@keras_export('keras.metrics.SquaredHinge') 2676class SquaredHinge(MeanMetricWrapper): 2677 """Computes the squared hinge metric between `y_true` and `y_pred`. 2678 2679 `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are 2680 provided we will convert them to -1 or 1. 2681 2682 Args: 2683 name: (Optional) string name of the metric instance. 2684 dtype: (Optional) data type of the metric result. 2685 2686 Standalone usage: 2687 2688 >>> m = tf.keras.metrics.SquaredHinge() 2689 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 2690 >>> m.result().numpy() 2691 1.86 2692 2693 >>> m.reset_state() 2694 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 2695 ... sample_weight=[1, 0]) 2696 >>> m.result().numpy() 2697 1.46 2698 2699 Usage with `compile()` API: 2700 2701 ```python 2702 model.compile( 2703 optimizer='sgd', 2704 loss='mse', 2705 metrics=[tf.keras.metrics.SquaredHinge()]) 2706 ``` 2707 """ 2708 2709 def __init__(self, name='squared_hinge', dtype=None): 2710 super(SquaredHinge, self).__init__(squared_hinge, name, dtype=dtype) 2711 2712 2713@keras_export('keras.metrics.CategoricalHinge') 2714class CategoricalHinge(MeanMetricWrapper): 2715 """Computes the categorical hinge metric between `y_true` and `y_pred`. 2716 2717 Args: 2718 name: (Optional) string name of the metric instance. 2719 dtype: (Optional) data type of the metric result. 2720 2721 Standalone usage: 2722 2723 >>> m = tf.keras.metrics.CategoricalHinge() 2724 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 2725 >>> m.result().numpy() 2726 1.4000001 2727 2728 >>> m.reset_state() 2729 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 2730 ... sample_weight=[1, 0]) 2731 >>> m.result().numpy() 2732 1.2 2733 2734 Usage with `compile()` API: 2735 2736 ```python 2737 model.compile( 2738 optimizer='sgd', 2739 loss='mse', 2740 metrics=[tf.keras.metrics.CategoricalHinge()]) 2741 ``` 2742 """ 2743 2744 def __init__(self, name='categorical_hinge', dtype=None): 2745 super(CategoricalHinge, self).__init__(categorical_hinge, name, dtype=dtype) 2746 2747 2748@keras_export('keras.metrics.RootMeanSquaredError') 2749class RootMeanSquaredError(Mean): 2750 """Computes root mean squared error metric between `y_true` and `y_pred`. 2751 2752 Standalone usage: 2753 2754 >>> m = tf.keras.metrics.RootMeanSquaredError() 2755 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2756 >>> m.result().numpy() 2757 0.5 2758 2759 >>> m.reset_state() 2760 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2761 ... sample_weight=[1, 0]) 2762 >>> m.result().numpy() 2763 0.70710677 2764 2765 Usage with `compile()` API: 2766 2767 ```python 2768 model.compile( 2769 optimizer='sgd', 2770 loss='mse', 2771 metrics=[tf.keras.metrics.RootMeanSquaredError()]) 2772 ``` 2773 """ 2774 2775 def __init__(self, name='root_mean_squared_error', dtype=None): 2776 super(RootMeanSquaredError, self).__init__(name, dtype=dtype) 2777 2778 def update_state(self, y_true, y_pred, sample_weight=None): 2779 """Accumulates root mean squared error statistics. 2780 2781 Args: 2782 y_true: The ground truth values. 2783 y_pred: The predicted values. 2784 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 2785 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 2786 be broadcastable to `y_true`. 2787 2788 Returns: 2789 Update op. 2790 """ 2791 y_true = math_ops.cast(y_true, self._dtype) 2792 y_pred = math_ops.cast(y_pred, self._dtype) 2793 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 2794 y_pred, y_true) 2795 error_sq = math_ops.squared_difference(y_pred, y_true) 2796 return super(RootMeanSquaredError, self).update_state( 2797 error_sq, sample_weight=sample_weight) 2798 2799 def result(self): 2800 return math_ops.sqrt(math_ops.div_no_nan(self.total, self.count)) 2801 2802 2803@keras_export('keras.metrics.LogCoshError') 2804class LogCoshError(MeanMetricWrapper): 2805 """Computes the logarithm of the hyperbolic cosine of the prediction error. 2806 2807 `logcosh = log((exp(x) + exp(-x))/2)`, where x is the error (y_pred - y_true) 2808 2809 Args: 2810 name: (Optional) string name of the metric instance. 2811 dtype: (Optional) data type of the metric result. 2812 2813 Standalone usage: 2814 2815 >>> m = tf.keras.metrics.LogCoshError() 2816 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2817 >>> m.result().numpy() 2818 0.10844523 2819 2820 >>> m.reset_state() 2821 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2822 ... sample_weight=[1, 0]) 2823 >>> m.result().numpy() 2824 0.21689045 2825 2826 Usage with `compile()` API: 2827 2828 ```python 2829 model.compile(optimizer='sgd', 2830 loss='mse', 2831 metrics=[tf.keras.metrics.LogCoshError()]) 2832 ``` 2833 """ 2834 2835 def __init__(self, name='logcosh', dtype=None): 2836 super(LogCoshError, self).__init__(logcosh, name, dtype=dtype) 2837 2838 2839@keras_export('keras.metrics.Poisson') 2840class Poisson(MeanMetricWrapper): 2841 """Computes the Poisson metric between `y_true` and `y_pred`. 2842 2843 `metric = y_pred - y_true * log(y_pred)` 2844 2845 Args: 2846 name: (Optional) string name of the metric instance. 2847 dtype: (Optional) data type of the metric result. 2848 2849 Standalone usage: 2850 2851 >>> m = tf.keras.metrics.Poisson() 2852 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]]) 2853 >>> m.result().numpy() 2854 0.49999997 2855 2856 >>> m.reset_state() 2857 >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]], 2858 ... sample_weight=[1, 0]) 2859 >>> m.result().numpy() 2860 0.99999994 2861 2862 Usage with `compile()` API: 2863 2864 ```python 2865 model.compile(optimizer='sgd', 2866 loss='mse', 2867 metrics=[tf.keras.metrics.Poisson()]) 2868 ``` 2869 """ 2870 2871 def __init__(self, name='poisson', dtype=None): 2872 super(Poisson, self).__init__(poisson, name, dtype=dtype) 2873 2874 2875@keras_export('keras.metrics.KLDivergence') 2876class KLDivergence(MeanMetricWrapper): 2877 """Computes Kullback-Leibler divergence metric between `y_true` and `y_pred`. 2878 2879 `metric = y_true * log(y_true / y_pred)` 2880 2881 Args: 2882 name: (Optional) string name of the metric instance. 2883 dtype: (Optional) data type of the metric result. 2884 2885 Standalone usage: 2886 2887 >>> m = tf.keras.metrics.KLDivergence() 2888 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 2889 >>> m.result().numpy() 2890 0.45814306 2891 2892 >>> m.reset_state() 2893 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 2894 ... sample_weight=[1, 0]) 2895 >>> m.result().numpy() 2896 0.9162892 2897 2898 Usage with `compile()` API: 2899 2900 ```python 2901 model.compile(optimizer='sgd', 2902 loss='mse', 2903 metrics=[tf.keras.metrics.KLDivergence()]) 2904 ``` 2905 """ 2906 2907 def __init__(self, name='kullback_leibler_divergence', dtype=None): 2908 super(KLDivergence, self).__init__( 2909 kullback_leibler_divergence, name, dtype=dtype) 2910 2911 2912@keras_export('keras.metrics.MeanIoU') 2913class MeanIoU(Metric): 2914 """Computes the mean Intersection-Over-Union metric. 2915 2916 Mean Intersection-Over-Union is a common evaluation metric for semantic image 2917 segmentation, which first computes the IOU for each semantic class and then 2918 computes the average over classes. IOU is defined as follows: 2919 IOU = true_positive / (true_positive + false_positive + false_negative). 2920 The predictions are accumulated in a confusion matrix, weighted by 2921 `sample_weight` and the metric is then calculated from it. 2922 2923 If `sample_weight` is `None`, weights default to 1. 2924 Use `sample_weight` of 0 to mask values. 2925 2926 Args: 2927 num_classes: The possible number of labels the prediction task can have. 2928 This value must be provided, since a confusion matrix of dimension = 2929 [num_classes, num_classes] will be allocated. 2930 name: (Optional) string name of the metric instance. 2931 dtype: (Optional) data type of the metric result. 2932 2933 Standalone usage: 2934 2935 >>> # cm = [[1, 1], 2936 >>> # [1, 1]] 2937 >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1] 2938 >>> # iou = true_positives / (sum_row + sum_col - true_positives)) 2939 >>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33 2940 >>> m = tf.keras.metrics.MeanIoU(num_classes=2) 2941 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1]) 2942 >>> m.result().numpy() 2943 0.33333334 2944 2945 >>> m.reset_state() 2946 >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1], 2947 ... sample_weight=[0.3, 0.3, 0.3, 0.1]) 2948 >>> m.result().numpy() 2949 0.23809525 2950 2951 Usage with `compile()` API: 2952 2953 ```python 2954 model.compile( 2955 optimizer='sgd', 2956 loss='mse', 2957 metrics=[tf.keras.metrics.MeanIoU(num_classes=2)]) 2958 ``` 2959 """ 2960 2961 def __init__(self, num_classes, name=None, dtype=None): 2962 super(MeanIoU, self).__init__(name=name, dtype=dtype) 2963 self.num_classes = num_classes 2964 2965 # Variable to accumulate the predictions in the confusion matrix. 2966 self.total_cm = self.add_weight( 2967 'total_confusion_matrix', 2968 shape=(num_classes, num_classes), 2969 initializer=init_ops.zeros_initializer) 2970 2971 def update_state(self, y_true, y_pred, sample_weight=None): 2972 """Accumulates the confusion matrix statistics. 2973 2974 Args: 2975 y_true: The ground truth values. 2976 y_pred: The predicted values. 2977 sample_weight: Optional weighting of each example. Defaults to 1. Can be a 2978 `Tensor` whose rank is either 0, or the same rank as `y_true`, and must 2979 be broadcastable to `y_true`. 2980 2981 Returns: 2982 Update op. 2983 """ 2984 2985 y_true = math_ops.cast(y_true, self._dtype) 2986 y_pred = math_ops.cast(y_pred, self._dtype) 2987 2988 # Flatten the input if its rank > 1. 2989 if y_pred.shape.ndims > 1: 2990 y_pred = array_ops.reshape(y_pred, [-1]) 2991 2992 if y_true.shape.ndims > 1: 2993 y_true = array_ops.reshape(y_true, [-1]) 2994 2995 if sample_weight is not None: 2996 sample_weight = math_ops.cast(sample_weight, self._dtype) 2997 if sample_weight.shape.ndims > 1: 2998 sample_weight = array_ops.reshape(sample_weight, [-1]) 2999 3000 # Accumulate the prediction to current confusion matrix. 3001 current_cm = confusion_matrix.confusion_matrix( 3002 y_true, 3003 y_pred, 3004 self.num_classes, 3005 weights=sample_weight, 3006 dtype=self._dtype) 3007 return self.total_cm.assign_add(current_cm) 3008 3009 def result(self): 3010 """Compute the mean intersection-over-union via the confusion matrix.""" 3011 sum_over_row = math_ops.cast( 3012 math_ops.reduce_sum(self.total_cm, axis=0), dtype=self._dtype) 3013 sum_over_col = math_ops.cast( 3014 math_ops.reduce_sum(self.total_cm, axis=1), dtype=self._dtype) 3015 true_positives = math_ops.cast( 3016 array_ops.tensor_diag_part(self.total_cm), dtype=self._dtype) 3017 3018 # sum_over_row + sum_over_col = 3019 # 2 * true_positives + false_positives + false_negatives. 3020 denominator = sum_over_row + sum_over_col - true_positives 3021 3022 # The mean is only computed over classes that appear in the 3023 # label or prediction tensor. If the denominator is 0, we need to 3024 # ignore the class. 3025 num_valid_entries = math_ops.reduce_sum( 3026 math_ops.cast(math_ops.not_equal(denominator, 0), dtype=self._dtype)) 3027 3028 iou = math_ops.div_no_nan(true_positives, denominator) 3029 3030 return math_ops.div_no_nan( 3031 math_ops.reduce_sum(iou, name='mean_iou'), num_valid_entries) 3032 3033 def reset_state(self): 3034 backend.set_value( 3035 self.total_cm, np.zeros((self.num_classes, self.num_classes))) 3036 3037 def get_config(self): 3038 config = {'num_classes': self.num_classes} 3039 base_config = super(MeanIoU, self).get_config() 3040 return dict(list(base_config.items()) + list(config.items())) 3041 3042 3043@keras_export('keras.metrics.MeanTensor') 3044class MeanTensor(Metric): 3045 """Computes the element-wise (weighted) mean of the given tensors. 3046 3047 `MeanTensor` returns a tensor with the same shape of the input tensors. The 3048 mean value is updated by keeping local variables `total` and `count`. The 3049 `total` tracks the sum of the weighted values, and `count` stores the sum of 3050 the weighted counts. 3051 3052 Args: 3053 name: (Optional) string name of the metric instance. 3054 dtype: (Optional) data type of the metric result. 3055 shape: (Optional) A list of integers, a tuple of integers, or a 1-D Tensor 3056 of type int32. If not specified, the shape is inferred from the values at 3057 the first call of update_state. 3058 3059 Standalone usage: 3060 3061 >>> m = tf.keras.metrics.MeanTensor() 3062 >>> m.update_state([0, 1, 2, 3]) 3063 >>> m.update_state([4, 5, 6, 7]) 3064 >>> m.result().numpy() 3065 array([2., 3., 4., 5.], dtype=float32) 3066 3067 >>> m.update_state([12, 10, 8, 6], sample_weight= [0, 0.2, 0.5, 1]) 3068 >>> m.result().numpy() 3069 array([2. , 3.6363635, 4.8 , 5.3333335], dtype=float32) 3070 3071 >>> m = tf.keras.metrics.MeanTensor(dtype=tf.float64, shape=(1, 4)) 3072 >>> m.result().numpy() 3073 array([[0., 0., 0., 0.]]) 3074 >>> m.update_state([[0, 1, 2, 3]]) 3075 >>> m.update_state([[4, 5, 6, 7]]) 3076 >>> m.result().numpy() 3077 array([[2., 3., 4., 5.]]) 3078 """ 3079 3080 def __init__(self, name='mean_tensor', dtype=None, shape=None): 3081 super(MeanTensor, self).__init__(name=name, dtype=dtype) 3082 self._shape = None 3083 self._total = None 3084 self._count = None 3085 self._built = False 3086 if shape is not None: 3087 self._build(shape) 3088 3089 def _build(self, shape): 3090 self._shape = tensor_shape.TensorShape(shape) 3091 self._build_input_shape = self._shape 3092 # Create new state variables 3093 self._total = self.add_weight( 3094 'total', shape=shape, initializer=init_ops.zeros_initializer) 3095 self._count = self.add_weight( 3096 'count', shape=shape, initializer=init_ops.zeros_initializer) 3097 with ops.init_scope(): 3098 if not context.executing_eagerly(): 3099 backend._initialize_variables(backend._get_session()) # pylint: disable=protected-access 3100 self._built = True 3101 3102 @property 3103 def total(self): 3104 return self._total if self._built else None 3105 3106 @property 3107 def count(self): 3108 return self._count if self._built else None 3109 3110 def update_state(self, values, sample_weight=None): 3111 """Accumulates statistics for computing the element-wise mean. 3112 3113 Args: 3114 values: Per-example value. 3115 sample_weight: Optional weighting of each example. Defaults to 1. 3116 3117 Returns: 3118 Update op. 3119 """ 3120 values = math_ops.cast(values, self._dtype) 3121 if not self._built: 3122 self._build(values.shape) 3123 elif values.shape != self._shape: 3124 raise ValueError('MeanTensor input values must always have the same ' 3125 'shape. Expected shape (set during the first call): {}. ' 3126 'Got: {}'.format(self._shape, values.shape)) 3127 3128 num_values = array_ops.ones_like(values) 3129 if sample_weight is not None: 3130 sample_weight = math_ops.cast(sample_weight, self._dtype) 3131 3132 # Update dimensions of weights to match with values if possible. 3133 values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions( 3134 values, sample_weight=sample_weight) 3135 try: 3136 # Broadcast weights if possible. 3137 sample_weight = weights_broadcast_ops.broadcast_weights( 3138 sample_weight, values) 3139 except ValueError: 3140 # Reduce values to same ndim as weight array 3141 ndim = backend.ndim(values) 3142 weight_ndim = backend.ndim(sample_weight) 3143 values = math_ops.reduce_mean( 3144 values, axis=list(range(weight_ndim, ndim))) 3145 3146 num_values = math_ops.multiply(num_values, sample_weight) 3147 values = math_ops.multiply(values, sample_weight) 3148 3149 update_total_op = self._total.assign_add(values) 3150 with ops.control_dependencies([update_total_op]): 3151 return self._count.assign_add(num_values) 3152 3153 def result(self): 3154 if not self._built: 3155 raise ValueError( 3156 'MeanTensor does not have any result yet. Please call the MeanTensor ' 3157 'instance or use `.update_state(value)` before retrieving the result.' 3158 ) 3159 return math_ops.div_no_nan(self.total, self.count) 3160 3161 def reset_state(self): 3162 if self._built: 3163 backend.batch_set_value( 3164 [(v, np.zeros(self._shape.as_list())) for v in self.variables]) 3165 3166 3167@keras_export('keras.metrics.BinaryCrossentropy') 3168class BinaryCrossentropy(MeanMetricWrapper): 3169 """Computes the crossentropy metric between the labels and predictions. 3170 3171 This is the crossentropy metric class to be used when there are only two 3172 label classes (0 and 1). 3173 3174 Args: 3175 name: (Optional) string name of the metric instance. 3176 dtype: (Optional) data type of the metric result. 3177 from_logits: (Optional )Whether output is expected to be a logits tensor. 3178 By default, we consider that output encodes a probability distribution. 3179 label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are 3180 smoothed, meaning the confidence on label values are relaxed. 3181 e.g. `label_smoothing=0.2` means that we will use a value of `0.1` for 3182 label `0` and `0.9` for label `1`". 3183 3184 Standalone usage: 3185 3186 >>> m = tf.keras.metrics.BinaryCrossentropy() 3187 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]]) 3188 >>> m.result().numpy() 3189 0.81492424 3190 3191 >>> m.reset_state() 3192 >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]], 3193 ... sample_weight=[1, 0]) 3194 >>> m.result().numpy() 3195 0.9162905 3196 3197 Usage with `compile()` API: 3198 3199 ```python 3200 model.compile( 3201 optimizer='sgd', 3202 loss='mse', 3203 metrics=[tf.keras.metrics.BinaryCrossentropy()]) 3204 ``` 3205 """ 3206 3207 def __init__(self, 3208 name='binary_crossentropy', 3209 dtype=None, 3210 from_logits=False, 3211 label_smoothing=0): 3212 super(BinaryCrossentropy, self).__init__( 3213 binary_crossentropy, 3214 name, 3215 dtype=dtype, 3216 from_logits=from_logits, 3217 label_smoothing=label_smoothing) 3218 3219 3220@keras_export('keras.metrics.CategoricalCrossentropy') 3221class CategoricalCrossentropy(MeanMetricWrapper): 3222 """Computes the crossentropy metric between the labels and predictions. 3223 3224 This is the crossentropy metric class to be used when there are multiple 3225 label classes (2 or more). Here we assume that labels are given as a `one_hot` 3226 representation. eg., When labels values are [2, 0, 1], 3227 `y_true` = [[0, 0, 1], [1, 0, 0], [0, 1, 0]]. 3228 3229 Args: 3230 name: (Optional) string name of the metric instance. 3231 dtype: (Optional) data type of the metric result. 3232 from_logits: (Optional) Whether output is expected to be a logits tensor. 3233 By default, we consider that output encodes a probability distribution. 3234 label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are 3235 smoothed, meaning the confidence on label values are relaxed. e.g. 3236 `label_smoothing=0.2` means that we will use a value of `0.1` for label 3237 `0` and `0.9` for label `1`" 3238 3239 Standalone usage: 3240 3241 >>> # EPSILON = 1e-7, y = y_true, y` = y_pred 3242 >>> # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON) 3243 >>> # y` = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]] 3244 >>> # xent = -sum(y * log(y'), axis = -1) 3245 >>> # = -((log 0.95), (log 0.1)) 3246 >>> # = [0.051, 2.302] 3247 >>> # Reduced xent = (0.051 + 2.302) / 2 3248 >>> m = tf.keras.metrics.CategoricalCrossentropy() 3249 >>> m.update_state([[0, 1, 0], [0, 0, 1]], 3250 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) 3251 >>> m.result().numpy() 3252 1.1769392 3253 3254 >>> m.reset_state() 3255 >>> m.update_state([[0, 1, 0], [0, 0, 1]], 3256 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], 3257 ... sample_weight=tf.constant([0.3, 0.7])) 3258 >>> m.result().numpy() 3259 1.6271976 3260 3261 Usage with `compile()` API: 3262 3263 ```python 3264 model.compile( 3265 optimizer='sgd', 3266 loss='mse', 3267 metrics=[tf.keras.metrics.CategoricalCrossentropy()]) 3268 ``` 3269 """ 3270 3271 def __init__(self, 3272 name='categorical_crossentropy', 3273 dtype=None, 3274 from_logits=False, 3275 label_smoothing=0): 3276 super(CategoricalCrossentropy, self).__init__( 3277 categorical_crossentropy, 3278 name, 3279 dtype=dtype, 3280 from_logits=from_logits, 3281 label_smoothing=label_smoothing) 3282 3283 3284@keras_export('keras.metrics.SparseCategoricalCrossentropy') 3285class SparseCategoricalCrossentropy(MeanMetricWrapper): 3286 """Computes the crossentropy metric between the labels and predictions. 3287 3288 Use this crossentropy metric when there are two or more label classes. 3289 We expect labels to be provided as integers. If you want to provide labels 3290 using `one-hot` representation, please use `CategoricalCrossentropy` metric. 3291 There should be `# classes` floating point values per feature for `y_pred` 3292 and a single floating point value per feature for `y_true`. 3293 3294 In the snippet below, there is a single floating point value per example for 3295 `y_true` and `# classes` floating pointing values per example for `y_pred`. 3296 The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is 3297 `[batch_size, num_classes]`. 3298 3299 Args: 3300 name: (Optional) string name of the metric instance. 3301 dtype: (Optional) data type of the metric result. 3302 from_logits: (Optional) Whether output is expected to be a logits tensor. 3303 By default, we consider that output encodes a probability distribution. 3304 axis: (Optional) Defaults to -1. The dimension along which the metric is 3305 computed. 3306 3307 Standalone usage: 3308 3309 >>> # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]] 3310 >>> # logits = log(y_pred) 3311 >>> # softmax = exp(logits) / sum(exp(logits), axis=-1) 3312 >>> # softmax = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]] 3313 >>> # xent = -sum(y * log(softmax), 1) 3314 >>> # log(softmax) = [[-2.9957, -0.0513, -16.1181], 3315 >>> # [-2.3026, -0.2231, -2.3026]] 3316 >>> # y_true * log(softmax) = [[0, -0.0513, 0], [0, 0, -2.3026]] 3317 >>> # xent = [0.0513, 2.3026] 3318 >>> # Reduced xent = (0.0513 + 2.3026) / 2 3319 >>> m = tf.keras.metrics.SparseCategoricalCrossentropy() 3320 >>> m.update_state([1, 2], 3321 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]) 3322 >>> m.result().numpy() 3323 1.1769392 3324 3325 >>> m.reset_state() 3326 >>> m.update_state([1, 2], 3327 ... [[0.05, 0.95, 0], [0.1, 0.8, 0.1]], 3328 ... sample_weight=tf.constant([0.3, 0.7])) 3329 >>> m.result().numpy() 3330 1.6271976 3331 3332 Usage with `compile()` API: 3333 3334 ```python 3335 model.compile( 3336 optimizer='sgd', 3337 loss='mse', 3338 metrics=[tf.keras.metrics.SparseCategoricalCrossentropy()]) 3339 ``` 3340 """ 3341 3342 def __init__(self, 3343 name='sparse_categorical_crossentropy', 3344 dtype=None, 3345 from_logits=False, 3346 axis=-1): 3347 super(SparseCategoricalCrossentropy, self).__init__( 3348 sparse_categorical_crossentropy, 3349 name, 3350 dtype=dtype, 3351 from_logits=from_logits, 3352 axis=axis) 3353 3354 3355class SumOverBatchSize(Reduce): 3356 """Computes the weighted sum over batch size of the given values. 3357 3358 For example, if values is [1, 3, 5, 7] then the metric value is 4. 3359 If the weights were specified as [1, 1, 0, 0] then the value would be 1. 3360 3361 This metric creates two variables, `total` and `count` that are used to 3362 compute the average of `values`. This average is ultimately returned as sum 3363 over batch size which is an idempotent operation that simply divides `total` 3364 by `count`. 3365 3366 If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 3367 to mask values. 3368 """ 3369 3370 def __init__(self, name='sum_over_batch_size', dtype=None): 3371 super(SumOverBatchSize, self).__init__( 3372 reduction=metrics_utils.Reduction.SUM_OVER_BATCH_SIZE, 3373 name=name, 3374 dtype=dtype) 3375 3376 3377class SumOverBatchSizeMetricWrapper(SumOverBatchSize): 3378 """Wraps a function with the `SumOverBatchSizeMetricWrapper` metric.""" 3379 3380 def __init__(self, fn, name=None, dtype=None, **kwargs): 3381 """Creates a `SumOverBatchSizeMetricWrapper` instance. 3382 3383 Args: 3384 fn: The metric function to wrap, with signature `fn(y_true, y_pred, 3385 **kwargs)`. 3386 name: (Optional) string name of the metric instance. 3387 dtype: (Optional) data type of the metric result. 3388 **kwargs: The keyword arguments that are passed on to `fn`. 3389 """ 3390 super(SumOverBatchSizeMetricWrapper, self).__init__(name=name, dtype=dtype) 3391 self._fn = fn 3392 self._fn_kwargs = kwargs 3393 3394 def update_state(self, y_true, y_pred, sample_weight=None): 3395 y_true = math_ops.cast(y_true, self._dtype) 3396 y_pred = math_ops.cast(y_pred, self._dtype) 3397 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 3398 y_pred, y_true) 3399 3400 ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx()) 3401 matches = ag_fn(y_true, y_pred, **self._fn_kwargs) 3402 return super(SumOverBatchSizeMetricWrapper, self).update_state( 3403 matches, sample_weight=sample_weight) 3404 3405 def get_config(self): 3406 config = {} 3407 for k, v in self._fn_kwargs.items(): 3408 config[k] = backend.eval(v) if is_tensor_or_variable(v) else v 3409 base_config = super(SumOverBatchSizeMetricWrapper, self).get_config() 3410 return dict(list(base_config.items()) + list(config.items())) 3411 3412 3413def accuracy(y_true, y_pred): 3414 [y_pred, y_true], _ = \ 3415 metrics_utils.ragged_assert_compatible_and_get_flat_values( 3416 [y_pred, y_true]) 3417 y_true.shape.assert_is_compatible_with(y_pred.shape) 3418 if y_true.dtype != y_pred.dtype: 3419 y_pred = math_ops.cast(y_pred, y_true.dtype) 3420 return math_ops.cast(math_ops.equal(y_true, y_pred), backend.floatx()) 3421 3422 3423@keras_export('keras.metrics.binary_accuracy') 3424@dispatch.add_dispatch_support 3425def binary_accuracy(y_true, y_pred, threshold=0.5): 3426 """Calculates how often predictions match binary labels. 3427 3428 Standalone usage: 3429 >>> y_true = [[1], [1], [0], [0]] 3430 >>> y_pred = [[1], [1], [0], [0]] 3431 >>> m = tf.keras.metrics.binary_accuracy(y_true, y_pred) 3432 >>> assert m.shape == (4,) 3433 >>> m.numpy() 3434 array([1., 1., 1., 1.], dtype=float32) 3435 3436 Args: 3437 y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`. 3438 y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`. 3439 threshold: (Optional) Float representing the threshold for deciding whether 3440 prediction values are 1 or 0. 3441 3442 Returns: 3443 Binary accuracy values. shape = `[batch_size, d0, .. dN-1]` 3444 """ 3445 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 3446 threshold = math_ops.cast(threshold, y_pred.dtype) 3447 y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype) 3448 return backend.mean(math_ops.equal(y_true, y_pred), axis=-1) 3449 3450 3451@keras_export('keras.metrics.categorical_accuracy') 3452@dispatch.add_dispatch_support 3453def categorical_accuracy(y_true, y_pred): 3454 """Calculates how often predictions match one-hot labels. 3455 3456 Standalone usage: 3457 >>> y_true = [[0, 0, 1], [0, 1, 0]] 3458 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] 3459 >>> m = tf.keras.metrics.categorical_accuracy(y_true, y_pred) 3460 >>> assert m.shape == (2,) 3461 >>> m.numpy() 3462 array([0., 1.], dtype=float32) 3463 3464 You can provide logits of classes as `y_pred`, since argmax of 3465 logits and probabilities are same. 3466 3467 Args: 3468 y_true: One-hot ground truth values. 3469 y_pred: The prediction values. 3470 3471 Returns: 3472 Categorical accuracy values. 3473 """ 3474 return math_ops.cast( 3475 math_ops.equal( 3476 math_ops.argmax(y_true, axis=-1), math_ops.argmax(y_pred, axis=-1)), 3477 backend.floatx()) 3478 3479 3480@keras_export('keras.metrics.sparse_categorical_accuracy') 3481@dispatch.add_dispatch_support 3482def sparse_categorical_accuracy(y_true, y_pred): 3483 """Calculates how often predictions match integer labels. 3484 3485 Standalone usage: 3486 >>> y_true = [2, 1] 3487 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] 3488 >>> m = tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred) 3489 >>> assert m.shape == (2,) 3490 >>> m.numpy() 3491 array([0., 1.], dtype=float32) 3492 3493 You can provide logits of classes as `y_pred`, since argmax of 3494 logits and probabilities are same. 3495 3496 Args: 3497 y_true: Integer ground truth values. 3498 y_pred: The prediction values. 3499 3500 Returns: 3501 Sparse categorical accuracy values. 3502 """ 3503 y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred) 3504 y_true = ops.convert_to_tensor_v2_with_dispatch(y_true) 3505 y_pred_rank = y_pred.shape.ndims 3506 y_true_rank = y_true.shape.ndims 3507 # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) 3508 if (y_true_rank is not None) and (y_pred_rank is not None) and (len( 3509 backend.int_shape(y_true)) == len(backend.int_shape(y_pred))): 3510 y_true = array_ops.squeeze(y_true, [-1]) 3511 y_pred = math_ops.argmax(y_pred, axis=-1) 3512 3513 # If the predicted output and actual output types don't match, force cast them 3514 # to match. 3515 if backend.dtype(y_pred) != backend.dtype(y_true): 3516 y_pred = math_ops.cast(y_pred, backend.dtype(y_true)) 3517 3518 return math_ops.cast(math_ops.equal(y_true, y_pred), backend.floatx()) 3519 3520 3521@keras_export('keras.metrics.top_k_categorical_accuracy') 3522@dispatch.add_dispatch_support 3523def top_k_categorical_accuracy(y_true, y_pred, k=5): 3524 """Computes how often targets are in the top `K` predictions. 3525 3526 Standalone usage: 3527 >>> y_true = [[0, 0, 1], [0, 1, 0]] 3528 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] 3529 >>> m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3) 3530 >>> assert m.shape == (2,) 3531 >>> m.numpy() 3532 array([1., 1.], dtype=float32) 3533 3534 Args: 3535 y_true: The ground truth values. 3536 y_pred: The prediction values. 3537 k: (Optional) Number of top elements to look at for computing accuracy. 3538 Defaults to 5. 3539 3540 Returns: 3541 Top K categorical accuracy value. 3542 """ 3543 return math_ops.cast( 3544 nn.in_top_k( 3545 y_pred, math_ops.argmax(y_true, axis=-1), k), backend.floatx()) 3546 3547 3548@keras_export('keras.metrics.sparse_top_k_categorical_accuracy') 3549@dispatch.add_dispatch_support 3550def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): 3551 """Computes how often integer targets are in the top `K` predictions. 3552 3553 Standalone usage: 3554 >>> y_true = [2, 1] 3555 >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] 3556 >>> m = tf.keras.metrics.sparse_top_k_categorical_accuracy( 3557 ... y_true, y_pred, k=3) 3558 >>> assert m.shape == (2,) 3559 >>> m.numpy() 3560 array([1., 1.], dtype=float32) 3561 3562 Args: 3563 y_true: tensor of true targets. 3564 y_pred: tensor of predicted targets. 3565 k: (Optional) Number of top elements to look at for computing accuracy. 3566 Defaults to 5. 3567 3568 Returns: 3569 Sparse top K categorical accuracy value. 3570 """ 3571 y_pred_rank = ops.convert_to_tensor_v2_with_dispatch(y_pred).shape.ndims 3572 y_true_rank = ops.convert_to_tensor_v2_with_dispatch(y_true).shape.ndims 3573 # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,) 3574 if (y_true_rank is not None) and (y_pred_rank is not None): 3575 if y_pred_rank > 2: 3576 y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]]) 3577 if y_true_rank > 1: 3578 y_true = array_ops.reshape(y_true, [-1]) 3579 3580 return math_ops.cast( 3581 nn.in_top_k(y_pred, math_ops.cast(y_true, 'int32'), k), backend.floatx()) 3582 3583 3584def cosine_proximity(y_true, y_pred, axis=-1): 3585 """Computes the cosine similarity between labels and predictions. 3586 3587 Args: 3588 y_true: The ground truth values. 3589 y_pred: The prediction values. 3590 axis: (Optional) Defaults to -1. The dimension along which the cosine 3591 similarity is computed. 3592 3593 Returns: 3594 Cosine similarity value. 3595 """ 3596 y_true = nn.l2_normalize(y_true, axis=axis) 3597 y_pred = nn.l2_normalize(y_pred, axis=axis) 3598 return math_ops.reduce_sum(y_true * y_pred, axis=axis) 3599 3600# Aliases 3601 3602acc = ACC = accuracy 3603bce = BCE = binary_crossentropy 3604mse = MSE = mean_squared_error 3605mae = MAE = mean_absolute_error 3606mape = MAPE = mean_absolute_percentage_error 3607msle = MSLE = mean_squared_logarithmic_error 3608cosine_similarity = cosine_proximity 3609log_cosh = logcosh 3610 3611 3612def clone_metric(metric): 3613 """Returns a clone of the metric if stateful, otherwise returns it as is.""" 3614 if isinstance(metric, Metric): 3615 with ops.init_scope(): 3616 return metric.__class__.from_config(metric.get_config()) 3617 return metric 3618 3619 3620def clone_metrics(metrics): 3621 """Clones the given metric list/dict.""" 3622 return nest.map_structure(clone_metric, metrics) 3623 3624 3625@keras_export('keras.metrics.serialize') 3626def serialize(metric): 3627 """Serializes metric function or `Metric` instance. 3628 3629 Args: 3630 metric: A Keras `Metric` instance or a metric function. 3631 3632 Returns: 3633 Metric configuration dictionary. 3634 """ 3635 return serialize_keras_object(metric) 3636 3637 3638@keras_export('keras.metrics.deserialize') 3639def deserialize(config, custom_objects=None): 3640 """Deserializes a serialized metric class/function instance. 3641 3642 Args: 3643 config: Metric configuration. 3644 custom_objects: Optional dictionary mapping names (strings) to custom 3645 objects (classes and functions) to be considered during deserialization. 3646 3647 Returns: 3648 A Keras `Metric` instance or a metric function. 3649 """ 3650 return deserialize_keras_object( 3651 config, 3652 module_objects=globals(), 3653 custom_objects=custom_objects, 3654 printable_module_name='metric function') 3655 3656 3657@keras_export('keras.metrics.get') 3658def get(identifier): 3659 """Retrieves a Keras metric as a `function`/`Metric` class instance. 3660 3661 The `identifier` may be the string name of a metric function or class. 3662 3663 >>> metric = tf.keras.metrics.get("categorical_crossentropy") 3664 >>> type(metric) 3665 <class 'function'> 3666 >>> metric = tf.keras.metrics.get("CategoricalCrossentropy") 3667 >>> type(metric) 3668 <class '...keras.metrics.CategoricalCrossentropy'> 3669 3670 You can also specify `config` of the metric to this function by passing dict 3671 containing `class_name` and `config` as an identifier. Also note that the 3672 `class_name` must map to a `Metric` class 3673 3674 >>> identifier = {"class_name": "CategoricalCrossentropy", 3675 ... "config": {"from_logits": True}} 3676 >>> metric = tf.keras.metrics.get(identifier) 3677 >>> type(metric) 3678 <class '...keras.metrics.CategoricalCrossentropy'> 3679 3680 Args: 3681 identifier: A metric identifier. One of None or string name of a metric 3682 function/class or metric configuration dictionary or a metric function or 3683 a metric class instance 3684 3685 Returns: 3686 A Keras metric as a `function`/ `Metric` class instance. 3687 3688 Raises: 3689 ValueError: If `identifier` cannot be interpreted. 3690 """ 3691 if isinstance(identifier, dict): 3692 return deserialize(identifier) 3693 elif isinstance(identifier, str): 3694 return deserialize(str(identifier)) 3695 elif callable(identifier): 3696 return identifier 3697 else: 3698 raise ValueError( 3699 'Could not interpret metric function identifier: {}'.format(identifier)) 3700 3701 3702def is_built_in(cls): 3703 return cls.__module__ == Metric.__module__ 3704