1def basic_iteration_step( 2 self, ddp_model, criterion, optimizer, hook_state, epoch, index, batch 3): 4 r""" 5 A function that performs an iteration of training. 6 Args: 7 ddp_model (nn.Module): distributed data parallel model 8 criterion (nn.Module): loss function to measure model 9 optimizer (optim.Optimizer): updates model parameters 10 hook_state (object): ddp communication hook state object 11 epoch (int): index of pass through the data 12 index (int): iteration number - 1 in current batch 13 batch (list): training examples 14 """ 15 hook_state.next_batch() 16 self.record_batch_start(self.epoch_key(epoch, index)) 17 optimizer.zero_grad() 18 self.record_forward_start(self.epoch_key(epoch, index)) 19 loss = criterion(ddp_model(batch[0]), batch[1]) 20 self.record_forward_end(self.epoch_key(epoch, index)) 21 self.record_backward_start(self.epoch_key(epoch, index)) 22 loss.backward() 23 self.record_backward_end(self.epoch_key(epoch, index)) 24 optimizer.step() 25 self.record_batch_end(self.epoch_key(epoch, index)) 26