xref: /aosp_15_r20/external/pytorch/test/jit/test_type_sharing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import io
4import os
5import sys
6
7import torch
8
9
10# Make the helper files in test/ importable
11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12sys.path.append(pytorch_test_dir)
13from torch.testing._internal.common_utils import suppress_warnings
14from torch.testing._internal.jit_utils import JitTestCase
15
16
17if __name__ == "__main__":
18    raise RuntimeError(
19        "This test file is not meant to be run directly, use:\n\n"
20        "\tpython test/test_jit.py TESTNAME\n\n"
21        "instead."
22    )
23
24
25class TestTypeSharing(JitTestCase):
26    def assertSameType(self, m1, m2):
27        if not isinstance(m1, torch.jit.ScriptModule):
28            m1 = torch.jit.script(m1)
29        if not isinstance(m2, torch.jit.ScriptModule):
30            m2 = torch.jit.script(m2)
31        self.assertEqual(m1._c._type(), m2._c._type())
32
33    def assertDifferentType(self, m1, m2):
34        if not isinstance(m1, torch.jit.ScriptModule):
35            m1 = torch.jit.script(m1)
36        if not isinstance(m2, torch.jit.ScriptModule):
37            m2 = torch.jit.script(m2)
38        self.assertNotEqual(m1._c._type(), m2._c._type())
39
40    def test_basic(self):
41        class M(torch.nn.Module):
42            def __init__(self, a, b, c):
43                super().__init__()
44                self.a = a
45                self.b = b
46                self.c = c
47
48            def forward(self, x):
49                return x
50
51        a = torch.rand(2, 3)
52        b = torch.rand(2, 3)
53        c = torch.rand(2, 3)
54        m1 = M(a, b, c)
55        m2 = M(a, b, c)
56        self.assertSameType(m1, m2)
57
58    def test_diff_attr_values(self):
59        """
60        Types should be shared even if attribute values differ
61        """
62
63        class M(torch.nn.Module):
64            def __init__(self, a, b, c):
65                super().__init__()
66                self.a = a
67                self.b = b
68                self.c = c
69
70            def forward(self, x):
71                return x
72
73        a = torch.rand(2, 3)
74        b = torch.rand(2, 3)
75        c = torch.rand(2, 3)
76        m1 = M(a, b, c)
77        m2 = M(a * 2, b * 3, c * 4)
78        self.assertSameType(m1, m2)
79
80    def test_constants(self):
81        """
82        Types should be shared for identical constant values, and different for different constant values
83        """
84
85        class M(torch.nn.Module):
86            __constants__ = ["const"]
87
88            def __init__(self, attr, const):
89                super().__init__()
90                self.attr = attr
91                self.const = const
92
93            def forward(self):
94                return self.const
95
96        attr = torch.rand(2, 3)
97        m1 = M(attr, 1)
98        m2 = M(attr, 1)
99        self.assertSameType(m1, m2)
100
101        # a different constant value
102        m3 = M(attr, 2)
103        self.assertDifferentType(m1, m3)
104
105    def test_linear(self):
106        """
107        Simple example with a real nn Module
108        """
109        a = torch.nn.Linear(5, 5)
110        b = torch.nn.Linear(5, 5)
111        c = torch.nn.Linear(10, 10)
112        a = torch.jit.script(a)
113        b = torch.jit.script(b)
114        c = torch.jit.script(c)
115
116        self.assertSameType(a, b)
117        self.assertDifferentType(a, c)
118
119    def test_submodules(self):
120        """
121        If submodules differ, the types should differ.
122        """
123
124        class M(torch.nn.Module):
125            def __init__(self, in1, out1, in2, out2):
126                super().__init__()
127                self.submod1 = torch.nn.Linear(in1, out1)
128                self.submod2 = torch.nn.Linear(in2, out2)
129
130            def forward(self, x):
131                x = self.submod1(x)
132                x = self.submod2(x)
133                return x
134
135        a = M(1, 1, 2, 2)
136        b = M(1, 1, 2, 2)
137        self.assertSameType(a, b)
138        self.assertSameType(a.submod1, b.submod1)
139        c = M(2, 2, 2, 2)
140        self.assertDifferentType(a, c)
141
142        self.assertSameType(b.submod2, c.submod1)
143        self.assertDifferentType(a.submod1, b.submod2)
144
145    def test_param_vs_attribute(self):
146        """
147        The same module with an `foo` as a parameter vs. attribute shouldn't
148        share types
149        """
150
151        class M(torch.nn.Module):
152            def __init__(self, foo):
153                super().__init__()
154                self.foo = foo
155
156            def forward(self, x):
157                return x + self.foo
158
159        as_param = torch.nn.Parameter(torch.ones(2, 2))
160        as_attr = torch.ones(2, 2)
161        param_mod = M(as_param)
162        attr_mod = M(as_attr)
163        self.assertDifferentType(attr_mod, param_mod)
164
165    def test_same_but_different_classes(self):
166        """
167        Even if everything about the module is the same, different originating
168        classes should prevent type sharing.
169        """
170
171        class A(torch.nn.Module):
172            __constants__ = ["const"]
173
174            def __init__(self, in1, out1, in2, out2):
175                super().__init__()
176                self.submod1 = torch.nn.Linear(in1, out1)
177                self.submod2 = torch.nn.Linear(in2, out2)
178                self.const = 5
179
180            def forward(self, x):
181                x = self.submod1(x)
182                x = self.submod2(x)
183                return x * self.const
184
185        class B(torch.nn.Module):
186            __constants__ = ["const"]
187
188            def __init__(self, in1, out1, in2, out2):
189                super().__init__()
190                self.submod1 = torch.nn.Linear(in1, out1)
191                self.submod2 = torch.nn.Linear(in2, out2)
192                self.const = 5
193
194            def forward(self, x):
195                x = self.submod1(x)
196                x = self.submod2(x)
197                return x * self.const
198
199        a = A(1, 1, 2, 2)
200        b = B(1, 1, 2, 2)
201        self.assertDifferentType(a, b)
202
203    def test_mutate_attr_value(self):
204        """
205        Mutating the value of an attribute should not change type sharing
206        """
207
208        class M(torch.nn.Module):
209            def __init__(self, in1, out1, in2, out2):
210                super().__init__()
211                self.submod1 = torch.nn.Linear(in1, out1)
212                self.submod2 = torch.nn.Linear(in2, out2)
213                self.foo = torch.ones(in1, in1)
214
215            def forward(self, x):
216                x = self.submod1(x)
217                x = self.submod2(x)
218                return x + self.foo
219
220        a = M(1, 1, 2, 2)
221        b = M(1, 1, 2, 2)
222        a.foo = torch.ones(2, 2)
223        b.foo = torch.rand(2, 2)
224        self.assertSameType(a, b)
225
226    def test_assign_python_attr(self):
227        """
228        Assigning a new (python-only) attribute should not change type sharing
229        """
230
231        class M(torch.nn.Module):
232            def __init__(self, in1, out1, in2, out2):
233                super().__init__()
234                self.submod1 = torch.nn.Linear(in1, out1)
235                self.submod2 = torch.nn.Linear(in2, out2)
236                self.foo = torch.ones(in1, in1)
237
238            def forward(self, x):
239                x = self.submod1(x)
240                x = self.submod2(x)
241                return x + self.foo
242
243        # explicitly call script() to freeze the type
244        a = torch.jit.script(M(1, 1, 2, 2))
245        b = torch.jit.script(M(1, 1, 2, 2))
246        a.new_attr = "foo bar baz"
247        self.assertSameType(a, b)
248
249        # but if we assign attributes *before* calling script(), the types
250        # should be different, since `new_attr` should be turned into a Script
251        # attribute
252        a = M(1, 1, 2, 2)
253        b = M(1, 1, 2, 2)
254        a.new_attr = "foo bar baz"
255        self.assertDifferentType(a, b)
256
257    def test_failed_attribute_compilation(self):
258        """
259        Attributes whose type cannot be inferred should fail cleanly with nice hints
260        """
261
262        class M(torch.nn.Module):
263            def __init__(self) -> None:
264                super().__init__()
265                # assign a type we know can't be converted to TorchScript
266                self.foo = object
267
268            def forward(self):
269                # try to use it in forward
270                return self.foo
271
272        m = M()
273        with self.assertRaisesRegexWithHighlight(
274            RuntimeError, "failed to convert Python type", "self.foo"
275        ):
276            torch.jit.script(m)
277
278    def test_script_function_attribute_different(self):
279        """
280        Different functions passed in should lead to different types
281        """
282
283        @torch.jit.script
284        def fn1(x):
285            return x + x
286
287        @torch.jit.script
288        def fn2(x):
289            return x - x
290
291        class M(torch.nn.Module):
292            def __init__(self, fn):
293                super().__init__()
294                self.fn = fn
295
296            def forward(self, x):
297                return self.fn(x)
298
299        fn1_mod = M(fn1)
300        fn2_mod = M(fn2)
301
302        self.assertDifferentType(fn1_mod, fn2_mod)
303
304    def test_builtin_function_same(self):
305        class Caller(torch.nn.Module):
306            def __init__(self, fn):
307                super().__init__()
308                self.fn = fn
309
310            def forward(self, input):
311                return self.fn(input, input)
312
313        c1 = Caller(torch.add)
314        c2 = Caller(torch.add)
315
316        self.assertSameType(c1, c2)
317
318    def test_builtin_function_different(self):
319        class Caller(torch.nn.Module):
320            def __init__(self, fn):
321                super().__init__()
322                self.fn = fn
323
324            def forward(self, input):
325                return self.fn(input, input)
326
327        c1 = Caller(torch.add)
328        c2 = Caller(torch.sub)
329
330        self.assertDifferentType(c1, c2)
331
332    def test_script_function_attribute_same(self):
333        """
334        Same functions passed in should lead to same types
335        """
336
337        @torch.jit.script
338        def fn(x):
339            return x + x
340
341        class M(torch.nn.Module):
342            def __init__(self, fn):
343                super().__init__()
344                self.fn = fn
345
346            def forward(self, x):
347                return self.fn(x)
348
349        fn1_mod = M(fn)
350        fn2_mod = M(fn)
351
352        self.assertSameType(fn1_mod, fn2_mod)
353
354    def test_python_function_attribute_different(self):
355        """
356        Different functions passed in should lead to different types
357        """
358
359        def fn1(x):
360            return x + x
361
362        def fn2(x):
363            return x - x
364
365        class M(torch.nn.Module):
366            def __init__(self, fn):
367                super().__init__()
368                self.fn = fn
369
370            def forward(self, x):
371                return self.fn(x)
372
373        fn1_mod = M(fn1)
374        fn2_mod = M(fn2)
375
376        self.assertDifferentType(fn1_mod, fn2_mod)
377
378    def test_python_function_attribute_same(self):
379        """
380        Same functions passed in should lead to same types
381        """
382
383        def fn(x):
384            return x + x
385
386        class M(torch.nn.Module):
387            def __init__(self, fn):
388                super().__init__()
389                self.fn = fn
390
391            def forward(self, x):
392                return self.fn(x)
393
394        fn1_mod = M(fn)
395        fn2_mod = M(fn)
396
397        self.assertSameType(fn1_mod, fn2_mod)
398
399    @suppress_warnings
400    def test_tracing_gives_different_types(self):
401        """
402        Since we can't guarantee that methods are the same between different
403        trace runs, tracing must always generate a unique type.
404        """
405
406        class M(torch.nn.Module):
407            def forward(self, x, y):
408                if x.sum() > y.sum():
409                    return x
410                else:
411                    return y
412
413        a = torch.jit.trace(M(), (torch.zeros(1, 1), torch.ones(1, 1)))
414        b = torch.jit.trace(M(), (torch.ones(1, 1), torch.zeros(1, 1)))
415        self.assertDifferentType(a, b)
416
417    def test_ignored_fns(self):
418        class M(torch.nn.Module):
419            def __init__(self, foo):
420                super().__init__()
421                self.foo = foo
422
423            @torch.jit.ignore
424            def ignored(self):
425                return self.foo
426
427            def forward(self):
428                return self.ignored()
429
430        a = torch.jit.script(M(torch.ones(1)))
431        b = torch.jit.script(M(torch.ones(2)))
432        self.assertSameType(a, b)
433        self.assertNotEqual(a(), b())
434
435    @suppress_warnings
436    def test_script_module_containing_traced_module(self):
437        class Traced(torch.nn.Module):
438            def forward(self, x):
439                if x.sum() > 0:
440                    return x
441                else:
442                    return x + x
443
444        class M(torch.nn.Module):
445            def __init__(self, input):
446                super().__init__()
447                self.traced = torch.jit.trace(Traced(), input)
448
449            def forward(self, x):
450                return self.traced(x)
451
452        a = M((torch.ones(1),))
453        b = M((torch.zeros(1),))
454        self.assertDifferentType(a, b)
455
456    def test_loaded_modules_work(self):
457        class AB(torch.nn.Module):
458            def __init__(self) -> None:
459                super().__init__()
460                self.a = 1
461                self.b = 1
462
463            def forward(self):
464                return self.a + self.b
465
466        class A(torch.nn.Module):
467            def __init__(self) -> None:
468                super().__init__()
469                self.a = 1
470
471            def forward(self):
472                return self.a
473
474        class Wrapper(torch.nn.Module):
475            def __init__(self, sub):
476                super().__init__()
477                self.sub = sub
478
479            def forward(self):
480                return self.sub()
481
482        def package(x):
483            buffer = io.BytesIO()
484            torch.jit.save(torch.jit.script(x), buffer)
485            buffer.seek(0)
486            return torch.jit.script(Wrapper(torch.jit.load(buffer)))
487
488        a = package(AB())
489        a()
490        b = package(A())
491        b()
492
493    def test_module_dict_same_type_different_name(self):
494        """
495        We should be able to differentiate between two ModuleDict instances
496        that have different keys but the same value types.
497        """
498
499        class A(torch.nn.Module):
500            def forward(self, x):
501                return x
502
503        class Foo(torch.nn.Module):
504            def __init__(self, s):
505                super().__init__()
506                self.dict = torch.nn.ModuleDict(s)
507
508            def forward(self, x):
509                return x
510
511        a = Foo({"foo": A()})
512        b = Foo({"bar": A()})
513        c = Foo({"bar": A()})
514        self.assertDifferentType(a, b)
515        self.assertSameType(b, c)
516
517    def test_type_sharing_define_in_init(self):
518        """
519        Tests that types between instances of a ScriptModule
520        subclass that defines methods in its __init__ are not
521        shared.
522        """
523
524        class A(torch.jit.ScriptModule):
525            def __init__(self, val):
526                super().__init__()
527                self.define(
528                    f"""
529                def forward(self) -> int:
530                    return {val}
531                """
532                )
533
534        one = A(1)
535        two = A(2)
536
537        self.assertEqual(one(), 1)
538        self.assertEqual(two(), 2)
539
540    def test_type_sharing_disabled(self):
541        """
542        Test that type sharing can be disabled.
543        """
544
545        class A(torch.nn.Module):
546            def __init__(self, sub):
547                super().__init__()
548                self.sub = sub
549
550            def forward(self, x):
551                return x
552
553        class B(torch.nn.Module):
554            def forward(self, x):
555                return x
556
557        top1 = A(A(B()))
558        top2 = A(A(B()))
559
560        top1_s = torch.jit._recursive.create_script_module(
561            top1,
562            torch.jit._recursive.infer_methods_to_compile,
563            share_types=False,
564        )
565        top2_s = torch.jit._recursive.create_script_module(
566            top2,
567            torch.jit._recursive.infer_methods_to_compile,
568            share_types=False,
569        )
570
571        self.assertDifferentType(top1_s, top2_s)
572        self.assertDifferentType(top1_s, top1_s.sub)
573        self.assertDifferentType(top1_s, top2_s.sub)
574        self.assertDifferentType(top2_s, top2_s.sub)
575        self.assertDifferentType(top2_s, top1_s.sub)
576
577    def test_type_shared_ignored_attributes(self):
578        """
579        Test that types are shared if the exclusion of their
580        ignored attributes makes them equal.
581        """
582
583        class A(torch.nn.Module):
584            __jit_ignored_attributes__ = ["a"]
585
586            def __init__(self, a, b):
587                super().__init__()
588                self.a = a
589                self.b = b
590
591            def forward(self, x):
592                return x
593
594        a_with_linear = A(torch.nn.Linear(5, 5), 5)
595        a_with_string = A("string", 10)
596
597        # Both should have the same type because the attribute
598        # that differs in type is ignored and the common attribute
599        # has the same type.
600        self.assertSameType(a_with_linear, a_with_string)
601
602    def test_type_not_shared_ignored_attributes(self):
603        """
604        Test that types are not shared if the exclusion of their
605        ignored attributes makes them not equal.
606        """
607
608        class A(torch.nn.Module):
609            __jit_ignored_attributes__ = ["a"]
610
611            def __init__(self, a, b, c):
612                super().__init__()
613                self.a = a
614                self.b = b
615                self.c = c
616
617            def forward(self, x):
618                return x
619
620        mod = A(torch.nn.Linear(5, 5), 5, "string")
621        s1 = torch.jit.script(mod)
622        A.__jit_ignored_attributes__ = ["a", "b"]
623        s2 = torch.jit.script(mod)
624
625        # The types of s1 and s2 should differ. Although they are instances
626        # of A, __jit_ignored_attributes__ was modified before scripting s2,
627        # so the set of ignored attributes is different between s1 and s2.
628        self.assertDifferentType(s1, s2)
629