xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/metrics.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-classes-have-attributes
16# pylint: disable=g-doc-return-or-yield
17"""Built-in metrics."""
18
19import abc
20import types
21import warnings
22
23import numpy as np
24
25from tensorflow.python.autograph.core import ag_ctx
26from tensorflow.python.autograph.impl import api as autograph
27from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.keras import activations
35from tensorflow.python.keras import backend
36from tensorflow.python.keras.engine import base_layer
37from tensorflow.python.keras.engine import base_layer_utils
38from tensorflow.python.keras.engine import keras_tensor
39from tensorflow.python.keras.losses import binary_crossentropy
40from tensorflow.python.keras.losses import categorical_crossentropy
41from tensorflow.python.keras.losses import categorical_hinge
42from tensorflow.python.keras.losses import hinge
43from tensorflow.python.keras.losses import kullback_leibler_divergence
44from tensorflow.python.keras.losses import logcosh
45from tensorflow.python.keras.losses import mean_absolute_error
46from tensorflow.python.keras.losses import mean_absolute_percentage_error
47from tensorflow.python.keras.losses import mean_squared_error
48from tensorflow.python.keras.losses import mean_squared_logarithmic_error
49from tensorflow.python.keras.losses import poisson
50from tensorflow.python.keras.losses import sparse_categorical_crossentropy
51from tensorflow.python.keras.losses import squared_hinge
52from tensorflow.python.keras.saving.saved_model import metric_serialization
53from tensorflow.python.keras.utils import generic_utils
54from tensorflow.python.keras.utils import losses_utils
55from tensorflow.python.keras.utils import metrics_utils
56from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
57from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
58from tensorflow.python.keras.utils.generic_utils import to_list
59from tensorflow.python.keras.utils.tf_utils import is_tensor_or_variable
60from tensorflow.python.ops import array_ops
61from tensorflow.python.ops import check_ops
62from tensorflow.python.ops import confusion_matrix
63from tensorflow.python.ops import init_ops
64from tensorflow.python.ops import math_ops
65from tensorflow.python.ops import nn
66from tensorflow.python.ops import variables as variables_module
67from tensorflow.python.ops import weights_broadcast_ops
68from tensorflow.python.util import dispatch
69from tensorflow.python.util import nest
70from tensorflow.python.util.tf_export import keras_export
71from tensorflow.tools.docs import doc_controls
72
73
74@keras_export('keras.metrics.Metric')
75class Metric(base_layer.Layer, metaclass=abc.ABCMeta):
76  """Encapsulates metric logic and state.
77
78  Args:
79    name: (Optional) string name of the metric instance.
80    dtype: (Optional) data type of the metric result.
81    **kwargs: Additional layer keywords arguments.
82
83  Standalone usage:
84
85  ```python
86  m = SomeMetric(...)
87  for input in ...:
88    m.update_state(input)
89  print('Final result: ', m.result().numpy())
90  ```
91
92  Usage with `compile()` API:
93
94  ```python
95  model = tf.keras.Sequential()
96  model.add(tf.keras.layers.Dense(64, activation='relu'))
97  model.add(tf.keras.layers.Dense(64, activation='relu'))
98  model.add(tf.keras.layers.Dense(10, activation='softmax'))
99
100  model.compile(optimizer=tf.keras.optimizers.RMSprop(0.01),
101                loss=tf.keras.losses.CategoricalCrossentropy(),
102                metrics=[tf.keras.metrics.CategoricalAccuracy()])
103
104  data = np.random.random((1000, 32))
105  labels = np.random.random((1000, 10))
106
107  dataset = tf.data.Dataset.from_tensor_slices((data, labels))
108  dataset = dataset.batch(32)
109
110  model.fit(dataset, epochs=10)
111  ```
112
113  To be implemented by subclasses:
114  * `__init__()`: All state variables should be created in this method by
115    calling `self.add_weight()` like: `self.var = self.add_weight(...)`
116  * `update_state()`: Has all updates to the state variables like:
117    self.var.assign_add(...).
118  * `result()`: Computes and returns a value for the metric
119    from the state variables.
120
121  Example subclass implementation:
122
123  ```python
124  class BinaryTruePositives(tf.keras.metrics.Metric):
125
126    def __init__(self, name='binary_true_positives', **kwargs):
127      super(BinaryTruePositives, self).__init__(name=name, **kwargs)
128      self.true_positives = self.add_weight(name='tp', initializer='zeros')
129
130    def update_state(self, y_true, y_pred, sample_weight=None):
131      y_true = tf.cast(y_true, tf.bool)
132      y_pred = tf.cast(y_pred, tf.bool)
133
134      values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
135      values = tf.cast(values, self.dtype)
136      if sample_weight is not None:
137        sample_weight = tf.cast(sample_weight, self.dtype)
138        sample_weight = tf.broadcast_to(sample_weight, values.shape)
139        values = tf.multiply(values, sample_weight)
140      self.true_positives.assign_add(tf.reduce_sum(values))
141
142    def result(self):
143      return self.true_positives
144  ```
145  """
146
147  def __init__(self, name=None, dtype=None, **kwargs):
148    super(Metric, self).__init__(name=name, dtype=dtype, **kwargs)
149    self.stateful = True  # All metric layers are stateful.
150    self.built = True
151    if not base_layer_utils.v2_dtype_behavior_enabled():
152      # We only do this when the V2 behavior is not enabled, as when it is
153      # enabled, the dtype already defaults to floatx.
154      self._dtype = (backend.floatx() if dtype is None
155                     else dtypes.as_dtype(dtype).name)
156
157  def __new__(cls, *args, **kwargs):
158    obj = super(Metric, cls).__new__(cls)
159
160    # If `update_state` is not in eager/tf.function and it is not from a
161    # built-in metric, wrap it in `tf.function`. This is so that users writing
162    # custom metrics in v1 need not worry about control dependencies and
163    # return ops.
164    if (base_layer_utils.is_in_eager_or_tf_function() or
165        is_built_in(cls)):
166      obj_update_state = obj.update_state
167
168      def update_state_fn(*args, **kwargs):
169        control_status = ag_ctx.control_status_ctx()
170        ag_update_state = autograph.tf_convert(obj_update_state, control_status)
171        return ag_update_state(*args, **kwargs)
172    else:
173      if isinstance(obj.update_state, def_function.Function):
174        update_state_fn = obj.update_state
175      else:
176        update_state_fn = def_function.function(obj.update_state)
177
178    obj.update_state = types.MethodType(
179        metrics_utils.update_state_wrapper(update_state_fn), obj)
180
181    obj_result = obj.result
182
183    def result_fn(*args, **kwargs):
184      control_status = ag_ctx.control_status_ctx()
185      ag_result = autograph.tf_convert(obj_result, control_status)
186      return ag_result(*args, **kwargs)
187
188    obj.result = types.MethodType(metrics_utils.result_wrapper(result_fn), obj)
189
190    return obj
191
192  def __call__(self, *args, **kwargs):
193    """Accumulates statistics and then computes metric result value.
194
195    Args:
196      *args:
197      **kwargs: A mini-batch of inputs to the Metric,
198        passed on to `update_state()`.
199
200    Returns:
201      The metric value tensor.
202    """
203
204    def replica_local_fn(*args, **kwargs):
205      """Updates the state of the metric in a replica-local context."""
206      if any(
207          isinstance(arg, keras_tensor.KerasTensor)
208          for arg in nest.flatten((args, kwargs))):
209        update_op = None
210      else:
211        update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
212      update_ops = []
213      if update_op is not None:
214        update_ops.append(update_op)
215      with ops.control_dependencies(update_ops):
216        result_t = self.result()  # pylint: disable=not-callable
217
218        # We are adding the metric object as metadata on the result tensor.
219        # This is required when we want to use a metric with `add_metric` API on
220        # a Model/Layer in graph mode. This metric instance will later be used
221        # to reset variable state after each epoch of training.
222        # Example:
223        #   model = Model()
224        #   mean = Mean()
225        #   model.add_metric(mean(values), name='mean')
226        result_t._metric_obj = self  # pylint: disable=protected-access
227        return result_t
228
229    from tensorflow.python.keras.distribute import distributed_training_utils  # pylint:disable=g-import-not-at-top
230    return distributed_training_utils.call_replica_local_fn(
231        replica_local_fn, *args, **kwargs)
232
233  @property
234  def dtype(self):
235    return self._dtype
236
237  def get_config(self):
238    """Returns the serializable config of the metric."""
239    return {'name': self.name, 'dtype': self.dtype}
240
241  def reset_state(self):
242    """Resets all of the metric state variables.
243
244    This function is called between epochs/steps,
245    when a metric is evaluated during training.
246    """
247    if not generic_utils.is_default(self.reset_states):
248      warnings.warn('Metric %s implements a `reset_states()` method; rename it '
249                    'to `reset_state()` (without the final "s"). The name '
250                    '`reset_states()` has been deprecated to improve API '
251                    'consistency.' % (self.__class__.__name__,))
252      return self.reset_states()
253    else:
254      backend.batch_set_value([(v, 0) for v in self.variables])
255
256  @abc.abstractmethod
257  def update_state(self, *args, **kwargs):
258    """Accumulates statistics for the metric.
259
260    Note: This function is executed as a graph function in graph mode.
261    This means:
262      a) Operations on the same resource are executed in textual order.
263         This should make it easier to do things like add the updated
264         value of a variable to another, for example.
265      b) You don't need to worry about collecting the update ops to execute.
266         All update ops added to the graph by this function will be executed.
267      As a result, code should generally work the same way with graph or
268      eager execution.
269
270    Args:
271      *args:
272      **kwargs: A mini-batch of inputs to the Metric.
273    """
274    raise NotImplementedError('Must be implemented in subclasses.')
275
276  @abc.abstractmethod
277  def result(self):
278    """Computes and returns the metric value tensor.
279
280    Result computation is an idempotent operation that simply calculates the
281    metric value using the state variables.
282    """
283    raise NotImplementedError('Must be implemented in subclasses.')
284
285  ### For use by subclasses ###
286  @doc_controls.for_subclass_implementers
287  def add_weight(
288      self,
289      name,
290      shape=(),
291      aggregation=variables_module.VariableAggregation.SUM,
292      synchronization=variables_module.VariableSynchronization.ON_READ,
293      initializer=None,
294      dtype=None):
295    """Adds state variable. Only for use by subclasses."""
296    if distribute_ctx.has_strategy():
297      strategy = distribute_ctx.get_strategy()
298    else:
299      strategy = None
300
301    # TODO(b/120571621): Make `ON_READ` work with Keras metrics on TPU.
302    if backend.is_tpu_strategy(strategy):
303      synchronization = variables_module.VariableSynchronization.ON_WRITE
304
305    with ops.init_scope():
306      return super(Metric, self).add_weight(
307          name=name,
308          shape=shape,
309          dtype=self._dtype if dtype is None else dtype,
310          trainable=False,
311          initializer=initializer,
312          collections=[],
313          synchronization=synchronization,
314          aggregation=aggregation)
315
316  ### End: For use by subclasses ###
317
318  @property
319  def trainable_weights(self):
320    # Overridden from Layer class to track submetric weights.
321    if self.trainable:
322      trainable_weights = self._trainable_weights
323      for m in self._metrics:
324        trainable_weights += m.trainable_weights
325      return self._dedup_weights(trainable_weights)
326    else:
327      return []
328
329  @property
330  def non_trainable_weights(self):
331    # Overridden from Layer class to track submetric weights.
332    if self.trainable:
333      non_trainable_weights = self._non_trainable_weights
334      for m in self._metrics:
335        non_trainable_weights += m.non_trainable_weights
336    else:
337      non_trainable_weights = (
338          self._non_trainable_weights + self._trainable_weights)
339      for m in self._metrics:
340        non_trainable_weights += m.weights
341    return self._dedup_weights(non_trainable_weights)
342
343  @property
344  def _trackable_saved_model_saver(self):
345    return metric_serialization.MetricSavedModelSaver(self)
346
347  @generic_utils.default
348  @doc_controls.do_not_generate_docs
349  def reset_states(self):
350    # Backwards compatibility alias of `reset_state`. New classes should
351    # only implement `reset_state`.
352    return self.reset_state()
353
354
355class Reduce(Metric):
356  """Encapsulates metrics that perform a reduce operation on the values.
357
358  Args:
359    reduction: a `tf.keras.metrics.Reduction` enum value.
360    name: string name of the metric instance.
361    dtype: (Optional) data type of the metric result.
362  """
363
364  def __init__(self, reduction, name, dtype=None):
365    super(Reduce, self).__init__(name=name, dtype=dtype)
366    self.reduction = reduction
367    self.total = self.add_weight(
368        'total', initializer=init_ops.zeros_initializer)
369    if reduction in [metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
370                     metrics_utils.Reduction.WEIGHTED_MEAN]:
371      self.count = self.add_weight(
372          'count', initializer=init_ops.zeros_initializer)
373
374  def update_state(self, values, sample_weight=None):
375    """Accumulates statistics for computing the metric.
376
377    Args:
378      values: Per-example value.
379      sample_weight: Optional weighting of each example. Defaults to 1.
380
381    Returns:
382      Update op.
383    """
384    [values], sample_weight = \
385        metrics_utils.ragged_assert_compatible_and_get_flat_values(
386            [values], sample_weight)
387    try:
388      values = math_ops.cast(values, self._dtype)
389    except (ValueError, TypeError):
390      msg = ('The output of a metric function can only be a single Tensor. '
391             'Got: %s' % (values,))
392      if isinstance(values, dict):
393        msg += ('. To return a dict of values, implement a custom Metric '
394                'subclass.')
395      raise RuntimeError(msg)
396    if sample_weight is not None:
397      sample_weight = math_ops.cast(sample_weight, self._dtype)
398      # Update dimensions of weights to match with values if possible.
399      values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions(
400          values, sample_weight=sample_weight)
401      try:
402        # Broadcast weights if possible.
403        sample_weight = weights_broadcast_ops.broadcast_weights(
404            sample_weight, values)
405      except ValueError:
406        # Reduce values to same ndim as weight array
407        ndim = backend.ndim(values)
408        weight_ndim = backend.ndim(sample_weight)
409        if self.reduction == metrics_utils.Reduction.SUM:
410          values = math_ops.reduce_sum(
411              values, axis=list(range(weight_ndim, ndim)))
412        else:
413          values = math_ops.reduce_mean(
414              values, axis=list(range(weight_ndim, ndim)))
415      values = math_ops.multiply(values, sample_weight)
416
417    value_sum = math_ops.reduce_sum(values)
418    with ops.control_dependencies([value_sum]):
419      update_total_op = self.total.assign_add(value_sum)
420
421    # Exit early if the reduction doesn't have a denominator.
422    if self.reduction == metrics_utils.Reduction.SUM:
423      return update_total_op
424
425    # Update `count` for reductions that require a denominator.
426    if self.reduction == metrics_utils.Reduction.SUM_OVER_BATCH_SIZE:
427      num_values = math_ops.cast(array_ops.size(values), self._dtype)
428    elif self.reduction == metrics_utils.Reduction.WEIGHTED_MEAN:
429      if sample_weight is None:
430        num_values = math_ops.cast(array_ops.size(values), self._dtype)
431      else:
432        num_values = math_ops.reduce_sum(sample_weight)
433    else:
434      raise NotImplementedError(
435          'reduction [%s] not implemented' % self.reduction)
436
437    with ops.control_dependencies([update_total_op]):
438      return self.count.assign_add(num_values)
439
440  def result(self):
441    if self.reduction == metrics_utils.Reduction.SUM:
442      return array_ops.identity(self.total)
443    elif self.reduction in [
444        metrics_utils.Reduction.WEIGHTED_MEAN,
445        metrics_utils.Reduction.SUM_OVER_BATCH_SIZE
446    ]:
447      return math_ops.div_no_nan(self.total, self.count)
448    else:
449      raise NotImplementedError(
450          'reduction [%s] not implemented' % self.reduction)
451
452
453@keras_export('keras.metrics.Sum')
454class Sum(Reduce):
455  """Computes the (weighted) sum of the given values.
456
457  For example, if values is [1, 3, 5, 7] then the sum is 16.
458  If the weights were specified as [1, 1, 0, 0] then the sum would be 4.
459
460  This metric creates one variable, `total`, that is used to compute the sum of
461  `values`. This is ultimately returned as `sum`.
462
463  If `sample_weight` is `None`, weights default to 1.  Use `sample_weight` of 0
464  to mask values.
465
466  Args:
467    name: (Optional) string name of the metric instance.
468    dtype: (Optional) data type of the metric result.
469
470  Standalone usage:
471
472  >>> m = tf.keras.metrics.Sum()
473  >>> m.update_state([1, 3, 5, 7])
474  >>> m.result().numpy()
475  16.0
476
477  Usage with `compile()` API:
478
479  ```python
480  model.add_metric(tf.keras.metrics.Sum(name='sum_1')(outputs))
481  model.compile(optimizer='sgd', loss='mse')
482  ```
483  """
484
485  def __init__(self, name='sum', dtype=None):
486    super(Sum, self).__init__(reduction=metrics_utils.Reduction.SUM,
487                              name=name, dtype=dtype)
488
489
490@keras_export('keras.metrics.Mean')
491class Mean(Reduce):
492  """Computes the (weighted) mean of the given values.
493
494  For example, if values is [1, 3, 5, 7] then the mean is 4.
495  If the weights were specified as [1, 1, 0, 0] then the mean would be 2.
496
497  This metric creates two variables, `total` and `count` that are used to
498  compute the average of `values`. This average is ultimately returned as `mean`
499  which is an idempotent operation that simply divides `total` by `count`.
500
501  If `sample_weight` is `None`, weights default to 1.
502  Use `sample_weight` of 0 to mask values.
503
504  Args:
505    name: (Optional) string name of the metric instance.
506    dtype: (Optional) data type of the metric result.
507
508  Standalone usage:
509
510  >>> m = tf.keras.metrics.Mean()
511  >>> m.update_state([1, 3, 5, 7])
512  >>> m.result().numpy()
513  4.0
514  >>> m.reset_state()
515  >>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0])
516  >>> m.result().numpy()
517  2.0
518
519  Usage with `compile()` API:
520
521  ```python
522  model.add_metric(tf.keras.metrics.Mean(name='mean_1')(outputs))
523  model.compile(optimizer='sgd', loss='mse')
524  ```
525  """
526
527  def __init__(self, name='mean', dtype=None):
528    super(Mean, self).__init__(
529        reduction=metrics_utils.Reduction.WEIGHTED_MEAN, name=name, dtype=dtype)
530
531
532@keras_export('keras.metrics.MeanRelativeError')
533class MeanRelativeError(Mean):
534  """Computes the mean relative error by normalizing with the given values.
535
536  This metric creates two local variables, `total` and `count` that are used to
537  compute the mean relative error. This is weighted by `sample_weight`, and
538  it is ultimately returned as `mean_relative_error`:
539  an idempotent operation that simply divides `total` by `count`.
540
541  If `sample_weight` is `None`, weights default to 1.
542  Use `sample_weight` of 0 to mask values.
543
544  Args:
545    normalizer: The normalizer values with same shape as predictions.
546    name: (Optional) string name of the metric instance.
547    dtype: (Optional) data type of the metric result.
548
549  Standalone usage:
550
551  >>> m = tf.keras.metrics.MeanRelativeError(normalizer=[1, 3, 2, 3])
552  >>> m.update_state([1, 3, 2, 3], [2, 4, 6, 8])
553
554  >>> # metric = mean(|y_pred - y_true| / normalizer)
555  >>> #        = mean([1, 1, 4, 5] / [1, 3, 2, 3]) = mean([1, 1/3, 2, 5/3])
556  >>> #        = 5/4 = 1.25
557  >>> m.result().numpy()
558  1.25
559
560  Usage with `compile()` API:
561
562  ```python
563  model.compile(
564    optimizer='sgd',
565    loss='mse',
566    metrics=[tf.keras.metrics.MeanRelativeError(normalizer=[1, 3])])
567  ```
568  """
569
570  def __init__(self, normalizer, name=None, dtype=None):
571    super(MeanRelativeError, self).__init__(name=name, dtype=dtype)
572    normalizer = math_ops.cast(normalizer, self._dtype)
573    self.normalizer = normalizer
574
575  def update_state(self, y_true, y_pred, sample_weight=None):
576    """Accumulates metric statistics.
577
578    Args:
579      y_true: The ground truth values.
580      y_pred: The predicted values.
581      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
582        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
583        be broadcastable to `y_true`.
584
585    Returns:
586      Update op.
587    """
588    y_true = math_ops.cast(y_true, self._dtype)
589    y_pred = math_ops.cast(y_pred, self._dtype)
590    [y_pred, y_true], sample_weight = \
591        metrics_utils.ragged_assert_compatible_and_get_flat_values(
592            [y_pred, y_true], sample_weight)
593    y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
594        y_pred, y_true)
595
596    y_pred, self.normalizer = losses_utils.remove_squeezable_dimensions(
597        y_pred, self.normalizer)
598    y_pred.shape.assert_is_compatible_with(y_true.shape)
599    relative_errors = math_ops.div_no_nan(
600        math_ops.abs(y_true - y_pred), self.normalizer)
601
602    return super(MeanRelativeError, self).update_state(
603        relative_errors, sample_weight=sample_weight)
604
605  def get_config(self):
606    n = self.normalizer
607    config = {'normalizer': backend.eval(n) if is_tensor_or_variable(n) else n}
608    base_config = super(MeanRelativeError, self).get_config()
609    return dict(list(base_config.items()) + list(config.items()))
610
611
612@keras_export('keras.metrics.MeanMetricWrapper')
613class MeanMetricWrapper(Mean):
614  """Wraps a stateless metric function with the Mean metric.
615
616  You could use this class to quickly build a mean metric from a function. The
617  function needs to have the signature `fn(y_true, y_pred)` and return a
618  per-sample loss array. `MeanMetricWrapper.result()` will return
619  the average metric value across all samples seen so far.
620
621  For example:
622
623  ```python
624  def accuracy(y_true, y_pred):
625    return tf.cast(tf.math.equal(y_true, y_pred), tf.float32)
626
627  accuracy_metric = tf.keras.metrics.MeanMetricWrapper(fn=accuracy)
628
629  keras_model.compile(..., metrics=accuracy_metric)
630  ```
631
632  Args:
633    fn: The metric function to wrap, with signature `fn(y_true, y_pred,
634      **kwargs)`.
635    name: (Optional) string name of the metric instance.
636    dtype: (Optional) data type of the metric result.
637    **kwargs: Keyword arguments to pass on to `fn`.
638  """
639
640  def __init__(self, fn, name=None, dtype=None, **kwargs):
641    super(MeanMetricWrapper, self).__init__(name=name, dtype=dtype)
642    self._fn = fn
643    self._fn_kwargs = kwargs
644
645  def update_state(self, y_true, y_pred, sample_weight=None):
646    """Accumulates metric statistics.
647
648    `y_true` and `y_pred` should have the same shape.
649
650    Args:
651      y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
652      y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
653      sample_weight: Optional `sample_weight` acts as a
654        coefficient for the metric. If a scalar is provided, then the metric is
655        simply scaled by the given value. If `sample_weight` is a tensor of size
656        `[batch_size]`, then the metric for each sample of the batch is rescaled
657        by the corresponding element in the `sample_weight` vector. If the shape
658        of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted
659        to this shape), then each metric element of `y_pred` is scaled by the
660        corresponding value of `sample_weight`. (Note on `dN-1`: all metric
661        functions reduce by 1 dimension, usually the last axis (-1)).
662
663    Returns:
664      Update op.
665    """
666    y_true = math_ops.cast(y_true, self._dtype)
667    y_pred = math_ops.cast(y_pred, self._dtype)
668    [y_true, y_pred], sample_weight = (
669        metrics_utils.ragged_assert_compatible_and_get_flat_values(
670            [y_true, y_pred], sample_weight))
671    y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
672        y_pred, y_true)
673
674    ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx())
675    matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
676    return super(MeanMetricWrapper, self).update_state(
677        matches, sample_weight=sample_weight)
678
679  def get_config(self):
680    config = {}
681
682    if type(self) is MeanMetricWrapper:  # pylint: disable=unidiomatic-typecheck
683      # Only include function argument when the object is a MeanMetricWrapper
684      # and not a subclass.
685      config['fn'] = self._fn
686
687    for k, v in self._fn_kwargs.items():
688      config[k] = backend.eval(v) if is_tensor_or_variable(v) else v
689    base_config = super(MeanMetricWrapper, self).get_config()
690    return dict(list(base_config.items()) + list(config.items()))
691
692  @classmethod
693  def from_config(cls, config):
694    # Note that while MeanMetricWrapper itself isn't public, objects of this
695    # class may be created and added to the model by calling model.compile.
696    fn = config.pop('fn', None)
697    if cls is MeanMetricWrapper:
698      return cls(get(fn), **config)
699    return super(MeanMetricWrapper, cls).from_config(config)
700
701
702@keras_export('keras.metrics.Accuracy')
703class Accuracy(MeanMetricWrapper):
704  """Calculates how often predictions equal labels.
705
706  This metric creates two local variables, `total` and `count` that are used to
707  compute the frequency with which `y_pred` matches `y_true`. This frequency is
708  ultimately returned as `binary accuracy`: an idempotent operation that simply
709  divides `total` by `count`.
710
711  If `sample_weight` is `None`, weights default to 1.
712  Use `sample_weight` of 0 to mask values.
713
714  Args:
715    name: (Optional) string name of the metric instance.
716    dtype: (Optional) data type of the metric result.
717
718  Standalone usage:
719
720  >>> m = tf.keras.metrics.Accuracy()
721  >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]])
722  >>> m.result().numpy()
723  0.75
724
725  >>> m.reset_state()
726  >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]],
727  ...                sample_weight=[1, 1, 0, 0])
728  >>> m.result().numpy()
729  0.5
730
731  Usage with `compile()` API:
732
733  ```python
734  model.compile(optimizer='sgd',
735                loss='mse',
736                metrics=[tf.keras.metrics.Accuracy()])
737  ```
738  """
739
740  def __init__(self, name='accuracy', dtype=None):
741    super(Accuracy, self).__init__(accuracy, name, dtype=dtype)
742
743
744@keras_export('keras.metrics.BinaryAccuracy')
745class BinaryAccuracy(MeanMetricWrapper):
746  """Calculates how often predictions match binary labels.
747
748  This metric creates two local variables, `total` and `count` that are used to
749  compute the frequency with which `y_pred` matches `y_true`. This frequency is
750  ultimately returned as `binary accuracy`: an idempotent operation that simply
751  divides `total` by `count`.
752
753  If `sample_weight` is `None`, weights default to 1.
754  Use `sample_weight` of 0 to mask values.
755
756  Args:
757    name: (Optional) string name of the metric instance.
758    dtype: (Optional) data type of the metric result.
759    threshold: (Optional) Float representing the threshold for deciding
760    whether prediction values are 1 or 0.
761
762  Standalone usage:
763
764  >>> m = tf.keras.metrics.BinaryAccuracy()
765  >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]])
766  >>> m.result().numpy()
767  0.75
768
769  >>> m.reset_state()
770  >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]],
771  ...                sample_weight=[1, 0, 0, 1])
772  >>> m.result().numpy()
773  0.5
774
775  Usage with `compile()` API:
776
777  ```python
778  model.compile(optimizer='sgd',
779                loss='mse',
780                metrics=[tf.keras.metrics.BinaryAccuracy()])
781  ```
782  """
783
784  def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5):
785    super(BinaryAccuracy, self).__init__(
786        binary_accuracy, name, dtype=dtype, threshold=threshold)
787
788
789@keras_export('keras.metrics.CategoricalAccuracy')
790class CategoricalAccuracy(MeanMetricWrapper):
791  """Calculates how often predictions match one-hot labels.
792
793  You can provide logits of classes as `y_pred`, since argmax of
794  logits and probabilities are same.
795
796  This metric creates two local variables, `total` and `count` that are used to
797  compute the frequency with which `y_pred` matches `y_true`. This frequency is
798  ultimately returned as `categorical accuracy`: an idempotent operation that
799  simply divides `total` by `count`.
800
801  `y_pred` and `y_true` should be passed in as vectors of probabilities, rather
802  than as labels. If necessary, use `tf.one_hot` to expand `y_true` as a vector.
803
804  If `sample_weight` is `None`, weights default to 1.
805  Use `sample_weight` of 0 to mask values.
806
807  Args:
808    name: (Optional) string name of the metric instance.
809    dtype: (Optional) data type of the metric result.
810
811  Standalone usage:
812
813  >>> m = tf.keras.metrics.CategoricalAccuracy()
814  >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],
815  ...                 [0.05, 0.95, 0]])
816  >>> m.result().numpy()
817  0.5
818
819  >>> m.reset_state()
820  >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],
821  ...                 [0.05, 0.95, 0]],
822  ...                sample_weight=[0.7, 0.3])
823  >>> m.result().numpy()
824  0.3
825
826  Usage with `compile()` API:
827
828  ```python
829  model.compile(
830    optimizer='sgd',
831    loss='mse',
832    metrics=[tf.keras.metrics.CategoricalAccuracy()])
833  ```
834  """
835
836  def __init__(self, name='categorical_accuracy', dtype=None):
837    super(CategoricalAccuracy, self).__init__(
838        categorical_accuracy, name, dtype=dtype)
839
840
841@keras_export('keras.metrics.SparseCategoricalAccuracy')
842class SparseCategoricalAccuracy(MeanMetricWrapper):
843  """Calculates how often predictions match integer labels.
844
845  ```python
846  acc = np.dot(sample_weight, np.equal(y_true, np.argmax(y_pred, axis=1))
847  ```
848
849  You can provide logits of classes as `y_pred`, since argmax of
850  logits and probabilities are same.
851
852  This metric creates two local variables, `total` and `count` that are used to
853  compute the frequency with which `y_pred` matches `y_true`. This frequency is
854  ultimately returned as `sparse categorical accuracy`: an idempotent operation
855  that simply divides `total` by `count`.
856
857  If `sample_weight` is `None`, weights default to 1.
858  Use `sample_weight` of 0 to mask values.
859
860  Args:
861    name: (Optional) string name of the metric instance.
862    dtype: (Optional) data type of the metric result.
863
864  Standalone usage:
865
866  >>> m = tf.keras.metrics.SparseCategoricalAccuracy()
867  >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]])
868  >>> m.result().numpy()
869  0.5
870
871  >>> m.reset_state()
872  >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]],
873  ...                sample_weight=[0.7, 0.3])
874  >>> m.result().numpy()
875  0.3
876
877  Usage with `compile()` API:
878
879  ```python
880  model.compile(
881      optimizer='sgd',
882      loss='mse',
883      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
884  ```
885  """
886
887  def __init__(self, name='sparse_categorical_accuracy', dtype=None):
888    super(SparseCategoricalAccuracy, self).__init__(
889        sparse_categorical_accuracy, name, dtype=dtype)
890
891
892@keras_export('keras.metrics.TopKCategoricalAccuracy')
893class TopKCategoricalAccuracy(MeanMetricWrapper):
894  """Computes how often targets are in the top `K` predictions.
895
896  Args:
897    k: (Optional) Number of top elements to look at for computing accuracy.
898      Defaults to 5.
899    name: (Optional) string name of the metric instance.
900    dtype: (Optional) data type of the metric result.
901
902  Standalone usage:
903
904  >>> m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)
905  >>> m.update_state([[0, 0, 1], [0, 1, 0]],
906  ...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
907  >>> m.result().numpy()
908  0.5
909
910  >>> m.reset_state()
911  >>> m.update_state([[0, 0, 1], [0, 1, 0]],
912  ...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
913  ...                sample_weight=[0.7, 0.3])
914  >>> m.result().numpy()
915  0.3
916
917  Usage with `compile()` API:
918
919  ```python
920  model.compile(optimizer='sgd',
921                loss='mse',
922                metrics=[tf.keras.metrics.TopKCategoricalAccuracy()])
923  ```
924  """
925
926  def __init__(self, k=5, name='top_k_categorical_accuracy', dtype=None):
927    super(TopKCategoricalAccuracy, self).__init__(
928        top_k_categorical_accuracy, name, dtype=dtype, k=k)
929
930
931@keras_export('keras.metrics.SparseTopKCategoricalAccuracy')
932class SparseTopKCategoricalAccuracy(MeanMetricWrapper):
933  """Computes how often integer targets are in the top `K` predictions.
934
935  Args:
936    k: (Optional) Number of top elements to look at for computing accuracy.
937      Defaults to 5.
938    name: (Optional) string name of the metric instance.
939    dtype: (Optional) data type of the metric result.
940
941  Standalone usage:
942
943  >>> m = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1)
944  >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
945  >>> m.result().numpy()
946  0.5
947
948  >>> m.reset_state()
949  >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
950  ...                sample_weight=[0.7, 0.3])
951  >>> m.result().numpy()
952  0.3
953
954  Usage with `compile()` API:
955
956  ```python
957  model.compile(
958    optimizer='sgd',
959    loss='mse',
960    metrics=[tf.keras.metrics.SparseTopKCategoricalAccuracy()])
961  ```
962  """
963
964  def __init__(self, k=5, name='sparse_top_k_categorical_accuracy', dtype=None):
965    super(SparseTopKCategoricalAccuracy, self).__init__(
966        sparse_top_k_categorical_accuracy, name, dtype=dtype, k=k)
967
968
969class _ConfusionMatrixConditionCount(Metric):
970  """Calculates the number of the given confusion matrix condition.
971
972  Args:
973    confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix` conditions.
974    thresholds: (Optional) Defaults to 0.5. A float value or a python list/tuple
975      of float threshold values in [0, 1]. A threshold is compared with
976      prediction values to determine the truth value of predictions (i.e., above
977      the threshold is `true`, below is `false`). One metric value is generated
978      for each threshold value.
979    name: (Optional) string name of the metric instance.
980    dtype: (Optional) data type of the metric result.
981  """
982
983  def __init__(self,
984               confusion_matrix_cond,
985               thresholds=None,
986               name=None,
987               dtype=None):
988    super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype)
989    self._confusion_matrix_cond = confusion_matrix_cond
990    self.init_thresholds = thresholds
991    self.thresholds = metrics_utils.parse_init_thresholds(
992        thresholds, default_threshold=0.5)
993    self._thresholds_distributed_evenly = (
994        metrics_utils.is_evenly_distributed_thresholds(self.thresholds))
995    self.accumulator = self.add_weight(
996        'accumulator',
997        shape=(len(self.thresholds),),
998        initializer=init_ops.zeros_initializer)
999
1000  def update_state(self, y_true, y_pred, sample_weight=None):
1001    """Accumulates the metric statistics.
1002
1003    Args:
1004      y_true: The ground truth values.
1005      y_pred: The predicted values.
1006      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1007        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1008        be broadcastable to `y_true`.
1009
1010    Returns:
1011      Update op.
1012    """
1013    return metrics_utils.update_confusion_matrix_variables(
1014        {self._confusion_matrix_cond: self.accumulator},
1015        y_true,
1016        y_pred,
1017        thresholds=self.thresholds,
1018        thresholds_distributed_evenly=self._thresholds_distributed_evenly,
1019        sample_weight=sample_weight)
1020
1021  def result(self):
1022    if len(self.thresholds) == 1:
1023      result = self.accumulator[0]
1024    else:
1025      result = self.accumulator
1026    return ops.convert_to_tensor_v2_with_dispatch(result)
1027
1028  def reset_state(self):
1029    num_thresholds = len(to_list(self.thresholds))
1030    backend.batch_set_value(
1031        [(v, np.zeros((num_thresholds,))) for v in self.variables])
1032
1033  def get_config(self):
1034    config = {'thresholds': self.init_thresholds}
1035    base_config = super(_ConfusionMatrixConditionCount, self).get_config()
1036    return dict(list(base_config.items()) + list(config.items()))
1037
1038
1039@keras_export('keras.metrics.FalsePositives')
1040class FalsePositives(_ConfusionMatrixConditionCount):
1041  """Calculates the number of false positives.
1042
1043  If `sample_weight` is given, calculates the sum of the weights of
1044  false positives. This metric creates one local variable, `accumulator`
1045  that is used to keep track of the number of false positives.
1046
1047  If `sample_weight` is `None`, weights default to 1.
1048  Use `sample_weight` of 0 to mask values.
1049
1050  Args:
1051    thresholds: (Optional) Defaults to 0.5. A float value or a python
1052      list/tuple of float threshold values in [0, 1]. A threshold is compared
1053      with prediction values to determine the truth value of predictions
1054      (i.e., above the threshold is `true`, below is `false`). One metric
1055      value is generated for each threshold value.
1056    name: (Optional) string name of the metric instance.
1057    dtype: (Optional) data type of the metric result.
1058
1059  Standalone usage:
1060
1061  >>> m = tf.keras.metrics.FalsePositives()
1062  >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1])
1063  >>> m.result().numpy()
1064  2.0
1065
1066  >>> m.reset_state()
1067  >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0])
1068  >>> m.result().numpy()
1069  1.0
1070
1071  Usage with `compile()` API:
1072
1073  ```python
1074  model.compile(optimizer='sgd',
1075                loss='mse',
1076                metrics=[tf.keras.metrics.FalsePositives()])
1077  ```
1078  """
1079
1080  def __init__(self, thresholds=None, name=None, dtype=None):
1081    super(FalsePositives, self).__init__(
1082        confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES,
1083        thresholds=thresholds,
1084        name=name,
1085        dtype=dtype)
1086
1087
1088@keras_export('keras.metrics.FalseNegatives')
1089class FalseNegatives(_ConfusionMatrixConditionCount):
1090  """Calculates the number of false negatives.
1091
1092  If `sample_weight` is given, calculates the sum of the weights of
1093  false negatives. This metric creates one local variable, `accumulator`
1094  that is used to keep track of the number of false negatives.
1095
1096  If `sample_weight` is `None`, weights default to 1.
1097  Use `sample_weight` of 0 to mask values.
1098
1099  Args:
1100    thresholds: (Optional) Defaults to 0.5. A float value or a python
1101      list/tuple of float threshold values in [0, 1]. A threshold is compared
1102      with prediction values to determine the truth value of predictions
1103      (i.e., above the threshold is `true`, below is `false`). One metric
1104      value is generated for each threshold value.
1105    name: (Optional) string name of the metric instance.
1106    dtype: (Optional) data type of the metric result.
1107
1108  Standalone usage:
1109
1110  >>> m = tf.keras.metrics.FalseNegatives()
1111  >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0])
1112  >>> m.result().numpy()
1113  2.0
1114
1115  >>> m.reset_state()
1116  >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0])
1117  >>> m.result().numpy()
1118  1.0
1119
1120  Usage with `compile()` API:
1121
1122  ```python
1123  model.compile(optimizer='sgd',
1124                loss='mse',
1125                metrics=[tf.keras.metrics.FalseNegatives()])
1126  ```
1127  """
1128
1129  def __init__(self, thresholds=None, name=None, dtype=None):
1130    super(FalseNegatives, self).__init__(
1131        confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES,
1132        thresholds=thresholds,
1133        name=name,
1134        dtype=dtype)
1135
1136
1137@keras_export('keras.metrics.TrueNegatives')
1138class TrueNegatives(_ConfusionMatrixConditionCount):
1139  """Calculates the number of true negatives.
1140
1141  If `sample_weight` is given, calculates the sum of the weights of
1142  true negatives. This metric creates one local variable, `accumulator`
1143  that is used to keep track of the number of true negatives.
1144
1145  If `sample_weight` is `None`, weights default to 1.
1146  Use `sample_weight` of 0 to mask values.
1147
1148  Args:
1149    thresholds: (Optional) Defaults to 0.5. A float value or a python
1150      list/tuple of float threshold values in [0, 1]. A threshold is compared
1151      with prediction values to determine the truth value of predictions
1152      (i.e., above the threshold is `true`, below is `false`). One metric
1153      value is generated for each threshold value.
1154    name: (Optional) string name of the metric instance.
1155    dtype: (Optional) data type of the metric result.
1156
1157  Standalone usage:
1158
1159  >>> m = tf.keras.metrics.TrueNegatives()
1160  >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0])
1161  >>> m.result().numpy()
1162  2.0
1163
1164  >>> m.reset_state()
1165  >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0])
1166  >>> m.result().numpy()
1167  1.0
1168
1169  Usage with `compile()` API:
1170
1171  ```python
1172  model.compile(optimizer='sgd',
1173                loss='mse',
1174                metrics=[tf.keras.metrics.TrueNegatives()])
1175  ```
1176  """
1177
1178  def __init__(self, thresholds=None, name=None, dtype=None):
1179    super(TrueNegatives, self).__init__(
1180        confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES,
1181        thresholds=thresholds,
1182        name=name,
1183        dtype=dtype)
1184
1185
1186@keras_export('keras.metrics.TruePositives')
1187class TruePositives(_ConfusionMatrixConditionCount):
1188  """Calculates the number of true positives.
1189
1190  If `sample_weight` is given, calculates the sum of the weights of
1191  true positives. This metric creates one local variable, `true_positives`
1192  that is used to keep track of the number of true positives.
1193
1194  If `sample_weight` is `None`, weights default to 1.
1195  Use `sample_weight` of 0 to mask values.
1196
1197  Args:
1198    thresholds: (Optional) Defaults to 0.5. A float value or a python
1199      list/tuple of float threshold values in [0, 1]. A threshold is compared
1200      with prediction values to determine the truth value of predictions
1201      (i.e., above the threshold is `true`, below is `false`). One metric
1202      value is generated for each threshold value.
1203    name: (Optional) string name of the metric instance.
1204    dtype: (Optional) data type of the metric result.
1205
1206  Standalone usage:
1207
1208  >>> m = tf.keras.metrics.TruePositives()
1209  >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
1210  >>> m.result().numpy()
1211  2.0
1212
1213  >>> m.reset_state()
1214  >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
1215  >>> m.result().numpy()
1216  1.0
1217
1218  Usage with `compile()` API:
1219
1220  ```python
1221  model.compile(optimizer='sgd',
1222                loss='mse',
1223                metrics=[tf.keras.metrics.TruePositives()])
1224  ```
1225  """
1226
1227  def __init__(self, thresholds=None, name=None, dtype=None):
1228    super(TruePositives, self).__init__(
1229        confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES,
1230        thresholds=thresholds,
1231        name=name,
1232        dtype=dtype)
1233
1234
1235@keras_export('keras.metrics.Precision')
1236class Precision(Metric):
1237  """Computes the precision of the predictions with respect to the labels.
1238
1239  The metric creates two local variables, `true_positives` and `false_positives`
1240  that are used to compute the precision. This value is ultimately returned as
1241  `precision`, an idempotent operation that simply divides `true_positives`
1242  by the sum of `true_positives` and `false_positives`.
1243
1244  If `sample_weight` is `None`, weights default to 1.
1245  Use `sample_weight` of 0 to mask values.
1246
1247  If `top_k` is set, we'll calculate precision as how often on average a class
1248  among the top-k classes with the highest predicted values of a batch entry is
1249  correct and can be found in the label for that entry.
1250
1251  If `class_id` is specified, we calculate precision by considering only the
1252  entries in the batch for which `class_id` is above the threshold and/or in the
1253  top-k highest predictions, and computing the fraction of them for which
1254  `class_id` is indeed a correct label.
1255
1256  Args:
1257    thresholds: (Optional) A float value or a python list/tuple of float
1258      threshold values in [0, 1]. A threshold is compared with prediction
1259      values to determine the truth value of predictions (i.e., above the
1260      threshold is `true`, below is `false`). One metric value is generated
1261      for each threshold value. If neither thresholds nor top_k are set, the
1262      default is to calculate precision with `thresholds=0.5`.
1263    top_k: (Optional) Unset by default. An int value specifying the top-k
1264      predictions to consider when calculating precision.
1265    class_id: (Optional) Integer class ID for which we want binary metrics.
1266      This must be in the half-open interval `[0, num_classes)`, where
1267      `num_classes` is the last dimension of predictions.
1268    name: (Optional) string name of the metric instance.
1269    dtype: (Optional) data type of the metric result.
1270
1271  Standalone usage:
1272
1273  >>> m = tf.keras.metrics.Precision()
1274  >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
1275  >>> m.result().numpy()
1276  0.6666667
1277
1278  >>> m.reset_state()
1279  >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
1280  >>> m.result().numpy()
1281  1.0
1282
1283  >>> # With top_k=2, it will calculate precision over y_true[:2] and y_pred[:2]
1284  >>> m = tf.keras.metrics.Precision(top_k=2)
1285  >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
1286  >>> m.result().numpy()
1287  0.0
1288
1289  >>> # With top_k=4, it will calculate precision over y_true[:4] and y_pred[:4]
1290  >>> m = tf.keras.metrics.Precision(top_k=4)
1291  >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
1292  >>> m.result().numpy()
1293  0.5
1294
1295  Usage with `compile()` API:
1296
1297  ```python
1298  model.compile(optimizer='sgd',
1299                loss='mse',
1300                metrics=[tf.keras.metrics.Precision()])
1301  ```
1302  """
1303
1304  def __init__(self,
1305               thresholds=None,
1306               top_k=None,
1307               class_id=None,
1308               name=None,
1309               dtype=None):
1310    super(Precision, self).__init__(name=name, dtype=dtype)
1311    self.init_thresholds = thresholds
1312    self.top_k = top_k
1313    self.class_id = class_id
1314
1315    default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
1316    self.thresholds = metrics_utils.parse_init_thresholds(
1317        thresholds, default_threshold=default_threshold)
1318    self._thresholds_distributed_evenly = (
1319        metrics_utils.is_evenly_distributed_thresholds(self.thresholds))
1320    self.true_positives = self.add_weight(
1321        'true_positives',
1322        shape=(len(self.thresholds),),
1323        initializer=init_ops.zeros_initializer)
1324    self.false_positives = self.add_weight(
1325        'false_positives',
1326        shape=(len(self.thresholds),),
1327        initializer=init_ops.zeros_initializer)
1328
1329  def update_state(self, y_true, y_pred, sample_weight=None):
1330    """Accumulates true positive and false positive statistics.
1331
1332    Args:
1333      y_true: The ground truth values, with the same dimensions as `y_pred`.
1334        Will be cast to `bool`.
1335      y_pred: The predicted values. Each element must be in the range `[0, 1]`.
1336      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1337        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1338        be broadcastable to `y_true`.
1339
1340    Returns:
1341      Update op.
1342    """
1343    return metrics_utils.update_confusion_matrix_variables(
1344        {
1345            metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
1346            metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives
1347        },
1348        y_true,
1349        y_pred,
1350        thresholds=self.thresholds,
1351        thresholds_distributed_evenly=self._thresholds_distributed_evenly,
1352        top_k=self.top_k,
1353        class_id=self.class_id,
1354        sample_weight=sample_weight)
1355
1356  def result(self):
1357    result = math_ops.div_no_nan(self.true_positives,
1358                                 self.true_positives + self.false_positives)
1359    return result[0] if len(self.thresholds) == 1 else result
1360
1361  def reset_state(self):
1362    num_thresholds = len(to_list(self.thresholds))
1363    backend.batch_set_value([(v, np.zeros((num_thresholds,)))
1364                             for v in (self.true_positives,
1365                                       self.false_positives)])
1366
1367  def get_config(self):
1368    config = {
1369        'thresholds': self.init_thresholds,
1370        'top_k': self.top_k,
1371        'class_id': self.class_id
1372    }
1373    base_config = super(Precision, self).get_config()
1374    return dict(list(base_config.items()) + list(config.items()))
1375
1376
1377@keras_export('keras.metrics.Recall')
1378class Recall(Metric):
1379  """Computes the recall of the predictions with respect to the labels.
1380
1381  This metric creates two local variables, `true_positives` and
1382  `false_negatives`, that are used to compute the recall. This value is
1383  ultimately returned as `recall`, an idempotent operation that simply divides
1384  `true_positives` by the sum of `true_positives` and `false_negatives`.
1385
1386  If `sample_weight` is `None`, weights default to 1.
1387  Use `sample_weight` of 0 to mask values.
1388
1389  If `top_k` is set, recall will be computed as how often on average a class
1390  among the labels of a batch entry is in the top-k predictions.
1391
1392  If `class_id` is specified, we calculate recall by considering only the
1393  entries in the batch for which `class_id` is in the label, and computing the
1394  fraction of them for which `class_id` is above the threshold and/or in the
1395  top-k predictions.
1396
1397  Args:
1398    thresholds: (Optional) A float value or a python list/tuple of float
1399      threshold values in [0, 1]. A threshold is compared with prediction
1400      values to determine the truth value of predictions (i.e., above the
1401      threshold is `true`, below is `false`). One metric value is generated
1402      for each threshold value. If neither thresholds nor top_k are set, the
1403      default is to calculate recall with `thresholds=0.5`.
1404    top_k: (Optional) Unset by default. An int value specifying the top-k
1405      predictions to consider when calculating recall.
1406    class_id: (Optional) Integer class ID for which we want binary metrics.
1407      This must be in the half-open interval `[0, num_classes)`, where
1408      `num_classes` is the last dimension of predictions.
1409    name: (Optional) string name of the metric instance.
1410    dtype: (Optional) data type of the metric result.
1411
1412  Standalone usage:
1413
1414  >>> m = tf.keras.metrics.Recall()
1415  >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
1416  >>> m.result().numpy()
1417  0.6666667
1418
1419  >>> m.reset_state()
1420  >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])
1421  >>> m.result().numpy()
1422  1.0
1423
1424  Usage with `compile()` API:
1425
1426  ```python
1427  model.compile(optimizer='sgd',
1428                loss='mse',
1429                metrics=[tf.keras.metrics.Recall()])
1430  ```
1431  """
1432
1433  def __init__(self,
1434               thresholds=None,
1435               top_k=None,
1436               class_id=None,
1437               name=None,
1438               dtype=None):
1439    super(Recall, self).__init__(name=name, dtype=dtype)
1440    self.init_thresholds = thresholds
1441    self.top_k = top_k
1442    self.class_id = class_id
1443
1444    default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF
1445    self.thresholds = metrics_utils.parse_init_thresholds(
1446        thresholds, default_threshold=default_threshold)
1447    self._thresholds_distributed_evenly = (
1448        metrics_utils.is_evenly_distributed_thresholds(self.thresholds))
1449    self.true_positives = self.add_weight(
1450        'true_positives',
1451        shape=(len(self.thresholds),),
1452        initializer=init_ops.zeros_initializer)
1453    self.false_negatives = self.add_weight(
1454        'false_negatives',
1455        shape=(len(self.thresholds),),
1456        initializer=init_ops.zeros_initializer)
1457
1458  def update_state(self, y_true, y_pred, sample_weight=None):
1459    """Accumulates true positive and false negative statistics.
1460
1461    Args:
1462      y_true: The ground truth values, with the same dimensions as `y_pred`.
1463        Will be cast to `bool`.
1464      y_pred: The predicted values. Each element must be in the range `[0, 1]`.
1465      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1466        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1467        be broadcastable to `y_true`.
1468
1469    Returns:
1470      Update op.
1471    """
1472    return metrics_utils.update_confusion_matrix_variables(
1473        {
1474            metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
1475            metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives
1476        },
1477        y_true,
1478        y_pred,
1479        thresholds=self.thresholds,
1480        thresholds_distributed_evenly=self._thresholds_distributed_evenly,
1481        top_k=self.top_k,
1482        class_id=self.class_id,
1483        sample_weight=sample_weight)
1484
1485  def result(self):
1486    result = math_ops.div_no_nan(self.true_positives,
1487                                 self.true_positives + self.false_negatives)
1488    return result[0] if len(self.thresholds) == 1 else result
1489
1490  def reset_state(self):
1491    num_thresholds = len(to_list(self.thresholds))
1492    backend.batch_set_value([(v, np.zeros((num_thresholds,)))
1493                             for v in (self.true_positives,
1494                                       self.false_negatives)])
1495
1496  def get_config(self):
1497    config = {
1498        'thresholds': self.init_thresholds,
1499        'top_k': self.top_k,
1500        'class_id': self.class_id
1501    }
1502    base_config = super(Recall, self).get_config()
1503    return dict(list(base_config.items()) + list(config.items()))
1504
1505
1506class SensitivitySpecificityBase(Metric, metaclass=abc.ABCMeta):
1507  """Abstract base class for computing sensitivity and specificity.
1508
1509  For additional information about specificity and sensitivity, see
1510  [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
1511  """
1512
1513  def __init__(self,
1514               value,
1515               num_thresholds=200,
1516               class_id=None,
1517               name=None,
1518               dtype=None):
1519    super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype)
1520    if num_thresholds <= 0:
1521      raise ValueError('`num_thresholds` must be > 0.')
1522    self.value = value
1523    self.class_id = class_id
1524    self.true_positives = self.add_weight(
1525        'true_positives',
1526        shape=(num_thresholds,),
1527        initializer=init_ops.zeros_initializer)
1528    self.true_negatives = self.add_weight(
1529        'true_negatives',
1530        shape=(num_thresholds,),
1531        initializer=init_ops.zeros_initializer)
1532    self.false_positives = self.add_weight(
1533        'false_positives',
1534        shape=(num_thresholds,),
1535        initializer=init_ops.zeros_initializer)
1536    self.false_negatives = self.add_weight(
1537        'false_negatives',
1538        shape=(num_thresholds,),
1539        initializer=init_ops.zeros_initializer)
1540
1541    # Compute `num_thresholds` thresholds in [0, 1]
1542    if num_thresholds == 1:
1543      self.thresholds = [0.5]
1544      self._thresholds_distributed_evenly = False
1545    else:
1546      thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
1547                    for i in range(num_thresholds - 2)]
1548      self.thresholds = [0.0] + thresholds + [1.0]
1549      self._thresholds_distributed_evenly = True
1550
1551  def update_state(self, y_true, y_pred, sample_weight=None):
1552    """Accumulates confusion matrix statistics.
1553
1554    Args:
1555      y_true: The ground truth values.
1556      y_pred: The predicted values.
1557      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
1558        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
1559        be broadcastable to `y_true`.
1560
1561    Returns:
1562      Update op.
1563    """
1564    return metrics_utils.update_confusion_matrix_variables(
1565        {
1566            metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,
1567            metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,
1568            metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,
1569            metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,
1570        },
1571        y_true,
1572        y_pred,
1573        thresholds=self.thresholds,
1574        thresholds_distributed_evenly=self._thresholds_distributed_evenly,
1575        class_id=self.class_id,
1576        sample_weight=sample_weight)
1577
1578  def reset_state(self):
1579    num_thresholds = len(self.thresholds)
1580    confusion_matrix_variables = (self.true_positives, self.true_negatives,
1581                                  self.false_positives, self.false_negatives)
1582    backend.batch_set_value([
1583        (v, np.zeros((num_thresholds,))) for v in confusion_matrix_variables
1584    ])
1585
1586  def get_config(self):
1587    config = {'class_id': self.class_id}
1588    base_config = super(SensitivitySpecificityBase, self).get_config()
1589    return dict(list(base_config.items()) + list(config.items()))
1590
1591  def _find_max_under_constraint(self, constrained, dependent, predicate):
1592    """Returns the maximum of dependent_statistic that satisfies the constraint.
1593
1594    Args:
1595      constrained: Over these values the constraint
1596        is specified. A rank-1 tensor.
1597      dependent: From these values the maximum that satiesfies the
1598        constraint is selected. Values in this tensor and in
1599        `constrained` are linked by having the same threshold at each
1600        position, hence this tensor must have the same shape.
1601      predicate: A binary boolean functor to be applied to arguments
1602      `constrained` and `self.value`, e.g. `tf.greater`.
1603
1604    Returns maximal dependent value, if no value satiesfies the constraint 0.0.
1605    """
1606    feasible = array_ops.where_v2(predicate(constrained, self.value))
1607    feasible_exists = math_ops.greater(array_ops.size(feasible), 0)
1608    max_dependent = math_ops.reduce_max(array_ops.gather(dependent, feasible))
1609
1610    return array_ops.where_v2(feasible_exists, max_dependent, 0.0)
1611
1612
1613@keras_export('keras.metrics.SensitivityAtSpecificity')
1614class SensitivityAtSpecificity(SensitivitySpecificityBase):
1615  """Computes best sensitivity where specificity is >= specified value.
1616
1617  the sensitivity at a given specificity.
1618
1619  `Sensitivity` measures the proportion of actual positives that are correctly
1620  identified as such (tp / (tp + fn)).
1621  `Specificity` measures the proportion of actual negatives that are correctly
1622  identified as such (tn / (tn + fp)).
1623
1624  This metric creates four local variables, `true_positives`, `true_negatives`,
1625  `false_positives` and `false_negatives` that are used to compute the
1626  sensitivity at the given specificity. The threshold for the given specificity
1627  value is computed and used to evaluate the corresponding sensitivity.
1628
1629  If `sample_weight` is `None`, weights default to 1.
1630  Use `sample_weight` of 0 to mask values.
1631
1632  If `class_id` is specified, we calculate precision by considering only the
1633  entries in the batch for which `class_id` is above the threshold predictions,
1634  and computing the fraction of them for which `class_id` is indeed a correct
1635  label.
1636
1637  For additional information about specificity and sensitivity, see
1638  [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
1639
1640  Args:
1641    specificity: A scalar value in range `[0, 1]`.
1642    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1643      use for matching the given specificity.
1644    class_id: (Optional) Integer class ID for which we want binary metrics.
1645      This must be in the half-open interval `[0, num_classes)`, where
1646      `num_classes` is the last dimension of predictions.
1647    name: (Optional) string name of the metric instance.
1648    dtype: (Optional) data type of the metric result.
1649
1650  Standalone usage:
1651
1652  >>> m = tf.keras.metrics.SensitivityAtSpecificity(0.5)
1653  >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
1654  >>> m.result().numpy()
1655  0.5
1656
1657  >>> m.reset_state()
1658  >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
1659  ...                sample_weight=[1, 1, 2, 2, 1])
1660  >>> m.result().numpy()
1661  0.333333
1662
1663  Usage with `compile()` API:
1664
1665  ```python
1666  model.compile(
1667      optimizer='sgd',
1668      loss='mse',
1669      metrics=[tf.keras.metrics.SensitivityAtSpecificity()])
1670  ```
1671  """
1672
1673  def __init__(self,
1674               specificity,
1675               num_thresholds=200,
1676               class_id=None,
1677               name=None,
1678               dtype=None):
1679    if specificity < 0 or specificity > 1:
1680      raise ValueError('`specificity` must be in the range [0, 1].')
1681    self.specificity = specificity
1682    self.num_thresholds = num_thresholds
1683    super(SensitivityAtSpecificity, self).__init__(
1684        specificity,
1685        num_thresholds=num_thresholds,
1686        class_id=class_id,
1687        name=name,
1688        dtype=dtype)
1689
1690  def result(self):
1691    specificities = math_ops.div_no_nan(
1692        self.true_negatives, self.true_negatives + self.false_positives)
1693    sensitivities = math_ops.div_no_nan(
1694        self.true_positives, self.true_positives + self.false_negatives)
1695    return self._find_max_under_constraint(
1696        specificities, sensitivities, math_ops.greater_equal)
1697
1698  def get_config(self):
1699    config = {
1700        'num_thresholds': self.num_thresholds,
1701        'specificity': self.specificity
1702    }
1703    base_config = super(SensitivityAtSpecificity, self).get_config()
1704    return dict(list(base_config.items()) + list(config.items()))
1705
1706
1707@keras_export('keras.metrics.SpecificityAtSensitivity')
1708class SpecificityAtSensitivity(SensitivitySpecificityBase):
1709  """Computes best specificity where sensitivity is >= specified value.
1710
1711  `Sensitivity` measures the proportion of actual positives that are correctly
1712  identified as such (tp / (tp + fn)).
1713  `Specificity` measures the proportion of actual negatives that are correctly
1714  identified as such (tn / (tn + fp)).
1715
1716  This metric creates four local variables, `true_positives`, `true_negatives`,
1717  `false_positives` and `false_negatives` that are used to compute the
1718  specificity at the given sensitivity. The threshold for the given sensitivity
1719  value is computed and used to evaluate the corresponding specificity.
1720
1721  If `sample_weight` is `None`, weights default to 1.
1722  Use `sample_weight` of 0 to mask values.
1723
1724  If `class_id` is specified, we calculate precision by considering only the
1725  entries in the batch for which `class_id` is above the threshold predictions,
1726  and computing the fraction of them for which `class_id` is indeed a correct
1727  label.
1728
1729  For additional information about specificity and sensitivity, see
1730  [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
1731
1732  Args:
1733    sensitivity: A scalar value in range `[0, 1]`.
1734    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1735      use for matching the given sensitivity.
1736    class_id: (Optional) Integer class ID for which we want binary metrics.
1737      This must be in the half-open interval `[0, num_classes)`, where
1738      `num_classes` is the last dimension of predictions.
1739    name: (Optional) string name of the metric instance.
1740    dtype: (Optional) data type of the metric result.
1741
1742  Standalone usage:
1743
1744  >>> m = tf.keras.metrics.SpecificityAtSensitivity(0.5)
1745  >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
1746  >>> m.result().numpy()
1747  0.66666667
1748
1749  >>> m.reset_state()
1750  >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
1751  ...                sample_weight=[1, 1, 2, 2, 2])
1752  >>> m.result().numpy()
1753  0.5
1754
1755  Usage with `compile()` API:
1756
1757  ```python
1758  model.compile(
1759      optimizer='sgd',
1760      loss='mse',
1761      metrics=[tf.keras.metrics.SpecificityAtSensitivity()])
1762  ```
1763  """
1764
1765  def __init__(self,
1766               sensitivity,
1767               num_thresholds=200,
1768               class_id=None,
1769               name=None,
1770               dtype=None):
1771    if sensitivity < 0 or sensitivity > 1:
1772      raise ValueError('`sensitivity` must be in the range [0, 1].')
1773    self.sensitivity = sensitivity
1774    self.num_thresholds = num_thresholds
1775    super(SpecificityAtSensitivity, self).__init__(
1776        sensitivity,
1777        num_thresholds=num_thresholds,
1778        class_id=class_id,
1779        name=name,
1780        dtype=dtype)
1781
1782  def result(self):
1783    sensitivities = math_ops.div_no_nan(
1784        self.true_positives, self.true_positives + self.false_negatives)
1785    specificities = math_ops.div_no_nan(
1786        self.true_negatives, self.true_negatives + self.false_positives)
1787    return self._find_max_under_constraint(
1788        sensitivities, specificities, math_ops.greater_equal)
1789
1790  def get_config(self):
1791    config = {
1792        'num_thresholds': self.num_thresholds,
1793        'sensitivity': self.sensitivity
1794    }
1795    base_config = super(SpecificityAtSensitivity, self).get_config()
1796    return dict(list(base_config.items()) + list(config.items()))
1797
1798
1799@keras_export('keras.metrics.PrecisionAtRecall')
1800class PrecisionAtRecall(SensitivitySpecificityBase):
1801  """Computes best precision where recall is >= specified value.
1802
1803  This metric creates four local variables, `true_positives`, `true_negatives`,
1804  `false_positives` and `false_negatives` that are used to compute the
1805  precision at the given recall. The threshold for the given recall
1806  value is computed and used to evaluate the corresponding precision.
1807
1808  If `sample_weight` is `None`, weights default to 1.
1809  Use `sample_weight` of 0 to mask values.
1810
1811  If `class_id` is specified, we calculate precision by considering only the
1812  entries in the batch for which `class_id` is above the threshold predictions,
1813  and computing the fraction of them for which `class_id` is indeed a correct
1814  label.
1815
1816  Args:
1817    recall: A scalar value in range `[0, 1]`.
1818    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1819      use for matching the given recall.
1820    class_id: (Optional) Integer class ID for which we want binary metrics.
1821      This must be in the half-open interval `[0, num_classes)`, where
1822      `num_classes` is the last dimension of predictions.
1823    name: (Optional) string name of the metric instance.
1824    dtype: (Optional) data type of the metric result.
1825
1826  Standalone usage:
1827
1828  >>> m = tf.keras.metrics.PrecisionAtRecall(0.5)
1829  >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])
1830  >>> m.result().numpy()
1831  0.5
1832
1833  >>> m.reset_state()
1834  >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],
1835  ...                sample_weight=[2, 2, 2, 1, 1])
1836  >>> m.result().numpy()
1837  0.33333333
1838
1839  Usage with `compile()` API:
1840
1841  ```python
1842  model.compile(
1843      optimizer='sgd',
1844      loss='mse',
1845      metrics=[tf.keras.metrics.PrecisionAtRecall(recall=0.8)])
1846  ```
1847  """
1848
1849  def __init__(self,
1850               recall,
1851               num_thresholds=200,
1852               class_id=None,
1853               name=None,
1854               dtype=None):
1855    if recall < 0 or recall > 1:
1856      raise ValueError('`recall` must be in the range [0, 1].')
1857    self.recall = recall
1858    self.num_thresholds = num_thresholds
1859    super(PrecisionAtRecall, self).__init__(
1860        value=recall,
1861        num_thresholds=num_thresholds,
1862        class_id=class_id,
1863        name=name,
1864        dtype=dtype)
1865
1866  def result(self):
1867    recalls = math_ops.div_no_nan(
1868        self.true_positives, self.true_positives + self.false_negatives)
1869    precisions = math_ops.div_no_nan(
1870        self.true_positives, self.true_positives + self.false_positives)
1871    return self._find_max_under_constraint(
1872        recalls, precisions, math_ops.greater_equal)
1873
1874  def get_config(self):
1875    config = {'num_thresholds': self.num_thresholds, 'recall': self.recall}
1876    base_config = super(PrecisionAtRecall, self).get_config()
1877    return dict(list(base_config.items()) + list(config.items()))
1878
1879
1880@keras_export('keras.metrics.RecallAtPrecision')
1881class RecallAtPrecision(SensitivitySpecificityBase):
1882  """Computes best recall where precision is >= specified value.
1883
1884  For a given score-label-distribution the required precision might not
1885  be achievable, in this case 0.0 is returned as recall.
1886
1887  This metric creates four local variables, `true_positives`, `true_negatives`,
1888  `false_positives` and `false_negatives` that are used to compute the
1889  recall at the given precision. The threshold for the given precision
1890  value is computed and used to evaluate the corresponding recall.
1891
1892  If `sample_weight` is `None`, weights default to 1.
1893  Use `sample_weight` of 0 to mask values.
1894
1895  If `class_id` is specified, we calculate precision by considering only the
1896  entries in the batch for which `class_id` is above the threshold predictions,
1897  and computing the fraction of them for which `class_id` is indeed a correct
1898  label.
1899
1900  Args:
1901    precision: A scalar value in range `[0, 1]`.
1902    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
1903      use for matching the given precision.
1904    class_id: (Optional) Integer class ID for which we want binary metrics.
1905      This must be in the half-open interval `[0, num_classes)`, where
1906      `num_classes` is the last dimension of predictions.
1907    name: (Optional) string name of the metric instance.
1908    dtype: (Optional) data type of the metric result.
1909
1910  Standalone usage:
1911
1912  >>> m = tf.keras.metrics.RecallAtPrecision(0.8)
1913  >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
1914  >>> m.result().numpy()
1915  0.5
1916
1917  >>> m.reset_state()
1918  >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
1919  ...                sample_weight=[1, 0, 0, 1])
1920  >>> m.result().numpy()
1921  1.0
1922
1923  Usage with `compile()` API:
1924
1925  ```python
1926  model.compile(
1927      optimizer='sgd',
1928      loss='mse',
1929      metrics=[tf.keras.metrics.RecallAtPrecision(precision=0.8)])
1930  ```
1931  """
1932
1933  def __init__(self,
1934               precision,
1935               num_thresholds=200,
1936               class_id=None,
1937               name=None,
1938               dtype=None):
1939    if precision < 0 or precision > 1:
1940      raise ValueError('`precision` must be in the range [0, 1].')
1941    self.precision = precision
1942    self.num_thresholds = num_thresholds
1943    super(RecallAtPrecision, self).__init__(
1944        value=precision,
1945        num_thresholds=num_thresholds,
1946        class_id=class_id,
1947        name=name,
1948        dtype=dtype)
1949
1950  def result(self):
1951    precisions = math_ops.div_no_nan(
1952        self.true_positives, self.true_positives + self.false_positives)
1953    recalls = math_ops.div_no_nan(
1954        self.true_positives, self.true_positives + self.false_negatives)
1955    return self._find_max_under_constraint(
1956        precisions, recalls, math_ops.greater_equal)
1957
1958  def get_config(self):
1959    config = {'num_thresholds': self.num_thresholds,
1960              'precision': self.precision}
1961    base_config = super(RecallAtPrecision, self).get_config()
1962    return dict(list(base_config.items()) + list(config.items()))
1963
1964
1965@keras_export('keras.metrics.AUC')
1966class AUC(Metric):
1967  """Approximates the AUC (Area under the curve) of the ROC or PR curves.
1968
1969  The AUC (Area under the curve) of the ROC (Receiver operating
1970  characteristic; default) or PR (Precision Recall) curves are quality measures
1971  of binary classifiers. Unlike the accuracy, and like cross-entropy
1972  losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
1973
1974  This class approximates AUCs using a Riemann sum. During the metric
1975  accumulation phrase, predictions are accumulated within predefined buckets
1976  by value. The AUC is then computed by interpolating per-bucket averages. These
1977  buckets define the evaluated operational points.
1978
1979  This metric creates four local variables, `true_positives`, `true_negatives`,
1980  `false_positives` and `false_negatives` that are used to compute the AUC.
1981  To discretize the AUC curve, a linearly spaced set of thresholds is used to
1982  compute pairs of recall and precision values. The area under the ROC-curve is
1983  therefore computed using the height of the recall values by the false positive
1984  rate, while the area under the PR-curve is the computed using the height of
1985  the precision values by the recall.
1986
1987  This value is ultimately returned as `auc`, an idempotent operation that
1988  computes the area under a discretized curve of precision versus recall values
1989  (computed using the aforementioned variables). The `num_thresholds` variable
1990  controls the degree of discretization with larger numbers of thresholds more
1991  closely approximating the true AUC. The quality of the approximation may vary
1992  dramatically depending on `num_thresholds`. The `thresholds` parameter can be
1993  used to manually specify thresholds which split the predictions more evenly.
1994
1995  For a best approximation of the real AUC, `predictions` should be distributed
1996  approximately uniformly in the range [0, 1] (if `from_logits=False`). The
1997  quality of the AUC approximation may be poor if this is not the case. Setting
1998  `summation_method` to 'minoring' or 'majoring' can help quantify the error in
1999  the approximation by providing lower or upper bound estimate of the AUC.
2000
2001  If `sample_weight` is `None`, weights default to 1.
2002  Use `sample_weight` of 0 to mask values.
2003
2004  Args:
2005    num_thresholds: (Optional) Defaults to 200. The number of thresholds to
2006      use when discretizing the roc curve. Values must be > 1.
2007    curve: (Optional) Specifies the name of the curve to be computed, 'ROC'
2008      [default] or 'PR' for the Precision-Recall-curve.
2009    summation_method: (Optional) Specifies the [Riemann summation method](
2010        https://en.wikipedia.org/wiki/Riemann_sum) used.
2011        'interpolation' (default) applies mid-point summation scheme for `ROC`.
2012        For PR-AUC, interpolates (true/false) positives but not the ratio that
2013        is precision (see Davis & Goadrich 2006 for details);
2014        'minoring' applies left summation
2015        for increasing intervals and right summation for decreasing intervals;
2016        'majoring' does the opposite.
2017    name: (Optional) string name of the metric instance.
2018    dtype: (Optional) data type of the metric result.
2019    thresholds: (Optional) A list of floating point values to use as the
2020      thresholds for discretizing the curve. If set, the `num_thresholds`
2021      parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
2022      equal to {-epsilon, 1+epsilon} for a small positive epsilon value will
2023      be automatically included with these to correctly handle predictions
2024      equal to exactly 0 or 1.
2025    multi_label: boolean indicating whether multilabel data should be
2026      treated as such, wherein AUC is computed separately for each label and
2027      then averaged across labels, or (when False) if the data should be
2028      flattened into a single label before AUC computation. In the latter
2029      case, when multilabel data is passed to AUC, each label-prediction pair
2030      is treated as an individual data point. Should be set to False for
2031      multi-class data.
2032    num_labels: (Optional) The number of labels, used when `multi_label` is
2033      True. If `num_labels` is not specified, then state variables get created
2034      on the first call to `update_state`.
2035    label_weights: (Optional) list, array, or tensor of non-negative weights
2036      used to compute AUCs for multilabel data. When `multi_label` is True,
2037      the weights are applied to the individual label AUCs when they are
2038      averaged to produce the multi-label AUC. When it's False, they are used
2039      to weight the individual label predictions in computing the confusion
2040      matrix on the flattened data. Note that this is unlike class_weights in
2041      that class_weights weights the example depending on the value of its
2042      label, whereas label_weights depends only on the index of that label
2043      before flattening; therefore `label_weights` should not be used for
2044      multi-class data.
2045    from_logits: boolean indicating whether the predictions (`y_pred` in
2046      `update_state`) are probabilities or sigmoid logits. As a rule of thumb,
2047      when using a keras loss, the `from_logits` constructor argument of the
2048      loss should match the AUC `from_logits` constructor argument.
2049
2050  Standalone usage:
2051
2052  >>> m = tf.keras.metrics.AUC(num_thresholds=3)
2053  >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
2054  >>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
2055  >>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
2056  >>> # tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0]
2057  >>> # auc = ((((1+0.5)/2)*(1-0)) + (((0.5+0)/2)*(0-0))) = 0.75
2058  >>> m.result().numpy()
2059  0.75
2060
2061  >>> m.reset_state()
2062  >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],
2063  ...                sample_weight=[1, 0, 0, 1])
2064  >>> m.result().numpy()
2065  1.0
2066
2067  Usage with `compile()` API:
2068
2069  ```python
2070  # Reports the AUC of a model outputing a probability.
2071  model.compile(optimizer='sgd',
2072                loss=tf.keras.losses.BinaryCrossentropy(),
2073                metrics=[tf.keras.metrics.AUC()])
2074
2075  # Reports the AUC of a model outputing a logit.
2076  model.compile(optimizer='sgd',
2077                loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
2078                metrics=[tf.keras.metrics.AUC(from_logits=True)])
2079  ```
2080  """
2081
2082  def __init__(self,
2083               num_thresholds=200,
2084               curve='ROC',
2085               summation_method='interpolation',
2086               name=None,
2087               dtype=None,
2088               thresholds=None,
2089               multi_label=False,
2090               num_labels=None,
2091               label_weights=None,
2092               from_logits=False):
2093    # Validate configurations.
2094    if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(
2095        metrics_utils.AUCCurve):
2096      raise ValueError('Invalid curve: "{}". Valid options are: "{}"'.format(
2097          curve, list(metrics_utils.AUCCurve)))
2098    if isinstance(
2099        summation_method,
2100        metrics_utils.AUCSummationMethod) and summation_method not in list(
2101            metrics_utils.AUCSummationMethod):
2102      raise ValueError(
2103          'Invalid summation method: "{}". Valid options are: "{}"'.format(
2104              summation_method, list(metrics_utils.AUCSummationMethod)))
2105
2106    # Update properties.
2107    if thresholds is not None:
2108      # If specified, use the supplied thresholds.
2109      self.num_thresholds = len(thresholds) + 2
2110      thresholds = sorted(thresholds)
2111      self._thresholds_distributed_evenly = (
2112          metrics_utils.is_evenly_distributed_thresholds(
2113              np.array([0.0] + thresholds + [1.0])))
2114    else:
2115      if num_thresholds <= 1:
2116        raise ValueError('`num_thresholds` must be > 1.')
2117
2118      # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
2119      # (0, 1).
2120      self.num_thresholds = num_thresholds
2121      thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
2122                    for i in range(num_thresholds - 2)]
2123      self._thresholds_distributed_evenly = True
2124
2125    # Add an endpoint "threshold" below zero and above one for either
2126    # threshold method to account for floating point imprecisions.
2127    self._thresholds = np.array([0.0 - backend.epsilon()] + thresholds +
2128                                [1.0 + backend.epsilon()])
2129
2130    if isinstance(curve, metrics_utils.AUCCurve):
2131      self.curve = curve
2132    else:
2133      self.curve = metrics_utils.AUCCurve.from_str(curve)
2134    if isinstance(summation_method, metrics_utils.AUCSummationMethod):
2135      self.summation_method = summation_method
2136    else:
2137      self.summation_method = metrics_utils.AUCSummationMethod.from_str(
2138          summation_method)
2139    super(AUC, self).__init__(name=name, dtype=dtype)
2140
2141    # Handle multilabel arguments.
2142    self.multi_label = multi_label
2143    if label_weights is not None:
2144      label_weights = constant_op.constant(label_weights, dtype=self.dtype)
2145      checks = [
2146          check_ops.assert_non_negative(
2147              label_weights,
2148              message='All values of `label_weights` must be non-negative.')
2149      ]
2150      with ops.control_dependencies(checks):
2151        self.label_weights = label_weights
2152
2153    else:
2154      self.label_weights = None
2155
2156    self._from_logits = from_logits
2157
2158    self._built = False
2159    if self.multi_label:
2160      if num_labels:
2161        shape = tensor_shape.TensorShape([None, num_labels])
2162        self._build(shape)
2163    else:
2164      if num_labels:
2165        raise ValueError(
2166            '`num_labels` is needed only when `multi_label` is True.')
2167      self._build(None)
2168
2169  @property
2170  def thresholds(self):
2171    """The thresholds used for evaluating AUC."""
2172    return list(self._thresholds)
2173
2174  def _build(self, shape):
2175    """Initialize TP, FP, TN, and FN tensors, given the shape of the data."""
2176    if self.multi_label:
2177      if shape.ndims != 2:
2178        raise ValueError('`y_true` must have rank=2 when `multi_label` is '
2179                         'True. Found rank %s.' % shape.ndims)
2180      self._num_labels = shape[1]
2181      variable_shape = tensor_shape.TensorShape(
2182          [tensor_shape.Dimension(self.num_thresholds), self._num_labels])
2183
2184    else:
2185      variable_shape = tensor_shape.TensorShape(
2186          [tensor_shape.Dimension(self.num_thresholds)])
2187    self._build_input_shape = shape
2188    # Create metric variables
2189    self.true_positives = self.add_weight(
2190        'true_positives',
2191        shape=variable_shape,
2192        initializer=init_ops.zeros_initializer)
2193    self.true_negatives = self.add_weight(
2194        'true_negatives',
2195        shape=variable_shape,
2196        initializer=init_ops.zeros_initializer)
2197    self.false_positives = self.add_weight(
2198        'false_positives',
2199        shape=variable_shape,
2200        initializer=init_ops.zeros_initializer)
2201    self.false_negatives = self.add_weight(
2202        'false_negatives',
2203        shape=variable_shape,
2204        initializer=init_ops.zeros_initializer)
2205
2206    if self.multi_label:
2207      with ops.init_scope():
2208        # This should only be necessary for handling v1 behavior. In v2, AUC
2209        # should be initialized outside of any tf.functions, and therefore in
2210        # eager mode.
2211        if not context.executing_eagerly():
2212          backend._initialize_variables(backend._get_session())  # pylint: disable=protected-access
2213
2214    self._built = True
2215
2216  def update_state(self, y_true, y_pred, sample_weight=None):
2217    """Accumulates confusion matrix statistics.
2218
2219    Args:
2220      y_true: The ground truth values.
2221      y_pred: The predicted values.
2222      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
2223        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
2224        be broadcastable to `y_true`.
2225
2226    Returns:
2227      Update op.
2228    """
2229    deps = []
2230    if not self._built:
2231      self._build(tensor_shape.TensorShape(y_pred.shape))
2232
2233    if self.multi_label or (self.label_weights is not None):
2234      # y_true should have shape (number of examples, number of labels).
2235      shapes = [
2236          (y_true, ('N', 'L'))
2237      ]
2238      if self.multi_label:
2239        # TP, TN, FP, and FN should all have shape
2240        # (number of thresholds, number of labels).
2241        shapes.extend([(self.true_positives, ('T', 'L')),
2242                       (self.true_negatives, ('T', 'L')),
2243                       (self.false_positives, ('T', 'L')),
2244                       (self.false_negatives, ('T', 'L'))])
2245      if self.label_weights is not None:
2246        # label_weights should be of length equal to the number of labels.
2247        shapes.append((self.label_weights, ('L',)))
2248      deps = [
2249          check_ops.assert_shapes(
2250              shapes, message='Number of labels is not consistent.')
2251      ]
2252
2253    # Only forward label_weights to update_confusion_matrix_variables when
2254    # multi_label is False. Otherwise the averaging of individual label AUCs is
2255    # handled in AUC.result
2256    label_weights = None if self.multi_label else self.label_weights
2257
2258    if self._from_logits:
2259      y_pred = activations.sigmoid(y_pred)
2260
2261    with ops.control_dependencies(deps):
2262      return metrics_utils.update_confusion_matrix_variables(
2263          {
2264              metrics_utils.ConfusionMatrix.TRUE_POSITIVES:
2265                  self.true_positives,
2266              metrics_utils.ConfusionMatrix.TRUE_NEGATIVES:
2267                  self.true_negatives,
2268              metrics_utils.ConfusionMatrix.FALSE_POSITIVES:
2269                  self.false_positives,
2270              metrics_utils.ConfusionMatrix.FALSE_NEGATIVES:
2271                  self.false_negatives,
2272          },
2273          y_true,
2274          y_pred,
2275          self._thresholds,
2276          thresholds_distributed_evenly=self._thresholds_distributed_evenly,
2277          sample_weight=sample_weight,
2278          multi_label=self.multi_label,
2279          label_weights=label_weights)
2280
2281  def interpolate_pr_auc(self):
2282    """Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
2283
2284    https://www.biostat.wisc.edu/~page/rocpr.pdf
2285
2286    Note here we derive & use a closed formula not present in the paper
2287    as follows:
2288
2289      Precision = TP / (TP + FP) = TP / P
2290
2291    Modeling all of TP (true positive), FP (false positive) and their sum
2292    P = TP + FP (predicted positive) as varying linearly within each interval
2293    [A, B] between successive thresholds, we get
2294
2295      Precision slope = dTP / dP
2296                      = (TP_B - TP_A) / (P_B - P_A)
2297                      = (TP - TP_A) / (P - P_A)
2298      Precision = (TP_A + slope * (P - P_A)) / P
2299
2300    The area within the interval is (slope / total_pos_weight) times
2301
2302      int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
2303      int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
2304
2305    where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
2306
2307      int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
2308
2309    Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
2310
2311      slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight
2312
2313    where dTP == TP_B - TP_A.
2314
2315    Note that when P_A == 0 the above calculation simplifies into
2316
2317      int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
2318
2319    which is really equivalent to imputing constant precision throughout the
2320    first bucket having >0 true positives.
2321
2322    Returns:
2323      pr_auc: an approximation of the area under the P-R curve.
2324    """
2325    dtp = self.true_positives[:self.num_thresholds -
2326                              1] - self.true_positives[1:]
2327    p = self.true_positives + self.false_positives
2328    dp = p[:self.num_thresholds - 1] - p[1:]
2329    prec_slope = math_ops.div_no_nan(
2330        dtp, math_ops.maximum(dp, 0), name='prec_slope')
2331    intercept = self.true_positives[1:] - math_ops.multiply(prec_slope, p[1:])
2332
2333    safe_p_ratio = array_ops.where(
2334        math_ops.logical_and(p[:self.num_thresholds - 1] > 0, p[1:] > 0),
2335        math_ops.div_no_nan(
2336            p[:self.num_thresholds - 1],
2337            math_ops.maximum(p[1:], 0),
2338            name='recall_relative_ratio'),
2339        array_ops.ones_like(p[1:]))
2340
2341    pr_auc_increment = math_ops.div_no_nan(
2342        prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
2343        math_ops.maximum(self.true_positives[1:] + self.false_negatives[1:], 0),
2344        name='pr_auc_increment')
2345
2346    if self.multi_label:
2347      by_label_auc = math_ops.reduce_sum(
2348          pr_auc_increment, name=self.name + '_by_label', axis=0)
2349      if self.label_weights is None:
2350        # Evenly weighted average of the label AUCs.
2351        return math_ops.reduce_mean(by_label_auc, name=self.name)
2352      else:
2353        # Weighted average of the label AUCs.
2354        return math_ops.div_no_nan(
2355            math_ops.reduce_sum(
2356                math_ops.multiply(by_label_auc, self.label_weights)),
2357            math_ops.reduce_sum(self.label_weights),
2358            name=self.name)
2359    else:
2360      return math_ops.reduce_sum(pr_auc_increment, name='interpolate_pr_auc')
2361
2362  def result(self):
2363    if (self.curve == metrics_utils.AUCCurve.PR and
2364        self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION
2365       ):
2366      # This use case is different and is handled separately.
2367      return self.interpolate_pr_auc()
2368
2369    # Set `x` and `y` values for the curves based on `curve` config.
2370    recall = math_ops.div_no_nan(self.true_positives,
2371                                 self.true_positives + self.false_negatives)
2372    if self.curve == metrics_utils.AUCCurve.ROC:
2373      fp_rate = math_ops.div_no_nan(self.false_positives,
2374                                    self.false_positives + self.true_negatives)
2375      x = fp_rate
2376      y = recall
2377    else:  # curve == 'PR'.
2378      precision = math_ops.div_no_nan(
2379          self.true_positives, self.true_positives + self.false_positives)
2380      x = recall
2381      y = precision
2382
2383    # Find the rectangle heights based on `summation_method`.
2384    if self.summation_method == metrics_utils.AUCSummationMethod.INTERPOLATION:
2385      # Note: the case ('PR', 'interpolation') has been handled above.
2386      heights = (y[:self.num_thresholds - 1] + y[1:]) / 2.
2387    elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING:
2388      heights = math_ops.minimum(y[:self.num_thresholds - 1], y[1:])
2389    else:  # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING:
2390      heights = math_ops.maximum(y[:self.num_thresholds - 1], y[1:])
2391
2392    # Sum up the areas of all the rectangles.
2393    if self.multi_label:
2394      riemann_terms = math_ops.multiply(x[:self.num_thresholds - 1] - x[1:],
2395                                        heights)
2396      by_label_auc = math_ops.reduce_sum(
2397          riemann_terms, name=self.name + '_by_label', axis=0)
2398
2399      if self.label_weights is None:
2400        # Unweighted average of the label AUCs.
2401        return math_ops.reduce_mean(by_label_auc, name=self.name)
2402      else:
2403        # Weighted average of the label AUCs.
2404        return math_ops.div_no_nan(
2405            math_ops.reduce_sum(
2406                math_ops.multiply(by_label_auc, self.label_weights)),
2407            math_ops.reduce_sum(self.label_weights),
2408            name=self.name)
2409    else:
2410      return math_ops.reduce_sum(
2411          math_ops.multiply(x[:self.num_thresholds - 1] - x[1:], heights),
2412          name=self.name)
2413
2414  def reset_state(self):
2415    if self._built:
2416      confusion_matrix_variables = (self.true_positives, self.true_negatives,
2417                                    self.false_positives, self.false_negatives)
2418      if self.multi_label:
2419        backend.batch_set_value(
2420            [(v, np.zeros((self.num_thresholds, self._num_labels)))
2421             for v in confusion_matrix_variables])
2422      else:
2423        backend.batch_set_value([(v, np.zeros((self.num_thresholds,)))
2424                                 for v in confusion_matrix_variables])
2425
2426  def get_config(self):
2427    if is_tensor_or_variable(self.label_weights):
2428      label_weights = backend.eval(self.label_weights)
2429    else:
2430      label_weights = self.label_weights
2431    config = {
2432        'num_thresholds': self.num_thresholds,
2433        'curve': self.curve.value,
2434        'summation_method': self.summation_method.value,
2435        # We remove the endpoint thresholds as an inverse of how the thresholds
2436        # were initialized. This ensures that a metric initialized from this
2437        # config has the same thresholds.
2438        'thresholds': self.thresholds[1:-1],
2439        'multi_label': self.multi_label,
2440        'label_weights': label_weights
2441    }
2442    base_config = super(AUC, self).get_config()
2443    return dict(list(base_config.items()) + list(config.items()))
2444
2445
2446@keras_export('keras.metrics.CosineSimilarity')
2447class CosineSimilarity(MeanMetricWrapper):
2448  """Computes the cosine similarity between the labels and predictions.
2449
2450  `cosine similarity = (a . b) / ||a|| ||b||`
2451
2452  See: [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity).
2453
2454  This metric keeps the average cosine similarity between `predictions` and
2455  `labels` over a stream of data.
2456
2457  Args:
2458    name: (Optional) string name of the metric instance.
2459    dtype: (Optional) data type of the metric result.
2460    axis: (Optional) Defaults to -1. The dimension along which the cosine
2461      similarity is computed.
2462
2463  Standalone usage:
2464
2465  >>> # l2_norm(y_true) = [[0., 1.], [1./1.414, 1./1.414]]
2466  >>> # l2_norm(y_pred) = [[1., 0.], [1./1.414, 1./1.414]]
2467  >>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]]
2468  >>> # result = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1))
2469  >>> #        = ((0. + 0.) +  (0.5 + 0.5)) / 2
2470  >>> m = tf.keras.metrics.CosineSimilarity(axis=1)
2471  >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]])
2472  >>> m.result().numpy()
2473  0.49999997
2474
2475  >>> m.reset_state()
2476  >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]],
2477  ...                sample_weight=[0.3, 0.7])
2478  >>> m.result().numpy()
2479  0.6999999
2480
2481  Usage with `compile()` API:
2482
2483  ```python
2484  model.compile(
2485      optimizer='sgd',
2486      loss='mse',
2487      metrics=[tf.keras.metrics.CosineSimilarity(axis=1)])
2488  ```
2489  """
2490
2491  def __init__(self, name='cosine_similarity', dtype=None, axis=-1):
2492    super(CosineSimilarity, self).__init__(
2493        cosine_similarity, name, dtype=dtype, axis=axis)
2494
2495
2496@keras_export('keras.metrics.MeanAbsoluteError')
2497class MeanAbsoluteError(MeanMetricWrapper):
2498  """Computes the mean absolute error between the labels and predictions.
2499
2500  Args:
2501    name: (Optional) string name of the metric instance.
2502    dtype: (Optional) data type of the metric result.
2503
2504  Standalone usage:
2505
2506  >>> m = tf.keras.metrics.MeanAbsoluteError()
2507  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2508  >>> m.result().numpy()
2509  0.25
2510
2511  >>> m.reset_state()
2512  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2513  ...                sample_weight=[1, 0])
2514  >>> m.result().numpy()
2515  0.5
2516
2517  Usage with `compile()` API:
2518
2519  ```python
2520  model.compile(
2521      optimizer='sgd',
2522      loss='mse',
2523      metrics=[tf.keras.metrics.MeanAbsoluteError()])
2524  ```
2525  """
2526
2527  def __init__(self, name='mean_absolute_error', dtype=None):
2528    super(MeanAbsoluteError, self).__init__(
2529        mean_absolute_error, name, dtype=dtype)
2530
2531
2532@keras_export('keras.metrics.MeanAbsolutePercentageError')
2533class MeanAbsolutePercentageError(MeanMetricWrapper):
2534  """Computes the mean absolute percentage error between `y_true` and `y_pred`.
2535
2536  Args:
2537    name: (Optional) string name of the metric instance.
2538    dtype: (Optional) data type of the metric result.
2539
2540  Standalone usage:
2541
2542  >>> m = tf.keras.metrics.MeanAbsolutePercentageError()
2543  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2544  >>> m.result().numpy()
2545  250000000.0
2546
2547  >>> m.reset_state()
2548  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2549  ...                sample_weight=[1, 0])
2550  >>> m.result().numpy()
2551  500000000.0
2552
2553  Usage with `compile()` API:
2554
2555  ```python
2556  model.compile(
2557      optimizer='sgd',
2558      loss='mse',
2559      metrics=[tf.keras.metrics.MeanAbsolutePercentageError()])
2560  ```
2561  """
2562
2563  def __init__(self, name='mean_absolute_percentage_error', dtype=None):
2564    super(MeanAbsolutePercentageError, self).__init__(
2565        mean_absolute_percentage_error, name, dtype=dtype)
2566
2567
2568@keras_export('keras.metrics.MeanSquaredError')
2569class MeanSquaredError(MeanMetricWrapper):
2570  """Computes the mean squared error between `y_true` and `y_pred`.
2571
2572  Args:
2573    name: (Optional) string name of the metric instance.
2574    dtype: (Optional) data type of the metric result.
2575
2576  Standalone usage:
2577
2578  >>> m = tf.keras.metrics.MeanSquaredError()
2579  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2580  >>> m.result().numpy()
2581  0.25
2582
2583  >>> m.reset_state()
2584  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2585  ...                sample_weight=[1, 0])
2586  >>> m.result().numpy()
2587  0.5
2588
2589  Usage with `compile()` API:
2590
2591  ```python
2592  model.compile(
2593      optimizer='sgd',
2594      loss='mse',
2595      metrics=[tf.keras.metrics.MeanSquaredError()])
2596  ```
2597  """
2598
2599  def __init__(self, name='mean_squared_error', dtype=None):
2600    super(MeanSquaredError, self).__init__(
2601        mean_squared_error, name, dtype=dtype)
2602
2603
2604@keras_export('keras.metrics.MeanSquaredLogarithmicError')
2605class MeanSquaredLogarithmicError(MeanMetricWrapper):
2606  """Computes the mean squared logarithmic error between `y_true` and `y_pred`.
2607
2608  Args:
2609    name: (Optional) string name of the metric instance.
2610    dtype: (Optional) data type of the metric result.
2611
2612  Standalone usage:
2613
2614  >>> m = tf.keras.metrics.MeanSquaredLogarithmicError()
2615  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2616  >>> m.result().numpy()
2617  0.12011322
2618
2619  >>> m.reset_state()
2620  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2621  ...                sample_weight=[1, 0])
2622  >>> m.result().numpy()
2623  0.24022643
2624
2625  Usage with `compile()` API:
2626
2627  ```python
2628  model.compile(
2629      optimizer='sgd',
2630      loss='mse',
2631      metrics=[tf.keras.metrics.MeanSquaredLogarithmicError()])
2632  ```
2633  """
2634
2635  def __init__(self, name='mean_squared_logarithmic_error', dtype=None):
2636    super(MeanSquaredLogarithmicError, self).__init__(
2637        mean_squared_logarithmic_error, name, dtype=dtype)
2638
2639
2640@keras_export('keras.metrics.Hinge')
2641class Hinge(MeanMetricWrapper):
2642  """Computes the hinge metric between `y_true` and `y_pred`.
2643
2644  `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
2645  provided we will convert them to -1 or 1.
2646
2647  Args:
2648    name: (Optional) string name of the metric instance.
2649    dtype: (Optional) data type of the metric result.
2650
2651  Standalone usage:
2652
2653  >>> m = tf.keras.metrics.Hinge()
2654  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
2655  >>> m.result().numpy()
2656  1.3
2657
2658  >>> m.reset_state()
2659  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
2660  ...                sample_weight=[1, 0])
2661  >>> m.result().numpy()
2662  1.1
2663
2664  Usage with `compile()` API:
2665
2666  ```python
2667  model.compile(optimizer='sgd', loss='mse', metrics=[tf.keras.metrics.Hinge()])
2668  ```
2669  """
2670
2671  def __init__(self, name='hinge', dtype=None):
2672    super(Hinge, self).__init__(hinge, name, dtype=dtype)
2673
2674
2675@keras_export('keras.metrics.SquaredHinge')
2676class SquaredHinge(MeanMetricWrapper):
2677  """Computes the squared hinge metric between `y_true` and `y_pred`.
2678
2679  `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
2680  provided we will convert them to -1 or 1.
2681
2682  Args:
2683    name: (Optional) string name of the metric instance.
2684    dtype: (Optional) data type of the metric result.
2685
2686  Standalone usage:
2687
2688  >>> m = tf.keras.metrics.SquaredHinge()
2689  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
2690  >>> m.result().numpy()
2691  1.86
2692
2693  >>> m.reset_state()
2694  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
2695  ...                sample_weight=[1, 0])
2696  >>> m.result().numpy()
2697  1.46
2698
2699  Usage with `compile()` API:
2700
2701  ```python
2702  model.compile(
2703      optimizer='sgd',
2704      loss='mse',
2705      metrics=[tf.keras.metrics.SquaredHinge()])
2706  ```
2707  """
2708
2709  def __init__(self, name='squared_hinge', dtype=None):
2710    super(SquaredHinge, self).__init__(squared_hinge, name, dtype=dtype)
2711
2712
2713@keras_export('keras.metrics.CategoricalHinge')
2714class CategoricalHinge(MeanMetricWrapper):
2715  """Computes the categorical hinge metric between `y_true` and `y_pred`.
2716
2717  Args:
2718    name: (Optional) string name of the metric instance.
2719    dtype: (Optional) data type of the metric result.
2720
2721  Standalone usage:
2722
2723  >>> m = tf.keras.metrics.CategoricalHinge()
2724  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
2725  >>> m.result().numpy()
2726  1.4000001
2727
2728  >>> m.reset_state()
2729  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
2730  ...                sample_weight=[1, 0])
2731  >>> m.result().numpy()
2732  1.2
2733
2734  Usage with `compile()` API:
2735
2736  ```python
2737  model.compile(
2738      optimizer='sgd',
2739      loss='mse',
2740      metrics=[tf.keras.metrics.CategoricalHinge()])
2741  ```
2742  """
2743
2744  def __init__(self, name='categorical_hinge', dtype=None):
2745    super(CategoricalHinge, self).__init__(categorical_hinge, name, dtype=dtype)
2746
2747
2748@keras_export('keras.metrics.RootMeanSquaredError')
2749class RootMeanSquaredError(Mean):
2750  """Computes root mean squared error metric between `y_true` and `y_pred`.
2751
2752  Standalone usage:
2753
2754  >>> m = tf.keras.metrics.RootMeanSquaredError()
2755  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2756  >>> m.result().numpy()
2757  0.5
2758
2759  >>> m.reset_state()
2760  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2761  ...                sample_weight=[1, 0])
2762  >>> m.result().numpy()
2763  0.70710677
2764
2765  Usage with `compile()` API:
2766
2767  ```python
2768  model.compile(
2769      optimizer='sgd',
2770      loss='mse',
2771      metrics=[tf.keras.metrics.RootMeanSquaredError()])
2772  ```
2773  """
2774
2775  def __init__(self, name='root_mean_squared_error', dtype=None):
2776    super(RootMeanSquaredError, self).__init__(name, dtype=dtype)
2777
2778  def update_state(self, y_true, y_pred, sample_weight=None):
2779    """Accumulates root mean squared error statistics.
2780
2781    Args:
2782      y_true: The ground truth values.
2783      y_pred: The predicted values.
2784      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
2785        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
2786        be broadcastable to `y_true`.
2787
2788    Returns:
2789      Update op.
2790    """
2791    y_true = math_ops.cast(y_true, self._dtype)
2792    y_pred = math_ops.cast(y_pred, self._dtype)
2793    y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
2794        y_pred, y_true)
2795    error_sq = math_ops.squared_difference(y_pred, y_true)
2796    return super(RootMeanSquaredError, self).update_state(
2797        error_sq, sample_weight=sample_weight)
2798
2799  def result(self):
2800    return math_ops.sqrt(math_ops.div_no_nan(self.total, self.count))
2801
2802
2803@keras_export('keras.metrics.LogCoshError')
2804class LogCoshError(MeanMetricWrapper):
2805  """Computes the logarithm of the hyperbolic cosine of the prediction error.
2806
2807  `logcosh = log((exp(x) + exp(-x))/2)`, where x is the error (y_pred - y_true)
2808
2809  Args:
2810    name: (Optional) string name of the metric instance.
2811    dtype: (Optional) data type of the metric result.
2812
2813  Standalone usage:
2814
2815  >>> m = tf.keras.metrics.LogCoshError()
2816  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2817  >>> m.result().numpy()
2818  0.10844523
2819
2820  >>> m.reset_state()
2821  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2822  ...                sample_weight=[1, 0])
2823  >>> m.result().numpy()
2824  0.21689045
2825
2826  Usage with `compile()` API:
2827
2828  ```python
2829  model.compile(optimizer='sgd',
2830                loss='mse',
2831                metrics=[tf.keras.metrics.LogCoshError()])
2832  ```
2833  """
2834
2835  def __init__(self, name='logcosh', dtype=None):
2836    super(LogCoshError, self).__init__(logcosh, name, dtype=dtype)
2837
2838
2839@keras_export('keras.metrics.Poisson')
2840class Poisson(MeanMetricWrapper):
2841  """Computes the Poisson metric between `y_true` and `y_pred`.
2842
2843  `metric = y_pred - y_true * log(y_pred)`
2844
2845  Args:
2846    name: (Optional) string name of the metric instance.
2847    dtype: (Optional) data type of the metric result.
2848
2849  Standalone usage:
2850
2851  >>> m = tf.keras.metrics.Poisson()
2852  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])
2853  >>> m.result().numpy()
2854  0.49999997
2855
2856  >>> m.reset_state()
2857  >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],
2858  ...                sample_weight=[1, 0])
2859  >>> m.result().numpy()
2860  0.99999994
2861
2862  Usage with `compile()` API:
2863
2864  ```python
2865  model.compile(optimizer='sgd',
2866                loss='mse',
2867                metrics=[tf.keras.metrics.Poisson()])
2868  ```
2869  """
2870
2871  def __init__(self, name='poisson', dtype=None):
2872    super(Poisson, self).__init__(poisson, name, dtype=dtype)
2873
2874
2875@keras_export('keras.metrics.KLDivergence')
2876class KLDivergence(MeanMetricWrapper):
2877  """Computes Kullback-Leibler divergence metric between `y_true` and `y_pred`.
2878
2879  `metric = y_true * log(y_true / y_pred)`
2880
2881  Args:
2882    name: (Optional) string name of the metric instance.
2883    dtype: (Optional) data type of the metric result.
2884
2885  Standalone usage:
2886
2887  >>> m = tf.keras.metrics.KLDivergence()
2888  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
2889  >>> m.result().numpy()
2890  0.45814306
2891
2892  >>> m.reset_state()
2893  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
2894  ...                sample_weight=[1, 0])
2895  >>> m.result().numpy()
2896  0.9162892
2897
2898  Usage with `compile()` API:
2899
2900  ```python
2901  model.compile(optimizer='sgd',
2902                loss='mse',
2903                metrics=[tf.keras.metrics.KLDivergence()])
2904  ```
2905  """
2906
2907  def __init__(self, name='kullback_leibler_divergence', dtype=None):
2908    super(KLDivergence, self).__init__(
2909        kullback_leibler_divergence, name, dtype=dtype)
2910
2911
2912@keras_export('keras.metrics.MeanIoU')
2913class MeanIoU(Metric):
2914  """Computes the mean Intersection-Over-Union metric.
2915
2916  Mean Intersection-Over-Union is a common evaluation metric for semantic image
2917  segmentation, which first computes the IOU for each semantic class and then
2918  computes the average over classes. IOU is defined as follows:
2919    IOU = true_positive / (true_positive + false_positive + false_negative).
2920  The predictions are accumulated in a confusion matrix, weighted by
2921  `sample_weight` and the metric is then calculated from it.
2922
2923  If `sample_weight` is `None`, weights default to 1.
2924  Use `sample_weight` of 0 to mask values.
2925
2926  Args:
2927    num_classes: The possible number of labels the prediction task can have.
2928      This value must be provided, since a confusion matrix of dimension =
2929      [num_classes, num_classes] will be allocated.
2930    name: (Optional) string name of the metric instance.
2931    dtype: (Optional) data type of the metric result.
2932
2933  Standalone usage:
2934
2935  >>> # cm = [[1, 1],
2936  >>> #        [1, 1]]
2937  >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]
2938  >>> # iou = true_positives / (sum_row + sum_col - true_positives))
2939  >>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33
2940  >>> m = tf.keras.metrics.MeanIoU(num_classes=2)
2941  >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])
2942  >>> m.result().numpy()
2943  0.33333334
2944
2945  >>> m.reset_state()
2946  >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1],
2947  ...                sample_weight=[0.3, 0.3, 0.3, 0.1])
2948  >>> m.result().numpy()
2949  0.23809525
2950
2951  Usage with `compile()` API:
2952
2953  ```python
2954  model.compile(
2955    optimizer='sgd',
2956    loss='mse',
2957    metrics=[tf.keras.metrics.MeanIoU(num_classes=2)])
2958  ```
2959  """
2960
2961  def __init__(self, num_classes, name=None, dtype=None):
2962    super(MeanIoU, self).__init__(name=name, dtype=dtype)
2963    self.num_classes = num_classes
2964
2965    # Variable to accumulate the predictions in the confusion matrix.
2966    self.total_cm = self.add_weight(
2967        'total_confusion_matrix',
2968        shape=(num_classes, num_classes),
2969        initializer=init_ops.zeros_initializer)
2970
2971  def update_state(self, y_true, y_pred, sample_weight=None):
2972    """Accumulates the confusion matrix statistics.
2973
2974    Args:
2975      y_true: The ground truth values.
2976      y_pred: The predicted values.
2977      sample_weight: Optional weighting of each example. Defaults to 1. Can be a
2978        `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
2979        be broadcastable to `y_true`.
2980
2981    Returns:
2982      Update op.
2983    """
2984
2985    y_true = math_ops.cast(y_true, self._dtype)
2986    y_pred = math_ops.cast(y_pred, self._dtype)
2987
2988    # Flatten the input if its rank > 1.
2989    if y_pred.shape.ndims > 1:
2990      y_pred = array_ops.reshape(y_pred, [-1])
2991
2992    if y_true.shape.ndims > 1:
2993      y_true = array_ops.reshape(y_true, [-1])
2994
2995    if sample_weight is not None:
2996      sample_weight = math_ops.cast(sample_weight, self._dtype)
2997      if sample_weight.shape.ndims > 1:
2998        sample_weight = array_ops.reshape(sample_weight, [-1])
2999
3000    # Accumulate the prediction to current confusion matrix.
3001    current_cm = confusion_matrix.confusion_matrix(
3002        y_true,
3003        y_pred,
3004        self.num_classes,
3005        weights=sample_weight,
3006        dtype=self._dtype)
3007    return self.total_cm.assign_add(current_cm)
3008
3009  def result(self):
3010    """Compute the mean intersection-over-union via the confusion matrix."""
3011    sum_over_row = math_ops.cast(
3012        math_ops.reduce_sum(self.total_cm, axis=0), dtype=self._dtype)
3013    sum_over_col = math_ops.cast(
3014        math_ops.reduce_sum(self.total_cm, axis=1), dtype=self._dtype)
3015    true_positives = math_ops.cast(
3016        array_ops.tensor_diag_part(self.total_cm), dtype=self._dtype)
3017
3018    # sum_over_row + sum_over_col =
3019    #     2 * true_positives + false_positives + false_negatives.
3020    denominator = sum_over_row + sum_over_col - true_positives
3021
3022    # The mean is only computed over classes that appear in the
3023    # label or prediction tensor. If the denominator is 0, we need to
3024    # ignore the class.
3025    num_valid_entries = math_ops.reduce_sum(
3026        math_ops.cast(math_ops.not_equal(denominator, 0), dtype=self._dtype))
3027
3028    iou = math_ops.div_no_nan(true_positives, denominator)
3029
3030    return math_ops.div_no_nan(
3031        math_ops.reduce_sum(iou, name='mean_iou'), num_valid_entries)
3032
3033  def reset_state(self):
3034    backend.set_value(
3035        self.total_cm, np.zeros((self.num_classes, self.num_classes)))
3036
3037  def get_config(self):
3038    config = {'num_classes': self.num_classes}
3039    base_config = super(MeanIoU, self).get_config()
3040    return dict(list(base_config.items()) + list(config.items()))
3041
3042
3043@keras_export('keras.metrics.MeanTensor')
3044class MeanTensor(Metric):
3045  """Computes the element-wise (weighted) mean of the given tensors.
3046
3047  `MeanTensor` returns a tensor with the same shape of the input tensors. The
3048  mean value is updated by keeping local variables `total` and `count`. The
3049  `total` tracks the sum of the weighted values, and `count` stores the sum of
3050  the weighted counts.
3051
3052  Args:
3053    name: (Optional) string name of the metric instance.
3054    dtype: (Optional) data type of the metric result.
3055    shape: (Optional) A list of integers, a tuple of integers, or a 1-D Tensor
3056      of type int32. If not specified, the shape is inferred from the values at
3057      the first call of update_state.
3058
3059  Standalone usage:
3060
3061  >>> m = tf.keras.metrics.MeanTensor()
3062  >>> m.update_state([0, 1, 2, 3])
3063  >>> m.update_state([4, 5, 6, 7])
3064  >>> m.result().numpy()
3065  array([2., 3., 4., 5.], dtype=float32)
3066
3067  >>> m.update_state([12, 10, 8, 6], sample_weight= [0, 0.2, 0.5, 1])
3068  >>> m.result().numpy()
3069  array([2.       , 3.6363635, 4.8      , 5.3333335], dtype=float32)
3070
3071  >>> m = tf.keras.metrics.MeanTensor(dtype=tf.float64, shape=(1, 4))
3072  >>> m.result().numpy()
3073  array([[0., 0., 0., 0.]])
3074  >>> m.update_state([[0, 1, 2, 3]])
3075  >>> m.update_state([[4, 5, 6, 7]])
3076  >>> m.result().numpy()
3077  array([[2., 3., 4., 5.]])
3078  """
3079
3080  def __init__(self, name='mean_tensor', dtype=None, shape=None):
3081    super(MeanTensor, self).__init__(name=name, dtype=dtype)
3082    self._shape = None
3083    self._total = None
3084    self._count = None
3085    self._built = False
3086    if shape is not None:
3087      self._build(shape)
3088
3089  def _build(self, shape):
3090    self._shape = tensor_shape.TensorShape(shape)
3091    self._build_input_shape = self._shape
3092    # Create new state variables
3093    self._total = self.add_weight(
3094        'total', shape=shape, initializer=init_ops.zeros_initializer)
3095    self._count = self.add_weight(
3096        'count', shape=shape, initializer=init_ops.zeros_initializer)
3097    with ops.init_scope():
3098      if not context.executing_eagerly():
3099        backend._initialize_variables(backend._get_session())  # pylint: disable=protected-access
3100    self._built = True
3101
3102  @property
3103  def total(self):
3104    return self._total if self._built else None
3105
3106  @property
3107  def count(self):
3108    return self._count if self._built else None
3109
3110  def update_state(self, values, sample_weight=None):
3111    """Accumulates statistics for computing the element-wise mean.
3112
3113    Args:
3114      values: Per-example value.
3115      sample_weight: Optional weighting of each example. Defaults to 1.
3116
3117    Returns:
3118      Update op.
3119    """
3120    values = math_ops.cast(values, self._dtype)
3121    if not self._built:
3122      self._build(values.shape)
3123    elif values.shape != self._shape:
3124      raise ValueError('MeanTensor input values must always have the same '
3125                       'shape. Expected shape (set during the first call): {}. '
3126                       'Got: {}'.format(self._shape, values.shape))
3127
3128    num_values = array_ops.ones_like(values)
3129    if sample_weight is not None:
3130      sample_weight = math_ops.cast(sample_weight, self._dtype)
3131
3132      # Update dimensions of weights to match with values if possible.
3133      values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions(
3134          values, sample_weight=sample_weight)
3135      try:
3136        # Broadcast weights if possible.
3137        sample_weight = weights_broadcast_ops.broadcast_weights(
3138            sample_weight, values)
3139      except ValueError:
3140        # Reduce values to same ndim as weight array
3141        ndim = backend.ndim(values)
3142        weight_ndim = backend.ndim(sample_weight)
3143        values = math_ops.reduce_mean(
3144            values, axis=list(range(weight_ndim, ndim)))
3145
3146      num_values = math_ops.multiply(num_values, sample_weight)
3147      values = math_ops.multiply(values, sample_weight)
3148
3149    update_total_op = self._total.assign_add(values)
3150    with ops.control_dependencies([update_total_op]):
3151      return self._count.assign_add(num_values)
3152
3153  def result(self):
3154    if not self._built:
3155      raise ValueError(
3156          'MeanTensor does not have any result yet. Please call the MeanTensor '
3157          'instance or use `.update_state(value)` before retrieving the result.'
3158          )
3159    return math_ops.div_no_nan(self.total, self.count)
3160
3161  def reset_state(self):
3162    if self._built:
3163      backend.batch_set_value(
3164          [(v, np.zeros(self._shape.as_list())) for v in self.variables])
3165
3166
3167@keras_export('keras.metrics.BinaryCrossentropy')
3168class BinaryCrossentropy(MeanMetricWrapper):
3169  """Computes the crossentropy metric between the labels and predictions.
3170
3171  This is the crossentropy metric class to be used when there are only two
3172  label classes (0 and 1).
3173
3174  Args:
3175    name: (Optional) string name of the metric instance.
3176    dtype: (Optional) data type of the metric result.
3177    from_logits: (Optional )Whether output is expected to be a logits tensor.
3178      By default, we consider that output encodes a probability distribution.
3179    label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are
3180      smoothed, meaning the confidence on label values are relaxed.
3181      e.g. `label_smoothing=0.2` means that we will use a value of `0.1` for
3182      label `0` and `0.9` for label `1`".
3183
3184  Standalone usage:
3185
3186  >>> m = tf.keras.metrics.BinaryCrossentropy()
3187  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])
3188  >>> m.result().numpy()
3189  0.81492424
3190
3191  >>> m.reset_state()
3192  >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],
3193  ...                sample_weight=[1, 0])
3194  >>> m.result().numpy()
3195  0.9162905
3196
3197  Usage with `compile()` API:
3198
3199  ```python
3200  model.compile(
3201      optimizer='sgd',
3202      loss='mse',
3203      metrics=[tf.keras.metrics.BinaryCrossentropy()])
3204  ```
3205  """
3206
3207  def __init__(self,
3208               name='binary_crossentropy',
3209               dtype=None,
3210               from_logits=False,
3211               label_smoothing=0):
3212    super(BinaryCrossentropy, self).__init__(
3213        binary_crossentropy,
3214        name,
3215        dtype=dtype,
3216        from_logits=from_logits,
3217        label_smoothing=label_smoothing)
3218
3219
3220@keras_export('keras.metrics.CategoricalCrossentropy')
3221class CategoricalCrossentropy(MeanMetricWrapper):
3222  """Computes the crossentropy metric between the labels and predictions.
3223
3224  This is the crossentropy metric class to be used when there are multiple
3225  label classes (2 or more). Here we assume that labels are given as a `one_hot`
3226  representation. eg., When labels values are [2, 0, 1],
3227   `y_true` = [[0, 0, 1], [1, 0, 0], [0, 1, 0]].
3228
3229  Args:
3230    name: (Optional) string name of the metric instance.
3231    dtype: (Optional) data type of the metric result.
3232    from_logits: (Optional) Whether output is expected to be a logits tensor.
3233      By default, we consider that output encodes a probability distribution.
3234    label_smoothing: (Optional) Float in [0, 1]. When > 0, label values are
3235      smoothed, meaning the confidence on label values are relaxed. e.g.
3236      `label_smoothing=0.2` means that we will use a value of `0.1` for label
3237      `0` and `0.9` for label `1`"
3238
3239  Standalone usage:
3240
3241  >>> # EPSILON = 1e-7, y = y_true, y` = y_pred
3242  >>> # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON)
3243  >>> # y` = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]]
3244  >>> # xent = -sum(y * log(y'), axis = -1)
3245  >>> #      = -((log 0.95), (log 0.1))
3246  >>> #      = [0.051, 2.302]
3247  >>> # Reduced xent = (0.051 + 2.302) / 2
3248  >>> m = tf.keras.metrics.CategoricalCrossentropy()
3249  >>> m.update_state([[0, 1, 0], [0, 0, 1]],
3250  ...                [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
3251  >>> m.result().numpy()
3252  1.1769392
3253
3254  >>> m.reset_state()
3255  >>> m.update_state([[0, 1, 0], [0, 0, 1]],
3256  ...                [[0.05, 0.95, 0], [0.1, 0.8, 0.1]],
3257  ...                sample_weight=tf.constant([0.3, 0.7]))
3258  >>> m.result().numpy()
3259  1.6271976
3260
3261  Usage with `compile()` API:
3262
3263  ```python
3264  model.compile(
3265    optimizer='sgd',
3266    loss='mse',
3267    metrics=[tf.keras.metrics.CategoricalCrossentropy()])
3268  ```
3269  """
3270
3271  def __init__(self,
3272               name='categorical_crossentropy',
3273               dtype=None,
3274               from_logits=False,
3275               label_smoothing=0):
3276    super(CategoricalCrossentropy, self).__init__(
3277        categorical_crossentropy,
3278        name,
3279        dtype=dtype,
3280        from_logits=from_logits,
3281        label_smoothing=label_smoothing)
3282
3283
3284@keras_export('keras.metrics.SparseCategoricalCrossentropy')
3285class SparseCategoricalCrossentropy(MeanMetricWrapper):
3286  """Computes the crossentropy metric between the labels and predictions.
3287
3288  Use this crossentropy metric when there are two or more label classes.
3289  We expect labels to be provided as integers. If you want to provide labels
3290  using `one-hot` representation, please use `CategoricalCrossentropy` metric.
3291  There should be `# classes` floating point values per feature for `y_pred`
3292  and a single floating point value per feature for `y_true`.
3293
3294  In the snippet below, there is a single floating point value per example for
3295  `y_true` and `# classes` floating pointing values per example for `y_pred`.
3296  The shape of `y_true` is `[batch_size]` and the shape of `y_pred` is
3297  `[batch_size, num_classes]`.
3298
3299  Args:
3300    name: (Optional) string name of the metric instance.
3301    dtype: (Optional) data type of the metric result.
3302    from_logits: (Optional) Whether output is expected to be a logits tensor.
3303      By default, we consider that output encodes a probability distribution.
3304    axis: (Optional) Defaults to -1. The dimension along which the metric is
3305      computed.
3306
3307  Standalone usage:
3308
3309  >>> # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]]
3310  >>> # logits = log(y_pred)
3311  >>> # softmax = exp(logits) / sum(exp(logits), axis=-1)
3312  >>> # softmax = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]]
3313  >>> # xent = -sum(y * log(softmax), 1)
3314  >>> # log(softmax) = [[-2.9957, -0.0513, -16.1181],
3315  >>> #                [-2.3026, -0.2231, -2.3026]]
3316  >>> # y_true * log(softmax) = [[0, -0.0513, 0], [0, 0, -2.3026]]
3317  >>> # xent = [0.0513, 2.3026]
3318  >>> # Reduced xent = (0.0513 + 2.3026) / 2
3319  >>> m = tf.keras.metrics.SparseCategoricalCrossentropy()
3320  >>> m.update_state([1, 2],
3321  ...                [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])
3322  >>> m.result().numpy()
3323  1.1769392
3324
3325  >>> m.reset_state()
3326  >>> m.update_state([1, 2],
3327  ...                [[0.05, 0.95, 0], [0.1, 0.8, 0.1]],
3328  ...                sample_weight=tf.constant([0.3, 0.7]))
3329  >>> m.result().numpy()
3330  1.6271976
3331
3332  Usage with `compile()` API:
3333
3334  ```python
3335  model.compile(
3336    optimizer='sgd',
3337    loss='mse',
3338    metrics=[tf.keras.metrics.SparseCategoricalCrossentropy()])
3339  ```
3340  """
3341
3342  def __init__(self,
3343               name='sparse_categorical_crossentropy',
3344               dtype=None,
3345               from_logits=False,
3346               axis=-1):
3347    super(SparseCategoricalCrossentropy, self).__init__(
3348        sparse_categorical_crossentropy,
3349        name,
3350        dtype=dtype,
3351        from_logits=from_logits,
3352        axis=axis)
3353
3354
3355class SumOverBatchSize(Reduce):
3356  """Computes the weighted sum over batch size of the given values.
3357
3358  For example, if values is [1, 3, 5, 7] then the metric value is 4.
3359  If the weights were specified as [1, 1, 0, 0] then the value would be 1.
3360
3361  This metric creates two variables, `total` and `count` that are used to
3362  compute the average of `values`. This average is ultimately returned as sum
3363  over batch size which is an idempotent operation that simply divides `total`
3364  by `count`.
3365
3366  If `sample_weight` is `None`, weights default to 1.  Use `sample_weight` of 0
3367  to mask values.
3368  """
3369
3370  def __init__(self, name='sum_over_batch_size', dtype=None):
3371    super(SumOverBatchSize, self).__init__(
3372        reduction=metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
3373        name=name,
3374        dtype=dtype)
3375
3376
3377class SumOverBatchSizeMetricWrapper(SumOverBatchSize):
3378  """Wraps a function with the `SumOverBatchSizeMetricWrapper` metric."""
3379
3380  def __init__(self, fn, name=None, dtype=None, **kwargs):
3381    """Creates a `SumOverBatchSizeMetricWrapper` instance.
3382
3383    Args:
3384      fn: The metric function to wrap, with signature `fn(y_true, y_pred,
3385        **kwargs)`.
3386      name: (Optional) string name of the metric instance.
3387      dtype: (Optional) data type of the metric result.
3388      **kwargs: The keyword arguments that are passed on to `fn`.
3389    """
3390    super(SumOverBatchSizeMetricWrapper, self).__init__(name=name, dtype=dtype)
3391    self._fn = fn
3392    self._fn_kwargs = kwargs
3393
3394  def update_state(self, y_true, y_pred, sample_weight=None):
3395    y_true = math_ops.cast(y_true, self._dtype)
3396    y_pred = math_ops.cast(y_pred, self._dtype)
3397    y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
3398        y_pred, y_true)
3399
3400    ag_fn = autograph.tf_convert(self._fn, ag_ctx.control_status_ctx())
3401    matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
3402    return super(SumOverBatchSizeMetricWrapper, self).update_state(
3403        matches, sample_weight=sample_weight)
3404
3405  def get_config(self):
3406    config = {}
3407    for k, v in self._fn_kwargs.items():
3408      config[k] = backend.eval(v) if is_tensor_or_variable(v) else v
3409    base_config = super(SumOverBatchSizeMetricWrapper, self).get_config()
3410    return dict(list(base_config.items()) + list(config.items()))
3411
3412
3413def accuracy(y_true, y_pred):
3414  [y_pred, y_true], _ = \
3415      metrics_utils.ragged_assert_compatible_and_get_flat_values(
3416          [y_pred, y_true])
3417  y_true.shape.assert_is_compatible_with(y_pred.shape)
3418  if y_true.dtype != y_pred.dtype:
3419    y_pred = math_ops.cast(y_pred, y_true.dtype)
3420  return math_ops.cast(math_ops.equal(y_true, y_pred), backend.floatx())
3421
3422
3423@keras_export('keras.metrics.binary_accuracy')
3424@dispatch.add_dispatch_support
3425def binary_accuracy(y_true, y_pred, threshold=0.5):
3426  """Calculates how often predictions match binary labels.
3427
3428  Standalone usage:
3429  >>> y_true = [[1], [1], [0], [0]]
3430  >>> y_pred = [[1], [1], [0], [0]]
3431  >>> m = tf.keras.metrics.binary_accuracy(y_true, y_pred)
3432  >>> assert m.shape == (4,)
3433  >>> m.numpy()
3434  array([1., 1., 1., 1.], dtype=float32)
3435
3436  Args:
3437    y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
3438    y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
3439    threshold: (Optional) Float representing the threshold for deciding whether
3440      prediction values are 1 or 0.
3441
3442  Returns:
3443    Binary accuracy values. shape = `[batch_size, d0, .. dN-1]`
3444  """
3445  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
3446  threshold = math_ops.cast(threshold, y_pred.dtype)
3447  y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype)
3448  return backend.mean(math_ops.equal(y_true, y_pred), axis=-1)
3449
3450
3451@keras_export('keras.metrics.categorical_accuracy')
3452@dispatch.add_dispatch_support
3453def categorical_accuracy(y_true, y_pred):
3454  """Calculates how often predictions match one-hot labels.
3455
3456  Standalone usage:
3457  >>> y_true = [[0, 0, 1], [0, 1, 0]]
3458  >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
3459  >>> m = tf.keras.metrics.categorical_accuracy(y_true, y_pred)
3460  >>> assert m.shape == (2,)
3461  >>> m.numpy()
3462  array([0., 1.], dtype=float32)
3463
3464  You can provide logits of classes as `y_pred`, since argmax of
3465  logits and probabilities are same.
3466
3467  Args:
3468    y_true: One-hot ground truth values.
3469    y_pred: The prediction values.
3470
3471  Returns:
3472    Categorical accuracy values.
3473  """
3474  return math_ops.cast(
3475      math_ops.equal(
3476          math_ops.argmax(y_true, axis=-1), math_ops.argmax(y_pred, axis=-1)),
3477      backend.floatx())
3478
3479
3480@keras_export('keras.metrics.sparse_categorical_accuracy')
3481@dispatch.add_dispatch_support
3482def sparse_categorical_accuracy(y_true, y_pred):
3483  """Calculates how often predictions match integer labels.
3484
3485  Standalone usage:
3486  >>> y_true = [2, 1]
3487  >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
3488  >>> m = tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred)
3489  >>> assert m.shape == (2,)
3490  >>> m.numpy()
3491  array([0., 1.], dtype=float32)
3492
3493  You can provide logits of classes as `y_pred`, since argmax of
3494  logits and probabilities are same.
3495
3496  Args:
3497    y_true: Integer ground truth values.
3498    y_pred: The prediction values.
3499
3500  Returns:
3501    Sparse categorical accuracy values.
3502  """
3503  y_pred = ops.convert_to_tensor_v2_with_dispatch(y_pred)
3504  y_true = ops.convert_to_tensor_v2_with_dispatch(y_true)
3505  y_pred_rank = y_pred.shape.ndims
3506  y_true_rank = y_true.shape.ndims
3507  # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
3508  if (y_true_rank is not None) and (y_pred_rank is not None) and (len(
3509      backend.int_shape(y_true)) == len(backend.int_shape(y_pred))):
3510    y_true = array_ops.squeeze(y_true, [-1])
3511  y_pred = math_ops.argmax(y_pred, axis=-1)
3512
3513  # If the predicted output and actual output types don't match, force cast them
3514  # to match.
3515  if backend.dtype(y_pred) != backend.dtype(y_true):
3516    y_pred = math_ops.cast(y_pred, backend.dtype(y_true))
3517
3518  return math_ops.cast(math_ops.equal(y_true, y_pred), backend.floatx())
3519
3520
3521@keras_export('keras.metrics.top_k_categorical_accuracy')
3522@dispatch.add_dispatch_support
3523def top_k_categorical_accuracy(y_true, y_pred, k=5):
3524  """Computes how often targets are in the top `K` predictions.
3525
3526  Standalone usage:
3527  >>> y_true = [[0, 0, 1], [0, 1, 0]]
3528  >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
3529  >>> m = tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)
3530  >>> assert m.shape == (2,)
3531  >>> m.numpy()
3532  array([1., 1.], dtype=float32)
3533
3534  Args:
3535    y_true: The ground truth values.
3536    y_pred: The prediction values.
3537    k: (Optional) Number of top elements to look at for computing accuracy.
3538      Defaults to 5.
3539
3540  Returns:
3541    Top K categorical accuracy value.
3542  """
3543  return math_ops.cast(
3544      nn.in_top_k(
3545          y_pred, math_ops.argmax(y_true, axis=-1), k), backend.floatx())
3546
3547
3548@keras_export('keras.metrics.sparse_top_k_categorical_accuracy')
3549@dispatch.add_dispatch_support
3550def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
3551  """Computes how often integer targets are in the top `K` predictions.
3552
3553  Standalone usage:
3554  >>> y_true = [2, 1]
3555  >>> y_pred = [[0.1, 0.9, 0.8], [0.05, 0.95, 0]]
3556  >>> m = tf.keras.metrics.sparse_top_k_categorical_accuracy(
3557  ...     y_true, y_pred, k=3)
3558  >>> assert m.shape == (2,)
3559  >>> m.numpy()
3560  array([1., 1.], dtype=float32)
3561
3562  Args:
3563    y_true: tensor of true targets.
3564    y_pred: tensor of predicted targets.
3565    k: (Optional) Number of top elements to look at for computing accuracy.
3566      Defaults to 5.
3567
3568  Returns:
3569    Sparse top K categorical accuracy value.
3570  """
3571  y_pred_rank = ops.convert_to_tensor_v2_with_dispatch(y_pred).shape.ndims
3572  y_true_rank = ops.convert_to_tensor_v2_with_dispatch(y_true).shape.ndims
3573  # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)
3574  if (y_true_rank is not None) and (y_pred_rank is not None):
3575    if y_pred_rank > 2:
3576      y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]])
3577    if y_true_rank > 1:
3578      y_true = array_ops.reshape(y_true, [-1])
3579
3580  return math_ops.cast(
3581      nn.in_top_k(y_pred, math_ops.cast(y_true, 'int32'), k), backend.floatx())
3582
3583
3584def cosine_proximity(y_true, y_pred, axis=-1):
3585  """Computes the cosine similarity between labels and predictions.
3586
3587  Args:
3588    y_true: The ground truth values.
3589    y_pred: The prediction values.
3590    axis: (Optional) Defaults to -1. The dimension along which the cosine
3591      similarity is computed.
3592
3593  Returns:
3594    Cosine similarity value.
3595  """
3596  y_true = nn.l2_normalize(y_true, axis=axis)
3597  y_pred = nn.l2_normalize(y_pred, axis=axis)
3598  return math_ops.reduce_sum(y_true * y_pred, axis=axis)
3599
3600# Aliases
3601
3602acc = ACC = accuracy
3603bce = BCE = binary_crossentropy
3604mse = MSE = mean_squared_error
3605mae = MAE = mean_absolute_error
3606mape = MAPE = mean_absolute_percentage_error
3607msle = MSLE = mean_squared_logarithmic_error
3608cosine_similarity = cosine_proximity
3609log_cosh = logcosh
3610
3611
3612def clone_metric(metric):
3613  """Returns a clone of the metric if stateful, otherwise returns it as is."""
3614  if isinstance(metric, Metric):
3615    with ops.init_scope():
3616      return metric.__class__.from_config(metric.get_config())
3617  return metric
3618
3619
3620def clone_metrics(metrics):
3621  """Clones the given metric list/dict."""
3622  return nest.map_structure(clone_metric, metrics)
3623
3624
3625@keras_export('keras.metrics.serialize')
3626def serialize(metric):
3627  """Serializes metric function or `Metric` instance.
3628
3629  Args:
3630    metric: A Keras `Metric` instance or a metric function.
3631
3632  Returns:
3633    Metric configuration dictionary.
3634  """
3635  return serialize_keras_object(metric)
3636
3637
3638@keras_export('keras.metrics.deserialize')
3639def deserialize(config, custom_objects=None):
3640  """Deserializes a serialized metric class/function instance.
3641
3642  Args:
3643    config: Metric configuration.
3644    custom_objects: Optional dictionary mapping names (strings) to custom
3645      objects (classes and functions) to be considered during deserialization.
3646
3647  Returns:
3648      A Keras `Metric` instance or a metric function.
3649  """
3650  return deserialize_keras_object(
3651      config,
3652      module_objects=globals(),
3653      custom_objects=custom_objects,
3654      printable_module_name='metric function')
3655
3656
3657@keras_export('keras.metrics.get')
3658def get(identifier):
3659  """Retrieves a Keras metric as a `function`/`Metric` class instance.
3660
3661  The `identifier` may be the string name of a metric function or class.
3662
3663  >>> metric = tf.keras.metrics.get("categorical_crossentropy")
3664  >>> type(metric)
3665  <class 'function'>
3666  >>> metric = tf.keras.metrics.get("CategoricalCrossentropy")
3667  >>> type(metric)
3668  <class '...keras.metrics.CategoricalCrossentropy'>
3669
3670  You can also specify `config` of the metric to this function by passing dict
3671  containing `class_name` and `config` as an identifier. Also note that the
3672  `class_name` must map to a `Metric` class
3673
3674  >>> identifier = {"class_name": "CategoricalCrossentropy",
3675  ...               "config": {"from_logits": True}}
3676  >>> metric = tf.keras.metrics.get(identifier)
3677  >>> type(metric)
3678  <class '...keras.metrics.CategoricalCrossentropy'>
3679
3680  Args:
3681    identifier: A metric identifier. One of None or string name of a metric
3682      function/class or metric configuration dictionary or a metric function or
3683      a metric class instance
3684
3685  Returns:
3686    A Keras metric as a `function`/ `Metric` class instance.
3687
3688  Raises:
3689    ValueError: If `identifier` cannot be interpreted.
3690  """
3691  if isinstance(identifier, dict):
3692    return deserialize(identifier)
3693  elif isinstance(identifier, str):
3694    return deserialize(str(identifier))
3695  elif callable(identifier):
3696    return identifier
3697  else:
3698    raise ValueError(
3699        'Could not interpret metric function identifier: {}'.format(identifier))
3700
3701
3702def is_built_in(cls):
3703  return cls.__module__ == Metric.__module__
3704