1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 3*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.gamma import Gamma 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker__all__ = ["Chi2"] 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerclass Chi2(Gamma): 10*da0073e9SAndroid Build Coastguard Worker r""" 11*da0073e9SAndroid Build Coastguard Worker Creates a Chi-squared distribution parameterized by shape parameter :attr:`df`. 12*da0073e9SAndroid Build Coastguard Worker This is exactly equivalent to ``Gamma(alpha=0.5*df, beta=0.5)`` 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker Example:: 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 17*da0073e9SAndroid Build Coastguard Worker >>> m = Chi2(torch.tensor([1.0])) 18*da0073e9SAndroid Build Coastguard Worker >>> m.sample() # Chi2 distributed with shape df=1 19*da0073e9SAndroid Build Coastguard Worker tensor([ 0.1046]) 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker Args: 22*da0073e9SAndroid Build Coastguard Worker df (float or Tensor): shape parameter of the distribution 23*da0073e9SAndroid Build Coastguard Worker """ 24*da0073e9SAndroid Build Coastguard Worker arg_constraints = {"df": constraints.positive} 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker def __init__(self, df, validate_args=None): 27*da0073e9SAndroid Build Coastguard Worker super().__init__(0.5 * df, 0.5, validate_args=validate_args) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape, _instance=None): 30*da0073e9SAndroid Build Coastguard Worker new = self._get_checked_instance(Chi2, _instance) 31*da0073e9SAndroid Build Coastguard Worker return super().expand(batch_shape, new) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker @property 34*da0073e9SAndroid Build Coastguard Worker def df(self): 35*da0073e9SAndroid Build Coastguard Worker return self.concentration * 2 36