xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/utils/metrics_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utils related to keras metrics."""
17
18import functools
19import weakref
20
21from enum import Enum
22
23import numpy as np
24
25from tensorflow.python.compat import compat
26from tensorflow.python.distribute import distribution_strategy_context
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.keras import backend
30from tensorflow.python.keras.utils import losses_utils
31from tensorflow.python.keras.utils import tf_utils
32from tensorflow.python.keras.utils.generic_utils import to_list
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import check_ops
35from tensorflow.python.ops import clip_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import gen_math_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import nn_ops
40from tensorflow.python.ops import variables as variables_module
41from tensorflow.python.ops import weights_broadcast_ops
42from tensorflow.python.ops.parallel_for import control_flow_ops as parallel_control_flow_ops
43from tensorflow.python.ops.ragged import ragged_tensor
44from tensorflow.python.util import tf_decorator
45
46NEG_INF = -1e10
47
48
49class Reduction(Enum):
50  """Types of metrics reduction.
51
52  Contains the following values:
53
54  * `SUM`: Scalar sum of weighted values.
55  * `SUM_OVER_BATCH_SIZE`: Scalar sum of weighted values divided by
56        number of elements.
57  * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights.
58  """
59  SUM = 'sum'
60  SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
61  WEIGHTED_MEAN = 'weighted_mean'
62
63
64def update_state_wrapper(update_state_fn):
65  """Decorator to wrap metric `update_state()` with `add_update()`.
66
67  Args:
68    update_state_fn: function that accumulates metric statistics.
69
70  Returns:
71    Decorated function that wraps `update_state_fn()` with `add_update()`.
72  """
73
74  def decorated(metric_obj, *args, **kwargs):
75    """Decorated function with `add_update()`."""
76    strategy = distribution_strategy_context.get_strategy()
77
78    for weight in metric_obj.weights:
79      if (backend.is_tpu_strategy(strategy) and
80          not strategy.extended.variable_created_in_scope(weight)
81          and not distribution_strategy_context.in_cross_replica_context()):
82        raise ValueError(
83            'Trying to run metric.update_state in replica context when '
84            'the metric was not created in TPUStrategy scope. '
85            'Make sure the keras Metric is created in TPUstrategy scope. ')
86
87    with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs):
88      update_op = update_state_fn(*args, **kwargs)
89    if update_op is not None:  # update_op will be None in eager execution.
90      metric_obj.add_update(update_op)
91    return update_op
92
93  return tf_decorator.make_decorator(update_state_fn, decorated)
94
95
96def result_wrapper(result_fn):
97  """Decorator to wrap metric `result()` function in `merge_call()`.
98
99  Result computation is an idempotent operation that simply calculates the
100  metric value using the state variables.
101
102  If metric state variables are distributed across replicas/devices and
103  `result()` is requested from the context of one device - This function wraps
104  `result()` in a distribution strategy `merge_call()`. With this,
105  the metric state variables will be aggregated across devices.
106
107  Args:
108    result_fn: function that computes the metric result.
109
110  Returns:
111    Decorated function that wraps `result_fn()` in distribution strategy
112    `merge_call()`.
113  """
114
115  def decorated(metric_obj, *args):
116    """Decorated function with merge_call."""
117    has_strategy = distribution_strategy_context.has_strategy()
118    replica_context = distribution_strategy_context.get_replica_context()
119
120    # The purpose of using `merge_call` to call `result()` is to trigger cross
121    # replica aggregation of metric state variables (SyncOnReadVariable). After
122    # we introduced `variable_sync_on_read_context`, in principle there is no
123    # need to use `merge_call` here. However the branch still exists because:
124    #
125    # 1. Keras V1 training code sometimes assumes `result_t` is the same tensor
126    #    across replicas (achieved by `merge_call`). With
127    #    `variable_sync_on_read_context` each replica gets their own tensors
128    #    residing on replica's device, thus breaking the assumption.
129    # 2. Keras c/fit creates a tf.function (a.k.a, train_function) that returns
130    #    the metric values of the first replica. With
131    #    `variable_sync_on_read_context` since each replica gets their own
132    #    tensors, the metric result tensors on the non-first replicas are not in
133    #    the return value of train_function, making TF graph optimizer prune the
134    #    branch that computes and aggregates those metric results. As a result,
135    #    if NCCL is used to do the aggregation, the program will hang because
136    #    NCCL ops are only launched on the non-pruned first replica.
137    #
138    # We condition on strategy.extended._use_merge_call() since we know if it is
139    # false, the program uses `jit_compile` to compile replica fn, meaning it is
140    # not V1 training (hence #1 is okay), and no pruning will happen as
141    # compiled functions are not inlined (hence #2 is okay).
142
143    if (not has_strategy or replica_context is None or
144        not distribution_strategy_context.get_strategy(
145        ).extended._use_merge_call()):
146      with distribution_strategy_context.variable_sync_on_read_context():
147        raw_result = result_fn(*args)
148        # Results need to be wrapped in a `tf.identity` op to ensure
149        # correct execution order.
150        if isinstance(raw_result,
151                      (ops.Tensor, variables_module.Variable, float, int)):
152          result_t = array_ops.identity(raw_result)
153        elif isinstance(raw_result, dict):
154          result_t = {
155              key: array_ops.identity(value)
156              for key, value in raw_result.items()
157          }
158        else:
159          try:
160            result_t = array_ops.identity(raw_result)
161          except (ValueError, TypeError):
162            raise RuntimeError(
163                'The output of `metric.result()` can only be a single '
164                'Tensor/Variable, or a dict of Tensors/Variables. '
165                'For metric %s, got result %s.' % (metric_obj.name, raw_result))
166    else:
167      # TODO(psv): Test distribution of metrics using different distribution
168      # strategies.
169
170      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
171      # with distribution object as the first parameter. We create a wrapper
172      # here so that the result function need not have that parameter.
173      def merge_fn_wrapper(distribution, merge_fn, *args):
174        # We will get `PerReplica` merge function. Taking the first one as all
175        # are identical copies of the function that we had passed below.
176        result = distribution.experimental_local_results(merge_fn)[0](*args)
177
178        # Wrapping result in identity so that control dependency between
179        # update_op from `update_state` and result works in case result returns
180        # a tensor.
181        return array_ops.identity(result)
182
183      # Wrapping result in merge_call. merge_call is used when we want to leave
184      # replica mode and compute a value in cross replica mode.
185      result_t = replica_context.merge_call(
186          merge_fn_wrapper, args=(result_fn,) + args)
187
188    # We are saving the result op here to be used in train/test execution
189    # functions. This basically gives the result op that was generated with a
190    # control dep to the updates for these workflows.
191    metric_obj._call_result = result_t
192    return result_t
193
194  return tf_decorator.make_decorator(result_fn, decorated)
195
196
197def weakmethod(method):
198  """Creates a weak reference to the bound method."""
199
200  cls = method.im_class
201  func = method.im_func
202  instance_ref = weakref.ref(method.im_self)
203
204  @functools.wraps(method)
205  def inner(*args, **kwargs):
206    return func.__get__(instance_ref(), cls)(*args, **kwargs)
207
208  del method
209  return inner
210
211
212def assert_thresholds_range(thresholds):
213  if thresholds is not None:
214    invalid_thresholds = [t for t in thresholds if t is None or t < 0 or t > 1]
215    if invalid_thresholds:
216      raise ValueError(
217          'Threshold values must be in [0, 1]. Invalid values: {}'.format(
218              invalid_thresholds))
219
220
221def parse_init_thresholds(thresholds, default_threshold=0.5):
222  if thresholds is not None:
223    assert_thresholds_range(to_list(thresholds))
224  thresholds = to_list(default_threshold if thresholds is None else thresholds)
225  return thresholds
226
227
228class ConfusionMatrix(Enum):
229  TRUE_POSITIVES = 'tp'
230  FALSE_POSITIVES = 'fp'
231  TRUE_NEGATIVES = 'tn'
232  FALSE_NEGATIVES = 'fn'
233
234
235class AUCCurve(Enum):
236  """Type of AUC Curve (ROC or PR)."""
237  ROC = 'ROC'
238  PR = 'PR'
239
240  @staticmethod
241  def from_str(key):
242    if key in ('pr', 'PR'):
243      return AUCCurve.PR
244    elif key in ('roc', 'ROC'):
245      return AUCCurve.ROC
246    else:
247      raise ValueError('Invalid AUC curve value "%s".' % key)
248
249
250class AUCSummationMethod(Enum):
251  """Type of AUC summation method.
252
253  https://en.wikipedia.org/wiki/Riemann_sum)
254
255  Contains the following values:
256  * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For
257    `PR` curve, interpolates (true/false) positives but not the ratio that is
258    precision (see Davis & Goadrich 2006 for details).
259  * 'minoring': Applies left summation for increasing intervals and right
260    summation for decreasing intervals.
261  * 'majoring': Applies right summation for increasing intervals and left
262    summation for decreasing intervals.
263  """
264  INTERPOLATION = 'interpolation'
265  MAJORING = 'majoring'
266  MINORING = 'minoring'
267
268  @staticmethod
269  def from_str(key):
270    if key in ('interpolation', 'Interpolation'):
271      return AUCSummationMethod.INTERPOLATION
272    elif key in ('majoring', 'Majoring'):
273      return AUCSummationMethod.MAJORING
274    elif key in ('minoring', 'Minoring'):
275      return AUCSummationMethod.MINORING
276    else:
277      raise ValueError('Invalid AUC summation method value "%s".' % key)
278
279
280def _update_confusion_matrix_variables_optimized(
281    variables_to_update,
282    y_true,
283    y_pred,
284    thresholds,
285    multi_label=False,
286    sample_weights=None,
287    label_weights=None,
288    thresholds_with_epsilon=False):
289  """Update confusion matrix variables with memory efficient alternative.
290
291  Note that the thresholds need to be evenly distributed within the list, eg,
292  the diff between consecutive elements are the same.
293
294  To compute TP/FP/TN/FN, we are measuring a binary classifier
295    C(t) = (predictions >= t)
296  at each threshold 't'. So we have
297    TP(t) = sum( C(t) * true_labels )
298    FP(t) = sum( C(t) * false_labels )
299
300  But, computing C(t) requires computation for each t. To make it fast,
301  observe that C(t) is a cumulative integral, and so if we have
302    thresholds = [t_0, ..., t_{n-1}];  t_0 < ... < t_{n-1}
303  where n = num_thresholds, and if we can compute the bucket function
304    B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
305  then we get
306    C(t_i) = sum( B(j), j >= i )
307  which is the reversed cumulative sum in tf.cumsum().
308
309  We can compute B(i) efficiently by taking advantage of the fact that
310  our thresholds are evenly distributed, in that
311    width = 1.0 / (num_thresholds - 1)
312    thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
313  Given a prediction value p, we can map it to its bucket by
314    bucket_index(p) = floor( p * (num_thresholds - 1) )
315  so we can use tf.math.unsorted_segment_sum() to update the buckets in one
316  pass.
317
318  Consider following example:
319  y_true = [0, 0, 1, 1]
320  y_pred = [0.1, 0.5, 0.3, 0.9]
321  thresholds = [0.0, 0.5, 1.0]
322  num_buckets = 2   # [0.0, 1.0], (1.0, 2.0]
323  bucket_index(y_pred) = tf.math.floor(y_pred * num_buckets)
324                       = tf.math.floor([0.2, 1.0, 0.6, 1.8])
325                       = [0, 0, 0, 1]
326  # The meaning of this bucket is that if any of the label is true,
327  # then 1 will be added to the corresponding bucket with the index.
328  # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the
329  # label for 1.8 is true, then 1 will be added to bucket 1.
330  #
331  # Note the second item "1.0" is floored to 0, since the value need to be
332  # strictly larger than the bucket lower bound.
333  # In the implementation, we use tf.math.ceil() - 1 to achieve this.
334  tp_bucket_value = tf.math.unsorted_segment_sum(true_labels, bucket_indices,
335                                                 num_segments=num_thresholds)
336                  = [1, 1, 0]
337  # For [1, 1, 0] here, it means there is 1 true value contributed by bucket 0,
338  # and 1 value contributed by bucket 1. When we aggregate them to together,
339  # the result become [a + b + c, b + c, c], since large thresholds will always
340  # contribute to the value for smaller thresholds.
341  true_positive = tf.math.cumsum(tp_bucket_value, reverse=True)
342                = [2, 1, 0]
343
344  This implementation exhibits a run time and space complexity of O(T + N),
345  where T is the number of thresholds and N is the size of predictions.
346  Metrics that rely on standard implementation instead exhibit a complexity of
347  O(T * N).
348
349  Args:
350    variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
351      and corresponding variables to update as values.
352    y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be cast
353      to `bool`.
354    y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
355      the range `[0, 1]`.
356    thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
357      It need to be evenly distributed (the diff between each element need to be
358      the same).
359    multi_label: Optional boolean indicating whether multidimensional
360      prediction/labels should be treated as multilabel responses, or flattened
361      into a single label. When True, the valus of `variables_to_update` must
362      have a second dimension equal to the number of labels in y_true and
363      y_pred, and those tensors must not be RaggedTensors.
364    sample_weights: Optional `Tensor` whose rank is either 0, or the same rank
365      as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
366      must be either `1`, or the same as the corresponding `y_true` dimension).
367    label_weights: Optional tensor of non-negative weights for multilabel
368      data. The weights are applied when calculating TP, FP, FN, and TN without
369      explicit multilabel handling (i.e. when the data is to be flattened).
370    thresholds_with_epsilon: Optional boolean indicating whether the leading and
371      tailing thresholds has any epsilon added for floating point imprecisions.
372      It will change how we handle the leading and tailing bucket.
373
374  Returns:
375    Update op.
376  """
377  num_thresholds = thresholds.shape.as_list()[0]
378
379  if sample_weights is None:
380    sample_weights = 1.0
381  else:
382    sample_weights = weights_broadcast_ops.broadcast_weights(
383        math_ops.cast(sample_weights, dtype=y_pred.dtype), y_pred)
384    if not multi_label:
385      sample_weights = array_ops.reshape(sample_weights, [-1])
386  if label_weights is None:
387    label_weights = 1.0
388  else:
389    label_weights = array_ops.expand_dims(label_weights, 0)
390    label_weights = weights_broadcast_ops.broadcast_weights(label_weights,
391                                                            y_pred)
392    if not multi_label:
393      label_weights = array_ops.reshape(label_weights, [-1])
394  weights = math_ops.multiply(sample_weights, label_weights)
395
396  # We shouldn't need this, but in case there are predict value that is out of
397  # the range of [0.0, 1.0]
398  y_pred = clip_ops.clip_by_value(y_pred,
399                                  clip_value_min=0.0, clip_value_max=1.0)
400
401  y_true = math_ops.cast(math_ops.cast(y_true, dtypes.bool), y_true.dtype)
402  if not multi_label:
403    y_true = array_ops.reshape(y_true, [-1])
404    y_pred = array_ops.reshape(y_pred, [-1])
405
406  true_labels = math_ops.multiply(y_true, weights)
407  false_labels = math_ops.multiply((1.0 - y_true), weights)
408
409  # Compute the bucket indices for each prediction value.
410  # Since the predict value has to be strictly greater than the thresholds,
411  # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket.
412  # We have to use math.ceil(val) - 1 for the bucket.
413  bucket_indices = math_ops.ceil(y_pred * (num_thresholds - 1)) - 1
414
415  if thresholds_with_epsilon:
416    # In this case, the first bucket should actually take into account since
417    # the any prediction between [0.0, 1.0] should be larger than the first
418    # threshold. We change the bucket value from -1 to 0.
419    bucket_indices = nn_ops.relu(bucket_indices)
420
421  bucket_indices = math_ops.cast(bucket_indices, dtypes.int32)
422
423  if multi_label:
424    # We need to run bucket segment sum for each of the label class. In the
425    # multi_label case, the rank of the label is 2. We first transpose it so
426    # that the label dim becomes the first and we can parallel run though them.
427    true_labels = array_ops.transpose_v2(true_labels)
428    false_labels = array_ops.transpose_v2(false_labels)
429    bucket_indices = array_ops.transpose_v2(bucket_indices)
430
431    def gather_bucket(label_and_bucket_index):
432      label, bucket_index = label_and_bucket_index[0], label_and_bucket_index[1]
433      return math_ops.unsorted_segment_sum(
434          data=label, segment_ids=bucket_index, num_segments=num_thresholds)
435    tp_bucket_v = parallel_control_flow_ops.vectorized_map(
436        gather_bucket, (true_labels, bucket_indices))
437    fp_bucket_v = parallel_control_flow_ops.vectorized_map(
438        gather_bucket, (false_labels, bucket_indices))
439    tp = array_ops.transpose_v2(
440        math_ops.cumsum(tp_bucket_v, reverse=True, axis=1))
441    fp = array_ops.transpose_v2(
442        math_ops.cumsum(fp_bucket_v, reverse=True, axis=1))
443  else:
444    tp_bucket_v = math_ops.unsorted_segment_sum(
445        data=true_labels, segment_ids=bucket_indices,
446        num_segments=num_thresholds)
447    fp_bucket_v = math_ops.unsorted_segment_sum(
448        data=false_labels, segment_ids=bucket_indices,
449        num_segments=num_thresholds)
450    tp = math_ops.cumsum(tp_bucket_v, reverse=True)
451    fp = math_ops.cumsum(fp_bucket_v, reverse=True)
452
453  # fn = sum(true_labels) - tp
454  # tn = sum(false_labels) - fp
455  if (ConfusionMatrix.TRUE_NEGATIVES in variables_to_update or
456      ConfusionMatrix.FALSE_NEGATIVES in variables_to_update):
457    if multi_label:
458      total_true_labels = math_ops.reduce_sum(true_labels, axis=1)
459      total_false_labels = math_ops.reduce_sum(false_labels, axis=1)
460    else:
461      total_true_labels = math_ops.reduce_sum(true_labels)
462      total_false_labels = math_ops.reduce_sum(false_labels)
463
464  update_ops = []
465  if ConfusionMatrix.TRUE_POSITIVES in variables_to_update:
466    variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES]
467    update_ops.append(variable.assign_add(tp))
468  if ConfusionMatrix.FALSE_POSITIVES in variables_to_update:
469    variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES]
470    update_ops.append(variable.assign_add(fp))
471  if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update:
472    variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES]
473    tn = total_false_labels - fp
474    update_ops.append(variable.assign_add(tn))
475  if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update:
476    variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES]
477    fn = total_true_labels - tp
478    update_ops.append(variable.assign_add(fn))
479  return control_flow_ops.group(update_ops)
480
481
482def is_evenly_distributed_thresholds(thresholds):
483  """Check if the thresholds list is evenly distributed.
484
485  We could leverage evenly distributed thresholds to use less memory when
486  calculate metrcis like AUC where each individual threshold need to be
487  evaluted.
488
489  Args:
490    thresholds: A python list or tuple, or 1D numpy array whose value is ranged
491      in [0, 1].
492
493  Returns:
494    boolean, whether the values in the inputs are evenly distributed.
495  """
496  # Check the list value and see if it is evenly distributed.
497  num_thresholds = len(thresholds)
498  if num_thresholds < 3:
499    return False
500  even_thresholds = np.arange(num_thresholds,
501                              dtype=np.float32) / (num_thresholds - 1)
502  return np.allclose(thresholds, even_thresholds, atol=backend.epsilon())
503
504
505def update_confusion_matrix_variables(variables_to_update,
506                                      y_true,
507                                      y_pred,
508                                      thresholds,
509                                      top_k=None,
510                                      class_id=None,
511                                      sample_weight=None,
512                                      multi_label=False,
513                                      label_weights=None,
514                                      thresholds_distributed_evenly=False):
515  """Returns op to update the given confusion matrix variables.
516
517  For every pair of values in y_true and y_pred:
518
519  true_positive: y_true == True and y_pred > thresholds
520  false_negatives: y_true == True and y_pred <= thresholds
521  true_negatives: y_true == False and y_pred <= thresholds
522  false_positive: y_true == False and y_pred > thresholds
523
524  The results will be weighted and added together. When multiple thresholds are
525  provided, we will repeat the same for every threshold.
526
527  For estimation of these metrics over a stream of data, the function creates an
528  `update_op` operation that updates the given variables.
529
530  If `sample_weight` is `None`, weights default to 1.
531  Use weights of 0 to mask values.
532
533  Args:
534    variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
535      and corresponding variables to update as values.
536    y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
537    y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
538      the range `[0, 1]`.
539    thresholds: A float value, float tensor, python list, or tuple of float
540      thresholds in `[0, 1]`, or NEG_INF (used when top_k is set).
541    top_k: Optional int, indicates that the positive labels should be limited to
542      the top k predictions.
543    class_id: Optional int, limits the prediction and labels to the class
544      specified by this argument.
545    sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
546      `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must
547      be either `1`, or the same as the corresponding `y_true` dimension).
548    multi_label: Optional boolean indicating whether multidimensional
549      prediction/labels should be treated as multilabel responses, or flattened
550      into a single label. When True, the valus of `variables_to_update` must
551      have a second dimension equal to the number of labels in y_true and
552      y_pred, and those tensors must not be RaggedTensors.
553    label_weights: (optional) tensor of non-negative weights for multilabel
554      data. The weights are applied when calculating TP, FP, FN, and TN without
555      explicit multilabel handling (i.e. when the data is to be flattened).
556    thresholds_distributed_evenly: Boolean, whether the thresholds are evenly
557      distributed within the list. An optimized method will be used if this is
558      the case. See _update_confusion_matrix_variables_optimized() for more
559      details.
560
561  Returns:
562    Update op.
563
564  Raises:
565    ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
566      `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if
567      `variables_to_update` contains invalid keys.
568  """
569  if multi_label and label_weights is not None:
570    raise ValueError('`label_weights` for multilabel data should be handled '
571                     'outside of `update_confusion_matrix_variables` when '
572                     '`multi_label` is True.')
573  if variables_to_update is None:
574    return
575  if not any(
576      key for key in variables_to_update if key in list(ConfusionMatrix)):
577    raise ValueError(
578        'Please provide at least one valid confusion matrix '
579        'variable to update. Valid variable key options are: "{}". '
580        'Received: "{}"'.format(
581            list(ConfusionMatrix), variables_to_update.keys()))
582
583  variable_dtype = list(variables_to_update.values())[0].dtype
584
585  y_true = math_ops.cast(y_true, dtype=variable_dtype)
586  y_pred = math_ops.cast(y_pred, dtype=variable_dtype)
587
588  if thresholds_distributed_evenly:
589    # Check whether the thresholds has any leading or tailing epsilon added
590    # for floating point imprecision. The leading and tailing threshold will be
591    # handled bit differently as the corner case.
592    # At this point, thresholds should be a list/array with more than 2 items,
593    # and ranged between [0, 1]. See is_evenly_distributed_thresholds() for more
594    # details.
595    thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0
596
597  thresholds = ops.convert_to_tensor_v2_with_dispatch(
598      thresholds, dtype=variable_dtype)
599  num_thresholds = thresholds.shape.as_list()[0]
600
601  if multi_label:
602    one_thresh = math_ops.equal(
603        math_ops.cast(1, dtype=dtypes.int32),
604        array_ops.rank(thresholds),
605        name='one_set_of_thresholds_cond')
606  else:
607    [y_pred,
608     y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
609                                                               sample_weight)
610    one_thresh = math_ops.cast(True, dtype=dtypes.bool)
611
612  invalid_keys = [
613      key for key in variables_to_update if key not in list(ConfusionMatrix)
614  ]
615  if invalid_keys:
616    raise ValueError(
617        'Invalid keys: {}. Valid variable key options are: "{}"'.format(
618            invalid_keys, list(ConfusionMatrix)))
619
620  with ops.control_dependencies([
621      check_ops.assert_greater_equal(
622          y_pred,
623          math_ops.cast(0.0, dtype=y_pred.dtype),
624          message='predictions must be >= 0'),
625      check_ops.assert_less_equal(
626          y_pred,
627          math_ops.cast(1.0, dtype=y_pred.dtype),
628          message='predictions must be <= 1')
629  ]):
630    if sample_weight is None:
631      y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
632          y_pred, y_true)
633    else:
634      sample_weight = math_ops.cast(sample_weight, dtype=variable_dtype)
635      y_pred, y_true, sample_weight = (
636          losses_utils.squeeze_or_expand_dimensions(
637              y_pred, y_true, sample_weight=sample_weight))
638  y_pred.shape.assert_is_compatible_with(y_true.shape)
639
640  if top_k is not None:
641    y_pred = _filter_top_k(y_pred, top_k)
642  if class_id is not None:
643    y_true = y_true[..., class_id]
644    y_pred = y_pred[..., class_id]
645
646  if thresholds_distributed_evenly and compat.forward_compatible(2021, 6, 8):
647    # The new approach will take effect after 2021/6/8, to give enough time
648    # for Brella release to pick up the new op tf.math.cumsum with float32.
649    return _update_confusion_matrix_variables_optimized(
650        variables_to_update, y_true, y_pred, thresholds,
651        multi_label=multi_label, sample_weights=sample_weight,
652        label_weights=label_weights,
653        thresholds_with_epsilon=thresholds_with_epsilon)
654
655  pred_shape = array_ops.shape(y_pred)
656  num_predictions = pred_shape[0]
657  if y_pred.shape.ndims == 1:
658    num_labels = 1
659  else:
660    num_labels = gen_math_ops.Prod(input=pred_shape[1:], axis=0)
661  thresh_label_tile = array_ops.where_v2(one_thresh, num_labels,
662                                         array_ops.ones([], dtype=dtypes.int32))
663
664  # Reshape predictions and labels, adding a dim for thresholding.
665  if multi_label:
666    predictions_extra_dim = array_ops.expand_dims(y_pred, 0)
667    labels_extra_dim = array_ops.expand_dims(
668        math_ops.cast(y_true, dtype=dtypes.bool), 0)
669  else:
670    # Flatten predictions and labels when not multilabel.
671    predictions_extra_dim = array_ops.reshape(y_pred, [1, -1])
672    labels_extra_dim = array_ops.reshape(
673        math_ops.cast(y_true, dtype=dtypes.bool), [1, -1])
674
675  # Tile the thresholds for every prediction.
676  if multi_label:
677    thresh_pretile_shape = [num_thresholds, 1, -1]
678    thresh_tiles = [1, num_predictions, thresh_label_tile]
679    data_tiles = [num_thresholds, 1, 1]
680  else:
681    thresh_pretile_shape = [num_thresholds, -1]
682    thresh_tiles = [1, num_predictions * num_labels]
683    data_tiles = [num_thresholds, 1]
684
685  thresh_tiled = array_ops.tile(
686      array_ops.reshape(thresholds, thresh_pretile_shape),
687      array_ops.stack(thresh_tiles))
688
689  # Tile the predictions for every threshold.
690  preds_tiled = array_ops.tile(predictions_extra_dim, data_tiles)
691
692  # Compare predictions and threshold.
693  pred_is_pos = math_ops.greater(preds_tiled, thresh_tiled)
694
695  # Tile labels by number of thresholds
696  label_is_pos = array_ops.tile(labels_extra_dim, data_tiles)
697
698  if sample_weight is not None:
699    sample_weight = weights_broadcast_ops.broadcast_weights(
700        math_ops.cast(sample_weight, dtype=variable_dtype), y_pred)
701    weights_tiled = array_ops.tile(
702        array_ops.reshape(sample_weight, thresh_tiles), data_tiles)
703  else:
704    weights_tiled = None
705
706  if label_weights is not None and not multi_label:
707    label_weights = array_ops.expand_dims(label_weights, 0)
708    label_weights = weights_broadcast_ops.broadcast_weights(label_weights,
709                                                            y_pred)
710    label_weights_tiled = array_ops.tile(
711        array_ops.reshape(label_weights, thresh_tiles), data_tiles)
712    if weights_tiled is None:
713      weights_tiled = label_weights_tiled
714    else:
715      weights_tiled = math_ops.multiply(weights_tiled, label_weights_tiled)
716
717  update_ops = []
718
719  def weighted_assign_add(label, pred, weights, var):
720    label_and_pred = math_ops.cast(
721        math_ops.logical_and(label, pred), dtype=var.dtype)
722    if weights is not None:
723      label_and_pred *= math_ops.cast(weights, dtype=var.dtype)
724    return var.assign_add(math_ops.reduce_sum(label_and_pred, 1))
725
726  loop_vars = {
727      ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
728  }
729  update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
730  update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
731  update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update
732
733  if update_fn or update_tn:
734    pred_is_neg = math_ops.logical_not(pred_is_pos)
735    loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)
736
737  if update_fp or update_tn:
738    label_is_neg = math_ops.logical_not(label_is_pos)
739    loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)
740    if update_tn:
741      loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg)
742
743  for matrix_cond, (label, pred) in loop_vars.items():
744
745    if matrix_cond in variables_to_update:
746      update_ops.append(
747          weighted_assign_add(label, pred, weights_tiled,
748                              variables_to_update[matrix_cond]))
749
750  return control_flow_ops.group(update_ops)
751
752
753def _filter_top_k(x, k):
754  """Filters top-k values in the last dim of x and set the rest to NEG_INF.
755
756  Used for computing top-k prediction values in dense labels (which has the same
757  shape as predictions) for recall and precision top-k metrics.
758
759  Args:
760    x: tensor with any dimensions.
761    k: the number of values to keep.
762
763  Returns:
764    tensor with same shape and dtype as x.
765  """
766  _, top_k_idx = nn_ops.top_k(x, k, sorted=False)
767  top_k_mask = math_ops.reduce_sum(
768      array_ops.one_hot(top_k_idx, array_ops.shape(x)[-1], axis=-1), axis=-2)
769  return x * top_k_mask + NEG_INF * (1 - top_k_mask)
770
771
772def ragged_assert_compatible_and_get_flat_values(values, mask=None):
773  """If ragged, it checks the compatibility and then returns the flat_values.
774
775     Note: If two tensors are dense, it does not check their compatibility.
776     Note: Although two ragged tensors with different ragged ranks could have
777           identical overall rank and dimension sizes and hence be compatible,
778           we do not support those cases.
779  Args:
780     values: A list of potentially ragged tensor of the same ragged_rank.
781     mask: A potentially ragged tensor of the same ragged_rank as elements in
782       Values.
783
784  Returns:
785     A tuple in which the first element is the list of tensors and the second
786     is the mask tensor. ([Values], mask). Mask and the element in Values
787     are equal to the flat_values of the input arguments (if they were ragged).
788  """
789  if isinstance(values, list):
790    is_all_ragged = \
791        all(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
792    is_any_ragged = \
793        any(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values)
794  else:
795    is_all_ragged = isinstance(values, ragged_tensor.RaggedTensor)
796    is_any_ragged = is_all_ragged
797  if (is_all_ragged and
798      ((mask is None) or isinstance(mask, ragged_tensor.RaggedTensor))):
799    to_be_stripped = False
800    if not isinstance(values, list):
801      values = [values]
802      to_be_stripped = True
803
804    # NOTE: we leave the flat_values compatibility to
805    # tf.TensorShape `assert_is_compatible_with`
806    # check if both dynamic dimensions are equal and then use the flat_values.
807    nested_row_split_list = [rt.nested_row_splits for rt in values]
808    assertion_list = _assert_splits_match(nested_row_split_list)
809
810    # if both are ragged sample_weights also should be ragged with same dims.
811    if isinstance(mask, ragged_tensor.RaggedTensor):
812      assertion_list_for_mask = _assert_splits_match(
813          [nested_row_split_list[0], mask.nested_row_splits])
814      with ops.control_dependencies(assertion_list_for_mask):
815        mask = array_ops.expand_dims(mask.flat_values, -1)
816
817    # values has at least 1 element.
818    flat_values = []
819    for value in values:
820      with ops.control_dependencies(assertion_list):
821        flat_values.append(array_ops.expand_dims(value.flat_values, -1))
822
823    values = flat_values[0] if to_be_stripped else flat_values
824
825  elif is_any_ragged:
826    raise TypeError('One of the inputs does not have acceptable types.')
827  # values are empty or value are not ragged and mask is ragged.
828  elif isinstance(mask, ragged_tensor.RaggedTensor):
829    raise TypeError('Ragged mask is not allowed with non-ragged inputs.')
830
831  return values, mask
832
833
834def _assert_splits_match(nested_splits_lists):
835  """Checks that the given splits lists are identical.
836
837  Performs static tests to ensure that the given splits lists are identical,
838  and returns a list of control dependency op tensors that check that they are
839  fully identical.
840
841  Args:
842    nested_splits_lists: A list of nested_splits_lists, where each split_list is
843      a list of `splits` tensors from a `RaggedTensor`, ordered from outermost
844      ragged dimension to innermost ragged dimension.
845
846  Returns:
847    A list of control dependency op tensors.
848  Raises:
849    ValueError: If the splits are not identical.
850  """
851  error_msg = 'Inputs must have identical ragged splits'
852  for splits_list in nested_splits_lists:
853    if len(splits_list) != len(nested_splits_lists[0]):
854      raise ValueError(error_msg)
855  return [
856      check_ops.assert_equal(s1, s2, message=error_msg)  # pylint: disable=g-complex-comprehension
857      for splits_list in nested_splits_lists[1:]
858      for (s1, s2) in zip(nested_splits_lists[0], splits_list)
859  ]
860