xref: /aosp_15_r20/external/pytorch/test/jit/test_dataclasses.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker# flake8: noqa
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerimport unittest
6*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass, field, InitVar
7*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum
8*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Optional
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerfrom hypothesis import given, settings, strategies as st
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerimport torch
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker# Example jittable dataclass
17*da0073e9SAndroid Build Coastguard Worker@dataclass(order=True)
18*da0073e9SAndroid Build Coastguard Workerclass Point:
19*da0073e9SAndroid Build Coastguard Worker    x: float
20*da0073e9SAndroid Build Coastguard Worker    y: float
21*da0073e9SAndroid Build Coastguard Worker    norm: Optional[torch.Tensor] = None
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    def __post_init__(self):
24*da0073e9SAndroid Build Coastguard Worker        self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerclass MixupScheme(Enum):
28*da0073e9SAndroid Build Coastguard Worker    INPUT = ["input"]
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    MANIFOLD = [
31*da0073e9SAndroid Build Coastguard Worker        "input",
32*da0073e9SAndroid Build Coastguard Worker        "before_fusion_projection",
33*da0073e9SAndroid Build Coastguard Worker        "after_fusion_projection",
34*da0073e9SAndroid Build Coastguard Worker        "after_classifier_projection",
35*da0073e9SAndroid Build Coastguard Worker    ]
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker@dataclass
39*da0073e9SAndroid Build Coastguard Workerclass MixupParams:
40*da0073e9SAndroid Build Coastguard Worker    def __init__(self, alpha: float = 0.125, scheme: MixupScheme = MixupScheme.INPUT):
41*da0073e9SAndroid Build Coastguard Worker        self.alpha = alpha
42*da0073e9SAndroid Build Coastguard Worker        self.scheme = scheme
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Workerclass MixupScheme2(Enum):
46*da0073e9SAndroid Build Coastguard Worker    A = 1
47*da0073e9SAndroid Build Coastguard Worker    B = 2
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker@dataclass
51*da0073e9SAndroid Build Coastguard Workerclass MixupParams2:
52*da0073e9SAndroid Build Coastguard Worker    def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
53*da0073e9SAndroid Build Coastguard Worker        self.alpha = alpha
54*da0073e9SAndroid Build Coastguard Worker        self.scheme = scheme
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker@dataclass
58*da0073e9SAndroid Build Coastguard Workerclass MixupParams3:
59*da0073e9SAndroid Build Coastguard Worker    def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
60*da0073e9SAndroid Build Coastguard Worker        self.alpha = alpha
61*da0073e9SAndroid Build Coastguard Worker        self.scheme = scheme
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker# Make sure the Meta internal tooling doesn't raise an overflow error
65*da0073e9SAndroid Build Coastguard WorkerNonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False)
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Workerclass TestDataclasses(JitTestCase):
69*da0073e9SAndroid Build Coastguard Worker    @classmethod
70*da0073e9SAndroid Build Coastguard Worker    def tearDownClass(cls):
71*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_clear_class_registry()
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    def test_init_vars(self):
74*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
75*da0073e9SAndroid Build Coastguard Worker        @dataclass(order=True)
76*da0073e9SAndroid Build Coastguard Worker        class Point2:
77*da0073e9SAndroid Build Coastguard Worker            x: float
78*da0073e9SAndroid Build Coastguard Worker            y: float
79*da0073e9SAndroid Build Coastguard Worker            norm_p: InitVar[int] = 2
80*da0073e9SAndroid Build Coastguard Worker            norm: Optional[torch.Tensor] = None
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker            def __post_init__(self, norm_p: int):
83*da0073e9SAndroid Build Coastguard Worker                self.norm = (
84*da0073e9SAndroid Build Coastguard Worker                    torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p
85*da0073e9SAndroid Build Coastguard Worker                ) ** (1 / norm_p)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker        def fn(x: float, y: float, p: int):
88*da0073e9SAndroid Build Coastguard Worker            pt = Point2(x, y, p)
89*da0073e9SAndroid Build Coastguard Worker            return pt.norm
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1.0, 2.0, 3))
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    # Sort of tests both __post_init__ and optional fields
94*da0073e9SAndroid Build Coastguard Worker    @settings(deadline=None)
95*da0073e9SAndroid Build Coastguard Worker    @given(NonHugeFloats, NonHugeFloats)
96*da0073e9SAndroid Build Coastguard Worker    def test__post_init__(self, x, y):
97*da0073e9SAndroid Build Coastguard Worker        P = torch.jit.script(Point)
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker        def fn(x: float, y: float):
100*da0073e9SAndroid Build Coastguard Worker            pt = P(x, y)
101*da0073e9SAndroid Build Coastguard Worker            return pt.norm
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, [x, y])
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker    @settings(deadline=None)
106*da0073e9SAndroid Build Coastguard Worker    @given(
107*da0073e9SAndroid Build Coastguard Worker        st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats)
108*da0073e9SAndroid Build Coastguard Worker    )
109*da0073e9SAndroid Build Coastguard Worker    def test_comparators(self, pt1, pt2):
110*da0073e9SAndroid Build Coastguard Worker        x1, y1 = pt1
111*da0073e9SAndroid Build Coastguard Worker        x2, y2 = pt2
112*da0073e9SAndroid Build Coastguard Worker        P = torch.jit.script(Point)
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        def compare(x1: float, y1: float, x2: float, y2: float):
115*da0073e9SAndroid Build Coastguard Worker            pt1 = P(x1, y1)
116*da0073e9SAndroid Build Coastguard Worker            pt2 = P(x2, y2)
117*da0073e9SAndroid Build Coastguard Worker            return (
118*da0073e9SAndroid Build Coastguard Worker                pt1 == pt2,
119*da0073e9SAndroid Build Coastguard Worker                # pt1 != pt2,   # TODO: Modify interpreter to auto-resolve (a != b) to not (a == b) when there's no __ne__
120*da0073e9SAndroid Build Coastguard Worker                pt1 < pt2,
121*da0073e9SAndroid Build Coastguard Worker                pt1 <= pt2,
122*da0073e9SAndroid Build Coastguard Worker                pt1 > pt2,
123*da0073e9SAndroid Build Coastguard Worker                pt1 >= pt2,
124*da0073e9SAndroid Build Coastguard Worker            )
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        self.checkScript(compare, [x1, y1, x2, y2])
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    def test_default_factories(self):
129*da0073e9SAndroid Build Coastguard Worker        @dataclass
130*da0073e9SAndroid Build Coastguard Worker        class Foo(object):
131*da0073e9SAndroid Build Coastguard Worker            x: List[int] = field(default_factory=list)
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(NotImplementedError):
134*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(Foo)
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker            def fn():
137*da0073e9SAndroid Build Coastguard Worker                foo = Foo()
138*da0073e9SAndroid Build Coastguard Worker                return foo.x
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(fn)()
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    # The user should be able to write their own __eq__ implementation
143*da0073e9SAndroid Build Coastguard Worker    # without us overriding it.
144*da0073e9SAndroid Build Coastguard Worker    def test_custom__eq__(self):
145*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
146*da0073e9SAndroid Build Coastguard Worker        @dataclass
147*da0073e9SAndroid Build Coastguard Worker        class CustomEq:
148*da0073e9SAndroid Build Coastguard Worker            a: int
149*da0073e9SAndroid Build Coastguard Worker            b: int
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker            def __eq__(self, other: "CustomEq") -> bool:
152*da0073e9SAndroid Build Coastguard Worker                return self.a == other.a  # ignore the b field
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker        def fn(a: int, b1: int, b2: int):
155*da0073e9SAndroid Build Coastguard Worker            pt1 = CustomEq(a, b1)
156*da0073e9SAndroid Build Coastguard Worker            pt2 = CustomEq(a, b2)
157*da0073e9SAndroid Build Coastguard Worker            return pt1 == pt2
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, [1, 2, 3])
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    def test_no_source(self):
162*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
163*da0073e9SAndroid Build Coastguard Worker            # uses list in Enum is not supported
164*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(MixupParams)
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(MixupParams2)  # don't throw
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    def test_use_unregistered_dataclass_raises(self):
169*da0073e9SAndroid Build Coastguard Worker        def f(a: MixupParams3):
170*da0073e9SAndroid Build Coastguard Worker            return 0
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(OSError):
173*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(f)
174