1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree.abs 9import abc 10import unittest.mock as mock 11 12from torch.distributed.elastic.metrics.api import ( 13 _get_metric_name, 14 MetricData, 15 MetricHandler, 16 MetricStream, 17 prof, 18) 19from torch.testing._internal.common_utils import run_tests, TestCase 20 21 22def foo_1(): 23 pass 24 25 26class TestMetricsHandler(MetricHandler): 27 def __init__(self) -> None: 28 self.metric_data = {} 29 30 def emit(self, metric_data: MetricData): 31 self.metric_data[metric_data.name] = metric_data 32 33 34class Parent(abc.ABC): 35 @abc.abstractmethod 36 def func(self): 37 raise NotImplementedError 38 39 def base_func(self): 40 self.func() 41 42 43class Child(Parent): 44 # need to decorate the implementation not the abstract method! 45 @prof 46 def func(self): 47 pass 48 49 50class MetricsApiTest(TestCase): 51 def foo_2(self): 52 pass 53 54 @prof 55 def bar(self): 56 pass 57 58 @prof 59 def throw(self): 60 raise RuntimeError 61 62 @prof(group="torchelastic") 63 def bar2(self): 64 pass 65 66 def test_get_metric_name(self): 67 # Note: since pytorch uses main method to launch tests, 68 # the module will be different between fb and oss, this 69 # allows keeping the module name consistent. 70 foo_1.__module__ = "api_test" 71 self.assertEqual("api_test.foo_1", _get_metric_name(foo_1)) 72 self.assertEqual("MetricsApiTest.foo_2", _get_metric_name(self.foo_2)) 73 74 def test_profile(self): 75 handler = TestMetricsHandler() 76 stream = MetricStream("torchelastic", handler) 77 # patch instead of configure to avoid conflicts when running tests in parallel 78 with mock.patch( 79 "torch.distributed.elastic.metrics.api.getStream", return_value=stream 80 ): 81 self.bar() 82 83 self.assertEqual(1, handler.metric_data["MetricsApiTest.bar.success"].value) 84 self.assertNotIn("MetricsApiTest.bar.failure", handler.metric_data) 85 self.assertIn("MetricsApiTest.bar.duration.ms", handler.metric_data) 86 87 with self.assertRaises(RuntimeError): 88 self.throw() 89 90 self.assertEqual( 91 1, handler.metric_data["MetricsApiTest.throw.failure"].value 92 ) 93 self.assertNotIn("MetricsApiTest.bar_raise.success", handler.metric_data) 94 self.assertIn("MetricsApiTest.throw.duration.ms", handler.metric_data) 95 96 self.bar2() 97 self.assertEqual( 98 "torchelastic", 99 handler.metric_data["MetricsApiTest.bar2.success"].group_name, 100 ) 101 102 def test_inheritance(self): 103 handler = TestMetricsHandler() 104 stream = MetricStream("torchelastic", handler) 105 # patch instead of configure to avoid conflicts when running tests in parallel 106 with mock.patch( 107 "torch.distributed.elastic.metrics.api.getStream", return_value=stream 108 ): 109 c = Child() 110 c.base_func() 111 112 self.assertEqual(1, handler.metric_data["Child.func.success"].value) 113 self.assertIn("Child.func.duration.ms", handler.metric_data) 114 115 116if __name__ == "__main__": 117 run_tests() 118