1# mypy: allow-untyped-defs 2from numbers import Number 3 4import torch 5from torch.distributions import constraints 6from torch.distributions.distribution import Distribution 7from torch.distributions.transformed_distribution import TransformedDistribution 8from torch.distributions.transforms import SigmoidTransform 9from torch.distributions.utils import ( 10 broadcast_all, 11 clamp_probs, 12 lazy_property, 13 logits_to_probs, 14 probs_to_logits, 15) 16from torch.types import _size 17 18 19__all__ = ["LogitRelaxedBernoulli", "RelaxedBernoulli"] 20 21 22class LogitRelaxedBernoulli(Distribution): 23 r""" 24 Creates a LogitRelaxedBernoulli distribution parameterized by :attr:`probs` 25 or :attr:`logits` (but not both), which is the logit of a RelaxedBernoulli 26 distribution. 27 28 Samples are logits of values in (0, 1). See [1] for more details. 29 30 Args: 31 temperature (Tensor): relaxation temperature 32 probs (Number, Tensor): the probability of sampling `1` 33 logits (Number, Tensor): the log-odds of sampling `1` 34 35 [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random 36 Variables (Maddison et al., 2017) 37 38 [2] Categorical Reparametrization with Gumbel-Softmax 39 (Jang et al., 2017) 40 """ 41 arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} 42 support = constraints.real 43 44 def __init__(self, temperature, probs=None, logits=None, validate_args=None): 45 self.temperature = temperature 46 if (probs is None) == (logits is None): 47 raise ValueError( 48 "Either `probs` or `logits` must be specified, but not both." 49 ) 50 if probs is not None: 51 is_scalar = isinstance(probs, Number) 52 (self.probs,) = broadcast_all(probs) 53 else: 54 is_scalar = isinstance(logits, Number) 55 (self.logits,) = broadcast_all(logits) 56 self._param = self.probs if probs is not None else self.logits 57 if is_scalar: 58 batch_shape = torch.Size() 59 else: 60 batch_shape = self._param.size() 61 super().__init__(batch_shape, validate_args=validate_args) 62 63 def expand(self, batch_shape, _instance=None): 64 new = self._get_checked_instance(LogitRelaxedBernoulli, _instance) 65 batch_shape = torch.Size(batch_shape) 66 new.temperature = self.temperature 67 if "probs" in self.__dict__: 68 new.probs = self.probs.expand(batch_shape) 69 new._param = new.probs 70 if "logits" in self.__dict__: 71 new.logits = self.logits.expand(batch_shape) 72 new._param = new.logits 73 super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=False) 74 new._validate_args = self._validate_args 75 return new 76 77 def _new(self, *args, **kwargs): 78 return self._param.new(*args, **kwargs) 79 80 @lazy_property 81 def logits(self): 82 return probs_to_logits(self.probs, is_binary=True) 83 84 @lazy_property 85 def probs(self): 86 return logits_to_probs(self.logits, is_binary=True) 87 88 @property 89 def param_shape(self): 90 return self._param.size() 91 92 def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 93 shape = self._extended_shape(sample_shape) 94 probs = clamp_probs(self.probs.expand(shape)) 95 uniforms = clamp_probs( 96 torch.rand(shape, dtype=probs.dtype, device=probs.device) 97 ) 98 return ( 99 uniforms.log() - (-uniforms).log1p() + probs.log() - (-probs).log1p() 100 ) / self.temperature 101 102 def log_prob(self, value): 103 if self._validate_args: 104 self._validate_sample(value) 105 logits, value = broadcast_all(self.logits, value) 106 diff = logits - value.mul(self.temperature) 107 return self.temperature.log() + diff - 2 * diff.exp().log1p() 108 109 110class RelaxedBernoulli(TransformedDistribution): 111 r""" 112 Creates a RelaxedBernoulli distribution, parametrized by 113 :attr:`temperature`, and either :attr:`probs` or :attr:`logits` 114 (but not both). This is a relaxed version of the `Bernoulli` distribution, 115 so the values are in (0, 1), and has reparametrizable samples. 116 117 Example:: 118 119 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 120 >>> m = RelaxedBernoulli(torch.tensor([2.2]), 121 ... torch.tensor([0.1, 0.2, 0.3, 0.99])) 122 >>> m.sample() 123 tensor([ 0.2951, 0.3442, 0.8918, 0.9021]) 124 125 Args: 126 temperature (Tensor): relaxation temperature 127 probs (Number, Tensor): the probability of sampling `1` 128 logits (Number, Tensor): the log-odds of sampling `1` 129 """ 130 arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} 131 support = constraints.unit_interval 132 has_rsample = True 133 134 def __init__(self, temperature, probs=None, logits=None, validate_args=None): 135 base_dist = LogitRelaxedBernoulli(temperature, probs, logits) 136 super().__init__(base_dist, SigmoidTransform(), validate_args=validate_args) 137 138 def expand(self, batch_shape, _instance=None): 139 new = self._get_checked_instance(RelaxedBernoulli, _instance) 140 return super().expand(batch_shape, _instance=new) 141 142 @property 143 def temperature(self): 144 return self.base_dist.temperature 145 146 @property 147 def logits(self): 148 return self.base_dist.logits 149 150 @property 151 def probs(self): 152 return self.base_dist.probs 153