1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5from enum import Enum 6from typing import Any, List 7 8import torch 9from torch.testing import FileCheck 10 11 12# Make the helper files in test/ importable 13pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 14sys.path.append(pytorch_test_dir) 15from torch.testing._internal.jit_utils import JitTestCase, make_global 16 17 18if __name__ == "__main__": 19 raise RuntimeError( 20 "This test file is not meant to be run directly, use:\n\n" 21 "\tpython test/test_jit.py TESTNAME\n\n" 22 "instead." 23 ) 24 25 26class TestEnum(JitTestCase): 27 def test_enum_value_types(self): 28 class IntEnum(Enum): 29 FOO = 1 30 BAR = 2 31 32 class FloatEnum(Enum): 33 FOO = 1.2 34 BAR = 2.3 35 36 class StringEnum(Enum): 37 FOO = "foo as in foo bar" 38 BAR = "bar as in foo bar" 39 40 make_global(IntEnum, FloatEnum, StringEnum) 41 42 @torch.jit.script 43 def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum): 44 return (a.name, b.name, c.name) 45 46 FileCheck().check("IntEnum").check("FloatEnum").check("StringEnum").run( 47 str(supported_enum_types.graph) 48 ) 49 50 class TensorEnum(Enum): 51 FOO = torch.tensor(0) 52 BAR = torch.tensor(1) 53 54 make_global(TensorEnum) 55 56 def unsupported_enum_types(a: TensorEnum): 57 return a.name 58 59 # TODO: rewrite code so that the highlight is not empty. 60 with self.assertRaisesRegexWithHighlight( 61 RuntimeError, "Cannot create Enum with value type 'Tensor'", "" 62 ): 63 torch.jit.script(unsupported_enum_types) 64 65 def test_enum_comp(self): 66 class Color(Enum): 67 RED = 1 68 GREEN = 2 69 70 make_global(Color) 71 72 @torch.jit.script 73 def enum_comp(x: Color, y: Color) -> bool: 74 return x == y 75 76 FileCheck().check("aten::eq").run(str(enum_comp.graph)) 77 78 self.assertEqual(enum_comp(Color.RED, Color.RED), True) 79 self.assertEqual(enum_comp(Color.RED, Color.GREEN), False) 80 81 def test_enum_comp_diff_classes(self): 82 class Foo(Enum): 83 ITEM1 = 1 84 ITEM2 = 2 85 86 class Bar(Enum): 87 ITEM1 = 1 88 ITEM2 = 2 89 90 make_global(Foo, Bar) 91 92 @torch.jit.script 93 def enum_comp(x: Foo) -> bool: 94 return x == Bar.ITEM1 95 96 FileCheck().check("prim::Constant").check_same("Bar.ITEM1").check( 97 "aten::eq" 98 ).run(str(enum_comp.graph)) 99 100 self.assertEqual(enum_comp(Foo.ITEM1), False) 101 102 def test_heterogenous_value_type_enum_error(self): 103 class Color(Enum): 104 RED = 1 105 GREEN = "green" 106 107 make_global(Color) 108 109 def enum_comp(x: Color, y: Color) -> bool: 110 return x == y 111 112 # TODO: rewrite code so that the highlight is not empty. 113 with self.assertRaisesRegexWithHighlight( 114 RuntimeError, "Could not unify type list", "" 115 ): 116 torch.jit.script(enum_comp) 117 118 def test_enum_name(self): 119 class Color(Enum): 120 RED = 1 121 GREEN = 2 122 123 make_global(Color) 124 125 @torch.jit.script 126 def enum_name(x: Color) -> str: 127 return x.name 128 129 FileCheck().check("Color").check_next("prim::EnumName").check_next( 130 "return" 131 ).run(str(enum_name.graph)) 132 133 self.assertEqual(enum_name(Color.RED), Color.RED.name) 134 self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name) 135 136 def test_enum_value(self): 137 class Color(Enum): 138 RED = 1 139 GREEN = 2 140 141 make_global(Color) 142 143 @torch.jit.script 144 def enum_value(x: Color) -> int: 145 return x.value 146 147 FileCheck().check("Color").check_next("prim::EnumValue").check_next( 148 "return" 149 ).run(str(enum_value.graph)) 150 151 self.assertEqual(enum_value(Color.RED), Color.RED.value) 152 self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value) 153 154 def test_enum_as_const(self): 155 class Color(Enum): 156 RED = 1 157 GREEN = 2 158 159 make_global(Color) 160 161 @torch.jit.script 162 def enum_const(x: Color) -> bool: 163 return x == Color.RED 164 165 FileCheck().check( 166 "prim::Constant[value=__torch__.jit.test_enum.Color.RED]" 167 ).check_next("aten::eq").check_next("return").run(str(enum_const.graph)) 168 169 self.assertEqual(enum_const(Color.RED), True) 170 self.assertEqual(enum_const(Color.GREEN), False) 171 172 def test_non_existent_enum_value(self): 173 class Color(Enum): 174 RED = 1 175 GREEN = 2 176 177 make_global(Color) 178 179 def enum_const(x: Color) -> bool: 180 if x == Color.PURPLE: 181 return True 182 else: 183 return False 184 185 with self.assertRaisesRegexWithHighlight( 186 RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE" 187 ): 188 torch.jit.script(enum_const) 189 190 def test_enum_ivalue_type(self): 191 class Color(Enum): 192 RED = 1 193 GREEN = 2 194 195 make_global(Color) 196 197 @torch.jit.script 198 def is_color_enum(x: Any): 199 return isinstance(x, Color) 200 201 FileCheck().check( 202 "prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]" 203 ).check_next("return").run(str(is_color_enum.graph)) 204 205 self.assertEqual(is_color_enum(Color.RED), True) 206 self.assertEqual(is_color_enum(Color.GREEN), True) 207 self.assertEqual(is_color_enum(1), False) 208 209 def test_closed_over_enum_constant(self): 210 class Color(Enum): 211 RED = 1 212 GREEN = 2 213 214 a = Color 215 216 @torch.jit.script 217 def closed_over_aliased_type(): 218 return a.RED.value 219 220 FileCheck().check("prim::Constant[value={}]".format(a.RED.value)).check_next( 221 "return" 222 ).run(str(closed_over_aliased_type.graph)) 223 224 self.assertEqual(closed_over_aliased_type(), Color.RED.value) 225 226 b = Color.RED 227 228 @torch.jit.script 229 def closed_over_aliased_value(): 230 return b.value 231 232 FileCheck().check("prim::Constant[value={}]".format(b.value)).check_next( 233 "return" 234 ).run(str(closed_over_aliased_value.graph)) 235 236 self.assertEqual(closed_over_aliased_value(), Color.RED.value) 237 238 def test_enum_as_module_attribute(self): 239 class Color(Enum): 240 RED = 1 241 GREEN = 2 242 243 class TestModule(torch.nn.Module): 244 def __init__(self, e: Color): 245 super().__init__() 246 self.e = e 247 248 def forward(self): 249 return self.e.value 250 251 m = TestModule(Color.RED) 252 scripted = torch.jit.script(m) 253 254 FileCheck().check("TestModule").check_next("Color").check_same( 255 'prim::GetAttr[name="e"]' 256 ).check_next("prim::EnumValue").check_next("return").run(str(scripted.graph)) 257 258 self.assertEqual(scripted(), Color.RED.value) 259 260 def test_string_enum_as_module_attribute(self): 261 class Color(Enum): 262 RED = "red" 263 GREEN = "green" 264 265 class TestModule(torch.nn.Module): 266 def __init__(self, e: Color): 267 super().__init__() 268 self.e = e 269 270 def forward(self): 271 return (self.e.name, self.e.value) 272 273 make_global(Color) 274 m = TestModule(Color.RED) 275 scripted = torch.jit.script(m) 276 277 self.assertEqual(scripted(), (Color.RED.name, Color.RED.value)) 278 279 def test_enum_return(self): 280 class Color(Enum): 281 RED = 1 282 GREEN = 2 283 284 make_global(Color) 285 286 @torch.jit.script 287 def return_enum(cond: bool): 288 if cond: 289 return Color.RED 290 else: 291 return Color.GREEN 292 293 self.assertEqual(return_enum(True), Color.RED) 294 self.assertEqual(return_enum(False), Color.GREEN) 295 296 def test_enum_module_return(self): 297 class Color(Enum): 298 RED = 1 299 GREEN = 2 300 301 class TestModule(torch.nn.Module): 302 def __init__(self, e: Color): 303 super().__init__() 304 self.e = e 305 306 def forward(self): 307 return self.e 308 309 make_global(Color) 310 m = TestModule(Color.RED) 311 scripted = torch.jit.script(m) 312 313 FileCheck().check("TestModule").check_next("Color").check_same( 314 'prim::GetAttr[name="e"]' 315 ).check_next("return").run(str(scripted.graph)) 316 317 self.assertEqual(scripted(), Color.RED) 318 319 def test_enum_iterate(self): 320 class Color(Enum): 321 RED = 1 322 GREEN = 2 323 BLUE = 3 324 325 def iterate_enum(x: Color): 326 res: List[int] = [] 327 for e in Color: 328 if e != x: 329 res.append(e.value) 330 return res 331 332 make_global(Color) 333 scripted = torch.jit.script(iterate_enum) 334 335 FileCheck().check("Enum<__torch__.jit.test_enum.Color>[]").check_same( 336 "Color.RED" 337 ).check_same("Color.GREEN").check_same("Color.BLUE").run(str(scripted.graph)) 338 339 # PURPLE always appears last because we follow Python's Enum definition order. 340 self.assertEqual(scripted(Color.RED), [Color.GREEN.value, Color.BLUE.value]) 341 self.assertEqual(scripted(Color.GREEN), [Color.RED.value, Color.BLUE.value]) 342 343 # Tests that explicitly and/or repeatedly scripting an Enum class is permitted. 344 def test_enum_explicit_script(self): 345 @torch.jit.script 346 class Color(Enum): 347 RED = 1 348 GREEN = 2 349 350 torch.jit.script(Color) 351 352 # Regression test for https://github.com/pytorch/pytorch/issues/108933 353 def test_typed_enum(self): 354 class Color(int, Enum): 355 RED = 1 356 GREEN = 2 357 358 @torch.jit.script 359 def is_red(x: Color) -> bool: 360 return x == Color.RED 361