1# mypy: allow-untyped-defs 2r""" 3PyTorch provides two global :class:`ConstraintRegistry` objects that link 4:class:`~torch.distributions.constraints.Constraint` objects to 5:class:`~torch.distributions.transforms.Transform` objects. These objects both 6input constraints and return transforms, but they have different guarantees on 7bijectivity. 8 91. ``biject_to(constraint)`` looks up a bijective 10 :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` 11 to the given ``constraint``. The returned transform is guaranteed to have 12 ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``. 132. ``transform_to(constraint)`` looks up a not-necessarily bijective 14 :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` 15 to the given ``constraint``. The returned transform is not guaranteed to 16 implement ``.log_abs_det_jacobian()``. 17 18The ``transform_to()`` registry is useful for performing unconstrained 19optimization on constrained parameters of probability distributions, which are 20indicated by each distribution's ``.arg_constraints`` dict. These transforms often 21overparameterize a space in order to avoid rotation; they are thus more 22suitable for coordinate-wise optimization algorithms like Adam:: 23 24 loc = torch.zeros(100, requires_grad=True) 25 unconstrained = torch.zeros(100, requires_grad=True) 26 scale = transform_to(Normal.arg_constraints['scale'])(unconstrained) 27 loss = -Normal(loc, scale).log_prob(data).sum() 28 29The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where 30samples from a probability distribution with constrained ``.support`` are 31propagated in an unconstrained space, and algorithms are typically rotation 32invariant.:: 33 34 dist = Exponential(rate) 35 unconstrained = torch.zeros(100, requires_grad=True) 36 sample = biject_to(dist.support)(unconstrained) 37 potential_energy = -dist.log_prob(sample).sum() 38 39.. note:: 40 41 An example where ``transform_to`` and ``biject_to`` differ is 42 ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a 43 :class:`~torch.distributions.transforms.SoftmaxTransform` that simply 44 exponentiates and normalizes its inputs; this is a cheap and mostly 45 coordinate-wise operation appropriate for algorithms like SVI. In 46 contrast, ``biject_to(constraints.simplex)`` returns a 47 :class:`~torch.distributions.transforms.StickBreakingTransform` that 48 bijects its input down to a one-fewer-dimensional space; this a more 49 expensive less numerically stable transform but is needed for algorithms 50 like HMC. 51 52The ``biject_to`` and ``transform_to`` objects can be extended by user-defined 53constraints and transforms using their ``.register()`` method either as a 54function on singleton constraints:: 55 56 transform_to.register(my_constraint, my_transform) 57 58or as a decorator on parameterized constraints:: 59 60 @transform_to.register(MyConstraintClass) 61 def my_factory(constraint): 62 assert isinstance(constraint, MyConstraintClass) 63 return MyTransform(constraint.param1, constraint.param2) 64 65You can create your own registry by creating a new :class:`ConstraintRegistry` 66object. 67""" 68 69import numbers 70 71from torch.distributions import constraints, transforms 72 73 74__all__ = [ 75 "ConstraintRegistry", 76 "biject_to", 77 "transform_to", 78] 79 80 81class ConstraintRegistry: 82 """ 83 Registry to link constraints to transforms. 84 """ 85 86 def __init__(self): 87 self._registry = {} 88 super().__init__() 89 90 def register(self, constraint, factory=None): 91 """ 92 Registers a :class:`~torch.distributions.constraints.Constraint` 93 subclass in this registry. Usage:: 94 95 @my_registry.register(MyConstraintClass) 96 def construct_transform(constraint): 97 assert isinstance(constraint, MyConstraint) 98 return MyTransform(constraint.arg_constraints) 99 100 Args: 101 constraint (subclass of :class:`~torch.distributions.constraints.Constraint`): 102 A subclass of :class:`~torch.distributions.constraints.Constraint`, or 103 a singleton object of the desired class. 104 factory (Callable): A callable that inputs a constraint object and returns 105 a :class:`~torch.distributions.transforms.Transform` object. 106 """ 107 # Support use as decorator. 108 if factory is None: 109 return lambda factory: self.register(constraint, factory) 110 111 # Support calling on singleton instances. 112 if isinstance(constraint, constraints.Constraint): 113 constraint = type(constraint) 114 115 if not isinstance(constraint, type) or not issubclass( 116 constraint, constraints.Constraint 117 ): 118 raise TypeError( 119 f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}" 120 ) 121 122 self._registry[constraint] = factory 123 return factory 124 125 def __call__(self, constraint): 126 """ 127 Looks up a transform to constrained space, given a constraint object. 128 Usage:: 129 130 constraint = Normal.arg_constraints['scale'] 131 scale = transform_to(constraint)(torch.zeros(1)) # constrained 132 u = transform_to(constraint).inv(scale) # unconstrained 133 134 Args: 135 constraint (:class:`~torch.distributions.constraints.Constraint`): 136 A constraint object. 137 138 Returns: 139 A :class:`~torch.distributions.transforms.Transform` object. 140 141 Raises: 142 `NotImplementedError` if no transform has been registered. 143 """ 144 # Look up by Constraint subclass. 145 try: 146 factory = self._registry[type(constraint)] 147 except KeyError: 148 raise NotImplementedError( 149 f"Cannot transform {type(constraint).__name__} constraints" 150 ) from None 151 return factory(constraint) 152 153 154biject_to = ConstraintRegistry() 155transform_to = ConstraintRegistry() 156 157 158################################################################################ 159# Registration Table 160################################################################################ 161 162 163@biject_to.register(constraints.real) 164@transform_to.register(constraints.real) 165def _transform_to_real(constraint): 166 return transforms.identity_transform 167 168 169@biject_to.register(constraints.independent) 170def _biject_to_independent(constraint): 171 base_transform = biject_to(constraint.base_constraint) 172 return transforms.IndependentTransform( 173 base_transform, constraint.reinterpreted_batch_ndims 174 ) 175 176 177@transform_to.register(constraints.independent) 178def _transform_to_independent(constraint): 179 base_transform = transform_to(constraint.base_constraint) 180 return transforms.IndependentTransform( 181 base_transform, constraint.reinterpreted_batch_ndims 182 ) 183 184 185@biject_to.register(constraints.positive) 186@biject_to.register(constraints.nonnegative) 187@transform_to.register(constraints.positive) 188@transform_to.register(constraints.nonnegative) 189def _transform_to_positive(constraint): 190 return transforms.ExpTransform() 191 192 193@biject_to.register(constraints.greater_than) 194@biject_to.register(constraints.greater_than_eq) 195@transform_to.register(constraints.greater_than) 196@transform_to.register(constraints.greater_than_eq) 197def _transform_to_greater_than(constraint): 198 return transforms.ComposeTransform( 199 [ 200 transforms.ExpTransform(), 201 transforms.AffineTransform(constraint.lower_bound, 1), 202 ] 203 ) 204 205 206@biject_to.register(constraints.less_than) 207@transform_to.register(constraints.less_than) 208def _transform_to_less_than(constraint): 209 return transforms.ComposeTransform( 210 [ 211 transforms.ExpTransform(), 212 transforms.AffineTransform(constraint.upper_bound, -1), 213 ] 214 ) 215 216 217@biject_to.register(constraints.interval) 218@biject_to.register(constraints.half_open_interval) 219@transform_to.register(constraints.interval) 220@transform_to.register(constraints.half_open_interval) 221def _transform_to_interval(constraint): 222 # Handle the special case of the unit interval. 223 lower_is_0 = ( 224 isinstance(constraint.lower_bound, numbers.Number) 225 and constraint.lower_bound == 0 226 ) 227 upper_is_1 = ( 228 isinstance(constraint.upper_bound, numbers.Number) 229 and constraint.upper_bound == 1 230 ) 231 if lower_is_0 and upper_is_1: 232 return transforms.SigmoidTransform() 233 234 loc = constraint.lower_bound 235 scale = constraint.upper_bound - constraint.lower_bound 236 return transforms.ComposeTransform( 237 [transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)] 238 ) 239 240 241@biject_to.register(constraints.simplex) 242def _biject_to_simplex(constraint): 243 return transforms.StickBreakingTransform() 244 245 246@transform_to.register(constraints.simplex) 247def _transform_to_simplex(constraint): 248 return transforms.SoftmaxTransform() 249 250 251# TODO define a bijection for LowerCholeskyTransform 252@transform_to.register(constraints.lower_cholesky) 253def _transform_to_lower_cholesky(constraint): 254 return transforms.LowerCholeskyTransform() 255 256 257@transform_to.register(constraints.positive_definite) 258@transform_to.register(constraints.positive_semidefinite) 259def _transform_to_positive_definite(constraint): 260 return transforms.PositiveDefiniteTransform() 261 262 263@biject_to.register(constraints.corr_cholesky) 264@transform_to.register(constraints.corr_cholesky) 265def _transform_to_corr_cholesky(constraint): 266 return transforms.CorrCholeskyTransform() 267 268 269@biject_to.register(constraints.cat) 270def _biject_to_cat(constraint): 271 return transforms.CatTransform( 272 [biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths 273 ) 274 275 276@transform_to.register(constraints.cat) 277def _transform_to_cat(constraint): 278 return transforms.CatTransform( 279 [transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths 280 ) 281 282 283@biject_to.register(constraints.stack) 284def _biject_to_stack(constraint): 285 return transforms.StackTransform( 286 [biject_to(c) for c in constraint.cseq], constraint.dim 287 ) 288 289 290@transform_to.register(constraints.stack) 291def _transform_to_stack(constraint): 292 return transforms.StackTransform( 293 [transform_to(c) for c in constraint.cseq], constraint.dim 294 ) 295