1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base class for linear operators.""" 16 17import abc 18import contextlib 19 20import numpy as np 21 22from tensorflow.python.framework import composite_tensor 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.framework import tensor_spec 27from tensorflow.python.framework import tensor_util 28from tensorflow.python.framework import type_spec 29from tensorflow.python.module import module 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import check_ops 32from tensorflow.python.ops import linalg_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import resource_variable_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.ops.linalg import linalg_impl as linalg 37from tensorflow.python.ops.linalg import linear_operator_algebra 38from tensorflow.python.ops.linalg import linear_operator_util 39from tensorflow.python.ops.linalg import slicing 40from tensorflow.python.platform import tf_logging as logging 41from tensorflow.python.trackable import data_structures 42from tensorflow.python.util import deprecation 43from tensorflow.python.util import dispatch 44from tensorflow.python.util import nest 45from tensorflow.python.util import variable_utils 46from tensorflow.python.util.tf_export import tf_export 47 48__all__ = ["LinearOperator"] 49 50 51# TODO(langmore) Use matrix_solve_ls for singular or non-square matrices. 52@tf_export("linalg.LinearOperator") 53class LinearOperator( 54 module.Module, composite_tensor.CompositeTensor, metaclass=abc.ABCMeta): 55 """Base class defining a [batch of] linear operator[s]. 56 57 Subclasses of `LinearOperator` provide access to common methods on a 58 (batch) matrix, without the need to materialize the matrix. This allows: 59 60 * Matrix free computations 61 * Operators that take advantage of special structure, while providing a 62 consistent API to users. 63 64 #### Subclassing 65 66 To enable a public method, subclasses should implement the leading-underscore 67 version of the method. The argument signature should be identical except for 68 the omission of `name="..."`. For example, to enable 69 `matmul(x, adjoint=False, name="matmul")` a subclass should implement 70 `_matmul(x, adjoint=False)`. 71 72 #### Performance contract 73 74 Subclasses should only implement the assert methods 75 (e.g. `assert_non_singular`) if they can be done in less than `O(N^3)` 76 time. 77 78 Class docstrings should contain an explanation of computational complexity. 79 Since this is a high-performance library, attention should be paid to detail, 80 and explanations can include constants as well as Big-O notation. 81 82 #### Shape compatibility 83 84 `LinearOperator` subclasses should operate on a [batch] matrix with 85 compatible shape. Class docstrings should define what is meant by compatible 86 shape. Some subclasses may not support batching. 87 88 Examples: 89 90 `x` is a batch matrix with compatible shape for `matmul` if 91 92 ``` 93 operator.shape = [B1,...,Bb] + [M, N], b >= 0, 94 x.shape = [B1,...,Bb] + [N, R] 95 ``` 96 97 `rhs` is a batch matrix with compatible shape for `solve` if 98 99 ``` 100 operator.shape = [B1,...,Bb] + [M, N], b >= 0, 101 rhs.shape = [B1,...,Bb] + [M, R] 102 ``` 103 104 #### Example docstring for subclasses. 105 106 This operator acts like a (batch) matrix `A` with shape 107 `[B1,...,Bb, M, N]` for some `b >= 0`. The first `b` indices index a 108 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 109 an `m x n` matrix. Again, this matrix `A` may not be materialized, but for 110 purposes of identifying and working with compatible arguments the shape is 111 relevant. 112 113 Examples: 114 115 ```python 116 some_tensor = ... shape = ???? 117 operator = MyLinOp(some_tensor) 118 119 operator.shape() 120 ==> [2, 4, 4] 121 122 operator.log_abs_determinant() 123 ==> Shape [2] Tensor 124 125 x = ... Shape [2, 4, 5] Tensor 126 127 operator.matmul(x) 128 ==> Shape [2, 4, 5] Tensor 129 ``` 130 131 #### Shape compatibility 132 133 This operator acts on batch matrices with compatible shape. 134 FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE 135 136 #### Performance 137 138 FILL THIS IN 139 140 #### Matrix property hints 141 142 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 143 for `X = non_singular, self_adjoint, positive_definite, square`. 144 These have the following meaning: 145 146 * If `is_X == True`, callers should expect the operator to have the 147 property `X`. This is a promise that should be fulfilled, but is *not* a 148 runtime assert. For example, finite floating point precision may result 149 in these promises being violated. 150 * If `is_X == False`, callers should expect the operator to not have `X`. 151 * If `is_X == None` (the default), callers should have no expectation either 152 way. 153 154 #### Initialization parameters 155 156 All subclasses of `LinearOperator` are expected to pass a `parameters` 157 argument to `super().__init__()`. This should be a `dict` containing 158 the unadulterated arguments passed to the subclass `__init__`. For example, 159 `MyLinearOperator` with an initializer should look like: 160 161 ```python 162 def __init__(self, operator, is_square=False, name=None): 163 parameters = dict( 164 operator=operator, 165 is_square=is_square, 166 name=name 167 ) 168 ... 169 super().__init__(..., parameters=parameters) 170 ``` 171 172 Users can then access `my_linear_operator.parameters` to see all arguments 173 passed to its initializer. 174 """ 175 176 # TODO(b/143910018) Remove graph_parents in V3. 177 @deprecation.deprecated_args(None, "Do not pass `graph_parents`. They will " 178 " no longer be used.", "graph_parents") 179 def __init__(self, 180 dtype, 181 graph_parents=None, 182 is_non_singular=None, 183 is_self_adjoint=None, 184 is_positive_definite=None, 185 is_square=None, 186 name=None, 187 parameters=None): 188 """Initialize the `LinearOperator`. 189 190 **This is a private method for subclass use.** 191 **Subclasses should copy-paste this `__init__` documentation.** 192 193 Args: 194 dtype: The type of the this `LinearOperator`. Arguments to `matmul` and 195 `solve` will have to be this type. 196 graph_parents: (Deprecated) Python list of graph prerequisites of this 197 `LinearOperator` Typically tensors that are passed during initialization 198 is_non_singular: Expect that this operator is non-singular. 199 is_self_adjoint: Expect that this operator is equal to its hermitian 200 transpose. If `dtype` is real, this is equivalent to being symmetric. 201 is_positive_definite: Expect that this operator is positive definite, 202 meaning the quadratic form `x^H A x` has positive real part for all 203 nonzero `x`. Note that we do not require the operator to be 204 self-adjoint to be positive-definite. See: 205 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 206 is_square: Expect that this operator acts like square [batch] matrices. 207 name: A name for this `LinearOperator`. 208 parameters: Python `dict` of parameters used to instantiate this 209 `LinearOperator`. 210 211 Raises: 212 ValueError: If any member of graph_parents is `None` or not a `Tensor`. 213 ValueError: If hints are set incorrectly. 214 """ 215 # Check and auto-set flags. 216 if is_positive_definite: 217 if is_non_singular is False: 218 raise ValueError("A positive definite matrix is always non-singular.") 219 is_non_singular = True 220 221 if is_non_singular: 222 if is_square is False: 223 raise ValueError("A non-singular matrix is always square.") 224 is_square = True 225 226 if is_self_adjoint: 227 if is_square is False: 228 raise ValueError("A self-adjoint matrix is always square.") 229 is_square = True 230 231 self._is_square_set_or_implied_by_hints = is_square 232 233 if graph_parents is not None: 234 self._set_graph_parents(graph_parents) 235 else: 236 self._graph_parents = [] 237 self._dtype = dtypes.as_dtype(dtype).base_dtype if dtype else dtype 238 self._is_non_singular = is_non_singular 239 self._is_self_adjoint = is_self_adjoint 240 self._is_positive_definite = is_positive_definite 241 self._parameters = self._no_dependency(parameters) 242 self._parameters_sanitized = False 243 self._name = name or type(self).__name__ 244 245 @contextlib.contextmanager 246 def _name_scope(self, name=None): # pylint: disable=method-hidden 247 """Helper function to standardize op scope.""" 248 full_name = self.name 249 if name is not None: 250 full_name += "/" + name 251 with ops.name_scope(full_name) as scope: 252 yield scope 253 254 @property 255 def parameters(self): 256 """Dictionary of parameters used to instantiate this `LinearOperator`.""" 257 return dict(self._parameters) 258 259 @property 260 def dtype(self): 261 """The `DType` of `Tensor`s handled by this `LinearOperator`.""" 262 return self._dtype 263 264 @property 265 def name(self): 266 """Name prepended to all ops created by this `LinearOperator`.""" 267 return self._name 268 269 @property 270 @deprecation.deprecated(None, "Do not call `graph_parents`.") 271 def graph_parents(self): 272 """List of graph dependencies of this `LinearOperator`.""" 273 return self._graph_parents 274 275 @property 276 def is_non_singular(self): 277 return self._is_non_singular 278 279 @property 280 def is_self_adjoint(self): 281 return self._is_self_adjoint 282 283 @property 284 def is_positive_definite(self): 285 return self._is_positive_definite 286 287 @property 288 def is_square(self): 289 """Return `True/False` depending on if this operator is square.""" 290 # Static checks done after __init__. Why? Because domain/range dimension 291 # sometimes requires lots of work done in the derived class after init. 292 auto_square_check = self.domain_dimension == self.range_dimension 293 if self._is_square_set_or_implied_by_hints is False and auto_square_check: 294 raise ValueError( 295 "User set is_square hint to False, but the operator was square.") 296 if self._is_square_set_or_implied_by_hints is None: 297 return auto_square_check 298 299 return self._is_square_set_or_implied_by_hints 300 301 @abc.abstractmethod 302 def _shape(self): 303 # Write this in derived class to enable all static shape methods. 304 raise NotImplementedError("_shape is not implemented.") 305 306 @property 307 def shape(self): 308 """`TensorShape` of this `LinearOperator`. 309 310 If this operator acts like the batch matrix `A` with 311 `A.shape = [B1,...,Bb, M, N]`, then this returns 312 `TensorShape([B1,...,Bb, M, N])`, equivalent to `A.shape`. 313 314 Returns: 315 `TensorShape`, statically determined, may be undefined. 316 """ 317 return self._shape() 318 319 def _shape_tensor(self): 320 # This is not an abstractmethod, since we want derived classes to be able to 321 # override this with optional kwargs, which can reduce the number of 322 # `convert_to_tensor` calls. See derived classes for examples. 323 raise NotImplementedError("_shape_tensor is not implemented.") 324 325 def shape_tensor(self, name="shape_tensor"): 326 """Shape of this `LinearOperator`, determined at runtime. 327 328 If this operator acts like the batch matrix `A` with 329 `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding 330 `[B1,...,Bb, M, N]`, equivalent to `tf.shape(A)`. 331 332 Args: 333 name: A name for this `Op`. 334 335 Returns: 336 `int32` `Tensor` 337 """ 338 with self._name_scope(name): # pylint: disable=not-callable 339 # Prefer to use statically defined shape if available. 340 if self.shape.is_fully_defined(): 341 return linear_operator_util.shape_tensor(self.shape.as_list()) 342 else: 343 return self._shape_tensor() 344 345 @property 346 def batch_shape(self): 347 """`TensorShape` of batch dimensions of this `LinearOperator`. 348 349 If this operator acts like the batch matrix `A` with 350 `A.shape = [B1,...,Bb, M, N]`, then this returns 351 `TensorShape([B1,...,Bb])`, equivalent to `A.shape[:-2]` 352 353 Returns: 354 `TensorShape`, statically determined, may be undefined. 355 """ 356 # Derived classes get this "for free" once .shape is implemented. 357 return self.shape[:-2] 358 359 def batch_shape_tensor(self, name="batch_shape_tensor"): 360 """Shape of batch dimensions of this operator, determined at runtime. 361 362 If this operator acts like the batch matrix `A` with 363 `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding 364 `[B1,...,Bb]`. 365 366 Args: 367 name: A name for this `Op`. 368 369 Returns: 370 `int32` `Tensor` 371 """ 372 # Derived classes get this "for free" once .shape() is implemented. 373 with self._name_scope(name): # pylint: disable=not-callable 374 return self._batch_shape_tensor() 375 376 def _batch_shape_tensor(self, shape=None): 377 # `shape` may be passed in if this can be pre-computed in a 378 # more efficient manner, e.g. without excessive Tensor conversions. 379 if self.batch_shape.is_fully_defined(): 380 return linear_operator_util.shape_tensor( 381 self.batch_shape.as_list(), name="batch_shape") 382 else: 383 shape = self.shape_tensor() if shape is None else shape 384 return shape[:-2] 385 386 @property 387 def tensor_rank(self, name="tensor_rank"): 388 """Rank (in the sense of tensors) of matrix corresponding to this operator. 389 390 If this operator acts like the batch matrix `A` with 391 `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. 392 393 Args: 394 name: A name for this `Op`. 395 396 Returns: 397 Python integer, or None if the tensor rank is undefined. 398 """ 399 # Derived classes get this "for free" once .shape() is implemented. 400 with self._name_scope(name): # pylint: disable=not-callable 401 return self.shape.ndims 402 403 def tensor_rank_tensor(self, name="tensor_rank_tensor"): 404 """Rank (in the sense of tensors) of matrix corresponding to this operator. 405 406 If this operator acts like the batch matrix `A` with 407 `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`. 408 409 Args: 410 name: A name for this `Op`. 411 412 Returns: 413 `int32` `Tensor`, determined at runtime. 414 """ 415 # Derived classes get this "for free" once .shape() is implemented. 416 with self._name_scope(name): # pylint: disable=not-callable 417 return self._tensor_rank_tensor() 418 419 def _tensor_rank_tensor(self, shape=None): 420 # `shape` may be passed in if this can be pre-computed in a 421 # more efficient manner, e.g. without excessive Tensor conversions. 422 if self.tensor_rank is not None: 423 return ops.convert_to_tensor_v2_with_dispatch(self.tensor_rank) 424 else: 425 shape = self.shape_tensor() if shape is None else shape 426 return array_ops.size(shape) 427 428 @property 429 def domain_dimension(self): 430 """Dimension (in the sense of vector spaces) of the domain of this operator. 431 432 If this operator acts like the batch matrix `A` with 433 `A.shape = [B1,...,Bb, M, N]`, then this returns `N`. 434 435 Returns: 436 `Dimension` object. 437 """ 438 # Derived classes get this "for free" once .shape is implemented. 439 if self.shape.rank is None: 440 return tensor_shape.Dimension(None) 441 else: 442 return self.shape.dims[-1] 443 444 def domain_dimension_tensor(self, name="domain_dimension_tensor"): 445 """Dimension (in the sense of vector spaces) of the domain of this operator. 446 447 Determined at runtime. 448 449 If this operator acts like the batch matrix `A` with 450 `A.shape = [B1,...,Bb, M, N]`, then this returns `N`. 451 452 Args: 453 name: A name for this `Op`. 454 455 Returns: 456 `int32` `Tensor` 457 """ 458 # Derived classes get this "for free" once .shape() is implemented. 459 with self._name_scope(name): # pylint: disable=not-callable 460 return self._domain_dimension_tensor() 461 462 def _domain_dimension_tensor(self, shape=None): 463 # `shape` may be passed in if this can be pre-computed in a 464 # more efficient manner, e.g. without excessive Tensor conversions. 465 dim_value = tensor_shape.dimension_value(self.domain_dimension) 466 if dim_value is not None: 467 return ops.convert_to_tensor_v2_with_dispatch(dim_value) 468 else: 469 shape = self.shape_tensor() if shape is None else shape 470 return shape[-1] 471 472 @property 473 def range_dimension(self): 474 """Dimension (in the sense of vector spaces) of the range of this operator. 475 476 If this operator acts like the batch matrix `A` with 477 `A.shape = [B1,...,Bb, M, N]`, then this returns `M`. 478 479 Returns: 480 `Dimension` object. 481 """ 482 # Derived classes get this "for free" once .shape is implemented. 483 if self.shape.dims: 484 return self.shape.dims[-2] 485 else: 486 return tensor_shape.Dimension(None) 487 488 def range_dimension_tensor(self, name="range_dimension_tensor"): 489 """Dimension (in the sense of vector spaces) of the range of this operator. 490 491 Determined at runtime. 492 493 If this operator acts like the batch matrix `A` with 494 `A.shape = [B1,...,Bb, M, N]`, then this returns `M`. 495 496 Args: 497 name: A name for this `Op`. 498 499 Returns: 500 `int32` `Tensor` 501 """ 502 # Derived classes get this "for free" once .shape() is implemented. 503 with self._name_scope(name): # pylint: disable=not-callable 504 return self._range_dimension_tensor() 505 506 def _range_dimension_tensor(self, shape=None): 507 # `shape` may be passed in if this can be pre-computed in a 508 # more efficient manner, e.g. without excessive Tensor conversions. 509 dim_value = tensor_shape.dimension_value(self.range_dimension) 510 if dim_value is not None: 511 return ops.convert_to_tensor_v2_with_dispatch(dim_value) 512 else: 513 shape = self.shape_tensor() if shape is None else shape 514 return shape[-2] 515 516 def _assert_non_singular(self): 517 """Private default implementation of _assert_non_singular.""" 518 logging.warn( 519 "Using (possibly slow) default implementation of assert_non_singular." 520 " Requires conversion to a dense matrix and O(N^3) operations.") 521 if self._can_use_cholesky(): 522 return self.assert_positive_definite() 523 else: 524 singular_values = linalg_ops.svd(self.to_dense(), compute_uv=False) 525 # TODO(langmore) Add .eig and .cond as methods. 526 cond = (math_ops.reduce_max(singular_values, axis=-1) / 527 math_ops.reduce_min(singular_values, axis=-1)) 528 return check_ops.assert_less( 529 cond, 530 self._max_condition_number_to_be_non_singular(), 531 message="Singular matrix up to precision epsilon.") 532 533 def _max_condition_number_to_be_non_singular(self): 534 """Return the maximum condition number that we consider nonsingular.""" 535 with ops.name_scope("max_nonsingular_condition_number"): 536 dtype_eps = np.finfo(self.dtype.as_numpy_dtype).eps 537 eps = math_ops.cast( 538 math_ops.reduce_max([ 539 100., 540 math_ops.cast(self.range_dimension_tensor(), self.dtype), 541 math_ops.cast(self.domain_dimension_tensor(), self.dtype) 542 ]), self.dtype) * dtype_eps 543 return 1. / eps 544 545 def assert_non_singular(self, name="assert_non_singular"): 546 """Returns an `Op` that asserts this operator is non singular. 547 548 This operator is considered non-singular if 549 550 ``` 551 ConditionNumber < max{100, range_dimension, domain_dimension} * eps, 552 eps := np.finfo(self.dtype.as_numpy_dtype).eps 553 ``` 554 555 Args: 556 name: A string name to prepend to created ops. 557 558 Returns: 559 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 560 the operator is singular. 561 """ 562 with self._name_scope(name): # pylint: disable=not-callable 563 return self._assert_non_singular() 564 565 def _assert_positive_definite(self): 566 """Default implementation of _assert_positive_definite.""" 567 logging.warn( 568 "Using (possibly slow) default implementation of " 569 "assert_positive_definite." 570 " Requires conversion to a dense matrix and O(N^3) operations.") 571 # If the operator is self-adjoint, then checking that 572 # Cholesky decomposition succeeds + results in positive diag is necessary 573 # and sufficient. 574 if self.is_self_adjoint: 575 return check_ops.assert_positive( 576 array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())), 577 message="Matrix was not positive definite.") 578 # We have no generic check for positive definite. 579 raise NotImplementedError("assert_positive_definite is not implemented.") 580 581 def assert_positive_definite(self, name="assert_positive_definite"): 582 """Returns an `Op` that asserts this operator is positive definite. 583 584 Here, positive definite means that the quadratic form `x^H A x` has positive 585 real part for all nonzero `x`. Note that we do not require the operator to 586 be self-adjoint to be positive definite. 587 588 Args: 589 name: A name to give this `Op`. 590 591 Returns: 592 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 593 the operator is not positive definite. 594 """ 595 with self._name_scope(name): # pylint: disable=not-callable 596 return self._assert_positive_definite() 597 598 def _assert_self_adjoint(self): 599 dense = self.to_dense() 600 logging.warn( 601 "Using (possibly slow) default implementation of assert_self_adjoint." 602 " Requires conversion to a dense matrix.") 603 return check_ops.assert_equal( 604 dense, 605 linalg.adjoint(dense), 606 message="Matrix was not equal to its adjoint.") 607 608 def assert_self_adjoint(self, name="assert_self_adjoint"): 609 """Returns an `Op` that asserts this operator is self-adjoint. 610 611 Here we check that this operator is *exactly* equal to its hermitian 612 transpose. 613 614 Args: 615 name: A string name to prepend to created ops. 616 617 Returns: 618 An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if 619 the operator is not self-adjoint. 620 """ 621 with self._name_scope(name): # pylint: disable=not-callable 622 return self._assert_self_adjoint() 623 624 def _check_input_dtype(self, arg): 625 """Check that arg.dtype == self.dtype.""" 626 if arg.dtype.base_dtype != self.dtype: 627 raise TypeError( 628 "Expected argument to have dtype %s. Found: %s in tensor %s" % 629 (self.dtype, arg.dtype, arg)) 630 631 @abc.abstractmethod 632 def _matmul(self, x, adjoint=False, adjoint_arg=False): 633 raise NotImplementedError("_matmul is not implemented.") 634 635 def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): 636 """Transform [batch] matrix `x` with left multiplication: `x --> Ax`. 637 638 ```python 639 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 640 operator = LinearOperator(...) 641 operator.shape = [..., M, N] 642 643 X = ... # shape [..., N, R], batch matrix, R > 0. 644 645 Y = operator.matmul(X) 646 Y.shape 647 ==> [..., M, R] 648 649 Y[..., :, r] = sum_j A[..., :, j] X[j, r] 650 ``` 651 652 Args: 653 x: `LinearOperator` or `Tensor` with compatible shape and same `dtype` as 654 `self`. See class docstring for definition of compatibility. 655 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 656 adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is 657 the hermitian transpose (transposition and complex conjugation). 658 name: A name for this `Op`. 659 660 Returns: 661 A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype` 662 as `self`. 663 """ 664 if isinstance(x, LinearOperator): 665 left_operator = self.adjoint() if adjoint else self 666 right_operator = x.adjoint() if adjoint_arg else x 667 668 if (right_operator.range_dimension is not None and 669 left_operator.domain_dimension is not None and 670 right_operator.range_dimension != left_operator.domain_dimension): 671 raise ValueError( 672 "Operators are incompatible. Expected `x` to have dimension" 673 " {} but got {}.".format( 674 left_operator.domain_dimension, right_operator.range_dimension)) 675 with self._name_scope(name): # pylint: disable=not-callable 676 return linear_operator_algebra.matmul(left_operator, right_operator) 677 678 with self._name_scope(name): # pylint: disable=not-callable 679 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 680 self._check_input_dtype(x) 681 682 self_dim = -2 if adjoint else -1 683 arg_dim = -1 if adjoint_arg else -2 684 tensor_shape.dimension_at_index( 685 self.shape, self_dim).assert_is_compatible_with( 686 x.shape[arg_dim]) 687 688 return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) 689 690 def __matmul__(self, other): 691 return self.matmul(other) 692 693 def _matvec(self, x, adjoint=False): 694 x_mat = array_ops.expand_dims(x, axis=-1) 695 y_mat = self.matmul(x_mat, adjoint=adjoint) 696 return array_ops.squeeze(y_mat, axis=-1) 697 698 def matvec(self, x, adjoint=False, name="matvec"): 699 """Transform [batch] vector `x` with left multiplication: `x --> Ax`. 700 701 ```python 702 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 703 operator = LinearOperator(...) 704 705 X = ... # shape [..., N], batch vector 706 707 Y = operator.matvec(X) 708 Y.shape 709 ==> [..., M] 710 711 Y[..., :] = sum_j A[..., :, j] X[..., j] 712 ``` 713 714 Args: 715 x: `Tensor` with compatible shape and same `dtype` as `self`. 716 `x` is treated as a [batch] vector meaning for every set of leading 717 dimensions, the last dimension defines a vector. 718 See class docstring for definition of compatibility. 719 adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`. 720 name: A name for this `Op`. 721 722 Returns: 723 A `Tensor` with shape `[..., M]` and same `dtype` as `self`. 724 """ 725 with self._name_scope(name): # pylint: disable=not-callable 726 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 727 self._check_input_dtype(x) 728 self_dim = -2 if adjoint else -1 729 tensor_shape.dimension_at_index( 730 self.shape, self_dim).assert_is_compatible_with(x.shape[-1]) 731 return self._matvec(x, adjoint=adjoint) 732 733 def _determinant(self): 734 logging.warn( 735 "Using (possibly slow) default implementation of determinant." 736 " Requires conversion to a dense matrix and O(N^3) operations.") 737 if self._can_use_cholesky(): 738 return math_ops.exp(self.log_abs_determinant()) 739 return linalg_ops.matrix_determinant(self.to_dense()) 740 741 def determinant(self, name="det"): 742 """Determinant for every batch member. 743 744 Args: 745 name: A name for this `Op`. 746 747 Returns: 748 `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. 749 750 Raises: 751 NotImplementedError: If `self.is_square` is `False`. 752 """ 753 if self.is_square is False: 754 raise NotImplementedError( 755 "Determinant not implemented for an operator that is expected to " 756 "not be square.") 757 with self._name_scope(name): # pylint: disable=not-callable 758 return self._determinant() 759 760 def _log_abs_determinant(self): 761 logging.warn( 762 "Using (possibly slow) default implementation of determinant." 763 " Requires conversion to a dense matrix and O(N^3) operations.") 764 if self._can_use_cholesky(): 765 diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())) 766 return 2 * math_ops.reduce_sum(math_ops.log(diag), axis=[-1]) 767 _, log_abs_det = linalg.slogdet(self.to_dense()) 768 return log_abs_det 769 770 def log_abs_determinant(self, name="log_abs_det"): 771 """Log absolute value of determinant for every batch member. 772 773 Args: 774 name: A name for this `Op`. 775 776 Returns: 777 `Tensor` with shape `self.batch_shape` and same `dtype` as `self`. 778 779 Raises: 780 NotImplementedError: If `self.is_square` is `False`. 781 """ 782 if self.is_square is False: 783 raise NotImplementedError( 784 "Determinant not implemented for an operator that is expected to " 785 "not be square.") 786 with self._name_scope(name): # pylint: disable=not-callable 787 return self._log_abs_determinant() 788 789 def _dense_solve(self, rhs, adjoint=False, adjoint_arg=False): 790 """Solve by conversion to a dense matrix.""" 791 if self.is_square is False: # pylint: disable=g-bool-id-comparison 792 raise NotImplementedError( 793 "Solve is not yet implemented for non-square operators.") 794 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 795 if self._can_use_cholesky(): 796 return linalg_ops.cholesky_solve( 797 linalg_ops.cholesky(self.to_dense()), rhs) 798 return linear_operator_util.matrix_solve_with_broadcast( 799 self.to_dense(), rhs, adjoint=adjoint) 800 801 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 802 """Default implementation of _solve.""" 803 logging.warn( 804 "Using (possibly slow) default implementation of solve." 805 " Requires conversion to a dense matrix and O(N^3) operations.") 806 return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 807 808 def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"): 809 """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`. 810 811 The returned `Tensor` will be close to an exact solution if `A` is well 812 conditioned. Otherwise closeness will vary. See class docstring for details. 813 814 Examples: 815 816 ```python 817 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 818 operator = LinearOperator(...) 819 operator.shape = [..., M, N] 820 821 # Solve R > 0 linear systems for every member of the batch. 822 RHS = ... # shape [..., M, R] 823 824 X = operator.solve(RHS) 825 # X[..., :, r] is the solution to the r'th linear system 826 # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r] 827 828 operator.matmul(X) 829 ==> RHS 830 ``` 831 832 Args: 833 rhs: `Tensor` with same `dtype` as this operator and compatible shape. 834 `rhs` is treated like a [batch] matrix meaning for every set of leading 835 dimensions, the last two dimensions defines a matrix. 836 See class docstring for definition of compatibility. 837 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 838 of this `LinearOperator`: `A^H X = rhs`. 839 adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H` 840 is the hermitian transpose (transposition and complex conjugation). 841 name: A name scope to use for ops added by this method. 842 843 Returns: 844 `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`. 845 846 Raises: 847 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 848 """ 849 if self.is_non_singular is False: 850 raise NotImplementedError( 851 "Exact solve not implemented for an operator that is expected to " 852 "be singular.") 853 if self.is_square is False: 854 raise NotImplementedError( 855 "Exact solve not implemented for an operator that is expected to " 856 "not be square.") 857 if isinstance(rhs, LinearOperator): 858 left_operator = self.adjoint() if adjoint else self 859 right_operator = rhs.adjoint() if adjoint_arg else rhs 860 861 if (right_operator.range_dimension is not None and 862 left_operator.domain_dimension is not None and 863 right_operator.range_dimension != left_operator.domain_dimension): 864 raise ValueError( 865 "Operators are incompatible. Expected `rhs` to have dimension" 866 " {} but got {}.".format( 867 left_operator.domain_dimension, right_operator.range_dimension)) 868 with self._name_scope(name): # pylint: disable=not-callable 869 return linear_operator_algebra.solve(left_operator, right_operator) 870 871 with self._name_scope(name): # pylint: disable=not-callable 872 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 873 self._check_input_dtype(rhs) 874 875 self_dim = -1 if adjoint else -2 876 arg_dim = -1 if adjoint_arg else -2 877 tensor_shape.dimension_at_index( 878 self.shape, self_dim).assert_is_compatible_with( 879 rhs.shape[arg_dim]) 880 881 return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 882 883 def _solvevec(self, rhs, adjoint=False): 884 """Default implementation of _solvevec.""" 885 rhs_mat = array_ops.expand_dims(rhs, axis=-1) 886 solution_mat = self.solve(rhs_mat, adjoint=adjoint) 887 return array_ops.squeeze(solution_mat, axis=-1) 888 889 def solvevec(self, rhs, adjoint=False, name="solve"): 890 """Solve single equation with best effort: `A X = rhs`. 891 892 The returned `Tensor` will be close to an exact solution if `A` is well 893 conditioned. Otherwise closeness will vary. See class docstring for details. 894 895 Examples: 896 897 ```python 898 # Make an operator acting like batch matrix A. Assume A.shape = [..., M, N] 899 operator = LinearOperator(...) 900 operator.shape = [..., M, N] 901 902 # Solve one linear system for every member of the batch. 903 RHS = ... # shape [..., M] 904 905 X = operator.solvevec(RHS) 906 # X is the solution to the linear system 907 # sum_j A[..., :, j] X[..., j] = RHS[..., :] 908 909 operator.matvec(X) 910 ==> RHS 911 ``` 912 913 Args: 914 rhs: `Tensor` with same `dtype` as this operator. 915 `rhs` is treated like a [batch] vector meaning for every set of leading 916 dimensions, the last dimension defines a vector. See class docstring 917 for definition of compatibility regarding batch dimensions. 918 adjoint: Python `bool`. If `True`, solve the system involving the adjoint 919 of this `LinearOperator`: `A^H X = rhs`. 920 name: A name scope to use for ops added by this method. 921 922 Returns: 923 `Tensor` with shape `[...,N]` and same `dtype` as `rhs`. 924 925 Raises: 926 NotImplementedError: If `self.is_non_singular` or `is_square` is False. 927 """ 928 with self._name_scope(name): # pylint: disable=not-callable 929 rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs") 930 self._check_input_dtype(rhs) 931 self_dim = -1 if adjoint else -2 932 tensor_shape.dimension_at_index( 933 self.shape, self_dim).assert_is_compatible_with(rhs.shape[-1]) 934 935 return self._solvevec(rhs, adjoint=adjoint) 936 937 def adjoint(self, name="adjoint"): 938 """Returns the adjoint of the current `LinearOperator`. 939 940 Given `A` representing this `LinearOperator`, return `A*`. 941 Note that calling `self.adjoint()` and `self.H` are equivalent. 942 943 Args: 944 name: A name for this `Op`. 945 946 Returns: 947 `LinearOperator` which represents the adjoint of this `LinearOperator`. 948 """ 949 if self.is_self_adjoint is True: # pylint: disable=g-bool-id-comparison 950 return self 951 with self._name_scope(name): # pylint: disable=not-callable 952 return linear_operator_algebra.adjoint(self) 953 954 # self.H is equivalent to self.adjoint(). 955 H = property(adjoint, None) 956 957 def inverse(self, name="inverse"): 958 """Returns the Inverse of this `LinearOperator`. 959 960 Given `A` representing this `LinearOperator`, return a `LinearOperator` 961 representing `A^-1`. 962 963 Args: 964 name: A name scope to use for ops added by this method. 965 966 Returns: 967 `LinearOperator` representing inverse of this matrix. 968 969 Raises: 970 ValueError: When the `LinearOperator` is not hinted to be `non_singular`. 971 """ 972 if self.is_square is False: # pylint: disable=g-bool-id-comparison 973 raise ValueError("Cannot take the Inverse: This operator represents " 974 "a non square matrix.") 975 if self.is_non_singular is False: # pylint: disable=g-bool-id-comparison 976 raise ValueError("Cannot take the Inverse: This operator represents " 977 "a singular matrix.") 978 979 with self._name_scope(name): # pylint: disable=not-callable 980 return linear_operator_algebra.inverse(self) 981 982 def cholesky(self, name="cholesky"): 983 """Returns a Cholesky factor as a `LinearOperator`. 984 985 Given `A` representing this `LinearOperator`, if `A` is positive definite 986 self-adjoint, return `L`, where `A = L L^T`, i.e. the cholesky 987 decomposition. 988 989 Args: 990 name: A name for this `Op`. 991 992 Returns: 993 `LinearOperator` which represents the lower triangular matrix 994 in the Cholesky decomposition. 995 996 Raises: 997 ValueError: When the `LinearOperator` is not hinted to be positive 998 definite and self adjoint. 999 """ 1000 1001 if not self._can_use_cholesky(): 1002 raise ValueError("Cannot take the Cholesky decomposition: " 1003 "Not a positive definite self adjoint matrix.") 1004 with self._name_scope(name): # pylint: disable=not-callable 1005 return linear_operator_algebra.cholesky(self) 1006 1007 def _to_dense(self): 1008 """Generic and often inefficient implementation. Override often.""" 1009 if self.batch_shape.is_fully_defined(): 1010 batch_shape = self.batch_shape 1011 else: 1012 batch_shape = self.batch_shape_tensor() 1013 1014 dim_value = tensor_shape.dimension_value(self.domain_dimension) 1015 if dim_value is not None: 1016 n = dim_value 1017 else: 1018 n = self.domain_dimension_tensor() 1019 1020 eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype) 1021 return self.matmul(eye) 1022 1023 def to_dense(self, name="to_dense"): 1024 """Return a dense (batch) matrix representing this operator.""" 1025 with self._name_scope(name): # pylint: disable=not-callable 1026 return self._to_dense() 1027 1028 def _diag_part(self): 1029 """Generic and often inefficient implementation. Override often.""" 1030 return array_ops.matrix_diag_part(self.to_dense()) 1031 1032 def diag_part(self, name="diag_part"): 1033 """Efficiently get the [batch] diagonal part of this operator. 1034 1035 If this operator has shape `[B1,...,Bb, M, N]`, this returns a 1036 `Tensor` `diagonal`, of shape `[B1,...,Bb, min(M, N)]`, where 1037 `diagonal[b1,...,bb, i] = self.to_dense()[b1,...,bb, i, i]`. 1038 1039 ``` 1040 my_operator = LinearOperatorDiag([1., 2.]) 1041 1042 # Efficiently get the diagonal 1043 my_operator.diag_part() 1044 ==> [1., 2.] 1045 1046 # Equivalent, but inefficient method 1047 tf.linalg.diag_part(my_operator.to_dense()) 1048 ==> [1., 2.] 1049 ``` 1050 1051 Args: 1052 name: A name for this `Op`. 1053 1054 Returns: 1055 diag_part: A `Tensor` of same `dtype` as self. 1056 """ 1057 with self._name_scope(name): # pylint: disable=not-callable 1058 return self._diag_part() 1059 1060 def _trace(self): 1061 return math_ops.reduce_sum(self.diag_part(), axis=-1) 1062 1063 def trace(self, name="trace"): 1064 """Trace of the linear operator, equal to sum of `self.diag_part()`. 1065 1066 If the operator is square, this is also the sum of the eigenvalues. 1067 1068 Args: 1069 name: A name for this `Op`. 1070 1071 Returns: 1072 Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`. 1073 """ 1074 with self._name_scope(name): # pylint: disable=not-callable 1075 return self._trace() 1076 1077 def _add_to_tensor(self, x): 1078 # Override if a more efficient implementation is available. 1079 return self.to_dense() + x 1080 1081 def add_to_tensor(self, x, name="add_to_tensor"): 1082 """Add matrix represented by this operator to `x`. Equivalent to `A + x`. 1083 1084 Args: 1085 x: `Tensor` with same `dtype` and shape broadcastable to `self.shape`. 1086 name: A name to give this `Op`. 1087 1088 Returns: 1089 A `Tensor` with broadcast shape and same `dtype` as `self`. 1090 """ 1091 with self._name_scope(name): # pylint: disable=not-callable 1092 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 1093 self._check_input_dtype(x) 1094 return self._add_to_tensor(x) 1095 1096 def _eigvals(self): 1097 return linalg_ops.self_adjoint_eigvals(self.to_dense()) 1098 1099 def eigvals(self, name="eigvals"): 1100 """Returns the eigenvalues of this linear operator. 1101 1102 If the operator is marked as self-adjoint (via `is_self_adjoint`) 1103 this computation can be more efficient. 1104 1105 Note: This currently only supports self-adjoint operators. 1106 1107 Args: 1108 name: A name for this `Op`. 1109 1110 Returns: 1111 Shape `[B1,...,Bb, N]` `Tensor` of same `dtype` as `self`. 1112 """ 1113 if not self.is_self_adjoint: 1114 raise NotImplementedError("Only self-adjoint matrices are supported.") 1115 with self._name_scope(name): # pylint: disable=not-callable 1116 return self._eigvals() 1117 1118 def _cond(self): 1119 if not self.is_self_adjoint: 1120 # In general the condition number is the ratio of the 1121 # absolute value of the largest and smallest singular values. 1122 vals = linalg_ops.svd(self.to_dense(), compute_uv=False) 1123 else: 1124 # For self-adjoint matrices, and in general normal matrices, 1125 # we can use eigenvalues. 1126 vals = math_ops.abs(self._eigvals()) 1127 1128 return (math_ops.reduce_max(vals, axis=-1) / 1129 math_ops.reduce_min(vals, axis=-1)) 1130 1131 def cond(self, name="cond"): 1132 """Returns the condition number of this linear operator. 1133 1134 Args: 1135 name: A name for this `Op`. 1136 1137 Returns: 1138 Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`. 1139 """ 1140 with self._name_scope(name): # pylint: disable=not-callable 1141 return self._cond() 1142 1143 def _can_use_cholesky(self): 1144 return self.is_self_adjoint and self.is_positive_definite 1145 1146 def _set_graph_parents(self, graph_parents): 1147 """Set self._graph_parents. Called during derived class init. 1148 1149 This method allows derived classes to set graph_parents, without triggering 1150 a deprecation warning (which is invoked if `graph_parents` is passed during 1151 `__init__`. 1152 1153 Args: 1154 graph_parents: Iterable over Tensors. 1155 """ 1156 # TODO(b/143910018) Remove this function in V3. 1157 graph_parents = [] if graph_parents is None else graph_parents 1158 for i, t in enumerate(graph_parents): 1159 if t is None or not (linear_operator_util.is_ref(t) or 1160 tensor_util.is_tf_type(t)): 1161 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) 1162 self._graph_parents = graph_parents 1163 1164 @property 1165 def _composite_tensor_fields(self): 1166 """A tuple of parameter names to rebuild the `LinearOperator`. 1167 1168 The tuple contains the names of kwargs to the `LinearOperator`'s constructor 1169 that the `TypeSpec` needs to rebuild the `LinearOperator` instance. 1170 1171 "is_non_singular", "is_self_adjoint", "is_positive_definite", and 1172 "is_square" are common to all `LinearOperator` subclasses and may be 1173 omitted. 1174 """ 1175 return () 1176 1177 @property 1178 def _composite_tensor_prefer_static_fields(self): 1179 """A tuple of names referring to parameters that may be treated statically. 1180 1181 This is a subset of `_composite_tensor_fields`, and contains the names of 1182 of `Tensor`-like args to the `LinearOperator`s constructor that may be 1183 stored as static values, if they are statically known. These are typically 1184 shapes or axis values. 1185 """ 1186 return () 1187 1188 @property 1189 def _type_spec(self): 1190 # This property will be overwritten by the `@make_composite_tensor` 1191 # decorator. However, we need it so that a valid subclass of the `ABCMeta` 1192 # class `CompositeTensor` can be constructed and passed to the 1193 # `@make_composite_tensor` decorator. 1194 pass 1195 1196 def _convert_variables_to_tensors(self): 1197 """Recursively converts ResourceVariables in the LinearOperator to Tensors. 1198 1199 The usage of `self._type_spec._from_components` violates the contract of 1200 `CompositeTensor`, since it is called on a different nested structure 1201 (one containing only `Tensor`s) than `self.type_spec` specifies (one that 1202 may contain `ResourceVariable`s). Since `LinearOperator`'s 1203 `_from_components` method just passes the contents of the nested structure 1204 to `__init__` to rebuild the operator, and any `LinearOperator` that may be 1205 instantiated with `ResourceVariables` may also be instantiated with 1206 `Tensor`s, this usage is valid. 1207 1208 Returns: 1209 tensor_operator: `self` with all internal Variables converted to Tensors. 1210 """ 1211 # pylint: disable=protected-access 1212 components = self._type_spec._to_components(self) 1213 tensor_components = variable_utils.convert_variables_to_tensors( 1214 components) 1215 return self._type_spec._from_components(tensor_components) 1216 # pylint: enable=protected-access 1217 1218 def __getitem__(self, slices): 1219 return slicing.batch_slice(self, params_overrides={}, slices=slices) 1220 1221 @property 1222 def _experimental_parameter_ndims_to_matrix_ndims(self): 1223 """A dict of names to number of dimensions contributing to an operator. 1224 1225 This is a dictionary of parameter names to `int`s specifying the 1226 number of right-most dimensions contributing to the **matrix** shape of the 1227 densified operator. 1228 If the parameter is a `Tensor`, this is mapped to an `int`. 1229 If the parameter is a `LinearOperator` (called `A`), this specifies the 1230 number of batch dimensions of `A` contributing to this `LinearOperator`s 1231 matrix shape. 1232 If the parameter is a structure, this is a structure of the same type of 1233 `int`s. 1234 """ 1235 return () 1236 1237 1238class _LinearOperatorSpec(type_spec.BatchableTypeSpec): 1239 """A tf.TypeSpec for `LinearOperator` objects.""" 1240 1241 __slots__ = ("_param_specs", "_non_tensor_params", "_prefer_static_fields") 1242 1243 def __init__(self, param_specs, non_tensor_params, prefer_static_fields): 1244 """Initializes a new `_LinearOperatorSpec`. 1245 1246 Args: 1247 param_specs: Python `dict` of `tf.TypeSpec` instances that describe 1248 kwargs to the `LinearOperator`'s constructor that are `Tensor`-like or 1249 `CompositeTensor` subclasses. 1250 non_tensor_params: Python `dict` containing non-`Tensor` and non- 1251 `CompositeTensor` kwargs to the `LinearOperator`'s constructor. 1252 prefer_static_fields: Python `tuple` of strings corresponding to the names 1253 of `Tensor`-like args to the `LinearOperator`s constructor that may be 1254 stored as static values, if known. These are typically shapes, indices, 1255 or axis values. 1256 """ 1257 self._param_specs = param_specs 1258 self._non_tensor_params = non_tensor_params 1259 self._prefer_static_fields = prefer_static_fields 1260 1261 @classmethod 1262 def from_operator(cls, operator): 1263 """Builds a `_LinearOperatorSpec` from a `LinearOperator` instance. 1264 1265 Args: 1266 operator: An instance of `LinearOperator`. 1267 1268 Returns: 1269 linear_operator_spec: An instance of `_LinearOperatorSpec` to be used as 1270 the `TypeSpec` of `operator`. 1271 """ 1272 validation_fields = ("is_non_singular", "is_self_adjoint", 1273 "is_positive_definite", "is_square") 1274 kwargs = _extract_attrs( 1275 operator, 1276 keys=set(operator._composite_tensor_fields + validation_fields)) # pylint: disable=protected-access 1277 1278 non_tensor_params = {} 1279 param_specs = {} 1280 for k, v in list(kwargs.items()): 1281 type_spec_or_v = _extract_type_spec_recursively(v) 1282 is_tensor = [isinstance(x, type_spec.TypeSpec) 1283 for x in nest.flatten(type_spec_or_v)] 1284 if all(is_tensor): 1285 param_specs[k] = type_spec_or_v 1286 elif not any(is_tensor): 1287 non_tensor_params[k] = v 1288 else: 1289 raise NotImplementedError(f"Field {k} contains a mix of `Tensor` and " 1290 f" non-`Tensor` values.") 1291 1292 return cls( 1293 param_specs=param_specs, 1294 non_tensor_params=non_tensor_params, 1295 prefer_static_fields=operator._composite_tensor_prefer_static_fields) # pylint: disable=protected-access 1296 1297 def _to_components(self, obj): 1298 return _extract_attrs(obj, keys=list(self._param_specs)) 1299 1300 def _from_components(self, components): 1301 kwargs = dict(self._non_tensor_params, **components) 1302 return self.value_type(**kwargs) 1303 1304 @property 1305 def _component_specs(self): 1306 return self._param_specs 1307 1308 def _serialize(self): 1309 return (self._param_specs, 1310 self._non_tensor_params, 1311 self._prefer_static_fields) 1312 1313 def _copy(self, **overrides): 1314 kwargs = { 1315 "param_specs": self._param_specs, 1316 "non_tensor_params": self._non_tensor_params, 1317 "prefer_static_fields": self._prefer_static_fields 1318 } 1319 kwargs.update(overrides) 1320 return type(self)(**kwargs) 1321 1322 def _batch(self, batch_size): 1323 """Returns a TypeSpec representing a batch of objects with this TypeSpec.""" 1324 return self._copy( 1325 param_specs=nest.map_structure( 1326 lambda spec: spec._batch(batch_size), # pylint: disable=protected-access 1327 self._param_specs)) 1328 1329 def _unbatch(self, batch_size): 1330 """Returns a TypeSpec representing a single element of this TypeSpec.""" 1331 return self._copy( 1332 param_specs=nest.map_structure( 1333 lambda spec: spec._unbatch(), # pylint: disable=protected-access 1334 self._param_specs)) 1335 1336 1337def make_composite_tensor(cls, module_name="tf.linalg"): 1338 """Class decorator to convert `LinearOperator`s to `CompositeTensor`.""" 1339 1340 spec_name = "{}Spec".format(cls.__name__) 1341 spec_type = type(spec_name, (_LinearOperatorSpec,), {"value_type": cls}) 1342 type_spec.register("{}.{}".format(module_name, spec_name))(spec_type) 1343 cls._type_spec = property(spec_type.from_operator) # pylint: disable=protected-access 1344 return cls 1345 1346 1347def _extract_attrs(op, keys): 1348 """Extract constructor kwargs to reconstruct `op`. 1349 1350 Args: 1351 op: A `LinearOperator` instance. 1352 keys: A Python `tuple` of strings indicating the names of the constructor 1353 kwargs to extract from `op`. 1354 1355 Returns: 1356 kwargs: A Python `dict` of kwargs to `op`'s constructor, keyed by `keys`. 1357 """ 1358 1359 kwargs = {} 1360 not_found = object() 1361 for k in keys: 1362 srcs = [ 1363 getattr(op, k, not_found), getattr(op, "_" + k, not_found), 1364 getattr(op, "parameters", {}).get(k, not_found), 1365 ] 1366 if any(v is not not_found for v in srcs): 1367 kwargs[k] = [v for v in srcs if v is not not_found][0] 1368 else: 1369 raise ValueError( 1370 f"Could not determine an appropriate value for field `{k}` in object " 1371 f" `{op}`. Looked for \n" 1372 f" 1. an attr called `{k}`,\n" 1373 f" 2. an attr called `_{k}`,\n" 1374 f" 3. an entry in `op.parameters` with key '{k}'.") 1375 if k in op._composite_tensor_prefer_static_fields and kwargs[k] is not None: # pylint: disable=protected-access 1376 if tensor_util.is_tensor(kwargs[k]): 1377 static_val = tensor_util.constant_value(kwargs[k]) 1378 if static_val is not None: 1379 kwargs[k] = static_val 1380 if isinstance(kwargs[k], (np.ndarray, np.generic)): 1381 kwargs[k] = kwargs[k].tolist() 1382 return kwargs 1383 1384 1385def _extract_type_spec_recursively(value): 1386 """Return (collection of) `TypeSpec`(s) for `value` if it includes `Tensor`s. 1387 1388 If `value` is a `Tensor` or `CompositeTensor`, return its `TypeSpec`. If 1389 `value` is a collection containing `Tensor` values, recursively supplant them 1390 with their respective `TypeSpec`s in a collection of parallel stucture. 1391 1392 If `value` is none of the above, return it unchanged. 1393 1394 Args: 1395 value: a Python `object` to (possibly) turn into a (collection of) 1396 `tf.TypeSpec`(s). 1397 1398 Returns: 1399 spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value` 1400 or `value`, if no `Tensor`s are found. 1401 """ 1402 if isinstance(value, composite_tensor.CompositeTensor): 1403 return value._type_spec # pylint: disable=protected-access 1404 if isinstance(value, variables.Variable): 1405 return resource_variable_ops.VariableSpec( 1406 value.shape, dtype=value.dtype, trainable=value.trainable) 1407 if tensor_util.is_tensor(value): 1408 return tensor_spec.TensorSpec(value.shape, value.dtype) 1409 # Unwrap trackable data structures to comply with `Type_Spec._serialize` 1410 # requirements. `ListWrapper`s are converted to `list`s, and for other 1411 # trackable data structures, the `__wrapped__` attribute is used. 1412 if isinstance(value, list): 1413 return list(_extract_type_spec_recursively(v) for v in value) 1414 if isinstance(value, data_structures.TrackableDataStructure): 1415 return _extract_type_spec_recursively(value.__wrapped__) 1416 if isinstance(value, tuple): 1417 return type(value)(_extract_type_spec_recursively(x) for x in value) 1418 if isinstance(value, dict): 1419 return type(value)((k, _extract_type_spec_recursively(v)) 1420 for k, v in value.items()) 1421 return value 1422 1423 1424# Overrides for tf.linalg functions. This allows a LinearOperator to be used in 1425# place of a Tensor. 1426# For instance tf.trace(linop) and linop.trace() both work. 1427 1428 1429@dispatch.dispatch_for_types(linalg.adjoint, LinearOperator) 1430def _adjoint(matrix, name=None): 1431 return matrix.adjoint(name) 1432 1433 1434@dispatch.dispatch_for_types(linalg.cholesky, LinearOperator) 1435def _cholesky(input, name=None): # pylint:disable=redefined-builtin 1436 return input.cholesky(name) 1437 1438 1439# The signature has to match with the one in python/op/array_ops.py, 1440# so we have k, padding_value, and align even though we don't use them here. 1441# pylint:disable=unused-argument 1442@dispatch.dispatch_for_types(linalg.diag_part, LinearOperator) 1443def _diag_part( 1444 input, # pylint:disable=redefined-builtin 1445 name="diag_part", 1446 k=0, 1447 padding_value=0, 1448 align="RIGHT_LEFT"): 1449 return input.diag_part(name) 1450# pylint:enable=unused-argument 1451 1452 1453@dispatch.dispatch_for_types(linalg.det, LinearOperator) 1454def _det(input, name=None): # pylint:disable=redefined-builtin 1455 return input.determinant(name) 1456 1457 1458@dispatch.dispatch_for_types(linalg.inv, LinearOperator) 1459def _inverse(input, adjoint=False, name=None): # pylint:disable=redefined-builtin 1460 inv = input.inverse(name) 1461 if adjoint: 1462 inv = inv.adjoint() 1463 return inv 1464 1465 1466@dispatch.dispatch_for_types(linalg.logdet, LinearOperator) 1467def _logdet(matrix, name=None): 1468 if matrix.is_positive_definite and matrix.is_self_adjoint: 1469 return matrix.log_abs_determinant(name) 1470 raise ValueError("Expected matrix to be self-adjoint positive definite.") 1471 1472 1473@dispatch.dispatch_for_types(math_ops.matmul, LinearOperator) 1474def _matmul( # pylint:disable=missing-docstring 1475 a, 1476 b, 1477 transpose_a=False, 1478 transpose_b=False, 1479 adjoint_a=False, 1480 adjoint_b=False, 1481 a_is_sparse=False, 1482 b_is_sparse=False, 1483 output_type=None, # pylint: disable=unused-argument 1484 name=None): 1485 if transpose_a or transpose_b: 1486 raise ValueError("Transposing not supported at this time.") 1487 if a_is_sparse or b_is_sparse: 1488 raise ValueError("Sparse methods not supported at this time.") 1489 if not isinstance(a, LinearOperator): 1490 # We use the identity (B^HA^H)^H = AB 1491 adjoint_matmul = b.matmul( 1492 a, 1493 adjoint=(not adjoint_b), 1494 adjoint_arg=(not adjoint_a), 1495 name=name) 1496 return linalg.adjoint(adjoint_matmul) 1497 return a.matmul( 1498 b, adjoint=adjoint_a, adjoint_arg=adjoint_b, name=name) 1499 1500 1501@dispatch.dispatch_for_types(linalg.solve, LinearOperator) 1502def _solve( 1503 matrix, 1504 rhs, 1505 adjoint=False, 1506 name=None): 1507 if not isinstance(matrix, LinearOperator): 1508 raise ValueError("Passing in `matrix` as a Tensor and `rhs` as a " 1509 "LinearOperator is not supported.") 1510 return matrix.solve(rhs, adjoint=adjoint, name=name) 1511 1512 1513@dispatch.dispatch_for_types(linalg.trace, LinearOperator) 1514def _trace(x, name=None): 1515 return x.trace(name) 1516