xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/parameter_server/trainer/iteration_steps.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1def basic_iteration_step(
2    self, ddp_model, criterion, optimizer, hook_state, epoch, index, batch
3):
4    r"""
5    A function that performs an iteration of training.
6    Args:
7        ddp_model (nn.Module): distributed data parallel model
8        criterion (nn.Module): loss function to measure model
9        optimizer (optim.Optimizer): updates model parameters
10        hook_state (object): ddp communication hook state object
11        epoch (int): index of pass through the data
12        index (int): iteration number - 1 in current batch
13        batch (list): training examples
14    """
15    hook_state.next_batch()
16    self.record_batch_start(self.epoch_key(epoch, index))
17    optimizer.zero_grad()
18    self.record_forward_start(self.epoch_key(epoch, index))
19    loss = criterion(ddp_model(batch[0]), batch[1])
20    self.record_forward_end(self.epoch_key(epoch, index))
21    self.record_backward_start(self.epoch_key(epoch, index))
22    loss.backward()
23    self.record_backward_end(self.epoch_key(epoch, index))
24    optimizer.step()
25    self.record_batch_end(self.epoch_key(epoch, index))
26