xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_embedding_v1.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Mid level API for TPU Embeddings without Embedding Accelerator."""
16
17from typing import Any, Dict, Iterable, Optional, Union, Text
18
19from tensorflow.python.distribute import distribution_strategy_context
20from tensorflow.python.distribute import tpu_strategy
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import sparse_tensor
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import embedding_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import sparse_ops
28from tensorflow.python.ops import variables as tf_variables
29from tensorflow.python.ops.ragged import ragged_tensor
30from tensorflow.python.tpu import tpu
31from tensorflow.python.tpu import tpu_embedding_base
32from tensorflow.python.tpu import tpu_embedding_v2_utils
33from tensorflow.python.util import nest
34from tensorflow.python.util.tf_export import tf_export
35
36
37@tf_export("tpu.experimental.embedding.TPUEmbeddingV0")
38class TPUEmbeddingV0(tpu_embedding_base.TPUEmbeddingBase):
39  """The TPUEmbedding mid level API running on TPU without Embedding accelerator.
40
41  NOTE: This mid level API is not intended for large embedding table lookup.
42  Embedding tables will be replicated across devices rather than sharding
43  across them. To do large embedding table lookup, please use the
44  `tpu.experimental.embedding.TPUEmbedding` class. This class is an alternative
45  way to do embedding lookups when the TPU doesn't support any version of
46  embedding feature. See
47  `tpu.experimental.tpu_hardware_feature.embedding_feature` for a detailed
48  explanation.
49
50  This class has to be created under the `TPUStrategy`, Otherwise a RuntimeError
51  will be raised.
52  ```python
53  strategy = tf.distribute.TPUStrategy(...)
54  with strategy.scope():
55    embedding = tf.tpu.experimental.embedding.TPUEmbeddingV0(
56        feature_config=feature_config,
57        optimizer=tf.tpu.experimental.embedding.SGD(0.1))
58  ```
59  When creating a distributed dataset that is to be passed to the lookup
60  operation a special input option must be specified:
61
62  ```python
63  distributed_dataset = (
64      strategy.distribute_datasets_from_function(
65          dataset_fn=...,
66          options=tf.distribute.InputOptions(
67              experimental_fetch_to_device=False))
68  dataset_iterator = iter(distributed_dataset)
69  ```
70
71  Below is an example of a training and evaluation step:
72
73  ```python
74  optimizer = tf.keras.optimizers.SGD(0.1)
75
76  @tf.function
77  def training_step(dataset_iterator, num_steps):
78    def tpu_step(embedding_features):
79      with tf.GradientTape() as tape:
80        tape.watch(embedding.embedding_table.values())
81        activation = embedding(embedding_features)
82        model_output = model(activations)
83        loss = ...  # some function of labels and model_output
84
85      embedding_gradients = tape.gradient(loss,
86                                          embedding.embedding_table.values())
87      optimizer.apply_gradients(list(zip(gradients,
88                                mid_level_api.embedding_tables.values())))
89      # Insert your model gradient and optimizer application here
90
91    for _ in tf.range(num_steps):
92      strategy.run(tpu_step, args=(next(dataset_iterator), ))
93
94  @tf.function
95  def evalution_step(dataset_iterator, num_steps):
96    def tpu_step(embedding_features):
97      activations = embedding(embedding_features)
98      model_output = model(activations)
99      # Insert your evaluation code here.
100
101    for _ in tf.range(num_steps):
102      strategy.run(tpu_step, args=(next(dataset_iterator), ))
103  ```
104
105  NOTE: The optimizer used here is a Keras optimizer. In order to make the slot
106  variable creation stay consistent between Keras optimizers and
107  embedding optimizers, the `slot_variable_creation_fn` argument of the
108  embedding optimizers has to be passed with the Keras `add_slot` function. Also
109  note that the slot names might be slightly different between them.
110
111  ```python
112  optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1)
113
114  def slot_variable_creation_fn(table, slot_names, slot_initializers):
115      slots = {}
116      for slot, initializer in zip(slot_names, slot_initializers):
117        slots[slot] = optimizer.add_slot(table, slot, initializer)
118      return slots
119
120  embedding_optimizer = tf.experimental.embedding.Adagrad(
121      learning_rate=0.1,
122      slot_variable_creation_fn=slot_variable_creation_fn)
123
124  # Use the embedding optimizer to create mid level api and keras optimizer to
125  # apply gradients.
126  ```
127  """
128
129  def __init__(
130      self,
131      feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable],  # pylint:disable=g-bare-generic
132      optimizer: Optional[tpu_embedding_v2_utils._Optimizer]):  # pylint:disable=protected-access
133    super(TPUEmbeddingV0, self).__init__(feature_config, optimizer)
134    self._strategy = distribution_strategy_context.get_strategy()
135    if not isinstance(self._strategy,
136                      (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)):
137      raise RuntimeError(
138          "TPUEmbeddingV0 should be created under TPUStrategy but found {}."
139          .format(self._strategy))
140    self._built = False
141
142  @property
143  def embedding_tables(
144      self) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]:
145    """Returns a dict of embedding tables, keyed by `TableConfig`."""
146    self._maybe_build()
147    # Only return the tables and not the slot variables.
148    return {
149        table: self._variables[table.name]["parameters"]
150        for table in self._table_config
151    }
152
153  def _create_variables_and_slots(
154      self) -> Dict[Text, Dict[Text, tf_variables.Variable]]:
155    """Create variables for TPU embeddings.
156
157    Note that this will always ensure that the variable is created under the
158    TPUStrategy.
159
160    Returns:
161      A dict of dicts. The outer dict is keyed by the table names and the inner
162      dicts are keyed by 'parameters' and the slot variable names.
163    """
164    variables = {}
165    for table in self._table_config:
166      # created TPUDistributedVariable.
167      variables[table.name] = self._create_variables(table, trainable=True)
168    return variables
169
170  def _maybe_build(self):
171    if not self._built:
172      # This can be called while tracing a function, so we wrap the
173      # initialization code with init_scope so it runs eagerly, this means that
174      # it will not be included in the function graph generated by tracing so
175      # that we can be sure that we only initialize the TPU for embeddings
176      # exactly once.
177      with ops.init_scope():
178        self.build()
179
180  def _apply_combiner_to_embeddings(
181      self,
182      embeddings: ops.Tensor,
183      weight: ops.Tensor,
184      combiner: Optional[Text] = None) -> ops.Tensor:
185    """Apply the combiner to the embedding look up result on second to last axis.
186
187    Args:
188      embeddings: A Tensor of the embedding lookup result.
189      weight: A Tensor of weight which has the same shape of the embeddings.
190      combiner: One of "mean", "sum", "sqrtn". Defaults to "mean".
191
192    Raises:
193      ValueError: If the combiner is not one of 'mean', 'sqrtn' or 'sum'.
194    Returns:
195      A Tensor.
196    """
197    if combiner is None:
198      combiner = "mean"
199    if combiner == "sum":
200      embeddings = math_ops.reduce_sum(embeddings, axis=-2)
201    elif combiner == "mean":
202      embeddings = math_ops.reduce_sum(embeddings, axis=-2)
203      weight_sum = math_ops.reduce_sum(weight, axis=-2)
204      embeddings = math_ops.div_no_nan(embeddings, weight_sum)
205    elif combiner == "sqrtn":
206      embeddings = math_ops.reduce_sum(embeddings, axis=-2)
207      weight_squared = math_ops.pow(weight, 2)
208      weight_sum = math_ops.reduce_sum(weight_squared, axis=-2)
209      weight_sum_sqrt = math_ops.sqrt(weight_sum)
210      embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt)
211    else:
212      raise ValueError(
213          f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}")
214    return embeddings
215
216  def _pad_or_truncate_with_sequence_length(self, embeddings: ops.Tensor,
217                                            sequence_length: int) -> ops.Tensor:
218    """Pad or truncate the embedding lookup result based on the sequence length.
219
220    Args:
221      embeddings: A rank 3 Tensor of the embedding lookup result.
222      sequence_length: number of the max sequence length set in the feature
223        config.
224
225    Returns:
226      A Tensor with second last axis padded or truncated.
227    """
228    original_sequence_length = embeddings.shape[1]
229    if original_sequence_length > sequence_length:
230      embeddings = array_ops.slice(
231          embeddings, begin=[0, 0, 0], size=[-1, sequence_length, -1])
232    else:
233      embeddings = array_ops.pad(
234          embeddings,
235          paddings=[[0, 0], [0, sequence_length - original_sequence_length],
236                    [0, 0]])
237    return embeddings
238
239  def embedding_lookup(self,
240                       features: Any,
241                       weights: Optional[Any] = None) -> Any:
242    """Apply embedding lookup on TPUs using Tensorcore.
243
244    Note that all the sparse and ragged tensors will be converted to dense
245    tensors on CPU and then passed to the TPU to do embedding look up. Large
246    embedding lookup is not supported by this API, use the TPUEmbedding mid
247    level api instead.
248
249    Args:
250      features: a nested structure of Tensors, SparseTensors or RaggedTensors.
251      weights: a nested structure of Tensors, SparseTensors or RaggedTensors or
252        None for no weights. If not None, structure must match that of inputs,
253        but entries are allowed to be None.
254
255    Returns:
256      A nested structure of Tensors with the same structure as inputs.
257    """
258    if not self._built:
259      self.build()
260    nest.assert_same_structure(features, self._feature_config)
261
262    flat_inputs = nest.flatten(features)
263    flat_weights = [None] * len(flat_inputs)
264    if weights is not None:
265      nest.assert_same_structure(features, weights)
266      flat_weights = nest.flatten(weights)
267    flat_features = nest.flatten_with_joined_string_paths(self._feature_config)
268
269    outputs = []
270    for inp, weight, (path, feature) in zip(flat_inputs, flat_weights,
271                                            flat_features):
272      table = self.embedding_tables[feature.table]
273
274      if weight is not None:
275        if isinstance(inp, ops.Tensor):
276          raise ValueError(
277              "Weight specified for {}, but input is dense.".format(path))
278        elif type(weight) is not type(inp):
279          raise ValueError(
280              "Weight for {} is of type {} but it does not match type of the "
281              "input which is {}.".format(path, type(weight), type(inp)))
282        elif feature.max_sequence_length > 0:
283          raise ValueError("Weight specified for {}, but this is a sequence "
284                           "feature.".format(path))
285
286      if isinstance(inp, ops.Tensor):
287        if feature.max_sequence_length > 0:
288          raise ValueError(
289              "Feature {} is a sequence feature but a dense tensor "
290              "was passed.".format(path))
291        outputs.append(embedding_ops.embedding_lookup_v2(table, inp))
292
293      elif isinstance(inp, sparse_tensor.SparseTensor):
294        outputs.append(
295            self._embedding_lookup_for_sparse_tensor(inp, weight, table,
296                                                     feature))
297      elif isinstance(inp, ragged_tensor.RaggedTensor):
298        outputs.append(
299            self._embedding_lookup_for_ragged_tensor(inp, weight, table,
300                                                     feature))
301      else:
302        raise ValueError("Input {} is type {}. Tensor, SparseTensor or "
303                         "RaggedTensor expected.".format(path, type(inp)))
304    return nest.pack_sequence_as(self._feature_config, outputs)
305
306  def _embedding_lookup_for_sparse_tensor(
307      self, inp: sparse_tensor.SparseTensor,
308      weight: Optional[sparse_tensor.SparseTensor],
309      table: tf_variables.Variable,
310      feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor:
311    """Embedding lookup for sparse tensor based on its feature config.
312
313    Args:
314      inp: a single SparseTensor input.
315      weight: None or SparseTensor which has the same shape of the input.
316      table: a table variable.
317      feature: a feature config.
318
319    Returns:
320      Embedding lookup result.
321    """
322
323    # This computation needs to placed outside of tpu as the size of the
324    # indices and values can change for different batch which can cause
325    # the program to re-compile.
326    def sparse_to_dense_computation(inp, weight):
327      if weight is None:
328        weight = sparse_tensor.SparseTensor(
329            inp.indices,
330            array_ops.ones_like(inp.values, dtype=dtypes.float32),
331            dense_shape=inp.dense_shape)
332      # Pad the sparse tensor to be dense tensor.
333      inp = sparse_ops.sparse_tensor_to_dense(inp)
334      weight = sparse_ops.sparse_tensor_to_dense(weight)
335      return inp, weight
336
337    inp, weight = tpu.outside_compilation(
338        sparse_to_dense_computation, inp=inp, weight=weight)
339
340    embeddings = embedding_ops.embedding_lookup_v2(table, inp)
341    weight = array_ops.expand_dims(weight, -1)
342    embeddings *= weight
343    if not feature.output_shape and feature.max_sequence_length > 0:
344      embeddings = self._pad_or_truncate_with_sequence_length(
345          embeddings, feature.max_sequence_length)
346    else:
347      embeddings = self._apply_combiner_to_embeddings(embeddings, weight,
348                                                      feature.table.combiner)
349    return embeddings
350
351  def _embedding_lookup_for_ragged_tensor(
352      self, inp: ragged_tensor.RaggedTensor,
353      weight: Optional[ragged_tensor.RaggedTensor],
354      table: tf_variables.Variable,
355      feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor:
356    """Embedding lookup for ragged tensor based on its feature config.
357
358    Args:
359      inp: a single rank 2 RaggedTensor input.
360      weight: None or RaggedTensor which has the same shape of the input.
361      table: a table variable.
362      feature: a feature config.
363
364    Returns:
365      Embedding lookup result.
366
367    Raises:
368      ValueError: if input ragged tensor is not rank 2 or output shape set in
369      the feature config doesn't match with the first dim size of the input.
370    """
371    if inp.shape.rank != 2:
372      raise ValueError(
373          "Only rank 2 ragged tensor is supported, but got rank {}".format(
374              inp.shape.rank))
375    batch_size = inp.shape[0]
376
377    # This computation needs to placed outside of tpu as the size of the row
378    # splits and values can change for different batch which can cause
379    # the program to re-compile.
380    def ragged_to_dense_outside_compilation(inp, weight, batch_size, feature):
381      if weight is None:
382        weight = ragged_tensor.RaggedTensor.from_row_splits(
383            array_ops.ones_like(inp.values, dtype=dtypes.float32),
384            inp.row_splits)
385      if not feature.output_shape and feature.max_sequence_length > 0:
386        inp = inp.to_tensor(shape=(batch_size, feature.max_sequence_length))
387        # Ignore weight if it is a sequence feature.
388        weight = array_ops.ones_like(inp, dtype=dtypes.float32)
389      elif feature.output_shape:
390        # Eagerly run the following op as the result as to be a number in
391        # order to use it as part of the output shape.
392        with ops.init_scope():
393          output_batch_size = math_ops.reduce_prod(feature.output_shape).numpy()
394        # If the output batch size matches the data batch size, treat it as
395        # normal ragged input.
396        if output_batch_size == batch_size:
397          inp, weight = inp.to_tensor(), weight.to_tensor()
398        # If the data batch size is a factor of the output batch size, the
399        # divide result will be the sequence length. Ignore the weights and
400        # combiner.
401        elif output_batch_size > batch_size and output_batch_size % batch_size == 0:
402          # Pad or truncate in the sequence dimension
403          seq_length = output_batch_size // batch_size
404          inp = inp.to_tensor(shape=(batch_size, seq_length))
405          # Ignore weight if it is a sequence feature.
406          weight = array_ops.ones_like(inp, dtype=dtypes.float32)
407        else:
408          raise ValueError(
409              "Output shape set in the FeatureConfig should be the factor of "
410              "the input data batch size. But instead got output shape {}, "
411              "input data batch size {}".format(feature.output_shape,
412                                                batch_size))
413      else:
414        inp, weight = inp.to_tensor(), weight.to_tensor()
415      return inp, weight
416
417    inp, weight = tpu.outside_compilation(
418        ragged_to_dense_outside_compilation,
419        inp=inp,
420        weight=weight,
421        batch_size=batch_size,
422        feature=feature)
423
424    embeddings = embedding_ops.embedding_lookup_v2(table, inp)
425    weight = array_ops.expand_dims(weight, -1)
426    embeddings *= weight
427
428    if feature.output_shape:
429      with ops.init_scope():
430        output_batch_size = math_ops.reduce_prod(feature.output_shape).numpy()
431      if output_batch_size == batch_size:
432        embeddings = self._apply_combiner_to_embeddings(embeddings, weight,
433                                                        feature.table.combiner)
434      embeddings = array_ops.reshape(
435          embeddings, shape=feature.output_shape + [feature.table.dim])
436    else:
437      if feature.max_sequence_length == 0:
438        embeddings = self._apply_combiner_to_embeddings(embeddings, weight,
439                                                        feature.table.combiner)
440    return embeddings
441