xref: /aosp_15_r20/external/pytorch/test/typing/pass/disabled_jit.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from enum import Enum
2from typing import Type, TypeVar
3from typing_extensions import assert_never, assert_type, ParamSpec
4
5import pytest
6
7from torch import jit, nn, ScriptDict, ScriptFunction, ScriptList
8
9
10P = ParamSpec("P")
11R = TypeVar("R", covariant=True)
12
13
14class Color(Enum):
15    RED = 1
16    GREEN = 2
17    BLUE = 3
18
19
20# Script Enum
21assert_type(jit.script(Color), Type[Color])
22
23# ScriptDict
24assert_type(jit.script({1: 1}), ScriptDict)
25
26# ScriptList
27assert_type(jit.script([0]), ScriptList)
28
29# ScriptModule
30scripted_module = jit.script(nn.Linear(2, 2))
31assert_type(scripted_module, jit.RecursiveScriptModule)
32
33# ScripFunction
34# NOTE: can't use assert_type because of parameter names
35# NOTE: Generic usage only possible with Python 3.9
36relu: ScriptFunction = jit.script(nn.functional.relu)
37
38# can't script nn.Module class
39with pytest.raises(RuntimeError):
40    assert_never(jit.script(nn.Linear))
41