1# Owner(s): ["module: mta"] 2 3import itertools 4import os 5import random 6import re 7import unittest 8import weakref 9from contextlib import nullcontext 10from numbers import Number 11 12import torch 13from torch.testing import make_tensor 14from torch.testing._comparison import default_tolerances 15from torch.testing._internal.common_cuda import TEST_MULTIGPU 16from torch.testing._internal.common_device_type import ( 17 dtypes, 18 instantiate_device_type_tests, 19 onlyCUDA, 20 OpDTypes, 21 ops, 22) 23from torch.testing._internal.common_dtype import ( 24 all_types_and_complex_and, 25 floating_types, 26 floating_types_and, 27 integral_types_and, 28) 29from torch.testing._internal.common_methods_invocations import ( 30 foreach_binary_op_db, 31 foreach_other_op_db, 32 foreach_pointwise_op_db, 33 foreach_reduce_op_db, 34 foreach_unary_op_db, 35) 36from torch.testing._internal.common_utils import ( 37 gradcheck, 38 parametrize, 39 run_tests, 40 skipIfRocmVersionLessThan, 41 skipIfTorchDynamo, 42 TEST_WITH_ROCM, 43 TestCase, 44) 45 46 47_BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator" 48 49 50class RegularFuncWrapper: 51 def __init__(self, func): 52 self.func = func 53 54 def __call__(self, inputs, scalars=None, **kwargs): 55 if scalars is not None: 56 assert len(inputs) == 3 57 # We need to distribute each scalar to the regular func and it needs 58 # special consideration as it is a keyword only argument to the 59 # regular func. (Strangely, it is not a keyword only argument to the 60 # foreach func) 61 return [ 62 self.func(*i, value=scalars[idx], **kwargs) 63 for idx, i in enumerate(zip(*inputs)) 64 ] 65 if len(inputs) == 2 and isinstance(inputs[1], (Number, torch.Tensor)): 66 # binary op with tensorlist and scalar. 67 inputs[1] = [inputs[1] for _ in range(len(inputs[0]))] 68 return [self.func(*i, **kwargs) for i in zip(*inputs)] 69 70 71class ForeachFuncWrapper: 72 def __init__(self, func): 73 self.func = func 74 # Some foreach functions don't have in-place implementations. 75 self.is_inplace = False if func is None else func.__name__.endswith("_") 76 77 def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs): 78 actual = None 79 zero_size = kwargs.pop("zero_size", False) 80 if ( 81 is_cuda 82 and torch.autograd.kineto_available() 83 and torch.profiler.ProfilerActivity.CUDA 84 in torch.profiler.supported_activities() 85 ): 86 with torch.profiler.profile() as p: 87 actual = self.func(*inputs, **kwargs) 88 keys = tuple([e.key for e in p.key_averages()]) 89 mta_called = any("multi_tensor_apply_kernel" in k for k in keys) 90 assert ( 91 mta_called == (expect_fastpath and (not zero_size)) 92 ), f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}" 93 else: 94 actual = self.func(*inputs, **kwargs) 95 if self.is_inplace: 96 assert id(inputs[0]) == id(actual) 97 return actual 98 99 100class InplaceForeachVersionBumpCheck: 101 def __init__( 102 self, 103 testcase: TestCase, 104 tensorlist: "List[torch.Tensor]", # noqa: F821 105 ) -> None: 106 self._testcase = testcase 107 self._tensorlist = tensorlist 108 self._orig_version_counts = [t._version for t in tensorlist] 109 110 def __enter__(self): 111 pass 112 113 def __exit__(self, exc_type, exc_value, traceback): 114 # note(crcrpar): some methods e.g. `_binary_test` could call the given inplace function multiple times 115 self._testcase.assertGreaterEqual( 116 [t._version for t in self._tensorlist], self._orig_version_counts 117 ) 118 119 120def get_transform_func(num_tensors, dtype, device, is_fastpath): 121 def transform(t): 122 if not torch.is_tensor(t): 123 return t 124 if torch.is_tensor(t) and t.ndim == 0: 125 return t 126 return make_tensor( 127 (num_tensors, num_tensors), 128 dtype=dtype, 129 device=device, 130 requires_grad=True, 131 noncontiguous=not is_fastpath, 132 ) 133 134 return transform 135 136 137# note(crcrpar): `zero_size` is `False` unless (dtype, device) == (torch.float32, "cuda") 138# as the pair would go through `multi_tensor_apply_kernel` if inputs are not zero size. 139@unittest.mock.patch.dict(os.environ, {"KINETO_LOG_LEVEL": "5"}) 140class TestForeach(TestCase): 141 @property 142 def is_cuda(self): 143 return self.device_type == "cuda" 144 145 def _get_funcs(self, op): 146 return ( 147 ForeachFuncWrapper(op.method_variant), 148 RegularFuncWrapper(op.ref), 149 ForeachFuncWrapper(op.inplace_variant), 150 RegularFuncWrapper(op.ref_inplace), 151 ) 152 153 # note(crcrpar): Make sure 0-size tensors are appropriately ignored by `multi_tensor_apply` 154 # which is originally reported in https://github.com/pytorch/pytorch/issues/94865. 155 # rel: 156 # - https://github.com/pytorch/pytorch/pull/94655 157 # - https://github.com/pytorch/pytorch/issues/100701 158 # - https://github.com/pytorch/pytorch/pull/100811 159 @onlyCUDA 160 @ops( 161 foreach_unary_op_db 162 + foreach_binary_op_db 163 + foreach_pointwise_op_db 164 + foreach_reduce_op_db 165 + foreach_other_op_db, 166 dtypes=(torch.float32,), 167 ) 168 def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op): 169 wrapped_op, _, inplace_op, _ = self._get_funcs(op) 170 171 for sample in op.sample_zero_size_inputs(device, dtype): 172 if op.method_variant is not None: 173 wrapped_op( 174 (sample.input, *sample.args), 175 is_cuda=self.is_cuda, 176 expect_fastpath=True, 177 zero_size=True, 178 ) 179 180 if op.inplace_variant is not None: 181 with InplaceForeachVersionBumpCheck(self, sample.input): 182 inplace_op( 183 (sample.input, *sample.args), 184 is_cuda=self.is_cuda, 185 expect_fastpath=True, 186 zero_size=True, 187 ) 188 189 @skipIfRocmVersionLessThan((6, 0)) 190 @ops( 191 foreach_unary_op_db 192 + foreach_binary_op_db 193 + foreach_pointwise_op_db 194 + foreach_reduce_op_db 195 + foreach_other_op_db, 196 ) 197 @parametrize( 198 "noncontiguous,inplace", 199 [(False, False), (False, True), (True, False), (True, True)], 200 name_fn=lambda x, y: "{}_{}".format( 201 "fastpath" if not x else "slowpath", "inplace" if y else "outplace" 202 ), 203 ) 204 def test_parity(self, device, dtype, op, noncontiguous, inplace): 205 if inplace: 206 _, _, func, ref = self._get_funcs(op) 207 else: 208 func, ref, _, _ = self._get_funcs(op) 209 for sample in op.sample_inputs( 210 device, dtype, noncontiguous=noncontiguous, allow_higher_dtype_scalars=True 211 ): 212 ref_kwargs = sample.kwargs 213 # div promotes ints to floats, so we cannot go on the fastpath there 214 div_slowpath = ( 215 dtype in integral_types_and(torch.bool) and op.name == "_foreach_div" 216 ) 217 expect_fastpath = not ( 218 noncontiguous or sample.disable_fastpath or div_slowpath 219 ) 220 ref_input, ctxmgr = sample.input, nullcontext() 221 if inplace: 222 with torch.no_grad(): 223 ref_input = [t.clone().detach() for t in sample.input] 224 ctxmgr = InplaceForeachVersionBumpCheck(self, sample.input) 225 try: 226 with ctxmgr: 227 actual = func( 228 [sample.input, *sample.args], 229 self.is_cuda, 230 expect_fastpath, 231 **sample.kwargs, 232 ) 233 except Exception as e: 234 with self.assertRaises(type(e)): 235 ref([ref_input, *sample.ref_args], **ref_kwargs) 236 else: 237 expected = ref([ref_input, *sample.ref_args], **ref_kwargs) 238 self.assertEqual(expected, actual) 239 240 def _binary_test( 241 self, 242 dtype, 243 op, 244 ref, 245 inputs, 246 is_fastpath, 247 is_inplace, 248 *, 249 alpha, 250 scalar_self_arg: bool, 251 ): 252 ref_inputs = ( 253 [[t.clone().detach() for t in inputs[0]], inputs[1]] 254 if is_inplace 255 else inputs 256 ) 257 try: 258 with InplaceForeachVersionBumpCheck( 259 self, inputs[0] 260 ) if op.is_inplace else nullcontext(): 261 actual = op(inputs, self.is_cuda, is_fastpath) 262 except RuntimeError as e: 263 with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])): 264 if not scalar_self_arg: 265 ref(ref_inputs) 266 else: 267 [ref.func(ref_inputs[0], t) for t in ref_inputs[1]] 268 else: 269 expected = ( 270 ref(ref_inputs) 271 if not scalar_self_arg 272 else [ref.func(ref_inputs[0], t) for t in ref_inputs[1]] 273 ) 274 self.assertEqual(actual, expected) 275 if alpha is not None and not scalar_self_arg: 276 kwargs = {"alpha": alpha} 277 ref_inputs = inputs 278 try: 279 op_kwargs = {} 280 op_kwargs.update(kwargs) 281 with InplaceForeachVersionBumpCheck( 282 self, inputs[0] 283 ) if op.is_inplace else nullcontext(): 284 actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs) 285 except RuntimeError as e: 286 with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])): 287 ref(ref_inputs, **kwargs) 288 else: 289 expected = ref(ref_inputs, **kwargs) 290 if dtype in (torch.float16, torch.bfloat16) and TEST_WITH_ROCM: 291 self.assertEqual( 292 expected, actual, atol=1.0e-3, rtol=default_tolerances(dtype)[0] 293 ) 294 else: 295 self.assertEqual(expected, actual) 296 297 @ops(filter(lambda op: op.supports_scalar_self_arg, foreach_binary_op_db)) 298 @parametrize("is_fastpath", (True, False)) 299 def test_binary_op_with_scalar_self_support(self, device, dtype, op, is_fastpath): 300 def clone(arg): 301 if isinstance(arg, (list, tuple)): 302 return [clone(a) for a in arg] 303 if torch.is_tensor(arg): 304 return arg.clone().detach().requires_grad_() 305 else: 306 return arg 307 308 scalar_self_arg_test_complete = False 309 for i, sample in enumerate( 310 op.sample_inputs( 311 device, 312 dtype, 313 noncontiguous=not is_fastpath, 314 allow_higher_dtype_scalars=True, 315 ) 316 ): 317 (rhs_arg,) = sample.args 318 kwargs = {} or sample.kwargs 319 alpha = kwargs.pop("alpha", None) 320 wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op) 321 if isinstance(rhs_arg, Number) and not scalar_self_arg_test_complete: 322 scalar_self_arg_test_complete = True 323 self._binary_test( 324 dtype, 325 wrapped_op, 326 ref, 327 [rhs_arg, sample.input], 328 is_fastpath, 329 False, 330 alpha=alpha, 331 scalar_self_arg=True, 332 ) 333 if op.supports_autograd and dtype == torch.float32: 334 transformed_sample = sample.transform( 335 get_transform_func( 336 len(sample.input), dtype, device, is_fastpath 337 ) 338 ) 339 tensors = transformed_sample.input 340 (rhs_arg,) = transformed_sample.args 341 ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg) 342 sum( 343 wrapped_op( 344 [rhs_arg, tensors], is_cuda=False, expect_fastpath=False 345 ) 346 ).mean().backward() 347 sum(ref.func(ref_rhs_arg, t) for t in ref_tensors).mean().backward() 348 self.assertEqual( 349 [t.grad for t in tensors], [t.grad for t in ref_tensors] 350 ) 351 352 @ops(foreach_pointwise_op_db) 353 @parametrize("is_fastpath", (True, False)) 354 def test_pointwise_op_with_tensor_of_scalarlist_overload( 355 self, device, dtype, op, is_fastpath 356 ): 357 for sample in op.sample_inputs( 358 device, 359 dtype, 360 noncontiguous=not is_fastpath, 361 allow_higher_dtype_scalars=True, 362 ): 363 assert isinstance(sample.args, tuple) 364 assert len(sample.args) == 2 365 inputs = [sample.input, *sample.args] 366 kwargs = sample.kwargs.copy() 367 disable_fastpath = sample.disable_fastpath and is_fastpath 368 wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op) 369 scalars = kwargs.pop("scalars", None) 370 371 if is_fastpath and scalars: 372 sample = sample.transform( 373 lambda t: t.clone().detach() if torch.is_tensor(t) else t 374 ) 375 inputs = [sample.input, *sample.args] 376 tensor_values = torch.tensor(scalars) 377 # 1D Tensor of scalars 378 for is_inplace, op_, ref_ in ( 379 (False, wrapped_op, ref), 380 (True, inplace_op, inplace_ref), 381 ): 382 self._pointwise_test( 383 op_, 384 ref_, 385 inputs, 386 is_fastpath and not disable_fastpath, 387 is_inplace, 388 scalars=tensor_values, 389 **kwargs, 390 ) 391 self._pointwise_test( 392 op_, 393 ref_, 394 inputs, 395 is_fastpath and not disable_fastpath, 396 is_inplace, 397 scalars=tensor_values[0], 398 custom_values_err="Expected packed scalar Tensor to be of dimension 1. Got 0 instead.", 399 **kwargs, 400 ) 401 if self.is_cuda: 402 self._pointwise_test( 403 op_, 404 ref_, 405 inputs, 406 is_fastpath and not disable_fastpath, 407 is_inplace, 408 scalars=tensor_values.cuda(), 409 custom_values_err="Expected scalars to be on CPU, got cuda:0 instead.", 410 **kwargs, 411 ) 412 self._pointwise_test( 413 op_, 414 ref_, 415 inputs, 416 is_fastpath and not disable_fastpath, 417 is_inplace, 418 scalars=tensor_values[:2], 419 custom_values_err=f"Expected length of scalars to match input of length {len(scalars)} but got 2 instead.", 420 **kwargs, 421 ) 422 self._pointwise_test( 423 op_, 424 ref_, 425 inputs, 426 is_fastpath and not disable_fastpath, 427 is_inplace, 428 scalars=torch.tensor([[0, 1], [2, 3]])[:, 1], 429 custom_values_err="Expected scalars to be contiguous.", 430 **kwargs, 431 ) 432 433 # Tests of implicit broadcasting 434 N = len(sample.input) 435 inputs = [ 436 [ 437 make_tensor( 438 (N, N), 439 device=device, 440 dtype=dtype, 441 noncontiguous=not is_fastpath, 442 ) 443 for _ in range(N) 444 ], 445 [ 446 make_tensor( 447 (N - i, 1), 448 device=device, 449 dtype=dtype, 450 noncontiguous=not is_fastpath, 451 ) 452 for i in range(N) 453 ], 454 [ 455 make_tensor( 456 (1, N - i), 457 device=device, 458 dtype=dtype, 459 noncontiguous=not is_fastpath, 460 ) 461 for i in range(N) 462 ], 463 ] 464 self._pointwise_test( 465 wrapped_op, 466 ref, 467 inputs, 468 is_fastpath and disable_fastpath, 469 is_inplace=False, 470 scalars=scalars, 471 **kwargs, 472 ) 473 self._pointwise_test( 474 inplace_op, 475 inplace_ref, 476 inputs, 477 is_fastpath and disable_fastpath, 478 is_inplace=True, 479 scalars=scalars, 480 **kwargs, 481 ) 482 483 def _pointwise_test( 484 self, 485 op, 486 ref, 487 inputs, 488 is_fastpath, 489 is_inplace, 490 *, 491 scalars=None, 492 custom_values_err=None, 493 **kwargs, 494 ): 495 ref_inputs = ( 496 [[t.clone().detach() for t in inputs[0]], inputs[1], inputs[2]] 497 if is_inplace 498 else inputs 499 ) 500 try: 501 with ( 502 InplaceForeachVersionBumpCheck(self, inputs[0]) 503 if is_inplace 504 else nullcontext() 505 ): 506 actual = op(inputs, self.is_cuda, is_fastpath, **kwargs) 507 except RuntimeError as e: 508 with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])): 509 ref(ref_inputs, **kwargs) 510 else: 511 expected = ref(ref_inputs, **kwargs) 512 self.assertEqual(expected, actual) 513 if scalars is not None: 514 kwargs = kwargs.copy() 515 kwargs["scalars"] = scalars 516 try: 517 actual = op(inputs, self.is_cuda, is_fastpath, **kwargs) 518 except RuntimeError as e: 519 # Match with error messages from regular non-foreach reference if no 520 # custom error message was provided. 521 if custom_values_err is None: 522 with self.assertRaisesRegex( 523 type(e), re.escape(str(e).splitlines()[0]) 524 ): 525 ref(ref_inputs, **kwargs) 526 else: 527 self.assertEqual(re.escape(str(e)), re.escape(custom_values_err)) 528 else: 529 expected = ref(ref_inputs, **kwargs) 530 self.assertEqual(expected, actual) 531 532 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) 533 def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype): 534 # TODO: enable empty list case 535 for tensors in [ 536 [torch.randn([0], device=device, dtype=dtype)], 537 [torch.empty_strided((0, 1), (0, 0), dtype=dtype, device=device)], 538 ]: 539 res = torch._foreach_add(tensors, 1) 540 self.assertEqual(res, tensors) 541 542 torch._foreach_add_(tensors, 1) 543 self.assertEqual(res, tensors) 544 545 # Regression test for https://github.com/pytorch/pytorch/issues/113156 546 torch._foreach_mul_(tensors, 1) 547 548 @onlyCUDA 549 @dtypes(torch.float32) 550 def test_foreach_check_stride_ignore_dims_of_one(self, device, dtype): 551 # default tensor stride is (9, 9, 3, 1). 552 tensor = torch.ones((2, 1, 3, 3), device=device, dtype=dtype) 553 strided_tensor = torch.ones( 554 (2, 1, 3, 3), device=device, dtype=dtype 555 ).as_strided((2, 1, 3, 3), (9, 1, 3, 1)) 556 left_inputs = [tensor, strided_tensor] 557 right_inputs = [strided_tensor, tensor] 558 compare_result = tensor + strided_tensor 559 foreach_add_check_ = ForeachFuncWrapper(torch._foreach_add) 560 out = foreach_add_check_( 561 (left_inputs, right_inputs), is_cuda=True, expect_fastpath=True 562 ) 563 for res in out: 564 self.assertEqual(res, compare_result) 565 566 @ops( 567 filter(lambda op: op.supports_out, foreach_binary_op_db), 568 dtypes=OpDTypes.supported, 569 ) 570 def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op): 571 foreach_op, ref = op.method_variant, op.ref 572 tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)] 573 574 if ref == torch.sub and dtype == torch.bool: 575 with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)): 576 [ref(t, 1) for t in tensors] 577 with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)): 578 foreach_op(tensors, 1) 579 return 580 581 expected = [ref(t, 1) for t in tensors] 582 res = foreach_op(tensors, 1) 583 self.assertEqual(res, expected) 584 585 @ops( 586 filter(lambda op: op.supports_out, foreach_binary_op_db), 587 allowed_dtypes=[torch.float], 588 ) 589 def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op): 590 foreach_op = op.method_variant 591 tensors = [ 592 torch.tensor([1.1], dtype=torch.float, device=device), 593 torch.tensor([1], dtype=torch.long, device=device), 594 ] 595 runtime_error = None 596 try: 597 foreach_op(tensors, 1) 598 except RuntimeError as e: 599 runtime_error = e 600 self.assertIsNone(runtime_error) 601 602 @skipIfTorchDynamo("Different error msgs, TODO") 603 @ops( 604 filter(lambda op: op.supports_out, foreach_binary_op_db), 605 dtypes=OpDTypes.supported, 606 ) 607 def test_binary_op_list_error_cases(self, device, dtype, op): 608 foreach_op, foreach_op_, ref, ref_ = ( 609 op.method_variant, 610 op.inplace_variant, 611 op.ref, 612 op.ref_inplace, 613 ) 614 tensors1 = [] 615 tensors2 = [] 616 ops_to_test = [foreach_op, foreach_op_] 617 618 # Empty lists 619 for fop in ops_to_test: 620 with self.assertRaisesRegex( 621 RuntimeError, "Tensor list must have at least one tensor." 622 ): 623 fop(tensors1, tensors2) 624 625 # One empty list 626 tensors1.append(torch.tensor([1], device=device, dtype=dtype)) 627 for fop in ops_to_test: 628 with self.assertRaisesRegex( 629 RuntimeError, 630 "Tensor list must have same number of elements as scalar list.", 631 ): 632 fop(tensors1, tensors2) 633 634 # Lists have different amount of tensors 635 tensors2.append(torch.tensor([1], device=device)) 636 tensors2.append(torch.tensor([1], device=device)) 637 for fop in ops_to_test: 638 with self.assertRaisesRegex( 639 RuntimeError, 640 "Tensor lists must have the same number of tensors, got 1 and 2", 641 ): 642 fop(tensors1, tensors2) 643 with self.assertRaisesRegex( 644 RuntimeError, 645 "Tensor lists must have the same number of tensors, got 2 and 1", 646 ): 647 fop(tensors2, tensors1) 648 649 # Corresponding tensors with different sizes that aren't compatible with broadcast 650 # If sizes are different then foreach chooses slow path, thus error messages are expected 651 # to be the same as torch regular function. 652 tensors1 = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)] 653 tensors2 = [torch.ones(11, 11, device=device, dtype=dtype) for _ in range(10)] 654 655 if dtype == torch.bool and foreach_op == torch._foreach_sub: 656 for fop in ops_to_test: 657 with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)): 658 fop(tensors1, tensors2) 659 return 660 with self.assertRaisesRegex( 661 RuntimeError, 662 r"The size of tensor a \(10\) must match the size of tensor b \(11\) at non-singleton dimension 1", 663 ): 664 foreach_op(tensors1, tensors2) 665 with self.assertRaisesRegex( 666 RuntimeError, 667 r"The size of tensor a \(10\) must match the size of tensor b \(11\) at non-singleton dimension 1", 668 ): 669 foreach_op_(tensors1, tensors2) 670 671 # different devices 672 if self.device_type == "cuda" and torch.cuda.device_count() > 1: 673 tensor1 = torch.zeros(10, 10, device="cuda:0", dtype=dtype) 674 tensor2 = torch.ones(10, 10, device="cuda:1", dtype=dtype) 675 with self.assertRaisesRegex( 676 RuntimeError, "Expected all tensors to be on the same device" 677 ): 678 foreach_op([tensor1], [tensor2]) 679 if ( 680 dtype in integral_types_and(torch.bool) 681 and foreach_op == torch._foreach_div 682 ): 683 with self.assertRaisesRegex(RuntimeError, "result type"): 684 foreach_op_([tensor1], [tensor2]) 685 else: 686 with self.assertRaisesRegex( 687 RuntimeError, "Expected all tensors to be on the same device" 688 ): 689 foreach_op_([tensor1], [tensor2]) 690 691 @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") 692 @ops( 693 filter(lambda op: op.supports_out, foreach_binary_op_db), 694 dtypes=OpDTypes.supported, 695 ) 696 def test_binary_op_list_slow_path(self, device, dtype, op): 697 foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op) 698 # 0-strides 699 tensor1 = make_tensor((10, 10), dtype=dtype, device=device) 700 tensor2 = make_tensor((1,), device=device, dtype=dtype).expand_as(tensor1) 701 inputs = ([tensor1], [tensor2]) 702 self._binary_test( 703 dtype, 704 foreach_op, 705 native_op, 706 inputs, 707 is_fastpath=False, 708 is_inplace=False, 709 alpha=None, 710 scalar_self_arg=False, 711 ) 712 self._binary_test( 713 dtype, 714 foreach_op_, 715 native_op_, 716 inputs, 717 is_fastpath=False, 718 is_inplace=True, 719 alpha=None, 720 scalar_self_arg=False, 721 ) 722 723 # different strides 724 tensor1 = torch.zeros(10, 10, device=device, dtype=dtype) 725 tensor2 = torch.ones(10, 10, device=device, dtype=dtype) 726 inputs = ([tensor1], [tensor2.t()]) 727 self._binary_test( 728 dtype, 729 foreach_op, 730 native_op, 731 inputs, 732 is_fastpath=False, 733 is_inplace=False, 734 alpha=None, 735 scalar_self_arg=False, 736 ) 737 self._binary_test( 738 dtype, 739 foreach_op_, 740 native_op_, 741 inputs, 742 is_fastpath=False, 743 is_inplace=True, 744 alpha=None, 745 scalar_self_arg=False, 746 ) 747 748 # non contiguous 749 tensor1 = make_tensor( 750 (5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True 751 ) 752 tensor2 = make_tensor( 753 (5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True 754 ) 755 self.assertFalse(tensor1.is_contiguous()) 756 self.assertFalse(tensor2.is_contiguous()) 757 inputs = ([tensor1], [tensor2]) 758 self._binary_test( 759 dtype, 760 foreach_op, 761 native_op, 762 inputs, 763 is_fastpath=False, 764 is_inplace=False, 765 alpha=None, 766 scalar_self_arg=False, 767 ) 768 self._binary_test( 769 dtype, 770 foreach_op_, 771 native_op_, 772 inputs, 773 is_fastpath=False, 774 is_inplace=True, 775 alpha=None, 776 scalar_self_arg=False, 777 ) 778 779 # sliced tensor 780 tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype) 781 tensor2 = make_tensor((5, 2, 1, 3 * 7), device=device, dtype=dtype)[ 782 :, :, :, ::7 783 ] 784 inputs = ([tensor1], [tensor2]) 785 self._binary_test( 786 dtype, 787 foreach_op, 788 native_op, 789 inputs, 790 is_fastpath=False, 791 is_inplace=False, 792 alpha=None, 793 scalar_self_arg=False, 794 ) 795 self._binary_test( 796 dtype, 797 foreach_op_, 798 native_op_, 799 inputs, 800 is_fastpath=False, 801 is_inplace=True, 802 alpha=None, 803 scalar_self_arg=False, 804 ) 805 806 @ops( 807 filter(lambda op: op.supports_out, foreach_binary_op_db), 808 dtypes=floating_types_and(torch.half, torch.bfloat16), 809 ) 810 def test_binary_op_float_inf_nan(self, device, dtype, op): 811 inputs = ( 812 [ 813 torch.tensor([float("inf")], device=device, dtype=dtype), 814 torch.tensor([-float("inf")], device=device, dtype=dtype), 815 torch.tensor([float("nan")], device=device, dtype=dtype), 816 torch.tensor([float("nan")], device=device, dtype=dtype), 817 ], 818 [ 819 torch.tensor([-float("inf")], device=device, dtype=dtype), 820 torch.tensor([float("inf")], device=device, dtype=dtype), 821 torch.tensor([float("inf")], device=device, dtype=dtype), 822 torch.tensor([float("nan")], device=device, dtype=dtype), 823 ], 824 ) 825 op, ref, inplace_op, inplace_ref = self._get_funcs(op) 826 self._binary_test( 827 dtype, op, ref, inputs, True, False, alpha=None, scalar_self_arg=False 828 ) 829 self._binary_test( 830 dtype, 831 inplace_op, 832 inplace_ref, 833 inputs, 834 True, 835 True, 836 alpha=None, 837 scalar_self_arg=False, 838 ) 839 840 # note: Below three tests (postfixed with `_tensors_on_different_devices`) 841 # checks whether foreach works with lists of tensors on different devices 842 # but tensors of the same index are on the same device, e.g., ['cuda', 'cpu]. 843 @onlyCUDA 844 @ops(foreach_unary_op_db) 845 def test_unary_op_tensors_on_different_devices(self, device, dtype, op): 846 method, ref, inplace_method, ref_inplace = self._get_funcs(op) 847 # tensors: ['cuda', 'cpu] 848 tensors = next( 849 iter( 850 op.sample_inputs( 851 device, 852 dtype, 853 num_input_tensors=[2], 854 allow_higher_dtype_scalars=True, 855 ) 856 ) 857 ).input 858 tensors[1] = tensors[1].to("cpu") 859 if not op.supports_out: 860 try: 861 actual = method((tensors,), False, False, zero_size=False) 862 except RuntimeError as e: 863 with self.assertRaisesRegex(type(e), str(e).splitlines()[0]): 864 ref((tensors,)) 865 else: 866 expected = ref((tensors,)) 867 self.assertEqual(expected, actual) 868 869 try: 870 inplace_method((tensors,), False, False, zero_size=False) 871 except RuntimeError as e: 872 with self.assertRaisesRegex(type(e), str(e).splitlines()[0]): 873 ref_inplace((tensors,)) 874 else: 875 if not op.supports_out: 876 self.assertEqual(expected, tensors) 877 else: 878 self.assertEqual([torch.zeros_like(t) for t in tensors], tensors) 879 880 @onlyCUDA 881 @ops(filter(lambda op: op.supports_out, foreach_binary_op_db)) 882 def test_binary_op_tensors_on_different_devices(self, device, dtype, op): 883 _cuda_tensors = next( 884 iter( 885 op.sample_inputs( 886 device, 887 dtype, 888 num_input_tensors=[2], 889 same_size=True, 890 allow_higher_dtype_scalars=True, 891 ) 892 ) 893 ).input 894 _cpu_tensors = next( 895 iter( 896 op.sample_inputs( 897 "cpu", 898 dtype, 899 num_input_tensors=[2], 900 same_size=True, 901 allow_higher_dtype_scalars=True, 902 ) 903 ) 904 ).input 905 tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors)) 906 907 foreach_op, foreach_op_ = op.method_variant, op.inplace_variant 908 native_op, native_op_ = op.ref, op.ref_inplace 909 try: 910 actual = foreach_op(tensors1, tensors2) 911 except RuntimeError as e: 912 with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])): 913 [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)] 914 else: 915 expected = [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)] 916 self.assertEqual(expected, actual) 917 try: 918 foreach_op_(tensors1, tensors2) 919 except RuntimeError as e: 920 with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])): 921 [native_op_(t1, t2) for t1, t2 in zip(tensors1, tensors2)] 922 else: 923 self.assertEqual(actual, tensors1) 924 925 @onlyCUDA 926 @ops(foreach_pointwise_op_db, allowed_dtypes=floating_types()) 927 def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op): 928 # tensors1: ['cuda', 'cpu] 929 # tensors2: ['cuda', 'cpu] 930 # tensors3: ['cuda', 'cpu] 931 # first tensorlist is zero-size when float32 932 _cuda_tensors = list( 933 op.sample_inputs( 934 device, 935 dtype, 936 num_input_tensors=[3], 937 same_size=True, 938 allow_higher_dtype_scalars=True, 939 ) 940 )[int(dtype == torch.float32)].input 941 _cpu_tensors = next( 942 iter( 943 op.sample_inputs( 944 "cpu", 945 dtype, 946 num_input_tensors=[3], 947 same_size=True, 948 allow_higher_dtype_scalars=True, 949 ) 950 ) 951 ).input 952 tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors)) 953 954 foreach_op, foreach_op_, native_op = ( 955 op.method_variant, 956 op.inplace_variant, 957 op.ref, 958 ) 959 actual = foreach_op(tensors1, tensors2, tensors3) 960 expected = [native_op(*_cuda_tensors), native_op(*_cpu_tensors)] 961 self.assertEqual(expected, actual) 962 963 # note(mkozuki): Limiting dtypes to FP32&FP64, we can safely run inplace ops. 964 foreach_op_(tensors1, tensors2, tensors3) 965 self.assertEqual(expected, tensors1) 966 967 # note: BFloat16 has the same number of exponent bits as FP32 968 # so if squared L2 norm overflows in BF16, then it also overflows in FP32. 969 @onlyCUDA 970 @ops( 971 [o for o in foreach_reduce_op_db if "norm" in o.name], 972 allowed_dtypes=(torch.half, torch.bfloat16), 973 ) 974 def test_foreach_l2_large_value_input(self, device, dtype, op): 975 ord, N = 2, 10 976 max_value = torch.finfo(dtype).max 977 scaler = torch.tensor([max_value]).sqrt().to(device=device, dtype=dtype) 978 inputs = ( 979 [ 980 t * scaler 981 for t in next( 982 iter( 983 op.sample_inputs( 984 device, 985 dtype, 986 requries_grad=True, 987 num_input_tensors=[N], 988 low=1, 989 ) 990 ) 991 ).input 992 ][:-1], 993 ) 994 # make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`. 995 self.assertTrue(scaler * scaler * N > max_value) 996 fn, ref_fn, *_ = self._get_funcs(op) 997 actual = fn( 998 inputs, is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False 999 ) 1000 expect = ref_fn(inputs, ord=ord) 1001 1002 if dtype == torch.float16: 1003 # making sure the reference L2 norm values are in the range of FP16. 1004 self.assertFalse(any(torch.isinf(e) for e in expect)) 1005 else: 1006 self.assertTrue( 1007 all( 1008 inputs[0][i].numel() == 0 or torch.isinf(e) 1009 for i, e in enumerate(expect) 1010 ) 1011 ) 1012 self.assertEqual(expect, actual, equal_nan=False) 1013 1014 @onlyCUDA 1015 @ops(foreach_reduce_op_db, allowed_dtypes=floating_types()) 1016 @parametrize("use_cuda_graph", (False, True)) 1017 def test_big_num_tensors(self, device, dtype, op, use_cuda_graph): 1018 N = 600 1019 tensorlist = [ 1020 make_tensor((2, 3), dtype=dtype, device=device, noncontiguous=False) 1021 for _ in range(N) 1022 ] 1023 fn, ref_fn, *_ = self._get_funcs(op) 1024 1025 import math 1026 1027 if op.name == "_foreach_norm": 1028 ords = (1, 2, math.inf) 1029 else: 1030 ords = (None,) 1031 1032 for ord in ords: 1033 kwargs = {"ord": ord} if ord else {} 1034 if not use_cuda_graph: 1035 actual = fn( 1036 inputs=[tensorlist], 1037 is_cuda=True, 1038 expect_fastpath=True, 1039 zero_size=False, 1040 **kwargs, 1041 ) 1042 else: 1043 # When using CUDA graphs and the tensor metadata doesn't fit in 1044 # the static kernel argument space, multi_tensor_apply creates 1045 # the launch arguments once, uses cudaUserObject_t to tie its 1046 # lifetime to the graph, and reuses it throughout replays. This 1047 # test verifies multi_tensor_apply's behavior in the scenario. 1048 g = torch.cuda.CUDAGraph() 1049 with torch.cuda.graph(g): 1050 actual = fn.func(tensorlist, **kwargs) 1051 g.replay() 1052 expect = ref_fn(inputs=[tensorlist], **kwargs) 1053 1054 self.assertEqual(expect, actual, equal_nan=True) 1055 1056 @onlyCUDA 1057 @ops(foreach_reduce_op_db) 1058 def test_foreach_reduce_large_input(self, device, dtype, op): 1059 # test inputs larger than kChunkSize = 65536 1060 N = 65536 * 2 1061 disable_fastpath = False 1062 kwargs = {} 1063 if op.name == "_foreach_norm": 1064 ord = 2 1065 disable_fastpath = not ( 1066 ord in (1, 2) 1067 and dtype in floating_types_and(torch.half, torch.bfloat16) 1068 ) 1069 kwargs["ord"] = ord 1070 1071 inputs = ([make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)],) 1072 wrapped_op, ref, _, _ = self._get_funcs(op) 1073 self.assertEqual( 1074 ref(inputs, **kwargs), 1075 wrapped_op( 1076 inputs, self.is_cuda, not disable_fastpath, zero_size=False, **kwargs 1077 ), 1078 ) 1079 1080 @onlyCUDA 1081 @ops( 1082 foreach_unary_op_db 1083 + foreach_binary_op_db 1084 + foreach_pointwise_op_db 1085 + foreach_other_op_db, 1086 dtypes=(torch.float,), 1087 ) 1088 def test_inplace_foreach_leaf_check_and_grad_fn(self, device, dtype, op): 1089 inplace_op = op.inplace_variant 1090 if inplace_op is None: 1091 self.skipTest("no in-place op available") 1092 1093 sample = next( 1094 iter( 1095 op.sample_inputs( 1096 dtype=dtype, device=device, num_input_tensors=[2], same_size=True 1097 ) 1098 ) 1099 ) 1100 sample.input[0].requires_grad_(True) 1101 with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"): 1102 inplace_op(sample.input, *sample.args) 1103 sample.input[1].requires_grad_(True) 1104 with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"): 1105 inplace_op(sample.input, *sample.args) 1106 1107 _tensors = [ 1108 t.clone().detach().requires_grad_(i == 0) 1109 for i, t in enumerate(sample.input) 1110 ] 1111 tensors = [t.clone() for t in _tensors] 1112 inplace_op(tensors, *sample.args) 1113 self.assertIsNotNone(tensors[0].grad_fn) 1114 self.assertIsNone(tensors[1].grad_fn) 1115 1116 @onlyCUDA 1117 @ops( 1118 filter( 1119 lambda op: op.supports_out, 1120 foreach_unary_op_db 1121 + foreach_binary_op_db 1122 + foreach_pointwise_op_db 1123 + foreach_other_op_db, 1124 ), 1125 dtypes=(torch.float,), 1126 ) 1127 def test_outplace_with_invalid_grads(self, device, dtype, op): 1128 func, *_ = self._get_funcs(op) 1129 sample = next( 1130 iter( 1131 op.sample_inputs( 1132 dtype=dtype, 1133 device=device, 1134 requires_grad=True, 1135 num_input_tensors=[2], 1136 same_size=True, 1137 ) 1138 ) 1139 ) 1140 self.assertTrue(all(t.requires_grad for t in sample.input)) 1141 (out1, out2) = func( 1142 [sample.input, *sample.args], 1143 is_cuda=False, 1144 expect_fastpath=False, 1145 **sample.kwargs, 1146 ) 1147 out1.backward(torch.ones_like(out1)) 1148 self.assertIsNotNone(sample.input[0].grad) 1149 self.assertIsNone(sample.input[1].grad) 1150 1151 @ops( 1152 filter( 1153 lambda op: op.backward_requires_result, 1154 foreach_unary_op_db 1155 + foreach_binary_op_db 1156 + foreach_pointwise_op_db 1157 + foreach_other_op_db, 1158 ), 1159 dtypes=(torch.float32,), 1160 ) 1161 def test_lifetime_of_grad_fn_when_result_is_saved(self, device, dtype, op): 1162 def get_ref(func, sample): 1163 class Foo: 1164 pass 1165 1166 out = func( 1167 (sample.input, *sample.args), 1168 is_cuda=False, 1169 expect_fastpath=False, 1170 **sample.kwargs, 1171 ) 1172 foo = Foo() 1173 meta_dict = out[0].grad_fn.metadata 1174 meta_dict[0] = foo 1175 ref = weakref.ref(foo) 1176 return out, ref 1177 1178 def _test(func, sample): 1179 out, ref = get_ref(func, sample) 1180 self.assertIsNotNone(ref()) 1181 del out 1182 self.assertIsNone(ref()) 1183 1184 func = self._get_funcs(op)[0] 1185 for sample in op.sample_inputs( 1186 device, dtype, requires_grad=True, num_input_tensors=[1] 1187 ): 1188 for key in ("is_fastpath", "disable_fastpath"): 1189 if key in sample.kwargs: 1190 del sample.kwargs[key] 1191 # note: `_foreach_pow.Scalar` and `_foreach_pow.ScalarList` don't depend on `result` 1192 # see: https://github.com/pytorch/pytorch/blob/5403c777/tools/autograd/derivatives.yaml#L3048-L3049 1193 if op.name == "_foreach_pow": 1194 if ( 1195 isinstance(sample.args[0], list) 1196 and isinstance(sample.args[0][0], Number) 1197 ) or ( 1198 isinstance(sample.args[0], Number) 1199 and not isinstance(sample.args[0], float) 1200 ): 1201 continue 1202 if isinstance(sample.args[0], float): 1203 new_args = (sample.input,) 1204 sample.input = sample.args[0] 1205 sample.args = new_args 1206 _test(func, sample) 1207 1208 @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 1209 def test_tensors_grouping(self): 1210 num_tensors_per_list = 10 1211 num_devices = torch.cuda.device_count() 1212 dtypes = (torch.float16, torch.float32, torch.float64) 1213 list1 = [ 1214 torch.tensor( 1215 i, 1216 device=torch.device("cuda", random.randint(0, num_devices - 1)), 1217 dtype=dtypes[random.randint(0, 2)], 1218 ) 1219 for i in range(num_tensors_per_list) 1220 ] 1221 list2 = [None for _ in list1] 1222 list3 = [torch.rand_like(t) for t in list1] 1223 nested_tensorlists = [list1, list2, list3] 1224 grouped_tensors = torch.utils._foreach_utils._group_tensors_by_device_and_dtype( 1225 nested_tensorlists, with_indices=True 1226 ) 1227 num_tensors_seen = 0 1228 for (device, dtype), ([l1, l2, l3], indices) in grouped_tensors.items(): 1229 for t in itertools.chain(l1, l3): 1230 self.assertEqual(t.device, device) 1231 self.assertEqual(t.dtype, dtype) 1232 num_tensors_seen += 1 1233 self.assertEqual(len(l1), len(l2)) 1234 self.assertTrue(all(p is None for p in l2)) 1235 for i, index in enumerate(indices): 1236 self.assertEqual(l1[i], list1[index]) 1237 self.assertEqual(l2[i], list2[index]) 1238 self.assertEqual(l3[i], list3[index]) 1239 self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list) 1240 1241 @onlyCUDA 1242 def test_0dim_tensor_overload_cpu_ok(self): 1243 tensors = [torch.ones((), device="cuda", dtype=torch.float32) for _ in range(2)] 1244 scalar_cpu_tensor = torch.tensor(4.0, device="cpu") 1245 1246 # For mul and div, the scalar is allowed to be on CPU too 1247 actual = torch._foreach_mul(tensors, scalar_cpu_tensor) 1248 self.assertEqual(actual, [t.mul(scalar_cpu_tensor) for t in tensors]) 1249 actual = torch._foreach_div(tensors, scalar_cpu_tensor) 1250 self.assertEqual(actual, [t.div(scalar_cpu_tensor) for t in tensors]) 1251 1252 @onlyCUDA 1253 def test_div_reciprocal(self): 1254 expect_m, expect_e = torch.frexp( 1255 torch.div(torch.tensor(0.1, device="cuda"), 10.0) 1256 ) 1257 actual_m, actual_e = torch.frexp( 1258 torch._foreach_div([torch.tensor(0.1, device="cuda")], [10.0])[0] 1259 ) 1260 self.assertEqual(expect_m, actual_m) 1261 self.assertEqual(expect_e, actual_e) 1262 1263 @onlyCUDA 1264 def test_0dim_tensor_overload_exception(self): 1265 # check exceptions of fast path 1266 tensors = [ 1267 make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2) 1268 ] 1269 with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be on"): 1270 torch._foreach_add(tensors, torch.tensor(1.0, device="cpu"), alpha=1.0) 1271 1272 tensors = [ 1273 make_tensor((2, 2), dtype=torch.float, device=d) for d in ("cpu", "cuda") 1274 ] 1275 with self.assertRaisesRegex( 1276 RuntimeError, "scalar tensor expected to be 0 dim but" 1277 ): 1278 torch._foreach_mul(tensors, torch.tensor([1.0, 1.0], device="cuda")) 1279 with self.assertRaisesRegex( 1280 RuntimeError, "scalar tensor expected to be 0 dim but" 1281 ): 1282 torch._foreach_add(tensors, torch.tensor([1.0, 1.0], device="cuda")) 1283 1284 @onlyCUDA 1285 @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db)) 1286 def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op): 1287 foreach_copy_ = op.inplace_variant 1288 copy_ = op.ref_inplace 1289 for non_blocking in (False, True): 1290 for sample in op.sample_inputs( 1291 device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True 1292 ): 1293 with torch.no_grad(): 1294 ref_input = [t.clone().detach() for t in sample.input] 1295 foreach_copy_(sample.input, sample.args[0], non_blocking) 1296 for t, s in zip(ref_input, sample.args[0]): 1297 copy_(t, s, non_blocking) 1298 self.assertEqual(sample.input, ref_input) 1299 if torch.cuda.device_count() > 1: 1300 device = torch.device("cuda", 1) 1301 rhs_tensors = [t.to(device) for t in sample.args[0]] 1302 foreach_copy_(sample.input, rhs_tensors, non_blocking) 1303 for t, s in zip(ref_input, rhs_tensors): 1304 copy_(t, s, non_blocking) 1305 self.assertEqual(ref_input, sample.input) 1306 1307 @onlyCUDA 1308 @ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db)) 1309 def test_foreach_copy_with_multi_dtypes(self, device, dtype, op): 1310 # check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_ 1311 foreach_copy_ = ForeachFuncWrapper(op.inplace_variant) 1312 for sample in op.sample_inputs( 1313 device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True 1314 ): 1315 for src_dtype in floating_types_and(torch.half, torch.bfloat16): 1316 if src_dtype == dtype: 1317 continue 1318 self_tensors = [t.clone() for t in sample.input] 1319 src_tensors = [t.to(src_dtype) for t in self_tensors] 1320 out = foreach_copy_( 1321 (self_tensors, src_tensors), is_cuda=True, expect_fastpath=True 1322 ) 1323 ref_out = [ 1324 torch.empty_like(t).copy_(s) 1325 for t, s in zip(self_tensors, src_tensors) 1326 ] 1327 for t, ref_t in zip(out, ref_out): 1328 self.assertTrue(torch.equal(t, ref_t)) 1329 1330 # Test reverse-mode & forward-mode AD if supported. 1331 @onlyCUDA 1332 @ops( 1333 foreach_unary_op_db 1334 + foreach_binary_op_db 1335 + foreach_pointwise_op_db 1336 + foreach_reduce_op_db 1337 + foreach_other_op_db, 1338 dtypes=OpDTypes.supported, 1339 allowed_dtypes=(torch.float64, torch.complex128), 1340 ) 1341 @parametrize( 1342 "inplace", (False, True), name_fn=lambda x: "inplace" if x else "outplace" 1343 ) 1344 def test_autodiff(self, device, dtype, op, inplace): 1345 if (not inplace) and not op.supports_out: 1346 self.skipTest("out-of-place not implemented") 1347 if inplace and op.has_no_in_place: 1348 self.skipTest("in-place not implemented") 1349 if not ( 1350 op.supports_autograd 1351 or op.supports_inplace_autograd 1352 or op.supports_forward_ad 1353 ): 1354 self.skipTest("neither reverse mode nor forward mode supported") 1355 1356 # note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex. 1357 if ( 1358 (not inplace) 1359 and dtype == torch.float64 1360 and op.name 1361 in ( 1362 "_foreach_acos", 1363 "_foreach_asin", 1364 "_foreach_log10", 1365 "_foreach_log1p", 1366 "_foreach_log2", 1367 "_foreach_log", 1368 "_foreach_pow", 1369 "_foreach_sqrt", 1370 ) 1371 ): 1372 value_range = {"low": 0.5, "high": 1.0} 1373 else: 1374 value_range = {} 1375 for sample in op.sample_inputs( 1376 device, 1377 dtype, 1378 requires_grad=True, 1379 num_input_tensors=[5], 1380 allow_higher_dtype_scalars=True, 1381 **value_range, 1382 ): 1383 # Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])` 1384 if op.name == "_foreach_pow" and isinstance(sample.input, Number): 1385 continue 1386 1387 func = None 1388 if inplace: 1389 # Call `clone` to avoid inplace modifications likewise 1390 # `torch.testing._internal.common_utils.TestGradients._get_safe_inplace` 1391 def inplace_func(*tensorlist): 1392 kwargs = ( 1393 {"alpha": sample.kwargs["alpha"]} 1394 if "alpha" in sample.kwargs 1395 else {} 1396 ) 1397 op.inplace_variant( 1398 tuple(t.clone() for t in tensorlist), *sample.args, **kwargs 1399 ) 1400 return tensorlist 1401 1402 func = inplace_func 1403 else: 1404 1405 def outplace_func(*tensorlist): 1406 kwargs = ( 1407 {"alpha": sample.kwargs["alpha"]} 1408 if "alpha" in sample.kwargs 1409 else {} 1410 ) 1411 return op.method_variant(tensorlist, *sample.args, **kwargs) 1412 1413 func = outplace_func 1414 1415 working_sample, err_msg_pattern = check_autodiff_sample( 1416 op, sample, dtype, inplace 1417 ) 1418 1419 def call_gradcheck(): 1420 gradcheck( 1421 func, 1422 sample.input, 1423 raise_exception=True, 1424 check_forward_ad=op.supports_forward_ad, 1425 check_batched_forward_grad=False, 1426 check_backward_ad=op.supports_autograd, 1427 check_batched_grad=False, 1428 ) 1429 1430 if not working_sample: 1431 if not err_msg_pattern: 1432 # lhs of float64 and rhs of complex. 1433 continue 1434 with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)): 1435 call_gradcheck() 1436 continue 1437 call_gradcheck() 1438 1439 # Test per-tensor `grad_fn` behavior. 1440 if inplace and op.supports_inplace_autograd: 1441 # per-tensor `grad_fn` check. 1442 hook_buffer = [] 1443 1444 def get_grad_fn_hook(i): 1445 def hook(grad_inputs, grad_outputs) -> None: 1446 hook_buffer.append(i) 1447 1448 return hook 1449 1450 _inputs = [t.clone().detach().requires_grad_() for t in sample.input] 1451 inputs = [t.clone() for t in _inputs] 1452 kwargs = ( 1453 {"alpha": sample.kwargs["alpha"]} 1454 if "alpha" in sample.kwargs 1455 else {} 1456 ) 1457 op.inplace_variant(inputs, *sample.args, **kwargs) 1458 1459 self.assertEqual(len({t.grad_fn for t in inputs}), len(inputs)) 1460 1461 for i, t in enumerate(inputs): 1462 t.grad_fn.register_hook(get_grad_fn_hook(i)) 1463 1464 torch.autograd.grad( 1465 inputs[0], 1466 inputs=(_inputs[0],), 1467 grad_outputs=(torch.rand_like(inputs[0]),), 1468 retain_graph=True, 1469 ) 1470 self.assertEqual(hook_buffer, [0]) 1471 hook_buffer.clear() 1472 1473 # tensors have different shapes. 1474 sum_of_cloned_tensors = torch.cat([t.view(-1) for t in inputs]).sum() 1475 grad_output = torch.rand_like(sum_of_cloned_tensors) 1476 torch.autograd.grad( 1477 sum_of_cloned_tensors, 1478 inputs=tuple(_inputs), 1479 grad_outputs=(grad_output,), 1480 retain_graph=False, 1481 ) 1482 self.assertEqual(hook_buffer, list(reversed(range(len(inputs))))) 1483 1484 1485# TODO(crcrpar): Hide this inside torch/testing/_internal. 1486# would end up adding another layer to `foreach_inputs_sample_func.__call__` 1487# so that we can use this function as something like the first argument of `filter` function. 1488# Even after moving this function to testing, I personally think it'd be better to check the error message. 1489def check_autodiff_sample(op, sample, dtype, is_inplace): 1490 if op.name == "_foreach_abs" and is_inplace and dtype == torch.complex128: 1491 return False, "In-place abs is not supported for complex tensors." 1492 if op.name == "_foreach_sub" and ( 1493 ( 1494 isinstance(sample.args[-1], list) 1495 and any(isinstance(a, bool) for a in sample.args[-1]) 1496 ) 1497 or isinstance(sample.args[-1], bool) 1498 ): 1499 return False, _BOOL_SUB_ERR_MSG 1500 if op.name == "_foreach_norm" and (not is_inplace): 1501 return ( 1502 False, 1503 "Trying to set a forward gradient that has a different size than that of the original Tensor, " 1504 "this is not supported. Tensor is of size [] while the given forward gradient is of size [1, 1].", 1505 ) 1506 rhs_arg_has_complex_number = sample.args and ( 1507 ( 1508 isinstance(sample.args[-1], list) 1509 and any(isinstance(a, complex) for a in sample.args[-1]) 1510 ) 1511 or (isinstance(sample.args[-1], complex)) 1512 ) 1513 if rhs_arg_has_complex_number and dtype == torch.float64: 1514 if op.name in ( 1515 "_foreach_clamp_max", 1516 "_foreach_clamp_min", 1517 "_foreach_maximum", 1518 "_foreach_minimum", 1519 ): 1520 return False, "clamp is not supported for complex types" 1521 if op.name == "_foreach_lerp" and is_inplace: 1522 return False, "value cannot be converted to type double without overflow" 1523 if not is_inplace: 1524 return False, "" 1525 else: 1526 if op.name == "_foreach_pow": 1527 return False, "Found dtype Double but expected ComplexDouble" 1528 if op.name in ( 1529 "_foreach_add", 1530 "_foreach_sub", 1531 "_foreach_mul", 1532 "_foreach_div", 1533 ): 1534 return ( 1535 False, 1536 "result type ComplexDouble can't be cast to the desired output type Double", 1537 ) 1538 return True, "" 1539 1540 1541instantiate_device_type_tests(TestForeach, globals()) 1542 1543 1544if __name__ == "__main__": 1545 run_tests() 1546