xref: /aosp_15_r20/external/pytorch/test/test_masked.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: masked operators"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker"""Tests for masked operations.
4*da0073e9SAndroid Build Coastguard Worker"""
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport itertools
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Any
9*da0073e9SAndroid Build Coastguard Workerfrom functools import wraps
10*da0073e9SAndroid Build Coastguard Workerimport unittest
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfTorchDynamo
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import \
15*da0073e9SAndroid Build Coastguard Worker    (TestCase, parametrize, suppress_warnings, _TestParametrizer, run_tests)
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import \
17*da0073e9SAndroid Build Coastguard Worker    (op_db, SampleInput)
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import \
19*da0073e9SAndroid Build Coastguard Worker    (instantiate_device_type_tests, ops, onlyNativeDeviceTypes, precisionOverride)
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerdef apply_masked_reduction_along_dim(op, input, *args, **kwargs):
23*da0073e9SAndroid Build Coastguard Worker    """Applies reduction op along given dimension to strided x
24*da0073e9SAndroid Build Coastguard Worker    elements that are valid according to mask tensor.
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker    The op is applied to each elementary slice of input with args and
27*da0073e9SAndroid Build Coastguard Worker    kwargs with the following constraints:
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    1. Prior applying the op:
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker      A. if kwargs contains an item with key 'dim_position' then it is
32*da0073e9SAndroid Build Coastguard Worker         removed from kwargs. The value of 'dim_position' is an
33*da0073e9SAndroid Build Coastguard Worker         integer that describes the dim argument position: while
34*da0073e9SAndroid Build Coastguard Worker         typically the dim argument appears at the 0-th position of
35*da0073e9SAndroid Build Coastguard Worker         the op arguments (excluding input), for instance, sum(input,
36*da0073e9SAndroid Build Coastguard Worker         dim), then there exists reductions that have extra arguments
37*da0073e9SAndroid Build Coastguard Worker         prior the dim argument, for instance, norm(input, ord, dim).
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker      B. if args or kwargs contains dim or keepdim arguments, these
40*da0073e9SAndroid Build Coastguard Worker         will be removed or replaced with None so that the op is
41*da0073e9SAndroid Build Coastguard Worker         applied to elementary slice using the default dim and keepdim
42*da0073e9SAndroid Build Coastguard Worker         value.
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker    2. The elementary slice of the input is defined as the flattened
45*da0073e9SAndroid Build Coastguard Worker      slice that has no masked out elements and when op is applied,
46*da0073e9SAndroid Build Coastguard Worker      the result will be a scalar value (assuming keepdim=False). For
47*da0073e9SAndroid Build Coastguard Worker      example, an input tensor to a reduction operation op having
48*da0073e9SAndroid Build Coastguard Worker      dim=0 and keepdim=True argument:
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker       [[1 * 2 * *]
51*da0073e9SAndroid Build Coastguard Worker        [* 3 4 * 5]]
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker      (* denotes masked out elements) has the following elementary
54*da0073e9SAndroid Build Coastguard Worker      slices: [1, 2] and [3, 4, 5]. The result of
55*da0073e9SAndroid Build Coastguard Worker      apply_masked_reduction_along_dim is
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker       [[op([1, 2], *args0, **kwargs, dim=None, keepdim=False)]
58*da0073e9SAndroid Build Coastguard Worker        [op([3, 4, 5], *args0, **kwargs, dim=None, keepdim=False)]]
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker      where args0 is args where dim value is replased with None if
61*da0073e9SAndroid Build Coastguard Worker      present.
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker      Using the same example data, if the op is called with dim=(0, 1)
64*da0073e9SAndroid Build Coastguard Worker      and keepdim=False, there is one elementary slice: [1, 2, 3, 4,
65*da0073e9SAndroid Build Coastguard Worker      5]; and the corresponding result of the op is:
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker        op([1, 2, 3, 4, 5], *args0, **kwargs, dim=None, keepdim=False)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    3. If the elementary slice is empty, the corresponding output
70*da0073e9SAndroid Build Coastguard Worker      value is nan if dtype is float, otherwise, 0.  An empty
71*da0073e9SAndroid Build Coastguard Worker      elementary slice corresponds to fully masked-out output, so, the
72*da0073e9SAndroid Build Coastguard Worker      corresponding specific value of the output will not be important
73*da0073e9SAndroid Build Coastguard Worker      because we used masked equality check for comparing the results
74*da0073e9SAndroid Build Coastguard Worker      of masked operations.
75*da0073e9SAndroid Build Coastguard Worker    """
76*da0073e9SAndroid Build Coastguard Worker    # eliminate mask and dim_position keyword arguments:
77*da0073e9SAndroid Build Coastguard Worker    mask = kwargs.pop('mask', None)
78*da0073e9SAndroid Build Coastguard Worker    dim_pos = kwargs.pop('dim_position', 0)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    dtype = kwargs.get('dtype', input.dtype)
81*da0073e9SAndroid Build Coastguard Worker    if input.ndim == 0:
82*da0073e9SAndroid Build Coastguard Worker        # scalar input is an elementary slice
83*da0073e9SAndroid Build Coastguard Worker        return op(input, *args, **kwargs).to(dtype=dtype)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    # eliminate keepdim keyword argument if specified:
86*da0073e9SAndroid Build Coastguard Worker    keepdim = kwargs.pop('keepdim', False)
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    # eliminate dim argument that may appear both as args or kwargs
89*da0073e9SAndroid Build Coastguard Worker    # element:
90*da0073e9SAndroid Build Coastguard Worker    if dim_pos < len(args):
91*da0073e9SAndroid Build Coastguard Worker        # dim is specified in args
92*da0073e9SAndroid Build Coastguard Worker        assert 'dim' not in kwargs, (args, kwargs)
93*da0073e9SAndroid Build Coastguard Worker        dim = args[dim_pos]
94*da0073e9SAndroid Build Coastguard Worker        args0 = args[:dim_pos] + (None,) + args[dim_pos + 1:]
95*da0073e9SAndroid Build Coastguard Worker    else:
96*da0073e9SAndroid Build Coastguard Worker        # dim may be specified in kwargs
97*da0073e9SAndroid Build Coastguard Worker        dim = kwargs.pop('dim', None)
98*da0073e9SAndroid Build Coastguard Worker        args0 = args
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    # dimensions along which the reduction operation is applied:
101*da0073e9SAndroid Build Coastguard Worker    dim_ = torch.masked._canonical_dim(dim, input.ndim)
102*da0073e9SAndroid Build Coastguard Worker    # slices in product(*ranges) define all elementary slices:
103*da0073e9SAndroid Build Coastguard Worker    ranges: List[Any] = []
104*da0073e9SAndroid Build Coastguard Worker    # shape of output for the keepdim=True case:
105*da0073e9SAndroid Build Coastguard Worker    shape = []
106*da0073e9SAndroid Build Coastguard Worker    for i in range(input.ndim):
107*da0073e9SAndroid Build Coastguard Worker        if i in dim_:
108*da0073e9SAndroid Build Coastguard Worker            ranges.append((slice(None),))
109*da0073e9SAndroid Build Coastguard Worker            shape.append(1)
110*da0073e9SAndroid Build Coastguard Worker        else:
111*da0073e9SAndroid Build Coastguard Worker            ranges.append(range(input.shape[i]))
112*da0073e9SAndroid Build Coastguard Worker            shape.append(input.shape[i])
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    # keepdim=True version of the output, filled with nan or 0:
115*da0073e9SAndroid Build Coastguard Worker    output = input.new_full(shape, float('nan') if dtype.is_floating_point else 0, dtype=dtype)
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    # apply op to all elementary slices:
118*da0073e9SAndroid Build Coastguard Worker    if mask is None:
119*da0073e9SAndroid Build Coastguard Worker        inpmask = input.new_ones([], dtype=torch.bool).expand(input.shape)
120*da0073e9SAndroid Build Coastguard Worker    else:
121*da0073e9SAndroid Build Coastguard Worker        inpmask = torch.masked._input_mask(input, mask=mask)
122*da0073e9SAndroid Build Coastguard Worker    for s in itertools.product(*ranges):
123*da0073e9SAndroid Build Coastguard Worker        # data of an elementary slice is 1D sequence and has only
124*da0073e9SAndroid Build Coastguard Worker        # masked-in elements:
125*da0073e9SAndroid Build Coastguard Worker        data = input[s].flatten()[inpmask[s].flatten().argwhere()]
126*da0073e9SAndroid Build Coastguard Worker        if not data.numel():
127*da0073e9SAndroid Build Coastguard Worker            # empty elementary slice
128*da0073e9SAndroid Build Coastguard Worker            continue
129*da0073e9SAndroid Build Coastguard Worker        output[s][0] = op(data, *args0, **kwargs)
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    if not keepdim:
132*da0073e9SAndroid Build Coastguard Worker        # reshape output for the keepdim=False case
133*da0073e9SAndroid Build Coastguard Worker        shape = [shape[i] for i in range(len(shape)) if i not in dim_]
134*da0073e9SAndroid Build Coastguard Worker        output = output.reshape(shape)
135*da0073e9SAndroid Build Coastguard Worker    return output
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Workerdef apply_masked_normalization_along_dim(op, input, *args, **kwargs):
139*da0073e9SAndroid Build Coastguard Worker    """Applies normalization op along given dimension to strided x
140*da0073e9SAndroid Build Coastguard Worker    elements that are valid according to mask tensor.
141*da0073e9SAndroid Build Coastguard Worker    """
142*da0073e9SAndroid Build Coastguard Worker    mask = kwargs.pop('mask', None)
143*da0073e9SAndroid Build Coastguard Worker    dim_pos = kwargs.pop('dim_position', 0)
144*da0073e9SAndroid Build Coastguard Worker    if input.ndim == 0:  # scalar input
145*da0073e9SAndroid Build Coastguard Worker        return op(input, *args, **kwargs)
146*da0073e9SAndroid Build Coastguard Worker    dtype = kwargs.get('dtype', input.dtype)
147*da0073e9SAndroid Build Coastguard Worker    dim = args[dim_pos]
148*da0073e9SAndroid Build Coastguard Worker    args0 = args[:dim_pos] + (0,) + args[dim_pos + 1:]
149*da0073e9SAndroid Build Coastguard Worker    output = torch.zeros_like(input, dtype=dtype)
150*da0073e9SAndroid Build Coastguard Worker    if mask is None:
151*da0073e9SAndroid Build Coastguard Worker        inpmask = input.new_ones([], dtype=torch.bool).expand(input.shape)
152*da0073e9SAndroid Build Coastguard Worker    else:
153*da0073e9SAndroid Build Coastguard Worker        inpmask = torch.masked._input_mask(input, mask=mask)
154*da0073e9SAndroid Build Coastguard Worker    dim_ = dim % input.ndim
155*da0073e9SAndroid Build Coastguard Worker    left_ranges = tuple(map(range, input.shape[:dim_]))
156*da0073e9SAndroid Build Coastguard Worker    right_ranges = tuple(map(range, input.shape[dim_ + 1:]))
157*da0073e9SAndroid Build Coastguard Worker    for s in itertools.product(*(left_ranges + ((slice(None),),) + right_ranges)):
158*da0073e9SAndroid Build Coastguard Worker        indices = inpmask[s].argwhere()
159*da0073e9SAndroid Build Coastguard Worker        output[s][indices] = op(input[s][indices], *args0, **kwargs)
160*da0073e9SAndroid Build Coastguard Worker    return output
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Workerreference_functions = dict(
164*da0073e9SAndroid Build Coastguard Worker    norm=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.linalg.vector_norm, *args, **dict(kwargs, dim_position=1)),
165*da0073e9SAndroid Build Coastguard Worker    var=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.var, *args, **dict(kwargs, dim_position=0)),
166*da0073e9SAndroid Build Coastguard Worker    std=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.std, *args, **dict(kwargs, dim_position=0)),
167*da0073e9SAndroid Build Coastguard Worker    softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.softmax, *args, **kwargs),
168*da0073e9SAndroid Build Coastguard Worker    log_softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.log_softmax, *args, **kwargs),
169*da0073e9SAndroid Build Coastguard Worker    softmin=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.nn.functional.softmin, *args, **kwargs),
170*da0073e9SAndroid Build Coastguard Worker    normalize=lambda *args, **kwargs: apply_masked_normalization_along_dim(
171*da0073e9SAndroid Build Coastguard Worker        torch.nn.functional.normalize, *args, **dict(kwargs, dim_position=1)),
172*da0073e9SAndroid Build Coastguard Worker)
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Workermasked_ops = [op for op in op_db if op.name.startswith('masked.')]
175*da0073e9SAndroid Build Coastguard Workermasked_ops_with_references = [op for op in masked_ops if op.name.rsplit('.', 1)[-1] in reference_functions]
176*da0073e9SAndroid Build Coastguard Workermasked_ops_with_non_strided_support = [op for op in masked_ops if op.supports_sparse or op.supports_sparse_csr]
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Workerdef _tensor_to_strided(obj):
180*da0073e9SAndroid Build Coastguard Worker    # after gh-59958 is resolved, replace the usage of this function
181*da0073e9SAndroid Build Coastguard Worker    # with torch.Tensor.to_dense
182*da0073e9SAndroid Build Coastguard Worker    if torch.is_tensor(obj):
183*da0073e9SAndroid Build Coastguard Worker        if obj.layout == torch.strided:
184*da0073e9SAndroid Build Coastguard Worker            return obj
185*da0073e9SAndroid Build Coastguard Worker        return obj.to_dense()
186*da0073e9SAndroid Build Coastguard Worker    return obj
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Workerdef to_strided(obj):
190*da0073e9SAndroid Build Coastguard Worker    """Convert the tensor content of object to strided tensor content.
191*da0073e9SAndroid Build Coastguard Worker    """
192*da0073e9SAndroid Build Coastguard Worker    return torch.utils._pytree.tree_map(_tensor_to_strided, obj)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Workerdef to_sparse_coo(obj):
196*da0073e9SAndroid Build Coastguard Worker    """Convert the tensor content of object to sparse coo tensor content.
197*da0073e9SAndroid Build Coastguard Worker    """
198*da0073e9SAndroid Build Coastguard Worker    return torch.utils._pytree.tree_map(torch.Tensor.to_sparse, obj)
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Workerdef to_sparse_csr(obj):
202*da0073e9SAndroid Build Coastguard Worker    """Convert the tensor content of object to sparse csr tensor content.
203*da0073e9SAndroid Build Coastguard Worker    """
204*da0073e9SAndroid Build Coastguard Worker    return torch.utils._pytree.tree_map(torch.Tensor.to_sparse_csr, obj)
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Workerclass mask_layouts(_TestParametrizer):
208*da0073e9SAndroid Build Coastguard Worker    """Decorator class for parametrization of test function with an input
209*da0073e9SAndroid Build Coastguard Worker    layout argument and an extra argument of sample inputs generator.
210*da0073e9SAndroid Build Coastguard Worker    The sample_inputs generator provides samples with all supported
211*da0073e9SAndroid Build Coastguard Worker    layouts for the mask argument.
212*da0073e9SAndroid Build Coastguard Worker    """
213*da0073e9SAndroid Build Coastguard Worker    def _parametrize_test(self, test, generic_cls, device_cls):
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker        @wraps(test)
216*da0073e9SAndroid Build Coastguard Worker        def wrap(self, layout, device, dtype, op):
217*da0073e9SAndroid Build Coastguard Worker            layout_name = str(layout).lstrip('torch.')
218*da0073e9SAndroid Build Coastguard Worker            if layout == torch.strided:
219*da0073e9SAndroid Build Coastguard Worker                # strided layouts are always supported
220*da0073e9SAndroid Build Coastguard Worker                sample_inputs_func = op.sample_inputs
221*da0073e9SAndroid Build Coastguard Worker            elif layout == torch.sparse_coo:
222*da0073e9SAndroid Build Coastguard Worker                if not op.supports_sparse:
223*da0073e9SAndroid Build Coastguard Worker                    raise unittest.SkipTest(f"{op.name} does not support inputs with {layout_name} layout")
224*da0073e9SAndroid Build Coastguard Worker                sample_inputs_func = op.sample_inputs_sparse_coo
225*da0073e9SAndroid Build Coastguard Worker            elif layout == torch.sparse_csr:
226*da0073e9SAndroid Build Coastguard Worker                if not op.supports_sparse_csr:
227*da0073e9SAndroid Build Coastguard Worker                    raise unittest.SkipTest(f"{op.name} does not support inputs with {layout_name} layout")
228*da0073e9SAndroid Build Coastguard Worker                sample_inputs_func = op.sample_inputs_sparse_csr
229*da0073e9SAndroid Build Coastguard Worker            else:
230*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError(f'{layout}')
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker            def sample_inputs_generator():
233*da0073e9SAndroid Build Coastguard Worker                for sample_input in sample_inputs_func(device, dtype):
234*da0073e9SAndroid Build Coastguard Worker                    mask = sample_input.kwargs.get('mask')
235*da0073e9SAndroid Build Coastguard Worker                    if mask is None:
236*da0073e9SAndroid Build Coastguard Worker                        yield sample_input
237*da0073e9SAndroid Build Coastguard Worker                    else:
238*da0073e9SAndroid Build Coastguard Worker                        if layout == sample_input.input.layout:
239*da0073e9SAndroid Build Coastguard Worker                            yield sample_input
240*da0073e9SAndroid Build Coastguard Worker                        if layout != torch.strided:
241*da0073e9SAndroid Build Coastguard Worker                            sample_input_kwargs = sample_input.kwargs.copy()
242*da0073e9SAndroid Build Coastguard Worker                            sample_input_kwargs.update(mask=mask.to_dense())
243*da0073e9SAndroid Build Coastguard Worker                            yield SampleInput(sample_input.input.clone(),
244*da0073e9SAndroid Build Coastguard Worker                                              args=sample_input.args,
245*da0073e9SAndroid Build Coastguard Worker                                              kwargs=sample_input_kwargs)
246*da0073e9SAndroid Build Coastguard Worker                        if layout != torch.sparse_coo and op.supports_sparse:
247*da0073e9SAndroid Build Coastguard Worker                            sample_input_kwargs = sample_input.kwargs.copy()
248*da0073e9SAndroid Build Coastguard Worker                            sample_input_kwargs.update(mask=mask.to_sparse())
249*da0073e9SAndroid Build Coastguard Worker                            yield SampleInput(sample_input.input.clone(),
250*da0073e9SAndroid Build Coastguard Worker                                              args=sample_input.args,
251*da0073e9SAndroid Build Coastguard Worker                                              kwargs=sample_input_kwargs)
252*da0073e9SAndroid Build Coastguard Worker                        if layout != torch.sparse_csr and op.supports_sparse_csr and sample_input.input.ndim == 2:
253*da0073e9SAndroid Build Coastguard Worker                            sample_input_kwargs = sample_input.kwargs.copy()
254*da0073e9SAndroid Build Coastguard Worker                            sample_input_kwargs.update(mask=mask.to_sparse_csr())
255*da0073e9SAndroid Build Coastguard Worker                            yield SampleInput(sample_input.input.clone(),
256*da0073e9SAndroid Build Coastguard Worker                                              args=sample_input.args,
257*da0073e9SAndroid Build Coastguard Worker                                              kwargs=sample_input_kwargs)
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker            test(self, layout, device, dtype, op, sample_inputs_generator())
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker        for layout in (torch.strided, torch.sparse_coo, torch.sparse_csr):
262*da0073e9SAndroid Build Coastguard Worker            yield (wrap, str(layout).lstrip('torch.'), {'layout': layout}, lambda _: [])
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Workerclass TestMasked(TestCase):
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    def assertEqualMasked(self, actual, expected, mask):
268*da0073e9SAndroid Build Coastguard Worker        strided = to_strided(actual)
269*da0073e9SAndroid Build Coastguard Worker        if mask is not None:
270*da0073e9SAndroid Build Coastguard Worker            strided = torch.where(mask, strided, strided.new_zeros([]))
271*da0073e9SAndroid Build Coastguard Worker            expected = torch.where(mask, expected, expected.new_zeros([]))
272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(strided, expected, exact_device=False)
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
275*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
276*da0073e9SAndroid Build Coastguard Worker    @ops(masked_ops_with_references)
277*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 5e-4, torch.float16: 5e-4})
278*da0073e9SAndroid Build Coastguard Worker    def test_reference_masked(self, device, dtype, op):
279*da0073e9SAndroid Build Coastguard Worker        op_name = op.name.rsplit('.', 1)[-1]
280*da0073e9SAndroid Build Coastguard Worker        ref_op = reference_functions[op_name]
281*da0073e9SAndroid Build Coastguard Worker        sample_inputs = op.sample_inputs(device, dtype)
282*da0073e9SAndroid Build Coastguard Worker        for sample_input in sample_inputs:
283*da0073e9SAndroid Build Coastguard Worker            t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
284*da0073e9SAndroid Build Coastguard Worker            if op_name in {'var', 'std'} and not (t_inp.dtype.is_floating_point or t_inp.dtype.is_complex):
285*da0073e9SAndroid Build Coastguard Worker                # torch.var/torch.std does not support integer inputs
286*da0073e9SAndroid Build Coastguard Worker                continue
287*da0073e9SAndroid Build Coastguard Worker            actual = op.op(t_inp, *t_args, **t_kwargs)
288*da0073e9SAndroid Build Coastguard Worker            expected = ref_op(t_inp, *t_args, **t_kwargs)
289*da0073e9SAndroid Build Coastguard Worker            if t_kwargs.get('mask') is None:
290*da0073e9SAndroid Build Coastguard Worker                outmask = None
291*da0073e9SAndroid Build Coastguard Worker            else:
292*da0073e9SAndroid Build Coastguard Worker                outmask = torch.masked._output_mask(op.op, t_inp, *t_args, **t_kwargs)
293*da0073e9SAndroid Build Coastguard Worker            self.assertEqualMasked(actual, expected, outmask)
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker    @mask_layouts()
296*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
297*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
298*da0073e9SAndroid Build Coastguard Worker    @ops(masked_ops_with_non_strided_support)
299*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-3})
300*da0073e9SAndroid Build Coastguard Worker    def test_mask_layout(self, layout, device, dtype, op, sample_inputs):
301*da0073e9SAndroid Build Coastguard Worker        for sample in sample_inputs:
302*da0073e9SAndroid Build Coastguard Worker            t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
303*da0073e9SAndroid Build Coastguard Worker            actual = op.op(t_inp, *t_args, **t_kwargs)
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker            assert actual.layout == layout
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker            # check masked invariance:
308*da0073e9SAndroid Build Coastguard Worker            #  op(inp, mask).to_dense() == op(inp.to_dense(), mask.to_dense()) at outmask
309*da0073e9SAndroid Build Coastguard Worker            #
310*da0073e9SAndroid Build Coastguard Worker            r_inp, r_args, r_kwargs = to_strided((t_inp, t_args, t_kwargs))
311*da0073e9SAndroid Build Coastguard Worker            if r_kwargs.get('mask') is None:
312*da0073e9SAndroid Build Coastguard Worker                outmask = None
313*da0073e9SAndroid Build Coastguard Worker            else:
314*da0073e9SAndroid Build Coastguard Worker                outmask = torch.masked._output_mask(op.op, r_inp, *r_args, **r_kwargs)
315*da0073e9SAndroid Build Coastguard Worker            expected = op.op(r_inp, *r_args, **r_kwargs)
316*da0073e9SAndroid Build Coastguard Worker            self.assertEqualMasked(actual, expected, outmask)
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1992")
319*da0073e9SAndroid Build Coastguard Worker    @parametrize("sparse_kind,fill_value", [('coo', 0), ('hybrid_coo', 0),
320*da0073e9SAndroid Build Coastguard Worker                                            ('coo', 123), ('hybrid_coo', 123),
321*da0073e9SAndroid Build Coastguard Worker                                            ('csr', 0), ('csr', 123)],
322*da0073e9SAndroid Build Coastguard Worker                 name_fn=lambda sparse_kind, fill_value: f'{sparse_kind}_fill_value_{fill_value}')
323*da0073e9SAndroid Build Coastguard Worker    def test_where(self, sparse_kind, fill_value):
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker        is_hybrid = False
326*da0073e9SAndroid Build Coastguard Worker        if sparse_kind == 'coo':
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker            def to_sparse(dense):
329*da0073e9SAndroid Build Coastguard Worker                return dense.to_sparse(2)
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker            def set_values(sparse, index, value):
332*da0073e9SAndroid Build Coastguard Worker                sparse._values()[index] = value
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        elif sparse_kind == 'hybrid_coo':
335*da0073e9SAndroid Build Coastguard Worker            is_hybrid = True
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker            def to_sparse(dense):
338*da0073e9SAndroid Build Coastguard Worker                return dense.to_sparse(1)
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker            def set_values(sparse, index, value):
341*da0073e9SAndroid Build Coastguard Worker                sparse._values()[index] = value
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker        elif sparse_kind == 'csr':
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker            def to_sparse(dense):
346*da0073e9SAndroid Build Coastguard Worker                return dense.to_sparse_csr()
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker            def set_values(sparse, index, value):
349*da0073e9SAndroid Build Coastguard Worker                sparse.values()[index] = value
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker        else:
352*da0073e9SAndroid Build Coastguard Worker            assert 0, sparse_kind
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker        mask = torch.tensor([[1, 0, 1, 0, 0],
355*da0073e9SAndroid Build Coastguard Worker                             [1, 1, 1, 1, 0],
356*da0073e9SAndroid Build Coastguard Worker                             [0, 1, 0, 1, 0],
357*da0073e9SAndroid Build Coastguard Worker                             [0, 0, 0, 0, 0],
358*da0073e9SAndroid Build Coastguard Worker                             [0, 0, 1, 1, 0],
359*da0073e9SAndroid Build Coastguard Worker                             [1, 1, 0, 0, 0]]).to(dtype=bool)
360*da0073e9SAndroid Build Coastguard Worker        mask = to_sparse(mask)
361*da0073e9SAndroid Build Coastguard Worker        # make some specified mask elements as explicit masked-out masks:
362*da0073e9SAndroid Build Coastguard Worker        if is_hybrid:
363*da0073e9SAndroid Build Coastguard Worker            set_values(mask, (1, 1), False)
364*da0073e9SAndroid Build Coastguard Worker            set_values(mask, (-2, -2), False)
365*da0073e9SAndroid Build Coastguard Worker        else:
366*da0073e9SAndroid Build Coastguard Worker            set_values(mask, 3, False)
367*da0073e9SAndroid Build Coastguard Worker            set_values(mask, -3, False)
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([[1, 0, 0, 0, -1],
370*da0073e9SAndroid Build Coastguard Worker                              [2, 3, 0, 0, -2],
371*da0073e9SAndroid Build Coastguard Worker                              [0, 4, 5, 0, -3],
372*da0073e9SAndroid Build Coastguard Worker                              [0, 0, 6, 7, 0],
373*da0073e9SAndroid Build Coastguard Worker                              [0, 8, 9, 0, -3],
374*da0073e9SAndroid Build Coastguard Worker                              [10, 11, 0, 0, -5]])
375*da0073e9SAndroid Build Coastguard Worker        input = to_sparse(input)
376*da0073e9SAndroid Build Coastguard Worker        # make specified input elements have zero values:
377*da0073e9SAndroid Build Coastguard Worker        if is_hybrid:
378*da0073e9SAndroid Build Coastguard Worker            set_values(input, (1, 1), 0)
379*da0073e9SAndroid Build Coastguard Worker            set_values(input, (-1, 0), 0)
380*da0073e9SAndroid Build Coastguard Worker            F = fill_value
381*da0073e9SAndroid Build Coastguard Worker        else:
382*da0073e9SAndroid Build Coastguard Worker            set_values(input, 3, 0)
383*da0073e9SAndroid Build Coastguard Worker            set_values(input, -3, 0)
384*da0073e9SAndroid Build Coastguard Worker            F = 0
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker        # expected where result:
387*da0073e9SAndroid Build Coastguard Worker        Z = 99
388*da0073e9SAndroid Build Coastguard Worker        # Z value corresponds to masked-in elements that are not
389*da0073e9SAndroid Build Coastguard Worker        # specified in the input and it will be replaced with a zero
390*da0073e9SAndroid Build Coastguard Worker        tmp = torch.tensor([[1, F, Z, F, F],
391*da0073e9SAndroid Build Coastguard Worker                            [2, F, Z, Z, F],
392*da0073e9SAndroid Build Coastguard Worker                            [F, 4, F, Z, F],
393*da0073e9SAndroid Build Coastguard Worker                            [0, 0, 0, 0, 0],
394*da0073e9SAndroid Build Coastguard Worker                            [F, F, 9, F, F],
395*da0073e9SAndroid Build Coastguard Worker                            [Z, 11, F, F, F]])
396*da0073e9SAndroid Build Coastguard Worker        tmp = to_sparse(tmp)
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker        sparse = torch.masked._where(mask, input,
400*da0073e9SAndroid Build Coastguard Worker                                     torch.tensor(fill_value, dtype=input.dtype, device=input.device))
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker        if tmp.layout == torch.sparse_coo:
403*da0073e9SAndroid Build Coastguard Worker            expected_sparse = torch.sparse_coo_tensor(
404*da0073e9SAndroid Build Coastguard Worker                tmp.indices(),
405*da0073e9SAndroid Build Coastguard Worker                torch.where(tmp.values() != Z, tmp.values(), tmp.values().new_full([], 0)),
406*da0073e9SAndroid Build Coastguard Worker                input.shape)
407*da0073e9SAndroid Build Coastguard Worker            outmask = torch.sparse_coo_tensor(sparse.indices(),
408*da0073e9SAndroid Build Coastguard Worker                                              sparse.values().new_full(sparse.values().shape, 1).to(dtype=bool),
409*da0073e9SAndroid Build Coastguard Worker                                              sparse.shape)._coalesced_(True)
410*da0073e9SAndroid Build Coastguard Worker        elif tmp.layout == torch.sparse_csr:
411*da0073e9SAndroid Build Coastguard Worker            expected_sparse = torch.sparse_csr_tensor(
412*da0073e9SAndroid Build Coastguard Worker                tmp.crow_indices(),
413*da0073e9SAndroid Build Coastguard Worker                tmp.col_indices(),
414*da0073e9SAndroid Build Coastguard Worker                torch.where(tmp.values() != Z, tmp.values(), tmp.values().new_full([], 0)),
415*da0073e9SAndroid Build Coastguard Worker                input.shape)
416*da0073e9SAndroid Build Coastguard Worker            outmask = torch.sparse_csr_tensor(sparse.crow_indices(), sparse.col_indices(),
417*da0073e9SAndroid Build Coastguard Worker                                              sparse.values().new_full(sparse.values().shape, 1).to(dtype=bool),
418*da0073e9SAndroid Build Coastguard Worker                                              sparse.shape)
419*da0073e9SAndroid Build Coastguard Worker        else:
420*da0073e9SAndroid Build Coastguard Worker            assert 0
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sparse, expected_sparse)
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker        # check invariance:
425*da0073e9SAndroid Build Coastguard Worker        #  torch.where(mask.to_dense(), input.to_dense(), fill_value)
426*da0073e9SAndroid Build Coastguard Worker        #    == where(mask, input, fill_value).to_dense(fill_value)
427*da0073e9SAndroid Build Coastguard Worker        expected = torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, F))
428*da0073e9SAndroid Build Coastguard Worker        dense = torch.where(outmask.to_dense(), sparse.to_dense(), torch.full(sparse.shape, F))
429*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dense, expected)
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMasked, globals(), except_for='meta')
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
435*da0073e9SAndroid Build Coastguard Worker    run_tests()
436