xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/feature_column.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""TPU Feature Column Library."""
16import math
17
18from tensorflow.python.feature_column import feature_column as fc
19from tensorflow.python.feature_column import feature_column_lib as fc_lib
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import init_ops
23from tensorflow.python.ops import variable_scope
24from tensorflow.python.tpu import tpu
25from tensorflow.python.tpu import tpu_function
26# pylint: disable=protected-access
27
28
29_TPU_FC_TO_SCOPE = '_tpu_feature_column_scope'
30_SUPPORTED_SEQUENCE_COLUMNS = (fc._SequenceCategoricalColumn,
31                               fc_lib.SequenceCategoricalColumn)
32
33
34# For V2 columns, we support anything that inherits from CategoricalColumn
35# other than those in the denylist. User-provided columns that inherit from
36# CategoricalColumn may or may not be compatible; it is up to the user to
37# manage TPU compatibility for custom columns.
38_SUPPORTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.CategoricalColumn,)
39_DENYLISTED_CATEGORICAL_COLUMNS_V2 = (fc_lib.HashedCategoricalColumn,
40                                      fc_lib.BucketizedColumn,
41                                      fc_lib.CrossedColumn)
42_SUPPORTED_CATEGORICAL_COLUMNS = (fc._IdentityCategoricalColumn,
43                                  fc._VocabularyFileCategoricalColumn,
44                                  fc._VocabularyListCategoricalColumn,
45                                  fc._WeightedCategoricalColumn,
46                                  fc._SequenceCategoricalColumn
47                                 ) + _SUPPORTED_CATEGORICAL_COLUMNS_V2
48_SEQUENCE_FEATURE_LENGTH_POSTFIX = '_seq_length_'
49
50
51def embedding_column(categorical_column,
52                     dimension,
53                     combiner='mean',
54                     initializer=None,
55                     max_sequence_length=0,
56                     learning_rate_fn=None,
57                     use_safe_embedding_lookup=True):
58  """TPU embedding_column for `tf.feature_column.embedding_column`.
59
60  Note that the interface for TPU embedding_column is different from the non-TPU
61  version. The following args available for the non-TPU version are NOT
62  supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable.
63
64  Args:
65    categorical_column: A categorical_column returned from
66        categorical_column_with_identity, weighted_categorical_column,
67        categorical_column_with_vocabulary_file,
68        categorical_column_with_vocabulary_list,
69        sequence_categorical_column_with_identity,
70        sequence_categorical_column_with_vocabulary_file,
71        sequence_categorical_column_with_vocabulary_list
72    dimension: An integer specifying dimension of the embedding, must be > 0.
73    combiner: A string specifying how to reduce if there are multiple entries
74      in a single row for a non-sequence column. For more information, see
75      `tf.feature_column.embedding_column`.
76    initializer: A variable initializer function to be used in embedding
77      variable initialization. If not specified, defaults to
78      `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and
79      standard deviation `1/sqrt(dimension)`.
80    max_sequence_length: An non-negative integer specifying the max sequence
81      length. Any sequence shorter then this will be padded with 0 embeddings
82      and any sequence longer will be truncated. This must be positive for
83      sequence features and 0 for non-sequence features.
84    learning_rate_fn: A function that takes global step and returns learning
85      rate for the embedding table. If you intend to use the same learning rate
86      for multiple embedding tables, please ensure that you pass the exact same
87      python function to all calls of embedding_column, otherwise performence
88      may suffer.
89    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
90      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
91      there are no empty rows and all weights and ids are positive at the
92      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
93      input tensors. Defaults to true, consider turning off if the above checks
94      are not needed. Note that having empty rows will not trigger any error
95      though the output result might be 0 or omitted.
96
97  Returns:
98    A  _TPUEmbeddingColumn.
99
100  Raises:
101    ValueError: if `dimension` not > 0.
102    ValueError: if `initializer` is specified but not callable.
103    TypeError: if categorical_column is not a supported type.
104  """
105  if isinstance(categorical_column, _DENYLISTED_CATEGORICAL_COLUMNS_V2):
106    raise TypeError('categorical_column for tpu '
107                    ' embedding_column was '
108                    f'denylisted type {type(categorical_column)}')
109  if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
110    raise TypeError(
111        'categorical_column for tpu '
112        ' embedding_column must be type {}, got {}.'.format(' or '.join([
113            cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS
114        ]), type(categorical_column)))
115  if (dimension is None) or (dimension < 1):
116    raise ValueError('Invalid dimension {}.'.format(dimension))
117
118  if (initializer is not None) and (not callable(initializer)):
119    raise ValueError('initializer must be callable if specified. '
120                     'Embedding of column_name: {}'.format(
121                         categorical_column.name))
122  if initializer is None:
123    initializer = init_ops.truncated_normal_initializer(
124        mean=0.0, stddev=1 / math.sqrt(dimension))
125
126  embedding_shape = categorical_column._num_buckets, dimension  # pylint: disable=protected-access
127
128  def _creator(weight_collections, scope):
129    embedding_column_layer = fc._EmbeddingColumnLayer(
130        embedding_shape=embedding_shape,
131        initializer=initializer,
132        weight_collections=weight_collections,
133        trainable=True,
134        name='embedding_column_layer')
135    return embedding_column_layer(None, scope=scope)  # pylint: disable=not-callable
136
137  column = _TPUEmbeddingColumn(
138      categorical_column=categorical_column,
139      dimension=dimension,
140      combiner=combiner,
141      layer_creator=_creator,
142      ckpt_to_load_from=None,
143      tensor_name_in_ckpt=None,
144      max_norm=None,
145      trainable=True,
146      max_sequence_length=max_sequence_length,
147      learning_rate_fn=learning_rate_fn,
148      use_safe_embedding_lookup=use_safe_embedding_lookup)
149  # For Embedding column, the initializer is hidden inside the creator Fn, which
150  # is not accessible later. So, we attach it to a special field. Also note
151  # that non-TPU Embedding column and non-TPU shared Embedding column handle the
152  # initializer differently. See shared_embedding_columns for details.
153  column._tpu_initializer = initializer
154  return column
155
156
157def shared_embedding_columns(categorical_columns,
158                             dimension,
159                             combiner='mean',
160                             initializer=None,
161                             shared_embedding_collection_name=None,
162                             max_sequence_lengths=None,
163                             learning_rate_fn=None,
164                             use_safe_embedding_lookup=True):
165  """List of dense columns that convert from sparse, categorical input.
166
167  Note that the interface for TPU embedding_column is different from the non-TPU
168  version. The following args available for the non-TPU version are NOT
169  supported: ckpt_to_load_from, tensor_name_in_ckp, max_norm and trainable.
170
171  Args:
172    categorical_columns: A list of categorical_columns returned from
173        categorical_column_with_identity, weighted_categorical_column,
174        categorical_column_with_vocabulary_file,
175        categorical_column_with_vocabulary_list,
176        sequence_categorical_column_with_identity,
177        sequence_categorical_column_with_vocabulary_file,
178        sequence_categorical_column_with_vocabulary_list
179    dimension: An integer specifying dimension of the embedding, must be > 0.
180    combiner: A string specifying how to reduce if there are multiple entries
181      in a single row for a non-sequence column. For more information, see
182      `tf.feature_column.embedding_column`.
183    initializer: A variable initializer function to be used in embedding
184      variable initialization. If not specified, defaults to
185      `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
186      `1/sqrt(dimension)`.
187    shared_embedding_collection_name: Optional name of the collection where
188      shared embedding weights are added. If not given, a reasonable name will
189      be chosen based on the names of `categorical_columns`. This is also used
190      in `variable_scope` when creating shared embedding weights.
191    max_sequence_lengths: An list of non-negative integers, either None or
192      empty or the same length as the argument categorical_columns. Entries
193      corresponding to non-sequence columns must be 0 and entries corresponding
194      to sequence columns specify the max sequence length for the column. Any
195      sequence shorter then this will be padded with 0 embeddings and any
196      sequence longer will be truncated.
197    learning_rate_fn: A function that takes global step and returns learning
198      rate for the embedding table. If you intend to use the same learning rate
199      for multiple embedding tables, please ensure that you pass the exact same
200      python function to all calls of shared_embedding_columns, otherwise
201      performence may suffer.
202    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
203      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
204      there are no empty rows and all weights and ids are positive at the
205      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
206      input tensors. Defaults to true, consider turning off if the above checks
207      are not needed. Note that having empty rows will not trigger any error
208      though the output result might be 0 or omitted.
209
210  Returns:
211    A  _TPUEmbeddingColumn.
212
213  Raises:
214    ValueError: if `dimension` not > 0.
215    ValueError: if `initializer` is specified but not callable.
216    ValueError: if `max_sequence_lengths` is specified and not the same length
217      as `categorical_columns`.
218    ValueError: if `max_sequence_lengths` is positive for a non sequence column
219      or 0 for a sequence column.
220  """
221  for categorical_column in categorical_columns:
222    if isinstance(categorical_column, _DENYLISTED_CATEGORICAL_COLUMNS_V2):
223      raise TypeError('categorical_column for tpu '
224                      ' embedding_column was denylisted type '
225                      f'{type(categorical_column)}')
226    if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS):
227      raise TypeError(
228          'categorical_column for tpu '
229          ' shared_embedding_columns must be type {}, got {}.'.format(
230              ' or '.join(
231                  [cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS]),
232              type(categorical_column)))
233
234  if not max_sequence_lengths:
235    max_sequence_lengths = [0] * len(categorical_columns)
236  if len(max_sequence_lengths) != len(categorical_columns):
237    raise ValueError('max_sequence_lengths and categorical_columns must be of '
238                     'the same length. len(max_sequence_lengths)={} '
239                     'len(categorical_columns)={}.'.format(
240                         len(max_sequence_lengths), len(categorical_columns)))
241
242  if (dimension is None) or (dimension < 1):
243    raise ValueError('Invalid dimension {}.'.format(dimension))
244
245  if (initializer is not None) and (not callable(initializer)):
246    raise ValueError('initializer must be callable if specified. ')
247  if initializer is None:
248    initializer = init_ops.truncated_normal_initializer(
249        mean=0.0, stddev=1 / math.sqrt(dimension))
250
251  # Sort the columns so the default collection name is deterministic even if the
252  # user passes columns from an unsorted collection, such as dict.values().
253  sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
254  num_buckets = sorted_columns[0]._num_buckets  # pylint: disable=protected-access
255
256  for c in sorted_columns[1:]:
257    if num_buckets != c._num_buckets:  # pylint: disable=protected-access
258      raise ValueError(
259          'To use shared_embedding_column, all categorical_columns must have '
260          'the same number of buckets. Given column: {} with buckets: {} does  '
261          'not match column: {} with buckets: {}'.format(
262              sorted_columns[0], num_buckets, c, c._num_buckets))  # pylint: disable=protected-access
263
264  if not shared_embedding_collection_name:
265    shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
266    shared_embedding_collection_name += '_shared_embedding'
267
268  tpu_columns = []
269
270  # Create the state (_SharedEmbeddingColumnLayer) here.
271  for categorical_column, max_sequence_length in zip(
272      categorical_columns, max_sequence_lengths):
273    column = _TPUSharedEmbeddingColumn(
274        categorical_column=categorical_column,
275        dimension=dimension,
276        combiner=combiner,
277        initializer=initializer,
278        shared_embedding_collection_name=shared_embedding_collection_name,
279        ckpt_to_load_from=None,
280        tensor_name_in_ckpt=None,
281        max_norm=None,
282        trainable=True,
283        max_sequence_length=max_sequence_length,
284        learning_rate_fn=learning_rate_fn,
285        use_safe_embedding_lookup=use_safe_embedding_lookup)
286    tpu_columns.append(column)
287
288  return tpu_columns
289
290
291class _TPUBaseEmbeddingColumn(object):
292  """Base class for TPU Embedding Column."""
293
294  def __init__(self,
295               categorical_column,
296               max_sequence_length=0,
297               learning_rate_fn=None):
298    self._tpu_categorical_column = categorical_column
299    self._max_sequence_length = max_sequence_length
300    self._learning_rate_fn = learning_rate_fn
301    if (self.is_sequence_column() and max_sequence_length < 1):
302      raise ValueError('max_sequence_length must be greater than 0 for '
303                       'sequence columns. Got max_sequence_length={} for '
304                       'sequence column {}.'.format(max_sequence_length,
305                                                    categorical_column.name))
306    if (not self.is_sequence_column() and max_sequence_length != 0):
307      raise ValueError('Non zero max_seq_length={} specified for non '
308                       'sequence column {}.'.format(max_sequence_length,
309                                                    categorical_column.name))
310
311  def get_combiner(self):
312    """Returns the embedding combiner."""
313    raise NotImplementedError('not implemented')
314
315  def get_embedding_table_size(self):
316    """Returns the embedding table size, tuple of vocab size and dimension."""
317    raise NotImplementedError('not implemented')
318
319  def get_feature_key_name(self):
320    """Returns the feature key name in the features dict."""
321    raise NotImplementedError('not impl')
322
323  def get_weight_key_name(self):
324    """Return the key name for weights."""
325    raise NotImplementedError('not impl')
326
327  def get_embedding_var_name(self):
328    """Returns the embedding variable name.
329
330    Feature key name and embedding variable name are usually one-to-one mapping.
331    But for shared embedding columns, it is many-to-one mapping.
332    """
333    raise NotImplementedError('not impl')
334
335  def get_initializer(self):
336    """Returns the initializer."""
337    raise NotImplementedError('not impl')
338
339  def is_categorical_column_weighted(self):
340    """Check if the categorical column of the embedding column is weighted."""
341    raise NotImplementedError('not impl')
342
343  def is_sequence_column(self):
344    return isinstance(self._tpu_categorical_column, _SUPPORTED_SEQUENCE_COLUMNS)
345
346  def get_max_sequence_length(self):
347    return self._max_sequence_length
348
349  def get_learning_rate_fn(self):
350    return self._learning_rate_fn
351
352  def get_sequence_length_feature_key_name(self):
353    """Get the key for the associated sequence length feature."""
354    return get_sequence_length_feature_key_name_from_feature_key_name(
355        self.get_feature_key_name())
356
357
358class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
359  """Core Embedding Column."""
360
361  def __new__(cls,
362              categorical_column,
363              dimension,
364              combiner='mean',
365              layer_creator=None,
366              ckpt_to_load_from=None,
367              tensor_name_in_ckpt=None,
368              max_norm=None,
369              trainable=True,
370              max_sequence_length=0,
371              learning_rate_fn=None,
372              use_safe_embedding_lookup=True,
373              bypass_scope_validation=False):
374    # Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
375    # are not supported on TPU. They are solely for matching the signature of
376    # __new__ of parent class fc._EmbeddingColumn.
377    del bypass_scope_validation
378    # pylint: disable=redundant-keyword-arg
379    return fc._EmbeddingColumn.__new__(
380        cls,
381        categorical_column,
382        dimension,
383        combiner=combiner,
384        layer_creator=layer_creator,
385        ckpt_to_load_from=ckpt_to_load_from,
386        tensor_name_in_ckpt=tensor_name_in_ckpt,
387        max_norm=max_norm,
388        trainable=trainable,
389        use_safe_embedding_lookup=use_safe_embedding_lookup)
390
391  def __init__(self,
392               categorical_column,
393               dimension,
394               combiner='mean',
395               layer_creator=None,
396               ckpt_to_load_from=None,
397               tensor_name_in_ckpt=None,
398               max_norm=None,
399               trainable=True,
400               max_sequence_length=0,
401               learning_rate_fn=None,
402               use_safe_embedding_lookup=True,
403               bypass_scope_validation=False):
404    _TPUBaseEmbeddingColumn.__init__(
405        self,
406        categorical_column,
407        max_sequence_length=max_sequence_length,
408        learning_rate_fn=learning_rate_fn)
409    self._key = None
410    # If true, scope validation is skipped to allow the same column to be used
411    # in multiple variable scopes. By default, this is False, and we expect a
412    # 1:1 mapping between feature columns and scopes.
413    self._bypass_scope_validation = bypass_scope_validation
414
415  def get_combiner(self):
416    return self.combiner
417
418  def get_embedding_table_size(self):
419    """Returns num_ids and width."""
420    return (self.categorical_column._num_buckets, self.dimension)
421
422  def get_feature_key_name(self):
423    """get_feature_key_name."""
424    if self.is_categorical_column_weighted():
425      return self.categorical_column.categorical_column.name
426    return self.categorical_column.name
427
428  def get_weight_key_name(self):
429    """get_weight_key_name."""
430    if self.is_categorical_column_weighted():
431      return self.categorical_column.weight_feature_key
432    return None
433
434  def get_embedding_var_name(self):
435    """get_embedding_var_name."""
436    return self.categorical_column.name
437
438  def get_initializer(self):
439    return self._tpu_initializer
440
441  def is_categorical_column_weighted(self):
442    """Check if the categorical column of the embedding column is weighted."""
443    if isinstance(
444        self.categorical_column,
445        (
446            fc._WeightedCategoricalColumn,  # pylint: disable=protected-access
447            fc_lib.WeightedCategoricalColumn)):
448      return True
449    return False
450
451  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
452    if tpu.under_tpu_inference_context():
453      def host_computation():
454        return fc._EmbeddingColumn._get_dense_tensor(
455            self, inputs, weight_collections, trainable)
456      return tpu.outside_compilation(host_computation)
457
458    if _is_running_on_cpu():
459      return fc._EmbeddingColumn._get_dense_tensor(
460          self, inputs, weight_collections, trainable)
461
462    # TPU mode
463    # Get the embeddings from the LazyBuilder.
464    tensor = inputs.get(self.get_feature_key_name())
465
466    # Add to collection for _create_tpu_embedding_variables_and_ops
467    _record_variable_scope_and_name(
468        self.get_embedding_var_name(),
469        'embedding_weights',
470        bypass_scope_validation=self._bypass_scope_validation)
471
472    return tensor
473
474  def _get_sequence_dense_tensor(
475      self, inputs, weight_collections=None, trainable=None):
476    if tpu.under_tpu_inference_context():
477      def host_computation():
478        return fc._EmbeddingColumn._get_sequence_dense_tensor(
479            self, inputs, weight_collections, trainable)
480      return tpu.outside_compilation(host_computation)
481
482    if _is_running_on_cpu():
483      return fc._EmbeddingColumn._get_sequence_dense_tensor(
484          self, inputs, weight_collections, trainable)
485
486    tensor = inputs.get(self.get_feature_key_name())
487    tensor_lengths = inputs.get(self.get_sequence_length_feature_key_name())
488
489    # inputs is a _LazyBuilder and for rank 1 tensors, it calls expand_dims(-1).
490    # We need to undo this to match the standard CPU sequence embedding.
491    tensor_lengths = array_ops.squeeze(tensor_lengths, -1)
492
493    # Add to collection for _create_tpu_embedding_variables_and_ops
494    _record_variable_scope_and_name(
495        self.get_embedding_var_name(),
496        'embedding_weights',
497        bypass_scope_validation=self._bypass_scope_validation)
498
499    return fc._SequenceDenseColumn.TensorSequenceLengthPair(
500        dense_tensor=tensor, sequence_length=tensor_lengths)
501
502
503class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
504                                fc._SharedEmbeddingColumn):
505  """Core Shared Embedding Column."""
506
507  def __new__(cls,
508              categorical_column,
509              dimension,
510              combiner='mean',
511              initializer=None,
512              shared_embedding_collection_name=None,
513              ckpt_to_load_from=None,
514              tensor_name_in_ckpt=None,
515              max_norm=None,
516              trainable=True,
517              max_sequence_length=0,
518              learning_rate_fn=None,
519              use_safe_embedding_lookup=True):
520    return fc._SharedEmbeddingColumn.__new__(
521        cls,
522        categorical_column,
523        dimension,
524        combiner=combiner,
525        initializer=initializer,
526        shared_embedding_collection_name=shared_embedding_collection_name,
527        ckpt_to_load_from=ckpt_to_load_from,
528        tensor_name_in_ckpt=tensor_name_in_ckpt,
529        max_norm=max_norm,
530        trainable=trainable,
531        use_safe_embedding_lookup=use_safe_embedding_lookup)
532
533  def __init__(self,
534               categorical_column,
535               dimension,
536               combiner='mean',
537               initializer=None,
538               shared_embedding_collection_name=None,
539               ckpt_to_load_from=None,
540               tensor_name_in_ckpt=None,
541               max_norm=None,
542               trainable=True,
543               max_sequence_length=0,
544               learning_rate_fn=None,
545               use_safe_embedding_lookup=True):
546
547    _TPUBaseEmbeddingColumn.__init__(
548        self,
549        categorical_column,
550        max_sequence_length=max_sequence_length,
551        learning_rate_fn=learning_rate_fn)
552    self._key = None
553
554  def get_combiner(self):
555    return self.combiner
556
557  def get_embedding_table_size(self):
558    """Returns num_ids and width."""
559    return (self.categorical_column._num_buckets, self.dimension)
560
561  def get_feature_key_name(self):
562    """get_feature_key_name."""
563    if self.is_categorical_column_weighted():
564      return self.categorical_column.categorical_column.name
565    return self.categorical_column.name
566
567  def get_weight_key_name(self):
568    """get_weight_key_name."""
569    if self.is_categorical_column_weighted():
570      return self.categorical_column.weight_feature_key
571    return None
572
573  def get_embedding_var_name(self):
574    """get_embedding_var_name."""
575    return self.shared_embedding_collection_name
576
577  def get_initializer(self):
578    return self.initializer
579
580  def is_categorical_column_weighted(self):
581    """Check if the categorical column of the embedding column is weighted."""
582    if isinstance(
583        self.categorical_column,
584        (
585            fc._WeightedCategoricalColumn,  # pylint: disable=protected-access
586            fc_lib.WeightedCategoricalColumn)):
587      return True
588    return False
589
590  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
591    if tpu.under_tpu_inference_context():
592      def host_computation():
593        return fc._SharedEmbeddingColumn._get_dense_tensor(
594            self, inputs, weight_collections, trainable)
595      return tpu.outside_compilation(host_computation)
596
597    if _is_running_on_cpu():
598      return fc._SharedEmbeddingColumn._get_dense_tensor(
599          self, inputs, weight_collections, trainable)
600
601    # TPU mode
602    # Get the embeddings from the LazyBuilder.
603    tensor = inputs.get(self.get_feature_key_name())
604
605    # Add to collection for _create_tpu_embedding_variables_and_ops
606    _record_variable_scope_and_name(
607        self.get_embedding_var_name(),
608        'embedding_weights',
609        is_shared_embedding=True)
610    return tensor
611
612  def _get_sequence_dense_tensor(
613      self, inputs, weight_collections=None, trainable=None):
614    if tpu.under_tpu_inference_context():
615      def host_computation():
616        return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
617            self, inputs, weight_collections, trainable)
618      return tpu.outside_compilation(host_computation)
619
620    if _is_running_on_cpu():
621      return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
622          self, inputs, weight_collections, trainable)
623
624    tensor = inputs.get(self.get_feature_key_name())
625    tensor_lengths = inputs.get(self.get_sequence_length_feature_key_name())
626
627    # Add to collection for _create_tpu_embedding_variables_and_ops
628    _record_variable_scope_and_name(
629        self.get_embedding_var_name(),
630        'embedding_weights',
631        is_shared_embedding=True)
632
633    return fc._SequenceDenseColumn.TensorSequenceLengthPair(
634        dense_tensor=tensor, sequence_length=tensor_lengths)
635
636
637def _record_variable_scope_and_name(embedding_var_name,
638                                    embedding_var_name_in_fc,
639                                    is_shared_embedding=False,
640                                    bypass_scope_validation=False):
641  """Add embedding variable name and scope to collection."""
642  g = ops.get_default_graph()
643  collection = g.get_collection_ref(_TPU_FC_TO_SCOPE)
644  if not collection:
645    collection.append({})
646
647  var_def_dict = collection[0]
648
649  captured_scope = variable_scope.get_variable_scope()
650  captured_scope_name = captured_scope.name
651
652  if embedding_var_name in var_def_dict:
653    if (var_def_dict[embedding_var_name][0] != captured_scope_name and
654        not is_shared_embedding and not bypass_scope_validation):
655      raise ValueError(
656          'For embedding var name {}, the variable scope name is different, '
657          'got {}; expected {}'.format(embedding_var_name,
658                                       captured_scope_name,
659                                       var_def_dict[embedding_var_name][0]))
660    if var_def_dict[embedding_var_name][1] != embedding_var_name_in_fc:
661      raise ValueError(
662          'For embedding var name {}, the embedding name is different, '
663          'got {}; expected {}'.format(embedding_var_name,
664                                       embedding_var_name_in_fc,
665                                       var_def_dict[embedding_var_name][1]))
666  else:
667    var_def_dict[embedding_var_name] = (captured_scope_name,
668                                        embedding_var_name_in_fc)
669
670
671def _is_running_on_cpu():
672  """Returns True if the current context is CPU model."""
673  return tpu_function.get_tpu_context().number_of_shards is None
674
675
676def get_sequence_length_feature_key_name_from_feature_key_name(feature_name):
677  """Gets the name of the sequence length feature from that of the base feature.
678
679  Args:
680    feature_name: The feature key of a sequence column.
681
682  Returns:
683    A string which is the feature key for the associated feature length column.
684  """
685  return feature_name + _SEQUENCE_FEATURE_LENGTH_POSTFIX
686
687
688def split_sequence_columns(feature_columns):
689  """Split a list of _TPUEmbeddingColumn into sequence and non-sequence columns.
690
691  For use in a TPUEstimator model_fn function. E.g.
692
693  def model_fn(features):
694    sequence_columns, feature_columns = (
695        tf.tpu.feature_column.split_sequence_columns(feature_columns))
696    input = tf.feature_column.input_layer(
697        features=features, feature_columns=feature_columns)
698    sequence_features, sequence_lengths = (
699        tf.contrib.feature_column.sequence_input_layer(
700            features=features, feature_columns=sequence_columns))
701
702  Args:
703    feature_columns: A list of _TPUEmbeddingColumns to split.
704
705  Returns:
706    Two lists of _TPUEmbeddingColumns, the first is the sequence columns and the
707    second is the non-sequence columns.
708  """
709  sequence_columns = []
710  non_sequence_columns = []
711  for column in feature_columns:
712    if not isinstance(column, (_TPUEmbeddingColumn, _TPUSharedEmbeddingColumn)):
713      raise TypeError(
714          'column must be a _TPUEmbeddingColumn or  _TPUSharedEmbeddingColumn '
715          f'but got {type(column)} instead.')
716    if column.is_sequence_column():
717      sequence_columns.append(column)
718    else:
719      non_sequence_columns.append(column)
720  return sequence_columns, non_sequence_columns
721