xref: /aosp_15_r20/external/pytorch/torch/distributions/chi2.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints
3*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.gamma import Gamma
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker__all__ = ["Chi2"]
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerclass Chi2(Gamma):
10*da0073e9SAndroid Build Coastguard Worker    r"""
11*da0073e9SAndroid Build Coastguard Worker    Creates a Chi-squared distribution parameterized by shape parameter :attr:`df`.
12*da0073e9SAndroid Build Coastguard Worker    This is exactly equivalent to ``Gamma(alpha=0.5*df, beta=0.5)``
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker    Example::
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
17*da0073e9SAndroid Build Coastguard Worker        >>> m = Chi2(torch.tensor([1.0]))
18*da0073e9SAndroid Build Coastguard Worker        >>> m.sample()  # Chi2 distributed with shape df=1
19*da0073e9SAndroid Build Coastguard Worker        tensor([ 0.1046])
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker    Args:
22*da0073e9SAndroid Build Coastguard Worker        df (float or Tensor): shape parameter of the distribution
23*da0073e9SAndroid Build Coastguard Worker    """
24*da0073e9SAndroid Build Coastguard Worker    arg_constraints = {"df": constraints.positive}
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker    def __init__(self, df, validate_args=None):
27*da0073e9SAndroid Build Coastguard Worker        super().__init__(0.5 * df, 0.5, validate_args=validate_args)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    def expand(self, batch_shape, _instance=None):
30*da0073e9SAndroid Build Coastguard Worker        new = self._get_checked_instance(Chi2, _instance)
31*da0073e9SAndroid Build Coastguard Worker        return super().expand(batch_shape, new)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    @property
34*da0073e9SAndroid Build Coastguard Worker    def df(self):
35*da0073e9SAndroid Build Coastguard Worker        return self.concentration * 2
36