xref: /aosp_15_r20/external/pytorch/test/dynamo/test_functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Worker# flake8: noqa: E731, C405, F811, C418, C417
3*da0073e9SAndroid Build Coastguard Workerimport collections
4*da0073e9SAndroid Build Coastguard Workerimport functools
5*da0073e9SAndroid Build Coastguard Workerimport inspect
6*da0073e9SAndroid Build Coastguard Workerimport itertools
7*da0073e9SAndroid Build Coastguard Workerimport math
8*da0073e9SAndroid Build Coastguard Workerimport operator
9*da0073e9SAndroid Build Coastguard Workerimport random
10*da0073e9SAndroid Build Coastguard Workerimport sys
11*da0073e9SAndroid Build Coastguard Workerimport unittest
12*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass, field
13*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, NamedTuple
14*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerimport numpy as np
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerimport torch
19*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
20*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
21*da0073e9SAndroid Build Coastguard Workerfrom torch import sub
22*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import (
23*da0073e9SAndroid Build Coastguard Worker    CompileCounterWithBackend,
24*da0073e9SAndroid Build Coastguard Worker    EagerAndRecordGraphs,
25*da0073e9SAndroid Build Coastguard Worker    normalize_gm,
26*da0073e9SAndroid Build Coastguard Worker)
27*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import ifdynstaticdefault, same
28*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.variables import ConstantVariable
29*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.variables.lists import RangeVariable
30*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import functional as F
31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
32*da0073e9SAndroid Build Coastguard Worker    disable_translation_validation_if_dynamic_shapes,
33*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
34*da0073e9SAndroid Build Coastguard Worker    parametrize,
35*da0073e9SAndroid Build Coastguard Worker)
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker# Defines all the kernels for tests
38*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.triton_utils import *  # noqa: F403
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerd = torch.ones(10, 10)
42*da0073e9SAndroid Build Coastguard Workere = torch.nn.Linear(10, 10)
43*da0073e9SAndroid Build Coastguard Workerflag = True
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerclass CustomDictSubclass(collections.OrderedDict):
47*da0073e9SAndroid Build Coastguard Worker    pass
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Workerclip01 = functools.partial(torch.clip, min=0.0, max=1.0)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Workerdef constant3(a, b):
54*da0073e9SAndroid Build Coastguard Worker    return a - b + (1.0 + 2)
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker_variable = 0
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Workerdef update_global(x):
61*da0073e9SAndroid Build Coastguard Worker    global _variable
62*da0073e9SAndroid Build Coastguard Worker    _variable += 1
63*da0073e9SAndroid Build Coastguard Worker    # Check that updated global variable value is picked up
64*da0073e9SAndroid Build Coastguard Worker    return x * _variable
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Workerdef func_with_default(a, b, some_default_arg=True):
68*da0073e9SAndroid Build Coastguard Worker    if some_default_arg:
69*da0073e9SAndroid Build Coastguard Worker        return a - b
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Workerdef make_test(fn=None, expected_frame_count=1):
73*da0073e9SAndroid Build Coastguard Worker    if fn is None:
74*da0073e9SAndroid Build Coastguard Worker        return lambda fn: make_test(fn, expected_frame_count=expected_frame_count)
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    nargs = len(inspect.signature(fn).parameters)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    def test_fn(self):
79*da0073e9SAndroid Build Coastguard Worker        return torch._dynamo.testing.standard_test(
80*da0073e9SAndroid Build Coastguard Worker            self,
81*da0073e9SAndroid Build Coastguard Worker            fn=fn,
82*da0073e9SAndroid Build Coastguard Worker            nargs=nargs,
83*da0073e9SAndroid Build Coastguard Worker            expected_frame_count=expected_frame_count,
84*da0073e9SAndroid Build Coastguard Worker        )
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    return test_fn
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Workerclass MyCls:
90*da0073e9SAndroid Build Coastguard Worker    a = 1
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker@torch.jit.script_if_tracing
94*da0073e9SAndroid Build Coastguard Workerdef inline_script_if_tracing(x):
95*da0073e9SAndroid Build Coastguard Worker    return x + 1.2
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker@torch.jit.ignore
99*da0073e9SAndroid Build Coastguard Workerdef inline_ignore(x):
100*da0073e9SAndroid Build Coastguard Worker    return x + 3.4
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker@torch.jit.unused
104*da0073e9SAndroid Build Coastguard Workerdef inline_unused(x):
105*da0073e9SAndroid Build Coastguard Worker    return x + 5.6
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache
109*da0073e9SAndroid Build Coastguard Workerdef inline_lru_cache_fn_with_default_args(x, y, _=None):
110*da0073e9SAndroid Build Coastguard Worker    return torch.sin(x * y)
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker@torch.jit.script_if_tracing
114*da0073e9SAndroid Build Coastguard Workerdef inline_script_if_tracing_fn_with_default_args(x, y, c=1.2):
115*da0073e9SAndroid Build Coastguard Worker    return torch.cos(x * y) + c
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Workerclass FunctionTests(torch._dynamo.test_case.TestCase):
119*da0073e9SAndroid Build Coastguard Worker    @make_test
120*da0073e9SAndroid Build Coastguard Worker    def test_inline_jit_annotations(x):
121*da0073e9SAndroid Build Coastguard Worker        x = inline_script_if_tracing(x)
122*da0073e9SAndroid Build Coastguard Worker        x = inline_ignore(x)
123*da0073e9SAndroid Build Coastguard Worker        x = inline_unused(x)
124*da0073e9SAndroid Build Coastguard Worker        return
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker    @make_test
127*da0073e9SAndroid Build Coastguard Worker    def test_inline_script_if_tracing_fn_with_default_args(a, b):
128*da0073e9SAndroid Build Coastguard Worker        return inline_script_if_tracing_fn_with_default_args(a, b)
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    @make_test
131*da0073e9SAndroid Build Coastguard Worker    def test_inline_lru_cache_fn_with_default_args(a, b):
132*da0073e9SAndroid Build Coastguard Worker        return inline_lru_cache_fn_with_default_args(a, 2, b)
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    @make_test
135*da0073e9SAndroid Build Coastguard Worker    def test_add(a, b):
136*da0073e9SAndroid Build Coastguard Worker        return a + b
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    @make_test
139*da0073e9SAndroid Build Coastguard Worker    def test_add_(a, b):
140*da0073e9SAndroid Build Coastguard Worker        a_copy = torch.tensor(a)
141*da0073e9SAndroid Build Coastguard Worker        return a_copy.add_(b, alpha=5.0)
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker    @make_test
144*da0073e9SAndroid Build Coastguard Worker    def test_addcdiv(a, b, c):
145*da0073e9SAndroid Build Coastguard Worker        # dynamo decomposes this to avoid a graph break when
146*da0073e9SAndroid Build Coastguard Worker        # the value kwarg is populated
147*da0073e9SAndroid Build Coastguard Worker        return torch.addcdiv(a, b, c, value=5.0)
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker    @make_test
150*da0073e9SAndroid Build Coastguard Worker    def test_addcdiv_(a, b, c):
151*da0073e9SAndroid Build Coastguard Worker        a_copy = torch.tensor(a)
152*da0073e9SAndroid Build Coastguard Worker        return a_copy.addcdiv_(b, c, value=5.0)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker    @make_test
155*da0073e9SAndroid Build Coastguard Worker    def test_is_not_null(a, b):
156*da0073e9SAndroid Build Coastguard Worker        if a is not None and b is not None:
157*da0073e9SAndroid Build Coastguard Worker            return a + b
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker    def test_foreach_lerp_(self):
160*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, s):
161*da0073e9SAndroid Build Coastguard Worker            return torch._foreach_lerp_(x, y, s)
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn)
166*da0073e9SAndroid Build Coastguard Worker        expected = fn(
167*da0073e9SAndroid Build Coastguard Worker            [torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14],
168*da0073e9SAndroid Build Coastguard Worker            [torch.ones(2, 2), torch.ones(2, 2)],
169*da0073e9SAndroid Build Coastguard Worker            torch.tensor(0.5),
170*da0073e9SAndroid Build Coastguard Worker        )
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker        actual = fn_opt(
173*da0073e9SAndroid Build Coastguard Worker            [torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14],
174*da0073e9SAndroid Build Coastguard Worker            [torch.ones(2, 2), torch.ones(2, 2)],
175*da0073e9SAndroid Build Coastguard Worker            torch.tensor(0.5),
176*da0073e9SAndroid Build Coastguard Worker        )
177*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(expected, actual))
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_foreach_pow(self):
180*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.utils import same
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
183*da0073e9SAndroid Build Coastguard Worker            return torch._foreach_pow(x, y)
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn)
188*da0073e9SAndroid Build Coastguard Worker        inps = (torch.tensor(0.80), [torch.tensor(3.4), torch.tensor(7.8)])
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker        actual = fn_opt(*inps)
191*da0073e9SAndroid Build Coastguard Worker        expected = fn(*inps)
192*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(actual, expected))
193*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(cnt.frame_count, 1)
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    def test_addcmul_(self):
196*da0073e9SAndroid Build Coastguard Worker        from copy import deepcopy
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker        from torch._dynamo.utils import same
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z, s):
201*da0073e9SAndroid Build Coastguard Worker            return x.addcmul_(y, z, value=s)
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
204*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn)
205*da0073e9SAndroid Build Coastguard Worker        inps = (
206*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 2),
207*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 2) + 1,
208*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 2),
209*da0073e9SAndroid Build Coastguard Worker            torch.tensor(0.3),
210*da0073e9SAndroid Build Coastguard Worker        )
211*da0073e9SAndroid Build Coastguard Worker        inps_2 = deepcopy(inps)
212*da0073e9SAndroid Build Coastguard Worker        actual = fn_opt(*inps)
213*da0073e9SAndroid Build Coastguard Worker        expected = fn(*inps_2)
214*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(actual, expected))
215*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker    @make_test
218*da0073e9SAndroid Build Coastguard Worker    def test_functools_partial(a, b):
219*da0073e9SAndroid Build Coastguard Worker        return clip01(a + b)
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker    @make_test
222*da0073e9SAndroid Build Coastguard Worker    def test_itertools_product(a, b):
223*da0073e9SAndroid Build Coastguard Worker        v = a
224*da0073e9SAndroid Build Coastguard Worker        for x, i in itertools.product([a, b], [1, 2]):
225*da0073e9SAndroid Build Coastguard Worker            v = v + x * i
226*da0073e9SAndroid Build Coastguard Worker        return v
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker    @make_test
229*da0073e9SAndroid Build Coastguard Worker    def test_itertools_chain(a, b):
230*da0073e9SAndroid Build Coastguard Worker        v = a
231*da0073e9SAndroid Build Coastguard Worker        for x in itertools.chain([a, b], [1, 2]):
232*da0073e9SAndroid Build Coastguard Worker            v = v + x
233*da0073e9SAndroid Build Coastguard Worker        return v
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker    @make_test
236*da0073e9SAndroid Build Coastguard Worker    def test_itertools_chain_from_iterable(a, b):
237*da0073e9SAndroid Build Coastguard Worker        v = a
238*da0073e9SAndroid Build Coastguard Worker        for x in itertools.chain.from_iterable([[a, b], [1, 2]]):
239*da0073e9SAndroid Build Coastguard Worker            v = v + x
240*da0073e9SAndroid Build Coastguard Worker        return v
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker    def test_itertools_reconstruct(self):
243*da0073e9SAndroid Build Coastguard Worker        def fn(a):
244*da0073e9SAndroid Build Coastguard Worker            it1 = itertools.repeat(1)
245*da0073e9SAndroid Build Coastguard Worker            it2 = itertools.count(2)
246*da0073e9SAndroid Build Coastguard Worker            for _ in range(3):
247*da0073e9SAndroid Build Coastguard Worker                a += next(it1)
248*da0073e9SAndroid Build Coastguard Worker                a += next(it2)
249*da0073e9SAndroid Build Coastguard Worker            return it1, it2, a
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
252*da0073e9SAndroid Build Coastguard Worker        i1, i2, a = fn(torch.ones(3, 3))
253*da0073e9SAndroid Build Coastguard Worker        it1, it2, b = opt_fn(torch.ones(3, 3))
254*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(next(i1), next(it1))
255*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(next(i2), next(it2))
256*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, b)
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker    @make_test
259*da0073e9SAndroid Build Coastguard Worker    def test_obj_eq(a, b):
260*da0073e9SAndroid Build Coastguard Worker        v = a + b
261*da0073e9SAndroid Build Coastguard Worker        if MyCls() == None:  # noqa: E711
262*da0073e9SAndroid Build Coastguard Worker            return -1
263*da0073e9SAndroid Build Coastguard Worker        if MyCls() != None:  # noqa: E711
264*da0073e9SAndroid Build Coastguard Worker            v = v.sin()
265*da0073e9SAndroid Build Coastguard Worker        if MyCls() == MyCls():
266*da0073e9SAndroid Build Coastguard Worker            return -2
267*da0073e9SAndroid Build Coastguard Worker        if MyCls() != MyCls():
268*da0073e9SAndroid Build Coastguard Worker            return v + 1
269*da0073e9SAndroid Build Coastguard Worker        return -3
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker    @make_test
272*da0073e9SAndroid Build Coastguard Worker    def test_cls_eq(a, b):
273*da0073e9SAndroid Build Coastguard Worker        v = a + b
274*da0073e9SAndroid Build Coastguard Worker        if MyCls == None:  # noqa: E711
275*da0073e9SAndroid Build Coastguard Worker            return -1
276*da0073e9SAndroid Build Coastguard Worker        if MyCls != None:  # noqa: E711
277*da0073e9SAndroid Build Coastguard Worker            v = v.sin()
278*da0073e9SAndroid Build Coastguard Worker        if MyCls != MyCls:
279*da0073e9SAndroid Build Coastguard Worker            return -2
280*da0073e9SAndroid Build Coastguard Worker        if MyCls == MyCls:
281*da0073e9SAndroid Build Coastguard Worker            return v + 1
282*da0073e9SAndroid Build Coastguard Worker        return -3
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker    @make_test
285*da0073e9SAndroid Build Coastguard Worker    def test_obj_is(a, b):
286*da0073e9SAndroid Build Coastguard Worker        v = a + b
287*da0073e9SAndroid Build Coastguard Worker        if MyCls() is None:  # noqa: E711
288*da0073e9SAndroid Build Coastguard Worker            return -1
289*da0073e9SAndroid Build Coastguard Worker        if MyCls() is not None:  # noqa: E711
290*da0073e9SAndroid Build Coastguard Worker            v = v.sin()
291*da0073e9SAndroid Build Coastguard Worker        if MyCls() is MyCls():
292*da0073e9SAndroid Build Coastguard Worker            return -2
293*da0073e9SAndroid Build Coastguard Worker        if MyCls() is not MyCls():
294*da0073e9SAndroid Build Coastguard Worker            return v + 1
295*da0073e9SAndroid Build Coastguard Worker        return -3
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    @make_test
298*da0073e9SAndroid Build Coastguard Worker    def test_cls_is(a, b):
299*da0073e9SAndroid Build Coastguard Worker        v = a + b
300*da0073e9SAndroid Build Coastguard Worker        if MyCls is None:  # noqa: E711
301*da0073e9SAndroid Build Coastguard Worker            return -1
302*da0073e9SAndroid Build Coastguard Worker        if MyCls is not None:  # noqa: E711
303*da0073e9SAndroid Build Coastguard Worker            v = v.sin()
304*da0073e9SAndroid Build Coastguard Worker        if MyCls is not MyCls:
305*da0073e9SAndroid Build Coastguard Worker            return -2
306*da0073e9SAndroid Build Coastguard Worker        if MyCls is MyCls:
307*da0073e9SAndroid Build Coastguard Worker            return v + 1
308*da0073e9SAndroid Build Coastguard Worker        return -3
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker    @make_test
311*da0073e9SAndroid Build Coastguard Worker    def test_itertools_combinations(a, b):
312*da0073e9SAndroid Build Coastguard Worker        combs = []
313*da0073e9SAndroid Build Coastguard Worker        for size in itertools.combinations((1, 2, 3, 4), 2):
314*da0073e9SAndroid Build Coastguard Worker            combs.append(torch.ones(size))
315*da0073e9SAndroid Build Coastguard Worker        return combs
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker    @make_test
318*da0073e9SAndroid Build Coastguard Worker    def test_np_iinfo(a):
319*da0073e9SAndroid Build Coastguard Worker        max_dim = np.iinfo(np.int16).max
320*da0073e9SAndroid Build Coastguard Worker        return a + max_dim
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker    @make_test
323*da0073e9SAndroid Build Coastguard Worker    def test_np_finfo(a):
324*da0073e9SAndroid Build Coastguard Worker        min_dim = np.finfo(np.float32).min
325*da0073e9SAndroid Build Coastguard Worker        return a + min_dim
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker    @make_test
328*da0073e9SAndroid Build Coastguard Worker    def test_constant1(a, b, c):
329*da0073e9SAndroid Build Coastguard Worker        return a - b * c + 1.0
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker    @make_test
332*da0073e9SAndroid Build Coastguard Worker    def test_constant2(a, b, c):
333*da0073e9SAndroid Build Coastguard Worker        return a - b * c + 1
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker    @make_test
336*da0073e9SAndroid Build Coastguard Worker    def test_constant3(a):
337*da0073e9SAndroid Build Coastguard Worker        b = 1
338*da0073e9SAndroid Build Coastguard Worker        c = 2
339*da0073e9SAndroid Build Coastguard Worker        d = 3
340*da0073e9SAndroid Build Coastguard Worker        return b + c - d + a
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker    @make_test
343*da0073e9SAndroid Build Coastguard Worker    def test_constant4(a, b):
344*da0073e9SAndroid Build Coastguard Worker        c = 2
345*da0073e9SAndroid Build Coastguard Worker        d = 3
346*da0073e9SAndroid Build Coastguard Worker        if c > d:
347*da0073e9SAndroid Build Coastguard Worker            return a - b
348*da0073e9SAndroid Build Coastguard Worker        return b - a
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker    @make_test
351*da0073e9SAndroid Build Coastguard Worker    def test_cls_hasattr(self, x):
352*da0073e9SAndroid Build Coastguard Worker        if hasattr(MyCls, "a"):
353*da0073e9SAndroid Build Coastguard Worker            x = x + 1
354*da0073e9SAndroid Build Coastguard Worker        if hasattr(MyCls, "b"):
355*da0073e9SAndroid Build Coastguard Worker            x = x + 2
356*da0073e9SAndroid Build Coastguard Worker        return x
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker    @make_test
359*da0073e9SAndroid Build Coastguard Worker    def test_finfo(a, b):
360*da0073e9SAndroid Build Coastguard Worker        if torch.iinfo(torch.int32).bits == 32:
361*da0073e9SAndroid Build Coastguard Worker            return torch.finfo(a.dtype).min * b
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker    @make_test
364*da0073e9SAndroid Build Coastguard Worker    def test_globalfn(a, b):
365*da0073e9SAndroid Build Coastguard Worker        return sub(a, b)
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker    @make_test
368*da0073e9SAndroid Build Coastguard Worker    def test_viatorch(a, b):
369*da0073e9SAndroid Build Coastguard Worker        return torch.sub(a, b)
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker    @make_test
372*da0073e9SAndroid Build Coastguard Worker    def test_viamethod(a, b):
373*da0073e9SAndroid Build Coastguard Worker        return a.sub(b)
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker    @make_test
376*da0073e9SAndroid Build Coastguard Worker    def test_indirect1(a, b):
377*da0073e9SAndroid Build Coastguard Worker        t = a.sub
378*da0073e9SAndroid Build Coastguard Worker        return t(b)
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker    @make_test
381*da0073e9SAndroid Build Coastguard Worker    def test_indirect2(a, b):
382*da0073e9SAndroid Build Coastguard Worker        t = a.sub
383*da0073e9SAndroid Build Coastguard Worker        args = (b,)
384*da0073e9SAndroid Build Coastguard Worker        return t(*args)
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker    @make_test
387*da0073e9SAndroid Build Coastguard Worker    def test_indirect3(a, b):
388*da0073e9SAndroid Build Coastguard Worker        t = a.sub
389*da0073e9SAndroid Build Coastguard Worker        args = (b,)
390*da0073e9SAndroid Build Coastguard Worker        kwargs = {}
391*da0073e9SAndroid Build Coastguard Worker        return t(*args, **kwargs)
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker    @make_test
394*da0073e9SAndroid Build Coastguard Worker    def test_methodcall1(a, b, c):
395*da0073e9SAndroid Build Coastguard Worker        return constant3(a, b) * c
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker    @make_test
398*da0073e9SAndroid Build Coastguard Worker    def test_methodcall2(a, b):
399*da0073e9SAndroid Build Coastguard Worker        return constant3(a=b, b=a) + 1
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker    @make_test
402*da0073e9SAndroid Build Coastguard Worker    def test_methodcall3(a, b):
403*da0073e9SAndroid Build Coastguard Worker        return constant3(a, b=1.0) + b
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker    def test_is_integer(self):
406*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
407*da0073e9SAndroid Build Coastguard Worker        def forward(t, m):
408*da0073e9SAndroid Build Coastguard Worker            return 2 * t if m.is_integer() else t
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([1])
411*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(forward(t, 1.0).item(), 2)
412*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(forward(t, 1.5).item(), 1)
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker    @parametrize(
415*da0073e9SAndroid Build Coastguard Worker        "method, num_type",
416*da0073e9SAndroid Build Coastguard Worker        (
417*da0073e9SAndroid Build Coastguard Worker            ("as_integer_ratio", int),
418*da0073e9SAndroid Build Coastguard Worker            ("bit_length", int),
419*da0073e9SAndroid Build Coastguard Worker            ("conjugate", int),
420*da0073e9SAndroid Build Coastguard Worker            ("as_integer_ratio", float),
421*da0073e9SAndroid Build Coastguard Worker            ("conjugate", float),
422*da0073e9SAndroid Build Coastguard Worker            ("hex", float),
423*da0073e9SAndroid Build Coastguard Worker            ("is_integer", float),
424*da0073e9SAndroid Build Coastguard Worker        ),
425*da0073e9SAndroid Build Coastguard Worker    )
426*da0073e9SAndroid Build Coastguard Worker    def test_number_method(self, method, num_type):
427*da0073e9SAndroid Build Coastguard Worker        def forward(t, m):
428*da0073e9SAndroid Build Coastguard Worker            return 2 * t if getattr(m, method)() else t
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker        wrapped = torch.compile(backend="eager", fullgraph=True)(forward)
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker        for i in (0, 1, 2.5):
433*da0073e9SAndroid Build Coastguard Worker            m = num_type(i)
434*da0073e9SAndroid Build Coastguard Worker            t = torch.tensor([1])
435*da0073e9SAndroid Build Coastguard Worker            actual = wrapped(t, m)
436*da0073e9SAndroid Build Coastguard Worker            expected = forward(t, m)
437*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected)
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker    @make_test
440*da0073e9SAndroid Build Coastguard Worker    def test_device_constant(a):
441*da0073e9SAndroid Build Coastguard Worker        return a + torch.ones(1, device=torch.device("cpu"))
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker    @make_test
444*da0073e9SAndroid Build Coastguard Worker    def test_tuple1(a, b):
445*da0073e9SAndroid Build Coastguard Worker        args = (a, b)
446*da0073e9SAndroid Build Coastguard Worker        return sub(*args)
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker    @make_test
449*da0073e9SAndroid Build Coastguard Worker    def test_tuple2(a, b):
450*da0073e9SAndroid Build Coastguard Worker        args = [a, b]
451*da0073e9SAndroid Build Coastguard Worker        return sub(*args)
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker    @make_test
454*da0073e9SAndroid Build Coastguard Worker    def test_is_in_onnx_export(x, y):
455*da0073e9SAndroid Build Coastguard Worker        if torch.onnx.is_in_onnx_export():
456*da0073e9SAndroid Build Coastguard Worker            return x - 1
457*da0073e9SAndroid Build Coastguard Worker        else:
458*da0073e9SAndroid Build Coastguard Worker            return y + 1
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Worker    @make_test
461*da0073e9SAndroid Build Coastguard Worker    def test_is_fx_tracing(x, y):
462*da0073e9SAndroid Build Coastguard Worker        if torch.fx._symbolic_trace.is_fx_tracing():
463*da0073e9SAndroid Build Coastguard Worker            return x - 1
464*da0073e9SAndroid Build Coastguard Worker        else:
465*da0073e9SAndroid Build Coastguard Worker            return y + 1
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker    @make_test
468*da0073e9SAndroid Build Coastguard Worker    def test_listarg1(a, b):
469*da0073e9SAndroid Build Coastguard Worker        return torch.cat([a, b])
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker    @make_test
472*da0073e9SAndroid Build Coastguard Worker    def test_listarg2(a, b):
473*da0073e9SAndroid Build Coastguard Worker        return torch.cat((a, b), dim=0)
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker    @make_test
476*da0073e9SAndroid Build Coastguard Worker    def test_listarg3(a, b):
477*da0073e9SAndroid Build Coastguard Worker        kwargs = {"tensors": (a, b), "dim": 0}
478*da0073e9SAndroid Build Coastguard Worker        return torch.cat(**kwargs)
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    @make_test
481*da0073e9SAndroid Build Coastguard Worker    def test_listarg4(a, b):
482*da0073e9SAndroid Build Coastguard Worker        return torch.cat(tensors=[a, b], dim=0)
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker    @make_test
485*da0073e9SAndroid Build Coastguard Worker    def test_listarg5(a, b):
486*da0073e9SAndroid Build Coastguard Worker        args = [(a, b)]
487*da0073e9SAndroid Build Coastguard Worker        kwargs = {"dim": 0}
488*da0073e9SAndroid Build Coastguard Worker        return torch.cat(*args, **kwargs)
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker    def test_list_slice(self):
491*da0073e9SAndroid Build Coastguard Worker        class Mock:
492*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
493*da0073e9SAndroid Build Coastguard Worker                self.ets = []
494*da0073e9SAndroid Build Coastguard Worker                self.counter = 0
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker            @torch.compile(backend="eager")
497*da0073e9SAndroid Build Coastguard Worker            def run(self, x):
498*da0073e9SAndroid Build Coastguard Worker                self.ets = self.ets[-3:]
499*da0073e9SAndroid Build Coastguard Worker                self.ets.append(x)
500*da0073e9SAndroid Build Coastguard Worker                return torch.sin(x)
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker        mock = Mock()
503*da0073e9SAndroid Build Coastguard Worker        mock.run(torch.randn(4))
504*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(mock.ets), 1)
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker    @make_test
507*da0073e9SAndroid Build Coastguard Worker    def test_deque(a, b):
508*da0073e9SAndroid Build Coastguard Worker        d = collections.deque([a, b])
509*da0073e9SAndroid Build Coastguard Worker        d.append(a + 1)
510*da0073e9SAndroid Build Coastguard Worker        d.extend([a, b])
511*da0073e9SAndroid Build Coastguard Worker        d.insert(0, "foo")
512*da0073e9SAndroid Build Coastguard Worker        tmp = d.pop()
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker        another_deque = collections.deque([tmp])
515*da0073e9SAndroid Build Coastguard Worker        d.extendleft(another_deque)
516*da0073e9SAndroid Build Coastguard Worker        another_deque.clear()
517*da0073e9SAndroid Build Coastguard Worker        d.extend(another_deque)
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker        d[2] = "setitem"
520*da0073e9SAndroid Build Coastguard Worker        d = d.copy()
521*da0073e9SAndroid Build Coastguard Worker        d.append(d.popleft())
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker        empty = collections.deque()
524*da0073e9SAndroid Build Coastguard Worker        d.extend(empty)
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker        return d
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker    @make_test
529*da0073e9SAndroid Build Coastguard Worker    def test_slice1(a):
530*da0073e9SAndroid Build Coastguard Worker        return a[5]
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker    @make_test
533*da0073e9SAndroid Build Coastguard Worker    def test_slice2(a):
534*da0073e9SAndroid Build Coastguard Worker        return a[:5]
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker    @make_test
537*da0073e9SAndroid Build Coastguard Worker    def test_slice3(a):
538*da0073e9SAndroid Build Coastguard Worker        return a[5:]
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker    @make_test
541*da0073e9SAndroid Build Coastguard Worker    def test_slice4(a):
542*da0073e9SAndroid Build Coastguard Worker        return a[2:5]
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker    @make_test
545*da0073e9SAndroid Build Coastguard Worker    def test_slice5(a):
546*da0073e9SAndroid Build Coastguard Worker        return a[::2]
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker    @make_test
549*da0073e9SAndroid Build Coastguard Worker    def test_slice6(a):
550*da0073e9SAndroid Build Coastguard Worker        return torch.unsqueeze(a, 0)[:, 2:]
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker    @make_test
553*da0073e9SAndroid Build Coastguard Worker    def test_range1(a):
554*da0073e9SAndroid Build Coastguard Worker        return torch.tensor(range(a.size(0)))
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker    @make_test
557*da0073e9SAndroid Build Coastguard Worker    def test_range2(x, y):
558*da0073e9SAndroid Build Coastguard Worker        r = x + y
559*da0073e9SAndroid Build Coastguard Worker        for i in range(x.size(0) + 2):
560*da0073e9SAndroid Build Coastguard Worker            r = r / y
561*da0073e9SAndroid Build Coastguard Worker        return r
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker    @make_test
564*da0073e9SAndroid Build Coastguard Worker    def test_unpack1(a):
565*da0073e9SAndroid Build Coastguard Worker        a, b = a[:5], a[5:]
566*da0073e9SAndroid Build Coastguard Worker        return a - b
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker    @make_test
569*da0073e9SAndroid Build Coastguard Worker    def test_unpack2(a):
570*da0073e9SAndroid Build Coastguard Worker        packed = [a[:5], a[5:]]
571*da0073e9SAndroid Build Coastguard Worker        a, b = packed
572*da0073e9SAndroid Build Coastguard Worker        return a - b
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker    @make_test
575*da0073e9SAndroid Build Coastguard Worker    def test_unpack3(a):
576*da0073e9SAndroid Build Coastguard Worker        packed = (a[:5], a[5:])
577*da0073e9SAndroid Build Coastguard Worker        a, b = packed
578*da0073e9SAndroid Build Coastguard Worker        return a - b
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker    @make_test
581*da0073e9SAndroid Build Coastguard Worker    def test_fn_with_self_set(a, b):
582*da0073e9SAndroid Build Coastguard Worker        # avg_pool2d is an odd one with __self__ set
583*da0073e9SAndroid Build Coastguard Worker        return F.avg_pool2d(
584*da0073e9SAndroid Build Coastguard Worker            torch.unsqueeze(a, 0) * torch.unsqueeze(b, 1), kernel_size=2, padding=1
585*da0073e9SAndroid Build Coastguard Worker        )
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker    @make_test
588*da0073e9SAndroid Build Coastguard Worker    def test_return_tuple1(a, b):
589*da0073e9SAndroid Build Coastguard Worker        return (a - b, b - a, a, b)
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker    @make_test
592*da0073e9SAndroid Build Coastguard Worker    def test_globalvar(a, b):
593*da0073e9SAndroid Build Coastguard Worker        return a - b + d
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Worker    @make_test
596*da0073e9SAndroid Build Coastguard Worker    def test_globalmodule(x):
597*da0073e9SAndroid Build Coastguard Worker        return e(x)
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker    @make_test
600*da0073e9SAndroid Build Coastguard Worker    def test_inline_with_default(a, b, c):
601*da0073e9SAndroid Build Coastguard Worker        return func_with_default(a, b) * c
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker    @make_test
604*da0073e9SAndroid Build Coastguard Worker    def test_inner_function(x):
605*da0073e9SAndroid Build Coastguard Worker        def fn(x):
606*da0073e9SAndroid Build Coastguard Worker            return torch.add(x, x)
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker        return fn(x)
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker    @make_test
611*da0073e9SAndroid Build Coastguard Worker    def test_transpose_for_scores(x):
612*da0073e9SAndroid Build Coastguard Worker        new_x_shape = x.size()[:-1] + (2, 5)
613*da0073e9SAndroid Build Coastguard Worker        x = x.view(*new_x_shape)
614*da0073e9SAndroid Build Coastguard Worker        return x.permute(0, 2, 1)
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker    @make_test
617*da0073e9SAndroid Build Coastguard Worker    def test_return_tuple2(x):
618*da0073e9SAndroid Build Coastguard Worker        return (torch.add(x, x), x)
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker    @make_test
621*da0073e9SAndroid Build Coastguard Worker    def test_load_global_bool(x):
622*da0073e9SAndroid Build Coastguard Worker        if flag:
623*da0073e9SAndroid Build Coastguard Worker            return torch.add(x, x)
624*da0073e9SAndroid Build Coastguard Worker        else:
625*da0073e9SAndroid Build Coastguard Worker            return x
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker    @make_test
628*da0073e9SAndroid Build Coastguard Worker    def test_len_tensor(x):
629*da0073e9SAndroid Build Coastguard Worker        z = len(x)
630*da0073e9SAndroid Build Coastguard Worker        return torch.add(x, z)
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker    @make_test
633*da0073e9SAndroid Build Coastguard Worker    def test_len_constant_list(x):
634*da0073e9SAndroid Build Coastguard Worker        z = len([1, 2, 3])
635*da0073e9SAndroid Build Coastguard Worker        return torch.add(x, z)
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker    @make_test
638*da0073e9SAndroid Build Coastguard Worker    def test_len_constant_dict(x):
639*da0073e9SAndroid Build Coastguard Worker        z = len({"foo": "bar"})
640*da0073e9SAndroid Build Coastguard Worker        return torch.add(x, z)
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker    @make_test
643*da0073e9SAndroid Build Coastguard Worker    def test_dict_copy(x):
644*da0073e9SAndroid Build Coastguard Worker        z = dict({"foo": x + 1})
645*da0073e9SAndroid Build Coastguard Worker        return z
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker    @make_test
648*da0073e9SAndroid Build Coastguard Worker    def test_dict_keys(x):
649*da0073e9SAndroid Build Coastguard Worker        d = {3: x}
650*da0073e9SAndroid Build Coastguard Worker        keys = d.keys()
651*da0073e9SAndroid Build Coastguard Worker        d[4] = x + 1
652*da0073e9SAndroid Build Coastguard Worker        d2 = {3: 2, 4: "aa"}
653*da0073e9SAndroid Build Coastguard Worker        return 3 in keys, 4 in keys, 5 in keys, d2.keys() == keys
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker    @make_test
656*da0073e9SAndroid Build Coastguard Worker    def test_dict_values(x):
657*da0073e9SAndroid Build Coastguard Worker        d = {3: x}
658*da0073e9SAndroid Build Coastguard Worker        values = d.values()
659*da0073e9SAndroid Build Coastguard Worker        d[3] = x + 1
660*da0073e9SAndroid Build Coastguard Worker        d[4] = x + 2
661*da0073e9SAndroid Build Coastguard Worker        return len(values)
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker    @make_test
664*da0073e9SAndroid Build Coastguard Worker    def test_dict_setdefault1(x):
665*da0073e9SAndroid Build Coastguard Worker        d = {"a": 1, "b": 2}
666*da0073e9SAndroid Build Coastguard Worker        d.setdefault("a", 10)
667*da0073e9SAndroid Build Coastguard Worker        if d["a"] == 1:
668*da0073e9SAndroid Build Coastguard Worker            return x + 1
669*da0073e9SAndroid Build Coastguard Worker        else:
670*da0073e9SAndroid Build Coastguard Worker            return x - 1
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker    @make_test
673*da0073e9SAndroid Build Coastguard Worker    def test_dict_setdefault2(x):
674*da0073e9SAndroid Build Coastguard Worker        d = {"a": 1, "b": 2}
675*da0073e9SAndroid Build Coastguard Worker        d.setdefault("c", 10)
676*da0073e9SAndroid Build Coastguard Worker        if d["c"] == 10:
677*da0073e9SAndroid Build Coastguard Worker            return x + 1
678*da0073e9SAndroid Build Coastguard Worker        else:
679*da0073e9SAndroid Build Coastguard Worker            return x - 1
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker    @make_test
682*da0073e9SAndroid Build Coastguard Worker    def test_dict_setdefault3(x):
683*da0073e9SAndroid Build Coastguard Worker        d = {"a": 1, "b": 2}
684*da0073e9SAndroid Build Coastguard Worker        d.setdefault("c")
685*da0073e9SAndroid Build Coastguard Worker        if d["c"] is None:
686*da0073e9SAndroid Build Coastguard Worker            return x + 1
687*da0073e9SAndroid Build Coastguard Worker        else:
688*da0073e9SAndroid Build Coastguard Worker            return x - 1
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker    @make_test
691*da0073e9SAndroid Build Coastguard Worker    def test_defaultdict_setdefault1(x):
692*da0073e9SAndroid Build Coastguard Worker        d = collections.defaultdict.fromkeys("a", "b")
693*da0073e9SAndroid Build Coastguard Worker        d["a"] = 1
694*da0073e9SAndroid Build Coastguard Worker        d["b"] = 2
695*da0073e9SAndroid Build Coastguard Worker        d.setdefault("a", 10)
696*da0073e9SAndroid Build Coastguard Worker        if d["a"] == 1:
697*da0073e9SAndroid Build Coastguard Worker            return x + 1
698*da0073e9SAndroid Build Coastguard Worker        else:
699*da0073e9SAndroid Build Coastguard Worker            return x - 1
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker    @make_test
702*da0073e9SAndroid Build Coastguard Worker    def test_defaultdict_setdefault2(x):
703*da0073e9SAndroid Build Coastguard Worker        d = collections.defaultdict.fromkeys("a", "b")
704*da0073e9SAndroid Build Coastguard Worker        d["a"] = 1
705*da0073e9SAndroid Build Coastguard Worker        d["b"] = 2
706*da0073e9SAndroid Build Coastguard Worker        d.setdefault("c", 10)
707*da0073e9SAndroid Build Coastguard Worker        if d["c"] == 10:
708*da0073e9SAndroid Build Coastguard Worker            return x + 1
709*da0073e9SAndroid Build Coastguard Worker        else:
710*da0073e9SAndroid Build Coastguard Worker            return x - 1
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker    @make_test
713*da0073e9SAndroid Build Coastguard Worker    def test_defaultdict_setdefault3(x):
714*da0073e9SAndroid Build Coastguard Worker        d = collections.defaultdict.fromkeys("a", "b")
715*da0073e9SAndroid Build Coastguard Worker        d["a"] = 1
716*da0073e9SAndroid Build Coastguard Worker        d["b"] = 2
717*da0073e9SAndroid Build Coastguard Worker        d.setdefault("c")
718*da0073e9SAndroid Build Coastguard Worker        if d["c"] is None:
719*da0073e9SAndroid Build Coastguard Worker            return x + 1
720*da0073e9SAndroid Build Coastguard Worker        else:
721*da0073e9SAndroid Build Coastguard Worker            return x - 1
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker    def test_dict_id_guard(self):
724*da0073e9SAndroid Build Coastguard Worker        d1 = collections.OrderedDict({"a": 2})
725*da0073e9SAndroid Build Coastguard Worker        d2 = d1
726*da0073e9SAndroid Build Coastguard Worker
727*da0073e9SAndroid Build Coastguard Worker        def fn(x):
728*da0073e9SAndroid Build Coastguard Worker            # Iteration forces DictGuardManager
729*da0073e9SAndroid Build Coastguard Worker            for k in d1:
730*da0073e9SAndroid Build Coastguard Worker                x = x * d1[k] * d2[k]
731*da0073e9SAndroid Build Coastguard Worker            return x
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
734*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
735*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker    @make_test
738*da0073e9SAndroid Build Coastguard Worker    def test_callable_lambda(x):
739*da0073e9SAndroid Build Coastguard Worker        if callable(lambda x: True):
740*da0073e9SAndroid Build Coastguard Worker            return x + 1
741*da0073e9SAndroid Build Coastguard Worker        else:
742*da0073e9SAndroid Build Coastguard Worker            return x - 1
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Worker    @make_test
745*da0073e9SAndroid Build Coastguard Worker    def test_callable_torch(x):
746*da0073e9SAndroid Build Coastguard Worker        if callable(torch.abs):
747*da0073e9SAndroid Build Coastguard Worker            return x + 1
748*da0073e9SAndroid Build Coastguard Worker        else:
749*da0073e9SAndroid Build Coastguard Worker            return x - 1
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker    @make_test
752*da0073e9SAndroid Build Coastguard Worker    def test_callable_builtin(x):
753*da0073e9SAndroid Build Coastguard Worker        if callable(sum):
754*da0073e9SAndroid Build Coastguard Worker            return x + 1
755*da0073e9SAndroid Build Coastguard Worker        else:
756*da0073e9SAndroid Build Coastguard Worker            return x - 1
757*da0073e9SAndroid Build Coastguard Worker
758*da0073e9SAndroid Build Coastguard Worker    def test_callable_class(self):
759*da0073e9SAndroid Build Coastguard Worker        class CallableClass:
760*da0073e9SAndroid Build Coastguard Worker            def __call__():
761*da0073e9SAndroid Build Coastguard Worker                pass
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Worker        class NotCallableClass:
764*da0073e9SAndroid Build Coastguard Worker            pass
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
767*da0073e9SAndroid Build Coastguard Worker        def fn1(x, arg):
768*da0073e9SAndroid Build Coastguard Worker            if callable(arg):
769*da0073e9SAndroid Build Coastguard Worker                return x
770*da0073e9SAndroid Build Coastguard Worker            return x + 1
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
773*da0073e9SAndroid Build Coastguard Worker        def fn2(x, arg):
774*da0073e9SAndroid Build Coastguard Worker            if callable(arg):
775*da0073e9SAndroid Build Coastguard Worker                return x * 2
776*da0073e9SAndroid Build Coastguard Worker            return x + 1
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(4)
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker        for f in [fn1, fn2]:
781*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f(input, NotCallableClass()), input + 1)
782*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
783*da0073e9SAndroid Build Coastguard Worker                f(input, CallableClass()), input if f is fn1 else input * 2
784*da0073e9SAndroid Build Coastguard Worker            )
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Worker            # passing tensor and scalars
787*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f(input, 1), input + 1)
788*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f(input, 1.1), input + 1)
789*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f(input, True), input + 1)
790*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(f(input, input), input + 1)
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Worker    def test_callable_list(self):
793*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
794*da0073e9SAndroid Build Coastguard Worker        def fn(x, arg):
795*da0073e9SAndroid Build Coastguard Worker            if callable(arg):
796*da0073e9SAndroid Build Coastguard Worker                return x
797*da0073e9SAndroid Build Coastguard Worker            return x + 1
798*da0073e9SAndroid Build Coastguard Worker
799*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(4)
800*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(input, [1, 2, 3]), input + 1)
801*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(input, (1, 2, 3)), input + 1)
802*da0073e9SAndroid Build Coastguard Worker
803*da0073e9SAndroid Build Coastguard Worker    @make_test
804*da0073e9SAndroid Build Coastguard Worker    def test_len_constant_misc_iterables(x):
805*da0073e9SAndroid Build Coastguard Worker        a = len((1, 2, 3))
806*da0073e9SAndroid Build Coastguard Worker        b = len("test str")
807*da0073e9SAndroid Build Coastguard Worker        c = a + b
808*da0073e9SAndroid Build Coastguard Worker        return torch.add(x, c)
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker    @make_test
811*da0073e9SAndroid Build Coastguard Worker    def test_dict_kwargs(x):
812*da0073e9SAndroid Build Coastguard Worker        z = dict(text_embed=x + 1, other=x + 2)
813*da0073e9SAndroid Build Coastguard Worker        return z
814*da0073e9SAndroid Build Coastguard Worker
815*da0073e9SAndroid Build Coastguard Worker    @make_test
816*da0073e9SAndroid Build Coastguard Worker    def test_ordered_dict_kwargs(x):
817*da0073e9SAndroid Build Coastguard Worker        z = collections.OrderedDict(sample=torch.ones(10))
818*da0073e9SAndroid Build Coastguard Worker        return z
819*da0073e9SAndroid Build Coastguard Worker
820*da0073e9SAndroid Build Coastguard Worker    @make_test
821*da0073e9SAndroid Build Coastguard Worker    def test_custom_dict_kwargs(x):
822*da0073e9SAndroid Build Coastguard Worker        z = CustomDictSubclass(sample=torch.ones(10))
823*da0073e9SAndroid Build Coastguard Worker        return z
824*da0073e9SAndroid Build Coastguard Worker
825*da0073e9SAndroid Build Coastguard Worker    @make_test
826*da0073e9SAndroid Build Coastguard Worker    def test_float(x):
827*da0073e9SAndroid Build Coastguard Worker        y = float(1.2)  # noqa: UP018
828*da0073e9SAndroid Build Coastguard Worker        y += float("1.2")
829*da0073e9SAndroid Build Coastguard Worker        return torch.add(x, y)
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker    @make_test
832*da0073e9SAndroid Build Coastguard Worker    def test_is_floating_point(x):
833*da0073e9SAndroid Build Coastguard Worker        y = x + 1
834*da0073e9SAndroid Build Coastguard Worker        return torch.is_floating_point(y), torch.is_floating_point(input=y)
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker    @make_test
837*da0073e9SAndroid Build Coastguard Worker    def test_dtype(x):
838*da0073e9SAndroid Build Coastguard Worker        if x.dtype == torch.float32:
839*da0073e9SAndroid Build Coastguard Worker            return x + 1
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker    @make_test
842*da0073e9SAndroid Build Coastguard Worker    def test_get_default_dtype(x):
843*da0073e9SAndroid Build Coastguard Worker        if x.dtype == torch.get_default_dtype():
844*da0073e9SAndroid Build Coastguard Worker            return x + 1
845*da0073e9SAndroid Build Coastguard Worker        else:
846*da0073e9SAndroid Build Coastguard Worker            return x - 1
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker    @make_test
849*da0073e9SAndroid Build Coastguard Worker    def test_get_autocast_gpu_dtype(x):
850*da0073e9SAndroid Build Coastguard Worker        dtype = torch.get_autocast_gpu_dtype()
851*da0073e9SAndroid Build Coastguard Worker        return x.type(dtype)
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Worker    @make_test
854*da0073e9SAndroid Build Coastguard Worker    def test_is_any_autocast_enabled(x):
855*da0073e9SAndroid Build Coastguard Worker        if torch._C._is_any_autocast_enabled():
856*da0073e9SAndroid Build Coastguard Worker            return x + 1
857*da0073e9SAndroid Build Coastguard Worker        else:
858*da0073e9SAndroid Build Coastguard Worker            return x - 1
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker    @make_test
861*da0073e9SAndroid Build Coastguard Worker    def test_is_checkpoint_valid(x):
862*da0073e9SAndroid Build Coastguard Worker        if torch.autograd._is_checkpoint_valid():
863*da0073e9SAndroid Build Coastguard Worker            return x + 1
864*da0073e9SAndroid Build Coastguard Worker        else:
865*da0073e9SAndroid Build Coastguard Worker            return x - 1
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Worker    @make_test
868*da0073e9SAndroid Build Coastguard Worker    def test_list_compare_polyfill(x):
869*da0073e9SAndroid Build Coastguard Worker        for a, b, c in [
870*da0073e9SAndroid Build Coastguard Worker            [(1, 2, 3), (1, 2, 3), 7.77],
871*da0073e9SAndroid Build Coastguard Worker            [(1, 4, 3), (1, 2, 3), 3.33],
872*da0073e9SAndroid Build Coastguard Worker            [(1, 2), (1, 2, 3), 5.55],
873*da0073e9SAndroid Build Coastguard Worker            [(1, 2, 3), (1, 2), 11.11],
874*da0073e9SAndroid Build Coastguard Worker            [(1, -1, 3), (1, 2, 3), 13.33],
875*da0073e9SAndroid Build Coastguard Worker        ]:
876*da0073e9SAndroid Build Coastguard Worker            if a != b:
877*da0073e9SAndroid Build Coastguard Worker                x += 1 * c
878*da0073e9SAndroid Build Coastguard Worker            if a == b:
879*da0073e9SAndroid Build Coastguard Worker                x += 2 * c
880*da0073e9SAndroid Build Coastguard Worker            if a < b:
881*da0073e9SAndroid Build Coastguard Worker                x += 4 * c
882*da0073e9SAndroid Build Coastguard Worker            if a > b:
883*da0073e9SAndroid Build Coastguard Worker                x += 8 * c
884*da0073e9SAndroid Build Coastguard Worker            if a <= b:
885*da0073e9SAndroid Build Coastguard Worker                x += 16 * c
886*da0073e9SAndroid Build Coastguard Worker            if a >= b:
887*da0073e9SAndroid Build Coastguard Worker                x += 32 * c
888*da0073e9SAndroid Build Coastguard Worker        return x
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker    @make_test
891*da0073e9SAndroid Build Coastguard Worker    def test_promote_types(x):
892*da0073e9SAndroid Build Coastguard Worker        if x.dtype == torch.promote_types(torch.int32, torch.float32):
893*da0073e9SAndroid Build Coastguard Worker            return x + 1
894*da0073e9SAndroid Build Coastguard Worker        else:
895*da0073e9SAndroid Build Coastguard Worker            return x - 1
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Worker    @make_test
898*da0073e9SAndroid Build Coastguard Worker    def test_cublas_allow_tf32(x):
899*da0073e9SAndroid Build Coastguard Worker        if torch.backends.cuda.matmul.allow_tf32:
900*da0073e9SAndroid Build Coastguard Worker            return x.sin() + 1
901*da0073e9SAndroid Build Coastguard Worker
902*da0073e9SAndroid Build Coastguard Worker        return x.cos() - 1
903*da0073e9SAndroid Build Coastguard Worker
904*da0073e9SAndroid Build Coastguard Worker    @make_test
905*da0073e9SAndroid Build Coastguard Worker    def test_get_calculate_correct_fan(x):
906*da0073e9SAndroid Build Coastguard Worker        fan_in = torch.nn.init._calculate_correct_fan(x, "fan_in")
907*da0073e9SAndroid Build Coastguard Worker        return x + fan_in
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker    @make_test
910*da0073e9SAndroid Build Coastguard Worker    def test_is_complex(x):
911*da0073e9SAndroid Build Coastguard Worker        if torch.is_complex(x):
912*da0073e9SAndroid Build Coastguard Worker            return x + 1
913*da0073e9SAndroid Build Coastguard Worker        else:
914*da0073e9SAndroid Build Coastguard Worker            return x - 1
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker    @make_test
917*da0073e9SAndroid Build Coastguard Worker    def test_tensor_is_complex(x):
918*da0073e9SAndroid Build Coastguard Worker        if x.is_complex():
919*da0073e9SAndroid Build Coastguard Worker            return x + 1
920*da0073e9SAndroid Build Coastguard Worker        else:
921*da0073e9SAndroid Build Coastguard Worker            return x - 1
922*da0073e9SAndroid Build Coastguard Worker
923*da0073e9SAndroid Build Coastguard Worker    @make_test
924*da0073e9SAndroid Build Coastguard Worker    def test_get_privateuse1_name(x):
925*da0073e9SAndroid Build Coastguard Worker        if torch._C._get_privateuse1_backend_name() == "privateuseone":
926*da0073e9SAndroid Build Coastguard Worker            return x + 1
927*da0073e9SAndroid Build Coastguard Worker        else:
928*da0073e9SAndroid Build Coastguard Worker            return x - 1
929*da0073e9SAndroid Build Coastguard Worker
930*da0073e9SAndroid Build Coastguard Worker    @make_test
931*da0073e9SAndroid Build Coastguard Worker    def test_device(x):
932*da0073e9SAndroid Build Coastguard Worker        if not x.is_cuda:
933*da0073e9SAndroid Build Coastguard Worker            return x + 1
934*da0073e9SAndroid Build Coastguard Worker
935*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
936*da0073e9SAndroid Build Coastguard Worker    @make_test
937*da0073e9SAndroid Build Coastguard Worker    def test_get_device_properties_tensor_device(a):
938*da0073e9SAndroid Build Coastguard Worker        x = a.to("cuda")
939*da0073e9SAndroid Build Coastguard Worker        prop = torch.cuda.get_device_properties(x.device)
940*da0073e9SAndroid Build Coastguard Worker        if prop.major == 8:
941*da0073e9SAndroid Build Coastguard Worker            return x + prop.multi_processor_count
942*da0073e9SAndroid Build Coastguard Worker        return x + prop.max_threads_per_multi_processor
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker    @make_test
945*da0073e9SAndroid Build Coastguard Worker    def test_tensor_type(a, b):
946*da0073e9SAndroid Build Coastguard Worker        m = a.to(torch.float16)
947*da0073e9SAndroid Build Coastguard Worker        return b.type(m.type())
948*da0073e9SAndroid Build Coastguard Worker
949*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
950*da0073e9SAndroid Build Coastguard Worker    @make_test
951*da0073e9SAndroid Build Coastguard Worker    def test_tensor_type2(a, b):
952*da0073e9SAndroid Build Coastguard Worker        m = a.to("cuda")
953*da0073e9SAndroid Build Coastguard Worker        return m + b.type(m.type())
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker    @make_test
956*da0073e9SAndroid Build Coastguard Worker    def test_tensor_type3(a, b):
957*da0073e9SAndroid Build Coastguard Worker        m = a.type(torch.HalfTensor)
958*da0073e9SAndroid Build Coastguard Worker        return b.type(m.type())
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Worker    @make_test
961*da0073e9SAndroid Build Coastguard Worker    def test_tensor_type4(a, b):
962*da0073e9SAndroid Build Coastguard Worker        m = a.type("torch.HalfTensor")
963*da0073e9SAndroid Build Coastguard Worker        return b.type(m.type())
964*da0073e9SAndroid Build Coastguard Worker
965*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
966*da0073e9SAndroid Build Coastguard Worker    @make_test
967*da0073e9SAndroid Build Coastguard Worker    def test_tensor_type5(a, b):
968*da0073e9SAndroid Build Coastguard Worker        m = a.type(torch.cuda.HalfTensor)
969*da0073e9SAndroid Build Coastguard Worker        return b.type(m.type())
970*da0073e9SAndroid Build Coastguard Worker
971*da0073e9SAndroid Build Coastguard Worker    @make_test
972*da0073e9SAndroid Build Coastguard Worker    def test_tensor_element_size(a):
973*da0073e9SAndroid Build Coastguard Worker        if a.element_size() > 1:
974*da0073e9SAndroid Build Coastguard Worker            return (a + a.element_size(), a - a.element_size())
975*da0073e9SAndroid Build Coastguard Worker        return (a - a.element_size(), a + a.element_size())
976*da0073e9SAndroid Build Coastguard Worker
977*da0073e9SAndroid Build Coastguard Worker    @make_test
978*da0073e9SAndroid Build Coastguard Worker    def test_ndim(x):
979*da0073e9SAndroid Build Coastguard Worker        if x.ndim == 2 and x.ndimension() == 2 and x.dim() == 2:
980*da0073e9SAndroid Build Coastguard Worker            return x + 1
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker    @make_test
983*da0073e9SAndroid Build Coastguard Worker    def test_T(x):
984*da0073e9SAndroid Build Coastguard Worker        return torch.ones_like(x.T)
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker    @make_test
987*da0073e9SAndroid Build Coastguard Worker    def test_mT(x):
988*da0073e9SAndroid Build Coastguard Worker        return torch.ones_like(x.mT)
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker    @make_test
991*da0073e9SAndroid Build Coastguard Worker    def test_is_sparse(x):
992*da0073e9SAndroid Build Coastguard Worker        if not x.is_sparse:
993*da0073e9SAndroid Build Coastguard Worker            return x + 1
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard Worker    @make_test
996*da0073e9SAndroid Build Coastguard Worker    def test_shape1(x):
997*da0073e9SAndroid Build Coastguard Worker        if x.shape[0] == 10:
998*da0073e9SAndroid Build Coastguard Worker            return x + 1
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker    @make_test
1001*da0073e9SAndroid Build Coastguard Worker    def test_shape2(x):
1002*da0073e9SAndroid Build Coastguard Worker        if x.size(1) == 10:
1003*da0073e9SAndroid Build Coastguard Worker            return x + 1
1004*da0073e9SAndroid Build Coastguard Worker
1005*da0073e9SAndroid Build Coastguard Worker    @make_test
1006*da0073e9SAndroid Build Coastguard Worker    def test_del(a, b):
1007*da0073e9SAndroid Build Coastguard Worker        c = a + 1
1008*da0073e9SAndroid Build Coastguard Worker        d = c + 2
1009*da0073e9SAndroid Build Coastguard Worker        del c, a
1010*da0073e9SAndroid Build Coastguard Worker        return b + d
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker    @make_test
1013*da0073e9SAndroid Build Coastguard Worker    def test_chunks1(x):
1014*da0073e9SAndroid Build Coastguard Worker        chunk_size = 5
1015*da0073e9SAndroid Build Coastguard Worker        assert x.shape[0] % chunk_size == 0
1016*da0073e9SAndroid Build Coastguard Worker        assert x.shape[0] // chunk_size == 2
1017*da0073e9SAndroid Build Coastguard Worker        return x[:chunk_size] - x[chunk_size:]
1018*da0073e9SAndroid Build Coastguard Worker
1019*da0073e9SAndroid Build Coastguard Worker    @make_test
1020*da0073e9SAndroid Build Coastguard Worker    def test_import1(x, y):
1021*da0073e9SAndroid Build Coastguard Worker        import torch
1022*da0073e9SAndroid Build Coastguard Worker        from torch import sub
1023*da0073e9SAndroid Build Coastguard Worker
1024*da0073e9SAndroid Build Coastguard Worker        return sub(torch.add(x, y), y)
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker    @make_test
1027*da0073e9SAndroid Build Coastguard Worker    def test_return_dict(x, y):
1028*da0073e9SAndroid Build Coastguard Worker        z = [x + y, y, False]
1029*da0073e9SAndroid Build Coastguard Worker        return {"x": x, "z": z, "a": x, "b": z, "c": x}
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker    @make_test
1032*da0073e9SAndroid Build Coastguard Worker    def test_return_dict2(x, y):
1033*da0073e9SAndroid Build Coastguard Worker        tmp = {"x": x}
1034*da0073e9SAndroid Build Coastguard Worker        tmp["z"] = [x + y, y]
1035*da0073e9SAndroid Build Coastguard Worker        tmp["y"] = y
1036*da0073e9SAndroid Build Coastguard Worker        tmp["z"].append(False)
1037*da0073e9SAndroid Build Coastguard Worker        return tmp
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker    @make_test
1040*da0073e9SAndroid Build Coastguard Worker    def test_funcdef_closure(x, y):
1041*da0073e9SAndroid Build Coastguard Worker        x = x + y + 1.0
1042*da0073e9SAndroid Build Coastguard Worker
1043*da0073e9SAndroid Build Coastguard Worker        def inner(z):
1044*da0073e9SAndroid Build Coastguard Worker            nonlocal x, y
1045*da0073e9SAndroid Build Coastguard Worker            y = x + z + 20.0
1046*da0073e9SAndroid Build Coastguard Worker            x = y + z + 10.0
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker        inner(2.0)
1049*da0073e9SAndroid Build Coastguard Worker        inner(3.0)
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Worker        return x, y
1052*da0073e9SAndroid Build Coastguard Worker
1053*da0073e9SAndroid Build Coastguard Worker    @make_test
1054*da0073e9SAndroid Build Coastguard Worker    def test_module_constant(x, y):
1055*da0073e9SAndroid Build Coastguard Worker        r = x + y
1056*da0073e9SAndroid Build Coastguard Worker        for i in range(torch._dynamo.testing.three):
1057*da0073e9SAndroid Build Coastguard Worker            r = r / y
1058*da0073e9SAndroid Build Coastguard Worker        return r
1059*da0073e9SAndroid Build Coastguard Worker
1060*da0073e9SAndroid Build Coastguard Worker    @make_test
1061*da0073e9SAndroid Build Coastguard Worker    def test_inline_softmax(x, y):
1062*da0073e9SAndroid Build Coastguard Worker        # This is common in sme huggingface models
1063*da0073e9SAndroid Build Coastguard Worker        return torch.nn.Softmax(dim=-1)(x + y * 2)
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker    @make_test
1066*da0073e9SAndroid Build Coastguard Worker    def test_dtype_compare(a, b):
1067*da0073e9SAndroid Build Coastguard Worker        if a.dtype == torch.float16:
1068*da0073e9SAndroid Build Coastguard Worker            return a + 10
1069*da0073e9SAndroid Build Coastguard Worker        if a.dtype == torch.float32:
1070*da0073e9SAndroid Build Coastguard Worker            return a - b * 32
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker    @make_test
1073*da0073e9SAndroid Build Coastguard Worker    def test_build_list_unpack(a, b):
1074*da0073e9SAndroid Build Coastguard Worker        it1 = (x + 1 for x in (a, b))
1075*da0073e9SAndroid Build Coastguard Worker        it2 = (x - 1 for x in (a, b))
1076*da0073e9SAndroid Build Coastguard Worker        return torch.cat([*it1, *it2], dim=-1)
1077*da0073e9SAndroid Build Coastguard Worker
1078*da0073e9SAndroid Build Coastguard Worker    @make_test
1079*da0073e9SAndroid Build Coastguard Worker    def test_tensor_len(a, b):
1080*da0073e9SAndroid Build Coastguard Worker        return a + b + len(a) + b.__len__()
1081*da0073e9SAndroid Build Coastguard Worker
1082*da0073e9SAndroid Build Coastguard Worker    @make_test
1083*da0073e9SAndroid Build Coastguard Worker    def test_pop(a, b):
1084*da0073e9SAndroid Build Coastguard Worker        ll = [a, b]
1085*da0073e9SAndroid Build Coastguard Worker        ll.append(a + 1)
1086*da0073e9SAndroid Build Coastguard Worker        ll.extend(
1087*da0073e9SAndroid Build Coastguard Worker            [
1088*da0073e9SAndroid Build Coastguard Worker                b + 2,
1089*da0073e9SAndroid Build Coastguard Worker                a + b,
1090*da0073e9SAndroid Build Coastguard Worker            ]
1091*da0073e9SAndroid Build Coastguard Worker        )
1092*da0073e9SAndroid Build Coastguard Worker        ll.pop(-1)
1093*da0073e9SAndroid Build Coastguard Worker        ll.pop(0)
1094*da0073e9SAndroid Build Coastguard Worker        ll.pop()
1095*da0073e9SAndroid Build Coastguard Worker        v1, v2 = ll
1096*da0073e9SAndroid Build Coastguard Worker        return v1 - v2
1097*da0073e9SAndroid Build Coastguard Worker
1098*da0073e9SAndroid Build Coastguard Worker    @make_test
1099*da0073e9SAndroid Build Coastguard Worker    def test_list_convert(a, b):
1100*da0073e9SAndroid Build Coastguard Worker        ll = [a + 2, b]
1101*da0073e9SAndroid Build Coastguard Worker        ll = tuple(ll)
1102*da0073e9SAndroid Build Coastguard Worker        tmp = b + 3
1103*da0073e9SAndroid Build Coastguard Worker        ll = list(ll)
1104*da0073e9SAndroid Build Coastguard Worker        v1, v2 = ll
1105*da0073e9SAndroid Build Coastguard Worker        return v1 - v2 + tmp
1106*da0073e9SAndroid Build Coastguard Worker
1107*da0073e9SAndroid Build Coastguard Worker    @make_test
1108*da0073e9SAndroid Build Coastguard Worker    def test_list_add(a, b):
1109*da0073e9SAndroid Build Coastguard Worker        l1 = (a, b)
1110*da0073e9SAndroid Build Coastguard Worker        l2 = ()  # being a LOAD_CONST in the bytecode
1111*da0073e9SAndroid Build Coastguard Worker        l3 = l1 + l2
1112*da0073e9SAndroid Build Coastguard Worker        return l3[0] + l3[1]
1113*da0073e9SAndroid Build Coastguard Worker
1114*da0073e9SAndroid Build Coastguard Worker    @make_test
1115*da0073e9SAndroid Build Coastguard Worker    def test_list_index_with_constant_tensor(a, b):
1116*da0073e9SAndroid Build Coastguard Worker        l1 = [a, b, a + 1, b + 1]
1117*da0073e9SAndroid Build Coastguard Worker        return l1[torch.as_tensor(2)]
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Worker    @make_test
1120*da0073e9SAndroid Build Coastguard Worker    def test_startswith(a, b):
1121*da0073e9SAndroid Build Coastguard Worker        x = a + b
1122*da0073e9SAndroid Build Coastguard Worker        if "foobar".startswith("foo") and "test" in constant3.__module__:
1123*da0073e9SAndroid Build Coastguard Worker            x = x + 1
1124*da0073e9SAndroid Build Coastguard Worker        return x
1125*da0073e9SAndroid Build Coastguard Worker
1126*da0073e9SAndroid Build Coastguard Worker    @make_test
1127*da0073e9SAndroid Build Coastguard Worker    def test_dict_ops(a, b):
1128*da0073e9SAndroid Build Coastguard Worker        tmp = {"a": a + 1, "b": b + 2}
1129*da0073e9SAndroid Build Coastguard Worker        assert tmp.get("zzz") is None
1130*da0073e9SAndroid Build Coastguard Worker        v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4)
1131*da0073e9SAndroid Build Coastguard Worker        tmp.update({"d": 3})
1132*da0073e9SAndroid Build Coastguard Worker        tmp["c"] = v + tmp["d"]
1133*da0073e9SAndroid Build Coastguard Worker        if "c" in tmp and "missing" not in tmp:
1134*da0073e9SAndroid Build Coastguard Worker            return tmp["c"] - tmp["a"] + len(tmp)
1135*da0073e9SAndroid Build Coastguard Worker
1136*da0073e9SAndroid Build Coastguard Worker    @make_test
1137*da0073e9SAndroid Build Coastguard Worker    def test_inline_jit__unwrap_optional(x):
1138*da0073e9SAndroid Build Coastguard Worker        if torch.jit._unwrap_optional(x) is None:
1139*da0073e9SAndroid Build Coastguard Worker            return torch.ones(2, 2)
1140*da0073e9SAndroid Build Coastguard Worker        return x.sin()
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker    @make_test
1143*da0073e9SAndroid Build Coastguard Worker    def test_zip_longest(x):
1144*da0073e9SAndroid Build Coastguard Worker        list1 = [1, 2, 3]
1145*da0073e9SAndroid Build Coastguard Worker        list2 = ["a", "b"]
1146*da0073e9SAndroid Build Coastguard Worker        list3 = [True, False, True, False]
1147*da0073e9SAndroid Build Coastguard Worker        return torch.sin(x + 1), list(
1148*da0073e9SAndroid Build Coastguard Worker            itertools.zip_longest(list1, list2, list3, fillvalue=None)
1149*da0073e9SAndroid Build Coastguard Worker        )
1150*da0073e9SAndroid Build Coastguard Worker
1151*da0073e9SAndroid Build Coastguard Worker    def test_torch_size_as_dict_key(self):
1152*da0073e9SAndroid Build Coastguard Worker        def fn(x, cached):
1153*da0073e9SAndroid Build Coastguard Worker            if x.shape not in cached:
1154*da0073e9SAndroid Build Coastguard Worker                cached[x.shape] = x
1155*da0073e9SAndroid Build Coastguard Worker            return x + cached[x.shape]
1156*da0073e9SAndroid Build Coastguard Worker
1157*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1158*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn(2, 3)
1159*da0073e9SAndroid Build Coastguard Worker        x2 = torch.randn(2, 3)
1160*da0073e9SAndroid Build Coastguard Worker        cached = {}
1161*da0073e9SAndroid Build Coastguard Worker        ref1 = fn(x1, cached)
1162*da0073e9SAndroid Build Coastguard Worker        ref2 = fn(x2, cached)
1163*da0073e9SAndroid Build Coastguard Worker        cached = {}
1164*da0073e9SAndroid Build Coastguard Worker        res1 = opt_fn(x1, cached)
1165*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(x2, cached)
1166*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref1, res1)
1167*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref2, res2)
1168*da0073e9SAndroid Build Coastguard Worker
1169*da0073e9SAndroid Build Coastguard Worker    def test_dict_param_keys(self):
1170*da0073e9SAndroid Build Coastguard Worker        a_param = torch.nn.Parameter(torch.ones([4, 4]))
1171*da0073e9SAndroid Build Coastguard Worker
1172*da0073e9SAndroid Build Coastguard Worker        def fn(a):
1173*da0073e9SAndroid Build Coastguard Worker            tmp = {"a": a, a_param: 3}
1174*da0073e9SAndroid Build Coastguard Worker            return tmp["a"] + tmp[a_param]
1175*da0073e9SAndroid Build Coastguard Worker
1176*da0073e9SAndroid Build Coastguard Worker        test = make_test(fn)
1177*da0073e9SAndroid Build Coastguard Worker        test(self)
1178*da0073e9SAndroid Build Coastguard Worker
1179*da0073e9SAndroid Build Coastguard Worker    def test_dict_mutable_map(self):
1180*da0073e9SAndroid Build Coastguard Worker        from collections.abc import MutableMapping
1181*da0073e9SAndroid Build Coastguard Worker
1182*da0073e9SAndroid Build Coastguard Worker        class TensorDict(MutableMapping):
1183*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1184*da0073e9SAndroid Build Coastguard Worker                self._dict = {}
1185*da0073e9SAndroid Build Coastguard Worker
1186*da0073e9SAndroid Build Coastguard Worker            def add(self, key, value):
1187*da0073e9SAndroid Build Coastguard Worker                self._dict[key] = value
1188*da0073e9SAndroid Build Coastguard Worker
1189*da0073e9SAndroid Build Coastguard Worker            def items(self):
1190*da0073e9SAndroid Build Coastguard Worker                return self._dict.items()
1191*da0073e9SAndroid Build Coastguard Worker
1192*da0073e9SAndroid Build Coastguard Worker            def __delitem__(self, key):
1193*da0073e9SAndroid Build Coastguard Worker                del self._dict[key]
1194*da0073e9SAndroid Build Coastguard Worker
1195*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, key):
1196*da0073e9SAndroid Build Coastguard Worker                return self._dict[key]
1197*da0073e9SAndroid Build Coastguard Worker
1198*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
1199*da0073e9SAndroid Build Coastguard Worker                return iter(self._dict)
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
1202*da0073e9SAndroid Build Coastguard Worker                return len(self._dict)
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker            def __setitem__(self, key, value):
1205*da0073e9SAndroid Build Coastguard Worker                self._dict[key] = value
1206*da0073e9SAndroid Build Coastguard Worker
1207*da0073e9SAndroid Build Coastguard Worker        tensor_dict = TensorDict()
1208*da0073e9SAndroid Build Coastguard Worker        tensor_dict.add("a", torch.ones(4) * 2)
1209*da0073e9SAndroid Build Coastguard Worker
1210*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1211*da0073e9SAndroid Build Coastguard Worker            copy_tensordict = dict(tensor_dict)
1212*da0073e9SAndroid Build Coastguard Worker            return x * copy_tensordict["a"]
1213*da0073e9SAndroid Build Coastguard Worker
1214*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1215*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
1216*da0073e9SAndroid Build Coastguard Worker
1217*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
1218*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
1219*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker    def test_unpack_mutable_map(self):
1222*da0073e9SAndroid Build Coastguard Worker        from collections.abc import MutableMapping
1223*da0073e9SAndroid Build Coastguard Worker
1224*da0073e9SAndroid Build Coastguard Worker        class TensorDict(MutableMapping):
1225*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1226*da0073e9SAndroid Build Coastguard Worker                self._dict = {}
1227*da0073e9SAndroid Build Coastguard Worker
1228*da0073e9SAndroid Build Coastguard Worker            def add(self, key, value):
1229*da0073e9SAndroid Build Coastguard Worker                self._dict[key] = value
1230*da0073e9SAndroid Build Coastguard Worker
1231*da0073e9SAndroid Build Coastguard Worker            def items(self):
1232*da0073e9SAndroid Build Coastguard Worker                return self._dict.items()
1233*da0073e9SAndroid Build Coastguard Worker
1234*da0073e9SAndroid Build Coastguard Worker            def __delitem__(self, key):
1235*da0073e9SAndroid Build Coastguard Worker                del self._dict[key]
1236*da0073e9SAndroid Build Coastguard Worker
1237*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, key):
1238*da0073e9SAndroid Build Coastguard Worker                return self._dict[key]
1239*da0073e9SAndroid Build Coastguard Worker
1240*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
1241*da0073e9SAndroid Build Coastguard Worker                return iter(self._dict)
1242*da0073e9SAndroid Build Coastguard Worker
1243*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
1244*da0073e9SAndroid Build Coastguard Worker                return len(self._dict)
1245*da0073e9SAndroid Build Coastguard Worker
1246*da0073e9SAndroid Build Coastguard Worker            def __setitem__(self, key, value):
1247*da0073e9SAndroid Build Coastguard Worker                self._dict[key] = value
1248*da0073e9SAndroid Build Coastguard Worker
1249*da0073e9SAndroid Build Coastguard Worker        tensor_dict = TensorDict()
1250*da0073e9SAndroid Build Coastguard Worker        tensor_dict.add("a", torch.ones(4) * 2)
1251*da0073e9SAndroid Build Coastguard Worker
1252*da0073e9SAndroid Build Coastguard Worker        def gn(x, a=1):
1253*da0073e9SAndroid Build Coastguard Worker            return x * a
1254*da0073e9SAndroid Build Coastguard Worker
1255*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1256*da0073e9SAndroid Build Coastguard Worker            return gn(x, **tensor_dict)
1257*da0073e9SAndroid Build Coastguard Worker
1258*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1259*da0073e9SAndroid Build Coastguard Worker
1260*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
1261*da0073e9SAndroid Build Coastguard Worker
1262*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
1263*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
1264*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
1265*da0073e9SAndroid Build Coastguard Worker
1266*da0073e9SAndroid Build Coastguard Worker    def _test_default_dict_helper(self, factory):
1267*da0073e9SAndroid Build Coastguard Worker        dd = collections.defaultdict(factory)
1268*da0073e9SAndroid Build Coastguard Worker        param = torch.nn.Parameter(torch.ones([2, 2]))
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1271*da0073e9SAndroid Build Coastguard Worker            dd["a"] = x + 1
1272*da0073e9SAndroid Build Coastguard Worker            dd[param] = 123
1273*da0073e9SAndroid Build Coastguard Worker            dd["c"] = x * 2
1274*da0073e9SAndroid Build Coastguard Worker            return dd["b"], dd
1275*da0073e9SAndroid Build Coastguard Worker
1276*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
1277*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
1278*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize_assert("eager")(fn)
1279*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
1280*da0073e9SAndroid Build Coastguard Worker
1281*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref[0], res[0]))
1282*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref[1]["a"], res[1]["a"]))
1283*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref[1]["c"], res[1]["c"]))
1284*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref[1][param], res[1][param]))
1285*da0073e9SAndroid Build Coastguard Worker
1286*da0073e9SAndroid Build Coastguard Worker    def test_default_dict_dict(self):
1287*da0073e9SAndroid Build Coastguard Worker        self._test_default_dict_helper(dict)
1288*da0073e9SAndroid Build Coastguard Worker
1289*da0073e9SAndroid Build Coastguard Worker    def test_default_dict_list(self):
1290*da0073e9SAndroid Build Coastguard Worker        self._test_default_dict_helper(list)
1291*da0073e9SAndroid Build Coastguard Worker
1292*da0073e9SAndroid Build Coastguard Worker    def test_default_dict_tuple(self):
1293*da0073e9SAndroid Build Coastguard Worker        self._test_default_dict_helper(tuple)
1294*da0073e9SAndroid Build Coastguard Worker
1295*da0073e9SAndroid Build Coastguard Worker    def test_default_dict_set(self):
1296*da0073e9SAndroid Build Coastguard Worker        self._test_default_dict_helper(set)
1297*da0073e9SAndroid Build Coastguard Worker
1298*da0073e9SAndroid Build Coastguard Worker    def test_default_dict_lambda(self):
1299*da0073e9SAndroid Build Coastguard Worker        self._test_default_dict_helper(lambda: dict())  # noqa: C408
1300*da0073e9SAndroid Build Coastguard Worker
1301*da0073e9SAndroid Build Coastguard Worker    def test_default_dict_closure(self):
1302*da0073e9SAndroid Build Coastguard Worker        def factory():
1303*da0073e9SAndroid Build Coastguard Worker            return dict()  # noqa: C408
1304*da0073e9SAndroid Build Coastguard Worker
1305*da0073e9SAndroid Build Coastguard Worker        self._test_default_dict_helper(factory)
1306*da0073e9SAndroid Build Coastguard Worker
1307*da0073e9SAndroid Build Coastguard Worker    def test_class_dict(self):
1308*da0073e9SAndroid Build Coastguard Worker        class A:
1309*da0073e9SAndroid Build Coastguard Worker            x = 4
1310*da0073e9SAndroid Build Coastguard Worker            y = 5
1311*da0073e9SAndroid Build Coastguard Worker
1312*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1313*da0073e9SAndroid Build Coastguard Worker                self.a = 6
1314*da0073e9SAndroid Build Coastguard Worker
1315*da0073e9SAndroid Build Coastguard Worker        a = A()
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1318*da0073e9SAndroid Build Coastguard Worker            if "x" in type(a).__dict__:
1319*da0073e9SAndroid Build Coastguard Worker                return x + 1
1320*da0073e9SAndroid Build Coastguard Worker            return x + 2
1321*da0073e9SAndroid Build Coastguard Worker
1322*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1323*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
1324*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
1325*da0073e9SAndroid Build Coastguard Worker
1326*da0073e9SAndroid Build Coastguard Worker    def test_default_dict_constr(self):
1327*da0073e9SAndroid Build Coastguard Worker        param = torch.nn.Parameter(torch.ones([2, 2]))
1328*da0073e9SAndroid Build Coastguard Worker
1329*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1330*da0073e9SAndroid Build Coastguard Worker            dd = collections.defaultdict(lambda: dict())  # noqa: C408
1331*da0073e9SAndroid Build Coastguard Worker            dd["a"] = x + 1
1332*da0073e9SAndroid Build Coastguard Worker            dd[param] = 123
1333*da0073e9SAndroid Build Coastguard Worker            dd["c"] = x * 2
1334*da0073e9SAndroid Build Coastguard Worker            dd.update({"b": x * 3})
1335*da0073e9SAndroid Build Coastguard Worker            dd.update([["d", x - 2], ("e", x + 2)])
1336*da0073e9SAndroid Build Coastguard Worker            dd.update(zip("ab", [x + 3, x + 4]))
1337*da0073e9SAndroid Build Coastguard Worker            return dd["b"], dd
1338*da0073e9SAndroid Build Coastguard Worker
1339*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
1340*da0073e9SAndroid Build Coastguard Worker        ref = fn(x)
1341*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize_assert("eager")(fn)
1342*da0073e9SAndroid Build Coastguard Worker        res = opt_fn(x)
1343*da0073e9SAndroid Build Coastguard Worker
1344*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref[0], res[0]))
1345*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref[1]["a"], res[1]["a"]))
1346*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref[1]["b"], res[1]["b"]))
1347*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref[1]["c"], res[1]["c"]))
1348*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref[1]["d"], res[1]["d"]))
1349*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref[1]["e"], res[1]["e"]))
1350*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(ref[1][param], res[1][param]))
1351*da0073e9SAndroid Build Coastguard Worker
1352*da0073e9SAndroid Build Coastguard Worker    def test_dict_tuple_lazy_guard(self):
1353*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
1354*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
1355*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x) * y[1]
1356*da0073e9SAndroid Build Coastguard Worker
1357*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(3), {1: 1, 2: 2})
1358*da0073e9SAndroid Build Coastguard Worker        # Changing the value of other key should not causing recompilation
1359*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
1360*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(3), {1: 1, 2: 3})
1361*da0073e9SAndroid Build Coastguard Worker
1362*da0073e9SAndroid Build Coastguard Worker        fn(torch.randn(3), (1, 2, 3))
1363*da0073e9SAndroid Build Coastguard Worker        # Changing the value of index 0, 2 (not 1) should not cause recompilation
1364*da0073e9SAndroid Build Coastguard Worker        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
1365*da0073e9SAndroid Build Coastguard Worker            fn(torch.randn(3), (11, 2, 13))
1366*da0073e9SAndroid Build Coastguard Worker
1367*da0073e9SAndroid Build Coastguard Worker    @make_test
1368*da0073e9SAndroid Build Coastguard Worker    def test_call_dict1(x):
1369*da0073e9SAndroid Build Coastguard Worker        d1 = dict()  # noqa: C408
1370*da0073e9SAndroid Build Coastguard Worker        d1["x"] = x + 1
1371*da0073e9SAndroid Build Coastguard Worker        d2 = collections.OrderedDict()
1372*da0073e9SAndroid Build Coastguard Worker        d2["x"] = x + 2
1373*da0073e9SAndroid Build Coastguard Worker        return d1["x"] + d2["x"] + 1
1374*da0073e9SAndroid Build Coastguard Worker
1375*da0073e9SAndroid Build Coastguard Worker    @make_test
1376*da0073e9SAndroid Build Coastguard Worker    def test_call_dict2(x):
1377*da0073e9SAndroid Build Coastguard Worker        d1 = dict()  # noqa: C408
1378*da0073e9SAndroid Build Coastguard Worker        d1["x"] = x
1379*da0073e9SAndroid Build Coastguard Worker        d2 = collections.OrderedDict(d1)
1380*da0073e9SAndroid Build Coastguard Worker        if isinstance(d2, collections.OrderedDict):
1381*da0073e9SAndroid Build Coastguard Worker            return x + 1
1382*da0073e9SAndroid Build Coastguard Worker        else:
1383*da0073e9SAndroid Build Coastguard Worker            return x - 1
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker    @make_test
1386*da0073e9SAndroid Build Coastguard Worker    def test_call_dict3(x):
1387*da0073e9SAndroid Build Coastguard Worker        my_list = [("a", x), ("b", x + 1), ("c", x + 2)]
1388*da0073e9SAndroid Build Coastguard Worker        d1 = dict(my_list)
1389*da0073e9SAndroid Build Coastguard Worker        d1["a"] = x + 10
1390*da0073e9SAndroid Build Coastguard Worker        d2 = collections.OrderedDict(my_list)
1391*da0073e9SAndroid Build Coastguard Worker        d2["c"] = x + 20
1392*da0073e9SAndroid Build Coastguard Worker        return d1["a"] + d2["c"] + 1
1393*da0073e9SAndroid Build Coastguard Worker
1394*da0073e9SAndroid Build Coastguard Worker    @make_test
1395*da0073e9SAndroid Build Coastguard Worker    def test_call_dict4(x):
1396*da0073e9SAndroid Build Coastguard Worker        my_list = (("a", x), ("b", x + 1), ("c", x + 2))
1397*da0073e9SAndroid Build Coastguard Worker        d1 = dict(my_list)
1398*da0073e9SAndroid Build Coastguard Worker        d1["a"] = x + 10
1399*da0073e9SAndroid Build Coastguard Worker        d2 = collections.OrderedDict(my_list)
1400*da0073e9SAndroid Build Coastguard Worker        d2["c"] = x + 20
1401*da0073e9SAndroid Build Coastguard Worker        return d1["a"] + d2["c"] + 1
1402*da0073e9SAndroid Build Coastguard Worker
1403*da0073e9SAndroid Build Coastguard Worker    @make_test
1404*da0073e9SAndroid Build Coastguard Worker    def test_call_dict5(x):
1405*da0073e9SAndroid Build Coastguard Worker        my_list = iter([("a", x), ("b", x + 1), ("c", x + 2)])
1406*da0073e9SAndroid Build Coastguard Worker        d1 = dict(my_list)
1407*da0073e9SAndroid Build Coastguard Worker        d1["a"] = x + 10
1408*da0073e9SAndroid Build Coastguard Worker        d2 = collections.OrderedDict(my_list)
1409*da0073e9SAndroid Build Coastguard Worker        d2["c"] = x + 20
1410*da0073e9SAndroid Build Coastguard Worker        return d1["a"] + d2["c"] + 1
1411*da0073e9SAndroid Build Coastguard Worker
1412*da0073e9SAndroid Build Coastguard Worker    @make_test
1413*da0073e9SAndroid Build Coastguard Worker    def test_dict_fromkeys(x, y):
1414*da0073e9SAndroid Build Coastguard Worker        lst = ["a", "b"]
1415*da0073e9SAndroid Build Coastguard Worker        d = dict.fromkeys(lst)
1416*da0073e9SAndroid Build Coastguard Worker        d1 = dict.fromkeys(d, x + 1)
1417*da0073e9SAndroid Build Coastguard Worker        d2 = collections.defaultdict.fromkeys(iter(d1), x - 2)
1418*da0073e9SAndroid Build Coastguard Worker        d3 = collections.OrderedDict.fromkeys(tuple(lst), value=y)
1419*da0073e9SAndroid Build Coastguard Worker        return d1["a"] * d2["b"] + d2["a"] + d1["b"] + d3["a"] + d3["b"] + 1
1420*da0073e9SAndroid Build Coastguard Worker
1421*da0073e9SAndroid Build Coastguard Worker    @make_test
1422*da0073e9SAndroid Build Coastguard Worker    def test_dict_copy(x):
1423*da0073e9SAndroid Build Coastguard Worker        my_list = [("a", x), ("b", x + 1), ("c", x + 2)]
1424*da0073e9SAndroid Build Coastguard Worker        d1 = dict(my_list)
1425*da0073e9SAndroid Build Coastguard Worker        d1["a"] = x + 10
1426*da0073e9SAndroid Build Coastguard Worker        d2 = d1.copy()
1427*da0073e9SAndroid Build Coastguard Worker        d2["a"] = x - 5
1428*da0073e9SAndroid Build Coastguard Worker        d2["b"] = x + 3
1429*da0073e9SAndroid Build Coastguard Worker        d3 = collections.OrderedDict(my_list)
1430*da0073e9SAndroid Build Coastguard Worker        d3["c"] = x + 20
1431*da0073e9SAndroid Build Coastguard Worker        d4 = d3.copy()
1432*da0073e9SAndroid Build Coastguard Worker        d4["c"] = x - 10
1433*da0073e9SAndroid Build Coastguard Worker        return d1["a"] * d2["a"] + d2["b"] + d3["c"] * d4["c"] + 1
1434*da0073e9SAndroid Build Coastguard Worker
1435*da0073e9SAndroid Build Coastguard Worker    @make_test
1436*da0073e9SAndroid Build Coastguard Worker    def test_dict_update(x, y, z):
1437*da0073e9SAndroid Build Coastguard Worker        d = {"a": x, "b": y}
1438*da0073e9SAndroid Build Coastguard Worker        d.update({"a": y - 1})
1439*da0073e9SAndroid Build Coastguard Worker        d.update([("b", z + 1), ["c", z]])
1440*da0073e9SAndroid Build Coastguard Worker        d.update(zip("ab", [z + 3, y + 2]))
1441*da0073e9SAndroid Build Coastguard Worker
1442*da0073e9SAndroid Build Coastguard Worker        od = collections.OrderedDict(a=x * 3, b=y + 2)
1443*da0073e9SAndroid Build Coastguard Worker        od.update({"a": y + 5})
1444*da0073e9SAndroid Build Coastguard Worker        od.update([["b", z + 6], ("c", z - 7)])
1445*da0073e9SAndroid Build Coastguard Worker        od.update(zip("ab", [z - 3, x + 2]))
1446*da0073e9SAndroid Build Coastguard Worker        return d["a"] * od["a"] + od["c"] + d["b"] + od["b"] * d["c"]
1447*da0073e9SAndroid Build Coastguard Worker
1448*da0073e9SAndroid Build Coastguard Worker    @make_test
1449*da0073e9SAndroid Build Coastguard Worker    def test_min_max(a, b):
1450*da0073e9SAndroid Build Coastguard Worker        c = a + b
1451*da0073e9SAndroid Build Coastguard Worker        a = a.sum()
1452*da0073e9SAndroid Build Coastguard Worker        b = b.sum()
1453*da0073e9SAndroid Build Coastguard Worker        a = min(max(a, 0), 1)
1454*da0073e9SAndroid Build Coastguard Worker        b = max(0, min(1, b))
1455*da0073e9SAndroid Build Coastguard Worker        return max(a, b) - min(a, b) + c
1456*da0073e9SAndroid Build Coastguard Worker
1457*da0073e9SAndroid Build Coastguard Worker    @make_test
1458*da0073e9SAndroid Build Coastguard Worker    def test_symbool_to_int(x):
1459*da0073e9SAndroid Build Coastguard Worker        # this is roughly the pattern found in einops.unpack()
1460*da0073e9SAndroid Build Coastguard Worker        if sum(s == -1 for s in x.size()) == 0:
1461*da0073e9SAndroid Build Coastguard Worker            return x + 1
1462*da0073e9SAndroid Build Coastguard Worker        else:
1463*da0073e9SAndroid Build Coastguard Worker            return x - 1
1464*da0073e9SAndroid Build Coastguard Worker
1465*da0073e9SAndroid Build Coastguard Worker    @make_test
1466*da0073e9SAndroid Build Coastguard Worker    def test_map_sum(a, b, c, d):
1467*da0073e9SAndroid Build Coastguard Worker        return sum(map(lambda x: x + 1, [a, b, c, d]))
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Worker    @make_test
1470*da0073e9SAndroid Build Coastguard Worker    def test_sum(a, b, c, d):
1471*da0073e9SAndroid Build Coastguard Worker        return sum([a, b, c, d])
1472*da0073e9SAndroid Build Coastguard Worker
1473*da0073e9SAndroid Build Coastguard Worker    @make_test
1474*da0073e9SAndroid Build Coastguard Worker    def test_sum_with_start_arg(a, b, c, d):
1475*da0073e9SAndroid Build Coastguard Worker        return sum([b, c, d], a)
1476*da0073e9SAndroid Build Coastguard Worker
1477*da0073e9SAndroid Build Coastguard Worker    @make_test
1478*da0073e9SAndroid Build Coastguard Worker    def test_sum_with_start_kwarg(a, b, c, d):
1479*da0073e9SAndroid Build Coastguard Worker        return sum([b, c, d], start=a)
1480*da0073e9SAndroid Build Coastguard Worker
1481*da0073e9SAndroid Build Coastguard Worker    @make_test(expected_frame_count=0)
1482*da0073e9SAndroid Build Coastguard Worker    def test_sum_shortcut():
1483*da0073e9SAndroid Build Coastguard Worker        return sum([0, 1.0, 2, 3.0])
1484*da0073e9SAndroid Build Coastguard Worker
1485*da0073e9SAndroid Build Coastguard Worker    @make_test(expected_frame_count=0)
1486*da0073e9SAndroid Build Coastguard Worker    def test_sum_shortcut_with_start_arg():
1487*da0073e9SAndroid Build Coastguard Worker        return sum([0, 1.0, 2, 3.0], -10)
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Worker    @make_test(expected_frame_count=0)
1490*da0073e9SAndroid Build Coastguard Worker    def test_sum_shortcut_with_start_kwarg():
1491*da0073e9SAndroid Build Coastguard Worker        return sum([0, 1.0, 2, 3.0], start=-10)
1492*da0073e9SAndroid Build Coastguard Worker
1493*da0073e9SAndroid Build Coastguard Worker    @make_test
1494*da0073e9SAndroid Build Coastguard Worker    def test_reduce(a, b, c, d):
1495*da0073e9SAndroid Build Coastguard Worker        return functools.reduce(operator.add, [a, b, c, d])
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker    @make_test
1498*da0073e9SAndroid Build Coastguard Worker    def test_reduce_with_initial(a, b, c, d):
1499*da0073e9SAndroid Build Coastguard Worker        return functools.reduce(operator.add, [b, c, d], a)
1500*da0073e9SAndroid Build Coastguard Worker
1501*da0073e9SAndroid Build Coastguard Worker    @make_test(expected_frame_count=0)
1502*da0073e9SAndroid Build Coastguard Worker    def test_reduce_with_single(x):
1503*da0073e9SAndroid Build Coastguard Worker        return functools.reduce(lambda a, b: (a, b), [x])
1504*da0073e9SAndroid Build Coastguard Worker
1505*da0073e9SAndroid Build Coastguard Worker    @make_test(expected_frame_count=0)
1506*da0073e9SAndroid Build Coastguard Worker    def test_reduce_with_single_with_initial(x, y):
1507*da0073e9SAndroid Build Coastguard Worker        return functools.reduce(lambda a, b: (a, b), [y], x)
1508*da0073e9SAndroid Build Coastguard Worker
1509*da0073e9SAndroid Build Coastguard Worker    @make_test(expected_frame_count=0)
1510*da0073e9SAndroid Build Coastguard Worker    def test_reduce_with_none_initial(x):
1511*da0073e9SAndroid Build Coastguard Worker        return functools.reduce(lambda a, b: (a, b), [x], None)
1512*da0073e9SAndroid Build Coastguard Worker
1513*da0073e9SAndroid Build Coastguard Worker    @make_test
1514*da0073e9SAndroid Build Coastguard Worker    def test_tuple_contains(a, b):
1515*da0073e9SAndroid Build Coastguard Worker        v1 = "a"
1516*da0073e9SAndroid Build Coastguard Worker        v2 = "b"
1517*da0073e9SAndroid Build Coastguard Worker        v3 = "c"
1518*da0073e9SAndroid Build Coastguard Worker        vals1 = (v1, v2, v3)
1519*da0073e9SAndroid Build Coastguard Worker        vals2 = ("d", "e", "f")
1520*da0073e9SAndroid Build Coastguard Worker        if "a" in vals1 and "b" not in vals2:
1521*da0073e9SAndroid Build Coastguard Worker            return a + b
1522*da0073e9SAndroid Build Coastguard Worker        return a - b
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1525*da0073e9SAndroid Build Coastguard Worker        sys.version_info < (3, 9),
1526*da0073e9SAndroid Build Coastguard Worker        "SET_UPDATE was added at Python 3.9",
1527*da0073e9SAndroid Build Coastguard Worker    )
1528*da0073e9SAndroid Build Coastguard Worker    @make_test
1529*da0073e9SAndroid Build Coastguard Worker    def test_set_update_bytecode(x):
1530*da0073e9SAndroid Build Coastguard Worker        # This produces bytecode SET_UPDATE since python 3.9
1531*da0073e9SAndroid Build Coastguard Worker        var = {"apple", "banana", "cherry"}
1532*da0073e9SAndroid Build Coastguard Worker        if isinstance(var, set):
1533*da0073e9SAndroid Build Coastguard Worker            return x + 1
1534*da0073e9SAndroid Build Coastguard Worker        else:
1535*da0073e9SAndroid Build Coastguard Worker            return x - 1
1536*da0073e9SAndroid Build Coastguard Worker
1537*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1538*da0073e9SAndroid Build Coastguard Worker        sys.version_info < (3, 9),
1539*da0073e9SAndroid Build Coastguard Worker        "SET_UPDATE was added at Python 3.9",
1540*da0073e9SAndroid Build Coastguard Worker    )
1541*da0073e9SAndroid Build Coastguard Worker    @make_test
1542*da0073e9SAndroid Build Coastguard Worker    def test_set_update_list_with_duplicated_items(x):
1543*da0073e9SAndroid Build Coastguard Worker        list1 = ["apple", "banana", "apple"]
1544*da0073e9SAndroid Build Coastguard Worker        list2 = ["orange", "banana"]
1545*da0073e9SAndroid Build Coastguard Worker        if len({*list1, *list2}) == 3:
1546*da0073e9SAndroid Build Coastguard Worker            return x + 1
1547*da0073e9SAndroid Build Coastguard Worker        else:
1548*da0073e9SAndroid Build Coastguard Worker            return x - 1
1549*da0073e9SAndroid Build Coastguard Worker
1550*da0073e9SAndroid Build Coastguard Worker    @make_test
1551*da0073e9SAndroid Build Coastguard Worker    def test_set_contains(a, b):
1552*da0073e9SAndroid Build Coastguard Worker        vals = set(["a", "b", "c"])
1553*da0073e9SAndroid Build Coastguard Worker        if "a" in vals:
1554*da0073e9SAndroid Build Coastguard Worker            x = a + b
1555*da0073e9SAndroid Build Coastguard Worker        else:
1556*da0073e9SAndroid Build Coastguard Worker            x = a - b
1557*da0073e9SAndroid Build Coastguard Worker        if "d" in vals:
1558*da0073e9SAndroid Build Coastguard Worker            y = a + b
1559*da0073e9SAndroid Build Coastguard Worker        else:
1560*da0073e9SAndroid Build Coastguard Worker            y = a - b
1561*da0073e9SAndroid Build Coastguard Worker        return x, y
1562*da0073e9SAndroid Build Coastguard Worker
1563*da0073e9SAndroid Build Coastguard Worker    def test_set_isdisjoint(self):
1564*da0073e9SAndroid Build Coastguard Worker        x = {"apple", "banana", "cherry"}
1565*da0073e9SAndroid Build Coastguard Worker        y = {"google", "microsoft", "apple"}
1566*da0073e9SAndroid Build Coastguard Worker
1567*da0073e9SAndroid Build Coastguard Worker        def fn(a):
1568*da0073e9SAndroid Build Coastguard Worker            if x.isdisjoint(y):
1569*da0073e9SAndroid Build Coastguard Worker                return a + 1
1570*da0073e9SAndroid Build Coastguard Worker            else:
1571*da0073e9SAndroid Build Coastguard Worker                return a - 1
1572*da0073e9SAndroid Build Coastguard Worker
1573*da0073e9SAndroid Build Coastguard Worker        test = make_test(fn)
1574*da0073e9SAndroid Build Coastguard Worker        test(self)
1575*da0073e9SAndroid Build Coastguard Worker
1576*da0073e9SAndroid Build Coastguard Worker    @make_test
1577*da0073e9SAndroid Build Coastguard Worker    def test_set_intersection(a, b):
1578*da0073e9SAndroid Build Coastguard Worker        set1 = {"apple", "banana", "cherry"}
1579*da0073e9SAndroid Build Coastguard Worker        set2 = {"google", "microsoft", "apple"}
1580*da0073e9SAndroid Build Coastguard Worker        intersection_set = set1.intersection(set2)
1581*da0073e9SAndroid Build Coastguard Worker        if "apple" in intersection_set:
1582*da0073e9SAndroid Build Coastguard Worker            x = a + b
1583*da0073e9SAndroid Build Coastguard Worker        else:
1584*da0073e9SAndroid Build Coastguard Worker            x = a - b
1585*da0073e9SAndroid Build Coastguard Worker        if "banana" in intersection_set:
1586*da0073e9SAndroid Build Coastguard Worker            y = a + b
1587*da0073e9SAndroid Build Coastguard Worker        else:
1588*da0073e9SAndroid Build Coastguard Worker            y = a - b
1589*da0073e9SAndroid Build Coastguard Worker        return x, y
1590*da0073e9SAndroid Build Coastguard Worker
1591*da0073e9SAndroid Build Coastguard Worker    @make_test
1592*da0073e9SAndroid Build Coastguard Worker    def test_set_union(a, b):
1593*da0073e9SAndroid Build Coastguard Worker        set1 = {"apple", "banana", "cherry"}
1594*da0073e9SAndroid Build Coastguard Worker        set2 = {"google", "microsoft", "apple"}
1595*da0073e9SAndroid Build Coastguard Worker        union_set = set1.union(set2)
1596*da0073e9SAndroid Build Coastguard Worker        if "apple" in union_set:
1597*da0073e9SAndroid Build Coastguard Worker            x = a + b
1598*da0073e9SAndroid Build Coastguard Worker        else:
1599*da0073e9SAndroid Build Coastguard Worker            x = a - b
1600*da0073e9SAndroid Build Coastguard Worker        if "banana" in union_set:
1601*da0073e9SAndroid Build Coastguard Worker            y = a + b
1602*da0073e9SAndroid Build Coastguard Worker        else:
1603*da0073e9SAndroid Build Coastguard Worker            y = a - b
1604*da0073e9SAndroid Build Coastguard Worker        return x, y
1605*da0073e9SAndroid Build Coastguard Worker
1606*da0073e9SAndroid Build Coastguard Worker    @make_test
1607*da0073e9SAndroid Build Coastguard Worker    def test_set_difference(a, b):
1608*da0073e9SAndroid Build Coastguard Worker        set1 = {"apple", "banana", "cherry"}
1609*da0073e9SAndroid Build Coastguard Worker        set2 = {"google", "microsoft", "apple"}
1610*da0073e9SAndroid Build Coastguard Worker        difference_set = set1.difference(set2)
1611*da0073e9SAndroid Build Coastguard Worker        if "apple" in difference_set:
1612*da0073e9SAndroid Build Coastguard Worker            x = a + b
1613*da0073e9SAndroid Build Coastguard Worker        else:
1614*da0073e9SAndroid Build Coastguard Worker            x = a - b
1615*da0073e9SAndroid Build Coastguard Worker        if "banana" in difference_set:
1616*da0073e9SAndroid Build Coastguard Worker            y = a + b
1617*da0073e9SAndroid Build Coastguard Worker        else:
1618*da0073e9SAndroid Build Coastguard Worker            y = a - b
1619*da0073e9SAndroid Build Coastguard Worker        return x, y
1620*da0073e9SAndroid Build Coastguard Worker
1621*da0073e9SAndroid Build Coastguard Worker    def test_set_keys_view(self):
1622*da0073e9SAndroid Build Coastguard Worker        from collections.abc import KeysView
1623*da0073e9SAndroid Build Coastguard Worker
1624*da0073e9SAndroid Build Coastguard Worker        class StringKeys(KeysView):
1625*da0073e9SAndroid Build Coastguard Worker            def __init__(self, keys):
1626*da0073e9SAndroid Build Coastguard Worker                self.keys = keys
1627*da0073e9SAndroid Build Coastguard Worker
1628*da0073e9SAndroid Build Coastguard Worker            def __getitem__(self, key):
1629*da0073e9SAndroid Build Coastguard Worker                return self.keys.__getitem__(key)
1630*da0073e9SAndroid Build Coastguard Worker
1631*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
1632*da0073e9SAndroid Build Coastguard Worker                yield from self.keys
1633*da0073e9SAndroid Build Coastguard Worker
1634*da0073e9SAndroid Build Coastguard Worker            def __repr__(self):
1635*da0073e9SAndroid Build Coastguard Worker                return f"{type(self).__name__}({self.keys})"
1636*da0073e9SAndroid Build Coastguard Worker
1637*da0073e9SAndroid Build Coastguard Worker            def __len__(self):
1638*da0073e9SAndroid Build Coastguard Worker                return len(self.keys)
1639*da0073e9SAndroid Build Coastguard Worker
1640*da0073e9SAndroid Build Coastguard Worker            def __contains__(self, item):
1641*da0073e9SAndroid Build Coastguard Worker                return self.keys.__contains__(item)
1642*da0073e9SAndroid Build Coastguard Worker
1643*da0073e9SAndroid Build Coastguard Worker        a = StringKeys([1, 2, 3, 3])
1644*da0073e9SAndroid Build Coastguard Worker
1645*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1646*da0073e9SAndroid Build Coastguard Worker            set_a = set(a)
1647*da0073e9SAndroid Build Coastguard Worker            return len(set_a) * x
1648*da0073e9SAndroid Build Coastguard Worker
1649*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1650*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
1651*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
1652*da0073e9SAndroid Build Coastguard Worker
1653*da0073e9SAndroid Build Coastguard Worker    def test_constant_set(self):
1654*da0073e9SAndroid Build Coastguard Worker        s = set([1, 2])
1655*da0073e9SAndroid Build Coastguard Worker
1656*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1657*da0073e9SAndroid Build Coastguard Worker            return torch.cos(x) * len(s)
1658*da0073e9SAndroid Build Coastguard Worker
1659*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1660*da0073e9SAndroid Build Coastguard Worker
1661*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
1662*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
1663*da0073e9SAndroid Build Coastguard Worker
1664*da0073e9SAndroid Build Coastguard Worker        # This should cause recompilation
1665*da0073e9SAndroid Build Coastguard Worker        s.add(3)
1666*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
1667*da0073e9SAndroid Build Coastguard Worker
1668*da0073e9SAndroid Build Coastguard Worker    def test_set_add(self):
1669*da0073e9SAndroid Build Coastguard Worker        s = set([1, 2])
1670*da0073e9SAndroid Build Coastguard Worker
1671*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1672*da0073e9SAndroid Build Coastguard Worker            s.add(3)
1673*da0073e9SAndroid Build Coastguard Worker            return torch.cos(x) * len(x)
1674*da0073e9SAndroid Build Coastguard Worker
1675*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1676*da0073e9SAndroid Build Coastguard Worker
1677*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(4)
1678*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
1679*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(s), 3)
1680*da0073e9SAndroid Build Coastguard Worker
1681*da0073e9SAndroid Build Coastguard Worker    @make_test
1682*da0073e9SAndroid Build Coastguard Worker    def test_tuple_iadd(a, b):
1683*da0073e9SAndroid Build Coastguard Worker        output = (a, b)
1684*da0073e9SAndroid Build Coastguard Worker        output += (a + b, a - b)
1685*da0073e9SAndroid Build Coastguard Worker        return output
1686*da0073e9SAndroid Build Coastguard Worker
1687*da0073e9SAndroid Build Coastguard Worker    @make_test
1688*da0073e9SAndroid Build Coastguard Worker    def test_unpack_ex1(x):
1689*da0073e9SAndroid Build Coastguard Worker        output = (x, x + 1, x + 2, x + 3)
1690*da0073e9SAndroid Build Coastguard Worker        a, b, *cd = output
1691*da0073e9SAndroid Build Coastguard Worker        return a - b / cd[0]
1692*da0073e9SAndroid Build Coastguard Worker
1693*da0073e9SAndroid Build Coastguard Worker    @make_test
1694*da0073e9SAndroid Build Coastguard Worker    def test_unpack_ex2(x):
1695*da0073e9SAndroid Build Coastguard Worker        output = (x, x + 1, x + 2, x + 3)
1696*da0073e9SAndroid Build Coastguard Worker        *ab, c, d = output
1697*da0073e9SAndroid Build Coastguard Worker        return c - d / ab[0]
1698*da0073e9SAndroid Build Coastguard Worker
1699*da0073e9SAndroid Build Coastguard Worker    @make_test
1700*da0073e9SAndroid Build Coastguard Worker    def test_unpack_ex3(x):
1701*da0073e9SAndroid Build Coastguard Worker        output = (x, x + 1, x + 2, x + 3)
1702*da0073e9SAndroid Build Coastguard Worker        a, *bc, d = output
1703*da0073e9SAndroid Build Coastguard Worker        return a - d / bc[0]
1704*da0073e9SAndroid Build Coastguard Worker
1705*da0073e9SAndroid Build Coastguard Worker    @make_test
1706*da0073e9SAndroid Build Coastguard Worker    def test_const_tuple_add1(x):
1707*da0073e9SAndroid Build Coastguard Worker        output = (x, x + 1, x + 2, x + 3)
1708*da0073e9SAndroid Build Coastguard Worker        output = () + output + ()
1709*da0073e9SAndroid Build Coastguard Worker        return output[2] + output[3]
1710*da0073e9SAndroid Build Coastguard Worker
1711*da0073e9SAndroid Build Coastguard Worker    @make_test
1712*da0073e9SAndroid Build Coastguard Worker    def test_const_tuple_add2(x):
1713*da0073e9SAndroid Build Coastguard Worker        output = (x, x + 1, x + 2, x + 3)
1714*da0073e9SAndroid Build Coastguard Worker        output = (None,) + output + (None,)
1715*da0073e9SAndroid Build Coastguard Worker        return output[2] + output[3]
1716*da0073e9SAndroid Build Coastguard Worker
1717*da0073e9SAndroid Build Coastguard Worker    @make_test
1718*da0073e9SAndroid Build Coastguard Worker    def test_list_truth(a, b):
1719*da0073e9SAndroid Build Coastguard Worker        tmp = [1, 2, 3]
1720*da0073e9SAndroid Build Coastguard Worker        if tmp:
1721*da0073e9SAndroid Build Coastguard Worker            return a + b
1722*da0073e9SAndroid Build Coastguard Worker        else:
1723*da0073e9SAndroid Build Coastguard Worker            return a - b
1724*da0073e9SAndroid Build Coastguard Worker
1725*da0073e9SAndroid Build Coastguard Worker    @make_test
1726*da0073e9SAndroid Build Coastguard Worker    def test_list_reversed(a, b):
1727*da0073e9SAndroid Build Coastguard Worker        tmp = [a + 1, a + 2, a + 3]
1728*da0073e9SAndroid Build Coastguard Worker        return a + b + next(iter(reversed(tmp)))
1729*da0073e9SAndroid Build Coastguard Worker
1730*da0073e9SAndroid Build Coastguard Worker    @make_test
1731*da0073e9SAndroid Build Coastguard Worker    def test_list_sorted1(x):
1732*da0073e9SAndroid Build Coastguard Worker        tmp = [1, 10, 3, 0]
1733*da0073e9SAndroid Build Coastguard Worker        return x + 1, sorted(tmp), sorted(tmp, reverse=True)
1734*da0073e9SAndroid Build Coastguard Worker
1735*da0073e9SAndroid Build Coastguard Worker    @make_test
1736*da0073e9SAndroid Build Coastguard Worker    def test_list_sorted2(x):
1737*da0073e9SAndroid Build Coastguard Worker        y = [
1738*da0073e9SAndroid Build Coastguard Worker            ("john", "A", 8),
1739*da0073e9SAndroid Build Coastguard Worker            ("jane", "B", 5),
1740*da0073e9SAndroid Build Coastguard Worker            ("dave", "B", 10),
1741*da0073e9SAndroid Build Coastguard Worker        ]
1742*da0073e9SAndroid Build Coastguard Worker        return (
1743*da0073e9SAndroid Build Coastguard Worker            x + 1,
1744*da0073e9SAndroid Build Coastguard Worker            sorted(y),
1745*da0073e9SAndroid Build Coastguard Worker            sorted(y, key=lambda student: student[2]),
1746*da0073e9SAndroid Build Coastguard Worker            sorted(y, key=lambda student: student[2], reverse=True),
1747*da0073e9SAndroid Build Coastguard Worker        )
1748*da0073e9SAndroid Build Coastguard Worker
1749*da0073e9SAndroid Build Coastguard Worker    @make_test
1750*da0073e9SAndroid Build Coastguard Worker    def test_tuple_sorted(x):
1751*da0073e9SAndroid Build Coastguard Worker        tmp = (1, 10, 3, 0)
1752*da0073e9SAndroid Build Coastguard Worker        return x + 1, sorted(tmp), sorted(tmp, reverse=True)
1753*da0073e9SAndroid Build Coastguard Worker
1754*da0073e9SAndroid Build Coastguard Worker    @make_test
1755*da0073e9SAndroid Build Coastguard Worker    def test_dict_sorted(x):
1756*da0073e9SAndroid Build Coastguard Worker        tmp = {1: "D", 10: "B", 3: "E", 0: "F"}
1757*da0073e9SAndroid Build Coastguard Worker        return x + 1, sorted(tmp), sorted(tmp, reverse=True)
1758*da0073e9SAndroid Build Coastguard Worker
1759*da0073e9SAndroid Build Coastguard Worker    def test_dict_hasattr(self):
1760*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1761*da0073e9SAndroid Build Coastguard Worker            if hasattr(x, "to"):
1762*da0073e9SAndroid Build Coastguard Worker                return x.to("cpu")
1763*da0073e9SAndroid Build Coastguard Worker            if hasattr(x, "items"):
1764*da0073e9SAndroid Build Coastguard Worker                return torch.cos(x["a"])
1765*da0073e9SAndroid Build Coastguard Worker            return x
1766*da0073e9SAndroid Build Coastguard Worker
1767*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1768*da0073e9SAndroid Build Coastguard Worker
1769*da0073e9SAndroid Build Coastguard Worker        x = dict(a=torch.randn(3))
1770*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
1771*da0073e9SAndroid Build Coastguard Worker
1772*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
1773*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), opt_fn(x))
1774*da0073e9SAndroid Build Coastguard Worker
1775*da0073e9SAndroid Build Coastguard Worker    @make_test
1776*da0073e9SAndroid Build Coastguard Worker    def test_list_clear(a, b):
1777*da0073e9SAndroid Build Coastguard Worker        tmp = [a + 1, a + 2]
1778*da0073e9SAndroid Build Coastguard Worker        tmp.clear()
1779*da0073e9SAndroid Build Coastguard Worker        tmp.append(a + b)
1780*da0073e9SAndroid Build Coastguard Worker        return tmp
1781*da0073e9SAndroid Build Coastguard Worker
1782*da0073e9SAndroid Build Coastguard Worker    @make_test
1783*da0073e9SAndroid Build Coastguard Worker    def test_not_list(a):
1784*da0073e9SAndroid Build Coastguard Worker        return not [a + 1]
1785*da0073e9SAndroid Build Coastguard Worker
1786*da0073e9SAndroid Build Coastguard Worker    @make_test
1787*da0073e9SAndroid Build Coastguard Worker    def test_islice_chain(a, b):
1788*da0073e9SAndroid Build Coastguard Worker        tmp1 = [a + 1, a + 2]
1789*da0073e9SAndroid Build Coastguard Worker        tmp2 = [a + 3, a + 4]
1790*da0073e9SAndroid Build Coastguard Worker        a, b = list(itertools.islice(itertools.chain(tmp1, tmp2), 1, 3))
1791*da0073e9SAndroid Build Coastguard Worker        c = next(itertools.islice(tmp1, 1, None))
1792*da0073e9SAndroid Build Coastguard Worker        return a - b / c
1793*da0073e9SAndroid Build Coastguard Worker
1794*da0073e9SAndroid Build Coastguard Worker    @make_test
1795*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple(a, b):
1796*da0073e9SAndroid Build Coastguard Worker        mytuple = collections.namedtuple("mytuple", ["x", "y", "xy"])
1797*da0073e9SAndroid Build Coastguard Worker        tmp = mytuple(a, b, a + b)
1798*da0073e9SAndroid Build Coastguard Worker        return mytuple(tmp.x, tmp[1], tmp.xy + b)
1799*da0073e9SAndroid Build Coastguard Worker
1800*da0073e9SAndroid Build Coastguard Worker    @make_test
1801*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple_defaults(a, b):
1802*da0073e9SAndroid Build Coastguard Worker        mytuple = collections.namedtuple(
1803*da0073e9SAndroid Build Coastguard Worker            "mytuple", ["x", "y", "xy"], defaults=(None, 1, None)
1804*da0073e9SAndroid Build Coastguard Worker        )
1805*da0073e9SAndroid Build Coastguard Worker        tmp = mytuple(a, xy=b)
1806*da0073e9SAndroid Build Coastguard Worker        return mytuple(tmp.x, tmp[1], tmp.xy + b)
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker    class MyNamedTuple(NamedTuple):
1809*da0073e9SAndroid Build Coastguard Worker        first: torch.Tensor
1810*da0073e9SAndroid Build Coastguard Worker        second: torch.Tensor
1811*da0073e9SAndroid Build Coastguard Worker
1812*da0073e9SAndroid Build Coastguard Worker        def add(self) -> torch.Tensor:
1813*da0073e9SAndroid Build Coastguard Worker            return self.first + self.second
1814*da0073e9SAndroid Build Coastguard Worker
1815*da0073e9SAndroid Build Coastguard Worker        @staticmethod
1816*da0073e9SAndroid Build Coastguard Worker        def static_method() -> int:
1817*da0073e9SAndroid Build Coastguard Worker            return 1
1818*da0073e9SAndroid Build Coastguard Worker
1819*da0073e9SAndroid Build Coastguard Worker        @classmethod
1820*da0073e9SAndroid Build Coastguard Worker        def class_method(cls) -> str:
1821*da0073e9SAndroid Build Coastguard Worker            return cls.__name__
1822*da0073e9SAndroid Build Coastguard Worker
1823*da0073e9SAndroid Build Coastguard Worker    @make_test
1824*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple_user_methods(a, b):
1825*da0073e9SAndroid Build Coastguard Worker        mytuple = FunctionTests.MyNamedTuple(a, b)
1826*da0073e9SAndroid Build Coastguard Worker        return mytuple.add(), mytuple.static_method(), mytuple.class_method()
1827*da0073e9SAndroid Build Coastguard Worker
1828*da0073e9SAndroid Build Coastguard Worker    @make_test
1829*da0073e9SAndroid Build Coastguard Worker    def test_namedtuple_hasattr(a, b):
1830*da0073e9SAndroid Build Coastguard Worker        mytuple = FunctionTests.MyNamedTuple(a, b)
1831*da0073e9SAndroid Build Coastguard Worker
1832*da0073e9SAndroid Build Coastguard Worker        def isinstance_namedtuple(obj) -> bool:
1833*da0073e9SAndroid Build Coastguard Worker            return (
1834*da0073e9SAndroid Build Coastguard Worker                isinstance(obj, tuple)
1835*da0073e9SAndroid Build Coastguard Worker                and hasattr(obj, "_asdict")
1836*da0073e9SAndroid Build Coastguard Worker                and hasattr(obj, "_fields")
1837*da0073e9SAndroid Build Coastguard Worker            )
1838*da0073e9SAndroid Build Coastguard Worker
1839*da0073e9SAndroid Build Coastguard Worker        if isinstance_namedtuple(mytuple):
1840*da0073e9SAndroid Build Coastguard Worker            return a + b
1841*da0073e9SAndroid Build Coastguard Worker        else:
1842*da0073e9SAndroid Build Coastguard Worker            return a - b
1843*da0073e9SAndroid Build Coastguard Worker
1844*da0073e9SAndroid Build Coastguard Worker    @make_test
1845*da0073e9SAndroid Build Coastguard Worker    def test_torch_size_hasattr(x):
1846*da0073e9SAndroid Build Coastguard Worker        if hasattr(x.shape, "_fields"):
1847*da0073e9SAndroid Build Coastguard Worker            return x + 1
1848*da0073e9SAndroid Build Coastguard Worker        else:
1849*da0073e9SAndroid Build Coastguard Worker            return x - 1
1850*da0073e9SAndroid Build Coastguard Worker
1851*da0073e9SAndroid Build Coastguard Worker    @make_test
1852*da0073e9SAndroid Build Coastguard Worker    def test_is_quantized(a, b):
1853*da0073e9SAndroid Build Coastguard Worker        if not a.is_quantized:
1854*da0073e9SAndroid Build Coastguard Worker            return a + b
1855*da0073e9SAndroid Build Coastguard Worker
1856*da0073e9SAndroid Build Coastguard Worker    @make_test
1857*da0073e9SAndroid Build Coastguard Worker    def test_fstrings1(a, b):
1858*da0073e9SAndroid Build Coastguard Worker        x = 1.229
1859*da0073e9SAndroid Build Coastguard Worker        tmp = f"{x:.2f} bar"
1860*da0073e9SAndroid Build Coastguard Worker        if tmp.startswith("1.23"):
1861*da0073e9SAndroid Build Coastguard Worker            return a + b
1862*da0073e9SAndroid Build Coastguard Worker
1863*da0073e9SAndroid Build Coastguard Worker    @make_test
1864*da0073e9SAndroid Build Coastguard Worker    def test_fstrings2(x):
1865*da0073e9SAndroid Build Coastguard Worker        tmp = f"{x.shape[0]} bar"
1866*da0073e9SAndroid Build Coastguard Worker        if tmp.startswith("10"):
1867*da0073e9SAndroid Build Coastguard Worker            return x + 1
1868*da0073e9SAndroid Build Coastguard Worker
1869*da0073e9SAndroid Build Coastguard Worker    @make_test
1870*da0073e9SAndroid Build Coastguard Worker    def test_fstrings3(x):
1871*da0073e9SAndroid Build Coastguard Worker        tmp = f"{x.__class__.__name__} foo"
1872*da0073e9SAndroid Build Coastguard Worker        if tmp.startswith("Tensor"):
1873*da0073e9SAndroid Build Coastguard Worker            return x + 1
1874*da0073e9SAndroid Build Coastguard Worker
1875*da0073e9SAndroid Build Coastguard Worker    @make_test
1876*da0073e9SAndroid Build Coastguard Worker    def test_fstrings4(x):
1877*da0073e9SAndroid Build Coastguard Worker        tmp = f"{x.shape[0]} bar"
1878*da0073e9SAndroid Build Coastguard Worker        if "10" in tmp:
1879*da0073e9SAndroid Build Coastguard Worker            return x + 1
1880*da0073e9SAndroid Build Coastguard Worker
1881*da0073e9SAndroid Build Coastguard Worker    @make_test
1882*da0073e9SAndroid Build Coastguard Worker    def test_fstrings5(x):
1883*da0073e9SAndroid Build Coastguard Worker        tmp = f"{x.shape[0]} bar"
1884*da0073e9SAndroid Build Coastguard Worker        if "10" in (tmp + "haha"):
1885*da0073e9SAndroid Build Coastguard Worker            return x + 1
1886*da0073e9SAndroid Build Coastguard Worker
1887*da0073e9SAndroid Build Coastguard Worker    @make_test
1888*da0073e9SAndroid Build Coastguard Worker    def test_fstrings6(x):
1889*da0073e9SAndroid Build Coastguard Worker        tmp = f"{x.shape[0] + x.shape[1]}"
1890*da0073e9SAndroid Build Coastguard Worker        if "20" in tmp:
1891*da0073e9SAndroid Build Coastguard Worker            return x + 1
1892*da0073e9SAndroid Build Coastguard Worker
1893*da0073e9SAndroid Build Coastguard Worker    @make_test
1894*da0073e9SAndroid Build Coastguard Worker    def test_tensor_new_with_size(x):
1895*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(5, 8)
1896*da0073e9SAndroid Build Coastguard Worker        z = x.new(y.size())
1897*da0073e9SAndroid Build Coastguard Worker        assert z.size() == y.size()
1898*da0073e9SAndroid Build Coastguard Worker
1899*da0073e9SAndroid Build Coastguard Worker    @make_test
1900*da0073e9SAndroid Build Coastguard Worker    def test_tensor_new_with_shape(x):
1901*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(5, 8)
1902*da0073e9SAndroid Build Coastguard Worker        z = x.new(y.shape)
1903*da0073e9SAndroid Build Coastguard Worker        assert z.size() == y.size()
1904*da0073e9SAndroid Build Coastguard Worker
1905*da0073e9SAndroid Build Coastguard Worker    @make_test
1906*da0073e9SAndroid Build Coastguard Worker    def test_jit_annotate(x):
1907*da0073e9SAndroid Build Coastguard Worker        y = torch.jit.annotate(Any, x + 1)
1908*da0073e9SAndroid Build Coastguard Worker        return y + 2
1909*da0073e9SAndroid Build Coastguard Worker
1910*da0073e9SAndroid Build Coastguard Worker    @make_test
1911*da0073e9SAndroid Build Coastguard Worker    def test_is_contiguous_memory_format(tensor):
1912*da0073e9SAndroid Build Coastguard Worker        if torch.jit.is_scripting():
1913*da0073e9SAndroid Build Coastguard Worker            return None
1914*da0073e9SAndroid Build Coastguard Worker        elif tensor.is_contiguous(memory_format=torch.contiguous_format):
1915*da0073e9SAndroid Build Coastguard Worker            return tensor + 1
1916*da0073e9SAndroid Build Coastguard Worker
1917*da0073e9SAndroid Build Coastguard Worker    def test_is_contiguous_frame_counts(self):
1918*da0073e9SAndroid Build Coastguard Worker        data = [
1919*da0073e9SAndroid Build Coastguard Worker            torch.rand(10),
1920*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 3, 32, 32),
1921*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 3, 32, 32).contiguous(memory_format=torch.channels_last),
1922*da0073e9SAndroid Build Coastguard Worker            torch.rand(10)[::2],
1923*da0073e9SAndroid Build Coastguard Worker            torch.rand(12),
1924*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 3, 24, 24).contiguous(memory_format=torch.channels_last),
1925*da0073e9SAndroid Build Coastguard Worker            torch.rand(50)[::2],
1926*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 3, 32, 32)[:, :, 2:-2, 3:-3],
1927*da0073e9SAndroid Build Coastguard Worker        ]
1928*da0073e9SAndroid Build Coastguard Worker        # dynamo should recompile for all inputs in static shapes mode
1929*da0073e9SAndroid Build Coastguard Worker        expected_frame_counts_static = [1, 2, 3, 4, 5, 6, 7, 8]
1930*da0073e9SAndroid Build Coastguard Worker        # dynamo should recompile for items 0, 1, 2, 6 in dynamic shapes mode
1931*da0073e9SAndroid Build Coastguard Worker        expected_frame_counts_dynamic = [1, 2, 3, 4, 4, 4, 4, 5]
1932*da0073e9SAndroid Build Coastguard Worker        expected_frame_counts = ifdynstaticdefault(
1933*da0073e9SAndroid Build Coastguard Worker            expected_frame_counts_static, expected_frame_counts_dynamic
1934*da0073e9SAndroid Build Coastguard Worker        )
1935*da0073e9SAndroid Build Coastguard Worker        dynamic = ifdynstaticdefault(False, True)
1936*da0073e9SAndroid Build Coastguard Worker
1937*da0073e9SAndroid Build Coastguard Worker        def func(x):
1938*da0073e9SAndroid Build Coastguard Worker            if x.is_contiguous():
1939*da0073e9SAndroid Build Coastguard Worker                return x + 1
1940*da0073e9SAndroid Build Coastguard Worker            elif x.is_contiguous(memory_format=torch.channels_last):
1941*da0073e9SAndroid Build Coastguard Worker                return x + 2
1942*da0073e9SAndroid Build Coastguard Worker            else:
1943*da0073e9SAndroid Build Coastguard Worker                return x + 3
1944*da0073e9SAndroid Build Coastguard Worker
1945*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
1946*da0073e9SAndroid Build Coastguard Worker        cfunc = torch._dynamo.optimize_assert(cnt, dynamic=dynamic)(func)
1947*da0073e9SAndroid Build Coastguard Worker
1948*da0073e9SAndroid Build Coastguard Worker        assert cnt.frame_count == 0
1949*da0073e9SAndroid Build Coastguard Worker        for i, x in enumerate(data):
1950*da0073e9SAndroid Build Coastguard Worker            expected = func(x)
1951*da0073e9SAndroid Build Coastguard Worker            output = cfunc(x)
1952*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(output, expected))
1953*da0073e9SAndroid Build Coastguard Worker            assert cnt.frame_count == expected_frame_counts[i]
1954*da0073e9SAndroid Build Coastguard Worker
1955*da0073e9SAndroid Build Coastguard Worker    @make_test
1956*da0073e9SAndroid Build Coastguard Worker    def test_list_slice_assignment(x):
1957*da0073e9SAndroid Build Coastguard Worker        m = [1, 2, 3, 4]
1958*da0073e9SAndroid Build Coastguard Worker        m[1:] = [6] * (len(m) - 1)
1959*da0073e9SAndroid Build Coastguard Worker        return x + 1
1960*da0073e9SAndroid Build Coastguard Worker
1961*da0073e9SAndroid Build Coastguard Worker    @make_test
1962*da0073e9SAndroid Build Coastguard Worker    def test_distributed_is_available(x):
1963*da0073e9SAndroid Build Coastguard Worker        if torch.distributed.is_available():
1964*da0073e9SAndroid Build Coastguard Worker            return x + 1
1965*da0073e9SAndroid Build Coastguard Worker        else:
1966*da0073e9SAndroid Build Coastguard Worker            return x - 1
1967*da0073e9SAndroid Build Coastguard Worker
1968*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
1969*da0073e9SAndroid Build Coastguard Worker        not torch.distributed.is_available(), "requires distributed package"
1970*da0073e9SAndroid Build Coastguard Worker    )
1971*da0073e9SAndroid Build Coastguard Worker    @make_test
1972*da0073e9SAndroid Build Coastguard Worker    def test_distributed_is_initialized(x):
1973*da0073e9SAndroid Build Coastguard Worker        if torch.distributed.is_initialized():
1974*da0073e9SAndroid Build Coastguard Worker            return x + 1
1975*da0073e9SAndroid Build Coastguard Worker        else:
1976*da0073e9SAndroid Build Coastguard Worker            return x - 1
1977*da0073e9SAndroid Build Coastguard Worker
1978*da0073e9SAndroid Build Coastguard Worker    @disable_translation_validation_if_dynamic_shapes
1979*da0073e9SAndroid Build Coastguard Worker    @make_test
1980*da0073e9SAndroid Build Coastguard Worker    def test_torch_distributions_functions(x):
1981*da0073e9SAndroid Build Coastguard Worker        normal = torch.distributions.Normal(x, torch.tensor(1))
1982*da0073e9SAndroid Build Coastguard Worker        independent = torch.distributions.Independent(normal, 1)
1983*da0073e9SAndroid Build Coastguard Worker        return independent.log_prob(x)
1984*da0073e9SAndroid Build Coastguard Worker
1985*da0073e9SAndroid Build Coastguard Worker    @make_test
1986*da0073e9SAndroid Build Coastguard Worker    def test_context_wrapping_nested_functions_no_closure(x):
1987*da0073e9SAndroid Build Coastguard Worker        @torch.no_grad()
1988*da0073e9SAndroid Build Coastguard Worker        def augment(x: torch.Tensor) -> torch.Tensor:
1989*da0073e9SAndroid Build Coastguard Worker            return (x + 1) * 2
1990*da0073e9SAndroid Build Coastguard Worker
1991*da0073e9SAndroid Build Coastguard Worker        return augment(x)
1992*da0073e9SAndroid Build Coastguard Worker
1993*da0073e9SAndroid Build Coastguard Worker    # # This is to test the new syntax for pattern matching
1994*da0073e9SAndroid Build Coastguard Worker    # # ("match ... case ...") added on python 3.10.
1995*da0073e9SAndroid Build Coastguard Worker    # # Uncomment these test cases if you run on 3.10+
1996*da0073e9SAndroid Build Coastguard Worker    # @make_test
1997*da0073e9SAndroid Build Coastguard Worker    # def test_match_sequence(a):
1998*da0073e9SAndroid Build Coastguard Worker    #     point = (5, 8)
1999*da0073e9SAndroid Build Coastguard Worker    #     match point:
2000*da0073e9SAndroid Build Coastguard Worker    #         case (0, 0):
2001*da0073e9SAndroid Build Coastguard Worker    #             return a
2002*da0073e9SAndroid Build Coastguard Worker    #         case (0, y):
2003*da0073e9SAndroid Build Coastguard Worker    #             return a - y
2004*da0073e9SAndroid Build Coastguard Worker    #         case (x, 0):
2005*da0073e9SAndroid Build Coastguard Worker    #             return a + x
2006*da0073e9SAndroid Build Coastguard Worker    #         case (x, y):
2007*da0073e9SAndroid Build Coastguard Worker    #             return a + x - y
2008*da0073e9SAndroid Build Coastguard Worker
2009*da0073e9SAndroid Build Coastguard Worker    # @make_test
2010*da0073e9SAndroid Build Coastguard Worker    # def test_match_mapping_and_match_keys(x):
2011*da0073e9SAndroid Build Coastguard Worker    #     param = {"a": 0.5}
2012*da0073e9SAndroid Build Coastguard Worker    #     match param:
2013*da0073e9SAndroid Build Coastguard Worker    #         case {"a": param}:
2014*da0073e9SAndroid Build Coastguard Worker    #             return x * param
2015*da0073e9SAndroid Build Coastguard Worker    #         case {"b": param}:
2016*da0073e9SAndroid Build Coastguard Worker    #             return x / param
2017*da0073e9SAndroid Build Coastguard Worker
2018*da0073e9SAndroid Build Coastguard Worker    def test_math_radians(self):
2019*da0073e9SAndroid Build Coastguard Worker        def func(x, a):
2020*da0073e9SAndroid Build Coastguard Worker            return x + math.radians(a)
2021*da0073e9SAndroid Build Coastguard Worker
2022*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
2023*da0073e9SAndroid Build Coastguard Worker        cfunc = torch._dynamo.optimize_assert(cnt)(func)
2024*da0073e9SAndroid Build Coastguard Worker
2025*da0073e9SAndroid Build Coastguard Worker        assert cnt.frame_count == 0
2026*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(10)
2027*da0073e9SAndroid Build Coastguard Worker        expected = func(x, 12)
2028*da0073e9SAndroid Build Coastguard Worker        output = cfunc(x, 12)
2029*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(output, expected))
2030*da0073e9SAndroid Build Coastguard Worker        assert cnt.frame_count == 1
2031*da0073e9SAndroid Build Coastguard Worker
2032*da0073e9SAndroid Build Coastguard Worker    @make_test
2033*da0073e9SAndroid Build Coastguard Worker    def test_numpy_meshgrid(x, y):
2034*da0073e9SAndroid Build Coastguard Worker        r1, r2 = np.meshgrid(x.numpy(), y.numpy())
2035*da0073e9SAndroid Build Coastguard Worker        return torch.from_numpy(r1), torch.from_numpy(r2)
2036*da0073e9SAndroid Build Coastguard Worker
2037*da0073e9SAndroid Build Coastguard Worker    @make_test
2038*da0073e9SAndroid Build Coastguard Worker    def test_torch_from_numpy(x):
2039*da0073e9SAndroid Build Coastguard Worker        a = x.numpy()
2040*da0073e9SAndroid Build Coastguard Worker        b = torch.from_numpy(a)
2041*da0073e9SAndroid Build Coastguard Worker        if b.size(0) == 1:
2042*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(True)
2043*da0073e9SAndroid Build Coastguard Worker        else:
2044*da0073e9SAndroid Build Coastguard Worker            return torch.tensor(False)
2045*da0073e9SAndroid Build Coastguard Worker
2046*da0073e9SAndroid Build Coastguard Worker    @make_test
2047*da0073e9SAndroid Build Coastguard Worker    def test_numpy_size(x):
2048*da0073e9SAndroid Build Coastguard Worker        a = x.numpy()
2049*da0073e9SAndroid Build Coastguard Worker        return a.size
2050*da0073e9SAndroid Build Coastguard Worker
2051*da0073e9SAndroid Build Coastguard Worker    @make_test
2052*da0073e9SAndroid Build Coastguard Worker    def test_numpy_attributes(x):
2053*da0073e9SAndroid Build Coastguard Worker        a = x.numpy()
2054*da0073e9SAndroid Build Coastguard Worker        return (
2055*da0073e9SAndroid Build Coastguard Worker            a.itemsize,
2056*da0073e9SAndroid Build Coastguard Worker            a.strides,
2057*da0073e9SAndroid Build Coastguard Worker            a.shape,
2058*da0073e9SAndroid Build Coastguard Worker            a.ndim,
2059*da0073e9SAndroid Build Coastguard Worker            a.size,
2060*da0073e9SAndroid Build Coastguard Worker            torch.from_numpy(a.T),
2061*da0073e9SAndroid Build Coastguard Worker            torch.from_numpy(a.real),
2062*da0073e9SAndroid Build Coastguard Worker            torch.from_numpy(a.imag),
2063*da0073e9SAndroid Build Coastguard Worker        )
2064*da0073e9SAndroid Build Coastguard Worker
2065*da0073e9SAndroid Build Coastguard Worker    @make_test
2066*da0073e9SAndroid Build Coastguard Worker    def test_mean_sum_np(x: torch.Tensor):
2067*da0073e9SAndroid Build Coastguard Worker        x_mean = np.mean(x.numpy(), 1)
2068*da0073e9SAndroid Build Coastguard Worker        x_sum = np.sum(x_mean)
2069*da0073e9SAndroid Build Coastguard Worker        x_sum_array = np.asarray(x_sum)
2070*da0073e9SAndroid Build Coastguard Worker        return torch.from_numpy(x_sum_array)
2071*da0073e9SAndroid Build Coastguard Worker
2072*da0073e9SAndroid Build Coastguard Worker    @make_test
2073*da0073e9SAndroid Build Coastguard Worker    def test_return_numpy_ndarray(x):
2074*da0073e9SAndroid Build Coastguard Worker        a = x.numpy()
2075*da0073e9SAndroid Build Coastguard Worker        return a.T
2076*da0073e9SAndroid Build Coastguard Worker
2077*da0073e9SAndroid Build Coastguard Worker    @make_test
2078*da0073e9SAndroid Build Coastguard Worker    def test_return_multiple_numpy_ndarray(x):
2079*da0073e9SAndroid Build Coastguard Worker        a = x.numpy()
2080*da0073e9SAndroid Build Coastguard Worker        return a.T, a.imag, a.real
2081*da0073e9SAndroid Build Coastguard Worker
2082*da0073e9SAndroid Build Coastguard Worker    @make_test
2083*da0073e9SAndroid Build Coastguard Worker    def test_ndarray_method(x):
2084*da0073e9SAndroid Build Coastguard Worker        a = x.numpy()
2085*da0073e9SAndroid Build Coastguard Worker        return a.copy()
2086*da0073e9SAndroid Build Coastguard Worker
2087*da0073e9SAndroid Build Coastguard Worker    @make_test
2088*da0073e9SAndroid Build Coastguard Worker    def test_ndarray_transpose(x):
2089*da0073e9SAndroid Build Coastguard Worker        a = x.numpy()
2090*da0073e9SAndroid Build Coastguard Worker        return a.transpose(0, 1)
2091*da0073e9SAndroid Build Coastguard Worker
2092*da0073e9SAndroid Build Coastguard Worker    @make_test
2093*da0073e9SAndroid Build Coastguard Worker    def test_ndarray_reshape(x):
2094*da0073e9SAndroid Build Coastguard Worker        a = x.numpy()
2095*da0073e9SAndroid Build Coastguard Worker        return a.reshape([1, a.size])
2096*da0073e9SAndroid Build Coastguard Worker
2097*da0073e9SAndroid Build Coastguard Worker    @make_test
2098*da0073e9SAndroid Build Coastguard Worker    def test_ndarray_methods_returning_scalar(x):
2099*da0073e9SAndroid Build Coastguard Worker        a = x.numpy()
2100*da0073e9SAndroid Build Coastguard Worker        return a.max(axis=0), a.all(axis=0)
2101*da0073e9SAndroid Build Coastguard Worker
2102*da0073e9SAndroid Build Coastguard Worker    @make_test
2103*da0073e9SAndroid Build Coastguard Worker    def test_ndarray_builtin_functions(x):
2104*da0073e9SAndroid Build Coastguard Worker        a = x.numpy()
2105*da0073e9SAndroid Build Coastguard Worker        return a + a, a - a
2106*da0073e9SAndroid Build Coastguard Worker
2107*da0073e9SAndroid Build Coastguard Worker    @make_test
2108*da0073e9SAndroid Build Coastguard Worker    def test_numpy_dtype_argument_to_function(x):
2109*da0073e9SAndroid Build Coastguard Worker        return np.ones_like(x, dtype=np.float64)
2110*da0073e9SAndroid Build Coastguard Worker
2111*da0073e9SAndroid Build Coastguard Worker    @make_test
2112*da0073e9SAndroid Build Coastguard Worker    def test_numpy_dtype_call_in_function(x):
2113*da0073e9SAndroid Build Coastguard Worker        dt = np.dtype("float")
2114*da0073e9SAndroid Build Coastguard Worker        return np.full_like(x, 2.4, dtype=dt)
2115*da0073e9SAndroid Build Coastguard Worker
2116*da0073e9SAndroid Build Coastguard Worker    @make_test
2117*da0073e9SAndroid Build Coastguard Worker    def test_numpy_linalg(x):
2118*da0073e9SAndroid Build Coastguard Worker        return np.linalg.norm(x.numpy(), axis=0)
2119*da0073e9SAndroid Build Coastguard Worker
2120*da0073e9SAndroid Build Coastguard Worker    @make_test
2121*da0073e9SAndroid Build Coastguard Worker    def test_numpy_fft(x):
2122*da0073e9SAndroid Build Coastguard Worker        return np.fft.fftshift(x.numpy())
2123*da0073e9SAndroid Build Coastguard Worker
2124*da0073e9SAndroid Build Coastguard Worker    @make_test
2125*da0073e9SAndroid Build Coastguard Worker    def test_numpy_random():
2126*da0073e9SAndroid Build Coastguard Worker        x = np.random.randn(2, 2)
2127*da0073e9SAndroid Build Coastguard Worker        return x - x
2128*da0073e9SAndroid Build Coastguard Worker
2129*da0073e9SAndroid Build Coastguard Worker    @make_test
2130*da0073e9SAndroid Build Coastguard Worker    def test_partials_torch_op_kwarg(x):
2131*da0073e9SAndroid Build Coastguard Worker        par_mul = functools.partial(torch.mul, other=torch.ones(10, 10))
2132*da0073e9SAndroid Build Coastguard Worker        return par_mul(x)
2133*da0073e9SAndroid Build Coastguard Worker
2134*da0073e9SAndroid Build Coastguard Worker    @make_test
2135*da0073e9SAndroid Build Coastguard Worker    def test_partials_torch_op_arg(x):
2136*da0073e9SAndroid Build Coastguard Worker        par_mul = functools.partial(torch.mul, torch.ones(10, 10))
2137*da0073e9SAndroid Build Coastguard Worker        return par_mul(x)
2138*da0073e9SAndroid Build Coastguard Worker
2139*da0073e9SAndroid Build Coastguard Worker    @make_test
2140*da0073e9SAndroid Build Coastguard Worker    def test_partials_udf_arg(x):
2141*da0073e9SAndroid Build Coastguard Worker        par_mul = functools.partial(udf_mul, torch.ones(10, 10))
2142*da0073e9SAndroid Build Coastguard Worker        return par_mul(x)
2143*da0073e9SAndroid Build Coastguard Worker
2144*da0073e9SAndroid Build Coastguard Worker    @make_test
2145*da0073e9SAndroid Build Coastguard Worker    def test_list_add_then_mutate(x):
2146*da0073e9SAndroid Build Coastguard Worker        my_list = [1, x]
2147*da0073e9SAndroid Build Coastguard Worker        y = x / 4.0
2148*da0073e9SAndroid Build Coastguard Worker        my_list = my_list + [x / 2.0, 4]
2149*da0073e9SAndroid Build Coastguard Worker        my_list.append(y)
2150*da0073e9SAndroid Build Coastguard Worker        return sum(my_list)
2151*da0073e9SAndroid Build Coastguard Worker
2152*da0073e9SAndroid Build Coastguard Worker    @make_test
2153*da0073e9SAndroid Build Coastguard Worker    def test_list_expand_lhs(x):
2154*da0073e9SAndroid Build Coastguard Worker        return sum(4 * [x])
2155*da0073e9SAndroid Build Coastguard Worker
2156*da0073e9SAndroid Build Coastguard Worker    @make_test
2157*da0073e9SAndroid Build Coastguard Worker    def test_in_not_in(x):
2158*da0073e9SAndroid Build Coastguard Worker        mylist = [1, 2, 3, 4, 5, x]
2159*da0073e9SAndroid Build Coastguard Worker        myotherlist = [1, 2, 3, 4, 5]
2160*da0073e9SAndroid Build Coastguard Worker        assert 3 in mylist
2161*da0073e9SAndroid Build Coastguard Worker        assert 6 not in myotherlist
2162*da0073e9SAndroid Build Coastguard Worker        return sum(mylist)
2163*da0073e9SAndroid Build Coastguard Worker
2164*da0073e9SAndroid Build Coastguard Worker    @make_test
2165*da0073e9SAndroid Build Coastguard Worker    def test_are_functorch_transforms_active(x):
2166*da0073e9SAndroid Build Coastguard Worker        if torch._C._are_functorch_transforms_active():
2167*da0073e9SAndroid Build Coastguard Worker            return x + 1
2168*da0073e9SAndroid Build Coastguard Worker        else:
2169*da0073e9SAndroid Build Coastguard Worker            return x - 1
2170*da0073e9SAndroid Build Coastguard Worker
2171*da0073e9SAndroid Build Coastguard Worker    @make_test
2172*da0073e9SAndroid Build Coastguard Worker    def test_partials_udf_kwarg(x):
2173*da0073e9SAndroid Build Coastguard Worker        par_mul = functools.partial(udf_mul, y=torch.ones(10, 10))
2174*da0073e9SAndroid Build Coastguard Worker        return par_mul(x)
2175*da0073e9SAndroid Build Coastguard Worker
2176*da0073e9SAndroid Build Coastguard Worker    @make_test
2177*da0073e9SAndroid Build Coastguard Worker    def test_partials_udf_kwarg_module(x, y):
2178*da0073e9SAndroid Build Coastguard Worker        par_mod = functools.partial(udf_module, mod=SmallNN())
2179*da0073e9SAndroid Build Coastguard Worker        return par_mod(x=x, y=y)
2180*da0073e9SAndroid Build Coastguard Worker
2181*da0073e9SAndroid Build Coastguard Worker    @make_test
2182*da0073e9SAndroid Build Coastguard Worker    def test_partials_udf_kwarg_method(x, y):
2183*da0073e9SAndroid Build Coastguard Worker        par_mod = functools.partial(udf_module, mod=SmallNN().forward)
2184*da0073e9SAndroid Build Coastguard Worker        return par_mod(x=x, y=y)
2185*da0073e9SAndroid Build Coastguard Worker
2186*da0073e9SAndroid Build Coastguard Worker    @make_test
2187*da0073e9SAndroid Build Coastguard Worker    def test_partials_lambda(x):
2188*da0073e9SAndroid Build Coastguard Worker        multiply = lambda x, y: x * y
2189*da0073e9SAndroid Build Coastguard Worker        triple = functools.partial(multiply, y=3)
2190*da0073e9SAndroid Build Coastguard Worker        return triple(x)
2191*da0073e9SAndroid Build Coastguard Worker
2192*da0073e9SAndroid Build Coastguard Worker    @unittest.skipUnless(torch.distributed.is_available(), "requires torch.distributed")
2193*da0073e9SAndroid Build Coastguard Worker    @make_test
2194*da0073e9SAndroid Build Coastguard Worker    def test_flat_param_same_storage_size(x, y):
2195*da0073e9SAndroid Build Coastguard Worker        import torch.distributed.fsdp._flat_param as flat_param
2196*da0073e9SAndroid Build Coastguard Worker
2197*da0073e9SAndroid Build Coastguard Worker        if flat_param._same_storage_size(x, 100):
2198*da0073e9SAndroid Build Coastguard Worker            x = x + 1
2199*da0073e9SAndroid Build Coastguard Worker        else:
2200*da0073e9SAndroid Build Coastguard Worker            x = x - 1
2201*da0073e9SAndroid Build Coastguard Worker        if flat_param._same_storage_size(y, 123):
2202*da0073e9SAndroid Build Coastguard Worker            y = y + 1
2203*da0073e9SAndroid Build Coastguard Worker        else:
2204*da0073e9SAndroid Build Coastguard Worker            y = y - 1
2205*da0073e9SAndroid Build Coastguard Worker        return x, y
2206*da0073e9SAndroid Build Coastguard Worker
2207*da0073e9SAndroid Build Coastguard Worker    @parametrize(
2208*da0073e9SAndroid Build Coastguard Worker        "attr",
2209*da0073e9SAndroid Build Coastguard Worker        (
2210*da0073e9SAndroid Build Coastguard Worker            # True
2211*da0073e9SAndroid Build Coastguard Worker            "__subclasshook__",
2212*da0073e9SAndroid Build Coastguard Worker            "__lt__",
2213*da0073e9SAndroid Build Coastguard Worker            "__hash__",
2214*da0073e9SAndroid Build Coastguard Worker            "__ge__",
2215*da0073e9SAndroid Build Coastguard Worker            "__le__",
2216*da0073e9SAndroid Build Coastguard Worker            "__gt__",
2217*da0073e9SAndroid Build Coastguard Worker            "__dict__",
2218*da0073e9SAndroid Build Coastguard Worker            "__getattribute__",
2219*da0073e9SAndroid Build Coastguard Worker            "__setattr__",
2220*da0073e9SAndroid Build Coastguard Worker            "__doc__",
2221*da0073e9SAndroid Build Coastguard Worker            "__repr__",
2222*da0073e9SAndroid Build Coastguard Worker            "__dir__",
2223*da0073e9SAndroid Build Coastguard Worker            "__init__",
2224*da0073e9SAndroid Build Coastguard Worker            "__new__",
2225*da0073e9SAndroid Build Coastguard Worker            "__class__",
2226*da0073e9SAndroid Build Coastguard Worker            "__eq__",
2227*da0073e9SAndroid Build Coastguard Worker            "__delattr__",
2228*da0073e9SAndroid Build Coastguard Worker            "__reduce__",
2229*da0073e9SAndroid Build Coastguard Worker            "__module__",
2230*da0073e9SAndroid Build Coastguard Worker            "__format__",
2231*da0073e9SAndroid Build Coastguard Worker            "__str__",
2232*da0073e9SAndroid Build Coastguard Worker            "__sizeof__",
2233*da0073e9SAndroid Build Coastguard Worker            "__ne__",
2234*da0073e9SAndroid Build Coastguard Worker            "__call__",
2235*da0073e9SAndroid Build Coastguard Worker            "__reduce_ex__",
2236*da0073e9SAndroid Build Coastguard Worker            "__init_subclass__",
2237*da0073e9SAndroid Build Coastguard Worker            "args",
2238*da0073e9SAndroid Build Coastguard Worker            "keywords",
2239*da0073e9SAndroid Build Coastguard Worker            "func",
2240*da0073e9SAndroid Build Coastguard Worker            # False
2241*da0073e9SAndroid Build Coastguard Worker            "__code__",
2242*da0073e9SAndroid Build Coastguard Worker            "__kwdefaults__",
2243*da0073e9SAndroid Build Coastguard Worker            "__defaults__",
2244*da0073e9SAndroid Build Coastguard Worker            "__name__",
2245*da0073e9SAndroid Build Coastguard Worker            "__annotations__",
2246*da0073e9SAndroid Build Coastguard Worker            "__get__",
2247*da0073e9SAndroid Build Coastguard Worker            "__builtins__",
2248*da0073e9SAndroid Build Coastguard Worker            "__qualname__",
2249*da0073e9SAndroid Build Coastguard Worker            "__globals__",
2250*da0073e9SAndroid Build Coastguard Worker            "__closure__",
2251*da0073e9SAndroid Build Coastguard Worker        ),
2252*da0073e9SAndroid Build Coastguard Worker    )
2253*da0073e9SAndroid Build Coastguard Worker    def test_partials_hasattr(self, attr):
2254*da0073e9SAndroid Build Coastguard Worker        def fn(t):
2255*da0073e9SAndroid Build Coastguard Worker            f = lambda x, y: torch.sin(x) + torch.cos(y)
2256*da0073e9SAndroid Build Coastguard Worker            p = functools.partial(f, y=t)
2257*da0073e9SAndroid Build Coastguard Worker            if hasattr(p, attr):
2258*da0073e9SAndroid Build Coastguard Worker                return p(t)
2259*da0073e9SAndroid Build Coastguard Worker            else:
2260*da0073e9SAndroid Build Coastguard Worker                return torch.zeros_like(t)
2261*da0073e9SAndroid Build Coastguard Worker
2262*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(3, 4)
2263*da0073e9SAndroid Build Coastguard Worker        counter = torch._dynamo.testing.CompileCounter()
2264*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fullgraph=True, backend=counter)(fn)
2265*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(t), fn(t))
2266*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(counter.frame_count, 0)
2267*da0073e9SAndroid Build Coastguard Worker
2268*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
2269*da0073e9SAndroid Build Coastguard Worker    def test_partials_hasattr_set_attr(self):
2270*da0073e9SAndroid Build Coastguard Worker        def fn(t):
2271*da0073e9SAndroid Build Coastguard Worker            f = lambda x, y: torch.sin(x) + torch.cos(y)
2272*da0073e9SAndroid Build Coastguard Worker            p = functools.partial(f, y=t)
2273*da0073e9SAndroid Build Coastguard Worker            p.__name__ = "test"
2274*da0073e9SAndroid Build Coastguard Worker            if hasattr(p, "__name__"):
2275*da0073e9SAndroid Build Coastguard Worker                return p(t)
2276*da0073e9SAndroid Build Coastguard Worker            else:
2277*da0073e9SAndroid Build Coastguard Worker                return torch.zeros_like(t)
2278*da0073e9SAndroid Build Coastguard Worker
2279*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(3, 4)
2280*da0073e9SAndroid Build Coastguard Worker        counter = torch._dynamo.testing.CompileCounter()
2281*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fullgraph=True, backend=counter)(fn)
2282*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(t), fn(t))
2283*da0073e9SAndroid Build Coastguard Worker
2284*da0073e9SAndroid Build Coastguard Worker    def test_filter(self):
2285*da0073e9SAndroid Build Coastguard Worker        def fn(inputs):
2286*da0073e9SAndroid Build Coastguard Worker            out = inputs[0]
2287*da0073e9SAndroid Build Coastguard Worker            for inp in filter(lambda x: (x.requires_grad), inputs):
2288*da0073e9SAndroid Build Coastguard Worker                out = out * inp
2289*da0073e9SAndroid Build Coastguard Worker            return out
2290*da0073e9SAndroid Build Coastguard Worker
2291*da0073e9SAndroid Build Coastguard Worker        input1 = torch.arange(2, dtype=torch.bfloat16)
2292*da0073e9SAndroid Build Coastguard Worker        input2 = torch.arange(2, dtype=torch.bfloat16).requires_grad_(True)
2293*da0073e9SAndroid Build Coastguard Worker        inputs = [input1, input2]
2294*da0073e9SAndroid Build Coastguard Worker
2295*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fullgraph=True)(fn)
2296*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(inputs), fn(inputs))
2297*da0073e9SAndroid Build Coastguard Worker
2298*da0073e9SAndroid Build Coastguard Worker    def test_filter_fallback(self):
2299*da0073e9SAndroid Build Coastguard Worker        def fn(inputs):
2300*da0073e9SAndroid Build Coastguard Worker            out = inputs[0]
2301*da0073e9SAndroid Build Coastguard Worker            for inp in filter(lambda x: x[0] == 1, inputs):
2302*da0073e9SAndroid Build Coastguard Worker                out = out * inp
2303*da0073e9SAndroid Build Coastguard Worker            return out
2304*da0073e9SAndroid Build Coastguard Worker
2305*da0073e9SAndroid Build Coastguard Worker        input1 = torch.ones(2, dtype=torch.bfloat16)
2306*da0073e9SAndroid Build Coastguard Worker        input2 = torch.arange(2, dtype=torch.bfloat16)
2307*da0073e9SAndroid Build Coastguard Worker        inputs = [input1, input2]
2308*da0073e9SAndroid Build Coastguard Worker
2309*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile()(fn)
2310*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(inputs), fn(inputs))
2311*da0073e9SAndroid Build Coastguard Worker
2312*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
2313*da0073e9SAndroid Build Coastguard Worker
2314*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(torch._dynamo.exc.Unsupported):
2315*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch.compile(fullgraph=True)(fn)
2316*da0073e9SAndroid Build Coastguard Worker            opt_fn(inputs)
2317*da0073e9SAndroid Build Coastguard Worker
2318*da0073e9SAndroid Build Coastguard Worker    def test_pow_int(self):
2319*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
2320*da0073e9SAndroid Build Coastguard Worker            return torch.pow(a, b)
2321*da0073e9SAndroid Build Coastguard Worker
2322*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2)
2323*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fullgraph=True, backend="eager", dynamic=True)(fn)
2324*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, 2), fn(x, 2))
2325*da0073e9SAndroid Build Coastguard Worker
2326*da0073e9SAndroid Build Coastguard Worker    def test_tensor_size_indexed_by_symint(self):
2327*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
2328*da0073e9SAndroid Build Coastguard Worker            index = x.shape[-1]
2329*da0073e9SAndroid Build Coastguard Worker            return x + y.shape[index]
2330*da0073e9SAndroid Build Coastguard Worker
2331*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(10, 2)
2332*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(10, 8, 6)
2333*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
2334*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, y), fn(x, y))
2335*da0073e9SAndroid Build Coastguard Worker
2336*da0073e9SAndroid Build Coastguard Worker    def test_partials_as_input_partials_lambda(self):
2337*da0073e9SAndroid Build Coastguard Worker        def fn(f0, f1, x):
2338*da0073e9SAndroid Build Coastguard Worker            return f0(x) * f1(x)
2339*da0073e9SAndroid Build Coastguard Worker
2340*da0073e9SAndroid Build Coastguard Worker        multiply = lambda x, y: x * y
2341*da0073e9SAndroid Build Coastguard Worker        lambda0 = functools.partial(multiply, y=3)
2342*da0073e9SAndroid Build Coastguard Worker        lambda1 = functools.partial(multiply, y=2)
2343*da0073e9SAndroid Build Coastguard Worker
2344*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2345*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.optimize(cnts, nopython=True)(fn)(
2346*da0073e9SAndroid Build Coastguard Worker            lambda0, lambda1, torch.randn(2, 2)
2347*da0073e9SAndroid Build Coastguard Worker        )
2348*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2349*da0073e9SAndroid Build Coastguard Worker
2350*da0073e9SAndroid Build Coastguard Worker    def test_partials_as_input_partials_mod(self):
2351*da0073e9SAndroid Build Coastguard Worker        def fn(f0, f1, x):
2352*da0073e9SAndroid Build Coastguard Worker            return f0(x) * f1(x)
2353*da0073e9SAndroid Build Coastguard Worker
2354*da0073e9SAndroid Build Coastguard Worker        lambda0 = functools.partial(SmallNN(), y=torch.randn(2, 2))
2355*da0073e9SAndroid Build Coastguard Worker        lambda1 = functools.partial(SmallNN(), y=torch.randn(2, 2))
2356*da0073e9SAndroid Build Coastguard Worker
2357*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2358*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
2359*da0073e9SAndroid Build Coastguard Worker        dynamo_result = torch._dynamo.optimize(cnts, nopython=True)(fn)(
2360*da0073e9SAndroid Build Coastguard Worker            lambda0, lambda1, x
2361*da0073e9SAndroid Build Coastguard Worker        )
2362*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2363*da0073e9SAndroid Build Coastguard Worker
2364*da0073e9SAndroid Build Coastguard Worker        eager_result = fn(lambda0, lambda1, x)
2365*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_result, dynamo_result)
2366*da0073e9SAndroid Build Coastguard Worker
2367*da0073e9SAndroid Build Coastguard Worker    def test_partials_as_input_UDF(self):
2368*da0073e9SAndroid Build Coastguard Worker        def fn(f0, f1, x):
2369*da0073e9SAndroid Build Coastguard Worker            return f0(x) * f1(x)
2370*da0073e9SAndroid Build Coastguard Worker
2371*da0073e9SAndroid Build Coastguard Worker        lambda0 = functools.partial(udf_mul, y=torch.randn(2, 2))
2372*da0073e9SAndroid Build Coastguard Worker        lambda1 = functools.partial(udf_mul, y=torch.randn(2, 2))
2373*da0073e9SAndroid Build Coastguard Worker
2374*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2375*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
2376*da0073e9SAndroid Build Coastguard Worker        dynamo_result = torch._dynamo.optimize(cnts, nopython=True)(fn)(
2377*da0073e9SAndroid Build Coastguard Worker            lambda0, lambda1, x
2378*da0073e9SAndroid Build Coastguard Worker        )
2379*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2380*da0073e9SAndroid Build Coastguard Worker
2381*da0073e9SAndroid Build Coastguard Worker        eager_result = fn(lambda0, lambda1, x)
2382*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_result, dynamo_result)
2383*da0073e9SAndroid Build Coastguard Worker
2384*da0073e9SAndroid Build Coastguard Worker    def test_partials_graph_break_reconstruct(self):
2385*da0073e9SAndroid Build Coastguard Worker        def fn(udf_mul_0, udf_mul_1, x):
2386*da0073e9SAndroid Build Coastguard Worker            lambda0 = functools.partial(udf_mul_0, y=x)
2387*da0073e9SAndroid Build Coastguard Worker            lambda1 = functools.partial(udf_mul_1, y=x)
2388*da0073e9SAndroid Build Coastguard Worker
2389*da0073e9SAndroid Build Coastguard Worker            print("break")
2390*da0073e9SAndroid Build Coastguard Worker            return torch.mul(lambda0(x), lambda1(x))
2391*da0073e9SAndroid Build Coastguard Worker
2392*da0073e9SAndroid Build Coastguard Worker        backend = EagerAndRecordGraphs()
2393*da0073e9SAndroid Build Coastguard Worker        cnts = CompileCounterWithBackend(backend)
2394*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
2395*da0073e9SAndroid Build Coastguard Worker        dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_mul, x)
2396*da0073e9SAndroid Build Coastguard Worker
2397*da0073e9SAndroid Build Coastguard Worker        eager_result = fn(udf_mul, udf_mul, x)
2398*da0073e9SAndroid Build Coastguard Worker        gm = backend.graphs[0]
2399*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_result, dynamo_result)
2400*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
2401*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(
2402*da0073e9SAndroid Build Coastguard Worker                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2403*da0073e9SAndroid Build Coastguard Worker                """\
2404*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
2405*da0073e9SAndroid Build Coastguard Worker    def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
2406*da0073e9SAndroid Build Coastguard Worker        l_lambda0_keywords_y_ = L_lambda0_keywords_y_
2407*da0073e9SAndroid Build Coastguard Worker
2408*da0073e9SAndroid Build Coastguard Worker        mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2409*da0073e9SAndroid Build Coastguard Worker        mul_1: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_;  l_lambda0_keywords_y_ = None
2410*da0073e9SAndroid Build Coastguard Worker
2411*da0073e9SAndroid Build Coastguard Worker        mul_2: "f32[2, 2]" = torch.mul(mul, mul_1);  mul = mul_1 = None
2412*da0073e9SAndroid Build Coastguard Worker        return (mul_2,)
2413*da0073e9SAndroid Build Coastguard Worker""",
2414*da0073e9SAndroid Build Coastguard Worker            )
2415*da0073e9SAndroid Build Coastguard Worker        else:
2416*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(
2417*da0073e9SAndroid Build Coastguard Worker                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2418*da0073e9SAndroid Build Coastguard Worker                """\
2419*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
2420*da0073e9SAndroid Build Coastguard Worker    def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
2421*da0073e9SAndroid Build Coastguard Worker        l_lambda0_keywords_y_ = L_lambda0_keywords_y_
2422*da0073e9SAndroid Build Coastguard Worker
2423*da0073e9SAndroid Build Coastguard Worker        mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2424*da0073e9SAndroid Build Coastguard Worker        mul_1: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_;  l_lambda0_keywords_y_ = None
2425*da0073e9SAndroid Build Coastguard Worker
2426*da0073e9SAndroid Build Coastguard Worker        mul_2: "f32[s0, s0]" = torch.mul(mul, mul_1);  mul = mul_1 = None
2427*da0073e9SAndroid Build Coastguard Worker        return (mul_2,)
2428*da0073e9SAndroid Build Coastguard Worker""",
2429*da0073e9SAndroid Build Coastguard Worker            )
2430*da0073e9SAndroid Build Coastguard Worker
2431*da0073e9SAndroid Build Coastguard Worker    def test_partials_graph_break_reconstruct_mix(self):
2432*da0073e9SAndroid Build Coastguard Worker        def fn(udf_mul_0, udf_add_1, x):
2433*da0073e9SAndroid Build Coastguard Worker            lambda0 = functools.partial(udf_mul_0, y=x)
2434*da0073e9SAndroid Build Coastguard Worker            lambda1 = functools.partial(udf_add_1, x)
2435*da0073e9SAndroid Build Coastguard Worker
2436*da0073e9SAndroid Build Coastguard Worker            print("break")
2437*da0073e9SAndroid Build Coastguard Worker            return torch.mul(lambda0(x), lambda1(x))
2438*da0073e9SAndroid Build Coastguard Worker
2439*da0073e9SAndroid Build Coastguard Worker        backend = EagerAndRecordGraphs()
2440*da0073e9SAndroid Build Coastguard Worker        cnts = CompileCounterWithBackend(backend)
2441*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
2442*da0073e9SAndroid Build Coastguard Worker        dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_add, x)
2443*da0073e9SAndroid Build Coastguard Worker
2444*da0073e9SAndroid Build Coastguard Worker        eager_result = fn(udf_mul, udf_add, x)
2445*da0073e9SAndroid Build Coastguard Worker        gm = backend.graphs[0]
2446*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_result, dynamo_result)
2447*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
2448*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(
2449*da0073e9SAndroid Build Coastguard Worker                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2450*da0073e9SAndroid Build Coastguard Worker                """\
2451*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
2452*da0073e9SAndroid Build Coastguard Worker    def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
2453*da0073e9SAndroid Build Coastguard Worker        l_lambda0_keywords_y_ = L_lambda0_keywords_y_
2454*da0073e9SAndroid Build Coastguard Worker
2455*da0073e9SAndroid Build Coastguard Worker        mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2456*da0073e9SAndroid Build Coastguard Worker
2457*da0073e9SAndroid Build Coastguard Worker        add: "f32[2, 2]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_;  l_lambda0_keywords_y_ = None
2458*da0073e9SAndroid Build Coastguard Worker
2459*da0073e9SAndroid Build Coastguard Worker        mul_1: "f32[2, 2]" = torch.mul(mul, add);  mul = add = None
2460*da0073e9SAndroid Build Coastguard Worker        return (mul_1,)
2461*da0073e9SAndroid Build Coastguard Worker""",
2462*da0073e9SAndroid Build Coastguard Worker            )
2463*da0073e9SAndroid Build Coastguard Worker        else:
2464*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(
2465*da0073e9SAndroid Build Coastguard Worker                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2466*da0073e9SAndroid Build Coastguard Worker                """\
2467*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
2468*da0073e9SAndroid Build Coastguard Worker    def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
2469*da0073e9SAndroid Build Coastguard Worker        l_lambda0_keywords_y_ = L_lambda0_keywords_y_
2470*da0073e9SAndroid Build Coastguard Worker
2471*da0073e9SAndroid Build Coastguard Worker        mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2472*da0073e9SAndroid Build Coastguard Worker
2473*da0073e9SAndroid Build Coastguard Worker        add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_;  l_lambda0_keywords_y_ = None
2474*da0073e9SAndroid Build Coastguard Worker
2475*da0073e9SAndroid Build Coastguard Worker        mul_1: "f32[s0, s0]" = torch.mul(mul, add);  mul = add = None
2476*da0073e9SAndroid Build Coastguard Worker        return (mul_1,)
2477*da0073e9SAndroid Build Coastguard Worker""",
2478*da0073e9SAndroid Build Coastguard Worker            )
2479*da0073e9SAndroid Build Coastguard Worker
2480*da0073e9SAndroid Build Coastguard Worker    def test_partials_graph_break_reconstruct_mix_no_source(self):
2481*da0073e9SAndroid Build Coastguard Worker        def fn(udf_mul_0, x):
2482*da0073e9SAndroid Build Coastguard Worker            udf_add_1 = lambda x, y: x + y
2483*da0073e9SAndroid Build Coastguard Worker
2484*da0073e9SAndroid Build Coastguard Worker            lambda0 = functools.partial(udf_mul_0, y=x)
2485*da0073e9SAndroid Build Coastguard Worker            lambda1 = functools.partial(udf_add_1, x)
2486*da0073e9SAndroid Build Coastguard Worker
2487*da0073e9SAndroid Build Coastguard Worker            print("break")
2488*da0073e9SAndroid Build Coastguard Worker            return torch.mul(lambda0(x), lambda1(x))
2489*da0073e9SAndroid Build Coastguard Worker
2490*da0073e9SAndroid Build Coastguard Worker        backend = EagerAndRecordGraphs()
2491*da0073e9SAndroid Build Coastguard Worker        cnts = CompileCounterWithBackend(backend)
2492*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
2493*da0073e9SAndroid Build Coastguard Worker        dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, x)
2494*da0073e9SAndroid Build Coastguard Worker
2495*da0073e9SAndroid Build Coastguard Worker        eager_result = fn(udf_mul, x)
2496*da0073e9SAndroid Build Coastguard Worker        gm = backend.graphs[0]
2497*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_result, dynamo_result)
2498*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
2499*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(
2500*da0073e9SAndroid Build Coastguard Worker                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2501*da0073e9SAndroid Build Coastguard Worker                """\
2502*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
2503*da0073e9SAndroid Build Coastguard Worker    def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
2504*da0073e9SAndroid Build Coastguard Worker        l_lambda0_keywords_y_ = L_lambda0_keywords_y_
2505*da0073e9SAndroid Build Coastguard Worker
2506*da0073e9SAndroid Build Coastguard Worker        mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2507*da0073e9SAndroid Build Coastguard Worker
2508*da0073e9SAndroid Build Coastguard Worker        add: "f32[2, 2]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_;  l_lambda0_keywords_y_ = None
2509*da0073e9SAndroid Build Coastguard Worker
2510*da0073e9SAndroid Build Coastguard Worker        mul_1: "f32[2, 2]" = torch.mul(mul, add);  mul = add = None
2511*da0073e9SAndroid Build Coastguard Worker        return (mul_1,)
2512*da0073e9SAndroid Build Coastguard Worker""",
2513*da0073e9SAndroid Build Coastguard Worker            )
2514*da0073e9SAndroid Build Coastguard Worker        else:
2515*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(
2516*da0073e9SAndroid Build Coastguard Worker                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2517*da0073e9SAndroid Build Coastguard Worker                """\
2518*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
2519*da0073e9SAndroid Build Coastguard Worker    def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
2520*da0073e9SAndroid Build Coastguard Worker        l_lambda0_keywords_y_ = L_lambda0_keywords_y_
2521*da0073e9SAndroid Build Coastguard Worker
2522*da0073e9SAndroid Build Coastguard Worker        mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
2523*da0073e9SAndroid Build Coastguard Worker
2524*da0073e9SAndroid Build Coastguard Worker        add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_;  l_lambda0_keywords_y_ = None
2525*da0073e9SAndroid Build Coastguard Worker
2526*da0073e9SAndroid Build Coastguard Worker        mul_1: "f32[s0, s0]" = torch.mul(mul, add);  mul = add = None
2527*da0073e9SAndroid Build Coastguard Worker        return (mul_1,)
2528*da0073e9SAndroid Build Coastguard Worker""",
2529*da0073e9SAndroid Build Coastguard Worker            )
2530*da0073e9SAndroid Build Coastguard Worker
2531*da0073e9SAndroid Build Coastguard Worker    def test_partials_graph_break_reconstruct_args_and_kwargs(self):
2532*da0073e9SAndroid Build Coastguard Worker        def fn(udf_mul_0, x):
2533*da0073e9SAndroid Build Coastguard Worker            lambda0 = functools.partial(udf_mul_0, x, 4, z=x)
2534*da0073e9SAndroid Build Coastguard Worker            lambda1 = functools.partial(udf_mul_0, 4, z=x)
2535*da0073e9SAndroid Build Coastguard Worker
2536*da0073e9SAndroid Build Coastguard Worker            return torch.mul(lambda0(), lambda1(5))
2537*da0073e9SAndroid Build Coastguard Worker
2538*da0073e9SAndroid Build Coastguard Worker        backend = EagerAndRecordGraphs()
2539*da0073e9SAndroid Build Coastguard Worker        cnts = CompileCounterWithBackend(backend)
2540*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
2541*da0073e9SAndroid Build Coastguard Worker        dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul2, x)
2542*da0073e9SAndroid Build Coastguard Worker
2543*da0073e9SAndroid Build Coastguard Worker        eager_result = fn(udf_mul2, x)
2544*da0073e9SAndroid Build Coastguard Worker        gm = backend.graphs[0]
2545*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_result, dynamo_result)
2546*da0073e9SAndroid Build Coastguard Worker        if torch._dynamo.config.assume_static_by_default:
2547*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(
2548*da0073e9SAndroid Build Coastguard Worker                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2549*da0073e9SAndroid Build Coastguard Worker                """\
2550*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
2551*da0073e9SAndroid Build Coastguard Worker    def forward(self, L_x_: "f32[2, 2]"):
2552*da0073e9SAndroid Build Coastguard Worker        l_x_ = L_x_
2553*da0073e9SAndroid Build Coastguard Worker
2554*da0073e9SAndroid Build Coastguard Worker        mul: "f32[2, 2]" = l_x_ * 4
2555*da0073e9SAndroid Build Coastguard Worker        mul_1: "f32[2, 2]" = mul * l_x_;  mul = None
2556*da0073e9SAndroid Build Coastguard Worker        mul_2: "f32[2, 2]" = 20 * l_x_;  l_x_ = None
2557*da0073e9SAndroid Build Coastguard Worker
2558*da0073e9SAndroid Build Coastguard Worker        mul_3: "f32[2, 2]" = torch.mul(mul_1, mul_2);  mul_1 = mul_2 = None
2559*da0073e9SAndroid Build Coastguard Worker        return (mul_3,)
2560*da0073e9SAndroid Build Coastguard Worker""",
2561*da0073e9SAndroid Build Coastguard Worker            )
2562*da0073e9SAndroid Build Coastguard Worker        else:
2563*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(
2564*da0073e9SAndroid Build Coastguard Worker                normalize_gm(backend.graphs[0].print_readable(print_output=False)),
2565*da0073e9SAndroid Build Coastguard Worker                """\
2566*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module):
2567*da0073e9SAndroid Build Coastguard Worker    def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, s0]"):
2568*da0073e9SAndroid Build Coastguard Worker        l_x_ = L_x_
2569*da0073e9SAndroid Build Coastguard Worker
2570*da0073e9SAndroid Build Coastguard Worker        mul: "f32[s0, s0]" = l_x_ * 4
2571*da0073e9SAndroid Build Coastguard Worker        mul_1: "f32[s0, s0]" = mul * l_x_;  mul = None
2572*da0073e9SAndroid Build Coastguard Worker        mul_2: "f32[s0, s0]" = 20 * l_x_;  l_x_ = None
2573*da0073e9SAndroid Build Coastguard Worker
2574*da0073e9SAndroid Build Coastguard Worker        mul_3: "f32[s0, s0]" = torch.mul(mul_1, mul_2);  mul_1 = mul_2 = None
2575*da0073e9SAndroid Build Coastguard Worker        return (mul_3,)
2576*da0073e9SAndroid Build Coastguard Worker""",
2577*da0073e9SAndroid Build Coastguard Worker            )
2578*da0073e9SAndroid Build Coastguard Worker
2579*da0073e9SAndroid Build Coastguard Worker    def test_partials_recompilation(self):
2580*da0073e9SAndroid Build Coastguard Worker        def fn(f0, f1, x):
2581*da0073e9SAndroid Build Coastguard Worker            return f0(x) * f1(x)
2582*da0073e9SAndroid Build Coastguard Worker
2583*da0073e9SAndroid Build Coastguard Worker        lambda0 = functools.partial(udf_mul, y=torch.randn(2, 2))
2584*da0073e9SAndroid Build Coastguard Worker        lambda1 = functools.partial(udf_mul, y=torch.randn(2, 2))
2585*da0073e9SAndroid Build Coastguard Worker
2586*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2587*da0073e9SAndroid Build Coastguard Worker
2588*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
2589*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
2590*da0073e9SAndroid Build Coastguard Worker        dynamo_result = fn(lambda0, lambda1, x)
2591*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2592*da0073e9SAndroid Build Coastguard Worker
2593*da0073e9SAndroid Build Coastguard Worker        fn(lambda1, lambda0, x)
2594*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2595*da0073e9SAndroid Build Coastguard Worker            cnts.frame_count, 1
2596*da0073e9SAndroid Build Coastguard Worker        )  # No recompile! Tensor and udf_mul guarded
2597*da0073e9SAndroid Build Coastguard Worker
2598*da0073e9SAndroid Build Coastguard Worker        lambda2 = functools.partial(udf_mul, y=torch.randn(3, 3))
2599*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
2600*da0073e9SAndroid Build Coastguard Worker        fn(lambda2, lambda2, x)
2601*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)  # Recompile! Tensor size changed
2602*da0073e9SAndroid Build Coastguard Worker
2603*da0073e9SAndroid Build Coastguard Worker        multiply = lambda x, y: x * y
2604*da0073e9SAndroid Build Coastguard Worker        lambda3 = functools.partial(multiply, y=torch.randn(3, 3))
2605*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
2606*da0073e9SAndroid Build Coastguard Worker        fn(lambda3, lambda3, x)
2607*da0073e9SAndroid Build Coastguard Worker
2608*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 3)  # Recompile! func id changed
2609*da0073e9SAndroid Build Coastguard Worker
2610*da0073e9SAndroid Build Coastguard Worker        def fn2(f0, f1, args):
2611*da0073e9SAndroid Build Coastguard Worker            return f0(*args) * f1(*args)
2612*da0073e9SAndroid Build Coastguard Worker
2613*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2614*da0073e9SAndroid Build Coastguard Worker
2615*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
2616*da0073e9SAndroid Build Coastguard Worker        fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2)
2617*da0073e9SAndroid Build Coastguard Worker        dynamo_result = fn2(lambda0, lambda1, [x])
2618*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)  # start over
2619*da0073e9SAndroid Build Coastguard Worker
2620*da0073e9SAndroid Build Coastguard Worker        lambda4 = functools.partial(multiply, y=3, x=torch.randn(3, 3))
2621*da0073e9SAndroid Build Coastguard Worker        fn2(lambda4, lambda4, [])
2622*da0073e9SAndroid Build Coastguard Worker
2623*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)  # Recompile! Different kwarg keys
2624*da0073e9SAndroid Build Coastguard Worker
2625*da0073e9SAndroid Build Coastguard Worker        lambda5 = functools.partial(multiply, 1)
2626*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
2627*da0073e9SAndroid Build Coastguard Worker        fn2(lambda5, lambda5, [x])
2628*da0073e9SAndroid Build Coastguard Worker
2629*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 3)  # Recompile! Different arg keys
2630*da0073e9SAndroid Build Coastguard Worker
2631*da0073e9SAndroid Build Coastguard Worker        lambda6 = lambda x: x + x
2632*da0073e9SAndroid Build Coastguard Worker        fn2(lambda6, lambda6, [x])
2633*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2634*da0073e9SAndroid Build Coastguard Worker            cnts.frame_count, 4
2635*da0073e9SAndroid Build Coastguard Worker        )  # Recompile! input is no longer a functools partial
2636*da0073e9SAndroid Build Coastguard Worker
2637*da0073e9SAndroid Build Coastguard Worker    def test_manual_seed(self):
2638*da0073e9SAndroid Build Coastguard Worker        @torch.compile
2639*da0073e9SAndroid Build Coastguard Worker        def foo():
2640*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(3)
2641*da0073e9SAndroid Build Coastguard Worker            return torch.randint(0, 5, (5,))
2642*da0073e9SAndroid Build Coastguard Worker
2643*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), foo())
2644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), foo())
2645*da0073e9SAndroid Build Coastguard Worker
2646*da0073e9SAndroid Build Coastguard Worker    def test_partial_across_graph_break_uninvoked(self):
2647*da0073e9SAndroid Build Coastguard Worker        from functools import partial
2648*da0073e9SAndroid Build Coastguard Worker
2649*da0073e9SAndroid Build Coastguard Worker        def bar(x, **kwargs):
2650*da0073e9SAndroid Build Coastguard Worker            return x + x
2651*da0073e9SAndroid Build Coastguard Worker
2652*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", dynamic=True)
2653*da0073e9SAndroid Build Coastguard Worker        def foo(x, i):
2654*da0073e9SAndroid Build Coastguard Worker            def inner():
2655*da0073e9SAndroid Build Coastguard Worker                print("this is a graph_break")
2656*da0073e9SAndroid Build Coastguard Worker                return op(x)
2657*da0073e9SAndroid Build Coastguard Worker
2658*da0073e9SAndroid Build Coastguard Worker            op = partial(bar, dim=10)
2659*da0073e9SAndroid Build Coastguard Worker            x = inner()
2660*da0073e9SAndroid Build Coastguard Worker            op = partial(bar, other=10)
2661*da0073e9SAndroid Build Coastguard Worker            return inner() + x
2662*da0073e9SAndroid Build Coastguard Worker
2663*da0073e9SAndroid Build Coastguard Worker        foo(torch.rand(1), 10)
2664*da0073e9SAndroid Build Coastguard Worker
2665*da0073e9SAndroid Build Coastguard Worker    def test_no_recompile_inner_function(self):
2666*da0073e9SAndroid Build Coastguard Worker        def forward(inp):
2667*da0073e9SAndroid Build Coastguard Worker            def g(y):
2668*da0073e9SAndroid Build Coastguard Worker                return inp + y
2669*da0073e9SAndroid Build Coastguard Worker
2670*da0073e9SAndroid Build Coastguard Worker            print("graph break")
2671*da0073e9SAndroid Build Coastguard Worker            return g(torch.rand([1]))
2672*da0073e9SAndroid Build Coastguard Worker
2673*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2674*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(forward)
2675*da0073e9SAndroid Build Coastguard Worker
2676*da0073e9SAndroid Build Coastguard Worker        input = torch.rand([2])
2677*da0073e9SAndroid Build Coastguard Worker        _ = opt_fn(input)
2678*da0073e9SAndroid Build Coastguard Worker        _ = opt_fn(input)
2679*da0073e9SAndroid Build Coastguard Worker        _ = opt_fn(input)
2680*da0073e9SAndroid Build Coastguard Worker        # Should not have recompiled
2681*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2682*da0073e9SAndroid Build Coastguard Worker
2683*da0073e9SAndroid Build Coastguard Worker    def test_no_recompile_inner_lambda(self):
2684*da0073e9SAndroid Build Coastguard Worker        def forward(inp):
2685*da0073e9SAndroid Build Coastguard Worker            g = lambda y: inp + y
2686*da0073e9SAndroid Build Coastguard Worker            print("graph break")
2687*da0073e9SAndroid Build Coastguard Worker            return g(torch.rand([1]))
2688*da0073e9SAndroid Build Coastguard Worker
2689*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
2690*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(forward)
2691*da0073e9SAndroid Build Coastguard Worker
2692*da0073e9SAndroid Build Coastguard Worker        input = torch.rand([2])
2693*da0073e9SAndroid Build Coastguard Worker        _ = opt_fn(input)
2694*da0073e9SAndroid Build Coastguard Worker        _ = opt_fn(input)
2695*da0073e9SAndroid Build Coastguard Worker        _ = opt_fn(input)
2696*da0073e9SAndroid Build Coastguard Worker        # Should not have recompiled
2697*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
2698*da0073e9SAndroid Build Coastguard Worker
2699*da0073e9SAndroid Build Coastguard Worker    def test_complex_closure(self):
2700*da0073e9SAndroid Build Coastguard Worker        @torch.compile
2701*da0073e9SAndroid Build Coastguard Worker        def forward(y):
2702*da0073e9SAndroid Build Coastguard Worker            def a():
2703*da0073e9SAndroid Build Coastguard Worker                def x(z):
2704*da0073e9SAndroid Build Coastguard Worker                    return y + z
2705*da0073e9SAndroid Build Coastguard Worker
2706*da0073e9SAndroid Build Coastguard Worker                return x
2707*da0073e9SAndroid Build Coastguard Worker
2708*da0073e9SAndroid Build Coastguard Worker            return a()
2709*da0073e9SAndroid Build Coastguard Worker
2710*da0073e9SAndroid Build Coastguard Worker        input1 = torch.rand([2])
2711*da0073e9SAndroid Build Coastguard Worker        input2 = torch.rand([2])
2712*da0073e9SAndroid Build Coastguard Worker        res = forward(input1)(input2)
2713*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res, input1 + input2))
2714*da0073e9SAndroid Build Coastguard Worker
2715*da0073e9SAndroid Build Coastguard Worker    def test_non_inlined_closure(self):
2716*da0073e9SAndroid Build Coastguard Worker        @torch.compile()
2717*da0073e9SAndroid Build Coastguard Worker        def program(x, y):
2718*da0073e9SAndroid Build Coastguard Worker            one = lambda x, y: x + y
2719*da0073e9SAndroid Build Coastguard Worker
2720*da0073e9SAndroid Build Coastguard Worker            def inner():
2721*da0073e9SAndroid Build Coastguard Worker                # Force no inlining
2722*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
2723*da0073e9SAndroid Build Coastguard Worker                return one(x, y)
2724*da0073e9SAndroid Build Coastguard Worker
2725*da0073e9SAndroid Build Coastguard Worker            res = inner()
2726*da0073e9SAndroid Build Coastguard Worker            one = lambda x, y: x - y
2727*da0073e9SAndroid Build Coastguard Worker            res += inner()
2728*da0073e9SAndroid Build Coastguard Worker            return res
2729*da0073e9SAndroid Build Coastguard Worker
2730*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(1)
2731*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(1)
2732*da0073e9SAndroid Build Coastguard Worker
2733*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(program(input1, input2), input1 + input1))
2734*da0073e9SAndroid Build Coastguard Worker
2735*da0073e9SAndroid Build Coastguard Worker    @parametrize("int_or_float", ("int", "float"))
2736*da0073e9SAndroid Build Coastguard Worker    def test_np_constant_collections_as_input(self, int_or_float):
2737*da0073e9SAndroid Build Coastguard Worker        info_func = getattr(np, f"{int_or_float[0]}info")
2738*da0073e9SAndroid Build Coastguard Worker        dt_string_arg = f"{int_or_float}16"
2739*da0073e9SAndroid Build Coastguard Worker        np_dt_attr = getattr(np, dt_string_arg)
2740*da0073e9SAndroid Build Coastguard Worker
2741*da0073e9SAndroid Build Coastguard Worker        dt_args = [dt_string_arg, np_dt_attr]
2742*da0073e9SAndroid Build Coastguard Worker        arg_variants_iter = itertools.chain(
2743*da0073e9SAndroid Build Coastguard Worker            dt_args, map(np.dtype, dt_args), map(info_func, dt_args)
2744*da0073e9SAndroid Build Coastguard Worker        )
2745*da0073e9SAndroid Build Coastguard Worker
2746*da0073e9SAndroid Build Coastguard Worker        def func(a, b, info_or_dt):
2747*da0073e9SAndroid Build Coastguard Worker            return a + info_func(info_or_dt).max
2748*da0073e9SAndroid Build Coastguard Worker
2749*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(func)
2750*da0073e9SAndroid Build Coastguard Worker
2751*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2)
2752*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2)
2753*da0073e9SAndroid Build Coastguard Worker        eager_result = func(a, b, dt_args[0])
2754*da0073e9SAndroid Build Coastguard Worker
2755*da0073e9SAndroid Build Coastguard Worker        for arg in arg_variants_iter:
2756*da0073e9SAndroid Build Coastguard Worker            opt_result = opt_fn(a, b, arg)
2757*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(opt_result, eager_result))
2758*da0073e9SAndroid Build Coastguard Worker
2759*da0073e9SAndroid Build Coastguard Worker    @parametrize(
2760*da0073e9SAndroid Build Coastguard Worker        "typ, info_func",
2761*da0073e9SAndroid Build Coastguard Worker        [
2762*da0073e9SAndroid Build Coastguard Worker            (int, np.iinfo),
2763*da0073e9SAndroid Build Coastguard Worker            (float, np.finfo),
2764*da0073e9SAndroid Build Coastguard Worker        ],
2765*da0073e9SAndroid Build Coastguard Worker        name_fn=lambda t, _: t.__name__,
2766*da0073e9SAndroid Build Coastguard Worker    )
2767*da0073e9SAndroid Build Coastguard Worker    def test_np_constant_collections_guards(self, typ, info_func):
2768*da0073e9SAndroid Build Coastguard Worker        def func_info(a, info):
2769*da0073e9SAndroid Build Coastguard Worker            return a + info.max
2770*da0073e9SAndroid Build Coastguard Worker
2771*da0073e9SAndroid Build Coastguard Worker        def func_dtype(a, dt):
2772*da0073e9SAndroid Build Coastguard Worker            return a + info_func(dt).max
2773*da0073e9SAndroid Build Coastguard Worker
2774*da0073e9SAndroid Build Coastguard Worker        dt_args = [
2775*da0073e9SAndroid Build Coastguard Worker            np.dtype(typ),
2776*da0073e9SAndroid Build Coastguard Worker            np.ones((1,), dtype=typ).dtype,
2777*da0073e9SAndroid Build Coastguard Worker            np.dtype(np.dtype(typ).name),
2778*da0073e9SAndroid Build Coastguard Worker            np.dtype(typ.__name__),
2779*da0073e9SAndroid Build Coastguard Worker        ]
2780*da0073e9SAndroid Build Coastguard Worker        cnts_1 = torch._dynamo.testing.CompileCounter()
2781*da0073e9SAndroid Build Coastguard Worker        opt_fn_dtype = torch._dynamo.optimize(cnts_1)(func_dtype)
2782*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(3, dtype=typ)
2783*da0073e9SAndroid Build Coastguard Worker        for arg in dt_args:
2784*da0073e9SAndroid Build Coastguard Worker            r = opt_fn_dtype(a, arg)
2785*da0073e9SAndroid Build Coastguard Worker        # each should produce an identical arg
2786*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts_1.frame_count, 1)
2787*da0073e9SAndroid Build Coastguard Worker
2788*da0073e9SAndroid Build Coastguard Worker        cnts_2 = torch._dynamo.testing.CompileCounter()
2789*da0073e9SAndroid Build Coastguard Worker        opt_fn_info = torch._dynamo.optimize(cnts_2)(func_info)
2790*da0073e9SAndroid Build Coastguard Worker        info_args = [info_func(dt) for dt in dt_args]
2791*da0073e9SAndroid Build Coastguard Worker        for arg in info_args:
2792*da0073e9SAndroid Build Coastguard Worker            r = opt_fn_info(a, arg)
2793*da0073e9SAndroid Build Coastguard Worker
2794*da0073e9SAndroid Build Coastguard Worker        # each should produce an identical arg
2795*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts_2.frame_count, 1)
2796*da0073e9SAndroid Build Coastguard Worker
2797*da0073e9SAndroid Build Coastguard Worker        if typ is float:
2798*da0073e9SAndroid Build Coastguard Worker            dt_extra = np.dtype(np.float16)
2799*da0073e9SAndroid Build Coastguard Worker        else:
2800*da0073e9SAndroid Build Coastguard Worker            dt_extra = np.dtype(np.int16)
2801*da0073e9SAndroid Build Coastguard Worker        info_extra = info_func(dt_extra)
2802*da0073e9SAndroid Build Coastguard Worker
2803*da0073e9SAndroid Build Coastguard Worker        eager_result_dtype = func_dtype(a, dt_extra)
2804*da0073e9SAndroid Build Coastguard Worker        compile_result_dtype = opt_fn_dtype(a, dt_extra)
2805*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts_1.frame_count, 2)
2806*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_result_dtype, compile_result_dtype)
2807*da0073e9SAndroid Build Coastguard Worker
2808*da0073e9SAndroid Build Coastguard Worker        eager_result_info = func_info(a, info_extra)
2809*da0073e9SAndroid Build Coastguard Worker        compile_result_info = opt_fn_info(a, info_extra)
2810*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts_2.frame_count, 2)
2811*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_result_info, compile_result_info)
2812*da0073e9SAndroid Build Coastguard Worker
2813*da0073e9SAndroid Build Coastguard Worker    def test_compare_constant_and_tensor(self):
2814*da0073e9SAndroid Build Coastguard Worker        for op in [
2815*da0073e9SAndroid Build Coastguard Worker            operator.lt,
2816*da0073e9SAndroid Build Coastguard Worker            operator.le,
2817*da0073e9SAndroid Build Coastguard Worker            operator.gt,
2818*da0073e9SAndroid Build Coastguard Worker            operator.ge,
2819*da0073e9SAndroid Build Coastguard Worker            operator.ne,
2820*da0073e9SAndroid Build Coastguard Worker            operator.eq,
2821*da0073e9SAndroid Build Coastguard Worker            operator.is_,
2822*da0073e9SAndroid Build Coastguard Worker            operator.is_not,
2823*da0073e9SAndroid Build Coastguard Worker        ]:
2824*da0073e9SAndroid Build Coastguard Worker            with self.subTest(op=op):
2825*da0073e9SAndroid Build Coastguard Worker
2826*da0073e9SAndroid Build Coastguard Worker                def fn(x):
2827*da0073e9SAndroid Build Coastguard Worker                    return op(-10, x)
2828*da0073e9SAndroid Build Coastguard Worker
2829*da0073e9SAndroid Build Coastguard Worker                opt_fn = torch.compile(fullgraph=True)(fn)
2830*da0073e9SAndroid Build Coastguard Worker
2831*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(10)
2832*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(opt_fn(x), fn(x))
2833*da0073e9SAndroid Build Coastguard Worker
2834*da0073e9SAndroid Build Coastguard Worker    def test_pos(self):
2835*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
2836*da0073e9SAndroid Build Coastguard Worker            return operator.pos(x) * +y
2837*da0073e9SAndroid Build Coastguard Worker
2838*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fullgraph=True, dynamic=True)(fn)
2839*da0073e9SAndroid Build Coastguard Worker
2840*da0073e9SAndroid Build Coastguard Worker        def test(x, y):
2841*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(opt_fn(x, y), fn(x, y))
2842*da0073e9SAndroid Build Coastguard Worker
2843*da0073e9SAndroid Build Coastguard Worker        test(torch.ones(4), 1)
2844*da0073e9SAndroid Build Coastguard Worker        test(1, torch.ones(4))
2845*da0073e9SAndroid Build Coastguard Worker        test(-1, -1)
2846*da0073e9SAndroid Build Coastguard Worker        test(-1.1, 1.1)
2847*da0073e9SAndroid Build Coastguard Worker        test(True, False)
2848*da0073e9SAndroid Build Coastguard Worker        test(torch.ones(4, dtype=torch.float32), 1.1)
2849*da0073e9SAndroid Build Coastguard Worker
2850*da0073e9SAndroid Build Coastguard Worker    def test_index(self):
2851*da0073e9SAndroid Build Coastguard Worker        def fn(x, t):
2852*da0073e9SAndroid Build Coastguard Worker            v = operator.index(x)
2853*da0073e9SAndroid Build Coastguard Worker            torch.mul(t, v)
2854*da0073e9SAndroid Build Coastguard Worker
2855*da0073e9SAndroid Build Coastguard Worker        def test(a, b):
2856*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(opt_fn(a, b), fn(a, b))
2857*da0073e9SAndroid Build Coastguard Worker
2858*da0073e9SAndroid Build Coastguard Worker        for dynamic in [True, False]:
2859*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
2860*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch._dynamo.optimize(dynamic=dynamic)(fn)
2861*da0073e9SAndroid Build Coastguard Worker            t = torch.ones(1)
2862*da0073e9SAndroid Build Coastguard Worker            test(10, t)
2863*da0073e9SAndroid Build Coastguard Worker            test(-100, t)
2864*da0073e9SAndroid Build Coastguard Worker            test(10, t)
2865*da0073e9SAndroid Build Coastguard Worker            test(False, t)
2866*da0073e9SAndroid Build Coastguard Worker            test(True, t)
2867*da0073e9SAndroid Build Coastguard Worker
2868*da0073e9SAndroid Build Coastguard Worker    def test_truth(self):
2869*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
2870*da0073e9SAndroid Build Coastguard Worker            return operator.truth(x) and bool(y)
2871*da0073e9SAndroid Build Coastguard Worker
2872*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fullgraph=True, dynamic=False)(fn)
2873*da0073e9SAndroid Build Coastguard Worker
2874*da0073e9SAndroid Build Coastguard Worker        def test(x, y):
2875*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(opt_fn(x, y), fn(x, y))
2876*da0073e9SAndroid Build Coastguard Worker
2877*da0073e9SAndroid Build Coastguard Worker        test(1, 100)
2878*da0073e9SAndroid Build Coastguard Worker        test(-1.1, True)
2879*da0073e9SAndroid Build Coastguard Worker        test(-1.1, 1.1)
2880*da0073e9SAndroid Build Coastguard Worker        test(True, False)
2881*da0073e9SAndroid Build Coastguard Worker        test(torch.ones(1), 1)
2882*da0073e9SAndroid Build Coastguard Worker        test(torch.zeros(1), 1)
2883*da0073e9SAndroid Build Coastguard Worker        test(torch.ones(1), torch.ones(1))
2884*da0073e9SAndroid Build Coastguard Worker
2885*da0073e9SAndroid Build Coastguard Worker    def test_unary_fold_op(self):
2886*da0073e9SAndroid Build Coastguard Worker        for op in (operator.abs, abs, operator.neg, operator.pos, operator.truth):
2887*da0073e9SAndroid Build Coastguard Worker            with self.subTest(op=op):
2888*da0073e9SAndroid Build Coastguard Worker
2889*da0073e9SAndroid Build Coastguard Worker                def fn():
2890*da0073e9SAndroid Build Coastguard Worker                    a = range(-10, 10)
2891*da0073e9SAndroid Build Coastguard Worker                    return list(map(op, a))
2892*da0073e9SAndroid Build Coastguard Worker
2893*da0073e9SAndroid Build Coastguard Worker                opt_fn = torch._dynamo.optimize(nopython=True)(fn)
2894*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(opt_fn(), fn())
2895*da0073e9SAndroid Build Coastguard Worker
2896*da0073e9SAndroid Build Coastguard Worker    def test_unary_fold_op_seq(self):
2897*da0073e9SAndroid Build Coastguard Worker        for op in (operator.length_hint,):
2898*da0073e9SAndroid Build Coastguard Worker            with self.subTest(op=op):
2899*da0073e9SAndroid Build Coastguard Worker
2900*da0073e9SAndroid Build Coastguard Worker                def fn():
2901*da0073e9SAndroid Build Coastguard Worker                    a = [tuple(range(-10, i)) for i in range(10)]
2902*da0073e9SAndroid Build Coastguard Worker                    return tuple(map(op, a))
2903*da0073e9SAndroid Build Coastguard Worker
2904*da0073e9SAndroid Build Coastguard Worker                opt_fn = torch._dynamo.optimize(nopython=True)(fn)
2905*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(opt_fn(), fn())
2906*da0073e9SAndroid Build Coastguard Worker
2907*da0073e9SAndroid Build Coastguard Worker    def gen_random_range_args(self):
2908*da0073e9SAndroid Build Coastguard Worker        args_count = random.randint(1, 3)
2909*da0073e9SAndroid Build Coastguard Worker        args = [random.randint(-10, 10) for _ in range(args_count)]
2910*da0073e9SAndroid Build Coastguard Worker        if args_count == 3 and args[2] == 0:
2911*da0073e9SAndroid Build Coastguard Worker            args[2] = 1
2912*da0073e9SAndroid Build Coastguard Worker        return args
2913*da0073e9SAndroid Build Coastguard Worker
2914*da0073e9SAndroid Build Coastguard Worker    def test_range_length(self):
2915*da0073e9SAndroid Build Coastguard Worker        def test(*args, expected=None):
2916*da0073e9SAndroid Build Coastguard Worker            r = range(*args)
2917*da0073e9SAndroid Build Coastguard Worker            range_variable = RangeVariable([ConstantVariable.create(v) for v in args])
2918*da0073e9SAndroid Build Coastguard Worker
2919*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(r), range_variable.range_length())
2920*da0073e9SAndroid Build Coastguard Worker
2921*da0073e9SAndroid Build Coastguard Worker            if expected is not None:
2922*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(r), expected)
2923*da0073e9SAndroid Build Coastguard Worker
2924*da0073e9SAndroid Build Coastguard Worker        test(1, 1, 1, expected=0)
2925*da0073e9SAndroid Build Coastguard Worker        test(1, 0, expected=0)
2926*da0073e9SAndroid Build Coastguard Worker        test(-10, expected=0)
2927*da0073e9SAndroid Build Coastguard Worker
2928*da0073e9SAndroid Build Coastguard Worker        test(4, expected=4)
2929*da0073e9SAndroid Build Coastguard Worker        test(10, expected=10)
2930*da0073e9SAndroid Build Coastguard Worker
2931*da0073e9SAndroid Build Coastguard Worker        # step >1
2932*da0073e9SAndroid Build Coastguard Worker        test(1, 10, 2, expected=5)
2933*da0073e9SAndroid Build Coastguard Worker
2934*da0073e9SAndroid Build Coastguard Worker        # negative step
2935*da0073e9SAndroid Build Coastguard Worker        test(10, 1, -1, expected=9)
2936*da0073e9SAndroid Build Coastguard Worker        test(10, 1, -3)
2937*da0073e9SAndroid Build Coastguard Worker
2938*da0073e9SAndroid Build Coastguard Worker        # Fuzz testing
2939*da0073e9SAndroid Build Coastguard Worker        for i in range(100):
2940*da0073e9SAndroid Build Coastguard Worker            args = self.gen_random_range_args()
2941*da0073e9SAndroid Build Coastguard Worker            print("testing :", args)
2942*da0073e9SAndroid Build Coastguard Worker            test(*args)
2943*da0073e9SAndroid Build Coastguard Worker
2944*da0073e9SAndroid Build Coastguard Worker    def test_indexed_range(self):
2945*da0073e9SAndroid Build Coastguard Worker        def test(range, index, expected=None):
2946*da0073e9SAndroid Build Coastguard Worker            range_variable = RangeVariable(
2947*da0073e9SAndroid Build Coastguard Worker                [
2948*da0073e9SAndroid Build Coastguard Worker                    ConstantVariable.create(v)
2949*da0073e9SAndroid Build Coastguard Worker                    for v in [range.start, range.stop, range.step]
2950*da0073e9SAndroid Build Coastguard Worker                ]
2951*da0073e9SAndroid Build Coastguard Worker            )
2952*da0073e9SAndroid Build Coastguard Worker
2953*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
2954*da0073e9SAndroid Build Coastguard Worker                range[index],
2955*da0073e9SAndroid Build Coastguard Worker                range_variable.apply_index(index).as_python_constant(),
2956*da0073e9SAndroid Build Coastguard Worker            )
2957*da0073e9SAndroid Build Coastguard Worker
2958*da0073e9SAndroid Build Coastguard Worker            if expected is not None:
2959*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(range[index], expected)
2960*da0073e9SAndroid Build Coastguard Worker
2961*da0073e9SAndroid Build Coastguard Worker        test(range(10), 1, expected=1)
2962*da0073e9SAndroid Build Coastguard Worker        test(range(10, 20, 2), 1, expected=12)
2963*da0073e9SAndroid Build Coastguard Worker
2964*da0073e9SAndroid Build Coastguard Worker        # Fuzz testing
2965*da0073e9SAndroid Build Coastguard Worker        for i in range(100):
2966*da0073e9SAndroid Build Coastguard Worker            range_args = self.gen_random_range_args()
2967*da0073e9SAndroid Build Coastguard Worker            r = range(*range_args)
2968*da0073e9SAndroid Build Coastguard Worker
2969*da0073e9SAndroid Build Coastguard Worker            if len(r) == 0:
2970*da0073e9SAndroid Build Coastguard Worker                continue
2971*da0073e9SAndroid Build Coastguard Worker
2972*da0073e9SAndroid Build Coastguard Worker            index = random.randint(0, len(r) - 1)
2973*da0073e9SAndroid Build Coastguard Worker
2974*da0073e9SAndroid Build Coastguard Worker            print("testing:", r, index)
2975*da0073e9SAndroid Build Coastguard Worker            test(r, index)
2976*da0073e9SAndroid Build Coastguard Worker
2977*da0073e9SAndroid Build Coastguard Worker    def test_sliced_range(self):
2978*da0073e9SAndroid Build Coastguard Worker        def test(range, slice, expected=None):
2979*da0073e9SAndroid Build Coastguard Worker            range_variable = RangeVariable(
2980*da0073e9SAndroid Build Coastguard Worker                [
2981*da0073e9SAndroid Build Coastguard Worker                    ConstantVariable.create(v)
2982*da0073e9SAndroid Build Coastguard Worker                    for v in [range.start, range.stop, range.step]
2983*da0073e9SAndroid Build Coastguard Worker                ]
2984*da0073e9SAndroid Build Coastguard Worker            )
2985*da0073e9SAndroid Build Coastguard Worker
2986*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
2987*da0073e9SAndroid Build Coastguard Worker                range[slice],
2988*da0073e9SAndroid Build Coastguard Worker                range_variable.apply_slice(slice).as_python_constant(),
2989*da0073e9SAndroid Build Coastguard Worker            )
2990*da0073e9SAndroid Build Coastguard Worker
2991*da0073e9SAndroid Build Coastguard Worker            if expected is not None:
2992*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
2993*da0073e9SAndroid Build Coastguard Worker                    range[slice],
2994*da0073e9SAndroid Build Coastguard Worker                    expected,
2995*da0073e9SAndroid Build Coastguard Worker                )
2996*da0073e9SAndroid Build Coastguard Worker
2997*da0073e9SAndroid Build Coastguard Worker        test(range(10), slice(1, 10, 2), expected=range(1, 10, 2))
2998*da0073e9SAndroid Build Coastguard Worker        test(range(10), slice(None, 10, None), expected=range(0, 10))
2999*da0073e9SAndroid Build Coastguard Worker        test(range(10), slice(-1, 7, None), expected=range(9, 7))
3000*da0073e9SAndroid Build Coastguard Worker        test(range(10), slice(-1, 7, 2), expected=range(9, 7, 2))
3001*da0073e9SAndroid Build Coastguard Worker        test(range(1, 10, 2), slice(3, 7, 2), expected=range(7, 11, 4))
3002*da0073e9SAndroid Build Coastguard Worker        test(range(1, 10, 2), slice(-3, 7, 2), expected=range(5, 11, 4))
3003*da0073e9SAndroid Build Coastguard Worker        test(range(-1, -5, -3), slice(5, None, -3), expected=range(-4, 2, 9))
3004*da0073e9SAndroid Build Coastguard Worker
3005*da0073e9SAndroid Build Coastguard Worker        def rand_slice():
3006*da0073e9SAndroid Build Coastguard Worker            def flip_coin():
3007*da0073e9SAndroid Build Coastguard Worker                # 1 out of 10
3008*da0073e9SAndroid Build Coastguard Worker                return random.randint(1, 10) == 5
3009*da0073e9SAndroid Build Coastguard Worker
3010*da0073e9SAndroid Build Coastguard Worker            def r_item(allow_zero=True):
3011*da0073e9SAndroid Build Coastguard Worker                i = random.randint(-10, 10)
3012*da0073e9SAndroid Build Coastguard Worker                if not allow_zero and i == 0:
3013*da0073e9SAndroid Build Coastguard Worker                    i = 1
3014*da0073e9SAndroid Build Coastguard Worker                if flip_coin():
3015*da0073e9SAndroid Build Coastguard Worker                    i = None
3016*da0073e9SAndroid Build Coastguard Worker                return i
3017*da0073e9SAndroid Build Coastguard Worker
3018*da0073e9SAndroid Build Coastguard Worker            arg_count = random.randint(1, 3)
3019*da0073e9SAndroid Build Coastguard Worker
3020*da0073e9SAndroid Build Coastguard Worker            if arg_count == 1:
3021*da0073e9SAndroid Build Coastguard Worker                return slice(r_item())
3022*da0073e9SAndroid Build Coastguard Worker            elif arg_count == 2:
3023*da0073e9SAndroid Build Coastguard Worker                return slice(r_item(), r_item())
3024*da0073e9SAndroid Build Coastguard Worker            else:
3025*da0073e9SAndroid Build Coastguard Worker                return slice(r_item(), r_item(), r_item(False))
3026*da0073e9SAndroid Build Coastguard Worker
3027*da0073e9SAndroid Build Coastguard Worker        # Fuzz testing
3028*da0073e9SAndroid Build Coastguard Worker        for i in range(100):
3029*da0073e9SAndroid Build Coastguard Worker            range_args = self.gen_random_range_args()
3030*da0073e9SAndroid Build Coastguard Worker            r = range(*range_args)
3031*da0073e9SAndroid Build Coastguard Worker            # generate random slice
3032*da0073e9SAndroid Build Coastguard Worker            s = rand_slice()
3033*da0073e9SAndroid Build Coastguard Worker
3034*da0073e9SAndroid Build Coastguard Worker            print("testing:", r, s)
3035*da0073e9SAndroid Build Coastguard Worker            test(r, s)
3036*da0073e9SAndroid Build Coastguard Worker
3037*da0073e9SAndroid Build Coastguard Worker    def test_range_with_slice_index(self):
3038*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3039*da0073e9SAndroid Build Coastguard Worker            acc = 1
3040*da0073e9SAndroid Build Coastguard Worker            for k in range(2)[1::2]:
3041*da0073e9SAndroid Build Coastguard Worker                acc *= acc * k
3042*da0073e9SAndroid Build Coastguard Worker            return x * acc
3043*da0073e9SAndroid Build Coastguard Worker
3044*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fullgraph=True)(fn)
3045*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1)
3046*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x), fn(x))
3047*da0073e9SAndroid Build Coastguard Worker
3048*da0073e9SAndroid Build Coastguard Worker    def test_range_with_index(self):
3049*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3050*da0073e9SAndroid Build Coastguard Worker            acc = 1
3051*da0073e9SAndroid Build Coastguard Worker            acc *= acc * range(10, 20, 2)[2]
3052*da0073e9SAndroid Build Coastguard Worker            return x * acc
3053*da0073e9SAndroid Build Coastguard Worker
3054*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fullgraph=True)(fn)
3055*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1)
3056*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x), fn(x))
3057*da0073e9SAndroid Build Coastguard Worker
3058*da0073e9SAndroid Build Coastguard Worker    def test_rand_inlined(self):
3059*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", dynamic=True)
3060*da0073e9SAndroid Build Coastguard Worker        def fn():
3061*da0073e9SAndroid Build Coastguard Worker            idx_size = [10]
3062*da0073e9SAndroid Build Coastguard Worker            idx_size[random.randint(0, 0)] = random.randint(1, 8)
3063*da0073e9SAndroid Build Coastguard Worker            t = tuple(idx_size)
3064*da0073e9SAndroid Build Coastguard Worker            src_size = [random.randint(1, 5) + s for s in idx_size]
3065*da0073e9SAndroid Build Coastguard Worker            idx = torch.empty(t)
3066*da0073e9SAndroid Build Coastguard Worker
3067*da0073e9SAndroid Build Coastguard Worker        fn()
3068*da0073e9SAndroid Build Coastguard Worker
3069*da0073e9SAndroid Build Coastguard Worker    def test_rand_tensor_partial(self):
3070*da0073e9SAndroid Build Coastguard Worker        from collections import namedtuple
3071*da0073e9SAndroid Build Coastguard Worker        from functools import partial
3072*da0073e9SAndroid Build Coastguard Worker
3073*da0073e9SAndroid Build Coastguard Worker        SdpaShape = namedtuple(
3074*da0073e9SAndroid Build Coastguard Worker            "Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"]
3075*da0073e9SAndroid Build Coastguard Worker        )
3076*da0073e9SAndroid Build Coastguard Worker
3077*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
3078*da0073e9SAndroid Build Coastguard Worker        def func():
3079*da0073e9SAndroid Build Coastguard Worker            make_tensor = partial(
3080*da0073e9SAndroid Build Coastguard Worker                torch.rand, device="cpu", dtype=torch.float16, requires_grad=True
3081*da0073e9SAndroid Build Coastguard Worker            )
3082*da0073e9SAndroid Build Coastguard Worker
3083*da0073e9SAndroid Build Coastguard Worker            bsz, num_heads, seq_len_q, seq_len_kv, head_dim = (16, 16, 128, 128, 16)
3084*da0073e9SAndroid Build Coastguard Worker            make_q_tensor = partial(
3085*da0073e9SAndroid Build Coastguard Worker                make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)
3086*da0073e9SAndroid Build Coastguard Worker            )
3087*da0073e9SAndroid Build Coastguard Worker            make_kv_tensor = partial(
3088*da0073e9SAndroid Build Coastguard Worker                make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim)
3089*da0073e9SAndroid Build Coastguard Worker            )
3090*da0073e9SAndroid Build Coastguard Worker            t1 = make_q_tensor()
3091*da0073e9SAndroid Build Coastguard Worker            t2 = make_kv_tensor()
3092*da0073e9SAndroid Build Coastguard Worker            t3 = t1 + t2
3093*da0073e9SAndroid Build Coastguard Worker
3094*da0073e9SAndroid Build Coastguard Worker        func()
3095*da0073e9SAndroid Build Coastguard Worker
3096*da0073e9SAndroid Build Coastguard Worker    def test_to(self):
3097*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager")
3098*da0073e9SAndroid Build Coastguard Worker        def fn():
3099*da0073e9SAndroid Build Coastguard Worker            t = torch.ones(2)
3100*da0073e9SAndroid Build Coastguard Worker            y = t.to("meta")
3101*da0073e9SAndroid Build Coastguard Worker
3102*da0073e9SAndroid Build Coastguard Worker        fn()
3103*da0073e9SAndroid Build Coastguard Worker
3104*da0073e9SAndroid Build Coastguard Worker    def test_elipsis(self):
3105*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
3106*da0073e9SAndroid Build Coastguard Worker        def fn(a, ind, val):
3107*da0073e9SAndroid Build Coastguard Worker            a[ind] = val
3108*da0073e9SAndroid Build Coastguard Worker            return a
3109*da0073e9SAndroid Build Coastguard Worker
3110*da0073e9SAndroid Build Coastguard Worker        arr = np.zeros(4)
3111*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(arr, np.s_[...], np.ones(4)), np.ones(4))
3112*da0073e9SAndroid Build Coastguard Worker
3113*da0073e9SAndroid Build Coastguard Worker        arr = np.array([[1, 1], [2, 2]])
3114*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3115*da0073e9SAndroid Build Coastguard Worker            fn(arr, np.s_[0, ...], np.zeros(2)), np.array([[0, 0], [2, 2]])
3116*da0073e9SAndroid Build Coastguard Worker        )
3117*da0073e9SAndroid Build Coastguard Worker
3118*da0073e9SAndroid Build Coastguard Worker        arr = np.array([[1, 1], [2, 2]])
3119*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3120*da0073e9SAndroid Build Coastguard Worker            fn(arr, np.s_[1, ...], np.zeros(2)), np.array([[1, 1], [0, 0]])
3121*da0073e9SAndroid Build Coastguard Worker        )
3122*da0073e9SAndroid Build Coastguard Worker
3123*da0073e9SAndroid Build Coastguard Worker        arr = np.array([[1, 1], [2, 2]])
3124*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3125*da0073e9SAndroid Build Coastguard Worker            fn(arr, np.s_[..., 0], np.array([3, 3])), np.array([[3, 1], [3, 2]])
3126*da0073e9SAndroid Build Coastguard Worker        )
3127*da0073e9SAndroid Build Coastguard Worker
3128*da0073e9SAndroid Build Coastguard Worker        arr = np.array([[1, 1], [2, 2]])
3129*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3130*da0073e9SAndroid Build Coastguard Worker            fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]])
3131*da0073e9SAndroid Build Coastguard Worker        )
3132*da0073e9SAndroid Build Coastguard Worker
3133*da0073e9SAndroid Build Coastguard Worker    def test_map_return(self):
3134*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
3135*da0073e9SAndroid Build Coastguard Worker            return map(lambda x: x + 1, [a, b])
3136*da0073e9SAndroid Build Coastguard Worker
3137*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3138*da0073e9SAndroid Build Coastguard Worker        m = opt_fn(torch.randn(3, 3), torch.randn(3, 3))
3139*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(m, map)
3140*da0073e9SAndroid Build Coastguard Worker
3141*da0073e9SAndroid Build Coastguard Worker    @make_test
3142*da0073e9SAndroid Build Coastguard Worker    def test_map_max(a, b):
3143*da0073e9SAndroid Build Coastguard Worker        return max(map(lambda x: x.sum(), [a, b]))
3144*da0073e9SAndroid Build Coastguard Worker
3145*da0073e9SAndroid Build Coastguard Worker    # max(map(...)) graph breaks
3146*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
3147*da0073e9SAndroid Build Coastguard Worker    @make_test
3148*da0073e9SAndroid Build Coastguard Worker    def test_map_max_const(a):
3149*da0073e9SAndroid Build Coastguard Worker        return max(map(lambda x: x, [1, 2, 3])), a + 1
3150*da0073e9SAndroid Build Coastguard Worker
3151*da0073e9SAndroid Build Coastguard Worker    @make_test
3152*da0073e9SAndroid Build Coastguard Worker    def test_map_list(a, b):
3153*da0073e9SAndroid Build Coastguard Worker        return list(map(lambda x: x + 1, [a, b]))
3154*da0073e9SAndroid Build Coastguard Worker
3155*da0073e9SAndroid Build Coastguard Worker    @make_test
3156*da0073e9SAndroid Build Coastguard Worker    def test_map_tuple(a, b):
3157*da0073e9SAndroid Build Coastguard Worker        return tuple(map(lambda x: x + 1, [a, b]))
3158*da0073e9SAndroid Build Coastguard Worker
3159*da0073e9SAndroid Build Coastguard Worker    @make_test
3160*da0073e9SAndroid Build Coastguard Worker    def test_map_iter(a, b):
3161*da0073e9SAndroid Build Coastguard Worker        it = iter(map(lambda x: x + 1, [a, b]))
3162*da0073e9SAndroid Build Coastguard Worker        return next(it)
3163*da0073e9SAndroid Build Coastguard Worker
3164*da0073e9SAndroid Build Coastguard Worker    @make_test
3165*da0073e9SAndroid Build Coastguard Worker    def test_map_zip_dict(a):
3166*da0073e9SAndroid Build Coastguard Worker        d = dict(
3167*da0073e9SAndroid Build Coastguard Worker            zip(
3168*da0073e9SAndroid Build Coastguard Worker                map(lambda x: x + 1, [0, 1, 2]),
3169*da0073e9SAndroid Build Coastguard Worker                [map(lambda x: x - 1, [y]) for y in [3, 4, 5]],
3170*da0073e9SAndroid Build Coastguard Worker            )
3171*da0073e9SAndroid Build Coastguard Worker        )
3172*da0073e9SAndroid Build Coastguard Worker        return list(d[3])[0], a + 1  # noqa: RUF015
3173*da0073e9SAndroid Build Coastguard Worker
3174*da0073e9SAndroid Build Coastguard Worker    @make_test
3175*da0073e9SAndroid Build Coastguard Worker    def test_map_dict_fromkeys(a):
3176*da0073e9SAndroid Build Coastguard Worker        return dict.fromkeys(map(lambda x: x + 1, [0, 1])), a + 1
3177*da0073e9SAndroid Build Coastguard Worker
3178*da0073e9SAndroid Build Coastguard Worker    @make_test
3179*da0073e9SAndroid Build Coastguard Worker    def test_map_set(a):
3180*da0073e9SAndroid Build Coastguard Worker        return set(map(lambda x: x + 1, [0, 1])), a + 1
3181*da0073e9SAndroid Build Coastguard Worker
3182*da0073e9SAndroid Build Coastguard Worker    # test_map_sum defined earlier
3183*da0073e9SAndroid Build Coastguard Worker
3184*da0073e9SAndroid Build Coastguard Worker    @make_test
3185*da0073e9SAndroid Build Coastguard Worker    def test_map_reduce(a, b):
3186*da0073e9SAndroid Build Coastguard Worker        return functools.reduce(lambda x, y: x + y, map(lambda x: x + 1, [a, b]))
3187*da0073e9SAndroid Build Coastguard Worker
3188*da0073e9SAndroid Build Coastguard Worker    @make_test
3189*da0073e9SAndroid Build Coastguard Worker    def test_map_sorted(a):
3190*da0073e9SAndroid Build Coastguard Worker        return sorted(map(lambda x: x + 1, [0, 4, 3, 1, 2])), a + 1
3191*da0073e9SAndroid Build Coastguard Worker
3192*da0073e9SAndroid Build Coastguard Worker    @make_test
3193*da0073e9SAndroid Build Coastguard Worker    def test_map_list_extend(a, b, c):
3194*da0073e9SAndroid Build Coastguard Worker        l = [a]
3195*da0073e9SAndroid Build Coastguard Worker        l.extend(map(lambda x: x + 1, [b, c]))
3196*da0073e9SAndroid Build Coastguard Worker        return l
3197*da0073e9SAndroid Build Coastguard Worker
3198*da0073e9SAndroid Build Coastguard Worker    @make_test
3199*da0073e9SAndroid Build Coastguard Worker    def test_map_list_slice_assign(a, b, c, d, e):
3200*da0073e9SAndroid Build Coastguard Worker        l = [a, b, c]
3201*da0073e9SAndroid Build Coastguard Worker        l[1:2] = map(lambda x: x + 1, [d, e])
3202*da0073e9SAndroid Build Coastguard Worker        return l
3203*da0073e9SAndroid Build Coastguard Worker
3204*da0073e9SAndroid Build Coastguard Worker    @make_test
3205*da0073e9SAndroid Build Coastguard Worker    def test_map_deque_extendleft(a, b, c):
3206*da0073e9SAndroid Build Coastguard Worker        d = collections.deque([a])
3207*da0073e9SAndroid Build Coastguard Worker        d.extendleft(map(lambda x: x + 1, [b, c]))
3208*da0073e9SAndroid Build Coastguard Worker        return d
3209*da0073e9SAndroid Build Coastguard Worker
3210*da0073e9SAndroid Build Coastguard Worker    @make_test
3211*da0073e9SAndroid Build Coastguard Worker    def test_map_str_join(a):
3212*da0073e9SAndroid Build Coastguard Worker        return "".join(map(lambda x: x, ["a", "b", "c"])), a + 1
3213*da0073e9SAndroid Build Coastguard Worker
3214*da0073e9SAndroid Build Coastguard Worker    def test_map_with_graph_break(self):
3215*da0073e9SAndroid Build Coastguard Worker        def f(a):
3216*da0073e9SAndroid Build Coastguard Worker            a += 1
3217*da0073e9SAndroid Build Coastguard Worker
3218*da0073e9SAndroid Build Coastguard Worker            def g(x):
3219*da0073e9SAndroid Build Coastguard Worker                nonlocal a
3220*da0073e9SAndroid Build Coastguard Worker                a += 1
3221*da0073e9SAndroid Build Coastguard Worker                return x + 1
3222*da0073e9SAndroid Build Coastguard Worker
3223*da0073e9SAndroid Build Coastguard Worker            m = map(g, [1, 2, 3, 4, 5])
3224*da0073e9SAndroid Build Coastguard Worker            a += next(m)  # won't graph break
3225*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
3226*da0073e9SAndroid Build Coastguard Worker            a += next(m)  # will graph break
3227*da0073e9SAndroid Build Coastguard Worker            return a
3228*da0073e9SAndroid Build Coastguard Worker
3229*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3230*da0073e9SAndroid Build Coastguard Worker        opt_f = torch.compile(f, backend=cnts)
3231*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3)))
3232*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 3)
3233*da0073e9SAndroid Build Coastguard Worker
3234*da0073e9SAndroid Build Coastguard Worker    def test_map_reconstruct(self):
3235*da0073e9SAndroid Build Coastguard Worker        def fn(a):
3236*da0073e9SAndroid Build Coastguard Worker            return map(lambda x: x[0] + x[1], zip([1, 2, 3], [1, 2, 3])), a + 1
3237*da0073e9SAndroid Build Coastguard Worker
3238*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3239*da0073e9SAndroid Build Coastguard Worker        m = opt_fn(torch.ones(3, 3))[0]
3240*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(m, map)
3241*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
3242*da0073e9SAndroid Build Coastguard Worker
3243*da0073e9SAndroid Build Coastguard Worker    def test_zip_reconstruct(self):
3244*da0073e9SAndroid Build Coastguard Worker        def fn(a):
3245*da0073e9SAndroid Build Coastguard Worker            return zip([1, 2, 3], map(lambda x: x + 1, [1, 2, 3])), a + 1
3246*da0073e9SAndroid Build Coastguard Worker
3247*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3248*da0073e9SAndroid Build Coastguard Worker        m = opt_fn(torch.ones(3, 3))[0]
3249*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(m, zip)
3250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
3251*da0073e9SAndroid Build Coastguard Worker
3252*da0073e9SAndroid Build Coastguard Worker    @make_test
3253*da0073e9SAndroid Build Coastguard Worker    def test_map_partial_unpack(a, b):
3254*da0073e9SAndroid Build Coastguard Worker        y = 1
3255*da0073e9SAndroid Build Coastguard Worker
3256*da0073e9SAndroid Build Coastguard Worker        def f(x):
3257*da0073e9SAndroid Build Coastguard Worker            nonlocal y
3258*da0073e9SAndroid Build Coastguard Worker            y += 1
3259*da0073e9SAndroid Build Coastguard Worker            return x
3260*da0073e9SAndroid Build Coastguard Worker
3261*da0073e9SAndroid Build Coastguard Worker        l = list(zip([a, b], map(f, [1, 2, 3, 4])))
3262*da0073e9SAndroid Build Coastguard Worker        return a + y
3263*da0073e9SAndroid Build Coastguard Worker
3264*da0073e9SAndroid Build Coastguard Worker    @make_test
3265*da0073e9SAndroid Build Coastguard Worker    def test_map_call_function_ex(a, b):
3266*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
3267*da0073e9SAndroid Build Coastguard Worker            return x + y
3268*da0073e9SAndroid Build Coastguard Worker
3269*da0073e9SAndroid Build Coastguard Worker        return f(*map(lambda x: x + 1, [a, b]))
3270*da0073e9SAndroid Build Coastguard Worker
3271*da0073e9SAndroid Build Coastguard Worker    @make_test
3272*da0073e9SAndroid Build Coastguard Worker    def test_map_unpack_twice(a, b):
3273*da0073e9SAndroid Build Coastguard Worker        m = map(lambda x: x + 1, [a, b])
3274*da0073e9SAndroid Build Coastguard Worker        l1 = list(m)
3275*da0073e9SAndroid Build Coastguard Worker        l2 = list(m)
3276*da0073e9SAndroid Build Coastguard Worker        return l1, l2
3277*da0073e9SAndroid Build Coastguard Worker
3278*da0073e9SAndroid Build Coastguard Worker    @make_test
3279*da0073e9SAndroid Build Coastguard Worker    def test_enumerate(a, b):
3280*da0073e9SAndroid Build Coastguard Worker        return list(enumerate([a, b], start=1)), a + 1
3281*da0073e9SAndroid Build Coastguard Worker
3282*da0073e9SAndroid Build Coastguard Worker    @make_test
3283*da0073e9SAndroid Build Coastguard Worker    def test_map_enumerate(a, b):
3284*da0073e9SAndroid Build Coastguard Worker        return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1
3285*da0073e9SAndroid Build Coastguard Worker
3286*da0073e9SAndroid Build Coastguard Worker    @make_test
3287*da0073e9SAndroid Build Coastguard Worker    def test_map_infinite(a, b):
3288*da0073e9SAndroid Build Coastguard Worker        return list(map(lambda x, y: x + y, [a, b], itertools.count(3)))
3289*da0073e9SAndroid Build Coastguard Worker
3290*da0073e9SAndroid Build Coastguard Worker    @make_test
3291*da0073e9SAndroid Build Coastguard Worker    def test_map_unpack_vars(a, b):
3292*da0073e9SAndroid Build Coastguard Worker        x, y = map(lambda x: x + 1, [a, b])
3293*da0073e9SAndroid Build Coastguard Worker        return x + y
3294*da0073e9SAndroid Build Coastguard Worker
3295*da0073e9SAndroid Build Coastguard Worker    def test_enumerate_custom(self):
3296*da0073e9SAndroid Build Coastguard Worker        class MyClass:
3297*da0073e9SAndroid Build Coastguard Worker            def __iter__(self):
3298*da0073e9SAndroid Build Coastguard Worker                self.a = 1
3299*da0073e9SAndroid Build Coastguard Worker                return self
3300*da0073e9SAndroid Build Coastguard Worker
3301*da0073e9SAndroid Build Coastguard Worker            def __next__(self):
3302*da0073e9SAndroid Build Coastguard Worker                if self.a > 3:
3303*da0073e9SAndroid Build Coastguard Worker                    raise StopIteration
3304*da0073e9SAndroid Build Coastguard Worker                self.a += 1
3305*da0073e9SAndroid Build Coastguard Worker                return self.a
3306*da0073e9SAndroid Build Coastguard Worker
3307*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3308*da0073e9SAndroid Build Coastguard Worker            for i, it in enumerate(MyClass()):
3309*da0073e9SAndroid Build Coastguard Worker                x += i + it
3310*da0073e9SAndroid Build Coastguard Worker            return x
3311*da0073e9SAndroid Build Coastguard Worker
3312*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3313*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3)))
3314*da0073e9SAndroid Build Coastguard Worker
3315*da0073e9SAndroid Build Coastguard Worker    def test_enumerate_reconstruct(self):
3316*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
3317*da0073e9SAndroid Build Coastguard Worker            return enumerate([a, b], start=1)
3318*da0073e9SAndroid Build Coastguard Worker
3319*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3320*da0073e9SAndroid Build Coastguard Worker        inps = (torch.randn(3, 3), torch.randn(3, 3))
3321*da0073e9SAndroid Build Coastguard Worker        it1 = fn(*inps)
3322*da0073e9SAndroid Build Coastguard Worker        it2 = opt_fn(*inps)
3323*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(it2, enumerate)
3324*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(list(it1), list(it2))
3325*da0073e9SAndroid Build Coastguard Worker
3326*da0073e9SAndroid Build Coastguard Worker
3327*da0073e9SAndroid Build Coastguard Workerdef udf_mul(x, y):
3328*da0073e9SAndroid Build Coastguard Worker    return x * y
3329*da0073e9SAndroid Build Coastguard Worker
3330*da0073e9SAndroid Build Coastguard Worker
3331*da0073e9SAndroid Build Coastguard Workerdef udf_mul2(x, y, z):
3332*da0073e9SAndroid Build Coastguard Worker    return x * y * z
3333*da0073e9SAndroid Build Coastguard Worker
3334*da0073e9SAndroid Build Coastguard Worker
3335*da0073e9SAndroid Build Coastguard Workerdef udf_add(x, y):
3336*da0073e9SAndroid Build Coastguard Worker    return x + y
3337*da0073e9SAndroid Build Coastguard Worker
3338*da0073e9SAndroid Build Coastguard Worker
3339*da0073e9SAndroid Build Coastguard Workerclass SmallNN(torch.nn.Module):
3340*da0073e9SAndroid Build Coastguard Worker    def forward(self, x, y):
3341*da0073e9SAndroid Build Coastguard Worker        combined = torch.cat((x, y), dim=1)
3342*da0073e9SAndroid Build Coastguard Worker        out = torch.nn.ReLU()(combined)
3343*da0073e9SAndroid Build Coastguard Worker        out = torch.nn.ReLU()(out)
3344*da0073e9SAndroid Build Coastguard Worker        return out
3345*da0073e9SAndroid Build Coastguard Worker
3346*da0073e9SAndroid Build Coastguard Worker
3347*da0073e9SAndroid Build Coastguard Workerdef udf_module(mod, x, y):
3348*da0073e9SAndroid Build Coastguard Worker    return mod(x, y)
3349*da0073e9SAndroid Build Coastguard Worker
3350*da0073e9SAndroid Build Coastguard Worker
3351*da0073e9SAndroid Build Coastguard Workerdef global_func_with_default_tensor_args(
3352*da0073e9SAndroid Build Coastguard Worker    x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))
3353*da0073e9SAndroid Build Coastguard Worker):
3354*da0073e9SAndroid Build Coastguard Worker    x.add_(1)
3355*da0073e9SAndroid Build Coastguard Worker    kw_x.add_(1)
3356*da0073e9SAndroid Build Coastguard Worker    return x, kw_x
3357*da0073e9SAndroid Build Coastguard Worker
3358*da0073e9SAndroid Build Coastguard Worker
3359*da0073e9SAndroid Build Coastguard Workerclass ModuleWithDefaultTensorArgsMethod(torch.nn.Module):
3360*da0073e9SAndroid Build Coastguard Worker    def forward(self, x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))):
3361*da0073e9SAndroid Build Coastguard Worker        x.add_(1)
3362*da0073e9SAndroid Build Coastguard Worker        kw_x.add_(1)
3363*da0073e9SAndroid Build Coastguard Worker        return x, kw_x
3364*da0073e9SAndroid Build Coastguard Worker
3365*da0073e9SAndroid Build Coastguard Worker
3366*da0073e9SAndroid Build Coastguard Workerclass WrapperModule(torch.nn.Module):
3367*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
3368*da0073e9SAndroid Build Coastguard Worker        super().__init__()
3369*da0073e9SAndroid Build Coastguard Worker        self.m = ModuleWithDefaultTensorArgsMethod()
3370*da0073e9SAndroid Build Coastguard Worker
3371*da0073e9SAndroid Build Coastguard Worker    def forward(self):
3372*da0073e9SAndroid Build Coastguard Worker        return self.m()
3373*da0073e9SAndroid Build Coastguard Worker
3374*da0073e9SAndroid Build Coastguard Worker
3375*da0073e9SAndroid Build Coastguard Workerclass DefaultsTests(torch._dynamo.test_case.TestCase):
3376*da0073e9SAndroid Build Coastguard Worker    def test_func_default_tensor_args(self):
3377*da0073e9SAndroid Build Coastguard Worker        """
3378*da0073e9SAndroid Build Coastguard Worker        Tests that we indeed reference (and mutate) "the one" default tensor arg
3379*da0073e9SAndroid Build Coastguard Worker        stored on the globally allocated function object, both from the orig and
3380*da0073e9SAndroid Build Coastguard Worker        compiled function
3381*da0073e9SAndroid Build Coastguard Worker        """
3382*da0073e9SAndroid Build Coastguard Worker
3383*da0073e9SAndroid Build Coastguard Worker        def func():
3384*da0073e9SAndroid Build Coastguard Worker            return global_func_with_default_tensor_args()
3385*da0073e9SAndroid Build Coastguard Worker
3386*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3387*da0073e9SAndroid Build Coastguard Worker        compiled_func = torch.compile(func, backend=cnts)
3388*da0073e9SAndroid Build Coastguard Worker        for i in range(4):
3389*da0073e9SAndroid Build Coastguard Worker            if i % 2 == 0:
3390*da0073e9SAndroid Build Coastguard Worker                x, kw_x = func()
3391*da0073e9SAndroid Build Coastguard Worker            else:
3392*da0073e9SAndroid Build Coastguard Worker                x, kw_x = compiled_func()
3393*da0073e9SAndroid Build Coastguard Worker            # the inner func mutates += 1 each call
3394*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(x, torch.ones_like(x) + i))
3395*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i))
3396*da0073e9SAndroid Build Coastguard Worker        # Calling compiled_func twice does not recompile
3397*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3398*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
3399*da0073e9SAndroid Build Coastguard Worker
3400*da0073e9SAndroid Build Coastguard Worker        # But with a change to the guarded default tensor, we do recompile
3401*da0073e9SAndroid Build Coastguard Worker        with patch.object(
3402*da0073e9SAndroid Build Coastguard Worker            global_func_with_default_tensor_args,
3403*da0073e9SAndroid Build Coastguard Worker            "__defaults__",
3404*da0073e9SAndroid Build Coastguard Worker            (torch.ones((3, 4, 5)),),
3405*da0073e9SAndroid Build Coastguard Worker        ):
3406*da0073e9SAndroid Build Coastguard Worker            x, kw_x = compiled_func()
3407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
3408*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 4)
3409*da0073e9SAndroid Build Coastguard Worker
3410*da0073e9SAndroid Build Coastguard Worker        with patch.object(
3411*da0073e9SAndroid Build Coastguard Worker            global_func_with_default_tensor_args,
3412*da0073e9SAndroid Build Coastguard Worker            "__kwdefaults__",
3413*da0073e9SAndroid Build Coastguard Worker            {"kw_x": torch.ones((3, 4, 5))},
3414*da0073e9SAndroid Build Coastguard Worker        ):
3415*da0073e9SAndroid Build Coastguard Worker            x, kw_x = compiled_func()
3416*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 3)
3417*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 6)
3418*da0073e9SAndroid Build Coastguard Worker
3419*da0073e9SAndroid Build Coastguard Worker    def test_meth_default_tensor_args(self):
3420*da0073e9SAndroid Build Coastguard Worker        """
3421*da0073e9SAndroid Build Coastguard Worker        Tests that we indeed reference (and mutate) "the one" default tensor arg
3422*da0073e9SAndroid Build Coastguard Worker        stored on the globally allocated function object, both from the orig and
3423*da0073e9SAndroid Build Coastguard Worker        compiled function
3424*da0073e9SAndroid Build Coastguard Worker        """
3425*da0073e9SAndroid Build Coastguard Worker        mod = WrapperModule()
3426*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3427*da0073e9SAndroid Build Coastguard Worker        compiled_mod = torch.compile(mod, backend=cnts)
3428*da0073e9SAndroid Build Coastguard Worker        for i in range(4):
3429*da0073e9SAndroid Build Coastguard Worker            if i % 2 == 0:
3430*da0073e9SAndroid Build Coastguard Worker                x, kw_x = mod()
3431*da0073e9SAndroid Build Coastguard Worker            else:
3432*da0073e9SAndroid Build Coastguard Worker                x, kw_x = compiled_mod()
3433*da0073e9SAndroid Build Coastguard Worker            # the inner func mutates += 1 each call
3434*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(x, torch.ones_like(x) + i))
3435*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i))
3436*da0073e9SAndroid Build Coastguard Worker        # Calling compiled_func twice does not recompile
3437*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3438*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 2)
3439*da0073e9SAndroid Build Coastguard Worker
3440*da0073e9SAndroid Build Coastguard Worker        # But with a change to the guarded default tensor, we do recompile
3441*da0073e9SAndroid Build Coastguard Worker        with patch.object(
3442*da0073e9SAndroid Build Coastguard Worker            ModuleWithDefaultTensorArgsMethod.forward,
3443*da0073e9SAndroid Build Coastguard Worker            "__defaults__",
3444*da0073e9SAndroid Build Coastguard Worker            (torch.ones((3, 4, 5)),),
3445*da0073e9SAndroid Build Coastguard Worker        ):
3446*da0073e9SAndroid Build Coastguard Worker            x, kw_x = compiled_mod()
3447*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
3448*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 4)
3449*da0073e9SAndroid Build Coastguard Worker
3450*da0073e9SAndroid Build Coastguard Worker        with patch.object(
3451*da0073e9SAndroid Build Coastguard Worker            ModuleWithDefaultTensorArgsMethod.forward,
3452*da0073e9SAndroid Build Coastguard Worker            "__kwdefaults__",
3453*da0073e9SAndroid Build Coastguard Worker            {"kw_x": torch.ones((3, 4, 5))},
3454*da0073e9SAndroid Build Coastguard Worker        ):
3455*da0073e9SAndroid Build Coastguard Worker            x, kw_x = compiled_mod()
3456*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 3)
3457*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 6)
3458*da0073e9SAndroid Build Coastguard Worker
3459*da0073e9SAndroid Build Coastguard Worker    def test_func_default_torch_args(self):
3460*da0073e9SAndroid Build Coastguard Worker        """
3461*da0073e9SAndroid Build Coastguard Worker        Tests other types of torch types as function default (size, dtype, device)
3462*da0073e9SAndroid Build Coastguard Worker        """
3463*da0073e9SAndroid Build Coastguard Worker
3464*da0073e9SAndroid Build Coastguard Worker        def func_with_default_torch_args(
3465*da0073e9SAndroid Build Coastguard Worker            dt=torch.float16, ds=torch.Size((1, 2, 3)), dd=torch.device("cpu")
3466*da0073e9SAndroid Build Coastguard Worker        ):
3467*da0073e9SAndroid Build Coastguard Worker            return torch.ones(ds, dtype=dt, device=dd)
3468*da0073e9SAndroid Build Coastguard Worker
3469*da0073e9SAndroid Build Coastguard Worker        def func():
3470*da0073e9SAndroid Build Coastguard Worker            return func_with_default_torch_args()
3471*da0073e9SAndroid Build Coastguard Worker
3472*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3473*da0073e9SAndroid Build Coastguard Worker        compiled_func = torch.compile(func, backend=cnts)
3474*da0073e9SAndroid Build Coastguard Worker        out = func()
3475*da0073e9SAndroid Build Coastguard Worker        compiled_out = compiled_func()
3476*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.dtype, compiled_out.dtype)
3477*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.device, compiled_out.device)
3478*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.size(), compiled_out.size())
3479*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3480*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 1)
3481*da0073e9SAndroid Build Coastguard Worker
3482*da0073e9SAndroid Build Coastguard Worker    def test_dataclass_factory(self):
3483*da0073e9SAndroid Build Coastguard Worker        @dataclass
3484*da0073e9SAndroid Build Coastguard Worker        class Output:
3485*da0073e9SAndroid Build Coastguard Worker            scalar: int = 2
3486*da0073e9SAndroid Build Coastguard Worker            named_tensors: Dict[str, torch.Tensor] = field(default_factory=dict)
3487*da0073e9SAndroid Build Coastguard Worker            lists: List[torch.Tensor] = field(default_factory=list)
3488*da0073e9SAndroid Build Coastguard Worker
3489*da0073e9SAndroid Build Coastguard Worker            def scale(self):
3490*da0073e9SAndroid Build Coastguard Worker                return self.scalar * 2
3491*da0073e9SAndroid Build Coastguard Worker
3492*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3493*da0073e9SAndroid Build Coastguard Worker            # Check default dict assignment
3494*da0073e9SAndroid Build Coastguard Worker            a = Output(1)
3495*da0073e9SAndroid Build Coastguard Worker            # Check that dataclass methods can be inlined
3496*da0073e9SAndroid Build Coastguard Worker            scaled_value = a.scale()
3497*da0073e9SAndroid Build Coastguard Worker
3498*da0073e9SAndroid Build Coastguard Worker            # Check that normal assignment works
3499*da0073e9SAndroid Build Coastguard Worker            b = Output(5, named_tensors={"x": x})
3500*da0073e9SAndroid Build Coastguard Worker
3501*da0073e9SAndroid Build Coastguard Worker            # Check default int assignment
3502*da0073e9SAndroid Build Coastguard Worker            c = Output()
3503*da0073e9SAndroid Build Coastguard Worker
3504*da0073e9SAndroid Build Coastguard Worker            # Check that the default members are properly initialized
3505*da0073e9SAndroid Build Coastguard Worker            if isinstance(a.named_tensors, dict):
3506*da0073e9SAndroid Build Coastguard Worker                x = torch.sin(x)
3507*da0073e9SAndroid Build Coastguard Worker
3508*da0073e9SAndroid Build Coastguard Worker            # Change dataclass
3509*da0073e9SAndroid Build Coastguard Worker            c.scalar = 6
3510*da0073e9SAndroid Build Coastguard Worker            c.named_tensors["x"] = x
3511*da0073e9SAndroid Build Coastguard Worker
3512*da0073e9SAndroid Build Coastguard Worker            # Return dataclaass as well to check reconstruction
3513*da0073e9SAndroid Build Coastguard Worker            return c, torch.cos(x) * scaled_value + b.named_tensors["x"] + c.scalar
3514*da0073e9SAndroid Build Coastguard Worker
3515*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3516*da0073e9SAndroid Build Coastguard Worker        compiled_fn = torch.compile(fn, backend=cnts, fullgraph=True)
3517*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
3518*da0073e9SAndroid Build Coastguard Worker        eager_dataclass, out = fn(x)
3519*da0073e9SAndroid Build Coastguard Worker        compiled_dataclass, compiled_out = compiled_fn(x)
3520*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_dataclass.scalar, compiled_dataclass.scalar)
3521*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3522*da0073e9SAndroid Build Coastguard Worker            eager_dataclass.named_tensors["x"], compiled_dataclass.named_tensors["x"]
3523*da0073e9SAndroid Build Coastguard Worker        )
3524*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(out, compiled_out))
3525*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3526*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.op_count, 5)
3527*da0073e9SAndroid Build Coastguard Worker
3528*da0073e9SAndroid Build Coastguard Worker    def test_dataclass_nested(self):
3529*da0073e9SAndroid Build Coastguard Worker        @dataclass
3530*da0073e9SAndroid Build Coastguard Worker        class Base:
3531*da0073e9SAndroid Build Coastguard Worker            outer_a: int
3532*da0073e9SAndroid Build Coastguard Worker            outer_b: int
3533*da0073e9SAndroid Build Coastguard Worker
3534*da0073e9SAndroid Build Coastguard Worker        @dataclass
3535*da0073e9SAndroid Build Coastguard Worker        class Derived(Base):
3536*da0073e9SAndroid Build Coastguard Worker            inner_a: Any = field(default_factory=list)
3537*da0073e9SAndroid Build Coastguard Worker
3538*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3539*da0073e9SAndroid Build Coastguard Worker            l = Derived(1, 2)
3540*da0073e9SAndroid Build Coastguard Worker            return l.outer_a * x
3541*da0073e9SAndroid Build Coastguard Worker
3542*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3543*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
3544*da0073e9SAndroid Build Coastguard Worker        res = fn(x)
3545*da0073e9SAndroid Build Coastguard Worker        ref = opt_fn(x)
3546*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
3547*da0073e9SAndroid Build Coastguard Worker
3548*da0073e9SAndroid Build Coastguard Worker    def test_listlike_of_tensors_contains_constant(self):
3549*da0073e9SAndroid Build Coastguard Worker        for listlike in [set, list]:
3550*da0073e9SAndroid Build Coastguard Worker
3551*da0073e9SAndroid Build Coastguard Worker            def fn(x):
3552*da0073e9SAndroid Build Coastguard Worker                x.add_(1)
3553*da0073e9SAndroid Build Coastguard Worker                s = listlike([x])
3554*da0073e9SAndroid Build Coastguard Worker                res = 1 in s
3555*da0073e9SAndroid Build Coastguard Worker                return res
3556*da0073e9SAndroid Build Coastguard Worker
3557*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3558*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1)
3559*da0073e9SAndroid Build Coastguard Worker            ref = opt_fn(x)
3560*da0073e9SAndroid Build Coastguard Worker            res = fn(x)
3561*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref, res)
3562*da0073e9SAndroid Build Coastguard Worker
3563*da0073e9SAndroid Build Coastguard Worker    def test_cast_tensor_single_elem(self):
3564*da0073e9SAndroid Build Coastguard Worker        with torch._dynamo.config.patch({"capture_scalar_outputs": True}):
3565*da0073e9SAndroid Build Coastguard Worker            for t, val in [
3566*da0073e9SAndroid Build Coastguard Worker                (float, 1.0),
3567*da0073e9SAndroid Build Coastguard Worker                (float, 1),
3568*da0073e9SAndroid Build Coastguard Worker                (float, True),
3569*da0073e9SAndroid Build Coastguard Worker                (int, 1),
3570*da0073e9SAndroid Build Coastguard Worker                (int, False),
3571*da0073e9SAndroid Build Coastguard Worker                # (int, 1.0), # fails due to a >= 0 comparison in sym_int
3572*da0073e9SAndroid Build Coastguard Worker            ]:  # , bool, complex]: no casting for sym_bool, no sym_complex
3573*da0073e9SAndroid Build Coastguard Worker
3574*da0073e9SAndroid Build Coastguard Worker                def fn(x):
3575*da0073e9SAndroid Build Coastguard Worker                    x = x + 1
3576*da0073e9SAndroid Build Coastguard Worker                    return t(x)
3577*da0073e9SAndroid Build Coastguard Worker
3578*da0073e9SAndroid Build Coastguard Worker                opt_fn = torch.compile(
3579*da0073e9SAndroid Build Coastguard Worker                    fn, backend="eager", fullgraph=True, dynamic=False
3580*da0073e9SAndroid Build Coastguard Worker                )
3581*da0073e9SAndroid Build Coastguard Worker                x = torch.tensor([val])
3582*da0073e9SAndroid Build Coastguard Worker                res = fn(x)
3583*da0073e9SAndroid Build Coastguard Worker                ref = opt_fn(x)
3584*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ref, res)
3585*da0073e9SAndroid Build Coastguard Worker
3586*da0073e9SAndroid Build Coastguard Worker                # Cannot handle non single-elem
3587*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(ValueError):
3588*da0073e9SAndroid Build Coastguard Worker                    fn(torch.tensor([val] * 2))
3589*da0073e9SAndroid Build Coastguard Worker                with self.assertRaises(torch._dynamo.exc.TorchRuntimeError):
3590*da0073e9SAndroid Build Coastguard Worker                    opt_fn(torch.tensor([val] * 2))
3591*da0073e9SAndroid Build Coastguard Worker
3592*da0073e9SAndroid Build Coastguard Worker    def test_set_construction(self):
3593*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3594*da0073e9SAndroid Build Coastguard Worker            y = x.add_(1)
3595*da0073e9SAndroid Build Coastguard Worker            s = set({x})
3596*da0073e9SAndroid Build Coastguard Worker            s.add(y)
3597*da0073e9SAndroid Build Coastguard Worker            return len(s)
3598*da0073e9SAndroid Build Coastguard Worker
3599*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3600*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
3601*da0073e9SAndroid Build Coastguard Worker        res = fn(x)
3602*da0073e9SAndroid Build Coastguard Worker        ref = opt_fn(x)
3603*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
3604*da0073e9SAndroid Build Coastguard Worker
3605*da0073e9SAndroid Build Coastguard Worker    def test_frozenset_construction(self):
3606*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3607*da0073e9SAndroid Build Coastguard Worker            s = frozenset({x})
3608*da0073e9SAndroid Build Coastguard Worker            t = frozenset(s)
3609*da0073e9SAndroid Build Coastguard Worker            return len(t)
3610*da0073e9SAndroid Build Coastguard Worker
3611*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3612*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
3613*da0073e9SAndroid Build Coastguard Worker        res = fn(x)
3614*da0073e9SAndroid Build Coastguard Worker        ref = opt_fn(x)
3615*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
3616*da0073e9SAndroid Build Coastguard Worker
3617*da0073e9SAndroid Build Coastguard Worker    def test_frozenset_reconstruction(self):
3618*da0073e9SAndroid Build Coastguard Worker        d = {}
3619*da0073e9SAndroid Build Coastguard Worker        f = frozenset()
3620*da0073e9SAndroid Build Coastguard Worker        d[f] = torch.randn(4)
3621*da0073e9SAndroid Build Coastguard Worker
3622*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3623*da0073e9SAndroid Build Coastguard Worker            k = frozenset()
3624*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
3625*da0073e9SAndroid Build Coastguard Worker            return d[k] * x
3626*da0073e9SAndroid Build Coastguard Worker
3627*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager")
3628*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
3629*da0073e9SAndroid Build Coastguard Worker        res = fn(x)
3630*da0073e9SAndroid Build Coastguard Worker        ref = opt_fn(x)
3631*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, res)
3632*da0073e9SAndroid Build Coastguard Worker
3633*da0073e9SAndroid Build Coastguard Worker    def test_frozenset_illegal_call_method(self):
3634*da0073e9SAndroid Build Coastguard Worker        def fn_add():
3635*da0073e9SAndroid Build Coastguard Worker            s = frozenset((1, 2, 3))
3636*da0073e9SAndroid Build Coastguard Worker            s.add({2})
3637*da0073e9SAndroid Build Coastguard Worker            return len(s)
3638*da0073e9SAndroid Build Coastguard Worker
3639*da0073e9SAndroid Build Coastguard Worker        def fn_pop():
3640*da0073e9SAndroid Build Coastguard Worker            s = frozenset((1, 2, 3))
3641*da0073e9SAndroid Build Coastguard Worker            s.pop()
3642*da0073e9SAndroid Build Coastguard Worker            return len(s)
3643*da0073e9SAndroid Build Coastguard Worker
3644*da0073e9SAndroid Build Coastguard Worker        def fn_update():
3645*da0073e9SAndroid Build Coastguard Worker            s = frozenset((1, 2, 3))
3646*da0073e9SAndroid Build Coastguard Worker            s.update({4, 5, 6})
3647*da0073e9SAndroid Build Coastguard Worker            return len(s)
3648*da0073e9SAndroid Build Coastguard Worker
3649*da0073e9SAndroid Build Coastguard Worker        def fn_remove():
3650*da0073e9SAndroid Build Coastguard Worker            s = frozenset((1, 2, 3))
3651*da0073e9SAndroid Build Coastguard Worker            s.remove(2)
3652*da0073e9SAndroid Build Coastguard Worker            return len(s)
3653*da0073e9SAndroid Build Coastguard Worker
3654*da0073e9SAndroid Build Coastguard Worker        def fn_discard():
3655*da0073e9SAndroid Build Coastguard Worker            s = frozenset((1, 2, 3))
3656*da0073e9SAndroid Build Coastguard Worker            s.discard(2)
3657*da0073e9SAndroid Build Coastguard Worker            return len(s)
3658*da0073e9SAndroid Build Coastguard Worker
3659*da0073e9SAndroid Build Coastguard Worker        def fn_clear():
3660*da0073e9SAndroid Build Coastguard Worker            s = frozenset((1, 2, 3))
3661*da0073e9SAndroid Build Coastguard Worker            s.clear()
3662*da0073e9SAndroid Build Coastguard Worker            return len(s)
3663*da0073e9SAndroid Build Coastguard Worker
3664*da0073e9SAndroid Build Coastguard Worker        for fn in [fn_add, fn_pop, fn_update, fn_remove, fn_discard, fn_clear]:
3665*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
3666*da0073e9SAndroid Build Coastguard Worker            opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3667*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError):
3668*da0073e9SAndroid Build Coastguard Worker                opt_fn()
3669*da0073e9SAndroid Build Coastguard Worker
3670*da0073e9SAndroid Build Coastguard Worker    def test_is_tensor_tensor(self):
3671*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
3672*da0073e9SAndroid Build Coastguard Worker            if x is y:
3673*da0073e9SAndroid Build Coastguard Worker                return x * 2
3674*da0073e9SAndroid Build Coastguard Worker            else:
3675*da0073e9SAndroid Build Coastguard Worker                return x + y
3676*da0073e9SAndroid Build Coastguard Worker
3677*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3678*da0073e9SAndroid Build Coastguard Worker
3679*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(2)
3680*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(2)
3681*da0073e9SAndroid Build Coastguard Worker
3682*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x, y), fn_opt(x, y))
3683*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x, x), fn_opt(x, x))
3684*da0073e9SAndroid Build Coastguard Worker
3685*da0073e9SAndroid Build Coastguard Worker    def test_is_not_tensor_tensor(self):
3686*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
3687*da0073e9SAndroid Build Coastguard Worker            if x is not y:
3688*da0073e9SAndroid Build Coastguard Worker                return x * 2
3689*da0073e9SAndroid Build Coastguard Worker            else:
3690*da0073e9SAndroid Build Coastguard Worker                return x + y
3691*da0073e9SAndroid Build Coastguard Worker
3692*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3693*da0073e9SAndroid Build Coastguard Worker
3694*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(2)
3695*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(2)
3696*da0073e9SAndroid Build Coastguard Worker
3697*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x, y), fn_opt(x, y))
3698*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x, x), fn_opt(x, x))
3699*da0073e9SAndroid Build Coastguard Worker
3700*da0073e9SAndroid Build Coastguard Worker    def test_is_mutated_tensor_tensor(self):
3701*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3702*da0073e9SAndroid Build Coastguard Worker            y = x.add_(1)
3703*da0073e9SAndroid Build Coastguard Worker            return x is y
3704*da0073e9SAndroid Build Coastguard Worker
3705*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3706*da0073e9SAndroid Build Coastguard Worker
3707*da0073e9SAndroid Build Coastguard Worker        z = torch.ones(4)
3708*da0073e9SAndroid Build Coastguard Worker
3709*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(z), fn_opt(z))
3710*da0073e9SAndroid Build Coastguard Worker
3711*da0073e9SAndroid Build Coastguard Worker    def test_is_mutated_tensor_tensor_across_graph_break(self):
3712*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3713*da0073e9SAndroid Build Coastguard Worker            y = x.add_(1)
3714*da0073e9SAndroid Build Coastguard Worker            cond = x is y
3715*da0073e9SAndroid Build Coastguard Worker            x.add_(1)
3716*da0073e9SAndroid Build Coastguard Worker            # The real tensor values are recovered when graph breaking.
3717*da0073e9SAndroid Build Coastguard Worker            # Hence we recover the invariant.
3718*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
3719*da0073e9SAndroid Build Coastguard Worker            x.add_(1)
3720*da0073e9SAndroid Build Coastguard Worker            return x is y, cond
3721*da0073e9SAndroid Build Coastguard Worker
3722*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(backend="eager", dynamic=True)(fn)
3723*da0073e9SAndroid Build Coastguard Worker
3724*da0073e9SAndroid Build Coastguard Worker        z = torch.ones(4)
3725*da0073e9SAndroid Build Coastguard Worker
3726*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(z), fn_opt(z))
3727*da0073e9SAndroid Build Coastguard Worker
3728*da0073e9SAndroid Build Coastguard Worker    def test_is_mutated_tensor_tensor(self):
3729*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3730*da0073e9SAndroid Build Coastguard Worker            y = x.add_(1)
3731*da0073e9SAndroid Build Coastguard Worker            return y is x
3732*da0073e9SAndroid Build Coastguard Worker
3733*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3734*da0073e9SAndroid Build Coastguard Worker
3735*da0073e9SAndroid Build Coastguard Worker        z = torch.ones(4, 1)
3736*da0073e9SAndroid Build Coastguard Worker
3737*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(z), fn_opt(z))
3738*da0073e9SAndroid Build Coastguard Worker
3739*da0073e9SAndroid Build Coastguard Worker    def test_is_init_in_compile_mutated_tensor_tensor(self):
3740*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3741*da0073e9SAndroid Build Coastguard Worker            z = x.clone()
3742*da0073e9SAndroid Build Coastguard Worker            y = z.add_(1)
3743*da0073e9SAndroid Build Coastguard Worker            return y is z
3744*da0073e9SAndroid Build Coastguard Worker
3745*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3746*da0073e9SAndroid Build Coastguard Worker
3747*da0073e9SAndroid Build Coastguard Worker        z = torch.ones(4, 1)
3748*da0073e9SAndroid Build Coastguard Worker
3749*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(z), fn_opt(z))
3750*da0073e9SAndroid Build Coastguard Worker
3751*da0073e9SAndroid Build Coastguard Worker    def test_is_init_in_compile_vmapped_mutated_tensor_tensor(self):
3752*da0073e9SAndroid Build Coastguard Worker        def fn(z):
3753*da0073e9SAndroid Build Coastguard Worker            x = z.clone()
3754*da0073e9SAndroid Build Coastguard Worker            y = torch.vmap(torch.Tensor.acos_)(x)
3755*da0073e9SAndroid Build Coastguard Worker            _ = y is z
3756*da0073e9SAndroid Build Coastguard Worker            return y is x
3757*da0073e9SAndroid Build Coastguard Worker
3758*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3759*da0073e9SAndroid Build Coastguard Worker
3760*da0073e9SAndroid Build Coastguard Worker        z = torch.ones(4, 1)
3761*da0073e9SAndroid Build Coastguard Worker
3762*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(z), fn_opt(z))
3763*da0073e9SAndroid Build Coastguard Worker
3764*da0073e9SAndroid Build Coastguard Worker    def test_is_vmapped_mutated_tensor_tensor(self):
3765*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3766*da0073e9SAndroid Build Coastguard Worker            y = torch.vmap(torch.Tensor.acos_)(x)
3767*da0073e9SAndroid Build Coastguard Worker            return y is x
3768*da0073e9SAndroid Build Coastguard Worker
3769*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3770*da0073e9SAndroid Build Coastguard Worker
3771*da0073e9SAndroid Build Coastguard Worker        z = torch.ones(4, 1)
3772*da0073e9SAndroid Build Coastguard Worker
3773*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(z), fn_opt(z))
3774*da0073e9SAndroid Build Coastguard Worker
3775*da0073e9SAndroid Build Coastguard Worker    def test_is_init_in_compile_vmapped_mutated_tensor_tensor_multi_arg(self):
3776*da0073e9SAndroid Build Coastguard Worker        def fn(y, z):
3777*da0073e9SAndroid Build Coastguard Worker            a = y.clone()
3778*da0073e9SAndroid Build Coastguard Worker            b = z.clone()
3779*da0073e9SAndroid Build Coastguard Worker
3780*da0073e9SAndroid Build Coastguard Worker            def g(a, b):
3781*da0073e9SAndroid Build Coastguard Worker                return a.acos_(), b.acos_()
3782*da0073e9SAndroid Build Coastguard Worker
3783*da0073e9SAndroid Build Coastguard Worker            c, d = torch.vmap(g)(a, b)
3784*da0073e9SAndroid Build Coastguard Worker            return a is c is b is d
3785*da0073e9SAndroid Build Coastguard Worker
3786*da0073e9SAndroid Build Coastguard Worker        fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
3787*da0073e9SAndroid Build Coastguard Worker
3788*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(4, 2)
3789*da0073e9SAndroid Build Coastguard Worker        z = torch.ones(4, 10)
3790*da0073e9SAndroid Build Coastguard Worker
3791*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(y, z), fn_opt(y, z))
3792*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(y, y), fn_opt(y, y))
3793*da0073e9SAndroid Build Coastguard Worker
3794*da0073e9SAndroid Build Coastguard Worker    def test_in_set_would_fail_broadcast(self):
3795*da0073e9SAndroid Build Coastguard Worker        param = torch.zeros(5)
3796*da0073e9SAndroid Build Coastguard Worker        param2 = torch.zeros(5, 10)
3797*da0073e9SAndroid Build Coastguard Worker
3798*da0073e9SAndroid Build Coastguard Worker        tensor_list = set()
3799*da0073e9SAndroid Build Coastguard Worker        tensor_list.add(param2)
3800*da0073e9SAndroid Build Coastguard Worker        assert param not in tensor_list
3801*da0073e9SAndroid Build Coastguard Worker
3802*da0073e9SAndroid Build Coastguard Worker        def fn(param, param2):
3803*da0073e9SAndroid Build Coastguard Worker            param.add_(1)
3804*da0073e9SAndroid Build Coastguard Worker            tensor_list = set([param2])
3805*da0073e9SAndroid Build Coastguard Worker            return param in tensor_list
3806*da0073e9SAndroid Build Coastguard Worker
3807*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3808*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3809*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(param, param2), fn(param, param2))
3810*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3811*da0073e9SAndroid Build Coastguard Worker        # Test aliased
3812*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(param, param), fn(param, param))
3813*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)  # Recompiles
3814*da0073e9SAndroid Build Coastguard Worker
3815*da0073e9SAndroid Build Coastguard Worker    def test_in_set_inplace(self):
3816*da0073e9SAndroid Build Coastguard Worker        param = torch.zeros(5)
3817*da0073e9SAndroid Build Coastguard Worker        param2 = torch.zeros(5, 10)
3818*da0073e9SAndroid Build Coastguard Worker
3819*da0073e9SAndroid Build Coastguard Worker        tensor_list = set()
3820*da0073e9SAndroid Build Coastguard Worker        tensor_list.add(param2)
3821*da0073e9SAndroid Build Coastguard Worker        assert param not in tensor_list
3822*da0073e9SAndroid Build Coastguard Worker
3823*da0073e9SAndroid Build Coastguard Worker        def fn(param, param2):
3824*da0073e9SAndroid Build Coastguard Worker            y = param.add_(1)  # Tensor method
3825*da0073e9SAndroid Build Coastguard Worker            z = torch.Tensor.add_(y, 1)  # torch function
3826*da0073e9SAndroid Build Coastguard Worker            tensor_list = set([param2])
3827*da0073e9SAndroid Build Coastguard Worker            return y in tensor_list and z in tensor_list
3828*da0073e9SAndroid Build Coastguard Worker
3829*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
3830*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
3831*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(param, param2), fn(param, param2))
3832*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
3833*da0073e9SAndroid Build Coastguard Worker        # Test aliased
3834*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(param, param), fn(param, param))
3835*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)  # Recompiles
3836*da0073e9SAndroid Build Coastguard Worker
3837*da0073e9SAndroid Build Coastguard Worker    def test_reconstructed_name(self):
3838*da0073e9SAndroid Build Coastguard Worker        lst = []
3839*da0073e9SAndroid Build Coastguard Worker
3840*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.disable
3841*da0073e9SAndroid Build Coastguard Worker        def disallowed(g):
3842*da0073e9SAndroid Build Coastguard Worker            lst.append(g.__name__)
3843*da0073e9SAndroid Build Coastguard Worker
3844*da0073e9SAndroid Build Coastguard Worker        def f():
3845*da0073e9SAndroid Build Coastguard Worker            def g():
3846*da0073e9SAndroid Build Coastguard Worker                return ()
3847*da0073e9SAndroid Build Coastguard Worker
3848*da0073e9SAndroid Build Coastguard Worker            disallowed(g)
3849*da0073e9SAndroid Build Coastguard Worker
3850*da0073e9SAndroid Build Coastguard Worker        f_opt = torch._dynamo
3851*da0073e9SAndroid Build Coastguard Worker        opt_f = torch._dynamo.optimize(backend="eager")(f)
3852*da0073e9SAndroid Build Coastguard Worker        opt_f()
3853*da0073e9SAndroid Build Coastguard Worker        f()
3854*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(lst), 2)
3855*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(lst[0], lst[1])
3856*da0073e9SAndroid Build Coastguard Worker
3857*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
3858*da0073e9SAndroid Build Coastguard Worker        sys.version_info < (3, 10),
3859*da0073e9SAndroid Build Coastguard Worker        "zip strict kwargs not implemented for Python < 3.10",
3860*da0073e9SAndroid Build Coastguard Worker    )
3861*da0073e9SAndroid Build Coastguard Worker    def test_zip_strict(self):
3862*da0073e9SAndroid Build Coastguard Worker        def fn(x, ys, zs):
3863*da0073e9SAndroid Build Coastguard Worker            x = x.clone()
3864*da0073e9SAndroid Build Coastguard Worker            for y, z in zip(ys, zs, strict=True):
3865*da0073e9SAndroid Build Coastguard Worker                x += y * z
3866*da0073e9SAndroid Build Coastguard Worker            return x
3867*da0073e9SAndroid Build Coastguard Worker
3868*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(backend="eager")(fn)
3869*da0073e9SAndroid Build Coastguard Worker        nopython_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
3870*da0073e9SAndroid Build Coastguard Worker
3871*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(3)
3872*da0073e9SAndroid Build Coastguard Worker        ys = [1.0, 2.0, 3.0]
3873*da0073e9SAndroid Build Coastguard Worker        zs = [2.0, 5.0, 8.0]
3874*da0073e9SAndroid Build Coastguard Worker
3875*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, ys, zs), fn(x, ys, zs))
3876*da0073e9SAndroid Build Coastguard Worker
3877*da0073e9SAndroid Build Coastguard Worker        # If nopython, should raise UserError
3878*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
3879*da0073e9SAndroid Build Coastguard Worker            nopython_fn(x, ys[:1], zs)
3880*da0073e9SAndroid Build Coastguard Worker
3881*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
3882*da0073e9SAndroid Build Coastguard Worker            nopython_fn(x, ys, zs[:1])
3883*da0073e9SAndroid Build Coastguard Worker
3884*da0073e9SAndroid Build Coastguard Worker        # Should cause fallback if allow graph break
3885*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "zip()"):
3886*da0073e9SAndroid Build Coastguard Worker            opt_fn(x, ys[:1], zs)
3887*da0073e9SAndroid Build Coastguard Worker
3888*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "zip()"):
3889*da0073e9SAndroid Build Coastguard Worker            opt_fn(x, ys, zs[:1])
3890*da0073e9SAndroid Build Coastguard Worker
3891*da0073e9SAndroid Build Coastguard Worker    def test_fn_with_attr(self):
3892*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3893*da0073e9SAndroid Build Coastguard Worker            if fn.pred:
3894*da0073e9SAndroid Build Coastguard Worker                return torch.relu(x * 2)
3895*da0073e9SAndroid Build Coastguard Worker            else:
3896*da0073e9SAndroid Build Coastguard Worker                return torch.abs(x + 3)
3897*da0073e9SAndroid Build Coastguard Worker
3898*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(3)
3899*da0073e9SAndroid Build Coastguard Worker        counter = torch._dynamo.testing.CompileCounter()
3900*da0073e9SAndroid Build Coastguard Worker        fn.pred = True
3901*da0073e9SAndroid Build Coastguard Worker        opt_fn_0 = torch.compile(fullgraph=True, backend=counter)(fn)
3902*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn_0(t), fn(t))
3903*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 1)
3904*da0073e9SAndroid Build Coastguard Worker        fn.pred = False
3905*da0073e9SAndroid Build Coastguard Worker        opt_fn_1 = torch.compile(fullgraph=True, backend=counter)(fn)
3906*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn_1(t), fn(t))
3907*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter.frame_count, 2)
3908*da0073e9SAndroid Build Coastguard Worker
3909*da0073e9SAndroid Build Coastguard Worker    def test_str_handler_for_user_defined_object(self):
3910*da0073e9SAndroid Build Coastguard Worker        """
3911*da0073e9SAndroid Build Coastguard Worker        Confirms handler behaviour for `str` is the same between eager and dynamo.
3912*da0073e9SAndroid Build Coastguard Worker        Compares a user defined object with custom `__str__` method and without.
3913*da0073e9SAndroid Build Coastguard Worker        """
3914*da0073e9SAndroid Build Coastguard Worker
3915*da0073e9SAndroid Build Coastguard Worker        class CustomStr:
3916*da0073e9SAndroid Build Coastguard Worker            def __str__(self):
3917*da0073e9SAndroid Build Coastguard Worker                return "ok"
3918*da0073e9SAndroid Build Coastguard Worker
3919*da0073e9SAndroid Build Coastguard Worker        def foo_custom_str(x):
3920*da0073e9SAndroid Build Coastguard Worker            a = CustomStr()
3921*da0073e9SAndroid Build Coastguard Worker            return x, str(a)
3922*da0073e9SAndroid Build Coastguard Worker
3923*da0073e9SAndroid Build Coastguard Worker        eager_custom_str = foo_custom_str(torch.ones(4))
3924*da0073e9SAndroid Build Coastguard Worker        dynamo_custom_str = torch.compile(foo_custom_str, fullgraph=True)(torch.ones(4))
3925*da0073e9SAndroid Build Coastguard Worker
3926*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_custom_str[1], dynamo_custom_str[1])
3927*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_custom_str[1], "ok")
3928*da0073e9SAndroid Build Coastguard Worker
3929*da0073e9SAndroid Build Coastguard Worker        class DefaultStr:
3930*da0073e9SAndroid Build Coastguard Worker            pass
3931*da0073e9SAndroid Build Coastguard Worker
3932*da0073e9SAndroid Build Coastguard Worker        def foo_default_str(x):
3933*da0073e9SAndroid Build Coastguard Worker            a = DefaultStr()
3934*da0073e9SAndroid Build Coastguard Worker            return x, str(a)
3935*da0073e9SAndroid Build Coastguard Worker
3936*da0073e9SAndroid Build Coastguard Worker        eager_default_str = foo_default_str(torch.ones(4))
3937*da0073e9SAndroid Build Coastguard Worker        dynamo_default_str = torch.compile(foo_default_str, fullgraph=True)(
3938*da0073e9SAndroid Build Coastguard Worker            torch.ones(4)
3939*da0073e9SAndroid Build Coastguard Worker        )
3940*da0073e9SAndroid Build Coastguard Worker
3941*da0073e9SAndroid Build Coastguard Worker        # Check that the tensor output from eager and dynamo modes are the same
3942*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_default_str[0], dynamo_default_str[0])
3943*da0073e9SAndroid Build Coastguard Worker
3944*da0073e9SAndroid Build Coastguard Worker        # Check that the class name (without memory address) is the same in both modes
3945*da0073e9SAndroid Build Coastguard Worker        eager_class_name = eager_default_str[1].split(" object at")[0]
3946*da0073e9SAndroid Build Coastguard Worker        dynamo_class_name = dynamo_default_str[1].split(" object at")[0]
3947*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_class_name, dynamo_class_name)
3948*da0073e9SAndroid Build Coastguard Worker
3949*da0073e9SAndroid Build Coastguard Worker    def test_pybind_object(self):
3950*da0073e9SAndroid Build Coastguard Worker        def fn(x, pybind_obj):
3951*da0073e9SAndroid Build Coastguard Worker            if pybind_obj.result:
3952*da0073e9SAndroid Build Coastguard Worker                return torch.cos(x)
3953*da0073e9SAndroid Build Coastguard Worker            return torch.sin(x)
3954*da0073e9SAndroid Build Coastguard Worker
3955*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3956*da0073e9SAndroid Build Coastguard Worker
3957*da0073e9SAndroid Build Coastguard Worker        pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(True, ["a==1"], 0)
3958*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
3959*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj))
3960*da0073e9SAndroid Build Coastguard Worker
3961*da0073e9SAndroid Build Coastguard Worker        pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(False, ["a==1"], 1)
3962*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4)
3963*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj))
3964*da0073e9SAndroid Build Coastguard Worker
3965*da0073e9SAndroid Build Coastguard Worker
3966*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(FunctionTests)
3967*da0073e9SAndroid Build Coastguard Worker
3968*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
3969*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
3970*da0073e9SAndroid Build Coastguard Worker
3971*da0073e9SAndroid Build Coastguard Worker    run_tests()
3972