1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Utils related to keras metrics.""" 17 18import functools 19import weakref 20 21from enum import Enum 22 23import numpy as np 24 25from tensorflow.python.compat import compat 26from tensorflow.python.distribute import distribution_strategy_context 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.keras import backend 30from tensorflow.python.keras.utils import losses_utils 31from tensorflow.python.keras.utils import tf_utils 32from tensorflow.python.keras.utils.generic_utils import to_list 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import check_ops 35from tensorflow.python.ops import clip_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import gen_math_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import nn_ops 40from tensorflow.python.ops import variables as variables_module 41from tensorflow.python.ops import weights_broadcast_ops 42from tensorflow.python.ops.parallel_for import control_flow_ops as parallel_control_flow_ops 43from tensorflow.python.ops.ragged import ragged_tensor 44from tensorflow.python.util import tf_decorator 45 46NEG_INF = -1e10 47 48 49class Reduction(Enum): 50 """Types of metrics reduction. 51 52 Contains the following values: 53 54 * `SUM`: Scalar sum of weighted values. 55 * `SUM_OVER_BATCH_SIZE`: Scalar sum of weighted values divided by 56 number of elements. 57 * `WEIGHTED_MEAN`: Scalar sum of weighted values divided by sum of weights. 58 """ 59 SUM = 'sum' 60 SUM_OVER_BATCH_SIZE = 'sum_over_batch_size' 61 WEIGHTED_MEAN = 'weighted_mean' 62 63 64def update_state_wrapper(update_state_fn): 65 """Decorator to wrap metric `update_state()` with `add_update()`. 66 67 Args: 68 update_state_fn: function that accumulates metric statistics. 69 70 Returns: 71 Decorated function that wraps `update_state_fn()` with `add_update()`. 72 """ 73 74 def decorated(metric_obj, *args, **kwargs): 75 """Decorated function with `add_update()`.""" 76 strategy = distribution_strategy_context.get_strategy() 77 78 for weight in metric_obj.weights: 79 if (backend.is_tpu_strategy(strategy) and 80 not strategy.extended.variable_created_in_scope(weight) 81 and not distribution_strategy_context.in_cross_replica_context()): 82 raise ValueError( 83 'Trying to run metric.update_state in replica context when ' 84 'the metric was not created in TPUStrategy scope. ' 85 'Make sure the keras Metric is created in TPUstrategy scope. ') 86 87 with tf_utils.graph_context_for_symbolic_tensors(*args, **kwargs): 88 update_op = update_state_fn(*args, **kwargs) 89 if update_op is not None: # update_op will be None in eager execution. 90 metric_obj.add_update(update_op) 91 return update_op 92 93 return tf_decorator.make_decorator(update_state_fn, decorated) 94 95 96def result_wrapper(result_fn): 97 """Decorator to wrap metric `result()` function in `merge_call()`. 98 99 Result computation is an idempotent operation that simply calculates the 100 metric value using the state variables. 101 102 If metric state variables are distributed across replicas/devices and 103 `result()` is requested from the context of one device - This function wraps 104 `result()` in a distribution strategy `merge_call()`. With this, 105 the metric state variables will be aggregated across devices. 106 107 Args: 108 result_fn: function that computes the metric result. 109 110 Returns: 111 Decorated function that wraps `result_fn()` in distribution strategy 112 `merge_call()`. 113 """ 114 115 def decorated(metric_obj, *args): 116 """Decorated function with merge_call.""" 117 has_strategy = distribution_strategy_context.has_strategy() 118 replica_context = distribution_strategy_context.get_replica_context() 119 120 # The purpose of using `merge_call` to call `result()` is to trigger cross 121 # replica aggregation of metric state variables (SyncOnReadVariable). After 122 # we introduced `variable_sync_on_read_context`, in principle there is no 123 # need to use `merge_call` here. However the branch still exists because: 124 # 125 # 1. Keras V1 training code sometimes assumes `result_t` is the same tensor 126 # across replicas (achieved by `merge_call`). With 127 # `variable_sync_on_read_context` each replica gets their own tensors 128 # residing on replica's device, thus breaking the assumption. 129 # 2. Keras c/fit creates a tf.function (a.k.a, train_function) that returns 130 # the metric values of the first replica. With 131 # `variable_sync_on_read_context` since each replica gets their own 132 # tensors, the metric result tensors on the non-first replicas are not in 133 # the return value of train_function, making TF graph optimizer prune the 134 # branch that computes and aggregates those metric results. As a result, 135 # if NCCL is used to do the aggregation, the program will hang because 136 # NCCL ops are only launched on the non-pruned first replica. 137 # 138 # We condition on strategy.extended._use_merge_call() since we know if it is 139 # false, the program uses `jit_compile` to compile replica fn, meaning it is 140 # not V1 training (hence #1 is okay), and no pruning will happen as 141 # compiled functions are not inlined (hence #2 is okay). 142 143 if (not has_strategy or replica_context is None or 144 not distribution_strategy_context.get_strategy( 145 ).extended._use_merge_call()): 146 with distribution_strategy_context.variable_sync_on_read_context(): 147 raw_result = result_fn(*args) 148 # Results need to be wrapped in a `tf.identity` op to ensure 149 # correct execution order. 150 if isinstance(raw_result, 151 (ops.Tensor, variables_module.Variable, float, int)): 152 result_t = array_ops.identity(raw_result) 153 elif isinstance(raw_result, dict): 154 result_t = { 155 key: array_ops.identity(value) 156 for key, value in raw_result.items() 157 } 158 else: 159 try: 160 result_t = array_ops.identity(raw_result) 161 except (ValueError, TypeError): 162 raise RuntimeError( 163 'The output of `metric.result()` can only be a single ' 164 'Tensor/Variable, or a dict of Tensors/Variables. ' 165 'For metric %s, got result %s.' % (metric_obj.name, raw_result)) 166 else: 167 # TODO(psv): Test distribution of metrics using different distribution 168 # strategies. 169 170 # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn 171 # with distribution object as the first parameter. We create a wrapper 172 # here so that the result function need not have that parameter. 173 def merge_fn_wrapper(distribution, merge_fn, *args): 174 # We will get `PerReplica` merge function. Taking the first one as all 175 # are identical copies of the function that we had passed below. 176 result = distribution.experimental_local_results(merge_fn)[0](*args) 177 178 # Wrapping result in identity so that control dependency between 179 # update_op from `update_state` and result works in case result returns 180 # a tensor. 181 return array_ops.identity(result) 182 183 # Wrapping result in merge_call. merge_call is used when we want to leave 184 # replica mode and compute a value in cross replica mode. 185 result_t = replica_context.merge_call( 186 merge_fn_wrapper, args=(result_fn,) + args) 187 188 # We are saving the result op here to be used in train/test execution 189 # functions. This basically gives the result op that was generated with a 190 # control dep to the updates for these workflows. 191 metric_obj._call_result = result_t 192 return result_t 193 194 return tf_decorator.make_decorator(result_fn, decorated) 195 196 197def weakmethod(method): 198 """Creates a weak reference to the bound method.""" 199 200 cls = method.im_class 201 func = method.im_func 202 instance_ref = weakref.ref(method.im_self) 203 204 @functools.wraps(method) 205 def inner(*args, **kwargs): 206 return func.__get__(instance_ref(), cls)(*args, **kwargs) 207 208 del method 209 return inner 210 211 212def assert_thresholds_range(thresholds): 213 if thresholds is not None: 214 invalid_thresholds = [t for t in thresholds if t is None or t < 0 or t > 1] 215 if invalid_thresholds: 216 raise ValueError( 217 'Threshold values must be in [0, 1]. Invalid values: {}'.format( 218 invalid_thresholds)) 219 220 221def parse_init_thresholds(thresholds, default_threshold=0.5): 222 if thresholds is not None: 223 assert_thresholds_range(to_list(thresholds)) 224 thresholds = to_list(default_threshold if thresholds is None else thresholds) 225 return thresholds 226 227 228class ConfusionMatrix(Enum): 229 TRUE_POSITIVES = 'tp' 230 FALSE_POSITIVES = 'fp' 231 TRUE_NEGATIVES = 'tn' 232 FALSE_NEGATIVES = 'fn' 233 234 235class AUCCurve(Enum): 236 """Type of AUC Curve (ROC or PR).""" 237 ROC = 'ROC' 238 PR = 'PR' 239 240 @staticmethod 241 def from_str(key): 242 if key in ('pr', 'PR'): 243 return AUCCurve.PR 244 elif key in ('roc', 'ROC'): 245 return AUCCurve.ROC 246 else: 247 raise ValueError('Invalid AUC curve value "%s".' % key) 248 249 250class AUCSummationMethod(Enum): 251 """Type of AUC summation method. 252 253 https://en.wikipedia.org/wiki/Riemann_sum) 254 255 Contains the following values: 256 * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For 257 `PR` curve, interpolates (true/false) positives but not the ratio that is 258 precision (see Davis & Goadrich 2006 for details). 259 * 'minoring': Applies left summation for increasing intervals and right 260 summation for decreasing intervals. 261 * 'majoring': Applies right summation for increasing intervals and left 262 summation for decreasing intervals. 263 """ 264 INTERPOLATION = 'interpolation' 265 MAJORING = 'majoring' 266 MINORING = 'minoring' 267 268 @staticmethod 269 def from_str(key): 270 if key in ('interpolation', 'Interpolation'): 271 return AUCSummationMethod.INTERPOLATION 272 elif key in ('majoring', 'Majoring'): 273 return AUCSummationMethod.MAJORING 274 elif key in ('minoring', 'Minoring'): 275 return AUCSummationMethod.MINORING 276 else: 277 raise ValueError('Invalid AUC summation method value "%s".' % key) 278 279 280def _update_confusion_matrix_variables_optimized( 281 variables_to_update, 282 y_true, 283 y_pred, 284 thresholds, 285 multi_label=False, 286 sample_weights=None, 287 label_weights=None, 288 thresholds_with_epsilon=False): 289 """Update confusion matrix variables with memory efficient alternative. 290 291 Note that the thresholds need to be evenly distributed within the list, eg, 292 the diff between consecutive elements are the same. 293 294 To compute TP/FP/TN/FN, we are measuring a binary classifier 295 C(t) = (predictions >= t) 296 at each threshold 't'. So we have 297 TP(t) = sum( C(t) * true_labels ) 298 FP(t) = sum( C(t) * false_labels ) 299 300 But, computing C(t) requires computation for each t. To make it fast, 301 observe that C(t) is a cumulative integral, and so if we have 302 thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} 303 where n = num_thresholds, and if we can compute the bucket function 304 B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) 305 then we get 306 C(t_i) = sum( B(j), j >= i ) 307 which is the reversed cumulative sum in tf.cumsum(). 308 309 We can compute B(i) efficiently by taking advantage of the fact that 310 our thresholds are evenly distributed, in that 311 width = 1.0 / (num_thresholds - 1) 312 thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] 313 Given a prediction value p, we can map it to its bucket by 314 bucket_index(p) = floor( p * (num_thresholds - 1) ) 315 so we can use tf.math.unsorted_segment_sum() to update the buckets in one 316 pass. 317 318 Consider following example: 319 y_true = [0, 0, 1, 1] 320 y_pred = [0.1, 0.5, 0.3, 0.9] 321 thresholds = [0.0, 0.5, 1.0] 322 num_buckets = 2 # [0.0, 1.0], (1.0, 2.0] 323 bucket_index(y_pred) = tf.math.floor(y_pred * num_buckets) 324 = tf.math.floor([0.2, 1.0, 0.6, 1.8]) 325 = [0, 0, 0, 1] 326 # The meaning of this bucket is that if any of the label is true, 327 # then 1 will be added to the corresponding bucket with the index. 328 # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the 329 # label for 1.8 is true, then 1 will be added to bucket 1. 330 # 331 # Note the second item "1.0" is floored to 0, since the value need to be 332 # strictly larger than the bucket lower bound. 333 # In the implementation, we use tf.math.ceil() - 1 to achieve this. 334 tp_bucket_value = tf.math.unsorted_segment_sum(true_labels, bucket_indices, 335 num_segments=num_thresholds) 336 = [1, 1, 0] 337 # For [1, 1, 0] here, it means there is 1 true value contributed by bucket 0, 338 # and 1 value contributed by bucket 1. When we aggregate them to together, 339 # the result become [a + b + c, b + c, c], since large thresholds will always 340 # contribute to the value for smaller thresholds. 341 true_positive = tf.math.cumsum(tp_bucket_value, reverse=True) 342 = [2, 1, 0] 343 344 This implementation exhibits a run time and space complexity of O(T + N), 345 where T is the number of thresholds and N is the size of predictions. 346 Metrics that rely on standard implementation instead exhibit a complexity of 347 O(T * N). 348 349 Args: 350 variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys 351 and corresponding variables to update as values. 352 y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be cast 353 to `bool`. 354 y_pred: A floating point `Tensor` of arbitrary shape and whose values are in 355 the range `[0, 1]`. 356 thresholds: A sorted floating point `Tensor` with value in `[0, 1]`. 357 It need to be evenly distributed (the diff between each element need to be 358 the same). 359 multi_label: Optional boolean indicating whether multidimensional 360 prediction/labels should be treated as multilabel responses, or flattened 361 into a single label. When True, the valus of `variables_to_update` must 362 have a second dimension equal to the number of labels in y_true and 363 y_pred, and those tensors must not be RaggedTensors. 364 sample_weights: Optional `Tensor` whose rank is either 0, or the same rank 365 as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions 366 must be either `1`, or the same as the corresponding `y_true` dimension). 367 label_weights: Optional tensor of non-negative weights for multilabel 368 data. The weights are applied when calculating TP, FP, FN, and TN without 369 explicit multilabel handling (i.e. when the data is to be flattened). 370 thresholds_with_epsilon: Optional boolean indicating whether the leading and 371 tailing thresholds has any epsilon added for floating point imprecisions. 372 It will change how we handle the leading and tailing bucket. 373 374 Returns: 375 Update op. 376 """ 377 num_thresholds = thresholds.shape.as_list()[0] 378 379 if sample_weights is None: 380 sample_weights = 1.0 381 else: 382 sample_weights = weights_broadcast_ops.broadcast_weights( 383 math_ops.cast(sample_weights, dtype=y_pred.dtype), y_pred) 384 if not multi_label: 385 sample_weights = array_ops.reshape(sample_weights, [-1]) 386 if label_weights is None: 387 label_weights = 1.0 388 else: 389 label_weights = array_ops.expand_dims(label_weights, 0) 390 label_weights = weights_broadcast_ops.broadcast_weights(label_weights, 391 y_pred) 392 if not multi_label: 393 label_weights = array_ops.reshape(label_weights, [-1]) 394 weights = math_ops.multiply(sample_weights, label_weights) 395 396 # We shouldn't need this, but in case there are predict value that is out of 397 # the range of [0.0, 1.0] 398 y_pred = clip_ops.clip_by_value(y_pred, 399 clip_value_min=0.0, clip_value_max=1.0) 400 401 y_true = math_ops.cast(math_ops.cast(y_true, dtypes.bool), y_true.dtype) 402 if not multi_label: 403 y_true = array_ops.reshape(y_true, [-1]) 404 y_pred = array_ops.reshape(y_pred, [-1]) 405 406 true_labels = math_ops.multiply(y_true, weights) 407 false_labels = math_ops.multiply((1.0 - y_true), weights) 408 409 # Compute the bucket indices for each prediction value. 410 # Since the predict value has to be strictly greater than the thresholds, 411 # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket. 412 # We have to use math.ceil(val) - 1 for the bucket. 413 bucket_indices = math_ops.ceil(y_pred * (num_thresholds - 1)) - 1 414 415 if thresholds_with_epsilon: 416 # In this case, the first bucket should actually take into account since 417 # the any prediction between [0.0, 1.0] should be larger than the first 418 # threshold. We change the bucket value from -1 to 0. 419 bucket_indices = nn_ops.relu(bucket_indices) 420 421 bucket_indices = math_ops.cast(bucket_indices, dtypes.int32) 422 423 if multi_label: 424 # We need to run bucket segment sum for each of the label class. In the 425 # multi_label case, the rank of the label is 2. We first transpose it so 426 # that the label dim becomes the first and we can parallel run though them. 427 true_labels = array_ops.transpose_v2(true_labels) 428 false_labels = array_ops.transpose_v2(false_labels) 429 bucket_indices = array_ops.transpose_v2(bucket_indices) 430 431 def gather_bucket(label_and_bucket_index): 432 label, bucket_index = label_and_bucket_index[0], label_and_bucket_index[1] 433 return math_ops.unsorted_segment_sum( 434 data=label, segment_ids=bucket_index, num_segments=num_thresholds) 435 tp_bucket_v = parallel_control_flow_ops.vectorized_map( 436 gather_bucket, (true_labels, bucket_indices)) 437 fp_bucket_v = parallel_control_flow_ops.vectorized_map( 438 gather_bucket, (false_labels, bucket_indices)) 439 tp = array_ops.transpose_v2( 440 math_ops.cumsum(tp_bucket_v, reverse=True, axis=1)) 441 fp = array_ops.transpose_v2( 442 math_ops.cumsum(fp_bucket_v, reverse=True, axis=1)) 443 else: 444 tp_bucket_v = math_ops.unsorted_segment_sum( 445 data=true_labels, segment_ids=bucket_indices, 446 num_segments=num_thresholds) 447 fp_bucket_v = math_ops.unsorted_segment_sum( 448 data=false_labels, segment_ids=bucket_indices, 449 num_segments=num_thresholds) 450 tp = math_ops.cumsum(tp_bucket_v, reverse=True) 451 fp = math_ops.cumsum(fp_bucket_v, reverse=True) 452 453 # fn = sum(true_labels) - tp 454 # tn = sum(false_labels) - fp 455 if (ConfusionMatrix.TRUE_NEGATIVES in variables_to_update or 456 ConfusionMatrix.FALSE_NEGATIVES in variables_to_update): 457 if multi_label: 458 total_true_labels = math_ops.reduce_sum(true_labels, axis=1) 459 total_false_labels = math_ops.reduce_sum(false_labels, axis=1) 460 else: 461 total_true_labels = math_ops.reduce_sum(true_labels) 462 total_false_labels = math_ops.reduce_sum(false_labels) 463 464 update_ops = [] 465 if ConfusionMatrix.TRUE_POSITIVES in variables_to_update: 466 variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES] 467 update_ops.append(variable.assign_add(tp)) 468 if ConfusionMatrix.FALSE_POSITIVES in variables_to_update: 469 variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES] 470 update_ops.append(variable.assign_add(fp)) 471 if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update: 472 variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES] 473 tn = total_false_labels - fp 474 update_ops.append(variable.assign_add(tn)) 475 if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update: 476 variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES] 477 fn = total_true_labels - tp 478 update_ops.append(variable.assign_add(fn)) 479 return control_flow_ops.group(update_ops) 480 481 482def is_evenly_distributed_thresholds(thresholds): 483 """Check if the thresholds list is evenly distributed. 484 485 We could leverage evenly distributed thresholds to use less memory when 486 calculate metrcis like AUC where each individual threshold need to be 487 evaluted. 488 489 Args: 490 thresholds: A python list or tuple, or 1D numpy array whose value is ranged 491 in [0, 1]. 492 493 Returns: 494 boolean, whether the values in the inputs are evenly distributed. 495 """ 496 # Check the list value and see if it is evenly distributed. 497 num_thresholds = len(thresholds) 498 if num_thresholds < 3: 499 return False 500 even_thresholds = np.arange(num_thresholds, 501 dtype=np.float32) / (num_thresholds - 1) 502 return np.allclose(thresholds, even_thresholds, atol=backend.epsilon()) 503 504 505def update_confusion_matrix_variables(variables_to_update, 506 y_true, 507 y_pred, 508 thresholds, 509 top_k=None, 510 class_id=None, 511 sample_weight=None, 512 multi_label=False, 513 label_weights=None, 514 thresholds_distributed_evenly=False): 515 """Returns op to update the given confusion matrix variables. 516 517 For every pair of values in y_true and y_pred: 518 519 true_positive: y_true == True and y_pred > thresholds 520 false_negatives: y_true == True and y_pred <= thresholds 521 true_negatives: y_true == False and y_pred <= thresholds 522 false_positive: y_true == False and y_pred > thresholds 523 524 The results will be weighted and added together. When multiple thresholds are 525 provided, we will repeat the same for every threshold. 526 527 For estimation of these metrics over a stream of data, the function creates an 528 `update_op` operation that updates the given variables. 529 530 If `sample_weight` is `None`, weights default to 1. 531 Use weights of 0 to mask values. 532 533 Args: 534 variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys 535 and corresponding variables to update as values. 536 y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`. 537 y_pred: A floating point `Tensor` of arbitrary shape and whose values are in 538 the range `[0, 1]`. 539 thresholds: A float value, float tensor, python list, or tuple of float 540 thresholds in `[0, 1]`, or NEG_INF (used when top_k is set). 541 top_k: Optional int, indicates that the positive labels should be limited to 542 the top k predictions. 543 class_id: Optional int, limits the prediction and labels to the class 544 specified by this argument. 545 sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as 546 `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must 547 be either `1`, or the same as the corresponding `y_true` dimension). 548 multi_label: Optional boolean indicating whether multidimensional 549 prediction/labels should be treated as multilabel responses, or flattened 550 into a single label. When True, the valus of `variables_to_update` must 551 have a second dimension equal to the number of labels in y_true and 552 y_pred, and those tensors must not be RaggedTensors. 553 label_weights: (optional) tensor of non-negative weights for multilabel 554 data. The weights are applied when calculating TP, FP, FN, and TN without 555 explicit multilabel handling (i.e. when the data is to be flattened). 556 thresholds_distributed_evenly: Boolean, whether the thresholds are evenly 557 distributed within the list. An optimized method will be used if this is 558 the case. See _update_confusion_matrix_variables_optimized() for more 559 details. 560 561 Returns: 562 Update op. 563 564 Raises: 565 ValueError: If `y_pred` and `y_true` have mismatched shapes, or if 566 `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if 567 `variables_to_update` contains invalid keys. 568 """ 569 if multi_label and label_weights is not None: 570 raise ValueError('`label_weights` for multilabel data should be handled ' 571 'outside of `update_confusion_matrix_variables` when ' 572 '`multi_label` is True.') 573 if variables_to_update is None: 574 return 575 if not any( 576 key for key in variables_to_update if key in list(ConfusionMatrix)): 577 raise ValueError( 578 'Please provide at least one valid confusion matrix ' 579 'variable to update. Valid variable key options are: "{}". ' 580 'Received: "{}"'.format( 581 list(ConfusionMatrix), variables_to_update.keys())) 582 583 variable_dtype = list(variables_to_update.values())[0].dtype 584 585 y_true = math_ops.cast(y_true, dtype=variable_dtype) 586 y_pred = math_ops.cast(y_pred, dtype=variable_dtype) 587 588 if thresholds_distributed_evenly: 589 # Check whether the thresholds has any leading or tailing epsilon added 590 # for floating point imprecision. The leading and tailing threshold will be 591 # handled bit differently as the corner case. 592 # At this point, thresholds should be a list/array with more than 2 items, 593 # and ranged between [0, 1]. See is_evenly_distributed_thresholds() for more 594 # details. 595 thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0 596 597 thresholds = ops.convert_to_tensor_v2_with_dispatch( 598 thresholds, dtype=variable_dtype) 599 num_thresholds = thresholds.shape.as_list()[0] 600 601 if multi_label: 602 one_thresh = math_ops.equal( 603 math_ops.cast(1, dtype=dtypes.int32), 604 array_ops.rank(thresholds), 605 name='one_set_of_thresholds_cond') 606 else: 607 [y_pred, 608 y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true], 609 sample_weight) 610 one_thresh = math_ops.cast(True, dtype=dtypes.bool) 611 612 invalid_keys = [ 613 key for key in variables_to_update if key not in list(ConfusionMatrix) 614 ] 615 if invalid_keys: 616 raise ValueError( 617 'Invalid keys: {}. Valid variable key options are: "{}"'.format( 618 invalid_keys, list(ConfusionMatrix))) 619 620 with ops.control_dependencies([ 621 check_ops.assert_greater_equal( 622 y_pred, 623 math_ops.cast(0.0, dtype=y_pred.dtype), 624 message='predictions must be >= 0'), 625 check_ops.assert_less_equal( 626 y_pred, 627 math_ops.cast(1.0, dtype=y_pred.dtype), 628 message='predictions must be <= 1') 629 ]): 630 if sample_weight is None: 631 y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( 632 y_pred, y_true) 633 else: 634 sample_weight = math_ops.cast(sample_weight, dtype=variable_dtype) 635 y_pred, y_true, sample_weight = ( 636 losses_utils.squeeze_or_expand_dimensions( 637 y_pred, y_true, sample_weight=sample_weight)) 638 y_pred.shape.assert_is_compatible_with(y_true.shape) 639 640 if top_k is not None: 641 y_pred = _filter_top_k(y_pred, top_k) 642 if class_id is not None: 643 y_true = y_true[..., class_id] 644 y_pred = y_pred[..., class_id] 645 646 if thresholds_distributed_evenly and compat.forward_compatible(2021, 6, 8): 647 # The new approach will take effect after 2021/6/8, to give enough time 648 # for Brella release to pick up the new op tf.math.cumsum with float32. 649 return _update_confusion_matrix_variables_optimized( 650 variables_to_update, y_true, y_pred, thresholds, 651 multi_label=multi_label, sample_weights=sample_weight, 652 label_weights=label_weights, 653 thresholds_with_epsilon=thresholds_with_epsilon) 654 655 pred_shape = array_ops.shape(y_pred) 656 num_predictions = pred_shape[0] 657 if y_pred.shape.ndims == 1: 658 num_labels = 1 659 else: 660 num_labels = gen_math_ops.Prod(input=pred_shape[1:], axis=0) 661 thresh_label_tile = array_ops.where_v2(one_thresh, num_labels, 662 array_ops.ones([], dtype=dtypes.int32)) 663 664 # Reshape predictions and labels, adding a dim for thresholding. 665 if multi_label: 666 predictions_extra_dim = array_ops.expand_dims(y_pred, 0) 667 labels_extra_dim = array_ops.expand_dims( 668 math_ops.cast(y_true, dtype=dtypes.bool), 0) 669 else: 670 # Flatten predictions and labels when not multilabel. 671 predictions_extra_dim = array_ops.reshape(y_pred, [1, -1]) 672 labels_extra_dim = array_ops.reshape( 673 math_ops.cast(y_true, dtype=dtypes.bool), [1, -1]) 674 675 # Tile the thresholds for every prediction. 676 if multi_label: 677 thresh_pretile_shape = [num_thresholds, 1, -1] 678 thresh_tiles = [1, num_predictions, thresh_label_tile] 679 data_tiles = [num_thresholds, 1, 1] 680 else: 681 thresh_pretile_shape = [num_thresholds, -1] 682 thresh_tiles = [1, num_predictions * num_labels] 683 data_tiles = [num_thresholds, 1] 684 685 thresh_tiled = array_ops.tile( 686 array_ops.reshape(thresholds, thresh_pretile_shape), 687 array_ops.stack(thresh_tiles)) 688 689 # Tile the predictions for every threshold. 690 preds_tiled = array_ops.tile(predictions_extra_dim, data_tiles) 691 692 # Compare predictions and threshold. 693 pred_is_pos = math_ops.greater(preds_tiled, thresh_tiled) 694 695 # Tile labels by number of thresholds 696 label_is_pos = array_ops.tile(labels_extra_dim, data_tiles) 697 698 if sample_weight is not None: 699 sample_weight = weights_broadcast_ops.broadcast_weights( 700 math_ops.cast(sample_weight, dtype=variable_dtype), y_pred) 701 weights_tiled = array_ops.tile( 702 array_ops.reshape(sample_weight, thresh_tiles), data_tiles) 703 else: 704 weights_tiled = None 705 706 if label_weights is not None and not multi_label: 707 label_weights = array_ops.expand_dims(label_weights, 0) 708 label_weights = weights_broadcast_ops.broadcast_weights(label_weights, 709 y_pred) 710 label_weights_tiled = array_ops.tile( 711 array_ops.reshape(label_weights, thresh_tiles), data_tiles) 712 if weights_tiled is None: 713 weights_tiled = label_weights_tiled 714 else: 715 weights_tiled = math_ops.multiply(weights_tiled, label_weights_tiled) 716 717 update_ops = [] 718 719 def weighted_assign_add(label, pred, weights, var): 720 label_and_pred = math_ops.cast( 721 math_ops.logical_and(label, pred), dtype=var.dtype) 722 if weights is not None: 723 label_and_pred *= math_ops.cast(weights, dtype=var.dtype) 724 return var.assign_add(math_ops.reduce_sum(label_and_pred, 1)) 725 726 loop_vars = { 727 ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos), 728 } 729 update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update 730 update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update 731 update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update 732 733 if update_fn or update_tn: 734 pred_is_neg = math_ops.logical_not(pred_is_pos) 735 loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg) 736 737 if update_fp or update_tn: 738 label_is_neg = math_ops.logical_not(label_is_pos) 739 loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos) 740 if update_tn: 741 loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg) 742 743 for matrix_cond, (label, pred) in loop_vars.items(): 744 745 if matrix_cond in variables_to_update: 746 update_ops.append( 747 weighted_assign_add(label, pred, weights_tiled, 748 variables_to_update[matrix_cond])) 749 750 return control_flow_ops.group(update_ops) 751 752 753def _filter_top_k(x, k): 754 """Filters top-k values in the last dim of x and set the rest to NEG_INF. 755 756 Used for computing top-k prediction values in dense labels (which has the same 757 shape as predictions) for recall and precision top-k metrics. 758 759 Args: 760 x: tensor with any dimensions. 761 k: the number of values to keep. 762 763 Returns: 764 tensor with same shape and dtype as x. 765 """ 766 _, top_k_idx = nn_ops.top_k(x, k, sorted=False) 767 top_k_mask = math_ops.reduce_sum( 768 array_ops.one_hot(top_k_idx, array_ops.shape(x)[-1], axis=-1), axis=-2) 769 return x * top_k_mask + NEG_INF * (1 - top_k_mask) 770 771 772def ragged_assert_compatible_and_get_flat_values(values, mask=None): 773 """If ragged, it checks the compatibility and then returns the flat_values. 774 775 Note: If two tensors are dense, it does not check their compatibility. 776 Note: Although two ragged tensors with different ragged ranks could have 777 identical overall rank and dimension sizes and hence be compatible, 778 we do not support those cases. 779 Args: 780 values: A list of potentially ragged tensor of the same ragged_rank. 781 mask: A potentially ragged tensor of the same ragged_rank as elements in 782 Values. 783 784 Returns: 785 A tuple in which the first element is the list of tensors and the second 786 is the mask tensor. ([Values], mask). Mask and the element in Values 787 are equal to the flat_values of the input arguments (if they were ragged). 788 """ 789 if isinstance(values, list): 790 is_all_ragged = \ 791 all(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values) 792 is_any_ragged = \ 793 any(isinstance(rt, ragged_tensor.RaggedTensor) for rt in values) 794 else: 795 is_all_ragged = isinstance(values, ragged_tensor.RaggedTensor) 796 is_any_ragged = is_all_ragged 797 if (is_all_ragged and 798 ((mask is None) or isinstance(mask, ragged_tensor.RaggedTensor))): 799 to_be_stripped = False 800 if not isinstance(values, list): 801 values = [values] 802 to_be_stripped = True 803 804 # NOTE: we leave the flat_values compatibility to 805 # tf.TensorShape `assert_is_compatible_with` 806 # check if both dynamic dimensions are equal and then use the flat_values. 807 nested_row_split_list = [rt.nested_row_splits for rt in values] 808 assertion_list = _assert_splits_match(nested_row_split_list) 809 810 # if both are ragged sample_weights also should be ragged with same dims. 811 if isinstance(mask, ragged_tensor.RaggedTensor): 812 assertion_list_for_mask = _assert_splits_match( 813 [nested_row_split_list[0], mask.nested_row_splits]) 814 with ops.control_dependencies(assertion_list_for_mask): 815 mask = array_ops.expand_dims(mask.flat_values, -1) 816 817 # values has at least 1 element. 818 flat_values = [] 819 for value in values: 820 with ops.control_dependencies(assertion_list): 821 flat_values.append(array_ops.expand_dims(value.flat_values, -1)) 822 823 values = flat_values[0] if to_be_stripped else flat_values 824 825 elif is_any_ragged: 826 raise TypeError('One of the inputs does not have acceptable types.') 827 # values are empty or value are not ragged and mask is ragged. 828 elif isinstance(mask, ragged_tensor.RaggedTensor): 829 raise TypeError('Ragged mask is not allowed with non-ragged inputs.') 830 831 return values, mask 832 833 834def _assert_splits_match(nested_splits_lists): 835 """Checks that the given splits lists are identical. 836 837 Performs static tests to ensure that the given splits lists are identical, 838 and returns a list of control dependency op tensors that check that they are 839 fully identical. 840 841 Args: 842 nested_splits_lists: A list of nested_splits_lists, where each split_list is 843 a list of `splits` tensors from a `RaggedTensor`, ordered from outermost 844 ragged dimension to innermost ragged dimension. 845 846 Returns: 847 A list of control dependency op tensors. 848 Raises: 849 ValueError: If the splits are not identical. 850 """ 851 error_msg = 'Inputs must have identical ragged splits' 852 for splits_list in nested_splits_lists: 853 if len(splits_list) != len(nested_splits_lists[0]): 854 raise ValueError(error_msg) 855 return [ 856 check_ops.assert_equal(s1, s2, message=error_msg) # pylint: disable=g-complex-comprehension 857 for splits_list in nested_splits_lists[1:] 858 for (s1, s2) in zip(nested_splits_lists[0], splits_list) 859 ] 860