1# mypy: allow-untyped-defs 2from typing import Dict 3 4import torch 5from torch.distributions import constraints 6from torch.distributions.distribution import Distribution 7from torch.distributions.independent import Independent 8from torch.distributions.transforms import ComposeTransform, Transform 9from torch.distributions.utils import _sum_rightmost 10from torch.types import _size 11 12 13__all__ = ["TransformedDistribution"] 14 15 16class TransformedDistribution(Distribution): 17 r""" 18 Extension of the Distribution class, which applies a sequence of Transforms 19 to a base distribution. Let f be the composition of transforms applied:: 20 21 X ~ BaseDistribution 22 Y = f(X) ~ TransformedDistribution(BaseDistribution, f) 23 log p(Y) = log p(X) + log |det (dX/dY)| 24 25 Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the 26 maximum shape of its base distribution and its transforms, since transforms 27 can introduce correlations among events. 28 29 An example for the usage of :class:`TransformedDistribution` would be:: 30 31 # Building a Logistic Distribution 32 # X ~ Uniform(0, 1) 33 # f = a + b * logit(X) 34 # Y ~ f(X) ~ Logistic(a, b) 35 base_distribution = Uniform(0, 1) 36 transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] 37 logistic = TransformedDistribution(base_distribution, transforms) 38 39 For more examples, please look at the implementations of 40 :class:`~torch.distributions.gumbel.Gumbel`, 41 :class:`~torch.distributions.half_cauchy.HalfCauchy`, 42 :class:`~torch.distributions.half_normal.HalfNormal`, 43 :class:`~torch.distributions.log_normal.LogNormal`, 44 :class:`~torch.distributions.pareto.Pareto`, 45 :class:`~torch.distributions.weibull.Weibull`, 46 :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and 47 :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` 48 """ 49 arg_constraints: Dict[str, constraints.Constraint] = {} 50 51 def __init__(self, base_distribution, transforms, validate_args=None): 52 if isinstance(transforms, Transform): 53 self.transforms = [ 54 transforms, 55 ] 56 elif isinstance(transforms, list): 57 if not all(isinstance(t, Transform) for t in transforms): 58 raise ValueError( 59 "transforms must be a Transform or a list of Transforms" 60 ) 61 self.transforms = transforms 62 else: 63 raise ValueError( 64 f"transforms must be a Transform or list, but was {transforms}" 65 ) 66 67 # Reshape base_distribution according to transforms. 68 base_shape = base_distribution.batch_shape + base_distribution.event_shape 69 base_event_dim = len(base_distribution.event_shape) 70 transform = ComposeTransform(self.transforms) 71 if len(base_shape) < transform.domain.event_dim: 72 raise ValueError( 73 f"base_distribution needs to have shape with size at least {transform.domain.event_dim}, but got {base_shape}." 74 ) 75 forward_shape = transform.forward_shape(base_shape) 76 expanded_base_shape = transform.inverse_shape(forward_shape) 77 if base_shape != expanded_base_shape: 78 base_batch_shape = expanded_base_shape[ 79 : len(expanded_base_shape) - base_event_dim 80 ] 81 base_distribution = base_distribution.expand(base_batch_shape) 82 reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim 83 if reinterpreted_batch_ndims > 0: 84 base_distribution = Independent( 85 base_distribution, reinterpreted_batch_ndims 86 ) 87 self.base_dist = base_distribution 88 89 # Compute shapes. 90 transform_change_in_event_dim = ( 91 transform.codomain.event_dim - transform.domain.event_dim 92 ) 93 event_dim = max( 94 transform.codomain.event_dim, # the transform is coupled 95 base_event_dim + transform_change_in_event_dim, # the base dist is coupled 96 ) 97 assert len(forward_shape) >= event_dim 98 cut = len(forward_shape) - event_dim 99 batch_shape = forward_shape[:cut] 100 event_shape = forward_shape[cut:] 101 super().__init__(batch_shape, event_shape, validate_args=validate_args) 102 103 def expand(self, batch_shape, _instance=None): 104 new = self._get_checked_instance(TransformedDistribution, _instance) 105 batch_shape = torch.Size(batch_shape) 106 shape = batch_shape + self.event_shape 107 for t in reversed(self.transforms): 108 shape = t.inverse_shape(shape) 109 base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)] 110 new.base_dist = self.base_dist.expand(base_batch_shape) 111 new.transforms = self.transforms 112 super(TransformedDistribution, new).__init__( 113 batch_shape, self.event_shape, validate_args=False 114 ) 115 new._validate_args = self._validate_args 116 return new 117 118 @constraints.dependent_property(is_discrete=False) 119 def support(self): 120 if not self.transforms: 121 return self.base_dist.support 122 support = self.transforms[-1].codomain 123 if len(self.event_shape) > support.event_dim: 124 support = constraints.independent( 125 support, len(self.event_shape) - support.event_dim 126 ) 127 return support 128 129 @property 130 def has_rsample(self): 131 return self.base_dist.has_rsample 132 133 def sample(self, sample_shape=torch.Size()): 134 """ 135 Generates a sample_shape shaped sample or sample_shape shaped batch of 136 samples if the distribution parameters are batched. Samples first from 137 base distribution and applies `transform()` for every transform in the 138 list. 139 """ 140 with torch.no_grad(): 141 x = self.base_dist.sample(sample_shape) 142 for transform in self.transforms: 143 x = transform(x) 144 return x 145 146 def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 147 """ 148 Generates a sample_shape shaped reparameterized sample or sample_shape 149 shaped batch of reparameterized samples if the distribution parameters 150 are batched. Samples first from base distribution and applies 151 `transform()` for every transform in the list. 152 """ 153 x = self.base_dist.rsample(sample_shape) 154 for transform in self.transforms: 155 x = transform(x) 156 return x 157 158 def log_prob(self, value): 159 """ 160 Scores the sample by inverting the transform(s) and computing the score 161 using the score of the base distribution and the log abs det jacobian. 162 """ 163 if self._validate_args: 164 self._validate_sample(value) 165 event_dim = len(self.event_shape) 166 log_prob = 0.0 167 y = value 168 for transform in reversed(self.transforms): 169 x = transform.inv(y) 170 event_dim += transform.domain.event_dim - transform.codomain.event_dim 171 log_prob = log_prob - _sum_rightmost( 172 transform.log_abs_det_jacobian(x, y), 173 event_dim - transform.domain.event_dim, 174 ) 175 y = x 176 177 log_prob = log_prob + _sum_rightmost( 178 self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape) 179 ) 180 return log_prob 181 182 def _monotonize_cdf(self, value): 183 """ 184 This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is 185 monotone increasing. 186 """ 187 sign = 1 188 for transform in self.transforms: 189 sign = sign * transform.sign 190 if isinstance(sign, int) and sign == 1: 191 return value 192 return sign * (value - 0.5) + 0.5 193 194 def cdf(self, value): 195 """ 196 Computes the cumulative distribution function by inverting the 197 transform(s) and computing the score of the base distribution. 198 """ 199 for transform in self.transforms[::-1]: 200 value = transform.inv(value) 201 if self._validate_args: 202 self.base_dist._validate_sample(value) 203 value = self.base_dist.cdf(value) 204 value = self._monotonize_cdf(value) 205 return value 206 207 def icdf(self, value): 208 """ 209 Computes the inverse cumulative distribution function using 210 transform(s) and computing the score of the base distribution. 211 """ 212 value = self._monotonize_cdf(value) 213 value = self.base_dist.icdf(value) 214 for transform in self.transforms: 215 value = transform(value) 216 return value 217