1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport hypothesis.strategies as st 4*da0073e9SAndroid Build Coastguard Workerfrom hypothesis import given 5*da0073e9SAndroid Build Coastguard Workerimport numpy as np 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo 8*da0073e9SAndroid Build Coastguard Workerimport torch.testing._internal.hypothesis_utils as hu 9*da0073e9SAndroid Build Coastguard Workerhu.assert_deadline_disabled() 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerclass PruningOpTest(TestCase): 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker # Generate rowwise mask vector based on indicator and threshold value. 15*da0073e9SAndroid Build Coastguard Worker # indicator is a vector that contains one value per weight row and it 16*da0073e9SAndroid Build Coastguard Worker # represents the importance of a row. 17*da0073e9SAndroid Build Coastguard Worker # We mask a row if its indicator value is less than the threshold. 18*da0073e9SAndroid Build Coastguard Worker def _generate_rowwise_mask(self, embedding_rows): 19*da0073e9SAndroid Build Coastguard Worker indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32)) 20*da0073e9SAndroid Build Coastguard Worker threshold = float(np.random.random_sample()) 21*da0073e9SAndroid Build Coastguard Worker mask = torch.BoolTensor([True if val >= threshold else False for val in indicator]) 22*da0073e9SAndroid Build Coastguard Worker return mask 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype): 25*da0073e9SAndroid Build Coastguard Worker embedding_weights = None 26*da0073e9SAndroid Build Coastguard Worker if weights_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: 27*da0073e9SAndroid Build Coastguard Worker embedding_weights = torch.randint(0, 100, (embedding_rows, embedding_dims), dtype=weights_dtype) 28*da0073e9SAndroid Build Coastguard Worker else: 29*da0073e9SAndroid Build Coastguard Worker embedding_weights = torch.rand((embedding_rows, embedding_dims), dtype=weights_dtype) 30*da0073e9SAndroid Build Coastguard Worker mask = self._generate_rowwise_mask(embedding_rows) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker def get_pt_result(embedding_weights, mask, indices_type): 33*da0073e9SAndroid Build Coastguard Worker return torch._rowwise_prune(embedding_weights, mask, indices_type) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker # Reference implementation. 36*da0073e9SAndroid Build Coastguard Worker def get_reference_result(embedding_weights, mask, indices_type): 37*da0073e9SAndroid Build Coastguard Worker num_embeddings = mask.size()[0] 38*da0073e9SAndroid Build Coastguard Worker compressed_idx_out = torch.zeros(num_embeddings, dtype=indices_type) 39*da0073e9SAndroid Build Coastguard Worker pruned_weights_out = embedding_weights[mask[:]] 40*da0073e9SAndroid Build Coastguard Worker idx = 0 41*da0073e9SAndroid Build Coastguard Worker for i in range(mask.size()[0]): 42*da0073e9SAndroid Build Coastguard Worker if mask[i]: 43*da0073e9SAndroid Build Coastguard Worker compressed_idx_out[i] = idx 44*da0073e9SAndroid Build Coastguard Worker idx = idx + 1 45*da0073e9SAndroid Build Coastguard Worker else: 46*da0073e9SAndroid Build Coastguard Worker compressed_idx_out[i] = -1 47*da0073e9SAndroid Build Coastguard Worker return (pruned_weights_out, compressed_idx_out) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker pt_pruned_weights, pt_compressed_indices_map = get_pt_result( 50*da0073e9SAndroid Build Coastguard Worker embedding_weights, mask, indices_type) 51*da0073e9SAndroid Build Coastguard Worker ref_pruned_weights, ref_compressed_indices_map = get_reference_result( 52*da0073e9SAndroid Build Coastguard Worker embedding_weights, mask, indices_type) 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(pt_pruned_weights, ref_pruned_weights) 55*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pt_compressed_indices_map, ref_compressed_indices_map) 56*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pt_compressed_indices_map.dtype, indices_type) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 60*da0073e9SAndroid Build Coastguard Worker @given( 61*da0073e9SAndroid Build Coastguard Worker embedding_rows=st.integers(1, 100), 62*da0073e9SAndroid Build Coastguard Worker embedding_dims=st.integers(1, 100), 63*da0073e9SAndroid Build Coastguard Worker weights_dtype=st.sampled_from([torch.float64, torch.float32, 64*da0073e9SAndroid Build Coastguard Worker torch.float16, torch.int8, 65*da0073e9SAndroid Build Coastguard Worker torch.int16, torch.int32, torch.int64]) 66*da0073e9SAndroid Build Coastguard Worker ) 67*da0073e9SAndroid Build Coastguard Worker def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, weights_dtype): 68*da0073e9SAndroid Build Coastguard Worker self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 72*da0073e9SAndroid Build Coastguard Worker @given( 73*da0073e9SAndroid Build Coastguard Worker embedding_rows=st.integers(1, 100), 74*da0073e9SAndroid Build Coastguard Worker embedding_dims=st.integers(1, 100), 75*da0073e9SAndroid Build Coastguard Worker weights_dtype=st.sampled_from([torch.float64, torch.float32, 76*da0073e9SAndroid Build Coastguard Worker torch.float16, torch.int8, 77*da0073e9SAndroid Build Coastguard Worker torch.int16, torch.int32, torch.int64]) 78*da0073e9SAndroid Build Coastguard Worker ) 79*da0073e9SAndroid Build Coastguard Worker def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype): 80*da0073e9SAndroid Build Coastguard Worker self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 84*da0073e9SAndroid Build Coastguard Worker run_tests() 85