xref: /aosp_15_r20/external/pytorch/torch/distributions/distribution.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport warnings
3*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, Optional
4*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import deprecated
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions import constraints
8*da0073e9SAndroid Build Coastguard Workerfrom torch.distributions.utils import lazy_property
9*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _size
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker__all__ = ["Distribution"]
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerclass Distribution:
16*da0073e9SAndroid Build Coastguard Worker    r"""
17*da0073e9SAndroid Build Coastguard Worker    Distribution is the abstract base class for probability distributions.
18*da0073e9SAndroid Build Coastguard Worker    """
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker    has_rsample = False
21*da0073e9SAndroid Build Coastguard Worker    has_enumerate_support = False
22*da0073e9SAndroid Build Coastguard Worker    _validate_args = __debug__
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    @staticmethod
25*da0073e9SAndroid Build Coastguard Worker    def set_default_validate_args(value: bool) -> None:
26*da0073e9SAndroid Build Coastguard Worker        """
27*da0073e9SAndroid Build Coastguard Worker        Sets whether validation is enabled or disabled.
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker        The default behavior mimics Python's ``assert`` statement: validation
30*da0073e9SAndroid Build Coastguard Worker        is on by default, but is disabled if Python is run in optimized mode
31*da0073e9SAndroid Build Coastguard Worker        (via ``python -O``). Validation may be expensive, so you may want to
32*da0073e9SAndroid Build Coastguard Worker        disable it once a model is working.
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker        Args:
35*da0073e9SAndroid Build Coastguard Worker            value (bool): Whether to enable validation.
36*da0073e9SAndroid Build Coastguard Worker        """
37*da0073e9SAndroid Build Coastguard Worker        if value not in [True, False]:
38*da0073e9SAndroid Build Coastguard Worker            raise ValueError
39*da0073e9SAndroid Build Coastguard Worker        Distribution._validate_args = value
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    def __init__(
42*da0073e9SAndroid Build Coastguard Worker        self,
43*da0073e9SAndroid Build Coastguard Worker        batch_shape: torch.Size = torch.Size(),
44*da0073e9SAndroid Build Coastguard Worker        event_shape: torch.Size = torch.Size(),
45*da0073e9SAndroid Build Coastguard Worker        validate_args: Optional[bool] = None,
46*da0073e9SAndroid Build Coastguard Worker    ):
47*da0073e9SAndroid Build Coastguard Worker        self._batch_shape = batch_shape
48*da0073e9SAndroid Build Coastguard Worker        self._event_shape = event_shape
49*da0073e9SAndroid Build Coastguard Worker        if validate_args is not None:
50*da0073e9SAndroid Build Coastguard Worker            self._validate_args = validate_args
51*da0073e9SAndroid Build Coastguard Worker        if self._validate_args:
52*da0073e9SAndroid Build Coastguard Worker            try:
53*da0073e9SAndroid Build Coastguard Worker                arg_constraints = self.arg_constraints
54*da0073e9SAndroid Build Coastguard Worker            except NotImplementedError:
55*da0073e9SAndroid Build Coastguard Worker                arg_constraints = {}
56*da0073e9SAndroid Build Coastguard Worker                warnings.warn(
57*da0073e9SAndroid Build Coastguard Worker                    f"{self.__class__} does not define `arg_constraints`. "
58*da0073e9SAndroid Build Coastguard Worker                    + "Please set `arg_constraints = {}` or initialize the distribution "
59*da0073e9SAndroid Build Coastguard Worker                    + "with `validate_args=False` to turn off validation."
60*da0073e9SAndroid Build Coastguard Worker                )
61*da0073e9SAndroid Build Coastguard Worker            for param, constraint in arg_constraints.items():
62*da0073e9SAndroid Build Coastguard Worker                if constraints.is_dependent(constraint):
63*da0073e9SAndroid Build Coastguard Worker                    continue  # skip constraints that cannot be checked
64*da0073e9SAndroid Build Coastguard Worker                if param not in self.__dict__ and isinstance(
65*da0073e9SAndroid Build Coastguard Worker                    getattr(type(self), param), lazy_property
66*da0073e9SAndroid Build Coastguard Worker                ):
67*da0073e9SAndroid Build Coastguard Worker                    continue  # skip checking lazily-constructed args
68*da0073e9SAndroid Build Coastguard Worker                value = getattr(self, param)
69*da0073e9SAndroid Build Coastguard Worker                valid = constraint.check(value)
70*da0073e9SAndroid Build Coastguard Worker                if not valid.all():
71*da0073e9SAndroid Build Coastguard Worker                    raise ValueError(
72*da0073e9SAndroid Build Coastguard Worker                        f"Expected parameter {param} "
73*da0073e9SAndroid Build Coastguard Worker                        f"({type(value).__name__} of shape {tuple(value.shape)}) "
74*da0073e9SAndroid Build Coastguard Worker                        f"of distribution {repr(self)} "
75*da0073e9SAndroid Build Coastguard Worker                        f"to satisfy the constraint {repr(constraint)}, "
76*da0073e9SAndroid Build Coastguard Worker                        f"but found invalid values:\n{value}"
77*da0073e9SAndroid Build Coastguard Worker                    )
78*da0073e9SAndroid Build Coastguard Worker        super().__init__()
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    def expand(self, batch_shape: _size, _instance=None):
81*da0073e9SAndroid Build Coastguard Worker        """
82*da0073e9SAndroid Build Coastguard Worker        Returns a new distribution instance (or populates an existing instance
83*da0073e9SAndroid Build Coastguard Worker        provided by a derived class) with batch dimensions expanded to
84*da0073e9SAndroid Build Coastguard Worker        `batch_shape`. This method calls :class:`~torch.Tensor.expand` on
85*da0073e9SAndroid Build Coastguard Worker        the distribution's parameters. As such, this does not allocate new
86*da0073e9SAndroid Build Coastguard Worker        memory for the expanded distribution instance. Additionally,
87*da0073e9SAndroid Build Coastguard Worker        this does not repeat any args checking or parameter broadcasting in
88*da0073e9SAndroid Build Coastguard Worker        `__init__.py`, when an instance is first created.
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker        Args:
91*da0073e9SAndroid Build Coastguard Worker            batch_shape (torch.Size): the desired expanded size.
92*da0073e9SAndroid Build Coastguard Worker            _instance: new instance provided by subclasses that
93*da0073e9SAndroid Build Coastguard Worker                need to override `.expand`.
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker        Returns:
96*da0073e9SAndroid Build Coastguard Worker            New distribution instance with batch dimensions expanded to
97*da0073e9SAndroid Build Coastguard Worker            `batch_size`.
98*da0073e9SAndroid Build Coastguard Worker        """
99*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    @property
102*da0073e9SAndroid Build Coastguard Worker    def batch_shape(self) -> torch.Size:
103*da0073e9SAndroid Build Coastguard Worker        """
104*da0073e9SAndroid Build Coastguard Worker        Returns the shape over which parameters are batched.
105*da0073e9SAndroid Build Coastguard Worker        """
106*da0073e9SAndroid Build Coastguard Worker        return self._batch_shape
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker    @property
109*da0073e9SAndroid Build Coastguard Worker    def event_shape(self) -> torch.Size:
110*da0073e9SAndroid Build Coastguard Worker        """
111*da0073e9SAndroid Build Coastguard Worker        Returns the shape of a single sample (without batching).
112*da0073e9SAndroid Build Coastguard Worker        """
113*da0073e9SAndroid Build Coastguard Worker        return self._event_shape
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker    @property
116*da0073e9SAndroid Build Coastguard Worker    def arg_constraints(self) -> Dict[str, constraints.Constraint]:
117*da0073e9SAndroid Build Coastguard Worker        """
118*da0073e9SAndroid Build Coastguard Worker        Returns a dictionary from argument names to
119*da0073e9SAndroid Build Coastguard Worker        :class:`~torch.distributions.constraints.Constraint` objects that
120*da0073e9SAndroid Build Coastguard Worker        should be satisfied by each argument of this distribution. Args that
121*da0073e9SAndroid Build Coastguard Worker        are not tensors need not appear in this dict.
122*da0073e9SAndroid Build Coastguard Worker        """
123*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    @property
126*da0073e9SAndroid Build Coastguard Worker    def support(self) -> Optional[Any]:
127*da0073e9SAndroid Build Coastguard Worker        """
128*da0073e9SAndroid Build Coastguard Worker        Returns a :class:`~torch.distributions.constraints.Constraint` object
129*da0073e9SAndroid Build Coastguard Worker        representing this distribution's support.
130*da0073e9SAndroid Build Coastguard Worker        """
131*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    @property
134*da0073e9SAndroid Build Coastguard Worker    def mean(self) -> torch.Tensor:
135*da0073e9SAndroid Build Coastguard Worker        """
136*da0073e9SAndroid Build Coastguard Worker        Returns the mean of the distribution.
137*da0073e9SAndroid Build Coastguard Worker        """
138*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker    @property
141*da0073e9SAndroid Build Coastguard Worker    def mode(self) -> torch.Tensor:
142*da0073e9SAndroid Build Coastguard Worker        """
143*da0073e9SAndroid Build Coastguard Worker        Returns the mode of the distribution.
144*da0073e9SAndroid Build Coastguard Worker        """
145*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(f"{self.__class__} does not implement mode")
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker    @property
148*da0073e9SAndroid Build Coastguard Worker    def variance(self) -> torch.Tensor:
149*da0073e9SAndroid Build Coastguard Worker        """
150*da0073e9SAndroid Build Coastguard Worker        Returns the variance of the distribution.
151*da0073e9SAndroid Build Coastguard Worker        """
152*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker    @property
155*da0073e9SAndroid Build Coastguard Worker    def stddev(self) -> torch.Tensor:
156*da0073e9SAndroid Build Coastguard Worker        """
157*da0073e9SAndroid Build Coastguard Worker        Returns the standard deviation of the distribution.
158*da0073e9SAndroid Build Coastguard Worker        """
159*da0073e9SAndroid Build Coastguard Worker        return self.variance.sqrt()
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    def sample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
162*da0073e9SAndroid Build Coastguard Worker        """
163*da0073e9SAndroid Build Coastguard Worker        Generates a sample_shape shaped sample or sample_shape shaped batch of
164*da0073e9SAndroid Build Coastguard Worker        samples if the distribution parameters are batched.
165*da0073e9SAndroid Build Coastguard Worker        """
166*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
167*da0073e9SAndroid Build Coastguard Worker            return self.rsample(sample_shape)
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
170*da0073e9SAndroid Build Coastguard Worker        """
171*da0073e9SAndroid Build Coastguard Worker        Generates a sample_shape shaped reparameterized sample or sample_shape
172*da0073e9SAndroid Build Coastguard Worker        shaped batch of reparameterized samples if the distribution parameters
173*da0073e9SAndroid Build Coastguard Worker        are batched.
174*da0073e9SAndroid Build Coastguard Worker        """
175*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    @deprecated(
178*da0073e9SAndroid Build Coastguard Worker        "`sample_n(n)` will be deprecated. Use `sample((n,))` instead.",
179*da0073e9SAndroid Build Coastguard Worker        category=FutureWarning,
180*da0073e9SAndroid Build Coastguard Worker    )
181*da0073e9SAndroid Build Coastguard Worker    def sample_n(self, n: int) -> torch.Tensor:
182*da0073e9SAndroid Build Coastguard Worker        """
183*da0073e9SAndroid Build Coastguard Worker        Generates n samples or n batches of samples if the distribution
184*da0073e9SAndroid Build Coastguard Worker        parameters are batched.
185*da0073e9SAndroid Build Coastguard Worker        """
186*da0073e9SAndroid Build Coastguard Worker        return self.sample(torch.Size((n,)))
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
189*da0073e9SAndroid Build Coastguard Worker        """
190*da0073e9SAndroid Build Coastguard Worker        Returns the log of the probability density/mass function evaluated at
191*da0073e9SAndroid Build Coastguard Worker        `value`.
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        Args:
194*da0073e9SAndroid Build Coastguard Worker            value (Tensor):
195*da0073e9SAndroid Build Coastguard Worker        """
196*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker    def cdf(self, value: torch.Tensor) -> torch.Tensor:
199*da0073e9SAndroid Build Coastguard Worker        """
200*da0073e9SAndroid Build Coastguard Worker        Returns the cumulative density/mass function evaluated at
201*da0073e9SAndroid Build Coastguard Worker        `value`.
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker        Args:
204*da0073e9SAndroid Build Coastguard Worker            value (Tensor):
205*da0073e9SAndroid Build Coastguard Worker        """
206*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    def icdf(self, value: torch.Tensor) -> torch.Tensor:
209*da0073e9SAndroid Build Coastguard Worker        """
210*da0073e9SAndroid Build Coastguard Worker        Returns the inverse cumulative density/mass function evaluated at
211*da0073e9SAndroid Build Coastguard Worker        `value`.
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker        Args:
214*da0073e9SAndroid Build Coastguard Worker            value (Tensor):
215*da0073e9SAndroid Build Coastguard Worker        """
216*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker    def enumerate_support(self, expand: bool = True) -> torch.Tensor:
219*da0073e9SAndroid Build Coastguard Worker        """
220*da0073e9SAndroid Build Coastguard Worker        Returns tensor containing all values supported by a discrete
221*da0073e9SAndroid Build Coastguard Worker        distribution. The result will enumerate over dimension 0, so the shape
222*da0073e9SAndroid Build Coastguard Worker        of the result will be `(cardinality,) + batch_shape + event_shape`
223*da0073e9SAndroid Build Coastguard Worker        (where `event_shape = ()` for univariate distributions).
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker        Note that this enumerates over all batched tensors in lock-step
226*da0073e9SAndroid Build Coastguard Worker        `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens
227*da0073e9SAndroid Build Coastguard Worker        along dim 0, but with the remaining batch dimensions being
228*da0073e9SAndroid Build Coastguard Worker        singleton dimensions, `[[0], [1], ..`.
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        To iterate over the full Cartesian product use
231*da0073e9SAndroid Build Coastguard Worker        `itertools.product(m.enumerate_support())`.
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker        Args:
234*da0073e9SAndroid Build Coastguard Worker            expand (bool): whether to expand the support over the
235*da0073e9SAndroid Build Coastguard Worker                batch dims to match the distribution's `batch_shape`.
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker        Returns:
238*da0073e9SAndroid Build Coastguard Worker            Tensor iterating over dimension 0.
239*da0073e9SAndroid Build Coastguard Worker        """
240*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker    def entropy(self) -> torch.Tensor:
243*da0073e9SAndroid Build Coastguard Worker        """
244*da0073e9SAndroid Build Coastguard Worker        Returns entropy of distribution, batched over batch_shape.
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker        Returns:
247*da0073e9SAndroid Build Coastguard Worker            Tensor of shape batch_shape.
248*da0073e9SAndroid Build Coastguard Worker        """
249*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker    def perplexity(self) -> torch.Tensor:
252*da0073e9SAndroid Build Coastguard Worker        """
253*da0073e9SAndroid Build Coastguard Worker        Returns perplexity of distribution, batched over batch_shape.
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker        Returns:
256*da0073e9SAndroid Build Coastguard Worker            Tensor of shape batch_shape.
257*da0073e9SAndroid Build Coastguard Worker        """
258*da0073e9SAndroid Build Coastguard Worker        return torch.exp(self.entropy())
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker    def _extended_shape(self, sample_shape: _size = torch.Size()) -> torch.Size:
261*da0073e9SAndroid Build Coastguard Worker        """
262*da0073e9SAndroid Build Coastguard Worker        Returns the size of the sample returned by the distribution, given
263*da0073e9SAndroid Build Coastguard Worker        a `sample_shape`. Note, that the batch and event shapes of a distribution
264*da0073e9SAndroid Build Coastguard Worker        instance are fixed at the time of construction. If this is empty, the
265*da0073e9SAndroid Build Coastguard Worker        returned shape is upcast to (1,).
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker        Args:
268*da0073e9SAndroid Build Coastguard Worker            sample_shape (torch.Size): the size of the sample to be drawn.
269*da0073e9SAndroid Build Coastguard Worker        """
270*da0073e9SAndroid Build Coastguard Worker        if not isinstance(sample_shape, torch.Size):
271*da0073e9SAndroid Build Coastguard Worker            sample_shape = torch.Size(sample_shape)
272*da0073e9SAndroid Build Coastguard Worker        return torch.Size(sample_shape + self._batch_shape + self._event_shape)
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker    def _validate_sample(self, value: torch.Tensor) -> None:
275*da0073e9SAndroid Build Coastguard Worker        """
276*da0073e9SAndroid Build Coastguard Worker        Argument validation for distribution methods such as `log_prob`,
277*da0073e9SAndroid Build Coastguard Worker        `cdf` and `icdf`. The rightmost dimensions of a value to be
278*da0073e9SAndroid Build Coastguard Worker        scored via these methods must agree with the distribution's batch
279*da0073e9SAndroid Build Coastguard Worker        and event shapes.
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker        Args:
282*da0073e9SAndroid Build Coastguard Worker            value (Tensor): the tensor whose log probability is to be
283*da0073e9SAndroid Build Coastguard Worker                computed by the `log_prob` method.
284*da0073e9SAndroid Build Coastguard Worker        Raises
285*da0073e9SAndroid Build Coastguard Worker            ValueError: when the rightmost dimensions of `value` do not match the
286*da0073e9SAndroid Build Coastguard Worker                distribution's batch and event shapes.
287*da0073e9SAndroid Build Coastguard Worker        """
288*da0073e9SAndroid Build Coastguard Worker        if not isinstance(value, torch.Tensor):
289*da0073e9SAndroid Build Coastguard Worker            raise ValueError("The value argument to log_prob must be a Tensor")
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker        event_dim_start = len(value.size()) - len(self._event_shape)
292*da0073e9SAndroid Build Coastguard Worker        if value.size()[event_dim_start:] != self._event_shape:
293*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
294*da0073e9SAndroid Build Coastguard Worker                f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}."
295*da0073e9SAndroid Build Coastguard Worker            )
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        actual_shape = value.size()
298*da0073e9SAndroid Build Coastguard Worker        expected_shape = self._batch_shape + self._event_shape
299*da0073e9SAndroid Build Coastguard Worker        for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
300*da0073e9SAndroid Build Coastguard Worker            if i != 1 and j != 1 and i != j:
301*da0073e9SAndroid Build Coastguard Worker                raise ValueError(
302*da0073e9SAndroid Build Coastguard Worker                    f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}."
303*da0073e9SAndroid Build Coastguard Worker                )
304*da0073e9SAndroid Build Coastguard Worker        try:
305*da0073e9SAndroid Build Coastguard Worker            support = self.support
306*da0073e9SAndroid Build Coastguard Worker        except NotImplementedError:
307*da0073e9SAndroid Build Coastguard Worker            warnings.warn(
308*da0073e9SAndroid Build Coastguard Worker                f"{self.__class__} does not define `support` to enable "
309*da0073e9SAndroid Build Coastguard Worker                + "sample validation. Please initialize the distribution with "
310*da0073e9SAndroid Build Coastguard Worker                + "`validate_args=False` to turn off validation."
311*da0073e9SAndroid Build Coastguard Worker            )
312*da0073e9SAndroid Build Coastguard Worker            return
313*da0073e9SAndroid Build Coastguard Worker        assert support is not None
314*da0073e9SAndroid Build Coastguard Worker        valid = support.check(value)
315*da0073e9SAndroid Build Coastguard Worker        if not valid.all():
316*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
317*da0073e9SAndroid Build Coastguard Worker                "Expected value argument "
318*da0073e9SAndroid Build Coastguard Worker                f"({type(value).__name__} of shape {tuple(value.shape)}) "
319*da0073e9SAndroid Build Coastguard Worker                f"to be within the support ({repr(support)}) "
320*da0073e9SAndroid Build Coastguard Worker                f"of the distribution {repr(self)}, "
321*da0073e9SAndroid Build Coastguard Worker                f"but found invalid values:\n{value}"
322*da0073e9SAndroid Build Coastguard Worker            )
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker    def _get_checked_instance(self, cls, _instance=None):
325*da0073e9SAndroid Build Coastguard Worker        if _instance is None and type(self).__init__ != cls.__init__:
326*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(
327*da0073e9SAndroid Build Coastguard Worker                f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method "
328*da0073e9SAndroid Build Coastguard Worker                "must also define a custom .expand() method."
329*da0073e9SAndroid Build Coastguard Worker            )
330*da0073e9SAndroid Build Coastguard Worker        return self.__new__(type(self)) if _instance is None else _instance
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker    def __repr__(self) -> str:
333*da0073e9SAndroid Build Coastguard Worker        param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
334*da0073e9SAndroid Build Coastguard Worker        args_string = ", ".join(
335*da0073e9SAndroid Build Coastguard Worker            [
336*da0073e9SAndroid Build Coastguard Worker                f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}"
337*da0073e9SAndroid Build Coastguard Worker                for p in param_names
338*da0073e9SAndroid Build Coastguard Worker            ]
339*da0073e9SAndroid Build Coastguard Worker        )
340*da0073e9SAndroid Build Coastguard Worker        return self.__class__.__name__ + "(" + args_string + ")"
341