1# mypy: allow-untyped-defs 2import torch 3from torch.distributions import constraints 4from torch.distributions.categorical import Categorical 5from torch.distributions.distribution import Distribution 6from torch.distributions.transformed_distribution import TransformedDistribution 7from torch.distributions.transforms import ExpTransform 8from torch.distributions.utils import broadcast_all, clamp_probs 9from torch.types import _size 10 11 12__all__ = ["ExpRelaxedCategorical", "RelaxedOneHotCategorical"] 13 14 15class ExpRelaxedCategorical(Distribution): 16 r""" 17 Creates a ExpRelaxedCategorical parameterized by 18 :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both). 19 Returns the log of a point in the simplex. Based on the interface to 20 :class:`OneHotCategorical`. 21 22 Implementation based on [1]. 23 24 See also: :func:`torch.distributions.OneHotCategorical` 25 26 Args: 27 temperature (Tensor): relaxation temperature 28 probs (Tensor): event probabilities 29 logits (Tensor): unnormalized log probability for each event 30 31 [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables 32 (Maddison et al., 2017) 33 34 [2] Categorical Reparametrization with Gumbel-Softmax 35 (Jang et al., 2017) 36 """ 37 arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} 38 support = ( 39 constraints.real_vector 40 ) # The true support is actually a submanifold of this. 41 has_rsample = True 42 43 def __init__(self, temperature, probs=None, logits=None, validate_args=None): 44 self._categorical = Categorical(probs, logits) 45 self.temperature = temperature 46 batch_shape = self._categorical.batch_shape 47 event_shape = self._categorical.param_shape[-1:] 48 super().__init__(batch_shape, event_shape, validate_args=validate_args) 49 50 def expand(self, batch_shape, _instance=None): 51 new = self._get_checked_instance(ExpRelaxedCategorical, _instance) 52 batch_shape = torch.Size(batch_shape) 53 new.temperature = self.temperature 54 new._categorical = self._categorical.expand(batch_shape) 55 super(ExpRelaxedCategorical, new).__init__( 56 batch_shape, self.event_shape, validate_args=False 57 ) 58 new._validate_args = self._validate_args 59 return new 60 61 def _new(self, *args, **kwargs): 62 return self._categorical._new(*args, **kwargs) 63 64 @property 65 def param_shape(self): 66 return self._categorical.param_shape 67 68 @property 69 def logits(self): 70 return self._categorical.logits 71 72 @property 73 def probs(self): 74 return self._categorical.probs 75 76 def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 77 shape = self._extended_shape(sample_shape) 78 uniforms = clamp_probs( 79 torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) 80 ) 81 gumbels = -((-(uniforms.log())).log()) 82 scores = (self.logits + gumbels) / self.temperature 83 return scores - scores.logsumexp(dim=-1, keepdim=True) 84 85 def log_prob(self, value): 86 K = self._categorical._num_events 87 if self._validate_args: 88 self._validate_sample(value) 89 logits, value = broadcast_all(self.logits, value) 90 log_scale = torch.full_like( 91 self.temperature, float(K) 92 ).lgamma() - self.temperature.log().mul(-(K - 1)) 93 score = logits - value.mul(self.temperature) 94 score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1) 95 return score + log_scale 96 97 98class RelaxedOneHotCategorical(TransformedDistribution): 99 r""" 100 Creates a RelaxedOneHotCategorical distribution parametrized by 101 :attr:`temperature`, and either :attr:`probs` or :attr:`logits`. 102 This is a relaxed version of the :class:`OneHotCategorical` distribution, so 103 its samples are on simplex, and are reparametrizable. 104 105 Example:: 106 107 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 108 >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), 109 ... torch.tensor([0.1, 0.2, 0.3, 0.4])) 110 >>> m.sample() 111 tensor([ 0.1294, 0.2324, 0.3859, 0.2523]) 112 113 Args: 114 temperature (Tensor): relaxation temperature 115 probs (Tensor): event probabilities 116 logits (Tensor): unnormalized log probability for each event 117 """ 118 arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} 119 support = constraints.simplex 120 has_rsample = True 121 122 def __init__(self, temperature, probs=None, logits=None, validate_args=None): 123 base_dist = ExpRelaxedCategorical( 124 temperature, probs, logits, validate_args=validate_args 125 ) 126 super().__init__(base_dist, ExpTransform(), validate_args=validate_args) 127 128 def expand(self, batch_shape, _instance=None): 129 new = self._get_checked_instance(RelaxedOneHotCategorical, _instance) 130 return super().expand(batch_shape, _instance=new) 131 132 @property 133 def temperature(self): 134 return self.base_dist.temperature 135 136 @property 137 def logits(self): 138 return self.base_dist.logits 139 140 @property 141 def probs(self): 142 return self.base_dist.probs 143