xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/parameter_server/trainer/hook_states.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1class BasicHookState:
2    def __init__(self, cref, process_group):
3        r"""
4        A class that holds state information that is needed by the communication hook
5        during the training algorithm.
6        Args:
7            cref (DdpTrainer): reference to the self keyword of the trainer instance
8            process_group (ProcessGroup): distributed process group
9        """
10        self.cref = cref
11        self.process_group = process_group
12        self.batch_number = -1
13
14    def get_key(self, bucket_index):
15        r"""
16        A method that returns an encoded key that represents the current batch and
17        bucket index.
18        Args:
19            bucket_index (int): index of the bucket being processed in backward
20        """
21        return f"{self.batch_number},{bucket_index}"
22
23    def next_batch(self):
24        r"""
25        A method that increments batch_number by 1.
26        """
27        self.batch_number += 1
28