xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/metrics_impl.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2# Licensed under the Apache License, Version 2.0 (the "License");
3# you may not use this file except in compliance with the License.
4# You may obtain a copy of the License at
5#
6#     http://www.apache.org/licenses/LICENSE-2.0
7#
8# Unless required by applicable law or agreed to in writing, software
9# distributed under the License is distributed on an "AS IS" BASIS,
10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11# See the License for the specific language governing permissions and
12# limitations under the License.
13# ==============================================================================
14"""Implementation of tf.metrics module."""
15
16from tensorflow.python.distribute import distribution_strategy_context
17from tensorflow.python.eager import context
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import sparse_tensor
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import check_ops
23from tensorflow.python.ops import confusion_matrix
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import nn
27from tensorflow.python.ops import sets
28from tensorflow.python.ops import sparse_ops
29from tensorflow.python.ops import state_ops
30from tensorflow.python.ops import variable_scope
31from tensorflow.python.ops import weights_broadcast_ops
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.util.deprecation import deprecated
34from tensorflow.python.util.tf_export import tf_export
35
36
37def metric_variable(shape, dtype, validate_shape=True, name=None):
38  """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections.
39
40  If running in a `DistributionStrategy` context, the variable will be
41  "sync on read". This means:
42
43  *   The returned object will be a container with separate variables
44      per replica of the model.
45
46  *   When writing to the variable, e.g. using `assign_add` in a metric
47      update, the update will be applied to the variable local to the
48      replica.
49
50  *   To get a metric's result value, we need to sum the variable values
51      across the replicas before computing the final answer. Furthermore,
52      the final answer should be computed once instead of in every
53      replica. Both of these are accomplished by running the computation
54      of the final result value inside
55      `distribution_strategy_context.get_replica_context().merge_call(fn)`.
56      Inside the `merge_call()`, ops are only added to the graph once
57      and access to a sync on read variable in a computation returns
58      the sum across all replicas.
59
60  Args:
61    shape: Shape of the created variable.
62    dtype: Type of the created variable.
63    validate_shape: (Optional) Whether shape validation is enabled for
64      the created variable.
65    name: (Optional) String name of the created variable.
66
67  Returns:
68    A (non-trainable) variable initialized to zero, or if inside a
69    `DistributionStrategy` scope a sync on read variable container.
70  """
71  # Note that synchronization "ON_READ" implies trainable=False.
72  return variable_scope.variable(
73      lambda: array_ops.zeros(shape, dtype),
74      trainable=False,
75      collections=[
76          ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
77      ],
78      validate_shape=validate_shape,
79      synchronization=variable_scope.VariableSynchronization.ON_READ,
80      aggregation=variable_scope.VariableAggregation.SUM,
81      name=name)
82
83
84def _remove_squeezable_dimensions(predictions, labels, weights):
85  """Squeeze or expand last dim if needed.
86
87  Squeezes last dim of `predictions` or `labels` if their rank differs by 1
88  (using confusion_matrix.remove_squeezable_dimensions).
89  Squeezes or expands last dim of `weights` if its rank differs by 1 from the
90  new rank of `predictions`.
91
92  If `weights` is scalar, it is kept scalar.
93
94  This will use static shape if available. Otherwise, it will add graph
95  operations, which could result in a performance hit.
96
97  Args:
98    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
99    labels: Optional label `Tensor` whose dimensions match `predictions`.
100    weights: Optional weight scalar or `Tensor` whose dimensions match
101      `predictions`.
102
103  Returns:
104    Tuple of `predictions`, `labels` and `weights`. Each of them possibly has
105    the last dimension squeezed, `weights` could be extended by one dimension.
106  """
107  predictions = ops.convert_to_tensor(predictions)
108  if labels is not None:
109    labels, predictions = confusion_matrix.remove_squeezable_dimensions(
110        labels, predictions)
111    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
112
113  if weights is None:
114    return predictions, labels, None
115
116  weights = ops.convert_to_tensor(weights)
117  weights_shape = weights.get_shape()
118  weights_rank = weights_shape.ndims
119  if weights_rank == 0:
120    return predictions, labels, weights
121
122  predictions_shape = predictions.get_shape()
123  predictions_rank = predictions_shape.ndims
124  if (predictions_rank is not None) and (weights_rank is not None):
125    # Use static rank.
126    if weights_rank - predictions_rank == 1:
127      weights = array_ops.squeeze(weights, [-1])
128    elif predictions_rank - weights_rank == 1:
129      weights = array_ops.expand_dims(weights, [-1])
130  else:
131    # Use dynamic rank.
132    weights_rank_tensor = array_ops.rank(weights)
133    rank_diff = weights_rank_tensor - array_ops.rank(predictions)
134
135    def _maybe_expand_weights():
136      return control_flow_ops.cond(
137          math_ops.equal(rank_diff, -1),
138          lambda: array_ops.expand_dims(weights, [-1]), lambda: weights)
139
140    # Don't attempt squeeze if it will fail based on static check.
141    if ((weights_rank is not None) and
142        (not weights_shape.dims[-1].is_compatible_with(1))):
143      maybe_squeeze_weights = lambda: weights
144    else:
145      maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1])
146
147    def _maybe_adjust_weights():
148      return control_flow_ops.cond(
149          math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
150          _maybe_expand_weights)
151
152    # If weights are scalar, do nothing. Otherwise, try to add or remove a
153    # dimension to match predictions.
154    weights = control_flow_ops.cond(
155        math_ops.equal(weights_rank_tensor, 0), lambda: weights,
156        _maybe_adjust_weights)
157  return predictions, labels, weights
158
159
160def _maybe_expand_labels(labels, predictions):
161  """If necessary, expand `labels` along last dimension to match `predictions`.
162
163  Args:
164    labels: `Tensor` or `SparseTensor` with shape
165      [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies
166      num_labels=1, in which case the result is an expanded `labels` with shape
167      [D1, ... DN, 1].
168    predictions: `Tensor` with shape [D1, ... DN, num_classes].
169
170  Returns:
171    `labels` with the same rank as `predictions`.
172
173  Raises:
174    ValueError: if `labels` has invalid shape.
175  """
176  with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope:
177    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
178
179    # If sparse, expand sparse shape.
180    if isinstance(labels, sparse_tensor.SparseTensor):
181      return control_flow_ops.cond(
182          math_ops.equal(
183              array_ops.rank(predictions),
184              array_ops.size(labels.dense_shape) + 1),
185          lambda: sparse_ops.sparse_reshape(  # pylint: disable=g-long-lambda
186              labels,
187              shape=array_ops.concat((labels.dense_shape, (1,)), 0),
188              name=scope),
189          lambda: labels)
190
191    # Otherwise, try to use static shape.
192    labels_rank = labels.get_shape().ndims
193    if labels_rank is not None:
194      predictions_rank = predictions.get_shape().ndims
195      if predictions_rank is not None:
196        if predictions_rank == labels_rank:
197          return labels
198        if predictions_rank == labels_rank + 1:
199          return array_ops.expand_dims(labels, -1, name=scope)
200        raise ValueError(
201            f'Unexpected labels shape {labels.get_shape()} for predictions '
202            f'shape {predictions.get_shape()}. Predictions rank should be the '
203            'same rank as labels rank or labels rank plus one .')
204
205    # Otherwise, use dynamic shape.
206    return control_flow_ops.cond(
207        math_ops.equal(array_ops.rank(predictions),
208                       array_ops.rank(labels) + 1),
209        lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
210
211
212def _safe_scalar_div(numerator, denominator, name):
213  """Divides two values, returning 0 if the denominator is 0.
214
215  Args:
216    numerator: A scalar `float64` `Tensor`.
217    denominator: A scalar `float64` `Tensor`.
218    name: Name for the returned op.
219
220  Returns:
221    0 if `denominator` == 0, else `numerator` / `denominator`
222  """
223  numerator.get_shape().with_rank_at_most(1)
224  denominator.get_shape().with_rank_at_most(1)
225  return math_ops.div_no_nan(numerator, denominator, name=name)
226
227
228def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
229  """Calculate a streaming confusion matrix.
230
231  Calculates a confusion matrix. For estimation over a stream of data,
232  the function creates an  `update_op` operation.
233
234  Args:
235    labels: A `Tensor` of ground truth labels with shape [batch size] and of
236      type `int32` or `int64`. The tensor will be flattened if its rank > 1.
237    predictions: A `Tensor` of prediction results for semantic labels, whose
238      shape is [batch size] and type `int32` or `int64`. The tensor will be
239      flattened if its rank > 1.
240    num_classes: The possible number of labels the prediction task can
241      have. This value must be provided, since a confusion matrix of
242      dimension = [num_classes, num_classes] will be allocated.
243    weights: Optional `Tensor` whose rank is either 0, or the same rank as
244      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
245      be either `1`, or the same as the corresponding `labels` dimension).
246
247  Returns:
248    total_cm: A `Tensor` representing the confusion matrix.
249    update_op: An operation that increments the confusion matrix.
250  """
251  # Local variable to accumulate the predictions in the confusion matrix.
252  total_cm = metric_variable(
253      [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix')
254
255  # Cast the type to int64 required by confusion_matrix_ops.
256  predictions = math_ops.cast(predictions, dtypes.int64)
257  labels = math_ops.cast(labels, dtypes.int64)
258  num_classes = math_ops.cast(num_classes, dtypes.int64)
259
260  # Flatten the input if its rank > 1.
261  if predictions.get_shape().ndims > 1:
262    predictions = array_ops.reshape(predictions, [-1])
263
264  if labels.get_shape().ndims > 1:
265    labels = array_ops.reshape(labels, [-1])
266
267  if (weights is not None) and (weights.get_shape().ndims > 1):
268    weights = array_ops.reshape(weights, [-1])
269
270  # Accumulate the prediction to current confusion matrix.
271  current_cm = confusion_matrix.confusion_matrix(
272      labels, predictions, num_classes, weights=weights, dtype=dtypes.float64)
273  update_op = state_ops.assign_add(total_cm, current_cm)
274  return total_cm, update_op
275
276
277def _aggregate_across_replicas(metrics_collections, metric_value_fn, *args):
278  """Aggregate metric value across replicas."""
279  def fn(distribution, *a):
280    """Call `metric_value_fn` in the correct control flow context."""
281    if hasattr(distribution.extended, '_outer_control_flow_context'):
282      # If there was an outer context captured before this method was called,
283      # then we enter that context to create the metric value op. If the
284      # captured context is `None`, ops.control_dependencies(None) gives the
285      # desired behavior. Else we use `Enter` and `Exit` to enter and exit the
286      # captured context.
287      # This special handling is needed because sometimes the metric is created
288      # inside a while_loop (and perhaps a TPU rewrite context). But we don't
289      # want the value op to be evaluated every step or on the TPU. So we
290      # create it outside so that it can be evaluated at the end on the host,
291      # once the update ops have been evaluated.
292
293      # pylint: disable=protected-access
294      if distribution.extended._outer_control_flow_context is None:
295        with ops.control_dependencies(None):
296          metric_value = metric_value_fn(distribution, *a)
297      else:
298        distribution.extended._outer_control_flow_context.Enter()
299        metric_value = metric_value_fn(distribution, *a)
300        distribution.extended._outer_control_flow_context.Exit()
301        # pylint: enable=protected-access
302    else:
303      metric_value = metric_value_fn(distribution, *a)
304    if metrics_collections:
305      ops.add_to_collections(metrics_collections, metric_value)
306    return metric_value
307
308  return distribution_strategy_context.get_replica_context().merge_call(
309      fn, args=args)
310
311
312@tf_export(v1=['metrics.mean'])
313def mean(values,
314         weights=None,
315         metrics_collections=None,
316         updates_collections=None,
317         name=None):
318  """Computes the (weighted) mean of the given values.
319
320  The `mean` function creates two local variables, `total` and `count`
321  that are used to compute the average of `values`. This average is ultimately
322  returned as `mean` which is an idempotent operation that simply divides
323  `total` by `count`.
324
325  For estimation of the metric over a stream of data, the function creates an
326  `update_op` operation that updates these variables and returns the `mean`.
327  `update_op` increments `total` with the reduced sum of the product of `values`
328  and `weights`, and it increments `count` with the reduced sum of `weights`.
329
330  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
331
332  Args:
333    values: A `Tensor` of arbitrary dimensions.
334    weights: Optional `Tensor` whose rank is either 0, or the same rank as
335      `values`, and must be broadcastable to `values` (i.e., all dimensions must
336      be either `1`, or the same as the corresponding `values` dimension).
337    metrics_collections: An optional list of collections that `mean`
338      should be added to.
339    updates_collections: An optional list of collections that `update_op`
340      should be added to.
341    name: An optional variable_scope name.
342
343  Returns:
344    mean: A `Tensor` representing the current mean, the value of `total` divided
345      by `count`.
346    update_op: An operation that increments the `total` and `count` variables
347      appropriately and whose value matches `mean_value`.
348
349  Raises:
350    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
351      or if either `metrics_collections` or `updates_collections` are not a list
352      or tuple.
353    RuntimeError: If eager execution is enabled.
354
355  @compatibility(TF2)
356  `tf.compat.v1.metrics.mean` is not compatible with eager
357  execution or `tf.function`.
358  Please use `tf.keras.metrics.Mean` instead for TF2 migration. After
359  instantiating a `tf.keras.metrics.Mean` object, you can first call the
360  `update_state()` method to record the new values, and then call the
361  `result()` method to get the mean eagerly. You can also attach it to a
362  Keras model with the `add_metric` method.  Please refer to the [migration
363  guide](https://www.tensorflow.org/guide/migrate#new-style_metrics_and_losses)
364  for more details.
365
366  #### Structural Mapping to TF2
367
368  Before:
369
370  ```python
371  mean, update_op = tf.compat.v1.metrics.mean(
372    values=values,
373    weights=weights,
374    metrics_collections=metrics_collections,
375    update_collections=update_collections,
376    name=name)
377  ```
378
379  After:
380
381  ```python
382   m = tf.keras.metrics.Mean(
383     name=name)
384
385   m.update_state(
386     values=values,
387     sample_weight=weights)
388
389   mean = m.result()
390  ```
391
392  #### How to Map Arguments
393
394  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
395  | :-------------------- | :-------------- | :------------------------- |
396  | `values`              | `values`        | In `update_state()` method |
397  | `weights`             | `sample_weight` | In `update_state()` method |
398  | `metrics_collections` | Not supported   | Metrics should be tracked  |
399  :                       :                 : explicitly or with Keras   :
400  :                       :                 : APIs, for example,         :
401  :                       :                 : [add_metric][add_metric],  :
402  :                       :                 : instead of via collections :
403  | `updates_collections` | Not supported   | -                          |
404  | `name`                | `name`          | In constructor             |
405
406  [add_metric]:https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_metric
407
408
409  #### Before & After Usage Example
410
411  Before:
412
413  >>> g = tf.Graph()
414  >>> with g.as_default():
415  ...   values = [1, 2, 3]
416  ...   mean, update_op = tf.compat.v1.metrics.mean(values)
417  ...   global_init = tf.compat.v1.global_variables_initializer()
418  ...   local_init = tf.compat.v1.local_variables_initializer()
419  >>> sess = tf.compat.v1.Session(graph=g)
420  >>> sess.run([global_init, local_init])
421  >>> sess.run(update_op)
422  >>> sess.run(mean)
423  2.0
424
425
426  After:
427
428  >>> m = tf.keras.metrics.Mean()
429  >>> m.update_state([1, 2, 3])
430  >>> m.result().numpy()
431  2.0
432
433  ```python
434  # Used within Keras model
435  model.add_metric(tf.keras.metrics.Mean()(values))
436  ```
437
438  @end_compatibility
439  """
440  if context.executing_eagerly():
441    raise RuntimeError('tf.metrics.mean is not supported when eager execution '
442                       'is enabled.')
443
444  with variable_scope.variable_scope(name, 'mean', (values, weights)):
445    values = math_ops.cast(values, dtypes.float32)
446
447    total = metric_variable([], dtypes.float32, name='total')
448    count = metric_variable([], dtypes.float32, name='count')
449
450    if weights is None:
451      num_values = math_ops.cast(array_ops.size(values), dtypes.float32)
452    else:
453      values, _, weights = _remove_squeezable_dimensions(
454          predictions=values, labels=None, weights=weights)
455      weights = weights_broadcast_ops.broadcast_weights(
456          math_ops.cast(weights, dtypes.float32), values)
457      values = math_ops.multiply(values, weights)
458      num_values = math_ops.reduce_sum(weights)
459
460    update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
461    with ops.control_dependencies([values]):
462      update_count_op = state_ops.assign_add(count, num_values)
463
464    def compute_mean(_, t, c):
465      return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value')
466
467    mean_t = _aggregate_across_replicas(
468        metrics_collections, compute_mean, total, count)
469    update_op = math_ops.div_no_nan(
470        update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
471
472    if updates_collections:
473      ops.add_to_collections(updates_collections, update_op)
474
475    return mean_t, update_op
476
477
478@tf_export(v1=['metrics.accuracy'])
479def accuracy(labels,
480             predictions,
481             weights=None,
482             metrics_collections=None,
483             updates_collections=None,
484             name=None):
485  """Calculates how often `predictions` matches `labels`.
486
487  The `accuracy` function creates two local variables, `total` and
488  `count` that are used to compute the frequency with which `predictions`
489  matches `labels`. This frequency is ultimately returned as `accuracy`: an
490  idempotent operation that simply divides `total` by `count`.
491
492  For estimation of the metric over a stream of data, the function creates an
493  `update_op` operation that updates these variables and returns the `accuracy`.
494  Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
495  where the corresponding elements of `predictions` and `labels` match and 0.0
496  otherwise. Then `update_op` increments `total` with the reduced sum of the
497  product of `weights` and `is_correct`, and it increments `count` with the
498  reduced sum of `weights`.
499
500  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
501
502  Args:
503    labels: The ground truth values, a `Tensor` whose shape matches
504      `predictions`.
505    predictions: The predicted values, a `Tensor` of any shape.
506    weights: Optional `Tensor` whose rank is either 0, or the same rank as
507      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
508      be either `1`, or the same as the corresponding `labels` dimension).
509    metrics_collections: An optional list of collections that `accuracy` should
510      be added to.
511    updates_collections: An optional list of collections that `update_op` should
512      be added to.
513    name: An optional variable_scope name.
514
515  Returns:
516    accuracy: A `Tensor` representing the accuracy, the value of `total` divided
517      by `count`.
518    update_op: An operation that increments the `total` and `count` variables
519      appropriately and whose value matches `accuracy`.
520
521  Raises:
522    ValueError: If `predictions` and `labels` have mismatched shapes, or if
523      `weights` is not `None` and its shape doesn't match `predictions`, or if
524      either `metrics_collections` or `updates_collections` are not a list or
525      tuple.
526    RuntimeError: If eager execution is enabled.
527
528  @compatibility(TF2)
529  `tf.compat.v1.metrics.accuracy` is not compatible with eager
530  execution or `tf.function`.
531  Please use `tf.keras.metrics.Accuracy` instead for TF2 migration. After
532  instantiating a `tf.keras.metrics.Accuracy` object, you can first call the
533  `update_state()` method to record the prediction/labels, and then call the
534  `result()` method to get the accuracy eagerly. You can also attach it to a
535  Keras model when calling the `compile` method. Please refer to [this
536  guide](https://www.tensorflow.org/guide/migrate#new-style_metrics_and_losses)
537  for more details.
538
539  #### Structural Mapping to Native TF2
540
541  Before:
542
543  ```python
544  accuracy, update_op = tf.compat.v1.metrics.accuracy(
545    labels=labels,
546    predictions=predictions,
547    weights=weights,
548    metrics_collections=metrics_collections,
549    update_collections=update_collections,
550    name=name)
551  ```
552
553  After:
554
555  ```python
556   m = tf.keras.metrics.Accuracy(
557     name=name,
558     dtype=None)
559
560   m.update_state(
561   y_true=labels,
562   y_pred=predictions,
563   sample_weight=weights)
564
565   accuracy = m.result()
566  ```
567
568  #### How to Map Arguments
569
570  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
571  | :-------------------- | :-------------- | :------------------------- |
572  | `label`               | `y_true`        | In `update_state()` method |
573  | `predictions`         | `y_true`        | In `update_state()` method |
574  | `weights`             | `sample_weight` | In `update_state()` method |
575  | `metrics_collections` | Not supported   | Metrics should be tracked  |
576  :                       :                 : explicitly or with Keras   :
577  :                       :                 : APIs, for example,         :
578  :                       :                 : [add_metric][add_metric],  :
579  :                       :                 : instead of via collections :
580  | `updates_collections` | Not supported   | -                          |
581  | `name`                | `name`          | In constructor             |
582
583  [add_metric]:https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_metric
584
585
586  #### Before & After Usage Example
587
588  Before:
589
590  >>> g = tf.Graph()
591  >>> with g.as_default():
592  ...   logits = [1, 2, 3]
593  ...   labels = [0, 2, 3]
594  ...   acc, acc_op = tf.compat.v1.metrics.accuracy(logits, labels)
595  ...   global_init = tf.compat.v1.global_variables_initializer()
596  ...   local_init = tf.compat.v1.local_variables_initializer()
597  >>> sess = tf.compat.v1.Session(graph=g)
598  >>> sess.run([global_init, local_init])
599  >>> print(sess.run([acc, acc_op]))
600  [0.0, 0.66667]
601
602
603  After:
604
605  >>> m = tf.keras.metrics.Accuracy()
606  >>> m.update_state([1, 2, 3], [0, 2, 3])
607  >>> m.result().numpy()
608  0.66667
609
610  ```python
611  # Used within Keras model
612  model.compile(optimizer='sgd',
613                loss='mse',
614                metrics=[tf.keras.metrics.Accuracy()])
615  ```
616
617  @end_compatibility
618  """
619  if context.executing_eagerly():
620    raise RuntimeError('tf.metrics.accuracy is not supported when eager '
621                       'execution is enabled.')
622
623  predictions, labels, weights = _remove_squeezable_dimensions(
624      predictions=predictions, labels=labels, weights=weights)
625  predictions.get_shape().assert_is_compatible_with(labels.get_shape())
626  if labels.dtype != predictions.dtype:
627    predictions = math_ops.cast(predictions, labels.dtype)
628  is_correct = math_ops.cast(
629      math_ops.equal(predictions, labels), dtypes.float32)
630  return mean(is_correct, weights, metrics_collections, updates_collections,
631              name or 'accuracy')
632
633
634def _confusion_matrix_at_thresholds(labels,
635                                    predictions,
636                                    thresholds,
637                                    weights=None,
638                                    includes=None):
639  """Computes true_positives, false_negatives, true_negatives, false_positives.
640
641  This function creates up to four local variables, `true_positives`,
642  `true_negatives`, `false_positives` and `false_negatives`.
643  `true_positive[i]` is defined as the total weight of values in `predictions`
644  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
645  `false_negatives[i]` is defined as the total weight of values in `predictions`
646  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
647  `true_negatives[i]` is defined as the total weight of values in `predictions`
648  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
649  `false_positives[i]` is defined as the total weight of values in `predictions`
650  above `thresholds[i]` whose corresponding entry in `labels` is `False`.
651
652  For estimation of these metrics over a stream of data, for each metric the
653  function respectively creates an `update_op` operation that updates the
654  variable and returns its value.
655
656  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
657
658  Args:
659    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
660      `bool`.
661    predictions: A floating point `Tensor` of arbitrary shape and whose values
662      are in the range `[0, 1]`.
663    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
664    weights: Optional `Tensor` whose rank is either 0, or the same rank as
665      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
666      be either `1`, or the same as the corresponding `labels` dimension).
667    includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
668        default to all four.
669
670  Returns:
671    values: Dict of variables of shape `[len(thresholds)]`. Keys are from
672        `includes`.
673    update_ops: Dict of operations that increments the `values`. Keys are from
674        `includes`.
675
676  Raises:
677    ValueError: If `predictions` and `labels` have mismatched shapes, or if
678      `weights` is not `None` and its shape doesn't match `predictions`, or if
679      `includes` contains invalid keys.
680  """
681  all_includes = ('tp', 'fn', 'tn', 'fp')
682  if includes is None:
683    includes = all_includes
684  else:
685    for include in includes:
686      if include not in all_includes:
687        raise ValueError(f'Invalid key: {include}')
688
689  with ops.control_dependencies([
690      check_ops.assert_greater_equal(
691          predictions,
692          math_ops.cast(0.0, dtype=predictions.dtype),
693          message='predictions must be in [0, 1]'),
694      check_ops.assert_less_equal(
695          predictions,
696          math_ops.cast(1.0, dtype=predictions.dtype),
697          message='predictions must be in [0, 1]')
698  ]):
699    predictions, labels, weights = _remove_squeezable_dimensions(
700        predictions=math_ops.cast(predictions, dtypes.float32),
701        labels=math_ops.cast(labels, dtype=dtypes.bool),
702        weights=weights)
703
704  num_thresholds = len(thresholds)
705
706  # Reshape predictions and labels.
707  predictions_2d = array_ops.reshape(predictions, [-1, 1])
708  labels_2d = array_ops.reshape(
709      math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
710
711  # Use static shape if known.
712  num_predictions = predictions_2d.get_shape().as_list()[0]
713
714  # Otherwise use dynamic shape.
715  if num_predictions is None:
716    num_predictions = array_ops.shape(predictions_2d)[0]
717  thresh_tiled = array_ops.tile(
718      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
719      array_ops.stack([1, num_predictions]))
720
721  # Tile the predictions after thresholding them across different thresholds.
722  pred_is_pos = math_ops.greater(
723      array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
724      thresh_tiled)
725  if ('fn' in includes) or ('tn' in includes):
726    pred_is_neg = math_ops.logical_not(pred_is_pos)
727
728  # Tile labels by number of thresholds
729  label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
730  if ('fp' in includes) or ('tn' in includes):
731    label_is_neg = math_ops.logical_not(label_is_pos)
732
733  if weights is not None:
734    weights = weights_broadcast_ops.broadcast_weights(
735        math_ops.cast(weights, dtypes.float32), predictions)
736    weights_tiled = array_ops.tile(
737        array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
738    thresh_tiled.get_shape().assert_is_compatible_with(
739        weights_tiled.get_shape())
740  else:
741    weights_tiled = None
742
743  values = {}
744  update_ops = {}
745
746  if 'tp' in includes:
747    true_p = metric_variable(
748        [num_thresholds], dtypes.float32, name='true_positives')
749    is_true_positive = math_ops.cast(
750        math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32)
751    if weights_tiled is not None:
752      is_true_positive *= weights_tiled
753    update_ops['tp'] = state_ops.assign_add(true_p,
754                                            math_ops.reduce_sum(
755                                                is_true_positive, 1))
756    values['tp'] = true_p
757
758  if 'fn' in includes:
759    false_n = metric_variable(
760        [num_thresholds], dtypes.float32, name='false_negatives')
761    is_false_negative = math_ops.cast(
762        math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32)
763    if weights_tiled is not None:
764      is_false_negative *= weights_tiled
765    update_ops['fn'] = state_ops.assign_add(false_n,
766                                            math_ops.reduce_sum(
767                                                is_false_negative, 1))
768    values['fn'] = false_n
769
770  if 'tn' in includes:
771    true_n = metric_variable(
772        [num_thresholds], dtypes.float32, name='true_negatives')
773    is_true_negative = math_ops.cast(
774        math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32)
775    if weights_tiled is not None:
776      is_true_negative *= weights_tiled
777    update_ops['tn'] = state_ops.assign_add(true_n,
778                                            math_ops.reduce_sum(
779                                                is_true_negative, 1))
780    values['tn'] = true_n
781
782  if 'fp' in includes:
783    false_p = metric_variable(
784        [num_thresholds], dtypes.float32, name='false_positives')
785    is_false_positive = math_ops.cast(
786        math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32)
787    if weights_tiled is not None:
788      is_false_positive *= weights_tiled
789    update_ops['fp'] = state_ops.assign_add(false_p,
790                                            math_ops.reduce_sum(
791                                                is_false_positive, 1))
792    values['fp'] = false_p
793
794  return values, update_ops
795
796
797def _aggregate_variable(v, collections):
798  f = lambda distribution, value: distribution.extended.read_var(value)
799  return _aggregate_across_replicas(collections, f, v)
800
801
802@tf_export(v1=['metrics.auc'])
803@deprecated(None,
804            'The value of AUC returned by this may race with the update so '
805            'this is deprecated. Please use tf.keras.metrics.AUC instead.')
806def auc(labels,
807        predictions,
808        weights=None,
809        num_thresholds=200,
810        metrics_collections=None,
811        updates_collections=None,
812        curve='ROC',
813        name=None,
814        summation_method='trapezoidal',
815        thresholds=None):
816  """Computes the approximate AUC via a Riemann sum.
817
818  The `auc` function creates four local variables, `true_positives`,
819  `true_negatives`, `false_positives` and `false_negatives` that are used to
820  compute the AUC. To discretize the AUC curve, a linearly spaced set of
821  thresholds is used to compute pairs of recall and precision values. The area
822  under the ROC-curve is therefore computed using the height of the recall
823  values by the false positive rate, while the area under the PR-curve is the
824  computed using the height of the precision values by the recall.
825
826  This value is ultimately returned as `auc`, an idempotent operation that
827  computes the area under a discretized curve of precision versus recall values
828  (computed using the aforementioned variables). The `num_thresholds` variable
829  controls the degree of discretization with larger numbers of thresholds more
830  closely approximating the true AUC. The quality of the approximation may vary
831  dramatically depending on `num_thresholds`.
832
833  For best results, `predictions` should be distributed approximately uniformly
834  in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
835  approximation may be poor if this is not the case. Setting `summation_method`
836  to 'minoring' or 'majoring' can help quantify the error in the approximation
837  by providing lower or upper bound estimate of the AUC. The `thresholds`
838  parameter can be used to manually specify thresholds which split the
839  predictions more evenly.
840
841  For estimation of the metric over a stream of data, the function creates an
842  `update_op` operation that updates these variables and returns the `auc`.
843
844  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
845
846  Args:
847    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
848      `bool`.
849    predictions: A floating point `Tensor` of arbitrary shape and whose values
850      are in the range `[0, 1]`.
851    weights: Optional `Tensor` whose rank is either 0, or the same rank as
852      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
853      be either `1`, or the same as the corresponding `labels` dimension).
854    num_thresholds: The number of thresholds to use when discretizing the roc
855      curve.
856    metrics_collections: An optional list of collections that `auc` should be
857      added to.
858    updates_collections: An optional list of collections that `update_op` should
859      be added to.
860    curve: Specifies the name of the curve to be computed, 'ROC' [default] or
861      'PR' for the Precision-Recall-curve.
862    name: An optional variable_scope name.
863    summation_method: Specifies the Riemann summation method used
864      (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that
865      applies the trapezoidal rule; 'careful_interpolation', a variant of it
866      differing only by a more correct interpolation scheme for PR-AUC -
867      interpolating (true/false) positives but not the ratio that is precision;
868      'minoring' that applies left summation for increasing intervals and right
869      summation for decreasing intervals; 'majoring' that does the opposite.
870      Note that 'careful_interpolation' is strictly preferred to 'trapezoidal'
871      (to be deprecated soon) as it applies the same method for ROC, and a
872      better one (see Davis & Goadrich 2006 for details) for the PR curve.
873    thresholds: An optional list of floating point values to use as the
874      thresholds for discretizing the curve. If set, the `num_thresholds`
875      parameter is ignored. Values should be in [0, 1]. Endpoint thresholds
876      equal to {-epsilon, 1+epsilon} for a small positive epsilon value will be
877      automatically included with these to correctly handle predictions equal to
878       exactly 0 or 1.
879
880  Returns:
881    auc: A scalar `Tensor` representing the current area-under-curve.
882    update_op: An operation that increments the `true_positives`,
883      `true_negatives`, `false_positives` and `false_negatives` variables
884      appropriately and whose value matches `auc`.
885
886  Raises:
887    ValueError: If `predictions` and `labels` have mismatched shapes, or if
888      `weights` is not `None` and its shape doesn't match `predictions`, or if
889      either `metrics_collections` or `updates_collections` are not a list or
890      tuple.
891    RuntimeError: If eager execution is enabled.
892  """
893  if context.executing_eagerly():
894    raise RuntimeError('tf.metrics.auc is not supported when eager execution '
895                       'is enabled.')
896
897  with variable_scope.variable_scope(name, 'auc',
898                                     (labels, predictions, weights)):
899    if curve != 'ROC' and curve != 'PR':
900      raise ValueError(f'Curve must be either ROC or PR. Curve {curve} is '
901                       'unknown.')
902
903    kepsilon = 1e-7  # To account for floating point imprecisions.
904    if thresholds is not None:
905      # If specified, use the supplied thresholds.
906      thresholds = sorted(thresholds)
907      num_thresholds = len(thresholds) + 2
908    else:
909      # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in
910      # (0, 1).
911      thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
912                    for i in range(num_thresholds - 2)]
913
914    # Add an endpoint "threshold" below zero and above one for either threshold
915    # method.
916    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
917
918    values, update_ops = _confusion_matrix_at_thresholds(
919        labels, predictions, thresholds, weights)
920
921    # Add epsilons to avoid dividing by 0.
922    epsilon = 1.0e-6
923
924    def interpolate_pr_auc(tp, fp, fn):
925      """Interpolation formula inspired by section 4 of (Davis et al., 2006).
926
927      Note here we derive & use a closed formula not present in the paper
928      - as follows:
929      Modeling all of TP (true positive weight),
930      FP (false positive weight) and their sum P = TP + FP (positive weight)
931      as varying linearly within each interval [A, B] between successive
932      thresholds, we get
933        Precision = (TP_A + slope * (P - P_A)) / P
934      with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A).
935      The area within the interval is thus (slope / total_pos_weight) times
936        int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
937        int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
938      where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in
939        int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
940      Bringing back the factor (slope / total_pos_weight) we'd put aside, we get
941         slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight
942      where dTP == TP_B - TP_A.
943      Note that when P_A == 0 the above calculation simplifies into
944        int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
945      which is really equivalent to imputing constant precision throughout the
946      first bucket having >0 true positives.
947
948      Args:
949        tp: true positive counts
950        fp: false positive counts
951        fn: false negative counts
952
953      Returns:
954        pr_auc: an approximation of the area under the P-R curve.
955
956      References:
957        The Relationship Between Precision-Recall and ROC Curves:
958          [Davis et al., 2006](https://dl.acm.org/citation.cfm?id=1143874)
959          ([pdf](https://www.biostat.wisc.edu/~page/rocpr.pdf))
960      """
961      dtp = tp[:num_thresholds - 1] - tp[1:]
962      p = tp + fp
963      prec_slope = math_ops.div_no_nan(
964          dtp,
965          math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0),
966          name='prec_slope')
967      intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:])
968      safe_p_ratio = array_ops.where(
969          math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0),
970          math_ops.div_no_nan(
971              p[:num_thresholds - 1],
972              math_ops.maximum(p[1:], 0),
973              name='recall_relative_ratio'), array_ops.ones_like(p[1:]))
974      return math_ops.reduce_sum(
975          math_ops.div_no_nan(
976              prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)),
977              math_ops.maximum(tp[1:] + fn[1:], 0),
978              name='pr_auc_increment'),
979          name='interpolate_pr_auc')
980
981    def compute_auc(tp, fn, tn, fp, name):
982      """Computes the roc-auc or pr-auc based on confusion counts."""
983      if curve == 'PR':
984        if summation_method == 'trapezoidal':
985          logging.warning(
986              'Trapezoidal rule is known to produce incorrect PR-AUCs; '
987              'please switch to "careful_interpolation" instead.')
988        elif summation_method == 'careful_interpolation':
989          # This one is a bit tricky and is handled separately.
990          return interpolate_pr_auc(tp, fp, fn)
991      rec = math_ops.divide(tp + epsilon, tp + fn + epsilon)
992      if curve == 'ROC':
993        fp_rate = math_ops.divide(fp, fp + tn + epsilon)
994        x = fp_rate
995        y = rec
996      else:  # curve == 'PR'.
997        prec = math_ops.divide(tp + epsilon, tp + fp + epsilon)
998        x = rec
999        y = prec
1000      if summation_method in ('trapezoidal', 'careful_interpolation'):
1001        # Note that the case ('PR', 'careful_interpolation') has been handled
1002        # above.
1003        return math_ops.reduce_sum(
1004            math_ops.multiply(x[:num_thresholds - 1] - x[1:],
1005                              (y[:num_thresholds - 1] + y[1:]) / 2.),
1006            name=name)
1007      elif summation_method == 'minoring':
1008        return math_ops.reduce_sum(
1009            math_ops.multiply(x[:num_thresholds - 1] - x[1:],
1010                              math_ops.minimum(y[:num_thresholds - 1], y[1:])),
1011            name=name)
1012      elif summation_method == 'majoring':
1013        return math_ops.reduce_sum(
1014            math_ops.multiply(x[:num_thresholds - 1] - x[1:],
1015                              math_ops.maximum(y[:num_thresholds - 1], y[1:])),
1016            name=name)
1017      else:
1018        raise ValueError(f'Invalid summation_method: {summation_method} '
1019                         'summation_method should be \'trapezoidal\', '
1020                         '\'careful_interpolation\', \'minoring\', or '
1021                         '\'majoring\'.')
1022
1023    # sum up the areas of all the trapeziums
1024    def compute_auc_value(_, values):
1025      return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'],
1026                         'value')
1027
1028    auc_value = _aggregate_across_replicas(
1029        metrics_collections, compute_auc_value, values)
1030    update_op = compute_auc(update_ops['tp'], update_ops['fn'],
1031                            update_ops['tn'], update_ops['fp'], 'update_op')
1032
1033    if updates_collections:
1034      ops.add_to_collections(updates_collections, update_op)
1035
1036    return auc_value, update_op
1037
1038
1039@tf_export(v1=['metrics.mean_absolute_error'])
1040def mean_absolute_error(labels,
1041                        predictions,
1042                        weights=None,
1043                        metrics_collections=None,
1044                        updates_collections=None,
1045                        name=None):
1046  """Computes the mean absolute error between the labels and predictions.
1047
1048  The `mean_absolute_error` function creates two local variables,
1049  `total` and `count` that are used to compute the mean absolute error. This
1050  average is weighted by `weights`, and it is ultimately returned as
1051  `mean_absolute_error`: an idempotent operation that simply divides `total` by
1052  `count`.
1053
1054  For estimation of the metric over a stream of data, the function creates an
1055  `update_op` operation that updates these variables and returns the
1056  `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
1057  absolute value of the differences between `predictions` and `labels`. Then
1058  `update_op` increments `total` with the reduced sum of the product of
1059  `weights` and `absolute_errors`, and it increments `count` with the reduced
1060  sum of `weights`
1061
1062  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1063
1064  Args:
1065    labels: A `Tensor` of the same shape as `predictions`.
1066    predictions: A `Tensor` of arbitrary shape.
1067    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1068      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1069      be either `1`, or the same as the corresponding `labels` dimension).
1070    metrics_collections: An optional list of collections that
1071      `mean_absolute_error` should be added to.
1072    updates_collections: An optional list of collections that `update_op` should
1073      be added to.
1074    name: An optional variable_scope name.
1075
1076  Returns:
1077    mean_absolute_error: A `Tensor` representing the current mean, the value of
1078      `total` divided by `count`.
1079    update_op: An operation that increments the `total` and `count` variables
1080      appropriately and whose value matches `mean_absolute_error`.
1081
1082  Raises:
1083    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1084      `weights` is not `None` and its shape doesn't match `predictions`, or if
1085      either `metrics_collections` or `updates_collections` are not a list or
1086      tuple.
1087    RuntimeError: If eager execution is enabled.
1088  """
1089  if context.executing_eagerly():
1090    raise RuntimeError('tf.metrics.mean_absolute_error is not supported '
1091                       'when eager execution is enabled.')
1092
1093  predictions, labels, weights = _remove_squeezable_dimensions(
1094      predictions=predictions, labels=labels, weights=weights)
1095  absolute_errors = math_ops.abs(predictions - labels)
1096  return mean(absolute_errors, weights, metrics_collections,
1097              updates_collections, name or 'mean_absolute_error')
1098
1099
1100@tf_export(v1=['metrics.mean_cosine_distance'])
1101def mean_cosine_distance(labels,
1102                         predictions,
1103                         dim,
1104                         weights=None,
1105                         metrics_collections=None,
1106                         updates_collections=None,
1107                         name=None):
1108  """Computes the cosine distance between the labels and predictions.
1109
1110  The `mean_cosine_distance` function creates two local variables,
1111  `total` and `count` that are used to compute the average cosine distance
1112  between `predictions` and `labels`. This average is weighted by `weights`,
1113  and it is ultimately returned as `mean_distance`, which is an idempotent
1114  operation that simply divides `total` by `count`.
1115
1116  For estimation of the metric over a stream of data, the function creates an
1117  `update_op` operation that updates these variables and returns the
1118  `mean_distance`.
1119
1120  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1121
1122  Args:
1123    labels: A `Tensor` of arbitrary shape.
1124    predictions: A `Tensor` of the same shape as `labels`.
1125    dim: The dimension along which the cosine distance is computed.
1126    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1127      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1128      be either `1`, or the same as the corresponding `labels` dimension). Also,
1129      dimension `dim` must be `1`.
1130    metrics_collections: An optional list of collections that the metric
1131      value variable should be added to.
1132    updates_collections: An optional list of collections that the metric update
1133      ops should be added to.
1134    name: An optional variable_scope name.
1135
1136  Returns:
1137    mean_distance: A `Tensor` representing the current mean, the value of
1138      `total` divided by `count`.
1139    update_op: An operation that increments the `total` and `count` variables
1140      appropriately.
1141
1142  Raises:
1143    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1144      `weights` is not `None` and its shape doesn't match `predictions`, or if
1145      either `metrics_collections` or `updates_collections` are not a list or
1146      tuple.
1147    RuntimeError: If eager execution is enabled.
1148  """
1149  if context.executing_eagerly():
1150    raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when '
1151                       'eager execution is enabled.')
1152
1153  predictions, labels, weights = _remove_squeezable_dimensions(
1154      predictions=predictions, labels=labels, weights=weights)
1155  radial_diffs = math_ops.multiply(predictions, labels)
1156  radial_diffs = math_ops.reduce_sum(
1157      radial_diffs, axis=[
1158          dim,
1159      ], keepdims=True)
1160  mean_distance, update_op = mean(radial_diffs, weights, None, None, name or
1161                                  'mean_cosine_distance')
1162  mean_distance = math_ops.subtract(1.0, mean_distance)
1163  update_op = math_ops.subtract(1.0, update_op)
1164
1165  if metrics_collections:
1166    ops.add_to_collections(metrics_collections, mean_distance)
1167
1168  if updates_collections:
1169    ops.add_to_collections(updates_collections, update_op)
1170
1171  return mean_distance, update_op
1172
1173
1174@tf_export(v1=['metrics.mean_per_class_accuracy'])
1175def mean_per_class_accuracy(labels,
1176                            predictions,
1177                            num_classes,
1178                            weights=None,
1179                            metrics_collections=None,
1180                            updates_collections=None,
1181                            name=None):
1182  """Calculates the mean of the per-class accuracies.
1183
1184  Calculates the accuracy for each class, then takes the mean of that.
1185
1186  For estimation of the metric over a stream of data, the function creates an
1187  `update_op` operation that updates the accuracy of each class and returns
1188  them.
1189
1190  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1191
1192  Args:
1193    labels: A `Tensor` of ground truth labels with shape [batch size] and of
1194      type `int32` or `int64`. The tensor will be flattened if its rank > 1.
1195    predictions: A `Tensor` of prediction results for semantic labels, whose
1196      shape is [batch size] and type `int32` or `int64`. The tensor will be
1197      flattened if its rank > 1.
1198    num_classes: The possible number of labels the prediction task can
1199      have. This value must be provided, since two variables with shape =
1200      [num_classes] will be allocated.
1201    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1202      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1203      be either `1`, or the same as the corresponding `labels` dimension).
1204    metrics_collections: An optional list of collections that
1205      `mean_per_class_accuracy'
1206      should be added to.
1207    updates_collections: An optional list of collections `update_op` should be
1208      added to.
1209    name: An optional variable_scope name.
1210
1211  Returns:
1212    mean_accuracy: A `Tensor` representing the mean per class accuracy.
1213    update_op: An operation that updates the accuracy tensor.
1214
1215  Raises:
1216    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1217      `weights` is not `None` and its shape doesn't match `predictions`, or if
1218      either `metrics_collections` or `updates_collections` are not a list or
1219      tuple.
1220    RuntimeError: If eager execution is enabled.
1221  """
1222  if context.executing_eagerly():
1223    raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported '
1224                       'when eager execution is enabled.')
1225
1226  with variable_scope.variable_scope(name, 'mean_accuracy',
1227                                     (predictions, labels, weights)):
1228    labels = math_ops.cast(labels, dtypes.int64)
1229
1230    # Flatten the input if its rank > 1.
1231    if labels.get_shape().ndims > 1:
1232      labels = array_ops.reshape(labels, [-1])
1233
1234    if predictions.get_shape().ndims > 1:
1235      predictions = array_ops.reshape(predictions, [-1])
1236
1237    # Check if shape is compatible.
1238    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1239
1240    total = metric_variable([num_classes], dtypes.float32, name='total')
1241    count = metric_variable([num_classes], dtypes.float32, name='count')
1242
1243    ones = array_ops.ones([array_ops.size(labels)], dtypes.float32)
1244
1245    if labels.dtype != predictions.dtype:
1246      predictions = math_ops.cast(predictions, labels.dtype)
1247    is_correct = math_ops.cast(
1248        math_ops.equal(predictions, labels), dtypes.float32)
1249
1250    if weights is not None:
1251      if weights.get_shape().ndims > 1:
1252        weights = array_ops.reshape(weights, [-1])
1253      weights = math_ops.cast(weights, dtypes.float32)
1254
1255      is_correct *= weights
1256      ones *= weights
1257
1258    update_total_op = state_ops.scatter_add(total, labels, ones)
1259    update_count_op = state_ops.scatter_add(count, labels, is_correct)
1260
1261    def compute_mean_accuracy(_, count, total):
1262      per_class_accuracy = math_ops.div_no_nan(
1263          count, math_ops.maximum(total, 0), name=None)
1264      mean_accuracy_v = math_ops.reduce_mean(
1265          per_class_accuracy, name='mean_accuracy')
1266      return mean_accuracy_v
1267
1268    mean_accuracy_v = _aggregate_across_replicas(
1269        metrics_collections, compute_mean_accuracy, count, total)
1270
1271    update_op = math_ops.div_no_nan(
1272        update_count_op, math_ops.maximum(update_total_op, 0), name='update_op')
1273    if updates_collections:
1274      ops.add_to_collections(updates_collections, update_op)
1275
1276    return mean_accuracy_v, update_op
1277
1278
1279@tf_export(v1=['metrics.mean_iou'])
1280def mean_iou(labels,
1281             predictions,
1282             num_classes,
1283             weights=None,
1284             metrics_collections=None,
1285             updates_collections=None,
1286             name=None):
1287  """Calculate per-step mean Intersection-Over-Union (mIOU).
1288
1289  Mean Intersection-Over-Union is a common evaluation metric for
1290  semantic image segmentation, which first computes the IOU for each
1291  semantic class and then computes the average over classes.
1292  IOU is defined as follows:
1293    IOU = true_positive / (true_positive + false_positive + false_negative).
1294  The predictions are accumulated in a confusion matrix, weighted by `weights`,
1295  and mIOU is then calculated from it.
1296
1297  For estimation of the metric over a stream of data, the function creates an
1298  `update_op` operation that updates these variables and returns the `mean_iou`.
1299
1300  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1301
1302  Args:
1303    labels: A `Tensor` of ground truth labels with shape [batch size] and of
1304      type `int32` or `int64`. The tensor will be flattened if its rank > 1.
1305    predictions: A `Tensor` of prediction results for semantic labels, whose
1306      shape is [batch size] and type `int32` or `int64`. The tensor will be
1307      flattened if its rank > 1.
1308    num_classes: The possible number of labels the prediction task can
1309      have. This value must be provided, since a confusion matrix of
1310      dimension = [num_classes, num_classes] will be allocated.
1311    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1312      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1313      be either `1`, or the same as the corresponding `labels` dimension).
1314    metrics_collections: An optional list of collections that `mean_iou`
1315      should be added to.
1316    updates_collections: An optional list of collections `update_op` should be
1317      added to.
1318    name: An optional variable_scope name.
1319
1320  Returns:
1321    mean_iou: A `Tensor` representing the mean intersection-over-union.
1322    update_op: An operation that increments the confusion matrix.
1323
1324  Raises:
1325    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1326      `weights` is not `None` and its shape doesn't match `predictions`, or if
1327      either `metrics_collections` or `updates_collections` are not a list or
1328      tuple.
1329    RuntimeError: If eager execution is enabled.
1330  """
1331  if context.executing_eagerly():
1332    raise RuntimeError('tf.metrics.mean_iou is not supported when '
1333                       'eager execution is enabled.')
1334
1335  with variable_scope.variable_scope(name, 'mean_iou',
1336                                     (predictions, labels, weights)):
1337    # Check if shape is compatible.
1338    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
1339
1340    total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
1341                                                      num_classes, weights)
1342
1343    def compute_mean_iou(_, total_cm):
1344      """Compute the mean intersection-over-union via the confusion matrix."""
1345      sum_over_row = math_ops.cast(
1346          math_ops.reduce_sum(total_cm, 0), dtypes.float32)
1347      sum_over_col = math_ops.cast(
1348          math_ops.reduce_sum(total_cm, 1), dtypes.float32)
1349      cm_diag = math_ops.cast(array_ops.diag_part(total_cm), dtypes.float32)
1350      denominator = sum_over_row + sum_over_col - cm_diag
1351
1352      # The mean is only computed over classes that appear in the
1353      # label or prediction tensor. If the denominator is 0, we need to
1354      # ignore the class.
1355      num_valid_entries = math_ops.reduce_sum(
1356          math_ops.cast(
1357              math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
1358
1359      # If the value of the denominator is 0, set it to 1 to avoid
1360      # zero division.
1361      denominator = array_ops.where(
1362          math_ops.greater(denominator, 0), denominator,
1363          array_ops.ones_like(denominator))
1364      iou = math_ops.divide(cm_diag, denominator)
1365
1366      # If the number of valid entries is 0 (no classes) we return 0.
1367      result = array_ops.where(
1368          math_ops.greater(num_valid_entries, 0),
1369          math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0)
1370      return result
1371
1372    # TODO(priyag): Use outside_compilation if in TPU context.
1373    mean_iou_v = _aggregate_across_replicas(
1374        metrics_collections, compute_mean_iou, total_cm)
1375
1376    if updates_collections:
1377      ops.add_to_collections(updates_collections, update_op)
1378
1379    return mean_iou_v, update_op
1380
1381
1382@tf_export(v1=['metrics.mean_relative_error'])
1383def mean_relative_error(labels,
1384                        predictions,
1385                        normalizer,
1386                        weights=None,
1387                        metrics_collections=None,
1388                        updates_collections=None,
1389                        name=None):
1390  """Computes the mean relative error by normalizing with the given values.
1391
1392  The `mean_relative_error` function creates two local variables,
1393  `total` and `count` that are used to compute the mean relative absolute error.
1394  This average is weighted by `weights`, and it is ultimately returned as
1395  `mean_relative_error`: an idempotent operation that simply divides `total` by
1396  `count`.
1397
1398  For estimation of the metric over a stream of data, the function creates an
1399  `update_op` operation that updates these variables and returns the
1400  `mean_reative_error`. Internally, a `relative_errors` operation divides the
1401  absolute value of the differences between `predictions` and `labels` by the
1402  `normalizer`. Then `update_op` increments `total` with the reduced sum of the
1403  product of `weights` and `relative_errors`, and it increments `count` with the
1404  reduced sum of `weights`.
1405
1406  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1407
1408  Args:
1409    labels: A `Tensor` of the same shape as `predictions`.
1410    predictions: A `Tensor` of arbitrary shape.
1411    normalizer: A `Tensor` of the same shape as `predictions`.
1412    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1413      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1414      be either `1`, or the same as the corresponding `labels` dimension).
1415    metrics_collections: An optional list of collections that
1416      `mean_relative_error` should be added to.
1417    updates_collections: An optional list of collections that `update_op` should
1418      be added to.
1419    name: An optional variable_scope name.
1420
1421  Returns:
1422    mean_relative_error: A `Tensor` representing the current mean, the value of
1423      `total` divided by `count`.
1424    update_op: An operation that increments the `total` and `count` variables
1425      appropriately and whose value matches `mean_relative_error`.
1426
1427  Raises:
1428    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1429      `weights` is not `None` and its shape doesn't match `predictions`, or if
1430      either `metrics_collections` or `updates_collections` are not a list or
1431      tuple.
1432    RuntimeError: If eager execution is enabled.
1433  """
1434  if context.executing_eagerly():
1435    raise RuntimeError('tf.metrics.mean_relative_error is not supported when '
1436                       'eager execution is enabled.')
1437
1438  predictions, labels, weights = _remove_squeezable_dimensions(
1439      predictions=predictions, labels=labels, weights=weights)
1440
1441  predictions, normalizer = confusion_matrix.remove_squeezable_dimensions(
1442      predictions, normalizer)
1443  predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
1444  relative_errors = array_ops.where(
1445      math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels),
1446      math_ops.divide(math_ops.abs(labels - predictions), normalizer))
1447  return mean(relative_errors, weights, metrics_collections,
1448              updates_collections, name or 'mean_relative_error')
1449
1450
1451@tf_export(v1=['metrics.mean_squared_error'])
1452def mean_squared_error(labels,
1453                       predictions,
1454                       weights=None,
1455                       metrics_collections=None,
1456                       updates_collections=None,
1457                       name=None):
1458  """Computes the mean squared error between the labels and predictions.
1459
1460  The `mean_squared_error` function creates two local variables,
1461  `total` and `count` that are used to compute the mean squared error.
1462  This average is weighted by `weights`, and it is ultimately returned as
1463  `mean_squared_error`: an idempotent operation that simply divides `total` by
1464  `count`.
1465
1466  For estimation of the metric over a stream of data, the function creates an
1467  `update_op` operation that updates these variables and returns the
1468  `mean_squared_error`. Internally, a `squared_error` operation computes the
1469  element-wise square of the difference between `predictions` and `labels`. Then
1470  `update_op` increments `total` with the reduced sum of the product of
1471  `weights` and `squared_error`, and it increments `count` with the reduced sum
1472  of `weights`.
1473
1474  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1475
1476  Args:
1477    labels: A `Tensor` of the same shape as `predictions`.
1478    predictions: A `Tensor` of arbitrary shape.
1479    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1480      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1481      be either `1`, or the same as the corresponding `labels` dimension).
1482    metrics_collections: An optional list of collections that
1483      `mean_squared_error` should be added to.
1484    updates_collections: An optional list of collections that `update_op` should
1485      be added to.
1486    name: An optional variable_scope name.
1487
1488  Returns:
1489    mean_squared_error: A `Tensor` representing the current mean, the value of
1490      `total` divided by `count`.
1491    update_op: An operation that increments the `total` and `count` variables
1492      appropriately and whose value matches `mean_squared_error`.
1493
1494  Raises:
1495    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1496      `weights` is not `None` and its shape doesn't match `predictions`, or if
1497      either `metrics_collections` or `updates_collections` are not a list or
1498      tuple.
1499    RuntimeError: If eager execution is enabled.
1500  """
1501  if context.executing_eagerly():
1502    raise RuntimeError('tf.metrics.mean_squared_error is not supported when '
1503                       'eager execution is enabled.')
1504
1505  predictions, labels, weights = _remove_squeezable_dimensions(
1506      predictions=predictions, labels=labels, weights=weights)
1507  squared_error = math_ops.squared_difference(labels, predictions)
1508  return mean(squared_error, weights, metrics_collections, updates_collections,
1509              name or 'mean_squared_error')
1510
1511
1512@tf_export(v1=['metrics.mean_tensor'])
1513def mean_tensor(values,
1514                weights=None,
1515                metrics_collections=None,
1516                updates_collections=None,
1517                name=None):
1518  """Computes the element-wise (weighted) mean of the given tensors.
1519
1520  In contrast to the `mean` function which returns a scalar with the
1521  mean,  this function returns an average tensor with the same shape as the
1522  input tensors.
1523
1524  The `mean_tensor` function creates two local variables,
1525  `total_tensor` and `count_tensor` that are used to compute the average of
1526  `values`. This average is ultimately returned as `mean` which is an idempotent
1527  operation that simply divides `total` by `count`.
1528
1529  For estimation of the metric over a stream of data, the function creates an
1530  `update_op` operation that updates these variables and returns the `mean`.
1531  `update_op` increments `total` with the reduced sum of the product of `values`
1532  and `weights`, and it increments `count` with the reduced sum of `weights`.
1533
1534  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1535
1536  Args:
1537    values: A `Tensor` of arbitrary dimensions.
1538    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1539      `values`, and must be broadcastable to `values` (i.e., all dimensions must
1540      be either `1`, or the same as the corresponding `values` dimension).
1541    metrics_collections: An optional list of collections that `mean`
1542      should be added to.
1543    updates_collections: An optional list of collections that `update_op`
1544      should be added to.
1545    name: An optional variable_scope name.
1546
1547  Returns:
1548    mean: A float `Tensor` representing the current mean, the value of `total`
1549      divided by `count`.
1550    update_op: An operation that increments the `total` and `count` variables
1551      appropriately and whose value matches `mean_value`.
1552
1553  Raises:
1554    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1555      or if either `metrics_collections` or `updates_collections` are not a list
1556      or tuple.
1557    RuntimeError: If eager execution is enabled.
1558  """
1559  if context.executing_eagerly():
1560    raise RuntimeError('tf.metrics.mean_tensor is not supported when '
1561                       'eager execution is enabled.')
1562
1563  with variable_scope.variable_scope(name, 'mean', (values, weights)):
1564    values = math_ops.cast(values, dtypes.float32)
1565    total = metric_variable(
1566        values.get_shape(), dtypes.float32, name='total_tensor')
1567    count = metric_variable(
1568        values.get_shape(), dtypes.float32, name='count_tensor')
1569
1570    num_values = array_ops.ones_like(values)
1571    if weights is not None:
1572      values, _, weights = _remove_squeezable_dimensions(
1573          predictions=values, labels=None, weights=weights)
1574      weights = weights_broadcast_ops.broadcast_weights(
1575          math_ops.cast(weights, dtypes.float32), values)
1576      values = math_ops.multiply(values, weights)
1577      num_values = math_ops.multiply(num_values, weights)
1578
1579    update_total_op = state_ops.assign_add(total, values)
1580    with ops.control_dependencies([values]):
1581      update_count_op = state_ops.assign_add(count, num_values)
1582
1583    compute_mean = lambda _, t, c: math_ops.div_no_nan(  # pylint: disable=g-long-lambda
1584        t, math_ops.maximum(c, 0), name='value')
1585
1586    mean_t = _aggregate_across_replicas(
1587        metrics_collections, compute_mean, total, count)
1588
1589    update_op = math_ops.div_no_nan(
1590        update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')
1591    if updates_collections:
1592      ops.add_to_collections(updates_collections, update_op)
1593
1594    return mean_t, update_op
1595
1596
1597@tf_export(v1=['metrics.percentage_below'])
1598def percentage_below(values,
1599                     threshold,
1600                     weights=None,
1601                     metrics_collections=None,
1602                     updates_collections=None,
1603                     name=None):
1604  """Computes the percentage of values less than the given threshold.
1605
1606  The `percentage_below` function creates two local variables,
1607  `total` and `count` that are used to compute the percentage of `values` that
1608  fall below `threshold`. This rate is weighted by `weights`, and it is
1609  ultimately returned as `percentage` which is an idempotent operation that
1610  simply divides `total` by `count`.
1611
1612  For estimation of the metric over a stream of data, the function creates an
1613  `update_op` operation that updates these variables and returns the
1614  `percentage`.
1615
1616  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1617
1618  Args:
1619    values: A numeric `Tensor` of arbitrary size.
1620    threshold: A scalar threshold.
1621    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1622      `values`, and must be broadcastable to `values` (i.e., all dimensions must
1623      be either `1`, or the same as the corresponding `values` dimension).
1624    metrics_collections: An optional list of collections that the metric
1625      value variable should be added to.
1626    updates_collections: An optional list of collections that the metric update
1627      ops should be added to.
1628    name: An optional variable_scope name.
1629
1630  Returns:
1631    percentage: A `Tensor` representing the current mean, the value of `total`
1632      divided by `count`.
1633    update_op: An operation that increments the `total` and `count` variables
1634      appropriately.
1635
1636  Raises:
1637    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1638      or if either `metrics_collections` or `updates_collections` are not a list
1639      or tuple.
1640    RuntimeError: If eager execution is enabled.
1641  """
1642  if context.executing_eagerly():
1643    raise RuntimeError('tf.metrics.percentage_below is not supported when '
1644                       'eager execution is enabled.')
1645
1646  is_below_threshold = math_ops.cast(
1647      math_ops.less(values, threshold), dtypes.float32)
1648  return mean(is_below_threshold, weights, metrics_collections,
1649              updates_collections, name or 'percentage_below_threshold')
1650
1651
1652def _count_condition(values,
1653                     weights=None,
1654                     metrics_collections=None,
1655                     updates_collections=None):
1656  """Sums the weights of cases where the given values are True.
1657
1658  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1659
1660  Args:
1661    values: A `bool` `Tensor` of arbitrary size.
1662    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1663      `values`, and must be broadcastable to `values` (i.e., all dimensions must
1664      be either `1`, or the same as the corresponding `values` dimension).
1665    metrics_collections: An optional list of collections that the metric
1666      value variable should be added to.
1667    updates_collections: An optional list of collections that the metric update
1668      ops should be added to.
1669
1670  Returns:
1671    value_tensor: A `Tensor` representing the current value of the metric.
1672    update_op: An operation that accumulates the error from a batch of data.
1673
1674  Raises:
1675    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1676      or if either `metrics_collections` or `updates_collections` are not a list
1677      or tuple.
1678  """
1679  check_ops.assert_type(values, dtypes.bool)
1680  count = metric_variable([], dtypes.float32, name='count')
1681
1682  values = math_ops.cast(values, dtypes.float32)
1683  if weights is not None:
1684    with ops.control_dependencies((check_ops.assert_rank_in(
1685        weights, (0, array_ops.rank(values))),)):
1686      weights = math_ops.cast(weights, dtypes.float32)
1687      values = math_ops.multiply(values, weights)
1688
1689  value_tensor = _aggregate_variable(count, metrics_collections)
1690
1691  update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
1692  if updates_collections:
1693    ops.add_to_collections(updates_collections, update_op)
1694
1695  return value_tensor, update_op
1696
1697
1698@tf_export(v1=['metrics.false_negatives'])
1699def false_negatives(labels,
1700                    predictions,
1701                    weights=None,
1702                    metrics_collections=None,
1703                    updates_collections=None,
1704                    name=None):
1705  """Computes the total number of false negatives.
1706
1707  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1708
1709  Args:
1710    labels: The ground truth values, a `Tensor` whose dimensions must match
1711      `predictions`. Will be cast to `bool`.
1712    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1713      be cast to `bool`.
1714    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1715      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1716      be either `1`, or the same as the corresponding `labels` dimension).
1717    metrics_collections: An optional list of collections that the metric
1718      value variable should be added to.
1719    updates_collections: An optional list of collections that the metric update
1720      ops should be added to.
1721    name: An optional variable_scope name.
1722
1723  Returns:
1724    value_tensor: A `Tensor` representing the current value of the metric.
1725    update_op: An operation that accumulates the error from a batch of data.
1726
1727  Raises:
1728    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
1729      or if either `metrics_collections` or `updates_collections` are not a list
1730      or tuple.
1731    RuntimeError: If eager execution is enabled.
1732  """
1733  if context.executing_eagerly():
1734    raise RuntimeError('tf.metrics.false_negatives is not supported when '
1735                       'eager execution is enabled.')
1736
1737  with variable_scope.variable_scope(name, 'false_negatives',
1738                                     (predictions, labels, weights)):
1739
1740    predictions, labels, weights = _remove_squeezable_dimensions(
1741        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1742        labels=math_ops.cast(labels, dtype=dtypes.bool),
1743        weights=weights)
1744    is_false_negative = math_ops.logical_and(
1745        math_ops.equal(labels, True), math_ops.equal(predictions, False))
1746    return _count_condition(is_false_negative, weights, metrics_collections,
1747                            updates_collections)
1748
1749
1750@tf_export(v1=['metrics.false_negatives_at_thresholds'])
1751def false_negatives_at_thresholds(labels,
1752                                  predictions,
1753                                  thresholds,
1754                                  weights=None,
1755                                  metrics_collections=None,
1756                                  updates_collections=None,
1757                                  name=None):
1758  """Computes false negatives at provided threshold values.
1759
1760  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1761
1762  Args:
1763    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1764      `bool`.
1765    predictions: A floating point `Tensor` of arbitrary shape and whose values
1766      are in the range `[0, 1]`.
1767    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1768    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1769      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1770      be either `1`, or the same as the corresponding `labels` dimension).
1771    metrics_collections: An optional list of collections that `false_negatives`
1772      should be added to.
1773    updates_collections: An optional list of collections that `update_op` should
1774      be added to.
1775    name: An optional variable_scope name.
1776
1777  Returns:
1778    false_negatives:  A float `Tensor` of shape `[len(thresholds)]`.
1779    update_op: An operation that updates the `false_negatives` variable and
1780      returns its current value.
1781
1782  Raises:
1783    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1784      `weights` is not `None` and its shape doesn't match `predictions`, or if
1785      either `metrics_collections` or `updates_collections` are not a list or
1786      tuple.
1787    RuntimeError: If eager execution is enabled.
1788  """
1789  if context.executing_eagerly():
1790    raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not '
1791                       'supported when eager execution is enabled.')
1792
1793  with variable_scope.variable_scope(name, 'false_negatives',
1794                                     (predictions, labels, weights)):
1795    values, update_ops = _confusion_matrix_at_thresholds(
1796        labels, predictions, thresholds, weights=weights, includes=('fn',))
1797
1798    fn_value = _aggregate_variable(values['fn'], metrics_collections)
1799
1800    if updates_collections:
1801      ops.add_to_collections(updates_collections, update_ops['fn'])
1802
1803    return fn_value, update_ops['fn']
1804
1805
1806@tf_export(v1=['metrics.false_positives'])
1807def false_positives(labels,
1808                    predictions,
1809                    weights=None,
1810                    metrics_collections=None,
1811                    updates_collections=None,
1812                    name=None):
1813  """Sum the weights of false positives.
1814
1815  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1816
1817  Args:
1818    labels: The ground truth values, a `Tensor` whose dimensions must match
1819      `predictions`. Will be cast to `bool`.
1820    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1821      be cast to `bool`.
1822    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1823      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1824      be either `1`, or the same as the corresponding `labels` dimension).
1825    metrics_collections: An optional list of collections that the metric
1826      value variable should be added to.
1827    updates_collections: An optional list of collections that the metric update
1828      ops should be added to.
1829    name: An optional variable_scope name.
1830
1831  Returns:
1832    value_tensor: A `Tensor` representing the current value of the metric.
1833    update_op: An operation that accumulates the error from a batch of data.
1834
1835  Raises:
1836    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1837      `weights` is not `None` and its shape doesn't match `predictions`, or if
1838      either `metrics_collections` or `updates_collections` are not a list or
1839      tuple.
1840    RuntimeError: If eager execution is enabled.
1841  """
1842  if context.executing_eagerly():
1843    raise RuntimeError('tf.metrics.false_positives is not supported when '
1844                       'eager execution is enabled.')
1845
1846  with variable_scope.variable_scope(name, 'false_positives',
1847                                     (predictions, labels, weights)):
1848
1849    predictions, labels, weights = _remove_squeezable_dimensions(
1850        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1851        labels=math_ops.cast(labels, dtype=dtypes.bool),
1852        weights=weights)
1853    is_false_positive = math_ops.logical_and(
1854        math_ops.equal(labels, False), math_ops.equal(predictions, True))
1855    return _count_condition(is_false_positive, weights, metrics_collections,
1856                            updates_collections)
1857
1858
1859@tf_export(v1=['metrics.false_positives_at_thresholds'])
1860def false_positives_at_thresholds(labels,
1861                                  predictions,
1862                                  thresholds,
1863                                  weights=None,
1864                                  metrics_collections=None,
1865                                  updates_collections=None,
1866                                  name=None):
1867  """Computes false positives at provided threshold values.
1868
1869  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1870
1871  Args:
1872    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1873      `bool`.
1874    predictions: A floating point `Tensor` of arbitrary shape and whose values
1875      are in the range `[0, 1]`.
1876    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1877    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1878      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1879      be either `1`, or the same as the corresponding `labels` dimension).
1880    metrics_collections: An optional list of collections that `false_positives`
1881      should be added to.
1882    updates_collections: An optional list of collections that `update_op` should
1883      be added to.
1884    name: An optional variable_scope name.
1885
1886  Returns:
1887    false_positives:  A float `Tensor` of shape `[len(thresholds)]`.
1888    update_op: An operation that updates the `false_positives` variable and
1889      returns its current value.
1890
1891  Raises:
1892    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1893      `weights` is not `None` and its shape doesn't match `predictions`, or if
1894      either `metrics_collections` or `updates_collections` are not a list or
1895      tuple.
1896    RuntimeError: If eager execution is enabled.
1897  """
1898  if context.executing_eagerly():
1899    raise RuntimeError('tf.metrics.false_positives_at_thresholds is not '
1900                       'supported when eager execution is enabled.')
1901
1902  with variable_scope.variable_scope(name, 'false_positives',
1903                                     (predictions, labels, weights)):
1904    values, update_ops = _confusion_matrix_at_thresholds(
1905        labels, predictions, thresholds, weights=weights, includes=('fp',))
1906
1907    fp_value = _aggregate_variable(values['fp'], metrics_collections)
1908
1909    if updates_collections:
1910      ops.add_to_collections(updates_collections, update_ops['fp'])
1911
1912    return fp_value, update_ops['fp']
1913
1914
1915@tf_export(v1=['metrics.true_negatives'])
1916def true_negatives(labels,
1917                   predictions,
1918                   weights=None,
1919                   metrics_collections=None,
1920                   updates_collections=None,
1921                   name=None):
1922  """Sum the weights of true_negatives.
1923
1924  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1925
1926  Args:
1927    labels: The ground truth values, a `Tensor` whose dimensions must match
1928      `predictions`. Will be cast to `bool`.
1929    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
1930      be cast to `bool`.
1931    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1932      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1933      be either `1`, or the same as the corresponding `labels` dimension).
1934    metrics_collections: An optional list of collections that the metric
1935      value variable should be added to.
1936    updates_collections: An optional list of collections that the metric update
1937      ops should be added to.
1938    name: An optional variable_scope name.
1939
1940  Returns:
1941    value_tensor: A `Tensor` representing the current value of the metric.
1942    update_op: An operation that accumulates the error from a batch of data.
1943
1944  Raises:
1945    ValueError: If `predictions` and `labels` have mismatched shapes, or if
1946      `weights` is not `None` and its shape doesn't match `predictions`, or if
1947      either `metrics_collections` or `updates_collections` are not a list or
1948      tuple.
1949    RuntimeError: If eager execution is enabled.
1950  """
1951  if context.executing_eagerly():
1952    raise RuntimeError('tf.metrics.true_negatives is not '
1953                       'supported when eager execution is enabled.')
1954
1955  with variable_scope.variable_scope(name, 'true_negatives',
1956                                     (predictions, labels, weights)):
1957
1958    predictions, labels, weights = _remove_squeezable_dimensions(
1959        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
1960        labels=math_ops.cast(labels, dtype=dtypes.bool),
1961        weights=weights)
1962    is_true_negative = math_ops.logical_and(
1963        math_ops.equal(labels, False), math_ops.equal(predictions, False))
1964    return _count_condition(is_true_negative, weights, metrics_collections,
1965                            updates_collections)
1966
1967
1968@tf_export(v1=['metrics.true_negatives_at_thresholds'])
1969def true_negatives_at_thresholds(labels,
1970                                 predictions,
1971                                 thresholds,
1972                                 weights=None,
1973                                 metrics_collections=None,
1974                                 updates_collections=None,
1975                                 name=None):
1976  """Computes true negatives at provided threshold values.
1977
1978  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
1979
1980  Args:
1981    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
1982      `bool`.
1983    predictions: A floating point `Tensor` of arbitrary shape and whose values
1984      are in the range `[0, 1]`.
1985    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
1986    weights: Optional `Tensor` whose rank is either 0, or the same rank as
1987      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
1988      be either `1`, or the same as the corresponding `labels` dimension).
1989    metrics_collections: An optional list of collections that `true_negatives`
1990      should be added to.
1991    updates_collections: An optional list of collections that `update_op` should
1992      be added to.
1993    name: An optional variable_scope name.
1994
1995  Returns:
1996    true_negatives:  A float `Tensor` of shape `[len(thresholds)]`.
1997    update_op: An operation that updates the `true_negatives` variable and
1998      returns its current value.
1999
2000  Raises:
2001    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2002      `weights` is not `None` and its shape doesn't match `predictions`, or if
2003      either `metrics_collections` or `updates_collections` are not a list or
2004      tuple.
2005    RuntimeError: If eager execution is enabled.
2006  """
2007  if context.executing_eagerly():
2008    raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not '
2009                       'supported when eager execution is enabled.')
2010
2011  with variable_scope.variable_scope(name, 'true_negatives',
2012                                     (predictions, labels, weights)):
2013    values, update_ops = _confusion_matrix_at_thresholds(
2014        labels, predictions, thresholds, weights=weights, includes=('tn',))
2015
2016    tn_value = _aggregate_variable(values['tn'], metrics_collections)
2017
2018    if updates_collections:
2019      ops.add_to_collections(updates_collections, update_ops['tn'])
2020
2021    return tn_value, update_ops['tn']
2022
2023
2024@tf_export(v1=['metrics.true_positives'])
2025def true_positives(labels,
2026                   predictions,
2027                   weights=None,
2028                   metrics_collections=None,
2029                   updates_collections=None,
2030                   name=None):
2031  """Sum the weights of true_positives.
2032
2033  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2034
2035  Args:
2036    labels: The ground truth values, a `Tensor` whose dimensions must match
2037      `predictions`. Will be cast to `bool`.
2038    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
2039      be cast to `bool`.
2040    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2041      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2042      be either `1`, or the same as the corresponding `labels` dimension).
2043    metrics_collections: An optional list of collections that the metric
2044      value variable should be added to.
2045    updates_collections: An optional list of collections that the metric update
2046      ops should be added to.
2047    name: An optional variable_scope name.
2048
2049  Returns:
2050    value_tensor: A `Tensor` representing the current value of the metric.
2051    update_op: An operation that accumulates the error from a batch of data.
2052
2053  Raises:
2054    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2055      `weights` is not `None` and its shape doesn't match `predictions`, or if
2056      either `metrics_collections` or `updates_collections` are not a list or
2057      tuple.
2058    RuntimeError: If eager execution is enabled.
2059  """
2060  if context.executing_eagerly():
2061    raise RuntimeError('tf.metrics.true_positives is not '
2062                       'supported when eager execution is enabled.')
2063
2064  with variable_scope.variable_scope(name, 'true_positives',
2065                                     (predictions, labels, weights)):
2066
2067    predictions, labels, weights = _remove_squeezable_dimensions(
2068        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2069        labels=math_ops.cast(labels, dtype=dtypes.bool),
2070        weights=weights)
2071    is_true_positive = math_ops.logical_and(
2072        math_ops.equal(labels, True), math_ops.equal(predictions, True))
2073    return _count_condition(is_true_positive, weights, metrics_collections,
2074                            updates_collections)
2075
2076
2077@tf_export(v1=['metrics.true_positives_at_thresholds'])
2078def true_positives_at_thresholds(labels,
2079                                 predictions,
2080                                 thresholds,
2081                                 weights=None,
2082                                 metrics_collections=None,
2083                                 updates_collections=None,
2084                                 name=None):
2085  """Computes true positives at provided threshold values.
2086
2087  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2088
2089  Args:
2090    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
2091      `bool`.
2092    predictions: A floating point `Tensor` of arbitrary shape and whose values
2093      are in the range `[0, 1]`.
2094    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2095    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2096      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2097      be either `1`, or the same as the corresponding `labels` dimension).
2098    metrics_collections: An optional list of collections that `true_positives`
2099      should be added to.
2100    updates_collections: An optional list of collections that `update_op` should
2101      be added to.
2102    name: An optional variable_scope name.
2103
2104  Returns:
2105    true_positives:  A float `Tensor` of shape `[len(thresholds)]`.
2106    update_op: An operation that updates the `true_positives` variable and
2107      returns its current value.
2108
2109  Raises:
2110    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2111      `weights` is not `None` and its shape doesn't match `predictions`, or if
2112      either `metrics_collections` or `updates_collections` are not a list or
2113      tuple.
2114    RuntimeError: If eager execution is enabled.
2115  """
2116  if context.executing_eagerly():
2117    raise RuntimeError('tf.metrics.true_positives_at_thresholds is not '
2118                       'supported when eager execution is enabled.')
2119
2120  with variable_scope.variable_scope(name, 'true_positives',
2121                                     (predictions, labels, weights)):
2122    values, update_ops = _confusion_matrix_at_thresholds(
2123        labels, predictions, thresholds, weights=weights, includes=('tp',))
2124
2125    tp_value = _aggregate_variable(values['tp'], metrics_collections)
2126
2127    if updates_collections:
2128      ops.add_to_collections(updates_collections, update_ops['tp'])
2129
2130    return tp_value, update_ops['tp']
2131
2132
2133@tf_export(v1=['metrics.precision'])
2134def precision(labels,
2135              predictions,
2136              weights=None,
2137              metrics_collections=None,
2138              updates_collections=None,
2139              name=None):
2140  """Computes the precision of the predictions with respect to the labels.
2141
2142  The `precision` function creates two local variables,
2143  `true_positives` and `false_positives`, that are used to compute the
2144  precision. This value is ultimately returned as `precision`, an idempotent
2145  operation that simply divides `true_positives` by the sum of `true_positives`
2146  and `false_positives`.
2147
2148  For estimation of the metric over a stream of data, the function creates an
2149  `update_op` operation that updates these variables and returns the
2150  `precision`. `update_op` weights each prediction by the corresponding value in
2151  `weights`.
2152
2153  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2154
2155  Args:
2156    labels: The ground truth values, a `Tensor` whose dimensions must match
2157      `predictions`. Will be cast to `bool`.
2158    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
2159      be cast to `bool`.
2160    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2161      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2162      be either `1`, or the same as the corresponding `labels` dimension).
2163    metrics_collections: An optional list of collections that `precision` should
2164      be added to.
2165    updates_collections: An optional list of collections that `update_op` should
2166      be added to.
2167    name: An optional variable_scope name.
2168
2169  Returns:
2170    precision: Scalar float `Tensor` with the value of `true_positives`
2171      divided by the sum of `true_positives` and `false_positives`.
2172    update_op: `Operation` that increments `true_positives` and
2173      `false_positives` variables appropriately and whose value matches
2174      `precision`.
2175
2176  Raises:
2177    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2178      `weights` is not `None` and its shape doesn't match `predictions`, or if
2179      either `metrics_collections` or `updates_collections` are not a list or
2180      tuple.
2181    RuntimeError: If eager execution is enabled.
2182  """
2183  if context.executing_eagerly():
2184    raise RuntimeError('tf.metrics.precision is not '
2185                       'supported when eager execution is enabled.')
2186
2187  with variable_scope.variable_scope(name, 'precision',
2188                                     (predictions, labels, weights)):
2189
2190    predictions, labels, weights = _remove_squeezable_dimensions(
2191        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2192        labels=math_ops.cast(labels, dtype=dtypes.bool),
2193        weights=weights)
2194
2195    true_p, true_positives_update_op = true_positives(
2196        labels,
2197        predictions,
2198        weights,
2199        metrics_collections=None,
2200        updates_collections=None,
2201        name=None)
2202    false_p, false_positives_update_op = false_positives(
2203        labels,
2204        predictions,
2205        weights,
2206        metrics_collections=None,
2207        updates_collections=None,
2208        name=None)
2209
2210    def compute_precision(tp, fp, name):
2211      return array_ops.where(
2212          math_ops.greater(tp + fp, 0), math_ops.divide(tp, tp + fp), 0, name)
2213
2214    def once_across_replicas(_, true_p, false_p):
2215      return compute_precision(true_p, false_p, 'value')
2216
2217    p = _aggregate_across_replicas(metrics_collections, once_across_replicas,
2218                                   true_p, false_p)
2219
2220    update_op = compute_precision(true_positives_update_op,
2221                                  false_positives_update_op, 'update_op')
2222    if updates_collections:
2223      ops.add_to_collections(updates_collections, update_op)
2224
2225    return p, update_op
2226
2227
2228@tf_export(v1=['metrics.precision_at_thresholds'])
2229def precision_at_thresholds(labels,
2230                            predictions,
2231                            thresholds,
2232                            weights=None,
2233                            metrics_collections=None,
2234                            updates_collections=None,
2235                            name=None):
2236  """Computes precision values for different `thresholds` on `predictions`.
2237
2238  The `precision_at_thresholds` function creates four local variables,
2239  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2240  for various values of thresholds. `precision[i]` is defined as the total
2241  weight of values in `predictions` above `thresholds[i]` whose corresponding
2242  entry in `labels` is `True`, divided by the total weight of values in
2243  `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
2244  false_positives[i])`).
2245
2246  For estimation of the metric over a stream of data, the function creates an
2247  `update_op` operation that updates these variables and returns the
2248  `precision`.
2249
2250  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2251
2252  Args:
2253    labels: The ground truth values, a `Tensor` whose dimensions must match
2254      `predictions`. Will be cast to `bool`.
2255    predictions: A floating point `Tensor` of arbitrary shape and whose values
2256      are in the range `[0, 1]`.
2257    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2258    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2259      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2260      be either `1`, or the same as the corresponding `labels` dimension).
2261    metrics_collections: An optional list of collections that `auc` should be
2262      added to.
2263    updates_collections: An optional list of collections that `update_op` should
2264      be added to.
2265    name: An optional variable_scope name.
2266
2267  Returns:
2268    precision: A float `Tensor` of shape `[len(thresholds)]`.
2269    update_op: An operation that increments the `true_positives`,
2270      `true_negatives`, `false_positives` and `false_negatives` variables that
2271      are used in the computation of `precision`.
2272
2273  Raises:
2274    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2275      `weights` is not `None` and its shape doesn't match `predictions`, or if
2276      either `metrics_collections` or `updates_collections` are not a list or
2277      tuple.
2278    RuntimeError: If eager execution is enabled.
2279  """
2280  if context.executing_eagerly():
2281    raise RuntimeError('tf.metrics.precision_at_thresholds is not '
2282                       'supported when eager execution is enabled.')
2283
2284  with variable_scope.variable_scope(name, 'precision_at_thresholds',
2285                                     (predictions, labels, weights)):
2286    values, update_ops = _confusion_matrix_at_thresholds(
2287        labels, predictions, thresholds, weights, includes=('tp', 'fp'))
2288
2289    # Avoid division by zero.
2290    epsilon = 1e-7
2291
2292    def compute_precision(tp, fp, name):
2293      return math_ops.divide(tp, epsilon + tp + fp, name='precision_' + name)
2294
2295    def precision_across_replicas(_, values):
2296      return compute_precision(values['tp'], values['fp'], 'value')
2297
2298    prec = _aggregate_across_replicas(
2299        metrics_collections, precision_across_replicas, values)
2300
2301    update_op = compute_precision(update_ops['tp'], update_ops['fp'],
2302                                  'update_op')
2303    if updates_collections:
2304      ops.add_to_collections(updates_collections, update_op)
2305
2306    return prec, update_op
2307
2308
2309@tf_export(v1=['metrics.recall'])
2310def recall(labels,
2311           predictions,
2312           weights=None,
2313           metrics_collections=None,
2314           updates_collections=None,
2315           name=None):
2316  """Computes the recall of the predictions with respect to the labels.
2317
2318  The `recall` function creates two local variables, `true_positives`
2319  and `false_negatives`, that are used to compute the recall. This value is
2320  ultimately returned as `recall`, an idempotent operation that simply divides
2321  `true_positives` by the sum of `true_positives` and `false_negatives`.
2322
2323  For estimation of the metric over a stream of data, the function creates an
2324  `update_op` that updates these variables and returns the `recall`. `update_op`
2325  weights each prediction by the corresponding value in `weights`.
2326
2327  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2328
2329  Args:
2330    labels: The ground truth values, a `Tensor` whose dimensions must match
2331      `predictions`. Will be cast to `bool`.
2332    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
2333      be cast to `bool`.
2334    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2335      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2336      be either `1`, or the same as the corresponding `labels` dimension).
2337    metrics_collections: An optional list of collections that `recall` should
2338      be added to.
2339    updates_collections: An optional list of collections that `update_op` should
2340      be added to.
2341    name: An optional variable_scope name.
2342
2343  Returns:
2344    recall: Scalar float `Tensor` with the value of `true_positives` divided
2345      by the sum of `true_positives` and `false_negatives`.
2346    update_op: `Operation` that increments `true_positives` and
2347      `false_negatives` variables appropriately and whose value matches
2348      `recall`.
2349
2350  Raises:
2351    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2352      `weights` is not `None` and its shape doesn't match `predictions`, or if
2353      either `metrics_collections` or `updates_collections` are not a list or
2354      tuple.
2355    RuntimeError: If eager execution is enabled.
2356  """
2357  if context.executing_eagerly():
2358    raise RuntimeError('tf.metrics.recall is not supported is not '
2359                       'supported when eager execution is enabled.')
2360
2361  with variable_scope.variable_scope(name, 'recall',
2362                                     (predictions, labels, weights)):
2363    predictions, labels, weights = _remove_squeezable_dimensions(
2364        predictions=math_ops.cast(predictions, dtype=dtypes.bool),
2365        labels=math_ops.cast(labels, dtype=dtypes.bool),
2366        weights=weights)
2367
2368    true_p, true_positives_update_op = true_positives(
2369        labels,
2370        predictions,
2371        weights,
2372        metrics_collections=None,
2373        updates_collections=None,
2374        name=None)
2375    false_n, false_negatives_update_op = false_negatives(
2376        labels,
2377        predictions,
2378        weights,
2379        metrics_collections=None,
2380        updates_collections=None,
2381        name=None)
2382
2383    def compute_recall(true_p, false_n, name):
2384      return array_ops.where(
2385          math_ops.greater(true_p + false_n, 0),
2386          math_ops.divide(true_p, true_p + false_n), 0, name)
2387
2388    def once_across_replicas(_, true_p, false_n):
2389      return compute_recall(true_p, false_n, 'value')
2390
2391    rec = _aggregate_across_replicas(
2392        metrics_collections, once_across_replicas, true_p, false_n)
2393
2394    update_op = compute_recall(true_positives_update_op,
2395                               false_negatives_update_op, 'update_op')
2396    if updates_collections:
2397      ops.add_to_collections(updates_collections, update_op)
2398
2399    return rec, update_op
2400
2401
2402def _at_k_name(name, k=None, class_id=None):
2403  if k is not None:
2404    name = '%s_at_%d' % (name, k)
2405  else:
2406    name = '%s_at_k' % (name)
2407  if class_id is not None:
2408    name = '%s_class%d' % (name, class_id)
2409  return name
2410
2411
2412def _select_class_id(ids, selected_id):
2413  """Filter all but `selected_id` out of `ids`.
2414
2415  Args:
2416    ids: `int64` `Tensor` or `SparseTensor` of IDs.
2417    selected_id: Int id to select.
2418
2419  Returns:
2420    `SparseTensor` of same dimensions as `ids`. This contains only the entries
2421    equal to `selected_id`.
2422  """
2423  ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids)
2424  if isinstance(ids, sparse_tensor.SparseTensor):
2425    return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values,
2426                                                        selected_id))
2427
2428  # TODO(ptucker): Make this more efficient, maybe add a sparse version of
2429  # tf.equal and tf.reduce_any?
2430
2431  # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
2432  ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
2433  ids_last_dim = array_ops.size(ids_shape) - 1
2434  filled_selected_id_shape = math_ops.reduced_shape(ids_shape,
2435                                                    array_ops.reshape(
2436                                                        ids_last_dim, [1]))
2437
2438  # Intersect `ids` with the selected ID.
2439  filled_selected_id = array_ops.fill(filled_selected_id_shape,
2440                                      math_ops.cast(selected_id, dtypes.int64))
2441  result = sets.set_intersection(filled_selected_id, ids)
2442  return sparse_tensor.SparseTensor(
2443      indices=result.indices, values=result.values, dense_shape=ids_shape)
2444
2445
2446def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
2447  """If class ID is specified, filter all other classes.
2448
2449  Args:
2450    labels: `int64` `Tensor` or `SparseTensor` with shape
2451      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2452      target classes for the associated prediction. Commonly, N=1 and `labels`
2453      has shape [batch_size, num_labels]. [D1, ... DN] must match
2454      `predictions_idx`.
2455    predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
2456      where N >= 1. Commonly, N=1 and `predictions_idx` has shape
2457      [batch size, k].
2458    selected_id: Int id to select.
2459
2460  Returns:
2461    Tuple of `labels` and `predictions_idx`, possibly with classes removed.
2462  """
2463  if selected_id is None:
2464    return labels, predictions_idx
2465  return (_select_class_id(labels, selected_id),
2466          _select_class_id(predictions_idx, selected_id))
2467
2468
2469def _sparse_true_positive_at_k(labels,
2470                               predictions_idx,
2471                               class_id=None,
2472                               weights=None,
2473                               name=None):
2474  """Calculates true positives for recall@k and precision@k.
2475
2476  If `class_id` is specified, calculate binary true positives for `class_id`
2477      only.
2478  If `class_id` is not specified, calculate metrics for `k` predicted vs
2479      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2480
2481  Args:
2482    labels: `int64` `Tensor` or `SparseTensor` with shape
2483      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2484      target classes for the associated prediction. Commonly, N=1 and `labels`
2485      has shape [batch_size, num_labels]. [D1, ... DN] must match
2486      `predictions_idx`.
2487    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2488      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2489      match `labels`.
2490    class_id: Class for which we want binary metrics.
2491    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2492      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2493      dimensions must be either `1`, or the same as the corresponding `labels`
2494      dimension).
2495    name: Name of operation.
2496
2497  Returns:
2498    A [D1, ... DN] `Tensor` of true positive counts.
2499  """
2500  with ops.name_scope(name, 'true_positives',
2501                      (predictions_idx, labels, weights)):
2502    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2503                                                     class_id)
2504    tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
2505    tp = math_ops.cast(tp, dtypes.float64)
2506    if weights is not None:
2507      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
2508          weights, tp),)):
2509        weights = math_ops.cast(weights, dtypes.float64)
2510        tp = math_ops.multiply(tp, weights)
2511    return tp
2512
2513
2514def _streaming_sparse_true_positive_at_k(labels,
2515                                         predictions_idx,
2516                                         k=None,
2517                                         class_id=None,
2518                                         weights=None,
2519                                         name=None):
2520  """Calculates weighted per step true positives for recall@k and precision@k.
2521
2522  If `class_id` is specified, calculate binary true positives for `class_id`
2523      only.
2524  If `class_id` is not specified, calculate metrics for `k` predicted vs
2525      `n` label classes, where `n` is the 2nd dimension of `labels`.
2526
2527  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2528
2529  Args:
2530    labels: `int64` `Tensor` or `SparseTensor` with shape
2531      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2532      target classes for the associated prediction. Commonly, N=1 and `labels`
2533      has shape [batch_size, num_labels]. [D1, ... DN] must match
2534      `predictions_idx`.
2535    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2536      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2537      match `labels`.
2538    k: Integer, k for @k metric. This is only used for default op name.
2539    class_id: Class for which we want binary metrics.
2540    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2541      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2542      dimensions must be either `1`, or the same as the corresponding `labels`
2543      dimension).
2544    name: Name of new variable, and namespace for other dependent ops.
2545
2546  Returns:
2547    A tuple of `Variable` and update `Operation`.
2548
2549  Raises:
2550    ValueError: If `weights` is not `None` and has an incompatible shape.
2551  """
2552  with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id),
2553                      (predictions_idx, labels, weights)) as scope:
2554    tp = _sparse_true_positive_at_k(
2555        predictions_idx=predictions_idx,
2556        labels=labels,
2557        class_id=class_id,
2558        weights=weights)
2559    batch_total_tp = math_ops.cast(math_ops.reduce_sum(tp), dtypes.float64)
2560
2561    var = metric_variable([], dtypes.float64, name=scope)
2562    return var, state_ops.assign_add(var, batch_total_tp, name='update')
2563
2564
2565def _sparse_false_negative_at_k(labels,
2566                                predictions_idx,
2567                                class_id=None,
2568                                weights=None):
2569  """Calculates false negatives for recall@k.
2570
2571  If `class_id` is specified, calculate binary true positives for `class_id`
2572      only.
2573  If `class_id` is not specified, calculate metrics for `k` predicted vs
2574      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
2575
2576  Args:
2577    labels: `int64` `Tensor` or `SparseTensor` with shape
2578      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2579      target classes for the associated prediction. Commonly, N=1 and `labels`
2580      has shape [batch_size, num_labels]. [D1, ... DN] must match
2581      `predictions_idx`.
2582    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2583      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2584      match `labels`.
2585    class_id: Class for which we want binary metrics.
2586    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2587      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2588      dimensions must be either `1`, or the same as the corresponding `labels`
2589      dimension).
2590
2591  Returns:
2592    A [D1, ... DN] `Tensor` of false negative counts.
2593  """
2594  with ops.name_scope(None, 'false_negatives',
2595                      (predictions_idx, labels, weights)):
2596    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
2597                                                     class_id)
2598    fn = sets.set_size(
2599        sets.set_difference(predictions_idx, labels, aminusb=False))
2600    fn = math_ops.cast(fn, dtypes.float64)
2601    if weights is not None:
2602      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
2603          weights, fn),)):
2604        weights = math_ops.cast(weights, dtypes.float64)
2605        fn = math_ops.multiply(fn, weights)
2606    return fn
2607
2608
2609def _streaming_sparse_false_negative_at_k(labels,
2610                                          predictions_idx,
2611                                          k,
2612                                          class_id=None,
2613                                          weights=None,
2614                                          name=None):
2615  """Calculates weighted per step false negatives for recall@k.
2616
2617  If `class_id` is specified, calculate binary true positives for `class_id`
2618      only.
2619  If `class_id` is not specified, calculate metrics for `k` predicted vs
2620      `n` label classes, where `n` is the 2nd dimension of `labels`.
2621
2622  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2623
2624  Args:
2625    labels: `int64` `Tensor` or `SparseTensor` with shape
2626      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
2627      target classes for the associated prediction. Commonly, N=1 and `labels`
2628      has shape [batch_size, num_labels]. [D1, ... DN] must match
2629      `predictions_idx`.
2630    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
2631      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
2632      match `labels`.
2633    k: Integer, k for @k metric. This is only used for default op name.
2634    class_id: Class for which we want binary metrics.
2635    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2636      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2637      dimensions must be either `1`, or the same as the corresponding `labels`
2638      dimension).
2639    name: Name of new variable, and namespace for other dependent ops.
2640
2641  Returns:
2642    A tuple of `Variable` and update `Operation`.
2643
2644  Raises:
2645    ValueError: If `weights` is not `None` and has an incompatible shape.
2646  """
2647  with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id),
2648                      (predictions_idx, labels, weights)) as scope:
2649    fn = _sparse_false_negative_at_k(
2650        predictions_idx=predictions_idx,
2651        labels=labels,
2652        class_id=class_id,
2653        weights=weights)
2654    batch_total_fn = math_ops.cast(math_ops.reduce_sum(fn), dtypes.float64)
2655
2656    var = metric_variable([], dtypes.float64, name=scope)
2657    return var, state_ops.assign_add(var, batch_total_fn, name='update')
2658
2659
2660@tf_export(v1=['metrics.recall_at_k'])
2661def recall_at_k(labels,
2662                predictions,
2663                k,
2664                class_id=None,
2665                weights=None,
2666                metrics_collections=None,
2667                updates_collections=None,
2668                name=None):
2669  """Computes recall@k of the predictions with respect to sparse labels.
2670
2671  If `class_id` is specified, we calculate recall by considering only the
2672      entries in the batch for which `class_id` is in the label, and computing
2673      the fraction of them for which `class_id` is in the top-k `predictions`.
2674  If `class_id` is not specified, we'll calculate recall as how often on
2675      average a class among the labels of a batch entry is in the top-k
2676      `predictions`.
2677
2678  `sparse_recall_at_k` creates two local variables,
2679  `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
2680  the recall_at_k frequency. This frequency is ultimately returned as
2681  `recall_at_<k>`: an idempotent operation that simply divides
2682  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
2683  `false_negative_at_<k>`).
2684
2685  For estimation of the metric over a stream of data, the function creates an
2686  `update_op` operation that updates these variables and returns the
2687  `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
2688  indicating the top `k` `predictions`. Set operations applied to `top_k` and
2689  `labels` calculate the true positives and false negatives weighted by
2690  `weights`. Then `update_op` increments `true_positive_at_<k>` and
2691  `false_negative_at_<k>` using these values.
2692
2693  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2694
2695  Args:
2696    labels: `int64` `Tensor` or `SparseTensor` with shape
2697      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2698      num_labels=1. N >= 1 and num_labels is the number of target classes for
2699      the associated prediction. Commonly, N=1 and `labels` has shape
2700      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2701      should be in range [0, num_classes), where num_classes is the last
2702      dimension of `predictions`. Values outside this range always count
2703      towards `false_negative_at_<k>`.
2704    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
2705      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
2706      The final dimension contains the logit values for each class. [D1, ... DN]
2707      must match `labels`.
2708    k: Integer, k for @k metric.
2709    class_id: Integer class ID for which we want binary metrics. This should be
2710      in range [0, num_classes), where num_classes is the last dimension of
2711      `predictions`. If class_id is outside this range, the method returns NAN.
2712    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2713      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2714      dimensions must be either `1`, or the same as the corresponding `labels`
2715      dimension).
2716    metrics_collections: An optional list of collections that values should
2717      be added to.
2718    updates_collections: An optional list of collections that updates should
2719      be added to.
2720    name: Name of new update operation, and namespace for other dependent ops.
2721
2722  Returns:
2723    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2724      by the sum of `true_positives` and `false_negatives`.
2725    update_op: `Operation` that increments `true_positives` and
2726      `false_negatives` variables appropriately, and whose value matches
2727      `recall`.
2728
2729  Raises:
2730    ValueError: If `weights` is not `None` and its shape doesn't match
2731    `predictions`, or if either `metrics_collections` or `updates_collections`
2732    are not a list or tuple.
2733    RuntimeError: If eager execution is enabled.
2734  """
2735  if context.executing_eagerly():
2736    raise RuntimeError('tf.metrics.recall_at_k is not '
2737                       'supported when eager execution is enabled.')
2738
2739  with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2740                      (predictions, labels, weights)) as scope:
2741    _, top_k_idx = nn.top_k(predictions, k)
2742    return recall_at_top_k(
2743        labels=labels,
2744        predictions_idx=top_k_idx,
2745        k=k,
2746        class_id=class_id,
2747        weights=weights,
2748        metrics_collections=metrics_collections,
2749        updates_collections=updates_collections,
2750        name=scope)
2751
2752
2753@tf_export(v1=['metrics.recall_at_top_k'])
2754def recall_at_top_k(labels,
2755                    predictions_idx,
2756                    k=None,
2757                    class_id=None,
2758                    weights=None,
2759                    metrics_collections=None,
2760                    updates_collections=None,
2761                    name=None):
2762  """Computes recall@k of top-k predictions with respect to sparse labels.
2763
2764  Differs from `recall_at_k` in that predictions must be in the form of top `k`
2765  class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k`
2766  for more details.
2767
2768  Args:
2769    labels: `int64` `Tensor` or `SparseTensor` with shape
2770      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
2771      num_labels=1. N >= 1 and num_labels is the number of target classes for
2772      the associated prediction. Commonly, N=1 and `labels` has shape
2773      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
2774      should be in range [0, num_classes), where num_classes is the last
2775      dimension of `predictions`. Values outside this range always count
2776      towards `false_negative_at_<k>`.
2777    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
2778      Commonly, N=1 and predictions has shape [batch size, k]. The final
2779      dimension contains the top `k` predicted class indices. [D1, ... DN] must
2780      match `labels`.
2781    k: Integer, k for @k metric. Only used for the default op name.
2782    class_id: Integer class ID for which we want binary metrics. This should be
2783      in range [0, num_classes), where num_classes is the last dimension of
2784      `predictions`. If class_id is outside this range, the method returns NAN.
2785    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
2786      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
2787      dimensions must be either `1`, or the same as the corresponding `labels`
2788      dimension).
2789    metrics_collections: An optional list of collections that values should
2790      be added to.
2791    updates_collections: An optional list of collections that updates should
2792      be added to.
2793    name: Name of new update operation, and namespace for other dependent ops.
2794
2795  Returns:
2796    recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
2797      by the sum of `true_positives` and `false_negatives`.
2798    update_op: `Operation` that increments `true_positives` and
2799      `false_negatives` variables appropriately, and whose value matches
2800      `recall`.
2801
2802  Raises:
2803    ValueError: If `weights` is not `None` and its shape doesn't match
2804    `predictions`, or if either `metrics_collections` or `updates_collections`
2805    are not a list or tuple.
2806  """
2807  with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
2808                      (predictions_idx, labels, weights)) as scope:
2809    labels = _maybe_expand_labels(labels, predictions_idx)
2810    top_k_idx = math_ops.cast(predictions_idx, dtypes.int64)
2811    tp, tp_update = _streaming_sparse_true_positive_at_k(
2812        predictions_idx=top_k_idx,
2813        labels=labels,
2814        k=k,
2815        class_id=class_id,
2816        weights=weights)
2817    fn, fn_update = _streaming_sparse_false_negative_at_k(
2818        predictions_idx=top_k_idx,
2819        labels=labels,
2820        k=k,
2821        class_id=class_id,
2822        weights=weights)
2823
2824    def compute_recall(_, tp, fn):
2825      return math_ops.divide(tp, math_ops.add(tp, fn), name=scope)
2826
2827    metric = _aggregate_across_replicas(
2828        metrics_collections, compute_recall, tp, fn)
2829
2830    update = math_ops.divide(
2831        tp_update, math_ops.add(tp_update, fn_update), name='update')
2832    if updates_collections:
2833      ops.add_to_collections(updates_collections, update)
2834    return metric, update
2835
2836
2837@tf_export(v1=['metrics.recall_at_thresholds'])
2838def recall_at_thresholds(labels,
2839                         predictions,
2840                         thresholds,
2841                         weights=None,
2842                         metrics_collections=None,
2843                         updates_collections=None,
2844                         name=None):
2845  """Computes various recall values for different `thresholds` on `predictions`.
2846
2847  The `recall_at_thresholds` function creates four local variables,
2848  `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
2849  for various values of thresholds. `recall[i]` is defined as the total weight
2850  of values in `predictions` above `thresholds[i]` whose corresponding entry in
2851  `labels` is `True`, divided by the total weight of `True` values in `labels`
2852  (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
2853
2854  For estimation of the metric over a stream of data, the function creates an
2855  `update_op` operation that updates these variables and returns the `recall`.
2856
2857  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2858
2859  Args:
2860    labels: The ground truth values, a `Tensor` whose dimensions must match
2861      `predictions`. Will be cast to `bool`.
2862    predictions: A floating point `Tensor` of arbitrary shape and whose values
2863      are in the range `[0, 1]`.
2864    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
2865    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2866      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2867      be either `1`, or the same as the corresponding `labels` dimension).
2868    metrics_collections: An optional list of collections that `recall` should be
2869      added to.
2870    updates_collections: An optional list of collections that `update_op` should
2871      be added to.
2872    name: An optional variable_scope name.
2873
2874  Returns:
2875    recall: A float `Tensor` of shape `[len(thresholds)]`.
2876    update_op: An operation that increments the `true_positives`,
2877      `true_negatives`, `false_positives` and `false_negatives` variables that
2878      are used in the computation of `recall`.
2879
2880  Raises:
2881    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2882      `weights` is not `None` and its shape doesn't match `predictions`, or if
2883      either `metrics_collections` or `updates_collections` are not a list or
2884      tuple.
2885    RuntimeError: If eager execution is enabled.
2886  """
2887  if context.executing_eagerly():
2888    raise RuntimeError('tf.metrics.recall_at_thresholds is not '
2889                       'supported when eager execution is enabled.')
2890
2891  with variable_scope.variable_scope(name, 'recall_at_thresholds',
2892                                     (predictions, labels, weights)):
2893    values, update_ops = _confusion_matrix_at_thresholds(
2894        labels, predictions, thresholds, weights, includes=('tp', 'fn'))
2895
2896    # Avoid division by zero.
2897    epsilon = 1e-7
2898
2899    def compute_recall(tp, fn, name):
2900      return math_ops.divide(tp, epsilon + tp + fn, name='recall_' + name)
2901
2902    def recall_across_replicas(_, values):
2903      return compute_recall(values['tp'], values['fn'], 'value')
2904
2905    rec = _aggregate_across_replicas(
2906        metrics_collections, recall_across_replicas, values)
2907
2908    update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
2909    if updates_collections:
2910      ops.add_to_collections(updates_collections, update_op)
2911
2912    return rec, update_op
2913
2914
2915@tf_export(v1=['metrics.root_mean_squared_error'])
2916def root_mean_squared_error(labels,
2917                            predictions,
2918                            weights=None,
2919                            metrics_collections=None,
2920                            updates_collections=None,
2921                            name=None):
2922  """Computes the root mean squared error between the labels and predictions.
2923
2924  The `root_mean_squared_error` function creates two local variables,
2925  `total` and `count` that are used to compute the root mean squared error.
2926  This average is weighted by `weights`, and it is ultimately returned as
2927  `root_mean_squared_error`: an idempotent operation that takes the square root
2928  of the division of `total` by `count`.
2929
2930  For estimation of the metric over a stream of data, the function creates an
2931  `update_op` operation that updates these variables and returns the
2932  `root_mean_squared_error`. Internally, a `squared_error` operation computes
2933  the element-wise square of the difference between `predictions` and `labels`.
2934  Then `update_op` increments `total` with the reduced sum of the product of
2935  `weights` and `squared_error`, and it increments `count` with the reduced sum
2936  of `weights`.
2937
2938  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
2939
2940  Args:
2941    labels: A `Tensor` of the same shape as `predictions`.
2942    predictions: A `Tensor` of arbitrary shape.
2943    weights: Optional `Tensor` whose rank is either 0, or the same rank as
2944      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
2945      be either `1`, or the same as the corresponding `labels` dimension).
2946    metrics_collections: An optional list of collections that
2947      `root_mean_squared_error` should be added to.
2948    updates_collections: An optional list of collections that `update_op` should
2949      be added to.
2950    name: An optional variable_scope name.
2951
2952  Returns:
2953    root_mean_squared_error: A `Tensor` representing the current mean, the value
2954      of `total` divided by `count`.
2955    update_op: An operation that increments the `total` and `count` variables
2956      appropriately and whose value matches `root_mean_squared_error`.
2957
2958  Raises:
2959    ValueError: If `predictions` and `labels` have mismatched shapes, or if
2960      `weights` is not `None` and its shape doesn't match `predictions`, or if
2961      either `metrics_collections` or `updates_collections` are not a list or
2962      tuple.
2963    RuntimeError: If eager execution is enabled.
2964  """
2965  if context.executing_eagerly():
2966    raise RuntimeError('tf.metrics.root_mean_squared_error is not '
2967                       'supported when eager execution is enabled.')
2968
2969  predictions, labels, weights = _remove_squeezable_dimensions(
2970      predictions=predictions, labels=labels, weights=weights)
2971  mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
2972                                          None, name or
2973                                          'root_mean_squared_error')
2974
2975  once_across_replicas = lambda _, mse: math_ops.sqrt(mse)
2976  rmse = _aggregate_across_replicas(
2977      metrics_collections, once_across_replicas, mse)
2978
2979  update_rmse_op = math_ops.sqrt(update_mse_op)
2980  if updates_collections:
2981    ops.add_to_collections(updates_collections, update_rmse_op)
2982
2983  return rmse, update_rmse_op
2984
2985
2986@tf_export(v1=['metrics.sensitivity_at_specificity'])
2987def sensitivity_at_specificity(labels,
2988                               predictions,
2989                               specificity,
2990                               weights=None,
2991                               num_thresholds=200,
2992                               metrics_collections=None,
2993                               updates_collections=None,
2994                               name=None):
2995  """Computes the specificity at a given sensitivity.
2996
2997  The `sensitivity_at_specificity` function creates four local
2998  variables, `true_positives`, `true_negatives`, `false_positives` and
2999  `false_negatives` that are used to compute the sensitivity at the given
3000  specificity value. The threshold for the given specificity value is computed
3001  and used to evaluate the corresponding sensitivity.
3002
3003  For estimation of the metric over a stream of data, the function creates an
3004  `update_op` operation that updates these variables and returns the
3005  `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
3006  `false_positives` and `false_negatives` counts with the weight of each case
3007  found in the `predictions` and `labels`.
3008
3009  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3010
3011  For additional information about specificity and sensitivity, see the
3012  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
3013
3014  Args:
3015    labels: The ground truth values, a `Tensor` whose dimensions must match
3016      `predictions`. Will be cast to `bool`.
3017    predictions: A floating point `Tensor` of arbitrary shape and whose values
3018      are in the range `[0, 1]`.
3019    specificity: A scalar value in range `[0, 1]`.
3020    weights: Optional `Tensor` whose rank is either 0, or the same rank as
3021      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
3022      be either `1`, or the same as the corresponding `labels` dimension).
3023    num_thresholds: The number of thresholds to use for matching the given
3024      specificity.
3025    metrics_collections: An optional list of collections that `sensitivity`
3026      should be added to.
3027    updates_collections: An optional list of collections that `update_op` should
3028      be added to.
3029    name: An optional variable_scope name.
3030
3031  Returns:
3032    sensitivity: A scalar `Tensor` representing the sensitivity at the given
3033      `specificity` value.
3034    update_op: An operation that increments the `true_positives`,
3035      `true_negatives`, `false_positives` and `false_negatives` variables
3036      appropriately and whose value matches `sensitivity`.
3037
3038  Raises:
3039    ValueError: If `predictions` and `labels` have mismatched shapes, if
3040      `weights` is not `None` and its shape doesn't match `predictions`, or if
3041      `specificity` is not between 0 and 1, or if either `metrics_collections`
3042      or `updates_collections` are not a list or tuple.
3043    RuntimeError: If eager execution is enabled.
3044  """
3045  if context.executing_eagerly():
3046    raise RuntimeError('tf.metrics.sensitivity_at_specificity is not '
3047                       'supported when eager execution is enabled.')
3048
3049  if specificity < 0 or specificity > 1:
3050    raise ValueError('`specificity` must be in the range [0, 1]. Currently, '
3051                     f'`specificity` got {specificity}.')
3052
3053  with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
3054                                     (predictions, labels, weights)):
3055    kepsilon = 1e-7  # to account for floating point imprecisions
3056    thresholds = [
3057        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
3058    ]
3059    thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
3060
3061    values, update_ops = _confusion_matrix_at_thresholds(
3062        labels, predictions, thresholds, weights)
3063
3064    def compute_sensitivity_at_specificity(tp, tn, fp, fn, name):
3065      specificities = math_ops.divide(tn, tn + fp + kepsilon)
3066      tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
3067      tf_index = math_ops.cast(tf_index, dtypes.int32)
3068
3069      # Now, we have the implicit threshold, so compute the sensitivity:
3070      return math_ops.divide(tp[tf_index],
3071                             tp[tf_index] + fn[tf_index] + kepsilon, name)
3072
3073    def sensitivity_across_replicas(_, values):
3074      return compute_sensitivity_at_specificity(
3075          values['tp'], values['tn'], values['fp'], values['fn'], 'value')
3076
3077    sensitivity = _aggregate_across_replicas(
3078        metrics_collections, sensitivity_across_replicas, values)
3079
3080    update_op = compute_sensitivity_at_specificity(
3081        update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
3082        'update_op')
3083    if updates_collections:
3084      ops.add_to_collections(updates_collections, update_op)
3085
3086    return sensitivity, update_op
3087
3088
3089def _expand_and_tile(tensor, multiple, dim=0, name=None):
3090  """Slice `tensor` shape in 2, then tile along the sliced dimension.
3091
3092  A new dimension is inserted in shape of `tensor` before `dim`, then values are
3093  tiled `multiple` times along the new dimension.
3094
3095  Args:
3096    tensor: Input `Tensor` or `SparseTensor`.
3097    multiple: Integer, number of times to tile.
3098    dim: Integer, dimension along which to tile.
3099    name: Name of operation.
3100
3101  Returns:
3102    `Tensor` result of expanding and tiling `tensor`.
3103
3104  Raises:
3105    ValueError: if `multiple` is less than 1, or `dim` is not in
3106    `[-rank(tensor), rank(tensor)]`.
3107  """
3108  if multiple < 1:
3109    raise ValueError(f'Invalid argument multiple={multiple} for '
3110                     'expand_and_tile  call. `multiple` must be an integer > 0')
3111  with ops.name_scope(name, 'expand_and_tile',
3112                      (tensor, multiple, dim)) as scope:
3113    # Sparse.
3114    tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor)
3115    if isinstance(tensor, sparse_tensor.SparseTensor):
3116      if dim < 0:
3117        expand_dims = array_ops.reshape(
3118            array_ops.size(tensor.dense_shape) + dim, [1])
3119      else:
3120        expand_dims = [dim]
3121      expanded_shape = array_ops.concat(
3122          (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1],
3123           array_ops.slice(tensor.dense_shape, expand_dims, [-1])),
3124          0,
3125          name='expanded_shape')
3126      expanded = sparse_ops.sparse_reshape(
3127          tensor, shape=expanded_shape, name='expand')
3128      if multiple == 1:
3129        return expanded
3130      return sparse_ops.sparse_concat(
3131          dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope)
3132
3133    # Dense.
3134    expanded = array_ops.expand_dims(
3135        tensor, dim if (dim >= 0) else (dim - 1), name='expand')
3136    if multiple == 1:
3137      return expanded
3138    ones = array_ops.ones_like(array_ops.shape(tensor))
3139    tile_multiples = array_ops.concat(
3140        (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples')
3141    return array_ops.tile(expanded, tile_multiples, name=scope)
3142
3143
3144def _num_relevant(labels, k):
3145  """Computes number of relevant values for each row in labels.
3146
3147  For labels with shape [D1, ... DN, num_labels], this is the minimum of
3148  `num_labels` and `k`.
3149
3150  Args:
3151    labels: `int64` `Tensor` or `SparseTensor` with shape
3152      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3153      target classes for the associated prediction. Commonly, N=1 and `labels`
3154      has shape [batch_size, num_labels].
3155    k: Integer, k for @k metric.
3156
3157  Returns:
3158    Integer `Tensor` of shape [D1, ... DN], where each value is the number of
3159    relevant values for that row.
3160
3161  Raises:
3162    ValueError: if inputs have invalid dtypes or values.
3163  """
3164  if k < 1:
3165    raise ValueError(f'Invalid k={k}')
3166  with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
3167    # For SparseTensor, calculate separate count for each row.
3168    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
3169    if isinstance(labels, sparse_tensor.SparseTensor):
3170      return math_ops.minimum(sets.set_size(labels), k, name=scope)
3171
3172    # The relevant values for each (d1, ... dN) is the minimum of k and the
3173    # number of labels along the last dimension that are non-negative.
3174    num_labels = math_ops.reduce_sum(
3175        array_ops.where_v2(math_ops.greater_equal(labels, 0),
3176                           array_ops.ones_like(labels),
3177                           array_ops.zeros_like(labels)),
3178        axis=-1)
3179    return math_ops.minimum(num_labels, k, name=scope)
3180
3181
3182def _sparse_average_precision_at_top_k(labels, predictions_idx):
3183  """Computes average precision@k of predictions with respect to sparse labels.
3184
3185  From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula
3186  for each row is:
3187
3188    AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items
3189
3190  A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`,
3191  `labels`, and the result `Tensors`. In the common case, this is [batch_size].
3192  Each row of the results contains the average precision for that row.
3193
3194  Args:
3195    labels: `int64` `Tensor` or `SparseTensor` with shape
3196      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3197      num_labels=1. N >= 1 and num_labels is the number of target classes for
3198      the associated prediction. Commonly, N=1 and `labels` has shape
3199      [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
3200      Values should be non-negative. Negative values are ignored.
3201    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
3202      Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
3203      dimension must be set and contains the top `k` predicted class indices.
3204      [D1, ... DN] must match `labels`. Values should be in range
3205      [0, num_classes).
3206
3207  Returns:
3208    `float64` `Tensor` of shape [D1, ... DN], where each value is the average
3209    precision for that row.
3210
3211  Raises:
3212    ValueError: if the last dimension of predictions_idx is not set.
3213  """
3214  with ops.name_scope(None, 'average_precision',
3215                      (predictions_idx, labels)) as scope:
3216    predictions_idx = math_ops.cast(
3217        predictions_idx, dtypes.int64, name='predictions_idx')
3218    if predictions_idx.get_shape().ndims == 0:
3219      raise ValueError('The rank of `predictions_idx` must be at least 1.')
3220    k = predictions_idx.get_shape().as_list()[-1]
3221    if k is None:
3222      raise ValueError('The last dimension of predictions_idx must be set. '
3223                       'Currently, it is None.')
3224    labels = _maybe_expand_labels(labels, predictions_idx)
3225
3226    # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate
3227    # prediction for each k, so we can calculate separate true positive values
3228    # for each k.
3229    predictions_idx_per_k = array_ops.expand_dims(
3230        predictions_idx, -1, name='predictions_idx_per_k')
3231
3232    # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor.
3233    labels_per_k = _expand_and_tile(
3234        labels, multiple=k, dim=-1, name='labels_per_k')
3235
3236    # The following tensors are all of shape [D1, ... DN, k], containing values
3237    # per row, per k value.
3238    # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at
3239    #     that k value is correct, 0 otherwise. This is the "rel_{i}" term from
3240    #     the formula above.
3241    # `tp_per_k` (int32) - True positive counts.
3242    # `retrieved_per_k` (int32) - Number of predicted values at each k. This is
3243    #     the precision denominator.
3244    # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}"
3245    #     term from the formula above.
3246    # `relevant_precision_per_k` (float64) - Relevant precisions; i.e.,
3247    #     precisions at all k for which relevance indicator is true.
3248    relevant_per_k = _sparse_true_positive_at_k(
3249        labels_per_k, predictions_idx_per_k, name='relevant_per_k')
3250    tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k')
3251    retrieved_per_k = math_ops.cumsum(
3252        array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
3253    precision_per_k = math_ops.divide(
3254        math_ops.cast(tp_per_k, dtypes.float64),
3255        math_ops.cast(retrieved_per_k, dtypes.float64),
3256        name='precision_per_k')
3257    relevant_precision_per_k = math_ops.multiply(
3258        precision_per_k,
3259        math_ops.cast(relevant_per_k, dtypes.float64),
3260        name='relevant_precision_per_k')
3261
3262    # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
3263    precision_sum = math_ops.reduce_sum(
3264        relevant_precision_per_k, axis=(-1,), name='precision_sum')
3265
3266    # Divide by number of relevant items to get average precision. These are
3267    # the "num_relevant_items" and "AveP" terms from the formula above.
3268    num_relevant_items = math_ops.cast(_num_relevant(labels, k), dtypes.float64)
3269    return math_ops.divide(precision_sum, num_relevant_items, name=scope)
3270
3271
3272def _streaming_sparse_average_precision_at_top_k(labels,
3273                                                 predictions_idx,
3274                                                 weights=None,
3275                                                 metrics_collections=None,
3276                                                 updates_collections=None,
3277                                                 name=None):
3278  """Computes average precision@k of predictions with respect to sparse labels.
3279
3280  `sparse_average_precision_at_top_k` creates two local variables,
3281  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3282  are used to compute the frequency. This frequency is ultimately returned as
3283  `average_precision_at_<k>`: an idempotent operation that simply divides
3284  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3285
3286  For estimation of the metric over a stream of data, the function creates an
3287  `update_op` operation that updates these variables and returns the
3288  `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate
3289  the true positives and false positives weighted by `weights`. Then `update_op`
3290  increments `true_positive_at_<k>` and `false_positive_at_<k>` using these
3291  values.
3292
3293  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3294
3295  Args:
3296    labels: `int64` `Tensor` or `SparseTensor` with shape
3297      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3298      num_labels=1. N >= 1 and num_labels is the number of target classes for
3299      the associated prediction. Commonly, N=1 and `labels` has shape
3300      [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
3301      Values should be non-negative. Negative values are ignored.
3302    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
3303      Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
3304      dimension contains the top `k` predicted class indices. [D1, ... DN] must
3305      match `labels`. Values should be in range [0, num_classes).
3306    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3307      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3308      dimensions must be either `1`, or the same as the corresponding `labels`
3309      dimension).
3310    metrics_collections: An optional list of collections that values should
3311      be added to.
3312    updates_collections: An optional list of collections that updates should
3313      be added to.
3314    name: Name of new update operation, and namespace for other dependent ops.
3315
3316  Returns:
3317    mean_average_precision: Scalar `float64` `Tensor` with the mean average
3318      precision values.
3319    update: `Operation` that increments variables appropriately, and whose
3320      value matches `metric`.
3321  """
3322  with ops.name_scope(name, 'average_precision_at_top_k',
3323                      (predictions_idx, labels, weights)) as scope:
3324    # Calculate per-example average precision, and apply weights.
3325    average_precision = _sparse_average_precision_at_top_k(
3326        predictions_idx=predictions_idx, labels=labels)
3327    if weights is not None:
3328      weights = weights_broadcast_ops.broadcast_weights(
3329          math_ops.cast(weights, dtypes.float64), average_precision)
3330      average_precision = math_ops.multiply(average_precision, weights)
3331
3332    # Create accumulation variables and update ops for max average precision and
3333    # total average precision.
3334    with ops.name_scope(None, 'max', (average_precision,)) as max_scope:
3335      # `max` is the max possible precision. Since max for any row is 1.0:
3336      # - For the unweighted case, this is just the number of rows.
3337      # - For the weighted case, it's the sum of the weights broadcast across
3338      #   `average_precision` rows.
3339      max_var = metric_variable([], dtypes.float64, name=max_scope)
3340      if weights is None:
3341        batch_max = math_ops.cast(
3342            array_ops.size(average_precision, name='batch_max'), dtypes.float64)
3343      else:
3344        batch_max = math_ops.reduce_sum(weights, name='batch_max')
3345      max_update = state_ops.assign_add(max_var, batch_max, name='update')
3346    with ops.name_scope(None, 'total', (average_precision,)) as total_scope:
3347      total_var = metric_variable([], dtypes.float64, name=total_scope)
3348      batch_total = math_ops.reduce_sum(average_precision, name='batch_total')
3349      total_update = state_ops.assign_add(total_var, batch_total, name='update')
3350
3351    # Divide total by max to get mean, for both vars and the update ops.
3352    def precision_across_replicas(_, total_var, max_var):
3353      return _safe_scalar_div(total_var, max_var, name='mean')
3354
3355    mean_average_precision = _aggregate_across_replicas(
3356        metrics_collections, precision_across_replicas, total_var, max_var)
3357
3358    update = _safe_scalar_div(total_update, max_update, name=scope)
3359    if updates_collections:
3360      ops.add_to_collections(updates_collections, update)
3361
3362    return mean_average_precision, update
3363
3364
3365def _clean_out_of_range_indices(labels, num_classes):
3366  """Replaces large out-of-range labels by small out-of-range labels.
3367
3368  Replaces any value in `labels` that is greater or equal to `num_classes` by
3369  -1. Do this conditionally for efficiency in case there are no such values.
3370
3371  Args:
3372    labels: `int64` `Tensor` or `SparseTensor`.
3373    num_classes: `int64` scalar `Tensor`.
3374  Returns:
3375    An `int64` `Tensor` or `SparseTensor` as `labels` with indices greater
3376    or equal to num_classes replaced by -1.
3377  """
3378
3379  def _labels_is_sparse():
3380    """Returns true is `labels` is a sparse tensor."""
3381    return isinstance(labels, (sparse_tensor.SparseTensor,
3382                               sparse_tensor.SparseTensorValue))
3383
3384  def _clean_out_of_range(values):
3385    """Replaces by -1 any large out-of-range `values`."""
3386    return array_ops.where_v2(math_ops.greater_equal(values, num_classes),
3387                              -1 * array_ops.ones_like(values), values)
3388
3389  def _clean_labels_out_of_range():
3390    """Replaces by -1 ane large out-of-range values in `labels`."""
3391    if _labels_is_sparse():
3392      return type(labels)(indices=labels.indices,
3393                          values=_clean_out_of_range(labels.values),
3394                          dense_shape=labels.dense_shape)
3395    else:
3396      return _clean_out_of_range(labels)
3397
3398  max_labels = math_ops.reduce_max(
3399      labels.values if _labels_is_sparse() else labels)
3400  return control_flow_ops.cond(
3401      math_ops.greater_equal(max_labels, num_classes),
3402      _clean_labels_out_of_range,
3403      lambda: labels)
3404
3405
3406@tf_export(v1=['metrics.sparse_average_precision_at_k'])
3407@deprecated(None, 'Use average_precision_at_k instead')
3408def sparse_average_precision_at_k(labels,
3409                                  predictions,
3410                                  k,
3411                                  weights=None,
3412                                  metrics_collections=None,
3413                                  updates_collections=None,
3414                                  name=None):
3415  """Renamed to `average_precision_at_k`, please use that method instead."""
3416  return average_precision_at_k(
3417      labels=labels,
3418      predictions=predictions,
3419      k=k,
3420      weights=weights,
3421      metrics_collections=metrics_collections,
3422      updates_collections=updates_collections,
3423      name=name)
3424
3425
3426@tf_export(v1=['metrics.average_precision_at_k'])
3427def average_precision_at_k(labels,
3428                           predictions,
3429                           k,
3430                           weights=None,
3431                           metrics_collections=None,
3432                           updates_collections=None,
3433                           name=None):
3434  """Computes average precision@k of predictions with respect to sparse labels.
3435
3436  `average_precision_at_k` creates two local variables,
3437  `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
3438  are used to compute the frequency. This frequency is ultimately returned as
3439  `average_precision_at_<k>`: an idempotent operation that simply divides
3440  `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
3441
3442  For estimation of the metric over a stream of data, the function creates an
3443  `update_op` operation that updates these variables and returns the
3444  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3445  indicating the top `k` `predictions`. Set operations applied to `top_k` and
3446  `labels` calculate the true positives and false positives weighted by
3447  `weights`. Then `update_op` increments `true_positive_at_<k>` and
3448  `false_positive_at_<k>` using these values.
3449
3450  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3451
3452  Args:
3453    labels: `int64` `Tensor` or `SparseTensor` with shape
3454      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3455      num_labels=1. N >= 1 and num_labels is the number of target classes for
3456      the associated prediction. Commonly, N=1 and `labels` has shape
3457      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3458      should be in range [0, num_classes), where num_classes is the last
3459      dimension of `predictions`. Values outside this range are ignored.
3460    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3461      N >= 1. Commonly, N=1 and `predictions` has shape
3462      [batch size, num_classes]. The final dimension contains the logit values
3463      for each class. [D1, ... DN] must match `labels`.
3464    k: Integer, k for @k metric. This will calculate an average precision for
3465      range `[1,k]`, as documented above.
3466    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3467      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3468      dimensions must be either `1`, or the same as the corresponding `labels`
3469      dimension).
3470    metrics_collections: An optional list of collections that values should
3471      be added to.
3472    updates_collections: An optional list of collections that updates should
3473      be added to.
3474    name: Name of new update operation, and namespace for other dependent ops.
3475
3476  Returns:
3477    mean_average_precision: Scalar `float64` `Tensor` with the mean average
3478      precision values.
3479    update: `Operation` that increments variables appropriately, and whose
3480      value matches `metric`.
3481
3482  Raises:
3483    ValueError: if k is invalid.
3484    RuntimeError: If eager execution is enabled.
3485  """
3486  if context.executing_eagerly():
3487    raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not '
3488                       'supported when eager execution is enabled.')
3489
3490  if k < 1:
3491    raise ValueError(f'Invalid k={k}. `k` should be >= 1.')
3492  with ops.name_scope(name, _at_k_name('average_precision', k),
3493                      (predictions, labels, weights)) as scope:
3494    # Calculate top k indices to produce [D1, ... DN, k] tensor.
3495    _, predictions_idx = nn.top_k(predictions, k)
3496    # The documentation states that labels should be in [0, ..., num_classes),
3497    # but num_classes is lost when predictions_idx replaces predictions.
3498    # For conformity with the documentation, any label >= num_classes, which is
3499    # ignored, is replaced by -1.
3500    labels = _clean_out_of_range_indices(
3501        labels, math_ops.cast(array_ops.shape(predictions)[-1], dtypes.int64))
3502    return _streaming_sparse_average_precision_at_top_k(
3503        labels=labels,
3504        predictions_idx=predictions_idx,
3505        weights=weights,
3506        metrics_collections=metrics_collections,
3507        updates_collections=updates_collections,
3508        name=scope)
3509
3510
3511def _sparse_false_positive_at_k(labels,
3512                                predictions_idx,
3513                                class_id=None,
3514                                weights=None):
3515  """Calculates false positives for precision@k.
3516
3517  If `class_id` is specified, calculate binary true positives for `class_id`
3518      only.
3519  If `class_id` is not specified, calculate metrics for `k` predicted vs
3520      `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
3521
3522  Args:
3523    labels: `int64` `Tensor` or `SparseTensor` with shape
3524      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3525      target classes for the associated prediction. Commonly, N=1 and `labels`
3526      has shape [batch_size, num_labels]. [D1, ... DN] must match
3527      `predictions_idx`.
3528    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3529      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3530      match `labels`.
3531    class_id: Class for which we want binary metrics.
3532    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3533      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3534      dimensions must be either `1`, or the same as the corresponding `labels`
3535      dimension).
3536
3537  Returns:
3538    A [D1, ... DN] `Tensor` of false positive counts.
3539  """
3540  with ops.name_scope(None, 'false_positives',
3541                      (predictions_idx, labels, weights)):
3542    labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
3543                                                     class_id)
3544    fp = sets.set_size(
3545        sets.set_difference(predictions_idx, labels, aminusb=True))
3546    fp = math_ops.cast(fp, dtypes.float64)
3547    if weights is not None:
3548      with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
3549          weights, fp),)):
3550        weights = math_ops.cast(weights, dtypes.float64)
3551        fp = math_ops.multiply(fp, weights)
3552    return fp
3553
3554
3555def _streaming_sparse_false_positive_at_k(labels,
3556                                          predictions_idx,
3557                                          k=None,
3558                                          class_id=None,
3559                                          weights=None,
3560                                          name=None):
3561  """Calculates weighted per step false positives for precision@k.
3562
3563  If `class_id` is specified, calculate binary true positives for `class_id`
3564      only.
3565  If `class_id` is not specified, calculate metrics for `k` predicted vs
3566      `n` label classes, where `n` is the 2nd dimension of `labels`.
3567
3568  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3569
3570  Args:
3571    labels: `int64` `Tensor` or `SparseTensor` with shape
3572      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
3573      target classes for the associated prediction. Commonly, N=1 and `labels`
3574      has shape [batch_size, num_labels]. [D1, ... DN] must match
3575      `predictions_idx`.
3576    predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
3577      top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
3578      match `labels`.
3579    k: Integer, k for @k metric. This is only used for default op name.
3580    class_id: Class for which we want binary metrics.
3581    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3582      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3583      dimensions must be either `1`, or the same as the corresponding `labels`
3584      dimension).
3585    name: Name of new variable, and namespace for other dependent ops.
3586
3587  Returns:
3588    A tuple of `Variable` and update `Operation`.
3589
3590  Raises:
3591    ValueError: If `weights` is not `None` and has an incompatible shape.
3592  """
3593  with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id),
3594                      (predictions_idx, labels, weights)) as scope:
3595    fp = _sparse_false_positive_at_k(
3596        predictions_idx=predictions_idx,
3597        labels=labels,
3598        class_id=class_id,
3599        weights=weights)
3600    batch_total_fp = math_ops.cast(math_ops.reduce_sum(fp), dtypes.float64)
3601
3602    var = metric_variable([], dtypes.float64, name=scope)
3603    return var, state_ops.assign_add(var, batch_total_fp, name='update')
3604
3605
3606@tf_export(v1=['metrics.precision_at_top_k'])
3607def precision_at_top_k(labels,
3608                       predictions_idx,
3609                       k=None,
3610                       class_id=None,
3611                       weights=None,
3612                       metrics_collections=None,
3613                       updates_collections=None,
3614                       name=None):
3615  """Computes precision@k of the predictions with respect to sparse labels.
3616
3617  Differs from `sparse_precision_at_k` in that predictions must be in the form
3618  of top `k` class indices, whereas `sparse_precision_at_k` expects logits.
3619  Refer to `sparse_precision_at_k` for more details.
3620
3621  Args:
3622    labels: `int64` `Tensor` or `SparseTensor` with shape
3623      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3624      num_labels=1. N >= 1 and num_labels is the number of target classes for
3625      the associated prediction. Commonly, N=1 and `labels` has shape
3626      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3627      should be in range [0, num_classes), where num_classes is the last
3628      dimension of `predictions`. Values outside this range are ignored.
3629    predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where
3630      N >= 1. Commonly, N=1 and predictions has shape [batch size, k].
3631      The final dimension contains the top `k` predicted class indices.
3632      [D1, ... DN] must match `labels`.
3633    k: Integer, k for @k metric. Only used for the default op name.
3634    class_id: Integer class ID for which we want binary metrics. This should be
3635      in range [0, num_classes], where num_classes is the last dimension of
3636      `predictions`. If `class_id` is outside this range, the method returns
3637      NAN.
3638    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3639      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3640      dimensions must be either `1`, or the same as the corresponding `labels`
3641      dimension).
3642    metrics_collections: An optional list of collections that values should
3643      be added to.
3644    updates_collections: An optional list of collections that updates should
3645      be added to.
3646    name: Name of new update operation, and namespace for other dependent ops.
3647
3648  Returns:
3649    precision: Scalar `float64` `Tensor` with the value of `true_positives`
3650      divided by the sum of `true_positives` and `false_positives`.
3651    update_op: `Operation` that increments `true_positives` and
3652      `false_positives` variables appropriately, and whose value matches
3653      `precision`.
3654
3655  Raises:
3656    ValueError: If `weights` is not `None` and its shape doesn't match
3657      `predictions`, or if either `metrics_collections` or `updates_collections`
3658      are not a list or tuple.
3659    RuntimeError: If eager execution is enabled.
3660  """
3661  if context.executing_eagerly():
3662    raise RuntimeError('tf.metrics.precision_at_top_k is not '
3663                       'supported when eager execution is enabled.')
3664
3665  with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3666                      (predictions_idx, labels, weights)) as scope:
3667    labels = _maybe_expand_labels(labels, predictions_idx)
3668    top_k_idx = math_ops.cast(predictions_idx, dtypes.int64)
3669    tp, tp_update = _streaming_sparse_true_positive_at_k(
3670        predictions_idx=top_k_idx,
3671        labels=labels,
3672        k=k,
3673        class_id=class_id,
3674        weights=weights)
3675    fp, fp_update = _streaming_sparse_false_positive_at_k(
3676        predictions_idx=top_k_idx,
3677        labels=labels,
3678        k=k,
3679        class_id=class_id,
3680        weights=weights)
3681
3682    def precision_across_replicas(_, tp, fp):
3683      return math_ops.divide(tp, math_ops.add(tp, fp), name=scope)
3684
3685    metric = _aggregate_across_replicas(
3686        metrics_collections, precision_across_replicas, tp, fp)
3687
3688    update = math_ops.divide(
3689        tp_update, math_ops.add(tp_update, fp_update), name='update')
3690    if updates_collections:
3691      ops.add_to_collections(updates_collections, update)
3692    return metric, update
3693
3694
3695@tf_export(v1=['metrics.sparse_precision_at_k'])
3696@deprecated(None, 'Use precision_at_k instead')
3697def sparse_precision_at_k(labels,
3698                          predictions,
3699                          k,
3700                          class_id=None,
3701                          weights=None,
3702                          metrics_collections=None,
3703                          updates_collections=None,
3704                          name=None):
3705  """Renamed to `precision_at_k`, please use that method instead."""
3706  return precision_at_k(
3707      labels=labels,
3708      predictions=predictions,
3709      k=k,
3710      class_id=class_id,
3711      weights=weights,
3712      metrics_collections=metrics_collections,
3713      updates_collections=updates_collections,
3714      name=name)
3715
3716
3717@tf_export(v1=['metrics.precision_at_k'])
3718def precision_at_k(labels,
3719                   predictions,
3720                   k,
3721                   class_id=None,
3722                   weights=None,
3723                   metrics_collections=None,
3724                   updates_collections=None,
3725                   name=None):
3726  """Computes precision@k of the predictions with respect to sparse labels.
3727
3728  If `class_id` is specified, we calculate precision by considering only the
3729      entries in the batch for which `class_id` is in the top-k highest
3730      `predictions`, and computing the fraction of them for which `class_id` is
3731      indeed a correct label.
3732  If `class_id` is not specified, we'll calculate precision as how often on
3733      average a class among the top-k classes with the highest predicted values
3734      of a batch entry is correct and can be found in the label for that entry.
3735
3736  `precision_at_k` creates two local variables,
3737  `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
3738  the precision@k frequency. This frequency is ultimately returned as
3739  `precision_at_<k>`: an idempotent operation that simply divides
3740  `true_positive_at_<k>` by total (`true_positive_at_<k>` +
3741  `false_positive_at_<k>`).
3742
3743  For estimation of the metric over a stream of data, the function creates an
3744  `update_op` operation that updates these variables and returns the
3745  `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
3746  indicating the top `k` `predictions`. Set operations applied to `top_k` and
3747  `labels` calculate the true positives and false positives weighted by
3748  `weights`. Then `update_op` increments `true_positive_at_<k>` and
3749  `false_positive_at_<k>` using these values.
3750
3751  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3752
3753  Args:
3754    labels: `int64` `Tensor` or `SparseTensor` with shape
3755      [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
3756      num_labels=1. N >= 1 and num_labels is the number of target classes for
3757      the associated prediction. Commonly, N=1 and `labels` has shape
3758      [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
3759      should be in range [0, num_classes), where num_classes is the last
3760      dimension of `predictions`. Values outside this range are ignored.
3761    predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
3762      N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
3763      The final dimension contains the logit values for each class. [D1, ... DN]
3764      must match `labels`.
3765    k: Integer, k for @k metric.
3766    class_id: Integer class ID for which we want binary metrics. This should be
3767      in range [0, num_classes], where num_classes is the last dimension of
3768      `predictions`. If `class_id` is outside this range, the method returns
3769      NAN.
3770    weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
3771      `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
3772      dimensions must be either `1`, or the same as the corresponding `labels`
3773      dimension).
3774    metrics_collections: An optional list of collections that values should
3775      be added to.
3776    updates_collections: An optional list of collections that updates should
3777      be added to.
3778    name: Name of new update operation, and namespace for other dependent ops.
3779
3780  Returns:
3781    precision: Scalar `float64` `Tensor` with the value of `true_positives`
3782      divided by the sum of `true_positives` and `false_positives`.
3783    update_op: `Operation` that increments `true_positives` and
3784      `false_positives` variables appropriately, and whose value matches
3785      `precision`.
3786
3787  Raises:
3788    ValueError: If `weights` is not `None` and its shape doesn't match
3789      `predictions`, or if either `metrics_collections` or `updates_collections`
3790      are not a list or tuple.
3791    RuntimeError: If eager execution is enabled.
3792  """
3793  if context.executing_eagerly():
3794    raise RuntimeError('tf.metrics.sparse_precision_at_k is not '
3795                       'supported when eager execution is enabled.')
3796
3797  with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
3798                      (predictions, labels, weights)) as scope:
3799    _, top_k_idx = nn.top_k(predictions, k)
3800    return precision_at_top_k(
3801        labels=labels,
3802        predictions_idx=top_k_idx,
3803        k=k,
3804        class_id=class_id,
3805        weights=weights,
3806        metrics_collections=metrics_collections,
3807        updates_collections=updates_collections,
3808        name=scope)
3809
3810
3811@tf_export(v1=['metrics.specificity_at_sensitivity'])
3812def specificity_at_sensitivity(labels,
3813                               predictions,
3814                               sensitivity,
3815                               weights=None,
3816                               num_thresholds=200,
3817                               metrics_collections=None,
3818                               updates_collections=None,
3819                               name=None):
3820  """Computes the specificity at a given sensitivity.
3821
3822  The `specificity_at_sensitivity` function creates four local
3823  variables, `true_positives`, `true_negatives`, `false_positives` and
3824  `false_negatives` that are used to compute the specificity at the given
3825  sensitivity value. The threshold for the given sensitivity value is computed
3826  and used to evaluate the corresponding specificity.
3827
3828  For estimation of the metric over a stream of data, the function creates an
3829  `update_op` operation that updates these variables and returns the
3830  `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
3831  `false_positives` and `false_negatives` counts with the weight of each case
3832  found in the `predictions` and `labels`.
3833
3834  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
3835
3836  For additional information about specificity and sensitivity, see the
3837  following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
3838
3839  Args:
3840    labels: The ground truth values, a `Tensor` whose dimensions must match
3841      `predictions`. Will be cast to `bool`.
3842    predictions: A floating point `Tensor` of arbitrary shape and whose values
3843      are in the range `[0, 1]`.
3844    sensitivity: A scalar value in range `[0, 1]`.
3845    weights: Optional `Tensor` whose rank is either 0, or the same rank as
3846      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
3847      be either `1`, or the same as the corresponding `labels` dimension).
3848    num_thresholds: The number of thresholds to use for matching the given
3849      sensitivity.
3850    metrics_collections: An optional list of collections that `specificity`
3851      should be added to.
3852    updates_collections: An optional list of collections that `update_op` should
3853      be added to.
3854    name: An optional variable_scope name.
3855
3856  Returns:
3857    specificity: A scalar `Tensor` representing the specificity at the given
3858      `sensitivity` value.
3859    update_op: An operation that increments the `true_positives`,
3860      `true_negatives`, `false_positives` and `false_negatives` variables
3861      appropriately and whose value matches `specificity`.
3862
3863  Raises:
3864    ValueError: If `predictions` and `labels` have mismatched shapes, if
3865      `weights` is not `None` and its shape doesn't match `predictions`, or if
3866      `sensitivity` is not between 0 and 1, or if either `metrics_collections`
3867      or `updates_collections` are not a list or tuple.
3868    RuntimeError: If eager execution is enabled.
3869  """
3870  if context.executing_eagerly():
3871    raise RuntimeError('tf.metrics.specificity_at_sensitivity is not '
3872                       'supported when eager execution is enabled.')
3873
3874  if sensitivity < 0 or sensitivity > 1:
3875    raise ValueError('`sensitivity` must be in the range [0, 1]. Currently, '
3876                     f'`sensitivity` is {sensitivity}.')
3877
3878  with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
3879                                     (predictions, labels, weights)):
3880    kepsilon = 1e-7  # to account for floating point imprecisions
3881    thresholds = [
3882        (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
3883    ]
3884    thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
3885
3886    values, update_ops = _confusion_matrix_at_thresholds(
3887        labels, predictions, thresholds, weights)
3888
3889    def compute_specificity_at_sensitivity(tp, tn, fp, fn, name):
3890      """Computes the specificity at the given sensitivity.
3891
3892      Args:
3893        tp: True positives.
3894        tn: True negatives.
3895        fp: False positives.
3896        fn: False negatives.
3897        name: The name of the operation.
3898
3899      Returns:
3900        The specificity using the aggregated values.
3901      """
3902      sensitivities = math_ops.divide(tp, tp + fn + kepsilon)
3903
3904      # We'll need to use this trick until tf.argmax allows us to specify
3905      # whether we should use the first or last index in case of ties.
3906      min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity))
3907      indices_at_minval = math_ops.equal(
3908          math_ops.abs(sensitivities - sensitivity), min_val)
3909      indices_at_minval = math_ops.cast(indices_at_minval, dtypes.int64)
3910      indices_at_minval = math_ops.cumsum(indices_at_minval)
3911      tf_index = math_ops.argmax(indices_at_minval, 0)
3912      tf_index = math_ops.cast(tf_index, dtypes.int32)
3913
3914      # Now, we have the implicit threshold, so compute the specificity:
3915      return math_ops.divide(tn[tf_index],
3916                             tn[tf_index] + fp[tf_index] + kepsilon, name)
3917
3918    def specificity_across_replicas(_, values):
3919      return compute_specificity_at_sensitivity(
3920          values['tp'], values['tn'], values['fp'], values['fn'], 'value')
3921
3922    specificity = _aggregate_across_replicas(
3923        metrics_collections, specificity_across_replicas, values)
3924
3925    update_op = compute_specificity_at_sensitivity(
3926        update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
3927        'update_op')
3928    if updates_collections:
3929      ops.add_to_collections(updates_collections, update_op)
3930
3931    return specificity, update_op
3932