xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_embedding_base.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base Class for TPU Embeddings Mid level APIs."""
16
17import functools
18from typing import Any, Dict, Iterable, Optional, Union, Text
19
20from tensorflow.python.framework import dtypes
21from tensorflow.python.ops import variables as tf_variables
22from tensorflow.python.tpu import tpu_embedding_v2_utils
23from tensorflow.python.trackable import autotrackable
24from tensorflow.python.util import nest
25
26
27class TPUEmbeddingBase(autotrackable.AutoTrackable):
28  """The TPUEmbedding Base class.
29
30  This class only contains the basic logic to check the feature config and table
31  config for the tpu embedding mid level APIs.
32  """
33
34  def __init__(
35      self,
36      feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable],  # pylint:disable=g-bare-generic
37      optimizer: Optional[tpu_embedding_v2_utils._Optimizer] = None):  # pylint:disable=protected-access
38    """Creates the TPUEmbeddingBase object."""
39    self._feature_config = feature_config
40    self._output_shapes = []
41    for feature in nest.flatten(feature_config):
42      self._output_shapes.append(feature.output_shape)
43    # Set table order here to the order of the first occurrence of the table in
44    # a feature provided by the user. The order of this struct must be fixed
45    # to provide the user with deterministic behavior over multiple
46    # instantiations.
47    self._table_config = []
48    for feature in nest.flatten(feature_config):
49      if feature.table not in self._table_config:
50        self._table_config.append(feature.table)
51
52    # Ensure tables have unique names. Also error check the optimizer as we
53    # specifically don't do that in the TableConfig class to allow high level
54    # APIs that are built on this to use strings/other classes to represent
55    # optimizers (before they are passed to this class).
56    table_names = []
57    for i, table in enumerate(self._table_config):
58      if table.optimizer is None:
59        # TODO(bfontain) Should we allow some sort of optimizer merging here?
60        table.optimizer = optimizer
61      if (table.optimizer is not None and
62          not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)):  # pylint: disable=protected-access
63        raise ValueError("{} is an unsupported optimizer class. Please pass an "
64                         "instance of one of the optimizer classes under "
65                         "tf.tpu.experimental.embedding.".format(
66                             type(table.optimizer)))
67      if table.name is None:
68        table.name = "table_{}".format(i)
69      if table.name in table_names:
70        raise ValueError("Tables must have a unique name. "
71                         f"Multiple tables with name {table.name} found.")
72      table_names.append(table.name)
73
74    self._built = False
75
76  @property
77  def embedding_tables(self):
78    """Returns a dict of embedding tables, keyed by `TableConfig`."""
79    raise NotImplementedError
80
81  def _create_variables(self, table: tpu_embedding_v2_utils.TableConfig,
82                        trainable: bool) -> Dict[Text, tf_variables.Variable]:
83    """Create all variables including table variables and slot variables."""
84    variable_shape = (table.vocabulary_size, table.dim)
85
86    def getter(name, shape, dtype, initializer, trainable):
87      del shape
88      # _add_variable_with_custom_getter clears the shape sometimes, so we
89      # take the global shape from outside the getter.
90      initial_value = functools.partial(
91          initializer, variable_shape, dtype=dtype)
92      return tf_variables.Variable(
93          name=name,
94          initial_value=initial_value,
95          shape=variable_shape,
96          dtype=dtype,
97          trainable=trainable)
98
99    def variable_creator(name, initializer, trainable=True):
100      # Use add_variable_with_custom_getter here so that we take advantage of
101      # the checkpoint loading to allow restore before the variables get
102      # created which avoids double initialization.
103      return self._add_variable_with_custom_getter(
104          name=name,
105          initializer=initializer,
106          shape=variable_shape,
107          dtype=dtypes.float32,
108          getter=getter,
109          trainable=trainable)
110
111    parameters = variable_creator(
112        table.name, table.initializer, trainable=trainable)
113
114    def slot_creator(name, initializer):
115      return variable_creator(table.name + "/" + name, initializer, False)
116
117    if table.optimizer is not None:
118      slot_vars = table.optimizer._create_slots(parameters, slot_creator)  # pylint: disable=protected-access
119    else:
120      slot_vars = {}
121    slot_vars["parameters"] = parameters
122    return slot_vars
123
124  def _create_variables_and_slots(self):
125    """Create variables and slots variables for TPU embeddings."""
126    raise NotImplementedError
127
128  def build(self):
129    """Create variables and slots variables for TPU embeddings."""
130    if self._built:
131      return
132    self._variables = self._create_variables_and_slots()
133    self._built = True
134
135  def __call__(self, features: Any, weights: Optional[Any] = None) -> Any:
136    """Call the mid level api to do embedding lookup."""
137    if not self._built:
138      self.build()
139    return self.embedding_lookup(features, weights)
140
141  def embedding_lookup(self,
142                       features: Any,
143                       weights: Optional[Any] = None) -> Any:
144    """Lookup the embedding table using the input features."""
145    raise NotImplementedError
146