1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15"""TPU Feature Column Library.""" 16import math 17 18from tensorflow.python.feature_column import feature_column as fc 19from tensorflow.python.feature_column import feature_column_lib as fc_lib 20from tensorflow.python.framework import ops 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import init_ops 23from tensorflow.python.ops import variable_scope 24from tensorflow.python.tpu import tpu 25from tensorflow.python.tpu import tpu_function 26# pylint: disable=protected-access 27 28 29_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope' 30_SUPPORTED_SEQUENCE_COLUMNS = (fc._SequenceCategoricalColumn, 31 fc_lib.SequenceCategoricalColumn) 32 33 34# For V2 columns, we support anything that inherits from CategoricalColumn 35# other than those in the denylist. User-provided columns that inherit from 36# CategoricalColumn may or may not be compatible; it is up to the user to 37# manage TPU compatibility for custom columns. 38_SUPPORTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.CategoricalColumn,) 39_DENYLISTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.HashedCategoricalColumn, 40 fc_lib.BucketizedColumn, 41 fc_lib.CrossedColumn) 42_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn, 43 fc._VocabularyFileCategoricalColumn, 44 fc._VocabularyListCategoricalColumn, 45 fc._WeightedCategoricalColumn, 46 fc._SequenceCategoricalColumn 47 ) + _SUPPORTED_CATEGORICAL_COLUMNS_V2 48_SEQUENCE_FEATURE_LENGTH_POSTFIX = '_seq_length_' 49 50 51def embedding_column(categorical_column, 52 dimension, 53 combiner='mean', 54 initializer=None, 55 max_sequence_length=0, 56 learning_rate_fn=None, 57 use_safe_embedding_lookup=True): 58 """TPU embedding_column for `tf.feature_column.embedding_column`. 59 60 Note that the interface for TPU embedding_column is different from the non-TPU 61 version. The following args available for the non-TPU version are NOT 62 supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable. 63 64 Args: 65 categorical_column: A categorical_column returned from 66 categorical_column_with_identity, weighted_categorical_column, 67 categorical_column_with_vocabulary_file, 68 categorical_column_with_vocabulary_list, 69 sequence_categorical_column_with_identity, 70 sequence_categorical_column_with_vocabulary_file, 71 sequence_categorical_column_with_vocabulary_list 72 dimension: An integer specifying dimension of the embedding, must be > 0. 73 combiner: A string specifying how to reduce if there are multiple entries 74 in a single row for a non-sequence column. For more information, see 75 `tf.feature_column.embedding_column`. 76 initializer: A variable initializer function to be used in embedding 77 variable initialization. If not specified, defaults to 78 `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and 79 standard deviation `1/sqrt(dimension)`. 80 max_sequence_length: An non-negative integer specifying the max sequence 81 length. Any sequence shorter then this will be padded with 0 embeddings 82 and any sequence longer will be truncated. This must be positive for 83 sequence features and 0 for non-sequence features. 84 learning_rate_fn: A function that takes global step and returns learning 85 rate for the embedding table. If you intend to use the same learning rate 86 for multiple embedding tables, please ensure that you pass the exact same 87 python function to all calls of embedding_column, otherwise performence 88 may suffer. 89 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 90 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 91 there are no empty rows and all weights and ids are positive at the 92 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 93 input tensors. Defaults to true, consider turning off if the above checks 94 are not needed. Note that having empty rows will not trigger any error 95 though the output result might be 0 or omitted. 96 97 Returns: 98 A _TPUEmbeddingColumn. 99 100 Raises: 101 ValueError: if `dimension` not > 0. 102 ValueError: if `initializer` is specified but not callable. 103 TypeError: if categorical_column is not a supported type. 104 """ 105 if isinstance(categorical_column, _DENYLISTED_CATEGORICAL_COLUMNS_V2): 106 raise TypeError('categorical_column for tpu ' 107 ' embedding_column was ' 108 f'denylisted type {type(categorical_column)}') 109 if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): 110 raise TypeError( 111 'categorical_column for tpu ' 112 ' embedding_column must be type {}, got {}.'.format(' or '.join([ 113 cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS 114 ]), type(categorical_column))) 115 if (dimension is None) or (dimension < 1): 116 raise ValueError('Invalid dimension {}.'.format(dimension)) 117 118 if (initializer is not None) and (not callable(initializer)): 119 raise ValueError('initializer must be callable if specified. ' 120 'Embedding of column_name: {}'.format( 121 categorical_column.name)) 122 if initializer is None: 123 initializer = init_ops.truncated_normal_initializer( 124 mean=0.0, stddev=1 / math.sqrt(dimension)) 125 126 embedding_shape = categorical_column._num_buckets, dimension # pylint: disable=protected-access 127 128 def _creator(weight_collections, scope): 129 embedding_column_layer = fc._EmbeddingColumnLayer( 130 embedding_shape=embedding_shape, 131 initializer=initializer, 132 weight_collections=weight_collections, 133 trainable=True, 134 name='embedding_column_layer') 135 return embedding_column_layer(None, scope=scope) # pylint: disable=not-callable 136 137 column = _TPUEmbeddingColumn( 138 categorical_column=categorical_column, 139 dimension=dimension, 140 combiner=combiner, 141 layer_creator=_creator, 142 ckpt_to_load_from=None, 143 tensor_name_in_ckpt=None, 144 max_norm=None, 145 trainable=True, 146 max_sequence_length=max_sequence_length, 147 learning_rate_fn=learning_rate_fn, 148 use_safe_embedding_lookup=use_safe_embedding_lookup) 149 # For Embedding column, the initializer is hidden inside the creator Fn, which 150 # is not accessible later. So, we attach it to a special field. Also note 151 # that non-TPU Embedding column and non-TPU shared Embedding column handle the 152 # initializer differently. See shared_embedding_columns for details. 153 column._tpu_initializer = initializer 154 return column 155 156 157def shared_embedding_columns(categorical_columns, 158 dimension, 159 combiner='mean', 160 initializer=None, 161 shared_embedding_collection_name=None, 162 max_sequence_lengths=None, 163 learning_rate_fn=None, 164 use_safe_embedding_lookup=True): 165 """List of dense columns that convert from sparse, categorical input. 166 167 Note that the interface for TPU embedding_column is different from the non-TPU 168 version. The following args available for the non-TPU version are NOT 169 supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable. 170 171 Args: 172 categorical_columns: A list of categorical_columns returned from 173 categorical_column_with_identity, weighted_categorical_column, 174 categorical_column_with_vocabulary_file, 175 categorical_column_with_vocabulary_list, 176 sequence_categorical_column_with_identity, 177 sequence_categorical_column_with_vocabulary_file, 178 sequence_categorical_column_with_vocabulary_list 179 dimension: An integer specifying dimension of the embedding, must be > 0. 180 combiner: A string specifying how to reduce if there are multiple entries 181 in a single row for a non-sequence column. For more information, see 182 `tf.feature_column.embedding_column`. 183 initializer: A variable initializer function to be used in embedding 184 variable initialization. If not specified, defaults to 185 `tf.truncated_normal_initializer` with mean `0.0` and standard deviation 186 `1/sqrt(dimension)`. 187 shared_embedding_collection_name: Optional name of the collection where 188 shared embedding weights are added. If not given, a reasonable name will 189 be chosen based on the names of `categorical_columns`. This is also used 190 in `variable_scope` when creating shared embedding weights. 191 max_sequence_lengths: An list of non-negative integers, either None or 192 empty or the same length as the argument categorical_columns. Entries 193 corresponding to non-sequence columns must be 0 and entries corresponding 194 to sequence columns specify the max sequence length for the column. Any 195 sequence shorter then this will be padded with 0 embeddings and any 196 sequence longer will be truncated. 197 learning_rate_fn: A function that takes global step and returns learning 198 rate for the embedding table. If you intend to use the same learning rate 199 for multiple embedding tables, please ensure that you pass the exact same 200 python function to all calls of shared_embedding_columns, otherwise 201 performence may suffer. 202 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 203 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 204 there are no empty rows and all weights and ids are positive at the 205 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 206 input tensors. Defaults to true, consider turning off if the above checks 207 are not needed. Note that having empty rows will not trigger any error 208 though the output result might be 0 or omitted. 209 210 Returns: 211 A _TPUEmbeddingColumn. 212 213 Raises: 214 ValueError: if `dimension` not > 0. 215 ValueError: if `initializer` is specified but not callable. 216 ValueError: if `max_sequence_lengths` is specified and not the same length 217 as `categorical_columns`. 218 ValueError: if `max_sequence_lengths` is positive for a non sequence column 219 or 0 for a sequence column. 220 """ 221 for categorical_column in categorical_columns: 222 if isinstance(categorical_column, _DENYLISTED_CATEGORICAL_COLUMNS_V2): 223 raise TypeError('categorical_column for tpu ' 224 ' embedding_column was denylisted type ' 225 f'{type(categorical_column)}') 226 if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS): 227 raise TypeError( 228 'categorical_column for tpu ' 229 ' shared_embedding_columns must be type {}, got {}.'.format( 230 ' or '.join( 231 [cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS]), 232 type(categorical_column))) 233 234 if not max_sequence_lengths: 235 max_sequence_lengths = [0] * len(categorical_columns) 236 if len(max_sequence_lengths) != len(categorical_columns): 237 raise ValueError('max_sequence_lengths and categorical_columns must be of ' 238 'the same length. len(max_sequence_lengths)={} ' 239 'len(categorical_columns)={}.'.format( 240 len(max_sequence_lengths), len(categorical_columns))) 241 242 if (dimension is None) or (dimension < 1): 243 raise ValueError('Invalid dimension {}.'.format(dimension)) 244 245 if (initializer is not None) and (not callable(initializer)): 246 raise ValueError('initializer must be callable if specified. ') 247 if initializer is None: 248 initializer = init_ops.truncated_normal_initializer( 249 mean=0.0, stddev=1 / math.sqrt(dimension)) 250 251 # Sort the columns so the default collection name is deterministic even if the 252 # user passes columns from an unsorted collection, such as dict.values(). 253 sorted_columns = sorted(categorical_columns, key=lambda x: x.name) 254 num_buckets = sorted_columns[0]._num_buckets # pylint: disable=protected-access 255 256 for c in sorted_columns[1:]: 257 if num_buckets != c._num_buckets: # pylint: disable=protected-access 258 raise ValueError( 259 'To use shared_embedding_column, all categorical_columns must have ' 260 'the same number of buckets. Given column: {} with buckets: {} does ' 261 'not match column: {} with buckets: {}'.format( 262 sorted_columns[0], num_buckets, c, c._num_buckets)) # pylint: disable=protected-access 263 264 if not shared_embedding_collection_name: 265 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) 266 shared_embedding_collection_name += '_shared_embedding' 267 268 tpu_columns = [] 269 270 # Create the state (_SharedEmbeddingColumnLayer) here. 271 for categorical_column, max_sequence_length in zip( 272 categorical_columns, max_sequence_lengths): 273 column = _TPUSharedEmbeddingColumn( 274 categorical_column=categorical_column, 275 dimension=dimension, 276 combiner=combiner, 277 initializer=initializer, 278 shared_embedding_collection_name=shared_embedding_collection_name, 279 ckpt_to_load_from=None, 280 tensor_name_in_ckpt=None, 281 max_norm=None, 282 trainable=True, 283 max_sequence_length=max_sequence_length, 284 learning_rate_fn=learning_rate_fn, 285 use_safe_embedding_lookup=use_safe_embedding_lookup) 286 tpu_columns.append(column) 287 288 return tpu_columns 289 290 291class _TPUBaseEmbeddingColumn(object): 292 """Base class for TPU Embedding Column.""" 293 294 def __init__(self, 295 categorical_column, 296 max_sequence_length=0, 297 learning_rate_fn=None): 298 self._tpu_categorical_column = categorical_column 299 self._max_sequence_length = max_sequence_length 300 self._learning_rate_fn = learning_rate_fn 301 if (self.is_sequence_column() and max_sequence_length < 1): 302 raise ValueError('max_sequence_length must be greater than 0 for ' 303 'sequence columns. Got max_sequence_length={} for ' 304 'sequence column {}.'.format(max_sequence_length, 305 categorical_column.name)) 306 if (not self.is_sequence_column() and max_sequence_length != 0): 307 raise ValueError('Non zero max_seq_length={} specified for non ' 308 'sequence column {}.'.format(max_sequence_length, 309 categorical_column.name)) 310 311 def get_combiner(self): 312 """Returns the embedding combiner.""" 313 raise NotImplementedError('not implemented') 314 315 def get_embedding_table_size(self): 316 """Returns the embedding table size, tuple of vocab size and dimension.""" 317 raise NotImplementedError('not implemented') 318 319 def get_feature_key_name(self): 320 """Returns the feature key name in the features dict.""" 321 raise NotImplementedError('not impl') 322 323 def get_weight_key_name(self): 324 """Return the key name for weights.""" 325 raise NotImplementedError('not impl') 326 327 def get_embedding_var_name(self): 328 """Returns the embedding variable name. 329 330 Feature key name and embedding variable name are usually one-to-one mapping. 331 But for shared embedding columns, it is many-to-one mapping. 332 """ 333 raise NotImplementedError('not impl') 334 335 def get_initializer(self): 336 """Returns the initializer.""" 337 raise NotImplementedError('not impl') 338 339 def is_categorical_column_weighted(self): 340 """Check if the categorical column of the embedding column is weighted.""" 341 raise NotImplementedError('not impl') 342 343 def is_sequence_column(self): 344 return isinstance(self._tpu_categorical_column, _SUPPORTED_SEQUENCE_COLUMNS) 345 346 def get_max_sequence_length(self): 347 return self._max_sequence_length 348 349 def get_learning_rate_fn(self): 350 return self._learning_rate_fn 351 352 def get_sequence_length_feature_key_name(self): 353 """Get the key for the associated sequence length feature.""" 354 return get_sequence_length_feature_key_name_from_feature_key_name( 355 self.get_feature_key_name()) 356 357 358class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn): 359 """Core Embedding Column.""" 360 361 def __new__(cls, 362 categorical_column, 363 dimension, 364 combiner='mean', 365 layer_creator=None, 366 ckpt_to_load_from=None, 367 tensor_name_in_ckpt=None, 368 max_norm=None, 369 trainable=True, 370 max_sequence_length=0, 371 learning_rate_fn=None, 372 use_safe_embedding_lookup=True, 373 bypass_scope_validation=False): 374 # Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable 375 # are not supported on TPU. They are solely for matching the signature of 376 # __new__ of parent class fc._EmbeddingColumn. 377 del bypass_scope_validation 378 # pylint: disable=redundant-keyword-arg 379 return fc._EmbeddingColumn.__new__( 380 cls, 381 categorical_column, 382 dimension, 383 combiner=combiner, 384 layer_creator=layer_creator, 385 ckpt_to_load_from=ckpt_to_load_from, 386 tensor_name_in_ckpt=tensor_name_in_ckpt, 387 max_norm=max_norm, 388 trainable=trainable, 389 use_safe_embedding_lookup=use_safe_embedding_lookup) 390 391 def __init__(self, 392 categorical_column, 393 dimension, 394 combiner='mean', 395 layer_creator=None, 396 ckpt_to_load_from=None, 397 tensor_name_in_ckpt=None, 398 max_norm=None, 399 trainable=True, 400 max_sequence_length=0, 401 learning_rate_fn=None, 402 use_safe_embedding_lookup=True, 403 bypass_scope_validation=False): 404 _TPUBaseEmbeddingColumn.__init__( 405 self, 406 categorical_column, 407 max_sequence_length=max_sequence_length, 408 learning_rate_fn=learning_rate_fn) 409 self._key = None 410 # If true, scope validation is skipped to allow the same column to be used 411 # in multiple variable scopes. By default, this is False, and we expect a 412 # 1:1 mapping between feature columns and scopes. 413 self._bypass_scope_validation = bypass_scope_validation 414 415 def get_combiner(self): 416 return self.combiner 417 418 def get_embedding_table_size(self): 419 """Returns num_ids and width.""" 420 return (self.categorical_column._num_buckets, self.dimension) 421 422 def get_feature_key_name(self): 423 """get_feature_key_name.""" 424 if self.is_categorical_column_weighted(): 425 return self.categorical_column.categorical_column.name 426 return self.categorical_column.name 427 428 def get_weight_key_name(self): 429 """get_weight_key_name.""" 430 if self.is_categorical_column_weighted(): 431 return self.categorical_column.weight_feature_key 432 return None 433 434 def get_embedding_var_name(self): 435 """get_embedding_var_name.""" 436 return self.categorical_column.name 437 438 def get_initializer(self): 439 return self._tpu_initializer 440 441 def is_categorical_column_weighted(self): 442 """Check if the categorical column of the embedding column is weighted.""" 443 if isinstance( 444 self.categorical_column, 445 ( 446 fc._WeightedCategoricalColumn, # pylint: disable=protected-access 447 fc_lib.WeightedCategoricalColumn)): 448 return True 449 return False 450 451 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 452 if tpu.under_tpu_inference_context(): 453 def host_computation(): 454 return fc._EmbeddingColumn._get_dense_tensor( 455 self, inputs, weight_collections, trainable) 456 return tpu.outside_compilation(host_computation) 457 458 if _is_running_on_cpu(): 459 return fc._EmbeddingColumn._get_dense_tensor( 460 self, inputs, weight_collections, trainable) 461 462 # TPU mode 463 # Get the embeddings from the LazyBuilder. 464 tensor = inputs.get(self.get_feature_key_name()) 465 466 # Add to collection for _create_tpu_embedding_variables_and_ops 467 _record_variable_scope_and_name( 468 self.get_embedding_var_name(), 469 'embedding_weights', 470 bypass_scope_validation=self._bypass_scope_validation) 471 472 return tensor 473 474 def _get_sequence_dense_tensor( 475 self, inputs, weight_collections=None, trainable=None): 476 if tpu.under_tpu_inference_context(): 477 def host_computation(): 478 return fc._EmbeddingColumn._get_sequence_dense_tensor( 479 self, inputs, weight_collections, trainable) 480 return tpu.outside_compilation(host_computation) 481 482 if _is_running_on_cpu(): 483 return fc._EmbeddingColumn._get_sequence_dense_tensor( 484 self, inputs, weight_collections, trainable) 485 486 tensor = inputs.get(self.get_feature_key_name()) 487 tensor_lengths = inputs.get(self.get_sequence_length_feature_key_name()) 488 489 # inputs is a _LazyBuilder and for rank 1 tensors, it calls expand_dims(-1). 490 # We need to undo this to match the standard CPU sequence embedding. 491 tensor_lengths = array_ops.squeeze(tensor_lengths, -1) 492 493 # Add to collection for _create_tpu_embedding_variables_and_ops 494 _record_variable_scope_and_name( 495 self.get_embedding_var_name(), 496 'embedding_weights', 497 bypass_scope_validation=self._bypass_scope_validation) 498 499 return fc._SequenceDenseColumn.TensorSequenceLengthPair( 500 dense_tensor=tensor, sequence_length=tensor_lengths) 501 502 503class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn, 504 fc._SharedEmbeddingColumn): 505 """Core Shared Embedding Column.""" 506 507 def __new__(cls, 508 categorical_column, 509 dimension, 510 combiner='mean', 511 initializer=None, 512 shared_embedding_collection_name=None, 513 ckpt_to_load_from=None, 514 tensor_name_in_ckpt=None, 515 max_norm=None, 516 trainable=True, 517 max_sequence_length=0, 518 learning_rate_fn=None, 519 use_safe_embedding_lookup=True): 520 return fc._SharedEmbeddingColumn.__new__( 521 cls, 522 categorical_column, 523 dimension, 524 combiner=combiner, 525 initializer=initializer, 526 shared_embedding_collection_name=shared_embedding_collection_name, 527 ckpt_to_load_from=ckpt_to_load_from, 528 tensor_name_in_ckpt=tensor_name_in_ckpt, 529 max_norm=max_norm, 530 trainable=trainable, 531 use_safe_embedding_lookup=use_safe_embedding_lookup) 532 533 def __init__(self, 534 categorical_column, 535 dimension, 536 combiner='mean', 537 initializer=None, 538 shared_embedding_collection_name=None, 539 ckpt_to_load_from=None, 540 tensor_name_in_ckpt=None, 541 max_norm=None, 542 trainable=True, 543 max_sequence_length=0, 544 learning_rate_fn=None, 545 use_safe_embedding_lookup=True): 546 547 _TPUBaseEmbeddingColumn.__init__( 548 self, 549 categorical_column, 550 max_sequence_length=max_sequence_length, 551 learning_rate_fn=learning_rate_fn) 552 self._key = None 553 554 def get_combiner(self): 555 return self.combiner 556 557 def get_embedding_table_size(self): 558 """Returns num_ids and width.""" 559 return (self.categorical_column._num_buckets, self.dimension) 560 561 def get_feature_key_name(self): 562 """get_feature_key_name.""" 563 if self.is_categorical_column_weighted(): 564 return self.categorical_column.categorical_column.name 565 return self.categorical_column.name 566 567 def get_weight_key_name(self): 568 """get_weight_key_name.""" 569 if self.is_categorical_column_weighted(): 570 return self.categorical_column.weight_feature_key 571 return None 572 573 def get_embedding_var_name(self): 574 """get_embedding_var_name.""" 575 return self.shared_embedding_collection_name 576 577 def get_initializer(self): 578 return self.initializer 579 580 def is_categorical_column_weighted(self): 581 """Check if the categorical column of the embedding column is weighted.""" 582 if isinstance( 583 self.categorical_column, 584 ( 585 fc._WeightedCategoricalColumn, # pylint: disable=protected-access 586 fc_lib.WeightedCategoricalColumn)): 587 return True 588 return False 589 590 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 591 if tpu.under_tpu_inference_context(): 592 def host_computation(): 593 return fc._SharedEmbeddingColumn._get_dense_tensor( 594 self, inputs, weight_collections, trainable) 595 return tpu.outside_compilation(host_computation) 596 597 if _is_running_on_cpu(): 598 return fc._SharedEmbeddingColumn._get_dense_tensor( 599 self, inputs, weight_collections, trainable) 600 601 # TPU mode 602 # Get the embeddings from the LazyBuilder. 603 tensor = inputs.get(self.get_feature_key_name()) 604 605 # Add to collection for _create_tpu_embedding_variables_and_ops 606 _record_variable_scope_and_name( 607 self.get_embedding_var_name(), 608 'embedding_weights', 609 is_shared_embedding=True) 610 return tensor 611 612 def _get_sequence_dense_tensor( 613 self, inputs, weight_collections=None, trainable=None): 614 if tpu.under_tpu_inference_context(): 615 def host_computation(): 616 return fc._SharedEmbeddingColumn._get_sequence_dense_tensor( 617 self, inputs, weight_collections, trainable) 618 return tpu.outside_compilation(host_computation) 619 620 if _is_running_on_cpu(): 621 return fc._SharedEmbeddingColumn._get_sequence_dense_tensor( 622 self, inputs, weight_collections, trainable) 623 624 tensor = inputs.get(self.get_feature_key_name()) 625 tensor_lengths = inputs.get(self.get_sequence_length_feature_key_name()) 626 627 # Add to collection for _create_tpu_embedding_variables_and_ops 628 _record_variable_scope_and_name( 629 self.get_embedding_var_name(), 630 'embedding_weights', 631 is_shared_embedding=True) 632 633 return fc._SequenceDenseColumn.TensorSequenceLengthPair( 634 dense_tensor=tensor, sequence_length=tensor_lengths) 635 636 637def _record_variable_scope_and_name(embedding_var_name, 638 embedding_var_name_in_fc, 639 is_shared_embedding=False, 640 bypass_scope_validation=False): 641 """Add embedding variable name and scope to collection.""" 642 g = ops.get_default_graph() 643 collection = g.get_collection_ref(_TPU_FC_TO_SCOPE) 644 if not collection: 645 collection.append({}) 646 647 var_def_dict = collection[0] 648 649 captured_scope = variable_scope.get_variable_scope() 650 captured_scope_name = captured_scope.name 651 652 if embedding_var_name in var_def_dict: 653 if (var_def_dict[embedding_var_name][0] != captured_scope_name and 654 not is_shared_embedding and not bypass_scope_validation): 655 raise ValueError( 656 'For embedding var name {}, the variable scope name is different, ' 657 'got {}; expected {}'.format(embedding_var_name, 658 captured_scope_name, 659 var_def_dict[embedding_var_name][0])) 660 if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc: 661 raise ValueError( 662 'For embedding var name {}, the embedding name is different, ' 663 'got {}; expected {}'.format(embedding_var_name, 664 embedding_var_name_in_fc, 665 var_def_dict[embedding_var_name][1])) 666 else: 667 var_def_dict[embedding_var_name] = (captured_scope_name, 668 embedding_var_name_in_fc) 669 670 671def _is_running_on_cpu(): 672 """Returns True if the current context is CPU model.""" 673 return tpu_function.get_tpu_context().number_of_shards is None 674 675 676def get_sequence_length_feature_key_name_from_feature_key_name(feature_name): 677 """Gets the name of the sequence length feature from that of the base feature. 678 679 Args: 680 feature_name: The feature key of a sequence column. 681 682 Returns: 683 A string which is the feature key for the associated feature length column. 684 """ 685 return feature_name + _SEQUENCE_FEATURE_LENGTH_POSTFIX 686 687 688def split_sequence_columns(feature_columns): 689 """Split a list of _TPUEmbeddingColumn into sequence and non-sequence columns. 690 691 For use in a TPUEstimator model_fn function. E.g. 692 693 def model_fn(features): 694 sequence_columns, feature_columns = ( 695 tf.tpu.feature_column.split_sequence_columns(feature_columns)) 696 input = tf.feature_column.input_layer( 697 features=features, feature_columns=feature_columns) 698 sequence_features, sequence_lengths = ( 699 tf.contrib.feature_column.sequence_input_layer( 700 features=features, feature_columns=sequence_columns)) 701 702 Args: 703 feature_columns: A list of _TPUEmbeddingColumns to split. 704 705 Returns: 706 Two lists of _TPUEmbeddingColumns, the first is the sequence columns and the 707 second is the non-sequence columns. 708 """ 709 sequence_columns = [] 710 non_sequence_columns = [] 711 for column in feature_columns: 712 if not isinstance(column, (_TPUEmbeddingColumn, _TPUSharedEmbeddingColumn)): 713 raise TypeError( 714 'column must be a _TPUEmbeddingColumn or _TPUSharedEmbeddingColumn ' 715 f'but got {type(column)} instead.') 716 if column.is_sequence_column(): 717 sequence_columns.append(column) 718 else: 719 non_sequence_columns.append(column) 720 return sequence_columns, non_sequence_columns 721