xref: /aosp_15_r20/external/pytorch/test/jit/test_enum.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum
6*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, List
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
13*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
19*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
20*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
21*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
22*da0073e9SAndroid Build Coastguard Worker        "instead."
23*da0073e9SAndroid Build Coastguard Worker    )
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workerclass TestEnum(JitTestCase):
27*da0073e9SAndroid Build Coastguard Worker    def test_enum_value_types(self):
28*da0073e9SAndroid Build Coastguard Worker        class IntEnum(Enum):
29*da0073e9SAndroid Build Coastguard Worker            FOO = 1
30*da0073e9SAndroid Build Coastguard Worker            BAR = 2
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker        class FloatEnum(Enum):
33*da0073e9SAndroid Build Coastguard Worker            FOO = 1.2
34*da0073e9SAndroid Build Coastguard Worker            BAR = 2.3
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker        class StringEnum(Enum):
37*da0073e9SAndroid Build Coastguard Worker            FOO = "foo as in foo bar"
38*da0073e9SAndroid Build Coastguard Worker            BAR = "bar as in foo bar"
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker        make_global(IntEnum, FloatEnum, StringEnum)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
43*da0073e9SAndroid Build Coastguard Worker        def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
44*da0073e9SAndroid Build Coastguard Worker            return (a.name, b.name, c.name)
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("IntEnum").check("FloatEnum").check("StringEnum").run(
47*da0073e9SAndroid Build Coastguard Worker            str(supported_enum_types.graph)
48*da0073e9SAndroid Build Coastguard Worker        )
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        class TensorEnum(Enum):
51*da0073e9SAndroid Build Coastguard Worker            FOO = torch.tensor(0)
52*da0073e9SAndroid Build Coastguard Worker            BAR = torch.tensor(1)
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker        make_global(TensorEnum)
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker        def unsupported_enum_types(a: TensorEnum):
57*da0073e9SAndroid Build Coastguard Worker            return a.name
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker        # TODO: rewrite code so that the highlight is not empty.
60*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
61*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Cannot create Enum with value type 'Tensor'", ""
62*da0073e9SAndroid Build Coastguard Worker        ):
63*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(unsupported_enum_types)
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    def test_enum_comp(self):
66*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
67*da0073e9SAndroid Build Coastguard Worker            RED = 1
68*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        make_global(Color)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
73*da0073e9SAndroid Build Coastguard Worker        def enum_comp(x: Color, y: Color) -> bool:
74*da0073e9SAndroid Build Coastguard Worker            return x == y
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::eq").run(str(enum_comp.graph))
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(enum_comp(Color.RED, Color.RED), True)
79*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(enum_comp(Color.RED, Color.GREEN), False)
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker    def test_enum_comp_diff_classes(self):
82*da0073e9SAndroid Build Coastguard Worker        class Foo(Enum):
83*da0073e9SAndroid Build Coastguard Worker            ITEM1 = 1
84*da0073e9SAndroid Build Coastguard Worker            ITEM2 = 2
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker        class Bar(Enum):
87*da0073e9SAndroid Build Coastguard Worker            ITEM1 = 1
88*da0073e9SAndroid Build Coastguard Worker            ITEM2 = 2
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker        make_global(Foo, Bar)
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
93*da0073e9SAndroid Build Coastguard Worker        def enum_comp(x: Foo) -> bool:
94*da0073e9SAndroid Build Coastguard Worker            return x == Bar.ITEM1
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::Constant").check_same("Bar.ITEM1").check(
97*da0073e9SAndroid Build Coastguard Worker            "aten::eq"
98*da0073e9SAndroid Build Coastguard Worker        ).run(str(enum_comp.graph))
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(enum_comp(Foo.ITEM1), False)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    def test_heterogenous_value_type_enum_error(self):
103*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
104*da0073e9SAndroid Build Coastguard Worker            RED = 1
105*da0073e9SAndroid Build Coastguard Worker            GREEN = "green"
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker        make_global(Color)
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker        def enum_comp(x: Color, y: Color) -> bool:
110*da0073e9SAndroid Build Coastguard Worker            return x == y
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        # TODO: rewrite code so that the highlight is not empty.
113*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
114*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Could not unify type list", ""
115*da0073e9SAndroid Build Coastguard Worker        ):
116*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(enum_comp)
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    def test_enum_name(self):
119*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
120*da0073e9SAndroid Build Coastguard Worker            RED = 1
121*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker        make_global(Color)
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
126*da0073e9SAndroid Build Coastguard Worker        def enum_name(x: Color) -> str:
127*da0073e9SAndroid Build Coastguard Worker            return x.name
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Color").check_next("prim::EnumName").check_next(
130*da0073e9SAndroid Build Coastguard Worker            "return"
131*da0073e9SAndroid Build Coastguard Worker        ).run(str(enum_name.graph))
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(enum_name(Color.RED), Color.RED.name)
134*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker    def test_enum_value(self):
137*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
138*da0073e9SAndroid Build Coastguard Worker            RED = 1
139*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker        make_global(Color)
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
144*da0073e9SAndroid Build Coastguard Worker        def enum_value(x: Color) -> int:
145*da0073e9SAndroid Build Coastguard Worker            return x.value
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Color").check_next("prim::EnumValue").check_next(
148*da0073e9SAndroid Build Coastguard Worker            "return"
149*da0073e9SAndroid Build Coastguard Worker        ).run(str(enum_value.graph))
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(enum_value(Color.RED), Color.RED.value)
152*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker    def test_enum_as_const(self):
155*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
156*da0073e9SAndroid Build Coastguard Worker            RED = 1
157*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker        make_global(Color)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
162*da0073e9SAndroid Build Coastguard Worker        def enum_const(x: Color) -> bool:
163*da0073e9SAndroid Build Coastguard Worker            return x == Color.RED
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker        FileCheck().check(
166*da0073e9SAndroid Build Coastguard Worker            "prim::Constant[value=__torch__.jit.test_enum.Color.RED]"
167*da0073e9SAndroid Build Coastguard Worker        ).check_next("aten::eq").check_next("return").run(str(enum_const.graph))
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(enum_const(Color.RED), True)
170*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(enum_const(Color.GREEN), False)
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    def test_non_existent_enum_value(self):
173*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
174*da0073e9SAndroid Build Coastguard Worker            RED = 1
175*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker        make_global(Color)
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker        def enum_const(x: Color) -> bool:
180*da0073e9SAndroid Build Coastguard Worker            if x == Color.PURPLE:
181*da0073e9SAndroid Build Coastguard Worker                return True
182*da0073e9SAndroid Build Coastguard Worker            else:
183*da0073e9SAndroid Build Coastguard Worker                return False
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
186*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE"
187*da0073e9SAndroid Build Coastguard Worker        ):
188*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(enum_const)
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker    def test_enum_ivalue_type(self):
191*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
192*da0073e9SAndroid Build Coastguard Worker            RED = 1
193*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker        make_global(Color)
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
198*da0073e9SAndroid Build Coastguard Worker        def is_color_enum(x: Any):
199*da0073e9SAndroid Build Coastguard Worker            return isinstance(x, Color)
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker        FileCheck().check(
202*da0073e9SAndroid Build Coastguard Worker            "prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]"
203*da0073e9SAndroid Build Coastguard Worker        ).check_next("return").run(str(is_color_enum.graph))
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(is_color_enum(Color.RED), True)
206*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(is_color_enum(Color.GREEN), True)
207*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(is_color_enum(1), False)
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker    def test_closed_over_enum_constant(self):
210*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
211*da0073e9SAndroid Build Coastguard Worker            RED = 1
212*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        a = Color
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
217*da0073e9SAndroid Build Coastguard Worker        def closed_over_aliased_type():
218*da0073e9SAndroid Build Coastguard Worker            return a.RED.value
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::Constant[value={}]".format(a.RED.value)).check_next(
221*da0073e9SAndroid Build Coastguard Worker            "return"
222*da0073e9SAndroid Build Coastguard Worker        ).run(str(closed_over_aliased_type.graph))
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(closed_over_aliased_type(), Color.RED.value)
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        b = Color.RED
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
229*da0073e9SAndroid Build Coastguard Worker        def closed_over_aliased_value():
230*da0073e9SAndroid Build Coastguard Worker            return b.value
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::Constant[value={}]".format(b.value)).check_next(
233*da0073e9SAndroid Build Coastguard Worker            "return"
234*da0073e9SAndroid Build Coastguard Worker        ).run(str(closed_over_aliased_value.graph))
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(closed_over_aliased_value(), Color.RED.value)
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker    def test_enum_as_module_attribute(self):
239*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
240*da0073e9SAndroid Build Coastguard Worker            RED = 1
241*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
244*da0073e9SAndroid Build Coastguard Worker            def __init__(self, e: Color):
245*da0073e9SAndroid Build Coastguard Worker                super().__init__()
246*da0073e9SAndroid Build Coastguard Worker                self.e = e
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker            def forward(self):
249*da0073e9SAndroid Build Coastguard Worker                return self.e.value
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        m = TestModule(Color.RED)
252*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(m)
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("TestModule").check_next("Color").check_same(
255*da0073e9SAndroid Build Coastguard Worker            'prim::GetAttr[name="e"]'
256*da0073e9SAndroid Build Coastguard Worker        ).check_next("prim::EnumValue").check_next("return").run(str(scripted.graph))
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted(), Color.RED.value)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker    def test_string_enum_as_module_attribute(self):
261*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
262*da0073e9SAndroid Build Coastguard Worker            RED = "red"
263*da0073e9SAndroid Build Coastguard Worker            GREEN = "green"
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
266*da0073e9SAndroid Build Coastguard Worker            def __init__(self, e: Color):
267*da0073e9SAndroid Build Coastguard Worker                super().__init__()
268*da0073e9SAndroid Build Coastguard Worker                self.e = e
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker            def forward(self):
271*da0073e9SAndroid Build Coastguard Worker                return (self.e.name, self.e.value)
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker        make_global(Color)
274*da0073e9SAndroid Build Coastguard Worker        m = TestModule(Color.RED)
275*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(m)
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted(), (Color.RED.name, Color.RED.value))
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker    def test_enum_return(self):
280*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
281*da0073e9SAndroid Build Coastguard Worker            RED = 1
282*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker        make_global(Color)
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
287*da0073e9SAndroid Build Coastguard Worker        def return_enum(cond: bool):
288*da0073e9SAndroid Build Coastguard Worker            if cond:
289*da0073e9SAndroid Build Coastguard Worker                return Color.RED
290*da0073e9SAndroid Build Coastguard Worker            else:
291*da0073e9SAndroid Build Coastguard Worker                return Color.GREEN
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(return_enum(True), Color.RED)
294*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(return_enum(False), Color.GREEN)
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker    def test_enum_module_return(self):
297*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
298*da0073e9SAndroid Build Coastguard Worker            RED = 1
299*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
302*da0073e9SAndroid Build Coastguard Worker            def __init__(self, e: Color):
303*da0073e9SAndroid Build Coastguard Worker                super().__init__()
304*da0073e9SAndroid Build Coastguard Worker                self.e = e
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker            def forward(self):
307*da0073e9SAndroid Build Coastguard Worker                return self.e
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        make_global(Color)
310*da0073e9SAndroid Build Coastguard Worker        m = TestModule(Color.RED)
311*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(m)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("TestModule").check_next("Color").check_same(
314*da0073e9SAndroid Build Coastguard Worker            'prim::GetAttr[name="e"]'
315*da0073e9SAndroid Build Coastguard Worker        ).check_next("return").run(str(scripted.graph))
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted(), Color.RED)
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker    def test_enum_iterate(self):
320*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
321*da0073e9SAndroid Build Coastguard Worker            RED = 1
322*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
323*da0073e9SAndroid Build Coastguard Worker            BLUE = 3
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker        def iterate_enum(x: Color):
326*da0073e9SAndroid Build Coastguard Worker            res: List[int] = []
327*da0073e9SAndroid Build Coastguard Worker            for e in Color:
328*da0073e9SAndroid Build Coastguard Worker                if e != x:
329*da0073e9SAndroid Build Coastguard Worker                    res.append(e.value)
330*da0073e9SAndroid Build Coastguard Worker            return res
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker        make_global(Color)
333*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(iterate_enum)
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Enum<__torch__.jit.test_enum.Color>[]").check_same(
336*da0073e9SAndroid Build Coastguard Worker            "Color.RED"
337*da0073e9SAndroid Build Coastguard Worker        ).check_same("Color.GREEN").check_same("Color.BLUE").run(str(scripted.graph))
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker        # PURPLE always appears last because we follow Python's Enum definition order.
340*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted(Color.RED), [Color.GREEN.value, Color.BLUE.value])
341*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted(Color.GREEN), [Color.RED.value, Color.BLUE.value])
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker    # Tests that explicitly and/or repeatedly scripting an Enum class is permitted.
344*da0073e9SAndroid Build Coastguard Worker    def test_enum_explicit_script(self):
345*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
346*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
347*da0073e9SAndroid Build Coastguard Worker            RED = 1
348*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(Color)
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker    # Regression test for https://github.com/pytorch/pytorch/issues/108933
353*da0073e9SAndroid Build Coastguard Worker    def test_typed_enum(self):
354*da0073e9SAndroid Build Coastguard Worker        class Color(int, Enum):
355*da0073e9SAndroid Build Coastguard Worker            RED = 1
356*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
359*da0073e9SAndroid Build Coastguard Worker        def is_red(x: Color) -> bool:
360*da0073e9SAndroid Build Coastguard Worker            return x == Color.RED
361