xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/distributions/util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for probability distributions."""
16
17import functools
18import hashlib
19
20import numpy as np
21
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import check_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import linalg_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import nn
33from tensorflow.python.util import tf_inspect
34
35
36def assert_integer_form(x,
37                        data=None,
38                        summarize=None,
39                        message=None,
40                        int_dtype=None,
41                        name="assert_integer_form"):
42  """Assert that x has integer components (or floats equal to integers).
43
44  Args:
45    x: Floating-point `Tensor`
46    data: The tensors to print out if the condition is `False`. Defaults to
47      error message and first few entries of `x` and `y`.
48    summarize: Print this many entries of each tensor.
49    message: A string to prefix to the default message.
50    int_dtype: A `tf.dtype` used to cast the float to. The default (`None`)
51      implies the smallest possible signed int will be used for casting.
52    name: A name for this operation (optional).
53
54  Returns:
55    Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`.
56  """
57  with ops.name_scope(name, values=[x, data]):
58    x = ops.convert_to_tensor(x, name="x")
59    if x.dtype.is_integer:
60      return control_flow_ops.no_op()
61    message = message or "{} has non-integer components".format(x)
62    if int_dtype is None:
63      try:
64        int_dtype = {
65            dtypes.float16: dtypes.int16,
66            dtypes.float32: dtypes.int32,
67            dtypes.float64: dtypes.int64,
68        }[x.dtype.base_dtype]
69      except KeyError:
70        raise TypeError("Unrecognized type {}".format(x.dtype.name))
71    return check_ops.assert_equal(
72        x,
73        math_ops.cast(math_ops.cast(x, int_dtype), x.dtype),
74        data=data,
75        summarize=summarize,
76        message=message,
77        name=name)
78
79
80def assert_symmetric(matrix):
81  matrix_t = array_ops.matrix_transpose(matrix)
82  return control_flow_ops.with_dependencies(
83      [check_ops.assert_equal(matrix, matrix_t)], matrix)
84
85
86def embed_check_nonnegative_integer_form(
87    x, name="embed_check_nonnegative_integer_form"):
88  """Assert x is a non-negative tensor, and optionally of integers."""
89  with ops.name_scope(name, values=[x]):
90    x = ops.convert_to_tensor(x, name="x")
91    assertions = [
92        check_ops.assert_non_negative(
93            x, message="'{}' must be non-negative.".format(x)),
94    ]
95    if not x.dtype.is_integer:
96      assertions += [
97          assert_integer_form(
98              x,
99              message="'{}' cannot contain fractional components.".format(x)),
100      ]
101    return control_flow_ops.with_dependencies(assertions, x)
102
103
104def same_dynamic_shape(a, b):
105  """Returns whether a and b have the same dynamic shape.
106
107  Args:
108    a: `Tensor`
109    b: `Tensor`
110
111  Returns:
112    `bool` `Tensor` representing if both tensors have the same shape.
113  """
114  a = ops.convert_to_tensor(a, name="a")
115  b = ops.convert_to_tensor(b, name="b")
116
117  # Here we can't just do math_ops.equal(a.shape, b.shape), since
118  # static shape inference may break the equality comparison between
119  # shape(a) and shape(b) in math_ops.equal.
120  def all_shapes_equal():
121    return math_ops.reduce_all(
122        math_ops.equal(
123            array_ops.concat(
124                [array_ops.shape(a), array_ops.shape(b)], 0),
125            array_ops.concat(
126                [array_ops.shape(b), array_ops.shape(a)], 0)))
127
128  # One of the shapes isn't fully defined, so we need to use the dynamic
129  # shape.
130  return control_flow_ops.cond(
131      math_ops.equal(array_ops.rank(a), array_ops.rank(b)),
132      all_shapes_equal, lambda: constant_op.constant(False))
133
134
135def maybe_get_static_value(x, dtype=None):
136  """Helper which tries to return a static value.
137
138  Given `x`, extract it's value statically, optionally casting to a specific
139  dtype. If this is not possible, None is returned.
140
141  Args:
142    x: `Tensor` for which to extract a value statically.
143    dtype: Optional dtype to cast to.
144
145  Returns:
146    Statically inferred value if possible, otherwise None.
147  """
148  if x is None:
149    return x
150  try:
151    # This returns an np.ndarray.
152    x_ = tensor_util.constant_value(x)
153  except TypeError:
154    x_ = x
155  if x_ is None or dtype is None:
156    return x_
157  return np.array(x_, dtype)
158
159
160def get_logits_and_probs(logits=None,
161                         probs=None,
162                         multidimensional=False,
163                         validate_args=False,
164                         name="get_logits_and_probs",
165                         dtype=None):
166  """Converts logit to probabilities (or vice-versa), and returns both.
167
168  Args:
169    logits: Floating-point `Tensor` representing log-odds.
170    probs: Floating-point `Tensor` representing probabilities.
171    multidimensional: Python `bool`, default `False`. If `True`, represents
172      whether the last dimension of `logits` or `probs`, a `[N1, N2, ...  k]`
173      dimensional tensor, representing the logit or probability of `shape[-1]`
174      classes.
175    validate_args: Python `bool`, default `False`. When `True`, either assert `0
176      <= probs <= 1` (if not `multidimensional`) or that the last dimension of
177      `probs` sums to one.
178    name: A name for this operation (optional).
179    dtype: `tf.DType` to prefer when converting args to `Tensor`s.
180
181  Returns:
182    logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
183      `1`, then the corresponding entry in the returned logit will be `-Inf` and
184      `Inf` respectively.
185
186  Raises:
187    ValueError: if neither `probs` nor `logits` were passed in, or both were.
188  """
189  with ops.name_scope(name, values=[probs, logits]):
190    if (probs is None) == (logits is None):
191      raise ValueError("Must pass probs or logits, but not both.")
192
193    if probs is None:
194      logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype)
195      if not logits.dtype.is_floating:
196        raise TypeError("logits must having floating type.")
197      # We can early return since we constructed probs and therefore know
198      # they're valid.
199      if multidimensional:
200        if validate_args:
201          logits = embed_check_categorical_event_shape(logits)
202        return logits, nn.softmax(logits, name="probs")
203      return logits, math_ops.sigmoid(logits, name="probs")
204
205    probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
206    if not probs.dtype.is_floating:
207      raise TypeError("probs must having floating type.")
208
209    if validate_args:
210      with ops.name_scope("validate_probs"):
211        one = constant_op.constant(1., probs.dtype)
212        dependencies = [check_ops.assert_non_negative(probs)]
213        if multidimensional:
214          probs = embed_check_categorical_event_shape(probs)
215          dependencies += [
216              check_ops.assert_near(
217                  math_ops.reduce_sum(probs, -1),
218                  one,
219                  message="probs does not sum to 1.")
220          ]
221        else:
222          dependencies += [
223              check_ops.assert_less_equal(
224                  probs, one, message="probs has components greater than 1.")
225          ]
226        probs = control_flow_ops.with_dependencies(dependencies, probs)
227
228    with ops.name_scope("logits"):
229      if multidimensional:
230        # Here we don't compute the multidimensional case, in a manner
231        # consistent with respect to the unidimensional case. We do so
232        # following the TF convention. Typically, you might expect to see
233        # logits = log(probs) - log(probs[pivot]). A side-effect of
234        # being consistent with the TF approach is that the unidimensional case
235        # implicitly handles the second dimension but the multidimensional case
236        # explicitly keeps the pivot dimension.
237        return math_ops.log(probs), probs
238      return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs
239
240
241def _is_known_unsigned_by_dtype(dt):
242  """Helper returning True if dtype is known to be unsigned."""
243  return {
244      dtypes.bool: True,
245      dtypes.uint8: True,
246      dtypes.uint16: True,
247  }.get(dt.base_dtype, False)
248
249
250def _is_known_signed_by_dtype(dt):
251  """Helper returning True if dtype is known to be signed."""
252  return {
253      dtypes.float16: True,
254      dtypes.float32: True,
255      dtypes.float64: True,
256      dtypes.int8: True,
257      dtypes.int16: True,
258      dtypes.int32: True,
259      dtypes.int64: True,
260  }.get(dt.base_dtype, False)
261
262
263def _is_known_dtype(dt):
264  """Helper returning True if dtype is known."""
265  return _is_known_unsigned_by_dtype(dt) or _is_known_signed_by_dtype(dt)
266
267
268def _largest_integer_by_dtype(dt):
269  """Helper returning the largest integer exactly representable by dtype."""
270  if not _is_known_dtype(dt):
271    raise TypeError("Unrecognized dtype: {}".format(dt.name))
272  if dt.is_floating:
273    return int(2**(np.finfo(dt.as_numpy_dtype).nmant + 1))
274  if dt.is_integer:
275    return np.iinfo(dt.as_numpy_dtype).max
276  if dt.base_dtype == dtypes.bool:
277    return int(1)
278  # We actually can't land here but keep the case for completeness.
279  raise TypeError("Unrecognized dtype: {}".format(dt.name))
280
281
282def _smallest_integer_by_dtype(dt):
283  """Helper returning the smallest integer exactly representable by dtype."""
284  if not _is_known_dtype(dt):
285    raise TypeError("Unrecognized dtype: {}".format(dt.name))
286  if _is_known_unsigned_by_dtype(dt):
287    return 0
288  return -1 * _largest_integer_by_dtype(dt)
289
290
291def _is_integer_like_by_dtype(dt):
292  """Helper returning True if dtype.is_integer or is `bool`."""
293  if not _is_known_dtype(dt):
294    raise TypeError("Unrecognized dtype: {}".format(dt.name))
295  return dt.is_integer or dt.base_dtype == dtypes.bool
296
297
298def embed_check_categorical_event_shape(
299    categorical_param, name="embed_check_categorical_event_shape"):
300  """Embeds checks that categorical distributions don't have too many classes.
301
302  A categorical-type distribution is one which, e.g., returns the class label
303  rather than a one-hot encoding.  E.g., `Categorical(probs)`.
304
305  Since distributions output samples in the same dtype as the parameters, we
306  must ensure that casting doesn't lose precision. That is, the
307  `parameter.dtype` implies a maximum number of classes. However, since shape is
308  `int32` and categorical variables are presumed to be indexes into a `Tensor`,
309  we must also ensure that the number of classes is no larger than the largest
310  possible `int32` index, i.e., `2**31-1`.
311
312  In other words the number of classes, `K`, must satisfy the following
313  condition:
314
315  ```python
316  K <= min(
317      int(2**31 - 1),  # Largest float as an index.
318      {
319          dtypes.float16: int(2**11),   # Largest int as a float16.
320          dtypes.float32: int(2**24),
321          dtypes.float64: int(2**53),
322      }.get(categorical_param.dtype.base_dtype, 0))
323  ```
324
325  Args:
326    categorical_param: Floating-point `Tensor` representing parameters of
327      distribution over categories. The rightmost shape is presumed to be the
328      number of categories.
329    name: A name for this operation (optional).
330
331  Returns:
332    categorical_param: Input `Tensor` with appropriate assertions embedded.
333
334  Raises:
335    TypeError: if `categorical_param` has an unknown `dtype`.
336    ValueError: if we can statically identify `categorical_param` as being too
337      large (for being closed under int32/float casting).
338  """
339  with ops.name_scope(name, values=[categorical_param]):
340    x = ops.convert_to_tensor(categorical_param, name="categorical_param")
341    # The size must not exceed both of:
342    # - The largest possible int32 (since categorical values are presumed to be
343    #   indexes into a Tensor).
344    # - The largest possible integer exactly representable under the given
345    #   floating-point dtype (since we need to cast to/from).
346    #
347    # The chosen floating-point thresholds are 2**(1 + mantissa_bits).
348    # For more details, see:
349    # https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation
350    x_dtype = x.dtype.base_dtype
351    max_event_size = (
352        _largest_integer_by_dtype(x_dtype) if x_dtype.is_floating else 0)
353    if max_event_size == 0:
354      raise TypeError("Unable to validate size of unrecognized dtype "
355                      "({}).".format(x_dtype.name))
356    try:
357      x_shape_static = x.get_shape().with_rank_at_least(1)
358    except ValueError:
359      raise ValueError("A categorical-distribution parameter must have "
360                       "at least 1 dimension.")
361    if tensor_shape.dimension_value(x_shape_static[-1]) is not None:
362      event_size = x_shape_static.dims[-1].value
363      if event_size < 2:
364        raise ValueError("A categorical-distribution parameter must have at "
365                         "least 2 events.")
366      if event_size > max_event_size:
367        raise ValueError("Number of classes exceeds `dtype` precision, i.e., "
368                         "{} implies shape ({}) cannot exceed {}.".format(
369                             x_dtype.name, event_size, max_event_size))
370      return x
371    else:
372      event_size = array_ops.shape(x, name="x_shape")[-1]
373      return control_flow_ops.with_dependencies([
374          check_ops.assert_rank_at_least(
375              x,
376              1,
377              message=("A categorical-distribution parameter must have "
378                       "at least 1 dimension.")),
379          check_ops.assert_greater_equal(
380              array_ops.shape(x)[-1],
381              2,
382              message=("A categorical-distribution parameter must have at "
383                       "least 2 events.")),
384          check_ops.assert_less_equal(
385              event_size,
386              max_event_size,
387              message="Number of classes exceeds `dtype` precision, "
388              "i.e., {} dtype cannot exceed {} shape.".format(
389                  x_dtype.name, max_event_size)),
390      ], x)
391
392
393def embed_check_integer_casting_closed(x,
394                                       target_dtype,
395                                       assert_nonnegative=True,
396                                       name="embed_check_casting_closed"):
397  """Ensures integers remain unaffected despite casting to/from int/float types.
398
399  Example integer-types: `uint8`, `int32`, `bool`.
400  Example floating-types: `float32`, `float64`.
401
402  The largest possible integer representable by an IEEE754 floating-point is
403  `2**(1 + mantissa_bits)` yet the largest possible integer as an int-type is
404  `2**(bits - 1) - 1`. This function ensures that a `Tensor` purporting to have
405  integer-form values can be cast to some other type without loss of precision.
406
407  The smallest representable integer is the negative of the largest
408  representable integer, except for types: `uint8`, `uint16`, `bool`. For these
409  types, the smallest representable integer is `0`.
410
411  Args:
412    x: `Tensor` representing integer-form values.
413    target_dtype: TF `dtype` under which `x` should have identical values.
414    assert_nonnegative: `bool` indicating `x` should contain nonnegative values.
415    name: A name for this operation (optional).
416
417  Returns:
418    x: Input `Tensor` with appropriate assertions embedded.
419
420  Raises:
421    TypeError: if `x` is neither integer- nor floating-type.
422    TypeError: if `target_dtype` is neither integer- nor floating-type.
423    TypeError: if neither `x` nor `target_dtype` are integer-type.
424  """
425
426  with ops.name_scope(name, values=[x]):
427    x = ops.convert_to_tensor(x, name="x")
428    if (not _is_integer_like_by_dtype(x.dtype) and not x.dtype.is_floating):
429      raise TypeError("{}.dtype must be floating- or "
430                      "integer-type.".format(x.dtype.name))
431    if (not _is_integer_like_by_dtype(target_dtype) and
432        not target_dtype.is_floating):
433      raise TypeError("target_dtype ({}) must be floating- or "
434                      "integer-type.".format(target_dtype.name))
435    if (not _is_integer_like_by_dtype(x.dtype) and
436        not _is_integer_like_by_dtype(target_dtype)):
437      raise TypeError("At least one of {}.dtype ({}) and target_dtype ({}) "
438                      "must be integer-type.".format(x, x.dtype.name,
439                                                     target_dtype.name))
440
441    assertions = []
442    if assert_nonnegative:
443      assertions += [
444          check_ops.assert_non_negative(
445              x, message="Elements must be non-negative."),
446      ]
447
448    if x.dtype.is_floating:
449      # Being here means _is_integer_like_by_dtype(target_dtype) = True.
450      # Since this check implies the magnitude check below, we need only it.
451      assertions += [
452          assert_integer_form(
453              x,
454              int_dtype=target_dtype,
455              message="Elements must be {}-equivalent.".format(
456                  target_dtype.name)),
457      ]
458    else:
459      if (_largest_integer_by_dtype(x.dtype) >
460          _largest_integer_by_dtype(target_dtype)):
461        # Cast may lose integer precision.
462        assertions += [
463            check_ops.assert_less_equal(
464                x,
465                _largest_integer_by_dtype(target_dtype),
466                message=("Elements cannot exceed {}.".format(
467                    _largest_integer_by_dtype(target_dtype)))),
468        ]
469      if (not assert_nonnegative and (_smallest_integer_by_dtype(
470          x.dtype) < _smallest_integer_by_dtype(target_dtype))):
471        assertions += [
472            check_ops.assert_greater_equal(
473                x,
474                _smallest_integer_by_dtype(target_dtype),
475                message=("Elements cannot be smaller than {}.".format(
476                    _smallest_integer_by_dtype(target_dtype)))),
477        ]
478
479    if not assertions:
480      return x
481    return control_flow_ops.with_dependencies(assertions, x)
482
483
484def log_combinations(n, counts, name="log_combinations"):
485  """Multinomial coefficient.
486
487  Given `n` and `counts`, where `counts` has last dimension `k`, we compute
488  the multinomial coefficient as:
489
490  ```n! / sum_i n_i!```
491
492  where `i` runs over all `k` classes.
493
494  Args:
495    n: Floating-point `Tensor` broadcastable with `counts`. This represents `n`
496      outcomes.
497    counts: Floating-point `Tensor` broadcastable with `n`. This represents
498      counts in `k` classes, where `k` is the last dimension of the tensor.
499    name: A name for this operation (optional).
500
501  Returns:
502    `Tensor` representing the multinomial coefficient between `n` and `counts`.
503  """
504  # First a bit about the number of ways counts could have come in:
505  # E.g. if counts = [1, 2], then this is 3 choose 2.
506  # In general, this is (sum counts)! / sum(counts!)
507  # The sum should be along the last dimension of counts. This is the
508  # "distribution" dimension. Here n a priori represents the sum of counts.
509  with ops.name_scope(name, values=[n, counts]):
510    n = ops.convert_to_tensor(n, name="n")
511    counts = ops.convert_to_tensor(counts, name="counts")
512    total_permutations = math_ops.lgamma(n + 1)
513    counts_factorial = math_ops.lgamma(counts + 1)
514    redundant_permutations = math_ops.reduce_sum(counts_factorial, axis=[-1])
515    return total_permutations - redundant_permutations
516
517
518def matrix_diag_transform(matrix, transform=None, name=None):
519  """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged.
520
521  Create a trainable covariance defined by a Cholesky factor:
522
523  ```python
524  # Transform network layer into 2 x 2 array.
525  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
526  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
527
528  # Make the diagonal positive. If the upper triangle was zero, this would be a
529  # valid Cholesky factor.
530  chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
531
532  # LinearOperatorLowerTriangular ignores the upper triangle.
533  operator = LinearOperatorLowerTriangular(chol)
534  ```
535
536  Example of heteroskedastic 2-D linear regression.
537
538  ```python
539  tfd = tfp.distributions
540
541  # Get a trainable Cholesky factor.
542  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
543  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
544  chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)
545
546  # Get a trainable mean.
547  mu = tf.contrib.layers.fully_connected(activations, 2)
548
549  # This is a fully trainable multivariate normal!
550  dist = tfd.MultivariateNormalTriL(mu, chol)
551
552  # Standard log loss. Minimizing this will "train" mu and chol, and then dist
553  # will be a distribution predicting labels as multivariate Gaussians.
554  loss = -1 * tf.reduce_mean(dist.log_prob(labels))
555  ```
556
557  Args:
558    matrix:  Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are
559      equal.
560    transform:  Element-wise function mapping `Tensors` to `Tensors`. To be
561      applied to the diagonal of `matrix`. If `None`, `matrix` is returned
562      unchanged. Defaults to `None`.
563    name:  A name to give created ops. Defaults to "matrix_diag_transform".
564
565  Returns:
566    A `Tensor` with same shape and `dtype` as `matrix`.
567  """
568  with ops.name_scope(name, "matrix_diag_transform", [matrix]):
569    matrix = ops.convert_to_tensor(matrix, name="matrix")
570    if transform is None:
571      return matrix
572    # Replace the diag with transformed diag.
573    diag = array_ops.matrix_diag_part(matrix)
574    transformed_diag = transform(diag)
575    transformed_mat = array_ops.matrix_set_diag(matrix, transformed_diag)
576
577  return transformed_mat
578
579
580def rotate_transpose(x, shift, name="rotate_transpose"):
581  """Circularly moves dims left or right.
582
583  Effectively identical to:
584
585  ```python
586  numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift))
587  ```
588
589  When `validate_args=False` additional graph-runtime checks are
590  performed. These checks entail moving data from to GPU to CPU.
591
592  Example:
593
594  ```python
595  x = tf.random.normal([1, 2, 3, 4])  # Tensor of shape [1, 2, 3, 4].
596  rotate_transpose(x, -1).shape == [2, 3, 4, 1]
597  rotate_transpose(x, -2).shape == [3, 4, 1, 2]
598  rotate_transpose(x,  1).shape == [4, 1, 2, 3]
599  rotate_transpose(x,  2).shape == [3, 4, 1, 2]
600  rotate_transpose(x,  7).shape == rotate_transpose(x, 3).shape  # [2, 3, 4, 1]
601  rotate_transpose(x, -7).shape == rotate_transpose(x, -3).shape  # [4, 1, 2, 3]
602  ```
603
604  Args:
605    x: `Tensor`.
606    shift: `Tensor`. Number of dimensions to transpose left (shift<0) or
607      transpose right (shift>0).
608    name: Python `str`. The name to give this op.
609
610  Returns:
611    rotated_x: Input `Tensor` with dimensions circularly rotated by shift.
612
613  Raises:
614    TypeError: if shift is not integer type.
615  """
616  with ops.name_scope(name, values=[x, shift]):
617    x = ops.convert_to_tensor(x, name="x")
618    shift = ops.convert_to_tensor(shift, name="shift")
619    # We do not assign back to preserve constant-ness.
620    check_ops.assert_integer(shift)
621    shift_value_static = tensor_util.constant_value(shift)
622    ndims = x.get_shape().ndims
623    if ndims is not None and shift_value_static is not None:
624      if ndims < 2:
625        return x
626      shift_value_static = np.sign(shift_value_static) * (
627          abs(shift_value_static) % ndims)
628      if shift_value_static == 0:
629        return x
630      perm = np.roll(np.arange(ndims), shift_value_static)
631      return array_ops.transpose(x, perm=perm)
632    else:
633      # Consider if we always had a positive shift, and some specified
634      # direction.
635      # When shifting left we want the new array:
636      #   last(x, n-shift) + first(x, shift)
637      # and if shifting right then we want:
638      #   last(x, shift) + first(x, n-shift)
639      # Observe that last(a) == slice(a, n) and first(a) == slice(0, a).
640      # Also, we can encode direction and shift as one: direction * shift.
641      # Combining these facts, we have:
642      #   a = cond(shift<0, -shift, n-shift)
643      #   last(x, n-a) + first(x, a) == x[a:n] + x[0:a]
644      # Finally, we transform shift by modulo length so it can be specified
645      # independently from the array upon which it operates (like python).
646      ndims = array_ops.rank(x)
647      shift = array_ops.where_v2(
648          math_ops.less(shift, 0),
649          math_ops.mod(-shift, ndims),  # pylint: disable=invalid-unary-operand-type
650          ndims - math_ops.mod(shift, ndims))
651      first = math_ops.range(0, shift)
652      last = math_ops.range(shift, ndims)
653      perm = array_ops.concat([last, first], 0)
654      return array_ops.transpose(x, perm=perm)
655
656
657def pick_vector(cond, true_vector, false_vector, name="pick_vector"):
658  """Picks possibly different length row `Tensor`s based on condition.
659
660  Value `Tensor`s should have exactly one dimension.
661
662  If `cond` is a python Boolean or `tf.constant` then either `true_vector` or
663  `false_vector` is immediately returned. I.e., no graph nodes are created and
664  no validation happens.
665
666  Args:
667    cond: `Tensor`. Must have `dtype=tf.bool` and be scalar.
668    true_vector: `Tensor` of one dimension. Returned when cond is `True`.
669    false_vector: `Tensor` of one dimension. Returned when cond is `False`.
670    name: Python `str`. The name to give this op.
671  Example:  ```python pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15,
672    18))  # [10, 11] pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15,
673    18))  # [15, 16, 17] ```
674
675  Returns:
676    true_or_false_vector: `Tensor`.
677
678  Raises:
679    TypeError: if `cond.dtype != tf.bool`
680    TypeError: if `cond` is not a constant and
681      `true_vector.dtype != false_vector.dtype`
682  """
683  with ops.name_scope(name, values=(cond, true_vector, false_vector)):
684    cond = ops.convert_to_tensor(cond, name="cond")
685    if cond.dtype != dtypes.bool:
686      raise TypeError("%s.dtype=%s which is not %s" %
687                      (cond, cond.dtype, dtypes.bool))
688    cond_value_static = tensor_util.constant_value(cond)
689    if cond_value_static is not None:
690      return true_vector if cond_value_static else false_vector
691    true_vector = ops.convert_to_tensor(true_vector, name="true_vector")
692    false_vector = ops.convert_to_tensor(false_vector, name="false_vector")
693    if true_vector.dtype != false_vector.dtype:
694      raise TypeError(
695          "%s.dtype=%s does not match %s.dtype=%s" %
696          (true_vector, true_vector.dtype, false_vector, false_vector.dtype))
697    n = array_ops.shape(true_vector)[0]
698    return array_ops.slice(
699        array_ops.concat([true_vector, false_vector], 0),
700        [array_ops.where_v2(cond, 0, n)], [array_ops.where(cond, n, -1)])
701
702
703def prefer_static_broadcast_shape(shape1,
704                                  shape2,
705                                  name="prefer_static_broadcast_shape"):
706  """Convenience function which statically broadcasts shape when possible.
707
708  Args:
709    shape1:  `1-D` integer `Tensor`.  Already converted to tensor!
710    shape2:  `1-D` integer `Tensor`.  Already converted to tensor!
711    name:  A string name to prepend to created ops.
712
713  Returns:
714    The broadcast shape, either as `TensorShape` (if broadcast can be done
715      statically), or as a `Tensor`.
716  """
717  with ops.name_scope(name, values=[shape1, shape2]):
718
719    def make_shape_tensor(x):
720      return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32)
721
722    def get_tensor_shape(s):
723      if isinstance(s, tensor_shape.TensorShape):
724        return s
725      s_ = tensor_util.constant_value(make_shape_tensor(s))
726      if s_ is not None:
727        return tensor_shape.TensorShape(s_)
728      return None
729
730    def get_shape_tensor(s):
731      if not isinstance(s, tensor_shape.TensorShape):
732        return make_shape_tensor(s)
733      if s.is_fully_defined():
734        return make_shape_tensor(s.as_list())
735      raise ValueError("Cannot broadcast from partially "
736                       "defined `TensorShape`.")
737
738    shape1_ = get_tensor_shape(shape1)
739    shape2_ = get_tensor_shape(shape2)
740    if shape1_ is not None and shape2_ is not None:
741      return array_ops.broadcast_static_shape(shape1_, shape2_)
742
743    shape1_ = get_shape_tensor(shape1)
744    shape2_ = get_shape_tensor(shape2)
745    return array_ops.broadcast_dynamic_shape(shape1_, shape2_)
746
747
748def prefer_static_rank(x):
749  """Return static rank of tensor `x` if available, else `tf.rank(x)`.
750
751  Args:
752    x: `Tensor` (already converted).
753
754  Returns:
755    Numpy array (if static rank is obtainable), else `Tensor`.
756  """
757  return prefer_static_value(array_ops.rank(x))
758
759
760def prefer_static_shape(x):
761  """Return static shape of tensor `x` if available, else `tf.shape(x)`.
762
763  Args:
764    x: `Tensor` (already converted).
765
766  Returns:
767    Numpy array (if static shape is obtainable), else `Tensor`.
768  """
769  return prefer_static_value(array_ops.shape(x))
770
771
772def prefer_static_value(x):
773  """Return static value of tensor `x` if available, else `x`.
774
775  Args:
776    x: `Tensor` (already converted).
777
778  Returns:
779    Numpy array (if static value is obtainable), else `Tensor`.
780  """
781  static_x = tensor_util.constant_value(x)
782  if static_x is not None:
783    return static_x
784  return x
785
786
787def gen_new_seed(seed, salt):
788  """Generate a new seed, from the given seed and salt."""
789  if seed is None:
790    return None
791  string = (str(seed) + salt).encode("utf-8")
792  return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
793
794
795def fill_triangular(x, upper=False, name=None):
796  """Creates a (batch of) triangular matrix from a vector of inputs.
797
798  Created matrix can be lower- or upper-triangular. (It is more efficient to
799  create the matrix as upper or lower, rather than transpose.)
800
801  Triangular matrix elements are filled in a clockwise spiral. See example,
802  below.
803
804  If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is
805  `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e.,
806  `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`.
807
808  Example:
809
810  ```python
811  fill_triangular([1, 2, 3, 4, 5, 6])
812  # ==> [[4, 0, 0],
813  #      [6, 5, 0],
814  #      [3, 2, 1]]
815
816  fill_triangular([1, 2, 3, 4, 5, 6], upper=True)
817  # ==> [[1, 2, 3],
818  #      [0, 5, 6],
819  #      [0, 0, 4]]
820  ```
821
822  For comparison, a pure numpy version of this function can be found in
823  `util_test.py`, function `_fill_triangular`.
824
825  Args:
826    x: `Tensor` representing lower (or upper) triangular elements.
827    upper: Python `bool` representing whether output matrix should be upper
828      triangular (`True`) or lower triangular (`False`, default).
829    name: Python `str`. The name to give this op.
830
831  Returns:
832    tril: `Tensor` with lower (or upper) triangular elements filled from `x`.
833
834  Raises:
835    ValueError: if `x` cannot be mapped to a triangular matrix.
836  """
837
838  with ops.name_scope(name, "fill_triangular", values=[x]):
839    x = ops.convert_to_tensor(x, name="x")
840    if tensor_shape.dimension_value(
841        x.shape.with_rank_at_least(1)[-1]) is not None:
842      # Formula derived by solving for n: m = n(n+1)/2.
843      m = np.int32(x.shape.dims[-1].value)
844      n = np.sqrt(0.25 + 2. * m) - 0.5
845      if n != np.floor(n):
846        raise ValueError("Input right-most shape ({}) does not "
847                         "correspond to a triangular matrix.".format(m))
848      n = np.int32(n)
849      static_final_shape = x.shape[:-1].concatenate([n, n])
850    else:
851      m = array_ops.shape(x)[-1]
852      # For derivation, see above. Casting automatically lops off the 0.5, so we
853      # omit it.  We don't validate n is an integer because this has
854      # graph-execution cost; an error will be thrown from the reshape, below.
855      n = math_ops.cast(
856          math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)),
857          dtype=dtypes.int32)
858      static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate(
859          [None, None])
860    # We now concatenate the "tail" of `x` to `x` (and reverse one of them).
861    #
862    # We do this based on the insight that the input `x` provides `ceil(n/2)`
863    # rows of an `n x n` matrix, some of which will get zeroed out being on the
864    # wrong side of the diagonal. The first row will not get zeroed out at all,
865    # and we need `floor(n/2)` more rows, so the first is what we omit from
866    # `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)`
867    # rows provided by a reversed tail, it is exactly the other set of elements
868    # of the reversed tail which will be zeroed out for being on the wrong side
869    # of the diagonal further up/down the matrix. And, in doing-so, we've filled
870    # the triangular matrix in a clock-wise spiral pattern. Neat!
871    #
872    # Try it out in numpy:
873    #  n = 3
874    #  x = np.arange(n * (n + 1) / 2)
875    #  m = x.shape[0]
876    #  n = np.int32(np.sqrt(.25 + 2 * m) - .5)
877    #  x_tail = x[(m - (n**2 - m)):]
878    #  np.concatenate([x_tail, x[::-1]], 0).reshape(n, n)  # lower
879    #  # ==> array([[3, 4, 5],
880    #               [5, 4, 3],
881    #               [2, 1, 0]])
882    #  np.concatenate([x, x_tail[::-1]], 0).reshape(n, n)  # upper
883    #  # ==> array([[0, 1, 2],
884    #               [3, 4, 5],
885    #               [5, 4, 3]])
886    #
887    # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't
888    # correctly handle `m == n == 1`. Hence, we do nonnegative indexing.
889    # Furthermore observe that:
890    #   m - (n**2 - m)
891    #   = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2)
892    #   = 2 (n**2 / 2 + n / 2) - n**2
893    #   = n**2 + n - n**2
894    #   = n
895    ndims = prefer_static_rank(x)
896    if upper:
897      x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])]
898    else:
899      x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])]
900    new_shape = (
901        static_final_shape.as_list() if static_final_shape.is_fully_defined()
902        else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0))
903    x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape)
904    x = array_ops.matrix_band_part(
905        x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0))
906    x.set_shape(static_final_shape)
907    return x
908
909
910def fill_triangular_inverse(x, upper=False, name=None):
911  """Creates a vector from a (batch of) triangular matrix.
912
913  The vector is created from the lower-triangular or upper-triangular portion
914  depending on the value of the parameter `upper`.
915
916  If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is
917  `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`.
918
919  Example:
920
921  ```python
922  fill_triangular_inverse(
923    [[4, 0, 0],
924     [6, 5, 0],
925     [3, 2, 1]])
926
927  # ==> [1, 2, 3, 4, 5, 6]
928
929  fill_triangular_inverse(
930    [[1, 2, 3],
931     [0, 5, 6],
932     [0, 0, 4]], upper=True)
933
934  # ==> [1, 2, 3, 4, 5, 6]
935  ```
936
937  Args:
938    x: `Tensor` representing lower (or upper) triangular elements.
939    upper: Python `bool` representing whether output matrix should be upper
940      triangular (`True`) or lower triangular (`False`, default).
941    name: Python `str`. The name to give this op.
942
943  Returns:
944    flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower
945      (or upper) triangular elements from `x`.
946  """
947
948  with ops.name_scope(name, "fill_triangular_inverse", values=[x]):
949    x = ops.convert_to_tensor(x, name="x")
950    if tensor_shape.dimension_value(
951        x.shape.with_rank_at_least(2)[-1]) is not None:
952      n = np.int32(x.shape.dims[-1].value)
953      m = np.int32((n * (n + 1)) // 2)
954      static_final_shape = x.shape[:-2].concatenate([m])
955    else:
956      n = array_ops.shape(x)[-1]
957      m = (n * (n + 1)) // 2
958      static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate(
959          [None])
960    ndims = prefer_static_rank(x)
961    if upper:
962      initial_elements = x[..., 0, :]
963      triangular_portion = x[..., 1:, :]
964    else:
965      initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2])
966      triangular_portion = x[..., :-1, :]
967    rotated_triangular_portion = array_ops.reverse(
968        array_ops.reverse(triangular_portion, axis=[ndims - 1]),
969        axis=[ndims - 2])
970    consolidated_matrix = triangular_portion + rotated_triangular_portion
971    end_sequence = array_ops.reshape(
972        consolidated_matrix,
973        array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0))
974    y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1)
975    y.set_shape(static_final_shape)
976    return y
977
978
979def tridiag(below=None, diag=None, above=None, name=None):
980  """Creates a matrix with values set above, below, and on the diagonal.
981
982  Example:
983
984  ```python
985  tridiag(below=[1., 2., 3.],
986          diag=[4., 5., 6., 7.],
987          above=[8., 9., 10.])
988  # ==> array([[  4.,   8.,   0.,   0.],
989  #            [  1.,   5.,   9.,   0.],
990  #            [  0.,   2.,   6.,  10.],
991  #            [  0.,   0.,   3.,   7.]], dtype=float32)
992  ```
993
994  Warning: This Op is intended for convenience, not efficiency.
995
996  Args:
997    below: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the below
998      diagonal part. `None` is logically equivalent to `below = 0`.
999    diag: `Tensor` of shape `[B1, ..., Bb, d]` corresponding to the diagonal
1000      part.  `None` is logically equivalent to `diag = 0`.
1001    above: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the above
1002      diagonal part.  `None` is logically equivalent to `above = 0`.
1003    name: Python `str`. The name to give this op.
1004
1005  Returns:
1006    tridiag: `Tensor` with values set above, below and on the diagonal.
1007
1008  Raises:
1009    ValueError: if all inputs are `None`.
1010  """
1011
1012  def _pad(x):
1013    """Prepends and appends a zero to every vector in a batch of vectors."""
1014    shape = array_ops.concat([array_ops.shape(x)[:-1], [1]], axis=0)
1015    z = array_ops.zeros(shape, dtype=x.dtype)
1016    return array_ops.concat([z, x, z], axis=-1)
1017
1018  def _add(*x):
1019    """Adds list of Tensors, ignoring `None`."""
1020    s = None
1021    for y in x:
1022      if y is None:
1023        continue
1024      elif s is None:
1025        s = y
1026      else:
1027        s += y
1028    if s is None:
1029      raise ValueError("Must specify at least one of `below`, `diag`, `above`.")
1030    return s
1031
1032  with ops.name_scope(name, "tridiag", [below, diag, above]):
1033    if below is not None:
1034      below = ops.convert_to_tensor(below, name="below")
1035      below = array_ops.matrix_diag(_pad(below))[..., :-1, 1:]
1036    if diag is not None:
1037      diag = ops.convert_to_tensor(diag, name="diag")
1038      diag = array_ops.matrix_diag(diag)
1039    if above is not None:
1040      above = ops.convert_to_tensor(above, name="above")
1041      above = array_ops.matrix_diag(_pad(above))[..., 1:, :-1]
1042    # TODO(jvdillon): Consider using scatter_nd instead of creating three full
1043    # matrices.
1044    return _add(below, diag, above)
1045
1046
1047def reduce_weighted_logsumexp(logx,
1048                              w=None,
1049                              axis=None,
1050                              keep_dims=False,
1051                              return_sign=False,
1052                              name=None):
1053  """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`.
1054
1055  If all weights `w` are known to be positive, it is more efficient to directly
1056  use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.math.log(w))` is
1057  more
1058  efficient than `du.reduce_weighted_logsumexp(logx, w)`.
1059
1060  Reduces `input_tensor` along the dimensions given in `axis`.
1061  Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
1062  entry in `axis`. If `keep_dims` is true, the reduced dimensions
1063  are retained with length 1.
1064
1065  If `axis` has no entries, all dimensions are reduced, and a
1066  tensor with a single element is returned.
1067
1068  This function is more numerically stable than log(sum(w * exp(input))). It
1069  avoids overflows caused by taking the exp of large inputs and underflows
1070  caused by taking the log of small inputs.
1071
1072  For example:
1073
1074  ```python
1075  x = tf.constant([[0., 0, 0],
1076                   [0, 0, 0]])
1077
1078  w = tf.constant([[-1., 1, 1],
1079                   [1, 1, 1]])
1080
1081  du.reduce_weighted_logsumexp(x, w)
1082  # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4)
1083
1084  du.reduce_weighted_logsumexp(x, w, axis=0)
1085  # ==> [log(-1+1), log(1+1), log(1+1)]
1086
1087  du.reduce_weighted_logsumexp(x, w, axis=1)
1088  # ==> [log(-1+1+1), log(1+1+1)]
1089
1090  du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)
1091  # ==> [[log(-1+1+1)], [log(1+1+1)]]
1092
1093  du.reduce_weighted_logsumexp(x, w, axis=[0, 1])
1094  # ==> log(-1+5)
1095  ```
1096
1097  Args:
1098    logx: The tensor to reduce. Should have numeric type.
1099    w: The weight tensor. Should have numeric type identical to `logx`.
1100    axis: The dimensions to reduce. If `None` (the default), reduces all
1101      dimensions. Must be in the range `[-rank(input_tensor),
1102      rank(input_tensor))`.
1103    keep_dims: If true, retains reduced dimensions with length 1.
1104    return_sign: If `True`, returns the sign of the result.
1105    name: A name for the operation (optional).
1106
1107  Returns:
1108    lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor.
1109    sign: (Optional) The sign of `sum(weight * exp(x))`.
1110  """
1111  with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]):
1112    logx = ops.convert_to_tensor(logx, name="logx")
1113    if w is None:
1114      lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims)
1115      if return_sign:
1116        sgn = array_ops.ones_like(lswe)
1117        return lswe, sgn
1118      return lswe
1119    w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w")
1120    log_absw_x = logx + math_ops.log(math_ops.abs(w))
1121    max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True)
1122    # If the largest element is `-inf` or `inf` then we don't bother subtracting
1123    # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That
1124    # this is ok follows from the fact that we're actually free to subtract any
1125    # value we like, so long as we add it back after taking the `log(sum(...))`.
1126    max_log_absw_x = array_ops.where_v2(
1127        math_ops.is_inf(max_log_absw_x), array_ops.zeros_like(max_log_absw_x),
1128        max_log_absw_x)
1129    wx_over_max_absw_x = (
1130        math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x))
1131    sum_wx_over_max_absw_x = math_ops.reduce_sum(
1132        wx_over_max_absw_x, axis=axis, keepdims=keep_dims)
1133    if not keep_dims:
1134      max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis)
1135    sgn = math_ops.sign(sum_wx_over_max_absw_x)
1136    lswe = max_log_absw_x + math_ops.log(sgn * sum_wx_over_max_absw_x)
1137    if return_sign:
1138      return lswe, sgn
1139    return lswe
1140
1141
1142# TODO(jvdillon): Merge this test back into:
1143# tensorflow/python/ops/softplus_op_test.py
1144# once TF core is accepting new ops.
1145def softplus_inverse(x, name=None):
1146  """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).
1147
1148  Mathematically this op is equivalent to:
1149
1150  ```none
1151  softplus_inverse = log(exp(x) - 1.)
1152  ```
1153
1154  Args:
1155    x: `Tensor`. Non-negative (not enforced), floating-point.
1156    name: A name for the operation (optional).
1157
1158  Returns:
1159    `Tensor`. Has the same type/shape as input `x`.
1160  """
1161  with ops.name_scope(name, "softplus_inverse", values=[x]):
1162    x = ops.convert_to_tensor(x, name="x")
1163    # We begin by deriving a more numerically stable softplus_inverse:
1164    # x = softplus(y) = Log[1 + exp{y}], (which means x > 0).
1165    # ==> exp{x} = 1 + exp{y}                                (1)
1166    # ==> y = Log[exp{x} - 1]                                (2)
1167    #       = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}]
1168    #       = Log[(1 - exp{-x}) / 1] + Log[exp{x}]
1169    #       = Log[1 - exp{-x}] + x                           (3)
1170    # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x.
1171    # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will
1172    # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0.
1173    #
1174    # In addition to the numerically stable derivation above, we clamp
1175    # small/large values to be congruent with the logic in:
1176    # tensorflow/core/kernels/softplus_op.h
1177    #
1178    # Finally, we set the input to one whenever the input is too large or too
1179    # small. This ensures that no unchosen codepath is +/- inf. This is
1180    # necessary to ensure the gradient doesn't get NaNs. Recall that the
1181    # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false`
1182    # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful
1183    # to overwrite `x` with ones only when we will never actually use this
1184    # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`.
1185    threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2.
1186    is_too_small = math_ops.less(x, np.exp(threshold))
1187    is_too_large = math_ops.greater(x, -threshold)
1188    too_small_value = math_ops.log(x)
1189    too_large_value = x
1190    # This `where` will ultimately be a NOP because we won't select this
1191    # codepath whenever we used the surrogate `ones_like`.
1192    x = array_ops.where_v2(
1193        math_ops.logical_or(is_too_small, is_too_large), array_ops.ones_like(x),
1194        x)
1195    y = x + math_ops.log(-math_ops.expm1(-x))  # == log(expm1(x))
1196    return array_ops.where_v2(
1197        is_too_small, too_small_value,
1198        array_ops.where_v2(is_too_large, too_large_value, y))
1199
1200
1201# TODO(b/35290280): Add unit-tests.
1202def dimension_size(x, axis):
1203  """Returns the size of a specific dimension."""
1204  # Since tf.gather isn't "constant-in, constant-out", we must first check the
1205  # static shape or fallback to dynamic shape.
1206  s = tensor_shape.dimension_value(
1207      x.shape.with_rank_at_least(np.abs(axis))[axis])
1208  if s is not None:
1209    return s
1210  return array_ops.shape(x)[axis]
1211
1212
1213def process_quadrature_grid_and_probs(quadrature_grid_and_probs,
1214                                      dtype,
1215                                      validate_args,
1216                                      name=None):
1217  """Validates quadrature grid, probs or computes them as necessary.
1218
1219  Args:
1220    quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
1221      representing the sample points and the corresponding (possibly
1222      normalized) weight.  When `None`, defaults to:
1223        `np.polynomial.hermite.hermgauss(deg=8)`.
1224    dtype: The expected `dtype` of `grid` and `probs`.
1225    validate_args: Python `bool`, default `False`. When `True` distribution
1226      parameters are checked for validity despite possibly degrading runtime
1227      performance. When `False` invalid inputs may silently render incorrect
1228      outputs.
1229    name: Python `str` name prefixed to Ops created by this class.
1230
1231  Returns:
1232     quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
1233      representing the sample points and the corresponding (possibly
1234      normalized) weight.
1235
1236  Raises:
1237    ValueError: if `quadrature_grid_and_probs is not None` and
1238      `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
1239  """
1240  with ops.name_scope(name, "process_quadrature_grid_and_probs",
1241                      [quadrature_grid_and_probs]):
1242    if quadrature_grid_and_probs is None:
1243      grid, probs = np.polynomial.hermite.hermgauss(deg=8)
1244      grid = grid.astype(dtype.as_numpy_dtype)
1245      probs = probs.astype(dtype.as_numpy_dtype)
1246      probs /= np.linalg.norm(probs, ord=1, keepdims=True)
1247      grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
1248      probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
1249      return grid, probs
1250
1251    grid, probs = tuple(quadrature_grid_and_probs)
1252    grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
1253    probs = ops.convert_to_tensor(probs, name="unnormalized_probs", dtype=dtype)
1254    probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs")
1255
1256    def _static_event_size(x):
1257      """Returns the static size of a specific dimension or `None`."""
1258      return tensor_shape.dimension_value(x.shape.with_rank_at_least(1)[-1])
1259
1260    m, n = _static_event_size(probs), _static_event_size(grid)
1261    if m is not None and n is not None:
1262      if m != n:
1263        raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
1264                         "same-length zero-th-dimension `Tensor`s "
1265                         "(saw lengths {}, {})".format(m, n))
1266    elif validate_args:
1267      assertions = [
1268          check_ops.assert_equal(
1269              dimension_size(probs, axis=-1),
1270              dimension_size(grid, axis=-1),
1271              message=("`quadrature_grid_and_probs` must be a `tuple` of "
1272                       "same-length zero-th-dimension `Tensor`s")),
1273      ]
1274      with ops.control_dependencies(assertions):
1275        grid = array_ops.identity(grid)
1276        probs = array_ops.identity(probs)
1277    return grid, probs
1278
1279
1280def pad(x, axis, front=False, back=False, value=0, count=1, name=None):
1281  """Pads `value` to the front and/or back of a `Tensor` dim, `count` times.
1282
1283  Args:
1284    x: `Tensor` input.
1285    axis: Scalar `int`-like `Tensor` representing the single dimension to pad.
1286      (Negative indexing is supported.)
1287    front: Python `bool`; if `True` the beginning of the `axis` dimension is
1288      padded with `value`, `count` times. If `False` no front padding is made.
1289    back: Python `bool`; if `True` the end of the `axis` dimension is padded
1290      with `value`, `count` times. If `False` no end padding is made.
1291    value: Scalar `int`-like `Tensor` representing the actual value added to the
1292      front and/or back of the `axis` dimension of `x`.
1293    count: Scalar `int`-like `Tensor` representing number of elements added to
1294      the front and/or back of the `axis` dimension of `x`. E.g., if `front =
1295      back = True` then `2 * count` elements are added.
1296    name: Python `str` name prefixed to Ops created by this function.
1297
1298  Returns:
1299    pad: The padded version of input `x`.
1300
1301  Raises:
1302    ValueError: if both `front` and `back` are `False`.
1303    TypeError: if `count` is not `int`-like.
1304  """
1305  with ops.name_scope(name, "pad", [x, value, count]):
1306    x = ops.convert_to_tensor(x, name="x")
1307    value = ops.convert_to_tensor(value, dtype=x.dtype, name="value")
1308    count = ops.convert_to_tensor(count, name="count")
1309    if not count.dtype.is_integer:
1310      raise TypeError("`count.dtype` (`{}`) must be `int`-like.".format(
1311          count.dtype.name))
1312    if not front and not back:
1313      raise ValueError("At least one of `front`, `back` must be `True`.")
1314    ndims = (
1315        x.shape.ndims if x.shape.ndims is not None else array_ops.rank(
1316            x, name="ndims"))
1317    axis = ops.convert_to_tensor(axis, name="axis")
1318    axis_ = tensor_util.constant_value(axis)
1319    if axis_ is not None:
1320      axis = axis_
1321      if axis < 0:
1322        axis = ndims + axis
1323      count_ = tensor_util.constant_value(count)
1324      if axis_ >= 0 or x.shape.ndims is not None:
1325        head = x.shape[:axis]
1326        middle = tensor_shape.TensorShape(None if count_ is None else (
1327            tensor_shape.dimension_at_index(x.shape, axis) + count_ *
1328            (front + back)))
1329        tail = x.shape[axis + 1:]
1330        final_shape = head.concatenate(middle.concatenate(tail))
1331      else:
1332        final_shape = None
1333    else:
1334      axis = array_ops.where_v2(axis < 0, ndims + axis, axis)
1335      final_shape = None
1336    x = array_ops.pad(
1337        x,
1338        paddings=array_ops.one_hot(
1339            indices=array_ops.stack(
1340                [axis if front else -1, axis if back else -1]),
1341            depth=ndims,
1342            axis=0,
1343            on_value=count,
1344            dtype=dtypes.int32),
1345        constant_values=value)
1346    if final_shape is not None:
1347      x.set_shape(final_shape)
1348    return x
1349
1350
1351def parent_frame_arguments():
1352  """Returns parent frame arguments.
1353
1354  When called inside a function, returns a dictionary with the caller's function
1355  arguments. These are positional arguments and keyword arguments (**kwargs),
1356  while variable arguments (*varargs) are excluded.
1357
1358  When called at global scope, this will return an empty dictionary, since there
1359  are no arguments.
1360
1361  WARNING: If caller function argument names are overloaded before invoking
1362  this method, then values will reflect the overloaded value. For this reason,
1363  we recommend calling `parent_frame_arguments` at the beginning of the
1364  function.
1365  """
1366  # All arguments and the names used for *varargs, and **kwargs
1367  arg_names, variable_arg_name, keyword_arg_name, local_vars = (
1368      tf_inspect._inspect.getargvalues(  # pylint: disable=protected-access
1369          # Get the first frame of the caller of this method.
1370          tf_inspect._inspect.stack()[1][0]))  # pylint: disable=protected-access
1371
1372  # Remove the *varargs, and flatten the **kwargs. Both are
1373  # nested lists.
1374  local_vars.pop(variable_arg_name, {})
1375  keyword_args = local_vars.pop(keyword_arg_name, {})
1376
1377  final_args = {}
1378  # Copy over arguments and their values. In general, local_vars
1379  # may contain more than just the arguments, since this method
1380  # can be called anywhere in a function.
1381  for arg_name in arg_names:
1382    final_args[arg_name] = local_vars.pop(arg_name)
1383  final_args.update(keyword_args)
1384
1385  return final_args
1386
1387
1388class AppendDocstring:
1389  """Helper class to promote private subclass docstring to public counterpart.
1390
1391  Example:
1392
1393  ```python
1394  class TransformedDistribution(Distribution):
1395    @distribution_util.AppendDocstring(
1396      additional_note="A special note!",
1397      kwargs_dict={"foo": "An extra arg."})
1398    def _prob(self, y, foo=None):
1399      pass
1400  ```
1401
1402  In this case, the `AppendDocstring` decorator appends the `additional_note` to
1403  the docstring of `prob` (not `_prob`) and adds a new `kwargs`
1404  section with each dictionary item as a bullet-point.
1405
1406  For a more detailed example, see `TransformedDistribution`.
1407  """
1408
1409  def __init__(self, additional_note="", kwargs_dict=None):
1410    """Initializes the AppendDocstring object.
1411
1412    Args:
1413      additional_note: Python string added as additional docstring to public
1414        version of function.
1415      kwargs_dict: Python string/string dictionary representing specific kwargs
1416        expanded from the **kwargs input.
1417
1418    Raises:
1419      ValueError: if kwargs_dict.key contains whitespace.
1420      ValueError: if kwargs_dict.value contains newlines.
1421    """
1422    self._additional_note = additional_note
1423    if kwargs_dict:
1424      bullets = []
1425      for key in sorted(kwargs_dict.keys()):
1426        value = kwargs_dict[key]
1427        if any(x.isspace() for x in key):
1428          raise ValueError("Parameter name \"%s\" contains whitespace." % key)
1429        value = value.lstrip()
1430        if "\n" in value:
1431          raise ValueError(
1432              "Parameter description for \"%s\" contains newlines." % key)
1433        bullets.append("*  `%s`: %s" % (key, value))
1434      self._additional_note += ("\n\n##### `kwargs`:\n\n" + "\n".join(bullets))
1435
1436  def __call__(self, fn):
1437
1438    @functools.wraps(fn)
1439    def _fn(*args, **kwargs):
1440      return fn(*args, **kwargs)
1441
1442    if _fn.__doc__ is None:
1443      _fn.__doc__ = self._additional_note
1444    else:
1445      _fn.__doc__ += "\n%s" % self._additional_note
1446    return _fn
1447