1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: masked operators"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker"""Tests for masked operations. 4*da0073e9SAndroid Build Coastguard Worker""" 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport itertools 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Any 9*da0073e9SAndroid Build Coastguard Workerfrom functools import wraps 10*da0073e9SAndroid Build Coastguard Workerimport unittest 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfTorchDynamo 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import \ 15*da0073e9SAndroid Build Coastguard Worker (TestCase, parametrize, suppress_warnings, _TestParametrizer, run_tests) 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import \ 17*da0073e9SAndroid Build Coastguard Worker (op_db, SampleInput) 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import \ 19*da0073e9SAndroid Build Coastguard Worker (instantiate_device_type_tests, ops, onlyNativeDeviceTypes, precisionOverride) 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerdef apply_masked_reduction_along_dim(op, input, *args, **kwargs): 23*da0073e9SAndroid Build Coastguard Worker """Applies reduction op along given dimension to strided x 24*da0073e9SAndroid Build Coastguard Worker elements that are valid according to mask tensor. 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker The op is applied to each elementary slice of input with args and 27*da0073e9SAndroid Build Coastguard Worker kwargs with the following constraints: 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 1. Prior applying the op: 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker A. if kwargs contains an item with key 'dim_position' then it is 32*da0073e9SAndroid Build Coastguard Worker removed from kwargs. The value of 'dim_position' is an 33*da0073e9SAndroid Build Coastguard Worker integer that describes the dim argument position: while 34*da0073e9SAndroid Build Coastguard Worker typically the dim argument appears at the 0-th position of 35*da0073e9SAndroid Build Coastguard Worker the op arguments (excluding input), for instance, sum(input, 36*da0073e9SAndroid Build Coastguard Worker dim), then there exists reductions that have extra arguments 37*da0073e9SAndroid Build Coastguard Worker prior the dim argument, for instance, norm(input, ord, dim). 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker B. if args or kwargs contains dim or keepdim arguments, these 40*da0073e9SAndroid Build Coastguard Worker will be removed or replaced with None so that the op is 41*da0073e9SAndroid Build Coastguard Worker applied to elementary slice using the default dim and keepdim 42*da0073e9SAndroid Build Coastguard Worker value. 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker 2. The elementary slice of the input is defined as the flattened 45*da0073e9SAndroid Build Coastguard Worker slice that has no masked out elements and when op is applied, 46*da0073e9SAndroid Build Coastguard Worker the result will be a scalar value (assuming keepdim=False). For 47*da0073e9SAndroid Build Coastguard Worker example, an input tensor to a reduction operation op having 48*da0073e9SAndroid Build Coastguard Worker dim=0 and keepdim=True argument: 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker [[1 * 2 * *] 51*da0073e9SAndroid Build Coastguard Worker [* 3 4 * 5]] 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker (* denotes masked out elements) has the following elementary 54*da0073e9SAndroid Build Coastguard Worker slices: [1, 2] and [3, 4, 5]. The result of 55*da0073e9SAndroid Build Coastguard Worker apply_masked_reduction_along_dim is 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker [[op([1, 2], *args0, **kwargs, dim=None, keepdim=False)] 58*da0073e9SAndroid Build Coastguard Worker [op([3, 4, 5], *args0, **kwargs, dim=None, keepdim=False)]] 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker where args0 is args where dim value is replased with None if 61*da0073e9SAndroid Build Coastguard Worker present. 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker Using the same example data, if the op is called with dim=(0, 1) 64*da0073e9SAndroid Build Coastguard Worker and keepdim=False, there is one elementary slice: [1, 2, 3, 4, 65*da0073e9SAndroid Build Coastguard Worker 5]; and the corresponding result of the op is: 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker op([1, 2, 3, 4, 5], *args0, **kwargs, dim=None, keepdim=False) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker 3. If the elementary slice is empty, the corresponding output 70*da0073e9SAndroid Build Coastguard Worker value is nan if dtype is float, otherwise, 0. An empty 71*da0073e9SAndroid Build Coastguard Worker elementary slice corresponds to fully masked-out output, so, the 72*da0073e9SAndroid Build Coastguard Worker corresponding specific value of the output will not be important 73*da0073e9SAndroid Build Coastguard Worker because we used masked equality check for comparing the results 74*da0073e9SAndroid Build Coastguard Worker of masked operations. 75*da0073e9SAndroid Build Coastguard Worker """ 76*da0073e9SAndroid Build Coastguard Worker # eliminate mask and dim_position keyword arguments: 77*da0073e9SAndroid Build Coastguard Worker mask = kwargs.pop('mask', None) 78*da0073e9SAndroid Build Coastguard Worker dim_pos = kwargs.pop('dim_position', 0) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker dtype = kwargs.get('dtype', input.dtype) 81*da0073e9SAndroid Build Coastguard Worker if input.ndim == 0: 82*da0073e9SAndroid Build Coastguard Worker # scalar input is an elementary slice 83*da0073e9SAndroid Build Coastguard Worker return op(input, *args, **kwargs).to(dtype=dtype) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker # eliminate keepdim keyword argument if specified: 86*da0073e9SAndroid Build Coastguard Worker keepdim = kwargs.pop('keepdim', False) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker # eliminate dim argument that may appear both as args or kwargs 89*da0073e9SAndroid Build Coastguard Worker # element: 90*da0073e9SAndroid Build Coastguard Worker if dim_pos < len(args): 91*da0073e9SAndroid Build Coastguard Worker # dim is specified in args 92*da0073e9SAndroid Build Coastguard Worker assert 'dim' not in kwargs, (args, kwargs) 93*da0073e9SAndroid Build Coastguard Worker dim = args[dim_pos] 94*da0073e9SAndroid Build Coastguard Worker args0 = args[:dim_pos] + (None,) + args[dim_pos + 1:] 95*da0073e9SAndroid Build Coastguard Worker else: 96*da0073e9SAndroid Build Coastguard Worker # dim may be specified in kwargs 97*da0073e9SAndroid Build Coastguard Worker dim = kwargs.pop('dim', None) 98*da0073e9SAndroid Build Coastguard Worker args0 = args 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker # dimensions along which the reduction operation is applied: 101*da0073e9SAndroid Build Coastguard Worker dim_ = torch.masked._canonical_dim(dim, input.ndim) 102*da0073e9SAndroid Build Coastguard Worker # slices in product(*ranges) define all elementary slices: 103*da0073e9SAndroid Build Coastguard Worker ranges: List[Any] = [] 104*da0073e9SAndroid Build Coastguard Worker # shape of output for the keepdim=True case: 105*da0073e9SAndroid Build Coastguard Worker shape = [] 106*da0073e9SAndroid Build Coastguard Worker for i in range(input.ndim): 107*da0073e9SAndroid Build Coastguard Worker if i in dim_: 108*da0073e9SAndroid Build Coastguard Worker ranges.append((slice(None),)) 109*da0073e9SAndroid Build Coastguard Worker shape.append(1) 110*da0073e9SAndroid Build Coastguard Worker else: 111*da0073e9SAndroid Build Coastguard Worker ranges.append(range(input.shape[i])) 112*da0073e9SAndroid Build Coastguard Worker shape.append(input.shape[i]) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker # keepdim=True version of the output, filled with nan or 0: 115*da0073e9SAndroid Build Coastguard Worker output = input.new_full(shape, float('nan') if dtype.is_floating_point else 0, dtype=dtype) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker # apply op to all elementary slices: 118*da0073e9SAndroid Build Coastguard Worker if mask is None: 119*da0073e9SAndroid Build Coastguard Worker inpmask = input.new_ones([], dtype=torch.bool).expand(input.shape) 120*da0073e9SAndroid Build Coastguard Worker else: 121*da0073e9SAndroid Build Coastguard Worker inpmask = torch.masked._input_mask(input, mask=mask) 122*da0073e9SAndroid Build Coastguard Worker for s in itertools.product(*ranges): 123*da0073e9SAndroid Build Coastguard Worker # data of an elementary slice is 1D sequence and has only 124*da0073e9SAndroid Build Coastguard Worker # masked-in elements: 125*da0073e9SAndroid Build Coastguard Worker data = input[s].flatten()[inpmask[s].flatten().argwhere()] 126*da0073e9SAndroid Build Coastguard Worker if not data.numel(): 127*da0073e9SAndroid Build Coastguard Worker # empty elementary slice 128*da0073e9SAndroid Build Coastguard Worker continue 129*da0073e9SAndroid Build Coastguard Worker output[s][0] = op(data, *args0, **kwargs) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker if not keepdim: 132*da0073e9SAndroid Build Coastguard Worker # reshape output for the keepdim=False case 133*da0073e9SAndroid Build Coastguard Worker shape = [shape[i] for i in range(len(shape)) if i not in dim_] 134*da0073e9SAndroid Build Coastguard Worker output = output.reshape(shape) 135*da0073e9SAndroid Build Coastguard Worker return output 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Workerdef apply_masked_normalization_along_dim(op, input, *args, **kwargs): 139*da0073e9SAndroid Build Coastguard Worker """Applies normalization op along given dimension to strided x 140*da0073e9SAndroid Build Coastguard Worker elements that are valid according to mask tensor. 141*da0073e9SAndroid Build Coastguard Worker """ 142*da0073e9SAndroid Build Coastguard Worker mask = kwargs.pop('mask', None) 143*da0073e9SAndroid Build Coastguard Worker dim_pos = kwargs.pop('dim_position', 0) 144*da0073e9SAndroid Build Coastguard Worker if input.ndim == 0: # scalar input 145*da0073e9SAndroid Build Coastguard Worker return op(input, *args, **kwargs) 146*da0073e9SAndroid Build Coastguard Worker dtype = kwargs.get('dtype', input.dtype) 147*da0073e9SAndroid Build Coastguard Worker dim = args[dim_pos] 148*da0073e9SAndroid Build Coastguard Worker args0 = args[:dim_pos] + (0,) + args[dim_pos + 1:] 149*da0073e9SAndroid Build Coastguard Worker output = torch.zeros_like(input, dtype=dtype) 150*da0073e9SAndroid Build Coastguard Worker if mask is None: 151*da0073e9SAndroid Build Coastguard Worker inpmask = input.new_ones([], dtype=torch.bool).expand(input.shape) 152*da0073e9SAndroid Build Coastguard Worker else: 153*da0073e9SAndroid Build Coastguard Worker inpmask = torch.masked._input_mask(input, mask=mask) 154*da0073e9SAndroid Build Coastguard Worker dim_ = dim % input.ndim 155*da0073e9SAndroid Build Coastguard Worker left_ranges = tuple(map(range, input.shape[:dim_])) 156*da0073e9SAndroid Build Coastguard Worker right_ranges = tuple(map(range, input.shape[dim_ + 1:])) 157*da0073e9SAndroid Build Coastguard Worker for s in itertools.product(*(left_ranges + ((slice(None),),) + right_ranges)): 158*da0073e9SAndroid Build Coastguard Worker indices = inpmask[s].argwhere() 159*da0073e9SAndroid Build Coastguard Worker output[s][indices] = op(input[s][indices], *args0, **kwargs) 160*da0073e9SAndroid Build Coastguard Worker return output 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Workerreference_functions = dict( 164*da0073e9SAndroid Build Coastguard Worker norm=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.linalg.vector_norm, *args, **dict(kwargs, dim_position=1)), 165*da0073e9SAndroid Build Coastguard Worker var=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.var, *args, **dict(kwargs, dim_position=0)), 166*da0073e9SAndroid Build Coastguard Worker std=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.std, *args, **dict(kwargs, dim_position=0)), 167*da0073e9SAndroid Build Coastguard Worker softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.softmax, *args, **kwargs), 168*da0073e9SAndroid Build Coastguard Worker log_softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.log_softmax, *args, **kwargs), 169*da0073e9SAndroid Build Coastguard Worker softmin=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.nn.functional.softmin, *args, **kwargs), 170*da0073e9SAndroid Build Coastguard Worker normalize=lambda *args, **kwargs: apply_masked_normalization_along_dim( 171*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.normalize, *args, **dict(kwargs, dim_position=1)), 172*da0073e9SAndroid Build Coastguard Worker) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Workermasked_ops = [op for op in op_db if op.name.startswith('masked.')] 175*da0073e9SAndroid Build Coastguard Workermasked_ops_with_references = [op for op in masked_ops if op.name.rsplit('.', 1)[-1] in reference_functions] 176*da0073e9SAndroid Build Coastguard Workermasked_ops_with_non_strided_support = [op for op in masked_ops if op.supports_sparse or op.supports_sparse_csr] 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Workerdef _tensor_to_strided(obj): 180*da0073e9SAndroid Build Coastguard Worker # after gh-59958 is resolved, replace the usage of this function 181*da0073e9SAndroid Build Coastguard Worker # with torch.Tensor.to_dense 182*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(obj): 183*da0073e9SAndroid Build Coastguard Worker if obj.layout == torch.strided: 184*da0073e9SAndroid Build Coastguard Worker return obj 185*da0073e9SAndroid Build Coastguard Worker return obj.to_dense() 186*da0073e9SAndroid Build Coastguard Worker return obj 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Workerdef to_strided(obj): 190*da0073e9SAndroid Build Coastguard Worker """Convert the tensor content of object to strided tensor content. 191*da0073e9SAndroid Build Coastguard Worker """ 192*da0073e9SAndroid Build Coastguard Worker return torch.utils._pytree.tree_map(_tensor_to_strided, obj) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Workerdef to_sparse_coo(obj): 196*da0073e9SAndroid Build Coastguard Worker """Convert the tensor content of object to sparse coo tensor content. 197*da0073e9SAndroid Build Coastguard Worker """ 198*da0073e9SAndroid Build Coastguard Worker return torch.utils._pytree.tree_map(torch.Tensor.to_sparse, obj) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Workerdef to_sparse_csr(obj): 202*da0073e9SAndroid Build Coastguard Worker """Convert the tensor content of object to sparse csr tensor content. 203*da0073e9SAndroid Build Coastguard Worker """ 204*da0073e9SAndroid Build Coastguard Worker return torch.utils._pytree.tree_map(torch.Tensor.to_sparse_csr, obj) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Workerclass mask_layouts(_TestParametrizer): 208*da0073e9SAndroid Build Coastguard Worker """Decorator class for parametrization of test function with an input 209*da0073e9SAndroid Build Coastguard Worker layout argument and an extra argument of sample inputs generator. 210*da0073e9SAndroid Build Coastguard Worker The sample_inputs generator provides samples with all supported 211*da0073e9SAndroid Build Coastguard Worker layouts for the mask argument. 212*da0073e9SAndroid Build Coastguard Worker """ 213*da0073e9SAndroid Build Coastguard Worker def _parametrize_test(self, test, generic_cls, device_cls): 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker @wraps(test) 216*da0073e9SAndroid Build Coastguard Worker def wrap(self, layout, device, dtype, op): 217*da0073e9SAndroid Build Coastguard Worker layout_name = str(layout).lstrip('torch.') 218*da0073e9SAndroid Build Coastguard Worker if layout == torch.strided: 219*da0073e9SAndroid Build Coastguard Worker # strided layouts are always supported 220*da0073e9SAndroid Build Coastguard Worker sample_inputs_func = op.sample_inputs 221*da0073e9SAndroid Build Coastguard Worker elif layout == torch.sparse_coo: 222*da0073e9SAndroid Build Coastguard Worker if not op.supports_sparse: 223*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest(f"{op.name} does not support inputs with {layout_name} layout") 224*da0073e9SAndroid Build Coastguard Worker sample_inputs_func = op.sample_inputs_sparse_coo 225*da0073e9SAndroid Build Coastguard Worker elif layout == torch.sparse_csr: 226*da0073e9SAndroid Build Coastguard Worker if not op.supports_sparse_csr: 227*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest(f"{op.name} does not support inputs with {layout_name} layout") 228*da0073e9SAndroid Build Coastguard Worker sample_inputs_func = op.sample_inputs_sparse_csr 229*da0073e9SAndroid Build Coastguard Worker else: 230*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(f'{layout}') 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker def sample_inputs_generator(): 233*da0073e9SAndroid Build Coastguard Worker for sample_input in sample_inputs_func(device, dtype): 234*da0073e9SAndroid Build Coastguard Worker mask = sample_input.kwargs.get('mask') 235*da0073e9SAndroid Build Coastguard Worker if mask is None: 236*da0073e9SAndroid Build Coastguard Worker yield sample_input 237*da0073e9SAndroid Build Coastguard Worker else: 238*da0073e9SAndroid Build Coastguard Worker if layout == sample_input.input.layout: 239*da0073e9SAndroid Build Coastguard Worker yield sample_input 240*da0073e9SAndroid Build Coastguard Worker if layout != torch.strided: 241*da0073e9SAndroid Build Coastguard Worker sample_input_kwargs = sample_input.kwargs.copy() 242*da0073e9SAndroid Build Coastguard Worker sample_input_kwargs.update(mask=mask.to_dense()) 243*da0073e9SAndroid Build Coastguard Worker yield SampleInput(sample_input.input.clone(), 244*da0073e9SAndroid Build Coastguard Worker args=sample_input.args, 245*da0073e9SAndroid Build Coastguard Worker kwargs=sample_input_kwargs) 246*da0073e9SAndroid Build Coastguard Worker if layout != torch.sparse_coo and op.supports_sparse: 247*da0073e9SAndroid Build Coastguard Worker sample_input_kwargs = sample_input.kwargs.copy() 248*da0073e9SAndroid Build Coastguard Worker sample_input_kwargs.update(mask=mask.to_sparse()) 249*da0073e9SAndroid Build Coastguard Worker yield SampleInput(sample_input.input.clone(), 250*da0073e9SAndroid Build Coastguard Worker args=sample_input.args, 251*da0073e9SAndroid Build Coastguard Worker kwargs=sample_input_kwargs) 252*da0073e9SAndroid Build Coastguard Worker if layout != torch.sparse_csr and op.supports_sparse_csr and sample_input.input.ndim == 2: 253*da0073e9SAndroid Build Coastguard Worker sample_input_kwargs = sample_input.kwargs.copy() 254*da0073e9SAndroid Build Coastguard Worker sample_input_kwargs.update(mask=mask.to_sparse_csr()) 255*da0073e9SAndroid Build Coastguard Worker yield SampleInput(sample_input.input.clone(), 256*da0073e9SAndroid Build Coastguard Worker args=sample_input.args, 257*da0073e9SAndroid Build Coastguard Worker kwargs=sample_input_kwargs) 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker test(self, layout, device, dtype, op, sample_inputs_generator()) 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker for layout in (torch.strided, torch.sparse_coo, torch.sparse_csr): 262*da0073e9SAndroid Build Coastguard Worker yield (wrap, str(layout).lstrip('torch.'), {'layout': layout}, lambda _: []) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Workerclass TestMasked(TestCase): 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker def assertEqualMasked(self, actual, expected, mask): 268*da0073e9SAndroid Build Coastguard Worker strided = to_strided(actual) 269*da0073e9SAndroid Build Coastguard Worker if mask is not None: 270*da0073e9SAndroid Build Coastguard Worker strided = torch.where(mask, strided, strided.new_zeros([])) 271*da0073e9SAndroid Build Coastguard Worker expected = torch.where(mask, expected, expected.new_zeros([])) 272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strided, expected, exact_device=False) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 275*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 276*da0073e9SAndroid Build Coastguard Worker @ops(masked_ops_with_references) 277*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 5e-4, torch.float16: 5e-4}) 278*da0073e9SAndroid Build Coastguard Worker def test_reference_masked(self, device, dtype, op): 279*da0073e9SAndroid Build Coastguard Worker op_name = op.name.rsplit('.', 1)[-1] 280*da0073e9SAndroid Build Coastguard Worker ref_op = reference_functions[op_name] 281*da0073e9SAndroid Build Coastguard Worker sample_inputs = op.sample_inputs(device, dtype) 282*da0073e9SAndroid Build Coastguard Worker for sample_input in sample_inputs: 283*da0073e9SAndroid Build Coastguard Worker t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs 284*da0073e9SAndroid Build Coastguard Worker if op_name in {'var', 'std'} and not (t_inp.dtype.is_floating_point or t_inp.dtype.is_complex): 285*da0073e9SAndroid Build Coastguard Worker # torch.var/torch.std does not support integer inputs 286*da0073e9SAndroid Build Coastguard Worker continue 287*da0073e9SAndroid Build Coastguard Worker actual = op.op(t_inp, *t_args, **t_kwargs) 288*da0073e9SAndroid Build Coastguard Worker expected = ref_op(t_inp, *t_args, **t_kwargs) 289*da0073e9SAndroid Build Coastguard Worker if t_kwargs.get('mask') is None: 290*da0073e9SAndroid Build Coastguard Worker outmask = None 291*da0073e9SAndroid Build Coastguard Worker else: 292*da0073e9SAndroid Build Coastguard Worker outmask = torch.masked._output_mask(op.op, t_inp, *t_args, **t_kwargs) 293*da0073e9SAndroid Build Coastguard Worker self.assertEqualMasked(actual, expected, outmask) 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker @mask_layouts() 296*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 297*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 298*da0073e9SAndroid Build Coastguard Worker @ops(masked_ops_with_non_strided_support) 299*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-3}) 300*da0073e9SAndroid Build Coastguard Worker def test_mask_layout(self, layout, device, dtype, op, sample_inputs): 301*da0073e9SAndroid Build Coastguard Worker for sample in sample_inputs: 302*da0073e9SAndroid Build Coastguard Worker t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs 303*da0073e9SAndroid Build Coastguard Worker actual = op.op(t_inp, *t_args, **t_kwargs) 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker assert actual.layout == layout 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker # check masked invariance: 308*da0073e9SAndroid Build Coastguard Worker # op(inp, mask).to_dense() == op(inp.to_dense(), mask.to_dense()) at outmask 309*da0073e9SAndroid Build Coastguard Worker # 310*da0073e9SAndroid Build Coastguard Worker r_inp, r_args, r_kwargs = to_strided((t_inp, t_args, t_kwargs)) 311*da0073e9SAndroid Build Coastguard Worker if r_kwargs.get('mask') is None: 312*da0073e9SAndroid Build Coastguard Worker outmask = None 313*da0073e9SAndroid Build Coastguard Worker else: 314*da0073e9SAndroid Build Coastguard Worker outmask = torch.masked._output_mask(op.op, r_inp, *r_args, **r_kwargs) 315*da0073e9SAndroid Build Coastguard Worker expected = op.op(r_inp, *r_args, **r_kwargs) 316*da0073e9SAndroid Build Coastguard Worker self.assertEqualMasked(actual, expected, outmask) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1992") 319*da0073e9SAndroid Build Coastguard Worker @parametrize("sparse_kind,fill_value", [('coo', 0), ('hybrid_coo', 0), 320*da0073e9SAndroid Build Coastguard Worker ('coo', 123), ('hybrid_coo', 123), 321*da0073e9SAndroid Build Coastguard Worker ('csr', 0), ('csr', 123)], 322*da0073e9SAndroid Build Coastguard Worker name_fn=lambda sparse_kind, fill_value: f'{sparse_kind}_fill_value_{fill_value}') 323*da0073e9SAndroid Build Coastguard Worker def test_where(self, sparse_kind, fill_value): 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker is_hybrid = False 326*da0073e9SAndroid Build Coastguard Worker if sparse_kind == 'coo': 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker def to_sparse(dense): 329*da0073e9SAndroid Build Coastguard Worker return dense.to_sparse(2) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker def set_values(sparse, index, value): 332*da0073e9SAndroid Build Coastguard Worker sparse._values()[index] = value 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker elif sparse_kind == 'hybrid_coo': 335*da0073e9SAndroid Build Coastguard Worker is_hybrid = True 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker def to_sparse(dense): 338*da0073e9SAndroid Build Coastguard Worker return dense.to_sparse(1) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker def set_values(sparse, index, value): 341*da0073e9SAndroid Build Coastguard Worker sparse._values()[index] = value 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker elif sparse_kind == 'csr': 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker def to_sparse(dense): 346*da0073e9SAndroid Build Coastguard Worker return dense.to_sparse_csr() 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker def set_values(sparse, index, value): 349*da0073e9SAndroid Build Coastguard Worker sparse.values()[index] = value 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker else: 352*da0073e9SAndroid Build Coastguard Worker assert 0, sparse_kind 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([[1, 0, 1, 0, 0], 355*da0073e9SAndroid Build Coastguard Worker [1, 1, 1, 1, 0], 356*da0073e9SAndroid Build Coastguard Worker [0, 1, 0, 1, 0], 357*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0], 358*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, 1, 0], 359*da0073e9SAndroid Build Coastguard Worker [1, 1, 0, 0, 0]]).to(dtype=bool) 360*da0073e9SAndroid Build Coastguard Worker mask = to_sparse(mask) 361*da0073e9SAndroid Build Coastguard Worker # make some specified mask elements as explicit masked-out masks: 362*da0073e9SAndroid Build Coastguard Worker if is_hybrid: 363*da0073e9SAndroid Build Coastguard Worker set_values(mask, (1, 1), False) 364*da0073e9SAndroid Build Coastguard Worker set_values(mask, (-2, -2), False) 365*da0073e9SAndroid Build Coastguard Worker else: 366*da0073e9SAndroid Build Coastguard Worker set_values(mask, 3, False) 367*da0073e9SAndroid Build Coastguard Worker set_values(mask, -3, False) 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([[1, 0, 0, 0, -1], 370*da0073e9SAndroid Build Coastguard Worker [2, 3, 0, 0, -2], 371*da0073e9SAndroid Build Coastguard Worker [0, 4, 5, 0, -3], 372*da0073e9SAndroid Build Coastguard Worker [0, 0, 6, 7, 0], 373*da0073e9SAndroid Build Coastguard Worker [0, 8, 9, 0, -3], 374*da0073e9SAndroid Build Coastguard Worker [10, 11, 0, 0, -5]]) 375*da0073e9SAndroid Build Coastguard Worker input = to_sparse(input) 376*da0073e9SAndroid Build Coastguard Worker # make specified input elements have zero values: 377*da0073e9SAndroid Build Coastguard Worker if is_hybrid: 378*da0073e9SAndroid Build Coastguard Worker set_values(input, (1, 1), 0) 379*da0073e9SAndroid Build Coastguard Worker set_values(input, (-1, 0), 0) 380*da0073e9SAndroid Build Coastguard Worker F = fill_value 381*da0073e9SAndroid Build Coastguard Worker else: 382*da0073e9SAndroid Build Coastguard Worker set_values(input, 3, 0) 383*da0073e9SAndroid Build Coastguard Worker set_values(input, -3, 0) 384*da0073e9SAndroid Build Coastguard Worker F = 0 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker # expected where result: 387*da0073e9SAndroid Build Coastguard Worker Z = 99 388*da0073e9SAndroid Build Coastguard Worker # Z value corresponds to masked-in elements that are not 389*da0073e9SAndroid Build Coastguard Worker # specified in the input and it will be replaced with a zero 390*da0073e9SAndroid Build Coastguard Worker tmp = torch.tensor([[1, F, Z, F, F], 391*da0073e9SAndroid Build Coastguard Worker [2, F, Z, Z, F], 392*da0073e9SAndroid Build Coastguard Worker [F, 4, F, Z, F], 393*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0], 394*da0073e9SAndroid Build Coastguard Worker [F, F, 9, F, F], 395*da0073e9SAndroid Build Coastguard Worker [Z, 11, F, F, F]]) 396*da0073e9SAndroid Build Coastguard Worker tmp = to_sparse(tmp) 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker sparse = torch.masked._where(mask, input, 400*da0073e9SAndroid Build Coastguard Worker torch.tensor(fill_value, dtype=input.dtype, device=input.device)) 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker if tmp.layout == torch.sparse_coo: 403*da0073e9SAndroid Build Coastguard Worker expected_sparse = torch.sparse_coo_tensor( 404*da0073e9SAndroid Build Coastguard Worker tmp.indices(), 405*da0073e9SAndroid Build Coastguard Worker torch.where(tmp.values() != Z, tmp.values(), tmp.values().new_full([], 0)), 406*da0073e9SAndroid Build Coastguard Worker input.shape) 407*da0073e9SAndroid Build Coastguard Worker outmask = torch.sparse_coo_tensor(sparse.indices(), 408*da0073e9SAndroid Build Coastguard Worker sparse.values().new_full(sparse.values().shape, 1).to(dtype=bool), 409*da0073e9SAndroid Build Coastguard Worker sparse.shape)._coalesced_(True) 410*da0073e9SAndroid Build Coastguard Worker elif tmp.layout == torch.sparse_csr: 411*da0073e9SAndroid Build Coastguard Worker expected_sparse = torch.sparse_csr_tensor( 412*da0073e9SAndroid Build Coastguard Worker tmp.crow_indices(), 413*da0073e9SAndroid Build Coastguard Worker tmp.col_indices(), 414*da0073e9SAndroid Build Coastguard Worker torch.where(tmp.values() != Z, tmp.values(), tmp.values().new_full([], 0)), 415*da0073e9SAndroid Build Coastguard Worker input.shape) 416*da0073e9SAndroid Build Coastguard Worker outmask = torch.sparse_csr_tensor(sparse.crow_indices(), sparse.col_indices(), 417*da0073e9SAndroid Build Coastguard Worker sparse.values().new_full(sparse.values().shape, 1).to(dtype=bool), 418*da0073e9SAndroid Build Coastguard Worker sparse.shape) 419*da0073e9SAndroid Build Coastguard Worker else: 420*da0073e9SAndroid Build Coastguard Worker assert 0 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sparse, expected_sparse) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker # check invariance: 425*da0073e9SAndroid Build Coastguard Worker # torch.where(mask.to_dense(), input.to_dense(), fill_value) 426*da0073e9SAndroid Build Coastguard Worker # == where(mask, input, fill_value).to_dense(fill_value) 427*da0073e9SAndroid Build Coastguard Worker expected = torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, F)) 428*da0073e9SAndroid Build Coastguard Worker dense = torch.where(outmask.to_dense(), sparse.to_dense(), torch.full(sparse.shape, F)) 429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dense, expected) 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMasked, globals(), except_for='meta') 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 435*da0073e9SAndroid Build Coastguard Worker run_tests() 436