xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/metrics_v1_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for V1 metrics."""
16from absl.testing import parameterized
17from tensorflow.python.data.ops import dataset_ops
18from tensorflow.python.distribute import combinations
19from tensorflow.python.distribute import strategy_combinations
20from tensorflow.python.distribute import strategy_test_lib
21from tensorflow.python.eager import test
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import metrics
25from tensorflow.python.ops import variables
26
27
28def _labeled_dataset_fn():
29  # First four batches of x: labels, predictions -> (labels == predictions)
30  #  0: 0, 0 -> True;   1: 1, 1 -> True;   2: 2, 2 -> True;   3: 3, 0 -> False
31  #  4: 4, 1 -> False;  5: 0, 2 -> False;  6: 1, 0 -> False;  7: 2, 1 -> False
32  #  8: 3, 2 -> False;  9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False
33  # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True
34  return dataset_ops.Dataset.range(1000).map(
35      lambda x: {"labels": x % 5, "predictions": x % 3}).batch(
36          4, drop_remainder=True)
37
38
39def _boolean_dataset_fn():
40  # First four batches of labels, predictions: {TP, FP, TN, FN}
41  # with a threshold of 0.5:
42  #   T, T -> TP;  F, T -> FP;   T, F -> FN
43  #   F, F -> TN;  T, T -> TP;   F, T -> FP
44  #   T, F -> FN;  F, F -> TN;   T, T -> TP
45  #   F, T -> FP;  T, F -> FN;   F, F -> TN
46  return dataset_ops.Dataset.from_tensor_slices({
47      "labels": [True, False, True, False],
48      "predictions": [True, True, False, False]}).repeat().batch(
49          3, drop_remainder=True)
50
51
52def _threshold_dataset_fn():
53  # First four batches of labels, predictions: {TP, FP, TN, FN}
54  # with a threshold of 0.5:
55  #   True, 1.0 -> TP;  False, .75 -> FP;   True, .25 -> FN
56  #  False, 0.0 -> TN;   True, 1.0 -> TP;  False, .75 -> FP
57  #   True, .25 -> FN;  False, 0.0 -> TN;   True, 1.0 -> TP
58  #  False, .75 -> FP;   True, .25 -> FN;  False, 0.0 -> TN
59  return dataset_ops.Dataset.from_tensor_slices({
60      "labels": [True, False, True, False],
61      "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(
62          3, drop_remainder=True)
63
64
65def _regression_dataset_fn():
66  return dataset_ops.Dataset.from_tensor_slices({
67      "labels": [1., .5, 1., 0.],
68      "predictions": [1., .75, .25, 0.]}).repeat()
69
70
71def all_combinations():
72  return combinations.combine(
73      distribution=[
74          strategy_combinations.default_strategy,
75          strategy_combinations.one_device_strategy,
76          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
77          strategy_combinations.mirrored_strategy_with_two_gpus,
78      ],
79      mode=["graph"])
80
81
82def tpu_combinations():
83  return combinations.combine(
84      distribution=[
85          strategy_combinations.tpu_strategy_one_step,
86          strategy_combinations.tpu_strategy
87      ],
88      mode=["graph"])
89
90
91# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k,
92# metrics.precision_at_k
93class MetricsV1Test(test.TestCase, parameterized.TestCase):
94
95  def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
96    with ops.Graph().as_default(), distribution.scope():
97      iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
98      if strategy_test_lib.is_tpu_strategy(distribution):
99        def step_fn(ctx, inputs):
100          value, update = distribution.extended.call_for_each_replica(
101              metric_fn, args=(inputs,))
102          ctx.set_non_tensor_output(name="value", output=value)
103          return distribution.group(update)
104
105        ctx = distribution.extended.experimental_run_steps_on_iterator(
106            step_fn, iterator, iterations=distribution.extended.steps_per_run)
107        update = ctx.run_op
108        value = ctx.non_tensor_outputs["value"]
109        # In each run, we run multiple steps, and each steps consumes as many
110        # batches as number of replicas.
111        batches_per_update = (
112            distribution.num_replicas_in_sync *
113            distribution.extended.steps_per_run)
114      else:
115        value, update = distribution.extended.call_for_each_replica(
116            metric_fn, args=(iterator.get_next(),))
117        update = distribution.group(update)
118        # TODO(josh11b): Once we switch to using a global batch size for input,
119        # replace "distribution.num_replicas_in_sync" with "1".
120        batches_per_update = distribution.num_replicas_in_sync
121
122      self.evaluate(iterator.initializer)
123      self.evaluate(variables.local_variables_initializer())
124
125      batches_consumed = 0
126      for i in range(4):
127        self.evaluate(update)
128        batches_consumed += batches_per_update
129        self.assertAllClose(expected_fn(batches_consumed),
130                            self.evaluate(value),
131                            0.001,
132                            msg="After update #" + str(i+1))
133        if batches_consumed >= 4:  # Consume 4 input batches in total.
134          break
135
136  @combinations.generate(all_combinations() + tpu_combinations())
137  def testMean(self, distribution):
138    def _dataset_fn():
139      return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(
140          4, drop_remainder=True)
141
142    def _expected_fn(num_batches):
143      # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc.
144      return num_batches * 2 - 0.5
145
146    self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn)
147
148  @combinations.generate(all_combinations() + tpu_combinations())
149  def testAccuracy(self, distribution):
150    def _metric_fn(x):
151      labels = x["labels"]
152      predictions = x["predictions"]
153      return metrics.accuracy(labels, predictions)
154
155    def _expected_fn(num_batches):
156      return [3./4, 3./8, 3./12, 4./16][num_batches - 1]
157
158    self._test_metric(
159        distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
160
161  # TODO(priyag, jhseu): Enable TPU for this test once scatter_add is added
162  # for TPUMirroredVariable.
163  @combinations.generate(all_combinations())
164  def testMeanPerClassAccuracy(self, distribution):
165    def _metric_fn(x):
166      labels = x["labels"]
167      predictions = x["predictions"]
168      return metrics.mean_per_class_accuracy(
169          labels, predictions, num_classes=5)
170
171    def _expected_fn(num_batches):
172      mean = lambda x: sum(x) / len(x)
173      return [mean([1., 1., 1., 0., 0.]),
174              mean([0.5, 0.5, 0.5, 0., 0.]),
175              mean([1./3, 1./3, 0.5, 0., 0.]),
176              mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1]
177
178    self._test_metric(
179        distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
180
181  # NOTE(priyag): This metric doesn't work on TPUs yet.
182  @combinations.generate(all_combinations())
183  def testMeanIOU(self, distribution):
184    def _metric_fn(x):
185      labels = x["labels"]
186      predictions = x["predictions"]
187      return metrics.mean_iou(
188          labels, predictions, num_classes=5)
189
190    def _expected_fn(num_batches):
191      mean = lambda x: sum(x) / len(x)
192      return [mean([1./2, 1./1, 1./1, 0.]),  # no class 4 in first batch
193              mean([1./4, 1./4, 1./3, 0., 0.]),
194              mean([1./6, 1./6, 1./5, 0., 0.]),
195              mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1]
196
197    self._test_metric(
198        distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
199
200  @combinations.generate(all_combinations() + tpu_combinations())
201  def testMeanTensor(self, distribution):
202    def _dataset_fn():
203      dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float)
204      # Want to produce a fixed, known shape, so drop remainder when batching.
205      dataset = dataset.batch(4, drop_remainder=True)
206      return dataset
207
208    def _expected_fn(num_batches):
209      # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2
210      # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1
211      # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches
212      # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1
213      first = 2. * num_batches - 2.
214      return [first, first + 1., first + 2., first + 3.]
215
216    self._test_metric(
217        distribution, _dataset_fn, metrics.mean_tensor, _expected_fn)
218
219  @combinations.generate(all_combinations() + tpu_combinations())
220  def testAUCROC(self, distribution):
221    def _metric_fn(x):
222      labels = x["labels"]
223      predictions = x["predictions"]
224      return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC",
225                         summation_method="careful_interpolation")
226
227    def _expected_fn(num_batches):
228      return [0.5, 7./9, 0.8, 0.75][num_batches - 1]
229
230    self._test_metric(
231        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
232
233  @combinations.generate(all_combinations() + tpu_combinations())
234  def testAUCPR(self, distribution):
235    def _metric_fn(x):
236      labels = x["labels"]
237      predictions = x["predictions"]
238      return metrics.auc(labels, predictions, num_thresholds=8, curve="PR",
239                         summation_method="careful_interpolation")
240
241    def _expected_fn(num_batches):
242      return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1]
243
244    self._test_metric(
245        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
246
247  @combinations.generate(all_combinations() + tpu_combinations())
248  def testFalseNegatives(self, distribution):
249    def _metric_fn(x):
250      labels = x["labels"]
251      predictions = x["predictions"]
252      return metrics.false_negatives(labels, predictions)
253
254    def _expected_fn(num_batches):
255      return [1., 1., 2., 3.][num_batches - 1]
256
257    self._test_metric(
258        distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
259
260  @combinations.generate(all_combinations() + tpu_combinations())
261  def testFalseNegativesAtThresholds(self, distribution):
262    def _metric_fn(x):
263      labels = x["labels"]
264      predictions = x["predictions"]
265      return metrics.false_negatives_at_thresholds(labels, predictions, [.5])
266
267    def _expected_fn(num_batches):
268      return [[1.], [1.], [2.], [3.]][num_batches - 1]
269
270    self._test_metric(
271        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
272
273  @combinations.generate(all_combinations() + tpu_combinations())
274  def testTrueNegatives(self, distribution):
275    def _metric_fn(x):
276      labels = x["labels"]
277      predictions = x["predictions"]
278      return metrics.true_negatives(labels, predictions)
279
280    def _expected_fn(num_batches):
281      return [0., 1., 2., 3.][num_batches - 1]
282
283    self._test_metric(
284        distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
285
286  @combinations.generate(all_combinations() + tpu_combinations())
287  def testTrueNegativesAtThresholds(self, distribution):
288    def _metric_fn(x):
289      labels = x["labels"]
290      predictions = x["predictions"]
291      return metrics.true_negatives_at_thresholds(labels, predictions, [.5])
292
293    def _expected_fn(num_batches):
294      return [[0.], [1.], [2.], [3.]][num_batches - 1]
295
296    self._test_metric(
297        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
298
299  @combinations.generate(all_combinations() + tpu_combinations())
300  def testFalsePositives(self, distribution):
301    def _metric_fn(x):
302      labels = x["labels"]
303      predictions = x["predictions"]
304      return metrics.false_positives(labels, predictions)
305
306    def _expected_fn(num_batches):
307      return [1., 2., 2., 3.][num_batches - 1]
308
309    self._test_metric(
310        distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
311
312  @combinations.generate(all_combinations() + tpu_combinations())
313  def testFalsePositivesAtThresholds(self, distribution):
314    def _metric_fn(x):
315      labels = x["labels"]
316      predictions = x["predictions"]
317      return metrics.false_positives_at_thresholds(labels, predictions, [.5])
318
319    def _expected_fn(num_batches):
320      return [[1.], [2.], [2.], [3.]][num_batches - 1]
321
322    self._test_metric(
323        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
324
325  @combinations.generate(all_combinations() + tpu_combinations())
326  def testTruePositives(self, distribution):
327    def _metric_fn(x):
328      labels = x["labels"]
329      predictions = x["predictions"]
330      return metrics.true_positives(labels, predictions)
331
332    def _expected_fn(num_batches):
333      return [1., 2., 3., 3.][num_batches - 1]
334
335    self._test_metric(
336        distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
337
338  @combinations.generate(all_combinations() + tpu_combinations())
339  def testTruePositivesAtThresholds(self, distribution):
340    def _metric_fn(x):
341      labels = x["labels"]
342      predictions = x["predictions"]
343      return metrics.true_positives_at_thresholds(labels, predictions, [.5])
344
345    def _expected_fn(num_batches):
346      return [[1.], [2.], [3.], [3.]][num_batches - 1]
347
348    self._test_metric(
349        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
350
351  @combinations.generate(all_combinations() + tpu_combinations())
352  def testPrecision(self, distribution):
353    def _metric_fn(x):
354      labels = x["labels"]
355      predictions = x["predictions"]
356      return metrics.precision(labels, predictions)
357
358    def _expected_fn(num_batches):
359      return [0.5, 0.5, 0.6, 0.5][num_batches - 1]
360
361    self._test_metric(
362        distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
363
364  @combinations.generate(all_combinations() + tpu_combinations())
365  def testPrecisionAtThreshold(self, distribution):
366    def _metric_fn(x):
367      labels = x["labels"]
368      predictions = x["predictions"]
369      return metrics.precision_at_thresholds(labels, predictions, [0.5])
370
371    def _expected_fn(num_batches):
372      return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1]
373
374    self._test_metric(
375        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
376
377  @combinations.generate(all_combinations() + tpu_combinations())
378  def testRecall(self, distribution):
379    def _metric_fn(x):
380      labels = x["labels"]
381      predictions = x["predictions"]
382      return metrics.recall(labels, predictions)
383
384    def _expected_fn(num_batches):
385      return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
386
387    self._test_metric(
388        distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
389
390  @combinations.generate(all_combinations() + tpu_combinations())
391  def testRecallAtThreshold(self, distribution):
392    def _metric_fn(x):
393      labels = x["labels"]
394      predictions = x["predictions"]
395      return metrics.recall_at_thresholds(labels, predictions, [0.5])
396
397    def _expected_fn(num_batches):
398      return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1]
399
400    self._test_metric(
401        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
402
403  @combinations.generate(all_combinations() + tpu_combinations())
404  def testMeanSquaredError(self, distribution):
405    def _metric_fn(x):
406      labels = x["labels"]
407      predictions = x["predictions"]
408      return metrics.mean_squared_error(labels, predictions)
409
410    def _expected_fn(num_batches):
411      return [0., 1./32, 0.208333, 0.15625][num_batches - 1]
412
413    self._test_metric(
414        distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
415
416  @combinations.generate(all_combinations() + tpu_combinations())
417  def testRootMeanSquaredError(self, distribution):
418    def _metric_fn(x):
419      labels = x["labels"]
420      predictions = x["predictions"]
421      return metrics.root_mean_squared_error(labels, predictions)
422
423    def _expected_fn(num_batches):
424      return [0., 0.176777, 0.456435, 0.395285][num_batches - 1]
425
426    self._test_metric(
427        distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
428
429  @combinations.generate(all_combinations())
430  def testSensitivityAtSpecificity(self, distribution):
431    def _metric_fn(x):
432      labels = x["labels"]
433      predictions = x["predictions"]
434      return metrics.sensitivity_at_specificity(labels, predictions, 0.8)
435
436    def _expected_fn(num_batches):
437      return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
438
439    self._test_metric(
440        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
441
442  @combinations.generate(all_combinations())
443  def testSpecificityAtSensitivity(self, distribution):
444    def _metric_fn(x):
445      labels = x["labels"]
446      predictions = x["predictions"]
447      return metrics.specificity_at_sensitivity(labels, predictions, 0.95)
448
449    def _expected_fn(num_batches):
450      return [0., 1./3, 0.5, 0.5][num_batches - 1]
451
452    self._test_metric(
453        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
454
455
456if __name__ == "__main__":
457  test.main()
458