1import torch 2 3from .MetricBase import MetricBase 4 5 6class CUDAMetric(MetricBase): 7 def __init__(self, rank: int, name: str): 8 self.rank = rank 9 self.name = name 10 self.start = None 11 self.end = None 12 13 def record_start(self): 14 self.start = torch.cuda.Event(enable_timing=True) 15 with torch.cuda.device(self.rank): 16 self.start.record() 17 18 def record_end(self): 19 self.end = torch.cuda.Event(enable_timing=True) 20 with torch.cuda.device(self.rank): 21 self.end.record() 22 23 def elapsed_time(self): 24 if not self.start.query(): 25 raise RuntimeError("start event did not complete") 26 if not self.end.query(): 27 raise RuntimeError("end event did not complete") 28 return self.start.elapsed_time(self.end) 29 30 def synchronize(self): 31 self.start.synchronize() 32 self.end.synchronize() 33