1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, NamedTuple, Optional, Tuple # noqa: F401 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerfrom torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import NoTest 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 14*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerif not _IS_MONKEYTYPE_INSTALLED: 18*da0073e9SAndroid Build Coastguard Worker print( 19*da0073e9SAndroid Build Coastguard Worker "monkeytype is not installed. Skipping tests for Profile-Directed Typing", 20*da0073e9SAndroid Build Coastguard Worker file=sys.stderr, 21*da0073e9SAndroid Build Coastguard Worker ) 22*da0073e9SAndroid Build Coastguard Worker JitTestCase = NoTest # type: ignore[misc, assignment] # noqa: F811 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 25*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 26*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 27*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 28*da0073e9SAndroid Build Coastguard Worker "instead." 29*da0073e9SAndroid Build Coastguard Worker ) 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workerclass TestPDT(JitTestCase): 33*da0073e9SAndroid Build Coastguard Worker """ 34*da0073e9SAndroid Build Coastguard Worker A suite of tests for profile directed typing in TorchScript. 35*da0073e9SAndroid Build Coastguard Worker """ 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker def test_nn_module(self): 38*da0073e9SAndroid Build Coastguard Worker class TestPDTModel(torch.nn.Module): 39*da0073e9SAndroid Build Coastguard Worker def forward(self, x) -> Any: 40*da0073e9SAndroid Build Coastguard Worker if isinstance(x, int): 41*da0073e9SAndroid Build Coastguard Worker return x + 1 42*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, float): 43*da0073e9SAndroid Build Coastguard Worker return x - 1 44*da0073e9SAndroid Build Coastguard Worker else: 45*da0073e9SAndroid Build Coastguard Worker return x 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker make_global(TestPDTModel) 48*da0073e9SAndroid Build Coastguard Worker pdt_model = TestPDTModel() 49*da0073e9SAndroid Build Coastguard Worker inp: List[Tuple[Any, ...]] = [ 50*da0073e9SAndroid Build Coastguard Worker (20,), 51*da0073e9SAndroid Build Coastguard Worker (2.7,), 52*da0073e9SAndroid Build Coastguard Worker (False,), 53*da0073e9SAndroid Build Coastguard Worker ] 54*da0073e9SAndroid Build Coastguard Worker scripted_pdt_model = torch.jit.script( 55*da0073e9SAndroid Build Coastguard Worker pdt_model, example_inputs={pdt_model: inp} 56*da0073e9SAndroid Build Coastguard Worker ) 57*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_pdt_model(50), pdt_model(50)) 58*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8)) 59*da0073e9SAndroid Build Coastguard Worker self.assertTrue(scripted_pdt_model(True), pdt_model(True)) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker def test_nested_nn_module_class(self): 62*da0073e9SAndroid Build Coastguard Worker class NestedPDTInner(torch.nn.Module): 63*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 64*da0073e9SAndroid Build Coastguard Worker if isinstance(x, int): 65*da0073e9SAndroid Build Coastguard Worker return x * 10 66*da0073e9SAndroid Build Coastguard Worker return x 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker class NestedModulePDTWrapper(torch.nn.Module): 69*da0073e9SAndroid Build Coastguard Worker def __init__(self, inner): 70*da0073e9SAndroid Build Coastguard Worker super().__init__() 71*da0073e9SAndroid Build Coastguard Worker self.inner = inner 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 74*da0073e9SAndroid Build Coastguard Worker return self.inner(x) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker make_global(NestedPDTInner, NestedModulePDTWrapper) 77*da0073e9SAndroid Build Coastguard Worker inner_pdt_model = NestedPDTInner() 78*da0073e9SAndroid Build Coastguard Worker wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model) 79*da0073e9SAndroid Build Coastguard Worker inp: List[Tuple[Any, ...]] = [(20,), (False,)] 80*da0073e9SAndroid Build Coastguard Worker scripted_pdt_model = torch.jit.script( 81*da0073e9SAndroid Build Coastguard Worker wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp} 82*da0073e9SAndroid Build Coastguard Worker ) 83*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30)) 84*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9)) 85*da0073e9SAndroid Build Coastguard Worker self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True)) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def test_nested_nn_module_class_with_args(self): 88*da0073e9SAndroid Build Coastguard Worker class NestedModulePDTInner(torch.nn.Module): 89*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 90*da0073e9SAndroid Build Coastguard Worker if isinstance(x, int): 91*da0073e9SAndroid Build Coastguard Worker return x * 10 + y 92*da0073e9SAndroid Build Coastguard Worker return x 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker class NestedModulePDTOuter(torch.nn.Module): 95*da0073e9SAndroid Build Coastguard Worker def __init__(self, inner): 96*da0073e9SAndroid Build Coastguard Worker super().__init__() 97*da0073e9SAndroid Build Coastguard Worker self.inner = inner 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 100*da0073e9SAndroid Build Coastguard Worker return self.inner(x, 20) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker make_global(NestedModulePDTInner, NestedModulePDTOuter) 103*da0073e9SAndroid Build Coastguard Worker inner_pdt_model = NestedModulePDTInner() 104*da0073e9SAndroid Build Coastguard Worker outer_pdt_model = NestedModulePDTOuter(inner_pdt_model) 105*da0073e9SAndroid Build Coastguard Worker inner_input: List[Tuple[Any, ...]] = [ 106*da0073e9SAndroid Build Coastguard Worker (10, 10), 107*da0073e9SAndroid Build Coastguard Worker (1.9, 20), 108*da0073e9SAndroid Build Coastguard Worker ] 109*da0073e9SAndroid Build Coastguard Worker outer_input: List[Tuple[Any, ...]] = [(20,), (False,)] 110*da0073e9SAndroid Build Coastguard Worker scripted_pdt_model = torch.jit.script( 111*da0073e9SAndroid Build Coastguard Worker outer_pdt_model, 112*da0073e9SAndroid Build Coastguard Worker example_inputs={ 113*da0073e9SAndroid Build Coastguard Worker inner_pdt_model: inner_input, 114*da0073e9SAndroid Build Coastguard Worker outer_pdt_model: outer_input, 115*da0073e9SAndroid Build Coastguard Worker }, 116*da0073e9SAndroid Build Coastguard Worker ) 117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30)) 118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9)) 119*da0073e9SAndroid Build Coastguard Worker self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True)) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker def test_nested_function_in_forward(self): 122*da0073e9SAndroid Build Coastguard Worker class NestedFunctionInForward(torch.nn.Module): 123*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 124*da0073e9SAndroid Build Coastguard Worker return self.fun(x) + 10 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker def fun(self, x): 127*da0073e9SAndroid Build Coastguard Worker if isinstance(x, bool): 128*da0073e9SAndroid Build Coastguard Worker return 0 129*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, int): 130*da0073e9SAndroid Build Coastguard Worker return x + 1 131*da0073e9SAndroid Build Coastguard Worker return 0 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker make_global(NestedFunctionInForward) 134*da0073e9SAndroid Build Coastguard Worker pdt_model = NestedFunctionInForward() 135*da0073e9SAndroid Build Coastguard Worker inp: List[Tuple[Any, ...]] = [(-1,), (False,)] 136*da0073e9SAndroid Build Coastguard Worker scripted_pdt_model = torch.jit.script( 137*da0073e9SAndroid Build Coastguard Worker pdt_model, example_inputs={pdt_model: inp} 138*da0073e9SAndroid Build Coastguard Worker ) 139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_pdt_model(30), pdt_model(30)) 140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_pdt_model(True), pdt_model(True)) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker def test_nn_module_with_export_function(self): 143*da0073e9SAndroid Build Coastguard Worker class TestModelWithExport(torch.nn.Module): 144*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 145*da0073e9SAndroid Build Coastguard Worker def fn(self, x, y) -> Any: 146*da0073e9SAndroid Build Coastguard Worker assert not (isinstance(x, bool) and isinstance(y, bool)) 147*da0073e9SAndroid Build Coastguard Worker if isinstance(x, int) and isinstance(y, int): 148*da0073e9SAndroid Build Coastguard Worker return x + y 149*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, float) and isinstance(y, float): 150*da0073e9SAndroid Build Coastguard Worker return x - y 151*da0073e9SAndroid Build Coastguard Worker else: 152*da0073e9SAndroid Build Coastguard Worker return -1 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker make_global(TestModelWithExport) 155*da0073e9SAndroid Build Coastguard Worker pdt_model = TestModelWithExport() 156*da0073e9SAndroid Build Coastguard Worker inp: List[Tuple[Any, ...]] = [ 157*da0073e9SAndroid Build Coastguard Worker ( 158*da0073e9SAndroid Build Coastguard Worker 20, 159*da0073e9SAndroid Build Coastguard Worker 10, 160*da0073e9SAndroid Build Coastguard Worker ), 161*da0073e9SAndroid Build Coastguard Worker ( 162*da0073e9SAndroid Build Coastguard Worker 2.7, 163*da0073e9SAndroid Build Coastguard Worker 8.9, 164*da0073e9SAndroid Build Coastguard Worker ), 165*da0073e9SAndroid Build Coastguard Worker ] 166*da0073e9SAndroid Build Coastguard Worker scripted_pdt_model = torch.jit.script( 167*da0073e9SAndroid Build Coastguard Worker pdt_model, example_inputs={pdt_model.fn: inp} 168*da0073e9SAndroid Build Coastguard Worker ) 169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90)) 170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2)) 171*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 172*da0073e9SAndroid Build Coastguard Worker scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2) 173*da0073e9SAndroid Build Coastguard Worker ) 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker def test_class_methods(self): 176*da0073e9SAndroid Build Coastguard Worker class PDTModel: 177*da0073e9SAndroid Build Coastguard Worker def test_sum(self, a): 178*da0073e9SAndroid Build Coastguard Worker return sum(a) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker make_global(PDTModel) 181*da0073e9SAndroid Build Coastguard Worker pdt_model = PDTModel() 182*da0073e9SAndroid Build Coastguard Worker inp: List[Tuple[Any, ...]] = [ 183*da0073e9SAndroid Build Coastguard Worker ( 184*da0073e9SAndroid Build Coastguard Worker [ 185*da0073e9SAndroid Build Coastguard Worker 10, 186*da0073e9SAndroid Build Coastguard Worker 20, 187*da0073e9SAndroid Build Coastguard Worker ], 188*da0073e9SAndroid Build Coastguard Worker ), 189*da0073e9SAndroid Build Coastguard Worker ] 190*da0073e9SAndroid Build Coastguard Worker scripted_pdt_model = torch.jit.script( 191*da0073e9SAndroid Build Coastguard Worker PDTModel, example_inputs={pdt_model.test_sum: inp} 192*da0073e9SAndroid Build Coastguard Worker ) 193*da0073e9SAndroid Build Coastguard Worker script_model = scripted_pdt_model() 194*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 195*da0073e9SAndroid Build Coastguard Worker script_model.test_sum( 196*da0073e9SAndroid Build Coastguard Worker [ 197*da0073e9SAndroid Build Coastguard Worker 10, 198*da0073e9SAndroid Build Coastguard Worker 20, 199*da0073e9SAndroid Build Coastguard Worker 30, 200*da0073e9SAndroid Build Coastguard Worker ], 201*da0073e9SAndroid Build Coastguard Worker ), 202*da0073e9SAndroid Build Coastguard Worker pdt_model.test_sum( 203*da0073e9SAndroid Build Coastguard Worker [ 204*da0073e9SAndroid Build Coastguard Worker 10, 205*da0073e9SAndroid Build Coastguard Worker 20, 206*da0073e9SAndroid Build Coastguard Worker 30, 207*da0073e9SAndroid Build Coastguard Worker ], 208*da0073e9SAndroid Build Coastguard Worker ), 209*da0073e9SAndroid Build Coastguard Worker ) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker def test_class_with_multiple_methods(self): 212*da0073e9SAndroid Build Coastguard Worker class PDTModelWithManyMethods: 213*da0073e9SAndroid Build Coastguard Worker def test_list_to_dict(self, a): 214*da0073e9SAndroid Build Coastguard Worker new_dictionary: Dict[float, bool] = {} 215*da0073e9SAndroid Build Coastguard Worker for element in a: 216*da0073e9SAndroid Build Coastguard Worker new_dictionary[element] = True 217*da0073e9SAndroid Build Coastguard Worker return new_dictionary 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker def test_substring(self, a, b): 220*da0073e9SAndroid Build Coastguard Worker return b in a 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker make_global(PDTModelWithManyMethods) 223*da0073e9SAndroid Build Coastguard Worker pdt_model = PDTModelWithManyMethods() 224*da0073e9SAndroid Build Coastguard Worker list_inp: List[Tuple[Any, ...]] = [ 225*da0073e9SAndroid Build Coastguard Worker ( 226*da0073e9SAndroid Build Coastguard Worker [ 227*da0073e9SAndroid Build Coastguard Worker 1.2, 228*da0073e9SAndroid Build Coastguard Worker 2.3, 229*da0073e9SAndroid Build Coastguard Worker ], 230*da0073e9SAndroid Build Coastguard Worker ), 231*da0073e9SAndroid Build Coastguard Worker ] 232*da0073e9SAndroid Build Coastguard Worker str_inp: List[Tuple[Any, ...]] = [ 233*da0073e9SAndroid Build Coastguard Worker ( 234*da0073e9SAndroid Build Coastguard Worker "abc", 235*da0073e9SAndroid Build Coastguard Worker "b", 236*da0073e9SAndroid Build Coastguard Worker ), 237*da0073e9SAndroid Build Coastguard Worker ] 238*da0073e9SAndroid Build Coastguard Worker scripted_pdt_model = torch.jit.script( 239*da0073e9SAndroid Build Coastguard Worker PDTModelWithManyMethods, 240*da0073e9SAndroid Build Coastguard Worker example_inputs={ 241*da0073e9SAndroid Build Coastguard Worker pdt_model.test_list_to_dict: list_inp, 242*da0073e9SAndroid Build Coastguard Worker pdt_model.test_substring: str_inp, 243*da0073e9SAndroid Build Coastguard Worker }, 244*da0073e9SAndroid Build Coastguard Worker ) 245*da0073e9SAndroid Build Coastguard Worker script_model = scripted_pdt_model() 246*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 247*da0073e9SAndroid Build Coastguard Worker script_model.test_list_to_dict( 248*da0073e9SAndroid Build Coastguard Worker [ 249*da0073e9SAndroid Build Coastguard Worker 1.1, 250*da0073e9SAndroid Build Coastguard Worker 2.2, 251*da0073e9SAndroid Build Coastguard Worker 3.3, 252*da0073e9SAndroid Build Coastguard Worker ], 253*da0073e9SAndroid Build Coastguard Worker ), 254*da0073e9SAndroid Build Coastguard Worker pdt_model.test_list_to_dict( 255*da0073e9SAndroid Build Coastguard Worker [ 256*da0073e9SAndroid Build Coastguard Worker 1.1, 257*da0073e9SAndroid Build Coastguard Worker 2.2, 258*da0073e9SAndroid Build Coastguard Worker 3.3, 259*da0073e9SAndroid Build Coastguard Worker ], 260*da0073e9SAndroid Build Coastguard Worker ), 261*da0073e9SAndroid Build Coastguard Worker ) 262*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 263*da0073e9SAndroid Build Coastguard Worker script_model.test_substring( 264*da0073e9SAndroid Build Coastguard Worker "helloworld", 265*da0073e9SAndroid Build Coastguard Worker "world", 266*da0073e9SAndroid Build Coastguard Worker ), 267*da0073e9SAndroid Build Coastguard Worker pdt_model.test_substring( 268*da0073e9SAndroid Build Coastguard Worker "helloworld", 269*da0073e9SAndroid Build Coastguard Worker "world", 270*da0073e9SAndroid Build Coastguard Worker ), 271*da0073e9SAndroid Build Coastguard Worker ) 272*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 273*da0073e9SAndroid Build Coastguard Worker script_model.test_substring( 274*da0073e9SAndroid Build Coastguard Worker "helloworld", 275*da0073e9SAndroid Build Coastguard Worker "def", 276*da0073e9SAndroid Build Coastguard Worker ), 277*da0073e9SAndroid Build Coastguard Worker pdt_model.test_substring( 278*da0073e9SAndroid Build Coastguard Worker "helloworld", 279*da0073e9SAndroid Build Coastguard Worker "def", 280*da0073e9SAndroid Build Coastguard Worker ), 281*da0073e9SAndroid Build Coastguard Worker ) 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker def test_multiple_class_with_same_method(self): 284*da0073e9SAndroid Build Coastguard Worker class PDTModelOne: 285*da0073e9SAndroid Build Coastguard Worker def test_find(self, a, b): 286*da0073e9SAndroid Build Coastguard Worker return b in a.keys() 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker class PDTModelTwo: 289*da0073e9SAndroid Build Coastguard Worker def test_find(self, a, b): 290*da0073e9SAndroid Build Coastguard Worker return b in a 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker make_global(PDTModelOne, PDTModelTwo) 293*da0073e9SAndroid Build Coastguard Worker pdt_model_one = PDTModelOne() 294*da0073e9SAndroid Build Coastguard Worker pdt_model_two = PDTModelTwo() 295*da0073e9SAndroid Build Coastguard Worker dict_inp: List[Tuple[Any, ...]] = [ 296*da0073e9SAndroid Build Coastguard Worker ( 297*da0073e9SAndroid Build Coastguard Worker { 298*da0073e9SAndroid Build Coastguard Worker 1.2: True, 299*da0073e9SAndroid Build Coastguard Worker 2.3: False, 300*da0073e9SAndroid Build Coastguard Worker }, 301*da0073e9SAndroid Build Coastguard Worker 1.2, 302*da0073e9SAndroid Build Coastguard Worker ), 303*da0073e9SAndroid Build Coastguard Worker ] 304*da0073e9SAndroid Build Coastguard Worker list_inp: List[Tuple[Any, ...]] = [ 305*da0073e9SAndroid Build Coastguard Worker ( 306*da0073e9SAndroid Build Coastguard Worker [ 307*da0073e9SAndroid Build Coastguard Worker "abc", 308*da0073e9SAndroid Build Coastguard Worker "b", 309*da0073e9SAndroid Build Coastguard Worker ], 310*da0073e9SAndroid Build Coastguard Worker "c", 311*da0073e9SAndroid Build Coastguard Worker ), 312*da0073e9SAndroid Build Coastguard Worker ] 313*da0073e9SAndroid Build Coastguard Worker scripted_pdt_model_one = torch.jit.script( 314*da0073e9SAndroid Build Coastguard Worker PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp} 315*da0073e9SAndroid Build Coastguard Worker ) 316*da0073e9SAndroid Build Coastguard Worker scripted_pdt_model_two = torch.jit.script( 317*da0073e9SAndroid Build Coastguard Worker PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp} 318*da0073e9SAndroid Build Coastguard Worker ) 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker script_model_one, script_model_two = ( 321*da0073e9SAndroid Build Coastguard Worker scripted_pdt_model_one(), 322*da0073e9SAndroid Build Coastguard Worker scripted_pdt_model_two(), 323*da0073e9SAndroid Build Coastguard Worker ) 324*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 325*da0073e9SAndroid Build Coastguard Worker script_model_one.test_find( 326*da0073e9SAndroid Build Coastguard Worker { 327*da0073e9SAndroid Build Coastguard Worker 1.1: True, 328*da0073e9SAndroid Build Coastguard Worker 2.2: True, 329*da0073e9SAndroid Build Coastguard Worker 3.3: False, 330*da0073e9SAndroid Build Coastguard Worker }, 331*da0073e9SAndroid Build Coastguard Worker 4.4, 332*da0073e9SAndroid Build Coastguard Worker ), 333*da0073e9SAndroid Build Coastguard Worker pdt_model_one.test_find( 334*da0073e9SAndroid Build Coastguard Worker { 335*da0073e9SAndroid Build Coastguard Worker 1.1: True, 336*da0073e9SAndroid Build Coastguard Worker 2.2: True, 337*da0073e9SAndroid Build Coastguard Worker 3.3: False, 338*da0073e9SAndroid Build Coastguard Worker }, 339*da0073e9SAndroid Build Coastguard Worker 4.4, 340*da0073e9SAndroid Build Coastguard Worker ), 341*da0073e9SAndroid Build Coastguard Worker ) 342*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 343*da0073e9SAndroid Build Coastguard Worker script_model_two.test_find( 344*da0073e9SAndroid Build Coastguard Worker [ 345*da0073e9SAndroid Build Coastguard Worker "hello", 346*da0073e9SAndroid Build Coastguard Worker "world", 347*da0073e9SAndroid Build Coastguard Worker ], 348*da0073e9SAndroid Build Coastguard Worker "world", 349*da0073e9SAndroid Build Coastguard Worker ), 350*da0073e9SAndroid Build Coastguard Worker pdt_model_two.test_find( 351*da0073e9SAndroid Build Coastguard Worker [ 352*da0073e9SAndroid Build Coastguard Worker "hello", 353*da0073e9SAndroid Build Coastguard Worker "world", 354*da0073e9SAndroid Build Coastguard Worker ], 355*da0073e9SAndroid Build Coastguard Worker "world", 356*da0073e9SAndroid Build Coastguard Worker ), 357*da0073e9SAndroid Build Coastguard Worker ) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker def test_pdt(self): 360*da0073e9SAndroid Build Coastguard Worker def test_sum(a, b): 361*da0073e9SAndroid Build Coastguard Worker return a + b 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker make_global(test_sum) 364*da0073e9SAndroid Build Coastguard Worker scripted_fn_add = torch.jit.script(test_sum, example_inputs=[(3, 4)]) 365*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_add(10, 2), test_sum(10, 2)) 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker def test_sub(a, b): 368*da0073e9SAndroid Build Coastguard Worker return a - b 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker make_global(test_sub) 371*da0073e9SAndroid Build Coastguard Worker scripted_fn_sub = torch.jit.script(test_sub, example_inputs=[(3.9, 4.10)]) 372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_sub(6.5, 2.9), test_sub(6.5, 2.9)) 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker def test_mul(a, b): 375*da0073e9SAndroid Build Coastguard Worker return a * b 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker make_global(test_mul) 378*da0073e9SAndroid Build Coastguard Worker scripted_fn_mul = torch.jit.script(test_mul, example_inputs=[(-10, 9)]) 379*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_mul(-1, 3), test_mul(-1, 3)) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker def test_args_complex(real, img): 382*da0073e9SAndroid Build Coastguard Worker return torch.complex(real, img) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker make_global(test_args_complex) 385*da0073e9SAndroid Build Coastguard Worker scripted_fn_complex = torch.jit.script( 386*da0073e9SAndroid Build Coastguard Worker test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))] 387*da0073e9SAndroid Build Coastguard Worker ) 388*da0073e9SAndroid Build Coastguard Worker arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4) 389*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2)) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker def test_bool(a): 392*da0073e9SAndroid Build Coastguard Worker if a: 393*da0073e9SAndroid Build Coastguard Worker return -1 394*da0073e9SAndroid Build Coastguard Worker else: 395*da0073e9SAndroid Build Coastguard Worker return 0 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker make_global(test_bool) 398*da0073e9SAndroid Build Coastguard Worker scripted_fn_bool = torch.jit.script(test_bool, example_inputs=[(True,)]) 399*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_bool(True), test_bool(True)) 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker def test_str(a): 402*da0073e9SAndroid Build Coastguard Worker if a == "": 403*da0073e9SAndroid Build Coastguard Worker return False 404*da0073e9SAndroid Build Coastguard Worker else: 405*da0073e9SAndroid Build Coastguard Worker return True 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker make_global(test_str) 408*da0073e9SAndroid Build Coastguard Worker scripted_fn_str = torch.jit.script(test_str, example_inputs=[("",)]) 409*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_str("abc"), test_str("abc")) 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker def test_pdt_list_and_tuple(self): 412*da0073e9SAndroid Build Coastguard Worker def test_list_and_tuple(a): 413*da0073e9SAndroid Build Coastguard Worker return sum(a) 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker make_global(test_list_and_tuple) 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker scripted_fn_float_list_input = torch.jit.script( 418*da0073e9SAndroid Build Coastguard Worker test_list_and_tuple, example_inputs=[([4.9, 8.9],)] 419*da0073e9SAndroid Build Coastguard Worker ) 420*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 421*da0073e9SAndroid Build Coastguard Worker scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6]) 422*da0073e9SAndroid Build Coastguard Worker ) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker scripted_fn_bool_list_input = torch.jit.script( 425*da0073e9SAndroid Build Coastguard Worker test_list_and_tuple, example_inputs=[([True, False, True],)] 426*da0073e9SAndroid Build Coastguard Worker ) 427*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 428*da0073e9SAndroid Build Coastguard Worker scripted_fn_bool_list_input([True, True, True]), 429*da0073e9SAndroid Build Coastguard Worker test_list_and_tuple([True, True, True]), 430*da0073e9SAndroid Build Coastguard Worker ) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker scripted_fn_int_list_input = torch.jit.script( 433*da0073e9SAndroid Build Coastguard Worker test_list_and_tuple, example_inputs=[([3, 4, 5],)] 434*da0073e9SAndroid Build Coastguard Worker ) 435*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 436*da0073e9SAndroid Build Coastguard Worker scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3]) 437*da0073e9SAndroid Build Coastguard Worker ) 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker scripted_fn_float_tuple_input = torch.jit.script( 440*da0073e9SAndroid Build Coastguard Worker test_list_and_tuple, example_inputs=[((4.9, 8.9),)] 441*da0073e9SAndroid Build Coastguard Worker ) 442*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 443*da0073e9SAndroid Build Coastguard Worker scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6)) 444*da0073e9SAndroid Build Coastguard Worker ) 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker scripted_fn_bool_tuple_input = torch.jit.script( 447*da0073e9SAndroid Build Coastguard Worker test_list_and_tuple, example_inputs=[((True, False, True),)] 448*da0073e9SAndroid Build Coastguard Worker ) 449*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 450*da0073e9SAndroid Build Coastguard Worker scripted_fn_bool_tuple_input((True, True, True)), 451*da0073e9SAndroid Build Coastguard Worker test_list_and_tuple((True, True, True)), 452*da0073e9SAndroid Build Coastguard Worker ) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker scripted_fn_int_tuple_input = torch.jit.script( 455*da0073e9SAndroid Build Coastguard Worker test_list_and_tuple, example_inputs=[((3, 4, 5),)] 456*da0073e9SAndroid Build Coastguard Worker ) 457*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 458*da0073e9SAndroid Build Coastguard Worker scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3)) 459*da0073e9SAndroid Build Coastguard Worker ) 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker def test_nested_list_and_tuple(self): 462*da0073e9SAndroid Build Coastguard Worker def test_nested_list(inp): 463*da0073e9SAndroid Build Coastguard Worker return [sum(v) for v in inp] 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker def test_nested_tuple(inp): 466*da0073e9SAndroid Build Coastguard Worker ans = 0.0 467*da0073e9SAndroid Build Coastguard Worker for tup in inp: 468*da0073e9SAndroid Build Coastguard Worker for val in tup: 469*da0073e9SAndroid Build Coastguard Worker if val > 0: 470*da0073e9SAndroid Build Coastguard Worker ans *= val 471*da0073e9SAndroid Build Coastguard Worker return ans 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker make_global(test_nested_list, test_nested_tuple) 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker list_inp = [ 476*da0073e9SAndroid Build Coastguard Worker [ 477*da0073e9SAndroid Build Coastguard Worker 1, 478*da0073e9SAndroid Build Coastguard Worker 2, 479*da0073e9SAndroid Build Coastguard Worker 3, 480*da0073e9SAndroid Build Coastguard Worker ], 481*da0073e9SAndroid Build Coastguard Worker [ 482*da0073e9SAndroid Build Coastguard Worker 5, 483*da0073e9SAndroid Build Coastguard Worker 6, 484*da0073e9SAndroid Build Coastguard Worker 7, 485*da0073e9SAndroid Build Coastguard Worker ], 486*da0073e9SAndroid Build Coastguard Worker ] 487*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script( 488*da0073e9SAndroid Build Coastguard Worker test_nested_list, 489*da0073e9SAndroid Build Coastguard Worker example_inputs=[ 490*da0073e9SAndroid Build Coastguard Worker (list_inp,), 491*da0073e9SAndroid Build Coastguard Worker ], 492*da0073e9SAndroid Build Coastguard Worker ) 493*da0073e9SAndroid Build Coastguard Worker inp = [ 494*da0073e9SAndroid Build Coastguard Worker [ 495*da0073e9SAndroid Build Coastguard Worker 0, 496*da0073e9SAndroid Build Coastguard Worker 4, 497*da0073e9SAndroid Build Coastguard Worker 7, 498*da0073e9SAndroid Build Coastguard Worker ], 499*da0073e9SAndroid Build Coastguard Worker [ 500*da0073e9SAndroid Build Coastguard Worker 8, 501*da0073e9SAndroid Build Coastguard Worker 11, 502*da0073e9SAndroid Build Coastguard Worker ], 503*da0073e9SAndroid Build Coastguard Worker [ 504*da0073e9SAndroid Build Coastguard Worker 6, 505*da0073e9SAndroid Build Coastguard Worker -1, 506*da0073e9SAndroid Build Coastguard Worker -20, 507*da0073e9SAndroid Build Coastguard Worker ], 508*da0073e9SAndroid Build Coastguard Worker ] 509*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 510*da0073e9SAndroid Build Coastguard Worker scripted_fn( 511*da0073e9SAndroid Build Coastguard Worker inp, 512*da0073e9SAndroid Build Coastguard Worker ), 513*da0073e9SAndroid Build Coastguard Worker test_nested_list( 514*da0073e9SAndroid Build Coastguard Worker inp, 515*da0073e9SAndroid Build Coastguard Worker ), 516*da0073e9SAndroid Build Coastguard Worker ) 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker list_inp = ( 519*da0073e9SAndroid Build Coastguard Worker [ 520*da0073e9SAndroid Build Coastguard Worker 1, 521*da0073e9SAndroid Build Coastguard Worker 2, 522*da0073e9SAndroid Build Coastguard Worker 3, 523*da0073e9SAndroid Build Coastguard Worker ], 524*da0073e9SAndroid Build Coastguard Worker [ 525*da0073e9SAndroid Build Coastguard Worker 5, 526*da0073e9SAndroid Build Coastguard Worker 6, 527*da0073e9SAndroid Build Coastguard Worker 7, 528*da0073e9SAndroid Build Coastguard Worker ], 529*da0073e9SAndroid Build Coastguard Worker ) 530*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script( 531*da0073e9SAndroid Build Coastguard Worker test_nested_list, 532*da0073e9SAndroid Build Coastguard Worker example_inputs=[ 533*da0073e9SAndroid Build Coastguard Worker (list_inp,), 534*da0073e9SAndroid Build Coastguard Worker ], 535*da0073e9SAndroid Build Coastguard Worker ) 536*da0073e9SAndroid Build Coastguard Worker inp = ( 537*da0073e9SAndroid Build Coastguard Worker [ 538*da0073e9SAndroid Build Coastguard Worker 0, 539*da0073e9SAndroid Build Coastguard Worker 4, 540*da0073e9SAndroid Build Coastguard Worker 7, 541*da0073e9SAndroid Build Coastguard Worker ], 542*da0073e9SAndroid Build Coastguard Worker [ 543*da0073e9SAndroid Build Coastguard Worker 8, 544*da0073e9SAndroid Build Coastguard Worker 11, 545*da0073e9SAndroid Build Coastguard Worker ], 546*da0073e9SAndroid Build Coastguard Worker [ 547*da0073e9SAndroid Build Coastguard Worker 6, 548*da0073e9SAndroid Build Coastguard Worker -1, 549*da0073e9SAndroid Build Coastguard Worker -20, 550*da0073e9SAndroid Build Coastguard Worker ], 551*da0073e9SAndroid Build Coastguard Worker ) 552*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 553*da0073e9SAndroid Build Coastguard Worker scripted_fn( 554*da0073e9SAndroid Build Coastguard Worker inp, 555*da0073e9SAndroid Build Coastguard Worker ), 556*da0073e9SAndroid Build Coastguard Worker test_nested_list( 557*da0073e9SAndroid Build Coastguard Worker inp, 558*da0073e9SAndroid Build Coastguard Worker ), 559*da0073e9SAndroid Build Coastguard Worker ) 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker tup_inp = [ 562*da0073e9SAndroid Build Coastguard Worker ( 563*da0073e9SAndroid Build Coastguard Worker 1.0, 564*da0073e9SAndroid Build Coastguard Worker 2.6, 565*da0073e9SAndroid Build Coastguard Worker 3.7, 566*da0073e9SAndroid Build Coastguard Worker ), 567*da0073e9SAndroid Build Coastguard Worker ( 568*da0073e9SAndroid Build Coastguard Worker 5.7, 569*da0073e9SAndroid Build Coastguard Worker 6.1, 570*da0073e9SAndroid Build Coastguard Worker 1.7, 571*da0073e9SAndroid Build Coastguard Worker ), 572*da0073e9SAndroid Build Coastguard Worker ] 573*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script( 574*da0073e9SAndroid Build Coastguard Worker test_nested_tuple, 575*da0073e9SAndroid Build Coastguard Worker example_inputs=[ 576*da0073e9SAndroid Build Coastguard Worker (tup_inp,), 577*da0073e9SAndroid Build Coastguard Worker ], 578*da0073e9SAndroid Build Coastguard Worker ) 579*da0073e9SAndroid Build Coastguard Worker inp = [ 580*da0073e9SAndroid Build Coastguard Worker ( 581*da0073e9SAndroid Build Coastguard Worker 1.0, 582*da0073e9SAndroid Build Coastguard Worker 4.1, 583*da0073e9SAndroid Build Coastguard Worker 7.4, 584*da0073e9SAndroid Build Coastguard Worker ), 585*da0073e9SAndroid Build Coastguard Worker ( 586*da0073e9SAndroid Build Coastguard Worker 4.8, 587*da0073e9SAndroid Build Coastguard Worker 1.1, 588*da0073e9SAndroid Build Coastguard Worker -1.2, 589*da0073e9SAndroid Build Coastguard Worker ), 590*da0073e9SAndroid Build Coastguard Worker ( 591*da0073e9SAndroid Build Coastguard Worker 6.3, 592*da0073e9SAndroid Build Coastguard Worker -1.3, 593*da0073e9SAndroid Build Coastguard Worker -2.0, 594*da0073e9SAndroid Build Coastguard Worker ), 595*da0073e9SAndroid Build Coastguard Worker ] 596*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 597*da0073e9SAndroid Build Coastguard Worker scripted_fn( 598*da0073e9SAndroid Build Coastguard Worker inp, 599*da0073e9SAndroid Build Coastguard Worker ), 600*da0073e9SAndroid Build Coastguard Worker test_nested_tuple( 601*da0073e9SAndroid Build Coastguard Worker inp, 602*da0073e9SAndroid Build Coastguard Worker ), 603*da0073e9SAndroid Build Coastguard Worker ) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker tup_inp = ( 606*da0073e9SAndroid Build Coastguard Worker ( 607*da0073e9SAndroid Build Coastguard Worker True, 608*da0073e9SAndroid Build Coastguard Worker False, 609*da0073e9SAndroid Build Coastguard Worker True, 610*da0073e9SAndroid Build Coastguard Worker ), 611*da0073e9SAndroid Build Coastguard Worker ( 612*da0073e9SAndroid Build Coastguard Worker False, 613*da0073e9SAndroid Build Coastguard Worker False, 614*da0073e9SAndroid Build Coastguard Worker False, 615*da0073e9SAndroid Build Coastguard Worker ), 616*da0073e9SAndroid Build Coastguard Worker ) 617*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script( 618*da0073e9SAndroid Build Coastguard Worker test_nested_tuple, 619*da0073e9SAndroid Build Coastguard Worker example_inputs=[ 620*da0073e9SAndroid Build Coastguard Worker (tup_inp,), 621*da0073e9SAndroid Build Coastguard Worker ], 622*da0073e9SAndroid Build Coastguard Worker ) 623*da0073e9SAndroid Build Coastguard Worker inp = ( 624*da0073e9SAndroid Build Coastguard Worker ( 625*da0073e9SAndroid Build Coastguard Worker True, 626*da0073e9SAndroid Build Coastguard Worker True, 627*da0073e9SAndroid Build Coastguard Worker True, 628*da0073e9SAndroid Build Coastguard Worker ), 629*da0073e9SAndroid Build Coastguard Worker ( 630*da0073e9SAndroid Build Coastguard Worker False, 631*da0073e9SAndroid Build Coastguard Worker False, 632*da0073e9SAndroid Build Coastguard Worker True, 633*da0073e9SAndroid Build Coastguard Worker ), 634*da0073e9SAndroid Build Coastguard Worker ) 635*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 636*da0073e9SAndroid Build Coastguard Worker scripted_fn( 637*da0073e9SAndroid Build Coastguard Worker inp, 638*da0073e9SAndroid Build Coastguard Worker ), 639*da0073e9SAndroid Build Coastguard Worker test_nested_tuple( 640*da0073e9SAndroid Build Coastguard Worker inp, 641*da0073e9SAndroid Build Coastguard Worker ), 642*da0073e9SAndroid Build Coastguard Worker ) 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Worker def test_pdt_dict(self): 645*da0073e9SAndroid Build Coastguard Worker def test_dict(a): 646*da0073e9SAndroid Build Coastguard Worker return a["foo"] 647*da0073e9SAndroid Build Coastguard Worker 648*da0073e9SAndroid Build Coastguard Worker def test_dict_int_list(a): 649*da0073e9SAndroid Build Coastguard Worker return a[1] 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker make_global(test_dict, test_dict_int_list) 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker str_bool_inp = {"foo": True, "bar": False} 654*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)]) 655*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 656*da0073e9SAndroid Build Coastguard Worker scripted_fn( 657*da0073e9SAndroid Build Coastguard Worker {"foo": False, "bar": True}, 658*da0073e9SAndroid Build Coastguard Worker ), 659*da0073e9SAndroid Build Coastguard Worker test_dict( 660*da0073e9SAndroid Build Coastguard Worker {"foo": False, "bar": True}, 661*da0073e9SAndroid Build Coastguard Worker ), 662*da0073e9SAndroid Build Coastguard Worker ) 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker str_list_inp = {0: [True, False], 1: [False, True]} 665*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script( 666*da0073e9SAndroid Build Coastguard Worker test_dict_int_list, example_inputs=[(str_list_inp,)] 667*da0073e9SAndroid Build Coastguard Worker ) 668*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 669*da0073e9SAndroid Build Coastguard Worker scripted_fn( 670*da0073e9SAndroid Build Coastguard Worker {0: [False, False], 1: [True, True]}, 671*da0073e9SAndroid Build Coastguard Worker ), 672*da0073e9SAndroid Build Coastguard Worker test_dict_int_list( 673*da0073e9SAndroid Build Coastguard Worker {0: [False, False], 1: [True, True]}, 674*da0073e9SAndroid Build Coastguard Worker ), 675*da0073e9SAndroid Build Coastguard Worker ) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker def test_any(self): 678*da0073e9SAndroid Build Coastguard Worker def test_multiple_types(a): 679*da0073e9SAndroid Build Coastguard Worker assert not isinstance(a, bool) 680*da0073e9SAndroid Build Coastguard Worker return a 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker def test_multiple_type_refinement(a): 683*da0073e9SAndroid Build Coastguard Worker if isinstance(a, bool): 684*da0073e9SAndroid Build Coastguard Worker return 1 685*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, int): 686*da0073e9SAndroid Build Coastguard Worker return 1 + a 687*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, float): 688*da0073e9SAndroid Build Coastguard Worker return 1 + int(a) 689*da0073e9SAndroid Build Coastguard Worker else: 690*da0073e9SAndroid Build Coastguard Worker return -1 691*da0073e9SAndroid Build Coastguard Worker 692*da0073e9SAndroid Build Coastguard Worker make_global(test_multiple_types, test_multiple_type_refinement) 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script( 695*da0073e9SAndroid Build Coastguard Worker test_multiple_types, example_inputs=[(1,), ("abc",), (8.9,), ([3, 4, 5],)] 696*da0073e9SAndroid Build Coastguard Worker ) 697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn(10), test_multiple_types(10)) 698*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn("def"), test_multiple_types("def")) 699*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999)) 700*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14])) 701*da0073e9SAndroid Build Coastguard Worker 702*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script( 703*da0073e9SAndroid Build Coastguard Worker test_multiple_type_refinement, 704*da0073e9SAndroid Build Coastguard Worker example_inputs=[ 705*da0073e9SAndroid Build Coastguard Worker (1,), 706*da0073e9SAndroid Build Coastguard Worker ("abc",), 707*da0073e9SAndroid Build Coastguard Worker (8.9,), 708*da0073e9SAndroid Build Coastguard Worker ([3, 4, 5],), 709*da0073e9SAndroid Build Coastguard Worker (True,), 710*da0073e9SAndroid Build Coastguard Worker ({"a": True},), 711*da0073e9SAndroid Build Coastguard Worker ], 712*da0073e9SAndroid Build Coastguard Worker ) 713*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10)) 714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def")) 715*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999)) 716*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 717*da0073e9SAndroid Build Coastguard Worker scripted_fn([10, 11, 14]), test_multiple_type_refinement([10, 11, 14]) 718*da0073e9SAndroid Build Coastguard Worker ) 719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn(False), test_multiple_type_refinement(False)) 720*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 721*da0073e9SAndroid Build Coastguard Worker scripted_fn({"abc": True, "def": False}), 722*da0073e9SAndroid Build Coastguard Worker test_multiple_type_refinement({"abc": True, "def": False}), 723*da0073e9SAndroid Build Coastguard Worker ) 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker def test_class_as_profiled_types(self): 726*da0073e9SAndroid Build Coastguard Worker class UserDefinedClass: 727*da0073e9SAndroid Build Coastguard Worker def fn(self, b) -> Any: 728*da0073e9SAndroid Build Coastguard Worker assert b is not None 729*da0073e9SAndroid Build Coastguard Worker if isinstance(b, int): 730*da0073e9SAndroid Build Coastguard Worker return b if b > 0 else -1 731*da0073e9SAndroid Build Coastguard Worker elif isinstance(b, float): 732*da0073e9SAndroid Build Coastguard Worker return b if b > 0.0 else -1.0 733*da0073e9SAndroid Build Coastguard Worker return 0 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker def test_model(a, m): 736*da0073e9SAndroid Build Coastguard Worker assert not isinstance(a, bool) 737*da0073e9SAndroid Build Coastguard Worker return m.fn(a) 738*da0073e9SAndroid Build Coastguard Worker 739*da0073e9SAndroid Build Coastguard Worker make_global(UserDefinedClass, test_model) 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker user_class = UserDefinedClass() 742*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script( 743*da0073e9SAndroid Build Coastguard Worker test_model, 744*da0073e9SAndroid Build Coastguard Worker example_inputs=[ 745*da0073e9SAndroid Build Coastguard Worker ( 746*da0073e9SAndroid Build Coastguard Worker 10, 747*da0073e9SAndroid Build Coastguard Worker user_class, 748*da0073e9SAndroid Build Coastguard Worker ), 749*da0073e9SAndroid Build Coastguard Worker ( 750*da0073e9SAndroid Build Coastguard Worker 10.9, 751*da0073e9SAndroid Build Coastguard Worker user_class, 752*da0073e9SAndroid Build Coastguard Worker ), 753*da0073e9SAndroid Build Coastguard Worker ], 754*da0073e9SAndroid Build Coastguard Worker ) 755*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 756*da0073e9SAndroid Build Coastguard Worker scripted_fn( 757*da0073e9SAndroid Build Coastguard Worker 100, 758*da0073e9SAndroid Build Coastguard Worker user_class, 759*da0073e9SAndroid Build Coastguard Worker ), 760*da0073e9SAndroid Build Coastguard Worker test_model(100, user_class), 761*da0073e9SAndroid Build Coastguard Worker ) 762*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 763*da0073e9SAndroid Build Coastguard Worker scripted_fn( 764*da0073e9SAndroid Build Coastguard Worker 1.9, 765*da0073e9SAndroid Build Coastguard Worker user_class, 766*da0073e9SAndroid Build Coastguard Worker ), 767*da0073e9SAndroid Build Coastguard Worker test_model(1.9, user_class), 768*da0073e9SAndroid Build Coastguard Worker ) 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker def test_class_with_args_as_profiled_types(self): 771*da0073e9SAndroid Build Coastguard Worker class ClassWithArgs: 772*da0073e9SAndroid Build Coastguard Worker def __init__(self, a: bool): 773*da0073e9SAndroid Build Coastguard Worker self.a = a 774*da0073e9SAndroid Build Coastguard Worker 775*da0073e9SAndroid Build Coastguard Worker def fn(self, b): 776*da0073e9SAndroid Build Coastguard Worker if self.a: 777*da0073e9SAndroid Build Coastguard Worker return b 778*da0073e9SAndroid Build Coastguard Worker else: 779*da0073e9SAndroid Build Coastguard Worker return -1 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker def test_model_with_args(a, m): 782*da0073e9SAndroid Build Coastguard Worker assert not isinstance(a, bool) 783*da0073e9SAndroid Build Coastguard Worker return m.fn(a) 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker make_global(ClassWithArgs, test_model_with_args) 786*da0073e9SAndroid Build Coastguard Worker 787*da0073e9SAndroid Build Coastguard Worker user_class = ClassWithArgs(False) 788*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script( 789*da0073e9SAndroid Build Coastguard Worker test_model_with_args, 790*da0073e9SAndroid Build Coastguard Worker example_inputs=[ 791*da0073e9SAndroid Build Coastguard Worker ( 792*da0073e9SAndroid Build Coastguard Worker 10, 793*da0073e9SAndroid Build Coastguard Worker user_class, 794*da0073e9SAndroid Build Coastguard Worker ), 795*da0073e9SAndroid Build Coastguard Worker ( 796*da0073e9SAndroid Build Coastguard Worker 10.9, 797*da0073e9SAndroid Build Coastguard Worker user_class, 798*da0073e9SAndroid Build Coastguard Worker ), 799*da0073e9SAndroid Build Coastguard Worker ], 800*da0073e9SAndroid Build Coastguard Worker ) 801*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 802*da0073e9SAndroid Build Coastguard Worker scripted_fn( 803*da0073e9SAndroid Build Coastguard Worker 100, 804*da0073e9SAndroid Build Coastguard Worker ClassWithArgs(True), 805*da0073e9SAndroid Build Coastguard Worker ), 806*da0073e9SAndroid Build Coastguard Worker test_model_with_args(100, ClassWithArgs(True)), 807*da0073e9SAndroid Build Coastguard Worker ) 808*da0073e9SAndroid Build Coastguard Worker 809*da0073e9SAndroid Build Coastguard Worker def test_nn_parameter_as_arg(self): 810*da0073e9SAndroid Build Coastguard Worker class TestNNParameter(torch.nn.Module): 811*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 812*da0073e9SAndroid Build Coastguard Worker super().__init__() 813*da0073e9SAndroid Build Coastguard Worker self.inp = torch.nn.Parameter(torch.ones(2, 3)) 814*da0073e9SAndroid Build Coastguard Worker 815*da0073e9SAndroid Build Coastguard Worker def add_nn_parameter_with_int(self, x, y): 816*da0073e9SAndroid Build Coastguard Worker return torch.add(x, y) 817*da0073e9SAndroid Build Coastguard Worker 818*da0073e9SAndroid Build Coastguard Worker def forward(self, y): 819*da0073e9SAndroid Build Coastguard Worker return self.add_nn_parameter_with_int(self.inp, y) 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker make_global(TestNNParameter) 822*da0073e9SAndroid Build Coastguard Worker pdt_model = TestNNParameter() 823*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script( 824*da0073e9SAndroid Build Coastguard Worker pdt_model, 825*da0073e9SAndroid Build Coastguard Worker example_inputs={ 826*da0073e9SAndroid Build Coastguard Worker pdt_model: [ 827*da0073e9SAndroid Build Coastguard Worker (10,), 828*da0073e9SAndroid Build Coastguard Worker ], 829*da0073e9SAndroid Build Coastguard Worker }, 830*da0073e9SAndroid Build Coastguard Worker ) 831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn(20), pdt_model(20)) 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker def test_fx_tracing_with_typing(self): 834*da0073e9SAndroid Build Coastguard Worker class FXModelOutput(NamedTuple): 835*da0073e9SAndroid Build Coastguard Worker result: List[int] 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker class FXModel(torch.nn.Module): 838*da0073e9SAndroid Build Coastguard Worker def forward(self, a) -> FXModelOutput: 839*da0073e9SAndroid Build Coastguard Worker result = FXModelOutput(result=a) 840*da0073e9SAndroid Build Coastguard Worker return result 841*da0073e9SAndroid Build Coastguard Worker 842*da0073e9SAndroid Build Coastguard Worker make_global(FXModel, FXModelOutput) 843*da0073e9SAndroid Build Coastguard Worker pdt_model = FXModel() 844*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script( 845*da0073e9SAndroid Build Coastguard Worker pdt_model, 846*da0073e9SAndroid Build Coastguard Worker example_inputs={ 847*da0073e9SAndroid Build Coastguard Worker pdt_model: [ 848*da0073e9SAndroid Build Coastguard Worker ( 849*da0073e9SAndroid Build Coastguard Worker [ 850*da0073e9SAndroid Build Coastguard Worker 10, 851*da0073e9SAndroid Build Coastguard Worker 20, 852*da0073e9SAndroid Build Coastguard Worker ], 853*da0073e9SAndroid Build Coastguard Worker ), 854*da0073e9SAndroid Build Coastguard Worker ], 855*da0073e9SAndroid Build Coastguard Worker }, 856*da0073e9SAndroid Build Coastguard Worker ) 857*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn([20]), pdt_model([20])) 858*da0073e9SAndroid Build Coastguard Worker 859*da0073e9SAndroid Build Coastguard Worker def test_nonetype_as_optional_of_type(self): 860*da0073e9SAndroid Build Coastguard Worker def test_none(a) -> Any: 861*da0073e9SAndroid Build Coastguard Worker if a is None: 862*da0073e9SAndroid Build Coastguard Worker return 0 863*da0073e9SAndroid Build Coastguard Worker else: 864*da0073e9SAndroid Build Coastguard Worker return a + torch.ones(1) 865*da0073e9SAndroid Build Coastguard Worker 866*da0073e9SAndroid Build Coastguard Worker make_global(test_none) 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10.6,)]) 869*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 870*da0073e9SAndroid Build Coastguard Worker scripted_fn( 871*da0073e9SAndroid Build Coastguard Worker 30.9, 872*da0073e9SAndroid Build Coastguard Worker ), 873*da0073e9SAndroid Build Coastguard Worker test_none( 874*da0073e9SAndroid Build Coastguard Worker 30.9, 875*da0073e9SAndroid Build Coastguard Worker ), 876*da0073e9SAndroid Build Coastguard Worker ) 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10,)]) 879*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 880*da0073e9SAndroid Build Coastguard Worker scripted_fn( 881*da0073e9SAndroid Build Coastguard Worker 2, 882*da0073e9SAndroid Build Coastguard Worker ), 883*da0073e9SAndroid Build Coastguard Worker test_none( 884*da0073e9SAndroid Build Coastguard Worker 2, 885*da0073e9SAndroid Build Coastguard Worker ), 886*da0073e9SAndroid Build Coastguard Worker ) 887*da0073e9SAndroid Build Coastguard Worker 888*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script( 889*da0073e9SAndroid Build Coastguard Worker test_none, example_inputs=[(None,), (torch.Tensor(1),)] 890*da0073e9SAndroid Build Coastguard Worker ) 891*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 892*da0073e9SAndroid Build Coastguard Worker scripted_fn( 893*da0073e9SAndroid Build Coastguard Worker torch.ones(1), 894*da0073e9SAndroid Build Coastguard Worker ), 895*da0073e9SAndroid Build Coastguard Worker test_none( 896*da0073e9SAndroid Build Coastguard Worker torch.ones(1), 897*da0073e9SAndroid Build Coastguard Worker ), 898*da0073e9SAndroid Build Coastguard Worker ) 899