xref: /aosp_15_r20/external/pytorch/test/test_maskedtensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: masked operators"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerimport unittest
5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
6*da0073e9SAndroid Build Coastguard Worker    decorateIf,
7*da0073e9SAndroid Build Coastguard Worker    TestCase,
8*da0073e9SAndroid Build Coastguard Worker    run_tests,
9*da0073e9SAndroid Build Coastguard Worker    make_tensor,
10*da0073e9SAndroid Build Coastguard Worker    parametrize,
11*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
12*da0073e9SAndroid Build Coastguard Worker)
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
14*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
15*da0073e9SAndroid Build Coastguard Worker    ops,
16*da0073e9SAndroid Build Coastguard Worker)
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import (
18*da0073e9SAndroid Build Coastguard Worker    SampleInput,
19*da0073e9SAndroid Build Coastguard Worker    binary_ufuncs,
20*da0073e9SAndroid Build Coastguard Worker    reduction_ops,
21*da0073e9SAndroid Build Coastguard Worker    unary_ufuncs,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerfrom torch.masked import as_masked_tensor, masked_tensor, _combine_input_and_mask
25*da0073e9SAndroid Build Coastguard Workerfrom torch.masked.maskedtensor.core import _masks_match, _tensors_match
26*da0073e9SAndroid Build Coastguard Workerfrom torch.masked.maskedtensor.unary import NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS, UNARY_NAMES
27*da0073e9SAndroid Build Coastguard Workerfrom torch.masked.maskedtensor.binary import NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS, BINARY_NAMES
28*da0073e9SAndroid Build Coastguard Workerfrom torch.masked.maskedtensor.reductions import REDUCE_NAMES
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Workerdef _compare_mt_t(mt_result, t_result, rtol=1e-05, atol=1e-05):
32*da0073e9SAndroid Build Coastguard Worker    mask = mt_result.get_mask()
33*da0073e9SAndroid Build Coastguard Worker    mt_result_data = mt_result.get_data()
34*da0073e9SAndroid Build Coastguard Worker    if mask.layout in {torch.sparse_coo, torch.sparse_csr}:
35*da0073e9SAndroid Build Coastguard Worker        mask = mask.to_dense()
36*da0073e9SAndroid Build Coastguard Worker    if mt_result_data.layout in {torch.sparse_coo, torch.sparse_csr}:
37*da0073e9SAndroid Build Coastguard Worker        mt_result_data = mt_result_data.to_dense()
38*da0073e9SAndroid Build Coastguard Worker    a = mt_result_data.detach().masked_fill_(~mask, 0)
39*da0073e9SAndroid Build Coastguard Worker    b = t_result.detach().masked_fill_(~mask, 0)
40*da0073e9SAndroid Build Coastguard Worker    if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol):
41*da0073e9SAndroid Build Coastguard Worker        raise ValueError("The data in MaskedTensor a and Tensor b do not match")
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Workerdef _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08):
44*da0073e9SAndroid Build Coastguard Worker    mt_data1 = mt1.get_data()
45*da0073e9SAndroid Build Coastguard Worker    mt_data2 = mt2.get_data()
46*da0073e9SAndroid Build Coastguard Worker    if mt_data1.layout != mt_data2.layout:
47*da0073e9SAndroid Build Coastguard Worker        raise ValueError("mt1's data and mt2's data do not have the same layout. "
48*da0073e9SAndroid Build Coastguard Worker                         f"mt1.get_data().layout = {mt_data1.layout} while mt2.get_data().layout = {mt_data2.layout}")
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    mask = mt1.get_mask()
51*da0073e9SAndroid Build Coastguard Worker    mask2 = mt2.get_mask()
52*da0073e9SAndroid Build Coastguard Worker    if not _masks_match(mt1, mt2):
53*da0073e9SAndroid Build Coastguard Worker        raise ValueError("mt1 and mt2 must have matching masks")
54*da0073e9SAndroid Build Coastguard Worker    if mask.layout != mask2.layout:
55*da0073e9SAndroid Build Coastguard Worker        raise ValueError("mt1's mask and mt2's mask do not have the same layout. "
56*da0073e9SAndroid Build Coastguard Worker                         f"mt1.get_mask().layout = {mask.layout} while mt2.get_mask().layout = {mask2.layout}")
57*da0073e9SAndroid Build Coastguard Worker    if mask.layout in {torch.sparse_coo, torch.sparse_csr}:
58*da0073e9SAndroid Build Coastguard Worker        mask = mask.to_dense()
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    if mt_data1.layout in {torch.sparse_coo, torch.sparse_csr}:
61*da0073e9SAndroid Build Coastguard Worker        mt_data1 = mt_data1.to_dense()
62*da0073e9SAndroid Build Coastguard Worker        mt_data2 = mt_data2.to_dense()
63*da0073e9SAndroid Build Coastguard Worker    a = mt_data1.detach().masked_fill_(~mask, 0)
64*da0073e9SAndroid Build Coastguard Worker    b = mt_data2.detach().masked_fill_(~mask, 0)
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol):
67*da0073e9SAndroid Build Coastguard Worker        raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match")
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Workerdef _compare_forward_backward(data, mask, fn):
70*da0073e9SAndroid Build Coastguard Worker    mt = masked_tensor(data, mask, requires_grad=True)
71*da0073e9SAndroid Build Coastguard Worker    masked_res = fn(mt)
72*da0073e9SAndroid Build Coastguard Worker    masked_res.sum().backward()
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    t = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_()
75*da0073e9SAndroid Build Coastguard Worker    tensor_res = fn(t)
76*da0073e9SAndroid Build Coastguard Worker    tensor_res.sum().backward()
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    _compare_mt_t(masked_res, tensor_res)
79*da0073e9SAndroid Build Coastguard Worker    _compare_mt_t(mt.grad, t.grad, atol=1e-06)
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Workerdef _create_random_mask(shape, device):
83*da0073e9SAndroid Build Coastguard Worker    return make_tensor(shape, device=device, dtype=torch.bool)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workerdef _generate_sample_data(
86*da0073e9SAndroid Build Coastguard Worker    device="cpu", dtype=torch.float, requires_grad=True, layout=torch.strided
87*da0073e9SAndroid Build Coastguard Worker):
88*da0073e9SAndroid Build Coastguard Worker    assert layout in {
89*da0073e9SAndroid Build Coastguard Worker        torch.strided,
90*da0073e9SAndroid Build Coastguard Worker        torch.sparse_coo,
91*da0073e9SAndroid Build Coastguard Worker        torch.sparse_csr,
92*da0073e9SAndroid Build Coastguard Worker    }, "Layout must be strided/sparse_coo/sparse_csr"
93*da0073e9SAndroid Build Coastguard Worker    shapes = [
94*da0073e9SAndroid Build Coastguard Worker        [],
95*da0073e9SAndroid Build Coastguard Worker        [2],
96*da0073e9SAndroid Build Coastguard Worker        [3, 5],
97*da0073e9SAndroid Build Coastguard Worker        [3, 2, 1, 2],
98*da0073e9SAndroid Build Coastguard Worker    ]
99*da0073e9SAndroid Build Coastguard Worker    inputs = []
100*da0073e9SAndroid Build Coastguard Worker    for s in shapes:
101*da0073e9SAndroid Build Coastguard Worker        data = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad)  # type: ignore[arg-type]
102*da0073e9SAndroid Build Coastguard Worker        mask = _create_random_mask(s, device)
103*da0073e9SAndroid Build Coastguard Worker        if layout == torch.sparse_coo:
104*da0073e9SAndroid Build Coastguard Worker            mask = mask.to_sparse_coo().coalesce()
105*da0073e9SAndroid Build Coastguard Worker            data = data.sparse_mask(mask).requires_grad_(requires_grad)
106*da0073e9SAndroid Build Coastguard Worker        elif layout == torch.sparse_csr:
107*da0073e9SAndroid Build Coastguard Worker            if data.ndim != 2 and mask.ndim != 2:
108*da0073e9SAndroid Build Coastguard Worker                continue
109*da0073e9SAndroid Build Coastguard Worker            mask = mask.to_sparse_csr()
110*da0073e9SAndroid Build Coastguard Worker            data = data.sparse_mask(mask)
111*da0073e9SAndroid Build Coastguard Worker        inputs.append(SampleInput(data, kwargs={"mask": mask}))
112*da0073e9SAndroid Build Coastguard Worker    return inputs
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Workerdef _fix_fn_name(fn_name):
115*da0073e9SAndroid Build Coastguard Worker    if fn_name[-1] == "_":
116*da0073e9SAndroid Build Coastguard Worker        fn_name = fn_name[:-1]
117*da0073e9SAndroid Build Coastguard Worker    return fn_name
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Workerclass TestBasics(TestCase):
121*da0073e9SAndroid Build Coastguard Worker    def test_invalid_tensor_inputs(self, device):
122*da0073e9SAndroid Build Coastguard Worker        data = torch.randn((3, 4), device=device)
123*da0073e9SAndroid Build Coastguard Worker        mask = _create_random_mask((3, 4), device=device)
124*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(data, mask)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "data must be a Tensor"):
127*da0073e9SAndroid Build Coastguard Worker            masked_tensor(mt, mask)
128*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "data must be a Tensor"):
129*da0073e9SAndroid Build Coastguard Worker            masked_tensor(0, mask)
130*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "mask must be a Tensor"):
131*da0073e9SAndroid Build Coastguard Worker            masked_tensor(data, mt)
132*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "mask must be a Tensor"):
133*da0073e9SAndroid Build Coastguard Worker            masked_tensor(data, 0)
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker    def test_diff_layouts(self, device):
136*da0073e9SAndroid Build Coastguard Worker        data = torch.randn((3, 4), device=device).to_sparse_coo()
137*da0073e9SAndroid Build Coastguard Worker        mask = _create_random_mask((3, 4), device=device)
138*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "data and mask must have the same layout"):
139*da0073e9SAndroid Build Coastguard Worker            masked_tensor(data, mask)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    def test_diff_dim(self, device):
142*da0073e9SAndroid Build Coastguard Worker        data = torch.randn((3, 4, 5), device=device)
143*da0073e9SAndroid Build Coastguard Worker        mask = _create_random_mask((3, 4), device=device)
144*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "data.dim\\(\\) must equal mask.dim\\(\\)"):
145*da0073e9SAndroid Build Coastguard Worker            masked_tensor(data, mask)
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker    def test_diff_sizes(self, device):
148*da0073e9SAndroid Build Coastguard Worker        data = torch.randn((3, 4), device=device)
149*da0073e9SAndroid Build Coastguard Worker        mask = _create_random_mask((3, 3), device=device)
150*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "data.size\\(\\) must equal mask.size\\(\\)"):
151*da0073e9SAndroid Build Coastguard Worker            masked_tensor(data, mask)
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    def test_grad_warning(self, device):
154*da0073e9SAndroid Build Coastguard Worker        data = torch.randn((3, 4), device=device, requires_grad=True)
155*da0073e9SAndroid Build Coastguard Worker        mask = _create_random_mask((3, 4), device=device)
156*da0073e9SAndroid Build Coastguard Worker        msg = "It is not recommended to create a MaskedTensor with a tensor that requires_grad."
157*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, msg):
158*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(data, mask)
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker    def test_add(self, device):
161*da0073e9SAndroid Build Coastguard Worker        data = torch.arange(5.0, device=device)
162*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor([True, True, False, True, False], device=device)
163*da0073e9SAndroid Build Coastguard Worker        m0 = masked_tensor(data, mask)
164*da0073e9SAndroid Build Coastguard Worker        m1 = masked_tensor(data, ~mask)
165*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Input masks must match."):
166*da0073e9SAndroid Build Coastguard Worker            m0 + m1
167*da0073e9SAndroid Build Coastguard Worker        _compare_mts(m0 + m0, masked_tensor(torch.tensor([0., 2, 0, 6, 0], device=device), mask))
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    def test_softmax(self, device):
170*da0073e9SAndroid Build Coastguard Worker        data = torch.randn((3, 4), device=device) * 0.1
171*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor(
172*da0073e9SAndroid Build Coastguard Worker            [
173*da0073e9SAndroid Build Coastguard Worker                [True, True, True, False],
174*da0073e9SAndroid Build Coastguard Worker                [False, True, False, True],
175*da0073e9SAndroid Build Coastguard Worker                [True, True, False, False],
176*da0073e9SAndroid Build Coastguard Worker            ],
177*da0073e9SAndroid Build Coastguard Worker            device=device
178*da0073e9SAndroid Build Coastguard Worker        )
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker        _compare_forward_backward(data, mask, lambda t: torch.softmax(t, -1))
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker    def test_where(self, device):
183*da0073e9SAndroid Build Coastguard Worker        data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device)
184*da0073e9SAndroid Build Coastguard Worker        mask = data < 0
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker        mx = masked_tensor(data, mask, requires_grad=True)
187*da0073e9SAndroid Build Coastguard Worker        my = masked_tensor(torch.ones_like(data), ~mask, requires_grad=True)
188*da0073e9SAndroid Build Coastguard Worker        masked_res = torch.where(mask, torch.exp(mx), my)
189*da0073e9SAndroid Build Coastguard Worker        masked_res.sum().backward()
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker        x = data.detach().clone().requires_grad_()
192*da0073e9SAndroid Build Coastguard Worker        y = torch.ones_like(x, device=device, requires_grad=True)
193*da0073e9SAndroid Build Coastguard Worker        tensor_res = torch.where(mask, torch.exp(x), y)
194*da0073e9SAndroid Build Coastguard Worker        tensor_res.sum().backward()
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker        _compare_mt_t(masked_res, tensor_res)
197*da0073e9SAndroid Build Coastguard Worker        _compare_mt_t(mx.grad, x.grad)
198*da0073e9SAndroid Build Coastguard Worker        _compare_mt_t(my.grad, y.grad)
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker    def test_unfold(self, device):
201*da0073e9SAndroid Build Coastguard Worker        data = torch.rand(5, 5, device=device)
202*da0073e9SAndroid Build Coastguard Worker        mask = torch.rand(5, 5, device=device) > 0.5
203*da0073e9SAndroid Build Coastguard Worker        _compare_forward_backward(data, mask, lambda t: t.unfold(1, 2, 2))
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker    def test_nn_unfold(self, device):
206*da0073e9SAndroid Build Coastguard Worker        data = torch.rand(2, 5, 3, 4, device=device)
207*da0073e9SAndroid Build Coastguard Worker        mask = torch.rand(2, 5, 3, 4, device=device) > 0.5
208*da0073e9SAndroid Build Coastguard Worker        _compare_forward_backward(data, mask, lambda t: torch.nn.functional.unfold(t, kernel_size=(2, 3)))
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker    def test_stack(self, device):
211*da0073e9SAndroid Build Coastguard Worker        masked_tensors = [
212*da0073e9SAndroid Build Coastguard Worker            masked_tensor(
213*da0073e9SAndroid Build Coastguard Worker                torch.rand(2, 5, 3, 4, device=device),
214*da0073e9SAndroid Build Coastguard Worker                torch.rand(2, 5, 3, 4, device=device) > 0.5,
215*da0073e9SAndroid Build Coastguard Worker                requires_grad=True,
216*da0073e9SAndroid Build Coastguard Worker            ) for _ in range(3)
217*da0073e9SAndroid Build Coastguard Worker        ]
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        data_tensors = [mt.get_data().detach().clone().requires_grad_() for mt in masked_tensors]
220*da0073e9SAndroid Build Coastguard Worker        masked_res = torch.stack(masked_tensors)
221*da0073e9SAndroid Build Coastguard Worker        tensor_res = torch.stack(data_tensors)
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker        masked_res.sum().backward()
224*da0073e9SAndroid Build Coastguard Worker        tensor_res.sum().backward()
225*da0073e9SAndroid Build Coastguard Worker        _compare_mt_t(masked_res, tensor_res)
226*da0073e9SAndroid Build Coastguard Worker        for mt, t in zip(masked_tensors, data_tensors):
227*da0073e9SAndroid Build Coastguard Worker            _compare_mt_t(mt.grad, t.grad, atol=1e-06)
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker    def test_to_sparse(self, device):
230*da0073e9SAndroid Build Coastguard Worker        for sample in _generate_sample_data(device=device):
231*da0073e9SAndroid Build Coastguard Worker            data = sample.input
232*da0073e9SAndroid Build Coastguard Worker            mask = sample.kwargs["mask"]
233*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(data.clone().detach(), mask, requires_grad=True)
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker            sparse_mt = mt.to_sparse()
236*da0073e9SAndroid Build Coastguard Worker            data.to_sparse().to_dense().sum().backward()
237*da0073e9SAndroid Build Coastguard Worker            sparse_mt.to_dense().sum().backward()
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker            _compare_mt_t(sparse_mt, data)
240*da0073e9SAndroid Build Coastguard Worker            _compare_mt_t(mt.grad, data.grad)
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker    def test_to_dense(self, device):
243*da0073e9SAndroid Build Coastguard Worker        samples = _generate_sample_data(
244*da0073e9SAndroid Build Coastguard Worker            device=device,
245*da0073e9SAndroid Build Coastguard Worker            layout=torch.sparse_coo
246*da0073e9SAndroid Build Coastguard Worker        ) + _generate_sample_data(device=device, layout=torch.sparse_csr)
247*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
248*da0073e9SAndroid Build Coastguard Worker            data = sample.input
249*da0073e9SAndroid Build Coastguard Worker            mask = sample.kwargs["mask"]
250*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(data, mask, requires_grad=True)
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker            dense_data = data.to_dense().detach().clone().requires_grad_(True)
253*da0073e9SAndroid Build Coastguard Worker            dense_mt = mt.to_dense()
254*da0073e9SAndroid Build Coastguard Worker            dense_data.sum().backward()
255*da0073e9SAndroid Build Coastguard Worker            dense_mt.sum().backward()
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker            _compare_mt_t(dense_mt, dense_data)
258*da0073e9SAndroid Build Coastguard Worker            _compare_mt_t(mt.grad.to_dense(), dense_data.grad)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker    def test_to_dense_and_sparse_coo(self, device):
261*da0073e9SAndroid Build Coastguard Worker        for sample in _generate_sample_data(device=device, layout=torch.strided):
262*da0073e9SAndroid Build Coastguard Worker            data = sample.input
263*da0073e9SAndroid Build Coastguard Worker            mask = sample.kwargs["mask"]
264*da0073e9SAndroid Build Coastguard Worker            ms = mask.to_sparse_coo().coalesce()
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(data, mask, requires_grad=True)
267*da0073e9SAndroid Build Coastguard Worker            mts = masked_tensor(data.sparse_mask(ms), ms, requires_grad=True)
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker            converted = mt.to_sparse().to_dense()
270*da0073e9SAndroid Build Coastguard Worker            converted.sum().backward()
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker            converted2 = mts.to_dense()
273*da0073e9SAndroid Build Coastguard Worker            converted2.sum().backward()
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker            _compare_mts(converted, converted2)
276*da0073e9SAndroid Build Coastguard Worker            _compare_mts(mt.grad, mts.grad.to_dense())
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker    def test_to_dense_and_sparse_csr(self, device):
279*da0073e9SAndroid Build Coastguard Worker        for sample in _generate_sample_data(device=device, layout=torch.strided):
280*da0073e9SAndroid Build Coastguard Worker            data = sample.input
281*da0073e9SAndroid Build Coastguard Worker            mask = sample.kwargs["mask"]
282*da0073e9SAndroid Build Coastguard Worker            if data.ndim != 2:
283*da0073e9SAndroid Build Coastguard Worker                continue
284*da0073e9SAndroid Build Coastguard Worker            ms = mask.to_sparse_csr()
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(data, mask, requires_grad=True)
287*da0073e9SAndroid Build Coastguard Worker            mts = masked_tensor(data.sparse_mask(ms), ms, requires_grad=True)
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker            converted = mt.to_sparse_csr().to_dense()
290*da0073e9SAndroid Build Coastguard Worker            converted.sum().backward()
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker            converted2 = mts.to_dense()
293*da0073e9SAndroid Build Coastguard Worker            converted2.sum().backward()
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker            _compare_mts(converted, converted2)
296*da0073e9SAndroid Build Coastguard Worker            _compare_mts(mt.grad, mts.grad.to_dense())
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker    def test_invalid_sparse_layout(self, device):
299*da0073e9SAndroid Build Coastguard Worker        data = torch.randn((3, 4), device=device).to_sparse_csc()
300*da0073e9SAndroid Build Coastguard Worker        mask = _create_random_mask((3, 4), device=device).to_sparse_csc()
301*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "data layout of torch.sparse_csc is not supported"):
302*da0073e9SAndroid Build Coastguard Worker            masked_tensor(data, mask)
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker    def test_invalid_sparse_coo_values(self, device):
305*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([3, 4, 5], dtype=torch.float32)
306*da0073e9SAndroid Build Coastguard Worker        i1 = torch.tensor([[0, 1, 1], [2, 0, 2]])
307*da0073e9SAndroid Build Coastguard Worker        i2 = torch.tensor([[0, 1, 1], [2, 1, 2]])
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_coo_tensor(i1, v, (2, 4), device=device)
310*da0073e9SAndroid Build Coastguard Worker        mask = torch.sparse_coo_tensor(i2, torch.tensor([True, True, True]), (2, 4), device=device)
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker        msg = "data and mask are both sparse COO tensors but do not have the same indices."
313*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
314*da0073e9SAndroid Build Coastguard Worker            masked_tensor(t, mask)
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker    def test_invalid_sparse_csr_values(self, device):
317*da0073e9SAndroid Build Coastguard Worker        crow_indices1 = [0, 2, 3]
318*da0073e9SAndroid Build Coastguard Worker        crow_indices2 = [0, 1, 3]
319*da0073e9SAndroid Build Coastguard Worker        col_indices1 = [0, 1, 2]
320*da0073e9SAndroid Build Coastguard Worker        col_indices2 = [1, 2, 3]
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker        values = [2, 3, 4]
323*da0073e9SAndroid Build Coastguard Worker        mask_values = [True, True, True]
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker        t1 = torch.sparse_csr_tensor(
326*da0073e9SAndroid Build Coastguard Worker            torch.tensor(crow_indices1, dtype=torch.int64),
327*da0073e9SAndroid Build Coastguard Worker            torch.tensor(col_indices1, dtype=torch.int64),
328*da0073e9SAndroid Build Coastguard Worker            torch.tensor(values),
329*da0073e9SAndroid Build Coastguard Worker            size=(2, 4)
330*da0073e9SAndroid Build Coastguard Worker        )
331*da0073e9SAndroid Build Coastguard Worker        mask1 = torch.sparse_csr_tensor(
332*da0073e9SAndroid Build Coastguard Worker            torch.tensor(crow_indices2, dtype=torch.int64),
333*da0073e9SAndroid Build Coastguard Worker            torch.tensor(col_indices1, dtype=torch.int64),
334*da0073e9SAndroid Build Coastguard Worker            torch.tensor(mask_values),
335*da0073e9SAndroid Build Coastguard Worker            dtype=torch.bool,
336*da0073e9SAndroid Build Coastguard Worker            size=(2, 4),
337*da0073e9SAndroid Build Coastguard Worker        )
338*da0073e9SAndroid Build Coastguard Worker        t2 = torch.sparse_csr_tensor(
339*da0073e9SAndroid Build Coastguard Worker            torch.tensor(crow_indices2, dtype=torch.int64),
340*da0073e9SAndroid Build Coastguard Worker            torch.tensor(col_indices1, dtype=torch.int64),
341*da0073e9SAndroid Build Coastguard Worker            torch.tensor(values),
342*da0073e9SAndroid Build Coastguard Worker            size=(2, 4),
343*da0073e9SAndroid Build Coastguard Worker        )
344*da0073e9SAndroid Build Coastguard Worker        mask2 = torch.sparse_csr_tensor(
345*da0073e9SAndroid Build Coastguard Worker            torch.tensor(crow_indices2, dtype=torch.int64),
346*da0073e9SAndroid Build Coastguard Worker            torch.tensor(col_indices2, dtype=torch.int64),
347*da0073e9SAndroid Build Coastguard Worker            torch.tensor(mask_values),
348*da0073e9SAndroid Build Coastguard Worker            dtype=torch.bool,
349*da0073e9SAndroid Build Coastguard Worker            size=(2, 4),
350*da0073e9SAndroid Build Coastguard Worker        )
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        msg = "data and mask are both sparse CSR tensors but do not share either crow or col indices."
353*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
354*da0073e9SAndroid Build Coastguard Worker            masked_tensor(t1, mask1)
355*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
356*da0073e9SAndroid Build Coastguard Worker            masked_tensor(t2, mask2)
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker    def test_contiguous(self, device):
359*da0073e9SAndroid Build Coastguard Worker        data = torch.randn((3, 3), device=device)
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        contiguous_data = data.clone()
362*da0073e9SAndroid Build Coastguard Worker        mask1 = (contiguous_data > 0).bool()
363*da0073e9SAndroid Build Coastguard Worker        not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2))
364*da0073e9SAndroid Build Coastguard Worker        mask2 = (not_contiguous_data > 0).bool()
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        contiguous_mt = masked_tensor(contiguous_data, mask1)
367*da0073e9SAndroid Build Coastguard Worker        not_contiguous_mt = masked_tensor(not_contiguous_data, mask2)
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker        contiguous_mt_sparse = masked_tensor(
370*da0073e9SAndroid Build Coastguard Worker            contiguous_data.to_sparse_coo(), mask1.to_sparse_coo()
371*da0073e9SAndroid Build Coastguard Worker        )
372*da0073e9SAndroid Build Coastguard Worker        not_contiguous_mt_sparse = masked_tensor(
373*da0073e9SAndroid Build Coastguard Worker            not_contiguous_data.to_sparse_coo(), mask2.to_sparse_coo()
374*da0073e9SAndroid Build Coastguard Worker        )
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(contiguous_data.is_contiguous(), True)
377*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(not_contiguous_data.is_contiguous(), False)
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(contiguous_mt.is_contiguous(), True)
380*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(not_contiguous_mt.is_contiguous(), False)
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        error_msg = "MaskedTensors with sparse data do not have is_contiguous"
383*da0073e9SAndroid Build Coastguard Worker        for t in [contiguous_mt_sparse, not_contiguous_mt_sparse]:
384*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, error_msg):
385*da0073e9SAndroid Build Coastguard Worker                t.is_contiguous()
386*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, error_msg):
387*da0073e9SAndroid Build Coastguard Worker                t.contiguous()
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker        now_contiguous_mt = not_contiguous_mt.contiguous()
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker        _compare_mts(not_contiguous_mt, now_contiguous_mt)
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(now_contiguous_mt.is_contiguous(), True)
394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(now_contiguous_mt.get_data().is_contiguous(), True)
395*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(now_contiguous_mt.is_contiguous(), True)
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Workerclass TestUnary(TestCase):
398*da0073e9SAndroid Build Coastguard Worker    def _get_test_data(self, fn_name):
399*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(10, 10)
400*da0073e9SAndroid Build Coastguard Worker        mask = torch.rand(10, 10) > 0.5
401*da0073e9SAndroid Build Coastguard Worker        fn_name = _fix_fn_name(fn_name)
402*da0073e9SAndroid Build Coastguard Worker        if fn_name in ["log", "log10", "log1p", "log2", "sqrt"]:
403*da0073e9SAndroid Build Coastguard Worker            data = data.mul(0.5).abs()
404*da0073e9SAndroid Build Coastguard Worker        if fn_name in ["rsqrt"]:
405*da0073e9SAndroid Build Coastguard Worker            data = data.abs() + 1  # Void division by zero
406*da0073e9SAndroid Build Coastguard Worker        if fn_name in ["acos", "arccos", "asin", "arcsin", "logit"]:
407*da0073e9SAndroid Build Coastguard Worker            data = data.abs().mul(0.5).clamp(0, 1)
408*da0073e9SAndroid Build Coastguard Worker        if fn_name in ["atanh", "arctanh", "erfinv"]:
409*da0073e9SAndroid Build Coastguard Worker            data = data.mul(0.5).clamp(-1, 1)
410*da0073e9SAndroid Build Coastguard Worker        if fn_name in ["acosh", "arccosh"]:
411*da0073e9SAndroid Build Coastguard Worker            data = data.abs() + 1
412*da0073e9SAndroid Build Coastguard Worker        if fn_name in ["bitwise_not"]:
413*da0073e9SAndroid Build Coastguard Worker            data = data.mul(128).to(torch.int8)
414*da0073e9SAndroid Build Coastguard Worker        return data, mask
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker    def _get_sample_kwargs(self, fn_name):
417*da0073e9SAndroid Build Coastguard Worker        fn_name = _fix_fn_name(fn_name)
418*da0073e9SAndroid Build Coastguard Worker        kwargs = {}
419*da0073e9SAndroid Build Coastguard Worker        if fn_name in ["clamp", "clip"]:
420*da0073e9SAndroid Build Coastguard Worker            kwargs["min"] = -0.5
421*da0073e9SAndroid Build Coastguard Worker            kwargs["max"] = 0.5
422*da0073e9SAndroid Build Coastguard Worker        return kwargs
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker    def _get_sample_args(self, fn_name, data, mask):
425*da0073e9SAndroid Build Coastguard Worker        fn_name = _fix_fn_name(fn_name)
426*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(data, mask)
427*da0073e9SAndroid Build Coastguard Worker        t_args = [data]
428*da0073e9SAndroid Build Coastguard Worker        mt_args = [mt]
429*da0073e9SAndroid Build Coastguard Worker        if fn_name in ["pow"]:
430*da0073e9SAndroid Build Coastguard Worker            t_args += [2.0]
431*da0073e9SAndroid Build Coastguard Worker            mt_args += [2.0]
432*da0073e9SAndroid Build Coastguard Worker        return t_args, mt_args
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker    @parametrize("fn", NATIVE_UNARY_FNS)
435*da0073e9SAndroid Build Coastguard Worker    def test_unary(self, fn):
436*da0073e9SAndroid Build Coastguard Worker        torch.random.manual_seed(0)
437*da0073e9SAndroid Build Coastguard Worker        fn_name = fn.__name__
438*da0073e9SAndroid Build Coastguard Worker        data, mask = self._get_test_data(fn_name)
439*da0073e9SAndroid Build Coastguard Worker        kwargs = self._get_sample_kwargs(fn_name)
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        t_args, mt_args = self._get_sample_args(fn_name, data, mask)
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker        mt_result = fn(*mt_args, **kwargs)
444*da0073e9SAndroid Build Coastguard Worker        t_result = fn(*t_args, **kwargs)
445*da0073e9SAndroid Build Coastguard Worker        _compare_mt_t(mt_result, t_result)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker    @parametrize("fn", NATIVE_INPLACE_UNARY_FNS)
448*da0073e9SAndroid Build Coastguard Worker    def test_inplace_unary(self, fn):
449*da0073e9SAndroid Build Coastguard Worker        torch.random.manual_seed(0)
450*da0073e9SAndroid Build Coastguard Worker        fn_name = fn.__name__
451*da0073e9SAndroid Build Coastguard Worker        data, mask = self._get_test_data(fn_name)
452*da0073e9SAndroid Build Coastguard Worker        kwargs = self._get_sample_kwargs(fn_name)
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker        t_args, mt_args = self._get_sample_args(fn_name, data, mask)
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker        mt_result = fn(*mt_args, **kwargs)
457*da0073e9SAndroid Build Coastguard Worker        t_result = fn(*t_args, **kwargs)
458*da0073e9SAndroid Build Coastguard Worker        _compare_mt_t(mt_result, t_result)
459*da0073e9SAndroid Build Coastguard Worker
460*da0073e9SAndroid Build Coastguard Workerclass TestBinary(TestCase):
461*da0073e9SAndroid Build Coastguard Worker    def _get_test_data(self, fn_name):
462*da0073e9SAndroid Build Coastguard Worker        fn_name = _fix_fn_name(fn_name)
463*da0073e9SAndroid Build Coastguard Worker        data0 = torch.randn(10, 10)
464*da0073e9SAndroid Build Coastguard Worker        data1 = torch.randn(10, 10)
465*da0073e9SAndroid Build Coastguard Worker        mask = torch.rand(10, 10) > 0.5
466*da0073e9SAndroid Build Coastguard Worker        if fn_name in ["bitwise_and", "bitwise_or", "bitwise_xor"]:
467*da0073e9SAndroid Build Coastguard Worker            data0 = data0.mul(128).to(torch.int8)
468*da0073e9SAndroid Build Coastguard Worker            data1 = data1.mul(128).to(torch.int8)
469*da0073e9SAndroid Build Coastguard Worker        if fn_name in ["bitwise_left_shift", "bitwise_right_shift"]:
470*da0073e9SAndroid Build Coastguard Worker            data0 = data0.abs().to(torch.int64)
471*da0073e9SAndroid Build Coastguard Worker            data1 = data1.abs().to(torch.int64)
472*da0073e9SAndroid Build Coastguard Worker        return data0, data1, mask
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker    def _get_sample_kwargs(self, fn_name):
475*da0073e9SAndroid Build Coastguard Worker        fn_name = _fix_fn_name(fn_name)
476*da0073e9SAndroid Build Coastguard Worker        kwargs = {}
477*da0073e9SAndroid Build Coastguard Worker        return kwargs
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker    def _yield_sample_args(self, fn_name, data0, data1, mask):
480*da0073e9SAndroid Build Coastguard Worker        """ Returns two sets of Tensor and MaskedTensor args for a binary function to compute.
481*da0073e9SAndroid Build Coastguard Worker            Tensor args are all the same (just the two provided data tensors),
482*da0073e9SAndroid Build Coastguard Worker            while the MaskedTensor args tests both (MaskedTensor, MaskedTensor) and (MaskedTensor, Tensor)
483*da0073e9SAndroid Build Coastguard Worker        """
484*da0073e9SAndroid Build Coastguard Worker        fn_name = _fix_fn_name(fn_name)
485*da0073e9SAndroid Build Coastguard Worker        mt0 = masked_tensor(data0, mask)
486*da0073e9SAndroid Build Coastguard Worker        mt1 = masked_tensor(data1, mask)
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker        t_args = [data0, data1]
489*da0073e9SAndroid Build Coastguard Worker        mt_args = [mt0, mt1]
490*da0073e9SAndroid Build Coastguard Worker        yield t_args, mt_args
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker        t_args = [data0, data1]
493*da0073e9SAndroid Build Coastguard Worker        mt_args = [mt0, data1]
494*da0073e9SAndroid Build Coastguard Worker        yield t_args, mt_args
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker    @parametrize("fn", NATIVE_BINARY_FNS)
497*da0073e9SAndroid Build Coastguard Worker    def test_binary(self, fn):
498*da0073e9SAndroid Build Coastguard Worker        torch.random.manual_seed(0)
499*da0073e9SAndroid Build Coastguard Worker        fn_name = fn.__name__
500*da0073e9SAndroid Build Coastguard Worker        data0, data1, mask = self._get_test_data(fn_name)
501*da0073e9SAndroid Build Coastguard Worker        kwargs = self._get_sample_kwargs(fn_name)
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker        for (t_args, mt_args) in self._yield_sample_args(fn_name, data0, data1, mask):
504*da0073e9SAndroid Build Coastguard Worker            mt_result = fn(*mt_args, **kwargs)
505*da0073e9SAndroid Build Coastguard Worker            t_result = fn(*t_args, **kwargs)
506*da0073e9SAndroid Build Coastguard Worker            _compare_mt_t(mt_result, t_result)
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker    @parametrize("fn", NATIVE_INPLACE_BINARY_FNS)
509*da0073e9SAndroid Build Coastguard Worker    def test_inplace_binary(self, fn):
510*da0073e9SAndroid Build Coastguard Worker        torch.random.manual_seed(0)
511*da0073e9SAndroid Build Coastguard Worker        fn_name = fn.__name__
512*da0073e9SAndroid Build Coastguard Worker        data0, data1, mask = self._get_test_data(fn_name)
513*da0073e9SAndroid Build Coastguard Worker        kwargs = self._get_sample_kwargs(fn_name)
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker        for (t_args, mt_args) in self._yield_sample_args(fn_name, data0, data1, mask):
516*da0073e9SAndroid Build Coastguard Worker            mt_result = fn(*mt_args, **kwargs)
517*da0073e9SAndroid Build Coastguard Worker            t_result = fn(*t_args, **kwargs)
518*da0073e9SAndroid Build Coastguard Worker            _compare_mt_t(mt_result, t_result)
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker    @parametrize("fn_name", ["add", "add_"])
521*da0073e9SAndroid Build Coastguard Worker    def test_masks_match(self, fn_name):
522*da0073e9SAndroid Build Coastguard Worker        torch.random.manual_seed(0)
523*da0073e9SAndroid Build Coastguard Worker        fn = getattr(torch.ops.aten, fn_name)
524*da0073e9SAndroid Build Coastguard Worker        data0, data1, mask = self._get_test_data(fn_name)
525*da0073e9SAndroid Build Coastguard Worker        mask0 = mask
526*da0073e9SAndroid Build Coastguard Worker        mask1 = torch.rand(mask.size()) > 0.5
527*da0073e9SAndroid Build Coastguard Worker        mt0 = masked_tensor(data0, mask0)
528*da0073e9SAndroid Build Coastguard Worker        mt1 = masked_tensor(data1, mask1)
529*da0073e9SAndroid Build Coastguard Worker        try:
530*da0073e9SAndroid Build Coastguard Worker            fn(mt0, mt1)
531*da0073e9SAndroid Build Coastguard Worker            raise AssertionError
532*da0073e9SAndroid Build Coastguard Worker        except ValueError as e:
533*da0073e9SAndroid Build Coastguard Worker            assert (
534*da0073e9SAndroid Build Coastguard Worker                "Input masks must match. If you need support for this, please open an issue on Github."
535*da0073e9SAndroid Build Coastguard Worker                == str(e)
536*da0073e9SAndroid Build Coastguard Worker            )
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Workerclass TestReductions(TestCase):
539*da0073e9SAndroid Build Coastguard Worker    def test_max_not_implemented(self):
540*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
541*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False], [False, True, False]])
542*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m)
543*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "torch._ops.aten.max.default"):
544*da0073e9SAndroid Build Coastguard Worker            mt.max()
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker    def test_sum(self):
547*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 2, 6], [3, 4, 5.0, 7]])
548*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False, True], [False, True, False, True]])
549*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m)
550*da0073e9SAndroid Build Coastguard Worker        _compare_mts(masked_tensor(torch.tensor(17.0), torch.tensor(True)), mt.sum())
551*da0073e9SAndroid Build Coastguard Worker        _compare_mts(
552*da0073e9SAndroid Build Coastguard Worker            masked_tensor(
553*da0073e9SAndroid Build Coastguard Worker                torch.tensor([0.0, 4.0, 1.0, 13]),
554*da0073e9SAndroid Build Coastguard Worker                torch.tensor([True, True, False, True]),
555*da0073e9SAndroid Build Coastguard Worker            ),
556*da0073e9SAndroid Build Coastguard Worker            mt.sum(dim=0),
557*da0073e9SAndroid Build Coastguard Worker        )
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker    def test_sum_grad(self):
560*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
561*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False], [False, True, False]])
562*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m, requires_grad=True)
563*da0073e9SAndroid Build Coastguard Worker        mt.sum().backward()
564*da0073e9SAndroid Build Coastguard Worker        _compare_mts(mt.grad, masked_tensor(torch.tensor(1.0).expand_as(m), m))
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker    def test_mean(self):
567*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 3, 2], [3, 4, 1.0, 4]])
568*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False, True], [False, True, False, True]])
569*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m)
570*da0073e9SAndroid Build Coastguard Worker        _compare_mts(masked_tensor(torch.tensor(2.5), torch.tensor(True)), mt.mean())
571*da0073e9SAndroid Build Coastguard Worker        _compare_mts(
572*da0073e9SAndroid Build Coastguard Worker            masked_tensor(
573*da0073e9SAndroid Build Coastguard Worker                torch.tensor([0.0, 4.0, 1.0, 3]),
574*da0073e9SAndroid Build Coastguard Worker                torch.tensor([True, True, False, True]),
575*da0073e9SAndroid Build Coastguard Worker            ),
576*da0073e9SAndroid Build Coastguard Worker            mt.mean(dim=0),
577*da0073e9SAndroid Build Coastguard Worker        )
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker    """
580*da0073e9SAndroid Build Coastguard Worker        The following block of tests "test_mean_grad_case_1[a through e] are used to test the functionality of
581*da0073e9SAndroid Build Coastguard Worker        the two different ways of constructing MaskedTensors:
582*da0073e9SAndroid Build Coastguard Worker            masked_tensor(data, mask, requires_grad=True/False) -- NO differentiable constructor and always a leaf
583*da0073e9SAndroid Build Coastguard Worker            as_masked_tensor(data, mask) -- differentiable constructor
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker        Like torch.tensor(data), masked_tensor(data, mask) will provide a UserWarning if data.requires_grad=True
586*da0073e9SAndroid Build Coastguard Worker        as_masked_tensor does not take in requires_grad -- it just takes on the requires_grad from data
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker        Therefore, there are 6 cases to test and we use `mean` as a proxy to test the different combinations
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker        Assuming mt.mean().backward() is run after each constructor:
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker        Case 1a:
593*da0073e9SAndroid Build Coastguard Worker            values.requires_grad = True
594*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(values, mask, requires_grad=True)
595*da0073e9SAndroid Build Coastguard Worker        yields
596*da0073e9SAndroid Build Coastguard Worker            - Provide a UserWarning because values.requires_grad=True
597*da0073e9SAndroid Build Coastguard Worker            - values.grad = None
598*da0073e9SAndroid Build Coastguard Worker            - mt.grad is a MaskedTensor with the correct gradient
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Worker        Case 1b:
601*da0073e9SAndroid Build Coastguard Worker            values.requires_grad = False
602*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(values, mask, requires_grad=True)
603*da0073e9SAndroid Build Coastguard Worker        yields
604*da0073e9SAndroid Build Coastguard Worker            - values.grad = None
605*da0073e9SAndroid Build Coastguard Worker            - mt.grad is a MaskedTensor with the correct gradient
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker        Case 2a/2b:
608*da0073e9SAndroid Build Coastguard Worker            values.requires_grad = True/False
609*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(values, mask, requires_grad=False)
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker            will both yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn"
612*da0073e9SAndroid Build Coastguard Worker            as expected. When values.requires_grad=True, we will also get a UserWarning
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Worker        Case 3a:
615*da0073e9SAndroid Build Coastguard Worker            values.requires_grad = True
616*da0073e9SAndroid Build Coastguard Worker            mt = as_masked_tensor(values, mask)
617*da0073e9SAndroid Build Coastguard Worker        yields
618*da0073e9SAndroid Build Coastguard Worker            - values.grad is a MaskedTensor with the correct gradient
619*da0073e9SAndroid Build Coastguard Worker            - mt.grad is None and gives a UserWarning that
620*da0073e9SAndroid Build Coastguard Worker              "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad"
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker        Case 3b:
623*da0073e9SAndroid Build Coastguard Worker            values.requires_grad = False
624*da0073e9SAndroid Build Coastguard Worker            mt = as_masked_tensor(values, mask)
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker            will yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn"
627*da0073e9SAndroid Build Coastguard Worker            as expected.
628*da0073e9SAndroid Build Coastguard Worker    """
629*da0073e9SAndroid Build Coastguard Worker    def test_mean_grad_case_1a(self):
630*da0073e9SAndroid Build Coastguard Worker        """ values.requires_grad = True
631*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(values, mask, requires_grad=True)
632*da0073e9SAndroid Build Coastguard Worker        """
633*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
634*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False], [False, True, False]])
635*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"):
636*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(d, m, requires_grad=True)
637*da0073e9SAndroid Build Coastguard Worker        mt.mean().backward()
638*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(d.grad)
639*da0073e9SAndroid Build Coastguard Worker        _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker    def test_mean_grad_case_1b(self):
642*da0073e9SAndroid Build Coastguard Worker        """ values.requires_grad = False
643*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(values, mask, requires_grad=True)
644*da0073e9SAndroid Build Coastguard Worker        """
645*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
646*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False], [False, True, False]])
647*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m, requires_grad=True)
648*da0073e9SAndroid Build Coastguard Worker        mt.mean().backward()
649*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(d.grad)
650*da0073e9SAndroid Build Coastguard Worker        _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker    def test_mean_grad_case_1c(self):
653*da0073e9SAndroid Build Coastguard Worker        """ values.requires_grad = True
654*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(values, mask, requires_grad=False)
655*da0073e9SAndroid Build Coastguard Worker        """
656*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
657*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False], [False, True, False]])
658*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"):
659*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(d, m, requires_grad=False)
660*da0073e9SAndroid Build Coastguard Worker        result = mt.mean()
661*da0073e9SAndroid Build Coastguard Worker        msg = "element 0 of tensors does not require grad and does not have a grad_fn"
662*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
663*da0073e9SAndroid Build Coastguard Worker            result.backward()
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker    def test_mean_grad_case_1d(self):
667*da0073e9SAndroid Build Coastguard Worker        """ values.requires_grad = False
668*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(values, mask, requires_grad=False)
669*da0073e9SAndroid Build Coastguard Worker        """
670*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
671*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False], [False, True, False]])
672*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m, requires_grad=False)
673*da0073e9SAndroid Build Coastguard Worker        result = mt.mean()
674*da0073e9SAndroid Build Coastguard Worker        msg = "element 0 of tensors does not require grad and does not have a grad_fn"
675*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
676*da0073e9SAndroid Build Coastguard Worker            result.backward()
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker    def test_mean_grad_case_1e(self):
679*da0073e9SAndroid Build Coastguard Worker        """ values.requires_grad = True
680*da0073e9SAndroid Build Coastguard Worker            mt = as_masked_tensor(values, mask)
681*da0073e9SAndroid Build Coastguard Worker        """
682*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True)
683*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False], [False, True, False]])
684*da0073e9SAndroid Build Coastguard Worker        mt = as_masked_tensor(d, m)
685*da0073e9SAndroid Build Coastguard Worker        mt.mean().backward()
686*da0073e9SAndroid Build Coastguard Worker        _compare_mts(d.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m))
687*da0073e9SAndroid Build Coastguard Worker        msg = "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad"
688*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, msg):
689*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(mt.grad)
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker    def test_mean_grad_case_1f(self):
692*da0073e9SAndroid Build Coastguard Worker        """ values.requires_grad = False
693*da0073e9SAndroid Build Coastguard Worker            mt = as_masked_tensor(values, mask)
694*da0073e9SAndroid Build Coastguard Worker        """
695*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
696*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False], [False, True, False]])
697*da0073e9SAndroid Build Coastguard Worker        mt = as_masked_tensor(d, m)
698*da0073e9SAndroid Build Coastguard Worker        result = mt.mean()
699*da0073e9SAndroid Build Coastguard Worker        msg = "element 0 of tensors does not require grad and does not have a grad_fn"
700*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
701*da0073e9SAndroid Build Coastguard Worker            result.backward()
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker    def test_mean_dim_grad(self):
704*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
705*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, True, False], [False, True, False]])
706*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m, requires_grad=True)
707*da0073e9SAndroid Build Coastguard Worker        mt.mean(1).sum().backward()
708*da0073e9SAndroid Build Coastguard Worker        _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0.5, 0], [0, 1, 0]]), m))
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker    def test_amax(self):
711*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]])
712*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False, True], [False, True, False, True]])
713*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m)
714*da0073e9SAndroid Build Coastguard Worker        _compare_mts(masked_tensor(torch.tensor(3.0), torch.tensor(True)), mt.amax())
715*da0073e9SAndroid Build Coastguard Worker        _compare_mts(
716*da0073e9SAndroid Build Coastguard Worker            masked_tensor(
717*da0073e9SAndroid Build Coastguard Worker                torch.tensor([0.0, -4.0, 1.0, 3]),
718*da0073e9SAndroid Build Coastguard Worker                torch.tensor([True, True, False, True]),
719*da0073e9SAndroid Build Coastguard Worker            ),
720*da0073e9SAndroid Build Coastguard Worker            mt.amax(dim=0),
721*da0073e9SAndroid Build Coastguard Worker        )
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker    def test_amax_grad(self):
724*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
725*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False], [False, True, False]])
726*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m, requires_grad=True)
727*da0073e9SAndroid Build Coastguard Worker        mt.amax().backward()
728*da0073e9SAndroid Build Coastguard Worker        _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.0, 0, 0], [0, 1, 0]]), m))
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker    def test_amin(self):
731*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]])
732*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False, True], [False, True, False, True]])
733*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m)
734*da0073e9SAndroid Build Coastguard Worker        _compare_mts(masked_tensor(torch.tensor(-4.0), torch.tensor(True)), mt.amin())
735*da0073e9SAndroid Build Coastguard Worker        _compare_mts(
736*da0073e9SAndroid Build Coastguard Worker            masked_tensor(
737*da0073e9SAndroid Build Coastguard Worker                torch.tensor([0.0, -4.0, 1.0, -3]),
738*da0073e9SAndroid Build Coastguard Worker                torch.tensor([True, True, False, True]),
739*da0073e9SAndroid Build Coastguard Worker            ),
740*da0073e9SAndroid Build Coastguard Worker            mt.amin(dim=0),
741*da0073e9SAndroid Build Coastguard Worker        )
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker    def test_amin_grad(self):
744*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 2], [3, 4, 5.0]])
745*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False], [False, True, False]])
746*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m, requires_grad=True)
747*da0073e9SAndroid Build Coastguard Worker        mt.amin().backward()
748*da0073e9SAndroid Build Coastguard Worker        _compare_mts(mt.grad, masked_tensor(torch.tensor([[1.0, 0, 0], [0, 0, 0]]), m))
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker    def test_prod(self):
751*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[0, 1, 3, 0.0], [float("nan"), 4, 1.0, 5.0]])
752*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False, True], [False, True, False, True]])
753*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m)
754*da0073e9SAndroid Build Coastguard Worker        _compare_mts(masked_tensor(torch.tensor(0.0), torch.tensor(True)), mt.prod())
755*da0073e9SAndroid Build Coastguard Worker        _compare_mts(
756*da0073e9SAndroid Build Coastguard Worker            masked_tensor(
757*da0073e9SAndroid Build Coastguard Worker                torch.tensor([0.0, 4.0, 1.0, 0.0]),
758*da0073e9SAndroid Build Coastguard Worker                torch.tensor([True, True, False, True]),
759*da0073e9SAndroid Build Coastguard Worker            ),
760*da0073e9SAndroid Build Coastguard Worker            mt.prod(dim=0),
761*da0073e9SAndroid Build Coastguard Worker        )
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Worker    def test_prod_grad(self):
764*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[2, float("nan"), 2], [3, 4, 5.0]])
765*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False], [False, True, False]])
766*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m, requires_grad=True)
767*da0073e9SAndroid Build Coastguard Worker        mt.prod().backward()
768*da0073e9SAndroid Build Coastguard Worker        _compare_mts(mt.grad, masked_tensor(torch.tensor([[4.0, 0, 0], [0, 2, 0]]), m))
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker    def test_all(self):
771*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[True, True, False, False], [False, True, True, True]])
772*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False, True], [False, True, False, True]])
773*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m)
774*da0073e9SAndroid Build Coastguard Worker        _compare_mts(masked_tensor(torch.tensor(False), torch.tensor(True)), mt.all())
775*da0073e9SAndroid Build Coastguard Worker        _compare_mts(
776*da0073e9SAndroid Build Coastguard Worker            masked_tensor(
777*da0073e9SAndroid Build Coastguard Worker                torch.tensor([True, True, True, False]),
778*da0073e9SAndroid Build Coastguard Worker                torch.tensor([True, True, False, True]),
779*da0073e9SAndroid Build Coastguard Worker            ),
780*da0073e9SAndroid Build Coastguard Worker            mt.all(dim=0),
781*da0073e9SAndroid Build Coastguard Worker        )
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, True, False], [False, True, False, False]])
784*da0073e9SAndroid Build Coastguard Worker        mt = masked_tensor(d, m)
785*da0073e9SAndroid Build Coastguard Worker        _compare_mts(
786*da0073e9SAndroid Build Coastguard Worker            masked_tensor(
787*da0073e9SAndroid Build Coastguard Worker                torch.tensor([True, True, False, True]),
788*da0073e9SAndroid Build Coastguard Worker                torch.tensor([True, True, True, False]),
789*da0073e9SAndroid Build Coastguard Worker            ),
790*da0073e9SAndroid Build Coastguard Worker            mt.all(dim=0),
791*da0073e9SAndroid Build Coastguard Worker        )
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker    def test_grad_dtype(self):
794*da0073e9SAndroid Build Coastguard Worker        d = torch.tensor([[True, True, False], [False, True, True]])
795*da0073e9SAndroid Build Coastguard Worker        m = torch.tensor([[True, False, False], [False, True, False]])
796*da0073e9SAndroid Build Coastguard Worker        msg = "Only Tensors of floating point and complex dtype can require gradients"
797*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
798*da0073e9SAndroid Build Coastguard Worker            masked_tensor(d, m, requires_grad=True)
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker    def test_any_true_dtype(self):
801*da0073e9SAndroid Build Coastguard Worker        mt = torch.masked.MaskedTensor(
802*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 2),
803*da0073e9SAndroid Build Coastguard Worker            torch.rand(2, 2) > 0.5
804*da0073e9SAndroid Build Coastguard Worker        )
805*da0073e9SAndroid Build Coastguard Worker        msg = "expected a boolean tensor"
806*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
807*da0073e9SAndroid Build Coastguard Worker            mt._is_any_true()
808*da0073e9SAndroid Build Coastguard Worker
809*da0073e9SAndroid Build Coastguard Worker    def test__is_any_true(self):
810*da0073e9SAndroid Build Coastguard Worker        mt = torch.masked.MaskedTensor(
811*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[True, True, False], [False, False, True]]),
812*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[True, False, False], [False, True, False]]),
813*da0073e9SAndroid Build Coastguard Worker        )
814*da0073e9SAndroid Build Coastguard Worker        _compare_mts(
815*da0073e9SAndroid Build Coastguard Worker            masked_tensor(torch.tensor(True), torch.tensor(True)),
816*da0073e9SAndroid Build Coastguard Worker            mt._is_any_true(),
817*da0073e9SAndroid Build Coastguard Worker        )
818*da0073e9SAndroid Build Coastguard Worker
819*da0073e9SAndroid Build Coastguard Worker    def test__is_any_true_false(self):
820*da0073e9SAndroid Build Coastguard Worker        mt = torch.masked.MaskedTensor(
821*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[True, True, False], [False, False, True]]),
822*da0073e9SAndroid Build Coastguard Worker            torch.tensor([[False, False, False], [False, False, False]]),
823*da0073e9SAndroid Build Coastguard Worker        )
824*da0073e9SAndroid Build Coastguard Worker        _compare_mts(
825*da0073e9SAndroid Build Coastguard Worker            masked_tensor(torch.tensor(False), torch.tensor(True),),
826*da0073e9SAndroid Build Coastguard Worker            mt._is_any_true(),
827*da0073e9SAndroid Build Coastguard Worker        )
828*da0073e9SAndroid Build Coastguard Worker
829*da0073e9SAndroid Build Coastguard Worker    def test_backward(self):
830*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/128557
831*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.detect_anomaly():
832*da0073e9SAndroid Build Coastguard Worker            mt = torch.masked.MaskedTensor(
833*da0073e9SAndroid Build Coastguard Worker                torch.rand(2, 2),
834*da0073e9SAndroid Build Coastguard Worker                torch.rand(2, 2) > 0.5,
835*da0073e9SAndroid Build Coastguard Worker                requires_grad=True
836*da0073e9SAndroid Build Coastguard Worker            )
837*da0073e9SAndroid Build Coastguard Worker            mt.sum().backward()
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker
840*da0073e9SAndroid Build Coastguard Workerdef is_unary(op):
841*da0073e9SAndroid Build Coastguard Worker    return op.name in UNARY_NAMES
842*da0073e9SAndroid Build Coastguard Worker
843*da0073e9SAndroid Build Coastguard Workerdef is_binary(op):
844*da0073e9SAndroid Build Coastguard Worker    return op.name in BINARY_NAMES
845*da0073e9SAndroid Build Coastguard Worker
846*da0073e9SAndroid Build Coastguard Workerdef is_reduction(op):
847*da0073e9SAndroid Build Coastguard Worker    return op.name in REDUCE_NAMES and op.name not in {"all", "mean", "std", "var"}
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Workermt_unary_ufuncs = [op for op in unary_ufuncs if is_unary(op)]
850*da0073e9SAndroid Build Coastguard Workermt_binary_ufuncs = [op for op in binary_ufuncs if is_binary(op)]
851*da0073e9SAndroid Build Coastguard Workermt_reduction_ufuncs = [op for op in reduction_ops if is_reduction(op)]
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard WorkerMASKEDTENSOR_FLOAT_TYPES = {
854*da0073e9SAndroid Build Coastguard Worker    torch.float16,
855*da0073e9SAndroid Build Coastguard Worker    torch.float32,
856*da0073e9SAndroid Build Coastguard Worker    torch.float64,
857*da0073e9SAndroid Build Coastguard Worker}
858*da0073e9SAndroid Build Coastguard Worker
859*da0073e9SAndroid Build Coastguard Workerclass TestOperators(TestCase):
860*da0073e9SAndroid Build Coastguard Worker    def _convert_mt_args(self, args, mask, layout):
861*da0073e9SAndroid Build Coastguard Worker        return [
862*da0073e9SAndroid Build Coastguard Worker            masked_tensor(
863*da0073e9SAndroid Build Coastguard Worker                arg.sparse_mask(mask) if layout != torch.strided else arg, mask
864*da0073e9SAndroid Build Coastguard Worker            )
865*da0073e9SAndroid Build Coastguard Worker            if torch.is_tensor(arg)
866*da0073e9SAndroid Build Coastguard Worker            else arg
867*da0073e9SAndroid Build Coastguard Worker            for arg in args
868*da0073e9SAndroid Build Coastguard Worker        ]
869*da0073e9SAndroid Build Coastguard Worker
870*da0073e9SAndroid Build Coastguard Worker    def _test_unary_binary_equality(self, device, dtype, op, layout=torch.strided):
871*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=True)
872*da0073e9SAndroid Build Coastguard Worker
873*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
874*da0073e9SAndroid Build Coastguard Worker            input = sample.input
875*da0073e9SAndroid Build Coastguard Worker            sample_args, sample_kwargs = sample.args, sample.kwargs
876*da0073e9SAndroid Build Coastguard Worker            mask = (
877*da0073e9SAndroid Build Coastguard Worker                _create_random_mask(input.shape, device)
878*da0073e9SAndroid Build Coastguard Worker                if "mask" not in sample_kwargs
879*da0073e9SAndroid Build Coastguard Worker                else sample_kwargs.pop("mask")
880*da0073e9SAndroid Build Coastguard Worker            )
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker            if layout == torch.sparse_coo:
883*da0073e9SAndroid Build Coastguard Worker                mask = mask.to_sparse_coo().coalesce()
884*da0073e9SAndroid Build Coastguard Worker                input = input.sparse_mask(mask)
885*da0073e9SAndroid Build Coastguard Worker            elif layout == torch.sparse_csr:
886*da0073e9SAndroid Build Coastguard Worker                if input.ndim != 2 or mask.ndim != 2:
887*da0073e9SAndroid Build Coastguard Worker                    continue
888*da0073e9SAndroid Build Coastguard Worker                mask = mask.to_sparse_csr()
889*da0073e9SAndroid Build Coastguard Worker                input = input.sparse_mask(mask)
890*da0073e9SAndroid Build Coastguard Worker
891*da0073e9SAndroid Build Coastguard Worker            # Binary operations currently only support same size masks
892*da0073e9SAndroid Build Coastguard Worker            if is_binary(op):
893*da0073e9SAndroid Build Coastguard Worker                if input.shape != sample_args[0].shape:
894*da0073e9SAndroid Build Coastguard Worker                    continue
895*da0073e9SAndroid Build Coastguard Worker                # Binary operations also don't support kwargs right now
896*da0073e9SAndroid Build Coastguard Worker                else:
897*da0073e9SAndroid Build Coastguard Worker                    sample_kwargs = {}
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(input, mask)
900*da0073e9SAndroid Build Coastguard Worker            mt_args = self._convert_mt_args(sample_args, mask, layout)
901*da0073e9SAndroid Build Coastguard Worker
902*da0073e9SAndroid Build Coastguard Worker            mt_result = op(mt, *mt_args, **sample_kwargs)
903*da0073e9SAndroid Build Coastguard Worker            t_result = op(sample.input, *sample_args, **sample_kwargs)
904*da0073e9SAndroid Build Coastguard Worker
905*da0073e9SAndroid Build Coastguard Worker            _compare_mt_t(mt_result, t_result)
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker            # If the operation is binary, check that lhs = masked, rhs = regular tensor also works
908*da0073e9SAndroid Build Coastguard Worker            if is_binary(op) and layout == torch.strided:
909*da0073e9SAndroid Build Coastguard Worker                mt_result2 = op(mt, *sample_args, **sample_kwargs)
910*da0073e9SAndroid Build Coastguard Worker                _compare_mt_t(mt_result2, t_result)
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker    def _test_reduction_equality(self, device, dtype, op, layout=torch.strided):
913*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype, requires_grad=True)
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker        for sample in samples:
916*da0073e9SAndroid Build Coastguard Worker            input = sample.input
917*da0073e9SAndroid Build Coastguard Worker            # Reduction operations don't support more advanced args/kwargs right now
918*da0073e9SAndroid Build Coastguard Worker            sample_args, sample_kwargs = (), {}
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker            if input.dim() == 0 or input.numel() == 0:
921*da0073e9SAndroid Build Coastguard Worker                continue
922*da0073e9SAndroid Build Coastguard Worker
923*da0073e9SAndroid Build Coastguard Worker            mask = _create_random_mask(input.shape, device)
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker            if torch.count_nonzero(mask) == 0:
926*da0073e9SAndroid Build Coastguard Worker                continue
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker            tensor_input = _combine_input_and_mask(op.op, input, mask)
929*da0073e9SAndroid Build Coastguard Worker            if layout == torch.sparse_coo:
930*da0073e9SAndroid Build Coastguard Worker                mask = mask.to_sparse_coo().coalesce()
931*da0073e9SAndroid Build Coastguard Worker                input = input.sparse_mask(mask)
932*da0073e9SAndroid Build Coastguard Worker            elif layout == torch.sparse_csr:
933*da0073e9SAndroid Build Coastguard Worker                if input.ndim != 2 or mask.ndim != 2:
934*da0073e9SAndroid Build Coastguard Worker                    continue
935*da0073e9SAndroid Build Coastguard Worker                mask = mask.to_sparse_csr()
936*da0073e9SAndroid Build Coastguard Worker                input = input.sparse_mask(mask)
937*da0073e9SAndroid Build Coastguard Worker
938*da0073e9SAndroid Build Coastguard Worker            mt = masked_tensor(input, mask)
939*da0073e9SAndroid Build Coastguard Worker            mt_args = self._convert_mt_args(sample_args, mask, layout)
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker            mt_result = op(mt, *mt_args, **sample_kwargs)
942*da0073e9SAndroid Build Coastguard Worker            t_result = op(tensor_input, *sample_args, **sample_kwargs)
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker            _compare_mt_t(mt_result, t_result)
945*da0073e9SAndroid Build Coastguard Worker
946*da0073e9SAndroid Build Coastguard Worker    @ops(mt_unary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES)  # type: ignore[arg-type]
947*da0073e9SAndroid Build Coastguard Worker    @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
948*da0073e9SAndroid Build Coastguard Worker    def test_unary_core(self, device, dtype, op, layout):
949*da0073e9SAndroid Build Coastguard Worker        # Skip tests that don't have len(kwargs) == 0
950*da0073e9SAndroid Build Coastguard Worker        skip_variants = {
951*da0073e9SAndroid Build Coastguard Worker            "decimals_0",
952*da0073e9SAndroid Build Coastguard Worker            "decimals_3",
953*da0073e9SAndroid Build Coastguard Worker            "decimals_neg_3",
954*da0073e9SAndroid Build Coastguard Worker        }
955*da0073e9SAndroid Build Coastguard Worker        if op.name == "round" and op.variant_test_name in skip_variants:
956*da0073e9SAndroid Build Coastguard Worker            return
957*da0073e9SAndroid Build Coastguard Worker        self._test_unary_binary_equality(device, dtype, op)
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker    @ops(mt_binary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES)  # type: ignore[arg-type]
960*da0073e9SAndroid Build Coastguard Worker    @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
961*da0073e9SAndroid Build Coastguard Worker    # FIXME:
962*da0073e9SAndroid Build Coastguard Worker    # Result is just wrong; production logic should be fixed
963*da0073e9SAndroid Build Coastguard Worker    @decorateIf(
964*da0073e9SAndroid Build Coastguard Worker        unittest.expectedFailure,
965*da0073e9SAndroid Build Coastguard Worker        lambda params: (
966*da0073e9SAndroid Build Coastguard Worker            params["op"].name == "add" and
967*da0073e9SAndroid Build Coastguard Worker            params["dtype"] in [torch.float16, torch.float32] and
968*da0073e9SAndroid Build Coastguard Worker            params["device"] == "cpu" and
969*da0073e9SAndroid Build Coastguard Worker            params["layout"] == torch.sparse_csr
970*da0073e9SAndroid Build Coastguard Worker        )
971*da0073e9SAndroid Build Coastguard Worker    )
972*da0073e9SAndroid Build Coastguard Worker    # Result is just wrong; production logic should be fixed
973*da0073e9SAndroid Build Coastguard Worker    @decorateIf(
974*da0073e9SAndroid Build Coastguard Worker        unittest.expectedFailure,
975*da0073e9SAndroid Build Coastguard Worker        lambda params: (
976*da0073e9SAndroid Build Coastguard Worker            params["op"].name == "sub" and
977*da0073e9SAndroid Build Coastguard Worker            params["dtype"] in [torch.float16, torch.float32] and
978*da0073e9SAndroid Build Coastguard Worker            params["device"] == "cpu" and
979*da0073e9SAndroid Build Coastguard Worker            params["layout"] == torch.sparse_csr
980*da0073e9SAndroid Build Coastguard Worker        )
981*da0073e9SAndroid Build Coastguard Worker    )
982*da0073e9SAndroid Build Coastguard Worker    # Result is just wrong; production logic should be fixed
983*da0073e9SAndroid Build Coastguard Worker    @decorateIf(
984*da0073e9SAndroid Build Coastguard Worker        unittest.expectedFailure,
985*da0073e9SAndroid Build Coastguard Worker        lambda params: (
986*da0073e9SAndroid Build Coastguard Worker            params["op"].name == "eq" and
987*da0073e9SAndroid Build Coastguard Worker            params["dtype"] == torch.float64 and
988*da0073e9SAndroid Build Coastguard Worker            params["device"] == "cpu" and
989*da0073e9SAndroid Build Coastguard Worker            params["layout"] == torch.sparse_csr
990*da0073e9SAndroid Build Coastguard Worker        )
991*da0073e9SAndroid Build Coastguard Worker    )
992*da0073e9SAndroid Build Coastguard Worker    def test_binary_core(self, device, dtype, op, layout):
993*da0073e9SAndroid Build Coastguard Worker        self._test_unary_binary_equality(device, dtype, op, layout)
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard Worker    @ops(mt_reduction_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES)  # type: ignore[arg-type]
996*da0073e9SAndroid Build Coastguard Worker    @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
997*da0073e9SAndroid Build Coastguard Worker    def test_reduction_all(self, device, dtype, op, layout):
998*da0073e9SAndroid Build Coastguard Worker        # argmin and argmax are not currently supported for torch.sparse_csr
999*da0073e9SAndroid Build Coastguard Worker        if op.name in {"argmin", "argmax"} and layout == torch.sparse_csr:
1000*da0073e9SAndroid Build Coastguard Worker            return
1001*da0073e9SAndroid Build Coastguard Worker
1002*da0073e9SAndroid Build Coastguard Worker        self._test_reduction_equality(device, dtype, op, layout)
1003*da0073e9SAndroid Build Coastguard Worker
1004*da0073e9SAndroid Build Coastguard Worker
1005*da0073e9SAndroid Build Coastguard Workeronly_for = ("cpu", "cuda")
1006*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
1007*da0073e9SAndroid Build Coastguard Worker
1008*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestBasics, globals(), only_for=only_for)
1009*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestUnary)
1010*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestBinary)
1011*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestReductions)
1012*da0073e9SAndroid Build Coastguard Worker
1013*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
1014*da0073e9SAndroid Build Coastguard Worker    run_tests()
1015