xref: /aosp_15_r20/external/pytorch/test/jit/test_data_parallel.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5import unittest
6
7import torch
8import torch.nn as nn
9import torch.nn.parallel as dp
10
11
12# Make the helper files in test/ importable
13pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14sys.path.append(pytorch_test_dir)
15from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA_MULTI_GPU
16
17
18if __name__ == "__main__":
19    raise RuntimeError(
20        "This test file is not meant to be run directly, use:\n\n"
21        "\tpython test/test_jit.py TESTNAME\n\n"
22        "instead."
23    )
24
25
26class TestDataParallel(JitTestCase):
27    class Mpy(torch.nn.Module):
28        def __init__(self) -> None:
29            super(TestDataParallel.Mpy, self).__init__()
30            self.m = nn.Sequential(
31                nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
32            )
33
34        @torch.jit.ignore
35        def forward(self, input):
36            return self.m(input)
37
38    class Mpy1(torch.nn.Module):
39        def __init__(self, block):
40            super(TestDataParallel.Mpy1, self).__init__()
41            self.m = block
42
43        @torch.jit.ignore
44        def forward(self, input):
45            return self.m.forward(input)
46
47    class Mpy2(torch.nn.Module):
48        def __init__(self, block1, block2):
49            super(TestDataParallel.Mpy2, self).__init__()
50            self.m1 = block1
51            self.m2 = block2
52
53        @torch.jit.ignore
54        def forward(self, input):
55            x = self.m1.forward(input)
56            return self.m2(x)
57
58    class Msm(torch.jit.ScriptModule):
59        __constants__ = ["m"]
60
61        def __init__(self) -> None:
62            super(TestDataParallel.Msm, self).__init__()
63            self.m = nn.Sequential(
64                nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
65            )
66
67        @torch.jit.script_method
68        def forward(self, input):
69            return self.m(input)
70
71    class Msm1(torch.jit.ScriptModule):
72        def __init__(self, block):
73            super(TestDataParallel.Msm1, self).__init__()
74            self.block = block
75
76        @torch.jit.script_method
77        def forward(self, input):
78            x = self.block(input)
79            return x
80
81    def check_replicas(self, module, replicas, input_shape=(2, 2)):
82        input = torch.randn(input_shape).cuda()
83        expected_output = module(input).data
84        for i, replica in enumerate(replicas):
85            for p in replica.parameters():
86                self.assertEqual(p.get_device(), i)
87            for b in replica.buffers():
88                self.assertEqual(b.get_device(), i)
89            replica_input = input.cuda(i)
90            self.assertEqual(replica(replica_input).data, expected_output)
91
92    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
93    def test_python_submodule_script(self):
94        module = self.Mpy1(self.Msm()).cuda()
95        replicas = dp.replicate(module, {0, 1})
96        self.check_replicas(module, replicas)
97
98    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
99    def test_shared_module(self):
100        s = self.Msm()
101        p1 = self.Mpy1(s)
102        module = self.Mpy2(p1, s).cuda()
103        replicas = dp.replicate(module, {0, 1})
104        self.check_replicas(module, replicas)
105
106    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
107    def test_traced_module(self):
108        module = torch.jit.trace(self.Mpy1(self.Mpy()), torch.ones(2, 2)).cuda()
109        replicas = dp.replicate(module, {0, 1})
110        self.check_replicas(module, replicas)
111
112    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
113    def test_tensor_sharing(self):
114        module = self.Msm1(self.Msm()).cuda()
115        replica = dp.replicate(module, {0, 1})
116
117        def assert_share_data(t1, t2):
118            # Only checks that they point to the same memory on the same device.
119            return (
120                t1.device == t2.device
121                and t1.storage().data_ptr() == t2.storage().data_ptr()
122            )
123
124        for p1, p2 in zip(module.parameters(), replica[0].parameters()):
125            self.assertTrue(assert_share_data(p1, p2))
126
127        for p1, p2 in zip(module.buffers(), replica[0].buffers()):
128            self.assertTrue(assert_share_data(p1, p2))
129
130        for p1, p2 in zip(module.parameters(), replica[1].parameters()):
131            self.assertFalse(assert_share_data(p1, p2))
132
133        for p1, p2 in zip(module.buffers(), replica[1].buffers()):
134            self.assertFalse(assert_share_data(p1, p2))
135
136    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
137    def test_tensor_sharing_with_forward(self):
138        module = self.Msm1(self.Msm()).cuda()
139        replica = dp.replicate(module, {0, 1})
140        x = torch.ones(2, 2, requires_grad=True).cuda()
141        first_forward = module(x)
142        first_forward.sum().backward()
143        with torch.no_grad():
144            for p in module.parameters():
145                # Use .data here to avoid version counter bump.
146                # The graph created by the following forward will be wrong but
147                # we never backward through them so it's fine
148                p.data -= 1.0 * p.grad
149        second_forward = module(x)
150
151        # replica which is on the same GPU has a shallow copy of the original
152        # params and buffers
153        r0_forward = replica[0](x)
154        self.assertEqual(second_forward, r0_forward)
155
156        # replica which is on a different GPU has a deep copy of the original
157        # params and buffers
158        x1 = torch.ones(2, 2, requires_grad=True).cuda(device=1)
159        r1_forward = replica[1](x1)
160        self.assertEqual(first_forward, r1_forward)
161