1from .CPUMetric import CPUMetric 2from .CUDAMetric import CUDAMetric 3 4 5class MetricsLogger: 6 def __init__(self, rank=None): 7 self.rank = rank 8 self.metrics = {} 9 10 def record_start(self, type, key, name, cuda): 11 if type in self.metrics and key in self.metrics[type]: 12 raise RuntimeError(f"metric_type={type} with key={key} already exists") 13 if cuda: 14 if self.rank is None: 15 raise RuntimeError("rank is required for cuda") 16 metric = CUDAMetric(self.rank, name) 17 else: 18 metric = CPUMetric(name) 19 if type not in self.metrics: 20 self.metrics[type] = {} 21 self.metrics[type][key] = metric 22 metric.record_start() 23 24 def record_end(self, type, key): 25 if type not in self.metrics or key not in self.metrics[type]: 26 raise RuntimeError(f"metric_type={type} with key={key} not found") 27 if self.metrics[type][key].get_end() is not None: 28 raise RuntimeError( 29 f"end for metric_type={type} with key={key} already exists" 30 ) 31 self.metrics[type][key].record_end() 32 33 def clear_metrics(self): 34 self.metrics.clear() 35 36 def get_metrics(self): 37 return self.metrics 38 39 def get_processed_metrics(self): 40 r""" 41 A method that processes the metrics recorded during the benchmark. 42 43 Returns:: 44 It returns a dictionary containing keys as the metrics 45 and values list of elapsed times. 46 47 Examples:: 48 49 >>> instance = MetricsLogger(rank) 50 >>> instance.cuda_record_start("forward_metric_type", "1", "forward_pass") 51 >>> instance.cuda_record_end("forward_metric_type", "1") 52 >>> instance.cuda_record_start("forward_metric_type", "2", "forward_pass") 53 >>> instance.cuda_record_end("forward_metric_type", "2") 54 >>> print(instance.metrics) 55 { 56 "forward_metric_type": { 57 "1": metric1, 58 "2": metric2 59 } 60 } 61 62 >>> print(instance.get_processed_metrics()) 63 { 64 "forward_metric_type,forward_pass" : [.0429, .0888] 65 } 66 """ 67 processed_metrics = {} 68 for metric_type in self.metrics.keys(): 69 for metric_key in self.metrics[metric_type].keys(): 70 metric = self.metrics[metric_type][metric_key] 71 if isinstance(metric, CUDAMetric): 72 metric.synchronize() 73 metric_name = metric.get_name() 74 elapsed_time = metric.elapsed_time() 75 processed_metric_name = f"{metric_type},{metric_name}" 76 if processed_metric_name not in processed_metrics: 77 processed_metrics[processed_metric_name] = [] 78 processed_metrics[processed_metric_name].append(elapsed_time) 79 return processed_metrics 80