xref: /aosp_15_r20/external/pytorch/torch/distributions/kumaraswamy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport torch
3*da0073e9SAndroid Build Coastguard Workerfrom torch import nan
4*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints
5*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.transformed_distribution import TransformedDistribution
6*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.transforms import AffineTransform, PowerTransform
7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.uniform import Uniform
8*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import broadcast_all, euler_constant
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker__all__ = ["Kumaraswamy"]
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerdef _moments(a, b, n):
15*da0073e9SAndroid Build Coastguard Worker    """
16*da0073e9SAndroid Build Coastguard Worker    Computes nth moment of Kumaraswamy using using torch.lgamma
17*da0073e9SAndroid Build Coastguard Worker    """
18*da0073e9SAndroid Build Coastguard Worker    arg1 = 1 + n / a
19*da0073e9SAndroid Build Coastguard Worker    log_value = torch.lgamma(arg1) + torch.lgamma(b) - torch.lgamma(arg1 + b)
20*da0073e9SAndroid Build Coastguard Worker    return b * torch.exp(log_value)
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Workerclass Kumaraswamy(TransformedDistribution):
24*da0073e9SAndroid Build Coastguard Worker    r"""
25*da0073e9SAndroid Build Coastguard Worker    Samples from a Kumaraswamy distribution.
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    Example::
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
30*da0073e9SAndroid Build Coastguard Worker        >>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0]))
31*da0073e9SAndroid Build Coastguard Worker        >>> m.sample()  # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1
32*da0073e9SAndroid Build Coastguard Worker        tensor([ 0.1729])
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker    Args:
35*da0073e9SAndroid Build Coastguard Worker        concentration1 (float or Tensor): 1st concentration parameter of the distribution
36*da0073e9SAndroid Build Coastguard Worker            (often referred to as alpha)
37*da0073e9SAndroid Build Coastguard Worker        concentration0 (float or Tensor): 2nd concentration parameter of the distribution
38*da0073e9SAndroid Build Coastguard Worker            (often referred to as beta)
39*da0073e9SAndroid Build Coastguard Worker    """
40*da0073e9SAndroid Build Coastguard Worker    arg_constraints = {
41*da0073e9SAndroid Build Coastguard Worker        "concentration1": constraints.positive,
42*da0073e9SAndroid Build Coastguard Worker        "concentration0": constraints.positive,
43*da0073e9SAndroid Build Coastguard Worker    }
44*da0073e9SAndroid Build Coastguard Worker    support = constraints.unit_interval
45*da0073e9SAndroid Build Coastguard Worker    has_rsample = True
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    def __init__(self, concentration1, concentration0, validate_args=None):
48*da0073e9SAndroid Build Coastguard Worker        self.concentration1, self.concentration0 = broadcast_all(
49*da0073e9SAndroid Build Coastguard Worker            concentration1, concentration0
50*da0073e9SAndroid Build Coastguard Worker        )
51*da0073e9SAndroid Build Coastguard Worker        finfo = torch.finfo(self.concentration0.dtype)
52*da0073e9SAndroid Build Coastguard Worker        base_dist = Uniform(
53*da0073e9SAndroid Build Coastguard Worker            torch.full_like(self.concentration0, 0),
54*da0073e9SAndroid Build Coastguard Worker            torch.full_like(self.concentration0, 1),
55*da0073e9SAndroid Build Coastguard Worker            validate_args=validate_args,
56*da0073e9SAndroid Build Coastguard Worker        )
57*da0073e9SAndroid Build Coastguard Worker        transforms = [
58*da0073e9SAndroid Build Coastguard Worker            PowerTransform(exponent=self.concentration0.reciprocal()),
59*da0073e9SAndroid Build Coastguard Worker            AffineTransform(loc=1.0, scale=-1.0),
60*da0073e9SAndroid Build Coastguard Worker            PowerTransform(exponent=self.concentration1.reciprocal()),
61*da0073e9SAndroid Build Coastguard Worker        ]
62*da0073e9SAndroid Build Coastguard Worker        super().__init__(base_dist, transforms, validate_args=validate_args)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    def expand(self, batch_shape, _instance=None):
65*da0073e9SAndroid Build Coastguard Worker        new = self._get_checked_instance(Kumaraswamy, _instance)
66*da0073e9SAndroid Build Coastguard Worker        new.concentration1 = self.concentration1.expand(batch_shape)
67*da0073e9SAndroid Build Coastguard Worker        new.concentration0 = self.concentration0.expand(batch_shape)
68*da0073e9SAndroid Build Coastguard Worker        return super().expand(batch_shape, _instance=new)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    @property
71*da0073e9SAndroid Build Coastguard Worker    def mean(self):
72*da0073e9SAndroid Build Coastguard Worker        return _moments(self.concentration1, self.concentration0, 1)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    @property
75*da0073e9SAndroid Build Coastguard Worker    def mode(self):
76*da0073e9SAndroid Build Coastguard Worker        # Evaluate in log-space for numerical stability.
77*da0073e9SAndroid Build Coastguard Worker        log_mode = (
78*da0073e9SAndroid Build Coastguard Worker            self.concentration0.reciprocal() * (-self.concentration0).log1p()
79*da0073e9SAndroid Build Coastguard Worker            - (-self.concentration0 * self.concentration1).log1p()
80*da0073e9SAndroid Build Coastguard Worker        )
81*da0073e9SAndroid Build Coastguard Worker        log_mode[(self.concentration0 < 1) | (self.concentration1 < 1)] = nan
82*da0073e9SAndroid Build Coastguard Worker        return log_mode.exp()
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    @property
85*da0073e9SAndroid Build Coastguard Worker    def variance(self):
86*da0073e9SAndroid Build Coastguard Worker        return _moments(self.concentration1, self.concentration0, 2) - torch.pow(
87*da0073e9SAndroid Build Coastguard Worker            self.mean, 2
88*da0073e9SAndroid Build Coastguard Worker        )
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    def entropy(self):
91*da0073e9SAndroid Build Coastguard Worker        t1 = 1 - self.concentration1.reciprocal()
92*da0073e9SAndroid Build Coastguard Worker        t0 = 1 - self.concentration0.reciprocal()
93*da0073e9SAndroid Build Coastguard Worker        H0 = torch.digamma(self.concentration0 + 1) + euler_constant
94*da0073e9SAndroid Build Coastguard Worker        return (
95*da0073e9SAndroid Build Coastguard Worker            t0
96*da0073e9SAndroid Build Coastguard Worker            + t1 * H0
97*da0073e9SAndroid Build Coastguard Worker            - torch.log(self.concentration1)
98*da0073e9SAndroid Build Coastguard Worker            - torch.log(self.concentration0)
99*da0073e9SAndroid Build Coastguard Worker        )
100