1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: decompositions"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 4*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (parametrize, run_tests, TestCase, TEST_SCIPY, 10*da0073e9SAndroid Build Coastguard Worker set_default_dtype) 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 12*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 13*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 14*da0073e9SAndroid Build Coastguard Worker dtypes, 15*da0073e9SAndroid Build Coastguard Worker OpDTypes, 16*da0073e9SAndroid Build Coastguard Worker) 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import ( 18*da0073e9SAndroid Build Coastguard Worker op_db, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 21*da0073e9SAndroid Build Coastguard Worker ops, 22*da0073e9SAndroid Build Coastguard Worker) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, log_input 25*da0073e9SAndroid Build Coastguard Workerimport torch._prims as prims 26*da0073e9SAndroid Build Coastguard Workerfrom torch._prims_common import CUDARngStateHelper 27*da0073e9SAndroid Build Coastguard Workerfrom torch._prims.executor import make_traced 28*da0073e9SAndroid Build Coastguard Workerimport torch._refs as refs 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY: 32*da0073e9SAndroid Build Coastguard Worker import scipy.special 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard WorkerNVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor" 35*da0073e9SAndroid Build Coastguard WorkerGET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decomposition" 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Workerclass TestPrims(TestCase): 38*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 39*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 40*da0073e9SAndroid Build Coastguard Worker def test_broadcast_in_dim(self, device, dtype): 41*da0073e9SAndroid Build Coastguard Worker def _wrapper(a, b, broadcast_dimensions): 42*da0073e9SAndroid Build Coastguard Worker return prims.broadcast_in_dim(a, b.shape, broadcast_dimensions) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker traced = make_traced(_wrapper) 45*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker for executor in ('aten',): 48*da0073e9SAndroid Build Coastguard Worker fn = partial(traced, executor=executor) 49*da0073e9SAndroid Build Coastguard Worker # Same shape 50*da0073e9SAndroid Build Coastguard Worker shape = (5, 5) 51*da0073e9SAndroid Build Coastguard Worker a = make_arg(shape) 52*da0073e9SAndroid Build Coastguard Worker b = make_arg(shape, low=0.0, high=0.0) 53*da0073e9SAndroid Build Coastguard Worker result = fn(a, b, (0, 1)) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, a.shape) 56*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.is_contiguous) 57*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, result) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker # Error input: reordering dims 60*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 61*da0073e9SAndroid Build Coastguard Worker result = fn(a, b, (1, 0)) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker # Adding outermost dimensions 64*da0073e9SAndroid Build Coastguard Worker a = make_arg((5, 5)) 65*da0073e9SAndroid Build Coastguard Worker b = make_arg((3, 3, 5, 5), low=0.0, high=0.0) 66*da0073e9SAndroid Build Coastguard Worker result = fn(a, b, (2, 3)) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, b.shape) 69*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.broadcast_to(b.shape), result) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker # Expands 72*da0073e9SAndroid Build Coastguard Worker a = make_arg((1, 5, 1)) 73*da0073e9SAndroid Build Coastguard Worker b = make_arg((3, 5, 7), low=0.0, high=0.0) 74*da0073e9SAndroid Build Coastguard Worker result = fn(a, b, (0, 1, 2)) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, b.shape) 77*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.expand_as(result), result) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker # Unsqueezes 80*da0073e9SAndroid Build Coastguard Worker a = make_arg((1, 2, 3)) 81*da0073e9SAndroid Build Coastguard Worker b = make_arg((1, 2, 1, 3), low=0.0, high=0.0) 82*da0073e9SAndroid Build Coastguard Worker result = fn(a, b, (0, 1, 3)) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, b.shape) 85*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.unsqueeze(2), result) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 88*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 89*da0073e9SAndroid Build Coastguard Worker def test_broadcast_in_dim_sum(self, device, dtype): 90*da0073e9SAndroid Build Coastguard Worker def _wrapper(a): 91*da0073e9SAndroid Build Coastguard Worker a_sum = prims.sum(a, [0, 1]) 92*da0073e9SAndroid Build Coastguard Worker a_bc = prims.broadcast_in_dim(a_sum, [], []) 93*da0073e9SAndroid Build Coastguard Worker return a_bc 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker traced = make_traced(_wrapper) 96*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker for executor in ('aten',): 99*da0073e9SAndroid Build Coastguard Worker fn = partial(traced, executor=executor) 100*da0073e9SAndroid Build Coastguard Worker shape = (5, 5) 101*da0073e9SAndroid Build Coastguard Worker a = make_arg(shape) 102*da0073e9SAndroid Build Coastguard Worker result = fn(a) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, ()) 105*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.is_contiguous) 106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_wrapper(a), result) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 109*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float64, torch.long) 110*da0073e9SAndroid Build Coastguard Worker def test_cbrt_prim(self, device, dtype): 111*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 112*da0073e9SAndroid Build Coastguard Worker batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)] 113*da0073e9SAndroid Build Coastguard Worker shapes = [(), (0,), (1,), (5,)] 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker # Sets the default dtype to NumPy's default dtype of double 116*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 117*da0073e9SAndroid Build Coastguard Worker # Tested here, as this OP is not currently exposed or tested in ATen 118*da0073e9SAndroid Build Coastguard Worker for b, s in product(batches, shapes): 119*da0073e9SAndroid Build Coastguard Worker x = make_arg(b + s) 120*da0073e9SAndroid Build Coastguard Worker y = prims.cbrt(x) 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().numpy() 123*da0073e9SAndroid Build Coastguard Worker y_np = scipy.special.cbrt(x_np) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_np, exact_device=False) 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 128*da0073e9SAndroid Build Coastguard Worker def test_collapse(self, device, dtype): 129*da0073e9SAndroid Build Coastguard Worker t = torch.rand(2, 2, 2) 130*da0073e9SAndroid Build Coastguard Worker dim_ranges = [(0, 0), (0, 1), (1, 2), (0, 2)] 131*da0073e9SAndroid Build Coastguard Worker expected_shapes = [(2, 2, 2), (4, 2), (2, 4), (8,)] 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker for (start, end), shape in zip(dim_ranges, expected_shapes): 134*da0073e9SAndroid Build Coastguard Worker expect = t.reshape(shape) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker copy = prims.collapse(t, start, end) 137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copy, expect) 138*da0073e9SAndroid Build Coastguard Worker self.assertFalse(copy._is_view()) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker view = prims.collapse_view(t, start, end) 141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(view, expect) 142*da0073e9SAndroid Build Coastguard Worker self.assertTrue(view._is_view()) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker t_discontig = t.transpose(0, 1) 145*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError, msg="no such view exists"): 146*da0073e9SAndroid Build Coastguard Worker view = prims.collapse_view(t_discontig, 0, 2) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker copy = prims.collapse(t_discontig, 0, 1) 149*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copy, t_discontig.reshape(4, 2)) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker error_dims = [(-1, 1), (0, 3), (1, -1)] 152*da0073e9SAndroid Build Coastguard Worker for start, end in error_dims: 153*da0073e9SAndroid Build Coastguard Worker for fn in [prims.collapse, prims.collapse_view]: 154*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 155*da0073e9SAndroid Build Coastguard Worker fn(t, start, end) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker def test_aten_overload_to_prims(self, device): 159*da0073e9SAndroid Build Coastguard Worker # This test is to ensure that the torch.ops.aten calls are replaced with refs 160*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.proxy_tensor import make_fx 161*da0073e9SAndroid Build Coastguard Worker from torch._prims.context import TorchRefsMode 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3, device=device) 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker def func(a): 166*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten.sigmoid.default(torch.ops.aten.digamma.default(a)) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker with TorchRefsMode(): 169*da0073e9SAndroid Build Coastguard Worker gm = make_fx(func)(a) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker # Check that all call_function nodes are prims 172*da0073e9SAndroid Build Coastguard Worker call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes)) 173*da0073e9SAndroid Build Coastguard Worker all_prims_namespace = all( 174*da0073e9SAndroid Build Coastguard Worker node.target.name().startswith("prims") for node in call_function_nodes 175*da0073e9SAndroid Build Coastguard Worker ) 176*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all_prims_namespace) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 179*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 180*da0073e9SAndroid Build Coastguard Worker @parametrize("correction", [0, 1]) 181*da0073e9SAndroid Build Coastguard Worker def test_var(self, device, dtype, correction): 182*da0073e9SAndroid Build Coastguard Worker def _wrapper(a): 183*da0073e9SAndroid Build Coastguard Worker return prims.var(a, [0, 1], correction=correction) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker traced = make_traced(_wrapper) 186*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker for executor in ('aten',): 189*da0073e9SAndroid Build Coastguard Worker fn = partial(traced, executor=executor) 190*da0073e9SAndroid Build Coastguard Worker shape = (5, 5) 191*da0073e9SAndroid Build Coastguard Worker a = make_arg(shape) 192*da0073e9SAndroid Build Coastguard Worker result = fn(a) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, ()) 195*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.is_contiguous) 196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_wrapper(a), result) 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 199*da0073e9SAndroid Build Coastguard Worker def test_memory_format_strides(self, device, dtype): 200*da0073e9SAndroid Build Coastguard Worker shapes = ( 201*da0073e9SAndroid Build Coastguard Worker (), 202*da0073e9SAndroid Build Coastguard Worker (0,), 203*da0073e9SAndroid Build Coastguard Worker (1,), 204*da0073e9SAndroid Build Coastguard Worker (5), 205*da0073e9SAndroid Build Coastguard Worker (1, 0), 206*da0073e9SAndroid Build Coastguard Worker (1, 1), 207*da0073e9SAndroid Build Coastguard Worker (3, 7), 208*da0073e9SAndroid Build Coastguard Worker (3, 0, 2), 209*da0073e9SAndroid Build Coastguard Worker (1, 1, 2), 210*da0073e9SAndroid Build Coastguard Worker (4, 1, 1), 211*da0073e9SAndroid Build Coastguard Worker (7, 8, 9), 212*da0073e9SAndroid Build Coastguard Worker ) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker channels_last_shapes = ( 215*da0073e9SAndroid Build Coastguard Worker (0, 0, 0, 0), 216*da0073e9SAndroid Build Coastguard Worker (1, 0, 3, 0), 217*da0073e9SAndroid Build Coastguard Worker (0, 2, 3, 5), 218*da0073e9SAndroid Build Coastguard Worker (2, 2, 2, 0), 219*da0073e9SAndroid Build Coastguard Worker (5, 4, 3, 2), 220*da0073e9SAndroid Build Coastguard Worker (8, 8, 7, 2), 221*da0073e9SAndroid Build Coastguard Worker (9, 1, 3, 1), 222*da0073e9SAndroid Build Coastguard Worker (4, 5, 8, 7) 223*da0073e9SAndroid Build Coastguard Worker ) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker channels_last_3d_shapes = ( 226*da0073e9SAndroid Build Coastguard Worker (0, 8, 7, 9, 2), 227*da0073e9SAndroid Build Coastguard Worker (5, 0, 7, 9, 2), 228*da0073e9SAndroid Build Coastguard Worker (5, 0, 7, 9, 0), 229*da0073e9SAndroid Build Coastguard Worker (5, 8, 7, 9, 2), 230*da0073e9SAndroid Build Coastguard Worker (5, 1, 7, 9, 2), 231*da0073e9SAndroid Build Coastguard Worker (5, 1, 7, 9, 1), 232*da0073e9SAndroid Build Coastguard Worker ) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker pairs = ( 235*da0073e9SAndroid Build Coastguard Worker (shapes, torch.contiguous_format), 236*da0073e9SAndroid Build Coastguard Worker (channels_last_shapes, torch.contiguous_format), 237*da0073e9SAndroid Build Coastguard Worker (channels_last_3d_shapes, torch.contiguous_format), 238*da0073e9SAndroid Build Coastguard Worker (channels_last_shapes, torch.channels_last), 239*da0073e9SAndroid Build Coastguard Worker (channels_last_3d_shapes, torch.channels_last_3d), 240*da0073e9SAndroid Build Coastguard Worker ) 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker for shapes, memory_format in pairs: 243*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 244*da0073e9SAndroid Build Coastguard Worker # tests empty 245*da0073e9SAndroid Build Coastguard Worker expected = torch.empty(shape, device=device, dtype=dtype, memory_format=memory_format) 246*da0073e9SAndroid Build Coastguard Worker actual = refs.empty(shape, device=device, dtype=dtype, memory_format=memory_format) 247*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.stride(), actual.stride()) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker # tests clone 250*da0073e9SAndroid Build Coastguard Worker a = torch.testing.make_tensor(shape, device=device, dtype=dtype) 251*da0073e9SAndroid Build Coastguard Worker expected = torch.clone(a, memory_format=memory_format) 252*da0073e9SAndroid Build Coastguard Worker actual = torch.clone(a, memory_format=memory_format) 253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.stride(), actual.stride()) 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker # tests contiguous 256*da0073e9SAndroid Build Coastguard Worker a = torch.testing.make_tensor(shape, device=device, dtype=dtype, noncontiguous=True) 257*da0073e9SAndroid Build Coastguard Worker expected = a.contiguous(memory_format=memory_format) 258*da0073e9SAndroid Build Coastguard Worker actual = refs.contiguous(a, memory_format=memory_format) 259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.stride(), actual.stride()) 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 262*da0073e9SAndroid Build Coastguard Worker def test_reshape_view_method(self, device, dtype): 263*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 264*da0073e9SAndroid Build Coastguard Worker a = make_arg((5, 5)) 265*da0073e9SAndroid Build Coastguard Worker new_shape = 1, 5, 1, 5 266*da0073e9SAndroid Build Coastguard Worker result_eager = a.reshape(*new_shape) 267*da0073e9SAndroid Build Coastguard Worker result_refs = refs.reshape(a, *new_shape) 268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_eager, result_refs) 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker result_eager = a.view(*new_shape) 271*da0073e9SAndroid Build Coastguard Worker result_refs = refs.view(a, *new_shape) 272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_eager, result_refs) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 276*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 277*da0073e9SAndroid Build Coastguard Worker def test_philox_rand(self, device, dtype): 278*da0073e9SAndroid Build Coastguard Worker sizes = (1000, 1000000) # offsets of 4 and 8 279*da0073e9SAndroid Build Coastguard Worker repeats = 2 # Checks multiple rand calls results with multiple philox_rand calls 280*da0073e9SAndroid Build Coastguard Worker for size in sizes: 281*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(123) 282*da0073e9SAndroid Build Coastguard Worker references = [] 283*da0073e9SAndroid Build Coastguard Worker results = [] 284*da0073e9SAndroid Build Coastguard Worker rng_states = [] 285*da0073e9SAndroid Build Coastguard Worker for _ in range(repeats): 286*da0073e9SAndroid Build Coastguard Worker rng_states.append(CUDARngStateHelper.get_torch_state_as_tuple()) 287*da0073e9SAndroid Build Coastguard Worker references.append(torch.rand(size, device=device, dtype=dtype)) 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed(123) 290*da0073e9SAndroid Build Coastguard Worker for idx in range(repeats): 291*da0073e9SAndroid Build Coastguard Worker seed, offset = rng_states[idx] 292*da0073e9SAndroid Build Coastguard Worker result, _ = torch.ops.rngprims.philox_rand((size,), 293*da0073e9SAndroid Build Coastguard Worker seed=seed, 294*da0073e9SAndroid Build Coastguard Worker offset=offset, 295*da0073e9SAndroid Build Coastguard Worker stride=None, 296*da0073e9SAndroid Build Coastguard Worker device=device, 297*da0073e9SAndroid Build Coastguard Worker dtype=dtype) 298*da0073e9SAndroid Build Coastguard Worker results.append(result) 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker for a, b in zip(references, results): 301*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 305*da0073e9SAndroid Build Coastguard Worker def test_functional_rng_wrappers(self, device, dtype): 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 308*da0073e9SAndroid Build Coastguard Worker ref1 = torch.rand(10, device=device, dtype=dtype) 309*da0073e9SAndroid Build Coastguard Worker ref2 = torch.rand(10, device=device, dtype=dtype) 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 313*da0073e9SAndroid Build Coastguard Worker rng_state1, res1 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype) 314*da0073e9SAndroid Build Coastguard Worker rng_state2, res2 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype) 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker res3 = torch._prims.rng_prims.run_with_rng_state(rng_state1, torch.rand, 10, device=device, dtype=dtype) 317*da0073e9SAndroid Build Coastguard Worker res4 = torch._prims.rng_prims.run_with_rng_state(rng_state2, torch.rand, 10, device=device, dtype=dtype) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref1, res1) 320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref2, res2) 321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref1, res3) 322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref2, res4) 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Workerclass TestPrimsBasic(TestCase): 325*da0073e9SAndroid Build Coastguard Worker def test_torch_ops(self): 326*da0073e9SAndroid Build Coastguard Worker r = make_tensor((2,), device='cpu', dtype=torch.float) 327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ops.prims.sin(r), torch.sin(r)) 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker r = LoggingTensor(r) 330*da0073e9SAndroid Build Coastguard Worker with capture_logs() as logs: 331*da0073e9SAndroid Build Coastguard Worker log_input("input", r) 332*da0073e9SAndroid Build Coastguard Worker prims.sin(r) 333*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline('\n'.join(logs), """\ 334*da0073e9SAndroid Build Coastguard Worker$0: f32[2] = input('input') 335*da0073e9SAndroid Build Coastguard Worker$1: f32[2] = torch._ops.prims.sin.default($0)""") 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker def test_mul_complex(self): 338*da0073e9SAndroid Build Coastguard Worker prims.mul(torch.randn(2), 1 + 1j) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker def test_clone_complex(self): 341*da0073e9SAndroid Build Coastguard Worker with torch._dispatch.python.enable_python_dispatcher(): 342*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, dtype=torch.complex64, device='meta').conj() 343*da0073e9SAndroid Build Coastguard Worker out = x + 1 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker def test_check_deprecation_warning(self): 346*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'): 347*da0073e9SAndroid Build Coastguard Worker torch._prims_common.check(True, lambda: 'message') 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestPrims, globals()) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Workerclass TestRefs(TestCase): 354*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 355*da0073e9SAndroid Build Coastguard Worker def test_constant_pad_nd_memory_format(self, device, dtype): 356*da0073e9SAndroid Build Coastguard Worker # Test memory format is preserved in unambiguous cases 357*da0073e9SAndroid Build Coastguard Worker for mf, ndim in ( 358*da0073e9SAndroid Build Coastguard Worker (torch.channels_last, 4), 359*da0073e9SAndroid Build Coastguard Worker (torch.contiguous_format, 4), 360*da0073e9SAndroid Build Coastguard Worker (torch.channels_last_3d, 5), 361*da0073e9SAndroid Build Coastguard Worker (torch.contiguous_format, 5), 362*da0073e9SAndroid Build Coastguard Worker ): 363*da0073e9SAndroid Build Coastguard Worker a = torch.zeros([2] * ndim).to(memory_format=mf) 364*da0073e9SAndroid Build Coastguard Worker res = refs.constant_pad_nd(a, pad=[1] * (2 * ndim)) 365*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res.is_contiguous(memory_format=mf)) 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker # Ambiguous cases 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker # is_channels_last_ and is_contiguous_, results in channels_last output 370*da0073e9SAndroid Build Coastguard Worker a = torch.empty_strided((2, 1, 2, 2), stride=(4, 1, 2, 1)) 371*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.is_contiguous(memory_format=torch.channels_last)) 372*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.is_contiguous()) 373*da0073e9SAndroid Build Coastguard Worker actual = refs.constant_pad_nd(a, pad=[1] * 8) 374*da0073e9SAndroid Build Coastguard Worker expect = torch.constant_pad_nd(a, pad=[1] * 8) 375*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.stride(), expect.stride()) 376*da0073e9SAndroid Build Coastguard Worker self.assertTrue(actual.is_contiguous(memory_format=torch.channels_last)) 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker # is_channels_last_contiguous_ but not is_channels_last_, results in 379*da0073e9SAndroid Build Coastguard Worker # contiguous output 380*da0073e9SAndroid Build Coastguard Worker a = torch.empty_strided((2, 1, 2, 2), stride=(4, 4, 2, 1)) 381*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.is_contiguous(memory_format=torch.channels_last)) 382*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.is_contiguous()) 383*da0073e9SAndroid Build Coastguard Worker actual = refs.constant_pad_nd(a, pad=[1] * 8) 384*da0073e9SAndroid Build Coastguard Worker expect = torch.constant_pad_nd(a, pad=[1] * 8) 385*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.stride(), expect.stride()) 386*da0073e9SAndroid Build Coastguard Worker self.assertTrue(actual.is_contiguous()) 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker def test_unbind(self): 389*da0073e9SAndroid Build Coastguard Worker # If unbind returns empty tuple, it breaks some assumptions in some backward tests in test_ops.py. 390*da0073e9SAndroid Build Coastguard Worker # So can't put this test into common_methods_invocations.py. 391*da0073e9SAndroid Build Coastguard Worker a = torch.rand([3, 0, 4]) 392*da0073e9SAndroid Build Coastguard Worker actual = refs.unbind(a, 1) 393*da0073e9SAndroid Build Coastguard Worker expect = torch.unbind(a, 1) 394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker def test_logspace_with_complex_input(self): 397*da0073e9SAndroid Build Coastguard Worker actual = refs.logspace(2, 10 + 5j, steps=5) 398*da0073e9SAndroid Build Coastguard Worker expect = torch.logspace(2, 10 + 5j, steps=5) 399*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker def test_linspace_with_complex_input(self): 402*da0073e9SAndroid Build Coastguard Worker actual = refs.linspace(2, 10 + 5j, steps=5) 403*da0073e9SAndroid Build Coastguard Worker expect = torch.linspace(2, 10 + 5j, steps=5) 404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker # From https://github.com/pytorch/pytorch/issues/109558 407*da0073e9SAndroid Build Coastguard Worker def test_infinite_loop_from_py_dispatcher(self): 408*da0073e9SAndroid Build Coastguard Worker # enables prim decomps 409*da0073e9SAndroid Build Coastguard Worker with torch._dispatch.python.enable_python_dispatcher(): 410*da0073e9SAndroid Build Coastguard Worker x = torch.ones(4) 411*da0073e9SAndroid Build Coastguard Worker y = x.to(device="meta") 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker def test_inferred_tags(self): 414*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ops.prims.normal.default.tags, (torch.Tag.nondeterministic_seeded, torch.Tag.pt2_compliant_tag)) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestRefs, globals()) 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Workerclass TestDecomp(TestCase): 422*da0073e9SAndroid Build Coastguard Worker @ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one) 423*da0073e9SAndroid Build Coastguard Worker def test_decomposition_method_vararg(self, device, dtype, op): 424*da0073e9SAndroid Build Coastguard Worker # some ops have vararg variants for the methods. this tests it. 425*da0073e9SAndroid Build Coastguard Worker # we don't have tests for varargs in OpInfo, so we need to 426*da0073e9SAndroid Build Coastguard Worker # improvise this a bit. 427*da0073e9SAndroid Build Coastguard Worker # The rule for general functions (the special cases being e.g. tensor 428*da0073e9SAndroid Build Coastguard Worker # creation functions taking shapes) is that things can be vararg 429*da0073e9SAndroid Build Coastguard Worker # if the method has only one argument of sequence type. 430*da0073e9SAndroid Build Coastguard Worker # e.g. permute can be called on a 3d tensor t as t.permute(0, 2, 1) 431*da0073e9SAndroid Build Coastguard Worker # as well as t.permute([0, 2, 1]) 432*da0073e9SAndroid Build Coastguard Worker # when the signature in native_functions.yaml 433*da0073e9SAndroid Build Coastguard Worker # shows arguments Tensor self, IntList dims 434*da0073e9SAndroid Build Coastguard Worker # we might need to adjust things for the factory functions or 435*da0073e9SAndroid Build Coastguard Worker # have them do their own test 436*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.proxy_tensor import make_fx 437*da0073e9SAndroid Build Coastguard Worker from torch._prims.context import TorchRefsMode 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker # filter out empty tuple as that cannot be the varargs 440*da0073e9SAndroid Build Coastguard Worker sample_inputs = (si for si in op.sample_inputs(device, dtype, requires_grad=False) 441*da0073e9SAndroid Build Coastguard Worker if (si.args[-1] if si.args else si.input)) 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker # just run one test, we assume there is a suitable one in the tests 444*da0073e9SAndroid Build Coastguard Worker sample_input = next(sample_inputs) 445*da0073e9SAndroid Build Coastguard Worker all_args = (sample_input.input,) + sample_input.args 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker # in general, the methods take varargs and not (always?) the function 448*da0073e9SAndroid Build Coastguard Worker # variants, the exception to this rule are the factory functions 449*da0073e9SAndroid Build Coastguard Worker if op.is_factory_function: 450*da0073e9SAndroid Build Coastguard Worker fn = op.op 451*da0073e9SAndroid Build Coastguard Worker else: 452*da0073e9SAndroid Build Coastguard Worker fn = op.method_variant 453*da0073e9SAndroid Build Coastguard Worker with TorchRefsMode(): 454*da0073e9SAndroid Build Coastguard Worker gm = make_fx(fn)(*all_args[:-1], *all_args[-1]) 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker # in case we add random factory functions 457*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 458*da0073e9SAndroid Build Coastguard Worker res = gm(*all_args[:-1], *all_args[-1]) 459*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 460*da0073e9SAndroid Build Coastguard Worker expected = fn(*all_args[:-1], *all_args[-1]) 461*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestDecomp, globals()) 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 468*da0073e9SAndroid Build Coastguard Worker run_tests() 469