1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Workerfrom torch import inf 4*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import Categorical, constraints 5*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.binomial import Binomial 6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.distribution import Distribution 7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import broadcast_all 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker__all__ = ["Multinomial"] 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerclass Multinomial(Distribution): 14*da0073e9SAndroid Build Coastguard Worker r""" 15*da0073e9SAndroid Build Coastguard Worker Creates a Multinomial distribution parameterized by :attr:`total_count` and 16*da0073e9SAndroid Build Coastguard Worker either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of 17*da0073e9SAndroid Build Coastguard Worker :attr:`probs` indexes over categories. All other dimensions index over batches. 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is 20*da0073e9SAndroid Build Coastguard Worker called (see example below) 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, 23*da0073e9SAndroid Build Coastguard Worker and it will be normalized to sum to 1 along the last dimension. :attr:`probs` 24*da0073e9SAndroid Build Coastguard Worker will return this normalized value. 25*da0073e9SAndroid Build Coastguard Worker The `logits` argument will be interpreted as unnormalized log probabilities 26*da0073e9SAndroid Build Coastguard Worker and can therefore be any real number. It will likewise be normalized so that 27*da0073e9SAndroid Build Coastguard Worker the resulting probabilities sum to 1 along the last dimension. :attr:`logits` 28*da0073e9SAndroid Build Coastguard Worker will return this normalized value. 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker - :meth:`sample` requires a single shared `total_count` for all 31*da0073e9SAndroid Build Coastguard Worker parameters and samples. 32*da0073e9SAndroid Build Coastguard Worker - :meth:`log_prob` allows different `total_count` for each parameter and 33*da0073e9SAndroid Build Coastguard Worker sample. 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker Example:: 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP("FIXME: found invalid values") 38*da0073e9SAndroid Build Coastguard Worker >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.])) 39*da0073e9SAndroid Build Coastguard Worker >>> x = m.sample() # equal probability of 0, 1, 2, 3 40*da0073e9SAndroid Build Coastguard Worker tensor([ 21., 24., 30., 25.]) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x) 43*da0073e9SAndroid Build Coastguard Worker tensor([-4.1338]) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker Args: 46*da0073e9SAndroid Build Coastguard Worker total_count (int): number of trials 47*da0073e9SAndroid Build Coastguard Worker probs (Tensor): event probabilities 48*da0073e9SAndroid Build Coastguard Worker logits (Tensor): event log probabilities (unnormalized) 49*da0073e9SAndroid Build Coastguard Worker """ 50*da0073e9SAndroid Build Coastguard Worker arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} 51*da0073e9SAndroid Build Coastguard Worker total_count: int 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker @property 54*da0073e9SAndroid Build Coastguard Worker def mean(self): 55*da0073e9SAndroid Build Coastguard Worker return self.probs * self.total_count 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker @property 58*da0073e9SAndroid Build Coastguard Worker def variance(self): 59*da0073e9SAndroid Build Coastguard Worker return self.total_count * self.probs * (1 - self.probs) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): 62*da0073e9SAndroid Build Coastguard Worker if not isinstance(total_count, int): 63*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("inhomogeneous total_count is not supported") 64*da0073e9SAndroid Build Coastguard Worker self.total_count = total_count 65*da0073e9SAndroid Build Coastguard Worker self._categorical = Categorical(probs=probs, logits=logits) 66*da0073e9SAndroid Build Coastguard Worker self._binomial = Binomial(total_count=total_count, probs=self.probs) 67*da0073e9SAndroid Build Coastguard Worker batch_shape = self._categorical.batch_shape 68*da0073e9SAndroid Build Coastguard Worker event_shape = self._categorical.param_shape[-1:] 69*da0073e9SAndroid Build Coastguard Worker super().__init__(batch_shape, event_shape, validate_args=validate_args) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 72*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Multinomial, _instance) 73*da0073e9SAndroid Build Coastguard Worker batch_shape = torch.Size(batch_shape) 74*da0073e9SAndroid Build Coastguard Worker new.total_count = self.total_count 75*da0073e9SAndroid Build Coastguard Worker new._categorical = self._categorical.expand(batch_shape) 76*da0073e9SAndroid Build Coastguard Worker super(Multinomial, new).__init__( 77*da0073e9SAndroid Build Coastguard Worker batch_shape, self.event_shape, validate_args=False 78*da0073e9SAndroid Build Coastguard Worker ) 79*da0073e9SAndroid Build Coastguard Worker new._validate_args = self._validate_args 80*da0073e9SAndroid Build Coastguard Worker return new 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker def _new(self, *args, **kwargs): 83*da0073e9SAndroid Build Coastguard Worker return self._categorical._new(*args, **kwargs) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker @constraints.dependent_property(is_discrete=True, event_dim=1) 86*da0073e9SAndroid Build Coastguard Worker def support(self): 87*da0073e9SAndroid Build Coastguard Worker return constraints.multinomial(self.total_count) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker @property 90*da0073e9SAndroid Build Coastguard Worker def logits(self): 91*da0073e9SAndroid Build Coastguard Worker return self._categorical.logits 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker @property 94*da0073e9SAndroid Build Coastguard Worker def probs(self): 95*da0073e9SAndroid Build Coastguard Worker return self._categorical.probs 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker @property 98*da0073e9SAndroid Build Coastguard Worker def param_shape(self): 99*da0073e9SAndroid Build Coastguard Worker return self._categorical.param_shape 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker def sample(self, sample_shape=torch.Size()): 102*da0073e9SAndroid Build Coastguard Worker sample_shape = torch.Size(sample_shape) 103*da0073e9SAndroid Build Coastguard Worker samples = self._categorical.sample( 104*da0073e9SAndroid Build Coastguard Worker torch.Size((self.total_count,)) + sample_shape 105*da0073e9SAndroid Build Coastguard Worker ) 106*da0073e9SAndroid Build Coastguard Worker # samples.shape is (total_count, sample_shape, batch_shape), need to change it to 107*da0073e9SAndroid Build Coastguard Worker # (sample_shape, batch_shape, total_count) 108*da0073e9SAndroid Build Coastguard Worker shifted_idx = list(range(samples.dim())) 109*da0073e9SAndroid Build Coastguard Worker shifted_idx.append(shifted_idx.pop(0)) 110*da0073e9SAndroid Build Coastguard Worker samples = samples.permute(*shifted_idx) 111*da0073e9SAndroid Build Coastguard Worker counts = samples.new(self._extended_shape(sample_shape)).zero_() 112*da0073e9SAndroid Build Coastguard Worker counts.scatter_add_(-1, samples, torch.ones_like(samples)) 113*da0073e9SAndroid Build Coastguard Worker return counts.type_as(self.probs) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker def entropy(self): 116*da0073e9SAndroid Build Coastguard Worker n = torch.tensor(self.total_count) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker cat_entropy = self._categorical.entropy() 119*da0073e9SAndroid Build Coastguard Worker term1 = n * cat_entropy - torch.lgamma(n + 1) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker support = self._binomial.enumerate_support(expand=False)[1:] 122*da0073e9SAndroid Build Coastguard Worker binomial_probs = torch.exp(self._binomial.log_prob(support)) 123*da0073e9SAndroid Build Coastguard Worker weights = torch.lgamma(support + 1) 124*da0073e9SAndroid Build Coastguard Worker term2 = (binomial_probs * weights).sum([0, -1]) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker return term1 + term2 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker def log_prob(self, value): 129*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 130*da0073e9SAndroid Build Coastguard Worker self._validate_sample(value) 131*da0073e9SAndroid Build Coastguard Worker logits, value = broadcast_all(self.logits, value) 132*da0073e9SAndroid Build Coastguard Worker logits = logits.clone(memory_format=torch.contiguous_format) 133*da0073e9SAndroid Build Coastguard Worker log_factorial_n = torch.lgamma(value.sum(-1) + 1) 134*da0073e9SAndroid Build Coastguard Worker log_factorial_xs = torch.lgamma(value + 1).sum(-1) 135*da0073e9SAndroid Build Coastguard Worker logits[(value == 0) & (logits == -inf)] = 0 136*da0073e9SAndroid Build Coastguard Worker log_powers = (logits * value).sum(-1) 137*da0073e9SAndroid Build Coastguard Worker return log_factorial_n - log_factorial_xs + log_powers 138