xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_embedding_for_serving_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for TPU Embeddings mid level API on CPU."""
16
17import numpy as np
18
19from tensorflow.python.compat import v2_compat
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import sparse_tensor
23from tensorflow.python.ops import init_ops_v2
24from tensorflow.python.ops.ragged import ragged_tensor
25from tensorflow.python.platform import test
26from tensorflow.python.tpu import tpu_embedding_for_serving
27from tensorflow.python.tpu import tpu_embedding_v2_utils
28from tensorflow.python.util import nest
29
30
31class TPUEmbeddingForServingTest(test.TestCase):
32
33  def setUp(self):
34    super(TPUEmbeddingForServingTest, self).setUp()
35
36    self.embedding_values = np.array(list(range(32)), dtype=np.float64)
37    self.initializer = init_ops_v2.Constant(self.embedding_values)
38    # Embedding for video initialized to
39    # 0 1 2 3
40    # 4 5 6 7
41    # ...
42    self.table_video = tpu_embedding_v2_utils.TableConfig(
43        vocabulary_size=8,
44        dim=4,
45        initializer=self.initializer,
46        combiner='sum',
47        name='video')
48    # Embedding for user initialized to
49    # 0 1
50    # 2 3
51    # 4 5
52    # 6 7
53    # ...
54    self.table_user = tpu_embedding_v2_utils.TableConfig(
55        vocabulary_size=16,
56        dim=2,
57        initializer=self.initializer,
58        combiner='mean',
59        name='user')
60    self.feature_config = (
61        tpu_embedding_v2_utils.FeatureConfig(
62            table=self.table_video, name='watched'),
63        tpu_embedding_v2_utils.FeatureConfig(
64            table=self.table_video, name='favorited'),
65        tpu_embedding_v2_utils.FeatureConfig(
66            table=self.table_user, name='friends'))
67
68    self.batch_size = 2
69    self.data_batch_size = 4
70
71    # One (global) batch of inputs
72    # sparse tensor for watched:
73    # row 0: 0
74    # row 1: 0, 1
75    # row 2: 0, 1
76    # row 3: 1
77    self.feature_watched_indices = [[0, 0], [1, 0], [1, 1],
78                                    [2, 0], [2, 1], [3, 0]]
79    self.feature_watched_values = [0, 0, 1, 0, 1, 1]
80    self.feature_watched_row_lengths = [1, 2, 2, 1]
81    # sparse tensor for favorited:
82    # row 0: 0, 1
83    # row 1: 1
84    # row 2: 0
85    # row 3: 0, 1
86    self.feature_favorited_indices = [[0, 0], [0, 1], [1, 0],
87                                      [2, 0], [3, 0], [3, 1]]
88    self.feature_favorited_values = [0, 1, 1, 0, 0, 1]
89    self.feature_favorited_row_lengths = [2, 1, 1, 2]
90    # sparse tensor for friends:
91    # row 0: 3
92    # row 1: 0, 1, 2
93    # row 2: 3
94    # row 3: 0, 1, 2
95    self.feature_friends_indices = [[0, 0], [1, 0], [1, 1], [1, 2],
96                                    [2, 0], [3, 0], [3, 1], [3, 2]]
97    self.feature_friends_values = [3, 0, 1, 2, 3, 0, 1, 2]
98    self.feature_friends_row_lengths = [1, 3, 1, 3]
99
100  def _create_mid_level(self):
101    optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
102    return tpu_embedding_for_serving.TPUEmbeddingForServing(
103        feature_config=self.feature_config, optimizer=optimizer)
104
105  def _get_dense_tensors(self, dtype=dtypes.int32):
106    feature0 = constant_op.constant(self.feature_watched_values, dtype=dtype)
107    feature1 = constant_op.constant(self.feature_favorited_values, dtype=dtype)
108    feature2 = constant_op.constant(self.feature_friends_values, dtype=dtype)
109    return (feature0, feature1, feature2)
110
111  def test_cpu_dense_lookup(self):
112    mid_level = self._create_mid_level()
113    features = self._get_dense_tensors()
114    results = mid_level(features, weights=None)
115    all_lookups = []
116    for feature, config in zip(nest.flatten(features), self.feature_config):
117      table = mid_level.embedding_tables[config.table].numpy()
118      all_lookups.append(table[feature.numpy()])
119    self.assertAllClose(results, nest.pack_sequence_as(results, all_lookups))
120
121  def test_cpu_dense_lookup_with_weights(self):
122    mid_level = self._create_mid_level()
123    features = self._get_dense_tensors()
124    weights = self._get_dense_tensors(dtype=dtypes.float32)
125
126    with self.assertRaisesRegex(
127        ValueError, 'Weight specified for .*, but input is dense.'):
128      mid_level(features, weights=weights)
129
130  def _get_sparse_tensors(self, dtype=dtypes.int32):
131    feature0 = sparse_tensor.SparseTensor(
132        indices=self.feature_watched_indices,
133        values=constant_op.constant(self.feature_watched_values, dtype=dtype),
134        dense_shape=[self.data_batch_size, 2])
135    feature1 = sparse_tensor.SparseTensor(
136        indices=self.feature_favorited_indices,
137        values=constant_op.constant(self.feature_favorited_values, dtype=dtype),
138        dense_shape=[self.data_batch_size, 2])
139    feature2 = sparse_tensor.SparseTensor(
140        indices=self.feature_friends_indices,
141        values=constant_op.constant(self.feature_friends_values, dtype=dtype),
142        dense_shape=[self.data_batch_size, 3])
143    return (feature0, feature1, feature2)
144
145  def test_cpu_sparse_lookup(self):
146    mid_level = self._create_mid_level()
147    features = self._get_sparse_tensors()
148    results = mid_level(features, weights=None)
149    reduced = []
150    for feature, config in zip(nest.flatten(features), self.feature_config):
151      table = mid_level.embedding_tables[config.table].numpy()
152      all_lookups = table[feature.values.numpy()]
153      # With row starts we can use reduceat in numpy. Get row starts from the
154      # ragged tensor API.
155      ragged = ragged_tensor.RaggedTensor.from_sparse(feature)
156      row_starts = ragged.row_starts().numpy()
157      reduced.append(np.add.reduceat(all_lookups, row_starts))
158      if config.table.combiner == 'mean':
159        # for mean, divide by the row lengths.
160        reduced[-1] /= np.expand_dims(ragged.row_lengths().numpy(), axis=1)
161    self.assertAllClose(results, nest.pack_sequence_as(results, reduced))
162
163  def test_cpu_sparse_lookup_with_weights(self):
164    mid_level = self._create_mid_level()
165    features = self._get_sparse_tensors()
166    weights = self._get_sparse_tensors(dtype=dtypes.float32)
167    results = mid_level(features, weights=weights)
168    weighted_sum = []
169    for feature, weight, config in zip(nest.flatten(features),
170                                       nest.flatten(weights),
171                                       self.feature_config):
172      table = mid_level.embedding_tables[config.table].numpy()
173      # Expand dims here needed to broadcast this multiplication properly.
174      weight = np.expand_dims(weight.values.numpy(), axis=1)
175      all_lookups = table[feature.values.numpy()] * weight
176      # With row starts we can use reduceat in numpy. Get row starts from the
177      # ragged tensor API.
178      row_starts = ragged_tensor.RaggedTensor.from_sparse(feature).row_starts()
179      row_starts = row_starts.numpy()
180      weighted_sum.append(np.add.reduceat(all_lookups, row_starts))
181      if config.table.combiner == 'mean':
182        weighted_sum[-1] /= np.add.reduceat(weight, row_starts)
183    self.assertAllClose(results, nest.pack_sequence_as(results,
184                                                       weighted_sum))
185
186  def test_cpu_sparse_lookup_with_non_sparse_weights(self):
187    mid_level = self._create_mid_level()
188    features = self._get_sparse_tensors()
189    weights = self._get_dense_tensors(dtype=dtypes.float32)
190    with self.assertRaisesRegex(
191        ValueError, 'but it does not match type of the input which is'):
192      mid_level(features, weights=weights)
193
194  def _get_ragged_tensors(self, dtype=dtypes.int32):
195    feature0 = ragged_tensor.RaggedTensor.from_row_lengths(
196        values=constant_op.constant(self.feature_watched_values, dtype=dtype),
197        row_lengths=self.feature_watched_row_lengths)
198    feature1 = ragged_tensor.RaggedTensor.from_row_lengths(
199        values=constant_op.constant(self.feature_favorited_values, dtype=dtype),
200        row_lengths=self.feature_favorited_row_lengths)
201    feature2 = ragged_tensor.RaggedTensor.from_row_lengths(
202        values=constant_op.constant(self.feature_friends_values, dtype=dtype),
203        row_lengths=self.feature_friends_row_lengths)
204    return (feature0, feature1, feature2)
205
206  def test_cpu_ragged_lookup_with_weights(self):
207    mid_level = self._create_mid_level()
208    features = self._get_ragged_tensors()
209    weights = self._get_ragged_tensors(dtype=dtypes.float32)
210    results = mid_level(features, weights=weights)
211    weighted_sum = []
212    for feature, weight, config in zip(nest.flatten(features),
213                                       nest.flatten(weights),
214                                       self.feature_config):
215      table = mid_level.embedding_tables[config.table].numpy()
216      # Expand dims here needed to broadcast this multiplication properly.
217      weight = np.expand_dims(weight.values.numpy(), axis=1)
218      all_lookups = table[feature.values.numpy()] * weight
219      row_starts = feature.row_starts().numpy()
220      weighted_sum.append(np.add.reduceat(all_lookups, row_starts))
221      if config.table.combiner == 'mean':
222        weighted_sum[-1] /= np.add.reduceat(weight, row_starts)
223    self.assertAllClose(results, nest.pack_sequence_as(results,
224                                                       weighted_sum))
225
226  def test_cpu_invalid_structure_for_features(self):
227    mid_level = self._create_mid_level()
228    # Remove one element of the tuple, self.feature_config has 3 so we need to
229    # pass 3.
230    features = tuple(self._get_sparse_tensors()[:2])
231    with self.assertRaises(ValueError):
232      mid_level(features, weights=None)
233
234  def test_cpu_invalid_structure_for_weights(self):
235    mid_level = self._create_mid_level()
236    features = self._get_sparse_tensors()
237    # Remove one element of the tuple, self.feature_config has 3 so we need to
238    # pass 3 (or None).
239    weights = tuple(self._get_dense_tensors(dtype=dtypes.float32)[:2])
240    with self.assertRaises(ValueError):
241      mid_level(features, weights=weights)
242
243  def _numpy_sequence_lookup(self, table, indices, values, batch_size,
244                             max_sequence_length, dim):
245    # First we truncate to max_sequence_length.
246    valid_entries = np.nonzero(indices[:, 1] < max_sequence_length)[0]
247    indices = indices[valid_entries]
248    values = values[valid_entries]
249    # Then we gather the values
250    lookup = table[values]
251    # Then we scatter them into the result array.
252    scatter_result = np.zeros([batch_size, max_sequence_length, dim])
253    for i, index in enumerate(indices):
254      scatter_result[index[0], index[1], :] = lookup[i]
255    return scatter_result
256
257  def test_cpu_sequence_lookup_sparse(self):
258    feature_config = (
259        tpu_embedding_v2_utils.FeatureConfig(
260            table=self.table_user, name='friends', max_sequence_length=2),)
261    optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
262    mid_level = tpu_embedding_for_serving.TPUEmbeddingForServing(
263        feature_config=feature_config, optimizer=optimizer)
264    features = self._get_sparse_tensors()[2:3]
265    result = mid_level(features, weights=None)
266
267    golden = self._numpy_sequence_lookup(
268        mid_level.embedding_tables[self.table_user].numpy(),
269        features[0].indices.numpy(),
270        features[0].values.numpy(),
271        self.data_batch_size,
272        feature_config[0].max_sequence_length,
273        self.table_user.dim)
274
275    self.assertAllClose(result[0], golden)
276
277  def test_cpu_sequence_lookup_ragged(self):
278    feature_config = (
279        tpu_embedding_v2_utils.FeatureConfig(
280            table=self.table_user, name='friends', max_sequence_length=2),)
281    optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
282    mid_level = tpu_embedding_for_serving.TPUEmbeddingForServing(
283        feature_config=feature_config, optimizer=optimizer)
284    features = self._get_ragged_tensors()[2:3]
285    result = mid_level(features, weights=None)
286
287    sparse_ver = features[0].to_sparse()
288    golden = self._numpy_sequence_lookup(
289        mid_level.embedding_tables[self.table_user].numpy(),
290        sparse_ver.indices.numpy(),
291        sparse_ver.values.numpy(),
292        self.data_batch_size,
293        feature_config[0].max_sequence_length,
294        self.table_user.dim)
295
296    self.assertAllClose(result[0], golden)
297
298  def test_cpu_high_dimensional_lookup_ragged(self):
299    feature_config = (tpu_embedding_v2_utils.FeatureConfig(
300        table=self.table_user, name='friends', output_shape=[2, 2]),)
301    optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
302    mid_level = tpu_embedding_for_serving.TPUEmbeddingForServing(
303        feature_config=feature_config, optimizer=optimizer)
304    features = self._get_ragged_tensors()[2:3]
305    result = mid_level(features, weights=None)
306
307    self.assertAllClose(result[0].shape, (2, 2, 2))
308
309  def test_cpu_high_dimensional_sequence_lookup_ragged(self):
310    # Prod of output shape is a factor of the data batch size.
311    # The divide result will be the sequence length.
312    feature_config = (tpu_embedding_v2_utils.FeatureConfig(
313        table=self.table_user, name='friends', output_shape=[2, 4]),)
314    optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
315    mid_level = tpu_embedding_for_serving.TPUEmbeddingForServing(
316        feature_config=feature_config, optimizer=optimizer)
317    features = self._get_ragged_tensors()[2:3]
318    result = mid_level(features, weights=None)
319    self.assertAllClose(result[0].shape, (2, 4, 2))
320
321  def test_cpu_high_dimensional_invalid_lookup_ragged(self):
322    # Prod of output shape is not a factor of the data batch size.
323    # An error will be raised in this case.
324    feature_config = (tpu_embedding_v2_utils.FeatureConfig(
325        table=self.table_user, name='friends', output_shape=[3]),)
326    optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
327    mid_level = tpu_embedding_for_serving.TPUEmbeddingForServing(
328        feature_config=feature_config, optimizer=optimizer)
329    features = self._get_ragged_tensors()[2:3]
330    with self.assertRaisesRegex(
331        ValueError,
332        'Output shape set in the FeatureConfig should be the factor'):
333      mid_level(features, weights=None)
334
335  def test_cpu_no_optimizer(self):
336    feature_config = (
337        tpu_embedding_v2_utils.FeatureConfig(
338            table=self.table_video, name='watched', max_sequence_length=2),)
339    mid_level = tpu_embedding_for_serving.TPUEmbeddingForServing(
340        feature_config=feature_config, optimizer=None)
341    # Build the layer manually to create the variables. Normally calling enqueue
342    # would do this.
343    mid_level.build()
344    self.assertEqual(
345        list(mid_level._variables[self.table_video.name].keys()),
346        ['parameters'])
347
348  def test_cpu_multiple_creation(self):
349    feature_config = (tpu_embedding_v2_utils.FeatureConfig(
350        table=self.table_user, name='friends', max_sequence_length=2),)
351    optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
352    embedding_one = tpu_embedding_for_serving.TPUEmbeddingForServing(
353        feature_config=feature_config, optimizer=optimizer)
354    embedding_two = tpu_embedding_for_serving.TPUEmbeddingForServing(
355        feature_config=feature_config, optimizer=optimizer)
356
357    # Both of the tpu embedding tables should be able to build on cpu.
358    embedding_one.build()
359    embedding_two.build()
360
361
362if __name__ == '__main__':
363  v2_compat.enable_v2_behavior()
364  test.main()
365