xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.ops.linalg_ops."""
16
17import itertools
18
19from absl.testing import parameterized
20import numpy as np
21import scipy.linalg
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import linalg_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import random_ops
31from tensorflow.python.ops.linalg import linalg
32from tensorflow.python.platform import test
33
34
35def _RandomPDMatrix(n, rng, dtype=np.float64):
36  """Random positive definite matrix."""
37  temp = rng.randn(n, n).astype(dtype)
38  if dtype in [np.complex64, np.complex128]:
39    temp.imag = rng.randn(n, n)
40  return np.conj(temp).dot(temp.T)
41
42
43class CholeskySolveTest(test.TestCase):
44
45  def setUp(self):
46    self.rng = np.random.RandomState(0)
47
48  @test_util.run_deprecated_v1
49  def test_works_with_five_different_random_pos_def_matrices(self):
50    for n in range(1, 6):
51      for np_type, atol in [(np.float32, 0.05), (np.float64, 1e-5)]:
52        with self.session():
53          # Create 2 x n x n matrix
54          array = np.array(
55              [_RandomPDMatrix(n, self.rng),
56               _RandomPDMatrix(n, self.rng)]).astype(np_type)
57          chol = linalg_ops.cholesky(array)
58          for k in range(1, 3):
59            with self.subTest(n=n, np_type=np_type, atol=atol, k=k):
60              rhs = self.rng.randn(2, n, k).astype(np_type)
61              x = linalg_ops.cholesky_solve(chol, rhs)
62              self.assertAllClose(rhs, math_ops.matmul(array, x), atol=atol)
63
64
65class LogdetTest(test.TestCase):
66
67  def setUp(self):
68    self.rng = np.random.RandomState(42)
69
70  @test_util.run_deprecated_v1
71  def test_works_with_five_different_random_pos_def_matrices(self):
72    for n in range(1, 6):
73      for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5),
74                             (np.complex64, 0.05), (np.complex128, 1e-5)]:
75        with self.subTest(n=n, np_dtype=np_dtype, atol=atol):
76          matrix = _RandomPDMatrix(n, self.rng, np_dtype)
77          _, logdet_np = np.linalg.slogdet(matrix)
78          with self.session():
79            # Create 2 x n x n matrix
80            # matrix = np.array(
81            #     [_RandomPDMatrix(n, self.rng, np_dtype),
82            #      _RandomPDMatrix(n, self.rng, np_dtype)]).astype(np_dtype)
83            logdet_tf = linalg.logdet(matrix)
84            self.assertAllClose(logdet_np, self.evaluate(logdet_tf), atol=atol)
85
86  def test_works_with_underflow_case(self):
87    for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5),
88                           (np.complex64, 0.05), (np.complex128, 1e-5)]:
89      with self.subTest(np_dtype=np_dtype, atol=atol):
90        matrix = (np.eye(20) * 1e-6).astype(np_dtype)
91        _, logdet_np = np.linalg.slogdet(matrix)
92        with self.session():
93          logdet_tf = linalg.logdet(matrix)
94          self.assertAllClose(logdet_np, self.evaluate(logdet_tf), atol=atol)
95
96
97class SlogdetTest(test.TestCase):
98
99  def setUp(self):
100    self.rng = np.random.RandomState(42)
101
102  @test_util.run_deprecated_v1
103  def test_works_with_five_different_random_pos_def_matrices(self):
104    for n in range(1, 6):
105      for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5),
106                             (np.complex64, 0.05), (np.complex128, 1e-5)]:
107        with self.subTest(n=n, np_dtype=np_dtype, atol=atol):
108          matrix = _RandomPDMatrix(n, self.rng, np_dtype)
109          sign_np, log_abs_det_np = np.linalg.slogdet(matrix)
110          with self.session():
111            sign_tf, log_abs_det_tf = linalg.slogdet(matrix)
112            self.assertAllClose(
113                log_abs_det_np, self.evaluate(log_abs_det_tf), atol=atol)
114            self.assertAllClose(sign_np, self.evaluate(sign_tf), atol=atol)
115
116  def test_works_with_underflow_case(self):
117    for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5),
118                           (np.complex64, 0.05), (np.complex128, 1e-5)]:
119      with self.subTest(np_dtype=np_dtype, atol=atol):
120        matrix = (np.eye(20) * 1e-6).astype(np_dtype)
121        sign_np, log_abs_det_np = np.linalg.slogdet(matrix)
122        with self.session():
123          sign_tf, log_abs_det_tf = linalg.slogdet(matrix)
124          self.assertAllClose(
125              log_abs_det_np, self.evaluate(log_abs_det_tf), atol=atol)
126          self.assertAllClose(sign_np, self.evaluate(sign_tf), atol=atol)
127
128
129class AdjointTest(test.TestCase):
130
131  def test_compare_to_numpy(self):
132    for dtype in np.float64, np.float64, np.complex64, np.complex128:
133      with self.subTest(dtype=dtype):
134        matrix_np = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j,
135                                                         6 + 6j]]).astype(dtype)
136        expected_transposed = np.conj(matrix_np.T)
137        with self.session():
138          matrix = ops.convert_to_tensor(matrix_np)
139          transposed = linalg.adjoint(matrix)
140          self.assertEqual((3, 2), transposed.get_shape())
141          self.assertAllEqual(expected_transposed, self.evaluate(transposed))
142
143
144class EyeTest(parameterized.TestCase, test.TestCase):
145
146  def testShapeInferenceNoBatch(self):
147    self.assertEqual((2, 2), linalg_ops.eye(num_rows=2).shape)
148    self.assertEqual((2, 3), linalg_ops.eye(num_rows=2, num_columns=3).shape)
149
150  def testShapeInferenceStaticBatch(self):
151    batch_shape = (2, 3)
152    self.assertEqual(
153        (2, 3, 2, 2),
154        linalg_ops.eye(num_rows=2, batch_shape=batch_shape).shape)
155    self.assertEqual(
156        (2, 3, 2, 3),
157        linalg_ops.eye(
158            num_rows=2, num_columns=3, batch_shape=batch_shape).shape)
159
160  @parameterized.named_parameters(
161      ("DynamicRow",
162       lambda: array_ops.placeholder_with_default(2, shape=None),
163       lambda: None),
164      ("DynamicRowStaticColumn",
165       lambda: array_ops.placeholder_with_default(2, shape=None),
166       lambda: 3),
167      ("StaticRowDynamicColumn",
168       lambda: 2,
169       lambda: array_ops.placeholder_with_default(3, shape=None)),
170      ("DynamicRowDynamicColumn",
171       lambda: array_ops.placeholder_with_default(2, shape=None),
172       lambda: array_ops.placeholder_with_default(3, shape=None)))
173  def testShapeInferenceStaticBatchWith(self, num_rows_fn, num_columns_fn):
174    num_rows = num_rows_fn()
175    num_columns = num_columns_fn()
176    batch_shape = (2, 3)
177    identity_matrix = linalg_ops.eye(
178        num_rows=num_rows,
179        num_columns=num_columns,
180        batch_shape=batch_shape)
181    self.assertEqual(4, identity_matrix.shape.ndims)
182    self.assertEqual((2, 3), identity_matrix.shape[:2])
183    if num_rows is not None and not isinstance(num_rows, ops.Tensor):
184      self.assertEqual(2, identity_matrix.shape[-2])
185
186    if num_columns is not None and not isinstance(num_columns, ops.Tensor):
187      self.assertEqual(3, identity_matrix.shape[-1])
188
189  @parameterized.parameters(
190      itertools.product(
191          # num_rows
192          [0, 1, 2, 5],
193          # num_columns
194          [None, 0, 1, 2, 5],
195          # batch_shape
196          [None, [], [2], [2, 3]],
197          # dtype
198          [
199              dtypes.int32,
200              dtypes.int64,
201              dtypes.float32,
202              dtypes.float64,
203              dtypes.complex64,
204              dtypes.complex128
205          ])
206      )
207  def test_eye_no_placeholder(self, num_rows, num_columns, batch_shape, dtype):
208    eye_np = np.eye(num_rows, M=num_columns, dtype=dtype.as_numpy_dtype)
209    if batch_shape is not None:
210      eye_np = np.tile(eye_np, batch_shape + [1, 1])
211    eye_tf = self.evaluate(linalg_ops.eye(
212        num_rows,
213        num_columns=num_columns,
214        batch_shape=batch_shape,
215        dtype=dtype))
216    self.assertAllEqual(eye_np, eye_tf)
217
218  @parameterized.parameters(
219      itertools.product(
220          # num_rows
221          [0, 1, 2, 5],
222          # num_columns
223          [0, 1, 2, 5],
224          # batch_shape
225          [[], [2], [2, 3]],
226          # dtype
227          [
228              dtypes.int32,
229              dtypes.int64,
230              dtypes.float32,
231              dtypes.float64,
232              dtypes.complex64,
233              dtypes.complex128
234          ])
235      )
236  @test_util.run_deprecated_v1
237  def test_eye_with_placeholder(
238      self, num_rows, num_columns, batch_shape, dtype):
239    eye_np = np.eye(num_rows, M=num_columns, dtype=dtype.as_numpy_dtype)
240    eye_np = np.tile(eye_np, batch_shape + [1, 1])
241    num_rows_placeholder = array_ops.placeholder(
242        dtypes.int32, name="num_rows")
243    num_columns_placeholder = array_ops.placeholder(
244        dtypes.int32, name="num_columns")
245    batch_shape_placeholder = array_ops.placeholder(
246        dtypes.int32, name="batch_shape")
247    eye = linalg_ops.eye(
248        num_rows_placeholder,
249        num_columns=num_columns_placeholder,
250        batch_shape=batch_shape_placeholder,
251        dtype=dtype)
252    with self.session() as sess:
253      eye_tf = sess.run(
254          eye,
255          feed_dict={
256              num_rows_placeholder: num_rows,
257              num_columns_placeholder: num_columns,
258              batch_shape_placeholder: batch_shape
259          })
260    self.assertAllEqual(eye_np, eye_tf)
261
262
263class _MatrixRankTest(object):
264
265  def test_batch_default_tolerance(self):
266    x_ = np.array(
267        [
268            [
269                [2, 3, -2],  # = row2+row3
270                [-1, 1, -2],
271                [3, 2, 0]
272            ],
273            [
274                [0, 2, 0],  # = 2*row2
275                [0, 1, 0],
276                [0, 3, 0]
277            ],  # = 3*row2
278            [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
279        ],
280        self.dtype)
281    x = array_ops.placeholder_with_default(
282        x_, shape=x_.shape if self.use_static_shape else None)
283    self.assertAllEqual([2, 1, 3], self.evaluate(linalg.matrix_rank(x)))
284
285  def test_custom_tolerance_broadcasts(self):
286    q = linalg.qr(random_ops.random_uniform([3, 3], dtype=self.dtype))[0]
287    e = constant_op.constant([0.1, 0.2, 0.3], dtype=self.dtype)
288    a = linalg.solve(q, linalg.transpose(a=e * q), adjoint=True)
289    self.assertAllEqual([3, 2, 1, 0],
290                        self.evaluate(
291                            linalg.matrix_rank(
292                                a, tol=[[0.09], [0.19], [0.29], [0.31]])))
293
294  def test_nonsquare(self):
295    x_ = np.array(
296        [
297            [
298                [2, 3, -2, 2],  # = row2+row3
299                [-1, 1, -2, 4],
300                [3, 2, 0, -2]
301            ],
302            [
303                [0, 2, 0, 6],  # = 2*row2
304                [0, 1, 0, 3],
305                [0, 3, 0, 9]
306            ]
307        ],  # = 3*row2
308        self.dtype)
309    x = array_ops.placeholder_with_default(
310        x_, shape=x_.shape if self.use_static_shape else None)
311    self.assertAllEqual([2, 1], self.evaluate(linalg.matrix_rank(x)))
312
313
314@test_util.run_all_in_graph_and_eager_modes
315class MatrixRankStatic32Test(test.TestCase, _MatrixRankTest):
316  dtype = np.float32
317  use_static_shape = True
318
319
320@test_util.run_all_in_graph_and_eager_modes
321class MatrixRankDynamic64Test(test.TestCase, _MatrixRankTest):
322  dtype = np.float64
323  use_static_shape = False
324
325
326class _PinvTest(object):
327
328  def expected_pinv(self, a, rcond):
329    """Calls `np.linalg.pinv` but corrects its broken batch semantics."""
330    if a.ndim < 3:
331      return np.linalg.pinv(a, rcond)
332    if rcond is None:
333      rcond = 10. * max(a.shape[-2], a.shape[-1]) * np.finfo(a.dtype).eps
334    s = np.concatenate([a.shape[:-2], [a.shape[-1], a.shape[-2]]])
335    a_pinv = np.zeros(s, dtype=a.dtype)
336    for i in np.ndindex(a.shape[:(a.ndim - 2)]):
337      a_pinv[i] = np.linalg.pinv(
338          a[i], rcond=rcond if isinstance(rcond, float) else rcond[i])
339    return a_pinv
340
341  def test_symmetric(self):
342    a_ = self.dtype([[1., .4, .5], [.4, .2, .25], [.5, .25, .35]])
343    a_ = np.stack([a_ + 1., a_], axis=0)  # Batch of matrices.
344    a = array_ops.placeholder_with_default(
345        a_, shape=a_.shape if self.use_static_shape else None)
346    if self.use_default_rcond:
347      rcond = None
348    else:
349      rcond = self.dtype([0., 0.01])  # Smallest 1 component is forced to zero.
350    expected_a_pinv_ = self.expected_pinv(a_, rcond)
351    a_pinv = linalg.pinv(a, rcond, validate_args=True)
352    a_pinv_ = self.evaluate(a_pinv)
353    self.assertAllClose(expected_a_pinv_, a_pinv_, atol=2e-5, rtol=2e-5)
354    if not self.use_static_shape:
355      return
356    self.assertAllEqual(expected_a_pinv_.shape, a_pinv.shape)
357
358  def test_nonsquare(self):
359    a_ = self.dtype([[1., .4, .5, 1.], [.4, .2, .25, 2.], [.5, .25, .35, 3.]])
360    a_ = np.stack([a_ + 0.5, a_], axis=0)  # Batch of matrices.
361    a = array_ops.placeholder_with_default(
362        a_, shape=a_.shape if self.use_static_shape else None)
363    if self.use_default_rcond:
364      rcond = None
365    else:
366      # Smallest 2 components are forced to zero.
367      rcond = self.dtype([0., 0.25])
368    expected_a_pinv_ = self.expected_pinv(a_, rcond)
369    a_pinv = linalg.pinv(a, rcond, validate_args=True)
370    a_pinv_ = self.evaluate(a_pinv)
371    self.assertAllClose(expected_a_pinv_, a_pinv_, atol=1e-5, rtol=1e-4)
372    if not self.use_static_shape:
373      return
374    self.assertAllEqual(expected_a_pinv_.shape, a_pinv.shape)
375
376
377@test_util.run_all_in_graph_and_eager_modes
378class PinvTestDynamic32DefaultRcond(test.TestCase, _PinvTest):
379  dtype = np.float32
380  use_static_shape = False
381  use_default_rcond = True
382
383
384@test_util.run_all_in_graph_and_eager_modes
385class PinvTestStatic64DefaultRcond(test.TestCase, _PinvTest):
386  dtype = np.float64
387  use_static_shape = True
388  use_default_rcond = True
389
390
391@test_util.run_all_in_graph_and_eager_modes
392class PinvTestDynamic32CustomtRcond(test.TestCase, _PinvTest):
393  dtype = np.float32
394  use_static_shape = False
395  use_default_rcond = False
396
397
398@test_util.run_all_in_graph_and_eager_modes
399class PinvTestStatic64CustomRcond(test.TestCase, _PinvTest):
400  dtype = np.float64
401  use_static_shape = True
402  use_default_rcond = False
403
404
405def make_tensor_hiding_attributes(value, hide_shape, hide_value=True):
406  if not hide_value:
407    return ops.convert_to_tensor(value)
408
409  shape = None if hide_shape else getattr(value, "shape", None)
410  return array_ops.placeholder_with_default(value, shape=shape)
411
412
413class _LUReconstruct(object):
414  dtype = np.float32
415  use_static_shape = True
416
417  def test_non_batch(self):
418    x_ = np.array([[3, 4], [1, 2]], dtype=self.dtype)
419    x = array_ops.placeholder_with_default(
420        x_, shape=x_.shape if self.use_static_shape else None)
421
422    y = linalg.lu_reconstruct(*linalg.lu(x), validate_args=True)
423    y_ = self.evaluate(y)
424
425    if self.use_static_shape:
426      self.assertAllEqual(x_.shape, y.shape)
427    self.assertAllClose(x_, y_, atol=0., rtol=1e-3)
428
429  def test_batch(self):
430    x_ = np.array([
431        [[3, 4], [1, 2]],
432        [[7, 8], [3, 4]],
433    ], dtype=self.dtype)
434    x = array_ops.placeholder_with_default(
435        x_, shape=x_.shape if self.use_static_shape else None)
436
437    y = linalg.lu_reconstruct(*linalg.lu(x), validate_args=True)
438    y_ = self.evaluate(y)
439
440    if self.use_static_shape:
441      self.assertAllEqual(x_.shape, y.shape)
442    self.assertAllClose(x_, y_, atol=0., rtol=1e-3)
443
444
445@test_util.run_all_in_graph_and_eager_modes
446class LUReconstructStatic(test.TestCase, _LUReconstruct):
447  use_static_shape = True
448
449
450@test_util.run_all_in_graph_and_eager_modes
451class LUReconstructDynamic(test.TestCase, _LUReconstruct):
452  use_static_shape = False
453
454
455class _LUMatrixInverse(object):
456  dtype = np.float32
457  use_static_shape = True
458
459  def test_non_batch(self):
460    x_ = np.array([[1, 2], [3, 4]], dtype=self.dtype)
461    x = array_ops.placeholder_with_default(
462        x_, shape=x_.shape if self.use_static_shape else None)
463
464    y = linalg.lu_matrix_inverse(*linalg.lu(x), validate_args=True)
465    y_ = self.evaluate(y)
466
467    if self.use_static_shape:
468      self.assertAllEqual(x_.shape, y.shape)
469    self.assertAllClose(np.linalg.inv(x_), y_, atol=0., rtol=1e-3)
470
471  def test_batch(self):
472    x_ = np.array([
473        [[1, 2], [3, 4]],
474        [[7, 8], [3, 4]],
475        [[0.25, 0.5], [0.75, -2.]],
476    ],
477                  dtype=self.dtype)
478    x = array_ops.placeholder_with_default(
479        x_, shape=x_.shape if self.use_static_shape else None)
480
481    y = linalg.lu_matrix_inverse(*linalg.lu(x), validate_args=True)
482    y_ = self.evaluate(y)
483
484    if self.use_static_shape:
485      self.assertAllEqual(x_.shape, y.shape)
486    self.assertAllClose(np.linalg.inv(x_), y_, atol=0., rtol=1e-3)
487
488
489@test_util.run_all_in_graph_and_eager_modes
490class LUMatrixInverseStatic(test.TestCase, _LUMatrixInverse):
491  use_static_shape = True
492
493
494@test_util.run_all_in_graph_and_eager_modes
495class LUMatrixInverseDynamic(test.TestCase, _LUMatrixInverse):
496  use_static_shape = False
497
498
499class _LUSolve(object):
500  dtype = np.float32
501  use_static_shape = True
502
503  def test_non_batch(self):
504    x_ = np.array([[1, 2], [3, 4]], dtype=self.dtype)
505    x = array_ops.placeholder_with_default(
506        x_, shape=x_.shape if self.use_static_shape else None)
507    rhs_ = np.array([[1, 1]], dtype=self.dtype).T
508    rhs = array_ops.placeholder_with_default(
509        rhs_, shape=rhs_.shape if self.use_static_shape else None)
510
511    lower_upper, perm = linalg.lu(x)
512    y = linalg.lu_solve(lower_upper, perm, rhs, validate_args=True)
513    y_, perm_ = self.evaluate([y, perm])
514
515    self.assertAllEqual([1, 0], perm_)
516    expected_ = np.linalg.solve(x_, rhs_)
517    if self.use_static_shape:
518      self.assertAllEqual(expected_.shape, y.shape)
519    self.assertAllClose(expected_, y_, atol=0., rtol=1e-3)
520
521  def test_batch_broadcast(self):
522    x_ = np.array([
523        [[1, 2], [3, 4]],
524        [[7, 8], [3, 4]],
525        [[0.25, 0.5], [0.75, -2.]],
526    ],
527                  dtype=self.dtype)
528    x = array_ops.placeholder_with_default(
529        x_, shape=x_.shape if self.use_static_shape else None)
530    rhs_ = np.array([[1, 1]], dtype=self.dtype).T
531    rhs = array_ops.placeholder_with_default(
532        rhs_, shape=rhs_.shape if self.use_static_shape else None)
533
534    lower_upper, perm = linalg.lu(x)
535    y = linalg.lu_solve(lower_upper, perm, rhs, validate_args=True)
536    y_, perm_ = self.evaluate([y, perm])
537
538    self.assertAllEqual([[1, 0], [0, 1], [1, 0]], perm_)
539    expected_ = np.linalg.solve(x_, rhs_[np.newaxis])
540    if self.use_static_shape:
541      self.assertAllEqual(expected_.shape, y.shape)
542    self.assertAllClose(expected_, y_, atol=0., rtol=1e-3)
543
544
545@test_util.run_all_in_graph_and_eager_modes
546class LUSolveStatic(test.TestCase, _LUSolve):
547  use_static_shape = True
548
549
550@test_util.run_all_in_graph_and_eager_modes
551class LUSolveDynamic(test.TestCase, _LUSolve):
552  use_static_shape = False
553
554
555@test_util.run_all_in_graph_and_eager_modes
556class EighTridiagonalTest(test.TestCase, parameterized.TestCase):
557
558  def check_residual(self, matrix, eigvals, eigvectors, atol):
559    # Test that A*eigvectors is close to eigvectors*diag(eigvals).
560    l = math_ops.cast(linalg.diag(eigvals), dtype=eigvectors.dtype)
561    av = math_ops.matmul(matrix, eigvectors)
562    vl = math_ops.matmul(eigvectors, l)
563    self.assertAllClose(av, vl, atol=atol)
564
565  def check_orthogonality(self, eigvectors, tol):
566    # Test that eigenvectors are orthogonal.
567    k = array_ops.shape(eigvectors)[1]
568    vtv = math_ops.matmul(
569        eigvectors, eigvectors, adjoint_a=True) - linalg.eye(
570            k, dtype=eigvectors.dtype)
571    self.assertAllLess(math_ops.abs(vtv), tol)
572
573  def run_test(self, alpha, beta, eigvals_only=True):
574    n = alpha.shape[0]
575    matrix = np.diag(alpha) + np.diag(beta, 1) + np.diag(np.conj(beta), -1)
576    # scipy.linalg.eigh_tridiagonal doesn't support complex inputs, so for
577    # this we call the slower numpy.linalg.eigh.
578    if np.issubdtype(alpha.dtype, np.complexfloating):
579      eigvals_expected, _ = np.linalg.eigh(matrix)
580    else:
581      eigvals_expected = scipy.linalg.eigh_tridiagonal(
582          alpha, beta, eigvals_only=True)
583    eigvals = linalg.eigh_tridiagonal(alpha, beta, eigvals_only=eigvals_only)
584    if not eigvals_only:
585      eigvals, eigvectors = eigvals
586
587    eps = np.finfo(alpha.dtype).eps
588    atol = n * eps * np.amax(np.abs(eigvals_expected))
589    self.assertAllClose(eigvals_expected, eigvals, atol=atol)
590    if not eigvals_only:
591      self.check_orthogonality(eigvectors, 2 * np.sqrt(n) * eps)
592      self.check_residual(matrix, eigvals, eigvectors, atol)
593
594  @parameterized.parameters((np.float32), (np.float64), (np.complex64),
595                            (np.complex128))
596  def test_small(self, dtype):
597    for n in [1, 2, 3]:
598      alpha = np.ones([n], dtype=dtype)
599      beta = np.ones([n - 1], dtype=dtype)
600      if np.issubdtype(alpha.dtype, np.complexfloating):
601        beta += 1j * beta
602      self.run_test(alpha, beta)
603
604  @parameterized.parameters((np.float32), (np.float64), (np.complex64),
605                            (np.complex128))
606  def test_toeplitz(self, dtype):
607    n = 8
608    for a, b in [[2, -1], [1, 0], [0, 1], [-1e10, 1e10], [-1e-10, 1e-10]]:
609      alpha = a * np.ones([n], dtype=dtype)
610      beta = b * np.ones([n - 1], dtype=dtype)
611      if np.issubdtype(alpha.dtype, np.complexfloating):
612        beta += 1j * beta
613      self.run_test(alpha, beta)
614
615  @parameterized.parameters((np.float32), (np.float64), (np.complex64),
616                            (np.complex128))
617  def test_random_uniform(self, dtype):
618    for n in [8, 50]:
619      alpha = np.random.uniform(size=(n,)).astype(dtype)
620      beta = np.random.uniform(size=(n - 1,)).astype(dtype)
621      if np.issubdtype(beta.dtype, np.complexfloating):
622        beta += 1j * np.random.uniform(size=(n - 1,)).astype(dtype)
623      self.run_test(alpha, beta)
624
625  @parameterized.parameters((np.float32), (np.float64), (np.complex64),
626                            (np.complex128))
627  def test_select(self, dtype):
628    n = 4
629    alpha = np.random.uniform(size=(n,)).astype(dtype)
630    beta = np.random.uniform(size=(n - 1,)).astype(dtype)
631    eigvals_all = linalg.eigh_tridiagonal(alpha, beta, select="a")
632
633    eps = np.finfo(alpha.dtype).eps
634    atol = 2 * n * eps
635    for first in range(n - 1):
636      for last in range(first + 1, n - 1):
637        # Check that we get the expected eigenvalues by selecting by
638        # index range.
639        eigvals_index = linalg.eigh_tridiagonal(
640            alpha, beta, select="i", select_range=(first, last))
641        self.assertAllClose(
642            eigvals_all[first:(last + 1)], eigvals_index, atol=atol)
643
644        # Check that we get the expected eigenvalues by selecting by
645        # value range.
646        eigvals_value = linalg.eigh_tridiagonal(
647            alpha,
648            beta,
649            select="v",
650            select_range=(eigvals_all[first], eigvals_all[last]))
651        self.assertAllClose(
652            eigvals_all[first:(last + 1)], eigvals_value, atol=atol)
653
654  @parameterized.parameters((np.float32), (np.float64), (np.complex64),
655                            (np.complex128))
656  def test_extreme_eigenvalues_test(self, dtype):
657    huge = 0.33 * np.finfo(dtype).max
658    tiny = 3 * np.finfo(dtype).tiny
659    for (a, b) in [(tiny, tiny), (huge, np.sqrt(huge))]:
660      alpha = np.array([-a, -np.sqrt(a), np.sqrt(a), a]).astype(dtype)
661
662      beta = b * np.ones([3], dtype=dtype)
663      if np.issubdtype(alpha.dtype, np.complexfloating):
664        beta += 1j * beta
665
666  @parameterized.parameters((np.float32), (np.float64), (np.complex64),
667                            (np.complex128))
668  def test_eigenvectors(self, dtype):
669    if test.is_gpu_available(cuda_only=True) or test_util.is_xla_enabled():
670      # cuda and XLA do not yet expose the stabilized tridiagonal solver
671      # needed for inverse iteration.
672      return
673    n = 8
674    alpha = np.random.uniform(size=(n,)).astype(dtype)
675    beta = np.random.uniform(size=(n - 1,)).astype(dtype)
676    if np.issubdtype(beta.dtype, np.complexfloating):
677      beta += 1j * np.random.uniform(size=(n - 1,)).astype(dtype)
678    self.run_test(alpha, beta, eigvals_only=False)
679
680    # Test that we can correctly generate an orthogonal basis for
681    # a fully degenerate matrix.
682    eps = np.finfo(dtype).eps
683    alpha = np.ones(n).astype(dtype)
684    beta = 0.01 * np.sqrt(eps) * np.ones((n - 1)).astype(dtype)
685    self.run_test(alpha, beta, eigvals_only=False)
686
687
688if __name__ == "__main__":
689  test.main()
690