xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/opinfo/definitions/sparse.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import os
4
5import torch
6from torch.testing import make_tensor  # noqa: F401
7from torch.testing._internal.opinfo.core import (  # noqa: F401
8    BinaryUfuncInfo,
9    ErrorInput,
10    generate_elementwise_binary_tensors,
11    ReductionOpInfo,
12    sample_inputs_reduction,
13    SampleInput,
14)
15
16
17def _check_validate(op_info, sample):
18    def _check_fail(sample):
19        try:
20            op_info(
21                sample.sample_input.input,
22                *sample.sample_input.args,
23                **sample.sample_input.kwargs,
24            )
25        except sample.error_type:
26            pass
27        except Exception as msg:
28            raise AssertionError(  # noqa: B904
29                f"{op_info.name} on {sample.sample_input=} expected exception "
30                f"{sample.error_type}: {sample.error_regex}, got {type(msg).__name__}: {msg}"
31            )
32        else:
33            raise AssertionError(
34                f"{op_info.name} on {sample.sample_input=} expected exception "
35                f"{sample.error_type}: {sample.error_regex}, got none."
36            )
37
38    def _check_success(sample):
39        try:
40            op_info(sample.input, *sample.args, **sample.kwargs)
41        except Exception as msg:
42            raise AssertionError(  # noqa: B904
43                f"{op_info.name} on {sample=} expected to succeed "
44                f", got {type(msg).__name__}: {msg}"
45            )
46
47    if isinstance(sample, ErrorInput):
48        _check_fail(sample)
49    else:
50        _check_success(sample)
51
52
53def _sample_inputs_sparse(
54    sample_inputs,
55    maybe_failing_sample_inputs,
56    validate_sample_input,
57    op_info,
58    *args,
59    **kwargs,
60):
61    check_validate = (
62        os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1"
63    )
64    for sample in sample_inputs(op_info, *args, **kwargs):
65        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
66        if isinstance(sample, SampleInput):
67            yield sample
68        # Error inputs are handled in error_inputs_sparse
69
70    for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs):
71        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
72        if isinstance(sample, SampleInput):
73            yield sample
74
75
76def _error_inputs_sparse(
77    maybe_failing_sample_inputs, validate_sample_input, op_info, *args, **kwargs
78):
79    check_validate = (
80        os.environ.get("PYTORCH_TEST_CHECK_VALIDATE_SPARSE_SAMPLES", "0") == "1"
81    )
82    for sample in maybe_failing_sample_inputs(op_info, *args, **kwargs):
83        sample = validate_sample_input(op_info, sample, check_validate=check_validate)
84        if isinstance(sample, ErrorInput):
85            yield sample
86        # Sample inputs are handled in sample_inputs_sparse
87
88
89def _apply_requires_grad_to_samples(sample_inputs):
90    """Decorator to _maybe_failing_sample_inputs_... generator functions
91    that clones and sets requires_grad argument to tensors in sample
92    input arguments. This is needed when the generated samples share
93    tensor instances.
94    """
95
96    def wrapper(op_info, device, dtype, requires_grad, layout, **kwargs):
97        def apply_requires_grad(x):
98            if (
99                not isinstance(x, torch.Tensor)
100                or x.requires_grad
101                or not requires_grad
102                or not (x.is_floating_point() or x.is_complex())
103            ):
104                return x
105            return x.detach().clone().requires_grad_(requires_grad)
106
107        if requires_grad:
108            for sample_input in sample_inputs(
109                op_info, device, dtype, requires_grad, layout, **kwargs
110            ):
111                yield sample_input.transform(apply_requires_grad)
112        else:
113            yield from sample_inputs(
114                op_info, device, dtype, requires_grad, layout, **kwargs
115            )
116
117    return wrapper
118
119
120def sample_inputs_sparse_reduction(
121    op_info, device, dtype, requires_grad, layout, blocksize=None, **kwargs
122):
123    """Sample inputs for reduction operations on sparse tensors."""
124    layout_name = str(layout).split(".", 1)[-1].rsplit("_coo", 1)[0]
125    op_supports_layout = getattr(op_info, "supports_" + layout_name)
126    if not op_supports_layout:
127        return
128
129    for sample_input in sample_inputs_reduction(
130        op_info, device, dtype, requires_grad, **kwargs
131    ):
132        if sample_input.input.ndim == 0:
133            # scalar sparse tensors are not supported
134            continue
135
136        if layout in {
137            torch.sparse_csr,
138            torch.sparse_csc,
139            torch.sparse_bsr,
140            torch.sparse_bsc,
141        }:
142            if sample_input.input.ndim < 2:
143                # conversion to sparse compressed tensors requires at
144                # least 2 dimensional tensors
145                continue
146            if sample_input.input.ndim > 2 and (sample_input.input == 0).any():
147                # Skip batched sparse compressed samples that contain
148                # explicit zeros because to_sparse(layout=..) will
149                # fail, see gh-98495.
150                # TODO: remove this if-block after gh-98495 is fixed.
151                continue
152
153        if layout in {torch.sparse_bsr, torch.sparse_bsc} and blocksize is None:
154            blocksize = (1, 1)
155
156        yield SampleInput(
157            sample_input.input.detach()
158            .to_sparse(layout=layout, blocksize=blocksize)
159            .requires_grad_(requires_grad),
160            args=sample_input.args,
161            kwargs=sample_input.kwargs,
162        )
163
164        if layout is torch.sparse_coo and (dtype.is_floating_point or dtype.is_complex):
165            # uncoalesced samples
166            inp = sample_input.input.detach().to_sparse(layout=layout)
167            inp = torch.sparse_coo_tensor(
168                inp.indices().repeat(1, 2),
169                inp.values().repeat(2),
170                inp.shape,
171                dtype=inp.dtype,
172                device=inp.device,
173            )
174            assert not inp.is_coalesced()
175            yield SampleInput(
176                inp.requires_grad_(requires_grad),
177                args=sample_input.args,
178                kwargs=sample_input.kwargs,
179            )
180
181        if sample_input.input.ndim > 2:
182            # hybrid samples
183            yield SampleInput(
184                sample_input.input.detach()
185                .to_sparse(
186                    layout=layout,
187                    blocksize=blocksize,
188                    dense_dim=sample_input.input.ndim - 2,
189                )
190                .requires_grad_(requires_grad),
191                args=sample_input.args,
192                kwargs=sample_input.kwargs,
193            )
194
195
196def _validate_sample_input_sparse_reduction(op_info, sample, check_validate=False):
197    """Return the specified sample when it is valid and supported by the
198    operation. Otherwise, return the sample as ErrorInput instance.
199
200    When check_validate is True, the result is validated against
201    calling the op on the sample.
202    """
203    UNSPECIFIED = object()
204    if op_info.name == "sum":
205        sample = _validate_sample_input_sparse_reduction_sum(sample)
206
207    if op_info.name in {"masked.sum"}:
208        mask = sample.kwargs.get("mask", UNSPECIFIED)
209        if (
210            mask not in {None, UNSPECIFIED}
211            and mask.ndim > 2
212            and mask.layout is torch.strided
213            and (mask == 0).any()
214        ):
215            # TODO: remove this if-block after gh-98495 is fixed.
216            sample = ErrorInput(
217                sample,
218                error_regex="Expect the same number of specified elements per batch.",
219            )
220        elif not sample.kwargs.get("keepdim"):
221            sample = ErrorInput(
222                sample,
223                error_type=(AssertionError, RuntimeError),
224                error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported",
225            )
226        elif mask is UNSPECIFIED:
227            sample = ErrorInput(
228                sample,
229                error_type=ValueError,
230                error_regex="masked (.*) expects explicit mask for sparse_csr tensor input",
231            )
232        elif sample.input.ndim > 2:
233            sample = ErrorInput(
234                sample,
235                error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.",
236            )
237
238    if op_info.name in {"masked.amax", "masked.amin", "masked.mean", "masked.prod"}:
239        t_inp = sample.input
240        batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
241        mask = sample.kwargs.get("mask")
242        if (
243            mask is not None
244            and mask.ndim > 2
245            and mask.layout is torch.strided
246            and (mask == 0).any()
247        ):
248            # TODO: remove this if-block after gh-98495 is fixed.
249            sample = ErrorInput(
250                sample,
251                error_regex="Expect the same number of specified elements per batch.",
252            )
253        elif mask is None:
254            sample = ErrorInput(
255                sample,
256                error_type=ValueError,
257                error_regex="masked (.*) expects explicit mask for sparse_csr tensor input",
258            )
259        elif (
260            mask.layout is sample.input.layout
261            and mask.ndim > 2
262            and op_info.name == "masked.mean"
263        ):
264            sample = ErrorInput(
265                sample,
266                error_type=TypeError,
267                error_regex=(
268                    "where[(][)] received an invalid combination of arguments"
269                    " - got [(]Tensor, Tensor, NoneType[)]"
270                ),
271            )
272        elif not sample.kwargs.get("keepdim"):
273            sample = ErrorInput(
274                sample,
275                error_type=(AssertionError, RuntimeError),
276                error_regex="reduction operations on (CSR|CSC) tensors with keepdim=False is unsupported",
277            )
278        elif (
279            sample.input.ndim > 2
280            and (sample.kwargs.get("dim") not in {0, 1})
281            and mask.ndim > 2
282            and mask.layout is not torch.strided
283        ):
284            if sample.kwargs.get("dim") == (0, -1):
285                sample = ErrorInput(
286                    sample,
287                    error_regex="tensor dimensionality must be sum of batch, base, and dense dimensionalities",
288                )
289            elif op_info.name == "masked.prod":
290                sample = ErrorInput(
291                    sample,
292                    error_regex="input_dim == 2 INTERNAL ASSERT FAILED at",
293                )
294            else:
295                sample = ErrorInput(
296                    sample,
297                    error_type=AssertionError,
298                    error_regex="Sparse CSR tensors are 2D and only support reduction along dim 0 or 1.",
299                )
300        elif sample.input.ndim > 2:
301            sample = ErrorInput(
302                sample,
303                error_regex="crow_indices is supposed to be a vector, but got 3 dimensional tensor.",
304            )
305        elif (
306            mask.layout is t_inp.layout
307            and mask._nnz() != t_inp._nnz()
308            and t_inp.dense_dim() > 0
309        ):
310            sample = ErrorInput(
311                sample,
312                error_regex="Index tensor must have the same number of dimensions as src tensor",
313            )
314
315    if check_validate:
316        _check_validate(op_info, sample)
317
318    return sample
319
320
321def _validate_sample_input_sparse_reduction_sum(sample, check_validate=False):
322    # NOTE: When fixing a failing sample case, remove the
323    #       corresponding if-block
324    t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
325    dim = t_kwargs.get("dim")
326    keepdim = t_kwargs.get("keepdim")
327    layout = t_inp.layout
328    if isinstance(dim, (int, list, tuple)):
329        if layout in {
330            torch.sparse_csr,
331            torch.sparse_csc,
332            torch.sparse_bsr,
333            torch.sparse_bsc,
334        }:
335            if layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}:
336                return ErrorInput(
337                    sample,
338                    error_regex=(
339                        "Currently the only compressed sparse format supported for sum.dim_IntList is CSR, but got layout"
340                    ),
341                )
342            if layout in {torch.sparse_csr, torch.sparse_csc} and not keepdim:
343                return ErrorInput(
344                    sample,
345                    error_regex=(
346                        "reduction operations on CSR tensors with keepdim=False is unsupported"
347                    ),
348                )
349            if t_inp.dim() != 2:
350                return ErrorInput(
351                    sample,
352                    error_regex=("input_dim == 2 INTERNAL ASSERT"),
353                )
354            if layout == torch.sparse_csr:
355                if t_inp.dtype == torch.bool:
356                    return ErrorInput(
357                        sample,
358                        error_regex=("_sparse_csr_sum_cpu not implemented for 'Bool'"),
359                    )
360                if t_inp.dtype == torch.complex32:
361                    return ErrorInput(
362                        sample,
363                        error_regex=(
364                            "_sparse_csr_sum_cuda not implemented for 'ComplexHalf'"
365                        ),
366                    )
367    return sample
368
369
370def _maybe_failing_sample_inputs_sparse_reduction_sum(
371    op_info, device, dtype, requires_grad, layout, **kwargs
372):
373    """Generator of samples that are known to fail or that were failing in past."""
374    # NOTE: When fixing a failing case, remove the Exception comment
375    #       but keep the `yield sample` statement.
376    if layout in [
377        torch.sparse_csr,
378        torch.sparse_csc,
379    ]:
380        # NotImplementedError: Could not run 'aten::sum.IntList_out' with arguments from the 'SparseCsrCPU' backend.
381        yield SampleInput(
382            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
383            .to_sparse(layout=layout)
384            .requires_grad_(requires_grad),
385            kwargs=dict(dim=0, keepdim=True),
386        )
387        yield SampleInput(
388            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
389            .to_sparse(layout=layout, dense_dim=1)
390            .requires_grad_(requires_grad),
391            kwargs=dict(dim=0),
392        )
393        yield SampleInput(
394            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
395            .to_sparse(layout=layout)
396            .requires_grad_(requires_grad),
397            kwargs=dict(dim=(0,)),
398        )
399        yield SampleInput(
400            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
401            .to_sparse(layout=layout)
402            .requires_grad_(requires_grad),
403            kwargs=dict(dim=(0,), keepdim=True),
404        )
405        yield SampleInput(
406            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
407            .to_sparse(layout=layout, dense_dim=1)
408            .requires_grad_(requires_grad),
409            kwargs=dict(dim=(0,)),
410        )
411
412        # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2]
413        yield SampleInput(
414            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
415            .to_sparse(layout=layout)
416            .requires_grad_(requires_grad),
417            kwargs=dict(dim=0),
418        )
419
420    if layout in [
421        torch.sparse_bsr,
422        torch.sparse_bsc,
423    ]:
424        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr
425        yield SampleInput(
426            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
427            .to_sparse(layout=layout, blocksize=(2, 2))
428            .requires_grad_(requires_grad),
429            kwargs=dict(dim=0, keepdim=True),
430        )
431        yield SampleInput(
432            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
433            .to_sparse(layout=layout, dense_dim=1, blocksize=(1, 1))
434            .requires_grad_(requires_grad),
435            kwargs=dict(dim=0),
436        )
437        yield SampleInput(
438            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
439            .to_sparse(layout=layout, blocksize=(1, 1))
440            .requires_grad_(requires_grad),
441            kwargs=dict(dim=(0,)),
442        )
443        yield SampleInput(
444            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
445            .to_sparse(layout=layout, blocksize=(1, 1))
446            .requires_grad_(requires_grad),
447            kwargs=dict(dim=(0,), keepdim=True),
448        )
449        yield SampleInput(
450            torch.tensor([[[0, 1]], [[2, 3]]], dtype=dtype)
451            .to_sparse(layout=layout, blocksize=(1, 1), dense_dim=1)
452            .requires_grad_(requires_grad),
453            kwargs=dict(dim=(0,)),
454        )
455
456        # RuntimeError: torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size [2]
457        yield SampleInput(
458            torch.tensor([[0, 1], [2, 3]], dtype=dtype)
459            .to_sparse(layout=layout, blocksize=(1, 1))
460            .requires_grad_(requires_grad),
461            kwargs=dict(dim=0),
462        )
463
464
465def sample_inputs_sparse_reduction_sum(
466    op_info, device, dtype, requires_grad, layout, **kwargs
467):
468    """Sample inputs for sum on sparse tensors."""
469    yield from _sample_inputs_sparse(
470        sample_inputs_sparse_reduction,
471        _maybe_failing_sample_inputs_sparse_reduction_sum,
472        _validate_sample_input_sparse_reduction,
473        op_info,
474        device,
475        dtype,
476        requires_grad,
477        layout,
478        **kwargs,
479    )
480
481
482def error_inputs_sparse_reduction_sum(op_info, device, layout, **kwargs):
483    """Error inputs for sum on sparse tensors."""
484    dtype = torch.float64
485    requires_grad = False
486    yield from _error_inputs_sparse(
487        _maybe_failing_sample_inputs_sparse_reduction_sum,
488        _validate_sample_input_sparse_reduction,
489        op_info,
490        device,
491        dtype,
492        requires_grad,
493        layout,
494        **kwargs,
495    )
496
497
498def sample_inputs_sparse_elementwise_binary_operation(
499    op_info, device, dtype, requires_grad, layout, **kwargs
500):
501    """Sample inputs for elementwise binary operations on sparse tensors.
502
503    The samples include regular, zero-sized, batched, and hybrid
504    sparse tensors as well as rhs scalars. All tensors are full tensors.
505    """
506
507    def _to_sparse(tensor, **kwargs):
508        return tensor.detach().to_sparse(**kwargs).requires_grad_(requires_grad)
509
510    for sample_input in generate_elementwise_binary_tensors(
511        op_info,
512        device=device,
513        dtype=dtype,
514        requires_grad=requires_grad,
515        exclude_zero=True,
516        **kwargs,
517    ):
518        lhs, rhs = sample_input.input, sample_input.args[0]
519        min_dense_dim = 0
520        max_dense_dim = lhs.ndim - 1
521        if layout in {
522            torch.sparse_csr,
523            torch.sparse_csc,
524            torch.sparse_bsr,
525            torch.sparse_bsc,
526        }:
527            if lhs.ndim < 2:
528                # sparse compressed tensors sparse_dim must be 2
529                continue
530            max_dense_dim = lhs.ndim - 2
531
532        for dense_dim in range(min_dense_dim, max_dense_dim + 1):
533            if layout in {torch.sparse_bsr, torch.sparse_bsc}:
534                blocksizes = [(1, 1)]
535                if lhs.numel() > 0:
536                    blocksizes.append(
537                        (
538                            lhs.shape[lhs.ndim - 2 - dense_dim],
539                            lhs.shape[lhs.ndim - 1 - dense_dim],
540                        )
541                    )
542            else:
543                blocksizes = [None]
544            for blocksize in blocksizes:
545                to_sparse_kwargs = dict(
546                    layout=layout, dense_dim=dense_dim, blocksize=blocksize
547                )
548                lhs_sparse = _to_sparse(lhs, **to_sparse_kwargs)
549                rhs_sparse = _to_sparse(rhs, **to_sparse_kwargs)
550                # op(sparse, sparse)
551                yield SampleInput(
552                    lhs_sparse,
553                    args=(rhs_sparse, *sample_input.args[1:]),
554                    kwargs=sample_input.kwargs,
555                )
556                # op(sparse, scalar)
557                yield SampleInput(
558                    lhs_sparse,
559                    args=(
560                        make_tensor(
561                            (), dtype=dtype, device=device, requires_grad=requires_grad
562                        ),
563                        *sample_input.args[1:],
564                    ),
565                    kwargs=sample_input.kwargs,
566                )
567
568
569def _validate_sample_input_elementwise_binary_sparse_mul(sample):
570    # NOTE: When fixing a failing sample case, remove the
571    #       corresponding if-block
572    t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
573    batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
574    layout = t_inp.layout
575    dtype = t_inp.dtype
576    if layout is torch.sparse_csr and batch_dim > 0 and t_args[0].ndim > 0:
577        return ErrorInput(
578            sample,
579            error_regex=(
580                "coo_to_sparse_csr: conversion from Sparse to SparseCsr for input"
581                " tensors with sparse_dim[(][)]!=2 is not supported"
582            ),
583        )
584    elif layout is torch.sparse_csc and t_args[0].ndim > 0:
585        return ErrorInput(
586            sample, error_regex="Expected result Tensor to be of format CSR"
587        )
588    elif layout is torch.sparse_bsr and t_args[0].ndim > 0:
589        return ErrorInput(
590            sample,
591            error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsr",
592        )
593    elif layout is torch.sparse_bsc and t_args[0].ndim > 0:
594        return ErrorInput(
595            sample,
596            error_regex="empty_sparse_compressed expected sparse compressed [(]non-block[)] tensor layout but got SparseBsc",
597        )
598    elif (
599        layout is torch.sparse_coo
600        and dtype is torch.bool
601        and t_args[0].ndim > 0
602        and t_inp.is_cpu
603        and t_inp.numel() > 0
604        and t_inp.dense_dim() > 0
605    ):
606        return ErrorInput(
607            sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Bool'"
608        )
609    elif (
610        layout in {torch.sparse_coo, torch.sparse_csr}
611        and dtype is torch.bool
612        and t_inp._nnz() > 0
613        and t_args[0].ndim > 0
614        and t_inp.is_cpu
615        and t_inp.numel() > 0
616    ):
617        return ErrorInput(
618            sample, error_regex="\"mul_out_sparse\" not implemented for 'Bool'"
619        )
620    elif (
621        layout is torch.sparse_csr
622        and t_args[0].layout is torch.strided
623        and 0 < t_args[0].ndim
624        and t_args[0].ndim < t_inp.ndim
625    ):
626        return ErrorInput(
627            sample, error_regex="sparse_mask_sparse_csr expects self to be 2D"
628        )
629    elif layout is torch.sparse_csr and (
630        (t_args[0].layout is torch.strided and 0 < t_args[0].ndim)
631        or (t_args[0].layout is layout and t_inp.shape != t_args[0].shape)
632    ):
633        return ErrorInput(
634            sample,
635            error_regex=(
636                "expects sparse inputs with equal dimensionality, number of sparse dimensions,"
637                " and shape of sparse dimensions"
638            ),
639        )
640    elif (
641        layout is torch.sparse_csr
642        and t_inp.dense_dim() > 0
643        and t_inp._nnz() > 0
644        and t_inp.is_cpu
645        and dtype is torch.float16
646        and t_args[0].ndim > 0
647    ):
648        return ErrorInput(
649            sample, error_regex="\"addcmul_cpu_out\" not implemented for 'Half'"
650        )
651    return sample
652
653
654@_apply_requires_grad_to_samples
655def _maybe_failing_sample_inputs_sparse_elementwise_binary_mul(
656    op_info, device, dtype, requires_grad, layout, **kwargs
657):
658    """Generator of samples that are known to fail or that were failing in past."""
659    # NOTE: When fixing a failing case, remove the Exception comment
660    #       but keep the `yield sample` statement.
661
662    blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
663    regular = torch.tensor([[1, 2], [3, 4]], device=device, dtype=dtype).to_sparse(
664        layout=layout, dense_dim=0, blocksize=blocksize
665    )
666    batch = torch.tensor(
667        [[[1, 2], [3, 4]], [[4, 5], [6, 7]]], device=device, dtype=dtype
668    ).to_sparse(layout=layout, dense_dim=0, blocksize=blocksize)
669    hybrid = torch.tensor(
670        [[[1], [2]], [[3], [4]]], device=device, dtype=dtype
671    ).to_sparse(layout=layout, dense_dim=1, blocksize=blocksize)
672
673    if layout is torch.sparse_csr:
674        # RuntimeError: crow_indices is supposed to be a vector, but got 2 dimensional tensor
675        yield SampleInput(batch, args=(batch,))
676        # RuntimeError: Only tensors with two sparse dimensions can be
677        # converted to the SparseCsr layout, got self with 3 sparse
678        # dimensions.
679        yield SampleInput(
680            torch.zeros_like(hybrid).requires_grad_(requires_grad),
681            args=(torch.zeros_like(hybrid).requires_grad_(requires_grad),),
682        )
683        if dtype is torch.complex32:
684            # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf'
685            yield SampleInput(regular, args=(regular,))
686        if dtype is torch.bool and regular.is_cpu:
687            # RuntimeError: "mul_out_sparse" not implemented for 'Bool'
688            yield SampleInput(regular, args=(regular,))
689    if layout is torch.sparse_csc:
690        # RuntimeError: Expected result Tensor to be of format CSR
691        yield SampleInput(regular, args=(regular,))
692    if layout is torch.sparse_bsr:
693        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsr
694        yield SampleInput(regular, args=(regular,))
695    if layout is torch.sparse_bsc:
696        # RuntimeError: empty_sparse_compressed expected sparse compressed (non-block) tensor layout but got SparseBsc
697        yield SampleInput(regular, args=(regular,))
698    if layout is torch.sparse_coo:
699        if dtype is torch.complex32:
700            # RuntimeError: "mul_out_sparse" not implemented for 'ComplexHalf'
701            yield SampleInput(regular, args=(regular,))
702        if dtype is torch.bool and regular.is_cpu:
703            # RuntimeError: "mul_out_sparse" not implemented for 'Bool'
704            yield SampleInput(regular, args=(regular,))
705        if dtype in {torch.bool, torch.float16} and regular.is_cpu:
706            # RuntimeError: "addcmul_cpu_out" not implemented for '(Bool|Half)'
707            yield SampleInput(hybrid, args=(hybrid,))
708
709
710def _validate_sample_input_sparse_elementwise_binary_operation(
711    op_info, sample, check_validate=False
712):
713    if op_info.name == "mul":
714        sample = _validate_sample_input_elementwise_binary_sparse_mul(sample)
715
716    if check_validate:
717        _check_validate(op_info, sample)
718    return sample
719
720
721def sample_inputs_sparse_mul(op_info, device, dtype, requires_grad, layout, **kwargs):
722    """Sample inputs for mul operation on sparse tensors."""
723    yield from _sample_inputs_sparse(
724        sample_inputs_sparse_elementwise_binary_operation,
725        _maybe_failing_sample_inputs_sparse_elementwise_binary_mul,
726        _validate_sample_input_sparse_elementwise_binary_operation,
727        op_info,
728        device,
729        dtype,
730        requires_grad,
731        layout,
732        **kwargs,
733    )
734
735
736def error_inputs_sparse_mul(op_info, device, layout, **kwargs):
737    """Error inputs for mul operation on sparse tensors."""
738    dtype = torch.float64
739    requires_grad = False
740    yield from _error_inputs_sparse(
741        _maybe_failing_sample_inputs_sparse_elementwise_binary_mul,
742        _validate_sample_input_sparse_elementwise_binary_operation,
743        op_info,
744        device,
745        dtype,
746        requires_grad,
747        layout,
748        **kwargs,
749    )
750
751
752def _sample_inputs_sparse_like_fns(
753    op_info, device, dtype, requires_grad, layout, **kwargs
754):
755    from torch.testing._internal.common_utils import TestCase
756
757    for tensor in TestCase().generate_simple_inputs(
758        layout,
759        device=device,
760        dtype=dtype,
761        enable_batch=True,
762        enable_hybrid=True,
763        enable_zero_sized=True,
764        enable_non_contiguous_indices=False,
765        enable_non_contiguous_values=False,
766    ):
767        yield SampleInput(tensor, args=(), kwargs={})
768        yield SampleInput(
769            tensor, args=(), kwargs=dict(device=device, dtype=dtype, layout=layout)
770        )
771
772        if dtype is not torch.float64:
773            yield SampleInput(tensor, args=(), kwargs=dict(dtype=torch.float64))
774
775        if torch.cuda.is_available():
776            other_device = "cuda" if tensor.device.type == "cpu" else "cpu"
777            yield SampleInput(tensor, args=(), kwargs=dict(device=other_device))
778
779        if layout is torch.sparse_csr:
780            other_layout = torch.sparse_csc
781        elif layout is torch.sparse_csc:
782            other_layout = torch.sparse_csr
783        elif layout is torch.sparse_bsr:
784            other_layout = torch.sparse_bsc
785        elif layout is torch.sparse_bsc:
786            other_layout = torch.sparse_bsr
787        else:
788            other_layout = torch.strided
789        yield SampleInput(tensor, args=(), kwargs=dict(layout=other_layout))
790
791        if layout is not torch.sparse_coo:
792            yield SampleInput(tensor, args=(), kwargs=dict(layout=torch.sparse_coo))
793
794
795def _validate_sample_input_sparse_like_fns(op_info, sample, check_validate=False):
796    if sample.input.layout in {
797        torch.sparse_csr,
798        torch.sparse_csc,
799        torch.sparse_bsr,
800        torch.sparse_bsc,
801    } and op_info.name not in {"zeros_like"}:
802        if sample.kwargs.get("layout", sample.input.layout) != sample.input.layout:
803            return ErrorInput(
804                sample,
805                error_regex=(
806                    "empty_like with different sparse layout is not supported"
807                    " \\(self is Sparse(Csc|Csr|Bsc|Bsr) but you requested Sparse(Csr|Csc|Bsr|Bsc)\\)"
808                ),
809            )
810    if sample.input.layout is torch.sparse_coo:
811        return ErrorInput(
812            sample,
813            error_regex=(
814                "Could not run 'aten::normal_' with arguments from the 'Sparse(CPU|CUDA)' backend."
815            ),
816        )
817    if check_validate:
818        _check_validate(op_info, sample)
819    return sample
820
821
822def _maybe_failing_sample_inputs_sparse_like_fns(
823    op_info, device, dtype, requires_grad, layout, **kwargs
824):
825    if torch.cuda.is_available() and layout is not torch.sparse_coo:
826        other_device = "cuda" if torch.device(device).type == "cpu" else "cpu"
827        if layout is torch.sparse_csr:
828            other_layout = torch.sparse_csc
829        elif layout is torch.sparse_csc:
830            other_layout = torch.sparse_csr
831        elif layout is torch.sparse_bsr:
832            other_layout = torch.sparse_bsc
833        elif layout is torch.sparse_bsc:
834            other_layout = torch.sparse_bsr
835        else:
836            other_layout = torch.strided
837
838        blocksize = (1, 1) if layout in {torch.sparse_bsr, torch.sparse_bsc} else None
839
840        yield SampleInput(
841            torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
842                layout=layout, blocksize=blocksize
843            ),
844            kwargs=dict(device=other_device),
845        )
846
847        yield SampleInput(
848            torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device).to_sparse(
849                layout=layout, blocksize=blocksize
850            ),
851            kwargs=dict(layout=other_layout),
852        )
853
854
855def sample_inputs_sparse_like_fns(
856    op_info, device, dtype, requires_grad, layout, **kwargs
857):
858    """Sample inputs for like-functions on sparse tensors."""
859    yield from _sample_inputs_sparse(
860        _sample_inputs_sparse_like_fns,
861        _maybe_failing_sample_inputs_sparse_like_fns,
862        _validate_sample_input_sparse_like_fns,
863        op_info,
864        device,
865        dtype,
866        requires_grad,
867        layout,
868        **kwargs,
869    )
870
871
872def error_inputs_sparse_like_fns(op_info, device, layout, **kwargs):
873    """Error inputs for like-functions on sparse tensors."""
874    dtype = torch.float64
875    requires_grad = False
876    yield from _error_inputs_sparse(
877        _maybe_failing_sample_inputs_sparse_like_fns,
878        _validate_sample_input_sparse_like_fns,
879        op_info,
880        device,
881        dtype,
882        requires_grad,
883        layout,
884        **kwargs,
885    )
886
887
888def _validate_sample_input_sparse_default(op_info, sample, check_validate=False):
889    if op_info.name == "to_sparse":
890        if (
891            sample.input.layout
892            in {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc}
893            and len(sample.args) == 1
894            and isinstance(sample.args[0], int)
895            and sample.args[0] != 2
896        ):
897            sample = ErrorInput(
898                sample,
899                error_regex="sparse dim argument must be 2 for sparse_compressed_to_sparse",
900            )
901
902    if check_validate:
903        _check_validate(op_info, sample)
904    return sample
905
906
907def validate_sample_input_sparse(op_info, sample, check_validate=False):
908    """Return the specified sample when it is valid and supported by the
909    operation. Otherwise, return the sample as ErrorInput instance.
910
911    When check_validate is True, the result is validated against
912    calling the op on the sample.
913    """
914    if isinstance(op_info, ReductionOpInfo):
915        return _validate_sample_input_sparse_reduction(
916            op_info, sample, check_validate=check_validate
917        )
918    elif isinstance(op_info, BinaryUfuncInfo):
919        return _validate_sample_input_sparse_elementwise_binary_operation(
920            op_info, sample, check_validate=check_validate
921        )
922    else:
923        return _validate_sample_input_sparse_default(
924            op_info, sample, check_validate=check_validate
925        )
926