xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/feature_column_v2_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""Tests for python.tpu.feature_column."""
16
17import copy
18
19from absl.testing import parameterized
20from keras.feature_column import dense_features as df_lib
21from keras.feature_column import sequence_feature_column as sfc_lib
22
23from tensorflow.python.client import session
24from tensorflow.python.feature_column import feature_column_lib as fc_lib
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import sparse_tensor
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import init_ops
30from tensorflow.python.ops import lookup_ops
31from tensorflow.python.ops import parsing_ops
32from tensorflow.python.ops import variable_scope
33from tensorflow.python.ops import variables as variables_lib
34from tensorflow.python.platform import test
35from tensorflow.python.tpu import feature_column_v2 as tpu_fc
36from tensorflow.python.tpu import tpu
37from tensorflow.python.tpu import tpu_function
38
39
40def _initialized_session():
41  sess = session.Session()
42  sess.run(variables_lib.global_variables_initializer())
43  sess.run(lookup_ops.tables_initializer())
44  return sess
45
46
47class _TestStateManager(fc_lib.StateManager):
48
49  def __init__(self, trainable=True):
50    self._all_variables = {}
51    self._trainable = trainable
52
53  def create_variable(self,
54                      feature_column,
55                      name,
56                      shape,
57                      dtype=None,
58                      trainable=True,
59                      use_resource=True,
60                      initializer=None):
61    if feature_column not in self._all_variables:
62      self._all_variables[feature_column] = {}
63    var_dict = self._all_variables[feature_column]
64    if name in var_dict:
65      return var_dict[name]
66    else:
67      var = variable_scope.get_variable(
68          name=name,
69          shape=shape,
70          dtype=dtype,
71          trainable=self._trainable and trainable,
72          use_resource=use_resource,
73          initializer=initializer)
74      var_dict[name] = var
75      return var
76
77  def get_variable(self, feature_column, name):
78    return self._all_variables[feature_column][name]
79
80
81class EmbeddingColumnTestV2(test.TestCase, parameterized.TestCase):
82
83  def test_defaults(self):
84    categorical_column = fc_lib.categorical_column_with_identity(
85        key='aaa', num_buckets=3)
86    embedding_dimension = 2
87    embedding_column = tpu_fc.embedding_column_v2(
88        categorical_column, dimension=embedding_dimension)
89    # Can't test default initializer as it's a random function.
90    self.assertIs(categorical_column, embedding_column.categorical_column)
91    self.assertEqual(embedding_dimension, embedding_column.dimension)
92    self.assertEqual('mean', embedding_column.combiner)
93    self.assertEqual('aaa_embedding', embedding_column.name)
94    self.assertEqual((embedding_dimension,), embedding_column.variable_shape)
95
96  def test_all_constructor_args(self):
97    categorical_column = fc_lib.categorical_column_with_identity(
98        key='aaa', num_buckets=3)
99    embedding_dimension = 2
100    embedding_column = tpu_fc.embedding_column_v2(
101        categorical_column,
102        dimension=embedding_dimension,
103        combiner='my_combiner',
104        initializer=lambda: 'my_initializer')
105    self.assertIs(categorical_column, embedding_column.categorical_column)
106    self.assertEqual(embedding_dimension, embedding_column.dimension)
107    self.assertEqual('my_combiner', embedding_column.combiner)
108    self.assertEqual('my_initializer', embedding_column.initializer())
109    self.assertEqual('aaa_embedding', embedding_column.name)
110    self.assertEqual((embedding_dimension,), embedding_column.variable_shape)
111    self.assertEqual({
112        'aaa': parsing_ops.VarLenFeature(dtypes.int64)
113    }, embedding_column._parse_example_spec)
114
115  @parameterized.named_parameters(
116      {
117          'testcase_name': 'use_safe_embedding_lookup',
118          'use_safe_embedding_lookup': True,
119      }, {
120          'testcase_name': 'dont_use_safe_embedding_lookup',
121          'use_safe_embedding_lookup': False,
122      })
123  @test_util.deprecated_graph_mode_only
124  def test_feature_layer_cpu(self, use_safe_embedding_lookup):
125    # Inputs.
126    vocabulary_size = 3
127    sparse_input = sparse_tensor.SparseTensorValue(
128        # example 0, ids [2]
129        # example 1, ids [0, 1]
130        # example 2, ids []
131        # example 3, ids [1]
132        indices=((0, 0), (1, 0), (1, 1), (3, 0)),
133        values=(2, 0, 1, 1),
134        dense_shape=(4, 2))
135
136    # Embedding variable.
137    embedding_dimension = 2
138    embedding_values = (
139        (1., 2.),  # id 0
140        (3., 5.),  # id 1
141        (7., 11.)  # id 2
142    )
143
144    def _initializer(shape, dtype, partition_info=None):
145      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
146      self.assertEqual(dtypes.float32, dtype)
147      self.assertIsNone(partition_info)
148      return embedding_values
149
150    # Expected lookup result, using combiner='mean'.
151    expected_lookups = (
152        # example 0, ids [2], embedding = [7, 11]
153        (7., 11.),
154        # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
155        (2., 3.5),
156        # example 2, ids [], embedding = [0, 0]
157        (0., 0.),
158        # example 3, ids [1], embedding = [3, 5]
159        (3., 5.),
160    )
161    expected_lookups_sequence = (
162        # example 0, ids [2], embedding = [[7, 11], [0, 0]]
163        ((7., 11.), (0., 0.),),
164        # example 1, ids [0, 1], embedding = [[1, 2], [3. 5]]
165        ((1., 2.), (3., 5.),),
166        # example 2, ids [], embedding = [0, 0]
167        ((0., 0.), (0., 0.),),
168        # example 3, ids [1], embedding = [3, 5]
169        ((3., 5.), (0., 0.),),
170    )
171
172    # Build columns.
173    categorical_column = fc_lib.categorical_column_with_identity(
174        key='aaa', num_buckets=vocabulary_size)
175    sequence_categorical_column = (
176        fc_lib.sequence_categorical_column_with_identity(
177            key='bbb', num_buckets=vocabulary_size))
178    embedding_column = tpu_fc.embedding_column_v2(
179        categorical_column,
180        dimension=embedding_dimension,
181        initializer=_initializer,
182        use_safe_embedding_lookup=use_safe_embedding_lookup)
183    sequence_embedding_column = tpu_fc.embedding_column_v2(
184        sequence_categorical_column,
185        dimension=embedding_dimension,
186        initializer=_initializer,
187        max_sequence_length=2,
188        use_safe_embedding_lookup=use_safe_embedding_lookup)
189
190    # Provide sparse input and get dense result.
191    features = {'aaa': sparse_input, 'bbb': sparse_input}
192    dense_features = df_lib.DenseFeatures([embedding_column])
193    sequence_features = sfc_lib.SequenceFeatures([sequence_embedding_column])
194    embedding_lookup = dense_features(features)
195    sequence_embedding_lookup = sequence_features(features)
196
197    # Assert expected embedding variable and lookups.
198    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
199    self.assertItemsEqual(
200        ('dense_features/aaa_embedding/embedding_weights:0',
201         'sequence_features/bbb_embedding/embedding_weights:0',),
202        tuple([v.name for v in global_vars]))
203    with _initialized_session():
204      self.assertAllEqual(embedding_values, global_vars[0])
205      self.assertAllEqual(expected_lookups, embedding_lookup)
206      self.assertAllEqual(expected_lookups_sequence,
207                          sequence_embedding_lookup[0].eval())
208      # The graph will still have SparseFillEmptyRows due to sequence being
209      # a Rank3 embedding lookup.
210      if use_safe_embedding_lookup:
211        self.assertEqual(2, [
212            x.type for x in ops.get_default_graph().get_operations()
213        ].count('SparseFillEmptyRows'))
214      else:
215        self.assertEqual(1, [
216            x.type for x in ops.get_default_graph().get_operations()
217        ].count('SparseFillEmptyRows'))
218
219  def test_deepcopy(self):
220    categorical_column = fc_lib.categorical_column_with_identity(
221        key='aaa', num_buckets=3)
222    embedding_column = tpu_fc.embedding_column_v2(
223        categorical_column, dimension=2)
224    embedding_column_copy = copy.deepcopy(embedding_column)
225    self.assertEqual(embedding_column.dimension,
226                     embedding_column_copy.dimension)
227    self.assertEqual(embedding_column._max_sequence_length,
228                     embedding_column_copy._max_sequence_length)
229
230  def test_with_scope_validation(self):
231    categorical_column = fc_lib.categorical_column_with_identity(
232        key='aaa', num_buckets=3)
233    embedding_dimension = 2
234    initializer = init_ops.truncated_normal_initializer(mean=0.0, stddev=.5)
235    embedding_column = tpu_fc._TPUEmbeddingColumnV2(
236        categorical_column=categorical_column,
237        dimension=embedding_dimension,
238        combiner='mean',
239        initializer=initializer,
240        max_sequence_length=0,
241        learning_rate_fn=None,
242        use_safe_embedding_lookup=True,
243        bypass_scope_validation=False)
244    self.assertIs(categorical_column, embedding_column.categorical_column)
245    self.assertEqual(embedding_dimension, embedding_column.dimension)
246    state_manager = _TestStateManager()
247    with tpu_function.tpu_shard_context(1):
248      with variable_scope.variable_scope('tower1/scope1'):
249        embedding_column.create_state(state_manager)
250      with variable_scope.variable_scope('tower2/scope2'):
251        # With default scope validation, the same column cannot be used in a new
252        # variable scope.
253        with self.assertRaisesRegex(ValueError,
254                                    'the variable scope name is different'):
255          embedding_column.create_state(state_manager)
256
257  def test_bypass_scope_validation(self):
258    categorical_column = fc_lib.categorical_column_with_identity(
259        key='aaa', num_buckets=3)
260    embedding_dimension = 2
261    initializer = init_ops.truncated_normal_initializer(mean=0.0, stddev=.5)
262    embedding_column = tpu_fc._TPUEmbeddingColumnV2(
263        categorical_column=categorical_column,
264        dimension=embedding_dimension,
265        combiner='mean',
266        initializer=initializer,
267        max_sequence_length=0,
268        learning_rate_fn=None,
269        use_safe_embedding_lookup=True,
270        bypass_scope_validation=True)
271    self.assertIs(categorical_column, embedding_column.categorical_column)
272    self.assertEqual(embedding_dimension, embedding_column.dimension)
273    state_manager = _TestStateManager()
274    with tpu_function.tpu_shard_context(1):
275      with variable_scope.variable_scope('tower1/scope1'):
276        embedding_column.create_state(state_manager)
277      with variable_scope.variable_scope('tower2/scope2'):
278        embedding_column.create_state(state_manager)
279
280  def test_deepcopy_with_bypass_scope_validation(self):
281    categorical_column = fc_lib.categorical_column_with_identity(
282        key='aaa', num_buckets=3)
283    embedding_dimension = 2
284    initializer = init_ops.truncated_normal_initializer(mean=0.0, stddev=.5)
285    embedding_column = tpu_fc._TPUEmbeddingColumnV2(
286        categorical_column=categorical_column,
287        dimension=embedding_dimension,
288        combiner='mean',
289        initializer=initializer,
290        max_sequence_length=0,
291        use_safe_embedding_lookup=False,
292        bypass_scope_validation=True)
293    embedding_column_copy = copy.deepcopy(embedding_column)
294    self.assertEqual(embedding_dimension, embedding_column_copy.dimension)
295    self.assertEqual(embedding_column._max_sequence_length,
296                     embedding_column_copy._max_sequence_length)
297    self.assertTrue(embedding_column_copy._bypass_scope_validation)
298    self.assertFalse(embedding_column_copy.use_safe_embedding_lookup)
299
300
301class SharedEmbeddingColumnTestV2(test.TestCase, parameterized.TestCase):
302
303  @test_util.deprecated_graph_mode_only
304  def test_defaults(self):
305    vocabulary_size = 3
306    categorical_column_a = fc_lib.categorical_column_with_identity(
307        key='aaa', num_buckets=vocabulary_size)
308    categorical_column_b = fc_lib.categorical_column_with_identity(
309        key='bbb', num_buckets=vocabulary_size)
310    embedding_dimension = 2
311    embedding_column_b, embedding_column_a = tpu_fc.shared_embedding_columns_v2(
312        [categorical_column_b, categorical_column_a],
313        dimension=embedding_dimension)
314    self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
315    self.assertIs(categorical_column_b, embedding_column_b.categorical_column)
316    self.assertEqual((vocabulary_size, embedding_dimension),
317                     embedding_column_a.get_embedding_table_size())
318    self.assertEqual((vocabulary_size, embedding_dimension),
319                     embedding_column_a.get_embedding_table_size())
320    self.assertEqual('mean', embedding_column_a.combiner)
321    self.assertEqual('mean', embedding_column_b.combiner)
322    self.assertIsNotNone(embedding_column_a.get_initializer())
323    self.assertIsNotNone(embedding_column_b.get_initializer())
324    self.assertEqual('aaa_bbb_shared_embedding',
325                     embedding_column_a.get_embedding_var_name())
326    self.assertEqual('aaa_bbb_shared_embedding',
327                     embedding_column_b.get_embedding_var_name())
328    self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
329    self.assertEqual('bbb_shared_embedding', embedding_column_b.name)
330    self.assertEqual((embedding_dimension,), embedding_column_a.variable_shape)
331    self.assertEqual((embedding_dimension,), embedding_column_b.variable_shape)
332
333  @test_util.deprecated_graph_mode_only
334  def test_all_constructor_args(self):
335    vocabulary_size = 3
336    categorical_column_a = fc_lib.categorical_column_with_identity(
337        key='aaa', num_buckets=vocabulary_size)
338    categorical_column_b = fc_lib.categorical_column_with_identity(
339        key='bbb', num_buckets=vocabulary_size)
340    embedding_dimension = 2
341    embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns_v2(
342        [categorical_column_a, categorical_column_b],
343        dimension=embedding_dimension,
344        combiner='my_combiner',
345        initializer=lambda: 'my_initializer',
346        shared_embedding_collection_name='var_scope_name')
347    self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
348    self.assertIs(categorical_column_b, embedding_column_b.categorical_column)
349    self.assertEqual((vocabulary_size, embedding_dimension),
350                     embedding_column_a.get_embedding_table_size())
351    self.assertEqual((vocabulary_size, embedding_dimension),
352                     embedding_column_a.get_embedding_table_size())
353    self.assertEqual('my_combiner', embedding_column_a.combiner)
354    self.assertEqual('my_combiner', embedding_column_b.combiner)
355    self.assertEqual('my_initializer', embedding_column_a.get_initializer()())
356    self.assertEqual('my_initializer', embedding_column_b.get_initializer()())
357    self.assertEqual('var_scope_name',
358                     embedding_column_a.get_embedding_var_name())
359    self.assertEqual('var_scope_name',
360                     embedding_column_b.get_embedding_var_name())
361    self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
362    self.assertEqual('bbb_shared_embedding', embedding_column_b.name)
363    self.assertEqual((embedding_dimension,), embedding_column_a.variable_shape)
364    self.assertEqual((embedding_dimension,), embedding_column_b.variable_shape)
365
366  @parameterized.named_parameters(
367      {
368          'testcase_name': 'use_safe_embedding_lookup',
369          'use_safe_embedding_lookup': True
370      }, {
371          'testcase_name': 'dont_use_safe_embedding_lookup',
372          'use_safe_embedding_lookup': False
373      })
374  @test_util.deprecated_graph_mode_only
375  def test_feature_layer_cpu(self, use_safe_embedding_lookup):
376    # Inputs.
377    vocabulary_size = 3
378    input_a = sparse_tensor.SparseTensorValue(
379        # example 0, ids [2]
380        # example 1, ids [0, 1]
381        indices=((0, 0), (1, 0), (1, 1)),
382        values=(2, 0, 1),
383        dense_shape=(2, 2))
384    input_b = sparse_tensor.SparseTensorValue(
385        # example 0, ids [2]
386        # example 1, ids [0, 1]
387        # example 2, ids []
388        indices=((0, 0), (1, 0), (1, 1)),
389        values=(2, 0, 1),
390        dense_shape=(3, 2))
391    input_features = {'aaa': input_a, 'bbb': input_b}
392
393    # Embedding variable.
394    embedding_dimension = 2
395    embedding_values = (
396        (1., 2.),  # id 0
397        (3., 5.),  # id 1
398        (7., 11.)  # id 2
399    )
400
401    def _initializer(shape, dtype, partition_info=None):
402      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
403      self.assertEqual(dtypes.float32, dtype)
404      self.assertIsNone(partition_info)
405      return embedding_values
406
407    # Expected lookup result, using combiner='mean'.
408    expected_lookups_a = (
409        # example 0:
410        (7., 11.),  # ids [2], embedding = [7, 11]
411        # example 1:
412        (2., 3.5),  # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
413    )
414    expected_lookups_b = (
415        # example 0:
416        ((7., 11.), (0., 0.),),  # ids [2], embedding = [[7, 11], [0, 0]]
417        # example 1:
418        ((1., 2.), (3., 5.),),  # ids [0, 1], embedding = [[1, 2], [3, 5]]
419        # example 2:
420        ((0., 0.), (0., 0.),),  # ids [], embedding = [[0, 0], [0, 0]]
421    )
422
423    # Build columns.
424    categorical_column_a = fc_lib.categorical_column_with_identity(
425        key='aaa', num_buckets=vocabulary_size)
426    categorical_column_b = fc_lib.sequence_categorical_column_with_identity(
427        key='bbb', num_buckets=vocabulary_size)
428    embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns_v2(
429        [categorical_column_a, categorical_column_b],
430        dimension=embedding_dimension,
431        initializer=_initializer,
432        max_sequence_lengths=[0, 2],
433        use_safe_embedding_lookup=use_safe_embedding_lookup)
434
435    # Provide sparse input and get dense result.
436    dense_features = df_lib.DenseFeatures([embedding_column_a])
437    sequence_features = sfc_lib.SequenceFeatures([embedding_column_b])
438    embedding_lookup_a = dense_features(input_features)
439    embedding_lookup_b = sequence_features(input_features)
440
441    # Assert expected embedding variable and lookups.
442    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
443    self.assertItemsEqual(
444        ('aaa_bbb_shared_embedding:0',),
445        tuple([v.name for v in global_vars]))
446    embedding_var = global_vars[0]
447    with _initialized_session():
448      self.assertAllEqual(embedding_values, embedding_var)
449      self.assertAllEqual(expected_lookups_a, embedding_lookup_a)
450      self.assertAllEqual(expected_lookups_b,
451                          embedding_lookup_b[0].eval())
452      # The graph will still have SparseFillEmptyRows due to sequence being
453      # a Rank3 embedding lookup.
454      if use_safe_embedding_lookup:
455        self.assertEqual(2, [
456            x.type for x in ops.get_default_graph().get_operations()
457        ].count('SparseFillEmptyRows'))
458      else:
459        self.assertEqual(1, [
460            x.type for x in ops.get_default_graph().get_operations()
461        ].count('SparseFillEmptyRows'))
462
463  def test_deepcopy(self):
464    vocabulary_size = 3
465    categorical_column_a = fc_lib.categorical_column_with_identity(
466        key='aaa', num_buckets=vocabulary_size)
467    categorical_column_b = fc_lib.categorical_column_with_identity(
468        key='bbb', num_buckets=vocabulary_size)
469    embedding_dimension = 2
470    columns = tpu_fc.shared_embedding_columns_v2(
471        [categorical_column_b, categorical_column_a],
472        dimension=embedding_dimension)
473    columns_copy = copy.deepcopy(columns)
474    self.assertEqual(
475        [column._shared_embedding_collection_name for column in columns],
476        [column._shared_embedding_collection_name for column in columns_copy])
477
478
479class DeviceSpecificEmbeddingColumnTestV2(test.TestCase,
480                                          parameterized.TestCase):
481
482  @parameterized.named_parameters(
483      {
484          'testcase_name': 'invalid_shared',
485          'shared': True,
486      }, {
487          'testcase_name': 'invalid_not_shared',
488          'shared': False,
489      })
490  @test_util.deprecated_graph_mode_only
491  def test_invalid_cases(self, shared):
492
493    # Inputs.
494    input_sparse_tensor = sparse_tensor.SparseTensorValue(
495        indices=((0, 0), (1, 0), (1, 1), (1, 4)),
496        values=(2, 0, 1, 3),
497        dense_shape=(2, 5))
498    input_features = {'inp': input_sparse_tensor}
499
500    # Build columns.
501    categorical_column_input = fc_lib.categorical_column_with_identity(
502        key='inp', num_buckets=3)
503
504    # Training on TPU with cpu embedding lookups is not supported.
505    if shared:
506      embedding_column = tpu_fc.shared_embedding_columns_v2(
507          [categorical_column_input],
508          dimension=2,
509          embedding_lookup_device='cpu',
510          tensor_core_shape=[None, 3])
511    else:
512      embedding_column = tpu_fc.embedding_column_v2(
513          categorical_column_input,
514          dimension=2,
515          embedding_lookup_device='cpu',
516          tensor_core_shape=[None, 3])
517    dense_features = df_lib.DenseFeatures(embedding_column)
518    with self.assertRaisesRegex(
519        ValueError,
520        r'.*embedding_lookup_device=\"cpu\" during training is not'):
521      dense_features(input_features)
522
523    # Inference on with TPU Embedding Hardware is not supported.
524    if shared:
525      embedding_column = tpu_fc.shared_embedding_columns_v2(
526          [categorical_column_input],
527          dimension=2,
528          embedding_lookup_device='tpu_embedding_core',
529          tensor_core_shape=[None, 3])
530    else:
531      embedding_column = tpu_fc.embedding_column_v2(
532          categorical_column_input,
533          dimension=2,
534          embedding_lookup_device='tpu_embedding_core',
535          tensor_core_shape=[None, 3])
536    context = tpu._TPUInferenceContext('tpu_inference')
537    context.Enter()
538    dense_features = df_lib.DenseFeatures(embedding_column)
539    with self.assertRaisesRegex(
540        ValueError,
541        r'Using embedding_lookup_device=tpu_embedding_core during inference is '
542    ):
543      dense_features(input_features)
544    context.Exit()
545
546  @parameterized.named_parameters(
547      {
548          'testcase_name': 'combiner_mean_shared',
549          'shared': True,
550          'combiner': 'mean'
551      }, {
552          'testcase_name': 'combiner_sum_shared',
553          'shared': True,
554          'combiner': 'sum'
555      }, {
556          'testcase_name': 'combiner_sqrtn_shared',
557          'shared': True,
558          'combiner': 'sqrtn'
559      }, {
560          'testcase_name': 'combiner_mean_not_shared',
561          'shared': False,
562          'combiner': 'mean'
563      }, {
564          'testcase_name': 'combiner_sum_not_shared',
565          'shared': False,
566          'combiner': 'sum'
567      }, {
568          'testcase_name': 'combiner_sqrtn_not_shared',
569          'shared': False,
570          'combiner': 'sqrtn'
571      })
572  @test_util.deprecated_graph_mode_only
573  def test_dense_embedding_lookup(self, shared, combiner):
574    # Inputs.
575    vocabulary_size = 3
576    input_sparse_tensor = sparse_tensor.SparseTensorValue(
577        # example 0, ids [2]
578        # example 1, ids [0, 1, 3]
579        indices=((0, 0), (1, 0), (1, 1), (1, 4)),
580        values=(2, 0, 1, 3),
581        dense_shape=(2, 5))
582    input_features = {'inp': input_sparse_tensor}
583
584    # Embedding variable.
585    embedding_dimension = 2
586    embedding_values = (
587        (1., 2.),  # id 0
588        (3., 5.),  # id 1
589        (7., 11.),  # id 2
590        (13., 17.)  # id 3
591    )
592
593    def _initializer(shape, dtype, partition_info=None):
594      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
595      self.assertEqual(dtypes.float32, dtype)
596      self.assertIsNone(partition_info)
597      return embedding_values
598
599    # Build columns.
600    categorical_column_input = fc_lib.categorical_column_with_identity(
601        key='inp', num_buckets=vocabulary_size)
602
603    # Set tensor_core_shape to be [None, 20] to ensure some padding and
604    # dynamic batch size.
605    if shared:
606      embedding_column = tpu_fc.shared_embedding_columns_v2(
607          [categorical_column_input],
608          dimension=embedding_dimension,
609          initializer=_initializer,
610          combiner=combiner,
611          embedding_lookup_device='tpu_tensor_core',
612          tensor_core_shape=[None, 3])
613    else:
614      embedding_column = tpu_fc.embedding_column_v2(
615          categorical_column_input,
616          dimension=embedding_dimension,
617          initializer=_initializer,
618          combiner=combiner,
619          embedding_lookup_device='tpu_tensor_core',
620          tensor_core_shape=[None, 3])
621
622    # Run in TPUContexts so that we hit the intended densification case.
623    context = tpu._TPUInferenceContext('tpu_inference')
624    context.Enter()
625    with tpu_function.tpu_shard_context(1):
626      dense_features = df_lib.DenseFeatures(embedding_column)
627      # Sqrtn combiner not supported for now.
628      if combiner == 'sqrtn':
629        with self.assertRaisesRegex(
630            ValueError, 'Dense TPU Embedding does not support combiner'):
631          embedding_lookup = dense_features(input_features)
632        return
633      if combiner == 'mean':
634        expected_lookups = (
635            # example 0:
636            (7., 11.),  # ids [2], embedding = [7, 11]
637            # example 1:
638            (2., 3.5),  # ids [0, 1], embedding = mean([1, 2] + [3, 5]) =
639            # [2, 3.5]
640        )
641      elif combiner == 'sum':
642        expected_lookups = (
643            # example 0:
644            (7., 11.),  # ids [2], embedding = [7, 11]
645            # example 1:
646            (4., 7),  # ids [0, 1], embedding = sum([1, 2] + [3, 5]) = [4, 7]
647        )
648
649      embedding_lookup = dense_features(input_features)
650
651      # Assert expected embedding variable and lookups.
652      global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
653      if shared:
654        self.assertCountEqual(('inp_shared_embedding:0',),
655                              tuple([v.name for v in global_vars]))
656      else:
657        self.assertCountEqual(
658            ('dense_features/inp_embedding/embedding_weights:0',),
659            tuple([v.name for v in global_vars]))
660
661      embedding_var = global_vars[0]
662      with _initialized_session():
663        self.assertAllEqual(embedding_values, embedding_var)
664        eval_res = embedding_lookup.eval()
665        self.assertAllEqual(expected_lookups, eval_res)
666      context.Exit()
667
668  @test_util.deprecated_graph_mode_only
669  def test_empty_row(self):
670    # Inputs.
671    vocabulary_size = 3
672    input_sparse_tensor = sparse_tensor.SparseTensorValue(
673        # example 0, ids []
674        # example 1, ids [0, 1, 3]
675        indices=((1, 0), (1, 1), (1, 4)),
676        values=(0, 1, 3),
677        dense_shape=(2, 5))
678    input_features = {'inp': input_sparse_tensor}
679
680    # Embedding variable.
681    embedding_dimension = 2
682    embedding_values = (
683        (1., 2.),  # id 0
684        (3., 5.),  # id 1
685        (7., 11.),  # id 2
686        (13., 17.)  # id 3
687    )
688
689    def _initializer(shape, dtype, partition_info=None):
690      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
691      self.assertEqual(dtypes.float32, dtype)
692      self.assertIsNone(partition_info)
693      return embedding_values
694
695    # Build columns.
696    categorical_column_input = fc_lib.categorical_column_with_identity(
697        key='inp', num_buckets=vocabulary_size)
698
699    # Set tensor_core_shape to be [None, 20] to ensure some padding and
700    # dynamic batch size.
701    embedding_column = tpu_fc.embedding_column_v2(
702        categorical_column_input,
703        dimension=embedding_dimension,
704        initializer=_initializer,
705        combiner='mean',
706        embedding_lookup_device='tpu_tensor_core',
707        tensor_core_shape=[None, 3])
708
709    # Run in TPUContexts so that we hit the intended densification case.
710    context = tpu._TPUInferenceContext('tpu_inference')
711    context.Enter()
712    with tpu_function.tpu_shard_context(1):
713      dense_features = df_lib.DenseFeatures(embedding_column)
714      expected_lookups = (
715          # example 0:
716          (0., 0.),  # ids [], embedding = [0, 0]
717          # example 1:
718          (2., 3.5),  # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
719      )
720
721      embedding_lookup = dense_features(input_features)
722
723      # Assert expected embedding variable and lookups.
724      global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
725      self.assertCountEqual(
726          ('dense_features/inp_embedding/embedding_weights:0',),
727          tuple([v.name for v in global_vars]))
728
729      embedding_var = global_vars[0]
730      with _initialized_session():
731        self.assertAllEqual(embedding_values, embedding_var)
732        eval_res = embedding_lookup.eval()
733        self.assertAllEqual(expected_lookups, eval_res)
734      context.Exit()
735
736  @test_util.deprecated_graph_mode_only
737  def test_error_dense_shape_invalid(self):
738    categorical_column_input = fc_lib.categorical_column_with_identity(
739        key='inp', num_buckets=5)
740    with self.assertRaisesRegex(ValueError, 'tensor_core_shape must be size 2'):
741      tpu_fc.shared_embedding_columns_v2([categorical_column_input],
742                                         dimension=20,
743                                         tensor_core_shape=[None, 20, 15])
744
745
746if __name__ == '__main__':
747  test.main()
748