1# Owner(s): ["module: decompositions"] 2 3from functools import partial 4from itertools import product 5import unittest 6 7import torch 8from torch.testing import make_tensor 9from torch.testing._internal.common_utils import (parametrize, run_tests, TestCase, TEST_SCIPY, 10 set_default_dtype) 11from torch.testing._internal.common_device_type import ( 12 instantiate_device_type_tests, 13 onlyCUDA, 14 dtypes, 15 OpDTypes, 16) 17from torch.testing._internal.common_methods_invocations import ( 18 op_db, 19) 20from torch.testing._internal.common_device_type import ( 21 ops, 22) 23 24from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, log_input 25import torch._prims as prims 26from torch._prims_common import CUDARngStateHelper 27from torch._prims.executor import make_traced 28import torch._refs as refs 29 30 31if TEST_SCIPY: 32 import scipy.special 33 34NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor" 35GET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decomposition" 36 37class TestPrims(TestCase): 38 @onlyCUDA 39 @dtypes(torch.float32) 40 def test_broadcast_in_dim(self, device, dtype): 41 def _wrapper(a, b, broadcast_dimensions): 42 return prims.broadcast_in_dim(a, b.shape, broadcast_dimensions) 43 44 traced = make_traced(_wrapper) 45 make_arg = partial(make_tensor, device=device, dtype=dtype) 46 47 for executor in ('aten',): 48 fn = partial(traced, executor=executor) 49 # Same shape 50 shape = (5, 5) 51 a = make_arg(shape) 52 b = make_arg(shape, low=0.0, high=0.0) 53 result = fn(a, b, (0, 1)) 54 55 self.assertEqual(result.shape, a.shape) 56 self.assertTrue(result.is_contiguous) 57 self.assertEqual(a, result) 58 59 # Error input: reordering dims 60 with self.assertRaises(Exception): 61 result = fn(a, b, (1, 0)) 62 63 # Adding outermost dimensions 64 a = make_arg((5, 5)) 65 b = make_arg((3, 3, 5, 5), low=0.0, high=0.0) 66 result = fn(a, b, (2, 3)) 67 68 self.assertEqual(result.shape, b.shape) 69 self.assertEqual(a.broadcast_to(b.shape), result) 70 71 # Expands 72 a = make_arg((1, 5, 1)) 73 b = make_arg((3, 5, 7), low=0.0, high=0.0) 74 result = fn(a, b, (0, 1, 2)) 75 76 self.assertEqual(result.shape, b.shape) 77 self.assertEqual(a.expand_as(result), result) 78 79 # Unsqueezes 80 a = make_arg((1, 2, 3)) 81 b = make_arg((1, 2, 1, 3), low=0.0, high=0.0) 82 result = fn(a, b, (0, 1, 3)) 83 84 self.assertEqual(result.shape, b.shape) 85 self.assertEqual(a.unsqueeze(2), result) 86 87 @onlyCUDA 88 @dtypes(torch.float32) 89 def test_broadcast_in_dim_sum(self, device, dtype): 90 def _wrapper(a): 91 a_sum = prims.sum(a, [0, 1]) 92 a_bc = prims.broadcast_in_dim(a_sum, [], []) 93 return a_bc 94 95 traced = make_traced(_wrapper) 96 make_arg = partial(make_tensor, device=device, dtype=dtype) 97 98 for executor in ('aten',): 99 fn = partial(traced, executor=executor) 100 shape = (5, 5) 101 a = make_arg(shape) 102 result = fn(a) 103 104 self.assertEqual(result.shape, ()) 105 self.assertTrue(result.is_contiguous) 106 self.assertEqual(_wrapper(a), result) 107 108 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 109 @dtypes(torch.float64, torch.long) 110 def test_cbrt_prim(self, device, dtype): 111 make_arg = partial(make_tensor, device=device, dtype=dtype) 112 batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)] 113 shapes = [(), (0,), (1,), (5,)] 114 115 # Sets the default dtype to NumPy's default dtype of double 116 with set_default_dtype(torch.double): 117 # Tested here, as this OP is not currently exposed or tested in ATen 118 for b, s in product(batches, shapes): 119 x = make_arg(b + s) 120 y = prims.cbrt(x) 121 122 x_np = x.cpu().numpy() 123 y_np = scipy.special.cbrt(x_np) 124 125 self.assertEqual(y, y_np, exact_device=False) 126 127 @dtypes(torch.float32) 128 def test_collapse(self, device, dtype): 129 t = torch.rand(2, 2, 2) 130 dim_ranges = [(0, 0), (0, 1), (1, 2), (0, 2)] 131 expected_shapes = [(2, 2, 2), (4, 2), (2, 4), (8,)] 132 133 for (start, end), shape in zip(dim_ranges, expected_shapes): 134 expect = t.reshape(shape) 135 136 copy = prims.collapse(t, start, end) 137 self.assertEqual(copy, expect) 138 self.assertFalse(copy._is_view()) 139 140 view = prims.collapse_view(t, start, end) 141 self.assertEqual(view, expect) 142 self.assertTrue(view._is_view()) 143 144 t_discontig = t.transpose(0, 1) 145 with self.assertRaises(ValueError, msg="no such view exists"): 146 view = prims.collapse_view(t_discontig, 0, 2) 147 148 copy = prims.collapse(t_discontig, 0, 1) 149 self.assertEqual(copy, t_discontig.reshape(4, 2)) 150 151 error_dims = [(-1, 1), (0, 3), (1, -1)] 152 for start, end in error_dims: 153 for fn in [prims.collapse, prims.collapse_view]: 154 with self.assertRaises(AssertionError): 155 fn(t, start, end) 156 157 158 def test_aten_overload_to_prims(self, device): 159 # This test is to ensure that the torch.ops.aten calls are replaced with refs 160 from torch.fx.experimental.proxy_tensor import make_fx 161 from torch._prims.context import TorchRefsMode 162 163 a = torch.randn(3, 3, device=device) 164 165 def func(a): 166 return torch.ops.aten.sigmoid.default(torch.ops.aten.digamma.default(a)) 167 168 with TorchRefsMode(): 169 gm = make_fx(func)(a) 170 171 # Check that all call_function nodes are prims 172 call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes)) 173 all_prims_namespace = all( 174 node.target.name().startswith("prims") for node in call_function_nodes 175 ) 176 self.assertTrue(all_prims_namespace) 177 178 @onlyCUDA 179 @dtypes(torch.float32) 180 @parametrize("correction", [0, 1]) 181 def test_var(self, device, dtype, correction): 182 def _wrapper(a): 183 return prims.var(a, [0, 1], correction=correction) 184 185 traced = make_traced(_wrapper) 186 make_arg = partial(make_tensor, device=device, dtype=dtype) 187 188 for executor in ('aten',): 189 fn = partial(traced, executor=executor) 190 shape = (5, 5) 191 a = make_arg(shape) 192 result = fn(a) 193 194 self.assertEqual(result.shape, ()) 195 self.assertTrue(result.is_contiguous) 196 self.assertEqual(_wrapper(a), result) 197 198 @dtypes(torch.float32) 199 def test_memory_format_strides(self, device, dtype): 200 shapes = ( 201 (), 202 (0,), 203 (1,), 204 (5), 205 (1, 0), 206 (1, 1), 207 (3, 7), 208 (3, 0, 2), 209 (1, 1, 2), 210 (4, 1, 1), 211 (7, 8, 9), 212 ) 213 214 channels_last_shapes = ( 215 (0, 0, 0, 0), 216 (1, 0, 3, 0), 217 (0, 2, 3, 5), 218 (2, 2, 2, 0), 219 (5, 4, 3, 2), 220 (8, 8, 7, 2), 221 (9, 1, 3, 1), 222 (4, 5, 8, 7) 223 ) 224 225 channels_last_3d_shapes = ( 226 (0, 8, 7, 9, 2), 227 (5, 0, 7, 9, 2), 228 (5, 0, 7, 9, 0), 229 (5, 8, 7, 9, 2), 230 (5, 1, 7, 9, 2), 231 (5, 1, 7, 9, 1), 232 ) 233 234 pairs = ( 235 (shapes, torch.contiguous_format), 236 (channels_last_shapes, torch.contiguous_format), 237 (channels_last_3d_shapes, torch.contiguous_format), 238 (channels_last_shapes, torch.channels_last), 239 (channels_last_3d_shapes, torch.channels_last_3d), 240 ) 241 242 for shapes, memory_format in pairs: 243 for shape in shapes: 244 # tests empty 245 expected = torch.empty(shape, device=device, dtype=dtype, memory_format=memory_format) 246 actual = refs.empty(shape, device=device, dtype=dtype, memory_format=memory_format) 247 self.assertEqual(expected.stride(), actual.stride()) 248 249 # tests clone 250 a = torch.testing.make_tensor(shape, device=device, dtype=dtype) 251 expected = torch.clone(a, memory_format=memory_format) 252 actual = torch.clone(a, memory_format=memory_format) 253 self.assertEqual(expected.stride(), actual.stride()) 254 255 # tests contiguous 256 a = torch.testing.make_tensor(shape, device=device, dtype=dtype, noncontiguous=True) 257 expected = a.contiguous(memory_format=memory_format) 258 actual = refs.contiguous(a, memory_format=memory_format) 259 self.assertEqual(expected.stride(), actual.stride()) 260 261 @dtypes(torch.float32) 262 def test_reshape_view_method(self, device, dtype): 263 make_arg = partial(make_tensor, device=device, dtype=dtype) 264 a = make_arg((5, 5)) 265 new_shape = 1, 5, 1, 5 266 result_eager = a.reshape(*new_shape) 267 result_refs = refs.reshape(a, *new_shape) 268 self.assertEqual(result_eager, result_refs) 269 270 result_eager = a.view(*new_shape) 271 result_refs = refs.view(a, *new_shape) 272 self.assertEqual(result_eager, result_refs) 273 274 275 @onlyCUDA 276 @dtypes(torch.float32) 277 def test_philox_rand(self, device, dtype): 278 sizes = (1000, 1000000) # offsets of 4 and 8 279 repeats = 2 # Checks multiple rand calls results with multiple philox_rand calls 280 for size in sizes: 281 torch.cuda.manual_seed(123) 282 references = [] 283 results = [] 284 rng_states = [] 285 for _ in range(repeats): 286 rng_states.append(CUDARngStateHelper.get_torch_state_as_tuple()) 287 references.append(torch.rand(size, device=device, dtype=dtype)) 288 289 torch.cuda.manual_seed(123) 290 for idx in range(repeats): 291 seed, offset = rng_states[idx] 292 result, _ = torch.ops.rngprims.philox_rand((size,), 293 seed=seed, 294 offset=offset, 295 stride=None, 296 device=device, 297 dtype=dtype) 298 results.append(result) 299 300 for a, b in zip(references, results): 301 self.assertEqual(a, b) 302 303 304 @dtypes(torch.float32) 305 def test_functional_rng_wrappers(self, device, dtype): 306 307 torch.manual_seed(123) 308 ref1 = torch.rand(10, device=device, dtype=dtype) 309 ref2 = torch.rand(10, device=device, dtype=dtype) 310 311 312 torch.manual_seed(123) 313 rng_state1, res1 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype) 314 rng_state2, res2 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype) 315 316 res3 = torch._prims.rng_prims.run_with_rng_state(rng_state1, torch.rand, 10, device=device, dtype=dtype) 317 res4 = torch._prims.rng_prims.run_with_rng_state(rng_state2, torch.rand, 10, device=device, dtype=dtype) 318 319 self.assertEqual(ref1, res1) 320 self.assertEqual(ref2, res2) 321 self.assertEqual(ref1, res3) 322 self.assertEqual(ref2, res4) 323 324class TestPrimsBasic(TestCase): 325 def test_torch_ops(self): 326 r = make_tensor((2,), device='cpu', dtype=torch.float) 327 self.assertEqual(torch.ops.prims.sin(r), torch.sin(r)) 328 329 r = LoggingTensor(r) 330 with capture_logs() as logs: 331 log_input("input", r) 332 prims.sin(r) 333 self.assertExpectedInline('\n'.join(logs), """\ 334$0: f32[2] = input('input') 335$1: f32[2] = torch._ops.prims.sin.default($0)""") 336 337 def test_mul_complex(self): 338 prims.mul(torch.randn(2), 1 + 1j) 339 340 def test_clone_complex(self): 341 with torch._dispatch.python.enable_python_dispatcher(): 342 x = torch.randn(4, dtype=torch.complex64, device='meta').conj() 343 out = x + 1 344 345 def test_check_deprecation_warning(self): 346 with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'): 347 torch._prims_common.check(True, lambda: 'message') 348 349 350instantiate_device_type_tests(TestPrims, globals()) 351 352 353class TestRefs(TestCase): 354 @dtypes(torch.float32) 355 def test_constant_pad_nd_memory_format(self, device, dtype): 356 # Test memory format is preserved in unambiguous cases 357 for mf, ndim in ( 358 (torch.channels_last, 4), 359 (torch.contiguous_format, 4), 360 (torch.channels_last_3d, 5), 361 (torch.contiguous_format, 5), 362 ): 363 a = torch.zeros([2] * ndim).to(memory_format=mf) 364 res = refs.constant_pad_nd(a, pad=[1] * (2 * ndim)) 365 self.assertTrue(res.is_contiguous(memory_format=mf)) 366 367 # Ambiguous cases 368 369 # is_channels_last_ and is_contiguous_, results in channels_last output 370 a = torch.empty_strided((2, 1, 2, 2), stride=(4, 1, 2, 1)) 371 self.assertTrue(a.is_contiguous(memory_format=torch.channels_last)) 372 self.assertTrue(a.is_contiguous()) 373 actual = refs.constant_pad_nd(a, pad=[1] * 8) 374 expect = torch.constant_pad_nd(a, pad=[1] * 8) 375 self.assertEqual(actual.stride(), expect.stride()) 376 self.assertTrue(actual.is_contiguous(memory_format=torch.channels_last)) 377 378 # is_channels_last_contiguous_ but not is_channels_last_, results in 379 # contiguous output 380 a = torch.empty_strided((2, 1, 2, 2), stride=(4, 4, 2, 1)) 381 self.assertTrue(a.is_contiguous(memory_format=torch.channels_last)) 382 self.assertTrue(a.is_contiguous()) 383 actual = refs.constant_pad_nd(a, pad=[1] * 8) 384 expect = torch.constant_pad_nd(a, pad=[1] * 8) 385 self.assertEqual(actual.stride(), expect.stride()) 386 self.assertTrue(actual.is_contiguous()) 387 388 def test_unbind(self): 389 # If unbind returns empty tuple, it breaks some assumptions in some backward tests in test_ops.py. 390 # So can't put this test into common_methods_invocations.py. 391 a = torch.rand([3, 0, 4]) 392 actual = refs.unbind(a, 1) 393 expect = torch.unbind(a, 1) 394 self.assertEqual(actual, expect) 395 396 def test_logspace_with_complex_input(self): 397 actual = refs.logspace(2, 10 + 5j, steps=5) 398 expect = torch.logspace(2, 10 + 5j, steps=5) 399 self.assertEqual(actual, expect) 400 401 def test_linspace_with_complex_input(self): 402 actual = refs.linspace(2, 10 + 5j, steps=5) 403 expect = torch.linspace(2, 10 + 5j, steps=5) 404 self.assertEqual(actual, expect) 405 406 # From https://github.com/pytorch/pytorch/issues/109558 407 def test_infinite_loop_from_py_dispatcher(self): 408 # enables prim decomps 409 with torch._dispatch.python.enable_python_dispatcher(): 410 x = torch.ones(4) 411 y = x.to(device="meta") 412 413 def test_inferred_tags(self): 414 self.assertEqual(torch.ops.prims.normal.default.tags, (torch.Tag.nondeterministic_seeded, torch.Tag.pt2_compliant_tag)) 415 416 417 418instantiate_device_type_tests(TestRefs, globals()) 419 420 421class TestDecomp(TestCase): 422 @ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one) 423 def test_decomposition_method_vararg(self, device, dtype, op): 424 # some ops have vararg variants for the methods. this tests it. 425 # we don't have tests for varargs in OpInfo, so we need to 426 # improvise this a bit. 427 # The rule for general functions (the special cases being e.g. tensor 428 # creation functions taking shapes) is that things can be vararg 429 # if the method has only one argument of sequence type. 430 # e.g. permute can be called on a 3d tensor t as t.permute(0, 2, 1) 431 # as well as t.permute([0, 2, 1]) 432 # when the signature in native_functions.yaml 433 # shows arguments Tensor self, IntList dims 434 # we might need to adjust things for the factory functions or 435 # have them do their own test 436 from torch.fx.experimental.proxy_tensor import make_fx 437 from torch._prims.context import TorchRefsMode 438 439 # filter out empty tuple as that cannot be the varargs 440 sample_inputs = (si for si in op.sample_inputs(device, dtype, requires_grad=False) 441 if (si.args[-1] if si.args else si.input)) 442 443 # just run one test, we assume there is a suitable one in the tests 444 sample_input = next(sample_inputs) 445 all_args = (sample_input.input,) + sample_input.args 446 447 # in general, the methods take varargs and not (always?) the function 448 # variants, the exception to this rule are the factory functions 449 if op.is_factory_function: 450 fn = op.op 451 else: 452 fn = op.method_variant 453 with TorchRefsMode(): 454 gm = make_fx(fn)(*all_args[:-1], *all_args[-1]) 455 456 # in case we add random factory functions 457 torch.manual_seed(1) 458 res = gm(*all_args[:-1], *all_args[-1]) 459 torch.manual_seed(1) 460 expected = fn(*all_args[:-1], *all_args[-1]) 461 self.assertEqual(res, expected) 462 463 464instantiate_device_type_tests(TestDecomp, globals()) 465 466 467if __name__ == "__main__": 468 run_tests() 469