xref: /aosp_15_r20/external/pytorch/torch/distributions/constraint_registry.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""
3PyTorch provides two global :class:`ConstraintRegistry` objects that link
4:class:`~torch.distributions.constraints.Constraint` objects to
5:class:`~torch.distributions.transforms.Transform` objects. These objects both
6input constraints and return transforms, but they have different guarantees on
7bijectivity.
8
91. ``biject_to(constraint)`` looks up a bijective
10   :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
11   to the given ``constraint``. The returned transform is guaranteed to have
12   ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.
132. ``transform_to(constraint)`` looks up a not-necessarily bijective
14   :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
15   to the given ``constraint``. The returned transform is not guaranteed to
16   implement ``.log_abs_det_jacobian()``.
17
18The ``transform_to()`` registry is useful for performing unconstrained
19optimization on constrained parameters of probability distributions, which are
20indicated by each distribution's ``.arg_constraints`` dict. These transforms often
21overparameterize a space in order to avoid rotation; they are thus more
22suitable for coordinate-wise optimization algorithms like Adam::
23
24    loc = torch.zeros(100, requires_grad=True)
25    unconstrained = torch.zeros(100, requires_grad=True)
26    scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
27    loss = -Normal(loc, scale).log_prob(data).sum()
28
29The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where
30samples from a probability distribution with constrained ``.support`` are
31propagated in an unconstrained space, and algorithms are typically rotation
32invariant.::
33
34    dist = Exponential(rate)
35    unconstrained = torch.zeros(100, requires_grad=True)
36    sample = biject_to(dist.support)(unconstrained)
37    potential_energy = -dist.log_prob(sample).sum()
38
39.. note::
40
41    An example where ``transform_to`` and ``biject_to`` differ is
42    ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a
43    :class:`~torch.distributions.transforms.SoftmaxTransform` that simply
44    exponentiates and normalizes its inputs; this is a cheap and mostly
45    coordinate-wise operation appropriate for algorithms like SVI. In
46    contrast, ``biject_to(constraints.simplex)`` returns a
47    :class:`~torch.distributions.transforms.StickBreakingTransform` that
48    bijects its input down to a one-fewer-dimensional space; this a more
49    expensive less numerically stable transform but is needed for algorithms
50    like HMC.
51
52The ``biject_to`` and ``transform_to`` objects can be extended by user-defined
53constraints and transforms using their ``.register()`` method either as a
54function on singleton constraints::
55
56    transform_to.register(my_constraint, my_transform)
57
58or as a decorator on parameterized constraints::
59
60    @transform_to.register(MyConstraintClass)
61    def my_factory(constraint):
62        assert isinstance(constraint, MyConstraintClass)
63        return MyTransform(constraint.param1, constraint.param2)
64
65You can create your own registry by creating a new :class:`ConstraintRegistry`
66object.
67"""
68
69import numbers
70
71from torch.distributions import constraints, transforms
72
73
74__all__ = [
75    "ConstraintRegistry",
76    "biject_to",
77    "transform_to",
78]
79
80
81class ConstraintRegistry:
82    """
83    Registry to link constraints to transforms.
84    """
85
86    def __init__(self):
87        self._registry = {}
88        super().__init__()
89
90    def register(self, constraint, factory=None):
91        """
92        Registers a :class:`~torch.distributions.constraints.Constraint`
93        subclass in this registry. Usage::
94
95            @my_registry.register(MyConstraintClass)
96            def construct_transform(constraint):
97                assert isinstance(constraint, MyConstraint)
98                return MyTransform(constraint.arg_constraints)
99
100        Args:
101            constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
102                A subclass of :class:`~torch.distributions.constraints.Constraint`, or
103                a singleton object of the desired class.
104            factory (Callable): A callable that inputs a constraint object and returns
105                a  :class:`~torch.distributions.transforms.Transform` object.
106        """
107        # Support use as decorator.
108        if factory is None:
109            return lambda factory: self.register(constraint, factory)
110
111        # Support calling on singleton instances.
112        if isinstance(constraint, constraints.Constraint):
113            constraint = type(constraint)
114
115        if not isinstance(constraint, type) or not issubclass(
116            constraint, constraints.Constraint
117        ):
118            raise TypeError(
119                f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}"
120            )
121
122        self._registry[constraint] = factory
123        return factory
124
125    def __call__(self, constraint):
126        """
127        Looks up a transform to constrained space, given a constraint object.
128        Usage::
129
130            constraint = Normal.arg_constraints['scale']
131            scale = transform_to(constraint)(torch.zeros(1))  # constrained
132            u = transform_to(constraint).inv(scale)           # unconstrained
133
134        Args:
135            constraint (:class:`~torch.distributions.constraints.Constraint`):
136                A constraint object.
137
138        Returns:
139            A :class:`~torch.distributions.transforms.Transform` object.
140
141        Raises:
142            `NotImplementedError` if no transform has been registered.
143        """
144        # Look up by Constraint subclass.
145        try:
146            factory = self._registry[type(constraint)]
147        except KeyError:
148            raise NotImplementedError(
149                f"Cannot transform {type(constraint).__name__} constraints"
150            ) from None
151        return factory(constraint)
152
153
154biject_to = ConstraintRegistry()
155transform_to = ConstraintRegistry()
156
157
158################################################################################
159# Registration Table
160################################################################################
161
162
163@biject_to.register(constraints.real)
164@transform_to.register(constraints.real)
165def _transform_to_real(constraint):
166    return transforms.identity_transform
167
168
169@biject_to.register(constraints.independent)
170def _biject_to_independent(constraint):
171    base_transform = biject_to(constraint.base_constraint)
172    return transforms.IndependentTransform(
173        base_transform, constraint.reinterpreted_batch_ndims
174    )
175
176
177@transform_to.register(constraints.independent)
178def _transform_to_independent(constraint):
179    base_transform = transform_to(constraint.base_constraint)
180    return transforms.IndependentTransform(
181        base_transform, constraint.reinterpreted_batch_ndims
182    )
183
184
185@biject_to.register(constraints.positive)
186@biject_to.register(constraints.nonnegative)
187@transform_to.register(constraints.positive)
188@transform_to.register(constraints.nonnegative)
189def _transform_to_positive(constraint):
190    return transforms.ExpTransform()
191
192
193@biject_to.register(constraints.greater_than)
194@biject_to.register(constraints.greater_than_eq)
195@transform_to.register(constraints.greater_than)
196@transform_to.register(constraints.greater_than_eq)
197def _transform_to_greater_than(constraint):
198    return transforms.ComposeTransform(
199        [
200            transforms.ExpTransform(),
201            transforms.AffineTransform(constraint.lower_bound, 1),
202        ]
203    )
204
205
206@biject_to.register(constraints.less_than)
207@transform_to.register(constraints.less_than)
208def _transform_to_less_than(constraint):
209    return transforms.ComposeTransform(
210        [
211            transforms.ExpTransform(),
212            transforms.AffineTransform(constraint.upper_bound, -1),
213        ]
214    )
215
216
217@biject_to.register(constraints.interval)
218@biject_to.register(constraints.half_open_interval)
219@transform_to.register(constraints.interval)
220@transform_to.register(constraints.half_open_interval)
221def _transform_to_interval(constraint):
222    # Handle the special case of the unit interval.
223    lower_is_0 = (
224        isinstance(constraint.lower_bound, numbers.Number)
225        and constraint.lower_bound == 0
226    )
227    upper_is_1 = (
228        isinstance(constraint.upper_bound, numbers.Number)
229        and constraint.upper_bound == 1
230    )
231    if lower_is_0 and upper_is_1:
232        return transforms.SigmoidTransform()
233
234    loc = constraint.lower_bound
235    scale = constraint.upper_bound - constraint.lower_bound
236    return transforms.ComposeTransform(
237        [transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)]
238    )
239
240
241@biject_to.register(constraints.simplex)
242def _biject_to_simplex(constraint):
243    return transforms.StickBreakingTransform()
244
245
246@transform_to.register(constraints.simplex)
247def _transform_to_simplex(constraint):
248    return transforms.SoftmaxTransform()
249
250
251# TODO define a bijection for LowerCholeskyTransform
252@transform_to.register(constraints.lower_cholesky)
253def _transform_to_lower_cholesky(constraint):
254    return transforms.LowerCholeskyTransform()
255
256
257@transform_to.register(constraints.positive_definite)
258@transform_to.register(constraints.positive_semidefinite)
259def _transform_to_positive_definite(constraint):
260    return transforms.PositiveDefiniteTransform()
261
262
263@biject_to.register(constraints.corr_cholesky)
264@transform_to.register(constraints.corr_cholesky)
265def _transform_to_corr_cholesky(constraint):
266    return transforms.CorrCholeskyTransform()
267
268
269@biject_to.register(constraints.cat)
270def _biject_to_cat(constraint):
271    return transforms.CatTransform(
272        [biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
273    )
274
275
276@transform_to.register(constraints.cat)
277def _transform_to_cat(constraint):
278    return transforms.CatTransform(
279        [transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
280    )
281
282
283@biject_to.register(constraints.stack)
284def _biject_to_stack(constraint):
285    return transforms.StackTransform(
286        [biject_to(c) for c in constraint.cseq], constraint.dim
287    )
288
289
290@transform_to.register(constraints.stack)
291def _transform_to_stack(constraint):
292    return transforms.StackTransform(
293        [transform_to(c) for c in constraint.cseq], constraint.dim
294    )
295