1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker# flake8: noqa 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass, field, InitVar 7*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum 8*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Optional 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerfrom hypothesis import given, settings, strategies as st 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerimport torch 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker# Example jittable dataclass 17*da0073e9SAndroid Build Coastguard Worker@dataclass(order=True) 18*da0073e9SAndroid Build Coastguard Workerclass Point: 19*da0073e9SAndroid Build Coastguard Worker x: float 20*da0073e9SAndroid Build Coastguard Worker y: float 21*da0073e9SAndroid Build Coastguard Worker norm: Optional[torch.Tensor] = None 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker def __post_init__(self): 24*da0073e9SAndroid Build Coastguard Worker self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerclass MixupScheme(Enum): 28*da0073e9SAndroid Build Coastguard Worker INPUT = ["input"] 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker MANIFOLD = [ 31*da0073e9SAndroid Build Coastguard Worker "input", 32*da0073e9SAndroid Build Coastguard Worker "before_fusion_projection", 33*da0073e9SAndroid Build Coastguard Worker "after_fusion_projection", 34*da0073e9SAndroid Build Coastguard Worker "after_classifier_projection", 35*da0073e9SAndroid Build Coastguard Worker ] 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker@dataclass 39*da0073e9SAndroid Build Coastguard Workerclass MixupParams: 40*da0073e9SAndroid Build Coastguard Worker def __init__(self, alpha: float = 0.125, scheme: MixupScheme = MixupScheme.INPUT): 41*da0073e9SAndroid Build Coastguard Worker self.alpha = alpha 42*da0073e9SAndroid Build Coastguard Worker self.scheme = scheme 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Workerclass MixupScheme2(Enum): 46*da0073e9SAndroid Build Coastguard Worker A = 1 47*da0073e9SAndroid Build Coastguard Worker B = 2 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker@dataclass 51*da0073e9SAndroid Build Coastguard Workerclass MixupParams2: 52*da0073e9SAndroid Build Coastguard Worker def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A): 53*da0073e9SAndroid Build Coastguard Worker self.alpha = alpha 54*da0073e9SAndroid Build Coastguard Worker self.scheme = scheme 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker@dataclass 58*da0073e9SAndroid Build Coastguard Workerclass MixupParams3: 59*da0073e9SAndroid Build Coastguard Worker def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A): 60*da0073e9SAndroid Build Coastguard Worker self.alpha = alpha 61*da0073e9SAndroid Build Coastguard Worker self.scheme = scheme 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker# Make sure the Meta internal tooling doesn't raise an overflow error 65*da0073e9SAndroid Build Coastguard WorkerNonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Workerclass TestDataclasses(JitTestCase): 69*da0073e9SAndroid Build Coastguard Worker @classmethod 70*da0073e9SAndroid Build Coastguard Worker def tearDownClass(cls): 71*da0073e9SAndroid Build Coastguard Worker torch._C._jit_clear_class_registry() 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker def test_init_vars(self): 74*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 75*da0073e9SAndroid Build Coastguard Worker @dataclass(order=True) 76*da0073e9SAndroid Build Coastguard Worker class Point2: 77*da0073e9SAndroid Build Coastguard Worker x: float 78*da0073e9SAndroid Build Coastguard Worker y: float 79*da0073e9SAndroid Build Coastguard Worker norm_p: InitVar[int] = 2 80*da0073e9SAndroid Build Coastguard Worker norm: Optional[torch.Tensor] = None 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker def __post_init__(self, norm_p: int): 83*da0073e9SAndroid Build Coastguard Worker self.norm = ( 84*da0073e9SAndroid Build Coastguard Worker torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p 85*da0073e9SAndroid Build Coastguard Worker ) ** (1 / norm_p) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def fn(x: float, y: float, p: int): 88*da0073e9SAndroid Build Coastguard Worker pt = Point2(x, y, p) 89*da0073e9SAndroid Build Coastguard Worker return pt.norm 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1.0, 2.0, 3)) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker # Sort of tests both __post_init__ and optional fields 94*da0073e9SAndroid Build Coastguard Worker @settings(deadline=None) 95*da0073e9SAndroid Build Coastguard Worker @given(NonHugeFloats, NonHugeFloats) 96*da0073e9SAndroid Build Coastguard Worker def test__post_init__(self, x, y): 97*da0073e9SAndroid Build Coastguard Worker P = torch.jit.script(Point) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker def fn(x: float, y: float): 100*da0073e9SAndroid Build Coastguard Worker pt = P(x, y) 101*da0073e9SAndroid Build Coastguard Worker return pt.norm 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, [x, y]) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker @settings(deadline=None) 106*da0073e9SAndroid Build Coastguard Worker @given( 107*da0073e9SAndroid Build Coastguard Worker st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats) 108*da0073e9SAndroid Build Coastguard Worker ) 109*da0073e9SAndroid Build Coastguard Worker def test_comparators(self, pt1, pt2): 110*da0073e9SAndroid Build Coastguard Worker x1, y1 = pt1 111*da0073e9SAndroid Build Coastguard Worker x2, y2 = pt2 112*da0073e9SAndroid Build Coastguard Worker P = torch.jit.script(Point) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker def compare(x1: float, y1: float, x2: float, y2: float): 115*da0073e9SAndroid Build Coastguard Worker pt1 = P(x1, y1) 116*da0073e9SAndroid Build Coastguard Worker pt2 = P(x2, y2) 117*da0073e9SAndroid Build Coastguard Worker return ( 118*da0073e9SAndroid Build Coastguard Worker pt1 == pt2, 119*da0073e9SAndroid Build Coastguard Worker # pt1 != pt2, # TODO: Modify interpreter to auto-resolve (a != b) to not (a == b) when there's no __ne__ 120*da0073e9SAndroid Build Coastguard Worker pt1 < pt2, 121*da0073e9SAndroid Build Coastguard Worker pt1 <= pt2, 122*da0073e9SAndroid Build Coastguard Worker pt1 > pt2, 123*da0073e9SAndroid Build Coastguard Worker pt1 >= pt2, 124*da0073e9SAndroid Build Coastguard Worker ) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker self.checkScript(compare, [x1, y1, x2, y2]) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker def test_default_factories(self): 129*da0073e9SAndroid Build Coastguard Worker @dataclass 130*da0073e9SAndroid Build Coastguard Worker class Foo(object): 131*da0073e9SAndroid Build Coastguard Worker x: List[int] = field(default_factory=list) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(NotImplementedError): 134*da0073e9SAndroid Build Coastguard Worker torch.jit.script(Foo) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker def fn(): 137*da0073e9SAndroid Build Coastguard Worker foo = Foo() 138*da0073e9SAndroid Build Coastguard Worker return foo.x 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker torch.jit.script(fn)() 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker # The user should be able to write their own __eq__ implementation 143*da0073e9SAndroid Build Coastguard Worker # without us overriding it. 144*da0073e9SAndroid Build Coastguard Worker def test_custom__eq__(self): 145*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 146*da0073e9SAndroid Build Coastguard Worker @dataclass 147*da0073e9SAndroid Build Coastguard Worker class CustomEq: 148*da0073e9SAndroid Build Coastguard Worker a: int 149*da0073e9SAndroid Build Coastguard Worker b: int 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other: "CustomEq") -> bool: 152*da0073e9SAndroid Build Coastguard Worker return self.a == other.a # ignore the b field 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker def fn(a: int, b1: int, b2: int): 155*da0073e9SAndroid Build Coastguard Worker pt1 = CustomEq(a, b1) 156*da0073e9SAndroid Build Coastguard Worker pt2 = CustomEq(a, b2) 157*da0073e9SAndroid Build Coastguard Worker return pt1 == pt2 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, [1, 2, 3]) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def test_no_source(self): 162*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 163*da0073e9SAndroid Build Coastguard Worker # uses list in Enum is not supported 164*da0073e9SAndroid Build Coastguard Worker torch.jit.script(MixupParams) 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker torch.jit.script(MixupParams2) # don't throw 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker def test_use_unregistered_dataclass_raises(self): 169*da0073e9SAndroid Build Coastguard Worker def f(a: MixupParams3): 170*da0073e9SAndroid Build Coastguard Worker return 0 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(OSError): 173*da0073e9SAndroid Build Coastguard Worker torch.jit.script(f) 174