xref: /aosp_15_r20/external/pytorch/torch/masked/maskedtensor/reductions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3
4import warnings
5
6import torch
7
8from .core import is_masked_tensor
9from .creation import as_masked_tensor, masked_tensor
10
11
12__all__ = []  # type: ignore[var-annotated]
13
14
15def _masked_all_all(data, mask=None):
16    if mask is None:
17        return data.all()
18    return data.masked_fill(~mask, True).all()
19
20
21def _masked_all_dim(data, dim, keepdim=False, mask=None):
22    if mask is None:
23        return torch.all(data, dim=dim, keepdim=keepdim)
24    return torch.all(data.masked_fill(~mask, True), dim=dim, keepdim=keepdim)
25
26
27def _masked_all(*args, **kwargs):
28    if len(args) == 1 and len(kwargs) == 1:
29        return _masked_all_all(args[0], mask=kwargs["mask"])
30    return _masked_all_dim(*args, **kwargs)
31
32
33def _multidim_any(mask, dim, keepdim):
34    if isinstance(dim, int):
35        return _multidim_any(mask, [dim], keepdim)
36    for d in sorted(dim, reverse=True):
37        mask = torch.any(mask, dim=d, keepdim=keepdim)
38    return mask
39
40
41def _get_masked_fn(fn):
42    if fn == "all":
43        return _masked_all
44    return getattr(torch.masked, fn)
45
46
47def _torch_reduce_all(fn):
48    def reduce_all(self):
49        masked_fn = _get_masked_fn(fn)
50        data = self.get_data()
51        mask = self.get_mask().values() if self.is_sparse else self.get_mask()
52        # When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the
53        # element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts.
54        # Therefore, this implementation calculates it using the strides.
55        if fn == "all":
56            result_data = masked_fn(data, mask=mask)
57
58        elif fn in {"argmin", "argmax"} and self.is_sparse_coo():
59            sparse_idx = masked_fn(data.values(), mask=mask).to(dtype=torch.int)
60            indices = (
61                data.to_sparse_coo().indices()
62                if not self.is_sparse_coo()
63                else data.indices()
64            )
65            idx = indices.unbind(1)[sparse_idx]
66            stride = data.size().numel() / torch.tensor(
67                data.size(), device=data.device
68            ).cumprod(0)
69            result_data = torch.sum(idx * stride)
70
71        # we simply pass in the values for sparse COO/CSR tensors
72        elif self.is_sparse:
73            result_data = masked_fn(masked_tensor(data.values(), mask))
74
75        else:
76            result_data = masked_fn(self, mask=mask)
77
78        return as_masked_tensor(result_data, torch.any(mask))
79
80    return reduce_all
81
82
83def _torch_reduce_dim(fn):
84    def reduce_dim(self, dim, keepdim=False, dtype=None):
85        if self.is_sparse:
86            msg = (
87                f"The sparse version of {fn} is not implemented in reductions.\n"
88                "If you would like this operator to be supported, please file an issue for a feature request at "
89                "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
90                "In the case that the semantics for the operator are not trivial, it would be appreciated "
91                "to also include a proposal for the semantics."
92            )
93            warnings.warn(msg)
94            return NotImplemented
95        if not is_masked_tensor(self):
96            raise TypeError("Input to reduce_dim must be a MaskedTensor")
97
98        masked_fn = _get_masked_fn(fn)
99        data = self.get_data()
100        mask = self.get_mask()
101        if fn == "all":
102            result_data = masked_fn(data, dim=dim, keepdim=keepdim, mask=mask)
103        else:
104            result_data = masked_fn(
105                self, dim=dim, keepdim=keepdim, dtype=dtype, mask=self.get_mask()
106            )
107        return as_masked_tensor(result_data, _multidim_any(mask, dim, keepdim))
108
109    return reduce_dim
110
111
112def _torch_reduce(fn):
113    def reduce_fn(*args, **kwargs):
114        if len(args) == 1 and len(kwargs) == 0:
115            return _torch_reduce_all(fn)(args[0])
116        return _torch_reduce_dim(fn)(*args, **kwargs)
117
118    return reduce_fn
119
120
121def _reduce_dim_args(input, dim, keepdim=False, dtype=None):
122    return input, dim, keepdim, dtype
123
124
125def _torch_grad_reduce(fn):
126    def grad_reduce(*args, **kwargs):
127        if len(args) == 1 and len(kwargs) == 0:
128            return _torch_reduce_all(fn)(args[0])
129        # TODO: autograd.Function doesn't support kwarg
130        input, dim, keepdim, dtype = _reduce_dim_args(*args, **kwargs)
131        return _torch_reduce_dim(fn)(input, dim, keepdim, dtype)
132
133    return grad_reduce
134
135
136REDUCE_NAMES = [
137    "sum",
138    "mean",
139    "amin",
140    "amax",
141    "argmin",
142    "argmax",
143    "prod",
144    "all",
145    "norm",
146    "var",
147    "std",
148]
149
150NATIVE_REDUCE_MAP = {
151    getattr(torch.ops.aten, name): _torch_reduce(name) for name in REDUCE_NAMES
152}
153TORCH_REDUCE_MAP = {
154    getattr(torch, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
155}
156TENSOR_REDUCE_MAP = {
157    getattr(torch.Tensor, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
158}
159
160NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys())
161TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys())
162TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys())
163
164
165def _is_reduction(fn):
166    return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP
167
168
169def _apply_reduction(fn, *args, **kwargs):
170    if fn in NATIVE_REDUCE_MAP:
171        return NATIVE_REDUCE_MAP[fn](*args, **kwargs)
172    if fn in TORCH_REDUCE_MAP:
173        return TORCH_REDUCE_MAP[fn](*args, **kwargs)
174    if fn in TENSOR_REDUCE_MAP:
175        return TENSOR_REDUCE_MAP[fn](*args, **kwargs)
176    return NotImplemented
177