xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/distributions/exponential.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Exponential distribution class."""
16
17import numpy as np
18
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.ops import nn
24from tensorflow.python.ops import random_ops
25from tensorflow.python.ops.distributions import gamma
26from tensorflow.python.util import deprecation
27from tensorflow.python.util.tf_export import tf_export
28
29
30__all__ = [
31    "Exponential",
32    "ExponentialWithSoftplusRate",
33]
34
35
36@tf_export(v1=["distributions.Exponential"])
37class Exponential(gamma.Gamma):
38  """Exponential distribution.
39
40  The Exponential distribution is parameterized by an event `rate` parameter.
41
42  #### Mathematical Details
43
44  The probability density function (pdf) is,
45
46  ```none
47  pdf(x; lambda, x > 0) = exp(-lambda x) / Z
48  Z = 1 / lambda
49  ```
50
51  where `rate = lambda` and `Z` is the normalizaing constant.
52
53  The Exponential distribution is a special case of the Gamma distribution,
54  i.e.,
55
56  ```python
57  Exponential(rate) = Gamma(concentration=1., rate)
58  ```
59
60  The Exponential distribution uses a `rate` parameter, or "inverse scale",
61  which can be intuited as,
62
63  ```none
64  X ~ Exponential(rate=1)
65  Y = X / rate
66  ```
67
68  """
69
70  @deprecation.deprecated(
71      "2019-01-01",
72      "The TensorFlow Distributions library has moved to "
73      "TensorFlow Probability "
74      "(https://github.com/tensorflow/probability). You "
75      "should update all references to use `tfp.distributions` "
76      "instead of `tf.distributions`.",
77      warn_once=True)
78  def __init__(self,
79               rate,
80               validate_args=False,
81               allow_nan_stats=True,
82               name="Exponential"):
83    """Construct Exponential distribution with parameter `rate`.
84
85    Args:
86      rate: Floating point tensor, equivalent to `1 / mean`. Must contain only
87        positive values.
88      validate_args: Python `bool`, default `False`. When `True` distribution
89        parameters are checked for validity despite possibly degrading runtime
90        performance. When `False` invalid inputs may silently render incorrect
91        outputs.
92      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
93        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
94        result is undefined. When `False`, an exception is raised if one or
95        more of the statistic's batch members are undefined.
96      name: Python `str` name prefixed to Ops created by this class.
97    """
98    parameters = dict(locals())
99    # Even though all statistics of are defined for valid inputs, this is not
100    # true in the parent class "Gamma."  Therefore, passing
101    # allow_nan_stats=True
102    # through to the parent class results in unnecessary asserts.
103    with ops.name_scope(name, values=[rate]) as name:
104      self._rate = ops.convert_to_tensor(rate, name="rate")
105    super(Exponential, self).__init__(
106        concentration=array_ops.ones([], dtype=self._rate.dtype),
107        rate=self._rate,
108        allow_nan_stats=allow_nan_stats,
109        validate_args=validate_args,
110        name=name)
111    self._parameters = parameters
112    self._graph_parents += [self._rate]
113
114  @staticmethod
115  def _param_shapes(sample_shape):
116    return {"rate": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)}
117
118  @property
119  def rate(self):
120    return self._rate
121
122  def _log_survival_function(self, value):
123    return self._log_prob(value) - math_ops.log(self._rate)
124
125  def _sample_n(self, n, seed=None):
126    shape = array_ops.concat([[n], array_ops.shape(self._rate)], 0)
127    # Uniform variates must be sampled from the open-interval `(0, 1)` rather
128    # than `[0, 1)`. To do so, we use `np.finfo(self.dtype.as_numpy_dtype).tiny`
129    # because it is the smallest, positive, "normal" number. A "normal" number
130    # is such that the mantissa has an implicit leading 1. Normal, positive
131    # numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In
132    # this case, a subnormal number (i.e., np.nextafter) can cause us to sample
133    # 0.
134    sampled = random_ops.random_uniform(
135        shape,
136        minval=np.finfo(self.dtype.as_numpy_dtype).tiny,
137        maxval=1.,
138        seed=seed,
139        dtype=self.dtype)
140    return -math_ops.log(sampled) / self._rate
141
142
143class ExponentialWithSoftplusRate(Exponential):
144  """Exponential with softplus transform on `rate`."""
145
146  @deprecation.deprecated(
147      "2019-01-01",
148      "Use `tfd.Exponential(tf.nn.softplus(rate)).",
149      warn_once=True)
150  def __init__(self,
151               rate,
152               validate_args=False,
153               allow_nan_stats=True,
154               name="ExponentialWithSoftplusRate"):
155    parameters = dict(locals())
156    with ops.name_scope(name, values=[rate]) as name:
157      super(ExponentialWithSoftplusRate, self).__init__(
158          rate=nn.softplus(rate, name="softplus_rate"),
159          validate_args=validate_args,
160          allow_nan_stats=allow_nan_stats,
161          name=name)
162    self._parameters = parameters
163