1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""The Bernoulli distribution class.""" 16 17from tensorflow.python.framework import dtypes 18from tensorflow.python.framework import ops 19from tensorflow.python.framework import tensor_shape 20from tensorflow.python.ops import array_ops 21from tensorflow.python.ops import math_ops 22from tensorflow.python.ops import nn 23from tensorflow.python.ops import random_ops 24from tensorflow.python.ops.distributions import distribution 25from tensorflow.python.ops.distributions import kullback_leibler 26from tensorflow.python.ops.distributions import util as distribution_util 27from tensorflow.python.util import deprecation 28from tensorflow.python.util.tf_export import tf_export 29 30 31@tf_export(v1=["distributions.Bernoulli"]) 32class Bernoulli(distribution.Distribution): 33 """Bernoulli distribution. 34 35 The Bernoulli distribution with `probs` parameter, i.e., the probability of a 36 `1` outcome (vs a `0` outcome). 37 """ 38 39 @deprecation.deprecated( 40 "2019-01-01", 41 "The TensorFlow Distributions library has moved to " 42 "TensorFlow Probability " 43 "(https://github.com/tensorflow/probability). You " 44 "should update all references to use `tfp.distributions` " 45 "instead of `tf.distributions`.", 46 warn_once=True) 47 def __init__(self, 48 logits=None, 49 probs=None, 50 dtype=dtypes.int32, 51 validate_args=False, 52 allow_nan_stats=True, 53 name="Bernoulli"): 54 """Construct Bernoulli distributions. 55 56 Args: 57 logits: An N-D `Tensor` representing the log-odds of a `1` event. Each 58 entry in the `Tensor` parametrizes an independent Bernoulli distribution 59 where the probability of an event is sigmoid(logits). Only one of 60 `logits` or `probs` should be passed in. 61 probs: An N-D `Tensor` representing the probability of a `1` 62 event. Each entry in the `Tensor` parameterizes an independent 63 Bernoulli distribution. Only one of `logits` or `probs` should be passed 64 in. 65 dtype: The type of the event samples. Default: `int32`. 66 validate_args: Python `bool`, default `False`. When `True` distribution 67 parameters are checked for validity despite possibly degrading runtime 68 performance. When `False` invalid inputs may silently render incorrect 69 outputs. 70 allow_nan_stats: Python `bool`, default `True`. When `True`, 71 statistics (e.g., mean, mode, variance) use the value "`NaN`" to 72 indicate the result is undefined. When `False`, an exception is raised 73 if one or more of the statistic's batch members are undefined. 74 name: Python `str` name prefixed to Ops created by this class. 75 76 Raises: 77 ValueError: If p and logits are passed, or if neither are passed. 78 """ 79 parameters = dict(locals()) 80 with ops.name_scope(name) as name: 81 self._logits, self._probs = distribution_util.get_logits_and_probs( 82 logits=logits, 83 probs=probs, 84 validate_args=validate_args, 85 name=name) 86 super(Bernoulli, self).__init__( 87 dtype=dtype, 88 reparameterization_type=distribution.NOT_REPARAMETERIZED, 89 validate_args=validate_args, 90 allow_nan_stats=allow_nan_stats, 91 parameters=parameters, 92 graph_parents=[self._logits, self._probs], 93 name=name) 94 95 @staticmethod 96 def _param_shapes(sample_shape): 97 return {"logits": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} 98 99 @property 100 def logits(self): 101 """Log-odds of a `1` outcome (vs `0`).""" 102 return self._logits 103 104 @property 105 def probs(self): 106 """Probability of a `1` outcome (vs `0`).""" 107 return self._probs 108 109 def _batch_shape_tensor(self): 110 return array_ops.shape(self._logits) 111 112 def _batch_shape(self): 113 return self._logits.get_shape() 114 115 def _event_shape_tensor(self): 116 return array_ops.constant([], dtype=dtypes.int32) 117 118 def _event_shape(self): 119 return tensor_shape.TensorShape([]) 120 121 def _sample_n(self, n, seed=None): 122 new_shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) 123 uniform = random_ops.random_uniform( 124 new_shape, seed=seed, dtype=self.probs.dtype) 125 sample = math_ops.less(uniform, self.probs) 126 return math_ops.cast(sample, self.dtype) 127 128 def _log_prob(self, event): 129 if self.validate_args: 130 event = distribution_util.embed_check_integer_casting_closed( 131 event, target_dtype=dtypes.bool) 132 133 # TODO(jaana): The current sigmoid_cross_entropy_with_logits has 134 # inconsistent behavior for logits = inf/-inf. 135 event = math_ops.cast(event, self.logits.dtype) 136 logits = self.logits 137 # sigmoid_cross_entropy_with_logits doesn't broadcast shape, 138 # so we do this here. 139 140 def _broadcast(logits, event): 141 return (array_ops.ones_like(event) * logits, 142 array_ops.ones_like(logits) * event) 143 144 if not (event.get_shape().is_fully_defined() and 145 logits.get_shape().is_fully_defined() and 146 event.get_shape() == logits.get_shape()): 147 logits, event = _broadcast(logits, event) 148 return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits) 149 150 def _entropy(self): 151 return (-self.logits * (math_ops.sigmoid(self.logits) - 1) + # pylint: disable=invalid-unary-operand-type 152 nn.softplus(-self.logits)) # pylint: disable=invalid-unary-operand-type 153 154 def _mean(self): 155 return array_ops.identity(self.probs) 156 157 def _variance(self): 158 return self._mean() * (1. - self.probs) 159 160 def _mode(self): 161 """Returns `1` if `prob > 0.5` and `0` otherwise.""" 162 return math_ops.cast(self.probs > 0.5, self.dtype) 163 164 165@kullback_leibler.RegisterKL(Bernoulli, Bernoulli) 166def _kl_bernoulli_bernoulli(a, b, name=None): 167 """Calculate the batched KL divergence KL(a || b) with a and b Bernoulli. 168 169 Args: 170 a: instance of a Bernoulli distribution object. 171 b: instance of a Bernoulli distribution object. 172 name: (optional) Name to use for created operations. 173 default is "kl_bernoulli_bernoulli". 174 175 Returns: 176 Batchwise KL(a || b) 177 """ 178 with ops.name_scope(name, "kl_bernoulli_bernoulli", 179 values=[a.logits, b.logits]): 180 delta_probs0 = nn.softplus(-b.logits) - nn.softplus(-a.logits) 181 delta_probs1 = nn.softplus(b.logits) - nn.softplus(a.logits) 182 return (math_ops.sigmoid(a.logits) * delta_probs0 183 + math_ops.sigmoid(-a.logits) * delta_probs1) 184