xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/parameter_server/metrics/CUDAMetric.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2
3from .MetricBase import MetricBase
4
5
6class CUDAMetric(MetricBase):
7    def __init__(self, rank: int, name: str):
8        self.rank = rank
9        self.name = name
10        self.start = None
11        self.end = None
12
13    def record_start(self):
14        self.start = torch.cuda.Event(enable_timing=True)
15        with torch.cuda.device(self.rank):
16            self.start.record()
17
18    def record_end(self):
19        self.end = torch.cuda.Event(enable_timing=True)
20        with torch.cuda.device(self.rank):
21            self.end.record()
22
23    def elapsed_time(self):
24        if not self.start.query():
25            raise RuntimeError("start event did not complete")
26        if not self.end.query():
27            raise RuntimeError("end event did not complete")
28        return self.start.elapsed_time(self.end)
29
30    def synchronize(self):
31        self.start.synchronize()
32        self.end.synchronize()
33