xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/linalg/linear_operator.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base class for linear operators."""
16
17import abc
18import contextlib
19
20import numpy as np
21
22from tensorflow.python.framework import composite_tensor
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.framework import tensor_spec
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.framework import type_spec
29from tensorflow.python.module import module
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import check_ops
32from tensorflow.python.ops import linalg_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import resource_variable_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.ops.linalg import linalg_impl as linalg
37from tensorflow.python.ops.linalg import linear_operator_algebra
38from tensorflow.python.ops.linalg import linear_operator_util
39from tensorflow.python.ops.linalg import slicing
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.trackable import data_structures
42from tensorflow.python.util import deprecation
43from tensorflow.python.util import dispatch
44from tensorflow.python.util import nest
45from tensorflow.python.util import variable_utils
46from tensorflow.python.util.tf_export import tf_export
47
48__all__ = ["LinearOperator"]
49
50
51# TODO(langmore) Use matrix_solve_ls for singular or non-square matrices.
52@tf_export("linalg.LinearOperator")
53class LinearOperator(
54    module.Module, composite_tensor.CompositeTensor, metaclass=abc.ABCMeta):
55  """Base class defining a [batch of] linear operator[s].
56
57  Subclasses of `LinearOperator` provide access to common methods on a
58  (batch) matrix, without the need to materialize the matrix.  This allows:
59
60  * Matrix free computations
61  * Operators that take advantage of special structure, while providing a
62    consistent API to users.
63
64  #### Subclassing
65
66  To enable a public method, subclasses should implement the leading-underscore
67  version of the method.  The argument signature should be identical except for
68  the omission of `name="..."`.  For example, to enable
69  `matmul(x, adjoint=False, name="matmul")` a subclass should implement
70  `_matmul(x, adjoint=False)`.
71
72  #### Performance contract
73
74  Subclasses should only implement the assert methods
75  (e.g. `assert_non_singular`) if they can be done in less than `O(N^3)`
76  time.
77
78  Class docstrings should contain an explanation of computational complexity.
79  Since this is a high-performance library, attention should be paid to detail,
80  and explanations can include constants as well as Big-O notation.
81
82  #### Shape compatibility
83
84  `LinearOperator` subclasses should operate on a [batch] matrix with
85  compatible shape.  Class docstrings should define what is meant by compatible
86  shape.  Some subclasses may not support batching.
87
88  Examples:
89
90  `x` is a batch matrix with compatible shape for `matmul` if
91
92  ```
93  operator.shape = [B1,...,Bb] + [M, N],  b >= 0,
94  x.shape =   [B1,...,Bb] + [N, R]
95  ```
96
97  `rhs` is a batch matrix with compatible shape for `solve` if
98
99  ```
100  operator.shape = [B1,...,Bb] + [M, N],  b >= 0,
101  rhs.shape =   [B1,...,Bb] + [M, R]
102  ```
103
104  #### Example docstring for subclasses.
105
106  This operator acts like a (batch) matrix `A` with shape
107  `[B1,...,Bb, M, N]` for some `b >= 0`.  The first `b` indices index a
108  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
109  an `m x n` matrix.  Again, this matrix `A` may not be materialized, but for
110  purposes of identifying and working with compatible arguments the shape is
111  relevant.
112
113  Examples:
114
115  ```python
116  some_tensor = ... shape = ????
117  operator = MyLinOp(some_tensor)
118
119  operator.shape()
120  ==> [2, 4, 4]
121
122  operator.log_abs_determinant()
123  ==> Shape [2] Tensor
124
125  x = ... Shape [2, 4, 5] Tensor
126
127  operator.matmul(x)
128  ==> Shape [2, 4, 5] Tensor
129  ```
130
131  #### Shape compatibility
132
133  This operator acts on batch matrices with compatible shape.
134  FILL IN WHAT IS MEANT BY COMPATIBLE SHAPE
135
136  #### Performance
137
138  FILL THIS IN
139
140  #### Matrix property hints
141
142  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
143  for `X = non_singular, self_adjoint, positive_definite, square`.
144  These have the following meaning:
145
146  * If `is_X == True`, callers should expect the operator to have the
147    property `X`.  This is a promise that should be fulfilled, but is *not* a
148    runtime assert.  For example, finite floating point precision may result
149    in these promises being violated.
150  * If `is_X == False`, callers should expect the operator to not have `X`.
151  * If `is_X == None` (the default), callers should have no expectation either
152    way.
153
154  #### Initialization parameters
155
156  All subclasses of `LinearOperator` are expected to pass a `parameters`
157  argument to `super().__init__()`.  This should be a `dict` containing
158  the unadulterated arguments passed to the subclass `__init__`.  For example,
159  `MyLinearOperator` with an initializer should look like:
160
161  ```python
162  def __init__(self, operator, is_square=False, name=None):
163     parameters = dict(
164         operator=operator,
165         is_square=is_square,
166         name=name
167     )
168     ...
169     super().__init__(..., parameters=parameters)
170  ```
171
172   Users can then access `my_linear_operator.parameters` to see all arguments
173   passed to its initializer.
174  """
175
176  # TODO(b/143910018) Remove graph_parents in V3.
177  @deprecation.deprecated_args(None, "Do not pass `graph_parents`.  They will "
178                               " no longer be used.", "graph_parents")
179  def __init__(self,
180               dtype,
181               graph_parents=None,
182               is_non_singular=None,
183               is_self_adjoint=None,
184               is_positive_definite=None,
185               is_square=None,
186               name=None,
187               parameters=None):
188    """Initialize the `LinearOperator`.
189
190    **This is a private method for subclass use.**
191    **Subclasses should copy-paste this `__init__` documentation.**
192
193    Args:
194      dtype: The type of the this `LinearOperator`.  Arguments to `matmul` and
195        `solve` will have to be this type.
196      graph_parents: (Deprecated) Python list of graph prerequisites of this
197        `LinearOperator` Typically tensors that are passed during initialization
198      is_non_singular:  Expect that this operator is non-singular.
199      is_self_adjoint:  Expect that this operator is equal to its hermitian
200        transpose.  If `dtype` is real, this is equivalent to being symmetric.
201      is_positive_definite:  Expect that this operator is positive definite,
202        meaning the quadratic form `x^H A x` has positive real part for all
203        nonzero `x`.  Note that we do not require the operator to be
204        self-adjoint to be positive-definite.  See:
205        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
206      is_square:  Expect that this operator acts like square [batch] matrices.
207      name: A name for this `LinearOperator`.
208      parameters: Python `dict` of parameters used to instantiate this
209        `LinearOperator`.
210
211    Raises:
212      ValueError:  If any member of graph_parents is `None` or not a `Tensor`.
213      ValueError:  If hints are set incorrectly.
214    """
215    # Check and auto-set flags.
216    if is_positive_definite:
217      if is_non_singular is False:
218        raise ValueError("A positive definite matrix is always non-singular.")
219      is_non_singular = True
220
221    if is_non_singular:
222      if is_square is False:
223        raise ValueError("A non-singular matrix is always square.")
224      is_square = True
225
226    if is_self_adjoint:
227      if is_square is False:
228        raise ValueError("A self-adjoint matrix is always square.")
229      is_square = True
230
231    self._is_square_set_or_implied_by_hints = is_square
232
233    if graph_parents is not None:
234      self._set_graph_parents(graph_parents)
235    else:
236      self._graph_parents = []
237    self._dtype = dtypes.as_dtype(dtype).base_dtype if dtype else dtype
238    self._is_non_singular = is_non_singular
239    self._is_self_adjoint = is_self_adjoint
240    self._is_positive_definite = is_positive_definite
241    self._parameters = self._no_dependency(parameters)
242    self._parameters_sanitized = False
243    self._name = name or type(self).__name__
244
245  @contextlib.contextmanager
246  def _name_scope(self, name=None):  # pylint: disable=method-hidden
247    """Helper function to standardize op scope."""
248    full_name = self.name
249    if name is not None:
250      full_name += "/" + name
251    with ops.name_scope(full_name) as scope:
252      yield scope
253
254  @property
255  def parameters(self):
256    """Dictionary of parameters used to instantiate this `LinearOperator`."""
257    return dict(self._parameters)
258
259  @property
260  def dtype(self):
261    """The `DType` of `Tensor`s handled by this `LinearOperator`."""
262    return self._dtype
263
264  @property
265  def name(self):
266    """Name prepended to all ops created by this `LinearOperator`."""
267    return self._name
268
269  @property
270  @deprecation.deprecated(None, "Do not call `graph_parents`.")
271  def graph_parents(self):
272    """List of graph dependencies of this `LinearOperator`."""
273    return self._graph_parents
274
275  @property
276  def is_non_singular(self):
277    return self._is_non_singular
278
279  @property
280  def is_self_adjoint(self):
281    return self._is_self_adjoint
282
283  @property
284  def is_positive_definite(self):
285    return self._is_positive_definite
286
287  @property
288  def is_square(self):
289    """Return `True/False` depending on if this operator is square."""
290    # Static checks done after __init__.  Why?  Because domain/range dimension
291    # sometimes requires lots of work done in the derived class after init.
292    auto_square_check = self.domain_dimension == self.range_dimension
293    if self._is_square_set_or_implied_by_hints is False and auto_square_check:
294      raise ValueError(
295          "User set is_square hint to False, but the operator was square.")
296    if self._is_square_set_or_implied_by_hints is None:
297      return auto_square_check
298
299    return self._is_square_set_or_implied_by_hints
300
301  @abc.abstractmethod
302  def _shape(self):
303    # Write this in derived class to enable all static shape methods.
304    raise NotImplementedError("_shape is not implemented.")
305
306  @property
307  def shape(self):
308    """`TensorShape` of this `LinearOperator`.
309
310    If this operator acts like the batch matrix `A` with
311    `A.shape = [B1,...,Bb, M, N]`, then this returns
312    `TensorShape([B1,...,Bb, M, N])`, equivalent to `A.shape`.
313
314    Returns:
315      `TensorShape`, statically determined, may be undefined.
316    """
317    return self._shape()
318
319  def _shape_tensor(self):
320    # This is not an abstractmethod, since we want derived classes to be able to
321    # override this with optional kwargs, which can reduce the number of
322    # `convert_to_tensor` calls.  See derived classes for examples.
323    raise NotImplementedError("_shape_tensor is not implemented.")
324
325  def shape_tensor(self, name="shape_tensor"):
326    """Shape of this `LinearOperator`, determined at runtime.
327
328    If this operator acts like the batch matrix `A` with
329    `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding
330    `[B1,...,Bb, M, N]`, equivalent to `tf.shape(A)`.
331
332    Args:
333      name:  A name for this `Op`.
334
335    Returns:
336      `int32` `Tensor`
337    """
338    with self._name_scope(name):  # pylint: disable=not-callable
339      # Prefer to use statically defined shape if available.
340      if self.shape.is_fully_defined():
341        return linear_operator_util.shape_tensor(self.shape.as_list())
342      else:
343        return self._shape_tensor()
344
345  @property
346  def batch_shape(self):
347    """`TensorShape` of batch dimensions of this `LinearOperator`.
348
349    If this operator acts like the batch matrix `A` with
350    `A.shape = [B1,...,Bb, M, N]`, then this returns
351    `TensorShape([B1,...,Bb])`, equivalent to `A.shape[:-2]`
352
353    Returns:
354      `TensorShape`, statically determined, may be undefined.
355    """
356    # Derived classes get this "for free" once .shape is implemented.
357    return self.shape[:-2]
358
359  def batch_shape_tensor(self, name="batch_shape_tensor"):
360    """Shape of batch dimensions of this operator, determined at runtime.
361
362    If this operator acts like the batch matrix `A` with
363    `A.shape = [B1,...,Bb, M, N]`, then this returns a `Tensor` holding
364    `[B1,...,Bb]`.
365
366    Args:
367      name:  A name for this `Op`.
368
369    Returns:
370      `int32` `Tensor`
371    """
372    # Derived classes get this "for free" once .shape() is implemented.
373    with self._name_scope(name):  # pylint: disable=not-callable
374      return self._batch_shape_tensor()
375
376  def _batch_shape_tensor(self, shape=None):
377    # `shape` may be passed in if this can be pre-computed in a
378    # more efficient manner, e.g. without excessive Tensor conversions.
379    if self.batch_shape.is_fully_defined():
380      return linear_operator_util.shape_tensor(
381          self.batch_shape.as_list(), name="batch_shape")
382    else:
383      shape = self.shape_tensor() if shape is None else shape
384      return shape[:-2]
385
386  @property
387  def tensor_rank(self, name="tensor_rank"):
388    """Rank (in the sense of tensors) of matrix corresponding to this operator.
389
390    If this operator acts like the batch matrix `A` with
391    `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`.
392
393    Args:
394      name:  A name for this `Op`.
395
396    Returns:
397      Python integer, or None if the tensor rank is undefined.
398    """
399    # Derived classes get this "for free" once .shape() is implemented.
400    with self._name_scope(name):  # pylint: disable=not-callable
401      return self.shape.ndims
402
403  def tensor_rank_tensor(self, name="tensor_rank_tensor"):
404    """Rank (in the sense of tensors) of matrix corresponding to this operator.
405
406    If this operator acts like the batch matrix `A` with
407    `A.shape = [B1,...,Bb, M, N]`, then this returns `b + 2`.
408
409    Args:
410      name:  A name for this `Op`.
411
412    Returns:
413      `int32` `Tensor`, determined at runtime.
414    """
415    # Derived classes get this "for free" once .shape() is implemented.
416    with self._name_scope(name):  # pylint: disable=not-callable
417      return self._tensor_rank_tensor()
418
419  def _tensor_rank_tensor(self, shape=None):
420    # `shape` may be passed in if this can be pre-computed in a
421    # more efficient manner, e.g. without excessive Tensor conversions.
422    if self.tensor_rank is not None:
423      return ops.convert_to_tensor_v2_with_dispatch(self.tensor_rank)
424    else:
425      shape = self.shape_tensor() if shape is None else shape
426      return array_ops.size(shape)
427
428  @property
429  def domain_dimension(self):
430    """Dimension (in the sense of vector spaces) of the domain of this operator.
431
432    If this operator acts like the batch matrix `A` with
433    `A.shape = [B1,...,Bb, M, N]`, then this returns `N`.
434
435    Returns:
436      `Dimension` object.
437    """
438    # Derived classes get this "for free" once .shape is implemented.
439    if self.shape.rank is None:
440      return tensor_shape.Dimension(None)
441    else:
442      return self.shape.dims[-1]
443
444  def domain_dimension_tensor(self, name="domain_dimension_tensor"):
445    """Dimension (in the sense of vector spaces) of the domain of this operator.
446
447    Determined at runtime.
448
449    If this operator acts like the batch matrix `A` with
450    `A.shape = [B1,...,Bb, M, N]`, then this returns `N`.
451
452    Args:
453      name:  A name for this `Op`.
454
455    Returns:
456      `int32` `Tensor`
457    """
458    # Derived classes get this "for free" once .shape() is implemented.
459    with self._name_scope(name):  # pylint: disable=not-callable
460      return self._domain_dimension_tensor()
461
462  def _domain_dimension_tensor(self, shape=None):
463    # `shape` may be passed in if this can be pre-computed in a
464    # more efficient manner, e.g. without excessive Tensor conversions.
465    dim_value = tensor_shape.dimension_value(self.domain_dimension)
466    if dim_value is not None:
467      return ops.convert_to_tensor_v2_with_dispatch(dim_value)
468    else:
469      shape = self.shape_tensor() if shape is None else shape
470      return shape[-1]
471
472  @property
473  def range_dimension(self):
474    """Dimension (in the sense of vector spaces) of the range of this operator.
475
476    If this operator acts like the batch matrix `A` with
477    `A.shape = [B1,...,Bb, M, N]`, then this returns `M`.
478
479    Returns:
480      `Dimension` object.
481    """
482    # Derived classes get this "for free" once .shape is implemented.
483    if self.shape.dims:
484      return self.shape.dims[-2]
485    else:
486      return tensor_shape.Dimension(None)
487
488  def range_dimension_tensor(self, name="range_dimension_tensor"):
489    """Dimension (in the sense of vector spaces) of the range of this operator.
490
491    Determined at runtime.
492
493    If this operator acts like the batch matrix `A` with
494    `A.shape = [B1,...,Bb, M, N]`, then this returns `M`.
495
496    Args:
497      name:  A name for this `Op`.
498
499    Returns:
500      `int32` `Tensor`
501    """
502    # Derived classes get this "for free" once .shape() is implemented.
503    with self._name_scope(name):  # pylint: disable=not-callable
504      return self._range_dimension_tensor()
505
506  def _range_dimension_tensor(self, shape=None):
507    # `shape` may be passed in if this can be pre-computed in a
508    # more efficient manner, e.g. without excessive Tensor conversions.
509    dim_value = tensor_shape.dimension_value(self.range_dimension)
510    if dim_value is not None:
511      return ops.convert_to_tensor_v2_with_dispatch(dim_value)
512    else:
513      shape = self.shape_tensor() if shape is None else shape
514      return shape[-2]
515
516  def _assert_non_singular(self):
517    """Private default implementation of _assert_non_singular."""
518    logging.warn(
519        "Using (possibly slow) default implementation of assert_non_singular."
520        "  Requires conversion to a dense matrix and O(N^3) operations.")
521    if self._can_use_cholesky():
522      return self.assert_positive_definite()
523    else:
524      singular_values = linalg_ops.svd(self.to_dense(), compute_uv=False)
525      # TODO(langmore) Add .eig and .cond as methods.
526      cond = (math_ops.reduce_max(singular_values, axis=-1) /
527              math_ops.reduce_min(singular_values, axis=-1))
528      return check_ops.assert_less(
529          cond,
530          self._max_condition_number_to_be_non_singular(),
531          message="Singular matrix up to precision epsilon.")
532
533  def _max_condition_number_to_be_non_singular(self):
534    """Return the maximum condition number that we consider nonsingular."""
535    with ops.name_scope("max_nonsingular_condition_number"):
536      dtype_eps = np.finfo(self.dtype.as_numpy_dtype).eps
537      eps = math_ops.cast(
538          math_ops.reduce_max([
539              100.,
540              math_ops.cast(self.range_dimension_tensor(), self.dtype),
541              math_ops.cast(self.domain_dimension_tensor(), self.dtype)
542          ]), self.dtype) * dtype_eps
543      return 1. / eps
544
545  def assert_non_singular(self, name="assert_non_singular"):
546    """Returns an `Op` that asserts this operator is non singular.
547
548    This operator is considered non-singular if
549
550    ```
551    ConditionNumber < max{100, range_dimension, domain_dimension} * eps,
552    eps := np.finfo(self.dtype.as_numpy_dtype).eps
553    ```
554
555    Args:
556      name:  A string name to prepend to created ops.
557
558    Returns:
559      An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
560        the operator is singular.
561    """
562    with self._name_scope(name):  # pylint: disable=not-callable
563      return self._assert_non_singular()
564
565  def _assert_positive_definite(self):
566    """Default implementation of _assert_positive_definite."""
567    logging.warn(
568        "Using (possibly slow) default implementation of "
569        "assert_positive_definite."
570        "  Requires conversion to a dense matrix and O(N^3) operations.")
571    # If the operator is self-adjoint, then checking that
572    # Cholesky decomposition succeeds + results in positive diag is necessary
573    # and sufficient.
574    if self.is_self_adjoint:
575      return check_ops.assert_positive(
576          array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense())),
577          message="Matrix was not positive definite.")
578    # We have no generic check for positive definite.
579    raise NotImplementedError("assert_positive_definite is not implemented.")
580
581  def assert_positive_definite(self, name="assert_positive_definite"):
582    """Returns an `Op` that asserts this operator is positive definite.
583
584    Here, positive definite means that the quadratic form `x^H A x` has positive
585    real part for all nonzero `x`.  Note that we do not require the operator to
586    be self-adjoint to be positive definite.
587
588    Args:
589      name:  A name to give this `Op`.
590
591    Returns:
592      An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
593        the operator is not positive definite.
594    """
595    with self._name_scope(name):  # pylint: disable=not-callable
596      return self._assert_positive_definite()
597
598  def _assert_self_adjoint(self):
599    dense = self.to_dense()
600    logging.warn(
601        "Using (possibly slow) default implementation of assert_self_adjoint."
602        "  Requires conversion to a dense matrix.")
603    return check_ops.assert_equal(
604        dense,
605        linalg.adjoint(dense),
606        message="Matrix was not equal to its adjoint.")
607
608  def assert_self_adjoint(self, name="assert_self_adjoint"):
609    """Returns an `Op` that asserts this operator is self-adjoint.
610
611    Here we check that this operator is *exactly* equal to its hermitian
612    transpose.
613
614    Args:
615      name:  A string name to prepend to created ops.
616
617    Returns:
618      An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
619        the operator is not self-adjoint.
620    """
621    with self._name_scope(name):  # pylint: disable=not-callable
622      return self._assert_self_adjoint()
623
624  def _check_input_dtype(self, arg):
625    """Check that arg.dtype == self.dtype."""
626    if arg.dtype.base_dtype != self.dtype:
627      raise TypeError(
628          "Expected argument to have dtype %s.  Found: %s in tensor %s" %
629          (self.dtype, arg.dtype, arg))
630
631  @abc.abstractmethod
632  def _matmul(self, x, adjoint=False, adjoint_arg=False):
633    raise NotImplementedError("_matmul is not implemented.")
634
635  def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
636    """Transform [batch] matrix `x` with left multiplication:  `x --> Ax`.
637
638    ```python
639    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
640    operator = LinearOperator(...)
641    operator.shape = [..., M, N]
642
643    X = ... # shape [..., N, R], batch matrix, R > 0.
644
645    Y = operator.matmul(X)
646    Y.shape
647    ==> [..., M, R]
648
649    Y[..., :, r] = sum_j A[..., :, j] X[j, r]
650    ```
651
652    Args:
653      x: `LinearOperator` or `Tensor` with compatible shape and same `dtype` as
654        `self`. See class docstring for definition of compatibility.
655      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
656      adjoint_arg:  Python `bool`.  If `True`, compute `A x^H` where `x^H` is
657        the hermitian transpose (transposition and complex conjugation).
658      name:  A name for this `Op`.
659
660    Returns:
661      A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
662        as `self`.
663    """
664    if isinstance(x, LinearOperator):
665      left_operator = self.adjoint() if adjoint else self
666      right_operator = x.adjoint() if adjoint_arg else x
667
668      if (right_operator.range_dimension is not None and
669          left_operator.domain_dimension is not None and
670          right_operator.range_dimension != left_operator.domain_dimension):
671        raise ValueError(
672            "Operators are incompatible. Expected `x` to have dimension"
673            " {} but got {}.".format(
674                left_operator.domain_dimension, right_operator.range_dimension))
675      with self._name_scope(name):  # pylint: disable=not-callable
676        return linear_operator_algebra.matmul(left_operator, right_operator)
677
678    with self._name_scope(name):  # pylint: disable=not-callable
679      x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
680      self._check_input_dtype(x)
681
682      self_dim = -2 if adjoint else -1
683      arg_dim = -1 if adjoint_arg else -2
684      tensor_shape.dimension_at_index(
685          self.shape, self_dim).assert_is_compatible_with(
686              x.shape[arg_dim])
687
688      return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
689
690  def __matmul__(self, other):
691    return self.matmul(other)
692
693  def _matvec(self, x, adjoint=False):
694    x_mat = array_ops.expand_dims(x, axis=-1)
695    y_mat = self.matmul(x_mat, adjoint=adjoint)
696    return array_ops.squeeze(y_mat, axis=-1)
697
698  def matvec(self, x, adjoint=False, name="matvec"):
699    """Transform [batch] vector `x` with left multiplication:  `x --> Ax`.
700
701    ```python
702    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
703    operator = LinearOperator(...)
704
705    X = ... # shape [..., N], batch vector
706
707    Y = operator.matvec(X)
708    Y.shape
709    ==> [..., M]
710
711    Y[..., :] = sum_j A[..., :, j] X[..., j]
712    ```
713
714    Args:
715      x: `Tensor` with compatible shape and same `dtype` as `self`.
716        `x` is treated as a [batch] vector meaning for every set of leading
717        dimensions, the last dimension defines a vector.
718        See class docstring for definition of compatibility.
719      adjoint: Python `bool`.  If `True`, left multiply by the adjoint: `A^H x`.
720      name:  A name for this `Op`.
721
722    Returns:
723      A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
724    """
725    with self._name_scope(name):  # pylint: disable=not-callable
726      x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
727      self._check_input_dtype(x)
728      self_dim = -2 if adjoint else -1
729      tensor_shape.dimension_at_index(
730          self.shape, self_dim).assert_is_compatible_with(x.shape[-1])
731      return self._matvec(x, adjoint=adjoint)
732
733  def _determinant(self):
734    logging.warn(
735        "Using (possibly slow) default implementation of determinant."
736        "  Requires conversion to a dense matrix and O(N^3) operations.")
737    if self._can_use_cholesky():
738      return math_ops.exp(self.log_abs_determinant())
739    return linalg_ops.matrix_determinant(self.to_dense())
740
741  def determinant(self, name="det"):
742    """Determinant for every batch member.
743
744    Args:
745      name:  A name for this `Op`.
746
747    Returns:
748      `Tensor` with shape `self.batch_shape` and same `dtype` as `self`.
749
750    Raises:
751      NotImplementedError:  If `self.is_square` is `False`.
752    """
753    if self.is_square is False:
754      raise NotImplementedError(
755          "Determinant not implemented for an operator that is expected to "
756          "not be square.")
757    with self._name_scope(name):  # pylint: disable=not-callable
758      return self._determinant()
759
760  def _log_abs_determinant(self):
761    logging.warn(
762        "Using (possibly slow) default implementation of determinant."
763        "  Requires conversion to a dense matrix and O(N^3) operations.")
764    if self._can_use_cholesky():
765      diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense()))
766      return 2 * math_ops.reduce_sum(math_ops.log(diag), axis=[-1])
767    _, log_abs_det = linalg.slogdet(self.to_dense())
768    return log_abs_det
769
770  def log_abs_determinant(self, name="log_abs_det"):
771    """Log absolute value of determinant for every batch member.
772
773    Args:
774      name:  A name for this `Op`.
775
776    Returns:
777      `Tensor` with shape `self.batch_shape` and same `dtype` as `self`.
778
779    Raises:
780      NotImplementedError:  If `self.is_square` is `False`.
781    """
782    if self.is_square is False:
783      raise NotImplementedError(
784          "Determinant not implemented for an operator that is expected to "
785          "not be square.")
786    with self._name_scope(name):  # pylint: disable=not-callable
787      return self._log_abs_determinant()
788
789  def _dense_solve(self, rhs, adjoint=False, adjoint_arg=False):
790    """Solve by conversion to a dense matrix."""
791    if self.is_square is False:  # pylint: disable=g-bool-id-comparison
792      raise NotImplementedError(
793          "Solve is not yet implemented for non-square operators.")
794    rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
795    if self._can_use_cholesky():
796      return linalg_ops.cholesky_solve(
797          linalg_ops.cholesky(self.to_dense()), rhs)
798    return linear_operator_util.matrix_solve_with_broadcast(
799        self.to_dense(), rhs, adjoint=adjoint)
800
801  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
802    """Default implementation of _solve."""
803    logging.warn(
804        "Using (possibly slow) default implementation of solve."
805        "  Requires conversion to a dense matrix and O(N^3) operations.")
806    return self._dense_solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
807
808  def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
809    """Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
810
811    The returned `Tensor` will be close to an exact solution if `A` is well
812    conditioned. Otherwise closeness will vary. See class docstring for details.
813
814    Examples:
815
816    ```python
817    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
818    operator = LinearOperator(...)
819    operator.shape = [..., M, N]
820
821    # Solve R > 0 linear systems for every member of the batch.
822    RHS = ... # shape [..., M, R]
823
824    X = operator.solve(RHS)
825    # X[..., :, r] is the solution to the r'th linear system
826    # sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]
827
828    operator.matmul(X)
829    ==> RHS
830    ```
831
832    Args:
833      rhs: `Tensor` with same `dtype` as this operator and compatible shape.
834        `rhs` is treated like a [batch] matrix meaning for every set of leading
835        dimensions, the last two dimensions defines a matrix.
836        See class docstring for definition of compatibility.
837      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
838        of this `LinearOperator`:  `A^H X = rhs`.
839      adjoint_arg:  Python `bool`.  If `True`, solve `A X = rhs^H` where `rhs^H`
840        is the hermitian transpose (transposition and complex conjugation).
841      name:  A name scope to use for ops added by this method.
842
843    Returns:
844      `Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.
845
846    Raises:
847      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
848    """
849    if self.is_non_singular is False:
850      raise NotImplementedError(
851          "Exact solve not implemented for an operator that is expected to "
852          "be singular.")
853    if self.is_square is False:
854      raise NotImplementedError(
855          "Exact solve not implemented for an operator that is expected to "
856          "not be square.")
857    if isinstance(rhs, LinearOperator):
858      left_operator = self.adjoint() if adjoint else self
859      right_operator = rhs.adjoint() if adjoint_arg else rhs
860
861      if (right_operator.range_dimension is not None and
862          left_operator.domain_dimension is not None and
863          right_operator.range_dimension != left_operator.domain_dimension):
864        raise ValueError(
865            "Operators are incompatible. Expected `rhs` to have dimension"
866            " {} but got {}.".format(
867                left_operator.domain_dimension, right_operator.range_dimension))
868      with self._name_scope(name):  # pylint: disable=not-callable
869        return linear_operator_algebra.solve(left_operator, right_operator)
870
871    with self._name_scope(name):  # pylint: disable=not-callable
872      rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
873      self._check_input_dtype(rhs)
874
875      self_dim = -1 if adjoint else -2
876      arg_dim = -1 if adjoint_arg else -2
877      tensor_shape.dimension_at_index(
878          self.shape, self_dim).assert_is_compatible_with(
879              rhs.shape[arg_dim])
880
881      return self._solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
882
883  def _solvevec(self, rhs, adjoint=False):
884    """Default implementation of _solvevec."""
885    rhs_mat = array_ops.expand_dims(rhs, axis=-1)
886    solution_mat = self.solve(rhs_mat, adjoint=adjoint)
887    return array_ops.squeeze(solution_mat, axis=-1)
888
889  def solvevec(self, rhs, adjoint=False, name="solve"):
890    """Solve single equation with best effort: `A X = rhs`.
891
892    The returned `Tensor` will be close to an exact solution if `A` is well
893    conditioned. Otherwise closeness will vary. See class docstring for details.
894
895    Examples:
896
897    ```python
898    # Make an operator acting like batch matrix A.  Assume A.shape = [..., M, N]
899    operator = LinearOperator(...)
900    operator.shape = [..., M, N]
901
902    # Solve one linear system for every member of the batch.
903    RHS = ... # shape [..., M]
904
905    X = operator.solvevec(RHS)
906    # X is the solution to the linear system
907    # sum_j A[..., :, j] X[..., j] = RHS[..., :]
908
909    operator.matvec(X)
910    ==> RHS
911    ```
912
913    Args:
914      rhs: `Tensor` with same `dtype` as this operator.
915        `rhs` is treated like a [batch] vector meaning for every set of leading
916        dimensions, the last dimension defines a vector.  See class docstring
917        for definition of compatibility regarding batch dimensions.
918      adjoint: Python `bool`.  If `True`, solve the system involving the adjoint
919        of this `LinearOperator`:  `A^H X = rhs`.
920      name:  A name scope to use for ops added by this method.
921
922    Returns:
923      `Tensor` with shape `[...,N]` and same `dtype` as `rhs`.
924
925    Raises:
926      NotImplementedError:  If `self.is_non_singular` or `is_square` is False.
927    """
928    with self._name_scope(name):  # pylint: disable=not-callable
929      rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
930      self._check_input_dtype(rhs)
931      self_dim = -1 if adjoint else -2
932      tensor_shape.dimension_at_index(
933          self.shape, self_dim).assert_is_compatible_with(rhs.shape[-1])
934
935      return self._solvevec(rhs, adjoint=adjoint)
936
937  def adjoint(self, name="adjoint"):
938    """Returns the adjoint of the current `LinearOperator`.
939
940    Given `A` representing this `LinearOperator`, return `A*`.
941    Note that calling `self.adjoint()` and `self.H` are equivalent.
942
943    Args:
944      name:  A name for this `Op`.
945
946    Returns:
947      `LinearOperator` which represents the adjoint of this `LinearOperator`.
948    """
949    if self.is_self_adjoint is True:  # pylint: disable=g-bool-id-comparison
950      return self
951    with self._name_scope(name):  # pylint: disable=not-callable
952      return linear_operator_algebra.adjoint(self)
953
954  # self.H is equivalent to self.adjoint().
955  H = property(adjoint, None)
956
957  def inverse(self, name="inverse"):
958    """Returns the Inverse of this `LinearOperator`.
959
960    Given `A` representing this `LinearOperator`, return a `LinearOperator`
961    representing `A^-1`.
962
963    Args:
964      name: A name scope to use for ops added by this method.
965
966    Returns:
967      `LinearOperator` representing inverse of this matrix.
968
969    Raises:
970      ValueError: When the `LinearOperator` is not hinted to be `non_singular`.
971    """
972    if self.is_square is False:  # pylint: disable=g-bool-id-comparison
973      raise ValueError("Cannot take the Inverse: This operator represents "
974                       "a non square matrix.")
975    if self.is_non_singular is False:  # pylint: disable=g-bool-id-comparison
976      raise ValueError("Cannot take the Inverse: This operator represents "
977                       "a singular matrix.")
978
979    with self._name_scope(name):  # pylint: disable=not-callable
980      return linear_operator_algebra.inverse(self)
981
982  def cholesky(self, name="cholesky"):
983    """Returns a Cholesky factor as a `LinearOperator`.
984
985    Given `A` representing this `LinearOperator`, if `A` is positive definite
986    self-adjoint, return `L`, where `A = L L^T`, i.e. the cholesky
987    decomposition.
988
989    Args:
990      name:  A name for this `Op`.
991
992    Returns:
993      `LinearOperator` which represents the lower triangular matrix
994      in the Cholesky decomposition.
995
996    Raises:
997      ValueError: When the `LinearOperator` is not hinted to be positive
998        definite and self adjoint.
999    """
1000
1001    if not self._can_use_cholesky():
1002      raise ValueError("Cannot take the Cholesky decomposition: "
1003                       "Not a positive definite self adjoint matrix.")
1004    with self._name_scope(name):  # pylint: disable=not-callable
1005      return linear_operator_algebra.cholesky(self)
1006
1007  def _to_dense(self):
1008    """Generic and often inefficient implementation.  Override often."""
1009    if self.batch_shape.is_fully_defined():
1010      batch_shape = self.batch_shape
1011    else:
1012      batch_shape = self.batch_shape_tensor()
1013
1014    dim_value = tensor_shape.dimension_value(self.domain_dimension)
1015    if dim_value is not None:
1016      n = dim_value
1017    else:
1018      n = self.domain_dimension_tensor()
1019
1020    eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype)
1021    return self.matmul(eye)
1022
1023  def to_dense(self, name="to_dense"):
1024    """Return a dense (batch) matrix representing this operator."""
1025    with self._name_scope(name):  # pylint: disable=not-callable
1026      return self._to_dense()
1027
1028  def _diag_part(self):
1029    """Generic and often inefficient implementation.  Override often."""
1030    return array_ops.matrix_diag_part(self.to_dense())
1031
1032  def diag_part(self, name="diag_part"):
1033    """Efficiently get the [batch] diagonal part of this operator.
1034
1035    If this operator has shape `[B1,...,Bb, M, N]`, this returns a
1036    `Tensor` `diagonal`, of shape `[B1,...,Bb, min(M, N)]`, where
1037    `diagonal[b1,...,bb, i] = self.to_dense()[b1,...,bb, i, i]`.
1038
1039    ```
1040    my_operator = LinearOperatorDiag([1., 2.])
1041
1042    # Efficiently get the diagonal
1043    my_operator.diag_part()
1044    ==> [1., 2.]
1045
1046    # Equivalent, but inefficient method
1047    tf.linalg.diag_part(my_operator.to_dense())
1048    ==> [1., 2.]
1049    ```
1050
1051    Args:
1052      name:  A name for this `Op`.
1053
1054    Returns:
1055      diag_part:  A `Tensor` of same `dtype` as self.
1056    """
1057    with self._name_scope(name):  # pylint: disable=not-callable
1058      return self._diag_part()
1059
1060  def _trace(self):
1061    return math_ops.reduce_sum(self.diag_part(), axis=-1)
1062
1063  def trace(self, name="trace"):
1064    """Trace of the linear operator, equal to sum of `self.diag_part()`.
1065
1066    If the operator is square, this is also the sum of the eigenvalues.
1067
1068    Args:
1069      name:  A name for this `Op`.
1070
1071    Returns:
1072      Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`.
1073    """
1074    with self._name_scope(name):  # pylint: disable=not-callable
1075      return self._trace()
1076
1077  def _add_to_tensor(self, x):
1078    # Override if a more efficient implementation is available.
1079    return self.to_dense() + x
1080
1081  def add_to_tensor(self, x, name="add_to_tensor"):
1082    """Add matrix represented by this operator to `x`.  Equivalent to `A + x`.
1083
1084    Args:
1085      x:  `Tensor` with same `dtype` and shape broadcastable to `self.shape`.
1086      name:  A name to give this `Op`.
1087
1088    Returns:
1089      A `Tensor` with broadcast shape and same `dtype` as `self`.
1090    """
1091    with self._name_scope(name):  # pylint: disable=not-callable
1092      x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
1093      self._check_input_dtype(x)
1094      return self._add_to_tensor(x)
1095
1096  def _eigvals(self):
1097    return linalg_ops.self_adjoint_eigvals(self.to_dense())
1098
1099  def eigvals(self, name="eigvals"):
1100    """Returns the eigenvalues of this linear operator.
1101
1102    If the operator is marked as self-adjoint (via `is_self_adjoint`)
1103    this computation can be more efficient.
1104
1105    Note: This currently only supports self-adjoint operators.
1106
1107    Args:
1108      name:  A name for this `Op`.
1109
1110    Returns:
1111      Shape `[B1,...,Bb, N]` `Tensor` of same `dtype` as `self`.
1112    """
1113    if not self.is_self_adjoint:
1114      raise NotImplementedError("Only self-adjoint matrices are supported.")
1115    with self._name_scope(name):  # pylint: disable=not-callable
1116      return self._eigvals()
1117
1118  def _cond(self):
1119    if not self.is_self_adjoint:
1120      # In general the condition number is the ratio of the
1121      # absolute value of the largest and smallest singular values.
1122      vals = linalg_ops.svd(self.to_dense(), compute_uv=False)
1123    else:
1124      # For self-adjoint matrices, and in general normal matrices,
1125      # we can use eigenvalues.
1126      vals = math_ops.abs(self._eigvals())
1127
1128    return (math_ops.reduce_max(vals, axis=-1) /
1129            math_ops.reduce_min(vals, axis=-1))
1130
1131  def cond(self, name="cond"):
1132    """Returns the condition number of this linear operator.
1133
1134    Args:
1135      name:  A name for this `Op`.
1136
1137    Returns:
1138      Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`.
1139    """
1140    with self._name_scope(name):  # pylint: disable=not-callable
1141      return self._cond()
1142
1143  def _can_use_cholesky(self):
1144    return self.is_self_adjoint and self.is_positive_definite
1145
1146  def _set_graph_parents(self, graph_parents):
1147    """Set self._graph_parents.  Called during derived class init.
1148
1149    This method allows derived classes to set graph_parents, without triggering
1150    a deprecation warning (which is invoked if `graph_parents` is passed during
1151    `__init__`.
1152
1153    Args:
1154      graph_parents: Iterable over Tensors.
1155    """
1156    # TODO(b/143910018) Remove this function in V3.
1157    graph_parents = [] if graph_parents is None else graph_parents
1158    for i, t in enumerate(graph_parents):
1159      if t is None or not (linear_operator_util.is_ref(t) or
1160                           tensor_util.is_tf_type(t)):
1161        raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
1162    self._graph_parents = graph_parents
1163
1164  @property
1165  def _composite_tensor_fields(self):
1166    """A tuple of parameter names to rebuild the `LinearOperator`.
1167
1168    The tuple contains the names of kwargs to the `LinearOperator`'s constructor
1169    that the `TypeSpec` needs to rebuild the `LinearOperator` instance.
1170
1171    "is_non_singular", "is_self_adjoint", "is_positive_definite", and
1172    "is_square" are common to all `LinearOperator` subclasses and may be
1173    omitted.
1174    """
1175    return ()
1176
1177  @property
1178  def _composite_tensor_prefer_static_fields(self):
1179    """A tuple of names referring to parameters that may be treated statically.
1180
1181    This is a subset of `_composite_tensor_fields`, and contains the names of
1182    of `Tensor`-like args to the `LinearOperator`s constructor that may be
1183    stored as static values, if they are statically known. These are typically
1184    shapes or axis values.
1185    """
1186    return ()
1187
1188  @property
1189  def _type_spec(self):
1190    # This property will be overwritten by the `@make_composite_tensor`
1191    # decorator. However, we need it so that a valid subclass of the `ABCMeta`
1192    # class `CompositeTensor` can be constructed and passed to the
1193    # `@make_composite_tensor` decorator.
1194    pass
1195
1196  def _convert_variables_to_tensors(self):
1197    """Recursively converts ResourceVariables in the LinearOperator to Tensors.
1198
1199    The usage of `self._type_spec._from_components` violates the contract of
1200    `CompositeTensor`, since it is called on a different nested structure
1201    (one containing only `Tensor`s) than `self.type_spec` specifies (one that
1202    may contain `ResourceVariable`s). Since `LinearOperator`'s
1203    `_from_components` method just passes the contents of the nested structure
1204    to `__init__` to rebuild the operator, and any `LinearOperator` that may be
1205    instantiated with `ResourceVariables` may also be instantiated with
1206    `Tensor`s, this usage is valid.
1207
1208    Returns:
1209      tensor_operator: `self` with all internal Variables converted to Tensors.
1210    """
1211    # pylint: disable=protected-access
1212    components = self._type_spec._to_components(self)
1213    tensor_components = variable_utils.convert_variables_to_tensors(
1214        components)
1215    return self._type_spec._from_components(tensor_components)
1216    # pylint: enable=protected-access
1217
1218  def __getitem__(self, slices):
1219    return slicing.batch_slice(self, params_overrides={}, slices=slices)
1220
1221  @property
1222  def _experimental_parameter_ndims_to_matrix_ndims(self):
1223    """A dict of names to number of dimensions contributing to an operator.
1224
1225    This is a dictionary of parameter names to `int`s specifying the
1226    number of right-most dimensions contributing to the **matrix** shape of the
1227    densified operator.
1228    If the parameter is a `Tensor`, this is mapped to an `int`.
1229    If the parameter is a `LinearOperator` (called `A`), this specifies the
1230    number of batch dimensions of `A` contributing to this `LinearOperator`s
1231    matrix shape.
1232    If the parameter is a structure, this is a structure of the same type of
1233    `int`s.
1234    """
1235    return ()
1236
1237
1238class _LinearOperatorSpec(type_spec.BatchableTypeSpec):
1239  """A tf.TypeSpec for `LinearOperator` objects."""
1240
1241  __slots__ = ("_param_specs", "_non_tensor_params", "_prefer_static_fields")
1242
1243  def __init__(self, param_specs, non_tensor_params, prefer_static_fields):
1244    """Initializes a new `_LinearOperatorSpec`.
1245
1246    Args:
1247      param_specs: Python `dict` of `tf.TypeSpec` instances that describe
1248        kwargs to the `LinearOperator`'s constructor that are `Tensor`-like or
1249        `CompositeTensor` subclasses.
1250      non_tensor_params: Python `dict` containing non-`Tensor` and non-
1251        `CompositeTensor` kwargs to the `LinearOperator`'s constructor.
1252      prefer_static_fields: Python `tuple` of strings corresponding to the names
1253        of `Tensor`-like args to the `LinearOperator`s constructor that may be
1254        stored as static values, if known. These are typically shapes, indices,
1255        or axis values.
1256    """
1257    self._param_specs = param_specs
1258    self._non_tensor_params = non_tensor_params
1259    self._prefer_static_fields = prefer_static_fields
1260
1261  @classmethod
1262  def from_operator(cls, operator):
1263    """Builds a `_LinearOperatorSpec` from a `LinearOperator` instance.
1264
1265    Args:
1266      operator: An instance of `LinearOperator`.
1267
1268    Returns:
1269      linear_operator_spec: An instance of `_LinearOperatorSpec` to be used as
1270        the `TypeSpec` of `operator`.
1271    """
1272    validation_fields = ("is_non_singular", "is_self_adjoint",
1273                         "is_positive_definite", "is_square")
1274    kwargs = _extract_attrs(
1275        operator,
1276        keys=set(operator._composite_tensor_fields + validation_fields))  # pylint: disable=protected-access
1277
1278    non_tensor_params = {}
1279    param_specs = {}
1280    for k, v in list(kwargs.items()):
1281      type_spec_or_v = _extract_type_spec_recursively(v)
1282      is_tensor = [isinstance(x, type_spec.TypeSpec)
1283                   for x in nest.flatten(type_spec_or_v)]
1284      if all(is_tensor):
1285        param_specs[k] = type_spec_or_v
1286      elif not any(is_tensor):
1287        non_tensor_params[k] = v
1288      else:
1289        raise NotImplementedError(f"Field {k} contains a mix of `Tensor` and "
1290                                  f" non-`Tensor` values.")
1291
1292    return cls(
1293        param_specs=param_specs,
1294        non_tensor_params=non_tensor_params,
1295        prefer_static_fields=operator._composite_tensor_prefer_static_fields)  # pylint: disable=protected-access
1296
1297  def _to_components(self, obj):
1298    return _extract_attrs(obj, keys=list(self._param_specs))
1299
1300  def _from_components(self, components):
1301    kwargs = dict(self._non_tensor_params, **components)
1302    return self.value_type(**kwargs)
1303
1304  @property
1305  def _component_specs(self):
1306    return self._param_specs
1307
1308  def _serialize(self):
1309    return (self._param_specs,
1310            self._non_tensor_params,
1311            self._prefer_static_fields)
1312
1313  def _copy(self, **overrides):
1314    kwargs = {
1315        "param_specs": self._param_specs,
1316        "non_tensor_params": self._non_tensor_params,
1317        "prefer_static_fields": self._prefer_static_fields
1318    }
1319    kwargs.update(overrides)
1320    return type(self)(**kwargs)
1321
1322  def _batch(self, batch_size):
1323    """Returns a TypeSpec representing a batch of objects with this TypeSpec."""
1324    return self._copy(
1325        param_specs=nest.map_structure(
1326            lambda spec: spec._batch(batch_size),  # pylint: disable=protected-access
1327            self._param_specs))
1328
1329  def _unbatch(self, batch_size):
1330    """Returns a TypeSpec representing a single element of this TypeSpec."""
1331    return self._copy(
1332        param_specs=nest.map_structure(
1333            lambda spec: spec._unbatch(),  # pylint: disable=protected-access
1334            self._param_specs))
1335
1336
1337def make_composite_tensor(cls, module_name="tf.linalg"):
1338  """Class decorator to convert `LinearOperator`s to `CompositeTensor`."""
1339
1340  spec_name = "{}Spec".format(cls.__name__)
1341  spec_type = type(spec_name, (_LinearOperatorSpec,), {"value_type": cls})
1342  type_spec.register("{}.{}".format(module_name, spec_name))(spec_type)
1343  cls._type_spec = property(spec_type.from_operator)  # pylint: disable=protected-access
1344  return cls
1345
1346
1347def _extract_attrs(op, keys):
1348  """Extract constructor kwargs to reconstruct `op`.
1349
1350  Args:
1351    op: A `LinearOperator` instance.
1352    keys: A Python `tuple` of strings indicating the names of the constructor
1353      kwargs to extract from `op`.
1354
1355  Returns:
1356    kwargs: A Python `dict` of kwargs to `op`'s constructor, keyed by `keys`.
1357  """
1358
1359  kwargs = {}
1360  not_found = object()
1361  for k in keys:
1362    srcs = [
1363        getattr(op, k, not_found), getattr(op, "_" + k, not_found),
1364        getattr(op, "parameters", {}).get(k, not_found),
1365    ]
1366    if any(v is not not_found for v in srcs):
1367      kwargs[k] = [v for v in srcs if v is not not_found][0]
1368    else:
1369      raise ValueError(
1370          f"Could not determine an appropriate value for field `{k}` in object "
1371          f" `{op}`. Looked for \n"
1372          f" 1. an attr called `{k}`,\n"
1373          f" 2. an attr called `_{k}`,\n"
1374          f" 3. an entry in `op.parameters` with key '{k}'.")
1375    if k in op._composite_tensor_prefer_static_fields and kwargs[k] is not None:  # pylint: disable=protected-access
1376      if tensor_util.is_tensor(kwargs[k]):
1377        static_val = tensor_util.constant_value(kwargs[k])
1378        if static_val is not None:
1379          kwargs[k] = static_val
1380    if isinstance(kwargs[k], (np.ndarray, np.generic)):
1381      kwargs[k] = kwargs[k].tolist()
1382  return kwargs
1383
1384
1385def _extract_type_spec_recursively(value):
1386  """Return (collection of) `TypeSpec`(s) for `value` if it includes `Tensor`s.
1387
1388  If `value` is a `Tensor` or `CompositeTensor`, return its `TypeSpec`. If
1389  `value` is a collection containing `Tensor` values, recursively supplant them
1390  with their respective `TypeSpec`s in a collection of parallel stucture.
1391
1392  If `value` is none of the above, return it unchanged.
1393
1394  Args:
1395    value: a Python `object` to (possibly) turn into a (collection of)
1396    `tf.TypeSpec`(s).
1397
1398  Returns:
1399    spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value`
1400    or `value`, if no `Tensor`s are found.
1401  """
1402  if isinstance(value, composite_tensor.CompositeTensor):
1403    return value._type_spec  # pylint: disable=protected-access
1404  if isinstance(value, variables.Variable):
1405    return resource_variable_ops.VariableSpec(
1406        value.shape, dtype=value.dtype, trainable=value.trainable)
1407  if tensor_util.is_tensor(value):
1408    return tensor_spec.TensorSpec(value.shape, value.dtype)
1409  # Unwrap trackable data structures to comply with `Type_Spec._serialize`
1410  # requirements. `ListWrapper`s are converted to `list`s, and for other
1411  # trackable data structures, the `__wrapped__` attribute is used.
1412  if isinstance(value, list):
1413    return list(_extract_type_spec_recursively(v) for v in value)
1414  if isinstance(value, data_structures.TrackableDataStructure):
1415    return _extract_type_spec_recursively(value.__wrapped__)
1416  if isinstance(value, tuple):
1417    return type(value)(_extract_type_spec_recursively(x) for x in value)
1418  if isinstance(value, dict):
1419    return type(value)((k, _extract_type_spec_recursively(v))
1420                       for k, v in value.items())
1421  return value
1422
1423
1424# Overrides for tf.linalg functions. This allows a LinearOperator to be used in
1425# place of a Tensor.
1426# For instance tf.trace(linop) and linop.trace() both work.
1427
1428
1429@dispatch.dispatch_for_types(linalg.adjoint, LinearOperator)
1430def _adjoint(matrix, name=None):
1431  return matrix.adjoint(name)
1432
1433
1434@dispatch.dispatch_for_types(linalg.cholesky, LinearOperator)
1435def _cholesky(input, name=None):   # pylint:disable=redefined-builtin
1436  return input.cholesky(name)
1437
1438
1439# The signature has to match with the one in python/op/array_ops.py,
1440# so we have k, padding_value, and align even though we don't use them here.
1441# pylint:disable=unused-argument
1442@dispatch.dispatch_for_types(linalg.diag_part, LinearOperator)
1443def _diag_part(
1444    input,  # pylint:disable=redefined-builtin
1445    name="diag_part",
1446    k=0,
1447    padding_value=0,
1448    align="RIGHT_LEFT"):
1449  return input.diag_part(name)
1450# pylint:enable=unused-argument
1451
1452
1453@dispatch.dispatch_for_types(linalg.det, LinearOperator)
1454def _det(input, name=None):  # pylint:disable=redefined-builtin
1455  return input.determinant(name)
1456
1457
1458@dispatch.dispatch_for_types(linalg.inv, LinearOperator)
1459def _inverse(input, adjoint=False, name=None):   # pylint:disable=redefined-builtin
1460  inv = input.inverse(name)
1461  if adjoint:
1462    inv = inv.adjoint()
1463  return inv
1464
1465
1466@dispatch.dispatch_for_types(linalg.logdet, LinearOperator)
1467def _logdet(matrix, name=None):
1468  if matrix.is_positive_definite and matrix.is_self_adjoint:
1469    return matrix.log_abs_determinant(name)
1470  raise ValueError("Expected matrix to be self-adjoint positive definite.")
1471
1472
1473@dispatch.dispatch_for_types(math_ops.matmul, LinearOperator)
1474def _matmul(  # pylint:disable=missing-docstring
1475    a,
1476    b,
1477    transpose_a=False,
1478    transpose_b=False,
1479    adjoint_a=False,
1480    adjoint_b=False,
1481    a_is_sparse=False,
1482    b_is_sparse=False,
1483    output_type=None,  # pylint: disable=unused-argument
1484    name=None):
1485  if transpose_a or transpose_b:
1486    raise ValueError("Transposing not supported at this time.")
1487  if a_is_sparse or b_is_sparse:
1488    raise ValueError("Sparse methods not supported at this time.")
1489  if not isinstance(a, LinearOperator):
1490    # We use the identity (B^HA^H)^H =  AB
1491    adjoint_matmul = b.matmul(
1492        a,
1493        adjoint=(not adjoint_b),
1494        adjoint_arg=(not adjoint_a),
1495        name=name)
1496    return linalg.adjoint(adjoint_matmul)
1497  return a.matmul(
1498      b, adjoint=adjoint_a, adjoint_arg=adjoint_b, name=name)
1499
1500
1501@dispatch.dispatch_for_types(linalg.solve, LinearOperator)
1502def _solve(
1503    matrix,
1504    rhs,
1505    adjoint=False,
1506    name=None):
1507  if not isinstance(matrix, LinearOperator):
1508    raise ValueError("Passing in `matrix` as a Tensor and `rhs` as a "
1509                     "LinearOperator is not supported.")
1510  return matrix.solve(rhs, adjoint=adjoint, name=name)
1511
1512
1513@dispatch.dispatch_for_types(linalg.trace, LinearOperator)
1514def _trace(x, name=None):
1515  return x.trace(name)
1516