xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/linalg/cholesky_op_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.tf.Cholesky."""
16
17import numpy as np
18
19from tensorflow.python.client import session
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes as dtypes_lib
22from tensorflow.python.framework import errors_impl
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import gradient_checker_v2
28from tensorflow.python.ops import linalg_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import stateless_random_ops
31from tensorflow.python.ops import variables
32from tensorflow.python.ops.linalg import linalg
33from tensorflow.python.platform import benchmark
34from tensorflow.python.platform import test
35
36
37# Different gradient implementations for benchmark purposes
38def _GradWithInverseL(l, l_inverse, grad):
39  middle = math_ops.matmul(l, grad, adjoint_a=True)
40  middle = array_ops.matrix_set_diag(middle,
41                                     0.5 * array_ops.matrix_diag_part(middle))
42  middle = array_ops.matrix_band_part(middle, -1, 0)
43  grad_a = math_ops.matmul(
44      math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
45  grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a))
46  return grad_a * 0.5
47
48
49def TriAngSolveCompositeGrad(l, grad):
50  # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
51
52  # Compute ((l^{H} @ grad) * (tril(ones)-1/2*eye)) = middle
53  middle = math_ops.matmul(l, grad, adjoint_a=True)
54  middle = array_ops.matrix_set_diag(middle,
55                                     0.5 * array_ops.matrix_diag_part(middle))
56  middle = array_ops.matrix_band_part(middle, -1, 0)
57
58  # Compute l^{-H} @ middle = z
59  l_inverse_middle = linalg_ops.matrix_triangular_solve(l, middle, adjoint=True)
60
61  # We need to compute z @ l^{-1}. With matrix_triangular_solve we
62  # actually compute l^{-H} @ z^{H} = grad. Since we later add grad^{H}
63  # we can ommit the conjugate transpose here.
64  z_h = math_ops.conj(array_ops.matrix_transpose(l_inverse_middle))
65  grad_a = linalg_ops.matrix_triangular_solve(l, z_h, adjoint=True)
66  grad_a += linalg.adjoint(grad_a)
67  return grad_a * 0.5
68
69
70def MatrixInverseCompositeGrad(l, grad):
71  l_inverse = linalg_ops.matrix_inverse(l)
72  return _GradWithInverseL(l, l_inverse, grad)
73
74
75def TriAngInvCompositeGrad(l, grad):
76  num_rows = array_ops.shape(l)[-1]
77  batch_shape = array_ops.shape(l)[:-2]
78  l_inverse = linalg_ops.matrix_triangular_solve(l,
79                                                 linalg_ops.eye(
80                                                     num_rows,
81                                                     batch_shape=batch_shape,
82                                                     dtype=l.dtype))
83  return _GradWithInverseL(l, l_inverse, grad)
84
85
86class CholeskyOpTest(test.TestCase):
87
88  def _verifyCholeskyBase(self, x, chol, verification):
89    chol_np, verification_np = self.evaluate([chol, verification])
90    self.assertAllClose(x, verification_np)
91    self.assertShapeEqual(x, chol)
92    # Check that the cholesky is lower triangular, and has positive diagonal
93    # elements.
94    if chol_np.shape[-1] > 0:
95      chol_reshaped = np.reshape(chol_np, (-1, chol_np.shape[-2],
96                                           chol_np.shape[-1]))
97      for chol_matrix in chol_reshaped:
98        self.assertAllClose(chol_matrix, np.tril(chol_matrix))
99        self.assertTrue((np.diag(chol_matrix) > 0.0).all())
100
101  def _verifyCholesky(self, x):
102    # Verify that LL^T == x.
103    chol = linalg_ops.cholesky(x)
104    verification = test_util.matmul_without_tf32(chol, chol, adjoint_b=True)
105    self._verifyCholeskyBase(x, chol, verification)
106
107  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
108  def testBasic(self):
109    data = np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]])
110    for dtype in (np.float32, np.float64):
111      with self.subTest(dtype=dtype):
112        self._verifyCholesky(data.astype(dtype))
113    for dtype in (np.complex64, np.complex128):
114      with self.subTest(dtype=dtype):
115        complex_data = np.tril(1j * data, -1).astype(dtype)
116        complex_data += np.triu(-1j * data, 1).astype(dtype)
117        complex_data += data
118        self._verifyCholesky(complex_data)
119
120  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
121  def testBatch(self):
122    simple_array = np.array([[[1., 0.], [0., 5.]]])  # shape (1, 2, 2)
123    self._verifyCholesky(simple_array)
124    self._verifyCholesky(np.vstack((simple_array, simple_array)))
125    odd_sized_array = np.array([[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]])
126    self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array)))
127
128    # Generate random positive-definite matrices.
129    matrices = np.random.rand(10, 5, 5)
130    for i in range(10):
131      with self.subTest(i=i):
132        matrices[i] = np.dot(matrices[i].T, matrices[i])
133    self._verifyCholesky(matrices)
134
135    # Generate random complex valued positive-definite matrices.
136    matrices = np.random.rand(10, 5, 5) + 1j * np.random.rand(10, 5, 5)
137    for i in range(10):
138      with self.subTest(i=i):
139        matrices[i] = np.dot(matrices[i].T.conj(), matrices[i])
140    self._verifyCholesky(matrices)
141
142  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
143  def testNonSquareMatrix(self):
144    with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
145      linalg_ops.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]]))
146    with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
147      linalg_ops.cholesky(
148          np.array([[[1., 2., 3.], [3., 4., 5.]], [[1., 2., 3.], [3., 4., 5.]]
149                   ]))
150
151  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
152  def testWrongDimensions(self):
153    tensor3 = constant_op.constant([1., 2.])
154    with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
155      linalg_ops.cholesky(tensor3)
156    with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
157      linalg_ops.cholesky(tensor3)
158
159  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
160  def testNotInvertibleCpu(self):
161    # Non-invertible inputs result in lower-triangular NaNs.
162    x = constant_op.constant([[1., -1., 0.], [-1., 1., -1.], [0., -1., 1.]])
163    chol = linalg_ops.cholesky(x)
164    # Extract the lower-triangular elements.
165    lower_mask = array_ops.matrix_band_part(
166        constant_op.constant(True, shape=x.shape), -1, 0)
167    chol_lower = array_ops.boolean_mask(chol, lower_mask)
168    # Assert all NaN.
169    all_nan = self.evaluate(
170        math_ops.reduce_all(math_ops.reduce_all(math_ops.is_nan(chol_lower))))
171    self.assertTrue(all_nan)
172
173  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
174  def testEmpty(self):
175    self._verifyCholesky(np.empty([0, 2, 2]))
176    self._verifyCholesky(np.empty([2, 0, 0]))
177
178  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
179  def testConcurrentExecutesWithoutError(self):
180    seed = [42, 24]
181    matrix_shape = [5, 5]
182    matrix1 = stateless_random_ops.stateless_random_normal(matrix_shape, seed)
183    matrix2 = stateless_random_ops.stateless_random_normal(matrix_shape, seed)
184    matrix1 = math_ops.matmul(matrix1, matrix1, adjoint_a=True)
185    matrix2 = math_ops.matmul(matrix2, matrix2, adjoint_a=True)
186    c1 = linalg_ops.cholesky(matrix1)
187    c2 = linalg_ops.cholesky(matrix2)
188    c1_val, c2_val = self.evaluate([c1, c2])
189    self.assertAllClose(c1_val, c2_val)
190
191
192class CholeskyGradTest(test.TestCase):
193  _backprop_block_size = 16
194
195  def getShapes(self, shapeList):
196    return ((elem, int(np.floor(1.2 * elem))) for elem in shapeList)
197
198  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
199  def testSmallMatrices(self):
200    np.random.seed(0)
201    shapes = self.getShapes([1, 2, 10])
202    self.runFiniteDifferences(
203        shapes, dtypes=(dtypes_lib.float32, dtypes_lib.float64))
204
205  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
206  def testSmallMatricesComplex(self):
207    np.random.seed(0)
208    shapes = self.getShapes([1, 2, 10])
209    self.runFiniteDifferences(
210        shapes, dtypes=(dtypes_lib.complex64, dtypes_lib.complex128))
211
212  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
213  def testOneBlockMatrices(self):
214    np.random.seed(0)
215    shapes = self.getShapes([self._backprop_block_size + 1])
216    self.runFiniteDifferences(
217        shapes,
218        dtypes=(dtypes_lib.float32, dtypes_lib.float64),
219        scalar_test=True)
220
221  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
222  def testTwoBlockMatrixFloat(self):
223    np.random.seed(0)
224    shapes = self.getShapes([2 * self._backprop_block_size + 1])
225    self.runFiniteDifferences(
226        shapes, dtypes=(dtypes_lib.float32,), scalar_test=True)
227
228  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
229  def testTwoBlockMatrixDouble(self):
230    np.random.seed(0)
231    shapes = self.getShapes([2 * self._backprop_block_size + 1])
232    self.runFiniteDifferences(
233        shapes, dtypes=(dtypes_lib.float64,), scalar_test=True)
234
235  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
236  def testTwoBlockMatrixComplexFloat(self):
237    np.random.seed(0)
238    shapes = self.getShapes([2 * self._backprop_block_size + 1])
239    self.runFiniteDifferences(
240        shapes, dtypes=(dtypes_lib.complex64,), scalar_test=True)
241
242  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
243  def testTwoBlockMatrixComplexDouble(self):
244    np.random.seed(0)
245    shapes = self.getShapes([2 * self._backprop_block_size + 1])
246    self.runFiniteDifferences(
247        shapes, dtypes=(dtypes_lib.complex128,), scalar_test=True)
248
249  def _runOneTest(self, shape, dtype, batch, scalar_test):
250    if dtype == dtypes_lib.float64:
251      tol = 1e-5
252    elif dtype == dtypes_lib.complex128:
253      tol = 5e-5
254    else:
255      tol = 5e-3
256    epsilon = np.finfo(dtype.as_numpy_dtype).eps
257    delta = epsilon**(1.0 / 3.0)
258
259    def RandomInput():
260      a = np.random.randn(shape[0], shape[1]).astype(dtype.as_numpy_dtype)
261      if dtype.is_complex:
262        a += 1j * np.random.randn(shape[0], shape[1]).astype(
263            dtype.as_numpy_dtype)
264      return a
265
266    def Compute(x):
267      # Turn the random matrix x into a Hermitian matrix by
268      # computing the quadratic form x * x^H.
269      a = test_util.matmul_without_tf32(
270          x, math_ops.conj(array_ops.matrix_transpose(x))) / shape[0]
271      if batch:
272        a = array_ops.tile(array_ops.expand_dims(a, 0), [2, 1, 1])
273      # Finally take the cholesky decomposition of the Hermitian matrix.
274      c = linalg_ops.cholesky(a)
275      if scalar_test:
276        # Reduce to a single scalar output to speed up test.
277        c = math_ops.reduce_mean(c)
278      return c
279
280    theoretical, numerical = gradient_checker_v2.compute_gradient(
281        Compute, [RandomInput()], delta=delta)
282    self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
283
284  def runFiniteDifferences(self,
285                           shapes,
286                           dtypes=(dtypes_lib.float32, dtypes_lib.float64,
287                                   dtypes_lib.complex64, dtypes_lib.complex128),
288                           scalar_test=False):
289    for shape_ in shapes:
290      for dtype_ in dtypes:
291        for batch_ in False, True:
292          self._runOneTest(shape_, dtype_, batch_, scalar_test)
293
294
295class CholeskyBenchmark(test.Benchmark):
296
297  shapes = [
298      (4, 4),
299      (10, 10),
300      (16, 16),
301      (101, 101),
302      (256, 256),
303      (1000, 1000),
304      (1024, 1024),
305      (2048, 2048),
306      (513, 2, 2),
307      (513, 8, 8),
308      (513, 256, 256),
309      (4, 513, 2, 2),
310  ]
311
312  def _GenerateMatrix(self, shape):
313    batch_shape = shape[:-2]
314    shape = shape[-2:]
315    assert shape[0] == shape[1]
316    n = shape[0]
317    matrix = np.ones(shape).astype(np.float32) / (
318        2.0 * n) + np.diag(np.ones(n).astype(np.float32))
319    return np.tile(matrix, batch_shape + (1, 1))
320
321  def benchmarkCholeskyOp(self):
322    for shape in self.shapes:
323      with ops.Graph().as_default(), \
324          session.Session(config=benchmark.benchmark_config()) as sess, \
325          ops.device("/cpu:0"):
326        matrix = variables.Variable(self._GenerateMatrix(shape))
327        l = linalg_ops.cholesky(matrix)
328        self.evaluate(variables.global_variables_initializer())
329        self.run_op_benchmark(
330            sess,
331            control_flow_ops.group(
332                l,),
333            min_iters=25,
334            name="cholesky_cpu_{shape}".format(shape=shape))
335
336      if test.is_gpu_available(True):
337        with ops.Graph().as_default(), \
338            session.Session(config=benchmark.benchmark_config()) as sess, \
339            ops.device("/device:GPU:0"):
340          matrix = variables.Variable(self._GenerateMatrix(shape))
341          l = linalg_ops.cholesky(matrix)
342          self.evaluate(variables.global_variables_initializer())
343          self.run_op_benchmark(
344              sess,
345              control_flow_ops.group(
346                  l,),
347              min_iters=25,
348              name="cholesky_gpu_{shape}".format(shape=shape))
349
350  def benchmarkGradVariants(self):
351
352    def _BenchmarkGrad(grad_fn, name, device):
353      for shape in self.shapes:
354        matrix = self._GenerateMatrix(shape)
355        with ops.Graph().as_default(), \
356            session.Session(config=benchmark.benchmark_config()) as sess, \
357            ops.device(device):
358          l = variables.Variable(np.linalg.cholesky(matrix))
359          grad_matrix = variables.Variable(
360              np.random.randn(*matrix.shape).astype(np.float32))
361          grad = grad_fn(l, grad_matrix)
362          self.evaluate(variables.global_variables_initializer())
363          self.run_op_benchmark(
364              sess,
365              control_flow_ops.group(
366                  grad,),
367              min_iters=25,
368              name="{name}_{dev}_{shape}".format(
369                  name=name, dev=grad.device, shape=shape))
370
371    if test.is_gpu_available(True):
372      _BenchmarkGrad(MatrixInverseCompositeGrad, "composite_matrix_inverse",
373                     "/device:GPU:0")
374      _BenchmarkGrad(TriAngInvCompositeGrad, "composite_tri_ang_inverse",
375                     "/device:GPU:0")
376      _BenchmarkGrad(TriAngSolveCompositeGrad, "composite_triangular_solve",
377                     "/device:GPU:0")
378
379    _BenchmarkGrad(MatrixInverseCompositeGrad, "composite_matrix_inverse",
380                   "/cpu:0")
381    _BenchmarkGrad(TriAngInvCompositeGrad, "composite_tri_ang_inverse",
382                   "/cpu:0")
383    _BenchmarkGrad(TriAngSolveCompositeGrad, "composite_triangular_solve",
384                   "/cpu:0")
385
386
387if __name__ == "__main__":
388  test.main()
389