1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: r2p"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport tempfile 4*da0073e9SAndroid Build Coastguard Workerimport time 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerfrom datetime import datetime, timedelta 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerfrom torch.monitor import ( 9*da0073e9SAndroid Build Coastguard Worker Aggregation, 10*da0073e9SAndroid Build Coastguard Worker Event, 11*da0073e9SAndroid Build Coastguard Worker log_event, 12*da0073e9SAndroid Build Coastguard Worker register_event_handler, 13*da0073e9SAndroid Build Coastguard Worker Stat, 14*da0073e9SAndroid Build Coastguard Worker TensorboardEventHandler, 15*da0073e9SAndroid Build Coastguard Worker unregister_event_handler, 16*da0073e9SAndroid Build Coastguard Worker _WaitCounter, 17*da0073e9SAndroid Build Coastguard Worker) 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerclass TestMonitor(TestCase): 21*da0073e9SAndroid Build Coastguard Worker def test_interval_stat(self) -> None: 22*da0073e9SAndroid Build Coastguard Worker events = [] 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker def handler(event): 25*da0073e9SAndroid Build Coastguard Worker events.append(event) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker handle = register_event_handler(handler) 28*da0073e9SAndroid Build Coastguard Worker s = Stat( 29*da0073e9SAndroid Build Coastguard Worker "asdf", 30*da0073e9SAndroid Build Coastguard Worker (Aggregation.SUM, Aggregation.COUNT), 31*da0073e9SAndroid Build Coastguard Worker timedelta(milliseconds=1), 32*da0073e9SAndroid Build Coastguard Worker ) 33*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.name, "asdf") 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker s.add(2) 36*da0073e9SAndroid Build Coastguard Worker for _ in range(100): 37*da0073e9SAndroid Build Coastguard Worker # NOTE: different platforms sleep may be inaccurate so we loop 38*da0073e9SAndroid Build Coastguard Worker # instead (i.e. win) 39*da0073e9SAndroid Build Coastguard Worker time.sleep(1 / 1000) # ms 40*da0073e9SAndroid Build Coastguard Worker s.add(3) 41*da0073e9SAndroid Build Coastguard Worker if len(events) >= 1: 42*da0073e9SAndroid Build Coastguard Worker break 43*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(len(events), 1) 44*da0073e9SAndroid Build Coastguard Worker unregister_event_handler(handle) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker def test_fixed_count_stat(self) -> None: 47*da0073e9SAndroid Build Coastguard Worker s = Stat( 48*da0073e9SAndroid Build Coastguard Worker "asdf", 49*da0073e9SAndroid Build Coastguard Worker (Aggregation.SUM, Aggregation.COUNT), 50*da0073e9SAndroid Build Coastguard Worker timedelta(hours=100), 51*da0073e9SAndroid Build Coastguard Worker 3, 52*da0073e9SAndroid Build Coastguard Worker ) 53*da0073e9SAndroid Build Coastguard Worker s.add(1) 54*da0073e9SAndroid Build Coastguard Worker s.add(2) 55*da0073e9SAndroid Build Coastguard Worker name = s.name 56*da0073e9SAndroid Build Coastguard Worker self.assertEqual(name, "asdf") 57*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.count, 2) 58*da0073e9SAndroid Build Coastguard Worker s.add(3) 59*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.count, 0) 60*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.get(), {Aggregation.SUM: 6.0, Aggregation.COUNT: 3}) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def test_log_event(self) -> None: 63*da0073e9SAndroid Build Coastguard Worker e = Event( 64*da0073e9SAndroid Build Coastguard Worker name="torch.monitor.TestEvent", 65*da0073e9SAndroid Build Coastguard Worker timestamp=datetime.now(), 66*da0073e9SAndroid Build Coastguard Worker data={ 67*da0073e9SAndroid Build Coastguard Worker "str": "a string", 68*da0073e9SAndroid Build Coastguard Worker "float": 1234.0, 69*da0073e9SAndroid Build Coastguard Worker "int": 1234, 70*da0073e9SAndroid Build Coastguard Worker }, 71*da0073e9SAndroid Build Coastguard Worker ) 72*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.name, "torch.monitor.TestEvent") 73*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(e.timestamp) 74*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(e.data) 75*da0073e9SAndroid Build Coastguard Worker log_event(e) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Really weird error") 78*da0073e9SAndroid Build Coastguard Worker def test_event_handler(self) -> None: 79*da0073e9SAndroid Build Coastguard Worker events = [] 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker def handler(event: Event) -> None: 82*da0073e9SAndroid Build Coastguard Worker events.append(event) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker handle = register_event_handler(handler) 85*da0073e9SAndroid Build Coastguard Worker e = Event( 86*da0073e9SAndroid Build Coastguard Worker name="torch.monitor.TestEvent", 87*da0073e9SAndroid Build Coastguard Worker timestamp=datetime.now(), 88*da0073e9SAndroid Build Coastguard Worker data={}, 89*da0073e9SAndroid Build Coastguard Worker ) 90*da0073e9SAndroid Build Coastguard Worker log_event(e) 91*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(events), 1) 92*da0073e9SAndroid Build Coastguard Worker self.assertEqual(events[0], e) 93*da0073e9SAndroid Build Coastguard Worker log_event(e) 94*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(events), 2) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker unregister_event_handler(handle) 97*da0073e9SAndroid Build Coastguard Worker log_event(e) 98*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(events), 2) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker def test_wait_counter(self) -> None: 101*da0073e9SAndroid Build Coastguard Worker wait_counter = _WaitCounter( 102*da0073e9SAndroid Build Coastguard Worker "test_wait_counter", 103*da0073e9SAndroid Build Coastguard Worker ) 104*da0073e9SAndroid Build Coastguard Worker with wait_counter.guard() as wcg: 105*da0073e9SAndroid Build Coastguard Worker pass 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo("Really weird error") 109*da0073e9SAndroid Build Coastguard Workerclass TestMonitorTensorboard(TestCase): 110*da0073e9SAndroid Build Coastguard Worker def setUp(self): 111*da0073e9SAndroid Build Coastguard Worker global SummaryWriter, event_multiplexer 112*da0073e9SAndroid Build Coastguard Worker try: 113*da0073e9SAndroid Build Coastguard Worker from torch.utils.tensorboard import SummaryWriter 114*da0073e9SAndroid Build Coastguard Worker from tensorboard.backend.event_processing import ( 115*da0073e9SAndroid Build Coastguard Worker plugin_event_multiplexer as event_multiplexer, 116*da0073e9SAndroid Build Coastguard Worker ) 117*da0073e9SAndroid Build Coastguard Worker except ImportError: 118*da0073e9SAndroid Build Coastguard Worker return self.skipTest("Skip the test since TensorBoard is not installed") 119*da0073e9SAndroid Build Coastguard Worker self.temp_dirs = [] 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker def create_summary_writer(self): 122*da0073e9SAndroid Build Coastguard Worker temp_dir = tempfile.TemporaryDirectory() # noqa: P201 123*da0073e9SAndroid Build Coastguard Worker self.temp_dirs.append(temp_dir) 124*da0073e9SAndroid Build Coastguard Worker return SummaryWriter(temp_dir.name) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 127*da0073e9SAndroid Build Coastguard Worker # Remove directories created by SummaryWriter 128*da0073e9SAndroid Build Coastguard Worker for temp_dir in self.temp_dirs: 129*da0073e9SAndroid Build Coastguard Worker temp_dir.cleanup() 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker def test_event_handler(self): 132*da0073e9SAndroid Build Coastguard Worker with self.create_summary_writer() as w: 133*da0073e9SAndroid Build Coastguard Worker handle = register_event_handler(TensorboardEventHandler(w)) 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker s = Stat( 136*da0073e9SAndroid Build Coastguard Worker "asdf", 137*da0073e9SAndroid Build Coastguard Worker (Aggregation.SUM, Aggregation.COUNT), 138*da0073e9SAndroid Build Coastguard Worker timedelta(hours=1), 139*da0073e9SAndroid Build Coastguard Worker 5, 140*da0073e9SAndroid Build Coastguard Worker ) 141*da0073e9SAndroid Build Coastguard Worker for i in range(10): 142*da0073e9SAndroid Build Coastguard Worker s.add(i) 143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.count, 0) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker unregister_event_handler(handle) 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker mul = event_multiplexer.EventMultiplexer() 148*da0073e9SAndroid Build Coastguard Worker mul.AddRunsFromDirectory(self.temp_dirs[-1].name) 149*da0073e9SAndroid Build Coastguard Worker mul.Reload() 150*da0073e9SAndroid Build Coastguard Worker scalar_dict = mul.PluginRunToTagToContent("scalars") 151*da0073e9SAndroid Build Coastguard Worker raw_result = { 152*da0073e9SAndroid Build Coastguard Worker tag: mul.Tensors(run, tag) 153*da0073e9SAndroid Build Coastguard Worker for run, run_dict in scalar_dict.items() 154*da0073e9SAndroid Build Coastguard Worker for tag in run_dict 155*da0073e9SAndroid Build Coastguard Worker } 156*da0073e9SAndroid Build Coastguard Worker scalars = { 157*da0073e9SAndroid Build Coastguard Worker tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items() 158*da0073e9SAndroid Build Coastguard Worker } 159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scalars, { 160*da0073e9SAndroid Build Coastguard Worker "asdf.sum": [10], 161*da0073e9SAndroid Build Coastguard Worker "asdf.count": [5], 162*da0073e9SAndroid Build Coastguard Worker }) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 166*da0073e9SAndroid Build Coastguard Worker run_tests() 167