1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for V1 metrics.""" 16from absl.testing import parameterized 17from tensorflow.python.data.ops import dataset_ops 18from tensorflow.python.distribute import combinations 19from tensorflow.python.distribute import strategy_combinations 20from tensorflow.python.distribute import strategy_test_lib 21from tensorflow.python.eager import test 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import metrics 25from tensorflow.python.ops import variables 26 27 28def _labeled_dataset_fn(): 29 # First four batches of x: labels, predictions -> (labels == predictions) 30 # 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False 31 # 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False 32 # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False 33 # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True 34 return dataset_ops.Dataset.range(1000).map( 35 lambda x: {"labels": x % 5, "predictions": x % 3}).batch( 36 4, drop_remainder=True) 37 38 39def _boolean_dataset_fn(): 40 # First four batches of labels, predictions: {TP, FP, TN, FN} 41 # with a threshold of 0.5: 42 # T, T -> TP; F, T -> FP; T, F -> FN 43 # F, F -> TN; T, T -> TP; F, T -> FP 44 # T, F -> FN; F, F -> TN; T, T -> TP 45 # F, T -> FP; T, F -> FN; F, F -> TN 46 return dataset_ops.Dataset.from_tensor_slices({ 47 "labels": [True, False, True, False], 48 "predictions": [True, True, False, False]}).repeat().batch( 49 3, drop_remainder=True) 50 51 52def _threshold_dataset_fn(): 53 # First four batches of labels, predictions: {TP, FP, TN, FN} 54 # with a threshold of 0.5: 55 # True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN 56 # False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP 57 # True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP 58 # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN 59 return dataset_ops.Dataset.from_tensor_slices({ 60 "labels": [True, False, True, False], 61 "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch( 62 3, drop_remainder=True) 63 64 65def _regression_dataset_fn(): 66 return dataset_ops.Dataset.from_tensor_slices({ 67 "labels": [1., .5, 1., 0.], 68 "predictions": [1., .75, .25, 0.]}).repeat() 69 70 71def all_combinations(): 72 return combinations.combine( 73 distribution=[ 74 strategy_combinations.default_strategy, 75 strategy_combinations.one_device_strategy, 76 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 77 strategy_combinations.mirrored_strategy_with_two_gpus, 78 ], 79 mode=["graph"]) 80 81 82def tpu_combinations(): 83 return combinations.combine( 84 distribution=[ 85 strategy_combinations.tpu_strategy_one_step, 86 strategy_combinations.tpu_strategy 87 ], 88 mode=["graph"]) 89 90 91# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k, 92# metrics.precision_at_k 93class MetricsV1Test(test.TestCase, parameterized.TestCase): 94 95 def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): 96 with ops.Graph().as_default(), distribution.scope(): 97 iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn()) 98 if strategy_test_lib.is_tpu_strategy(distribution): 99 def step_fn(ctx, inputs): 100 value, update = distribution.extended.call_for_each_replica( 101 metric_fn, args=(inputs,)) 102 ctx.set_non_tensor_output(name="value", output=value) 103 return distribution.group(update) 104 105 ctx = distribution.extended.experimental_run_steps_on_iterator( 106 step_fn, iterator, iterations=distribution.extended.steps_per_run) 107 update = ctx.run_op 108 value = ctx.non_tensor_outputs["value"] 109 # In each run, we run multiple steps, and each steps consumes as many 110 # batches as number of replicas. 111 batches_per_update = ( 112 distribution.num_replicas_in_sync * 113 distribution.extended.steps_per_run) 114 else: 115 value, update = distribution.extended.call_for_each_replica( 116 metric_fn, args=(iterator.get_next(),)) 117 update = distribution.group(update) 118 # TODO(josh11b): Once we switch to using a global batch size for input, 119 # replace "distribution.num_replicas_in_sync" with "1". 120 batches_per_update = distribution.num_replicas_in_sync 121 122 self.evaluate(iterator.initializer) 123 self.evaluate(variables.local_variables_initializer()) 124 125 batches_consumed = 0 126 for i in range(4): 127 self.evaluate(update) 128 batches_consumed += batches_per_update 129 self.assertAllClose(expected_fn(batches_consumed), 130 self.evaluate(value), 131 0.001, 132 msg="After update #" + str(i+1)) 133 if batches_consumed >= 4: # Consume 4 input batches in total. 134 break 135 136 @combinations.generate(all_combinations() + tpu_combinations()) 137 def testMean(self, distribution): 138 def _dataset_fn(): 139 return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch( 140 4, drop_remainder=True) 141 142 def _expected_fn(num_batches): 143 # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc. 144 return num_batches * 2 - 0.5 145 146 self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn) 147 148 @combinations.generate(all_combinations() + tpu_combinations()) 149 def testAccuracy(self, distribution): 150 def _metric_fn(x): 151 labels = x["labels"] 152 predictions = x["predictions"] 153 return metrics.accuracy(labels, predictions) 154 155 def _expected_fn(num_batches): 156 return [3./4, 3./8, 3./12, 4./16][num_batches - 1] 157 158 self._test_metric( 159 distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) 160 161 # TODO(priyag, jhseu): Enable TPU for this test once scatter_add is added 162 # for TPUMirroredVariable. 163 @combinations.generate(all_combinations()) 164 def testMeanPerClassAccuracy(self, distribution): 165 def _metric_fn(x): 166 labels = x["labels"] 167 predictions = x["predictions"] 168 return metrics.mean_per_class_accuracy( 169 labels, predictions, num_classes=5) 170 171 def _expected_fn(num_batches): 172 mean = lambda x: sum(x) / len(x) 173 return [mean([1., 1., 1., 0., 0.]), 174 mean([0.5, 0.5, 0.5, 0., 0.]), 175 mean([1./3, 1./3, 0.5, 0., 0.]), 176 mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1] 177 178 self._test_metric( 179 distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) 180 181 # NOTE(priyag): This metric doesn't work on TPUs yet. 182 @combinations.generate(all_combinations()) 183 def testMeanIOU(self, distribution): 184 def _metric_fn(x): 185 labels = x["labels"] 186 predictions = x["predictions"] 187 return metrics.mean_iou( 188 labels, predictions, num_classes=5) 189 190 def _expected_fn(num_batches): 191 mean = lambda x: sum(x) / len(x) 192 return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch 193 mean([1./4, 1./4, 1./3, 0., 0.]), 194 mean([1./6, 1./6, 1./5, 0., 0.]), 195 mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1] 196 197 self._test_metric( 198 distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) 199 200 @combinations.generate(all_combinations() + tpu_combinations()) 201 def testMeanTensor(self, distribution): 202 def _dataset_fn(): 203 dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) 204 # Want to produce a fixed, known shape, so drop remainder when batching. 205 dataset = dataset.batch(4, drop_remainder=True) 206 return dataset 207 208 def _expected_fn(num_batches): 209 # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2 210 # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1 211 # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches 212 # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1 213 first = 2. * num_batches - 2. 214 return [first, first + 1., first + 2., first + 3.] 215 216 self._test_metric( 217 distribution, _dataset_fn, metrics.mean_tensor, _expected_fn) 218 219 @combinations.generate(all_combinations() + tpu_combinations()) 220 def testAUCROC(self, distribution): 221 def _metric_fn(x): 222 labels = x["labels"] 223 predictions = x["predictions"] 224 return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC", 225 summation_method="careful_interpolation") 226 227 def _expected_fn(num_batches): 228 return [0.5, 7./9, 0.8, 0.75][num_batches - 1] 229 230 self._test_metric( 231 distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) 232 233 @combinations.generate(all_combinations() + tpu_combinations()) 234 def testAUCPR(self, distribution): 235 def _metric_fn(x): 236 labels = x["labels"] 237 predictions = x["predictions"] 238 return metrics.auc(labels, predictions, num_thresholds=8, curve="PR", 239 summation_method="careful_interpolation") 240 241 def _expected_fn(num_batches): 242 return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1] 243 244 self._test_metric( 245 distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) 246 247 @combinations.generate(all_combinations() + tpu_combinations()) 248 def testFalseNegatives(self, distribution): 249 def _metric_fn(x): 250 labels = x["labels"] 251 predictions = x["predictions"] 252 return metrics.false_negatives(labels, predictions) 253 254 def _expected_fn(num_batches): 255 return [1., 1., 2., 3.][num_batches - 1] 256 257 self._test_metric( 258 distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) 259 260 @combinations.generate(all_combinations() + tpu_combinations()) 261 def testFalseNegativesAtThresholds(self, distribution): 262 def _metric_fn(x): 263 labels = x["labels"] 264 predictions = x["predictions"] 265 return metrics.false_negatives_at_thresholds(labels, predictions, [.5]) 266 267 def _expected_fn(num_batches): 268 return [[1.], [1.], [2.], [3.]][num_batches - 1] 269 270 self._test_metric( 271 distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) 272 273 @combinations.generate(all_combinations() + tpu_combinations()) 274 def testTrueNegatives(self, distribution): 275 def _metric_fn(x): 276 labels = x["labels"] 277 predictions = x["predictions"] 278 return metrics.true_negatives(labels, predictions) 279 280 def _expected_fn(num_batches): 281 return [0., 1., 2., 3.][num_batches - 1] 282 283 self._test_metric( 284 distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) 285 286 @combinations.generate(all_combinations() + tpu_combinations()) 287 def testTrueNegativesAtThresholds(self, distribution): 288 def _metric_fn(x): 289 labels = x["labels"] 290 predictions = x["predictions"] 291 return metrics.true_negatives_at_thresholds(labels, predictions, [.5]) 292 293 def _expected_fn(num_batches): 294 return [[0.], [1.], [2.], [3.]][num_batches - 1] 295 296 self._test_metric( 297 distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) 298 299 @combinations.generate(all_combinations() + tpu_combinations()) 300 def testFalsePositives(self, distribution): 301 def _metric_fn(x): 302 labels = x["labels"] 303 predictions = x["predictions"] 304 return metrics.false_positives(labels, predictions) 305 306 def _expected_fn(num_batches): 307 return [1., 2., 2., 3.][num_batches - 1] 308 309 self._test_metric( 310 distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) 311 312 @combinations.generate(all_combinations() + tpu_combinations()) 313 def testFalsePositivesAtThresholds(self, distribution): 314 def _metric_fn(x): 315 labels = x["labels"] 316 predictions = x["predictions"] 317 return metrics.false_positives_at_thresholds(labels, predictions, [.5]) 318 319 def _expected_fn(num_batches): 320 return [[1.], [2.], [2.], [3.]][num_batches - 1] 321 322 self._test_metric( 323 distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) 324 325 @combinations.generate(all_combinations() + tpu_combinations()) 326 def testTruePositives(self, distribution): 327 def _metric_fn(x): 328 labels = x["labels"] 329 predictions = x["predictions"] 330 return metrics.true_positives(labels, predictions) 331 332 def _expected_fn(num_batches): 333 return [1., 2., 3., 3.][num_batches - 1] 334 335 self._test_metric( 336 distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) 337 338 @combinations.generate(all_combinations() + tpu_combinations()) 339 def testTruePositivesAtThresholds(self, distribution): 340 def _metric_fn(x): 341 labels = x["labels"] 342 predictions = x["predictions"] 343 return metrics.true_positives_at_thresholds(labels, predictions, [.5]) 344 345 def _expected_fn(num_batches): 346 return [[1.], [2.], [3.], [3.]][num_batches - 1] 347 348 self._test_metric( 349 distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) 350 351 @combinations.generate(all_combinations() + tpu_combinations()) 352 def testPrecision(self, distribution): 353 def _metric_fn(x): 354 labels = x["labels"] 355 predictions = x["predictions"] 356 return metrics.precision(labels, predictions) 357 358 def _expected_fn(num_batches): 359 return [0.5, 0.5, 0.6, 0.5][num_batches - 1] 360 361 self._test_metric( 362 distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) 363 364 @combinations.generate(all_combinations() + tpu_combinations()) 365 def testPrecisionAtThreshold(self, distribution): 366 def _metric_fn(x): 367 labels = x["labels"] 368 predictions = x["predictions"] 369 return metrics.precision_at_thresholds(labels, predictions, [0.5]) 370 371 def _expected_fn(num_batches): 372 return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1] 373 374 self._test_metric( 375 distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) 376 377 @combinations.generate(all_combinations() + tpu_combinations()) 378 def testRecall(self, distribution): 379 def _metric_fn(x): 380 labels = x["labels"] 381 predictions = x["predictions"] 382 return metrics.recall(labels, predictions) 383 384 def _expected_fn(num_batches): 385 return [0.5, 2./3, 0.6, 0.5][num_batches - 1] 386 387 self._test_metric( 388 distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) 389 390 @combinations.generate(all_combinations() + tpu_combinations()) 391 def testRecallAtThreshold(self, distribution): 392 def _metric_fn(x): 393 labels = x["labels"] 394 predictions = x["predictions"] 395 return metrics.recall_at_thresholds(labels, predictions, [0.5]) 396 397 def _expected_fn(num_batches): 398 return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1] 399 400 self._test_metric( 401 distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) 402 403 @combinations.generate(all_combinations() + tpu_combinations()) 404 def testMeanSquaredError(self, distribution): 405 def _metric_fn(x): 406 labels = x["labels"] 407 predictions = x["predictions"] 408 return metrics.mean_squared_error(labels, predictions) 409 410 def _expected_fn(num_batches): 411 return [0., 1./32, 0.208333, 0.15625][num_batches - 1] 412 413 self._test_metric( 414 distribution, _regression_dataset_fn, _metric_fn, _expected_fn) 415 416 @combinations.generate(all_combinations() + tpu_combinations()) 417 def testRootMeanSquaredError(self, distribution): 418 def _metric_fn(x): 419 labels = x["labels"] 420 predictions = x["predictions"] 421 return metrics.root_mean_squared_error(labels, predictions) 422 423 def _expected_fn(num_batches): 424 return [0., 0.176777, 0.456435, 0.395285][num_batches - 1] 425 426 self._test_metric( 427 distribution, _regression_dataset_fn, _metric_fn, _expected_fn) 428 429 @combinations.generate(all_combinations()) 430 def testSensitivityAtSpecificity(self, distribution): 431 def _metric_fn(x): 432 labels = x["labels"] 433 predictions = x["predictions"] 434 return metrics.sensitivity_at_specificity(labels, predictions, 0.8) 435 436 def _expected_fn(num_batches): 437 return [0.5, 2./3, 0.6, 0.5][num_batches - 1] 438 439 self._test_metric( 440 distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) 441 442 @combinations.generate(all_combinations()) 443 def testSpecificityAtSensitivity(self, distribution): 444 def _metric_fn(x): 445 labels = x["labels"] 446 predictions = x["predictions"] 447 return metrics.specificity_at_sensitivity(labels, predictions, 0.95) 448 449 def _expected_fn(num_batches): 450 return [0., 1./3, 0.5, 0.5][num_batches - 1] 451 452 self._test_metric( 453 distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) 454 455 456if __name__ == "__main__": 457 test.main() 458