xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/monitoring_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for monitoring."""
16
17import time
18
19from tensorflow.python.eager import monitoring
20from tensorflow.python.eager import test
21from tensorflow.python.framework import errors
22from tensorflow.python.framework import test_util
23
24
25class MonitoringTest(test_util.TensorFlowTestCase):
26
27  def test_counter(self):
28    counter = monitoring.Counter('test/counter', 'test counter')
29    counter.get_cell().increase_by(1)
30    self.assertEqual(counter.get_cell().value(), 1)
31    counter.get_cell().increase_by(5)
32    self.assertEqual(counter.get_cell().value(), 6)
33
34  def test_multiple_counters(self):
35    counter1 = monitoring.Counter('test/counter1', 'test counter', 'label1')
36    counter1.get_cell('foo').increase_by(1)
37    self.assertEqual(counter1.get_cell('foo').value(), 1)
38    counter2 = monitoring.Counter('test/counter2', 'test counter', 'label1',
39                                  'label2')
40    counter2.get_cell('foo', 'bar').increase_by(5)
41    self.assertEqual(counter2.get_cell('foo', 'bar').value(), 5)
42
43  def test_same_counter(self):
44    counter1 = monitoring.Counter('test/same_counter', 'test counter')  # pylint: disable=unused-variable
45    with self.assertRaises(errors.AlreadyExistsError):
46      counter2 = monitoring.Counter('test/same_counter', 'test counter')  # pylint: disable=unused-variable
47
48  def test_int_gauge(self):
49    gauge = monitoring.IntGauge('test/gauge', 'test gauge')
50    gauge.get_cell().set(1)
51    self.assertEqual(gauge.get_cell().value(), 1)
52    gauge.get_cell().set(5)
53    self.assertEqual(gauge.get_cell().value(), 5)
54
55    gauge1 = monitoring.IntGauge('test/gauge1', 'test gauge1', 'label1')
56    gauge1.get_cell('foo').set(2)
57    self.assertEqual(gauge1.get_cell('foo').value(), 2)
58
59  def test_string_gauge(self):
60    gauge = monitoring.StringGauge('test/gauge', 'test gauge')
61    gauge.get_cell().set('left')
62    self.assertEqual(gauge.get_cell().value(), 'left')
63    gauge.get_cell().set('right')
64    self.assertEqual(gauge.get_cell().value(), 'right')
65
66    gauge1 = monitoring.StringGauge('test/gauge1', 'test gauge1', 'label1')
67    gauge1.get_cell('foo').set('start')
68    self.assertEqual(gauge1.get_cell('foo').value(), 'start')
69
70  def test_bool_gauge(self):
71    gauge = monitoring.BoolGauge('test/gauge', 'test gauge')
72    gauge.get_cell().set(True)
73    self.assertTrue(gauge.get_cell().value())
74    gauge.get_cell().set(False)
75    self.assertFalse(gauge.get_cell().value())
76
77    gauge1 = monitoring.BoolGauge('test/gauge1', 'test gauge1', 'label1')
78    gauge1.get_cell('foo').set(True)
79    self.assertTrue(gauge1.get_cell('foo').value())
80
81  def test_sampler(self):
82    buckets = monitoring.ExponentialBuckets(1.0, 2.0, 2)
83    sampler = monitoring.Sampler('test/sampler', buckets, 'test sampler')
84    sampler.get_cell().add(1.0)
85    sampler.get_cell().add(5.0)
86    histogram_proto = sampler.get_cell().value()
87    self.assertEqual(histogram_proto.min, 1.0)
88    self.assertEqual(histogram_proto.num, 2.0)
89    self.assertEqual(histogram_proto.sum, 6.0)
90
91    sampler1 = monitoring.Sampler('test/sampler1', buckets, 'test sampler',
92                                  'label1')
93    sampler1.get_cell('foo').add(2.0)
94    sampler1.get_cell('foo').add(4.0)
95    sampler1.get_cell('bar').add(8.0)
96    histogram_proto1 = sampler1.get_cell('foo').value()
97    self.assertEqual(histogram_proto1.max, 4.0)
98    self.assertEqual(histogram_proto1.num, 2.0)
99    self.assertEqual(histogram_proto1.sum, 6.0)
100
101  def test_context_manager(self):
102    counter = monitoring.Counter('test/ctxmgr', 'test context manager', 'slot')
103    with monitoring.MonitoredTimer(counter.get_cell('long')):
104      time.sleep(0.01)
105      with monitoring.MonitoredTimer(counter.get_cell('short')):
106        time.sleep(0.01)
107    self.assertGreater(
108        counter.get_cell('long').value(),
109        counter.get_cell('short').value())
110
111  def test_function_decorator(self):
112    counter = monitoring.Counter('test/funcdecorator', 'test func decorator')
113
114    @monitoring.monitored_timer(counter.get_cell())
115    def timed_function(seconds):
116      time.sleep(seconds)
117
118    timed_function(0.001)
119    self.assertGreater(counter.get_cell().value(), 1000)
120
121
122if __name__ == '__main__':
123  test.main()
124