1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Operations for TPUs.""" 16 17from tensorflow.python.framework import dtypes 18from tensorflow.python.framework import ops 19from tensorflow.python.ops import array_ops 20# pylint: disable=wildcard-import,unused-import 21from tensorflow.python.ops import gen_tpu_ops 22from tensorflow.python.ops.gen_tpu_ops import * 23# pylint: enable=wildcard-import,unused-import 24from tensorflow.python.platform import tf_logging as logging 25from tensorflow.python.tpu import tpu_function 26from tensorflow.python.util.tf_export import tf_export 27 28 29def _create_default_group_assignment(): 30 num_shards = tpu_function.get_tpu_context().number_of_shards 31 if num_shards is None: 32 logging.warning( 33 "cross_replica_sum should be used within a tpu_shard_context, but " 34 "got unset number_of_shards. Assuming 1.") 35 num_shards = 1 36 group_assignment = [list(range(num_shards))] 37 return group_assignment 38 39 40def all_to_all(x, 41 concat_dimension, 42 split_dimension, 43 split_count, 44 group_assignment=None, 45 name=None): 46 """Exchange data across TPU replicas. 47 48 Args: 49 x: The local tensor. 50 concat_dimension: The dimension number to concatenate. 51 split_dimension: The dimension number to split. 52 split_count: The number of splits, this number must equal to the sub-group 53 size(group_assignment.get_shape()[1]) 54 group_assignment: Optional 2d int32 lists with shape [num_groups, 55 num_replicas_per_group]. `group_assignment[i]` represents the replica ids 56 in the ith subgroup. 57 name: Optional op name. 58 59 Returns: 60 A `Tensor` which is concatenated by data from different replicas. 61 """ 62 if group_assignment is None: 63 group_assignment = _create_default_group_assignment() 64 return gen_tpu_ops.all_to_all( 65 x, 66 group_assignment, 67 concat_dimension=concat_dimension, 68 split_dimension=split_dimension, 69 split_count=split_count, 70 name=name) 71 72 73@ops.RegisterGradient("AllToAll") 74def _all_to_all_grad(op, grad): 75 # The gradient of a all-to-all is also a all-to-all but the 76 # split_dimension and concat_dimension is swapped. 77 # The gradient with respect to group_assignment is None. 78 return [ 79 gen_tpu_ops.all_to_all( 80 grad, 81 op.inputs[1], 82 concat_dimension=op.get_attr("split_dimension"), 83 split_dimension=op.get_attr("concat_dimension"), 84 split_count=op.get_attr("split_count")), None 85 ] 86 87 88@tf_export(v1=["tpu.cross_replica_sum"]) 89def cross_replica_sum(x, group_assignment=None, name=None): 90 """Sum the input tensor across replicas according to group_assignment. 91 92 Args: 93 x: The local tensor to the sum. 94 group_assignment: Optional 2d int32 lists with shape [num_groups, 95 num_replicas_per_group]. `group_assignment[i]` represents the replica ids 96 in the ith subgroup. 97 name: Optional op name. 98 99 Returns: 100 A `Tensor` which is summed across replicas. 101 """ 102 if group_assignment is None: 103 group_assignment = _create_default_group_assignment() 104 105 return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name) 106 107 108def collective_permute(x, source_target_pairs, name=None): 109 """Permute the input tensor across replicas given source_target_pairs. 110 111 For each source_target_pair <a, b>, we send replica a's input to replica b. 112 Each replica id must only appear once in the source column. Also it must 113 only appear once in the target column. 114 For the replica id not in the target column, this op returns a zero tensor 115 with the same shape and dtype of the input x. 116 117 For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing 118 source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs: 119 `[0, A, B, C]`. 120 121 Args: 122 x: The local tensor to be permuted. 123 source_target_pairs: 2d int lists with shape [num_pairs, 2]. 124 source_target_pairs[i][0] represents the source replica id and 125 source_target_pairs[i][1] represents the target replica id. 126 name: Optional op name. 127 128 Returns: 129 A `Tensor` which is permuted. 130 """ 131 return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name) 132 133 134@ops.RegisterGradient("CollectivePermute") 135def _collective_permute_grad(op, grad): 136 # The gradient of a collective permute operation is also a collective 137 # permute, but with source/target pairs reversed. The gradient with respect 138 # to input argument `source_target_pairs` is `None`. 139 source_target_pairs = op.inputs[1][:, ::-1] 140 return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None] 141 142 143@ops.RegisterGradient("CrossReplicaSum") 144def _cross_replica_sum_grad(op, grad): 145 # The gradient of a cross replica sum is also a cross-replica sum. 146 # The gradient with respect to group_assignment is None. 147 return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None] 148 149 150# This extra type checking exists to give a more helpful error message in 151# the common case that uint8 and int64 values are infed. Remove when both 152# types are supported. 153 154_SUPPORTED_INFEED_DTYPES = set([ 155 dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32, 156 dtypes.complex64, dtypes.uint32 157]) 158 159 160@ops.RegisterGradient("TPUEmbeddingActivations") 161def _embedding_activations_grad(activations_op, grad_wrt_activations): 162 """Saves the gradient of embedding activations ops in a graph collection.""" 163 g = ops.get_default_graph() 164 table_id = activations_op.get_attr("table_id") 165 lookup_id = activations_op.get_attr("lookup_id") 166 table_gradients = g.get_collection_ref("tpu_embedding_gradients_table_%d" % 167 table_id) 168 169 if not table_gradients: 170 raise RuntimeError( 171 "Gradients for TPUEmbedding have been generated in non-training mode." 172 "This is not expected. Consider putting your Optimizer.minimize code " 173 "behind the training mode condition check. For Estimator, you can " 174 "do \n\n" 175 " if mode == tf.estimator.ModeKeys.TRAIN:\n" 176 " train_op = opt.minimize(loss)\n" 177 "\n") 178 179 if lookup_id < 0 or lookup_id >= len(table_gradients): 180 raise RuntimeError( 181 "Gradients (w.r.t. TPUEmbedding activations) generated for table_id {} " 182 "and lookup_id {}. The lookup_id attribute is outside the expected " 183 "range [0, {}).".format(table_id, lookup_id, len(table_gradients))) 184 185 if table_gradients[lookup_id] is not None: 186 raise RuntimeError( 187 "Duplicate gradients (w.r.t. TPUEmbedding activations) generated for " 188 "table_id {} and lookup_id {}. This happens when there are multiple " 189 "calls to tf.gradients in a graph containing TPU embeddings. " 190 "TF cannot identify which gradient to use for updating the embedding " 191 "variables. Consider placing tf.StopGradient around tensors where " 192 "variable update is not required. Previous gradients were generated by " 193 "the following callstack: {}.".format( 194 table_id, lookup_id, table_gradients[lookup_id].op.traceback)) 195 196 table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations) 197 return [ 198 # RegisterGradient requires that value be returned for all inputs. Since 199 # the first argument (tpu_gradient_variable_{table_name}) has shape [1], 200 # we will return zeros(shape=[1]). The actual gradient w.r.t. the 201 # embedding activations (grad_wrt_activations) has the same shape as the 202 # activations returned by embedding_activations. 203 array_ops.zeros(arg.shape, dtype=dtypes.float32) 204 for arg in activations_op.inputs 205 ] 206 207 208def infeed_dequeue(dtype, shape, name=None): 209 """A placeholder op for a value that will be fed into the computation. 210 211 Args: 212 dtype: A `tf.DType`. The type of elements in the tensor. 213 shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor. 214 name: A name for the operation (optional). 215 216 Returns: 217 A `Tensor` of type `dtype`. 218 A tensor that will be provided using the infeed mechanism. 219 220 Raises: 221 TypeError: If 'dtype` is not a supported infeed type. 222 """ 223 if dtype not in _SUPPORTED_INFEED_DTYPES: 224 raise TypeError( 225 "Operation '{}' has type {} which is not a supported TPU infeed type. " 226 "Supported types are: {}".format(name, dtype, 227 list(_SUPPORTED_INFEED_DTYPES))) 228 229 return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name) 230 231 232# pylint: disable=redefined-outer-name 233def infeed_dequeue_tuple(dtypes, shapes, name=None): 234 """A placeholder op for values fed into the TPU simultaneously as a tuple. 235 236 Args: 237 dtypes: A list of `tf.DType`s that has length `>= 1`. The element types of 238 each element in `outputs`. 239 shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). The 240 shapes of each tensor in `outputs`. 241 name: A name for the operation (optional). 242 243 Returns: 244 A list of `Tensor` objects of type `dtypes`. 245 A list of tensors that will be provided using the infeed mechanism. 246 247 Raises: 248 TypeError: If a type in 'dtypes` is not a supported infeed type. 249 """ 250 for dtype in dtypes: 251 if dtype not in _SUPPORTED_INFEED_DTYPES: 252 raise TypeError( 253 "{} is not a supported TPU infeed type. Supported types are: " 254 "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) 255 return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name) 256 257 258# pylint: enable=redefined-outer-name 259 260 261# pylint: disable=protected-access 262def send_tpu_embedding_gradients(inputs, 263 config, 264 learning_rates=None, 265 name=None): 266 """A placeholder op for feeding per-sample gradients to the embedding layer. 267 268 Args: 269 inputs: A TensorList of gradients with which to update embedding tables. 270 This argument has the same length and shapes as the return value of 271 RecvTPUEmbeddingActivations, but contains gradients of the model's loss 272 with respect to the embedding activations. The embedding tables are 273 updated from these gradients via the optimizers specified in the TPU 274 embedding configuration given to tpu.initialize_system. 275 config: Serialized TPUEmbeddingConfiguration proto. 276 learning_rates: A TensorList of float32 scalars, one for each dynamic 277 learning rate tag: see the comments in 278 //third_party/tensorflow/core/protobuf/tpu/ 279 optimization_parameters.proto. Multiple tables can share the same 280 dynamic learning rate tag as specified in the configuration. If the 281 learning rates for all tables are constant, this list should be empty. 282 name: A name for the operation (optional). 283 284 Returns: 285 A SendTPUEmbeddingGradients operation. 286 """ 287 if learning_rates is None: 288 learning_rates = [] 289 return gen_tpu_ops.send_tpu_embedding_gradients( 290 inputs=inputs, learning_rates=learning_rates, config=config, name=name) 291 292 293send_tpu_embedding_gradients.__doc__ = ( 294 gen_tpu_ops.send_tpu_embedding_gradients.__doc__) 295 296 297# pylint: disable=protected-access 298def enqueue_tpu_embedding_integer_batch(batch, 299 device_ordinal, 300 mode_override=None, 301 name=None): 302 """A placeholder op for enqueueing embedding IDs to the TPU. 303 304 Args: 305 batch: A list of 1D tensors, one for each embedding table, containing the 306 indices into the tables. 307 device_ordinal: The TPU device to use. Should be >= 0 and less than the 308 number of TPU cores in the task on which the node is placed. 309 mode_override: A string input that overrides the mode specified in the 310 TPUEmbeddingConfiguration. Supported values are {'unspecified', 311 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 312 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 313 is used (optional). 314 name: A name for the operation (optional). 315 316 Returns: 317 An EnqueueTPUEmbeddingIntegerBatch operation. 318 """ 319 if mode_override is None: 320 mode_override = "unspecified" 321 return gen_tpu_ops.enqueue_tpu_embedding_integer_batch( 322 batch=batch, 323 device_ordinal=device_ordinal, 324 mode_override=mode_override, 325 name=name) 326 327 328enqueue_tpu_embedding_integer_batch.__doc__ = ( 329 gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__) 330 331 332# pylint: disable=protected-access 333def enqueue_tpu_embedding_sparse_batch(sample_indices, 334 embedding_indices, 335 aggregation_weights, 336 device_ordinal, 337 combiners=None, 338 mode_override=None, 339 name=None): 340 """A placeholder op for enqueueing embedding IDs to the TPU. 341 342 Args: 343 sample_indices: A list of rank 1 Tensors specifying the training example and 344 feature to which the corresponding embedding_indices and 345 aggregation_weights values belong. sample_indices[i] must equal b * nf + 346 f, where nf is the number of features from the corresponding table, f is 347 in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed, 348 and will be converted to int32 internally. 349 embedding_indices: A list of rank 1 Tensors, indices into the embedding 350 tables. Both int32 and int64 are allowed and will be converted to int32 351 internally. 352 aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e., 353 per (training example, feature) -- aggregation weights. Both float32 and 354 float64 are allowed and will be converted to float32 internally. 355 device_ordinal: The TPU device to use. Should be >= 0 and less than the 356 number of TPU cores in the task on which the node is placed. 357 combiners: A list of string scalars, one for each embedding table that 358 specify how to normalize the embedding activations after weighted 359 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 360 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 361 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 362 is to use 'sum' for all tables (optional). 363 mode_override: A string input that overrides the mode specified in the 364 TPUEmbeddingConfiguration. Supported values are {'unspecified', 365 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 366 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 367 is used (optional). 368 name: A name for the operation (optional). 369 370 Returns: 371 An EnqueueTPUEmbeddingSparseBatch operation. 372 """ 373 if mode_override is None: 374 mode_override = "unspecified" 375 return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch( 376 sample_indices=sample_indices, 377 embedding_indices=embedding_indices, 378 aggregation_weights=aggregation_weights, 379 device_ordinal=device_ordinal, 380 combiners=combiners, 381 mode_override=mode_override, 382 name=name) 383 384 385enqueue_tpu_embedding_sparse_batch.__doc__ = ( 386 gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__) 387 388 389# pylint: disable=protected-access 390def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices, 391 embedding_indices, 392 aggregation_weights, 393 table_ids, 394 device_ordinal, 395 max_sequence_lengths=None, 396 num_features=None, 397 combiners=None, 398 mode_override=None, 399 name=None): 400 """A placeholder op for enqueueing embedding IDs to the TPU. 401 402 Args: 403 sample_indices: A list of rank 2 Tensors specifying the training example to 404 which the corresponding embedding_indices and aggregation_weights values 405 belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). If 406 the size of its first dimension is 0, we assume each embedding_indices 407 belongs to a different sample. Both int32 and int64 are allowed and will 408 be converted to int32 internally. 409 embedding_indices: A list of rank 1 Tensors, indices into the embedding 410 tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both 411 int32 and int64 are allowed and will be converted to int32 internally. 412 aggregation_weights: A list of rank 1 Tensors containing per training 413 example aggregation weights. It corresponds to sp_weights.values in 414 embedding_lookup_sparse(). If the size of its first dimension is 0, we 415 assume all weights are 1. Both float32 and float64 are allowed and will be 416 converted to float32 internally. 417 table_ids: A list of integers specifying the identifier of the embedding 418 table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to 419 lookup the corresponding input. The ith input is looked up using 420 table_ids[i]. The size of the table_ids list must be equal to that of 421 sample_indices, embedding_indices and aggregation_weights. 422 device_ordinal: The TPU device to use. Should be >= 0 and less than the 423 number of TPU cores in the task on which the node is placed. 424 max_sequence_lengths: A list of integers, the size of which is equal to 425 sample_indices. If equal to 0, the corresponding feature is considered to 426 be a non-sequence feature, If greater than 0, the corresponding feature is 427 a sequence feature with the given maximal length. If None, then we assume 428 a list of all zeroes. 429 num_features: A list of integers, the size of which is equal to 430 sample_indices. If non-empty, entries in this list must be at least 1. For 431 each batch element, we will take num_features rows of the input tensor for 432 embedding lookup. E.g., when sample_indices is empty, the embedding 433 indices must be of shape (batch_size*num_features). 434 combiners: A list of string scalars, one for each embedding table that 435 specify how to normalize the embedding activations after weighted 436 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 437 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 438 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 439 is to use 'sum' for all tables (optional). 440 mode_override: A string input that overrides the mode specified in the 441 TPUEmbeddingConfiguration. Supported values are {'unspecified', 442 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified', 443 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 444 is used (optional). 445 name: A name for the operation (optional). 446 447 Returns: 448 An EnqueueTPUEmbeddingSparseTensorBatch operation. 449 """ 450 if mode_override is None: 451 mode_override = "unspecified" 452 return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( 453 sample_indices=sample_indices, 454 embedding_indices=embedding_indices, 455 aggregation_weights=aggregation_weights, 456 table_ids=table_ids, 457 device_ordinal=device_ordinal, 458 max_sequence_lengths=max_sequence_lengths, 459 combiners=combiners, 460 mode_override=mode_override, 461 num_features=num_features, 462 name=name) 463 464 465enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = ( 466 gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__) 467 468 469# pylint: disable=protected-access 470def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits, 471 embedding_indices, 472 aggregation_weights, 473 table_ids, 474 device_ordinal, 475 max_sequence_lengths=None, 476 num_features=None, 477 combiners=None, 478 mode_override=None, 479 name=None): 480 """A placeholder op for enqueueing embedding IDs to the TPU. 481 482 Args: 483 sample_splits: A list of rank 1 Tensors specifying the break points for 484 splitting embedding_indices and aggregation_weights into rows. It 485 corresponds to ids.row_splits in embedding_lookup(), when ids is a 486 RaggedTensor. Both int32 and int64 are allowed and will be converted to 487 int32 internally. 488 embedding_indices: A list of rank 1 Tensors, indices into the embedding 489 tables. It corresponds to ids.values in embedding_lookup(), when ids is a 490 RaggedTensor. Both int32 and int64 are allowed and will be converted to 491 int32 internally. 492 aggregation_weights: A list of rank 1 Tensors containing per training 493 example aggregation weights. It corresponds to the values field of a 494 RaggedTensor with the same row_splits as ids in embedding_lookup(), when 495 ids is a RaggedTensor. Both float32 and float64 are allowed and will be 496 converted to float32 internally. 497 table_ids: A list of integers specifying the identifier of the embedding 498 table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to 499 lookup the corresponding input. The ith input is looked up using 500 table_ids[i]. The size of the table_ids list must be equal to that of 501 sample_indices, embedding_indices and aggregation_weights. 502 device_ordinal: The TPU device to use. Should be >= 0 and less than the 503 number of TPU cores in the task on which the node is placed. 504 max_sequence_lengths: A list of integers, the size of which is equal to 505 sample_indices. If equal to 0, the corresponding feature is considered to 506 be a non-sequence feature, If greater than 0, the corresponding feature is 507 a sequence feature with the given maximal length. If None, then we assume 508 a list of all zeroes. 509 num_features: A list of integers, the size of which must be equal to 510 sample_indices. If non-empty, entries in this list must be at least 1. For 511 each batch element, we will take num_features rows of the input tensor for 512 embedding lookup. E.g., when sample_indices is empty, the embedding 513 indices must be of shape (batch_size*num_features). 514 combiners: A list of string scalars, one for each embedding table that 515 specify how to normalize the embedding activations after weighted 516 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 517 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 518 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 519 is to use 'sum' for all tables (optional). 520 mode_override: A string input that overrides the mode specified in the 521 TPUEmbeddingConfiguration. Supported values are {'unspecified', 522 'inference', 'training', 'backward_pass_only'}. When set to 'unspecified', 523 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 524 is used (optional). 525 name: A name for the operation (optional). 526 527 Returns: 528 An EnqueueTPUEmbeddingRaggedTensorBatch operation. 529 """ 530 if mode_override is None: 531 mode_override = "unspecified" 532 return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch( 533 sample_splits=sample_splits, 534 embedding_indices=embedding_indices, 535 aggregation_weights=aggregation_weights, 536 table_ids=table_ids, 537 device_ordinal=device_ordinal, 538 max_sequence_lengths=max_sequence_lengths, 539 combiners=combiners, 540 mode_override=mode_override, 541 num_features=num_features, 542 name=name) 543 544 545enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = ( 546 gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__) 547 548 549def enqueue_tpu_embedding_arbitrary_tensor_batch(sample_indices_or_row_splits, 550 embedding_indices, 551 aggregation_weights, 552 device_ordinal, 553 combiners=None, 554 mode_override=None, 555 name=None): 556 """A placeholder op for enqueueing embedding IDs to the TPU. 557 558 Args: 559 sample_indices_or_row_splits: A list of rank 1 or 2 Tensors. When rank 2, 560 the tensors specify the training example to which the corresponding 561 embedding_indices and aggregation_weights values belong. If the size of 562 its first dimension is 0, we assume each embedding_indices belongs to a 563 different sample. Both int32 and int64 are allowed and will be converted 564 to int32 internally. When rank 1, the tensors specify the row splits for 565 splitting embedding_indices and aggregation_weights into rows. It 566 corresponds to ids.row_splits in embedding_lookup(), when ids is a 567 RaggedTensor. When enqueuing N-D ragged tensor, only the last dimension is 568 allowed to be ragged. the row splits is 1-D dense tensor. When empty, we 569 assume a dense tensor is passed to the op. Both int32 and int64 are 570 allowed and will be converted to int32 internally. 571 embedding_indices: A list of rank 1 Tensors, indices into the embedding 572 tables. Both int32 and int64 are allowed and will be converted to int32 573 internally. 574 aggregation_weights: A list of rank 1 Tensors containing per training 575 example aggregation weights. Both float32 and float64 are allowed and will 576 be converted to float32 internally. 577 device_ordinal: The TPU device to use. Should be >= 0 and less than the 578 number of TPU cores in the task on which the node is placed. 579 combiners: A list of string scalars, one for each embedding table that 580 specify how to normalize the embedding activations after weighted 581 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 582 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 583 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 584 is to use 'sum' for all tables (optional). 585 mode_override: A string input that overrides the mode specified in the 586 TPUEmbeddingConfiguration. Supported values are {'unspecified', 587 'inference', 'training', 'backward_pass_only'}. When set to 'unspecified', 588 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override 589 is used (optional). 590 name: A name for the operation (optional). 591 592 Returns: 593 An EnqueueTPUEmbeddingArbitraryTensorBatch operation. 594 """ 595 if mode_override is None: 596 mode_override = "unspecified" 597 return gen_tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch( 598 sample_indices_or_row_splits=sample_indices_or_row_splits, 599 embedding_indices=embedding_indices, 600 aggregation_weights=aggregation_weights, 601 device_ordinal=device_ordinal, 602 combiners=combiners, 603 mode_override=mode_override, 604 name=name) 605 606 607enqueue_tpu_embedding_arbitrary_tensor_batch.__doc__ = ( 608 gen_tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch.__doc__) 609