xref: /aosp_15_r20/external/pytorch/test/jit/test_generator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import io
4import math
5import unittest
6
7import torch
8from torch.nn import init
9from torch.testing._internal.common_utils import skipIfLegacyJitExecutor
10from torch.testing._internal.jit_utils import JitTestCase
11
12
13if __name__ == "__main__":
14    raise RuntimeError(
15        "This test file is not meant to be run directly, use:\n\n"
16        "\tpython test/test_jit.py TESTNAME\n\n"
17        "instead."
18    )
19
20
21class TestGenerator(JitTestCase):
22    # torch.jit.trace does not properly capture the generator manual seed
23    # and thus is non deterministic even if the generator is manually seeded
24    @skipIfLegacyJitExecutor("legacy JIT executor does not support Generator type")
25    @unittest.expectedFailure
26    def test_trace(self):
27        def f():
28            generator = torch.Generator()
29            generator.seed()
30            generator.manual_seed(2023)
31            generator.initial_seed()
32            tensor = torch.empty(2, 2)
33            tensor.uniform_(0, 1, generator=generator)
34            return tensor
35
36        traced_f = torch.jit.trace(f, ())
37
38        # Run this 3 times to ensure that the generator is being manually seeded
39        # each time the traced function is run
40        for i in range(3):
41            torch.manual_seed(1)
42
43            eager_tensor = f()
44
45            # Change the seed of the default generator to
46            # check that we're using the generator from the
47            # trace
48            torch.manual_seed(2)
49            traced_tensor = traced_f()
50
51            self.assertEqual(eager_tensor, traced_tensor)
52
53    def test_script(self):
54        def f():
55            generator = torch.Generator()
56            generator.seed()
57            generator.manual_seed(2023)
58            generator.initial_seed()
59            tensor = torch.empty(2, 2)
60            tensor.normal_(-1.0, 1.0, generator=generator)
61            return tensor
62
63        script_f = torch.jit.script(f, ())
64
65        # Run this 3 times to ensure that the generator is being manually seeded
66        # each time the traced function is run
67        for i in range(3):
68            torch.manual_seed(1)
69
70            eager_tensor = f()
71
72            # Change the seed of the default generator to
73            # check that we're using the generator from the
74            # trace
75            torch.manual_seed(2)
76
77            script_tensor = script_f()
78
79            self.assertEqual(eager_tensor, script_tensor)
80
81    def test_default_generator(self):
82        def f():
83            # check that calling manual seed for the default generator works
84            torch.manual_seed(2023)
85            tensor = torch.empty(2, 2)
86            tensor.normal_(-1.0, 1.0)
87            return tensor
88
89        torch.manual_seed(1)
90
91        eager_tensor = f()
92
93        torch.manual_seed(2)
94
95        script_f = torch.jit.script(f, ())
96        script_tensor = script_f()
97
98        self.assertEqual(eager_tensor, script_tensor)
99
100    def test_generator_arg(self):
101        def f(generator: torch.Generator):
102            tensor = torch.empty(2, 2)
103            tensor.normal_(-1.0, 1.0, generator=generator)
104            return tensor
105
106        generator = torch.Generator()
107        generator.manual_seed(2023)
108
109        script_f = torch.jit.script(f, (generator,))
110
111        for i in range(3):
112            generator = torch.Generator()
113            generator.manual_seed(2023 + i)
114
115            torch.manual_seed(1 + i)
116
117            eager_tensor = f(generator)
118
119            generator = torch.Generator()
120            generator.manual_seed(2023 + i)
121
122            torch.manual_seed(1 + i)
123
124            script_tensor = script_f(generator)
125
126            self.assertEqual(eager_tensor, script_tensor)
127
128    def test_save_load(self):
129        class Foo(torch.nn.Module):
130            def __init__(self) -> None:
131                super().__init__()
132                self.foo = torch.nn.Linear(2, 2, bias=False)
133                self.bar = torch.nn.Linear(2, 2, bias=False)
134
135                self.reset_parameters()
136
137            def reset_linear(self, module, generator):
138                init.kaiming_uniform_(
139                    module.weight, a=math.sqrt(5), generator=generator
140                )
141
142            def reset_parameters(self):
143                generator = torch.Generator()
144                generator.manual_seed(1)
145                self.reset_linear(self.foo, generator)
146
147                generator = torch.Generator()
148                generator.manual_seed(2)
149                self.reset_linear(self.bar, generator)
150
151            def forward(self, x):
152                x = self.foo(x)
153                x = self.bar(x)
154
155                generator = torch.Generator()
156                generator.manual_seed(3)
157                r = torch.empty_like(x)
158                r.normal_(0.0, 1.0, generator=generator)
159
160                return x, r
161
162        eager_foo = Foo()
163
164        script_module = torch.jit.script(Foo())
165        saved_module = io.BytesIO()
166        torch.jit.save(script_module, saved_module)
167        saved_module.seek(0)
168
169        loaded_module = torch.jit.load(saved_module)
170
171        self.assertEqual(eager_foo.foo.weight, loaded_module.foo.weight)
172        self.assertEqual(eager_foo.bar.weight, loaded_module.bar.weight)
173
174        try:
175            # Run this 3 times so make sure that the generator seed is being set
176            # every time forward is called
177            for i in range(3):
178                x = torch.ones(2, 2)
179                out1, r1 = eager_foo(x)
180                out2, r2 = loaded_module(x)
181
182                try:
183                    self.assertEqual(out1, out2)
184                except:  # noqa: B001, E722
185                    print(f"Iteration {i}:\n{out1=}\n{out2=}")
186                    raise
187
188                try:
189                    self.assertEqual(r1, r2)
190                except:  # noqa: B001, E722
191                    print(f"Iteration {i}:\n{r1=}\n{r2=}")
192                    raise
193        except:  # noqa: B001, E722
194            print(loaded_module.forward.code)
195            raise
196