# Owner(s): ["module: masked operators"] import torch import unittest from torch.testing._internal.common_utils import ( decorateIf, TestCase, run_tests, make_tensor, parametrize, instantiate_parametrized_tests, ) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, ops, ) from torch.testing._internal.common_methods_invocations import ( SampleInput, binary_ufuncs, reduction_ops, unary_ufuncs, ) from torch.masked import as_masked_tensor, masked_tensor, _combine_input_and_mask from torch.masked.maskedtensor.core import _masks_match, _tensors_match from torch.masked.maskedtensor.unary import NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS, UNARY_NAMES from torch.masked.maskedtensor.binary import NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS, BINARY_NAMES from torch.masked.maskedtensor.reductions import REDUCE_NAMES def _compare_mt_t(mt_result, t_result, rtol=1e-05, atol=1e-05): mask = mt_result.get_mask() mt_result_data = mt_result.get_data() if mask.layout in {torch.sparse_coo, torch.sparse_csr}: mask = mask.to_dense() if mt_result_data.layout in {torch.sparse_coo, torch.sparse_csr}: mt_result_data = mt_result_data.to_dense() a = mt_result_data.detach().masked_fill_(~mask, 0) b = t_result.detach().masked_fill_(~mask, 0) if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol): raise ValueError("The data in MaskedTensor a and Tensor b do not match") def _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08): mt_data1 = mt1.get_data() mt_data2 = mt2.get_data() if mt_data1.layout != mt_data2.layout: raise ValueError("mt1's data and mt2's data do not have the same layout. " f"mt1.get_data().layout = {mt_data1.layout} while mt2.get_data().layout = {mt_data2.layout}") mask = mt1.get_mask() mask2 = mt2.get_mask() if not _masks_match(mt1, mt2): raise ValueError("mt1 and mt2 must have matching masks") if mask.layout != mask2.layout: raise ValueError("mt1's mask and mt2's mask do not have the same layout. " f"mt1.get_mask().layout = {mask.layout} while mt2.get_mask().layout = {mask2.layout}") if mask.layout in {torch.sparse_coo, torch.sparse_csr}: mask = mask.to_dense() if mt_data1.layout in {torch.sparse_coo, torch.sparse_csr}: mt_data1 = mt_data1.to_dense() mt_data2 = mt_data2.to_dense() a = mt_data1.detach().masked_fill_(~mask, 0) b = mt_data2.detach().masked_fill_(~mask, 0) if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol): raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match") def _compare_forward_backward(data, mask, fn): mt = masked_tensor(data, mask, requires_grad=True) masked_res = fn(mt) masked_res.sum().backward() t = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_() tensor_res = fn(t) tensor_res.sum().backward() _compare_mt_t(masked_res, tensor_res) _compare_mt_t(mt.grad, t.grad, atol=1e-06) def _create_random_mask(shape, device): return make_tensor(shape, device=device, dtype=torch.bool) def _generate_sample_data( device="cpu", dtype=torch.float, requires_grad=True, layout=torch.strided ): assert layout in { torch.strided, torch.sparse_coo, torch.sparse_csr, }, "Layout must be strided/sparse_coo/sparse_csr" shapes = [ [], [2], [3, 5], [3, 2, 1, 2], ] inputs = [] for s in shapes: data = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) # type: ignore[arg-type] mask = _create_random_mask(s, device) if layout == torch.sparse_coo: mask = mask.to_sparse_coo().coalesce() data = data.sparse_mask(mask).requires_grad_(requires_grad) elif layout == torch.sparse_csr: if data.ndim != 2 and mask.ndim != 2: continue mask = mask.to_sparse_csr() data = data.sparse_mask(mask) inputs.append(SampleInput(data, kwargs={"mask": mask})) return inputs def _fix_fn_name(fn_name): if fn_name[-1] == "_": fn_name = fn_name[:-1] return fn_name class TestBasics(TestCase): def test_invalid_tensor_inputs(self, device): data = torch.randn((3, 4), device=device) mask = _create_random_mask((3, 4), device=device) mt = masked_tensor(data, mask) with self.assertRaisesRegex(TypeError, "data must be a Tensor"): masked_tensor(mt, mask) with self.assertRaisesRegex(TypeError, "data must be a Tensor"): masked_tensor(0, mask) with self.assertRaisesRegex(TypeError, "mask must be a Tensor"): masked_tensor(data, mt) with self.assertRaisesRegex(TypeError, "mask must be a Tensor"): masked_tensor(data, 0) def test_diff_layouts(self, device): data = torch.randn((3, 4), device=device).to_sparse_coo() mask = _create_random_mask((3, 4), device=device) with self.assertRaisesRegex(TypeError, "data and mask must have the same layout"): masked_tensor(data, mask) def test_diff_dim(self, device): data = torch.randn((3, 4, 5), device=device) mask = _create_random_mask((3, 4), device=device) with self.assertRaisesRegex(ValueError, "data.dim\\(\\) must equal mask.dim\\(\\)"): masked_tensor(data, mask) def test_diff_sizes(self, device): data = torch.randn((3, 4), device=device) mask = _create_random_mask((3, 3), device=device) with self.assertRaisesRegex(ValueError, "data.size\\(\\) must equal mask.size\\(\\)"): masked_tensor(data, mask) def test_grad_warning(self, device): data = torch.randn((3, 4), device=device, requires_grad=True) mask = _create_random_mask((3, 4), device=device) msg = "It is not recommended to create a MaskedTensor with a tensor that requires_grad." with self.assertWarnsRegex(UserWarning, msg): mt = masked_tensor(data, mask) def test_add(self, device): data = torch.arange(5.0, device=device) mask = torch.tensor([True, True, False, True, False], device=device) m0 = masked_tensor(data, mask) m1 = masked_tensor(data, ~mask) with self.assertRaisesRegex(ValueError, "Input masks must match."): m0 + m1 _compare_mts(m0 + m0, masked_tensor(torch.tensor([0., 2, 0, 6, 0], device=device), mask)) def test_softmax(self, device): data = torch.randn((3, 4), device=device) * 0.1 mask = torch.tensor( [ [True, True, True, False], [False, True, False, True], [True, True, False, False], ], device=device ) _compare_forward_backward(data, mask, lambda t: torch.softmax(t, -1)) def test_where(self, device): data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device) mask = data < 0 mx = masked_tensor(data, mask, requires_grad=True) my = masked_tensor(torch.ones_like(data), ~mask, requires_grad=True) masked_res = torch.where(mask, torch.exp(mx), my) masked_res.sum().backward() x = data.detach().clone().requires_grad_() y = torch.ones_like(x, device=device, requires_grad=True) tensor_res = torch.where(mask, torch.exp(x), y) tensor_res.sum().backward() _compare_mt_t(masked_res, tensor_res) _compare_mt_t(mx.grad, x.grad) _compare_mt_t(my.grad, y.grad) def test_unfold(self, device): data = torch.rand(5, 5, device=device) mask = torch.rand(5, 5, device=device) > 0.5 _compare_forward_backward(data, mask, lambda t: t.unfold(1, 2, 2)) def test_nn_unfold(self, device): data = torch.rand(2, 5, 3, 4, device=device) mask = torch.rand(2, 5, 3, 4, device=device) > 0.5 _compare_forward_backward(data, mask, lambda t: torch.nn.functional.unfold(t, kernel_size=(2, 3))) def test_stack(self, device): masked_tensors = [ masked_tensor( torch.rand(2, 5, 3, 4, device=device), torch.rand(2, 5, 3, 4, device=device) > 0.5, requires_grad=True, ) for _ in range(3) ] data_tensors = [mt.get_data().detach().clone().requires_grad_() for mt in masked_tensors] masked_res = torch.stack(masked_tensors) tensor_res = torch.stack(data_tensors) masked_res.sum().backward() tensor_res.sum().backward() _compare_mt_t(masked_res, tensor_res) for mt, t in zip(masked_tensors, data_tensors): _compare_mt_t(mt.grad, t.grad, atol=1e-06) def test_to_sparse(self, device): for sample in _generate_sample_data(device=device): data = sample.input mask = sample.kwargs["mask"] mt = masked_tensor(data.clone().detach(), mask, requires_grad=True) sparse_mt = mt.to_sparse() data.to_sparse().to_dense().sum().backward() sparse_mt.to_dense().sum().backward() _compare_mt_t(sparse_mt, data) _compare_mt_t(mt.grad, data.grad) def test_to_dense(self, device): samples = _generate_sample_data( device=device, layout=torch.sparse_coo ) + _generate_sample_data(device=device, layout=torch.sparse_csr) for sample in samples: data = sample.input mask = sample.kwargs["mask"] mt = masked_tensor(data, mask, requires_grad=True) dense_data = data.to_dense().detach().clone().requires_grad_(True) dense_mt = mt.to_dense() dense_data.sum().backward() dense_mt.sum().backward() _compare_mt_t(dense_mt, dense_data) _compare_mt_t(mt.grad.to_dense(), dense_data.grad) def test_to_dense_and_sparse_coo(self, device): for sample in _generate_sample_data(device=device, layout=torch.strided): data = sample.input mask = sample.kwargs["mask"] ms = mask.to_sparse_coo().coalesce() mt = masked_tensor(data, mask, requires_grad=True) mts = masked_tensor(data.sparse_mask(ms), ms, requires_grad=True) converted = mt.to_sparse().to_dense() converted.sum().backward() converted2 = mts.to_dense() converted2.sum().backward() _compare_mts(converted, converted2) _compare_mts(mt.grad, mts.grad.to_dense()) def test_to_dense_and_sparse_csr(self, device): for sample in _generate_sample_data(device=device, layout=torch.strided): data = sample.input mask = sample.kwargs["mask"] if data.ndim != 2: continue ms = mask.to_sparse_csr() mt = masked_tensor(data, mask, requires_grad=True) mts = masked_tensor(data.sparse_mask(ms), ms, requires_grad=True) converted = mt.to_sparse_csr().to_dense() converted.sum().backward() converted2 = mts.to_dense() converted2.sum().backward() _compare_mts(converted, converted2) _compare_mts(mt.grad, mts.grad.to_dense()) def test_invalid_sparse_layout(self, device): data = torch.randn((3, 4), device=device).to_sparse_csc() mask = _create_random_mask((3, 4), device=device).to_sparse_csc() with self.assertRaisesRegex(TypeError, "data layout of torch.sparse_csc is not supported"): masked_tensor(data, mask) def test_invalid_sparse_coo_values(self, device): v = torch.tensor([3, 4, 5], dtype=torch.float32) i1 = torch.tensor([[0, 1, 1], [2, 0, 2]]) i2 = torch.tensor([[0, 1, 1], [2, 1, 2]]) t = torch.sparse_coo_tensor(i1, v, (2, 4), device=device) mask = torch.sparse_coo_tensor(i2, torch.tensor([True, True, True]), (2, 4), device=device) msg = "data and mask are both sparse COO tensors but do not have the same indices." with self.assertRaisesRegex(ValueError, msg): masked_tensor(t, mask) def test_invalid_sparse_csr_values(self, device): crow_indices1 = [0, 2, 3] crow_indices2 = [0, 1, 3] col_indices1 = [0, 1, 2] col_indices2 = [1, 2, 3] values = [2, 3, 4] mask_values = [True, True, True] t1 = torch.sparse_csr_tensor( torch.tensor(crow_indices1, dtype=torch.int64), torch.tensor(col_indices1, dtype=torch.int64), torch.tensor(values), size=(2, 4) ) mask1 = torch.sparse_csr_tensor( torch.tensor(crow_indices2, dtype=torch.int64), torch.tensor(col_indices1, dtype=torch.int64), torch.tensor(mask_values), dtype=torch.bool, size=(2, 4), ) t2 = torch.sparse_csr_tensor( torch.tensor(crow_indices2, dtype=torch.int64), torch.tensor(col_indices1, dtype=torch.int64), torch.tensor(values), size=(2, 4), ) mask2 = torch.sparse_csr_tensor( torch.tensor(crow_indices2, dtype=torch.int64), torch.tensor(col_indices2, dtype=torch.int64), torch.tensor(mask_values), dtype=torch.bool, size=(2, 4), ) msg = "data and mask are both sparse CSR tensors but do not share either crow or col indices." with self.assertRaisesRegex(ValueError, msg): masked_tensor(t1, mask1) with self.assertRaisesRegex(ValueError, msg): masked_tensor(t2, mask2) def test_contiguous(self, device): data = torch.randn((3, 3), device=device) contiguous_data = data.clone() mask1 = (contiguous_data > 0).bool() not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2)) mask2 = (not_contiguous_data > 0).bool() contiguous_mt = masked_tensor(contiguous_data, mask1) not_contiguous_mt = masked_tensor(not_contiguous_data, mask2) contiguous_mt_sparse = masked_tensor( contiguous_data.to_sparse_coo(), mask1.to_sparse_coo() ) not_contiguous_mt_sparse = masked_tensor( not_contiguous_data.to_sparse_coo(), mask2.to_sparse_coo() ) self.assertEqual(contiguous_data.is_contiguous(), True) self.assertEqual(not_contiguous_data.is_contiguous(), False) self.assertEqual(contiguous_mt.is_contiguous(), True) self.assertEqual(not_contiguous_mt.is_contiguous(), False) error_msg = "MaskedTensors with sparse data do not have is_contiguous" for t in [contiguous_mt_sparse, not_contiguous_mt_sparse]: with self.assertRaisesRegex(ValueError, error_msg): t.is_contiguous() with self.assertRaisesRegex(ValueError, error_msg): t.contiguous() now_contiguous_mt = not_contiguous_mt.contiguous() _compare_mts(not_contiguous_mt, now_contiguous_mt) self.assertEqual(now_contiguous_mt.is_contiguous(), True) self.assertEqual(now_contiguous_mt.get_data().is_contiguous(), True) self.assertEqual(now_contiguous_mt.is_contiguous(), True) class TestUnary(TestCase): def _get_test_data(self, fn_name): data = torch.randn(10, 10) mask = torch.rand(10, 10) > 0.5 fn_name = _fix_fn_name(fn_name) if fn_name in ["log", "log10", "log1p", "log2", "sqrt"]: data = data.mul(0.5).abs() if fn_name in ["rsqrt"]: data = data.abs() + 1 # Void division by zero if fn_name in ["acos", "arccos", "asin", "arcsin", "logit"]: data = data.abs().mul(0.5).clamp(0, 1) if fn_name in ["atanh", "arctanh", "erfinv"]: data = data.mul(0.5).clamp(-1, 1) if fn_name in ["acosh", "arccosh"]: data = data.abs() + 1 if fn_name in ["bitwise_not"]: data = data.mul(128).to(torch.int8) return data, mask def _get_sample_kwargs(self, fn_name): fn_name = _fix_fn_name(fn_name) kwargs = {} if fn_name in ["clamp", "clip"]: kwargs["min"] = -0.5 kwargs["max"] = 0.5 return kwargs def _get_sample_args(self, fn_name, data, mask): fn_name = _fix_fn_name(fn_name) mt = masked_tensor(data, mask) t_args = [data] mt_args = [mt] if fn_name in ["pow"]: t_args += [2.0] mt_args += [2.0] return t_args, mt_args @parametrize("fn", NATIVE_UNARY_FNS) def test_unary(self, fn): torch.random.manual_seed(0) fn_name = fn.__name__ data, mask = self._get_test_data(fn_name) kwargs = self._get_sample_kwargs(fn_name) t_args, mt_args = self._get_sample_args(fn_name, data, mask) mt_result = fn(*mt_args, **kwargs) t_result = fn(*t_args, **kwargs) _compare_mt_t(mt_result, t_result) @parametrize("fn", NATIVE_INPLACE_UNARY_FNS) def test_inplace_unary(self, fn): torch.random.manual_seed(0) fn_name = fn.__name__ data, mask = self._get_test_data(fn_name) kwargs = self._get_sample_kwargs(fn_name) t_args, mt_args = self._get_sample_args(fn_name, data, mask) mt_result = fn(*mt_args, **kwargs) t_result = fn(*t_args, **kwargs) _compare_mt_t(mt_result, t_result) class TestBinary(TestCase): def _get_test_data(self, fn_name): fn_name = _fix_fn_name(fn_name) data0 = torch.randn(10, 10) data1 = torch.randn(10, 10) mask = torch.rand(10, 10) > 0.5 if fn_name in ["bitwise_and", "bitwise_or", "bitwise_xor"]: data0 = data0.mul(128).to(torch.int8) data1 = data1.mul(128).to(torch.int8) if fn_name in ["bitwise_left_shift", "bitwise_right_shift"]: data0 = data0.abs().to(torch.int64) data1 = data1.abs().to(torch.int64) return data0, data1, mask def _get_sample_kwargs(self, fn_name): fn_name = _fix_fn_name(fn_name) kwargs = {} return kwargs def _yield_sample_args(self, fn_name, data0, data1, mask): """ Returns two sets of Tensor and MaskedTensor args for a binary function to compute. Tensor args are all the same (just the two provided data tensors), while the MaskedTensor args tests both (MaskedTensor, MaskedTensor) and (MaskedTensor, Tensor) """ fn_name = _fix_fn_name(fn_name) mt0 = masked_tensor(data0, mask) mt1 = masked_tensor(data1, mask) t_args = [data0, data1] mt_args = [mt0, mt1] yield t_args, mt_args t_args = [data0, data1] mt_args = [mt0, data1] yield t_args, mt_args @parametrize("fn", NATIVE_BINARY_FNS) def test_binary(self, fn): torch.random.manual_seed(0) fn_name = fn.__name__ data0, data1, mask = self._get_test_data(fn_name) kwargs = self._get_sample_kwargs(fn_name) for (t_args, mt_args) in self._yield_sample_args(fn_name, data0, data1, mask): mt_result = fn(*mt_args, **kwargs) t_result = fn(*t_args, **kwargs) _compare_mt_t(mt_result, t_result) @parametrize("fn", NATIVE_INPLACE_BINARY_FNS) def test_inplace_binary(self, fn): torch.random.manual_seed(0) fn_name = fn.__name__ data0, data1, mask = self._get_test_data(fn_name) kwargs = self._get_sample_kwargs(fn_name) for (t_args, mt_args) in self._yield_sample_args(fn_name, data0, data1, mask): mt_result = fn(*mt_args, **kwargs) t_result = fn(*t_args, **kwargs) _compare_mt_t(mt_result, t_result) @parametrize("fn_name", ["add", "add_"]) def test_masks_match(self, fn_name): torch.random.manual_seed(0) fn = getattr(torch.ops.aten, fn_name) data0, data1, mask = self._get_test_data(fn_name) mask0 = mask mask1 = torch.rand(mask.size()) > 0.5 mt0 = masked_tensor(data0, mask0) mt1 = masked_tensor(data1, mask1) try: fn(mt0, mt1) raise AssertionError except ValueError as e: assert ( "Input masks must match. If you need support for this, please open an issue on Github." == str(e) ) class TestReductions(TestCase): def test_max_not_implemented(self): d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) m = torch.tensor([[True, False, False], [False, True, False]]) mt = masked_tensor(d, m) with self.assertRaisesRegex(TypeError, "torch._ops.aten.max.default"): mt.max() def test_sum(self): d = torch.tensor([[0, 1, 2, 6], [3, 4, 5.0, 7]]) m = torch.tensor([[True, False, False, True], [False, True, False, True]]) mt = masked_tensor(d, m) _compare_mts(masked_tensor(torch.tensor(17.0), torch.tensor(True)), mt.sum()) _compare_mts( masked_tensor( torch.tensor([0.0, 4.0, 1.0, 13]), torch.tensor([True, True, False, True]), ), mt.sum(dim=0), ) def test_sum_grad(self): d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) m = torch.tensor([[True, False, False], [False, True, False]]) mt = masked_tensor(d, m, requires_grad=True) mt.sum().backward() _compare_mts(mt.grad, masked_tensor(torch.tensor(1.0).expand_as(m), m)) def test_mean(self): d = torch.tensor([[0, 1, 3, 2], [3, 4, 1.0, 4]]) m = torch.tensor([[True, False, False, True], [False, True, False, True]]) mt = masked_tensor(d, m) _compare_mts(masked_tensor(torch.tensor(2.5), torch.tensor(True)), mt.mean()) _compare_mts( masked_tensor( torch.tensor([0.0, 4.0, 1.0, 3]), torch.tensor([True, True, False, True]), ), mt.mean(dim=0), ) """ The following block of tests "test_mean_grad_case_1[a through e] are used to test the functionality of the two different ways of constructing MaskedTensors: masked_tensor(data, mask, requires_grad=True/False) -- NO differentiable constructor and always a leaf as_masked_tensor(data, mask) -- differentiable constructor Like torch.tensor(data), masked_tensor(data, mask) will provide a UserWarning if data.requires_grad=True as_masked_tensor does not take in requires_grad -- it just takes on the requires_grad from data Therefore, there are 6 cases to test and we use `mean` as a proxy to test the different combinations Assuming mt.mean().backward() is run after each constructor: Case 1a: values.requires_grad = True mt = masked_tensor(values, mask, requires_grad=True) yields - Provide a UserWarning because values.requires_grad=True - values.grad = None - mt.grad is a MaskedTensor with the correct gradient Case 1b: values.requires_grad = False mt = masked_tensor(values, mask, requires_grad=True) yields - values.grad = None - mt.grad is a MaskedTensor with the correct gradient Case 2a/2b: values.requires_grad = True/False mt = masked_tensor(values, mask, requires_grad=False) will both yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn" as expected. When values.requires_grad=True, we will also get a UserWarning Case 3a: values.requires_grad = True mt = as_masked_tensor(values, mask) yields - values.grad is a MaskedTensor with the correct gradient - mt.grad is None and gives a UserWarning that "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad" Case 3b: values.requires_grad = False mt = as_masked_tensor(values, mask) will yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn" as expected. """ def test_mean_grad_case_1a(self): """ values.requires_grad = True mt = masked_tensor(values, mask, requires_grad=True) """ d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True) m = torch.tensor([[True, False, False], [False, True, False]]) with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"): mt = masked_tensor(d, m, requires_grad=True) mt.mean().backward() self.assertIsNone(d.grad) _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m)) def test_mean_grad_case_1b(self): """ values.requires_grad = False mt = masked_tensor(values, mask, requires_grad=True) """ d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) m = torch.tensor([[True, False, False], [False, True, False]]) mt = masked_tensor(d, m, requires_grad=True) mt.mean().backward() self.assertIsNone(d.grad) _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m)) def test_mean_grad_case_1c(self): """ values.requires_grad = True mt = masked_tensor(values, mask, requires_grad=False) """ d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True) m = torch.tensor([[True, False, False], [False, True, False]]) with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"): mt = masked_tensor(d, m, requires_grad=False) result = mt.mean() msg = "element 0 of tensors does not require grad and does not have a grad_fn" with self.assertRaisesRegex(RuntimeError, msg): result.backward() def test_mean_grad_case_1d(self): """ values.requires_grad = False mt = masked_tensor(values, mask, requires_grad=False) """ d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) m = torch.tensor([[True, False, False], [False, True, False]]) mt = masked_tensor(d, m, requires_grad=False) result = mt.mean() msg = "element 0 of tensors does not require grad and does not have a grad_fn" with self.assertRaisesRegex(RuntimeError, msg): result.backward() def test_mean_grad_case_1e(self): """ values.requires_grad = True mt = as_masked_tensor(values, mask) """ d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True) m = torch.tensor([[True, False, False], [False, True, False]]) mt = as_masked_tensor(d, m) mt.mean().backward() _compare_mts(d.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m)) msg = "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad" with self.assertWarnsRegex(UserWarning, msg): self.assertIsNone(mt.grad) def test_mean_grad_case_1f(self): """ values.requires_grad = False mt = as_masked_tensor(values, mask) """ d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) m = torch.tensor([[True, False, False], [False, True, False]]) mt = as_masked_tensor(d, m) result = mt.mean() msg = "element 0 of tensors does not require grad and does not have a grad_fn" with self.assertRaisesRegex(RuntimeError, msg): result.backward() def test_mean_dim_grad(self): d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) m = torch.tensor([[True, True, False], [False, True, False]]) mt = masked_tensor(d, m, requires_grad=True) mt.mean(1).sum().backward() _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0.5, 0], [0, 1, 0]]), m)) def test_amax(self): d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]]) m = torch.tensor([[True, False, False, True], [False, True, False, True]]) mt = masked_tensor(d, m) _compare_mts(masked_tensor(torch.tensor(3.0), torch.tensor(True)), mt.amax()) _compare_mts( masked_tensor( torch.tensor([0.0, -4.0, 1.0, 3]), torch.tensor([True, True, False, True]), ), mt.amax(dim=0), ) def test_amax_grad(self): d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) m = torch.tensor([[True, False, False], [False, True, False]]) mt = masked_tensor(d, m, requires_grad=True) mt.amax().backward() _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.0, 0, 0], [0, 1, 0]]), m)) def test_amin(self): d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]]) m = torch.tensor([[True, False, False, True], [False, True, False, True]]) mt = masked_tensor(d, m) _compare_mts(masked_tensor(torch.tensor(-4.0), torch.tensor(True)), mt.amin()) _compare_mts( masked_tensor( torch.tensor([0.0, -4.0, 1.0, -3]), torch.tensor([True, True, False, True]), ), mt.amin(dim=0), ) def test_amin_grad(self): d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) m = torch.tensor([[True, False, False], [False, True, False]]) mt = masked_tensor(d, m, requires_grad=True) mt.amin().backward() _compare_mts(mt.grad, masked_tensor(torch.tensor([[1.0, 0, 0], [0, 0, 0]]), m)) def test_prod(self): d = torch.tensor([[0, 1, 3, 0.0], [float("nan"), 4, 1.0, 5.0]]) m = torch.tensor([[True, False, False, True], [False, True, False, True]]) mt = masked_tensor(d, m) _compare_mts(masked_tensor(torch.tensor(0.0), torch.tensor(True)), mt.prod()) _compare_mts( masked_tensor( torch.tensor([0.0, 4.0, 1.0, 0.0]), torch.tensor([True, True, False, True]), ), mt.prod(dim=0), ) def test_prod_grad(self): d = torch.tensor([[2, float("nan"), 2], [3, 4, 5.0]]) m = torch.tensor([[True, False, False], [False, True, False]]) mt = masked_tensor(d, m, requires_grad=True) mt.prod().backward() _compare_mts(mt.grad, masked_tensor(torch.tensor([[4.0, 0, 0], [0, 2, 0]]), m)) def test_all(self): d = torch.tensor([[True, True, False, False], [False, True, True, True]]) m = torch.tensor([[True, False, False, True], [False, True, False, True]]) mt = masked_tensor(d, m) _compare_mts(masked_tensor(torch.tensor(False), torch.tensor(True)), mt.all()) _compare_mts( masked_tensor( torch.tensor([True, True, True, False]), torch.tensor([True, True, False, True]), ), mt.all(dim=0), ) m = torch.tensor([[True, False, True, False], [False, True, False, False]]) mt = masked_tensor(d, m) _compare_mts( masked_tensor( torch.tensor([True, True, False, True]), torch.tensor([True, True, True, False]), ), mt.all(dim=0), ) def test_grad_dtype(self): d = torch.tensor([[True, True, False], [False, True, True]]) m = torch.tensor([[True, False, False], [False, True, False]]) msg = "Only Tensors of floating point and complex dtype can require gradients" with self.assertRaisesRegex(RuntimeError, msg): masked_tensor(d, m, requires_grad=True) def test_any_true_dtype(self): mt = torch.masked.MaskedTensor( torch.rand(2, 2), torch.rand(2, 2) > 0.5 ) msg = "expected a boolean tensor" with self.assertRaisesRegex(ValueError, msg): mt._is_any_true() def test__is_any_true(self): mt = torch.masked.MaskedTensor( torch.tensor([[True, True, False], [False, False, True]]), torch.tensor([[True, False, False], [False, True, False]]), ) _compare_mts( masked_tensor(torch.tensor(True), torch.tensor(True)), mt._is_any_true(), ) def test__is_any_true_false(self): mt = torch.masked.MaskedTensor( torch.tensor([[True, True, False], [False, False, True]]), torch.tensor([[False, False, False], [False, False, False]]), ) _compare_mts( masked_tensor(torch.tensor(False), torch.tensor(True),), mt._is_any_true(), ) def test_backward(self): # See https://github.com/pytorch/pytorch/issues/128557 with torch.autograd.detect_anomaly(): mt = torch.masked.MaskedTensor( torch.rand(2, 2), torch.rand(2, 2) > 0.5, requires_grad=True ) mt.sum().backward() def is_unary(op): return op.name in UNARY_NAMES def is_binary(op): return op.name in BINARY_NAMES def is_reduction(op): return op.name in REDUCE_NAMES and op.name not in {"all", "mean", "std", "var"} mt_unary_ufuncs = [op for op in unary_ufuncs if is_unary(op)] mt_binary_ufuncs = [op for op in binary_ufuncs if is_binary(op)] mt_reduction_ufuncs = [op for op in reduction_ops if is_reduction(op)] MASKEDTENSOR_FLOAT_TYPES = { torch.float16, torch.float32, torch.float64, } class TestOperators(TestCase): def _convert_mt_args(self, args, mask, layout): return [ masked_tensor( arg.sparse_mask(mask) if layout != torch.strided else arg, mask ) if torch.is_tensor(arg) else arg for arg in args ] def _test_unary_binary_equality(self, device, dtype, op, layout=torch.strided): samples = op.sample_inputs(device, dtype, requires_grad=True) for sample in samples: input = sample.input sample_args, sample_kwargs = sample.args, sample.kwargs mask = ( _create_random_mask(input.shape, device) if "mask" not in sample_kwargs else sample_kwargs.pop("mask") ) if layout == torch.sparse_coo: mask = mask.to_sparse_coo().coalesce() input = input.sparse_mask(mask) elif layout == torch.sparse_csr: if input.ndim != 2 or mask.ndim != 2: continue mask = mask.to_sparse_csr() input = input.sparse_mask(mask) # Binary operations currently only support same size masks if is_binary(op): if input.shape != sample_args[0].shape: continue # Binary operations also don't support kwargs right now else: sample_kwargs = {} mt = masked_tensor(input, mask) mt_args = self._convert_mt_args(sample_args, mask, layout) mt_result = op(mt, *mt_args, **sample_kwargs) t_result = op(sample.input, *sample_args, **sample_kwargs) _compare_mt_t(mt_result, t_result) # If the operation is binary, check that lhs = masked, rhs = regular tensor also works if is_binary(op) and layout == torch.strided: mt_result2 = op(mt, *sample_args, **sample_kwargs) _compare_mt_t(mt_result2, t_result) def _test_reduction_equality(self, device, dtype, op, layout=torch.strided): samples = op.sample_inputs(device, dtype, requires_grad=True) for sample in samples: input = sample.input # Reduction operations don't support more advanced args/kwargs right now sample_args, sample_kwargs = (), {} if input.dim() == 0 or input.numel() == 0: continue mask = _create_random_mask(input.shape, device) if torch.count_nonzero(mask) == 0: continue tensor_input = _combine_input_and_mask(op.op, input, mask) if layout == torch.sparse_coo: mask = mask.to_sparse_coo().coalesce() input = input.sparse_mask(mask) elif layout == torch.sparse_csr: if input.ndim != 2 or mask.ndim != 2: continue mask = mask.to_sparse_csr() input = input.sparse_mask(mask) mt = masked_tensor(input, mask) mt_args = self._convert_mt_args(sample_args, mask, layout) mt_result = op(mt, *mt_args, **sample_kwargs) t_result = op(tensor_input, *sample_args, **sample_kwargs) _compare_mt_t(mt_result, t_result) @ops(mt_unary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type] @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr]) def test_unary_core(self, device, dtype, op, layout): # Skip tests that don't have len(kwargs) == 0 skip_variants = { "decimals_0", "decimals_3", "decimals_neg_3", } if op.name == "round" and op.variant_test_name in skip_variants: return self._test_unary_binary_equality(device, dtype, op) @ops(mt_binary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type] @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr]) # FIXME: # Result is just wrong; production logic should be fixed @decorateIf( unittest.expectedFailure, lambda params: ( params["op"].name == "add" and params["dtype"] in [torch.float16, torch.float32] and params["device"] == "cpu" and params["layout"] == torch.sparse_csr ) ) # Result is just wrong; production logic should be fixed @decorateIf( unittest.expectedFailure, lambda params: ( params["op"].name == "sub" and params["dtype"] in [torch.float16, torch.float32] and params["device"] == "cpu" and params["layout"] == torch.sparse_csr ) ) # Result is just wrong; production logic should be fixed @decorateIf( unittest.expectedFailure, lambda params: ( params["op"].name == "eq" and params["dtype"] == torch.float64 and params["device"] == "cpu" and params["layout"] == torch.sparse_csr ) ) def test_binary_core(self, device, dtype, op, layout): self._test_unary_binary_equality(device, dtype, op, layout) @ops(mt_reduction_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type] @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr]) def test_reduction_all(self, device, dtype, op, layout): # argmin and argmax are not currently supported for torch.sparse_csr if op.name in {"argmin", "argmax"} and layout == torch.sparse_csr: return self._test_reduction_equality(device, dtype, op, layout) only_for = ("cpu", "cuda") instantiate_device_type_tests(TestOperators, globals(), only_for=only_for) instantiate_device_type_tests(TestBasics, globals(), only_for=only_for) instantiate_parametrized_tests(TestUnary) instantiate_parametrized_tests(TestBinary) instantiate_parametrized_tests(TestReductions) if __name__ == '__main__': run_tests()