1# mypy: allow-untyped-defs 2from typing import Dict 3 4import torch 5from torch.distributions import Categorical, constraints 6from torch.distributions.distribution import Distribution 7 8 9__all__ = ["MixtureSameFamily"] 10 11 12class MixtureSameFamily(Distribution): 13 r""" 14 The `MixtureSameFamily` distribution implements a (batch of) mixture 15 distribution where all component are from different parameterizations of 16 the same distribution type. It is parameterized by a `Categorical` 17 "selecting distribution" (over `k` component) and a component 18 distribution, i.e., a `Distribution` with a rightmost batch shape 19 (equal to `[k]`) which indexes each (batch of) component. 20 21 Examples:: 22 23 >>> # xdoctest: +SKIP("undefined vars") 24 >>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally 25 >>> # weighted normal distributions 26 >>> mix = D.Categorical(torch.ones(5,)) 27 >>> comp = D.Normal(torch.randn(5,), torch.rand(5,)) 28 >>> gmm = MixtureSameFamily(mix, comp) 29 30 >>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally 31 >>> # weighted bivariate normal distributions 32 >>> mix = D.Categorical(torch.ones(5,)) 33 >>> comp = D.Independent(D.Normal( 34 ... torch.randn(5,2), torch.rand(5,2)), 1) 35 >>> gmm = MixtureSameFamily(mix, comp) 36 37 >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each 38 >>> # consisting of 5 random weighted bivariate normal distributions 39 >>> mix = D.Categorical(torch.rand(3,5)) 40 >>> comp = D.Independent(D.Normal( 41 ... torch.randn(3,5,2), torch.rand(3,5,2)), 1) 42 >>> gmm = MixtureSameFamily(mix, comp) 43 44 Args: 45 mixture_distribution: `torch.distributions.Categorical`-like 46 instance. Manages the probability of selecting component. 47 The number of categories must match the rightmost batch 48 dimension of the `component_distribution`. Must have either 49 scalar `batch_shape` or `batch_shape` matching 50 `component_distribution.batch_shape[:-1]` 51 component_distribution: `torch.distributions.Distribution`-like 52 instance. Right-most batch dimension indexes component. 53 """ 54 arg_constraints: Dict[str, constraints.Constraint] = {} 55 has_rsample = False 56 57 def __init__( 58 self, mixture_distribution, component_distribution, validate_args=None 59 ): 60 self._mixture_distribution = mixture_distribution 61 self._component_distribution = component_distribution 62 63 if not isinstance(self._mixture_distribution, Categorical): 64 raise ValueError( 65 " The Mixture distribution needs to be an " 66 " instance of torch.distributions.Categorical" 67 ) 68 69 if not isinstance(self._component_distribution, Distribution): 70 raise ValueError( 71 "The Component distribution need to be an " 72 "instance of torch.distributions.Distribution" 73 ) 74 75 # Check that batch size matches 76 mdbs = self._mixture_distribution.batch_shape 77 cdbs = self._component_distribution.batch_shape[:-1] 78 for size1, size2 in zip(reversed(mdbs), reversed(cdbs)): 79 if size1 != 1 and size2 != 1 and size1 != size2: 80 raise ValueError( 81 f"`mixture_distribution.batch_shape` ({mdbs}) is not " 82 "compatible with `component_distribution." 83 f"batch_shape`({cdbs})" 84 ) 85 86 # Check that the number of mixture component matches 87 km = self._mixture_distribution.logits.shape[-1] 88 kc = self._component_distribution.batch_shape[-1] 89 if km is not None and kc is not None and km != kc: 90 raise ValueError( 91 f"`mixture_distribution component` ({km}) does not" 92 " equal `component_distribution.batch_shape[-1]`" 93 f" ({kc})" 94 ) 95 self._num_component = km 96 97 event_shape = self._component_distribution.event_shape 98 self._event_ndims = len(event_shape) 99 super().__init__( 100 batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args 101 ) 102 103 def expand(self, batch_shape, _instance=None): 104 batch_shape = torch.Size(batch_shape) 105 batch_shape_comp = batch_shape + (self._num_component,) 106 new = self._get_checked_instance(MixtureSameFamily, _instance) 107 new._component_distribution = self._component_distribution.expand( 108 batch_shape_comp 109 ) 110 new._mixture_distribution = self._mixture_distribution.expand(batch_shape) 111 new._num_component = self._num_component 112 new._event_ndims = self._event_ndims 113 event_shape = new._component_distribution.event_shape 114 super(MixtureSameFamily, new).__init__( 115 batch_shape=batch_shape, event_shape=event_shape, validate_args=False 116 ) 117 new._validate_args = self._validate_args 118 return new 119 120 @constraints.dependent_property 121 def support(self): 122 # FIXME this may have the wrong shape when support contains batched 123 # parameters 124 return self._component_distribution.support 125 126 @property 127 def mixture_distribution(self): 128 return self._mixture_distribution 129 130 @property 131 def component_distribution(self): 132 return self._component_distribution 133 134 @property 135 def mean(self): 136 probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) 137 return torch.sum( 138 probs * self.component_distribution.mean, dim=-1 - self._event_ndims 139 ) # [B, E] 140 141 @property 142 def variance(self): 143 # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) 144 probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) 145 mean_cond_var = torch.sum( 146 probs * self.component_distribution.variance, dim=-1 - self._event_ndims 147 ) 148 var_cond_mean = torch.sum( 149 probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0), 150 dim=-1 - self._event_ndims, 151 ) 152 return mean_cond_var + var_cond_mean 153 154 def cdf(self, x): 155 x = self._pad(x) 156 cdf_x = self.component_distribution.cdf(x) 157 mix_prob = self.mixture_distribution.probs 158 159 return torch.sum(cdf_x * mix_prob, dim=-1) 160 161 def log_prob(self, x): 162 if self._validate_args: 163 self._validate_sample(x) 164 x = self._pad(x) 165 log_prob_x = self.component_distribution.log_prob(x) # [S, B, k] 166 log_mix_prob = torch.log_softmax( 167 self.mixture_distribution.logits, dim=-1 168 ) # [B, k] 169 return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B] 170 171 def sample(self, sample_shape=torch.Size()): 172 with torch.no_grad(): 173 sample_len = len(sample_shape) 174 batch_len = len(self.batch_shape) 175 gather_dim = sample_len + batch_len 176 es = self.event_shape 177 178 # mixture samples [n, B] 179 mix_sample = self.mixture_distribution.sample(sample_shape) 180 mix_shape = mix_sample.shape 181 182 # component samples [n, B, k, E] 183 comp_samples = self.component_distribution.sample(sample_shape) 184 185 # Gather along the k dimension 186 mix_sample_r = mix_sample.reshape( 187 mix_shape + torch.Size([1] * (len(es) + 1)) 188 ) 189 mix_sample_r = mix_sample_r.repeat( 190 torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es 191 ) 192 193 samples = torch.gather(comp_samples, gather_dim, mix_sample_r) 194 return samples.squeeze(gather_dim) 195 196 def _pad(self, x): 197 return x.unsqueeze(-1 - self._event_ndims) 198 199 def _pad_mixture_dimensions(self, x): 200 dist_batch_ndims = len(self.batch_shape) 201 cat_batch_ndims = len(self.mixture_distribution.batch_shape) 202 pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims 203 xs = x.shape 204 x = x.reshape( 205 xs[:-1] 206 + torch.Size(pad_ndims * [1]) 207 + xs[-1:] 208 + torch.Size(self._event_ndims * [1]) 209 ) 210 return x 211 212 def __repr__(self): 213 args_string = ( 214 f"\n {self.mixture_distribution},\n {self.component_distribution}" 215 ) 216 return "MixtureSameFamily" + "(" + args_string + ")" 217