xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/distributions/bernoulli.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The Bernoulli distribution class."""
16
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import tensor_shape
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops import nn
23from tensorflow.python.ops import random_ops
24from tensorflow.python.ops.distributions import distribution
25from tensorflow.python.ops.distributions import kullback_leibler
26from tensorflow.python.ops.distributions import util as distribution_util
27from tensorflow.python.util import deprecation
28from tensorflow.python.util.tf_export import tf_export
29
30
31@tf_export(v1=["distributions.Bernoulli"])
32class Bernoulli(distribution.Distribution):
33  """Bernoulli distribution.
34
35  The Bernoulli distribution with `probs` parameter, i.e., the probability of a
36  `1` outcome (vs a `0` outcome).
37  """
38
39  @deprecation.deprecated(
40      "2019-01-01",
41      "The TensorFlow Distributions library has moved to "
42      "TensorFlow Probability "
43      "(https://github.com/tensorflow/probability). You "
44      "should update all references to use `tfp.distributions` "
45      "instead of `tf.distributions`.",
46      warn_once=True)
47  def __init__(self,
48               logits=None,
49               probs=None,
50               dtype=dtypes.int32,
51               validate_args=False,
52               allow_nan_stats=True,
53               name="Bernoulli"):
54    """Construct Bernoulli distributions.
55
56    Args:
57      logits: An N-D `Tensor` representing the log-odds of a `1` event. Each
58        entry in the `Tensor` parametrizes an independent Bernoulli distribution
59        where the probability of an event is sigmoid(logits). Only one of
60        `logits` or `probs` should be passed in.
61      probs: An N-D `Tensor` representing the probability of a `1`
62        event. Each entry in the `Tensor` parameterizes an independent
63        Bernoulli distribution. Only one of `logits` or `probs` should be passed
64        in.
65      dtype: The type of the event samples. Default: `int32`.
66      validate_args: Python `bool`, default `False`. When `True` distribution
67        parameters are checked for validity despite possibly degrading runtime
68        performance. When `False` invalid inputs may silently render incorrect
69        outputs.
70      allow_nan_stats: Python `bool`, default `True`. When `True`,
71        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
72        indicate the result is undefined. When `False`, an exception is raised
73        if one or more of the statistic's batch members are undefined.
74      name: Python `str` name prefixed to Ops created by this class.
75
76    Raises:
77      ValueError: If p and logits are passed, or if neither are passed.
78    """
79    parameters = dict(locals())
80    with ops.name_scope(name) as name:
81      self._logits, self._probs = distribution_util.get_logits_and_probs(
82          logits=logits,
83          probs=probs,
84          validate_args=validate_args,
85          name=name)
86    super(Bernoulli, self).__init__(
87        dtype=dtype,
88        reparameterization_type=distribution.NOT_REPARAMETERIZED,
89        validate_args=validate_args,
90        allow_nan_stats=allow_nan_stats,
91        parameters=parameters,
92        graph_parents=[self._logits, self._probs],
93        name=name)
94
95  @staticmethod
96  def _param_shapes(sample_shape):
97    return {"logits": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)}
98
99  @property
100  def logits(self):
101    """Log-odds of a `1` outcome (vs `0`)."""
102    return self._logits
103
104  @property
105  def probs(self):
106    """Probability of a `1` outcome (vs `0`)."""
107    return self._probs
108
109  def _batch_shape_tensor(self):
110    return array_ops.shape(self._logits)
111
112  def _batch_shape(self):
113    return self._logits.get_shape()
114
115  def _event_shape_tensor(self):
116    return array_ops.constant([], dtype=dtypes.int32)
117
118  def _event_shape(self):
119    return tensor_shape.TensorShape([])
120
121  def _sample_n(self, n, seed=None):
122    new_shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
123    uniform = random_ops.random_uniform(
124        new_shape, seed=seed, dtype=self.probs.dtype)
125    sample = math_ops.less(uniform, self.probs)
126    return math_ops.cast(sample, self.dtype)
127
128  def _log_prob(self, event):
129    if self.validate_args:
130      event = distribution_util.embed_check_integer_casting_closed(
131          event, target_dtype=dtypes.bool)
132
133    # TODO(jaana): The current sigmoid_cross_entropy_with_logits has
134    # inconsistent behavior for logits = inf/-inf.
135    event = math_ops.cast(event, self.logits.dtype)
136    logits = self.logits
137    # sigmoid_cross_entropy_with_logits doesn't broadcast shape,
138    # so we do this here.
139
140    def _broadcast(logits, event):
141      return (array_ops.ones_like(event) * logits,
142              array_ops.ones_like(logits) * event)
143
144    if not (event.get_shape().is_fully_defined() and
145            logits.get_shape().is_fully_defined() and
146            event.get_shape() == logits.get_shape()):
147      logits, event = _broadcast(logits, event)
148    return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
149
150  def _entropy(self):
151    return (-self.logits * (math_ops.sigmoid(self.logits) - 1) +  # pylint: disable=invalid-unary-operand-type
152            nn.softplus(-self.logits))  # pylint: disable=invalid-unary-operand-type
153
154  def _mean(self):
155    return array_ops.identity(self.probs)
156
157  def _variance(self):
158    return self._mean() * (1. - self.probs)
159
160  def _mode(self):
161    """Returns `1` if `prob > 0.5` and `0` otherwise."""
162    return math_ops.cast(self.probs > 0.5, self.dtype)
163
164
165@kullback_leibler.RegisterKL(Bernoulli, Bernoulli)
166def _kl_bernoulli_bernoulli(a, b, name=None):
167  """Calculate the batched KL divergence KL(a || b) with a and b Bernoulli.
168
169  Args:
170    a: instance of a Bernoulli distribution object.
171    b: instance of a Bernoulli distribution object.
172    name: (optional) Name to use for created operations.
173      default is "kl_bernoulli_bernoulli".
174
175  Returns:
176    Batchwise KL(a || b)
177  """
178  with ops.name_scope(name, "kl_bernoulli_bernoulli",
179                      values=[a.logits, b.logits]):
180    delta_probs0 = nn.softplus(-b.logits) - nn.softplus(-a.logits)
181    delta_probs1 = nn.softplus(b.logits) - nn.softplus(a.logits)
182    return (math_ops.sigmoid(a.logits) * delta_probs0
183            + math_ops.sigmoid(-a.logits) * delta_probs1)
184