1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Worker# flake8: noqa: E731, C405, F811, C418, C417 3*da0073e9SAndroid Build Coastguard Workerimport collections 4*da0073e9SAndroid Build Coastguard Workerimport functools 5*da0073e9SAndroid Build Coastguard Workerimport inspect 6*da0073e9SAndroid Build Coastguard Workerimport itertools 7*da0073e9SAndroid Build Coastguard Workerimport math 8*da0073e9SAndroid Build Coastguard Workerimport operator 9*da0073e9SAndroid Build Coastguard Workerimport random 10*da0073e9SAndroid Build Coastguard Workerimport sys 11*da0073e9SAndroid Build Coastguard Workerimport unittest 12*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass, field 13*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, NamedTuple 14*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerimport numpy as np 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Workerimport torch 19*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 20*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 21*da0073e9SAndroid Build Coastguard Workerfrom torch import sub 22*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import ( 23*da0073e9SAndroid Build Coastguard Worker CompileCounterWithBackend, 24*da0073e9SAndroid Build Coastguard Worker EagerAndRecordGraphs, 25*da0073e9SAndroid Build Coastguard Worker normalize_gm, 26*da0073e9SAndroid Build Coastguard Worker) 27*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import ifdynstaticdefault, same 28*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.variables import ConstantVariable 29*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.variables.lists import RangeVariable 30*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import functional as F 31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 32*da0073e9SAndroid Build Coastguard Worker disable_translation_validation_if_dynamic_shapes, 33*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 34*da0073e9SAndroid Build Coastguard Worker parametrize, 35*da0073e9SAndroid Build Coastguard Worker) 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker# Defines all the kernels for tests 38*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.triton_utils import * # noqa: F403 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Workerd = torch.ones(10, 10) 42*da0073e9SAndroid Build Coastguard Workere = torch.nn.Linear(10, 10) 43*da0073e9SAndroid Build Coastguard Workerflag = True 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerclass CustomDictSubclass(collections.OrderedDict): 47*da0073e9SAndroid Build Coastguard Worker pass 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Workerclip01 = functools.partial(torch.clip, min=0.0, max=1.0) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Workerdef constant3(a, b): 54*da0073e9SAndroid Build Coastguard Worker return a - b + (1.0 + 2) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker_variable = 0 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Workerdef update_global(x): 61*da0073e9SAndroid Build Coastguard Worker global _variable 62*da0073e9SAndroid Build Coastguard Worker _variable += 1 63*da0073e9SAndroid Build Coastguard Worker # Check that updated global variable value is picked up 64*da0073e9SAndroid Build Coastguard Worker return x * _variable 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Workerdef func_with_default(a, b, some_default_arg=True): 68*da0073e9SAndroid Build Coastguard Worker if some_default_arg: 69*da0073e9SAndroid Build Coastguard Worker return a - b 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Workerdef make_test(fn=None, expected_frame_count=1): 73*da0073e9SAndroid Build Coastguard Worker if fn is None: 74*da0073e9SAndroid Build Coastguard Worker return lambda fn: make_test(fn, expected_frame_count=expected_frame_count) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker nargs = len(inspect.signature(fn).parameters) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker def test_fn(self): 79*da0073e9SAndroid Build Coastguard Worker return torch._dynamo.testing.standard_test( 80*da0073e9SAndroid Build Coastguard Worker self, 81*da0073e9SAndroid Build Coastguard Worker fn=fn, 82*da0073e9SAndroid Build Coastguard Worker nargs=nargs, 83*da0073e9SAndroid Build Coastguard Worker expected_frame_count=expected_frame_count, 84*da0073e9SAndroid Build Coastguard Worker ) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker return test_fn 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Workerclass MyCls: 90*da0073e9SAndroid Build Coastguard Worker a = 1 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker@torch.jit.script_if_tracing 94*da0073e9SAndroid Build Coastguard Workerdef inline_script_if_tracing(x): 95*da0073e9SAndroid Build Coastguard Worker return x + 1.2 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker@torch.jit.ignore 99*da0073e9SAndroid Build Coastguard Workerdef inline_ignore(x): 100*da0073e9SAndroid Build Coastguard Worker return x + 3.4 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker@torch.jit.unused 104*da0073e9SAndroid Build Coastguard Workerdef inline_unused(x): 105*da0073e9SAndroid Build Coastguard Worker return x + 5.6 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache 109*da0073e9SAndroid Build Coastguard Workerdef inline_lru_cache_fn_with_default_args(x, y, _=None): 110*da0073e9SAndroid Build Coastguard Worker return torch.sin(x * y) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker@torch.jit.script_if_tracing 114*da0073e9SAndroid Build Coastguard Workerdef inline_script_if_tracing_fn_with_default_args(x, y, c=1.2): 115*da0073e9SAndroid Build Coastguard Worker return torch.cos(x * y) + c 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Workerclass FunctionTests(torch._dynamo.test_case.TestCase): 119*da0073e9SAndroid Build Coastguard Worker @make_test 120*da0073e9SAndroid Build Coastguard Worker def test_inline_jit_annotations(x): 121*da0073e9SAndroid Build Coastguard Worker x = inline_script_if_tracing(x) 122*da0073e9SAndroid Build Coastguard Worker x = inline_ignore(x) 123*da0073e9SAndroid Build Coastguard Worker x = inline_unused(x) 124*da0073e9SAndroid Build Coastguard Worker return 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker @make_test 127*da0073e9SAndroid Build Coastguard Worker def test_inline_script_if_tracing_fn_with_default_args(a, b): 128*da0073e9SAndroid Build Coastguard Worker return inline_script_if_tracing_fn_with_default_args(a, b) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker @make_test 131*da0073e9SAndroid Build Coastguard Worker def test_inline_lru_cache_fn_with_default_args(a, b): 132*da0073e9SAndroid Build Coastguard Worker return inline_lru_cache_fn_with_default_args(a, 2, b) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker @make_test 135*da0073e9SAndroid Build Coastguard Worker def test_add(a, b): 136*da0073e9SAndroid Build Coastguard Worker return a + b 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker @make_test 139*da0073e9SAndroid Build Coastguard Worker def test_add_(a, b): 140*da0073e9SAndroid Build Coastguard Worker a_copy = torch.tensor(a) 141*da0073e9SAndroid Build Coastguard Worker return a_copy.add_(b, alpha=5.0) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker @make_test 144*da0073e9SAndroid Build Coastguard Worker def test_addcdiv(a, b, c): 145*da0073e9SAndroid Build Coastguard Worker # dynamo decomposes this to avoid a graph break when 146*da0073e9SAndroid Build Coastguard Worker # the value kwarg is populated 147*da0073e9SAndroid Build Coastguard Worker return torch.addcdiv(a, b, c, value=5.0) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker @make_test 150*da0073e9SAndroid Build Coastguard Worker def test_addcdiv_(a, b, c): 151*da0073e9SAndroid Build Coastguard Worker a_copy = torch.tensor(a) 152*da0073e9SAndroid Build Coastguard Worker return a_copy.addcdiv_(b, c, value=5.0) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker @make_test 155*da0073e9SAndroid Build Coastguard Worker def test_is_not_null(a, b): 156*da0073e9SAndroid Build Coastguard Worker if a is not None and b is not None: 157*da0073e9SAndroid Build Coastguard Worker return a + b 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker def test_foreach_lerp_(self): 160*da0073e9SAndroid Build Coastguard Worker def fn(x, y, s): 161*da0073e9SAndroid Build Coastguard Worker return torch._foreach_lerp_(x, y, s) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn) 166*da0073e9SAndroid Build Coastguard Worker expected = fn( 167*da0073e9SAndroid Build Coastguard Worker [torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14], 168*da0073e9SAndroid Build Coastguard Worker [torch.ones(2, 2), torch.ones(2, 2)], 169*da0073e9SAndroid Build Coastguard Worker torch.tensor(0.5), 170*da0073e9SAndroid Build Coastguard Worker ) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker actual = fn_opt( 173*da0073e9SAndroid Build Coastguard Worker [torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14], 174*da0073e9SAndroid Build Coastguard Worker [torch.ones(2, 2), torch.ones(2, 2)], 175*da0073e9SAndroid Build Coastguard Worker torch.tensor(0.5), 176*da0073e9SAndroid Build Coastguard Worker ) 177*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(expected, actual)) 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker def test_broadcast_foreach_pow(self): 180*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.utils import same 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 183*da0073e9SAndroid Build Coastguard Worker return torch._foreach_pow(x, y) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn) 188*da0073e9SAndroid Build Coastguard Worker inps = (torch.tensor(0.80), [torch.tensor(3.4), torch.tensor(7.8)]) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker actual = fn_opt(*inps) 191*da0073e9SAndroid Build Coastguard Worker expected = fn(*inps) 192*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(actual, expected)) 193*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cnt.frame_count, 1) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker def test_addcmul_(self): 196*da0073e9SAndroid Build Coastguard Worker from copy import deepcopy 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.utils import same 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z, s): 201*da0073e9SAndroid Build Coastguard Worker return x.addcmul_(y, z, value=s) 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 204*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn) 205*da0073e9SAndroid Build Coastguard Worker inps = ( 206*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 2), 207*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 2) + 1, 208*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 2), 209*da0073e9SAndroid Build Coastguard Worker torch.tensor(0.3), 210*da0073e9SAndroid Build Coastguard Worker ) 211*da0073e9SAndroid Build Coastguard Worker inps_2 = deepcopy(inps) 212*da0073e9SAndroid Build Coastguard Worker actual = fn_opt(*inps) 213*da0073e9SAndroid Build Coastguard Worker expected = fn(*inps_2) 214*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(actual, expected)) 215*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker @make_test 218*da0073e9SAndroid Build Coastguard Worker def test_functools_partial(a, b): 219*da0073e9SAndroid Build Coastguard Worker return clip01(a + b) 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker @make_test 222*da0073e9SAndroid Build Coastguard Worker def test_itertools_product(a, b): 223*da0073e9SAndroid Build Coastguard Worker v = a 224*da0073e9SAndroid Build Coastguard Worker for x, i in itertools.product([a, b], [1, 2]): 225*da0073e9SAndroid Build Coastguard Worker v = v + x * i 226*da0073e9SAndroid Build Coastguard Worker return v 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker @make_test 229*da0073e9SAndroid Build Coastguard Worker def test_itertools_chain(a, b): 230*da0073e9SAndroid Build Coastguard Worker v = a 231*da0073e9SAndroid Build Coastguard Worker for x in itertools.chain([a, b], [1, 2]): 232*da0073e9SAndroid Build Coastguard Worker v = v + x 233*da0073e9SAndroid Build Coastguard Worker return v 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker @make_test 236*da0073e9SAndroid Build Coastguard Worker def test_itertools_chain_from_iterable(a, b): 237*da0073e9SAndroid Build Coastguard Worker v = a 238*da0073e9SAndroid Build Coastguard Worker for x in itertools.chain.from_iterable([[a, b], [1, 2]]): 239*da0073e9SAndroid Build Coastguard Worker v = v + x 240*da0073e9SAndroid Build Coastguard Worker return v 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker def test_itertools_reconstruct(self): 243*da0073e9SAndroid Build Coastguard Worker def fn(a): 244*da0073e9SAndroid Build Coastguard Worker it1 = itertools.repeat(1) 245*da0073e9SAndroid Build Coastguard Worker it2 = itertools.count(2) 246*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 247*da0073e9SAndroid Build Coastguard Worker a += next(it1) 248*da0073e9SAndroid Build Coastguard Worker a += next(it2) 249*da0073e9SAndroid Build Coastguard Worker return it1, it2, a 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 252*da0073e9SAndroid Build Coastguard Worker i1, i2, a = fn(torch.ones(3, 3)) 253*da0073e9SAndroid Build Coastguard Worker it1, it2, b = opt_fn(torch.ones(3, 3)) 254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(next(i1), next(it1)) 255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(next(i2), next(it2)) 256*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker @make_test 259*da0073e9SAndroid Build Coastguard Worker def test_obj_eq(a, b): 260*da0073e9SAndroid Build Coastguard Worker v = a + b 261*da0073e9SAndroid Build Coastguard Worker if MyCls() == None: # noqa: E711 262*da0073e9SAndroid Build Coastguard Worker return -1 263*da0073e9SAndroid Build Coastguard Worker if MyCls() != None: # noqa: E711 264*da0073e9SAndroid Build Coastguard Worker v = v.sin() 265*da0073e9SAndroid Build Coastguard Worker if MyCls() == MyCls(): 266*da0073e9SAndroid Build Coastguard Worker return -2 267*da0073e9SAndroid Build Coastguard Worker if MyCls() != MyCls(): 268*da0073e9SAndroid Build Coastguard Worker return v + 1 269*da0073e9SAndroid Build Coastguard Worker return -3 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker @make_test 272*da0073e9SAndroid Build Coastguard Worker def test_cls_eq(a, b): 273*da0073e9SAndroid Build Coastguard Worker v = a + b 274*da0073e9SAndroid Build Coastguard Worker if MyCls == None: # noqa: E711 275*da0073e9SAndroid Build Coastguard Worker return -1 276*da0073e9SAndroid Build Coastguard Worker if MyCls != None: # noqa: E711 277*da0073e9SAndroid Build Coastguard Worker v = v.sin() 278*da0073e9SAndroid Build Coastguard Worker if MyCls != MyCls: 279*da0073e9SAndroid Build Coastguard Worker return -2 280*da0073e9SAndroid Build Coastguard Worker if MyCls == MyCls: 281*da0073e9SAndroid Build Coastguard Worker return v + 1 282*da0073e9SAndroid Build Coastguard Worker return -3 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker @make_test 285*da0073e9SAndroid Build Coastguard Worker def test_obj_is(a, b): 286*da0073e9SAndroid Build Coastguard Worker v = a + b 287*da0073e9SAndroid Build Coastguard Worker if MyCls() is None: # noqa: E711 288*da0073e9SAndroid Build Coastguard Worker return -1 289*da0073e9SAndroid Build Coastguard Worker if MyCls() is not None: # noqa: E711 290*da0073e9SAndroid Build Coastguard Worker v = v.sin() 291*da0073e9SAndroid Build Coastguard Worker if MyCls() is MyCls(): 292*da0073e9SAndroid Build Coastguard Worker return -2 293*da0073e9SAndroid Build Coastguard Worker if MyCls() is not MyCls(): 294*da0073e9SAndroid Build Coastguard Worker return v + 1 295*da0073e9SAndroid Build Coastguard Worker return -3 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker @make_test 298*da0073e9SAndroid Build Coastguard Worker def test_cls_is(a, b): 299*da0073e9SAndroid Build Coastguard Worker v = a + b 300*da0073e9SAndroid Build Coastguard Worker if MyCls is None: # noqa: E711 301*da0073e9SAndroid Build Coastguard Worker return -1 302*da0073e9SAndroid Build Coastguard Worker if MyCls is not None: # noqa: E711 303*da0073e9SAndroid Build Coastguard Worker v = v.sin() 304*da0073e9SAndroid Build Coastguard Worker if MyCls is not MyCls: 305*da0073e9SAndroid Build Coastguard Worker return -2 306*da0073e9SAndroid Build Coastguard Worker if MyCls is MyCls: 307*da0073e9SAndroid Build Coastguard Worker return v + 1 308*da0073e9SAndroid Build Coastguard Worker return -3 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker @make_test 311*da0073e9SAndroid Build Coastguard Worker def test_itertools_combinations(a, b): 312*da0073e9SAndroid Build Coastguard Worker combs = [] 313*da0073e9SAndroid Build Coastguard Worker for size in itertools.combinations((1, 2, 3, 4), 2): 314*da0073e9SAndroid Build Coastguard Worker combs.append(torch.ones(size)) 315*da0073e9SAndroid Build Coastguard Worker return combs 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker @make_test 318*da0073e9SAndroid Build Coastguard Worker def test_np_iinfo(a): 319*da0073e9SAndroid Build Coastguard Worker max_dim = np.iinfo(np.int16).max 320*da0073e9SAndroid Build Coastguard Worker return a + max_dim 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker @make_test 323*da0073e9SAndroid Build Coastguard Worker def test_np_finfo(a): 324*da0073e9SAndroid Build Coastguard Worker min_dim = np.finfo(np.float32).min 325*da0073e9SAndroid Build Coastguard Worker return a + min_dim 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker @make_test 328*da0073e9SAndroid Build Coastguard Worker def test_constant1(a, b, c): 329*da0073e9SAndroid Build Coastguard Worker return a - b * c + 1.0 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker @make_test 332*da0073e9SAndroid Build Coastguard Worker def test_constant2(a, b, c): 333*da0073e9SAndroid Build Coastguard Worker return a - b * c + 1 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker @make_test 336*da0073e9SAndroid Build Coastguard Worker def test_constant3(a): 337*da0073e9SAndroid Build Coastguard Worker b = 1 338*da0073e9SAndroid Build Coastguard Worker c = 2 339*da0073e9SAndroid Build Coastguard Worker d = 3 340*da0073e9SAndroid Build Coastguard Worker return b + c - d + a 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker @make_test 343*da0073e9SAndroid Build Coastguard Worker def test_constant4(a, b): 344*da0073e9SAndroid Build Coastguard Worker c = 2 345*da0073e9SAndroid Build Coastguard Worker d = 3 346*da0073e9SAndroid Build Coastguard Worker if c > d: 347*da0073e9SAndroid Build Coastguard Worker return a - b 348*da0073e9SAndroid Build Coastguard Worker return b - a 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker @make_test 351*da0073e9SAndroid Build Coastguard Worker def test_cls_hasattr(self, x): 352*da0073e9SAndroid Build Coastguard Worker if hasattr(MyCls, "a"): 353*da0073e9SAndroid Build Coastguard Worker x = x + 1 354*da0073e9SAndroid Build Coastguard Worker if hasattr(MyCls, "b"): 355*da0073e9SAndroid Build Coastguard Worker x = x + 2 356*da0073e9SAndroid Build Coastguard Worker return x 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker @make_test 359*da0073e9SAndroid Build Coastguard Worker def test_finfo(a, b): 360*da0073e9SAndroid Build Coastguard Worker if torch.iinfo(torch.int32).bits == 32: 361*da0073e9SAndroid Build Coastguard Worker return torch.finfo(a.dtype).min * b 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker @make_test 364*da0073e9SAndroid Build Coastguard Worker def test_globalfn(a, b): 365*da0073e9SAndroid Build Coastguard Worker return sub(a, b) 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker @make_test 368*da0073e9SAndroid Build Coastguard Worker def test_viatorch(a, b): 369*da0073e9SAndroid Build Coastguard Worker return torch.sub(a, b) 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker @make_test 372*da0073e9SAndroid Build Coastguard Worker def test_viamethod(a, b): 373*da0073e9SAndroid Build Coastguard Worker return a.sub(b) 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker @make_test 376*da0073e9SAndroid Build Coastguard Worker def test_indirect1(a, b): 377*da0073e9SAndroid Build Coastguard Worker t = a.sub 378*da0073e9SAndroid Build Coastguard Worker return t(b) 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker @make_test 381*da0073e9SAndroid Build Coastguard Worker def test_indirect2(a, b): 382*da0073e9SAndroid Build Coastguard Worker t = a.sub 383*da0073e9SAndroid Build Coastguard Worker args = (b,) 384*da0073e9SAndroid Build Coastguard Worker return t(*args) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker @make_test 387*da0073e9SAndroid Build Coastguard Worker def test_indirect3(a, b): 388*da0073e9SAndroid Build Coastguard Worker t = a.sub 389*da0073e9SAndroid Build Coastguard Worker args = (b,) 390*da0073e9SAndroid Build Coastguard Worker kwargs = {} 391*da0073e9SAndroid Build Coastguard Worker return t(*args, **kwargs) 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker @make_test 394*da0073e9SAndroid Build Coastguard Worker def test_methodcall1(a, b, c): 395*da0073e9SAndroid Build Coastguard Worker return constant3(a, b) * c 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker @make_test 398*da0073e9SAndroid Build Coastguard Worker def test_methodcall2(a, b): 399*da0073e9SAndroid Build Coastguard Worker return constant3(a=b, b=a) + 1 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker @make_test 402*da0073e9SAndroid Build Coastguard Worker def test_methodcall3(a, b): 403*da0073e9SAndroid Build Coastguard Worker return constant3(a, b=1.0) + b 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker def test_is_integer(self): 406*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 407*da0073e9SAndroid Build Coastguard Worker def forward(t, m): 408*da0073e9SAndroid Build Coastguard Worker return 2 * t if m.is_integer() else t 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([1]) 411*da0073e9SAndroid Build Coastguard Worker self.assertEqual(forward(t, 1.0).item(), 2) 412*da0073e9SAndroid Build Coastguard Worker self.assertEqual(forward(t, 1.5).item(), 1) 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker @parametrize( 415*da0073e9SAndroid Build Coastguard Worker "method, num_type", 416*da0073e9SAndroid Build Coastguard Worker ( 417*da0073e9SAndroid Build Coastguard Worker ("as_integer_ratio", int), 418*da0073e9SAndroid Build Coastguard Worker ("bit_length", int), 419*da0073e9SAndroid Build Coastguard Worker ("conjugate", int), 420*da0073e9SAndroid Build Coastguard Worker ("as_integer_ratio", float), 421*da0073e9SAndroid Build Coastguard Worker ("conjugate", float), 422*da0073e9SAndroid Build Coastguard Worker ("hex", float), 423*da0073e9SAndroid Build Coastguard Worker ("is_integer", float), 424*da0073e9SAndroid Build Coastguard Worker ), 425*da0073e9SAndroid Build Coastguard Worker ) 426*da0073e9SAndroid Build Coastguard Worker def test_number_method(self, method, num_type): 427*da0073e9SAndroid Build Coastguard Worker def forward(t, m): 428*da0073e9SAndroid Build Coastguard Worker return 2 * t if getattr(m, method)() else t 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker wrapped = torch.compile(backend="eager", fullgraph=True)(forward) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker for i in (0, 1, 2.5): 433*da0073e9SAndroid Build Coastguard Worker m = num_type(i) 434*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([1]) 435*da0073e9SAndroid Build Coastguard Worker actual = wrapped(t, m) 436*da0073e9SAndroid Build Coastguard Worker expected = forward(t, m) 437*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker @make_test 440*da0073e9SAndroid Build Coastguard Worker def test_device_constant(a): 441*da0073e9SAndroid Build Coastguard Worker return a + torch.ones(1, device=torch.device("cpu")) 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker @make_test 444*da0073e9SAndroid Build Coastguard Worker def test_tuple1(a, b): 445*da0073e9SAndroid Build Coastguard Worker args = (a, b) 446*da0073e9SAndroid Build Coastguard Worker return sub(*args) 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker @make_test 449*da0073e9SAndroid Build Coastguard Worker def test_tuple2(a, b): 450*da0073e9SAndroid Build Coastguard Worker args = [a, b] 451*da0073e9SAndroid Build Coastguard Worker return sub(*args) 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker @make_test 454*da0073e9SAndroid Build Coastguard Worker def test_is_in_onnx_export(x, y): 455*da0073e9SAndroid Build Coastguard Worker if torch.onnx.is_in_onnx_export(): 456*da0073e9SAndroid Build Coastguard Worker return x - 1 457*da0073e9SAndroid Build Coastguard Worker else: 458*da0073e9SAndroid Build Coastguard Worker return y + 1 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker @make_test 461*da0073e9SAndroid Build Coastguard Worker def test_is_fx_tracing(x, y): 462*da0073e9SAndroid Build Coastguard Worker if torch.fx._symbolic_trace.is_fx_tracing(): 463*da0073e9SAndroid Build Coastguard Worker return x - 1 464*da0073e9SAndroid Build Coastguard Worker else: 465*da0073e9SAndroid Build Coastguard Worker return y + 1 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker @make_test 468*da0073e9SAndroid Build Coastguard Worker def test_listarg1(a, b): 469*da0073e9SAndroid Build Coastguard Worker return torch.cat([a, b]) 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker @make_test 472*da0073e9SAndroid Build Coastguard Worker def test_listarg2(a, b): 473*da0073e9SAndroid Build Coastguard Worker return torch.cat((a, b), dim=0) 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker @make_test 476*da0073e9SAndroid Build Coastguard Worker def test_listarg3(a, b): 477*da0073e9SAndroid Build Coastguard Worker kwargs = {"tensors": (a, b), "dim": 0} 478*da0073e9SAndroid Build Coastguard Worker return torch.cat(**kwargs) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker @make_test 481*da0073e9SAndroid Build Coastguard Worker def test_listarg4(a, b): 482*da0073e9SAndroid Build Coastguard Worker return torch.cat(tensors=[a, b], dim=0) 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker @make_test 485*da0073e9SAndroid Build Coastguard Worker def test_listarg5(a, b): 486*da0073e9SAndroid Build Coastguard Worker args = [(a, b)] 487*da0073e9SAndroid Build Coastguard Worker kwargs = {"dim": 0} 488*da0073e9SAndroid Build Coastguard Worker return torch.cat(*args, **kwargs) 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker def test_list_slice(self): 491*da0073e9SAndroid Build Coastguard Worker class Mock: 492*da0073e9SAndroid Build Coastguard Worker def __init__(self): 493*da0073e9SAndroid Build Coastguard Worker self.ets = [] 494*da0073e9SAndroid Build Coastguard Worker self.counter = 0 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 497*da0073e9SAndroid Build Coastguard Worker def run(self, x): 498*da0073e9SAndroid Build Coastguard Worker self.ets = self.ets[-3:] 499*da0073e9SAndroid Build Coastguard Worker self.ets.append(x) 500*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker mock = Mock() 503*da0073e9SAndroid Build Coastguard Worker mock.run(torch.randn(4)) 504*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(mock.ets), 1) 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker @make_test 507*da0073e9SAndroid Build Coastguard Worker def test_deque(a, b): 508*da0073e9SAndroid Build Coastguard Worker d = collections.deque([a, b]) 509*da0073e9SAndroid Build Coastguard Worker d.append(a + 1) 510*da0073e9SAndroid Build Coastguard Worker d.extend([a, b]) 511*da0073e9SAndroid Build Coastguard Worker d.insert(0, "foo") 512*da0073e9SAndroid Build Coastguard Worker tmp = d.pop() 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker another_deque = collections.deque([tmp]) 515*da0073e9SAndroid Build Coastguard Worker d.extendleft(another_deque) 516*da0073e9SAndroid Build Coastguard Worker another_deque.clear() 517*da0073e9SAndroid Build Coastguard Worker d.extend(another_deque) 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker d[2] = "setitem" 520*da0073e9SAndroid Build Coastguard Worker d = d.copy() 521*da0073e9SAndroid Build Coastguard Worker d.append(d.popleft()) 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker empty = collections.deque() 524*da0073e9SAndroid Build Coastguard Worker d.extend(empty) 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker return d 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker @make_test 529*da0073e9SAndroid Build Coastguard Worker def test_slice1(a): 530*da0073e9SAndroid Build Coastguard Worker return a[5] 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker @make_test 533*da0073e9SAndroid Build Coastguard Worker def test_slice2(a): 534*da0073e9SAndroid Build Coastguard Worker return a[:5] 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker @make_test 537*da0073e9SAndroid Build Coastguard Worker def test_slice3(a): 538*da0073e9SAndroid Build Coastguard Worker return a[5:] 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker @make_test 541*da0073e9SAndroid Build Coastguard Worker def test_slice4(a): 542*da0073e9SAndroid Build Coastguard Worker return a[2:5] 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker @make_test 545*da0073e9SAndroid Build Coastguard Worker def test_slice5(a): 546*da0073e9SAndroid Build Coastguard Worker return a[::2] 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker @make_test 549*da0073e9SAndroid Build Coastguard Worker def test_slice6(a): 550*da0073e9SAndroid Build Coastguard Worker return torch.unsqueeze(a, 0)[:, 2:] 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker @make_test 553*da0073e9SAndroid Build Coastguard Worker def test_range1(a): 554*da0073e9SAndroid Build Coastguard Worker return torch.tensor(range(a.size(0))) 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker @make_test 557*da0073e9SAndroid Build Coastguard Worker def test_range2(x, y): 558*da0073e9SAndroid Build Coastguard Worker r = x + y 559*da0073e9SAndroid Build Coastguard Worker for i in range(x.size(0) + 2): 560*da0073e9SAndroid Build Coastguard Worker r = r / y 561*da0073e9SAndroid Build Coastguard Worker return r 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker @make_test 564*da0073e9SAndroid Build Coastguard Worker def test_unpack1(a): 565*da0073e9SAndroid Build Coastguard Worker a, b = a[:5], a[5:] 566*da0073e9SAndroid Build Coastguard Worker return a - b 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker @make_test 569*da0073e9SAndroid Build Coastguard Worker def test_unpack2(a): 570*da0073e9SAndroid Build Coastguard Worker packed = [a[:5], a[5:]] 571*da0073e9SAndroid Build Coastguard Worker a, b = packed 572*da0073e9SAndroid Build Coastguard Worker return a - b 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker @make_test 575*da0073e9SAndroid Build Coastguard Worker def test_unpack3(a): 576*da0073e9SAndroid Build Coastguard Worker packed = (a[:5], a[5:]) 577*da0073e9SAndroid Build Coastguard Worker a, b = packed 578*da0073e9SAndroid Build Coastguard Worker return a - b 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker @make_test 581*da0073e9SAndroid Build Coastguard Worker def test_fn_with_self_set(a, b): 582*da0073e9SAndroid Build Coastguard Worker # avg_pool2d is an odd one with __self__ set 583*da0073e9SAndroid Build Coastguard Worker return F.avg_pool2d( 584*da0073e9SAndroid Build Coastguard Worker torch.unsqueeze(a, 0) * torch.unsqueeze(b, 1), kernel_size=2, padding=1 585*da0073e9SAndroid Build Coastguard Worker ) 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker @make_test 588*da0073e9SAndroid Build Coastguard Worker def test_return_tuple1(a, b): 589*da0073e9SAndroid Build Coastguard Worker return (a - b, b - a, a, b) 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker @make_test 592*da0073e9SAndroid Build Coastguard Worker def test_globalvar(a, b): 593*da0073e9SAndroid Build Coastguard Worker return a - b + d 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Worker @make_test 596*da0073e9SAndroid Build Coastguard Worker def test_globalmodule(x): 597*da0073e9SAndroid Build Coastguard Worker return e(x) 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker @make_test 600*da0073e9SAndroid Build Coastguard Worker def test_inline_with_default(a, b, c): 601*da0073e9SAndroid Build Coastguard Worker return func_with_default(a, b) * c 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker @make_test 604*da0073e9SAndroid Build Coastguard Worker def test_inner_function(x): 605*da0073e9SAndroid Build Coastguard Worker def fn(x): 606*da0073e9SAndroid Build Coastguard Worker return torch.add(x, x) 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker return fn(x) 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker @make_test 611*da0073e9SAndroid Build Coastguard Worker def test_transpose_for_scores(x): 612*da0073e9SAndroid Build Coastguard Worker new_x_shape = x.size()[:-1] + (2, 5) 613*da0073e9SAndroid Build Coastguard Worker x = x.view(*new_x_shape) 614*da0073e9SAndroid Build Coastguard Worker return x.permute(0, 2, 1) 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker @make_test 617*da0073e9SAndroid Build Coastguard Worker def test_return_tuple2(x): 618*da0073e9SAndroid Build Coastguard Worker return (torch.add(x, x), x) 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker @make_test 621*da0073e9SAndroid Build Coastguard Worker def test_load_global_bool(x): 622*da0073e9SAndroid Build Coastguard Worker if flag: 623*da0073e9SAndroid Build Coastguard Worker return torch.add(x, x) 624*da0073e9SAndroid Build Coastguard Worker else: 625*da0073e9SAndroid Build Coastguard Worker return x 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker @make_test 628*da0073e9SAndroid Build Coastguard Worker def test_len_tensor(x): 629*da0073e9SAndroid Build Coastguard Worker z = len(x) 630*da0073e9SAndroid Build Coastguard Worker return torch.add(x, z) 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker @make_test 633*da0073e9SAndroid Build Coastguard Worker def test_len_constant_list(x): 634*da0073e9SAndroid Build Coastguard Worker z = len([1, 2, 3]) 635*da0073e9SAndroid Build Coastguard Worker return torch.add(x, z) 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker @make_test 638*da0073e9SAndroid Build Coastguard Worker def test_len_constant_dict(x): 639*da0073e9SAndroid Build Coastguard Worker z = len({"foo": "bar"}) 640*da0073e9SAndroid Build Coastguard Worker return torch.add(x, z) 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker @make_test 643*da0073e9SAndroid Build Coastguard Worker def test_dict_copy(x): 644*da0073e9SAndroid Build Coastguard Worker z = dict({"foo": x + 1}) 645*da0073e9SAndroid Build Coastguard Worker return z 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker @make_test 648*da0073e9SAndroid Build Coastguard Worker def test_dict_keys(x): 649*da0073e9SAndroid Build Coastguard Worker d = {3: x} 650*da0073e9SAndroid Build Coastguard Worker keys = d.keys() 651*da0073e9SAndroid Build Coastguard Worker d[4] = x + 1 652*da0073e9SAndroid Build Coastguard Worker d2 = {3: 2, 4: "aa"} 653*da0073e9SAndroid Build Coastguard Worker return 3 in keys, 4 in keys, 5 in keys, d2.keys() == keys 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker @make_test 656*da0073e9SAndroid Build Coastguard Worker def test_dict_values(x): 657*da0073e9SAndroid Build Coastguard Worker d = {3: x} 658*da0073e9SAndroid Build Coastguard Worker values = d.values() 659*da0073e9SAndroid Build Coastguard Worker d[3] = x + 1 660*da0073e9SAndroid Build Coastguard Worker d[4] = x + 2 661*da0073e9SAndroid Build Coastguard Worker return len(values) 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker @make_test 664*da0073e9SAndroid Build Coastguard Worker def test_dict_setdefault1(x): 665*da0073e9SAndroid Build Coastguard Worker d = {"a": 1, "b": 2} 666*da0073e9SAndroid Build Coastguard Worker d.setdefault("a", 10) 667*da0073e9SAndroid Build Coastguard Worker if d["a"] == 1: 668*da0073e9SAndroid Build Coastguard Worker return x + 1 669*da0073e9SAndroid Build Coastguard Worker else: 670*da0073e9SAndroid Build Coastguard Worker return x - 1 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker @make_test 673*da0073e9SAndroid Build Coastguard Worker def test_dict_setdefault2(x): 674*da0073e9SAndroid Build Coastguard Worker d = {"a": 1, "b": 2} 675*da0073e9SAndroid Build Coastguard Worker d.setdefault("c", 10) 676*da0073e9SAndroid Build Coastguard Worker if d["c"] == 10: 677*da0073e9SAndroid Build Coastguard Worker return x + 1 678*da0073e9SAndroid Build Coastguard Worker else: 679*da0073e9SAndroid Build Coastguard Worker return x - 1 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker @make_test 682*da0073e9SAndroid Build Coastguard Worker def test_dict_setdefault3(x): 683*da0073e9SAndroid Build Coastguard Worker d = {"a": 1, "b": 2} 684*da0073e9SAndroid Build Coastguard Worker d.setdefault("c") 685*da0073e9SAndroid Build Coastguard Worker if d["c"] is None: 686*da0073e9SAndroid Build Coastguard Worker return x + 1 687*da0073e9SAndroid Build Coastguard Worker else: 688*da0073e9SAndroid Build Coastguard Worker return x - 1 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker @make_test 691*da0073e9SAndroid Build Coastguard Worker def test_defaultdict_setdefault1(x): 692*da0073e9SAndroid Build Coastguard Worker d = collections.defaultdict.fromkeys("a", "b") 693*da0073e9SAndroid Build Coastguard Worker d["a"] = 1 694*da0073e9SAndroid Build Coastguard Worker d["b"] = 2 695*da0073e9SAndroid Build Coastguard Worker d.setdefault("a", 10) 696*da0073e9SAndroid Build Coastguard Worker if d["a"] == 1: 697*da0073e9SAndroid Build Coastguard Worker return x + 1 698*da0073e9SAndroid Build Coastguard Worker else: 699*da0073e9SAndroid Build Coastguard Worker return x - 1 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker @make_test 702*da0073e9SAndroid Build Coastguard Worker def test_defaultdict_setdefault2(x): 703*da0073e9SAndroid Build Coastguard Worker d = collections.defaultdict.fromkeys("a", "b") 704*da0073e9SAndroid Build Coastguard Worker d["a"] = 1 705*da0073e9SAndroid Build Coastguard Worker d["b"] = 2 706*da0073e9SAndroid Build Coastguard Worker d.setdefault("c", 10) 707*da0073e9SAndroid Build Coastguard Worker if d["c"] == 10: 708*da0073e9SAndroid Build Coastguard Worker return x + 1 709*da0073e9SAndroid Build Coastguard Worker else: 710*da0073e9SAndroid Build Coastguard Worker return x - 1 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker @make_test 713*da0073e9SAndroid Build Coastguard Worker def test_defaultdict_setdefault3(x): 714*da0073e9SAndroid Build Coastguard Worker d = collections.defaultdict.fromkeys("a", "b") 715*da0073e9SAndroid Build Coastguard Worker d["a"] = 1 716*da0073e9SAndroid Build Coastguard Worker d["b"] = 2 717*da0073e9SAndroid Build Coastguard Worker d.setdefault("c") 718*da0073e9SAndroid Build Coastguard Worker if d["c"] is None: 719*da0073e9SAndroid Build Coastguard Worker return x + 1 720*da0073e9SAndroid Build Coastguard Worker else: 721*da0073e9SAndroid Build Coastguard Worker return x - 1 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Worker def test_dict_id_guard(self): 724*da0073e9SAndroid Build Coastguard Worker d1 = collections.OrderedDict({"a": 2}) 725*da0073e9SAndroid Build Coastguard Worker d2 = d1 726*da0073e9SAndroid Build Coastguard Worker 727*da0073e9SAndroid Build Coastguard Worker def fn(x): 728*da0073e9SAndroid Build Coastguard Worker # Iteration forces DictGuardManager 729*da0073e9SAndroid Build Coastguard Worker for k in d1: 730*da0073e9SAndroid Build Coastguard Worker x = x * d1[k] * d2[k] 731*da0073e9SAndroid Build Coastguard Worker return x 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 734*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 735*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker @make_test 738*da0073e9SAndroid Build Coastguard Worker def test_callable_lambda(x): 739*da0073e9SAndroid Build Coastguard Worker if callable(lambda x: True): 740*da0073e9SAndroid Build Coastguard Worker return x + 1 741*da0073e9SAndroid Build Coastguard Worker else: 742*da0073e9SAndroid Build Coastguard Worker return x - 1 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Worker @make_test 745*da0073e9SAndroid Build Coastguard Worker def test_callable_torch(x): 746*da0073e9SAndroid Build Coastguard Worker if callable(torch.abs): 747*da0073e9SAndroid Build Coastguard Worker return x + 1 748*da0073e9SAndroid Build Coastguard Worker else: 749*da0073e9SAndroid Build Coastguard Worker return x - 1 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker @make_test 752*da0073e9SAndroid Build Coastguard Worker def test_callable_builtin(x): 753*da0073e9SAndroid Build Coastguard Worker if callable(sum): 754*da0073e9SAndroid Build Coastguard Worker return x + 1 755*da0073e9SAndroid Build Coastguard Worker else: 756*da0073e9SAndroid Build Coastguard Worker return x - 1 757*da0073e9SAndroid Build Coastguard Worker 758*da0073e9SAndroid Build Coastguard Worker def test_callable_class(self): 759*da0073e9SAndroid Build Coastguard Worker class CallableClass: 760*da0073e9SAndroid Build Coastguard Worker def __call__(): 761*da0073e9SAndroid Build Coastguard Worker pass 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Worker class NotCallableClass: 764*da0073e9SAndroid Build Coastguard Worker pass 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 767*da0073e9SAndroid Build Coastguard Worker def fn1(x, arg): 768*da0073e9SAndroid Build Coastguard Worker if callable(arg): 769*da0073e9SAndroid Build Coastguard Worker return x 770*da0073e9SAndroid Build Coastguard Worker return x + 1 771*da0073e9SAndroid Build Coastguard Worker 772*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 773*da0073e9SAndroid Build Coastguard Worker def fn2(x, arg): 774*da0073e9SAndroid Build Coastguard Worker if callable(arg): 775*da0073e9SAndroid Build Coastguard Worker return x * 2 776*da0073e9SAndroid Build Coastguard Worker return x + 1 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker input = torch.randn(4) 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker for f in [fn1, fn2]: 781*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(input, NotCallableClass()), input + 1) 782*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 783*da0073e9SAndroid Build Coastguard Worker f(input, CallableClass()), input if f is fn1 else input * 2 784*da0073e9SAndroid Build Coastguard Worker ) 785*da0073e9SAndroid Build Coastguard Worker 786*da0073e9SAndroid Build Coastguard Worker # passing tensor and scalars 787*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(input, 1), input + 1) 788*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(input, 1.1), input + 1) 789*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(input, True), input + 1) 790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(input, input), input + 1) 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Worker def test_callable_list(self): 793*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 794*da0073e9SAndroid Build Coastguard Worker def fn(x, arg): 795*da0073e9SAndroid Build Coastguard Worker if callable(arg): 796*da0073e9SAndroid Build Coastguard Worker return x 797*da0073e9SAndroid Build Coastguard Worker return x + 1 798*da0073e9SAndroid Build Coastguard Worker 799*da0073e9SAndroid Build Coastguard Worker input = torch.randn(4) 800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(input, [1, 2, 3]), input + 1) 801*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(input, (1, 2, 3)), input + 1) 802*da0073e9SAndroid Build Coastguard Worker 803*da0073e9SAndroid Build Coastguard Worker @make_test 804*da0073e9SAndroid Build Coastguard Worker def test_len_constant_misc_iterables(x): 805*da0073e9SAndroid Build Coastguard Worker a = len((1, 2, 3)) 806*da0073e9SAndroid Build Coastguard Worker b = len("test str") 807*da0073e9SAndroid Build Coastguard Worker c = a + b 808*da0073e9SAndroid Build Coastguard Worker return torch.add(x, c) 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker @make_test 811*da0073e9SAndroid Build Coastguard Worker def test_dict_kwargs(x): 812*da0073e9SAndroid Build Coastguard Worker z = dict(text_embed=x + 1, other=x + 2) 813*da0073e9SAndroid Build Coastguard Worker return z 814*da0073e9SAndroid Build Coastguard Worker 815*da0073e9SAndroid Build Coastguard Worker @make_test 816*da0073e9SAndroid Build Coastguard Worker def test_ordered_dict_kwargs(x): 817*da0073e9SAndroid Build Coastguard Worker z = collections.OrderedDict(sample=torch.ones(10)) 818*da0073e9SAndroid Build Coastguard Worker return z 819*da0073e9SAndroid Build Coastguard Worker 820*da0073e9SAndroid Build Coastguard Worker @make_test 821*da0073e9SAndroid Build Coastguard Worker def test_custom_dict_kwargs(x): 822*da0073e9SAndroid Build Coastguard Worker z = CustomDictSubclass(sample=torch.ones(10)) 823*da0073e9SAndroid Build Coastguard Worker return z 824*da0073e9SAndroid Build Coastguard Worker 825*da0073e9SAndroid Build Coastguard Worker @make_test 826*da0073e9SAndroid Build Coastguard Worker def test_float(x): 827*da0073e9SAndroid Build Coastguard Worker y = float(1.2) # noqa: UP018 828*da0073e9SAndroid Build Coastguard Worker y += float("1.2") 829*da0073e9SAndroid Build Coastguard Worker return torch.add(x, y) 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker @make_test 832*da0073e9SAndroid Build Coastguard Worker def test_is_floating_point(x): 833*da0073e9SAndroid Build Coastguard Worker y = x + 1 834*da0073e9SAndroid Build Coastguard Worker return torch.is_floating_point(y), torch.is_floating_point(input=y) 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker @make_test 837*da0073e9SAndroid Build Coastguard Worker def test_dtype(x): 838*da0073e9SAndroid Build Coastguard Worker if x.dtype == torch.float32: 839*da0073e9SAndroid Build Coastguard Worker return x + 1 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker @make_test 842*da0073e9SAndroid Build Coastguard Worker def test_get_default_dtype(x): 843*da0073e9SAndroid Build Coastguard Worker if x.dtype == torch.get_default_dtype(): 844*da0073e9SAndroid Build Coastguard Worker return x + 1 845*da0073e9SAndroid Build Coastguard Worker else: 846*da0073e9SAndroid Build Coastguard Worker return x - 1 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker @make_test 849*da0073e9SAndroid Build Coastguard Worker def test_get_autocast_gpu_dtype(x): 850*da0073e9SAndroid Build Coastguard Worker dtype = torch.get_autocast_gpu_dtype() 851*da0073e9SAndroid Build Coastguard Worker return x.type(dtype) 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard Worker @make_test 854*da0073e9SAndroid Build Coastguard Worker def test_is_any_autocast_enabled(x): 855*da0073e9SAndroid Build Coastguard Worker if torch._C._is_any_autocast_enabled(): 856*da0073e9SAndroid Build Coastguard Worker return x + 1 857*da0073e9SAndroid Build Coastguard Worker else: 858*da0073e9SAndroid Build Coastguard Worker return x - 1 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker @make_test 861*da0073e9SAndroid Build Coastguard Worker def test_is_checkpoint_valid(x): 862*da0073e9SAndroid Build Coastguard Worker if torch.autograd._is_checkpoint_valid(): 863*da0073e9SAndroid Build Coastguard Worker return x + 1 864*da0073e9SAndroid Build Coastguard Worker else: 865*da0073e9SAndroid Build Coastguard Worker return x - 1 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Worker @make_test 868*da0073e9SAndroid Build Coastguard Worker def test_list_compare_polyfill(x): 869*da0073e9SAndroid Build Coastguard Worker for a, b, c in [ 870*da0073e9SAndroid Build Coastguard Worker [(1, 2, 3), (1, 2, 3), 7.77], 871*da0073e9SAndroid Build Coastguard Worker [(1, 4, 3), (1, 2, 3), 3.33], 872*da0073e9SAndroid Build Coastguard Worker [(1, 2), (1, 2, 3), 5.55], 873*da0073e9SAndroid Build Coastguard Worker [(1, 2, 3), (1, 2), 11.11], 874*da0073e9SAndroid Build Coastguard Worker [(1, -1, 3), (1, 2, 3), 13.33], 875*da0073e9SAndroid Build Coastguard Worker ]: 876*da0073e9SAndroid Build Coastguard Worker if a != b: 877*da0073e9SAndroid Build Coastguard Worker x += 1 * c 878*da0073e9SAndroid Build Coastguard Worker if a == b: 879*da0073e9SAndroid Build Coastguard Worker x += 2 * c 880*da0073e9SAndroid Build Coastguard Worker if a < b: 881*da0073e9SAndroid Build Coastguard Worker x += 4 * c 882*da0073e9SAndroid Build Coastguard Worker if a > b: 883*da0073e9SAndroid Build Coastguard Worker x += 8 * c 884*da0073e9SAndroid Build Coastguard Worker if a <= b: 885*da0073e9SAndroid Build Coastguard Worker x += 16 * c 886*da0073e9SAndroid Build Coastguard Worker if a >= b: 887*da0073e9SAndroid Build Coastguard Worker x += 32 * c 888*da0073e9SAndroid Build Coastguard Worker return x 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker @make_test 891*da0073e9SAndroid Build Coastguard Worker def test_promote_types(x): 892*da0073e9SAndroid Build Coastguard Worker if x.dtype == torch.promote_types(torch.int32, torch.float32): 893*da0073e9SAndroid Build Coastguard Worker return x + 1 894*da0073e9SAndroid Build Coastguard Worker else: 895*da0073e9SAndroid Build Coastguard Worker return x - 1 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Worker @make_test 898*da0073e9SAndroid Build Coastguard Worker def test_cublas_allow_tf32(x): 899*da0073e9SAndroid Build Coastguard Worker if torch.backends.cuda.matmul.allow_tf32: 900*da0073e9SAndroid Build Coastguard Worker return x.sin() + 1 901*da0073e9SAndroid Build Coastguard Worker 902*da0073e9SAndroid Build Coastguard Worker return x.cos() - 1 903*da0073e9SAndroid Build Coastguard Worker 904*da0073e9SAndroid Build Coastguard Worker @make_test 905*da0073e9SAndroid Build Coastguard Worker def test_get_calculate_correct_fan(x): 906*da0073e9SAndroid Build Coastguard Worker fan_in = torch.nn.init._calculate_correct_fan(x, "fan_in") 907*da0073e9SAndroid Build Coastguard Worker return x + fan_in 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker @make_test 910*da0073e9SAndroid Build Coastguard Worker def test_is_complex(x): 911*da0073e9SAndroid Build Coastguard Worker if torch.is_complex(x): 912*da0073e9SAndroid Build Coastguard Worker return x + 1 913*da0073e9SAndroid Build Coastguard Worker else: 914*da0073e9SAndroid Build Coastguard Worker return x - 1 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker @make_test 917*da0073e9SAndroid Build Coastguard Worker def test_tensor_is_complex(x): 918*da0073e9SAndroid Build Coastguard Worker if x.is_complex(): 919*da0073e9SAndroid Build Coastguard Worker return x + 1 920*da0073e9SAndroid Build Coastguard Worker else: 921*da0073e9SAndroid Build Coastguard Worker return x - 1 922*da0073e9SAndroid Build Coastguard Worker 923*da0073e9SAndroid Build Coastguard Worker @make_test 924*da0073e9SAndroid Build Coastguard Worker def test_get_privateuse1_name(x): 925*da0073e9SAndroid Build Coastguard Worker if torch._C._get_privateuse1_backend_name() == "privateuseone": 926*da0073e9SAndroid Build Coastguard Worker return x + 1 927*da0073e9SAndroid Build Coastguard Worker else: 928*da0073e9SAndroid Build Coastguard Worker return x - 1 929*da0073e9SAndroid Build Coastguard Worker 930*da0073e9SAndroid Build Coastguard Worker @make_test 931*da0073e9SAndroid Build Coastguard Worker def test_device(x): 932*da0073e9SAndroid Build Coastguard Worker if not x.is_cuda: 933*da0073e9SAndroid Build Coastguard Worker return x + 1 934*da0073e9SAndroid Build Coastguard Worker 935*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 936*da0073e9SAndroid Build Coastguard Worker @make_test 937*da0073e9SAndroid Build Coastguard Worker def test_get_device_properties_tensor_device(a): 938*da0073e9SAndroid Build Coastguard Worker x = a.to("cuda") 939*da0073e9SAndroid Build Coastguard Worker prop = torch.cuda.get_device_properties(x.device) 940*da0073e9SAndroid Build Coastguard Worker if prop.major == 8: 941*da0073e9SAndroid Build Coastguard Worker return x + prop.multi_processor_count 942*da0073e9SAndroid Build Coastguard Worker return x + prop.max_threads_per_multi_processor 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker @make_test 945*da0073e9SAndroid Build Coastguard Worker def test_tensor_type(a, b): 946*da0073e9SAndroid Build Coastguard Worker m = a.to(torch.float16) 947*da0073e9SAndroid Build Coastguard Worker return b.type(m.type()) 948*da0073e9SAndroid Build Coastguard Worker 949*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 950*da0073e9SAndroid Build Coastguard Worker @make_test 951*da0073e9SAndroid Build Coastguard Worker def test_tensor_type2(a, b): 952*da0073e9SAndroid Build Coastguard Worker m = a.to("cuda") 953*da0073e9SAndroid Build Coastguard Worker return m + b.type(m.type()) 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker @make_test 956*da0073e9SAndroid Build Coastguard Worker def test_tensor_type3(a, b): 957*da0073e9SAndroid Build Coastguard Worker m = a.type(torch.HalfTensor) 958*da0073e9SAndroid Build Coastguard Worker return b.type(m.type()) 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Worker @make_test 961*da0073e9SAndroid Build Coastguard Worker def test_tensor_type4(a, b): 962*da0073e9SAndroid Build Coastguard Worker m = a.type("torch.HalfTensor") 963*da0073e9SAndroid Build Coastguard Worker return b.type(m.type()) 964*da0073e9SAndroid Build Coastguard Worker 965*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 966*da0073e9SAndroid Build Coastguard Worker @make_test 967*da0073e9SAndroid Build Coastguard Worker def test_tensor_type5(a, b): 968*da0073e9SAndroid Build Coastguard Worker m = a.type(torch.cuda.HalfTensor) 969*da0073e9SAndroid Build Coastguard Worker return b.type(m.type()) 970*da0073e9SAndroid Build Coastguard Worker 971*da0073e9SAndroid Build Coastguard Worker @make_test 972*da0073e9SAndroid Build Coastguard Worker def test_tensor_element_size(a): 973*da0073e9SAndroid Build Coastguard Worker if a.element_size() > 1: 974*da0073e9SAndroid Build Coastguard Worker return (a + a.element_size(), a - a.element_size()) 975*da0073e9SAndroid Build Coastguard Worker return (a - a.element_size(), a + a.element_size()) 976*da0073e9SAndroid Build Coastguard Worker 977*da0073e9SAndroid Build Coastguard Worker @make_test 978*da0073e9SAndroid Build Coastguard Worker def test_ndim(x): 979*da0073e9SAndroid Build Coastguard Worker if x.ndim == 2 and x.ndimension() == 2 and x.dim() == 2: 980*da0073e9SAndroid Build Coastguard Worker return x + 1 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker @make_test 983*da0073e9SAndroid Build Coastguard Worker def test_T(x): 984*da0073e9SAndroid Build Coastguard Worker return torch.ones_like(x.T) 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker @make_test 987*da0073e9SAndroid Build Coastguard Worker def test_mT(x): 988*da0073e9SAndroid Build Coastguard Worker return torch.ones_like(x.mT) 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker @make_test 991*da0073e9SAndroid Build Coastguard Worker def test_is_sparse(x): 992*da0073e9SAndroid Build Coastguard Worker if not x.is_sparse: 993*da0073e9SAndroid Build Coastguard Worker return x + 1 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard Worker @make_test 996*da0073e9SAndroid Build Coastguard Worker def test_shape1(x): 997*da0073e9SAndroid Build Coastguard Worker if x.shape[0] == 10: 998*da0073e9SAndroid Build Coastguard Worker return x + 1 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker @make_test 1001*da0073e9SAndroid Build Coastguard Worker def test_shape2(x): 1002*da0073e9SAndroid Build Coastguard Worker if x.size(1) == 10: 1003*da0073e9SAndroid Build Coastguard Worker return x + 1 1004*da0073e9SAndroid Build Coastguard Worker 1005*da0073e9SAndroid Build Coastguard Worker @make_test 1006*da0073e9SAndroid Build Coastguard Worker def test_del(a, b): 1007*da0073e9SAndroid Build Coastguard Worker c = a + 1 1008*da0073e9SAndroid Build Coastguard Worker d = c + 2 1009*da0073e9SAndroid Build Coastguard Worker del c, a 1010*da0073e9SAndroid Build Coastguard Worker return b + d 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Worker @make_test 1013*da0073e9SAndroid Build Coastguard Worker def test_chunks1(x): 1014*da0073e9SAndroid Build Coastguard Worker chunk_size = 5 1015*da0073e9SAndroid Build Coastguard Worker assert x.shape[0] % chunk_size == 0 1016*da0073e9SAndroid Build Coastguard Worker assert x.shape[0] // chunk_size == 2 1017*da0073e9SAndroid Build Coastguard Worker return x[:chunk_size] - x[chunk_size:] 1018*da0073e9SAndroid Build Coastguard Worker 1019*da0073e9SAndroid Build Coastguard Worker @make_test 1020*da0073e9SAndroid Build Coastguard Worker def test_import1(x, y): 1021*da0073e9SAndroid Build Coastguard Worker import torch 1022*da0073e9SAndroid Build Coastguard Worker from torch import sub 1023*da0073e9SAndroid Build Coastguard Worker 1024*da0073e9SAndroid Build Coastguard Worker return sub(torch.add(x, y), y) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker @make_test 1027*da0073e9SAndroid Build Coastguard Worker def test_return_dict(x, y): 1028*da0073e9SAndroid Build Coastguard Worker z = [x + y, y, False] 1029*da0073e9SAndroid Build Coastguard Worker return {"x": x, "z": z, "a": x, "b": z, "c": x} 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker @make_test 1032*da0073e9SAndroid Build Coastguard Worker def test_return_dict2(x, y): 1033*da0073e9SAndroid Build Coastguard Worker tmp = {"x": x} 1034*da0073e9SAndroid Build Coastguard Worker tmp["z"] = [x + y, y] 1035*da0073e9SAndroid Build Coastguard Worker tmp["y"] = y 1036*da0073e9SAndroid Build Coastguard Worker tmp["z"].append(False) 1037*da0073e9SAndroid Build Coastguard Worker return tmp 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Worker @make_test 1040*da0073e9SAndroid Build Coastguard Worker def test_funcdef_closure(x, y): 1041*da0073e9SAndroid Build Coastguard Worker x = x + y + 1.0 1042*da0073e9SAndroid Build Coastguard Worker 1043*da0073e9SAndroid Build Coastguard Worker def inner(z): 1044*da0073e9SAndroid Build Coastguard Worker nonlocal x, y 1045*da0073e9SAndroid Build Coastguard Worker y = x + z + 20.0 1046*da0073e9SAndroid Build Coastguard Worker x = y + z + 10.0 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker inner(2.0) 1049*da0073e9SAndroid Build Coastguard Worker inner(3.0) 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Worker return x, y 1052*da0073e9SAndroid Build Coastguard Worker 1053*da0073e9SAndroid Build Coastguard Worker @make_test 1054*da0073e9SAndroid Build Coastguard Worker def test_module_constant(x, y): 1055*da0073e9SAndroid Build Coastguard Worker r = x + y 1056*da0073e9SAndroid Build Coastguard Worker for i in range(torch._dynamo.testing.three): 1057*da0073e9SAndroid Build Coastguard Worker r = r / y 1058*da0073e9SAndroid Build Coastguard Worker return r 1059*da0073e9SAndroid Build Coastguard Worker 1060*da0073e9SAndroid Build Coastguard Worker @make_test 1061*da0073e9SAndroid Build Coastguard Worker def test_inline_softmax(x, y): 1062*da0073e9SAndroid Build Coastguard Worker # This is common in sme huggingface models 1063*da0073e9SAndroid Build Coastguard Worker return torch.nn.Softmax(dim=-1)(x + y * 2) 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker @make_test 1066*da0073e9SAndroid Build Coastguard Worker def test_dtype_compare(a, b): 1067*da0073e9SAndroid Build Coastguard Worker if a.dtype == torch.float16: 1068*da0073e9SAndroid Build Coastguard Worker return a + 10 1069*da0073e9SAndroid Build Coastguard Worker if a.dtype == torch.float32: 1070*da0073e9SAndroid Build Coastguard Worker return a - b * 32 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker @make_test 1073*da0073e9SAndroid Build Coastguard Worker def test_build_list_unpack(a, b): 1074*da0073e9SAndroid Build Coastguard Worker it1 = (x + 1 for x in (a, b)) 1075*da0073e9SAndroid Build Coastguard Worker it2 = (x - 1 for x in (a, b)) 1076*da0073e9SAndroid Build Coastguard Worker return torch.cat([*it1, *it2], dim=-1) 1077*da0073e9SAndroid Build Coastguard Worker 1078*da0073e9SAndroid Build Coastguard Worker @make_test 1079*da0073e9SAndroid Build Coastguard Worker def test_tensor_len(a, b): 1080*da0073e9SAndroid Build Coastguard Worker return a + b + len(a) + b.__len__() 1081*da0073e9SAndroid Build Coastguard Worker 1082*da0073e9SAndroid Build Coastguard Worker @make_test 1083*da0073e9SAndroid Build Coastguard Worker def test_pop(a, b): 1084*da0073e9SAndroid Build Coastguard Worker ll = [a, b] 1085*da0073e9SAndroid Build Coastguard Worker ll.append(a + 1) 1086*da0073e9SAndroid Build Coastguard Worker ll.extend( 1087*da0073e9SAndroid Build Coastguard Worker [ 1088*da0073e9SAndroid Build Coastguard Worker b + 2, 1089*da0073e9SAndroid Build Coastguard Worker a + b, 1090*da0073e9SAndroid Build Coastguard Worker ] 1091*da0073e9SAndroid Build Coastguard Worker ) 1092*da0073e9SAndroid Build Coastguard Worker ll.pop(-1) 1093*da0073e9SAndroid Build Coastguard Worker ll.pop(0) 1094*da0073e9SAndroid Build Coastguard Worker ll.pop() 1095*da0073e9SAndroid Build Coastguard Worker v1, v2 = ll 1096*da0073e9SAndroid Build Coastguard Worker return v1 - v2 1097*da0073e9SAndroid Build Coastguard Worker 1098*da0073e9SAndroid Build Coastguard Worker @make_test 1099*da0073e9SAndroid Build Coastguard Worker def test_list_convert(a, b): 1100*da0073e9SAndroid Build Coastguard Worker ll = [a + 2, b] 1101*da0073e9SAndroid Build Coastguard Worker ll = tuple(ll) 1102*da0073e9SAndroid Build Coastguard Worker tmp = b + 3 1103*da0073e9SAndroid Build Coastguard Worker ll = list(ll) 1104*da0073e9SAndroid Build Coastguard Worker v1, v2 = ll 1105*da0073e9SAndroid Build Coastguard Worker return v1 - v2 + tmp 1106*da0073e9SAndroid Build Coastguard Worker 1107*da0073e9SAndroid Build Coastguard Worker @make_test 1108*da0073e9SAndroid Build Coastguard Worker def test_list_add(a, b): 1109*da0073e9SAndroid Build Coastguard Worker l1 = (a, b) 1110*da0073e9SAndroid Build Coastguard Worker l2 = () # being a LOAD_CONST in the bytecode 1111*da0073e9SAndroid Build Coastguard Worker l3 = l1 + l2 1112*da0073e9SAndroid Build Coastguard Worker return l3[0] + l3[1] 1113*da0073e9SAndroid Build Coastguard Worker 1114*da0073e9SAndroid Build Coastguard Worker @make_test 1115*da0073e9SAndroid Build Coastguard Worker def test_list_index_with_constant_tensor(a, b): 1116*da0073e9SAndroid Build Coastguard Worker l1 = [a, b, a + 1, b + 1] 1117*da0073e9SAndroid Build Coastguard Worker return l1[torch.as_tensor(2)] 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Worker @make_test 1120*da0073e9SAndroid Build Coastguard Worker def test_startswith(a, b): 1121*da0073e9SAndroid Build Coastguard Worker x = a + b 1122*da0073e9SAndroid Build Coastguard Worker if "foobar".startswith("foo") and "test" in constant3.__module__: 1123*da0073e9SAndroid Build Coastguard Worker x = x + 1 1124*da0073e9SAndroid Build Coastguard Worker return x 1125*da0073e9SAndroid Build Coastguard Worker 1126*da0073e9SAndroid Build Coastguard Worker @make_test 1127*da0073e9SAndroid Build Coastguard Worker def test_dict_ops(a, b): 1128*da0073e9SAndroid Build Coastguard Worker tmp = {"a": a + 1, "b": b + 2} 1129*da0073e9SAndroid Build Coastguard Worker assert tmp.get("zzz") is None 1130*da0073e9SAndroid Build Coastguard Worker v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4) 1131*da0073e9SAndroid Build Coastguard Worker tmp.update({"d": 3}) 1132*da0073e9SAndroid Build Coastguard Worker tmp["c"] = v + tmp["d"] 1133*da0073e9SAndroid Build Coastguard Worker if "c" in tmp and "missing" not in tmp: 1134*da0073e9SAndroid Build Coastguard Worker return tmp["c"] - tmp["a"] + len(tmp) 1135*da0073e9SAndroid Build Coastguard Worker 1136*da0073e9SAndroid Build Coastguard Worker @make_test 1137*da0073e9SAndroid Build Coastguard Worker def test_inline_jit__unwrap_optional(x): 1138*da0073e9SAndroid Build Coastguard Worker if torch.jit._unwrap_optional(x) is None: 1139*da0073e9SAndroid Build Coastguard Worker return torch.ones(2, 2) 1140*da0073e9SAndroid Build Coastguard Worker return x.sin() 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker @make_test 1143*da0073e9SAndroid Build Coastguard Worker def test_zip_longest(x): 1144*da0073e9SAndroid Build Coastguard Worker list1 = [1, 2, 3] 1145*da0073e9SAndroid Build Coastguard Worker list2 = ["a", "b"] 1146*da0073e9SAndroid Build Coastguard Worker list3 = [True, False, True, False] 1147*da0073e9SAndroid Build Coastguard Worker return torch.sin(x + 1), list( 1148*da0073e9SAndroid Build Coastguard Worker itertools.zip_longest(list1, list2, list3, fillvalue=None) 1149*da0073e9SAndroid Build Coastguard Worker ) 1150*da0073e9SAndroid Build Coastguard Worker 1151*da0073e9SAndroid Build Coastguard Worker def test_torch_size_as_dict_key(self): 1152*da0073e9SAndroid Build Coastguard Worker def fn(x, cached): 1153*da0073e9SAndroid Build Coastguard Worker if x.shape not in cached: 1154*da0073e9SAndroid Build Coastguard Worker cached[x.shape] = x 1155*da0073e9SAndroid Build Coastguard Worker return x + cached[x.shape] 1156*da0073e9SAndroid Build Coastguard Worker 1157*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1158*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn(2, 3) 1159*da0073e9SAndroid Build Coastguard Worker x2 = torch.randn(2, 3) 1160*da0073e9SAndroid Build Coastguard Worker cached = {} 1161*da0073e9SAndroid Build Coastguard Worker ref1 = fn(x1, cached) 1162*da0073e9SAndroid Build Coastguard Worker ref2 = fn(x2, cached) 1163*da0073e9SAndroid Build Coastguard Worker cached = {} 1164*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(x1, cached) 1165*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(x2, cached) 1166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref1, res1) 1167*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref2, res2) 1168*da0073e9SAndroid Build Coastguard Worker 1169*da0073e9SAndroid Build Coastguard Worker def test_dict_param_keys(self): 1170*da0073e9SAndroid Build Coastguard Worker a_param = torch.nn.Parameter(torch.ones([4, 4])) 1171*da0073e9SAndroid Build Coastguard Worker 1172*da0073e9SAndroid Build Coastguard Worker def fn(a): 1173*da0073e9SAndroid Build Coastguard Worker tmp = {"a": a, a_param: 3} 1174*da0073e9SAndroid Build Coastguard Worker return tmp["a"] + tmp[a_param] 1175*da0073e9SAndroid Build Coastguard Worker 1176*da0073e9SAndroid Build Coastguard Worker test = make_test(fn) 1177*da0073e9SAndroid Build Coastguard Worker test(self) 1178*da0073e9SAndroid Build Coastguard Worker 1179*da0073e9SAndroid Build Coastguard Worker def test_dict_mutable_map(self): 1180*da0073e9SAndroid Build Coastguard Worker from collections.abc import MutableMapping 1181*da0073e9SAndroid Build Coastguard Worker 1182*da0073e9SAndroid Build Coastguard Worker class TensorDict(MutableMapping): 1183*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1184*da0073e9SAndroid Build Coastguard Worker self._dict = {} 1185*da0073e9SAndroid Build Coastguard Worker 1186*da0073e9SAndroid Build Coastguard Worker def add(self, key, value): 1187*da0073e9SAndroid Build Coastguard Worker self._dict[key] = value 1188*da0073e9SAndroid Build Coastguard Worker 1189*da0073e9SAndroid Build Coastguard Worker def items(self): 1190*da0073e9SAndroid Build Coastguard Worker return self._dict.items() 1191*da0073e9SAndroid Build Coastguard Worker 1192*da0073e9SAndroid Build Coastguard Worker def __delitem__(self, key): 1193*da0073e9SAndroid Build Coastguard Worker del self._dict[key] 1194*da0073e9SAndroid Build Coastguard Worker 1195*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, key): 1196*da0073e9SAndroid Build Coastguard Worker return self._dict[key] 1197*da0073e9SAndroid Build Coastguard Worker 1198*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 1199*da0073e9SAndroid Build Coastguard Worker return iter(self._dict) 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Worker def __len__(self): 1202*da0073e9SAndroid Build Coastguard Worker return len(self._dict) 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker def __setitem__(self, key, value): 1205*da0073e9SAndroid Build Coastguard Worker self._dict[key] = value 1206*da0073e9SAndroid Build Coastguard Worker 1207*da0073e9SAndroid Build Coastguard Worker tensor_dict = TensorDict() 1208*da0073e9SAndroid Build Coastguard Worker tensor_dict.add("a", torch.ones(4) * 2) 1209*da0073e9SAndroid Build Coastguard Worker 1210*da0073e9SAndroid Build Coastguard Worker def fn(x): 1211*da0073e9SAndroid Build Coastguard Worker copy_tensordict = dict(tensor_dict) 1212*da0073e9SAndroid Build Coastguard Worker return x * copy_tensordict["a"] 1213*da0073e9SAndroid Build Coastguard Worker 1214*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1215*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 1216*da0073e9SAndroid Build Coastguard Worker 1217*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 1218*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 1219*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker def test_unpack_mutable_map(self): 1222*da0073e9SAndroid Build Coastguard Worker from collections.abc import MutableMapping 1223*da0073e9SAndroid Build Coastguard Worker 1224*da0073e9SAndroid Build Coastguard Worker class TensorDict(MutableMapping): 1225*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1226*da0073e9SAndroid Build Coastguard Worker self._dict = {} 1227*da0073e9SAndroid Build Coastguard Worker 1228*da0073e9SAndroid Build Coastguard Worker def add(self, key, value): 1229*da0073e9SAndroid Build Coastguard Worker self._dict[key] = value 1230*da0073e9SAndroid Build Coastguard Worker 1231*da0073e9SAndroid Build Coastguard Worker def items(self): 1232*da0073e9SAndroid Build Coastguard Worker return self._dict.items() 1233*da0073e9SAndroid Build Coastguard Worker 1234*da0073e9SAndroid Build Coastguard Worker def __delitem__(self, key): 1235*da0073e9SAndroid Build Coastguard Worker del self._dict[key] 1236*da0073e9SAndroid Build Coastguard Worker 1237*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, key): 1238*da0073e9SAndroid Build Coastguard Worker return self._dict[key] 1239*da0073e9SAndroid Build Coastguard Worker 1240*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 1241*da0073e9SAndroid Build Coastguard Worker return iter(self._dict) 1242*da0073e9SAndroid Build Coastguard Worker 1243*da0073e9SAndroid Build Coastguard Worker def __len__(self): 1244*da0073e9SAndroid Build Coastguard Worker return len(self._dict) 1245*da0073e9SAndroid Build Coastguard Worker 1246*da0073e9SAndroid Build Coastguard Worker def __setitem__(self, key, value): 1247*da0073e9SAndroid Build Coastguard Worker self._dict[key] = value 1248*da0073e9SAndroid Build Coastguard Worker 1249*da0073e9SAndroid Build Coastguard Worker tensor_dict = TensorDict() 1250*da0073e9SAndroid Build Coastguard Worker tensor_dict.add("a", torch.ones(4) * 2) 1251*da0073e9SAndroid Build Coastguard Worker 1252*da0073e9SAndroid Build Coastguard Worker def gn(x, a=1): 1253*da0073e9SAndroid Build Coastguard Worker return x * a 1254*da0073e9SAndroid Build Coastguard Worker 1255*da0073e9SAndroid Build Coastguard Worker def fn(x): 1256*da0073e9SAndroid Build Coastguard Worker return gn(x, **tensor_dict) 1257*da0073e9SAndroid Build Coastguard Worker 1258*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1259*da0073e9SAndroid Build Coastguard Worker 1260*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 1261*da0073e9SAndroid Build Coastguard Worker 1262*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 1263*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 1264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1265*da0073e9SAndroid Build Coastguard Worker 1266*da0073e9SAndroid Build Coastguard Worker def _test_default_dict_helper(self, factory): 1267*da0073e9SAndroid Build Coastguard Worker dd = collections.defaultdict(factory) 1268*da0073e9SAndroid Build Coastguard Worker param = torch.nn.Parameter(torch.ones([2, 2])) 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Worker def fn(x): 1271*da0073e9SAndroid Build Coastguard Worker dd["a"] = x + 1 1272*da0073e9SAndroid Build Coastguard Worker dd[param] = 123 1273*da0073e9SAndroid Build Coastguard Worker dd["c"] = x * 2 1274*da0073e9SAndroid Build Coastguard Worker return dd["b"], dd 1275*da0073e9SAndroid Build Coastguard Worker 1276*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 1277*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 1278*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert("eager")(fn) 1279*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 1280*da0073e9SAndroid Build Coastguard Worker 1281*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref[0], res[0])) 1282*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref[1]["a"], res[1]["a"])) 1283*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref[1]["c"], res[1]["c"])) 1284*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref[1][param], res[1][param])) 1285*da0073e9SAndroid Build Coastguard Worker 1286*da0073e9SAndroid Build Coastguard Worker def test_default_dict_dict(self): 1287*da0073e9SAndroid Build Coastguard Worker self._test_default_dict_helper(dict) 1288*da0073e9SAndroid Build Coastguard Worker 1289*da0073e9SAndroid Build Coastguard Worker def test_default_dict_list(self): 1290*da0073e9SAndroid Build Coastguard Worker self._test_default_dict_helper(list) 1291*da0073e9SAndroid Build Coastguard Worker 1292*da0073e9SAndroid Build Coastguard Worker def test_default_dict_tuple(self): 1293*da0073e9SAndroid Build Coastguard Worker self._test_default_dict_helper(tuple) 1294*da0073e9SAndroid Build Coastguard Worker 1295*da0073e9SAndroid Build Coastguard Worker def test_default_dict_set(self): 1296*da0073e9SAndroid Build Coastguard Worker self._test_default_dict_helper(set) 1297*da0073e9SAndroid Build Coastguard Worker 1298*da0073e9SAndroid Build Coastguard Worker def test_default_dict_lambda(self): 1299*da0073e9SAndroid Build Coastguard Worker self._test_default_dict_helper(lambda: dict()) # noqa: C408 1300*da0073e9SAndroid Build Coastguard Worker 1301*da0073e9SAndroid Build Coastguard Worker def test_default_dict_closure(self): 1302*da0073e9SAndroid Build Coastguard Worker def factory(): 1303*da0073e9SAndroid Build Coastguard Worker return dict() # noqa: C408 1304*da0073e9SAndroid Build Coastguard Worker 1305*da0073e9SAndroid Build Coastguard Worker self._test_default_dict_helper(factory) 1306*da0073e9SAndroid Build Coastguard Worker 1307*da0073e9SAndroid Build Coastguard Worker def test_class_dict(self): 1308*da0073e9SAndroid Build Coastguard Worker class A: 1309*da0073e9SAndroid Build Coastguard Worker x = 4 1310*da0073e9SAndroid Build Coastguard Worker y = 5 1311*da0073e9SAndroid Build Coastguard Worker 1312*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1313*da0073e9SAndroid Build Coastguard Worker self.a = 6 1314*da0073e9SAndroid Build Coastguard Worker 1315*da0073e9SAndroid Build Coastguard Worker a = A() 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker def fn(x): 1318*da0073e9SAndroid Build Coastguard Worker if "x" in type(a).__dict__: 1319*da0073e9SAndroid Build Coastguard Worker return x + 1 1320*da0073e9SAndroid Build Coastguard Worker return x + 2 1321*da0073e9SAndroid Build Coastguard Worker 1322*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1323*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 1324*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 1325*da0073e9SAndroid Build Coastguard Worker 1326*da0073e9SAndroid Build Coastguard Worker def test_default_dict_constr(self): 1327*da0073e9SAndroid Build Coastguard Worker param = torch.nn.Parameter(torch.ones([2, 2])) 1328*da0073e9SAndroid Build Coastguard Worker 1329*da0073e9SAndroid Build Coastguard Worker def fn(x): 1330*da0073e9SAndroid Build Coastguard Worker dd = collections.defaultdict(lambda: dict()) # noqa: C408 1331*da0073e9SAndroid Build Coastguard Worker dd["a"] = x + 1 1332*da0073e9SAndroid Build Coastguard Worker dd[param] = 123 1333*da0073e9SAndroid Build Coastguard Worker dd["c"] = x * 2 1334*da0073e9SAndroid Build Coastguard Worker dd.update({"b": x * 3}) 1335*da0073e9SAndroid Build Coastguard Worker dd.update([["d", x - 2], ("e", x + 2)]) 1336*da0073e9SAndroid Build Coastguard Worker dd.update(zip("ab", [x + 3, x + 4])) 1337*da0073e9SAndroid Build Coastguard Worker return dd["b"], dd 1338*da0073e9SAndroid Build Coastguard Worker 1339*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 1340*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 1341*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert("eager")(fn) 1342*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 1343*da0073e9SAndroid Build Coastguard Worker 1344*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref[0], res[0])) 1345*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref[1]["a"], res[1]["a"])) 1346*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref[1]["b"], res[1]["b"])) 1347*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref[1]["c"], res[1]["c"])) 1348*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref[1]["d"], res[1]["d"])) 1349*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref[1]["e"], res[1]["e"])) 1350*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref[1][param], res[1][param])) 1351*da0073e9SAndroid Build Coastguard Worker 1352*da0073e9SAndroid Build Coastguard Worker def test_dict_tuple_lazy_guard(self): 1353*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 1354*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 1355*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) * y[1] 1356*da0073e9SAndroid Build Coastguard Worker 1357*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(3), {1: 1, 2: 2}) 1358*da0073e9SAndroid Build Coastguard Worker # Changing the value of other key should not causing recompilation 1359*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 1360*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(3), {1: 1, 2: 3}) 1361*da0073e9SAndroid Build Coastguard Worker 1362*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(3), (1, 2, 3)) 1363*da0073e9SAndroid Build Coastguard Worker # Changing the value of index 0, 2 (not 1) should not cause recompilation 1364*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 1365*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(3), (11, 2, 13)) 1366*da0073e9SAndroid Build Coastguard Worker 1367*da0073e9SAndroid Build Coastguard Worker @make_test 1368*da0073e9SAndroid Build Coastguard Worker def test_call_dict1(x): 1369*da0073e9SAndroid Build Coastguard Worker d1 = dict() # noqa: C408 1370*da0073e9SAndroid Build Coastguard Worker d1["x"] = x + 1 1371*da0073e9SAndroid Build Coastguard Worker d2 = collections.OrderedDict() 1372*da0073e9SAndroid Build Coastguard Worker d2["x"] = x + 2 1373*da0073e9SAndroid Build Coastguard Worker return d1["x"] + d2["x"] + 1 1374*da0073e9SAndroid Build Coastguard Worker 1375*da0073e9SAndroid Build Coastguard Worker @make_test 1376*da0073e9SAndroid Build Coastguard Worker def test_call_dict2(x): 1377*da0073e9SAndroid Build Coastguard Worker d1 = dict() # noqa: C408 1378*da0073e9SAndroid Build Coastguard Worker d1["x"] = x 1379*da0073e9SAndroid Build Coastguard Worker d2 = collections.OrderedDict(d1) 1380*da0073e9SAndroid Build Coastguard Worker if isinstance(d2, collections.OrderedDict): 1381*da0073e9SAndroid Build Coastguard Worker return x + 1 1382*da0073e9SAndroid Build Coastguard Worker else: 1383*da0073e9SAndroid Build Coastguard Worker return x - 1 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker @make_test 1386*da0073e9SAndroid Build Coastguard Worker def test_call_dict3(x): 1387*da0073e9SAndroid Build Coastguard Worker my_list = [("a", x), ("b", x + 1), ("c", x + 2)] 1388*da0073e9SAndroid Build Coastguard Worker d1 = dict(my_list) 1389*da0073e9SAndroid Build Coastguard Worker d1["a"] = x + 10 1390*da0073e9SAndroid Build Coastguard Worker d2 = collections.OrderedDict(my_list) 1391*da0073e9SAndroid Build Coastguard Worker d2["c"] = x + 20 1392*da0073e9SAndroid Build Coastguard Worker return d1["a"] + d2["c"] + 1 1393*da0073e9SAndroid Build Coastguard Worker 1394*da0073e9SAndroid Build Coastguard Worker @make_test 1395*da0073e9SAndroid Build Coastguard Worker def test_call_dict4(x): 1396*da0073e9SAndroid Build Coastguard Worker my_list = (("a", x), ("b", x + 1), ("c", x + 2)) 1397*da0073e9SAndroid Build Coastguard Worker d1 = dict(my_list) 1398*da0073e9SAndroid Build Coastguard Worker d1["a"] = x + 10 1399*da0073e9SAndroid Build Coastguard Worker d2 = collections.OrderedDict(my_list) 1400*da0073e9SAndroid Build Coastguard Worker d2["c"] = x + 20 1401*da0073e9SAndroid Build Coastguard Worker return d1["a"] + d2["c"] + 1 1402*da0073e9SAndroid Build Coastguard Worker 1403*da0073e9SAndroid Build Coastguard Worker @make_test 1404*da0073e9SAndroid Build Coastguard Worker def test_call_dict5(x): 1405*da0073e9SAndroid Build Coastguard Worker my_list = iter([("a", x), ("b", x + 1), ("c", x + 2)]) 1406*da0073e9SAndroid Build Coastguard Worker d1 = dict(my_list) 1407*da0073e9SAndroid Build Coastguard Worker d1["a"] = x + 10 1408*da0073e9SAndroid Build Coastguard Worker d2 = collections.OrderedDict(my_list) 1409*da0073e9SAndroid Build Coastguard Worker d2["c"] = x + 20 1410*da0073e9SAndroid Build Coastguard Worker return d1["a"] + d2["c"] + 1 1411*da0073e9SAndroid Build Coastguard Worker 1412*da0073e9SAndroid Build Coastguard Worker @make_test 1413*da0073e9SAndroid Build Coastguard Worker def test_dict_fromkeys(x, y): 1414*da0073e9SAndroid Build Coastguard Worker lst = ["a", "b"] 1415*da0073e9SAndroid Build Coastguard Worker d = dict.fromkeys(lst) 1416*da0073e9SAndroid Build Coastguard Worker d1 = dict.fromkeys(d, x + 1) 1417*da0073e9SAndroid Build Coastguard Worker d2 = collections.defaultdict.fromkeys(iter(d1), x - 2) 1418*da0073e9SAndroid Build Coastguard Worker d3 = collections.OrderedDict.fromkeys(tuple(lst), value=y) 1419*da0073e9SAndroid Build Coastguard Worker return d1["a"] * d2["b"] + d2["a"] + d1["b"] + d3["a"] + d3["b"] + 1 1420*da0073e9SAndroid Build Coastguard Worker 1421*da0073e9SAndroid Build Coastguard Worker @make_test 1422*da0073e9SAndroid Build Coastguard Worker def test_dict_copy(x): 1423*da0073e9SAndroid Build Coastguard Worker my_list = [("a", x), ("b", x + 1), ("c", x + 2)] 1424*da0073e9SAndroid Build Coastguard Worker d1 = dict(my_list) 1425*da0073e9SAndroid Build Coastguard Worker d1["a"] = x + 10 1426*da0073e9SAndroid Build Coastguard Worker d2 = d1.copy() 1427*da0073e9SAndroid Build Coastguard Worker d2["a"] = x - 5 1428*da0073e9SAndroid Build Coastguard Worker d2["b"] = x + 3 1429*da0073e9SAndroid Build Coastguard Worker d3 = collections.OrderedDict(my_list) 1430*da0073e9SAndroid Build Coastguard Worker d3["c"] = x + 20 1431*da0073e9SAndroid Build Coastguard Worker d4 = d3.copy() 1432*da0073e9SAndroid Build Coastguard Worker d4["c"] = x - 10 1433*da0073e9SAndroid Build Coastguard Worker return d1["a"] * d2["a"] + d2["b"] + d3["c"] * d4["c"] + 1 1434*da0073e9SAndroid Build Coastguard Worker 1435*da0073e9SAndroid Build Coastguard Worker @make_test 1436*da0073e9SAndroid Build Coastguard Worker def test_dict_update(x, y, z): 1437*da0073e9SAndroid Build Coastguard Worker d = {"a": x, "b": y} 1438*da0073e9SAndroid Build Coastguard Worker d.update({"a": y - 1}) 1439*da0073e9SAndroid Build Coastguard Worker d.update([("b", z + 1), ["c", z]]) 1440*da0073e9SAndroid Build Coastguard Worker d.update(zip("ab", [z + 3, y + 2])) 1441*da0073e9SAndroid Build Coastguard Worker 1442*da0073e9SAndroid Build Coastguard Worker od = collections.OrderedDict(a=x * 3, b=y + 2) 1443*da0073e9SAndroid Build Coastguard Worker od.update({"a": y + 5}) 1444*da0073e9SAndroid Build Coastguard Worker od.update([["b", z + 6], ("c", z - 7)]) 1445*da0073e9SAndroid Build Coastguard Worker od.update(zip("ab", [z - 3, x + 2])) 1446*da0073e9SAndroid Build Coastguard Worker return d["a"] * od["a"] + od["c"] + d["b"] + od["b"] * d["c"] 1447*da0073e9SAndroid Build Coastguard Worker 1448*da0073e9SAndroid Build Coastguard Worker @make_test 1449*da0073e9SAndroid Build Coastguard Worker def test_min_max(a, b): 1450*da0073e9SAndroid Build Coastguard Worker c = a + b 1451*da0073e9SAndroid Build Coastguard Worker a = a.sum() 1452*da0073e9SAndroid Build Coastguard Worker b = b.sum() 1453*da0073e9SAndroid Build Coastguard Worker a = min(max(a, 0), 1) 1454*da0073e9SAndroid Build Coastguard Worker b = max(0, min(1, b)) 1455*da0073e9SAndroid Build Coastguard Worker return max(a, b) - min(a, b) + c 1456*da0073e9SAndroid Build Coastguard Worker 1457*da0073e9SAndroid Build Coastguard Worker @make_test 1458*da0073e9SAndroid Build Coastguard Worker def test_symbool_to_int(x): 1459*da0073e9SAndroid Build Coastguard Worker # this is roughly the pattern found in einops.unpack() 1460*da0073e9SAndroid Build Coastguard Worker if sum(s == -1 for s in x.size()) == 0: 1461*da0073e9SAndroid Build Coastguard Worker return x + 1 1462*da0073e9SAndroid Build Coastguard Worker else: 1463*da0073e9SAndroid Build Coastguard Worker return x - 1 1464*da0073e9SAndroid Build Coastguard Worker 1465*da0073e9SAndroid Build Coastguard Worker @make_test 1466*da0073e9SAndroid Build Coastguard Worker def test_map_sum(a, b, c, d): 1467*da0073e9SAndroid Build Coastguard Worker return sum(map(lambda x: x + 1, [a, b, c, d])) 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Worker @make_test 1470*da0073e9SAndroid Build Coastguard Worker def test_sum(a, b, c, d): 1471*da0073e9SAndroid Build Coastguard Worker return sum([a, b, c, d]) 1472*da0073e9SAndroid Build Coastguard Worker 1473*da0073e9SAndroid Build Coastguard Worker @make_test 1474*da0073e9SAndroid Build Coastguard Worker def test_sum_with_start_arg(a, b, c, d): 1475*da0073e9SAndroid Build Coastguard Worker return sum([b, c, d], a) 1476*da0073e9SAndroid Build Coastguard Worker 1477*da0073e9SAndroid Build Coastguard Worker @make_test 1478*da0073e9SAndroid Build Coastguard Worker def test_sum_with_start_kwarg(a, b, c, d): 1479*da0073e9SAndroid Build Coastguard Worker return sum([b, c, d], start=a) 1480*da0073e9SAndroid Build Coastguard Worker 1481*da0073e9SAndroid Build Coastguard Worker @make_test(expected_frame_count=0) 1482*da0073e9SAndroid Build Coastguard Worker def test_sum_shortcut(): 1483*da0073e9SAndroid Build Coastguard Worker return sum([0, 1.0, 2, 3.0]) 1484*da0073e9SAndroid Build Coastguard Worker 1485*da0073e9SAndroid Build Coastguard Worker @make_test(expected_frame_count=0) 1486*da0073e9SAndroid Build Coastguard Worker def test_sum_shortcut_with_start_arg(): 1487*da0073e9SAndroid Build Coastguard Worker return sum([0, 1.0, 2, 3.0], -10) 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Worker @make_test(expected_frame_count=0) 1490*da0073e9SAndroid Build Coastguard Worker def test_sum_shortcut_with_start_kwarg(): 1491*da0073e9SAndroid Build Coastguard Worker return sum([0, 1.0, 2, 3.0], start=-10) 1492*da0073e9SAndroid Build Coastguard Worker 1493*da0073e9SAndroid Build Coastguard Worker @make_test 1494*da0073e9SAndroid Build Coastguard Worker def test_reduce(a, b, c, d): 1495*da0073e9SAndroid Build Coastguard Worker return functools.reduce(operator.add, [a, b, c, d]) 1496*da0073e9SAndroid Build Coastguard Worker 1497*da0073e9SAndroid Build Coastguard Worker @make_test 1498*da0073e9SAndroid Build Coastguard Worker def test_reduce_with_initial(a, b, c, d): 1499*da0073e9SAndroid Build Coastguard Worker return functools.reduce(operator.add, [b, c, d], a) 1500*da0073e9SAndroid Build Coastguard Worker 1501*da0073e9SAndroid Build Coastguard Worker @make_test(expected_frame_count=0) 1502*da0073e9SAndroid Build Coastguard Worker def test_reduce_with_single(x): 1503*da0073e9SAndroid Build Coastguard Worker return functools.reduce(lambda a, b: (a, b), [x]) 1504*da0073e9SAndroid Build Coastguard Worker 1505*da0073e9SAndroid Build Coastguard Worker @make_test(expected_frame_count=0) 1506*da0073e9SAndroid Build Coastguard Worker def test_reduce_with_single_with_initial(x, y): 1507*da0073e9SAndroid Build Coastguard Worker return functools.reduce(lambda a, b: (a, b), [y], x) 1508*da0073e9SAndroid Build Coastguard Worker 1509*da0073e9SAndroid Build Coastguard Worker @make_test(expected_frame_count=0) 1510*da0073e9SAndroid Build Coastguard Worker def test_reduce_with_none_initial(x): 1511*da0073e9SAndroid Build Coastguard Worker return functools.reduce(lambda a, b: (a, b), [x], None) 1512*da0073e9SAndroid Build Coastguard Worker 1513*da0073e9SAndroid Build Coastguard Worker @make_test 1514*da0073e9SAndroid Build Coastguard Worker def test_tuple_contains(a, b): 1515*da0073e9SAndroid Build Coastguard Worker v1 = "a" 1516*da0073e9SAndroid Build Coastguard Worker v2 = "b" 1517*da0073e9SAndroid Build Coastguard Worker v3 = "c" 1518*da0073e9SAndroid Build Coastguard Worker vals1 = (v1, v2, v3) 1519*da0073e9SAndroid Build Coastguard Worker vals2 = ("d", "e", "f") 1520*da0073e9SAndroid Build Coastguard Worker if "a" in vals1 and "b" not in vals2: 1521*da0073e9SAndroid Build Coastguard Worker return a + b 1522*da0073e9SAndroid Build Coastguard Worker return a - b 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1525*da0073e9SAndroid Build Coastguard Worker sys.version_info < (3, 9), 1526*da0073e9SAndroid Build Coastguard Worker "SET_UPDATE was added at Python 3.9", 1527*da0073e9SAndroid Build Coastguard Worker ) 1528*da0073e9SAndroid Build Coastguard Worker @make_test 1529*da0073e9SAndroid Build Coastguard Worker def test_set_update_bytecode(x): 1530*da0073e9SAndroid Build Coastguard Worker # This produces bytecode SET_UPDATE since python 3.9 1531*da0073e9SAndroid Build Coastguard Worker var = {"apple", "banana", "cherry"} 1532*da0073e9SAndroid Build Coastguard Worker if isinstance(var, set): 1533*da0073e9SAndroid Build Coastguard Worker return x + 1 1534*da0073e9SAndroid Build Coastguard Worker else: 1535*da0073e9SAndroid Build Coastguard Worker return x - 1 1536*da0073e9SAndroid Build Coastguard Worker 1537*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1538*da0073e9SAndroid Build Coastguard Worker sys.version_info < (3, 9), 1539*da0073e9SAndroid Build Coastguard Worker "SET_UPDATE was added at Python 3.9", 1540*da0073e9SAndroid Build Coastguard Worker ) 1541*da0073e9SAndroid Build Coastguard Worker @make_test 1542*da0073e9SAndroid Build Coastguard Worker def test_set_update_list_with_duplicated_items(x): 1543*da0073e9SAndroid Build Coastguard Worker list1 = ["apple", "banana", "apple"] 1544*da0073e9SAndroid Build Coastguard Worker list2 = ["orange", "banana"] 1545*da0073e9SAndroid Build Coastguard Worker if len({*list1, *list2}) == 3: 1546*da0073e9SAndroid Build Coastguard Worker return x + 1 1547*da0073e9SAndroid Build Coastguard Worker else: 1548*da0073e9SAndroid Build Coastguard Worker return x - 1 1549*da0073e9SAndroid Build Coastguard Worker 1550*da0073e9SAndroid Build Coastguard Worker @make_test 1551*da0073e9SAndroid Build Coastguard Worker def test_set_contains(a, b): 1552*da0073e9SAndroid Build Coastguard Worker vals = set(["a", "b", "c"]) 1553*da0073e9SAndroid Build Coastguard Worker if "a" in vals: 1554*da0073e9SAndroid Build Coastguard Worker x = a + b 1555*da0073e9SAndroid Build Coastguard Worker else: 1556*da0073e9SAndroid Build Coastguard Worker x = a - b 1557*da0073e9SAndroid Build Coastguard Worker if "d" in vals: 1558*da0073e9SAndroid Build Coastguard Worker y = a + b 1559*da0073e9SAndroid Build Coastguard Worker else: 1560*da0073e9SAndroid Build Coastguard Worker y = a - b 1561*da0073e9SAndroid Build Coastguard Worker return x, y 1562*da0073e9SAndroid Build Coastguard Worker 1563*da0073e9SAndroid Build Coastguard Worker def test_set_isdisjoint(self): 1564*da0073e9SAndroid Build Coastguard Worker x = {"apple", "banana", "cherry"} 1565*da0073e9SAndroid Build Coastguard Worker y = {"google", "microsoft", "apple"} 1566*da0073e9SAndroid Build Coastguard Worker 1567*da0073e9SAndroid Build Coastguard Worker def fn(a): 1568*da0073e9SAndroid Build Coastguard Worker if x.isdisjoint(y): 1569*da0073e9SAndroid Build Coastguard Worker return a + 1 1570*da0073e9SAndroid Build Coastguard Worker else: 1571*da0073e9SAndroid Build Coastguard Worker return a - 1 1572*da0073e9SAndroid Build Coastguard Worker 1573*da0073e9SAndroid Build Coastguard Worker test = make_test(fn) 1574*da0073e9SAndroid Build Coastguard Worker test(self) 1575*da0073e9SAndroid Build Coastguard Worker 1576*da0073e9SAndroid Build Coastguard Worker @make_test 1577*da0073e9SAndroid Build Coastguard Worker def test_set_intersection(a, b): 1578*da0073e9SAndroid Build Coastguard Worker set1 = {"apple", "banana", "cherry"} 1579*da0073e9SAndroid Build Coastguard Worker set2 = {"google", "microsoft", "apple"} 1580*da0073e9SAndroid Build Coastguard Worker intersection_set = set1.intersection(set2) 1581*da0073e9SAndroid Build Coastguard Worker if "apple" in intersection_set: 1582*da0073e9SAndroid Build Coastguard Worker x = a + b 1583*da0073e9SAndroid Build Coastguard Worker else: 1584*da0073e9SAndroid Build Coastguard Worker x = a - b 1585*da0073e9SAndroid Build Coastguard Worker if "banana" in intersection_set: 1586*da0073e9SAndroid Build Coastguard Worker y = a + b 1587*da0073e9SAndroid Build Coastguard Worker else: 1588*da0073e9SAndroid Build Coastguard Worker y = a - b 1589*da0073e9SAndroid Build Coastguard Worker return x, y 1590*da0073e9SAndroid Build Coastguard Worker 1591*da0073e9SAndroid Build Coastguard Worker @make_test 1592*da0073e9SAndroid Build Coastguard Worker def test_set_union(a, b): 1593*da0073e9SAndroid Build Coastguard Worker set1 = {"apple", "banana", "cherry"} 1594*da0073e9SAndroid Build Coastguard Worker set2 = {"google", "microsoft", "apple"} 1595*da0073e9SAndroid Build Coastguard Worker union_set = set1.union(set2) 1596*da0073e9SAndroid Build Coastguard Worker if "apple" in union_set: 1597*da0073e9SAndroid Build Coastguard Worker x = a + b 1598*da0073e9SAndroid Build Coastguard Worker else: 1599*da0073e9SAndroid Build Coastguard Worker x = a - b 1600*da0073e9SAndroid Build Coastguard Worker if "banana" in union_set: 1601*da0073e9SAndroid Build Coastguard Worker y = a + b 1602*da0073e9SAndroid Build Coastguard Worker else: 1603*da0073e9SAndroid Build Coastguard Worker y = a - b 1604*da0073e9SAndroid Build Coastguard Worker return x, y 1605*da0073e9SAndroid Build Coastguard Worker 1606*da0073e9SAndroid Build Coastguard Worker @make_test 1607*da0073e9SAndroid Build Coastguard Worker def test_set_difference(a, b): 1608*da0073e9SAndroid Build Coastguard Worker set1 = {"apple", "banana", "cherry"} 1609*da0073e9SAndroid Build Coastguard Worker set2 = {"google", "microsoft", "apple"} 1610*da0073e9SAndroid Build Coastguard Worker difference_set = set1.difference(set2) 1611*da0073e9SAndroid Build Coastguard Worker if "apple" in difference_set: 1612*da0073e9SAndroid Build Coastguard Worker x = a + b 1613*da0073e9SAndroid Build Coastguard Worker else: 1614*da0073e9SAndroid Build Coastguard Worker x = a - b 1615*da0073e9SAndroid Build Coastguard Worker if "banana" in difference_set: 1616*da0073e9SAndroid Build Coastguard Worker y = a + b 1617*da0073e9SAndroid Build Coastguard Worker else: 1618*da0073e9SAndroid Build Coastguard Worker y = a - b 1619*da0073e9SAndroid Build Coastguard Worker return x, y 1620*da0073e9SAndroid Build Coastguard Worker 1621*da0073e9SAndroid Build Coastguard Worker def test_set_keys_view(self): 1622*da0073e9SAndroid Build Coastguard Worker from collections.abc import KeysView 1623*da0073e9SAndroid Build Coastguard Worker 1624*da0073e9SAndroid Build Coastguard Worker class StringKeys(KeysView): 1625*da0073e9SAndroid Build Coastguard Worker def __init__(self, keys): 1626*da0073e9SAndroid Build Coastguard Worker self.keys = keys 1627*da0073e9SAndroid Build Coastguard Worker 1628*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, key): 1629*da0073e9SAndroid Build Coastguard Worker return self.keys.__getitem__(key) 1630*da0073e9SAndroid Build Coastguard Worker 1631*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 1632*da0073e9SAndroid Build Coastguard Worker yield from self.keys 1633*da0073e9SAndroid Build Coastguard Worker 1634*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 1635*da0073e9SAndroid Build Coastguard Worker return f"{type(self).__name__}({self.keys})" 1636*da0073e9SAndroid Build Coastguard Worker 1637*da0073e9SAndroid Build Coastguard Worker def __len__(self): 1638*da0073e9SAndroid Build Coastguard Worker return len(self.keys) 1639*da0073e9SAndroid Build Coastguard Worker 1640*da0073e9SAndroid Build Coastguard Worker def __contains__(self, item): 1641*da0073e9SAndroid Build Coastguard Worker return self.keys.__contains__(item) 1642*da0073e9SAndroid Build Coastguard Worker 1643*da0073e9SAndroid Build Coastguard Worker a = StringKeys([1, 2, 3, 3]) 1644*da0073e9SAndroid Build Coastguard Worker 1645*da0073e9SAndroid Build Coastguard Worker def fn(x): 1646*da0073e9SAndroid Build Coastguard Worker set_a = set(a) 1647*da0073e9SAndroid Build Coastguard Worker return len(set_a) * x 1648*da0073e9SAndroid Build Coastguard Worker 1649*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1650*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 1651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 1652*da0073e9SAndroid Build Coastguard Worker 1653*da0073e9SAndroid Build Coastguard Worker def test_constant_set(self): 1654*da0073e9SAndroid Build Coastguard Worker s = set([1, 2]) 1655*da0073e9SAndroid Build Coastguard Worker 1656*da0073e9SAndroid Build Coastguard Worker def fn(x): 1657*da0073e9SAndroid Build Coastguard Worker return torch.cos(x) * len(s) 1658*da0073e9SAndroid Build Coastguard Worker 1659*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1660*da0073e9SAndroid Build Coastguard Worker 1661*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 1662*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 1663*da0073e9SAndroid Build Coastguard Worker 1664*da0073e9SAndroid Build Coastguard Worker # This should cause recompilation 1665*da0073e9SAndroid Build Coastguard Worker s.add(3) 1666*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 1667*da0073e9SAndroid Build Coastguard Worker 1668*da0073e9SAndroid Build Coastguard Worker def test_set_add(self): 1669*da0073e9SAndroid Build Coastguard Worker s = set([1, 2]) 1670*da0073e9SAndroid Build Coastguard Worker 1671*da0073e9SAndroid Build Coastguard Worker def fn(x): 1672*da0073e9SAndroid Build Coastguard Worker s.add(3) 1673*da0073e9SAndroid Build Coastguard Worker return torch.cos(x) * len(x) 1674*da0073e9SAndroid Build Coastguard Worker 1675*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1676*da0073e9SAndroid Build Coastguard Worker 1677*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 1678*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 1679*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(s), 3) 1680*da0073e9SAndroid Build Coastguard Worker 1681*da0073e9SAndroid Build Coastguard Worker @make_test 1682*da0073e9SAndroid Build Coastguard Worker def test_tuple_iadd(a, b): 1683*da0073e9SAndroid Build Coastguard Worker output = (a, b) 1684*da0073e9SAndroid Build Coastguard Worker output += (a + b, a - b) 1685*da0073e9SAndroid Build Coastguard Worker return output 1686*da0073e9SAndroid Build Coastguard Worker 1687*da0073e9SAndroid Build Coastguard Worker @make_test 1688*da0073e9SAndroid Build Coastguard Worker def test_unpack_ex1(x): 1689*da0073e9SAndroid Build Coastguard Worker output = (x, x + 1, x + 2, x + 3) 1690*da0073e9SAndroid Build Coastguard Worker a, b, *cd = output 1691*da0073e9SAndroid Build Coastguard Worker return a - b / cd[0] 1692*da0073e9SAndroid Build Coastguard Worker 1693*da0073e9SAndroid Build Coastguard Worker @make_test 1694*da0073e9SAndroid Build Coastguard Worker def test_unpack_ex2(x): 1695*da0073e9SAndroid Build Coastguard Worker output = (x, x + 1, x + 2, x + 3) 1696*da0073e9SAndroid Build Coastguard Worker *ab, c, d = output 1697*da0073e9SAndroid Build Coastguard Worker return c - d / ab[0] 1698*da0073e9SAndroid Build Coastguard Worker 1699*da0073e9SAndroid Build Coastguard Worker @make_test 1700*da0073e9SAndroid Build Coastguard Worker def test_unpack_ex3(x): 1701*da0073e9SAndroid Build Coastguard Worker output = (x, x + 1, x + 2, x + 3) 1702*da0073e9SAndroid Build Coastguard Worker a, *bc, d = output 1703*da0073e9SAndroid Build Coastguard Worker return a - d / bc[0] 1704*da0073e9SAndroid Build Coastguard Worker 1705*da0073e9SAndroid Build Coastguard Worker @make_test 1706*da0073e9SAndroid Build Coastguard Worker def test_const_tuple_add1(x): 1707*da0073e9SAndroid Build Coastguard Worker output = (x, x + 1, x + 2, x + 3) 1708*da0073e9SAndroid Build Coastguard Worker output = () + output + () 1709*da0073e9SAndroid Build Coastguard Worker return output[2] + output[3] 1710*da0073e9SAndroid Build Coastguard Worker 1711*da0073e9SAndroid Build Coastguard Worker @make_test 1712*da0073e9SAndroid Build Coastguard Worker def test_const_tuple_add2(x): 1713*da0073e9SAndroid Build Coastguard Worker output = (x, x + 1, x + 2, x + 3) 1714*da0073e9SAndroid Build Coastguard Worker output = (None,) + output + (None,) 1715*da0073e9SAndroid Build Coastguard Worker return output[2] + output[3] 1716*da0073e9SAndroid Build Coastguard Worker 1717*da0073e9SAndroid Build Coastguard Worker @make_test 1718*da0073e9SAndroid Build Coastguard Worker def test_list_truth(a, b): 1719*da0073e9SAndroid Build Coastguard Worker tmp = [1, 2, 3] 1720*da0073e9SAndroid Build Coastguard Worker if tmp: 1721*da0073e9SAndroid Build Coastguard Worker return a + b 1722*da0073e9SAndroid Build Coastguard Worker else: 1723*da0073e9SAndroid Build Coastguard Worker return a - b 1724*da0073e9SAndroid Build Coastguard Worker 1725*da0073e9SAndroid Build Coastguard Worker @make_test 1726*da0073e9SAndroid Build Coastguard Worker def test_list_reversed(a, b): 1727*da0073e9SAndroid Build Coastguard Worker tmp = [a + 1, a + 2, a + 3] 1728*da0073e9SAndroid Build Coastguard Worker return a + b + next(iter(reversed(tmp))) 1729*da0073e9SAndroid Build Coastguard Worker 1730*da0073e9SAndroid Build Coastguard Worker @make_test 1731*da0073e9SAndroid Build Coastguard Worker def test_list_sorted1(x): 1732*da0073e9SAndroid Build Coastguard Worker tmp = [1, 10, 3, 0] 1733*da0073e9SAndroid Build Coastguard Worker return x + 1, sorted(tmp), sorted(tmp, reverse=True) 1734*da0073e9SAndroid Build Coastguard Worker 1735*da0073e9SAndroid Build Coastguard Worker @make_test 1736*da0073e9SAndroid Build Coastguard Worker def test_list_sorted2(x): 1737*da0073e9SAndroid Build Coastguard Worker y = [ 1738*da0073e9SAndroid Build Coastguard Worker ("john", "A", 8), 1739*da0073e9SAndroid Build Coastguard Worker ("jane", "B", 5), 1740*da0073e9SAndroid Build Coastguard Worker ("dave", "B", 10), 1741*da0073e9SAndroid Build Coastguard Worker ] 1742*da0073e9SAndroid Build Coastguard Worker return ( 1743*da0073e9SAndroid Build Coastguard Worker x + 1, 1744*da0073e9SAndroid Build Coastguard Worker sorted(y), 1745*da0073e9SAndroid Build Coastguard Worker sorted(y, key=lambda student: student[2]), 1746*da0073e9SAndroid Build Coastguard Worker sorted(y, key=lambda student: student[2], reverse=True), 1747*da0073e9SAndroid Build Coastguard Worker ) 1748*da0073e9SAndroid Build Coastguard Worker 1749*da0073e9SAndroid Build Coastguard Worker @make_test 1750*da0073e9SAndroid Build Coastguard Worker def test_tuple_sorted(x): 1751*da0073e9SAndroid Build Coastguard Worker tmp = (1, 10, 3, 0) 1752*da0073e9SAndroid Build Coastguard Worker return x + 1, sorted(tmp), sorted(tmp, reverse=True) 1753*da0073e9SAndroid Build Coastguard Worker 1754*da0073e9SAndroid Build Coastguard Worker @make_test 1755*da0073e9SAndroid Build Coastguard Worker def test_dict_sorted(x): 1756*da0073e9SAndroid Build Coastguard Worker tmp = {1: "D", 10: "B", 3: "E", 0: "F"} 1757*da0073e9SAndroid Build Coastguard Worker return x + 1, sorted(tmp), sorted(tmp, reverse=True) 1758*da0073e9SAndroid Build Coastguard Worker 1759*da0073e9SAndroid Build Coastguard Worker def test_dict_hasattr(self): 1760*da0073e9SAndroid Build Coastguard Worker def fn(x): 1761*da0073e9SAndroid Build Coastguard Worker if hasattr(x, "to"): 1762*da0073e9SAndroid Build Coastguard Worker return x.to("cpu") 1763*da0073e9SAndroid Build Coastguard Worker if hasattr(x, "items"): 1764*da0073e9SAndroid Build Coastguard Worker return torch.cos(x["a"]) 1765*da0073e9SAndroid Build Coastguard Worker return x 1766*da0073e9SAndroid Build Coastguard Worker 1767*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1768*da0073e9SAndroid Build Coastguard Worker 1769*da0073e9SAndroid Build Coastguard Worker x = dict(a=torch.randn(3)) 1770*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 1771*da0073e9SAndroid Build Coastguard Worker 1772*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 1773*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 1774*da0073e9SAndroid Build Coastguard Worker 1775*da0073e9SAndroid Build Coastguard Worker @make_test 1776*da0073e9SAndroid Build Coastguard Worker def test_list_clear(a, b): 1777*da0073e9SAndroid Build Coastguard Worker tmp = [a + 1, a + 2] 1778*da0073e9SAndroid Build Coastguard Worker tmp.clear() 1779*da0073e9SAndroid Build Coastguard Worker tmp.append(a + b) 1780*da0073e9SAndroid Build Coastguard Worker return tmp 1781*da0073e9SAndroid Build Coastguard Worker 1782*da0073e9SAndroid Build Coastguard Worker @make_test 1783*da0073e9SAndroid Build Coastguard Worker def test_not_list(a): 1784*da0073e9SAndroid Build Coastguard Worker return not [a + 1] 1785*da0073e9SAndroid Build Coastguard Worker 1786*da0073e9SAndroid Build Coastguard Worker @make_test 1787*da0073e9SAndroid Build Coastguard Worker def test_islice_chain(a, b): 1788*da0073e9SAndroid Build Coastguard Worker tmp1 = [a + 1, a + 2] 1789*da0073e9SAndroid Build Coastguard Worker tmp2 = [a + 3, a + 4] 1790*da0073e9SAndroid Build Coastguard Worker a, b = list(itertools.islice(itertools.chain(tmp1, tmp2), 1, 3)) 1791*da0073e9SAndroid Build Coastguard Worker c = next(itertools.islice(tmp1, 1, None)) 1792*da0073e9SAndroid Build Coastguard Worker return a - b / c 1793*da0073e9SAndroid Build Coastguard Worker 1794*da0073e9SAndroid Build Coastguard Worker @make_test 1795*da0073e9SAndroid Build Coastguard Worker def test_namedtuple(a, b): 1796*da0073e9SAndroid Build Coastguard Worker mytuple = collections.namedtuple("mytuple", ["x", "y", "xy"]) 1797*da0073e9SAndroid Build Coastguard Worker tmp = mytuple(a, b, a + b) 1798*da0073e9SAndroid Build Coastguard Worker return mytuple(tmp.x, tmp[1], tmp.xy + b) 1799*da0073e9SAndroid Build Coastguard Worker 1800*da0073e9SAndroid Build Coastguard Worker @make_test 1801*da0073e9SAndroid Build Coastguard Worker def test_namedtuple_defaults(a, b): 1802*da0073e9SAndroid Build Coastguard Worker mytuple = collections.namedtuple( 1803*da0073e9SAndroid Build Coastguard Worker "mytuple", ["x", "y", "xy"], defaults=(None, 1, None) 1804*da0073e9SAndroid Build Coastguard Worker ) 1805*da0073e9SAndroid Build Coastguard Worker tmp = mytuple(a, xy=b) 1806*da0073e9SAndroid Build Coastguard Worker return mytuple(tmp.x, tmp[1], tmp.xy + b) 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker class MyNamedTuple(NamedTuple): 1809*da0073e9SAndroid Build Coastguard Worker first: torch.Tensor 1810*da0073e9SAndroid Build Coastguard Worker second: torch.Tensor 1811*da0073e9SAndroid Build Coastguard Worker 1812*da0073e9SAndroid Build Coastguard Worker def add(self) -> torch.Tensor: 1813*da0073e9SAndroid Build Coastguard Worker return self.first + self.second 1814*da0073e9SAndroid Build Coastguard Worker 1815*da0073e9SAndroid Build Coastguard Worker @staticmethod 1816*da0073e9SAndroid Build Coastguard Worker def static_method() -> int: 1817*da0073e9SAndroid Build Coastguard Worker return 1 1818*da0073e9SAndroid Build Coastguard Worker 1819*da0073e9SAndroid Build Coastguard Worker @classmethod 1820*da0073e9SAndroid Build Coastguard Worker def class_method(cls) -> str: 1821*da0073e9SAndroid Build Coastguard Worker return cls.__name__ 1822*da0073e9SAndroid Build Coastguard Worker 1823*da0073e9SAndroid Build Coastguard Worker @make_test 1824*da0073e9SAndroid Build Coastguard Worker def test_namedtuple_user_methods(a, b): 1825*da0073e9SAndroid Build Coastguard Worker mytuple = FunctionTests.MyNamedTuple(a, b) 1826*da0073e9SAndroid Build Coastguard Worker return mytuple.add(), mytuple.static_method(), mytuple.class_method() 1827*da0073e9SAndroid Build Coastguard Worker 1828*da0073e9SAndroid Build Coastguard Worker @make_test 1829*da0073e9SAndroid Build Coastguard Worker def test_namedtuple_hasattr(a, b): 1830*da0073e9SAndroid Build Coastguard Worker mytuple = FunctionTests.MyNamedTuple(a, b) 1831*da0073e9SAndroid Build Coastguard Worker 1832*da0073e9SAndroid Build Coastguard Worker def isinstance_namedtuple(obj) -> bool: 1833*da0073e9SAndroid Build Coastguard Worker return ( 1834*da0073e9SAndroid Build Coastguard Worker isinstance(obj, tuple) 1835*da0073e9SAndroid Build Coastguard Worker and hasattr(obj, "_asdict") 1836*da0073e9SAndroid Build Coastguard Worker and hasattr(obj, "_fields") 1837*da0073e9SAndroid Build Coastguard Worker ) 1838*da0073e9SAndroid Build Coastguard Worker 1839*da0073e9SAndroid Build Coastguard Worker if isinstance_namedtuple(mytuple): 1840*da0073e9SAndroid Build Coastguard Worker return a + b 1841*da0073e9SAndroid Build Coastguard Worker else: 1842*da0073e9SAndroid Build Coastguard Worker return a - b 1843*da0073e9SAndroid Build Coastguard Worker 1844*da0073e9SAndroid Build Coastguard Worker @make_test 1845*da0073e9SAndroid Build Coastguard Worker def test_torch_size_hasattr(x): 1846*da0073e9SAndroid Build Coastguard Worker if hasattr(x.shape, "_fields"): 1847*da0073e9SAndroid Build Coastguard Worker return x + 1 1848*da0073e9SAndroid Build Coastguard Worker else: 1849*da0073e9SAndroid Build Coastguard Worker return x - 1 1850*da0073e9SAndroid Build Coastguard Worker 1851*da0073e9SAndroid Build Coastguard Worker @make_test 1852*da0073e9SAndroid Build Coastguard Worker def test_is_quantized(a, b): 1853*da0073e9SAndroid Build Coastguard Worker if not a.is_quantized: 1854*da0073e9SAndroid Build Coastguard Worker return a + b 1855*da0073e9SAndroid Build Coastguard Worker 1856*da0073e9SAndroid Build Coastguard Worker @make_test 1857*da0073e9SAndroid Build Coastguard Worker def test_fstrings1(a, b): 1858*da0073e9SAndroid Build Coastguard Worker x = 1.229 1859*da0073e9SAndroid Build Coastguard Worker tmp = f"{x:.2f} bar" 1860*da0073e9SAndroid Build Coastguard Worker if tmp.startswith("1.23"): 1861*da0073e9SAndroid Build Coastguard Worker return a + b 1862*da0073e9SAndroid Build Coastguard Worker 1863*da0073e9SAndroid Build Coastguard Worker @make_test 1864*da0073e9SAndroid Build Coastguard Worker def test_fstrings2(x): 1865*da0073e9SAndroid Build Coastguard Worker tmp = f"{x.shape[0]} bar" 1866*da0073e9SAndroid Build Coastguard Worker if tmp.startswith("10"): 1867*da0073e9SAndroid Build Coastguard Worker return x + 1 1868*da0073e9SAndroid Build Coastguard Worker 1869*da0073e9SAndroid Build Coastguard Worker @make_test 1870*da0073e9SAndroid Build Coastguard Worker def test_fstrings3(x): 1871*da0073e9SAndroid Build Coastguard Worker tmp = f"{x.__class__.__name__} foo" 1872*da0073e9SAndroid Build Coastguard Worker if tmp.startswith("Tensor"): 1873*da0073e9SAndroid Build Coastguard Worker return x + 1 1874*da0073e9SAndroid Build Coastguard Worker 1875*da0073e9SAndroid Build Coastguard Worker @make_test 1876*da0073e9SAndroid Build Coastguard Worker def test_fstrings4(x): 1877*da0073e9SAndroid Build Coastguard Worker tmp = f"{x.shape[0]} bar" 1878*da0073e9SAndroid Build Coastguard Worker if "10" in tmp: 1879*da0073e9SAndroid Build Coastguard Worker return x + 1 1880*da0073e9SAndroid Build Coastguard Worker 1881*da0073e9SAndroid Build Coastguard Worker @make_test 1882*da0073e9SAndroid Build Coastguard Worker def test_fstrings5(x): 1883*da0073e9SAndroid Build Coastguard Worker tmp = f"{x.shape[0]} bar" 1884*da0073e9SAndroid Build Coastguard Worker if "10" in (tmp + "haha"): 1885*da0073e9SAndroid Build Coastguard Worker return x + 1 1886*da0073e9SAndroid Build Coastguard Worker 1887*da0073e9SAndroid Build Coastguard Worker @make_test 1888*da0073e9SAndroid Build Coastguard Worker def test_fstrings6(x): 1889*da0073e9SAndroid Build Coastguard Worker tmp = f"{x.shape[0] + x.shape[1]}" 1890*da0073e9SAndroid Build Coastguard Worker if "20" in tmp: 1891*da0073e9SAndroid Build Coastguard Worker return x + 1 1892*da0073e9SAndroid Build Coastguard Worker 1893*da0073e9SAndroid Build Coastguard Worker @make_test 1894*da0073e9SAndroid Build Coastguard Worker def test_tensor_new_with_size(x): 1895*da0073e9SAndroid Build Coastguard Worker y = torch.rand(5, 8) 1896*da0073e9SAndroid Build Coastguard Worker z = x.new(y.size()) 1897*da0073e9SAndroid Build Coastguard Worker assert z.size() == y.size() 1898*da0073e9SAndroid Build Coastguard Worker 1899*da0073e9SAndroid Build Coastguard Worker @make_test 1900*da0073e9SAndroid Build Coastguard Worker def test_tensor_new_with_shape(x): 1901*da0073e9SAndroid Build Coastguard Worker y = torch.rand(5, 8) 1902*da0073e9SAndroid Build Coastguard Worker z = x.new(y.shape) 1903*da0073e9SAndroid Build Coastguard Worker assert z.size() == y.size() 1904*da0073e9SAndroid Build Coastguard Worker 1905*da0073e9SAndroid Build Coastguard Worker @make_test 1906*da0073e9SAndroid Build Coastguard Worker def test_jit_annotate(x): 1907*da0073e9SAndroid Build Coastguard Worker y = torch.jit.annotate(Any, x + 1) 1908*da0073e9SAndroid Build Coastguard Worker return y + 2 1909*da0073e9SAndroid Build Coastguard Worker 1910*da0073e9SAndroid Build Coastguard Worker @make_test 1911*da0073e9SAndroid Build Coastguard Worker def test_is_contiguous_memory_format(tensor): 1912*da0073e9SAndroid Build Coastguard Worker if torch.jit.is_scripting(): 1913*da0073e9SAndroid Build Coastguard Worker return None 1914*da0073e9SAndroid Build Coastguard Worker elif tensor.is_contiguous(memory_format=torch.contiguous_format): 1915*da0073e9SAndroid Build Coastguard Worker return tensor + 1 1916*da0073e9SAndroid Build Coastguard Worker 1917*da0073e9SAndroid Build Coastguard Worker def test_is_contiguous_frame_counts(self): 1918*da0073e9SAndroid Build Coastguard Worker data = [ 1919*da0073e9SAndroid Build Coastguard Worker torch.rand(10), 1920*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 3, 32, 32), 1921*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 3, 32, 32).contiguous(memory_format=torch.channels_last), 1922*da0073e9SAndroid Build Coastguard Worker torch.rand(10)[::2], 1923*da0073e9SAndroid Build Coastguard Worker torch.rand(12), 1924*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 3, 24, 24).contiguous(memory_format=torch.channels_last), 1925*da0073e9SAndroid Build Coastguard Worker torch.rand(50)[::2], 1926*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 3, 32, 32)[:, :, 2:-2, 3:-3], 1927*da0073e9SAndroid Build Coastguard Worker ] 1928*da0073e9SAndroid Build Coastguard Worker # dynamo should recompile for all inputs in static shapes mode 1929*da0073e9SAndroid Build Coastguard Worker expected_frame_counts_static = [1, 2, 3, 4, 5, 6, 7, 8] 1930*da0073e9SAndroid Build Coastguard Worker # dynamo should recompile for items 0, 1, 2, 6 in dynamic shapes mode 1931*da0073e9SAndroid Build Coastguard Worker expected_frame_counts_dynamic = [1, 2, 3, 4, 4, 4, 4, 5] 1932*da0073e9SAndroid Build Coastguard Worker expected_frame_counts = ifdynstaticdefault( 1933*da0073e9SAndroid Build Coastguard Worker expected_frame_counts_static, expected_frame_counts_dynamic 1934*da0073e9SAndroid Build Coastguard Worker ) 1935*da0073e9SAndroid Build Coastguard Worker dynamic = ifdynstaticdefault(False, True) 1936*da0073e9SAndroid Build Coastguard Worker 1937*da0073e9SAndroid Build Coastguard Worker def func(x): 1938*da0073e9SAndroid Build Coastguard Worker if x.is_contiguous(): 1939*da0073e9SAndroid Build Coastguard Worker return x + 1 1940*da0073e9SAndroid Build Coastguard Worker elif x.is_contiguous(memory_format=torch.channels_last): 1941*da0073e9SAndroid Build Coastguard Worker return x + 2 1942*da0073e9SAndroid Build Coastguard Worker else: 1943*da0073e9SAndroid Build Coastguard Worker return x + 3 1944*da0073e9SAndroid Build Coastguard Worker 1945*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1946*da0073e9SAndroid Build Coastguard Worker cfunc = torch._dynamo.optimize_assert(cnt, dynamic=dynamic)(func) 1947*da0073e9SAndroid Build Coastguard Worker 1948*da0073e9SAndroid Build Coastguard Worker assert cnt.frame_count == 0 1949*da0073e9SAndroid Build Coastguard Worker for i, x in enumerate(data): 1950*da0073e9SAndroid Build Coastguard Worker expected = func(x) 1951*da0073e9SAndroid Build Coastguard Worker output = cfunc(x) 1952*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(output, expected)) 1953*da0073e9SAndroid Build Coastguard Worker assert cnt.frame_count == expected_frame_counts[i] 1954*da0073e9SAndroid Build Coastguard Worker 1955*da0073e9SAndroid Build Coastguard Worker @make_test 1956*da0073e9SAndroid Build Coastguard Worker def test_list_slice_assignment(x): 1957*da0073e9SAndroid Build Coastguard Worker m = [1, 2, 3, 4] 1958*da0073e9SAndroid Build Coastguard Worker m[1:] = [6] * (len(m) - 1) 1959*da0073e9SAndroid Build Coastguard Worker return x + 1 1960*da0073e9SAndroid Build Coastguard Worker 1961*da0073e9SAndroid Build Coastguard Worker @make_test 1962*da0073e9SAndroid Build Coastguard Worker def test_distributed_is_available(x): 1963*da0073e9SAndroid Build Coastguard Worker if torch.distributed.is_available(): 1964*da0073e9SAndroid Build Coastguard Worker return x + 1 1965*da0073e9SAndroid Build Coastguard Worker else: 1966*da0073e9SAndroid Build Coastguard Worker return x - 1 1967*da0073e9SAndroid Build Coastguard Worker 1968*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1969*da0073e9SAndroid Build Coastguard Worker not torch.distributed.is_available(), "requires distributed package" 1970*da0073e9SAndroid Build Coastguard Worker ) 1971*da0073e9SAndroid Build Coastguard Worker @make_test 1972*da0073e9SAndroid Build Coastguard Worker def test_distributed_is_initialized(x): 1973*da0073e9SAndroid Build Coastguard Worker if torch.distributed.is_initialized(): 1974*da0073e9SAndroid Build Coastguard Worker return x + 1 1975*da0073e9SAndroid Build Coastguard Worker else: 1976*da0073e9SAndroid Build Coastguard Worker return x - 1 1977*da0073e9SAndroid Build Coastguard Worker 1978*da0073e9SAndroid Build Coastguard Worker @disable_translation_validation_if_dynamic_shapes 1979*da0073e9SAndroid Build Coastguard Worker @make_test 1980*da0073e9SAndroid Build Coastguard Worker def test_torch_distributions_functions(x): 1981*da0073e9SAndroid Build Coastguard Worker normal = torch.distributions.Normal(x, torch.tensor(1)) 1982*da0073e9SAndroid Build Coastguard Worker independent = torch.distributions.Independent(normal, 1) 1983*da0073e9SAndroid Build Coastguard Worker return independent.log_prob(x) 1984*da0073e9SAndroid Build Coastguard Worker 1985*da0073e9SAndroid Build Coastguard Worker @make_test 1986*da0073e9SAndroid Build Coastguard Worker def test_context_wrapping_nested_functions_no_closure(x): 1987*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 1988*da0073e9SAndroid Build Coastguard Worker def augment(x: torch.Tensor) -> torch.Tensor: 1989*da0073e9SAndroid Build Coastguard Worker return (x + 1) * 2 1990*da0073e9SAndroid Build Coastguard Worker 1991*da0073e9SAndroid Build Coastguard Worker return augment(x) 1992*da0073e9SAndroid Build Coastguard Worker 1993*da0073e9SAndroid Build Coastguard Worker # # This is to test the new syntax for pattern matching 1994*da0073e9SAndroid Build Coastguard Worker # # ("match ... case ...") added on python 3.10. 1995*da0073e9SAndroid Build Coastguard Worker # # Uncomment these test cases if you run on 3.10+ 1996*da0073e9SAndroid Build Coastguard Worker # @make_test 1997*da0073e9SAndroid Build Coastguard Worker # def test_match_sequence(a): 1998*da0073e9SAndroid Build Coastguard Worker # point = (5, 8) 1999*da0073e9SAndroid Build Coastguard Worker # match point: 2000*da0073e9SAndroid Build Coastguard Worker # case (0, 0): 2001*da0073e9SAndroid Build Coastguard Worker # return a 2002*da0073e9SAndroid Build Coastguard Worker # case (0, y): 2003*da0073e9SAndroid Build Coastguard Worker # return a - y 2004*da0073e9SAndroid Build Coastguard Worker # case (x, 0): 2005*da0073e9SAndroid Build Coastguard Worker # return a + x 2006*da0073e9SAndroid Build Coastguard Worker # case (x, y): 2007*da0073e9SAndroid Build Coastguard Worker # return a + x - y 2008*da0073e9SAndroid Build Coastguard Worker 2009*da0073e9SAndroid Build Coastguard Worker # @make_test 2010*da0073e9SAndroid Build Coastguard Worker # def test_match_mapping_and_match_keys(x): 2011*da0073e9SAndroid Build Coastguard Worker # param = {"a": 0.5} 2012*da0073e9SAndroid Build Coastguard Worker # match param: 2013*da0073e9SAndroid Build Coastguard Worker # case {"a": param}: 2014*da0073e9SAndroid Build Coastguard Worker # return x * param 2015*da0073e9SAndroid Build Coastguard Worker # case {"b": param}: 2016*da0073e9SAndroid Build Coastguard Worker # return x / param 2017*da0073e9SAndroid Build Coastguard Worker 2018*da0073e9SAndroid Build Coastguard Worker def test_math_radians(self): 2019*da0073e9SAndroid Build Coastguard Worker def func(x, a): 2020*da0073e9SAndroid Build Coastguard Worker return x + math.radians(a) 2021*da0073e9SAndroid Build Coastguard Worker 2022*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2023*da0073e9SAndroid Build Coastguard Worker cfunc = torch._dynamo.optimize_assert(cnt)(func) 2024*da0073e9SAndroid Build Coastguard Worker 2025*da0073e9SAndroid Build Coastguard Worker assert cnt.frame_count == 0 2026*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10) 2027*da0073e9SAndroid Build Coastguard Worker expected = func(x, 12) 2028*da0073e9SAndroid Build Coastguard Worker output = cfunc(x, 12) 2029*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(output, expected)) 2030*da0073e9SAndroid Build Coastguard Worker assert cnt.frame_count == 1 2031*da0073e9SAndroid Build Coastguard Worker 2032*da0073e9SAndroid Build Coastguard Worker @make_test 2033*da0073e9SAndroid Build Coastguard Worker def test_numpy_meshgrid(x, y): 2034*da0073e9SAndroid Build Coastguard Worker r1, r2 = np.meshgrid(x.numpy(), y.numpy()) 2035*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(r1), torch.from_numpy(r2) 2036*da0073e9SAndroid Build Coastguard Worker 2037*da0073e9SAndroid Build Coastguard Worker @make_test 2038*da0073e9SAndroid Build Coastguard Worker def test_torch_from_numpy(x): 2039*da0073e9SAndroid Build Coastguard Worker a = x.numpy() 2040*da0073e9SAndroid Build Coastguard Worker b = torch.from_numpy(a) 2041*da0073e9SAndroid Build Coastguard Worker if b.size(0) == 1: 2042*da0073e9SAndroid Build Coastguard Worker return torch.tensor(True) 2043*da0073e9SAndroid Build Coastguard Worker else: 2044*da0073e9SAndroid Build Coastguard Worker return torch.tensor(False) 2045*da0073e9SAndroid Build Coastguard Worker 2046*da0073e9SAndroid Build Coastguard Worker @make_test 2047*da0073e9SAndroid Build Coastguard Worker def test_numpy_size(x): 2048*da0073e9SAndroid Build Coastguard Worker a = x.numpy() 2049*da0073e9SAndroid Build Coastguard Worker return a.size 2050*da0073e9SAndroid Build Coastguard Worker 2051*da0073e9SAndroid Build Coastguard Worker @make_test 2052*da0073e9SAndroid Build Coastguard Worker def test_numpy_attributes(x): 2053*da0073e9SAndroid Build Coastguard Worker a = x.numpy() 2054*da0073e9SAndroid Build Coastguard Worker return ( 2055*da0073e9SAndroid Build Coastguard Worker a.itemsize, 2056*da0073e9SAndroid Build Coastguard Worker a.strides, 2057*da0073e9SAndroid Build Coastguard Worker a.shape, 2058*da0073e9SAndroid Build Coastguard Worker a.ndim, 2059*da0073e9SAndroid Build Coastguard Worker a.size, 2060*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(a.T), 2061*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(a.real), 2062*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(a.imag), 2063*da0073e9SAndroid Build Coastguard Worker ) 2064*da0073e9SAndroid Build Coastguard Worker 2065*da0073e9SAndroid Build Coastguard Worker @make_test 2066*da0073e9SAndroid Build Coastguard Worker def test_mean_sum_np(x: torch.Tensor): 2067*da0073e9SAndroid Build Coastguard Worker x_mean = np.mean(x.numpy(), 1) 2068*da0073e9SAndroid Build Coastguard Worker x_sum = np.sum(x_mean) 2069*da0073e9SAndroid Build Coastguard Worker x_sum_array = np.asarray(x_sum) 2070*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy(x_sum_array) 2071*da0073e9SAndroid Build Coastguard Worker 2072*da0073e9SAndroid Build Coastguard Worker @make_test 2073*da0073e9SAndroid Build Coastguard Worker def test_return_numpy_ndarray(x): 2074*da0073e9SAndroid Build Coastguard Worker a = x.numpy() 2075*da0073e9SAndroid Build Coastguard Worker return a.T 2076*da0073e9SAndroid Build Coastguard Worker 2077*da0073e9SAndroid Build Coastguard Worker @make_test 2078*da0073e9SAndroid Build Coastguard Worker def test_return_multiple_numpy_ndarray(x): 2079*da0073e9SAndroid Build Coastguard Worker a = x.numpy() 2080*da0073e9SAndroid Build Coastguard Worker return a.T, a.imag, a.real 2081*da0073e9SAndroid Build Coastguard Worker 2082*da0073e9SAndroid Build Coastguard Worker @make_test 2083*da0073e9SAndroid Build Coastguard Worker def test_ndarray_method(x): 2084*da0073e9SAndroid Build Coastguard Worker a = x.numpy() 2085*da0073e9SAndroid Build Coastguard Worker return a.copy() 2086*da0073e9SAndroid Build Coastguard Worker 2087*da0073e9SAndroid Build Coastguard Worker @make_test 2088*da0073e9SAndroid Build Coastguard Worker def test_ndarray_transpose(x): 2089*da0073e9SAndroid Build Coastguard Worker a = x.numpy() 2090*da0073e9SAndroid Build Coastguard Worker return a.transpose(0, 1) 2091*da0073e9SAndroid Build Coastguard Worker 2092*da0073e9SAndroid Build Coastguard Worker @make_test 2093*da0073e9SAndroid Build Coastguard Worker def test_ndarray_reshape(x): 2094*da0073e9SAndroid Build Coastguard Worker a = x.numpy() 2095*da0073e9SAndroid Build Coastguard Worker return a.reshape([1, a.size]) 2096*da0073e9SAndroid Build Coastguard Worker 2097*da0073e9SAndroid Build Coastguard Worker @make_test 2098*da0073e9SAndroid Build Coastguard Worker def test_ndarray_methods_returning_scalar(x): 2099*da0073e9SAndroid Build Coastguard Worker a = x.numpy() 2100*da0073e9SAndroid Build Coastguard Worker return a.max(axis=0), a.all(axis=0) 2101*da0073e9SAndroid Build Coastguard Worker 2102*da0073e9SAndroid Build Coastguard Worker @make_test 2103*da0073e9SAndroid Build Coastguard Worker def test_ndarray_builtin_functions(x): 2104*da0073e9SAndroid Build Coastguard Worker a = x.numpy() 2105*da0073e9SAndroid Build Coastguard Worker return a + a, a - a 2106*da0073e9SAndroid Build Coastguard Worker 2107*da0073e9SAndroid Build Coastguard Worker @make_test 2108*da0073e9SAndroid Build Coastguard Worker def test_numpy_dtype_argument_to_function(x): 2109*da0073e9SAndroid Build Coastguard Worker return np.ones_like(x, dtype=np.float64) 2110*da0073e9SAndroid Build Coastguard Worker 2111*da0073e9SAndroid Build Coastguard Worker @make_test 2112*da0073e9SAndroid Build Coastguard Worker def test_numpy_dtype_call_in_function(x): 2113*da0073e9SAndroid Build Coastguard Worker dt = np.dtype("float") 2114*da0073e9SAndroid Build Coastguard Worker return np.full_like(x, 2.4, dtype=dt) 2115*da0073e9SAndroid Build Coastguard Worker 2116*da0073e9SAndroid Build Coastguard Worker @make_test 2117*da0073e9SAndroid Build Coastguard Worker def test_numpy_linalg(x): 2118*da0073e9SAndroid Build Coastguard Worker return np.linalg.norm(x.numpy(), axis=0) 2119*da0073e9SAndroid Build Coastguard Worker 2120*da0073e9SAndroid Build Coastguard Worker @make_test 2121*da0073e9SAndroid Build Coastguard Worker def test_numpy_fft(x): 2122*da0073e9SAndroid Build Coastguard Worker return np.fft.fftshift(x.numpy()) 2123*da0073e9SAndroid Build Coastguard Worker 2124*da0073e9SAndroid Build Coastguard Worker @make_test 2125*da0073e9SAndroid Build Coastguard Worker def test_numpy_random(): 2126*da0073e9SAndroid Build Coastguard Worker x = np.random.randn(2, 2) 2127*da0073e9SAndroid Build Coastguard Worker return x - x 2128*da0073e9SAndroid Build Coastguard Worker 2129*da0073e9SAndroid Build Coastguard Worker @make_test 2130*da0073e9SAndroid Build Coastguard Worker def test_partials_torch_op_kwarg(x): 2131*da0073e9SAndroid Build Coastguard Worker par_mul = functools.partial(torch.mul, other=torch.ones(10, 10)) 2132*da0073e9SAndroid Build Coastguard Worker return par_mul(x) 2133*da0073e9SAndroid Build Coastguard Worker 2134*da0073e9SAndroid Build Coastguard Worker @make_test 2135*da0073e9SAndroid Build Coastguard Worker def test_partials_torch_op_arg(x): 2136*da0073e9SAndroid Build Coastguard Worker par_mul = functools.partial(torch.mul, torch.ones(10, 10)) 2137*da0073e9SAndroid Build Coastguard Worker return par_mul(x) 2138*da0073e9SAndroid Build Coastguard Worker 2139*da0073e9SAndroid Build Coastguard Worker @make_test 2140*da0073e9SAndroid Build Coastguard Worker def test_partials_udf_arg(x): 2141*da0073e9SAndroid Build Coastguard Worker par_mul = functools.partial(udf_mul, torch.ones(10, 10)) 2142*da0073e9SAndroid Build Coastguard Worker return par_mul(x) 2143*da0073e9SAndroid Build Coastguard Worker 2144*da0073e9SAndroid Build Coastguard Worker @make_test 2145*da0073e9SAndroid Build Coastguard Worker def test_list_add_then_mutate(x): 2146*da0073e9SAndroid Build Coastguard Worker my_list = [1, x] 2147*da0073e9SAndroid Build Coastguard Worker y = x / 4.0 2148*da0073e9SAndroid Build Coastguard Worker my_list = my_list + [x / 2.0, 4] 2149*da0073e9SAndroid Build Coastguard Worker my_list.append(y) 2150*da0073e9SAndroid Build Coastguard Worker return sum(my_list) 2151*da0073e9SAndroid Build Coastguard Worker 2152*da0073e9SAndroid Build Coastguard Worker @make_test 2153*da0073e9SAndroid Build Coastguard Worker def test_list_expand_lhs(x): 2154*da0073e9SAndroid Build Coastguard Worker return sum(4 * [x]) 2155*da0073e9SAndroid Build Coastguard Worker 2156*da0073e9SAndroid Build Coastguard Worker @make_test 2157*da0073e9SAndroid Build Coastguard Worker def test_in_not_in(x): 2158*da0073e9SAndroid Build Coastguard Worker mylist = [1, 2, 3, 4, 5, x] 2159*da0073e9SAndroid Build Coastguard Worker myotherlist = [1, 2, 3, 4, 5] 2160*da0073e9SAndroid Build Coastguard Worker assert 3 in mylist 2161*da0073e9SAndroid Build Coastguard Worker assert 6 not in myotherlist 2162*da0073e9SAndroid Build Coastguard Worker return sum(mylist) 2163*da0073e9SAndroid Build Coastguard Worker 2164*da0073e9SAndroid Build Coastguard Worker @make_test 2165*da0073e9SAndroid Build Coastguard Worker def test_are_functorch_transforms_active(x): 2166*da0073e9SAndroid Build Coastguard Worker if torch._C._are_functorch_transforms_active(): 2167*da0073e9SAndroid Build Coastguard Worker return x + 1 2168*da0073e9SAndroid Build Coastguard Worker else: 2169*da0073e9SAndroid Build Coastguard Worker return x - 1 2170*da0073e9SAndroid Build Coastguard Worker 2171*da0073e9SAndroid Build Coastguard Worker @make_test 2172*da0073e9SAndroid Build Coastguard Worker def test_partials_udf_kwarg(x): 2173*da0073e9SAndroid Build Coastguard Worker par_mul = functools.partial(udf_mul, y=torch.ones(10, 10)) 2174*da0073e9SAndroid Build Coastguard Worker return par_mul(x) 2175*da0073e9SAndroid Build Coastguard Worker 2176*da0073e9SAndroid Build Coastguard Worker @make_test 2177*da0073e9SAndroid Build Coastguard Worker def test_partials_udf_kwarg_module(x, y): 2178*da0073e9SAndroid Build Coastguard Worker par_mod = functools.partial(udf_module, mod=SmallNN()) 2179*da0073e9SAndroid Build Coastguard Worker return par_mod(x=x, y=y) 2180*da0073e9SAndroid Build Coastguard Worker 2181*da0073e9SAndroid Build Coastguard Worker @make_test 2182*da0073e9SAndroid Build Coastguard Worker def test_partials_udf_kwarg_method(x, y): 2183*da0073e9SAndroid Build Coastguard Worker par_mod = functools.partial(udf_module, mod=SmallNN().forward) 2184*da0073e9SAndroid Build Coastguard Worker return par_mod(x=x, y=y) 2185*da0073e9SAndroid Build Coastguard Worker 2186*da0073e9SAndroid Build Coastguard Worker @make_test 2187*da0073e9SAndroid Build Coastguard Worker def test_partials_lambda(x): 2188*da0073e9SAndroid Build Coastguard Worker multiply = lambda x, y: x * y 2189*da0073e9SAndroid Build Coastguard Worker triple = functools.partial(multiply, y=3) 2190*da0073e9SAndroid Build Coastguard Worker return triple(x) 2191*da0073e9SAndroid Build Coastguard Worker 2192*da0073e9SAndroid Build Coastguard Worker @unittest.skipUnless(torch.distributed.is_available(), "requires torch.distributed") 2193*da0073e9SAndroid Build Coastguard Worker @make_test 2194*da0073e9SAndroid Build Coastguard Worker def test_flat_param_same_storage_size(x, y): 2195*da0073e9SAndroid Build Coastguard Worker import torch.distributed.fsdp._flat_param as flat_param 2196*da0073e9SAndroid Build Coastguard Worker 2197*da0073e9SAndroid Build Coastguard Worker if flat_param._same_storage_size(x, 100): 2198*da0073e9SAndroid Build Coastguard Worker x = x + 1 2199*da0073e9SAndroid Build Coastguard Worker else: 2200*da0073e9SAndroid Build Coastguard Worker x = x - 1 2201*da0073e9SAndroid Build Coastguard Worker if flat_param._same_storage_size(y, 123): 2202*da0073e9SAndroid Build Coastguard Worker y = y + 1 2203*da0073e9SAndroid Build Coastguard Worker else: 2204*da0073e9SAndroid Build Coastguard Worker y = y - 1 2205*da0073e9SAndroid Build Coastguard Worker return x, y 2206*da0073e9SAndroid Build Coastguard Worker 2207*da0073e9SAndroid Build Coastguard Worker @parametrize( 2208*da0073e9SAndroid Build Coastguard Worker "attr", 2209*da0073e9SAndroid Build Coastguard Worker ( 2210*da0073e9SAndroid Build Coastguard Worker # True 2211*da0073e9SAndroid Build Coastguard Worker "__subclasshook__", 2212*da0073e9SAndroid Build Coastguard Worker "__lt__", 2213*da0073e9SAndroid Build Coastguard Worker "__hash__", 2214*da0073e9SAndroid Build Coastguard Worker "__ge__", 2215*da0073e9SAndroid Build Coastguard Worker "__le__", 2216*da0073e9SAndroid Build Coastguard Worker "__gt__", 2217*da0073e9SAndroid Build Coastguard Worker "__dict__", 2218*da0073e9SAndroid Build Coastguard Worker "__getattribute__", 2219*da0073e9SAndroid Build Coastguard Worker "__setattr__", 2220*da0073e9SAndroid Build Coastguard Worker "__doc__", 2221*da0073e9SAndroid Build Coastguard Worker "__repr__", 2222*da0073e9SAndroid Build Coastguard Worker "__dir__", 2223*da0073e9SAndroid Build Coastguard Worker "__init__", 2224*da0073e9SAndroid Build Coastguard Worker "__new__", 2225*da0073e9SAndroid Build Coastguard Worker "__class__", 2226*da0073e9SAndroid Build Coastguard Worker "__eq__", 2227*da0073e9SAndroid Build Coastguard Worker "__delattr__", 2228*da0073e9SAndroid Build Coastguard Worker "__reduce__", 2229*da0073e9SAndroid Build Coastguard Worker "__module__", 2230*da0073e9SAndroid Build Coastguard Worker "__format__", 2231*da0073e9SAndroid Build Coastguard Worker "__str__", 2232*da0073e9SAndroid Build Coastguard Worker "__sizeof__", 2233*da0073e9SAndroid Build Coastguard Worker "__ne__", 2234*da0073e9SAndroid Build Coastguard Worker "__call__", 2235*da0073e9SAndroid Build Coastguard Worker "__reduce_ex__", 2236*da0073e9SAndroid Build Coastguard Worker "__init_subclass__", 2237*da0073e9SAndroid Build Coastguard Worker "args", 2238*da0073e9SAndroid Build Coastguard Worker "keywords", 2239*da0073e9SAndroid Build Coastguard Worker "func", 2240*da0073e9SAndroid Build Coastguard Worker # False 2241*da0073e9SAndroid Build Coastguard Worker "__code__", 2242*da0073e9SAndroid Build Coastguard Worker "__kwdefaults__", 2243*da0073e9SAndroid Build Coastguard Worker "__defaults__", 2244*da0073e9SAndroid Build Coastguard Worker "__name__", 2245*da0073e9SAndroid Build Coastguard Worker "__annotations__", 2246*da0073e9SAndroid Build Coastguard Worker "__get__", 2247*da0073e9SAndroid Build Coastguard Worker "__builtins__", 2248*da0073e9SAndroid Build Coastguard Worker "__qualname__", 2249*da0073e9SAndroid Build Coastguard Worker "__globals__", 2250*da0073e9SAndroid Build Coastguard Worker "__closure__", 2251*da0073e9SAndroid Build Coastguard Worker ), 2252*da0073e9SAndroid Build Coastguard Worker ) 2253*da0073e9SAndroid Build Coastguard Worker def test_partials_hasattr(self, attr): 2254*da0073e9SAndroid Build Coastguard Worker def fn(t): 2255*da0073e9SAndroid Build Coastguard Worker f = lambda x, y: torch.sin(x) + torch.cos(y) 2256*da0073e9SAndroid Build Coastguard Worker p = functools.partial(f, y=t) 2257*da0073e9SAndroid Build Coastguard Worker if hasattr(p, attr): 2258*da0073e9SAndroid Build Coastguard Worker return p(t) 2259*da0073e9SAndroid Build Coastguard Worker else: 2260*da0073e9SAndroid Build Coastguard Worker return torch.zeros_like(t) 2261*da0073e9SAndroid Build Coastguard Worker 2262*da0073e9SAndroid Build Coastguard Worker t = torch.randn(3, 4) 2263*da0073e9SAndroid Build Coastguard Worker counter = torch._dynamo.testing.CompileCounter() 2264*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fullgraph=True, backend=counter)(fn) 2265*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(t), fn(t)) 2266*da0073e9SAndroid Build Coastguard Worker self.assertGreater(counter.frame_count, 0) 2267*da0073e9SAndroid Build Coastguard Worker 2268*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 2269*da0073e9SAndroid Build Coastguard Worker def test_partials_hasattr_set_attr(self): 2270*da0073e9SAndroid Build Coastguard Worker def fn(t): 2271*da0073e9SAndroid Build Coastguard Worker f = lambda x, y: torch.sin(x) + torch.cos(y) 2272*da0073e9SAndroid Build Coastguard Worker p = functools.partial(f, y=t) 2273*da0073e9SAndroid Build Coastguard Worker p.__name__ = "test" 2274*da0073e9SAndroid Build Coastguard Worker if hasattr(p, "__name__"): 2275*da0073e9SAndroid Build Coastguard Worker return p(t) 2276*da0073e9SAndroid Build Coastguard Worker else: 2277*da0073e9SAndroid Build Coastguard Worker return torch.zeros_like(t) 2278*da0073e9SAndroid Build Coastguard Worker 2279*da0073e9SAndroid Build Coastguard Worker t = torch.randn(3, 4) 2280*da0073e9SAndroid Build Coastguard Worker counter = torch._dynamo.testing.CompileCounter() 2281*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fullgraph=True, backend=counter)(fn) 2282*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(t), fn(t)) 2283*da0073e9SAndroid Build Coastguard Worker 2284*da0073e9SAndroid Build Coastguard Worker def test_filter(self): 2285*da0073e9SAndroid Build Coastguard Worker def fn(inputs): 2286*da0073e9SAndroid Build Coastguard Worker out = inputs[0] 2287*da0073e9SAndroid Build Coastguard Worker for inp in filter(lambda x: (x.requires_grad), inputs): 2288*da0073e9SAndroid Build Coastguard Worker out = out * inp 2289*da0073e9SAndroid Build Coastguard Worker return out 2290*da0073e9SAndroid Build Coastguard Worker 2291*da0073e9SAndroid Build Coastguard Worker input1 = torch.arange(2, dtype=torch.bfloat16) 2292*da0073e9SAndroid Build Coastguard Worker input2 = torch.arange(2, dtype=torch.bfloat16).requires_grad_(True) 2293*da0073e9SAndroid Build Coastguard Worker inputs = [input1, input2] 2294*da0073e9SAndroid Build Coastguard Worker 2295*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fullgraph=True)(fn) 2296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(inputs), fn(inputs)) 2297*da0073e9SAndroid Build Coastguard Worker 2298*da0073e9SAndroid Build Coastguard Worker def test_filter_fallback(self): 2299*da0073e9SAndroid Build Coastguard Worker def fn(inputs): 2300*da0073e9SAndroid Build Coastguard Worker out = inputs[0] 2301*da0073e9SAndroid Build Coastguard Worker for inp in filter(lambda x: x[0] == 1, inputs): 2302*da0073e9SAndroid Build Coastguard Worker out = out * inp 2303*da0073e9SAndroid Build Coastguard Worker return out 2304*da0073e9SAndroid Build Coastguard Worker 2305*da0073e9SAndroid Build Coastguard Worker input1 = torch.ones(2, dtype=torch.bfloat16) 2306*da0073e9SAndroid Build Coastguard Worker input2 = torch.arange(2, dtype=torch.bfloat16) 2307*da0073e9SAndroid Build Coastguard Worker inputs = [input1, input2] 2308*da0073e9SAndroid Build Coastguard Worker 2309*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile()(fn) 2310*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(inputs), fn(inputs)) 2311*da0073e9SAndroid Build Coastguard Worker 2312*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 2313*da0073e9SAndroid Build Coastguard Worker 2314*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch._dynamo.exc.Unsupported): 2315*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fullgraph=True)(fn) 2316*da0073e9SAndroid Build Coastguard Worker opt_fn(inputs) 2317*da0073e9SAndroid Build Coastguard Worker 2318*da0073e9SAndroid Build Coastguard Worker def test_pow_int(self): 2319*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 2320*da0073e9SAndroid Build Coastguard Worker return torch.pow(a, b) 2321*da0073e9SAndroid Build Coastguard Worker 2322*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2) 2323*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fullgraph=True, backend="eager", dynamic=True)(fn) 2324*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, 2), fn(x, 2)) 2325*da0073e9SAndroid Build Coastguard Worker 2326*da0073e9SAndroid Build Coastguard Worker def test_tensor_size_indexed_by_symint(self): 2327*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 2328*da0073e9SAndroid Build Coastguard Worker index = x.shape[-1] 2329*da0073e9SAndroid Build Coastguard Worker return x + y.shape[index] 2330*da0073e9SAndroid Build Coastguard Worker 2331*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, 2) 2332*da0073e9SAndroid Build Coastguard Worker y = torch.rand(10, 8, 6) 2333*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 2334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, y), fn(x, y)) 2335*da0073e9SAndroid Build Coastguard Worker 2336*da0073e9SAndroid Build Coastguard Worker def test_partials_as_input_partials_lambda(self): 2337*da0073e9SAndroid Build Coastguard Worker def fn(f0, f1, x): 2338*da0073e9SAndroid Build Coastguard Worker return f0(x) * f1(x) 2339*da0073e9SAndroid Build Coastguard Worker 2340*da0073e9SAndroid Build Coastguard Worker multiply = lambda x, y: x * y 2341*da0073e9SAndroid Build Coastguard Worker lambda0 = functools.partial(multiply, y=3) 2342*da0073e9SAndroid Build Coastguard Worker lambda1 = functools.partial(multiply, y=2) 2343*da0073e9SAndroid Build Coastguard Worker 2344*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2345*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize(cnts, nopython=True)(fn)( 2346*da0073e9SAndroid Build Coastguard Worker lambda0, lambda1, torch.randn(2, 2) 2347*da0073e9SAndroid Build Coastguard Worker ) 2348*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2349*da0073e9SAndroid Build Coastguard Worker 2350*da0073e9SAndroid Build Coastguard Worker def test_partials_as_input_partials_mod(self): 2351*da0073e9SAndroid Build Coastguard Worker def fn(f0, f1, x): 2352*da0073e9SAndroid Build Coastguard Worker return f0(x) * f1(x) 2353*da0073e9SAndroid Build Coastguard Worker 2354*da0073e9SAndroid Build Coastguard Worker lambda0 = functools.partial(SmallNN(), y=torch.randn(2, 2)) 2355*da0073e9SAndroid Build Coastguard Worker lambda1 = functools.partial(SmallNN(), y=torch.randn(2, 2)) 2356*da0073e9SAndroid Build Coastguard Worker 2357*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2358*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 2359*da0073e9SAndroid Build Coastguard Worker dynamo_result = torch._dynamo.optimize(cnts, nopython=True)(fn)( 2360*da0073e9SAndroid Build Coastguard Worker lambda0, lambda1, x 2361*da0073e9SAndroid Build Coastguard Worker ) 2362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2363*da0073e9SAndroid Build Coastguard Worker 2364*da0073e9SAndroid Build Coastguard Worker eager_result = fn(lambda0, lambda1, x) 2365*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_result, dynamo_result) 2366*da0073e9SAndroid Build Coastguard Worker 2367*da0073e9SAndroid Build Coastguard Worker def test_partials_as_input_UDF(self): 2368*da0073e9SAndroid Build Coastguard Worker def fn(f0, f1, x): 2369*da0073e9SAndroid Build Coastguard Worker return f0(x) * f1(x) 2370*da0073e9SAndroid Build Coastguard Worker 2371*da0073e9SAndroid Build Coastguard Worker lambda0 = functools.partial(udf_mul, y=torch.randn(2, 2)) 2372*da0073e9SAndroid Build Coastguard Worker lambda1 = functools.partial(udf_mul, y=torch.randn(2, 2)) 2373*da0073e9SAndroid Build Coastguard Worker 2374*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2375*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 2376*da0073e9SAndroid Build Coastguard Worker dynamo_result = torch._dynamo.optimize(cnts, nopython=True)(fn)( 2377*da0073e9SAndroid Build Coastguard Worker lambda0, lambda1, x 2378*da0073e9SAndroid Build Coastguard Worker ) 2379*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2380*da0073e9SAndroid Build Coastguard Worker 2381*da0073e9SAndroid Build Coastguard Worker eager_result = fn(lambda0, lambda1, x) 2382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_result, dynamo_result) 2383*da0073e9SAndroid Build Coastguard Worker 2384*da0073e9SAndroid Build Coastguard Worker def test_partials_graph_break_reconstruct(self): 2385*da0073e9SAndroid Build Coastguard Worker def fn(udf_mul_0, udf_mul_1, x): 2386*da0073e9SAndroid Build Coastguard Worker lambda0 = functools.partial(udf_mul_0, y=x) 2387*da0073e9SAndroid Build Coastguard Worker lambda1 = functools.partial(udf_mul_1, y=x) 2388*da0073e9SAndroid Build Coastguard Worker 2389*da0073e9SAndroid Build Coastguard Worker print("break") 2390*da0073e9SAndroid Build Coastguard Worker return torch.mul(lambda0(x), lambda1(x)) 2391*da0073e9SAndroid Build Coastguard Worker 2392*da0073e9SAndroid Build Coastguard Worker backend = EagerAndRecordGraphs() 2393*da0073e9SAndroid Build Coastguard Worker cnts = CompileCounterWithBackend(backend) 2394*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 2395*da0073e9SAndroid Build Coastguard Worker dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_mul, x) 2396*da0073e9SAndroid Build Coastguard Worker 2397*da0073e9SAndroid Build Coastguard Worker eager_result = fn(udf_mul, udf_mul, x) 2398*da0073e9SAndroid Build Coastguard Worker gm = backend.graphs[0] 2399*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_result, dynamo_result) 2400*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 2401*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2402*da0073e9SAndroid Build Coastguard Worker normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2403*da0073e9SAndroid Build Coastguard Worker """\ 2404*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 2405*da0073e9SAndroid Build Coastguard Worker def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"): 2406*da0073e9SAndroid Build Coastguard Worker l_lambda0_keywords_y_ = L_lambda0_keywords_y_ 2407*da0073e9SAndroid Build Coastguard Worker 2408*da0073e9SAndroid Build Coastguard Worker mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ 2409*da0073e9SAndroid Build Coastguard Worker mul_1: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None 2410*da0073e9SAndroid Build Coastguard Worker 2411*da0073e9SAndroid Build Coastguard Worker mul_2: "f32[2, 2]" = torch.mul(mul, mul_1); mul = mul_1 = None 2412*da0073e9SAndroid Build Coastguard Worker return (mul_2,) 2413*da0073e9SAndroid Build Coastguard Worker""", 2414*da0073e9SAndroid Build Coastguard Worker ) 2415*da0073e9SAndroid Build Coastguard Worker else: 2416*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2417*da0073e9SAndroid Build Coastguard Worker normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2418*da0073e9SAndroid Build Coastguard Worker """\ 2419*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 2420*da0073e9SAndroid Build Coastguard Worker def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"): 2421*da0073e9SAndroid Build Coastguard Worker l_lambda0_keywords_y_ = L_lambda0_keywords_y_ 2422*da0073e9SAndroid Build Coastguard Worker 2423*da0073e9SAndroid Build Coastguard Worker mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ 2424*da0073e9SAndroid Build Coastguard Worker mul_1: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None 2425*da0073e9SAndroid Build Coastguard Worker 2426*da0073e9SAndroid Build Coastguard Worker mul_2: "f32[s0, s0]" = torch.mul(mul, mul_1); mul = mul_1 = None 2427*da0073e9SAndroid Build Coastguard Worker return (mul_2,) 2428*da0073e9SAndroid Build Coastguard Worker""", 2429*da0073e9SAndroid Build Coastguard Worker ) 2430*da0073e9SAndroid Build Coastguard Worker 2431*da0073e9SAndroid Build Coastguard Worker def test_partials_graph_break_reconstruct_mix(self): 2432*da0073e9SAndroid Build Coastguard Worker def fn(udf_mul_0, udf_add_1, x): 2433*da0073e9SAndroid Build Coastguard Worker lambda0 = functools.partial(udf_mul_0, y=x) 2434*da0073e9SAndroid Build Coastguard Worker lambda1 = functools.partial(udf_add_1, x) 2435*da0073e9SAndroid Build Coastguard Worker 2436*da0073e9SAndroid Build Coastguard Worker print("break") 2437*da0073e9SAndroid Build Coastguard Worker return torch.mul(lambda0(x), lambda1(x)) 2438*da0073e9SAndroid Build Coastguard Worker 2439*da0073e9SAndroid Build Coastguard Worker backend = EagerAndRecordGraphs() 2440*da0073e9SAndroid Build Coastguard Worker cnts = CompileCounterWithBackend(backend) 2441*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 2442*da0073e9SAndroid Build Coastguard Worker dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_add, x) 2443*da0073e9SAndroid Build Coastguard Worker 2444*da0073e9SAndroid Build Coastguard Worker eager_result = fn(udf_mul, udf_add, x) 2445*da0073e9SAndroid Build Coastguard Worker gm = backend.graphs[0] 2446*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_result, dynamo_result) 2447*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 2448*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2449*da0073e9SAndroid Build Coastguard Worker normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2450*da0073e9SAndroid Build Coastguard Worker """\ 2451*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 2452*da0073e9SAndroid Build Coastguard Worker def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"): 2453*da0073e9SAndroid Build Coastguard Worker l_lambda0_keywords_y_ = L_lambda0_keywords_y_ 2454*da0073e9SAndroid Build Coastguard Worker 2455*da0073e9SAndroid Build Coastguard Worker mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ 2456*da0073e9SAndroid Build Coastguard Worker 2457*da0073e9SAndroid Build Coastguard Worker add: "f32[2, 2]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None 2458*da0073e9SAndroid Build Coastguard Worker 2459*da0073e9SAndroid Build Coastguard Worker mul_1: "f32[2, 2]" = torch.mul(mul, add); mul = add = None 2460*da0073e9SAndroid Build Coastguard Worker return (mul_1,) 2461*da0073e9SAndroid Build Coastguard Worker""", 2462*da0073e9SAndroid Build Coastguard Worker ) 2463*da0073e9SAndroid Build Coastguard Worker else: 2464*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2465*da0073e9SAndroid Build Coastguard Worker normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2466*da0073e9SAndroid Build Coastguard Worker """\ 2467*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 2468*da0073e9SAndroid Build Coastguard Worker def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"): 2469*da0073e9SAndroid Build Coastguard Worker l_lambda0_keywords_y_ = L_lambda0_keywords_y_ 2470*da0073e9SAndroid Build Coastguard Worker 2471*da0073e9SAndroid Build Coastguard Worker mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ 2472*da0073e9SAndroid Build Coastguard Worker 2473*da0073e9SAndroid Build Coastguard Worker add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None 2474*da0073e9SAndroid Build Coastguard Worker 2475*da0073e9SAndroid Build Coastguard Worker mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None 2476*da0073e9SAndroid Build Coastguard Worker return (mul_1,) 2477*da0073e9SAndroid Build Coastguard Worker""", 2478*da0073e9SAndroid Build Coastguard Worker ) 2479*da0073e9SAndroid Build Coastguard Worker 2480*da0073e9SAndroid Build Coastguard Worker def test_partials_graph_break_reconstruct_mix_no_source(self): 2481*da0073e9SAndroid Build Coastguard Worker def fn(udf_mul_0, x): 2482*da0073e9SAndroid Build Coastguard Worker udf_add_1 = lambda x, y: x + y 2483*da0073e9SAndroid Build Coastguard Worker 2484*da0073e9SAndroid Build Coastguard Worker lambda0 = functools.partial(udf_mul_0, y=x) 2485*da0073e9SAndroid Build Coastguard Worker lambda1 = functools.partial(udf_add_1, x) 2486*da0073e9SAndroid Build Coastguard Worker 2487*da0073e9SAndroid Build Coastguard Worker print("break") 2488*da0073e9SAndroid Build Coastguard Worker return torch.mul(lambda0(x), lambda1(x)) 2489*da0073e9SAndroid Build Coastguard Worker 2490*da0073e9SAndroid Build Coastguard Worker backend = EagerAndRecordGraphs() 2491*da0073e9SAndroid Build Coastguard Worker cnts = CompileCounterWithBackend(backend) 2492*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 2493*da0073e9SAndroid Build Coastguard Worker dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, x) 2494*da0073e9SAndroid Build Coastguard Worker 2495*da0073e9SAndroid Build Coastguard Worker eager_result = fn(udf_mul, x) 2496*da0073e9SAndroid Build Coastguard Worker gm = backend.graphs[0] 2497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_result, dynamo_result) 2498*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 2499*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2500*da0073e9SAndroid Build Coastguard Worker normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2501*da0073e9SAndroid Build Coastguard Worker """\ 2502*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 2503*da0073e9SAndroid Build Coastguard Worker def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"): 2504*da0073e9SAndroid Build Coastguard Worker l_lambda0_keywords_y_ = L_lambda0_keywords_y_ 2505*da0073e9SAndroid Build Coastguard Worker 2506*da0073e9SAndroid Build Coastguard Worker mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ 2507*da0073e9SAndroid Build Coastguard Worker 2508*da0073e9SAndroid Build Coastguard Worker add: "f32[2, 2]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None 2509*da0073e9SAndroid Build Coastguard Worker 2510*da0073e9SAndroid Build Coastguard Worker mul_1: "f32[2, 2]" = torch.mul(mul, add); mul = add = None 2511*da0073e9SAndroid Build Coastguard Worker return (mul_1,) 2512*da0073e9SAndroid Build Coastguard Worker""", 2513*da0073e9SAndroid Build Coastguard Worker ) 2514*da0073e9SAndroid Build Coastguard Worker else: 2515*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2516*da0073e9SAndroid Build Coastguard Worker normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2517*da0073e9SAndroid Build Coastguard Worker """\ 2518*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 2519*da0073e9SAndroid Build Coastguard Worker def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"): 2520*da0073e9SAndroid Build Coastguard Worker l_lambda0_keywords_y_ = L_lambda0_keywords_y_ 2521*da0073e9SAndroid Build Coastguard Worker 2522*da0073e9SAndroid Build Coastguard Worker mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_ 2523*da0073e9SAndroid Build Coastguard Worker 2524*da0073e9SAndroid Build Coastguard Worker add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None 2525*da0073e9SAndroid Build Coastguard Worker 2526*da0073e9SAndroid Build Coastguard Worker mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None 2527*da0073e9SAndroid Build Coastguard Worker return (mul_1,) 2528*da0073e9SAndroid Build Coastguard Worker""", 2529*da0073e9SAndroid Build Coastguard Worker ) 2530*da0073e9SAndroid Build Coastguard Worker 2531*da0073e9SAndroid Build Coastguard Worker def test_partials_graph_break_reconstruct_args_and_kwargs(self): 2532*da0073e9SAndroid Build Coastguard Worker def fn(udf_mul_0, x): 2533*da0073e9SAndroid Build Coastguard Worker lambda0 = functools.partial(udf_mul_0, x, 4, z=x) 2534*da0073e9SAndroid Build Coastguard Worker lambda1 = functools.partial(udf_mul_0, 4, z=x) 2535*da0073e9SAndroid Build Coastguard Worker 2536*da0073e9SAndroid Build Coastguard Worker return torch.mul(lambda0(), lambda1(5)) 2537*da0073e9SAndroid Build Coastguard Worker 2538*da0073e9SAndroid Build Coastguard Worker backend = EagerAndRecordGraphs() 2539*da0073e9SAndroid Build Coastguard Worker cnts = CompileCounterWithBackend(backend) 2540*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 2541*da0073e9SAndroid Build Coastguard Worker dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul2, x) 2542*da0073e9SAndroid Build Coastguard Worker 2543*da0073e9SAndroid Build Coastguard Worker eager_result = fn(udf_mul2, x) 2544*da0073e9SAndroid Build Coastguard Worker gm = backend.graphs[0] 2545*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_result, dynamo_result) 2546*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 2547*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2548*da0073e9SAndroid Build Coastguard Worker normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2549*da0073e9SAndroid Build Coastguard Worker """\ 2550*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 2551*da0073e9SAndroid Build Coastguard Worker def forward(self, L_x_: "f32[2, 2]"): 2552*da0073e9SAndroid Build Coastguard Worker l_x_ = L_x_ 2553*da0073e9SAndroid Build Coastguard Worker 2554*da0073e9SAndroid Build Coastguard Worker mul: "f32[2, 2]" = l_x_ * 4 2555*da0073e9SAndroid Build Coastguard Worker mul_1: "f32[2, 2]" = mul * l_x_; mul = None 2556*da0073e9SAndroid Build Coastguard Worker mul_2: "f32[2, 2]" = 20 * l_x_; l_x_ = None 2557*da0073e9SAndroid Build Coastguard Worker 2558*da0073e9SAndroid Build Coastguard Worker mul_3: "f32[2, 2]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None 2559*da0073e9SAndroid Build Coastguard Worker return (mul_3,) 2560*da0073e9SAndroid Build Coastguard Worker""", 2561*da0073e9SAndroid Build Coastguard Worker ) 2562*da0073e9SAndroid Build Coastguard Worker else: 2563*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2564*da0073e9SAndroid Build Coastguard Worker normalize_gm(backend.graphs[0].print_readable(print_output=False)), 2565*da0073e9SAndroid Build Coastguard Worker """\ 2566*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 2567*da0073e9SAndroid Build Coastguard Worker def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, s0]"): 2568*da0073e9SAndroid Build Coastguard Worker l_x_ = L_x_ 2569*da0073e9SAndroid Build Coastguard Worker 2570*da0073e9SAndroid Build Coastguard Worker mul: "f32[s0, s0]" = l_x_ * 4 2571*da0073e9SAndroid Build Coastguard Worker mul_1: "f32[s0, s0]" = mul * l_x_; mul = None 2572*da0073e9SAndroid Build Coastguard Worker mul_2: "f32[s0, s0]" = 20 * l_x_; l_x_ = None 2573*da0073e9SAndroid Build Coastguard Worker 2574*da0073e9SAndroid Build Coastguard Worker mul_3: "f32[s0, s0]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None 2575*da0073e9SAndroid Build Coastguard Worker return (mul_3,) 2576*da0073e9SAndroid Build Coastguard Worker""", 2577*da0073e9SAndroid Build Coastguard Worker ) 2578*da0073e9SAndroid Build Coastguard Worker 2579*da0073e9SAndroid Build Coastguard Worker def test_partials_recompilation(self): 2580*da0073e9SAndroid Build Coastguard Worker def fn(f0, f1, x): 2581*da0073e9SAndroid Build Coastguard Worker return f0(x) * f1(x) 2582*da0073e9SAndroid Build Coastguard Worker 2583*da0073e9SAndroid Build Coastguard Worker lambda0 = functools.partial(udf_mul, y=torch.randn(2, 2)) 2584*da0073e9SAndroid Build Coastguard Worker lambda1 = functools.partial(udf_mul, y=torch.randn(2, 2)) 2585*da0073e9SAndroid Build Coastguard Worker 2586*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2587*da0073e9SAndroid Build Coastguard Worker 2588*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 2589*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 2590*da0073e9SAndroid Build Coastguard Worker dynamo_result = fn(lambda0, lambda1, x) 2591*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2592*da0073e9SAndroid Build Coastguard Worker 2593*da0073e9SAndroid Build Coastguard Worker fn(lambda1, lambda0, x) 2594*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2595*da0073e9SAndroid Build Coastguard Worker cnts.frame_count, 1 2596*da0073e9SAndroid Build Coastguard Worker ) # No recompile! Tensor and udf_mul guarded 2597*da0073e9SAndroid Build Coastguard Worker 2598*da0073e9SAndroid Build Coastguard Worker lambda2 = functools.partial(udf_mul, y=torch.randn(3, 3)) 2599*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3) 2600*da0073e9SAndroid Build Coastguard Worker fn(lambda2, lambda2, x) 2601*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) # Recompile! Tensor size changed 2602*da0073e9SAndroid Build Coastguard Worker 2603*da0073e9SAndroid Build Coastguard Worker multiply = lambda x, y: x * y 2604*da0073e9SAndroid Build Coastguard Worker lambda3 = functools.partial(multiply, y=torch.randn(3, 3)) 2605*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3) 2606*da0073e9SAndroid Build Coastguard Worker fn(lambda3, lambda3, x) 2607*da0073e9SAndroid Build Coastguard Worker 2608*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 3) # Recompile! func id changed 2609*da0073e9SAndroid Build Coastguard Worker 2610*da0073e9SAndroid Build Coastguard Worker def fn2(f0, f1, args): 2611*da0073e9SAndroid Build Coastguard Worker return f0(*args) * f1(*args) 2612*da0073e9SAndroid Build Coastguard Worker 2613*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2614*da0073e9SAndroid Build Coastguard Worker 2615*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 2616*da0073e9SAndroid Build Coastguard Worker fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2) 2617*da0073e9SAndroid Build Coastguard Worker dynamo_result = fn2(lambda0, lambda1, [x]) 2618*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) # start over 2619*da0073e9SAndroid Build Coastguard Worker 2620*da0073e9SAndroid Build Coastguard Worker lambda4 = functools.partial(multiply, y=3, x=torch.randn(3, 3)) 2621*da0073e9SAndroid Build Coastguard Worker fn2(lambda4, lambda4, []) 2622*da0073e9SAndroid Build Coastguard Worker 2623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) # Recompile! Different kwarg keys 2624*da0073e9SAndroid Build Coastguard Worker 2625*da0073e9SAndroid Build Coastguard Worker lambda5 = functools.partial(multiply, 1) 2626*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3) 2627*da0073e9SAndroid Build Coastguard Worker fn2(lambda5, lambda5, [x]) 2628*da0073e9SAndroid Build Coastguard Worker 2629*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 3) # Recompile! Different arg keys 2630*da0073e9SAndroid Build Coastguard Worker 2631*da0073e9SAndroid Build Coastguard Worker lambda6 = lambda x: x + x 2632*da0073e9SAndroid Build Coastguard Worker fn2(lambda6, lambda6, [x]) 2633*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2634*da0073e9SAndroid Build Coastguard Worker cnts.frame_count, 4 2635*da0073e9SAndroid Build Coastguard Worker ) # Recompile! input is no longer a functools partial 2636*da0073e9SAndroid Build Coastguard Worker 2637*da0073e9SAndroid Build Coastguard Worker def test_manual_seed(self): 2638*da0073e9SAndroid Build Coastguard Worker @torch.compile 2639*da0073e9SAndroid Build Coastguard Worker def foo(): 2640*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(3) 2641*da0073e9SAndroid Build Coastguard Worker return torch.randint(0, 5, (5,)) 2642*da0073e9SAndroid Build Coastguard Worker 2643*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), foo()) 2644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), foo()) 2645*da0073e9SAndroid Build Coastguard Worker 2646*da0073e9SAndroid Build Coastguard Worker def test_partial_across_graph_break_uninvoked(self): 2647*da0073e9SAndroid Build Coastguard Worker from functools import partial 2648*da0073e9SAndroid Build Coastguard Worker 2649*da0073e9SAndroid Build Coastguard Worker def bar(x, **kwargs): 2650*da0073e9SAndroid Build Coastguard Worker return x + x 2651*da0073e9SAndroid Build Coastguard Worker 2652*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", dynamic=True) 2653*da0073e9SAndroid Build Coastguard Worker def foo(x, i): 2654*da0073e9SAndroid Build Coastguard Worker def inner(): 2655*da0073e9SAndroid Build Coastguard Worker print("this is a graph_break") 2656*da0073e9SAndroid Build Coastguard Worker return op(x) 2657*da0073e9SAndroid Build Coastguard Worker 2658*da0073e9SAndroid Build Coastguard Worker op = partial(bar, dim=10) 2659*da0073e9SAndroid Build Coastguard Worker x = inner() 2660*da0073e9SAndroid Build Coastguard Worker op = partial(bar, other=10) 2661*da0073e9SAndroid Build Coastguard Worker return inner() + x 2662*da0073e9SAndroid Build Coastguard Worker 2663*da0073e9SAndroid Build Coastguard Worker foo(torch.rand(1), 10) 2664*da0073e9SAndroid Build Coastguard Worker 2665*da0073e9SAndroid Build Coastguard Worker def test_no_recompile_inner_function(self): 2666*da0073e9SAndroid Build Coastguard Worker def forward(inp): 2667*da0073e9SAndroid Build Coastguard Worker def g(y): 2668*da0073e9SAndroid Build Coastguard Worker return inp + y 2669*da0073e9SAndroid Build Coastguard Worker 2670*da0073e9SAndroid Build Coastguard Worker print("graph break") 2671*da0073e9SAndroid Build Coastguard Worker return g(torch.rand([1])) 2672*da0073e9SAndroid Build Coastguard Worker 2673*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2674*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(forward) 2675*da0073e9SAndroid Build Coastguard Worker 2676*da0073e9SAndroid Build Coastguard Worker input = torch.rand([2]) 2677*da0073e9SAndroid Build Coastguard Worker _ = opt_fn(input) 2678*da0073e9SAndroid Build Coastguard Worker _ = opt_fn(input) 2679*da0073e9SAndroid Build Coastguard Worker _ = opt_fn(input) 2680*da0073e9SAndroid Build Coastguard Worker # Should not have recompiled 2681*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2682*da0073e9SAndroid Build Coastguard Worker 2683*da0073e9SAndroid Build Coastguard Worker def test_no_recompile_inner_lambda(self): 2684*da0073e9SAndroid Build Coastguard Worker def forward(inp): 2685*da0073e9SAndroid Build Coastguard Worker g = lambda y: inp + y 2686*da0073e9SAndroid Build Coastguard Worker print("graph break") 2687*da0073e9SAndroid Build Coastguard Worker return g(torch.rand([1])) 2688*da0073e9SAndroid Build Coastguard Worker 2689*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 2690*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(forward) 2691*da0073e9SAndroid Build Coastguard Worker 2692*da0073e9SAndroid Build Coastguard Worker input = torch.rand([2]) 2693*da0073e9SAndroid Build Coastguard Worker _ = opt_fn(input) 2694*da0073e9SAndroid Build Coastguard Worker _ = opt_fn(input) 2695*da0073e9SAndroid Build Coastguard Worker _ = opt_fn(input) 2696*da0073e9SAndroid Build Coastguard Worker # Should not have recompiled 2697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 2698*da0073e9SAndroid Build Coastguard Worker 2699*da0073e9SAndroid Build Coastguard Worker def test_complex_closure(self): 2700*da0073e9SAndroid Build Coastguard Worker @torch.compile 2701*da0073e9SAndroid Build Coastguard Worker def forward(y): 2702*da0073e9SAndroid Build Coastguard Worker def a(): 2703*da0073e9SAndroid Build Coastguard Worker def x(z): 2704*da0073e9SAndroid Build Coastguard Worker return y + z 2705*da0073e9SAndroid Build Coastguard Worker 2706*da0073e9SAndroid Build Coastguard Worker return x 2707*da0073e9SAndroid Build Coastguard Worker 2708*da0073e9SAndroid Build Coastguard Worker return a() 2709*da0073e9SAndroid Build Coastguard Worker 2710*da0073e9SAndroid Build Coastguard Worker input1 = torch.rand([2]) 2711*da0073e9SAndroid Build Coastguard Worker input2 = torch.rand([2]) 2712*da0073e9SAndroid Build Coastguard Worker res = forward(input1)(input2) 2713*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res, input1 + input2)) 2714*da0073e9SAndroid Build Coastguard Worker 2715*da0073e9SAndroid Build Coastguard Worker def test_non_inlined_closure(self): 2716*da0073e9SAndroid Build Coastguard Worker @torch.compile() 2717*da0073e9SAndroid Build Coastguard Worker def program(x, y): 2718*da0073e9SAndroid Build Coastguard Worker one = lambda x, y: x + y 2719*da0073e9SAndroid Build Coastguard Worker 2720*da0073e9SAndroid Build Coastguard Worker def inner(): 2721*da0073e9SAndroid Build Coastguard Worker # Force no inlining 2722*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 2723*da0073e9SAndroid Build Coastguard Worker return one(x, y) 2724*da0073e9SAndroid Build Coastguard Worker 2725*da0073e9SAndroid Build Coastguard Worker res = inner() 2726*da0073e9SAndroid Build Coastguard Worker one = lambda x, y: x - y 2727*da0073e9SAndroid Build Coastguard Worker res += inner() 2728*da0073e9SAndroid Build Coastguard Worker return res 2729*da0073e9SAndroid Build Coastguard Worker 2730*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(1) 2731*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(1) 2732*da0073e9SAndroid Build Coastguard Worker 2733*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(program(input1, input2), input1 + input1)) 2734*da0073e9SAndroid Build Coastguard Worker 2735*da0073e9SAndroid Build Coastguard Worker @parametrize("int_or_float", ("int", "float")) 2736*da0073e9SAndroid Build Coastguard Worker def test_np_constant_collections_as_input(self, int_or_float): 2737*da0073e9SAndroid Build Coastguard Worker info_func = getattr(np, f"{int_or_float[0]}info") 2738*da0073e9SAndroid Build Coastguard Worker dt_string_arg = f"{int_or_float}16" 2739*da0073e9SAndroid Build Coastguard Worker np_dt_attr = getattr(np, dt_string_arg) 2740*da0073e9SAndroid Build Coastguard Worker 2741*da0073e9SAndroid Build Coastguard Worker dt_args = [dt_string_arg, np_dt_attr] 2742*da0073e9SAndroid Build Coastguard Worker arg_variants_iter = itertools.chain( 2743*da0073e9SAndroid Build Coastguard Worker dt_args, map(np.dtype, dt_args), map(info_func, dt_args) 2744*da0073e9SAndroid Build Coastguard Worker ) 2745*da0073e9SAndroid Build Coastguard Worker 2746*da0073e9SAndroid Build Coastguard Worker def func(a, b, info_or_dt): 2747*da0073e9SAndroid Build Coastguard Worker return a + info_func(info_or_dt).max 2748*da0073e9SAndroid Build Coastguard Worker 2749*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(func) 2750*da0073e9SAndroid Build Coastguard Worker 2751*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2) 2752*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2) 2753*da0073e9SAndroid Build Coastguard Worker eager_result = func(a, b, dt_args[0]) 2754*da0073e9SAndroid Build Coastguard Worker 2755*da0073e9SAndroid Build Coastguard Worker for arg in arg_variants_iter: 2756*da0073e9SAndroid Build Coastguard Worker opt_result = opt_fn(a, b, arg) 2757*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_result, eager_result)) 2758*da0073e9SAndroid Build Coastguard Worker 2759*da0073e9SAndroid Build Coastguard Worker @parametrize( 2760*da0073e9SAndroid Build Coastguard Worker "typ, info_func", 2761*da0073e9SAndroid Build Coastguard Worker [ 2762*da0073e9SAndroid Build Coastguard Worker (int, np.iinfo), 2763*da0073e9SAndroid Build Coastguard Worker (float, np.finfo), 2764*da0073e9SAndroid Build Coastguard Worker ], 2765*da0073e9SAndroid Build Coastguard Worker name_fn=lambda t, _: t.__name__, 2766*da0073e9SAndroid Build Coastguard Worker ) 2767*da0073e9SAndroid Build Coastguard Worker def test_np_constant_collections_guards(self, typ, info_func): 2768*da0073e9SAndroid Build Coastguard Worker def func_info(a, info): 2769*da0073e9SAndroid Build Coastguard Worker return a + info.max 2770*da0073e9SAndroid Build Coastguard Worker 2771*da0073e9SAndroid Build Coastguard Worker def func_dtype(a, dt): 2772*da0073e9SAndroid Build Coastguard Worker return a + info_func(dt).max 2773*da0073e9SAndroid Build Coastguard Worker 2774*da0073e9SAndroid Build Coastguard Worker dt_args = [ 2775*da0073e9SAndroid Build Coastguard Worker np.dtype(typ), 2776*da0073e9SAndroid Build Coastguard Worker np.ones((1,), dtype=typ).dtype, 2777*da0073e9SAndroid Build Coastguard Worker np.dtype(np.dtype(typ).name), 2778*da0073e9SAndroid Build Coastguard Worker np.dtype(typ.__name__), 2779*da0073e9SAndroid Build Coastguard Worker ] 2780*da0073e9SAndroid Build Coastguard Worker cnts_1 = torch._dynamo.testing.CompileCounter() 2781*da0073e9SAndroid Build Coastguard Worker opt_fn_dtype = torch._dynamo.optimize(cnts_1)(func_dtype) 2782*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(3, dtype=typ) 2783*da0073e9SAndroid Build Coastguard Worker for arg in dt_args: 2784*da0073e9SAndroid Build Coastguard Worker r = opt_fn_dtype(a, arg) 2785*da0073e9SAndroid Build Coastguard Worker # each should produce an identical arg 2786*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts_1.frame_count, 1) 2787*da0073e9SAndroid Build Coastguard Worker 2788*da0073e9SAndroid Build Coastguard Worker cnts_2 = torch._dynamo.testing.CompileCounter() 2789*da0073e9SAndroid Build Coastguard Worker opt_fn_info = torch._dynamo.optimize(cnts_2)(func_info) 2790*da0073e9SAndroid Build Coastguard Worker info_args = [info_func(dt) for dt in dt_args] 2791*da0073e9SAndroid Build Coastguard Worker for arg in info_args: 2792*da0073e9SAndroid Build Coastguard Worker r = opt_fn_info(a, arg) 2793*da0073e9SAndroid Build Coastguard Worker 2794*da0073e9SAndroid Build Coastguard Worker # each should produce an identical arg 2795*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts_2.frame_count, 1) 2796*da0073e9SAndroid Build Coastguard Worker 2797*da0073e9SAndroid Build Coastguard Worker if typ is float: 2798*da0073e9SAndroid Build Coastguard Worker dt_extra = np.dtype(np.float16) 2799*da0073e9SAndroid Build Coastguard Worker else: 2800*da0073e9SAndroid Build Coastguard Worker dt_extra = np.dtype(np.int16) 2801*da0073e9SAndroid Build Coastguard Worker info_extra = info_func(dt_extra) 2802*da0073e9SAndroid Build Coastguard Worker 2803*da0073e9SAndroid Build Coastguard Worker eager_result_dtype = func_dtype(a, dt_extra) 2804*da0073e9SAndroid Build Coastguard Worker compile_result_dtype = opt_fn_dtype(a, dt_extra) 2805*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts_1.frame_count, 2) 2806*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_result_dtype, compile_result_dtype) 2807*da0073e9SAndroid Build Coastguard Worker 2808*da0073e9SAndroid Build Coastguard Worker eager_result_info = func_info(a, info_extra) 2809*da0073e9SAndroid Build Coastguard Worker compile_result_info = opt_fn_info(a, info_extra) 2810*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts_2.frame_count, 2) 2811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_result_info, compile_result_info) 2812*da0073e9SAndroid Build Coastguard Worker 2813*da0073e9SAndroid Build Coastguard Worker def test_compare_constant_and_tensor(self): 2814*da0073e9SAndroid Build Coastguard Worker for op in [ 2815*da0073e9SAndroid Build Coastguard Worker operator.lt, 2816*da0073e9SAndroid Build Coastguard Worker operator.le, 2817*da0073e9SAndroid Build Coastguard Worker operator.gt, 2818*da0073e9SAndroid Build Coastguard Worker operator.ge, 2819*da0073e9SAndroid Build Coastguard Worker operator.ne, 2820*da0073e9SAndroid Build Coastguard Worker operator.eq, 2821*da0073e9SAndroid Build Coastguard Worker operator.is_, 2822*da0073e9SAndroid Build Coastguard Worker operator.is_not, 2823*da0073e9SAndroid Build Coastguard Worker ]: 2824*da0073e9SAndroid Build Coastguard Worker with self.subTest(op=op): 2825*da0073e9SAndroid Build Coastguard Worker 2826*da0073e9SAndroid Build Coastguard Worker def fn(x): 2827*da0073e9SAndroid Build Coastguard Worker return op(-10, x) 2828*da0073e9SAndroid Build Coastguard Worker 2829*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fullgraph=True)(fn) 2830*da0073e9SAndroid Build Coastguard Worker 2831*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 2832*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x), fn(x)) 2833*da0073e9SAndroid Build Coastguard Worker 2834*da0073e9SAndroid Build Coastguard Worker def test_pos(self): 2835*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 2836*da0073e9SAndroid Build Coastguard Worker return operator.pos(x) * +y 2837*da0073e9SAndroid Build Coastguard Worker 2838*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fullgraph=True, dynamic=True)(fn) 2839*da0073e9SAndroid Build Coastguard Worker 2840*da0073e9SAndroid Build Coastguard Worker def test(x, y): 2841*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, y), fn(x, y)) 2842*da0073e9SAndroid Build Coastguard Worker 2843*da0073e9SAndroid Build Coastguard Worker test(torch.ones(4), 1) 2844*da0073e9SAndroid Build Coastguard Worker test(1, torch.ones(4)) 2845*da0073e9SAndroid Build Coastguard Worker test(-1, -1) 2846*da0073e9SAndroid Build Coastguard Worker test(-1.1, 1.1) 2847*da0073e9SAndroid Build Coastguard Worker test(True, False) 2848*da0073e9SAndroid Build Coastguard Worker test(torch.ones(4, dtype=torch.float32), 1.1) 2849*da0073e9SAndroid Build Coastguard Worker 2850*da0073e9SAndroid Build Coastguard Worker def test_index(self): 2851*da0073e9SAndroid Build Coastguard Worker def fn(x, t): 2852*da0073e9SAndroid Build Coastguard Worker v = operator.index(x) 2853*da0073e9SAndroid Build Coastguard Worker torch.mul(t, v) 2854*da0073e9SAndroid Build Coastguard Worker 2855*da0073e9SAndroid Build Coastguard Worker def test(a, b): 2856*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(a, b), fn(a, b)) 2857*da0073e9SAndroid Build Coastguard Worker 2858*da0073e9SAndroid Build Coastguard Worker for dynamic in [True, False]: 2859*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 2860*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(dynamic=dynamic)(fn) 2861*da0073e9SAndroid Build Coastguard Worker t = torch.ones(1) 2862*da0073e9SAndroid Build Coastguard Worker test(10, t) 2863*da0073e9SAndroid Build Coastguard Worker test(-100, t) 2864*da0073e9SAndroid Build Coastguard Worker test(10, t) 2865*da0073e9SAndroid Build Coastguard Worker test(False, t) 2866*da0073e9SAndroid Build Coastguard Worker test(True, t) 2867*da0073e9SAndroid Build Coastguard Worker 2868*da0073e9SAndroid Build Coastguard Worker def test_truth(self): 2869*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 2870*da0073e9SAndroid Build Coastguard Worker return operator.truth(x) and bool(y) 2871*da0073e9SAndroid Build Coastguard Worker 2872*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fullgraph=True, dynamic=False)(fn) 2873*da0073e9SAndroid Build Coastguard Worker 2874*da0073e9SAndroid Build Coastguard Worker def test(x, y): 2875*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, y), fn(x, y)) 2876*da0073e9SAndroid Build Coastguard Worker 2877*da0073e9SAndroid Build Coastguard Worker test(1, 100) 2878*da0073e9SAndroid Build Coastguard Worker test(-1.1, True) 2879*da0073e9SAndroid Build Coastguard Worker test(-1.1, 1.1) 2880*da0073e9SAndroid Build Coastguard Worker test(True, False) 2881*da0073e9SAndroid Build Coastguard Worker test(torch.ones(1), 1) 2882*da0073e9SAndroid Build Coastguard Worker test(torch.zeros(1), 1) 2883*da0073e9SAndroid Build Coastguard Worker test(torch.ones(1), torch.ones(1)) 2884*da0073e9SAndroid Build Coastguard Worker 2885*da0073e9SAndroid Build Coastguard Worker def test_unary_fold_op(self): 2886*da0073e9SAndroid Build Coastguard Worker for op in (operator.abs, abs, operator.neg, operator.pos, operator.truth): 2887*da0073e9SAndroid Build Coastguard Worker with self.subTest(op=op): 2888*da0073e9SAndroid Build Coastguard Worker 2889*da0073e9SAndroid Build Coastguard Worker def fn(): 2890*da0073e9SAndroid Build Coastguard Worker a = range(-10, 10) 2891*da0073e9SAndroid Build Coastguard Worker return list(map(op, a)) 2892*da0073e9SAndroid Build Coastguard Worker 2893*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(nopython=True)(fn) 2894*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(), fn()) 2895*da0073e9SAndroid Build Coastguard Worker 2896*da0073e9SAndroid Build Coastguard Worker def test_unary_fold_op_seq(self): 2897*da0073e9SAndroid Build Coastguard Worker for op in (operator.length_hint,): 2898*da0073e9SAndroid Build Coastguard Worker with self.subTest(op=op): 2899*da0073e9SAndroid Build Coastguard Worker 2900*da0073e9SAndroid Build Coastguard Worker def fn(): 2901*da0073e9SAndroid Build Coastguard Worker a = [tuple(range(-10, i)) for i in range(10)] 2902*da0073e9SAndroid Build Coastguard Worker return tuple(map(op, a)) 2903*da0073e9SAndroid Build Coastguard Worker 2904*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(nopython=True)(fn) 2905*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(), fn()) 2906*da0073e9SAndroid Build Coastguard Worker 2907*da0073e9SAndroid Build Coastguard Worker def gen_random_range_args(self): 2908*da0073e9SAndroid Build Coastguard Worker args_count = random.randint(1, 3) 2909*da0073e9SAndroid Build Coastguard Worker args = [random.randint(-10, 10) for _ in range(args_count)] 2910*da0073e9SAndroid Build Coastguard Worker if args_count == 3 and args[2] == 0: 2911*da0073e9SAndroid Build Coastguard Worker args[2] = 1 2912*da0073e9SAndroid Build Coastguard Worker return args 2913*da0073e9SAndroid Build Coastguard Worker 2914*da0073e9SAndroid Build Coastguard Worker def test_range_length(self): 2915*da0073e9SAndroid Build Coastguard Worker def test(*args, expected=None): 2916*da0073e9SAndroid Build Coastguard Worker r = range(*args) 2917*da0073e9SAndroid Build Coastguard Worker range_variable = RangeVariable([ConstantVariable.create(v) for v in args]) 2918*da0073e9SAndroid Build Coastguard Worker 2919*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(r), range_variable.range_length()) 2920*da0073e9SAndroid Build Coastguard Worker 2921*da0073e9SAndroid Build Coastguard Worker if expected is not None: 2922*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(r), expected) 2923*da0073e9SAndroid Build Coastguard Worker 2924*da0073e9SAndroid Build Coastguard Worker test(1, 1, 1, expected=0) 2925*da0073e9SAndroid Build Coastguard Worker test(1, 0, expected=0) 2926*da0073e9SAndroid Build Coastguard Worker test(-10, expected=0) 2927*da0073e9SAndroid Build Coastguard Worker 2928*da0073e9SAndroid Build Coastguard Worker test(4, expected=4) 2929*da0073e9SAndroid Build Coastguard Worker test(10, expected=10) 2930*da0073e9SAndroid Build Coastguard Worker 2931*da0073e9SAndroid Build Coastguard Worker # step >1 2932*da0073e9SAndroid Build Coastguard Worker test(1, 10, 2, expected=5) 2933*da0073e9SAndroid Build Coastguard Worker 2934*da0073e9SAndroid Build Coastguard Worker # negative step 2935*da0073e9SAndroid Build Coastguard Worker test(10, 1, -1, expected=9) 2936*da0073e9SAndroid Build Coastguard Worker test(10, 1, -3) 2937*da0073e9SAndroid Build Coastguard Worker 2938*da0073e9SAndroid Build Coastguard Worker # Fuzz testing 2939*da0073e9SAndroid Build Coastguard Worker for i in range(100): 2940*da0073e9SAndroid Build Coastguard Worker args = self.gen_random_range_args() 2941*da0073e9SAndroid Build Coastguard Worker print("testing :", args) 2942*da0073e9SAndroid Build Coastguard Worker test(*args) 2943*da0073e9SAndroid Build Coastguard Worker 2944*da0073e9SAndroid Build Coastguard Worker def test_indexed_range(self): 2945*da0073e9SAndroid Build Coastguard Worker def test(range, index, expected=None): 2946*da0073e9SAndroid Build Coastguard Worker range_variable = RangeVariable( 2947*da0073e9SAndroid Build Coastguard Worker [ 2948*da0073e9SAndroid Build Coastguard Worker ConstantVariable.create(v) 2949*da0073e9SAndroid Build Coastguard Worker for v in [range.start, range.stop, range.step] 2950*da0073e9SAndroid Build Coastguard Worker ] 2951*da0073e9SAndroid Build Coastguard Worker ) 2952*da0073e9SAndroid Build Coastguard Worker 2953*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2954*da0073e9SAndroid Build Coastguard Worker range[index], 2955*da0073e9SAndroid Build Coastguard Worker range_variable.apply_index(index).as_python_constant(), 2956*da0073e9SAndroid Build Coastguard Worker ) 2957*da0073e9SAndroid Build Coastguard Worker 2958*da0073e9SAndroid Build Coastguard Worker if expected is not None: 2959*da0073e9SAndroid Build Coastguard Worker self.assertEqual(range[index], expected) 2960*da0073e9SAndroid Build Coastguard Worker 2961*da0073e9SAndroid Build Coastguard Worker test(range(10), 1, expected=1) 2962*da0073e9SAndroid Build Coastguard Worker test(range(10, 20, 2), 1, expected=12) 2963*da0073e9SAndroid Build Coastguard Worker 2964*da0073e9SAndroid Build Coastguard Worker # Fuzz testing 2965*da0073e9SAndroid Build Coastguard Worker for i in range(100): 2966*da0073e9SAndroid Build Coastguard Worker range_args = self.gen_random_range_args() 2967*da0073e9SAndroid Build Coastguard Worker r = range(*range_args) 2968*da0073e9SAndroid Build Coastguard Worker 2969*da0073e9SAndroid Build Coastguard Worker if len(r) == 0: 2970*da0073e9SAndroid Build Coastguard Worker continue 2971*da0073e9SAndroid Build Coastguard Worker 2972*da0073e9SAndroid Build Coastguard Worker index = random.randint(0, len(r) - 1) 2973*da0073e9SAndroid Build Coastguard Worker 2974*da0073e9SAndroid Build Coastguard Worker print("testing:", r, index) 2975*da0073e9SAndroid Build Coastguard Worker test(r, index) 2976*da0073e9SAndroid Build Coastguard Worker 2977*da0073e9SAndroid Build Coastguard Worker def test_sliced_range(self): 2978*da0073e9SAndroid Build Coastguard Worker def test(range, slice, expected=None): 2979*da0073e9SAndroid Build Coastguard Worker range_variable = RangeVariable( 2980*da0073e9SAndroid Build Coastguard Worker [ 2981*da0073e9SAndroid Build Coastguard Worker ConstantVariable.create(v) 2982*da0073e9SAndroid Build Coastguard Worker for v in [range.start, range.stop, range.step] 2983*da0073e9SAndroid Build Coastguard Worker ] 2984*da0073e9SAndroid Build Coastguard Worker ) 2985*da0073e9SAndroid Build Coastguard Worker 2986*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2987*da0073e9SAndroid Build Coastguard Worker range[slice], 2988*da0073e9SAndroid Build Coastguard Worker range_variable.apply_slice(slice).as_python_constant(), 2989*da0073e9SAndroid Build Coastguard Worker ) 2990*da0073e9SAndroid Build Coastguard Worker 2991*da0073e9SAndroid Build Coastguard Worker if expected is not None: 2992*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2993*da0073e9SAndroid Build Coastguard Worker range[slice], 2994*da0073e9SAndroid Build Coastguard Worker expected, 2995*da0073e9SAndroid Build Coastguard Worker ) 2996*da0073e9SAndroid Build Coastguard Worker 2997*da0073e9SAndroid Build Coastguard Worker test(range(10), slice(1, 10, 2), expected=range(1, 10, 2)) 2998*da0073e9SAndroid Build Coastguard Worker test(range(10), slice(None, 10, None), expected=range(0, 10)) 2999*da0073e9SAndroid Build Coastguard Worker test(range(10), slice(-1, 7, None), expected=range(9, 7)) 3000*da0073e9SAndroid Build Coastguard Worker test(range(10), slice(-1, 7, 2), expected=range(9, 7, 2)) 3001*da0073e9SAndroid Build Coastguard Worker test(range(1, 10, 2), slice(3, 7, 2), expected=range(7, 11, 4)) 3002*da0073e9SAndroid Build Coastguard Worker test(range(1, 10, 2), slice(-3, 7, 2), expected=range(5, 11, 4)) 3003*da0073e9SAndroid Build Coastguard Worker test(range(-1, -5, -3), slice(5, None, -3), expected=range(-4, 2, 9)) 3004*da0073e9SAndroid Build Coastguard Worker 3005*da0073e9SAndroid Build Coastguard Worker def rand_slice(): 3006*da0073e9SAndroid Build Coastguard Worker def flip_coin(): 3007*da0073e9SAndroid Build Coastguard Worker # 1 out of 10 3008*da0073e9SAndroid Build Coastguard Worker return random.randint(1, 10) == 5 3009*da0073e9SAndroid Build Coastguard Worker 3010*da0073e9SAndroid Build Coastguard Worker def r_item(allow_zero=True): 3011*da0073e9SAndroid Build Coastguard Worker i = random.randint(-10, 10) 3012*da0073e9SAndroid Build Coastguard Worker if not allow_zero and i == 0: 3013*da0073e9SAndroid Build Coastguard Worker i = 1 3014*da0073e9SAndroid Build Coastguard Worker if flip_coin(): 3015*da0073e9SAndroid Build Coastguard Worker i = None 3016*da0073e9SAndroid Build Coastguard Worker return i 3017*da0073e9SAndroid Build Coastguard Worker 3018*da0073e9SAndroid Build Coastguard Worker arg_count = random.randint(1, 3) 3019*da0073e9SAndroid Build Coastguard Worker 3020*da0073e9SAndroid Build Coastguard Worker if arg_count == 1: 3021*da0073e9SAndroid Build Coastguard Worker return slice(r_item()) 3022*da0073e9SAndroid Build Coastguard Worker elif arg_count == 2: 3023*da0073e9SAndroid Build Coastguard Worker return slice(r_item(), r_item()) 3024*da0073e9SAndroid Build Coastguard Worker else: 3025*da0073e9SAndroid Build Coastguard Worker return slice(r_item(), r_item(), r_item(False)) 3026*da0073e9SAndroid Build Coastguard Worker 3027*da0073e9SAndroid Build Coastguard Worker # Fuzz testing 3028*da0073e9SAndroid Build Coastguard Worker for i in range(100): 3029*da0073e9SAndroid Build Coastguard Worker range_args = self.gen_random_range_args() 3030*da0073e9SAndroid Build Coastguard Worker r = range(*range_args) 3031*da0073e9SAndroid Build Coastguard Worker # generate random slice 3032*da0073e9SAndroid Build Coastguard Worker s = rand_slice() 3033*da0073e9SAndroid Build Coastguard Worker 3034*da0073e9SAndroid Build Coastguard Worker print("testing:", r, s) 3035*da0073e9SAndroid Build Coastguard Worker test(r, s) 3036*da0073e9SAndroid Build Coastguard Worker 3037*da0073e9SAndroid Build Coastguard Worker def test_range_with_slice_index(self): 3038*da0073e9SAndroid Build Coastguard Worker def fn(x): 3039*da0073e9SAndroid Build Coastguard Worker acc = 1 3040*da0073e9SAndroid Build Coastguard Worker for k in range(2)[1::2]: 3041*da0073e9SAndroid Build Coastguard Worker acc *= acc * k 3042*da0073e9SAndroid Build Coastguard Worker return x * acc 3043*da0073e9SAndroid Build Coastguard Worker 3044*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fullgraph=True)(fn) 3045*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1) 3046*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x), fn(x)) 3047*da0073e9SAndroid Build Coastguard Worker 3048*da0073e9SAndroid Build Coastguard Worker def test_range_with_index(self): 3049*da0073e9SAndroid Build Coastguard Worker def fn(x): 3050*da0073e9SAndroid Build Coastguard Worker acc = 1 3051*da0073e9SAndroid Build Coastguard Worker acc *= acc * range(10, 20, 2)[2] 3052*da0073e9SAndroid Build Coastguard Worker return x * acc 3053*da0073e9SAndroid Build Coastguard Worker 3054*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fullgraph=True)(fn) 3055*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1) 3056*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x), fn(x)) 3057*da0073e9SAndroid Build Coastguard Worker 3058*da0073e9SAndroid Build Coastguard Worker def test_rand_inlined(self): 3059*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", dynamic=True) 3060*da0073e9SAndroid Build Coastguard Worker def fn(): 3061*da0073e9SAndroid Build Coastguard Worker idx_size = [10] 3062*da0073e9SAndroid Build Coastguard Worker idx_size[random.randint(0, 0)] = random.randint(1, 8) 3063*da0073e9SAndroid Build Coastguard Worker t = tuple(idx_size) 3064*da0073e9SAndroid Build Coastguard Worker src_size = [random.randint(1, 5) + s for s in idx_size] 3065*da0073e9SAndroid Build Coastguard Worker idx = torch.empty(t) 3066*da0073e9SAndroid Build Coastguard Worker 3067*da0073e9SAndroid Build Coastguard Worker fn() 3068*da0073e9SAndroid Build Coastguard Worker 3069*da0073e9SAndroid Build Coastguard Worker def test_rand_tensor_partial(self): 3070*da0073e9SAndroid Build Coastguard Worker from collections import namedtuple 3071*da0073e9SAndroid Build Coastguard Worker from functools import partial 3072*da0073e9SAndroid Build Coastguard Worker 3073*da0073e9SAndroid Build Coastguard Worker SdpaShape = namedtuple( 3074*da0073e9SAndroid Build Coastguard Worker "Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"] 3075*da0073e9SAndroid Build Coastguard Worker ) 3076*da0073e9SAndroid Build Coastguard Worker 3077*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 3078*da0073e9SAndroid Build Coastguard Worker def func(): 3079*da0073e9SAndroid Build Coastguard Worker make_tensor = partial( 3080*da0073e9SAndroid Build Coastguard Worker torch.rand, device="cpu", dtype=torch.float16, requires_grad=True 3081*da0073e9SAndroid Build Coastguard Worker ) 3082*da0073e9SAndroid Build Coastguard Worker 3083*da0073e9SAndroid Build Coastguard Worker bsz, num_heads, seq_len_q, seq_len_kv, head_dim = (16, 16, 128, 128, 16) 3084*da0073e9SAndroid Build Coastguard Worker make_q_tensor = partial( 3085*da0073e9SAndroid Build Coastguard Worker make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim) 3086*da0073e9SAndroid Build Coastguard Worker ) 3087*da0073e9SAndroid Build Coastguard Worker make_kv_tensor = partial( 3088*da0073e9SAndroid Build Coastguard Worker make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim) 3089*da0073e9SAndroid Build Coastguard Worker ) 3090*da0073e9SAndroid Build Coastguard Worker t1 = make_q_tensor() 3091*da0073e9SAndroid Build Coastguard Worker t2 = make_kv_tensor() 3092*da0073e9SAndroid Build Coastguard Worker t3 = t1 + t2 3093*da0073e9SAndroid Build Coastguard Worker 3094*da0073e9SAndroid Build Coastguard Worker func() 3095*da0073e9SAndroid Build Coastguard Worker 3096*da0073e9SAndroid Build Coastguard Worker def test_to(self): 3097*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 3098*da0073e9SAndroid Build Coastguard Worker def fn(): 3099*da0073e9SAndroid Build Coastguard Worker t = torch.ones(2) 3100*da0073e9SAndroid Build Coastguard Worker y = t.to("meta") 3101*da0073e9SAndroid Build Coastguard Worker 3102*da0073e9SAndroid Build Coastguard Worker fn() 3103*da0073e9SAndroid Build Coastguard Worker 3104*da0073e9SAndroid Build Coastguard Worker def test_elipsis(self): 3105*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 3106*da0073e9SAndroid Build Coastguard Worker def fn(a, ind, val): 3107*da0073e9SAndroid Build Coastguard Worker a[ind] = val 3108*da0073e9SAndroid Build Coastguard Worker return a 3109*da0073e9SAndroid Build Coastguard Worker 3110*da0073e9SAndroid Build Coastguard Worker arr = np.zeros(4) 3111*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(arr, np.s_[...], np.ones(4)), np.ones(4)) 3112*da0073e9SAndroid Build Coastguard Worker 3113*da0073e9SAndroid Build Coastguard Worker arr = np.array([[1, 1], [2, 2]]) 3114*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3115*da0073e9SAndroid Build Coastguard Worker fn(arr, np.s_[0, ...], np.zeros(2)), np.array([[0, 0], [2, 2]]) 3116*da0073e9SAndroid Build Coastguard Worker ) 3117*da0073e9SAndroid Build Coastguard Worker 3118*da0073e9SAndroid Build Coastguard Worker arr = np.array([[1, 1], [2, 2]]) 3119*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3120*da0073e9SAndroid Build Coastguard Worker fn(arr, np.s_[1, ...], np.zeros(2)), np.array([[1, 1], [0, 0]]) 3121*da0073e9SAndroid Build Coastguard Worker ) 3122*da0073e9SAndroid Build Coastguard Worker 3123*da0073e9SAndroid Build Coastguard Worker arr = np.array([[1, 1], [2, 2]]) 3124*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3125*da0073e9SAndroid Build Coastguard Worker fn(arr, np.s_[..., 0], np.array([3, 3])), np.array([[3, 1], [3, 2]]) 3126*da0073e9SAndroid Build Coastguard Worker ) 3127*da0073e9SAndroid Build Coastguard Worker 3128*da0073e9SAndroid Build Coastguard Worker arr = np.array([[1, 1], [2, 2]]) 3129*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3130*da0073e9SAndroid Build Coastguard Worker fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]]) 3131*da0073e9SAndroid Build Coastguard Worker ) 3132*da0073e9SAndroid Build Coastguard Worker 3133*da0073e9SAndroid Build Coastguard Worker def test_map_return(self): 3134*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 3135*da0073e9SAndroid Build Coastguard Worker return map(lambda x: x + 1, [a, b]) 3136*da0073e9SAndroid Build Coastguard Worker 3137*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3138*da0073e9SAndroid Build Coastguard Worker m = opt_fn(torch.randn(3, 3), torch.randn(3, 3)) 3139*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(m, map) 3140*da0073e9SAndroid Build Coastguard Worker 3141*da0073e9SAndroid Build Coastguard Worker @make_test 3142*da0073e9SAndroid Build Coastguard Worker def test_map_max(a, b): 3143*da0073e9SAndroid Build Coastguard Worker return max(map(lambda x: x.sum(), [a, b])) 3144*da0073e9SAndroid Build Coastguard Worker 3145*da0073e9SAndroid Build Coastguard Worker # max(map(...)) graph breaks 3146*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 3147*da0073e9SAndroid Build Coastguard Worker @make_test 3148*da0073e9SAndroid Build Coastguard Worker def test_map_max_const(a): 3149*da0073e9SAndroid Build Coastguard Worker return max(map(lambda x: x, [1, 2, 3])), a + 1 3150*da0073e9SAndroid Build Coastguard Worker 3151*da0073e9SAndroid Build Coastguard Worker @make_test 3152*da0073e9SAndroid Build Coastguard Worker def test_map_list(a, b): 3153*da0073e9SAndroid Build Coastguard Worker return list(map(lambda x: x + 1, [a, b])) 3154*da0073e9SAndroid Build Coastguard Worker 3155*da0073e9SAndroid Build Coastguard Worker @make_test 3156*da0073e9SAndroid Build Coastguard Worker def test_map_tuple(a, b): 3157*da0073e9SAndroid Build Coastguard Worker return tuple(map(lambda x: x + 1, [a, b])) 3158*da0073e9SAndroid Build Coastguard Worker 3159*da0073e9SAndroid Build Coastguard Worker @make_test 3160*da0073e9SAndroid Build Coastguard Worker def test_map_iter(a, b): 3161*da0073e9SAndroid Build Coastguard Worker it = iter(map(lambda x: x + 1, [a, b])) 3162*da0073e9SAndroid Build Coastguard Worker return next(it) 3163*da0073e9SAndroid Build Coastguard Worker 3164*da0073e9SAndroid Build Coastguard Worker @make_test 3165*da0073e9SAndroid Build Coastguard Worker def test_map_zip_dict(a): 3166*da0073e9SAndroid Build Coastguard Worker d = dict( 3167*da0073e9SAndroid Build Coastguard Worker zip( 3168*da0073e9SAndroid Build Coastguard Worker map(lambda x: x + 1, [0, 1, 2]), 3169*da0073e9SAndroid Build Coastguard Worker [map(lambda x: x - 1, [y]) for y in [3, 4, 5]], 3170*da0073e9SAndroid Build Coastguard Worker ) 3171*da0073e9SAndroid Build Coastguard Worker ) 3172*da0073e9SAndroid Build Coastguard Worker return list(d[3])[0], a + 1 # noqa: RUF015 3173*da0073e9SAndroid Build Coastguard Worker 3174*da0073e9SAndroid Build Coastguard Worker @make_test 3175*da0073e9SAndroid Build Coastguard Worker def test_map_dict_fromkeys(a): 3176*da0073e9SAndroid Build Coastguard Worker return dict.fromkeys(map(lambda x: x + 1, [0, 1])), a + 1 3177*da0073e9SAndroid Build Coastguard Worker 3178*da0073e9SAndroid Build Coastguard Worker @make_test 3179*da0073e9SAndroid Build Coastguard Worker def test_map_set(a): 3180*da0073e9SAndroid Build Coastguard Worker return set(map(lambda x: x + 1, [0, 1])), a + 1 3181*da0073e9SAndroid Build Coastguard Worker 3182*da0073e9SAndroid Build Coastguard Worker # test_map_sum defined earlier 3183*da0073e9SAndroid Build Coastguard Worker 3184*da0073e9SAndroid Build Coastguard Worker @make_test 3185*da0073e9SAndroid Build Coastguard Worker def test_map_reduce(a, b): 3186*da0073e9SAndroid Build Coastguard Worker return functools.reduce(lambda x, y: x + y, map(lambda x: x + 1, [a, b])) 3187*da0073e9SAndroid Build Coastguard Worker 3188*da0073e9SAndroid Build Coastguard Worker @make_test 3189*da0073e9SAndroid Build Coastguard Worker def test_map_sorted(a): 3190*da0073e9SAndroid Build Coastguard Worker return sorted(map(lambda x: x + 1, [0, 4, 3, 1, 2])), a + 1 3191*da0073e9SAndroid Build Coastguard Worker 3192*da0073e9SAndroid Build Coastguard Worker @make_test 3193*da0073e9SAndroid Build Coastguard Worker def test_map_list_extend(a, b, c): 3194*da0073e9SAndroid Build Coastguard Worker l = [a] 3195*da0073e9SAndroid Build Coastguard Worker l.extend(map(lambda x: x + 1, [b, c])) 3196*da0073e9SAndroid Build Coastguard Worker return l 3197*da0073e9SAndroid Build Coastguard Worker 3198*da0073e9SAndroid Build Coastguard Worker @make_test 3199*da0073e9SAndroid Build Coastguard Worker def test_map_list_slice_assign(a, b, c, d, e): 3200*da0073e9SAndroid Build Coastguard Worker l = [a, b, c] 3201*da0073e9SAndroid Build Coastguard Worker l[1:2] = map(lambda x: x + 1, [d, e]) 3202*da0073e9SAndroid Build Coastguard Worker return l 3203*da0073e9SAndroid Build Coastguard Worker 3204*da0073e9SAndroid Build Coastguard Worker @make_test 3205*da0073e9SAndroid Build Coastguard Worker def test_map_deque_extendleft(a, b, c): 3206*da0073e9SAndroid Build Coastguard Worker d = collections.deque([a]) 3207*da0073e9SAndroid Build Coastguard Worker d.extendleft(map(lambda x: x + 1, [b, c])) 3208*da0073e9SAndroid Build Coastguard Worker return d 3209*da0073e9SAndroid Build Coastguard Worker 3210*da0073e9SAndroid Build Coastguard Worker @make_test 3211*da0073e9SAndroid Build Coastguard Worker def test_map_str_join(a): 3212*da0073e9SAndroid Build Coastguard Worker return "".join(map(lambda x: x, ["a", "b", "c"])), a + 1 3213*da0073e9SAndroid Build Coastguard Worker 3214*da0073e9SAndroid Build Coastguard Worker def test_map_with_graph_break(self): 3215*da0073e9SAndroid Build Coastguard Worker def f(a): 3216*da0073e9SAndroid Build Coastguard Worker a += 1 3217*da0073e9SAndroid Build Coastguard Worker 3218*da0073e9SAndroid Build Coastguard Worker def g(x): 3219*da0073e9SAndroid Build Coastguard Worker nonlocal a 3220*da0073e9SAndroid Build Coastguard Worker a += 1 3221*da0073e9SAndroid Build Coastguard Worker return x + 1 3222*da0073e9SAndroid Build Coastguard Worker 3223*da0073e9SAndroid Build Coastguard Worker m = map(g, [1, 2, 3, 4, 5]) 3224*da0073e9SAndroid Build Coastguard Worker a += next(m) # won't graph break 3225*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 3226*da0073e9SAndroid Build Coastguard Worker a += next(m) # will graph break 3227*da0073e9SAndroid Build Coastguard Worker return a 3228*da0073e9SAndroid Build Coastguard Worker 3229*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3230*da0073e9SAndroid Build Coastguard Worker opt_f = torch.compile(f, backend=cnts) 3231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3))) 3232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 3) 3233*da0073e9SAndroid Build Coastguard Worker 3234*da0073e9SAndroid Build Coastguard Worker def test_map_reconstruct(self): 3235*da0073e9SAndroid Build Coastguard Worker def fn(a): 3236*da0073e9SAndroid Build Coastguard Worker return map(lambda x: x[0] + x[1], zip([1, 2, 3], [1, 2, 3])), a + 1 3237*da0073e9SAndroid Build Coastguard Worker 3238*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3239*da0073e9SAndroid Build Coastguard Worker m = opt_fn(torch.ones(3, 3))[0] 3240*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(m, map) 3241*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0])) 3242*da0073e9SAndroid Build Coastguard Worker 3243*da0073e9SAndroid Build Coastguard Worker def test_zip_reconstruct(self): 3244*da0073e9SAndroid Build Coastguard Worker def fn(a): 3245*da0073e9SAndroid Build Coastguard Worker return zip([1, 2, 3], map(lambda x: x + 1, [1, 2, 3])), a + 1 3246*da0073e9SAndroid Build Coastguard Worker 3247*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3248*da0073e9SAndroid Build Coastguard Worker m = opt_fn(torch.ones(3, 3))[0] 3249*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(m, zip) 3250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0])) 3251*da0073e9SAndroid Build Coastguard Worker 3252*da0073e9SAndroid Build Coastguard Worker @make_test 3253*da0073e9SAndroid Build Coastguard Worker def test_map_partial_unpack(a, b): 3254*da0073e9SAndroid Build Coastguard Worker y = 1 3255*da0073e9SAndroid Build Coastguard Worker 3256*da0073e9SAndroid Build Coastguard Worker def f(x): 3257*da0073e9SAndroid Build Coastguard Worker nonlocal y 3258*da0073e9SAndroid Build Coastguard Worker y += 1 3259*da0073e9SAndroid Build Coastguard Worker return x 3260*da0073e9SAndroid Build Coastguard Worker 3261*da0073e9SAndroid Build Coastguard Worker l = list(zip([a, b], map(f, [1, 2, 3, 4]))) 3262*da0073e9SAndroid Build Coastguard Worker return a + y 3263*da0073e9SAndroid Build Coastguard Worker 3264*da0073e9SAndroid Build Coastguard Worker @make_test 3265*da0073e9SAndroid Build Coastguard Worker def test_map_call_function_ex(a, b): 3266*da0073e9SAndroid Build Coastguard Worker def f(x, y): 3267*da0073e9SAndroid Build Coastguard Worker return x + y 3268*da0073e9SAndroid Build Coastguard Worker 3269*da0073e9SAndroid Build Coastguard Worker return f(*map(lambda x: x + 1, [a, b])) 3270*da0073e9SAndroid Build Coastguard Worker 3271*da0073e9SAndroid Build Coastguard Worker @make_test 3272*da0073e9SAndroid Build Coastguard Worker def test_map_unpack_twice(a, b): 3273*da0073e9SAndroid Build Coastguard Worker m = map(lambda x: x + 1, [a, b]) 3274*da0073e9SAndroid Build Coastguard Worker l1 = list(m) 3275*da0073e9SAndroid Build Coastguard Worker l2 = list(m) 3276*da0073e9SAndroid Build Coastguard Worker return l1, l2 3277*da0073e9SAndroid Build Coastguard Worker 3278*da0073e9SAndroid Build Coastguard Worker @make_test 3279*da0073e9SAndroid Build Coastguard Worker def test_enumerate(a, b): 3280*da0073e9SAndroid Build Coastguard Worker return list(enumerate([a, b], start=1)), a + 1 3281*da0073e9SAndroid Build Coastguard Worker 3282*da0073e9SAndroid Build Coastguard Worker @make_test 3283*da0073e9SAndroid Build Coastguard Worker def test_map_enumerate(a, b): 3284*da0073e9SAndroid Build Coastguard Worker return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1 3285*da0073e9SAndroid Build Coastguard Worker 3286*da0073e9SAndroid Build Coastguard Worker @make_test 3287*da0073e9SAndroid Build Coastguard Worker def test_map_infinite(a, b): 3288*da0073e9SAndroid Build Coastguard Worker return list(map(lambda x, y: x + y, [a, b], itertools.count(3))) 3289*da0073e9SAndroid Build Coastguard Worker 3290*da0073e9SAndroid Build Coastguard Worker @make_test 3291*da0073e9SAndroid Build Coastguard Worker def test_map_unpack_vars(a, b): 3292*da0073e9SAndroid Build Coastguard Worker x, y = map(lambda x: x + 1, [a, b]) 3293*da0073e9SAndroid Build Coastguard Worker return x + y 3294*da0073e9SAndroid Build Coastguard Worker 3295*da0073e9SAndroid Build Coastguard Worker def test_enumerate_custom(self): 3296*da0073e9SAndroid Build Coastguard Worker class MyClass: 3297*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 3298*da0073e9SAndroid Build Coastguard Worker self.a = 1 3299*da0073e9SAndroid Build Coastguard Worker return self 3300*da0073e9SAndroid Build Coastguard Worker 3301*da0073e9SAndroid Build Coastguard Worker def __next__(self): 3302*da0073e9SAndroid Build Coastguard Worker if self.a > 3: 3303*da0073e9SAndroid Build Coastguard Worker raise StopIteration 3304*da0073e9SAndroid Build Coastguard Worker self.a += 1 3305*da0073e9SAndroid Build Coastguard Worker return self.a 3306*da0073e9SAndroid Build Coastguard Worker 3307*da0073e9SAndroid Build Coastguard Worker def fn(x): 3308*da0073e9SAndroid Build Coastguard Worker for i, it in enumerate(MyClass()): 3309*da0073e9SAndroid Build Coastguard Worker x += i + it 3310*da0073e9SAndroid Build Coastguard Worker return x 3311*da0073e9SAndroid Build Coastguard Worker 3312*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3313*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3))) 3314*da0073e9SAndroid Build Coastguard Worker 3315*da0073e9SAndroid Build Coastguard Worker def test_enumerate_reconstruct(self): 3316*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 3317*da0073e9SAndroid Build Coastguard Worker return enumerate([a, b], start=1) 3318*da0073e9SAndroid Build Coastguard Worker 3319*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3320*da0073e9SAndroid Build Coastguard Worker inps = (torch.randn(3, 3), torch.randn(3, 3)) 3321*da0073e9SAndroid Build Coastguard Worker it1 = fn(*inps) 3322*da0073e9SAndroid Build Coastguard Worker it2 = opt_fn(*inps) 3323*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(it2, enumerate) 3324*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(it1), list(it2)) 3325*da0073e9SAndroid Build Coastguard Worker 3326*da0073e9SAndroid Build Coastguard Worker 3327*da0073e9SAndroid Build Coastguard Workerdef udf_mul(x, y): 3328*da0073e9SAndroid Build Coastguard Worker return x * y 3329*da0073e9SAndroid Build Coastguard Worker 3330*da0073e9SAndroid Build Coastguard Worker 3331*da0073e9SAndroid Build Coastguard Workerdef udf_mul2(x, y, z): 3332*da0073e9SAndroid Build Coastguard Worker return x * y * z 3333*da0073e9SAndroid Build Coastguard Worker 3334*da0073e9SAndroid Build Coastguard Worker 3335*da0073e9SAndroid Build Coastguard Workerdef udf_add(x, y): 3336*da0073e9SAndroid Build Coastguard Worker return x + y 3337*da0073e9SAndroid Build Coastguard Worker 3338*da0073e9SAndroid Build Coastguard Worker 3339*da0073e9SAndroid Build Coastguard Workerclass SmallNN(torch.nn.Module): 3340*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 3341*da0073e9SAndroid Build Coastguard Worker combined = torch.cat((x, y), dim=1) 3342*da0073e9SAndroid Build Coastguard Worker out = torch.nn.ReLU()(combined) 3343*da0073e9SAndroid Build Coastguard Worker out = torch.nn.ReLU()(out) 3344*da0073e9SAndroid Build Coastguard Worker return out 3345*da0073e9SAndroid Build Coastguard Worker 3346*da0073e9SAndroid Build Coastguard Worker 3347*da0073e9SAndroid Build Coastguard Workerdef udf_module(mod, x, y): 3348*da0073e9SAndroid Build Coastguard Worker return mod(x, y) 3349*da0073e9SAndroid Build Coastguard Worker 3350*da0073e9SAndroid Build Coastguard Worker 3351*da0073e9SAndroid Build Coastguard Workerdef global_func_with_default_tensor_args( 3352*da0073e9SAndroid Build Coastguard Worker x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2)) 3353*da0073e9SAndroid Build Coastguard Worker): 3354*da0073e9SAndroid Build Coastguard Worker x.add_(1) 3355*da0073e9SAndroid Build Coastguard Worker kw_x.add_(1) 3356*da0073e9SAndroid Build Coastguard Worker return x, kw_x 3357*da0073e9SAndroid Build Coastguard Worker 3358*da0073e9SAndroid Build Coastguard Worker 3359*da0073e9SAndroid Build Coastguard Workerclass ModuleWithDefaultTensorArgsMethod(torch.nn.Module): 3360*da0073e9SAndroid Build Coastguard Worker def forward(self, x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))): 3361*da0073e9SAndroid Build Coastguard Worker x.add_(1) 3362*da0073e9SAndroid Build Coastguard Worker kw_x.add_(1) 3363*da0073e9SAndroid Build Coastguard Worker return x, kw_x 3364*da0073e9SAndroid Build Coastguard Worker 3365*da0073e9SAndroid Build Coastguard Worker 3366*da0073e9SAndroid Build Coastguard Workerclass WrapperModule(torch.nn.Module): 3367*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3368*da0073e9SAndroid Build Coastguard Worker super().__init__() 3369*da0073e9SAndroid Build Coastguard Worker self.m = ModuleWithDefaultTensorArgsMethod() 3370*da0073e9SAndroid Build Coastguard Worker 3371*da0073e9SAndroid Build Coastguard Worker def forward(self): 3372*da0073e9SAndroid Build Coastguard Worker return self.m() 3373*da0073e9SAndroid Build Coastguard Worker 3374*da0073e9SAndroid Build Coastguard Worker 3375*da0073e9SAndroid Build Coastguard Workerclass DefaultsTests(torch._dynamo.test_case.TestCase): 3376*da0073e9SAndroid Build Coastguard Worker def test_func_default_tensor_args(self): 3377*da0073e9SAndroid Build Coastguard Worker """ 3378*da0073e9SAndroid Build Coastguard Worker Tests that we indeed reference (and mutate) "the one" default tensor arg 3379*da0073e9SAndroid Build Coastguard Worker stored on the globally allocated function object, both from the orig and 3380*da0073e9SAndroid Build Coastguard Worker compiled function 3381*da0073e9SAndroid Build Coastguard Worker """ 3382*da0073e9SAndroid Build Coastguard Worker 3383*da0073e9SAndroid Build Coastguard Worker def func(): 3384*da0073e9SAndroid Build Coastguard Worker return global_func_with_default_tensor_args() 3385*da0073e9SAndroid Build Coastguard Worker 3386*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3387*da0073e9SAndroid Build Coastguard Worker compiled_func = torch.compile(func, backend=cnts) 3388*da0073e9SAndroid Build Coastguard Worker for i in range(4): 3389*da0073e9SAndroid Build Coastguard Worker if i % 2 == 0: 3390*da0073e9SAndroid Build Coastguard Worker x, kw_x = func() 3391*da0073e9SAndroid Build Coastguard Worker else: 3392*da0073e9SAndroid Build Coastguard Worker x, kw_x = compiled_func() 3393*da0073e9SAndroid Build Coastguard Worker # the inner func mutates += 1 each call 3394*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(x, torch.ones_like(x) + i)) 3395*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i)) 3396*da0073e9SAndroid Build Coastguard Worker # Calling compiled_func twice does not recompile 3397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 3399*da0073e9SAndroid Build Coastguard Worker 3400*da0073e9SAndroid Build Coastguard Worker # But with a change to the guarded default tensor, we do recompile 3401*da0073e9SAndroid Build Coastguard Worker with patch.object( 3402*da0073e9SAndroid Build Coastguard Worker global_func_with_default_tensor_args, 3403*da0073e9SAndroid Build Coastguard Worker "__defaults__", 3404*da0073e9SAndroid Build Coastguard Worker (torch.ones((3, 4, 5)),), 3405*da0073e9SAndroid Build Coastguard Worker ): 3406*da0073e9SAndroid Build Coastguard Worker x, kw_x = compiled_func() 3407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 3408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 4) 3409*da0073e9SAndroid Build Coastguard Worker 3410*da0073e9SAndroid Build Coastguard Worker with patch.object( 3411*da0073e9SAndroid Build Coastguard Worker global_func_with_default_tensor_args, 3412*da0073e9SAndroid Build Coastguard Worker "__kwdefaults__", 3413*da0073e9SAndroid Build Coastguard Worker {"kw_x": torch.ones((3, 4, 5))}, 3414*da0073e9SAndroid Build Coastguard Worker ): 3415*da0073e9SAndroid Build Coastguard Worker x, kw_x = compiled_func() 3416*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 3) 3417*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 6) 3418*da0073e9SAndroid Build Coastguard Worker 3419*da0073e9SAndroid Build Coastguard Worker def test_meth_default_tensor_args(self): 3420*da0073e9SAndroid Build Coastguard Worker """ 3421*da0073e9SAndroid Build Coastguard Worker Tests that we indeed reference (and mutate) "the one" default tensor arg 3422*da0073e9SAndroid Build Coastguard Worker stored on the globally allocated function object, both from the orig and 3423*da0073e9SAndroid Build Coastguard Worker compiled function 3424*da0073e9SAndroid Build Coastguard Worker """ 3425*da0073e9SAndroid Build Coastguard Worker mod = WrapperModule() 3426*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3427*da0073e9SAndroid Build Coastguard Worker compiled_mod = torch.compile(mod, backend=cnts) 3428*da0073e9SAndroid Build Coastguard Worker for i in range(4): 3429*da0073e9SAndroid Build Coastguard Worker if i % 2 == 0: 3430*da0073e9SAndroid Build Coastguard Worker x, kw_x = mod() 3431*da0073e9SAndroid Build Coastguard Worker else: 3432*da0073e9SAndroid Build Coastguard Worker x, kw_x = compiled_mod() 3433*da0073e9SAndroid Build Coastguard Worker # the inner func mutates += 1 each call 3434*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(x, torch.ones_like(x) + i)) 3435*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i)) 3436*da0073e9SAndroid Build Coastguard Worker # Calling compiled_func twice does not recompile 3437*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 2) 3439*da0073e9SAndroid Build Coastguard Worker 3440*da0073e9SAndroid Build Coastguard Worker # But with a change to the guarded default tensor, we do recompile 3441*da0073e9SAndroid Build Coastguard Worker with patch.object( 3442*da0073e9SAndroid Build Coastguard Worker ModuleWithDefaultTensorArgsMethod.forward, 3443*da0073e9SAndroid Build Coastguard Worker "__defaults__", 3444*da0073e9SAndroid Build Coastguard Worker (torch.ones((3, 4, 5)),), 3445*da0073e9SAndroid Build Coastguard Worker ): 3446*da0073e9SAndroid Build Coastguard Worker x, kw_x = compiled_mod() 3447*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 3448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 4) 3449*da0073e9SAndroid Build Coastguard Worker 3450*da0073e9SAndroid Build Coastguard Worker with patch.object( 3451*da0073e9SAndroid Build Coastguard Worker ModuleWithDefaultTensorArgsMethod.forward, 3452*da0073e9SAndroid Build Coastguard Worker "__kwdefaults__", 3453*da0073e9SAndroid Build Coastguard Worker {"kw_x": torch.ones((3, 4, 5))}, 3454*da0073e9SAndroid Build Coastguard Worker ): 3455*da0073e9SAndroid Build Coastguard Worker x, kw_x = compiled_mod() 3456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 3) 3457*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 6) 3458*da0073e9SAndroid Build Coastguard Worker 3459*da0073e9SAndroid Build Coastguard Worker def test_func_default_torch_args(self): 3460*da0073e9SAndroid Build Coastguard Worker """ 3461*da0073e9SAndroid Build Coastguard Worker Tests other types of torch types as function default (size, dtype, device) 3462*da0073e9SAndroid Build Coastguard Worker """ 3463*da0073e9SAndroid Build Coastguard Worker 3464*da0073e9SAndroid Build Coastguard Worker def func_with_default_torch_args( 3465*da0073e9SAndroid Build Coastguard Worker dt=torch.float16, ds=torch.Size((1, 2, 3)), dd=torch.device("cpu") 3466*da0073e9SAndroid Build Coastguard Worker ): 3467*da0073e9SAndroid Build Coastguard Worker return torch.ones(ds, dtype=dt, device=dd) 3468*da0073e9SAndroid Build Coastguard Worker 3469*da0073e9SAndroid Build Coastguard Worker def func(): 3470*da0073e9SAndroid Build Coastguard Worker return func_with_default_torch_args() 3471*da0073e9SAndroid Build Coastguard Worker 3472*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3473*da0073e9SAndroid Build Coastguard Worker compiled_func = torch.compile(func, backend=cnts) 3474*da0073e9SAndroid Build Coastguard Worker out = func() 3475*da0073e9SAndroid Build Coastguard Worker compiled_out = compiled_func() 3476*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dtype, compiled_out.dtype) 3477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.device, compiled_out.device) 3478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.size(), compiled_out.size()) 3479*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3480*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 1) 3481*da0073e9SAndroid Build Coastguard Worker 3482*da0073e9SAndroid Build Coastguard Worker def test_dataclass_factory(self): 3483*da0073e9SAndroid Build Coastguard Worker @dataclass 3484*da0073e9SAndroid Build Coastguard Worker class Output: 3485*da0073e9SAndroid Build Coastguard Worker scalar: int = 2 3486*da0073e9SAndroid Build Coastguard Worker named_tensors: Dict[str, torch.Tensor] = field(default_factory=dict) 3487*da0073e9SAndroid Build Coastguard Worker lists: List[torch.Tensor] = field(default_factory=list) 3488*da0073e9SAndroid Build Coastguard Worker 3489*da0073e9SAndroid Build Coastguard Worker def scale(self): 3490*da0073e9SAndroid Build Coastguard Worker return self.scalar * 2 3491*da0073e9SAndroid Build Coastguard Worker 3492*da0073e9SAndroid Build Coastguard Worker def fn(x): 3493*da0073e9SAndroid Build Coastguard Worker # Check default dict assignment 3494*da0073e9SAndroid Build Coastguard Worker a = Output(1) 3495*da0073e9SAndroid Build Coastguard Worker # Check that dataclass methods can be inlined 3496*da0073e9SAndroid Build Coastguard Worker scaled_value = a.scale() 3497*da0073e9SAndroid Build Coastguard Worker 3498*da0073e9SAndroid Build Coastguard Worker # Check that normal assignment works 3499*da0073e9SAndroid Build Coastguard Worker b = Output(5, named_tensors={"x": x}) 3500*da0073e9SAndroid Build Coastguard Worker 3501*da0073e9SAndroid Build Coastguard Worker # Check default int assignment 3502*da0073e9SAndroid Build Coastguard Worker c = Output() 3503*da0073e9SAndroid Build Coastguard Worker 3504*da0073e9SAndroid Build Coastguard Worker # Check that the default members are properly initialized 3505*da0073e9SAndroid Build Coastguard Worker if isinstance(a.named_tensors, dict): 3506*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 3507*da0073e9SAndroid Build Coastguard Worker 3508*da0073e9SAndroid Build Coastguard Worker # Change dataclass 3509*da0073e9SAndroid Build Coastguard Worker c.scalar = 6 3510*da0073e9SAndroid Build Coastguard Worker c.named_tensors["x"] = x 3511*da0073e9SAndroid Build Coastguard Worker 3512*da0073e9SAndroid Build Coastguard Worker # Return dataclaass as well to check reconstruction 3513*da0073e9SAndroid Build Coastguard Worker return c, torch.cos(x) * scaled_value + b.named_tensors["x"] + c.scalar 3514*da0073e9SAndroid Build Coastguard Worker 3515*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3516*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch.compile(fn, backend=cnts, fullgraph=True) 3517*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 3518*da0073e9SAndroid Build Coastguard Worker eager_dataclass, out = fn(x) 3519*da0073e9SAndroid Build Coastguard Worker compiled_dataclass, compiled_out = compiled_fn(x) 3520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_dataclass.scalar, compiled_dataclass.scalar) 3521*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3522*da0073e9SAndroid Build Coastguard Worker eager_dataclass.named_tensors["x"], compiled_dataclass.named_tensors["x"] 3523*da0073e9SAndroid Build Coastguard Worker ) 3524*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(out, compiled_out)) 3525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3526*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.op_count, 5) 3527*da0073e9SAndroid Build Coastguard Worker 3528*da0073e9SAndroid Build Coastguard Worker def test_dataclass_nested(self): 3529*da0073e9SAndroid Build Coastguard Worker @dataclass 3530*da0073e9SAndroid Build Coastguard Worker class Base: 3531*da0073e9SAndroid Build Coastguard Worker outer_a: int 3532*da0073e9SAndroid Build Coastguard Worker outer_b: int 3533*da0073e9SAndroid Build Coastguard Worker 3534*da0073e9SAndroid Build Coastguard Worker @dataclass 3535*da0073e9SAndroid Build Coastguard Worker class Derived(Base): 3536*da0073e9SAndroid Build Coastguard Worker inner_a: Any = field(default_factory=list) 3537*da0073e9SAndroid Build Coastguard Worker 3538*da0073e9SAndroid Build Coastguard Worker def fn(x): 3539*da0073e9SAndroid Build Coastguard Worker l = Derived(1, 2) 3540*da0073e9SAndroid Build Coastguard Worker return l.outer_a * x 3541*da0073e9SAndroid Build Coastguard Worker 3542*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3543*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 3544*da0073e9SAndroid Build Coastguard Worker res = fn(x) 3545*da0073e9SAndroid Build Coastguard Worker ref = opt_fn(x) 3546*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 3547*da0073e9SAndroid Build Coastguard Worker 3548*da0073e9SAndroid Build Coastguard Worker def test_listlike_of_tensors_contains_constant(self): 3549*da0073e9SAndroid Build Coastguard Worker for listlike in [set, list]: 3550*da0073e9SAndroid Build Coastguard Worker 3551*da0073e9SAndroid Build Coastguard Worker def fn(x): 3552*da0073e9SAndroid Build Coastguard Worker x.add_(1) 3553*da0073e9SAndroid Build Coastguard Worker s = listlike([x]) 3554*da0073e9SAndroid Build Coastguard Worker res = 1 in s 3555*da0073e9SAndroid Build Coastguard Worker return res 3556*da0073e9SAndroid Build Coastguard Worker 3557*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3558*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 3559*da0073e9SAndroid Build Coastguard Worker ref = opt_fn(x) 3560*da0073e9SAndroid Build Coastguard Worker res = fn(x) 3561*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 3562*da0073e9SAndroid Build Coastguard Worker 3563*da0073e9SAndroid Build Coastguard Worker def test_cast_tensor_single_elem(self): 3564*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch({"capture_scalar_outputs": True}): 3565*da0073e9SAndroid Build Coastguard Worker for t, val in [ 3566*da0073e9SAndroid Build Coastguard Worker (float, 1.0), 3567*da0073e9SAndroid Build Coastguard Worker (float, 1), 3568*da0073e9SAndroid Build Coastguard Worker (float, True), 3569*da0073e9SAndroid Build Coastguard Worker (int, 1), 3570*da0073e9SAndroid Build Coastguard Worker (int, False), 3571*da0073e9SAndroid Build Coastguard Worker # (int, 1.0), # fails due to a >= 0 comparison in sym_int 3572*da0073e9SAndroid Build Coastguard Worker ]: # , bool, complex]: no casting for sym_bool, no sym_complex 3573*da0073e9SAndroid Build Coastguard Worker 3574*da0073e9SAndroid Build Coastguard Worker def fn(x): 3575*da0073e9SAndroid Build Coastguard Worker x = x + 1 3576*da0073e9SAndroid Build Coastguard Worker return t(x) 3577*da0073e9SAndroid Build Coastguard Worker 3578*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile( 3579*da0073e9SAndroid Build Coastguard Worker fn, backend="eager", fullgraph=True, dynamic=False 3580*da0073e9SAndroid Build Coastguard Worker ) 3581*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([val]) 3582*da0073e9SAndroid Build Coastguard Worker res = fn(x) 3583*da0073e9SAndroid Build Coastguard Worker ref = opt_fn(x) 3584*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 3585*da0073e9SAndroid Build Coastguard Worker 3586*da0073e9SAndroid Build Coastguard Worker # Cannot handle non single-elem 3587*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 3588*da0073e9SAndroid Build Coastguard Worker fn(torch.tensor([val] * 2)) 3589*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch._dynamo.exc.TorchRuntimeError): 3590*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.tensor([val] * 2)) 3591*da0073e9SAndroid Build Coastguard Worker 3592*da0073e9SAndroid Build Coastguard Worker def test_set_construction(self): 3593*da0073e9SAndroid Build Coastguard Worker def fn(x): 3594*da0073e9SAndroid Build Coastguard Worker y = x.add_(1) 3595*da0073e9SAndroid Build Coastguard Worker s = set({x}) 3596*da0073e9SAndroid Build Coastguard Worker s.add(y) 3597*da0073e9SAndroid Build Coastguard Worker return len(s) 3598*da0073e9SAndroid Build Coastguard Worker 3599*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3600*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 3601*da0073e9SAndroid Build Coastguard Worker res = fn(x) 3602*da0073e9SAndroid Build Coastguard Worker ref = opt_fn(x) 3603*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 3604*da0073e9SAndroid Build Coastguard Worker 3605*da0073e9SAndroid Build Coastguard Worker def test_frozenset_construction(self): 3606*da0073e9SAndroid Build Coastguard Worker def fn(x): 3607*da0073e9SAndroid Build Coastguard Worker s = frozenset({x}) 3608*da0073e9SAndroid Build Coastguard Worker t = frozenset(s) 3609*da0073e9SAndroid Build Coastguard Worker return len(t) 3610*da0073e9SAndroid Build Coastguard Worker 3611*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3612*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 3613*da0073e9SAndroid Build Coastguard Worker res = fn(x) 3614*da0073e9SAndroid Build Coastguard Worker ref = opt_fn(x) 3615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 3616*da0073e9SAndroid Build Coastguard Worker 3617*da0073e9SAndroid Build Coastguard Worker def test_frozenset_reconstruction(self): 3618*da0073e9SAndroid Build Coastguard Worker d = {} 3619*da0073e9SAndroid Build Coastguard Worker f = frozenset() 3620*da0073e9SAndroid Build Coastguard Worker d[f] = torch.randn(4) 3621*da0073e9SAndroid Build Coastguard Worker 3622*da0073e9SAndroid Build Coastguard Worker def fn(x): 3623*da0073e9SAndroid Build Coastguard Worker k = frozenset() 3624*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 3625*da0073e9SAndroid Build Coastguard Worker return d[k] * x 3626*da0073e9SAndroid Build Coastguard Worker 3627*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 3628*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 3629*da0073e9SAndroid Build Coastguard Worker res = fn(x) 3630*da0073e9SAndroid Build Coastguard Worker ref = opt_fn(x) 3631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 3632*da0073e9SAndroid Build Coastguard Worker 3633*da0073e9SAndroid Build Coastguard Worker def test_frozenset_illegal_call_method(self): 3634*da0073e9SAndroid Build Coastguard Worker def fn_add(): 3635*da0073e9SAndroid Build Coastguard Worker s = frozenset((1, 2, 3)) 3636*da0073e9SAndroid Build Coastguard Worker s.add({2}) 3637*da0073e9SAndroid Build Coastguard Worker return len(s) 3638*da0073e9SAndroid Build Coastguard Worker 3639*da0073e9SAndroid Build Coastguard Worker def fn_pop(): 3640*da0073e9SAndroid Build Coastguard Worker s = frozenset((1, 2, 3)) 3641*da0073e9SAndroid Build Coastguard Worker s.pop() 3642*da0073e9SAndroid Build Coastguard Worker return len(s) 3643*da0073e9SAndroid Build Coastguard Worker 3644*da0073e9SAndroid Build Coastguard Worker def fn_update(): 3645*da0073e9SAndroid Build Coastguard Worker s = frozenset((1, 2, 3)) 3646*da0073e9SAndroid Build Coastguard Worker s.update({4, 5, 6}) 3647*da0073e9SAndroid Build Coastguard Worker return len(s) 3648*da0073e9SAndroid Build Coastguard Worker 3649*da0073e9SAndroid Build Coastguard Worker def fn_remove(): 3650*da0073e9SAndroid Build Coastguard Worker s = frozenset((1, 2, 3)) 3651*da0073e9SAndroid Build Coastguard Worker s.remove(2) 3652*da0073e9SAndroid Build Coastguard Worker return len(s) 3653*da0073e9SAndroid Build Coastguard Worker 3654*da0073e9SAndroid Build Coastguard Worker def fn_discard(): 3655*da0073e9SAndroid Build Coastguard Worker s = frozenset((1, 2, 3)) 3656*da0073e9SAndroid Build Coastguard Worker s.discard(2) 3657*da0073e9SAndroid Build Coastguard Worker return len(s) 3658*da0073e9SAndroid Build Coastguard Worker 3659*da0073e9SAndroid Build Coastguard Worker def fn_clear(): 3660*da0073e9SAndroid Build Coastguard Worker s = frozenset((1, 2, 3)) 3661*da0073e9SAndroid Build Coastguard Worker s.clear() 3662*da0073e9SAndroid Build Coastguard Worker return len(s) 3663*da0073e9SAndroid Build Coastguard Worker 3664*da0073e9SAndroid Build Coastguard Worker for fn in [fn_add, fn_pop, fn_update, fn_remove, fn_discard, fn_clear]: 3665*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 3666*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3667*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError): 3668*da0073e9SAndroid Build Coastguard Worker opt_fn() 3669*da0073e9SAndroid Build Coastguard Worker 3670*da0073e9SAndroid Build Coastguard Worker def test_is_tensor_tensor(self): 3671*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 3672*da0073e9SAndroid Build Coastguard Worker if x is y: 3673*da0073e9SAndroid Build Coastguard Worker return x * 2 3674*da0073e9SAndroid Build Coastguard Worker else: 3675*da0073e9SAndroid Build Coastguard Worker return x + y 3676*da0073e9SAndroid Build Coastguard Worker 3677*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3678*da0073e9SAndroid Build Coastguard Worker 3679*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(2) 3680*da0073e9SAndroid Build Coastguard Worker y = torch.ones(2) 3681*da0073e9SAndroid Build Coastguard Worker 3682*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x, y), fn_opt(x, y)) 3683*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x, x), fn_opt(x, x)) 3684*da0073e9SAndroid Build Coastguard Worker 3685*da0073e9SAndroid Build Coastguard Worker def test_is_not_tensor_tensor(self): 3686*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 3687*da0073e9SAndroid Build Coastguard Worker if x is not y: 3688*da0073e9SAndroid Build Coastguard Worker return x * 2 3689*da0073e9SAndroid Build Coastguard Worker else: 3690*da0073e9SAndroid Build Coastguard Worker return x + y 3691*da0073e9SAndroid Build Coastguard Worker 3692*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3693*da0073e9SAndroid Build Coastguard Worker 3694*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(2) 3695*da0073e9SAndroid Build Coastguard Worker y = torch.ones(2) 3696*da0073e9SAndroid Build Coastguard Worker 3697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x, y), fn_opt(x, y)) 3698*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x, x), fn_opt(x, x)) 3699*da0073e9SAndroid Build Coastguard Worker 3700*da0073e9SAndroid Build Coastguard Worker def test_is_mutated_tensor_tensor(self): 3701*da0073e9SAndroid Build Coastguard Worker def fn(x): 3702*da0073e9SAndroid Build Coastguard Worker y = x.add_(1) 3703*da0073e9SAndroid Build Coastguard Worker return x is y 3704*da0073e9SAndroid Build Coastguard Worker 3705*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3706*da0073e9SAndroid Build Coastguard Worker 3707*da0073e9SAndroid Build Coastguard Worker z = torch.ones(4) 3708*da0073e9SAndroid Build Coastguard Worker 3709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(z), fn_opt(z)) 3710*da0073e9SAndroid Build Coastguard Worker 3711*da0073e9SAndroid Build Coastguard Worker def test_is_mutated_tensor_tensor_across_graph_break(self): 3712*da0073e9SAndroid Build Coastguard Worker def fn(x): 3713*da0073e9SAndroid Build Coastguard Worker y = x.add_(1) 3714*da0073e9SAndroid Build Coastguard Worker cond = x is y 3715*da0073e9SAndroid Build Coastguard Worker x.add_(1) 3716*da0073e9SAndroid Build Coastguard Worker # The real tensor values are recovered when graph breaking. 3717*da0073e9SAndroid Build Coastguard Worker # Hence we recover the invariant. 3718*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 3719*da0073e9SAndroid Build Coastguard Worker x.add_(1) 3720*da0073e9SAndroid Build Coastguard Worker return x is y, cond 3721*da0073e9SAndroid Build Coastguard Worker 3722*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(backend="eager", dynamic=True)(fn) 3723*da0073e9SAndroid Build Coastguard Worker 3724*da0073e9SAndroid Build Coastguard Worker z = torch.ones(4) 3725*da0073e9SAndroid Build Coastguard Worker 3726*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(z), fn_opt(z)) 3727*da0073e9SAndroid Build Coastguard Worker 3728*da0073e9SAndroid Build Coastguard Worker def test_is_mutated_tensor_tensor(self): 3729*da0073e9SAndroid Build Coastguard Worker def fn(x): 3730*da0073e9SAndroid Build Coastguard Worker y = x.add_(1) 3731*da0073e9SAndroid Build Coastguard Worker return y is x 3732*da0073e9SAndroid Build Coastguard Worker 3733*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3734*da0073e9SAndroid Build Coastguard Worker 3735*da0073e9SAndroid Build Coastguard Worker z = torch.ones(4, 1) 3736*da0073e9SAndroid Build Coastguard Worker 3737*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(z), fn_opt(z)) 3738*da0073e9SAndroid Build Coastguard Worker 3739*da0073e9SAndroid Build Coastguard Worker def test_is_init_in_compile_mutated_tensor_tensor(self): 3740*da0073e9SAndroid Build Coastguard Worker def fn(x): 3741*da0073e9SAndroid Build Coastguard Worker z = x.clone() 3742*da0073e9SAndroid Build Coastguard Worker y = z.add_(1) 3743*da0073e9SAndroid Build Coastguard Worker return y is z 3744*da0073e9SAndroid Build Coastguard Worker 3745*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3746*da0073e9SAndroid Build Coastguard Worker 3747*da0073e9SAndroid Build Coastguard Worker z = torch.ones(4, 1) 3748*da0073e9SAndroid Build Coastguard Worker 3749*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(z), fn_opt(z)) 3750*da0073e9SAndroid Build Coastguard Worker 3751*da0073e9SAndroid Build Coastguard Worker def test_is_init_in_compile_vmapped_mutated_tensor_tensor(self): 3752*da0073e9SAndroid Build Coastguard Worker def fn(z): 3753*da0073e9SAndroid Build Coastguard Worker x = z.clone() 3754*da0073e9SAndroid Build Coastguard Worker y = torch.vmap(torch.Tensor.acos_)(x) 3755*da0073e9SAndroid Build Coastguard Worker _ = y is z 3756*da0073e9SAndroid Build Coastguard Worker return y is x 3757*da0073e9SAndroid Build Coastguard Worker 3758*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3759*da0073e9SAndroid Build Coastguard Worker 3760*da0073e9SAndroid Build Coastguard Worker z = torch.ones(4, 1) 3761*da0073e9SAndroid Build Coastguard Worker 3762*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(z), fn_opt(z)) 3763*da0073e9SAndroid Build Coastguard Worker 3764*da0073e9SAndroid Build Coastguard Worker def test_is_vmapped_mutated_tensor_tensor(self): 3765*da0073e9SAndroid Build Coastguard Worker def fn(x): 3766*da0073e9SAndroid Build Coastguard Worker y = torch.vmap(torch.Tensor.acos_)(x) 3767*da0073e9SAndroid Build Coastguard Worker return y is x 3768*da0073e9SAndroid Build Coastguard Worker 3769*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3770*da0073e9SAndroid Build Coastguard Worker 3771*da0073e9SAndroid Build Coastguard Worker z = torch.ones(4, 1) 3772*da0073e9SAndroid Build Coastguard Worker 3773*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(z), fn_opt(z)) 3774*da0073e9SAndroid Build Coastguard Worker 3775*da0073e9SAndroid Build Coastguard Worker def test_is_init_in_compile_vmapped_mutated_tensor_tensor_multi_arg(self): 3776*da0073e9SAndroid Build Coastguard Worker def fn(y, z): 3777*da0073e9SAndroid Build Coastguard Worker a = y.clone() 3778*da0073e9SAndroid Build Coastguard Worker b = z.clone() 3779*da0073e9SAndroid Build Coastguard Worker 3780*da0073e9SAndroid Build Coastguard Worker def g(a, b): 3781*da0073e9SAndroid Build Coastguard Worker return a.acos_(), b.acos_() 3782*da0073e9SAndroid Build Coastguard Worker 3783*da0073e9SAndroid Build Coastguard Worker c, d = torch.vmap(g)(a, b) 3784*da0073e9SAndroid Build Coastguard Worker return a is c is b is d 3785*da0073e9SAndroid Build Coastguard Worker 3786*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) 3787*da0073e9SAndroid Build Coastguard Worker 3788*da0073e9SAndroid Build Coastguard Worker y = torch.ones(4, 2) 3789*da0073e9SAndroid Build Coastguard Worker z = torch.ones(4, 10) 3790*da0073e9SAndroid Build Coastguard Worker 3791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(y, z), fn_opt(y, z)) 3792*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(y, y), fn_opt(y, y)) 3793*da0073e9SAndroid Build Coastguard Worker 3794*da0073e9SAndroid Build Coastguard Worker def test_in_set_would_fail_broadcast(self): 3795*da0073e9SAndroid Build Coastguard Worker param = torch.zeros(5) 3796*da0073e9SAndroid Build Coastguard Worker param2 = torch.zeros(5, 10) 3797*da0073e9SAndroid Build Coastguard Worker 3798*da0073e9SAndroid Build Coastguard Worker tensor_list = set() 3799*da0073e9SAndroid Build Coastguard Worker tensor_list.add(param2) 3800*da0073e9SAndroid Build Coastguard Worker assert param not in tensor_list 3801*da0073e9SAndroid Build Coastguard Worker 3802*da0073e9SAndroid Build Coastguard Worker def fn(param, param2): 3803*da0073e9SAndroid Build Coastguard Worker param.add_(1) 3804*da0073e9SAndroid Build Coastguard Worker tensor_list = set([param2]) 3805*da0073e9SAndroid Build Coastguard Worker return param in tensor_list 3806*da0073e9SAndroid Build Coastguard Worker 3807*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3808*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3809*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(param, param2), fn(param, param2)) 3810*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3811*da0073e9SAndroid Build Coastguard Worker # Test aliased 3812*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(param, param), fn(param, param)) 3813*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) # Recompiles 3814*da0073e9SAndroid Build Coastguard Worker 3815*da0073e9SAndroid Build Coastguard Worker def test_in_set_inplace(self): 3816*da0073e9SAndroid Build Coastguard Worker param = torch.zeros(5) 3817*da0073e9SAndroid Build Coastguard Worker param2 = torch.zeros(5, 10) 3818*da0073e9SAndroid Build Coastguard Worker 3819*da0073e9SAndroid Build Coastguard Worker tensor_list = set() 3820*da0073e9SAndroid Build Coastguard Worker tensor_list.add(param2) 3821*da0073e9SAndroid Build Coastguard Worker assert param not in tensor_list 3822*da0073e9SAndroid Build Coastguard Worker 3823*da0073e9SAndroid Build Coastguard Worker def fn(param, param2): 3824*da0073e9SAndroid Build Coastguard Worker y = param.add_(1) # Tensor method 3825*da0073e9SAndroid Build Coastguard Worker z = torch.Tensor.add_(y, 1) # torch function 3826*da0073e9SAndroid Build Coastguard Worker tensor_list = set([param2]) 3827*da0073e9SAndroid Build Coastguard Worker return y in tensor_list and z in tensor_list 3828*da0073e9SAndroid Build Coastguard Worker 3829*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 3830*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 3831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(param, param2), fn(param, param2)) 3832*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 3833*da0073e9SAndroid Build Coastguard Worker # Test aliased 3834*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(param, param), fn(param, param)) 3835*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) # Recompiles 3836*da0073e9SAndroid Build Coastguard Worker 3837*da0073e9SAndroid Build Coastguard Worker def test_reconstructed_name(self): 3838*da0073e9SAndroid Build Coastguard Worker lst = [] 3839*da0073e9SAndroid Build Coastguard Worker 3840*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable 3841*da0073e9SAndroid Build Coastguard Worker def disallowed(g): 3842*da0073e9SAndroid Build Coastguard Worker lst.append(g.__name__) 3843*da0073e9SAndroid Build Coastguard Worker 3844*da0073e9SAndroid Build Coastguard Worker def f(): 3845*da0073e9SAndroid Build Coastguard Worker def g(): 3846*da0073e9SAndroid Build Coastguard Worker return () 3847*da0073e9SAndroid Build Coastguard Worker 3848*da0073e9SAndroid Build Coastguard Worker disallowed(g) 3849*da0073e9SAndroid Build Coastguard Worker 3850*da0073e9SAndroid Build Coastguard Worker f_opt = torch._dynamo 3851*da0073e9SAndroid Build Coastguard Worker opt_f = torch._dynamo.optimize(backend="eager")(f) 3852*da0073e9SAndroid Build Coastguard Worker opt_f() 3853*da0073e9SAndroid Build Coastguard Worker f() 3854*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(lst), 2) 3855*da0073e9SAndroid Build Coastguard Worker self.assertEqual(lst[0], lst[1]) 3856*da0073e9SAndroid Build Coastguard Worker 3857*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 3858*da0073e9SAndroid Build Coastguard Worker sys.version_info < (3, 10), 3859*da0073e9SAndroid Build Coastguard Worker "zip strict kwargs not implemented for Python < 3.10", 3860*da0073e9SAndroid Build Coastguard Worker ) 3861*da0073e9SAndroid Build Coastguard Worker def test_zip_strict(self): 3862*da0073e9SAndroid Build Coastguard Worker def fn(x, ys, zs): 3863*da0073e9SAndroid Build Coastguard Worker x = x.clone() 3864*da0073e9SAndroid Build Coastguard Worker for y, z in zip(ys, zs, strict=True): 3865*da0073e9SAndroid Build Coastguard Worker x += y * z 3866*da0073e9SAndroid Build Coastguard Worker return x 3867*da0073e9SAndroid Build Coastguard Worker 3868*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(backend="eager")(fn) 3869*da0073e9SAndroid Build Coastguard Worker nopython_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn) 3870*da0073e9SAndroid Build Coastguard Worker 3871*da0073e9SAndroid Build Coastguard Worker x = torch.ones(3) 3872*da0073e9SAndroid Build Coastguard Worker ys = [1.0, 2.0, 3.0] 3873*da0073e9SAndroid Build Coastguard Worker zs = [2.0, 5.0, 8.0] 3874*da0073e9SAndroid Build Coastguard Worker 3875*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, ys, zs), fn(x, ys, zs)) 3876*da0073e9SAndroid Build Coastguard Worker 3877*da0073e9SAndroid Build Coastguard Worker # If nopython, should raise UserError 3878*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): 3879*da0073e9SAndroid Build Coastguard Worker nopython_fn(x, ys[:1], zs) 3880*da0073e9SAndroid Build Coastguard Worker 3881*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): 3882*da0073e9SAndroid Build Coastguard Worker nopython_fn(x, ys, zs[:1]) 3883*da0073e9SAndroid Build Coastguard Worker 3884*da0073e9SAndroid Build Coastguard Worker # Should cause fallback if allow graph break 3885*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "zip()"): 3886*da0073e9SAndroid Build Coastguard Worker opt_fn(x, ys[:1], zs) 3887*da0073e9SAndroid Build Coastguard Worker 3888*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "zip()"): 3889*da0073e9SAndroid Build Coastguard Worker opt_fn(x, ys, zs[:1]) 3890*da0073e9SAndroid Build Coastguard Worker 3891*da0073e9SAndroid Build Coastguard Worker def test_fn_with_attr(self): 3892*da0073e9SAndroid Build Coastguard Worker def fn(x): 3893*da0073e9SAndroid Build Coastguard Worker if fn.pred: 3894*da0073e9SAndroid Build Coastguard Worker return torch.relu(x * 2) 3895*da0073e9SAndroid Build Coastguard Worker else: 3896*da0073e9SAndroid Build Coastguard Worker return torch.abs(x + 3) 3897*da0073e9SAndroid Build Coastguard Worker 3898*da0073e9SAndroid Build Coastguard Worker t = torch.ones(3) 3899*da0073e9SAndroid Build Coastguard Worker counter = torch._dynamo.testing.CompileCounter() 3900*da0073e9SAndroid Build Coastguard Worker fn.pred = True 3901*da0073e9SAndroid Build Coastguard Worker opt_fn_0 = torch.compile(fullgraph=True, backend=counter)(fn) 3902*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn_0(t), fn(t)) 3903*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 3904*da0073e9SAndroid Build Coastguard Worker fn.pred = False 3905*da0073e9SAndroid Build Coastguard Worker opt_fn_1 = torch.compile(fullgraph=True, backend=counter)(fn) 3906*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn_1(t), fn(t)) 3907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 3908*da0073e9SAndroid Build Coastguard Worker 3909*da0073e9SAndroid Build Coastguard Worker def test_str_handler_for_user_defined_object(self): 3910*da0073e9SAndroid Build Coastguard Worker """ 3911*da0073e9SAndroid Build Coastguard Worker Confirms handler behaviour for `str` is the same between eager and dynamo. 3912*da0073e9SAndroid Build Coastguard Worker Compares a user defined object with custom `__str__` method and without. 3913*da0073e9SAndroid Build Coastguard Worker """ 3914*da0073e9SAndroid Build Coastguard Worker 3915*da0073e9SAndroid Build Coastguard Worker class CustomStr: 3916*da0073e9SAndroid Build Coastguard Worker def __str__(self): 3917*da0073e9SAndroid Build Coastguard Worker return "ok" 3918*da0073e9SAndroid Build Coastguard Worker 3919*da0073e9SAndroid Build Coastguard Worker def foo_custom_str(x): 3920*da0073e9SAndroid Build Coastguard Worker a = CustomStr() 3921*da0073e9SAndroid Build Coastguard Worker return x, str(a) 3922*da0073e9SAndroid Build Coastguard Worker 3923*da0073e9SAndroid Build Coastguard Worker eager_custom_str = foo_custom_str(torch.ones(4)) 3924*da0073e9SAndroid Build Coastguard Worker dynamo_custom_str = torch.compile(foo_custom_str, fullgraph=True)(torch.ones(4)) 3925*da0073e9SAndroid Build Coastguard Worker 3926*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_custom_str[1], dynamo_custom_str[1]) 3927*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_custom_str[1], "ok") 3928*da0073e9SAndroid Build Coastguard Worker 3929*da0073e9SAndroid Build Coastguard Worker class DefaultStr: 3930*da0073e9SAndroid Build Coastguard Worker pass 3931*da0073e9SAndroid Build Coastguard Worker 3932*da0073e9SAndroid Build Coastguard Worker def foo_default_str(x): 3933*da0073e9SAndroid Build Coastguard Worker a = DefaultStr() 3934*da0073e9SAndroid Build Coastguard Worker return x, str(a) 3935*da0073e9SAndroid Build Coastguard Worker 3936*da0073e9SAndroid Build Coastguard Worker eager_default_str = foo_default_str(torch.ones(4)) 3937*da0073e9SAndroid Build Coastguard Worker dynamo_default_str = torch.compile(foo_default_str, fullgraph=True)( 3938*da0073e9SAndroid Build Coastguard Worker torch.ones(4) 3939*da0073e9SAndroid Build Coastguard Worker ) 3940*da0073e9SAndroid Build Coastguard Worker 3941*da0073e9SAndroid Build Coastguard Worker # Check that the tensor output from eager and dynamo modes are the same 3942*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_default_str[0], dynamo_default_str[0]) 3943*da0073e9SAndroid Build Coastguard Worker 3944*da0073e9SAndroid Build Coastguard Worker # Check that the class name (without memory address) is the same in both modes 3945*da0073e9SAndroid Build Coastguard Worker eager_class_name = eager_default_str[1].split(" object at")[0] 3946*da0073e9SAndroid Build Coastguard Worker dynamo_class_name = dynamo_default_str[1].split(" object at")[0] 3947*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_class_name, dynamo_class_name) 3948*da0073e9SAndroid Build Coastguard Worker 3949*da0073e9SAndroid Build Coastguard Worker def test_pybind_object(self): 3950*da0073e9SAndroid Build Coastguard Worker def fn(x, pybind_obj): 3951*da0073e9SAndroid Build Coastguard Worker if pybind_obj.result: 3952*da0073e9SAndroid Build Coastguard Worker return torch.cos(x) 3953*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 3954*da0073e9SAndroid Build Coastguard Worker 3955*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 3956*da0073e9SAndroid Build Coastguard Worker 3957*da0073e9SAndroid Build Coastguard Worker pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(True, ["a==1"], 0) 3958*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 3959*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj)) 3960*da0073e9SAndroid Build Coastguard Worker 3961*da0073e9SAndroid Build Coastguard Worker pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(False, ["a==1"], 1) 3962*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 3963*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj)) 3964*da0073e9SAndroid Build Coastguard Worker 3965*da0073e9SAndroid Build Coastguard Worker 3966*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(FunctionTests) 3967*da0073e9SAndroid Build Coastguard Worker 3968*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 3969*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 3970*da0073e9SAndroid Build Coastguard Worker 3971*da0073e9SAndroid Build Coastguard Worker run_tests() 3972