1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for probability distributions.""" 16 17import functools 18import hashlib 19 20import numpy as np 21 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import check_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import linalg_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import nn 33from tensorflow.python.util import tf_inspect 34 35 36def assert_integer_form(x, 37 data=None, 38 summarize=None, 39 message=None, 40 int_dtype=None, 41 name="assert_integer_form"): 42 """Assert that x has integer components (or floats equal to integers). 43 44 Args: 45 x: Floating-point `Tensor` 46 data: The tensors to print out if the condition is `False`. Defaults to 47 error message and first few entries of `x` and `y`. 48 summarize: Print this many entries of each tensor. 49 message: A string to prefix to the default message. 50 int_dtype: A `tf.dtype` used to cast the float to. The default (`None`) 51 implies the smallest possible signed int will be used for casting. 52 name: A name for this operation (optional). 53 54 Returns: 55 Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`. 56 """ 57 with ops.name_scope(name, values=[x, data]): 58 x = ops.convert_to_tensor(x, name="x") 59 if x.dtype.is_integer: 60 return control_flow_ops.no_op() 61 message = message or "{} has non-integer components".format(x) 62 if int_dtype is None: 63 try: 64 int_dtype = { 65 dtypes.float16: dtypes.int16, 66 dtypes.float32: dtypes.int32, 67 dtypes.float64: dtypes.int64, 68 }[x.dtype.base_dtype] 69 except KeyError: 70 raise TypeError("Unrecognized type {}".format(x.dtype.name)) 71 return check_ops.assert_equal( 72 x, 73 math_ops.cast(math_ops.cast(x, int_dtype), x.dtype), 74 data=data, 75 summarize=summarize, 76 message=message, 77 name=name) 78 79 80def assert_symmetric(matrix): 81 matrix_t = array_ops.matrix_transpose(matrix) 82 return control_flow_ops.with_dependencies( 83 [check_ops.assert_equal(matrix, matrix_t)], matrix) 84 85 86def embed_check_nonnegative_integer_form( 87 x, name="embed_check_nonnegative_integer_form"): 88 """Assert x is a non-negative tensor, and optionally of integers.""" 89 with ops.name_scope(name, values=[x]): 90 x = ops.convert_to_tensor(x, name="x") 91 assertions = [ 92 check_ops.assert_non_negative( 93 x, message="'{}' must be non-negative.".format(x)), 94 ] 95 if not x.dtype.is_integer: 96 assertions += [ 97 assert_integer_form( 98 x, 99 message="'{}' cannot contain fractional components.".format(x)), 100 ] 101 return control_flow_ops.with_dependencies(assertions, x) 102 103 104def same_dynamic_shape(a, b): 105 """Returns whether a and b have the same dynamic shape. 106 107 Args: 108 a: `Tensor` 109 b: `Tensor` 110 111 Returns: 112 `bool` `Tensor` representing if both tensors have the same shape. 113 """ 114 a = ops.convert_to_tensor(a, name="a") 115 b = ops.convert_to_tensor(b, name="b") 116 117 # Here we can't just do math_ops.equal(a.shape, b.shape), since 118 # static shape inference may break the equality comparison between 119 # shape(a) and shape(b) in math_ops.equal. 120 def all_shapes_equal(): 121 return math_ops.reduce_all( 122 math_ops.equal( 123 array_ops.concat( 124 [array_ops.shape(a), array_ops.shape(b)], 0), 125 array_ops.concat( 126 [array_ops.shape(b), array_ops.shape(a)], 0))) 127 128 # One of the shapes isn't fully defined, so we need to use the dynamic 129 # shape. 130 return control_flow_ops.cond( 131 math_ops.equal(array_ops.rank(a), array_ops.rank(b)), 132 all_shapes_equal, lambda: constant_op.constant(False)) 133 134 135def maybe_get_static_value(x, dtype=None): 136 """Helper which tries to return a static value. 137 138 Given `x`, extract it's value statically, optionally casting to a specific 139 dtype. If this is not possible, None is returned. 140 141 Args: 142 x: `Tensor` for which to extract a value statically. 143 dtype: Optional dtype to cast to. 144 145 Returns: 146 Statically inferred value if possible, otherwise None. 147 """ 148 if x is None: 149 return x 150 try: 151 # This returns an np.ndarray. 152 x_ = tensor_util.constant_value(x) 153 except TypeError: 154 x_ = x 155 if x_ is None or dtype is None: 156 return x_ 157 return np.array(x_, dtype) 158 159 160def get_logits_and_probs(logits=None, 161 probs=None, 162 multidimensional=False, 163 validate_args=False, 164 name="get_logits_and_probs", 165 dtype=None): 166 """Converts logit to probabilities (or vice-versa), and returns both. 167 168 Args: 169 logits: Floating-point `Tensor` representing log-odds. 170 probs: Floating-point `Tensor` representing probabilities. 171 multidimensional: Python `bool`, default `False`. If `True`, represents 172 whether the last dimension of `logits` or `probs`, a `[N1, N2, ... k]` 173 dimensional tensor, representing the logit or probability of `shape[-1]` 174 classes. 175 validate_args: Python `bool`, default `False`. When `True`, either assert `0 176 <= probs <= 1` (if not `multidimensional`) or that the last dimension of 177 `probs` sums to one. 178 name: A name for this operation (optional). 179 dtype: `tf.DType` to prefer when converting args to `Tensor`s. 180 181 Returns: 182 logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or 183 `1`, then the corresponding entry in the returned logit will be `-Inf` and 184 `Inf` respectively. 185 186 Raises: 187 ValueError: if neither `probs` nor `logits` were passed in, or both were. 188 """ 189 with ops.name_scope(name, values=[probs, logits]): 190 if (probs is None) == (logits is None): 191 raise ValueError("Must pass probs or logits, but not both.") 192 193 if probs is None: 194 logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype) 195 if not logits.dtype.is_floating: 196 raise TypeError("logits must having floating type.") 197 # We can early return since we constructed probs and therefore know 198 # they're valid. 199 if multidimensional: 200 if validate_args: 201 logits = embed_check_categorical_event_shape(logits) 202 return logits, nn.softmax(logits, name="probs") 203 return logits, math_ops.sigmoid(logits, name="probs") 204 205 probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype) 206 if not probs.dtype.is_floating: 207 raise TypeError("probs must having floating type.") 208 209 if validate_args: 210 with ops.name_scope("validate_probs"): 211 one = constant_op.constant(1., probs.dtype) 212 dependencies = [check_ops.assert_non_negative(probs)] 213 if multidimensional: 214 probs = embed_check_categorical_event_shape(probs) 215 dependencies += [ 216 check_ops.assert_near( 217 math_ops.reduce_sum(probs, -1), 218 one, 219 message="probs does not sum to 1.") 220 ] 221 else: 222 dependencies += [ 223 check_ops.assert_less_equal( 224 probs, one, message="probs has components greater than 1.") 225 ] 226 probs = control_flow_ops.with_dependencies(dependencies, probs) 227 228 with ops.name_scope("logits"): 229 if multidimensional: 230 # Here we don't compute the multidimensional case, in a manner 231 # consistent with respect to the unidimensional case. We do so 232 # following the TF convention. Typically, you might expect to see 233 # logits = log(probs) - log(probs[pivot]). A side-effect of 234 # being consistent with the TF approach is that the unidimensional case 235 # implicitly handles the second dimension but the multidimensional case 236 # explicitly keeps the pivot dimension. 237 return math_ops.log(probs), probs 238 return math_ops.log(probs) - math_ops.log1p(-1. * probs), probs 239 240 241def _is_known_unsigned_by_dtype(dt): 242 """Helper returning True if dtype is known to be unsigned.""" 243 return { 244 dtypes.bool: True, 245 dtypes.uint8: True, 246 dtypes.uint16: True, 247 }.get(dt.base_dtype, False) 248 249 250def _is_known_signed_by_dtype(dt): 251 """Helper returning True if dtype is known to be signed.""" 252 return { 253 dtypes.float16: True, 254 dtypes.float32: True, 255 dtypes.float64: True, 256 dtypes.int8: True, 257 dtypes.int16: True, 258 dtypes.int32: True, 259 dtypes.int64: True, 260 }.get(dt.base_dtype, False) 261 262 263def _is_known_dtype(dt): 264 """Helper returning True if dtype is known.""" 265 return _is_known_unsigned_by_dtype(dt) or _is_known_signed_by_dtype(dt) 266 267 268def _largest_integer_by_dtype(dt): 269 """Helper returning the largest integer exactly representable by dtype.""" 270 if not _is_known_dtype(dt): 271 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 272 if dt.is_floating: 273 return int(2**(np.finfo(dt.as_numpy_dtype).nmant + 1)) 274 if dt.is_integer: 275 return np.iinfo(dt.as_numpy_dtype).max 276 if dt.base_dtype == dtypes.bool: 277 return int(1) 278 # We actually can't land here but keep the case for completeness. 279 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 280 281 282def _smallest_integer_by_dtype(dt): 283 """Helper returning the smallest integer exactly representable by dtype.""" 284 if not _is_known_dtype(dt): 285 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 286 if _is_known_unsigned_by_dtype(dt): 287 return 0 288 return -1 * _largest_integer_by_dtype(dt) 289 290 291def _is_integer_like_by_dtype(dt): 292 """Helper returning True if dtype.is_integer or is `bool`.""" 293 if not _is_known_dtype(dt): 294 raise TypeError("Unrecognized dtype: {}".format(dt.name)) 295 return dt.is_integer or dt.base_dtype == dtypes.bool 296 297 298def embed_check_categorical_event_shape( 299 categorical_param, name="embed_check_categorical_event_shape"): 300 """Embeds checks that categorical distributions don't have too many classes. 301 302 A categorical-type distribution is one which, e.g., returns the class label 303 rather than a one-hot encoding. E.g., `Categorical(probs)`. 304 305 Since distributions output samples in the same dtype as the parameters, we 306 must ensure that casting doesn't lose precision. That is, the 307 `parameter.dtype` implies a maximum number of classes. However, since shape is 308 `int32` and categorical variables are presumed to be indexes into a `Tensor`, 309 we must also ensure that the number of classes is no larger than the largest 310 possible `int32` index, i.e., `2**31-1`. 311 312 In other words the number of classes, `K`, must satisfy the following 313 condition: 314 315 ```python 316 K <= min( 317 int(2**31 - 1), # Largest float as an index. 318 { 319 dtypes.float16: int(2**11), # Largest int as a float16. 320 dtypes.float32: int(2**24), 321 dtypes.float64: int(2**53), 322 }.get(categorical_param.dtype.base_dtype, 0)) 323 ``` 324 325 Args: 326 categorical_param: Floating-point `Tensor` representing parameters of 327 distribution over categories. The rightmost shape is presumed to be the 328 number of categories. 329 name: A name for this operation (optional). 330 331 Returns: 332 categorical_param: Input `Tensor` with appropriate assertions embedded. 333 334 Raises: 335 TypeError: if `categorical_param` has an unknown `dtype`. 336 ValueError: if we can statically identify `categorical_param` as being too 337 large (for being closed under int32/float casting). 338 """ 339 with ops.name_scope(name, values=[categorical_param]): 340 x = ops.convert_to_tensor(categorical_param, name="categorical_param") 341 # The size must not exceed both of: 342 # - The largest possible int32 (since categorical values are presumed to be 343 # indexes into a Tensor). 344 # - The largest possible integer exactly representable under the given 345 # floating-point dtype (since we need to cast to/from). 346 # 347 # The chosen floating-point thresholds are 2**(1 + mantissa_bits). 348 # For more details, see: 349 # https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation 350 x_dtype = x.dtype.base_dtype 351 max_event_size = ( 352 _largest_integer_by_dtype(x_dtype) if x_dtype.is_floating else 0) 353 if max_event_size == 0: 354 raise TypeError("Unable to validate size of unrecognized dtype " 355 "({}).".format(x_dtype.name)) 356 try: 357 x_shape_static = x.get_shape().with_rank_at_least(1) 358 except ValueError: 359 raise ValueError("A categorical-distribution parameter must have " 360 "at least 1 dimension.") 361 if tensor_shape.dimension_value(x_shape_static[-1]) is not None: 362 event_size = x_shape_static.dims[-1].value 363 if event_size < 2: 364 raise ValueError("A categorical-distribution parameter must have at " 365 "least 2 events.") 366 if event_size > max_event_size: 367 raise ValueError("Number of classes exceeds `dtype` precision, i.e., " 368 "{} implies shape ({}) cannot exceed {}.".format( 369 x_dtype.name, event_size, max_event_size)) 370 return x 371 else: 372 event_size = array_ops.shape(x, name="x_shape")[-1] 373 return control_flow_ops.with_dependencies([ 374 check_ops.assert_rank_at_least( 375 x, 376 1, 377 message=("A categorical-distribution parameter must have " 378 "at least 1 dimension.")), 379 check_ops.assert_greater_equal( 380 array_ops.shape(x)[-1], 381 2, 382 message=("A categorical-distribution parameter must have at " 383 "least 2 events.")), 384 check_ops.assert_less_equal( 385 event_size, 386 max_event_size, 387 message="Number of classes exceeds `dtype` precision, " 388 "i.e., {} dtype cannot exceed {} shape.".format( 389 x_dtype.name, max_event_size)), 390 ], x) 391 392 393def embed_check_integer_casting_closed(x, 394 target_dtype, 395 assert_nonnegative=True, 396 name="embed_check_casting_closed"): 397 """Ensures integers remain unaffected despite casting to/from int/float types. 398 399 Example integer-types: `uint8`, `int32`, `bool`. 400 Example floating-types: `float32`, `float64`. 401 402 The largest possible integer representable by an IEEE754 floating-point is 403 `2**(1 + mantissa_bits)` yet the largest possible integer as an int-type is 404 `2**(bits - 1) - 1`. This function ensures that a `Tensor` purporting to have 405 integer-form values can be cast to some other type without loss of precision. 406 407 The smallest representable integer is the negative of the largest 408 representable integer, except for types: `uint8`, `uint16`, `bool`. For these 409 types, the smallest representable integer is `0`. 410 411 Args: 412 x: `Tensor` representing integer-form values. 413 target_dtype: TF `dtype` under which `x` should have identical values. 414 assert_nonnegative: `bool` indicating `x` should contain nonnegative values. 415 name: A name for this operation (optional). 416 417 Returns: 418 x: Input `Tensor` with appropriate assertions embedded. 419 420 Raises: 421 TypeError: if `x` is neither integer- nor floating-type. 422 TypeError: if `target_dtype` is neither integer- nor floating-type. 423 TypeError: if neither `x` nor `target_dtype` are integer-type. 424 """ 425 426 with ops.name_scope(name, values=[x]): 427 x = ops.convert_to_tensor(x, name="x") 428 if (not _is_integer_like_by_dtype(x.dtype) and not x.dtype.is_floating): 429 raise TypeError("{}.dtype must be floating- or " 430 "integer-type.".format(x.dtype.name)) 431 if (not _is_integer_like_by_dtype(target_dtype) and 432 not target_dtype.is_floating): 433 raise TypeError("target_dtype ({}) must be floating- or " 434 "integer-type.".format(target_dtype.name)) 435 if (not _is_integer_like_by_dtype(x.dtype) and 436 not _is_integer_like_by_dtype(target_dtype)): 437 raise TypeError("At least one of {}.dtype ({}) and target_dtype ({}) " 438 "must be integer-type.".format(x, x.dtype.name, 439 target_dtype.name)) 440 441 assertions = [] 442 if assert_nonnegative: 443 assertions += [ 444 check_ops.assert_non_negative( 445 x, message="Elements must be non-negative."), 446 ] 447 448 if x.dtype.is_floating: 449 # Being here means _is_integer_like_by_dtype(target_dtype) = True. 450 # Since this check implies the magnitude check below, we need only it. 451 assertions += [ 452 assert_integer_form( 453 x, 454 int_dtype=target_dtype, 455 message="Elements must be {}-equivalent.".format( 456 target_dtype.name)), 457 ] 458 else: 459 if (_largest_integer_by_dtype(x.dtype) > 460 _largest_integer_by_dtype(target_dtype)): 461 # Cast may lose integer precision. 462 assertions += [ 463 check_ops.assert_less_equal( 464 x, 465 _largest_integer_by_dtype(target_dtype), 466 message=("Elements cannot exceed {}.".format( 467 _largest_integer_by_dtype(target_dtype)))), 468 ] 469 if (not assert_nonnegative and (_smallest_integer_by_dtype( 470 x.dtype) < _smallest_integer_by_dtype(target_dtype))): 471 assertions += [ 472 check_ops.assert_greater_equal( 473 x, 474 _smallest_integer_by_dtype(target_dtype), 475 message=("Elements cannot be smaller than {}.".format( 476 _smallest_integer_by_dtype(target_dtype)))), 477 ] 478 479 if not assertions: 480 return x 481 return control_flow_ops.with_dependencies(assertions, x) 482 483 484def log_combinations(n, counts, name="log_combinations"): 485 """Multinomial coefficient. 486 487 Given `n` and `counts`, where `counts` has last dimension `k`, we compute 488 the multinomial coefficient as: 489 490 ```n! / sum_i n_i!``` 491 492 where `i` runs over all `k` classes. 493 494 Args: 495 n: Floating-point `Tensor` broadcastable with `counts`. This represents `n` 496 outcomes. 497 counts: Floating-point `Tensor` broadcastable with `n`. This represents 498 counts in `k` classes, where `k` is the last dimension of the tensor. 499 name: A name for this operation (optional). 500 501 Returns: 502 `Tensor` representing the multinomial coefficient between `n` and `counts`. 503 """ 504 # First a bit about the number of ways counts could have come in: 505 # E.g. if counts = [1, 2], then this is 3 choose 2. 506 # In general, this is (sum counts)! / sum(counts!) 507 # The sum should be along the last dimension of counts. This is the 508 # "distribution" dimension. Here n a priori represents the sum of counts. 509 with ops.name_scope(name, values=[n, counts]): 510 n = ops.convert_to_tensor(n, name="n") 511 counts = ops.convert_to_tensor(counts, name="counts") 512 total_permutations = math_ops.lgamma(n + 1) 513 counts_factorial = math_ops.lgamma(counts + 1) 514 redundant_permutations = math_ops.reduce_sum(counts_factorial, axis=[-1]) 515 return total_permutations - redundant_permutations 516 517 518def matrix_diag_transform(matrix, transform=None, name=None): 519 """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged. 520 521 Create a trainable covariance defined by a Cholesky factor: 522 523 ```python 524 # Transform network layer into 2 x 2 array. 525 matrix_values = tf.contrib.layers.fully_connected(activations, 4) 526 matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) 527 528 # Make the diagonal positive. If the upper triangle was zero, this would be a 529 # valid Cholesky factor. 530 chol = matrix_diag_transform(matrix, transform=tf.nn.softplus) 531 532 # LinearOperatorLowerTriangular ignores the upper triangle. 533 operator = LinearOperatorLowerTriangular(chol) 534 ``` 535 536 Example of heteroskedastic 2-D linear regression. 537 538 ```python 539 tfd = tfp.distributions 540 541 # Get a trainable Cholesky factor. 542 matrix_values = tf.contrib.layers.fully_connected(activations, 4) 543 matrix = tf.reshape(matrix_values, (batch_size, 2, 2)) 544 chol = matrix_diag_transform(matrix, transform=tf.nn.softplus) 545 546 # Get a trainable mean. 547 mu = tf.contrib.layers.fully_connected(activations, 2) 548 549 # This is a fully trainable multivariate normal! 550 dist = tfd.MultivariateNormalTriL(mu, chol) 551 552 # Standard log loss. Minimizing this will "train" mu and chol, and then dist 553 # will be a distribution predicting labels as multivariate Gaussians. 554 loss = -1 * tf.reduce_mean(dist.log_prob(labels)) 555 ``` 556 557 Args: 558 matrix: Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are 559 equal. 560 transform: Element-wise function mapping `Tensors` to `Tensors`. To be 561 applied to the diagonal of `matrix`. If `None`, `matrix` is returned 562 unchanged. Defaults to `None`. 563 name: A name to give created ops. Defaults to "matrix_diag_transform". 564 565 Returns: 566 A `Tensor` with same shape and `dtype` as `matrix`. 567 """ 568 with ops.name_scope(name, "matrix_diag_transform", [matrix]): 569 matrix = ops.convert_to_tensor(matrix, name="matrix") 570 if transform is None: 571 return matrix 572 # Replace the diag with transformed diag. 573 diag = array_ops.matrix_diag_part(matrix) 574 transformed_diag = transform(diag) 575 transformed_mat = array_ops.matrix_set_diag(matrix, transformed_diag) 576 577 return transformed_mat 578 579 580def rotate_transpose(x, shift, name="rotate_transpose"): 581 """Circularly moves dims left or right. 582 583 Effectively identical to: 584 585 ```python 586 numpy.transpose(x, numpy.roll(numpy.arange(len(x.shape)), shift)) 587 ``` 588 589 When `validate_args=False` additional graph-runtime checks are 590 performed. These checks entail moving data from to GPU to CPU. 591 592 Example: 593 594 ```python 595 x = tf.random.normal([1, 2, 3, 4]) # Tensor of shape [1, 2, 3, 4]. 596 rotate_transpose(x, -1).shape == [2, 3, 4, 1] 597 rotate_transpose(x, -2).shape == [3, 4, 1, 2] 598 rotate_transpose(x, 1).shape == [4, 1, 2, 3] 599 rotate_transpose(x, 2).shape == [3, 4, 1, 2] 600 rotate_transpose(x, 7).shape == rotate_transpose(x, 3).shape # [2, 3, 4, 1] 601 rotate_transpose(x, -7).shape == rotate_transpose(x, -3).shape # [4, 1, 2, 3] 602 ``` 603 604 Args: 605 x: `Tensor`. 606 shift: `Tensor`. Number of dimensions to transpose left (shift<0) or 607 transpose right (shift>0). 608 name: Python `str`. The name to give this op. 609 610 Returns: 611 rotated_x: Input `Tensor` with dimensions circularly rotated by shift. 612 613 Raises: 614 TypeError: if shift is not integer type. 615 """ 616 with ops.name_scope(name, values=[x, shift]): 617 x = ops.convert_to_tensor(x, name="x") 618 shift = ops.convert_to_tensor(shift, name="shift") 619 # We do not assign back to preserve constant-ness. 620 check_ops.assert_integer(shift) 621 shift_value_static = tensor_util.constant_value(shift) 622 ndims = x.get_shape().ndims 623 if ndims is not None and shift_value_static is not None: 624 if ndims < 2: 625 return x 626 shift_value_static = np.sign(shift_value_static) * ( 627 abs(shift_value_static) % ndims) 628 if shift_value_static == 0: 629 return x 630 perm = np.roll(np.arange(ndims), shift_value_static) 631 return array_ops.transpose(x, perm=perm) 632 else: 633 # Consider if we always had a positive shift, and some specified 634 # direction. 635 # When shifting left we want the new array: 636 # last(x, n-shift) + first(x, shift) 637 # and if shifting right then we want: 638 # last(x, shift) + first(x, n-shift) 639 # Observe that last(a) == slice(a, n) and first(a) == slice(0, a). 640 # Also, we can encode direction and shift as one: direction * shift. 641 # Combining these facts, we have: 642 # a = cond(shift<0, -shift, n-shift) 643 # last(x, n-a) + first(x, a) == x[a:n] + x[0:a] 644 # Finally, we transform shift by modulo length so it can be specified 645 # independently from the array upon which it operates (like python). 646 ndims = array_ops.rank(x) 647 shift = array_ops.where_v2( 648 math_ops.less(shift, 0), 649 math_ops.mod(-shift, ndims), # pylint: disable=invalid-unary-operand-type 650 ndims - math_ops.mod(shift, ndims)) 651 first = math_ops.range(0, shift) 652 last = math_ops.range(shift, ndims) 653 perm = array_ops.concat([last, first], 0) 654 return array_ops.transpose(x, perm=perm) 655 656 657def pick_vector(cond, true_vector, false_vector, name="pick_vector"): 658 """Picks possibly different length row `Tensor`s based on condition. 659 660 Value `Tensor`s should have exactly one dimension. 661 662 If `cond` is a python Boolean or `tf.constant` then either `true_vector` or 663 `false_vector` is immediately returned. I.e., no graph nodes are created and 664 no validation happens. 665 666 Args: 667 cond: `Tensor`. Must have `dtype=tf.bool` and be scalar. 668 true_vector: `Tensor` of one dimension. Returned when cond is `True`. 669 false_vector: `Tensor` of one dimension. Returned when cond is `False`. 670 name: Python `str`. The name to give this op. 671 Example: ```python pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15, 672 18)) # [10, 11] pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15, 673 18)) # [15, 16, 17] ``` 674 675 Returns: 676 true_or_false_vector: `Tensor`. 677 678 Raises: 679 TypeError: if `cond.dtype != tf.bool` 680 TypeError: if `cond` is not a constant and 681 `true_vector.dtype != false_vector.dtype` 682 """ 683 with ops.name_scope(name, values=(cond, true_vector, false_vector)): 684 cond = ops.convert_to_tensor(cond, name="cond") 685 if cond.dtype != dtypes.bool: 686 raise TypeError("%s.dtype=%s which is not %s" % 687 (cond, cond.dtype, dtypes.bool)) 688 cond_value_static = tensor_util.constant_value(cond) 689 if cond_value_static is not None: 690 return true_vector if cond_value_static else false_vector 691 true_vector = ops.convert_to_tensor(true_vector, name="true_vector") 692 false_vector = ops.convert_to_tensor(false_vector, name="false_vector") 693 if true_vector.dtype != false_vector.dtype: 694 raise TypeError( 695 "%s.dtype=%s does not match %s.dtype=%s" % 696 (true_vector, true_vector.dtype, false_vector, false_vector.dtype)) 697 n = array_ops.shape(true_vector)[0] 698 return array_ops.slice( 699 array_ops.concat([true_vector, false_vector], 0), 700 [array_ops.where_v2(cond, 0, n)], [array_ops.where(cond, n, -1)]) 701 702 703def prefer_static_broadcast_shape(shape1, 704 shape2, 705 name="prefer_static_broadcast_shape"): 706 """Convenience function which statically broadcasts shape when possible. 707 708 Args: 709 shape1: `1-D` integer `Tensor`. Already converted to tensor! 710 shape2: `1-D` integer `Tensor`. Already converted to tensor! 711 name: A string name to prepend to created ops. 712 713 Returns: 714 The broadcast shape, either as `TensorShape` (if broadcast can be done 715 statically), or as a `Tensor`. 716 """ 717 with ops.name_scope(name, values=[shape1, shape2]): 718 719 def make_shape_tensor(x): 720 return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32) 721 722 def get_tensor_shape(s): 723 if isinstance(s, tensor_shape.TensorShape): 724 return s 725 s_ = tensor_util.constant_value(make_shape_tensor(s)) 726 if s_ is not None: 727 return tensor_shape.TensorShape(s_) 728 return None 729 730 def get_shape_tensor(s): 731 if not isinstance(s, tensor_shape.TensorShape): 732 return make_shape_tensor(s) 733 if s.is_fully_defined(): 734 return make_shape_tensor(s.as_list()) 735 raise ValueError("Cannot broadcast from partially " 736 "defined `TensorShape`.") 737 738 shape1_ = get_tensor_shape(shape1) 739 shape2_ = get_tensor_shape(shape2) 740 if shape1_ is not None and shape2_ is not None: 741 return array_ops.broadcast_static_shape(shape1_, shape2_) 742 743 shape1_ = get_shape_tensor(shape1) 744 shape2_ = get_shape_tensor(shape2) 745 return array_ops.broadcast_dynamic_shape(shape1_, shape2_) 746 747 748def prefer_static_rank(x): 749 """Return static rank of tensor `x` if available, else `tf.rank(x)`. 750 751 Args: 752 x: `Tensor` (already converted). 753 754 Returns: 755 Numpy array (if static rank is obtainable), else `Tensor`. 756 """ 757 return prefer_static_value(array_ops.rank(x)) 758 759 760def prefer_static_shape(x): 761 """Return static shape of tensor `x` if available, else `tf.shape(x)`. 762 763 Args: 764 x: `Tensor` (already converted). 765 766 Returns: 767 Numpy array (if static shape is obtainable), else `Tensor`. 768 """ 769 return prefer_static_value(array_ops.shape(x)) 770 771 772def prefer_static_value(x): 773 """Return static value of tensor `x` if available, else `x`. 774 775 Args: 776 x: `Tensor` (already converted). 777 778 Returns: 779 Numpy array (if static value is obtainable), else `Tensor`. 780 """ 781 static_x = tensor_util.constant_value(x) 782 if static_x is not None: 783 return static_x 784 return x 785 786 787def gen_new_seed(seed, salt): 788 """Generate a new seed, from the given seed and salt.""" 789 if seed is None: 790 return None 791 string = (str(seed) + salt).encode("utf-8") 792 return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 793 794 795def fill_triangular(x, upper=False, name=None): 796 """Creates a (batch of) triangular matrix from a vector of inputs. 797 798 Created matrix can be lower- or upper-triangular. (It is more efficient to 799 create the matrix as upper or lower, rather than transpose.) 800 801 Triangular matrix elements are filled in a clockwise spiral. See example, 802 below. 803 804 If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is 805 `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., 806 `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. 807 808 Example: 809 810 ```python 811 fill_triangular([1, 2, 3, 4, 5, 6]) 812 # ==> [[4, 0, 0], 813 # [6, 5, 0], 814 # [3, 2, 1]] 815 816 fill_triangular([1, 2, 3, 4, 5, 6], upper=True) 817 # ==> [[1, 2, 3], 818 # [0, 5, 6], 819 # [0, 0, 4]] 820 ``` 821 822 For comparison, a pure numpy version of this function can be found in 823 `util_test.py`, function `_fill_triangular`. 824 825 Args: 826 x: `Tensor` representing lower (or upper) triangular elements. 827 upper: Python `bool` representing whether output matrix should be upper 828 triangular (`True`) or lower triangular (`False`, default). 829 name: Python `str`. The name to give this op. 830 831 Returns: 832 tril: `Tensor` with lower (or upper) triangular elements filled from `x`. 833 834 Raises: 835 ValueError: if `x` cannot be mapped to a triangular matrix. 836 """ 837 838 with ops.name_scope(name, "fill_triangular", values=[x]): 839 x = ops.convert_to_tensor(x, name="x") 840 if tensor_shape.dimension_value( 841 x.shape.with_rank_at_least(1)[-1]) is not None: 842 # Formula derived by solving for n: m = n(n+1)/2. 843 m = np.int32(x.shape.dims[-1].value) 844 n = np.sqrt(0.25 + 2. * m) - 0.5 845 if n != np.floor(n): 846 raise ValueError("Input right-most shape ({}) does not " 847 "correspond to a triangular matrix.".format(m)) 848 n = np.int32(n) 849 static_final_shape = x.shape[:-1].concatenate([n, n]) 850 else: 851 m = array_ops.shape(x)[-1] 852 # For derivation, see above. Casting automatically lops off the 0.5, so we 853 # omit it. We don't validate n is an integer because this has 854 # graph-execution cost; an error will be thrown from the reshape, below. 855 n = math_ops.cast( 856 math_ops.sqrt(0.25 + math_ops.cast(2 * m, dtype=dtypes.float32)), 857 dtype=dtypes.int32) 858 static_final_shape = x.shape.with_rank_at_least(1)[:-1].concatenate( 859 [None, None]) 860 # We now concatenate the "tail" of `x` to `x` (and reverse one of them). 861 # 862 # We do this based on the insight that the input `x` provides `ceil(n/2)` 863 # rows of an `n x n` matrix, some of which will get zeroed out being on the 864 # wrong side of the diagonal. The first row will not get zeroed out at all, 865 # and we need `floor(n/2)` more rows, so the first is what we omit from 866 # `x_tail`. If we then stack those `ceil(n/2)` rows with the `floor(n/2)` 867 # rows provided by a reversed tail, it is exactly the other set of elements 868 # of the reversed tail which will be zeroed out for being on the wrong side 869 # of the diagonal further up/down the matrix. And, in doing-so, we've filled 870 # the triangular matrix in a clock-wise spiral pattern. Neat! 871 # 872 # Try it out in numpy: 873 # n = 3 874 # x = np.arange(n * (n + 1) / 2) 875 # m = x.shape[0] 876 # n = np.int32(np.sqrt(.25 + 2 * m) - .5) 877 # x_tail = x[(m - (n**2 - m)):] 878 # np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower 879 # # ==> array([[3, 4, 5], 880 # [5, 4, 3], 881 # [2, 1, 0]]) 882 # np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper 883 # # ==> array([[0, 1, 2], 884 # [3, 4, 5], 885 # [5, 4, 3]]) 886 # 887 # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't 888 # correctly handle `m == n == 1`. Hence, we do nonnegative indexing. 889 # Furthermore observe that: 890 # m - (n**2 - m) 891 # = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2) 892 # = 2 (n**2 / 2 + n / 2) - n**2 893 # = n**2 + n - n**2 894 # = n 895 ndims = prefer_static_rank(x) 896 if upper: 897 x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])] 898 else: 899 x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])] 900 new_shape = ( 901 static_final_shape.as_list() if static_final_shape.is_fully_defined() 902 else array_ops.concat([array_ops.shape(x)[:-1], [n, n]], axis=0)) 903 x = array_ops.reshape(array_ops.concat(x_list, axis=-1), new_shape) 904 x = array_ops.matrix_band_part( 905 x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0)) 906 x.set_shape(static_final_shape) 907 return x 908 909 910def fill_triangular_inverse(x, upper=False, name=None): 911 """Creates a vector from a (batch of) triangular matrix. 912 913 The vector is created from the lower-triangular or upper-triangular portion 914 depending on the value of the parameter `upper`. 915 916 If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is 917 `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`. 918 919 Example: 920 921 ```python 922 fill_triangular_inverse( 923 [[4, 0, 0], 924 [6, 5, 0], 925 [3, 2, 1]]) 926 927 # ==> [1, 2, 3, 4, 5, 6] 928 929 fill_triangular_inverse( 930 [[1, 2, 3], 931 [0, 5, 6], 932 [0, 0, 4]], upper=True) 933 934 # ==> [1, 2, 3, 4, 5, 6] 935 ``` 936 937 Args: 938 x: `Tensor` representing lower (or upper) triangular elements. 939 upper: Python `bool` representing whether output matrix should be upper 940 triangular (`True`) or lower triangular (`False`, default). 941 name: Python `str`. The name to give this op. 942 943 Returns: 944 flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower 945 (or upper) triangular elements from `x`. 946 """ 947 948 with ops.name_scope(name, "fill_triangular_inverse", values=[x]): 949 x = ops.convert_to_tensor(x, name="x") 950 if tensor_shape.dimension_value( 951 x.shape.with_rank_at_least(2)[-1]) is not None: 952 n = np.int32(x.shape.dims[-1].value) 953 m = np.int32((n * (n + 1)) // 2) 954 static_final_shape = x.shape[:-2].concatenate([m]) 955 else: 956 n = array_ops.shape(x)[-1] 957 m = (n * (n + 1)) // 2 958 static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate( 959 [None]) 960 ndims = prefer_static_rank(x) 961 if upper: 962 initial_elements = x[..., 0, :] 963 triangular_portion = x[..., 1:, :] 964 else: 965 initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2]) 966 triangular_portion = x[..., :-1, :] 967 rotated_triangular_portion = array_ops.reverse( 968 array_ops.reverse(triangular_portion, axis=[ndims - 1]), 969 axis=[ndims - 2]) 970 consolidated_matrix = triangular_portion + rotated_triangular_portion 971 end_sequence = array_ops.reshape( 972 consolidated_matrix, 973 array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0)) 974 y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1) 975 y.set_shape(static_final_shape) 976 return y 977 978 979def tridiag(below=None, diag=None, above=None, name=None): 980 """Creates a matrix with values set above, below, and on the diagonal. 981 982 Example: 983 984 ```python 985 tridiag(below=[1., 2., 3.], 986 diag=[4., 5., 6., 7.], 987 above=[8., 9., 10.]) 988 # ==> array([[ 4., 8., 0., 0.], 989 # [ 1., 5., 9., 0.], 990 # [ 0., 2., 6., 10.], 991 # [ 0., 0., 3., 7.]], dtype=float32) 992 ``` 993 994 Warning: This Op is intended for convenience, not efficiency. 995 996 Args: 997 below: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the below 998 diagonal part. `None` is logically equivalent to `below = 0`. 999 diag: `Tensor` of shape `[B1, ..., Bb, d]` corresponding to the diagonal 1000 part. `None` is logically equivalent to `diag = 0`. 1001 above: `Tensor` of shape `[B1, ..., Bb, d-1]` corresponding to the above 1002 diagonal part. `None` is logically equivalent to `above = 0`. 1003 name: Python `str`. The name to give this op. 1004 1005 Returns: 1006 tridiag: `Tensor` with values set above, below and on the diagonal. 1007 1008 Raises: 1009 ValueError: if all inputs are `None`. 1010 """ 1011 1012 def _pad(x): 1013 """Prepends and appends a zero to every vector in a batch of vectors.""" 1014 shape = array_ops.concat([array_ops.shape(x)[:-1], [1]], axis=0) 1015 z = array_ops.zeros(shape, dtype=x.dtype) 1016 return array_ops.concat([z, x, z], axis=-1) 1017 1018 def _add(*x): 1019 """Adds list of Tensors, ignoring `None`.""" 1020 s = None 1021 for y in x: 1022 if y is None: 1023 continue 1024 elif s is None: 1025 s = y 1026 else: 1027 s += y 1028 if s is None: 1029 raise ValueError("Must specify at least one of `below`, `diag`, `above`.") 1030 return s 1031 1032 with ops.name_scope(name, "tridiag", [below, diag, above]): 1033 if below is not None: 1034 below = ops.convert_to_tensor(below, name="below") 1035 below = array_ops.matrix_diag(_pad(below))[..., :-1, 1:] 1036 if diag is not None: 1037 diag = ops.convert_to_tensor(diag, name="diag") 1038 diag = array_ops.matrix_diag(diag) 1039 if above is not None: 1040 above = ops.convert_to_tensor(above, name="above") 1041 above = array_ops.matrix_diag(_pad(above))[..., 1:, :-1] 1042 # TODO(jvdillon): Consider using scatter_nd instead of creating three full 1043 # matrices. 1044 return _add(below, diag, above) 1045 1046 1047def reduce_weighted_logsumexp(logx, 1048 w=None, 1049 axis=None, 1050 keep_dims=False, 1051 return_sign=False, 1052 name=None): 1053 """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`. 1054 1055 If all weights `w` are known to be positive, it is more efficient to directly 1056 use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.math.log(w))` is 1057 more 1058 efficient than `du.reduce_weighted_logsumexp(logx, w)`. 1059 1060 Reduces `input_tensor` along the dimensions given in `axis`. 1061 Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each 1062 entry in `axis`. If `keep_dims` is true, the reduced dimensions 1063 are retained with length 1. 1064 1065 If `axis` has no entries, all dimensions are reduced, and a 1066 tensor with a single element is returned. 1067 1068 This function is more numerically stable than log(sum(w * exp(input))). It 1069 avoids overflows caused by taking the exp of large inputs and underflows 1070 caused by taking the log of small inputs. 1071 1072 For example: 1073 1074 ```python 1075 x = tf.constant([[0., 0, 0], 1076 [0, 0, 0]]) 1077 1078 w = tf.constant([[-1., 1, 1], 1079 [1, 1, 1]]) 1080 1081 du.reduce_weighted_logsumexp(x, w) 1082 # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4) 1083 1084 du.reduce_weighted_logsumexp(x, w, axis=0) 1085 # ==> [log(-1+1), log(1+1), log(1+1)] 1086 1087 du.reduce_weighted_logsumexp(x, w, axis=1) 1088 # ==> [log(-1+1+1), log(1+1+1)] 1089 1090 du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True) 1091 # ==> [[log(-1+1+1)], [log(1+1+1)]] 1092 1093 du.reduce_weighted_logsumexp(x, w, axis=[0, 1]) 1094 # ==> log(-1+5) 1095 ``` 1096 1097 Args: 1098 logx: The tensor to reduce. Should have numeric type. 1099 w: The weight tensor. Should have numeric type identical to `logx`. 1100 axis: The dimensions to reduce. If `None` (the default), reduces all 1101 dimensions. Must be in the range `[-rank(input_tensor), 1102 rank(input_tensor))`. 1103 keep_dims: If true, retains reduced dimensions with length 1. 1104 return_sign: If `True`, returns the sign of the result. 1105 name: A name for the operation (optional). 1106 1107 Returns: 1108 lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor. 1109 sign: (Optional) The sign of `sum(weight * exp(x))`. 1110 """ 1111 with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]): 1112 logx = ops.convert_to_tensor(logx, name="logx") 1113 if w is None: 1114 lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims) 1115 if return_sign: 1116 sgn = array_ops.ones_like(lswe) 1117 return lswe, sgn 1118 return lswe 1119 w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w") 1120 log_absw_x = logx + math_ops.log(math_ops.abs(w)) 1121 max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True) 1122 # If the largest element is `-inf` or `inf` then we don't bother subtracting 1123 # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That 1124 # this is ok follows from the fact that we're actually free to subtract any 1125 # value we like, so long as we add it back after taking the `log(sum(...))`. 1126 max_log_absw_x = array_ops.where_v2( 1127 math_ops.is_inf(max_log_absw_x), array_ops.zeros_like(max_log_absw_x), 1128 max_log_absw_x) 1129 wx_over_max_absw_x = ( 1130 math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x)) 1131 sum_wx_over_max_absw_x = math_ops.reduce_sum( 1132 wx_over_max_absw_x, axis=axis, keepdims=keep_dims) 1133 if not keep_dims: 1134 max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis) 1135 sgn = math_ops.sign(sum_wx_over_max_absw_x) 1136 lswe = max_log_absw_x + math_ops.log(sgn * sum_wx_over_max_absw_x) 1137 if return_sign: 1138 return lswe, sgn 1139 return lswe 1140 1141 1142# TODO(jvdillon): Merge this test back into: 1143# tensorflow/python/ops/softplus_op_test.py 1144# once TF core is accepting new ops. 1145def softplus_inverse(x, name=None): 1146 """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)). 1147 1148 Mathematically this op is equivalent to: 1149 1150 ```none 1151 softplus_inverse = log(exp(x) - 1.) 1152 ``` 1153 1154 Args: 1155 x: `Tensor`. Non-negative (not enforced), floating-point. 1156 name: A name for the operation (optional). 1157 1158 Returns: 1159 `Tensor`. Has the same type/shape as input `x`. 1160 """ 1161 with ops.name_scope(name, "softplus_inverse", values=[x]): 1162 x = ops.convert_to_tensor(x, name="x") 1163 # We begin by deriving a more numerically stable softplus_inverse: 1164 # x = softplus(y) = Log[1 + exp{y}], (which means x > 0). 1165 # ==> exp{x} = 1 + exp{y} (1) 1166 # ==> y = Log[exp{x} - 1] (2) 1167 # = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}] 1168 # = Log[(1 - exp{-x}) / 1] + Log[exp{x}] 1169 # = Log[1 - exp{-x}] + x (3) 1170 # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x. 1171 # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will 1172 # be zero. To fix this, we use 1 - exp{-x} approx x for small x > 0. 1173 # 1174 # In addition to the numerically stable derivation above, we clamp 1175 # small/large values to be congruent with the logic in: 1176 # tensorflow/core/kernels/softplus_op.h 1177 # 1178 # Finally, we set the input to one whenever the input is too large or too 1179 # small. This ensures that no unchosen codepath is +/- inf. This is 1180 # necessary to ensure the gradient doesn't get NaNs. Recall that the 1181 # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false` 1182 # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful 1183 # to overwrite `x` with ones only when we will never actually use this 1184 # value. Note that we use ones and not zeros since `log(expm1(0.)) = -inf`. 1185 threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2. 1186 is_too_small = math_ops.less(x, np.exp(threshold)) 1187 is_too_large = math_ops.greater(x, -threshold) 1188 too_small_value = math_ops.log(x) 1189 too_large_value = x 1190 # This `where` will ultimately be a NOP because we won't select this 1191 # codepath whenever we used the surrogate `ones_like`. 1192 x = array_ops.where_v2( 1193 math_ops.logical_or(is_too_small, is_too_large), array_ops.ones_like(x), 1194 x) 1195 y = x + math_ops.log(-math_ops.expm1(-x)) # == log(expm1(x)) 1196 return array_ops.where_v2( 1197 is_too_small, too_small_value, 1198 array_ops.where_v2(is_too_large, too_large_value, y)) 1199 1200 1201# TODO(b/35290280): Add unit-tests. 1202def dimension_size(x, axis): 1203 """Returns the size of a specific dimension.""" 1204 # Since tf.gather isn't "constant-in, constant-out", we must first check the 1205 # static shape or fallback to dynamic shape. 1206 s = tensor_shape.dimension_value( 1207 x.shape.with_rank_at_least(np.abs(axis))[axis]) 1208 if s is not None: 1209 return s 1210 return array_ops.shape(x)[axis] 1211 1212 1213def process_quadrature_grid_and_probs(quadrature_grid_and_probs, 1214 dtype, 1215 validate_args, 1216 name=None): 1217 """Validates quadrature grid, probs or computes them as necessary. 1218 1219 Args: 1220 quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s 1221 representing the sample points and the corresponding (possibly 1222 normalized) weight. When `None`, defaults to: 1223 `np.polynomial.hermite.hermgauss(deg=8)`. 1224 dtype: The expected `dtype` of `grid` and `probs`. 1225 validate_args: Python `bool`, default `False`. When `True` distribution 1226 parameters are checked for validity despite possibly degrading runtime 1227 performance. When `False` invalid inputs may silently render incorrect 1228 outputs. 1229 name: Python `str` name prefixed to Ops created by this class. 1230 1231 Returns: 1232 quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s 1233 representing the sample points and the corresponding (possibly 1234 normalized) weight. 1235 1236 Raises: 1237 ValueError: if `quadrature_grid_and_probs is not None` and 1238 `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` 1239 """ 1240 with ops.name_scope(name, "process_quadrature_grid_and_probs", 1241 [quadrature_grid_and_probs]): 1242 if quadrature_grid_and_probs is None: 1243 grid, probs = np.polynomial.hermite.hermgauss(deg=8) 1244 grid = grid.astype(dtype.as_numpy_dtype) 1245 probs = probs.astype(dtype.as_numpy_dtype) 1246 probs /= np.linalg.norm(probs, ord=1, keepdims=True) 1247 grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) 1248 probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype) 1249 return grid, probs 1250 1251 grid, probs = tuple(quadrature_grid_and_probs) 1252 grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) 1253 probs = ops.convert_to_tensor(probs, name="unnormalized_probs", dtype=dtype) 1254 probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs") 1255 1256 def _static_event_size(x): 1257 """Returns the static size of a specific dimension or `None`.""" 1258 return tensor_shape.dimension_value(x.shape.with_rank_at_least(1)[-1]) 1259 1260 m, n = _static_event_size(probs), _static_event_size(grid) 1261 if m is not None and n is not None: 1262 if m != n: 1263 raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of " 1264 "same-length zero-th-dimension `Tensor`s " 1265 "(saw lengths {}, {})".format(m, n)) 1266 elif validate_args: 1267 assertions = [ 1268 check_ops.assert_equal( 1269 dimension_size(probs, axis=-1), 1270 dimension_size(grid, axis=-1), 1271 message=("`quadrature_grid_and_probs` must be a `tuple` of " 1272 "same-length zero-th-dimension `Tensor`s")), 1273 ] 1274 with ops.control_dependencies(assertions): 1275 grid = array_ops.identity(grid) 1276 probs = array_ops.identity(probs) 1277 return grid, probs 1278 1279 1280def pad(x, axis, front=False, back=False, value=0, count=1, name=None): 1281 """Pads `value` to the front and/or back of a `Tensor` dim, `count` times. 1282 1283 Args: 1284 x: `Tensor` input. 1285 axis: Scalar `int`-like `Tensor` representing the single dimension to pad. 1286 (Negative indexing is supported.) 1287 front: Python `bool`; if `True` the beginning of the `axis` dimension is 1288 padded with `value`, `count` times. If `False` no front padding is made. 1289 back: Python `bool`; if `True` the end of the `axis` dimension is padded 1290 with `value`, `count` times. If `False` no end padding is made. 1291 value: Scalar `int`-like `Tensor` representing the actual value added to the 1292 front and/or back of the `axis` dimension of `x`. 1293 count: Scalar `int`-like `Tensor` representing number of elements added to 1294 the front and/or back of the `axis` dimension of `x`. E.g., if `front = 1295 back = True` then `2 * count` elements are added. 1296 name: Python `str` name prefixed to Ops created by this function. 1297 1298 Returns: 1299 pad: The padded version of input `x`. 1300 1301 Raises: 1302 ValueError: if both `front` and `back` are `False`. 1303 TypeError: if `count` is not `int`-like. 1304 """ 1305 with ops.name_scope(name, "pad", [x, value, count]): 1306 x = ops.convert_to_tensor(x, name="x") 1307 value = ops.convert_to_tensor(value, dtype=x.dtype, name="value") 1308 count = ops.convert_to_tensor(count, name="count") 1309 if not count.dtype.is_integer: 1310 raise TypeError("`count.dtype` (`{}`) must be `int`-like.".format( 1311 count.dtype.name)) 1312 if not front and not back: 1313 raise ValueError("At least one of `front`, `back` must be `True`.") 1314 ndims = ( 1315 x.shape.ndims if x.shape.ndims is not None else array_ops.rank( 1316 x, name="ndims")) 1317 axis = ops.convert_to_tensor(axis, name="axis") 1318 axis_ = tensor_util.constant_value(axis) 1319 if axis_ is not None: 1320 axis = axis_ 1321 if axis < 0: 1322 axis = ndims + axis 1323 count_ = tensor_util.constant_value(count) 1324 if axis_ >= 0 or x.shape.ndims is not None: 1325 head = x.shape[:axis] 1326 middle = tensor_shape.TensorShape(None if count_ is None else ( 1327 tensor_shape.dimension_at_index(x.shape, axis) + count_ * 1328 (front + back))) 1329 tail = x.shape[axis + 1:] 1330 final_shape = head.concatenate(middle.concatenate(tail)) 1331 else: 1332 final_shape = None 1333 else: 1334 axis = array_ops.where_v2(axis < 0, ndims + axis, axis) 1335 final_shape = None 1336 x = array_ops.pad( 1337 x, 1338 paddings=array_ops.one_hot( 1339 indices=array_ops.stack( 1340 [axis if front else -1, axis if back else -1]), 1341 depth=ndims, 1342 axis=0, 1343 on_value=count, 1344 dtype=dtypes.int32), 1345 constant_values=value) 1346 if final_shape is not None: 1347 x.set_shape(final_shape) 1348 return x 1349 1350 1351def parent_frame_arguments(): 1352 """Returns parent frame arguments. 1353 1354 When called inside a function, returns a dictionary with the caller's function 1355 arguments. These are positional arguments and keyword arguments (**kwargs), 1356 while variable arguments (*varargs) are excluded. 1357 1358 When called at global scope, this will return an empty dictionary, since there 1359 are no arguments. 1360 1361 WARNING: If caller function argument names are overloaded before invoking 1362 this method, then values will reflect the overloaded value. For this reason, 1363 we recommend calling `parent_frame_arguments` at the beginning of the 1364 function. 1365 """ 1366 # All arguments and the names used for *varargs, and **kwargs 1367 arg_names, variable_arg_name, keyword_arg_name, local_vars = ( 1368 tf_inspect._inspect.getargvalues( # pylint: disable=protected-access 1369 # Get the first frame of the caller of this method. 1370 tf_inspect._inspect.stack()[1][0])) # pylint: disable=protected-access 1371 1372 # Remove the *varargs, and flatten the **kwargs. Both are 1373 # nested lists. 1374 local_vars.pop(variable_arg_name, {}) 1375 keyword_args = local_vars.pop(keyword_arg_name, {}) 1376 1377 final_args = {} 1378 # Copy over arguments and their values. In general, local_vars 1379 # may contain more than just the arguments, since this method 1380 # can be called anywhere in a function. 1381 for arg_name in arg_names: 1382 final_args[arg_name] = local_vars.pop(arg_name) 1383 final_args.update(keyword_args) 1384 1385 return final_args 1386 1387 1388class AppendDocstring: 1389 """Helper class to promote private subclass docstring to public counterpart. 1390 1391 Example: 1392 1393 ```python 1394 class TransformedDistribution(Distribution): 1395 @distribution_util.AppendDocstring( 1396 additional_note="A special note!", 1397 kwargs_dict={"foo": "An extra arg."}) 1398 def _prob(self, y, foo=None): 1399 pass 1400 ``` 1401 1402 In this case, the `AppendDocstring` decorator appends the `additional_note` to 1403 the docstring of `prob` (not `_prob`) and adds a new `kwargs` 1404 section with each dictionary item as a bullet-point. 1405 1406 For a more detailed example, see `TransformedDistribution`. 1407 """ 1408 1409 def __init__(self, additional_note="", kwargs_dict=None): 1410 """Initializes the AppendDocstring object. 1411 1412 Args: 1413 additional_note: Python string added as additional docstring to public 1414 version of function. 1415 kwargs_dict: Python string/string dictionary representing specific kwargs 1416 expanded from the **kwargs input. 1417 1418 Raises: 1419 ValueError: if kwargs_dict.key contains whitespace. 1420 ValueError: if kwargs_dict.value contains newlines. 1421 """ 1422 self._additional_note = additional_note 1423 if kwargs_dict: 1424 bullets = [] 1425 for key in sorted(kwargs_dict.keys()): 1426 value = kwargs_dict[key] 1427 if any(x.isspace() for x in key): 1428 raise ValueError("Parameter name \"%s\" contains whitespace." % key) 1429 value = value.lstrip() 1430 if "\n" in value: 1431 raise ValueError( 1432 "Parameter description for \"%s\" contains newlines." % key) 1433 bullets.append("* `%s`: %s" % (key, value)) 1434 self._additional_note += ("\n\n##### `kwargs`:\n\n" + "\n".join(bullets)) 1435 1436 def __call__(self, fn): 1437 1438 @functools.wraps(fn) 1439 def _fn(*args, **kwargs): 1440 return fn(*args, **kwargs) 1441 1442 if _fn.__doc__ is None: 1443 _fn.__doc__ = self._additional_note 1444 else: 1445 _fn.__doc__ += "\n%s" % self._additional_note 1446 return _fn 1447