1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operations for linear algebra.""" 16 17import numpy as np 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import check_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import gen_linalg_ops 27from tensorflow.python.ops import linalg_ops 28from tensorflow.python.ops import map_fn 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import special_math_ops 31from tensorflow.python.ops import stateless_random_ops 32from tensorflow.python.util import dispatch 33from tensorflow.python.util.tf_export import tf_export 34 35# Linear algebra ops. 36band_part = array_ops.matrix_band_part 37cholesky = linalg_ops.cholesky 38cholesky_solve = linalg_ops.cholesky_solve 39det = linalg_ops.matrix_determinant 40slogdet = gen_linalg_ops.log_matrix_determinant 41tf_export('linalg.slogdet')(dispatch.add_dispatch_support(slogdet)) 42diag = array_ops.matrix_diag 43diag_part = array_ops.matrix_diag_part 44eigh = linalg_ops.self_adjoint_eig 45eigvalsh = linalg_ops.self_adjoint_eigvals 46einsum = special_math_ops.einsum 47eye = linalg_ops.eye 48inv = linalg_ops.matrix_inverse 49logm = gen_linalg_ops.matrix_logarithm 50lu = gen_linalg_ops.lu 51tf_export('linalg.logm')(dispatch.add_dispatch_support(logm)) 52lstsq = linalg_ops.matrix_solve_ls 53norm = linalg_ops.norm 54qr = linalg_ops.qr 55set_diag = array_ops.matrix_set_diag 56solve = linalg_ops.matrix_solve 57sqrtm = linalg_ops.matrix_square_root 58svd = linalg_ops.svd 59tensordot = math_ops.tensordot 60trace = math_ops.trace 61transpose = array_ops.matrix_transpose 62triangular_solve = linalg_ops.matrix_triangular_solve 63 64 65@tf_export('linalg.logdet') 66@dispatch.add_dispatch_support 67def logdet(matrix, name=None): 68 """Computes log of the determinant of a hermitian positive definite matrix. 69 70 ```python 71 # Compute the determinant of a matrix while reducing the chance of over- or 72 underflow: 73 A = ... # shape 10 x 10 74 det = tf.exp(tf.linalg.logdet(A)) # scalar 75 ``` 76 77 Args: 78 matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, 79 or `complex128` with shape `[..., M, M]`. 80 name: A name to give this `Op`. Defaults to `logdet`. 81 82 Returns: 83 The natural log of the determinant of `matrix`. 84 85 @compatibility(numpy) 86 Equivalent to numpy.linalg.slogdet, although no sign is returned since only 87 hermitian positive definite matrices are supported. 88 @end_compatibility 89 """ 90 # This uses the property that the log det(A) = 2*sum(log(real(diag(C)))) 91 # where C is the cholesky decomposition of A. 92 with ops.name_scope(name, 'logdet', [matrix]): 93 chol = gen_linalg_ops.cholesky(matrix) 94 return 2.0 * math_ops.reduce_sum( 95 math_ops.log(math_ops.real(array_ops.matrix_diag_part(chol))), 96 axis=[-1]) 97 98 99@tf_export('linalg.adjoint') 100@dispatch.add_dispatch_support 101def adjoint(matrix, name=None): 102 """Transposes the last two dimensions of and conjugates tensor `matrix`. 103 104 For example: 105 106 ```python 107 x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j], 108 [4 + 4j, 5 + 5j, 6 + 6j]]) 109 tf.linalg.adjoint(x) # [[1 - 1j, 4 - 4j], 110 # [2 - 2j, 5 - 5j], 111 # [3 - 3j, 6 - 6j]] 112 ``` 113 114 Args: 115 matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, 116 or `complex128` with shape `[..., M, M]`. 117 name: A name to give this `Op` (optional). 118 119 Returns: 120 The adjoint (a.k.a. Hermitian transpose a.k.a. conjugate transpose) of 121 matrix. 122 """ 123 with ops.name_scope(name, 'adjoint', [matrix]): 124 matrix = ops.convert_to_tensor(matrix, name='matrix') 125 return array_ops.matrix_transpose(matrix, conjugate=True) 126 127 128# This section is ported nearly verbatim from Eigen's implementation: 129# https://eigen.tuxfamily.org/dox/unsupported/MatrixExponential_8h_source.html 130def _matrix_exp_pade3(matrix): 131 """3rd-order Pade approximant for matrix exponential.""" 132 b = [120.0, 60.0, 12.0] 133 b = [constant_op.constant(x, matrix.dtype) for x in b] 134 ident = linalg_ops.eye( 135 array_ops.shape(matrix)[-2], 136 batch_shape=array_ops.shape(matrix)[:-2], 137 dtype=matrix.dtype) 138 matrix_2 = math_ops.matmul(matrix, matrix) 139 tmp = matrix_2 + b[1] * ident 140 matrix_u = math_ops.matmul(matrix, tmp) 141 matrix_v = b[2] * matrix_2 + b[0] * ident 142 return matrix_u, matrix_v 143 144 145def _matrix_exp_pade5(matrix): 146 """5th-order Pade approximant for matrix exponential.""" 147 b = [30240.0, 15120.0, 3360.0, 420.0, 30.0] 148 b = [constant_op.constant(x, matrix.dtype) for x in b] 149 ident = linalg_ops.eye( 150 array_ops.shape(matrix)[-2], 151 batch_shape=array_ops.shape(matrix)[:-2], 152 dtype=matrix.dtype) 153 matrix_2 = math_ops.matmul(matrix, matrix) 154 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 155 tmp = matrix_4 + b[3] * matrix_2 + b[1] * ident 156 matrix_u = math_ops.matmul(matrix, tmp) 157 matrix_v = b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident 158 return matrix_u, matrix_v 159 160 161def _matrix_exp_pade7(matrix): 162 """7th-order Pade approximant for matrix exponential.""" 163 b = [17297280.0, 8648640.0, 1995840.0, 277200.0, 25200.0, 1512.0, 56.0] 164 b = [constant_op.constant(x, matrix.dtype) for x in b] 165 ident = linalg_ops.eye( 166 array_ops.shape(matrix)[-2], 167 batch_shape=array_ops.shape(matrix)[:-2], 168 dtype=matrix.dtype) 169 matrix_2 = math_ops.matmul(matrix, matrix) 170 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 171 matrix_6 = math_ops.matmul(matrix_4, matrix_2) 172 tmp = matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident 173 matrix_u = math_ops.matmul(matrix, tmp) 174 matrix_v = b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + b[0] * ident 175 return matrix_u, matrix_v 176 177 178def _matrix_exp_pade9(matrix): 179 """9th-order Pade approximant for matrix exponential.""" 180 b = [ 181 17643225600.0, 8821612800.0, 2075673600.0, 302702400.0, 30270240.0, 182 2162160.0, 110880.0, 3960.0, 90.0 183 ] 184 b = [constant_op.constant(x, matrix.dtype) for x in b] 185 ident = linalg_ops.eye( 186 array_ops.shape(matrix)[-2], 187 batch_shape=array_ops.shape(matrix)[:-2], 188 dtype=matrix.dtype) 189 matrix_2 = math_ops.matmul(matrix, matrix) 190 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 191 matrix_6 = math_ops.matmul(matrix_4, matrix_2) 192 matrix_8 = math_ops.matmul(matrix_6, matrix_2) 193 tmp = ( 194 matrix_8 + b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + 195 b[1] * ident) 196 matrix_u = math_ops.matmul(matrix, tmp) 197 matrix_v = ( 198 b[8] * matrix_8 + b[6] * matrix_6 + b[4] * matrix_4 + b[2] * matrix_2 + 199 b[0] * ident) 200 return matrix_u, matrix_v 201 202 203def _matrix_exp_pade13(matrix): 204 """13th-order Pade approximant for matrix exponential.""" 205 b = [ 206 64764752532480000.0, 32382376266240000.0, 7771770303897600.0, 207 1187353796428800.0, 129060195264000.0, 10559470521600.0, 670442572800.0, 208 33522128640.0, 1323241920.0, 40840800.0, 960960.0, 16380.0, 182.0 209 ] 210 b = [constant_op.constant(x, matrix.dtype) for x in b] 211 ident = linalg_ops.eye( 212 array_ops.shape(matrix)[-2], 213 batch_shape=array_ops.shape(matrix)[:-2], 214 dtype=matrix.dtype) 215 matrix_2 = math_ops.matmul(matrix, matrix) 216 matrix_4 = math_ops.matmul(matrix_2, matrix_2) 217 matrix_6 = math_ops.matmul(matrix_4, matrix_2) 218 tmp_u = ( 219 math_ops.matmul(matrix_6, matrix_6 + b[11] * matrix_4 + b[9] * matrix_2) + 220 b[7] * matrix_6 + b[5] * matrix_4 + b[3] * matrix_2 + b[1] * ident) 221 matrix_u = math_ops.matmul(matrix, tmp_u) 222 tmp_v = b[12] * matrix_6 + b[10] * matrix_4 + b[8] * matrix_2 223 matrix_v = ( 224 math_ops.matmul(matrix_6, tmp_v) + b[6] * matrix_6 + b[4] * matrix_4 + 225 b[2] * matrix_2 + b[0] * ident) 226 return matrix_u, matrix_v 227 228 229@tf_export('linalg.expm') 230@dispatch.add_dispatch_support 231def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin 232 r"""Computes the matrix exponential of one or more square matrices. 233 234 $$exp(A) = \sum_{n=0}^\infty A^n/n!$$ 235 236 The exponential is computed using a combination of the scaling and squaring 237 method and the Pade approximation. Details can be found in: 238 Nicholas J. Higham, "The scaling and squaring method for the matrix 239 exponential revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005. 240 241 The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions 242 form square matrices. The output is a tensor of the same shape as the input 243 containing the exponential for all input submatrices `[..., :, :]`. 244 245 Args: 246 input: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, or 247 `complex128` with shape `[..., M, M]`. 248 name: A name to give this `Op` (optional). 249 250 Returns: 251 the matrix exponential of the input. 252 253 Raises: 254 ValueError: An unsupported type is provided as input. 255 256 @compatibility(scipy) 257 Equivalent to scipy.linalg.expm 258 @end_compatibility 259 """ 260 with ops.name_scope(name, 'matrix_exponential', [input]): 261 matrix = ops.convert_to_tensor(input, name='input') 262 if matrix.shape[-2:] == [0, 0]: 263 return matrix 264 batch_shape = matrix.shape[:-2] 265 if not batch_shape.is_fully_defined(): 266 batch_shape = array_ops.shape(matrix)[:-2] 267 268 # reshaping the batch makes the where statements work better 269 matrix = array_ops.reshape( 270 matrix, array_ops.concat(([-1], array_ops.shape(matrix)[-2:]), axis=0)) 271 l1_norm = math_ops.reduce_max( 272 math_ops.reduce_sum( 273 math_ops.abs(matrix), 274 axis=array_ops.size(array_ops.shape(matrix)) - 2), 275 axis=-1)[..., array_ops.newaxis, array_ops.newaxis] 276 277 const = lambda x: constant_op.constant(x, l1_norm.dtype) 278 279 def _nest_where(vals, cases): 280 assert len(vals) == len(cases) - 1 281 if len(vals) == 1: 282 return array_ops.where_v2( 283 math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1]) 284 else: 285 return array_ops.where_v2( 286 math_ops.less(l1_norm, const(vals[0])), cases[0], 287 _nest_where(vals[1:], cases[1:])) 288 289 if matrix.dtype in [dtypes.float16, dtypes.float32, dtypes.complex64]: 290 maxnorm = const(3.925724783138660) 291 squarings = math_ops.maximum( 292 math_ops.floor( 293 math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0) 294 u3, v3 = _matrix_exp_pade3(matrix) 295 u5, v5 = _matrix_exp_pade5(matrix) 296 u7, v7 = _matrix_exp_pade7( 297 matrix / 298 math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype)) 299 conds = (4.258730016922831e-001, 1.880152677804762e+000) 300 u = _nest_where(conds, (u3, u5, u7)) 301 v = _nest_where(conds, (v3, v5, v7)) 302 elif matrix.dtype in [dtypes.float64, dtypes.complex128]: 303 maxnorm = const(5.371920351148152) 304 squarings = math_ops.maximum( 305 math_ops.floor( 306 math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0) 307 u3, v3 = _matrix_exp_pade3(matrix) 308 u5, v5 = _matrix_exp_pade5(matrix) 309 u7, v7 = _matrix_exp_pade7(matrix) 310 u9, v9 = _matrix_exp_pade9(matrix) 311 u13, v13 = _matrix_exp_pade13( 312 matrix / 313 math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype)) 314 conds = (1.495585217958292e-002, 2.539398330063230e-001, 315 9.504178996162932e-001, 2.097847961257068e+000) 316 u = _nest_where(conds, (u3, u5, u7, u9, u13)) 317 v = _nest_where(conds, (v3, v5, v7, v9, v13)) 318 else: 319 raise ValueError('tf.linalg.expm does not support matrices of type %s' % 320 matrix.dtype) 321 322 is_finite = math_ops.is_finite(math_ops.reduce_max(l1_norm)) 323 nan = constant_op.constant(np.nan, matrix.dtype) 324 result = control_flow_ops.cond( 325 is_finite, lambda: linalg_ops.matrix_solve(-u + v, u + v), 326 lambda: array_ops.fill(array_ops.shape(matrix), nan)) 327 max_squarings = math_ops.reduce_max(squarings) 328 i = const(0.0) 329 330 def c(i, _): 331 return control_flow_ops.cond(is_finite, 332 lambda: math_ops.less(i, max_squarings), 333 lambda: constant_op.constant(False)) 334 335 def b(i, r): 336 return i + 1, array_ops.where_v2( 337 math_ops.less(i, squarings), math_ops.matmul(r, r), r) 338 339 _, result = control_flow_ops.while_loop(c, b, [i, result]) 340 if not matrix.shape.is_fully_defined(): 341 return array_ops.reshape( 342 result, 343 array_ops.concat((batch_shape, array_ops.shape(result)[-2:]), axis=0)) 344 return array_ops.reshape(result, batch_shape.concatenate(result.shape[-2:])) 345 346 347@tf_export('linalg.banded_triangular_solve', v1=[]) 348def banded_triangular_solve( 349 bands, 350 rhs, 351 lower=True, 352 adjoint=False, # pylint: disable=redefined-outer-name 353 name=None): 354 r"""Solve triangular systems of equations with a banded solver. 355 356 `bands` is a tensor of shape `[..., K, M]`, where `K` represents the number 357 of bands stored. This corresponds to a batch of `M` by `M` matrices, whose 358 `K` subdiagonals (when `lower` is `True`) are stored. 359 360 This operator broadcasts the batch dimensions of `bands` and the batch 361 dimensions of `rhs`. 362 363 364 Examples: 365 366 Storing 2 bands of a 3x3 matrix. 367 Note that first element in the second row is ignored due to 368 the 'LEFT_RIGHT' padding. 369 370 >>> x = [[2., 3., 4.], [1., 2., 3.]] 371 >>> x2 = [[2., 3., 4.], [10000., 2., 3.]] 372 >>> y = tf.zeros([3, 3]) 373 >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(-1, 0)) 374 >>> z 375 <tf.Tensor: shape=(3, 3), dtype=float32, numpy= 376 array([[2., 0., 0.], 377 [2., 3., 0.], 378 [0., 3., 4.]], dtype=float32)> 379 >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([3, 1])) 380 >>> soln 381 <tf.Tensor: shape=(3, 1), dtype=float32, numpy= 382 array([[0.5 ], 383 [0. ], 384 [0.25]], dtype=float32)> 385 >>> are_equal = soln == tf.linalg.banded_triangular_solve(x2, tf.ones([3, 1])) 386 >>> tf.reduce_all(are_equal).numpy() 387 True 388 >>> are_equal = soln == tf.linalg.triangular_solve(z, tf.ones([3, 1])) 389 >>> tf.reduce_all(are_equal).numpy() 390 True 391 392 Storing 2 superdiagonals of a 4x4 matrix. Because of the 'LEFT_RIGHT' padding 393 the last element of the first row is ignored. 394 395 >>> x = [[2., 3., 4., 5.], [-1., -2., -3., -4.]] 396 >>> y = tf.zeros([4, 4]) 397 >>> z = tf.linalg.set_diag(y, x, align='LEFT_RIGHT', k=(0, 1)) 398 >>> z 399 <tf.Tensor: shape=(4, 4), dtype=float32, numpy= 400 array([[-1., 2., 0., 0.], 401 [ 0., -2., 3., 0.], 402 [ 0., 0., -3., 4.], 403 [ 0., 0., -0., -4.]], dtype=float32)> 404 >>> soln = tf.linalg.banded_triangular_solve(x, tf.ones([4, 1]), lower=False) 405 >>> soln 406 <tf.Tensor: shape=(4, 1), dtype=float32, numpy= 407 array([[-4. ], 408 [-1.5 ], 409 [-0.6666667], 410 [-0.25 ]], dtype=float32)> 411 >>> are_equal = (soln == tf.linalg.triangular_solve( 412 ... z, tf.ones([4, 1]), lower=False)) 413 >>> tf.reduce_all(are_equal).numpy() 414 True 415 416 417 Args: 418 bands: A `Tensor` describing the bands of the left hand side, with shape 419 `[..., K, M]`. The `K` rows correspond to the diagonal to the `K - 1`-th 420 diagonal (the diagonal is the top row) when `lower` is `True` and 421 otherwise the `K - 1`-th superdiagonal to the diagonal (the diagonal is 422 the bottom row) when `lower` is `False`. The bands are stored with 423 'LEFT_RIGHT' alignment, where the superdiagonals are padded on the right 424 and subdiagonals are padded on the left. This is the alignment cuSPARSE 425 uses. See `tf.linalg.set_diag` for more details. 426 rhs: A `Tensor` of shape [..., M] or [..., M, N] and with the same dtype as 427 `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known 428 statically, `rhs` will be treated as a matrix rather than a vector. 429 lower: An optional `bool`. Defaults to `True`. Boolean indicating whether 430 `bands` represents a lower or upper triangular matrix. 431 adjoint: An optional `bool`. Defaults to `False`. Boolean indicating whether 432 to solve with the matrix's block-wise adjoint. 433 name: A name to give this `Op` (optional). 434 435 Returns: 436 A `Tensor` of shape [..., M] or [..., M, N] containing the solutions. 437 """ 438 with ops.name_scope(name, 'banded_triangular_solve', [bands, rhs]): 439 return gen_linalg_ops.banded_triangular_solve( 440 bands, rhs, lower=lower, adjoint=adjoint) 441 442 443@tf_export('linalg.tridiagonal_solve') 444@dispatch.add_dispatch_support 445def tridiagonal_solve(diagonals, 446 rhs, 447 diagonals_format='compact', 448 transpose_rhs=False, 449 conjugate_rhs=False, 450 name=None, 451 partial_pivoting=True, 452 perturb_singular=False): 453 r"""Solves tridiagonal systems of equations. 454 455 The input can be supplied in various formats: `matrix`, `sequence` and 456 `compact`, specified by the `diagonals_format` arg. 457 458 In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with 459 two inner-most dimensions representing the square tridiagonal matrices. 460 Elements outside of the three diagonals will be ignored. 461 462 In `sequence` format, `diagonals` are supplied as a tuple or list of three 463 tensors of shapes `[..., N]`, `[..., M]`, `[..., N]` representing 464 superdiagonals, diagonals, and subdiagonals, respectively. `N` can be either 465 `M-1` or `M`; in the latter case, the last element of superdiagonal and the 466 first element of subdiagonal will be ignored. 467 468 In `compact` format the three diagonals are brought together into one tensor 469 of shape `[..., 3, M]`, with last two dimensions containing superdiagonals, 470 diagonals, and subdiagonals, in order. Similarly to `sequence` format, 471 elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored. 472 473 The `compact` format is recommended as the one with best performance. In case 474 you need to cast a tensor into a compact format manually, use `tf.gather_nd`. 475 An example for a tensor of shape [m, m]: 476 477 ```python 478 rhs = tf.constant([...]) 479 matrix = tf.constant([[...]]) 480 m = matrix.shape[0] 481 dummy_idx = [0, 0] # An arbitrary element to use as a dummy 482 indices = [[[i, i + 1] for i in range(m - 1)] + [dummy_idx], # Superdiagonal 483 [[i, i] for i in range(m)], # Diagonal 484 [dummy_idx] + [[i + 1, i] for i in range(m - 1)]] # Subdiagonal 485 diagonals=tf.gather_nd(matrix, indices) 486 x = tf.linalg.tridiagonal_solve(diagonals, rhs) 487 ``` 488 489 Regardless of the `diagonals_format`, `rhs` is a tensor of shape `[..., M]` or 490 `[..., M, K]`. The latter allows to simultaneously solve K systems with the 491 same left-hand sides and K different right-hand sides. If `transpose_rhs` 492 is set to `True` the expected shape is `[..., M]` or `[..., K, M]`. 493 494 The batch dimensions, denoted as `...`, must be the same in `diagonals` and 495 `rhs`. 496 497 The output is a tensor of the same shape as `rhs`: either `[..., M]` or 498 `[..., M, K]`. 499 500 The op isn't guaranteed to raise an error if the input matrix is not 501 invertible. `tf.debugging.check_numerics` can be applied to the output to 502 detect invertibility problems. 503 504 **Note**: with large batch sizes, the computation on the GPU may be slow, if 505 either `partial_pivoting=True` or there are multiple right-hand sides 506 (`K > 1`). If this issue arises, consider if it's possible to disable pivoting 507 and have `K = 1`, or, alternatively, consider using CPU. 508 509 On CPU, solution is computed via Gaussian elimination with or without partial 510 pivoting, depending on `partial_pivoting` parameter. On GPU, Nvidia's cuSPARSE 511 library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv 512 513 Args: 514 diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The 515 shape depends of `diagonals_format`, see description above. Must be 516 `float32`, `float64`, `complex64`, or `complex128`. 517 rhs: A `Tensor` of shape [..., M] or [..., M, K] and with the same dtype as 518 `diagonals`. Note that if the shape of `rhs` and/or `diags` isn't known 519 statically, `rhs` will be treated as a matrix rather than a vector. 520 diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is 521 `compact`. 522 transpose_rhs: If `True`, `rhs` is transposed before solving (has no effect 523 if the shape of rhs is [..., M]). 524 conjugate_rhs: If `True`, `rhs` is conjugated before solving. 525 name: A name to give this `Op` (optional). 526 partial_pivoting: whether to perform partial pivoting. `True` by default. 527 Partial pivoting makes the procedure more stable, but slower. Partial 528 pivoting is unnecessary in some cases, including diagonally dominant and 529 symmetric positive definite matrices (see e.g. theorem 9.12 in [1]). 530 perturb_singular: whether to perturb singular matrices to return a finite 531 result. `False` by default. If true, solutions to systems involving 532 a singular matrix will be computed by perturbing near-zero pivots in 533 the partially pivoted LU decomposition. Specifically, tiny pivots are 534 perturbed by an amount of order `eps * max_{ij} |U(i,j)|` to avoid 535 overflow. Here `U` is the upper triangular part of the LU decomposition, 536 and `eps` is the machine precision. This is useful for solving 537 numerically singular systems when computing eigenvectors by inverse 538 iteration. 539 If `partial_pivoting` is `False`, `perturb_singular` must be `False` as 540 well. 541 542 Returns: 543 A `Tensor` of shape [..., M] or [..., M, K] containing the solutions. 544 If the input matrix is singular, the result is undefined. 545 546 Raises: 547 ValueError: Is raised if any of the following conditions hold: 548 1. An unsupported type is provided as input, 549 2. the input tensors have incorrect shapes, 550 3. `perturb_singular` is `True` but `partial_pivoting` is not. 551 UnimplementedError: Whenever `partial_pivoting` is true and the backend is 552 XLA, or whenever `perturb_singular` is true and the backend is 553 XLA or GPU. 554 555 [1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms: 556 Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7. 557 558 """ 559 if perturb_singular and not partial_pivoting: 560 raise ValueError('partial_pivoting must be True if perturb_singular is.') 561 562 if diagonals_format == 'compact': 563 return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 564 conjugate_rhs, partial_pivoting, 565 perturb_singular, name) 566 567 if diagonals_format == 'sequence': 568 if not isinstance(diagonals, (tuple, list)) or len(diagonals) != 3: 569 raise ValueError('Expected diagonals to be a sequence of length 3.') 570 571 superdiag, maindiag, subdiag = diagonals 572 if (not subdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1]) or 573 not superdiag.shape[:-1].is_compatible_with(maindiag.shape[:-1])): 574 raise ValueError( 575 'Tensors representing the three diagonals must have the same shape,' 576 'except for the last dimension, got {}, {}, {}'.format( 577 subdiag.shape, maindiag.shape, superdiag.shape)) 578 579 m = tensor_shape.dimension_value(maindiag.shape[-1]) 580 581 def pad_if_necessary(t, name, last_dim_padding): 582 n = tensor_shape.dimension_value(t.shape[-1]) 583 if not n or n == m: 584 return t 585 if n == m - 1: 586 paddings = ([[0, 0] for _ in range(len(t.shape) - 1)] + 587 [last_dim_padding]) 588 return array_ops.pad(t, paddings) 589 raise ValueError('Expected {} to be have length {} or {}, got {}.'.format( 590 name, m, m - 1, n)) 591 592 subdiag = pad_if_necessary(subdiag, 'subdiagonal', [1, 0]) 593 superdiag = pad_if_necessary(superdiag, 'superdiagonal', [0, 1]) 594 595 diagonals = array_ops.stack((superdiag, maindiag, subdiag), axis=-2) 596 return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 597 conjugate_rhs, partial_pivoting, 598 perturb_singular, name) 599 600 if diagonals_format == 'matrix': 601 m1 = tensor_shape.dimension_value(diagonals.shape[-1]) 602 m2 = tensor_shape.dimension_value(diagonals.shape[-2]) 603 if m1 and m2 and m1 != m2: 604 raise ValueError( 605 'Expected last two dimensions of diagonals to be same, got {} and {}' 606 .format(m1, m2)) 607 m = m1 or m2 608 diagonals = array_ops.matrix_diag_part( 609 diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT') 610 return _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 611 conjugate_rhs, partial_pivoting, 612 perturb_singular, name) 613 614 raise ValueError('Unrecognized diagonals_format: {}'.format(diagonals_format)) 615 616 617def _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, 618 conjugate_rhs, partial_pivoting, 619 perturb_singular, name): 620 """Helper function used after the input has been cast to compact form.""" 621 diags_rank, rhs_rank = diagonals.shape.rank, rhs.shape.rank 622 623 # If we know the rank of the diagonal tensor, do some static checking. 624 if diags_rank: 625 if diags_rank < 2: 626 raise ValueError( 627 'Expected diagonals to have rank at least 2, got {}'.format( 628 diags_rank)) 629 if rhs_rank and rhs_rank != diags_rank and rhs_rank != diags_rank - 1: 630 raise ValueError('Expected the rank of rhs to be {} or {}, got {}'.format( 631 diags_rank - 1, diags_rank, rhs_rank)) 632 if (rhs_rank and not diagonals.shape[:-2].is_compatible_with( 633 rhs.shape[:diags_rank - 2])): 634 raise ValueError('Batch shapes {} and {} are incompatible'.format( 635 diagonals.shape[:-2], rhs.shape[:diags_rank - 2])) 636 637 if diagonals.shape[-2] and diagonals.shape[-2] != 3: 638 raise ValueError('Expected 3 diagonals got {}'.format(diagonals.shape[-2])) 639 640 def check_num_lhs_matches_num_rhs(): 641 if (diagonals.shape[-1] and rhs.shape[-2] and 642 diagonals.shape[-1] != rhs.shape[-2]): 643 raise ValueError('Expected number of left-hand sided and right-hand ' 644 'sides to be equal, got {} and {}'.format( 645 diagonals.shape[-1], rhs.shape[-2])) 646 647 if rhs_rank and diags_rank and rhs_rank == diags_rank - 1: 648 # Rhs provided as a vector, ignoring transpose_rhs 649 if conjugate_rhs: 650 rhs = math_ops.conj(rhs) 651 rhs = array_ops.expand_dims(rhs, -1) 652 check_num_lhs_matches_num_rhs() 653 return array_ops.squeeze( 654 linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting, 655 perturb_singular, name), -1) 656 657 if transpose_rhs: 658 rhs = array_ops.matrix_transpose(rhs, conjugate=conjugate_rhs) 659 elif conjugate_rhs: 660 rhs = math_ops.conj(rhs) 661 662 check_num_lhs_matches_num_rhs() 663 return linalg_ops.tridiagonal_solve(diagonals, rhs, partial_pivoting, 664 perturb_singular, name) 665 666 667@tf_export('linalg.tridiagonal_matmul') 668@dispatch.add_dispatch_support 669def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None): 670 r"""Multiplies tridiagonal matrix by matrix. 671 672 `diagonals` is representation of 3-diagonal NxN matrix, which depends on 673 `diagonals_format`. 674 675 In `matrix` format, `diagonals` must be a tensor of shape `[..., M, M]`, with 676 two inner-most dimensions representing the square tridiagonal matrices. 677 Elements outside of the three diagonals will be ignored. 678 679 If `sequence` format, `diagonals` is list or tuple of three tensors: 680 `[superdiag, maindiag, subdiag]`, each having shape [..., M]. Last element 681 of `superdiag` first element of `subdiag` are ignored. 682 683 In `compact` format the three diagonals are brought together into one tensor 684 of shape `[..., 3, M]`, with last two dimensions containing superdiagonals, 685 diagonals, and subdiagonals, in order. Similarly to `sequence` format, 686 elements `diagonals[..., 0, M-1]` and `diagonals[..., 2, 0]` are ignored. 687 688 The `sequence` format is recommended as the one with the best performance. 689 690 `rhs` is matrix to the right of multiplication. It has shape `[..., M, N]`. 691 692 Example: 693 694 ```python 695 superdiag = tf.constant([-1, -1, 0], dtype=tf.float64) 696 maindiag = tf.constant([2, 2, 2], dtype=tf.float64) 697 subdiag = tf.constant([0, -1, -1], dtype=tf.float64) 698 diagonals = [superdiag, maindiag, subdiag] 699 rhs = tf.constant([[1, 1], [1, 1], [1, 1]], dtype=tf.float64) 700 x = tf.linalg.tridiagonal_matmul(diagonals, rhs, diagonals_format='sequence') 701 ``` 702 703 Args: 704 diagonals: A `Tensor` or tuple of `Tensor`s describing left-hand sides. The 705 shape depends of `diagonals_format`, see description above. Must be 706 `float32`, `float64`, `complex64`, or `complex128`. 707 rhs: A `Tensor` of shape [..., M, N] and with the same dtype as `diagonals`. 708 diagonals_format: one of `sequence`, or `compact`. Default is `compact`. 709 name: A name to give this `Op` (optional). 710 711 Returns: 712 A `Tensor` of shape [..., M, N] containing the result of multiplication. 713 714 Raises: 715 ValueError: An unsupported type is provided as input, or when the input 716 tensors have incorrect shapes. 717 """ 718 if diagonals_format == 'compact': 719 superdiag = diagonals[..., 0, :] 720 maindiag = diagonals[..., 1, :] 721 subdiag = diagonals[..., 2, :] 722 elif diagonals_format == 'sequence': 723 superdiag, maindiag, subdiag = diagonals 724 elif diagonals_format == 'matrix': 725 m1 = tensor_shape.dimension_value(diagonals.shape[-1]) 726 m2 = tensor_shape.dimension_value(diagonals.shape[-2]) 727 if m1 and m2 and m1 != m2: 728 raise ValueError( 729 'Expected last two dimensions of diagonals to be same, got {} and {}' 730 .format(m1, m2)) 731 diags = array_ops.matrix_diag_part( 732 diagonals, k=(-1, 1), padding_value=0., align='LEFT_RIGHT') 733 superdiag = diags[..., 0, :] 734 maindiag = diags[..., 1, :] 735 subdiag = diags[..., 2, :] 736 else: 737 raise ValueError('Unrecognized diagonals_format: %s' % diagonals_format) 738 739 # C++ backend requires matrices. 740 # Converting 1-dimensional vectors to matrices with 1 row. 741 superdiag = array_ops.expand_dims(superdiag, -2) 742 maindiag = array_ops.expand_dims(maindiag, -2) 743 subdiag = array_ops.expand_dims(subdiag, -2) 744 745 return linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs, name) 746 747 748def _maybe_validate_matrix(a, validate_args): 749 """Checks that input is a `float` matrix.""" 750 assertions = [] 751 if not a.dtype.is_floating: 752 raise TypeError('Input `a` must have `float`-like `dtype` ' 753 '(saw {}).'.format(a.dtype.name)) 754 if a.shape is not None and a.shape.rank is not None: 755 if a.shape.rank < 2: 756 raise ValueError('Input `a` must have at least 2 dimensions ' 757 '(saw: {}).'.format(a.shape.rank)) 758 elif validate_args: 759 assertions.append( 760 check_ops.assert_rank_at_least( 761 a, rank=2, message='Input `a` must have at least 2 dimensions.')) 762 return assertions 763 764 765@tf_export('linalg.matrix_rank') 766@dispatch.add_dispatch_support 767def matrix_rank(a, tol=None, validate_args=False, name=None): 768 """Compute the matrix rank of one or more matrices. 769 770 Args: 771 a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be 772 pseudo-inverted. 773 tol: Threshold below which the singular value is counted as 'zero'. 774 Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`). 775 validate_args: When `True`, additional assertions might be embedded in the 776 graph. 777 Default value: `False` (i.e., no graph assertions are added). 778 name: Python `str` prefixed to ops created by this function. 779 Default value: 'matrix_rank'. 780 781 Returns: 782 matrix_rank: (Batch of) `int32` scalars representing the number of non-zero 783 singular values. 784 """ 785 with ops.name_scope(name or 'matrix_rank'): 786 a = ops.convert_to_tensor(a, dtype_hint=dtypes.float32, name='a') 787 assertions = _maybe_validate_matrix(a, validate_args) 788 if assertions: 789 with ops.control_dependencies(assertions): 790 a = array_ops.identity(a) 791 s = svd(a, compute_uv=False) 792 if tol is None: 793 if (a.shape[-2:]).is_fully_defined(): 794 m = np.max(a.shape[-2:].as_list()) 795 else: 796 m = math_ops.reduce_max(array_ops.shape(a)[-2:]) 797 eps = np.finfo(a.dtype.as_numpy_dtype).eps 798 tol = ( 799 eps * math_ops.cast(m, a.dtype) * 800 math_ops.reduce_max(s, axis=-1, keepdims=True)) 801 return math_ops.reduce_sum(math_ops.cast(s > tol, dtypes.int32), axis=-1) 802 803 804@tf_export('linalg.pinv') 805@dispatch.add_dispatch_support 806def pinv(a, rcond=None, validate_args=False, name=None): 807 """Compute the Moore-Penrose pseudo-inverse of one or more matrices. 808 809 Calculate the [generalized inverse of a matrix]( 810 https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its 811 singular-value decomposition (SVD) and including all large singular values. 812 813 The pseudo-inverse of a matrix `A`, is defined as: 'the matrix that 'solves' 814 [the least-squares problem] `A @ x = b`,' i.e., if `x_hat` is a solution, then 815 `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if 816 `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then 817 `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1] 818 819 This function is analogous to [`numpy.linalg.pinv`]( 820 https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html). 821 It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the 822 default `rcond` is `1e-15`. Here the default is 823 `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`. 824 825 Args: 826 a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be 827 pseudo-inverted. 828 rcond: `Tensor` of small singular value cutoffs. Singular values smaller 829 (in modulus) than `rcond` * largest_singular_value (again, in modulus) are 830 set to zero. Must broadcast against `tf.shape(a)[:-2]`. 831 Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`. 832 validate_args: When `True`, additional assertions might be embedded in the 833 graph. 834 Default value: `False` (i.e., no graph assertions are added). 835 name: Python `str` prefixed to ops created by this function. 836 Default value: 'pinv'. 837 838 Returns: 839 a_pinv: (Batch of) pseudo-inverse of input `a`. Has same shape as `a` except 840 rightmost two dimensions are transposed. 841 842 Raises: 843 TypeError: if input `a` does not have `float`-like `dtype`. 844 ValueError: if input `a` has fewer than 2 dimensions. 845 846 #### Examples 847 848 ```python 849 import tensorflow as tf 850 import tensorflow_probability as tfp 851 852 a = tf.constant([[1., 0.4, 0.5], 853 [0.4, 0.2, 0.25], 854 [0.5, 0.25, 0.35]]) 855 tf.matmul(tf.linalg.pinv(a), a) 856 # ==> array([[1., 0., 0.], 857 [0., 1., 0.], 858 [0., 0., 1.]], dtype=float32) 859 860 a = tf.constant([[1., 0.4, 0.5, 1.], 861 [0.4, 0.2, 0.25, 2.], 862 [0.5, 0.25, 0.35, 3.]]) 863 tf.matmul(tf.linalg.pinv(a), a) 864 # ==> array([[ 0.76, 0.37, 0.21, -0.02], 865 [ 0.37, 0.43, -0.33, 0.02], 866 [ 0.21, -0.33, 0.81, 0.01], 867 [-0.02, 0.02, 0.01, 1. ]], dtype=float32) 868 ``` 869 870 #### References 871 872 [1]: G. Strang. 'Linear Algebra and Its Applications, 2nd Ed.' Academic Press, 873 Inc., 1980, pp. 139-142. 874 """ 875 with ops.name_scope(name or 'pinv'): 876 a = ops.convert_to_tensor(a, name='a') 877 878 assertions = _maybe_validate_matrix(a, validate_args) 879 if assertions: 880 with ops.control_dependencies(assertions): 881 a = array_ops.identity(a) 882 883 dtype = a.dtype.as_numpy_dtype 884 885 if rcond is None: 886 887 def get_dim_size(dim): 888 dim_val = tensor_shape.dimension_value(a.shape[dim]) 889 if dim_val is not None: 890 return dim_val 891 return array_ops.shape(a)[dim] 892 893 num_rows = get_dim_size(-2) 894 num_cols = get_dim_size(-1) 895 if isinstance(num_rows, int) and isinstance(num_cols, int): 896 max_rows_cols = float(max(num_rows, num_cols)) 897 else: 898 max_rows_cols = math_ops.cast( 899 math_ops.maximum(num_rows, num_cols), dtype) 900 rcond = 10. * max_rows_cols * np.finfo(dtype).eps 901 902 rcond = ops.convert_to_tensor(rcond, dtype=dtype, name='rcond') 903 904 # Calculate pseudo inverse via SVD. 905 # Note: if a is Hermitian then u == v. (We might observe additional 906 # performance by explicitly setting `v = u` in such cases.) 907 [ 908 singular_values, # Sigma 909 left_singular_vectors, # U 910 right_singular_vectors, # V 911 ] = svd( 912 a, full_matrices=False, compute_uv=True) 913 914 # Saturate small singular values to inf. This has the effect of make 915 # `1. / s = 0.` while not resulting in `NaN` gradients. 916 cutoff = rcond * math_ops.reduce_max(singular_values, axis=-1) 917 singular_values = array_ops.where_v2( 918 singular_values > array_ops.expand_dims_v2(cutoff, -1), singular_values, 919 np.array(np.inf, dtype)) 920 921 # By the definition of the SVD, `a == u @ s @ v^H`, and the pseudo-inverse 922 # is defined as `pinv(a) == v @ inv(s) @ u^H`. 923 a_pinv = math_ops.matmul( 924 right_singular_vectors / array_ops.expand_dims_v2(singular_values, -2), 925 left_singular_vectors, 926 adjoint_b=True) 927 928 if a.shape is not None and a.shape.rank is not None: 929 a_pinv.set_shape(a.shape[:-2].concatenate([a.shape[-1], a.shape[-2]])) 930 931 return a_pinv 932 933 934@tf_export('linalg.lu_solve') 935@dispatch.add_dispatch_support 936def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): 937 """Solves systems of linear eqns `A X = RHS`, given LU factorizations. 938 939 Note: this function does not verify the implied matrix is actually invertible 940 nor is this condition checked even when `validate_args=True`. 941 942 Args: 943 lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, 944 matmul(L, U)) = X` then `lower_upper = L + U - eye`. 945 perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = 946 X` then `perm = argmax(P)`. 947 rhs: Matrix-shaped float `Tensor` representing targets for which to solve; 948 `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[..., 949 tf.newaxis])[..., 0]`. 950 validate_args: Python `bool` indicating whether arguments should be checked 951 for correctness. Note: this function does not verify the implied matrix is 952 actually invertible, even when `validate_args=True`. 953 Default value: `False` (i.e., don't validate arguments). 954 name: Python `str` name given to ops managed by this object. 955 Default value: `None` (i.e., 'lu_solve'). 956 957 Returns: 958 x: The `X` in `A @ X = RHS`. 959 960 #### Examples 961 962 ```python 963 import numpy as np 964 import tensorflow as tf 965 import tensorflow_probability as tfp 966 967 x = [[[1., 2], 968 [3, 4]], 969 [[7, 8], 970 [3, 4]]] 971 inv_x = tf.linalg.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2)) 972 tf.assert_near(tf.matrix_inverse(x), inv_x) 973 # ==> True 974 ``` 975 976 """ 977 978 with ops.name_scope(name or 'lu_solve'): 979 lower_upper = ops.convert_to_tensor( 980 lower_upper, dtype_hint=dtypes.float32, name='lower_upper') 981 perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') 982 rhs = ops.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs') 983 984 assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args) 985 if assertions: 986 with ops.control_dependencies(assertions): 987 lower_upper = array_ops.identity(lower_upper) 988 perm = array_ops.identity(perm) 989 rhs = array_ops.identity(rhs) 990 991 if (rhs.shape.rank == 2 and perm.shape.rank == 1): 992 # Both rhs and perm have scalar batch_shape. 993 permuted_rhs = array_ops.gather(rhs, perm, axis=-2) 994 else: 995 # Either rhs or perm have non-scalar batch_shape or we can't determine 996 # this information statically. 997 rhs_shape = array_ops.shape(rhs) 998 broadcast_batch_shape = array_ops.broadcast_dynamic_shape( 999 rhs_shape[:-2], 1000 array_ops.shape(perm)[:-1]) 1001 d, m = rhs_shape[-2], rhs_shape[-1] 1002 rhs_broadcast_shape = array_ops.concat([broadcast_batch_shape, [d, m]], 1003 axis=0) 1004 1005 # Tile out rhs. 1006 broadcast_rhs = array_ops.broadcast_to(rhs, rhs_broadcast_shape) 1007 broadcast_rhs = array_ops.reshape(broadcast_rhs, [-1, d, m]) 1008 1009 # Tile out perm and add batch indices. 1010 broadcast_perm = array_ops.broadcast_to(perm, rhs_broadcast_shape[:-1]) 1011 broadcast_perm = array_ops.reshape(broadcast_perm, [-1, d]) 1012 broadcast_batch_size = math_ops.reduce_prod(broadcast_batch_shape) 1013 broadcast_batch_indices = array_ops.broadcast_to( 1014 math_ops.range(broadcast_batch_size)[:, array_ops.newaxis], 1015 [broadcast_batch_size, d]) 1016 broadcast_perm = array_ops.stack( 1017 [broadcast_batch_indices, broadcast_perm], axis=-1) 1018 1019 permuted_rhs = array_ops.gather_nd(broadcast_rhs, broadcast_perm) 1020 permuted_rhs = array_ops.reshape(permuted_rhs, rhs_broadcast_shape) 1021 1022 lower = set_diag( 1023 band_part(lower_upper, num_lower=-1, num_upper=0), 1024 array_ops.ones( 1025 array_ops.shape(lower_upper)[:-1], dtype=lower_upper.dtype)) 1026 return triangular_solve( 1027 lower_upper, # Only upper is accessed. 1028 triangular_solve(lower, permuted_rhs), 1029 lower=False) 1030 1031 1032@tf_export('linalg.lu_matrix_inverse') 1033@dispatch.add_dispatch_support 1034def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None): 1035 """Computes the inverse given the LU decomposition(s) of one or more matrices. 1036 1037 This op is conceptually identical to, 1038 1039 ```python 1040 inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X)) 1041 tf.assert_near(tf.matrix_inverse(X), inv_X) 1042 # ==> True 1043 ``` 1044 1045 Note: this function does not verify the implied matrix is actually invertible 1046 nor is this condition checked even when `validate_args=True`. 1047 1048 Args: 1049 lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, 1050 matmul(L, U)) = X` then `lower_upper = L + U - eye`. 1051 perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = 1052 X` then `perm = argmax(P)`. 1053 validate_args: Python `bool` indicating whether arguments should be checked 1054 for correctness. Note: this function does not verify the implied matrix is 1055 actually invertible, even when `validate_args=True`. 1056 Default value: `False` (i.e., don't validate arguments). 1057 name: Python `str` name given to ops managed by this object. 1058 Default value: `None` (i.e., 'lu_matrix_inverse'). 1059 1060 Returns: 1061 inv_x: The matrix_inv, i.e., 1062 `tf.matrix_inverse(tf.linalg.lu_reconstruct(lu, perm))`. 1063 1064 #### Examples 1065 1066 ```python 1067 import numpy as np 1068 import tensorflow as tf 1069 import tensorflow_probability as tfp 1070 1071 x = [[[3., 4], [1, 2]], 1072 [[7., 8], [3, 4]]] 1073 inv_x = tf.linalg.lu_matrix_inverse(*tf.linalg.lu(x)) 1074 tf.assert_near(tf.matrix_inverse(x), inv_x) 1075 # ==> True 1076 ``` 1077 1078 """ 1079 1080 with ops.name_scope(name or 'lu_matrix_inverse'): 1081 lower_upper = ops.convert_to_tensor( 1082 lower_upper, dtype_hint=dtypes.float32, name='lower_upper') 1083 perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') 1084 assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) 1085 if assertions: 1086 with ops.control_dependencies(assertions): 1087 lower_upper = array_ops.identity(lower_upper) 1088 perm = array_ops.identity(perm) 1089 shape = array_ops.shape(lower_upper) 1090 return lu_solve( 1091 lower_upper, 1092 perm, 1093 rhs=eye(shape[-1], batch_shape=shape[:-2], dtype=lower_upper.dtype), 1094 validate_args=False) 1095 1096 1097@tf_export('linalg.lu_reconstruct') 1098@dispatch.add_dispatch_support 1099def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): 1100 """The reconstruct one or more matrices from their LU decomposition(s). 1101 1102 Args: 1103 lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, 1104 matmul(L, U)) = X` then `lower_upper = L + U - eye`. 1105 perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = 1106 X` then `perm = argmax(P)`. 1107 validate_args: Python `bool` indicating whether arguments should be checked 1108 for correctness. 1109 Default value: `False` (i.e., don't validate arguments). 1110 name: Python `str` name given to ops managed by this object. 1111 Default value: `None` (i.e., 'lu_reconstruct'). 1112 1113 Returns: 1114 x: The original input to `tf.linalg.lu`, i.e., `x` as in, 1115 `lu_reconstruct(*tf.linalg.lu(x))`. 1116 1117 #### Examples 1118 1119 ```python 1120 import numpy as np 1121 import tensorflow as tf 1122 import tensorflow_probability as tfp 1123 1124 x = [[[3., 4], [1, 2]], 1125 [[7., 8], [3, 4]]] 1126 x_reconstructed = tf.linalg.lu_reconstruct(*tf.linalg.lu(x)) 1127 tf.assert_near(x, x_reconstructed) 1128 # ==> True 1129 ``` 1130 1131 """ 1132 with ops.name_scope(name or 'lu_reconstruct'): 1133 lower_upper = ops.convert_to_tensor( 1134 lower_upper, dtype_hint=dtypes.float32, name='lower_upper') 1135 perm = ops.convert_to_tensor(perm, dtype_hint=dtypes.int32, name='perm') 1136 1137 assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) 1138 if assertions: 1139 with ops.control_dependencies(assertions): 1140 lower_upper = array_ops.identity(lower_upper) 1141 perm = array_ops.identity(perm) 1142 1143 shape = array_ops.shape(lower_upper) 1144 1145 lower = set_diag( 1146 band_part(lower_upper, num_lower=-1, num_upper=0), 1147 array_ops.ones(shape[:-1], dtype=lower_upper.dtype)) 1148 upper = band_part(lower_upper, num_lower=0, num_upper=-1) 1149 x = math_ops.matmul(lower, upper) 1150 1151 if (lower_upper.shape is None or lower_upper.shape.rank is None or 1152 lower_upper.shape.rank != 2): 1153 # We either don't know the batch rank or there are >0 batch dims. 1154 batch_size = math_ops.reduce_prod(shape[:-2]) 1155 d = shape[-1] 1156 x = array_ops.reshape(x, [batch_size, d, d]) 1157 perm = array_ops.reshape(perm, [batch_size, d]) 1158 perm = map_fn.map_fn(array_ops.invert_permutation, perm) 1159 batch_indices = array_ops.broadcast_to( 1160 math_ops.range(batch_size)[:, array_ops.newaxis], [batch_size, d]) 1161 x = array_ops.gather_nd(x, array_ops.stack([batch_indices, perm], 1162 axis=-1)) 1163 x = array_ops.reshape(x, shape) 1164 else: 1165 x = array_ops.gather(x, array_ops.invert_permutation(perm)) 1166 1167 x.set_shape(lower_upper.shape) 1168 return x 1169 1170 1171def lu_reconstruct_assertions(lower_upper, perm, validate_args): 1172 """Returns list of assertions related to `lu_reconstruct` assumptions.""" 1173 assertions = [] 1174 1175 message = 'Input `lower_upper` must have at least 2 dimensions.' 1176 if lower_upper.shape.rank is not None and lower_upper.shape.rank < 2: 1177 raise ValueError(message) 1178 elif validate_args: 1179 assertions.append( 1180 check_ops.assert_rank_at_least_v2(lower_upper, rank=2, message=message)) 1181 1182 message = '`rank(lower_upper)` must equal `rank(perm) + 1`' 1183 if lower_upper.shape.rank is not None and perm.shape.rank is not None: 1184 if lower_upper.shape.rank != perm.shape.rank + 1: 1185 raise ValueError(message) 1186 elif validate_args: 1187 assertions.append( 1188 check_ops.assert_rank( 1189 lower_upper, rank=array_ops.rank(perm) + 1, message=message)) 1190 1191 message = '`lower_upper` must be square.' 1192 if lower_upper.shape[:-2].is_fully_defined(): 1193 if lower_upper.shape[-2] != lower_upper.shape[-1]: 1194 raise ValueError(message) 1195 elif validate_args: 1196 m, n = array_ops.split( 1197 array_ops.shape(lower_upper)[-2:], num_or_size_splits=2) 1198 assertions.append(check_ops.assert_equal(m, n, message=message)) 1199 1200 return assertions 1201 1202 1203def _lu_solve_assertions(lower_upper, perm, rhs, validate_args): 1204 """Returns list of assertions related to `lu_solve` assumptions.""" 1205 assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) 1206 1207 message = 'Input `rhs` must have at least 2 dimensions.' 1208 if rhs.shape.ndims is not None: 1209 if rhs.shape.ndims < 2: 1210 raise ValueError(message) 1211 elif validate_args: 1212 assertions.append( 1213 check_ops.assert_rank_at_least(rhs, rank=2, message=message)) 1214 1215 message = '`lower_upper.shape[-1]` must equal `rhs.shape[-1]`.' 1216 if (lower_upper.shape[-1] is not None and rhs.shape[-2] is not None): 1217 if lower_upper.shape[-1] != rhs.shape[-2]: 1218 raise ValueError(message) 1219 elif validate_args: 1220 assertions.append( 1221 check_ops.assert_equal( 1222 array_ops.shape(lower_upper)[-1], 1223 array_ops.shape(rhs)[-2], 1224 message=message)) 1225 1226 return assertions 1227 1228 1229@tf_export('linalg.eigh_tridiagonal') 1230@dispatch.add_dispatch_support 1231def eigh_tridiagonal(alpha, 1232 beta, 1233 eigvals_only=True, 1234 select='a', 1235 select_range=None, 1236 tol=None, 1237 name=None): 1238 """Computes the eigenvalues of a Hermitian tridiagonal matrix. 1239 1240 Args: 1241 alpha: A real or complex tensor of shape (n), the diagonal elements of the 1242 matrix. NOTE: If alpha is complex, the imaginary part is ignored (assumed 1243 zero) to satisfy the requirement that the matrix be Hermitian. 1244 beta: A real or complex tensor of shape (n-1), containing the elements of 1245 the first super-diagonal of the matrix. If beta is complex, the first 1246 sub-diagonal of the matrix is assumed to be the conjugate of beta to 1247 satisfy the requirement that the matrix be Hermitian 1248 eigvals_only: If False, both eigenvalues and corresponding eigenvectors are 1249 computed. If True, only eigenvalues are computed. Default is True. 1250 select: Optional string with values in {‘a’, ‘v’, ‘i’} (default is 'a') that 1251 determines which eigenvalues to calculate: 1252 'a': all eigenvalues. 1253 ‘v’: eigenvalues in the interval (min, max] given by `select_range`. 1254 'i’: eigenvalues with indices min <= i <= max. 1255 select_range: Size 2 tuple or list or tensor specifying the range of 1256 eigenvalues to compute together with select. If select is 'a', 1257 select_range is ignored. 1258 tol: Optional scalar. The absolute tolerance to which each eigenvalue is 1259 required. An eigenvalue (or cluster) is considered to have converged if it 1260 lies in an interval of this width. If tol is None (default), the value 1261 eps*|T|_2 is used where eps is the machine precision, and |T|_2 is the 1262 2-norm of the matrix T. 1263 name: Optional name of the op. 1264 1265 Returns: 1266 eig_vals: The eigenvalues of the matrix in non-decreasing order. 1267 eig_vectors: If `eigvals_only` is False the eigenvectors are returned in 1268 the second output argument. 1269 1270 Raises: 1271 ValueError: If input values are invalid. 1272 NotImplemented: Computing eigenvectors for `eigvals_only` = False is 1273 not implemented yet. 1274 1275 This op implements a subset of the functionality of 1276 scipy.linalg.eigh_tridiagonal. 1277 1278 Note: The result is undefined if the input contains +/-inf or NaN, or if 1279 any value in beta has a magnitude greater than 1280 `numpy.sqrt(numpy.finfo(beta.dtype.as_numpy_dtype).max)`. 1281 1282 1283 TODO(b/187527398): 1284 Add support for outer batch dimensions. 1285 1286 #### Examples 1287 1288 ```python 1289 import numpy 1290 eigvals = tf.linalg.eigh_tridiagonal([0.0, 0.0, 0.0], [1.0, 1.0]) 1291 eigvals_expected = [-numpy.sqrt(2.0), 0.0, numpy.sqrt(2.0)] 1292 tf.assert_near(eigvals_expected, eigvals) 1293 # ==> True 1294 ``` 1295 1296 """ 1297 with ops.name_scope(name or 'eigh_tridiagonal'): 1298 1299 def _compute_eigenvalues(alpha, beta): 1300 """Computes all eigenvalues of a Hermitian tridiagonal matrix.""" 1301 1302 def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x): 1303 """Implements the Sturm sequence recurrence.""" 1304 with ops.name_scope('sturm'): 1305 n = alpha.shape[0] 1306 zeros = array_ops.zeros(array_ops.shape(x), dtype=dtypes.int32) 1307 ones = array_ops.ones(array_ops.shape(x), dtype=dtypes.int32) 1308 1309 # The first step in the Sturm sequence recurrence 1310 # requires special care if x is equal to alpha[0]. 1311 def sturm_step0(): 1312 q = alpha[0] - x 1313 count = array_ops.where(q < 0, ones, zeros) 1314 q = array_ops.where( 1315 math_ops.equal(alpha[0], x), alpha0_perturbation, q) 1316 return q, count 1317 1318 # Subsequent steps all take this form: 1319 def sturm_step(i, q, count): 1320 q = alpha[i] - beta_sq[i - 1] / q - x 1321 count = array_ops.where(q <= pivmin, count + 1, count) 1322 q = array_ops.where(q <= pivmin, math_ops.minimum(q, -pivmin), q) 1323 return q, count 1324 1325 # The first step initializes q and count. 1326 q, count = sturm_step0() 1327 1328 # Peel off ((n-1) % blocksize) steps from the main loop, so we can run 1329 # the bulk of the iterations unrolled by a factor of blocksize. 1330 blocksize = 16 1331 i = 1 1332 peel = (n - 1) % blocksize 1333 unroll_cnt = peel 1334 1335 def unrolled_steps(start, q, count): 1336 for j in range(unroll_cnt): 1337 q, count = sturm_step(start + j, q, count) 1338 return start + unroll_cnt, q, count 1339 1340 i, q, count = unrolled_steps(i, q, count) 1341 1342 # Run the remaining steps of the Sturm sequence using a partially 1343 # unrolled while loop. 1344 unroll_cnt = blocksize 1345 cond = lambda i, q, count: math_ops.less(i, n) 1346 _, _, count = control_flow_ops.while_loop( 1347 cond, unrolled_steps, [i, q, count], back_prop=False) 1348 return count 1349 1350 with ops.name_scope('compute_eigenvalues'): 1351 if alpha.dtype.is_complex: 1352 alpha = math_ops.real(alpha) 1353 beta_sq = math_ops.real(math_ops.conj(beta) * beta) 1354 beta_abs = math_ops.sqrt(beta_sq) 1355 else: 1356 beta_sq = math_ops.square(beta) 1357 beta_abs = math_ops.abs(beta) 1358 1359 # Estimate the largest and smallest eigenvalues of T using the 1360 # Gershgorin circle theorem. 1361 finfo = np.finfo(alpha.dtype.as_numpy_dtype) 1362 off_diag_abs_row_sum = array_ops.concat( 1363 [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0) 1364 lambda_est_max = math_ops.minimum( 1365 finfo.max, math_ops.reduce_max(alpha + off_diag_abs_row_sum)) 1366 lambda_est_min = math_ops.maximum( 1367 finfo.min, math_ops.reduce_min(alpha - off_diag_abs_row_sum)) 1368 # Upper bound on 2-norm of T. 1369 t_norm = math_ops.maximum( 1370 math_ops.abs(lambda_est_min), math_ops.abs(lambda_est_max)) 1371 1372 # Compute the smallest allowed pivot in the Sturm sequence to avoid 1373 # overflow. 1374 one = np.ones([], dtype=alpha.dtype.as_numpy_dtype) 1375 safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny) 1376 pivmin = safemin * math_ops.maximum(one, math_ops.reduce_max(beta_sq)) 1377 alpha0_perturbation = math_ops.square(finfo.eps * beta_abs[0]) 1378 abs_tol = finfo.eps * t_norm 1379 if tol: 1380 abs_tol = math_ops.maximum(tol, abs_tol) 1381 # In the worst case, when the absolute tolerance is eps*lambda_est_max 1382 # and lambda_est_max = -lambda_est_min, we have to take as many 1383 # bisection steps as there are bits in the mantissa plus 1. 1384 max_it = finfo.nmant + 1 1385 1386 # Determine the indices of the desired eigenvalues, based on select 1387 # and select_range. 1388 asserts = None 1389 if select == 'a': 1390 target_counts = math_ops.range(n) 1391 elif select == 'i': 1392 asserts = check_ops.assert_less_equal( 1393 select_range[0], 1394 select_range[1], 1395 message='Got empty index range in select_range.') 1396 target_counts = math_ops.range(select_range[0], select_range[1] + 1) 1397 elif select == 'v': 1398 asserts = check_ops.assert_less( 1399 select_range[0], 1400 select_range[1], 1401 message='Got empty interval in select_range.') 1402 else: 1403 raise ValueError("'select must have a value in {'a', 'i', 'v'}.") 1404 1405 if asserts: 1406 with ops.control_dependencies([asserts]): 1407 alpha = array_ops.identity(alpha) 1408 1409 # Run binary search for all desired eigenvalues in parallel, starting 1410 # from an interval slightly wider than the estimated 1411 # [lambda_est_min, lambda_est_max]. 1412 fudge = 2.1 # We widen starting interval the Gershgorin interval a bit. 1413 norm_slack = math_ops.cast(n, alpha.dtype) * fudge * finfo.eps * t_norm 1414 if select in {'a', 'i'}: 1415 lower = lambda_est_min - norm_slack - 2 * fudge * pivmin 1416 upper = lambda_est_max + norm_slack + fudge * pivmin 1417 else: 1418 # Count the number of eigenvalues in the given range. 1419 lower = select_range[0] - norm_slack - 2 * fudge * pivmin 1420 upper = select_range[1] + norm_slack + fudge * pivmin 1421 first = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, lower) 1422 last = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, upper) 1423 target_counts = math_ops.range(first, last) 1424 1425 # Pre-broadcast the scalars used in the Sturm sequence for improved 1426 # performance. 1427 upper = math_ops.minimum(upper, finfo.max) 1428 lower = math_ops.maximum(lower, finfo.min) 1429 target_shape = array_ops.shape(target_counts) 1430 lower = array_ops.broadcast_to(lower, shape=target_shape) 1431 upper = array_ops.broadcast_to(upper, shape=target_shape) 1432 pivmin = array_ops.broadcast_to(pivmin, target_shape) 1433 alpha0_perturbation = array_ops.broadcast_to(alpha0_perturbation, 1434 target_shape) 1435 1436 # We compute the midpoint as 0.5*lower + 0.5*upper to avoid overflow in 1437 # (lower + upper) or (upper - lower) when the matrix has eigenvalues 1438 # with magnitude greater than finfo.max / 2. 1439 def midpoint(lower, upper): 1440 return (0.5 * lower) + (0.5 * upper) 1441 1442 def continue_binary_search(i, lower, upper): 1443 return math_ops.logical_and( 1444 math_ops.less(i, max_it), 1445 math_ops.less(abs_tol, math_ops.reduce_max(upper - lower))) 1446 1447 def binary_search_step(i, lower, upper): 1448 mid = midpoint(lower, upper) 1449 counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid) 1450 lower = array_ops.where(counts <= target_counts, mid, lower) 1451 upper = array_ops.where(counts > target_counts, mid, upper) 1452 return i + 1, lower, upper 1453 1454 # Start parallel binary searches. 1455 _, lower, upper = control_flow_ops.while_loop(continue_binary_search, 1456 binary_search_step, 1457 [0, lower, upper]) 1458 return midpoint(lower, upper) 1459 1460 def _compute_eigenvectors(alpha, beta, eigvals): 1461 """Implements inverse iteration to compute eigenvectors.""" 1462 with ops.name_scope('compute_eigenvectors'): 1463 k = array_ops.size(eigvals) 1464 n = array_ops.size(alpha) 1465 alpha = math_ops.cast(alpha, dtype=beta.dtype) 1466 1467 # Eigenvectors corresponding to cluster of close eigenvalues are 1468 # not unique and need to be explicitly orthogonalized. Here we 1469 # identify such clusters. Note: This function assumes that 1470 # eigenvalues are sorted in non-decreasing order. 1471 gap = eigvals[1:] - eigvals[:-1] 1472 eps = np.finfo(eigvals.dtype.as_numpy_dtype).eps 1473 t_norm = math_ops.maximum( 1474 math_ops.abs(eigvals[0]), math_ops.abs(eigvals[-1])) 1475 gaptol = np.sqrt(eps) * t_norm 1476 # Find the beginning and end of runs of eigenvectors corresponding 1477 # to eigenvalues closer than "gaptol", which will need to be 1478 # orthogonalized against each other. 1479 close = math_ops.less(gap, gaptol) 1480 left_neighbor_close = array_ops.concat([[False], close], axis=0) 1481 right_neighbor_close = array_ops.concat([close, [False]], axis=0) 1482 ortho_interval_start = math_ops.logical_and( 1483 math_ops.logical_not(left_neighbor_close), right_neighbor_close) 1484 ortho_interval_start = array_ops.squeeze( 1485 array_ops.where_v2(ortho_interval_start), axis=-1) 1486 ortho_interval_end = math_ops.logical_and( 1487 left_neighbor_close, math_ops.logical_not(right_neighbor_close)) 1488 ortho_interval_end = array_ops.squeeze( 1489 array_ops.where_v2(ortho_interval_end), axis=-1) + 1 1490 num_clusters = array_ops.size(ortho_interval_end) 1491 1492 # We perform inverse iteration for all eigenvectors in parallel, 1493 # starting from a random set of vectors, until all have converged. 1494 v0 = math_ops.cast( 1495 stateless_random_ops.stateless_random_normal( 1496 shape=(k, n), seed=[7, 42]), 1497 dtype=beta.dtype) 1498 nrm_v = norm(v0, axis=1) 1499 v0 = v0 / nrm_v[:, array_ops.newaxis] 1500 zero_nrm = constant_op.constant(0, shape=nrm_v.shape, dtype=nrm_v.dtype) 1501 1502 # Replicate alpha-eigvals(ik) and beta across the k eigenvectors so we 1503 # can solve the k systems 1504 # [T - eigvals(i)*eye(n)] x_i = r_i 1505 # simultaneously using the batching mechanism. 1506 eigvals_cast = math_ops.cast(eigvals, dtype=beta.dtype) 1507 alpha_shifted = ( 1508 alpha[array_ops.newaxis, :] - eigvals_cast[:, array_ops.newaxis]) 1509 beta = array_ops.tile(beta[array_ops.newaxis, :], [k, 1]) 1510 diags = [beta, alpha_shifted, math_ops.conj(beta)] 1511 1512 def orthogonalize_close_eigenvectors(eigenvectors): 1513 # Eigenvectors corresponding to a cluster of close eigenvalues are not 1514 # uniquely defined, but the subspace they span is. To avoid numerical 1515 # instability, we explicitly mutually orthogonalize such eigenvectors 1516 # after each step of inverse iteration. It is customary to use 1517 # modified Gram-Schmidt for this, but this is not very efficient 1518 # on some platforms, so here we defer to the QR decomposition in 1519 # TensorFlow. 1520 def orthogonalize_cluster(cluster_idx, eigenvectors): 1521 start = ortho_interval_start[cluster_idx] 1522 end = ortho_interval_end[cluster_idx] 1523 update_indices = array_ops.expand_dims( 1524 math_ops.range(start, end), -1) 1525 vectors_in_cluster = eigenvectors[start:end, :] 1526 # We use the builtin QR factorization to orthonormalize the 1527 # vectors in the cluster. 1528 q, _ = qr(transpose(vectors_in_cluster)) 1529 vectors_to_update = transpose(q) 1530 eigenvectors = array_ops.tensor_scatter_nd_update( 1531 eigenvectors, update_indices, vectors_to_update) 1532 return cluster_idx + 1, eigenvectors 1533 1534 _, eigenvectors = control_flow_ops.while_loop( 1535 lambda i, ev: math_ops.less(i, num_clusters), 1536 orthogonalize_cluster, [0, eigenvectors]) 1537 return eigenvectors 1538 1539 def continue_iteration(i, _, nrm_v, nrm_v_old): 1540 max_it = 5 # Taken from LAPACK xSTEIN. 1541 min_norm_growth = 0.1 1542 norm_growth_factor = constant_op.constant( 1543 1 + min_norm_growth, dtype=nrm_v.dtype) 1544 # We stop the inverse iteration when we reach the maximum number of 1545 # iterations or the norm growths is less than 10%. 1546 return math_ops.logical_and( 1547 math_ops.less(i, max_it), 1548 math_ops.reduce_any( 1549 math_ops.greater_equal( 1550 math_ops.real(nrm_v), 1551 math_ops.real(norm_growth_factor * nrm_v_old)))) 1552 1553 def inverse_iteration_step(i, v, nrm_v, nrm_v_old): 1554 v = tridiagonal_solve( 1555 diags, 1556 v, 1557 diagonals_format='sequence', 1558 partial_pivoting=True, 1559 perturb_singular=True) 1560 nrm_v_old = nrm_v 1561 nrm_v = norm(v, axis=1) 1562 v = v / nrm_v[:, array_ops.newaxis] 1563 v = orthogonalize_close_eigenvectors(v) 1564 return i + 1, v, nrm_v, nrm_v_old 1565 1566 _, v, nrm_v, _ = control_flow_ops.while_loop(continue_iteration, 1567 inverse_iteration_step, 1568 [0, v0, nrm_v, zero_nrm]) 1569 return transpose(v) 1570 1571 alpha = ops.convert_to_tensor(alpha, name='alpha') 1572 n = alpha.shape[0] 1573 if n <= 1: 1574 return math_ops.real(alpha) 1575 beta = ops.convert_to_tensor(beta, name='beta') 1576 1577 if alpha.dtype != beta.dtype: 1578 raise ValueError("'alpha' and 'beta' must have the same type.") 1579 1580 eigvals = _compute_eigenvalues(alpha, beta) 1581 if eigvals_only: 1582 return eigvals 1583 1584 eigvectors = _compute_eigenvectors(alpha, beta, eigvals) 1585 return eigvals, eigvectors 1586