xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/ops/tpu_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Operations for TPUs."""
16
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.ops import array_ops
20# pylint: disable=wildcard-import,unused-import
21from tensorflow.python.ops import gen_tpu_ops
22from tensorflow.python.ops.gen_tpu_ops import *
23# pylint: enable=wildcard-import,unused-import
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.tpu import tpu_function
26from tensorflow.python.util.tf_export import tf_export
27
28
29def _create_default_group_assignment():
30  num_shards = tpu_function.get_tpu_context().number_of_shards
31  if num_shards is None:
32    logging.warning(
33        "cross_replica_sum should be used within a tpu_shard_context, but "
34        "got unset number_of_shards. Assuming 1.")
35    num_shards = 1
36  group_assignment = [list(range(num_shards))]
37  return group_assignment
38
39
40def all_to_all(x,
41               concat_dimension,
42               split_dimension,
43               split_count,
44               group_assignment=None,
45               name=None):
46  """Exchange data across TPU replicas.
47
48  Args:
49    x: The local tensor.
50    concat_dimension: The dimension number to concatenate.
51    split_dimension: The dimension number to split.
52    split_count: The number of splits, this number must equal to the sub-group
53      size(group_assignment.get_shape()[1])
54    group_assignment: Optional 2d int32 lists with shape [num_groups,
55      num_replicas_per_group]. `group_assignment[i]` represents the replica ids
56      in the ith subgroup.
57    name: Optional op name.
58
59  Returns:
60    A `Tensor` which is concatenated by data from different replicas.
61  """
62  if group_assignment is None:
63    group_assignment = _create_default_group_assignment()
64  return gen_tpu_ops.all_to_all(
65      x,
66      group_assignment,
67      concat_dimension=concat_dimension,
68      split_dimension=split_dimension,
69      split_count=split_count,
70      name=name)
71
72
73@ops.RegisterGradient("AllToAll")
74def _all_to_all_grad(op, grad):
75  # The gradient of a all-to-all is also a all-to-all but the
76  # split_dimension and concat_dimension is swapped.
77  # The gradient with respect to group_assignment is None.
78  return [
79      gen_tpu_ops.all_to_all(
80          grad,
81          op.inputs[1],
82          concat_dimension=op.get_attr("split_dimension"),
83          split_dimension=op.get_attr("concat_dimension"),
84          split_count=op.get_attr("split_count")), None
85  ]
86
87
88@tf_export(v1=["tpu.cross_replica_sum"])
89def cross_replica_sum(x, group_assignment=None, name=None):
90  """Sum the input tensor across replicas according to group_assignment.
91
92  Args:
93    x: The local tensor to the sum.
94    group_assignment: Optional 2d int32 lists with shape [num_groups,
95      num_replicas_per_group]. `group_assignment[i]` represents the replica ids
96      in the ith subgroup.
97    name: Optional op name.
98
99  Returns:
100    A `Tensor` which is summed across replicas.
101  """
102  if group_assignment is None:
103    group_assignment = _create_default_group_assignment()
104
105  return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
106
107
108def collective_permute(x, source_target_pairs, name=None):
109  """Permute the input tensor across replicas given source_target_pairs.
110
111  For each source_target_pair <a, b>, we send replica a's input to replica b.
112  Each replica id must only appear once in the source column. Also it must
113  only appear once in the target column.
114  For the replica id not in the target column, this op returns a zero tensor
115  with the same shape and dtype of the input x.
116
117  For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
118  source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs:
119  `[0, A, B, C]`.
120
121  Args:
122    x: The local tensor to be permuted.
123    source_target_pairs: 2d int lists with shape [num_pairs, 2].
124      source_target_pairs[i][0] represents the source replica id and
125      source_target_pairs[i][1] represents the target replica id.
126    name: Optional op name.
127
128  Returns:
129    A `Tensor` which is permuted.
130  """
131  return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
132
133
134@ops.RegisterGradient("CollectivePermute")
135def _collective_permute_grad(op, grad):
136  # The gradient of a collective permute operation is also a collective
137  # permute, but with source/target pairs reversed. The gradient with respect
138  # to input argument `source_target_pairs` is `None`.
139  source_target_pairs = op.inputs[1][:, ::-1]
140  return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None]
141
142
143@ops.RegisterGradient("CrossReplicaSum")
144def _cross_replica_sum_grad(op, grad):
145  # The gradient of a cross replica sum is also a cross-replica sum.
146  # The gradient with respect to group_assignment is None.
147  return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
148
149
150# This extra type checking exists to give a more helpful error message in
151# the common case that uint8 and int64 values are infed. Remove when both
152# types are supported.
153
154_SUPPORTED_INFEED_DTYPES = set([
155    dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32,
156    dtypes.complex64, dtypes.uint32
157])
158
159
160@ops.RegisterGradient("TPUEmbeddingActivations")
161def _embedding_activations_grad(activations_op, grad_wrt_activations):
162  """Saves the gradient of embedding activations ops in a graph collection."""
163  g = ops.get_default_graph()
164  table_id = activations_op.get_attr("table_id")
165  lookup_id = activations_op.get_attr("lookup_id")
166  table_gradients = g.get_collection_ref("tpu_embedding_gradients_table_%d" %
167                                         table_id)
168
169  if not table_gradients:
170    raise RuntimeError(
171        "Gradients for TPUEmbedding have been generated in non-training mode."
172        "This is not expected. Consider putting your Optimizer.minimize code "
173        "behind the training mode condition check. For Estimator, you can "
174        "do \n\n"
175        "    if mode == tf.estimator.ModeKeys.TRAIN:\n"
176        "        train_op = opt.minimize(loss)\n"
177        "\n")
178
179  if lookup_id < 0 or lookup_id >= len(table_gradients):
180    raise RuntimeError(
181        "Gradients (w.r.t. TPUEmbedding activations) generated for table_id {} "
182        "and lookup_id {}. The lookup_id attribute is outside the expected "
183        "range [0, {}).".format(table_id, lookup_id, len(table_gradients)))
184
185  if table_gradients[lookup_id] is not None:
186    raise RuntimeError(
187        "Duplicate gradients (w.r.t. TPUEmbedding activations) generated for "
188        "table_id {} and lookup_id {}. This happens when there are multiple "
189        "calls to tf.gradients in a graph containing TPU embeddings. "
190        "TF cannot identify which gradient to use for updating the embedding "
191        "variables. Consider placing tf.StopGradient around tensors where "
192        "variable update is not required. Previous gradients were generated by "
193        "the following callstack: {}.".format(
194            table_id, lookup_id, table_gradients[lookup_id].op.traceback))
195
196  table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
197  return [
198      # RegisterGradient requires that value be returned for all inputs. Since
199      # the first argument (tpu_gradient_variable_{table_name}) has shape [1],
200      # we will return zeros(shape=[1]). The actual gradient w.r.t. the
201      # embedding activations (grad_wrt_activations) has the same shape as the
202      # activations returned by  embedding_activations.
203      array_ops.zeros(arg.shape, dtype=dtypes.float32)
204      for arg in activations_op.inputs
205  ]
206
207
208def infeed_dequeue(dtype, shape, name=None):
209  """A placeholder op for a value that will be fed into the computation.
210
211  Args:
212    dtype: A `tf.DType`. The type of elements in the tensor.
213    shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor.
214    name: A name for the operation (optional).
215
216  Returns:
217    A `Tensor` of type `dtype`.
218    A tensor that will be provided using the infeed mechanism.
219
220  Raises:
221    TypeError: If 'dtype` is not a supported infeed type.
222  """
223  if dtype not in _SUPPORTED_INFEED_DTYPES:
224    raise TypeError(
225        "Operation '{}' has type {} which is not a supported TPU infeed type. "
226        "Supported types are: {}".format(name, dtype,
227                                         list(_SUPPORTED_INFEED_DTYPES)))
228
229  return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name)
230
231
232# pylint: disable=redefined-outer-name
233def infeed_dequeue_tuple(dtypes, shapes, name=None):
234  """A placeholder op for values fed into the TPU simultaneously as a tuple.
235
236  Args:
237    dtypes: A list of `tf.DType`s that has length `>= 1`. The element types of
238      each element in `outputs`.
239    shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). The
240      shapes of each tensor in `outputs`.
241    name: A name for the operation (optional).
242
243  Returns:
244    A list of `Tensor` objects of type `dtypes`.
245    A list of tensors that will be provided using the infeed mechanism.
246
247  Raises:
248    TypeError: If a type in 'dtypes` is not a supported infeed type.
249  """
250  for dtype in dtypes:
251    if dtype not in _SUPPORTED_INFEED_DTYPES:
252      raise TypeError(
253          "{} is not a supported TPU infeed type. Supported types are: "
254          "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
255  return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
256
257
258# pylint: enable=redefined-outer-name
259
260
261# pylint: disable=protected-access
262def send_tpu_embedding_gradients(inputs,
263                                 config,
264                                 learning_rates=None,
265                                 name=None):
266  """A placeholder op for feeding per-sample gradients to the embedding layer.
267
268  Args:
269    inputs: A TensorList of gradients with which to update embedding tables.
270      This argument has the same length and shapes as the return value of
271      RecvTPUEmbeddingActivations, but contains gradients of the model's loss
272      with respect to the embedding activations. The embedding tables are
273      updated from these gradients via the optimizers specified in the TPU
274      embedding configuration given to tpu.initialize_system.
275    config: Serialized TPUEmbeddingConfiguration proto.
276    learning_rates: A TensorList of float32 scalars, one for each dynamic
277        learning rate tag: see the comments in
278          //third_party/tensorflow/core/protobuf/tpu/
279          optimization_parameters.proto. Multiple tables can share the same
280          dynamic learning rate tag as specified in the configuration. If the
281          learning rates for all tables are constant, this list should be empty.
282    name: A name for the operation (optional).
283
284  Returns:
285    A SendTPUEmbeddingGradients operation.
286  """
287  if learning_rates is None:
288    learning_rates = []
289  return gen_tpu_ops.send_tpu_embedding_gradients(
290      inputs=inputs, learning_rates=learning_rates, config=config, name=name)
291
292
293send_tpu_embedding_gradients.__doc__ = (
294    gen_tpu_ops.send_tpu_embedding_gradients.__doc__)
295
296
297# pylint: disable=protected-access
298def enqueue_tpu_embedding_integer_batch(batch,
299                                        device_ordinal,
300                                        mode_override=None,
301                                        name=None):
302  """A placeholder op for enqueueing embedding IDs to the TPU.
303
304  Args:
305    batch: A list of 1D tensors, one for each embedding table, containing the
306      indices into the tables.
307    device_ordinal: The TPU device to use. Should be >= 0 and less than the
308      number of TPU cores in the task on which the node is placed.
309    mode_override: A string input that overrides the mode specified in the
310      TPUEmbeddingConfiguration. Supported values are {'unspecified',
311      'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
312      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
313      is used (optional).
314    name: A name for the operation (optional).
315
316  Returns:
317    An EnqueueTPUEmbeddingIntegerBatch operation.
318  """
319  if mode_override is None:
320    mode_override = "unspecified"
321  return gen_tpu_ops.enqueue_tpu_embedding_integer_batch(
322      batch=batch,
323      device_ordinal=device_ordinal,
324      mode_override=mode_override,
325      name=name)
326
327
328enqueue_tpu_embedding_integer_batch.__doc__ = (
329    gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__)
330
331
332# pylint: disable=protected-access
333def enqueue_tpu_embedding_sparse_batch(sample_indices,
334                                       embedding_indices,
335                                       aggregation_weights,
336                                       device_ordinal,
337                                       combiners=None,
338                                       mode_override=None,
339                                       name=None):
340  """A placeholder op for enqueueing embedding IDs to the TPU.
341
342  Args:
343    sample_indices: A list of rank 1 Tensors specifying the training example and
344      feature to which the corresponding embedding_indices and
345      aggregation_weights values belong. sample_indices[i] must equal b * nf +
346      f, where nf is the number of features from the corresponding table, f is
347      in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
348      and will be converted to int32 internally.
349    embedding_indices: A list of rank 1 Tensors, indices into the embedding
350      tables. Both int32 and int64 are allowed and will be converted to int32
351      internally.
352    aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e.,
353      per (training example, feature) -- aggregation weights. Both float32 and
354      float64 are allowed and will be converted to float32 internally.
355    device_ordinal: The TPU device to use. Should be >= 0 and less than the
356      number of TPU cores in the task on which the node is placed.
357    combiners: A list of string scalars, one for each embedding table that
358      specify how to normalize the embedding activations after weighted
359      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
360      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
361      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
362      is to use 'sum' for all tables (optional).
363    mode_override: A string input that overrides the mode specified in the
364      TPUEmbeddingConfiguration. Supported values are {'unspecified',
365      'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
366      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
367      is used (optional).
368    name: A name for the operation (optional).
369
370  Returns:
371    An EnqueueTPUEmbeddingSparseBatch operation.
372  """
373  if mode_override is None:
374    mode_override = "unspecified"
375  return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch(
376      sample_indices=sample_indices,
377      embedding_indices=embedding_indices,
378      aggregation_weights=aggregation_weights,
379      device_ordinal=device_ordinal,
380      combiners=combiners,
381      mode_override=mode_override,
382      name=name)
383
384
385enqueue_tpu_embedding_sparse_batch.__doc__ = (
386    gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__)
387
388
389# pylint: disable=protected-access
390def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
391                                              embedding_indices,
392                                              aggregation_weights,
393                                              table_ids,
394                                              device_ordinal,
395                                              max_sequence_lengths=None,
396                                              num_features=None,
397                                              combiners=None,
398                                              mode_override=None,
399                                              name=None):
400  """A placeholder op for enqueueing embedding IDs to the TPU.
401
402  Args:
403    sample_indices: A list of rank 2 Tensors specifying the training example to
404      which the corresponding embedding_indices and aggregation_weights values
405      belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). If
406      the size of its first dimension is 0, we assume each embedding_indices
407      belongs to a different sample. Both int32 and int64 are allowed and will
408      be converted to int32 internally.
409    embedding_indices: A list of rank 1 Tensors, indices into the embedding
410      tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
411      int32 and int64 are allowed and will be converted to int32 internally.
412    aggregation_weights: A list of rank 1 Tensors containing per training
413      example aggregation weights. It corresponds to sp_weights.values in
414      embedding_lookup_sparse(). If the size of its first dimension is 0, we
415      assume all weights are 1. Both float32 and float64 are allowed and will be
416      converted to float32 internally.
417    table_ids: A list of integers specifying the identifier of the embedding
418      table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
419      lookup the corresponding input. The ith input is looked up using
420      table_ids[i]. The size of the table_ids list must be equal to that of
421      sample_indices, embedding_indices and aggregation_weights.
422    device_ordinal: The TPU device to use. Should be >= 0 and less than the
423      number of TPU cores in the task on which the node is placed.
424    max_sequence_lengths: A list of integers, the size of which is equal to
425      sample_indices. If equal to 0, the corresponding feature is considered to
426      be a non-sequence feature, If greater than 0, the corresponding feature is
427      a sequence feature with the given maximal length. If None, then we assume
428      a list of all zeroes.
429    num_features: A list of integers, the size of which is equal to
430      sample_indices. If non-empty, entries in this list must be at least 1. For
431      each batch element, we will take num_features rows of the input tensor for
432      embedding lookup. E.g., when sample_indices is empty, the embedding
433      indices must be of shape (batch_size*num_features).
434    combiners: A list of string scalars, one for each embedding table that
435      specify how to normalize the embedding activations after weighted
436      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
437      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
438      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
439      is to use 'sum' for all tables (optional).
440    mode_override: A string input that overrides the mode specified in the
441      TPUEmbeddingConfiguration. Supported values are {'unspecified',
442      'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
443      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
444      is used (optional).
445    name: A name for the operation (optional).
446
447  Returns:
448    An EnqueueTPUEmbeddingSparseTensorBatch operation.
449  """
450  if mode_override is None:
451    mode_override = "unspecified"
452  return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
453      sample_indices=sample_indices,
454      embedding_indices=embedding_indices,
455      aggregation_weights=aggregation_weights,
456      table_ids=table_ids,
457      device_ordinal=device_ordinal,
458      max_sequence_lengths=max_sequence_lengths,
459      combiners=combiners,
460      mode_override=mode_override,
461      num_features=num_features,
462      name=name)
463
464
465enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
466    gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
467
468
469# pylint: disable=protected-access
470def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
471                                              embedding_indices,
472                                              aggregation_weights,
473                                              table_ids,
474                                              device_ordinal,
475                                              max_sequence_lengths=None,
476                                              num_features=None,
477                                              combiners=None,
478                                              mode_override=None,
479                                              name=None):
480  """A placeholder op for enqueueing embedding IDs to the TPU.
481
482  Args:
483    sample_splits: A list of rank 1 Tensors specifying the break points for
484      splitting embedding_indices and aggregation_weights into rows. It
485      corresponds to ids.row_splits in embedding_lookup(), when ids is a
486      RaggedTensor. Both int32 and int64 are allowed and will be converted to
487      int32 internally.
488    embedding_indices: A list of rank 1 Tensors, indices into the embedding
489      tables. It corresponds to ids.values in embedding_lookup(), when ids is a
490      RaggedTensor. Both int32 and int64 are allowed and will be converted to
491      int32 internally.
492    aggregation_weights: A list of rank 1 Tensors containing per training
493      example aggregation weights. It corresponds to the values field of a
494      RaggedTensor with the same row_splits as ids in embedding_lookup(), when
495      ids is a RaggedTensor. Both float32 and float64 are allowed and will be
496      converted to float32 internally.
497    table_ids: A list of integers specifying the identifier of the embedding
498      table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
499      lookup the corresponding input. The ith input is looked up using
500      table_ids[i]. The size of the table_ids list must be equal to that of
501      sample_indices, embedding_indices and aggregation_weights.
502    device_ordinal: The TPU device to use. Should be >= 0 and less than the
503      number of TPU cores in the task on which the node is placed.
504    max_sequence_lengths: A list of integers, the size of which is equal to
505      sample_indices. If equal to 0, the corresponding feature is considered to
506      be a non-sequence feature, If greater than 0, the corresponding feature is
507      a sequence feature with the given maximal length. If None, then we assume
508      a list of all zeroes.
509    num_features: A list of integers, the size of which must be equal to
510      sample_indices. If non-empty, entries in this list must be at least 1. For
511      each batch element, we will take num_features rows of the input tensor for
512      embedding lookup. E.g., when sample_indices is empty, the embedding
513      indices must be of shape (batch_size*num_features).
514    combiners: A list of string scalars, one for each embedding table that
515      specify how to normalize the embedding activations after weighted
516      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
517      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
518      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
519      is to use 'sum' for all tables (optional).
520    mode_override: A string input that overrides the mode specified in the
521      TPUEmbeddingConfiguration. Supported values are {'unspecified',
522      'inference', 'training', 'backward_pass_only'}. When set to 'unspecified',
523      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
524      is used (optional).
525    name: A name for the operation (optional).
526
527  Returns:
528    An EnqueueTPUEmbeddingRaggedTensorBatch operation.
529  """
530  if mode_override is None:
531    mode_override = "unspecified"
532  return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
533      sample_splits=sample_splits,
534      embedding_indices=embedding_indices,
535      aggregation_weights=aggregation_weights,
536      table_ids=table_ids,
537      device_ordinal=device_ordinal,
538      max_sequence_lengths=max_sequence_lengths,
539      combiners=combiners,
540      mode_override=mode_override,
541      num_features=num_features,
542      name=name)
543
544
545enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = (
546    gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__)
547
548
549def enqueue_tpu_embedding_arbitrary_tensor_batch(sample_indices_or_row_splits,
550                                                 embedding_indices,
551                                                 aggregation_weights,
552                                                 device_ordinal,
553                                                 combiners=None,
554                                                 mode_override=None,
555                                                 name=None):
556  """A placeholder op for enqueueing embedding IDs to the TPU.
557
558  Args:
559    sample_indices_or_row_splits: A list of rank 1 or 2 Tensors. When rank 2,
560      the tensors specify the training example to which the corresponding
561      embedding_indices and aggregation_weights values belong. If the size of
562      its first dimension is 0, we assume each embedding_indices belongs to a
563      different sample. Both int32 and int64 are allowed and will be converted
564      to int32 internally. When rank 1, the tensors specify the row splits for
565      splitting embedding_indices and aggregation_weights into rows. It
566      corresponds to ids.row_splits in embedding_lookup(), when ids is a
567      RaggedTensor. When enqueuing N-D ragged tensor, only the last dimension is
568      allowed to be ragged. the row splits is 1-D dense tensor. When empty, we
569      assume a dense tensor is passed to the op. Both int32 and int64 are
570      allowed and will be converted to int32 internally.
571    embedding_indices: A list of rank 1 Tensors, indices into the embedding
572      tables. Both int32 and int64 are allowed and will be converted to int32
573      internally.
574    aggregation_weights: A list of rank 1 Tensors containing per training
575      example aggregation weights. Both float32 and float64 are allowed and will
576      be converted to float32 internally.
577    device_ordinal: The TPU device to use. Should be >= 0 and less than the
578      number of TPU cores in the task on which the node is placed.
579    combiners: A list of string scalars, one for each embedding table that
580      specify how to normalize the embedding activations after weighted
581      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
582      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
583      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
584      is to use 'sum' for all tables (optional).
585    mode_override: A string input that overrides the mode specified in the
586      TPUEmbeddingConfiguration. Supported values are {'unspecified',
587      'inference', 'training', 'backward_pass_only'}. When set to 'unspecified',
588      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
589      is used (optional).
590    name: A name for the operation (optional).
591
592  Returns:
593    An EnqueueTPUEmbeddingArbitraryTensorBatch operation.
594  """
595  if mode_override is None:
596    mode_override = "unspecified"
597  return gen_tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch(
598      sample_indices_or_row_splits=sample_indices_or_row_splits,
599      embedding_indices=embedding_indices,
600      aggregation_weights=aggregation_weights,
601      device_ordinal=device_ordinal,
602      combiners=combiners,
603      mode_override=mode_override,
604      name=name)
605
606
607enqueue_tpu_embedding_arbitrary_tensor_batch.__doc__ = (
608    gen_tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch.__doc__)
609