1# Owner(s): ["oncall: jit"] 2 3import io 4import math 5import unittest 6 7import torch 8from torch.nn import init 9from torch.testing._internal.common_utils import skipIfLegacyJitExecutor 10from torch.testing._internal.jit_utils import JitTestCase 11 12 13if __name__ == "__main__": 14 raise RuntimeError( 15 "This test file is not meant to be run directly, use:\n\n" 16 "\tpython test/test_jit.py TESTNAME\n\n" 17 "instead." 18 ) 19 20 21class TestGenerator(JitTestCase): 22 # torch.jit.trace does not properly capture the generator manual seed 23 # and thus is non deterministic even if the generator is manually seeded 24 @skipIfLegacyJitExecutor("legacy JIT executor does not support Generator type") 25 @unittest.expectedFailure 26 def test_trace(self): 27 def f(): 28 generator = torch.Generator() 29 generator.seed() 30 generator.manual_seed(2023) 31 generator.initial_seed() 32 tensor = torch.empty(2, 2) 33 tensor.uniform_(0, 1, generator=generator) 34 return tensor 35 36 traced_f = torch.jit.trace(f, ()) 37 38 # Run this 3 times to ensure that the generator is being manually seeded 39 # each time the traced function is run 40 for i in range(3): 41 torch.manual_seed(1) 42 43 eager_tensor = f() 44 45 # Change the seed of the default generator to 46 # check that we're using the generator from the 47 # trace 48 torch.manual_seed(2) 49 traced_tensor = traced_f() 50 51 self.assertEqual(eager_tensor, traced_tensor) 52 53 def test_script(self): 54 def f(): 55 generator = torch.Generator() 56 generator.seed() 57 generator.manual_seed(2023) 58 generator.initial_seed() 59 tensor = torch.empty(2, 2) 60 tensor.normal_(-1.0, 1.0, generator=generator) 61 return tensor 62 63 script_f = torch.jit.script(f, ()) 64 65 # Run this 3 times to ensure that the generator is being manually seeded 66 # each time the traced function is run 67 for i in range(3): 68 torch.manual_seed(1) 69 70 eager_tensor = f() 71 72 # Change the seed of the default generator to 73 # check that we're using the generator from the 74 # trace 75 torch.manual_seed(2) 76 77 script_tensor = script_f() 78 79 self.assertEqual(eager_tensor, script_tensor) 80 81 def test_default_generator(self): 82 def f(): 83 # check that calling manual seed for the default generator works 84 torch.manual_seed(2023) 85 tensor = torch.empty(2, 2) 86 tensor.normal_(-1.0, 1.0) 87 return tensor 88 89 torch.manual_seed(1) 90 91 eager_tensor = f() 92 93 torch.manual_seed(2) 94 95 script_f = torch.jit.script(f, ()) 96 script_tensor = script_f() 97 98 self.assertEqual(eager_tensor, script_tensor) 99 100 def test_generator_arg(self): 101 def f(generator: torch.Generator): 102 tensor = torch.empty(2, 2) 103 tensor.normal_(-1.0, 1.0, generator=generator) 104 return tensor 105 106 generator = torch.Generator() 107 generator.manual_seed(2023) 108 109 script_f = torch.jit.script(f, (generator,)) 110 111 for i in range(3): 112 generator = torch.Generator() 113 generator.manual_seed(2023 + i) 114 115 torch.manual_seed(1 + i) 116 117 eager_tensor = f(generator) 118 119 generator = torch.Generator() 120 generator.manual_seed(2023 + i) 121 122 torch.manual_seed(1 + i) 123 124 script_tensor = script_f(generator) 125 126 self.assertEqual(eager_tensor, script_tensor) 127 128 def test_save_load(self): 129 class Foo(torch.nn.Module): 130 def __init__(self) -> None: 131 super().__init__() 132 self.foo = torch.nn.Linear(2, 2, bias=False) 133 self.bar = torch.nn.Linear(2, 2, bias=False) 134 135 self.reset_parameters() 136 137 def reset_linear(self, module, generator): 138 init.kaiming_uniform_( 139 module.weight, a=math.sqrt(5), generator=generator 140 ) 141 142 def reset_parameters(self): 143 generator = torch.Generator() 144 generator.manual_seed(1) 145 self.reset_linear(self.foo, generator) 146 147 generator = torch.Generator() 148 generator.manual_seed(2) 149 self.reset_linear(self.bar, generator) 150 151 def forward(self, x): 152 x = self.foo(x) 153 x = self.bar(x) 154 155 generator = torch.Generator() 156 generator.manual_seed(3) 157 r = torch.empty_like(x) 158 r.normal_(0.0, 1.0, generator=generator) 159 160 return x, r 161 162 eager_foo = Foo() 163 164 script_module = torch.jit.script(Foo()) 165 saved_module = io.BytesIO() 166 torch.jit.save(script_module, saved_module) 167 saved_module.seek(0) 168 169 loaded_module = torch.jit.load(saved_module) 170 171 self.assertEqual(eager_foo.foo.weight, loaded_module.foo.weight) 172 self.assertEqual(eager_foo.bar.weight, loaded_module.bar.weight) 173 174 try: 175 # Run this 3 times so make sure that the generator seed is being set 176 # every time forward is called 177 for i in range(3): 178 x = torch.ones(2, 2) 179 out1, r1 = eager_foo(x) 180 out2, r2 = loaded_module(x) 181 182 try: 183 self.assertEqual(out1, out2) 184 except: # noqa: B001, E722 185 print(f"Iteration {i}:\n{out1=}\n{out2=}") 186 raise 187 188 try: 189 self.assertEqual(r1, r2) 190 except: # noqa: B001, E722 191 print(f"Iteration {i}:\n{r1=}\n{r2=}") 192 raise 193 except: # noqa: B001, E722 194 print(loaded_module.forward.code) 195 raise 196