1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.linalg.linalg_impl.tridiagonal_solve.""" 16 17import itertools 18 19import numpy as np 20 21from tensorflow.python.client import session 22from tensorflow.python.eager import backprop 23from tensorflow.python.eager import context 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import variables 32from tensorflow.python.ops.linalg import linalg_impl 33from tensorflow.python.platform import benchmark 34from tensorflow.python.platform import test 35 36_sample_diags = np.array([[2, 1, 4, 0], [1, 3, 2, 2], [0, 1, -1, 1]]) 37_sample_rhs = np.array([1, 2, 3, 4]) 38_sample_result = np.array([-9, 5, -4, 4]) 39 40# Flag, indicating that test should be run only with partial_pivoting=True 41FLAG_REQUIRES_PIVOTING = "FLAG_REQUIRES_PIVOT" 42 43# Flag, indicating that test shouldn't be parameterized by different values of 44# partial_pivoting, etc. 45FLAG_NO_PARAMETERIZATION = "FLAG_NO_PARAMETERIZATION" 46 47 48def flags(*args): 49 50 def decorator(f): 51 for flag in args: 52 setattr(f, flag, True) 53 return f 54 55 return decorator 56 57 58def _tfconst(array): 59 if array is not None: 60 return constant_op.constant(array, dtypes.float64) 61 62 63def _tf_ones(shape): 64 return array_ops.ones(shape, dtype=dtypes.float64) 65 66 67class TridiagonalSolveOpTest(test.TestCase): 68 69 def _test(self, 70 diags, 71 rhs, 72 expected, 73 diags_format="compact", 74 transpose_rhs=False, 75 conjugate_rhs=False): 76 with self.cached_session(): 77 pivoting = True 78 if hasattr(self, "pivoting"): 79 pivoting = self.pivoting 80 if test_util.is_xla_enabled() and pivoting: 81 # Pivoting is not supported by xla backends. 82 return 83 result = linalg_impl.tridiagonal_solve( 84 diags, 85 rhs, 86 diags_format, 87 transpose_rhs, 88 conjugate_rhs, 89 partial_pivoting=pivoting) 90 result = self.evaluate(result) 91 if expected is None: 92 self.assertAllEqual( 93 np.zeros_like(result, dtype=np.bool_), np.isfinite(result)) 94 else: 95 self.assertAllClose(result, expected) 96 97 def _testWithLists(self, 98 diags, 99 rhs, 100 expected=None, 101 diags_format="compact", 102 transpose_rhs=False, 103 conjugate_rhs=False): 104 self._test( 105 _tfconst(diags), _tfconst(rhs), _tfconst(expected), diags_format, 106 transpose_rhs, conjugate_rhs) 107 108 def _assertRaises(self, diags, rhs, diags_format="compact"): 109 pivoting = True 110 if hasattr(self, "pivoting"): 111 pivoting = self.pivoting 112 if test_util.is_xla_enabled() and pivoting: 113 # Pivoting is not supported by xla backends. 114 return 115 with self.assertRaises(ValueError): 116 linalg_impl.tridiagonal_solve( 117 diags, rhs, diags_format, partial_pivoting=pivoting) 118 119 # Tests with various dtypes 120 121 def testReal(self): 122 for dtype in dtypes.float32, dtypes.float64: 123 self._test( 124 diags=constant_op.constant(_sample_diags, dtype), 125 rhs=constant_op.constant(_sample_rhs, dtype), 126 expected=constant_op.constant(_sample_result, dtype)) 127 128 def testComplex(self): 129 for dtype in dtypes.complex64, dtypes.complex128: 130 self._test( 131 diags=constant_op.constant(_sample_diags, dtype) * (1 + 1j), 132 rhs=constant_op.constant(_sample_rhs, dtype) * (1 - 1j), 133 expected=constant_op.constant(_sample_result, dtype) * (1 - 1j) / 134 (1 + 1j)) 135 136 # Tests with small matrix sizes 137 138 def test3x3(self): 139 self._testWithLists( 140 diags=[[2, -1, 0], [1, 3, 1], [0, -1, -2]], 141 rhs=[1, 2, 3], 142 expected=[-3, 2, 7]) 143 144 def test2x2(self): 145 self._testWithLists( 146 diags=[[2, 0], [1, 3], [0, 1]], rhs=[1, 4], expected=[-5, 3]) 147 148 def test2x2Complex(self): 149 for dtype in dtypes.complex64, dtypes.complex128: 150 self._test( 151 diags=constant_op.constant([[2j, 0j], [1j, 3j], [0j, 1j]], dtype), 152 rhs=constant_op.constant([1 - 1j, 4 - 4j], dtype), 153 expected=constant_op.constant([5 + 5j, -3 - 3j], dtype)) 154 155 def test1x1(self): 156 self._testWithLists(diags=[[0], [3], [0]], rhs=[6], expected=[2]) 157 158 def test0x0(self): 159 if test_util.is_xla_enabled(): 160 # The following test crashes with XLA due to slicing 0 length tensors. 161 return 162 self._test( 163 diags=constant_op.constant(0, shape=(3, 0), dtype=dtypes.float32), 164 rhs=constant_op.constant(0, shape=(0, 1), dtype=dtypes.float32), 165 expected=constant_op.constant(0, shape=(0, 1), dtype=dtypes.float32)) 166 167 def test2x2WithMultipleRhs(self): 168 self._testWithLists( 169 diags=[[2, 0], [1, 3], [0, 1]], 170 rhs=[[1, 2, 3], [4, 8, 12]], 171 expected=[[-5, -10, -15], [3, 6, 9]]) 172 173 def test1x1WithMultipleRhs(self): 174 self._testWithLists( 175 diags=[[0], [3], [0]], rhs=[[6, 9, 12]], expected=[[2, 3, 4]]) 176 177 def test1x1NotInvertible(self): 178 if test_util.is_xla_enabled(): 179 # XLA implementation does not check invertibility. 180 return 181 self._testWithLists(diags=[[0], [0], [0]], rhs=[[6, 9, 12]]) 182 183 def test2x2NotInvertible(self): 184 if test_util.is_xla_enabled(): 185 # XLA implementation does not check invertibility. 186 return 187 self._testWithLists(diags=[[3, 0], [1, 3], [0, 1]], rhs=[1, 4]) 188 189 # Other edge cases 190 191 @flags(FLAG_REQUIRES_PIVOTING) 192 def testCaseRequiringPivoting(self): 193 # Without partial pivoting (e.g. Thomas algorithm) this would fail. 194 self._testWithLists( 195 diags=[[2, -1, 1, 0], [1, 4, 1, -1], [0, 2, -2, 3]], 196 rhs=[1, 2, 3, 4], 197 expected=[8, -3.5, 0, -4]) 198 199 @flags(FLAG_REQUIRES_PIVOTING) 200 def testCaseRequiringPivotingLastRows(self): 201 self._testWithLists( 202 diags=[[2, 1, -1, 0], [1, -1, 2, 1], [0, 1, -6, 1]], 203 rhs=[1, 2, -1, -2], 204 expected=[5, -2, -5, 3]) 205 206 def testNotInvertible(self): 207 if test_util.is_xla_enabled(): 208 return 209 self._testWithLists( 210 diags=[[2, -1, 1, 0], [1, 4, 1, -1], [0, 2, 0, 3]], rhs=[1, 2, 3, 4]) 211 212 def testDiagonal(self): 213 self._testWithLists( 214 diags=[[0, 0, 0, 0], [1, 2, -1, -2], [0, 0, 0, 0]], 215 rhs=[1, 2, 3, 4], 216 expected=[1, 1, -3, -2]) 217 218 def testUpperTriangular(self): 219 self._testWithLists( 220 diags=[[2, 4, -1, 0], [1, 3, 1, 2], [0, 0, 0, 0]], 221 rhs=[1, 6, 4, 4], 222 expected=[13, -6, 6, 2]) 223 224 def testLowerTriangular(self): 225 self._testWithLists( 226 diags=[[0, 0, 0, 0], [2, -1, 3, 1], [0, 1, 4, 2]], 227 rhs=[4, 5, 6, 1], 228 expected=[2, -3, 6, -11]) 229 230 # Multiple right-hand sides and batching 231 232 def testWithTwoRightHandSides(self): 233 self._testWithLists( 234 diags=_sample_diags, 235 rhs=np.transpose([_sample_rhs, 2 * _sample_rhs]), 236 expected=np.transpose([_sample_result, 2 * _sample_result])) 237 238 def testBatching(self): 239 self._testWithLists( 240 diags=np.array([_sample_diags, -_sample_diags]), 241 rhs=np.array([_sample_rhs, 2 * _sample_rhs]), 242 expected=np.array([_sample_result, -2 * _sample_result])) 243 244 def testWithTwoBatchingDimensions(self): 245 self._testWithLists( 246 diags=np.array([[_sample_diags, -_sample_diags, _sample_diags], 247 [-_sample_diags, _sample_diags, -_sample_diags]]), 248 rhs=np.array([[_sample_rhs, 2 * _sample_rhs, 3 * _sample_rhs], 249 [4 * _sample_rhs, 5 * _sample_rhs, 6 * _sample_rhs]]), 250 expected=np.array( 251 [[_sample_result, -2 * _sample_result, 3 * _sample_result], 252 [-4 * _sample_result, 5 * _sample_result, -6 * _sample_result]])) 253 254 def testBatchingAndTwoRightHandSides(self): 255 rhs = np.transpose([_sample_rhs, 2 * _sample_rhs]) 256 expected_result = np.transpose([_sample_result, 2 * _sample_result]) 257 self._testWithLists( 258 diags=np.array([_sample_diags, -_sample_diags]), 259 rhs=np.array([rhs, 2 * rhs]), 260 expected=np.array([expected_result, -2 * expected_result])) 261 262 # Various input formats 263 264 def testSequenceFormat(self): 265 self._test( 266 diags=(_tfconst([2, 1, 4]), _tfconst([1, 3, 2, 2]), _tfconst([1, -1, 267 1])), 268 rhs=_tfconst([1, 2, 3, 4]), 269 expected=_tfconst([-9, 5, -4, 4]), 270 diags_format="sequence") 271 272 def testSequenceFormatWithDummyElements(self): 273 dummy = 20 274 self._test( 275 diags=(_tfconst([2, 1, 4, 276 dummy]), _tfconst([1, 3, 2, 277 2]), _tfconst([dummy, 1, -1, 1])), 278 rhs=_tfconst([1, 2, 3, 4]), 279 expected=_tfconst([-9, 5, -4, 4]), 280 diags_format="sequence") 281 282 def testSequenceFormatWithBatching(self): 283 self._test( 284 diags=(_tfconst([[2, 1, 4], [-2, -1, -4]]), 285 _tfconst([[1, 3, 2, 2], 286 [-1, -3, -2, -2]]), _tfconst([[1, -1, 1], [-1, 1, 287 -1]])), 288 rhs=_tfconst([[1, 2, 3, 4], [1, 2, 3, 4]]), 289 expected=_tfconst([[-9, 5, -4, 4], [9, -5, 4, -4]]), 290 diags_format="sequence") 291 292 def testMatrixFormat(self): 293 self._testWithLists( 294 diags=[[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4], [0, 0, 1, 2]], 295 rhs=[1, 2, 3, 4], 296 expected=[-9, 5, -4, 4], 297 diags_format="matrix") 298 299 def testMatrixFormatWithMultipleRightHandSides(self): 300 self._testWithLists( 301 diags=[[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4], [0, 0, 1, 2]], 302 rhs=[[1, -1], [2, -2], [3, -3], [4, -4]], 303 expected=[[-9, 9], [5, -5], [-4, 4], [4, -4]], 304 diags_format="matrix") 305 306 def testMatrixFormatWithBatching(self): 307 self._testWithLists( 308 diags=[[[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4], [0, 0, 1, 2]], 309 [[-1, -2, 0, 0], [-1, -3, -1, 0], [0, 1, -2, -4], [0, 0, -1, 310 -2]]], 311 rhs=[[1, 2, 3, 4], [1, 2, 3, 4]], 312 expected=[[-9, 5, -4, 4], [9, -5, 4, -4]], 313 diags_format="matrix") 314 315 def testRightHandSideAsColumn(self): 316 self._testWithLists( 317 diags=_sample_diags, 318 rhs=np.transpose([_sample_rhs]), 319 expected=np.transpose([_sample_result]), 320 diags_format="compact") 321 322 # Tests with transpose and adjoint 323 324 def testTransposeRhs(self): 325 self._testWithLists( 326 diags=_sample_diags, 327 rhs=np.array([_sample_rhs, 2 * _sample_rhs]), 328 expected=np.array([_sample_result, 2 * _sample_result]).T, 329 transpose_rhs=True) 330 331 def testConjugateRhs(self): 332 self._testWithLists( 333 diags=_sample_diags, 334 rhs=np.transpose([_sample_rhs * (1 + 1j), _sample_rhs * (1 - 2j)]), 335 expected=np.transpose( 336 [_sample_result * (1 - 1j), _sample_result * (1 + 2j)]), 337 conjugate_rhs=True) 338 339 def testAdjointRhs(self): 340 self._testWithLists( 341 diags=_sample_diags, 342 rhs=np.array([_sample_rhs * (1 + 1j), _sample_rhs * (1 - 2j)]), 343 expected=np.array( 344 [_sample_result * (1 - 1j), _sample_result * (1 + 2j)]).T, 345 transpose_rhs=True, 346 conjugate_rhs=True) 347 348 def testTransposeRhsWithBatching(self): 349 self._testWithLists( 350 diags=np.array([_sample_diags, -_sample_diags]), 351 rhs=np.array([[_sample_rhs, 2 * _sample_rhs], 352 [3 * _sample_rhs, 4 * _sample_rhs]]), 353 expected=np.array([[_sample_result, 2 * _sample_result], 354 [-3 * _sample_result, 355 -4 * _sample_result]]).transpose(0, 2, 1), 356 transpose_rhs=True) 357 358 def testTransposeRhsWithRhsAsVector(self): 359 self._testWithLists( 360 diags=_sample_diags, 361 rhs=_sample_rhs, 362 expected=_sample_result, 363 transpose_rhs=True) 364 365 def testConjugateRhsWithRhsAsVector(self): 366 self._testWithLists( 367 diags=_sample_diags, 368 rhs=_sample_rhs * (1 + 1j), 369 expected=_sample_result * (1 - 1j), 370 conjugate_rhs=True) 371 372 def testTransposeRhsWithRhsAsVectorAndBatching(self): 373 self._testWithLists( 374 diags=np.array([_sample_diags, -_sample_diags]), 375 rhs=np.array([_sample_rhs, 2 * _sample_rhs]), 376 expected=np.array([_sample_result, -2 * _sample_result]), 377 transpose_rhs=True) 378 379 # Gradient tests 380 381 def _gradientTest( 382 self, 383 diags, 384 rhs, 385 y, # output = reduce_sum(y * tridiag_solve(diags, rhs)) 386 expected_grad_diags, # expected gradient of output w.r.t. diags 387 expected_grad_rhs, # expected gradient of output w.r.t. rhs 388 diags_format="compact", 389 transpose_rhs=False, 390 conjugate_rhs=False, 391 feed_dict=None): 392 expected_grad_diags = _tfconst(expected_grad_diags) 393 expected_grad_rhs = _tfconst(expected_grad_rhs) 394 with backprop.GradientTape() as tape_diags: 395 with backprop.GradientTape() as tape_rhs: 396 tape_diags.watch(diags) 397 tape_rhs.watch(rhs) 398 if test_util.is_xla_enabled(): 399 # Pivoting is not supported by xla backends. 400 return 401 x = linalg_impl.tridiagonal_solve( 402 diags, 403 rhs, 404 diagonals_format=diags_format, 405 transpose_rhs=transpose_rhs, 406 conjugate_rhs=conjugate_rhs) 407 res = math_ops.reduce_sum(x * y) 408 with self.cached_session() as sess: 409 actual_grad_diags = sess.run( 410 tape_diags.gradient(res, diags), feed_dict=feed_dict) 411 actual_rhs_diags = sess.run( 412 tape_rhs.gradient(res, rhs), feed_dict=feed_dict) 413 self.assertAllClose(expected_grad_diags, actual_grad_diags) 414 self.assertAllClose(expected_grad_rhs, actual_rhs_diags) 415 416 def _gradientTestWithLists(self, 417 diags, 418 rhs, 419 y, 420 expected_grad_diags, 421 expected_grad_rhs, 422 diags_format="compact", 423 transpose_rhs=False, 424 conjugate_rhs=False): 425 self._gradientTest( 426 _tfconst(diags), _tfconst(rhs), _tfconst(y), expected_grad_diags, 427 expected_grad_rhs, diags_format, transpose_rhs, conjugate_rhs) 428 429 def testGradientSimple(self): 430 self._gradientTestWithLists( 431 diags=_sample_diags, 432 rhs=_sample_rhs, 433 y=[1, 3, 2, 4], 434 expected_grad_diags=[[-5, 0, 4, 0], [9, 0, -4, -16], [0, 0, 5, 16]], 435 expected_grad_rhs=[1, 0, -1, 4]) 436 437 def testGradientWithMultipleRhs(self): 438 self._gradientTestWithLists( 439 diags=_sample_diags, 440 rhs=[[1, 2], [2, 4], [3, 6], [4, 8]], 441 y=[[1, 5], [2, 6], [3, 7], [4, 8]], 442 expected_grad_diags=([[-20, 28, -60, 0], [36, -35, 60, 80], 443 [0, 63, -75, -80]]), 444 expected_grad_rhs=[[0, 2], [1, 3], [1, 7], [0, -10]]) 445 446 def _makeDataForGradientWithBatching(self): 447 y = np.array([1, 3, 2, 4]) 448 grad_diags = np.array([[-5, 0, 4, 0], [9, 0, -4, -16], [0, 0, 5, 16]]) 449 grad_rhs = np.array([1, 0, -1, 4]) 450 451 diags_batched = np.array( 452 [[_sample_diags, 2 * _sample_diags, 3 * _sample_diags], 453 [4 * _sample_diags, 5 * _sample_diags, 6 * _sample_diags]]) 454 rhs_batched = np.array([[_sample_rhs, -_sample_rhs, _sample_rhs], 455 [-_sample_rhs, _sample_rhs, -_sample_rhs]]) 456 y_batched = np.array([[y, y, y], [y, y, y]]) 457 expected_grad_diags_batched = np.array( 458 [[grad_diags, -grad_diags / 4, grad_diags / 9], 459 [-grad_diags / 16, grad_diags / 25, -grad_diags / 36]]) 460 expected_grad_rhs_batched = np.array( 461 [[grad_rhs, grad_rhs / 2, grad_rhs / 3], 462 [grad_rhs / 4, grad_rhs / 5, grad_rhs / 6]]) 463 464 return (y_batched, diags_batched, rhs_batched, expected_grad_diags_batched, 465 expected_grad_rhs_batched) 466 467 def testGradientWithBatchDims(self): 468 y, diags, rhs, expected_grad_diags, expected_grad_rhs = \ 469 self._makeDataForGradientWithBatching() 470 471 self._gradientTestWithLists( 472 diags=diags, 473 rhs=rhs, 474 y=y, 475 expected_grad_diags=expected_grad_diags, 476 expected_grad_rhs=expected_grad_rhs) 477 478 @test_util.run_deprecated_v1 479 def testGradientWithUnknownShapes(self): 480 481 def placeholder(rank): 482 return array_ops.placeholder( 483 dtypes.float64, shape=(None for _ in range(rank))) 484 485 y, diags, rhs, expected_grad_diags, expected_grad_rhs = \ 486 self._makeDataForGradientWithBatching() 487 488 diags_placeholder = placeholder(rank=4) 489 rhs_placeholder = placeholder(rank=3) 490 y_placeholder = placeholder(rank=3) 491 492 self._gradientTest( 493 diags=diags_placeholder, 494 rhs=rhs_placeholder, 495 y=y_placeholder, 496 expected_grad_diags=expected_grad_diags, 497 expected_grad_rhs=expected_grad_rhs, 498 feed_dict={ 499 diags_placeholder: diags, 500 rhs_placeholder: rhs, 501 y_placeholder: y 502 }) 503 504 # Invalid input shapes 505 506 @flags(FLAG_NO_PARAMETERIZATION) 507 def testInvalidShapesCompactFormat(self): 508 509 def test_raises(diags_shape, rhs_shape): 510 self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "compact") 511 512 test_raises((5, 4, 4), (5, 4)) 513 test_raises((5, 3, 4), (4, 5)) 514 test_raises((5, 3, 4), (5)) 515 test_raises((5), (5, 4)) 516 517 @flags(FLAG_NO_PARAMETERIZATION) 518 def testInvalidShapesSequenceFormat(self): 519 520 def test_raises(diags_tuple_shapes, rhs_shape): 521 diagonals = tuple(_tf_ones(shape) for shape in diags_tuple_shapes) 522 self._assertRaises(diagonals, _tf_ones(rhs_shape), "sequence") 523 524 test_raises(((5, 4), (5, 4)), (5, 4)) 525 test_raises(((5, 4), (5, 4), (5, 6)), (5, 4)) 526 test_raises(((5, 3), (5, 4), (5, 6)), (5, 4)) 527 test_raises(((5, 6), (5, 4), (5, 3)), (5, 4)) 528 test_raises(((5, 4), (7, 4), (5, 4)), (5, 4)) 529 test_raises(((5, 4), (7, 4), (5, 4)), (3, 4)) 530 531 @flags(FLAG_NO_PARAMETERIZATION) 532 def testInvalidShapesMatrixFormat(self): 533 534 def test_raises(diags_shape, rhs_shape): 535 self._assertRaises(_tf_ones(diags_shape), _tf_ones(rhs_shape), "matrix") 536 537 test_raises((5, 4, 7), (5, 4)) 538 test_raises((5, 4, 4), (3, 4)) 539 test_raises((5, 4, 4), (5, 3)) 540 541 # Tests with placeholders 542 543 def _testWithPlaceholders(self, 544 diags_shape, 545 rhs_shape, 546 diags_feed, 547 rhs_feed, 548 expected, 549 diags_format="compact"): 550 if context.executing_eagerly(): 551 return 552 diags = array_ops.placeholder(dtypes.float64, shape=diags_shape) 553 rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape) 554 if test_util.is_xla_enabled() and self.pivoting: 555 # Pivoting is not supported by xla backends. 556 return 557 x = linalg_impl.tridiagonal_solve( 558 diags, rhs, diags_format, partial_pivoting=self.pivoting) 559 with self.cached_session() as sess: 560 result = sess.run(x, feed_dict={diags: diags_feed, rhs: rhs_feed}) 561 self.assertAllClose(result, expected) 562 563 @test_util.run_deprecated_v1 564 def testCompactFormatAllDimsUnknown(self): 565 self._testWithPlaceholders( 566 diags_shape=[None, None], 567 rhs_shape=[None], 568 diags_feed=_sample_diags, 569 rhs_feed=_sample_rhs, 570 expected=_sample_result) 571 572 @test_util.run_deprecated_v1 573 def testCompactFormatUnknownMatrixSize(self): 574 self._testWithPlaceholders( 575 diags_shape=[3, None], 576 rhs_shape=[4], 577 diags_feed=_sample_diags, 578 rhs_feed=_sample_rhs, 579 expected=_sample_result) 580 581 @test_util.run_deprecated_v1 582 def testCompactFormatUnknownRhsCount(self): 583 self._testWithPlaceholders( 584 diags_shape=[3, 4], 585 rhs_shape=[4, None], 586 diags_feed=_sample_diags, 587 rhs_feed=np.transpose([_sample_rhs, 2 * _sample_rhs]), 588 expected=np.transpose([_sample_result, 2 * _sample_result])) 589 590 @test_util.run_deprecated_v1 591 def testCompactFormatUnknownBatchSize(self): 592 self._testWithPlaceholders( 593 diags_shape=[None, 3, 4], 594 rhs_shape=[None, 4], 595 diags_feed=np.array([_sample_diags, -_sample_diags]), 596 rhs_feed=np.array([_sample_rhs, 2 * _sample_rhs]), 597 expected=np.array([_sample_result, -2 * _sample_result])) 598 599 @test_util.run_deprecated_v1 600 def testMatrixFormatWithUnknownDims(self): 601 if context.executing_eagerly(): 602 return 603 604 def test_with_matrix_shapes(matrix_shape, rhs_shape=None): 605 matrix = np.array([[1, 2, 0, 0], [1, 3, 1, 0], [0, -1, 2, 4], 606 [0, 0, 1, 2]]) 607 rhs = np.array([1, 2, 3, 4]) 608 x = np.array([-9, 5, -4, 4]) 609 self._testWithPlaceholders( 610 diags_shape=matrix_shape, 611 rhs_shape=rhs_shape, 612 diags_feed=matrix, 613 rhs_feed=np.transpose([rhs, 2 * rhs]), 614 expected=np.transpose([x, 2 * x]), 615 diags_format="matrix") 616 617 test_with_matrix_shapes(matrix_shape=[4, 4], rhs_shape=[None, None]) 618 test_with_matrix_shapes(matrix_shape=[None, 4], rhs_shape=[None, None]) 619 test_with_matrix_shapes(matrix_shape=[4, None], rhs_shape=[None, None]) 620 test_with_matrix_shapes(matrix_shape=[None, None], rhs_shape=[None, None]) 621 test_with_matrix_shapes(matrix_shape=[4, 4]) 622 test_with_matrix_shapes(matrix_shape=[None, 4]) 623 test_with_matrix_shapes(matrix_shape=[4, None]) 624 test_with_matrix_shapes(matrix_shape=[None, None]) 625 test_with_matrix_shapes(matrix_shape=None, rhs_shape=[None, None]) 626 test_with_matrix_shapes(matrix_shape=None) 627 628 @test_util.run_deprecated_v1 629 def testSequenceFormatWithUnknownDims(self): 630 if context.executing_eagerly(): 631 return 632 if test_util.is_xla_enabled() and self.pivoting: 633 # Pivoting is not supported by xla backends. 634 return 635 superdiag = array_ops.placeholder(dtypes.float64, shape=[None]) 636 diag = array_ops.placeholder(dtypes.float64, shape=[None]) 637 subdiag = array_ops.placeholder(dtypes.float64, shape=[None]) 638 rhs = array_ops.placeholder(dtypes.float64, shape=[None]) 639 640 x = linalg_impl.tridiagonal_solve((superdiag, diag, subdiag), 641 rhs, 642 diagonals_format="sequence", 643 partial_pivoting=self.pivoting) 644 with self.cached_session() as sess: 645 result = sess.run( 646 x, 647 feed_dict={ 648 subdiag: [20, 1, -1, 1], 649 diag: [1, 3, 2, 2], 650 superdiag: [2, 1, 4, 20], 651 rhs: [1, 2, 3, 4] 652 }) 653 self.assertAllClose(result, [-9, 5, -4, 4]) 654 655 # Benchmark 656 657 class TridiagonalSolveBenchmark(test.Benchmark): 658 sizes = [(100000, 1, 1), (1000000, 1, 1), (10000000, 1, 1), (100000, 10, 1), 659 (100000, 100, 1), (10000, 1, 10), (10000, 1, 100)] 660 661 pivoting_options = [(True, "pivoting"), (False, "no_pivoting")] 662 663 def _generateData(self, matrix_size, batch_size, num_rhs, seed=42): 664 np.random.seed(seed) 665 data = np.random.normal(size=(batch_size, matrix_size, 3 + num_rhs)) 666 diags = np.stack([data[:, :, 0], data[:, :, 1], data[:, :, 2]], axis=-2) 667 rhs = data[:, :, 3:] 668 return (variables.Variable(diags, dtype=dtypes.float64), 669 variables.Variable(rhs, dtype=dtypes.float64)) 670 671 def _generateMatrixData(self, matrix_size, batch_size, num_rhs, seed=42): 672 np.random.seed(seed) 673 import scipy.sparse as sparse # pylint:disable=g-import-not-at-top 674 # By being strictly diagonally dominant, we guarantee invertibility.d 675 diag = 2 * np.abs(np.random.randn(matrix_size)) + 4.1 676 subdiag = 2 * np.abs(np.random.randn(matrix_size - 1)) 677 superdiag = 2 * np.abs(np.random.randn(matrix_size - 1)) 678 matrix = sparse.diags([superdiag, diag, subdiag], [1, 0, -1]).toarray() 679 vector = np.random.randn(batch_size, matrix_size, num_rhs) 680 return (variables.Variable(np.tile(matrix, (batch_size, 1, 1))), 681 variables.Variable(vector)) 682 683 def _benchmark(self, generate_data_fn, test_name_format_string): 684 devices = [("/cpu:0", "cpu")] 685 if test.is_gpu_available(cuda_only=True): 686 devices += [("/gpu:0", "gpu")] 687 688 for device_option, pivoting_option, size_option in \ 689 itertools.product(devices, self.pivoting_options, self.sizes): 690 691 device_id, device_name = device_option 692 pivoting, pivoting_name = pivoting_option 693 matrix_size, batch_size, num_rhs = size_option 694 695 with ops.Graph().as_default(), \ 696 session.Session(config=benchmark.benchmark_config()) as sess, \ 697 ops.device(device_id): 698 diags, rhs = generate_data_fn(matrix_size, batch_size, num_rhs) 699 # Pivoting is not supported by XLA backends. 700 if test.is_xla_enabled() and pivoting: 701 return 702 x = linalg_impl.tridiagonal_solve( 703 diags, rhs, partial_pivoting=pivoting) 704 self.evaluate(variables.global_variables_initializer()) 705 self.run_op_benchmark( 706 sess, 707 control_flow_ops.group(x), 708 min_iters=10, 709 store_memory_usage=False, 710 name=test_name_format_string.format(device_name, matrix_size, 711 batch_size, num_rhs, 712 pivoting_name)) 713 714 def benchmarkTridiagonalSolveOp_WithMatrixInput(self): 715 self._benchmark( 716 self._generateMatrixData, 717 test_name_format_string=( 718 "tridiagonal_solve_matrix_format_{}_matrix_size_{}_" 719 "batch_size_{}_num_rhs_{}_{}")) 720 721 def benchmarkTridiagonalSolveOp(self): 722 self._benchmark( 723 self._generateMatrixData, 724 test_name_format_string=("tridiagonal_solve_{}_matrix_size_{}_" 725 "batch_size_{}_num_rhs_{}_{}")) 726 727 728if __name__ == "__main__": 729 for name, fun in dict(TridiagonalSolveOpTest.__dict__).items(): 730 if not name.startswith("test"): 731 continue 732 if hasattr(fun, FLAG_NO_PARAMETERIZATION): 733 continue 734 735 # Replace testFoo with testFoo_pivoting and testFoo_noPivoting, setting 736 # self.pivoting to corresponding value. 737 delattr(TridiagonalSolveOpTest, name) 738 739 def decor(test_fun, pivoting): 740 741 def wrapped(instance): 742 instance.pivoting = pivoting 743 test_fun(instance) 744 745 return wrapped 746 747 setattr(TridiagonalSolveOpTest, name + "_pivoting", 748 decor(fun, pivoting=True)) 749 if not hasattr(fun, FLAG_REQUIRES_PIVOTING): 750 setattr(TridiagonalSolveOpTest, name + "_noPivoting", 751 decor(fun, pivoting=False)) 752 753 test.main() 754