1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operations for embeddings.""" 16 17from tensorflow.python.framework import constant_op 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import sparse_tensor 21from tensorflow.python.framework import tensor_shape 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import clip_ops 24# Imports gradient definitions. 25from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import 26from tensorflow.python.ops import data_flow_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import resource_variable_ops 29from tensorflow.python.ops import sparse_ops 30from tensorflow.python.ops import variables 31from tensorflow.python.ops.ragged import ragged_functional_ops 32from tensorflow.python.ops.ragged import ragged_tensor 33from tensorflow.python.util import dispatch 34from tensorflow.python.util.tf_export import tf_export 35 36 37def _clip(params, ids, max_norm): 38 """Helper function for _embedding_lookup_and_transform. 39 40 This function optionally clips embeddings to an l2-norm of max_norm. 41 42 Args: 43 params: A `Tensor` of embeddings retrieved by `gather`. 44 ids: The `ids` argument that was passed to `gather`. 45 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 46 than this value. 47 48 Returns: 49 A `Tensor` with the same type as `params`. 50 """ 51 52 def _rank(x): 53 """Helper function to retrieve the rank of a tensor. 54 55 Args: 56 x: Something convertible to `Tensor`. 57 58 Returns: 59 Either a pair `(rank, True)` where `rank` is an integer or a pair 60 `(rank, False)` where `rank` is an integer `Tensor`. In either case, 61 `rank` is the rank of `x`. 62 """ 63 rank = ops.convert_to_tensor(x).get_shape().ndims 64 if rank: 65 return rank, True 66 else: 67 return array_ops.rank(x), False 68 69 if max_norm is None: 70 return params 71 ids_rank, ids_static = _rank(ids) 72 params_rank, params_static = _rank(params) 73 return clip_ops.clip_by_norm( 74 params, 75 max_norm, 76 axes=(list(range(ids_rank, params_rank)) if ids_static and params_static 77 else math_ops.range(ids_rank, params_rank))) 78 79 80def _colocate_with(param): 81 if ops.inside_function() and hasattr(param, "handle"): 82 # The `ops.colocate_with` will hard-code a device string if `param.device` 83 # is known, which will then break serving. We capture it here so that it 84 # produces a tensor without a device. 85 return ops.colocate_with(ops.get_default_graph().capture(param.handle)) 86 else: 87 return ops.colocate_with(param) 88 89 90def _embedding_lookup_and_transform(params, 91 ids, 92 partition_strategy="mod", 93 name=None, 94 max_norm=None, 95 transform_fn=None): 96 """Helper function for embedding_lookup and _compute_sampled_logits. 97 98 This function is a generalization of embedding_lookup that optionally 99 applies a caller-specified transformation to each embedding. This is 100 done through the `transform_fn` argument. If provided, the function is 101 applied to each partitioned tensor of retrieved embeddings, colocated 102 with the embeddings. This function will be called with a single `Tensor` 103 argument of the same type as the `params` tensor and should return a 104 `Tensor`. The shape of the argument will be the same as `params` except 105 for the size of the first dimension. The first dimension of the result's 106 shape must be the same size as the argument's. 107 108 Args: 109 params: See embedding_lookup. 110 ids: See embedding_lookup. 111 partition_strategy: See embedding_lookup. 112 name: See embedding_lookup. 113 max_norm: See embedding_lookup. 114 transform_fn: An optional function to apply to each retrieved embedding. If 115 max_norm is provided, transform_fn is applied to the norm-limited 116 embeddings. 117 118 Returns: 119 See embedding_lookup for details. 120 Raises: 121 ValueError: If `params` is empty. 122 """ 123 if params is None: 124 raise ValueError("params must be specified") 125 if isinstance(params, (list, tuple)) and not params: 126 raise ValueError("Length of params is currently 0. " 127 "Need at least one param.") 128 if isinstance(params, variables.PartitionedVariable): 129 params = list(params) # Iterate to get the underlying Variables. 130 if not isinstance(params, list): 131 params = [params] 132 133 with ops.name_scope(name, "embedding_lookup", params + [ids]) as name: 134 np = len(params) # Number of partitions 135 # Preserve the resource variable status to avoid accidental dense reads. 136 if not any( 137 isinstance(p, resource_variable_ops.BaseResourceVariable) 138 for p in params): 139 params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") 140 ids = ops.convert_to_tensor(ids, name="ids") 141 if np == 1 and (not transform_fn or ids.get_shape().ndims == 1): 142 with _colocate_with(params[0]): 143 result = _clip( 144 array_ops.gather(params[0], ids, name=name), ids, max_norm) 145 if transform_fn: 146 result = transform_fn(result) 147 # Make sure the final result does not have colocation constraints on the 148 # params. Similar to the case np > 1 where parallel_dynamic_stitch is 149 # outside the scope of all with _colocate_with(params[p]). 150 return array_ops.identity(result) 151 else: 152 # Flatten the ids. There are two cases where we need to do this. 153 # - There is more than one params tensor. 154 # - There is a transform_fn and ids is not statically known to be 1-D. 155 # We must flatten in this case because transform_fn expects a flat 156 # tensor of embeddings. 157 flat_ids = array_ops.reshape(ids, [-1]) 158 original_indices = math_ops.range(array_ops.size(flat_ids)) 159 160 # Create p_assignments and set new_ids depending on the strategy. 161 if partition_strategy == "mod": 162 p_assignments = flat_ids % np 163 new_ids = flat_ids // np 164 elif partition_strategy == "div": 165 # Compute num_total_ids as the sum of dim-0 of params, then assign to 166 # partitions based on a constant number of ids per partition. Optimize 167 # if we already know the full shape statically. 168 dim_0_size = tensor_shape.Dimension( 169 tensor_shape.dimension_value(params[0].get_shape()[0])) 170 for p in range(1, np): 171 dim_0_size += tensor_shape.Dimension( 172 tensor_shape.dimension_value(params[p].get_shape()[0])) 173 if dim_0_size.value: 174 num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) 175 else: 176 dim_0_sizes = [] 177 for p in range(np): 178 param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0]) 179 if param_p_dim is not None: 180 dim_0_sizes.append(param_p_dim) 181 else: 182 with _colocate_with(params[p]): 183 dim_0_sizes.append(array_ops.shape(params[p])[0]) 184 num_total_ids = math_ops.reduce_sum( 185 math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) 186 ids_per_partition = num_total_ids // np 187 extras = num_total_ids % np 188 189 p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), 190 (flat_ids - extras) // 191 ids_per_partition) 192 193 # Emulate a conditional using a boolean indicator tensor 194 new_ids = array_ops.where(p_assignments < extras, 195 flat_ids % (ids_per_partition + 1), 196 (flat_ids - extras) % ids_per_partition) 197 else: 198 raise ValueError( 199 f"Unrecognized partition strategy: {partition_strategy}." 200 "Must be one of either `mod` or `div`.") 201 202 # Cast partition assignments to int32 for use in dynamic_partition. 203 # There really should not be more than 2^32 partitions. 204 p_assignments = math_ops.cast(p_assignments, dtypes.int32) 205 # Partition list of ids based on assignments into np separate lists 206 gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) 207 # Similarly, partition the original indices. 208 pindices = data_flow_ops.dynamic_partition(original_indices, 209 p_assignments, np) 210 # Do np separate lookups, finding embeddings for plist[p] in params[p] 211 partitioned_result = [] 212 for p in range(np): 213 pids = gather_ids[p] 214 with ops.device_v2(None): 215 with _colocate_with(params[p]): 216 result = array_ops.gather(params[p], pids) 217 if transform_fn: 218 # If transform_fn is provided, the clip_by_norm precedes 219 # the transform and hence must be co-located. See below 220 # for the counterpart if transform_fn is not provided. 221 result = transform_fn(_clip(result, pids, max_norm)) 222 partitioned_result.append(result) 223 # Stitch these back together 224 ret = data_flow_ops.parallel_dynamic_stitch( 225 pindices, partitioned_result, name=name) 226 227 # Determine the static element shape. 228 if transform_fn is None: 229 element_shape_s = params[0].get_shape()[1:] 230 for p in params[1:]: 231 element_shape_s = element_shape_s.merge_with(p.get_shape()[1:]) 232 else: 233 element_shape_s = ret.get_shape()[1:] 234 235 # Compute the dynamic element shape. 236 if element_shape_s.is_fully_defined(): 237 element_shape_d = element_shape_s 238 elif transform_fn is None: 239 # It's important that we compute params[0].shape on the right device 240 # to avoid data motion. 241 with _colocate_with(params[0]): 242 params_shape = array_ops.shape(params[0]) 243 element_shape_d = params_shape[1:] 244 else: 245 element_shape_d = array_ops.shape(ret)[1:] 246 247 # Reshape to reverse the flattening of ids. 248 ret = array_ops.reshape( 249 ret, array_ops.concat([array_ops.shape(ids), element_shape_d], 0)) 250 251 # Normally the reshape is sufficient, but setting shape explicitly 252 # teaches shape inference that params[1:].get_shape() matters 253 # (in the case that transform_fn is None). 254 ret.set_shape(ids.get_shape().concatenate(element_shape_s)) 255 if not transform_fn: 256 # If transform_fn was provided, the clip_by_norm was done above. 257 ret = _clip(ret, ids, max_norm) 258 return ret 259 260 261@tf_export(v1=["nn.embedding_lookup"]) 262@dispatch.add_dispatch_support 263def embedding_lookup( 264 params, 265 ids, 266 partition_strategy="mod", 267 name=None, 268 validate_indices=True, # pylint: disable=unused-argument 269 max_norm=None): 270 """Looks up embeddings for the given `ids` from a list of tensors. 271 272 This function is used to perform parallel lookups on the list of tensors in 273 `params`. It is a generalization of `tf.gather`, where `params` is 274 interpreted as a partitioning of a large embedding tensor. `params` may be 275 a `PartitionedVariable` as returned by using `tf.compat.v1.get_variable()` 276 with a partitioner. 277 278 If `len(params) > 1`, each element `id` of `ids` is partitioned between 279 the elements of `params` according to the `partition_strategy`. 280 In all strategies, if the id space does not evenly divide the number of 281 partitions, each of the first `(max_id + 1) % len(params)` partitions will 282 be assigned one more id. 283 284 If `partition_strategy` is `"mod"`, we assign each id to partition 285 `p = id % len(params)`. For instance, 286 13 ids are split across 5 partitions as: 287 `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]` 288 289 If `partition_strategy` is `"div"`, we assign ids to partitions in a 290 contiguous manner. In this case, 13 ids are split across 5 partitions as: 291 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]` 292 293 If the input ids are ragged tensors, partition variables are not supported and 294 the partition strategy and the max_norm are ignored. 295 The results of the lookup are concatenated into a dense 296 tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. 297 298 Args: 299 params: A single tensor representing the complete embedding tensor, or a 300 list of P tensors all of same shape except for the first dimension, 301 representing sharded embedding tensors. Alternatively, a 302 `PartitionedVariable`, created by partitioning along dimension 0. Each 303 element must be appropriately sized for the given `partition_strategy`. 304 ids: A `Tensor` or a 'RaggedTensor' with type `int32` or `int64` containing 305 the ids to be looked up in `params`. 306 partition_strategy: A string specifying the partitioning strategy, relevant 307 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 308 is `"mod"`. 309 name: A name for the operation (optional). 310 validate_indices: DEPRECATED. If this operation is assigned to CPU, values 311 in `indices` are always validated to be within range. If assigned to GPU, 312 out-of-bound indices result in safe but unspecified behavior, which may 313 include raising an error. 314 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 315 than this value. 316 317 Returns: 318 A `Tensor` or a 'RaggedTensor', depending on the input, with the same type 319 as the tensors in `params`. 320 321 Raises: 322 ValueError: If `params` is empty. 323 """ 324 if isinstance(ids, ragged_tensor.RaggedTensor): 325 return embedding_lookup_ragged(params, ids, 326 partition_strategy=partition_strategy, 327 max_norm=max_norm, 328 name=name) 329 330 return _embedding_lookup_and_transform( 331 params=params, 332 ids=ids, 333 partition_strategy=partition_strategy, 334 name=name, 335 max_norm=max_norm, 336 transform_fn=None) 337 338 339@tf_export("nn.embedding_lookup", v1=[]) 340@dispatch.add_dispatch_support 341def embedding_lookup_v2(params, ids, max_norm=None, name=None): 342 """Looks up embeddings for the given `ids` from a list of tensors. 343 344 This function is used to perform parallel lookups on the list of tensors in 345 `params`. It is a generalization of `tf.gather`, where `params` is 346 interpreted as a partitioning of a large embedding tensor. 347 348 If `len(params) > 1`, each element `id` of `ids` is partitioned between the 349 elements of `params` according to the "div" partition strategy, which means we 350 assign ids to partitions in a contiguous manner. For instance, 13 ids are 351 split across 5 partitions as: 352 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 353 354 If the id space does not evenly divide the number of partitions, each of the 355 first `(max_id + 1) % len(params)` partitions will be assigned one more id. 356 357 The results of the lookup are concatenated into a dense 358 tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. 359 360 Args: 361 params: A single tensor representing the complete embedding tensor, or a 362 list of tensors all of same shape except for the first dimension, 363 representing sharded embedding tensors following "div" partition strategy. 364 ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked 365 up in `params`. 366 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 367 than this value. 368 name: A name for the operation (optional). 369 370 Returns: 371 A `Tensor` with the same type as the tensors in `params`. 372 373 For instance, if `params` is a 5x2 matrix: 374 375 ```python 376 [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] 377 ``` 378 379 or a list of matrices: 380 381 ```python 382 params[0]: [[1, 2], [3, 4]] 383 params[1]: [[5, 6], [7, 8]] 384 params[2]: [[9, 10]] 385 ``` 386 387 and `ids` is: 388 389 ```python 390 [0, 3, 4] 391 ``` 392 393 The output will be a 3x2 matrix: 394 395 ```python 396 [[1, 2], [7, 8], [9, 10]] 397 ``` 398 399 Raises: 400 ValueError: If `params` is empty. 401 """ 402 return embedding_lookup(params, ids, "div", name, max_norm=max_norm) 403 404 405@tf_export(v1=["nn.embedding_lookup_sparse"]) 406@dispatch.add_dispatch_support 407def embedding_lookup_sparse(params, 408 sp_ids, 409 sp_weights, 410 partition_strategy="mod", 411 name=None, 412 combiner=None, 413 max_norm=None): 414 """Looks up embeddings for the given ids and weights from a list of tensors. 415 416 This op assumes that there is at least one id for each row in the dense tensor 417 represented by sp_ids (i.e. there are no rows with empty features), and that 418 all the indices of sp_ids are in canonical row-major order. 419 420 `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s with rank of 2. 421 Embeddings are always aggregated along the last dimension. 422 423 It also assumes that all id values lie in the range [0, p0), where p0 424 is the sum of the size of params along dimension 0. 425 426 Args: 427 params: A single tensor representing the complete embedding tensor, or a 428 list tensors all of same shape except for the first dimension, 429 representing sharded embedding tensors. Alternatively, a 430 `PartitionedVariable`, created by partitioning along dimension 0. Each 431 element must be appropriately sized for the given `partition_strategy`. 432 sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size 433 and M is arbitrary. 434 sp_weights: either a `SparseTensor` of float / double weights, or `None` to 435 indicate all weights should be taken to be 1. If specified, `sp_weights` 436 must have exactly the same shape and indices as `sp_ids`. 437 partition_strategy: A string specifying the partitioning strategy, relevant 438 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 439 is `"mod"`. See `tf.nn.embedding_lookup` for more details. 440 name: Optional name for the op. 441 combiner: A string specifying the reduction op. Currently "mean", "sqrtn" 442 and "sum" are supported. "sum" computes the weighted sum of the embedding 443 results for each row. "mean" is the weighted sum divided by the total 444 weight. "sqrtn" is the weighted sum divided by the square root of the sum 445 of the squares of the weights. Defaults to `mean`. 446 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 447 than this value, before combining. 448 449 Returns: 450 A dense tensor representing the combined embeddings for the 451 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 452 looks up the embeddings for all ids in that row, multiplies them by the 453 corresponding weight, and combines these embeddings as specified. 454 455 In other words, if 456 457 `shape(combined params) = [p0, p1, ..., pm]` 458 459 and 460 461 `shape(sp_ids) = shape(sp_weights) = [d0, d1]` 462 463 then 464 465 `shape(output) = [d0, p1, ..., pm]`. 466 467 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 468 469 ```python 470 [0, 0]: id 1, weight 2.0 471 [0, 1]: id 3, weight 0.5 472 [1, 0]: id 0, weight 1.0 473 [2, 3]: id 1, weight 3.0 474 ``` 475 476 with `combiner`="mean", then the output will be a 3x20 matrix where 477 478 ```python 479 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 480 output[1, :] = (params[0, :] * 1.0) / 1.0 481 output[2, :] = (params[1, :] * 3.0) / 3.0 482 ``` 483 484 Raises: 485 TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is 486 neither `None` nor `SparseTensor`. 487 ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}. 488 """ 489 if combiner is None: 490 combiner = "mean" 491 if combiner not in ("mean", "sqrtn", "sum"): 492 raise ValueError( 493 f"combiner must be one of 'mean', 'sqrtn' or 'sum', got {combiner}") 494 if isinstance(params, variables.PartitionedVariable): 495 params = list(params) # Iterate to get the underlying Variables. 496 if not isinstance(params, list): 497 params = [params] 498 if not isinstance(sp_ids, sparse_tensor.SparseTensor): 499 raise TypeError(f"sp_ids must be SparseTensor, got {type(sp_ids)}") 500 ignore_weights = sp_weights is None 501 if not ignore_weights: 502 if not isinstance(sp_weights, sparse_tensor.SparseTensor): 503 raise TypeError(f"sp_weights must be either None or SparseTensor," 504 f"got {type(sp_weights)}") 505 sp_ids.values.get_shape().assert_is_compatible_with( 506 sp_weights.values.get_shape()) 507 sp_ids.indices.get_shape().assert_is_compatible_with( 508 sp_weights.indices.get_shape()) 509 sp_ids.dense_shape.get_shape().assert_is_compatible_with( 510 sp_weights.dense_shape.get_shape()) 511 # TODO(yleon): Add enhanced node assertions to verify that sp_ids and 512 # sp_weights have equal indices and shapes. 513 514 with ops.name_scope(name, "embedding_lookup_sparse", 515 params + [sp_ids]) as name: 516 segment_ids = sp_ids.indices[:, 0] 517 518 ids = sp_ids.values 519 ids, idx = array_ops.unique(ids) 520 521 embeddings = embedding_lookup( 522 params, ids, partition_strategy=partition_strategy, max_norm=max_norm) 523 if not ignore_weights: 524 if segment_ids.dtype != dtypes.int32: 525 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 526 527 weights = sp_weights.values 528 embeddings = array_ops.gather(embeddings, idx) 529 530 original_dtype = embeddings.dtype 531 if embeddings.dtype in (dtypes.float16, dtypes.bfloat16): 532 # Cast low-precision embeddings to float32 during the computation to 533 # avoid numerical issues. 534 embeddings = math_ops.cast(embeddings, dtypes.float32) 535 if weights.dtype != embeddings.dtype: 536 weights = math_ops.cast(weights, embeddings.dtype) 537 538 # Reshape weights to allow broadcast 539 ones_shape = array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0) 540 ones = array_ops.ones(ones_shape, dtype=dtypes.int32) 541 bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 542 0) 543 544 orig_weights_shape = weights.get_shape() 545 weights = array_ops.reshape(weights, bcast_weights_shape) 546 547 # Set the weight shape, since after reshaping to bcast_weights_shape, 548 # the shape becomes None. 549 if embeddings.get_shape().ndims is not None: 550 weights.set_shape( 551 orig_weights_shape.concatenate( 552 [1 for _ in range(embeddings.get_shape().ndims - 1)])) 553 554 embeddings *= weights 555 556 if combiner == "sum": 557 embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name) 558 elif combiner == "mean": 559 embeddings = math_ops.segment_sum(embeddings, segment_ids) 560 weight_sum = math_ops.segment_sum(weights, segment_ids) 561 embeddings = math_ops.div_no_nan(embeddings, weight_sum, name=name) 562 elif combiner == "sqrtn": 563 embeddings = math_ops.segment_sum(embeddings, segment_ids) 564 weights_squared = math_ops.pow(weights, 2) 565 weight_sum = math_ops.segment_sum(weights_squared, segment_ids) 566 weight_sum_sqrt = math_ops.sqrt(weight_sum) 567 embeddings = math_ops.div_no_nan(embeddings, weight_sum_sqrt, name=name) 568 else: 569 assert False, "Unrecognized combiner" 570 if embeddings.dtype != original_dtype: 571 embeddings = math_ops.cast(embeddings, original_dtype) 572 else: 573 if segment_ids.dtype not in (dtypes.int32, dtypes.int64): 574 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 575 assert idx is not None 576 if combiner == "sum": 577 embeddings = math_ops.sparse_segment_sum( 578 embeddings, idx, segment_ids, name=name) 579 elif combiner == "mean": 580 embeddings = math_ops.sparse_segment_mean( 581 embeddings, idx, segment_ids, name=name) 582 elif combiner == "sqrtn": 583 embeddings = math_ops.sparse_segment_sqrt_n( 584 embeddings, idx, segment_ids, name=name) 585 else: 586 assert False, "Unrecognized combiner" 587 588 return embeddings 589 590 591@tf_export("nn.embedding_lookup_sparse", v1=[]) 592@dispatch.add_dispatch_support 593def embedding_lookup_sparse_v2(params, 594 sp_ids, 595 sp_weights, 596 combiner=None, 597 max_norm=None, 598 name=None): 599 """Looks up embeddings for the given ids and weights from a list of tensors. 600 601 This op assumes that there is at least one id for each row in the dense tensor 602 represented by sp_ids (i.e. there are no rows with empty features), and that 603 all the indices of sp_ids are in canonical row-major order. 604 605 `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s with rank of 2. 606 Embeddings are always aggregated along the last dimension. 607 608 It also assumes that all id values lie in the range [0, p0), where p0 609 is the sum of the size of params along dimension 0. 610 611 If `len(params) > 1`, each element of `sp_ids` is partitioned between the 612 elements of `params` according to the "div" partition strategy, which means we 613 assign ids to partitions in a contiguous manner. For instance, 13 ids are 614 split across 5 partitions as: 615 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 616 617 If the id space does not evenly divide the number of partitions, each of the 618 first `(max_id + 1) % len(params)` partitions will be assigned one more id. 619 620 Args: 621 params: A single tensor representing the complete embedding tensor, or a 622 list of tensors all of same shape except for the first dimension, 623 representing sharded embedding tensors following "div" partition strategy. 624 sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size 625 and M is arbitrary. 626 sp_weights: either a `SparseTensor` of float / double weights, or `None` to 627 indicate all weights should be taken to be 1. If specified, `sp_weights` 628 must have exactly the same shape and indices as `sp_ids`. 629 combiner: A string specifying the reduction op. Currently "mean", "sqrtn" 630 and "sum" are supported. "sum" computes the weighted sum of the embedding 631 results for each row. "mean" is the weighted sum divided by the total 632 weight. "sqrtn" is the weighted sum divided by the square root of the sum 633 of the squares of the weights. Defaults to `mean`. 634 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 635 than this value, before combining. 636 name: Optional name for the op. 637 638 Returns: 639 A dense tensor representing the combined embeddings for the 640 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 641 looks up the embeddings for all ids in that row, multiplies them by the 642 corresponding weight, and combines these embeddings as specified. 643 644 In other words, if 645 646 `shape(combined params) = [p0, p1, ..., pm]` 647 648 and 649 650 `shape(sp_ids) = shape(sp_weights) = [d0, d1]` 651 652 then 653 654 `shape(output) = [d0, p1, ..., pm]`. 655 656 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 657 658 ```python 659 [0, 0]: id 1, weight 2.0 660 [0, 1]: id 3, weight 0.5 661 [1, 0]: id 0, weight 1.0 662 [2, 3]: id 1, weight 3.0 663 ``` 664 665 with `combiner`="mean", then the output will be a 3x20 matrix where 666 667 ```python 668 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 669 output[1, :] = (params[0, :] * 1.0) / 1.0 670 output[2, :] = (params[1, :] * 3.0) / 3.0 671 ``` 672 673 Raises: 674 TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is 675 neither `None` nor `SparseTensor`. 676 ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}. 677 """ 678 return embedding_lookup_sparse(params, sp_ids, sp_weights, "div", name, 679 combiner, max_norm) 680 681 682@tf_export("nn.safe_embedding_lookup_sparse", v1=[]) 683@dispatch.add_dispatch_support 684def safe_embedding_lookup_sparse_v2(embedding_weights, 685 sparse_ids, 686 sparse_weights=None, 687 combiner="mean", 688 default_id=None, 689 max_norm=None, 690 name=None): 691 """Lookup embedding results, accounting for invalid IDs and empty features. 692 693 The partitioned embedding in `embedding_weights` must all be the same shape 694 except for the first dimension. The first dimension is allowed to vary as the 695 vocabulary size is not necessarily a multiple of num of shards. 696 697 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs 698 with non-positive weight. For an entry with no features, the embedding vector 699 for `default_id` is returned, or the 0-vector if `default_id` is not supplied. 700 701 The ids and weights may be multi-dimensional. Embeddings are always aggregated 702 along the last dimension. 703 704 If `len(embedding_weights) > 1`, each element `id` of `ids` is partitioned 705 between the elements of `embedding_weights` according to the "div" partition 706 strategy, which means we assign ids to partitions in a contiguous manner. For 707 instance, 13 ids are split across 5 partitions as: 708 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`. 709 710 If the id space does not evenly divide the number of partitions, each of the 711 first `(max_id + 1) % len(embedding_weights)` partitions will be assigned one 712 more id. 713 714 Args: 715 embedding_weights: A single tensor representing the complete embedding 716 tensor, or a list of tensors all of same shape except for the first 717 dimension, representing sharded embedding tensors following "div" 718 partition strategy. 719 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the 720 ids. `d_0` is typically batch size. 721 sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing 722 float weights corresponding to `sparse_ids`, or `None` if all weights are 723 be assumed to be 1.0. 724 combiner: A string specifying how to combine embedding results for each 725 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the 726 default. 727 default_id: The id to use for an entry with no features. Defaults to 728 0-vector. 729 max_norm: If not `None`, all embeddings are l2-normalized to max_norm before 730 combining. 731 name: A name for this operation (optional). 732 733 Returns: 734 A dense tensor representing the combined embeddings for the 735 sparse ids. For each row in the dense tensor represented by `sparse_ids`, 736 the op looks up the embeddings for all ids in that row, multiplies them by 737 the corresponding weight, and combines these embeddings as specified. 738 739 In other words, if 740 741 `shape(combined embedding_weights) = [p0, p1, ..., pm]` 742 743 and 744 745 `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]` 746 747 then 748 749 `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`. 750 751 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 752 753 ```python 754 [0, 0]: id 1, weight 2.0 755 [0, 1]: id 3, weight 0.5 756 [1, 0]: id -1, weight 1.0 757 [2, 3]: id 1, weight 3.0 758 ``` 759 760 `default_id` is 0. 761 762 with `combiner`="mean", then the output will be a 3x20 matrix where 763 764 ```python 765 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 766 output[1, :] = (params[0, :] * 1.0) / 1.0 767 output[2, :] = (params[1, :] * 3.0) / 3.0 768 ``` 769 770 Raises: 771 ValueError: if `embedding_weights` is empty. 772 """ 773 return safe_embedding_lookup_sparse( 774 embedding_weights, 775 sparse_ids, 776 sparse_weights=sparse_weights, 777 combiner=combiner, 778 default_id=default_id, 779 name=name, 780 partition_strategy="div", 781 max_norm=max_norm) 782 783 784@tf_export(v1=["nn.safe_embedding_lookup_sparse"]) 785@dispatch.add_dispatch_support 786def safe_embedding_lookup_sparse(embedding_weights, 787 sparse_ids, 788 sparse_weights=None, 789 combiner="mean", 790 default_id=None, 791 name=None, 792 partition_strategy="div", 793 max_norm=None): 794 """Lookup embedding results, accounting for invalid IDs and empty features. 795 796 The partitioned embedding in `embedding_weights` must all be the same shape 797 except for the first dimension. The first dimension is allowed to vary as the 798 vocabulary size is not necessarily a multiple of `P`. `embedding_weights` 799 may be a `PartitionedVariable` as returned by using 800 `tf.compat.v1.get_variable()` with a 801 partitioner. 802 803 Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs 804 with non-positive weight. For an entry with no features, the embedding vector 805 for `default_id` is returned, or the 0-vector if `default_id` is not supplied. 806 807 The ids and weights may be multi-dimensional. Embeddings are always aggregated 808 along the last dimension. 809 810 Args: 811 embedding_weights: A single tensor representing the complete embedding 812 tensor, or a list tensors all of same shape except for the first 813 dimension, representing sharded embedding tensors. Alternatively, a 814 `PartitionedVariable`, created by partitioning along dimension 0. Each 815 element must be appropriately sized for the given `partition_strategy`. 816 sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the 817 ids. `d_0` is typically batch size. 818 sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing 819 float weights corresponding to `sparse_ids`, or `None` if all weights are 820 be assumed to be 1.0. 821 combiner: A string specifying how to combine embedding results for each 822 entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the 823 default. 824 default_id: The id to use for an entry with no features. 825 name: A name for this operation (optional). 826 partition_strategy: A string specifying the partitioning strategy. Currently 827 `"div"` and `"mod"` are supported. Default is `"div"`. 828 max_norm: If not `None`, all embeddings are l2-normalized to max_norm before 829 combining. 830 831 Returns: 832 A dense tensor representing the combined embeddings for the 833 sparse ids. For each row in the dense tensor represented by `sp_ids`, the op 834 looks up the embeddings for all ids in that row, multiplies them by the 835 corresponding weight, and combines these embeddings as specified. 836 837 In other words, if 838 839 `shape(combined embedding_weights) = [p0, p1, ..., pm]` 840 841 and 842 843 `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]` 844 845 then 846 847 `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`. 848 849 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 850 851 ```python 852 [0, 0]: id 1, weight 2.0 853 [0, 1]: id 3, weight 0.5 854 [1, 0]: id -1, weight 1.0 855 [2, 3]: id 1, weight 3.0 856 ``` 857 858 `default_id` is 0. 859 860 with `combiner`="mean", then the output will be a 3x20 matrix where 861 862 ```python 863 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 864 output[1, :] = (params[0, :] * 1.0) / 1.0 865 output[2, :] = (params[1, :] * 3.0) / 3.0 866 ``` 867 868 Raises: 869 ValueError: if `embedding_weights` is empty. 870 """ 871 if embedding_weights is None: 872 raise ValueError(f"Missing embedding_weights {embedding_weights}.") 873 if isinstance(embedding_weights, variables.PartitionedVariable): 874 embedding_weights = list(embedding_weights) # get underlying Variables. 875 if not isinstance(embedding_weights, list): 876 embedding_weights = [embedding_weights] 877 if len(embedding_weights) < 1: 878 raise ValueError(f"Missing embedding_weights {embedding_weights}.") 879 880 dtype = sparse_weights.dtype if sparse_weights is not None else None 881 embedding_weights = [ 882 w if (isinstance(w, resource_variable_ops.ResourceVariable) 883 and dtype in (None, w.dtype)) 884 else ops.convert_to_tensor(w, dtype=dtype) 885 for w in embedding_weights 886 ] 887 888 with ops.name_scope(name, "embedding_lookup", embedding_weights + 889 [sparse_ids, sparse_weights]) as scope: 890 # Reshape higher-rank sparse ids and weights to linear segment ids. 891 original_shape = sparse_ids.dense_shape 892 original_rank_dim = tensor_shape.dimension_value( 893 sparse_ids.dense_shape.get_shape()[0]) 894 original_rank = ( 895 array_ops.size(original_shape) 896 if original_rank_dim is None else original_rank_dim) 897 sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [ 898 math_ops.reduce_prod( 899 array_ops.slice(original_shape, [0], [original_rank - 1])), 900 array_ops.gather(original_shape, original_rank - 1) 901 ]) 902 if sparse_weights is not None: 903 sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices, 904 sparse_weights.values, 905 sparse_ids.dense_shape) 906 907 # Prune invalid ids and weights. 908 sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) 909 if combiner != "sum": 910 sparse_ids, sparse_weights = _prune_invalid_weights( 911 sparse_ids, sparse_weights) 912 913 # Fill in dummy values for empty features, if necessary. 914 sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows( 915 sparse_ids, default_id or 0) 916 if sparse_weights is not None: 917 sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0) 918 919 result = embedding_lookup_sparse( 920 embedding_weights, 921 sparse_ids, 922 sparse_weights, 923 combiner=combiner, 924 partition_strategy=partition_strategy, 925 name=None if default_id is None else scope, 926 max_norm=max_norm) 927 928 if default_id is None: 929 # Broadcast is_row_empty to the same shape as embedding_lookup_result, 930 # for use in Select. 931 is_row_empty = array_ops.tile( 932 array_ops.reshape(is_row_empty, [-1, 1]), 933 array_ops.stack([1, array_ops.shape(result)[1]])) 934 935 result = array_ops.where( 936 is_row_empty, array_ops.zeros_like(result), result, name=scope) 937 938 # Reshape back from linear ids back into higher-dimensional dense result. 939 final_result = array_ops.reshape( 940 result, 941 array_ops.concat([ 942 array_ops.slice( 943 math_ops.cast(original_shape, dtypes.int32), [0], 944 [original_rank - 1]), 945 array_ops.slice(array_ops.shape(result), [1], [-1]) 946 ], 0)) 947 final_result.set_shape( 948 tensor_shape.unknown_shape( 949 (tensor_shape.Dimension(original_rank_dim) - 1).value).concatenate( 950 result.get_shape()[1:])) 951 return final_result 952 953 954def embedding_lookup_ragged(embedding_weights, 955 ragged_ids, 956 partition_strategy="mod", 957 max_norm=None, 958 name=None): 959 """Look up the ragged ids in a list of embedding tensors. 960 961 Args: 962 embedding_weights: A tensor representing the complete embedding tensor 963 having the shape [e1, ...eM] 964 ragged_ids: A 'RaggedTensor' with type 'int32' or 'int64' containing the ids 965 to be looked up in 'embedding_weights' of shape [r0, ..rN]. Values must be 966 in the range '[0, embedding_weights.shape[0]]'. 967 partition_strategy: A string specifying the partitioning strategy. 968 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 969 than this value. 970 name: A name for the operation (optional) 971 972 Returns: 973 A ragged tensor of shape [r0, r1, ...rN, e1, ...eM]. 974 975 Raises: 976 ValueError: whether the embedding_weights is empty or the ragged_ids is 977 not a RaggedTensor. 978 """ 979 if embedding_weights is None: 980 raise ValueError("The embedding weights must be specified.") 981 if isinstance(embedding_weights, (list, tuple)) and not embedding_weights: 982 raise ValueError("The embedding weights should not be empty.") 983 if ragged_ids.dtype != dtypes.int32 and ragged_ids.dtype != dtypes.int64: 984 raise ValueError("The values contained by the inputs have type " 985 f"{str(ragged_ids.dtype)}" 986 " and cannot be processed. All values" 987 " should be indices, either of type `in32` or `int64`.") 988 989 with ops.name_scope(name, "embedding_lookup_ragged") as name: 990 looked_up_ragged = ragged_functional_ops.map_flat_values( 991 embedding_lookup, 992 params=embedding_weights, 993 ids=ragged_ids, 994 partition_strategy=partition_strategy, 995 max_norm=max_norm) 996 997 return looked_up_ragged 998 999 1000def _prune_invalid_ids(sparse_ids, sparse_weights): 1001 """Prune invalid IDs (< 0) from the input ids and weights.""" 1002 is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) 1003 if sparse_weights is not None: 1004 is_id_valid = math_ops.logical_and( 1005 is_id_valid, 1006 array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) 1007 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) 1008 if sparse_weights is not None: 1009 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) 1010 return sparse_ids, sparse_weights 1011 1012 1013def _prune_invalid_weights(sparse_ids, sparse_weights): 1014 """Prune invalid weights (< 0) from the input ids and weights.""" 1015 if sparse_weights is not None: 1016 is_weights_valid = math_ops.greater(sparse_weights.values, 0) 1017 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) 1018 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) 1019 return sparse_ids, sparse_weights 1020