1# Owner(s): ["module: masked operators"] 2 3import torch 4import unittest 5from torch.testing._internal.common_utils import ( 6 decorateIf, 7 TestCase, 8 run_tests, 9 make_tensor, 10 parametrize, 11 instantiate_parametrized_tests, 12) 13from torch.testing._internal.common_device_type import ( 14 instantiate_device_type_tests, 15 ops, 16) 17from torch.testing._internal.common_methods_invocations import ( 18 SampleInput, 19 binary_ufuncs, 20 reduction_ops, 21 unary_ufuncs, 22) 23 24from torch.masked import as_masked_tensor, masked_tensor, _combine_input_and_mask 25from torch.masked.maskedtensor.core import _masks_match, _tensors_match 26from torch.masked.maskedtensor.unary import NATIVE_INPLACE_UNARY_FNS, NATIVE_UNARY_FNS, UNARY_NAMES 27from torch.masked.maskedtensor.binary import NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS, BINARY_NAMES 28from torch.masked.maskedtensor.reductions import REDUCE_NAMES 29 30 31def _compare_mt_t(mt_result, t_result, rtol=1e-05, atol=1e-05): 32 mask = mt_result.get_mask() 33 mt_result_data = mt_result.get_data() 34 if mask.layout in {torch.sparse_coo, torch.sparse_csr}: 35 mask = mask.to_dense() 36 if mt_result_data.layout in {torch.sparse_coo, torch.sparse_csr}: 37 mt_result_data = mt_result_data.to_dense() 38 a = mt_result_data.detach().masked_fill_(~mask, 0) 39 b = t_result.detach().masked_fill_(~mask, 0) 40 if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol): 41 raise ValueError("The data in MaskedTensor a and Tensor b do not match") 42 43def _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08): 44 mt_data1 = mt1.get_data() 45 mt_data2 = mt2.get_data() 46 if mt_data1.layout != mt_data2.layout: 47 raise ValueError("mt1's data and mt2's data do not have the same layout. " 48 f"mt1.get_data().layout = {mt_data1.layout} while mt2.get_data().layout = {mt_data2.layout}") 49 50 mask = mt1.get_mask() 51 mask2 = mt2.get_mask() 52 if not _masks_match(mt1, mt2): 53 raise ValueError("mt1 and mt2 must have matching masks") 54 if mask.layout != mask2.layout: 55 raise ValueError("mt1's mask and mt2's mask do not have the same layout. " 56 f"mt1.get_mask().layout = {mask.layout} while mt2.get_mask().layout = {mask2.layout}") 57 if mask.layout in {torch.sparse_coo, torch.sparse_csr}: 58 mask = mask.to_dense() 59 60 if mt_data1.layout in {torch.sparse_coo, torch.sparse_csr}: 61 mt_data1 = mt_data1.to_dense() 62 mt_data2 = mt_data2.to_dense() 63 a = mt_data1.detach().masked_fill_(~mask, 0) 64 b = mt_data2.detach().masked_fill_(~mask, 0) 65 66 if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol): 67 raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match") 68 69def _compare_forward_backward(data, mask, fn): 70 mt = masked_tensor(data, mask, requires_grad=True) 71 masked_res = fn(mt) 72 masked_res.sum().backward() 73 74 t = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_() 75 tensor_res = fn(t) 76 tensor_res.sum().backward() 77 78 _compare_mt_t(masked_res, tensor_res) 79 _compare_mt_t(mt.grad, t.grad, atol=1e-06) 80 81 82def _create_random_mask(shape, device): 83 return make_tensor(shape, device=device, dtype=torch.bool) 84 85def _generate_sample_data( 86 device="cpu", dtype=torch.float, requires_grad=True, layout=torch.strided 87): 88 assert layout in { 89 torch.strided, 90 torch.sparse_coo, 91 torch.sparse_csr, 92 }, "Layout must be strided/sparse_coo/sparse_csr" 93 shapes = [ 94 [], 95 [2], 96 [3, 5], 97 [3, 2, 1, 2], 98 ] 99 inputs = [] 100 for s in shapes: 101 data = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) # type: ignore[arg-type] 102 mask = _create_random_mask(s, device) 103 if layout == torch.sparse_coo: 104 mask = mask.to_sparse_coo().coalesce() 105 data = data.sparse_mask(mask).requires_grad_(requires_grad) 106 elif layout == torch.sparse_csr: 107 if data.ndim != 2 and mask.ndim != 2: 108 continue 109 mask = mask.to_sparse_csr() 110 data = data.sparse_mask(mask) 111 inputs.append(SampleInput(data, kwargs={"mask": mask})) 112 return inputs 113 114def _fix_fn_name(fn_name): 115 if fn_name[-1] == "_": 116 fn_name = fn_name[:-1] 117 return fn_name 118 119 120class TestBasics(TestCase): 121 def test_invalid_tensor_inputs(self, device): 122 data = torch.randn((3, 4), device=device) 123 mask = _create_random_mask((3, 4), device=device) 124 mt = masked_tensor(data, mask) 125 126 with self.assertRaisesRegex(TypeError, "data must be a Tensor"): 127 masked_tensor(mt, mask) 128 with self.assertRaisesRegex(TypeError, "data must be a Tensor"): 129 masked_tensor(0, mask) 130 with self.assertRaisesRegex(TypeError, "mask must be a Tensor"): 131 masked_tensor(data, mt) 132 with self.assertRaisesRegex(TypeError, "mask must be a Tensor"): 133 masked_tensor(data, 0) 134 135 def test_diff_layouts(self, device): 136 data = torch.randn((3, 4), device=device).to_sparse_coo() 137 mask = _create_random_mask((3, 4), device=device) 138 with self.assertRaisesRegex(TypeError, "data and mask must have the same layout"): 139 masked_tensor(data, mask) 140 141 def test_diff_dim(self, device): 142 data = torch.randn((3, 4, 5), device=device) 143 mask = _create_random_mask((3, 4), device=device) 144 with self.assertRaisesRegex(ValueError, "data.dim\\(\\) must equal mask.dim\\(\\)"): 145 masked_tensor(data, mask) 146 147 def test_diff_sizes(self, device): 148 data = torch.randn((3, 4), device=device) 149 mask = _create_random_mask((3, 3), device=device) 150 with self.assertRaisesRegex(ValueError, "data.size\\(\\) must equal mask.size\\(\\)"): 151 masked_tensor(data, mask) 152 153 def test_grad_warning(self, device): 154 data = torch.randn((3, 4), device=device, requires_grad=True) 155 mask = _create_random_mask((3, 4), device=device) 156 msg = "It is not recommended to create a MaskedTensor with a tensor that requires_grad." 157 with self.assertWarnsRegex(UserWarning, msg): 158 mt = masked_tensor(data, mask) 159 160 def test_add(self, device): 161 data = torch.arange(5.0, device=device) 162 mask = torch.tensor([True, True, False, True, False], device=device) 163 m0 = masked_tensor(data, mask) 164 m1 = masked_tensor(data, ~mask) 165 with self.assertRaisesRegex(ValueError, "Input masks must match."): 166 m0 + m1 167 _compare_mts(m0 + m0, masked_tensor(torch.tensor([0., 2, 0, 6, 0], device=device), mask)) 168 169 def test_softmax(self, device): 170 data = torch.randn((3, 4), device=device) * 0.1 171 mask = torch.tensor( 172 [ 173 [True, True, True, False], 174 [False, True, False, True], 175 [True, True, False, False], 176 ], 177 device=device 178 ) 179 180 _compare_forward_backward(data, mask, lambda t: torch.softmax(t, -1)) 181 182 def test_where(self, device): 183 data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device) 184 mask = data < 0 185 186 mx = masked_tensor(data, mask, requires_grad=True) 187 my = masked_tensor(torch.ones_like(data), ~mask, requires_grad=True) 188 masked_res = torch.where(mask, torch.exp(mx), my) 189 masked_res.sum().backward() 190 191 x = data.detach().clone().requires_grad_() 192 y = torch.ones_like(x, device=device, requires_grad=True) 193 tensor_res = torch.where(mask, torch.exp(x), y) 194 tensor_res.sum().backward() 195 196 _compare_mt_t(masked_res, tensor_res) 197 _compare_mt_t(mx.grad, x.grad) 198 _compare_mt_t(my.grad, y.grad) 199 200 def test_unfold(self, device): 201 data = torch.rand(5, 5, device=device) 202 mask = torch.rand(5, 5, device=device) > 0.5 203 _compare_forward_backward(data, mask, lambda t: t.unfold(1, 2, 2)) 204 205 def test_nn_unfold(self, device): 206 data = torch.rand(2, 5, 3, 4, device=device) 207 mask = torch.rand(2, 5, 3, 4, device=device) > 0.5 208 _compare_forward_backward(data, mask, lambda t: torch.nn.functional.unfold(t, kernel_size=(2, 3))) 209 210 def test_stack(self, device): 211 masked_tensors = [ 212 masked_tensor( 213 torch.rand(2, 5, 3, 4, device=device), 214 torch.rand(2, 5, 3, 4, device=device) > 0.5, 215 requires_grad=True, 216 ) for _ in range(3) 217 ] 218 219 data_tensors = [mt.get_data().detach().clone().requires_grad_() for mt in masked_tensors] 220 masked_res = torch.stack(masked_tensors) 221 tensor_res = torch.stack(data_tensors) 222 223 masked_res.sum().backward() 224 tensor_res.sum().backward() 225 _compare_mt_t(masked_res, tensor_res) 226 for mt, t in zip(masked_tensors, data_tensors): 227 _compare_mt_t(mt.grad, t.grad, atol=1e-06) 228 229 def test_to_sparse(self, device): 230 for sample in _generate_sample_data(device=device): 231 data = sample.input 232 mask = sample.kwargs["mask"] 233 mt = masked_tensor(data.clone().detach(), mask, requires_grad=True) 234 235 sparse_mt = mt.to_sparse() 236 data.to_sparse().to_dense().sum().backward() 237 sparse_mt.to_dense().sum().backward() 238 239 _compare_mt_t(sparse_mt, data) 240 _compare_mt_t(mt.grad, data.grad) 241 242 def test_to_dense(self, device): 243 samples = _generate_sample_data( 244 device=device, 245 layout=torch.sparse_coo 246 ) + _generate_sample_data(device=device, layout=torch.sparse_csr) 247 for sample in samples: 248 data = sample.input 249 mask = sample.kwargs["mask"] 250 mt = masked_tensor(data, mask, requires_grad=True) 251 252 dense_data = data.to_dense().detach().clone().requires_grad_(True) 253 dense_mt = mt.to_dense() 254 dense_data.sum().backward() 255 dense_mt.sum().backward() 256 257 _compare_mt_t(dense_mt, dense_data) 258 _compare_mt_t(mt.grad.to_dense(), dense_data.grad) 259 260 def test_to_dense_and_sparse_coo(self, device): 261 for sample in _generate_sample_data(device=device, layout=torch.strided): 262 data = sample.input 263 mask = sample.kwargs["mask"] 264 ms = mask.to_sparse_coo().coalesce() 265 266 mt = masked_tensor(data, mask, requires_grad=True) 267 mts = masked_tensor(data.sparse_mask(ms), ms, requires_grad=True) 268 269 converted = mt.to_sparse().to_dense() 270 converted.sum().backward() 271 272 converted2 = mts.to_dense() 273 converted2.sum().backward() 274 275 _compare_mts(converted, converted2) 276 _compare_mts(mt.grad, mts.grad.to_dense()) 277 278 def test_to_dense_and_sparse_csr(self, device): 279 for sample in _generate_sample_data(device=device, layout=torch.strided): 280 data = sample.input 281 mask = sample.kwargs["mask"] 282 if data.ndim != 2: 283 continue 284 ms = mask.to_sparse_csr() 285 286 mt = masked_tensor(data, mask, requires_grad=True) 287 mts = masked_tensor(data.sparse_mask(ms), ms, requires_grad=True) 288 289 converted = mt.to_sparse_csr().to_dense() 290 converted.sum().backward() 291 292 converted2 = mts.to_dense() 293 converted2.sum().backward() 294 295 _compare_mts(converted, converted2) 296 _compare_mts(mt.grad, mts.grad.to_dense()) 297 298 def test_invalid_sparse_layout(self, device): 299 data = torch.randn((3, 4), device=device).to_sparse_csc() 300 mask = _create_random_mask((3, 4), device=device).to_sparse_csc() 301 with self.assertRaisesRegex(TypeError, "data layout of torch.sparse_csc is not supported"): 302 masked_tensor(data, mask) 303 304 def test_invalid_sparse_coo_values(self, device): 305 v = torch.tensor([3, 4, 5], dtype=torch.float32) 306 i1 = torch.tensor([[0, 1, 1], [2, 0, 2]]) 307 i2 = torch.tensor([[0, 1, 1], [2, 1, 2]]) 308 309 t = torch.sparse_coo_tensor(i1, v, (2, 4), device=device) 310 mask = torch.sparse_coo_tensor(i2, torch.tensor([True, True, True]), (2, 4), device=device) 311 312 msg = "data and mask are both sparse COO tensors but do not have the same indices." 313 with self.assertRaisesRegex(ValueError, msg): 314 masked_tensor(t, mask) 315 316 def test_invalid_sparse_csr_values(self, device): 317 crow_indices1 = [0, 2, 3] 318 crow_indices2 = [0, 1, 3] 319 col_indices1 = [0, 1, 2] 320 col_indices2 = [1, 2, 3] 321 322 values = [2, 3, 4] 323 mask_values = [True, True, True] 324 325 t1 = torch.sparse_csr_tensor( 326 torch.tensor(crow_indices1, dtype=torch.int64), 327 torch.tensor(col_indices1, dtype=torch.int64), 328 torch.tensor(values), 329 size=(2, 4) 330 ) 331 mask1 = torch.sparse_csr_tensor( 332 torch.tensor(crow_indices2, dtype=torch.int64), 333 torch.tensor(col_indices1, dtype=torch.int64), 334 torch.tensor(mask_values), 335 dtype=torch.bool, 336 size=(2, 4), 337 ) 338 t2 = torch.sparse_csr_tensor( 339 torch.tensor(crow_indices2, dtype=torch.int64), 340 torch.tensor(col_indices1, dtype=torch.int64), 341 torch.tensor(values), 342 size=(2, 4), 343 ) 344 mask2 = torch.sparse_csr_tensor( 345 torch.tensor(crow_indices2, dtype=torch.int64), 346 torch.tensor(col_indices2, dtype=torch.int64), 347 torch.tensor(mask_values), 348 dtype=torch.bool, 349 size=(2, 4), 350 ) 351 352 msg = "data and mask are both sparse CSR tensors but do not share either crow or col indices." 353 with self.assertRaisesRegex(ValueError, msg): 354 masked_tensor(t1, mask1) 355 with self.assertRaisesRegex(ValueError, msg): 356 masked_tensor(t2, mask2) 357 358 def test_contiguous(self, device): 359 data = torch.randn((3, 3), device=device) 360 361 contiguous_data = data.clone() 362 mask1 = (contiguous_data > 0).bool() 363 not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2)) 364 mask2 = (not_contiguous_data > 0).bool() 365 366 contiguous_mt = masked_tensor(contiguous_data, mask1) 367 not_contiguous_mt = masked_tensor(not_contiguous_data, mask2) 368 369 contiguous_mt_sparse = masked_tensor( 370 contiguous_data.to_sparse_coo(), mask1.to_sparse_coo() 371 ) 372 not_contiguous_mt_sparse = masked_tensor( 373 not_contiguous_data.to_sparse_coo(), mask2.to_sparse_coo() 374 ) 375 376 self.assertEqual(contiguous_data.is_contiguous(), True) 377 self.assertEqual(not_contiguous_data.is_contiguous(), False) 378 379 self.assertEqual(contiguous_mt.is_contiguous(), True) 380 self.assertEqual(not_contiguous_mt.is_contiguous(), False) 381 382 error_msg = "MaskedTensors with sparse data do not have is_contiguous" 383 for t in [contiguous_mt_sparse, not_contiguous_mt_sparse]: 384 with self.assertRaisesRegex(ValueError, error_msg): 385 t.is_contiguous() 386 with self.assertRaisesRegex(ValueError, error_msg): 387 t.contiguous() 388 389 now_contiguous_mt = not_contiguous_mt.contiguous() 390 391 _compare_mts(not_contiguous_mt, now_contiguous_mt) 392 393 self.assertEqual(now_contiguous_mt.is_contiguous(), True) 394 self.assertEqual(now_contiguous_mt.get_data().is_contiguous(), True) 395 self.assertEqual(now_contiguous_mt.is_contiguous(), True) 396 397class TestUnary(TestCase): 398 def _get_test_data(self, fn_name): 399 data = torch.randn(10, 10) 400 mask = torch.rand(10, 10) > 0.5 401 fn_name = _fix_fn_name(fn_name) 402 if fn_name in ["log", "log10", "log1p", "log2", "sqrt"]: 403 data = data.mul(0.5).abs() 404 if fn_name in ["rsqrt"]: 405 data = data.abs() + 1 # Void division by zero 406 if fn_name in ["acos", "arccos", "asin", "arcsin", "logit"]: 407 data = data.abs().mul(0.5).clamp(0, 1) 408 if fn_name in ["atanh", "arctanh", "erfinv"]: 409 data = data.mul(0.5).clamp(-1, 1) 410 if fn_name in ["acosh", "arccosh"]: 411 data = data.abs() + 1 412 if fn_name in ["bitwise_not"]: 413 data = data.mul(128).to(torch.int8) 414 return data, mask 415 416 def _get_sample_kwargs(self, fn_name): 417 fn_name = _fix_fn_name(fn_name) 418 kwargs = {} 419 if fn_name in ["clamp", "clip"]: 420 kwargs["min"] = -0.5 421 kwargs["max"] = 0.5 422 return kwargs 423 424 def _get_sample_args(self, fn_name, data, mask): 425 fn_name = _fix_fn_name(fn_name) 426 mt = masked_tensor(data, mask) 427 t_args = [data] 428 mt_args = [mt] 429 if fn_name in ["pow"]: 430 t_args += [2.0] 431 mt_args += [2.0] 432 return t_args, mt_args 433 434 @parametrize("fn", NATIVE_UNARY_FNS) 435 def test_unary(self, fn): 436 torch.random.manual_seed(0) 437 fn_name = fn.__name__ 438 data, mask = self._get_test_data(fn_name) 439 kwargs = self._get_sample_kwargs(fn_name) 440 441 t_args, mt_args = self._get_sample_args(fn_name, data, mask) 442 443 mt_result = fn(*mt_args, **kwargs) 444 t_result = fn(*t_args, **kwargs) 445 _compare_mt_t(mt_result, t_result) 446 447 @parametrize("fn", NATIVE_INPLACE_UNARY_FNS) 448 def test_inplace_unary(self, fn): 449 torch.random.manual_seed(0) 450 fn_name = fn.__name__ 451 data, mask = self._get_test_data(fn_name) 452 kwargs = self._get_sample_kwargs(fn_name) 453 454 t_args, mt_args = self._get_sample_args(fn_name, data, mask) 455 456 mt_result = fn(*mt_args, **kwargs) 457 t_result = fn(*t_args, **kwargs) 458 _compare_mt_t(mt_result, t_result) 459 460class TestBinary(TestCase): 461 def _get_test_data(self, fn_name): 462 fn_name = _fix_fn_name(fn_name) 463 data0 = torch.randn(10, 10) 464 data1 = torch.randn(10, 10) 465 mask = torch.rand(10, 10) > 0.5 466 if fn_name in ["bitwise_and", "bitwise_or", "bitwise_xor"]: 467 data0 = data0.mul(128).to(torch.int8) 468 data1 = data1.mul(128).to(torch.int8) 469 if fn_name in ["bitwise_left_shift", "bitwise_right_shift"]: 470 data0 = data0.abs().to(torch.int64) 471 data1 = data1.abs().to(torch.int64) 472 return data0, data1, mask 473 474 def _get_sample_kwargs(self, fn_name): 475 fn_name = _fix_fn_name(fn_name) 476 kwargs = {} 477 return kwargs 478 479 def _yield_sample_args(self, fn_name, data0, data1, mask): 480 """ Returns two sets of Tensor and MaskedTensor args for a binary function to compute. 481 Tensor args are all the same (just the two provided data tensors), 482 while the MaskedTensor args tests both (MaskedTensor, MaskedTensor) and (MaskedTensor, Tensor) 483 """ 484 fn_name = _fix_fn_name(fn_name) 485 mt0 = masked_tensor(data0, mask) 486 mt1 = masked_tensor(data1, mask) 487 488 t_args = [data0, data1] 489 mt_args = [mt0, mt1] 490 yield t_args, mt_args 491 492 t_args = [data0, data1] 493 mt_args = [mt0, data1] 494 yield t_args, mt_args 495 496 @parametrize("fn", NATIVE_BINARY_FNS) 497 def test_binary(self, fn): 498 torch.random.manual_seed(0) 499 fn_name = fn.__name__ 500 data0, data1, mask = self._get_test_data(fn_name) 501 kwargs = self._get_sample_kwargs(fn_name) 502 503 for (t_args, mt_args) in self._yield_sample_args(fn_name, data0, data1, mask): 504 mt_result = fn(*mt_args, **kwargs) 505 t_result = fn(*t_args, **kwargs) 506 _compare_mt_t(mt_result, t_result) 507 508 @parametrize("fn", NATIVE_INPLACE_BINARY_FNS) 509 def test_inplace_binary(self, fn): 510 torch.random.manual_seed(0) 511 fn_name = fn.__name__ 512 data0, data1, mask = self._get_test_data(fn_name) 513 kwargs = self._get_sample_kwargs(fn_name) 514 515 for (t_args, mt_args) in self._yield_sample_args(fn_name, data0, data1, mask): 516 mt_result = fn(*mt_args, **kwargs) 517 t_result = fn(*t_args, **kwargs) 518 _compare_mt_t(mt_result, t_result) 519 520 @parametrize("fn_name", ["add", "add_"]) 521 def test_masks_match(self, fn_name): 522 torch.random.manual_seed(0) 523 fn = getattr(torch.ops.aten, fn_name) 524 data0, data1, mask = self._get_test_data(fn_name) 525 mask0 = mask 526 mask1 = torch.rand(mask.size()) > 0.5 527 mt0 = masked_tensor(data0, mask0) 528 mt1 = masked_tensor(data1, mask1) 529 try: 530 fn(mt0, mt1) 531 raise AssertionError 532 except ValueError as e: 533 assert ( 534 "Input masks must match. If you need support for this, please open an issue on Github." 535 == str(e) 536 ) 537 538class TestReductions(TestCase): 539 def test_max_not_implemented(self): 540 d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 541 m = torch.tensor([[True, False, False], [False, True, False]]) 542 mt = masked_tensor(d, m) 543 with self.assertRaisesRegex(TypeError, "torch._ops.aten.max.default"): 544 mt.max() 545 546 def test_sum(self): 547 d = torch.tensor([[0, 1, 2, 6], [3, 4, 5.0, 7]]) 548 m = torch.tensor([[True, False, False, True], [False, True, False, True]]) 549 mt = masked_tensor(d, m) 550 _compare_mts(masked_tensor(torch.tensor(17.0), torch.tensor(True)), mt.sum()) 551 _compare_mts( 552 masked_tensor( 553 torch.tensor([0.0, 4.0, 1.0, 13]), 554 torch.tensor([True, True, False, True]), 555 ), 556 mt.sum(dim=0), 557 ) 558 559 def test_sum_grad(self): 560 d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 561 m = torch.tensor([[True, False, False], [False, True, False]]) 562 mt = masked_tensor(d, m, requires_grad=True) 563 mt.sum().backward() 564 _compare_mts(mt.grad, masked_tensor(torch.tensor(1.0).expand_as(m), m)) 565 566 def test_mean(self): 567 d = torch.tensor([[0, 1, 3, 2], [3, 4, 1.0, 4]]) 568 m = torch.tensor([[True, False, False, True], [False, True, False, True]]) 569 mt = masked_tensor(d, m) 570 _compare_mts(masked_tensor(torch.tensor(2.5), torch.tensor(True)), mt.mean()) 571 _compare_mts( 572 masked_tensor( 573 torch.tensor([0.0, 4.0, 1.0, 3]), 574 torch.tensor([True, True, False, True]), 575 ), 576 mt.mean(dim=0), 577 ) 578 579 """ 580 The following block of tests "test_mean_grad_case_1[a through e] are used to test the functionality of 581 the two different ways of constructing MaskedTensors: 582 masked_tensor(data, mask, requires_grad=True/False) -- NO differentiable constructor and always a leaf 583 as_masked_tensor(data, mask) -- differentiable constructor 584 585 Like torch.tensor(data), masked_tensor(data, mask) will provide a UserWarning if data.requires_grad=True 586 as_masked_tensor does not take in requires_grad -- it just takes on the requires_grad from data 587 588 Therefore, there are 6 cases to test and we use `mean` as a proxy to test the different combinations 589 590 Assuming mt.mean().backward() is run after each constructor: 591 592 Case 1a: 593 values.requires_grad = True 594 mt = masked_tensor(values, mask, requires_grad=True) 595 yields 596 - Provide a UserWarning because values.requires_grad=True 597 - values.grad = None 598 - mt.grad is a MaskedTensor with the correct gradient 599 600 Case 1b: 601 values.requires_grad = False 602 mt = masked_tensor(values, mask, requires_grad=True) 603 yields 604 - values.grad = None 605 - mt.grad is a MaskedTensor with the correct gradient 606 607 Case 2a/2b: 608 values.requires_grad = True/False 609 mt = masked_tensor(values, mask, requires_grad=False) 610 611 will both yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn" 612 as expected. When values.requires_grad=True, we will also get a UserWarning 613 614 Case 3a: 615 values.requires_grad = True 616 mt = as_masked_tensor(values, mask) 617 yields 618 - values.grad is a MaskedTensor with the correct gradient 619 - mt.grad is None and gives a UserWarning that 620 "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad" 621 622 Case 3b: 623 values.requires_grad = False 624 mt = as_masked_tensor(values, mask) 625 626 will yield a RuntimeError of "element 0 of tensors does not require grad and does not have a grad_fn" 627 as expected. 628 """ 629 def test_mean_grad_case_1a(self): 630 """ values.requires_grad = True 631 mt = masked_tensor(values, mask, requires_grad=True) 632 """ 633 d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True) 634 m = torch.tensor([[True, False, False], [False, True, False]]) 635 with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"): 636 mt = masked_tensor(d, m, requires_grad=True) 637 mt.mean().backward() 638 self.assertIsNone(d.grad) 639 _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m)) 640 641 def test_mean_grad_case_1b(self): 642 """ values.requires_grad = False 643 mt = masked_tensor(values, mask, requires_grad=True) 644 """ 645 d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 646 m = torch.tensor([[True, False, False], [False, True, False]]) 647 mt = masked_tensor(d, m, requires_grad=True) 648 mt.mean().backward() 649 self.assertIsNone(d.grad) 650 _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m)) 651 652 def test_mean_grad_case_1c(self): 653 """ values.requires_grad = True 654 mt = masked_tensor(values, mask, requires_grad=False) 655 """ 656 d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True) 657 m = torch.tensor([[True, False, False], [False, True, False]]) 658 with self.assertWarnsRegex(UserWarning, "It is not recommended to create a MaskedTensor"): 659 mt = masked_tensor(d, m, requires_grad=False) 660 result = mt.mean() 661 msg = "element 0 of tensors does not require grad and does not have a grad_fn" 662 with self.assertRaisesRegex(RuntimeError, msg): 663 result.backward() 664 665 666 def test_mean_grad_case_1d(self): 667 """ values.requires_grad = False 668 mt = masked_tensor(values, mask, requires_grad=False) 669 """ 670 d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 671 m = torch.tensor([[True, False, False], [False, True, False]]) 672 mt = masked_tensor(d, m, requires_grad=False) 673 result = mt.mean() 674 msg = "element 0 of tensors does not require grad and does not have a grad_fn" 675 with self.assertRaisesRegex(RuntimeError, msg): 676 result.backward() 677 678 def test_mean_grad_case_1e(self): 679 """ values.requires_grad = True 680 mt = as_masked_tensor(values, mask) 681 """ 682 d = torch.tensor([[0, 1, 2], [3, 4, 5.0]], requires_grad=True) 683 m = torch.tensor([[True, False, False], [False, True, False]]) 684 mt = as_masked_tensor(d, m) 685 mt.mean().backward() 686 _compare_mts(d.grad, masked_tensor(torch.tensor([[0.5, 0, 0], [0, 0.5, 0]]), m)) 687 msg = "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad" 688 with self.assertWarnsRegex(UserWarning, msg): 689 self.assertIsNone(mt.grad) 690 691 def test_mean_grad_case_1f(self): 692 """ values.requires_grad = False 693 mt = as_masked_tensor(values, mask) 694 """ 695 d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 696 m = torch.tensor([[True, False, False], [False, True, False]]) 697 mt = as_masked_tensor(d, m) 698 result = mt.mean() 699 msg = "element 0 of tensors does not require grad and does not have a grad_fn" 700 with self.assertRaisesRegex(RuntimeError, msg): 701 result.backward() 702 703 def test_mean_dim_grad(self): 704 d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 705 m = torch.tensor([[True, True, False], [False, True, False]]) 706 mt = masked_tensor(d, m, requires_grad=True) 707 mt.mean(1).sum().backward() 708 _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.5, 0.5, 0], [0, 1, 0]]), m)) 709 710 def test_amax(self): 711 d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]]) 712 m = torch.tensor([[True, False, False, True], [False, True, False, True]]) 713 mt = masked_tensor(d, m) 714 _compare_mts(masked_tensor(torch.tensor(3.0), torch.tensor(True)), mt.amax()) 715 _compare_mts( 716 masked_tensor( 717 torch.tensor([0.0, -4.0, 1.0, 3]), 718 torch.tensor([True, True, False, True]), 719 ), 720 mt.amax(dim=0), 721 ) 722 723 def test_amax_grad(self): 724 d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 725 m = torch.tensor([[True, False, False], [False, True, False]]) 726 mt = masked_tensor(d, m, requires_grad=True) 727 mt.amax().backward() 728 _compare_mts(mt.grad, masked_tensor(torch.tensor([[0.0, 0, 0], [0, 1, 0]]), m)) 729 730 def test_amin(self): 731 d = torch.tensor([[0, 1, 3, -3], [3, -4, 1.0, 3]]) 732 m = torch.tensor([[True, False, False, True], [False, True, False, True]]) 733 mt = masked_tensor(d, m) 734 _compare_mts(masked_tensor(torch.tensor(-4.0), torch.tensor(True)), mt.amin()) 735 _compare_mts( 736 masked_tensor( 737 torch.tensor([0.0, -4.0, 1.0, -3]), 738 torch.tensor([True, True, False, True]), 739 ), 740 mt.amin(dim=0), 741 ) 742 743 def test_amin_grad(self): 744 d = torch.tensor([[0, 1, 2], [3, 4, 5.0]]) 745 m = torch.tensor([[True, False, False], [False, True, False]]) 746 mt = masked_tensor(d, m, requires_grad=True) 747 mt.amin().backward() 748 _compare_mts(mt.grad, masked_tensor(torch.tensor([[1.0, 0, 0], [0, 0, 0]]), m)) 749 750 def test_prod(self): 751 d = torch.tensor([[0, 1, 3, 0.0], [float("nan"), 4, 1.0, 5.0]]) 752 m = torch.tensor([[True, False, False, True], [False, True, False, True]]) 753 mt = masked_tensor(d, m) 754 _compare_mts(masked_tensor(torch.tensor(0.0), torch.tensor(True)), mt.prod()) 755 _compare_mts( 756 masked_tensor( 757 torch.tensor([0.0, 4.0, 1.0, 0.0]), 758 torch.tensor([True, True, False, True]), 759 ), 760 mt.prod(dim=0), 761 ) 762 763 def test_prod_grad(self): 764 d = torch.tensor([[2, float("nan"), 2], [3, 4, 5.0]]) 765 m = torch.tensor([[True, False, False], [False, True, False]]) 766 mt = masked_tensor(d, m, requires_grad=True) 767 mt.prod().backward() 768 _compare_mts(mt.grad, masked_tensor(torch.tensor([[4.0, 0, 0], [0, 2, 0]]), m)) 769 770 def test_all(self): 771 d = torch.tensor([[True, True, False, False], [False, True, True, True]]) 772 m = torch.tensor([[True, False, False, True], [False, True, False, True]]) 773 mt = masked_tensor(d, m) 774 _compare_mts(masked_tensor(torch.tensor(False), torch.tensor(True)), mt.all()) 775 _compare_mts( 776 masked_tensor( 777 torch.tensor([True, True, True, False]), 778 torch.tensor([True, True, False, True]), 779 ), 780 mt.all(dim=0), 781 ) 782 783 m = torch.tensor([[True, False, True, False], [False, True, False, False]]) 784 mt = masked_tensor(d, m) 785 _compare_mts( 786 masked_tensor( 787 torch.tensor([True, True, False, True]), 788 torch.tensor([True, True, True, False]), 789 ), 790 mt.all(dim=0), 791 ) 792 793 def test_grad_dtype(self): 794 d = torch.tensor([[True, True, False], [False, True, True]]) 795 m = torch.tensor([[True, False, False], [False, True, False]]) 796 msg = "Only Tensors of floating point and complex dtype can require gradients" 797 with self.assertRaisesRegex(RuntimeError, msg): 798 masked_tensor(d, m, requires_grad=True) 799 800 def test_any_true_dtype(self): 801 mt = torch.masked.MaskedTensor( 802 torch.rand(2, 2), 803 torch.rand(2, 2) > 0.5 804 ) 805 msg = "expected a boolean tensor" 806 with self.assertRaisesRegex(ValueError, msg): 807 mt._is_any_true() 808 809 def test__is_any_true(self): 810 mt = torch.masked.MaskedTensor( 811 torch.tensor([[True, True, False], [False, False, True]]), 812 torch.tensor([[True, False, False], [False, True, False]]), 813 ) 814 _compare_mts( 815 masked_tensor(torch.tensor(True), torch.tensor(True)), 816 mt._is_any_true(), 817 ) 818 819 def test__is_any_true_false(self): 820 mt = torch.masked.MaskedTensor( 821 torch.tensor([[True, True, False], [False, False, True]]), 822 torch.tensor([[False, False, False], [False, False, False]]), 823 ) 824 _compare_mts( 825 masked_tensor(torch.tensor(False), torch.tensor(True),), 826 mt._is_any_true(), 827 ) 828 829 def test_backward(self): 830 # See https://github.com/pytorch/pytorch/issues/128557 831 with torch.autograd.detect_anomaly(): 832 mt = torch.masked.MaskedTensor( 833 torch.rand(2, 2), 834 torch.rand(2, 2) > 0.5, 835 requires_grad=True 836 ) 837 mt.sum().backward() 838 839 840def is_unary(op): 841 return op.name in UNARY_NAMES 842 843def is_binary(op): 844 return op.name in BINARY_NAMES 845 846def is_reduction(op): 847 return op.name in REDUCE_NAMES and op.name not in {"all", "mean", "std", "var"} 848 849mt_unary_ufuncs = [op for op in unary_ufuncs if is_unary(op)] 850mt_binary_ufuncs = [op for op in binary_ufuncs if is_binary(op)] 851mt_reduction_ufuncs = [op for op in reduction_ops if is_reduction(op)] 852 853MASKEDTENSOR_FLOAT_TYPES = { 854 torch.float16, 855 torch.float32, 856 torch.float64, 857} 858 859class TestOperators(TestCase): 860 def _convert_mt_args(self, args, mask, layout): 861 return [ 862 masked_tensor( 863 arg.sparse_mask(mask) if layout != torch.strided else arg, mask 864 ) 865 if torch.is_tensor(arg) 866 else arg 867 for arg in args 868 ] 869 870 def _test_unary_binary_equality(self, device, dtype, op, layout=torch.strided): 871 samples = op.sample_inputs(device, dtype, requires_grad=True) 872 873 for sample in samples: 874 input = sample.input 875 sample_args, sample_kwargs = sample.args, sample.kwargs 876 mask = ( 877 _create_random_mask(input.shape, device) 878 if "mask" not in sample_kwargs 879 else sample_kwargs.pop("mask") 880 ) 881 882 if layout == torch.sparse_coo: 883 mask = mask.to_sparse_coo().coalesce() 884 input = input.sparse_mask(mask) 885 elif layout == torch.sparse_csr: 886 if input.ndim != 2 or mask.ndim != 2: 887 continue 888 mask = mask.to_sparse_csr() 889 input = input.sparse_mask(mask) 890 891 # Binary operations currently only support same size masks 892 if is_binary(op): 893 if input.shape != sample_args[0].shape: 894 continue 895 # Binary operations also don't support kwargs right now 896 else: 897 sample_kwargs = {} 898 899 mt = masked_tensor(input, mask) 900 mt_args = self._convert_mt_args(sample_args, mask, layout) 901 902 mt_result = op(mt, *mt_args, **sample_kwargs) 903 t_result = op(sample.input, *sample_args, **sample_kwargs) 904 905 _compare_mt_t(mt_result, t_result) 906 907 # If the operation is binary, check that lhs = masked, rhs = regular tensor also works 908 if is_binary(op) and layout == torch.strided: 909 mt_result2 = op(mt, *sample_args, **sample_kwargs) 910 _compare_mt_t(mt_result2, t_result) 911 912 def _test_reduction_equality(self, device, dtype, op, layout=torch.strided): 913 samples = op.sample_inputs(device, dtype, requires_grad=True) 914 915 for sample in samples: 916 input = sample.input 917 # Reduction operations don't support more advanced args/kwargs right now 918 sample_args, sample_kwargs = (), {} 919 920 if input.dim() == 0 or input.numel() == 0: 921 continue 922 923 mask = _create_random_mask(input.shape, device) 924 925 if torch.count_nonzero(mask) == 0: 926 continue 927 928 tensor_input = _combine_input_and_mask(op.op, input, mask) 929 if layout == torch.sparse_coo: 930 mask = mask.to_sparse_coo().coalesce() 931 input = input.sparse_mask(mask) 932 elif layout == torch.sparse_csr: 933 if input.ndim != 2 or mask.ndim != 2: 934 continue 935 mask = mask.to_sparse_csr() 936 input = input.sparse_mask(mask) 937 938 mt = masked_tensor(input, mask) 939 mt_args = self._convert_mt_args(sample_args, mask, layout) 940 941 mt_result = op(mt, *mt_args, **sample_kwargs) 942 t_result = op(tensor_input, *sample_args, **sample_kwargs) 943 944 _compare_mt_t(mt_result, t_result) 945 946 @ops(mt_unary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type] 947 @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr]) 948 def test_unary_core(self, device, dtype, op, layout): 949 # Skip tests that don't have len(kwargs) == 0 950 skip_variants = { 951 "decimals_0", 952 "decimals_3", 953 "decimals_neg_3", 954 } 955 if op.name == "round" and op.variant_test_name in skip_variants: 956 return 957 self._test_unary_binary_equality(device, dtype, op) 958 959 @ops(mt_binary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type] 960 @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr]) 961 # FIXME: 962 # Result is just wrong; production logic should be fixed 963 @decorateIf( 964 unittest.expectedFailure, 965 lambda params: ( 966 params["op"].name == "add" and 967 params["dtype"] in [torch.float16, torch.float32] and 968 params["device"] == "cpu" and 969 params["layout"] == torch.sparse_csr 970 ) 971 ) 972 # Result is just wrong; production logic should be fixed 973 @decorateIf( 974 unittest.expectedFailure, 975 lambda params: ( 976 params["op"].name == "sub" and 977 params["dtype"] in [torch.float16, torch.float32] and 978 params["device"] == "cpu" and 979 params["layout"] == torch.sparse_csr 980 ) 981 ) 982 # Result is just wrong; production logic should be fixed 983 @decorateIf( 984 unittest.expectedFailure, 985 lambda params: ( 986 params["op"].name == "eq" and 987 params["dtype"] == torch.float64 and 988 params["device"] == "cpu" and 989 params["layout"] == torch.sparse_csr 990 ) 991 ) 992 def test_binary_core(self, device, dtype, op, layout): 993 self._test_unary_binary_equality(device, dtype, op, layout) 994 995 @ops(mt_reduction_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type] 996 @parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr]) 997 def test_reduction_all(self, device, dtype, op, layout): 998 # argmin and argmax are not currently supported for torch.sparse_csr 999 if op.name in {"argmin", "argmax"} and layout == torch.sparse_csr: 1000 return 1001 1002 self._test_reduction_equality(device, dtype, op, layout) 1003 1004 1005only_for = ("cpu", "cuda") 1006instantiate_device_type_tests(TestOperators, globals(), only_for=only_for) 1007 1008instantiate_device_type_tests(TestBasics, globals(), only_for=only_for) 1009instantiate_parametrized_tests(TestUnary) 1010instantiate_parametrized_tests(TestBinary) 1011instantiate_parametrized_tests(TestReductions) 1012 1013if __name__ == '__main__': 1014 run_tests() 1015