xref: /aosp_15_r20/external/pytorch/torch/distributions/relaxed_categorical.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3from torch.distributions import constraints
4from torch.distributions.categorical import Categorical
5from torch.distributions.distribution import Distribution
6from torch.distributions.transformed_distribution import TransformedDistribution
7from torch.distributions.transforms import ExpTransform
8from torch.distributions.utils import broadcast_all, clamp_probs
9from torch.types import _size
10
11
12__all__ = ["ExpRelaxedCategorical", "RelaxedOneHotCategorical"]
13
14
15class ExpRelaxedCategorical(Distribution):
16    r"""
17    Creates a ExpRelaxedCategorical parameterized by
18    :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both).
19    Returns the log of a point in the simplex. Based on the interface to
20    :class:`OneHotCategorical`.
21
22    Implementation based on [1].
23
24    See also: :func:`torch.distributions.OneHotCategorical`
25
26    Args:
27        temperature (Tensor): relaxation temperature
28        probs (Tensor): event probabilities
29        logits (Tensor): unnormalized log probability for each event
30
31    [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
32    (Maddison et al., 2017)
33
34    [2] Categorical Reparametrization with Gumbel-Softmax
35    (Jang et al., 2017)
36    """
37    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
38    support = (
39        constraints.real_vector
40    )  # The true support is actually a submanifold of this.
41    has_rsample = True
42
43    def __init__(self, temperature, probs=None, logits=None, validate_args=None):
44        self._categorical = Categorical(probs, logits)
45        self.temperature = temperature
46        batch_shape = self._categorical.batch_shape
47        event_shape = self._categorical.param_shape[-1:]
48        super().__init__(batch_shape, event_shape, validate_args=validate_args)
49
50    def expand(self, batch_shape, _instance=None):
51        new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
52        batch_shape = torch.Size(batch_shape)
53        new.temperature = self.temperature
54        new._categorical = self._categorical.expand(batch_shape)
55        super(ExpRelaxedCategorical, new).__init__(
56            batch_shape, self.event_shape, validate_args=False
57        )
58        new._validate_args = self._validate_args
59        return new
60
61    def _new(self, *args, **kwargs):
62        return self._categorical._new(*args, **kwargs)
63
64    @property
65    def param_shape(self):
66        return self._categorical.param_shape
67
68    @property
69    def logits(self):
70        return self._categorical.logits
71
72    @property
73    def probs(self):
74        return self._categorical.probs
75
76    def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
77        shape = self._extended_shape(sample_shape)
78        uniforms = clamp_probs(
79            torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device)
80        )
81        gumbels = -((-(uniforms.log())).log())
82        scores = (self.logits + gumbels) / self.temperature
83        return scores - scores.logsumexp(dim=-1, keepdim=True)
84
85    def log_prob(self, value):
86        K = self._categorical._num_events
87        if self._validate_args:
88            self._validate_sample(value)
89        logits, value = broadcast_all(self.logits, value)
90        log_scale = torch.full_like(
91            self.temperature, float(K)
92        ).lgamma() - self.temperature.log().mul(-(K - 1))
93        score = logits - value.mul(self.temperature)
94        score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
95        return score + log_scale
96
97
98class RelaxedOneHotCategorical(TransformedDistribution):
99    r"""
100    Creates a RelaxedOneHotCategorical distribution parametrized by
101    :attr:`temperature`, and either :attr:`probs` or :attr:`logits`.
102    This is a relaxed version of the :class:`OneHotCategorical` distribution, so
103    its samples are on simplex, and are reparametrizable.
104
105    Example::
106
107        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
108        >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]),
109        ...                              torch.tensor([0.1, 0.2, 0.3, 0.4]))
110        >>> m.sample()
111        tensor([ 0.1294,  0.2324,  0.3859,  0.2523])
112
113    Args:
114        temperature (Tensor): relaxation temperature
115        probs (Tensor): event probabilities
116        logits (Tensor): unnormalized log probability for each event
117    """
118    arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
119    support = constraints.simplex
120    has_rsample = True
121
122    def __init__(self, temperature, probs=None, logits=None, validate_args=None):
123        base_dist = ExpRelaxedCategorical(
124            temperature, probs, logits, validate_args=validate_args
125        )
126        super().__init__(base_dist, ExpTransform(), validate_args=validate_args)
127
128    def expand(self, batch_shape, _instance=None):
129        new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
130        return super().expand(batch_shape, _instance=new)
131
132    @property
133    def temperature(self):
134        return self.base_dist.temperature
135
136    @property
137    def logits(self):
138        return self.base_dist.logits
139
140    @property
141    def probs(self):
142        return self.base_dist.probs
143