1# mypy: allow-untyped-defs 2from typing import Dict 3 4import torch 5from torch.distributions import constraints 6from torch.distributions.distribution import Distribution 7from torch.distributions.utils import _sum_rightmost 8from torch.types import _size 9 10 11__all__ = ["Independent"] 12 13 14class Independent(Distribution): 15 r""" 16 Reinterprets some of the batch dims of a distribution as event dims. 17 18 This is mainly useful for changing the shape of the result of 19 :meth:`log_prob`. For example to create a diagonal Normal distribution with 20 the same shape as a Multivariate Normal distribution (so they are 21 interchangeable), you can:: 22 23 >>> from torch.distributions.multivariate_normal import MultivariateNormal 24 >>> from torch.distributions.normal import Normal 25 >>> loc = torch.zeros(3) 26 >>> scale = torch.ones(3) 27 >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale)) 28 >>> [mvn.batch_shape, mvn.event_shape] 29 [torch.Size([]), torch.Size([3])] 30 >>> normal = Normal(loc, scale) 31 >>> [normal.batch_shape, normal.event_shape] 32 [torch.Size([3]), torch.Size([])] 33 >>> diagn = Independent(normal, 1) 34 >>> [diagn.batch_shape, diagn.event_shape] 35 [torch.Size([]), torch.Size([3])] 36 37 Args: 38 base_distribution (torch.distributions.distribution.Distribution): a 39 base distribution 40 reinterpreted_batch_ndims (int): the number of batch dims to 41 reinterpret as event dims 42 """ 43 arg_constraints: Dict[str, constraints.Constraint] = {} 44 45 def __init__( 46 self, base_distribution, reinterpreted_batch_ndims, validate_args=None 47 ): 48 if reinterpreted_batch_ndims > len(base_distribution.batch_shape): 49 raise ValueError( 50 "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), " 51 f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}" 52 ) 53 shape = base_distribution.batch_shape + base_distribution.event_shape 54 event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape) 55 batch_shape = shape[: len(shape) - event_dim] 56 event_shape = shape[len(shape) - event_dim :] 57 self.base_dist = base_distribution 58 self.reinterpreted_batch_ndims = reinterpreted_batch_ndims 59 super().__init__(batch_shape, event_shape, validate_args=validate_args) 60 61 def expand(self, batch_shape, _instance=None): 62 new = self._get_checked_instance(Independent, _instance) 63 batch_shape = torch.Size(batch_shape) 64 new.base_dist = self.base_dist.expand( 65 batch_shape + self.event_shape[: self.reinterpreted_batch_ndims] 66 ) 67 new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims 68 super(Independent, new).__init__( 69 batch_shape, self.event_shape, validate_args=False 70 ) 71 new._validate_args = self._validate_args 72 return new 73 74 @property 75 def has_rsample(self): 76 return self.base_dist.has_rsample 77 78 @property 79 def has_enumerate_support(self): 80 if self.reinterpreted_batch_ndims > 0: 81 return False 82 return self.base_dist.has_enumerate_support 83 84 @constraints.dependent_property 85 def support(self): 86 result = self.base_dist.support 87 if self.reinterpreted_batch_ndims: 88 result = constraints.independent(result, self.reinterpreted_batch_ndims) 89 return result 90 91 @property 92 def mean(self): 93 return self.base_dist.mean 94 95 @property 96 def mode(self): 97 return self.base_dist.mode 98 99 @property 100 def variance(self): 101 return self.base_dist.variance 102 103 def sample(self, sample_shape=torch.Size()): 104 return self.base_dist.sample(sample_shape) 105 106 def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 107 return self.base_dist.rsample(sample_shape) 108 109 def log_prob(self, value): 110 log_prob = self.base_dist.log_prob(value) 111 return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims) 112 113 def entropy(self): 114 entropy = self.base_dist.entropy() 115 return _sum_rightmost(entropy, self.reinterpreted_batch_ndims) 116 117 def enumerate_support(self, expand=True): 118 if self.reinterpreted_batch_ndims > 0: 119 raise NotImplementedError( 120 "Enumeration over cartesian product is not implemented" 121 ) 122 return self.base_dist.enumerate_support(expand=expand) 123 124 def __repr__(self): 125 return ( 126 self.__class__.__name__ 127 + f"({self.base_dist}, {self.reinterpreted_batch_ndims})" 128 ) 129