xref: /aosp_15_r20/external/pytorch/test/jit/test_enum.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5from enum import Enum
6from typing import Any, List
7
8import torch
9from torch.testing import FileCheck
10
11
12# Make the helper files in test/ importable
13pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14sys.path.append(pytorch_test_dir)
15from torch.testing._internal.jit_utils import JitTestCase, make_global
16
17
18if __name__ == "__main__":
19    raise RuntimeError(
20        "This test file is not meant to be run directly, use:\n\n"
21        "\tpython test/test_jit.py TESTNAME\n\n"
22        "instead."
23    )
24
25
26class TestEnum(JitTestCase):
27    def test_enum_value_types(self):
28        class IntEnum(Enum):
29            FOO = 1
30            BAR = 2
31
32        class FloatEnum(Enum):
33            FOO = 1.2
34            BAR = 2.3
35
36        class StringEnum(Enum):
37            FOO = "foo as in foo bar"
38            BAR = "bar as in foo bar"
39
40        make_global(IntEnum, FloatEnum, StringEnum)
41
42        @torch.jit.script
43        def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
44            return (a.name, b.name, c.name)
45
46        FileCheck().check("IntEnum").check("FloatEnum").check("StringEnum").run(
47            str(supported_enum_types.graph)
48        )
49
50        class TensorEnum(Enum):
51            FOO = torch.tensor(0)
52            BAR = torch.tensor(1)
53
54        make_global(TensorEnum)
55
56        def unsupported_enum_types(a: TensorEnum):
57            return a.name
58
59        # TODO: rewrite code so that the highlight is not empty.
60        with self.assertRaisesRegexWithHighlight(
61            RuntimeError, "Cannot create Enum with value type 'Tensor'", ""
62        ):
63            torch.jit.script(unsupported_enum_types)
64
65    def test_enum_comp(self):
66        class Color(Enum):
67            RED = 1
68            GREEN = 2
69
70        make_global(Color)
71
72        @torch.jit.script
73        def enum_comp(x: Color, y: Color) -> bool:
74            return x == y
75
76        FileCheck().check("aten::eq").run(str(enum_comp.graph))
77
78        self.assertEqual(enum_comp(Color.RED, Color.RED), True)
79        self.assertEqual(enum_comp(Color.RED, Color.GREEN), False)
80
81    def test_enum_comp_diff_classes(self):
82        class Foo(Enum):
83            ITEM1 = 1
84            ITEM2 = 2
85
86        class Bar(Enum):
87            ITEM1 = 1
88            ITEM2 = 2
89
90        make_global(Foo, Bar)
91
92        @torch.jit.script
93        def enum_comp(x: Foo) -> bool:
94            return x == Bar.ITEM1
95
96        FileCheck().check("prim::Constant").check_same("Bar.ITEM1").check(
97            "aten::eq"
98        ).run(str(enum_comp.graph))
99
100        self.assertEqual(enum_comp(Foo.ITEM1), False)
101
102    def test_heterogenous_value_type_enum_error(self):
103        class Color(Enum):
104            RED = 1
105            GREEN = "green"
106
107        make_global(Color)
108
109        def enum_comp(x: Color, y: Color) -> bool:
110            return x == y
111
112        # TODO: rewrite code so that the highlight is not empty.
113        with self.assertRaisesRegexWithHighlight(
114            RuntimeError, "Could not unify type list", ""
115        ):
116            torch.jit.script(enum_comp)
117
118    def test_enum_name(self):
119        class Color(Enum):
120            RED = 1
121            GREEN = 2
122
123        make_global(Color)
124
125        @torch.jit.script
126        def enum_name(x: Color) -> str:
127            return x.name
128
129        FileCheck().check("Color").check_next("prim::EnumName").check_next(
130            "return"
131        ).run(str(enum_name.graph))
132
133        self.assertEqual(enum_name(Color.RED), Color.RED.name)
134        self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)
135
136    def test_enum_value(self):
137        class Color(Enum):
138            RED = 1
139            GREEN = 2
140
141        make_global(Color)
142
143        @torch.jit.script
144        def enum_value(x: Color) -> int:
145            return x.value
146
147        FileCheck().check("Color").check_next("prim::EnumValue").check_next(
148            "return"
149        ).run(str(enum_value.graph))
150
151        self.assertEqual(enum_value(Color.RED), Color.RED.value)
152        self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
153
154    def test_enum_as_const(self):
155        class Color(Enum):
156            RED = 1
157            GREEN = 2
158
159        make_global(Color)
160
161        @torch.jit.script
162        def enum_const(x: Color) -> bool:
163            return x == Color.RED
164
165        FileCheck().check(
166            "prim::Constant[value=__torch__.jit.test_enum.Color.RED]"
167        ).check_next("aten::eq").check_next("return").run(str(enum_const.graph))
168
169        self.assertEqual(enum_const(Color.RED), True)
170        self.assertEqual(enum_const(Color.GREEN), False)
171
172    def test_non_existent_enum_value(self):
173        class Color(Enum):
174            RED = 1
175            GREEN = 2
176
177        make_global(Color)
178
179        def enum_const(x: Color) -> bool:
180            if x == Color.PURPLE:
181                return True
182            else:
183                return False
184
185        with self.assertRaisesRegexWithHighlight(
186            RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE"
187        ):
188            torch.jit.script(enum_const)
189
190    def test_enum_ivalue_type(self):
191        class Color(Enum):
192            RED = 1
193            GREEN = 2
194
195        make_global(Color)
196
197        @torch.jit.script
198        def is_color_enum(x: Any):
199            return isinstance(x, Color)
200
201        FileCheck().check(
202            "prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]"
203        ).check_next("return").run(str(is_color_enum.graph))
204
205        self.assertEqual(is_color_enum(Color.RED), True)
206        self.assertEqual(is_color_enum(Color.GREEN), True)
207        self.assertEqual(is_color_enum(1), False)
208
209    def test_closed_over_enum_constant(self):
210        class Color(Enum):
211            RED = 1
212            GREEN = 2
213
214        a = Color
215
216        @torch.jit.script
217        def closed_over_aliased_type():
218            return a.RED.value
219
220        FileCheck().check("prim::Constant[value={}]".format(a.RED.value)).check_next(
221            "return"
222        ).run(str(closed_over_aliased_type.graph))
223
224        self.assertEqual(closed_over_aliased_type(), Color.RED.value)
225
226        b = Color.RED
227
228        @torch.jit.script
229        def closed_over_aliased_value():
230            return b.value
231
232        FileCheck().check("prim::Constant[value={}]".format(b.value)).check_next(
233            "return"
234        ).run(str(closed_over_aliased_value.graph))
235
236        self.assertEqual(closed_over_aliased_value(), Color.RED.value)
237
238    def test_enum_as_module_attribute(self):
239        class Color(Enum):
240            RED = 1
241            GREEN = 2
242
243        class TestModule(torch.nn.Module):
244            def __init__(self, e: Color):
245                super().__init__()
246                self.e = e
247
248            def forward(self):
249                return self.e.value
250
251        m = TestModule(Color.RED)
252        scripted = torch.jit.script(m)
253
254        FileCheck().check("TestModule").check_next("Color").check_same(
255            'prim::GetAttr[name="e"]'
256        ).check_next("prim::EnumValue").check_next("return").run(str(scripted.graph))
257
258        self.assertEqual(scripted(), Color.RED.value)
259
260    def test_string_enum_as_module_attribute(self):
261        class Color(Enum):
262            RED = "red"
263            GREEN = "green"
264
265        class TestModule(torch.nn.Module):
266            def __init__(self, e: Color):
267                super().__init__()
268                self.e = e
269
270            def forward(self):
271                return (self.e.name, self.e.value)
272
273        make_global(Color)
274        m = TestModule(Color.RED)
275        scripted = torch.jit.script(m)
276
277        self.assertEqual(scripted(), (Color.RED.name, Color.RED.value))
278
279    def test_enum_return(self):
280        class Color(Enum):
281            RED = 1
282            GREEN = 2
283
284        make_global(Color)
285
286        @torch.jit.script
287        def return_enum(cond: bool):
288            if cond:
289                return Color.RED
290            else:
291                return Color.GREEN
292
293        self.assertEqual(return_enum(True), Color.RED)
294        self.assertEqual(return_enum(False), Color.GREEN)
295
296    def test_enum_module_return(self):
297        class Color(Enum):
298            RED = 1
299            GREEN = 2
300
301        class TestModule(torch.nn.Module):
302            def __init__(self, e: Color):
303                super().__init__()
304                self.e = e
305
306            def forward(self):
307                return self.e
308
309        make_global(Color)
310        m = TestModule(Color.RED)
311        scripted = torch.jit.script(m)
312
313        FileCheck().check("TestModule").check_next("Color").check_same(
314            'prim::GetAttr[name="e"]'
315        ).check_next("return").run(str(scripted.graph))
316
317        self.assertEqual(scripted(), Color.RED)
318
319    def test_enum_iterate(self):
320        class Color(Enum):
321            RED = 1
322            GREEN = 2
323            BLUE = 3
324
325        def iterate_enum(x: Color):
326            res: List[int] = []
327            for e in Color:
328                if e != x:
329                    res.append(e.value)
330            return res
331
332        make_global(Color)
333        scripted = torch.jit.script(iterate_enum)
334
335        FileCheck().check("Enum<__torch__.jit.test_enum.Color>[]").check_same(
336            "Color.RED"
337        ).check_same("Color.GREEN").check_same("Color.BLUE").run(str(scripted.graph))
338
339        # PURPLE always appears last because we follow Python's Enum definition order.
340        self.assertEqual(scripted(Color.RED), [Color.GREEN.value, Color.BLUE.value])
341        self.assertEqual(scripted(Color.GREEN), [Color.RED.value, Color.BLUE.value])
342
343    # Tests that explicitly and/or repeatedly scripting an Enum class is permitted.
344    def test_enum_explicit_script(self):
345        @torch.jit.script
346        class Color(Enum):
347            RED = 1
348            GREEN = 2
349
350        torch.jit.script(Color)
351
352    # Regression test for https://github.com/pytorch/pytorch/issues/108933
353    def test_typed_enum(self):
354        class Color(int, Enum):
355            RED = 1
356            GREEN = 2
357
358        @torch.jit.script
359        def is_red(x: Color) -> bool:
360            return x == Color.RED
361