1# Owner(s): ["oncall: jit"] 2# flake8: noqa 3 4import sys 5import unittest 6from dataclasses import dataclass, field, InitVar 7from enum import Enum 8from typing import List, Optional 9 10from hypothesis import given, settings, strategies as st 11 12import torch 13from torch.testing._internal.jit_utils import JitTestCase 14 15 16# Example jittable dataclass 17@dataclass(order=True) 18class Point: 19 x: float 20 y: float 21 norm: Optional[torch.Tensor] = None 22 23 def __post_init__(self): 24 self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5 25 26 27class MixupScheme(Enum): 28 INPUT = ["input"] 29 30 MANIFOLD = [ 31 "input", 32 "before_fusion_projection", 33 "after_fusion_projection", 34 "after_classifier_projection", 35 ] 36 37 38@dataclass 39class MixupParams: 40 def __init__(self, alpha: float = 0.125, scheme: MixupScheme = MixupScheme.INPUT): 41 self.alpha = alpha 42 self.scheme = scheme 43 44 45class MixupScheme2(Enum): 46 A = 1 47 B = 2 48 49 50@dataclass 51class MixupParams2: 52 def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A): 53 self.alpha = alpha 54 self.scheme = scheme 55 56 57@dataclass 58class MixupParams3: 59 def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A): 60 self.alpha = alpha 61 self.scheme = scheme 62 63 64# Make sure the Meta internal tooling doesn't raise an overflow error 65NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False) 66 67 68class TestDataclasses(JitTestCase): 69 @classmethod 70 def tearDownClass(cls): 71 torch._C._jit_clear_class_registry() 72 73 def test_init_vars(self): 74 @torch.jit.script 75 @dataclass(order=True) 76 class Point2: 77 x: float 78 y: float 79 norm_p: InitVar[int] = 2 80 norm: Optional[torch.Tensor] = None 81 82 def __post_init__(self, norm_p: int): 83 self.norm = ( 84 torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p 85 ) ** (1 / norm_p) 86 87 def fn(x: float, y: float, p: int): 88 pt = Point2(x, y, p) 89 return pt.norm 90 91 self.checkScript(fn, (1.0, 2.0, 3)) 92 93 # Sort of tests both __post_init__ and optional fields 94 @settings(deadline=None) 95 @given(NonHugeFloats, NonHugeFloats) 96 def test__post_init__(self, x, y): 97 P = torch.jit.script(Point) 98 99 def fn(x: float, y: float): 100 pt = P(x, y) 101 return pt.norm 102 103 self.checkScript(fn, [x, y]) 104 105 @settings(deadline=None) 106 @given( 107 st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats) 108 ) 109 def test_comparators(self, pt1, pt2): 110 x1, y1 = pt1 111 x2, y2 = pt2 112 P = torch.jit.script(Point) 113 114 def compare(x1: float, y1: float, x2: float, y2: float): 115 pt1 = P(x1, y1) 116 pt2 = P(x2, y2) 117 return ( 118 pt1 == pt2, 119 # pt1 != pt2, # TODO: Modify interpreter to auto-resolve (a != b) to not (a == b) when there's no __ne__ 120 pt1 < pt2, 121 pt1 <= pt2, 122 pt1 > pt2, 123 pt1 >= pt2, 124 ) 125 126 self.checkScript(compare, [x1, y1, x2, y2]) 127 128 def test_default_factories(self): 129 @dataclass 130 class Foo(object): 131 x: List[int] = field(default_factory=list) 132 133 with self.assertRaises(NotImplementedError): 134 torch.jit.script(Foo) 135 136 def fn(): 137 foo = Foo() 138 return foo.x 139 140 torch.jit.script(fn)() 141 142 # The user should be able to write their own __eq__ implementation 143 # without us overriding it. 144 def test_custom__eq__(self): 145 @torch.jit.script 146 @dataclass 147 class CustomEq: 148 a: int 149 b: int 150 151 def __eq__(self, other: "CustomEq") -> bool: 152 return self.a == other.a # ignore the b field 153 154 def fn(a: int, b1: int, b2: int): 155 pt1 = CustomEq(a, b1) 156 pt2 = CustomEq(a, b2) 157 return pt1 == pt2 158 159 self.checkScript(fn, [1, 2, 3]) 160 161 def test_no_source(self): 162 with self.assertRaises(RuntimeError): 163 # uses list in Enum is not supported 164 torch.jit.script(MixupParams) 165 166 torch.jit.script(MixupParams2) # don't throw 167 168 def test_use_unregistered_dataclass_raises(self): 169 def f(a: MixupParams3): 170 return 0 171 172 with self.assertRaises(OSError): 173 torch.jit.script(f) 174