1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum 6*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, List 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 13*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 14*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 19*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 20*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 21*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 22*da0073e9SAndroid Build Coastguard Worker "instead." 23*da0073e9SAndroid Build Coastguard Worker ) 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Workerclass TestEnum(JitTestCase): 27*da0073e9SAndroid Build Coastguard Worker def test_enum_value_types(self): 28*da0073e9SAndroid Build Coastguard Worker class IntEnum(Enum): 29*da0073e9SAndroid Build Coastguard Worker FOO = 1 30*da0073e9SAndroid Build Coastguard Worker BAR = 2 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker class FloatEnum(Enum): 33*da0073e9SAndroid Build Coastguard Worker FOO = 1.2 34*da0073e9SAndroid Build Coastguard Worker BAR = 2.3 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker class StringEnum(Enum): 37*da0073e9SAndroid Build Coastguard Worker FOO = "foo as in foo bar" 38*da0073e9SAndroid Build Coastguard Worker BAR = "bar as in foo bar" 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker make_global(IntEnum, FloatEnum, StringEnum) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 43*da0073e9SAndroid Build Coastguard Worker def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum): 44*da0073e9SAndroid Build Coastguard Worker return (a.name, b.name, c.name) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker FileCheck().check("IntEnum").check("FloatEnum").check("StringEnum").run( 47*da0073e9SAndroid Build Coastguard Worker str(supported_enum_types.graph) 48*da0073e9SAndroid Build Coastguard Worker ) 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker class TensorEnum(Enum): 51*da0073e9SAndroid Build Coastguard Worker FOO = torch.tensor(0) 52*da0073e9SAndroid Build Coastguard Worker BAR = torch.tensor(1) 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker make_global(TensorEnum) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker def unsupported_enum_types(a: TensorEnum): 57*da0073e9SAndroid Build Coastguard Worker return a.name 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker # TODO: rewrite code so that the highlight is not empty. 60*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 61*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Cannot create Enum with value type 'Tensor'", "" 62*da0073e9SAndroid Build Coastguard Worker ): 63*da0073e9SAndroid Build Coastguard Worker torch.jit.script(unsupported_enum_types) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def test_enum_comp(self): 66*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 67*da0073e9SAndroid Build Coastguard Worker RED = 1 68*da0073e9SAndroid Build Coastguard Worker GREEN = 2 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker make_global(Color) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 73*da0073e9SAndroid Build Coastguard Worker def enum_comp(x: Color, y: Color) -> bool: 74*da0073e9SAndroid Build Coastguard Worker return x == y 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::eq").run(str(enum_comp.graph)) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker self.assertEqual(enum_comp(Color.RED, Color.RED), True) 79*da0073e9SAndroid Build Coastguard Worker self.assertEqual(enum_comp(Color.RED, Color.GREEN), False) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker def test_enum_comp_diff_classes(self): 82*da0073e9SAndroid Build Coastguard Worker class Foo(Enum): 83*da0073e9SAndroid Build Coastguard Worker ITEM1 = 1 84*da0073e9SAndroid Build Coastguard Worker ITEM2 = 2 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker class Bar(Enum): 87*da0073e9SAndroid Build Coastguard Worker ITEM1 = 1 88*da0073e9SAndroid Build Coastguard Worker ITEM2 = 2 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker make_global(Foo, Bar) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 93*da0073e9SAndroid Build Coastguard Worker def enum_comp(x: Foo) -> bool: 94*da0073e9SAndroid Build Coastguard Worker return x == Bar.ITEM1 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::Constant").check_same("Bar.ITEM1").check( 97*da0073e9SAndroid Build Coastguard Worker "aten::eq" 98*da0073e9SAndroid Build Coastguard Worker ).run(str(enum_comp.graph)) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker self.assertEqual(enum_comp(Foo.ITEM1), False) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker def test_heterogenous_value_type_enum_error(self): 103*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 104*da0073e9SAndroid Build Coastguard Worker RED = 1 105*da0073e9SAndroid Build Coastguard Worker GREEN = "green" 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker make_global(Color) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def enum_comp(x: Color, y: Color) -> bool: 110*da0073e9SAndroid Build Coastguard Worker return x == y 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker # TODO: rewrite code so that the highlight is not empty. 113*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 114*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Could not unify type list", "" 115*da0073e9SAndroid Build Coastguard Worker ): 116*da0073e9SAndroid Build Coastguard Worker torch.jit.script(enum_comp) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker def test_enum_name(self): 119*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 120*da0073e9SAndroid Build Coastguard Worker RED = 1 121*da0073e9SAndroid Build Coastguard Worker GREEN = 2 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker make_global(Color) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 126*da0073e9SAndroid Build Coastguard Worker def enum_name(x: Color) -> str: 127*da0073e9SAndroid Build Coastguard Worker return x.name 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Color").check_next("prim::EnumName").check_next( 130*da0073e9SAndroid Build Coastguard Worker "return" 131*da0073e9SAndroid Build Coastguard Worker ).run(str(enum_name.graph)) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker self.assertEqual(enum_name(Color.RED), Color.RED.name) 134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker def test_enum_value(self): 137*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 138*da0073e9SAndroid Build Coastguard Worker RED = 1 139*da0073e9SAndroid Build Coastguard Worker GREEN = 2 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker make_global(Color) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 144*da0073e9SAndroid Build Coastguard Worker def enum_value(x: Color) -> int: 145*da0073e9SAndroid Build Coastguard Worker return x.value 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Color").check_next("prim::EnumValue").check_next( 148*da0073e9SAndroid Build Coastguard Worker "return" 149*da0073e9SAndroid Build Coastguard Worker ).run(str(enum_value.graph)) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker self.assertEqual(enum_value(Color.RED), Color.RED.value) 152*da0073e9SAndroid Build Coastguard Worker self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker def test_enum_as_const(self): 155*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 156*da0073e9SAndroid Build Coastguard Worker RED = 1 157*da0073e9SAndroid Build Coastguard Worker GREEN = 2 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker make_global(Color) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 162*da0073e9SAndroid Build Coastguard Worker def enum_const(x: Color) -> bool: 163*da0073e9SAndroid Build Coastguard Worker return x == Color.RED 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker FileCheck().check( 166*da0073e9SAndroid Build Coastguard Worker "prim::Constant[value=__torch__.jit.test_enum.Color.RED]" 167*da0073e9SAndroid Build Coastguard Worker ).check_next("aten::eq").check_next("return").run(str(enum_const.graph)) 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(enum_const(Color.RED), True) 170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(enum_const(Color.GREEN), False) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker def test_non_existent_enum_value(self): 173*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 174*da0073e9SAndroid Build Coastguard Worker RED = 1 175*da0073e9SAndroid Build Coastguard Worker GREEN = 2 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker make_global(Color) 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker def enum_const(x: Color) -> bool: 180*da0073e9SAndroid Build Coastguard Worker if x == Color.PURPLE: 181*da0073e9SAndroid Build Coastguard Worker return True 182*da0073e9SAndroid Build Coastguard Worker else: 183*da0073e9SAndroid Build Coastguard Worker return False 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 186*da0073e9SAndroid Build Coastguard Worker RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE" 187*da0073e9SAndroid Build Coastguard Worker ): 188*da0073e9SAndroid Build Coastguard Worker torch.jit.script(enum_const) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker def test_enum_ivalue_type(self): 191*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 192*da0073e9SAndroid Build Coastguard Worker RED = 1 193*da0073e9SAndroid Build Coastguard Worker GREEN = 2 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker make_global(Color) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 198*da0073e9SAndroid Build Coastguard Worker def is_color_enum(x: Any): 199*da0073e9SAndroid Build Coastguard Worker return isinstance(x, Color) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker FileCheck().check( 202*da0073e9SAndroid Build Coastguard Worker "prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]" 203*da0073e9SAndroid Build Coastguard Worker ).check_next("return").run(str(is_color_enum.graph)) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker self.assertEqual(is_color_enum(Color.RED), True) 206*da0073e9SAndroid Build Coastguard Worker self.assertEqual(is_color_enum(Color.GREEN), True) 207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(is_color_enum(1), False) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker def test_closed_over_enum_constant(self): 210*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 211*da0073e9SAndroid Build Coastguard Worker RED = 1 212*da0073e9SAndroid Build Coastguard Worker GREEN = 2 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker a = Color 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 217*da0073e9SAndroid Build Coastguard Worker def closed_over_aliased_type(): 218*da0073e9SAndroid Build Coastguard Worker return a.RED.value 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::Constant[value={}]".format(a.RED.value)).check_next( 221*da0073e9SAndroid Build Coastguard Worker "return" 222*da0073e9SAndroid Build Coastguard Worker ).run(str(closed_over_aliased_type.graph)) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(closed_over_aliased_type(), Color.RED.value) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker b = Color.RED 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 229*da0073e9SAndroid Build Coastguard Worker def closed_over_aliased_value(): 230*da0073e9SAndroid Build Coastguard Worker return b.value 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::Constant[value={}]".format(b.value)).check_next( 233*da0073e9SAndroid Build Coastguard Worker "return" 234*da0073e9SAndroid Build Coastguard Worker ).run(str(closed_over_aliased_value.graph)) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(closed_over_aliased_value(), Color.RED.value) 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker def test_enum_as_module_attribute(self): 239*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 240*da0073e9SAndroid Build Coastguard Worker RED = 1 241*da0073e9SAndroid Build Coastguard Worker GREEN = 2 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 244*da0073e9SAndroid Build Coastguard Worker def __init__(self, e: Color): 245*da0073e9SAndroid Build Coastguard Worker super().__init__() 246*da0073e9SAndroid Build Coastguard Worker self.e = e 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker def forward(self): 249*da0073e9SAndroid Build Coastguard Worker return self.e.value 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker m = TestModule(Color.RED) 252*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(m) 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker FileCheck().check("TestModule").check_next("Color").check_same( 255*da0073e9SAndroid Build Coastguard Worker 'prim::GetAttr[name="e"]' 256*da0073e9SAndroid Build Coastguard Worker ).check_next("prim::EnumValue").check_next("return").run(str(scripted.graph)) 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted(), Color.RED.value) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker def test_string_enum_as_module_attribute(self): 261*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 262*da0073e9SAndroid Build Coastguard Worker RED = "red" 263*da0073e9SAndroid Build Coastguard Worker GREEN = "green" 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 266*da0073e9SAndroid Build Coastguard Worker def __init__(self, e: Color): 267*da0073e9SAndroid Build Coastguard Worker super().__init__() 268*da0073e9SAndroid Build Coastguard Worker self.e = e 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker def forward(self): 271*da0073e9SAndroid Build Coastguard Worker return (self.e.name, self.e.value) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker make_global(Color) 274*da0073e9SAndroid Build Coastguard Worker m = TestModule(Color.RED) 275*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(m) 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted(), (Color.RED.name, Color.RED.value)) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker def test_enum_return(self): 280*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 281*da0073e9SAndroid Build Coastguard Worker RED = 1 282*da0073e9SAndroid Build Coastguard Worker GREEN = 2 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker make_global(Color) 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 287*da0073e9SAndroid Build Coastguard Worker def return_enum(cond: bool): 288*da0073e9SAndroid Build Coastguard Worker if cond: 289*da0073e9SAndroid Build Coastguard Worker return Color.RED 290*da0073e9SAndroid Build Coastguard Worker else: 291*da0073e9SAndroid Build Coastguard Worker return Color.GREEN 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker self.assertEqual(return_enum(True), Color.RED) 294*da0073e9SAndroid Build Coastguard Worker self.assertEqual(return_enum(False), Color.GREEN) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker def test_enum_module_return(self): 297*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 298*da0073e9SAndroid Build Coastguard Worker RED = 1 299*da0073e9SAndroid Build Coastguard Worker GREEN = 2 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 302*da0073e9SAndroid Build Coastguard Worker def __init__(self, e: Color): 303*da0073e9SAndroid Build Coastguard Worker super().__init__() 304*da0073e9SAndroid Build Coastguard Worker self.e = e 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker def forward(self): 307*da0073e9SAndroid Build Coastguard Worker return self.e 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker make_global(Color) 310*da0073e9SAndroid Build Coastguard Worker m = TestModule(Color.RED) 311*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(m) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker FileCheck().check("TestModule").check_next("Color").check_same( 314*da0073e9SAndroid Build Coastguard Worker 'prim::GetAttr[name="e"]' 315*da0073e9SAndroid Build Coastguard Worker ).check_next("return").run(str(scripted.graph)) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted(), Color.RED) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker def test_enum_iterate(self): 320*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 321*da0073e9SAndroid Build Coastguard Worker RED = 1 322*da0073e9SAndroid Build Coastguard Worker GREEN = 2 323*da0073e9SAndroid Build Coastguard Worker BLUE = 3 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker def iterate_enum(x: Color): 326*da0073e9SAndroid Build Coastguard Worker res: List[int] = [] 327*da0073e9SAndroid Build Coastguard Worker for e in Color: 328*da0073e9SAndroid Build Coastguard Worker if e != x: 329*da0073e9SAndroid Build Coastguard Worker res.append(e.value) 330*da0073e9SAndroid Build Coastguard Worker return res 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker make_global(Color) 333*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(iterate_enum) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Enum<__torch__.jit.test_enum.Color>[]").check_same( 336*da0073e9SAndroid Build Coastguard Worker "Color.RED" 337*da0073e9SAndroid Build Coastguard Worker ).check_same("Color.GREEN").check_same("Color.BLUE").run(str(scripted.graph)) 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker # PURPLE always appears last because we follow Python's Enum definition order. 340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted(Color.RED), [Color.GREEN.value, Color.BLUE.value]) 341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted(Color.GREEN), [Color.RED.value, Color.BLUE.value]) 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker # Tests that explicitly and/or repeatedly scripting an Enum class is permitted. 344*da0073e9SAndroid Build Coastguard Worker def test_enum_explicit_script(self): 345*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 346*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 347*da0073e9SAndroid Build Coastguard Worker RED = 1 348*da0073e9SAndroid Build Coastguard Worker GREEN = 2 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker torch.jit.script(Color) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/108933 353*da0073e9SAndroid Build Coastguard Worker def test_typed_enum(self): 354*da0073e9SAndroid Build Coastguard Worker class Color(int, Enum): 355*da0073e9SAndroid Build Coastguard Worker RED = 1 356*da0073e9SAndroid Build Coastguard Worker GREEN = 2 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 359*da0073e9SAndroid Build Coastguard Worker def is_red(x: Color) -> bool: 360*da0073e9SAndroid Build Coastguard Worker return x == Color.RED 361