xref: /aosp_15_r20/external/pytorch/test/test_monitor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: r2p"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport tempfile
4*da0073e9SAndroid Build Coastguard Workerimport time
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerfrom datetime import datetime, timedelta
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerfrom torch.monitor import (
9*da0073e9SAndroid Build Coastguard Worker    Aggregation,
10*da0073e9SAndroid Build Coastguard Worker    Event,
11*da0073e9SAndroid Build Coastguard Worker    log_event,
12*da0073e9SAndroid Build Coastguard Worker    register_event_handler,
13*da0073e9SAndroid Build Coastguard Worker    Stat,
14*da0073e9SAndroid Build Coastguard Worker    TensorboardEventHandler,
15*da0073e9SAndroid Build Coastguard Worker    unregister_event_handler,
16*da0073e9SAndroid Build Coastguard Worker    _WaitCounter,
17*da0073e9SAndroid Build Coastguard Worker)
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerclass TestMonitor(TestCase):
21*da0073e9SAndroid Build Coastguard Worker    def test_interval_stat(self) -> None:
22*da0073e9SAndroid Build Coastguard Worker        events = []
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker        def handler(event):
25*da0073e9SAndroid Build Coastguard Worker            events.append(event)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker        handle = register_event_handler(handler)
28*da0073e9SAndroid Build Coastguard Worker        s = Stat(
29*da0073e9SAndroid Build Coastguard Worker            "asdf",
30*da0073e9SAndroid Build Coastguard Worker            (Aggregation.SUM, Aggregation.COUNT),
31*da0073e9SAndroid Build Coastguard Worker            timedelta(milliseconds=1),
32*da0073e9SAndroid Build Coastguard Worker        )
33*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.name, "asdf")
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker        s.add(2)
36*da0073e9SAndroid Build Coastguard Worker        for _ in range(100):
37*da0073e9SAndroid Build Coastguard Worker            # NOTE: different platforms sleep may be inaccurate so we loop
38*da0073e9SAndroid Build Coastguard Worker            # instead (i.e. win)
39*da0073e9SAndroid Build Coastguard Worker            time.sleep(1 / 1000)  # ms
40*da0073e9SAndroid Build Coastguard Worker            s.add(3)
41*da0073e9SAndroid Build Coastguard Worker            if len(events) >= 1:
42*da0073e9SAndroid Build Coastguard Worker                break
43*da0073e9SAndroid Build Coastguard Worker        self.assertGreaterEqual(len(events), 1)
44*da0073e9SAndroid Build Coastguard Worker        unregister_event_handler(handle)
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    def test_fixed_count_stat(self) -> None:
47*da0073e9SAndroid Build Coastguard Worker        s = Stat(
48*da0073e9SAndroid Build Coastguard Worker            "asdf",
49*da0073e9SAndroid Build Coastguard Worker            (Aggregation.SUM, Aggregation.COUNT),
50*da0073e9SAndroid Build Coastguard Worker            timedelta(hours=100),
51*da0073e9SAndroid Build Coastguard Worker            3,
52*da0073e9SAndroid Build Coastguard Worker        )
53*da0073e9SAndroid Build Coastguard Worker        s.add(1)
54*da0073e9SAndroid Build Coastguard Worker        s.add(2)
55*da0073e9SAndroid Build Coastguard Worker        name = s.name
56*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(name, "asdf")
57*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.count, 2)
58*da0073e9SAndroid Build Coastguard Worker        s.add(3)
59*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.count, 0)
60*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.get(), {Aggregation.SUM: 6.0, Aggregation.COUNT: 3})
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    def test_log_event(self) -> None:
63*da0073e9SAndroid Build Coastguard Worker        e = Event(
64*da0073e9SAndroid Build Coastguard Worker            name="torch.monitor.TestEvent",
65*da0073e9SAndroid Build Coastguard Worker            timestamp=datetime.now(),
66*da0073e9SAndroid Build Coastguard Worker            data={
67*da0073e9SAndroid Build Coastguard Worker                "str": "a string",
68*da0073e9SAndroid Build Coastguard Worker                "float": 1234.0,
69*da0073e9SAndroid Build Coastguard Worker                "int": 1234,
70*da0073e9SAndroid Build Coastguard Worker            },
71*da0073e9SAndroid Build Coastguard Worker        )
72*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(e.name, "torch.monitor.TestEvent")
73*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(e.timestamp)
74*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(e.data)
75*da0073e9SAndroid Build Coastguard Worker        log_event(e)
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Really weird error")
78*da0073e9SAndroid Build Coastguard Worker    def test_event_handler(self) -> None:
79*da0073e9SAndroid Build Coastguard Worker        events = []
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker        def handler(event: Event) -> None:
82*da0073e9SAndroid Build Coastguard Worker            events.append(event)
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker        handle = register_event_handler(handler)
85*da0073e9SAndroid Build Coastguard Worker        e = Event(
86*da0073e9SAndroid Build Coastguard Worker            name="torch.monitor.TestEvent",
87*da0073e9SAndroid Build Coastguard Worker            timestamp=datetime.now(),
88*da0073e9SAndroid Build Coastguard Worker            data={},
89*da0073e9SAndroid Build Coastguard Worker        )
90*da0073e9SAndroid Build Coastguard Worker        log_event(e)
91*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(events), 1)
92*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(events[0], e)
93*da0073e9SAndroid Build Coastguard Worker        log_event(e)
94*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(events), 2)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker        unregister_event_handler(handle)
97*da0073e9SAndroid Build Coastguard Worker        log_event(e)
98*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(events), 2)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    def test_wait_counter(self) -> None:
101*da0073e9SAndroid Build Coastguard Worker        wait_counter = _WaitCounter(
102*da0073e9SAndroid Build Coastguard Worker            "test_wait_counter",
103*da0073e9SAndroid Build Coastguard Worker        )
104*da0073e9SAndroid Build Coastguard Worker        with wait_counter.guard() as wcg:
105*da0073e9SAndroid Build Coastguard Worker            pass
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Really weird error")
109*da0073e9SAndroid Build Coastguard Workerclass TestMonitorTensorboard(TestCase):
110*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
111*da0073e9SAndroid Build Coastguard Worker        global SummaryWriter, event_multiplexer
112*da0073e9SAndroid Build Coastguard Worker        try:
113*da0073e9SAndroid Build Coastguard Worker            from torch.utils.tensorboard import SummaryWriter
114*da0073e9SAndroid Build Coastguard Worker            from tensorboard.backend.event_processing import (
115*da0073e9SAndroid Build Coastguard Worker                plugin_event_multiplexer as event_multiplexer,
116*da0073e9SAndroid Build Coastguard Worker            )
117*da0073e9SAndroid Build Coastguard Worker        except ImportError:
118*da0073e9SAndroid Build Coastguard Worker            return self.skipTest("Skip the test since TensorBoard is not installed")
119*da0073e9SAndroid Build Coastguard Worker        self.temp_dirs = []
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    def create_summary_writer(self):
122*da0073e9SAndroid Build Coastguard Worker        temp_dir = tempfile.TemporaryDirectory()  # noqa: P201
123*da0073e9SAndroid Build Coastguard Worker        self.temp_dirs.append(temp_dir)
124*da0073e9SAndroid Build Coastguard Worker        return SummaryWriter(temp_dir.name)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
127*da0073e9SAndroid Build Coastguard Worker        # Remove directories created by SummaryWriter
128*da0073e9SAndroid Build Coastguard Worker        for temp_dir in self.temp_dirs:
129*da0073e9SAndroid Build Coastguard Worker            temp_dir.cleanup()
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    def test_event_handler(self):
132*da0073e9SAndroid Build Coastguard Worker        with self.create_summary_writer() as w:
133*da0073e9SAndroid Build Coastguard Worker            handle = register_event_handler(TensorboardEventHandler(w))
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker            s = Stat(
136*da0073e9SAndroid Build Coastguard Worker                "asdf",
137*da0073e9SAndroid Build Coastguard Worker                (Aggregation.SUM, Aggregation.COUNT),
138*da0073e9SAndroid Build Coastguard Worker                timedelta(hours=1),
139*da0073e9SAndroid Build Coastguard Worker                5,
140*da0073e9SAndroid Build Coastguard Worker            )
141*da0073e9SAndroid Build Coastguard Worker            for i in range(10):
142*da0073e9SAndroid Build Coastguard Worker                s.add(i)
143*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(s.count, 0)
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker            unregister_event_handler(handle)
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        mul = event_multiplexer.EventMultiplexer()
148*da0073e9SAndroid Build Coastguard Worker        mul.AddRunsFromDirectory(self.temp_dirs[-1].name)
149*da0073e9SAndroid Build Coastguard Worker        mul.Reload()
150*da0073e9SAndroid Build Coastguard Worker        scalar_dict = mul.PluginRunToTagToContent("scalars")
151*da0073e9SAndroid Build Coastguard Worker        raw_result = {
152*da0073e9SAndroid Build Coastguard Worker            tag: mul.Tensors(run, tag)
153*da0073e9SAndroid Build Coastguard Worker            for run, run_dict in scalar_dict.items()
154*da0073e9SAndroid Build Coastguard Worker            for tag in run_dict
155*da0073e9SAndroid Build Coastguard Worker        }
156*da0073e9SAndroid Build Coastguard Worker        scalars = {
157*da0073e9SAndroid Build Coastguard Worker            tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items()
158*da0073e9SAndroid Build Coastguard Worker        }
159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scalars, {
160*da0073e9SAndroid Build Coastguard Worker            "asdf.sum": [10],
161*da0073e9SAndroid Build Coastguard Worker            "asdf.count": [5],
162*da0073e9SAndroid Build Coastguard Worker        })
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
166*da0073e9SAndroid Build Coastguard Worker    run_tests()
167