1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/linalg_ops.cc. 17 18 #include <cmath> 19 20 #include "tensorflow/core/framework/kernel_def_builder.h" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/register_types.h" 23 #include "tensorflow/core/framework/tensor_shape.h" 24 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/platform/types.h" 27 28 namespace tensorflow { 29 30 static const char kNotInvertibleMsg[] = "The matrix is not invertible."; 31 32 static const char kNotInvertibleScalarMsg[] = 33 "The matrix is not invertible: it is a scalar with value zero."; 34 35 static const char kThomasFailedMsg[] = 36 "The matrix is either not invertible, or requires pivoting. " 37 "Try setting partial_pivoting = True."; 38 39 template <class Scalar> 40 class TridiagonalSolveOp : public LinearAlgebraOp<Scalar> { 41 public: 42 INHERIT_LINALG_TYPEDEFS(Scalar); 43 using MatrixMapRow = 44 decltype(std::declval<const ConstMatrixMaps>()[0].row(0)); 45 TridiagonalSolveOp(OpKernelConstruction * context)46 explicit TridiagonalSolveOp(OpKernelConstruction* context) : Base(context) { 47 OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_)); 48 perturb_singular_ = false; 49 if (context->HasAttr("perturb_singular")) { 50 OP_REQUIRES_OK(context, 51 context->GetAttr("perturb_singular", &perturb_singular_)); 52 } 53 OP_REQUIRES(context, pivoting_ || !perturb_singular_, 54 errors::InvalidArgument("Setting perturb_singular requires " 55 "also setting partial_pivoting.")); 56 } 57 ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const58 void ValidateInputMatrixShapes( 59 OpKernelContext* context, 60 const TensorShapes& input_matrix_shapes) const final { 61 auto num_inputs = input_matrix_shapes.size(); 62 OP_REQUIRES(context, num_inputs == 2, 63 errors::InvalidArgument("Expected two input matrices, got ", 64 num_inputs, ".")); 65 66 auto num_diags = input_matrix_shapes[0].dim_size(0); 67 OP_REQUIRES( 68 context, num_diags == 3, 69 errors::InvalidArgument("Expected diagonals to be provided as a " 70 "matrix with 3 rows, got ", 71 num_diags, " rows.")); 72 73 auto num_eqs_left = input_matrix_shapes[0].dim_size(1); 74 auto num_eqs_right = input_matrix_shapes[1].dim_size(0); 75 OP_REQUIRES( 76 context, num_eqs_left == num_eqs_right, 77 errors::InvalidArgument("Expected the same number of left-hand sides " 78 "and right-hand sides, got ", 79 num_eqs_left, " and ", num_eqs_right, ".")); 80 } 81 GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const82 TensorShapes GetOutputMatrixShapes( 83 const TensorShapes& input_matrix_shapes) const final { 84 return TensorShapes({input_matrix_shapes[1]}); 85 } 86 GetCostPerUnit(const TensorShapes & input_matrix_shapes) const87 int64_t GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { 88 const int num_eqs = static_cast<int>(input_matrix_shapes[0].dim_size(1)); 89 const int num_rhss = static_cast<int>(input_matrix_shapes[1].dim_size(0)); 90 91 const double add_cost = Eigen::TensorOpCost::AddCost<Scalar>(); 92 const double mult_cost = Eigen::TensorOpCost::MulCost<Scalar>(); 93 const double div_cost = Eigen::TensorOpCost::DivCost<Scalar>(); 94 95 double cost; 96 if (pivoting_) { 97 // Assuming cases with and without row interchange are equiprobable. 98 cost = num_eqs * (div_cost * (num_rhss + 1) + 99 (add_cost + mult_cost) * (2.5 * num_rhss + 1.5)); 100 } else { 101 cost = num_eqs * (div_cost * (num_rhss + 1) + 102 (add_cost + mult_cost) * (2 * num_rhss + 1)); 103 } 104 return cost >= static_cast<double>(kint64max) ? kint64max 105 : static_cast<int64_t>(cost); 106 } 107 EnableInputForwarding() const108 bool EnableInputForwarding() const final { return false; } 109 ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)110 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 111 MatrixMaps* outputs) final { 112 const auto diagonals = inputs[0]; 113 114 // Superdiagonal elements, first is ignored. 115 const auto& superdiag = diagonals.row(0); 116 // Diagonal elements. 117 const auto& diag = diagonals.row(1); 118 // Subdiagonal elements, n-th is ignored. 119 const auto& subdiag = diagonals.row(2); 120 // Right-hand sides. 121 const auto& rhs = inputs[1]; 122 123 const int n = diag.size(); 124 MatrixMap& x = outputs->at(0); 125 constexpr Scalar zero(0); 126 127 if (n == 0) { 128 return; 129 } 130 if (pivoting_ && perturb_singular_) { 131 SolveWithGaussianEliminationWithPivotingAndPerturbSingular( 132 context, superdiag, diag, subdiag, rhs, x); 133 return; 134 } 135 136 if (n == 1) { 137 if (diag(0) == zero) { 138 LOG(WARNING) << kNotInvertibleScalarMsg; 139 x.fill(std::numeric_limits<Scalar>::quiet_NaN()); 140 } else { 141 x.row(0) = rhs.row(0) / diag(0); 142 } 143 return; 144 } 145 146 if (pivoting_) { 147 SolveWithGaussianEliminationWithPivoting(context, superdiag, diag, 148 subdiag, rhs, x); 149 } else { 150 SolveWithThomasAlgorithm(context, superdiag, diag, subdiag, rhs, x); 151 } 152 } 153 154 private: 155 TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOp); 156 157 // Adjust pivot such that neither 'rhs[i,:] / pivot' nor '1 / pivot' cause 158 // overflow, where i numerates the multiple right-hand-sides. During the 159 // back-substitution phase in 160 // SolveWithGaussianEliminationWithPivotingAndPerturbSingular, we compute 161 // the i'th row of the solution as rhs[i,:] * (1 / pivot). This logic is 162 // extracted from the LAPACK routine xLAGTS. MaybePerturbPivot(RealScalar perturb,Scalar & pivot,Eigen::Matrix<Scalar,1,Eigen::Dynamic> & rhs_row)163 void MaybePerturbPivot(RealScalar perturb, Scalar& pivot, 164 Eigen::Matrix<Scalar, 1, Eigen::Dynamic>& rhs_row) { 165 constexpr RealScalar one(1); 166 // The following logic is extracted from xLAMCH in LAPACK. 167 constexpr RealScalar tiny = std::numeric_limits<RealScalar>::min(); 168 constexpr RealScalar small = one / std::numeric_limits<RealScalar>::max(); 169 constexpr RealScalar safemin = 170 (small < tiny 171 ? tiny 172 : (one + std::numeric_limits<RealScalar>::epsilon()) * safemin); 173 constexpr RealScalar bignum = one / safemin; 174 175 RealScalar abs_pivot = std::abs(pivot); 176 if (abs_pivot >= one) { 177 return; 178 } 179 // Safeguard against infinite loop if 'perturb' is zero. 180 // 'perturb' should never have magnitude smaller than safemin. 181 perturb = std::max(std::abs(perturb), safemin); 182 // Make sure perturb and pivot have the same sign. 183 perturb = std::copysign(perturb, std::real(pivot)); 184 185 bool stop = false; 186 const RealScalar max_factor = rhs_row.array().abs().maxCoeff(); 187 while (abs_pivot < one && !stop) { 188 if (abs_pivot < safemin) { 189 if (abs_pivot == 0 || max_factor * safemin > abs_pivot) { 190 pivot += perturb; 191 perturb *= 2; 192 } else { 193 pivot *= bignum; 194 rhs_row *= bignum; 195 stop = true; 196 } 197 } else if (max_factor > abs_pivot * bignum) { 198 pivot += perturb; 199 perturb *= 2; 200 } else { 201 stop = true; 202 } 203 abs_pivot = std::abs(pivot); 204 } 205 } 206 207 // This function roughly follows LAPACK's xLAGTF + xLAGTS routines. 208 // 209 // It computes the solution to the a linear system with multiple 210 // right-hand sides 211 // T * X = RHS 212 // where T is a tridiagonal matrix using a row-pivoted LU decomposition. 213 214 // This routine differs from SolveWithGaussianEliminationWithPivoting by 215 // allowing the tridiagonal matrix to be numerically singular. 216 // If tiny diagonal elements of U are encountered, signaling that T is 217 // numerically singular, the diagonal elements are perturbed by 218 // an amount proportional to eps*max_abs_u to avoid overflow, where 219 // max_abs_u is max_{i,j} | U(i,j) |. This is useful when using this 220 // routine for computing eigenvectors of a matrix T' via inverse 221 // iteration by solving the singular system 222 // (T' - lambda*I) X = RHS, 223 // where lambda is an eigenvalue of T'. 224 // 225 // By fusing the factorization and solution, we avoid storing L 226 // and pivoting information, and the forward solve is done on-the-fly 227 // during factorization, instead of requiring a separate loop. SolveWithGaussianEliminationWithPivotingAndPerturbSingular(OpKernelContext * context,const MatrixMapRow & superdiag,const MatrixMapRow & diag,const MatrixMapRow & subdiag,const ConstMatrixMap & rhs,MatrixMap & x)228 void SolveWithGaussianEliminationWithPivotingAndPerturbSingular( 229 OpKernelContext* context, const MatrixMapRow& superdiag, 230 const MatrixMapRow& diag, const MatrixMapRow& subdiag, 231 const ConstMatrixMap& rhs, MatrixMap& x) { 232 constexpr Scalar zero(0); 233 constexpr RealScalar realzero(0); 234 constexpr Scalar one(1); 235 constexpr RealScalar eps = std::numeric_limits<RealScalar>::epsilon(); 236 237 const int n = diag.size(); 238 if (n == 0) return; 239 if (n == 1) { 240 Scalar denom = diag(0); 241 RealScalar tol = eps * std::abs(denom); 242 Eigen::Matrix<Scalar, 1, Eigen::Dynamic> row = rhs.row(0); 243 MaybePerturbPivot(tol, denom, row); 244 x = row * (one / denom); 245 return; 246 } 247 248 // The three columns in u are the diagonal, superdiagonal, and second 249 // superdiagonal, respectively, of the U matrix in the LU decomposition 250 // of the input matrix (subject to row exchanges due to pivoting). For 251 // a pivoted tridiagonal matrix, the U matrix has at most two non-zero 252 // superdiagonals. 253 Eigen::Array<Scalar, Eigen::Dynamic, 3> u(n, 3); 254 255 // We accumulate max( abs( U(i,j) ) ) in max_abs_u for use in perturbing 256 // near-zero pivots during the solution phase. 257 u(0, 0) = diag(0); 258 u(0, 1) = superdiag(0); 259 RealScalar max_abs_u = std::max(std::abs(u(0, 0)), std::abs(u(0, 1))); 260 RealScalar scale1 = std::abs(u(0, 0)) + std::abs(u(0, 1)); 261 x.row(0) = rhs.row(0); 262 for (int k = 0; k < n - 1; ++k) { 263 // The non-zeros in the (k+1)-st row are 264 // [ ... subdiag(k+1) (diag(k+1)-shift) superdiag(k+1) ... ] 265 u(k + 1, 0) = diag(k + 1); 266 RealScalar scale2 = std::abs(subdiag(k + 1)) + std::abs(u(k + 1, 0)); 267 if (k < n - 2) scale2 += std::abs(superdiag(k + 1)); 268 if (subdiag(k + 1) == zero) { 269 // The sub-diagonal in the k+1 row is already zero. Move to the next 270 // row. 271 scale1 = scale2; 272 u(k + 1, 1) = superdiag(k + 1); 273 u(k, 2) = zero; 274 x.row(k + 1) = rhs.row(k + 1); 275 } else { 276 const RealScalar piv1 = 277 u(k, 0) == zero ? realzero : std::abs(u(k, 0)) / scale1; 278 const RealScalar piv2 = std::abs(subdiag(k + 1)) / scale2; 279 if (piv2 <= piv1) { 280 // No row pivoting needed. 281 scale1 = scale2; 282 Scalar factor = subdiag(k + 1) / u(k, 0); 283 u(k + 1, 0) = diag(k + 1) - factor * u(k, 1); 284 u(k + 1, 1) = superdiag(k + 1); 285 u(k, 2) = zero; 286 x.row(k + 1) = rhs.row(k + 1) - factor * x.row(k); 287 } else { 288 // Swap rows k and k+1. 289 Scalar factor = u(k, 0) / subdiag(k + 1); 290 u(k, 0) = subdiag(k + 1); 291 u(k + 1, 0) = u(k, 1) - factor * diag(k + 1); 292 u(k, 1) = diag(k + 1); 293 if (k < n - 2) { 294 u(k, 2) = superdiag(k + 1); 295 u(k + 1, 1) = -factor * superdiag(k + 1); 296 } 297 x.row(k + 1) = x.row(k) - factor * rhs.row(k + 1); 298 x.row(k) = rhs.row(k + 1); 299 } 300 } 301 if (k < n - 2) { 302 for (int i = 0; i < 3; ++i) { 303 max_abs_u = std::max(max_abs_u, std::abs(u(k, i))); 304 } 305 } 306 } 307 max_abs_u = std::max(max_abs_u, std::abs(u(n - 1, 0))); 308 309 // We have already solved L z = P rhs above. Now we solve U x = z, 310 // possibly perturbing small pivots to avoid overflow. The variable tol 311 // contains eps * max( abs( u(:,:) ) ). If tiny pivots are encountered, 312 // they are perturbed by a small amount on the scale of tol to avoid 313 // overflow or scaled up to avoid underflow. 314 RealScalar tol = eps * max_abs_u; 315 Scalar denom = u(n - 1, 0); 316 Eigen::Matrix<Scalar, 1, Eigen::Dynamic> row = x.row(n - 1); 317 MaybePerturbPivot(tol, denom, row); 318 x.row(n - 1) = row * (one / denom); 319 if (n > 1) { 320 denom = u(n - 2, 0); 321 row = x.row(n - 2) - u(n - 2, 1) * x.row(n - 1); 322 MaybePerturbPivot(std::copysign(tol, std::real(denom)), denom, row); 323 x.row(n - 2) = row * (one / denom); 324 325 for (int k = n - 3; k >= 0; --k) { 326 row = x.row(k) - u(k, 1) * x.row(k + 1) - u(k, 2) * x.row(k + 2); 327 denom = u(k, 0); 328 MaybePerturbPivot(std::copysign(tol, std::real(denom)), denom, row); 329 x.row(k) = row * (one / denom); 330 } 331 } 332 } 333 SolveWithGaussianEliminationWithPivoting(OpKernelContext * context,const MatrixMapRow & superdiag,const MatrixMapRow & diag,const MatrixMapRow & subdiag,const ConstMatrixMap & rhs,MatrixMap & x)334 void SolveWithGaussianEliminationWithPivoting(OpKernelContext* context, 335 const MatrixMapRow& superdiag, 336 const MatrixMapRow& diag, 337 const MatrixMapRow& subdiag, 338 const ConstMatrixMap& rhs, 339 MatrixMap& x) { 340 const int n = diag.size(); 341 const Scalar zero(0); 342 343 // The three columns in u are the diagonal, superdiagonal, and second 344 // superdiagonal, respectively, of the U matrix in the LU decomposition of 345 // the input matrix (subject to row exchanges due to pivoting). For pivoted 346 // tridiagonal matrix, the U matrix has at most two non-zero superdiagonals. 347 Eigen::Array<Scalar, Eigen::Dynamic, 3> u(n, 3); 348 349 // The code below roughly follows LAPACK's dgtsv routine, with main 350 // difference being not overwriting the input. 351 u(0, 0) = diag(0); 352 u(0, 1) = superdiag(0); 353 x.row(0) = rhs.row(0); 354 for (int i = 0; i < n - 1; ++i) { 355 if (std::abs(u(i)) >= std::abs(subdiag(i + 1))) { 356 // No row interchange. 357 if (u(i) == zero) { 358 LOG(WARNING) << kNotInvertibleMsg; 359 x.fill(std::numeric_limits<Scalar>::quiet_NaN()); 360 return; 361 } 362 const Scalar factor = subdiag(i + 1) / u(i, 0); 363 u(i + 1, 0) = diag(i + 1) - factor * u(i, 1); 364 x.row(i + 1) = rhs.row(i + 1) - factor * x.row(i); 365 if (i != n - 2) { 366 u(i + 1, 1) = superdiag(i + 1); 367 u(i, 2) = 0; 368 } 369 } else { 370 // Interchange rows i and i + 1. 371 const Scalar factor = u(i, 0) / subdiag(i + 1); 372 u(i, 0) = subdiag(i + 1); 373 u(i + 1, 0) = u(i, 1) - factor * diag(i + 1); 374 u(i, 1) = diag(i + 1); 375 x.row(i + 1) = x.row(i) - factor * rhs.row(i + 1); 376 x.row(i) = rhs.row(i + 1); 377 if (i != n - 2) { 378 u(i, 2) = superdiag(i + 1); 379 u(i + 1, 1) = -factor * superdiag(i + 1); 380 } 381 } 382 } 383 if (u(n - 1, 0) == zero) { 384 LOG(WARNING) << kNotInvertibleMsg; 385 x.fill(std::numeric_limits<Scalar>::quiet_NaN()); 386 return; 387 } 388 x.row(n - 1) /= u(n - 1, 0); 389 x.row(n - 2) = (x.row(n - 2) - u(n - 2, 1) * x.row(n - 1)) / u(n - 2, 0); 390 for (int i = n - 3; i >= 0; --i) { 391 x.row(i) = (x.row(i) - u(i, 1) * x.row(i + 1) - u(i, 2) * x.row(i + 2)) / 392 u(i, 0); 393 } 394 } 395 SolveWithThomasAlgorithm(OpKernelContext * context,const MatrixMapRow & superdiag,const MatrixMapRow & diag,const MatrixMapRow & subdiag,const ConstMatrixMap & rhs,MatrixMap & x)396 void SolveWithThomasAlgorithm(OpKernelContext* context, 397 const MatrixMapRow& superdiag, 398 const MatrixMapRow& diag, 399 const MatrixMapRow& subdiag, 400 const ConstMatrixMap& rhs, MatrixMap& x) { 401 const int n = diag.size(); 402 const Scalar zero(0); 403 404 // The superdiagonal of the U matrix in the LU decomposition of the input 405 // matrix (in Thomas algorithm, the U matrix has ones on the diagonal and 406 // one superdiagonal). 407 Eigen::Matrix<Scalar, Eigen::Dynamic, 1> u(n); 408 409 if (diag(0) == zero) { 410 LOG(WARNING) << kThomasFailedMsg; 411 x.fill(std::numeric_limits<Scalar>::quiet_NaN()); 412 return; 413 } 414 415 u(0) = superdiag(0) / diag(0); 416 x.row(0) = rhs.row(0) / diag(0); 417 for (int i = 1; i < n; ++i) { 418 auto denom = diag(i) - subdiag(i) * u(i - 1); 419 if (denom == zero) { 420 LOG(WARNING) << kThomasFailedMsg; 421 x.fill(std::numeric_limits<Scalar>::quiet_NaN()); 422 return; 423 } 424 u(i) = superdiag(i) / denom; 425 x.row(i) = (rhs.row(i) - subdiag(i) * x.row(i - 1)) / denom; 426 } 427 for (int i = n - 2; i >= 0; --i) { 428 x.row(i) -= u(i) * x.row(i + 1); 429 } 430 } 431 432 bool pivoting_; 433 bool perturb_singular_; 434 }; 435 436 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<float>), float); 437 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<double>), 438 double); 439 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<complex64>), 440 complex64); 441 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<complex128>), 442 complex128); 443 } // namespace tensorflow 444