xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/metrics/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10import abc
11import time
12from collections import namedtuple
13from functools import wraps
14from typing import Dict, Optional
15from typing_extensions import deprecated
16
17
18__all__ = [
19    "MetricsConfig",
20    "MetricHandler",
21    "ConsoleMetricHandler",
22    "NullMetricHandler",
23    "MetricStream",
24    "configure",
25    "getStream",
26    "prof",
27    "profile",
28    "put_metric",
29    "publish_metric",
30    "get_elapsed_time_ms",
31    "MetricData",
32]
33
34MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"])
35
36
37class MetricsConfig:
38    __slots__ = ["params"]
39
40    def __init__(self, params: Optional[Dict[str, str]] = None):
41        self.params = params
42        if self.params is None:
43            self.params = {}
44
45
46class MetricHandler(abc.ABC):
47    @abc.abstractmethod
48    def emit(self, metric_data: MetricData):
49        pass
50
51
52class ConsoleMetricHandler(MetricHandler):
53    def emit(self, metric_data: MetricData):
54        print(
55            f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}"
56        )
57
58
59class NullMetricHandler(MetricHandler):
60    def emit(self, metric_data: MetricData):
61        pass
62
63
64class MetricStream:
65    def __init__(self, group_name: str, handler: MetricHandler):
66        self.group_name = group_name
67        self.handler = handler
68
69    def add_value(self, metric_name: str, metric_value: int):
70        self.handler.emit(
71            MetricData(time.time(), self.group_name, metric_name, metric_value)
72        )
73
74
75_metrics_map: Dict[str, MetricHandler] = {}
76_default_metrics_handler: MetricHandler = NullMetricHandler()
77
78
79# pyre-fixme[9]: group has type `str`; used as `None`.
80def configure(handler: MetricHandler, group: Optional[str] = None):
81    if group is None:
82        global _default_metrics_handler
83        # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
84        #  as `MetricHandler`.
85        _default_metrics_handler = handler
86    else:
87        _metrics_map[group] = handler
88
89
90def getStream(group: str):
91    if group in _metrics_map:
92        handler = _metrics_map[group]
93    else:
94        handler = _default_metrics_handler
95    return MetricStream(group, handler)
96
97
98def _get_metric_name(fn):
99    qualname = fn.__qualname__
100    split = qualname.split(".")
101    if len(split) == 1:
102        module = fn.__module__
103        if module:
104            return module.split(".")[-1] + "." + split[0]
105        else:
106            return split[0]
107    else:
108        return qualname
109
110
111def prof(fn=None, group: str = "torchelastic"):
112    r"""
113    @profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates.
114
115    The metric name defaults to the qualified name (``class_name.def_name``) of the function.
116    If the function does not belong to a class, it uses the leaf module name instead.
117
118    Usage
119
120    ::
121
122     @metrics.prof
123     def x():
124         pass
125
126     @metrics.prof(group="agent")
127     def y():
128         pass
129    """
130
131    def wrap(f):
132        @wraps(f)
133        def wrapper(*args, **kwargs):
134            key = _get_metric_name(f)
135            try:
136                start = time.time()
137                result = f(*args, **kwargs)
138                put_metric(f"{key}.success", 1, group)
139            except Exception:
140                put_metric(f"{key}.failure", 1, group)
141                raise
142            finally:
143                put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group)  # type: ignore[possibly-undefined]
144            return result
145
146        return wrapper
147
148    if fn:
149        return wrap(fn)
150    else:
151        return wrap
152
153
154@deprecated("Deprecated, use `@prof` instead", category=FutureWarning)
155def profile(group=None):
156    """
157    @profile decorator adds latency and success/failure metrics to any given function.
158
159    Usage
160
161    ::
162
163     @metrics.profile("my_metric_group")
164     def some_function(<arguments>):
165    """
166
167    def wrap(func):
168        @wraps(func)
169        def wrapper(*args, **kwargs):
170            try:
171                start_time = time.time()
172                result = func(*args, **kwargs)
173                publish_metric(group, f"{func.__name__}.success", 1)
174            except Exception:
175                publish_metric(group, f"{func.__name__}.failure", 1)
176                raise
177            finally:
178                publish_metric(
179                    group,
180                    f"{func.__name__}.duration.ms",
181                    get_elapsed_time_ms(start_time),  # type: ignore[possibly-undefined]
182                )
183            return result
184
185        return wrapper
186
187    return wrap
188
189
190def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"):
191    """
192    Publish a metric data point.
193
194    Usage
195
196    ::
197
198     put_metric("metric_name", 1)
199     put_metric("metric_name", 1, "metric_group_name")
200    """
201    getStream(metric_group).add_value(metric_name, metric_value)
202
203
204@deprecated(
205    "Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead",
206    category=FutureWarning,
207)
208def publish_metric(metric_group: str, metric_name: str, metric_value: int):
209    metric_stream = getStream(metric_group)
210    metric_stream.add_value(metric_name, metric_value)
211
212
213def get_elapsed_time_ms(start_time_in_seconds: float):
214    """Return the elapsed time in millis from the given start time."""
215    end_time = time.time()
216    return int((end_time - start_time_in_seconds) * 1000)
217