xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/linalg/linalg_impl.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for linear algebra."""
16
17import numpy as np
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import check_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import gen_linalg_ops
27from tensorflow.python.ops import linalg_ops
28from tensorflow.python.ops import map_fn
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import special_math_ops
31from tensorflow.python.ops import stateless_random_ops
32from tensorflow.python.util import dispatch
33from tensorflow.python.util.tf_export import tf_export
34
35# Linear algebra ops.
36band_part = array_ops.matrix_band_part
37cholesky = linalg_ops.cholesky
38cholesky_solve = linalg_ops.cholesky_solve
39det = linalg_ops.matrix_determinant
40slogdet = gen_linalg_ops.log_matrix_determinant
41tf_export('linalg.slogdet')(dispatch.add_dispatch_support(slogdet))
42diag = array_ops.matrix_diag
43diag_part = array_ops.matrix_diag_part
44eigh = linalg_ops.self_adjoint_eig
45eigvalsh = linalg_ops.self_adjoint_eigvals
46einsum = special_math_ops.einsum
47eye = linalg_ops.eye
48inv = linalg_ops.matrix_inverse
49logm = gen_linalg_ops.matrix_logarithm
50lu = gen_linalg_ops.lu
51tf_export('linalg.logm')(dispatch.add_dispatch_support(logm))
52lstsq = linalg_ops.matrix_solve_ls
53norm = linalg_ops.norm
54qr = linalg_ops.qr
55set_diag = array_ops.matrix_set_diag
56solve = linalg_ops.matrix_solve
57sqrtm = linalg_ops.matrix_square_root
58svd = linalg_ops.svd
59tensordot = math_ops.tensordot
60trace = math_ops.trace
61transpose = array_ops.matrix_transpose
62triangular_solve = linalg_ops.matrix_triangular_solve
63
64
65@tf_export('linalg.logdet')
66@dispatch.add_dispatch_support
67def logdet(matrix, name=None):
68  """Computes log of the determinant of a hermitian positive definite matrix.
69
70  ```python
71  # Compute the determinant of a matrix while reducing the chance of over- or
72  underflow:
73  A = ... # shape 10 x 10
74  det = tf.exp(tf.linalg.logdet(A))  # scalar
75  ```
76
77  Args:
78    matrix:  A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
79      or `complex128` with shape `[..., M, M]`.
80    name:  A name to give this `Op`.  Defaults to `logdet`.
81
82  Returns:
83    The natural log of the determinant of `matrix`.
84
85  @compatibility(numpy)
86  Equivalent to numpy.linalg.slogdet, although no sign is returned since only
87  hermitian positive definite matrices are supported.
88  @end_compatibility
89  """
90  # This uses the property that the log det(A) = 2*sum(log(real(diag(C))))
91  # where C is the cholesky decomposition of A.
92  with ops.name_scope(name, 'logdet', [matrix]):
93    chol = gen_linalg_ops.cholesky(matrix)
94    return 2.0 * math_ops.reduce_sum(
95        math_ops.log(math_ops.real(array_ops.matrix_diag_part(chol))),
96        axis=[-1])
97
98
99@tf_export('linalg.adjoint')
100@dispatch.add_dispatch_support
101def adjoint(matrix, name=None):
102  """Transposes the last two dimensions of and conjugates tensor `matrix`.
103
104  For example:
105
106  ```python
107  x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
108                   [4 + 4j, 5 + 5j, 6 + 6j]])
109  tf.linalg.adjoint(x)  # [[1 - 1j, 4 - 4j],
110                        #  [2 - 2j, 5 - 5j],
111                        #  [3 - 3j, 6 - 6j]]
112  ```
113
114  Args:
115    matrix:  A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`,
116      or `complex128` with shape `[..., M, M]`.
117    name:  A name to give this `Op` (optional).
118
119  Returns:
120    The adjoint (a.k.a. Hermitian transpose a.k.a. conjugate transpose) of
121    matrix.
122  """
123  with ops.name_scope(name, 'adjoint', [matrix]):
124    matrix = ops.convert_to_tensor(matrix, name='matrix')
125    return array_ops.matrix_transpose(matrix, conjugate=True)
126
127
128# This section is ported nearly verbatim from Eigen's implementation:
129# https://eigen.tuxfamily.org/dox/unsupported/MatrixExponential_8h_source.html
130def _matrix_exp_pade3(matrix):
131  """3rd-order Pade approximant for matrix exponential."""
132  b = [120.0, 60.0, 12.0]
133  b = [constant_op.constant(x, matrix.dtype) for x in b]
134  ident = linalg_ops.eye(
135      array_ops.shape(matrix)[-2],
136      batch_shape=array_ops.shape(matrix)[:-2],
137      dtype=matrix.dtype)
138  matrix_2 = math_ops.matmul(matrix, matrix)
139  tmp = matrix_2 + b[1] * ident
140  matrix_u = math_ops.matmul(matrix, tmp)
141  matrix_v = b[2] * matrix_2 + b[0] * ident
142  return matrix_u, matrix_v
143
144
145def _matrix_exp_pade5(matrix):
146  """5th-order Pade approximant for matrix exponential."""
147  b = [30240.0, 15120.0, 3360.0, 420.0, 30.0]
148  b = [constant_op.constant(x, matrix.dtype) for x in b]
149  ident = linalg_ops.eye(
150      array_ops.shape(matrix)[-2],
151      batch_shape=array_ops.shape(matrix)[:-2],
152      dtype=matrix.dtype)
153  matrix_2 = math_ops.matmul(matrix, matrix)
154  matrix_4 = math_ops.matmul(matrix_2, matrix_2)
155  tmp = matrix_4 + b[3] * matrix_2 + b[1] * ident
156  matrix_u = math_ops.matmul(matrix, tmp)
157  matrix_v = b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident
158  return matrix_u, matrix_v
159
160
161def _matrix_exp_pade7(matrix):
162  """7th-order Pade approximant for matrix exponential."""
163  b = [17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0]
164  b = [constant_op.constant(x, matrix.dtype) for x in b]
165  ident = linalg_ops.eye(
166      array_ops.shape(matrix)[-2],
167      batch_shape=array_ops.shape(matrix)[:-2],
168      dtype=matrix.dtype)
169  matrix_2 = math_ops.matmul(matrix, matrix)
170  matrix_4 = math_ops.matmul(matrix_2, matrix_2)
171  matrix_6 = math_ops.matmul(matrix_4, matrix_2)
172  tmp = matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident
173  matrix_u = math_ops.matmul(matrix, tmp)
174  matrix_v = b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident
175  return matrix_u, matrix_v
176
177
178def _matrix_exp_pade9(matrix):
179  """9th-order Pade approximant for matrix exponential."""
180  b = [
181      17643225600.0, 8821612800.0, 2075673600.0, 302702400.0, 30270240.0,
182      2162160.0, 110880.0, 3960.0, 90.0
183  ]
184  b = [constant_op.constant(x, matrix.dtype) for x in b]
185  ident = linalg_ops.eye(
186      array_ops.shape(matrix)[-2],
187      batch_shape=array_ops.shape(matrix)[:-2],
188      dtype=matrix.dtype)
189  matrix_2 = math_ops.matmul(matrix, matrix)
190  matrix_4 = math_ops.matmul(matrix_2, matrix_2)
191  matrix_6 = math_ops.matmul(matrix_4, matrix_2)
192  matrix_8 = math_ops.matmul(matrix_6, matrix_2)
193  tmp = (
194      matrix_8 + b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 +
195      b[1] * ident)
196  matrix_u = math_ops.matmul(matrix, tmp)
197  matrix_v = (
198      b[8] * matrix_8 + b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 +
199      b[0] * ident)
200  return matrix_u, matrix_v
201
202
203def _matrix_exp_pade13(matrix):
204  """13th-order Pade approximant for matrix exponential."""
205  b = [
206      64764752532480000.0, 32382376266240000.0, 7771770303897600.0,
207      1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0,
208      33522128640.0, 1323241920.0, 40840800.0, 960960.0, 16380.0, 182.0
209  ]
210  b = [constant_op.constant(x, matrix.dtype) for x in b]
211  ident = linalg_ops.eye(
212      array_ops.shape(matrix)[-2],
213      batch_shape=array_ops.shape(matrix)[:-2],
214      dtype=matrix.dtype)
215  matrix_2 = math_ops.matmul(matrix, matrix)
216  matrix_4 = math_ops.matmul(matrix_2, matrix_2)
217  matrix_6 = math_ops.matmul(matrix_4, matrix_2)
218  tmp_u = (
219      math_ops.matmul(matrix_6, matrix_6 + b[11] * matrix_4 + b[9] * matrix_2) +
220      b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident)
221  matrix_u = math_ops.matmul(matrix, tmp_u)
222  tmp_v = b[12] * matrix_6 + b[10] * matrix_4 + b[8] * matrix_2
223  matrix_v = (
224      math_ops.matmul(matrix_6, tmp_v) + b[6] * matrix_6 + b[4] * matrix_4 +
225      b[2] * matrix_2 + b[0] * ident)
226  return matrix_u, matrix_v
227
228
229@tf_export('linalg.expm')
230@dispatch.add_dispatch_support
231def matrix_exponential(input, name=None):  # pylint: disable=redefined-builtin
232  r"""Computes the matrix exponential of one or more square matrices.
233
234  $$exp(A) = \sum_{n=0}^\infty A^n/n!$$
235
236  The exponential is computed using a combination of the scaling and squaring
237  method and the Pade approximation. Details can be found in:
238  Nicholas J. Higham, "The scaling and squaring method for the matrix
239  exponential revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
240
241  The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
242  form square matrices. The output is a tensor of the same shape as the input
243  containing the exponential for all input submatrices `[..., :, :]`.
244
245  Args:
246    input: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, or
247      `complex128` with shape `[..., M, M]`.
248    name:  A name to give this `Op` (optional).
249
250  Returns:
251    the matrix exponential of the input.
252
253  Raises:
254    ValueError: An unsupported type is provided as input.
255
256  @compatibility(scipy)
257  Equivalent to scipy.linalg.expm
258  @end_compatibility
259  """
260  with ops.name_scope(name, 'matrix_exponential', [input]):
261    matrix = ops.convert_to_tensor(input, name='input')
262    if matrix.shape[-2:] == [0, 0]:
263      return matrix
264    batch_shape = matrix.shape[:-2]
265    if not batch_shape.is_fully_defined():
266      batch_shape = array_ops.shape(matrix)[:-2]
267
268    # reshaping the batch makes the where statements work better
269    matrix = array_ops.reshape(
270        matrix, array_ops.concat(([-1], array_ops.shape(matrix)[-2:]), axis=0))
271    l1_norm = math_ops.reduce_max(
272        math_ops.reduce_sum(
273            math_ops.abs(matrix),
274            axis=array_ops.size(array_ops.shape(matrix)) - 2),
275        axis=-1)[..., array_ops.newaxis, array_ops.newaxis]
276
277    const = lambda x: constant_op.constant(x, l1_norm.dtype)
278
279    def _nest_where(vals, cases):
280      assert len(vals) == len(cases) - 1
281      if len(vals) == 1:
282        return array_ops.where_v2(
283            math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1])
284      else:
285        return array_ops.where_v2(
286            math_ops.less(l1_norm, const(vals[0])), cases[0],
287            _nest_where(vals[1:], cases[1:]))
288
289    if matrix.dtype in [dtypes.float16, dtypes.float32, dtypes.complex64]:
290      maxnorm = const(3.925724783138660)
291      squarings = math_ops.maximum(
292          math_ops.floor(
293              math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
294      u3, v3 = _matrix_exp_pade3(matrix)
295      u5, v5 = _matrix_exp_pade5(matrix)
296      u7, v7 = _matrix_exp_pade7(
297          matrix /
298          math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype))
299      conds = (4.258730016922831e-001, 1.880152677804762e+000)
300      u = _nest_where(conds, (u3, u5, u7))
301      v = _nest_where(conds, (v3, v5, v7))
302    elif matrix.dtype in [dtypes.float64, dtypes.complex128]:
303      maxnorm = const(5.371920351148152)
304      squarings = math_ops.maximum(
305          math_ops.floor(
306              math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
307      u3, v3 = _matrix_exp_pade3(matrix)
308      u5, v5 = _matrix_exp_pade5(matrix)
309      u7, v7 = _matrix_exp_pade7(matrix)
310      u9, v9 = _matrix_exp_pade9(matrix)
311      u13, v13 = _matrix_exp_pade13(
312          matrix /
313          math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype))
314      conds = (1.495585217958292e-002, 2.539398330063230e-001,
315               9.504178996162932e-001, 2.097847961257068e+000)
316      u = _nest_where(conds, (u3, u5, u7, u9, u13))
317      v = _nest_where(conds, (v3, v5, v7, v9, v13))
318    else:
319      raise ValueError('tf.linalg.expm does not support matrices of type %s' %
320                       matrix.dtype)
321
322    is_finite = math_ops.is_finite(math_ops.reduce_max(l1_norm))
323    nan = constant_op.constant(np.nan, matrix.dtype)
324    result = control_flow_ops.cond(
325        is_finite, lambda: linalg_ops.matrix_solve(-u + v, u + v),
326        lambda: array_ops.fill(array_ops.shape(matrix), nan))
327    max_squarings = math_ops.reduce_max(squarings)
328    i = const(0.0)
329
330    def c(i, _):
331      return control_flow_ops.cond(is_finite,
332                                   lambda: math_ops.less(i, max_squarings),
333                                   lambda: constant_op.constant(False))
334
335    def b(i, r):
336      return i + 1, array_ops.where_v2(
337          math_ops.less(i, squarings), math_ops.matmul(r, r), r)
338
339    _, result = control_flow_ops.while_loop(c, b, [i, result])
340    if not matrix.shape.is_fully_defined():
341      return array_ops.reshape(
342          result,
343          array_ops.concat((batch_shape, array_ops.shape(result)[-2:]), axis=0))
344    return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:]))
345
346
347@tf_export('linalg.banded_triangular_solve', v1=[])
348def banded_triangular_solve(
349    bands,
350    rhs,
351    lower=True,
352    adjoint=False,  # pylint: disable=redefined-outer-name
353    name=None):
354  r"""Solve triangular systems of equations with a banded solver.
355
356  `bands` is a tensor of shape `[..., K, M]`, where `K` represents the number
357  of bands stored. This corresponds to a batch of `M` by `M` matrices, whose
358  `K` subdiagonals (when `lower` is `True`) are stored.
359
360  This operator broadcasts the batch dimensions of `bands` and the batch
361  dimensions of `rhs`.
362
363
364  Examples:
365
366  Storing 2 bands of a 3x3 matrix.
367  Note that first element in the second row is ignored due to
368  the 'LEFT_RIGHT' padding.
369
370  >>> x = [[2., 3., 4.], [1., 2., 3.]]
371  >>> x2 = [[2., 3., 4.], [10000., 2., 3.]]
372  >>> y = tf.zeros([3, 3])
373  >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(-1, 0))
374  >>> z
375  <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
376  array([[2., 0., 0.],
377         [2., 3., 0.],
378         [0., 3., 4.]], dtype=float32)>
379  >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([3, 1]))
380  >>> soln
381  <tf.Tensor: shape=(3, 1), dtype=float32, numpy=
382  array([[0.5 ],
383         [0.  ],
384         [0.25]], dtype=float32)>
385  >>> are_equal = soln == tf.linalg.banded_triangular_solve(x2, tf.ones([3, 1]))
386  >>> tf.reduce_all(are_equal).numpy()
387  True
388  >>> are_equal = soln == tf.linalg.triangular_solve(z, tf.ones([3, 1]))
389  >>> tf.reduce_all(are_equal).numpy()
390  True
391
392  Storing 2 superdiagonals of a 4x4 matrix. Because of the 'LEFT_RIGHT' padding
393  the last element of the first row is ignored.
394
395  >>> x = [[2., 3., 4., 5.], [-1., -2., -3., -4.]]
396  >>> y = tf.zeros([4, 4])
397  >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(0, 1))
398  >>> z
399  <tf.Tensor: shape=(4, 4), dtype=float32, numpy=
400  array([[-1.,  2.,  0.,  0.],
401         [ 0., -2.,  3.,  0.],
402         [ 0.,  0., -3.,  4.],
403         [ 0.,  0., -0., -4.]], dtype=float32)>
404  >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([4, 1]), lower=False)
405  >>> soln
406  <tf.Tensor: shape=(4, 1), dtype=float32, numpy=
407  array([[-4.       ],
408         [-1.5      ],
409         [-0.6666667],
410         [-0.25     ]], dtype=float32)>
411  >>> are_equal = (soln == tf.linalg.triangular_solve(
412  ...   z, tf.ones([4, 1]), lower=False))
413  >>> tf.reduce_all(are_equal).numpy()
414  True
415
416
417  Args:
418    bands: A `Tensor` describing the bands of the left hand side, with shape
419      `[..., K, M]`. The `K` rows correspond to the diagonal to the `K - 1`-th
420      diagonal (the diagonal is the top row) when `lower` is `True` and
421      otherwise the `K - 1`-th superdiagonal to the diagonal (the diagonal is
422      the bottom row) when `lower` is `False`. The bands are stored with
423      'LEFT_RIGHT' alignment, where the superdiagonals are padded on the right
424      and subdiagonals are padded on the left. This is the alignment cuSPARSE
425      uses.  See  `tf.linalg.set_diag` for more details.
426    rhs: A `Tensor` of shape [..., M] or [..., M, N] and with the same dtype as
427      `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known
428      statically, `rhs` will be treated as a matrix rather than a vector.
429    lower: An optional `bool`. Defaults to `True`. Boolean indicating whether
430      `bands` represents a lower or upper triangular matrix.
431    adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether
432      to solve with the matrix's block-wise adjoint.
433    name:  A name to give this `Op` (optional).
434
435  Returns:
436    A `Tensor` of shape [..., M] or [..., M, N] containing the solutions.
437  """
438  with ops.name_scope(name, 'banded_triangular_solve', [bands, rhs]):
439    return gen_linalg_ops.banded_triangular_solve(
440        bands, rhs, lower=lower, adjoint=adjoint)
441
442
443@tf_export('linalg.tridiagonal_solve')
444@dispatch.add_dispatch_support
445def tridiagonal_solve(diagonals,
446                      rhs,
447                      diagonals_format='compact',
448                      transpose_rhs=False,
449                      conjugate_rhs=False,
450                      name=None,
451                      partial_pivoting=True,
452                      perturb_singular=False):
453  r"""Solves tridiagonal systems of equations.
454
455  The input can be supplied in various formats: `matrix`, `sequence` and
456  `compact`, specified by the `diagonals_format` arg.
457
458  In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
459  two inner-most dimensions representing the square tridiagonal matrices.
460  Elements outside of the three diagonals will be ignored.
461
462  In `sequence` format, `diagonals` are supplied as a tuple or list of three
463  tensors of shapes `[..., N]`, `[..., M]`, `[..., N]` representing
464  superdiagonals, diagonals, and subdiagonals, respectively. `N` can be either
465  `M-1` or `M`; in the latter case, the last element of superdiagonal and the
466  first element of subdiagonal will be ignored.
467
468  In `compact` format the three diagonals are brought together into one tensor
469  of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
470  diagonals, and subdiagonals, in order. Similarly to `sequence` format,
471  elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.
472
473  The `compact` format is recommended as the one with best performance. In case
474  you need to cast a tensor into a compact format manually, use `tf.gather_nd`.
475  An example for a tensor of shape [m, m]:
476
477  ```python
478  rhs = tf.constant([...])
479  matrix = tf.constant([[...]])
480  m = matrix.shape[0]
481  dummy_idx = [0, 0]  # An arbitrary element to use as a dummy
482  indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx],  # Superdiagonal
483           [[i, i] for i in range(m)],                          # Diagonal
484           [dummy_idx] + [[i + 1, i] for i in range(m - 1)]]    # Subdiagonal
485  diagonals=tf.gather_nd(matrix, indices)
486  x = tf.linalg.tridiagonal_solve(diagonals, rhs)
487  ```
488
489  Regardless of the `diagonals_format`, `rhs` is a tensor of shape `[..., M]` or
490  `[..., M, K]`. The latter allows to simultaneously solve K systems with the
491  same left-hand sides and K different right-hand sides. If `transpose_rhs`
492  is set to `True` the expected shape is `[..., M]` or `[..., K, M]`.
493
494  The batch dimensions, denoted as `...`, must be the same in `diagonals` and
495  `rhs`.
496
497  The output is a tensor of the same shape as `rhs`: either `[..., M]` or
498  `[..., M, K]`.
499
500  The op isn't guaranteed to raise an error if the input matrix is not
501  invertible. `tf.debugging.check_numerics` can be applied to the output to
502  detect invertibility problems.
503
504  **Note**: with large batch sizes, the computation on the GPU may be slow, if
505  either `partial_pivoting=True` or there are multiple right-hand sides
506  (`K > 1`). If this issue arises, consider if it's possible to disable pivoting
507  and have `K = 1`, or, alternatively, consider using CPU.
508
509  On CPU, solution is computed via Gaussian elimination with or without partial
510  pivoting, depending on `partial_pivoting` parameter. On GPU, Nvidia's cuSPARSE
511  library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv
512
513  Args:
514    diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
515      shape depends of `diagonals_format`, see description above. Must be
516      `float32`, `float64`, `complex64`, or `complex128`.
517    rhs: A `Tensor` of shape [..., M] or [..., M, K] and with the same dtype as
518      `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known
519      statically, `rhs` will be treated as a matrix rather than a vector.
520    diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is
521      `compact`.
522    transpose_rhs: If `True`, `rhs` is transposed before solving (has no effect
523      if the shape of rhs is [..., M]).
524    conjugate_rhs: If `True`, `rhs` is conjugated before solving.
525    name:  A name to give this `Op` (optional).
526    partial_pivoting: whether to perform partial pivoting. `True` by default.
527      Partial pivoting makes the procedure more stable, but slower. Partial
528      pivoting is unnecessary in some cases, including diagonally dominant and
529      symmetric positive definite matrices (see e.g. theorem 9.12 in [1]).
530    perturb_singular: whether to perturb singular matrices to return a finite
531      result. `False` by default. If true, solutions to systems involving
532      a singular matrix will be computed by perturbing near-zero pivots in
533      the partially pivoted LU decomposition. Specifically, tiny pivots are
534      perturbed by an amount of order `eps * max_{ij} |U(i,j)|` to avoid
535      overflow. Here `U` is the upper triangular part of the LU decomposition,
536      and `eps` is the machine precision. This is useful for solving
537      numerically singular systems when computing eigenvectors by inverse
538      iteration.
539      If `partial_pivoting` is `False`, `perturb_singular` must be `False` as
540      well.
541
542  Returns:
543    A `Tensor` of shape [..., M] or [..., M, K] containing the solutions.
544    If the input matrix is singular, the result is undefined.
545
546  Raises:
547    ValueError: Is raised if any of the following conditions hold:
548      1. An unsupported type is provided as input,
549      2. the input tensors have incorrect shapes,
550      3. `perturb_singular` is `True` but `partial_pivoting` is not.
551    UnimplementedError: Whenever `partial_pivoting` is true and the backend is
552      XLA, or whenever `perturb_singular` is true and the backend is
553      XLA or GPU.
554
555  [1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms:
556  Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7.
557
558  """
559  if perturb_singular and not partial_pivoting:
560    raise ValueError('partial_pivoting must be True if perturb_singular is.')
561
562  if diagonals_format == 'compact':
563    return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
564                                             conjugate_rhs, partial_pivoting,
565                                             perturb_singular, name)
566
567  if diagonals_format == 'sequence':
568    if not isinstance(diagonals, (tuple, list)) or len(diagonals) != 3:
569      raise ValueError('Expected diagonals to be a sequence of length 3.')
570
571    superdiag, maindiag, subdiag = diagonals
572    if (not subdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1]) or
573        not superdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1])):
574      raise ValueError(
575          'Tensors representing the three diagonals must have the same shape,'
576          'except for the last dimension, got {}, {}, {}'.format(
577              subdiag.shape, maindiag.shape, superdiag.shape))
578
579    m = tensor_shape.dimension_value(maindiag.shape[-1])
580
581    def pad_if_necessary(t, name, last_dim_padding):
582      n = tensor_shape.dimension_value(t.shape[-1])
583      if not n or n == m:
584        return t
585      if n == m - 1:
586        paddings = ([[0, 0] for _ in range(len(t.shape) - 1)] +
587                    [last_dim_padding])
588        return array_ops.pad(t, paddings)
589      raise ValueError('Expected {} to be have length {} or {}, got {}.'.format(
590          name, m, m - 1, n))
591
592    subdiag = pad_if_necessary(subdiag, 'subdiagonal', [1, 0])
593    superdiag = pad_if_necessary(superdiag, 'superdiagonal', [0, 1])
594
595    diagonals = array_ops.stack((superdiag, maindiag, subdiag), axis=-2)
596    return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
597                                             conjugate_rhs, partial_pivoting,
598                                             perturb_singular, name)
599
600  if diagonals_format == 'matrix':
601    m1 = tensor_shape.dimension_value(diagonals.shape[-1])
602    m2 = tensor_shape.dimension_value(diagonals.shape[-2])
603    if m1 and m2 and m1 != m2:
604      raise ValueError(
605          'Expected last two dimensions of diagonals to be same, got {} and {}'
606          .format(m1, m2))
607    m = m1 or m2
608    diagonals = array_ops.matrix_diag_part(
609        diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT')
610    return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
611                                             conjugate_rhs, partial_pivoting,
612                                             perturb_singular, name)
613
614  raise ValueError('Unrecognized diagonals_format: {}'.format(diagonals_format))
615
616
617def _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs,
618                                      conjugate_rhs, partial_pivoting,
619                                      perturb_singular, name):
620  """Helper function used after the input has been cast to compact form."""
621  diags_rank, rhs_rank = diagonals.shape.rank, rhs.shape.rank
622
623  # If we know the rank of the diagonal tensor, do some static checking.
624  if diags_rank:
625    if diags_rank < 2:
626      raise ValueError(
627          'Expected diagonals to have rank at least 2, got {}'.format(
628              diags_rank))
629    if rhs_rank and rhs_rank != diags_rank and rhs_rank != diags_rank - 1:
630      raise ValueError('Expected the rank of rhs to be {} or {}, got {}'.format(
631          diags_rank - 1, diags_rank, rhs_rank))
632    if (rhs_rank and not diagonals.shape[:-2].is_compatible_with(
633        rhs.shape[:diags_rank - 2])):
634      raise ValueError('Batch shapes {} and {} are incompatible'.format(
635          diagonals.shape[:-2], rhs.shape[:diags_rank - 2]))
636
637  if diagonals.shape[-2] and diagonals.shape[-2] != 3:
638    raise ValueError('Expected 3 diagonals got {}'.format(diagonals.shape[-2]))
639
640  def check_num_lhs_matches_num_rhs():
641    if (diagonals.shape[-1] and rhs.shape[-2] and
642        diagonals.shape[-1] != rhs.shape[-2]):
643      raise ValueError('Expected number of left-hand sided and right-hand '
644                       'sides to be equal, got {} and {}'.format(
645                           diagonals.shape[-1], rhs.shape[-2]))
646
647  if rhs_rank and diags_rank and rhs_rank == diags_rank - 1:
648    # Rhs provided as a vector, ignoring transpose_rhs
649    if conjugate_rhs:
650      rhs = math_ops.conj(rhs)
651    rhs = array_ops.expand_dims(rhs, -1)
652    check_num_lhs_matches_num_rhs()
653    return array_ops.squeeze(
654        linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting,
655                                     perturb_singular, name), -1)
656
657  if transpose_rhs:
658    rhs = array_ops.matrix_transpose(rhs, conjugate=conjugate_rhs)
659  elif conjugate_rhs:
660    rhs = math_ops.conj(rhs)
661
662  check_num_lhs_matches_num_rhs()
663  return linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting,
664                                      perturb_singular, name)
665
666
667@tf_export('linalg.tridiagonal_matmul')
668@dispatch.add_dispatch_support
669def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None):
670  r"""Multiplies tridiagonal matrix by matrix.
671
672  `diagonals` is representation of 3-diagonal NxN matrix, which depends on
673  `diagonals_format`.
674
675  In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with
676  two inner-most dimensions representing the square tridiagonal matrices.
677  Elements outside of the three diagonals will be ignored.
678
679  If `sequence` format, `diagonals` is list or tuple of three tensors:
680  `[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element
681  of `superdiag` first element of `subdiag` are ignored.
682
683  In `compact` format the three diagonals are brought together into one tensor
684  of shape `[..., 3, M]`, with last two dimensions containing superdiagonals,
685  diagonals, and subdiagonals, in order. Similarly to `sequence` format,
686  elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored.
687
688  The `sequence` format is recommended as the one with the best performance.
689
690  `rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`.
691
692  Example:
693
694  ```python
695  superdiag = tf.constant([-1, -1, 0], dtype=tf.float64)
696  maindiag = tf.constant([2, 2, 2], dtype=tf.float64)
697  subdiag = tf.constant([0, -1, -1], dtype=tf.float64)
698  diagonals = [superdiag, maindiag, subdiag]
699  rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64)
700  x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence')
701  ```
702
703  Args:
704    diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The
705      shape depends of `diagonals_format`, see description above. Must be
706      `float32`, `float64`, `complex64`, or `complex128`.
707    rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`.
708    diagonals_format: one of `sequence`, or `compact`. Default is `compact`.
709    name:  A name to give this `Op` (optional).
710
711  Returns:
712    A `Tensor` of shape [..., M, N] containing the result of multiplication.
713
714  Raises:
715    ValueError: An unsupported type is provided as input, or when the input
716    tensors have incorrect shapes.
717  """
718  if diagonals_format == 'compact':
719    superdiag = diagonals[..., 0, :]
720    maindiag = diagonals[..., 1, :]
721    subdiag = diagonals[..., 2, :]
722  elif diagonals_format == 'sequence':
723    superdiag, maindiag, subdiag = diagonals
724  elif diagonals_format == 'matrix':
725    m1 = tensor_shape.dimension_value(diagonals.shape[-1])
726    m2 = tensor_shape.dimension_value(diagonals.shape[-2])
727    if m1 and m2 and m1 != m2:
728      raise ValueError(
729          'Expected last two dimensions of diagonals to be same, got {} and {}'
730          .format(m1, m2))
731    diags = array_ops.matrix_diag_part(
732        diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT')
733    superdiag = diags[..., 0, :]
734    maindiag = diags[..., 1, :]
735    subdiag = diags[..., 2, :]
736  else:
737    raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format)
738
739  # C++ backend requires matrices.
740  # Converting 1-dimensional vectors to matrices with 1 row.
741  superdiag = array_ops.expand_dims(superdiag, -2)
742  maindiag = array_ops.expand_dims(maindiag, -2)
743  subdiag = array_ops.expand_dims(subdiag, -2)
744
745  return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs, name)
746
747
748def _maybe_validate_matrix(a, validate_args):
749  """Checks that input is a `float` matrix."""
750  assertions = []
751  if not a.dtype.is_floating:
752    raise TypeError('Input `a` must have `float`-like `dtype` '
753                    '(saw {}).'.format(a.dtype.name))
754  if a.shape is not None and a.shape.rank is not None:
755    if a.shape.rank < 2:
756      raise ValueError('Input `a` must have at least 2 dimensions '
757                       '(saw: {}).'.format(a.shape.rank))
758  elif validate_args:
759    assertions.append(
760        check_ops.assert_rank_at_least(
761            a, rank=2, message='Input `a` must have at least 2 dimensions.'))
762  return assertions
763
764
765@tf_export('linalg.matrix_rank')
766@dispatch.add_dispatch_support
767def matrix_rank(a, tol=None, validate_args=False, name=None):
768  """Compute the matrix rank of one or more matrices.
769
770  Args:
771    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
772      pseudo-inverted.
773    tol: Threshold below which the singular value is counted as 'zero'.
774      Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`).
775    validate_args: When `True`, additional assertions might be embedded in the
776      graph.
777      Default value: `False` (i.e., no graph assertions are added).
778    name: Python `str` prefixed to ops created by this function.
779      Default value: 'matrix_rank'.
780
781  Returns:
782    matrix_rank: (Batch of) `int32` scalars representing the number of non-zero
783      singular values.
784  """
785  with ops.name_scope(name or 'matrix_rank'):
786    a = ops.convert_to_tensor(a, dtype_hint=dtypes.float32, name='a')
787    assertions = _maybe_validate_matrix(a, validate_args)
788    if assertions:
789      with ops.control_dependencies(assertions):
790        a = array_ops.identity(a)
791    s = svd(a, compute_uv=False)
792    if tol is None:
793      if (a.shape[-2:]).is_fully_defined():
794        m = np.max(a.shape[-2:].as_list())
795      else:
796        m = math_ops.reduce_max(array_ops.shape(a)[-2:])
797      eps = np.finfo(a.dtype.as_numpy_dtype).eps
798      tol = (
799          eps * math_ops.cast(m, a.dtype) *
800          math_ops.reduce_max(s, axis=-1, keepdims=True))
801    return math_ops.reduce_sum(math_ops.cast(s > tol, dtypes.int32), axis=-1)
802
803
804@tf_export('linalg.pinv')
805@dispatch.add_dispatch_support
806def pinv(a, rcond=None, validate_args=False, name=None):
807  """Compute the Moore-Penrose pseudo-inverse of one or more matrices.
808
809  Calculate the [generalized inverse of a matrix](
810  https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its
811  singular-value decomposition (SVD) and including all large singular values.
812
813  The pseudo-inverse of a matrix `A`, is defined as: 'the matrix that 'solves'
814  [the least-squares problem] `A @ x = b`,' i.e., if `x_hat` is a solution, then
815  `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if
816  `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then
817  `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1]
818
819  This function is analogous to [`numpy.linalg.pinv`](
820  https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html).
821  It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the
822  default `rcond` is `1e-15`. Here the default is
823  `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`.
824
825  Args:
826    a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be
827      pseudo-inverted.
828    rcond: `Tensor` of small singular value cutoffs.  Singular values smaller
829      (in modulus) than `rcond` * largest_singular_value (again, in modulus) are
830      set to zero. Must broadcast against `tf.shape(a)[:-2]`.
831      Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`.
832    validate_args: When `True`, additional assertions might be embedded in the
833      graph.
834      Default value: `False` (i.e., no graph assertions are added).
835    name: Python `str` prefixed to ops created by this function.
836      Default value: 'pinv'.
837
838  Returns:
839    a_pinv: (Batch of) pseudo-inverse of input `a`. Has same shape as `a` except
840      rightmost two dimensions are transposed.
841
842  Raises:
843    TypeError: if input `a` does not have `float`-like `dtype`.
844    ValueError: if input `a` has fewer than 2 dimensions.
845
846  #### Examples
847
848  ```python
849  import tensorflow as tf
850  import tensorflow_probability as tfp
851
852  a = tf.constant([[1.,  0.4,  0.5],
853                   [0.4, 0.2,  0.25],
854                   [0.5, 0.25, 0.35]])
855  tf.matmul(tf.linalg.pinv(a), a)
856  # ==> array([[1., 0., 0.],
857               [0., 1., 0.],
858               [0., 0., 1.]], dtype=float32)
859
860  a = tf.constant([[1.,  0.4,  0.5,  1.],
861                   [0.4, 0.2,  0.25, 2.],
862                   [0.5, 0.25, 0.35, 3.]])
863  tf.matmul(tf.linalg.pinv(a), a)
864  # ==> array([[ 0.76,  0.37,  0.21, -0.02],
865               [ 0.37,  0.43, -0.33,  0.02],
866               [ 0.21, -0.33,  0.81,  0.01],
867               [-0.02,  0.02,  0.01,  1.  ]], dtype=float32)
868  ```
869
870  #### References
871
872  [1]: G. Strang. 'Linear Algebra and Its Applications, 2nd Ed.' Academic Press,
873       Inc., 1980, pp. 139-142.
874  """
875  with ops.name_scope(name or 'pinv'):
876    a = ops.convert_to_tensor(a, name='a')
877
878    assertions = _maybe_validate_matrix(a, validate_args)
879    if assertions:
880      with ops.control_dependencies(assertions):
881        a = array_ops.identity(a)
882
883    dtype = a.dtype.as_numpy_dtype
884
885    if rcond is None:
886
887      def get_dim_size(dim):
888        dim_val = tensor_shape.dimension_value(a.shape[dim])
889        if dim_val is not None:
890          return dim_val
891        return array_ops.shape(a)[dim]
892
893      num_rows = get_dim_size(-2)
894      num_cols = get_dim_size(-1)
895      if isinstance(num_rows, int) and isinstance(num_cols, int):
896        max_rows_cols = float(max(num_rows, num_cols))
897      else:
898        max_rows_cols = math_ops.cast(
899            math_ops.maximum(num_rows, num_cols), dtype)
900      rcond = 10. * max_rows_cols * np.finfo(dtype).eps
901
902    rcond = ops.convert_to_tensor(rcond, dtype=dtype, name='rcond')
903
904    # Calculate pseudo inverse via SVD.
905    # Note: if a is Hermitian then u == v. (We might observe additional
906    # performance by explicitly setting `v = u` in such cases.)
907    [
908        singular_values,  # Sigma
909        left_singular_vectors,  # U
910        right_singular_vectors,  # V
911    ] = svd(
912        a, full_matrices=False, compute_uv=True)
913
914    # Saturate small singular values to inf. This has the effect of make
915    # `1. / s = 0.` while not resulting in `NaN` gradients.
916    cutoff = rcond * math_ops.reduce_max(singular_values, axis=-1)
917    singular_values = array_ops.where_v2(
918        singular_values > array_ops.expand_dims_v2(cutoff, -1), singular_values,
919        np.array(np.inf, dtype))
920
921    # By the definition of the SVD, `a == u @ s @ v^H`, and the pseudo-inverse
922    # is defined as `pinv(a) == v @ inv(s) @ u^H`.
923    a_pinv = math_ops.matmul(
924        right_singular_vectors / array_ops.expand_dims_v2(singular_values, -2),
925        left_singular_vectors,
926        adjoint_b=True)
927
928    if a.shape is not None and a.shape.rank is not None:
929      a_pinv.set_shape(a.shape[:-2].concatenate([a.shape[-1], a.shape[-2]]))
930
931    return a_pinv
932
933
934@tf_export('linalg.lu_solve')
935@dispatch.add_dispatch_support
936def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
937  """Solves systems of linear eqns `A X = RHS`, given LU factorizations.
938
939  Note: this function does not verify the implied matrix is actually invertible
940  nor is this condition checked even when `validate_args=True`.
941
942  Args:
943    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
944      matmul(L, U)) = X` then `lower_upper = L + U - eye`.
945    perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
946      X` then `perm = argmax(P)`.
947    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
948      `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[...,
949        tf.newaxis])[..., 0]`.
950    validate_args: Python `bool` indicating whether arguments should be checked
951      for correctness. Note: this function does not verify the implied matrix is
952        actually invertible, even when `validate_args=True`.
953      Default value: `False` (i.e., don't validate arguments).
954    name: Python `str` name given to ops managed by this object.
955      Default value: `None` (i.e., 'lu_solve').
956
957  Returns:
958    x: The `X` in `A @ X = RHS`.
959
960  #### Examples
961
962  ```python
963  import numpy as np
964  import tensorflow as tf
965  import tensorflow_probability as tfp
966
967  x = [[[1., 2],
968        [3, 4]],
969       [[7, 8],
970        [3, 4]]]
971  inv_x = tf.linalg.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
972  tf.assert_near(tf.matrix_inverse(x), inv_x)
973  # ==> True
974  ```
975
976  """
977
978  with ops.name_scope(name or 'lu_solve'):
979    lower_upper = ops.convert_to_tensor(
980        lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
981    perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')
982    rhs = ops.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs')
983
984    assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args)
985    if assertions:
986      with ops.control_dependencies(assertions):
987        lower_upper = array_ops.identity(lower_upper)
988        perm = array_ops.identity(perm)
989        rhs = array_ops.identity(rhs)
990
991    if (rhs.shape.rank == 2 and perm.shape.rank == 1):
992      # Both rhs and perm have scalar batch_shape.
993      permuted_rhs = array_ops.gather(rhs, perm, axis=-2)
994    else:
995      # Either rhs or perm have non-scalar batch_shape or we can't determine
996      # this information statically.
997      rhs_shape = array_ops.shape(rhs)
998      broadcast_batch_shape = array_ops.broadcast_dynamic_shape(
999          rhs_shape[:-2],
1000          array_ops.shape(perm)[:-1])
1001      d, m = rhs_shape[-2], rhs_shape[-1]
1002      rhs_broadcast_shape = array_ops.concat([broadcast_batch_shape, [d, m]],
1003                                             axis=0)
1004
1005      # Tile out rhs.
1006      broadcast_rhs = array_ops.broadcast_to(rhs, rhs_broadcast_shape)
1007      broadcast_rhs = array_ops.reshape(broadcast_rhs, [-1, d, m])
1008
1009      # Tile out perm and add batch indices.
1010      broadcast_perm = array_ops.broadcast_to(perm, rhs_broadcast_shape[:-1])
1011      broadcast_perm = array_ops.reshape(broadcast_perm, [-1, d])
1012      broadcast_batch_size = math_ops.reduce_prod(broadcast_batch_shape)
1013      broadcast_batch_indices = array_ops.broadcast_to(
1014          math_ops.range(broadcast_batch_size)[:, array_ops.newaxis],
1015          [broadcast_batch_size, d])
1016      broadcast_perm = array_ops.stack(
1017          [broadcast_batch_indices, broadcast_perm], axis=-1)
1018
1019      permuted_rhs = array_ops.gather_nd(broadcast_rhs, broadcast_perm)
1020      permuted_rhs = array_ops.reshape(permuted_rhs, rhs_broadcast_shape)
1021
1022    lower = set_diag(
1023        band_part(lower_upper, num_lower=-1, num_upper=0),
1024        array_ops.ones(
1025            array_ops.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
1026    return triangular_solve(
1027        lower_upper,  # Only upper is accessed.
1028        triangular_solve(lower, permuted_rhs),
1029        lower=False)
1030
1031
1032@tf_export('linalg.lu_matrix_inverse')
1033@dispatch.add_dispatch_support
1034def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None):
1035  """Computes the inverse given the LU decomposition(s) of one or more matrices.
1036
1037  This op is conceptually identical to,
1038
1039  ```python
1040  inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X))
1041  tf.assert_near(tf.matrix_inverse(X), inv_X)
1042  # ==> True
1043  ```
1044
1045  Note: this function does not verify the implied matrix is actually invertible
1046  nor is this condition checked even when `validate_args=True`.
1047
1048  Args:
1049    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
1050      matmul(L, U)) = X` then `lower_upper = L + U - eye`.
1051    perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
1052      X` then `perm = argmax(P)`.
1053    validate_args: Python `bool` indicating whether arguments should be checked
1054      for correctness. Note: this function does not verify the implied matrix is
1055        actually invertible, even when `validate_args=True`.
1056      Default value: `False` (i.e., don't validate arguments).
1057    name: Python `str` name given to ops managed by this object.
1058      Default value: `None` (i.e., 'lu_matrix_inverse').
1059
1060  Returns:
1061    inv_x: The matrix_inv, i.e.,
1062      `tf.matrix_inverse(tf.linalg.lu_reconstruct(lu, perm))`.
1063
1064  #### Examples
1065
1066  ```python
1067  import numpy as np
1068  import tensorflow as tf
1069  import tensorflow_probability as tfp
1070
1071  x = [[[3., 4], [1, 2]],
1072       [[7., 8], [3, 4]]]
1073  inv_x = tf.linalg.lu_matrix_inverse(*tf.linalg.lu(x))
1074  tf.assert_near(tf.matrix_inverse(x), inv_x)
1075  # ==> True
1076  ```
1077
1078  """
1079
1080  with ops.name_scope(name or 'lu_matrix_inverse'):
1081    lower_upper = ops.convert_to_tensor(
1082        lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
1083    perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')
1084    assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
1085    if assertions:
1086      with ops.control_dependencies(assertions):
1087        lower_upper = array_ops.identity(lower_upper)
1088        perm = array_ops.identity(perm)
1089    shape = array_ops.shape(lower_upper)
1090    return lu_solve(
1091        lower_upper,
1092        perm,
1093        rhs=eye(shape[-1], batch_shape=shape[:-2], dtype=lower_upper.dtype),
1094        validate_args=False)
1095
1096
1097@tf_export('linalg.lu_reconstruct')
1098@dispatch.add_dispatch_support
1099def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
1100  """The reconstruct one or more matrices from their LU decomposition(s).
1101
1102  Args:
1103    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P,
1104      matmul(L, U)) = X` then `lower_upper = L + U - eye`.
1105    perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) =
1106      X` then `perm = argmax(P)`.
1107    validate_args: Python `bool` indicating whether arguments should be checked
1108      for correctness.
1109      Default value: `False` (i.e., don't validate arguments).
1110    name: Python `str` name given to ops managed by this object.
1111      Default value: `None` (i.e., 'lu_reconstruct').
1112
1113  Returns:
1114    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
1115      `lu_reconstruct(*tf.linalg.lu(x))`.
1116
1117  #### Examples
1118
1119  ```python
1120  import numpy as np
1121  import tensorflow as tf
1122  import tensorflow_probability as tfp
1123
1124  x = [[[3., 4], [1, 2]],
1125       [[7., 8], [3, 4]]]
1126  x_reconstructed = tf.linalg.lu_reconstruct(*tf.linalg.lu(x))
1127  tf.assert_near(x, x_reconstructed)
1128  # ==> True
1129  ```
1130
1131  """
1132  with ops.name_scope(name or 'lu_reconstruct'):
1133    lower_upper = ops.convert_to_tensor(
1134        lower_upper, dtype_hint=dtypes.float32, name='lower_upper')
1135    perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm')
1136
1137    assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
1138    if assertions:
1139      with ops.control_dependencies(assertions):
1140        lower_upper = array_ops.identity(lower_upper)
1141        perm = array_ops.identity(perm)
1142
1143    shape = array_ops.shape(lower_upper)
1144
1145    lower = set_diag(
1146        band_part(lower_upper, num_lower=-1, num_upper=0),
1147        array_ops.ones(shape[:-1], dtype=lower_upper.dtype))
1148    upper = band_part(lower_upper, num_lower=0, num_upper=-1)
1149    x = math_ops.matmul(lower, upper)
1150
1151    if (lower_upper.shape is None or lower_upper.shape.rank is None or
1152        lower_upper.shape.rank != 2):
1153      # We either don't know the batch rank or there are >0 batch dims.
1154      batch_size = math_ops.reduce_prod(shape[:-2])
1155      d = shape[-1]
1156      x = array_ops.reshape(x, [batch_size, d, d])
1157      perm = array_ops.reshape(perm, [batch_size, d])
1158      perm = map_fn.map_fn(array_ops.invert_permutation, perm)
1159      batch_indices = array_ops.broadcast_to(
1160          math_ops.range(batch_size)[:, array_ops.newaxis], [batch_size, d])
1161      x = array_ops.gather_nd(x, array_ops.stack([batch_indices, perm],
1162                                                 axis=-1))
1163      x = array_ops.reshape(x, shape)
1164    else:
1165      x = array_ops.gather(x, array_ops.invert_permutation(perm))
1166
1167    x.set_shape(lower_upper.shape)
1168    return x
1169
1170
1171def lu_reconstruct_assertions(lower_upper, perm, validate_args):
1172  """Returns list of assertions related to `lu_reconstruct` assumptions."""
1173  assertions = []
1174
1175  message = 'Input `lower_upper` must have at least 2 dimensions.'
1176  if lower_upper.shape.rank is not None and lower_upper.shape.rank < 2:
1177    raise ValueError(message)
1178  elif validate_args:
1179    assertions.append(
1180        check_ops.assert_rank_at_least_v2(lower_upper, rank=2, message=message))
1181
1182  message = '`rank(lower_upper)` must equal `rank(perm) + 1`'
1183  if lower_upper.shape.rank is not None and perm.shape.rank is not None:
1184    if lower_upper.shape.rank != perm.shape.rank + 1:
1185      raise ValueError(message)
1186  elif validate_args:
1187    assertions.append(
1188        check_ops.assert_rank(
1189            lower_upper, rank=array_ops.rank(perm) + 1, message=message))
1190
1191  message = '`lower_upper` must be square.'
1192  if lower_upper.shape[:-2].is_fully_defined():
1193    if lower_upper.shape[-2] != lower_upper.shape[-1]:
1194      raise ValueError(message)
1195  elif validate_args:
1196    m, n = array_ops.split(
1197        array_ops.shape(lower_upper)[-2:], num_or_size_splits=2)
1198    assertions.append(check_ops.assert_equal(m, n, message=message))
1199
1200  return assertions
1201
1202
1203def _lu_solve_assertions(lower_upper, perm, rhs, validate_args):
1204  """Returns list of assertions related to `lu_solve` assumptions."""
1205  assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args)
1206
1207  message = 'Input `rhs` must have at least 2 dimensions.'
1208  if rhs.shape.ndims is not None:
1209    if rhs.shape.ndims < 2:
1210      raise ValueError(message)
1211  elif validate_args:
1212    assertions.append(
1213        check_ops.assert_rank_at_least(rhs, rank=2, message=message))
1214
1215  message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.'
1216  if (lower_upper.shape[-1] is not None and rhs.shape[-2] is not None):
1217    if lower_upper.shape[-1] != rhs.shape[-2]:
1218      raise ValueError(message)
1219  elif validate_args:
1220    assertions.append(
1221        check_ops.assert_equal(
1222            array_ops.shape(lower_upper)[-1],
1223            array_ops.shape(rhs)[-2],
1224            message=message))
1225
1226  return assertions
1227
1228
1229@tf_export('linalg.eigh_tridiagonal')
1230@dispatch.add_dispatch_support
1231def eigh_tridiagonal(alpha,
1232                     beta,
1233                     eigvals_only=True,
1234                     select='a',
1235                     select_range=None,
1236                     tol=None,
1237                     name=None):
1238  """Computes the eigenvalues of a Hermitian tridiagonal matrix.
1239
1240  Args:
1241    alpha: A real or complex tensor of shape (n), the diagonal elements of the
1242      matrix. NOTE: If alpha is complex, the imaginary part is ignored (assumed
1243        zero) to satisfy the requirement that the matrix be Hermitian.
1244    beta: A real or complex tensor of shape (n-1), containing the elements of
1245      the first super-diagonal of the matrix. If beta is complex, the first
1246      sub-diagonal of the matrix is assumed to be the conjugate of beta to
1247      satisfy the requirement that the matrix be Hermitian
1248    eigvals_only: If False, both eigenvalues and corresponding eigenvectors are
1249      computed. If True, only eigenvalues are computed. Default is True.
1250    select: Optional string with values in {‘a’, ‘v’, ‘i’} (default is 'a') that
1251      determines which eigenvalues to calculate:
1252        'a': all eigenvalues.
1253        ‘v’: eigenvalues in the interval (min, max] given by `select_range`.
1254        'i’: eigenvalues with indices min <= i <= max.
1255    select_range: Size 2 tuple or list or tensor specifying the range of
1256      eigenvalues to compute together with select. If select is 'a',
1257      select_range is ignored.
1258    tol: Optional scalar. The absolute tolerance to which each eigenvalue is
1259      required. An eigenvalue (or cluster) is considered to have converged if it
1260      lies in an interval of this width. If tol is None (default), the value
1261      eps*|T|_2 is used where eps is the machine precision, and |T|_2 is the
1262      2-norm of the matrix T.
1263    name: Optional name of the op.
1264
1265  Returns:
1266    eig_vals: The eigenvalues of the matrix in non-decreasing order.
1267    eig_vectors: If `eigvals_only` is False the eigenvectors are returned in
1268      the second output argument.
1269
1270  Raises:
1271     ValueError: If input values are invalid.
1272     NotImplemented: Computing eigenvectors for `eigvals_only` = False is
1273       not implemented yet.
1274
1275  This op implements a subset of the functionality of
1276  scipy.linalg.eigh_tridiagonal.
1277
1278  Note: The result is undefined if the input contains +/-inf or NaN, or if
1279  any value in beta has a magnitude greater than
1280  `numpy.sqrt(numpy.finfo(beta.dtype.as_numpy_dtype).max)`.
1281
1282
1283  TODO(b/187527398):
1284    Add support for outer batch dimensions.
1285
1286  #### Examples
1287
1288  ```python
1289  import numpy
1290  eigvals = tf.linalg.eigh_tridiagonal([0.0, 0.0, 0.0], [1.0, 1.0])
1291  eigvals_expected = [-numpy.sqrt(2.0), 0.0, numpy.sqrt(2.0)]
1292  tf.assert_near(eigvals_expected, eigvals)
1293  # ==> True
1294  ```
1295
1296  """
1297  with ops.name_scope(name or 'eigh_tridiagonal'):
1298
1299    def _compute_eigenvalues(alpha, beta):
1300      """Computes all eigenvalues of a Hermitian tridiagonal matrix."""
1301
1302      def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
1303        """Implements the Sturm sequence recurrence."""
1304        with ops.name_scope('sturm'):
1305          n = alpha.shape[0]
1306          zeros = array_ops.zeros(array_ops.shape(x), dtype=dtypes.int32)
1307          ones = array_ops.ones(array_ops.shape(x), dtype=dtypes.int32)
1308
1309          # The first step in the Sturm sequence recurrence
1310          # requires special care if x is equal to alpha[0].
1311          def sturm_step0():
1312            q = alpha[0] - x
1313            count = array_ops.where(q < 0, ones, zeros)
1314            q = array_ops.where(
1315                math_ops.equal(alpha[0], x), alpha0_perturbation, q)
1316            return q, count
1317
1318          # Subsequent steps all take this form:
1319          def sturm_step(i, q, count):
1320            q = alpha[i] - beta_sq[i - 1] / q - x
1321            count = array_ops.where(q <= pivmin, count + 1, count)
1322            q = array_ops.where(q <= pivmin, math_ops.minimum(q, -pivmin), q)
1323            return q, count
1324
1325          # The first step initializes q and count.
1326          q, count = sturm_step0()
1327
1328          # Peel off ((n-1) % blocksize) steps from the main loop, so we can run
1329          # the bulk of the iterations unrolled by a factor of blocksize.
1330          blocksize = 16
1331          i = 1
1332          peel = (n - 1) % blocksize
1333          unroll_cnt = peel
1334
1335          def unrolled_steps(start, q, count):
1336            for j in range(unroll_cnt):
1337              q, count = sturm_step(start + j, q, count)
1338            return start + unroll_cnt, q, count
1339
1340          i, q, count = unrolled_steps(i, q, count)
1341
1342          # Run the remaining steps of the Sturm sequence using a partially
1343          # unrolled while loop.
1344          unroll_cnt = blocksize
1345          cond = lambda i, q, count: math_ops.less(i, n)
1346          _, _, count = control_flow_ops.while_loop(
1347              cond, unrolled_steps, [i, q, count], back_prop=False)
1348          return count
1349
1350      with ops.name_scope('compute_eigenvalues'):
1351        if alpha.dtype.is_complex:
1352          alpha = math_ops.real(alpha)
1353          beta_sq = math_ops.real(math_ops.conj(beta) * beta)
1354          beta_abs = math_ops.sqrt(beta_sq)
1355        else:
1356          beta_sq = math_ops.square(beta)
1357          beta_abs = math_ops.abs(beta)
1358
1359        # Estimate the largest and smallest eigenvalues of T using the
1360        # Gershgorin circle theorem.
1361        finfo = np.finfo(alpha.dtype.as_numpy_dtype)
1362        off_diag_abs_row_sum = array_ops.concat(
1363            [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0)
1364        lambda_est_max = math_ops.minimum(
1365            finfo.max, math_ops.reduce_max(alpha + off_diag_abs_row_sum))
1366        lambda_est_min = math_ops.maximum(
1367            finfo.min, math_ops.reduce_min(alpha - off_diag_abs_row_sum))
1368        # Upper bound on 2-norm of T.
1369        t_norm = math_ops.maximum(
1370            math_ops.abs(lambda_est_min), math_ops.abs(lambda_est_max))
1371
1372        # Compute the smallest allowed pivot in the Sturm sequence to avoid
1373        # overflow.
1374        one = np.ones([], dtype=alpha.dtype.as_numpy_dtype)
1375        safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny)
1376        pivmin = safemin * math_ops.maximum(one, math_ops.reduce_max(beta_sq))
1377        alpha0_perturbation = math_ops.square(finfo.eps * beta_abs[0])
1378        abs_tol = finfo.eps * t_norm
1379        if tol:
1380          abs_tol = math_ops.maximum(tol, abs_tol)
1381        # In the worst case, when the absolute tolerance is eps*lambda_est_max
1382        # and lambda_est_max = -lambda_est_min, we have to take as many
1383        # bisection steps as there are bits in the mantissa plus 1.
1384        max_it = finfo.nmant + 1
1385
1386        # Determine the indices of the desired eigenvalues, based on select
1387        # and select_range.
1388        asserts = None
1389        if select == 'a':
1390          target_counts = math_ops.range(n)
1391        elif select == 'i':
1392          asserts = check_ops.assert_less_equal(
1393              select_range[0],
1394              select_range[1],
1395              message='Got empty index range in select_range.')
1396          target_counts = math_ops.range(select_range[0], select_range[1] + 1)
1397        elif select == 'v':
1398          asserts = check_ops.assert_less(
1399              select_range[0],
1400              select_range[1],
1401              message='Got empty interval in select_range.')
1402        else:
1403          raise ValueError("'select must have a value in {'a', 'i', 'v'}.")
1404
1405        if asserts:
1406          with ops.control_dependencies([asserts]):
1407            alpha = array_ops.identity(alpha)
1408
1409        # Run binary search for all desired eigenvalues in parallel, starting
1410        # from  an interval slightly wider than the estimated
1411        # [lambda_est_min, lambda_est_max].
1412        fudge = 2.1  # We widen starting interval the Gershgorin interval a bit.
1413        norm_slack = math_ops.cast(n, alpha.dtype) * fudge * finfo.eps * t_norm
1414        if select in {'a', 'i'}:
1415          lower = lambda_est_min - norm_slack - 2 * fudge * pivmin
1416          upper = lambda_est_max + norm_slack + fudge * pivmin
1417        else:
1418          # Count the number of eigenvalues in the given range.
1419          lower = select_range[0] - norm_slack - 2 * fudge * pivmin
1420          upper = select_range[1] + norm_slack + fudge * pivmin
1421          first = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, lower)
1422          last = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, upper)
1423          target_counts = math_ops.range(first, last)
1424
1425        # Pre-broadcast the scalars used in the Sturm sequence for improved
1426        # performance.
1427        upper = math_ops.minimum(upper, finfo.max)
1428        lower = math_ops.maximum(lower, finfo.min)
1429        target_shape = array_ops.shape(target_counts)
1430        lower = array_ops.broadcast_to(lower, shape=target_shape)
1431        upper = array_ops.broadcast_to(upper, shape=target_shape)
1432        pivmin = array_ops.broadcast_to(pivmin, target_shape)
1433        alpha0_perturbation = array_ops.broadcast_to(alpha0_perturbation,
1434                                                     target_shape)
1435
1436        # We compute the midpoint as 0.5*lower + 0.5*upper to avoid overflow in
1437        # (lower + upper) or (upper - lower) when the matrix has eigenvalues
1438        # with magnitude greater than finfo.max / 2.
1439        def midpoint(lower, upper):
1440          return (0.5 * lower) + (0.5 * upper)
1441
1442        def continue_binary_search(i, lower, upper):
1443          return math_ops.logical_and(
1444              math_ops.less(i, max_it),
1445              math_ops.less(abs_tol, math_ops.reduce_max(upper - lower)))
1446
1447        def binary_search_step(i, lower, upper):
1448          mid = midpoint(lower, upper)
1449          counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid)
1450          lower = array_ops.where(counts <= target_counts, mid, lower)
1451          upper = array_ops.where(counts > target_counts, mid, upper)
1452          return i + 1, lower, upper
1453
1454        # Start parallel binary searches.
1455        _, lower, upper = control_flow_ops.while_loop(continue_binary_search,
1456                                                      binary_search_step,
1457                                                      [0, lower, upper])
1458        return midpoint(lower, upper)
1459
1460    def _compute_eigenvectors(alpha, beta, eigvals):
1461      """Implements inverse iteration to compute eigenvectors."""
1462      with ops.name_scope('compute_eigenvectors'):
1463        k = array_ops.size(eigvals)
1464        n = array_ops.size(alpha)
1465        alpha = math_ops.cast(alpha, dtype=beta.dtype)
1466
1467        # Eigenvectors corresponding to cluster of close eigenvalues are
1468        # not unique and need to be explicitly orthogonalized. Here we
1469        # identify such clusters. Note: This function assumes that
1470        # eigenvalues are sorted in non-decreasing order.
1471        gap = eigvals[1:] - eigvals[:-1]
1472        eps = np.finfo(eigvals.dtype.as_numpy_dtype).eps
1473        t_norm = math_ops.maximum(
1474            math_ops.abs(eigvals[0]), math_ops.abs(eigvals[-1]))
1475        gaptol = np.sqrt(eps) * t_norm
1476        # Find the beginning and end of runs of eigenvectors corresponding
1477        # to eigenvalues closer than "gaptol", which will need to be
1478        # orthogonalized against each other.
1479        close = math_ops.less(gap, gaptol)
1480        left_neighbor_close = array_ops.concat([[False], close], axis=0)
1481        right_neighbor_close = array_ops.concat([close, [False]], axis=0)
1482        ortho_interval_start = math_ops.logical_and(
1483            math_ops.logical_not(left_neighbor_close), right_neighbor_close)
1484        ortho_interval_start = array_ops.squeeze(
1485            array_ops.where_v2(ortho_interval_start), axis=-1)
1486        ortho_interval_end = math_ops.logical_and(
1487            left_neighbor_close, math_ops.logical_not(right_neighbor_close))
1488        ortho_interval_end = array_ops.squeeze(
1489            array_ops.where_v2(ortho_interval_end), axis=-1) + 1
1490        num_clusters = array_ops.size(ortho_interval_end)
1491
1492        # We perform inverse iteration for all eigenvectors in parallel,
1493        # starting from a random set of vectors, until all have converged.
1494        v0 = math_ops.cast(
1495            stateless_random_ops.stateless_random_normal(
1496                shape=(k, n), seed=[7, 42]),
1497            dtype=beta.dtype)
1498        nrm_v = norm(v0, axis=1)
1499        v0 = v0 / nrm_v[:, array_ops.newaxis]
1500        zero_nrm = constant_op.constant(0, shape=nrm_v.shape, dtype=nrm_v.dtype)
1501
1502        # Replicate alpha-eigvals(ik) and beta across the k eigenvectors so we
1503        # can solve the k systems
1504        #    [T - eigvals(i)*eye(n)] x_i = r_i
1505        # simultaneously using the batching mechanism.
1506        eigvals_cast = math_ops.cast(eigvals, dtype=beta.dtype)
1507        alpha_shifted = (
1508            alpha[array_ops.newaxis, :] - eigvals_cast[:, array_ops.newaxis])
1509        beta = array_ops.tile(beta[array_ops.newaxis, :], [k, 1])
1510        diags = [beta, alpha_shifted, math_ops.conj(beta)]
1511
1512        def orthogonalize_close_eigenvectors(eigenvectors):
1513          # Eigenvectors corresponding to a cluster of close eigenvalues are not
1514          # uniquely defined, but the subspace they span is. To avoid numerical
1515          # instability, we explicitly mutually orthogonalize such eigenvectors
1516          # after each step of inverse iteration. It is customary to use
1517          # modified Gram-Schmidt for this, but this is not very efficient
1518          # on some platforms, so here we defer to the QR decomposition in
1519          # TensorFlow.
1520          def orthogonalize_cluster(cluster_idx, eigenvectors):
1521            start = ortho_interval_start[cluster_idx]
1522            end = ortho_interval_end[cluster_idx]
1523            update_indices = array_ops.expand_dims(
1524                math_ops.range(start, end), -1)
1525            vectors_in_cluster = eigenvectors[start:end, :]
1526            # We use the builtin QR factorization to orthonormalize the
1527            # vectors in the cluster.
1528            q, _ = qr(transpose(vectors_in_cluster))
1529            vectors_to_update = transpose(q)
1530            eigenvectors = array_ops.tensor_scatter_nd_update(
1531                eigenvectors, update_indices, vectors_to_update)
1532            return cluster_idx + 1, eigenvectors
1533
1534          _, eigenvectors = control_flow_ops.while_loop(
1535              lambda i, ev: math_ops.less(i, num_clusters),
1536              orthogonalize_cluster, [0, eigenvectors])
1537          return eigenvectors
1538
1539        def continue_iteration(i, _, nrm_v, nrm_v_old):
1540          max_it = 5  # Taken from LAPACK xSTEIN.
1541          min_norm_growth = 0.1
1542          norm_growth_factor = constant_op.constant(
1543              1 + min_norm_growth, dtype=nrm_v.dtype)
1544          # We stop the inverse iteration when we reach the maximum number of
1545          # iterations or the norm growths is less than 10%.
1546          return math_ops.logical_and(
1547              math_ops.less(i, max_it),
1548              math_ops.reduce_any(
1549                  math_ops.greater_equal(
1550                      math_ops.real(nrm_v),
1551                      math_ops.real(norm_growth_factor * nrm_v_old))))
1552
1553        def inverse_iteration_step(i, v, nrm_v, nrm_v_old):
1554          v = tridiagonal_solve(
1555              diags,
1556              v,
1557              diagonals_format='sequence',
1558              partial_pivoting=True,
1559              perturb_singular=True)
1560          nrm_v_old = nrm_v
1561          nrm_v = norm(v, axis=1)
1562          v = v / nrm_v[:, array_ops.newaxis]
1563          v = orthogonalize_close_eigenvectors(v)
1564          return i + 1, v, nrm_v, nrm_v_old
1565
1566        _, v, nrm_v, _ = control_flow_ops.while_loop(continue_iteration,
1567                                                     inverse_iteration_step,
1568                                                     [0, v0, nrm_v, zero_nrm])
1569        return transpose(v)
1570
1571    alpha = ops.convert_to_tensor(alpha, name='alpha')
1572    n = alpha.shape[0]
1573    if n <= 1:
1574      return math_ops.real(alpha)
1575    beta = ops.convert_to_tensor(beta, name='beta')
1576
1577    if alpha.dtype != beta.dtype:
1578      raise ValueError("'alpha' and 'beta' must have the same type.")
1579
1580    eigvals = _compute_eigenvalues(alpha, beta)
1581    if eigvals_only:
1582      return eigvals
1583
1584    eigvectors = _compute_eigenvectors(alpha, beta, eigvals)
1585    return eigvals, eigvectors
1586