1# mypy: allow-untyped-defs 2import warnings 3from typing import Any, Dict, Optional 4from typing_extensions import deprecated 5 6import torch 7from torch.distributions import constraints 8from torch.distributions.utils import lazy_property 9from torch.types import _size 10 11 12__all__ = ["Distribution"] 13 14 15class Distribution: 16 r""" 17 Distribution is the abstract base class for probability distributions. 18 """ 19 20 has_rsample = False 21 has_enumerate_support = False 22 _validate_args = __debug__ 23 24 @staticmethod 25 def set_default_validate_args(value: bool) -> None: 26 """ 27 Sets whether validation is enabled or disabled. 28 29 The default behavior mimics Python's ``assert`` statement: validation 30 is on by default, but is disabled if Python is run in optimized mode 31 (via ``python -O``). Validation may be expensive, so you may want to 32 disable it once a model is working. 33 34 Args: 35 value (bool): Whether to enable validation. 36 """ 37 if value not in [True, False]: 38 raise ValueError 39 Distribution._validate_args = value 40 41 def __init__( 42 self, 43 batch_shape: torch.Size = torch.Size(), 44 event_shape: torch.Size = torch.Size(), 45 validate_args: Optional[bool] = None, 46 ): 47 self._batch_shape = batch_shape 48 self._event_shape = event_shape 49 if validate_args is not None: 50 self._validate_args = validate_args 51 if self._validate_args: 52 try: 53 arg_constraints = self.arg_constraints 54 except NotImplementedError: 55 arg_constraints = {} 56 warnings.warn( 57 f"{self.__class__} does not define `arg_constraints`. " 58 + "Please set `arg_constraints = {}` or initialize the distribution " 59 + "with `validate_args=False` to turn off validation." 60 ) 61 for param, constraint in arg_constraints.items(): 62 if constraints.is_dependent(constraint): 63 continue # skip constraints that cannot be checked 64 if param not in self.__dict__ and isinstance( 65 getattr(type(self), param), lazy_property 66 ): 67 continue # skip checking lazily-constructed args 68 value = getattr(self, param) 69 valid = constraint.check(value) 70 if not valid.all(): 71 raise ValueError( 72 f"Expected parameter {param} " 73 f"({type(value).__name__} of shape {tuple(value.shape)}) " 74 f"of distribution {repr(self)} " 75 f"to satisfy the constraint {repr(constraint)}, " 76 f"but found invalid values:\n{value}" 77 ) 78 super().__init__() 79 80 def expand(self, batch_shape: _size, _instance=None): 81 """ 82 Returns a new distribution instance (or populates an existing instance 83 provided by a derived class) with batch dimensions expanded to 84 `batch_shape`. This method calls :class:`~torch.Tensor.expand` on 85 the distribution's parameters. As such, this does not allocate new 86 memory for the expanded distribution instance. Additionally, 87 this does not repeat any args checking or parameter broadcasting in 88 `__init__.py`, when an instance is first created. 89 90 Args: 91 batch_shape (torch.Size): the desired expanded size. 92 _instance: new instance provided by subclasses that 93 need to override `.expand`. 94 95 Returns: 96 New distribution instance with batch dimensions expanded to 97 `batch_size`. 98 """ 99 raise NotImplementedError 100 101 @property 102 def batch_shape(self) -> torch.Size: 103 """ 104 Returns the shape over which parameters are batched. 105 """ 106 return self._batch_shape 107 108 @property 109 def event_shape(self) -> torch.Size: 110 """ 111 Returns the shape of a single sample (without batching). 112 """ 113 return self._event_shape 114 115 @property 116 def arg_constraints(self) -> Dict[str, constraints.Constraint]: 117 """ 118 Returns a dictionary from argument names to 119 :class:`~torch.distributions.constraints.Constraint` objects that 120 should be satisfied by each argument of this distribution. Args that 121 are not tensors need not appear in this dict. 122 """ 123 raise NotImplementedError 124 125 @property 126 def support(self) -> Optional[Any]: 127 """ 128 Returns a :class:`~torch.distributions.constraints.Constraint` object 129 representing this distribution's support. 130 """ 131 raise NotImplementedError 132 133 @property 134 def mean(self) -> torch.Tensor: 135 """ 136 Returns the mean of the distribution. 137 """ 138 raise NotImplementedError 139 140 @property 141 def mode(self) -> torch.Tensor: 142 """ 143 Returns the mode of the distribution. 144 """ 145 raise NotImplementedError(f"{self.__class__} does not implement mode") 146 147 @property 148 def variance(self) -> torch.Tensor: 149 """ 150 Returns the variance of the distribution. 151 """ 152 raise NotImplementedError 153 154 @property 155 def stddev(self) -> torch.Tensor: 156 """ 157 Returns the standard deviation of the distribution. 158 """ 159 return self.variance.sqrt() 160 161 def sample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 162 """ 163 Generates a sample_shape shaped sample or sample_shape shaped batch of 164 samples if the distribution parameters are batched. 165 """ 166 with torch.no_grad(): 167 return self.rsample(sample_shape) 168 169 def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 170 """ 171 Generates a sample_shape shaped reparameterized sample or sample_shape 172 shaped batch of reparameterized samples if the distribution parameters 173 are batched. 174 """ 175 raise NotImplementedError 176 177 @deprecated( 178 "`sample_n(n)` will be deprecated. Use `sample((n,))` instead.", 179 category=FutureWarning, 180 ) 181 def sample_n(self, n: int) -> torch.Tensor: 182 """ 183 Generates n samples or n batches of samples if the distribution 184 parameters are batched. 185 """ 186 return self.sample(torch.Size((n,))) 187 188 def log_prob(self, value: torch.Tensor) -> torch.Tensor: 189 """ 190 Returns the log of the probability density/mass function evaluated at 191 `value`. 192 193 Args: 194 value (Tensor): 195 """ 196 raise NotImplementedError 197 198 def cdf(self, value: torch.Tensor) -> torch.Tensor: 199 """ 200 Returns the cumulative density/mass function evaluated at 201 `value`. 202 203 Args: 204 value (Tensor): 205 """ 206 raise NotImplementedError 207 208 def icdf(self, value: torch.Tensor) -> torch.Tensor: 209 """ 210 Returns the inverse cumulative density/mass function evaluated at 211 `value`. 212 213 Args: 214 value (Tensor): 215 """ 216 raise NotImplementedError 217 218 def enumerate_support(self, expand: bool = True) -> torch.Tensor: 219 """ 220 Returns tensor containing all values supported by a discrete 221 distribution. The result will enumerate over dimension 0, so the shape 222 of the result will be `(cardinality,) + batch_shape + event_shape` 223 (where `event_shape = ()` for univariate distributions). 224 225 Note that this enumerates over all batched tensors in lock-step 226 `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens 227 along dim 0, but with the remaining batch dimensions being 228 singleton dimensions, `[[0], [1], ..`. 229 230 To iterate over the full Cartesian product use 231 `itertools.product(m.enumerate_support())`. 232 233 Args: 234 expand (bool): whether to expand the support over the 235 batch dims to match the distribution's `batch_shape`. 236 237 Returns: 238 Tensor iterating over dimension 0. 239 """ 240 raise NotImplementedError 241 242 def entropy(self) -> torch.Tensor: 243 """ 244 Returns entropy of distribution, batched over batch_shape. 245 246 Returns: 247 Tensor of shape batch_shape. 248 """ 249 raise NotImplementedError 250 251 def perplexity(self) -> torch.Tensor: 252 """ 253 Returns perplexity of distribution, batched over batch_shape. 254 255 Returns: 256 Tensor of shape batch_shape. 257 """ 258 return torch.exp(self.entropy()) 259 260 def _extended_shape(self, sample_shape: _size = torch.Size()) -> torch.Size: 261 """ 262 Returns the size of the sample returned by the distribution, given 263 a `sample_shape`. Note, that the batch and event shapes of a distribution 264 instance are fixed at the time of construction. If this is empty, the 265 returned shape is upcast to (1,). 266 267 Args: 268 sample_shape (torch.Size): the size of the sample to be drawn. 269 """ 270 if not isinstance(sample_shape, torch.Size): 271 sample_shape = torch.Size(sample_shape) 272 return torch.Size(sample_shape + self._batch_shape + self._event_shape) 273 274 def _validate_sample(self, value: torch.Tensor) -> None: 275 """ 276 Argument validation for distribution methods such as `log_prob`, 277 `cdf` and `icdf`. The rightmost dimensions of a value to be 278 scored via these methods must agree with the distribution's batch 279 and event shapes. 280 281 Args: 282 value (Tensor): the tensor whose log probability is to be 283 computed by the `log_prob` method. 284 Raises 285 ValueError: when the rightmost dimensions of `value` do not match the 286 distribution's batch and event shapes. 287 """ 288 if not isinstance(value, torch.Tensor): 289 raise ValueError("The value argument to log_prob must be a Tensor") 290 291 event_dim_start = len(value.size()) - len(self._event_shape) 292 if value.size()[event_dim_start:] != self._event_shape: 293 raise ValueError( 294 f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}." 295 ) 296 297 actual_shape = value.size() 298 expected_shape = self._batch_shape + self._event_shape 299 for i, j in zip(reversed(actual_shape), reversed(expected_shape)): 300 if i != 1 and j != 1 and i != j: 301 raise ValueError( 302 f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}." 303 ) 304 try: 305 support = self.support 306 except NotImplementedError: 307 warnings.warn( 308 f"{self.__class__} does not define `support` to enable " 309 + "sample validation. Please initialize the distribution with " 310 + "`validate_args=False` to turn off validation." 311 ) 312 return 313 assert support is not None 314 valid = support.check(value) 315 if not valid.all(): 316 raise ValueError( 317 "Expected value argument " 318 f"({type(value).__name__} of shape {tuple(value.shape)}) " 319 f"to be within the support ({repr(support)}) " 320 f"of the distribution {repr(self)}, " 321 f"but found invalid values:\n{value}" 322 ) 323 324 def _get_checked_instance(self, cls, _instance=None): 325 if _instance is None and type(self).__init__ != cls.__init__: 326 raise NotImplementedError( 327 f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method " 328 "must also define a custom .expand() method." 329 ) 330 return self.__new__(type(self)) if _instance is None else _instance 331 332 def __repr__(self) -> str: 333 param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] 334 args_string = ", ".join( 335 [ 336 f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}" 337 for p in param_names 338 ] 339 ) 340 return self.__class__.__name__ + "(" + args_string + ")" 341