xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_embedding_v2_utils_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for TPU Embeddings mid level API utils on TPU."""
16
17from absl.testing import parameterized
18
19from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
20from tensorflow.python.compat import v2_compat
21from tensorflow.python.platform import test
22from tensorflow.python.tpu import tpu_embedding_v2_utils
23
24
25class TPUEmbeddingOptimizerTest(parameterized.TestCase, test.TestCase):
26
27  @parameterized.parameters(tpu_embedding_v2_utils.Adagrad,
28                            tpu_embedding_v2_utils.Adam,
29                            tpu_embedding_v2_utils.FTRL)
30  def test_grad_clip_with_accumulation_off(self, optimizer):
31    with self.assertRaisesRegex(ValueError, 'accumulation'):
32      optimizer(use_gradient_accumulation=False, clipvalue=0.)
33    with self.assertRaisesRegex(ValueError, 'accumulation'):
34      optimizer(use_gradient_accumulation=False, clipvalue=(None, 1.))
35
36  @parameterized.parameters(tpu_embedding_v2_utils.SGD,
37                            tpu_embedding_v2_utils.Adagrad,
38                            tpu_embedding_v2_utils.Adam,
39                            tpu_embedding_v2_utils.FTRL)
40  def test_grad_clip_with_tuple(self, optimizer):
41    opt = optimizer(clipvalue=(-1., 1.))
42    self.assertEqual(-1., opt.clip_gradient_min)
43    self.assertEqual(1., opt.clip_gradient_max)
44
45  @parameterized.parameters(tpu_embedding_v2_utils.SGD,
46                            tpu_embedding_v2_utils.Adagrad,
47                            tpu_embedding_v2_utils.Adam,
48                            tpu_embedding_v2_utils.FTRL)
49  def test_grad_clip_with_single_value(self, optimizer):
50    opt = optimizer(clipvalue=1.)
51    self.assertEqual(-1., opt.clip_gradient_min)
52    self.assertEqual(1., opt.clip_gradient_max)
53
54  @parameterized.parameters(tpu_embedding_v2_utils.SGD,
55                            tpu_embedding_v2_utils.Adagrad,
56                            tpu_embedding_v2_utils.Adam,
57                            tpu_embedding_v2_utils.FTRL)
58  def test_grad_clip_with_tuple_and_none(self, optimizer):
59    opt = optimizer(clipvalue=(None, 1))
60    self.assertIsNone(opt.clip_gradient_min)
61    self.assertEqual(1., opt.clip_gradient_max)
62
63  @parameterized.parameters(tpu_embedding_v2_utils.SGD,
64                            tpu_embedding_v2_utils.Adagrad,
65                            tpu_embedding_v2_utils.Adam,
66                            tpu_embedding_v2_utils.FTRL)
67  def test_equal_and_hash_function(self, optimizer):
68    opt1 = optimizer(0.1)
69    opt2 = optimizer(0.1)
70    opt3 = optimizer(0.2)
71    self.assertEqual(opt1, opt2)
72    self.assertEqual(hash(opt1), hash(opt2))
73    self.assertNotEqual(opt1, opt3)
74    self.assertNotEqual(hash(opt1), hash(opt3))
75
76
77class ConfigTest(test.TestCase):
78
79  def test_table_config_repr(self):
80    table = tpu_embedding_v2_utils.TableConfig(
81        vocabulary_size=2, dim=4,
82        combiner='sum', name='table')
83
84    self.assertEqual(
85        repr(table),
86        'TableConfig(vocabulary_size=2, dim=4, initializer=None, '
87        'optimizer=None, combiner=\'sum\', name=\'table\')')
88
89  def test_feature_config_repr(self):
90    table = tpu_embedding_v2_utils.TableConfig(
91        vocabulary_size=2, dim=4, initializer=None,
92        combiner='sum', name='table')
93
94    feature_config = tpu_embedding_v2_utils.FeatureConfig(
95        table=table, name='feature')
96
97    self.assertEqual(
98        repr(feature_config),
99        'FeatureConfig(table=TableConfig(vocabulary_size=2, dim=4, '
100        'initializer=None, optimizer=None, combiner=\'sum\', name=\'table\'), '
101        'max_sequence_length=0, validate_weights_and_indices=True, '
102        'name=\'feature\')')
103
104
105class TPUEmbeddingConfigurationTest(test.TestCase):
106
107  def test_no_truncate(self):
108    truncate_length = 14937  # Experimentally maximum string length loggable.
109
110    config = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration()
111    for i in range(500):
112      td = config.table_descriptor.add()
113      td.name = 'table_{}'.format(i)
114      td.vocabulary_size = i
115    config.num_hosts = 2
116    config.num_tensor_cores = 4
117    config.batch_size_per_tensor_core = 128
118
119    self.assertGreater(
120        len(str(config)), truncate_length,
121        'Test sanity check: generated config should be of truncating length.')
122
123    with self.assertLogs() as logs:
124      tpu_embedding_v2_utils.log_tpu_embedding_configuration(config)
125
126    self.assertIn('table_499', ''.join(logs.output))
127    for line in logs.output:
128      self.assertLess(
129          len(line), truncate_length,
130          'Logging function lines should not be of truncating length.')
131
132
133if __name__ == '__main__':
134  v2_compat.enable_v2_behavior()
135  test.main()
136