1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5import unittest 6 7import torch 8import torch.nn as nn 9import torch.nn.parallel as dp 10 11 12# Make the helper files in test/ importable 13pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 14sys.path.append(pytorch_test_dir) 15from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA_MULTI_GPU 16 17 18if __name__ == "__main__": 19 raise RuntimeError( 20 "This test file is not meant to be run directly, use:\n\n" 21 "\tpython test/test_jit.py TESTNAME\n\n" 22 "instead." 23 ) 24 25 26class TestDataParallel(JitTestCase): 27 class Mpy(torch.nn.Module): 28 def __init__(self) -> None: 29 super(TestDataParallel.Mpy, self).__init__() 30 self.m = nn.Sequential( 31 nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2) 32 ) 33 34 @torch.jit.ignore 35 def forward(self, input): 36 return self.m(input) 37 38 class Mpy1(torch.nn.Module): 39 def __init__(self, block): 40 super(TestDataParallel.Mpy1, self).__init__() 41 self.m = block 42 43 @torch.jit.ignore 44 def forward(self, input): 45 return self.m.forward(input) 46 47 class Mpy2(torch.nn.Module): 48 def __init__(self, block1, block2): 49 super(TestDataParallel.Mpy2, self).__init__() 50 self.m1 = block1 51 self.m2 = block2 52 53 @torch.jit.ignore 54 def forward(self, input): 55 x = self.m1.forward(input) 56 return self.m2(x) 57 58 class Msm(torch.jit.ScriptModule): 59 __constants__ = ["m"] 60 61 def __init__(self) -> None: 62 super(TestDataParallel.Msm, self).__init__() 63 self.m = nn.Sequential( 64 nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2) 65 ) 66 67 @torch.jit.script_method 68 def forward(self, input): 69 return self.m(input) 70 71 class Msm1(torch.jit.ScriptModule): 72 def __init__(self, block): 73 super(TestDataParallel.Msm1, self).__init__() 74 self.block = block 75 76 @torch.jit.script_method 77 def forward(self, input): 78 x = self.block(input) 79 return x 80 81 def check_replicas(self, module, replicas, input_shape=(2, 2)): 82 input = torch.randn(input_shape).cuda() 83 expected_output = module(input).data 84 for i, replica in enumerate(replicas): 85 for p in replica.parameters(): 86 self.assertEqual(p.get_device(), i) 87 for b in replica.buffers(): 88 self.assertEqual(b.get_device(), i) 89 replica_input = input.cuda(i) 90 self.assertEqual(replica(replica_input).data, expected_output) 91 92 @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported") 93 def test_python_submodule_script(self): 94 module = self.Mpy1(self.Msm()).cuda() 95 replicas = dp.replicate(module, {0, 1}) 96 self.check_replicas(module, replicas) 97 98 @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported") 99 def test_shared_module(self): 100 s = self.Msm() 101 p1 = self.Mpy1(s) 102 module = self.Mpy2(p1, s).cuda() 103 replicas = dp.replicate(module, {0, 1}) 104 self.check_replicas(module, replicas) 105 106 @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported") 107 def test_traced_module(self): 108 module = torch.jit.trace(self.Mpy1(self.Mpy()), torch.ones(2, 2)).cuda() 109 replicas = dp.replicate(module, {0, 1}) 110 self.check_replicas(module, replicas) 111 112 @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported") 113 def test_tensor_sharing(self): 114 module = self.Msm1(self.Msm()).cuda() 115 replica = dp.replicate(module, {0, 1}) 116 117 def assert_share_data(t1, t2): 118 # Only checks that they point to the same memory on the same device. 119 return ( 120 t1.device == t2.device 121 and t1.storage().data_ptr() == t2.storage().data_ptr() 122 ) 123 124 for p1, p2 in zip(module.parameters(), replica[0].parameters()): 125 self.assertTrue(assert_share_data(p1, p2)) 126 127 for p1, p2 in zip(module.buffers(), replica[0].buffers()): 128 self.assertTrue(assert_share_data(p1, p2)) 129 130 for p1, p2 in zip(module.parameters(), replica[1].parameters()): 131 self.assertFalse(assert_share_data(p1, p2)) 132 133 for p1, p2 in zip(module.buffers(), replica[1].buffers()): 134 self.assertFalse(assert_share_data(p1, p2)) 135 136 @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported") 137 def test_tensor_sharing_with_forward(self): 138 module = self.Msm1(self.Msm()).cuda() 139 replica = dp.replicate(module, {0, 1}) 140 x = torch.ones(2, 2, requires_grad=True).cuda() 141 first_forward = module(x) 142 first_forward.sum().backward() 143 with torch.no_grad(): 144 for p in module.parameters(): 145 # Use .data here to avoid version counter bump. 146 # The graph created by the following forward will be wrong but 147 # we never backward through them so it's fine 148 p.data -= 1.0 * p.grad 149 second_forward = module(x) 150 151 # replica which is on the same GPU has a shallow copy of the original 152 # params and buffers 153 r0_forward = replica[0](x) 154 self.assertEqual(second_forward, r0_forward) 155 156 # replica which is on a different GPU has a deep copy of the original 157 # params and buffers 158 x1 = torch.ones(2, 2, requires_grad=True).cuda(device=1) 159 r1_forward = replica[1](x1) 160 self.assertEqual(first_forward, r1_forward) 161