1import functools 2import time 3from abc import ABC, abstractmethod 4 5from metrics.MetricsLogger import MetricsLogger 6 7import torch 8 9 10class TrainerBase(ABC): 11 BATCH_LEVEL_METRIC = "batch_level_metric" 12 BATCH_ALL = "batch_all" 13 FORWARD_METRIC = "forward_metric" 14 FORWARD_PASS = "forward_pass" 15 BACKWARD_METRIC = "backward_metric" 16 BACKWARD = "backward" 17 18 def __init__(self, rank): 19 r""" 20 Inits TrainerBase class. 21 Args: 22 rank (int): worker rank 23 """ 24 self.__metrics_logger = MetricsLogger(rank) 25 26 @abstractmethod 27 def train(self): 28 r""" 29 A method to be implemented by child class that will train a neural network. 30 """ 31 return 32 33 def record_start(self, type, key, name, cuda=True): 34 r""" 35 A method that records the start event for a metric. 36 Args: 37 type (str): group id for metric 38 key (str): unique id for metric within a group 39 name (str): description of the metric 40 cuda (bool): indicator to determine if this is a CUDA metric 41 """ 42 self.__metrics_logger.record_start(type, key, name, cuda) 43 44 def record_end(self, type, key): 45 r""" 46 A method that records the end event for a metric. 47 Args: 48 type (str): group id for metric 49 key (str): unique id for metric within a group 50 """ 51 self.__metrics_logger.record_end(type, key) 52 53 def record_batch_start(self, key, cuda=True): 54 r""" 55 A helper method that records a batch metric for the 56 given key. A user should call this at the start of an 57 iteration step during training. 58 Args: 59 key (str): unique id for metric within a group 60 cuda (bool): indicator to determine if this is a CUDA metric 61 """ 62 self.__metrics_logger.record_start( 63 self.BATCH_LEVEL_METRIC, key, self.BATCH_ALL, cuda 64 ) 65 66 def record_batch_end(self, key): 67 r""" 68 A helper method that records a batch metric for the 69 given key. A user should call this at the end of an 70 iteration step during training. 71 Args: 72 key (str): unique id for metric within a group 73 """ 74 self.__metrics_logger.record_end(self.BATCH_LEVEL_METRIC, key) 75 76 def record_forward_start(self, key, cuda=True): 77 r""" 78 A helper method that records a forward metric 79 for the given key. A user should call this before 80 their neural network forward. 81 Args: 82 key (str): unique id for metric within a group 83 cuda (bool): indicator to determine if this is a CUDA metric 84 """ 85 self.__metrics_logger.record_start( 86 self.FORWARD_METRIC, key, self.FORWARD_PASS, cuda 87 ) 88 89 def record_forward_end(self, key): 90 r""" 91 A helper method that records a forward metric 92 for the given key. A user should call this after their 93 neural network forward. 94 Args: 95 key (str): unique id for metric within a group 96 """ 97 self.__metrics_logger.record_end(self.FORWARD_METRIC, key) 98 99 def record_backward_start(self, key, cuda=True): 100 r""" 101 A helper method that records a backward metric 102 for the given key. A user should call this before 103 their .backward() call. 104 Args: 105 key (str): unique id for metric within a group 106 cuda (bool): indicator to determine if this is a CUDA metric 107 """ 108 self.__metrics_logger.record_start( 109 self.BACKWARD_METRIC, key, self.BACKWARD, cuda 110 ) 111 112 def record_backward_end(self, key): 113 r""" 114 A helper method that records a backward metric 115 for the given key. A user should call this after 116 .backward(). 117 Args: 118 key (str): unique id for metric within a group 119 """ 120 self.__metrics_logger.record_end(self.BACKWARD_METRIC, key) 121 122 @staticmethod 123 def methodmetric(name, type="method_metric", cuda=True): 124 r""" 125 A decorator that records a metric for the decorated method. 126 Args: 127 name (str): description of the metric 128 type (str): group id for metric 129 cuda (bool): indicator to determine if this is a CUDA metric 130 """ 131 132 def decorator(function): 133 @functools.wraps(function) 134 def wrapper(self, *args): 135 key = time.time() 136 self.__metrics_logger.record_start(type, key, name, cuda) 137 result = function(self, *args) 138 self.__metrics_logger.record_end(type, key) 139 return result 140 141 return wrapper 142 143 return decorator 144 145 def get_metrics(self): 146 r""" 147 A method that returns metrics captured by the __metrics_logger. 148 """ 149 return self.__metrics_logger.get_processed_metrics() 150 151 def clear_metrics(self): 152 r""" 153 A method that clears __metrics_logger recorded metrics. 154 """ 155 return self.__metrics_logger.clear_metrics() 156 157 158class DdpTrainer(TrainerBase): 159 def __init__( 160 self, 161 process_group, 162 use_cuda_rpc, 163 server_rref, 164 backend, 165 epochs, 166 preprocess_data, 167 create_criterion, 168 create_ddp_model, 169 hook_state_class, 170 hook, 171 iteration_step, 172 ): 173 r""" 174 A trainer that implements a DDP training algorithm using a simple hook that performs allreduce 175 using the process_group implementation. 176 Args: 177 process_group (ProcessGroup): distributed process group 178 use_cuda_rpc (bool): indicator for CUDA RPC 179 server_rref (RRef): remote reference to the server 180 backend (str): distributed communication backend 181 epochs (int): epoch count for training 182 preprocess_data (function): preprocesses data passed 183 to the trainer before starting training 184 create_criterion (function): creates a criterion to calculate loss 185 create_ddp_model (function): creates a ddp model for the trainer 186 hook_state_class (class): class that will be used to keep tracking of state 187 during training. 188 hook (function): ddp communication hook 189 iteration_step (function): will perform 1 step of training 190 """ 191 super().__init__(process_group.rank()) 192 self.process_group = process_group 193 self.use_cuda_rpc = use_cuda_rpc 194 self.server_rref = server_rref 195 self.backend = backend 196 self.epochs = epochs 197 self.preprocess_data = preprocess_data 198 self.create_criterion = create_criterion 199 self.create_ddp_model = create_ddp_model 200 self.hook_state_class = hook_state_class 201 self.hook = hook 202 self.iteration_step = iteration_step 203 204 self.rank = process_group.rank() 205 self.trainer_count = process_group.size() 206 207 def epoch_key(self, epoch, index): 208 r""" 209 A method that returns an encoded key that represents the current epoch and 210 iteration index. 211 Args: 212 epoch (int): epoch index 213 index (int): iteration index 214 """ 215 return f"{epoch},{index}" 216 217 def train(self, model, data): 218 r""" 219 A method that implements the training algorithm. 220 Args: 221 model (nn.Module): neural network model 222 data (list): training examples 223 """ 224 model = model.cuda(self.rank) 225 data = self.preprocess_data(self.rank, data) 226 criterion = self.create_criterion(self.rank) 227 ddp_model, hook_state = self.create_ddp_model( 228 self, self.rank, model, self.process_group, self.hook_state_class, self.hook 229 ) 230 optimizer = torch.optim.SGD(ddp_model.parameters(), 1e-4) 231 232 for epoch in range(self.epochs): 233 if epoch % 5 == 0 and self.rank == 0: 234 print(f"train epoch={epoch}") 235 for index, batch in enumerate(data): 236 self.iteration_step( 237 self, 238 ddp_model, 239 criterion, 240 optimizer, 241 hook_state, 242 epoch, 243 index, 244 batch, 245 ) 246 torch.cuda.synchronize(self.rank) 247