1# Owner(s): ["module: unknown"] 2 3import hypothesis.strategies as st 4from hypothesis import given 5import numpy as np 6import torch 7from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo 8import torch.testing._internal.hypothesis_utils as hu 9hu.assert_deadline_disabled() 10 11 12class PruningOpTest(TestCase): 13 14 # Generate rowwise mask vector based on indicator and threshold value. 15 # indicator is a vector that contains one value per weight row and it 16 # represents the importance of a row. 17 # We mask a row if its indicator value is less than the threshold. 18 def _generate_rowwise_mask(self, embedding_rows): 19 indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32)) 20 threshold = float(np.random.random_sample()) 21 mask = torch.BoolTensor([True if val >= threshold else False for val in indicator]) 22 return mask 23 24 def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype): 25 embedding_weights = None 26 if weights_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: 27 embedding_weights = torch.randint(0, 100, (embedding_rows, embedding_dims), dtype=weights_dtype) 28 else: 29 embedding_weights = torch.rand((embedding_rows, embedding_dims), dtype=weights_dtype) 30 mask = self._generate_rowwise_mask(embedding_rows) 31 32 def get_pt_result(embedding_weights, mask, indices_type): 33 return torch._rowwise_prune(embedding_weights, mask, indices_type) 34 35 # Reference implementation. 36 def get_reference_result(embedding_weights, mask, indices_type): 37 num_embeddings = mask.size()[0] 38 compressed_idx_out = torch.zeros(num_embeddings, dtype=indices_type) 39 pruned_weights_out = embedding_weights[mask[:]] 40 idx = 0 41 for i in range(mask.size()[0]): 42 if mask[i]: 43 compressed_idx_out[i] = idx 44 idx = idx + 1 45 else: 46 compressed_idx_out[i] = -1 47 return (pruned_weights_out, compressed_idx_out) 48 49 pt_pruned_weights, pt_compressed_indices_map = get_pt_result( 50 embedding_weights, mask, indices_type) 51 ref_pruned_weights, ref_compressed_indices_map = get_reference_result( 52 embedding_weights, mask, indices_type) 53 54 torch.testing.assert_close(pt_pruned_weights, ref_pruned_weights) 55 self.assertEqual(pt_compressed_indices_map, ref_compressed_indices_map) 56 self.assertEqual(pt_compressed_indices_map.dtype, indices_type) 57 58 59 @skipIfTorchDynamo() 60 @given( 61 embedding_rows=st.integers(1, 100), 62 embedding_dims=st.integers(1, 100), 63 weights_dtype=st.sampled_from([torch.float64, torch.float32, 64 torch.float16, torch.int8, 65 torch.int16, torch.int32, torch.int64]) 66 ) 67 def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, weights_dtype): 68 self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype) 69 70 71 @skipIfTorchDynamo() 72 @given( 73 embedding_rows=st.integers(1, 100), 74 embedding_dims=st.integers(1, 100), 75 weights_dtype=st.sampled_from([torch.float64, torch.float32, 76 torch.float16, torch.int8, 77 torch.int16, torch.int32, torch.int64]) 78 ) 79 def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype): 80 self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype) 81 82 83if __name__ == '__main__': 84 run_tests() 85