1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15"""Tests for python.tpu.feature_column.""" 16 17import copy 18 19from absl.testing import parameterized 20from keras.feature_column import dense_features as df_lib 21from keras.feature_column import sequence_feature_column as sfc_lib 22 23from tensorflow.python.client import session 24from tensorflow.python.feature_column import feature_column_lib as fc_lib 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import sparse_tensor 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import init_ops 30from tensorflow.python.ops import lookup_ops 31from tensorflow.python.ops import parsing_ops 32from tensorflow.python.ops import variable_scope 33from tensorflow.python.ops import variables as variables_lib 34from tensorflow.python.platform import test 35from tensorflow.python.tpu import feature_column_v2 as tpu_fc 36from tensorflow.python.tpu import tpu 37from tensorflow.python.tpu import tpu_function 38 39 40def _initialized_session(): 41 sess = session.Session() 42 sess.run(variables_lib.global_variables_initializer()) 43 sess.run(lookup_ops.tables_initializer()) 44 return sess 45 46 47class _TestStateManager(fc_lib.StateManager): 48 49 def __init__(self, trainable=True): 50 self._all_variables = {} 51 self._trainable = trainable 52 53 def create_variable(self, 54 feature_column, 55 name, 56 shape, 57 dtype=None, 58 trainable=True, 59 use_resource=True, 60 initializer=None): 61 if feature_column not in self._all_variables: 62 self._all_variables[feature_column] = {} 63 var_dict = self._all_variables[feature_column] 64 if name in var_dict: 65 return var_dict[name] 66 else: 67 var = variable_scope.get_variable( 68 name=name, 69 shape=shape, 70 dtype=dtype, 71 trainable=self._trainable and trainable, 72 use_resource=use_resource, 73 initializer=initializer) 74 var_dict[name] = var 75 return var 76 77 def get_variable(self, feature_column, name): 78 return self._all_variables[feature_column][name] 79 80 81class EmbeddingColumnTestV2(test.TestCase, parameterized.TestCase): 82 83 def test_defaults(self): 84 categorical_column = fc_lib.categorical_column_with_identity( 85 key='aaa', num_buckets=3) 86 embedding_dimension = 2 87 embedding_column = tpu_fc.embedding_column_v2( 88 categorical_column, dimension=embedding_dimension) 89 # Can't test default initializer as it's a random function. 90 self.assertIs(categorical_column, embedding_column.categorical_column) 91 self.assertEqual(embedding_dimension, embedding_column.dimension) 92 self.assertEqual('mean', embedding_column.combiner) 93 self.assertEqual('aaa_embedding', embedding_column.name) 94 self.assertEqual((embedding_dimension,), embedding_column.variable_shape) 95 96 def test_all_constructor_args(self): 97 categorical_column = fc_lib.categorical_column_with_identity( 98 key='aaa', num_buckets=3) 99 embedding_dimension = 2 100 embedding_column = tpu_fc.embedding_column_v2( 101 categorical_column, 102 dimension=embedding_dimension, 103 combiner='my_combiner', 104 initializer=lambda: 'my_initializer') 105 self.assertIs(categorical_column, embedding_column.categorical_column) 106 self.assertEqual(embedding_dimension, embedding_column.dimension) 107 self.assertEqual('my_combiner', embedding_column.combiner) 108 self.assertEqual('my_initializer', embedding_column.initializer()) 109 self.assertEqual('aaa_embedding', embedding_column.name) 110 self.assertEqual((embedding_dimension,), embedding_column.variable_shape) 111 self.assertEqual({ 112 'aaa': parsing_ops.VarLenFeature(dtypes.int64) 113 }, embedding_column._parse_example_spec) 114 115 @parameterized.named_parameters( 116 { 117 'testcase_name': 'use_safe_embedding_lookup', 118 'use_safe_embedding_lookup': True, 119 }, { 120 'testcase_name': 'dont_use_safe_embedding_lookup', 121 'use_safe_embedding_lookup': False, 122 }) 123 @test_util.deprecated_graph_mode_only 124 def test_feature_layer_cpu(self, use_safe_embedding_lookup): 125 # Inputs. 126 vocabulary_size = 3 127 sparse_input = sparse_tensor.SparseTensorValue( 128 # example 0, ids [2] 129 # example 1, ids [0, 1] 130 # example 2, ids [] 131 # example 3, ids [1] 132 indices=((0, 0), (1, 0), (1, 1), (3, 0)), 133 values=(2, 0, 1, 1), 134 dense_shape=(4, 2)) 135 136 # Embedding variable. 137 embedding_dimension = 2 138 embedding_values = ( 139 (1., 2.), # id 0 140 (3., 5.), # id 1 141 (7., 11.) # id 2 142 ) 143 144 def _initializer(shape, dtype, partition_info=None): 145 self.assertAllEqual((vocabulary_size, embedding_dimension), shape) 146 self.assertEqual(dtypes.float32, dtype) 147 self.assertIsNone(partition_info) 148 return embedding_values 149 150 # Expected lookup result, using combiner='mean'. 151 expected_lookups = ( 152 # example 0, ids [2], embedding = [7, 11] 153 (7., 11.), 154 # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] 155 (2., 3.5), 156 # example 2, ids [], embedding = [0, 0] 157 (0., 0.), 158 # example 3, ids [1], embedding = [3, 5] 159 (3., 5.), 160 ) 161 expected_lookups_sequence = ( 162 # example 0, ids [2], embedding = [[7, 11], [0, 0]] 163 ((7., 11.), (0., 0.),), 164 # example 1, ids [0, 1], embedding = [[1, 2], [3. 5]] 165 ((1., 2.), (3., 5.),), 166 # example 2, ids [], embedding = [0, 0] 167 ((0., 0.), (0., 0.),), 168 # example 3, ids [1], embedding = [3, 5] 169 ((3., 5.), (0., 0.),), 170 ) 171 172 # Build columns. 173 categorical_column = fc_lib.categorical_column_with_identity( 174 key='aaa', num_buckets=vocabulary_size) 175 sequence_categorical_column = ( 176 fc_lib.sequence_categorical_column_with_identity( 177 key='bbb', num_buckets=vocabulary_size)) 178 embedding_column = tpu_fc.embedding_column_v2( 179 categorical_column, 180 dimension=embedding_dimension, 181 initializer=_initializer, 182 use_safe_embedding_lookup=use_safe_embedding_lookup) 183 sequence_embedding_column = tpu_fc.embedding_column_v2( 184 sequence_categorical_column, 185 dimension=embedding_dimension, 186 initializer=_initializer, 187 max_sequence_length=2, 188 use_safe_embedding_lookup=use_safe_embedding_lookup) 189 190 # Provide sparse input and get dense result. 191 features = {'aaa': sparse_input, 'bbb': sparse_input} 192 dense_features = df_lib.DenseFeatures([embedding_column]) 193 sequence_features = sfc_lib.SequenceFeatures([sequence_embedding_column]) 194 embedding_lookup = dense_features(features) 195 sequence_embedding_lookup = sequence_features(features) 196 197 # Assert expected embedding variable and lookups. 198 global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 199 self.assertItemsEqual( 200 ('dense_features/aaa_embedding/embedding_weights:0', 201 'sequence_features/bbb_embedding/embedding_weights:0',), 202 tuple([v.name for v in global_vars])) 203 with _initialized_session(): 204 self.assertAllEqual(embedding_values, global_vars[0]) 205 self.assertAllEqual(expected_lookups, embedding_lookup) 206 self.assertAllEqual(expected_lookups_sequence, 207 sequence_embedding_lookup[0].eval()) 208 # The graph will still have SparseFillEmptyRows due to sequence being 209 # a Rank3 embedding lookup. 210 if use_safe_embedding_lookup: 211 self.assertEqual(2, [ 212 x.type for x in ops.get_default_graph().get_operations() 213 ].count('SparseFillEmptyRows')) 214 else: 215 self.assertEqual(1, [ 216 x.type for x in ops.get_default_graph().get_operations() 217 ].count('SparseFillEmptyRows')) 218 219 def test_deepcopy(self): 220 categorical_column = fc_lib.categorical_column_with_identity( 221 key='aaa', num_buckets=3) 222 embedding_column = tpu_fc.embedding_column_v2( 223 categorical_column, dimension=2) 224 embedding_column_copy = copy.deepcopy(embedding_column) 225 self.assertEqual(embedding_column.dimension, 226 embedding_column_copy.dimension) 227 self.assertEqual(embedding_column._max_sequence_length, 228 embedding_column_copy._max_sequence_length) 229 230 def test_with_scope_validation(self): 231 categorical_column = fc_lib.categorical_column_with_identity( 232 key='aaa', num_buckets=3) 233 embedding_dimension = 2 234 initializer = init_ops.truncated_normal_initializer(mean=0.0, stddev=.5) 235 embedding_column = tpu_fc._TPUEmbeddingColumnV2( 236 categorical_column=categorical_column, 237 dimension=embedding_dimension, 238 combiner='mean', 239 initializer=initializer, 240 max_sequence_length=0, 241 learning_rate_fn=None, 242 use_safe_embedding_lookup=True, 243 bypass_scope_validation=False) 244 self.assertIs(categorical_column, embedding_column.categorical_column) 245 self.assertEqual(embedding_dimension, embedding_column.dimension) 246 state_manager = _TestStateManager() 247 with tpu_function.tpu_shard_context(1): 248 with variable_scope.variable_scope('tower1/scope1'): 249 embedding_column.create_state(state_manager) 250 with variable_scope.variable_scope('tower2/scope2'): 251 # With default scope validation, the same column cannot be used in a new 252 # variable scope. 253 with self.assertRaisesRegex(ValueError, 254 'the variable scope name is different'): 255 embedding_column.create_state(state_manager) 256 257 def test_bypass_scope_validation(self): 258 categorical_column = fc_lib.categorical_column_with_identity( 259 key='aaa', num_buckets=3) 260 embedding_dimension = 2 261 initializer = init_ops.truncated_normal_initializer(mean=0.0, stddev=.5) 262 embedding_column = tpu_fc._TPUEmbeddingColumnV2( 263 categorical_column=categorical_column, 264 dimension=embedding_dimension, 265 combiner='mean', 266 initializer=initializer, 267 max_sequence_length=0, 268 learning_rate_fn=None, 269 use_safe_embedding_lookup=True, 270 bypass_scope_validation=True) 271 self.assertIs(categorical_column, embedding_column.categorical_column) 272 self.assertEqual(embedding_dimension, embedding_column.dimension) 273 state_manager = _TestStateManager() 274 with tpu_function.tpu_shard_context(1): 275 with variable_scope.variable_scope('tower1/scope1'): 276 embedding_column.create_state(state_manager) 277 with variable_scope.variable_scope('tower2/scope2'): 278 embedding_column.create_state(state_manager) 279 280 def test_deepcopy_with_bypass_scope_validation(self): 281 categorical_column = fc_lib.categorical_column_with_identity( 282 key='aaa', num_buckets=3) 283 embedding_dimension = 2 284 initializer = init_ops.truncated_normal_initializer(mean=0.0, stddev=.5) 285 embedding_column = tpu_fc._TPUEmbeddingColumnV2( 286 categorical_column=categorical_column, 287 dimension=embedding_dimension, 288 combiner='mean', 289 initializer=initializer, 290 max_sequence_length=0, 291 use_safe_embedding_lookup=False, 292 bypass_scope_validation=True) 293 embedding_column_copy = copy.deepcopy(embedding_column) 294 self.assertEqual(embedding_dimension, embedding_column_copy.dimension) 295 self.assertEqual(embedding_column._max_sequence_length, 296 embedding_column_copy._max_sequence_length) 297 self.assertTrue(embedding_column_copy._bypass_scope_validation) 298 self.assertFalse(embedding_column_copy.use_safe_embedding_lookup) 299 300 301class SharedEmbeddingColumnTestV2(test.TestCase, parameterized.TestCase): 302 303 @test_util.deprecated_graph_mode_only 304 def test_defaults(self): 305 vocabulary_size = 3 306 categorical_column_a = fc_lib.categorical_column_with_identity( 307 key='aaa', num_buckets=vocabulary_size) 308 categorical_column_b = fc_lib.categorical_column_with_identity( 309 key='bbb', num_buckets=vocabulary_size) 310 embedding_dimension = 2 311 embedding_column_b, embedding_column_a = tpu_fc.shared_embedding_columns_v2( 312 [categorical_column_b, categorical_column_a], 313 dimension=embedding_dimension) 314 self.assertIs(categorical_column_a, embedding_column_a.categorical_column) 315 self.assertIs(categorical_column_b, embedding_column_b.categorical_column) 316 self.assertEqual((vocabulary_size, embedding_dimension), 317 embedding_column_a.get_embedding_table_size()) 318 self.assertEqual((vocabulary_size, embedding_dimension), 319 embedding_column_a.get_embedding_table_size()) 320 self.assertEqual('mean', embedding_column_a.combiner) 321 self.assertEqual('mean', embedding_column_b.combiner) 322 self.assertIsNotNone(embedding_column_a.get_initializer()) 323 self.assertIsNotNone(embedding_column_b.get_initializer()) 324 self.assertEqual('aaa_bbb_shared_embedding', 325 embedding_column_a.get_embedding_var_name()) 326 self.assertEqual('aaa_bbb_shared_embedding', 327 embedding_column_b.get_embedding_var_name()) 328 self.assertEqual('aaa_shared_embedding', embedding_column_a.name) 329 self.assertEqual('bbb_shared_embedding', embedding_column_b.name) 330 self.assertEqual((embedding_dimension,), embedding_column_a.variable_shape) 331 self.assertEqual((embedding_dimension,), embedding_column_b.variable_shape) 332 333 @test_util.deprecated_graph_mode_only 334 def test_all_constructor_args(self): 335 vocabulary_size = 3 336 categorical_column_a = fc_lib.categorical_column_with_identity( 337 key='aaa', num_buckets=vocabulary_size) 338 categorical_column_b = fc_lib.categorical_column_with_identity( 339 key='bbb', num_buckets=vocabulary_size) 340 embedding_dimension = 2 341 embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns_v2( 342 [categorical_column_a, categorical_column_b], 343 dimension=embedding_dimension, 344 combiner='my_combiner', 345 initializer=lambda: 'my_initializer', 346 shared_embedding_collection_name='var_scope_name') 347 self.assertIs(categorical_column_a, embedding_column_a.categorical_column) 348 self.assertIs(categorical_column_b, embedding_column_b.categorical_column) 349 self.assertEqual((vocabulary_size, embedding_dimension), 350 embedding_column_a.get_embedding_table_size()) 351 self.assertEqual((vocabulary_size, embedding_dimension), 352 embedding_column_a.get_embedding_table_size()) 353 self.assertEqual('my_combiner', embedding_column_a.combiner) 354 self.assertEqual('my_combiner', embedding_column_b.combiner) 355 self.assertEqual('my_initializer', embedding_column_a.get_initializer()()) 356 self.assertEqual('my_initializer', embedding_column_b.get_initializer()()) 357 self.assertEqual('var_scope_name', 358 embedding_column_a.get_embedding_var_name()) 359 self.assertEqual('var_scope_name', 360 embedding_column_b.get_embedding_var_name()) 361 self.assertEqual('aaa_shared_embedding', embedding_column_a.name) 362 self.assertEqual('bbb_shared_embedding', embedding_column_b.name) 363 self.assertEqual((embedding_dimension,), embedding_column_a.variable_shape) 364 self.assertEqual((embedding_dimension,), embedding_column_b.variable_shape) 365 366 @parameterized.named_parameters( 367 { 368 'testcase_name': 'use_safe_embedding_lookup', 369 'use_safe_embedding_lookup': True 370 }, { 371 'testcase_name': 'dont_use_safe_embedding_lookup', 372 'use_safe_embedding_lookup': False 373 }) 374 @test_util.deprecated_graph_mode_only 375 def test_feature_layer_cpu(self, use_safe_embedding_lookup): 376 # Inputs. 377 vocabulary_size = 3 378 input_a = sparse_tensor.SparseTensorValue( 379 # example 0, ids [2] 380 # example 1, ids [0, 1] 381 indices=((0, 0), (1, 0), (1, 1)), 382 values=(2, 0, 1), 383 dense_shape=(2, 2)) 384 input_b = sparse_tensor.SparseTensorValue( 385 # example 0, ids [2] 386 # example 1, ids [0, 1] 387 # example 2, ids [] 388 indices=((0, 0), (1, 0), (1, 1)), 389 values=(2, 0, 1), 390 dense_shape=(3, 2)) 391 input_features = {'aaa': input_a, 'bbb': input_b} 392 393 # Embedding variable. 394 embedding_dimension = 2 395 embedding_values = ( 396 (1., 2.), # id 0 397 (3., 5.), # id 1 398 (7., 11.) # id 2 399 ) 400 401 def _initializer(shape, dtype, partition_info=None): 402 self.assertAllEqual((vocabulary_size, embedding_dimension), shape) 403 self.assertEqual(dtypes.float32, dtype) 404 self.assertIsNone(partition_info) 405 return embedding_values 406 407 # Expected lookup result, using combiner='mean'. 408 expected_lookups_a = ( 409 # example 0: 410 (7., 11.), # ids [2], embedding = [7, 11] 411 # example 1: 412 (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] 413 ) 414 expected_lookups_b = ( 415 # example 0: 416 ((7., 11.), (0., 0.),), # ids [2], embedding = [[7, 11], [0, 0]] 417 # example 1: 418 ((1., 2.), (3., 5.),), # ids [0, 1], embedding = [[1, 2], [3, 5]] 419 # example 2: 420 ((0., 0.), (0., 0.),), # ids [], embedding = [[0, 0], [0, 0]] 421 ) 422 423 # Build columns. 424 categorical_column_a = fc_lib.categorical_column_with_identity( 425 key='aaa', num_buckets=vocabulary_size) 426 categorical_column_b = fc_lib.sequence_categorical_column_with_identity( 427 key='bbb', num_buckets=vocabulary_size) 428 embedding_column_a, embedding_column_b = tpu_fc.shared_embedding_columns_v2( 429 [categorical_column_a, categorical_column_b], 430 dimension=embedding_dimension, 431 initializer=_initializer, 432 max_sequence_lengths=[0, 2], 433 use_safe_embedding_lookup=use_safe_embedding_lookup) 434 435 # Provide sparse input and get dense result. 436 dense_features = df_lib.DenseFeatures([embedding_column_a]) 437 sequence_features = sfc_lib.SequenceFeatures([embedding_column_b]) 438 embedding_lookup_a = dense_features(input_features) 439 embedding_lookup_b = sequence_features(input_features) 440 441 # Assert expected embedding variable and lookups. 442 global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 443 self.assertItemsEqual( 444 ('aaa_bbb_shared_embedding:0',), 445 tuple([v.name for v in global_vars])) 446 embedding_var = global_vars[0] 447 with _initialized_session(): 448 self.assertAllEqual(embedding_values, embedding_var) 449 self.assertAllEqual(expected_lookups_a, embedding_lookup_a) 450 self.assertAllEqual(expected_lookups_b, 451 embedding_lookup_b[0].eval()) 452 # The graph will still have SparseFillEmptyRows due to sequence being 453 # a Rank3 embedding lookup. 454 if use_safe_embedding_lookup: 455 self.assertEqual(2, [ 456 x.type for x in ops.get_default_graph().get_operations() 457 ].count('SparseFillEmptyRows')) 458 else: 459 self.assertEqual(1, [ 460 x.type for x in ops.get_default_graph().get_operations() 461 ].count('SparseFillEmptyRows')) 462 463 def test_deepcopy(self): 464 vocabulary_size = 3 465 categorical_column_a = fc_lib.categorical_column_with_identity( 466 key='aaa', num_buckets=vocabulary_size) 467 categorical_column_b = fc_lib.categorical_column_with_identity( 468 key='bbb', num_buckets=vocabulary_size) 469 embedding_dimension = 2 470 columns = tpu_fc.shared_embedding_columns_v2( 471 [categorical_column_b, categorical_column_a], 472 dimension=embedding_dimension) 473 columns_copy = copy.deepcopy(columns) 474 self.assertEqual( 475 [column._shared_embedding_collection_name for column in columns], 476 [column._shared_embedding_collection_name for column in columns_copy]) 477 478 479class DeviceSpecificEmbeddingColumnTestV2(test.TestCase, 480 parameterized.TestCase): 481 482 @parameterized.named_parameters( 483 { 484 'testcase_name': 'invalid_shared', 485 'shared': True, 486 }, { 487 'testcase_name': 'invalid_not_shared', 488 'shared': False, 489 }) 490 @test_util.deprecated_graph_mode_only 491 def test_invalid_cases(self, shared): 492 493 # Inputs. 494 input_sparse_tensor = sparse_tensor.SparseTensorValue( 495 indices=((0, 0), (1, 0), (1, 1), (1, 4)), 496 values=(2, 0, 1, 3), 497 dense_shape=(2, 5)) 498 input_features = {'inp': input_sparse_tensor} 499 500 # Build columns. 501 categorical_column_input = fc_lib.categorical_column_with_identity( 502 key='inp', num_buckets=3) 503 504 # Training on TPU with cpu embedding lookups is not supported. 505 if shared: 506 embedding_column = tpu_fc.shared_embedding_columns_v2( 507 [categorical_column_input], 508 dimension=2, 509 embedding_lookup_device='cpu', 510 tensor_core_shape=[None, 3]) 511 else: 512 embedding_column = tpu_fc.embedding_column_v2( 513 categorical_column_input, 514 dimension=2, 515 embedding_lookup_device='cpu', 516 tensor_core_shape=[None, 3]) 517 dense_features = df_lib.DenseFeatures(embedding_column) 518 with self.assertRaisesRegex( 519 ValueError, 520 r'.*embedding_lookup_device=\"cpu\" during training is not'): 521 dense_features(input_features) 522 523 # Inference on with TPU Embedding Hardware is not supported. 524 if shared: 525 embedding_column = tpu_fc.shared_embedding_columns_v2( 526 [categorical_column_input], 527 dimension=2, 528 embedding_lookup_device='tpu_embedding_core', 529 tensor_core_shape=[None, 3]) 530 else: 531 embedding_column = tpu_fc.embedding_column_v2( 532 categorical_column_input, 533 dimension=2, 534 embedding_lookup_device='tpu_embedding_core', 535 tensor_core_shape=[None, 3]) 536 context = tpu._TPUInferenceContext('tpu_inference') 537 context.Enter() 538 dense_features = df_lib.DenseFeatures(embedding_column) 539 with self.assertRaisesRegex( 540 ValueError, 541 r'Using embedding_lookup_device=tpu_embedding_core during inference is ' 542 ): 543 dense_features(input_features) 544 context.Exit() 545 546 @parameterized.named_parameters( 547 { 548 'testcase_name': 'combiner_mean_shared', 549 'shared': True, 550 'combiner': 'mean' 551 }, { 552 'testcase_name': 'combiner_sum_shared', 553 'shared': True, 554 'combiner': 'sum' 555 }, { 556 'testcase_name': 'combiner_sqrtn_shared', 557 'shared': True, 558 'combiner': 'sqrtn' 559 }, { 560 'testcase_name': 'combiner_mean_not_shared', 561 'shared': False, 562 'combiner': 'mean' 563 }, { 564 'testcase_name': 'combiner_sum_not_shared', 565 'shared': False, 566 'combiner': 'sum' 567 }, { 568 'testcase_name': 'combiner_sqrtn_not_shared', 569 'shared': False, 570 'combiner': 'sqrtn' 571 }) 572 @test_util.deprecated_graph_mode_only 573 def test_dense_embedding_lookup(self, shared, combiner): 574 # Inputs. 575 vocabulary_size = 3 576 input_sparse_tensor = sparse_tensor.SparseTensorValue( 577 # example 0, ids [2] 578 # example 1, ids [0, 1, 3] 579 indices=((0, 0), (1, 0), (1, 1), (1, 4)), 580 values=(2, 0, 1, 3), 581 dense_shape=(2, 5)) 582 input_features = {'inp': input_sparse_tensor} 583 584 # Embedding variable. 585 embedding_dimension = 2 586 embedding_values = ( 587 (1., 2.), # id 0 588 (3., 5.), # id 1 589 (7., 11.), # id 2 590 (13., 17.) # id 3 591 ) 592 593 def _initializer(shape, dtype, partition_info=None): 594 self.assertAllEqual((vocabulary_size, embedding_dimension), shape) 595 self.assertEqual(dtypes.float32, dtype) 596 self.assertIsNone(partition_info) 597 return embedding_values 598 599 # Build columns. 600 categorical_column_input = fc_lib.categorical_column_with_identity( 601 key='inp', num_buckets=vocabulary_size) 602 603 # Set tensor_core_shape to be [None, 20] to ensure some padding and 604 # dynamic batch size. 605 if shared: 606 embedding_column = tpu_fc.shared_embedding_columns_v2( 607 [categorical_column_input], 608 dimension=embedding_dimension, 609 initializer=_initializer, 610 combiner=combiner, 611 embedding_lookup_device='tpu_tensor_core', 612 tensor_core_shape=[None, 3]) 613 else: 614 embedding_column = tpu_fc.embedding_column_v2( 615 categorical_column_input, 616 dimension=embedding_dimension, 617 initializer=_initializer, 618 combiner=combiner, 619 embedding_lookup_device='tpu_tensor_core', 620 tensor_core_shape=[None, 3]) 621 622 # Run in TPUContexts so that we hit the intended densification case. 623 context = tpu._TPUInferenceContext('tpu_inference') 624 context.Enter() 625 with tpu_function.tpu_shard_context(1): 626 dense_features = df_lib.DenseFeatures(embedding_column) 627 # Sqrtn combiner not supported for now. 628 if combiner == 'sqrtn': 629 with self.assertRaisesRegex( 630 ValueError, 'Dense TPU Embedding does not support combiner'): 631 embedding_lookup = dense_features(input_features) 632 return 633 if combiner == 'mean': 634 expected_lookups = ( 635 # example 0: 636 (7., 11.), # ids [2], embedding = [7, 11] 637 # example 1: 638 (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = 639 # [2, 3.5] 640 ) 641 elif combiner == 'sum': 642 expected_lookups = ( 643 # example 0: 644 (7., 11.), # ids [2], embedding = [7, 11] 645 # example 1: 646 (4., 7), # ids [0, 1], embedding = sum([1, 2] + [3, 5]) = [4, 7] 647 ) 648 649 embedding_lookup = dense_features(input_features) 650 651 # Assert expected embedding variable and lookups. 652 global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 653 if shared: 654 self.assertCountEqual(('inp_shared_embedding:0',), 655 tuple([v.name for v in global_vars])) 656 else: 657 self.assertCountEqual( 658 ('dense_features/inp_embedding/embedding_weights:0',), 659 tuple([v.name for v in global_vars])) 660 661 embedding_var = global_vars[0] 662 with _initialized_session(): 663 self.assertAllEqual(embedding_values, embedding_var) 664 eval_res = embedding_lookup.eval() 665 self.assertAllEqual(expected_lookups, eval_res) 666 context.Exit() 667 668 @test_util.deprecated_graph_mode_only 669 def test_empty_row(self): 670 # Inputs. 671 vocabulary_size = 3 672 input_sparse_tensor = sparse_tensor.SparseTensorValue( 673 # example 0, ids [] 674 # example 1, ids [0, 1, 3] 675 indices=((1, 0), (1, 1), (1, 4)), 676 values=(0, 1, 3), 677 dense_shape=(2, 5)) 678 input_features = {'inp': input_sparse_tensor} 679 680 # Embedding variable. 681 embedding_dimension = 2 682 embedding_values = ( 683 (1., 2.), # id 0 684 (3., 5.), # id 1 685 (7., 11.), # id 2 686 (13., 17.) # id 3 687 ) 688 689 def _initializer(shape, dtype, partition_info=None): 690 self.assertAllEqual((vocabulary_size, embedding_dimension), shape) 691 self.assertEqual(dtypes.float32, dtype) 692 self.assertIsNone(partition_info) 693 return embedding_values 694 695 # Build columns. 696 categorical_column_input = fc_lib.categorical_column_with_identity( 697 key='inp', num_buckets=vocabulary_size) 698 699 # Set tensor_core_shape to be [None, 20] to ensure some padding and 700 # dynamic batch size. 701 embedding_column = tpu_fc.embedding_column_v2( 702 categorical_column_input, 703 dimension=embedding_dimension, 704 initializer=_initializer, 705 combiner='mean', 706 embedding_lookup_device='tpu_tensor_core', 707 tensor_core_shape=[None, 3]) 708 709 # Run in TPUContexts so that we hit the intended densification case. 710 context = tpu._TPUInferenceContext('tpu_inference') 711 context.Enter() 712 with tpu_function.tpu_shard_context(1): 713 dense_features = df_lib.DenseFeatures(embedding_column) 714 expected_lookups = ( 715 # example 0: 716 (0., 0.), # ids [], embedding = [0, 0] 717 # example 1: 718 (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] 719 ) 720 721 embedding_lookup = dense_features(input_features) 722 723 # Assert expected embedding variable and lookups. 724 global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 725 self.assertCountEqual( 726 ('dense_features/inp_embedding/embedding_weights:0',), 727 tuple([v.name for v in global_vars])) 728 729 embedding_var = global_vars[0] 730 with _initialized_session(): 731 self.assertAllEqual(embedding_values, embedding_var) 732 eval_res = embedding_lookup.eval() 733 self.assertAllEqual(expected_lookups, eval_res) 734 context.Exit() 735 736 @test_util.deprecated_graph_mode_only 737 def test_error_dense_shape_invalid(self): 738 categorical_column_input = fc_lib.categorical_column_with_identity( 739 key='inp', num_buckets=5) 740 with self.assertRaisesRegex(ValueError, 'tensor_core_shape must be size 2'): 741 tpu_fc.shared_embedding_columns_v2([categorical_column_input], 742 dimension=20, 743 tensor_core_shape=[None, 20, 15]) 744 745 746if __name__ == '__main__': 747 test.main() 748