# Owner(s): ["oncall: jit"] import os import sys from typing import Any, Dict, List, NamedTuple, Optional, Tuple # noqa: F401 import torch from torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED from torch.testing._internal.common_utils import NoTest from torch.testing._internal.jit_utils import JitTestCase, make_global # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) if not _IS_MONKEYTYPE_INSTALLED: print( "monkeytype is not installed. Skipping tests for Profile-Directed Typing", file=sys.stderr, ) JitTestCase = NoTest # type: ignore[misc, assignment] # noqa: F811 if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestPDT(JitTestCase): """ A suite of tests for profile directed typing in TorchScript. """ def test_nn_module(self): class TestPDTModel(torch.nn.Module): def forward(self, x) -> Any: if isinstance(x, int): return x + 1 elif isinstance(x, float): return x - 1 else: return x make_global(TestPDTModel) pdt_model = TestPDTModel() inp: List[Tuple[Any, ...]] = [ (20,), (2.7,), (False,), ] scripted_pdt_model = torch.jit.script( pdt_model, example_inputs={pdt_model: inp} ) self.assertEqual(scripted_pdt_model(50), pdt_model(50)) self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8)) self.assertTrue(scripted_pdt_model(True), pdt_model(True)) def test_nested_nn_module_class(self): class NestedPDTInner(torch.nn.Module): def forward(self, x): if isinstance(x, int): return x * 10 return x class NestedModulePDTWrapper(torch.nn.Module): def __init__(self, inner): super().__init__() self.inner = inner def forward(self, x): return self.inner(x) make_global(NestedPDTInner, NestedModulePDTWrapper) inner_pdt_model = NestedPDTInner() wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model) inp: List[Tuple[Any, ...]] = [(20,), (False,)] scripted_pdt_model = torch.jit.script( wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp} ) self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30)) self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9)) self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True)) def test_nested_nn_module_class_with_args(self): class NestedModulePDTInner(torch.nn.Module): def forward(self, x, y): if isinstance(x, int): return x * 10 + y return x class NestedModulePDTOuter(torch.nn.Module): def __init__(self, inner): super().__init__() self.inner = inner def forward(self, x): return self.inner(x, 20) make_global(NestedModulePDTInner, NestedModulePDTOuter) inner_pdt_model = NestedModulePDTInner() outer_pdt_model = NestedModulePDTOuter(inner_pdt_model) inner_input: List[Tuple[Any, ...]] = [ (10, 10), (1.9, 20), ] outer_input: List[Tuple[Any, ...]] = [(20,), (False,)] scripted_pdt_model = torch.jit.script( outer_pdt_model, example_inputs={ inner_pdt_model: inner_input, outer_pdt_model: outer_input, }, ) self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30)) self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9)) self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True)) def test_nested_function_in_forward(self): class NestedFunctionInForward(torch.nn.Module): def forward(self, x): return self.fun(x) + 10 def fun(self, x): if isinstance(x, bool): return 0 elif isinstance(x, int): return x + 1 return 0 make_global(NestedFunctionInForward) pdt_model = NestedFunctionInForward() inp: List[Tuple[Any, ...]] = [(-1,), (False,)] scripted_pdt_model = torch.jit.script( pdt_model, example_inputs={pdt_model: inp} ) self.assertEqual(scripted_pdt_model(30), pdt_model(30)) self.assertEqual(scripted_pdt_model(True), pdt_model(True)) def test_nn_module_with_export_function(self): class TestModelWithExport(torch.nn.Module): @torch.jit.export def fn(self, x, y) -> Any: assert not (isinstance(x, bool) and isinstance(y, bool)) if isinstance(x, int) and isinstance(y, int): return x + y elif isinstance(x, float) and isinstance(y, float): return x - y else: return -1 make_global(TestModelWithExport) pdt_model = TestModelWithExport() inp: List[Tuple[Any, ...]] = [ ( 20, 10, ), ( 2.7, 8.9, ), ] scripted_pdt_model = torch.jit.script( pdt_model, example_inputs={pdt_model.fn: inp} ) self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90)) self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2)) self.assertTrue( scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2) ) def test_class_methods(self): class PDTModel: def test_sum(self, a): return sum(a) make_global(PDTModel) pdt_model = PDTModel() inp: List[Tuple[Any, ...]] = [ ( [ 10, 20, ], ), ] scripted_pdt_model = torch.jit.script( PDTModel, example_inputs={pdt_model.test_sum: inp} ) script_model = scripted_pdt_model() self.assertEqual( script_model.test_sum( [ 10, 20, 30, ], ), pdt_model.test_sum( [ 10, 20, 30, ], ), ) def test_class_with_multiple_methods(self): class PDTModelWithManyMethods: def test_list_to_dict(self, a): new_dictionary: Dict[float, bool] = {} for element in a: new_dictionary[element] = True return new_dictionary def test_substring(self, a, b): return b in a make_global(PDTModelWithManyMethods) pdt_model = PDTModelWithManyMethods() list_inp: List[Tuple[Any, ...]] = [ ( [ 1.2, 2.3, ], ), ] str_inp: List[Tuple[Any, ...]] = [ ( "abc", "b", ), ] scripted_pdt_model = torch.jit.script( PDTModelWithManyMethods, example_inputs={ pdt_model.test_list_to_dict: list_inp, pdt_model.test_substring: str_inp, }, ) script_model = scripted_pdt_model() self.assertEqual( script_model.test_list_to_dict( [ 1.1, 2.2, 3.3, ], ), pdt_model.test_list_to_dict( [ 1.1, 2.2, 3.3, ], ), ) self.assertEqual( script_model.test_substring( "helloworld", "world", ), pdt_model.test_substring( "helloworld", "world", ), ) self.assertEqual( script_model.test_substring( "helloworld", "def", ), pdt_model.test_substring( "helloworld", "def", ), ) def test_multiple_class_with_same_method(self): class PDTModelOne: def test_find(self, a, b): return b in a.keys() class PDTModelTwo: def test_find(self, a, b): return b in a make_global(PDTModelOne, PDTModelTwo) pdt_model_one = PDTModelOne() pdt_model_two = PDTModelTwo() dict_inp: List[Tuple[Any, ...]] = [ ( { 1.2: True, 2.3: False, }, 1.2, ), ] list_inp: List[Tuple[Any, ...]] = [ ( [ "abc", "b", ], "c", ), ] scripted_pdt_model_one = torch.jit.script( PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp} ) scripted_pdt_model_two = torch.jit.script( PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp} ) script_model_one, script_model_two = ( scripted_pdt_model_one(), scripted_pdt_model_two(), ) self.assertEqual( script_model_one.test_find( { 1.1: True, 2.2: True, 3.3: False, }, 4.4, ), pdt_model_one.test_find( { 1.1: True, 2.2: True, 3.3: False, }, 4.4, ), ) self.assertEqual( script_model_two.test_find( [ "hello", "world", ], "world", ), pdt_model_two.test_find( [ "hello", "world", ], "world", ), ) def test_pdt(self): def test_sum(a, b): return a + b make_global(test_sum) scripted_fn_add = torch.jit.script(test_sum, example_inputs=[(3, 4)]) self.assertEqual(scripted_fn_add(10, 2), test_sum(10, 2)) def test_sub(a, b): return a - b make_global(test_sub) scripted_fn_sub = torch.jit.script(test_sub, example_inputs=[(3.9, 4.10)]) self.assertEqual(scripted_fn_sub(6.5, 2.9), test_sub(6.5, 2.9)) def test_mul(a, b): return a * b make_global(test_mul) scripted_fn_mul = torch.jit.script(test_mul, example_inputs=[(-10, 9)]) self.assertEqual(scripted_fn_mul(-1, 3), test_mul(-1, 3)) def test_args_complex(real, img): return torch.complex(real, img) make_global(test_args_complex) scripted_fn_complex = torch.jit.script( test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))] ) arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4) self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2)) def test_bool(a): if a: return -1 else: return 0 make_global(test_bool) scripted_fn_bool = torch.jit.script(test_bool, example_inputs=[(True,)]) self.assertEqual(scripted_fn_bool(True), test_bool(True)) def test_str(a): if a == "": return False else: return True make_global(test_str) scripted_fn_str = torch.jit.script(test_str, example_inputs=[("",)]) self.assertEqual(scripted_fn_str("abc"), test_str("abc")) def test_pdt_list_and_tuple(self): def test_list_and_tuple(a): return sum(a) make_global(test_list_and_tuple) scripted_fn_float_list_input = torch.jit.script( test_list_and_tuple, example_inputs=[([4.9, 8.9],)] ) self.assertEqual( scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6]) ) scripted_fn_bool_list_input = torch.jit.script( test_list_and_tuple, example_inputs=[([True, False, True],)] ) self.assertEqual( scripted_fn_bool_list_input([True, True, True]), test_list_and_tuple([True, True, True]), ) scripted_fn_int_list_input = torch.jit.script( test_list_and_tuple, example_inputs=[([3, 4, 5],)] ) self.assertEqual( scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3]) ) scripted_fn_float_tuple_input = torch.jit.script( test_list_and_tuple, example_inputs=[((4.9, 8.9),)] ) self.assertEqual( scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6)) ) scripted_fn_bool_tuple_input = torch.jit.script( test_list_and_tuple, example_inputs=[((True, False, True),)] ) self.assertEqual( scripted_fn_bool_tuple_input((True, True, True)), test_list_and_tuple((True, True, True)), ) scripted_fn_int_tuple_input = torch.jit.script( test_list_and_tuple, example_inputs=[((3, 4, 5),)] ) self.assertEqual( scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3)) ) def test_nested_list_and_tuple(self): def test_nested_list(inp): return [sum(v) for v in inp] def test_nested_tuple(inp): ans = 0.0 for tup in inp: for val in tup: if val > 0: ans *= val return ans make_global(test_nested_list, test_nested_tuple) list_inp = [ [ 1, 2, 3, ], [ 5, 6, 7, ], ] scripted_fn = torch.jit.script( test_nested_list, example_inputs=[ (list_inp,), ], ) inp = [ [ 0, 4, 7, ], [ 8, 11, ], [ 6, -1, -20, ], ] self.assertEqual( scripted_fn( inp, ), test_nested_list( inp, ), ) list_inp = ( [ 1, 2, 3, ], [ 5, 6, 7, ], ) scripted_fn = torch.jit.script( test_nested_list, example_inputs=[ (list_inp,), ], ) inp = ( [ 0, 4, 7, ], [ 8, 11, ], [ 6, -1, -20, ], ) self.assertEqual( scripted_fn( inp, ), test_nested_list( inp, ), ) tup_inp = [ ( 1.0, 2.6, 3.7, ), ( 5.7, 6.1, 1.7, ), ] scripted_fn = torch.jit.script( test_nested_tuple, example_inputs=[ (tup_inp,), ], ) inp = [ ( 1.0, 4.1, 7.4, ), ( 4.8, 1.1, -1.2, ), ( 6.3, -1.3, -2.0, ), ] self.assertEqual( scripted_fn( inp, ), test_nested_tuple( inp, ), ) tup_inp = ( ( True, False, True, ), ( False, False, False, ), ) scripted_fn = torch.jit.script( test_nested_tuple, example_inputs=[ (tup_inp,), ], ) inp = ( ( True, True, True, ), ( False, False, True, ), ) self.assertEqual( scripted_fn( inp, ), test_nested_tuple( inp, ), ) def test_pdt_dict(self): def test_dict(a): return a["foo"] def test_dict_int_list(a): return a[1] make_global(test_dict, test_dict_int_list) str_bool_inp = {"foo": True, "bar": False} scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)]) self.assertEqual( scripted_fn( {"foo": False, "bar": True}, ), test_dict( {"foo": False, "bar": True}, ), ) str_list_inp = {0: [True, False], 1: [False, True]} scripted_fn = torch.jit.script( test_dict_int_list, example_inputs=[(str_list_inp,)] ) self.assertEqual( scripted_fn( {0: [False, False], 1: [True, True]}, ), test_dict_int_list( {0: [False, False], 1: [True, True]}, ), ) def test_any(self): def test_multiple_types(a): assert not isinstance(a, bool) return a def test_multiple_type_refinement(a): if isinstance(a, bool): return 1 elif isinstance(a, int): return 1 + a elif isinstance(a, float): return 1 + int(a) else: return -1 make_global(test_multiple_types, test_multiple_type_refinement) scripted_fn = torch.jit.script( test_multiple_types, example_inputs=[(1,), ("abc",), (8.9,), ([3, 4, 5],)] ) self.assertEqual(scripted_fn(10), test_multiple_types(10)) self.assertEqual(scripted_fn("def"), test_multiple_types("def")) self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999)) self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14])) scripted_fn = torch.jit.script( test_multiple_type_refinement, example_inputs=[ (1,), ("abc",), (8.9,), ([3, 4, 5],), (True,), ({"a": True},), ], ) self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10)) self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def")) self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999)) self.assertEqual( scripted_fn([10, 11, 14]), test_multiple_type_refinement([10, 11, 14]) ) self.assertEqual(scripted_fn(False), test_multiple_type_refinement(False)) self.assertEqual( scripted_fn({"abc": True, "def": False}), test_multiple_type_refinement({"abc": True, "def": False}), ) def test_class_as_profiled_types(self): class UserDefinedClass: def fn(self, b) -> Any: assert b is not None if isinstance(b, int): return b if b > 0 else -1 elif isinstance(b, float): return b if b > 0.0 else -1.0 return 0 def test_model(a, m): assert not isinstance(a, bool) return m.fn(a) make_global(UserDefinedClass, test_model) user_class = UserDefinedClass() scripted_fn = torch.jit.script( test_model, example_inputs=[ ( 10, user_class, ), ( 10.9, user_class, ), ], ) self.assertEqual( scripted_fn( 100, user_class, ), test_model(100, user_class), ) self.assertEqual( scripted_fn( 1.9, user_class, ), test_model(1.9, user_class), ) def test_class_with_args_as_profiled_types(self): class ClassWithArgs: def __init__(self, a: bool): self.a = a def fn(self, b): if self.a: return b else: return -1 def test_model_with_args(a, m): assert not isinstance(a, bool) return m.fn(a) make_global(ClassWithArgs, test_model_with_args) user_class = ClassWithArgs(False) scripted_fn = torch.jit.script( test_model_with_args, example_inputs=[ ( 10, user_class, ), ( 10.9, user_class, ), ], ) self.assertEqual( scripted_fn( 100, ClassWithArgs(True), ), test_model_with_args(100, ClassWithArgs(True)), ) def test_nn_parameter_as_arg(self): class TestNNParameter(torch.nn.Module): def __init__(self) -> None: super().__init__() self.inp = torch.nn.Parameter(torch.ones(2, 3)) def add_nn_parameter_with_int(self, x, y): return torch.add(x, y) def forward(self, y): return self.add_nn_parameter_with_int(self.inp, y) make_global(TestNNParameter) pdt_model = TestNNParameter() scripted_fn = torch.jit.script( pdt_model, example_inputs={ pdt_model: [ (10,), ], }, ) self.assertEqual(scripted_fn(20), pdt_model(20)) def test_fx_tracing_with_typing(self): class FXModelOutput(NamedTuple): result: List[int] class FXModel(torch.nn.Module): def forward(self, a) -> FXModelOutput: result = FXModelOutput(result=a) return result make_global(FXModel, FXModelOutput) pdt_model = FXModel() scripted_fn = torch.jit.script( pdt_model, example_inputs={ pdt_model: [ ( [ 10, 20, ], ), ], }, ) self.assertEqual(scripted_fn([20]), pdt_model([20])) def test_nonetype_as_optional_of_type(self): def test_none(a) -> Any: if a is None: return 0 else: return a + torch.ones(1) make_global(test_none) scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10.6,)]) self.assertEqual( scripted_fn( 30.9, ), test_none( 30.9, ), ) scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10,)]) self.assertEqual( scripted_fn( 2, ), test_none( 2, ), ) scripted_fn = torch.jit.script( test_none, example_inputs=[(None,), (torch.Tensor(1),)] ) self.assertEqual( scripted_fn( torch.ones(1), ), test_none( torch.ones(1), ), )