xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/feature_column_v2.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""TPU Feature Column Library."""
16import copy
17import math
18
19import enum
20
21from tensorflow.python.feature_column import feature_column as fc
22from tensorflow.python.feature_column import feature_column_lib as fc_lib
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import embedding_ops
27from tensorflow.python.ops import init_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import sparse_ops
30from tensorflow.python.ops import variable_scope
31from tensorflow.python.tpu import tpu
32from tensorflow.python.tpu.feature_column import _is_running_on_cpu
33from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name
34from tensorflow.python.tpu.feature_column import _SUPPORTED_CATEGORICAL_COLUMNS_V2
35from tensorflow.python.tpu.feature_column import _SUPPORTED_SEQUENCE_COLUMNS
36from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn
37from tensorflow.python.util.tf_export import tf_export
38# pylint: disable=protected-access
39
40_ALLOWED_DEVICES = ['cpu', 'tpu_tensor_core', 'tpu_embedding_core']
41_TENSOR_CORE_MASK_KEY_SUFFIX = '__TENSOR_CORE_MASK'
42
43
44class EmbeddingDevice(enum.Enum):
45  CPU = 1
46  TPU_TENSOR_CORE = 2
47  TPU_EMBEDDING_CORE = 3
48
49
50@tf_export(v1=['tpu.experimental.embedding_column'])
51def embedding_column_v2(categorical_column,
52                        dimension,
53                        combiner='mean',
54                        initializer=None,
55                        max_sequence_length=0,
56                        learning_rate_fn=None,
57                        embedding_lookup_device=None,
58                        tensor_core_shape=None,
59                        use_safe_embedding_lookup=True):
60  """TPU version of `tf.compat.v1.feature_column.embedding_column`.
61
62  Note that the interface for `tf.tpu.experimental.embedding_column` is
63  different from that of `tf.compat.v1.feature_column.embedding_column`: The
64  following arguments are NOT supported: `ckpt_to_load_from`,
65  `tensor_name_in_ckpt`, `max_norm` and `trainable`.
66
67  Use this function in place of `tf.compat.v1.feature_column.embedding_column`
68  when you want to use the TPU to accelerate your embedding lookups via TPU
69  embeddings.
70
71  ```
72  column = tf.feature_column.categorical_column_with_identity(...)
73  tpu_column = tf.tpu.experimental.embedding_column(column, 10)
74  ...
75  def model_fn(features):
76    dense_feature = tf.keras.layers.DenseFeature(tpu_column)
77    embedded_feature = dense_feature(features)
78    ...
79
80  estimator = tf.estimator.tpu.TPUEstimator(
81      model_fn=model_fn,
82      ...
83      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
84        column=[tpu_column],
85        ...))
86  ```
87
88  Args:
89    categorical_column: A categorical column returned from
90        `categorical_column_with_identity`, `weighted_categorical_column`,
91        `categorical_column_with_vocabulary_file`,
92        `categorical_column_with_vocabulary_list`,
93        `sequence_categorical_column_with_identity`,
94        `sequence_categorical_column_with_vocabulary_file`,
95        `sequence_categorical_column_with_vocabulary_list`
96    dimension: An integer specifying dimension of the embedding, must be > 0.
97    combiner: A string specifying how to reduce if there are multiple entries
98      in a single row for a non-sequence column. For more information, see
99      `tf.feature_column.embedding_column`.
100    initializer: A variable initializer function to be used in embedding
101      variable initialization. If not specified, defaults to
102      `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and
103      standard deviation `1/sqrt(dimension)`.
104    max_sequence_length: An non-negative integer specifying the max sequence
105      length. Any sequence shorter then this will be padded with 0 embeddings
106      and any sequence longer will be truncated. This must be positive for
107      sequence features and 0 for non-sequence features.
108    learning_rate_fn: A function that takes global step and returns learning
109      rate for the embedding table. If you intend to use the same learning rate
110      for multiple embedding tables, please ensure that you pass the exact same
111      python function to all calls of embedding_column, otherwise performence
112      may suffer.
113    embedding_lookup_device: The device on which to run the embedding lookup.
114      Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core".
115      If specifying "tpu_tensor_core", a tensor_core_shape must be supplied.
116      If not specified, the default behavior is embedding lookup on
117      "tpu_embedding_core" for training and "cpu" for inference.
118      Valid options for training : ["tpu_embedding_core", "tpu_tensor_core"]
119      Valid options for serving :  ["cpu", "tpu_tensor_core"]
120      For training, tpu_embedding_core is good for large embedding vocab (>1M),
121      otherwise, tpu_tensor_core is often sufficient.
122      For serving, doing embedding lookup on tpu_tensor_core during serving is
123      a way to reduce host cpu usage in cases where that is a bottleneck.
124    tensor_core_shape: If supplied, a list of integers which specifies
125      the intended dense shape to run embedding lookup for this feature on
126      TensorCore. The batch dimension can be left None or -1 to indicate
127      a dynamic shape. Only rank 2 shapes currently supported.
128    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
129      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
130      there are no empty rows and all weights and ids are positive at the
131      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
132      input tensors. Defaults to true, consider turning off if the above checks
133      are not needed. Note that having empty rows will not trigger any error
134      though the output result might be 0 or omitted.
135
136  Returns:
137    A  `_TPUEmbeddingColumnV2`.
138
139  Raises:
140    ValueError: if `dimension` not > 0.
141    ValueError: if `initializer` is specified but not callable.
142  """
143
144  if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS_V2):
145    raise TypeError(
146        'categorical_column for tpu '
147        'embedding_column must be type {}, got {}.'.format(' or '.join([
148            cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS_V2
149        ]), type(categorical_column)))
150  if (dimension is None) or (dimension < 1):
151    raise ValueError('Invalid dimension {}.'.format(dimension))
152  if tensor_core_shape and len(tensor_core_shape) != 2:
153    raise ValueError(
154        'tensor_core_shape must be size 2. Got {}.'.format(tensor_core_shape))
155
156  if (initializer is not None) and (not callable(initializer)):
157    raise ValueError('initializer must be callable if specified. '
158                     'Embedding of column_name: {}'.format(
159                         categorical_column.name))
160  if initializer is None:
161    initializer = init_ops.truncated_normal_initializer(
162        mean=0.0, stddev=1 / math.sqrt(dimension))
163
164  if (embedding_lookup_device and
165      embedding_lookup_device not in _ALLOWED_DEVICES):
166    raise ValueError(
167        f'If set, embedding_lookup_device must be in {_ALLOWED_DEVICES}')
168
169  if embedding_lookup_device == 'cpu':
170    embedding_lookup_device = EmbeddingDevice.CPU
171  elif embedding_lookup_device == 'tpu_tensor_core':
172    embedding_lookup_device = EmbeddingDevice.TPU_TENSOR_CORE
173  elif embedding_lookup_device == 'tpu_embedding_core':
174    embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE
175
176  if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE:
177    if not tensor_core_shape:
178      raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
179                       'tensor_core_shape to be set.')
180    if isinstance(categorical_column, _SUPPORTED_SEQUENCE_COLUMNS):
181      raise ValueError('embedding_lookup_device=tpu_tensor_core currently does '
182                       'not support sequence columns.')
183
184  if not embedding_lookup_device:
185    return _TPUEmbeddingColumnV2(
186        categorical_column=categorical_column,
187        dimension=dimension,
188        combiner=combiner,
189        initializer=initializer,
190        max_sequence_length=max_sequence_length,
191        learning_rate_fn=learning_rate_fn,
192        use_safe_embedding_lookup=use_safe_embedding_lookup)
193  else:
194    return _TPUDeviceSpecificEmbeddingColumnV2(
195        categorical_column=categorical_column,
196        dimension=dimension,
197        combiner=combiner,
198        initializer=initializer,
199        max_sequence_length=max_sequence_length,
200        learning_rate_fn=learning_rate_fn,
201        embedding_lookup_device=embedding_lookup_device,
202        tensor_core_shape=tensor_core_shape,
203        use_safe_embedding_lookup=use_safe_embedding_lookup)
204
205
206@tf_export(v1=['tpu.experimental.shared_embedding_columns'])
207def shared_embedding_columns_v2(categorical_columns,
208                                dimension,
209                                combiner='mean',
210                                initializer=None,
211                                shared_embedding_collection_name=None,
212                                max_sequence_lengths=None,
213                                learning_rate_fn=None,
214                                embedding_lookup_device=None,
215                                tensor_core_shape=None,
216                                use_safe_embedding_lookup=True):
217  """TPU version of `tf.compat.v1.feature_column.shared_embedding_columns`.
218
219  Note that the interface for `tf.tpu.experimental.shared_embedding_columns` is
220  different from that of `tf.compat.v1.feature_column.shared_embedding_columns`:
221  The following arguments are NOT supported: `ckpt_to_load_from`,
222  `tensor_name_in_ckpt`, `max_norm` and `trainable`.
223
224  Use this function in place of
225  tf.compat.v1.feature_column.shared_embedding_columns` when you want to use the
226  TPU to accelerate your embedding lookups via TPU embeddings.
227
228  ```
229  column_a = tf.feature_column.categorical_column_with_identity(...)
230  column_b = tf.feature_column.categorical_column_with_identity(...)
231  tpu_columns = tf.tpu.experimental.shared_embedding_columns(
232      [column_a, column_b], 10)
233  ...
234  def model_fn(features):
235    dense_feature = tf.keras.layers.DenseFeature(tpu_columns)
236    embedded_feature = dense_feature(features)
237    ...
238
239  estimator = tf.estimator.tpu.TPUEstimator(
240      model_fn=model_fn,
241      ...
242      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
243          column=tpu_columns,
244          ...))
245  ```
246
247  Args:
248    categorical_columns: A list of categorical columns returned from
249      `categorical_column_with_identity`, `weighted_categorical_column`,
250      `categorical_column_with_vocabulary_file`,
251      `categorical_column_with_vocabulary_list`,
252      `sequence_categorical_column_with_identity`,
253      `sequence_categorical_column_with_vocabulary_file`,
254      `sequence_categorical_column_with_vocabulary_list`
255    dimension: An integer specifying dimension of the embedding, must be > 0.
256    combiner: A string specifying how to reduce if there are multiple entries in
257      a single row for a non-sequence column. For more information, see
258      `tf.feature_column.embedding_column`.
259    initializer: A variable initializer function to be used in embedding
260      variable initialization. If not specified, defaults to
261      `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
262      `1/sqrt(dimension)`.
263    shared_embedding_collection_name: Optional name of the collection where
264      shared embedding weights are added. If not given, a reasonable name will
265      be chosen based on the names of `categorical_columns`. This is also used
266      in `variable_scope` when creating shared embedding weights.
267    max_sequence_lengths: An list of non-negative integers, either None or empty
268      or the same length as the argument categorical_columns. Entries
269      corresponding to non-sequence columns must be 0 and entries corresponding
270      to sequence columns specify the max sequence length for the column. Any
271      sequence shorter then this will be padded with 0 embeddings and any
272      sequence longer will be truncated.
273    learning_rate_fn: A function that takes global step and returns learning
274      rate for the embedding table. If you intend to use the same learning rate
275      for multiple embedding tables, please ensure that you pass the exact same
276      python function to all calls of shared_embedding_columns, otherwise
277      performence may suffer.
278    embedding_lookup_device: The device on which to run the embedding lookup.
279      Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core". If
280      specifying "tpu_tensor_core", a tensor_core_shape must be supplied.
281      Defaults to "cpu". If not specified, the default behavior is embedding
282      lookup on "tpu_embedding_core" for training and "cpu" for inference.
283      Valid options for training : ["tpu_embedding_core", "tpu_tensor_core"]
284      Valid options for serving :  ["cpu", "tpu_tensor_core"]
285      For training, tpu_embedding_core is good for large embedding vocab (>1M),
286      otherwise, tpu_tensor_core is often sufficient.
287      For serving, doing embedding lookup on tpu_tensor_core during serving is
288      a way to reduce host cpu usage in cases where that is a bottleneck.
289    tensor_core_shape: If supplied, a list of integers which specifies the
290      intended dense shape to run embedding lookup for this feature on
291      TensorCore. The batch dimension can be left None or -1 to indicate a
292      dynamic shape. Only rank 2 shapes currently supported.
293    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
294      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
295      there are no empty rows and all weights and ids are positive at the
296      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
297      input tensors. Defaults to true, consider turning off if the above checks
298      are not needed. Note that having empty rows will not trigger any error
299      though the output result might be 0 or omitted.
300
301  Returns:
302    A  list of `_TPUSharedEmbeddingColumnV2`.
303
304  Raises:
305    ValueError: if `dimension` not > 0.
306    ValueError: if `initializer` is specified but not callable.
307    ValueError: if `max_sequence_lengths` is specified and not the same length
308      as `categorical_columns`.
309    ValueError: if `max_sequence_lengths` is positive for a non sequence column
310      or 0 for a sequence column.
311  """
312
313  for categorical_column in categorical_columns:
314    if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS_V2):
315      raise TypeError(
316          'categorical_column for tpu '
317          ' shared_embedding_columns must be type {}, got {}.'.format(
318              ' or '.join(
319                  [cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS_V2]),
320              type(categorical_column)))
321
322  if not max_sequence_lengths:
323    max_sequence_lengths = [0] * len(categorical_columns)
324  if len(max_sequence_lengths) != len(categorical_columns):
325    raise ValueError('max_sequence_lengths and categorical_columns must be of '
326                     'the same length. len(max_sequence_lengths)={} '
327                     'len(categorical_columns)={}.'.format(
328                         len(max_sequence_lengths), len(categorical_columns)))
329
330  if (dimension is None) or (dimension < 1):
331    raise ValueError('Invalid dimension {}.'.format(dimension))
332  if tensor_core_shape and len(tensor_core_shape) != 2:
333    raise ValueError(
334        'tensor_core_shape must be size 2. Got {}.'.format(tensor_core_shape))
335
336  if (initializer is not None) and (not callable(initializer)):
337    raise ValueError('initializer must be callable if specified. ')
338  if initializer is None:
339    initializer = init_ops.truncated_normal_initializer(
340        mean=0.0, stddev=1 / math.sqrt(dimension))
341
342  # Sort the columns so the default collection name is deterministic even if the
343  # user passes columns from an unsorted collection, such as dict.values().
344  sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
345  num_buckets = sorted_columns[0]._num_buckets  # pylint: disable=protected-access
346
347  for c in sorted_columns[1:]:
348    if num_buckets != c._num_buckets:  # pylint: disable=protected-access
349      raise ValueError(
350          'To use shared_embedding_column, all categorical_columns must have '
351          'the same number of buckets. Given column: {} with buckets: {} does  '
352          'not match column: {} with buckets: {}'.format(
353              sorted_columns[0], num_buckets, c, c._num_buckets))  # pylint: disable=protected-access
354
355  if not shared_embedding_collection_name:
356    shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
357    shared_embedding_collection_name += '_shared_embedding'
358
359  tpu_columns = []
360
361  column_creator = fc_lib.SharedEmbeddingColumnCreator(
362      dimension=dimension, initializer=initializer, ckpt_to_load_from=None,
363      tensor_name_in_ckpt=None, num_buckets=num_buckets, trainable=None,
364      name=shared_embedding_collection_name)
365
366  if (embedding_lookup_device and
367      embedding_lookup_device not in _ALLOWED_DEVICES):
368    raise ValueError(
369        f'If set, embedding_lookup_device must be in {_ALLOWED_DEVICES}')
370
371  if embedding_lookup_device == 'cpu':
372    embedding_lookup_device = EmbeddingDevice.CPU
373  elif embedding_lookup_device == 'tpu_tensor_core':
374    embedding_lookup_device = EmbeddingDevice.TPU_TENSOR_CORE
375  elif embedding_lookup_device == 'tpu_embedding_core':
376    embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE
377
378  if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE:
379    if not tensor_core_shape:
380      raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires '
381                       'tensor_core_shape to be set.')
382    for c in sorted_columns:
383      if isinstance(c, _SUPPORTED_SEQUENCE_COLUMNS):
384        raise ValueError('embedding_lookup_device=tpu_tensor_core currently '
385                         'does not support sequence columns.')
386
387  # Create the state (_SharedEmbeddingColumnLayer) here.
388  for categorical_column, max_sequence_length in zip(
389      categorical_columns, max_sequence_lengths):
390    if not embedding_lookup_device:
391      column = _TPUSharedEmbeddingColumnV2(
392          categorical_column=categorical_column,
393          shared_embedding_column_creator=column_creator,
394          combiner=combiner,
395          initializer=initializer,
396          shared_embedding_collection_name=shared_embedding_collection_name,
397          max_sequence_length=max_sequence_length,
398          learning_rate_fn=learning_rate_fn,
399          use_safe_embedding_lookup=use_safe_embedding_lookup)
400    else:
401      column = _TPUSharedDeviceSpecificEmbeddingColumnV2(
402          categorical_column=categorical_column,
403          shared_embedding_column_creator=column_creator,
404          combiner=combiner,
405          initializer=initializer,
406          shared_embedding_collection_name=shared_embedding_collection_name,
407          max_sequence_length=max_sequence_length,
408          learning_rate_fn=learning_rate_fn,
409          embedding_lookup_device=embedding_lookup_device,
410          tensor_core_shape=tensor_core_shape,
411          use_safe_embedding_lookup=use_safe_embedding_lookup)
412    tpu_columns.append(column)
413
414  return tpu_columns
415
416
417class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn):
418  """Core Embedding Column."""
419
420  def __new__(cls,
421              categorical_column,
422              dimension,
423              combiner='mean',
424              initializer=None,
425              max_sequence_length=0,
426              learning_rate_fn=None,
427              use_safe_embedding_lookup=True,
428              bypass_scope_validation=False):
429    del bypass_scope_validation
430    # pylint: disable=redundant-keyword-arg
431    return fc_lib.EmbeddingColumn.__new__(
432        cls,
433        categorical_column,
434        dimension,
435        combiner=combiner,
436        initializer=initializer,
437        ckpt_to_load_from=None,
438        tensor_name_in_ckpt=None,
439        max_norm=None,
440        trainable=True,
441        use_safe_embedding_lookup=use_safe_embedding_lookup)
442
443  def __getnewargs__(self):
444    return (self._tpu_categorical_column, self.dimension, self.combiner,
445            self.initializer, self._max_sequence_length, self._learning_rate_fn,
446            self.use_safe_embedding_lookup, self._bypass_scope_validation)
447
448  def __deepcopy__(self, memo):
449    return _TPUEmbeddingColumnV2(
450        *(copy.deepcopy(a, memo) for a in self.__getnewargs__()))
451
452  def __init__(self,
453               categorical_column,
454               dimension,
455               combiner='mean',
456               initializer=None,
457               max_sequence_length=0,
458               learning_rate_fn=None,
459               use_safe_embedding_lookup=True,
460               bypass_scope_validation=False):
461    _TPUBaseEmbeddingColumn.__init__(
462        self,
463        categorical_column,
464        max_sequence_length=max_sequence_length,
465        learning_rate_fn=learning_rate_fn)
466    self._key = None
467    # If true, scope validation is skipped to allow the same column to be used
468    # in multiple variable scopes. By default, this is False, and we expect a
469    # 1:1 mapping between feature columns and scopes.
470    self._bypass_scope_validation = bypass_scope_validation
471
472  def get_combiner(self):
473    return self.combiner
474
475  def get_embedding_table_size(self):
476    """Returns num_ids and width."""
477    return (self.categorical_column._num_buckets, self.dimension)
478
479  def get_feature_key_name(self):
480    """get_feature_key_name."""
481    if self.is_categorical_column_weighted():
482      return self.categorical_column.categorical_column.name
483    return self.categorical_column.name
484
485  def get_weight_key_name(self):
486    """get_weight_key_name."""
487    if self.is_categorical_column_weighted():
488      return self.categorical_column.weight_feature_key
489    return None
490
491  def get_embedding_var_name(self):
492    """get_embedding_var_name."""
493    return self.categorical_column.name
494
495  def get_initializer(self):
496    return self.initializer
497
498  def is_categorical_column_weighted(self):
499    """Check if the categorical column of the embedding column is weighted."""
500    if isinstance(
501        self.categorical_column,
502        (
503            fc._WeightedCategoricalColumn,  # pylint: disable=protected-access
504            fc_lib.WeightedCategoricalColumn)):
505      return True
506    return False
507
508  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
509    if tpu.under_tpu_inference_context():
510      def host_computation():
511        return fc_lib.EmbeddingColumn._get_dense_tensor(
512            self, inputs, weight_collections, trainable)
513      return tpu.outside_compilation(host_computation)
514
515    if _is_running_on_cpu():
516      return fc_lib.EmbeddingColumn._get_dense_tensor(
517          self, inputs, weight_collections, trainable)
518
519    # TPU mode
520    # Get the embeddings from the LazyBuilder.
521    tensor = inputs.get(self.get_feature_key_name())
522
523    # Add to collection for _create_tpu_embedding_variables_and_ops
524    _record_variable_scope_and_name(
525        self.get_embedding_var_name(),
526        'embedding_weights',
527        bypass_scope_validation=self._bypass_scope_validation)
528
529    return tensor
530
531  def create_state(self, state_manager):
532    if _is_running_on_cpu():
533      return fc_lib.EmbeddingColumn.create_state(
534          self, state_manager)
535
536    # Create state is called for the EmbeddingColumn to create its embedding
537    # variables under feature column V2, if we are on TPU so record the scope
538    # here.
539    _record_variable_scope_and_name(
540        self.get_embedding_var_name(),
541        'embedding_weights',
542        bypass_scope_validation=self._bypass_scope_validation)
543
544  def get_dense_tensor(self, transformation_cache, state_manager):
545    if tpu.under_tpu_inference_context():
546      def host_computation():
547        return fc_lib.EmbeddingColumn.get_dense_tensor(
548            self, transformation_cache, state_manager)
549      return tpu.outside_compilation(host_computation)
550
551    if _is_running_on_cpu():
552      return fc_lib.EmbeddingColumn.get_dense_tensor(
553          self, transformation_cache, state_manager)
554
555    # TPU mode
556    # Get the embeddings from the FeatureTransformationCache.
557    tensor = transformation_cache.get(self.get_feature_key_name(),
558                                      state_manager)
559
560    return tensor
561
562  def _get_sequence_dense_tensor(
563      self, inputs, weight_collections=None, trainable=None):
564    if tpu.under_tpu_inference_context():
565      def host_computation():
566        return fc_lib.EmbeddingColumn._get_sequence_dense_tensor(
567            self, inputs, weight_collections, trainable)
568      return tpu.outside_compilation(host_computation)
569
570    if _is_running_on_cpu():
571      return fc_lib.EmbeddingColumn._get_sequence_dense_tensor(
572          self, inputs, weight_collections, trainable)
573
574    tensor = inputs.get(self.get_feature_key_name())
575    tensor_lengths = inputs.get(self.get_sequence_length_feature_key_name())
576
577    # inputs is a _LazyBuilder and for rank 1 tensors, it calls expand_dims(-1).
578    # We need to undo this to match the standard CPU sequence embedding.
579    tensor_lengths = array_ops.squeeze(tensor_lengths, -1)
580
581    # Add to collection for _create_tpu_embedding_variables_and_ops
582    _record_variable_scope_and_name(
583        self.get_embedding_var_name(),
584        'embedding_weights',
585        bypass_scope_validation=self._bypass_scope_validation)
586
587    return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair(
588        dense_tensor=tensor, sequence_length=tensor_lengths)
589
590  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
591    if tpu.under_tpu_inference_context():
592      def host_computation():
593        return fc_lib.EmbeddingColumn.get_sequence_dense_tensor(
594            self, transformation_cache, state_manager)
595      return tpu.outside_compilation(host_computation)
596
597    if _is_running_on_cpu():
598      return fc_lib.EmbeddingColumn.get_sequence_dense_tensor(
599          self, transformation_cache, state_manager)
600
601    tensor = transformation_cache.get(self.get_feature_key_name(),
602                                      state_manager)
603    tensor_lengths = transformation_cache.get(
604        self.get_sequence_length_feature_key_name(),
605        state_manager)
606
607    # FeatureTransformationCache expands rank 1 tensors (like sequence length)
608    # to rank 2. We need to undo this to match the standard CPU sequence
609    # embedding.
610    tensor_lengths = array_ops.squeeze(tensor_lengths, -1)
611
612    return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair(
613        dense_tensor=tensor, sequence_length=tensor_lengths)
614
615
616class _TPUSharedEmbeddingColumnV2(_TPUBaseEmbeddingColumn,
617                                  fc_lib.SharedEmbeddingColumn):
618  """Core Shared Embedding Column."""
619
620  def __new__(cls,
621              categorical_column,
622              shared_embedding_column_creator,
623              combiner='mean',
624              initializer=None,
625              shared_embedding_collection_name=None,
626              max_sequence_length=0,
627              learning_rate_fn=None,
628              use_safe_embedding_lookup=True):
629    # pylint: disable=redundant-keyword-arg
630    return fc_lib.SharedEmbeddingColumn.__new__(
631        cls,
632        categorical_column,
633        combiner=combiner,
634        shared_embedding_column_creator=shared_embedding_column_creator,
635        max_norm=None,
636        use_safe_embedding_lookup=use_safe_embedding_lookup)
637
638  def __getnewargs__(self):
639    return (self._tpu_categorical_column, self.shared_embedding_column_creator,
640            self.combiner, self._initializer,
641            self._shared_embedding_collection_name, self._max_sequence_length,
642            self._learning_rate_fn)
643
644  def __deepcopy__(self, memo):
645    return _TPUSharedEmbeddingColumnV2(
646        *(copy.deepcopy(a, memo) for a in self.__getnewargs__()))
647
648  def __init__(self,
649               categorical_column,
650               shared_embedding_column_creator,
651               combiner='mean',
652               initializer=None,
653               shared_embedding_collection_name=None,
654               max_sequence_length=0,
655               learning_rate_fn=None,
656               use_safe_embedding_lookup=True):
657
658    _TPUBaseEmbeddingColumn.__init__(
659        self,
660        categorical_column,
661        max_sequence_length=max_sequence_length,
662        learning_rate_fn=learning_rate_fn)
663    self._initializer = initializer
664    self._shared_embedding_collection_name = shared_embedding_collection_name
665
666  def get_combiner(self):
667    return self.combiner
668
669  def get_embedding_table_size(self):
670    """Returns num_ids and width."""
671    return (self.categorical_column._num_buckets,
672            self.shared_embedding_column_creator.dimension)
673
674  def get_feature_key_name(self):
675    """get_feature_key_name."""
676    if self.is_categorical_column_weighted():
677      return self.categorical_column.categorical_column.name
678    return self.categorical_column.name
679
680  def get_weight_key_name(self):
681    """get_weight_key_name."""
682    if self.is_categorical_column_weighted():
683      return self.categorical_column.weight_feature_key
684    return None
685
686  def get_embedding_var_name(self):
687    """get_embedding_var_name."""
688    return self._shared_embedding_collection_name
689
690  def get_initializer(self):
691    return self._initializer
692
693  def is_categorical_column_weighted(self):
694    """Check if the categorical column of the embedding column is weighted."""
695    if isinstance(
696        self.categorical_column,
697        (
698            fc._WeightedCategoricalColumn,  # pylint: disable=protected-access
699            fc_lib.WeightedCategoricalColumn)):
700      return True
701    return False
702
703  def _get_dense_tensor_internal(
704      self, transformation_cache, state_manager):
705    if tpu.under_tpu_inference_context():
706      def host_computation():
707        return fc_lib.SharedEmbeddingColumn._get_dense_tensor_internal(
708            self, transformation_cache, state_manager)
709      return tpu.outside_compilation(host_computation)
710
711    if _is_running_on_cpu():
712      return fc_lib.SharedEmbeddingColumn._get_dense_tensor_internal(
713          self, transformation_cache, state_manager)
714
715    # TPU mode
716    # Get the embeddings from the FeatureTransformationCache.
717    tensor = transformation_cache.get(self.get_feature_key_name(),
718                                      state_manager)
719
720    # Add to collection for _create_tpu_embedding_variables_and_ops
721    # Note that in Feature Column V2, shared embeddings have no scope.
722    _record_variable_scope_and_name(
723        self.get_embedding_var_name(),
724        self.shared_embedding_column_creator._name,
725        is_shared_embedding=True)
726    return tensor
727
728  def get_sequence_dense_tensor(
729      self, transformation_cache, state_manager):
730    if tpu.under_tpu_inference_context():
731      def host_computation():
732        return fc_lib.SharedEmbeddingColumn.get_sequence_dense_tensor(
733            self, transformation_cache, state_manager)
734      return tpu.outside_compilation(host_computation)
735
736    if _is_running_on_cpu():
737      return fc_lib.SharedEmbeddingColumn.get_sequence_dense_tensor(
738          self, transformation_cache, state_manager)
739
740    tensor = self._get_dense_tensor_internal(
741        transformation_cache, state_manager)
742    tensor_lengths = transformation_cache.get(
743        self.get_sequence_length_feature_key_name(),
744        state_manager)
745
746    # FeatureTransformationCache expands rank 1 tensors (like sequence length)
747    # to rank 2. We need to undo this to match the standard CPU sequence
748    # embedding.
749    tensor_lengths = array_ops.squeeze(tensor_lengths, -1)
750
751    return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair(
752        dense_tensor=tensor, sequence_length=tensor_lengths)
753
754
755def split_sequence_columns_v2(feature_columns):
756  """Split a list of _TPUEmbeddingColumn into sequence and non-sequence columns.
757
758  For use in a TPUEstimator model_fn function. E.g.
759
760  def model_fn(features):
761    sequence_columns, feature_columns = (
762        tf.tpu.feature_column.split_sequence_columns(feature_columns))
763    input = tf.feature_column.input_layer(
764        features=features, feature_columns=feature_columns)
765    sequence_features, sequence_lengths = (
766        tf.contrib.feature_column.sequence_input_layer(
767            features=features, feature_columns=sequence_columns))
768
769  Args:
770    feature_columns: A list of _TPUEmbeddingColumns to split.
771
772  Returns:
773    Two lists of _TPUEmbeddingColumns, the first is the sequence columns and the
774    second is the non-sequence columns.
775  """
776  sequence_columns = []
777  non_sequence_columns = []
778  for column in feature_columns:
779    if not isinstance(column, (_TPUEmbeddingColumnV2,
780                               _TPUSharedEmbeddingColumnV2)):
781      raise TypeError(
782          'column must be a _TPUEmbeddingColumnV2 or '
783          f'_TPUSharedEmbeddingColumnV2 but got {type(column)} instead.')
784    if column.is_sequence_column():
785      sequence_columns.append(column)
786    else:
787      non_sequence_columns.append(column)
788  return sequence_columns, non_sequence_columns
789
790
791def sparse_embedding_aggregate_slice(params,
792                                     values_and_values_mask,
793                                     combiner='mean',
794                                     name='sparse_embedding_aggregate_slice'):
795  """Uses XLA's dynamic slice operations to perform embedding lookups.
796
797  From third_party/cloud_tpu/models/movielens/tpu_embedding.py
798
799  Args:
800    params: Tensor of embedding table. Rank 2 (table_size x embedding dim)
801    values_and_values_mask: is a two-tuple that contains: values - Tensor of
802      embedding indices. Rank 2 (batch x n_indices) values_mask - Tensor of mask
803      / weights. Rank 2 (batch x n_indices)
804    combiner: The combiner to use for the embedding lookup. Currently supports
805      'sum' and 'mean'.
806    name: Optional name scope for created ops
807
808  Returns:
809    Rank 2 tensor of aggregated (per batch element) embedding vectors.
810
811  Raises:
812    ValueError: Combiner is not supported.
813  """
814  values, values_mask = values_and_values_mask  # unpack the two-tuple
815  with ops.name_scope(name):
816    _, embedding_dimension = params.get_shape().as_list()
817    n_batch, n_indices_padded = values.get_shape().as_list()
818    if not n_batch:
819      n_batch = -1
820
821    emb_lookup = array_ops.reshape(
822        embedding_ops.embedding_lookup(
823            params, array_ops.reshape(values, [n_batch, n_indices_padded])),
824        [n_batch, n_indices_padded, embedding_dimension])
825
826    values_mask_broadcast = array_ops.reshape(values_mask,
827                                              [n_batch, n_indices_padded, 1])
828    aggregate_emb = math_ops.reduce_sum(
829        emb_lookup * values_mask_broadcast, axis=1)
830    if combiner == 'sum':
831      return aggregate_emb
832    elif combiner == 'mean':
833      # In the case we have an empty row, both aggregate_emb and
834      # math_ops.reduce_sum(values_mask_broadcast, axis=1) will be 0. Thus,
835      # we can take max it with a non-zero value to prevent NaNs. Note that
836      # math_ops.reduce_sum(values_mask_broadcast, axis=1) will have integer
837      # values so 1.0 is the smallest value.
838      return aggregate_emb / math_ops.maximum(
839          math_ops.reduce_sum(values_mask_broadcast, axis=1), 1.0)
840    else:
841      raise ValueError('Dense TPU Embedding does not support combiner '
842                       'other than sum and mean.')
843
844
845def pad_sparse_embedding_lookup_indices(sparse_indices, padded_size):
846  """Creates statically-sized Tensors containing indices and weights.
847
848  From third_party/cloud_tpu/models/movielens/tpu_embedding.py
849
850  Also computes sparse_indices.values % embedding_table_size, for equivalent
851  functionality to sparse_column_with_integerized_feature. The returned
852  padded weight Tensor also doubles as a mask indicating which values in
853  the returned padded indices Tensor are indices versus padded zeros.
854
855  Args:
856    sparse_indices: SparseTensor of embedding lookup indices.
857    padded_size: Number of columns of the returned Tensors. Indices which fall
858      out of bounds will be truncated to the padded size.
859
860  Returns:
861    (sparse_indices.values padded to the specified size,
862     a mask the same size as the returned padded values in which 0s
863     indicate padded locations and 1s (or values from sparse_weights)
864     indicate actual values)
865  """
866  batch_size = sparse_indices.dense_shape[0]
867  sparse_indices = sparse_ops.sparse_slice(sparse_indices, [0, 0],
868                                           [batch_size, padded_size])
869  indices, values = sparse_indices.indices, sparse_indices.values
870
871  padded_values = array_ops.scatter_nd(
872      indices,
873      math_ops.cast(values, dtypes.int32),
874      shape=(batch_size, padded_size))
875
876  weights = array_ops.ones_like(values, dtype=dtypes.float32)
877  padded_mask = array_ops.scatter_nd(
878      indices, weights, shape=(batch_size, padded_size))
879
880  return padded_values, padded_mask
881
882
883def _check_invalid_cases(embedding_lookup_device):
884  """Checks for invalid embedding_lookup_device configurations."""
885  if (tpu.under_tpu_inference_context() and
886      embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE):
887    raise ValueError(
888        'Using embedding_lookup_device=tpu_embedding_core during inference '
889        'is not supported.')
890  if embedding_lookup_device == EmbeddingDevice.CPU:
891    if not tpu.under_tpu_inference_context():
892      raise ValueError(
893          'Using TPUEmbeddingColumn with embedding_lookup_device="cpu" '
894          'during training is not supported.')
895
896
897class _TPUDeviceSpecificEmbeddingColumnV2(_TPUEmbeddingColumnV2):
898  """TPUEmbeddingColumn which allows serving on TensorCore."""
899
900  def __new__(cls, *args, **kwargs):
901    # For __new__, just capture the inference dense shape and call parent.
902    if 'tensor_core_shape' in kwargs:
903      cls._tensor_core_shape = kwargs['tensor_core_shape']
904      del kwargs['tensor_core_shape']
905    if 'embedding_lookup_device' in kwargs:
906      cls._embedding_lookup_device = kwargs['embedding_lookup_device']
907      del kwargs['embedding_lookup_device']
908    return _TPUEmbeddingColumnV2.__new__(cls, *args, **kwargs)
909
910  def __init__(self, *args, **kwargs):
911    # For __init__, just capture the inference dense shape and call parent.
912    if 'tensor_core_shape' in kwargs:
913      self._tensor_core_shape = kwargs['tensor_core_shape']
914      del kwargs['tensor_core_shape']
915    if 'embedding_lookup_device' in kwargs:
916      self._embedding_lookup_device = kwargs['embedding_lookup_device']
917      del kwargs['embedding_lookup_device']
918    _TPUEmbeddingColumnV2.__init__(self, *args, **kwargs)
919
920  def __deepcopy__(self, memo):
921    return _TPUDeviceSpecificEmbeddingColumnV2(
922        *(copy.deepcopy(a, memo) for a in self.__getnewargs__()),
923        tensor_core_shape=self._tensor_core_shape,
924        embedding_lookup_device=self._embedding_lookup_device)
925
926  def create_state(self, state_manager):
927    _check_invalid_cases(self._embedding_lookup_device)
928    # CPU case.
929    is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU
930    is_cpu = is_cpu or _is_running_on_cpu()
931    if is_cpu:
932      return fc_lib.EmbeddingColumn.create_state(self, state_manager)
933    # TPU_EMBEDDING_CORE case.
934    elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
935      return super(_TPUDeviceSpecificEmbeddingColumnV2,
936                   self).create_state(state_manager)
937
938    # TPU_EMBEDDING_CORE case.
939    return fc_lib.EmbeddingColumn.create_state(self, state_manager)
940
941  def get_dense_tensor(self, transformation_cache, state_manager):
942    """Private method that follows get_dense_tensor."""
943    _check_invalid_cases(self._embedding_lookup_device)
944    # CPU Case.
945    is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU
946    is_cpu = is_cpu or _is_running_on_cpu()
947    if is_cpu:
948      return super(_TPUDeviceSpecificEmbeddingColumnV2,
949                   self).get_dense_tensor(transformation_cache, state_manager)
950    # TPU_EMBEDDING_CORE case.
951    elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
952      return super(_TPUDeviceSpecificEmbeddingColumnV2,
953                   self).get_dense_tensor(transformation_cache, state_manager)
954
955    # TPU_EMBEDDING_CORE cases.
956    if tpu.under_tpu_inference_context():
957      # For inference, use outside compile to densify and pad the input tensors.
958      sparse_tensor = transformation_cache.get(self.categorical_column.name,
959                                               state_manager)
960
961      def host_computation():
962        return pad_sparse_embedding_lookup_indices(sparse_tensor,
963                                                   self._tensor_core_shape[1])
964
965      values, mask = tpu.outside_compilation(host_computation)
966    else:
967      # For training, the inputs should already have been densified and padded.
968      values = transformation_cache.get(self.categorical_column.name,
969                                        state_manager)
970      mask = transformation_cache.get(
971          self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX,
972          state_manager)
973    embedding_weights = state_manager.get_variable(
974        self, name='embedding_weights')
975    return sparse_embedding_aggregate_slice(embedding_weights, (values, mask),
976                                            self.get_combiner())
977
978  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
979    _check_invalid_cases(self._embedding_lookup_device)
980    # CPU Case.
981    is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU
982    is_cpu = is_cpu or _is_running_on_cpu()
983    if is_cpu:
984      return super(_TPUDeviceSpecificEmbeddingColumnV2,
985                   self)._get_dense_tensor(inputs, weight_collections,
986                                           trainable)
987    # TPU_EMBEDDING_CORE case.
988    elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
989      return super(_TPUDeviceSpecificEmbeddingColumnV2,
990                   self)._get_dense_tensor(inputs, weight_collections,
991                                           trainable)
992
993    # TPU_EMBEDDING_CORE cases.
994    if tpu.under_tpu_inference_context():
995      # For inference, use outside compile to densify and pad the input tensors.
996      sparse_tensor = inputs.get(self.get_feature_key_name())
997
998      def host_computation():
999        return pad_sparse_embedding_lookup_indices(sparse_tensor,
1000                                                   self._tensor_core_shape[1])
1001
1002      values, mask = tpu.outside_compilation(host_computation)
1003    else:
1004      # For training, the inputs should already have been densified and padded.
1005      values = inputs.get(self.get_feature_key_name())
1006      mask = inputs.get(self.get_feature_key_name() +
1007                        _TENSOR_CORE_MASK_KEY_SUFFIX)
1008
1009    embedding_shape = (self.categorical_column._num_buckets, self.dimension)  # pylint: disable=protected-access
1010    if (weight_collections and
1011        ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
1012      weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
1013    embedding_weights = variable_scope.get_variable(
1014        name='embedding_weights',
1015        shape=embedding_shape,
1016        dtype=dtypes.float32,
1017        initializer=self.initializer,
1018        trainable=self.trainable and trainable,
1019        collections=weight_collections)
1020    return sparse_embedding_aggregate_slice(embedding_weights, (values, mask),
1021                                            self.get_combiner())
1022
1023
1024class _TPUSharedDeviceSpecificEmbeddingColumnV2(_TPUSharedEmbeddingColumnV2):
1025  """TPUSharedEmbeddingColumnV2 which allows serving on TensorCore."""
1026
1027  def __new__(cls, *args, **kwargs):
1028    # For __new__, just capture the inference dense shape and call parent.
1029    if 'tensor_core_shape' in kwargs:
1030      cls._tensor_core_shape = kwargs['tensor_core_shape']
1031      del kwargs['tensor_core_shape']
1032    if 'embedding_lookup_device' in kwargs:
1033      cls._embedding_lookup_device = kwargs['embedding_lookup_device']
1034      del kwargs['embedding_lookup_device']
1035
1036    return _TPUSharedEmbeddingColumnV2.__new__(cls, *args, **kwargs)
1037
1038  def __init__(self, *args, **kwargs):
1039    # For __init__, just capture the inference dense shape and call parent.
1040    if 'tensor_core_shape' in kwargs:
1041      self._tensor_core_shape = kwargs['tensor_core_shape']
1042      del kwargs['tensor_core_shape']
1043    if 'embedding_lookup_device' in kwargs:
1044      self._embedding_lookup_device = kwargs['embedding_lookup_device']
1045      del kwargs['embedding_lookup_device']
1046    _TPUSharedEmbeddingColumnV2.__init__(self, *args, **kwargs)
1047
1048  def __deepcopy__(self, memo):
1049    return _TPUSharedDeviceSpecificEmbeddingColumnV2(
1050        *(copy.deepcopy(a, memo) for a in self.__getnewargs__()),
1051        tensor_core_shape=self._tensor_core_shape,
1052        embedding_lookup_device=self._embedding_lookup_device)
1053
1054  def _get_dense_tensor_internal(self, transformation_cache, state_manager):
1055    """Private method that follows _get_dense_tensor_internal."""
1056    _check_invalid_cases(self._embedding_lookup_device)
1057    # CPU Case.
1058    is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU
1059    is_cpu = is_cpu or _is_running_on_cpu()
1060    if is_cpu:
1061      return super(_TPUSharedDeviceSpecificEmbeddingColumnV2,
1062                   self)._get_dense_tensor_internal(transformation_cache,
1063                                                    state_manager)
1064    # TPU_EMBEDDING_CORE case.
1065    if self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE:
1066      return super(_TPUSharedDeviceSpecificEmbeddingColumnV2,
1067                   self)._get_dense_tensor_internal(transformation_cache,
1068                                                    state_manager)
1069
1070    # TPU_EMBEDDING_CORE cases.
1071    if tpu.under_tpu_inference_context():
1072      # For inference, use outside compile to densify and pad the input tensors.
1073      sparse_tensor = transformation_cache.get(self.categorical_column.name,
1074                                               state_manager)
1075
1076      def host_computation():
1077        return pad_sparse_embedding_lookup_indices(sparse_tensor,
1078                                                   self._tensor_core_shape[1])
1079
1080      values, mask = tpu.outside_compilation(host_computation)
1081    else:
1082      # For training, the inputs should already have been densified and padded.
1083      values = transformation_cache.get(self.categorical_column.name,
1084                                        state_manager)
1085      mask = transformation_cache.get(
1086          self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX,
1087          state_manager)
1088
1089    # Do a dense embedding lookup on TensorCore.
1090    embedding_weights = self.shared_embedding_column_creator.embedding_weights
1091    return sparse_embedding_aggregate_slice(embedding_weights, (values, mask),
1092                                            self.get_combiner())
1093