xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/embedding_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for embeddings."""
16
17from tensorflow.python.framework import constant_op
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import sparse_tensor
21from tensorflow.python.framework import tensor_shape
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import clip_ops
24# Imports gradient definitions.
25from tensorflow.python.ops import data_flow_grad  # pylint: disable=unused-import
26from tensorflow.python.ops import data_flow_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import resource_variable_ops
29from tensorflow.python.ops import sparse_ops
30from tensorflow.python.ops import variables
31from tensorflow.python.ops.ragged import ragged_functional_ops
32from tensorflow.python.ops.ragged import ragged_tensor
33from tensorflow.python.util import dispatch
34from tensorflow.python.util.tf_export import tf_export
35
36
37def _clip(params, ids, max_norm):
38  """Helper function for _embedding_lookup_and_transform.
39
40  This function optionally clips embeddings to an l2-norm of max_norm.
41
42  Args:
43    params: A `Tensor` of embeddings retrieved by `gather`.
44    ids: The `ids` argument that was passed to `gather`.
45    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
46      than this value.
47
48  Returns:
49    A `Tensor` with the same type as `params`.
50  """
51
52  def _rank(x):
53    """Helper function to retrieve the rank of a tensor.
54
55    Args:
56      x: Something convertible to `Tensor`.
57
58    Returns:
59      Either a pair `(rank, True)` where `rank` is an integer or a pair
60      `(rank, False)` where `rank` is an integer `Tensor`. In either case,
61      `rank` is the rank of `x`.
62    """
63    rank = ops.convert_to_tensor(x).get_shape().ndims
64    if rank:
65      return rank, True
66    else:
67      return array_ops.rank(x), False
68
69  if max_norm is None:
70    return params
71  ids_rank, ids_static = _rank(ids)
72  params_rank, params_static = _rank(params)
73  return clip_ops.clip_by_norm(
74      params,
75      max_norm,
76      axes=(list(range(ids_rank, params_rank)) if ids_static and params_static
77            else math_ops.range(ids_rank, params_rank)))
78
79
80def _colocate_with(param):
81  if ops.inside_function() and hasattr(param, "handle"):
82    # The `ops.colocate_with` will hard-code a device string if `param.device`
83    # is known, which will then break serving. We capture it here so that it
84    # produces a tensor without a device.
85    return ops.colocate_with(ops.get_default_graph().capture(param.handle))
86  else:
87    return ops.colocate_with(param)
88
89
90def _embedding_lookup_and_transform(params,
91                                    ids,
92                                    partition_strategy="mod",
93                                    name=None,
94                                    max_norm=None,
95                                    transform_fn=None):
96  """Helper function for embedding_lookup and _compute_sampled_logits.
97
98  This function is a generalization of embedding_lookup that optionally
99  applies a caller-specified transformation to each embedding. This is
100  done through the `transform_fn` argument. If provided, the function is
101  applied to each partitioned tensor of retrieved embeddings, colocated
102  with the embeddings. This function will be called with a single `Tensor`
103  argument of the same type as the `params` tensor and should return a
104  `Tensor`. The shape of the argument will be the same as `params` except
105  for the size of the first dimension. The first dimension of the result's
106  shape must be the same size as the argument's.
107
108  Args:
109    params: See embedding_lookup.
110    ids: See embedding_lookup.
111    partition_strategy: See embedding_lookup.
112    name: See embedding_lookup.
113    max_norm: See embedding_lookup.
114    transform_fn: An optional function to apply to each retrieved embedding. If
115      max_norm is provided, transform_fn is applied to the norm-limited
116      embeddings.
117
118  Returns:
119    See embedding_lookup for details.
120  Raises:
121    ValueError: If `params` is empty.
122  """
123  if params is None:
124    raise ValueError("params must be specified")
125  if isinstance(params, (list, tuple)) and not params:
126    raise ValueError("Length of params is currently 0. "
127                     "Need at least one param.")
128  if isinstance(params, variables.PartitionedVariable):
129    params = list(params)  # Iterate to get the underlying Variables.
130  if not isinstance(params, list):
131    params = [params]
132
133  with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
134    np = len(params)  # Number of partitions
135    # Preserve the resource variable status to avoid accidental dense reads.
136    if not any(
137        isinstance(p, resource_variable_ops.BaseResourceVariable)
138        for p in params):
139      params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
140    ids = ops.convert_to_tensor(ids, name="ids")
141    if np == 1 and (not transform_fn or ids.get_shape().ndims == 1):
142      with _colocate_with(params[0]):
143        result = _clip(
144            array_ops.gather(params[0], ids, name=name), ids, max_norm)
145        if transform_fn:
146          result = transform_fn(result)
147      # Make sure the final result does not have colocation constraints on the
148      # params. Similar to the case np > 1 where parallel_dynamic_stitch is
149      # outside the scope of all with _colocate_with(params[p]).
150      return array_ops.identity(result)
151    else:
152      # Flatten the ids. There are two cases where we need to do this.
153      # - There is more than one params tensor.
154      # - There is a transform_fn and ids is not statically known to be 1-D.
155      #   We must flatten in this case because transform_fn expects a flat
156      #   tensor of embeddings.
157      flat_ids = array_ops.reshape(ids, [-1])
158      original_indices = math_ops.range(array_ops.size(flat_ids))
159
160      # Create p_assignments and set new_ids depending on the strategy.
161      if partition_strategy == "mod":
162        p_assignments = flat_ids % np
163        new_ids = flat_ids // np
164      elif partition_strategy == "div":
165        # Compute num_total_ids as the sum of dim-0 of params, then assign to
166        # partitions based on a constant number of ids per partition. Optimize
167        # if we already know the full shape statically.
168        dim_0_size = tensor_shape.Dimension(
169            tensor_shape.dimension_value(params[0].get_shape()[0]))
170        for p in range(1, np):
171          dim_0_size += tensor_shape.Dimension(
172              tensor_shape.dimension_value(params[p].get_shape()[0]))
173        if dim_0_size.value:
174          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
175        else:
176          dim_0_sizes = []
177          for p in range(np):
178            param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0])
179            if param_p_dim is not None:
180              dim_0_sizes.append(param_p_dim)
181            else:
182              with _colocate_with(params[p]):
183                dim_0_sizes.append(array_ops.shape(params[p])[0])
184          num_total_ids = math_ops.reduce_sum(
185              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
186        ids_per_partition = num_total_ids // np
187        extras = num_total_ids % np
188
189        p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1),
190                                         (flat_ids - extras) //
191                                         ids_per_partition)
192
193        # Emulate a conditional using a boolean indicator tensor
194        new_ids = array_ops.where(p_assignments < extras,
195                                  flat_ids % (ids_per_partition + 1),
196                                  (flat_ids - extras) % ids_per_partition)
197      else:
198        raise ValueError(
199            f"Unrecognized partition strategy: {partition_strategy}."
200            "Must be one of either `mod` or `div`.")
201
202      # Cast partition assignments to int32 for use in dynamic_partition.
203      # There really should not be more than 2^32 partitions.
204      p_assignments = math_ops.cast(p_assignments, dtypes.int32)
205      # Partition list of ids based on assignments into np separate lists
206      gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
207      # Similarly, partition the original indices.
208      pindices = data_flow_ops.dynamic_partition(original_indices,
209                                                 p_assignments, np)
210      # Do np separate lookups, finding embeddings for plist[p] in params[p]
211      partitioned_result = []
212      for p in range(np):
213        pids = gather_ids[p]
214        with ops.device_v2(None):
215          with _colocate_with(params[p]):
216            result = array_ops.gather(params[p], pids)
217            if transform_fn:
218              # If transform_fn is provided, the clip_by_norm precedes
219              # the transform and hence must be co-located. See below
220              # for the counterpart if transform_fn is not provided.
221              result = transform_fn(_clip(result, pids, max_norm))
222        partitioned_result.append(result)
223      # Stitch these back together
224      ret = data_flow_ops.parallel_dynamic_stitch(
225          pindices, partitioned_result, name=name)
226
227      # Determine the static element shape.
228      if transform_fn is None:
229        element_shape_s = params[0].get_shape()[1:]
230        for p in params[1:]:
231          element_shape_s = element_shape_s.merge_with(p.get_shape()[1:])
232      else:
233        element_shape_s = ret.get_shape()[1:]
234
235      # Compute the dynamic element shape.
236      if element_shape_s.is_fully_defined():
237        element_shape_d = element_shape_s
238      elif transform_fn is None:
239        # It's important that we compute params[0].shape on the right device
240        # to avoid data motion.
241        with _colocate_with(params[0]):
242          params_shape = array_ops.shape(params[0])
243        element_shape_d = params_shape[1:]
244      else:
245        element_shape_d = array_ops.shape(ret)[1:]
246
247      # Reshape to reverse the flattening of ids.
248      ret = array_ops.reshape(
249          ret, array_ops.concat([array_ops.shape(ids), element_shape_d], 0))
250
251      # Normally the reshape is sufficient, but setting shape explicitly
252      # teaches shape inference that params[1:].get_shape() matters
253      # (in the case that transform_fn is None).
254      ret.set_shape(ids.get_shape().concatenate(element_shape_s))
255      if not transform_fn:
256        # If transform_fn was provided, the clip_by_norm was done above.
257        ret = _clip(ret, ids, max_norm)
258      return ret
259
260
261@tf_export(v1=["nn.embedding_lookup"])
262@dispatch.add_dispatch_support
263def embedding_lookup(
264    params,
265    ids,
266    partition_strategy="mod",
267    name=None,
268    validate_indices=True,  # pylint: disable=unused-argument
269    max_norm=None):
270  """Looks up embeddings for the given `ids` from a list of tensors.
271
272  This function is used to perform parallel lookups on the list of tensors in
273  `params`.  It is a generalization of `tf.gather`, where `params` is
274  interpreted as a partitioning of a large embedding tensor.  `params` may be
275  a `PartitionedVariable` as returned by using `tf.compat.v1.get_variable()`
276  with a partitioner.
277
278  If `len(params) > 1`, each element `id` of `ids` is partitioned between
279  the elements of `params` according to the `partition_strategy`.
280  In all strategies, if the id space does not evenly divide the number of
281  partitions, each of the first `(max_id + 1) % len(params)` partitions will
282  be assigned one more id.
283
284  If `partition_strategy` is `"mod"`, we assign each id to partition
285  `p = id % len(params)`. For instance,
286  13 ids are split across 5 partitions as:
287  `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
288
289  If `partition_strategy` is `"div"`, we assign ids to partitions in a
290  contiguous manner. In this case, 13 ids are split across 5 partitions as:
291  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`
292
293  If the input ids are ragged tensors, partition variables are not supported and
294  the partition strategy and the max_norm are ignored.
295  The results of the lookup are concatenated into a dense
296  tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
297
298  Args:
299    params: A single tensor representing the complete embedding tensor, or a
300      list of P tensors all of same shape except for the first dimension,
301      representing sharded embedding tensors.  Alternatively, a
302      `PartitionedVariable`, created by partitioning along dimension 0. Each
303      element must be appropriately sized for the given `partition_strategy`.
304    ids: A `Tensor` or a 'RaggedTensor' with type `int32` or `int64` containing
305      the ids to be looked up in `params`.
306    partition_strategy: A string specifying the partitioning strategy, relevant
307      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
308      is `"mod"`.
309    name: A name for the operation (optional).
310    validate_indices: DEPRECATED. If this operation is assigned to CPU, values
311      in `indices` are always validated to be within range.  If assigned to GPU,
312      out-of-bound indices result in safe but unspecified behavior, which may
313      include raising an error.
314    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
315      than this value.
316
317  Returns:
318    A `Tensor` or a 'RaggedTensor', depending on the input, with the same type
319    as the tensors in `params`.
320
321  Raises:
322    ValueError: If `params` is empty.
323  """
324  if isinstance(ids, ragged_tensor.RaggedTensor):
325    return embedding_lookup_ragged(params, ids,
326                                   partition_strategy=partition_strategy,
327                                   max_norm=max_norm,
328                                   name=name)
329
330  return _embedding_lookup_and_transform(
331      params=params,
332      ids=ids,
333      partition_strategy=partition_strategy,
334      name=name,
335      max_norm=max_norm,
336      transform_fn=None)
337
338
339@tf_export("nn.embedding_lookup", v1=[])
340@dispatch.add_dispatch_support
341def embedding_lookup_v2(params, ids, max_norm=None, name=None):
342  """Looks up embeddings for the given `ids` from a list of tensors.
343
344  This function is used to perform parallel lookups on the list of tensors in
345  `params`.  It is a generalization of `tf.gather`, where `params` is
346  interpreted as a partitioning of a large embedding tensor.
347
348  If `len(params) > 1`, each element `id` of `ids` is partitioned between the
349  elements of `params` according to the "div" partition strategy, which means we
350  assign ids to partitions in a contiguous manner. For instance, 13 ids are
351  split across 5 partitions as:
352  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
353
354  If the id space does not evenly divide the number of partitions, each of the
355  first `(max_id + 1) % len(params)` partitions will be assigned one more id.
356
357  The results of the lookup are concatenated into a dense
358  tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
359
360  Args:
361    params: A single tensor representing the complete embedding tensor, or a
362      list of tensors all of same shape except for the first dimension,
363      representing sharded embedding tensors following "div" partition strategy.
364    ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
365      up in `params`.
366    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
367      than this value.
368    name: A name for the operation (optional).
369
370  Returns:
371    A `Tensor` with the same type as the tensors in `params`.
372
373    For instance, if `params` is a 5x2 matrix:
374
375    ```python
376    [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
377    ```
378
379    or a list of matrices:
380
381    ```python
382    params[0]: [[1, 2], [3, 4]]
383    params[1]: [[5, 6], [7, 8]]
384    params[2]: [[9, 10]]
385    ```
386
387    and `ids` is:
388
389    ```python
390    [0, 3, 4]
391    ```
392
393    The output will be a 3x2 matrix:
394
395    ```python
396    [[1, 2], [7, 8], [9, 10]]
397    ```
398
399  Raises:
400    ValueError: If `params` is empty.
401  """
402  return embedding_lookup(params, ids, "div", name, max_norm=max_norm)
403
404
405@tf_export(v1=["nn.embedding_lookup_sparse"])
406@dispatch.add_dispatch_support
407def embedding_lookup_sparse(params,
408                            sp_ids,
409                            sp_weights,
410                            partition_strategy="mod",
411                            name=None,
412                            combiner=None,
413                            max_norm=None):
414  """Looks up embeddings for the given ids and weights from a list of tensors.
415
416  This op assumes that there is at least one id for each row in the dense tensor
417  represented by sp_ids (i.e. there are no rows with empty features), and that
418  all the indices of sp_ids are in canonical row-major order.
419
420  `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s with rank of 2.
421  Embeddings are always aggregated along the last dimension.
422
423  It also assumes that all id values lie in the range [0, p0), where p0
424  is the sum of the size of params along dimension 0.
425
426  Args:
427    params: A single tensor representing the complete embedding tensor, or a
428      list tensors all of same shape except for the first dimension,
429      representing sharded embedding tensors. Alternatively, a
430      `PartitionedVariable`, created by partitioning along dimension 0. Each
431      element must be appropriately sized for the given `partition_strategy`.
432    sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size
433      and M is arbitrary.
434    sp_weights: either a `SparseTensor` of float / double weights, or `None` to
435      indicate all weights should be taken to be 1. If specified, `sp_weights`
436      must have exactly the same shape and indices as `sp_ids`.
437    partition_strategy: A string specifying the partitioning strategy, relevant
438      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
439      is `"mod"`. See `tf.nn.embedding_lookup` for more details.
440    name: Optional name for the op.
441    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
442      and "sum" are supported. "sum" computes the weighted sum of the embedding
443      results for each row. "mean" is the weighted sum divided by the total
444      weight. "sqrtn" is the weighted sum divided by the square root of the sum
445      of the squares of the weights. Defaults to `mean`.
446    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
447      than this value, before combining.
448
449  Returns:
450    A dense tensor representing the combined embeddings for the
451    sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
452    looks up the embeddings for all ids in that row, multiplies them by the
453    corresponding weight, and combines these embeddings as specified.
454
455    In other words, if
456
457      `shape(combined params) = [p0, p1, ..., pm]`
458
459    and
460
461      `shape(sp_ids) = shape(sp_weights) = [d0, d1]`
462
463    then
464
465      `shape(output) = [d0, p1, ..., pm]`.
466
467    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
468
469      ```python
470      [0, 0]: id 1, weight 2.0
471      [0, 1]: id 3, weight 0.5
472      [1, 0]: id 0, weight 1.0
473      [2, 3]: id 1, weight 3.0
474      ```
475
476    with `combiner`="mean", then the output will be a 3x20 matrix where
477
478      ```python
479      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
480      output[1, :] = (params[0, :] * 1.0) / 1.0
481      output[2, :] = (params[1, :] * 3.0) / 3.0
482      ```
483
484  Raises:
485    TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
486      neither `None` nor `SparseTensor`.
487    ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
488  """
489  if combiner is None:
490    combiner = "mean"
491  if combiner not in ("mean", "sqrtn", "sum"):
492    raise ValueError(
493        f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}")
494  if isinstance(params, variables.PartitionedVariable):
495    params = list(params)  # Iterate to get the underlying Variables.
496  if not isinstance(params, list):
497    params = [params]
498  if not isinstance(sp_ids, sparse_tensor.SparseTensor):
499    raise TypeError(f"sp_ids must be SparseTensor, got {type(sp_ids)}")
500  ignore_weights = sp_weights is None
501  if not ignore_weights:
502    if not isinstance(sp_weights, sparse_tensor.SparseTensor):
503      raise TypeError(f"sp_weights must be either None or SparseTensor,"
504                      f"got {type(sp_weights)}")
505    sp_ids.values.get_shape().assert_is_compatible_with(
506        sp_weights.values.get_shape())
507    sp_ids.indices.get_shape().assert_is_compatible_with(
508        sp_weights.indices.get_shape())
509    sp_ids.dense_shape.get_shape().assert_is_compatible_with(
510        sp_weights.dense_shape.get_shape())
511    # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
512    # sp_weights have equal indices and shapes.
513
514  with ops.name_scope(name, "embedding_lookup_sparse",
515                      params + [sp_ids]) as name:
516    segment_ids = sp_ids.indices[:, 0]
517
518    ids = sp_ids.values
519    ids, idx = array_ops.unique(ids)
520
521    embeddings = embedding_lookup(
522        params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
523    if not ignore_weights:
524      if segment_ids.dtype != dtypes.int32:
525        segment_ids = math_ops.cast(segment_ids, dtypes.int32)
526
527      weights = sp_weights.values
528      embeddings = array_ops.gather(embeddings, idx)
529
530      original_dtype = embeddings.dtype
531      if embeddings.dtype in (dtypes.float16, dtypes.bfloat16):
532        # Cast low-precision embeddings to float32 during the computation to
533        # avoid numerical issues.
534        embeddings = math_ops.cast(embeddings, dtypes.float32)
535      if weights.dtype != embeddings.dtype:
536        weights = math_ops.cast(weights, embeddings.dtype)
537
538      # Reshape weights to allow broadcast
539      ones_shape = array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0)
540      ones = array_ops.ones(ones_shape, dtype=dtypes.int32)
541      bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
542                                             0)
543
544      orig_weights_shape = weights.get_shape()
545      weights = array_ops.reshape(weights, bcast_weights_shape)
546
547      # Set the weight shape, since after reshaping to bcast_weights_shape,
548      # the shape becomes None.
549      if embeddings.get_shape().ndims is not None:
550        weights.set_shape(
551            orig_weights_shape.concatenate(
552                [1 for _ in range(embeddings.get_shape().ndims - 1)]))
553
554      embeddings *= weights
555
556      if combiner == "sum":
557        embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
558      elif combiner == "mean":
559        embeddings = math_ops.segment_sum(embeddings, segment_ids)
560        weight_sum = math_ops.segment_sum(weights, segment_ids)
561        embeddings = math_ops.div_no_nan(embeddings, weight_sum, name=name)
562      elif combiner == "sqrtn":
563        embeddings = math_ops.segment_sum(embeddings, segment_ids)
564        weights_squared = math_ops.pow(weights, 2)
565        weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
566        weight_sum_sqrt = math_ops.sqrt(weight_sum)
567        embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt, name=name)
568      else:
569        assert False, "Unrecognized combiner"
570      if embeddings.dtype != original_dtype:
571        embeddings = math_ops.cast(embeddings, original_dtype)
572    else:
573      if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
574        segment_ids = math_ops.cast(segment_ids, dtypes.int32)
575      assert idx is not None
576      if combiner == "sum":
577        embeddings = math_ops.sparse_segment_sum(
578            embeddings, idx, segment_ids, name=name)
579      elif combiner == "mean":
580        embeddings = math_ops.sparse_segment_mean(
581            embeddings, idx, segment_ids, name=name)
582      elif combiner == "sqrtn":
583        embeddings = math_ops.sparse_segment_sqrt_n(
584            embeddings, idx, segment_ids, name=name)
585      else:
586        assert False, "Unrecognized combiner"
587
588    return embeddings
589
590
591@tf_export("nn.embedding_lookup_sparse", v1=[])
592@dispatch.add_dispatch_support
593def embedding_lookup_sparse_v2(params,
594                               sp_ids,
595                               sp_weights,
596                               combiner=None,
597                               max_norm=None,
598                               name=None):
599  """Looks up embeddings for the given ids and weights from a list of tensors.
600
601  This op assumes that there is at least one id for each row in the dense tensor
602  represented by sp_ids (i.e. there are no rows with empty features), and that
603  all the indices of sp_ids are in canonical row-major order.
604
605  `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s with rank of 2.
606  Embeddings are always aggregated along the last dimension.
607
608  It also assumes that all id values lie in the range [0, p0), where p0
609  is the sum of the size of params along dimension 0.
610
611  If `len(params) > 1`, each element of `sp_ids` is partitioned between the
612  elements of `params` according to the "div" partition strategy, which means we
613  assign ids to partitions in a contiguous manner. For instance, 13 ids are
614  split across 5 partitions as:
615  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
616
617  If the id space does not evenly divide the number of partitions, each of the
618  first `(max_id + 1) % len(params)` partitions will be assigned one more id.
619
620  Args:
621    params: A single tensor representing the complete embedding tensor, or a
622      list of tensors all of same shape except for the first dimension,
623      representing sharded embedding tensors following "div" partition strategy.
624    sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size
625      and M is arbitrary.
626    sp_weights: either a `SparseTensor` of float / double weights, or `None` to
627      indicate all weights should be taken to be 1. If specified, `sp_weights`
628      must have exactly the same shape and indices as `sp_ids`.
629    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
630      and "sum" are supported. "sum" computes the weighted sum of the embedding
631      results for each row. "mean" is the weighted sum divided by the total
632      weight. "sqrtn" is the weighted sum divided by the square root of the sum
633      of the squares of the weights. Defaults to `mean`.
634    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
635      than this value, before combining.
636    name: Optional name for the op.
637
638  Returns:
639    A dense tensor representing the combined embeddings for the
640    sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
641    looks up the embeddings for all ids in that row, multiplies them by the
642    corresponding weight, and combines these embeddings as specified.
643
644    In other words, if
645
646      `shape(combined params) = [p0, p1, ..., pm]`
647
648    and
649
650      `shape(sp_ids) = shape(sp_weights) = [d0, d1]`
651
652    then
653
654      `shape(output) = [d0, p1, ..., pm]`.
655
656    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
657
658      ```python
659      [0, 0]: id 1, weight 2.0
660      [0, 1]: id 3, weight 0.5
661      [1, 0]: id 0, weight 1.0
662      [2, 3]: id 1, weight 3.0
663      ```
664
665    with `combiner`="mean", then the output will be a 3x20 matrix where
666
667      ```python
668      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
669      output[1, :] = (params[0, :] * 1.0) / 1.0
670      output[2, :] = (params[1, :] * 3.0) / 3.0
671      ```
672
673  Raises:
674    TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
675      neither `None` nor `SparseTensor`.
676    ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
677  """
678  return embedding_lookup_sparse(params, sp_ids, sp_weights, "div", name,
679                                 combiner, max_norm)
680
681
682@tf_export("nn.safe_embedding_lookup_sparse", v1=[])
683@dispatch.add_dispatch_support
684def safe_embedding_lookup_sparse_v2(embedding_weights,
685                                    sparse_ids,
686                                    sparse_weights=None,
687                                    combiner="mean",
688                                    default_id=None,
689                                    max_norm=None,
690                                    name=None):
691  """Lookup embedding results, accounting for invalid IDs and empty features.
692
693  The partitioned embedding in `embedding_weights` must all be the same shape
694  except for the first dimension. The first dimension is allowed to vary as the
695  vocabulary size is not necessarily a multiple of num of shards.
696
697  Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
698  with non-positive weight. For an entry with no features, the embedding vector
699  for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
700
701  The ids and weights may be multi-dimensional. Embeddings are always aggregated
702  along the last dimension.
703
704  If `len(embedding_weights) > 1`, each element `id` of `ids` is partitioned
705  between the elements of `embedding_weights` according to the "div" partition
706  strategy, which means we assign ids to partitions in a contiguous manner. For
707  instance, 13 ids are split across 5 partitions as:
708  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
709
710  If the id space does not evenly divide the number of partitions, each of the
711  first `(max_id + 1) % len(embedding_weights)` partitions will be assigned one
712  more id.
713
714  Args:
715    embedding_weights: A single tensor representing the complete embedding
716      tensor, or a list of tensors all of same shape except for the first
717      dimension, representing sharded embedding tensors following "div"
718      partition strategy.
719    sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
720      ids. `d_0` is typically batch size.
721    sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
722      float weights corresponding to `sparse_ids`, or `None` if all weights are
723      be assumed to be 1.0.
724    combiner: A string specifying how to combine embedding results for each
725      entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the
726      default.
727    default_id: The id to use for an entry with no features. Defaults to
728      0-vector.
729    max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
730      combining.
731    name: A name for this operation (optional).
732
733  Returns:
734    A dense tensor representing the combined embeddings for the
735    sparse ids. For each row in the dense tensor represented by `sparse_ids`,
736    the op looks up the embeddings for all ids in that row, multiplies them by
737    the corresponding weight, and combines these embeddings as specified.
738
739    In other words, if
740
741      `shape(combined embedding_weights) = [p0, p1, ..., pm]`
742
743    and
744
745      `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]`
746
747    then
748
749      `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`.
750
751    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
752
753      ```python
754      [0, 0]: id 1, weight 2.0
755      [0, 1]: id 3, weight 0.5
756      [1, 0]: id -1, weight 1.0
757      [2, 3]: id 1, weight 3.0
758      ```
759
760    `default_id` is 0.
761
762    with `combiner`="mean", then the output will be a 3x20 matrix where
763
764      ```python
765      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
766      output[1, :] = (params[0, :] * 1.0) / 1.0
767      output[2, :] = (params[1, :] * 3.0) / 3.0
768      ```
769
770  Raises:
771    ValueError: if `embedding_weights` is empty.
772  """
773  return safe_embedding_lookup_sparse(
774      embedding_weights,
775      sparse_ids,
776      sparse_weights=sparse_weights,
777      combiner=combiner,
778      default_id=default_id,
779      name=name,
780      partition_strategy="div",
781      max_norm=max_norm)
782
783
784@tf_export(v1=["nn.safe_embedding_lookup_sparse"])
785@dispatch.add_dispatch_support
786def safe_embedding_lookup_sparse(embedding_weights,
787                                 sparse_ids,
788                                 sparse_weights=None,
789                                 combiner="mean",
790                                 default_id=None,
791                                 name=None,
792                                 partition_strategy="div",
793                                 max_norm=None):
794  """Lookup embedding results, accounting for invalid IDs and empty features.
795
796  The partitioned embedding in `embedding_weights` must all be the same shape
797  except for the first dimension. The first dimension is allowed to vary as the
798  vocabulary size is not necessarily a multiple of `P`.  `embedding_weights`
799  may be a `PartitionedVariable` as returned by using
800  `tf.compat.v1.get_variable()` with a
801  partitioner.
802
803  Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
804  with non-positive weight. For an entry with no features, the embedding vector
805  for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
806
807  The ids and weights may be multi-dimensional. Embeddings are always aggregated
808  along the last dimension.
809
810  Args:
811    embedding_weights: A single tensor representing the complete embedding
812      tensor, or a list tensors all of same shape except for the first
813      dimension, representing sharded embedding tensors. Alternatively, a
814      `PartitionedVariable`, created by partitioning along dimension 0. Each
815      element must be appropriately sized for the given `partition_strategy`.
816    sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
817      ids. `d_0` is typically batch size.
818    sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
819      float weights corresponding to `sparse_ids`, or `None` if all weights are
820      be assumed to be 1.0.
821    combiner: A string specifying how to combine embedding results for each
822      entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the
823      default.
824    default_id: The id to use for an entry with no features.
825    name: A name for this operation (optional).
826    partition_strategy: A string specifying the partitioning strategy. Currently
827      `"div"` and `"mod"` are supported. Default is `"div"`.
828    max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
829      combining.
830
831  Returns:
832    A dense tensor representing the combined embeddings for the
833    sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
834    looks up the embeddings for all ids in that row, multiplies them by the
835    corresponding weight, and combines these embeddings as specified.
836
837    In other words, if
838
839      `shape(combined embedding_weights) = [p0, p1, ..., pm]`
840
841    and
842
843      `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]`
844
845    then
846
847      `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`.
848
849    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
850
851      ```python
852      [0, 0]: id 1, weight 2.0
853      [0, 1]: id 3, weight 0.5
854      [1, 0]: id -1, weight 1.0
855      [2, 3]: id 1, weight 3.0
856      ```
857
858    `default_id` is 0.
859
860    with `combiner`="mean", then the output will be a 3x20 matrix where
861
862      ```python
863      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
864      output[1, :] = (params[0, :] * 1.0) / 1.0
865      output[2, :] = (params[1, :] * 3.0) / 3.0
866      ```
867
868  Raises:
869    ValueError: if `embedding_weights` is empty.
870  """
871  if embedding_weights is None:
872    raise ValueError(f"Missing embedding_weights {embedding_weights}.")
873  if isinstance(embedding_weights, variables.PartitionedVariable):
874    embedding_weights = list(embedding_weights)  # get underlying Variables.
875  if not isinstance(embedding_weights, list):
876    embedding_weights = [embedding_weights]
877  if len(embedding_weights) < 1:
878    raise ValueError(f"Missing embedding_weights {embedding_weights}.")
879
880  dtype = sparse_weights.dtype if sparse_weights is not None else None
881  embedding_weights = [
882      w if (isinstance(w, resource_variable_ops.ResourceVariable)
883            and dtype in (None, w.dtype))
884      else ops.convert_to_tensor(w, dtype=dtype)
885      for w in embedding_weights
886  ]
887
888  with ops.name_scope(name, "embedding_lookup", embedding_weights +
889                      [sparse_ids, sparse_weights]) as scope:
890    # Reshape higher-rank sparse ids and weights to linear segment ids.
891    original_shape = sparse_ids.dense_shape
892    original_rank_dim = tensor_shape.dimension_value(
893        sparse_ids.dense_shape.get_shape()[0])
894    original_rank = (
895        array_ops.size(original_shape)
896        if original_rank_dim is None else original_rank_dim)
897    sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
898        math_ops.reduce_prod(
899            array_ops.slice(original_shape, [0], [original_rank - 1])),
900        array_ops.gather(original_shape, original_rank - 1)
901    ])
902    if sparse_weights is not None:
903      sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices,
904                                                  sparse_weights.values,
905                                                  sparse_ids.dense_shape)
906
907    # Prune invalid ids and weights.
908    sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
909    if combiner != "sum":
910      sparse_ids, sparse_weights = _prune_invalid_weights(
911          sparse_ids, sparse_weights)
912
913    # Fill in dummy values for empty features, if necessary.
914    sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(
915        sparse_ids, default_id or 0)
916    if sparse_weights is not None:
917      sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
918
919    result = embedding_lookup_sparse(
920        embedding_weights,
921        sparse_ids,
922        sparse_weights,
923        combiner=combiner,
924        partition_strategy=partition_strategy,
925        name=None if default_id is None else scope,
926        max_norm=max_norm)
927
928    if default_id is None:
929      # Broadcast is_row_empty to the same shape as embedding_lookup_result,
930      # for use in Select.
931      is_row_empty = array_ops.tile(
932          array_ops.reshape(is_row_empty, [-1, 1]),
933          array_ops.stack([1, array_ops.shape(result)[1]]))
934
935      result = array_ops.where(
936          is_row_empty, array_ops.zeros_like(result), result, name=scope)
937
938    # Reshape back from linear ids back into higher-dimensional dense result.
939    final_result = array_ops.reshape(
940        result,
941        array_ops.concat([
942            array_ops.slice(
943                math_ops.cast(original_shape, dtypes.int32), [0],
944                [original_rank - 1]),
945            array_ops.slice(array_ops.shape(result), [1], [-1])
946        ], 0))
947    final_result.set_shape(
948        tensor_shape.unknown_shape(
949            (tensor_shape.Dimension(original_rank_dim) - 1).value).concatenate(
950                result.get_shape()[1:]))
951    return final_result
952
953
954def embedding_lookup_ragged(embedding_weights,
955                            ragged_ids,
956                            partition_strategy="mod",
957                            max_norm=None,
958                            name=None):
959  """Look up the ragged ids in a list of embedding tensors.
960
961  Args:
962    embedding_weights: A tensor representing the complete embedding tensor
963      having the shape [e1, ...eM]
964    ragged_ids: A 'RaggedTensor' with type 'int32' or 'int64' containing the ids
965      to be looked up in 'embedding_weights' of shape [r0, ..rN]. Values must be
966      in the range '[0, embedding_weights.shape[0]]'.
967    partition_strategy: A string specifying the partitioning strategy.
968    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
969      than this value.
970    name: A name for the operation (optional)
971
972  Returns:
973    A ragged tensor of shape [r0, r1, ...rN, e1, ...eM].
974
975  Raises:
976    ValueError: whether the embedding_weights is empty or the ragged_ids is
977    not a RaggedTensor.
978  """
979  if embedding_weights is None:
980    raise ValueError("The embedding weights must be specified.")
981  if isinstance(embedding_weights, (list, tuple)) and not embedding_weights:
982    raise ValueError("The embedding weights should not be empty.")
983  if ragged_ids.dtype != dtypes.int32 and ragged_ids.dtype != dtypes.int64:
984    raise ValueError("The values contained by the inputs have type "
985                     f"{str(ragged_ids.dtype)}"
986                     " and cannot be processed. All values"
987                     " should be indices, either of type `in32` or `int64`.")
988
989  with ops.name_scope(name, "embedding_lookup_ragged") as name:
990    looked_up_ragged = ragged_functional_ops.map_flat_values(
991        embedding_lookup,
992        params=embedding_weights,
993        ids=ragged_ids,
994        partition_strategy=partition_strategy,
995        max_norm=max_norm)
996
997    return looked_up_ragged
998
999
1000def _prune_invalid_ids(sparse_ids, sparse_weights):
1001  """Prune invalid IDs (< 0) from the input ids and weights."""
1002  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
1003  if sparse_weights is not None:
1004    is_id_valid = math_ops.logical_and(
1005        is_id_valid,
1006        array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
1007  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
1008  if sparse_weights is not None:
1009    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
1010  return sparse_ids, sparse_weights
1011
1012
1013def _prune_invalid_weights(sparse_ids, sparse_weights):
1014  """Prune invalid weights (< 0) from the input ids and weights."""
1015  if sparse_weights is not None:
1016    is_weights_valid = math_ops.greater(sparse_weights.values, 0)
1017    sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
1018    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
1019  return sparse_ids, sparse_weights
1020