1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: masked operators"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerimport unittest 5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 6*da0073e9SAndroid Build Coastguard Worker decorateIf, 7*da0073e9SAndroid Build Coastguard Worker TestCase, 8*da0073e9SAndroid Build Coastguard Worker run_tests, 9*da0073e9SAndroid Build Coastguard Worker make_tensor, 10*da0073e9SAndroid Build Coastguard Worker parametrize, 11*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 12*da0073e9SAndroid Build Coastguard Worker) 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 14*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 15*da0073e9SAndroid Build Coastguard Worker ops, 16*da0073e9SAndroid Build Coastguard Worker) 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import ( 18*da0073e9SAndroid Build Coastguard Worker SampleInput, 19*da0073e9SAndroid Build Coastguard Worker binary_ufuncs, 20*da0073e9SAndroid Build Coastguard Worker reduction_ops, 21*da0073e9SAndroid Build Coastguard Worker unary_ufuncs, 22*da0073e9SAndroid Build Coastguard Worker) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerfrom torch.masked import as_masked_tensor, masked_tensor, _combine_input_and_mask 25*da0073e9SAndroid Build Coastguard Workerfrom torch.masked.maskedtensor.core import _masks_match, _tensors_match 26*da0073e9SAndroid Build Coastguard Workerfrom torch.masked.maskedtensor.unary import NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS, UNARY_NAMES 27*da0073e9SAndroid Build Coastguard Workerfrom torch.masked.maskedtensor.binary import NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS, BINARY_NAMES 28*da0073e9SAndroid Build Coastguard Workerfrom torch.masked.maskedtensor.reductions import REDUCE_NAMES 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Workerdef _compare_mt_t(mt_result, t_result, rtol=1e-05, atol=1e-05): 32*da0073e9SAndroid Build Coastguard Worker mask = mt_result.get_mask() 33*da0073e9SAndroid Build Coastguard Worker mt_result_data = mt_result.get_data() 34*da0073e9SAndroid Build Coastguard Worker if mask.layout in {torch.sparse_coo, torch.sparse_csr}: 35*da0073e9SAndroid Build Coastguard Worker mask = mask.to_dense() 36*da0073e9SAndroid Build Coastguard Worker if mt_result_data.layout in {torch.sparse_coo, torch.sparse_csr}: 37*da0073e9SAndroid Build Coastguard Worker mt_result_data = mt_result_data.to_dense() 38*da0073e9SAndroid Build Coastguard Worker a = mt_result_data.detach().masked_fill_(~mask, 0) 39*da0073e9SAndroid Build Coastguard Worker b = t_result.detach().masked_fill_(~mask, 0) 40*da0073e9SAndroid Build Coastguard Worker if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol): 41*da0073e9SAndroid Build Coastguard Worker raise ValueError("The data in MaskedTensor a and Tensor b do not match") 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Workerdef _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08): 44*da0073e9SAndroid Build Coastguard Worker mt_data1 = mt1.get_data() 45*da0073e9SAndroid Build Coastguard Worker mt_data2 = mt2.get_data() 46*da0073e9SAndroid Build Coastguard Worker if mt_data1.layout != mt_data2.layout: 47*da0073e9SAndroid Build Coastguard Worker raise ValueError("mt1's data and mt2's data do not have the same layout. " 48*da0073e9SAndroid Build Coastguard Worker f"mt1.get_data().layout = {mt_data1.layout} while mt2.get_data().layout = {mt_data2.layout}") 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker mask = mt1.get_mask() 51*da0073e9SAndroid Build Coastguard Worker mask2 = mt2.get_mask() 52*da0073e9SAndroid Build Coastguard Worker if not _masks_match(mt1, mt2): 53*da0073e9SAndroid Build Coastguard Worker raise ValueError("mt1 and mt2 must have matching masks") 54*da0073e9SAndroid Build Coastguard Worker if mask.layout != mask2.layout: 55*da0073e9SAndroid Build Coastguard Worker raise ValueError("mt1's mask and mt2's mask do not have the same layout. " 56*da0073e9SAndroid Build Coastguard Worker f"mt1.get_mask().layout = {mask.layout} while mt2.get_mask().layout = {mask2.layout}") 57*da0073e9SAndroid Build Coastguard Worker if mask.layout in {torch.sparse_coo, torch.sparse_csr}: 58*da0073e9SAndroid Build Coastguard Worker mask = mask.to_dense() 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker if mt_data1.layout in {torch.sparse_coo, torch.sparse_csr}: 61*da0073e9SAndroid Build Coastguard Worker mt_data1 = mt_data1.to_dense() 62*da0073e9SAndroid Build Coastguard Worker mt_data2 = mt_data2.to_dense() 63*da0073e9SAndroid Build Coastguard Worker a = mt_data1.detach().masked_fill_(~mask, 0) 64*da0073e9SAndroid Build Coastguard Worker b = mt_data2.detach().masked_fill_(~mask, 0) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol): 67*da0073e9SAndroid Build Coastguard Worker raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match") 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Workerdef _compare_forward_backward(data, mask, fn): 70*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(data, mask, requires_grad=True) 71*da0073e9SAndroid Build Coastguard Worker masked_res = fn(mt) 72*da0073e9SAndroid Build Coastguard Worker masked_res.sum().backward() 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker t = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_() 75*da0073e9SAndroid Build Coastguard Worker tensor_res = fn(t) 76*da0073e9SAndroid Build Coastguard Worker tensor_res.sum().backward() 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(masked_res, tensor_res) 79*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(mt.grad, t.grad, atol=1e-06) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Workerdef _create_random_mask(shape, device): 83*da0073e9SAndroid Build Coastguard Worker return make_tensor(shape, device=device, dtype=torch.bool) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workerdef _generate_sample_data( 86*da0073e9SAndroid Build Coastguard Worker device="cpu", dtype=torch.float, requires_grad=True, layout=torch.strided 87*da0073e9SAndroid Build Coastguard Worker): 88*da0073e9SAndroid Build Coastguard Worker assert layout in { 89*da0073e9SAndroid Build Coastguard Worker torch.strided, 90*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo, 91*da0073e9SAndroid Build Coastguard Worker torch.sparse_csr, 92*da0073e9SAndroid Build Coastguard Worker }, "Layout must be strided/sparse_coo/sparse_csr" 93*da0073e9SAndroid Build Coastguard Worker shapes = [ 94*da0073e9SAndroid Build Coastguard Worker [], 95*da0073e9SAndroid Build Coastguard Worker [2], 96*da0073e9SAndroid Build Coastguard Worker [3, 5], 97*da0073e9SAndroid Build Coastguard Worker [3, 2, 1, 2], 98*da0073e9SAndroid Build Coastguard Worker ] 99*da0073e9SAndroid Build Coastguard Worker inputs = [] 100*da0073e9SAndroid Build Coastguard Worker for s in shapes: 101*da0073e9SAndroid Build Coastguard Worker data = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) # type: ignore[arg-type] 102*da0073e9SAndroid Build Coastguard Worker mask = _create_random_mask(s, device) 103*da0073e9SAndroid Build Coastguard Worker if layout == torch.sparse_coo: 104*da0073e9SAndroid Build Coastguard Worker mask = mask.to_sparse_coo().coalesce() 105*da0073e9SAndroid Build Coastguard Worker data = data.sparse_mask(mask).requires_grad_(requires_grad) 106*da0073e9SAndroid Build Coastguard Worker elif layout == torch.sparse_csr: 107*da0073e9SAndroid Build Coastguard Worker if data.ndim != 2 and mask.ndim != 2: 108*da0073e9SAndroid Build Coastguard Worker continue 109*da0073e9SAndroid Build Coastguard Worker mask = mask.to_sparse_csr() 110*da0073e9SAndroid Build Coastguard Worker data = data.sparse_mask(mask) 111*da0073e9SAndroid Build Coastguard Worker inputs.append(SampleInput(data, kwargs={"mask": mask})) 112*da0073e9SAndroid Build Coastguard Worker return inputs 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Workerdef _fix_fn_name(fn_name): 115*da0073e9SAndroid Build Coastguard Worker if fn_name[-1] == "_": 116*da0073e9SAndroid Build Coastguard Worker fn_name = fn_name[:-1] 117*da0073e9SAndroid Build Coastguard Worker return fn_name 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Workerclass TestBasics(TestCase): 121*da0073e9SAndroid Build Coastguard Worker def test_invalid_tensor_inputs(self, device): 122*da0073e9SAndroid Build Coastguard Worker data = torch.randn((3, 4), device=device) 123*da0073e9SAndroid Build Coastguard Worker mask = _create_random_mask((3, 4), device=device) 124*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(data, mask) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "data must be a Tensor"): 127*da0073e9SAndroid Build Coastguard Worker masked_tensor(mt, mask) 128*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "data must be a Tensor"): 129*da0073e9SAndroid Build Coastguard Worker masked_tensor(0, mask) 130*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "mask must be a Tensor"): 131*da0073e9SAndroid Build Coastguard Worker masked_tensor(data, mt) 132*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "mask must be a Tensor"): 133*da0073e9SAndroid Build Coastguard Worker masked_tensor(data, 0) 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker def test_diff_layouts(self, device): 136*da0073e9SAndroid Build Coastguard Worker data = torch.randn((3, 4), device=device).to_sparse_coo() 137*da0073e9SAndroid Build Coastguard Worker mask = _create_random_mask((3, 4), device=device) 138*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "data and mask must have the same layout"): 139*da0073e9SAndroid Build Coastguard Worker masked_tensor(data, mask) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker def test_diff_dim(self, device): 142*da0073e9SAndroid Build Coastguard Worker data = torch.randn((3, 4, 5), device=device) 143*da0073e9SAndroid Build Coastguard Worker mask = _create_random_mask((3, 4), device=device) 144*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "data.dim\\(\\) must equal mask.dim\\(\\)"): 145*da0073e9SAndroid Build Coastguard Worker masked_tensor(data, mask) 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker def test_diff_sizes(self, device): 148*da0073e9SAndroid Build Coastguard Worker data = torch.randn((3, 4), device=device) 149*da0073e9SAndroid Build Coastguard Worker mask = _create_random_mask((3, 3), device=device) 150*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "data.size\\(\\) must equal mask.size\\(\\)"): 151*da0073e9SAndroid Build Coastguard Worker masked_tensor(data, mask) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker def test_grad_warning(self, device): 154*da0073e9SAndroid Build Coastguard Worker data = torch.randn((3, 4), device=device, requires_grad=True) 155*da0073e9SAndroid Build Coastguard Worker mask = _create_random_mask((3, 4), device=device) 156*da0073e9SAndroid Build Coastguard Worker msg = "It is not recommended to create a MaskedTensor with a tensor that requires_grad." 157*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, msg): 158*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(data, mask) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker def test_add(self, device): 161*da0073e9SAndroid Build Coastguard Worker data = torch.arange(5.0, device=device) 162*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([True, True, False, True, False], device=device) 163*da0073e9SAndroid Build Coastguard Worker m0 = masked_tensor(data, mask) 164*da0073e9SAndroid Build Coastguard Worker m1 = masked_tensor(data, ~mask) 165*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Input masks must match."): 166*da0073e9SAndroid Build Coastguard Worker m0 + m1 167*da0073e9SAndroid Build Coastguard Worker _compare_mts(m0 + m0, masked_tensor(torch.tensor([0., 2, 0, 6, 0], device=device), mask)) 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker def test_softmax(self, device): 170*da0073e9SAndroid Build Coastguard Worker data = torch.randn((3, 4), device=device) * 0.1 171*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor( 172*da0073e9SAndroid Build Coastguard Worker [ 173*da0073e9SAndroid Build Coastguard Worker [True, True, True, False], 174*da0073e9SAndroid Build Coastguard Worker [False, True, False, True], 175*da0073e9SAndroid Build Coastguard Worker [True, True, False, False], 176*da0073e9SAndroid Build Coastguard Worker ], 177*da0073e9SAndroid Build Coastguard Worker device=device 178*da0073e9SAndroid Build Coastguard Worker ) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker _compare_forward_backward(data, mask, lambda t: torch.softmax(t, -1)) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker def test_where(self, device): 183*da0073e9SAndroid Build Coastguard Worker data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device) 184*da0073e9SAndroid Build Coastguard Worker mask = data < 0 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker mx = masked_tensor(data, mask, requires_grad=True) 187*da0073e9SAndroid Build Coastguard Worker my = masked_tensor(torch.ones_like(data), ~mask, requires_grad=True) 188*da0073e9SAndroid Build Coastguard Worker masked_res = torch.where(mask, torch.exp(mx), my) 189*da0073e9SAndroid Build Coastguard Worker masked_res.sum().backward() 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker x = data.detach().clone().requires_grad_() 192*da0073e9SAndroid Build Coastguard Worker y = torch.ones_like(x, device=device, requires_grad=True) 193*da0073e9SAndroid Build Coastguard Worker tensor_res = torch.where(mask, torch.exp(x), y) 194*da0073e9SAndroid Build Coastguard Worker tensor_res.sum().backward() 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(masked_res, tensor_res) 197*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(mx.grad, x.grad) 198*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(my.grad, y.grad) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker def test_unfold(self, device): 201*da0073e9SAndroid Build Coastguard Worker data = torch.rand(5, 5, device=device) 202*da0073e9SAndroid Build Coastguard Worker mask = torch.rand(5, 5, device=device) > 0.5 203*da0073e9SAndroid Build Coastguard Worker _compare_forward_backward(data, mask, lambda t: t.unfold(1, 2, 2)) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker def test_nn_unfold(self, device): 206*da0073e9SAndroid Build Coastguard Worker data = torch.rand(2, 5, 3, 4, device=device) 207*da0073e9SAndroid Build Coastguard Worker mask = torch.rand(2, 5, 3, 4, device=device) > 0.5 208*da0073e9SAndroid Build Coastguard Worker _compare_forward_backward(data, mask, lambda t: torch.nn.functional.unfold(t, kernel_size=(2, 3))) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker def test_stack(self, device): 211*da0073e9SAndroid Build Coastguard Worker masked_tensors = [ 212*da0073e9SAndroid Build Coastguard Worker masked_tensor( 213*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 5, 3, 4, device=device), 214*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 5, 3, 4, device=device) > 0.5, 215*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 216*da0073e9SAndroid Build Coastguard Worker ) for _ in range(3) 217*da0073e9SAndroid Build Coastguard Worker ] 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker data_tensors = [mt.get_data().detach().clone().requires_grad_() for mt in masked_tensors] 220*da0073e9SAndroid Build Coastguard Worker masked_res = torch.stack(masked_tensors) 221*da0073e9SAndroid Build Coastguard Worker tensor_res = torch.stack(data_tensors) 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker masked_res.sum().backward() 224*da0073e9SAndroid Build Coastguard Worker tensor_res.sum().backward() 225*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(masked_res, tensor_res) 226*da0073e9SAndroid Build Coastguard Worker for mt, t in zip(masked_tensors, data_tensors): 227*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(mt.grad, t.grad, atol=1e-06) 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker def test_to_sparse(self, device): 230*da0073e9SAndroid Build Coastguard Worker for sample in _generate_sample_data(device=device): 231*da0073e9SAndroid Build Coastguard Worker data = sample.input 232*da0073e9SAndroid Build Coastguard Worker mask = sample.kwargs["mask"] 233*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(data.clone().detach(), mask, requires_grad=True) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker sparse_mt = mt.to_sparse() 236*da0073e9SAndroid Build Coastguard Worker data.to_sparse().to_dense().sum().backward() 237*da0073e9SAndroid Build Coastguard Worker sparse_mt.to_dense().sum().backward() 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(sparse_mt, data) 240*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(mt.grad, data.grad) 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker def test_to_dense(self, device): 243*da0073e9SAndroid Build Coastguard Worker samples = _generate_sample_data( 244*da0073e9SAndroid Build Coastguard Worker device=device, 245*da0073e9SAndroid Build Coastguard Worker layout=torch.sparse_coo 246*da0073e9SAndroid Build Coastguard Worker ) + _generate_sample_data(device=device, layout=torch.sparse_csr) 247*da0073e9SAndroid Build Coastguard Worker for sample in samples: 248*da0073e9SAndroid Build Coastguard Worker data = sample.input 249*da0073e9SAndroid Build Coastguard Worker mask = sample.kwargs["mask"] 250*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(data, mask, requires_grad=True) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker dense_data = data.to_dense().detach().clone().requires_grad_(True) 253*da0073e9SAndroid Build Coastguard Worker dense_mt = mt.to_dense() 254*da0073e9SAndroid Build Coastguard Worker dense_data.sum().backward() 255*da0073e9SAndroid Build Coastguard Worker dense_mt.sum().backward() 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(dense_mt, dense_data) 258*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(mt.grad.to_dense(), dense_data.grad) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker def test_to_dense_and_sparse_coo(self, device): 261*da0073e9SAndroid Build Coastguard Worker for sample in _generate_sample_data(device=device, layout=torch.strided): 262*da0073e9SAndroid Build Coastguard Worker data = sample.input 263*da0073e9SAndroid Build Coastguard Worker mask = sample.kwargs["mask"] 264*da0073e9SAndroid Build Coastguard Worker ms = mask.to_sparse_coo().coalesce() 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(data, mask, requires_grad=True) 267*da0073e9SAndroid Build Coastguard Worker mts = masked_tensor(data.sparse_mask(ms), ms, requires_grad=True) 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker converted = mt.to_sparse().to_dense() 270*da0073e9SAndroid Build Coastguard Worker converted.sum().backward() 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker converted2 = mts.to_dense() 273*da0073e9SAndroid Build Coastguard Worker converted2.sum().backward() 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker _compare_mts(converted, converted2) 276*da0073e9SAndroid Build Coastguard Worker _compare_mts(mt.grad, mts.grad.to_dense()) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker def test_to_dense_and_sparse_csr(self, device): 279*da0073e9SAndroid Build Coastguard Worker for sample in _generate_sample_data(device=device, layout=torch.strided): 280*da0073e9SAndroid Build Coastguard Worker data = sample.input 281*da0073e9SAndroid Build Coastguard Worker mask = sample.kwargs["mask"] 282*da0073e9SAndroid Build Coastguard Worker if data.ndim != 2: 283*da0073e9SAndroid Build Coastguard Worker continue 284*da0073e9SAndroid Build Coastguard Worker ms = mask.to_sparse_csr() 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(data, mask, requires_grad=True) 287*da0073e9SAndroid Build Coastguard Worker mts = masked_tensor(data.sparse_mask(ms), ms, requires_grad=True) 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker converted = mt.to_sparse_csr().to_dense() 290*da0073e9SAndroid Build Coastguard Worker converted.sum().backward() 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker converted2 = mts.to_dense() 293*da0073e9SAndroid Build Coastguard Worker converted2.sum().backward() 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker _compare_mts(converted, converted2) 296*da0073e9SAndroid Build Coastguard Worker _compare_mts(mt.grad, mts.grad.to_dense()) 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker def test_invalid_sparse_layout(self, device): 299*da0073e9SAndroid Build Coastguard Worker data = torch.randn((3, 4), device=device).to_sparse_csc() 300*da0073e9SAndroid Build Coastguard Worker mask = _create_random_mask((3, 4), device=device).to_sparse_csc() 301*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "data layout of torch.sparse_csc is not supported"): 302*da0073e9SAndroid Build Coastguard Worker masked_tensor(data, mask) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker def test_invalid_sparse_coo_values(self, device): 305*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([3, 4, 5], dtype=torch.float32) 306*da0073e9SAndroid Build Coastguard Worker i1 = torch.tensor([[0, 1, 1], [2, 0, 2]]) 307*da0073e9SAndroid Build Coastguard Worker i2 = torch.tensor([[0, 1, 1], [2, 1, 2]]) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_coo_tensor(i1, v, (2, 4), device=device) 310*da0073e9SAndroid Build Coastguard Worker mask = torch.sparse_coo_tensor(i2, torch.tensor([True, True, True]), (2, 4), device=device) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker msg = "data and mask are both sparse COO tensors but do not have the same indices." 313*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 314*da0073e9SAndroid Build Coastguard Worker masked_tensor(t, mask) 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker def test_invalid_sparse_csr_values(self, device): 317*da0073e9SAndroid Build Coastguard Worker crow_indices1 = [0, 2, 3] 318*da0073e9SAndroid Build Coastguard Worker crow_indices2 = [0, 1, 3] 319*da0073e9SAndroid Build Coastguard Worker col_indices1 = [0, 1, 2] 320*da0073e9SAndroid Build Coastguard Worker col_indices2 = [1, 2, 3] 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker values = [2, 3, 4] 323*da0073e9SAndroid Build Coastguard Worker mask_values = [True, True, True] 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker t1 = torch.sparse_csr_tensor( 326*da0073e9SAndroid Build Coastguard Worker torch.tensor(crow_indices1, dtype=torch.int64), 327*da0073e9SAndroid Build Coastguard Worker torch.tensor(col_indices1, dtype=torch.int64), 328*da0073e9SAndroid Build Coastguard Worker torch.tensor(values), 329*da0073e9SAndroid Build Coastguard Worker size=(2, 4) 330*da0073e9SAndroid Build Coastguard Worker ) 331*da0073e9SAndroid Build Coastguard Worker mask1 = torch.sparse_csr_tensor( 332*da0073e9SAndroid Build Coastguard Worker torch.tensor(crow_indices2, dtype=torch.int64), 333*da0073e9SAndroid Build Coastguard Worker torch.tensor(col_indices1, dtype=torch.int64), 334*da0073e9SAndroid Build Coastguard Worker torch.tensor(mask_values), 335*da0073e9SAndroid Build Coastguard Worker dtype=torch.bool, 336*da0073e9SAndroid Build Coastguard Worker size=(2, 4), 337*da0073e9SAndroid Build Coastguard Worker ) 338*da0073e9SAndroid Build Coastguard Worker t2 = torch.sparse_csr_tensor( 339*da0073e9SAndroid Build Coastguard Worker torch.tensor(crow_indices2, dtype=torch.int64), 340*da0073e9SAndroid Build Coastguard Worker torch.tensor(col_indices1, dtype=torch.int64), 341*da0073e9SAndroid Build Coastguard Worker torch.tensor(values), 342*da0073e9SAndroid Build Coastguard Worker size=(2, 4), 343*da0073e9SAndroid Build Coastguard Worker ) 344*da0073e9SAndroid Build Coastguard Worker mask2 = torch.sparse_csr_tensor( 345*da0073e9SAndroid Build Coastguard Worker torch.tensor(crow_indices2, dtype=torch.int64), 346*da0073e9SAndroid Build Coastguard Worker torch.tensor(col_indices2, dtype=torch.int64), 347*da0073e9SAndroid Build Coastguard Worker torch.tensor(mask_values), 348*da0073e9SAndroid Build Coastguard Worker dtype=torch.bool, 349*da0073e9SAndroid Build Coastguard Worker size=(2, 4), 350*da0073e9SAndroid Build Coastguard Worker ) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker msg = "data and mask are both sparse CSR tensors but do not share either crow or col indices." 353*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 354*da0073e9SAndroid Build Coastguard Worker masked_tensor(t1, mask1) 355*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 356*da0073e9SAndroid Build Coastguard Worker masked_tensor(t2, mask2) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker def test_contiguous(self, device): 359*da0073e9SAndroid Build Coastguard Worker data = torch.randn((3, 3), device=device) 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker contiguous_data = data.clone() 362*da0073e9SAndroid Build Coastguard Worker mask1 = (contiguous_data > 0).bool() 363*da0073e9SAndroid Build Coastguard Worker not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2)) 364*da0073e9SAndroid Build Coastguard Worker mask2 = (not_contiguous_data > 0).bool() 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker contiguous_mt = masked_tensor(contiguous_data, mask1) 367*da0073e9SAndroid Build Coastguard Worker not_contiguous_mt = masked_tensor(not_contiguous_data, mask2) 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker contiguous_mt_sparse = masked_tensor( 370*da0073e9SAndroid Build Coastguard Worker contiguous_data.to_sparse_coo(), mask1.to_sparse_coo() 371*da0073e9SAndroid Build Coastguard Worker ) 372*da0073e9SAndroid Build Coastguard Worker not_contiguous_mt_sparse = masked_tensor( 373*da0073e9SAndroid Build Coastguard Worker not_contiguous_data.to_sparse_coo(), mask2.to_sparse_coo() 374*da0073e9SAndroid Build Coastguard Worker ) 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(contiguous_data.is_contiguous(), True) 377*da0073e9SAndroid Build Coastguard Worker self.assertEqual(not_contiguous_data.is_contiguous(), False) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker self.assertEqual(contiguous_mt.is_contiguous(), True) 380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(not_contiguous_mt.is_contiguous(), False) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker error_msg = "MaskedTensors with sparse data do not have is_contiguous" 383*da0073e9SAndroid Build Coastguard Worker for t in [contiguous_mt_sparse, not_contiguous_mt_sparse]: 384*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, error_msg): 385*da0073e9SAndroid Build Coastguard Worker t.is_contiguous() 386*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, error_msg): 387*da0073e9SAndroid Build Coastguard Worker t.contiguous() 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker now_contiguous_mt = not_contiguous_mt.contiguous() 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker _compare_mts(not_contiguous_mt, now_contiguous_mt) 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker self.assertEqual(now_contiguous_mt.is_contiguous(), True) 394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(now_contiguous_mt.get_data().is_contiguous(), True) 395*da0073e9SAndroid Build Coastguard Worker self.assertEqual(now_contiguous_mt.is_contiguous(), True) 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Workerclass TestUnary(TestCase): 398*da0073e9SAndroid Build Coastguard Worker def _get_test_data(self, fn_name): 399*da0073e9SAndroid Build Coastguard Worker data = torch.randn(10, 10) 400*da0073e9SAndroid Build Coastguard Worker mask = torch.rand(10, 10) > 0.5 401*da0073e9SAndroid Build Coastguard Worker fn_name = _fix_fn_name(fn_name) 402*da0073e9SAndroid Build Coastguard Worker if fn_name in ["log", "log10", "log1p", "log2", "sqrt"]: 403*da0073e9SAndroid Build Coastguard Worker data = data.mul(0.5).abs() 404*da0073e9SAndroid Build Coastguard Worker if fn_name in ["rsqrt"]: 405*da0073e9SAndroid Build Coastguard Worker data = data.abs() + 1 # Void division by zero 406*da0073e9SAndroid Build Coastguard Worker if fn_name in ["acos", "arccos", "asin", "arcsin", "logit"]: 407*da0073e9SAndroid Build Coastguard Worker data = data.abs().mul(0.5).clamp(0, 1) 408*da0073e9SAndroid Build Coastguard Worker if fn_name in ["atanh", "arctanh", "erfinv"]: 409*da0073e9SAndroid Build Coastguard Worker data = data.mul(0.5).clamp(-1, 1) 410*da0073e9SAndroid Build Coastguard Worker if fn_name in ["acosh", "arccosh"]: 411*da0073e9SAndroid Build Coastguard Worker data = data.abs() + 1 412*da0073e9SAndroid Build Coastguard Worker if fn_name in ["bitwise_not"]: 413*da0073e9SAndroid Build Coastguard Worker data = data.mul(128).to(torch.int8) 414*da0073e9SAndroid Build Coastguard Worker return data, mask 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker def _get_sample_kwargs(self, fn_name): 417*da0073e9SAndroid Build Coastguard Worker fn_name = _fix_fn_name(fn_name) 418*da0073e9SAndroid Build Coastguard Worker kwargs = {} 419*da0073e9SAndroid Build Coastguard Worker if fn_name in ["clamp", "clip"]: 420*da0073e9SAndroid Build Coastguard Worker kwargs["min"] = -0.5 421*da0073e9SAndroid Build Coastguard Worker kwargs["max"] = 0.5 422*da0073e9SAndroid Build Coastguard Worker return kwargs 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker def _get_sample_args(self, fn_name, data, mask): 425*da0073e9SAndroid Build Coastguard Worker fn_name = _fix_fn_name(fn_name) 426*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(data, mask) 427*da0073e9SAndroid Build Coastguard Worker t_args = [data] 428*da0073e9SAndroid Build Coastguard Worker mt_args = [mt] 429*da0073e9SAndroid Build Coastguard Worker if fn_name in ["pow"]: 430*da0073e9SAndroid Build Coastguard Worker t_args += [2.0] 431*da0073e9SAndroid Build Coastguard Worker mt_args += [2.0] 432*da0073e9SAndroid Build Coastguard Worker return t_args, mt_args 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker @parametrize("fn", NATIVE_UNARY_FNS) 435*da0073e9SAndroid Build Coastguard Worker def test_unary(self, fn): 436*da0073e9SAndroid Build Coastguard Worker torch.random.manual_seed(0) 437*da0073e9SAndroid Build Coastguard Worker fn_name = fn.__name__ 438*da0073e9SAndroid Build Coastguard Worker data, mask = self._get_test_data(fn_name) 439*da0073e9SAndroid Build Coastguard Worker kwargs = self._get_sample_kwargs(fn_name) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker t_args, mt_args = self._get_sample_args(fn_name, data, mask) 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker mt_result = fn(*mt_args, **kwargs) 444*da0073e9SAndroid Build Coastguard Worker t_result = fn(*t_args, **kwargs) 445*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(mt_result, t_result) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker @parametrize("fn", NATIVE_INPLACE_UNARY_FNS) 448*da0073e9SAndroid Build Coastguard Worker def test_inplace_unary(self, fn): 449*da0073e9SAndroid Build Coastguard Worker torch.random.manual_seed(0) 450*da0073e9SAndroid Build Coastguard Worker fn_name = fn.__name__ 451*da0073e9SAndroid Build Coastguard Worker data, mask = self._get_test_data(fn_name) 452*da0073e9SAndroid Build Coastguard Worker kwargs = self._get_sample_kwargs(fn_name) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker t_args, mt_args = self._get_sample_args(fn_name, data, mask) 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker mt_result = fn(*mt_args, **kwargs) 457*da0073e9SAndroid Build Coastguard Worker t_result = fn(*t_args, **kwargs) 458*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(mt_result, t_result) 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Workerclass TestBinary(TestCase): 461*da0073e9SAndroid Build Coastguard Worker def _get_test_data(self, fn_name): 462*da0073e9SAndroid Build Coastguard Worker fn_name = _fix_fn_name(fn_name) 463*da0073e9SAndroid Build Coastguard Worker data0 = torch.randn(10, 10) 464*da0073e9SAndroid Build Coastguard Worker data1 = torch.randn(10, 10) 465*da0073e9SAndroid Build Coastguard Worker mask = torch.rand(10, 10) > 0.5 466*da0073e9SAndroid Build Coastguard Worker if fn_name in ["bitwise_and", "bitwise_or", "bitwise_xor"]: 467*da0073e9SAndroid Build Coastguard Worker data0 = data0.mul(128).to(torch.int8) 468*da0073e9SAndroid Build Coastguard Worker data1 = data1.mul(128).to(torch.int8) 469*da0073e9SAndroid Build Coastguard Worker if fn_name in ["bitwise_left_shift", "bitwise_right_shift"]: 470*da0073e9SAndroid Build Coastguard Worker data0 = data0.abs().to(torch.int64) 471*da0073e9SAndroid Build Coastguard Worker data1 = data1.abs().to(torch.int64) 472*da0073e9SAndroid Build Coastguard Worker return data0, data1, mask 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker def _get_sample_kwargs(self, fn_name): 475*da0073e9SAndroid Build Coastguard Worker fn_name = _fix_fn_name(fn_name) 476*da0073e9SAndroid Build Coastguard Worker kwargs = {} 477*da0073e9SAndroid Build Coastguard Worker return kwargs 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker def _yield_sample_args(self, fn_name, data0, data1, mask): 480*da0073e9SAndroid Build Coastguard Worker """ Returns two sets of Tensor and MaskedTensor args for a binary function to compute. 481*da0073e9SAndroid Build Coastguard Worker Tensor args are all the same (just the two provided data tensors), 482*da0073e9SAndroid Build Coastguard Worker while the MaskedTensor args tests both (MaskedTensor, MaskedTensor) and (MaskedTensor, Tensor) 483*da0073e9SAndroid Build Coastguard Worker """ 484*da0073e9SAndroid Build Coastguard Worker fn_name = _fix_fn_name(fn_name) 485*da0073e9SAndroid Build Coastguard Worker mt0 = masked_tensor(data0, mask) 486*da0073e9SAndroid Build Coastguard Worker mt1 = masked_tensor(data1, mask) 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker t_args = [data0, data1] 489*da0073e9SAndroid Build Coastguard Worker mt_args = [mt0, mt1] 490*da0073e9SAndroid Build Coastguard Worker yield t_args, mt_args 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker t_args = [data0, data1] 493*da0073e9SAndroid Build Coastguard Worker mt_args = [mt0, data1] 494*da0073e9SAndroid Build Coastguard Worker yield t_args, mt_args 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker @parametrize("fn", NATIVE_BINARY_FNS) 497*da0073e9SAndroid Build Coastguard Worker def test_binary(self, fn): 498*da0073e9SAndroid Build Coastguard Worker torch.random.manual_seed(0) 499*da0073e9SAndroid Build Coastguard Worker fn_name = fn.__name__ 500*da0073e9SAndroid Build Coastguard Worker data0, data1, mask = self._get_test_data(fn_name) 501*da0073e9SAndroid Build Coastguard Worker kwargs = self._get_sample_kwargs(fn_name) 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker for (t_args, mt_args) in self._yield_sample_args(fn_name, data0, data1, mask): 504*da0073e9SAndroid Build Coastguard Worker mt_result = fn(*mt_args, **kwargs) 505*da0073e9SAndroid Build Coastguard Worker t_result = fn(*t_args, **kwargs) 506*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(mt_result, t_result) 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker @parametrize("fn", NATIVE_INPLACE_BINARY_FNS) 509*da0073e9SAndroid Build Coastguard Worker def test_inplace_binary(self, fn): 510*da0073e9SAndroid Build Coastguard Worker torch.random.manual_seed(0) 511*da0073e9SAndroid Build Coastguard Worker fn_name = fn.__name__ 512*da0073e9SAndroid Build Coastguard Worker data0, data1, mask = self._get_test_data(fn_name) 513*da0073e9SAndroid Build Coastguard Worker kwargs = self._get_sample_kwargs(fn_name) 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker for (t_args, mt_args) in self._yield_sample_args(fn_name, data0, data1, mask): 516*da0073e9SAndroid Build Coastguard Worker mt_result = fn(*mt_args, **kwargs) 517*da0073e9SAndroid Build Coastguard Worker t_result = fn(*t_args, **kwargs) 518*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(mt_result, t_result) 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker @parametrize("fn_name", ["add", "add_"]) 521*da0073e9SAndroid Build Coastguard Worker def test_masks_match(self, fn_name): 522*da0073e9SAndroid Build Coastguard Worker torch.random.manual_seed(0) 523*da0073e9SAndroid Build Coastguard Worker fn = getattr(torch.ops.aten, fn_name) 524*da0073e9SAndroid Build Coastguard Worker data0, data1, mask = self._get_test_data(fn_name) 525*da0073e9SAndroid Build Coastguard Worker mask0 = mask 526*da0073e9SAndroid Build Coastguard Worker mask1 = torch.rand(mask.size()) > 0.5 527*da0073e9SAndroid Build Coastguard Worker mt0 = masked_tensor(data0, mask0) 528*da0073e9SAndroid Build Coastguard Worker mt1 = masked_tensor(data1, mask1) 529*da0073e9SAndroid Build Coastguard Worker try: 530*da0073e9SAndroid Build Coastguard Worker fn(mt0, mt1) 531*da0073e9SAndroid Build Coastguard Worker raise AssertionError 532*da0073e9SAndroid Build Coastguard Worker except ValueError as e: 533*da0073e9SAndroid Build Coastguard Worker assert ( 534*da0073e9SAndroid Build Coastguard Worker "Input masks must match. If you need support for this, please open an issue on Github." 535*da0073e9SAndroid Build Coastguard Worker == str(e) 536*da0073e9SAndroid Build Coastguard Worker ) 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Workerclass TestReductions(TestCase): 539*da0073e9SAndroid Build Coastguard Worker def test_max_not_implemented(self): 540*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 541*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False], [False, True, False]]) 542*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m) 543*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "torch._ops.aten.max.default"): 544*da0073e9SAndroid Build Coastguard Worker mt.max() 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker def test_sum(self): 547*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 2, 6], [3, 4, 5.0, 7]]) 548*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False, True], [False, True, False, True]]) 549*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m) 550*da0073e9SAndroid Build Coastguard Worker _compare_mts(masked_tensor(torch.tensor(17.0), torch.tensor(True)), mt.sum()) 551*da0073e9SAndroid Build Coastguard Worker _compare_mts( 552*da0073e9SAndroid Build Coastguard Worker masked_tensor( 553*da0073e9SAndroid Build Coastguard Worker torch.tensor([0.0, 4.0, 1.0, 13]), 554*da0073e9SAndroid Build Coastguard Worker torch.tensor([True, True, False, True]), 555*da0073e9SAndroid Build Coastguard Worker ), 556*da0073e9SAndroid Build Coastguard Worker mt.sum(dim=0), 557*da0073e9SAndroid Build Coastguard Worker ) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker def test_sum_grad(self): 560*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 561*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False], [False, True, False]]) 562*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m, requires_grad=True) 563*da0073e9SAndroid Build Coastguard Worker mt.sum().backward() 564*da0073e9SAndroid Build Coastguard Worker _compare_mts(mt.grad, masked_tensor(torch.tensor(1.0).expand_as(m), m)) 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker def test_mean(self): 567*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 3, 2], [3, 4, 1.0, 4]]) 568*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False, True], [False, True, False, True]]) 569*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m) 570*da0073e9SAndroid Build Coastguard Worker _compare_mts(masked_tensor(torch.tensor(2.5), torch.tensor(True)), mt.mean()) 571*da0073e9SAndroid Build Coastguard Worker _compare_mts( 572*da0073e9SAndroid Build Coastguard Worker masked_tensor( 573*da0073e9SAndroid Build Coastguard Worker torch.tensor([0.0, 4.0, 1.0, 3]), 574*da0073e9SAndroid Build Coastguard Worker torch.tensor([True, True, False, True]), 575*da0073e9SAndroid Build Coastguard Worker ), 576*da0073e9SAndroid Build Coastguard Worker mt.mean(dim=0), 577*da0073e9SAndroid Build Coastguard Worker ) 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker """ 580*da0073e9SAndroid Build Coastguard Worker The following block of tests "test_mean_grad_case_1[a through e] are used to test the functionality of 581*da0073e9SAndroid Build Coastguard Worker the two different ways of constructing MaskedTensors: 582*da0073e9SAndroid Build Coastguard Worker masked_tensor(data, mask, requires_grad=True/False) -- NO differentiable constructor and always a leaf 583*da0073e9SAndroid Build Coastguard Worker as_masked_tensor(data, mask) -- differentiable constructor 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker Like torch.tensor(data), masked_tensor(data, mask) will provide a UserWarning if data.requires_grad=True 586*da0073e9SAndroid Build Coastguard Worker as_masked_tensor does not take in requires_grad -- it just takes on the requires_grad from data 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker Therefore, there are 6 cases to test and we use `mean` as a proxy to test the different combinations 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker Assuming mt.mean().backward() is run after each constructor: 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker Case 1a: 593*da0073e9SAndroid Build Coastguard Worker values.requires_grad = True 594*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(values, mask, requires_grad=True) 595*da0073e9SAndroid Build Coastguard Worker yields 596*da0073e9SAndroid Build Coastguard Worker - Provide a UserWarning because values.requires_grad=True 597*da0073e9SAndroid Build Coastguard Worker - values.grad = None 598*da0073e9SAndroid Build Coastguard Worker - mt.grad is a MaskedTensor with the correct gradient 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker Case 1b: 601*da0073e9SAndroid Build Coastguard Worker values.requires_grad = False 602*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(values, mask, requires_grad=True) 603*da0073e9SAndroid Build Coastguard Worker yields 604*da0073e9SAndroid Build Coastguard Worker - values.grad = None 605*da0073e9SAndroid Build Coastguard Worker - mt.grad is a MaskedTensor with the correct gradient 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker Case 2a/2b: 608*da0073e9SAndroid Build Coastguard Worker values.requires_grad = True/False 609*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(values, mask, requires_grad=False) 610*da0073e9SAndroid Build Coastguard Worker 611*da0073e9SAndroid Build Coastguard Worker will both yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn" 612*da0073e9SAndroid Build Coastguard Worker as expected. When values.requires_grad=True, we will also get a UserWarning 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Worker Case 3a: 615*da0073e9SAndroid Build Coastguard Worker values.requires_grad = True 616*da0073e9SAndroid Build Coastguard Worker mt = as_masked_tensor(values, mask) 617*da0073e9SAndroid Build Coastguard Worker yields 618*da0073e9SAndroid Build Coastguard Worker - values.grad is a MaskedTensor with the correct gradient 619*da0073e9SAndroid Build Coastguard Worker - mt.grad is None and gives a UserWarning that 620*da0073e9SAndroid Build Coastguard Worker "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad" 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker Case 3b: 623*da0073e9SAndroid Build Coastguard Worker values.requires_grad = False 624*da0073e9SAndroid Build Coastguard Worker mt = as_masked_tensor(values, mask) 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker will yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn" 627*da0073e9SAndroid Build Coastguard Worker as expected. 628*da0073e9SAndroid Build Coastguard Worker """ 629*da0073e9SAndroid Build Coastguard Worker def test_mean_grad_case_1a(self): 630*da0073e9SAndroid Build Coastguard Worker """ values.requires_grad = True 631*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(values, mask, requires_grad=True) 632*da0073e9SAndroid Build Coastguard Worker """ 633*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True) 634*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False], [False, True, False]]) 635*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"): 636*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m, requires_grad=True) 637*da0073e9SAndroid Build Coastguard Worker mt.mean().backward() 638*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(d.grad) 639*da0073e9SAndroid Build Coastguard Worker _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m)) 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker def test_mean_grad_case_1b(self): 642*da0073e9SAndroid Build Coastguard Worker """ values.requires_grad = False 643*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(values, mask, requires_grad=True) 644*da0073e9SAndroid Build Coastguard Worker """ 645*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 646*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False], [False, True, False]]) 647*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m, requires_grad=True) 648*da0073e9SAndroid Build Coastguard Worker mt.mean().backward() 649*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(d.grad) 650*da0073e9SAndroid Build Coastguard Worker _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m)) 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker def test_mean_grad_case_1c(self): 653*da0073e9SAndroid Build Coastguard Worker """ values.requires_grad = True 654*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(values, mask, requires_grad=False) 655*da0073e9SAndroid Build Coastguard Worker """ 656*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True) 657*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False], [False, True, False]]) 658*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"): 659*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m, requires_grad=False) 660*da0073e9SAndroid Build Coastguard Worker result = mt.mean() 661*da0073e9SAndroid Build Coastguard Worker msg = "element 0 of tensors does not require grad and does not have a grad_fn" 662*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 663*da0073e9SAndroid Build Coastguard Worker result.backward() 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker def test_mean_grad_case_1d(self): 667*da0073e9SAndroid Build Coastguard Worker """ values.requires_grad = False 668*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(values, mask, requires_grad=False) 669*da0073e9SAndroid Build Coastguard Worker """ 670*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 671*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False], [False, True, False]]) 672*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m, requires_grad=False) 673*da0073e9SAndroid Build Coastguard Worker result = mt.mean() 674*da0073e9SAndroid Build Coastguard Worker msg = "element 0 of tensors does not require grad and does not have a grad_fn" 675*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 676*da0073e9SAndroid Build Coastguard Worker result.backward() 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker def test_mean_grad_case_1e(self): 679*da0073e9SAndroid Build Coastguard Worker """ values.requires_grad = True 680*da0073e9SAndroid Build Coastguard Worker mt = as_masked_tensor(values, mask) 681*da0073e9SAndroid Build Coastguard Worker """ 682*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True) 683*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False], [False, True, False]]) 684*da0073e9SAndroid Build Coastguard Worker mt = as_masked_tensor(d, m) 685*da0073e9SAndroid Build Coastguard Worker mt.mean().backward() 686*da0073e9SAndroid Build Coastguard Worker _compare_mts(d.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m)) 687*da0073e9SAndroid Build Coastguard Worker msg = "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad" 688*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, msg): 689*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(mt.grad) 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Worker def test_mean_grad_case_1f(self): 692*da0073e9SAndroid Build Coastguard Worker """ values.requires_grad = False 693*da0073e9SAndroid Build Coastguard Worker mt = as_masked_tensor(values, mask) 694*da0073e9SAndroid Build Coastguard Worker """ 695*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 696*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False], [False, True, False]]) 697*da0073e9SAndroid Build Coastguard Worker mt = as_masked_tensor(d, m) 698*da0073e9SAndroid Build Coastguard Worker result = mt.mean() 699*da0073e9SAndroid Build Coastguard Worker msg = "element 0 of tensors does not require grad and does not have a grad_fn" 700*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 701*da0073e9SAndroid Build Coastguard Worker result.backward() 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker def test_mean_dim_grad(self): 704*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 705*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, True, False], [False, True, False]]) 706*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m, requires_grad=True) 707*da0073e9SAndroid Build Coastguard Worker mt.mean(1).sum().backward() 708*da0073e9SAndroid Build Coastguard Worker _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0.5, 0], [0, 1, 0]]), m)) 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker def test_amax(self): 711*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]]) 712*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False, True], [False, True, False, True]]) 713*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m) 714*da0073e9SAndroid Build Coastguard Worker _compare_mts(masked_tensor(torch.tensor(3.0), torch.tensor(True)), mt.amax()) 715*da0073e9SAndroid Build Coastguard Worker _compare_mts( 716*da0073e9SAndroid Build Coastguard Worker masked_tensor( 717*da0073e9SAndroid Build Coastguard Worker torch.tensor([0.0, -4.0, 1.0, 3]), 718*da0073e9SAndroid Build Coastguard Worker torch.tensor([True, True, False, True]), 719*da0073e9SAndroid Build Coastguard Worker ), 720*da0073e9SAndroid Build Coastguard Worker mt.amax(dim=0), 721*da0073e9SAndroid Build Coastguard Worker ) 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Worker def test_amax_grad(self): 724*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 725*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False], [False, True, False]]) 726*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m, requires_grad=True) 727*da0073e9SAndroid Build Coastguard Worker mt.amax().backward() 728*da0073e9SAndroid Build Coastguard Worker _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.0, 0, 0], [0, 1, 0]]), m)) 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker def test_amin(self): 731*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]]) 732*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False, True], [False, True, False, True]]) 733*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m) 734*da0073e9SAndroid Build Coastguard Worker _compare_mts(masked_tensor(torch.tensor(-4.0), torch.tensor(True)), mt.amin()) 735*da0073e9SAndroid Build Coastguard Worker _compare_mts( 736*da0073e9SAndroid Build Coastguard Worker masked_tensor( 737*da0073e9SAndroid Build Coastguard Worker torch.tensor([0.0, -4.0, 1.0, -3]), 738*da0073e9SAndroid Build Coastguard Worker torch.tensor([True, True, False, True]), 739*da0073e9SAndroid Build Coastguard Worker ), 740*da0073e9SAndroid Build Coastguard Worker mt.amin(dim=0), 741*da0073e9SAndroid Build Coastguard Worker ) 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker def test_amin_grad(self): 744*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 745*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False], [False, True, False]]) 746*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m, requires_grad=True) 747*da0073e9SAndroid Build Coastguard Worker mt.amin().backward() 748*da0073e9SAndroid Build Coastguard Worker _compare_mts(mt.grad, masked_tensor(torch.tensor([[1.0, 0, 0], [0, 0, 0]]), m)) 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker def test_prod(self): 751*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[0, 1, 3, 0.0], [float("nan"), 4, 1.0, 5.0]]) 752*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False, True], [False, True, False, True]]) 753*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m) 754*da0073e9SAndroid Build Coastguard Worker _compare_mts(masked_tensor(torch.tensor(0.0), torch.tensor(True)), mt.prod()) 755*da0073e9SAndroid Build Coastguard Worker _compare_mts( 756*da0073e9SAndroid Build Coastguard Worker masked_tensor( 757*da0073e9SAndroid Build Coastguard Worker torch.tensor([0.0, 4.0, 1.0, 0.0]), 758*da0073e9SAndroid Build Coastguard Worker torch.tensor([True, True, False, True]), 759*da0073e9SAndroid Build Coastguard Worker ), 760*da0073e9SAndroid Build Coastguard Worker mt.prod(dim=0), 761*da0073e9SAndroid Build Coastguard Worker ) 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Worker def test_prod_grad(self): 764*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[2, float("nan"), 2], [3, 4, 5.0]]) 765*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False], [False, True, False]]) 766*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m, requires_grad=True) 767*da0073e9SAndroid Build Coastguard Worker mt.prod().backward() 768*da0073e9SAndroid Build Coastguard Worker _compare_mts(mt.grad, masked_tensor(torch.tensor([[4.0, 0, 0], [0, 2, 0]]), m)) 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker def test_all(self): 771*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[True, True, False, False], [False, True, True, True]]) 772*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False, True], [False, True, False, True]]) 773*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m) 774*da0073e9SAndroid Build Coastguard Worker _compare_mts(masked_tensor(torch.tensor(False), torch.tensor(True)), mt.all()) 775*da0073e9SAndroid Build Coastguard Worker _compare_mts( 776*da0073e9SAndroid Build Coastguard Worker masked_tensor( 777*da0073e9SAndroid Build Coastguard Worker torch.tensor([True, True, True, False]), 778*da0073e9SAndroid Build Coastguard Worker torch.tensor([True, True, False, True]), 779*da0073e9SAndroid Build Coastguard Worker ), 780*da0073e9SAndroid Build Coastguard Worker mt.all(dim=0), 781*da0073e9SAndroid Build Coastguard Worker ) 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, True, False], [False, True, False, False]]) 784*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(d, m) 785*da0073e9SAndroid Build Coastguard Worker _compare_mts( 786*da0073e9SAndroid Build Coastguard Worker masked_tensor( 787*da0073e9SAndroid Build Coastguard Worker torch.tensor([True, True, False, True]), 788*da0073e9SAndroid Build Coastguard Worker torch.tensor([True, True, True, False]), 789*da0073e9SAndroid Build Coastguard Worker ), 790*da0073e9SAndroid Build Coastguard Worker mt.all(dim=0), 791*da0073e9SAndroid Build Coastguard Worker ) 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker def test_grad_dtype(self): 794*da0073e9SAndroid Build Coastguard Worker d = torch.tensor([[True, True, False], [False, True, True]]) 795*da0073e9SAndroid Build Coastguard Worker m = torch.tensor([[True, False, False], [False, True, False]]) 796*da0073e9SAndroid Build Coastguard Worker msg = "Only Tensors of floating point and complex dtype can require gradients" 797*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 798*da0073e9SAndroid Build Coastguard Worker masked_tensor(d, m, requires_grad=True) 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker def test_any_true_dtype(self): 801*da0073e9SAndroid Build Coastguard Worker mt = torch.masked.MaskedTensor( 802*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 2), 803*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 2) > 0.5 804*da0073e9SAndroid Build Coastguard Worker ) 805*da0073e9SAndroid Build Coastguard Worker msg = "expected a boolean tensor" 806*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 807*da0073e9SAndroid Build Coastguard Worker mt._is_any_true() 808*da0073e9SAndroid Build Coastguard Worker 809*da0073e9SAndroid Build Coastguard Worker def test__is_any_true(self): 810*da0073e9SAndroid Build Coastguard Worker mt = torch.masked.MaskedTensor( 811*da0073e9SAndroid Build Coastguard Worker torch.tensor([[True, True, False], [False, False, True]]), 812*da0073e9SAndroid Build Coastguard Worker torch.tensor([[True, False, False], [False, True, False]]), 813*da0073e9SAndroid Build Coastguard Worker ) 814*da0073e9SAndroid Build Coastguard Worker _compare_mts( 815*da0073e9SAndroid Build Coastguard Worker masked_tensor(torch.tensor(True), torch.tensor(True)), 816*da0073e9SAndroid Build Coastguard Worker mt._is_any_true(), 817*da0073e9SAndroid Build Coastguard Worker ) 818*da0073e9SAndroid Build Coastguard Worker 819*da0073e9SAndroid Build Coastguard Worker def test__is_any_true_false(self): 820*da0073e9SAndroid Build Coastguard Worker mt = torch.masked.MaskedTensor( 821*da0073e9SAndroid Build Coastguard Worker torch.tensor([[True, True, False], [False, False, True]]), 822*da0073e9SAndroid Build Coastguard Worker torch.tensor([[False, False, False], [False, False, False]]), 823*da0073e9SAndroid Build Coastguard Worker ) 824*da0073e9SAndroid Build Coastguard Worker _compare_mts( 825*da0073e9SAndroid Build Coastguard Worker masked_tensor(torch.tensor(False), torch.tensor(True),), 826*da0073e9SAndroid Build Coastguard Worker mt._is_any_true(), 827*da0073e9SAndroid Build Coastguard Worker ) 828*da0073e9SAndroid Build Coastguard Worker 829*da0073e9SAndroid Build Coastguard Worker def test_backward(self): 830*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/128557 831*da0073e9SAndroid Build Coastguard Worker with torch.autograd.detect_anomaly(): 832*da0073e9SAndroid Build Coastguard Worker mt = torch.masked.MaskedTensor( 833*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 2), 834*da0073e9SAndroid Build Coastguard Worker torch.rand(2, 2) > 0.5, 835*da0073e9SAndroid Build Coastguard Worker requires_grad=True 836*da0073e9SAndroid Build Coastguard Worker ) 837*da0073e9SAndroid Build Coastguard Worker mt.sum().backward() 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker 840*da0073e9SAndroid Build Coastguard Workerdef is_unary(op): 841*da0073e9SAndroid Build Coastguard Worker return op.name in UNARY_NAMES 842*da0073e9SAndroid Build Coastguard Worker 843*da0073e9SAndroid Build Coastguard Workerdef is_binary(op): 844*da0073e9SAndroid Build Coastguard Worker return op.name in BINARY_NAMES 845*da0073e9SAndroid Build Coastguard Worker 846*da0073e9SAndroid Build Coastguard Workerdef is_reduction(op): 847*da0073e9SAndroid Build Coastguard Worker return op.name in REDUCE_NAMES and op.name not in {"all", "mean", "std", "var"} 848*da0073e9SAndroid Build Coastguard Worker 849*da0073e9SAndroid Build Coastguard Workermt_unary_ufuncs = [op for op in unary_ufuncs if is_unary(op)] 850*da0073e9SAndroid Build Coastguard Workermt_binary_ufuncs = [op for op in binary_ufuncs if is_binary(op)] 851*da0073e9SAndroid Build Coastguard Workermt_reduction_ufuncs = [op for op in reduction_ops if is_reduction(op)] 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard WorkerMASKEDTENSOR_FLOAT_TYPES = { 854*da0073e9SAndroid Build Coastguard Worker torch.float16, 855*da0073e9SAndroid Build Coastguard Worker torch.float32, 856*da0073e9SAndroid Build Coastguard Worker torch.float64, 857*da0073e9SAndroid Build Coastguard Worker} 858*da0073e9SAndroid Build Coastguard Worker 859*da0073e9SAndroid Build Coastguard Workerclass TestOperators(TestCase): 860*da0073e9SAndroid Build Coastguard Worker def _convert_mt_args(self, args, mask, layout): 861*da0073e9SAndroid Build Coastguard Worker return [ 862*da0073e9SAndroid Build Coastguard Worker masked_tensor( 863*da0073e9SAndroid Build Coastguard Worker arg.sparse_mask(mask) if layout != torch.strided else arg, mask 864*da0073e9SAndroid Build Coastguard Worker ) 865*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(arg) 866*da0073e9SAndroid Build Coastguard Worker else arg 867*da0073e9SAndroid Build Coastguard Worker for arg in args 868*da0073e9SAndroid Build Coastguard Worker ] 869*da0073e9SAndroid Build Coastguard Worker 870*da0073e9SAndroid Build Coastguard Worker def _test_unary_binary_equality(self, device, dtype, op, layout=torch.strided): 871*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=True) 872*da0073e9SAndroid Build Coastguard Worker 873*da0073e9SAndroid Build Coastguard Worker for sample in samples: 874*da0073e9SAndroid Build Coastguard Worker input = sample.input 875*da0073e9SAndroid Build Coastguard Worker sample_args, sample_kwargs = sample.args, sample.kwargs 876*da0073e9SAndroid Build Coastguard Worker mask = ( 877*da0073e9SAndroid Build Coastguard Worker _create_random_mask(input.shape, device) 878*da0073e9SAndroid Build Coastguard Worker if "mask" not in sample_kwargs 879*da0073e9SAndroid Build Coastguard Worker else sample_kwargs.pop("mask") 880*da0073e9SAndroid Build Coastguard Worker ) 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker if layout == torch.sparse_coo: 883*da0073e9SAndroid Build Coastguard Worker mask = mask.to_sparse_coo().coalesce() 884*da0073e9SAndroid Build Coastguard Worker input = input.sparse_mask(mask) 885*da0073e9SAndroid Build Coastguard Worker elif layout == torch.sparse_csr: 886*da0073e9SAndroid Build Coastguard Worker if input.ndim != 2 or mask.ndim != 2: 887*da0073e9SAndroid Build Coastguard Worker continue 888*da0073e9SAndroid Build Coastguard Worker mask = mask.to_sparse_csr() 889*da0073e9SAndroid Build Coastguard Worker input = input.sparse_mask(mask) 890*da0073e9SAndroid Build Coastguard Worker 891*da0073e9SAndroid Build Coastguard Worker # Binary operations currently only support same size masks 892*da0073e9SAndroid Build Coastguard Worker if is_binary(op): 893*da0073e9SAndroid Build Coastguard Worker if input.shape != sample_args[0].shape: 894*da0073e9SAndroid Build Coastguard Worker continue 895*da0073e9SAndroid Build Coastguard Worker # Binary operations also don't support kwargs right now 896*da0073e9SAndroid Build Coastguard Worker else: 897*da0073e9SAndroid Build Coastguard Worker sample_kwargs = {} 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(input, mask) 900*da0073e9SAndroid Build Coastguard Worker mt_args = self._convert_mt_args(sample_args, mask, layout) 901*da0073e9SAndroid Build Coastguard Worker 902*da0073e9SAndroid Build Coastguard Worker mt_result = op(mt, *mt_args, **sample_kwargs) 903*da0073e9SAndroid Build Coastguard Worker t_result = op(sample.input, *sample_args, **sample_kwargs) 904*da0073e9SAndroid Build Coastguard Worker 905*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(mt_result, t_result) 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker # If the operation is binary, check that lhs = masked, rhs = regular tensor also works 908*da0073e9SAndroid Build Coastguard Worker if is_binary(op) and layout == torch.strided: 909*da0073e9SAndroid Build Coastguard Worker mt_result2 = op(mt, *sample_args, **sample_kwargs) 910*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(mt_result2, t_result) 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker def _test_reduction_equality(self, device, dtype, op, layout=torch.strided): 913*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype, requires_grad=True) 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker for sample in samples: 916*da0073e9SAndroid Build Coastguard Worker input = sample.input 917*da0073e9SAndroid Build Coastguard Worker # Reduction operations don't support more advanced args/kwargs right now 918*da0073e9SAndroid Build Coastguard Worker sample_args, sample_kwargs = (), {} 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker if input.dim() == 0 or input.numel() == 0: 921*da0073e9SAndroid Build Coastguard Worker continue 922*da0073e9SAndroid Build Coastguard Worker 923*da0073e9SAndroid Build Coastguard Worker mask = _create_random_mask(input.shape, device) 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker if torch.count_nonzero(mask) == 0: 926*da0073e9SAndroid Build Coastguard Worker continue 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker tensor_input = _combine_input_and_mask(op.op, input, mask) 929*da0073e9SAndroid Build Coastguard Worker if layout == torch.sparse_coo: 930*da0073e9SAndroid Build Coastguard Worker mask = mask.to_sparse_coo().coalesce() 931*da0073e9SAndroid Build Coastguard Worker input = input.sparse_mask(mask) 932*da0073e9SAndroid Build Coastguard Worker elif layout == torch.sparse_csr: 933*da0073e9SAndroid Build Coastguard Worker if input.ndim != 2 or mask.ndim != 2: 934*da0073e9SAndroid Build Coastguard Worker continue 935*da0073e9SAndroid Build Coastguard Worker mask = mask.to_sparse_csr() 936*da0073e9SAndroid Build Coastguard Worker input = input.sparse_mask(mask) 937*da0073e9SAndroid Build Coastguard Worker 938*da0073e9SAndroid Build Coastguard Worker mt = masked_tensor(input, mask) 939*da0073e9SAndroid Build Coastguard Worker mt_args = self._convert_mt_args(sample_args, mask, layout) 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker mt_result = op(mt, *mt_args, **sample_kwargs) 942*da0073e9SAndroid Build Coastguard Worker t_result = op(tensor_input, *sample_args, **sample_kwargs) 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker _compare_mt_t(mt_result, t_result) 945*da0073e9SAndroid Build Coastguard Worker 946*da0073e9SAndroid Build Coastguard Worker @ops(mt_unary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type] 947*da0073e9SAndroid Build Coastguard Worker @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr]) 948*da0073e9SAndroid Build Coastguard Worker def test_unary_core(self, device, dtype, op, layout): 949*da0073e9SAndroid Build Coastguard Worker # Skip tests that don't have len(kwargs) == 0 950*da0073e9SAndroid Build Coastguard Worker skip_variants = { 951*da0073e9SAndroid Build Coastguard Worker "decimals_0", 952*da0073e9SAndroid Build Coastguard Worker "decimals_3", 953*da0073e9SAndroid Build Coastguard Worker "decimals_neg_3", 954*da0073e9SAndroid Build Coastguard Worker } 955*da0073e9SAndroid Build Coastguard Worker if op.name == "round" and op.variant_test_name in skip_variants: 956*da0073e9SAndroid Build Coastguard Worker return 957*da0073e9SAndroid Build Coastguard Worker self._test_unary_binary_equality(device, dtype, op) 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker @ops(mt_binary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type] 960*da0073e9SAndroid Build Coastguard Worker @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr]) 961*da0073e9SAndroid Build Coastguard Worker # FIXME: 962*da0073e9SAndroid Build Coastguard Worker # Result is just wrong; production logic should be fixed 963*da0073e9SAndroid Build Coastguard Worker @decorateIf( 964*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 965*da0073e9SAndroid Build Coastguard Worker lambda params: ( 966*da0073e9SAndroid Build Coastguard Worker params["op"].name == "add" and 967*da0073e9SAndroid Build Coastguard Worker params["dtype"] in [torch.float16, torch.float32] and 968*da0073e9SAndroid Build Coastguard Worker params["device"] == "cpu" and 969*da0073e9SAndroid Build Coastguard Worker params["layout"] == torch.sparse_csr 970*da0073e9SAndroid Build Coastguard Worker ) 971*da0073e9SAndroid Build Coastguard Worker ) 972*da0073e9SAndroid Build Coastguard Worker # Result is just wrong; production logic should be fixed 973*da0073e9SAndroid Build Coastguard Worker @decorateIf( 974*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 975*da0073e9SAndroid Build Coastguard Worker lambda params: ( 976*da0073e9SAndroid Build Coastguard Worker params["op"].name == "sub" and 977*da0073e9SAndroid Build Coastguard Worker params["dtype"] in [torch.float16, torch.float32] and 978*da0073e9SAndroid Build Coastguard Worker params["device"] == "cpu" and 979*da0073e9SAndroid Build Coastguard Worker params["layout"] == torch.sparse_csr 980*da0073e9SAndroid Build Coastguard Worker ) 981*da0073e9SAndroid Build Coastguard Worker ) 982*da0073e9SAndroid Build Coastguard Worker # Result is just wrong; production logic should be fixed 983*da0073e9SAndroid Build Coastguard Worker @decorateIf( 984*da0073e9SAndroid Build Coastguard Worker unittest.expectedFailure, 985*da0073e9SAndroid Build Coastguard Worker lambda params: ( 986*da0073e9SAndroid Build Coastguard Worker params["op"].name == "eq" and 987*da0073e9SAndroid Build Coastguard Worker params["dtype"] == torch.float64 and 988*da0073e9SAndroid Build Coastguard Worker params["device"] == "cpu" and 989*da0073e9SAndroid Build Coastguard Worker params["layout"] == torch.sparse_csr 990*da0073e9SAndroid Build Coastguard Worker ) 991*da0073e9SAndroid Build Coastguard Worker ) 992*da0073e9SAndroid Build Coastguard Worker def test_binary_core(self, device, dtype, op, layout): 993*da0073e9SAndroid Build Coastguard Worker self._test_unary_binary_equality(device, dtype, op, layout) 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard Worker @ops(mt_reduction_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type] 996*da0073e9SAndroid Build Coastguard Worker @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr]) 997*da0073e9SAndroid Build Coastguard Worker def test_reduction_all(self, device, dtype, op, layout): 998*da0073e9SAndroid Build Coastguard Worker # argmin and argmax are not currently supported for torch.sparse_csr 999*da0073e9SAndroid Build Coastguard Worker if op.name in {"argmin", "argmax"} and layout == torch.sparse_csr: 1000*da0073e9SAndroid Build Coastguard Worker return 1001*da0073e9SAndroid Build Coastguard Worker 1002*da0073e9SAndroid Build Coastguard Worker self._test_reduction_equality(device, dtype, op, layout) 1003*da0073e9SAndroid Build Coastguard Worker 1004*da0073e9SAndroid Build Coastguard Worker 1005*da0073e9SAndroid Build Coastguard Workeronly_for = ("cpu", "cuda") 1006*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestOperators, globals(), only_for=only_for) 1007*da0073e9SAndroid Build Coastguard Worker 1008*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestBasics, globals(), only_for=only_for) 1009*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestUnary) 1010*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestBinary) 1011*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestReductions) 1012*da0073e9SAndroid Build Coastguard Worker 1013*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 1014*da0073e9SAndroid Build Coastguard Worker run_tests() 1015