xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/linalg/tridiagonal_solve_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/linalg_ops.cc.
17 
18 #include <cmath>
19 
20 #include "tensorflow/core/framework/kernel_def_builder.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace tensorflow {
29 
30 static const char kNotInvertibleMsg[] = "The matrix is not invertible.";
31 
32 static const char kNotInvertibleScalarMsg[] =
33     "The matrix is not invertible: it is a scalar with value zero.";
34 
35 static const char kThomasFailedMsg[] =
36     "The matrix is either not invertible, or requires pivoting. "
37     "Try setting partial_pivoting = True.";
38 
39 template <class Scalar>
40 class TridiagonalSolveOp : public LinearAlgebraOp<Scalar> {
41  public:
42   INHERIT_LINALG_TYPEDEFS(Scalar);
43   using MatrixMapRow =
44       decltype(std::declval<const ConstMatrixMaps>()[0].row(0));
45 
TridiagonalSolveOp(OpKernelConstruction * context)46   explicit TridiagonalSolveOp(OpKernelConstruction* context) : Base(context) {
47     OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
48     perturb_singular_ = false;
49     if (context->HasAttr("perturb_singular")) {
50       OP_REQUIRES_OK(context,
51                      context->GetAttr("perturb_singular", &perturb_singular_));
52     }
53     OP_REQUIRES(context, pivoting_ || !perturb_singular_,
54                 errors::InvalidArgument("Setting perturb_singular requires "
55                                         "also setting partial_pivoting."));
56   }
57 
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const58   void ValidateInputMatrixShapes(
59       OpKernelContext* context,
60       const TensorShapes& input_matrix_shapes) const final {
61     auto num_inputs = input_matrix_shapes.size();
62     OP_REQUIRES(context, num_inputs == 2,
63                 errors::InvalidArgument("Expected two input matrices, got ",
64                                         num_inputs, "."));
65 
66     auto num_diags = input_matrix_shapes[0].dim_size(0);
67     OP_REQUIRES(
68         context, num_diags == 3,
69         errors::InvalidArgument("Expected diagonals to be provided as a "
70                                 "matrix with 3 rows, got ",
71                                 num_diags, " rows."));
72 
73     auto num_eqs_left = input_matrix_shapes[0].dim_size(1);
74     auto num_eqs_right = input_matrix_shapes[1].dim_size(0);
75     OP_REQUIRES(
76         context, num_eqs_left == num_eqs_right,
77         errors::InvalidArgument("Expected the same number of left-hand sides "
78                                 "and right-hand sides, got ",
79                                 num_eqs_left, " and ", num_eqs_right, "."));
80   }
81 
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const82   TensorShapes GetOutputMatrixShapes(
83       const TensorShapes& input_matrix_shapes) const final {
84     return TensorShapes({input_matrix_shapes[1]});
85   }
86 
GetCostPerUnit(const TensorShapes & input_matrix_shapes) const87   int64_t GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
88     const int num_eqs = static_cast<int>(input_matrix_shapes[0].dim_size(1));
89     const int num_rhss = static_cast<int>(input_matrix_shapes[1].dim_size(0));
90 
91     const double add_cost = Eigen::TensorOpCost::AddCost<Scalar>();
92     const double mult_cost = Eigen::TensorOpCost::MulCost<Scalar>();
93     const double div_cost = Eigen::TensorOpCost::DivCost<Scalar>();
94 
95     double cost;
96     if (pivoting_) {
97       // Assuming cases with and without row interchange are equiprobable.
98       cost = num_eqs * (div_cost * (num_rhss + 1) +
99                         (add_cost + mult_cost) * (2.5 * num_rhss + 1.5));
100     } else {
101       cost = num_eqs * (div_cost * (num_rhss + 1) +
102                         (add_cost + mult_cost) * (2 * num_rhss + 1));
103     }
104     return cost >= static_cast<double>(kint64max) ? kint64max
105                                                   : static_cast<int64_t>(cost);
106   }
107 
EnableInputForwarding() const108   bool EnableInputForwarding() const final { return false; }
109 
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)110   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
111                      MatrixMaps* outputs) final {
112     const auto diagonals = inputs[0];
113 
114     // Superdiagonal elements, first is ignored.
115     const auto& superdiag = diagonals.row(0);
116     // Diagonal elements.
117     const auto& diag = diagonals.row(1);
118     // Subdiagonal elements, n-th is ignored.
119     const auto& subdiag = diagonals.row(2);
120     // Right-hand sides.
121     const auto& rhs = inputs[1];
122 
123     const int n = diag.size();
124     MatrixMap& x = outputs->at(0);
125     constexpr Scalar zero(0);
126 
127     if (n == 0) {
128       return;
129     }
130     if (pivoting_ && perturb_singular_) {
131       SolveWithGaussianEliminationWithPivotingAndPerturbSingular(
132           context, superdiag, diag, subdiag, rhs, x);
133       return;
134     }
135 
136     if (n == 1) {
137       if (diag(0) == zero) {
138         LOG(WARNING) << kNotInvertibleScalarMsg;
139         x.fill(std::numeric_limits<Scalar>::quiet_NaN());
140       } else {
141         x.row(0) = rhs.row(0) / diag(0);
142       }
143       return;
144     }
145 
146     if (pivoting_) {
147       SolveWithGaussianEliminationWithPivoting(context, superdiag, diag,
148                                                subdiag, rhs, x);
149     } else {
150       SolveWithThomasAlgorithm(context, superdiag, diag, subdiag, rhs, x);
151     }
152   }
153 
154  private:
155   TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOp);
156 
157   // Adjust pivot such that neither 'rhs[i,:] / pivot' nor '1 / pivot' cause
158   // overflow, where i numerates the multiple right-hand-sides. During the
159   // back-substitution phase in
160   // SolveWithGaussianEliminationWithPivotingAndPerturbSingular, we compute
161   // the i'th row of the solution as rhs[i,:] * (1 / pivot). This logic is
162   // extracted from the LAPACK routine xLAGTS.
MaybePerturbPivot(RealScalar perturb,Scalar & pivot,Eigen::Matrix<Scalar,1,Eigen::Dynamic> & rhs_row)163   void MaybePerturbPivot(RealScalar perturb, Scalar& pivot,
164                          Eigen::Matrix<Scalar, 1, Eigen::Dynamic>& rhs_row) {
165     constexpr RealScalar one(1);
166     // The following logic is extracted from xLAMCH in LAPACK.
167     constexpr RealScalar tiny = std::numeric_limits<RealScalar>::min();
168     constexpr RealScalar small = one / std::numeric_limits<RealScalar>::max();
169     constexpr RealScalar safemin =
170         (small < tiny
171              ? tiny
172              : (one + std::numeric_limits<RealScalar>::epsilon()) * safemin);
173     constexpr RealScalar bignum = one / safemin;
174 
175     RealScalar abs_pivot = std::abs(pivot);
176     if (abs_pivot >= one) {
177       return;
178     }
179     // Safeguard against infinite loop if 'perturb' is zero.
180     // 'perturb' should never have magnitude smaller than safemin.
181     perturb = std::max(std::abs(perturb), safemin);
182     // Make sure perturb and pivot have the same sign.
183     perturb = std::copysign(perturb, std::real(pivot));
184 
185     bool stop = false;
186     const RealScalar max_factor = rhs_row.array().abs().maxCoeff();
187     while (abs_pivot < one && !stop) {
188       if (abs_pivot < safemin) {
189         if (abs_pivot == 0 || max_factor * safemin > abs_pivot) {
190           pivot += perturb;
191           perturb *= 2;
192         } else {
193           pivot *= bignum;
194           rhs_row *= bignum;
195           stop = true;
196         }
197       } else if (max_factor > abs_pivot * bignum) {
198         pivot += perturb;
199         perturb *= 2;
200       } else {
201         stop = true;
202       }
203       abs_pivot = std::abs(pivot);
204     }
205   }
206 
207   // This function roughly follows LAPACK's xLAGTF + xLAGTS routines.
208   //
209   // It computes the solution to the a linear system with multiple
210   // right-hand sides
211   //     T * X = RHS
212   // where T is a tridiagonal matrix using a row-pivoted LU decomposition.
213 
214   // This routine differs from SolveWithGaussianEliminationWithPivoting by
215   // allowing the tridiagonal matrix to be numerically singular.
216   // If tiny diagonal elements of U are encountered, signaling that T is
217   // numerically singular, the diagonal elements are perturbed by
218   // an amount proportional to eps*max_abs_u to avoid overflow, where
219   // max_abs_u is max_{i,j} | U(i,j) |. This is useful when using this
220   // routine for computing eigenvectors of a matrix T' via inverse
221   // iteration by solving the singular system
222   //   (T' - lambda*I) X = RHS,
223   // where lambda is an eigenvalue of T'.
224   //
225   // By fusing the factorization and solution, we avoid storing L
226   // and pivoting information, and the forward solve is done on-the-fly
227   // during factorization, instead of requiring a separate loop.
SolveWithGaussianEliminationWithPivotingAndPerturbSingular(OpKernelContext * context,const MatrixMapRow & superdiag,const MatrixMapRow & diag,const MatrixMapRow & subdiag,const ConstMatrixMap & rhs,MatrixMap & x)228   void SolveWithGaussianEliminationWithPivotingAndPerturbSingular(
229       OpKernelContext* context, const MatrixMapRow& superdiag,
230       const MatrixMapRow& diag, const MatrixMapRow& subdiag,
231       const ConstMatrixMap& rhs, MatrixMap& x) {
232     constexpr Scalar zero(0);
233     constexpr RealScalar realzero(0);
234     constexpr Scalar one(1);
235     constexpr RealScalar eps = std::numeric_limits<RealScalar>::epsilon();
236 
237     const int n = diag.size();
238     if (n == 0) return;
239     if (n == 1) {
240       Scalar denom = diag(0);
241       RealScalar tol = eps * std::abs(denom);
242       Eigen::Matrix<Scalar, 1, Eigen::Dynamic> row = rhs.row(0);
243       MaybePerturbPivot(tol, denom, row);
244       x = row * (one / denom);
245       return;
246     }
247 
248     // The three columns in u are the diagonal, superdiagonal, and second
249     // superdiagonal, respectively, of the U matrix in the LU decomposition
250     // of the input matrix (subject to row exchanges due to pivoting). For
251     // a pivoted tridiagonal matrix, the U matrix has at most two non-zero
252     // superdiagonals.
253     Eigen::Array<Scalar, Eigen::Dynamic, 3> u(n, 3);
254 
255     // We accumulate max( abs( U(i,j) ) ) in max_abs_u for use in perturbing
256     // near-zero pivots during the solution phase.
257     u(0, 0) = diag(0);
258     u(0, 1) = superdiag(0);
259     RealScalar max_abs_u = std::max(std::abs(u(0, 0)), std::abs(u(0, 1)));
260     RealScalar scale1 = std::abs(u(0, 0)) + std::abs(u(0, 1));
261     x.row(0) = rhs.row(0);
262     for (int k = 0; k < n - 1; ++k) {
263       // The non-zeros in the (k+1)-st row are
264       //    [ ... subdiag(k+1) (diag(k+1)-shift) superdiag(k+1) ... ]
265       u(k + 1, 0) = diag(k + 1);
266       RealScalar scale2 = std::abs(subdiag(k + 1)) + std::abs(u(k + 1, 0));
267       if (k < n - 2) scale2 += std::abs(superdiag(k + 1));
268       if (subdiag(k + 1) == zero) {
269         // The sub-diagonal in the k+1 row is already zero. Move to the next
270         // row.
271         scale1 = scale2;
272         u(k + 1, 1) = superdiag(k + 1);
273         u(k, 2) = zero;
274         x.row(k + 1) = rhs.row(k + 1);
275       } else {
276         const RealScalar piv1 =
277             u(k, 0) == zero ? realzero : std::abs(u(k, 0)) / scale1;
278         const RealScalar piv2 = std::abs(subdiag(k + 1)) / scale2;
279         if (piv2 <= piv1) {
280           // No row pivoting needed.
281           scale1 = scale2;
282           Scalar factor = subdiag(k + 1) / u(k, 0);
283           u(k + 1, 0) = diag(k + 1) - factor * u(k, 1);
284           u(k + 1, 1) = superdiag(k + 1);
285           u(k, 2) = zero;
286           x.row(k + 1) = rhs.row(k + 1) - factor * x.row(k);
287         } else {
288           // Swap rows k and k+1.
289           Scalar factor = u(k, 0) / subdiag(k + 1);
290           u(k, 0) = subdiag(k + 1);
291           u(k + 1, 0) = u(k, 1) - factor * diag(k + 1);
292           u(k, 1) = diag(k + 1);
293           if (k < n - 2) {
294             u(k, 2) = superdiag(k + 1);
295             u(k + 1, 1) = -factor * superdiag(k + 1);
296           }
297           x.row(k + 1) = x.row(k) - factor * rhs.row(k + 1);
298           x.row(k) = rhs.row(k + 1);
299         }
300       }
301       if (k < n - 2) {
302         for (int i = 0; i < 3; ++i) {
303           max_abs_u = std::max(max_abs_u, std::abs(u(k, i)));
304         }
305       }
306     }
307     max_abs_u = std::max(max_abs_u, std::abs(u(n - 1, 0)));
308 
309     // We have already solved L z = P rhs above. Now we solve U x = z,
310     // possibly perturbing small pivots to avoid overflow. The variable tol
311     // contains eps * max( abs( u(:,:) ) ). If tiny pivots are encountered,
312     // they are perturbed by a small amount on the scale of tol to avoid
313     // overflow or scaled up to avoid underflow.
314     RealScalar tol = eps * max_abs_u;
315     Scalar denom = u(n - 1, 0);
316     Eigen::Matrix<Scalar, 1, Eigen::Dynamic> row = x.row(n - 1);
317     MaybePerturbPivot(tol, denom, row);
318     x.row(n - 1) = row * (one / denom);
319     if (n > 1) {
320       denom = u(n - 2, 0);
321       row = x.row(n - 2) - u(n - 2, 1) * x.row(n - 1);
322       MaybePerturbPivot(std::copysign(tol, std::real(denom)), denom, row);
323       x.row(n - 2) = row * (one / denom);
324 
325       for (int k = n - 3; k >= 0; --k) {
326         row = x.row(k) - u(k, 1) * x.row(k + 1) - u(k, 2) * x.row(k + 2);
327         denom = u(k, 0);
328         MaybePerturbPivot(std::copysign(tol, std::real(denom)), denom, row);
329         x.row(k) = row * (one / denom);
330       }
331     }
332   }
333 
SolveWithGaussianEliminationWithPivoting(OpKernelContext * context,const MatrixMapRow & superdiag,const MatrixMapRow & diag,const MatrixMapRow & subdiag,const ConstMatrixMap & rhs,MatrixMap & x)334   void SolveWithGaussianEliminationWithPivoting(OpKernelContext* context,
335                                                 const MatrixMapRow& superdiag,
336                                                 const MatrixMapRow& diag,
337                                                 const MatrixMapRow& subdiag,
338                                                 const ConstMatrixMap& rhs,
339                                                 MatrixMap& x) {
340     const int n = diag.size();
341     const Scalar zero(0);
342 
343     // The three columns in u are the diagonal, superdiagonal, and second
344     // superdiagonal, respectively, of the U matrix in the LU decomposition of
345     // the input matrix (subject to row exchanges due to pivoting). For pivoted
346     // tridiagonal matrix, the U matrix has at most two non-zero superdiagonals.
347     Eigen::Array<Scalar, Eigen::Dynamic, 3> u(n, 3);
348 
349     // The code below roughly follows LAPACK's dgtsv routine, with main
350     // difference being not overwriting the input.
351     u(0, 0) = diag(0);
352     u(0, 1) = superdiag(0);
353     x.row(0) = rhs.row(0);
354     for (int i = 0; i < n - 1; ++i) {
355       if (std::abs(u(i)) >= std::abs(subdiag(i + 1))) {
356         // No row interchange.
357         if (u(i) == zero) {
358           LOG(WARNING) << kNotInvertibleMsg;
359           x.fill(std::numeric_limits<Scalar>::quiet_NaN());
360           return;
361         }
362         const Scalar factor = subdiag(i + 1) / u(i, 0);
363         u(i + 1, 0) = diag(i + 1) - factor * u(i, 1);
364         x.row(i + 1) = rhs.row(i + 1) - factor * x.row(i);
365         if (i != n - 2) {
366           u(i + 1, 1) = superdiag(i + 1);
367           u(i, 2) = 0;
368         }
369       } else {
370         // Interchange rows i and i + 1.
371         const Scalar factor = u(i, 0) / subdiag(i + 1);
372         u(i, 0) = subdiag(i + 1);
373         u(i + 1, 0) = u(i, 1) - factor * diag(i + 1);
374         u(i, 1) = diag(i + 1);
375         x.row(i + 1) = x.row(i) - factor * rhs.row(i + 1);
376         x.row(i) = rhs.row(i + 1);
377         if (i != n - 2) {
378           u(i, 2) = superdiag(i + 1);
379           u(i + 1, 1) = -factor * superdiag(i + 1);
380         }
381       }
382     }
383     if (u(n - 1, 0) == zero) {
384       LOG(WARNING) << kNotInvertibleMsg;
385       x.fill(std::numeric_limits<Scalar>::quiet_NaN());
386       return;
387     }
388     x.row(n - 1) /= u(n - 1, 0);
389     x.row(n - 2) = (x.row(n - 2) - u(n - 2, 1) * x.row(n - 1)) / u(n - 2, 0);
390     for (int i = n - 3; i >= 0; --i) {
391       x.row(i) = (x.row(i) - u(i, 1) * x.row(i + 1) - u(i, 2) * x.row(i + 2)) /
392                  u(i, 0);
393     }
394   }
395 
SolveWithThomasAlgorithm(OpKernelContext * context,const MatrixMapRow & superdiag,const MatrixMapRow & diag,const MatrixMapRow & subdiag,const ConstMatrixMap & rhs,MatrixMap & x)396   void SolveWithThomasAlgorithm(OpKernelContext* context,
397                                 const MatrixMapRow& superdiag,
398                                 const MatrixMapRow& diag,
399                                 const MatrixMapRow& subdiag,
400                                 const ConstMatrixMap& rhs, MatrixMap& x) {
401     const int n = diag.size();
402     const Scalar zero(0);
403 
404     // The superdiagonal of the U matrix in the LU decomposition of the input
405     // matrix (in Thomas algorithm, the U matrix has ones on the diagonal and
406     // one superdiagonal).
407     Eigen::Matrix<Scalar, Eigen::Dynamic, 1> u(n);
408 
409     if (diag(0) == zero) {
410       LOG(WARNING) << kThomasFailedMsg;
411       x.fill(std::numeric_limits<Scalar>::quiet_NaN());
412       return;
413     }
414 
415     u(0) = superdiag(0) / diag(0);
416     x.row(0) = rhs.row(0) / diag(0);
417     for (int i = 1; i < n; ++i) {
418       auto denom = diag(i) - subdiag(i) * u(i - 1);
419       if (denom == zero) {
420         LOG(WARNING) << kThomasFailedMsg;
421         x.fill(std::numeric_limits<Scalar>::quiet_NaN());
422         return;
423       }
424       u(i) = superdiag(i) / denom;
425       x.row(i) = (rhs.row(i) - subdiag(i) * x.row(i - 1)) / denom;
426     }
427     for (int i = n - 2; i >= 0; --i) {
428       x.row(i) -= u(i) * x.row(i + 1);
429     }
430   }
431 
432   bool pivoting_;
433   bool perturb_singular_;
434 };
435 
436 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<float>), float);
437 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<double>),
438                        double);
439 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<complex64>),
440                        complex64);
441 REGISTER_LINALG_OP_CPU("TridiagonalSolve", (TridiagonalSolveOp<complex128>),
442                        complex128);
443 }  // namespace tensorflow
444