xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/nn_impl.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Implementation of Neural Net (NN) functions."""
16
17import math
18
19from tensorflow.python.distribute import distribution_strategy_context as ds
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import candidate_sampling_ops
25from tensorflow.python.ops import check_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import custom_gradient
28from tensorflow.python.ops import embedding_ops
29from tensorflow.python.ops import gen_array_ops  # pylint: disable=unused-import
30from tensorflow.python.ops import gen_nn_ops
31from tensorflow.python.ops import gen_sparse_ops
32from tensorflow.python.ops import linalg_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import nn_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.ops.losses import util as losses_util
37from tensorflow.python.platform import device_context
38from tensorflow.python.util import dispatch
39from tensorflow.python.util.deprecation import deprecated_args
40from tensorflow.python.util.deprecation import deprecated_argument_lookup
41from tensorflow.python.util.tf_export import tf_export
42
43
44@tf_export("nn.log_poisson_loss")
45@dispatch.add_dispatch_support
46def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
47  """Computes log Poisson loss given `log_input`.
48
49  Gives the log-likelihood loss between the prediction and the target under the
50  assumption that the target has a Poisson distribution.
51  Caveat: By default, this is not the exact loss, but the loss minus a
52    constant term [log(z!)]. That has no effect for optimization, but
53    does not play well with relative loss comparisons. To compute an
54    approximation of the log factorial term, specify
55    compute_full_loss=True to enable Stirling's Approximation.
56
57  For brevity, let `c = log(x) = log_input`, `z = targets`.  The log Poisson
58  loss is
59
60        -log(exp(-x) * (x^z) / z!)
61      = -log(exp(-x) * (x^z)) + log(z!)
62      ~ -log(exp(-x)) - log(x^z) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
63          [ Note the second term is the Stirling's Approximation for log(z!).
64            It is invariant to x and does not affect optimization, though
65            important for correct relative loss comparisons. It is only
66            computed when compute_full_loss == True. ]
67      = x - z * log(x) [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
68      = exp(c) - z * c [+ z * log(z) - z + 0.5 * log(2 * pi * z)]
69
70  Args:
71    targets: A `Tensor` of the same type and shape as `log_input`.
72    log_input: A `Tensor` of type `float32` or `float64`.
73    compute_full_loss: whether to compute the full loss. If false, a constant
74      term is dropped in favor of more efficient optimization.
75    name: A name for the operation (optional).
76
77  Returns:
78    A `Tensor` of the same shape as `log_input` with the componentwise
79    logistic losses.
80
81  Raises:
82    ValueError: If `log_input` and `targets` do not have the same shape.
83  """
84  with ops.name_scope(name, "log_poisson_loss", [log_input, targets]) as name:
85    log_input = ops.convert_to_tensor(log_input, name="log_input")
86    targets = ops.convert_to_tensor(targets, name="targets")
87    try:
88      targets.get_shape().assert_is_compatible_with(log_input.get_shape())
89    except ValueError:
90      raise ValueError(
91          "`log_input` and `targets` must have the same shape, received "
92          f"({log_input.get_shape()} vs {targets.get_shape()}).")
93
94    result = math_ops.exp(log_input) - log_input * targets
95    if compute_full_loss:
96      # need to create constant tensors here so that their dtypes can be matched
97      # to that of the targets.
98      point_five = constant_op.constant(0.5, dtype=targets.dtype)
99      two_pi = constant_op.constant(2 * math.pi, dtype=targets.dtype)
100
101      stirling_approx = (targets * math_ops.log(targets)) - targets + (
102          point_five * math_ops.log(two_pi * targets))
103      zeros = array_ops.zeros_like(targets, dtype=targets.dtype)
104      ones = array_ops.ones_like(targets, dtype=targets.dtype)
105      cond = math_ops.logical_and(targets >= zeros, targets <= ones)
106      result += array_ops.where(cond, zeros, stirling_approx)
107    return result
108
109
110@tf_export(v1=["nn.sigmoid_cross_entropy_with_logits"])
111@dispatch.add_dispatch_support
112def sigmoid_cross_entropy_with_logits(  # pylint: disable=invalid-name
113    _sentinel=None,
114    labels=None,
115    logits=None,
116    name=None):
117  """See sigmoid_cross_entropy_with_logits_v2."""
118  # pylint: disable=protected-access
119  nn_ops._ensure_xent_args("sigmoid_cross_entropy_with_logits", _sentinel,
120                           labels, logits)
121  # pylint: enable=protected-access
122
123  with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
124    logits = ops.convert_to_tensor(logits, name="logits")
125    labels = ops.convert_to_tensor(labels, name="labels")
126    try:
127      labels.get_shape().assert_is_compatible_with(logits.get_shape())
128    except ValueError:
129      raise ValueError("`logits` and `labels` must have the same shape, "
130                       f"received ({logits.get_shape()} vs "
131                       f"{labels.get_shape()}).")
132
133    # The logistic loss formula from above is
134    #   x - x * z + log(1 + exp(-x))
135    # For x < 0, a more numerically stable formula is
136    #   -x * z + log(1 + exp(x))
137    # Note that these two expressions can be combined into the following:
138    #   max(x, 0) - x * z + log(1 + exp(-abs(x)))
139    # To allow computing gradients at zero, we define custom versions of max and
140    # abs functions.
141    zeros = array_ops.zeros_like(logits, dtype=logits.dtype)
142    cond = (logits >= zeros)
143    relu_logits = array_ops.where(cond, logits, zeros)
144    neg_abs_logits = array_ops.where(cond, -logits, logits)  # pylint: disable=invalid-unary-operand-type
145    return math_ops.add(
146        relu_logits - logits * labels,
147        math_ops.log1p(math_ops.exp(neg_abs_logits)),
148        name=name)
149
150
151# Note: intentionally calling this v2 to not allow existing code with indirect
152# imports to ignore the sentinel behavior.
153@tf_export("nn.sigmoid_cross_entropy_with_logits", v1=[])
154@dispatch.register_binary_elementwise_api
155@dispatch.add_dispatch_support
156def sigmoid_cross_entropy_with_logits_v2(  # pylint: disable=invalid-name
157    labels=None,
158    logits=None,
159    name=None):
160  r"""Computes sigmoid cross entropy given `logits`.
161
162  Measures the probability error in tasks with two outcomes in which each
163  outcome is independent and need not have a fully certain label. For instance,
164  one could perform a regression where the probability of an event happening is
165  known and used as a label. This loss may also be used for binary
166  classification, where labels are either zero or one.
167
168  For brevity, let `x = logits`, `z = labels`.  The logistic loss is
169
170        z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
171      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
172      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
173      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
174      = (1 - z) * x + log(1 + exp(-x))
175      = x - x * z + log(1 + exp(-x))
176
177  For x < 0, to avoid overflow in exp(-x), we reformulate the above
178
179        x - x * z + log(1 + exp(-x))
180      = log(exp(x)) - x * z + log(1 + exp(-x))
181      = - x * z + log(1 + exp(x))
182
183  Hence, to ensure stability and avoid overflow, the implementation uses this
184  equivalent formulation
185
186      max(x, 0) - x * z + log(1 + exp(-abs(x)))
187
188  `logits` and `labels` must have the same type and shape.
189
190  >>> logits = tf.constant([1., -1., 0., 1., -1., 0., 0.])
191  >>> labels = tf.constant([0., 0., 0., 1., 1., 1., 0.5])
192  >>> tf.nn.sigmoid_cross_entropy_with_logits(
193  ...     labels=labels, logits=logits).numpy()
194  array([1.3132617, 0.3132617, 0.6931472, 0.3132617, 1.3132617, 0.6931472,
195         0.6931472], dtype=float32)
196
197  Compared to the losses which handle multiple outcomes,
198  `tf.nn.softmax_cross_entropy_with_logits` for general multi-class
199  classification and `tf.nn.sparse_softmax_cross_entropy_with_logits` for more
200  efficient multi-class classification with hard labels,
201  `sigmoid_cross_entropy_with_logits` is a slight simplification for binary
202  classification:
203
204        sigmoid(x) = softmax([x, 0])[0]
205
206  $$\frac{1}{1 + e^{-x}} = \frac{e^x}{e^x + e^0}$$
207
208  While `sigmoid_cross_entropy_with_logits` works for soft binary labels
209  (probabilities between 0 and 1), it can also be used for binary classification
210  where the labels are hard. There is an equivalence between all three symbols
211  in this case, with a probability 0 indicating the second class or 1 indicating
212  the first class:
213
214  >>> sigmoid_logits = tf.constant([1., -1., 0.])
215  >>> softmax_logits = tf.stack([sigmoid_logits, tf.zeros_like(sigmoid_logits)],
216  ...                           axis=-1)
217  >>> soft_binary_labels = tf.constant([1., 1., 0.])
218  >>> soft_multiclass_labels = tf.stack(
219  ...     [soft_binary_labels, 1. - soft_binary_labels], axis=-1)
220  >>> hard_labels = tf.constant([0, 0, 1])
221  >>> tf.nn.sparse_softmax_cross_entropy_with_logits(
222  ...     labels=hard_labels, logits=softmax_logits).numpy()
223  array([0.31326166, 1.3132616 , 0.6931472 ], dtype=float32)
224  >>> tf.nn.softmax_cross_entropy_with_logits(
225  ...     labels=soft_multiclass_labels, logits=softmax_logits).numpy()
226  array([0.31326166, 1.3132616, 0.6931472], dtype=float32)
227  >>> tf.nn.sigmoid_cross_entropy_with_logits(
228  ...     labels=soft_binary_labels, logits=sigmoid_logits).numpy()
229  array([0.31326166, 1.3132616, 0.6931472], dtype=float32)
230
231  Args:
232    labels: A `Tensor` of the same type and shape as `logits`. Between 0 and 1,
233      inclusive.
234    logits: A `Tensor` of type `float32` or `float64`. Any real number.
235    name: A name for the operation (optional).
236
237  Returns:
238    A `Tensor` of the same shape as `logits` with the componentwise
239    logistic losses.
240
241  Raises:
242    ValueError: If `logits` and `labels` do not have the same shape.
243  """
244  return sigmoid_cross_entropy_with_logits(
245      logits=logits, labels=labels, name=name)
246
247
248sigmoid_cross_entropy_with_logits.__doc__ = (
249    sigmoid_cross_entropy_with_logits_v2.__doc__)
250
251
252@tf_export("nn.weighted_cross_entropy_with_logits", v1=[])
253@dispatch.add_dispatch_support
254def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight,
255                                          name=None):
256  """Computes a weighted cross entropy.
257
258  This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
259  allows one to trade off recall and precision by up- or down-weighting the
260  cost of a positive error relative to a negative error.
261
262  The usual cross-entropy cost is defined as:
263
264      labels * -log(sigmoid(logits)) +
265          (1 - labels) * -log(1 - sigmoid(logits))
266
267  A value `pos_weight > 1` decreases the false negative count, hence increasing
268  the recall.
269  Conversely setting `pos_weight < 1` decreases the false positive count and
270  increases the precision.
271  This can be seen from the fact that `pos_weight` is introduced as a
272  multiplicative coefficient for the positive labels term
273  in the loss expression:
274
275      labels * -log(sigmoid(logits)) * pos_weight +
276          (1 - labels) * -log(1 - sigmoid(logits))
277
278  For brevity, let `x = logits`, `z = labels`, `q = pos_weight`.
279  The loss is:
280
281        qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
282      = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
283      = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
284      = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
285      = (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
286      = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
287
288  Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
289  the implementation uses
290
291      (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
292
293  `logits` and `labels` must have the same type and shape.
294
295  >>> labels = tf.constant([1., 0.5, 0.])
296  >>> logits = tf.constant([1.5, -0.1, -10.])
297  >>> tf.nn.weighted_cross_entropy_with_logits(
298  ...     labels=labels, logits=logits, pos_weight=tf.constant(1.5)).numpy()
299  array([3.0211994e-01, 8.8049585e-01, 4.5776367e-05], dtype=float32)
300  >>> tf.nn.weighted_cross_entropy_with_logits(
301  ...     labels=labels, logits=logits, pos_weight=tf.constant(0.5)).numpy()
302  array([1.00706644e-01, 5.08297503e-01, 4.57763672e-05], dtype=float32)
303
304  Args:
305    labels: A `Tensor` of the same type and shape as `logits`, with values
306      between 0 and 1 inclusive.
307    logits: A `Tensor` of type `float32` or `float64`, any real numbers.
308    pos_weight: A coefficient to use on the positive examples, typically a
309      scalar but otherwise broadcastable to the shape of `logits`. Its value
310      should be non-negative.
311    name: A name for the operation (optional).
312
313  Returns:
314    A `Tensor` of the same shape as `logits` with the componentwise
315    weighted logistic losses.
316
317  Raises:
318    ValueError: If `logits` and `labels` do not have the same shape.
319  """
320  with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
321    logits = ops.convert_to_tensor(logits, name="logits")
322    labels = ops.convert_to_tensor(labels, name="labels")
323    try:
324      labels.get_shape().assert_is_compatible_with(logits.get_shape())
325    except ValueError:
326      raise ValueError("`logits` and `labels` must have the same shape, "
327                       f"received ({logits.get_shape()} vs "
328                       f"{labels.get_shape()}).")
329
330    # The logistic loss formula from above is
331    #   (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
332    # For x < 0, a more numerically stable formula is
333    #   (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(x)) - l * x
334    # To avoid branching, we use the combined version
335    #   (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
336    log_weight = 1 + (pos_weight - 1) * labels
337    return math_ops.add(
338        (1 - labels) * logits,
339        log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) +
340                      nn_ops.relu(-logits)),  # pylint: disable=invalid-unary-operand-type
341        name=name)
342
343
344@tf_export(v1=["nn.weighted_cross_entropy_with_logits"])
345@dispatch.add_dispatch_support
346@deprecated_args(None, "targets is deprecated, use labels instead", "targets")
347def weighted_cross_entropy_with_logits(labels=None,
348                                       logits=None,
349                                       pos_weight=None,
350                                       name=None,
351                                       targets=None):
352  """Computes a weighted cross entropy.
353
354  This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
355  allows one to trade off recall and precision by up- or down-weighting the
356  cost of a positive error relative to a negative error.
357
358  The usual cross-entropy cost is defined as:
359
360      labels * -log(sigmoid(logits)) +
361          (1 - labels) * -log(1 - sigmoid(logits))
362
363  A value `pos_weight > 1` decreases the false negative count, hence increasing
364  the recall.
365  Conversely setting `pos_weight < 1` decreases the false positive count and
366  increases the precision.
367  This can be seen from the fact that `pos_weight` is introduced as a
368  multiplicative coefficient for the positive labels term
369  in the loss expression:
370
371      labels * -log(sigmoid(logits)) * pos_weight +
372          (1 - labels) * -log(1 - sigmoid(logits))
373
374  For brevity, let `x = logits`, `z = labels`, `q = pos_weight`.
375  The loss is:
376
377        qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
378      = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
379      = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
380      = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
381      = (1 - z) * x + (qz +  1 - z) * log(1 + exp(-x))
382      = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
383
384  Setting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,
385  the implementation uses
386
387      (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
388
389  `logits` and `labels` must have the same type and shape.
390
391  Args:
392    labels: A `Tensor` of the same type and shape as `logits`.
393    logits: A `Tensor` of type `float32` or `float64`.
394    pos_weight: A coefficient to use on the positive examples.
395    name: A name for the operation (optional).
396    targets: Deprecated alias for labels.
397
398  Returns:
399    A `Tensor` of the same shape as `logits` with the componentwise
400    weighted logistic losses.
401
402  Raises:
403    ValueError: If `logits` and `labels` do not have the same shape.
404  """
405  labels = deprecated_argument_lookup("labels", labels, "targets", targets)
406  return weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, name)
407
408
409@tf_export("nn.compute_average_loss")
410@dispatch.add_dispatch_support
411def compute_average_loss(per_example_loss,
412                         sample_weight=None,
413                         global_batch_size=None):
414  """Scales per-example losses with sample_weights and computes their average.
415
416  Usage with distribution strategy and custom training loop:
417
418  ```python
419  with strategy.scope():
420    def compute_loss(labels, predictions, sample_weight=None):
421
422      # If you are using a `Loss` class instead, set reduction to `NONE` so that
423      # we can do the reduction afterwards and divide by global batch size.
424      per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
425          labels, predictions)
426
427      # Compute loss that is scaled by sample_weight and by global batch size.
428      return tf.nn.compute_average_loss(
429          per_example_loss,
430          sample_weight=sample_weight,
431          global_batch_size=GLOBAL_BATCH_SIZE)
432  ```
433
434  Args:
435    per_example_loss: Per-example loss.
436    sample_weight: Optional weighting for each example.
437    global_batch_size: Optional global batch size value. Defaults to (size of
438      first dimension of `losses`) * (number of replicas).
439
440  Returns:
441    Scalar loss value.
442  """  # pylint: disable=g-doc-exception
443  per_example_loss = ops.convert_to_tensor(per_example_loss)
444  input_dtype = per_example_loss.dtype
445
446  with losses_util.check_per_example_loss_rank(per_example_loss):
447    if sample_weight is not None:
448      sample_weight = ops.convert_to_tensor(sample_weight)
449      per_example_loss = losses_util.scale_losses_by_sample_weight(
450          per_example_loss, sample_weight)
451    per_example_loss = math_ops.cast(per_example_loss, input_dtype)
452
453    if global_batch_size is None:
454      if ds.has_strategy() and ds.in_cross_replica_context():
455        raise RuntimeError(
456            "You are calling `compute_average_loss` in cross replica context, "
457            "while it was expected to be called in replica context.")
458
459      num_replicas = ds.get_strategy().num_replicas_in_sync
460      per_replica_batch_size = array_ops.shape_v2(per_example_loss)[0]
461      global_batch_size = per_replica_batch_size * num_replicas
462
463    check_ops.assert_scalar_v2(
464        global_batch_size, message="global_batch_size must be scalar.")
465    check_ops.assert_integer_v2(
466        global_batch_size,
467        message="global_batch_size must be an integer.")
468    check_ops.assert_positive_v2(
469        global_batch_size, message="global_batch_size must be positive.")
470
471    global_batch_size = math_ops.cast(global_batch_size, input_dtype)
472    return math_ops.reduce_sum(per_example_loss) / global_batch_size
473
474
475@tf_export("nn.scale_regularization_loss")
476@dispatch.add_dispatch_support
477def scale_regularization_loss(regularization_loss):
478  """Scales the sum of the given regularization losses by number of replicas.
479
480  Usage with distribution strategy and custom training loop:
481
482  ```python
483  with strategy.scope():
484    def compute_loss(self, label, predictions):
485      per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
486          labels, predictions)
487
488      # Compute loss that is scaled by sample_weight and by global batch size.
489      loss = tf.nn.compute_average_loss(
490          per_example_loss,
491          sample_weight=sample_weight,
492          global_batch_size=GLOBAL_BATCH_SIZE)
493
494      # Add scaled regularization losses.
495      loss += tf.nn.scale_regularization_loss(tf.nn.l2_loss(weights))
496      return loss
497  ```
498
499  Args:
500    regularization_loss: Regularization loss.
501
502  Returns:
503    Scalar loss value.
504  """  # pylint: disable=g-doc-exception
505  if ds.has_strategy() and ds.in_cross_replica_context():
506    raise RuntimeError(
507        "You are calling `scale_regularization_loss` in cross replica context, "
508        "while it was expected to be called in replica context.")
509
510  num_replicas = ds.get_strategy().num_replicas_in_sync
511  return math_ops.reduce_sum(regularization_loss) / num_replicas
512
513
514@tf_export(v1=["nn.relu_layer"])
515@dispatch.add_dispatch_support
516def relu_layer(x, weights, biases, name=None):
517  """Computes Relu(x * weight + biases).
518
519  Args:
520    x: a 2D tensor.  Dimensions typically: batch, in_units
521    weights: a 2D tensor.  Dimensions typically: in_units, out_units
522    biases: a 1D tensor.  Dimensions: out_units
523    name: A name for the operation (optional).  If not specified
524      "nn_relu_layer" is used.
525
526  Returns:
527    A 2-D Tensor computing relu(matmul(x, weights) + biases).
528    Dimensions typically: batch, out_units.
529  """
530  with ops.name_scope(name, "relu_layer", [x, weights, biases]) as name:
531    x = ops.convert_to_tensor(x, name="x")
532    weights = ops.convert_to_tensor(weights, name="weights")
533    biases = ops.convert_to_tensor(biases, name="biases")
534    xw_plus_b = nn_ops.bias_add(math_ops.matmul(x, weights), biases)
535    return nn_ops.relu(xw_plus_b, name=name)
536
537
538@tf_export("nn.silu", "nn.swish")
539@dispatch.register_unary_elementwise_api
540@dispatch.add_dispatch_support
541def swish(features, beta=1.0):
542  # pylint: disable=g-doc-args
543  """Computes the SiLU or Swish activation function: `x * sigmoid(beta * x)`.
544
545  beta : Hyperparameter for Swish activation function. Default value 1.0.
546
547  The SiLU activation function was introduced in "Gaussian Error Linear Units
548  (GELUs)" [Hendrycks et al. 2016](https://arxiv.org/abs/1606.08415) and
549  "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in
550  Reinforcement Learning"
551  [Elfwing et al. 2017](https://arxiv.org/abs/1702.03118) and was independently
552  discovered (and called swish) in "Searching for Activation Functions"
553  [Ramachandran et al. 2017](https://arxiv.org/abs/1710.05941)
554
555  Args:
556    features: A `Tensor` representing preactivation values.
557    beta: A 'Tensor' representing value of beta hyperparameter.
558
559  Returns:
560    The activation value.
561  """
562  # pylint: enable=g-doc-args
563  features = ops.convert_to_tensor(features, name="features")
564  beta = ops.convert_to_tensor(beta, name="beta")
565  beta = math_ops.cast(beta, features.dtype)
566
567  @custom_gradient.custom_gradient
568  def swish_impl(features):
569
570    def grad(dy):
571      """Gradient for the Swish activation function."""
572      # Naively, x * tf.nn.sigmoid(x) requires keeping both x and sigmoid(x)
573      # around for backprop, effectively doubling the tensor's memory
574      # consumption. We use a control dependency here so that sigmoid(features)
575      # is re-computed during backprop (the control dep prevents it being
576      # de-duped with the forward pass) and we can free the sigmoid(features)
577      # expression immediately after use during the forward pass.
578      with ops.control_dependencies([dy]):
579        sigmoid_features = math_ops.sigmoid(beta * features)
580      activation_grad = (
581          sigmoid_features * (1.0 + (beta * features) *
582                              (1.0 - sigmoid_features)))
583      return dy * activation_grad
584
585    return features * math_ops.sigmoid(beta * features), grad
586
587  return swish_impl(features)
588
589
590# pylint: disable=redefined-builtin
591@tf_export("linalg.normalize")
592@dispatch.add_dispatch_support
593def normalize(tensor, ord="euclidean", axis=None, name=None):
594  """Normalizes `tensor` along dimension `axis` using specified norm.
595
596  This uses `tf.linalg.norm` to compute the norm along `axis`.
597
598  This function can compute several different vector norms (the 1-norm, the
599  Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
600  matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).
601
602  Args:
603    tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
604    ord: Order of the norm. Supported values are `'fro'`, `'euclidean'`, `1`,
605      `2`, `np.inf` and any positive real number yielding the corresponding
606      p-norm. Default is `'euclidean'` which is equivalent to Frobenius norm if
607      `tensor` is a matrix and equivalent to 2-norm for vectors.
608      Some restrictions apply: a) The Frobenius norm `'fro'` is not defined for
609        vectors, b) If axis is a 2-tuple (matrix norm), only `'euclidean'`,
610        '`fro'`, `1`, `2`, `np.inf` are supported. See the description of `axis`
611        on how to compute norms for a batch of vectors or matrices stored in a
612        tensor.
613    axis: If `axis` is `None` (the default), the input is considered a vector
614      and a single vector norm is computed over the entire set of values in the
615      tensor, i.e. `norm(tensor, ord=ord)` is equivalent to
616      `norm(reshape(tensor, [-1]), ord=ord)`. If `axis` is a Python integer, the
617      input is considered a batch of vectors, and `axis` determines the axis in
618      `tensor` over which to compute vector norms. If `axis` is a 2-tuple of
619      Python integers it is considered a batch of matrices and `axis` determines
620      the axes in `tensor` over which to compute a matrix norm.
621      Negative indices are supported. Example: If you are passing a tensor that
622        can be either a matrix or a batch of matrices at runtime, pass
623        `axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are
624        computed.
625    name: The name of the op.
626
627  Returns:
628    normalized: A normalized `Tensor` with the same shape as `tensor`.
629    norm: The computed norms with the same shape and dtype `tensor` but the
630      final axis is 1 instead. Same as running
631      `tf.cast(tf.linalg.norm(tensor, ord, axis keepdims=True), tensor.dtype)`.
632
633  Raises:
634    ValueError: If `ord` or `axis` is invalid.
635  """
636  with ops.name_scope(name, "normalize", [tensor]) as name:
637    tensor = ops.convert_to_tensor(tensor)
638    norm = linalg_ops.norm(tensor, ord, axis, keepdims=True)
639    norm = math_ops.cast(norm, tensor.dtype)
640    normalized = tensor / norm
641    return normalized, norm
642
643
644@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize",
645           v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"])
646@dispatch.add_dispatch_support
647@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
648def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
649  """Normalizes along dimension `axis` using an L2 norm.
650
651  For a 1-D tensor with `axis = 0`, computes
652
653      output = x / sqrt(max(sum(x**2), epsilon))
654
655  For `x` with more dimensions, independently normalizes each 1-D slice along
656  dimension `axis`.
657
658  1-D tensor example:
659  >>> x = tf.constant([3.0, 4.0])
660  >>> tf.math.l2_normalize(x).numpy()
661  array([0.6, 0.8], dtype=float32)
662
663  2-D tensor example:
664  >>> x = tf.constant([[3.0], [4.0]])
665  >>> tf.math.l2_normalize(x, 0).numpy()
666  array([[0.6],
667       [0.8]], dtype=float32)
668
669  >>> x = tf.constant([[3.0], [4.0]])
670  >>> tf.math.l2_normalize(x, 1).numpy()
671  array([[1.],
672       [1.]], dtype=float32)
673
674  Args:
675    x: A `Tensor`.
676    axis: Dimension along which to normalize.  A scalar or a vector of
677      integers.
678    epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
679      divisor if `norm < sqrt(epsilon)`.
680    name: A name for this operation (optional).
681    dim: Deprecated, do not use.
682
683  Returns:
684    A `Tensor` with the same shape as `x`.
685  """
686  axis = deprecated_argument_lookup("axis", axis, "dim", dim)
687  with ops.name_scope(name, "l2_normalize", [x]) as name:
688    x = ops.convert_to_tensor(x, name="x")
689    if x.dtype.is_complex:
690      square_real = math_ops.square(math_ops.real(x))
691      square_imag = math_ops.square(math_ops.imag(x))
692      square_sum = math_ops.real(
693          math_ops.reduce_sum(square_real + square_imag, axis, keepdims=True))
694      x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
695      norm_real = math_ops.multiply(math_ops.real(x), x_inv_norm)
696      norm_imag = math_ops.multiply(math_ops.imag(x), x_inv_norm)
697      return math_ops.complex(norm_real, norm_imag, name=name)
698    square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True)
699    x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
700    return math_ops.multiply(x, x_inv_norm, name=name)
701
702
703def _count_nonzero(input_tensor, dtype=dtypes.int64):
704  """Same as math_ops.count_nonzero.
705
706  The reduction is done in dtype, which can be faster for 32-bit dtypes.
707
708  Args:
709      input_tensor: numeric tensor
710      dtype: reduction dtype
711
712  Returns:
713      number of nonzero values with type dtype
714  """
715  with ops.name_scope("count_nonzero", values=[input_tensor]):
716    zero = array_ops.zeros([], dtype=input_tensor.dtype)
717    nonzero_count = math_ops.reduce_sum(
718        math_ops.cast(
719            math_ops.not_equal(input_tensor, zero),
720            dtype=dtype), name="nonzero_count")
721    return nonzero_count
722
723
724@tf_export("math.zero_fraction", "nn.zero_fraction")
725@dispatch.add_dispatch_support
726def zero_fraction(value, name=None):
727  """Returns the fraction of zeros in `value`.
728
729  If `value` is empty, the result is `nan`.
730
731  This is useful in summaries to measure and report sparsity.  For example,
732
733  ```python
734      z = tf.nn.relu(...)
735      summ = tf.compat.v1.summary.scalar('sparsity', tf.nn.zero_fraction(z))
736  ```
737
738  Args:
739    value: A tensor of numeric type.
740    name: A name for the operation (optional).
741
742  Returns:
743    The fraction of zeros in `value`, with type `float32`.
744  """
745  with ops.name_scope(name, "zero_fraction", [value]):
746    value = ops.convert_to_tensor(value, name="value")
747    size = array_ops.size(value, out_type=dtypes.int64)
748    # If the count is small, we can save memory/CPU with an int32 reduction.
749    num_nonzero = control_flow_ops.cond(
750        size <= dtypes.int32.max,
751        # pylint: disable=g-long-lambda
752        true_fn=lambda: math_ops.cast(
753            _count_nonzero(value, dtype=dtypes.int32),
754            dtype=dtypes.int64),
755        false_fn=lambda: _count_nonzero(value, dtype=dtypes.int64))
756
757    with ops.name_scope("counts_to_fraction"):
758      num_zero = size - num_nonzero
759      num_zero_float32 = math_ops.cast(num_zero, dtype=dtypes.float32)
760      size_float32 = math_ops.cast(size, dtype=dtypes.float32)
761      zero_fraction_float32 = num_zero_float32 / size_float32
762
763    return array_ops.identity(zero_fraction_float32, "fraction")
764
765
766# pylint: disable=redefined-builtin
767@tf_export(v1=["nn.depthwise_conv2d"])
768@dispatch.add_dispatch_support
769def depthwise_conv2d(input,
770                     filter,
771                     strides,
772                     padding,
773                     rate=None,
774                     name=None,
775                     data_format=None,
776                     dilations=None):
777  """Depthwise 2-D convolution.
778
779  Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
780  and a filter tensor of shape
781  `[filter_height, filter_width, in_channels, channel_multiplier]`
782  containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
783  applies a different filter to each input channel (expanding from 1 channel
784  to `channel_multiplier` channels for each), then concatenates the results
785  together.  The output has `in_channels * channel_multiplier` channels.
786
787  In detail, with the default NHWC format,
788
789      output[b, i, j, k * channel_multiplier + q] = sum_{di, dj}
790           filter[di, dj, k, q] * input[b, strides[1] * i + rate[0] * di,
791                                           strides[2] * j + rate[1] * dj, k]
792
793  Must have `strides[0] = strides[3] = 1`.  For the most common case of the
794  same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
795  If any value in `rate` is greater than 1, we perform atrous depthwise
796  convolution, in which case all values in the `strides` tensor must be equal
797  to 1.
798
799  Usage Example:
800
801  >>> x = np.array([
802  ...     [1., 2.],
803  ...     [3., 4.],
804  ...     [5., 6.]
805  ... ], dtype=np.float32).reshape((1, 3, 2, 1))
806  >>> kernel = np.array([
807  ...     [1., 2.],
808  ...     [3., 4]
809  ... ], dtype=np.float32).reshape((2, 1, 1, 2))
810  >>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
811  ...                                  padding='VALID').numpy()
812    array([[[[10., 14.],
813             [14., 20.]],
814            [[18., 26.],
815             [22., 32.]]]], dtype=float32)
816
817  >>> tf.compat.v1.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
818  ...                                  padding=[[0, 0], [1, 0], [1, 0], [0, 0]]
819  ...                                 ).numpy()
820    array([[[[ 0.,  0.],
821             [ 3.,  4.],
822             [ 6.,  8.]],
823            [[ 0.,  0.],
824             [10., 14.],
825             [14., 20.]],
826            [[ 0.,  0.],
827             [18., 26.],
828             [22., 32.]]]], dtype=float32)
829
830  Args:
831    input: 4-D with shape according to `data_format`.
832    filter: 4-D with shape
833      `[filter_height, filter_width, in_channels, channel_multiplier]`.
834    strides: 1-D of size 4.  The stride of the sliding window for each
835      dimension of `input`.
836    padding: Controls how to pad the image before applying the convolution. Can
837      be the string `"SAME"` or `"VALID"` indicating the type of padding
838      algorithm to use, or a list indicating the explicit paddings at the start
839      and end of each dimension. When explicit padding is used and data_format
840      is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
841      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
842      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
843      [pad_top, pad_bottom], [pad_left, pad_right]]`.
844    rate: 1-D of size 2. The dilation rate in which we sample input values
845      across the `height` and `width` dimensions in atrous convolution. If it is
846      greater than 1, then all values of strides must be 1.
847    name: A name for this operation (optional).
848    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
849    dilations: Alias of rate.
850
851  Returns:
852    A 4-D `Tensor` with shape according to `data_format`.  E.g., for
853    "NHWC" format, shape is
854    `[batch, out_height, out_width, in_channels * channel_multiplier].`
855  """
856  rate = deprecated_argument_lookup("dilations", dilations, "rate", rate)
857  with ops.name_scope(name, "depthwise", [input, filter]) as name:
858    input = ops.convert_to_tensor(input, name="tensor_in")
859    filter = ops.convert_to_tensor(filter, name="filter_in")
860    if rate is None:
861      rate = [1, 1]
862
863    # Use depthwise_conv2d_native if executing on TPU.
864    if device_context.enclosing_tpu_context() is not None:
865      if data_format == "NCHW":
866        dilations = [1, 1, rate[0], rate[1]]
867      else:
868        dilations = [1, rate[0], rate[1], 1]
869      return nn_ops.depthwise_conv2d_native(
870          input=input,
871          filter=filter,
872          strides=strides,
873          padding=padding,
874          data_format=data_format,
875          dilations=dilations,
876          name=name)
877
878    def op(input_converted, _, padding):
879      return nn_ops.depthwise_conv2d_native(
880          input=input_converted,
881          filter=filter,
882          strides=strides,
883          padding=padding,
884          data_format=data_format,
885          name=name)
886
887    return nn_ops.with_space_to_batch(
888        input=input,
889        filter_shape=array_ops.shape(filter),
890        dilation_rate=rate,
891        padding=padding,
892        data_format=data_format,
893        op=op)
894
895
896@tf_export("nn.depthwise_conv2d", v1=[])
897@dispatch.add_dispatch_support
898def depthwise_conv2d_v2(input,
899                        filter,
900                        strides,
901                        padding,
902                        data_format=None,
903                        dilations=None,
904                        name=None):
905  """Depthwise 2-D convolution.
906
907  Given a 4D input tensor ('NHWC' or 'NCHW' data formats)
908  and a filter tensor of shape
909  `[filter_height, filter_width, in_channels, channel_multiplier]`
910  containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
911  applies a different filter to each input channel (expanding from 1 channel
912  to `channel_multiplier` channels for each), then concatenates the results
913  together.  The output has `in_channels * channel_multiplier` channels.
914
915  In detail, with the default NHWC format,
916
917      output[b, i, j, k * channel_multiplier + q] =
918          sum_{di, dj} filter[di, dj, k, q] *
919                       input[b, strides[1] * i + dilations[0] * di,
920                                strides[2] * j + dilations[1] * dj, k]
921
922  Must have `strides[0] = strides[3] = 1`.  For the most common case of the
923  same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
924  If any value in `dilations` is greater than 1, we perform atrous depthwise
925  convolution, in which case all values in the `strides` tensor must be equal
926  to 1.
927
928  Usage Example:
929
930  >>> x = np.array([
931  ...     [1., 2.],
932  ...     [3., 4.],
933  ...     [5., 6.]
934  ... ], dtype=np.float32).reshape((1, 3, 2, 1))
935  >>> kernel = np.array([
936  ...     [1., 2.],
937  ...     [3., 4]
938  ... ], dtype=np.float32).reshape((2, 1, 1, 2))
939  >>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
940  ...                        padding='VALID').numpy()
941    array([[[[10., 14.],
942             [14., 20.]],
943            [[18., 26.],
944             [22., 32.]]]], dtype=float32)
945
946  >>> tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1],
947  ...                        padding=[[0, 0], [1, 0], [1, 0], [0, 0]]).numpy()
948    array([[[[ 0.,  0.],
949             [ 3.,  4.],
950             [ 6.,  8.]],
951            [[ 0.,  0.],
952             [10., 14.],
953             [14., 20.]],
954            [[ 0.,  0.],
955             [18., 26.],
956             [22., 32.]]]], dtype=float32)
957
958  Args:
959    input: 4-D with shape according to `data_format`.
960    filter: 4-D with shape
961      `[filter_height, filter_width, in_channels, channel_multiplier]`.
962    strides: 1-D of size 4.  The stride of the sliding window for each
963      dimension of `input`.
964    padding: Controls how to pad the image before applying the convolution. Can
965      be the string `"SAME"` or `"VALID"` indicating the type of padding
966      algorithm to use, or a list indicating the explicit paddings at the start
967      and end of each dimension. See
968      [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2)
969      for more information. When explicit padding is used and data_format
970      is `"NHWC"`, this should be in the form `[[0, 0], [pad_top, pad_bottom],
971      [pad_left, pad_right], [0, 0]]`. When explicit padding used and
972      data_format is `"NCHW"`, this should be in the form `[[0, 0], [0, 0],
973      [pad_top, pad_bottom], [pad_left, pad_right]]`.
974    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
975    dilations: 1-D of size 2. The dilation rate in which we sample input values
976      across the `height` and `width` dimensions in atrous convolution. If it is
977      greater than 1, then all values of strides must be 1.
978    name: A name for this operation (optional).
979
980  Returns:
981    A 4-D `Tensor` with shape according to `data_format`.  E.g., for
982    "NHWC" format, shape is
983    `[batch, out_height, out_width, in_channels * channel_multiplier].`
984  """
985  return depthwise_conv2d(input=input,
986                          filter=filter,
987                          strides=strides,
988                          padding=padding,
989                          rate=dilations,
990                          name=name,
991                          data_format=data_format)
992
993# pylint: enable=redefined-builtin
994
995
996# pylint: disable=redefined-builtin,line-too-long
997@tf_export(v1=["nn.separable_conv2d"])
998@dispatch.add_dispatch_support
999def separable_conv2d(input,
1000                     depthwise_filter,
1001                     pointwise_filter,
1002                     strides,
1003                     padding,
1004                     rate=None,
1005                     name=None,
1006                     data_format=None,
1007                     dilations=None):
1008  """2-D convolution with separable filters.
1009
1010  Performs a depthwise convolution that acts separately on channels followed by
1011  a pointwise convolution that mixes channels.  Note that this is separability
1012  between dimensions `[1, 2]` and `3`, not spatial separability between
1013  dimensions `1` and `2`.
1014
1015  In detail, with the default NHWC format,
1016
1017      output[b, i, j, k] = sum_{di, dj, q, r}
1018          input[b, strides[1] * i + di, strides[2] * j + dj, q] *
1019          depthwise_filter[di, dj, q, r] *
1020          pointwise_filter[0, 0, q * channel_multiplier + r, k]
1021
1022  `strides` controls the strides for the depthwise convolution only, since
1023  the pointwise convolution has implicit strides of `[1, 1, 1, 1]`.  Must have
1024  `strides[0] = strides[3] = 1`.  For the most common case of the same
1025  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
1026  If any value in `rate` is greater than 1, we perform atrous depthwise
1027  convolution, in which case all values in the `strides` tensor must be equal
1028  to 1.
1029
1030  Args:
1031    input: 4-D `Tensor` with shape according to `data_format`.
1032    depthwise_filter: 4-D `Tensor` with shape
1033      `[filter_height, filter_width, in_channels, channel_multiplier]`.
1034      Contains `in_channels` convolutional filters of depth 1.
1035    pointwise_filter: 4-D `Tensor` with shape
1036      `[1, 1, channel_multiplier * in_channels, out_channels]`.  Pointwise
1037      filter to mix channels after `depthwise_filter` has convolved spatially.
1038    strides: 1-D of size 4.  The strides for the depthwise convolution for
1039      each dimension of `input`.
1040    padding: Controls how to pad the image before applying the depthwise
1041      convolution. Can be the string `"SAME"` or `"VALID"` indicating the type
1042      of padding algorithm to use, or a Python list indicating the explicit
1043      paddings at the start and end of each dimension. When explicit padding is
1044      used and data_format is `"NHWC"`, this should be in the form `[[0, 0],
1045      [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit
1046      padding used and data_format is `"NCHW"`, this should be in the form
1047      `[[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]]`.
1048    rate: 1-D of size 2. The dilation rate in which we sample input values
1049      across the `height` and `width` dimensions in atrous convolution. If it is
1050      greater than 1, then all values of strides must be 1.
1051    name: A name for this operation (optional).
1052    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
1053    dilations: Alias of rate.
1054
1055  Returns:
1056    A 4-D `Tensor` with shape according to 'data_format'. For
1057      example, with data_format="NHWC", shape is [batch, out_height,
1058      out_width, out_channels].
1059  """
1060  rate = deprecated_argument_lookup("dilations", dilations, "rate", rate)
1061  with ops.name_scope(name, "separable_conv2d",
1062                      [input, depthwise_filter, pointwise_filter]) as name:
1063    input = ops.convert_to_tensor(input, name="tensor_in")
1064    depthwise_filter = ops.convert_to_tensor(
1065        depthwise_filter, name="depthwise_filter")
1066    pointwise_filter = ops.convert_to_tensor(
1067        pointwise_filter, name="pointwise_filter")
1068
1069    pointwise_filter_shape = pointwise_filter.get_shape().with_rank(4)
1070    pointwise_filter_shape.dims[0].assert_is_compatible_with(1)
1071    pointwise_filter_shape.dims[1].assert_is_compatible_with(1)
1072
1073    if rate is None:
1074      rate = [1, 1]
1075
1076    # The layout of the ops in the graph are expected to be as follows:
1077    # depthwise_conv2d  // Conv2D op corresponding to native depthwise conv.
1078    # separable_conv2d  // Conv2D op corresponding to the pointwise conv.
1079
1080    def op(input_converted, _, padding):
1081      return nn_ops.depthwise_conv2d_native(
1082          input=input_converted,
1083          filter=depthwise_filter,
1084          strides=strides,
1085          padding=padding,
1086          data_format=data_format,
1087          name="depthwise")
1088
1089    depthwise = nn_ops.with_space_to_batch(
1090        input=input,
1091        filter_shape=array_ops.shape(depthwise_filter),
1092        dilation_rate=rate,
1093        padding=padding,
1094        data_format=data_format,
1095        op=op)
1096
1097    return nn_ops.conv2d(
1098        depthwise,
1099        pointwise_filter, [1, 1, 1, 1],
1100        padding="VALID",
1101        data_format=data_format,
1102        name=name)
1103
1104
1105@tf_export("nn.separable_conv2d", v1=[])
1106@dispatch.add_dispatch_support
1107def separable_conv2d_v2(
1108    input,
1109    depthwise_filter,
1110    pointwise_filter,
1111    strides,
1112    padding,
1113    data_format=None,
1114    dilations=None,
1115    name=None,
1116):
1117  """2-D convolution with separable filters.
1118
1119  Performs a depthwise convolution that acts separately on channels followed by
1120  a pointwise convolution that mixes channels.  Note that this is separability
1121  between dimensions `[1, 2]` and `3`, not spatial separability between
1122  dimensions `1` and `2`.
1123
1124  In detail, with the default NHWC format,
1125
1126      output[b, i, j, k] = sum_{di, dj, q, r}
1127          input[b, strides[1] * i + di, strides[2] * j + dj, q] *
1128          depthwise_filter[di, dj, q, r] *
1129          pointwise_filter[0, 0, q * channel_multiplier + r, k]
1130
1131  `strides` controls the strides for the depthwise convolution only, since
1132  the pointwise convolution has implicit strides of `[1, 1, 1, 1]`.  Must have
1133  `strides[0] = strides[3] = 1`.  For the most common case of the same
1134  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
1135  If any value in `rate` is greater than 1, we perform atrous depthwise
1136  convolution, in which case all values in the `strides` tensor must be equal
1137  to 1.
1138
1139  Args:
1140    input: 4-D `Tensor` with shape according to `data_format`.
1141    depthwise_filter: 4-D `Tensor` with shape `[filter_height, filter_width,
1142      in_channels, channel_multiplier]`. Contains `in_channels` convolutional
1143      filters of depth 1.
1144    pointwise_filter: 4-D `Tensor` with shape `[1, 1, channel_multiplier *
1145      in_channels, out_channels]`.  Pointwise filter to mix channels after
1146      `depthwise_filter` has convolved spatially.
1147    strides: 1-D of size 4.  The strides for the depthwise convolution for each
1148      dimension of `input`.
1149    padding: Controls how to pad the image before applying the depthwise
1150      convolution. Can be the string `"SAME"` or `"VALID"` indicating the type
1151      of padding algorithm to use, or a Python list indicating the explicit
1152      paddings at the start and end of each dimension. When explicit padding is
1153      used and data_format is `"NHWC"`, this should be in the form `[[0, 0],
1154      [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit
1155      padding used and data_format is `"NCHW"`, this should be in the form
1156      `[[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]]`.
1157    data_format: The data format for input. Either "NHWC" (default) or "NCHW".
1158    dilations: 1-D of size 2. The dilation rate in which we sample input values
1159      across the `height` and `width` dimensions in atrous convolution. If it is
1160      greater than 1, then all values of strides must be 1.
1161    name: A name for this operation (optional).
1162
1163  Returns:
1164    A 4-D `Tensor` with shape according to 'data_format'. For
1165      example, with data_format="NHWC", shape is [batch, out_height,
1166      out_width, out_channels].
1167  """
1168  return separable_conv2d(
1169      input,
1170      depthwise_filter,
1171      pointwise_filter,
1172      strides,
1173      padding,
1174      rate=dilations,
1175      name=name,
1176      data_format=data_format)
1177
1178# pylint: enable=redefined-builtin,line-too-long
1179
1180
1181@tf_export(v1=["nn.sufficient_statistics"])
1182@dispatch.add_dispatch_support
1183def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
1184                          keepdims=None):
1185  """Calculate the sufficient statistics for the mean and variance of `x`.
1186
1187  These sufficient statistics are computed using the one pass algorithm on
1188  an input that's optionally shifted. See:
1189  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
1190
1191  For example:
1192  >>> t = [[1, 2, 3], [4, 5, 6]]
1193  >>> sufficient_statistics(t, [1])
1194  (<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,),
1195  dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,),
1196  dtype=int32, numpy=array([14, 77], dtype=int32)>, None)
1197  >>> sufficient_statistics(t, [-1])
1198  (<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(2,),
1199  dtype=int32, numpy=array([ 6, 15], dtype=int32)>, <tf.Tensor: shape=(2,),
1200  dtype=int32, numpy=array([14, 77], dtype=int32)>, None)
1201
1202  Args:
1203    x: A `Tensor`.
1204    axes: Array of ints. Axes along which to compute mean and variance. As in
1205      Python, the axes can also be negative numbers. A negative axis is
1206      interpreted as counting from the end of the rank, i.e., axis +
1207      rank(values)-th dimension.
1208    shift: A `Tensor` containing the value by which to shift the data for
1209      numerical stability, or `None` if no shift is to be performed. A shift
1210      close to the true mean provides the most numerically stable results.
1211    keep_dims: produce statistics with the same dimensionality as the input.
1212    name: Name used to scope the operations that compute the sufficient stats.
1213    keepdims: Alias for keep_dims.
1214
1215  Returns:
1216    Four `Tensor` objects of the same type as `x`:
1217
1218    * the count (number of elements to average over).
1219    * the (possibly shifted) sum of the elements in the array.
1220    * the (possibly shifted) sum of squares of the elements in the array.
1221    * the shift by which the mean must be corrected or None if `shift` is None.
1222  """
1223  axes = list(set(axes))
1224  keep_dims = deprecated_argument_lookup(
1225      "keepdims", keepdims, "keep_dims", keep_dims)
1226  if keep_dims is None:
1227    keep_dims = False
1228  with ops.name_scope(name, "sufficient_statistics", [x, shift]):
1229    x = ops.convert_to_tensor(x, name="x")
1230    x_shape = x.get_shape()
1231    if x_shape.rank is not None and all(
1232        x_shape.dims[d].value is not None for d in axes):
1233      counts = 1
1234      for d in axes:
1235        counts *= x_shape.dims[d].value
1236      counts = constant_op.constant(counts, dtype=x.dtype)
1237    else:  # shape needs to be inferred at runtime.
1238      # Normalize axes to be positive. Required for gather.
1239      rank = array_ops.rank(x)
1240      positive_axes = [axis + rank if axis < 0 else axis for axis in axes]
1241      x_dims = array_ops.gather(
1242          math_ops.cast(array_ops.shape(x), x.dtype), positive_axes)
1243      counts = math_ops.reduce_prod(x_dims, name="count")
1244    if shift is not None:
1245      shift = ops.convert_to_tensor(shift, name="shift")
1246      m_ss = math_ops.subtract(x, shift)
1247      v_ss = math_ops.squared_difference(x, shift)
1248    else:  # no shift.
1249      m_ss = x
1250      v_ss = math_ops.square(x)
1251    m_ss = math_ops.reduce_sum(m_ss, axes, keepdims=keep_dims, name="mean_ss")
1252    v_ss = math_ops.reduce_sum(v_ss, axes, keepdims=keep_dims, name="var_ss")
1253  return counts, m_ss, v_ss, shift
1254
1255
1256@tf_export("nn.sufficient_statistics", v1=[])
1257@dispatch.add_dispatch_support
1258def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None):
1259  """Calculate the sufficient statistics for the mean and variance of `x`.
1260
1261  These sufficient statistics are computed using the one pass algorithm on
1262  an input that's optionally shifted. See:
1263  https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data
1264
1265  Args:
1266    x: A `Tensor`.
1267    axes: Array of ints. Axes along which to compute mean and variance.
1268    shift: A `Tensor` containing the value by which to shift the data for
1269      numerical stability, or `None` if no shift is to be performed. A shift
1270      close to the true mean provides the most numerically stable results.
1271    keepdims: produce statistics with the same dimensionality as the input.
1272    name: Name used to scope the operations that compute the sufficient stats.
1273
1274  Returns:
1275    Four `Tensor` objects of the same type as `x`:
1276
1277    * the count (number of elements to average over).
1278    * the (possibly shifted) sum of the elements in the array.
1279    * the (possibly shifted) sum of squares of the elements in the array.
1280    * the shift by which the mean must be corrected or None if `shift` is None.
1281  """
1282  return sufficient_statistics(
1283      x=x, axes=axes, shift=shift, keep_dims=keepdims, name=name)
1284
1285
1286@tf_export("nn.normalize_moments")
1287@dispatch.add_dispatch_support
1288def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
1289  """Calculate the mean and variance of based on the sufficient statistics.
1290
1291  Args:
1292    counts: A `Tensor` containing the total count of the data (one value).
1293    mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly
1294      shifted) sum of the elements to average over.
1295    variance_ss: A `Tensor` containing the variance sufficient statistics: the
1296      (possibly shifted) squared sum of the data to compute the variance over.
1297    shift: A `Tensor` containing the value by which the data is shifted for
1298      numerical stability, or `None` if no shift was performed.
1299    name: Name used to scope the operations that compute the moments.
1300
1301  Returns:
1302    Two `Tensor` objects: `mean` and `variance`.
1303  """
1304  with ops.name_scope(name, "normalize", [counts, mean_ss, variance_ss, shift]):
1305    divisor = math_ops.reciprocal(counts, name="divisor")
1306    if shift is not None:
1307      shifted_mean = math_ops.multiply(mean_ss, divisor, name="shifted_mean")
1308      mean = math_ops.add(shifted_mean, shift, name="mean")
1309    else:  # no shift.
1310      shifted_mean = math_ops.multiply(mean_ss, divisor, name="mean")
1311      mean = shifted_mean
1312    variance = math_ops.subtract(
1313        math_ops.multiply(variance_ss, divisor),
1314        math_ops.square(shifted_mean),
1315        name="variance")
1316  return (mean, variance)
1317
1318
1319@tf_export(v1=["nn.moments"])
1320@dispatch.add_dispatch_support
1321def moments(
1322    x,
1323    axes,
1324    shift=None,  # pylint: disable=unused-argument
1325    name=None,
1326    keep_dims=None,
1327    keepdims=None):
1328  """Calculate the mean and variance of `x`.
1329
1330  The mean and variance are calculated by aggregating the contents of `x`
1331  across `axes`.  If `x` is 1-D and `axes = [0]` this is just the mean
1332  and variance of a vector.
1333
1334  Note: shift is currently not used; the true mean is computed and used.
1335
1336  When using these moments for batch normalization (see
1337  `tf.nn.batch_normalization`):
1338
1339   * for so-called "global normalization", used with convolutional filters with
1340     shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
1341   * for simple batch normalization pass `axes=[0]` (batch only).
1342
1343  Args:
1344    x: A `Tensor`.
1345    axes: Array of ints.  Axes along which to compute mean and
1346      variance.
1347    shift: Not used in the current implementation
1348    name: Name used to scope the operations that compute the moments.
1349    keep_dims: produce moments with the same dimensionality as the input.
1350    keepdims: Alias to keep_dims.
1351
1352  Returns:
1353    Two `Tensor` objects: `mean` and `variance`.
1354  """
1355  keep_dims = deprecated_argument_lookup(
1356      "keepdims", keepdims, "keep_dims", keep_dims)
1357  if keep_dims is None:
1358    keep_dims = False
1359  with ops.name_scope(name, "moments", [x, axes]):
1360    # The dynamic range of fp16 is too limited to support the collection of
1361    # sufficient statistics. As a workaround we simply perform the operations
1362    # on 32-bit floats before converting the mean and variance back to fp16
1363    y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
1364    # Compute true mean while keeping the dims for proper broadcasting.
1365    mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean")
1366    # sample variance, not unbiased variance
1367    # Note: stop_gradient does not change the gradient that gets
1368    #       backpropagated to the mean from the variance calculation,
1369    #       because that gradient is zero
1370    variance = math_ops.reduce_mean(
1371        math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
1372        axes,
1373        keepdims=True,
1374        name="variance")
1375    if not keep_dims:
1376      mean = array_ops.squeeze(mean, axes)
1377      variance = array_ops.squeeze(variance, axes)
1378    if x.dtype == dtypes.float16:
1379      return (math_ops.cast(mean, dtypes.float16),
1380              math_ops.cast(variance, dtypes.float16))
1381    else:
1382      return (mean, variance)
1383
1384
1385@tf_export("nn.moments", v1=[])
1386@dispatch.add_dispatch_support
1387def moments_v2(
1388    x,
1389    axes,
1390    shift=None,
1391    keepdims=False,
1392    name=None):
1393  """Calculates the mean and variance of `x`.
1394
1395  The mean and variance are calculated by aggregating the contents of `x`
1396  across `axes`.  If `x` is 1-D and `axes = [0]` this is just the mean
1397  and variance of a vector.
1398
1399  Note: shift is currently not used; the true mean is computed and used.
1400
1401  When using these moments for batch normalization (see
1402  `tf.nn.batch_normalization`):
1403
1404   * for so-called "global normalization", used with convolutional filters with
1405     shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
1406   * for simple batch normalization pass `axes=[0]` (batch only).
1407
1408  Args:
1409    x: A `Tensor`.
1410    axes: Array of ints.  Axes along which to compute mean and
1411      variance.
1412    shift: Not used in the current implementation.
1413    keepdims: produce moments with the same dimensionality as the input.
1414    name: Name used to scope the operations that compute the moments.
1415
1416  Returns:
1417    Two `Tensor` objects: `mean` and `variance`.
1418  """
1419  return moments(x=x, axes=axes, shift=shift, name=name, keep_dims=keepdims)
1420
1421
1422@tf_export(v1=["nn.weighted_moments"])
1423@dispatch.add_dispatch_support
1424def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None,
1425                     keepdims=None):
1426  """Returns the frequency-weighted mean and variance of `x`.
1427
1428  Args:
1429    x: A tensor.
1430    axes: 1-d tensor of int32 values; these are the axes along which
1431      to compute mean and variance.
1432    frequency_weights: A tensor of positive weights which can be
1433      broadcast with x.
1434    name: Name used to scope the operation.
1435    keep_dims: Produce moments with the same dimensionality as the input.
1436    keepdims: Alias of keep_dims.
1437
1438  Returns:
1439    Two tensors: `weighted_mean` and `weighted_variance`.
1440  """
1441  keep_dims = deprecated_argument_lookup(
1442      "keepdims", keepdims, "keep_dims", keep_dims)
1443  if keep_dims is None:
1444    keep_dims = False
1445  with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]):
1446    x = ops.convert_to_tensor(x, name="x")
1447    frequency_weights = ops.convert_to_tensor(
1448        frequency_weights, name="frequency_weights")
1449
1450    # Unlike moments(), this just uses a simpler two-pass method.
1451
1452    # See comment in moments() WRT precision; it applies here too.
1453    needs_cast = x.dtype == dtypes.float16
1454    if needs_cast:
1455      x = math_ops.cast(x, dtypes.float32)
1456
1457    if frequency_weights.dtype != x.dtype:
1458      frequency_weights = math_ops.cast(frequency_weights, x.dtype)
1459
1460    # Note that we use keep_dims=True for our reductions regardless of the arg;
1461    # this is so that the results remain broadcast-compatible with the inputs.
1462    weighted_input_sum = math_ops.reduce_sum(
1463        frequency_weights * x, axes, name="weighted_input_sum", keepdims=True)
1464
1465    # The shape of the weights isn't necessarily the same as x's
1466    # shape, just broadcast-compatible with it -- so this expression
1467    # performs broadcasting to give a per-item weight, with the same
1468    # shape as (frequency_weights * x). This avoids having to reason
1469    # through all the broadcast logic to compute a correct
1470    # sum_of_weights.
1471    broadcasted_weights = frequency_weights + array_ops.zeros_like(x)
1472
1473    sum_of_weights = math_ops.reduce_sum(
1474        broadcasted_weights, axes, name="sum_of_weights", keepdims=True)
1475
1476    divisor = math_ops.reciprocal(sum_of_weights, name="inv_weight_sum")
1477
1478    weighted_mean = math_ops.multiply(weighted_input_sum, divisor)
1479
1480    # Have the weighted mean; now on to variance:
1481    weighted_distsq = math_ops.reduce_sum(
1482        frequency_weights * math_ops.squared_difference(x, weighted_mean),
1483        axes,
1484        name="weighted_distsq",
1485        keepdims=True)
1486
1487    weighted_variance = math_ops.multiply(weighted_distsq, divisor)
1488
1489    if not keep_dims:
1490      weighted_mean = array_ops.squeeze(weighted_mean, axis=axes)
1491      weighted_variance = array_ops.squeeze(
1492          weighted_variance, axis=axes)
1493
1494    if needs_cast:
1495      weighted_mean = math_ops.cast(weighted_mean, dtypes.float16)
1496      weighted_variance = math_ops.cast(weighted_variance, dtypes.float16)
1497
1498    return weighted_mean, weighted_variance
1499
1500
1501@tf_export("nn.weighted_moments", v1=[])
1502@dispatch.add_dispatch_support
1503def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None):
1504  """Returns the frequency-weighted mean and variance of `x`.
1505
1506  Args:
1507    x: A tensor.
1508    axes: 1-d tensor of int32 values; these are the axes along which
1509      to compute mean and variance.
1510    frequency_weights: A tensor of positive weights which can be
1511      broadcast with x.
1512    keepdims: Produce moments with the same dimensionality as the input.
1513    name: Name used to scope the operation.
1514
1515  Returns:
1516    Two tensors: `weighted_mean` and `weighted_variance`.
1517  """
1518  return weighted_moments(
1519      x=x,
1520      axes=axes,
1521      frequency_weights=frequency_weights,
1522      name=name,
1523      keep_dims=keepdims)
1524
1525
1526@tf_export("nn.batch_normalization")
1527@dispatch.add_dispatch_support
1528def batch_normalization(x,
1529                        mean,
1530                        variance,
1531                        offset,
1532                        scale,
1533                        variance_epsilon,
1534                        name=None):
1535  r"""Batch normalization.
1536
1537  Normalizes a tensor by `mean` and `variance`, and applies (optionally) a
1538  `scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\):
1539
1540  \\(\frac{\gamma(x-\mu)}{\sigma}+\beta\\)
1541
1542  `mean`, `variance`, `offset` and `scale` are all expected to be of one of two
1543  shapes:
1544
1545    * In all generality, they can have the same number of dimensions as the
1546      input `x`, with identical sizes as `x` for the dimensions that are not
1547      normalized over (the 'depth' dimension(s)), and dimension 1 for the
1548      others which are being normalized over.
1549      `mean` and `variance` in this case would typically be the outputs of
1550      `tf.nn.moments(..., keepdims=True)` during training, or running averages
1551      thereof during inference.
1552    * In the common case where the 'depth' dimension is the last dimension in
1553      the input tensor `x`, they may be one dimensional tensors of the same
1554      size as the 'depth' dimension.
1555      This is the case for example for the common `[batch, depth]` layout of
1556      fully-connected layers, and `[batch, height, width, depth]` for
1557      convolutions.
1558      `mean` and `variance` in this case would typically be the outputs of
1559      `tf.nn.moments(..., keepdims=False)` during training, or running averages
1560      thereof during inference.
1561
1562  See equation 11 in Algorithm 2 of source:
1563  [Batch Normalization: Accelerating Deep Network Training by
1564  Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy]
1565  (http://arxiv.org/abs/1502.03167).
1566
1567  Args:
1568    x: Input `Tensor` of arbitrary dimensionality.
1569    mean: A mean `Tensor`.
1570    variance: A variance `Tensor`.
1571    offset: An offset `Tensor`, often denoted \\(\beta\\) in equations, or
1572      None. If present, will be added to the normalized tensor.
1573    scale: A scale `Tensor`, often denoted \\(\gamma\\) in equations, or
1574      `None`. If present, the scale is applied to the normalized tensor.
1575    variance_epsilon: A small float number to avoid dividing by 0.
1576    name: A name for this operation (optional).
1577
1578  Returns:
1579    the normalized, scaled, offset tensor.
1580
1581  References:
1582    Batch Normalization - Accelerating Deep Network Training by Reducing
1583    Internal Covariate Shift:
1584      [Ioffe et al., 2015](http://arxiv.org/abs/1502.03167)
1585      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1586  """
1587  with ops.name_scope(name, "batchnorm", [x, mean, variance, scale, offset]):
1588    inv = math_ops.rsqrt(variance + variance_epsilon)
1589    if scale is not None:
1590      inv *= scale
1591    # Note: tensorflow/contrib/quantize/python/fold_batch_norms.py depends on
1592    # the precise order of ops that are generated by the expression below.
1593    return x * math_ops.cast(inv, x.dtype) + math_ops.cast(
1594        offset - mean * inv if offset is not None else -mean * inv, x.dtype)
1595
1596
1597@tf_export(v1=["nn.fused_batch_norm"])
1598@dispatch.add_dispatch_support
1599def fused_batch_norm(
1600    x,
1601    scale,
1602    offset,  # pylint: disable=invalid-name
1603    mean=None,
1604    variance=None,
1605    epsilon=0.001,
1606    data_format="NHWC",
1607    is_training=True,
1608    name=None,
1609    exponential_avg_factor=1.0):
1610  r"""Batch normalization.
1611
1612
1613  See Source: [Batch Normalization: Accelerating Deep Network Training by
1614  Reducing Internal Covariate Shift; S. Ioffe, C. Szegedy]
1615  (http://arxiv.org/abs/1502.03167).
1616
1617  Args:
1618    x: Input `Tensor` of 4 or 5 dimensions.
1619    scale: A `Tensor` of 1 dimension for scaling.
1620    offset: A `Tensor` of 1 dimension for bias.
1621    mean: A `Tensor` of 1 dimension for population mean. The shape and meaning
1622          of this argument depends on the value of is_training and
1623          exponential_avg_factor as follows:
1624          is_training==False (inference):
1625            Mean must be a `Tensor` of the same shape as scale containing the
1626            estimated population mean computed during training.
1627          is_training==True and exponential_avg_factor == 1.0:
1628            Mean must be None.
1629          is_training==True and exponential_avg_factor != 1.0:
1630            Mean must be a `Tensor` of the same shape as scale containing the
1631            exponential running mean.
1632    variance: A `Tensor` of 1 dimension for population variance. The shape and
1633          meaning of this argument depends on the value of is_training and
1634          exponential_avg_factor as follows:
1635          is_training==False (inference):
1636            Variance must be a `Tensor` of the same shape as scale containing
1637            the estimated population variance computed during training.
1638          is_training==True and exponential_avg_factor == 1.0:
1639            Variance must be None.
1640          is_training==True and exponential_avg_factor != 1.0:
1641            Variance must be a `Tensor` of the same shape as scale containing
1642            the exponential running variance.
1643    epsilon: A small float number added to the variance of x.
1644    data_format: The data format for x. Support "NHWC" (default) or "NCHW" for
1645                 4D tenors and "NDHWC" or "NCDHW" for 5D tensors.
1646    is_training: A bool value to specify if the operation is used for
1647                 training or inference.
1648    name: A name for this operation (optional).
1649    exponential_avg_factor: A float number (usually between 0 and 1) used
1650                            for controlling the decay of the running
1651                            population average of mean and variance.
1652                            If set to 1.0, the current batch average is
1653                            returned.
1654
1655  Returns:
1656    y: A 4D or 5D Tensor for the normalized, scaled, offsetted x.
1657    running_mean: A 1D Tensor for the exponential running mean of x.
1658                  The output value is (1 - exponential_avg_factor) * mean +
1659                  exponential_avg_factor * batch_mean), where batch_mean
1660                  is the mean of the current batch in x.
1661    running_var: A 1D Tensor for the exponential running variance
1662                 The output value is (1 - exponential_avg_factor) * variance +
1663                 exponential_avg_factor * batch_variance), where batch_variance
1664                 is the variance of the current batch in x.
1665
1666  References:
1667    Batch Normalization - Accelerating Deep Network Training by Reducing
1668    Internal Covariate Shift:
1669      [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1670      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1671  """
1672  if (not is_training or exponential_avg_factor != 1.0) and (
1673      (mean is None) or (variance is None)):
1674    raise ValueError("Both `mean` and `variance` must be a 1D tensor when "
1675                     "`is_training` is False or `exponential_avg_factor` != "
1676                     f"1.0. Received: `mean` {mean!r} and `variance` "
1677                     f"{variance!r}")
1678  x = ops.convert_to_tensor(x, name="input")
1679  scale = ops.convert_to_tensor(scale, name="scale")
1680  offset = ops.convert_to_tensor(offset, name="offset")
1681  if mean is None:
1682    mean = constant_op.constant([])
1683  if variance is None:
1684    variance = constant_op.constant([])
1685
1686  # Set a minimum epsilon to 1.001e-5, which is a requirement by CUDNN to
1687  # prevent exception (see cudnn.h).
1688  min_epsilon = 1.001e-5
1689  epsilon = epsilon if epsilon > min_epsilon else min_epsilon
1690
1691  y, running_mean, running_var, _, _, _ = gen_nn_ops.fused_batch_norm_v3(
1692      x,
1693      scale,
1694      offset,
1695      mean,
1696      variance,
1697      epsilon=epsilon,
1698      exponential_avg_factor=exponential_avg_factor,
1699      data_format=data_format,
1700      is_training=is_training,
1701      name=name)
1702  return y, running_mean, running_var
1703
1704
1705@tf_export(v1=["nn.batch_norm_with_global_normalization"])
1706@dispatch.add_dispatch_support
1707def batch_norm_with_global_normalization(t=None,
1708                                         m=None,
1709                                         v=None,
1710                                         beta=None,
1711                                         gamma=None,
1712                                         variance_epsilon=None,
1713                                         scale_after_normalization=None,
1714                                         name=None,
1715                                         input=None,  # pylint: disable=redefined-builtin
1716                                         mean=None,
1717                                         variance=None):
1718  """Batch normalization.
1719
1720  This op is deprecated. See `tf.nn.batch_normalization`.
1721
1722  Args:
1723    t: A 4D input Tensor.
1724    m: A 1D mean Tensor with size matching the last dimension of t.
1725      This is the first output from tf.nn.moments,
1726      or a saved moving average thereof.
1727    v: A 1D variance Tensor with size matching the last dimension of t.
1728      This is the second output from tf.nn.moments,
1729      or a saved moving average thereof.
1730    beta: A 1D beta Tensor with size matching the last dimension of t.
1731      An offset to be added to the normalized tensor.
1732    gamma: A 1D gamma Tensor with size matching the last dimension of t.
1733      If "scale_after_normalization" is true, this tensor will be multiplied
1734      with the normalized tensor.
1735    variance_epsilon: A small float number to avoid dividing by 0.
1736    scale_after_normalization: A bool indicating whether the resulted tensor
1737      needs to be multiplied with gamma.
1738    name: A name for this operation (optional).
1739    input: Alias for t.
1740    mean: Alias for m.
1741    variance: Alias for v.
1742
1743  Returns:
1744     A batch-normalized `t`.
1745
1746  References:
1747    Batch Normalization - Accelerating Deep Network Training by Reducing
1748    Internal Covariate Shift:
1749      [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1750      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1751  """
1752  t = deprecated_argument_lookup("input", input, "t", t)
1753  m = deprecated_argument_lookup("mean", mean, "m", m)
1754  v = deprecated_argument_lookup("variance", variance, "v", v)
1755  return batch_normalization(t, m, v, beta, gamma if scale_after_normalization
1756                             else None, variance_epsilon, name)
1757
1758
1759# pylint: disable=redefined-builtin,line-too-long
1760@tf_export("nn.batch_norm_with_global_normalization", v1=[])
1761@dispatch.add_dispatch_support
1762def batch_norm_with_global_normalization_v2(input,
1763                                            mean,
1764                                            variance,
1765                                            beta,
1766                                            gamma,
1767                                            variance_epsilon,
1768                                            scale_after_normalization,
1769                                            name=None):
1770  """Batch normalization.
1771
1772  This op is deprecated. See `tf.nn.batch_normalization`.
1773
1774  Args:
1775    input: A 4D input Tensor.
1776    mean: A 1D mean Tensor with size matching the last dimension of t.
1777      This is the first output from tf.nn.moments,
1778      or a saved moving average thereof.
1779    variance: A 1D variance Tensor with size matching the last dimension of t.
1780      This is the second output from tf.nn.moments,
1781      or a saved moving average thereof.
1782    beta: A 1D beta Tensor with size matching the last dimension of t.
1783      An offset to be added to the normalized tensor.
1784    gamma: A 1D gamma Tensor with size matching the last dimension of t.
1785      If "scale_after_normalization" is true, this tensor will be multiplied
1786      with the normalized tensor.
1787    variance_epsilon: A small float number to avoid dividing by 0.
1788    scale_after_normalization: A bool indicating whether the resulted tensor
1789      needs to be multiplied with gamma.
1790    name: A name for this operation (optional).
1791
1792  Returns:
1793     A batch-normalized `t`.
1794
1795  References:
1796    Batch Normalization - Accelerating Deep Network Training by Reducing Internal Covariate Shift:
1797      [Ioffe et al., 2015](http://proceedings.mlr.press/v37/ioffe15.html)
1798      ([pdf](http://proceedings.mlr.press/v37/ioffe15.pdf))
1799  """
1800  return batch_norm_with_global_normalization(t=input,
1801                                              m=mean,
1802                                              v=variance,
1803                                              beta=beta,
1804                                              gamma=gamma,
1805                                              variance_epsilon=variance_epsilon,
1806                                              scale_after_normalization=scale_after_normalization,
1807                                              name=name)
1808
1809# pylint: enable=redefined-builtin,line-too-long
1810
1811
1812def _sum_rows(x):
1813  """Returns a vector summing up each row of the matrix x."""
1814  # _sum_rows(x) is equivalent to math_ops.reduce_sum(x, 1) when x is
1815  # a matrix.  The gradient of _sum_rows(x) is more efficient than
1816  # reduce_sum(x, 1)'s gradient in today's implementation. Therefore,
1817  # we use _sum_rows(x) in the nce_loss() computation since the loss
1818  # is mostly used for training.
1819  cols = array_ops.shape(x)[1]
1820  ones_shape = array_ops.stack([cols, 1])
1821  ones = array_ops.ones(ones_shape, x.dtype)
1822  return array_ops.reshape(math_ops.matmul(x, ones), [-1])
1823
1824
1825def _compute_sampled_logits(weights,
1826                            biases,
1827                            labels,
1828                            inputs,
1829                            num_sampled,
1830                            num_classes,
1831                            num_true=1,
1832                            sampled_values=None,
1833                            subtract_log_q=True,
1834                            remove_accidental_hits=False,
1835                            partition_strategy="mod",
1836                            name=None,
1837                            seed=None):
1838  """Helper function for nce_loss and sampled_softmax_loss functions.
1839
1840  Computes sampled output training logits and labels suitable for implementing
1841  e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see
1842  sampled_softmax_loss).
1843
1844  Note: In the case where num_true > 1, we assign to each target class
1845  the target probability 1 / num_true so that the target probabilities
1846  sum to 1 per-example.
1847
1848  Args:
1849    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
1850        objects whose concatenation along dimension 0 has shape
1851        `[num_classes, dim]`.  The (possibly-partitioned) class embeddings.
1852    biases: A `Tensor` of shape `[num_classes]`.  The (possibly-partitioned)
1853        class biases.
1854    labels: A `Tensor` of type `int64` and shape `[batch_size,
1855        num_true]`. The target classes.  Note that this format differs from
1856        the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
1857    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
1858        activations of the input network.
1859    num_sampled: An `int`.  The number of classes to randomly sample per batch.
1860    num_classes: An `int`. The number of possible classes.
1861    num_true: An `int`.  The number of target classes per training example.
1862    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
1863        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
1864        (if None, we default to `log_uniform_candidate_sampler`)
1865    subtract_log_q: A `bool`.  whether to subtract the log expected count of
1866        the labels in the sample to get the logits of the true labels.
1867        Default is True.  Turn off for Negative Sampling.
1868    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
1869        where a sampled class equals one of the target classes.  Default is
1870        False.
1871    partition_strategy: A string specifying the partitioning strategy, relevant
1872        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
1873        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
1874    name: A name for the operation (optional).
1875    seed: random seed for candidate sampling. Default to None, which doesn't set
1876        the op-level random seed for candidate sampling.
1877  Returns:
1878    out_logits: `Tensor` object with shape
1879        `[batch_size, num_true + num_sampled]`, for passing to either
1880        `nn.sigmoid_cross_entropy_with_logits` (NCE) or
1881        `nn.softmax_cross_entropy_with_logits` (sampled softmax).
1882    out_labels: A Tensor object with the same shape as `out_logits`.
1883  """
1884
1885  if isinstance(weights, variables.PartitionedVariable):
1886    weights = list(weights)
1887  if not isinstance(weights, list):
1888    weights = [weights]
1889
1890  with ops.name_scope(name, "compute_sampled_logits",
1891                      weights + [biases, inputs, labels]):
1892    if labels.dtype != dtypes.int64:
1893      labels = math_ops.cast(labels, dtypes.int64)
1894    labels_flat = array_ops.reshape(labels, [-1])
1895
1896    # Sample the negative labels.
1897    #   sampled shape: [num_sampled] tensor
1898    #   true_expected_count shape = [batch_size, 1] tensor
1899    #   sampled_expected_count shape = [num_sampled] tensor
1900    if sampled_values is None:
1901      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
1902          true_classes=labels,
1903          num_true=num_true,
1904          num_sampled=num_sampled,
1905          unique=True,
1906          range_max=num_classes,
1907          seed=seed)
1908    # NOTE: pylint cannot tell that 'sampled_values' is a sequence
1909    # pylint: disable=unpacking-non-sequence
1910    sampled, true_expected_count, sampled_expected_count = (
1911        array_ops.stop_gradient(s) for s in sampled_values)
1912    # pylint: enable=unpacking-non-sequence
1913    sampled = math_ops.cast(sampled, dtypes.int64)
1914
1915    # labels_flat is a [batch_size * num_true] tensor
1916    # sampled is a [num_sampled] int tensor
1917    all_ids = array_ops.concat([labels_flat, sampled], 0)
1918
1919    # Retrieve the true weights and the logits of the sampled weights.
1920
1921    # weights shape is [num_classes, dim]
1922    all_w = embedding_ops.embedding_lookup(
1923        weights, all_ids, partition_strategy=partition_strategy)
1924    if all_w.dtype != inputs.dtype:
1925      all_w = math_ops.cast(all_w, inputs.dtype)
1926
1927    # true_w shape is [batch_size * num_true, dim]
1928    true_w = array_ops.slice(all_w, [0, 0],
1929                             array_ops.stack(
1930                                 [array_ops.shape(labels_flat)[0], -1]))
1931
1932    sampled_w = array_ops.slice(
1933        all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
1934    # inputs has shape [batch_size, dim]
1935    # sampled_w has shape [num_sampled, dim]
1936    # Apply X*W', which yields [batch_size, num_sampled]
1937    sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True)
1938
1939    # Retrieve the true and sampled biases, compute the true logits, and
1940    # add the biases to the true and sampled logits.
1941    all_b = embedding_ops.embedding_lookup(
1942        biases, all_ids, partition_strategy=partition_strategy)
1943    if all_b.dtype != inputs.dtype:
1944      all_b = math_ops.cast(all_b, inputs.dtype)
1945    # true_b is a [batch_size * num_true] tensor
1946    # sampled_b is a [num_sampled] float tensor
1947    true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
1948    sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])
1949
1950    # inputs shape is [batch_size, dim]
1951    # true_w shape is [batch_size * num_true, dim]
1952    # row_wise_dots is [batch_size, num_true, dim]
1953    dim = array_ops.shape(true_w)[1:2]
1954    new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
1955    row_wise_dots = math_ops.multiply(
1956        array_ops.expand_dims(inputs, 1),
1957        array_ops.reshape(true_w, new_true_w_shape))
1958    # We want the row-wise dot plus biases which yields a
1959    # [batch_size, num_true] tensor of true_logits.
1960    dots_as_matrix = array_ops.reshape(row_wise_dots,
1961                                       array_ops.concat([[-1], dim], 0))
1962    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
1963    true_b = array_ops.reshape(true_b, [-1, num_true])
1964    true_logits += true_b
1965    sampled_logits += sampled_b
1966
1967    if remove_accidental_hits:
1968      acc_hits = candidate_sampling_ops.compute_accidental_hits(
1969          labels, sampled, num_true=num_true)
1970      acc_indices, acc_ids, acc_weights = acc_hits
1971
1972      # This is how SparseToDense expects the indices.
1973      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
1974      acc_ids_2d_int32 = array_ops.reshape(
1975          math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
1976      sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,
1977                                        "sparse_indices")
1978      # Create sampled_logits_shape = [batch_size, num_sampled]
1979      sampled_logits_shape = array_ops.concat(
1980          [array_ops.shape(labels)[:1],
1981           array_ops.expand_dims(num_sampled, 0)], 0)
1982      if sampled_logits.dtype != acc_weights.dtype:
1983        acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
1984      sampled_logits += gen_sparse_ops.sparse_to_dense(
1985          sparse_indices,
1986          sampled_logits_shape,
1987          acc_weights,
1988          default_value=0.0,
1989          validate_indices=False)
1990
1991    if subtract_log_q:
1992      # Subtract log of Q(l), prior probability that l appears in sampled.
1993      true_logits -= math_ops.log(true_expected_count)
1994      sampled_logits -= math_ops.log(sampled_expected_count)
1995
1996    # Construct output logits and labels. The true labels/logits start at col 0.
1997    out_logits = array_ops.concat([true_logits, sampled_logits], 1)
1998
1999    # true_logits is a float tensor, ones_like(true_logits) is a float
2000    # tensor of ones. We then divide by num_true to ensure the per-example
2001    # labels sum to 1.0, i.e. form a proper probability distribution.
2002    out_labels = array_ops.concat([
2003        array_ops.ones_like(true_logits) / num_true,
2004        array_ops.zeros_like(sampled_logits)
2005    ], 1)
2006
2007    return out_logits, out_labels
2008
2009
2010@tf_export("nn.nce_loss", v1=[])
2011@dispatch.add_dispatch_support
2012def nce_loss_v2(weights,
2013                biases,
2014                labels,
2015                inputs,
2016                num_sampled,
2017                num_classes,
2018                num_true=1,
2019                sampled_values=None,
2020                remove_accidental_hits=False,
2021                name="nce_loss"):
2022  """Computes and returns the noise-contrastive estimation training loss.
2023
2024  See [Noise-contrastive estimation: A new estimation principle for
2025  unnormalized statistical
2026  models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
2027  Also see our [Candidate Sampling Algorithms
2028  Reference](https://www.tensorflow.org/extras/candidate_sampling.pdf)
2029
2030  A common use case is to use this method for training, and calculate the full
2031  sigmoid loss for evaluation or inference as in the following example:
2032
2033  ```python
2034  if mode == "train":
2035    loss = tf.nn.nce_loss(
2036        weights=weights,
2037        biases=biases,
2038        labels=labels,
2039        inputs=inputs,
2040        ...)
2041  elif mode == "eval":
2042    logits = tf.matmul(inputs, tf.transpose(weights))
2043    logits = tf.nn.bias_add(logits, biases)
2044    labels_one_hot = tf.one_hot(labels, n_classes)
2045    loss = tf.nn.sigmoid_cross_entropy_with_logits(
2046        labels=labels_one_hot,
2047        logits=logits)
2048    loss = tf.reduce_sum(loss, axis=1)
2049  ```
2050
2051  Note: when doing embedding lookup on `weights` and `bias`, "div" partition
2052  strategy will be used. Support for other partition strategy will be added
2053  later.
2054
2055  Note: By default this uses a log-uniform (Zipfian) distribution for sampling,
2056  so your labels must be sorted in order of decreasing frequency to achieve
2057  good results.  For more details, see
2058  `tf.random.log_uniform_candidate_sampler`.
2059
2060  Note: In the case where `num_true` > 1, we assign to each target class
2061  the target probability 1 / `num_true` so that the target probabilities
2062  sum to 1 per-example.
2063
2064  Note: It would be useful to allow a variable number of target classes per
2065  example.  We hope to provide this functionality in a future release.
2066  For now, if you have a variable number of target classes, you can pad them
2067  out to a constant number by either repeating them or by padding
2068  with an otherwise unused class.
2069
2070  Args:
2071    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2072      objects whose concatenation along dimension 0 has shape [num_classes,
2073      dim].  The (possibly-partitioned) class embeddings.
2074    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2075    labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The
2076      target classes.
2077    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward activations of
2078      the input network.
2079    num_sampled: An `int`.  The number of negative classes to randomly sample
2080      per batch. This single sample of negative classes is evaluated for each
2081      element in the batch.
2082    num_classes: An `int`. The number of possible classes.
2083    num_true: An `int`.  The number of target classes per training example.
2084    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2085      `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2086      (if None, we default to `log_uniform_candidate_sampler`)
2087    remove_accidental_hits:  A `bool`.  Whether to remove "accidental hits"
2088      where a sampled class equals one of the target classes.  If set to `True`,
2089      this is a "Sampled Logistic" loss instead of NCE, and we are learning to
2090      generate log-odds instead of log probabilities.  See our [Candidate
2091      Sampling Algorithms Reference]
2092        (https://www.tensorflow.org/extras/candidate_sampling.pdf). Default is
2093          False.
2094    name: A name for the operation (optional).
2095
2096  Returns:
2097    A `batch_size` 1-D tensor of per-example NCE losses.
2098  """
2099  # TODO(yuefengz): get partition_strategy from either variables or distribution
2100  # strategies.
2101  return nce_loss(
2102      weights,
2103      biases,
2104      labels,
2105      inputs,
2106      num_sampled,
2107      num_classes,
2108      num_true=num_true,
2109      sampled_values=sampled_values,
2110      remove_accidental_hits=remove_accidental_hits,
2111      partition_strategy="div",
2112      name=name)
2113
2114
2115@tf_export(v1=["nn.nce_loss"])
2116@dispatch.add_dispatch_support
2117def nce_loss(weights,
2118             biases,
2119             labels,
2120             inputs,
2121             num_sampled,
2122             num_classes,
2123             num_true=1,
2124             sampled_values=None,
2125             remove_accidental_hits=False,
2126             partition_strategy="mod",
2127             name="nce_loss"):
2128  """Computes and returns the noise-contrastive estimation training loss.
2129
2130  A common use case is to use this method for training, and calculate the full
2131  sigmoid loss for evaluation or inference. In this case, you must set
2132  `partition_strategy="div"` for the two losses to be consistent, as in the
2133  following example:
2134
2135  ```python
2136  if mode == "train":
2137    loss = tf.nn.nce_loss(
2138        weights=weights,
2139        biases=biases,
2140        labels=labels,
2141        inputs=inputs,
2142        ...,
2143        partition_strategy="div")
2144  elif mode == "eval":
2145    logits = tf.matmul(inputs, tf.transpose(weights))
2146    logits = tf.nn.bias_add(logits, biases)
2147    labels_one_hot = tf.one_hot(labels, n_classes)
2148    loss = tf.nn.sigmoid_cross_entropy_with_logits(
2149        labels=labels_one_hot,
2150        logits=logits)
2151    loss = tf.reduce_sum(loss, axis=1)
2152  ```
2153
2154  Note: By default this uses a log-uniform (Zipfian) distribution for sampling,
2155  so your labels must be sorted in order of decreasing frequency to achieve
2156  good results.  For more details, see
2157  `tf.random.log_uniform_candidate_sampler`.
2158
2159  Note: In the case where `num_true` > 1, we assign to each target class
2160  the target probability 1 / `num_true` so that the target probabilities
2161  sum to 1 per-example.
2162
2163  Note: It would be useful to allow a variable number of target classes per
2164  example.  We hope to provide this functionality in a future release.
2165  For now, if you have a variable number of target classes, you can pad them
2166  out to a constant number by either repeating them or by padding
2167  with an otherwise unused class.
2168
2169  Args:
2170    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2171        objects whose concatenation along dimension 0 has shape
2172        [num_classes, dim].  The (possibly-partitioned) class embeddings.
2173    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2174    labels: A `Tensor` of type `int64` and shape `[batch_size,
2175        num_true]`. The target classes.
2176    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
2177        activations of the input network.
2178    num_sampled: An `int`.  The number of negative classes to randomly sample
2179        per batch. This single sample of negative classes is evaluated for each
2180        element in the batch.
2181    num_classes: An `int`. The number of possible classes.
2182    num_true: An `int`.  The number of target classes per training example.
2183    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2184        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2185        (if None, we default to `log_uniform_candidate_sampler`)
2186    remove_accidental_hits:  A `bool`.  Whether to remove "accidental hits"
2187        where a sampled class equals one of the target classes.  If set to
2188        `True`, this is a "Sampled Logistic" loss instead of NCE, and we are
2189        learning to generate log-odds instead of log probabilities. See
2190        our Candidate Sampling Algorithms Reference
2191        ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)).
2192        Default is False.
2193    partition_strategy: A string specifying the partitioning strategy, relevant
2194        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
2195        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
2196    name: A name for the operation (optional).
2197
2198  Returns:
2199    A `batch_size` 1-D tensor of per-example NCE losses.
2200
2201  References:
2202    Noise-contrastive estimation - A new estimation principle for unnormalized
2203    statistical models:
2204      [Gutmann et al., 2010](http://proceedings.mlr.press/v9/gutmann10a)
2205      ([pdf](http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf))
2206  """
2207  logits, labels = _compute_sampled_logits(
2208      weights=weights,
2209      biases=biases,
2210      labels=labels,
2211      inputs=inputs,
2212      num_sampled=num_sampled,
2213      num_classes=num_classes,
2214      num_true=num_true,
2215      sampled_values=sampled_values,
2216      subtract_log_q=True,
2217      remove_accidental_hits=remove_accidental_hits,
2218      partition_strategy=partition_strategy,
2219      name=name)
2220  sampled_losses = sigmoid_cross_entropy_with_logits(
2221      labels=labels, logits=logits, name="sampled_losses")
2222  # sampled_losses is batch_size x {true_loss, sampled_losses...}
2223  # We sum out true and sampled losses.
2224  return _sum_rows(sampled_losses)
2225
2226
2227@tf_export("nn.sampled_softmax_loss", v1=[])
2228@dispatch.add_dispatch_support
2229def sampled_softmax_loss_v2(weights,
2230                            biases,
2231                            labels,
2232                            inputs,
2233                            num_sampled,
2234                            num_classes,
2235                            num_true=1,
2236                            sampled_values=None,
2237                            remove_accidental_hits=True,
2238                            seed=None,
2239                            name="sampled_softmax_loss"):
2240  """Computes and returns the sampled softmax training loss.
2241
2242  This is a faster way to train a softmax classifier over a huge number of
2243  classes.
2244
2245  This operation is for training only.  It is generally an underestimate of
2246  the full softmax loss.
2247
2248  A common use case is to use this method for training, and calculate the full
2249  softmax loss for evaluation or inference as in the following example:
2250
2251  ```python
2252  if mode == "train":
2253    loss = tf.nn.sampled_softmax_loss(
2254        weights=weights,
2255        biases=biases,
2256        labels=labels,
2257        inputs=inputs,
2258        ...)
2259  elif mode == "eval":
2260    logits = tf.matmul(inputs, tf.transpose(weights))
2261    logits = tf.nn.bias_add(logits, biases)
2262    labels_one_hot = tf.one_hot(labels, n_classes)
2263    loss = tf.nn.softmax_cross_entropy_with_logits(
2264        labels=labels_one_hot,
2265        logits=logits)
2266  ```
2267
2268  See our [Candidate Sampling Algorithms Reference]
2269  (https://www.tensorflow.org/extras/candidate_sampling.pdf)
2270
2271  Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
2272  ([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math.
2273
2274  Note: when doing embedding lookup on `weights` and `bias`, "div" partition
2275  strategy will be used. Support for other partition strategy will be added
2276  later.
2277
2278  Args:
2279    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2280      objects whose concatenation along dimension 0 has shape [num_classes,
2281      dim].  The (possibly-sharded) class embeddings.
2282    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2283    labels: A `Tensor` of type `int64` and shape `[batch_size, num_true]`. The
2284      target classes.  Note that this format differs from the `labels` argument
2285      of `nn.softmax_cross_entropy_with_logits`.
2286    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward activations of
2287      the input network.
2288    num_sampled: An `int`.  The number of classes to randomly sample per batch.
2289    num_classes: An `int`. The number of possible classes.
2290    num_true: An `int`.  The number of target classes per training example.
2291    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2292      `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2293      (if None, we default to `log_uniform_candidate_sampler`)
2294    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
2295      where a sampled class equals one of the target classes.  Default is True.
2296    seed: random seed for candidate sampling. Default to None, which doesn't set
2297      the op-level random seed for candidate sampling.
2298    name: A name for the operation (optional).
2299
2300  Returns:
2301    A `batch_size` 1-D tensor of per-example sampled softmax losses.
2302
2303  """
2304  return sampled_softmax_loss(
2305      weights,
2306      biases,
2307      labels,
2308      inputs,
2309      num_sampled,
2310      num_classes,
2311      num_true=num_true,
2312      sampled_values=sampled_values,
2313      remove_accidental_hits=remove_accidental_hits,
2314      partition_strategy="div",
2315      name=name,
2316      seed=seed)
2317
2318
2319@tf_export(v1=["nn.sampled_softmax_loss"])
2320@dispatch.add_dispatch_support
2321def sampled_softmax_loss(weights,
2322                         biases,
2323                         labels,
2324                         inputs,
2325                         num_sampled,
2326                         num_classes,
2327                         num_true=1,
2328                         sampled_values=None,
2329                         remove_accidental_hits=True,
2330                         partition_strategy="mod",
2331                         name="sampled_softmax_loss",
2332                         seed=None):
2333  """Computes and returns the sampled softmax training loss.
2334
2335  This is a faster way to train a softmax classifier over a huge number of
2336  classes.
2337
2338  This operation is for training only.  It is generally an underestimate of
2339  the full softmax loss.
2340
2341  A common use case is to use this method for training, and calculate the full
2342  softmax loss for evaluation or inference. In this case, you must set
2343  `partition_strategy="div"` for the two losses to be consistent, as in the
2344  following example:
2345
2346  ```python
2347  if mode == "train":
2348    loss = tf.nn.sampled_softmax_loss(
2349        weights=weights,
2350        biases=biases,
2351        labels=labels,
2352        inputs=inputs,
2353        ...,
2354        partition_strategy="div")
2355  elif mode == "eval":
2356    logits = tf.matmul(inputs, tf.transpose(weights))
2357    logits = tf.nn.bias_add(logits, biases)
2358    labels_one_hot = tf.one_hot(labels, n_classes)
2359    loss = tf.nn.softmax_cross_entropy_with_logits(
2360        labels=labels_one_hot,
2361        logits=logits)
2362  ```
2363
2364  See our Candidate Sampling Algorithms Reference
2365  ([pdf](https://www.tensorflow.org/extras/candidate_sampling.pdf)).
2366  Also see Section 3 of (Jean et al., 2014) for the math.
2367
2368  Args:
2369    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
2370        objects whose concatenation along dimension 0 has shape
2371        [num_classes, dim].  The (possibly-sharded) class embeddings.
2372    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
2373    labels: A `Tensor` of type `int64` and shape `[batch_size,
2374        num_true]`. The target classes.  Note that this format differs from
2375        the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
2376    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
2377        activations of the input network.
2378    num_sampled: An `int`.  The number of classes to randomly sample per batch.
2379    num_classes: An `int`. The number of possible classes.
2380    num_true: An `int`.  The number of target classes per training example.
2381    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
2382        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
2383        (if None, we default to `log_uniform_candidate_sampler`)
2384    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
2385        where a sampled class equals one of the target classes.  Default is
2386        True.
2387    partition_strategy: A string specifying the partitioning strategy, relevant
2388        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
2389        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
2390    name: A name for the operation (optional).
2391    seed: random seed for candidate sampling. Default to None, which doesn't set
2392        the op-level random seed for candidate sampling.
2393
2394  Returns:
2395    A `batch_size` 1-D tensor of per-example sampled softmax losses.
2396
2397  References:
2398    On Using Very Large Target Vocabulary for Neural Machine Translation:
2399      [Jean et al., 2014]
2400      (https://aclanthology.coli.uni-saarland.de/papers/P15-1001/p15-1001)
2401      ([pdf](http://aclweb.org/anthology/P15-1001))
2402  """
2403  logits, labels = _compute_sampled_logits(
2404      weights=weights,
2405      biases=biases,
2406      labels=labels,
2407      inputs=inputs,
2408      num_sampled=num_sampled,
2409      num_classes=num_classes,
2410      num_true=num_true,
2411      sampled_values=sampled_values,
2412      subtract_log_q=True,
2413      remove_accidental_hits=remove_accidental_hits,
2414      partition_strategy=partition_strategy,
2415      name=name,
2416      seed=seed)
2417  labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
2418  sampled_losses = nn_ops.softmax_cross_entropy_with_logits_v2(
2419      labels=labels, logits=logits)
2420  # sampled_losses is a [batch_size] tensor.
2421  return sampled_losses
2422