from absl.testing import parameterized import tensorflow as tf from google.protobuf import text_format from fcp.dictionary import dictionary_pb2 from fcp.tensorflow import dictionary_ops class DictionaryOpsTest(tf.test.TestCase, parameterized.TestCase): def test_direct_tf_use_literal_dictionary(self): dictionary = dictionary_pb2.DictionaryDescription() text_format.Merge( 'special_ids: < unk: 0 > ' 'vocabulary: < ' ' index: < token: "a" token: "b" token: "c" token: "d" >' '>', dictionary) lookup = dictionary_ops.dictionary_lookup( tf.constant(['a', 'b', 'a', 'a', 'd', 'X']), dictionary_description_proto=dictionary.SerializeToString()) with tf.compat.v1.Session() as sess: tokenized = sess.run(lookup) self.assertEqual([1, 2, 1, 1, 4, 0], tokenized.tolist()) @parameterized.named_parameters( ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) def test_build_dictionary_with_output_blocklist(self, vocabulary_type): # Build a dictionary, explicitly blocklisting the first token and # implicitly blocklisting the last token via output_size. dictionary = dictionary_ops.Dictionary.from_tokens( ['01', '02', '10', '11'], unk_id=0, output_blocklist_tokens=['01'], output_size=4, vocabulary_type=vocabulary_type) if vocabulary_type in ( dictionary_ops.VocabularyType.TOKEN_INDEX, ): result = dictionary_ops.dictionary_lookup( [['01', '02', '10', '11', '12']], dictionary_description_proto=dictionary.dictionary_description_proto) with tf.compat.v1.Session() as sess: tokenized = sess.run(result) self.assertEqual([[1, 2, 3, 4, 0]], tokenized.tolist()) self.assertEqual( [1, 4], list(dictionary.dictionary_description.output_blocklist_ids.id)) @parameterized.named_parameters( ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) def test_build_dictionary(self, vocabulary_type): dictionary = dictionary_ops.Dictionary.from_tokens( ['A', 'a', 'B', 'c'], unk_id=0, vocabulary_type=vocabulary_type) result = dictionary_ops.dictionary_lookup( [['A', 'a', 'B', 'b', 'C', 'c', 'D', 'd']], dictionary_description_proto=dictionary.dictionary_description_proto) expected = [[1, 2, 3, 0, 0, 4, 0, 0]] with tf.compat.v1.Session() as sess: tokenized = sess.run(result) self.assertEqual(expected, tokenized.tolist()) @parameterized.named_parameters( ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) def test_dictionary_should_raise_with_duplicate_tokens(self, vocabulary_type): with self.assertRaisesRegex(ValueError, 'Duplicate tokens'): dictionary_ops.Dictionary.from_tokens(['01', '02', '11', '10', '11'], vocabulary_type=vocabulary_type) @parameterized.named_parameters( ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) def test_lookup_in_python(self, vocabulary_type): dictionary = dictionary_ops.Dictionary.from_tokens( ['01', '02', '10', '11'], unk_id=0, vocabulary_type=vocabulary_type) self.assertLen(dictionary, 5) self.assertListEqual([1, 2, 3, 4, 0], dictionary.lookup(['01', '02', '10', '11', '12'])) @parameterized.named_parameters( ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) def test_reverse_lookup_in_python(self, vocabulary_type): dictionary = dictionary_ops.Dictionary.from_tokens( ['01', '02', '10', '11'], unk_id=0, vocabulary_type=vocabulary_type) self.assertLen(dictionary, 5) rlookup = [ t.decode('utf-8') for t in dictionary.reverse_lookup([3, 2, 1, 4, 0]) ] self.assertListEqual(['10', '02', '01', '11', ''], rlookup) def test_literal_dictionary_in_python(self): dictionary_description = dictionary_pb2.DictionaryDescription() text_format.Merge( 'special_ids: < unk: 0 > ' 'vocabulary: < ' ' index: < token: "a" token: "b" token: "c" token: "d" >' '>', dictionary_description) dictionary = dictionary_ops.Dictionary.from_dictionary_description( dictionary_description) self.assertListEqual([b'a', b'b', b'c', b'd'], dictionary.tokens) if __name__ == '__main__': # Required since the test still relies on v1 Session.run behavior. tf.compat.v1.disable_v2_behavior() tf.test.main()