1# Owner(s): ["module: masked operators"] 2 3"""Tests for masked operations. 4""" 5 6import itertools 7import torch 8from typing import List, Any 9from functools import wraps 10import unittest 11from torch.testing._internal.common_utils import skipIfTorchDynamo 12 13 14from torch.testing._internal.common_utils import \ 15 (TestCase, parametrize, suppress_warnings, _TestParametrizer, run_tests) 16from torch.testing._internal.common_methods_invocations import \ 17 (op_db, SampleInput) 18from torch.testing._internal.common_device_type import \ 19 (instantiate_device_type_tests, ops, onlyNativeDeviceTypes, precisionOverride) 20 21 22def apply_masked_reduction_along_dim(op, input, *args, **kwargs): 23 """Applies reduction op along given dimension to strided x 24 elements that are valid according to mask tensor. 25 26 The op is applied to each elementary slice of input with args and 27 kwargs with the following constraints: 28 29 1. Prior applying the op: 30 31 A. if kwargs contains an item with key 'dim_position' then it is 32 removed from kwargs. The value of 'dim_position' is an 33 integer that describes the dim argument position: while 34 typically the dim argument appears at the 0-th position of 35 the op arguments (excluding input), for instance, sum(input, 36 dim), then there exists reductions that have extra arguments 37 prior the dim argument, for instance, norm(input, ord, dim). 38 39 B. if args or kwargs contains dim or keepdim arguments, these 40 will be removed or replaced with None so that the op is 41 applied to elementary slice using the default dim and keepdim 42 value. 43 44 2. The elementary slice of the input is defined as the flattened 45 slice that has no masked out elements and when op is applied, 46 the result will be a scalar value (assuming keepdim=False). For 47 example, an input tensor to a reduction operation op having 48 dim=0 and keepdim=True argument: 49 50 [[1 * 2 * *] 51 [* 3 4 * 5]] 52 53 (* denotes masked out elements) has the following elementary 54 slices: [1, 2] and [3, 4, 5]. The result of 55 apply_masked_reduction_along_dim is 56 57 [[op([1, 2], *args0, **kwargs, dim=None, keepdim=False)] 58 [op([3, 4, 5], *args0, **kwargs, dim=None, keepdim=False)]] 59 60 where args0 is args where dim value is replased with None if 61 present. 62 63 Using the same example data, if the op is called with dim=(0, 1) 64 and keepdim=False, there is one elementary slice: [1, 2, 3, 4, 65 5]; and the corresponding result of the op is: 66 67 op([1, 2, 3, 4, 5], *args0, **kwargs, dim=None, keepdim=False) 68 69 3. If the elementary slice is empty, the corresponding output 70 value is nan if dtype is float, otherwise, 0. An empty 71 elementary slice corresponds to fully masked-out output, so, the 72 corresponding specific value of the output will not be important 73 because we used masked equality check for comparing the results 74 of masked operations. 75 """ 76 # eliminate mask and dim_position keyword arguments: 77 mask = kwargs.pop('mask', None) 78 dim_pos = kwargs.pop('dim_position', 0) 79 80 dtype = kwargs.get('dtype', input.dtype) 81 if input.ndim == 0: 82 # scalar input is an elementary slice 83 return op(input, *args, **kwargs).to(dtype=dtype) 84 85 # eliminate keepdim keyword argument if specified: 86 keepdim = kwargs.pop('keepdim', False) 87 88 # eliminate dim argument that may appear both as args or kwargs 89 # element: 90 if dim_pos < len(args): 91 # dim is specified in args 92 assert 'dim' not in kwargs, (args, kwargs) 93 dim = args[dim_pos] 94 args0 = args[:dim_pos] + (None,) + args[dim_pos + 1:] 95 else: 96 # dim may be specified in kwargs 97 dim = kwargs.pop('dim', None) 98 args0 = args 99 100 # dimensions along which the reduction operation is applied: 101 dim_ = torch.masked._canonical_dim(dim, input.ndim) 102 # slices in product(*ranges) define all elementary slices: 103 ranges: List[Any] = [] 104 # shape of output for the keepdim=True case: 105 shape = [] 106 for i in range(input.ndim): 107 if i in dim_: 108 ranges.append((slice(None),)) 109 shape.append(1) 110 else: 111 ranges.append(range(input.shape[i])) 112 shape.append(input.shape[i]) 113 114 # keepdim=True version of the output, filled with nan or 0: 115 output = input.new_full(shape, float('nan') if dtype.is_floating_point else 0, dtype=dtype) 116 117 # apply op to all elementary slices: 118 if mask is None: 119 inpmask = input.new_ones([], dtype=torch.bool).expand(input.shape) 120 else: 121 inpmask = torch.masked._input_mask(input, mask=mask) 122 for s in itertools.product(*ranges): 123 # data of an elementary slice is 1D sequence and has only 124 # masked-in elements: 125 data = input[s].flatten()[inpmask[s].flatten().argwhere()] 126 if not data.numel(): 127 # empty elementary slice 128 continue 129 output[s][0] = op(data, *args0, **kwargs) 130 131 if not keepdim: 132 # reshape output for the keepdim=False case 133 shape = [shape[i] for i in range(len(shape)) if i not in dim_] 134 output = output.reshape(shape) 135 return output 136 137 138def apply_masked_normalization_along_dim(op, input, *args, **kwargs): 139 """Applies normalization op along given dimension to strided x 140 elements that are valid according to mask tensor. 141 """ 142 mask = kwargs.pop('mask', None) 143 dim_pos = kwargs.pop('dim_position', 0) 144 if input.ndim == 0: # scalar input 145 return op(input, *args, **kwargs) 146 dtype = kwargs.get('dtype', input.dtype) 147 dim = args[dim_pos] 148 args0 = args[:dim_pos] + (0,) + args[dim_pos + 1:] 149 output = torch.zeros_like(input, dtype=dtype) 150 if mask is None: 151 inpmask = input.new_ones([], dtype=torch.bool).expand(input.shape) 152 else: 153 inpmask = torch.masked._input_mask(input, mask=mask) 154 dim_ = dim % input.ndim 155 left_ranges = tuple(map(range, input.shape[:dim_])) 156 right_ranges = tuple(map(range, input.shape[dim_ + 1:])) 157 for s in itertools.product(*(left_ranges + ((slice(None),),) + right_ranges)): 158 indices = inpmask[s].argwhere() 159 output[s][indices] = op(input[s][indices], *args0, **kwargs) 160 return output 161 162 163reference_functions = dict( 164 norm=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.linalg.vector_norm, *args, **dict(kwargs, dim_position=1)), 165 var=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.var, *args, **dict(kwargs, dim_position=0)), 166 std=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.std, *args, **dict(kwargs, dim_position=0)), 167 softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.softmax, *args, **kwargs), 168 log_softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.log_softmax, *args, **kwargs), 169 softmin=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.nn.functional.softmin, *args, **kwargs), 170 normalize=lambda *args, **kwargs: apply_masked_normalization_along_dim( 171 torch.nn.functional.normalize, *args, **dict(kwargs, dim_position=1)), 172) 173 174masked_ops = [op for op in op_db if op.name.startswith('masked.')] 175masked_ops_with_references = [op for op in masked_ops if op.name.rsplit('.', 1)[-1] in reference_functions] 176masked_ops_with_non_strided_support = [op for op in masked_ops if op.supports_sparse or op.supports_sparse_csr] 177 178 179def _tensor_to_strided(obj): 180 # after gh-59958 is resolved, replace the usage of this function 181 # with torch.Tensor.to_dense 182 if torch.is_tensor(obj): 183 if obj.layout == torch.strided: 184 return obj 185 return obj.to_dense() 186 return obj 187 188 189def to_strided(obj): 190 """Convert the tensor content of object to strided tensor content. 191 """ 192 return torch.utils._pytree.tree_map(_tensor_to_strided, obj) 193 194 195def to_sparse_coo(obj): 196 """Convert the tensor content of object to sparse coo tensor content. 197 """ 198 return torch.utils._pytree.tree_map(torch.Tensor.to_sparse, obj) 199 200 201def to_sparse_csr(obj): 202 """Convert the tensor content of object to sparse csr tensor content. 203 """ 204 return torch.utils._pytree.tree_map(torch.Tensor.to_sparse_csr, obj) 205 206 207class mask_layouts(_TestParametrizer): 208 """Decorator class for parametrization of test function with an input 209 layout argument and an extra argument of sample inputs generator. 210 The sample_inputs generator provides samples with all supported 211 layouts for the mask argument. 212 """ 213 def _parametrize_test(self, test, generic_cls, device_cls): 214 215 @wraps(test) 216 def wrap(self, layout, device, dtype, op): 217 layout_name = str(layout).lstrip('torch.') 218 if layout == torch.strided: 219 # strided layouts are always supported 220 sample_inputs_func = op.sample_inputs 221 elif layout == torch.sparse_coo: 222 if not op.supports_sparse: 223 raise unittest.SkipTest(f"{op.name} does not support inputs with {layout_name} layout") 224 sample_inputs_func = op.sample_inputs_sparse_coo 225 elif layout == torch.sparse_csr: 226 if not op.supports_sparse_csr: 227 raise unittest.SkipTest(f"{op.name} does not support inputs with {layout_name} layout") 228 sample_inputs_func = op.sample_inputs_sparse_csr 229 else: 230 raise NotImplementedError(f'{layout}') 231 232 def sample_inputs_generator(): 233 for sample_input in sample_inputs_func(device, dtype): 234 mask = sample_input.kwargs.get('mask') 235 if mask is None: 236 yield sample_input 237 else: 238 if layout == sample_input.input.layout: 239 yield sample_input 240 if layout != torch.strided: 241 sample_input_kwargs = sample_input.kwargs.copy() 242 sample_input_kwargs.update(mask=mask.to_dense()) 243 yield SampleInput(sample_input.input.clone(), 244 args=sample_input.args, 245 kwargs=sample_input_kwargs) 246 if layout != torch.sparse_coo and op.supports_sparse: 247 sample_input_kwargs = sample_input.kwargs.copy() 248 sample_input_kwargs.update(mask=mask.to_sparse()) 249 yield SampleInput(sample_input.input.clone(), 250 args=sample_input.args, 251 kwargs=sample_input_kwargs) 252 if layout != torch.sparse_csr and op.supports_sparse_csr and sample_input.input.ndim == 2: 253 sample_input_kwargs = sample_input.kwargs.copy() 254 sample_input_kwargs.update(mask=mask.to_sparse_csr()) 255 yield SampleInput(sample_input.input.clone(), 256 args=sample_input.args, 257 kwargs=sample_input_kwargs) 258 259 test(self, layout, device, dtype, op, sample_inputs_generator()) 260 261 for layout in (torch.strided, torch.sparse_coo, torch.sparse_csr): 262 yield (wrap, str(layout).lstrip('torch.'), {'layout': layout}, lambda _: []) 263 264 265class TestMasked(TestCase): 266 267 def assertEqualMasked(self, actual, expected, mask): 268 strided = to_strided(actual) 269 if mask is not None: 270 strided = torch.where(mask, strided, strided.new_zeros([])) 271 expected = torch.where(mask, expected, expected.new_zeros([])) 272 self.assertEqual(strided, expected, exact_device=False) 273 274 @onlyNativeDeviceTypes 275 @suppress_warnings 276 @ops(masked_ops_with_references) 277 @precisionOverride({torch.bfloat16: 5e-4, torch.float16: 5e-4}) 278 def test_reference_masked(self, device, dtype, op): 279 op_name = op.name.rsplit('.', 1)[-1] 280 ref_op = reference_functions[op_name] 281 sample_inputs = op.sample_inputs(device, dtype) 282 for sample_input in sample_inputs: 283 t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs 284 if op_name in {'var', 'std'} and not (t_inp.dtype.is_floating_point or t_inp.dtype.is_complex): 285 # torch.var/torch.std does not support integer inputs 286 continue 287 actual = op.op(t_inp, *t_args, **t_kwargs) 288 expected = ref_op(t_inp, *t_args, **t_kwargs) 289 if t_kwargs.get('mask') is None: 290 outmask = None 291 else: 292 outmask = torch.masked._output_mask(op.op, t_inp, *t_args, **t_kwargs) 293 self.assertEqualMasked(actual, expected, outmask) 294 295 @mask_layouts() 296 @onlyNativeDeviceTypes 297 @suppress_warnings 298 @ops(masked_ops_with_non_strided_support) 299 @precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-3}) 300 def test_mask_layout(self, layout, device, dtype, op, sample_inputs): 301 for sample in sample_inputs: 302 t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs 303 actual = op.op(t_inp, *t_args, **t_kwargs) 304 305 assert actual.layout == layout 306 307 # check masked invariance: 308 # op(inp, mask).to_dense() == op(inp.to_dense(), mask.to_dense()) at outmask 309 # 310 r_inp, r_args, r_kwargs = to_strided((t_inp, t_args, t_kwargs)) 311 if r_kwargs.get('mask') is None: 312 outmask = None 313 else: 314 outmask = torch.masked._output_mask(op.op, r_inp, *r_args, **r_kwargs) 315 expected = op.op(r_inp, *r_args, **r_kwargs) 316 self.assertEqualMasked(actual, expected, outmask) 317 318 @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1992") 319 @parametrize("sparse_kind,fill_value", [('coo', 0), ('hybrid_coo', 0), 320 ('coo', 123), ('hybrid_coo', 123), 321 ('csr', 0), ('csr', 123)], 322 name_fn=lambda sparse_kind, fill_value: f'{sparse_kind}_fill_value_{fill_value}') 323 def test_where(self, sparse_kind, fill_value): 324 325 is_hybrid = False 326 if sparse_kind == 'coo': 327 328 def to_sparse(dense): 329 return dense.to_sparse(2) 330 331 def set_values(sparse, index, value): 332 sparse._values()[index] = value 333 334 elif sparse_kind == 'hybrid_coo': 335 is_hybrid = True 336 337 def to_sparse(dense): 338 return dense.to_sparse(1) 339 340 def set_values(sparse, index, value): 341 sparse._values()[index] = value 342 343 elif sparse_kind == 'csr': 344 345 def to_sparse(dense): 346 return dense.to_sparse_csr() 347 348 def set_values(sparse, index, value): 349 sparse.values()[index] = value 350 351 else: 352 assert 0, sparse_kind 353 354 mask = torch.tensor([[1, 0, 1, 0, 0], 355 [1, 1, 1, 1, 0], 356 [0, 1, 0, 1, 0], 357 [0, 0, 0, 0, 0], 358 [0, 0, 1, 1, 0], 359 [1, 1, 0, 0, 0]]).to(dtype=bool) 360 mask = to_sparse(mask) 361 # make some specified mask elements as explicit masked-out masks: 362 if is_hybrid: 363 set_values(mask, (1, 1), False) 364 set_values(mask, (-2, -2), False) 365 else: 366 set_values(mask, 3, False) 367 set_values(mask, -3, False) 368 369 input = torch.tensor([[1, 0, 0, 0, -1], 370 [2, 3, 0, 0, -2], 371 [0, 4, 5, 0, -3], 372 [0, 0, 6, 7, 0], 373 [0, 8, 9, 0, -3], 374 [10, 11, 0, 0, -5]]) 375 input = to_sparse(input) 376 # make specified input elements have zero values: 377 if is_hybrid: 378 set_values(input, (1, 1), 0) 379 set_values(input, (-1, 0), 0) 380 F = fill_value 381 else: 382 set_values(input, 3, 0) 383 set_values(input, -3, 0) 384 F = 0 385 386 # expected where result: 387 Z = 99 388 # Z value corresponds to masked-in elements that are not 389 # specified in the input and it will be replaced with a zero 390 tmp = torch.tensor([[1, F, Z, F, F], 391 [2, F, Z, Z, F], 392 [F, 4, F, Z, F], 393 [0, 0, 0, 0, 0], 394 [F, F, 9, F, F], 395 [Z, 11, F, F, F]]) 396 tmp = to_sparse(tmp) 397 398 399 sparse = torch.masked._where(mask, input, 400 torch.tensor(fill_value, dtype=input.dtype, device=input.device)) 401 402 if tmp.layout == torch.sparse_coo: 403 expected_sparse = torch.sparse_coo_tensor( 404 tmp.indices(), 405 torch.where(tmp.values() != Z, tmp.values(), tmp.values().new_full([], 0)), 406 input.shape) 407 outmask = torch.sparse_coo_tensor(sparse.indices(), 408 sparse.values().new_full(sparse.values().shape, 1).to(dtype=bool), 409 sparse.shape)._coalesced_(True) 410 elif tmp.layout == torch.sparse_csr: 411 expected_sparse = torch.sparse_csr_tensor( 412 tmp.crow_indices(), 413 tmp.col_indices(), 414 torch.where(tmp.values() != Z, tmp.values(), tmp.values().new_full([], 0)), 415 input.shape) 416 outmask = torch.sparse_csr_tensor(sparse.crow_indices(), sparse.col_indices(), 417 sparse.values().new_full(sparse.values().shape, 1).to(dtype=bool), 418 sparse.shape) 419 else: 420 assert 0 421 422 self.assertEqual(sparse, expected_sparse) 423 424 # check invariance: 425 # torch.where(mask.to_dense(), input.to_dense(), fill_value) 426 # == where(mask, input, fill_value).to_dense(fill_value) 427 expected = torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, F)) 428 dense = torch.where(outmask.to_dense(), sparse.to_dense(), torch.full(sparse.shape, F)) 429 self.assertEqual(dense, expected) 430 431 432instantiate_device_type_tests(TestMasked, globals(), except_for='meta') 433 434if __name__ == "__main__": 435 run_tests() 436