xref: /aosp_15_r20/external/pytorch/torch/masked/_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import warnings
4from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
5
6import torch
7from torch import sym_float, Tensor
8from torch._prims_common import corresponding_real_dtype
9from torch.masked import _docs
10from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor
11from torch.masked.maskedtensor.creation import as_masked_tensor
12
13
14if TYPE_CHECKING:
15    from torch.types import _dtype as DType
16
17    DimOrDims = Optional[Union[int, Tuple[int], List[int]]]
18else:
19    # The JIT doesn't understand Union, nor torch.dtype here
20    DType = int
21    DimOrDims = Optional[Tuple[int]]
22
23
24__all__: List[str] = []
25
26# All masked reduction/normalization operations have the same
27# signatures. Here we introduce docstring templates that are applied
28# to docstrings of reduction/normalization functions via
29# _apply_docstring_templates decorator.
30
31
32def _apply_docstring_templates(func):
33    """Decorator that applies docstring templates to function docstring
34    and returns the function instance.
35    """
36
37    doc_string = getattr(_docs, f"{func.__name__}_docstring", None)
38    if doc_string is None:
39        warnings.warn(
40            f"No documentation string available for {func.__name__}."
41            " PyTorch team should run `python tools/update_masked_docs.py`"
42            " to generate the missing docstrings."
43        )
44    else:
45        func.__doc__ = doc_string
46
47    # Expose function as public symbol
48    __all__.append(func.__name__)
49
50    return func
51
52
53def _generate_docstring(func):
54    """A utility function called from tools/update_masked_docs.py
55    script to update the module torch.masked._docs.py
56    """
57    docstring_templates = dict(
58        reduction_signature="""\
59{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
60        reduction_descr="""\
61Returns {operation name} of all the elements in the :attr:`input`
62tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
63elements are masked out according to the boolean tensor
64:attr:`mask`.""",
65        reduction_args="""\
66If :attr:`keepdim` is ``True``, the output tensor is of the same size
67as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
68size 1. Otherwise, :attr:`dim` is squeezed (see
69:func:`torch.squeeze`), resulting in the output tensor having 1 (or
70``len(dim)``) fewer dimension(s).
71
72The boolean tensor :attr:`mask` defines the "validity" of
73:attr:`input` tensor elements: if :attr:`mask` element is True
74then the corresponding element in :attr:`input` tensor will be
75included in {operation name} computation, otherwise the element is
76ignored.
77
78When all elements of :attr:`input` along the given dimension
79:attr:`dim` are ignored (fully masked-out), the corresponding element
80of the output tensor will have undefined value: it may or may not
81correspond to the identity value of {operation name} operation; the
82choice may correspond to the value that leads to the most efficient
83storage of :attr:`output` tensor.
84
85The mask of the output tensor can be computed as
86``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
87dtype=torch.bool)``.
88
89The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
90don't need to match, but they must be :ref:`broadcastable
91<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
92tensor must not be greater than of the :attr:`input` tensor.
93
94Args:
95    input (Tensor): the input tensor
96    {args_declarations}
97
98Keyword args:
99    {kwargs_declarations}""",
100        reduction_example="""\
101Example::
102
103    >>> input = {example_input}
104    >>> input
105    {indent_example_input}
106    >>> mask = {example_mask}
107    >>> mask
108    {indent_example_mask}
109    >>> {full_function_name}(input, {example_args}, mask=mask)
110    {indent_example_output}
111""",
112        reduction_identity="""\
113The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.""",
114        reduction_identity_dtype="""\
115The identity value of {operation name} operation, which is used to start the
116reduction, depends on input dtype. For instance, for float32, uint8,
117and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.""",
118        normalization_signature="""\
119{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
120        normalization_descr="""\
121Returns {operation name} of all the slices in the :attr:`input` tensor
122along :attr:`dim` while the :attr:`input` elements are masked out
123according to the boolean tensor :attr:`mask`.
124
125{definition}""",
126        normalization_args="""\
127The boolean tensor :attr:`mask` defines the "validity" of
128:attr:`input` tensor elements: if :attr:`mask` element is True then
129the corresponding element in :attr:`input` tensor will be included in
130{operation name} computation, otherwise the element is ignored.
131
132The values of masked-out elements of the output tensor have undefined
133value: it may or may not be set to zero or nan; the choice may correspond to
134the value that leads to the most efficient storage of :attr:`output`
135tensor.
136
137The mask of the {operation name} output tensor can be computed as
138``torch.broadcast_to(mask, input.shape)``.
139
140The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
141don't need to match, but they must be :ref:`broadcastable
142<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
143tensor must not be greater than of the :attr:`input` tensor.
144
145Args:
146    input (Tensor): the input tensor
147    {args_declarations}
148
149Keyword args:
150    {kwargs_declarations}""",
151        normalization_example="""\
152Example::
153
154    >>> input = {example_input}
155    >>> input
156    {indent_example_input}
157    >>> mask = {example_mask}
158    >>> mask
159    {indent_example_mask}
160    >>> {full_function_name}(input, {example_args}, mask=mask)
161    {indent_example_output}
162""",
163    )
164
165    args_and_kwargs = dict(
166        # argument name sufficies separated by double underscore will
167        # be removed in the final documentation string.
168        sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
169        prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
170        cumsum=(("dim__as_int",), ("dtype=None", "mask=None")),
171        cumprod=(("dim__as_int",), ("dtype=None", "mask=None")),
172        amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
173        amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
174        argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
175        argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
176        mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
177        median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
178        norm=(
179            (
180                "ord",
181                "dim",
182            ),
183            ("keepdim=False", "dtype=None", "mask=None"),
184        ),
185        var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
186        std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
187        logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
188        softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
189        log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
190        softmin=(("dim__as_int",), ("dtype=None", "mask=None")),
191        normalize=(
192            (
193                "ord__required",
194                "dim__as_int",
195            ),
196            ("eps=1e-12", "dtype=None", "mask=None"),
197        ),
198    )
199
200    argument_declarations = dict(
201        dim="""\
202dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
203  Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
204        dim__as_int="""\
205dim (int): the dimension along which {operation name} is computed.""",
206        ord="""\
207ord (int, float, optional): the order of vector norm. Default: 2.
208  See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
209        ord__required="""\
210ord (int, float): the order of vector norm. Default: 2.
211  See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
212        unbiased="""\
213unbiased (bool): when True, use Bessel's correction, otherwise, compute
214  the uncorrected sample variance.""",
215        eps="""\
216eps (float, optional): small value to avoid division by zero. Default: {default}.""",
217        keepdim="""\
218keepdim (bool, optional): whether the output tensor has
219  :attr:`dim` retained or not. Default: {default}.""",
220        dtype="""\
221dtype (:class:`torch.dtype`, optional): the desired data type
222  of returned tensor.  If specified, the input tensor is
223  casted to :attr:`dtype` before the operation is
224  performed. Default: {default}.""",
225        mask="""\
226mask (:class:`torch.Tensor`, optional): the boolean tensor
227  containing the binary mask of validity of input tensor
228  elements.
229  Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
230    )
231
232    definitions = dict(
233        softmax="""\
234Let ``x`` be a sequence of unmasked elements of one-dimensional slice
235of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
236defined as ``exp(x[i])/sum(exp(x))``.""",
237        log_softmax="""\
238Let ``x`` be a sequence of unmasked elements of one-dimensional slice
239of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
240defined as ``log(exp(x[i])/sum(exp(x)))``.""",
241        softmin="""\
242Let ``x`` be a sequence of unmasked elements of one-dimensional slice
243of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
244defined as ``exp(-x[i])/sum(exp(-x))``.""",
245        normalize="""\
246Let ``x`` be a sequence of unmasked elements of one-dimensional slice
247of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
248defined as ``x[i]/max(norm(x, p), eps)``.""",
249        cumsum="""\
250Let ``x`` be a sequence of unmasked elements of one-dimensional slice
251of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
252defined as ``sum(x[:i])``.""",
253        cumprod="""\
254Let ``x`` be a sequence of unmasked elements of one-dimensional slice
255of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
256defined as ``prod(x[:i])``.""",
257    )
258
259    reduction_names = dict(
260        sum="sum",
261        prod="product",
262        amax="maximum",
263        amin="minimum",
264        argmax="argmax",
265        argmin="argmin",
266        mean="mean",
267        median="median",
268        norm="norm",
269        var="variance",
270        std="standard_deviation",
271        logsumexp="logsumexp",
272    )
273
274    normalization_names = dict(
275        softmax="softmax",
276        log_softmax="log_softmax",
277        softmin="softmin",
278        normalize="normalize",
279        cumsum="cumulative_sum",
280        cumprod="cumulative_prod",
281    )
282
283    operation_names = {}
284    operation_names.update(reduction_names)
285    operation_names.update(normalization_names)
286
287    # Default example data:
288    example_dim = 1
289    example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]])
290    example_mask = torch.tensor([[True, False, True], [False, False, False]])
291    example_args: Tuple[Any, ...]
292    if func.__name__ in {"norm", "normalize"}:
293        example_args = (2.0, example_dim)
294        example_input = example_input.to(dtype=torch.float32)
295    elif func.__name__ in {"var", "std"}:
296        example_args = (example_dim, False)
297    elif func.__name__ == "median":
298        example_args = (example_dim,)
299        example_input = example_input.to(dtype=torch.float32)
300    else:
301        example_args = (example_dim,)
302
303    operation_args: Tuple[str, ...]
304    operation_kwargs: Tuple[str, ...]
305    operation_args, operation_kwargs = args_and_kwargs[func.__name__]
306    arg_declarations = [
307        "\n    ".join(
308            argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines()
309        )
310        for a in operation_args
311    ]
312    kwarg_declarations = [
313        "\n    ".join(
314            argument_declarations.get(
315                a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.'
316            )
317            .format(default=a.split("=", 1)[1])
318            .splitlines()
319        )
320        for a in operation_kwargs
321    ]
322
323    if func.__name__ in reduction_names:
324        op_kind = "reduction"
325        doc_sections = ["signature", "descr", "identity", "args", "example"]
326    elif func.__name__ in normalization_names:
327        op_kind = "normalization"
328        doc_sections = ["signature", "descr", "args", "example"]
329        example_input = example_input.to(dtype=torch.float32)
330    else:
331        assert 0  # add function name to operation names dictionaries
332    example_output = func(example_input, *example_args, mask=example_mask)
333
334    template_data = {
335        "function_name": func.__name__,
336        "full_function_name": func.__module__ + "." + func.__name__,
337        "operation name": operation_names[func.__name__],
338        "operation_args": ", ".join(a.split("__", 1)[0] for a in operation_args),
339        "operation_kwargs": ", ".join(a.split("__", 1)[0] for a in operation_kwargs),
340        # one-line representation of a tensor:
341        "example_input": " ".join(str(example_input).split()),
342        "example_args": ", ".join(map(str, example_args)),
343        "example_mask": " ".join(str(example_mask).split()),
344        # multi-line representation of a tensor with indent
345        "indent_example_input": ("\n    ").join(str(example_input).splitlines()),
346        "indent_example_mask": ("\n    ").join(str(example_mask).splitlines()),
347        "indent_example_output": ("\n    ").join(str(example_output).splitlines()),
348    }
349
350    if func.__name__ in reduction_names:
351        template_data.update(
352            identity_uint8=_reduction_identity(
353                func.__name__, torch.tensor(0, dtype=torch.uint8)
354            ),
355            identity_int32=_reduction_identity(
356                func.__name__, torch.tensor(0, dtype=torch.int32)
357            ),
358            identity_float32=_reduction_identity(
359                func.__name__, torch.tensor(0, dtype=torch.float32)
360            ),
361        )
362        if func.__name__ == "norm":
363            template_data.update(
364                identity_ord_ninf=_reduction_identity(
365                    func.__name__, torch.tensor(0, dtype=torch.float32), float("-inf")
366                )
367            )
368    elif func.__name__ in normalization_names:
369        template_data.update(definition=definitions[func.__name__])
370    else:
371        assert 0  # add function name to operation names dictionaries
372    template_data.update(
373        args_declarations=("\n    ".join(arg_declarations)).format_map(template_data)
374    )
375    template_data.update(
376        kwargs_declarations=("\n    ".join(kwarg_declarations)).format_map(
377            template_data
378        )
379    )
380
381    # Apply function name info to docstring templates:
382    templates = {
383        k: v.format_map(template_data)
384        for k, v in docstring_templates.items()
385        if k.startswith(op_kind)
386    }
387    templates.update(
388        (k, v.format_map(template_data) if isinstance(v, str) else v)
389        for k, v in template_data.items()
390    )
391
392    # Apply docstring templates to function doctring:
393    if func.__doc__ is None:
394        doc_template = "\n\n".join([f"{{{op_kind}_{sec}}}" for sec in doc_sections])
395    else:
396        doc_template = func.__doc__
397    return doc_template.format_map(templates)
398
399
400def _reduction_identity(op_name: str, input: Tensor, *args):
401    """Return identity value as scalar tensor of a reduction operation on
402    given input, or None, if the identity value cannot be uniquely
403    defined for the given input.
404
405    The identity value of the operation is defined as the initial
406    value to reduction operation that has a property ``op(op_identity,
407    value) == value`` for any value in the domain of the operation.
408    Or put it another way, including or excluding the identity value in
409    a list of operands will not change the reduction result.
410
411    See https://github.com/pytorch/rfcs/pull/27 for more information.
412
413    """
414    dtype: DType = input.dtype
415    device = input.device
416    op_name = op_name.rsplit(".", 1)[-1]  # lstrip module name when present
417    if op_name in {"sum", "cumsum"}:
418        return torch.tensor(0, dtype=dtype, device=device)
419    elif op_name in {"prod", "cumprod"}:
420        return torch.tensor(1, dtype=dtype, device=device)
421    elif op_name in {"amax", "argmax", "logaddexp"}:
422        if torch.is_floating_point(input):
423            return torch.tensor(-torch.inf, dtype=dtype, device=device)
424        elif torch.is_signed(input) or dtype == torch.uint8:
425            return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
426    elif op_name in {"logsumexp"}:
427        if torch.is_floating_point(input):
428            return torch.tensor(-torch.inf, dtype=dtype, device=device)
429        elif torch.is_complex(input):
430            return torch.tensor(-torch.inf + 0j, dtype=dtype, device=device)
431        elif torch.is_signed(input) or dtype == torch.uint8:
432            return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
433    elif op_name in {"amin", "argmin"}:
434        if torch.is_floating_point(input):
435            return torch.tensor(torch.inf, dtype=dtype, device=device)
436        elif torch.is_signed(input) or dtype == torch.uint8:
437            return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device)
438    elif op_name == "mean":
439        # Strictly speaking, the identity value of the mean operation
440        # is the mean of the input. Since the mean value depends on
441        # the dim argument and it may be a non-scalar tensor, we
442        # consider the identity value of the mean operation ambiguous.
443        # Moreover, the mean value of empty input is undefined.
444        return None
445    elif op_name == "norm":
446        ord = args[0] if args else 2
447        if ord == float("-inf"):
448            assert torch.is_floating_point(input), input.dtype
449            return torch.tensor(torch.inf, dtype=dtype, device=device)
450        return torch.tensor(0, dtype=dtype, device=device)
451    elif op_name == "median":
452        # We use NaN for now because the implementation is currently using torch.nanmedian
453        # and NaN is the identity for that function since it gets ignored
454        dtype = input.dtype if torch.is_floating_point(input) else torch.float
455        return torch.tensor(torch.nan, dtype=dtype, device=device)
456    elif op_name in {"var", "std"}:
457        return None
458    raise NotImplementedError(f"identity of {op_name} on {dtype} input")
459
460
461def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]:
462    """Return dim argument as a tuple of sorted dim values."""
463    dims: List[int] = []
464    if dim == ():
465        # Currently, `dim=()` in reductions operations means "reduce
466        # over all dimensions" while in future, it will read "no
467        # reduce". See https://github.com/pytorch/pytorch/issues/29137
468        # When gh-29137 is resolved, this if-block must be deleted.
469        dim = None
470    if dim is None:
471        return tuple(range(ndim))
472    ndim = max(ndim, 1)
473    dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim
474    for d in dim_:
475        if d in dims:
476            raise RuntimeError(f"dim={d} appears multiple times in the list of dims")
477        if d >= ndim or d < -ndim:
478            raise IndexError(
479                f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
480            )
481        dims.append(d % ndim)
482    return tuple(sorted(dims))
483
484
485def _sparse_coo_flatten_indices(indices: Tensor, shape: tuple):
486    # Flatted N-D indices to 1-D indices
487    flat_indices = indices.new_zeros(indices.size(1))
488    for d, sz in enumerate(shape):
489        flat_indices.mul_(sz)
490        flat_indices.add_(indices[d])
491    return flat_indices
492
493
494def _any(input: Tensor, dim: tuple, keepdim: bool):
495    # Support torch.any with tuple dim argument.
496    # Workaround of https://github.com/pytorch/pytorch/issues/56586
497    r = input
498    for d in reversed(dim):
499        r = r.any(dim=d, keepdim=keepdim)
500    return r
501
502
503def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
504    """Sparse variant of torch.where. Supports sparse COO and hybrid sparse COO tensors.
505
506    _sparse_coo_where implements the following invariant:
507
508      _sparse_coo_where(mask, input, fill_value).to_dense(fill_value) ==
509        torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
510
511    where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
512    tensor, and `to_dense(fill_value)` is like `to_dense()` except
513    that the unspecified elements are mapped to `fill_value` rather
514    than to `0`.
515
516    Returns a sparse COO tensor with the following features:
517
518    - all specified elements correspond to masked-in elements that
519      have the values of the input tensor. If there exists a masked-in
520      element (as specified by mask) that is not specified in the
521      input, in the result tensor, the corresponding element has value
522      0. In the dense part of the sparse tensor, the masked-out
523      elements are replaced with fill_value.
524
525    - all unspecified elements correspond to masked-out elements.
526    """
527
528    assert input.layout == torch.sparse_coo
529    assert mask.layout == input.layout
530    assert mask.shape == input.shape
531    assert mask.dense_dim() == input.dense_dim()  # TODO: eliminate this restriction
532
533    input = input.coalesce()
534
535    # For set operations on sparse tensor indices, we'll convert
536    # multi-dimensional indices to 1-D indices for efficiency.
537    input_flat_indices = _sparse_coo_flatten_indices(
538        input.indices(), input.shape[: input.sparse_dim()]
539    )
540    mask_flat_indices = _sparse_coo_flatten_indices(
541        mask.indices(), mask.shape[: mask.sparse_dim()]
542    )
543
544    # the set of mask flat indices that define masked-in elements:
545    if mask.dense_dim() > 0:
546        mask_values = _any(
547            mask.values(), tuple(range(1, input.sparse_dim() + 1)), False
548        )
549    else:
550        mask_values = mask.values()
551    maskin_flat_indices = mask_flat_indices[mask_values.nonzero()[:, 0]]
552
553    def intersection(i1, i2):
554        union, counts = torch.cat([i1, i2]).unique(return_counts=True)
555        return union, torch.where(counts.gt(1))
556
557    def minus(i1, i2):
558        union, counts = torch.cat([i1, i2]).unique(return_counts=True)
559        return intersection(union[torch.where(counts.eq(1))], i1)
560
561    def _apply(a):
562        obj, w = a
563        return obj[w]
564
565    # the set of input flat indices of specified and masked-in elements:
566    maskin_input_flat_indices = _apply(
567        intersection(maskin_flat_indices, input_flat_indices)
568    )
569    _, w = intersection(input_flat_indices, maskin_input_flat_indices)
570
571    # the indices and values of masked-in elements
572    where_input_indices = input.indices()[(slice(None),) + w]
573    where_input_values = input.values()[w]
574
575    if mask.dense_dim() > 0:
576        # apply mask to the dense part of the input values:
577        _, w1 = intersection(mask_flat_indices, maskin_input_flat_indices)
578        where_mask_values = mask.values()[w1]
579        where_input_values = torch.where(
580            where_mask_values, where_input_values, fill_value
581        )
582
583    # the set of flat indices of unspecified input and masked-in elements:
584    maskin_zero_flat_indices = _apply(
585        minus(maskin_flat_indices, maskin_input_flat_indices)
586    )
587
588    # the indices of masked-in zero elements
589    _, w = intersection(mask_flat_indices, maskin_zero_flat_indices)
590    where_zero_indices = mask.indices()[(slice(None),) + w]
591
592    # construct result
593    n = where_zero_indices.size(1)
594    if n == 0:
595        # the input is coalesced, hence input_flat_indices are ordered
596        # and the result is guaranteed to be coalesced:
597        result = torch.sparse_coo_tensor(
598            where_input_indices, where_input_values, input.shape
599        )
600        return result._coalesced_(True)
601
602    where_indices = torch.cat([where_input_indices, where_zero_indices], dim=1)
603    where_values = torch.cat(
604        [
605            where_input_values,
606            where_input_values.new_zeros((n,) + where_input_values.shape[1:]),
607        ]
608    )
609    result = torch.sparse_coo_tensor(where_indices, where_values, input.shape)
610
611    # appending zero elements leads to uncoalesced sparse tensor
612    return result.coalesce()
613
614
615def _sparse_coo_scatter_reduction_helper(
616    op,
617    mask_input: Tensor,
618    dims: Tuple[int, ...],
619    keepdim: bool,
620    dtype: Optional[DType] = None,
621) -> Tensor:
622    reduce = op.__name__
623    valid_reductions = ["sum", "prod", "amax", "amin"]
624    if reduce not in valid_reductions:
625        raise ValueError(
626            f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
627        )
628
629    output_dtype = dtype
630    values, indices = mask_input._values(), mask_input._indices()
631    input_dims = mask_input.dim()
632    num_sparse_dims = mask_input.sparse_dim()
633    reduced_sparse_dims = []
634    retained_sparse_dims = []
635    reduced_dense_dims = []
636
637    # promote dtype if specified
638    if values.dtype != output_dtype:
639        values = values.to(output_dtype)
640
641    if keepdim:
642        output_shape = tuple(
643            1 if i in dims else si for (i, si) in enumerate(mask_input.shape)
644        )
645    else:
646        output_shape = tuple(
647            si for (i, si) in enumerate(mask_input.shape) if i not in dims
648        )
649
650    for d in dims:
651        if d >= input_dims:
652            continue
653
654        if d < num_sparse_dims:
655            reduced_sparse_dims.append(d)
656        else:
657            reduced_dense_dims.append(d + 1 - num_sparse_dims)
658
659    # Reduce dense dimensions
660    if len(reduced_dense_dims) > 0:
661        if reduce == "sum":
662            new_values = values
663            new_values = op(new_values, dim=reduced_dense_dims, keepdim=bool(keepdim))
664        else:
665            # FIXME: Implement reductions for dense dimensions for ops with non-zero reduction identities
666            return NotImplemented
667    else:
668        new_values = values.clone()
669
670    # Reduce sparse dimensions
671    if len(reduced_sparse_dims) == num_sparse_dims:
672        if reduce in {"amax", "amin"} and new_values.size(0) == 0:
673            # IndexError: amax(): Expected reduction dim 0 to have non-zero size.
674            # sum()/prod() return the reduction identity when dim has size 0 but amax()/amin() do not
675            # See https://github.com/pytorch/pytorch/issues/61901
676            new_values = _reduction_identity(reduce, new_values)
677        else:
678            new_values = op(new_values, dim=0)
679        if keepdim:
680            for _ in range(num_sparse_dims):
681                new_values = new_values.unsqueeze(0)
682        return new_values.to(dtype=output_dtype).to_sparse()
683    else:
684        new_indices = indices.clone()
685        if keepdim:
686            # zero out reduced sparse dimensions if keepdim = True
687            # ensures that the call to torch.unique folds duplicated indices together while preserving the dimension
688            new_indices[reduced_sparse_dims, :] = 0
689        else:
690            # remove reduced sparse dimensions if keepdim = False
691            if len(reduced_sparse_dims) > 0:
692                retained_sparse_dims = [
693                    i
694                    for i in range(num_sparse_dims)
695                    if i not in set(reduced_sparse_dims)
696                ]
697                new_indices = new_indices.index_select(
698                    0, torch.tensor(retained_sparse_dims).to(mask_input.device)
699                )
700
701    # Use scatter_reduce to reduce items in the new_values tensor that correspond to the same indices in new_indices
702    if new_indices.numel() > 0:
703        # lexsort indices and get index tensor for scatter reduction
704        new_indices, inverse_indices = torch.unique(
705            new_indices, return_inverse=True, dim=1
706        )
707        out_shape = list(new_values.shape)
708        out_shape[0] = new_indices.shape[1]
709        for _ in range(new_values.ndim - 1):
710            inverse_indices = inverse_indices.unsqueeze(-1)
711        scatter_indices = inverse_indices.expand(new_values.shape)
712        # FIXME: temporary workaround for issue with bfloat16/float16 remove when acctype is implemented for scatter_reduce
713        if output_dtype in {torch.bfloat16, torch.float16}:
714            new_values = new_values.to(torch.float)
715            out = new_values.new_empty(out_shape)
716            new_values = out.scatter_reduce_(
717                0, scatter_indices, new_values, reduce=reduce, include_self=False
718            )
719            new_values = new_values.to(dtype=output_dtype)
720        else:
721            out = new_values.new_empty(out_shape)
722            new_values = out.scatter_reduce_(
723                0, scatter_indices, new_values, reduce=reduce, include_self=False
724            )
725
726    return torch.sparse_coo_tensor(
727        new_indices,
728        new_values,
729        output_shape,
730        dtype=output_dtype,
731        device=mask_input.device,
732    )
733
734
735def _sparse_csr_segment_reduction_helper(
736    op,
737    mask_input: Tensor,
738    dims: Tuple[int, ...],
739    keepdim: bool,
740    dtype: Optional[DType] = None,
741) -> Tensor:
742    # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True
743    # FIXME: when dense dimensions are implemented for CSR tensors
744    assert (
745        keepdim
746    ), "reduction operations on CSR tensors with keepdim=False is unsupported"
747    reduce = op.__name__
748    valid_reductions = ["sum", "prod", "mean", "amax", "amin"]
749    if reduce not in valid_reductions:
750        raise ValueError(
751            f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
752        )
753    device = mask_input.device
754    output_dtype = dtype
755    values, crow_indices, col_indices = (
756        mask_input.values(),
757        mask_input.crow_indices(),
758        mask_input.col_indices(),
759    )
760
761    # promote dtype if specified
762    if values.dtype != output_dtype:
763        values = values.to(output_dtype)
764
765    if len(dims) == 0:
766        return mask_input
767    if len(dims) == 1:
768        if dims[0] == 0:
769            new_col_indices, scatter_indices = torch.unique(
770                col_indices, return_inverse=True
771            )
772            new_nnz = new_col_indices.shape[0]
773            new_crow_indices = torch.tensor([0, new_nnz])
774            new_values = values.new_empty(new_col_indices.shape)
775            new_values.scatter_reduce_(
776                0, scatter_indices, values, reduce, include_self=False
777            )
778            new_shape = [1, mask_input.size(1)]
779        else:
780            assert (
781                dims[0] == 1
782            ), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1."
783            # all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1
784            # except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0
785            new_crow_indices = torch.cat(
786                (
787                    crow_indices.new_zeros(1),
788                    torch.cumsum(torch.diff(crow_indices) != 0, 0),
789                ),
790                0,
791            )
792            new_nnz = new_crow_indices[-1]
793            new_col_indices = col_indices.new_zeros(new_nnz)
794            new_values = torch._segment_reduce(values, reduce, offsets=crow_indices)  # type: ignore[attr-defined]
795            new_shape = [mask_input.size(0), 1]
796    else:
797        assert len(dims) == 2
798        nnz = min(1, values.numel())
799        if nnz == 1:
800            op_kwargs = {"keepdim": True, "dtype": output_dtype}
801            # amax and amin do not support dtype kwarg
802            if reduce in ["amax", "amin"]:
803                del op_kwargs["dtype"]
804            new_values = op(values, 0, **op_kwargs)
805        else:
806            new_values = torch.empty(0, dtype=output_dtype)
807        new_col_indices = col_indices.new_zeros(nnz)
808        new_crow_indices = torch.tensor([0, nnz])
809        new_shape = [1, nnz]
810
811    return torch.sparse_csr_tensor(
812        new_crow_indices,
813        new_col_indices,
814        new_values,
815        new_shape,
816        dtype=output_dtype,
817        device=device,
818    )
819
820
821def _sparse_csr_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
822    """Sparse variant of torch.where. Supports sparse CSR tensors."""
823    # TODO: implement sparse CSR specific where operator for efficiency
824    return _sparse_coo_where(
825        mask.to_sparse_coo(), input.to_sparse_coo(), fill_value
826    ).to_sparse_csr()
827
828
829def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
830    """torch.where with sparse inputs support.
831
832    _where implements the following invariant:
833
834      _where(mask, input, fill_value).to_dense(fill_value) ==
835        torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
836
837    where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
838    tensor, and `to_dense(fill_value)` is like `to_dense()` except
839    that the unspecified elements are mapped to `fill_value` rather
840    than to `0`.
841
842    Returns a sparse tensor with the following features:
843
844    - all specified elements correspond to masked-in elements that
845      have the values of the input tensor. If there exists a masked-in
846      element (as specified by mask) that is not specified in the
847      input, in the result tensor, the corresponding element has value
848      0. In the dense part of the sparse tensor, the masked-out
849      elements are replaced with fill_value.
850
851    - all unspecified elements correspond to masked-out elements.
852    """
853    if mask.layout == torch.strided:
854        return torch.where(mask, input, fill_value)
855    elif mask.layout == torch.sparse_coo:
856        return _sparse_coo_where(mask, input, fill_value)
857    elif mask.layout == torch.sparse_csr:
858        return _sparse_csr_where(mask, input, fill_value)
859    else:
860        raise ValueError(
861            f"_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}"
862        )
863
864
865def _input_mask(input: Union[Tensor, MaskedTensor], *args, **kwargs) -> Tensor:
866    """Return canonical input mask.
867
868    A canonical input mask is defined as a boolean mask tensor that
869    shape and layout matches with the shape and the layout of the
870    input.
871
872    The canonical input mask is computed from the :attr:`mask` tensor
873    content to meet the following criteria:
874
875    1. The shape of the canonical input mask is the same as the shape
876       of :attr:`input` tensor. If the mask tensor has a smaller shape
877       than the shape of the :attr:`input`, broadcasting rules will be
878       applied. Downcasting of mask is not supported.
879
880    2. The layout of the canonical input mask is the same as the
881       layout of the :attr:`input` tensor. If the mask has different
882       layout, it will be converted to the expected layout.  In the
883       case of sparse COO layout, the canonical input mask will be
884       coalesced.
885
886    3. The dtype of the canonical input mask is torch.bool. If the
887       mask dtype is not bool then it will be converted to bool dtype
888       using `.to(dtype=bool)` method call.
889
890    4. The elements of the canonical input mask have boolean values
891       copied from the content of the :attr:`mask` tensor (after
892       possible broadcasting and dtype conversion transforms).  In
893       general, the sparsity pattern of the sparse canonical input
894       mask need not to be the same as the sparsity pattern of the
895       sparse :attr:`input` tensor.
896
897    """
898    if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
899        raise ValueError(
900            f"_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}"
901        )
902
903    mask = kwargs.get("mask")
904
905    # default mask
906    if mask is None:
907        raise ValueError("_input_mask requires explicit mask")
908
909    # mask shape must match with input shape
910    if mask.shape != input.shape:
911        if mask.ndim > input.ndim:
912            raise IndexError(
913                "_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)"
914            )
915        if mask.layout == torch.strided:
916            mask = torch.broadcast_to(mask.clone(), input.shape).to(dtype=torch.bool)
917        elif mask.layout == torch.sparse_coo:
918            mask = torch._sparse_broadcast_to(mask, input.shape)
919        else:
920            assert mask.layout == torch.sparse_csr
921            # Broadcasting of CSR tensors is not implemented. Working
922            # around by using COO layout.
923            mask = torch._sparse_broadcast_to(
924                mask.to_sparse(), input.shape
925            ).to_sparse_csr()
926
927    # mask layout must match with input layout
928    if mask.layout != input.layout:
929        if input.layout == torch.strided:
930            mask = mask.to_dense()
931        elif input.layout == torch.sparse_coo:
932            if mask.layout == torch.strided:
933                mask = mask.to_sparse(input.sparse_dim())
934            else:
935                mask = mask.to_sparse()
936        else:
937            assert input.layout == torch.sparse_csr
938            mask = mask.to_sparse_csr()
939
940    # sparse mask must be coalesced
941    if mask.layout == torch.sparse_coo:
942        mask = mask.coalesce()
943
944    # mask is a boolean tensor
945    mask = mask.to(dtype=torch.bool)
946
947    return mask
948
949
950def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor:
951    """Return output mask of masked operation applied to given arguments."""
952    if callable(op):
953        is_reduction = op.__name__ in {
954            "sum",
955            "prod",
956            "amax",
957            "amin",
958            "argmax",
959            "argmin",
960            "mean",
961            "median",
962            "norm",
963            "var",
964            "std",
965            "logsumexp",
966        }
967        is_normalization = op.__name__ in {
968            "softmax",
969            "log_softmax",
970            "softmin",
971            "normalize",
972            "cumsum",
973            "cumprod",
974        }
975        if is_reduction:
976            if op.__name__ == "norm":
977                if args:
978                    args = args[1:]  # lstrip ord argument
979            dim = args[0] if args else kwargs.get("dim")
980            outmask = _input_mask(input, *args, **kwargs)
981            keepdim = kwargs.get("keepdim", False)
982            dim_ = _canonical_dim(dim, input.ndim)
983            return _any(outmask, dim_, bool(keepdim))
984        elif is_normalization:
985            return _input_mask(input, *args, **kwargs)
986        else:
987            raise ValueError(
988                f"_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})"
989            )
990    else:
991        raise ValueError(
992            f"_output_mask expected masked operation (got {type(op).__name__} object)"
993        )
994
995
996def _combine_input_and_mask(
997    op, input: Union[MaskedTensor, Tensor], mask, *args
998) -> Tensor:
999    def helper(input, mask):
1000        if mask is None:
1001            return input
1002        canonical_mask = _input_mask(input, mask=mask)
1003        if callable(op):
1004            fill_value = _reduction_identity(op.__name__, input, *args)
1005            return _where(canonical_mask, input, fill_value)
1006        else:
1007            raise ValueError(
1008                f"_combine_input_and_mask expected masked operation (got {type(op).__name__} object)"
1009            )
1010
1011    class Combine(torch.autograd.Function):
1012        @staticmethod
1013        def forward(ctx, input, mask):
1014            """Return input with masked-out elements eliminated for the given operations."""
1015            ctx.save_for_backward(mask)
1016
1017            if mask is not None:
1018                ctx.mark_non_differentiable(mask)
1019
1020            return helper(input, mask)
1021
1022        @staticmethod
1023        def backward(ctx, grad_output):
1024            (mask,) = ctx.saved_tensors
1025            grad_data = (
1026                grad_output.get_data() if is_masked_tensor(grad_output) else grad_output
1027            )
1028            result = as_masked_tensor(grad_data, mask)
1029            return result, None
1030
1031    return (
1032        Combine.apply(input.get_data(), input.get_mask())  # type: ignore[union-attr]
1033        if is_masked_tensor(input)
1034        else helper(input, mask)
1035    )
1036
1037
1038@_apply_docstring_templates
1039def sum(
1040    input: Union[Tensor, MaskedTensor],
1041    dim: DimOrDims = None,
1042    *,
1043    keepdim: Optional[bool] = False,
1044    dtype: Optional[DType] = None,
1045    mask: Optional[Tensor] = None,
1046) -> Tensor:
1047    # __doc__ is generated by _apply_docstring_templates decorator
1048    if dtype is None:
1049        # promote integer types to int64 when output dtype is not specified
1050        if input.layout == torch.sparse_csr:
1051            if input.dtype in {
1052                torch.uint8,
1053                torch.bool,
1054                torch.int8,
1055                torch.int16,
1056                torch.int32,
1057            }:
1058                # csr.to(dtype=torch.int64) is not implemented, so
1059                # using coo.to on input to ensure the promoted dtype
1060                input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
1061            else:
1062                dtype = input.dtype
1063        else:
1064            dtype = input.dtype
1065            if input.dtype in {
1066                torch.uint8,
1067                torch.bool,
1068                torch.int8,
1069                torch.int16,
1070                torch.int32,
1071            }:
1072                dtype = torch.int64
1073    dim_ = _canonical_dim(dim, input.ndim)
1074    mask_input = _combine_input_and_mask(sum, input, mask)
1075    if mask_input.layout == torch.strided:
1076        return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype)
1077    elif mask_input.layout == torch.sparse_coo:
1078        return _sparse_coo_scatter_reduction_helper(
1079            torch.sum, mask_input, dim_, bool(keepdim), dtype
1080        )
1081    elif mask_input.layout == torch.sparse_csr:
1082        return torch._sparse_csr_sum(
1083            mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
1084        )
1085    else:
1086        raise ValueError(
1087            f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
1088        )
1089
1090
1091@_apply_docstring_templates
1092def prod(
1093    input: Union[Tensor, MaskedTensor],
1094    dim: DimOrDims = None,
1095    *,
1096    keepdim: Optional[bool] = False,
1097    dtype: Optional[DType] = None,
1098    mask: Optional[Tensor] = None,
1099) -> Tensor:
1100    # __doc__ is generated by _apply_docstring_templates decorator
1101    if dtype is None:
1102        # promote integer types to int64 when output dtype is not specified
1103        if input.layout == torch.sparse_csr:
1104            if input.dtype in {
1105                torch.uint8,
1106                torch.bool,
1107                torch.int8,
1108                torch.int16,
1109                torch.int32,
1110            }:
1111                # csr.to(dtype=torch.int64) is not implemented, so
1112                # using coo.to on input to ensure the promoted dtype
1113                input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
1114            else:
1115                dtype = input.dtype
1116        else:
1117            dtype = input.dtype
1118            if input.dtype in {
1119                torch.uint8,
1120                torch.bool,
1121                torch.int8,
1122                torch.int16,
1123                torch.int32,
1124            }:
1125                dtype = torch.int64
1126    dim_ = _canonical_dim(dim, input.ndim)
1127    mask_input = _combine_input_and_mask(prod, input, mask)
1128    if mask_input.layout == torch.strided:
1129        # Workaround https://github.com/pytorch/pytorch/issues/56586
1130        result = mask_input
1131        result = result.to(dtype=dtype)
1132        for d in reversed(dim_):
1133            result = result.prod(dim=d, keepdim=bool(keepdim))
1134        return result
1135    elif mask_input.layout == torch.sparse_coo:
1136        if mask is None:
1137            # See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors
1138            raise ValueError(
1139                "masked prod expects explicit mask for sparse_coo tensor input"
1140            )
1141        return _sparse_coo_scatter_reduction_helper(
1142            torch.prod, mask_input, dim_, bool(keepdim), dtype
1143        )
1144    elif mask_input.layout == torch.sparse_csr:
1145        if mask is None:
1146            # mask is None corresponds to all-True mask. The
1147            # unspecified elements in the CSR tensor correspond to
1148            # zero values. Hence, the prod reduction result is
1149            # automatically zero unless all elements are specified.
1150            # A semi-optimal way to take this into account is to use:
1151            #
1152            #   masked_prod(csr, ..., mask=None) == torch._sparse_csr_prod(csr, ...) * all(csr.nonzero(), ...)
1153            #
1154            # but that requires implementing `all` and `nonzero`
1155            # support for sparse csr tensors.
1156            raise ValueError(
1157                "masked prod expects explicit mask for sparse_csr tensor input"
1158            )
1159        return torch._sparse_csr_prod(
1160            mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
1161        )
1162    else:
1163        raise ValueError(
1164            f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
1165        )
1166
1167
1168@_apply_docstring_templates
1169def cumsum(
1170    input: Tensor,
1171    dim: int,
1172    *,
1173    dtype: Optional[DType] = None,
1174    mask: Optional[Tensor] = None,
1175) -> Tensor:
1176    if dtype is None:
1177        dtype = input.dtype
1178    dim_ = _canonical_dim(dim, input.ndim)[0]
1179    mask_input = _combine_input_and_mask(sum, input, mask)
1180    if mask_input.layout == torch.strided:
1181        return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype)
1182    else:
1183        raise ValueError(
1184            f"masked cumsum expects strided tensor (got {mask_input.layout} tensor)"
1185        )
1186
1187
1188@_apply_docstring_templates
1189def cumprod(
1190    input: Tensor,
1191    dim: int,
1192    *,
1193    dtype: Optional[DType] = None,
1194    mask: Optional[Tensor] = None,
1195) -> Tensor:
1196    if dtype is None:
1197        dtype = input.dtype
1198    dim_ = _canonical_dim(dim, input.ndim)[0]
1199    mask_input = _combine_input_and_mask(prod, input, mask)
1200    if mask_input.layout == torch.strided:
1201        return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype)
1202    else:
1203        raise ValueError(
1204            f"masked cumprod expects strided tensor (got {mask_input.layout} tensor)"
1205        )
1206
1207
1208@_apply_docstring_templates
1209def amax(
1210    input: Union[Tensor, MaskedTensor],
1211    dim: DimOrDims = None,
1212    *,
1213    keepdim: Optional[bool] = False,
1214    dtype: Optional[DType] = None,
1215    mask: Optional[Tensor] = None,
1216) -> Tensor:
1217    """\
1218{reduction_signature}
1219
1220{reduction_descr}
1221
1222{reduction_identity_dtype}
1223
1224{reduction_args}
1225
1226{reduction_example}"""
1227    if dtype is None:
1228        dtype = input.dtype
1229
1230    mask_input = _combine_input_and_mask(amax, input, mask)
1231    dim_ = _canonical_dim(dim, mask_input.ndim)
1232    if mask_input.layout == torch.strided:
1233        return torch.amax(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
1234    elif mask_input.layout == torch.sparse_coo:
1235        if mask is None:
1236            # See comment in the sparse_csr branch of prod, a similar issue arises here
1237            # where unspecified elements along a dimension may need to be reduced with the result
1238            raise ValueError(
1239                "masked amax expects explicit mask for sparse_coo tensor input"
1240            )
1241        return _sparse_coo_scatter_reduction_helper(
1242            torch.amax, mask_input, dim_, bool(keepdim), dtype
1243        )
1244    elif mask_input.layout == torch.sparse_csr:
1245        if mask is None:
1246            raise ValueError(
1247                "masked amax expects explicit mask for sparse_csr tensor input"
1248            )
1249        return _sparse_csr_segment_reduction_helper(
1250            torch.amax, mask_input, dim_, bool(keepdim), dtype
1251        )
1252    else:
1253        raise ValueError(
1254            f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
1255        )
1256
1257
1258@_apply_docstring_templates
1259def amin(
1260    input: Union[Tensor, MaskedTensor],
1261    dim: DimOrDims = None,
1262    *,
1263    keepdim: Optional[bool] = False,
1264    dtype: Optional[DType] = None,
1265    mask: Optional[Tensor] = None,
1266) -> Tensor:
1267    """\
1268{reduction_signature}
1269
1270{reduction_descr}
1271
1272{reduction_identity_dtype}
1273
1274{reduction_args}
1275
1276{reduction_example}"""
1277    if dtype is None:
1278        dtype = input.dtype
1279
1280    mask_input = _combine_input_and_mask(amin, input, mask)
1281    dim_ = _canonical_dim(dim, mask_input.ndim)
1282    if mask_input.layout == torch.strided:
1283        return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
1284    elif mask_input.layout == torch.sparse_coo:
1285        if mask is None:
1286            # See comment in the sparse_csr branch of prod, a similar issue arises here
1287            # where unspecified elements along a dimension may need to be reduced with the result
1288            raise ValueError(
1289                "masked amax expects explicit mask for sparse_coo tensor input"
1290            )
1291        return _sparse_coo_scatter_reduction_helper(
1292            torch.amin, mask_input, dim_, bool(keepdim), dtype
1293        )
1294    elif mask_input.layout == torch.sparse_csr:
1295        if mask is None:
1296            raise ValueError(
1297                "masked amin expects explicit mask for sparse_csr tensor input"
1298            )
1299        return _sparse_csr_segment_reduction_helper(
1300            torch.amin, mask_input, dim_, bool(keepdim), dtype
1301        )
1302    else:
1303        raise ValueError(
1304            f"masked amin expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
1305        )
1306
1307
1308@_apply_docstring_templates
1309def argmax(
1310    input: Union[Tensor, MaskedTensor],
1311    dim: Optional[int] = None,
1312    *,
1313    keepdim: Optional[bool] = False,
1314    dtype: Optional[DType] = None,
1315    mask: Optional[Tensor] = None,
1316) -> Tensor:
1317    """\
1318{reduction_signature}
1319{reduction_descr}
1320{reduction_identity_dtype}
1321{reduction_args}
1322{reduction_example}"""
1323    if dtype is None:
1324        dtype = input.dtype
1325    mask_input = _combine_input_and_mask(argmax, input, mask)
1326    if mask_input.layout == torch.strided:
1327        return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype)
1328    else:
1329        raise ValueError(
1330            f"masked argmax expects strided tensor (got {mask_input.layout} tensor)"
1331        )
1332
1333
1334@_apply_docstring_templates
1335def argmin(
1336    input: Union[Tensor, MaskedTensor],
1337    dim: Optional[int] = None,
1338    *,
1339    keepdim: Optional[bool] = False,
1340    dtype: Optional[DType] = None,
1341    mask: Optional[Tensor] = None,
1342) -> Tensor:
1343    """\
1344{reduction_signature}
1345{reduction_descr}
1346{reduction_identity_dtype}
1347{reduction_args}
1348{reduction_example}"""
1349    if dtype is None:
1350        dtype = input.dtype
1351    mask_input = _combine_input_and_mask(argmin, input, mask)
1352    if mask_input.layout == torch.strided:
1353        return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype)
1354    else:
1355        raise ValueError(
1356            f"masked argmin expects strided tensor (got {mask_input.layout} tensor)"
1357        )
1358
1359
1360@_apply_docstring_templates
1361def mean(
1362    input: Union[Tensor, MaskedTensor],
1363    dim: DimOrDims = None,
1364    *,
1365    keepdim: Optional[bool] = False,
1366    dtype: Optional[DType] = None,
1367    mask: Optional[Tensor] = None,
1368) -> Tensor:
1369    """\
1370{reduction_signature}
1371
1372{reduction_descr}
1373
1374By definition, the identity value of a mean operation is the mean
1375value of the tensor. If all elements of the input tensor along given
1376dimension(s) :attr:`dim` are masked-out, the identity value of the
1377mean is undefined.  Due to this ambiguity, the elements of output
1378tensor with strided layout, that correspond to fully masked-out
1379elements, have ``nan`` values.
1380
1381{reduction_args}
1382
1383{reduction_example}"""
1384    if dtype is None:
1385        dtype = input.dtype
1386    if input.layout == torch.strided:
1387        if mask is None:
1388            # TODO: compute count analytically
1389            count = sum(
1390                torch.ones(input.shape, dtype=torch.int64, device=input.device),
1391                dim,
1392                keepdim=keepdim,
1393            )
1394            total = sum(input, dim, keepdim=keepdim, dtype=dtype)
1395        else:
1396            inmask = _input_mask(input, mask=mask)
1397            count = inmask.sum(dim=dim, keepdim=bool(keepdim))
1398            total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
1399        return total / count
1400    elif input.layout == torch.sparse_csr:
1401        mask_input = _combine_input_and_mask(mean, input, mask)
1402        dim_ = _canonical_dim(dim, mask_input.ndim)
1403        if mask is None:
1404            raise ValueError(
1405                "masked mean expects explicit mask for sparse_csr tensor input"
1406            )
1407        return _sparse_csr_segment_reduction_helper(
1408            torch.mean, mask_input, dim_, bool(keepdim), dtype
1409        )
1410    else:
1411        raise ValueError(
1412            f"masked mean expects strided or sparse_csr tensor (got {input.layout} tensor)"
1413        )
1414
1415
1416@_apply_docstring_templates
1417def median(
1418    input: Union[Tensor, MaskedTensor],
1419    dim: int = -1,
1420    *,
1421    keepdim: bool = False,
1422    dtype: Optional[DType] = None,
1423    mask: Optional[Tensor] = None,
1424) -> Tensor:
1425    """\
1426{reduction_signature}
1427{reduction_descr}
1428By definition, the identity value of a median operation is the median
1429value of the tensor. If all elements of the input tensor along given
1430dimension(s) :attr:`dim` are masked-out, the identity value of the
1431median is undefined.  Due to this ambiguity, the elements of output
1432tensor with strided layout, that correspond to fully masked-out
1433elements, have ``nan`` values.
1434{reduction_args}
1435{reduction_example}"""
1436    if dtype is None:
1437        dtype = input.dtype
1438    dim_ = _canonical_dim(dim, input.ndim)[0]
1439    is_float = torch.is_floating_point(input)
1440    if not is_float:
1441        input = input.to(dtype=torch.float)
1442    mask_input = _combine_input_and_mask(median, input, mask)
1443    if mask_input.layout == torch.strided:
1444        output = torch.nanmedian(mask_input, dim_, keepdim).values
1445        if is_float:
1446            return output
1447        elif not is_float and not torch.isnan(output).any():
1448            return output.to(dtype=dtype)
1449        else:
1450            raise ValueError(
1451                "masked median expects no fully masked out rows if dtype is not floating point"
1452            )
1453    else:
1454        raise ValueError(
1455            f"masked median expects strided tensor (got {mask_input.layout} tensor)"
1456        )
1457
1458
1459@_apply_docstring_templates
1460def logsumexp(
1461    input: Tensor,
1462    dim: DimOrDims = None,
1463    *,
1464    keepdim: bool = False,
1465    dtype: Optional[DType] = None,
1466    mask: Optional[Tensor] = None,
1467) -> Tensor:
1468    if dtype is None:
1469        dtype = input.dtype
1470    dim_ = _canonical_dim(dim, input.ndim)
1471    mask_input = _combine_input_and_mask(logsumexp, input, mask)
1472    if mask_input.layout == torch.strided:
1473        return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype)
1474    else:
1475        raise ValueError(
1476            f"masked logsumexp expects strided tensor (got {mask_input.layout} tensor)"
1477        )
1478
1479
1480# Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations
1481def logaddexp(
1482    input: Union[Tensor, MaskedTensor],
1483    other: Union[Tensor, MaskedTensor],
1484    *,
1485    dtype: Optional[DType] = None,
1486    input_mask: Optional[Tensor] = None,
1487    other_mask: Optional[Tensor] = None,
1488) -> Tensor:
1489    """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor
1490
1491    Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other`
1492    tensor. The :attr:`input` elements are masked out according to the boolean tensor
1493    :attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor
1494    :attr:`other_mask`.
1495
1496    The shapes of a mask tensor and the tensor to be masked
1497    don't need to match, but they must be :ref:`broadcastable
1498    <broadcasting-semantics>` and the dimensionality of the mask
1499    tensor must not be greater than of the tensor to be masked.
1500
1501    Args:
1502        input (Tensor): the input tensor
1503        other (Tensor): the second input tensor
1504
1505    Keyword args:
1506        dtype (:class:`torch.dtype`, optional): the desired data type
1507          of returned tensor.  If specified, the output tensor is
1508          casted to :attr:`dtype` after the operation is
1509          performed. Default: None.
1510        input_mask (:class:`torch.Tensor`, optional): the boolean tensor
1511          containing the binary mask of validity of :attr:`input` tensor elements.
1512          Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
1513        other_mask (:class:`torch.Tensor`, optional): the boolean tensor
1514          containing the binary mask of validity of :attr:`other` tensor elements.
1515          Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``.
1516
1517    Example::
1518
1519        >>> input = torch.tensor([-100.0, -200, -300])
1520        >>> input
1521        tensor([-100., -200., -300.])
1522        >>> other = torch.tensor([-1.0, -2, -3])
1523        >>> other
1524        tensor([-1., -2., -3.])
1525        >>> mask = torch.tensor([True, False, True])
1526        >>> mask
1527        tensor([ True, False,  True])
1528        >>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask)
1529        tensor([-1., -inf, -3.])"""
1530    if dtype is None:
1531        dtype = input.dtype
1532    if input.layout == torch.strided and other.layout == torch.strided:
1533        mask_input = _combine_input_and_mask(logaddexp, input, input_mask)
1534        mask_other = _combine_input_and_mask(logaddexp, other, other_mask)
1535        return torch.logaddexp(mask_input, mask_other).to(dtype=dtype)
1536    else:
1537        raise ValueError(
1538            f"masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)"
1539        )
1540
1541
1542@_apply_docstring_templates
1543def norm(
1544    input: Union[Tensor, MaskedTensor],
1545    ord: Optional[float] = 2.0,
1546    dim: DimOrDims = None,
1547    *,
1548    keepdim: Optional[bool] = False,
1549    dtype: Optional[DType] = None,
1550    mask: Optional[Tensor] = None,
1551) -> Tensor:
1552    """\
1553{reduction_signature}
1554
1555{reduction_descr}
1556
1557The identity value of norm operation, which is used to start the
1558reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is
1559``{identity_ord_ninf}``.
1560
1561{reduction_args}
1562
1563{reduction_example}"""
1564    if dtype is None:
1565        dtype = input.dtype
1566    mask_input = _combine_input_and_mask(norm, input, mask, ord)
1567    if mask_input.layout == torch.strided:
1568        dim_ = _canonical_dim(dim, input.ndim)
1569        return torch.linalg.vector_norm(
1570            mask_input, ord, dim_, bool(keepdim), dtype=dtype
1571        )
1572    else:
1573        raise ValueError(
1574            f"masked norm expects strided tensor (got {mask_input.layout} tensor)"
1575        )
1576
1577
1578def _std_var(
1579    input: Union[Tensor, MaskedTensor],
1580    dim: DimOrDims,
1581    unbiased: Optional[bool],
1582    *,
1583    correction_opt: Optional[Union[int, float]],
1584    keepdim: Optional[bool],
1585    dtype: Optional[DType],
1586    mask: Optional[Tensor],
1587    take_sqrt: Optional[bool],
1588) -> Tensor:
1589    assert (
1590        unbiased is None or correction_opt is None
1591    ), "Only one of unbiased and correction may be given"
1592    correction = 1.0
1593    if unbiased is not None:
1594        correction = 1.0 if unbiased else 0.0
1595    if correction_opt is not None:
1596        correction = sym_float(correction_opt)
1597
1598    if dtype is None:
1599        dtype = input.dtype
1600        if not (dtype.is_floating_point or dtype.is_complex):
1601            dtype = torch.float32
1602    compute_dtype = dtype
1603    if not (compute_dtype.is_floating_point or compute_dtype.is_complex):
1604        compute_dtype = torch.float32
1605    if input.layout == torch.strided:
1606        if mask is None:
1607            # TODO: compute count analytically
1608            count = sum(
1609                torch.ones(input.shape, dtype=torch.int64, device=input.device),
1610                dim,
1611                keepdim=True,
1612            )
1613            sample_total = sum(input, dim, keepdim=True, dtype=dtype)
1614        else:
1615            inmask = _input_mask(input, mask=mask)
1616            count = inmask.sum(dim=dim, keepdim=True)
1617            sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
1618        # TODO: replace torch.subtract/divide/square/maximum with
1619        # masked subtract/divide/square/maximum when these will be
1620        # available.
1621        sample_mean = torch.divide(sample_total, count)
1622        x = torch.subtract(input, sample_mean)
1623        if mask is None:
1624            total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
1625        else:
1626            total = sum(
1627                x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask  # type: ignore[possibly-undefined]
1628            )
1629        if not keepdim:
1630            count = count.reshape(total.shape)
1631        if correction != 0:
1632            real_dtype = (
1633                corresponding_real_dtype(compute_dtype)
1634                if compute_dtype.is_complex
1635                else compute_dtype
1636            )
1637            count = count.to(real_dtype)
1638            count = torch.subtract(count, correction)
1639            count = torch.maximum(count, count.new_zeros([]))
1640        output = torch.divide(total, count).to(dtype=dtype)
1641        if take_sqrt:
1642            output = torch.sqrt(output)
1643        return output
1644    else:
1645        raise ValueError(
1646            f"masked std/var expects strided tensor (got {input.layout} tensor)"
1647        )
1648
1649
1650@_apply_docstring_templates
1651def var(
1652    input: Union[Tensor, MaskedTensor],
1653    dim: DimOrDims = None,
1654    unbiased: Optional[bool] = None,
1655    *,
1656    correction: Optional[Union[int, float]] = None,
1657    keepdim: Optional[bool] = False,
1658    dtype: Optional[DType] = None,
1659    mask: Optional[Tensor] = None,
1660) -> Tensor:
1661    """\
1662{reduction_signature}
1663{reduction_descr}
1664The identity value of sample variance operation is undefined. The
1665elements of output tensor with strided layout, that correspond to
1666fully masked-out elements, have ``nan`` values.
1667{reduction_args}
1668{reduction_example}"""
1669    return _std_var(
1670        input=input,
1671        dim=dim,
1672        unbiased=unbiased,
1673        correction_opt=correction,
1674        keepdim=keepdim,
1675        dtype=dtype,
1676        mask=mask,
1677        take_sqrt=False,
1678    )
1679
1680
1681@_apply_docstring_templates
1682def std(
1683    input: Union[Tensor, MaskedTensor],
1684    dim: DimOrDims = None,
1685    unbiased: Optional[bool] = None,
1686    *,
1687    correction: Optional[int] = None,
1688    keepdim: Optional[bool] = False,
1689    dtype: Optional[DType] = None,
1690    mask: Optional[Tensor] = None,
1691) -> Tensor:
1692    """\
1693{reduction_signature}
1694{reduction_descr}
1695The identity value of sample standard deviation operation is undefined. The
1696elements of output tensor with strided layout, that correspond to
1697fully masked-out elements, have ``nan`` values.
1698{reduction_args}
1699{reduction_example}"""
1700    return _std_var(
1701        input=input,
1702        dim=dim,
1703        unbiased=unbiased,
1704        correction_opt=correction,
1705        keepdim=keepdim,
1706        dtype=dtype,
1707        mask=mask,
1708        take_sqrt=True,
1709    )
1710
1711
1712@_apply_docstring_templates
1713def softmax(
1714    input: Union[Tensor, MaskedTensor],
1715    dim: int,
1716    *,
1717    dtype: Optional[DType] = None,
1718    mask: Optional[Tensor] = None,
1719) -> Tensor:
1720    if dtype is None:
1721        dtype = input.dtype
1722    dim_ = _canonical_dim(dim, input.ndim)[0]
1723    mask_input = _combine_input_and_mask(amax, input, mask)
1724    if mask_input.layout == torch.strided:
1725        return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype)
1726    else:
1727        raise ValueError(
1728            f"masked softmax expects strided tensor (got {mask_input.layout} tensor)"
1729        )
1730
1731
1732@_apply_docstring_templates
1733def log_softmax(
1734    input: Union[Tensor, MaskedTensor],
1735    dim: int,
1736    *,
1737    dtype: Optional[DType] = None,
1738    mask: Optional[Tensor] = None,
1739) -> Tensor:
1740    if dtype is None:
1741        dtype = input.dtype
1742    dim_ = _canonical_dim(dim, input.ndim)[0]
1743    mask_input = _combine_input_and_mask(amax, input, mask)
1744    if mask_input.layout == torch.strided:
1745        return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype)
1746    else:
1747        raise ValueError(
1748            f"masked log_softmax expects strided tensor (got {mask_input.layout} tensor)"
1749        )
1750
1751
1752@_apply_docstring_templates
1753def softmin(
1754    input: Union[Tensor, MaskedTensor],
1755    dim: int,
1756    *,
1757    dtype: Optional[DType] = None,
1758    mask: Optional[Tensor] = None,
1759) -> Tensor:
1760    if dtype is None:
1761        dtype = input.dtype
1762    dim_ = _canonical_dim(dim, input.ndim)[0]
1763    mask_input = _combine_input_and_mask(amin, input, mask)
1764    if mask_input.layout == torch.strided:
1765        return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype)
1766    else:
1767        raise ValueError(
1768            f"masked softmin expects strided tensor (got {mask_input.layout} tensor)"
1769        )
1770
1771
1772@_apply_docstring_templates
1773def normalize(
1774    input: Union[Tensor, MaskedTensor],
1775    ord: float,
1776    dim: int,
1777    *,
1778    eps: float = 1e-12,
1779    dtype: Optional[DType] = None,
1780    mask: Optional[Tensor] = None,
1781) -> Tensor:
1782    if dtype is None:
1783        dtype = input.dtype
1784    dim_ = _canonical_dim(dim, input.ndim)[0]
1785    # TODO: eliminate mask_input as unnecessary when using masked divide.
1786    mask_input = _combine_input_and_mask(sum, input, mask)
1787    if mask_input.layout == torch.strided:
1788        nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask)
1789        # TODO: replace torch.maximum with masked maximum when available.
1790        denom = torch.maximum(nrm_, nrm_.new_full([], eps))
1791        # TODO: replace torch.divide with masked divide when available.
1792        return torch.divide(mask_input, denom)
1793    else:
1794        raise ValueError(
1795            f"masked normalize expects strided tensor (got {mask_input.layout} tensor)"
1796        )
1797