1# mypy: ignore-errors 2 3import os 4 5import torch 6from torch.testing import make_tensor # noqa: F401 7from torch.testing._internal.opinfo.core import ( # noqa: F401 8 BinaryUfuncInfo, 9 ErrorInput, 10 generate_elementwise_binary_tensors, 11 ReductionOpInfo, 12 sample_inputs_reduction, 13 SampleInput, 14) 15 16 17def _check_validate(op_info, sample): 18 def _check_fail(sample): 19 try: 20 op_info( 21 sample.sample_input.input, 22 *sample.sample_input.args, 23 **sample.sample_input.kwargs, 24 ) 25 except sample.error_type: 26 pass 27 except Exception as msg: 28 raise AssertionError( # noqa: B904 29 f"{op_info.name} on {sample.sample_input=} expected exception " 30 f"{sample.error_type}: {sample.error_regex}, got {type(msg).__name__}: {msg}" 31 ) 32 else: 33 raise AssertionError( 34 f"{op_info.name} on {sample.sample_input=} expected exception " 35 f"{sample.error_type}: {sample.error_regex}, got none." 36 ) 37 38 def _check_success(sample): 39 try: 40 op_info(sample.input, *sample.args, **sample.kwargs) 41 except Exception as msg: 42 raise AssertionError( # noqa: B904 43 f"{op_info.name} on {sample=} expected to succeed " 44 f", got {type(msg).__name__}: {msg}" 45 ) 46 47 if isinstance(sample, ErrorInput): 48 _check_fail(sample) 49 else: 50 _check_success(sample) 51 52 53def _sample_inputs_sparse( 54 sample_inputs, 55 maybe_failing_sample_inputs, 56 validate_sample_input, 57 op_info, 58 *args, 59 **kwargs, 60): 61 check_validate = ( 62 os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1" 63 ) 64 for sample in sample_inputs(op_info, *args, **kwargs): 65 sample = validate_sample_input(op_info, sample, check_validate=check_validate) 66 if isinstance(sample, SampleInput): 67 yield sample 68 # Error inputs are handled in error_inputs_sparse 69 70 for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs): 71 sample = validate_sample_input(op_info, sample, check_validate=check_validate) 72 if isinstance(sample, SampleInput): 73 yield sample 74 75 76def _error_inputs_sparse( 77 maybe_failing_sample_inputs, validate_sample_input, op_info, *args, **kwargs 78): 79 check_validate = ( 80 os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1" 81 ) 82 for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs): 83 sample = validate_sample_input(op_info, sample, check_validate=check_validate) 84 if isinstance(sample, ErrorInput): 85 yield sample 86 # Sample inputs are handled in sample_inputs_sparse 87 88 89def _apply_requires_grad_to_samples(sample_inputs): 90 """Decorator to _maybe_failing_sample_inputs_... generator functions 91 that clones and sets requires_grad argument to tensors in sample 92 input arguments. This is needed when the generated samples share 93 tensor instances. 94 """ 95 96 def wrapper(op_info, device, dtype, requires_grad, layout, **kwargs): 97 def apply_requires_grad(x): 98 if ( 99 not isinstance(x, torch.Tensor) 100 or x.requires_grad 101 or not requires_grad 102 or not (x.is_floating_point() or x.is_complex()) 103 ): 104 return x 105 return x.detach().clone().requires_grad_(requires_grad) 106 107 if requires_grad: 108 for sample_input in sample_inputs( 109 op_info, device, dtype, requires_grad, layout, **kwargs 110 ): 111 yield sample_input.transform(apply_requires_grad) 112 else: 113 yield from sample_inputs( 114 op_info, device, dtype, requires_grad, layout, **kwargs 115 ) 116 117 return wrapper 118 119 120def sample_inputs_sparse_reduction( 121 op_info, device, dtype, requires_grad, layout, blocksize=None, **kwargs 122): 123 """Sample inputs for reduction operations on sparse tensors.""" 124 layout_name = str(layout).split(".", 1)[-1].rsplit("_coo", 1)[0] 125 op_supports_layout = getattr(op_info, "supports_" + layout_name) 126 if not op_supports_layout: 127 return 128 129 for sample_input in sample_inputs_reduction( 130 op_info, device, dtype, requires_grad, **kwargs 131 ): 132 if sample_input.input.ndim == 0: 133 # scalar sparse tensors are not supported 134 continue 135 136 if layout in { 137 torch.sparse_csr, 138 torch.sparse_csc, 139 torch.sparse_bsr, 140 torch.sparse_bsc, 141 }: 142 if sample_input.input.ndim < 2: 143 # conversion to sparse compressed tensors requires at 144 # least 2 dimensional tensors 145 continue 146 if sample_input.input.ndim > 2 and (sample_input.input == 0).any(): 147 # Skip batched sparse compressed samples that contain 148 # explicit zeros because to_sparse(layout=..) will 149 # fail, see gh-98495. 150 # TODO: remove this if-block after gh-98495 is fixed. 151 continue 152 153 if layout in {torch.sparse_bsr, torch.sparse_bsc} and blocksize is None: 154 blocksize = (1, 1) 155 156 yield SampleInput( 157 sample_input.input.detach() 158 .to_sparse(layout=layout, blocksize=blocksize) 159 .requires_grad_(requires_grad), 160 args=sample_input.args, 161 kwargs=sample_input.kwargs, 162 ) 163 164 if layout is torch.sparse_coo and (dtype.is_floating_point or dtype.is_complex): 165 # uncoalesced samples 166 inp = sample_input.input.detach().to_sparse(layout=layout) 167 inp = torch.sparse_coo_tensor( 168 inp.indices().repeat(1, 2), 169 inp.values().repeat(2), 170 inp.shape, 171 dtype=inp.dtype, 172 device=inp.device, 173 ) 174 assert not inp.is_coalesced() 175 yield SampleInput( 176 inp.requires_grad_(requires_grad), 177 args=sample_input.args, 178 kwargs=sample_input.kwargs, 179 ) 180 181 if sample_input.input.ndim > 2: 182 # hybrid samples 183 yield SampleInput( 184 sample_input.input.detach() 185 .to_sparse( 186 layout=layout, 187 blocksize=blocksize, 188 dense_dim=sample_input.input.ndim - 2, 189 ) 190 .requires_grad_(requires_grad), 191 args=sample_input.args, 192 kwargs=sample_input.kwargs, 193 ) 194 195 196def _validate_sample_input_sparse_reduction(op_info, sample, check_validate=False): 197 """Return the specified sample when it is valid and supported by the 198 operation. Otherwise, return the sample as ErrorInput instance. 199 200 When check_validate is True, the result is validated against 201 calling the op on the sample. 202 """ 203 UNSPECIFIED = object() 204 if op_info.name == "sum": 205 sample = _validate_sample_input_sparse_reduction_sum(sample) 206 207 if op_info.name in {"masked.sum"}: 208 mask = sample.kwargs.get("mask", UNSPECIFIED) 209 if ( 210 mask not in {None, UNSPECIFIED} 211 and mask.ndim > 2 212 and mask.layout is torch.strided 213 and (mask == 0).any() 214 ): 215 # TODO: remove this if-block after gh-98495 is fixed. 216 sample = ErrorInput( 217 sample, 218 error_regex="Expect the same number of specified elements per batch.", 219 ) 220 elif not sample.kwargs.get("keepdim"): 221 sample = ErrorInput( 222 sample, 223 error_type=(AssertionError, RuntimeError), 224 error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported", 225 ) 226 elif mask is UNSPECIFIED: 227 sample = ErrorInput( 228 sample, 229 error_type=ValueError, 230 error_regex="masked (.*) expects explicit mask for sparse_csr tensor input", 231 ) 232 elif sample.input.ndim > 2: 233 sample = ErrorInput( 234 sample, 235 error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.", 236 ) 237 238 if op_info.name in {"masked.amax", "masked.amin", "masked.mean", "masked.prod"}: 239 t_inp = sample.input 240 batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim() 241 mask = sample.kwargs.get("mask") 242 if ( 243 mask is not None 244 and mask.ndim > 2 245 and mask.layout is torch.strided 246 and (mask == 0).any() 247 ): 248 # TODO: remove this if-block after gh-98495 is fixed. 249 sample = ErrorInput( 250 sample, 251 error_regex="Expect the same number of specified elements per batch.", 252 ) 253 elif mask is None: 254 sample = ErrorInput( 255 sample, 256 error_type=ValueError, 257 error_regex="masked (.*) expects explicit mask for sparse_csr tensor input", 258 ) 259 elif ( 260 mask.layout is sample.input.layout 261 and mask.ndim > 2 262 and op_info.name == "masked.mean" 263 ): 264 sample = ErrorInput( 265 sample, 266 error_type=TypeError, 267 error_regex=( 268 "where[(][)] received an invalid combination of arguments" 269 " - got [(]Tensor, Tensor, NoneType[)]" 270 ), 271 ) 272 elif not sample.kwargs.get("keepdim"): 273 sample = ErrorInput( 274 sample, 275 error_type=(AssertionError, RuntimeError), 276 error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported", 277 ) 278 elif ( 279 sample.input.ndim > 2 280 and (sample.kwargs.get("dim") not in {0, 1}) 281 and mask.ndim > 2 282 and mask.layout is not torch.strided 283 ): 284 if sample.kwargs.get("dim") == (0, -1): 285 sample = ErrorInput( 286 sample, 287 error_regex="tensor dimensionality must be sum of batch, base, and dense dimensionalities", 288 ) 289 elif op_info.name == "masked.prod": 290 sample = ErrorInput( 291 sample, 292 error_regex="input_dim == 2 INTERNAL ASSERT FAILED at", 293 ) 294 else: 295 sample = ErrorInput( 296 sample, 297 error_type=AssertionError, 298 error_regex="Sparse CSR tensors are 2D and only support reduction along dim 0 or 1.", 299 ) 300 elif sample.input.ndim > 2: 301 sample = ErrorInput( 302 sample, 303 error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.", 304 ) 305 elif ( 306 mask.layout is t_inp.layout 307 and mask._nnz() != t_inp._nnz() 308 and t_inp.dense_dim() > 0 309 ): 310 sample = ErrorInput( 311 sample, 312 error_regex="Index tensor must have the same number of dimensions as src tensor", 313 ) 314 315 if check_validate: 316 _check_validate(op_info, sample) 317 318 return sample 319 320 321def _validate_sample_input_sparse_reduction_sum(sample, check_validate=False): 322 # NOTE: When fixing a failing sample case, remove the 323 # corresponding if-block 324 t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs 325 dim = t_kwargs.get("dim") 326 keepdim = t_kwargs.get("keepdim") 327 layout = t_inp.layout 328 if isinstance(dim, (int, list, tuple)): 329 if layout in { 330 torch.sparse_csr, 331 torch.sparse_csc, 332 torch.sparse_bsr, 333 torch.sparse_bsc, 334 }: 335 if layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}: 336 return ErrorInput( 337 sample, 338 error_regex=( 339 "Currently the only compressed sparse format supported for sum.dim_IntList is CSR, but got layout" 340 ), 341 ) 342 if layout in {torch.sparse_csr, torch.sparse_csc} and not keepdim: 343 return ErrorInput( 344 sample, 345 error_regex=( 346 "reduction operations on CSR tensors with keepdim=False is unsupported" 347 ), 348 ) 349 if t_inp.dim() != 2: 350 return ErrorInput( 351 sample, 352 error_regex=("input_dim == 2 INTERNAL ASSERT"), 353 ) 354 if layout == torch.sparse_csr: 355 if t_inp.dtype == torch.bool: 356 return ErrorInput( 357 sample, 358 error_regex=("_sparse_csr_sum_cpu not implemented for 'Bool'"), 359 ) 360 if t_inp.dtype == torch.complex32: 361 return ErrorInput( 362 sample, 363 error_regex=( 364 "_sparse_csr_sum_cuda not implemented for 'ComplexHalf'" 365 ), 366 ) 367 return sample 368 369 370def _maybe_failing_sample_inputs_sparse_reduction_sum( 371 op_info, device, dtype, requires_grad, layout, **kwargs 372): 373 """Generator of samples that are known to fail or that were failing in past.""" 374 # NOTE: When fixing a failing case, remove the Exception comment 375 # but keep the `yield sample` statement. 376 if layout in [ 377 torch.sparse_csr, 378 torch.sparse_csc, 379 ]: 380 # NotImplementedError: Could not run 'aten::sum.IntList_out' with arguments from the 'SparseCsrCPU' backend. 381 yield SampleInput( 382 torch.tensor([[0, 1], [2, 3]], dtype=dtype) 383 .to_sparse(layout=layout) 384 .requires_grad_(requires_grad), 385 kwargs=dict(dim=0, keepdim=True), 386 ) 387 yield SampleInput( 388 torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype) 389 .to_sparse(layout=layout, dense_dim=1) 390 .requires_grad_(requires_grad), 391 kwargs=dict(dim=0), 392 ) 393 yield SampleInput( 394 torch.tensor([[0, 1], [2, 3]], dtype=dtype) 395 .to_sparse(layout=layout) 396 .requires_grad_(requires_grad), 397 kwargs=dict(dim=(0,)), 398 ) 399 yield SampleInput( 400 torch.tensor([[0, 1], [2, 3]], dtype=dtype) 401 .to_sparse(layout=layout) 402 .requires_grad_(requires_grad), 403 kwargs=dict(dim=(0,), keepdim=True), 404 ) 405 yield SampleInput( 406 torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype) 407 .to_sparse(layout=layout, dense_dim=1) 408 .requires_grad_(requires_grad), 409 kwargs=dict(dim=(0,)), 410 ) 411 412 # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2] 413 yield SampleInput( 414 torch.tensor([[0, 1], [2, 3]], dtype=dtype) 415 .to_sparse(layout=layout) 416 .requires_grad_(requires_grad), 417 kwargs=dict(dim=0), 418 ) 419 420 if layout in [ 421 torch.sparse_bsr, 422 torch.sparse_bsc, 423 ]: 424 # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr 425 yield SampleInput( 426 torch.tensor([[0, 1], [2, 3]], dtype=dtype) 427 .to_sparse(layout=layout, blocksize=(2, 2)) 428 .requires_grad_(requires_grad), 429 kwargs=dict(dim=0, keepdim=True), 430 ) 431 yield SampleInput( 432 torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype) 433 .to_sparse(layout=layout, dense_dim=1, blocksize=(1, 1)) 434 .requires_grad_(requires_grad), 435 kwargs=dict(dim=0), 436 ) 437 yield SampleInput( 438 torch.tensor([[0, 1], [2, 3]], dtype=dtype) 439 .to_sparse(layout=layout, blocksize=(1, 1)) 440 .requires_grad_(requires_grad), 441 kwargs=dict(dim=(0,)), 442 ) 443 yield SampleInput( 444 torch.tensor([[0, 1], [2, 3]], dtype=dtype) 445 .to_sparse(layout=layout, blocksize=(1, 1)) 446 .requires_grad_(requires_grad), 447 kwargs=dict(dim=(0,), keepdim=True), 448 ) 449 yield SampleInput( 450 torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype) 451 .to_sparse(layout=layout, blocksize=(1, 1), dense_dim=1) 452 .requires_grad_(requires_grad), 453 kwargs=dict(dim=(0,)), 454 ) 455 456 # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2] 457 yield SampleInput( 458 torch.tensor([[0, 1], [2, 3]], dtype=dtype) 459 .to_sparse(layout=layout, blocksize=(1, 1)) 460 .requires_grad_(requires_grad), 461 kwargs=dict(dim=0), 462 ) 463 464 465def sample_inputs_sparse_reduction_sum( 466 op_info, device, dtype, requires_grad, layout, **kwargs 467): 468 """Sample inputs for sum on sparse tensors.""" 469 yield from _sample_inputs_sparse( 470 sample_inputs_sparse_reduction, 471 _maybe_failing_sample_inputs_sparse_reduction_sum, 472 _validate_sample_input_sparse_reduction, 473 op_info, 474 device, 475 dtype, 476 requires_grad, 477 layout, 478 **kwargs, 479 ) 480 481 482def error_inputs_sparse_reduction_sum(op_info, device, layout, **kwargs): 483 """Error inputs for sum on sparse tensors.""" 484 dtype = torch.float64 485 requires_grad = False 486 yield from _error_inputs_sparse( 487 _maybe_failing_sample_inputs_sparse_reduction_sum, 488 _validate_sample_input_sparse_reduction, 489 op_info, 490 device, 491 dtype, 492 requires_grad, 493 layout, 494 **kwargs, 495 ) 496 497 498def sample_inputs_sparse_elementwise_binary_operation( 499 op_info, device, dtype, requires_grad, layout, **kwargs 500): 501 """Sample inputs for elementwise binary operations on sparse tensors. 502 503 The samples include regular, zero-sized, batched, and hybrid 504 sparse tensors as well as rhs scalars. All tensors are full tensors. 505 """ 506 507 def _to_sparse(tensor, **kwargs): 508 return tensor.detach().to_sparse(**kwargs).requires_grad_(requires_grad) 509 510 for sample_input in generate_elementwise_binary_tensors( 511 op_info, 512 device=device, 513 dtype=dtype, 514 requires_grad=requires_grad, 515 exclude_zero=True, 516 **kwargs, 517 ): 518 lhs, rhs = sample_input.input, sample_input.args[0] 519 min_dense_dim = 0 520 max_dense_dim = lhs.ndim - 1 521 if layout in { 522 torch.sparse_csr, 523 torch.sparse_csc, 524 torch.sparse_bsr, 525 torch.sparse_bsc, 526 }: 527 if lhs.ndim < 2: 528 # sparse compressed tensors sparse_dim must be 2 529 continue 530 max_dense_dim = lhs.ndim - 2 531 532 for dense_dim in range(min_dense_dim, max_dense_dim + 1): 533 if layout in {torch.sparse_bsr, torch.sparse_bsc}: 534 blocksizes = [(1, 1)] 535 if lhs.numel() > 0: 536 blocksizes.append( 537 ( 538 lhs.shape[lhs.ndim - 2 - dense_dim], 539 lhs.shape[lhs.ndim - 1 - dense_dim], 540 ) 541 ) 542 else: 543 blocksizes = [None] 544 for blocksize in blocksizes: 545 to_sparse_kwargs = dict( 546 layout=layout, dense_dim=dense_dim, blocksize=blocksize 547 ) 548 lhs_sparse = _to_sparse(lhs, **to_sparse_kwargs) 549 rhs_sparse = _to_sparse(rhs, **to_sparse_kwargs) 550 # op(sparse, sparse) 551 yield SampleInput( 552 lhs_sparse, 553 args=(rhs_sparse, *sample_input.args[1:]), 554 kwargs=sample_input.kwargs, 555 ) 556 # op(sparse, scalar) 557 yield SampleInput( 558 lhs_sparse, 559 args=( 560 make_tensor( 561 (), dtype=dtype, device=device, requires_grad=requires_grad 562 ), 563 *sample_input.args[1:], 564 ), 565 kwargs=sample_input.kwargs, 566 ) 567 568 569def _validate_sample_input_elementwise_binary_sparse_mul(sample): 570 # NOTE: When fixing a failing sample case, remove the 571 # corresponding if-block 572 t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs 573 batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim() 574 layout = t_inp.layout 575 dtype = t_inp.dtype 576 if layout is torch.sparse_csr and batch_dim > 0 and t_args[0].ndim > 0: 577 return ErrorInput( 578 sample, 579 error_regex=( 580 "coo_to_sparse_csr: conversion from Sparse to SparseCsr for input" 581 " tensors with sparse_dim[(][)]!=2 is not supported" 582 ), 583 ) 584 elif layout is torch.sparse_csc and t_args[0].ndim > 0: 585 return ErrorInput( 586 sample, error_regex="Expected result Tensor to be of format CSR" 587 ) 588 elif layout is torch.sparse_bsr and t_args[0].ndim > 0: 589 return ErrorInput( 590 sample, 591 error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsr", 592 ) 593 elif layout is torch.sparse_bsc and t_args[0].ndim > 0: 594 return ErrorInput( 595 sample, 596 error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsc", 597 ) 598 elif ( 599 layout is torch.sparse_coo 600 and dtype is torch.bool 601 and t_args[0].ndim > 0 602 and t_inp.is_cpu 603 and t_inp.numel() > 0 604 and t_inp.dense_dim() > 0 605 ): 606 return ErrorInput( 607 sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Bool'" 608 ) 609 elif ( 610 layout in {torch.sparse_coo, torch.sparse_csr} 611 and dtype is torch.bool 612 and t_inp._nnz() > 0 613 and t_args[0].ndim > 0 614 and t_inp.is_cpu 615 and t_inp.numel() > 0 616 ): 617 return ErrorInput( 618 sample, error_regex="\"mul_out_sparse\" not implemented for 'Bool'" 619 ) 620 elif ( 621 layout is torch.sparse_csr 622 and t_args[0].layout is torch.strided 623 and 0 < t_args[0].ndim 624 and t_args[0].ndim < t_inp.ndim 625 ): 626 return ErrorInput( 627 sample, error_regex="sparse_mask_sparse_csr expects self to be 2D" 628 ) 629 elif layout is torch.sparse_csr and ( 630 (t_args[0].layout is torch.strided and 0 < t_args[0].ndim) 631 or (t_args[0].layout is layout and t_inp.shape != t_args[0].shape) 632 ): 633 return ErrorInput( 634 sample, 635 error_regex=( 636 "expects sparse inputs with equal dimensionality, number of sparse dimensions," 637 " and shape of sparse dimensions" 638 ), 639 ) 640 elif ( 641 layout is torch.sparse_csr 642 and t_inp.dense_dim() > 0 643 and t_inp._nnz() > 0 644 and t_inp.is_cpu 645 and dtype is torch.float16 646 and t_args[0].ndim > 0 647 ): 648 return ErrorInput( 649 sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Half'" 650 ) 651 return sample 652 653 654@_apply_requires_grad_to_samples 655def _maybe_failing_sample_inputs_sparse_elementwise_binary_mul( 656 op_info, device, dtype, requires_grad, layout, **kwargs 657): 658 """Generator of samples that are known to fail or that were failing in past.""" 659 # NOTE: When fixing a failing case, remove the Exception comment 660 # but keep the `yield sample` statement. 661 662 blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None 663 regular = torch.tensor([[1, 2], [3, 4]], device=device, dtype=dtype).to_sparse( 664 layout=layout, dense_dim=0, blocksize=blocksize 665 ) 666 batch = torch.tensor( 667 [[[1, 2], [3, 4]], [[4, 5], [6, 7]]], device=device, dtype=dtype 668 ).to_sparse(layout=layout, dense_dim=0, blocksize=blocksize) 669 hybrid = torch.tensor( 670 [[[1], [2]], [[3], [4]]], device=device, dtype=dtype 671 ).to_sparse(layout=layout, dense_dim=1, blocksize=blocksize) 672 673 if layout is torch.sparse_csr: 674 # RuntimeError: crow_indices is supposed to be a vector, but got 2 dimensional tensor 675 yield SampleInput(batch, args=(batch,)) 676 # RuntimeError: Only tensors with two sparse dimensions can be 677 # converted to the SparseCsr layout, got self with 3 sparse 678 # dimensions. 679 yield SampleInput( 680 torch.zeros_like(hybrid).requires_grad_(requires_grad), 681 args=(torch.zeros_like(hybrid).requires_grad_(requires_grad),), 682 ) 683 if dtype is torch.complex32: 684 # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf' 685 yield SampleInput(regular, args=(regular,)) 686 if dtype is torch.bool and regular.is_cpu: 687 # RuntimeError: "mul_out_sparse" not implemented for 'Bool' 688 yield SampleInput(regular, args=(regular,)) 689 if layout is torch.sparse_csc: 690 # RuntimeError: Expected result Tensor to be of format CSR 691 yield SampleInput(regular, args=(regular,)) 692 if layout is torch.sparse_bsr: 693 # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr 694 yield SampleInput(regular, args=(regular,)) 695 if layout is torch.sparse_bsc: 696 # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsc 697 yield SampleInput(regular, args=(regular,)) 698 if layout is torch.sparse_coo: 699 if dtype is torch.complex32: 700 # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf' 701 yield SampleInput(regular, args=(regular,)) 702 if dtype is torch.bool and regular.is_cpu: 703 # RuntimeError: "mul_out_sparse" not implemented for 'Bool' 704 yield SampleInput(regular, args=(regular,)) 705 if dtype in {torch.bool, torch.float16} and regular.is_cpu: 706 # RuntimeError: "addcmul_cpu_out" not implemented for '(Bool|Half)' 707 yield SampleInput(hybrid, args=(hybrid,)) 708 709 710def _validate_sample_input_sparse_elementwise_binary_operation( 711 op_info, sample, check_validate=False 712): 713 if op_info.name == "mul": 714 sample = _validate_sample_input_elementwise_binary_sparse_mul(sample) 715 716 if check_validate: 717 _check_validate(op_info, sample) 718 return sample 719 720 721def sample_inputs_sparse_mul(op_info, device, dtype, requires_grad, layout, **kwargs): 722 """Sample inputs for mul operation on sparse tensors.""" 723 yield from _sample_inputs_sparse( 724 sample_inputs_sparse_elementwise_binary_operation, 725 _maybe_failing_sample_inputs_sparse_elementwise_binary_mul, 726 _validate_sample_input_sparse_elementwise_binary_operation, 727 op_info, 728 device, 729 dtype, 730 requires_grad, 731 layout, 732 **kwargs, 733 ) 734 735 736def error_inputs_sparse_mul(op_info, device, layout, **kwargs): 737 """Error inputs for mul operation on sparse tensors.""" 738 dtype = torch.float64 739 requires_grad = False 740 yield from _error_inputs_sparse( 741 _maybe_failing_sample_inputs_sparse_elementwise_binary_mul, 742 _validate_sample_input_sparse_elementwise_binary_operation, 743 op_info, 744 device, 745 dtype, 746 requires_grad, 747 layout, 748 **kwargs, 749 ) 750 751 752def _sample_inputs_sparse_like_fns( 753 op_info, device, dtype, requires_grad, layout, **kwargs 754): 755 from torch.testing._internal.common_utils import TestCase 756 757 for tensor in TestCase().generate_simple_inputs( 758 layout, 759 device=device, 760 dtype=dtype, 761 enable_batch=True, 762 enable_hybrid=True, 763 enable_zero_sized=True, 764 enable_non_contiguous_indices=False, 765 enable_non_contiguous_values=False, 766 ): 767 yield SampleInput(tensor, args=(), kwargs={}) 768 yield SampleInput( 769 tensor, args=(), kwargs=dict(device=device, dtype=dtype, layout=layout) 770 ) 771 772 if dtype is not torch.float64: 773 yield SampleInput(tensor, args=(), kwargs=dict(dtype=torch.float64)) 774 775 if torch.cuda.is_available(): 776 other_device = "cuda" if tensor.device.type == "cpu" else "cpu" 777 yield SampleInput(tensor, args=(), kwargs=dict(device=other_device)) 778 779 if layout is torch.sparse_csr: 780 other_layout = torch.sparse_csc 781 elif layout is torch.sparse_csc: 782 other_layout = torch.sparse_csr 783 elif layout is torch.sparse_bsr: 784 other_layout = torch.sparse_bsc 785 elif layout is torch.sparse_bsc: 786 other_layout = torch.sparse_bsr 787 else: 788 other_layout = torch.strided 789 yield SampleInput(tensor, args=(), kwargs=dict(layout=other_layout)) 790 791 if layout is not torch.sparse_coo: 792 yield SampleInput(tensor, args=(), kwargs=dict(layout=torch.sparse_coo)) 793 794 795def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False): 796 if sample.input.layout in { 797 torch.sparse_csr, 798 torch.sparse_csc, 799 torch.sparse_bsr, 800 torch.sparse_bsc, 801 } and op_info.name not in {"zeros_like"}: 802 if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout: 803 return ErrorInput( 804 sample, 805 error_regex=( 806 "empty_like with different sparse layout is not supported" 807 " \\(self is Sparse(Csc|Csr|Bsc|Bsr) but you requested Sparse(Csr|Csc|Bsr|Bsc)\\)" 808 ), 809 ) 810 if sample.input.layout is torch.sparse_coo: 811 return ErrorInput( 812 sample, 813 error_regex=( 814 "Could not run 'aten::normal_' with arguments from the 'Sparse(CPU|CUDA)' backend." 815 ), 816 ) 817 if check_validate: 818 _check_validate(op_info, sample) 819 return sample 820 821 822def _maybe_failing_sample_inputs_sparse_like_fns( 823 op_info, device, dtype, requires_grad, layout, **kwargs 824): 825 if torch.cuda.is_available() and layout is not torch.sparse_coo: 826 other_device = "cuda" if torch.device(device).type == "cpu" else "cpu" 827 if layout is torch.sparse_csr: 828 other_layout = torch.sparse_csc 829 elif layout is torch.sparse_csc: 830 other_layout = torch.sparse_csr 831 elif layout is torch.sparse_bsr: 832 other_layout = torch.sparse_bsc 833 elif layout is torch.sparse_bsc: 834 other_layout = torch.sparse_bsr 835 else: 836 other_layout = torch.strided 837 838 blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None 839 840 yield SampleInput( 841 torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse( 842 layout=layout, blocksize=blocksize 843 ), 844 kwargs=dict(device=other_device), 845 ) 846 847 yield SampleInput( 848 torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse( 849 layout=layout, blocksize=blocksize 850 ), 851 kwargs=dict(layout=other_layout), 852 ) 853 854 855def sample_inputs_sparse_like_fns( 856 op_info, device, dtype, requires_grad, layout, **kwargs 857): 858 """Sample inputs for like-functions on sparse tensors.""" 859 yield from _sample_inputs_sparse( 860 _sample_inputs_sparse_like_fns, 861 _maybe_failing_sample_inputs_sparse_like_fns, 862 _validate_sample_input_sparse_like_fns, 863 op_info, 864 device, 865 dtype, 866 requires_grad, 867 layout, 868 **kwargs, 869 ) 870 871 872def error_inputs_sparse_like_fns(op_info, device, layout, **kwargs): 873 """Error inputs for like-functions on sparse tensors.""" 874 dtype = torch.float64 875 requires_grad = False 876 yield from _error_inputs_sparse( 877 _maybe_failing_sample_inputs_sparse_like_fns, 878 _validate_sample_input_sparse_like_fns, 879 op_info, 880 device, 881 dtype, 882 requires_grad, 883 layout, 884 **kwargs, 885 ) 886 887 888def _validate_sample_input_sparse_default(op_info, sample, check_validate=False): 889 if op_info.name == "to_sparse": 890 if ( 891 sample.input.layout 892 in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} 893 and len(sample.args) == 1 894 and isinstance(sample.args[0], int) 895 and sample.args[0] != 2 896 ): 897 sample = ErrorInput( 898 sample, 899 error_regex="sparse dim argument must be 2 for sparse_compressed_to_sparse", 900 ) 901 902 if check_validate: 903 _check_validate(op_info, sample) 904 return sample 905 906 907def validate_sample_input_sparse(op_info, sample, check_validate=False): 908 """Return the specified sample when it is valid and supported by the 909 operation. Otherwise, return the sample as ErrorInput instance. 910 911 When check_validate is True, the result is validated against 912 calling the op on the sample. 913 """ 914 if isinstance(op_info, ReductionOpInfo): 915 return _validate_sample_input_sparse_reduction( 916 op_info, sample, check_validate=check_validate 917 ) 918 elif isinstance(op_info, BinaryUfuncInfo): 919 return _validate_sample_input_sparse_elementwise_binary_operation( 920 op_info, sample, check_validate=check_validate 921 ) 922 else: 923 return _validate_sample_input_sparse_default( 924 op_info, sample, check_validate=check_validate 925 ) 926