xref: /aosp_15_r20/external/pytorch/benchmarks/distributed/rpc/parameter_server/trainer/trainer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import functools
2import time
3from abc import ABC, abstractmethod
4
5from metrics.MetricsLogger import MetricsLogger
6
7import torch
8
9
10class TrainerBase(ABC):
11    BATCH_LEVEL_METRIC = "batch_level_metric"
12    BATCH_ALL = "batch_all"
13    FORWARD_METRIC = "forward_metric"
14    FORWARD_PASS = "forward_pass"
15    BACKWARD_METRIC = "backward_metric"
16    BACKWARD = "backward"
17
18    def __init__(self, rank):
19        r"""
20        Inits TrainerBase class.
21        Args:
22            rank (int): worker rank
23        """
24        self.__metrics_logger = MetricsLogger(rank)
25
26    @abstractmethod
27    def train(self):
28        r"""
29        A method to be implemented by child class that will train a neural network.
30        """
31        return
32
33    def record_start(self, type, key, name, cuda=True):
34        r"""
35        A method that records the start event for a metric.
36        Args:
37            type (str): group id for metric
38            key (str): unique id for metric within a group
39            name (str): description of the metric
40            cuda (bool): indicator to determine if this is a CUDA metric
41        """
42        self.__metrics_logger.record_start(type, key, name, cuda)
43
44    def record_end(self, type, key):
45        r"""
46        A method that records the end event for a metric.
47        Args:
48            type (str): group id for metric
49            key (str): unique id for metric within a group
50        """
51        self.__metrics_logger.record_end(type, key)
52
53    def record_batch_start(self, key, cuda=True):
54        r"""
55        A helper method that records a batch metric for the
56        given key. A user should call this at the start of an
57        iteration step during training.
58        Args:
59            key (str): unique id for metric within a group
60            cuda (bool): indicator to determine if this is a CUDA metric
61        """
62        self.__metrics_logger.record_start(
63            self.BATCH_LEVEL_METRIC, key, self.BATCH_ALL, cuda
64        )
65
66    def record_batch_end(self, key):
67        r"""
68        A helper method that records a batch metric for the
69        given key. A user should call this at the end of an
70        iteration step during training.
71        Args:
72            key (str): unique id for metric within a group
73        """
74        self.__metrics_logger.record_end(self.BATCH_LEVEL_METRIC, key)
75
76    def record_forward_start(self, key, cuda=True):
77        r"""
78        A helper method that records a forward metric
79        for the given key. A user should call this before
80        their neural network forward.
81        Args:
82            key (str): unique id for metric within a group
83            cuda (bool): indicator to determine if this is a CUDA metric
84        """
85        self.__metrics_logger.record_start(
86            self.FORWARD_METRIC, key, self.FORWARD_PASS, cuda
87        )
88
89    def record_forward_end(self, key):
90        r"""
91        A helper method that records a forward metric
92        for the given key. A user should call this after their
93        neural network forward.
94        Args:
95            key (str): unique id for metric within a group
96        """
97        self.__metrics_logger.record_end(self.FORWARD_METRIC, key)
98
99    def record_backward_start(self, key, cuda=True):
100        r"""
101        A helper method that records a backward metric
102        for the given key. A user should call this before
103        their .backward() call.
104        Args:
105            key (str): unique id for metric within a group
106            cuda (bool): indicator to determine if this is a CUDA metric
107        """
108        self.__metrics_logger.record_start(
109            self.BACKWARD_METRIC, key, self.BACKWARD, cuda
110        )
111
112    def record_backward_end(self, key):
113        r"""
114        A helper method that records a backward metric
115        for the given key. A user should call this after
116        .backward().
117        Args:
118            key (str): unique id for metric within a group
119        """
120        self.__metrics_logger.record_end(self.BACKWARD_METRIC, key)
121
122    @staticmethod
123    def methodmetric(name, type="method_metric", cuda=True):
124        r"""
125        A decorator that records a metric for the decorated method.
126        Args:
127            name (str): description of the metric
128            type (str): group id for metric
129            cuda (bool): indicator to determine if this is a CUDA metric
130        """
131
132        def decorator(function):
133            @functools.wraps(function)
134            def wrapper(self, *args):
135                key = time.time()
136                self.__metrics_logger.record_start(type, key, name, cuda)
137                result = function(self, *args)
138                self.__metrics_logger.record_end(type, key)
139                return result
140
141            return wrapper
142
143        return decorator
144
145    def get_metrics(self):
146        r"""
147        A method that returns metrics captured by the __metrics_logger.
148        """
149        return self.__metrics_logger.get_processed_metrics()
150
151    def clear_metrics(self):
152        r"""
153        A method that clears __metrics_logger recorded metrics.
154        """
155        return self.__metrics_logger.clear_metrics()
156
157
158class DdpTrainer(TrainerBase):
159    def __init__(
160        self,
161        process_group,
162        use_cuda_rpc,
163        server_rref,
164        backend,
165        epochs,
166        preprocess_data,
167        create_criterion,
168        create_ddp_model,
169        hook_state_class,
170        hook,
171        iteration_step,
172    ):
173        r"""
174        A trainer that implements a DDP training algorithm using a simple hook that performs allreduce
175        using the process_group implementation.
176        Args:
177            process_group (ProcessGroup): distributed process group
178            use_cuda_rpc (bool): indicator for CUDA RPC
179            server_rref (RRef): remote reference to the server
180            backend (str): distributed communication backend
181            epochs (int): epoch count for training
182            preprocess_data (function): preprocesses data passed
183                to the trainer before starting training
184            create_criterion (function): creates a criterion to calculate loss
185            create_ddp_model (function): creates a ddp model for the trainer
186            hook_state_class (class): class that will be used to keep tracking of state
187                during training.
188            hook (function): ddp communication hook
189            iteration_step (function): will perform 1 step of training
190        """
191        super().__init__(process_group.rank())
192        self.process_group = process_group
193        self.use_cuda_rpc = use_cuda_rpc
194        self.server_rref = server_rref
195        self.backend = backend
196        self.epochs = epochs
197        self.preprocess_data = preprocess_data
198        self.create_criterion = create_criterion
199        self.create_ddp_model = create_ddp_model
200        self.hook_state_class = hook_state_class
201        self.hook = hook
202        self.iteration_step = iteration_step
203
204        self.rank = process_group.rank()
205        self.trainer_count = process_group.size()
206
207    def epoch_key(self, epoch, index):
208        r"""
209        A method that returns an encoded key that represents the current epoch and
210        iteration index.
211        Args:
212            epoch (int): epoch index
213            index (int): iteration index
214        """
215        return f"{epoch},{index}"
216
217    def train(self, model, data):
218        r"""
219        A method that implements the training algorithm.
220        Args:
221            model (nn.Module): neural network model
222            data (list): training examples
223        """
224        model = model.cuda(self.rank)
225        data = self.preprocess_data(self.rank, data)
226        criterion = self.create_criterion(self.rank)
227        ddp_model, hook_state = self.create_ddp_model(
228            self, self.rank, model, self.process_group, self.hook_state_class, self.hook
229        )
230        optimizer = torch.optim.SGD(ddp_model.parameters(), 1e-4)
231
232        for epoch in range(self.epochs):
233            if epoch % 5 == 0 and self.rank == 0:
234                print(f"train epoch={epoch}")
235            for index, batch in enumerate(data):
236                self.iteration_step(
237                    self,
238                    ddp_model,
239                    criterion,
240                    optimizer,
241                    hook_state,
242                    epoch,
243                    index,
244                    batch,
245                )
246        torch.cuda.synchronize(self.rank)
247