xref: /aosp_15_r20/external/pytorch/test/test_pruning_op.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport hypothesis.strategies as st
4*da0073e9SAndroid Build Coastguard Workerfrom hypothesis import given
5*da0073e9SAndroid Build Coastguard Workerimport numpy as np
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
8*da0073e9SAndroid Build Coastguard Workerimport torch.testing._internal.hypothesis_utils as hu
9*da0073e9SAndroid Build Coastguard Workerhu.assert_deadline_disabled()
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerclass PruningOpTest(TestCase):
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker    # Generate rowwise mask vector based on indicator and threshold value.
15*da0073e9SAndroid Build Coastguard Worker    # indicator is a vector that contains one value per weight row and it
16*da0073e9SAndroid Build Coastguard Worker    # represents the importance of a row.
17*da0073e9SAndroid Build Coastguard Worker    # We mask a row if its indicator value is less than the threshold.
18*da0073e9SAndroid Build Coastguard Worker    def _generate_rowwise_mask(self, embedding_rows):
19*da0073e9SAndroid Build Coastguard Worker        indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32))
20*da0073e9SAndroid Build Coastguard Worker        threshold = float(np.random.random_sample())
21*da0073e9SAndroid Build Coastguard Worker        mask = torch.BoolTensor([True if val >= threshold else False for val in indicator])
22*da0073e9SAndroid Build Coastguard Worker        return mask
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype):
25*da0073e9SAndroid Build Coastguard Worker        embedding_weights = None
26*da0073e9SAndroid Build Coastguard Worker        if weights_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
27*da0073e9SAndroid Build Coastguard Worker            embedding_weights = torch.randint(0, 100, (embedding_rows, embedding_dims), dtype=weights_dtype)
28*da0073e9SAndroid Build Coastguard Worker        else:
29*da0073e9SAndroid Build Coastguard Worker            embedding_weights = torch.rand((embedding_rows, embedding_dims), dtype=weights_dtype)
30*da0073e9SAndroid Build Coastguard Worker        mask = self._generate_rowwise_mask(embedding_rows)
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker        def get_pt_result(embedding_weights, mask, indices_type):
33*da0073e9SAndroid Build Coastguard Worker            return torch._rowwise_prune(embedding_weights, mask, indices_type)
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker        # Reference implementation.
36*da0073e9SAndroid Build Coastguard Worker        def get_reference_result(embedding_weights, mask, indices_type):
37*da0073e9SAndroid Build Coastguard Worker            num_embeddings = mask.size()[0]
38*da0073e9SAndroid Build Coastguard Worker            compressed_idx_out = torch.zeros(num_embeddings, dtype=indices_type)
39*da0073e9SAndroid Build Coastguard Worker            pruned_weights_out = embedding_weights[mask[:]]
40*da0073e9SAndroid Build Coastguard Worker            idx = 0
41*da0073e9SAndroid Build Coastguard Worker            for i in range(mask.size()[0]):
42*da0073e9SAndroid Build Coastguard Worker                if mask[i]:
43*da0073e9SAndroid Build Coastguard Worker                    compressed_idx_out[i] = idx
44*da0073e9SAndroid Build Coastguard Worker                    idx = idx + 1
45*da0073e9SAndroid Build Coastguard Worker                else:
46*da0073e9SAndroid Build Coastguard Worker                    compressed_idx_out[i] = -1
47*da0073e9SAndroid Build Coastguard Worker            return (pruned_weights_out, compressed_idx_out)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        pt_pruned_weights, pt_compressed_indices_map = get_pt_result(
50*da0073e9SAndroid Build Coastguard Worker            embedding_weights, mask, indices_type)
51*da0073e9SAndroid Build Coastguard Worker        ref_pruned_weights, ref_compressed_indices_map = get_reference_result(
52*da0073e9SAndroid Build Coastguard Worker            embedding_weights, mask, indices_type)
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(pt_pruned_weights, ref_pruned_weights)
55*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pt_compressed_indices_map, ref_compressed_indices_map)
56*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(pt_compressed_indices_map.dtype, indices_type)
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
60*da0073e9SAndroid Build Coastguard Worker    @given(
61*da0073e9SAndroid Build Coastguard Worker        embedding_rows=st.integers(1, 100),
62*da0073e9SAndroid Build Coastguard Worker        embedding_dims=st.integers(1, 100),
63*da0073e9SAndroid Build Coastguard Worker        weights_dtype=st.sampled_from([torch.float64, torch.float32,
64*da0073e9SAndroid Build Coastguard Worker                                       torch.float16, torch.int8,
65*da0073e9SAndroid Build Coastguard Worker                                       torch.int16, torch.int32, torch.int64])
66*da0073e9SAndroid Build Coastguard Worker    )
67*da0073e9SAndroid Build Coastguard Worker    def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
68*da0073e9SAndroid Build Coastguard Worker        self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo()
72*da0073e9SAndroid Build Coastguard Worker    @given(
73*da0073e9SAndroid Build Coastguard Worker        embedding_rows=st.integers(1, 100),
74*da0073e9SAndroid Build Coastguard Worker        embedding_dims=st.integers(1, 100),
75*da0073e9SAndroid Build Coastguard Worker        weights_dtype=st.sampled_from([torch.float64, torch.float32,
76*da0073e9SAndroid Build Coastguard Worker                                       torch.float16, torch.int8,
77*da0073e9SAndroid Build Coastguard Worker                                       torch.int16, torch.int32, torch.int64])
78*da0073e9SAndroid Build Coastguard Worker    )
79*da0073e9SAndroid Build Coastguard Worker    def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype):
80*da0073e9SAndroid Build Coastguard Worker        self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
84*da0073e9SAndroid Build Coastguard Worker    run_tests()
85