1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Mid level API for TPU Embeddings without Embedding Accelerator.""" 16 17from typing import Any, Dict, Iterable, Optional, Union, Text 18 19from tensorflow.python.distribute import distribution_strategy_context 20from tensorflow.python.distribute import tpu_strategy 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import sparse_tensor 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import embedding_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import sparse_ops 28from tensorflow.python.ops import variables as tf_variables 29from tensorflow.python.ops.ragged import ragged_tensor 30from tensorflow.python.tpu import tpu 31from tensorflow.python.tpu import tpu_embedding_base 32from tensorflow.python.tpu import tpu_embedding_v2_utils 33from tensorflow.python.util import nest 34from tensorflow.python.util.tf_export import tf_export 35 36 37@tf_export("tpu.experimental.embedding.TPUEmbeddingV0") 38class TPUEmbeddingV0(tpu_embedding_base.TPUEmbeddingBase): 39 """The TPUEmbedding mid level API running on TPU without Embedding accelerator. 40 41 NOTE: This mid level API is not intended for large embedding table lookup. 42 Embedding tables will be replicated across devices rather than sharding 43 across them. To do large embedding table lookup, please use the 44 `tpu.experimental.embedding.TPUEmbedding` class. This class is an alternative 45 way to do embedding lookups when the TPU doesn't support any version of 46 embedding feature. See 47 `tpu.experimental.tpu_hardware_feature.embedding_feature` for a detailed 48 explanation. 49 50 This class has to be created under the `TPUStrategy`, Otherwise a RuntimeError 51 will be raised. 52 ```python 53 strategy = tf.distribute.TPUStrategy(...) 54 with strategy.scope(): 55 embedding = tf.tpu.experimental.embedding.TPUEmbeddingV0( 56 feature_config=feature_config, 57 optimizer=tf.tpu.experimental.embedding.SGD(0.1)) 58 ``` 59 When creating a distributed dataset that is to be passed to the lookup 60 operation a special input option must be specified: 61 62 ```python 63 distributed_dataset = ( 64 strategy.distribute_datasets_from_function( 65 dataset_fn=..., 66 options=tf.distribute.InputOptions( 67 experimental_fetch_to_device=False)) 68 dataset_iterator = iter(distributed_dataset) 69 ``` 70 71 Below is an example of a training and evaluation step: 72 73 ```python 74 optimizer = tf.keras.optimizers.SGD(0.1) 75 76 @tf.function 77 def training_step(dataset_iterator, num_steps): 78 def tpu_step(embedding_features): 79 with tf.GradientTape() as tape: 80 tape.watch(embedding.embedding_table.values()) 81 activation = embedding(embedding_features) 82 model_output = model(activations) 83 loss = ... # some function of labels and model_output 84 85 embedding_gradients = tape.gradient(loss, 86 embedding.embedding_table.values()) 87 optimizer.apply_gradients(list(zip(gradients, 88 mid_level_api.embedding_tables.values()))) 89 # Insert your model gradient and optimizer application here 90 91 for _ in tf.range(num_steps): 92 strategy.run(tpu_step, args=(next(dataset_iterator), )) 93 94 @tf.function 95 def evalution_step(dataset_iterator, num_steps): 96 def tpu_step(embedding_features): 97 activations = embedding(embedding_features) 98 model_output = model(activations) 99 # Insert your evaluation code here. 100 101 for _ in tf.range(num_steps): 102 strategy.run(tpu_step, args=(next(dataset_iterator), )) 103 ``` 104 105 NOTE: The optimizer used here is a Keras optimizer. In order to make the slot 106 variable creation stay consistent between Keras optimizers and 107 embedding optimizers, the `slot_variable_creation_fn` argument of the 108 embedding optimizers has to be passed with the Keras `add_slot` function. Also 109 note that the slot names might be slightly different between them. 110 111 ```python 112 optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1) 113 114 def slot_variable_creation_fn(table, slot_names, slot_initializers): 115 slots = {} 116 for slot, initializer in zip(slot_names, slot_initializers): 117 slots[slot] = optimizer.add_slot(table, slot, initializer) 118 return slots 119 120 embedding_optimizer = tf.experimental.embedding.Adagrad( 121 learning_rate=0.1, 122 slot_variable_creation_fn=slot_variable_creation_fn) 123 124 # Use the embedding optimizer to create mid level api and keras optimizer to 125 # apply gradients. 126 ``` 127 """ 128 129 def __init__( 130 self, 131 feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic 132 optimizer: Optional[tpu_embedding_v2_utils._Optimizer]): # pylint:disable=protected-access 133 super(TPUEmbeddingV0, self).__init__(feature_config, optimizer) 134 self._strategy = distribution_strategy_context.get_strategy() 135 if not isinstance(self._strategy, 136 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV2)): 137 raise RuntimeError( 138 "TPUEmbeddingV0 should be created under TPUStrategy but found {}." 139 .format(self._strategy)) 140 self._built = False 141 142 @property 143 def embedding_tables( 144 self) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]: 145 """Returns a dict of embedding tables, keyed by `TableConfig`.""" 146 self._maybe_build() 147 # Only return the tables and not the slot variables. 148 return { 149 table: self._variables[table.name]["parameters"] 150 for table in self._table_config 151 } 152 153 def _create_variables_and_slots( 154 self) -> Dict[Text, Dict[Text, tf_variables.Variable]]: 155 """Create variables for TPU embeddings. 156 157 Note that this will always ensure that the variable is created under the 158 TPUStrategy. 159 160 Returns: 161 A dict of dicts. The outer dict is keyed by the table names and the inner 162 dicts are keyed by 'parameters' and the slot variable names. 163 """ 164 variables = {} 165 for table in self._table_config: 166 # created TPUDistributedVariable. 167 variables[table.name] = self._create_variables(table, trainable=True) 168 return variables 169 170 def _maybe_build(self): 171 if not self._built: 172 # This can be called while tracing a function, so we wrap the 173 # initialization code with init_scope so it runs eagerly, this means that 174 # it will not be included in the function graph generated by tracing so 175 # that we can be sure that we only initialize the TPU for embeddings 176 # exactly once. 177 with ops.init_scope(): 178 self.build() 179 180 def _apply_combiner_to_embeddings( 181 self, 182 embeddings: ops.Tensor, 183 weight: ops.Tensor, 184 combiner: Optional[Text] = None) -> ops.Tensor: 185 """Apply the combiner to the embedding look up result on second to last axis. 186 187 Args: 188 embeddings: A Tensor of the embedding lookup result. 189 weight: A Tensor of weight which has the same shape of the embeddings. 190 combiner: One of "mean", "sum", "sqrtn". Defaults to "mean". 191 192 Raises: 193 ValueError: If the combiner is not one of 'mean', 'sqrtn' or 'sum'. 194 Returns: 195 A Tensor. 196 """ 197 if combiner is None: 198 combiner = "mean" 199 if combiner == "sum": 200 embeddings = math_ops.reduce_sum(embeddings, axis=-2) 201 elif combiner == "mean": 202 embeddings = math_ops.reduce_sum(embeddings, axis=-2) 203 weight_sum = math_ops.reduce_sum(weight, axis=-2) 204 embeddings = math_ops.div_no_nan(embeddings, weight_sum) 205 elif combiner == "sqrtn": 206 embeddings = math_ops.reduce_sum(embeddings, axis=-2) 207 weight_squared = math_ops.pow(weight, 2) 208 weight_sum = math_ops.reduce_sum(weight_squared, axis=-2) 209 weight_sum_sqrt = math_ops.sqrt(weight_sum) 210 embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt) 211 else: 212 raise ValueError( 213 f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}") 214 return embeddings 215 216 def _pad_or_truncate_with_sequence_length(self, embeddings: ops.Tensor, 217 sequence_length: int) -> ops.Tensor: 218 """Pad or truncate the embedding lookup result based on the sequence length. 219 220 Args: 221 embeddings: A rank 3 Tensor of the embedding lookup result. 222 sequence_length: number of the max sequence length set in the feature 223 config. 224 225 Returns: 226 A Tensor with second last axis padded or truncated. 227 """ 228 original_sequence_length = embeddings.shape[1] 229 if original_sequence_length > sequence_length: 230 embeddings = array_ops.slice( 231 embeddings, begin=[0, 0, 0], size=[-1, sequence_length, -1]) 232 else: 233 embeddings = array_ops.pad( 234 embeddings, 235 paddings=[[0, 0], [0, sequence_length - original_sequence_length], 236 [0, 0]]) 237 return embeddings 238 239 def embedding_lookup(self, 240 features: Any, 241 weights: Optional[Any] = None) -> Any: 242 """Apply embedding lookup on TPUs using Tensorcore. 243 244 Note that all the sparse and ragged tensors will be converted to dense 245 tensors on CPU and then passed to the TPU to do embedding look up. Large 246 embedding lookup is not supported by this API, use the TPUEmbedding mid 247 level api instead. 248 249 Args: 250 features: a nested structure of Tensors, SparseTensors or RaggedTensors. 251 weights: a nested structure of Tensors, SparseTensors or RaggedTensors or 252 None for no weights. If not None, structure must match that of inputs, 253 but entries are allowed to be None. 254 255 Returns: 256 A nested structure of Tensors with the same structure as inputs. 257 """ 258 if not self._built: 259 self.build() 260 nest.assert_same_structure(features, self._feature_config) 261 262 flat_inputs = nest.flatten(features) 263 flat_weights = [None] * len(flat_inputs) 264 if weights is not None: 265 nest.assert_same_structure(features, weights) 266 flat_weights = nest.flatten(weights) 267 flat_features = nest.flatten_with_joined_string_paths(self._feature_config) 268 269 outputs = [] 270 for inp, weight, (path, feature) in zip(flat_inputs, flat_weights, 271 flat_features): 272 table = self.embedding_tables[feature.table] 273 274 if weight is not None: 275 if isinstance(inp, ops.Tensor): 276 raise ValueError( 277 "Weight specified for {}, but input is dense.".format(path)) 278 elif type(weight) is not type(inp): 279 raise ValueError( 280 "Weight for {} is of type {} but it does not match type of the " 281 "input which is {}.".format(path, type(weight), type(inp))) 282 elif feature.max_sequence_length > 0: 283 raise ValueError("Weight specified for {}, but this is a sequence " 284 "feature.".format(path)) 285 286 if isinstance(inp, ops.Tensor): 287 if feature.max_sequence_length > 0: 288 raise ValueError( 289 "Feature {} is a sequence feature but a dense tensor " 290 "was passed.".format(path)) 291 outputs.append(embedding_ops.embedding_lookup_v2(table, inp)) 292 293 elif isinstance(inp, sparse_tensor.SparseTensor): 294 outputs.append( 295 self._embedding_lookup_for_sparse_tensor(inp, weight, table, 296 feature)) 297 elif isinstance(inp, ragged_tensor.RaggedTensor): 298 outputs.append( 299 self._embedding_lookup_for_ragged_tensor(inp, weight, table, 300 feature)) 301 else: 302 raise ValueError("Input {} is type {}. Tensor, SparseTensor or " 303 "RaggedTensor expected.".format(path, type(inp))) 304 return nest.pack_sequence_as(self._feature_config, outputs) 305 306 def _embedding_lookup_for_sparse_tensor( 307 self, inp: sparse_tensor.SparseTensor, 308 weight: Optional[sparse_tensor.SparseTensor], 309 table: tf_variables.Variable, 310 feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor: 311 """Embedding lookup for sparse tensor based on its feature config. 312 313 Args: 314 inp: a single SparseTensor input. 315 weight: None or SparseTensor which has the same shape of the input. 316 table: a table variable. 317 feature: a feature config. 318 319 Returns: 320 Embedding lookup result. 321 """ 322 323 # This computation needs to placed outside of tpu as the size of the 324 # indices and values can change for different batch which can cause 325 # the program to re-compile. 326 def sparse_to_dense_computation(inp, weight): 327 if weight is None: 328 weight = sparse_tensor.SparseTensor( 329 inp.indices, 330 array_ops.ones_like(inp.values, dtype=dtypes.float32), 331 dense_shape=inp.dense_shape) 332 # Pad the sparse tensor to be dense tensor. 333 inp = sparse_ops.sparse_tensor_to_dense(inp) 334 weight = sparse_ops.sparse_tensor_to_dense(weight) 335 return inp, weight 336 337 inp, weight = tpu.outside_compilation( 338 sparse_to_dense_computation, inp=inp, weight=weight) 339 340 embeddings = embedding_ops.embedding_lookup_v2(table, inp) 341 weight = array_ops.expand_dims(weight, -1) 342 embeddings *= weight 343 if not feature.output_shape and feature.max_sequence_length > 0: 344 embeddings = self._pad_or_truncate_with_sequence_length( 345 embeddings, feature.max_sequence_length) 346 else: 347 embeddings = self._apply_combiner_to_embeddings(embeddings, weight, 348 feature.table.combiner) 349 return embeddings 350 351 def _embedding_lookup_for_ragged_tensor( 352 self, inp: ragged_tensor.RaggedTensor, 353 weight: Optional[ragged_tensor.RaggedTensor], 354 table: tf_variables.Variable, 355 feature: tpu_embedding_v2_utils.FeatureConfig) -> ops.Tensor: 356 """Embedding lookup for ragged tensor based on its feature config. 357 358 Args: 359 inp: a single rank 2 RaggedTensor input. 360 weight: None or RaggedTensor which has the same shape of the input. 361 table: a table variable. 362 feature: a feature config. 363 364 Returns: 365 Embedding lookup result. 366 367 Raises: 368 ValueError: if input ragged tensor is not rank 2 or output shape set in 369 the feature config doesn't match with the first dim size of the input. 370 """ 371 if inp.shape.rank != 2: 372 raise ValueError( 373 "Only rank 2 ragged tensor is supported, but got rank {}".format( 374 inp.shape.rank)) 375 batch_size = inp.shape[0] 376 377 # This computation needs to placed outside of tpu as the size of the row 378 # splits and values can change for different batch which can cause 379 # the program to re-compile. 380 def ragged_to_dense_outside_compilation(inp, weight, batch_size, feature): 381 if weight is None: 382 weight = ragged_tensor.RaggedTensor.from_row_splits( 383 array_ops.ones_like(inp.values, dtype=dtypes.float32), 384 inp.row_splits) 385 if not feature.output_shape and feature.max_sequence_length > 0: 386 inp = inp.to_tensor(shape=(batch_size, feature.max_sequence_length)) 387 # Ignore weight if it is a sequence feature. 388 weight = array_ops.ones_like(inp, dtype=dtypes.float32) 389 elif feature.output_shape: 390 # Eagerly run the following op as the result as to be a number in 391 # order to use it as part of the output shape. 392 with ops.init_scope(): 393 output_batch_size = math_ops.reduce_prod(feature.output_shape).numpy() 394 # If the output batch size matches the data batch size, treat it as 395 # normal ragged input. 396 if output_batch_size == batch_size: 397 inp, weight = inp.to_tensor(), weight.to_tensor() 398 # If the data batch size is a factor of the output batch size, the 399 # divide result will be the sequence length. Ignore the weights and 400 # combiner. 401 elif output_batch_size > batch_size and output_batch_size % batch_size == 0: 402 # Pad or truncate in the sequence dimension 403 seq_length = output_batch_size // batch_size 404 inp = inp.to_tensor(shape=(batch_size, seq_length)) 405 # Ignore weight if it is a sequence feature. 406 weight = array_ops.ones_like(inp, dtype=dtypes.float32) 407 else: 408 raise ValueError( 409 "Output shape set in the FeatureConfig should be the factor of " 410 "the input data batch size. But instead got output shape {}, " 411 "input data batch size {}".format(feature.output_shape, 412 batch_size)) 413 else: 414 inp, weight = inp.to_tensor(), weight.to_tensor() 415 return inp, weight 416 417 inp, weight = tpu.outside_compilation( 418 ragged_to_dense_outside_compilation, 419 inp=inp, 420 weight=weight, 421 batch_size=batch_size, 422 feature=feature) 423 424 embeddings = embedding_ops.embedding_lookup_v2(table, inp) 425 weight = array_ops.expand_dims(weight, -1) 426 embeddings *= weight 427 428 if feature.output_shape: 429 with ops.init_scope(): 430 output_batch_size = math_ops.reduce_prod(feature.output_shape).numpy() 431 if output_batch_size == batch_size: 432 embeddings = self._apply_combiner_to_embeddings(embeddings, weight, 433 feature.table.combiner) 434 embeddings = array_ops.reshape( 435 embeddings, shape=feature.output_shape + [feature.table.dim]) 436 else: 437 if feature.max_sequence_length == 0: 438 embeddings = self._apply_combiner_to_embeddings(embeddings, weight, 439 feature.table.combiner) 440 return embeddings 441