1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport warnings 3*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, Optional 4*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import deprecated 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints 8*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import lazy_property 9*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _size 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker__all__ = ["Distribution"] 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerclass Distribution: 16*da0073e9SAndroid Build Coastguard Worker r""" 17*da0073e9SAndroid Build Coastguard Worker Distribution is the abstract base class for probability distributions. 18*da0073e9SAndroid Build Coastguard Worker """ 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker has_rsample = False 21*da0073e9SAndroid Build Coastguard Worker has_enumerate_support = False 22*da0073e9SAndroid Build Coastguard Worker _validate_args = __debug__ 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker @staticmethod 25*da0073e9SAndroid Build Coastguard Worker def set_default_validate_args(value: bool) -> None: 26*da0073e9SAndroid Build Coastguard Worker """ 27*da0073e9SAndroid Build Coastguard Worker Sets whether validation is enabled or disabled. 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker The default behavior mimics Python's ``assert`` statement: validation 30*da0073e9SAndroid Build Coastguard Worker is on by default, but is disabled if Python is run in optimized mode 31*da0073e9SAndroid Build Coastguard Worker (via ``python -O``). Validation may be expensive, so you may want to 32*da0073e9SAndroid Build Coastguard Worker disable it once a model is working. 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker Args: 35*da0073e9SAndroid Build Coastguard Worker value (bool): Whether to enable validation. 36*da0073e9SAndroid Build Coastguard Worker """ 37*da0073e9SAndroid Build Coastguard Worker if value not in [True, False]: 38*da0073e9SAndroid Build Coastguard Worker raise ValueError 39*da0073e9SAndroid Build Coastguard Worker Distribution._validate_args = value 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker def __init__( 42*da0073e9SAndroid Build Coastguard Worker self, 43*da0073e9SAndroid Build Coastguard Worker batch_shape: torch.Size = torch.Size(), 44*da0073e9SAndroid Build Coastguard Worker event_shape: torch.Size = torch.Size(), 45*da0073e9SAndroid Build Coastguard Worker validate_args: Optional[bool] = None, 46*da0073e9SAndroid Build Coastguard Worker ): 47*da0073e9SAndroid Build Coastguard Worker self._batch_shape = batch_shape 48*da0073e9SAndroid Build Coastguard Worker self._event_shape = event_shape 49*da0073e9SAndroid Build Coastguard Worker if validate_args is not None: 50*da0073e9SAndroid Build Coastguard Worker self._validate_args = validate_args 51*da0073e9SAndroid Build Coastguard Worker if self._validate_args: 52*da0073e9SAndroid Build Coastguard Worker try: 53*da0073e9SAndroid Build Coastguard Worker arg_constraints = self.arg_constraints 54*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 55*da0073e9SAndroid Build Coastguard Worker arg_constraints = {} 56*da0073e9SAndroid Build Coastguard Worker warnings.warn( 57*da0073e9SAndroid Build Coastguard Worker f"{self.__class__} does not define `arg_constraints`. " 58*da0073e9SAndroid Build Coastguard Worker + "Please set `arg_constraints = {}` or initialize the distribution " 59*da0073e9SAndroid Build Coastguard Worker + "with `validate_args=False` to turn off validation." 60*da0073e9SAndroid Build Coastguard Worker ) 61*da0073e9SAndroid Build Coastguard Worker for param, constraint in arg_constraints.items(): 62*da0073e9SAndroid Build Coastguard Worker if constraints.is_dependent(constraint): 63*da0073e9SAndroid Build Coastguard Worker continue # skip constraints that cannot be checked 64*da0073e9SAndroid Build Coastguard Worker if param not in self.__dict__ and isinstance( 65*da0073e9SAndroid Build Coastguard Worker getattr(type(self), param), lazy_property 66*da0073e9SAndroid Build Coastguard Worker ): 67*da0073e9SAndroid Build Coastguard Worker continue # skip checking lazily-constructed args 68*da0073e9SAndroid Build Coastguard Worker value = getattr(self, param) 69*da0073e9SAndroid Build Coastguard Worker valid = constraint.check(value) 70*da0073e9SAndroid Build Coastguard Worker if not valid.all(): 71*da0073e9SAndroid Build Coastguard Worker raise ValueError( 72*da0073e9SAndroid Build Coastguard Worker f"Expected parameter {param} " 73*da0073e9SAndroid Build Coastguard Worker f"({type(value).__name__} of shape {tuple(value.shape)}) " 74*da0073e9SAndroid Build Coastguard Worker f"of distribution {repr(self)} " 75*da0073e9SAndroid Build Coastguard Worker f"to satisfy the constraint {repr(constraint)}, " 76*da0073e9SAndroid Build Coastguard Worker f"but found invalid values:\n{value}" 77*da0073e9SAndroid Build Coastguard Worker ) 78*da0073e9SAndroid Build Coastguard Worker super().__init__() 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def expand(self, batch_shape: _size, _instance=None): 81*da0073e9SAndroid Build Coastguard Worker """ 82*da0073e9SAndroid Build Coastguard Worker Returns a new distribution instance (or populates an existing instance 83*da0073e9SAndroid Build Coastguard Worker provided by a derived class) with batch dimensions expanded to 84*da0073e9SAndroid Build Coastguard Worker `batch_shape`. This method calls :class:`~torch.Tensor.expand` on 85*da0073e9SAndroid Build Coastguard Worker the distribution's parameters. As such, this does not allocate new 86*da0073e9SAndroid Build Coastguard Worker memory for the expanded distribution instance. Additionally, 87*da0073e9SAndroid Build Coastguard Worker this does not repeat any args checking or parameter broadcasting in 88*da0073e9SAndroid Build Coastguard Worker `__init__.py`, when an instance is first created. 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker Args: 91*da0073e9SAndroid Build Coastguard Worker batch_shape (torch.Size): the desired expanded size. 92*da0073e9SAndroid Build Coastguard Worker _instance: new instance provided by subclasses that 93*da0073e9SAndroid Build Coastguard Worker need to override `.expand`. 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker Returns: 96*da0073e9SAndroid Build Coastguard Worker New distribution instance with batch dimensions expanded to 97*da0073e9SAndroid Build Coastguard Worker `batch_size`. 98*da0073e9SAndroid Build Coastguard Worker """ 99*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker @property 102*da0073e9SAndroid Build Coastguard Worker def batch_shape(self) -> torch.Size: 103*da0073e9SAndroid Build Coastguard Worker """ 104*da0073e9SAndroid Build Coastguard Worker Returns the shape over which parameters are batched. 105*da0073e9SAndroid Build Coastguard Worker """ 106*da0073e9SAndroid Build Coastguard Worker return self._batch_shape 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker @property 109*da0073e9SAndroid Build Coastguard Worker def event_shape(self) -> torch.Size: 110*da0073e9SAndroid Build Coastguard Worker """ 111*da0073e9SAndroid Build Coastguard Worker Returns the shape of a single sample (without batching). 112*da0073e9SAndroid Build Coastguard Worker """ 113*da0073e9SAndroid Build Coastguard Worker return self._event_shape 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker @property 116*da0073e9SAndroid Build Coastguard Worker def arg_constraints(self) -> Dict[str, constraints.Constraint]: 117*da0073e9SAndroid Build Coastguard Worker """ 118*da0073e9SAndroid Build Coastguard Worker Returns a dictionary from argument names to 119*da0073e9SAndroid Build Coastguard Worker :class:`~torch.distributions.constraints.Constraint` objects that 120*da0073e9SAndroid Build Coastguard Worker should be satisfied by each argument of this distribution. Args that 121*da0073e9SAndroid Build Coastguard Worker are not tensors need not appear in this dict. 122*da0073e9SAndroid Build Coastguard Worker """ 123*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker @property 126*da0073e9SAndroid Build Coastguard Worker def support(self) -> Optional[Any]: 127*da0073e9SAndroid Build Coastguard Worker """ 128*da0073e9SAndroid Build Coastguard Worker Returns a :class:`~torch.distributions.constraints.Constraint` object 129*da0073e9SAndroid Build Coastguard Worker representing this distribution's support. 130*da0073e9SAndroid Build Coastguard Worker """ 131*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker @property 134*da0073e9SAndroid Build Coastguard Worker def mean(self) -> torch.Tensor: 135*da0073e9SAndroid Build Coastguard Worker """ 136*da0073e9SAndroid Build Coastguard Worker Returns the mean of the distribution. 137*da0073e9SAndroid Build Coastguard Worker """ 138*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker @property 141*da0073e9SAndroid Build Coastguard Worker def mode(self) -> torch.Tensor: 142*da0073e9SAndroid Build Coastguard Worker """ 143*da0073e9SAndroid Build Coastguard Worker Returns the mode of the distribution. 144*da0073e9SAndroid Build Coastguard Worker """ 145*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(f"{self.__class__} does not implement mode") 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker @property 148*da0073e9SAndroid Build Coastguard Worker def variance(self) -> torch.Tensor: 149*da0073e9SAndroid Build Coastguard Worker """ 150*da0073e9SAndroid Build Coastguard Worker Returns the variance of the distribution. 151*da0073e9SAndroid Build Coastguard Worker """ 152*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker @property 155*da0073e9SAndroid Build Coastguard Worker def stddev(self) -> torch.Tensor: 156*da0073e9SAndroid Build Coastguard Worker """ 157*da0073e9SAndroid Build Coastguard Worker Returns the standard deviation of the distribution. 158*da0073e9SAndroid Build Coastguard Worker """ 159*da0073e9SAndroid Build Coastguard Worker return self.variance.sqrt() 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def sample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 162*da0073e9SAndroid Build Coastguard Worker """ 163*da0073e9SAndroid Build Coastguard Worker Generates a sample_shape shaped sample or sample_shape shaped batch of 164*da0073e9SAndroid Build Coastguard Worker samples if the distribution parameters are batched. 165*da0073e9SAndroid Build Coastguard Worker """ 166*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 167*da0073e9SAndroid Build Coastguard Worker return self.rsample(sample_shape) 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: 170*da0073e9SAndroid Build Coastguard Worker """ 171*da0073e9SAndroid Build Coastguard Worker Generates a sample_shape shaped reparameterized sample or sample_shape 172*da0073e9SAndroid Build Coastguard Worker shaped batch of reparameterized samples if the distribution parameters 173*da0073e9SAndroid Build Coastguard Worker are batched. 174*da0073e9SAndroid Build Coastguard Worker """ 175*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker @deprecated( 178*da0073e9SAndroid Build Coastguard Worker "`sample_n(n)` will be deprecated. Use `sample((n,))` instead.", 179*da0073e9SAndroid Build Coastguard Worker category=FutureWarning, 180*da0073e9SAndroid Build Coastguard Worker ) 181*da0073e9SAndroid Build Coastguard Worker def sample_n(self, n: int) -> torch.Tensor: 182*da0073e9SAndroid Build Coastguard Worker """ 183*da0073e9SAndroid Build Coastguard Worker Generates n samples or n batches of samples if the distribution 184*da0073e9SAndroid Build Coastguard Worker parameters are batched. 185*da0073e9SAndroid Build Coastguard Worker """ 186*da0073e9SAndroid Build Coastguard Worker return self.sample(torch.Size((n,))) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker def log_prob(self, value: torch.Tensor) -> torch.Tensor: 189*da0073e9SAndroid Build Coastguard Worker """ 190*da0073e9SAndroid Build Coastguard Worker Returns the log of the probability density/mass function evaluated at 191*da0073e9SAndroid Build Coastguard Worker `value`. 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker Args: 194*da0073e9SAndroid Build Coastguard Worker value (Tensor): 195*da0073e9SAndroid Build Coastguard Worker """ 196*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker def cdf(self, value: torch.Tensor) -> torch.Tensor: 199*da0073e9SAndroid Build Coastguard Worker """ 200*da0073e9SAndroid Build Coastguard Worker Returns the cumulative density/mass function evaluated at 201*da0073e9SAndroid Build Coastguard Worker `value`. 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker Args: 204*da0073e9SAndroid Build Coastguard Worker value (Tensor): 205*da0073e9SAndroid Build Coastguard Worker """ 206*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker def icdf(self, value: torch.Tensor) -> torch.Tensor: 209*da0073e9SAndroid Build Coastguard Worker """ 210*da0073e9SAndroid Build Coastguard Worker Returns the inverse cumulative density/mass function evaluated at 211*da0073e9SAndroid Build Coastguard Worker `value`. 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker Args: 214*da0073e9SAndroid Build Coastguard Worker value (Tensor): 215*da0073e9SAndroid Build Coastguard Worker """ 216*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker def enumerate_support(self, expand: bool = True) -> torch.Tensor: 219*da0073e9SAndroid Build Coastguard Worker """ 220*da0073e9SAndroid Build Coastguard Worker Returns tensor containing all values supported by a discrete 221*da0073e9SAndroid Build Coastguard Worker distribution. The result will enumerate over dimension 0, so the shape 222*da0073e9SAndroid Build Coastguard Worker of the result will be `(cardinality,) + batch_shape + event_shape` 223*da0073e9SAndroid Build Coastguard Worker (where `event_shape = ()` for univariate distributions). 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker Note that this enumerates over all batched tensors in lock-step 226*da0073e9SAndroid Build Coastguard Worker `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens 227*da0073e9SAndroid Build Coastguard Worker along dim 0, but with the remaining batch dimensions being 228*da0073e9SAndroid Build Coastguard Worker singleton dimensions, `[[0], [1], ..`. 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker To iterate over the full Cartesian product use 231*da0073e9SAndroid Build Coastguard Worker `itertools.product(m.enumerate_support())`. 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker Args: 234*da0073e9SAndroid Build Coastguard Worker expand (bool): whether to expand the support over the 235*da0073e9SAndroid Build Coastguard Worker batch dims to match the distribution's `batch_shape`. 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker Returns: 238*da0073e9SAndroid Build Coastguard Worker Tensor iterating over dimension 0. 239*da0073e9SAndroid Build Coastguard Worker """ 240*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker def entropy(self) -> torch.Tensor: 243*da0073e9SAndroid Build Coastguard Worker """ 244*da0073e9SAndroid Build Coastguard Worker Returns entropy of distribution, batched over batch_shape. 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker Returns: 247*da0073e9SAndroid Build Coastguard Worker Tensor of shape batch_shape. 248*da0073e9SAndroid Build Coastguard Worker """ 249*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker def perplexity(self) -> torch.Tensor: 252*da0073e9SAndroid Build Coastguard Worker """ 253*da0073e9SAndroid Build Coastguard Worker Returns perplexity of distribution, batched over batch_shape. 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker Returns: 256*da0073e9SAndroid Build Coastguard Worker Tensor of shape batch_shape. 257*da0073e9SAndroid Build Coastguard Worker """ 258*da0073e9SAndroid Build Coastguard Worker return torch.exp(self.entropy()) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker def _extended_shape(self, sample_shape: _size = torch.Size()) -> torch.Size: 261*da0073e9SAndroid Build Coastguard Worker """ 262*da0073e9SAndroid Build Coastguard Worker Returns the size of the sample returned by the distribution, given 263*da0073e9SAndroid Build Coastguard Worker a `sample_shape`. Note, that the batch and event shapes of a distribution 264*da0073e9SAndroid Build Coastguard Worker instance are fixed at the time of construction. If this is empty, the 265*da0073e9SAndroid Build Coastguard Worker returned shape is upcast to (1,). 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker Args: 268*da0073e9SAndroid Build Coastguard Worker sample_shape (torch.Size): the size of the sample to be drawn. 269*da0073e9SAndroid Build Coastguard Worker """ 270*da0073e9SAndroid Build Coastguard Worker if not isinstance(sample_shape, torch.Size): 271*da0073e9SAndroid Build Coastguard Worker sample_shape = torch.Size(sample_shape) 272*da0073e9SAndroid Build Coastguard Worker return torch.Size(sample_shape + self._batch_shape + self._event_shape) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker def _validate_sample(self, value: torch.Tensor) -> None: 275*da0073e9SAndroid Build Coastguard Worker """ 276*da0073e9SAndroid Build Coastguard Worker Argument validation for distribution methods such as `log_prob`, 277*da0073e9SAndroid Build Coastguard Worker `cdf` and `icdf`. The rightmost dimensions of a value to be 278*da0073e9SAndroid Build Coastguard Worker scored via these methods must agree with the distribution's batch 279*da0073e9SAndroid Build Coastguard Worker and event shapes. 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker Args: 282*da0073e9SAndroid Build Coastguard Worker value (Tensor): the tensor whose log probability is to be 283*da0073e9SAndroid Build Coastguard Worker computed by the `log_prob` method. 284*da0073e9SAndroid Build Coastguard Worker Raises 285*da0073e9SAndroid Build Coastguard Worker ValueError: when the rightmost dimensions of `value` do not match the 286*da0073e9SAndroid Build Coastguard Worker distribution's batch and event shapes. 287*da0073e9SAndroid Build Coastguard Worker """ 288*da0073e9SAndroid Build Coastguard Worker if not isinstance(value, torch.Tensor): 289*da0073e9SAndroid Build Coastguard Worker raise ValueError("The value argument to log_prob must be a Tensor") 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker event_dim_start = len(value.size()) - len(self._event_shape) 292*da0073e9SAndroid Build Coastguard Worker if value.size()[event_dim_start:] != self._event_shape: 293*da0073e9SAndroid Build Coastguard Worker raise ValueError( 294*da0073e9SAndroid Build Coastguard Worker f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}." 295*da0073e9SAndroid Build Coastguard Worker ) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker actual_shape = value.size() 298*da0073e9SAndroid Build Coastguard Worker expected_shape = self._batch_shape + self._event_shape 299*da0073e9SAndroid Build Coastguard Worker for i, j in zip(reversed(actual_shape), reversed(expected_shape)): 300*da0073e9SAndroid Build Coastguard Worker if i != 1 and j != 1 and i != j: 301*da0073e9SAndroid Build Coastguard Worker raise ValueError( 302*da0073e9SAndroid Build Coastguard Worker f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}." 303*da0073e9SAndroid Build Coastguard Worker ) 304*da0073e9SAndroid Build Coastguard Worker try: 305*da0073e9SAndroid Build Coastguard Worker support = self.support 306*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 307*da0073e9SAndroid Build Coastguard Worker warnings.warn( 308*da0073e9SAndroid Build Coastguard Worker f"{self.__class__} does not define `support` to enable " 309*da0073e9SAndroid Build Coastguard Worker + "sample validation. Please initialize the distribution with " 310*da0073e9SAndroid Build Coastguard Worker + "`validate_args=False` to turn off validation." 311*da0073e9SAndroid Build Coastguard Worker ) 312*da0073e9SAndroid Build Coastguard Worker return 313*da0073e9SAndroid Build Coastguard Worker assert support is not None 314*da0073e9SAndroid Build Coastguard Worker valid = support.check(value) 315*da0073e9SAndroid Build Coastguard Worker if not valid.all(): 316*da0073e9SAndroid Build Coastguard Worker raise ValueError( 317*da0073e9SAndroid Build Coastguard Worker "Expected value argument " 318*da0073e9SAndroid Build Coastguard Worker f"({type(value).__name__} of shape {tuple(value.shape)}) " 319*da0073e9SAndroid Build Coastguard Worker f"to be within the support ({repr(support)}) " 320*da0073e9SAndroid Build Coastguard Worker f"of the distribution {repr(self)}, " 321*da0073e9SAndroid Build Coastguard Worker f"but found invalid values:\n{value}" 322*da0073e9SAndroid Build Coastguard Worker ) 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker def _get_checked_instance(self, cls, _instance=None): 325*da0073e9SAndroid Build Coastguard Worker if _instance is None and type(self).__init__ != cls.__init__: 326*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 327*da0073e9SAndroid Build Coastguard Worker f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method " 328*da0073e9SAndroid Build Coastguard Worker "must also define a custom .expand() method." 329*da0073e9SAndroid Build Coastguard Worker ) 330*da0073e9SAndroid Build Coastguard Worker return self.__new__(type(self)) if _instance is None else _instance 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker def __repr__(self) -> str: 333*da0073e9SAndroid Build Coastguard Worker param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] 334*da0073e9SAndroid Build Coastguard Worker args_string = ", ".join( 335*da0073e9SAndroid Build Coastguard Worker [ 336*da0073e9SAndroid Build Coastguard Worker f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}" 337*da0073e9SAndroid Build Coastguard Worker for p in param_names 338*da0073e9SAndroid Build Coastguard Worker ] 339*da0073e9SAndroid Build Coastguard Worker ) 340*da0073e9SAndroid Build Coastguard Worker return self.__class__.__name__ + "(" + args_string + ")" 341