1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Implementation of Neural Net (NN) functions.""" 16 17import math 18 19from tensorflow.python.distribute import distribution_strategy_context as ds 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import candidate_sampling_ops 25from tensorflow.python.ops import check_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import custom_gradient 28from tensorflow.python.ops import embedding_ops 29from tensorflow.python.ops import gen_array_ops # pylint: disable=unused-import 30from tensorflow.python.ops import gen_nn_ops 31from tensorflow.python.ops import gen_sparse_ops 32from tensorflow.python.ops import linalg_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import nn_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.ops.losses import util as losses_util 37from tensorflow.python.platform import device_context 38from tensorflow.python.util import dispatch 39from tensorflow.python.util.deprecation import deprecated_args 40from tensorflow.python.util.deprecation import deprecated_argument_lookup 41from tensorflow.python.util.tf_export import tf_export 42 43 44@tf_export("nn.log_poisson_loss") 45@dispatch.add_dispatch_support 46def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None): 47 """Computes log Poisson loss given `log_input`. 48 49 Gives the log-likelihood loss between the prediction and the target under the 50 assumption that the target has a Poisson distribution. 51 Caveat: By default, this is not the exact loss, but the loss minus a 52 constant term [log(z!)]. That has no effect for optimization, but 53 does not play well with relative loss comparisons. To compute an 54 approximation of the log factorial term, specify 55 compute_full_loss=True to enable Stirling's Approximation. 56 57 For brevity, let `c = log(x) = log_input`, `z = targets`. The log Poisson 58 loss is 59 60 -log(exp(-x) * (x^z) / z!) 61 = -log(exp(-x) * (x^z)) + log(z!) 62 ~ -log(exp(-x)) - log(x^z) [+ z * log(z) - z + 0.5 * log(2 * pi * z)] 63 [ Note the second term is the Stirling's Approximation for log(z!). 64 It is invariant to x and does not affect optimization, though 65 important for correct relative loss comparisons. It is only 66 computed when compute_full_loss == True. ] 67 = x - z * log(x) [+ z * log(z) - z + 0.5 * log(2 * pi * z)] 68 = exp(c) - z * c [+ z * log(z) - z + 0.5 * log(2 * pi * z)] 69 70 Args: 71 targets: A `Tensor` of the same type and shape as `log_input`. 72 log_input: A `Tensor` of type `float32` or `float64`. 73 compute_full_loss: whether to compute the full loss. If false, a constant 74 term is dropped in favor of more efficient optimization. 75 name: A name for the operation (optional). 76 77 Returns: 78 A `Tensor` of the same shape as `log_input` with the componentwise 79 logistic losses. 80 81 Raises: 82 ValueError: If `log_input` and `targets` do not have the same shape. 83 """ 84 with ops.name_scope(name, "log_poisson_loss", [log_input, targets]) as name: 85 log_input = ops.convert_to_tensor(log_input, name="log_input") 86 targets = ops.convert_to_tensor(targets, name="targets") 87 try: 88 targets.get_shape().assert_is_compatible_with(log_input.get_shape()) 89 except ValueError: 90 raise ValueError( 91 "`log_input` and `targets` must have the same shape, received " 92 f"({log_input.get_shape()} vs {targets.get_shape()}).") 93 94 result = math_ops.exp(log_input) - log_input * targets 95 if compute_full_loss: 96 # need to create constant tensors here so that their dtypes can be matched 97 # to that of the targets. 98 point_five = constant_op.constant(0.5, dtype=targets.dtype) 99 two_pi = constant_op.constant(2 * math.pi, dtype=targets.dtype) 100 101 stirling_approx = (targets * math_ops.log(targets)) - targets + ( 102 point_five * math_ops.log(two_pi * targets)) 103 zeros = array_ops.zeros_like(targets, dtype=targets.dtype) 104 ones = array_ops.ones_like(targets, dtype=targets.dtype) 105 cond = math_ops.logical_and(targets >= zeros, targets <= ones) 106 result += array_ops.where(cond, zeros, stirling_approx) 107 return result 108 109 110@tf_export(v1=["nn.sigmoid_cross_entropy_with_logits"]) 111@dispatch.add_dispatch_support 112def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name 113 _sentinel=None, 114 labels=None, 115 logits=None, 116 name=None): 117 """See sigmoid_cross_entropy_with_logits_v2.""" 118 # pylint: disable=protected-access 119 nn_ops._ensure_xent_args("sigmoid_cross_entropy_with_logits", _sentinel, 120 labels, logits) 121 # pylint: enable=protected-access 122 123 with ops.name_scope(name, "logistic_loss", [logits, labels]) as name: 124 logits = ops.convert_to_tensor(logits, name="logits") 125 labels = ops.convert_to_tensor(labels, name="labels") 126 try: 127 labels.get_shape().assert_is_compatible_with(logits.get_shape()) 128 except ValueError: 129 raise ValueError("`logits` and `labels` must have the same shape, " 130 f"received ({logits.get_shape()} vs " 131 f"{labels.get_shape()}).") 132 133 # The logistic loss formula from above is 134 # x - x * z + log(1 + exp(-x)) 135 # For x < 0, a more numerically stable formula is 136 # -x * z + log(1 + exp(x)) 137 # Note that these two expressions can be combined into the following: 138 # max(x, 0) - x * z + log(1 + exp(-abs(x))) 139 # To allow computing gradients at zero, we define custom versions of max and 140 # abs functions. 141 zeros = array_ops.zeros_like(logits, dtype=logits.dtype) 142 cond = (logits >= zeros) 143 relu_logits = array_ops.where(cond, logits, zeros) 144 neg_abs_logits = array_ops.where(cond, -logits, logits) # pylint: disable=invalid-unary-operand-type 145 return math_ops.add( 146 relu_logits - logits * labels, 147 math_ops.log1p(math_ops.exp(neg_abs_logits)), 148 name=name) 149 150 151# Note: intentionally calling this v2 to not allow existing code with indirect 152# imports to ignore the sentinel behavior. 153@tf_export("nn.sigmoid_cross_entropy_with_logits", v1=[]) 154@dispatch.register_binary_elementwise_api 155@dispatch.add_dispatch_support 156def sigmoid_cross_entropy_with_logits_v2( # pylint: disable=invalid-name 157 labels=None, 158 logits=None, 159 name=None): 160 r"""Computes sigmoid cross entropy given `logits`. 161 162 Measures the probability error in tasks with two outcomes in which each 163 outcome is independent and need not have a fully certain label. For instance, 164 one could perform a regression where the probability of an event happening is 165 known and used as a label. This loss may also be used for binary 166 classification, where labels are either zero or one. 167 168 For brevity, let `x = logits`, `z = labels`. The logistic loss is 169 170 z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 171 = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) 172 = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) 173 = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) 174 = (1 - z) * x + log(1 + exp(-x)) 175 = x - x * z + log(1 + exp(-x)) 176 177 For x < 0, to avoid overflow in exp(-x), we reformulate the above 178 179 x - x * z + log(1 + exp(-x)) 180 = log(exp(x)) - x * z + log(1 + exp(-x)) 181 = - x * z + log(1 + exp(x)) 182 183 Hence, to ensure stability and avoid overflow, the implementation uses this 184 equivalent formulation 185 186 max(x, 0) - x * z + log(1 + exp(-abs(x))) 187 188 `logits` and `labels` must have the same type and shape. 189 190 >>> logits = tf.constant([1., -1., 0., 1., -1., 0., 0.]) 191 >>> labels = tf.constant([0., 0., 0., 1., 1., 1., 0.5]) 192 >>> tf.nn.sigmoid_cross_entropy_with_logits( 193 ... labels=labels, logits=logits).numpy() 194 array([1.3132617, 0.3132617, 0.6931472, 0.3132617, 1.3132617, 0.6931472, 195 0.6931472], dtype=float32) 196 197 Compared to the losses which handle multiple outcomes, 198 `tf.nn.softmax_cross_entropy_with_logits` for general multi-class 199 classification and `tf.nn.sparse_softmax_cross_entropy_with_logits` for more 200 efficient multi-class classification with hard labels, 201 `sigmoid_cross_entropy_with_logits` is a slight simplification for binary 202 classification: 203 204 sigmoid(x) = softmax([x, 0])[0] 205 206 $$\frac{1}{1 + e^{-x}} = \frac{e^x}{e^x + e^0}$$ 207 208 While `sigmoid_cross_entropy_with_logits` works for soft binary labels 209 (probabilities between 0 and 1), it can also be used for binary classification 210 where the labels are hard. There is an equivalence between all three symbols 211 in this case, with a probability 0 indicating the second class or 1 indicating 212 the first class: 213 214 >>> sigmoid_logits = tf.constant([1., -1., 0.]) 215 >>> softmax_logits = tf.stack([sigmoid_logits, tf.zeros_like(sigmoid_logits)], 216 ... axis=-1) 217 >>> soft_binary_labels = tf.constant([1., 1., 0.]) 218 >>> soft_multiclass_labels = tf.stack( 219 ... [soft_binary_labels, 1. - soft_binary_labels], axis=-1) 220 >>> hard_labels = tf.constant([0, 0, 1]) 221 >>> tf.nn.sparse_softmax_cross_entropy_with_logits( 222 ... labels=hard_labels, logits=softmax_logits).numpy() 223 array([0.31326166, 1.3132616 , 0.6931472 ], dtype=float32) 224 >>> tf.nn.softmax_cross_entropy_with_logits( 225 ... labels=soft_multiclass_labels, logits=softmax_logits).numpy() 226 array([0.31326166, 1.3132616, 0.6931472], dtype=float32) 227 >>> tf.nn.sigmoid_cross_entropy_with_logits( 228 ... labels=soft_binary_labels, logits=sigmoid_logits).numpy() 229 array([0.31326166, 1.3132616, 0.6931472], dtype=float32) 230 231 Args: 232 labels: A `Tensor` of the same type and shape as `logits`. Between 0 and 1, 233 inclusive. 234 logits: A `Tensor` of type `float32` or `float64`. Any real number. 235 name: A name for the operation (optional). 236 237 Returns: 238 A `Tensor` of the same shape as `logits` with the componentwise 239 logistic losses. 240 241 Raises: 242 ValueError: If `logits` and `labels` do not have the same shape. 243 """ 244 return sigmoid_cross_entropy_with_logits( 245 logits=logits, labels=labels, name=name) 246 247 248sigmoid_cross_entropy_with_logits.__doc__ = ( 249 sigmoid_cross_entropy_with_logits_v2.__doc__) 250 251 252@tf_export("nn.weighted_cross_entropy_with_logits", v1=[]) 253@dispatch.add_dispatch_support 254def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, 255 name=None): 256 """Computes a weighted cross entropy. 257 258 This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`, 259 allows one to trade off recall and precision by up- or down-weighting the 260 cost of a positive error relative to a negative error. 261 262 The usual cross-entropy cost is defined as: 263 264 labels * -log(sigmoid(logits)) + 265 (1 - labels) * -log(1 - sigmoid(logits)) 266 267 A value `pos_weight > 1` decreases the false negative count, hence increasing 268 the recall. 269 Conversely setting `pos_weight < 1` decreases the false positive count and 270 increases the precision. 271 This can be seen from the fact that `pos_weight` is introduced as a 272 multiplicative coefficient for the positive labels term 273 in the loss expression: 274 275 labels * -log(sigmoid(logits)) * pos_weight + 276 (1 - labels) * -log(1 - sigmoid(logits)) 277 278 For brevity, let `x = logits`, `z = labels`, `q = pos_weight`. 279 The loss is: 280 281 qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 282 = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) 283 = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) 284 = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) 285 = (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x)) 286 = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x)) 287 288 Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow, 289 the implementation uses 290 291 (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0)) 292 293 `logits` and `labels` must have the same type and shape. 294 295 >>> labels = tf.constant([1., 0.5, 0.]) 296 >>> logits = tf.constant([1.5, -0.1, -10.]) 297 >>> tf.nn.weighted_cross_entropy_with_logits( 298 ... labels=labels, logits=logits, pos_weight=tf.constant(1.5)).numpy() 299 array([3.0211994e-01, 8.8049585e-01, 4.5776367e-05], dtype=float32) 300 >>> tf.nn.weighted_cross_entropy_with_logits( 301 ... labels=labels, logits=logits, pos_weight=tf.constant(0.5)).numpy() 302 array([1.00706644e-01, 5.08297503e-01, 4.57763672e-05], dtype=float32) 303 304 Args: 305 labels: A `Tensor` of the same type and shape as `logits`, with values 306 between 0 and 1 inclusive. 307 logits: A `Tensor` of type `float32` or `float64`, any real numbers. 308 pos_weight: A coefficient to use on the positive examples, typically a 309 scalar but otherwise broadcastable to the shape of `logits`. Its value 310 should be non-negative. 311 name: A name for the operation (optional). 312 313 Returns: 314 A `Tensor` of the same shape as `logits` with the componentwise 315 weighted logistic losses. 316 317 Raises: 318 ValueError: If `logits` and `labels` do not have the same shape. 319 """ 320 with ops.name_scope(name, "logistic_loss", [logits, labels]) as name: 321 logits = ops.convert_to_tensor(logits, name="logits") 322 labels = ops.convert_to_tensor(labels, name="labels") 323 try: 324 labels.get_shape().assert_is_compatible_with(logits.get_shape()) 325 except ValueError: 326 raise ValueError("`logits` and `labels` must have the same shape, " 327 f"received ({logits.get_shape()} vs " 328 f"{labels.get_shape()}).") 329 330 # The logistic loss formula from above is 331 # (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x)) 332 # For x < 0, a more numerically stable formula is 333 # (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(x)) - l * x 334 # To avoid branching, we use the combined version 335 # (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0)) 336 log_weight = 1 + (pos_weight - 1) * labels 337 return math_ops.add( 338 (1 - labels) * logits, 339 log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) + 340 nn_ops.relu(-logits)), # pylint: disable=invalid-unary-operand-type 341 name=name) 342 343 344@tf_export(v1=["nn.weighted_cross_entropy_with_logits"]) 345@dispatch.add_dispatch_support 346@deprecated_args(None, "targets is deprecated, use labels instead", "targets") 347def weighted_cross_entropy_with_logits(labels=None, 348 logits=None, 349 pos_weight=None, 350 name=None, 351 targets=None): 352 """Computes a weighted cross entropy. 353 354 This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`, 355 allows one to trade off recall and precision by up- or down-weighting the 356 cost of a positive error relative to a negative error. 357 358 The usual cross-entropy cost is defined as: 359 360 labels * -log(sigmoid(logits)) + 361 (1 - labels) * -log(1 - sigmoid(logits)) 362 363 A value `pos_weight > 1` decreases the false negative count, hence increasing 364 the recall. 365 Conversely setting `pos_weight < 1` decreases the false positive count and 366 increases the precision. 367 This can be seen from the fact that `pos_weight` is introduced as a 368 multiplicative coefficient for the positive labels term 369 in the loss expression: 370 371 labels * -log(sigmoid(logits)) * pos_weight + 372 (1 - labels) * -log(1 - sigmoid(logits)) 373 374 For brevity, let `x = logits`, `z = labels`, `q = pos_weight`. 375 The loss is: 376 377 qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) 378 = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) 379 = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) 380 = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) 381 = (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x)) 382 = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x)) 383 384 Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow, 385 the implementation uses 386 387 (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0)) 388 389 `logits` and `labels` must have the same type and shape. 390 391 Args: 392 labels: A `Tensor` of the same type and shape as `logits`. 393 logits: A `Tensor` of type `float32` or `float64`. 394 pos_weight: A coefficient to use on the positive examples. 395 name: A name for the operation (optional). 396 targets: Deprecated alias for labels. 397 398 Returns: 399 A `Tensor` of the same shape as `logits` with the componentwise 400 weighted logistic losses. 401 402 Raises: 403 ValueError: If `logits` and `labels` do not have the same shape. 404 """ 405 labels = deprecated_argument_lookup("labels", labels, "targets", targets) 406 return weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, name) 407 408 409@tf_export("nn.compute_average_loss") 410@dispatch.add_dispatch_support 411def compute_average_loss(per_example_loss, 412 sample_weight=None, 413 global_batch_size=None): 414 """Scales per-example losses with sample_weights and computes their average. 415 416 Usage with distribution strategy and custom training loop: 417 418 ```python 419 with strategy.scope(): 420 def compute_loss(labels, predictions, sample_weight=None): 421 422 # If you are using a `Loss` class instead, set reduction to `NONE` so that 423 # we can do the reduction afterwards and divide by global batch size. 424 per_example_loss = tf.keras.losses.sparse_categorical_crossentropy( 425 labels, predictions) 426 427 # Compute loss that is scaled by sample_weight and by global batch size. 428 return tf.nn.compute_average_loss( 429 per_example_loss, 430 sample_weight=sample_weight, 431 global_batch_size=GLOBAL_BATCH_SIZE) 432 ``` 433 434 Args: 435 per_example_loss: Per-example loss. 436 sample_weight: Optional weighting for each example. 437 global_batch_size: Optional global batch size value. Defaults to (size of 438 first dimension of `losses`) * (number of replicas). 439 440 Returns: 441 Scalar loss value. 442 """ # pylint: disable=g-doc-exception 443 per_example_loss = ops.convert_to_tensor(per_example_loss) 444 input_dtype = per_example_loss.dtype 445 446 with losses_util.check_per_example_loss_rank(per_example_loss): 447 if sample_weight is not None: 448 sample_weight = ops.convert_to_tensor(sample_weight) 449 per_example_loss = losses_util.scale_losses_by_sample_weight( 450 per_example_loss, sample_weight) 451 per_example_loss = math_ops.cast(per_example_loss, input_dtype) 452 453 if global_batch_size is None: 454 if ds.has_strategy() and ds.in_cross_replica_context(): 455 raise RuntimeError( 456 "You are calling `compute_average_loss` in cross replica context, " 457 "while it was expected to be called in replica context.") 458 459 num_replicas = ds.get_strategy().num_replicas_in_sync 460 per_replica_batch_size = array_ops.shape_v2(per_example_loss)[0] 461 global_batch_size = per_replica_batch_size * num_replicas 462 463 check_ops.assert_scalar_v2( 464 global_batch_size, message="global_batch_size must be scalar.") 465 check_ops.assert_integer_v2( 466 global_batch_size, 467 message="global_batch_size must be an integer.") 468 check_ops.assert_positive_v2( 469 global_batch_size, message="global_batch_size must be positive.") 470 471 global_batch_size = math_ops.cast(global_batch_size, input_dtype) 472 return math_ops.reduce_sum(per_example_loss) / global_batch_size 473 474 475@tf_export("nn.scale_regularization_loss") 476@dispatch.add_dispatch_support 477def scale_regularization_loss(regularization_loss): 478 """Scales the sum of the given regularization losses by number of replicas. 479 480 Usage with distribution strategy and custom training loop: 481 482 ```python 483 with strategy.scope(): 484 def compute_loss(self, label, predictions): 485 per_example_loss = tf.keras.losses.sparse_categorical_crossentropy( 486 labels, predictions) 487 488 # Compute loss that is scaled by sample_weight and by global batch size. 489 loss = tf.nn.compute_average_loss( 490 per_example_loss, 491 sample_weight=sample_weight, 492 global_batch_size=GLOBAL_BATCH_SIZE) 493 494 # Add scaled regularization losses. 495 loss += tf.nn.scale_regularization_loss(tf.nn.l2_loss(weights)) 496 return loss 497 ``` 498 499 Args: 500 regularization_loss: Regularization loss. 501 502 Returns: 503 Scalar loss value. 504 """ # pylint: disable=g-doc-exception 505 if ds.has_strategy() and ds.in_cross_replica_context(): 506 raise RuntimeError( 507 "You are calling `scale_regularization_loss` in cross replica context, " 508 "while it was expected to be called in replica context.") 509 510 num_replicas = ds.get_strategy().num_replicas_in_sync 511 return math_ops.reduce_sum(regularization_loss) / num_replicas 512 513 514@tf_export(v1=["nn.relu_layer"]) 515@dispatch.add_dispatch_support 516def relu_layer(x, weights, biases, name=None): 517 """Computes Relu(x * weight + biases). 518 519 Args: 520 x: a 2D tensor. Dimensions typically: batch, in_units 521 weights: a 2D tensor. Dimensions typically: in_units, out_units 522 biases: a 1D tensor. Dimensions: out_units 523 name: A name for the operation (optional). If not specified 524 "nn_relu_layer" is used. 525 526 Returns: 527 A 2-D Tensor computing relu(matmul(x, weights) + biases). 528 Dimensions typically: batch, out_units. 529 """ 530 with ops.name_scope(name, "relu_layer", [x, weights, biases]) as name: 531 x = ops.convert_to_tensor(x, name="x") 532 weights = ops.convert_to_tensor(weights, name="weights") 533 biases = ops.convert_to_tensor(biases, name="biases") 534 xw_plus_b = nn_ops.bias_add(math_ops.matmul(x, weights), biases) 535 return nn_ops.relu(xw_plus_b, name=name) 536 537 538@tf_export("nn.silu", "nn.swish") 539@dispatch.register_unary_elementwise_api 540@dispatch.add_dispatch_support 541def swish(features, beta=1.0): 542 # pylint: disable=g-doc-args 543 """Computes the SiLU or Swish activation function: `x * sigmoid(beta * x)`. 544 545 beta : Hyperparameter for Swish activation function. Default value 1.0. 546 547 The SiLU activation function was introduced in "Gaussian Error Linear Units 548 (GELUs)" [Hendrycks et al. 2016](https://arxiv.org/abs/1606.08415) and 549 "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in 550 Reinforcement Learning" 551 [Elfwing et al. 2017](https://arxiv.org/abs/1702.03118) and was independently 552 discovered (and called swish) in "Searching for Activation Functions" 553 [Ramachandran et al. 2017](https://arxiv.org/abs/1710.05941) 554 555 Args: 556 features: A `Tensor` representing preactivation values. 557 beta: A 'Tensor' representing value of beta hyperparameter. 558 559 Returns: 560 The activation value. 561 """ 562 # pylint: enable=g-doc-args 563 features = ops.convert_to_tensor(features, name="features") 564 beta = ops.convert_to_tensor(beta, name="beta") 565 beta = math_ops.cast(beta, features.dtype) 566 567 @custom_gradient.custom_gradient 568 def swish_impl(features): 569 570 def grad(dy): 571 """Gradient for the Swish activation function.""" 572 # Naively, x * tf.nn.sigmoid(x) requires keeping both x and sigmoid(x) 573 # around for backprop, effectively doubling the tensor's memory 574 # consumption. We use a control dependency here so that sigmoid(features) 575 # is re-computed during backprop (the control dep prevents it being 576 # de-duped with the forward pass) and we can free the sigmoid(features) 577 # expression immediately after use during the forward pass. 578 with ops.control_dependencies([dy]): 579 sigmoid_features = math_ops.sigmoid(beta * features) 580 activation_grad = ( 581 sigmoid_features * (1.0 + (beta * features) * 582 (1.0 - sigmoid_features))) 583 return dy * activation_grad 584 585 return features * math_ops.sigmoid(beta * features), grad 586 587 return swish_impl(features) 588 589 590# pylint: disable=redefined-builtin 591@tf_export("linalg.normalize") 592@dispatch.add_dispatch_support 593def normalize(tensor, ord="euclidean", axis=None, name=None): 594 """Normalizes `tensor` along dimension `axis` using specified norm. 595 596 This uses `tf.linalg.norm` to compute the norm along `axis`. 597 598 This function can compute several different vector norms (the 1-norm, the 599 Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and 600 matrix norms (Frobenius, 1-norm, 2-norm and inf-norm). 601 602 Args: 603 tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128` 604 ord: Order of the norm. Supported values are `'fro'`, `'euclidean'`, `1`, 605 `2`, `np.inf` and any positive real number yielding the corresponding 606 p-norm. Default is `'euclidean'` which is equivalent to Frobenius norm if 607 `tensor` is a matrix and equivalent to 2-norm for vectors. 608 Some restrictions apply: a) The Frobenius norm `'fro'` is not defined for 609 vectors, b) If axis is a 2-tuple (matrix norm), only `'euclidean'`, 610 '`fro'`, `1`, `2`, `np.inf` are supported. See the description of `axis` 611 on how to compute norms for a batch of vectors or matrices stored in a 612 tensor. 613 axis: If `axis` is `None` (the default), the input is considered a vector 614 and a single vector norm is computed over the entire set of values in the 615 tensor, i.e. `norm(tensor, ord=ord)` is equivalent to 616 `norm(reshape(tensor, [-1]), ord=ord)`. If `axis` is a Python integer, the 617 input is considered a batch of vectors, and `axis` determines the axis in 618 `tensor` over which to compute vector norms. If `axis` is a 2-tuple of 619 Python integers it is considered a batch of matrices and `axis` determines 620 the axes in `tensor` over which to compute a matrix norm. 621 Negative indices are supported. Example: If you are passing a tensor that 622 can be either a matrix or a batch of matrices at runtime, pass 623 `axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are 624 computed. 625 name: The name of the op. 626 627 Returns: 628 normalized: A normalized `Tensor` with the same shape as `tensor`. 629 norm: The computed norms with the same shape and dtype `tensor` but the 630 final axis is 1 instead. Same as running 631 `tf.cast(tf.linalg.norm(tensor, ord, axis keepdims=True), tensor.dtype)`. 632 633 Raises: 634 ValueError: If `ord` or `axis` is invalid. 635 """ 636 with ops.name_scope(name, "normalize", [tensor]) as name: 637 tensor = ops.convert_to_tensor(tensor) 638 norm = linalg_ops.norm(tensor, ord, axis, keepdims=True) 639 norm = math_ops.cast(norm, tensor.dtype) 640 normalized = tensor / norm 641 return normalized, norm 642 643 644@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", 645 v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"]) 646@dispatch.add_dispatch_support 647@deprecated_args(None, "dim is deprecated, use axis instead", "dim") 648def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None): 649 """Normalizes along dimension `axis` using an L2 norm. 650 651 For a 1-D tensor with `axis = 0`, computes 652 653 output = x / sqrt(max(sum(x**2), epsilon)) 654 655 For `x` with more dimensions, independently normalizes each 1-D slice along 656 dimension `axis`. 657 658 1-D tensor example: 659 >>> x = tf.constant([3.0, 4.0]) 660 >>> tf.math.l2_normalize(x).numpy() 661 array([0.6, 0.8], dtype=float32) 662 663 2-D tensor example: 664 >>> x = tf.constant([[3.0], [4.0]]) 665 >>> tf.math.l2_normalize(x, 0).numpy() 666 array([[0.6], 667 [0.8]], dtype=float32) 668 669 >>> x = tf.constant([[3.0], [4.0]]) 670 >>> tf.math.l2_normalize(x, 1).numpy() 671 array([[1.], 672 [1.]], dtype=float32) 673 674 Args: 675 x: A `Tensor`. 676 axis: Dimension along which to normalize. A scalar or a vector of 677 integers. 678 epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the 679 divisor if `norm < sqrt(epsilon)`. 680 name: A name for this operation (optional). 681 dim: Deprecated, do not use. 682 683 Returns: 684 A `Tensor` with the same shape as `x`. 685 """ 686 axis = deprecated_argument_lookup("axis", axis, "dim", dim) 687 with ops.name_scope(name, "l2_normalize", [x]) as name: 688 x = ops.convert_to_tensor(x, name="x") 689 if x.dtype.is_complex: 690 square_real = math_ops.square(math_ops.real(x)) 691 square_imag = math_ops.square(math_ops.imag(x)) 692 square_sum = math_ops.real( 693 math_ops.reduce_sum(square_real + square_imag, axis, keepdims=True)) 694 x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)) 695 norm_real = math_ops.multiply(math_ops.real(x), x_inv_norm) 696 norm_imag = math_ops.multiply(math_ops.imag(x), x_inv_norm) 697 return math_ops.complex(norm_real, norm_imag, name=name) 698 square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True) 699 x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)) 700 return math_ops.multiply(x, x_inv_norm, name=name) 701 702 703def _count_nonzero(input_tensor, dtype=dtypes.int64): 704 """Same as math_ops.count_nonzero. 705 706 The reduction is done in dtype, which can be faster for 32-bit dtypes. 707 708 Args: 709 input_tensor: numeric tensor 710 dtype: reduction dtype 711 712 Returns: 713 number of nonzero values with type dtype 714 """ 715 with ops.name_scope("count_nonzero", values=[input_tensor]): 716 zero = array_ops.zeros([], dtype=input_tensor.dtype) 717 nonzero_count = math_ops.reduce_sum( 718 math_ops.cast( 719 math_ops.not_equal(input_tensor, zero), 720 dtype=dtype), name="nonzero_count") 721 return nonzero_count 722 723 724@tf_export("math.zero_fraction", "nn.zero_fraction") 725@dispatch.add_dispatch_support 726def zero_fraction(value, name=None): 727 """Returns the fraction of zeros in `value`. 728 729 If `value` is empty, the result is `nan`. 730 731 This is useful in summaries to measure and report sparsity. For example, 732 733 ```python 734 z = tf.nn.relu(...) 735 summ = tf.compat.v1.summary.scalar('sparsity', tf.nn.zero_fraction(z)) 736 ``` 737 738 Args: 739 value: A tensor of numeric type. 740 name: A name for the operation (optional). 741 742 Returns: 743 The fraction of zeros in `value`, with type `float32`. 744 """ 745 with ops.name_scope(name, "zero_fraction", [value]): 746 value = ops.convert_to_tensor(value, name="value") 747 size = array_ops.size(value, out_type=dtypes.int64) 748 # If the count is small, we can save memory/CPU with an int32 reduction. 749 num_nonzero = control_flow_ops.cond( 750 size <= dtypes.int32.max, 751 # pylint: disable=g-long-lambda 752 true_fn=lambda: math_ops.cast( 753 _count_nonzero(value, dtype=dtypes.int32), 754 dtype=dtypes.int64), 755 false_fn=lambda: _count_nonzero(value, dtype=dtypes.int64)) 756 757 with ops.name_scope("counts_to_fraction"): 758 num_zero = size - num_nonzero 759 num_zero_float32 = math_ops.cast(num_zero, dtype=dtypes.float32) 760 size_float32 = math_ops.cast(size, dtype=dtypes.float32) 761 zero_fraction_float32 = num_zero_float32 / size_float32 762 763 return array_ops.identity(zero_fraction_float32, "fraction") 764 765 766# pylint: disable=redefined-builtin 767@tf_export(v1=["nn.depthwise_conv2d"]) 768@dispatch.add_dispatch_support 769def depthwise_conv2d(input, 770 filter, 771 strides, 772 padding, 773 rate=None, 774 name=None, 775 data_format=None, 776 dilations=None): 777 """Depthwise 2-D convolution. 778 779 Given a 4D input tensor ('NHWC' or 'NCHW' data formats) 780 and a filter tensor of shape 781 `[filter_height, filter_width, in_channels, channel_multiplier]` 782 containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d` 783 applies a different filter to each input channel (expanding from 1 channel 784 to `channel_multiplier` channels for each), then concatenates the results 785 together. The output has `in_channels * channel_multiplier` channels. 786 787 In detail, with the default NHWC format, 788 789 output[b, i, j, k * channel_multiplier + q] = sum_{di, dj} 790 filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di, 791 strides[2] * j + rate[1] * dj, k] 792 793 Must have `strides[0] = strides[3] = 1`. For the most common case of the 794 same horizontal and vertical strides, `strides = [1, stride, stride, 1]`. 795 If any value in `rate` is greater than 1, we perform atrous depthwise 796 convolution, in which case all values in the `strides` tensor must be equal 797 to 1. 798 799 Usage Example: 800 801 >>> x = np.array([ 802 ... [1., 2.], 803 ... [3., 4.], 804 ... [5., 6.] 805 ... ], dtype=np.float32).reshape((1, 3, 2, 1)) 806 >>> kernel = np.array([ 807 ... [1., 2.], 808 ... [3., 4] 809 ... ], dtype=np.float32).reshape((2, 1, 1, 2)) 810 >>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], 811 ... padding='VALID').numpy() 812 array([[[[10., 14.], 813 [14., 20.]], 814 [[18., 26.], 815 [22., 32.]]]], dtype=float32) 816 817 >>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], 818 ... padding=[[0, 0], [1, 0], [1, 0], [0, 0]] 819 ... ).numpy() 820 array([[[[ 0., 0.], 821 [ 3., 4.], 822 [ 6., 8.]], 823 [[ 0., 0.], 824 [10., 14.], 825 [14., 20.]], 826 [[ 0., 0.], 827 [18., 26.], 828 [22., 32.]]]], dtype=float32) 829 830 Args: 831 input: 4-D with shape according to `data_format`. 832 filter: 4-D with shape 833 `[filter_height, filter_width, in_channels, channel_multiplier]`. 834 strides: 1-D of size 4. The stride of the sliding window for each 835 dimension of `input`. 836 padding: Controls how to pad the image before applying the convolution. Can 837 be the string `"SAME"` or `"VALID"` indicating the type of padding 838 algorithm to use, or a list indicating the explicit paddings at the start 839 and end of each dimension. When explicit padding is used and data_format 840 is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom], 841 [pad_left, pad_right], [0, 0]]`. When explicit padding used and 842 data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0], 843 [pad_top, pad_bottom], [pad_left, pad_right]]`. 844 rate: 1-D of size 2. The dilation rate in which we sample input values 845 across the `height` and `width` dimensions in atrous convolution. If it is 846 greater than 1, then all values of strides must be 1. 847 name: A name for this operation (optional). 848 data_format: The data format for input. Either "NHWC" (default) or "NCHW". 849 dilations: Alias of rate. 850 851 Returns: 852 A 4-D `Tensor` with shape according to `data_format`. E.g., for 853 "NHWC" format, shape is 854 `[batch, out_height, out_width, in_channels * channel_multiplier].` 855 """ 856 rate = deprecated_argument_lookup("dilations", dilations, "rate", rate) 857 with ops.name_scope(name, "depthwise", [input, filter]) as name: 858 input = ops.convert_to_tensor(input, name="tensor_in") 859 filter = ops.convert_to_tensor(filter, name="filter_in") 860 if rate is None: 861 rate = [1, 1] 862 863 # Use depthwise_conv2d_native if executing on TPU. 864 if device_context.enclosing_tpu_context() is not None: 865 if data_format == "NCHW": 866 dilations = [1, 1, rate[0], rate[1]] 867 else: 868 dilations = [1, rate[0], rate[1], 1] 869 return nn_ops.depthwise_conv2d_native( 870 input=input, 871 filter=filter, 872 strides=strides, 873 padding=padding, 874 data_format=data_format, 875 dilations=dilations, 876 name=name) 877 878 def op(input_converted, _, padding): 879 return nn_ops.depthwise_conv2d_native( 880 input=input_converted, 881 filter=filter, 882 strides=strides, 883 padding=padding, 884 data_format=data_format, 885 name=name) 886 887 return nn_ops.with_space_to_batch( 888 input=input, 889 filter_shape=array_ops.shape(filter), 890 dilation_rate=rate, 891 padding=padding, 892 data_format=data_format, 893 op=op) 894 895 896@tf_export("nn.depthwise_conv2d", v1=[]) 897@dispatch.add_dispatch_support 898def depthwise_conv2d_v2(input, 899 filter, 900 strides, 901 padding, 902 data_format=None, 903 dilations=None, 904 name=None): 905 """Depthwise 2-D convolution. 906 907 Given a 4D input tensor ('NHWC' or 'NCHW' data formats) 908 and a filter tensor of shape 909 `[filter_height, filter_width, in_channels, channel_multiplier]` 910 containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d` 911 applies a different filter to each input channel (expanding from 1 channel 912 to `channel_multiplier` channels for each), then concatenates the results 913 together. The output has `in_channels * channel_multiplier` channels. 914 915 In detail, with the default NHWC format, 916 917 output[b, i, j, k * channel_multiplier + q] = 918 sum_{di, dj} filter[di, dj, k, q] * 919 input[b, strides[1] * i + dilations[0] * di, 920 strides[2] * j + dilations[1] * dj, k] 921 922 Must have `strides[0] = strides[3] = 1`. For the most common case of the 923 same horizontal and vertical strides, `strides = [1, stride, stride, 1]`. 924 If any value in `dilations` is greater than 1, we perform atrous depthwise 925 convolution, in which case all values in the `strides` tensor must be equal 926 to 1. 927 928 Usage Example: 929 930 >>> x = np.array([ 931 ... [1., 2.], 932 ... [3., 4.], 933 ... [5., 6.] 934 ... ], dtype=np.float32).reshape((1, 3, 2, 1)) 935 >>> kernel = np.array([ 936 ... [1., 2.], 937 ... [3., 4] 938 ... ], dtype=np.float32).reshape((2, 1, 1, 2)) 939 >>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], 940 ... padding='VALID').numpy() 941 array([[[[10., 14.], 942 [14., 20.]], 943 [[18., 26.], 944 [22., 32.]]]], dtype=float32) 945 946 >>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], 947 ... padding=[[0, 0], [1, 0], [1, 0], [0, 0]]).numpy() 948 array([[[[ 0., 0.], 949 [ 3., 4.], 950 [ 6., 8.]], 951 [[ 0., 0.], 952 [10., 14.], 953 [14., 20.]], 954 [[ 0., 0.], 955 [18., 26.], 956 [22., 32.]]]], dtype=float32) 957 958 Args: 959 input: 4-D with shape according to `data_format`. 960 filter: 4-D with shape 961 `[filter_height, filter_width, in_channels, channel_multiplier]`. 962 strides: 1-D of size 4. The stride of the sliding window for each 963 dimension of `input`. 964 padding: Controls how to pad the image before applying the convolution. Can 965 be the string `"SAME"` or `"VALID"` indicating the type of padding 966 algorithm to use, or a list indicating the explicit paddings at the start 967 and end of each dimension. See 968 [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2) 969 for more information. When explicit padding is used and data_format 970 is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom], 971 [pad_left, pad_right], [0, 0]]`. When explicit padding used and 972 data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0], 973 [pad_top, pad_bottom], [pad_left, pad_right]]`. 974 data_format: The data format for input. Either "NHWC" (default) or "NCHW". 975 dilations: 1-D of size 2. The dilation rate in which we sample input values 976 across the `height` and `width` dimensions in atrous convolution. If it is 977 greater than 1, then all values of strides must be 1. 978 name: A name for this operation (optional). 979 980 Returns: 981 A 4-D `Tensor` with shape according to `data_format`. E.g., for 982 "NHWC" format, shape is 983 `[batch, out_height, out_width, in_channels * channel_multiplier].` 984 """ 985 return depthwise_conv2d(input=input, 986 filter=filter, 987 strides=strides, 988 padding=padding, 989 rate=dilations, 990 name=name, 991 data_format=data_format) 992 993# pylint: enable=redefined-builtin 994 995 996# pylint: disable=redefined-builtin,line-too-long 997@tf_export(v1=["nn.separable_conv2d"]) 998@dispatch.add_dispatch_support 999def separable_conv2d(input, 1000 depthwise_filter, 1001 pointwise_filter, 1002 strides, 1003 padding, 1004 rate=None, 1005 name=None, 1006 data_format=None, 1007 dilations=None): 1008 """2-D convolution with separable filters. 1009 1010 Performs a depthwise convolution that acts separately on channels followed by 1011 a pointwise convolution that mixes channels. Note that this is separability 1012 between dimensions `[1, 2]` and `3`, not spatial separability between 1013 dimensions `1` and `2`. 1014 1015 In detail, with the default NHWC format, 1016 1017 output[b, i, j, k] = sum_{di, dj, q, r} 1018 input[b, strides[1] * i + di, strides[2] * j + dj, q] * 1019 depthwise_filter[di, dj, q, r] * 1020 pointwise_filter[0, 0, q * channel_multiplier + r, k] 1021 1022 `strides` controls the strides for the depthwise convolution only, since 1023 the pointwise convolution has implicit strides of `[1, 1, 1, 1]`. Must have 1024 `strides[0] = strides[3] = 1`. For the most common case of the same 1025 horizontal and vertical strides, `strides = [1, stride, stride, 1]`. 1026 If any value in `rate` is greater than 1, we perform atrous depthwise 1027 convolution, in which case all values in the `strides` tensor must be equal 1028 to 1. 1029 1030 Args: 1031 input: 4-D `Tensor` with shape according to `data_format`. 1032 depthwise_filter: 4-D `Tensor` with shape 1033 `[filter_height, filter_width, in_channels, channel_multiplier]`. 1034 Contains `in_channels` convolutional filters of depth 1. 1035 pointwise_filter: 4-D `Tensor` with shape 1036 `[1, 1, channel_multiplier * in_channels, out_channels]`. Pointwise 1037 filter to mix channels after `depthwise_filter` has convolved spatially. 1038 strides: 1-D of size 4. The strides for the depthwise convolution for 1039 each dimension of `input`. 1040 padding: Controls how to pad the image before applying the depthwise 1041 convolution. Can be the string `"SAME"` or `"VALID"` indicating the type 1042 of padding algorithm to use, or a Python list indicating the explicit 1043 paddings at the start and end of each dimension. When explicit padding is 1044 used and data_format is `"NHWC"`, this should be in the form `[[0, 0], 1045 [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit 1046 padding used and data_format is `"NCHW"`, this should be in the form 1047 `[[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]]`. 1048 rate: 1-D of size 2. The dilation rate in which we sample input values 1049 across the `height` and `width` dimensions in atrous convolution. If it is 1050 greater than 1, then all values of strides must be 1. 1051 name: A name for this operation (optional). 1052 data_format: The data format for input. Either "NHWC" (default) or "NCHW". 1053 dilations: Alias of rate. 1054 1055 Returns: 1056 A 4-D `Tensor` with shape according to 'data_format'. For 1057 example, with data_format="NHWC", shape is [batch, out_height, 1058 out_width, out_channels]. 1059 """ 1060 rate = deprecated_argument_lookup("dilations", dilations, "rate", rate) 1061 with ops.name_scope(name, "separable_conv2d", 1062 [input, depthwise_filter, pointwise_filter]) as name: 1063 input = ops.convert_to_tensor(input, name="tensor_in") 1064 depthwise_filter = ops.convert_to_tensor( 1065 depthwise_filter, name="depthwise_filter") 1066 pointwise_filter = ops.convert_to_tensor( 1067 pointwise_filter, name="pointwise_filter") 1068 1069 pointwise_filter_shape = pointwise_filter.get_shape().with_rank(4) 1070 pointwise_filter_shape.dims[0].assert_is_compatible_with(1) 1071 pointwise_filter_shape.dims[1].assert_is_compatible_with(1) 1072 1073 if rate is None: 1074 rate = [1, 1] 1075 1076 # The layout of the ops in the graph are expected to be as follows: 1077 # depthwise_conv2d // Conv2D op corresponding to native depthwise conv. 1078 # separable_conv2d // Conv2D op corresponding to the pointwise conv. 1079 1080 def op(input_converted, _, padding): 1081 return nn_ops.depthwise_conv2d_native( 1082 input=input_converted, 1083 filter=depthwise_filter, 1084 strides=strides, 1085 padding=padding, 1086 data_format=data_format, 1087 name="depthwise") 1088 1089 depthwise = nn_ops.with_space_to_batch( 1090 input=input, 1091 filter_shape=array_ops.shape(depthwise_filter), 1092 dilation_rate=rate, 1093 padding=padding, 1094 data_format=data_format, 1095 op=op) 1096 1097 return nn_ops.conv2d( 1098 depthwise, 1099 pointwise_filter, [1, 1, 1, 1], 1100 padding="VALID", 1101 data_format=data_format, 1102 name=name) 1103 1104 1105@tf_export("nn.separable_conv2d", v1=[]) 1106@dispatch.add_dispatch_support 1107def separable_conv2d_v2( 1108 input, 1109 depthwise_filter, 1110 pointwise_filter, 1111 strides, 1112 padding, 1113 data_format=None, 1114 dilations=None, 1115 name=None, 1116): 1117 """2-D convolution with separable filters. 1118 1119 Performs a depthwise convolution that acts separately on channels followed by 1120 a pointwise convolution that mixes channels. Note that this is separability 1121 between dimensions `[1, 2]` and `3`, not spatial separability between 1122 dimensions `1` and `2`. 1123 1124 In detail, with the default NHWC format, 1125 1126 output[b, i, j, k] = sum_{di, dj, q, r} 1127 input[b, strides[1] * i + di, strides[2] * j + dj, q] * 1128 depthwise_filter[di, dj, q, r] * 1129 pointwise_filter[0, 0, q * channel_multiplier + r, k] 1130 1131 `strides` controls the strides for the depthwise convolution only, since 1132 the pointwise convolution has implicit strides of `[1, 1, 1, 1]`. Must have 1133 `strides[0] = strides[3] = 1`. For the most common case of the same 1134 horizontal and vertical strides, `strides = [1, stride, stride, 1]`. 1135 If any value in `rate` is greater than 1, we perform atrous depthwise 1136 convolution, in which case all values in the `strides` tensor must be equal 1137 to 1. 1138 1139 Args: 1140 input: 4-D `Tensor` with shape according to `data_format`. 1141 depthwise_filter: 4-D `Tensor` with shape `[filter_height, filter_width, 1142 in_channels, channel_multiplier]`. Contains `in_channels` convolutional 1143 filters of depth 1. 1144 pointwise_filter: 4-D `Tensor` with shape `[1, 1, channel_multiplier * 1145 in_channels, out_channels]`. Pointwise filter to mix channels after 1146 `depthwise_filter` has convolved spatially. 1147 strides: 1-D of size 4. The strides for the depthwise convolution for each 1148 dimension of `input`. 1149 padding: Controls how to pad the image before applying the depthwise 1150 convolution. Can be the string `"SAME"` or `"VALID"` indicating the type 1151 of padding algorithm to use, or a Python list indicating the explicit 1152 paddings at the start and end of each dimension. When explicit padding is 1153 used and data_format is `"NHWC"`, this should be in the form `[[0, 0], 1154 [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit 1155 padding used and data_format is `"NCHW"`, this should be in the form 1156 `[[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]]`. 1157 data_format: The data format for input. Either "NHWC" (default) or "NCHW". 1158 dilations: 1-D of size 2. The dilation rate in which we sample input values 1159 across the `height` and `width` dimensions in atrous convolution. If it is 1160 greater than 1, then all values of strides must be 1. 1161 name: A name for this operation (optional). 1162 1163 Returns: 1164 A 4-D `Tensor` with shape according to 'data_format'. For 1165 example, with data_format="NHWC", shape is [batch, out_height, 1166 out_width, out_channels]. 1167 """ 1168 return separable_conv2d( 1169 input, 1170 depthwise_filter, 1171 pointwise_filter, 1172 strides, 1173 padding, 1174 rate=dilations, 1175 name=name, 1176 data_format=data_format) 1177 1178# pylint: enable=redefined-builtin,line-too-long 1179 1180 1181@tf_export(v1=["nn.sufficient_statistics"]) 1182@dispatch.add_dispatch_support 1183def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None, 1184 keepdims=None): 1185 """Calculate the sufficient statistics for the mean and variance of `x`. 1186 1187 These sufficient statistics are computed using the one pass algorithm on 1188 an input that's optionally shifted. See: 1189 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data 1190 1191 For example: 1192 >>> t = [[1, 2, 3], [4, 5, 6]] 1193 >>> sufficient_statistics(t, [1]) 1194 (<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,), 1195 dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,), 1196 dtype=int32, numpy=array([14, 77], dtype=int32)>, None) 1197 >>> sufficient_statistics(t, [-1]) 1198 (<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,), 1199 dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,), 1200 dtype=int32, numpy=array([14, 77], dtype=int32)>, None) 1201 1202 Args: 1203 x: A `Tensor`. 1204 axes: Array of ints. Axes along which to compute mean and variance. As in 1205 Python, the axes can also be negative numbers. A negative axis is 1206 interpreted as counting from the end of the rank, i.e., axis + 1207 rank(values)-th dimension. 1208 shift: A `Tensor` containing the value by which to shift the data for 1209 numerical stability, or `None` if no shift is to be performed. A shift 1210 close to the true mean provides the most numerically stable results. 1211 keep_dims: produce statistics with the same dimensionality as the input. 1212 name: Name used to scope the operations that compute the sufficient stats. 1213 keepdims: Alias for keep_dims. 1214 1215 Returns: 1216 Four `Tensor` objects of the same type as `x`: 1217 1218 * the count (number of elements to average over). 1219 * the (possibly shifted) sum of the elements in the array. 1220 * the (possibly shifted) sum of squares of the elements in the array. 1221 * the shift by which the mean must be corrected or None if `shift` is None. 1222 """ 1223 axes = list(set(axes)) 1224 keep_dims = deprecated_argument_lookup( 1225 "keepdims", keepdims, "keep_dims", keep_dims) 1226 if keep_dims is None: 1227 keep_dims = False 1228 with ops.name_scope(name, "sufficient_statistics", [x, shift]): 1229 x = ops.convert_to_tensor(x, name="x") 1230 x_shape = x.get_shape() 1231 if x_shape.rank is not None and all( 1232 x_shape.dims[d].value is not None for d in axes): 1233 counts = 1 1234 for d in axes: 1235 counts *= x_shape.dims[d].value 1236 counts = constant_op.constant(counts, dtype=x.dtype) 1237 else: # shape needs to be inferred at runtime. 1238 # Normalize axes to be positive. Required for gather. 1239 rank = array_ops.rank(x) 1240 positive_axes = [axis + rank if axis < 0 else axis for axis in axes] 1241 x_dims = array_ops.gather( 1242 math_ops.cast(array_ops.shape(x), x.dtype), positive_axes) 1243 counts = math_ops.reduce_prod(x_dims, name="count") 1244 if shift is not None: 1245 shift = ops.convert_to_tensor(shift, name="shift") 1246 m_ss = math_ops.subtract(x, shift) 1247 v_ss = math_ops.squared_difference(x, shift) 1248 else: # no shift. 1249 m_ss = x 1250 v_ss = math_ops.square(x) 1251 m_ss = math_ops.reduce_sum(m_ss, axes, keepdims=keep_dims, name="mean_ss") 1252 v_ss = math_ops.reduce_sum(v_ss, axes, keepdims=keep_dims, name="var_ss") 1253 return counts, m_ss, v_ss, shift 1254 1255 1256@tf_export("nn.sufficient_statistics", v1=[]) 1257@dispatch.add_dispatch_support 1258def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None): 1259 """Calculate the sufficient statistics for the mean and variance of `x`. 1260 1261 These sufficient statistics are computed using the one pass algorithm on 1262 an input that's optionally shifted. See: 1263 https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data 1264 1265 Args: 1266 x: A `Tensor`. 1267 axes: Array of ints. Axes along which to compute mean and variance. 1268 shift: A `Tensor` containing the value by which to shift the data for 1269 numerical stability, or `None` if no shift is to be performed. A shift 1270 close to the true mean provides the most numerically stable results. 1271 keepdims: produce statistics with the same dimensionality as the input. 1272 name: Name used to scope the operations that compute the sufficient stats. 1273 1274 Returns: 1275 Four `Tensor` objects of the same type as `x`: 1276 1277 * the count (number of elements to average over). 1278 * the (possibly shifted) sum of the elements in the array. 1279 * the (possibly shifted) sum of squares of the elements in the array. 1280 * the shift by which the mean must be corrected or None if `shift` is None. 1281 """ 1282 return sufficient_statistics( 1283 x=x, axes=axes, shift=shift, keep_dims=keepdims, name=name) 1284 1285 1286@tf_export("nn.normalize_moments") 1287@dispatch.add_dispatch_support 1288def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): 1289 """Calculate the mean and variance of based on the sufficient statistics. 1290 1291 Args: 1292 counts: A `Tensor` containing the total count of the data (one value). 1293 mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly 1294 shifted) sum of the elements to average over. 1295 variance_ss: A `Tensor` containing the variance sufficient statistics: the 1296 (possibly shifted) squared sum of the data to compute the variance over. 1297 shift: A `Tensor` containing the value by which the data is shifted for 1298 numerical stability, or `None` if no shift was performed. 1299 name: Name used to scope the operations that compute the moments. 1300 1301 Returns: 1302 Two `Tensor` objects: `mean` and `variance`. 1303 """ 1304 with ops.name_scope(name, "normalize", [counts, mean_ss, variance_ss, shift]): 1305 divisor = math_ops.reciprocal(counts, name="divisor") 1306 if shift is not None: 1307 shifted_mean = math_ops.multiply(mean_ss, divisor, name="shifted_mean") 1308 mean = math_ops.add(shifted_mean, shift, name="mean") 1309 else: # no shift. 1310 shifted_mean = math_ops.multiply(mean_ss, divisor, name="mean") 1311 mean = shifted_mean 1312 variance = math_ops.subtract( 1313 math_ops.multiply(variance_ss, divisor), 1314 math_ops.square(shifted_mean), 1315 name="variance") 1316 return (mean, variance) 1317 1318 1319@tf_export(v1=["nn.moments"]) 1320@dispatch.add_dispatch_support 1321def moments( 1322 x, 1323 axes, 1324 shift=None, # pylint: disable=unused-argument 1325 name=None, 1326 keep_dims=None, 1327 keepdims=None): 1328 """Calculate the mean and variance of `x`. 1329 1330 The mean and variance are calculated by aggregating the contents of `x` 1331 across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean 1332 and variance of a vector. 1333 1334 Note: shift is currently not used; the true mean is computed and used. 1335 1336 When using these moments for batch normalization (see 1337 `tf.nn.batch_normalization`): 1338 1339 * for so-called "global normalization", used with convolutional filters with 1340 shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`. 1341 * for simple batch normalization pass `axes=[0]` (batch only). 1342 1343 Args: 1344 x: A `Tensor`. 1345 axes: Array of ints. Axes along which to compute mean and 1346 variance. 1347 shift: Not used in the current implementation 1348 name: Name used to scope the operations that compute the moments. 1349 keep_dims: produce moments with the same dimensionality as the input. 1350 keepdims: Alias to keep_dims. 1351 1352 Returns: 1353 Two `Tensor` objects: `mean` and `variance`. 1354 """ 1355 keep_dims = deprecated_argument_lookup( 1356 "keepdims", keepdims, "keep_dims", keep_dims) 1357 if keep_dims is None: 1358 keep_dims = False 1359 with ops.name_scope(name, "moments", [x, axes]): 1360 # The dynamic range of fp16 is too limited to support the collection of 1361 # sufficient statistics. As a workaround we simply perform the operations 1362 # on 32-bit floats before converting the mean and variance back to fp16 1363 y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x 1364 # Compute true mean while keeping the dims for proper broadcasting. 1365 mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean") 1366 # sample variance, not unbiased variance 1367 # Note: stop_gradient does not change the gradient that gets 1368 # backpropagated to the mean from the variance calculation, 1369 # because that gradient is zero 1370 variance = math_ops.reduce_mean( 1371 math_ops.squared_difference(y, array_ops.stop_gradient(mean)), 1372 axes, 1373 keepdims=True, 1374 name="variance") 1375 if not keep_dims: 1376 mean = array_ops.squeeze(mean, axes) 1377 variance = array_ops.squeeze(variance, axes) 1378 if x.dtype == dtypes.float16: 1379 return (math_ops.cast(mean, dtypes.float16), 1380 math_ops.cast(variance, dtypes.float16)) 1381 else: 1382 return (mean, variance) 1383 1384 1385@tf_export("nn.moments", v1=[]) 1386@dispatch.add_dispatch_support 1387def moments_v2( 1388 x, 1389 axes, 1390 shift=None, 1391 keepdims=False, 1392 name=None): 1393 """Calculates the mean and variance of `x`. 1394 1395 The mean and variance are calculated by aggregating the contents of `x` 1396 across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean 1397 and variance of a vector. 1398 1399 Note: shift is currently not used; the true mean is computed and used. 1400 1401 When using these moments for batch normalization (see 1402 `tf.nn.batch_normalization`): 1403 1404 * for so-called "global normalization", used with convolutional filters with 1405 shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`. 1406 * for simple batch normalization pass `axes=[0]` (batch only). 1407 1408 Args: 1409 x: A `Tensor`. 1410 axes: Array of ints. Axes along which to compute mean and 1411 variance. 1412 shift: Not used in the current implementation. 1413 keepdims: produce moments with the same dimensionality as the input. 1414 name: Name used to scope the operations that compute the moments. 1415 1416 Returns: 1417 Two `Tensor` objects: `mean` and `variance`. 1418 """ 1419 return moments(x=x, axes=axes, shift=shift, name=name, keep_dims=keepdims) 1420 1421 1422@tf_export(v1=["nn.weighted_moments"]) 1423@dispatch.add_dispatch_support 1424def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None, 1425 keepdims=None): 1426 """Returns the frequency-weighted mean and variance of `x`. 1427 1428 Args: 1429 x: A tensor. 1430 axes: 1-d tensor of int32 values; these are the axes along which 1431 to compute mean and variance. 1432 frequency_weights: A tensor of positive weights which can be 1433 broadcast with x. 1434 name: Name used to scope the operation. 1435 keep_dims: Produce moments with the same dimensionality as the input. 1436 keepdims: Alias of keep_dims. 1437 1438 Returns: 1439 Two tensors: `weighted_mean` and `weighted_variance`. 1440 """ 1441 keep_dims = deprecated_argument_lookup( 1442 "keepdims", keepdims, "keep_dims", keep_dims) 1443 if keep_dims is None: 1444 keep_dims = False 1445 with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]): 1446 x = ops.convert_to_tensor(x, name="x") 1447 frequency_weights = ops.convert_to_tensor( 1448 frequency_weights, name="frequency_weights") 1449 1450 # Unlike moments(), this just uses a simpler two-pass method. 1451 1452 # See comment in moments() WRT precision; it applies here too. 1453 needs_cast = x.dtype == dtypes.float16 1454 if needs_cast: 1455 x = math_ops.cast(x, dtypes.float32) 1456 1457 if frequency_weights.dtype != x.dtype: 1458 frequency_weights = math_ops.cast(frequency_weights, x.dtype) 1459 1460 # Note that we use keep_dims=True for our reductions regardless of the arg; 1461 # this is so that the results remain broadcast-compatible with the inputs. 1462 weighted_input_sum = math_ops.reduce_sum( 1463 frequency_weights * x, axes, name="weighted_input_sum", keepdims=True) 1464 1465 # The shape of the weights isn't necessarily the same as x's 1466 # shape, just broadcast-compatible with it -- so this expression 1467 # performs broadcasting to give a per-item weight, with the same 1468 # shape as (frequency_weights * x). This avoids having to reason 1469 # through all the broadcast logic to compute a correct 1470 # sum_of_weights. 1471 broadcasted_weights = frequency_weights + array_ops.zeros_like(x) 1472 1473 sum_of_weights = math_ops.reduce_sum( 1474 broadcasted_weights, axes, name="sum_of_weights", keepdims=True) 1475 1476 divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum") 1477 1478 weighted_mean = math_ops.multiply(weighted_input_sum, divisor) 1479 1480 # Have the weighted mean; now on to variance: 1481 weighted_distsq = math_ops.reduce_sum( 1482 frequency_weights * math_ops.squared_difference(x, weighted_mean), 1483 axes, 1484 name="weighted_distsq", 1485 keepdims=True) 1486 1487 weighted_variance = math_ops.multiply(weighted_distsq, divisor) 1488 1489 if not keep_dims: 1490 weighted_mean = array_ops.squeeze(weighted_mean, axis=axes) 1491 weighted_variance = array_ops.squeeze( 1492 weighted_variance, axis=axes) 1493 1494 if needs_cast: 1495 weighted_mean = math_ops.cast(weighted_mean, dtypes.float16) 1496 weighted_variance = math_ops.cast(weighted_variance, dtypes.float16) 1497 1498 return weighted_mean, weighted_variance 1499 1500 1501@tf_export("nn.weighted_moments", v1=[]) 1502@dispatch.add_dispatch_support 1503def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None): 1504 """Returns the frequency-weighted mean and variance of `x`. 1505 1506 Args: 1507 x: A tensor. 1508 axes: 1-d tensor of int32 values; these are the axes along which 1509 to compute mean and variance. 1510 frequency_weights: A tensor of positive weights which can be 1511 broadcast with x. 1512 keepdims: Produce moments with the same dimensionality as the input. 1513 name: Name used to scope the operation. 1514 1515 Returns: 1516 Two tensors: `weighted_mean` and `weighted_variance`. 1517 """ 1518 return weighted_moments( 1519 x=x, 1520 axes=axes, 1521 frequency_weights=frequency_weights, 1522 name=name, 1523 keep_dims=keepdims) 1524 1525 1526@tf_export("nn.batch_normalization") 1527@dispatch.add_dispatch_support 1528def batch_normalization(x, 1529 mean, 1530 variance, 1531 offset, 1532 scale, 1533 variance_epsilon, 1534 name=None): 1535 r"""Batch normalization. 1536 1537 Normalizes a tensor by `mean` and `variance`, and applies (optionally) a 1538 `scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\): 1539 1540 \\(\frac{\gamma(x-\mu)}{\sigma}+\beta\\) 1541 1542 `mean`, `variance`, `offset` and `scale` are all expected to be of one of two 1543 shapes: 1544 1545 * In all generality, they can have the same number of dimensions as the 1546 input `x`, with identical sizes as `x` for the dimensions that are not 1547 normalized over (the 'depth' dimension(s)), and dimension 1 for the 1548 others which are being normalized over. 1549 `mean` and `variance` in this case would typically be the outputs of 1550 `tf.nn.moments(..., keepdims=True)` during training, or running averages 1551 thereof during inference. 1552 * In the common case where the 'depth' dimension is the last dimension in 1553 the input tensor `x`, they may be one dimensional tensors of the same 1554 size as the 'depth' dimension. 1555 This is the case for example for the common `[batch, depth]` layout of 1556 fully-connected layers, and `[batch, height, width, depth]` for 1557 convolutions. 1558 `mean` and `variance` in this case would typically be the outputs of 1559 `tf.nn.moments(..., keepdims=False)` during training, or running averages 1560 thereof during inference. 1561 1562 See equation 11 in Algorithm 2 of source: 1563 [Batch Normalization: Accelerating Deep Network Training by 1564 Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy] 1565 (http://arxiv.org/abs/1502.03167). 1566 1567 Args: 1568 x: Input `Tensor` of arbitrary dimensionality. 1569 mean: A mean `Tensor`. 1570 variance: A variance `Tensor`. 1571 offset: An offset `Tensor`, often denoted \\(\beta\\) in equations, or 1572 None. If present, will be added to the normalized tensor. 1573 scale: A scale `Tensor`, often denoted \\(\gamma\\) in equations, or 1574 `None`. If present, the scale is applied to the normalized tensor. 1575 variance_epsilon: A small float number to avoid dividing by 0. 1576 name: A name for this operation (optional). 1577 1578 Returns: 1579 the normalized, scaled, offset tensor. 1580 1581 References: 1582 Batch Normalization - Accelerating Deep Network Training by Reducing 1583 Internal Covariate Shift: 1584 [Ioffe et al., 2015](http://arxiv.org/abs/1502.03167) 1585 ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) 1586 """ 1587 with ops.name_scope(name, "batchnorm", [x, mean, variance, scale, offset]): 1588 inv = math_ops.rsqrt(variance + variance_epsilon) 1589 if scale is not None: 1590 inv *= scale 1591 # Note: tensorflow/contrib/quantize/python/fold_batch_norms.py depends on 1592 # the precise order of ops that are generated by the expression below. 1593 return x * math_ops.cast(inv, x.dtype) + math_ops.cast( 1594 offset - mean * inv if offset is not None else -mean * inv, x.dtype) 1595 1596 1597@tf_export(v1=["nn.fused_batch_norm"]) 1598@dispatch.add_dispatch_support 1599def fused_batch_norm( 1600 x, 1601 scale, 1602 offset, # pylint: disable=invalid-name 1603 mean=None, 1604 variance=None, 1605 epsilon=0.001, 1606 data_format="NHWC", 1607 is_training=True, 1608 name=None, 1609 exponential_avg_factor=1.0): 1610 r"""Batch normalization. 1611 1612 1613 See Source: [Batch Normalization: Accelerating Deep Network Training by 1614 Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy] 1615 (http://arxiv.org/abs/1502.03167). 1616 1617 Args: 1618 x: Input `Tensor` of 4 or 5 dimensions. 1619 scale: A `Tensor` of 1 dimension for scaling. 1620 offset: A `Tensor` of 1 dimension for bias. 1621 mean: A `Tensor` of 1 dimension for population mean. The shape and meaning 1622 of this argument depends on the value of is_training and 1623 exponential_avg_factor as follows: 1624 is_training==False (inference): 1625 Mean must be a `Tensor` of the same shape as scale containing the 1626 estimated population mean computed during training. 1627 is_training==True and exponential_avg_factor == 1.0: 1628 Mean must be None. 1629 is_training==True and exponential_avg_factor != 1.0: 1630 Mean must be a `Tensor` of the same shape as scale containing the 1631 exponential running mean. 1632 variance: A `Tensor` of 1 dimension for population variance. The shape and 1633 meaning of this argument depends on the value of is_training and 1634 exponential_avg_factor as follows: 1635 is_training==False (inference): 1636 Variance must be a `Tensor` of the same shape as scale containing 1637 the estimated population variance computed during training. 1638 is_training==True and exponential_avg_factor == 1.0: 1639 Variance must be None. 1640 is_training==True and exponential_avg_factor != 1.0: 1641 Variance must be a `Tensor` of the same shape as scale containing 1642 the exponential running variance. 1643 epsilon: A small float number added to the variance of x. 1644 data_format: The data format for x. Support "NHWC" (default) or "NCHW" for 1645 4D tenors and "NDHWC" or "NCDHW" for 5D tensors. 1646 is_training: A bool value to specify if the operation is used for 1647 training or inference. 1648 name: A name for this operation (optional). 1649 exponential_avg_factor: A float number (usually between 0 and 1) used 1650 for controlling the decay of the running 1651 population average of mean and variance. 1652 If set to 1.0, the current batch average is 1653 returned. 1654 1655 Returns: 1656 y: A 4D or 5D Tensor for the normalized, scaled, offsetted x. 1657 running_mean: A 1D Tensor for the exponential running mean of x. 1658 The output value is (1 - exponential_avg_factor) * mean + 1659 exponential_avg_factor * batch_mean), where batch_mean 1660 is the mean of the current batch in x. 1661 running_var: A 1D Tensor for the exponential running variance 1662 The output value is (1 - exponential_avg_factor) * variance + 1663 exponential_avg_factor * batch_variance), where batch_variance 1664 is the variance of the current batch in x. 1665 1666 References: 1667 Batch Normalization - Accelerating Deep Network Training by Reducing 1668 Internal Covariate Shift: 1669 [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html) 1670 ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) 1671 """ 1672 if (not is_training or exponential_avg_factor != 1.0) and ( 1673 (mean is None) or (variance is None)): 1674 raise ValueError("Both `mean` and `variance` must be a 1D tensor when " 1675 "`is_training` is False or `exponential_avg_factor` != " 1676 f"1.0. Received: `mean` {mean!r} and `variance` " 1677 f"{variance!r}") 1678 x = ops.convert_to_tensor(x, name="input") 1679 scale = ops.convert_to_tensor(scale, name="scale") 1680 offset = ops.convert_to_tensor(offset, name="offset") 1681 if mean is None: 1682 mean = constant_op.constant([]) 1683 if variance is None: 1684 variance = constant_op.constant([]) 1685 1686 # Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to 1687 # prevent exception (see cudnn.h). 1688 min_epsilon = 1.001e-5 1689 epsilon = epsilon if epsilon > min_epsilon else min_epsilon 1690 1691 y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3( 1692 x, 1693 scale, 1694 offset, 1695 mean, 1696 variance, 1697 epsilon=epsilon, 1698 exponential_avg_factor=exponential_avg_factor, 1699 data_format=data_format, 1700 is_training=is_training, 1701 name=name) 1702 return y, running_mean, running_var 1703 1704 1705@tf_export(v1=["nn.batch_norm_with_global_normalization"]) 1706@dispatch.add_dispatch_support 1707def batch_norm_with_global_normalization(t=None, 1708 m=None, 1709 v=None, 1710 beta=None, 1711 gamma=None, 1712 variance_epsilon=None, 1713 scale_after_normalization=None, 1714 name=None, 1715 input=None, # pylint: disable=redefined-builtin 1716 mean=None, 1717 variance=None): 1718 """Batch normalization. 1719 1720 This op is deprecated. See `tf.nn.batch_normalization`. 1721 1722 Args: 1723 t: A 4D input Tensor. 1724 m: A 1D mean Tensor with size matching the last dimension of t. 1725 This is the first output from tf.nn.moments, 1726 or a saved moving average thereof. 1727 v: A 1D variance Tensor with size matching the last dimension of t. 1728 This is the second output from tf.nn.moments, 1729 or a saved moving average thereof. 1730 beta: A 1D beta Tensor with size matching the last dimension of t. 1731 An offset to be added to the normalized tensor. 1732 gamma: A 1D gamma Tensor with size matching the last dimension of t. 1733 If "scale_after_normalization" is true, this tensor will be multiplied 1734 with the normalized tensor. 1735 variance_epsilon: A small float number to avoid dividing by 0. 1736 scale_after_normalization: A bool indicating whether the resulted tensor 1737 needs to be multiplied with gamma. 1738 name: A name for this operation (optional). 1739 input: Alias for t. 1740 mean: Alias for m. 1741 variance: Alias for v. 1742 1743 Returns: 1744 A batch-normalized `t`. 1745 1746 References: 1747 Batch Normalization - Accelerating Deep Network Training by Reducing 1748 Internal Covariate Shift: 1749 [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html) 1750 ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) 1751 """ 1752 t = deprecated_argument_lookup("input", input, "t", t) 1753 m = deprecated_argument_lookup("mean", mean, "m", m) 1754 v = deprecated_argument_lookup("variance", variance, "v", v) 1755 return batch_normalization(t, m, v, beta, gamma if scale_after_normalization 1756 else None, variance_epsilon, name) 1757 1758 1759# pylint: disable=redefined-builtin,line-too-long 1760@tf_export("nn.batch_norm_with_global_normalization", v1=[]) 1761@dispatch.add_dispatch_support 1762def batch_norm_with_global_normalization_v2(input, 1763 mean, 1764 variance, 1765 beta, 1766 gamma, 1767 variance_epsilon, 1768 scale_after_normalization, 1769 name=None): 1770 """Batch normalization. 1771 1772 This op is deprecated. See `tf.nn.batch_normalization`. 1773 1774 Args: 1775 input: A 4D input Tensor. 1776 mean: A 1D mean Tensor with size matching the last dimension of t. 1777 This is the first output from tf.nn.moments, 1778 or a saved moving average thereof. 1779 variance: A 1D variance Tensor with size matching the last dimension of t. 1780 This is the second output from tf.nn.moments, 1781 or a saved moving average thereof. 1782 beta: A 1D beta Tensor with size matching the last dimension of t. 1783 An offset to be added to the normalized tensor. 1784 gamma: A 1D gamma Tensor with size matching the last dimension of t. 1785 If "scale_after_normalization" is true, this tensor will be multiplied 1786 with the normalized tensor. 1787 variance_epsilon: A small float number to avoid dividing by 0. 1788 scale_after_normalization: A bool indicating whether the resulted tensor 1789 needs to be multiplied with gamma. 1790 name: A name for this operation (optional). 1791 1792 Returns: 1793 A batch-normalized `t`. 1794 1795 References: 1796 Batch Normalization - Accelerating Deep Network Training by Reducing Internal Covariate Shift: 1797 [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html) 1798 ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf)) 1799 """ 1800 return batch_norm_with_global_normalization(t=input, 1801 m=mean, 1802 v=variance, 1803 beta=beta, 1804 gamma=gamma, 1805 variance_epsilon=variance_epsilon, 1806 scale_after_normalization=scale_after_normalization, 1807 name=name) 1808 1809# pylint: enable=redefined-builtin,line-too-long 1810 1811 1812def _sum_rows(x): 1813 """Returns a vector summing up each row of the matrix x.""" 1814 # _sum_rows(x) is equivalent to math_ops.reduce_sum(x, 1) when x is 1815 # a matrix. The gradient of _sum_rows(x) is more efficient than 1816 # reduce_sum(x, 1)'s gradient in today's implementation. Therefore, 1817 # we use _sum_rows(x) in the nce_loss() computation since the loss 1818 # is mostly used for training. 1819 cols = array_ops.shape(x)[1] 1820 ones_shape = array_ops.stack([cols, 1]) 1821 ones = array_ops.ones(ones_shape, x.dtype) 1822 return array_ops.reshape(math_ops.matmul(x, ones), [-1]) 1823 1824 1825def _compute_sampled_logits(weights, 1826 biases, 1827 labels, 1828 inputs, 1829 num_sampled, 1830 num_classes, 1831 num_true=1, 1832 sampled_values=None, 1833 subtract_log_q=True, 1834 remove_accidental_hits=False, 1835 partition_strategy="mod", 1836 name=None, 1837 seed=None): 1838 """Helper function for nce_loss and sampled_softmax_loss functions. 1839 1840 Computes sampled output training logits and labels suitable for implementing 1841 e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see 1842 sampled_softmax_loss). 1843 1844 Note: In the case where num_true > 1, we assign to each target class 1845 the target probability 1 / num_true so that the target probabilities 1846 sum to 1 per-example. 1847 1848 Args: 1849 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` 1850 objects whose concatenation along dimension 0 has shape 1851 `[num_classes, dim]`. The (possibly-partitioned) class embeddings. 1852 biases: A `Tensor` of shape `[num_classes]`. The (possibly-partitioned) 1853 class biases. 1854 labels: A `Tensor` of type `int64` and shape `[batch_size, 1855 num_true]`. The target classes. Note that this format differs from 1856 the `labels` argument of `nn.softmax_cross_entropy_with_logits`. 1857 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward 1858 activations of the input network. 1859 num_sampled: An `int`. The number of classes to randomly sample per batch. 1860 num_classes: An `int`. The number of possible classes. 1861 num_true: An `int`. The number of target classes per training example. 1862 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, 1863 `sampled_expected_count`) returned by a `*_candidate_sampler` function. 1864 (if None, we default to `log_uniform_candidate_sampler`) 1865 subtract_log_q: A `bool`. whether to subtract the log expected count of 1866 the labels in the sample to get the logits of the true labels. 1867 Default is True. Turn off for Negative Sampling. 1868 remove_accidental_hits: A `bool`. whether to remove "accidental hits" 1869 where a sampled class equals one of the target classes. Default is 1870 False. 1871 partition_strategy: A string specifying the partitioning strategy, relevant 1872 if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. 1873 Default is `"mod"`. See `tf.nn.embedding_lookup` for more details. 1874 name: A name for the operation (optional). 1875 seed: random seed for candidate sampling. Default to None, which doesn't set 1876 the op-level random seed for candidate sampling. 1877 Returns: 1878 out_logits: `Tensor` object with shape 1879 `[batch_size, num_true + num_sampled]`, for passing to either 1880 `nn.sigmoid_cross_entropy_with_logits` (NCE) or 1881 `nn.softmax_cross_entropy_with_logits` (sampled softmax). 1882 out_labels: A Tensor object with the same shape as `out_logits`. 1883 """ 1884 1885 if isinstance(weights, variables.PartitionedVariable): 1886 weights = list(weights) 1887 if not isinstance(weights, list): 1888 weights = [weights] 1889 1890 with ops.name_scope(name, "compute_sampled_logits", 1891 weights + [biases, inputs, labels]): 1892 if labels.dtype != dtypes.int64: 1893 labels = math_ops.cast(labels, dtypes.int64) 1894 labels_flat = array_ops.reshape(labels, [-1]) 1895 1896 # Sample the negative labels. 1897 # sampled shape: [num_sampled] tensor 1898 # true_expected_count shape = [batch_size, 1] tensor 1899 # sampled_expected_count shape = [num_sampled] tensor 1900 if sampled_values is None: 1901 sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler( 1902 true_classes=labels, 1903 num_true=num_true, 1904 num_sampled=num_sampled, 1905 unique=True, 1906 range_max=num_classes, 1907 seed=seed) 1908 # NOTE: pylint cannot tell that 'sampled_values' is a sequence 1909 # pylint: disable=unpacking-non-sequence 1910 sampled, true_expected_count, sampled_expected_count = ( 1911 array_ops.stop_gradient(s) for s in sampled_values) 1912 # pylint: enable=unpacking-non-sequence 1913 sampled = math_ops.cast(sampled, dtypes.int64) 1914 1915 # labels_flat is a [batch_size * num_true] tensor 1916 # sampled is a [num_sampled] int tensor 1917 all_ids = array_ops.concat([labels_flat, sampled], 0) 1918 1919 # Retrieve the true weights and the logits of the sampled weights. 1920 1921 # weights shape is [num_classes, dim] 1922 all_w = embedding_ops.embedding_lookup( 1923 weights, all_ids, partition_strategy=partition_strategy) 1924 if all_w.dtype != inputs.dtype: 1925 all_w = math_ops.cast(all_w, inputs.dtype) 1926 1927 # true_w shape is [batch_size * num_true, dim] 1928 true_w = array_ops.slice(all_w, [0, 0], 1929 array_ops.stack( 1930 [array_ops.shape(labels_flat)[0], -1])) 1931 1932 sampled_w = array_ops.slice( 1933 all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1]) 1934 # inputs has shape [batch_size, dim] 1935 # sampled_w has shape [num_sampled, dim] 1936 # Apply X*W', which yields [batch_size, num_sampled] 1937 sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True) 1938 1939 # Retrieve the true and sampled biases, compute the true logits, and 1940 # add the biases to the true and sampled logits. 1941 all_b = embedding_ops.embedding_lookup( 1942 biases, all_ids, partition_strategy=partition_strategy) 1943 if all_b.dtype != inputs.dtype: 1944 all_b = math_ops.cast(all_b, inputs.dtype) 1945 # true_b is a [batch_size * num_true] tensor 1946 # sampled_b is a [num_sampled] float tensor 1947 true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat)) 1948 sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1]) 1949 1950 # inputs shape is [batch_size, dim] 1951 # true_w shape is [batch_size * num_true, dim] 1952 # row_wise_dots is [batch_size, num_true, dim] 1953 dim = array_ops.shape(true_w)[1:2] 1954 new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0) 1955 row_wise_dots = math_ops.multiply( 1956 array_ops.expand_dims(inputs, 1), 1957 array_ops.reshape(true_w, new_true_w_shape)) 1958 # We want the row-wise dot plus biases which yields a 1959 # [batch_size, num_true] tensor of true_logits. 1960 dots_as_matrix = array_ops.reshape(row_wise_dots, 1961 array_ops.concat([[-1], dim], 0)) 1962 true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true]) 1963 true_b = array_ops.reshape(true_b, [-1, num_true]) 1964 true_logits += true_b 1965 sampled_logits += sampled_b 1966 1967 if remove_accidental_hits: 1968 acc_hits = candidate_sampling_ops.compute_accidental_hits( 1969 labels, sampled, num_true=num_true) 1970 acc_indices, acc_ids, acc_weights = acc_hits 1971 1972 # This is how SparseToDense expects the indices. 1973 acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1]) 1974 acc_ids_2d_int32 = array_ops.reshape( 1975 math_ops.cast(acc_ids, dtypes.int32), [-1, 1]) 1976 sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1, 1977 "sparse_indices") 1978 # Create sampled_logits_shape = [batch_size, num_sampled] 1979 sampled_logits_shape = array_ops.concat( 1980 [array_ops.shape(labels)[:1], 1981 array_ops.expand_dims(num_sampled, 0)], 0) 1982 if sampled_logits.dtype != acc_weights.dtype: 1983 acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype) 1984 sampled_logits += gen_sparse_ops.sparse_to_dense( 1985 sparse_indices, 1986 sampled_logits_shape, 1987 acc_weights, 1988 default_value=0.0, 1989 validate_indices=False) 1990 1991 if subtract_log_q: 1992 # Subtract log of Q(l), prior probability that l appears in sampled. 1993 true_logits -= math_ops.log(true_expected_count) 1994 sampled_logits -= math_ops.log(sampled_expected_count) 1995 1996 # Construct output logits and labels. The true labels/logits start at col 0. 1997 out_logits = array_ops.concat([true_logits, sampled_logits], 1) 1998 1999 # true_logits is a float tensor, ones_like(true_logits) is a float 2000 # tensor of ones. We then divide by num_true to ensure the per-example 2001 # labels sum to 1.0, i.e. form a proper probability distribution. 2002 out_labels = array_ops.concat([ 2003 array_ops.ones_like(true_logits) / num_true, 2004 array_ops.zeros_like(sampled_logits) 2005 ], 1) 2006 2007 return out_logits, out_labels 2008 2009 2010@tf_export("nn.nce_loss", v1=[]) 2011@dispatch.add_dispatch_support 2012def nce_loss_v2(weights, 2013 biases, 2014 labels, 2015 inputs, 2016 num_sampled, 2017 num_classes, 2018 num_true=1, 2019 sampled_values=None, 2020 remove_accidental_hits=False, 2021 name="nce_loss"): 2022 """Computes and returns the noise-contrastive estimation training loss. 2023 2024 See [Noise-contrastive estimation: A new estimation principle for 2025 unnormalized statistical 2026 models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf). 2027 Also see our [Candidate Sampling Algorithms 2028 Reference](https://www.tensorflow.org/extras/candidate_sampling.pdf) 2029 2030 A common use case is to use this method for training, and calculate the full 2031 sigmoid loss for evaluation or inference as in the following example: 2032 2033 ```python 2034 if mode == "train": 2035 loss = tf.nn.nce_loss( 2036 weights=weights, 2037 biases=biases, 2038 labels=labels, 2039 inputs=inputs, 2040 ...) 2041 elif mode == "eval": 2042 logits = tf.matmul(inputs, tf.transpose(weights)) 2043 logits = tf.nn.bias_add(logits, biases) 2044 labels_one_hot = tf.one_hot(labels, n_classes) 2045 loss = tf.nn.sigmoid_cross_entropy_with_logits( 2046 labels=labels_one_hot, 2047 logits=logits) 2048 loss = tf.reduce_sum(loss, axis=1) 2049 ``` 2050 2051 Note: when doing embedding lookup on `weights` and `bias`, "div" partition 2052 strategy will be used. Support for other partition strategy will be added 2053 later. 2054 2055 Note: By default this uses a log-uniform (Zipfian) distribution for sampling, 2056 so your labels must be sorted in order of decreasing frequency to achieve 2057 good results. For more details, see 2058 `tf.random.log_uniform_candidate_sampler`. 2059 2060 Note: In the case where `num_true` > 1, we assign to each target class 2061 the target probability 1 / `num_true` so that the target probabilities 2062 sum to 1 per-example. 2063 2064 Note: It would be useful to allow a variable number of target classes per 2065 example. We hope to provide this functionality in a future release. 2066 For now, if you have a variable number of target classes, you can pad them 2067 out to a constant number by either repeating them or by padding 2068 with an otherwise unused class. 2069 2070 Args: 2071 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` 2072 objects whose concatenation along dimension 0 has shape [num_classes, 2073 dim]. The (possibly-partitioned) class embeddings. 2074 biases: A `Tensor` of shape `[num_classes]`. The class biases. 2075 labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The 2076 target classes. 2077 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of 2078 the input network. 2079 num_sampled: An `int`. The number of negative classes to randomly sample 2080 per batch. This single sample of negative classes is evaluated for each 2081 element in the batch. 2082 num_classes: An `int`. The number of possible classes. 2083 num_true: An `int`. The number of target classes per training example. 2084 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, 2085 `sampled_expected_count`) returned by a `*_candidate_sampler` function. 2086 (if None, we default to `log_uniform_candidate_sampler`) 2087 remove_accidental_hits: A `bool`. Whether to remove "accidental hits" 2088 where a sampled class equals one of the target classes. If set to `True`, 2089 this is a "Sampled Logistic" loss instead of NCE, and we are learning to 2090 generate log-odds instead of log probabilities. See our [Candidate 2091 Sampling Algorithms Reference] 2092 (https://www.tensorflow.org/extras/candidate_sampling.pdf). Default is 2093 False. 2094 name: A name for the operation (optional). 2095 2096 Returns: 2097 A `batch_size` 1-D tensor of per-example NCE losses. 2098 """ 2099 # TODO(yuefengz): get partition_strategy from either variables or distribution 2100 # strategies. 2101 return nce_loss( 2102 weights, 2103 biases, 2104 labels, 2105 inputs, 2106 num_sampled, 2107 num_classes, 2108 num_true=num_true, 2109 sampled_values=sampled_values, 2110 remove_accidental_hits=remove_accidental_hits, 2111 partition_strategy="div", 2112 name=name) 2113 2114 2115@tf_export(v1=["nn.nce_loss"]) 2116@dispatch.add_dispatch_support 2117def nce_loss(weights, 2118 biases, 2119 labels, 2120 inputs, 2121 num_sampled, 2122 num_classes, 2123 num_true=1, 2124 sampled_values=None, 2125 remove_accidental_hits=False, 2126 partition_strategy="mod", 2127 name="nce_loss"): 2128 """Computes and returns the noise-contrastive estimation training loss. 2129 2130 A common use case is to use this method for training, and calculate the full 2131 sigmoid loss for evaluation or inference. In this case, you must set 2132 `partition_strategy="div"` for the two losses to be consistent, as in the 2133 following example: 2134 2135 ```python 2136 if mode == "train": 2137 loss = tf.nn.nce_loss( 2138 weights=weights, 2139 biases=biases, 2140 labels=labels, 2141 inputs=inputs, 2142 ..., 2143 partition_strategy="div") 2144 elif mode == "eval": 2145 logits = tf.matmul(inputs, tf.transpose(weights)) 2146 logits = tf.nn.bias_add(logits, biases) 2147 labels_one_hot = tf.one_hot(labels, n_classes) 2148 loss = tf.nn.sigmoid_cross_entropy_with_logits( 2149 labels=labels_one_hot, 2150 logits=logits) 2151 loss = tf.reduce_sum(loss, axis=1) 2152 ``` 2153 2154 Note: By default this uses a log-uniform (Zipfian) distribution for sampling, 2155 so your labels must be sorted in order of decreasing frequency to achieve 2156 good results. For more details, see 2157 `tf.random.log_uniform_candidate_sampler`. 2158 2159 Note: In the case where `num_true` > 1, we assign to each target class 2160 the target probability 1 / `num_true` so that the target probabilities 2161 sum to 1 per-example. 2162 2163 Note: It would be useful to allow a variable number of target classes per 2164 example. We hope to provide this functionality in a future release. 2165 For now, if you have a variable number of target classes, you can pad them 2166 out to a constant number by either repeating them or by padding 2167 with an otherwise unused class. 2168 2169 Args: 2170 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` 2171 objects whose concatenation along dimension 0 has shape 2172 [num_classes, dim]. The (possibly-partitioned) class embeddings. 2173 biases: A `Tensor` of shape `[num_classes]`. The class biases. 2174 labels: A `Tensor` of type `int64` and shape `[batch_size, 2175 num_true]`. The target classes. 2176 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward 2177 activations of the input network. 2178 num_sampled: An `int`. The number of negative classes to randomly sample 2179 per batch. This single sample of negative classes is evaluated for each 2180 element in the batch. 2181 num_classes: An `int`. The number of possible classes. 2182 num_true: An `int`. The number of target classes per training example. 2183 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, 2184 `sampled_expected_count`) returned by a `*_candidate_sampler` function. 2185 (if None, we default to `log_uniform_candidate_sampler`) 2186 remove_accidental_hits: A `bool`. Whether to remove "accidental hits" 2187 where a sampled class equals one of the target classes. If set to 2188 `True`, this is a "Sampled Logistic" loss instead of NCE, and we are 2189 learning to generate log-odds instead of log probabilities. See 2190 our Candidate Sampling Algorithms Reference 2191 ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)). 2192 Default is False. 2193 partition_strategy: A string specifying the partitioning strategy, relevant 2194 if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. 2195 Default is `"mod"`. See `tf.nn.embedding_lookup` for more details. 2196 name: A name for the operation (optional). 2197 2198 Returns: 2199 A `batch_size` 1-D tensor of per-example NCE losses. 2200 2201 References: 2202 Noise-contrastive estimation - A new estimation principle for unnormalized 2203 statistical models: 2204 [Gutmann et al., 2010](http://proceedings.mlr.press/v9/gutmann10a) 2205 ([pdf](http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf)) 2206 """ 2207 logits, labels = _compute_sampled_logits( 2208 weights=weights, 2209 biases=biases, 2210 labels=labels, 2211 inputs=inputs, 2212 num_sampled=num_sampled, 2213 num_classes=num_classes, 2214 num_true=num_true, 2215 sampled_values=sampled_values, 2216 subtract_log_q=True, 2217 remove_accidental_hits=remove_accidental_hits, 2218 partition_strategy=partition_strategy, 2219 name=name) 2220 sampled_losses = sigmoid_cross_entropy_with_logits( 2221 labels=labels, logits=logits, name="sampled_losses") 2222 # sampled_losses is batch_size x {true_loss, sampled_losses...} 2223 # We sum out true and sampled losses. 2224 return _sum_rows(sampled_losses) 2225 2226 2227@tf_export("nn.sampled_softmax_loss", v1=[]) 2228@dispatch.add_dispatch_support 2229def sampled_softmax_loss_v2(weights, 2230 biases, 2231 labels, 2232 inputs, 2233 num_sampled, 2234 num_classes, 2235 num_true=1, 2236 sampled_values=None, 2237 remove_accidental_hits=True, 2238 seed=None, 2239 name="sampled_softmax_loss"): 2240 """Computes and returns the sampled softmax training loss. 2241 2242 This is a faster way to train a softmax classifier over a huge number of 2243 classes. 2244 2245 This operation is for training only. It is generally an underestimate of 2246 the full softmax loss. 2247 2248 A common use case is to use this method for training, and calculate the full 2249 softmax loss for evaluation or inference as in the following example: 2250 2251 ```python 2252 if mode == "train": 2253 loss = tf.nn.sampled_softmax_loss( 2254 weights=weights, 2255 biases=biases, 2256 labels=labels, 2257 inputs=inputs, 2258 ...) 2259 elif mode == "eval": 2260 logits = tf.matmul(inputs, tf.transpose(weights)) 2261 logits = tf.nn.bias_add(logits, biases) 2262 labels_one_hot = tf.one_hot(labels, n_classes) 2263 loss = tf.nn.softmax_cross_entropy_with_logits( 2264 labels=labels_one_hot, 2265 logits=logits) 2266 ``` 2267 2268 See our [Candidate Sampling Algorithms Reference] 2269 (https://www.tensorflow.org/extras/candidate_sampling.pdf) 2270 2271 Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007) 2272 ([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math. 2273 2274 Note: when doing embedding lookup on `weights` and `bias`, "div" partition 2275 strategy will be used. Support for other partition strategy will be added 2276 later. 2277 2278 Args: 2279 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` 2280 objects whose concatenation along dimension 0 has shape [num_classes, 2281 dim]. The (possibly-sharded) class embeddings. 2282 biases: A `Tensor` of shape `[num_classes]`. The class biases. 2283 labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The 2284 target classes. Note that this format differs from the `labels` argument 2285 of `nn.softmax_cross_entropy_with_logits`. 2286 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of 2287 the input network. 2288 num_sampled: An `int`. The number of classes to randomly sample per batch. 2289 num_classes: An `int`. The number of possible classes. 2290 num_true: An `int`. The number of target classes per training example. 2291 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, 2292 `sampled_expected_count`) returned by a `*_candidate_sampler` function. 2293 (if None, we default to `log_uniform_candidate_sampler`) 2294 remove_accidental_hits: A `bool`. whether to remove "accidental hits" 2295 where a sampled class equals one of the target classes. Default is True. 2296 seed: random seed for candidate sampling. Default to None, which doesn't set 2297 the op-level random seed for candidate sampling. 2298 name: A name for the operation (optional). 2299 2300 Returns: 2301 A `batch_size` 1-D tensor of per-example sampled softmax losses. 2302 2303 """ 2304 return sampled_softmax_loss( 2305 weights, 2306 biases, 2307 labels, 2308 inputs, 2309 num_sampled, 2310 num_classes, 2311 num_true=num_true, 2312 sampled_values=sampled_values, 2313 remove_accidental_hits=remove_accidental_hits, 2314 partition_strategy="div", 2315 name=name, 2316 seed=seed) 2317 2318 2319@tf_export(v1=["nn.sampled_softmax_loss"]) 2320@dispatch.add_dispatch_support 2321def sampled_softmax_loss(weights, 2322 biases, 2323 labels, 2324 inputs, 2325 num_sampled, 2326 num_classes, 2327 num_true=1, 2328 sampled_values=None, 2329 remove_accidental_hits=True, 2330 partition_strategy="mod", 2331 name="sampled_softmax_loss", 2332 seed=None): 2333 """Computes and returns the sampled softmax training loss. 2334 2335 This is a faster way to train a softmax classifier over a huge number of 2336 classes. 2337 2338 This operation is for training only. It is generally an underestimate of 2339 the full softmax loss. 2340 2341 A common use case is to use this method for training, and calculate the full 2342 softmax loss for evaluation or inference. In this case, you must set 2343 `partition_strategy="div"` for the two losses to be consistent, as in the 2344 following example: 2345 2346 ```python 2347 if mode == "train": 2348 loss = tf.nn.sampled_softmax_loss( 2349 weights=weights, 2350 biases=biases, 2351 labels=labels, 2352 inputs=inputs, 2353 ..., 2354 partition_strategy="div") 2355 elif mode == "eval": 2356 logits = tf.matmul(inputs, tf.transpose(weights)) 2357 logits = tf.nn.bias_add(logits, biases) 2358 labels_one_hot = tf.one_hot(labels, n_classes) 2359 loss = tf.nn.softmax_cross_entropy_with_logits( 2360 labels=labels_one_hot, 2361 logits=logits) 2362 ``` 2363 2364 See our Candidate Sampling Algorithms Reference 2365 ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)). 2366 Also see Section 3 of (Jean et al., 2014) for the math. 2367 2368 Args: 2369 weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` 2370 objects whose concatenation along dimension 0 has shape 2371 [num_classes, dim]. The (possibly-sharded) class embeddings. 2372 biases: A `Tensor` of shape `[num_classes]`. The class biases. 2373 labels: A `Tensor` of type `int64` and shape `[batch_size, 2374 num_true]`. The target classes. Note that this format differs from 2375 the `labels` argument of `nn.softmax_cross_entropy_with_logits`. 2376 inputs: A `Tensor` of shape `[batch_size, dim]`. The forward 2377 activations of the input network. 2378 num_sampled: An `int`. The number of classes to randomly sample per batch. 2379 num_classes: An `int`. The number of possible classes. 2380 num_true: An `int`. The number of target classes per training example. 2381 sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, 2382 `sampled_expected_count`) returned by a `*_candidate_sampler` function. 2383 (if None, we default to `log_uniform_candidate_sampler`) 2384 remove_accidental_hits: A `bool`. whether to remove "accidental hits" 2385 where a sampled class equals one of the target classes. Default is 2386 True. 2387 partition_strategy: A string specifying the partitioning strategy, relevant 2388 if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported. 2389 Default is `"mod"`. See `tf.nn.embedding_lookup` for more details. 2390 name: A name for the operation (optional). 2391 seed: random seed for candidate sampling. Default to None, which doesn't set 2392 the op-level random seed for candidate sampling. 2393 2394 Returns: 2395 A `batch_size` 1-D tensor of per-example sampled softmax losses. 2396 2397 References: 2398 On Using Very Large Target Vocabulary for Neural Machine Translation: 2399 [Jean et al., 2014] 2400 (https://aclanthology.coli.uni-saarland.de/papers/P15-1001/p15-1001) 2401 ([pdf](http://aclweb.org/anthology/P15-1001)) 2402 """ 2403 logits, labels = _compute_sampled_logits( 2404 weights=weights, 2405 biases=biases, 2406 labels=labels, 2407 inputs=inputs, 2408 num_sampled=num_sampled, 2409 num_classes=num_classes, 2410 num_true=num_true, 2411 sampled_values=sampled_values, 2412 subtract_log_q=True, 2413 remove_accidental_hits=remove_accidental_hits, 2414 partition_strategy=partition_strategy, 2415 name=name, 2416 seed=seed) 2417 labels = array_ops.stop_gradient(labels, name="labels_stop_gradient") 2418 sampled_losses = nn_ops.softmax_cross_entropy_with_logits_v2( 2419 labels=labels, logits=logits) 2420 # sampled_losses is a [batch_size] tensor. 2421 return sampled_losses 2422