1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for TPU Embeddings mid level API utils on TPU.""" 16 17from absl.testing import parameterized 18 19from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 20from tensorflow.python.compat import v2_compat 21from tensorflow.python.platform import test 22from tensorflow.python.tpu import tpu_embedding_v2_utils 23 24 25class TPUEmbeddingOptimizerTest(parameterized.TestCase, test.TestCase): 26 27 @parameterized.parameters(tpu_embedding_v2_utils.Adagrad, 28 tpu_embedding_v2_utils.Adam, 29 tpu_embedding_v2_utils.FTRL) 30 def test_grad_clip_with_accumulation_off(self, optimizer): 31 with self.assertRaisesRegex(ValueError, 'accumulation'): 32 optimizer(use_gradient_accumulation=False, clipvalue=0.) 33 with self.assertRaisesRegex(ValueError, 'accumulation'): 34 optimizer(use_gradient_accumulation=False, clipvalue=(None, 1.)) 35 36 @parameterized.parameters(tpu_embedding_v2_utils.SGD, 37 tpu_embedding_v2_utils.Adagrad, 38 tpu_embedding_v2_utils.Adam, 39 tpu_embedding_v2_utils.FTRL) 40 def test_grad_clip_with_tuple(self, optimizer): 41 opt = optimizer(clipvalue=(-1., 1.)) 42 self.assertEqual(-1., opt.clip_gradient_min) 43 self.assertEqual(1., opt.clip_gradient_max) 44 45 @parameterized.parameters(tpu_embedding_v2_utils.SGD, 46 tpu_embedding_v2_utils.Adagrad, 47 tpu_embedding_v2_utils.Adam, 48 tpu_embedding_v2_utils.FTRL) 49 def test_grad_clip_with_single_value(self, optimizer): 50 opt = optimizer(clipvalue=1.) 51 self.assertEqual(-1., opt.clip_gradient_min) 52 self.assertEqual(1., opt.clip_gradient_max) 53 54 @parameterized.parameters(tpu_embedding_v2_utils.SGD, 55 tpu_embedding_v2_utils.Adagrad, 56 tpu_embedding_v2_utils.Adam, 57 tpu_embedding_v2_utils.FTRL) 58 def test_grad_clip_with_tuple_and_none(self, optimizer): 59 opt = optimizer(clipvalue=(None, 1)) 60 self.assertIsNone(opt.clip_gradient_min) 61 self.assertEqual(1., opt.clip_gradient_max) 62 63 @parameterized.parameters(tpu_embedding_v2_utils.SGD, 64 tpu_embedding_v2_utils.Adagrad, 65 tpu_embedding_v2_utils.Adam, 66 tpu_embedding_v2_utils.FTRL) 67 def test_equal_and_hash_function(self, optimizer): 68 opt1 = optimizer(0.1) 69 opt2 = optimizer(0.1) 70 opt3 = optimizer(0.2) 71 self.assertEqual(opt1, opt2) 72 self.assertEqual(hash(opt1), hash(opt2)) 73 self.assertNotEqual(opt1, opt3) 74 self.assertNotEqual(hash(opt1), hash(opt3)) 75 76 77class ConfigTest(test.TestCase): 78 79 def test_table_config_repr(self): 80 table = tpu_embedding_v2_utils.TableConfig( 81 vocabulary_size=2, dim=4, 82 combiner='sum', name='table') 83 84 self.assertEqual( 85 repr(table), 86 'TableConfig(vocabulary_size=2, dim=4, initializer=None, ' 87 'optimizer=None, combiner=\'sum\', name=\'table\')') 88 89 def test_feature_config_repr(self): 90 table = tpu_embedding_v2_utils.TableConfig( 91 vocabulary_size=2, dim=4, initializer=None, 92 combiner='sum', name='table') 93 94 feature_config = tpu_embedding_v2_utils.FeatureConfig( 95 table=table, name='feature') 96 97 self.assertEqual( 98 repr(feature_config), 99 'FeatureConfig(table=TableConfig(vocabulary_size=2, dim=4, ' 100 'initializer=None, optimizer=None, combiner=\'sum\', name=\'table\'), ' 101 'max_sequence_length=0, validate_weights_and_indices=True, ' 102 'name=\'feature\')') 103 104 105class TPUEmbeddingConfigurationTest(test.TestCase): 106 107 def test_no_truncate(self): 108 truncate_length = 14937 # Experimentally maximum string length loggable. 109 110 config = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration() 111 for i in range(500): 112 td = config.table_descriptor.add() 113 td.name = 'table_{}'.format(i) 114 td.vocabulary_size = i 115 config.num_hosts = 2 116 config.num_tensor_cores = 4 117 config.batch_size_per_tensor_core = 128 118 119 self.assertGreater( 120 len(str(config)), truncate_length, 121 'Test sanity check: generated config should be of truncating length.') 122 123 with self.assertLogs() as logs: 124 tpu_embedding_v2_utils.log_tpu_embedding_configuration(config) 125 126 self.assertIn('table_499', ''.join(logs.output)) 127 for line in logs.output: 128 self.assertLess( 129 len(line), truncate_length, 130 'Logging function lines should not be of truncating length.') 131 132 133if __name__ == '__main__': 134 v2_compat.enable_v2_behavior() 135 test.main() 136