xref: /aosp_15_r20/external/pytorch/test/test_maskedtensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: masked operators"]
2
3import torch
4import unittest
5from torch.testing._internal.common_utils import (
6    decorateIf,
7    TestCase,
8    run_tests,
9    make_tensor,
10    parametrize,
11    instantiate_parametrized_tests,
12)
13from torch.testing._internal.common_device_type import (
14    instantiate_device_type_tests,
15    ops,
16)
17from torch.testing._internal.common_methods_invocations import (
18    SampleInput,
19    binary_ufuncs,
20    reduction_ops,
21    unary_ufuncs,
22)
23
24from torch.masked import as_masked_tensor, masked_tensor, _combine_input_and_mask
25from torch.masked.maskedtensor.core import _masks_match, _tensors_match
26from torch.masked.maskedtensor.unary import NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS, UNARY_NAMES
27from torch.masked.maskedtensor.binary import NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS, BINARY_NAMES
28from torch.masked.maskedtensor.reductions import REDUCE_NAMES
29
30
31def _compare_mt_t(mt_result, t_result, rtol=1e-05, atol=1e-05):
32    mask = mt_result.get_mask()
33    mt_result_data = mt_result.get_data()
34    if mask.layout in {torch.sparse_coo, torch.sparse_csr}:
35        mask = mask.to_dense()
36    if mt_result_data.layout in {torch.sparse_coo, torch.sparse_csr}:
37        mt_result_data = mt_result_data.to_dense()
38    a = mt_result_data.detach().masked_fill_(~mask, 0)
39    b = t_result.detach().masked_fill_(~mask, 0)
40    if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol):
41        raise ValueError("The data in MaskedTensor a and Tensor b do not match")
42
43def _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08):
44    mt_data1 = mt1.get_data()
45    mt_data2 = mt2.get_data()
46    if mt_data1.layout != mt_data2.layout:
47        raise ValueError("mt1's data and mt2's data do not have the same layout. "
48                         f"mt1.get_data().layout = {mt_data1.layout} while mt2.get_data().layout = {mt_data2.layout}")
49
50    mask = mt1.get_mask()
51    mask2 = mt2.get_mask()
52    if not _masks_match(mt1, mt2):
53        raise ValueError("mt1 and mt2 must have matching masks")
54    if mask.layout != mask2.layout:
55        raise ValueError("mt1's mask and mt2's mask do not have the same layout. "
56                         f"mt1.get_mask().layout = {mask.layout} while mt2.get_mask().layout = {mask2.layout}")
57    if mask.layout in {torch.sparse_coo, torch.sparse_csr}:
58        mask = mask.to_dense()
59
60    if mt_data1.layout in {torch.sparse_coo, torch.sparse_csr}:
61        mt_data1 = mt_data1.to_dense()
62        mt_data2 = mt_data2.to_dense()
63    a = mt_data1.detach().masked_fill_(~mask, 0)
64    b = mt_data2.detach().masked_fill_(~mask, 0)
65
66    if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol):
67        raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match")
68
69def _compare_forward_backward(data, mask, fn):
70    mt = masked_tensor(data, mask, requires_grad=True)
71    masked_res = fn(mt)
72    masked_res.sum().backward()
73
74    t = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_()
75    tensor_res = fn(t)
76    tensor_res.sum().backward()
77
78    _compare_mt_t(masked_res, tensor_res)
79    _compare_mt_t(mt.grad, t.grad, atol=1e-06)
80
81
82def _create_random_mask(shape, device):
83    return make_tensor(shape, device=device, dtype=torch.bool)
84
85def _generate_sample_data(
86    device="cpu", dtype=torch.float, requires_grad=True, layout=torch.strided
87):
88    assert layout in {
89        torch.strided,
90        torch.sparse_coo,
91        torch.sparse_csr,
92    }, "Layout must be strided/sparse_coo/sparse_csr"
93    shapes = [
94        [],
95        [2],
96        [3, 5],
97        [3, 2, 1, 2],
98    ]
99    inputs = []
100    for s in shapes:
101        data = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)  # type: ignore[arg-type]
102        mask = _create_random_mask(s, device)
103        if layout == torch.sparse_coo:
104            mask = mask.to_sparse_coo().coalesce()
105            data = data.sparse_mask(mask).requires_grad_(requires_grad)
106        elif layout == torch.sparse_csr:
107            if data.ndim != 2 and mask.ndim != 2:
108                continue
109            mask = mask.to_sparse_csr()
110            data = data.sparse_mask(mask)
111        inputs.append(SampleInput(data, kwargs={"mask": mask}))
112    return inputs
113
114def _fix_fn_name(fn_name):
115    if fn_name[-1] == "_":
116        fn_name = fn_name[:-1]
117    return fn_name
118
119
120class TestBasics(TestCase):
121    def test_invalid_tensor_inputs(self, device):
122        data = torch.randn((3, 4), device=device)
123        mask = _create_random_mask((3, 4), device=device)
124        mt = masked_tensor(data, mask)
125
126        with self.assertRaisesRegex(TypeError, "data must be a Tensor"):
127            masked_tensor(mt, mask)
128        with self.assertRaisesRegex(TypeError, "data must be a Tensor"):
129            masked_tensor(0, mask)
130        with self.assertRaisesRegex(TypeError, "mask must be a Tensor"):
131            masked_tensor(data, mt)
132        with self.assertRaisesRegex(TypeError, "mask must be a Tensor"):
133            masked_tensor(data, 0)
134
135    def test_diff_layouts(self, device):
136        data = torch.randn((3, 4), device=device).to_sparse_coo()
137        mask = _create_random_mask((3, 4), device=device)
138        with self.assertRaisesRegex(TypeError, "data and mask must have the same layout"):
139            masked_tensor(data, mask)
140
141    def test_diff_dim(self, device):
142        data = torch.randn((3, 4, 5), device=device)
143        mask = _create_random_mask((3, 4), device=device)
144        with self.assertRaisesRegex(ValueError, "data.dim\\(\\) must equal mask.dim\\(\\)"):
145            masked_tensor(data, mask)
146
147    def test_diff_sizes(self, device):
148        data = torch.randn((3, 4), device=device)
149        mask = _create_random_mask((3, 3), device=device)
150        with self.assertRaisesRegex(ValueError, "data.size\\(\\) must equal mask.size\\(\\)"):
151            masked_tensor(data, mask)
152
153    def test_grad_warning(self, device):
154        data = torch.randn((3, 4), device=device, requires_grad=True)
155        mask = _create_random_mask((3, 4), device=device)
156        msg = "It is not recommended to create a MaskedTensor with a tensor that requires_grad."
157        with self.assertWarnsRegex(UserWarning, msg):
158            mt = masked_tensor(data, mask)
159
160    def test_add(self, device):
161        data = torch.arange(5.0, device=device)
162        mask = torch.tensor([True, True, False, True, False], device=device)
163        m0 = masked_tensor(data, mask)
164        m1 = masked_tensor(data, ~mask)
165        with self.assertRaisesRegex(ValueError, "Input masks must match."):
166            m0 + m1
167        _compare_mts(m0 + m0, masked_tensor(torch.tensor([0., 2, 0, 6, 0], device=device), mask))
168
169    def test_softmax(self, device):
170        data = torch.randn((3, 4), device=device) * 0.1
171        mask = torch.tensor(
172            [
173                [True, True, True, False],
174                [False, True, False, True],
175                [True, True, False, False],
176            ],
177            device=device
178        )
179
180        _compare_forward_backward(data, mask, lambda t: torch.softmax(t, -1))
181
182    def test_where(self, device):
183        data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device)
184        mask = data < 0
185
186        mx = masked_tensor(data, mask, requires_grad=True)
187        my = masked_tensor(torch.ones_like(data), ~mask, requires_grad=True)
188        masked_res = torch.where(mask, torch.exp(mx), my)
189        masked_res.sum().backward()
190
191        x = data.detach().clone().requires_grad_()
192        y = torch.ones_like(x, device=device, requires_grad=True)
193        tensor_res = torch.where(mask, torch.exp(x), y)
194        tensor_res.sum().backward()
195
196        _compare_mt_t(masked_res, tensor_res)
197        _compare_mt_t(mx.grad, x.grad)
198        _compare_mt_t(my.grad, y.grad)
199
200    def test_unfold(self, device):
201        data = torch.rand(5, 5, device=device)
202        mask = torch.rand(5, 5, device=device) > 0.5
203        _compare_forward_backward(data, mask, lambda t: t.unfold(1, 2, 2))
204
205    def test_nn_unfold(self, device):
206        data = torch.rand(2, 5, 3, 4, device=device)
207        mask = torch.rand(2, 5, 3, 4, device=device) > 0.5
208        _compare_forward_backward(data, mask, lambda t: torch.nn.functional.unfold(t, kernel_size=(2, 3)))
209
210    def test_stack(self, device):
211        masked_tensors = [
212            masked_tensor(
213                torch.rand(2, 5, 3, 4, device=device),
214                torch.rand(2, 5, 3, 4, device=device) > 0.5,
215                requires_grad=True,
216            ) for _ in range(3)
217        ]
218
219        data_tensors = [mt.get_data().detach().clone().requires_grad_() for mt in masked_tensors]
220        masked_res = torch.stack(masked_tensors)
221        tensor_res = torch.stack(data_tensors)
222
223        masked_res.sum().backward()
224        tensor_res.sum().backward()
225        _compare_mt_t(masked_res, tensor_res)
226        for mt, t in zip(masked_tensors, data_tensors):
227            _compare_mt_t(mt.grad, t.grad, atol=1e-06)
228
229    def test_to_sparse(self, device):
230        for sample in _generate_sample_data(device=device):
231            data = sample.input
232            mask = sample.kwargs["mask"]
233            mt = masked_tensor(data.clone().detach(), mask, requires_grad=True)
234
235            sparse_mt = mt.to_sparse()
236            data.to_sparse().to_dense().sum().backward()
237            sparse_mt.to_dense().sum().backward()
238
239            _compare_mt_t(sparse_mt, data)
240            _compare_mt_t(mt.grad, data.grad)
241
242    def test_to_dense(self, device):
243        samples = _generate_sample_data(
244            device=device,
245            layout=torch.sparse_coo
246        ) + _generate_sample_data(device=device, layout=torch.sparse_csr)
247        for sample in samples:
248            data = sample.input
249            mask = sample.kwargs["mask"]
250            mt = masked_tensor(data, mask, requires_grad=True)
251
252            dense_data = data.to_dense().detach().clone().requires_grad_(True)
253            dense_mt = mt.to_dense()
254            dense_data.sum().backward()
255            dense_mt.sum().backward()
256
257            _compare_mt_t(dense_mt, dense_data)
258            _compare_mt_t(mt.grad.to_dense(), dense_data.grad)
259
260    def test_to_dense_and_sparse_coo(self, device):
261        for sample in _generate_sample_data(device=device, layout=torch.strided):
262            data = sample.input
263            mask = sample.kwargs["mask"]
264            ms = mask.to_sparse_coo().coalesce()
265
266            mt = masked_tensor(data, mask, requires_grad=True)
267            mts = masked_tensor(data.sparse_mask(ms), ms, requires_grad=True)
268
269            converted = mt.to_sparse().to_dense()
270            converted.sum().backward()
271
272            converted2 = mts.to_dense()
273            converted2.sum().backward()
274
275            _compare_mts(converted, converted2)
276            _compare_mts(mt.grad, mts.grad.to_dense())
277
278    def test_to_dense_and_sparse_csr(self, device):
279        for sample in _generate_sample_data(device=device, layout=torch.strided):
280            data = sample.input
281            mask = sample.kwargs["mask"]
282            if data.ndim != 2:
283                continue
284            ms = mask.to_sparse_csr()
285
286            mt = masked_tensor(data, mask, requires_grad=True)
287            mts = masked_tensor(data.sparse_mask(ms), ms, requires_grad=True)
288
289            converted = mt.to_sparse_csr().to_dense()
290            converted.sum().backward()
291
292            converted2 = mts.to_dense()
293            converted2.sum().backward()
294
295            _compare_mts(converted, converted2)
296            _compare_mts(mt.grad, mts.grad.to_dense())
297
298    def test_invalid_sparse_layout(self, device):
299        data = torch.randn((3, 4), device=device).to_sparse_csc()
300        mask = _create_random_mask((3, 4), device=device).to_sparse_csc()
301        with self.assertRaisesRegex(TypeError, "data layout of torch.sparse_csc is not supported"):
302            masked_tensor(data, mask)
303
304    def test_invalid_sparse_coo_values(self, device):
305        v = torch.tensor([3, 4, 5], dtype=torch.float32)
306        i1 = torch.tensor([[0, 1, 1], [2, 0, 2]])
307        i2 = torch.tensor([[0, 1, 1], [2, 1, 2]])
308
309        t = torch.sparse_coo_tensor(i1, v, (2, 4), device=device)
310        mask = torch.sparse_coo_tensor(i2, torch.tensor([True, True, True]), (2, 4), device=device)
311
312        msg = "data and mask are both sparse COO tensors but do not have the same indices."
313        with self.assertRaisesRegex(ValueError, msg):
314            masked_tensor(t, mask)
315
316    def test_invalid_sparse_csr_values(self, device):
317        crow_indices1 = [0, 2, 3]
318        crow_indices2 = [0, 1, 3]
319        col_indices1 = [0, 1, 2]
320        col_indices2 = [1, 2, 3]
321
322        values = [2, 3, 4]
323        mask_values = [True, True, True]
324
325        t1 = torch.sparse_csr_tensor(
326            torch.tensor(crow_indices1, dtype=torch.int64),
327            torch.tensor(col_indices1, dtype=torch.int64),
328            torch.tensor(values),
329            size=(2, 4)
330        )
331        mask1 = torch.sparse_csr_tensor(
332            torch.tensor(crow_indices2, dtype=torch.int64),
333            torch.tensor(col_indices1, dtype=torch.int64),
334            torch.tensor(mask_values),
335            dtype=torch.bool,
336            size=(2, 4),
337        )
338        t2 = torch.sparse_csr_tensor(
339            torch.tensor(crow_indices2, dtype=torch.int64),
340            torch.tensor(col_indices1, dtype=torch.int64),
341            torch.tensor(values),
342            size=(2, 4),
343        )
344        mask2 = torch.sparse_csr_tensor(
345            torch.tensor(crow_indices2, dtype=torch.int64),
346            torch.tensor(col_indices2, dtype=torch.int64),
347            torch.tensor(mask_values),
348            dtype=torch.bool,
349            size=(2, 4),
350        )
351
352        msg = "data and mask are both sparse CSR tensors but do not share either crow or col indices."
353        with self.assertRaisesRegex(ValueError, msg):
354            masked_tensor(t1, mask1)
355        with self.assertRaisesRegex(ValueError, msg):
356            masked_tensor(t2, mask2)
357
358    def test_contiguous(self, device):
359        data = torch.randn((3, 3), device=device)
360
361        contiguous_data = data.clone()
362        mask1 = (contiguous_data > 0).bool()
363        not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2))
364        mask2 = (not_contiguous_data > 0).bool()
365
366        contiguous_mt = masked_tensor(contiguous_data, mask1)
367        not_contiguous_mt = masked_tensor(not_contiguous_data, mask2)
368
369        contiguous_mt_sparse = masked_tensor(
370            contiguous_data.to_sparse_coo(), mask1.to_sparse_coo()
371        )
372        not_contiguous_mt_sparse = masked_tensor(
373            not_contiguous_data.to_sparse_coo(), mask2.to_sparse_coo()
374        )
375
376        self.assertEqual(contiguous_data.is_contiguous(), True)
377        self.assertEqual(not_contiguous_data.is_contiguous(), False)
378
379        self.assertEqual(contiguous_mt.is_contiguous(), True)
380        self.assertEqual(not_contiguous_mt.is_contiguous(), False)
381
382        error_msg = "MaskedTensors with sparse data do not have is_contiguous"
383        for t in [contiguous_mt_sparse, not_contiguous_mt_sparse]:
384            with self.assertRaisesRegex(ValueError, error_msg):
385                t.is_contiguous()
386            with self.assertRaisesRegex(ValueError, error_msg):
387                t.contiguous()
388
389        now_contiguous_mt = not_contiguous_mt.contiguous()
390
391        _compare_mts(not_contiguous_mt, now_contiguous_mt)
392
393        self.assertEqual(now_contiguous_mt.is_contiguous(), True)
394        self.assertEqual(now_contiguous_mt.get_data().is_contiguous(), True)
395        self.assertEqual(now_contiguous_mt.is_contiguous(), True)
396
397class TestUnary(TestCase):
398    def _get_test_data(self, fn_name):
399        data = torch.randn(10, 10)
400        mask = torch.rand(10, 10) > 0.5
401        fn_name = _fix_fn_name(fn_name)
402        if fn_name in ["log", "log10", "log1p", "log2", "sqrt"]:
403            data = data.mul(0.5).abs()
404        if fn_name in ["rsqrt"]:
405            data = data.abs() + 1  # Void division by zero
406        if fn_name in ["acos", "arccos", "asin", "arcsin", "logit"]:
407            data = data.abs().mul(0.5).clamp(0, 1)
408        if fn_name in ["atanh", "arctanh", "erfinv"]:
409            data = data.mul(0.5).clamp(-1, 1)
410        if fn_name in ["acosh", "arccosh"]:
411            data = data.abs() + 1
412        if fn_name in ["bitwise_not"]:
413            data = data.mul(128).to(torch.int8)
414        return data, mask
415
416    def _get_sample_kwargs(self, fn_name):
417        fn_name = _fix_fn_name(fn_name)
418        kwargs = {}
419        if fn_name in ["clamp", "clip"]:
420            kwargs["min"] = -0.5
421            kwargs["max"] = 0.5
422        return kwargs
423
424    def _get_sample_args(self, fn_name, data, mask):
425        fn_name = _fix_fn_name(fn_name)
426        mt = masked_tensor(data, mask)
427        t_args = [data]
428        mt_args = [mt]
429        if fn_name in ["pow"]:
430            t_args += [2.0]
431            mt_args += [2.0]
432        return t_args, mt_args
433
434    @parametrize("fn", NATIVE_UNARY_FNS)
435    def test_unary(self, fn):
436        torch.random.manual_seed(0)
437        fn_name = fn.__name__
438        data, mask = self._get_test_data(fn_name)
439        kwargs = self._get_sample_kwargs(fn_name)
440
441        t_args, mt_args = self._get_sample_args(fn_name, data, mask)
442
443        mt_result = fn(*mt_args, **kwargs)
444        t_result = fn(*t_args, **kwargs)
445        _compare_mt_t(mt_result, t_result)
446
447    @parametrize("fn", NATIVE_INPLACE_UNARY_FNS)
448    def test_inplace_unary(self, fn):
449        torch.random.manual_seed(0)
450        fn_name = fn.__name__
451        data, mask = self._get_test_data(fn_name)
452        kwargs = self._get_sample_kwargs(fn_name)
453
454        t_args, mt_args = self._get_sample_args(fn_name, data, mask)
455
456        mt_result = fn(*mt_args, **kwargs)
457        t_result = fn(*t_args, **kwargs)
458        _compare_mt_t(mt_result, t_result)
459
460class TestBinary(TestCase):
461    def _get_test_data(self, fn_name):
462        fn_name = _fix_fn_name(fn_name)
463        data0 = torch.randn(10, 10)
464        data1 = torch.randn(10, 10)
465        mask = torch.rand(10, 10) > 0.5
466        if fn_name in ["bitwise_and", "bitwise_or", "bitwise_xor"]:
467            data0 = data0.mul(128).to(torch.int8)
468            data1 = data1.mul(128).to(torch.int8)
469        if fn_name in ["bitwise_left_shift", "bitwise_right_shift"]:
470            data0 = data0.abs().to(torch.int64)
471            data1 = data1.abs().to(torch.int64)
472        return data0, data1, mask
473
474    def _get_sample_kwargs(self, fn_name):
475        fn_name = _fix_fn_name(fn_name)
476        kwargs = {}
477        return kwargs
478
479    def _yield_sample_args(self, fn_name, data0, data1, mask):
480        """ Returns two sets of Tensor and MaskedTensor args for a binary function to compute.
481            Tensor args are all the same (just the two provided data tensors),
482            while the MaskedTensor args tests both (MaskedTensor, MaskedTensor) and (MaskedTensor, Tensor)
483        """
484        fn_name = _fix_fn_name(fn_name)
485        mt0 = masked_tensor(data0, mask)
486        mt1 = masked_tensor(data1, mask)
487
488        t_args = [data0, data1]
489        mt_args = [mt0, mt1]
490        yield t_args, mt_args
491
492        t_args = [data0, data1]
493        mt_args = [mt0, data1]
494        yield t_args, mt_args
495
496    @parametrize("fn", NATIVE_BINARY_FNS)
497    def test_binary(self, fn):
498        torch.random.manual_seed(0)
499        fn_name = fn.__name__
500        data0, data1, mask = self._get_test_data(fn_name)
501        kwargs = self._get_sample_kwargs(fn_name)
502
503        for (t_args, mt_args) in self._yield_sample_args(fn_name, data0, data1, mask):
504            mt_result = fn(*mt_args, **kwargs)
505            t_result = fn(*t_args, **kwargs)
506            _compare_mt_t(mt_result, t_result)
507
508    @parametrize("fn", NATIVE_INPLACE_BINARY_FNS)
509    def test_inplace_binary(self, fn):
510        torch.random.manual_seed(0)
511        fn_name = fn.__name__
512        data0, data1, mask = self._get_test_data(fn_name)
513        kwargs = self._get_sample_kwargs(fn_name)
514
515        for (t_args, mt_args) in self._yield_sample_args(fn_name, data0, data1, mask):
516            mt_result = fn(*mt_args, **kwargs)
517            t_result = fn(*t_args, **kwargs)
518            _compare_mt_t(mt_result, t_result)
519
520    @parametrize("fn_name", ["add", "add_"])
521    def test_masks_match(self, fn_name):
522        torch.random.manual_seed(0)
523        fn = getattr(torch.ops.aten, fn_name)
524        data0, data1, mask = self._get_test_data(fn_name)
525        mask0 = mask
526        mask1 = torch.rand(mask.size()) > 0.5
527        mt0 = masked_tensor(data0, mask0)
528        mt1 = masked_tensor(data1, mask1)
529        try:
530            fn(mt0, mt1)
531            raise AssertionError
532        except ValueError as e:
533            assert (
534                "Input masks must match. If you need support for this, please open an issue on Github."
535                == str(e)
536            )
537
538class TestReductions(TestCase):
539    def test_max_not_implemented(self):
540        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
541        m = torch.tensor([[True, False, False], [False, True, False]])
542        mt = masked_tensor(d, m)
543        with self.assertRaisesRegex(TypeError, "torch._ops.aten.max.default"):
544            mt.max()
545
546    def test_sum(self):
547        d = torch.tensor([[0, 1, 2, 6], [3, 4, 5.0, 7]])
548        m = torch.tensor([[True, False, False, True], [False, True, False, True]])
549        mt = masked_tensor(d, m)
550        _compare_mts(masked_tensor(torch.tensor(17.0), torch.tensor(True)), mt.sum())
551        _compare_mts(
552            masked_tensor(
553                torch.tensor([0.0, 4.0, 1.0, 13]),
554                torch.tensor([True, True, False, True]),
555            ),
556            mt.sum(dim=0),
557        )
558
559    def test_sum_grad(self):
560        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
561        m = torch.tensor([[True, False, False], [False, True, False]])
562        mt = masked_tensor(d, m, requires_grad=True)
563        mt.sum().backward()
564        _compare_mts(mt.grad, masked_tensor(torch.tensor(1.0).expand_as(m), m))
565
566    def test_mean(self):
567        d = torch.tensor([[0, 1, 3, 2], [3, 4, 1.0, 4]])
568        m = torch.tensor([[True, False, False, True], [False, True, False, True]])
569        mt = masked_tensor(d, m)
570        _compare_mts(masked_tensor(torch.tensor(2.5), torch.tensor(True)), mt.mean())
571        _compare_mts(
572            masked_tensor(
573                torch.tensor([0.0, 4.0, 1.0, 3]),
574                torch.tensor([True, True, False, True]),
575            ),
576            mt.mean(dim=0),
577        )
578
579    """
580        The following block of tests "test_mean_grad_case_1[a through e] are used to test the functionality of
581        the two different ways of constructing MaskedTensors:
582            masked_tensor(data, mask, requires_grad=True/False) -- NO differentiable constructor and always a leaf
583            as_masked_tensor(data, mask) -- differentiable constructor
584
585        Like torch.tensor(data), masked_tensor(data, mask) will provide a UserWarning if data.requires_grad=True
586        as_masked_tensor does not take in requires_grad -- it just takes on the requires_grad from data
587
588        Therefore, there are 6 cases to test and we use `mean` as a proxy to test the different combinations
589
590        Assuming mt.mean().backward() is run after each constructor:
591
592        Case 1a:
593            values.requires_grad = True
594            mt = masked_tensor(values, mask, requires_grad=True)
595        yields
596            - Provide a UserWarning because values.requires_grad=True
597            - values.grad = None
598            - mt.grad is a MaskedTensor with the correct gradient
599
600        Case 1b:
601            values.requires_grad = False
602            mt = masked_tensor(values, mask, requires_grad=True)
603        yields
604            - values.grad = None
605            - mt.grad is a MaskedTensor with the correct gradient
606
607        Case 2a/2b:
608            values.requires_grad = True/False
609            mt = masked_tensor(values, mask, requires_grad=False)
610
611            will both yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn"
612            as expected. When values.requires_grad=True, we will also get a UserWarning
613
614        Case 3a:
615            values.requires_grad = True
616            mt = as_masked_tensor(values, mask)
617        yields
618            - values.grad is a MaskedTensor with the correct gradient
619            - mt.grad is None and gives a UserWarning that
620              "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad"
621
622        Case 3b:
623            values.requires_grad = False
624            mt = as_masked_tensor(values, mask)
625
626            will yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn"
627            as expected.
628    """
629    def test_mean_grad_case_1a(self):
630        """ values.requires_grad = True
631            mt = masked_tensor(values, mask, requires_grad=True)
632        """
633        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
634        m = torch.tensor([[True, False, False], [False, True, False]])
635        with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"):
636            mt = masked_tensor(d, m, requires_grad=True)
637        mt.mean().backward()
638        self.assertIsNone(d.grad)
639        _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))
640
641    def test_mean_grad_case_1b(self):
642        """ values.requires_grad = False
643            mt = masked_tensor(values, mask, requires_grad=True)
644        """
645        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
646        m = torch.tensor([[True, False, False], [False, True, False]])
647        mt = masked_tensor(d, m, requires_grad=True)
648        mt.mean().backward()
649        self.assertIsNone(d.grad)
650        _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))
651
652    def test_mean_grad_case_1c(self):
653        """ values.requires_grad = True
654            mt = masked_tensor(values, mask, requires_grad=False)
655        """
656        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
657        m = torch.tensor([[True, False, False], [False, True, False]])
658        with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"):
659            mt = masked_tensor(d, m, requires_grad=False)
660        result = mt.mean()
661        msg = "element 0 of tensors does not require grad and does not have a grad_fn"
662        with self.assertRaisesRegex(RuntimeError, msg):
663            result.backward()
664
665
666    def test_mean_grad_case_1d(self):
667        """ values.requires_grad = False
668            mt = masked_tensor(values, mask, requires_grad=False)
669        """
670        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
671        m = torch.tensor([[True, False, False], [False, True, False]])
672        mt = masked_tensor(d, m, requires_grad=False)
673        result = mt.mean()
674        msg = "element 0 of tensors does not require grad and does not have a grad_fn"
675        with self.assertRaisesRegex(RuntimeError, msg):
676            result.backward()
677
678    def test_mean_grad_case_1e(self):
679        """ values.requires_grad = True
680            mt = as_masked_tensor(values, mask)
681        """
682        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
683        m = torch.tensor([[True, False, False], [False, True, False]])
684        mt = as_masked_tensor(d, m)
685        mt.mean().backward()
686        _compare_mts(d.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))
687        msg = "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad"
688        with self.assertWarnsRegex(UserWarning, msg):
689            self.assertIsNone(mt.grad)
690
691    def test_mean_grad_case_1f(self):
692        """ values.requires_grad = False
693            mt = as_masked_tensor(values, mask)
694        """
695        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
696        m = torch.tensor([[True, False, False], [False, True, False]])
697        mt = as_masked_tensor(d, m)
698        result = mt.mean()
699        msg = "element 0 of tensors does not require grad and does not have a grad_fn"
700        with self.assertRaisesRegex(RuntimeError, msg):
701            result.backward()
702
703    def test_mean_dim_grad(self):
704        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
705        m = torch.tensor([[True, True, False], [False, True, False]])
706        mt = masked_tensor(d, m, requires_grad=True)
707        mt.mean(1).sum().backward()
708        _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0.5, 0], [0, 1, 0]]), m))
709
710    def test_amax(self):
711        d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]])
712        m = torch.tensor([[True, False, False, True], [False, True, False, True]])
713        mt = masked_tensor(d, m)
714        _compare_mts(masked_tensor(torch.tensor(3.0), torch.tensor(True)), mt.amax())
715        _compare_mts(
716            masked_tensor(
717                torch.tensor([0.0, -4.0, 1.0, 3]),
718                torch.tensor([True, True, False, True]),
719            ),
720            mt.amax(dim=0),
721        )
722
723    def test_amax_grad(self):
724        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
725        m = torch.tensor([[True, False, False], [False, True, False]])
726        mt = masked_tensor(d, m, requires_grad=True)
727        mt.amax().backward()
728        _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.0, 0, 0], [0, 1, 0]]), m))
729
730    def test_amin(self):
731        d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]])
732        m = torch.tensor([[True, False, False, True], [False, True, False, True]])
733        mt = masked_tensor(d, m)
734        _compare_mts(masked_tensor(torch.tensor(-4.0), torch.tensor(True)), mt.amin())
735        _compare_mts(
736            masked_tensor(
737                torch.tensor([0.0, -4.0, 1.0, -3]),
738                torch.tensor([True, True, False, True]),
739            ),
740            mt.amin(dim=0),
741        )
742
743    def test_amin_grad(self):
744        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
745        m = torch.tensor([[True, False, False], [False, True, False]])
746        mt = masked_tensor(d, m, requires_grad=True)
747        mt.amin().backward()
748        _compare_mts(mt.grad, masked_tensor(torch.tensor([[1.0, 0, 0], [0, 0, 0]]), m))
749
750    def test_prod(self):
751        d = torch.tensor([[0, 1, 3, 0.0], [float("nan"), 4, 1.0, 5.0]])
752        m = torch.tensor([[True, False, False, True], [False, True, False, True]])
753        mt = masked_tensor(d, m)
754        _compare_mts(masked_tensor(torch.tensor(0.0), torch.tensor(True)), mt.prod())
755        _compare_mts(
756            masked_tensor(
757                torch.tensor([0.0, 4.0, 1.0, 0.0]),
758                torch.tensor([True, True, False, True]),
759            ),
760            mt.prod(dim=0),
761        )
762
763    def test_prod_grad(self):
764        d = torch.tensor([[2, float("nan"), 2], [3, 4, 5.0]])
765        m = torch.tensor([[True, False, False], [False, True, False]])
766        mt = masked_tensor(d, m, requires_grad=True)
767        mt.prod().backward()
768        _compare_mts(mt.grad, masked_tensor(torch.tensor([[4.0, 0, 0], [0, 2, 0]]), m))
769
770    def test_all(self):
771        d = torch.tensor([[True, True, False, False], [False, True, True, True]])
772        m = torch.tensor([[True, False, False, True], [False, True, False, True]])
773        mt = masked_tensor(d, m)
774        _compare_mts(masked_tensor(torch.tensor(False), torch.tensor(True)), mt.all())
775        _compare_mts(
776            masked_tensor(
777                torch.tensor([True, True, True, False]),
778                torch.tensor([True, True, False, True]),
779            ),
780            mt.all(dim=0),
781        )
782
783        m = torch.tensor([[True, False, True, False], [False, True, False, False]])
784        mt = masked_tensor(d, m)
785        _compare_mts(
786            masked_tensor(
787                torch.tensor([True, True, False, True]),
788                torch.tensor([True, True, True, False]),
789            ),
790            mt.all(dim=0),
791        )
792
793    def test_grad_dtype(self):
794        d = torch.tensor([[True, True, False], [False, True, True]])
795        m = torch.tensor([[True, False, False], [False, True, False]])
796        msg = "Only Tensors of floating point and complex dtype can require gradients"
797        with self.assertRaisesRegex(RuntimeError, msg):
798            masked_tensor(d, m, requires_grad=True)
799
800    def test_any_true_dtype(self):
801        mt = torch.masked.MaskedTensor(
802            torch.rand(2, 2),
803            torch.rand(2, 2) > 0.5
804        )
805        msg = "expected a boolean tensor"
806        with self.assertRaisesRegex(ValueError, msg):
807            mt._is_any_true()
808
809    def test__is_any_true(self):
810        mt = torch.masked.MaskedTensor(
811            torch.tensor([[True, True, False], [False, False, True]]),
812            torch.tensor([[True, False, False], [False, True, False]]),
813        )
814        _compare_mts(
815            masked_tensor(torch.tensor(True), torch.tensor(True)),
816            mt._is_any_true(),
817        )
818
819    def test__is_any_true_false(self):
820        mt = torch.masked.MaskedTensor(
821            torch.tensor([[True, True, False], [False, False, True]]),
822            torch.tensor([[False, False, False], [False, False, False]]),
823        )
824        _compare_mts(
825            masked_tensor(torch.tensor(False), torch.tensor(True),),
826            mt._is_any_true(),
827        )
828
829    def test_backward(self):
830        # See https://github.com/pytorch/pytorch/issues/128557
831        with torch.autograd.detect_anomaly():
832            mt = torch.masked.MaskedTensor(
833                torch.rand(2, 2),
834                torch.rand(2, 2) > 0.5,
835                requires_grad=True
836            )
837            mt.sum().backward()
838
839
840def is_unary(op):
841    return op.name in UNARY_NAMES
842
843def is_binary(op):
844    return op.name in BINARY_NAMES
845
846def is_reduction(op):
847    return op.name in REDUCE_NAMES and op.name not in {"all", "mean", "std", "var"}
848
849mt_unary_ufuncs = [op for op in unary_ufuncs if is_unary(op)]
850mt_binary_ufuncs = [op for op in binary_ufuncs if is_binary(op)]
851mt_reduction_ufuncs = [op for op in reduction_ops if is_reduction(op)]
852
853MASKEDTENSOR_FLOAT_TYPES = {
854    torch.float16,
855    torch.float32,
856    torch.float64,
857}
858
859class TestOperators(TestCase):
860    def _convert_mt_args(self, args, mask, layout):
861        return [
862            masked_tensor(
863                arg.sparse_mask(mask) if layout != torch.strided else arg, mask
864            )
865            if torch.is_tensor(arg)
866            else arg
867            for arg in args
868        ]
869
870    def _test_unary_binary_equality(self, device, dtype, op, layout=torch.strided):
871        samples = op.sample_inputs(device, dtype, requires_grad=True)
872
873        for sample in samples:
874            input = sample.input
875            sample_args, sample_kwargs = sample.args, sample.kwargs
876            mask = (
877                _create_random_mask(input.shape, device)
878                if "mask" not in sample_kwargs
879                else sample_kwargs.pop("mask")
880            )
881
882            if layout == torch.sparse_coo:
883                mask = mask.to_sparse_coo().coalesce()
884                input = input.sparse_mask(mask)
885            elif layout == torch.sparse_csr:
886                if input.ndim != 2 or mask.ndim != 2:
887                    continue
888                mask = mask.to_sparse_csr()
889                input = input.sparse_mask(mask)
890
891            # Binary operations currently only support same size masks
892            if is_binary(op):
893                if input.shape != sample_args[0].shape:
894                    continue
895                # Binary operations also don't support kwargs right now
896                else:
897                    sample_kwargs = {}
898
899            mt = masked_tensor(input, mask)
900            mt_args = self._convert_mt_args(sample_args, mask, layout)
901
902            mt_result = op(mt, *mt_args, **sample_kwargs)
903            t_result = op(sample.input, *sample_args, **sample_kwargs)
904
905            _compare_mt_t(mt_result, t_result)
906
907            # If the operation is binary, check that lhs = masked, rhs = regular tensor also works
908            if is_binary(op) and layout == torch.strided:
909                mt_result2 = op(mt, *sample_args, **sample_kwargs)
910                _compare_mt_t(mt_result2, t_result)
911
912    def _test_reduction_equality(self, device, dtype, op, layout=torch.strided):
913        samples = op.sample_inputs(device, dtype, requires_grad=True)
914
915        for sample in samples:
916            input = sample.input
917            # Reduction operations don't support more advanced args/kwargs right now
918            sample_args, sample_kwargs = (), {}
919
920            if input.dim() == 0 or input.numel() == 0:
921                continue
922
923            mask = _create_random_mask(input.shape, device)
924
925            if torch.count_nonzero(mask) == 0:
926                continue
927
928            tensor_input = _combine_input_and_mask(op.op, input, mask)
929            if layout == torch.sparse_coo:
930                mask = mask.to_sparse_coo().coalesce()
931                input = input.sparse_mask(mask)
932            elif layout == torch.sparse_csr:
933                if input.ndim != 2 or mask.ndim != 2:
934                    continue
935                mask = mask.to_sparse_csr()
936                input = input.sparse_mask(mask)
937
938            mt = masked_tensor(input, mask)
939            mt_args = self._convert_mt_args(sample_args, mask, layout)
940
941            mt_result = op(mt, *mt_args, **sample_kwargs)
942            t_result = op(tensor_input, *sample_args, **sample_kwargs)
943
944            _compare_mt_t(mt_result, t_result)
945
946    @ops(mt_unary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES)  # type: ignore[arg-type]
947    @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
948    def test_unary_core(self, device, dtype, op, layout):
949        # Skip tests that don't have len(kwargs) == 0
950        skip_variants = {
951            "decimals_0",
952            "decimals_3",
953            "decimals_neg_3",
954        }
955        if op.name == "round" and op.variant_test_name in skip_variants:
956            return
957        self._test_unary_binary_equality(device, dtype, op)
958
959    @ops(mt_binary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES)  # type: ignore[arg-type]
960    @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
961    # FIXME:
962    # Result is just wrong; production logic should be fixed
963    @decorateIf(
964        unittest.expectedFailure,
965        lambda params: (
966            params["op"].name == "add" and
967            params["dtype"] in [torch.float16, torch.float32] and
968            params["device"] == "cpu" and
969            params["layout"] == torch.sparse_csr
970        )
971    )
972    # Result is just wrong; production logic should be fixed
973    @decorateIf(
974        unittest.expectedFailure,
975        lambda params: (
976            params["op"].name == "sub" and
977            params["dtype"] in [torch.float16, torch.float32] and
978            params["device"] == "cpu" and
979            params["layout"] == torch.sparse_csr
980        )
981    )
982    # Result is just wrong; production logic should be fixed
983    @decorateIf(
984        unittest.expectedFailure,
985        lambda params: (
986            params["op"].name == "eq" and
987            params["dtype"] == torch.float64 and
988            params["device"] == "cpu" and
989            params["layout"] == torch.sparse_csr
990        )
991    )
992    def test_binary_core(self, device, dtype, op, layout):
993        self._test_unary_binary_equality(device, dtype, op, layout)
994
995    @ops(mt_reduction_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES)  # type: ignore[arg-type]
996    @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
997    def test_reduction_all(self, device, dtype, op, layout):
998        # argmin and argmax are not currently supported for torch.sparse_csr
999        if op.name in {"argmin", "argmax"} and layout == torch.sparse_csr:
1000            return
1001
1002        self._test_reduction_equality(device, dtype, op, layout)
1003
1004
1005only_for = ("cpu", "cuda")
1006instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
1007
1008instantiate_device_type_tests(TestBasics, globals(), only_for=only_for)
1009instantiate_parametrized_tests(TestUnary)
1010instantiate_parametrized_tests(TestBinary)
1011instantiate_parametrized_tests(TestReductions)
1012
1013if __name__ == '__main__':
1014    run_tests()
1015