xref: /aosp_15_r20/external/pytorch/test/jit/test_dataclasses.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2# flake8: noqa
3
4import sys
5import unittest
6from dataclasses import dataclass, field, InitVar
7from enum import Enum
8from typing import List, Optional
9
10from hypothesis import given, settings, strategies as st
11
12import torch
13from torch.testing._internal.jit_utils import JitTestCase
14
15
16# Example jittable dataclass
17@dataclass(order=True)
18class Point:
19    x: float
20    y: float
21    norm: Optional[torch.Tensor] = None
22
23    def __post_init__(self):
24        self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5
25
26
27class MixupScheme(Enum):
28    INPUT = ["input"]
29
30    MANIFOLD = [
31        "input",
32        "before_fusion_projection",
33        "after_fusion_projection",
34        "after_classifier_projection",
35    ]
36
37
38@dataclass
39class MixupParams:
40    def __init__(self, alpha: float = 0.125, scheme: MixupScheme = MixupScheme.INPUT):
41        self.alpha = alpha
42        self.scheme = scheme
43
44
45class MixupScheme2(Enum):
46    A = 1
47    B = 2
48
49
50@dataclass
51class MixupParams2:
52    def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
53        self.alpha = alpha
54        self.scheme = scheme
55
56
57@dataclass
58class MixupParams3:
59    def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
60        self.alpha = alpha
61        self.scheme = scheme
62
63
64# Make sure the Meta internal tooling doesn't raise an overflow error
65NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False)
66
67
68class TestDataclasses(JitTestCase):
69    @classmethod
70    def tearDownClass(cls):
71        torch._C._jit_clear_class_registry()
72
73    def test_init_vars(self):
74        @torch.jit.script
75        @dataclass(order=True)
76        class Point2:
77            x: float
78            y: float
79            norm_p: InitVar[int] = 2
80            norm: Optional[torch.Tensor] = None
81
82            def __post_init__(self, norm_p: int):
83                self.norm = (
84                    torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p
85                ) ** (1 / norm_p)
86
87        def fn(x: float, y: float, p: int):
88            pt = Point2(x, y, p)
89            return pt.norm
90
91        self.checkScript(fn, (1.0, 2.0, 3))
92
93    # Sort of tests both __post_init__ and optional fields
94    @settings(deadline=None)
95    @given(NonHugeFloats, NonHugeFloats)
96    def test__post_init__(self, x, y):
97        P = torch.jit.script(Point)
98
99        def fn(x: float, y: float):
100            pt = P(x, y)
101            return pt.norm
102
103        self.checkScript(fn, [x, y])
104
105    @settings(deadline=None)
106    @given(
107        st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats)
108    )
109    def test_comparators(self, pt1, pt2):
110        x1, y1 = pt1
111        x2, y2 = pt2
112        P = torch.jit.script(Point)
113
114        def compare(x1: float, y1: float, x2: float, y2: float):
115            pt1 = P(x1, y1)
116            pt2 = P(x2, y2)
117            return (
118                pt1 == pt2,
119                # pt1 != pt2,   # TODO: Modify interpreter to auto-resolve (a != b) to not (a == b) when there's no __ne__
120                pt1 < pt2,
121                pt1 <= pt2,
122                pt1 > pt2,
123                pt1 >= pt2,
124            )
125
126        self.checkScript(compare, [x1, y1, x2, y2])
127
128    def test_default_factories(self):
129        @dataclass
130        class Foo(object):
131            x: List[int] = field(default_factory=list)
132
133        with self.assertRaises(NotImplementedError):
134            torch.jit.script(Foo)
135
136            def fn():
137                foo = Foo()
138                return foo.x
139
140            torch.jit.script(fn)()
141
142    # The user should be able to write their own __eq__ implementation
143    # without us overriding it.
144    def test_custom__eq__(self):
145        @torch.jit.script
146        @dataclass
147        class CustomEq:
148            a: int
149            b: int
150
151            def __eq__(self, other: "CustomEq") -> bool:
152                return self.a == other.a  # ignore the b field
153
154        def fn(a: int, b1: int, b2: int):
155            pt1 = CustomEq(a, b1)
156            pt2 = CustomEq(a, b2)
157            return pt1 == pt2
158
159        self.checkScript(fn, [1, 2, 3])
160
161    def test_no_source(self):
162        with self.assertRaises(RuntimeError):
163            # uses list in Enum is not supported
164            torch.jit.script(MixupParams)
165
166        torch.jit.script(MixupParams2)  # don't throw
167
168    def test_use_unregistered_dataclass_raises(self):
169        def f(a: MixupParams3):
170            return 0
171
172        with self.assertRaises(OSError):
173            torch.jit.script(f)
174