xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/linalg/tridiagonal_solve_op_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.linalg.linalg_impl.tridiagonal_solve."""
16
17import itertools
18
19import numpy as np
20
21from tensorflow.python.client import session
22from tensorflow.python.eager import backprop
23from tensorflow.python.eager import context
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import variables
32from tensorflow.python.ops.linalg import linalg_impl
33from tensorflow.python.platform import benchmark
34from tensorflow.python.platform import test
35
36_sample_diags = np.array([[2, 1, 4, 0], [1, 3, 2, 2], [0, 1, -1, 1]])
37_sample_rhs = np.array([1, 2, 3, 4])
38_sample_result = np.array([-9, 5, -4, 4])
39
40# Flag, indicating that test should be run only with partial_pivoting=True
41FLAG_REQUIRES_PIVOTING = "FLAG_REQUIRES_PIVOT"
42
43# Flag, indicating that test shouldn't be parameterized by different values of
44# partial_pivoting, etc.
45FLAG_NO_PARAMETERIZATION = "FLAG_NO_PARAMETERIZATION"
46
47
48def flags(*args):
49
50  def decorator(f):
51    for flag in args:
52      setattr(f, flag, True)
53    return f
54
55  return decorator
56
57
58def _tfconst(array):
59  if array is not None:
60    return constant_op.constant(array, dtypes.float64)
61
62
63def _tf_ones(shape):
64  return array_ops.ones(shape, dtype=dtypes.float64)
65
66
67class TridiagonalSolveOpTest(test.TestCase):
68
69  def _test(self,
70            diags,
71            rhs,
72            expected,
73            diags_format="compact",
74            transpose_rhs=False,
75            conjugate_rhs=False):
76    with self.cached_session():
77      pivoting = True
78      if hasattr(self, "pivoting"):
79        pivoting = self.pivoting
80      if test_util.is_xla_enabled() and pivoting:
81        # Pivoting is not supported by xla backends.
82        return
83      result = linalg_impl.tridiagonal_solve(
84          diags,
85          rhs,
86          diags_format,
87          transpose_rhs,
88          conjugate_rhs,
89          partial_pivoting=pivoting)
90      result = self.evaluate(result)
91      if expected is None:
92        self.assertAllEqual(
93            np.zeros_like(result, dtype=np.bool_), np.isfinite(result))
94      else:
95        self.assertAllClose(result, expected)
96
97  def _testWithLists(self,
98                     diags,
99                     rhs,
100                     expected=None,
101                     diags_format="compact",
102                     transpose_rhs=False,
103                     conjugate_rhs=False):
104    self._test(
105        _tfconst(diags), _tfconst(rhs), _tfconst(expected), diags_format,
106        transpose_rhs, conjugate_rhs)
107
108  def _assertRaises(self, diags, rhs, diags_format="compact"):
109    pivoting = True
110    if hasattr(self, "pivoting"):
111      pivoting = self.pivoting
112    if test_util.is_xla_enabled() and pivoting:
113      # Pivoting is not supported by xla backends.
114      return
115    with self.assertRaises(ValueError):
116      linalg_impl.tridiagonal_solve(
117          diags, rhs, diags_format, partial_pivoting=pivoting)
118
119  # Tests with various dtypes
120
121  def testReal(self):
122    for dtype in dtypes.float32, dtypes.float64:
123      self._test(
124          diags=constant_op.constant(_sample_diags, dtype),
125          rhs=constant_op.constant(_sample_rhs, dtype),
126          expected=constant_op.constant(_sample_result, dtype))
127
128  def testComplex(self):
129    for dtype in dtypes.complex64, dtypes.complex128:
130      self._test(
131          diags=constant_op.constant(_sample_diags, dtype) * (1 + 1j),
132          rhs=constant_op.constant(_sample_rhs, dtype) * (1 - 1j),
133          expected=constant_op.constant(_sample_result, dtype) * (1 - 1j) /
134          (1 + 1j))
135
136  # Tests with small matrix sizes
137
138  def test3x3(self):
139    self._testWithLists(
140        diags=[[2, -1, 0], [1, 3, 1], [0, -1, -2]],
141        rhs=[1, 2, 3],
142        expected=[-3, 2, 7])
143
144  def test2x2(self):
145    self._testWithLists(
146        diags=[[2, 0], [1, 3], [0, 1]], rhs=[1, 4], expected=[-5, 3])
147
148  def test2x2Complex(self):
149    for dtype in dtypes.complex64, dtypes.complex128:
150      self._test(
151          diags=constant_op.constant([[2j, 0j], [1j, 3j], [0j, 1j]], dtype),
152          rhs=constant_op.constant([1 - 1j, 4 - 4j], dtype),
153          expected=constant_op.constant([5 + 5j, -3 - 3j], dtype))
154
155  def test1x1(self):
156    self._testWithLists(diags=[[0], [3], [0]], rhs=[6], expected=[2])
157
158  def test0x0(self):
159    if test_util.is_xla_enabled():
160      # The following test crashes with XLA due to slicing 0 length tensors.
161      return
162    self._test(
163        diags=constant_op.constant(0, shape=(3, 0), dtype=dtypes.float32),
164        rhs=constant_op.constant(0, shape=(0, 1), dtype=dtypes.float32),
165        expected=constant_op.constant(0, shape=(0, 1), dtype=dtypes.float32))
166
167  def test2x2WithMultipleRhs(self):
168    self._testWithLists(
169        diags=[[2, 0], [1, 3], [0, 1]],
170        rhs=[[1, 2, 3], [4, 8, 12]],
171        expected=[[-5, -10, -15], [3, 6, 9]])
172
173  def test1x1WithMultipleRhs(self):
174    self._testWithLists(
175        diags=[[0], [3], [0]], rhs=[[6, 9, 12]], expected=[[2, 3, 4]])
176
177  def test1x1NotInvertible(self):
178    if test_util.is_xla_enabled():
179      # XLA implementation does not check invertibility.
180      return
181    self._testWithLists(diags=[[0], [0], [0]], rhs=[[6, 9, 12]])
182
183  def test2x2NotInvertible(self):
184    if test_util.is_xla_enabled():
185      # XLA implementation does not check invertibility.
186      return
187    self._testWithLists(diags=[[3, 0], [1, 3], [0, 1]], rhs=[1, 4])
188
189  # Other edge cases
190
191  @flags(FLAG_REQUIRES_PIVOTING)
192  def testCaseRequiringPivoting(self):
193    # Without partial pivoting (e.g. Thomas algorithm) this would fail.
194    self._testWithLists(
195        diags=[[2, -1, 1, 0], [1, 4, 1, -1], [0, 2, -2, 3]],
196        rhs=[1, 2, 3, 4],
197        expected=[8, -3.5, 0, -4])
198
199  @flags(FLAG_REQUIRES_PIVOTING)
200  def testCaseRequiringPivotingLastRows(self):
201    self._testWithLists(
202        diags=[[2, 1, -1, 0], [1, -1, 2, 1], [0, 1, -6, 1]],
203        rhs=[1, 2, -1, -2],
204        expected=[5, -2, -5, 3])
205
206  def testNotInvertible(self):
207    if test_util.is_xla_enabled():
208      return
209    self._testWithLists(
210        diags=[[2, -1, 1, 0], [1, 4, 1, -1], [0, 2, 0, 3]], rhs=[1, 2, 3, 4])
211
212  def testDiagonal(self):
213    self._testWithLists(
214        diags=[[0, 0, 0, 0], [1, 2, -1, -2], [0, 0, 0, 0]],
215        rhs=[1, 2, 3, 4],
216        expected=[1, 1, -3, -2])
217
218  def testUpperTriangular(self):
219    self._testWithLists(
220        diags=[[2, 4, -1, 0], [1, 3, 1, 2], [0, 0, 0, 0]],
221        rhs=[1, 6, 4, 4],
222        expected=[13, -6, 6, 2])
223
224  def testLowerTriangular(self):
225    self._testWithLists(
226        diags=[[0, 0, 0, 0], [2, -1, 3, 1], [0, 1, 4, 2]],
227        rhs=[4, 5, 6, 1],
228        expected=[2, -3, 6, -11])
229
230  # Multiple right-hand sides and batching
231
232  def testWithTwoRightHandSides(self):
233    self._testWithLists(
234        diags=_sample_diags,
235        rhs=np.transpose([_sample_rhs, 2 * _sample_rhs]),
236        expected=np.transpose([_sample_result, 2 * _sample_result]))
237
238  def testBatching(self):
239    self._testWithLists(
240        diags=np.array([_sample_diags, -_sample_diags]),
241        rhs=np.array([_sample_rhs, 2 * _sample_rhs]),
242        expected=np.array([_sample_result, -2 * _sample_result]))
243
244  def testWithTwoBatchingDimensions(self):
245    self._testWithLists(
246        diags=np.array([[_sample_diags, -_sample_diags, _sample_diags],
247                        [-_sample_diags, _sample_diags, -_sample_diags]]),
248        rhs=np.array([[_sample_rhs, 2 * _sample_rhs, 3 * _sample_rhs],
249                      [4 * _sample_rhs, 5 * _sample_rhs, 6 * _sample_rhs]]),
250        expected=np.array(
251            [[_sample_result, -2 * _sample_result, 3 * _sample_result],
252             [-4 * _sample_result, 5 * _sample_result, -6 * _sample_result]]))
253
254  def testBatchingAndTwoRightHandSides(self):
255    rhs = np.transpose([_sample_rhs, 2 * _sample_rhs])
256    expected_result = np.transpose([_sample_result, 2 * _sample_result])
257    self._testWithLists(
258        diags=np.array([_sample_diags, -_sample_diags]),
259        rhs=np.array([rhs, 2 * rhs]),
260        expected=np.array([expected_result, -2 * expected_result]))
261
262  # Various input formats
263
264  def testSequenceFormat(self):
265    self._test(
266        diags=(_tfconst([2, 1, 4]), _tfconst([1, 3, 2, 2]), _tfconst([1, -1,
267                                                                      1])),
268        rhs=_tfconst([1, 2, 3, 4]),
269        expected=_tfconst([-9, 5, -4, 4]),
270        diags_format="sequence")
271
272  def testSequenceFormatWithDummyElements(self):
273    dummy = 20
274    self._test(
275        diags=(_tfconst([2, 1, 4,
276                         dummy]), _tfconst([1, 3, 2,
277                                            2]), _tfconst([dummy, 1, -1, 1])),
278        rhs=_tfconst([1, 2, 3, 4]),
279        expected=_tfconst([-9, 5, -4, 4]),
280        diags_format="sequence")
281
282  def testSequenceFormatWithBatching(self):
283    self._test(
284        diags=(_tfconst([[2, 1, 4], [-2, -1, -4]]),
285               _tfconst([[1, 3, 2, 2],
286                         [-1, -3, -2, -2]]), _tfconst([[1, -1, 1], [-1, 1,
287                                                                    -1]])),
288        rhs=_tfconst([[1, 2, 3, 4], [1, 2, 3, 4]]),
289        expected=_tfconst([[-9, 5, -4, 4], [9, -5, 4, -4]]),
290        diags_format="sequence")
291
292  def testMatrixFormat(self):
293    self._testWithLists(
294        diags=[[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4], [0, 0, 1, 2]],
295        rhs=[1, 2, 3, 4],
296        expected=[-9, 5, -4, 4],
297        diags_format="matrix")
298
299  def testMatrixFormatWithMultipleRightHandSides(self):
300    self._testWithLists(
301        diags=[[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4], [0, 0, 1, 2]],
302        rhs=[[1, -1], [2, -2], [3, -3], [4, -4]],
303        expected=[[-9, 9], [5, -5], [-4, 4], [4, -4]],
304        diags_format="matrix")
305
306  def testMatrixFormatWithBatching(self):
307    self._testWithLists(
308        diags=[[[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4], [0, 0, 1, 2]],
309               [[-1, -2, 0, 0], [-1, -3, -1, 0], [0, 1, -2, -4], [0, 0, -1,
310                                                                  -2]]],
311        rhs=[[1, 2, 3, 4], [1, 2, 3, 4]],
312        expected=[[-9, 5, -4, 4], [9, -5, 4, -4]],
313        diags_format="matrix")
314
315  def testRightHandSideAsColumn(self):
316    self._testWithLists(
317        diags=_sample_diags,
318        rhs=np.transpose([_sample_rhs]),
319        expected=np.transpose([_sample_result]),
320        diags_format="compact")
321
322  # Tests with transpose and adjoint
323
324  def testTransposeRhs(self):
325    self._testWithLists(
326        diags=_sample_diags,
327        rhs=np.array([_sample_rhs, 2 * _sample_rhs]),
328        expected=np.array([_sample_result, 2 * _sample_result]).T,
329        transpose_rhs=True)
330
331  def testConjugateRhs(self):
332    self._testWithLists(
333        diags=_sample_diags,
334        rhs=np.transpose([_sample_rhs * (1 + 1j), _sample_rhs * (1 - 2j)]),
335        expected=np.transpose(
336            [_sample_result * (1 - 1j), _sample_result * (1 + 2j)]),
337        conjugate_rhs=True)
338
339  def testAdjointRhs(self):
340    self._testWithLists(
341        diags=_sample_diags,
342        rhs=np.array([_sample_rhs * (1 + 1j), _sample_rhs * (1 - 2j)]),
343        expected=np.array(
344            [_sample_result * (1 - 1j), _sample_result * (1 + 2j)]).T,
345        transpose_rhs=True,
346        conjugate_rhs=True)
347
348  def testTransposeRhsWithBatching(self):
349    self._testWithLists(
350        diags=np.array([_sample_diags, -_sample_diags]),
351        rhs=np.array([[_sample_rhs, 2 * _sample_rhs],
352                      [3 * _sample_rhs, 4 * _sample_rhs]]),
353        expected=np.array([[_sample_result, 2 * _sample_result],
354                           [-3 * _sample_result,
355                            -4 * _sample_result]]).transpose(0, 2, 1),
356        transpose_rhs=True)
357
358  def testTransposeRhsWithRhsAsVector(self):
359    self._testWithLists(
360        diags=_sample_diags,
361        rhs=_sample_rhs,
362        expected=_sample_result,
363        transpose_rhs=True)
364
365  def testConjugateRhsWithRhsAsVector(self):
366    self._testWithLists(
367        diags=_sample_diags,
368        rhs=_sample_rhs * (1 + 1j),
369        expected=_sample_result * (1 - 1j),
370        conjugate_rhs=True)
371
372  def testTransposeRhsWithRhsAsVectorAndBatching(self):
373    self._testWithLists(
374        diags=np.array([_sample_diags, -_sample_diags]),
375        rhs=np.array([_sample_rhs, 2 * _sample_rhs]),
376        expected=np.array([_sample_result, -2 * _sample_result]),
377        transpose_rhs=True)
378
379  # Gradient tests
380
381  def _gradientTest(
382      self,
383      diags,
384      rhs,
385      y,  # output = reduce_sum(y * tridiag_solve(diags, rhs))
386      expected_grad_diags,  # expected gradient of output w.r.t. diags
387      expected_grad_rhs,  # expected gradient of output w.r.t. rhs
388      diags_format="compact",
389      transpose_rhs=False,
390      conjugate_rhs=False,
391      feed_dict=None):
392    expected_grad_diags = _tfconst(expected_grad_diags)
393    expected_grad_rhs = _tfconst(expected_grad_rhs)
394    with backprop.GradientTape() as tape_diags:
395      with backprop.GradientTape() as tape_rhs:
396        tape_diags.watch(diags)
397        tape_rhs.watch(rhs)
398        if test_util.is_xla_enabled():
399          # Pivoting is not supported by xla backends.
400          return
401        x = linalg_impl.tridiagonal_solve(
402            diags,
403            rhs,
404            diagonals_format=diags_format,
405            transpose_rhs=transpose_rhs,
406            conjugate_rhs=conjugate_rhs)
407        res = math_ops.reduce_sum(x * y)
408    with self.cached_session() as sess:
409      actual_grad_diags = sess.run(
410          tape_diags.gradient(res, diags), feed_dict=feed_dict)
411      actual_rhs_diags = sess.run(
412          tape_rhs.gradient(res, rhs), feed_dict=feed_dict)
413    self.assertAllClose(expected_grad_diags, actual_grad_diags)
414    self.assertAllClose(expected_grad_rhs, actual_rhs_diags)
415
416  def _gradientTestWithLists(self,
417                             diags,
418                             rhs,
419                             y,
420                             expected_grad_diags,
421                             expected_grad_rhs,
422                             diags_format="compact",
423                             transpose_rhs=False,
424                             conjugate_rhs=False):
425    self._gradientTest(
426        _tfconst(diags), _tfconst(rhs), _tfconst(y), expected_grad_diags,
427        expected_grad_rhs, diags_format, transpose_rhs, conjugate_rhs)
428
429  def testGradientSimple(self):
430    self._gradientTestWithLists(
431        diags=_sample_diags,
432        rhs=_sample_rhs,
433        y=[1, 3, 2, 4],
434        expected_grad_diags=[[-5, 0, 4, 0], [9, 0, -4, -16], [0, 0, 5, 16]],
435        expected_grad_rhs=[1, 0, -1, 4])
436
437  def testGradientWithMultipleRhs(self):
438    self._gradientTestWithLists(
439        diags=_sample_diags,
440        rhs=[[1, 2], [2, 4], [3, 6], [4, 8]],
441        y=[[1, 5], [2, 6], [3, 7], [4, 8]],
442        expected_grad_diags=([[-20, 28, -60, 0], [36, -35, 60, 80],
443                              [0, 63, -75, -80]]),
444        expected_grad_rhs=[[0, 2], [1, 3], [1, 7], [0, -10]])
445
446  def _makeDataForGradientWithBatching(self):
447    y = np.array([1, 3, 2, 4])
448    grad_diags = np.array([[-5, 0, 4, 0], [9, 0, -4, -16], [0, 0, 5, 16]])
449    grad_rhs = np.array([1, 0, -1, 4])
450
451    diags_batched = np.array(
452        [[_sample_diags, 2 * _sample_diags, 3 * _sample_diags],
453         [4 * _sample_diags, 5 * _sample_diags, 6 * _sample_diags]])
454    rhs_batched = np.array([[_sample_rhs, -_sample_rhs, _sample_rhs],
455                            [-_sample_rhs, _sample_rhs, -_sample_rhs]])
456    y_batched = np.array([[y, y, y], [y, y, y]])
457    expected_grad_diags_batched = np.array(
458        [[grad_diags, -grad_diags / 4, grad_diags / 9],
459         [-grad_diags / 16, grad_diags / 25, -grad_diags / 36]])
460    expected_grad_rhs_batched = np.array(
461        [[grad_rhs, grad_rhs / 2, grad_rhs / 3],
462         [grad_rhs / 4, grad_rhs / 5, grad_rhs / 6]])
463
464    return (y_batched, diags_batched, rhs_batched, expected_grad_diags_batched,
465            expected_grad_rhs_batched)
466
467  def testGradientWithBatchDims(self):
468    y, diags, rhs, expected_grad_diags, expected_grad_rhs = \
469      self._makeDataForGradientWithBatching()
470
471    self._gradientTestWithLists(
472        diags=diags,
473        rhs=rhs,
474        y=y,
475        expected_grad_diags=expected_grad_diags,
476        expected_grad_rhs=expected_grad_rhs)
477
478  @test_util.run_deprecated_v1
479  def testGradientWithUnknownShapes(self):
480
481    def placeholder(rank):
482      return array_ops.placeholder(
483          dtypes.float64, shape=(None for _ in range(rank)))
484
485    y, diags, rhs, expected_grad_diags, expected_grad_rhs = \
486      self._makeDataForGradientWithBatching()
487
488    diags_placeholder = placeholder(rank=4)
489    rhs_placeholder = placeholder(rank=3)
490    y_placeholder = placeholder(rank=3)
491
492    self._gradientTest(
493        diags=diags_placeholder,
494        rhs=rhs_placeholder,
495        y=y_placeholder,
496        expected_grad_diags=expected_grad_diags,
497        expected_grad_rhs=expected_grad_rhs,
498        feed_dict={
499            diags_placeholder: diags,
500            rhs_placeholder: rhs,
501            y_placeholder: y
502        })
503
504  # Invalid input shapes
505
506  @flags(FLAG_NO_PARAMETERIZATION)
507  def testInvalidShapesCompactFormat(self):
508
509    def test_raises(diags_shape, rhs_shape):
510      self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "compact")
511
512    test_raises((5, 4, 4), (5, 4))
513    test_raises((5, 3, 4), (4, 5))
514    test_raises((5, 3, 4), (5))
515    test_raises((5), (5, 4))
516
517  @flags(FLAG_NO_PARAMETERIZATION)
518  def testInvalidShapesSequenceFormat(self):
519
520    def test_raises(diags_tuple_shapes, rhs_shape):
521      diagonals = tuple(_tf_ones(shape) for shape in diags_tuple_shapes)
522      self._assertRaises(diagonals, _tf_ones(rhs_shape), "sequence")
523
524    test_raises(((5, 4), (5, 4)), (5, 4))
525    test_raises(((5, 4), (5, 4), (5, 6)), (5, 4))
526    test_raises(((5, 3), (5, 4), (5, 6)), (5, 4))
527    test_raises(((5, 6), (5, 4), (5, 3)), (5, 4))
528    test_raises(((5, 4), (7, 4), (5, 4)), (5, 4))
529    test_raises(((5, 4), (7, 4), (5, 4)), (3, 4))
530
531  @flags(FLAG_NO_PARAMETERIZATION)
532  def testInvalidShapesMatrixFormat(self):
533
534    def test_raises(diags_shape, rhs_shape):
535      self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "matrix")
536
537    test_raises((5, 4, 7), (5, 4))
538    test_raises((5, 4, 4), (3, 4))
539    test_raises((5, 4, 4), (5, 3))
540
541  # Tests with placeholders
542
543  def _testWithPlaceholders(self,
544                            diags_shape,
545                            rhs_shape,
546                            diags_feed,
547                            rhs_feed,
548                            expected,
549                            diags_format="compact"):
550    if context.executing_eagerly():
551      return
552    diags = array_ops.placeholder(dtypes.float64, shape=diags_shape)
553    rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape)
554    if test_util.is_xla_enabled() and self.pivoting:
555      # Pivoting is not supported by xla backends.
556      return
557    x = linalg_impl.tridiagonal_solve(
558        diags, rhs, diags_format, partial_pivoting=self.pivoting)
559    with self.cached_session() as sess:
560      result = sess.run(x, feed_dict={diags: diags_feed, rhs: rhs_feed})
561      self.assertAllClose(result, expected)
562
563  @test_util.run_deprecated_v1
564  def testCompactFormatAllDimsUnknown(self):
565    self._testWithPlaceholders(
566        diags_shape=[None, None],
567        rhs_shape=[None],
568        diags_feed=_sample_diags,
569        rhs_feed=_sample_rhs,
570        expected=_sample_result)
571
572  @test_util.run_deprecated_v1
573  def testCompactFormatUnknownMatrixSize(self):
574    self._testWithPlaceholders(
575        diags_shape=[3, None],
576        rhs_shape=[4],
577        diags_feed=_sample_diags,
578        rhs_feed=_sample_rhs,
579        expected=_sample_result)
580
581  @test_util.run_deprecated_v1
582  def testCompactFormatUnknownRhsCount(self):
583    self._testWithPlaceholders(
584        diags_shape=[3, 4],
585        rhs_shape=[4, None],
586        diags_feed=_sample_diags,
587        rhs_feed=np.transpose([_sample_rhs, 2 * _sample_rhs]),
588        expected=np.transpose([_sample_result, 2 * _sample_result]))
589
590  @test_util.run_deprecated_v1
591  def testCompactFormatUnknownBatchSize(self):
592    self._testWithPlaceholders(
593        diags_shape=[None, 3, 4],
594        rhs_shape=[None, 4],
595        diags_feed=np.array([_sample_diags, -_sample_diags]),
596        rhs_feed=np.array([_sample_rhs, 2 * _sample_rhs]),
597        expected=np.array([_sample_result, -2 * _sample_result]))
598
599  @test_util.run_deprecated_v1
600  def testMatrixFormatWithUnknownDims(self):
601    if context.executing_eagerly():
602      return
603
604    def test_with_matrix_shapes(matrix_shape, rhs_shape=None):
605      matrix = np.array([[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4],
606                         [0, 0, 1, 2]])
607      rhs = np.array([1, 2, 3, 4])
608      x = np.array([-9, 5, -4, 4])
609      self._testWithPlaceholders(
610          diags_shape=matrix_shape,
611          rhs_shape=rhs_shape,
612          diags_feed=matrix,
613          rhs_feed=np.transpose([rhs, 2 * rhs]),
614          expected=np.transpose([x, 2 * x]),
615          diags_format="matrix")
616
617    test_with_matrix_shapes(matrix_shape=[4, 4], rhs_shape=[None, None])
618    test_with_matrix_shapes(matrix_shape=[None, 4], rhs_shape=[None, None])
619    test_with_matrix_shapes(matrix_shape=[4, None], rhs_shape=[None, None])
620    test_with_matrix_shapes(matrix_shape=[None, None], rhs_shape=[None, None])
621    test_with_matrix_shapes(matrix_shape=[4, 4])
622    test_with_matrix_shapes(matrix_shape=[None, 4])
623    test_with_matrix_shapes(matrix_shape=[4, None])
624    test_with_matrix_shapes(matrix_shape=[None, None])
625    test_with_matrix_shapes(matrix_shape=None, rhs_shape=[None, None])
626    test_with_matrix_shapes(matrix_shape=None)
627
628  @test_util.run_deprecated_v1
629  def testSequenceFormatWithUnknownDims(self):
630    if context.executing_eagerly():
631      return
632    if test_util.is_xla_enabled() and self.pivoting:
633      # Pivoting is not supported by xla backends.
634      return
635    superdiag = array_ops.placeholder(dtypes.float64, shape=[None])
636    diag = array_ops.placeholder(dtypes.float64, shape=[None])
637    subdiag = array_ops.placeholder(dtypes.float64, shape=[None])
638    rhs = array_ops.placeholder(dtypes.float64, shape=[None])
639
640    x = linalg_impl.tridiagonal_solve((superdiag, diag, subdiag),
641                                      rhs,
642                                      diagonals_format="sequence",
643                                      partial_pivoting=self.pivoting)
644    with self.cached_session() as sess:
645      result = sess.run(
646          x,
647          feed_dict={
648              subdiag: [20, 1, -1, 1],
649              diag: [1, 3, 2, 2],
650              superdiag: [2, 1, 4, 20],
651              rhs: [1, 2, 3, 4]
652          })
653      self.assertAllClose(result, [-9, 5, -4, 4])
654
655  # Benchmark
656
657  class TridiagonalSolveBenchmark(test.Benchmark):
658    sizes = [(100000, 1, 1), (1000000, 1, 1), (10000000, 1, 1), (100000, 10, 1),
659             (100000, 100, 1), (10000, 1, 10), (10000, 1, 100)]
660
661    pivoting_options = [(True, "pivoting"), (False, "no_pivoting")]
662
663    def _generateData(self, matrix_size, batch_size, num_rhs, seed=42):
664      np.random.seed(seed)
665      data = np.random.normal(size=(batch_size, matrix_size, 3 + num_rhs))
666      diags = np.stack([data[:, :, 0], data[:, :, 1], data[:, :, 2]], axis=-2)
667      rhs = data[:, :, 3:]
668      return (variables.Variable(diags, dtype=dtypes.float64),
669              variables.Variable(rhs, dtype=dtypes.float64))
670
671    def _generateMatrixData(self, matrix_size, batch_size, num_rhs, seed=42):
672      np.random.seed(seed)
673      import scipy.sparse as sparse  # pylint:disable=g-import-not-at-top
674      # By being strictly diagonally dominant, we guarantee invertibility.d
675      diag = 2 * np.abs(np.random.randn(matrix_size)) + 4.1
676      subdiag = 2 * np.abs(np.random.randn(matrix_size - 1))
677      superdiag = 2 * np.abs(np.random.randn(matrix_size - 1))
678      matrix = sparse.diags([superdiag, diag, subdiag], [1, 0, -1]).toarray()
679      vector = np.random.randn(batch_size, matrix_size, num_rhs)
680      return (variables.Variable(np.tile(matrix, (batch_size, 1, 1))),
681              variables.Variable(vector))
682
683    def _benchmark(self, generate_data_fn, test_name_format_string):
684      devices = [("/cpu:0", "cpu")]
685      if test.is_gpu_available(cuda_only=True):
686        devices += [("/gpu:0", "gpu")]
687
688      for device_option, pivoting_option, size_option in \
689          itertools.product(devices, self.pivoting_options, self.sizes):
690
691        device_id, device_name = device_option
692        pivoting, pivoting_name = pivoting_option
693        matrix_size, batch_size, num_rhs = size_option
694
695        with ops.Graph().as_default(), \
696            session.Session(config=benchmark.benchmark_config()) as sess, \
697            ops.device(device_id):
698          diags, rhs = generate_data_fn(matrix_size, batch_size, num_rhs)
699          # Pivoting is not supported by XLA backends.
700          if test.is_xla_enabled() and pivoting:
701            return
702          x = linalg_impl.tridiagonal_solve(
703              diags, rhs, partial_pivoting=pivoting)
704          self.evaluate(variables.global_variables_initializer())
705          self.run_op_benchmark(
706              sess,
707              control_flow_ops.group(x),
708              min_iters=10,
709              store_memory_usage=False,
710              name=test_name_format_string.format(device_name, matrix_size,
711                                                  batch_size, num_rhs,
712                                                  pivoting_name))
713
714    def benchmarkTridiagonalSolveOp_WithMatrixInput(self):
715      self._benchmark(
716          self._generateMatrixData,
717          test_name_format_string=(
718              "tridiagonal_solve_matrix_format_{}_matrix_size_{}_"
719              "batch_size_{}_num_rhs_{}_{}"))
720
721    def benchmarkTridiagonalSolveOp(self):
722      self._benchmark(
723          self._generateMatrixData,
724          test_name_format_string=("tridiagonal_solve_{}_matrix_size_{}_"
725                                   "batch_size_{}_num_rhs_{}_{}"))
726
727
728if __name__ == "__main__":
729  for name, fun in dict(TridiagonalSolveOpTest.__dict__).items():
730    if not name.startswith("test"):
731      continue
732    if hasattr(fun, FLAG_NO_PARAMETERIZATION):
733      continue
734
735    # Replace testFoo with testFoo_pivoting and testFoo_noPivoting, setting
736    # self.pivoting to corresponding value.
737    delattr(TridiagonalSolveOpTest, name)
738
739    def decor(test_fun, pivoting):
740
741      def wrapped(instance):
742        instance.pivoting = pivoting
743        test_fun(instance)
744
745      return wrapped
746
747    setattr(TridiagonalSolveOpTest, name + "_pivoting",
748            decor(fun, pivoting=True))
749    if not hasattr(fun, FLAG_REQUIRES_PIVOTING):
750      setattr(TridiagonalSolveOpTest, name + "_noPivoting",
751              decor(fun, pivoting=False))
752
753  test.main()
754