xref: /aosp_15_r20/external/pytorch/torch/masked/maskedtensor/core.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3
4import warnings
5from typing import Any
6from typing_extensions import TypeGuard
7
8import torch
9from torch.overrides import get_default_nowrap_functions
10
11
12__all__ = [
13    "MaskedTensor",
14    "is_masked_tensor",
15]
16
17
18def is_masked_tensor(obj: Any, /) -> TypeGuard["MaskedTensor"]:
19    r"""Returns True if the input is a MaskedTensor, else False
20
21    Args:
22        a: any input
23
24    Examples:
25
26        >>> # xdoctest: +SKIP
27        >>> from torch.masked import MaskedTensor
28        >>> data = torch.arange(6).reshape(2,3)
29        >>> mask = torch.tensor([[True, False, False], [True, True, False]])
30        >>> mt = MaskedTensor(data, mask)
31        >>> is_masked_tensor(mt)
32        True
33    """
34    return isinstance(obj, MaskedTensor)
35
36
37def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08):
38    if is_masked_tensor(a) or is_masked_tensor(b):
39        raise ValueError("Neither `a` nor `b` can be a MaskedTensor.")
40    if a.layout != b.layout:
41        raise ValueError(
42            f"`a` and `b` must have the same layout. Got {a.layout} and {b.layout}"
43        )
44
45    if a.dtype != b.dtype:
46        b = b.type(a.dtype)
47    if a.layout == b.layout == torch.sparse_coo:
48        return _tensors_match(a.values(), b.values(), exact) and _tensors_match(
49            a.indices(), b.indices(), exact
50        )
51    elif a.layout == b.layout == torch.sparse_csr:
52        return (
53            _tensors_match(a.crow_indices(), b.crow_indices(), exact)
54            and _tensors_match(a.col_indices(), b.col_indices(), exact)
55            and _tensors_match(a.values(), b.values(), exact)
56        )
57    if exact:
58        return (a.dim() == b.dim()) and torch.eq(a, b).all().item()
59    return (a.dim() == b.dim()) and torch.allclose(a, b, rtol=rtol, atol=atol)
60
61
62def _masks_match(a, b):
63    if is_masked_tensor(a) and is_masked_tensor(b):
64        mask_a = a.get_mask()
65        mask_b = b.get_mask()
66        return _tensors_match(mask_a, mask_b, exact=True)
67    return True
68
69
70def _map_mt_args_kwargs(args, kwargs, map_fn):
71    def _helper(a, map_fn):
72        if is_masked_tensor(a):
73            return map_fn(a)
74        elif torch.is_tensor(a):
75            return a
76        elif isinstance(a, list):
77            a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
78            return a_impl
79        elif isinstance(a, tuple):
80            a_impl, _ = _map_mt_args_kwargs(a, {}, map_fn)
81            return tuple(a_impl)
82        else:
83            return a
84
85    if kwargs is None:
86        kwargs = {}
87    impl_args = []
88    for a in args:
89        impl_args.append(_helper(a, map_fn))
90    impl_kwargs = {}
91    for k in kwargs.keys():
92        impl_kwargs[k] = _helper(a, map_fn)
93    return impl_args, impl_kwargs
94
95
96def _wrap_result(result_data, result_mask):
97    if isinstance(result_data, list):
98        return [_wrap_result(r, m) for (r, m) in zip(result_data, result_mask)]
99    if isinstance(result_data, tuple):
100        return tuple(_wrap_result(r, m) for (r, m) in zip(result_data, result_mask))
101    if torch.is_tensor(result_data):
102        return MaskedTensor(result_data, result_mask)
103    # Expect result_data and result_mask to be Tensors only
104    return NotImplemented
105
106
107def _masked_tensor_str(data, mask, formatter):
108    if data.layout in {torch.sparse_coo, torch.sparse_csr}:
109        data = data.to_dense()
110        mask = mask.to_dense()
111    if data.dim() == 1:
112        formatted_elements = [
113            formatter.format(d.item()) if isinstance(d.item(), float) else str(d.item())
114            for d in data
115        ]
116        max_len = max(8 if x[1] else len(x[0]) for x in zip(formatted_elements, ~mask))
117        return (
118            "["
119            + ", ".join(
120                [
121                    "--".rjust(max_len) if m else e
122                    for (e, m) in zip(formatted_elements, ~mask)
123                ]
124            )
125            + "]"
126        )
127    sub_strings = [_masked_tensor_str(d, m, formatter) for (d, m) in zip(data, mask)]
128    sub_strings = ["\n".join(["  " + si for si in s.split("\n")]) for s in sub_strings]
129    return "[\n" + ",\n".join(sub_strings) + "\n]"
130
131
132def _get_data(a):
133    if is_masked_tensor(a):
134        return a._masked_data
135    return a
136
137
138def _maybe_get_mask(a):
139    if is_masked_tensor(a):
140        return a.get_mask()
141    return None
142
143
144class MaskedTensor(torch.Tensor):
145    @staticmethod
146    def __new__(cls, data, mask, requires_grad=False):
147        if is_masked_tensor(data) or not torch.is_tensor(data):
148            raise TypeError("data must be a Tensor")
149        if is_masked_tensor(mask) or not torch.is_tensor(mask):
150            raise TypeError("mask must be a Tensor")
151        # Use a Tensor that of the give size for the wrapper.
152        kwargs = {
153            "device": data.device,
154            "dtype": data.dtype,
155            "layout": data.layout,
156            "requires_grad": requires_grad,
157            "dispatch_sizes_strides_policy": "strides",
158            "dispatch_layout": True,
159        }
160        warnings.warn(
161            (
162                "The PyTorch API of MaskedTensors is in prototype stage "
163                "and will change in the near future. Please open a Github issue "
164                "for features requests and see our documentation on the torch.masked "
165                "module for further information about the project."
166            ),
167            UserWarning,
168            stacklevel=2,
169        )
170        if data.requires_grad:
171            warnings.warn(
172                "It is not recommended to create a MaskedTensor with a tensor that requires_grad. "
173                "To avoid this, you can use data.clone().detach()",
174                UserWarning,
175                stacklevel=2,
176            )
177        return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)  # type: ignore[attr-defined]
178
179    def _preprocess_data(self, data, mask):
180        from .._ops import _sparse_coo_where, _sparse_csr_where
181
182        if data.layout != mask.layout:
183            raise TypeError("data and mask must have the same layout.")
184        if data.layout == torch.sparse_coo:
185            data = data.coalesce()
186            mask = mask.coalesce()
187            if data._nnz() != mask._nnz():
188                data = _sparse_coo_where(mask, data, torch.tensor(0))
189        elif data.layout == torch.sparse_csr:
190            if data._nnz() != mask._nnz():
191                data = _sparse_csr_where(mask, data, torch.tensor(0))
192
193        # Have to pick awkward names to not conflict with existing fields such as data
194        self._masked_data = data.clone()
195        self._masked_mask = mask.clone()
196
197    def _validate_members(self):
198        data = self._masked_data
199        mask = self.get_mask()
200        if type(data) != type(mask):
201            raise TypeError(
202                f"data and mask must have the same type. Got {type(data)} and {type(mask)}"
203            )
204        if data.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
205            raise TypeError(f"data layout of {data.layout} is not supported.")
206        if data.layout == torch.sparse_coo:
207            if not _tensors_match(data.indices(), mask.indices(), exact=True):
208                raise ValueError(
209                    "data and mask are both sparse COO tensors but do not have the same indices."
210                )
211        elif data.layout == torch.sparse_csr:
212            if not _tensors_match(
213                data.crow_indices(), mask.crow_indices(), exact=True
214            ) or not _tensors_match(data.col_indices(), mask.col_indices(), exact=True):
215                raise ValueError(
216                    "data and mask are both sparse CSR tensors but do not share either crow or col indices."
217                )
218        if mask.dtype != torch.bool:
219            raise TypeError("mask must have dtype bool.")
220        if not (
221            data.dtype == torch.float16
222            or data.dtype == torch.float32
223            or data.dtype == torch.float64
224            or data.dtype == torch.bool
225            or data.dtype == torch.int8
226            or data.dtype == torch.int16
227            or data.dtype == torch.int32
228            or data.dtype == torch.int64
229        ):
230            raise TypeError(f"{data.dtype} is not supported in MaskedTensor.")
231        if data.dim() != mask.dim():
232            raise ValueError("data.dim() must equal mask.dim()")
233        if data.size() != mask.size():
234            raise ValueError("data.size() must equal mask.size()")
235
236    def __init__(self, data, mask, requires_grad=False):
237        self._preprocess_data(data, mask)
238        self._validate_members()
239
240    @staticmethod
241    def _from_values(data, mask):
242        """Differentiable constructor for MaskedTensor"""
243
244        class Constructor(torch.autograd.Function):
245            @staticmethod
246            def forward(ctx, data, mask):
247                return MaskedTensor(data, mask)
248
249            @staticmethod
250            def backward(ctx, grad_output):
251                return grad_output, None
252
253        result = Constructor.apply(data, mask)
254        return result
255
256    def _set_data_mask(self, data, mask):
257        self._masked_data = data
258        self._masked_mask = mask
259        self._validate_members()
260
261    def __repr__(self):
262        formatter = "{0:8.4f}"
263        if self.dim() == 0:
264            scalar_data = self.get_data().item()
265            data_formatted = (
266                formatter.format(scalar_data)
267                if isinstance(scalar_data, float)
268                else str(scalar_data)
269            )
270            if not self.get_mask().item():
271                data_formatted = "--"
272            return (
273                "MaskedTensor("
274                + data_formatted
275                + ", "
276                + str(self.get_mask().item())
277                + ")"
278            )
279        s = _masked_tensor_str(self.get_data(), self.get_mask(), formatter)
280        s = "\n".join("  " + si for si in s.split("\n"))
281        return "MaskedTensor(\n" + s + "\n)"
282
283    # Seems like this needs to be defined before torch_dispatch to work
284    @classmethod
285    def __torch_function__(cls, func, types, args=(), kwargs=None):
286        kwargs = kwargs or {}
287
288        from ._ops_refs import _MASKEDTENSOR_FUNCTION_TABLE
289
290        if func in _MASKEDTENSOR_FUNCTION_TABLE:
291            return _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs)
292
293        if not all(issubclass(cls, t) for t in types):
294            return NotImplemented
295        with torch._C.DisableTorchFunctionSubclass():
296            ret = func(*args, **kwargs)
297            if func in get_default_nowrap_functions():
298                return ret
299            else:
300                return torch._tensor._convert(ret, cls)
301
302    @classmethod
303    def unary(cls, fn, data, mask):
304        return MaskedTensor(fn(data), mask)
305
306    @classmethod
307    def __torch_dispatch__(cls, func, types, args, kwargs):
308        func = func.overloadpacket
309
310        from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE
311
312        if func in _MASKEDTENSOR_DISPATCH_TABLE:
313            return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
314
315        msg = (
316            f"{func.__name__} is not implemented in __torch_dispatch__ for MaskedTensor.\n"
317            "If you would like this operator to be supported, please file an issue for a feature request at "
318            "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
319            "In the case that the semantics for the operator are not trivial, it would be appreciated "
320            "to also include a proposal for the semantics."
321        )
322        warnings.warn(msg)
323        return NotImplemented
324
325    def __lt__(self, other):
326        if is_masked_tensor(other):
327            return MaskedTensor(self.get_data() < _get_data(other), self.get_mask())
328        return MaskedTensor(self.get_data() < other, self.get_mask())
329
330    def to_tensor(self, value):
331        return self.get_data().masked_fill(~self.get_mask(), value)
332
333    def get_data(self):
334        class GetData(torch.autograd.Function):
335            @staticmethod
336            def forward(ctx, self):
337                return self._masked_data
338
339            @staticmethod
340            def backward(ctx, grad_output):
341                if is_masked_tensor(grad_output):
342                    return grad_output
343                return MaskedTensor(grad_output, self.get_mask())
344
345        return GetData.apply(self)
346
347    def get_mask(self):
348        return self._masked_mask
349
350    def is_sparse_coo(self):
351        return self.layout == torch.sparse_coo
352
353    def is_sparse_csr(self):
354        return self.layout == torch.sparse_csr
355
356    # Update later to support more sparse layouts
357    @property
358    def is_sparse(self):
359        return self.is_sparse_coo() or self.is_sparse_csr()
360