xref: /aosp_15_r20/external/pytorch/test/test_masked.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: masked operators"]
2
3"""Tests for masked operations.
4"""
5
6import itertools
7import torch
8from typing import List, Any
9from functools import wraps
10import unittest
11from torch.testing._internal.common_utils import skipIfTorchDynamo
12
13
14from torch.testing._internal.common_utils import \
15    (TestCase, parametrize, suppress_warnings, _TestParametrizer, run_tests)
16from torch.testing._internal.common_methods_invocations import \
17    (op_db, SampleInput)
18from torch.testing._internal.common_device_type import \
19    (instantiate_device_type_tests, ops, onlyNativeDeviceTypes, precisionOverride)
20
21
22def apply_masked_reduction_along_dim(op, input, *args, **kwargs):
23    """Applies reduction op along given dimension to strided x
24    elements that are valid according to mask tensor.
25
26    The op is applied to each elementary slice of input with args and
27    kwargs with the following constraints:
28
29    1. Prior applying the op:
30
31      A. if kwargs contains an item with key 'dim_position' then it is
32         removed from kwargs. The value of 'dim_position' is an
33         integer that describes the dim argument position: while
34         typically the dim argument appears at the 0-th position of
35         the op arguments (excluding input), for instance, sum(input,
36         dim), then there exists reductions that have extra arguments
37         prior the dim argument, for instance, norm(input, ord, dim).
38
39      B. if args or kwargs contains dim or keepdim arguments, these
40         will be removed or replaced with None so that the op is
41         applied to elementary slice using the default dim and keepdim
42         value.
43
44    2. The elementary slice of the input is defined as the flattened
45      slice that has no masked out elements and when op is applied,
46      the result will be a scalar value (assuming keepdim=False). For
47      example, an input tensor to a reduction operation op having
48      dim=0 and keepdim=True argument:
49
50       [[1 * 2 * *]
51        [* 3 4 * 5]]
52
53      (* denotes masked out elements) has the following elementary
54      slices: [1, 2] and [3, 4, 5]. The result of
55      apply_masked_reduction_along_dim is
56
57       [[op([1, 2], *args0, **kwargs, dim=None, keepdim=False)]
58        [op([3, 4, 5], *args0, **kwargs, dim=None, keepdim=False)]]
59
60      where args0 is args where dim value is replased with None if
61      present.
62
63      Using the same example data, if the op is called with dim=(0, 1)
64      and keepdim=False, there is one elementary slice: [1, 2, 3, 4,
65      5]; and the corresponding result of the op is:
66
67        op([1, 2, 3, 4, 5], *args0, **kwargs, dim=None, keepdim=False)
68
69    3. If the elementary slice is empty, the corresponding output
70      value is nan if dtype is float, otherwise, 0.  An empty
71      elementary slice corresponds to fully masked-out output, so, the
72      corresponding specific value of the output will not be important
73      because we used masked equality check for comparing the results
74      of masked operations.
75    """
76    # eliminate mask and dim_position keyword arguments:
77    mask = kwargs.pop('mask', None)
78    dim_pos = kwargs.pop('dim_position', 0)
79
80    dtype = kwargs.get('dtype', input.dtype)
81    if input.ndim == 0:
82        # scalar input is an elementary slice
83        return op(input, *args, **kwargs).to(dtype=dtype)
84
85    # eliminate keepdim keyword argument if specified:
86    keepdim = kwargs.pop('keepdim', False)
87
88    # eliminate dim argument that may appear both as args or kwargs
89    # element:
90    if dim_pos < len(args):
91        # dim is specified in args
92        assert 'dim' not in kwargs, (args, kwargs)
93        dim = args[dim_pos]
94        args0 = args[:dim_pos] + (None,) + args[dim_pos + 1:]
95    else:
96        # dim may be specified in kwargs
97        dim = kwargs.pop('dim', None)
98        args0 = args
99
100    # dimensions along which the reduction operation is applied:
101    dim_ = torch.masked._canonical_dim(dim, input.ndim)
102    # slices in product(*ranges) define all elementary slices:
103    ranges: List[Any] = []
104    # shape of output for the keepdim=True case:
105    shape = []
106    for i in range(input.ndim):
107        if i in dim_:
108            ranges.append((slice(None),))
109            shape.append(1)
110        else:
111            ranges.append(range(input.shape[i]))
112            shape.append(input.shape[i])
113
114    # keepdim=True version of the output, filled with nan or 0:
115    output = input.new_full(shape, float('nan') if dtype.is_floating_point else 0, dtype=dtype)
116
117    # apply op to all elementary slices:
118    if mask is None:
119        inpmask = input.new_ones([], dtype=torch.bool).expand(input.shape)
120    else:
121        inpmask = torch.masked._input_mask(input, mask=mask)
122    for s in itertools.product(*ranges):
123        # data of an elementary slice is 1D sequence and has only
124        # masked-in elements:
125        data = input[s].flatten()[inpmask[s].flatten().argwhere()]
126        if not data.numel():
127            # empty elementary slice
128            continue
129        output[s][0] = op(data, *args0, **kwargs)
130
131    if not keepdim:
132        # reshape output for the keepdim=False case
133        shape = [shape[i] for i in range(len(shape)) if i not in dim_]
134        output = output.reshape(shape)
135    return output
136
137
138def apply_masked_normalization_along_dim(op, input, *args, **kwargs):
139    """Applies normalization op along given dimension to strided x
140    elements that are valid according to mask tensor.
141    """
142    mask = kwargs.pop('mask', None)
143    dim_pos = kwargs.pop('dim_position', 0)
144    if input.ndim == 0:  # scalar input
145        return op(input, *args, **kwargs)
146    dtype = kwargs.get('dtype', input.dtype)
147    dim = args[dim_pos]
148    args0 = args[:dim_pos] + (0,) + args[dim_pos + 1:]
149    output = torch.zeros_like(input, dtype=dtype)
150    if mask is None:
151        inpmask = input.new_ones([], dtype=torch.bool).expand(input.shape)
152    else:
153        inpmask = torch.masked._input_mask(input, mask=mask)
154    dim_ = dim % input.ndim
155    left_ranges = tuple(map(range, input.shape[:dim_]))
156    right_ranges = tuple(map(range, input.shape[dim_ + 1:]))
157    for s in itertools.product(*(left_ranges + ((slice(None),),) + right_ranges)):
158        indices = inpmask[s].argwhere()
159        output[s][indices] = op(input[s][indices], *args0, **kwargs)
160    return output
161
162
163reference_functions = dict(
164    norm=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.linalg.vector_norm, *args, **dict(kwargs, dim_position=1)),
165    var=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.var, *args, **dict(kwargs, dim_position=0)),
166    std=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.std, *args, **dict(kwargs, dim_position=0)),
167    softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.softmax, *args, **kwargs),
168    log_softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.log_softmax, *args, **kwargs),
169    softmin=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.nn.functional.softmin, *args, **kwargs),
170    normalize=lambda *args, **kwargs: apply_masked_normalization_along_dim(
171        torch.nn.functional.normalize, *args, **dict(kwargs, dim_position=1)),
172)
173
174masked_ops = [op for op in op_db if op.name.startswith('masked.')]
175masked_ops_with_references = [op for op in masked_ops if op.name.rsplit('.', 1)[-1] in reference_functions]
176masked_ops_with_non_strided_support = [op for op in masked_ops if op.supports_sparse or op.supports_sparse_csr]
177
178
179def _tensor_to_strided(obj):
180    # after gh-59958 is resolved, replace the usage of this function
181    # with torch.Tensor.to_dense
182    if torch.is_tensor(obj):
183        if obj.layout == torch.strided:
184            return obj
185        return obj.to_dense()
186    return obj
187
188
189def to_strided(obj):
190    """Convert the tensor content of object to strided tensor content.
191    """
192    return torch.utils._pytree.tree_map(_tensor_to_strided, obj)
193
194
195def to_sparse_coo(obj):
196    """Convert the tensor content of object to sparse coo tensor content.
197    """
198    return torch.utils._pytree.tree_map(torch.Tensor.to_sparse, obj)
199
200
201def to_sparse_csr(obj):
202    """Convert the tensor content of object to sparse csr tensor content.
203    """
204    return torch.utils._pytree.tree_map(torch.Tensor.to_sparse_csr, obj)
205
206
207class mask_layouts(_TestParametrizer):
208    """Decorator class for parametrization of test function with an input
209    layout argument and an extra argument of sample inputs generator.
210    The sample_inputs generator provides samples with all supported
211    layouts for the mask argument.
212    """
213    def _parametrize_test(self, test, generic_cls, device_cls):
214
215        @wraps(test)
216        def wrap(self, layout, device, dtype, op):
217            layout_name = str(layout).lstrip('torch.')
218            if layout == torch.strided:
219                # strided layouts are always supported
220                sample_inputs_func = op.sample_inputs
221            elif layout == torch.sparse_coo:
222                if not op.supports_sparse:
223                    raise unittest.SkipTest(f"{op.name} does not support inputs with {layout_name} layout")
224                sample_inputs_func = op.sample_inputs_sparse_coo
225            elif layout == torch.sparse_csr:
226                if not op.supports_sparse_csr:
227                    raise unittest.SkipTest(f"{op.name} does not support inputs with {layout_name} layout")
228                sample_inputs_func = op.sample_inputs_sparse_csr
229            else:
230                raise NotImplementedError(f'{layout}')
231
232            def sample_inputs_generator():
233                for sample_input in sample_inputs_func(device, dtype):
234                    mask = sample_input.kwargs.get('mask')
235                    if mask is None:
236                        yield sample_input
237                    else:
238                        if layout == sample_input.input.layout:
239                            yield sample_input
240                        if layout != torch.strided:
241                            sample_input_kwargs = sample_input.kwargs.copy()
242                            sample_input_kwargs.update(mask=mask.to_dense())
243                            yield SampleInput(sample_input.input.clone(),
244                                              args=sample_input.args,
245                                              kwargs=sample_input_kwargs)
246                        if layout != torch.sparse_coo and op.supports_sparse:
247                            sample_input_kwargs = sample_input.kwargs.copy()
248                            sample_input_kwargs.update(mask=mask.to_sparse())
249                            yield SampleInput(sample_input.input.clone(),
250                                              args=sample_input.args,
251                                              kwargs=sample_input_kwargs)
252                        if layout != torch.sparse_csr and op.supports_sparse_csr and sample_input.input.ndim == 2:
253                            sample_input_kwargs = sample_input.kwargs.copy()
254                            sample_input_kwargs.update(mask=mask.to_sparse_csr())
255                            yield SampleInput(sample_input.input.clone(),
256                                              args=sample_input.args,
257                                              kwargs=sample_input_kwargs)
258
259            test(self, layout, device, dtype, op, sample_inputs_generator())
260
261        for layout in (torch.strided, torch.sparse_coo, torch.sparse_csr):
262            yield (wrap, str(layout).lstrip('torch.'), {'layout': layout}, lambda _: [])
263
264
265class TestMasked(TestCase):
266
267    def assertEqualMasked(self, actual, expected, mask):
268        strided = to_strided(actual)
269        if mask is not None:
270            strided = torch.where(mask, strided, strided.new_zeros([]))
271            expected = torch.where(mask, expected, expected.new_zeros([]))
272        self.assertEqual(strided, expected, exact_device=False)
273
274    @onlyNativeDeviceTypes
275    @suppress_warnings
276    @ops(masked_ops_with_references)
277    @precisionOverride({torch.bfloat16: 5e-4, torch.float16: 5e-4})
278    def test_reference_masked(self, device, dtype, op):
279        op_name = op.name.rsplit('.', 1)[-1]
280        ref_op = reference_functions[op_name]
281        sample_inputs = op.sample_inputs(device, dtype)
282        for sample_input in sample_inputs:
283            t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
284            if op_name in {'var', 'std'} and not (t_inp.dtype.is_floating_point or t_inp.dtype.is_complex):
285                # torch.var/torch.std does not support integer inputs
286                continue
287            actual = op.op(t_inp, *t_args, **t_kwargs)
288            expected = ref_op(t_inp, *t_args, **t_kwargs)
289            if t_kwargs.get('mask') is None:
290                outmask = None
291            else:
292                outmask = torch.masked._output_mask(op.op, t_inp, *t_args, **t_kwargs)
293            self.assertEqualMasked(actual, expected, outmask)
294
295    @mask_layouts()
296    @onlyNativeDeviceTypes
297    @suppress_warnings
298    @ops(masked_ops_with_non_strided_support)
299    @precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-3})
300    def test_mask_layout(self, layout, device, dtype, op, sample_inputs):
301        for sample in sample_inputs:
302            t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
303            actual = op.op(t_inp, *t_args, **t_kwargs)
304
305            assert actual.layout == layout
306
307            # check masked invariance:
308            #  op(inp, mask).to_dense() == op(inp.to_dense(), mask.to_dense()) at outmask
309            #
310            r_inp, r_args, r_kwargs = to_strided((t_inp, t_args, t_kwargs))
311            if r_kwargs.get('mask') is None:
312                outmask = None
313            else:
314                outmask = torch.masked._output_mask(op.op, r_inp, *r_args, **r_kwargs)
315            expected = op.op(r_inp, *r_args, **r_kwargs)
316            self.assertEqualMasked(actual, expected, outmask)
317
318    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1992")
319    @parametrize("sparse_kind,fill_value", [('coo', 0), ('hybrid_coo', 0),
320                                            ('coo', 123), ('hybrid_coo', 123),
321                                            ('csr', 0), ('csr', 123)],
322                 name_fn=lambda sparse_kind, fill_value: f'{sparse_kind}_fill_value_{fill_value}')
323    def test_where(self, sparse_kind, fill_value):
324
325        is_hybrid = False
326        if sparse_kind == 'coo':
327
328            def to_sparse(dense):
329                return dense.to_sparse(2)
330
331            def set_values(sparse, index, value):
332                sparse._values()[index] = value
333
334        elif sparse_kind == 'hybrid_coo':
335            is_hybrid = True
336
337            def to_sparse(dense):
338                return dense.to_sparse(1)
339
340            def set_values(sparse, index, value):
341                sparse._values()[index] = value
342
343        elif sparse_kind == 'csr':
344
345            def to_sparse(dense):
346                return dense.to_sparse_csr()
347
348            def set_values(sparse, index, value):
349                sparse.values()[index] = value
350
351        else:
352            assert 0, sparse_kind
353
354        mask = torch.tensor([[1, 0, 1, 0, 0],
355                             [1, 1, 1, 1, 0],
356                             [0, 1, 0, 1, 0],
357                             [0, 0, 0, 0, 0],
358                             [0, 0, 1, 1, 0],
359                             [1, 1, 0, 0, 0]]).to(dtype=bool)
360        mask = to_sparse(mask)
361        # make some specified mask elements as explicit masked-out masks:
362        if is_hybrid:
363            set_values(mask, (1, 1), False)
364            set_values(mask, (-2, -2), False)
365        else:
366            set_values(mask, 3, False)
367            set_values(mask, -3, False)
368
369        input = torch.tensor([[1, 0, 0, 0, -1],
370                              [2, 3, 0, 0, -2],
371                              [0, 4, 5, 0, -3],
372                              [0, 0, 6, 7, 0],
373                              [0, 8, 9, 0, -3],
374                              [10, 11, 0, 0, -5]])
375        input = to_sparse(input)
376        # make specified input elements have zero values:
377        if is_hybrid:
378            set_values(input, (1, 1), 0)
379            set_values(input, (-1, 0), 0)
380            F = fill_value
381        else:
382            set_values(input, 3, 0)
383            set_values(input, -3, 0)
384            F = 0
385
386        # expected where result:
387        Z = 99
388        # Z value corresponds to masked-in elements that are not
389        # specified in the input and it will be replaced with a zero
390        tmp = torch.tensor([[1, F, Z, F, F],
391                            [2, F, Z, Z, F],
392                            [F, 4, F, Z, F],
393                            [0, 0, 0, 0, 0],
394                            [F, F, 9, F, F],
395                            [Z, 11, F, F, F]])
396        tmp = to_sparse(tmp)
397
398
399        sparse = torch.masked._where(mask, input,
400                                     torch.tensor(fill_value, dtype=input.dtype, device=input.device))
401
402        if tmp.layout == torch.sparse_coo:
403            expected_sparse = torch.sparse_coo_tensor(
404                tmp.indices(),
405                torch.where(tmp.values() != Z, tmp.values(), tmp.values().new_full([], 0)),
406                input.shape)
407            outmask = torch.sparse_coo_tensor(sparse.indices(),
408                                              sparse.values().new_full(sparse.values().shape, 1).to(dtype=bool),
409                                              sparse.shape)._coalesced_(True)
410        elif tmp.layout == torch.sparse_csr:
411            expected_sparse = torch.sparse_csr_tensor(
412                tmp.crow_indices(),
413                tmp.col_indices(),
414                torch.where(tmp.values() != Z, tmp.values(), tmp.values().new_full([], 0)),
415                input.shape)
416            outmask = torch.sparse_csr_tensor(sparse.crow_indices(), sparse.col_indices(),
417                                              sparse.values().new_full(sparse.values().shape, 1).to(dtype=bool),
418                                              sparse.shape)
419        else:
420            assert 0
421
422        self.assertEqual(sparse, expected_sparse)
423
424        # check invariance:
425        #  torch.where(mask.to_dense(), input.to_dense(), fill_value)
426        #    == where(mask, input, fill_value).to_dense(fill_value)
427        expected = torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, F))
428        dense = torch.where(outmask.to_dense(), sparse.to_dense(), torch.full(sparse.shape, F))
429        self.assertEqual(dense, expected)
430
431
432instantiate_device_type_tests(TestMasked, globals(), except_for='meta')
433
434if __name__ == "__main__":
435    run_tests()
436