xref: /aosp_15_r20/external/pytorch/torch/distributions/multinomial.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport torch
3*da0073e9SAndroid Build Coastguard Workerfrom torch import inf
4*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import Categorical, constraints
5*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.binomial import Binomial
6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.distribution import Distribution
7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import broadcast_all
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker__all__ = ["Multinomial"]
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerclass Multinomial(Distribution):
14*da0073e9SAndroid Build Coastguard Worker    r"""
15*da0073e9SAndroid Build Coastguard Worker    Creates a Multinomial distribution parameterized by :attr:`total_count` and
16*da0073e9SAndroid Build Coastguard Worker    either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
17*da0073e9SAndroid Build Coastguard Worker    :attr:`probs` indexes over categories. All other dimensions index over batches.
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
20*da0073e9SAndroid Build Coastguard Worker    called (see example below)
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
23*da0073e9SAndroid Build Coastguard Worker              and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
24*da0073e9SAndroid Build Coastguard Worker              will return this normalized value.
25*da0073e9SAndroid Build Coastguard Worker              The `logits` argument will be interpreted as unnormalized log probabilities
26*da0073e9SAndroid Build Coastguard Worker              and can therefore be any real number. It will likewise be normalized so that
27*da0073e9SAndroid Build Coastguard Worker              the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
28*da0073e9SAndroid Build Coastguard Worker              will return this normalized value.
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    -   :meth:`sample` requires a single shared `total_count` for all
31*da0073e9SAndroid Build Coastguard Worker        parameters and samples.
32*da0073e9SAndroid Build Coastguard Worker    -   :meth:`log_prob` allows different `total_count` for each parameter and
33*da0073e9SAndroid Build Coastguard Worker        sample.
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    Example::
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP("FIXME: found invalid values")
38*da0073e9SAndroid Build Coastguard Worker        >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
39*da0073e9SAndroid Build Coastguard Worker        >>> x = m.sample()  # equal probability of 0, 1, 2, 3
40*da0073e9SAndroid Build Coastguard Worker        tensor([ 21.,  24.,  30.,  25.])
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker        >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
43*da0073e9SAndroid Build Coastguard Worker        tensor([-4.1338])
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    Args:
46*da0073e9SAndroid Build Coastguard Worker        total_count (int): number of trials
47*da0073e9SAndroid Build Coastguard Worker        probs (Tensor): event probabilities
48*da0073e9SAndroid Build Coastguard Worker        logits (Tensor): event log probabilities (unnormalized)
49*da0073e9SAndroid Build Coastguard Worker    """
50*da0073e9SAndroid Build Coastguard Worker    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
51*da0073e9SAndroid Build Coastguard Worker    total_count: int
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    @property
54*da0073e9SAndroid Build Coastguard Worker    def mean(self):
55*da0073e9SAndroid Build Coastguard Worker        return self.probs * self.total_count
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    @property
58*da0073e9SAndroid Build Coastguard Worker    def variance(self):
59*da0073e9SAndroid Build Coastguard Worker        return self.total_count * self.probs * (1 - self.probs)
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
62*da0073e9SAndroid Build Coastguard Worker        if not isinstance(total_count, int):
63*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError("inhomogeneous total_count is not supported")
64*da0073e9SAndroid Build Coastguard Worker        self.total_count = total_count
65*da0073e9SAndroid Build Coastguard Worker        self._categorical = Categorical(probs=probs, logits=logits)
66*da0073e9SAndroid Build Coastguard Worker        self._binomial = Binomial(total_count=total_count, probs=self.probs)
67*da0073e9SAndroid Build Coastguard Worker        batch_shape = self._categorical.batch_shape
68*da0073e9SAndroid Build Coastguard Worker        event_shape = self._categorical.param_shape[-1:]
69*da0073e9SAndroid Build Coastguard Worker        super().__init__(batch_shape, event_shape, validate_args=validate_args)
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    def expand(self, batch_shape, _instance=None):
72*da0073e9SAndroid Build Coastguard Worker        new = self._get_checked_instance(Multinomial, _instance)
73*da0073e9SAndroid Build Coastguard Worker        batch_shape = torch.Size(batch_shape)
74*da0073e9SAndroid Build Coastguard Worker        new.total_count = self.total_count
75*da0073e9SAndroid Build Coastguard Worker        new._categorical = self._categorical.expand(batch_shape)
76*da0073e9SAndroid Build Coastguard Worker        super(Multinomial, new).__init__(
77*da0073e9SAndroid Build Coastguard Worker            batch_shape, self.event_shape, validate_args=False
78*da0073e9SAndroid Build Coastguard Worker        )
79*da0073e9SAndroid Build Coastguard Worker        new._validate_args = self._validate_args
80*da0073e9SAndroid Build Coastguard Worker        return new
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    def _new(self, *args, **kwargs):
83*da0073e9SAndroid Build Coastguard Worker        return self._categorical._new(*args, **kwargs)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    @constraints.dependent_property(is_discrete=True, event_dim=1)
86*da0073e9SAndroid Build Coastguard Worker    def support(self):
87*da0073e9SAndroid Build Coastguard Worker        return constraints.multinomial(self.total_count)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    @property
90*da0073e9SAndroid Build Coastguard Worker    def logits(self):
91*da0073e9SAndroid Build Coastguard Worker        return self._categorical.logits
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    @property
94*da0073e9SAndroid Build Coastguard Worker    def probs(self):
95*da0073e9SAndroid Build Coastguard Worker        return self._categorical.probs
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    @property
98*da0073e9SAndroid Build Coastguard Worker    def param_shape(self):
99*da0073e9SAndroid Build Coastguard Worker        return self._categorical.param_shape
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    def sample(self, sample_shape=torch.Size()):
102*da0073e9SAndroid Build Coastguard Worker        sample_shape = torch.Size(sample_shape)
103*da0073e9SAndroid Build Coastguard Worker        samples = self._categorical.sample(
104*da0073e9SAndroid Build Coastguard Worker            torch.Size((self.total_count,)) + sample_shape
105*da0073e9SAndroid Build Coastguard Worker        )
106*da0073e9SAndroid Build Coastguard Worker        # samples.shape is (total_count, sample_shape, batch_shape), need to change it to
107*da0073e9SAndroid Build Coastguard Worker        # (sample_shape, batch_shape, total_count)
108*da0073e9SAndroid Build Coastguard Worker        shifted_idx = list(range(samples.dim()))
109*da0073e9SAndroid Build Coastguard Worker        shifted_idx.append(shifted_idx.pop(0))
110*da0073e9SAndroid Build Coastguard Worker        samples = samples.permute(*shifted_idx)
111*da0073e9SAndroid Build Coastguard Worker        counts = samples.new(self._extended_shape(sample_shape)).zero_()
112*da0073e9SAndroid Build Coastguard Worker        counts.scatter_add_(-1, samples, torch.ones_like(samples))
113*da0073e9SAndroid Build Coastguard Worker        return counts.type_as(self.probs)
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker    def entropy(self):
116*da0073e9SAndroid Build Coastguard Worker        n = torch.tensor(self.total_count)
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker        cat_entropy = self._categorical.entropy()
119*da0073e9SAndroid Build Coastguard Worker        term1 = n * cat_entropy - torch.lgamma(n + 1)
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker        support = self._binomial.enumerate_support(expand=False)[1:]
122*da0073e9SAndroid Build Coastguard Worker        binomial_probs = torch.exp(self._binomial.log_prob(support))
123*da0073e9SAndroid Build Coastguard Worker        weights = torch.lgamma(support + 1)
124*da0073e9SAndroid Build Coastguard Worker        term2 = (binomial_probs * weights).sum([0, -1])
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        return term1 + term2
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    def log_prob(self, value):
129*da0073e9SAndroid Build Coastguard Worker        if self._validate_args:
130*da0073e9SAndroid Build Coastguard Worker            self._validate_sample(value)
131*da0073e9SAndroid Build Coastguard Worker        logits, value = broadcast_all(self.logits, value)
132*da0073e9SAndroid Build Coastguard Worker        logits = logits.clone(memory_format=torch.contiguous_format)
133*da0073e9SAndroid Build Coastguard Worker        log_factorial_n = torch.lgamma(value.sum(-1) + 1)
134*da0073e9SAndroid Build Coastguard Worker        log_factorial_xs = torch.lgamma(value + 1).sum(-1)
135*da0073e9SAndroid Build Coastguard Worker        logits[(value == 0) & (logits == -inf)] = 0
136*da0073e9SAndroid Build Coastguard Worker        log_powers = (logits * value).sum(-1)
137*da0073e9SAndroid Build Coastguard Worker        return log_factorial_n - log_factorial_xs + log_powers
138