1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base Class for TPU Embeddings Mid level APIs.""" 16 17import functools 18from typing import Any, Dict, Iterable, Optional, Union, Text 19 20from tensorflow.python.framework import dtypes 21from tensorflow.python.ops import variables as tf_variables 22from tensorflow.python.tpu import tpu_embedding_v2_utils 23from tensorflow.python.trackable import autotrackable 24from tensorflow.python.util import nest 25 26 27class TPUEmbeddingBase(autotrackable.AutoTrackable): 28 """The TPUEmbedding Base class. 29 30 This class only contains the basic logic to check the feature config and table 31 config for the tpu embedding mid level APIs. 32 """ 33 34 def __init__( 35 self, 36 feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic 37 optimizer: Optional[tpu_embedding_v2_utils._Optimizer] = None): # pylint:disable=protected-access 38 """Creates the TPUEmbeddingBase object.""" 39 self._feature_config = feature_config 40 self._output_shapes = [] 41 for feature in nest.flatten(feature_config): 42 self._output_shapes.append(feature.output_shape) 43 # Set table order here to the order of the first occurrence of the table in 44 # a feature provided by the user. The order of this struct must be fixed 45 # to provide the user with deterministic behavior over multiple 46 # instantiations. 47 self._table_config = [] 48 for feature in nest.flatten(feature_config): 49 if feature.table not in self._table_config: 50 self._table_config.append(feature.table) 51 52 # Ensure tables have unique names. Also error check the optimizer as we 53 # specifically don't do that in the TableConfig class to allow high level 54 # APIs that are built on this to use strings/other classes to represent 55 # optimizers (before they are passed to this class). 56 table_names = [] 57 for i, table in enumerate(self._table_config): 58 if table.optimizer is None: 59 # TODO(bfontain) Should we allow some sort of optimizer merging here? 60 table.optimizer = optimizer 61 if (table.optimizer is not None and 62 not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)): # pylint: disable=protected-access 63 raise ValueError("{} is an unsupported optimizer class. Please pass an " 64 "instance of one of the optimizer classes under " 65 "tf.tpu.experimental.embedding.".format( 66 type(table.optimizer))) 67 if table.name is None: 68 table.name = "table_{}".format(i) 69 if table.name in table_names: 70 raise ValueError("Tables must have a unique name. " 71 f"Multiple tables with name {table.name} found.") 72 table_names.append(table.name) 73 74 self._built = False 75 76 @property 77 def embedding_tables(self): 78 """Returns a dict of embedding tables, keyed by `TableConfig`.""" 79 raise NotImplementedError 80 81 def _create_variables(self, table: tpu_embedding_v2_utils.TableConfig, 82 trainable: bool) -> Dict[Text, tf_variables.Variable]: 83 """Create all variables including table variables and slot variables.""" 84 variable_shape = (table.vocabulary_size, table.dim) 85 86 def getter(name, shape, dtype, initializer, trainable): 87 del shape 88 # _add_variable_with_custom_getter clears the shape sometimes, so we 89 # take the global shape from outside the getter. 90 initial_value = functools.partial( 91 initializer, variable_shape, dtype=dtype) 92 return tf_variables.Variable( 93 name=name, 94 initial_value=initial_value, 95 shape=variable_shape, 96 dtype=dtype, 97 trainable=trainable) 98 99 def variable_creator(name, initializer, trainable=True): 100 # Use add_variable_with_custom_getter here so that we take advantage of 101 # the checkpoint loading to allow restore before the variables get 102 # created which avoids double initialization. 103 return self._add_variable_with_custom_getter( 104 name=name, 105 initializer=initializer, 106 shape=variable_shape, 107 dtype=dtypes.float32, 108 getter=getter, 109 trainable=trainable) 110 111 parameters = variable_creator( 112 table.name, table.initializer, trainable=trainable) 113 114 def slot_creator(name, initializer): 115 return variable_creator(table.name + "/" + name, initializer, False) 116 117 if table.optimizer is not None: 118 slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access 119 else: 120 slot_vars = {} 121 slot_vars["parameters"] = parameters 122 return slot_vars 123 124 def _create_variables_and_slots(self): 125 """Create variables and slots variables for TPU embeddings.""" 126 raise NotImplementedError 127 128 def build(self): 129 """Create variables and slots variables for TPU embeddings.""" 130 if self._built: 131 return 132 self._variables = self._create_variables_and_slots() 133 self._built = True 134 135 def __call__(self, features: Any, weights: Optional[Any] = None) -> Any: 136 """Call the mid level api to do embedding lookup.""" 137 if not self._built: 138 self.build() 139 return self.embedding_lookup(features, weights) 140 141 def embedding_lookup(self, 142 features: Any, 143 weights: Optional[Any] = None) -> Any: 144 """Lookup the embedding table using the input features.""" 145 raise NotImplementedError 146