1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# Licensed under the Apache License, Version 2.0 (the "License"); 3# you may not use this file except in compliance with the License. 4# You may obtain a copy of the License at 5# 6# http://www.apache.org/licenses/LICENSE-2.0 7# 8# Unless required by applicable law or agreed to in writing, software 9# distributed under the License is distributed on an "AS IS" BASIS, 10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11# See the License for the specific language governing permissions and 12# limitations under the License. 13# ============================================================================== 14"""Implementation of tf.metrics module.""" 15 16from tensorflow.python.distribute import distribution_strategy_context 17from tensorflow.python.eager import context 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import sparse_tensor 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import check_ops 23from tensorflow.python.ops import confusion_matrix 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import nn 27from tensorflow.python.ops import sets 28from tensorflow.python.ops import sparse_ops 29from tensorflow.python.ops import state_ops 30from tensorflow.python.ops import variable_scope 31from tensorflow.python.ops import weights_broadcast_ops 32from tensorflow.python.platform import tf_logging as logging 33from tensorflow.python.util.deprecation import deprecated 34from tensorflow.python.util.tf_export import tf_export 35 36 37def metric_variable(shape, dtype, validate_shape=True, name=None): 38 """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections. 39 40 If running in a `DistributionStrategy` context, the variable will be 41 "sync on read". This means: 42 43 * The returned object will be a container with separate variables 44 per replica of the model. 45 46 * When writing to the variable, e.g. using `assign_add` in a metric 47 update, the update will be applied to the variable local to the 48 replica. 49 50 * To get a metric's result value, we need to sum the variable values 51 across the replicas before computing the final answer. Furthermore, 52 the final answer should be computed once instead of in every 53 replica. Both of these are accomplished by running the computation 54 of the final result value inside 55 `distribution_strategy_context.get_replica_context().merge_call(fn)`. 56 Inside the `merge_call()`, ops are only added to the graph once 57 and access to a sync on read variable in a computation returns 58 the sum across all replicas. 59 60 Args: 61 shape: Shape of the created variable. 62 dtype: Type of the created variable. 63 validate_shape: (Optional) Whether shape validation is enabled for 64 the created variable. 65 name: (Optional) String name of the created variable. 66 67 Returns: 68 A (non-trainable) variable initialized to zero, or if inside a 69 `DistributionStrategy` scope a sync on read variable container. 70 """ 71 # Note that synchronization "ON_READ" implies trainable=False. 72 return variable_scope.variable( 73 lambda: array_ops.zeros(shape, dtype), 74 trainable=False, 75 collections=[ 76 ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES 77 ], 78 validate_shape=validate_shape, 79 synchronization=variable_scope.VariableSynchronization.ON_READ, 80 aggregation=variable_scope.VariableAggregation.SUM, 81 name=name) 82 83 84def _remove_squeezable_dimensions(predictions, labels, weights): 85 """Squeeze or expand last dim if needed. 86 87 Squeezes last dim of `predictions` or `labels` if their rank differs by 1 88 (using confusion_matrix.remove_squeezable_dimensions). 89 Squeezes or expands last dim of `weights` if its rank differs by 1 from the 90 new rank of `predictions`. 91 92 If `weights` is scalar, it is kept scalar. 93 94 This will use static shape if available. Otherwise, it will add graph 95 operations, which could result in a performance hit. 96 97 Args: 98 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 99 labels: Optional label `Tensor` whose dimensions match `predictions`. 100 weights: Optional weight scalar or `Tensor` whose dimensions match 101 `predictions`. 102 103 Returns: 104 Tuple of `predictions`, `labels` and `weights`. Each of them possibly has 105 the last dimension squeezed, `weights` could be extended by one dimension. 106 """ 107 predictions = ops.convert_to_tensor(predictions) 108 if labels is not None: 109 labels, predictions = confusion_matrix.remove_squeezable_dimensions( 110 labels, predictions) 111 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 112 113 if weights is None: 114 return predictions, labels, None 115 116 weights = ops.convert_to_tensor(weights) 117 weights_shape = weights.get_shape() 118 weights_rank = weights_shape.ndims 119 if weights_rank == 0: 120 return predictions, labels, weights 121 122 predictions_shape = predictions.get_shape() 123 predictions_rank = predictions_shape.ndims 124 if (predictions_rank is not None) and (weights_rank is not None): 125 # Use static rank. 126 if weights_rank - predictions_rank == 1: 127 weights = array_ops.squeeze(weights, [-1]) 128 elif predictions_rank - weights_rank == 1: 129 weights = array_ops.expand_dims(weights, [-1]) 130 else: 131 # Use dynamic rank. 132 weights_rank_tensor = array_ops.rank(weights) 133 rank_diff = weights_rank_tensor - array_ops.rank(predictions) 134 135 def _maybe_expand_weights(): 136 return control_flow_ops.cond( 137 math_ops.equal(rank_diff, -1), 138 lambda: array_ops.expand_dims(weights, [-1]), lambda: weights) 139 140 # Don't attempt squeeze if it will fail based on static check. 141 if ((weights_rank is not None) and 142 (not weights_shape.dims[-1].is_compatible_with(1))): 143 maybe_squeeze_weights = lambda: weights 144 else: 145 maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1]) 146 147 def _maybe_adjust_weights(): 148 return control_flow_ops.cond( 149 math_ops.equal(rank_diff, 1), maybe_squeeze_weights, 150 _maybe_expand_weights) 151 152 # If weights are scalar, do nothing. Otherwise, try to add or remove a 153 # dimension to match predictions. 154 weights = control_flow_ops.cond( 155 math_ops.equal(weights_rank_tensor, 0), lambda: weights, 156 _maybe_adjust_weights) 157 return predictions, labels, weights 158 159 160def _maybe_expand_labels(labels, predictions): 161 """If necessary, expand `labels` along last dimension to match `predictions`. 162 163 Args: 164 labels: `Tensor` or `SparseTensor` with shape 165 [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies 166 num_labels=1, in which case the result is an expanded `labels` with shape 167 [D1, ... DN, 1]. 168 predictions: `Tensor` with shape [D1, ... DN, num_classes]. 169 170 Returns: 171 `labels` with the same rank as `predictions`. 172 173 Raises: 174 ValueError: if `labels` has invalid shape. 175 """ 176 with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope: 177 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 178 179 # If sparse, expand sparse shape. 180 if isinstance(labels, sparse_tensor.SparseTensor): 181 return control_flow_ops.cond( 182 math_ops.equal( 183 array_ops.rank(predictions), 184 array_ops.size(labels.dense_shape) + 1), 185 lambda: sparse_ops.sparse_reshape( # pylint: disable=g-long-lambda 186 labels, 187 shape=array_ops.concat((labels.dense_shape, (1,)), 0), 188 name=scope), 189 lambda: labels) 190 191 # Otherwise, try to use static shape. 192 labels_rank = labels.get_shape().ndims 193 if labels_rank is not None: 194 predictions_rank = predictions.get_shape().ndims 195 if predictions_rank is not None: 196 if predictions_rank == labels_rank: 197 return labels 198 if predictions_rank == labels_rank + 1: 199 return array_ops.expand_dims(labels, -1, name=scope) 200 raise ValueError( 201 f'Unexpected labels shape {labels.get_shape()} for predictions ' 202 f'shape {predictions.get_shape()}. Predictions rank should be the ' 203 'same rank as labels rank or labels rank plus one .') 204 205 # Otherwise, use dynamic shape. 206 return control_flow_ops.cond( 207 math_ops.equal(array_ops.rank(predictions), 208 array_ops.rank(labels) + 1), 209 lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels) 210 211 212def _safe_scalar_div(numerator, denominator, name): 213 """Divides two values, returning 0 if the denominator is 0. 214 215 Args: 216 numerator: A scalar `float64` `Tensor`. 217 denominator: A scalar `float64` `Tensor`. 218 name: Name for the returned op. 219 220 Returns: 221 0 if `denominator` == 0, else `numerator` / `denominator` 222 """ 223 numerator.get_shape().with_rank_at_most(1) 224 denominator.get_shape().with_rank_at_most(1) 225 return math_ops.div_no_nan(numerator, denominator, name=name) 226 227 228def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None): 229 """Calculate a streaming confusion matrix. 230 231 Calculates a confusion matrix. For estimation over a stream of data, 232 the function creates an `update_op` operation. 233 234 Args: 235 labels: A `Tensor` of ground truth labels with shape [batch size] and of 236 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 237 predictions: A `Tensor` of prediction results for semantic labels, whose 238 shape is [batch size] and type `int32` or `int64`. The tensor will be 239 flattened if its rank > 1. 240 num_classes: The possible number of labels the prediction task can 241 have. This value must be provided, since a confusion matrix of 242 dimension = [num_classes, num_classes] will be allocated. 243 weights: Optional `Tensor` whose rank is either 0, or the same rank as 244 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 245 be either `1`, or the same as the corresponding `labels` dimension). 246 247 Returns: 248 total_cm: A `Tensor` representing the confusion matrix. 249 update_op: An operation that increments the confusion matrix. 250 """ 251 # Local variable to accumulate the predictions in the confusion matrix. 252 total_cm = metric_variable( 253 [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix') 254 255 # Cast the type to int64 required by confusion_matrix_ops. 256 predictions = math_ops.cast(predictions, dtypes.int64) 257 labels = math_ops.cast(labels, dtypes.int64) 258 num_classes = math_ops.cast(num_classes, dtypes.int64) 259 260 # Flatten the input if its rank > 1. 261 if predictions.get_shape().ndims > 1: 262 predictions = array_ops.reshape(predictions, [-1]) 263 264 if labels.get_shape().ndims > 1: 265 labels = array_ops.reshape(labels, [-1]) 266 267 if (weights is not None) and (weights.get_shape().ndims > 1): 268 weights = array_ops.reshape(weights, [-1]) 269 270 # Accumulate the prediction to current confusion matrix. 271 current_cm = confusion_matrix.confusion_matrix( 272 labels, predictions, num_classes, weights=weights, dtype=dtypes.float64) 273 update_op = state_ops.assign_add(total_cm, current_cm) 274 return total_cm, update_op 275 276 277def _aggregate_across_replicas(metrics_collections, metric_value_fn, *args): 278 """Aggregate metric value across replicas.""" 279 def fn(distribution, *a): 280 """Call `metric_value_fn` in the correct control flow context.""" 281 if hasattr(distribution.extended, '_outer_control_flow_context'): 282 # If there was an outer context captured before this method was called, 283 # then we enter that context to create the metric value op. If the 284 # captured context is `None`, ops.control_dependencies(None) gives the 285 # desired behavior. Else we use `Enter` and `Exit` to enter and exit the 286 # captured context. 287 # This special handling is needed because sometimes the metric is created 288 # inside a while_loop (and perhaps a TPU rewrite context). But we don't 289 # want the value op to be evaluated every step or on the TPU. So we 290 # create it outside so that it can be evaluated at the end on the host, 291 # once the update ops have been evaluated. 292 293 # pylint: disable=protected-access 294 if distribution.extended._outer_control_flow_context is None: 295 with ops.control_dependencies(None): 296 metric_value = metric_value_fn(distribution, *a) 297 else: 298 distribution.extended._outer_control_flow_context.Enter() 299 metric_value = metric_value_fn(distribution, *a) 300 distribution.extended._outer_control_flow_context.Exit() 301 # pylint: enable=protected-access 302 else: 303 metric_value = metric_value_fn(distribution, *a) 304 if metrics_collections: 305 ops.add_to_collections(metrics_collections, metric_value) 306 return metric_value 307 308 return distribution_strategy_context.get_replica_context().merge_call( 309 fn, args=args) 310 311 312@tf_export(v1=['metrics.mean']) 313def mean(values, 314 weights=None, 315 metrics_collections=None, 316 updates_collections=None, 317 name=None): 318 """Computes the (weighted) mean of the given values. 319 320 The `mean` function creates two local variables, `total` and `count` 321 that are used to compute the average of `values`. This average is ultimately 322 returned as `mean` which is an idempotent operation that simply divides 323 `total` by `count`. 324 325 For estimation of the metric over a stream of data, the function creates an 326 `update_op` operation that updates these variables and returns the `mean`. 327 `update_op` increments `total` with the reduced sum of the product of `values` 328 and `weights`, and it increments `count` with the reduced sum of `weights`. 329 330 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 331 332 Args: 333 values: A `Tensor` of arbitrary dimensions. 334 weights: Optional `Tensor` whose rank is either 0, or the same rank as 335 `values`, and must be broadcastable to `values` (i.e., all dimensions must 336 be either `1`, or the same as the corresponding `values` dimension). 337 metrics_collections: An optional list of collections that `mean` 338 should be added to. 339 updates_collections: An optional list of collections that `update_op` 340 should be added to. 341 name: An optional variable_scope name. 342 343 Returns: 344 mean: A `Tensor` representing the current mean, the value of `total` divided 345 by `count`. 346 update_op: An operation that increments the `total` and `count` variables 347 appropriately and whose value matches `mean_value`. 348 349 Raises: 350 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 351 or if either `metrics_collections` or `updates_collections` are not a list 352 or tuple. 353 RuntimeError: If eager execution is enabled. 354 355 @compatibility(TF2) 356 `tf.compat.v1.metrics.mean` is not compatible with eager 357 execution or `tf.function`. 358 Please use `tf.keras.metrics.Mean` instead for TF2 migration. After 359 instantiating a `tf.keras.metrics.Mean` object, you can first call the 360 `update_state()` method to record the new values, and then call the 361 `result()` method to get the mean eagerly. You can also attach it to a 362 Keras model with the `add_metric` method. Please refer to the [migration 363 guide](https://www.tensorflow.org/guide/migrate#new-style_metrics_and_losses) 364 for more details. 365 366 #### Structural Mapping to TF2 367 368 Before: 369 370 ```python 371 mean, update_op = tf.compat.v1.metrics.mean( 372 values=values, 373 weights=weights, 374 metrics_collections=metrics_collections, 375 update_collections=update_collections, 376 name=name) 377 ``` 378 379 After: 380 381 ```python 382 m = tf.keras.metrics.Mean( 383 name=name) 384 385 m.update_state( 386 values=values, 387 sample_weight=weights) 388 389 mean = m.result() 390 ``` 391 392 #### How to Map Arguments 393 394 | TF1 Arg Name | TF2 Arg Name | Note | 395 | :-------------------- | :-------------- | :------------------------- | 396 | `values` | `values` | In `update_state()` method | 397 | `weights` | `sample_weight` | In `update_state()` method | 398 | `metrics_collections` | Not supported | Metrics should be tracked | 399 : : : explicitly or with Keras : 400 : : : APIs, for example, : 401 : : : [add_metric][add_metric], : 402 : : : instead of via collections : 403 | `updates_collections` | Not supported | - | 404 | `name` | `name` | In constructor | 405 406 [add_metric]:https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_metric 407 408 409 #### Before & After Usage Example 410 411 Before: 412 413 >>> g = tf.Graph() 414 >>> with g.as_default(): 415 ... values = [1, 2, 3] 416 ... mean, update_op = tf.compat.v1.metrics.mean(values) 417 ... global_init = tf.compat.v1.global_variables_initializer() 418 ... local_init = tf.compat.v1.local_variables_initializer() 419 >>> sess = tf.compat.v1.Session(graph=g) 420 >>> sess.run([global_init, local_init]) 421 >>> sess.run(update_op) 422 >>> sess.run(mean) 423 2.0 424 425 426 After: 427 428 >>> m = tf.keras.metrics.Mean() 429 >>> m.update_state([1, 2, 3]) 430 >>> m.result().numpy() 431 2.0 432 433 ```python 434 # Used within Keras model 435 model.add_metric(tf.keras.metrics.Mean()(values)) 436 ``` 437 438 @end_compatibility 439 """ 440 if context.executing_eagerly(): 441 raise RuntimeError('tf.metrics.mean is not supported when eager execution ' 442 'is enabled.') 443 444 with variable_scope.variable_scope(name, 'mean', (values, weights)): 445 values = math_ops.cast(values, dtypes.float32) 446 447 total = metric_variable([], dtypes.float32, name='total') 448 count = metric_variable([], dtypes.float32, name='count') 449 450 if weights is None: 451 num_values = math_ops.cast(array_ops.size(values), dtypes.float32) 452 else: 453 values, _, weights = _remove_squeezable_dimensions( 454 predictions=values, labels=None, weights=weights) 455 weights = weights_broadcast_ops.broadcast_weights( 456 math_ops.cast(weights, dtypes.float32), values) 457 values = math_ops.multiply(values, weights) 458 num_values = math_ops.reduce_sum(weights) 459 460 update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values)) 461 with ops.control_dependencies([values]): 462 update_count_op = state_ops.assign_add(count, num_values) 463 464 def compute_mean(_, t, c): 465 return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value') 466 467 mean_t = _aggregate_across_replicas( 468 metrics_collections, compute_mean, total, count) 469 update_op = math_ops.div_no_nan( 470 update_total_op, math_ops.maximum(update_count_op, 0), name='update_op') 471 472 if updates_collections: 473 ops.add_to_collections(updates_collections, update_op) 474 475 return mean_t, update_op 476 477 478@tf_export(v1=['metrics.accuracy']) 479def accuracy(labels, 480 predictions, 481 weights=None, 482 metrics_collections=None, 483 updates_collections=None, 484 name=None): 485 """Calculates how often `predictions` matches `labels`. 486 487 The `accuracy` function creates two local variables, `total` and 488 `count` that are used to compute the frequency with which `predictions` 489 matches `labels`. This frequency is ultimately returned as `accuracy`: an 490 idempotent operation that simply divides `total` by `count`. 491 492 For estimation of the metric over a stream of data, the function creates an 493 `update_op` operation that updates these variables and returns the `accuracy`. 494 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 495 where the corresponding elements of `predictions` and `labels` match and 0.0 496 otherwise. Then `update_op` increments `total` with the reduced sum of the 497 product of `weights` and `is_correct`, and it increments `count` with the 498 reduced sum of `weights`. 499 500 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 501 502 Args: 503 labels: The ground truth values, a `Tensor` whose shape matches 504 `predictions`. 505 predictions: The predicted values, a `Tensor` of any shape. 506 weights: Optional `Tensor` whose rank is either 0, or the same rank as 507 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 508 be either `1`, or the same as the corresponding `labels` dimension). 509 metrics_collections: An optional list of collections that `accuracy` should 510 be added to. 511 updates_collections: An optional list of collections that `update_op` should 512 be added to. 513 name: An optional variable_scope name. 514 515 Returns: 516 accuracy: A `Tensor` representing the accuracy, the value of `total` divided 517 by `count`. 518 update_op: An operation that increments the `total` and `count` variables 519 appropriately and whose value matches `accuracy`. 520 521 Raises: 522 ValueError: If `predictions` and `labels` have mismatched shapes, or if 523 `weights` is not `None` and its shape doesn't match `predictions`, or if 524 either `metrics_collections` or `updates_collections` are not a list or 525 tuple. 526 RuntimeError: If eager execution is enabled. 527 528 @compatibility(TF2) 529 `tf.compat.v1.metrics.accuracy` is not compatible with eager 530 execution or `tf.function`. 531 Please use `tf.keras.metrics.Accuracy` instead for TF2 migration. After 532 instantiating a `tf.keras.metrics.Accuracy` object, you can first call the 533 `update_state()` method to record the prediction/labels, and then call the 534 `result()` method to get the accuracy eagerly. You can also attach it to a 535 Keras model when calling the `compile` method. Please refer to [this 536 guide](https://www.tensorflow.org/guide/migrate#new-style_metrics_and_losses) 537 for more details. 538 539 #### Structural Mapping to Native TF2 540 541 Before: 542 543 ```python 544 accuracy, update_op = tf.compat.v1.metrics.accuracy( 545 labels=labels, 546 predictions=predictions, 547 weights=weights, 548 metrics_collections=metrics_collections, 549 update_collections=update_collections, 550 name=name) 551 ``` 552 553 After: 554 555 ```python 556 m = tf.keras.metrics.Accuracy( 557 name=name, 558 dtype=None) 559 560 m.update_state( 561 y_true=labels, 562 y_pred=predictions, 563 sample_weight=weights) 564 565 accuracy = m.result() 566 ``` 567 568 #### How to Map Arguments 569 570 | TF1 Arg Name | TF2 Arg Name | Note | 571 | :-------------------- | :-------------- | :------------------------- | 572 | `label` | `y_true` | In `update_state()` method | 573 | `predictions` | `y_true` | In `update_state()` method | 574 | `weights` | `sample_weight` | In `update_state()` method | 575 | `metrics_collections` | Not supported | Metrics should be tracked | 576 : : : explicitly or with Keras : 577 : : : APIs, for example, : 578 : : : [add_metric][add_metric], : 579 : : : instead of via collections : 580 | `updates_collections` | Not supported | - | 581 | `name` | `name` | In constructor | 582 583 [add_metric]:https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#add_metric 584 585 586 #### Before & After Usage Example 587 588 Before: 589 590 >>> g = tf.Graph() 591 >>> with g.as_default(): 592 ... logits = [1, 2, 3] 593 ... labels = [0, 2, 3] 594 ... acc, acc_op = tf.compat.v1.metrics.accuracy(logits, labels) 595 ... global_init = tf.compat.v1.global_variables_initializer() 596 ... local_init = tf.compat.v1.local_variables_initializer() 597 >>> sess = tf.compat.v1.Session(graph=g) 598 >>> sess.run([global_init, local_init]) 599 >>> print(sess.run([acc, acc_op])) 600 [0.0, 0.66667] 601 602 603 After: 604 605 >>> m = tf.keras.metrics.Accuracy() 606 >>> m.update_state([1, 2, 3], [0, 2, 3]) 607 >>> m.result().numpy() 608 0.66667 609 610 ```python 611 # Used within Keras model 612 model.compile(optimizer='sgd', 613 loss='mse', 614 metrics=[tf.keras.metrics.Accuracy()]) 615 ``` 616 617 @end_compatibility 618 """ 619 if context.executing_eagerly(): 620 raise RuntimeError('tf.metrics.accuracy is not supported when eager ' 621 'execution is enabled.') 622 623 predictions, labels, weights = _remove_squeezable_dimensions( 624 predictions=predictions, labels=labels, weights=weights) 625 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 626 if labels.dtype != predictions.dtype: 627 predictions = math_ops.cast(predictions, labels.dtype) 628 is_correct = math_ops.cast( 629 math_ops.equal(predictions, labels), dtypes.float32) 630 return mean(is_correct, weights, metrics_collections, updates_collections, 631 name or 'accuracy') 632 633 634def _confusion_matrix_at_thresholds(labels, 635 predictions, 636 thresholds, 637 weights=None, 638 includes=None): 639 """Computes true_positives, false_negatives, true_negatives, false_positives. 640 641 This function creates up to four local variables, `true_positives`, 642 `true_negatives`, `false_positives` and `false_negatives`. 643 `true_positive[i]` is defined as the total weight of values in `predictions` 644 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 645 `false_negatives[i]` is defined as the total weight of values in `predictions` 646 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 647 `true_negatives[i]` is defined as the total weight of values in `predictions` 648 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 649 `false_positives[i]` is defined as the total weight of values in `predictions` 650 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 651 652 For estimation of these metrics over a stream of data, for each metric the 653 function respectively creates an `update_op` operation that updates the 654 variable and returns its value. 655 656 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 657 658 Args: 659 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 660 `bool`. 661 predictions: A floating point `Tensor` of arbitrary shape and whose values 662 are in the range `[0, 1]`. 663 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 664 weights: Optional `Tensor` whose rank is either 0, or the same rank as 665 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 666 be either `1`, or the same as the corresponding `labels` dimension). 667 includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`, 668 default to all four. 669 670 Returns: 671 values: Dict of variables of shape `[len(thresholds)]`. Keys are from 672 `includes`. 673 update_ops: Dict of operations that increments the `values`. Keys are from 674 `includes`. 675 676 Raises: 677 ValueError: If `predictions` and `labels` have mismatched shapes, or if 678 `weights` is not `None` and its shape doesn't match `predictions`, or if 679 `includes` contains invalid keys. 680 """ 681 all_includes = ('tp', 'fn', 'tn', 'fp') 682 if includes is None: 683 includes = all_includes 684 else: 685 for include in includes: 686 if include not in all_includes: 687 raise ValueError(f'Invalid key: {include}') 688 689 with ops.control_dependencies([ 690 check_ops.assert_greater_equal( 691 predictions, 692 math_ops.cast(0.0, dtype=predictions.dtype), 693 message='predictions must be in [0, 1]'), 694 check_ops.assert_less_equal( 695 predictions, 696 math_ops.cast(1.0, dtype=predictions.dtype), 697 message='predictions must be in [0, 1]') 698 ]): 699 predictions, labels, weights = _remove_squeezable_dimensions( 700 predictions=math_ops.cast(predictions, dtypes.float32), 701 labels=math_ops.cast(labels, dtype=dtypes.bool), 702 weights=weights) 703 704 num_thresholds = len(thresholds) 705 706 # Reshape predictions and labels. 707 predictions_2d = array_ops.reshape(predictions, [-1, 1]) 708 labels_2d = array_ops.reshape( 709 math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 710 711 # Use static shape if known. 712 num_predictions = predictions_2d.get_shape().as_list()[0] 713 714 # Otherwise use dynamic shape. 715 if num_predictions is None: 716 num_predictions = array_ops.shape(predictions_2d)[0] 717 thresh_tiled = array_ops.tile( 718 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 719 array_ops.stack([1, num_predictions])) 720 721 # Tile the predictions after thresholding them across different thresholds. 722 pred_is_pos = math_ops.greater( 723 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), 724 thresh_tiled) 725 if ('fn' in includes) or ('tn' in includes): 726 pred_is_neg = math_ops.logical_not(pred_is_pos) 727 728 # Tile labels by number of thresholds 729 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 730 if ('fp' in includes) or ('tn' in includes): 731 label_is_neg = math_ops.logical_not(label_is_pos) 732 733 if weights is not None: 734 weights = weights_broadcast_ops.broadcast_weights( 735 math_ops.cast(weights, dtypes.float32), predictions) 736 weights_tiled = array_ops.tile( 737 array_ops.reshape(weights, [1, -1]), [num_thresholds, 1]) 738 thresh_tiled.get_shape().assert_is_compatible_with( 739 weights_tiled.get_shape()) 740 else: 741 weights_tiled = None 742 743 values = {} 744 update_ops = {} 745 746 if 'tp' in includes: 747 true_p = metric_variable( 748 [num_thresholds], dtypes.float32, name='true_positives') 749 is_true_positive = math_ops.cast( 750 math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32) 751 if weights_tiled is not None: 752 is_true_positive *= weights_tiled 753 update_ops['tp'] = state_ops.assign_add(true_p, 754 math_ops.reduce_sum( 755 is_true_positive, 1)) 756 values['tp'] = true_p 757 758 if 'fn' in includes: 759 false_n = metric_variable( 760 [num_thresholds], dtypes.float32, name='false_negatives') 761 is_false_negative = math_ops.cast( 762 math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32) 763 if weights_tiled is not None: 764 is_false_negative *= weights_tiled 765 update_ops['fn'] = state_ops.assign_add(false_n, 766 math_ops.reduce_sum( 767 is_false_negative, 1)) 768 values['fn'] = false_n 769 770 if 'tn' in includes: 771 true_n = metric_variable( 772 [num_thresholds], dtypes.float32, name='true_negatives') 773 is_true_negative = math_ops.cast( 774 math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32) 775 if weights_tiled is not None: 776 is_true_negative *= weights_tiled 777 update_ops['tn'] = state_ops.assign_add(true_n, 778 math_ops.reduce_sum( 779 is_true_negative, 1)) 780 values['tn'] = true_n 781 782 if 'fp' in includes: 783 false_p = metric_variable( 784 [num_thresholds], dtypes.float32, name='false_positives') 785 is_false_positive = math_ops.cast( 786 math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32) 787 if weights_tiled is not None: 788 is_false_positive *= weights_tiled 789 update_ops['fp'] = state_ops.assign_add(false_p, 790 math_ops.reduce_sum( 791 is_false_positive, 1)) 792 values['fp'] = false_p 793 794 return values, update_ops 795 796 797def _aggregate_variable(v, collections): 798 f = lambda distribution, value: distribution.extended.read_var(value) 799 return _aggregate_across_replicas(collections, f, v) 800 801 802@tf_export(v1=['metrics.auc']) 803@deprecated(None, 804 'The value of AUC returned by this may race with the update so ' 805 'this is deprecated. Please use tf.keras.metrics.AUC instead.') 806def auc(labels, 807 predictions, 808 weights=None, 809 num_thresholds=200, 810 metrics_collections=None, 811 updates_collections=None, 812 curve='ROC', 813 name=None, 814 summation_method='trapezoidal', 815 thresholds=None): 816 """Computes the approximate AUC via a Riemann sum. 817 818 The `auc` function creates four local variables, `true_positives`, 819 `true_negatives`, `false_positives` and `false_negatives` that are used to 820 compute the AUC. To discretize the AUC curve, a linearly spaced set of 821 thresholds is used to compute pairs of recall and precision values. The area 822 under the ROC-curve is therefore computed using the height of the recall 823 values by the false positive rate, while the area under the PR-curve is the 824 computed using the height of the precision values by the recall. 825 826 This value is ultimately returned as `auc`, an idempotent operation that 827 computes the area under a discretized curve of precision versus recall values 828 (computed using the aforementioned variables). The `num_thresholds` variable 829 controls the degree of discretization with larger numbers of thresholds more 830 closely approximating the true AUC. The quality of the approximation may vary 831 dramatically depending on `num_thresholds`. 832 833 For best results, `predictions` should be distributed approximately uniformly 834 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 835 approximation may be poor if this is not the case. Setting `summation_method` 836 to 'minoring' or 'majoring' can help quantify the error in the approximation 837 by providing lower or upper bound estimate of the AUC. The `thresholds` 838 parameter can be used to manually specify thresholds which split the 839 predictions more evenly. 840 841 For estimation of the metric over a stream of data, the function creates an 842 `update_op` operation that updates these variables and returns the `auc`. 843 844 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 845 846 Args: 847 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 848 `bool`. 849 predictions: A floating point `Tensor` of arbitrary shape and whose values 850 are in the range `[0, 1]`. 851 weights: Optional `Tensor` whose rank is either 0, or the same rank as 852 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 853 be either `1`, or the same as the corresponding `labels` dimension). 854 num_thresholds: The number of thresholds to use when discretizing the roc 855 curve. 856 metrics_collections: An optional list of collections that `auc` should be 857 added to. 858 updates_collections: An optional list of collections that `update_op` should 859 be added to. 860 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 861 'PR' for the Precision-Recall-curve. 862 name: An optional variable_scope name. 863 summation_method: Specifies the Riemann summation method used 864 (https://en.wikipedia.org/wiki/Riemann_sum): 'trapezoidal' [default] that 865 applies the trapezoidal rule; 'careful_interpolation', a variant of it 866 differing only by a more correct interpolation scheme for PR-AUC - 867 interpolating (true/false) positives but not the ratio that is precision; 868 'minoring' that applies left summation for increasing intervals and right 869 summation for decreasing intervals; 'majoring' that does the opposite. 870 Note that 'careful_interpolation' is strictly preferred to 'trapezoidal' 871 (to be deprecated soon) as it applies the same method for ROC, and a 872 better one (see Davis & Goadrich 2006 for details) for the PR curve. 873 thresholds: An optional list of floating point values to use as the 874 thresholds for discretizing the curve. If set, the `num_thresholds` 875 parameter is ignored. Values should be in [0, 1]. Endpoint thresholds 876 equal to {-epsilon, 1+epsilon} for a small positive epsilon value will be 877 automatically included with these to correctly handle predictions equal to 878 exactly 0 or 1. 879 880 Returns: 881 auc: A scalar `Tensor` representing the current area-under-curve. 882 update_op: An operation that increments the `true_positives`, 883 `true_negatives`, `false_positives` and `false_negatives` variables 884 appropriately and whose value matches `auc`. 885 886 Raises: 887 ValueError: If `predictions` and `labels` have mismatched shapes, or if 888 `weights` is not `None` and its shape doesn't match `predictions`, or if 889 either `metrics_collections` or `updates_collections` are not a list or 890 tuple. 891 RuntimeError: If eager execution is enabled. 892 """ 893 if context.executing_eagerly(): 894 raise RuntimeError('tf.metrics.auc is not supported when eager execution ' 895 'is enabled.') 896 897 with variable_scope.variable_scope(name, 'auc', 898 (labels, predictions, weights)): 899 if curve != 'ROC' and curve != 'PR': 900 raise ValueError(f'Curve must be either ROC or PR. Curve {curve} is ' 901 'unknown.') 902 903 kepsilon = 1e-7 # To account for floating point imprecisions. 904 if thresholds is not None: 905 # If specified, use the supplied thresholds. 906 thresholds = sorted(thresholds) 907 num_thresholds = len(thresholds) + 2 908 else: 909 # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in 910 # (0, 1). 911 thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) 912 for i in range(num_thresholds - 2)] 913 914 # Add an endpoint "threshold" below zero and above one for either threshold 915 # method. 916 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 917 918 values, update_ops = _confusion_matrix_at_thresholds( 919 labels, predictions, thresholds, weights) 920 921 # Add epsilons to avoid dividing by 0. 922 epsilon = 1.0e-6 923 924 def interpolate_pr_auc(tp, fp, fn): 925 """Interpolation formula inspired by section 4 of (Davis et al., 2006). 926 927 Note here we derive & use a closed formula not present in the paper 928 - as follows: 929 Modeling all of TP (true positive weight), 930 FP (false positive weight) and their sum P = TP + FP (positive weight) 931 as varying linearly within each interval [A, B] between successive 932 thresholds, we get 933 Precision = (TP_A + slope * (P - P_A)) / P 934 with slope = dTP / dP = (TP_B - TP_A) / (P_B - P_A). 935 The area within the interval is thus (slope / total_pos_weight) times 936 int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} 937 int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P} 938 where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in 939 int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A) 940 Bringing back the factor (slope / total_pos_weight) we'd put aside, we get 941 slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight 942 where dTP == TP_B - TP_A. 943 Note that when P_A == 0 the above calculation simplifies into 944 int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A) 945 which is really equivalent to imputing constant precision throughout the 946 first bucket having >0 true positives. 947 948 Args: 949 tp: true positive counts 950 fp: false positive counts 951 fn: false negative counts 952 953 Returns: 954 pr_auc: an approximation of the area under the P-R curve. 955 956 References: 957 The Relationship Between Precision-Recall and ROC Curves: 958 [Davis et al., 2006](https://dl.acm.org/citation.cfm?id=1143874) 959 ([pdf](https://www.biostat.wisc.edu/~page/rocpr.pdf)) 960 """ 961 dtp = tp[:num_thresholds - 1] - tp[1:] 962 p = tp + fp 963 prec_slope = math_ops.div_no_nan( 964 dtp, 965 math_ops.maximum(p[:num_thresholds - 1] - p[1:], 0), 966 name='prec_slope') 967 intercept = tp[1:] - math_ops.multiply(prec_slope, p[1:]) 968 safe_p_ratio = array_ops.where( 969 math_ops.logical_and(p[:num_thresholds - 1] > 0, p[1:] > 0), 970 math_ops.div_no_nan( 971 p[:num_thresholds - 1], 972 math_ops.maximum(p[1:], 0), 973 name='recall_relative_ratio'), array_ops.ones_like(p[1:])) 974 return math_ops.reduce_sum( 975 math_ops.div_no_nan( 976 prec_slope * (dtp + intercept * math_ops.log(safe_p_ratio)), 977 math_ops.maximum(tp[1:] + fn[1:], 0), 978 name='pr_auc_increment'), 979 name='interpolate_pr_auc') 980 981 def compute_auc(tp, fn, tn, fp, name): 982 """Computes the roc-auc or pr-auc based on confusion counts.""" 983 if curve == 'PR': 984 if summation_method == 'trapezoidal': 985 logging.warning( 986 'Trapezoidal rule is known to produce incorrect PR-AUCs; ' 987 'please switch to "careful_interpolation" instead.') 988 elif summation_method == 'careful_interpolation': 989 # This one is a bit tricky and is handled separately. 990 return interpolate_pr_auc(tp, fp, fn) 991 rec = math_ops.divide(tp + epsilon, tp + fn + epsilon) 992 if curve == 'ROC': 993 fp_rate = math_ops.divide(fp, fp + tn + epsilon) 994 x = fp_rate 995 y = rec 996 else: # curve == 'PR'. 997 prec = math_ops.divide(tp + epsilon, tp + fp + epsilon) 998 x = rec 999 y = prec 1000 if summation_method in ('trapezoidal', 'careful_interpolation'): 1001 # Note that the case ('PR', 'careful_interpolation') has been handled 1002 # above. 1003 return math_ops.reduce_sum( 1004 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 1005 (y[:num_thresholds - 1] + y[1:]) / 2.), 1006 name=name) 1007 elif summation_method == 'minoring': 1008 return math_ops.reduce_sum( 1009 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 1010 math_ops.minimum(y[:num_thresholds - 1], y[1:])), 1011 name=name) 1012 elif summation_method == 'majoring': 1013 return math_ops.reduce_sum( 1014 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 1015 math_ops.maximum(y[:num_thresholds - 1], y[1:])), 1016 name=name) 1017 else: 1018 raise ValueError(f'Invalid summation_method: {summation_method} ' 1019 'summation_method should be \'trapezoidal\', ' 1020 '\'careful_interpolation\', \'minoring\', or ' 1021 '\'majoring\'.') 1022 1023 # sum up the areas of all the trapeziums 1024 def compute_auc_value(_, values): 1025 return compute_auc(values['tp'], values['fn'], values['tn'], values['fp'], 1026 'value') 1027 1028 auc_value = _aggregate_across_replicas( 1029 metrics_collections, compute_auc_value, values) 1030 update_op = compute_auc(update_ops['tp'], update_ops['fn'], 1031 update_ops['tn'], update_ops['fp'], 'update_op') 1032 1033 if updates_collections: 1034 ops.add_to_collections(updates_collections, update_op) 1035 1036 return auc_value, update_op 1037 1038 1039@tf_export(v1=['metrics.mean_absolute_error']) 1040def mean_absolute_error(labels, 1041 predictions, 1042 weights=None, 1043 metrics_collections=None, 1044 updates_collections=None, 1045 name=None): 1046 """Computes the mean absolute error between the labels and predictions. 1047 1048 The `mean_absolute_error` function creates two local variables, 1049 `total` and `count` that are used to compute the mean absolute error. This 1050 average is weighted by `weights`, and it is ultimately returned as 1051 `mean_absolute_error`: an idempotent operation that simply divides `total` by 1052 `count`. 1053 1054 For estimation of the metric over a stream of data, the function creates an 1055 `update_op` operation that updates these variables and returns the 1056 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 1057 absolute value of the differences between `predictions` and `labels`. Then 1058 `update_op` increments `total` with the reduced sum of the product of 1059 `weights` and `absolute_errors`, and it increments `count` with the reduced 1060 sum of `weights` 1061 1062 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1063 1064 Args: 1065 labels: A `Tensor` of the same shape as `predictions`. 1066 predictions: A `Tensor` of arbitrary shape. 1067 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1068 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1069 be either `1`, or the same as the corresponding `labels` dimension). 1070 metrics_collections: An optional list of collections that 1071 `mean_absolute_error` should be added to. 1072 updates_collections: An optional list of collections that `update_op` should 1073 be added to. 1074 name: An optional variable_scope name. 1075 1076 Returns: 1077 mean_absolute_error: A `Tensor` representing the current mean, the value of 1078 `total` divided by `count`. 1079 update_op: An operation that increments the `total` and `count` variables 1080 appropriately and whose value matches `mean_absolute_error`. 1081 1082 Raises: 1083 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1084 `weights` is not `None` and its shape doesn't match `predictions`, or if 1085 either `metrics_collections` or `updates_collections` are not a list or 1086 tuple. 1087 RuntimeError: If eager execution is enabled. 1088 """ 1089 if context.executing_eagerly(): 1090 raise RuntimeError('tf.metrics.mean_absolute_error is not supported ' 1091 'when eager execution is enabled.') 1092 1093 predictions, labels, weights = _remove_squeezable_dimensions( 1094 predictions=predictions, labels=labels, weights=weights) 1095 absolute_errors = math_ops.abs(predictions - labels) 1096 return mean(absolute_errors, weights, metrics_collections, 1097 updates_collections, name or 'mean_absolute_error') 1098 1099 1100@tf_export(v1=['metrics.mean_cosine_distance']) 1101def mean_cosine_distance(labels, 1102 predictions, 1103 dim, 1104 weights=None, 1105 metrics_collections=None, 1106 updates_collections=None, 1107 name=None): 1108 """Computes the cosine distance between the labels and predictions. 1109 1110 The `mean_cosine_distance` function creates two local variables, 1111 `total` and `count` that are used to compute the average cosine distance 1112 between `predictions` and `labels`. This average is weighted by `weights`, 1113 and it is ultimately returned as `mean_distance`, which is an idempotent 1114 operation that simply divides `total` by `count`. 1115 1116 For estimation of the metric over a stream of data, the function creates an 1117 `update_op` operation that updates these variables and returns the 1118 `mean_distance`. 1119 1120 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1121 1122 Args: 1123 labels: A `Tensor` of arbitrary shape. 1124 predictions: A `Tensor` of the same shape as `labels`. 1125 dim: The dimension along which the cosine distance is computed. 1126 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1127 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1128 be either `1`, or the same as the corresponding `labels` dimension). Also, 1129 dimension `dim` must be `1`. 1130 metrics_collections: An optional list of collections that the metric 1131 value variable should be added to. 1132 updates_collections: An optional list of collections that the metric update 1133 ops should be added to. 1134 name: An optional variable_scope name. 1135 1136 Returns: 1137 mean_distance: A `Tensor` representing the current mean, the value of 1138 `total` divided by `count`. 1139 update_op: An operation that increments the `total` and `count` variables 1140 appropriately. 1141 1142 Raises: 1143 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1144 `weights` is not `None` and its shape doesn't match `predictions`, or if 1145 either `metrics_collections` or `updates_collections` are not a list or 1146 tuple. 1147 RuntimeError: If eager execution is enabled. 1148 """ 1149 if context.executing_eagerly(): 1150 raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when ' 1151 'eager execution is enabled.') 1152 1153 predictions, labels, weights = _remove_squeezable_dimensions( 1154 predictions=predictions, labels=labels, weights=weights) 1155 radial_diffs = math_ops.multiply(predictions, labels) 1156 radial_diffs = math_ops.reduce_sum( 1157 radial_diffs, axis=[ 1158 dim, 1159 ], keepdims=True) 1160 mean_distance, update_op = mean(radial_diffs, weights, None, None, name or 1161 'mean_cosine_distance') 1162 mean_distance = math_ops.subtract(1.0, mean_distance) 1163 update_op = math_ops.subtract(1.0, update_op) 1164 1165 if metrics_collections: 1166 ops.add_to_collections(metrics_collections, mean_distance) 1167 1168 if updates_collections: 1169 ops.add_to_collections(updates_collections, update_op) 1170 1171 return mean_distance, update_op 1172 1173 1174@tf_export(v1=['metrics.mean_per_class_accuracy']) 1175def mean_per_class_accuracy(labels, 1176 predictions, 1177 num_classes, 1178 weights=None, 1179 metrics_collections=None, 1180 updates_collections=None, 1181 name=None): 1182 """Calculates the mean of the per-class accuracies. 1183 1184 Calculates the accuracy for each class, then takes the mean of that. 1185 1186 For estimation of the metric over a stream of data, the function creates an 1187 `update_op` operation that updates the accuracy of each class and returns 1188 them. 1189 1190 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1191 1192 Args: 1193 labels: A `Tensor` of ground truth labels with shape [batch size] and of 1194 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 1195 predictions: A `Tensor` of prediction results for semantic labels, whose 1196 shape is [batch size] and type `int32` or `int64`. The tensor will be 1197 flattened if its rank > 1. 1198 num_classes: The possible number of labels the prediction task can 1199 have. This value must be provided, since two variables with shape = 1200 [num_classes] will be allocated. 1201 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1202 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1203 be either `1`, or the same as the corresponding `labels` dimension). 1204 metrics_collections: An optional list of collections that 1205 `mean_per_class_accuracy' 1206 should be added to. 1207 updates_collections: An optional list of collections `update_op` should be 1208 added to. 1209 name: An optional variable_scope name. 1210 1211 Returns: 1212 mean_accuracy: A `Tensor` representing the mean per class accuracy. 1213 update_op: An operation that updates the accuracy tensor. 1214 1215 Raises: 1216 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1217 `weights` is not `None` and its shape doesn't match `predictions`, or if 1218 either `metrics_collections` or `updates_collections` are not a list or 1219 tuple. 1220 RuntimeError: If eager execution is enabled. 1221 """ 1222 if context.executing_eagerly(): 1223 raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported ' 1224 'when eager execution is enabled.') 1225 1226 with variable_scope.variable_scope(name, 'mean_accuracy', 1227 (predictions, labels, weights)): 1228 labels = math_ops.cast(labels, dtypes.int64) 1229 1230 # Flatten the input if its rank > 1. 1231 if labels.get_shape().ndims > 1: 1232 labels = array_ops.reshape(labels, [-1]) 1233 1234 if predictions.get_shape().ndims > 1: 1235 predictions = array_ops.reshape(predictions, [-1]) 1236 1237 # Check if shape is compatible. 1238 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1239 1240 total = metric_variable([num_classes], dtypes.float32, name='total') 1241 count = metric_variable([num_classes], dtypes.float32, name='count') 1242 1243 ones = array_ops.ones([array_ops.size(labels)], dtypes.float32) 1244 1245 if labels.dtype != predictions.dtype: 1246 predictions = math_ops.cast(predictions, labels.dtype) 1247 is_correct = math_ops.cast( 1248 math_ops.equal(predictions, labels), dtypes.float32) 1249 1250 if weights is not None: 1251 if weights.get_shape().ndims > 1: 1252 weights = array_ops.reshape(weights, [-1]) 1253 weights = math_ops.cast(weights, dtypes.float32) 1254 1255 is_correct *= weights 1256 ones *= weights 1257 1258 update_total_op = state_ops.scatter_add(total, labels, ones) 1259 update_count_op = state_ops.scatter_add(count, labels, is_correct) 1260 1261 def compute_mean_accuracy(_, count, total): 1262 per_class_accuracy = math_ops.div_no_nan( 1263 count, math_ops.maximum(total, 0), name=None) 1264 mean_accuracy_v = math_ops.reduce_mean( 1265 per_class_accuracy, name='mean_accuracy') 1266 return mean_accuracy_v 1267 1268 mean_accuracy_v = _aggregate_across_replicas( 1269 metrics_collections, compute_mean_accuracy, count, total) 1270 1271 update_op = math_ops.div_no_nan( 1272 update_count_op, math_ops.maximum(update_total_op, 0), name='update_op') 1273 if updates_collections: 1274 ops.add_to_collections(updates_collections, update_op) 1275 1276 return mean_accuracy_v, update_op 1277 1278 1279@tf_export(v1=['metrics.mean_iou']) 1280def mean_iou(labels, 1281 predictions, 1282 num_classes, 1283 weights=None, 1284 metrics_collections=None, 1285 updates_collections=None, 1286 name=None): 1287 """Calculate per-step mean Intersection-Over-Union (mIOU). 1288 1289 Mean Intersection-Over-Union is a common evaluation metric for 1290 semantic image segmentation, which first computes the IOU for each 1291 semantic class and then computes the average over classes. 1292 IOU is defined as follows: 1293 IOU = true_positive / (true_positive + false_positive + false_negative). 1294 The predictions are accumulated in a confusion matrix, weighted by `weights`, 1295 and mIOU is then calculated from it. 1296 1297 For estimation of the metric over a stream of data, the function creates an 1298 `update_op` operation that updates these variables and returns the `mean_iou`. 1299 1300 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1301 1302 Args: 1303 labels: A `Tensor` of ground truth labels with shape [batch size] and of 1304 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 1305 predictions: A `Tensor` of prediction results for semantic labels, whose 1306 shape is [batch size] and type `int32` or `int64`. The tensor will be 1307 flattened if its rank > 1. 1308 num_classes: The possible number of labels the prediction task can 1309 have. This value must be provided, since a confusion matrix of 1310 dimension = [num_classes, num_classes] will be allocated. 1311 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1312 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1313 be either `1`, or the same as the corresponding `labels` dimension). 1314 metrics_collections: An optional list of collections that `mean_iou` 1315 should be added to. 1316 updates_collections: An optional list of collections `update_op` should be 1317 added to. 1318 name: An optional variable_scope name. 1319 1320 Returns: 1321 mean_iou: A `Tensor` representing the mean intersection-over-union. 1322 update_op: An operation that increments the confusion matrix. 1323 1324 Raises: 1325 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1326 `weights` is not `None` and its shape doesn't match `predictions`, or if 1327 either `metrics_collections` or `updates_collections` are not a list or 1328 tuple. 1329 RuntimeError: If eager execution is enabled. 1330 """ 1331 if context.executing_eagerly(): 1332 raise RuntimeError('tf.metrics.mean_iou is not supported when ' 1333 'eager execution is enabled.') 1334 1335 with variable_scope.variable_scope(name, 'mean_iou', 1336 (predictions, labels, weights)): 1337 # Check if shape is compatible. 1338 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1339 1340 total_cm, update_op = _streaming_confusion_matrix(labels, predictions, 1341 num_classes, weights) 1342 1343 def compute_mean_iou(_, total_cm): 1344 """Compute the mean intersection-over-union via the confusion matrix.""" 1345 sum_over_row = math_ops.cast( 1346 math_ops.reduce_sum(total_cm, 0), dtypes.float32) 1347 sum_over_col = math_ops.cast( 1348 math_ops.reduce_sum(total_cm, 1), dtypes.float32) 1349 cm_diag = math_ops.cast(array_ops.diag_part(total_cm), dtypes.float32) 1350 denominator = sum_over_row + sum_over_col - cm_diag 1351 1352 # The mean is only computed over classes that appear in the 1353 # label or prediction tensor. If the denominator is 0, we need to 1354 # ignore the class. 1355 num_valid_entries = math_ops.reduce_sum( 1356 math_ops.cast( 1357 math_ops.not_equal(denominator, 0), dtype=dtypes.float32)) 1358 1359 # If the value of the denominator is 0, set it to 1 to avoid 1360 # zero division. 1361 denominator = array_ops.where( 1362 math_ops.greater(denominator, 0), denominator, 1363 array_ops.ones_like(denominator)) 1364 iou = math_ops.divide(cm_diag, denominator) 1365 1366 # If the number of valid entries is 0 (no classes) we return 0. 1367 result = array_ops.where( 1368 math_ops.greater(num_valid_entries, 0), 1369 math_ops.reduce_sum(iou, name='mean_iou') / num_valid_entries, 0) 1370 return result 1371 1372 # TODO(priyag): Use outside_compilation if in TPU context. 1373 mean_iou_v = _aggregate_across_replicas( 1374 metrics_collections, compute_mean_iou, total_cm) 1375 1376 if updates_collections: 1377 ops.add_to_collections(updates_collections, update_op) 1378 1379 return mean_iou_v, update_op 1380 1381 1382@tf_export(v1=['metrics.mean_relative_error']) 1383def mean_relative_error(labels, 1384 predictions, 1385 normalizer, 1386 weights=None, 1387 metrics_collections=None, 1388 updates_collections=None, 1389 name=None): 1390 """Computes the mean relative error by normalizing with the given values. 1391 1392 The `mean_relative_error` function creates two local variables, 1393 `total` and `count` that are used to compute the mean relative absolute error. 1394 This average is weighted by `weights`, and it is ultimately returned as 1395 `mean_relative_error`: an idempotent operation that simply divides `total` by 1396 `count`. 1397 1398 For estimation of the metric over a stream of data, the function creates an 1399 `update_op` operation that updates these variables and returns the 1400 `mean_reative_error`. Internally, a `relative_errors` operation divides the 1401 absolute value of the differences between `predictions` and `labels` by the 1402 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 1403 product of `weights` and `relative_errors`, and it increments `count` with the 1404 reduced sum of `weights`. 1405 1406 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1407 1408 Args: 1409 labels: A `Tensor` of the same shape as `predictions`. 1410 predictions: A `Tensor` of arbitrary shape. 1411 normalizer: A `Tensor` of the same shape as `predictions`. 1412 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1413 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1414 be either `1`, or the same as the corresponding `labels` dimension). 1415 metrics_collections: An optional list of collections that 1416 `mean_relative_error` should be added to. 1417 updates_collections: An optional list of collections that `update_op` should 1418 be added to. 1419 name: An optional variable_scope name. 1420 1421 Returns: 1422 mean_relative_error: A `Tensor` representing the current mean, the value of 1423 `total` divided by `count`. 1424 update_op: An operation that increments the `total` and `count` variables 1425 appropriately and whose value matches `mean_relative_error`. 1426 1427 Raises: 1428 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1429 `weights` is not `None` and its shape doesn't match `predictions`, or if 1430 either `metrics_collections` or `updates_collections` are not a list or 1431 tuple. 1432 RuntimeError: If eager execution is enabled. 1433 """ 1434 if context.executing_eagerly(): 1435 raise RuntimeError('tf.metrics.mean_relative_error is not supported when ' 1436 'eager execution is enabled.') 1437 1438 predictions, labels, weights = _remove_squeezable_dimensions( 1439 predictions=predictions, labels=labels, weights=weights) 1440 1441 predictions, normalizer = confusion_matrix.remove_squeezable_dimensions( 1442 predictions, normalizer) 1443 predictions.get_shape().assert_is_compatible_with(normalizer.get_shape()) 1444 relative_errors = array_ops.where( 1445 math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels), 1446 math_ops.divide(math_ops.abs(labels - predictions), normalizer)) 1447 return mean(relative_errors, weights, metrics_collections, 1448 updates_collections, name or 'mean_relative_error') 1449 1450 1451@tf_export(v1=['metrics.mean_squared_error']) 1452def mean_squared_error(labels, 1453 predictions, 1454 weights=None, 1455 metrics_collections=None, 1456 updates_collections=None, 1457 name=None): 1458 """Computes the mean squared error between the labels and predictions. 1459 1460 The `mean_squared_error` function creates two local variables, 1461 `total` and `count` that are used to compute the mean squared error. 1462 This average is weighted by `weights`, and it is ultimately returned as 1463 `mean_squared_error`: an idempotent operation that simply divides `total` by 1464 `count`. 1465 1466 For estimation of the metric over a stream of data, the function creates an 1467 `update_op` operation that updates these variables and returns the 1468 `mean_squared_error`. Internally, a `squared_error` operation computes the 1469 element-wise square of the difference between `predictions` and `labels`. Then 1470 `update_op` increments `total` with the reduced sum of the product of 1471 `weights` and `squared_error`, and it increments `count` with the reduced sum 1472 of `weights`. 1473 1474 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1475 1476 Args: 1477 labels: A `Tensor` of the same shape as `predictions`. 1478 predictions: A `Tensor` of arbitrary shape. 1479 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1480 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1481 be either `1`, or the same as the corresponding `labels` dimension). 1482 metrics_collections: An optional list of collections that 1483 `mean_squared_error` should be added to. 1484 updates_collections: An optional list of collections that `update_op` should 1485 be added to. 1486 name: An optional variable_scope name. 1487 1488 Returns: 1489 mean_squared_error: A `Tensor` representing the current mean, the value of 1490 `total` divided by `count`. 1491 update_op: An operation that increments the `total` and `count` variables 1492 appropriately and whose value matches `mean_squared_error`. 1493 1494 Raises: 1495 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1496 `weights` is not `None` and its shape doesn't match `predictions`, or if 1497 either `metrics_collections` or `updates_collections` are not a list or 1498 tuple. 1499 RuntimeError: If eager execution is enabled. 1500 """ 1501 if context.executing_eagerly(): 1502 raise RuntimeError('tf.metrics.mean_squared_error is not supported when ' 1503 'eager execution is enabled.') 1504 1505 predictions, labels, weights = _remove_squeezable_dimensions( 1506 predictions=predictions, labels=labels, weights=weights) 1507 squared_error = math_ops.squared_difference(labels, predictions) 1508 return mean(squared_error, weights, metrics_collections, updates_collections, 1509 name or 'mean_squared_error') 1510 1511 1512@tf_export(v1=['metrics.mean_tensor']) 1513def mean_tensor(values, 1514 weights=None, 1515 metrics_collections=None, 1516 updates_collections=None, 1517 name=None): 1518 """Computes the element-wise (weighted) mean of the given tensors. 1519 1520 In contrast to the `mean` function which returns a scalar with the 1521 mean, this function returns an average tensor with the same shape as the 1522 input tensors. 1523 1524 The `mean_tensor` function creates two local variables, 1525 `total_tensor` and `count_tensor` that are used to compute the average of 1526 `values`. This average is ultimately returned as `mean` which is an idempotent 1527 operation that simply divides `total` by `count`. 1528 1529 For estimation of the metric over a stream of data, the function creates an 1530 `update_op` operation that updates these variables and returns the `mean`. 1531 `update_op` increments `total` with the reduced sum of the product of `values` 1532 and `weights`, and it increments `count` with the reduced sum of `weights`. 1533 1534 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1535 1536 Args: 1537 values: A `Tensor` of arbitrary dimensions. 1538 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1539 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1540 be either `1`, or the same as the corresponding `values` dimension). 1541 metrics_collections: An optional list of collections that `mean` 1542 should be added to. 1543 updates_collections: An optional list of collections that `update_op` 1544 should be added to. 1545 name: An optional variable_scope name. 1546 1547 Returns: 1548 mean: A float `Tensor` representing the current mean, the value of `total` 1549 divided by `count`. 1550 update_op: An operation that increments the `total` and `count` variables 1551 appropriately and whose value matches `mean_value`. 1552 1553 Raises: 1554 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1555 or if either `metrics_collections` or `updates_collections` are not a list 1556 or tuple. 1557 RuntimeError: If eager execution is enabled. 1558 """ 1559 if context.executing_eagerly(): 1560 raise RuntimeError('tf.metrics.mean_tensor is not supported when ' 1561 'eager execution is enabled.') 1562 1563 with variable_scope.variable_scope(name, 'mean', (values, weights)): 1564 values = math_ops.cast(values, dtypes.float32) 1565 total = metric_variable( 1566 values.get_shape(), dtypes.float32, name='total_tensor') 1567 count = metric_variable( 1568 values.get_shape(), dtypes.float32, name='count_tensor') 1569 1570 num_values = array_ops.ones_like(values) 1571 if weights is not None: 1572 values, _, weights = _remove_squeezable_dimensions( 1573 predictions=values, labels=None, weights=weights) 1574 weights = weights_broadcast_ops.broadcast_weights( 1575 math_ops.cast(weights, dtypes.float32), values) 1576 values = math_ops.multiply(values, weights) 1577 num_values = math_ops.multiply(num_values, weights) 1578 1579 update_total_op = state_ops.assign_add(total, values) 1580 with ops.control_dependencies([values]): 1581 update_count_op = state_ops.assign_add(count, num_values) 1582 1583 compute_mean = lambda _, t, c: math_ops.div_no_nan( # pylint: disable=g-long-lambda 1584 t, math_ops.maximum(c, 0), name='value') 1585 1586 mean_t = _aggregate_across_replicas( 1587 metrics_collections, compute_mean, total, count) 1588 1589 update_op = math_ops.div_no_nan( 1590 update_total_op, math_ops.maximum(update_count_op, 0), name='update_op') 1591 if updates_collections: 1592 ops.add_to_collections(updates_collections, update_op) 1593 1594 return mean_t, update_op 1595 1596 1597@tf_export(v1=['metrics.percentage_below']) 1598def percentage_below(values, 1599 threshold, 1600 weights=None, 1601 metrics_collections=None, 1602 updates_collections=None, 1603 name=None): 1604 """Computes the percentage of values less than the given threshold. 1605 1606 The `percentage_below` function creates two local variables, 1607 `total` and `count` that are used to compute the percentage of `values` that 1608 fall below `threshold`. This rate is weighted by `weights`, and it is 1609 ultimately returned as `percentage` which is an idempotent operation that 1610 simply divides `total` by `count`. 1611 1612 For estimation of the metric over a stream of data, the function creates an 1613 `update_op` operation that updates these variables and returns the 1614 `percentage`. 1615 1616 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1617 1618 Args: 1619 values: A numeric `Tensor` of arbitrary size. 1620 threshold: A scalar threshold. 1621 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1622 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1623 be either `1`, or the same as the corresponding `values` dimension). 1624 metrics_collections: An optional list of collections that the metric 1625 value variable should be added to. 1626 updates_collections: An optional list of collections that the metric update 1627 ops should be added to. 1628 name: An optional variable_scope name. 1629 1630 Returns: 1631 percentage: A `Tensor` representing the current mean, the value of `total` 1632 divided by `count`. 1633 update_op: An operation that increments the `total` and `count` variables 1634 appropriately. 1635 1636 Raises: 1637 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1638 or if either `metrics_collections` or `updates_collections` are not a list 1639 or tuple. 1640 RuntimeError: If eager execution is enabled. 1641 """ 1642 if context.executing_eagerly(): 1643 raise RuntimeError('tf.metrics.percentage_below is not supported when ' 1644 'eager execution is enabled.') 1645 1646 is_below_threshold = math_ops.cast( 1647 math_ops.less(values, threshold), dtypes.float32) 1648 return mean(is_below_threshold, weights, metrics_collections, 1649 updates_collections, name or 'percentage_below_threshold') 1650 1651 1652def _count_condition(values, 1653 weights=None, 1654 metrics_collections=None, 1655 updates_collections=None): 1656 """Sums the weights of cases where the given values are True. 1657 1658 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1659 1660 Args: 1661 values: A `bool` `Tensor` of arbitrary size. 1662 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1663 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1664 be either `1`, or the same as the corresponding `values` dimension). 1665 metrics_collections: An optional list of collections that the metric 1666 value variable should be added to. 1667 updates_collections: An optional list of collections that the metric update 1668 ops should be added to. 1669 1670 Returns: 1671 value_tensor: A `Tensor` representing the current value of the metric. 1672 update_op: An operation that accumulates the error from a batch of data. 1673 1674 Raises: 1675 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1676 or if either `metrics_collections` or `updates_collections` are not a list 1677 or tuple. 1678 """ 1679 check_ops.assert_type(values, dtypes.bool) 1680 count = metric_variable([], dtypes.float32, name='count') 1681 1682 values = math_ops.cast(values, dtypes.float32) 1683 if weights is not None: 1684 with ops.control_dependencies((check_ops.assert_rank_in( 1685 weights, (0, array_ops.rank(values))),)): 1686 weights = math_ops.cast(weights, dtypes.float32) 1687 values = math_ops.multiply(values, weights) 1688 1689 value_tensor = _aggregate_variable(count, metrics_collections) 1690 1691 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) 1692 if updates_collections: 1693 ops.add_to_collections(updates_collections, update_op) 1694 1695 return value_tensor, update_op 1696 1697 1698@tf_export(v1=['metrics.false_negatives']) 1699def false_negatives(labels, 1700 predictions, 1701 weights=None, 1702 metrics_collections=None, 1703 updates_collections=None, 1704 name=None): 1705 """Computes the total number of false negatives. 1706 1707 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1708 1709 Args: 1710 labels: The ground truth values, a `Tensor` whose dimensions must match 1711 `predictions`. Will be cast to `bool`. 1712 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1713 be cast to `bool`. 1714 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1715 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1716 be either `1`, or the same as the corresponding `labels` dimension). 1717 metrics_collections: An optional list of collections that the metric 1718 value variable should be added to. 1719 updates_collections: An optional list of collections that the metric update 1720 ops should be added to. 1721 name: An optional variable_scope name. 1722 1723 Returns: 1724 value_tensor: A `Tensor` representing the current value of the metric. 1725 update_op: An operation that accumulates the error from a batch of data. 1726 1727 Raises: 1728 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1729 or if either `metrics_collections` or `updates_collections` are not a list 1730 or tuple. 1731 RuntimeError: If eager execution is enabled. 1732 """ 1733 if context.executing_eagerly(): 1734 raise RuntimeError('tf.metrics.false_negatives is not supported when ' 1735 'eager execution is enabled.') 1736 1737 with variable_scope.variable_scope(name, 'false_negatives', 1738 (predictions, labels, weights)): 1739 1740 predictions, labels, weights = _remove_squeezable_dimensions( 1741 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1742 labels=math_ops.cast(labels, dtype=dtypes.bool), 1743 weights=weights) 1744 is_false_negative = math_ops.logical_and( 1745 math_ops.equal(labels, True), math_ops.equal(predictions, False)) 1746 return _count_condition(is_false_negative, weights, metrics_collections, 1747 updates_collections) 1748 1749 1750@tf_export(v1=['metrics.false_negatives_at_thresholds']) 1751def false_negatives_at_thresholds(labels, 1752 predictions, 1753 thresholds, 1754 weights=None, 1755 metrics_collections=None, 1756 updates_collections=None, 1757 name=None): 1758 """Computes false negatives at provided threshold values. 1759 1760 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1761 1762 Args: 1763 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1764 `bool`. 1765 predictions: A floating point `Tensor` of arbitrary shape and whose values 1766 are in the range `[0, 1]`. 1767 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1768 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1769 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1770 be either `1`, or the same as the corresponding `labels` dimension). 1771 metrics_collections: An optional list of collections that `false_negatives` 1772 should be added to. 1773 updates_collections: An optional list of collections that `update_op` should 1774 be added to. 1775 name: An optional variable_scope name. 1776 1777 Returns: 1778 false_negatives: A float `Tensor` of shape `[len(thresholds)]`. 1779 update_op: An operation that updates the `false_negatives` variable and 1780 returns its current value. 1781 1782 Raises: 1783 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1784 `weights` is not `None` and its shape doesn't match `predictions`, or if 1785 either `metrics_collections` or `updates_collections` are not a list or 1786 tuple. 1787 RuntimeError: If eager execution is enabled. 1788 """ 1789 if context.executing_eagerly(): 1790 raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not ' 1791 'supported when eager execution is enabled.') 1792 1793 with variable_scope.variable_scope(name, 'false_negatives', 1794 (predictions, labels, weights)): 1795 values, update_ops = _confusion_matrix_at_thresholds( 1796 labels, predictions, thresholds, weights=weights, includes=('fn',)) 1797 1798 fn_value = _aggregate_variable(values['fn'], metrics_collections) 1799 1800 if updates_collections: 1801 ops.add_to_collections(updates_collections, update_ops['fn']) 1802 1803 return fn_value, update_ops['fn'] 1804 1805 1806@tf_export(v1=['metrics.false_positives']) 1807def false_positives(labels, 1808 predictions, 1809 weights=None, 1810 metrics_collections=None, 1811 updates_collections=None, 1812 name=None): 1813 """Sum the weights of false positives. 1814 1815 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1816 1817 Args: 1818 labels: The ground truth values, a `Tensor` whose dimensions must match 1819 `predictions`. Will be cast to `bool`. 1820 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1821 be cast to `bool`. 1822 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1823 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1824 be either `1`, or the same as the corresponding `labels` dimension). 1825 metrics_collections: An optional list of collections that the metric 1826 value variable should be added to. 1827 updates_collections: An optional list of collections that the metric update 1828 ops should be added to. 1829 name: An optional variable_scope name. 1830 1831 Returns: 1832 value_tensor: A `Tensor` representing the current value of the metric. 1833 update_op: An operation that accumulates the error from a batch of data. 1834 1835 Raises: 1836 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1837 `weights` is not `None` and its shape doesn't match `predictions`, or if 1838 either `metrics_collections` or `updates_collections` are not a list or 1839 tuple. 1840 RuntimeError: If eager execution is enabled. 1841 """ 1842 if context.executing_eagerly(): 1843 raise RuntimeError('tf.metrics.false_positives is not supported when ' 1844 'eager execution is enabled.') 1845 1846 with variable_scope.variable_scope(name, 'false_positives', 1847 (predictions, labels, weights)): 1848 1849 predictions, labels, weights = _remove_squeezable_dimensions( 1850 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1851 labels=math_ops.cast(labels, dtype=dtypes.bool), 1852 weights=weights) 1853 is_false_positive = math_ops.logical_and( 1854 math_ops.equal(labels, False), math_ops.equal(predictions, True)) 1855 return _count_condition(is_false_positive, weights, metrics_collections, 1856 updates_collections) 1857 1858 1859@tf_export(v1=['metrics.false_positives_at_thresholds']) 1860def false_positives_at_thresholds(labels, 1861 predictions, 1862 thresholds, 1863 weights=None, 1864 metrics_collections=None, 1865 updates_collections=None, 1866 name=None): 1867 """Computes false positives at provided threshold values. 1868 1869 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1870 1871 Args: 1872 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1873 `bool`. 1874 predictions: A floating point `Tensor` of arbitrary shape and whose values 1875 are in the range `[0, 1]`. 1876 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1877 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1878 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1879 be either `1`, or the same as the corresponding `labels` dimension). 1880 metrics_collections: An optional list of collections that `false_positives` 1881 should be added to. 1882 updates_collections: An optional list of collections that `update_op` should 1883 be added to. 1884 name: An optional variable_scope name. 1885 1886 Returns: 1887 false_positives: A float `Tensor` of shape `[len(thresholds)]`. 1888 update_op: An operation that updates the `false_positives` variable and 1889 returns its current value. 1890 1891 Raises: 1892 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1893 `weights` is not `None` and its shape doesn't match `predictions`, or if 1894 either `metrics_collections` or `updates_collections` are not a list or 1895 tuple. 1896 RuntimeError: If eager execution is enabled. 1897 """ 1898 if context.executing_eagerly(): 1899 raise RuntimeError('tf.metrics.false_positives_at_thresholds is not ' 1900 'supported when eager execution is enabled.') 1901 1902 with variable_scope.variable_scope(name, 'false_positives', 1903 (predictions, labels, weights)): 1904 values, update_ops = _confusion_matrix_at_thresholds( 1905 labels, predictions, thresholds, weights=weights, includes=('fp',)) 1906 1907 fp_value = _aggregate_variable(values['fp'], metrics_collections) 1908 1909 if updates_collections: 1910 ops.add_to_collections(updates_collections, update_ops['fp']) 1911 1912 return fp_value, update_ops['fp'] 1913 1914 1915@tf_export(v1=['metrics.true_negatives']) 1916def true_negatives(labels, 1917 predictions, 1918 weights=None, 1919 metrics_collections=None, 1920 updates_collections=None, 1921 name=None): 1922 """Sum the weights of true_negatives. 1923 1924 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1925 1926 Args: 1927 labels: The ground truth values, a `Tensor` whose dimensions must match 1928 `predictions`. Will be cast to `bool`. 1929 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1930 be cast to `bool`. 1931 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1932 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1933 be either `1`, or the same as the corresponding `labels` dimension). 1934 metrics_collections: An optional list of collections that the metric 1935 value variable should be added to. 1936 updates_collections: An optional list of collections that the metric update 1937 ops should be added to. 1938 name: An optional variable_scope name. 1939 1940 Returns: 1941 value_tensor: A `Tensor` representing the current value of the metric. 1942 update_op: An operation that accumulates the error from a batch of data. 1943 1944 Raises: 1945 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1946 `weights` is not `None` and its shape doesn't match `predictions`, or if 1947 either `metrics_collections` or `updates_collections` are not a list or 1948 tuple. 1949 RuntimeError: If eager execution is enabled. 1950 """ 1951 if context.executing_eagerly(): 1952 raise RuntimeError('tf.metrics.true_negatives is not ' 1953 'supported when eager execution is enabled.') 1954 1955 with variable_scope.variable_scope(name, 'true_negatives', 1956 (predictions, labels, weights)): 1957 1958 predictions, labels, weights = _remove_squeezable_dimensions( 1959 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1960 labels=math_ops.cast(labels, dtype=dtypes.bool), 1961 weights=weights) 1962 is_true_negative = math_ops.logical_and( 1963 math_ops.equal(labels, False), math_ops.equal(predictions, False)) 1964 return _count_condition(is_true_negative, weights, metrics_collections, 1965 updates_collections) 1966 1967 1968@tf_export(v1=['metrics.true_negatives_at_thresholds']) 1969def true_negatives_at_thresholds(labels, 1970 predictions, 1971 thresholds, 1972 weights=None, 1973 metrics_collections=None, 1974 updates_collections=None, 1975 name=None): 1976 """Computes true negatives at provided threshold values. 1977 1978 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1979 1980 Args: 1981 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1982 `bool`. 1983 predictions: A floating point `Tensor` of arbitrary shape and whose values 1984 are in the range `[0, 1]`. 1985 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1986 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1987 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1988 be either `1`, or the same as the corresponding `labels` dimension). 1989 metrics_collections: An optional list of collections that `true_negatives` 1990 should be added to. 1991 updates_collections: An optional list of collections that `update_op` should 1992 be added to. 1993 name: An optional variable_scope name. 1994 1995 Returns: 1996 true_negatives: A float `Tensor` of shape `[len(thresholds)]`. 1997 update_op: An operation that updates the `true_negatives` variable and 1998 returns its current value. 1999 2000 Raises: 2001 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2002 `weights` is not `None` and its shape doesn't match `predictions`, or if 2003 either `metrics_collections` or `updates_collections` are not a list or 2004 tuple. 2005 RuntimeError: If eager execution is enabled. 2006 """ 2007 if context.executing_eagerly(): 2008 raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not ' 2009 'supported when eager execution is enabled.') 2010 2011 with variable_scope.variable_scope(name, 'true_negatives', 2012 (predictions, labels, weights)): 2013 values, update_ops = _confusion_matrix_at_thresholds( 2014 labels, predictions, thresholds, weights=weights, includes=('tn',)) 2015 2016 tn_value = _aggregate_variable(values['tn'], metrics_collections) 2017 2018 if updates_collections: 2019 ops.add_to_collections(updates_collections, update_ops['tn']) 2020 2021 return tn_value, update_ops['tn'] 2022 2023 2024@tf_export(v1=['metrics.true_positives']) 2025def true_positives(labels, 2026 predictions, 2027 weights=None, 2028 metrics_collections=None, 2029 updates_collections=None, 2030 name=None): 2031 """Sum the weights of true_positives. 2032 2033 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2034 2035 Args: 2036 labels: The ground truth values, a `Tensor` whose dimensions must match 2037 `predictions`. Will be cast to `bool`. 2038 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 2039 be cast to `bool`. 2040 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2041 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2042 be either `1`, or the same as the corresponding `labels` dimension). 2043 metrics_collections: An optional list of collections that the metric 2044 value variable should be added to. 2045 updates_collections: An optional list of collections that the metric update 2046 ops should be added to. 2047 name: An optional variable_scope name. 2048 2049 Returns: 2050 value_tensor: A `Tensor` representing the current value of the metric. 2051 update_op: An operation that accumulates the error from a batch of data. 2052 2053 Raises: 2054 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2055 `weights` is not `None` and its shape doesn't match `predictions`, or if 2056 either `metrics_collections` or `updates_collections` are not a list or 2057 tuple. 2058 RuntimeError: If eager execution is enabled. 2059 """ 2060 if context.executing_eagerly(): 2061 raise RuntimeError('tf.metrics.true_positives is not ' 2062 'supported when eager execution is enabled.') 2063 2064 with variable_scope.variable_scope(name, 'true_positives', 2065 (predictions, labels, weights)): 2066 2067 predictions, labels, weights = _remove_squeezable_dimensions( 2068 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 2069 labels=math_ops.cast(labels, dtype=dtypes.bool), 2070 weights=weights) 2071 is_true_positive = math_ops.logical_and( 2072 math_ops.equal(labels, True), math_ops.equal(predictions, True)) 2073 return _count_condition(is_true_positive, weights, metrics_collections, 2074 updates_collections) 2075 2076 2077@tf_export(v1=['metrics.true_positives_at_thresholds']) 2078def true_positives_at_thresholds(labels, 2079 predictions, 2080 thresholds, 2081 weights=None, 2082 metrics_collections=None, 2083 updates_collections=None, 2084 name=None): 2085 """Computes true positives at provided threshold values. 2086 2087 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2088 2089 Args: 2090 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 2091 `bool`. 2092 predictions: A floating point `Tensor` of arbitrary shape and whose values 2093 are in the range `[0, 1]`. 2094 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 2095 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2096 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2097 be either `1`, or the same as the corresponding `labels` dimension). 2098 metrics_collections: An optional list of collections that `true_positives` 2099 should be added to. 2100 updates_collections: An optional list of collections that `update_op` should 2101 be added to. 2102 name: An optional variable_scope name. 2103 2104 Returns: 2105 true_positives: A float `Tensor` of shape `[len(thresholds)]`. 2106 update_op: An operation that updates the `true_positives` variable and 2107 returns its current value. 2108 2109 Raises: 2110 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2111 `weights` is not `None` and its shape doesn't match `predictions`, or if 2112 either `metrics_collections` or `updates_collections` are not a list or 2113 tuple. 2114 RuntimeError: If eager execution is enabled. 2115 """ 2116 if context.executing_eagerly(): 2117 raise RuntimeError('tf.metrics.true_positives_at_thresholds is not ' 2118 'supported when eager execution is enabled.') 2119 2120 with variable_scope.variable_scope(name, 'true_positives', 2121 (predictions, labels, weights)): 2122 values, update_ops = _confusion_matrix_at_thresholds( 2123 labels, predictions, thresholds, weights=weights, includes=('tp',)) 2124 2125 tp_value = _aggregate_variable(values['tp'], metrics_collections) 2126 2127 if updates_collections: 2128 ops.add_to_collections(updates_collections, update_ops['tp']) 2129 2130 return tp_value, update_ops['tp'] 2131 2132 2133@tf_export(v1=['metrics.precision']) 2134def precision(labels, 2135 predictions, 2136 weights=None, 2137 metrics_collections=None, 2138 updates_collections=None, 2139 name=None): 2140 """Computes the precision of the predictions with respect to the labels. 2141 2142 The `precision` function creates two local variables, 2143 `true_positives` and `false_positives`, that are used to compute the 2144 precision. This value is ultimately returned as `precision`, an idempotent 2145 operation that simply divides `true_positives` by the sum of `true_positives` 2146 and `false_positives`. 2147 2148 For estimation of the metric over a stream of data, the function creates an 2149 `update_op` operation that updates these variables and returns the 2150 `precision`. `update_op` weights each prediction by the corresponding value in 2151 `weights`. 2152 2153 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2154 2155 Args: 2156 labels: The ground truth values, a `Tensor` whose dimensions must match 2157 `predictions`. Will be cast to `bool`. 2158 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 2159 be cast to `bool`. 2160 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2161 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2162 be either `1`, or the same as the corresponding `labels` dimension). 2163 metrics_collections: An optional list of collections that `precision` should 2164 be added to. 2165 updates_collections: An optional list of collections that `update_op` should 2166 be added to. 2167 name: An optional variable_scope name. 2168 2169 Returns: 2170 precision: Scalar float `Tensor` with the value of `true_positives` 2171 divided by the sum of `true_positives` and `false_positives`. 2172 update_op: `Operation` that increments `true_positives` and 2173 `false_positives` variables appropriately and whose value matches 2174 `precision`. 2175 2176 Raises: 2177 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2178 `weights` is not `None` and its shape doesn't match `predictions`, or if 2179 either `metrics_collections` or `updates_collections` are not a list or 2180 tuple. 2181 RuntimeError: If eager execution is enabled. 2182 """ 2183 if context.executing_eagerly(): 2184 raise RuntimeError('tf.metrics.precision is not ' 2185 'supported when eager execution is enabled.') 2186 2187 with variable_scope.variable_scope(name, 'precision', 2188 (predictions, labels, weights)): 2189 2190 predictions, labels, weights = _remove_squeezable_dimensions( 2191 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 2192 labels=math_ops.cast(labels, dtype=dtypes.bool), 2193 weights=weights) 2194 2195 true_p, true_positives_update_op = true_positives( 2196 labels, 2197 predictions, 2198 weights, 2199 metrics_collections=None, 2200 updates_collections=None, 2201 name=None) 2202 false_p, false_positives_update_op = false_positives( 2203 labels, 2204 predictions, 2205 weights, 2206 metrics_collections=None, 2207 updates_collections=None, 2208 name=None) 2209 2210 def compute_precision(tp, fp, name): 2211 return array_ops.where( 2212 math_ops.greater(tp + fp, 0), math_ops.divide(tp, tp + fp), 0, name) 2213 2214 def once_across_replicas(_, true_p, false_p): 2215 return compute_precision(true_p, false_p, 'value') 2216 2217 p = _aggregate_across_replicas(metrics_collections, once_across_replicas, 2218 true_p, false_p) 2219 2220 update_op = compute_precision(true_positives_update_op, 2221 false_positives_update_op, 'update_op') 2222 if updates_collections: 2223 ops.add_to_collections(updates_collections, update_op) 2224 2225 return p, update_op 2226 2227 2228@tf_export(v1=['metrics.precision_at_thresholds']) 2229def precision_at_thresholds(labels, 2230 predictions, 2231 thresholds, 2232 weights=None, 2233 metrics_collections=None, 2234 updates_collections=None, 2235 name=None): 2236 """Computes precision values for different `thresholds` on `predictions`. 2237 2238 The `precision_at_thresholds` function creates four local variables, 2239 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 2240 for various values of thresholds. `precision[i]` is defined as the total 2241 weight of values in `predictions` above `thresholds[i]` whose corresponding 2242 entry in `labels` is `True`, divided by the total weight of values in 2243 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 2244 false_positives[i])`). 2245 2246 For estimation of the metric over a stream of data, the function creates an 2247 `update_op` operation that updates these variables and returns the 2248 `precision`. 2249 2250 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2251 2252 Args: 2253 labels: The ground truth values, a `Tensor` whose dimensions must match 2254 `predictions`. Will be cast to `bool`. 2255 predictions: A floating point `Tensor` of arbitrary shape and whose values 2256 are in the range `[0, 1]`. 2257 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 2258 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2259 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2260 be either `1`, or the same as the corresponding `labels` dimension). 2261 metrics_collections: An optional list of collections that `auc` should be 2262 added to. 2263 updates_collections: An optional list of collections that `update_op` should 2264 be added to. 2265 name: An optional variable_scope name. 2266 2267 Returns: 2268 precision: A float `Tensor` of shape `[len(thresholds)]`. 2269 update_op: An operation that increments the `true_positives`, 2270 `true_negatives`, `false_positives` and `false_negatives` variables that 2271 are used in the computation of `precision`. 2272 2273 Raises: 2274 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2275 `weights` is not `None` and its shape doesn't match `predictions`, or if 2276 either `metrics_collections` or `updates_collections` are not a list or 2277 tuple. 2278 RuntimeError: If eager execution is enabled. 2279 """ 2280 if context.executing_eagerly(): 2281 raise RuntimeError('tf.metrics.precision_at_thresholds is not ' 2282 'supported when eager execution is enabled.') 2283 2284 with variable_scope.variable_scope(name, 'precision_at_thresholds', 2285 (predictions, labels, weights)): 2286 values, update_ops = _confusion_matrix_at_thresholds( 2287 labels, predictions, thresholds, weights, includes=('tp', 'fp')) 2288 2289 # Avoid division by zero. 2290 epsilon = 1e-7 2291 2292 def compute_precision(tp, fp, name): 2293 return math_ops.divide(tp, epsilon + tp + fp, name='precision_' + name) 2294 2295 def precision_across_replicas(_, values): 2296 return compute_precision(values['tp'], values['fp'], 'value') 2297 2298 prec = _aggregate_across_replicas( 2299 metrics_collections, precision_across_replicas, values) 2300 2301 update_op = compute_precision(update_ops['tp'], update_ops['fp'], 2302 'update_op') 2303 if updates_collections: 2304 ops.add_to_collections(updates_collections, update_op) 2305 2306 return prec, update_op 2307 2308 2309@tf_export(v1=['metrics.recall']) 2310def recall(labels, 2311 predictions, 2312 weights=None, 2313 metrics_collections=None, 2314 updates_collections=None, 2315 name=None): 2316 """Computes the recall of the predictions with respect to the labels. 2317 2318 The `recall` function creates two local variables, `true_positives` 2319 and `false_negatives`, that are used to compute the recall. This value is 2320 ultimately returned as `recall`, an idempotent operation that simply divides 2321 `true_positives` by the sum of `true_positives` and `false_negatives`. 2322 2323 For estimation of the metric over a stream of data, the function creates an 2324 `update_op` that updates these variables and returns the `recall`. `update_op` 2325 weights each prediction by the corresponding value in `weights`. 2326 2327 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2328 2329 Args: 2330 labels: The ground truth values, a `Tensor` whose dimensions must match 2331 `predictions`. Will be cast to `bool`. 2332 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 2333 be cast to `bool`. 2334 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2335 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2336 be either `1`, or the same as the corresponding `labels` dimension). 2337 metrics_collections: An optional list of collections that `recall` should 2338 be added to. 2339 updates_collections: An optional list of collections that `update_op` should 2340 be added to. 2341 name: An optional variable_scope name. 2342 2343 Returns: 2344 recall: Scalar float `Tensor` with the value of `true_positives` divided 2345 by the sum of `true_positives` and `false_negatives`. 2346 update_op: `Operation` that increments `true_positives` and 2347 `false_negatives` variables appropriately and whose value matches 2348 `recall`. 2349 2350 Raises: 2351 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2352 `weights` is not `None` and its shape doesn't match `predictions`, or if 2353 either `metrics_collections` or `updates_collections` are not a list or 2354 tuple. 2355 RuntimeError: If eager execution is enabled. 2356 """ 2357 if context.executing_eagerly(): 2358 raise RuntimeError('tf.metrics.recall is not supported is not ' 2359 'supported when eager execution is enabled.') 2360 2361 with variable_scope.variable_scope(name, 'recall', 2362 (predictions, labels, weights)): 2363 predictions, labels, weights = _remove_squeezable_dimensions( 2364 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 2365 labels=math_ops.cast(labels, dtype=dtypes.bool), 2366 weights=weights) 2367 2368 true_p, true_positives_update_op = true_positives( 2369 labels, 2370 predictions, 2371 weights, 2372 metrics_collections=None, 2373 updates_collections=None, 2374 name=None) 2375 false_n, false_negatives_update_op = false_negatives( 2376 labels, 2377 predictions, 2378 weights, 2379 metrics_collections=None, 2380 updates_collections=None, 2381 name=None) 2382 2383 def compute_recall(true_p, false_n, name): 2384 return array_ops.where( 2385 math_ops.greater(true_p + false_n, 0), 2386 math_ops.divide(true_p, true_p + false_n), 0, name) 2387 2388 def once_across_replicas(_, true_p, false_n): 2389 return compute_recall(true_p, false_n, 'value') 2390 2391 rec = _aggregate_across_replicas( 2392 metrics_collections, once_across_replicas, true_p, false_n) 2393 2394 update_op = compute_recall(true_positives_update_op, 2395 false_negatives_update_op, 'update_op') 2396 if updates_collections: 2397 ops.add_to_collections(updates_collections, update_op) 2398 2399 return rec, update_op 2400 2401 2402def _at_k_name(name, k=None, class_id=None): 2403 if k is not None: 2404 name = '%s_at_%d' % (name, k) 2405 else: 2406 name = '%s_at_k' % (name) 2407 if class_id is not None: 2408 name = '%s_class%d' % (name, class_id) 2409 return name 2410 2411 2412def _select_class_id(ids, selected_id): 2413 """Filter all but `selected_id` out of `ids`. 2414 2415 Args: 2416 ids: `int64` `Tensor` or `SparseTensor` of IDs. 2417 selected_id: Int id to select. 2418 2419 Returns: 2420 `SparseTensor` of same dimensions as `ids`. This contains only the entries 2421 equal to `selected_id`. 2422 """ 2423 ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids) 2424 if isinstance(ids, sparse_tensor.SparseTensor): 2425 return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values, 2426 selected_id)) 2427 2428 # TODO(ptucker): Make this more efficient, maybe add a sparse version of 2429 # tf.equal and tf.reduce_any? 2430 2431 # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1. 2432 ids_shape = array_ops.shape(ids, out_type=dtypes.int64) 2433 ids_last_dim = array_ops.size(ids_shape) - 1 2434 filled_selected_id_shape = math_ops.reduced_shape(ids_shape, 2435 array_ops.reshape( 2436 ids_last_dim, [1])) 2437 2438 # Intersect `ids` with the selected ID. 2439 filled_selected_id = array_ops.fill(filled_selected_id_shape, 2440 math_ops.cast(selected_id, dtypes.int64)) 2441 result = sets.set_intersection(filled_selected_id, ids) 2442 return sparse_tensor.SparseTensor( 2443 indices=result.indices, values=result.values, dense_shape=ids_shape) 2444 2445 2446def _maybe_select_class_id(labels, predictions_idx, selected_id=None): 2447 """If class ID is specified, filter all other classes. 2448 2449 Args: 2450 labels: `int64` `Tensor` or `SparseTensor` with shape 2451 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2452 target classes for the associated prediction. Commonly, N=1 and `labels` 2453 has shape [batch_size, num_labels]. [D1, ... DN] must match 2454 `predictions_idx`. 2455 predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k] 2456 where N >= 1. Commonly, N=1 and `predictions_idx` has shape 2457 [batch size, k]. 2458 selected_id: Int id to select. 2459 2460 Returns: 2461 Tuple of `labels` and `predictions_idx`, possibly with classes removed. 2462 """ 2463 if selected_id is None: 2464 return labels, predictions_idx 2465 return (_select_class_id(labels, selected_id), 2466 _select_class_id(predictions_idx, selected_id)) 2467 2468 2469def _sparse_true_positive_at_k(labels, 2470 predictions_idx, 2471 class_id=None, 2472 weights=None, 2473 name=None): 2474 """Calculates true positives for recall@k and precision@k. 2475 2476 If `class_id` is specified, calculate binary true positives for `class_id` 2477 only. 2478 If `class_id` is not specified, calculate metrics for `k` predicted vs 2479 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 2480 2481 Args: 2482 labels: `int64` `Tensor` or `SparseTensor` with shape 2483 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2484 target classes for the associated prediction. Commonly, N=1 and `labels` 2485 has shape [batch_size, num_labels]. [D1, ... DN] must match 2486 `predictions_idx`. 2487 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2488 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2489 match `labels`. 2490 class_id: Class for which we want binary metrics. 2491 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2492 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2493 dimensions must be either `1`, or the same as the corresponding `labels` 2494 dimension). 2495 name: Name of operation. 2496 2497 Returns: 2498 A [D1, ... DN] `Tensor` of true positive counts. 2499 """ 2500 with ops.name_scope(name, 'true_positives', 2501 (predictions_idx, labels, weights)): 2502 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 2503 class_id) 2504 tp = sets.set_size(sets.set_intersection(predictions_idx, labels)) 2505 tp = math_ops.cast(tp, dtypes.float64) 2506 if weights is not None: 2507 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 2508 weights, tp),)): 2509 weights = math_ops.cast(weights, dtypes.float64) 2510 tp = math_ops.multiply(tp, weights) 2511 return tp 2512 2513 2514def _streaming_sparse_true_positive_at_k(labels, 2515 predictions_idx, 2516 k=None, 2517 class_id=None, 2518 weights=None, 2519 name=None): 2520 """Calculates weighted per step true positives for recall@k and precision@k. 2521 2522 If `class_id` is specified, calculate binary true positives for `class_id` 2523 only. 2524 If `class_id` is not specified, calculate metrics for `k` predicted vs 2525 `n` label classes, where `n` is the 2nd dimension of `labels`. 2526 2527 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2528 2529 Args: 2530 labels: `int64` `Tensor` or `SparseTensor` with shape 2531 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2532 target classes for the associated prediction. Commonly, N=1 and `labels` 2533 has shape [batch_size, num_labels]. [D1, ... DN] must match 2534 `predictions_idx`. 2535 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2536 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2537 match `labels`. 2538 k: Integer, k for @k metric. This is only used for default op name. 2539 class_id: Class for which we want binary metrics. 2540 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2541 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2542 dimensions must be either `1`, or the same as the corresponding `labels` 2543 dimension). 2544 name: Name of new variable, and namespace for other dependent ops. 2545 2546 Returns: 2547 A tuple of `Variable` and update `Operation`. 2548 2549 Raises: 2550 ValueError: If `weights` is not `None` and has an incompatible shape. 2551 """ 2552 with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id), 2553 (predictions_idx, labels, weights)) as scope: 2554 tp = _sparse_true_positive_at_k( 2555 predictions_idx=predictions_idx, 2556 labels=labels, 2557 class_id=class_id, 2558 weights=weights) 2559 batch_total_tp = math_ops.cast(math_ops.reduce_sum(tp), dtypes.float64) 2560 2561 var = metric_variable([], dtypes.float64, name=scope) 2562 return var, state_ops.assign_add(var, batch_total_tp, name='update') 2563 2564 2565def _sparse_false_negative_at_k(labels, 2566 predictions_idx, 2567 class_id=None, 2568 weights=None): 2569 """Calculates false negatives for recall@k. 2570 2571 If `class_id` is specified, calculate binary true positives for `class_id` 2572 only. 2573 If `class_id` is not specified, calculate metrics for `k` predicted vs 2574 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 2575 2576 Args: 2577 labels: `int64` `Tensor` or `SparseTensor` with shape 2578 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2579 target classes for the associated prediction. Commonly, N=1 and `labels` 2580 has shape [batch_size, num_labels]. [D1, ... DN] must match 2581 `predictions_idx`. 2582 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2583 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2584 match `labels`. 2585 class_id: Class for which we want binary metrics. 2586 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2587 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2588 dimensions must be either `1`, or the same as the corresponding `labels` 2589 dimension). 2590 2591 Returns: 2592 A [D1, ... DN] `Tensor` of false negative counts. 2593 """ 2594 with ops.name_scope(None, 'false_negatives', 2595 (predictions_idx, labels, weights)): 2596 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 2597 class_id) 2598 fn = sets.set_size( 2599 sets.set_difference(predictions_idx, labels, aminusb=False)) 2600 fn = math_ops.cast(fn, dtypes.float64) 2601 if weights is not None: 2602 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 2603 weights, fn),)): 2604 weights = math_ops.cast(weights, dtypes.float64) 2605 fn = math_ops.multiply(fn, weights) 2606 return fn 2607 2608 2609def _streaming_sparse_false_negative_at_k(labels, 2610 predictions_idx, 2611 k, 2612 class_id=None, 2613 weights=None, 2614 name=None): 2615 """Calculates weighted per step false negatives for recall@k. 2616 2617 If `class_id` is specified, calculate binary true positives for `class_id` 2618 only. 2619 If `class_id` is not specified, calculate metrics for `k` predicted vs 2620 `n` label classes, where `n` is the 2nd dimension of `labels`. 2621 2622 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2623 2624 Args: 2625 labels: `int64` `Tensor` or `SparseTensor` with shape 2626 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2627 target classes for the associated prediction. Commonly, N=1 and `labels` 2628 has shape [batch_size, num_labels]. [D1, ... DN] must match 2629 `predictions_idx`. 2630 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2631 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2632 match `labels`. 2633 k: Integer, k for @k metric. This is only used for default op name. 2634 class_id: Class for which we want binary metrics. 2635 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2636 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2637 dimensions must be either `1`, or the same as the corresponding `labels` 2638 dimension). 2639 name: Name of new variable, and namespace for other dependent ops. 2640 2641 Returns: 2642 A tuple of `Variable` and update `Operation`. 2643 2644 Raises: 2645 ValueError: If `weights` is not `None` and has an incompatible shape. 2646 """ 2647 with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id), 2648 (predictions_idx, labels, weights)) as scope: 2649 fn = _sparse_false_negative_at_k( 2650 predictions_idx=predictions_idx, 2651 labels=labels, 2652 class_id=class_id, 2653 weights=weights) 2654 batch_total_fn = math_ops.cast(math_ops.reduce_sum(fn), dtypes.float64) 2655 2656 var = metric_variable([], dtypes.float64, name=scope) 2657 return var, state_ops.assign_add(var, batch_total_fn, name='update') 2658 2659 2660@tf_export(v1=['metrics.recall_at_k']) 2661def recall_at_k(labels, 2662 predictions, 2663 k, 2664 class_id=None, 2665 weights=None, 2666 metrics_collections=None, 2667 updates_collections=None, 2668 name=None): 2669 """Computes recall@k of the predictions with respect to sparse labels. 2670 2671 If `class_id` is specified, we calculate recall by considering only the 2672 entries in the batch for which `class_id` is in the label, and computing 2673 the fraction of them for which `class_id` is in the top-k `predictions`. 2674 If `class_id` is not specified, we'll calculate recall as how often on 2675 average a class among the labels of a batch entry is in the top-k 2676 `predictions`. 2677 2678 `sparse_recall_at_k` creates two local variables, 2679 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 2680 the recall_at_k frequency. This frequency is ultimately returned as 2681 `recall_at_<k>`: an idempotent operation that simply divides 2682 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 2683 `false_negative_at_<k>`). 2684 2685 For estimation of the metric over a stream of data, the function creates an 2686 `update_op` operation that updates these variables and returns the 2687 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 2688 indicating the top `k` `predictions`. Set operations applied to `top_k` and 2689 `labels` calculate the true positives and false negatives weighted by 2690 `weights`. Then `update_op` increments `true_positive_at_<k>` and 2691 `false_negative_at_<k>` using these values. 2692 2693 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2694 2695 Args: 2696 labels: `int64` `Tensor` or `SparseTensor` with shape 2697 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2698 num_labels=1. N >= 1 and num_labels is the number of target classes for 2699 the associated prediction. Commonly, N=1 and `labels` has shape 2700 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 2701 should be in range [0, num_classes), where num_classes is the last 2702 dimension of `predictions`. Values outside this range always count 2703 towards `false_negative_at_<k>`. 2704 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 2705 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 2706 The final dimension contains the logit values for each class. [D1, ... DN] 2707 must match `labels`. 2708 k: Integer, k for @k metric. 2709 class_id: Integer class ID for which we want binary metrics. This should be 2710 in range [0, num_classes), where num_classes is the last dimension of 2711 `predictions`. If class_id is outside this range, the method returns NAN. 2712 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2713 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2714 dimensions must be either `1`, or the same as the corresponding `labels` 2715 dimension). 2716 metrics_collections: An optional list of collections that values should 2717 be added to. 2718 updates_collections: An optional list of collections that updates should 2719 be added to. 2720 name: Name of new update operation, and namespace for other dependent ops. 2721 2722 Returns: 2723 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2724 by the sum of `true_positives` and `false_negatives`. 2725 update_op: `Operation` that increments `true_positives` and 2726 `false_negatives` variables appropriately, and whose value matches 2727 `recall`. 2728 2729 Raises: 2730 ValueError: If `weights` is not `None` and its shape doesn't match 2731 `predictions`, or if either `metrics_collections` or `updates_collections` 2732 are not a list or tuple. 2733 RuntimeError: If eager execution is enabled. 2734 """ 2735 if context.executing_eagerly(): 2736 raise RuntimeError('tf.metrics.recall_at_k is not ' 2737 'supported when eager execution is enabled.') 2738 2739 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), 2740 (predictions, labels, weights)) as scope: 2741 _, top_k_idx = nn.top_k(predictions, k) 2742 return recall_at_top_k( 2743 labels=labels, 2744 predictions_idx=top_k_idx, 2745 k=k, 2746 class_id=class_id, 2747 weights=weights, 2748 metrics_collections=metrics_collections, 2749 updates_collections=updates_collections, 2750 name=scope) 2751 2752 2753@tf_export(v1=['metrics.recall_at_top_k']) 2754def recall_at_top_k(labels, 2755 predictions_idx, 2756 k=None, 2757 class_id=None, 2758 weights=None, 2759 metrics_collections=None, 2760 updates_collections=None, 2761 name=None): 2762 """Computes recall@k of top-k predictions with respect to sparse labels. 2763 2764 Differs from `recall_at_k` in that predictions must be in the form of top `k` 2765 class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k` 2766 for more details. 2767 2768 Args: 2769 labels: `int64` `Tensor` or `SparseTensor` with shape 2770 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2771 num_labels=1. N >= 1 and num_labels is the number of target classes for 2772 the associated prediction. Commonly, N=1 and `labels` has shape 2773 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 2774 should be in range [0, num_classes), where num_classes is the last 2775 dimension of `predictions`. Values outside this range always count 2776 towards `false_negative_at_<k>`. 2777 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 2778 Commonly, N=1 and predictions has shape [batch size, k]. The final 2779 dimension contains the top `k` predicted class indices. [D1, ... DN] must 2780 match `labels`. 2781 k: Integer, k for @k metric. Only used for the default op name. 2782 class_id: Integer class ID for which we want binary metrics. This should be 2783 in range [0, num_classes), where num_classes is the last dimension of 2784 `predictions`. If class_id is outside this range, the method returns NAN. 2785 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2786 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2787 dimensions must be either `1`, or the same as the corresponding `labels` 2788 dimension). 2789 metrics_collections: An optional list of collections that values should 2790 be added to. 2791 updates_collections: An optional list of collections that updates should 2792 be added to. 2793 name: Name of new update operation, and namespace for other dependent ops. 2794 2795 Returns: 2796 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2797 by the sum of `true_positives` and `false_negatives`. 2798 update_op: `Operation` that increments `true_positives` and 2799 `false_negatives` variables appropriately, and whose value matches 2800 `recall`. 2801 2802 Raises: 2803 ValueError: If `weights` is not `None` and its shape doesn't match 2804 `predictions`, or if either `metrics_collections` or `updates_collections` 2805 are not a list or tuple. 2806 """ 2807 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), 2808 (predictions_idx, labels, weights)) as scope: 2809 labels = _maybe_expand_labels(labels, predictions_idx) 2810 top_k_idx = math_ops.cast(predictions_idx, dtypes.int64) 2811 tp, tp_update = _streaming_sparse_true_positive_at_k( 2812 predictions_idx=top_k_idx, 2813 labels=labels, 2814 k=k, 2815 class_id=class_id, 2816 weights=weights) 2817 fn, fn_update = _streaming_sparse_false_negative_at_k( 2818 predictions_idx=top_k_idx, 2819 labels=labels, 2820 k=k, 2821 class_id=class_id, 2822 weights=weights) 2823 2824 def compute_recall(_, tp, fn): 2825 return math_ops.divide(tp, math_ops.add(tp, fn), name=scope) 2826 2827 metric = _aggregate_across_replicas( 2828 metrics_collections, compute_recall, tp, fn) 2829 2830 update = math_ops.divide( 2831 tp_update, math_ops.add(tp_update, fn_update), name='update') 2832 if updates_collections: 2833 ops.add_to_collections(updates_collections, update) 2834 return metric, update 2835 2836 2837@tf_export(v1=['metrics.recall_at_thresholds']) 2838def recall_at_thresholds(labels, 2839 predictions, 2840 thresholds, 2841 weights=None, 2842 metrics_collections=None, 2843 updates_collections=None, 2844 name=None): 2845 """Computes various recall values for different `thresholds` on `predictions`. 2846 2847 The `recall_at_thresholds` function creates four local variables, 2848 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 2849 for various values of thresholds. `recall[i]` is defined as the total weight 2850 of values in `predictions` above `thresholds[i]` whose corresponding entry in 2851 `labels` is `True`, divided by the total weight of `True` values in `labels` 2852 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 2853 2854 For estimation of the metric over a stream of data, the function creates an 2855 `update_op` operation that updates these variables and returns the `recall`. 2856 2857 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2858 2859 Args: 2860 labels: The ground truth values, a `Tensor` whose dimensions must match 2861 `predictions`. Will be cast to `bool`. 2862 predictions: A floating point `Tensor` of arbitrary shape and whose values 2863 are in the range `[0, 1]`. 2864 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 2865 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2866 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2867 be either `1`, or the same as the corresponding `labels` dimension). 2868 metrics_collections: An optional list of collections that `recall` should be 2869 added to. 2870 updates_collections: An optional list of collections that `update_op` should 2871 be added to. 2872 name: An optional variable_scope name. 2873 2874 Returns: 2875 recall: A float `Tensor` of shape `[len(thresholds)]`. 2876 update_op: An operation that increments the `true_positives`, 2877 `true_negatives`, `false_positives` and `false_negatives` variables that 2878 are used in the computation of `recall`. 2879 2880 Raises: 2881 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2882 `weights` is not `None` and its shape doesn't match `predictions`, or if 2883 either `metrics_collections` or `updates_collections` are not a list or 2884 tuple. 2885 RuntimeError: If eager execution is enabled. 2886 """ 2887 if context.executing_eagerly(): 2888 raise RuntimeError('tf.metrics.recall_at_thresholds is not ' 2889 'supported when eager execution is enabled.') 2890 2891 with variable_scope.variable_scope(name, 'recall_at_thresholds', 2892 (predictions, labels, weights)): 2893 values, update_ops = _confusion_matrix_at_thresholds( 2894 labels, predictions, thresholds, weights, includes=('tp', 'fn')) 2895 2896 # Avoid division by zero. 2897 epsilon = 1e-7 2898 2899 def compute_recall(tp, fn, name): 2900 return math_ops.divide(tp, epsilon + tp + fn, name='recall_' + name) 2901 2902 def recall_across_replicas(_, values): 2903 return compute_recall(values['tp'], values['fn'], 'value') 2904 2905 rec = _aggregate_across_replicas( 2906 metrics_collections, recall_across_replicas, values) 2907 2908 update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') 2909 if updates_collections: 2910 ops.add_to_collections(updates_collections, update_op) 2911 2912 return rec, update_op 2913 2914 2915@tf_export(v1=['metrics.root_mean_squared_error']) 2916def root_mean_squared_error(labels, 2917 predictions, 2918 weights=None, 2919 metrics_collections=None, 2920 updates_collections=None, 2921 name=None): 2922 """Computes the root mean squared error between the labels and predictions. 2923 2924 The `root_mean_squared_error` function creates two local variables, 2925 `total` and `count` that are used to compute the root mean squared error. 2926 This average is weighted by `weights`, and it is ultimately returned as 2927 `root_mean_squared_error`: an idempotent operation that takes the square root 2928 of the division of `total` by `count`. 2929 2930 For estimation of the metric over a stream of data, the function creates an 2931 `update_op` operation that updates these variables and returns the 2932 `root_mean_squared_error`. Internally, a `squared_error` operation computes 2933 the element-wise square of the difference between `predictions` and `labels`. 2934 Then `update_op` increments `total` with the reduced sum of the product of 2935 `weights` and `squared_error`, and it increments `count` with the reduced sum 2936 of `weights`. 2937 2938 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2939 2940 Args: 2941 labels: A `Tensor` of the same shape as `predictions`. 2942 predictions: A `Tensor` of arbitrary shape. 2943 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2944 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2945 be either `1`, or the same as the corresponding `labels` dimension). 2946 metrics_collections: An optional list of collections that 2947 `root_mean_squared_error` should be added to. 2948 updates_collections: An optional list of collections that `update_op` should 2949 be added to. 2950 name: An optional variable_scope name. 2951 2952 Returns: 2953 root_mean_squared_error: A `Tensor` representing the current mean, the value 2954 of `total` divided by `count`. 2955 update_op: An operation that increments the `total` and `count` variables 2956 appropriately and whose value matches `root_mean_squared_error`. 2957 2958 Raises: 2959 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2960 `weights` is not `None` and its shape doesn't match `predictions`, or if 2961 either `metrics_collections` or `updates_collections` are not a list or 2962 tuple. 2963 RuntimeError: If eager execution is enabled. 2964 """ 2965 if context.executing_eagerly(): 2966 raise RuntimeError('tf.metrics.root_mean_squared_error is not ' 2967 'supported when eager execution is enabled.') 2968 2969 predictions, labels, weights = _remove_squeezable_dimensions( 2970 predictions=predictions, labels=labels, weights=weights) 2971 mse, update_mse_op = mean_squared_error(labels, predictions, weights, None, 2972 None, name or 2973 'root_mean_squared_error') 2974 2975 once_across_replicas = lambda _, mse: math_ops.sqrt(mse) 2976 rmse = _aggregate_across_replicas( 2977 metrics_collections, once_across_replicas, mse) 2978 2979 update_rmse_op = math_ops.sqrt(update_mse_op) 2980 if updates_collections: 2981 ops.add_to_collections(updates_collections, update_rmse_op) 2982 2983 return rmse, update_rmse_op 2984 2985 2986@tf_export(v1=['metrics.sensitivity_at_specificity']) 2987def sensitivity_at_specificity(labels, 2988 predictions, 2989 specificity, 2990 weights=None, 2991 num_thresholds=200, 2992 metrics_collections=None, 2993 updates_collections=None, 2994 name=None): 2995 """Computes the specificity at a given sensitivity. 2996 2997 The `sensitivity_at_specificity` function creates four local 2998 variables, `true_positives`, `true_negatives`, `false_positives` and 2999 `false_negatives` that are used to compute the sensitivity at the given 3000 specificity value. The threshold for the given specificity value is computed 3001 and used to evaluate the corresponding sensitivity. 3002 3003 For estimation of the metric over a stream of data, the function creates an 3004 `update_op` operation that updates these variables and returns the 3005 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 3006 `false_positives` and `false_negatives` counts with the weight of each case 3007 found in the `predictions` and `labels`. 3008 3009 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3010 3011 For additional information about specificity and sensitivity, see the 3012 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 3013 3014 Args: 3015 labels: The ground truth values, a `Tensor` whose dimensions must match 3016 `predictions`. Will be cast to `bool`. 3017 predictions: A floating point `Tensor` of arbitrary shape and whose values 3018 are in the range `[0, 1]`. 3019 specificity: A scalar value in range `[0, 1]`. 3020 weights: Optional `Tensor` whose rank is either 0, or the same rank as 3021 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 3022 be either `1`, or the same as the corresponding `labels` dimension). 3023 num_thresholds: The number of thresholds to use for matching the given 3024 specificity. 3025 metrics_collections: An optional list of collections that `sensitivity` 3026 should be added to. 3027 updates_collections: An optional list of collections that `update_op` should 3028 be added to. 3029 name: An optional variable_scope name. 3030 3031 Returns: 3032 sensitivity: A scalar `Tensor` representing the sensitivity at the given 3033 `specificity` value. 3034 update_op: An operation that increments the `true_positives`, 3035 `true_negatives`, `false_positives` and `false_negatives` variables 3036 appropriately and whose value matches `sensitivity`. 3037 3038 Raises: 3039 ValueError: If `predictions` and `labels` have mismatched shapes, if 3040 `weights` is not `None` and its shape doesn't match `predictions`, or if 3041 `specificity` is not between 0 and 1, or if either `metrics_collections` 3042 or `updates_collections` are not a list or tuple. 3043 RuntimeError: If eager execution is enabled. 3044 """ 3045 if context.executing_eagerly(): 3046 raise RuntimeError('tf.metrics.sensitivity_at_specificity is not ' 3047 'supported when eager execution is enabled.') 3048 3049 if specificity < 0 or specificity > 1: 3050 raise ValueError('`specificity` must be in the range [0, 1]. Currently, ' 3051 f'`specificity` got {specificity}.') 3052 3053 with variable_scope.variable_scope(name, 'sensitivity_at_specificity', 3054 (predictions, labels, weights)): 3055 kepsilon = 1e-7 # to account for floating point imprecisions 3056 thresholds = [ 3057 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 3058 ] 3059 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 3060 3061 values, update_ops = _confusion_matrix_at_thresholds( 3062 labels, predictions, thresholds, weights) 3063 3064 def compute_sensitivity_at_specificity(tp, tn, fp, fn, name): 3065 specificities = math_ops.divide(tn, tn + fp + kepsilon) 3066 tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0) 3067 tf_index = math_ops.cast(tf_index, dtypes.int32) 3068 3069 # Now, we have the implicit threshold, so compute the sensitivity: 3070 return math_ops.divide(tp[tf_index], 3071 tp[tf_index] + fn[tf_index] + kepsilon, name) 3072 3073 def sensitivity_across_replicas(_, values): 3074 return compute_sensitivity_at_specificity( 3075 values['tp'], values['tn'], values['fp'], values['fn'], 'value') 3076 3077 sensitivity = _aggregate_across_replicas( 3078 metrics_collections, sensitivity_across_replicas, values) 3079 3080 update_op = compute_sensitivity_at_specificity( 3081 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 3082 'update_op') 3083 if updates_collections: 3084 ops.add_to_collections(updates_collections, update_op) 3085 3086 return sensitivity, update_op 3087 3088 3089def _expand_and_tile(tensor, multiple, dim=0, name=None): 3090 """Slice `tensor` shape in 2, then tile along the sliced dimension. 3091 3092 A new dimension is inserted in shape of `tensor` before `dim`, then values are 3093 tiled `multiple` times along the new dimension. 3094 3095 Args: 3096 tensor: Input `Tensor` or `SparseTensor`. 3097 multiple: Integer, number of times to tile. 3098 dim: Integer, dimension along which to tile. 3099 name: Name of operation. 3100 3101 Returns: 3102 `Tensor` result of expanding and tiling `tensor`. 3103 3104 Raises: 3105 ValueError: if `multiple` is less than 1, or `dim` is not in 3106 `[-rank(tensor), rank(tensor)]`. 3107 """ 3108 if multiple < 1: 3109 raise ValueError(f'Invalid argument multiple={multiple} for ' 3110 'expand_and_tile call. `multiple` must be an integer > 0') 3111 with ops.name_scope(name, 'expand_and_tile', 3112 (tensor, multiple, dim)) as scope: 3113 # Sparse. 3114 tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor) 3115 if isinstance(tensor, sparse_tensor.SparseTensor): 3116 if dim < 0: 3117 expand_dims = array_ops.reshape( 3118 array_ops.size(tensor.dense_shape) + dim, [1]) 3119 else: 3120 expand_dims = [dim] 3121 expanded_shape = array_ops.concat( 3122 (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1], 3123 array_ops.slice(tensor.dense_shape, expand_dims, [-1])), 3124 0, 3125 name='expanded_shape') 3126 expanded = sparse_ops.sparse_reshape( 3127 tensor, shape=expanded_shape, name='expand') 3128 if multiple == 1: 3129 return expanded 3130 return sparse_ops.sparse_concat( 3131 dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope) 3132 3133 # Dense. 3134 expanded = array_ops.expand_dims( 3135 tensor, dim if (dim >= 0) else (dim - 1), name='expand') 3136 if multiple == 1: 3137 return expanded 3138 ones = array_ops.ones_like(array_ops.shape(tensor)) 3139 tile_multiples = array_ops.concat( 3140 (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples') 3141 return array_ops.tile(expanded, tile_multiples, name=scope) 3142 3143 3144def _num_relevant(labels, k): 3145 """Computes number of relevant values for each row in labels. 3146 3147 For labels with shape [D1, ... DN, num_labels], this is the minimum of 3148 `num_labels` and `k`. 3149 3150 Args: 3151 labels: `int64` `Tensor` or `SparseTensor` with shape 3152 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 3153 target classes for the associated prediction. Commonly, N=1 and `labels` 3154 has shape [batch_size, num_labels]. 3155 k: Integer, k for @k metric. 3156 3157 Returns: 3158 Integer `Tensor` of shape [D1, ... DN], where each value is the number of 3159 relevant values for that row. 3160 3161 Raises: 3162 ValueError: if inputs have invalid dtypes or values. 3163 """ 3164 if k < 1: 3165 raise ValueError(f'Invalid k={k}') 3166 with ops.name_scope(None, 'num_relevant', (labels,)) as scope: 3167 # For SparseTensor, calculate separate count for each row. 3168 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 3169 if isinstance(labels, sparse_tensor.SparseTensor): 3170 return math_ops.minimum(sets.set_size(labels), k, name=scope) 3171 3172 # The relevant values for each (d1, ... dN) is the minimum of k and the 3173 # number of labels along the last dimension that are non-negative. 3174 num_labels = math_ops.reduce_sum( 3175 array_ops.where_v2(math_ops.greater_equal(labels, 0), 3176 array_ops.ones_like(labels), 3177 array_ops.zeros_like(labels)), 3178 axis=-1) 3179 return math_ops.minimum(num_labels, k, name=scope) 3180 3181 3182def _sparse_average_precision_at_top_k(labels, predictions_idx): 3183 """Computes average precision@k of predictions with respect to sparse labels. 3184 3185 From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula 3186 for each row is: 3187 3188 AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items 3189 3190 A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`, 3191 `labels`, and the result `Tensors`. In the common case, this is [batch_size]. 3192 Each row of the results contains the average precision for that row. 3193 3194 Args: 3195 labels: `int64` `Tensor` or `SparseTensor` with shape 3196 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3197 num_labels=1. N >= 1 and num_labels is the number of target classes for 3198 the associated prediction. Commonly, N=1 and `labels` has shape 3199 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`. 3200 Values should be non-negative. Negative values are ignored. 3201 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 3202 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 3203 dimension must be set and contains the top `k` predicted class indices. 3204 [D1, ... DN] must match `labels`. Values should be in range 3205 [0, num_classes). 3206 3207 Returns: 3208 `float64` `Tensor` of shape [D1, ... DN], where each value is the average 3209 precision for that row. 3210 3211 Raises: 3212 ValueError: if the last dimension of predictions_idx is not set. 3213 """ 3214 with ops.name_scope(None, 'average_precision', 3215 (predictions_idx, labels)) as scope: 3216 predictions_idx = math_ops.cast( 3217 predictions_idx, dtypes.int64, name='predictions_idx') 3218 if predictions_idx.get_shape().ndims == 0: 3219 raise ValueError('The rank of `predictions_idx` must be at least 1.') 3220 k = predictions_idx.get_shape().as_list()[-1] 3221 if k is None: 3222 raise ValueError('The last dimension of predictions_idx must be set. ' 3223 'Currently, it is None.') 3224 labels = _maybe_expand_labels(labels, predictions_idx) 3225 3226 # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate 3227 # prediction for each k, so we can calculate separate true positive values 3228 # for each k. 3229 predictions_idx_per_k = array_ops.expand_dims( 3230 predictions_idx, -1, name='predictions_idx_per_k') 3231 3232 # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor. 3233 labels_per_k = _expand_and_tile( 3234 labels, multiple=k, dim=-1, name='labels_per_k') 3235 3236 # The following tensors are all of shape [D1, ... DN, k], containing values 3237 # per row, per k value. 3238 # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at 3239 # that k value is correct, 0 otherwise. This is the "rel_{i}" term from 3240 # the formula above. 3241 # `tp_per_k` (int32) - True positive counts. 3242 # `retrieved_per_k` (int32) - Number of predicted values at each k. This is 3243 # the precision denominator. 3244 # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}" 3245 # term from the formula above. 3246 # `relevant_precision_per_k` (float64) - Relevant precisions; i.e., 3247 # precisions at all k for which relevance indicator is true. 3248 relevant_per_k = _sparse_true_positive_at_k( 3249 labels_per_k, predictions_idx_per_k, name='relevant_per_k') 3250 tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k') 3251 retrieved_per_k = math_ops.cumsum( 3252 array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k') 3253 precision_per_k = math_ops.divide( 3254 math_ops.cast(tp_per_k, dtypes.float64), 3255 math_ops.cast(retrieved_per_k, dtypes.float64), 3256 name='precision_per_k') 3257 relevant_precision_per_k = math_ops.multiply( 3258 precision_per_k, 3259 math_ops.cast(relevant_per_k, dtypes.float64), 3260 name='relevant_precision_per_k') 3261 3262 # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor. 3263 precision_sum = math_ops.reduce_sum( 3264 relevant_precision_per_k, axis=(-1,), name='precision_sum') 3265 3266 # Divide by number of relevant items to get average precision. These are 3267 # the "num_relevant_items" and "AveP" terms from the formula above. 3268 num_relevant_items = math_ops.cast(_num_relevant(labels, k), dtypes.float64) 3269 return math_ops.divide(precision_sum, num_relevant_items, name=scope) 3270 3271 3272def _streaming_sparse_average_precision_at_top_k(labels, 3273 predictions_idx, 3274 weights=None, 3275 metrics_collections=None, 3276 updates_collections=None, 3277 name=None): 3278 """Computes average precision@k of predictions with respect to sparse labels. 3279 3280 `sparse_average_precision_at_top_k` creates two local variables, 3281 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 3282 are used to compute the frequency. This frequency is ultimately returned as 3283 `average_precision_at_<k>`: an idempotent operation that simply divides 3284 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 3285 3286 For estimation of the metric over a stream of data, the function creates an 3287 `update_op` operation that updates these variables and returns the 3288 `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate 3289 the true positives and false positives weighted by `weights`. Then `update_op` 3290 increments `true_positive_at_<k>` and `false_positive_at_<k>` using these 3291 values. 3292 3293 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3294 3295 Args: 3296 labels: `int64` `Tensor` or `SparseTensor` with shape 3297 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3298 num_labels=1. N >= 1 and num_labels is the number of target classes for 3299 the associated prediction. Commonly, N=1 and `labels` has shape 3300 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`. 3301 Values should be non-negative. Negative values are ignored. 3302 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 3303 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 3304 dimension contains the top `k` predicted class indices. [D1, ... DN] must 3305 match `labels`. Values should be in range [0, num_classes). 3306 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3307 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3308 dimensions must be either `1`, or the same as the corresponding `labels` 3309 dimension). 3310 metrics_collections: An optional list of collections that values should 3311 be added to. 3312 updates_collections: An optional list of collections that updates should 3313 be added to. 3314 name: Name of new update operation, and namespace for other dependent ops. 3315 3316 Returns: 3317 mean_average_precision: Scalar `float64` `Tensor` with the mean average 3318 precision values. 3319 update: `Operation` that increments variables appropriately, and whose 3320 value matches `metric`. 3321 """ 3322 with ops.name_scope(name, 'average_precision_at_top_k', 3323 (predictions_idx, labels, weights)) as scope: 3324 # Calculate per-example average precision, and apply weights. 3325 average_precision = _sparse_average_precision_at_top_k( 3326 predictions_idx=predictions_idx, labels=labels) 3327 if weights is not None: 3328 weights = weights_broadcast_ops.broadcast_weights( 3329 math_ops.cast(weights, dtypes.float64), average_precision) 3330 average_precision = math_ops.multiply(average_precision, weights) 3331 3332 # Create accumulation variables and update ops for max average precision and 3333 # total average precision. 3334 with ops.name_scope(None, 'max', (average_precision,)) as max_scope: 3335 # `max` is the max possible precision. Since max for any row is 1.0: 3336 # - For the unweighted case, this is just the number of rows. 3337 # - For the weighted case, it's the sum of the weights broadcast across 3338 # `average_precision` rows. 3339 max_var = metric_variable([], dtypes.float64, name=max_scope) 3340 if weights is None: 3341 batch_max = math_ops.cast( 3342 array_ops.size(average_precision, name='batch_max'), dtypes.float64) 3343 else: 3344 batch_max = math_ops.reduce_sum(weights, name='batch_max') 3345 max_update = state_ops.assign_add(max_var, batch_max, name='update') 3346 with ops.name_scope(None, 'total', (average_precision,)) as total_scope: 3347 total_var = metric_variable([], dtypes.float64, name=total_scope) 3348 batch_total = math_ops.reduce_sum(average_precision, name='batch_total') 3349 total_update = state_ops.assign_add(total_var, batch_total, name='update') 3350 3351 # Divide total by max to get mean, for both vars and the update ops. 3352 def precision_across_replicas(_, total_var, max_var): 3353 return _safe_scalar_div(total_var, max_var, name='mean') 3354 3355 mean_average_precision = _aggregate_across_replicas( 3356 metrics_collections, precision_across_replicas, total_var, max_var) 3357 3358 update = _safe_scalar_div(total_update, max_update, name=scope) 3359 if updates_collections: 3360 ops.add_to_collections(updates_collections, update) 3361 3362 return mean_average_precision, update 3363 3364 3365def _clean_out_of_range_indices(labels, num_classes): 3366 """Replaces large out-of-range labels by small out-of-range labels. 3367 3368 Replaces any value in `labels` that is greater or equal to `num_classes` by 3369 -1. Do this conditionally for efficiency in case there are no such values. 3370 3371 Args: 3372 labels: `int64` `Tensor` or `SparseTensor`. 3373 num_classes: `int64` scalar `Tensor`. 3374 Returns: 3375 An `int64` `Tensor` or `SparseTensor` as `labels` with indices greater 3376 or equal to num_classes replaced by -1. 3377 """ 3378 3379 def _labels_is_sparse(): 3380 """Returns true is `labels` is a sparse tensor.""" 3381 return isinstance(labels, (sparse_tensor.SparseTensor, 3382 sparse_tensor.SparseTensorValue)) 3383 3384 def _clean_out_of_range(values): 3385 """Replaces by -1 any large out-of-range `values`.""" 3386 return array_ops.where_v2(math_ops.greater_equal(values, num_classes), 3387 -1 * array_ops.ones_like(values), values) 3388 3389 def _clean_labels_out_of_range(): 3390 """Replaces by -1 ane large out-of-range values in `labels`.""" 3391 if _labels_is_sparse(): 3392 return type(labels)(indices=labels.indices, 3393 values=_clean_out_of_range(labels.values), 3394 dense_shape=labels.dense_shape) 3395 else: 3396 return _clean_out_of_range(labels) 3397 3398 max_labels = math_ops.reduce_max( 3399 labels.values if _labels_is_sparse() else labels) 3400 return control_flow_ops.cond( 3401 math_ops.greater_equal(max_labels, num_classes), 3402 _clean_labels_out_of_range, 3403 lambda: labels) 3404 3405 3406@tf_export(v1=['metrics.sparse_average_precision_at_k']) 3407@deprecated(None, 'Use average_precision_at_k instead') 3408def sparse_average_precision_at_k(labels, 3409 predictions, 3410 k, 3411 weights=None, 3412 metrics_collections=None, 3413 updates_collections=None, 3414 name=None): 3415 """Renamed to `average_precision_at_k`, please use that method instead.""" 3416 return average_precision_at_k( 3417 labels=labels, 3418 predictions=predictions, 3419 k=k, 3420 weights=weights, 3421 metrics_collections=metrics_collections, 3422 updates_collections=updates_collections, 3423 name=name) 3424 3425 3426@tf_export(v1=['metrics.average_precision_at_k']) 3427def average_precision_at_k(labels, 3428 predictions, 3429 k, 3430 weights=None, 3431 metrics_collections=None, 3432 updates_collections=None, 3433 name=None): 3434 """Computes average precision@k of predictions with respect to sparse labels. 3435 3436 `average_precision_at_k` creates two local variables, 3437 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 3438 are used to compute the frequency. This frequency is ultimately returned as 3439 `average_precision_at_<k>`: an idempotent operation that simply divides 3440 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 3441 3442 For estimation of the metric over a stream of data, the function creates an 3443 `update_op` operation that updates these variables and returns the 3444 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 3445 indicating the top `k` `predictions`. Set operations applied to `top_k` and 3446 `labels` calculate the true positives and false positives weighted by 3447 `weights`. Then `update_op` increments `true_positive_at_<k>` and 3448 `false_positive_at_<k>` using these values. 3449 3450 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3451 3452 Args: 3453 labels: `int64` `Tensor` or `SparseTensor` with shape 3454 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3455 num_labels=1. N >= 1 and num_labels is the number of target classes for 3456 the associated prediction. Commonly, N=1 and `labels` has shape 3457 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3458 should be in range [0, num_classes), where num_classes is the last 3459 dimension of `predictions`. Values outside this range are ignored. 3460 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 3461 N >= 1. Commonly, N=1 and `predictions` has shape 3462 [batch size, num_classes]. The final dimension contains the logit values 3463 for each class. [D1, ... DN] must match `labels`. 3464 k: Integer, k for @k metric. This will calculate an average precision for 3465 range `[1,k]`, as documented above. 3466 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3467 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3468 dimensions must be either `1`, or the same as the corresponding `labels` 3469 dimension). 3470 metrics_collections: An optional list of collections that values should 3471 be added to. 3472 updates_collections: An optional list of collections that updates should 3473 be added to. 3474 name: Name of new update operation, and namespace for other dependent ops. 3475 3476 Returns: 3477 mean_average_precision: Scalar `float64` `Tensor` with the mean average 3478 precision values. 3479 update: `Operation` that increments variables appropriately, and whose 3480 value matches `metric`. 3481 3482 Raises: 3483 ValueError: if k is invalid. 3484 RuntimeError: If eager execution is enabled. 3485 """ 3486 if context.executing_eagerly(): 3487 raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not ' 3488 'supported when eager execution is enabled.') 3489 3490 if k < 1: 3491 raise ValueError(f'Invalid k={k}. `k` should be >= 1.') 3492 with ops.name_scope(name, _at_k_name('average_precision', k), 3493 (predictions, labels, weights)) as scope: 3494 # Calculate top k indices to produce [D1, ... DN, k] tensor. 3495 _, predictions_idx = nn.top_k(predictions, k) 3496 # The documentation states that labels should be in [0, ..., num_classes), 3497 # but num_classes is lost when predictions_idx replaces predictions. 3498 # For conformity with the documentation, any label >= num_classes, which is 3499 # ignored, is replaced by -1. 3500 labels = _clean_out_of_range_indices( 3501 labels, math_ops.cast(array_ops.shape(predictions)[-1], dtypes.int64)) 3502 return _streaming_sparse_average_precision_at_top_k( 3503 labels=labels, 3504 predictions_idx=predictions_idx, 3505 weights=weights, 3506 metrics_collections=metrics_collections, 3507 updates_collections=updates_collections, 3508 name=scope) 3509 3510 3511def _sparse_false_positive_at_k(labels, 3512 predictions_idx, 3513 class_id=None, 3514 weights=None): 3515 """Calculates false positives for precision@k. 3516 3517 If `class_id` is specified, calculate binary true positives for `class_id` 3518 only. 3519 If `class_id` is not specified, calculate metrics for `k` predicted vs 3520 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 3521 3522 Args: 3523 labels: `int64` `Tensor` or `SparseTensor` with shape 3524 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 3525 target classes for the associated prediction. Commonly, N=1 and `labels` 3526 has shape [batch_size, num_labels]. [D1, ... DN] must match 3527 `predictions_idx`. 3528 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 3529 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 3530 match `labels`. 3531 class_id: Class for which we want binary metrics. 3532 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3533 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3534 dimensions must be either `1`, or the same as the corresponding `labels` 3535 dimension). 3536 3537 Returns: 3538 A [D1, ... DN] `Tensor` of false positive counts. 3539 """ 3540 with ops.name_scope(None, 'false_positives', 3541 (predictions_idx, labels, weights)): 3542 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 3543 class_id) 3544 fp = sets.set_size( 3545 sets.set_difference(predictions_idx, labels, aminusb=True)) 3546 fp = math_ops.cast(fp, dtypes.float64) 3547 if weights is not None: 3548 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 3549 weights, fp),)): 3550 weights = math_ops.cast(weights, dtypes.float64) 3551 fp = math_ops.multiply(fp, weights) 3552 return fp 3553 3554 3555def _streaming_sparse_false_positive_at_k(labels, 3556 predictions_idx, 3557 k=None, 3558 class_id=None, 3559 weights=None, 3560 name=None): 3561 """Calculates weighted per step false positives for precision@k. 3562 3563 If `class_id` is specified, calculate binary true positives for `class_id` 3564 only. 3565 If `class_id` is not specified, calculate metrics for `k` predicted vs 3566 `n` label classes, where `n` is the 2nd dimension of `labels`. 3567 3568 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3569 3570 Args: 3571 labels: `int64` `Tensor` or `SparseTensor` with shape 3572 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 3573 target classes for the associated prediction. Commonly, N=1 and `labels` 3574 has shape [batch_size, num_labels]. [D1, ... DN] must match 3575 `predictions_idx`. 3576 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 3577 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 3578 match `labels`. 3579 k: Integer, k for @k metric. This is only used for default op name. 3580 class_id: Class for which we want binary metrics. 3581 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3582 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3583 dimensions must be either `1`, or the same as the corresponding `labels` 3584 dimension). 3585 name: Name of new variable, and namespace for other dependent ops. 3586 3587 Returns: 3588 A tuple of `Variable` and update `Operation`. 3589 3590 Raises: 3591 ValueError: If `weights` is not `None` and has an incompatible shape. 3592 """ 3593 with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id), 3594 (predictions_idx, labels, weights)) as scope: 3595 fp = _sparse_false_positive_at_k( 3596 predictions_idx=predictions_idx, 3597 labels=labels, 3598 class_id=class_id, 3599 weights=weights) 3600 batch_total_fp = math_ops.cast(math_ops.reduce_sum(fp), dtypes.float64) 3601 3602 var = metric_variable([], dtypes.float64, name=scope) 3603 return var, state_ops.assign_add(var, batch_total_fp, name='update') 3604 3605 3606@tf_export(v1=['metrics.precision_at_top_k']) 3607def precision_at_top_k(labels, 3608 predictions_idx, 3609 k=None, 3610 class_id=None, 3611 weights=None, 3612 metrics_collections=None, 3613 updates_collections=None, 3614 name=None): 3615 """Computes precision@k of the predictions with respect to sparse labels. 3616 3617 Differs from `sparse_precision_at_k` in that predictions must be in the form 3618 of top `k` class indices, whereas `sparse_precision_at_k` expects logits. 3619 Refer to `sparse_precision_at_k` for more details. 3620 3621 Args: 3622 labels: `int64` `Tensor` or `SparseTensor` with shape 3623 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3624 num_labels=1. N >= 1 and num_labels is the number of target classes for 3625 the associated prediction. Commonly, N=1 and `labels` has shape 3626 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3627 should be in range [0, num_classes), where num_classes is the last 3628 dimension of `predictions`. Values outside this range are ignored. 3629 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where 3630 N >= 1. Commonly, N=1 and predictions has shape [batch size, k]. 3631 The final dimension contains the top `k` predicted class indices. 3632 [D1, ... DN] must match `labels`. 3633 k: Integer, k for @k metric. Only used for the default op name. 3634 class_id: Integer class ID for which we want binary metrics. This should be 3635 in range [0, num_classes], where num_classes is the last dimension of 3636 `predictions`. If `class_id` is outside this range, the method returns 3637 NAN. 3638 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3639 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3640 dimensions must be either `1`, or the same as the corresponding `labels` 3641 dimension). 3642 metrics_collections: An optional list of collections that values should 3643 be added to. 3644 updates_collections: An optional list of collections that updates should 3645 be added to. 3646 name: Name of new update operation, and namespace for other dependent ops. 3647 3648 Returns: 3649 precision: Scalar `float64` `Tensor` with the value of `true_positives` 3650 divided by the sum of `true_positives` and `false_positives`. 3651 update_op: `Operation` that increments `true_positives` and 3652 `false_positives` variables appropriately, and whose value matches 3653 `precision`. 3654 3655 Raises: 3656 ValueError: If `weights` is not `None` and its shape doesn't match 3657 `predictions`, or if either `metrics_collections` or `updates_collections` 3658 are not a list or tuple. 3659 RuntimeError: If eager execution is enabled. 3660 """ 3661 if context.executing_eagerly(): 3662 raise RuntimeError('tf.metrics.precision_at_top_k is not ' 3663 'supported when eager execution is enabled.') 3664 3665 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id), 3666 (predictions_idx, labels, weights)) as scope: 3667 labels = _maybe_expand_labels(labels, predictions_idx) 3668 top_k_idx = math_ops.cast(predictions_idx, dtypes.int64) 3669 tp, tp_update = _streaming_sparse_true_positive_at_k( 3670 predictions_idx=top_k_idx, 3671 labels=labels, 3672 k=k, 3673 class_id=class_id, 3674 weights=weights) 3675 fp, fp_update = _streaming_sparse_false_positive_at_k( 3676 predictions_idx=top_k_idx, 3677 labels=labels, 3678 k=k, 3679 class_id=class_id, 3680 weights=weights) 3681 3682 def precision_across_replicas(_, tp, fp): 3683 return math_ops.divide(tp, math_ops.add(tp, fp), name=scope) 3684 3685 metric = _aggregate_across_replicas( 3686 metrics_collections, precision_across_replicas, tp, fp) 3687 3688 update = math_ops.divide( 3689 tp_update, math_ops.add(tp_update, fp_update), name='update') 3690 if updates_collections: 3691 ops.add_to_collections(updates_collections, update) 3692 return metric, update 3693 3694 3695@tf_export(v1=['metrics.sparse_precision_at_k']) 3696@deprecated(None, 'Use precision_at_k instead') 3697def sparse_precision_at_k(labels, 3698 predictions, 3699 k, 3700 class_id=None, 3701 weights=None, 3702 metrics_collections=None, 3703 updates_collections=None, 3704 name=None): 3705 """Renamed to `precision_at_k`, please use that method instead.""" 3706 return precision_at_k( 3707 labels=labels, 3708 predictions=predictions, 3709 k=k, 3710 class_id=class_id, 3711 weights=weights, 3712 metrics_collections=metrics_collections, 3713 updates_collections=updates_collections, 3714 name=name) 3715 3716 3717@tf_export(v1=['metrics.precision_at_k']) 3718def precision_at_k(labels, 3719 predictions, 3720 k, 3721 class_id=None, 3722 weights=None, 3723 metrics_collections=None, 3724 updates_collections=None, 3725 name=None): 3726 """Computes precision@k of the predictions with respect to sparse labels. 3727 3728 If `class_id` is specified, we calculate precision by considering only the 3729 entries in the batch for which `class_id` is in the top-k highest 3730 `predictions`, and computing the fraction of them for which `class_id` is 3731 indeed a correct label. 3732 If `class_id` is not specified, we'll calculate precision as how often on 3733 average a class among the top-k classes with the highest predicted values 3734 of a batch entry is correct and can be found in the label for that entry. 3735 3736 `precision_at_k` creates two local variables, 3737 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 3738 the precision@k frequency. This frequency is ultimately returned as 3739 `precision_at_<k>`: an idempotent operation that simply divides 3740 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 3741 `false_positive_at_<k>`). 3742 3743 For estimation of the metric over a stream of data, the function creates an 3744 `update_op` operation that updates these variables and returns the 3745 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 3746 indicating the top `k` `predictions`. Set operations applied to `top_k` and 3747 `labels` calculate the true positives and false positives weighted by 3748 `weights`. Then `update_op` increments `true_positive_at_<k>` and 3749 `false_positive_at_<k>` using these values. 3750 3751 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3752 3753 Args: 3754 labels: `int64` `Tensor` or `SparseTensor` with shape 3755 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3756 num_labels=1. N >= 1 and num_labels is the number of target classes for 3757 the associated prediction. Commonly, N=1 and `labels` has shape 3758 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3759 should be in range [0, num_classes), where num_classes is the last 3760 dimension of `predictions`. Values outside this range are ignored. 3761 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 3762 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 3763 The final dimension contains the logit values for each class. [D1, ... DN] 3764 must match `labels`. 3765 k: Integer, k for @k metric. 3766 class_id: Integer class ID for which we want binary metrics. This should be 3767 in range [0, num_classes], where num_classes is the last dimension of 3768 `predictions`. If `class_id` is outside this range, the method returns 3769 NAN. 3770 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3771 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3772 dimensions must be either `1`, or the same as the corresponding `labels` 3773 dimension). 3774 metrics_collections: An optional list of collections that values should 3775 be added to. 3776 updates_collections: An optional list of collections that updates should 3777 be added to. 3778 name: Name of new update operation, and namespace for other dependent ops. 3779 3780 Returns: 3781 precision: Scalar `float64` `Tensor` with the value of `true_positives` 3782 divided by the sum of `true_positives` and `false_positives`. 3783 update_op: `Operation` that increments `true_positives` and 3784 `false_positives` variables appropriately, and whose value matches 3785 `precision`. 3786 3787 Raises: 3788 ValueError: If `weights` is not `None` and its shape doesn't match 3789 `predictions`, or if either `metrics_collections` or `updates_collections` 3790 are not a list or tuple. 3791 RuntimeError: If eager execution is enabled. 3792 """ 3793 if context.executing_eagerly(): 3794 raise RuntimeError('tf.metrics.sparse_precision_at_k is not ' 3795 'supported when eager execution is enabled.') 3796 3797 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id), 3798 (predictions, labels, weights)) as scope: 3799 _, top_k_idx = nn.top_k(predictions, k) 3800 return precision_at_top_k( 3801 labels=labels, 3802 predictions_idx=top_k_idx, 3803 k=k, 3804 class_id=class_id, 3805 weights=weights, 3806 metrics_collections=metrics_collections, 3807 updates_collections=updates_collections, 3808 name=scope) 3809 3810 3811@tf_export(v1=['metrics.specificity_at_sensitivity']) 3812def specificity_at_sensitivity(labels, 3813 predictions, 3814 sensitivity, 3815 weights=None, 3816 num_thresholds=200, 3817 metrics_collections=None, 3818 updates_collections=None, 3819 name=None): 3820 """Computes the specificity at a given sensitivity. 3821 3822 The `specificity_at_sensitivity` function creates four local 3823 variables, `true_positives`, `true_negatives`, `false_positives` and 3824 `false_negatives` that are used to compute the specificity at the given 3825 sensitivity value. The threshold for the given sensitivity value is computed 3826 and used to evaluate the corresponding specificity. 3827 3828 For estimation of the metric over a stream of data, the function creates an 3829 `update_op` operation that updates these variables and returns the 3830 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 3831 `false_positives` and `false_negatives` counts with the weight of each case 3832 found in the `predictions` and `labels`. 3833 3834 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3835 3836 For additional information about specificity and sensitivity, see the 3837 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 3838 3839 Args: 3840 labels: The ground truth values, a `Tensor` whose dimensions must match 3841 `predictions`. Will be cast to `bool`. 3842 predictions: A floating point `Tensor` of arbitrary shape and whose values 3843 are in the range `[0, 1]`. 3844 sensitivity: A scalar value in range `[0, 1]`. 3845 weights: Optional `Tensor` whose rank is either 0, or the same rank as 3846 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 3847 be either `1`, or the same as the corresponding `labels` dimension). 3848 num_thresholds: The number of thresholds to use for matching the given 3849 sensitivity. 3850 metrics_collections: An optional list of collections that `specificity` 3851 should be added to. 3852 updates_collections: An optional list of collections that `update_op` should 3853 be added to. 3854 name: An optional variable_scope name. 3855 3856 Returns: 3857 specificity: A scalar `Tensor` representing the specificity at the given 3858 `sensitivity` value. 3859 update_op: An operation that increments the `true_positives`, 3860 `true_negatives`, `false_positives` and `false_negatives` variables 3861 appropriately and whose value matches `specificity`. 3862 3863 Raises: 3864 ValueError: If `predictions` and `labels` have mismatched shapes, if 3865 `weights` is not `None` and its shape doesn't match `predictions`, or if 3866 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 3867 or `updates_collections` are not a list or tuple. 3868 RuntimeError: If eager execution is enabled. 3869 """ 3870 if context.executing_eagerly(): 3871 raise RuntimeError('tf.metrics.specificity_at_sensitivity is not ' 3872 'supported when eager execution is enabled.') 3873 3874 if sensitivity < 0 or sensitivity > 1: 3875 raise ValueError('`sensitivity` must be in the range [0, 1]. Currently, ' 3876 f'`sensitivity` is {sensitivity}.') 3877 3878 with variable_scope.variable_scope(name, 'specificity_at_sensitivity', 3879 (predictions, labels, weights)): 3880 kepsilon = 1e-7 # to account for floating point imprecisions 3881 thresholds = [ 3882 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 3883 ] 3884 thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon] 3885 3886 values, update_ops = _confusion_matrix_at_thresholds( 3887 labels, predictions, thresholds, weights) 3888 3889 def compute_specificity_at_sensitivity(tp, tn, fp, fn, name): 3890 """Computes the specificity at the given sensitivity. 3891 3892 Args: 3893 tp: True positives. 3894 tn: True negatives. 3895 fp: False positives. 3896 fn: False negatives. 3897 name: The name of the operation. 3898 3899 Returns: 3900 The specificity using the aggregated values. 3901 """ 3902 sensitivities = math_ops.divide(tp, tp + fn + kepsilon) 3903 3904 # We'll need to use this trick until tf.argmax allows us to specify 3905 # whether we should use the first or last index in case of ties. 3906 min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity)) 3907 indices_at_minval = math_ops.equal( 3908 math_ops.abs(sensitivities - sensitivity), min_val) 3909 indices_at_minval = math_ops.cast(indices_at_minval, dtypes.int64) 3910 indices_at_minval = math_ops.cumsum(indices_at_minval) 3911 tf_index = math_ops.argmax(indices_at_minval, 0) 3912 tf_index = math_ops.cast(tf_index, dtypes.int32) 3913 3914 # Now, we have the implicit threshold, so compute the specificity: 3915 return math_ops.divide(tn[tf_index], 3916 tn[tf_index] + fp[tf_index] + kepsilon, name) 3917 3918 def specificity_across_replicas(_, values): 3919 return compute_specificity_at_sensitivity( 3920 values['tp'], values['tn'], values['fp'], values['fn'], 'value') 3921 3922 specificity = _aggregate_across_replicas( 3923 metrics_collections, specificity_across_replicas, values) 3924 3925 update_op = compute_specificity_at_sensitivity( 3926 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 3927 'update_op') 3928 if updates_collections: 3929 ops.add_to_collections(updates_collections, update_op) 3930 3931 return specificity, update_op 3932