xref: /aosp_15_r20/external/pytorch/test/test_pruning_op.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import hypothesis.strategies as st
4from hypothesis import given
5import numpy as np
6import torch
7from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
8import torch.testing._internal.hypothesis_utils as hu
9hu.assert_deadline_disabled()
10
11
12class PruningOpTest(TestCase):
13
14    # Generate rowwise mask vector based on indicator and threshold value.
15    # indicator is a vector that contains one value per weight row and it
16    # represents the importance of a row.
17    # We mask a row if its indicator value is less than the threshold.
18    def _generate_rowwise_mask(self, embedding_rows):
19        indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32))
20        threshold = float(np.random.random_sample())
21        mask = torch.BoolTensor([True if val >= threshold else False for val in indicator])
22        return mask
23
24    def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype):
25        embedding_weights = None
26        if weights_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
27            embedding_weights = torch.randint(0, 100, (embedding_rows, embedding_dims), dtype=weights_dtype)
28        else:
29            embedding_weights = torch.rand((embedding_rows, embedding_dims), dtype=weights_dtype)
30        mask = self._generate_rowwise_mask(embedding_rows)
31
32        def get_pt_result(embedding_weights, mask, indices_type):
33            return torch._rowwise_prune(embedding_weights, mask, indices_type)
34
35        # Reference implementation.
36        def get_reference_result(embedding_weights, mask, indices_type):
37            num_embeddings = mask.size()[0]
38            compressed_idx_out = torch.zeros(num_embeddings, dtype=indices_type)
39            pruned_weights_out = embedding_weights[mask[:]]
40            idx = 0
41            for i in range(mask.size()[0]):
42                if mask[i]:
43                    compressed_idx_out[i] = idx
44                    idx = idx + 1
45                else:
46                    compressed_idx_out[i] = -1
47            return (pruned_weights_out, compressed_idx_out)
48
49        pt_pruned_weights, pt_compressed_indices_map = get_pt_result(
50            embedding_weights, mask, indices_type)
51        ref_pruned_weights, ref_compressed_indices_map = get_reference_result(
52            embedding_weights, mask, indices_type)
53
54        torch.testing.assert_close(pt_pruned_weights, ref_pruned_weights)
55        self.assertEqual(pt_compressed_indices_map, ref_compressed_indices_map)
56        self.assertEqual(pt_compressed_indices_map.dtype, indices_type)
57
58
59    @skipIfTorchDynamo()
60    @given(
61        embedding_rows=st.integers(1, 100),
62        embedding_dims=st.integers(1, 100),
63        weights_dtype=st.sampled_from([torch.float64, torch.float32,
64                                       torch.float16, torch.int8,
65                                       torch.int16, torch.int32, torch.int64])
66    )
67    def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
68        self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype)
69
70
71    @skipIfTorchDynamo()
72    @given(
73        embedding_rows=st.integers(1, 100),
74        embedding_dims=st.integers(1, 100),
75        weights_dtype=st.sampled_from([torch.float64, torch.float32,
76                                       torch.float16, torch.int8,
77                                       torch.int16, torch.int32, torch.int64])
78    )
79    def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
80        self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype)
81
82
83if __name__ == '__main__':
84    run_tests()
85