xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_overlap.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import sys
4import time
5from statistics import mean
6from unittest.mock import patch
7
8import torch
9import torch.nn as nn
10from torch import distributed as dist
11from torch.cuda import Event
12from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
13from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
14from torch.testing._internal.common_fsdp import FSDPTest
15from torch.testing._internal.common_utils import (
16    get_cycles_per_ms,
17    run_tests,
18    TEST_WITH_DEV_DBG_ASAN,
19)
20
21
22if not dist.is_available():
23    print("Distributed not available, skipping tests", file=sys.stderr)
24    sys.exit(0)
25
26if TEST_WITH_DEV_DBG_ASAN:
27    print(
28        "Skip dev-asan as torch + multiprocessing spawn have known issues",
29        file=sys.stderr,
30    )
31    sys.exit(0)
32
33
34class Layer(nn.Module):
35    def __init__(self, compute_cycles, has_params: bool):
36        super().__init__()
37        self.sleep_cycles = compute_cycles
38        self.optional_param = None
39        if has_params:
40            self.optional_param = nn.Parameter(torch.rand(1))
41
42    def forward(self, x):
43        # Get 2 events.
44        self.e1 = Event(enable_timing=True)
45        self.e2 = Event(enable_timing=True)
46
47        # Record the fake forward compute time.
48        self.e1.record()
49        if self.sleep_cycles > 0:
50            torch.cuda._sleep(self.sleep_cycles)
51        if self.optional_param is not None:
52            x = x + self.optional_param  # force the param to be part of the graph
53        self.e2.record()
54        return x
55
56    def get_time(self):
57        # return the recorded duration.
58        return self.e1.elapsed_time(self.e2)
59
60
61def _create_model(compute_cycles, has_params: bool):
62    # Use `limit_all_gathers=False` since the timing being tested relies on the
63    # CPU running ahead of the GPU
64    model = FSDP(
65        nn.Sequential(
66            FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
67            FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
68            FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
69            FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
70        ),
71        limit_all_gathers=False,
72    ).cuda()
73    return model
74
75
76class Min10:
77    def __init__(self) -> None:
78        self.data = []
79
80    def add(self, new_data):
81        if len(self.data) < 10:
82            self.data.append(new_data)
83        else:
84            self.data = sorted(self.data)
85            if new_data < self.data[-1]:
86                self.data[-1] = new_data
87
88    def avg(self):
89        return mean(self.data)
90
91
92class TestForwardOverlapWorldSizeOne(FSDPTest):
93    @property
94    def world_size(self):
95        return 1
96
97    def _dist_train(self):
98        rank = self.rank
99        world_size = self.world_size
100        # Save the original torch.distributed.all_gather_into_tensor function since we will
101        # patch it to include an artificial delay.
102        orig_all_gather = torch.distributed.all_gather_into_tensor
103
104        def run(compute_cycles, all_gather_cycles):
105            has_params = all_gather_cycles > 0
106            model = _create_model(compute_cycles, has_params)
107
108            # Get the input and sets the input's requires_grad to True because
109            # we have a fake compute in the forward pass.
110            batch = torch.rand(1).cuda()
111            batch.requires_grad = True
112
113            # Run one dummy iteration to trigger the execution order validation
114            # all-gathers
115            out = model(batch)
116            out.backward()
117            model.zero_grad(set_to_none=True)
118
119            # We run 20 iterations but only collect timing data from the minimal 10
120            # data points because nondeterministic system events can disturb the timing.
121            cpu_iter = Min10()
122            cpu_wait = Min10()
123            gpu_compute = Min10()
124            gpu_total = Min10()
125            for _ in range(20):
126                # Get two events for measuring the overall time.
127                e1 = Event(enable_timing=True)
128                e2 = Event(enable_timing=True)
129
130                cpu_start = time.process_time()
131
132                all_gather_called = False
133
134                def _delayed_all_gather(*args, **kwargs):
135                    nonlocal all_gather_called
136                    all_gather_called = True
137                    torch.cuda._sleep(all_gather_cycles)
138                    assert orig_all_gather
139                    return orig_all_gather(*args, **kwargs)
140
141                # forward pass
142                #
143                # Even though both e1 & e2 are on the compute stream, since
144                # compute depends on all_gather, e2-e1 includes all_gather time.
145                e1.record()
146                with patch(
147                    "torch.distributed.all_gather_into_tensor", _delayed_all_gather
148                ):
149                    out = model(batch)
150                    if has_params and world_size > 1:
151                        self.assertTrue(all_gather_called)
152                    else:
153                        self.assertFalse(all_gather_called)
154                e2.record()
155
156                # backward pass
157                out.backward()
158                model.zero_grad(set_to_none=True)
159
160                cpu_iter_time = time.process_time() - cpu_start
161
162                # wait for gpu
163                out.item()
164                cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time
165
166                # get sum of the compute time
167                times = []
168                for mod in model.modules():
169                    if not isinstance(mod, Layer):
170                        continue
171                    times.append(mod.get_time())
172
173                # get gpu compute + all_gather time
174                overall_gpu_time = e1.elapsed_time(e2)
175
176                cpu_iter.add(cpu_iter_time)
177                cpu_wait.add(cpu_wait_for_gpu_time)
178                gpu_compute.add(sum(times))
179                gpu_total.add(overall_gpu_time)
180
181            del model
182            return {
183                "cpu_iter": cpu_iter.avg(),
184                "cpu_wait": cpu_wait.avg(),
185                "gpu_compute": gpu_compute.avg(),
186                "gpu_total": gpu_total.avg(),
187            }
188
189        sleep_cycles = int(100 * get_cycles_per_ms())
190
191        e1 = run(0, 0)  # no compute, no all-gather
192        e2 = run(0, sleep_cycles)  # no compute, only all-gather
193        e3 = run(sleep_cycles, 0)  # only compute, no all-gather
194        e4 = run(sleep_cycles, sleep_cycles)  # both compute and all-gather
195        debug_string = f"\nrank{rank}:\n  e1: {e1}\n  e2: {e2}\n  e3: {e3}\n  e4: {e4}"
196        print(debug_string)
197
198        # Check the cpu/gpu timing. CPU should run ahead of GPU. Therefore, cpu-gpu
199        # wait should be long, except when there is no real work on GPU.
200        #
201        # If the assertions fail below, we likely have a cpu-gpu wait in the forward/backward pass.
202        # e4["cpu_iter"] may not be short as cpu may take some time to queue both compute and all-gather.
203        short = [
204            e1["cpu_iter"],
205            e2["cpu_iter"],
206            e3["cpu_iter"],
207            e1["cpu_wait"],
208        ]
209        long = [e3["cpu_wait"], e4["cpu_wait"]]
210        if world_size == 1:
211            short.append(e2["cpu_wait"])  # all gather should not be happening.
212        else:
213            long.append(
214                e2["cpu_wait"]
215            )  # all gather should happen and prolong the cpu-gpu wait.
216        for s in short:
217            for l in long:
218                # 10X longer is a safe margin, since the GPU work timing is around 100X more
219                # of that of the CPU.
220                self.assertTrue(s * 10 < l)
221
222        # Check the GPU timing.
223        short = [e1["gpu_compute"], e1["gpu_total"], e2["gpu_compute"]]
224        long = [e3["gpu_compute"], e3["gpu_total"], e4["gpu_compute"], e4["gpu_total"]]
225        if world_size == 1:
226            short.append(e2["gpu_total"])  # all gather should not be happening.
227        else:
228            long.append(
229                e2["gpu_total"]
230            )  # all gather should happen and prolong the cpu-gpu wait.
231        for s in short:
232            for l in long:
233                # 10X longer is a safe margin, since the time is around 100X longer
234                # when there is work on GPU vs. no work.
235                self.assertTrue(s * 10 < l)
236
237        # Check the GPU overlapping when there is all-gather.
238        if world_size > 1:
239            compute_only = e3["gpu_compute"]
240            all_gather_only = e2["gpu_total"]
241            both = e4["gpu_total"]
242            self.assertTrue(compute_only + all_gather_only > 1.1 * both)
243
244    @skip_if_lt_x_gpu(2)
245    def test_forward_overlap(self):
246        self._dist_train()
247
248
249class TestForwardOverlapWorldSizeTwo(TestForwardOverlapWorldSizeOne):
250    @property
251    def world_size(self):
252        return 2
253
254
255if __name__ == "__main__":
256    run_tests()
257