1# mypy: allow-untyped-defs 2r""" 3The following constraints are implemented: 4 5- ``constraints.boolean`` 6- ``constraints.cat`` 7- ``constraints.corr_cholesky`` 8- ``constraints.dependent`` 9- ``constraints.greater_than(lower_bound)`` 10- ``constraints.greater_than_eq(lower_bound)`` 11- ``constraints.independent(constraint, reinterpreted_batch_ndims)`` 12- ``constraints.integer_interval(lower_bound, upper_bound)`` 13- ``constraints.interval(lower_bound, upper_bound)`` 14- ``constraints.less_than(upper_bound)`` 15- ``constraints.lower_cholesky`` 16- ``constraints.lower_triangular`` 17- ``constraints.multinomial`` 18- ``constraints.nonnegative`` 19- ``constraints.nonnegative_integer`` 20- ``constraints.one_hot`` 21- ``constraints.positive_integer`` 22- ``constraints.positive`` 23- ``constraints.positive_semidefinite`` 24- ``constraints.positive_definite`` 25- ``constraints.real_vector`` 26- ``constraints.real`` 27- ``constraints.simplex`` 28- ``constraints.symmetric`` 29- ``constraints.stack`` 30- ``constraints.square`` 31- ``constraints.symmetric`` 32- ``constraints.unit_interval`` 33""" 34 35import torch 36 37 38__all__ = [ 39 "Constraint", 40 "boolean", 41 "cat", 42 "corr_cholesky", 43 "dependent", 44 "dependent_property", 45 "greater_than", 46 "greater_than_eq", 47 "independent", 48 "integer_interval", 49 "interval", 50 "half_open_interval", 51 "is_dependent", 52 "less_than", 53 "lower_cholesky", 54 "lower_triangular", 55 "multinomial", 56 "nonnegative", 57 "nonnegative_integer", 58 "one_hot", 59 "positive", 60 "positive_semidefinite", 61 "positive_definite", 62 "positive_integer", 63 "real", 64 "real_vector", 65 "simplex", 66 "square", 67 "stack", 68 "symmetric", 69 "unit_interval", 70] 71 72 73class Constraint: 74 """ 75 Abstract base class for constraints. 76 77 A constraint object represents a region over which a variable is valid, 78 e.g. within which a variable can be optimized. 79 80 Attributes: 81 is_discrete (bool): Whether constrained space is discrete. 82 Defaults to False. 83 event_dim (int): Number of rightmost dimensions that together define 84 an event. The :meth:`check` method will remove this many dimensions 85 when computing validity. 86 """ 87 88 is_discrete = False # Default to continuous. 89 event_dim = 0 # Default to univariate. 90 91 def check(self, value): 92 """ 93 Returns a byte tensor of ``sample_shape + batch_shape`` indicating 94 whether each event in value satisfies this constraint. 95 """ 96 raise NotImplementedError 97 98 def __repr__(self): 99 return self.__class__.__name__[1:] + "()" 100 101 102class _Dependent(Constraint): 103 """ 104 Placeholder for variables whose support depends on other variables. 105 These variables obey no simple coordinate-wise constraints. 106 107 Args: 108 is_discrete (bool): Optional value of ``.is_discrete`` in case this 109 can be computed statically. If not provided, access to the 110 ``.is_discrete`` attribute will raise a NotImplementedError. 111 event_dim (int): Optional value of ``.event_dim`` in case this 112 can be computed statically. If not provided, access to the 113 ``.event_dim`` attribute will raise a NotImplementedError. 114 """ 115 116 def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): 117 self._is_discrete = is_discrete 118 self._event_dim = event_dim 119 super().__init__() 120 121 @property 122 def is_discrete(self): 123 if self._is_discrete is NotImplemented: 124 raise NotImplementedError(".is_discrete cannot be determined statically") 125 return self._is_discrete 126 127 @property 128 def event_dim(self): 129 if self._event_dim is NotImplemented: 130 raise NotImplementedError(".event_dim cannot be determined statically") 131 return self._event_dim 132 133 def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): 134 """ 135 Support for syntax to customize static attributes:: 136 137 constraints.dependent(is_discrete=True, event_dim=1) 138 """ 139 if is_discrete is NotImplemented: 140 is_discrete = self._is_discrete 141 if event_dim is NotImplemented: 142 event_dim = self._event_dim 143 return _Dependent(is_discrete=is_discrete, event_dim=event_dim) 144 145 def check(self, x): 146 raise ValueError("Cannot determine validity of dependent constraint") 147 148 149def is_dependent(constraint): 150 """ 151 Checks if ``constraint`` is a ``_Dependent`` object. 152 153 Args: 154 constraint : A ``Constraint`` object. 155 156 Returns: 157 ``bool``: True if ``constraint`` can be refined to the type ``_Dependent``, False otherwise. 158 159 Examples: 160 >>> import torch 161 >>> from torch.distributions import Bernoulli 162 >>> from torch.distributions.constraints import is_dependent 163 164 >>> dist = Bernoulli(probs = torch.tensor([0.6], requires_grad=True)) 165 >>> constraint1 = dist.arg_constraints["probs"] 166 >>> constraint2 = dist.arg_constraints["logits"] 167 168 >>> for constraint in [constraint1, constraint2]: 169 >>> if is_dependent(constraint): 170 >>> continue 171 """ 172 return isinstance(constraint, _Dependent) 173 174 175class _DependentProperty(property, _Dependent): 176 """ 177 Decorator that extends @property to act like a `Dependent` constraint when 178 called on a class and act like a property when called on an object. 179 180 Example:: 181 182 class Uniform(Distribution): 183 def __init__(self, low, high): 184 self.low = low 185 self.high = high 186 @constraints.dependent_property(is_discrete=False, event_dim=0) 187 def support(self): 188 return constraints.interval(self.low, self.high) 189 190 Args: 191 fn (Callable): The function to be decorated. 192 is_discrete (bool): Optional value of ``.is_discrete`` in case this 193 can be computed statically. If not provided, access to the 194 ``.is_discrete`` attribute will raise a NotImplementedError. 195 event_dim (int): Optional value of ``.event_dim`` in case this 196 can be computed statically. If not provided, access to the 197 ``.event_dim`` attribute will raise a NotImplementedError. 198 """ 199 200 def __init__( 201 self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented 202 ): 203 super().__init__(fn) 204 self._is_discrete = is_discrete 205 self._event_dim = event_dim 206 207 def __call__(self, fn): 208 """ 209 Support for syntax to customize static attributes:: 210 211 @constraints.dependent_property(is_discrete=True, event_dim=1) 212 def support(self): 213 ... 214 """ 215 return _DependentProperty( 216 fn, is_discrete=self._is_discrete, event_dim=self._event_dim 217 ) 218 219 220class _IndependentConstraint(Constraint): 221 """ 222 Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many 223 dims in :meth:`check`, so that an event is valid only if all its 224 independent entries are valid. 225 """ 226 227 def __init__(self, base_constraint, reinterpreted_batch_ndims): 228 assert isinstance(base_constraint, Constraint) 229 assert isinstance(reinterpreted_batch_ndims, int) 230 assert reinterpreted_batch_ndims >= 0 231 self.base_constraint = base_constraint 232 self.reinterpreted_batch_ndims = reinterpreted_batch_ndims 233 super().__init__() 234 235 @property 236 def is_discrete(self): 237 return self.base_constraint.is_discrete 238 239 @property 240 def event_dim(self): 241 return self.base_constraint.event_dim + self.reinterpreted_batch_ndims 242 243 def check(self, value): 244 result = self.base_constraint.check(value) 245 if result.dim() < self.reinterpreted_batch_ndims: 246 expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims 247 raise ValueError( 248 f"Expected value.dim() >= {expected} but got {value.dim()}" 249 ) 250 result = result.reshape( 251 result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,) 252 ) 253 result = result.all(-1) 254 return result 255 256 def __repr__(self): 257 return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})" 258 259 260class _Boolean(Constraint): 261 """ 262 Constrain to the two values `{0, 1}`. 263 """ 264 265 is_discrete = True 266 267 def check(self, value): 268 return (value == 0) | (value == 1) 269 270 271class _OneHot(Constraint): 272 """ 273 Constrain to one-hot vectors. 274 """ 275 276 is_discrete = True 277 event_dim = 1 278 279 def check(self, value): 280 is_boolean = (value == 0) | (value == 1) 281 is_normalized = value.sum(-1).eq(1) 282 return is_boolean.all(-1) & is_normalized 283 284 285class _IntegerInterval(Constraint): 286 """ 287 Constrain to an integer interval `[lower_bound, upper_bound]`. 288 """ 289 290 is_discrete = True 291 292 def __init__(self, lower_bound, upper_bound): 293 self.lower_bound = lower_bound 294 self.upper_bound = upper_bound 295 super().__init__() 296 297 def check(self, value): 298 return ( 299 (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound) 300 ) 301 302 def __repr__(self): 303 fmt_string = self.__class__.__name__[1:] 304 fmt_string += ( 305 f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" 306 ) 307 return fmt_string 308 309 310class _IntegerLessThan(Constraint): 311 """ 312 Constrain to an integer interval `(-inf, upper_bound]`. 313 """ 314 315 is_discrete = True 316 317 def __init__(self, upper_bound): 318 self.upper_bound = upper_bound 319 super().__init__() 320 321 def check(self, value): 322 return (value % 1 == 0) & (value <= self.upper_bound) 323 324 def __repr__(self): 325 fmt_string = self.__class__.__name__[1:] 326 fmt_string += f"(upper_bound={self.upper_bound})" 327 return fmt_string 328 329 330class _IntegerGreaterThan(Constraint): 331 """ 332 Constrain to an integer interval `[lower_bound, inf)`. 333 """ 334 335 is_discrete = True 336 337 def __init__(self, lower_bound): 338 self.lower_bound = lower_bound 339 super().__init__() 340 341 def check(self, value): 342 return (value % 1 == 0) & (value >= self.lower_bound) 343 344 def __repr__(self): 345 fmt_string = self.__class__.__name__[1:] 346 fmt_string += f"(lower_bound={self.lower_bound})" 347 return fmt_string 348 349 350class _Real(Constraint): 351 """ 352 Trivially constrain to the extended real line `[-inf, inf]`. 353 """ 354 355 def check(self, value): 356 return value == value # False for NANs. 357 358 359class _GreaterThan(Constraint): 360 """ 361 Constrain to a real half line `(lower_bound, inf]`. 362 """ 363 364 def __init__(self, lower_bound): 365 self.lower_bound = lower_bound 366 super().__init__() 367 368 def check(self, value): 369 return self.lower_bound < value 370 371 def __repr__(self): 372 fmt_string = self.__class__.__name__[1:] 373 fmt_string += f"(lower_bound={self.lower_bound})" 374 return fmt_string 375 376 377class _GreaterThanEq(Constraint): 378 """ 379 Constrain to a real half line `[lower_bound, inf)`. 380 """ 381 382 def __init__(self, lower_bound): 383 self.lower_bound = lower_bound 384 super().__init__() 385 386 def check(self, value): 387 return self.lower_bound <= value 388 389 def __repr__(self): 390 fmt_string = self.__class__.__name__[1:] 391 fmt_string += f"(lower_bound={self.lower_bound})" 392 return fmt_string 393 394 395class _LessThan(Constraint): 396 """ 397 Constrain to a real half line `[-inf, upper_bound)`. 398 """ 399 400 def __init__(self, upper_bound): 401 self.upper_bound = upper_bound 402 super().__init__() 403 404 def check(self, value): 405 return value < self.upper_bound 406 407 def __repr__(self): 408 fmt_string = self.__class__.__name__[1:] 409 fmt_string += f"(upper_bound={self.upper_bound})" 410 return fmt_string 411 412 413class _Interval(Constraint): 414 """ 415 Constrain to a real interval `[lower_bound, upper_bound]`. 416 """ 417 418 def __init__(self, lower_bound, upper_bound): 419 self.lower_bound = lower_bound 420 self.upper_bound = upper_bound 421 super().__init__() 422 423 def check(self, value): 424 return (self.lower_bound <= value) & (value <= self.upper_bound) 425 426 def __repr__(self): 427 fmt_string = self.__class__.__name__[1:] 428 fmt_string += ( 429 f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" 430 ) 431 return fmt_string 432 433 434class _HalfOpenInterval(Constraint): 435 """ 436 Constrain to a real interval `[lower_bound, upper_bound)`. 437 """ 438 439 def __init__(self, lower_bound, upper_bound): 440 self.lower_bound = lower_bound 441 self.upper_bound = upper_bound 442 super().__init__() 443 444 def check(self, value): 445 return (self.lower_bound <= value) & (value < self.upper_bound) 446 447 def __repr__(self): 448 fmt_string = self.__class__.__name__[1:] 449 fmt_string += ( 450 f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" 451 ) 452 return fmt_string 453 454 455class _Simplex(Constraint): 456 """ 457 Constrain to the unit simplex in the innermost (rightmost) dimension. 458 Specifically: `x >= 0` and `x.sum(-1) == 1`. 459 """ 460 461 event_dim = 1 462 463 def check(self, value): 464 return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6) 465 466 467class _Multinomial(Constraint): 468 """ 469 Constrain to nonnegative integer values summing to at most an upper bound. 470 471 Note due to limitations of the Multinomial distribution, this currently 472 checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future 473 this may be strengthened to ``value.sum(-1) == upper_bound``. 474 """ 475 476 is_discrete = True 477 event_dim = 1 478 479 def __init__(self, upper_bound): 480 self.upper_bound = upper_bound 481 482 def check(self, x): 483 return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound) 484 485 486class _LowerTriangular(Constraint): 487 """ 488 Constrain to lower-triangular square matrices. 489 """ 490 491 event_dim = 2 492 493 def check(self, value): 494 value_tril = value.tril() 495 return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] 496 497 498class _LowerCholesky(Constraint): 499 """ 500 Constrain to lower-triangular square matrices with positive diagonals. 501 """ 502 503 event_dim = 2 504 505 def check(self, value): 506 value_tril = value.tril() 507 lower_triangular = ( 508 (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] 509 ) 510 511 positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0] 512 return lower_triangular & positive_diagonal 513 514 515class _CorrCholesky(Constraint): 516 """ 517 Constrain to lower-triangular square matrices with positive diagonals and each 518 row vector being of unit length. 519 """ 520 521 event_dim = 2 522 523 def check(self, value): 524 tol = ( 525 torch.finfo(value.dtype).eps * value.size(-1) * 10 526 ) # 10 is an adjustable fudge factor 527 row_norm = torch.linalg.norm(value.detach(), dim=-1) 528 unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1) 529 return _LowerCholesky().check(value) & unit_row_norm 530 531 532class _Square(Constraint): 533 """ 534 Constrain to square matrices. 535 """ 536 537 event_dim = 2 538 539 def check(self, value): 540 return torch.full( 541 size=value.shape[:-2], 542 fill_value=(value.shape[-2] == value.shape[-1]), 543 dtype=torch.bool, 544 device=value.device, 545 ) 546 547 548class _Symmetric(_Square): 549 """ 550 Constrain to Symmetric square matrices. 551 """ 552 553 def check(self, value): 554 square_check = super().check(value) 555 if not square_check.all(): 556 return square_check 557 return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1) 558 559 560class _PositiveSemidefinite(_Symmetric): 561 """ 562 Constrain to positive-semidefinite matrices. 563 """ 564 565 def check(self, value): 566 sym_check = super().check(value) 567 if not sym_check.all(): 568 return sym_check 569 return torch.linalg.eigvalsh(value).ge(0).all(-1) 570 571 572class _PositiveDefinite(_Symmetric): 573 """ 574 Constrain to positive-definite matrices. 575 """ 576 577 def check(self, value): 578 sym_check = super().check(value) 579 if not sym_check.all(): 580 return sym_check 581 return torch.linalg.cholesky_ex(value).info.eq(0) 582 583 584class _Cat(Constraint): 585 """ 586 Constraint functor that applies a sequence of constraints 587 `cseq` at the submatrices at dimension `dim`, 588 each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`. 589 """ 590 591 def __init__(self, cseq, dim=0, lengths=None): 592 assert all(isinstance(c, Constraint) for c in cseq) 593 self.cseq = list(cseq) 594 if lengths is None: 595 lengths = [1] * len(self.cseq) 596 self.lengths = list(lengths) 597 assert len(self.lengths) == len(self.cseq) 598 self.dim = dim 599 super().__init__() 600 601 @property 602 def is_discrete(self): 603 return any(c.is_discrete for c in self.cseq) 604 605 @property 606 def event_dim(self): 607 return max(c.event_dim for c in self.cseq) 608 609 def check(self, value): 610 assert -value.dim() <= self.dim < value.dim() 611 checks = [] 612 start = 0 613 for constr, length in zip(self.cseq, self.lengths): 614 v = value.narrow(self.dim, start, length) 615 checks.append(constr.check(v)) 616 start = start + length # avoid += for jit compat 617 return torch.cat(checks, self.dim) 618 619 620class _Stack(Constraint): 621 """ 622 Constraint functor that applies a sequence of constraints 623 `cseq` at the submatrices at dimension `dim`, 624 in a way compatible with :func:`torch.stack`. 625 """ 626 627 def __init__(self, cseq, dim=0): 628 assert all(isinstance(c, Constraint) for c in cseq) 629 self.cseq = list(cseq) 630 self.dim = dim 631 super().__init__() 632 633 @property 634 def is_discrete(self): 635 return any(c.is_discrete for c in self.cseq) 636 637 @property 638 def event_dim(self): 639 dim = max(c.event_dim for c in self.cseq) 640 if self.dim + dim < 0: 641 dim += 1 642 return dim 643 644 def check(self, value): 645 assert -value.dim() <= self.dim < value.dim() 646 vs = [value.select(self.dim, i) for i in range(value.size(self.dim))] 647 return torch.stack( 648 [constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim 649 ) 650 651 652# Public interface. 653dependent = _Dependent() 654dependent_property = _DependentProperty 655independent = _IndependentConstraint 656boolean = _Boolean() 657one_hot = _OneHot() 658nonnegative_integer = _IntegerGreaterThan(0) 659positive_integer = _IntegerGreaterThan(1) 660integer_interval = _IntegerInterval 661real = _Real() 662real_vector = independent(real, 1) 663positive = _GreaterThan(0.0) 664nonnegative = _GreaterThanEq(0.0) 665greater_than = _GreaterThan 666greater_than_eq = _GreaterThanEq 667less_than = _LessThan 668multinomial = _Multinomial 669unit_interval = _Interval(0.0, 1.0) 670interval = _Interval 671half_open_interval = _HalfOpenInterval 672simplex = _Simplex() 673lower_triangular = _LowerTriangular() 674lower_cholesky = _LowerCholesky() 675corr_cholesky = _CorrCholesky() 676square = _Square() 677symmetric = _Symmetric() 678positive_semidefinite = _PositiveSemidefinite() 679positive_definite = _PositiveDefinite() 680cat = _Cat 681stack = _Stack 682