# Owner(s): ["oncall: jit"] import io import math import unittest import torch from torch.nn import init from torch.testing._internal.common_utils import skipIfLegacyJitExecutor from torch.testing._internal.jit_utils import JitTestCase if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestGenerator(JitTestCase): # torch.jit.trace does not properly capture the generator manual seed # and thus is non deterministic even if the generator is manually seeded @skipIfLegacyJitExecutor("legacy JIT executor does not support Generator type") @unittest.expectedFailure def test_trace(self): def f(): generator = torch.Generator() generator.seed() generator.manual_seed(2023) generator.initial_seed() tensor = torch.empty(2, 2) tensor.uniform_(0, 1, generator=generator) return tensor traced_f = torch.jit.trace(f, ()) # Run this 3 times to ensure that the generator is being manually seeded # each time the traced function is run for i in range(3): torch.manual_seed(1) eager_tensor = f() # Change the seed of the default generator to # check that we're using the generator from the # trace torch.manual_seed(2) traced_tensor = traced_f() self.assertEqual(eager_tensor, traced_tensor) def test_script(self): def f(): generator = torch.Generator() generator.seed() generator.manual_seed(2023) generator.initial_seed() tensor = torch.empty(2, 2) tensor.normal_(-1.0, 1.0, generator=generator) return tensor script_f = torch.jit.script(f, ()) # Run this 3 times to ensure that the generator is being manually seeded # each time the traced function is run for i in range(3): torch.manual_seed(1) eager_tensor = f() # Change the seed of the default generator to # check that we're using the generator from the # trace torch.manual_seed(2) script_tensor = script_f() self.assertEqual(eager_tensor, script_tensor) def test_default_generator(self): def f(): # check that calling manual seed for the default generator works torch.manual_seed(2023) tensor = torch.empty(2, 2) tensor.normal_(-1.0, 1.0) return tensor torch.manual_seed(1) eager_tensor = f() torch.manual_seed(2) script_f = torch.jit.script(f, ()) script_tensor = script_f() self.assertEqual(eager_tensor, script_tensor) def test_generator_arg(self): def f(generator: torch.Generator): tensor = torch.empty(2, 2) tensor.normal_(-1.0, 1.0, generator=generator) return tensor generator = torch.Generator() generator.manual_seed(2023) script_f = torch.jit.script(f, (generator,)) for i in range(3): generator = torch.Generator() generator.manual_seed(2023 + i) torch.manual_seed(1 + i) eager_tensor = f(generator) generator = torch.Generator() generator.manual_seed(2023 + i) torch.manual_seed(1 + i) script_tensor = script_f(generator) self.assertEqual(eager_tensor, script_tensor) def test_save_load(self): class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.foo = torch.nn.Linear(2, 2, bias=False) self.bar = torch.nn.Linear(2, 2, bias=False) self.reset_parameters() def reset_linear(self, module, generator): init.kaiming_uniform_( module.weight, a=math.sqrt(5), generator=generator ) def reset_parameters(self): generator = torch.Generator() generator.manual_seed(1) self.reset_linear(self.foo, generator) generator = torch.Generator() generator.manual_seed(2) self.reset_linear(self.bar, generator) def forward(self, x): x = self.foo(x) x = self.bar(x) generator = torch.Generator() generator.manual_seed(3) r = torch.empty_like(x) r.normal_(0.0, 1.0, generator=generator) return x, r eager_foo = Foo() script_module = torch.jit.script(Foo()) saved_module = io.BytesIO() torch.jit.save(script_module, saved_module) saved_module.seek(0) loaded_module = torch.jit.load(saved_module) self.assertEqual(eager_foo.foo.weight, loaded_module.foo.weight) self.assertEqual(eager_foo.bar.weight, loaded_module.bar.weight) try: # Run this 3 times so make sure that the generator seed is being set # every time forward is called for i in range(3): x = torch.ones(2, 2) out1, r1 = eager_foo(x) out2, r2 = loaded_module(x) try: self.assertEqual(out1, out2) except: # noqa: B001, E722 print(f"Iteration {i}:\n{out1=}\n{out2=}") raise try: self.assertEqual(r1, r2) except: # noqa: B001, E722 print(f"Iteration {i}:\n{r1=}\n{r2=}") raise except: # noqa: B001, E722 print(loaded_module.forward.code) raise