xref: /aosp_15_r20/external/pytorch/test/jit/test_class_type.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import io
4import os
5import sys
6import unittest
7from typing import Any
8
9import torch
10import torch.nn as nn
11from torch.testing import FileCheck
12
13
14# Make the helper files in test/ importable
15pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
16sys.path.append(pytorch_test_dir)
17from typing import Dict, Iterable, List, Optional, Tuple
18
19import torch.testing._internal.jit_utils
20from torch.testing._internal.common_utils import IS_SANDCASTLE, skipIfTorchDynamo
21from torch.testing._internal.jit_utils import JitTestCase, make_global
22
23
24if __name__ == "__main__":
25    raise RuntimeError(
26        "This test file is not meant to be run directly, use:\n\n"
27        "\tpython test/test_jit.py TESTNAME\n\n"
28        "instead."
29    )
30
31
32class TestClassType(JitTestCase):
33    def test_reference_semantics(self):
34        """
35        Test that modifications made to a class instance in TorchScript
36        are visible in eager.
37        """
38
39        class Foo:
40            def __init__(self, a: int):
41                self.a = a
42
43            def set_a(self, value: int):
44                self.a = value
45
46            def get_a(self) -> int:
47                return self.a
48
49            @property
50            def attr(self):
51                return self.a
52
53        make_global(Foo)  # see [local resolution in python]
54
55        def test_fn(obj: Foo):
56            obj.set_a(2)
57
58        scripted_fn = torch.jit.script(test_fn)
59        obj = torch.jit.script(Foo(1))
60        self.assertEqual(obj.get_a(), 1)
61        self.assertEqual(obj.attr, 1)
62
63        scripted_fn(obj)
64
65        self.assertEqual(obj.get_a(), 2)
66        self.assertEqual(obj.attr, 2)
67
68    def test_get_with_method(self):
69        class FooTest:
70            def __init__(self, x):
71                self.foo = x
72
73            def getFooTest(self):
74                return self.foo
75
76        def fn(x):
77            foo = FooTest(x)
78            return foo.getFooTest()
79
80        input = torch.ones(2, 3)
81        self.assertEqual(fn(input), input)
82
83    def test_get_attr(self):
84        class FooTest:  # noqa: B903
85            def __init__(self, x):
86                self.foo = x
87
88        @torch.jit.script
89        def fn(x):
90            foo = FooTest(x)
91            return foo.foo
92
93        input = torch.ones(2, 3)
94        self.assertEqual(fn(input), input)
95
96    def test_in(self):
97        class FooTest:  # noqa: B903
98            def __init__(self) -> None:
99                pass
100
101            def __contains__(self, key: str) -> bool:
102                return key == "hi"
103
104        @torch.jit.script
105        def fn():
106            foo = FooTest()
107            return "hi" in foo, "no" in foo
108
109        self.assertEqual(fn(), (True, False))
110
111    def test_set_attr_in_method(self):
112        class FooTest:
113            def __init__(self, x: int) -> None:
114                self.foo = x
115
116            def incFooTest(self, y: int) -> None:
117                self.foo = self.foo + y
118
119        @torch.jit.script
120        def fn(x: int) -> int:
121            foo = FooTest(x)
122            foo.incFooTest(2)
123            return foo.foo
124
125        self.assertEqual(fn(1), 3)
126
127    def test_set_attr_type_mismatch(self):
128        with self.assertRaisesRegexWithHighlight(
129            RuntimeError, "Wrong type for attribute assignment", "self.foo = 10"
130        ):
131
132            @torch.jit.script
133            class FooTest:
134                def __init__(self, x):
135                    self.foo = x
136                    self.foo = 10  # should error since int != Tensor
137
138    def test_get_attr_not_initialized(self):
139        with self.assertRaisesRegexWithHighlight(
140            RuntimeError, "object has no attribute or method", "self.asdf"
141        ):
142
143            @torch.jit.script
144            class FooTest:
145                def __init__(self, x):
146                    self.foo = x
147
148                def get_non_initialized(self):
149                    return self.asdf  # asdf isn't an attr
150
151    def test_set_attr_non_initialized(self):
152        with self.assertRaisesRegexWithHighlight(
153            RuntimeError, "Tried to set nonexistent attribute", "self.bar = y"
154        ):
155
156            @torch.jit.script
157            class FooTest:
158                def __init__(self, x):
159                    self.foo = x
160
161                def set_non_initialized(self, y):
162                    self.bar = y  # can't assign to non-initialized attr
163
164    def test_schema_human_readable(self):
165        """
166        Make sure that the schema is human readable, ie the mode parameter should read "nearest" instead of being displayed in octal
167        aten::__interpolate(Tensor input, int? size=None, float[]? scale_factor=None,
168        str mode='\156\145\141\162\145\163\164', bool? align_corners=None) -> (Tensor):
169        Expected a value of type 'Optional[int]' for argument 'size' but instead found type 'Tensor'.
170        """
171        with self.assertRaisesRegexWithHighlight(RuntimeError, "nearest", ""):
172
173            @torch.jit.script
174            def FooTest(x):
175                return torch.nn.functional.interpolate(x, "bad")
176
177    def test_type_annotations(self):
178        with self.assertRaisesRegexWithHighlight(
179            RuntimeError, "Expected a value of type 'bool", ""
180        ):
181
182            @torch.jit.script  # noqa: B903
183            class FooTest:  # noqa: B903
184                def __init__(self, x: bool) -> None:
185                    self.foo = x
186
187            @torch.jit.script
188            def fn(x):
189                FooTest(x)
190
191            fn(2)
192
193    def test_conditional_set_attr(self):
194        with self.assertRaisesRegexWithHighlight(
195            RuntimeError, "assignment cannot be in a control-flow block", ""
196        ):
197
198            @torch.jit.script
199            class FooTest:
200                def __init__(self, x):
201                    if 1 == 1:
202                        self.attr = x
203
204    def test_class_type_as_param(self):
205        class FooTest:  # noqa: B903
206            def __init__(self, x):
207                self.attr = x
208
209        make_global(FooTest)  # see [local resolution in python]
210
211        @torch.jit.script
212        def fn(foo: FooTest) -> torch.Tensor:
213            return foo.attr
214
215        @torch.jit.script
216        def fn2(x):
217            foo = FooTest(x)
218            return fn(foo)
219
220        input = torch.ones(1)
221        self.assertEqual(fn2(input), input)
222
223    def test_out_of_order_methods(self):
224        class FooTest:
225            def __init__(self, x):
226                self.x = x
227                self.x = self.get_stuff(x)
228
229            def get_stuff(self, y):
230                return self.x + y
231
232        @torch.jit.script
233        def fn(x):
234            f = FooTest(x)
235            return f.x
236
237        input = torch.ones(1)
238        self.assertEqual(fn(input), input + input)
239
240    def test_save_load_with_classes(self):
241        class FooTest:
242            def __init__(self, x):
243                self.x = x
244
245            def get_x(self):
246                return self.x
247
248        class MyMod(torch.jit.ScriptModule):
249            @torch.jit.script_method
250            def forward(self, a):
251                foo = FooTest(a)
252                return foo.get_x()
253
254        m = MyMod()
255
256        buffer = io.BytesIO()
257        torch.jit.save(m, buffer)
258
259        # classes are globally registered for now, so we need to clear the JIT
260        # registry to simulate loading a new model
261
262        buffer.seek(0)
263        m_loaded = torch.jit.load(buffer)
264
265        input = torch.rand(2, 3)
266        output = m_loaded(input)
267        self.assertEqual(input, output)
268
269    def test_save_load_with_classes_returned(self):
270        class FooTest:
271            def __init__(self, x):
272                self.x = x
273
274            def clone(self):
275                clone = FooTest(self.x)
276                return clone
277
278        class MyMod(torch.jit.ScriptModule):
279            @torch.jit.script_method
280            def forward(self, a):
281                foo = FooTest(a)
282                foo_clone = foo.clone()
283                return foo_clone.x
284
285        m = MyMod()
286
287        buffer = io.BytesIO()
288        torch.jit.save(m, buffer)
289
290        # classes are globally registered for now, so we need to clear the JIT
291        # registry to simulate loading a new model
292        torch.testing._internal.jit_utils.clear_class_registry()
293
294        buffer.seek(0)
295        m_loaded = torch.jit.load(buffer)
296
297        input = torch.rand(2, 3)
298        output = m_loaded(input)
299        self.assertEqual(input, output)
300
301    def test_save_load_with_classes_nested(self):
302        class FooNestedTest:  # noqa: B903
303            def __init__(self, y):
304                self.y = y
305
306        class FooNestedTest2:
307            def __init__(self, y):
308                self.y = y
309                self.nested = FooNestedTest(y)
310
311        class FooTest:
312            def __init__(self, x):
313                self.class_attr = FooNestedTest(x)
314                self.class_attr2 = FooNestedTest2(x)
315                self.x = self.class_attr.y + self.class_attr2.y
316
317        class MyMod(torch.jit.ScriptModule):
318            @torch.jit.script_method
319            def forward(self, a):
320                foo = FooTest(a)
321                return foo.x
322
323        m = MyMod()
324
325        buffer = io.BytesIO()
326        torch.jit.save(m, buffer)
327
328        # classes are globally registered for now, so we need to clear the JIT
329        # registry to simulate loading a new model
330        torch.testing._internal.jit_utils.clear_class_registry()
331
332        buffer.seek(0)
333        m_loaded = torch.jit.load(buffer)
334
335        input = torch.rand(2, 3)
336        output = m_loaded(input)
337        self.assertEqual(2 * input, output)
338
339    def test_python_interop(self):
340        class Foo:  # noqa: B903
341            def __init__(self, x, y):
342                self.x = x
343                self.y = y
344
345        make_global(Foo)  # see [local resolution in python]
346
347        @torch.jit.script
348        def use_foo(foo: Foo) -> Foo:
349            return foo
350
351        # create from python
352        x = torch.ones(2, 3)
353        y = torch.zeros(2, 3)
354        f = Foo(x, y)
355
356        self.assertEqual(x, f.x)
357        self.assertEqual(y, f.y)
358
359        # pass in and out of script
360        f2 = use_foo(f)
361
362        self.assertEqual(x, f2.x)
363        self.assertEqual(y, f2.y)
364
365    def test_class_specialization(self):
366        class Foo:  # noqa: B903
367            def __init__(self, x, y):
368                self.x = x
369                self.y = y
370
371        make_global(Foo)  # see [local resolution in python]
372
373        def use_foo(foo: Foo, foo2: Foo, tup: Tuple[Foo, Foo]) -> torch.Tensor:
374            a, b = tup
375            return foo.x + foo2.y + a.x + b.y
376
377        # create from python
378        x = torch.ones(2, 3)
379        y = torch.zeros(2, 3)
380        f = Foo(x, y)
381        f2 = Foo(x * 2, y * 3)
382        f3 = Foo(x * 4, y * 4)
383
384        input = (f, f2, (f, f3))
385        sfoo = self.checkScript(use_foo, input)
386        graphstr = str(sfoo.graph_for(*input))
387        FileCheck().check_count("prim::GetAttr", 4).run(graphstr)
388
389    def test_class_sorting(self):
390        class Foo:  # noqa: B903
391            def __init__(self, x: int) -> None:
392                self.x = x
393
394            def __lt__(self, other) -> bool:
395                # type: (Foo) -> bool
396                return self.x < other.x
397
398            def getVal(self):
399                return self.x
400
401        make_global(Foo)  # see [local resolution in python]
402
403        def test(li: List[Foo], reverse: bool = False) -> Tuple[List[int], List[int]]:
404            li_sorted = sorted(li)
405            ret_sorted = torch.jit.annotate(List[int], [])
406            for foo in li_sorted:
407                ret_sorted.append(foo.getVal())
408
409            li.sort(reverse=reverse)
410            ret_sort = torch.jit.annotate(List[int], [])
411            for foo in li:
412                ret_sort.append(foo.getVal())
413            return ret_sorted, ret_sort
414
415        self.checkScript(test, ([Foo(2), Foo(1), Foo(3)],))
416        self.checkScript(test, ([Foo(2), Foo(1), Foo(3)], True))
417        self.checkScript(test, ([Foo(2)],))
418        self.checkScript(test, ([],))
419
420        @torch.jit.script
421        def test_list_no_reverse():
422            li = [Foo(3), Foo(1)]
423            li.sort()
424            return li[0].getVal()
425
426        self.assertEqual(test_list_no_reverse(), 1)
427
428        @torch.jit.script
429        def test_sorted_copies():
430            li = [Foo(3), Foo(1)]
431            li_sorted = sorted(li)
432            return li[0].getVal(), li_sorted[0].getVal()
433
434        self.assertEqual(test_sorted_copies(), (3, 1))
435
436        @torch.jit.script
437        def test_nested_inside_tuple():
438            li = [(1, Foo(12)), (1, Foo(11))]
439            li.sort()
440            return [(li[0][0], li[0][1].getVal()), (li[1][0], li[1][1].getVal())]
441
442        self.assertEqual(test_nested_inside_tuple(), [(1, 11), (1, 12)])
443
444        with self.assertRaisesRegexWithHighlight(
445            RuntimeError, "bool' for argument 'reverse", ""
446        ):
447
448            @torch.jit.script
449            def test():
450                li = [Foo(1)]
451                li.sort(li)
452                return li
453
454            test()
455
456        with self.assertRaisesRegexWithHighlight(
457            RuntimeError, "must define a __lt__", ""
458        ):
459
460            @torch.jit.script
461            class NoMethod:
462                def __init__(self) -> None:
463                    pass
464
465            @torch.jit.script
466            def test():
467                li = [NoMethod(), NoMethod()]
468                li.sort()
469                return li
470
471            test()
472
473        @torch.jit.script
474        class WrongLt:
475            def __init__(self) -> None:
476                pass
477
478            # lt method defined with the wrong signature
479            def __lt__(self, other):
480                pass
481
482        with self.assertRaisesRegexWithHighlight(
483            RuntimeError, "must define a __lt__", ""
484        ):
485
486            @torch.jit.script
487            def test():
488                li = [WrongLt(), WrongLt()]
489                li.sort()
490                return li
491
492            test()
493
494    def test_class_inheritance(self):
495        @torch.jit.script
496        class Base:
497            def __init__(self) -> None:
498                self.b = 2
499
500            def two(self, x):
501                return x + self.b
502
503        with self.assertRaisesRegexWithHighlight(
504            RuntimeError, "does not support inheritance", ""
505        ):
506
507            @torch.jit.script
508            class Derived(Base):
509                def two(self, x):
510                    return x + self.b + 2
511
512    def test_class_inheritance_implicit(self):
513        """
514        Test that inheritance is detected in
515        implicit scripting codepaths (e.g. try_ann_to_type).
516        """
517
518        class A:
519            def __init__(self, t):
520                self.t = t
521
522            @staticmethod
523            def f(a: torch.Tensor):
524                return A(a + 1)
525
526        class B(A):
527            def __init__(self, t):
528                self.t = t + 10
529
530            @staticmethod
531            def f(a: torch.Tensor):
532                return A(a + 1)
533
534        x = A(torch.tensor([3]))
535
536        def fun(x: Any):
537            if isinstance(x, A):
538                return A.f(x.t)
539            else:
540                return B.f(x.t)
541
542        with self.assertRaisesRegexWithHighlight(
543            RuntimeError, "object has no attribute or method", ""
544        ):
545            sc = torch.jit.script(fun)
546
547    @skipIfTorchDynamo("Test does not work with TorchDynamo")
548    @unittest.skipIf(IS_SANDCASTLE, "Importing like this doesn't work in fbcode")
549    def test_imported_classes(self):
550        import jit._imported_class_test.bar
551        import jit._imported_class_test.foo
552        import jit._imported_class_test.very.very.nested
553
554        class MyMod(torch.jit.ScriptModule):
555            @torch.jit.script_method
556            def forward(self, a):
557                foo = jit._imported_class_test.foo.FooSameName(a)
558                bar = jit._imported_class_test.bar.FooSameName(a)
559                three = jit._imported_class_test.very.very.nested.FooUniqueName(a)
560                return foo.x + bar.y + three.y
561
562        m = MyMod()
563
564        buffer = io.BytesIO()
565        torch.jit.save(m, buffer)
566
567        # classes are globally registered for now, so we need to clear the JIT
568        # registry to simulate loading a new model
569        torch.testing._internal.jit_utils.clear_class_registry()
570
571        buffer.seek(0)
572        m_loaded = torch.jit.load(buffer)
573
574        input = torch.rand(2, 3)
575        output = m_loaded(input)
576        self.assertEqual(3 * input, output)
577
578    def test_interface(self):
579        @torch.jit.script
580        class Foo:
581            def __init__(self) -> None:
582                pass
583
584            def one(self, x, y):
585                return x + y
586
587            def two(self, x):
588                return 2 * x
589
590        @torch.jit.script
591        class Bar:
592            def __init__(self) -> None:
593                pass
594
595            def one(self, x, y):
596                return x * y
597
598            def two(self, x):
599                return 2 / x
600
601        @torch.jit.interface
602        class OneTwo:
603            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
604                pass
605
606            def two(self, x: torch.Tensor) -> torch.Tensor:
607                pass
608
609        @torch.jit.interface
610        class OneTwoThree:
611            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
612                pass
613
614            def two(self, x: torch.Tensor) -> torch.Tensor:
615                pass
616
617            def three(self, x: torch.Tensor) -> torch.Tensor:
618                pass
619
620        @torch.jit.interface
621        class OneTwoWrong:
622            def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
623                pass
624
625            def two(self, x: int) -> int:
626                pass
627
628        @torch.jit.script
629        class NotMember:
630            def __init__(self) -> None:
631                pass
632
633            def one(self, x, y):
634                return x + y
635
636            # missing two
637
638        @torch.jit.script
639        class NotMember2:
640            def __init__(self) -> None:
641                pass
642
643            def one(self, x, y):
644                return x + y
645
646            def two(self, x: int) -> int:
647                return 3
648
649        make_global(Foo, Bar, OneTwo, OneTwoThree, OneTwoWrong, NotMember, NotMember2)
650
651        def use_them(x):
652            a = Foo()
653            b = Bar()
654            c = torch.jit.annotate(List[OneTwo], [a, b])
655            for i in range(len(c)):
656                x = c[i].one(x, x)
657                x = c[i].two(x)
658            return x
659
660        self.checkScript(use_them, (torch.rand(3, 4),))
661
662        @torch.jit.script
663        def as_interface(x: OneTwo) -> OneTwo:
664            return x
665
666        @torch.jit.script
667        def inherit(x: OneTwoThree) -> OneTwo:
668            return as_interface(x)
669
670        with self.assertRaisesRegexWithHighlight(
671            RuntimeError, "does not have method", ""
672        ):
673
674            @torch.jit.script
675            def wrong1():
676                return as_interface(NotMember())
677
678        with self.assertRaisesRegexWithHighlight(
679            RuntimeError, "is not compatible with interface", ""
680        ):
681
682            @torch.jit.script
683            def wrong2():
684                return as_interface(NotMember2())
685
686        with self.assertRaisesRegexWithHighlight(
687            RuntimeError, "does not have method", ""
688        ):
689
690            @torch.jit.script
691            def wrong3():
692                return inherit(as_interface(Foo()))
693
694        with self.assertRaisesRegexWithHighlight(
695            RuntimeError, "is not compatible with interface", ""
696        ):
697
698            @torch.jit.script
699            def wrong4(x: OneTwoWrong) -> int:
700                return as_interface(x)
701
702        # Test interface/class python assignment
703        class TestPyAssign(nn.Module):
704            def __init__(self) -> None:
705                super().__init__()
706                self.proxy_mod = Foo()
707
708            def forward(self, x):
709                return self.proxy_mod.two(x)
710
711        TestPyAssign.__annotations__ = {"proxy_mod": OneTwo}
712
713        input = torch.rand(3, 4)
714        scripted_pyassign_mod = torch.jit.script(TestPyAssign())
715        imported_mod = self.getExportImportCopy(scripted_pyassign_mod)
716        self.assertEqual(scripted_pyassign_mod(input), imported_mod(input))
717
718        class TestPyAssignError(nn.Module):
719            def __init__(self, obj):
720                super().__init__()
721                self.proxy_mod = obj
722
723            def forward(self, x):
724                return self.proxy_mod.two(x)
725
726        TestPyAssignError.__annotations__ = {"proxy_mod": OneTwoThree}
727
728        with self.assertRaisesRegexWithHighlight(
729            RuntimeError, "is not compatible with interface __torch__", ""
730        ):
731            torch.jit.script(TestPyAssignError(Foo()))
732
733        # test pure python object assignment to interface fails
734        class PyClass:
735            def __init__(self) -> None:
736                pass
737
738        with self.assertRaisesRegexWithHighlight(
739            RuntimeError, "the value is not a TorchScript compatible type", ""
740        ):
741            torch.jit.script(TestPyAssignError(PyClass()))
742        # TODO test: interface-interface class-interface inheritance errors,
743        # NamedTuple inheritance errors
744
745    def test_overloaded_fn(self):
746        @torch.jit.script
747        class Foo:
748            def __init__(self, x):
749                self.x = x
750
751            def __len__(self) -> int:
752                return len(self.x)
753
754            def __neg__(self):
755                self.x = -self.x
756                return self
757
758            def __mul__(self, other: torch.Tensor) -> torch.Tensor:
759                return self.x * other
760
761        def test_overload():
762            a = Foo(torch.ones([3, 3]))
763            return len(a), -a * torch.zeros([3, 3])
764
765        make_global(Foo)  # see [local resolution in python]
766
767        self.checkScript(test_overload, ())
768        # unary ops tested above
769
770        # TODO - support compiling classes from strings in jit.CompilationUnit
771        @torch.jit.script
772        class MyClass:
773            def __init__(self, x: int) -> None:
774                self.x = x
775
776            def __add__(self, other: int) -> int:
777                return self.x + other
778
779            def __sub__(self, other: int) -> int:
780                return self.x - other
781
782            def __mul__(self, other: int) -> int:
783                return self.x * other
784
785            def __pow__(self, other: int) -> int:
786                return int(self.x**other)
787
788            def __truediv__(self, other: int) -> float:
789                return self.x / other
790
791            def __mod__(self, other: int) -> int:
792                return self.x % other
793
794            def __ne__(self, other: int) -> bool:
795                return self.x != other
796
797            def __eq__(self, other: int) -> bool:
798                return self.x == other
799
800            def __lt__(self, other: int) -> bool:
801                return self.x < other
802
803            def __gt__(self, other: int) -> bool:
804                return self.x > other
805
806            def __le__(self, other: int) -> bool:
807                return self.x <= other
808
809            def __ge__(self, other: int) -> bool:
810                return self.x >= other
811
812            def __and__(self, other: int) -> int:
813                return self.x & other
814
815            def __or__(self, other: int) -> int:
816                return self.x | other
817
818            def __xor__(self, other: int) -> int:
819                return self.x ^ other
820
821            def __getitem__(self, other: int) -> int:
822                return other + 1
823
824            def __setitem__(self, idx: int, val: int) -> None:
825                self.x = val * idx
826
827            def __call__(self, val: int) -> int:
828                return self.x * val * 3
829
830        make_global(Foo)  # see [local resolution in python]
831
832        def add():
833            return MyClass(4) + 3
834
835        def sub():  # noqa: E306
836            return MyClass(4) - 3
837
838        def mul():  # noqa: E306
839            return MyClass(4) * 3
840
841        def pow():  # noqa: E306
842            return MyClass(4) ** 3
843
844        def truediv():  # noqa: E306
845            return MyClass(4) / 3
846
847        def ne():  # noqa: E306
848            return MyClass(4) != 3
849
850        def eq():  # noqa: E306
851            return MyClass(4) == 3
852
853        def lt():  # noqa: E306
854            return MyClass(4) < 3
855
856        def gt():  # noqa: E306
857            return MyClass(4) > 3
858
859        def le():  # noqa: E306
860            return MyClass(4) <= 3
861
862        def ge():  # noqa: E306
863            return MyClass(4) >= 3
864
865        def _and():  # noqa: E306
866            return MyClass(4) & 3
867
868        def _or():  # noqa: E306
869            return MyClass(4) | 3
870
871        def _xor():  # noqa: E306
872            return MyClass(4) ^ 3
873
874        def getitem():  # noqa: E306
875            return MyClass(4)[1]
876
877        def setitem():  # noqa: E306
878            a = MyClass(4)
879            a[1] = 5
880            return a.x
881
882        def call():  # noqa: E306
883            a = MyClass(5)
884            return a(2)
885
886        ops = [
887            add,
888            sub,
889            mul,
890            pow,
891            ne,
892            eq,
893            lt,
894            gt,
895            le,
896            ge,
897            _and,
898            _or,
899            _xor,
900            getitem,
901            setitem,
902            call,
903        ]
904
905        ops.append(truediv)
906        for func in ops:
907            self.checkScript(func, ())
908
909        with self.assertRaisesRegexWithHighlight(
910            RuntimeError, "object has no attribute or method", ""
911        ):
912
913            @torch.jit.script
914            def test():
915                return Foo(torch.tensor(1)) + Foo(torch.tensor(1))
916
917    def test_cast_overloads(self):
918        @torch.jit.script
919        class Foo:
920            def __init__(self, val: float) -> None:
921                self.val = val
922
923            def __int__(self):
924                return int(self.val)
925
926            def __float__(self):
927                return self.val
928
929            def __bool__(self):
930                return bool(self.val)
931
932            def __str__(self):
933                return str(self.val)
934
935        make_global(Foo)  # see [local resolution in python]
936
937        def test(foo: Foo) -> Tuple[int, float, bool]:
938            if foo:
939                pass
940            return int(foo), float(foo), bool(foo)
941
942        fn = torch.jit.script(test)
943        self.assertEqual(fn(Foo(0.5)), test(0.5))
944        self.assertEqual(fn(Foo(0.0)), test(0.0))
945        # str has slightly different formatting
946        self.assertTrue("0.5" in (str(Foo(0.5))))
947        self.assertTrue("0." in (str(Foo(0.0))))
948
949        @torch.jit.script
950        class BadBool:
951            def __init__(self) -> None:
952                pass
953
954            def __bool__(self):
955                return (1, 2)
956
957        with self.assertRaisesRegexWithHighlight(
958            RuntimeError, "expected a bool expression for condition", ""
959        ):
960
961            @torch.jit.script
962            def test():
963                if BadBool():
964                    print(1)
965
966    def test_init_compiled_first(self):
967        @torch.jit.script  # noqa: B903
968        class Foo:  # noqa: B903
969            def __before_init__(self):
970                # accessing this field should not throw, since __init__ should be compiled
971                return self.x
972
973            def __init__(self, x, y):
974                self.x = x
975                self.y = y
976
977    def test_class_constructs_itself(self):
978        @torch.jit.script  # noqa: B903
979        class LSTMStateStack:  # noqa: B903
980            def __init__(self, num_layers: int, hidden_size: int) -> None:
981                self.num_layers = num_layers
982                self.hidden_size = hidden_size
983                self.last_state = (
984                    torch.zeros(num_layers, 1, hidden_size),
985                    torch.zeros(num_layers, 1, hidden_size),
986                )
987                self.stack = [(self.last_state[0][-1], self.last_state[0][-1])]
988
989            def copy(self):
990                # should be able to construct a class inside its own methods
991                other = LSTMStateStack(self.num_layers, self.hidden_size)
992                other.stack = list(self.stack)
993                return other
994
995    def test_optional_type_promotion(self):
996        @torch.jit.script
997        class Leaf:
998            def __init__(self) -> None:
999                self.x = 1
1000
1001        # should not throw
1002        @torch.jit.script  # noqa: B903
1003        class Tree:  # noqa: B903
1004            def __init__(self) -> None:
1005                self.child = torch.jit.annotate(Optional[Leaf], None)
1006
1007            def add_child(self, child: Leaf) -> None:
1008                self.child = child
1009
1010    def test_recursive_class(self):
1011        """
1012        Recursive class types not yet supported. We should give a good error message.
1013        """
1014        with self.assertRaises(RuntimeError):
1015
1016            @torch.jit.script  # noqa: B903
1017            class Tree:  # noqa: B903
1018                def __init__(self) -> None:
1019                    self.parent = torch.jit.annotate(Optional[Tree], None)
1020
1021    def test_class_constant(self):
1022        class M(torch.nn.Module):
1023            __constants__ = ["w"]
1024
1025            def __init__(self, w):
1026                super().__init__()
1027                self.w = w
1028
1029            def forward(self, x):
1030                # Make sure class constant is accessible in method
1031                y = self.w
1032                return x, y
1033
1034        # Test serialization/deserialization of class constant
1035        for c in (2, 1.0, None, True, "str", (2, 3), [5.9, 7.3]):
1036            m = torch.jit.script(M(c))
1037            buffer = io.BytesIO()
1038            torch.jit.save(m, buffer)
1039
1040            buffer.seek(0)
1041            m_loaded = torch.jit.load(buffer)
1042            input = torch.rand(2, 3)
1043            self.assertEqual(m(input), m_loaded(input))
1044            # Make sure class constant is accessible from module
1045            self.assertEqual(m.w, m_loaded.w)
1046
1047    def test_py_class_to_ivalue_missing_attribute(self):
1048        class Foo:
1049            i: int
1050            f: float
1051
1052            def __init__(self, i: int, f: float):
1053                self.i = i
1054                self.f = f
1055
1056        make_global(Foo)  # see [local resolution in python]
1057
1058        @torch.jit.script
1059        def test_fn(x: Foo) -> float:
1060            return x.i + x.f
1061
1062        test_fn(Foo(3, 4.0))
1063
1064        with self.assertRaisesRegexWithHighlight(
1065            RuntimeError, "missing attribute i", ""
1066        ):
1067            test_fn(torch.rand(3, 4))
1068
1069    def test_unused_method(self):
1070        """
1071        Test unused methods on scripted classes.
1072        """
1073
1074        @torch.jit.script
1075        class Unused:
1076            def __init__(self) -> None:
1077                self.count: int = 0
1078                self.items: List[int] = []
1079
1080            def used(self):
1081                self.count += 1
1082                return self.count
1083
1084            @torch.jit.unused
1085            def unused(self, x: int, y: Iterable[int], **kwargs) -> int:
1086                a = next(self.items)
1087                return a
1088
1089            def uses_unused(self) -> int:
1090                return self.unused(y="hi", x=3)
1091
1092        class ModuleWithUnused(nn.Module):
1093            def __init__(self) -> None:
1094                super().__init__()
1095                self.obj = Unused()
1096
1097            def forward(self):
1098                return self.obj.used()
1099
1100            @torch.jit.export
1101            def calls_unused(self):
1102                return self.obj.unused(3, "hi")
1103
1104            @torch.jit.export
1105            def calls_unused_indirectly(self):
1106                return self.obj.uses_unused()
1107
1108        python_module = ModuleWithUnused()
1109        script_module = torch.jit.script(ModuleWithUnused())
1110
1111        # Forward should work because it does not used any methods marked unused.
1112        self.assertEqual(python_module.forward(), script_module.forward())
1113
1114        # Calling a method marked unused should throw.
1115        with self.assertRaises(torch.jit.Error):
1116            script_module.calls_unused()
1117
1118        with self.assertRaises(torch.jit.Error):
1119            script_module.calls_unused_indirectly()
1120
1121    def test_self_referential_method(self):
1122        """
1123        Test that a scripted class can have a method that refers to the class itself
1124        in its type annotations.
1125        """
1126
1127        @torch.jit.script
1128        class Meta:
1129            def __init__(self, a: int):
1130                self.a = a
1131
1132            def method(self, other: List["Meta"]) -> "Meta":
1133                return Meta(len(other))
1134
1135        class ModuleWithMeta(torch.nn.Module):
1136            def __init__(self, a: int):
1137                super().__init__()
1138                self.meta = Meta(a)
1139
1140            def forward(self):
1141                new_obj = self.meta.method([self.meta])
1142                return new_obj.a
1143
1144        self.checkModule(ModuleWithMeta(5), ())
1145
1146    def test_type_annotation(self):
1147        """
1148        Test that annotating container attributes with types works correctly
1149        """
1150
1151        @torch.jit.script
1152        class CompetitiveLinkingTokenReplacementUtils:
1153            def __init__(self) -> None:
1154                self.my_list: List[Tuple[float, int, int]] = []
1155                self.my_dict: Dict[int, int] = {}
1156
1157        @torch.jit.script
1158        def foo():
1159            y = CompetitiveLinkingTokenReplacementUtils()
1160            new_dict: Dict[int, int] = {1: 1, 2: 2}
1161            y.my_dict = new_dict
1162
1163            new_list: List[Tuple[float, int, int]] = [(1.0, 1, 1)]
1164            y.my_list = new_list
1165            return y
1166
1167    def test_default_args(self):
1168        """
1169        Test that methods on class types can have default arguments.
1170        """
1171
1172        @torch.jit.script
1173        class ClassWithDefaultArgs:
1174            def __init__(
1175                self,
1176                a: int = 1,
1177                b: Optional[List[int]] = None,
1178                c: Tuple[int, int, int] = (1, 2, 3),
1179                d: Optional[Dict[int, int]] = None,
1180                e: Optional[str] = None,
1181            ):
1182                self.int = a
1183                self.tup = c
1184                self.str = e
1185
1186                self.list = [1, 2, 3]
1187                if b is not None:
1188                    self.list = b
1189
1190                self.dict = {1: 2, 3: 4}
1191                if d is not None:
1192                    self.dict = d
1193
1194            def add(self, b: int, scale: float = 1.0) -> float:
1195                return self.int * scale + b
1196
1197        def all_defaults() -> int:
1198            obj: ClassWithDefaultArgs = ClassWithDefaultArgs()
1199            return obj.int + obj.list[2] + obj.tup[1]
1200
1201        def some_defaults() -> int:
1202            obj: ClassWithDefaultArgs = ClassWithDefaultArgs(b=[5, 6, 7])
1203            return obj.int + obj.list[2] + obj.dict[1]
1204
1205        def override_defaults() -> int:
1206            obj: ClassWithDefaultArgs = ClassWithDefaultArgs(
1207                3, [9, 10, 11], (12, 13, 14), {3: 4}, "str"
1208            )
1209            s: int = obj.int
1210
1211            for x in obj.list:
1212                s += x
1213
1214            for y in obj.tup:
1215                s += y
1216
1217            s += obj.dict[3]
1218
1219            st = obj.str
1220            if st is not None:
1221                s += len(st)
1222
1223            return s
1224
1225        def method_defaults() -> float:
1226            obj: ClassWithDefaultArgs = ClassWithDefaultArgs()
1227            return obj.add(3) + obj.add(3, 0.25)
1228
1229        self.checkScript(all_defaults, ())
1230        self.checkScript(some_defaults, ())
1231        self.checkScript(override_defaults, ())
1232        self.checkScript(method_defaults, ())
1233
1234        # The constructor of this class below has some arguments without default values.
1235        class ClassWithSomeDefaultArgs:  # noqa: B903
1236            def __init__(
1237                self,
1238                a: int,
1239                b: int = 1,
1240            ):
1241                self.a = a
1242                self.b = b
1243
1244        def default_b() -> int:
1245            obj: ClassWithSomeDefaultArgs = ClassWithSomeDefaultArgs(1)
1246            return obj.a + obj.b
1247
1248        def set_b() -> int:
1249            obj: ClassWithSomeDefaultArgs = ClassWithSomeDefaultArgs(1, 4)
1250            return obj.a + obj.b
1251
1252        self.checkScript(default_b, ())
1253        self.checkScript(set_b, ())
1254
1255        # The constructor of this class below has mutable arguments. This should throw
1256        # an error.
1257        class ClassWithMutableArgs:  # noqa: B903
1258            def __init__(
1259                self,
1260                a: List[int] = [1, 2, 3],  # noqa: B006
1261            ):
1262                self.a = a
1263
1264        def should_fail():
1265            obj: ClassWithMutableArgs = ClassWithMutableArgs()
1266
1267        with self.assertRaisesRegexWithHighlight(
1268            RuntimeError, "Mutable default parameters are not supported", ""
1269        ):
1270            torch.jit.script(should_fail)
1271
1272    def test_staticmethod(self):
1273        """
1274        Test static methods on class types.
1275        """
1276
1277        @torch.jit.script
1278        class ClassWithStaticMethod:
1279            def __init__(self, a: int, b: int):
1280                self.a: int = a
1281                self.b: int = b
1282
1283            def get_a(self):
1284                return self.a
1285
1286            def get_b(self):
1287                return self.b
1288
1289            def __eq__(self, other: "ClassWithStaticMethod"):
1290                return self.a == other.a and self.b == other.b
1291
1292            # staticmethod that calls constructor.
1293            @staticmethod
1294            def create(args: List["ClassWithStaticMethod"]) -> "ClassWithStaticMethod":
1295                return ClassWithStaticMethod(args[0].a, args[0].b)
1296
1297            # staticmethod that calls another staticmethod.
1298            @staticmethod
1299            def create_from(a: int, b: int) -> "ClassWithStaticMethod":
1300                a = ClassWithStaticMethod(a, b)
1301                return ClassWithStaticMethod.create([a])
1302
1303        # Script function that calls staticmethod.
1304        def test_function(a: int, b: int) -> "ClassWithStaticMethod":
1305            return ClassWithStaticMethod.create_from(a, b)
1306
1307        make_global(ClassWithStaticMethod)
1308
1309        self.checkScript(test_function, (1, 2))
1310
1311    def test_classmethod(self):
1312        """
1313        Test classmethods on class types.
1314        """
1315
1316        @torch.jit.script
1317        class ClassWithClassMethod:
1318            def __init__(self, a: int):
1319                self.a: int = a
1320
1321            def __eq__(self, other: "ClassWithClassMethod"):
1322                return self.a == other.a
1323
1324            @classmethod
1325            def create(cls, a: int) -> "ClassWithClassMethod":
1326                return cls(a)
1327
1328        make_global(ClassWithClassMethod)
1329
1330        def test_function(a: int) -> "ClassWithClassMethod":
1331            x = ClassWithClassMethod(a)
1332            # Support calling classmethod with an instance
1333            # Calling with the class is not supported.
1334            return x.create(a)
1335
1336        self.checkScript(test_function, (1,))
1337
1338    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
1339    def test_properties(self):
1340        """
1341        Test that a scripted class can make use of the @property decorator.
1342        """
1343
1344        def free_function(x: int) -> int:
1345            return x + 1
1346
1347        @torch.jit.script
1348        class Properties:
1349            __jit_unused_properties__ = ["unsupported"]
1350
1351            def __init__(self, a: int):
1352                self.a = a
1353
1354            @property
1355            def attr(self) -> int:
1356                return self.a - 1
1357
1358            @property
1359            def unsupported(self) -> int:
1360                return sum([self.a])
1361
1362            @torch.jit.unused
1363            @property
1364            def unsupported_2(self) -> int:
1365                return sum([self.a])
1366
1367            @unsupported_2.setter
1368            def unsupported_2(self, value):
1369                self.a = sum([self.a])
1370
1371            @attr.setter
1372            def attr(self, value: int):
1373                self.a = value + 3
1374
1375        @torch.jit.script
1376        class NoSetter:
1377            def __init__(self, a: int):
1378                self.a = a
1379
1380            @property
1381            def attr(self) -> int:
1382                return free_function(self.a)
1383
1384        @torch.jit.script
1385        class MethodThatUsesProperty:
1386            def __init__(self, a: int):
1387                self.a = a
1388
1389            @property
1390            def attr(self) -> int:
1391                return self.a - 2
1392
1393            @attr.setter
1394            def attr(self, value: int):
1395                self.a = value + 4
1396
1397            def forward(self):
1398                return self.attr
1399
1400        class ModuleWithProperties(torch.nn.Module):
1401            def __init__(self, a: int):
1402                super().__init__()
1403                self.props = Properties(a)
1404
1405            def forward(self, a: int, b: int, c: int, d: int):
1406                self.props.attr = a
1407                props = Properties(b)
1408                no_setter = NoSetter(c)
1409                method_uses_property = MethodThatUsesProperty(a + b)
1410
1411                props.attr = c
1412                method_uses_property.attr = d
1413
1414                return self.props.attr + no_setter.attr + method_uses_property.forward()
1415
1416        self.checkModule(
1417            ModuleWithProperties(5),
1418            (
1419                5,
1420                6,
1421                7,
1422                8,
1423            ),
1424        )
1425
1426    def test_custom_delete(self):
1427        """
1428        Test that del can be called on an instance of a class that
1429        overrides __delitem__.
1430        """
1431
1432        class Example:
1433            def __init__(self) -> None:
1434                self._data: Dict[str, torch.Tensor] = {"1": torch.tensor(1.0)}
1435
1436            def check(self, key: str) -> bool:
1437                return key in self._data
1438
1439            def __delitem__(self, key: str):
1440                del self._data[key]
1441
1442        def fn() -> bool:
1443            example = Example()
1444            del example["1"]
1445            return example.check("1")
1446
1447        self.checkScript(fn, ())
1448
1449        # Test the case in which the class does not have __delitem__ defined.
1450        class NoDelItem:
1451            def __init__(self) -> None:
1452                self._data: Dict[str, torch.Tensor] = {"1": torch.tensor(1.0)}
1453
1454            def check(self, key: str) -> bool:
1455                return key in self._data
1456
1457        def fn() -> bool:
1458            example = NoDelItem()
1459            key = "1"
1460            del example[key]
1461            return example.check(key)
1462
1463        with self.assertRaisesRegexWithHighlight(
1464            RuntimeError, r"Class does not define __delitem__", "example[key]"
1465        ):
1466            self.checkScript(fn, ())
1467
1468    def test_recursive_script_builtin_type_resolution(self):
1469        """
1470        Test resolution of built-in torch types(e.g. torch.Tensor, torch.device) when a class is recursively compiled.
1471        """
1472        # A will be implicitly compiled because it is not annotated with @torch.jit.script
1473        # but is used in g() below.
1474        tensor_t = torch.Tensor
1475        device_t = torch.device
1476        device_ty = torch.device
1477
1478        class A:
1479            def __init__(self) -> None:
1480                pass
1481
1482            def f(self, x: tensor_t, y: torch.device) -> tensor_t:
1483                return x.to(device=y)
1484
1485            def g(self, x: device_t) -> device_ty:
1486                return x
1487
1488            def h(self, a: "A") -> "A":
1489                return A()
1490
1491            def i(self, a: List[int]) -> int:
1492                return a[0]
1493
1494            def j(self, l: List[device_t]) -> device_ty:
1495                return l[0]
1496
1497        def call_f():
1498            a = A()
1499            return a.f(torch.tensor([1]), torch.device("cpu"))
1500
1501        def call_g():
1502            a = A()
1503            return a.g(torch.device("cpu"))
1504
1505        def call_i():
1506            a = A()
1507            return a.i([3])
1508
1509        def call_j():
1510            a = A()
1511            return a.j([torch.device("cpu"), torch.device("cpu")])
1512
1513        for fn in [call_f, call_g, call_i, call_j]:
1514            self.checkScript(fn, ())
1515            s = self.getExportImportCopy(torch.jit.script(fn))
1516            self.assertEqual(s(), fn())
1517
1518    def test_recursive_script_module_builtin_type_resolution(self):
1519        """
1520        Test resolution of built-in torch types(e.g. torch.Tensor, torch.device) when a class is recursively compiled
1521        when compiling a module.
1522        """
1523
1524        class Wrapper:
1525            def __init__(self, t):
1526                self.t = t
1527
1528            def to(self, l: List[torch.device], device: Optional[torch.device] = None):
1529                return self.t.to(device=device)
1530
1531        class A(nn.Module):
1532            def forward(self):
1533                return Wrapper(torch.rand(4, 4))
1534
1535        scripted = torch.jit.script(A())
1536        self.getExportImportCopy(scripted)
1537
1538    def test_class_attribute_wrong_type(self):
1539        """
1540        Test that the error message displayed when convering a class type
1541        to an IValue that has an attribute of the wrong type.
1542        """
1543
1544        @torch.jit.script  # noqa: B903
1545        class ValHolder:  # noqa: B903
1546            def __init__(self, val):
1547                self.val = val
1548
1549        class Mod(nn.Module):
1550            def __init__(self) -> None:
1551                super().__init__()
1552                self.mod1 = ValHolder("1")
1553                self.mod2 = ValHolder("2")
1554
1555            def forward(self, cond: bool):
1556                if cond:
1557                    mod = self.mod1
1558                else:
1559                    mod = self.mod2
1560                return mod.val
1561
1562        with self.assertRaisesRegexWithHighlight(
1563            RuntimeError, "Could not cast attribute 'val' to type Tensor", ""
1564        ):
1565            torch.jit.script(Mod())
1566
1567    def test_recursive_scripting(self):
1568        """
1569        Test that class types are recursively scripted when an Python instance of one
1570        is encountered as a module attribute.
1571        """
1572
1573        class Class:
1574            def __init__(self, a: int):
1575                self.a = a
1576
1577            def get_a(self) -> int:
1578                return self.a
1579
1580        class M(torch.nn.Module):
1581            def __init__(self, obj):
1582                super().__init__()
1583                self.obj = obj
1584
1585            def forward(self) -> int:
1586                return self.obj.get_a()
1587
1588        self.checkModule(M(Class(4)), ())
1589
1590    def test_recursive_scripting_failed(self):
1591        """
1592        Test that class types module attributes that fail to script
1593        are added as failed attributes and do not cause compilation itself
1594        to fail unless they are used in scripted code.
1595        """
1596
1597        class UnscriptableClass:
1598            def __init__(self, a: int):
1599                self.a = a
1600
1601            def get_a(self) -> bool:
1602                return issubclass(self.a, int)
1603
1604        # This Module has an attribute of type UnscriptableClass
1605        # and tries to use it in scripted code. This should fail.
1606        class ShouldNotCompile(torch.nn.Module):
1607            def __init__(self, obj):
1608                super().__init__()
1609                self.obj = obj
1610
1611            def forward(self) -> bool:
1612                return self.obj.get_a()
1613
1614        with self.assertRaisesRegexWithHighlight(
1615            RuntimeError, "failed to convert Python type", ""
1616        ):
1617            torch.jit.script(ShouldNotCompile(UnscriptableClass(4)))
1618
1619        # This Module has an attribute of type UnscriptableClass
1620        # and does not try to use it in scripted code. This should not fail.
1621        class ShouldCompile(torch.nn.Module):
1622            def __init__(self, obj):
1623                super().__init__()
1624                self.obj = obj
1625
1626            @torch.jit.ignore
1627            def ignored_method(self) -> bool:
1628                return self.obj.get_a()
1629
1630            def forward(self, x: int) -> int:
1631                return x + x
1632
1633        self.checkModule(ShouldCompile(UnscriptableClass(4)), (4,))
1634
1635    def test_unresolved_class_attributes(self):
1636        class UnresolvedAttrClass:
1637            def __init__(self) -> None:
1638                pass
1639
1640            (attr_a, attr_b), [attr_c, attr_d] = ("", ""), ["", ""]
1641            attr_e: int = 0
1642
1643        def fn_a():
1644            u = UnresolvedAttrClass()
1645            return u.attr_a
1646
1647        def fn_b():
1648            u = UnresolvedAttrClass()
1649            return u.attr_b
1650
1651        def fn_c():
1652            u = UnresolvedAttrClass()
1653            return u.attr_c
1654
1655        def fn_d():
1656            u = UnresolvedAttrClass()
1657            return u.attr_d
1658
1659        def fn_e():
1660            u = UnresolvedAttrClass()
1661            return u.attr_e
1662
1663        error_message_regex = (
1664            "object has no attribute or method.*is defined as a class attribute"
1665        )
1666        for fn in (fn_a, fn_b, fn_c, fn_d, fn_e):
1667            with self.assertRaisesRegex(RuntimeError, error_message_regex):
1668                torch.jit.script(fn)
1669