xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/parameter_server/metrics/MetricsLogger.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from .CPUMetric import CPUMetric
2from .CUDAMetric import CUDAMetric
3
4
5class MetricsLogger:
6    def __init__(self, rank=None):
7        self.rank = rank
8        self.metrics = {}
9
10    def record_start(self, type, key, name, cuda):
11        if type in self.metrics and key in self.metrics[type]:
12            raise RuntimeError(f"metric_type={type} with key={key} already exists")
13        if cuda:
14            if self.rank is None:
15                raise RuntimeError("rank is required for cuda")
16            metric = CUDAMetric(self.rank, name)
17        else:
18            metric = CPUMetric(name)
19        if type not in self.metrics:
20            self.metrics[type] = {}
21        self.metrics[type][key] = metric
22        metric.record_start()
23
24    def record_end(self, type, key):
25        if type not in self.metrics or key not in self.metrics[type]:
26            raise RuntimeError(f"metric_type={type} with key={key} not found")
27        if self.metrics[type][key].get_end() is not None:
28            raise RuntimeError(
29                f"end for metric_type={type} with key={key} already exists"
30            )
31        self.metrics[type][key].record_end()
32
33    def clear_metrics(self):
34        self.metrics.clear()
35
36    def get_metrics(self):
37        return self.metrics
38
39    def get_processed_metrics(self):
40        r"""
41        A method that processes the metrics recorded during the benchmark.
42
43        Returns::
44            It returns a dictionary containing keys as the metrics
45                and values list of elapsed times.
46
47        Examples::
48
49            >>> instance = MetricsLogger(rank)
50            >>> instance.cuda_record_start("forward_metric_type", "1", "forward_pass")
51            >>> instance.cuda_record_end("forward_metric_type", "1")
52            >>> instance.cuda_record_start("forward_metric_type", "2", "forward_pass")
53            >>> instance.cuda_record_end("forward_metric_type", "2")
54            >>> print(instance.metrics)
55            {
56                "forward_metric_type": {
57                    "1": metric1,
58                    "2": metric2
59                }
60            }
61
62            >>> print(instance.get_processed_metrics())
63            {
64                "forward_metric_type,forward_pass" : [.0429, .0888]
65            }
66        """
67        processed_metrics = {}
68        for metric_type in self.metrics.keys():
69            for metric_key in self.metrics[metric_type].keys():
70                metric = self.metrics[metric_type][metric_key]
71                if isinstance(metric, CUDAMetric):
72                    metric.synchronize()
73                metric_name = metric.get_name()
74                elapsed_time = metric.elapsed_time()
75                processed_metric_name = f"{metric_type},{metric_name}"
76                if processed_metric_name not in processed_metrics:
77                    processed_metrics[processed_metric_name] = []
78                processed_metrics[processed_metric_name].append(elapsed_time)
79        return processed_metrics
80