xref: /aosp_15_r20/external/pytorch/test/jit/test_pdt.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, NamedTuple, Optional, Tuple  # noqa: F401
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerfrom torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import NoTest
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
14*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerif not _IS_MONKEYTYPE_INSTALLED:
18*da0073e9SAndroid Build Coastguard Worker    print(
19*da0073e9SAndroid Build Coastguard Worker        "monkeytype is not installed. Skipping tests for Profile-Directed Typing",
20*da0073e9SAndroid Build Coastguard Worker        file=sys.stderr,
21*da0073e9SAndroid Build Coastguard Worker    )
22*da0073e9SAndroid Build Coastguard Worker    JitTestCase = NoTest  # type: ignore[misc, assignment] # noqa: F811
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
25*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
26*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
27*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
28*da0073e9SAndroid Build Coastguard Worker        "instead."
29*da0073e9SAndroid Build Coastguard Worker    )
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerclass TestPDT(JitTestCase):
33*da0073e9SAndroid Build Coastguard Worker    """
34*da0073e9SAndroid Build Coastguard Worker    A suite of tests for profile directed typing in TorchScript.
35*da0073e9SAndroid Build Coastguard Worker    """
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker    def test_nn_module(self):
38*da0073e9SAndroid Build Coastguard Worker        class TestPDTModel(torch.nn.Module):
39*da0073e9SAndroid Build Coastguard Worker            def forward(self, x) -> Any:
40*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, int):
41*da0073e9SAndroid Build Coastguard Worker                    return x + 1
42*da0073e9SAndroid Build Coastguard Worker                elif isinstance(x, float):
43*da0073e9SAndroid Build Coastguard Worker                    return x - 1
44*da0073e9SAndroid Build Coastguard Worker                else:
45*da0073e9SAndroid Build Coastguard Worker                    return x
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker        make_global(TestPDTModel)
48*da0073e9SAndroid Build Coastguard Worker        pdt_model = TestPDTModel()
49*da0073e9SAndroid Build Coastguard Worker        inp: List[Tuple[Any, ...]] = [
50*da0073e9SAndroid Build Coastguard Worker            (20,),
51*da0073e9SAndroid Build Coastguard Worker            (2.7,),
52*da0073e9SAndroid Build Coastguard Worker            (False,),
53*da0073e9SAndroid Build Coastguard Worker        ]
54*da0073e9SAndroid Build Coastguard Worker        scripted_pdt_model = torch.jit.script(
55*da0073e9SAndroid Build Coastguard Worker            pdt_model, example_inputs={pdt_model: inp}
56*da0073e9SAndroid Build Coastguard Worker        )
57*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_pdt_model(50), pdt_model(50))
58*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8))
59*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(scripted_pdt_model(True), pdt_model(True))
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    def test_nested_nn_module_class(self):
62*da0073e9SAndroid Build Coastguard Worker        class NestedPDTInner(torch.nn.Module):
63*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
64*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, int):
65*da0073e9SAndroid Build Coastguard Worker                    return x * 10
66*da0073e9SAndroid Build Coastguard Worker                return x
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        class NestedModulePDTWrapper(torch.nn.Module):
69*da0073e9SAndroid Build Coastguard Worker            def __init__(self, inner):
70*da0073e9SAndroid Build Coastguard Worker                super().__init__()
71*da0073e9SAndroid Build Coastguard Worker                self.inner = inner
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
74*da0073e9SAndroid Build Coastguard Worker                return self.inner(x)
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker        make_global(NestedPDTInner, NestedModulePDTWrapper)
77*da0073e9SAndroid Build Coastguard Worker        inner_pdt_model = NestedPDTInner()
78*da0073e9SAndroid Build Coastguard Worker        wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model)
79*da0073e9SAndroid Build Coastguard Worker        inp: List[Tuple[Any, ...]] = [(20,), (False,)]
80*da0073e9SAndroid Build Coastguard Worker        scripted_pdt_model = torch.jit.script(
81*da0073e9SAndroid Build Coastguard Worker            wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp}
82*da0073e9SAndroid Build Coastguard Worker        )
83*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30))
84*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9))
85*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True))
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    def test_nested_nn_module_class_with_args(self):
88*da0073e9SAndroid Build Coastguard Worker        class NestedModulePDTInner(torch.nn.Module):
89*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
90*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, int):
91*da0073e9SAndroid Build Coastguard Worker                    return x * 10 + y
92*da0073e9SAndroid Build Coastguard Worker                return x
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker        class NestedModulePDTOuter(torch.nn.Module):
95*da0073e9SAndroid Build Coastguard Worker            def __init__(self, inner):
96*da0073e9SAndroid Build Coastguard Worker                super().__init__()
97*da0073e9SAndroid Build Coastguard Worker                self.inner = inner
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
100*da0073e9SAndroid Build Coastguard Worker                return self.inner(x, 20)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker        make_global(NestedModulePDTInner, NestedModulePDTOuter)
103*da0073e9SAndroid Build Coastguard Worker        inner_pdt_model = NestedModulePDTInner()
104*da0073e9SAndroid Build Coastguard Worker        outer_pdt_model = NestedModulePDTOuter(inner_pdt_model)
105*da0073e9SAndroid Build Coastguard Worker        inner_input: List[Tuple[Any, ...]] = [
106*da0073e9SAndroid Build Coastguard Worker            (10, 10),
107*da0073e9SAndroid Build Coastguard Worker            (1.9, 20),
108*da0073e9SAndroid Build Coastguard Worker        ]
109*da0073e9SAndroid Build Coastguard Worker        outer_input: List[Tuple[Any, ...]] = [(20,), (False,)]
110*da0073e9SAndroid Build Coastguard Worker        scripted_pdt_model = torch.jit.script(
111*da0073e9SAndroid Build Coastguard Worker            outer_pdt_model,
112*da0073e9SAndroid Build Coastguard Worker            example_inputs={
113*da0073e9SAndroid Build Coastguard Worker                inner_pdt_model: inner_input,
114*da0073e9SAndroid Build Coastguard Worker                outer_pdt_model: outer_input,
115*da0073e9SAndroid Build Coastguard Worker            },
116*da0073e9SAndroid Build Coastguard Worker        )
117*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30))
118*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9))
119*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True))
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    def test_nested_function_in_forward(self):
122*da0073e9SAndroid Build Coastguard Worker        class NestedFunctionInForward(torch.nn.Module):
123*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
124*da0073e9SAndroid Build Coastguard Worker                return self.fun(x) + 10
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker            def fun(self, x):
127*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, bool):
128*da0073e9SAndroid Build Coastguard Worker                    return 0
129*da0073e9SAndroid Build Coastguard Worker                elif isinstance(x, int):
130*da0073e9SAndroid Build Coastguard Worker                    return x + 1
131*da0073e9SAndroid Build Coastguard Worker                return 0
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker        make_global(NestedFunctionInForward)
134*da0073e9SAndroid Build Coastguard Worker        pdt_model = NestedFunctionInForward()
135*da0073e9SAndroid Build Coastguard Worker        inp: List[Tuple[Any, ...]] = [(-1,), (False,)]
136*da0073e9SAndroid Build Coastguard Worker        scripted_pdt_model = torch.jit.script(
137*da0073e9SAndroid Build Coastguard Worker            pdt_model, example_inputs={pdt_model: inp}
138*da0073e9SAndroid Build Coastguard Worker        )
139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_pdt_model(30), pdt_model(30))
140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_pdt_model(True), pdt_model(True))
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    def test_nn_module_with_export_function(self):
143*da0073e9SAndroid Build Coastguard Worker        class TestModelWithExport(torch.nn.Module):
144*da0073e9SAndroid Build Coastguard Worker            @torch.jit.export
145*da0073e9SAndroid Build Coastguard Worker            def fn(self, x, y) -> Any:
146*da0073e9SAndroid Build Coastguard Worker                assert not (isinstance(x, bool) and isinstance(y, bool))
147*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, int) and isinstance(y, int):
148*da0073e9SAndroid Build Coastguard Worker                    return x + y
149*da0073e9SAndroid Build Coastguard Worker                elif isinstance(x, float) and isinstance(y, float):
150*da0073e9SAndroid Build Coastguard Worker                    return x - y
151*da0073e9SAndroid Build Coastguard Worker                else:
152*da0073e9SAndroid Build Coastguard Worker                    return -1
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker        make_global(TestModelWithExport)
155*da0073e9SAndroid Build Coastguard Worker        pdt_model = TestModelWithExport()
156*da0073e9SAndroid Build Coastguard Worker        inp: List[Tuple[Any, ...]] = [
157*da0073e9SAndroid Build Coastguard Worker            (
158*da0073e9SAndroid Build Coastguard Worker                20,
159*da0073e9SAndroid Build Coastguard Worker                10,
160*da0073e9SAndroid Build Coastguard Worker            ),
161*da0073e9SAndroid Build Coastguard Worker            (
162*da0073e9SAndroid Build Coastguard Worker                2.7,
163*da0073e9SAndroid Build Coastguard Worker                8.9,
164*da0073e9SAndroid Build Coastguard Worker            ),
165*da0073e9SAndroid Build Coastguard Worker        ]
166*da0073e9SAndroid Build Coastguard Worker        scripted_pdt_model = torch.jit.script(
167*da0073e9SAndroid Build Coastguard Worker            pdt_model, example_inputs={pdt_model.fn: inp}
168*da0073e9SAndroid Build Coastguard Worker        )
169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90))
170*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2))
171*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
172*da0073e9SAndroid Build Coastguard Worker            scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2)
173*da0073e9SAndroid Build Coastguard Worker        )
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker    def test_class_methods(self):
176*da0073e9SAndroid Build Coastguard Worker        class PDTModel:
177*da0073e9SAndroid Build Coastguard Worker            def test_sum(self, a):
178*da0073e9SAndroid Build Coastguard Worker                return sum(a)
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker        make_global(PDTModel)
181*da0073e9SAndroid Build Coastguard Worker        pdt_model = PDTModel()
182*da0073e9SAndroid Build Coastguard Worker        inp: List[Tuple[Any, ...]] = [
183*da0073e9SAndroid Build Coastguard Worker            (
184*da0073e9SAndroid Build Coastguard Worker                [
185*da0073e9SAndroid Build Coastguard Worker                    10,
186*da0073e9SAndroid Build Coastguard Worker                    20,
187*da0073e9SAndroid Build Coastguard Worker                ],
188*da0073e9SAndroid Build Coastguard Worker            ),
189*da0073e9SAndroid Build Coastguard Worker        ]
190*da0073e9SAndroid Build Coastguard Worker        scripted_pdt_model = torch.jit.script(
191*da0073e9SAndroid Build Coastguard Worker            PDTModel, example_inputs={pdt_model.test_sum: inp}
192*da0073e9SAndroid Build Coastguard Worker        )
193*da0073e9SAndroid Build Coastguard Worker        script_model = scripted_pdt_model()
194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
195*da0073e9SAndroid Build Coastguard Worker            script_model.test_sum(
196*da0073e9SAndroid Build Coastguard Worker                [
197*da0073e9SAndroid Build Coastguard Worker                    10,
198*da0073e9SAndroid Build Coastguard Worker                    20,
199*da0073e9SAndroid Build Coastguard Worker                    30,
200*da0073e9SAndroid Build Coastguard Worker                ],
201*da0073e9SAndroid Build Coastguard Worker            ),
202*da0073e9SAndroid Build Coastguard Worker            pdt_model.test_sum(
203*da0073e9SAndroid Build Coastguard Worker                [
204*da0073e9SAndroid Build Coastguard Worker                    10,
205*da0073e9SAndroid Build Coastguard Worker                    20,
206*da0073e9SAndroid Build Coastguard Worker                    30,
207*da0073e9SAndroid Build Coastguard Worker                ],
208*da0073e9SAndroid Build Coastguard Worker            ),
209*da0073e9SAndroid Build Coastguard Worker        )
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    def test_class_with_multiple_methods(self):
212*da0073e9SAndroid Build Coastguard Worker        class PDTModelWithManyMethods:
213*da0073e9SAndroid Build Coastguard Worker            def test_list_to_dict(self, a):
214*da0073e9SAndroid Build Coastguard Worker                new_dictionary: Dict[float, bool] = {}
215*da0073e9SAndroid Build Coastguard Worker                for element in a:
216*da0073e9SAndroid Build Coastguard Worker                    new_dictionary[element] = True
217*da0073e9SAndroid Build Coastguard Worker                return new_dictionary
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker            def test_substring(self, a, b):
220*da0073e9SAndroid Build Coastguard Worker                return b in a
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker        make_global(PDTModelWithManyMethods)
223*da0073e9SAndroid Build Coastguard Worker        pdt_model = PDTModelWithManyMethods()
224*da0073e9SAndroid Build Coastguard Worker        list_inp: List[Tuple[Any, ...]] = [
225*da0073e9SAndroid Build Coastguard Worker            (
226*da0073e9SAndroid Build Coastguard Worker                [
227*da0073e9SAndroid Build Coastguard Worker                    1.2,
228*da0073e9SAndroid Build Coastguard Worker                    2.3,
229*da0073e9SAndroid Build Coastguard Worker                ],
230*da0073e9SAndroid Build Coastguard Worker            ),
231*da0073e9SAndroid Build Coastguard Worker        ]
232*da0073e9SAndroid Build Coastguard Worker        str_inp: List[Tuple[Any, ...]] = [
233*da0073e9SAndroid Build Coastguard Worker            (
234*da0073e9SAndroid Build Coastguard Worker                "abc",
235*da0073e9SAndroid Build Coastguard Worker                "b",
236*da0073e9SAndroid Build Coastguard Worker            ),
237*da0073e9SAndroid Build Coastguard Worker        ]
238*da0073e9SAndroid Build Coastguard Worker        scripted_pdt_model = torch.jit.script(
239*da0073e9SAndroid Build Coastguard Worker            PDTModelWithManyMethods,
240*da0073e9SAndroid Build Coastguard Worker            example_inputs={
241*da0073e9SAndroid Build Coastguard Worker                pdt_model.test_list_to_dict: list_inp,
242*da0073e9SAndroid Build Coastguard Worker                pdt_model.test_substring: str_inp,
243*da0073e9SAndroid Build Coastguard Worker            },
244*da0073e9SAndroid Build Coastguard Worker        )
245*da0073e9SAndroid Build Coastguard Worker        script_model = scripted_pdt_model()
246*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
247*da0073e9SAndroid Build Coastguard Worker            script_model.test_list_to_dict(
248*da0073e9SAndroid Build Coastguard Worker                [
249*da0073e9SAndroid Build Coastguard Worker                    1.1,
250*da0073e9SAndroid Build Coastguard Worker                    2.2,
251*da0073e9SAndroid Build Coastguard Worker                    3.3,
252*da0073e9SAndroid Build Coastguard Worker                ],
253*da0073e9SAndroid Build Coastguard Worker            ),
254*da0073e9SAndroid Build Coastguard Worker            pdt_model.test_list_to_dict(
255*da0073e9SAndroid Build Coastguard Worker                [
256*da0073e9SAndroid Build Coastguard Worker                    1.1,
257*da0073e9SAndroid Build Coastguard Worker                    2.2,
258*da0073e9SAndroid Build Coastguard Worker                    3.3,
259*da0073e9SAndroid Build Coastguard Worker                ],
260*da0073e9SAndroid Build Coastguard Worker            ),
261*da0073e9SAndroid Build Coastguard Worker        )
262*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
263*da0073e9SAndroid Build Coastguard Worker            script_model.test_substring(
264*da0073e9SAndroid Build Coastguard Worker                "helloworld",
265*da0073e9SAndroid Build Coastguard Worker                "world",
266*da0073e9SAndroid Build Coastguard Worker            ),
267*da0073e9SAndroid Build Coastguard Worker            pdt_model.test_substring(
268*da0073e9SAndroid Build Coastguard Worker                "helloworld",
269*da0073e9SAndroid Build Coastguard Worker                "world",
270*da0073e9SAndroid Build Coastguard Worker            ),
271*da0073e9SAndroid Build Coastguard Worker        )
272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
273*da0073e9SAndroid Build Coastguard Worker            script_model.test_substring(
274*da0073e9SAndroid Build Coastguard Worker                "helloworld",
275*da0073e9SAndroid Build Coastguard Worker                "def",
276*da0073e9SAndroid Build Coastguard Worker            ),
277*da0073e9SAndroid Build Coastguard Worker            pdt_model.test_substring(
278*da0073e9SAndroid Build Coastguard Worker                "helloworld",
279*da0073e9SAndroid Build Coastguard Worker                "def",
280*da0073e9SAndroid Build Coastguard Worker            ),
281*da0073e9SAndroid Build Coastguard Worker        )
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker    def test_multiple_class_with_same_method(self):
284*da0073e9SAndroid Build Coastguard Worker        class PDTModelOne:
285*da0073e9SAndroid Build Coastguard Worker            def test_find(self, a, b):
286*da0073e9SAndroid Build Coastguard Worker                return b in a.keys()
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker        class PDTModelTwo:
289*da0073e9SAndroid Build Coastguard Worker            def test_find(self, a, b):
290*da0073e9SAndroid Build Coastguard Worker                return b in a
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        make_global(PDTModelOne, PDTModelTwo)
293*da0073e9SAndroid Build Coastguard Worker        pdt_model_one = PDTModelOne()
294*da0073e9SAndroid Build Coastguard Worker        pdt_model_two = PDTModelTwo()
295*da0073e9SAndroid Build Coastguard Worker        dict_inp: List[Tuple[Any, ...]] = [
296*da0073e9SAndroid Build Coastguard Worker            (
297*da0073e9SAndroid Build Coastguard Worker                {
298*da0073e9SAndroid Build Coastguard Worker                    1.2: True,
299*da0073e9SAndroid Build Coastguard Worker                    2.3: False,
300*da0073e9SAndroid Build Coastguard Worker                },
301*da0073e9SAndroid Build Coastguard Worker                1.2,
302*da0073e9SAndroid Build Coastguard Worker            ),
303*da0073e9SAndroid Build Coastguard Worker        ]
304*da0073e9SAndroid Build Coastguard Worker        list_inp: List[Tuple[Any, ...]] = [
305*da0073e9SAndroid Build Coastguard Worker            (
306*da0073e9SAndroid Build Coastguard Worker                [
307*da0073e9SAndroid Build Coastguard Worker                    "abc",
308*da0073e9SAndroid Build Coastguard Worker                    "b",
309*da0073e9SAndroid Build Coastguard Worker                ],
310*da0073e9SAndroid Build Coastguard Worker                "c",
311*da0073e9SAndroid Build Coastguard Worker            ),
312*da0073e9SAndroid Build Coastguard Worker        ]
313*da0073e9SAndroid Build Coastguard Worker        scripted_pdt_model_one = torch.jit.script(
314*da0073e9SAndroid Build Coastguard Worker            PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp}
315*da0073e9SAndroid Build Coastguard Worker        )
316*da0073e9SAndroid Build Coastguard Worker        scripted_pdt_model_two = torch.jit.script(
317*da0073e9SAndroid Build Coastguard Worker            PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp}
318*da0073e9SAndroid Build Coastguard Worker        )
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker        script_model_one, script_model_two = (
321*da0073e9SAndroid Build Coastguard Worker            scripted_pdt_model_one(),
322*da0073e9SAndroid Build Coastguard Worker            scripted_pdt_model_two(),
323*da0073e9SAndroid Build Coastguard Worker        )
324*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
325*da0073e9SAndroid Build Coastguard Worker            script_model_one.test_find(
326*da0073e9SAndroid Build Coastguard Worker                {
327*da0073e9SAndroid Build Coastguard Worker                    1.1: True,
328*da0073e9SAndroid Build Coastguard Worker                    2.2: True,
329*da0073e9SAndroid Build Coastguard Worker                    3.3: False,
330*da0073e9SAndroid Build Coastguard Worker                },
331*da0073e9SAndroid Build Coastguard Worker                4.4,
332*da0073e9SAndroid Build Coastguard Worker            ),
333*da0073e9SAndroid Build Coastguard Worker            pdt_model_one.test_find(
334*da0073e9SAndroid Build Coastguard Worker                {
335*da0073e9SAndroid Build Coastguard Worker                    1.1: True,
336*da0073e9SAndroid Build Coastguard Worker                    2.2: True,
337*da0073e9SAndroid Build Coastguard Worker                    3.3: False,
338*da0073e9SAndroid Build Coastguard Worker                },
339*da0073e9SAndroid Build Coastguard Worker                4.4,
340*da0073e9SAndroid Build Coastguard Worker            ),
341*da0073e9SAndroid Build Coastguard Worker        )
342*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
343*da0073e9SAndroid Build Coastguard Worker            script_model_two.test_find(
344*da0073e9SAndroid Build Coastguard Worker                [
345*da0073e9SAndroid Build Coastguard Worker                    "hello",
346*da0073e9SAndroid Build Coastguard Worker                    "world",
347*da0073e9SAndroid Build Coastguard Worker                ],
348*da0073e9SAndroid Build Coastguard Worker                "world",
349*da0073e9SAndroid Build Coastguard Worker            ),
350*da0073e9SAndroid Build Coastguard Worker            pdt_model_two.test_find(
351*da0073e9SAndroid Build Coastguard Worker                [
352*da0073e9SAndroid Build Coastguard Worker                    "hello",
353*da0073e9SAndroid Build Coastguard Worker                    "world",
354*da0073e9SAndroid Build Coastguard Worker                ],
355*da0073e9SAndroid Build Coastguard Worker                "world",
356*da0073e9SAndroid Build Coastguard Worker            ),
357*da0073e9SAndroid Build Coastguard Worker        )
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker    def test_pdt(self):
360*da0073e9SAndroid Build Coastguard Worker        def test_sum(a, b):
361*da0073e9SAndroid Build Coastguard Worker            return a + b
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker        make_global(test_sum)
364*da0073e9SAndroid Build Coastguard Worker        scripted_fn_add = torch.jit.script(test_sum, example_inputs=[(3, 4)])
365*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_add(10, 2), test_sum(10, 2))
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker        def test_sub(a, b):
368*da0073e9SAndroid Build Coastguard Worker            return a - b
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker        make_global(test_sub)
371*da0073e9SAndroid Build Coastguard Worker        scripted_fn_sub = torch.jit.script(test_sub, example_inputs=[(3.9, 4.10)])
372*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_sub(6.5, 2.9), test_sub(6.5, 2.9))
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker        def test_mul(a, b):
375*da0073e9SAndroid Build Coastguard Worker            return a * b
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker        make_global(test_mul)
378*da0073e9SAndroid Build Coastguard Worker        scripted_fn_mul = torch.jit.script(test_mul, example_inputs=[(-10, 9)])
379*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_mul(-1, 3), test_mul(-1, 3))
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker        def test_args_complex(real, img):
382*da0073e9SAndroid Build Coastguard Worker            return torch.complex(real, img)
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker        make_global(test_args_complex)
385*da0073e9SAndroid Build Coastguard Worker        scripted_fn_complex = torch.jit.script(
386*da0073e9SAndroid Build Coastguard Worker            test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))]
387*da0073e9SAndroid Build Coastguard Worker        )
388*da0073e9SAndroid Build Coastguard Worker        arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4)
389*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2))
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker        def test_bool(a):
392*da0073e9SAndroid Build Coastguard Worker            if a:
393*da0073e9SAndroid Build Coastguard Worker                return -1
394*da0073e9SAndroid Build Coastguard Worker            else:
395*da0073e9SAndroid Build Coastguard Worker                return 0
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker        make_global(test_bool)
398*da0073e9SAndroid Build Coastguard Worker        scripted_fn_bool = torch.jit.script(test_bool, example_inputs=[(True,)])
399*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_bool(True), test_bool(True))
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        def test_str(a):
402*da0073e9SAndroid Build Coastguard Worker            if a == "":
403*da0073e9SAndroid Build Coastguard Worker                return False
404*da0073e9SAndroid Build Coastguard Worker            else:
405*da0073e9SAndroid Build Coastguard Worker                return True
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker        make_global(test_str)
408*da0073e9SAndroid Build Coastguard Worker        scripted_fn_str = torch.jit.script(test_str, example_inputs=[("",)])
409*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn_str("abc"), test_str("abc"))
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker    def test_pdt_list_and_tuple(self):
412*da0073e9SAndroid Build Coastguard Worker        def test_list_and_tuple(a):
413*da0073e9SAndroid Build Coastguard Worker            return sum(a)
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker        make_global(test_list_and_tuple)
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker        scripted_fn_float_list_input = torch.jit.script(
418*da0073e9SAndroid Build Coastguard Worker            test_list_and_tuple, example_inputs=[([4.9, 8.9],)]
419*da0073e9SAndroid Build Coastguard Worker        )
420*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
421*da0073e9SAndroid Build Coastguard Worker            scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6])
422*da0073e9SAndroid Build Coastguard Worker        )
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker        scripted_fn_bool_list_input = torch.jit.script(
425*da0073e9SAndroid Build Coastguard Worker            test_list_and_tuple, example_inputs=[([True, False, True],)]
426*da0073e9SAndroid Build Coastguard Worker        )
427*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
428*da0073e9SAndroid Build Coastguard Worker            scripted_fn_bool_list_input([True, True, True]),
429*da0073e9SAndroid Build Coastguard Worker            test_list_and_tuple([True, True, True]),
430*da0073e9SAndroid Build Coastguard Worker        )
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker        scripted_fn_int_list_input = torch.jit.script(
433*da0073e9SAndroid Build Coastguard Worker            test_list_and_tuple, example_inputs=[([3, 4, 5],)]
434*da0073e9SAndroid Build Coastguard Worker        )
435*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
436*da0073e9SAndroid Build Coastguard Worker            scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3])
437*da0073e9SAndroid Build Coastguard Worker        )
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker        scripted_fn_float_tuple_input = torch.jit.script(
440*da0073e9SAndroid Build Coastguard Worker            test_list_and_tuple, example_inputs=[((4.9, 8.9),)]
441*da0073e9SAndroid Build Coastguard Worker        )
442*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
443*da0073e9SAndroid Build Coastguard Worker            scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6))
444*da0073e9SAndroid Build Coastguard Worker        )
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker        scripted_fn_bool_tuple_input = torch.jit.script(
447*da0073e9SAndroid Build Coastguard Worker            test_list_and_tuple, example_inputs=[((True, False, True),)]
448*da0073e9SAndroid Build Coastguard Worker        )
449*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
450*da0073e9SAndroid Build Coastguard Worker            scripted_fn_bool_tuple_input((True, True, True)),
451*da0073e9SAndroid Build Coastguard Worker            test_list_and_tuple((True, True, True)),
452*da0073e9SAndroid Build Coastguard Worker        )
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker        scripted_fn_int_tuple_input = torch.jit.script(
455*da0073e9SAndroid Build Coastguard Worker            test_list_and_tuple, example_inputs=[((3, 4, 5),)]
456*da0073e9SAndroid Build Coastguard Worker        )
457*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
458*da0073e9SAndroid Build Coastguard Worker            scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3))
459*da0073e9SAndroid Build Coastguard Worker        )
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker    def test_nested_list_and_tuple(self):
462*da0073e9SAndroid Build Coastguard Worker        def test_nested_list(inp):
463*da0073e9SAndroid Build Coastguard Worker            return [sum(v) for v in inp]
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker        def test_nested_tuple(inp):
466*da0073e9SAndroid Build Coastguard Worker            ans = 0.0
467*da0073e9SAndroid Build Coastguard Worker            for tup in inp:
468*da0073e9SAndroid Build Coastguard Worker                for val in tup:
469*da0073e9SAndroid Build Coastguard Worker                    if val > 0:
470*da0073e9SAndroid Build Coastguard Worker                        ans *= val
471*da0073e9SAndroid Build Coastguard Worker            return ans
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker        make_global(test_nested_list, test_nested_tuple)
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker        list_inp = [
476*da0073e9SAndroid Build Coastguard Worker            [
477*da0073e9SAndroid Build Coastguard Worker                1,
478*da0073e9SAndroid Build Coastguard Worker                2,
479*da0073e9SAndroid Build Coastguard Worker                3,
480*da0073e9SAndroid Build Coastguard Worker            ],
481*da0073e9SAndroid Build Coastguard Worker            [
482*da0073e9SAndroid Build Coastguard Worker                5,
483*da0073e9SAndroid Build Coastguard Worker                6,
484*da0073e9SAndroid Build Coastguard Worker                7,
485*da0073e9SAndroid Build Coastguard Worker            ],
486*da0073e9SAndroid Build Coastguard Worker        ]
487*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(
488*da0073e9SAndroid Build Coastguard Worker            test_nested_list,
489*da0073e9SAndroid Build Coastguard Worker            example_inputs=[
490*da0073e9SAndroid Build Coastguard Worker                (list_inp,),
491*da0073e9SAndroid Build Coastguard Worker            ],
492*da0073e9SAndroid Build Coastguard Worker        )
493*da0073e9SAndroid Build Coastguard Worker        inp = [
494*da0073e9SAndroid Build Coastguard Worker            [
495*da0073e9SAndroid Build Coastguard Worker                0,
496*da0073e9SAndroid Build Coastguard Worker                4,
497*da0073e9SAndroid Build Coastguard Worker                7,
498*da0073e9SAndroid Build Coastguard Worker            ],
499*da0073e9SAndroid Build Coastguard Worker            [
500*da0073e9SAndroid Build Coastguard Worker                8,
501*da0073e9SAndroid Build Coastguard Worker                11,
502*da0073e9SAndroid Build Coastguard Worker            ],
503*da0073e9SAndroid Build Coastguard Worker            [
504*da0073e9SAndroid Build Coastguard Worker                6,
505*da0073e9SAndroid Build Coastguard Worker                -1,
506*da0073e9SAndroid Build Coastguard Worker                -20,
507*da0073e9SAndroid Build Coastguard Worker            ],
508*da0073e9SAndroid Build Coastguard Worker        ]
509*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
510*da0073e9SAndroid Build Coastguard Worker            scripted_fn(
511*da0073e9SAndroid Build Coastguard Worker                inp,
512*da0073e9SAndroid Build Coastguard Worker            ),
513*da0073e9SAndroid Build Coastguard Worker            test_nested_list(
514*da0073e9SAndroid Build Coastguard Worker                inp,
515*da0073e9SAndroid Build Coastguard Worker            ),
516*da0073e9SAndroid Build Coastguard Worker        )
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker        list_inp = (
519*da0073e9SAndroid Build Coastguard Worker            [
520*da0073e9SAndroid Build Coastguard Worker                1,
521*da0073e9SAndroid Build Coastguard Worker                2,
522*da0073e9SAndroid Build Coastguard Worker                3,
523*da0073e9SAndroid Build Coastguard Worker            ],
524*da0073e9SAndroid Build Coastguard Worker            [
525*da0073e9SAndroid Build Coastguard Worker                5,
526*da0073e9SAndroid Build Coastguard Worker                6,
527*da0073e9SAndroid Build Coastguard Worker                7,
528*da0073e9SAndroid Build Coastguard Worker            ],
529*da0073e9SAndroid Build Coastguard Worker        )
530*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(
531*da0073e9SAndroid Build Coastguard Worker            test_nested_list,
532*da0073e9SAndroid Build Coastguard Worker            example_inputs=[
533*da0073e9SAndroid Build Coastguard Worker                (list_inp,),
534*da0073e9SAndroid Build Coastguard Worker            ],
535*da0073e9SAndroid Build Coastguard Worker        )
536*da0073e9SAndroid Build Coastguard Worker        inp = (
537*da0073e9SAndroid Build Coastguard Worker            [
538*da0073e9SAndroid Build Coastguard Worker                0,
539*da0073e9SAndroid Build Coastguard Worker                4,
540*da0073e9SAndroid Build Coastguard Worker                7,
541*da0073e9SAndroid Build Coastguard Worker            ],
542*da0073e9SAndroid Build Coastguard Worker            [
543*da0073e9SAndroid Build Coastguard Worker                8,
544*da0073e9SAndroid Build Coastguard Worker                11,
545*da0073e9SAndroid Build Coastguard Worker            ],
546*da0073e9SAndroid Build Coastguard Worker            [
547*da0073e9SAndroid Build Coastguard Worker                6,
548*da0073e9SAndroid Build Coastguard Worker                -1,
549*da0073e9SAndroid Build Coastguard Worker                -20,
550*da0073e9SAndroid Build Coastguard Worker            ],
551*da0073e9SAndroid Build Coastguard Worker        )
552*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
553*da0073e9SAndroid Build Coastguard Worker            scripted_fn(
554*da0073e9SAndroid Build Coastguard Worker                inp,
555*da0073e9SAndroid Build Coastguard Worker            ),
556*da0073e9SAndroid Build Coastguard Worker            test_nested_list(
557*da0073e9SAndroid Build Coastguard Worker                inp,
558*da0073e9SAndroid Build Coastguard Worker            ),
559*da0073e9SAndroid Build Coastguard Worker        )
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker        tup_inp = [
562*da0073e9SAndroid Build Coastguard Worker            (
563*da0073e9SAndroid Build Coastguard Worker                1.0,
564*da0073e9SAndroid Build Coastguard Worker                2.6,
565*da0073e9SAndroid Build Coastguard Worker                3.7,
566*da0073e9SAndroid Build Coastguard Worker            ),
567*da0073e9SAndroid Build Coastguard Worker            (
568*da0073e9SAndroid Build Coastguard Worker                5.7,
569*da0073e9SAndroid Build Coastguard Worker                6.1,
570*da0073e9SAndroid Build Coastguard Worker                1.7,
571*da0073e9SAndroid Build Coastguard Worker            ),
572*da0073e9SAndroid Build Coastguard Worker        ]
573*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(
574*da0073e9SAndroid Build Coastguard Worker            test_nested_tuple,
575*da0073e9SAndroid Build Coastguard Worker            example_inputs=[
576*da0073e9SAndroid Build Coastguard Worker                (tup_inp,),
577*da0073e9SAndroid Build Coastguard Worker            ],
578*da0073e9SAndroid Build Coastguard Worker        )
579*da0073e9SAndroid Build Coastguard Worker        inp = [
580*da0073e9SAndroid Build Coastguard Worker            (
581*da0073e9SAndroid Build Coastguard Worker                1.0,
582*da0073e9SAndroid Build Coastguard Worker                4.1,
583*da0073e9SAndroid Build Coastguard Worker                7.4,
584*da0073e9SAndroid Build Coastguard Worker            ),
585*da0073e9SAndroid Build Coastguard Worker            (
586*da0073e9SAndroid Build Coastguard Worker                4.8,
587*da0073e9SAndroid Build Coastguard Worker                1.1,
588*da0073e9SAndroid Build Coastguard Worker                -1.2,
589*da0073e9SAndroid Build Coastguard Worker            ),
590*da0073e9SAndroid Build Coastguard Worker            (
591*da0073e9SAndroid Build Coastguard Worker                6.3,
592*da0073e9SAndroid Build Coastguard Worker                -1.3,
593*da0073e9SAndroid Build Coastguard Worker                -2.0,
594*da0073e9SAndroid Build Coastguard Worker            ),
595*da0073e9SAndroid Build Coastguard Worker        ]
596*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
597*da0073e9SAndroid Build Coastguard Worker            scripted_fn(
598*da0073e9SAndroid Build Coastguard Worker                inp,
599*da0073e9SAndroid Build Coastguard Worker            ),
600*da0073e9SAndroid Build Coastguard Worker            test_nested_tuple(
601*da0073e9SAndroid Build Coastguard Worker                inp,
602*da0073e9SAndroid Build Coastguard Worker            ),
603*da0073e9SAndroid Build Coastguard Worker        )
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker        tup_inp = (
606*da0073e9SAndroid Build Coastguard Worker            (
607*da0073e9SAndroid Build Coastguard Worker                True,
608*da0073e9SAndroid Build Coastguard Worker                False,
609*da0073e9SAndroid Build Coastguard Worker                True,
610*da0073e9SAndroid Build Coastguard Worker            ),
611*da0073e9SAndroid Build Coastguard Worker            (
612*da0073e9SAndroid Build Coastguard Worker                False,
613*da0073e9SAndroid Build Coastguard Worker                False,
614*da0073e9SAndroid Build Coastguard Worker                False,
615*da0073e9SAndroid Build Coastguard Worker            ),
616*da0073e9SAndroid Build Coastguard Worker        )
617*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(
618*da0073e9SAndroid Build Coastguard Worker            test_nested_tuple,
619*da0073e9SAndroid Build Coastguard Worker            example_inputs=[
620*da0073e9SAndroid Build Coastguard Worker                (tup_inp,),
621*da0073e9SAndroid Build Coastguard Worker            ],
622*da0073e9SAndroid Build Coastguard Worker        )
623*da0073e9SAndroid Build Coastguard Worker        inp = (
624*da0073e9SAndroid Build Coastguard Worker            (
625*da0073e9SAndroid Build Coastguard Worker                True,
626*da0073e9SAndroid Build Coastguard Worker                True,
627*da0073e9SAndroid Build Coastguard Worker                True,
628*da0073e9SAndroid Build Coastguard Worker            ),
629*da0073e9SAndroid Build Coastguard Worker            (
630*da0073e9SAndroid Build Coastguard Worker                False,
631*da0073e9SAndroid Build Coastguard Worker                False,
632*da0073e9SAndroid Build Coastguard Worker                True,
633*da0073e9SAndroid Build Coastguard Worker            ),
634*da0073e9SAndroid Build Coastguard Worker        )
635*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
636*da0073e9SAndroid Build Coastguard Worker            scripted_fn(
637*da0073e9SAndroid Build Coastguard Worker                inp,
638*da0073e9SAndroid Build Coastguard Worker            ),
639*da0073e9SAndroid Build Coastguard Worker            test_nested_tuple(
640*da0073e9SAndroid Build Coastguard Worker                inp,
641*da0073e9SAndroid Build Coastguard Worker            ),
642*da0073e9SAndroid Build Coastguard Worker        )
643*da0073e9SAndroid Build Coastguard Worker
644*da0073e9SAndroid Build Coastguard Worker    def test_pdt_dict(self):
645*da0073e9SAndroid Build Coastguard Worker        def test_dict(a):
646*da0073e9SAndroid Build Coastguard Worker            return a["foo"]
647*da0073e9SAndroid Build Coastguard Worker
648*da0073e9SAndroid Build Coastguard Worker        def test_dict_int_list(a):
649*da0073e9SAndroid Build Coastguard Worker            return a[1]
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker        make_global(test_dict, test_dict_int_list)
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker        str_bool_inp = {"foo": True, "bar": False}
654*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)])
655*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
656*da0073e9SAndroid Build Coastguard Worker            scripted_fn(
657*da0073e9SAndroid Build Coastguard Worker                {"foo": False, "bar": True},
658*da0073e9SAndroid Build Coastguard Worker            ),
659*da0073e9SAndroid Build Coastguard Worker            test_dict(
660*da0073e9SAndroid Build Coastguard Worker                {"foo": False, "bar": True},
661*da0073e9SAndroid Build Coastguard Worker            ),
662*da0073e9SAndroid Build Coastguard Worker        )
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker        str_list_inp = {0: [True, False], 1: [False, True]}
665*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(
666*da0073e9SAndroid Build Coastguard Worker            test_dict_int_list, example_inputs=[(str_list_inp,)]
667*da0073e9SAndroid Build Coastguard Worker        )
668*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
669*da0073e9SAndroid Build Coastguard Worker            scripted_fn(
670*da0073e9SAndroid Build Coastguard Worker                {0: [False, False], 1: [True, True]},
671*da0073e9SAndroid Build Coastguard Worker            ),
672*da0073e9SAndroid Build Coastguard Worker            test_dict_int_list(
673*da0073e9SAndroid Build Coastguard Worker                {0: [False, False], 1: [True, True]},
674*da0073e9SAndroid Build Coastguard Worker            ),
675*da0073e9SAndroid Build Coastguard Worker        )
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker    def test_any(self):
678*da0073e9SAndroid Build Coastguard Worker        def test_multiple_types(a):
679*da0073e9SAndroid Build Coastguard Worker            assert not isinstance(a, bool)
680*da0073e9SAndroid Build Coastguard Worker            return a
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker        def test_multiple_type_refinement(a):
683*da0073e9SAndroid Build Coastguard Worker            if isinstance(a, bool):
684*da0073e9SAndroid Build Coastguard Worker                return 1
685*da0073e9SAndroid Build Coastguard Worker            elif isinstance(a, int):
686*da0073e9SAndroid Build Coastguard Worker                return 1 + a
687*da0073e9SAndroid Build Coastguard Worker            elif isinstance(a, float):
688*da0073e9SAndroid Build Coastguard Worker                return 1 + int(a)
689*da0073e9SAndroid Build Coastguard Worker            else:
690*da0073e9SAndroid Build Coastguard Worker                return -1
691*da0073e9SAndroid Build Coastguard Worker
692*da0073e9SAndroid Build Coastguard Worker        make_global(test_multiple_types, test_multiple_type_refinement)
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(
695*da0073e9SAndroid Build Coastguard Worker            test_multiple_types, example_inputs=[(1,), ("abc",), (8.9,), ([3, 4, 5],)]
696*da0073e9SAndroid Build Coastguard Worker        )
697*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn(10), test_multiple_types(10))
698*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn("def"), test_multiple_types("def"))
699*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999))
700*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14]))
701*da0073e9SAndroid Build Coastguard Worker
702*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(
703*da0073e9SAndroid Build Coastguard Worker            test_multiple_type_refinement,
704*da0073e9SAndroid Build Coastguard Worker            example_inputs=[
705*da0073e9SAndroid Build Coastguard Worker                (1,),
706*da0073e9SAndroid Build Coastguard Worker                ("abc",),
707*da0073e9SAndroid Build Coastguard Worker                (8.9,),
708*da0073e9SAndroid Build Coastguard Worker                ([3, 4, 5],),
709*da0073e9SAndroid Build Coastguard Worker                (True,),
710*da0073e9SAndroid Build Coastguard Worker                ({"a": True},),
711*da0073e9SAndroid Build Coastguard Worker            ],
712*da0073e9SAndroid Build Coastguard Worker        )
713*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10))
714*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def"))
715*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999))
716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
717*da0073e9SAndroid Build Coastguard Worker            scripted_fn([10, 11, 14]), test_multiple_type_refinement([10, 11, 14])
718*da0073e9SAndroid Build Coastguard Worker        )
719*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn(False), test_multiple_type_refinement(False))
720*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
721*da0073e9SAndroid Build Coastguard Worker            scripted_fn({"abc": True, "def": False}),
722*da0073e9SAndroid Build Coastguard Worker            test_multiple_type_refinement({"abc": True, "def": False}),
723*da0073e9SAndroid Build Coastguard Worker        )
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker    def test_class_as_profiled_types(self):
726*da0073e9SAndroid Build Coastguard Worker        class UserDefinedClass:
727*da0073e9SAndroid Build Coastguard Worker            def fn(self, b) -> Any:
728*da0073e9SAndroid Build Coastguard Worker                assert b is not None
729*da0073e9SAndroid Build Coastguard Worker                if isinstance(b, int):
730*da0073e9SAndroid Build Coastguard Worker                    return b if b > 0 else -1
731*da0073e9SAndroid Build Coastguard Worker                elif isinstance(b, float):
732*da0073e9SAndroid Build Coastguard Worker                    return b if b > 0.0 else -1.0
733*da0073e9SAndroid Build Coastguard Worker                return 0
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker        def test_model(a, m):
736*da0073e9SAndroid Build Coastguard Worker            assert not isinstance(a, bool)
737*da0073e9SAndroid Build Coastguard Worker            return m.fn(a)
738*da0073e9SAndroid Build Coastguard Worker
739*da0073e9SAndroid Build Coastguard Worker        make_global(UserDefinedClass, test_model)
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker        user_class = UserDefinedClass()
742*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(
743*da0073e9SAndroid Build Coastguard Worker            test_model,
744*da0073e9SAndroid Build Coastguard Worker            example_inputs=[
745*da0073e9SAndroid Build Coastguard Worker                (
746*da0073e9SAndroid Build Coastguard Worker                    10,
747*da0073e9SAndroid Build Coastguard Worker                    user_class,
748*da0073e9SAndroid Build Coastguard Worker                ),
749*da0073e9SAndroid Build Coastguard Worker                (
750*da0073e9SAndroid Build Coastguard Worker                    10.9,
751*da0073e9SAndroid Build Coastguard Worker                    user_class,
752*da0073e9SAndroid Build Coastguard Worker                ),
753*da0073e9SAndroid Build Coastguard Worker            ],
754*da0073e9SAndroid Build Coastguard Worker        )
755*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
756*da0073e9SAndroid Build Coastguard Worker            scripted_fn(
757*da0073e9SAndroid Build Coastguard Worker                100,
758*da0073e9SAndroid Build Coastguard Worker                user_class,
759*da0073e9SAndroid Build Coastguard Worker            ),
760*da0073e9SAndroid Build Coastguard Worker            test_model(100, user_class),
761*da0073e9SAndroid Build Coastguard Worker        )
762*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
763*da0073e9SAndroid Build Coastguard Worker            scripted_fn(
764*da0073e9SAndroid Build Coastguard Worker                1.9,
765*da0073e9SAndroid Build Coastguard Worker                user_class,
766*da0073e9SAndroid Build Coastguard Worker            ),
767*da0073e9SAndroid Build Coastguard Worker            test_model(1.9, user_class),
768*da0073e9SAndroid Build Coastguard Worker        )
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker    def test_class_with_args_as_profiled_types(self):
771*da0073e9SAndroid Build Coastguard Worker        class ClassWithArgs:
772*da0073e9SAndroid Build Coastguard Worker            def __init__(self, a: bool):
773*da0073e9SAndroid Build Coastguard Worker                self.a = a
774*da0073e9SAndroid Build Coastguard Worker
775*da0073e9SAndroid Build Coastguard Worker            def fn(self, b):
776*da0073e9SAndroid Build Coastguard Worker                if self.a:
777*da0073e9SAndroid Build Coastguard Worker                    return b
778*da0073e9SAndroid Build Coastguard Worker                else:
779*da0073e9SAndroid Build Coastguard Worker                    return -1
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker        def test_model_with_args(a, m):
782*da0073e9SAndroid Build Coastguard Worker            assert not isinstance(a, bool)
783*da0073e9SAndroid Build Coastguard Worker            return m.fn(a)
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker        make_global(ClassWithArgs, test_model_with_args)
786*da0073e9SAndroid Build Coastguard Worker
787*da0073e9SAndroid Build Coastguard Worker        user_class = ClassWithArgs(False)
788*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(
789*da0073e9SAndroid Build Coastguard Worker            test_model_with_args,
790*da0073e9SAndroid Build Coastguard Worker            example_inputs=[
791*da0073e9SAndroid Build Coastguard Worker                (
792*da0073e9SAndroid Build Coastguard Worker                    10,
793*da0073e9SAndroid Build Coastguard Worker                    user_class,
794*da0073e9SAndroid Build Coastguard Worker                ),
795*da0073e9SAndroid Build Coastguard Worker                (
796*da0073e9SAndroid Build Coastguard Worker                    10.9,
797*da0073e9SAndroid Build Coastguard Worker                    user_class,
798*da0073e9SAndroid Build Coastguard Worker                ),
799*da0073e9SAndroid Build Coastguard Worker            ],
800*da0073e9SAndroid Build Coastguard Worker        )
801*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
802*da0073e9SAndroid Build Coastguard Worker            scripted_fn(
803*da0073e9SAndroid Build Coastguard Worker                100,
804*da0073e9SAndroid Build Coastguard Worker                ClassWithArgs(True),
805*da0073e9SAndroid Build Coastguard Worker            ),
806*da0073e9SAndroid Build Coastguard Worker            test_model_with_args(100, ClassWithArgs(True)),
807*da0073e9SAndroid Build Coastguard Worker        )
808*da0073e9SAndroid Build Coastguard Worker
809*da0073e9SAndroid Build Coastguard Worker    def test_nn_parameter_as_arg(self):
810*da0073e9SAndroid Build Coastguard Worker        class TestNNParameter(torch.nn.Module):
811*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
812*da0073e9SAndroid Build Coastguard Worker                super().__init__()
813*da0073e9SAndroid Build Coastguard Worker                self.inp = torch.nn.Parameter(torch.ones(2, 3))
814*da0073e9SAndroid Build Coastguard Worker
815*da0073e9SAndroid Build Coastguard Worker            def add_nn_parameter_with_int(self, x, y):
816*da0073e9SAndroid Build Coastguard Worker                return torch.add(x, y)
817*da0073e9SAndroid Build Coastguard Worker
818*da0073e9SAndroid Build Coastguard Worker            def forward(self, y):
819*da0073e9SAndroid Build Coastguard Worker                return self.add_nn_parameter_with_int(self.inp, y)
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker        make_global(TestNNParameter)
822*da0073e9SAndroid Build Coastguard Worker        pdt_model = TestNNParameter()
823*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(
824*da0073e9SAndroid Build Coastguard Worker            pdt_model,
825*da0073e9SAndroid Build Coastguard Worker            example_inputs={
826*da0073e9SAndroid Build Coastguard Worker                pdt_model: [
827*da0073e9SAndroid Build Coastguard Worker                    (10,),
828*da0073e9SAndroid Build Coastguard Worker                ],
829*da0073e9SAndroid Build Coastguard Worker            },
830*da0073e9SAndroid Build Coastguard Worker        )
831*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn(20), pdt_model(20))
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker    def test_fx_tracing_with_typing(self):
834*da0073e9SAndroid Build Coastguard Worker        class FXModelOutput(NamedTuple):
835*da0073e9SAndroid Build Coastguard Worker            result: List[int]
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker        class FXModel(torch.nn.Module):
838*da0073e9SAndroid Build Coastguard Worker            def forward(self, a) -> FXModelOutput:
839*da0073e9SAndroid Build Coastguard Worker                result = FXModelOutput(result=a)
840*da0073e9SAndroid Build Coastguard Worker                return result
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Worker        make_global(FXModel, FXModelOutput)
843*da0073e9SAndroid Build Coastguard Worker        pdt_model = FXModel()
844*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(
845*da0073e9SAndroid Build Coastguard Worker            pdt_model,
846*da0073e9SAndroid Build Coastguard Worker            example_inputs={
847*da0073e9SAndroid Build Coastguard Worker                pdt_model: [
848*da0073e9SAndroid Build Coastguard Worker                    (
849*da0073e9SAndroid Build Coastguard Worker                        [
850*da0073e9SAndroid Build Coastguard Worker                            10,
851*da0073e9SAndroid Build Coastguard Worker                            20,
852*da0073e9SAndroid Build Coastguard Worker                        ],
853*da0073e9SAndroid Build Coastguard Worker                    ),
854*da0073e9SAndroid Build Coastguard Worker                ],
855*da0073e9SAndroid Build Coastguard Worker            },
856*da0073e9SAndroid Build Coastguard Worker        )
857*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scripted_fn([20]), pdt_model([20]))
858*da0073e9SAndroid Build Coastguard Worker
859*da0073e9SAndroid Build Coastguard Worker    def test_nonetype_as_optional_of_type(self):
860*da0073e9SAndroid Build Coastguard Worker        def test_none(a) -> Any:
861*da0073e9SAndroid Build Coastguard Worker            if a is None:
862*da0073e9SAndroid Build Coastguard Worker                return 0
863*da0073e9SAndroid Build Coastguard Worker            else:
864*da0073e9SAndroid Build Coastguard Worker                return a + torch.ones(1)
865*da0073e9SAndroid Build Coastguard Worker
866*da0073e9SAndroid Build Coastguard Worker        make_global(test_none)
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10.6,)])
869*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
870*da0073e9SAndroid Build Coastguard Worker            scripted_fn(
871*da0073e9SAndroid Build Coastguard Worker                30.9,
872*da0073e9SAndroid Build Coastguard Worker            ),
873*da0073e9SAndroid Build Coastguard Worker            test_none(
874*da0073e9SAndroid Build Coastguard Worker                30.9,
875*da0073e9SAndroid Build Coastguard Worker            ),
876*da0073e9SAndroid Build Coastguard Worker        )
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10,)])
879*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
880*da0073e9SAndroid Build Coastguard Worker            scripted_fn(
881*da0073e9SAndroid Build Coastguard Worker                2,
882*da0073e9SAndroid Build Coastguard Worker            ),
883*da0073e9SAndroid Build Coastguard Worker            test_none(
884*da0073e9SAndroid Build Coastguard Worker                2,
885*da0073e9SAndroid Build Coastguard Worker            ),
886*da0073e9SAndroid Build Coastguard Worker        )
887*da0073e9SAndroid Build Coastguard Worker
888*da0073e9SAndroid Build Coastguard Worker        scripted_fn = torch.jit.script(
889*da0073e9SAndroid Build Coastguard Worker            test_none, example_inputs=[(None,), (torch.Tensor(1),)]
890*da0073e9SAndroid Build Coastguard Worker        )
891*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
892*da0073e9SAndroid Build Coastguard Worker            scripted_fn(
893*da0073e9SAndroid Build Coastguard Worker                torch.ones(1),
894*da0073e9SAndroid Build Coastguard Worker            ),
895*da0073e9SAndroid Build Coastguard Worker            test_none(
896*da0073e9SAndroid Build Coastguard Worker                torch.ones(1),
897*da0073e9SAndroid Build Coastguard Worker            ),
898*da0073e9SAndroid Build Coastguard Worker        )
899