1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Internal utilities for `LinearOperator` classes.""" 16 17import numpy as np 18 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.module import module 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import check_ops 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.ops import linalg_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import variables as variables_module 28from tensorflow.python.util import nest 29 30 31################################################################################ 32# To make more friendly for TF2. 33################################################################################ 34 35 36def convert_nonref_to_tensor(value, dtype=None, dtype_hint=None, name=None): 37 """Converts the given `value` to a `Tensor` if input is nonreference type. 38 39 This function converts Python objects of various types to `Tensor` objects 40 except if the input has nonreference semantics. Reference semantics are 41 characterized by `is_ref` and is any object which is a 42 `tf.Variable` or instance of `tf.Module`. This function accepts any input 43 which `tf.convert_to_tensor` would also. 44 45 Note: This function diverges from default Numpy behavior for `float` and 46 `string` types when `None` is present in a Python list or scalar. Rather 47 than silently converting `None` values, an error will be thrown. 48 49 Args: 50 value: An object whose type has a registered `Tensor` conversion function. 51 dtype: Optional element type for the returned tensor. If missing, the 52 type is inferred from the type of `value`. 53 dtype_hint: Optional element type for the returned tensor, 54 used when dtype is None. In some cases, a caller may not have a 55 dtype in mind when converting to a tensor, so dtype_hint 56 can be used as a soft preference. If the conversion to 57 `dtype_hint` is not possible, this argument has no effect. 58 name: Optional name to use if a new `Tensor` is created. 59 60 Returns: 61 tensor: A `Tensor` based on `value`. 62 63 Raises: 64 TypeError: If no conversion function is registered for `value` to `dtype`. 65 RuntimeError: If a registered conversion function returns an invalid value. 66 ValueError: If the `value` is a tensor not of given `dtype` in graph mode. 67 68 69 #### Examples: 70 71 ```python 72 73 x = tf.Variable(0.) 74 y = convert_nonref_to_tensor(x) 75 x is y 76 # ==> True 77 78 x = tf.constant(0.) 79 y = convert_nonref_to_tensor(x) 80 x is y 81 # ==> True 82 83 x = np.array(0.) 84 y = convert_nonref_to_tensor(x) 85 x is y 86 # ==> False 87 tf.is_tensor(y) 88 # ==> True 89 90 x = tfp.util.DeferredTensor(13.37, lambda x: x) 91 y = convert_nonref_to_tensor(x) 92 x is y 93 # ==> True 94 tf.is_tensor(y) 95 # ==> False 96 tf.equal(y, 13.37) 97 # ==> True 98 ``` 99 100 """ 101 # We explicitly do not use a tf.name_scope to avoid graph clutter. 102 if value is None: 103 return None 104 if is_ref(value): 105 if dtype is None: 106 return value 107 dtype_base = base_dtype(dtype) 108 value_dtype_base = base_dtype(value.dtype) 109 if dtype_base != value_dtype_base: 110 raise TypeError( 111 f"Argument `value` must be of dtype `{dtype_name(dtype_base)}` " 112 f"Received: `{dtype_name(value_dtype_base)}`.") 113 return value 114 return ops.convert_to_tensor_v2_with_dispatch( 115 value, dtype=dtype, dtype_hint=dtype_hint, name=name) 116 117 118def base_dtype(dtype): 119 """Returns a non-reference `dtype` based on this `dtype`.""" 120 dtype = dtypes.as_dtype(dtype) 121 if hasattr(dtype, "base_dtype"): 122 return dtype.base_dtype 123 return dtype 124 125 126def dtype_name(dtype): 127 """Returns the string name for this `dtype`.""" 128 dtype = dtypes.as_dtype(dtype) 129 if hasattr(dtype, "name"): 130 return dtype.name 131 if hasattr(dtype, "__name__"): 132 return dtype.__name__ 133 return str(dtype) 134 135 136def check_dtype(arg, dtype): 137 """Check that arg.dtype == self.dtype.""" 138 if arg.dtype.base_dtype != dtype: 139 raise TypeError( 140 f"Expected argument to have dtype {dtype}. Found: {arg.dtype} in " 141 f"tensor {arg}.") 142 143 144def is_ref(x): 145 """Evaluates if the object has reference semantics. 146 147 An object is deemed "reference" if it is a `tf.Variable` instance or is 148 derived from a `tf.Module` with `dtype` and `shape` properties. 149 150 Args: 151 x: Any object. 152 153 Returns: 154 is_ref: Python `bool` indicating input is has nonreference semantics, i.e., 155 is a `tf.Variable` or a `tf.Module` with `dtype` and `shape` properties. 156 """ 157 return ( 158 # Note: we check that tf.Variable is a class because we might be using a 159 # different backend other than TF. 160 isinstance(x, variables_module.Variable) or 161 (isinstance(x, module.Module) and hasattr(x, "dtype") and 162 hasattr(x, "shape"))) 163 164 165def assert_not_ref_type(x, arg_name): 166 if is_ref(x): 167 raise TypeError( 168 f"Argument {arg_name} cannot be reference type. Found: {type(x)}.") 169 170 171################################################################################ 172# Asserts. 173################################################################################ 174 175 176def assert_no_entries_with_modulus_zero( 177 x, message=None, name="assert_no_entries_with_modulus_zero"): 178 """Returns `Op` that asserts Tensor `x` has no entries with modulus zero. 179 180 Args: 181 x: Numeric `Tensor`, real, integer, or complex. 182 message: A string message to prepend to failure message. 183 name: A name to give this `Op`. 184 185 Returns: 186 An `Op` that asserts `x` has no entries with modulus zero. 187 """ 188 with ops.name_scope(name, values=[x]): 189 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 190 dtype = x.dtype.base_dtype 191 should_be_nonzero = math_ops.abs(x) 192 zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype) 193 return check_ops.assert_less(zero, should_be_nonzero, message=message) 194 195 196def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"): 197 """Returns `Op` that asserts Tensor `x` has no non-zero imaginary parts. 198 199 Args: 200 x: Numeric `Tensor`, real, integer, or complex. 201 message: A string message to prepend to failure message. 202 name: A name to give this `Op`. 203 204 Returns: 205 An `Op` that asserts `x` has no entries with modulus zero. 206 """ 207 with ops.name_scope(name, values=[x]): 208 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 209 dtype = x.dtype.base_dtype 210 211 if dtype.is_floating: 212 return control_flow_ops.no_op() 213 214 zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype) 215 return check_ops.assert_equal(zero, math_ops.imag(x), message=message) 216 217 218def assert_compatible_matrix_dimensions(operator, x): 219 """Assert that an argument to solve/matmul has proper domain dimension. 220 221 If `operator.shape[-2:] = [M, N]`, and `x.shape[-2:] = [Q, R]`, then 222 `operator.matmul(x)` is defined only if `N = Q`. This `Op` returns an 223 `Assert` that "fires" if this is not the case. Static checks are already 224 done by the base class `LinearOperator`. 225 226 Args: 227 operator: `LinearOperator`. 228 x: `Tensor`. 229 230 Returns: 231 `Assert` `Op`. 232 """ 233 # Static checks are done in the base class. Only tensor asserts here. 234 assert_same_dd = check_ops.assert_equal( 235 array_ops.shape(x)[-2], 236 operator.domain_dimension_tensor(), 237 # This error message made to look similar to error raised by static check 238 # in the base class. 239 message=("Dimensions are not compatible. " 240 "shape[-2] of argument to be the same as this operator")) 241 242 return assert_same_dd 243 244 245def assert_is_batch_matrix(tensor): 246 """Static assert that `tensor` has rank `2` or higher.""" 247 sh = tensor.shape 248 if sh.ndims is not None and sh.ndims < 2: 249 raise ValueError( 250 f"Expected [batch] matrix to have at least two dimensions. Found: " 251 f"{tensor}.") 252 253 254def shape_tensor(shape, name=None): 255 """Convert Tensor using default type, unless empty list or tuple.""" 256 # Works just like random_ops._ShapeTensor. 257 if isinstance(shape, (tuple, list)) and not shape: 258 dtype = dtypes.int32 259 else: 260 dtype = None 261 return ops.convert_to_tensor_v2_with_dispatch(shape, dtype=dtype, name=name) 262 263 264################################################################################ 265# Broadcasting versions of common linear algebra functions. 266# TODO(b/77519145) Do this more efficiently in some special cases. 267################################################################################ 268 269 270def broadcast_matrix_batch_dims(batch_matrices, name=None): 271 """Broadcast leading dimensions of zero or more [batch] matrices. 272 273 Example broadcasting one batch dim of two simple matrices. 274 275 ```python 276 x = [[1, 2], 277 [3, 4]] # Shape [2, 2], no batch dims 278 279 y = [[[1]]] # Shape [1, 1, 1], 1 batch dim of shape [1] 280 281 x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) 282 283 x_bc 284 ==> [[[1, 2], 285 [3, 4]]] # Shape [1, 2, 2], 1 batch dim of shape [1]. 286 287 y_bc 288 ==> same as y 289 ``` 290 291 Example broadcasting many batch dims 292 293 ```python 294 x = tf.random.normal(shape=(2, 3, 1, 4, 4)) 295 y = tf.random.normal(shape=(1, 3, 2, 5, 5)) 296 x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) 297 298 x_bc.shape 299 ==> (2, 3, 2, 4, 4) 300 301 y_bc.shape 302 ==> (2, 3, 2, 5, 5) 303 ``` 304 305 Args: 306 batch_matrices: Iterable of `Tensor`s, each having two or more dimensions. 307 name: A string name to prepend to created ops. 308 309 Returns: 310 bcast_matrices: List of `Tensor`s, with `bcast_matrices[i]` containing 311 the values from `batch_matrices[i]`, with possibly broadcast batch dims. 312 313 Raises: 314 ValueError: If any input `Tensor` is statically determined to have less 315 than two dimensions. 316 """ 317 with ops.name_scope( 318 name or "broadcast_matrix_batch_dims", values=batch_matrices): 319 check_ops.assert_proper_iterable(batch_matrices) 320 batch_matrices = list(batch_matrices) 321 322 for i, mat in enumerate(batch_matrices): 323 batch_matrices[i] = ops.convert_to_tensor_v2_with_dispatch(mat) 324 assert_is_batch_matrix(batch_matrices[i]) 325 326 if len(batch_matrices) < 2: 327 return batch_matrices 328 329 # Try static broadcasting. 330 # bcast_batch_shape is the broadcast batch shape of ALL matrices. 331 # E.g. if batch_matrices = [x, y], with 332 # x.shape = [2, j, k] (batch shape = [2]) 333 # y.shape = [3, 1, l, m] (batch shape = [3, 1]) 334 # ==> bcast_batch_shape = [3, 2] 335 bcast_batch_shape = batch_matrices[0].shape[:-2] 336 for mat in batch_matrices[1:]: 337 bcast_batch_shape = array_ops.broadcast_static_shape( 338 bcast_batch_shape, 339 mat.shape[:-2]) 340 if bcast_batch_shape.is_fully_defined(): 341 for i, mat in enumerate(batch_matrices): 342 if mat.shape[:-2] != bcast_batch_shape: 343 bcast_shape = array_ops.concat( 344 [bcast_batch_shape.as_list(), array_ops.shape(mat)[-2:]], axis=0) 345 batch_matrices[i] = array_ops.broadcast_to(mat, bcast_shape) 346 return batch_matrices 347 348 # Since static didn't work, do dynamic, which always copies data. 349 bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2] 350 for mat in batch_matrices[1:]: 351 bcast_batch_shape = array_ops.broadcast_dynamic_shape( 352 bcast_batch_shape, 353 array_ops.shape(mat)[:-2]) 354 for i, mat in enumerate(batch_matrices): 355 batch_matrices[i] = array_ops.broadcast_to( 356 mat, 357 array_ops.concat( 358 [bcast_batch_shape, array_ops.shape(mat)[-2:]], axis=0)) 359 360 return batch_matrices 361 362 363def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None): 364 """Solve systems of linear equations.""" 365 with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]): 366 matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix") 367 rhs = ops.convert_to_tensor_v2_with_dispatch( 368 rhs, name="rhs", dtype=matrix.dtype) 369 370 # If either matrix/rhs has extra dims, we can reshape to get rid of them. 371 matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency( 372 matrix, rhs, adjoint_a=adjoint) 373 374 # This will broadcast by brute force if we still need to. 375 matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs]) 376 377 solution = linalg_ops.matrix_solve( 378 matrix, rhs, adjoint=adjoint and still_need_to_transpose) 379 380 return reshape_inv(solution) 381 382 383def _reshape_for_efficiency(a, 384 b, 385 transpose_a=False, 386 transpose_b=False, 387 adjoint_a=False, 388 adjoint_b=False): 389 """Maybe reshape a, b, and return an inverse map. For matmul/solve.""" 390 def identity(x): 391 return x 392 393 # At this point, we have not taken transpose/adjoint of a/b. 394 still_need_to_transpose = True 395 396 if a.shape.ndims is None or b.shape.ndims is None: 397 return a, b, identity, still_need_to_transpose 398 399 # This could be handled in the future, but seems less common. 400 if a.shape.ndims >= b.shape.ndims: 401 return a, b, identity, still_need_to_transpose 402 403 # From now on, we might modify b, but will not modify a. 404 405 # Suppose: 406 # a.shape = C + [m, n], b.shape = 407 # b.shape = S + C + [n, r] 408 b_extra_ndims = b.shape.ndims - a.shape.ndims 409 410 # b_extra_sh = S, b_main_sh = C + [n, r] 411 b_extra_sh = array_ops.shape(b)[:b_extra_ndims] 412 b_main_sh = array_ops.shape(b)[b_extra_ndims:] 413 414 # No reason to flip unless the extra dims of b are big enough. Why? 415 # Assume adjoint/transpose = False. Then... 416 # By not flipping, we have to replicate a to shape 417 # b_extra_sh + a.shape, 418 # which could use extra memory. But in all cases, the final output has shape 419 # b_extra_sh + a.shape[:-1] + [b.shape[-1]] 420 # So we only end up creating a larger object if the end dim of b is smaller 421 # than the end dim of a. This often happens, e.g. if b was a vector that was 422 # expanded to a matrix (by appending a singleton). 423 424 # Since adjoint/transpose may not be False, we must make adjustments here. 425 # The dim of b that holds the multiple equations. 426 a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1] 427 b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1] 428 b_extra_sz_ = ( 429 np.prod(b.shape[:b_extra_ndims].as_list()) 430 if b.shape[:b_extra_ndims].is_fully_defined() else None) 431 if (a_domain_sz_ is not None and b_eq_sz_ is not None and 432 b_extra_sz_ is not None): 433 if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_: 434 return a, b, identity, still_need_to_transpose 435 436 # At this point, we're flipping for sure! 437 # Any transposes/adjoints will happen here explicitly, rather than in calling 438 # code. Why? To avoid having to write separate complex code for each case. 439 if adjoint_a: 440 a = array_ops.matrix_transpose(a, conjugate=True) 441 elif transpose_a: 442 a = array_ops.matrix_transpose(a, conjugate=False) 443 if adjoint_b: 444 b = array_ops.matrix_transpose(b, conjugate=True) 445 elif transpose_a: 446 b = array_ops.matrix_transpose(b, conjugate=False) 447 still_need_to_transpose = False 448 449 # Recompute shapes, since the transpose/adjoint may have changed them. 450 b_extra_sh = array_ops.shape(b)[:b_extra_ndims] 451 b_main_sh = array_ops.shape(b)[b_extra_ndims:] 452 453 # Permutation to put the extra dims at the end. 454 perm = ( 455 np.concatenate( 456 (np.arange(b_extra_ndims, b.shape.ndims), 457 np.arange(0, b_extra_ndims)), 0)) 458 b_extra_on_end = array_ops.transpose(b, perm=perm) 459 460 # Now squash this end into one long dim. 461 b_squashed_end = array_ops.reshape( 462 b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0)) 463 464 def reshape_inv(y): 465 # Expand the extra dims hanging off the end, "b_extra_sh". 466 # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y 467 # Could have different batch dims than a and b, because of broadcasting. 468 y_extra_shape = array_ops.concat( 469 (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) 470 y_extra_on_end = array_ops.reshape(y, y_extra_shape) 471 inverse_perm = np.argsort(perm) 472 return array_ops.transpose(y_extra_on_end, perm=inverse_perm) 473 474 return a, b_squashed_end, reshape_inv, still_need_to_transpose 475 476 477################################################################################ 478# Helpers for hints. 479################################################################################ 480 481 482def use_operator_or_provided_hint_unless_contradicting( 483 operator, hint_attr_name, provided_hint_value, message): 484 """Get combined hint in the case where operator.hint should equal hint. 485 486 Args: 487 operator: LinearOperator that a meta-operator was initialized with. 488 hint_attr_name: String name for the attribute. 489 provided_hint_value: Bool or None. Value passed by user in initialization. 490 message: Error message to print if hints contradict. 491 492 Returns: 493 True, False, or None. 494 495 Raises: 496 ValueError: If hints contradict. 497 """ 498 op_hint = getattr(operator, hint_attr_name) 499 # pylint: disable=g-bool-id-comparison 500 if op_hint is False and provided_hint_value: 501 raise ValueError(message) 502 if op_hint and provided_hint_value is False: 503 raise ValueError(message) 504 if op_hint or provided_hint_value: 505 return True 506 if op_hint is False or provided_hint_value is False: 507 return False 508 # pylint: enable=g-bool-id-comparison 509 return None 510 511 512################################################################################ 513# Utilities for blockwise operators. 514################################################################################ 515 516 517def arg_is_blockwise(block_dimensions, arg, arg_split_dim): 518 """Detect if input should be interpreted as a list of blocks.""" 519 # Tuples and lists of length equal to the number of operators may be 520 # blockwise. 521 if (isinstance(arg, (tuple, list)) and len(arg) == len(block_dimensions)): 522 # If the elements of the iterable are not nested, interpret the input as 523 # blockwise. 524 if not any(nest.is_nested(x) for x in arg): 525 return True 526 else: 527 arg_dims = [ops.convert_to_tensor_v2_with_dispatch( 528 x).shape[arg_split_dim] for x in arg] 529 self_dims = [dim.value for dim in block_dimensions] 530 531 # If none of the operator dimensions are known, interpret the input as 532 # blockwise if its matching dimensions are unequal. 533 if all(self_d is None for self_d in self_dims): 534 535 # A nested tuple/list with a single outermost element is not blockwise 536 if len(arg_dims) == 1: 537 return False 538 elif any(dim != arg_dims[0] for dim in arg_dims): 539 return True 540 else: 541 raise ValueError( 542 "Parsing of the input structure is ambiguous. Please input " 543 "a blockwise iterable of `Tensor`s or a single `Tensor`.") 544 545 # If input dimensions equal the respective (known) blockwise operator 546 # dimensions, then the input is blockwise. 547 if all(self_d == arg_d or self_d is None 548 for self_d, arg_d in zip(self_dims, arg_dims)): 549 return True 550 551 # If input dimensions equals are all equal, and are greater than or equal 552 # to the sum of the known operator dimensions, interpret the input as 553 # blockwise. 554 # input is not blockwise. 555 self_dim = sum(self_d for self_d in self_dims if self_d is not None) 556 if all(s == arg_dims[0] for s in arg_dims) and arg_dims[0] >= self_dim: 557 return False 558 559 # If none of these conditions is met, the input shape is mismatched. 560 raise ValueError("Input dimension does not match operator dimension.") 561 else: 562 return False 563 564 565def split_arg_into_blocks(block_dims, block_dims_fn, arg, axis=-1): 566 """Split `x` into blocks matching `operators`'s `domain_dimension`. 567 568 Specifically, if we have a blockwise lower-triangular matrix, with block 569 sizes along the diagonal `[M_j, M_j] j = 0,1,2..J`, this method splits `arg` 570 on `axis` into `J` tensors, whose shape at `axis` is `M_j`. 571 572 Args: 573 block_dims: Iterable of `TensorShapes`. 574 block_dims_fn: Callable returning an iterable of `Tensor`s. 575 arg: `Tensor`. `arg` is split into `J` tensors. 576 axis: Python `Integer` representing the axis to split `arg` on. 577 578 Returns: 579 A list of `Tensor`s. 580 """ 581 block_sizes = [dim.value for dim in block_dims] 582 if any(d is None for d in block_sizes): 583 block_sizes = block_dims_fn() 584 return array_ops.split(arg, block_sizes, axis=axis) 585