1from enum import Enum 2from typing import Type, TypeVar 3from typing_extensions import assert_never, assert_type, ParamSpec 4 5import pytest 6 7from torch import jit, nn, ScriptDict, ScriptFunction, ScriptList 8 9 10P = ParamSpec("P") 11R = TypeVar("R", covariant=True) 12 13 14class Color(Enum): 15 RED = 1 16 GREEN = 2 17 BLUE = 3 18 19 20# Script Enum 21assert_type(jit.script(Color), Type[Color]) 22 23# ScriptDict 24assert_type(jit.script({1: 1}), ScriptDict) 25 26# ScriptList 27assert_type(jit.script([0]), ScriptList) 28 29# ScriptModule 30scripted_module = jit.script(nn.Linear(2, 2)) 31assert_type(scripted_module, jit.RecursiveScriptModule) 32 33# ScripFunction 34# NOTE: can't use assert_type because of parameter names 35# NOTE: Generic usage only possible with Python 3.9 36relu: ScriptFunction = jit.script(nn.functional.relu) 37 38# can't script nn.Module class 39with pytest.raises(RuntimeError): 40 assert_never(jit.script(nn.Linear)) 41