xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/linalg/tridiagonal_solve_op_gpu.cu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/linalg_ops.cc.
17 
18 #ifdef GOOGLE_CUDA
19 
20 #define EIGEN_USE_GPU
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
27 #include "tensorflow/core/kernels/transpose_functor.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/util/cuda_sparse.h"
30 #include "tensorflow/core/util/gpu_device_functions.h"
31 #include "tensorflow/core/util/gpu_kernel_helper.h"
32 #include "tensorflow/core/util/gpu_launch_config.h"
33 #include "tensorflow/core/util/gpu_solvers.h"
34 
35 namespace tensorflow {
36 
37 static const char kNotInvertibleMsg[] = "The matrix is not invertible.";
38 
39 static const char kNotInvertibleScalarMsg[] =
40     "The matrix is not invertible: it is a scalar with value zero.";
41 
42 template <typename Scalar>
SolveForSizeOneOrTwoKernel(const int m,const Scalar * __restrict__ diags,const Scalar * __restrict__ rhs,const int num_rhs,Scalar * __restrict__ x)43 __global__ void SolveForSizeOneOrTwoKernel(const int m,
44                                            const Scalar* __restrict__ diags,
45                                            const Scalar* __restrict__ rhs,
46                                            const int num_rhs,
47                                            Scalar* __restrict__ x) {
48   const Scalar nan = Eigen::NumTraits<Scalar>::quiet_NaN();
49   if (m == 1) {
50     bool singular = diags[1] == Scalar(0);
51     for (int i : GpuGridRangeX(num_rhs)) {
52       x[i] = singular ? nan : rhs[i] / diags[1];
53     }
54   } else {
55     const Scalar det = diags[2] * diags[3] - diags[0] * diags[5];
56     bool singular = det == Scalar(0);
57     for (int i : GpuGridRangeX(num_rhs)) {
58       x[i] = singular ? nan
59                       : (diags[3] * rhs[i] - diags[0] * rhs[i + num_rhs]) / det;
60       x[i + num_rhs] =
61           singular ? nan
62                    : (diags[2] * rhs[i + num_rhs] - diags[5] * rhs[i]) / det;
63     }
64   }
65 }
66 
67 template <typename Scalar>
AsDeviceMemory(const Scalar * cuda_memory)68 se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* cuda_memory) {
69   se::DeviceMemoryBase wrapped(const_cast<Scalar*>(cuda_memory));
70   se::DeviceMemory<Scalar> typed(wrapped);
71   return typed;
72 }
73 
74 template <typename Scalar>
CopyDeviceToDevice(OpKernelContext * context,const Scalar * src,Scalar * dst,const int num_elements)75 void CopyDeviceToDevice(OpKernelContext* context, const Scalar* src,
76                         Scalar* dst, const int num_elements) {
77   auto src_device_mem = AsDeviceMemory(src);
78   auto dst_device_mem = AsDeviceMemory(dst);
79   auto* stream = context->op_device_context()->stream();
80   bool copy_status = stream
81                          ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
82                                          sizeof(Scalar) * num_elements)
83                          .ok();
84 
85   if (!copy_status) {
86     context->SetStatus(errors::Internal("Copying device-to-device failed."));
87   }
88 }
89 
90 // This implementation is used in cases when the batching mechanism of
91 // LinearAlgebraOp is suitable. See TridiagonalSolveOpGpu below.
92 template <class Scalar>
93 class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
94  public:
95   INHERIT_LINALG_TYPEDEFS(Scalar);
96 
TridiagonalSolveOpGpuLinalg(OpKernelConstruction * context)97   explicit TridiagonalSolveOpGpuLinalg(OpKernelConstruction* context)
98       : Base(context) {
99     OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
100     perturb_singular_ = false;
101     if (context->HasAttr("perturb_singular")) {
102       OP_REQUIRES_OK(context,
103                      context->GetAttr("perturb_singular", &perturb_singular_));
104     }
105     OP_REQUIRES(
106         context, perturb_singular_ == false,
107         errors::Unimplemented("The solver to support perturb_singular is"
108                               " not implemented on GPU."));
109   }
110 
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const111   void ValidateInputMatrixShapes(
112       OpKernelContext* context,
113       const TensorShapes& input_matrix_shapes) const final {
114     auto num_inputs = input_matrix_shapes.size();
115     OP_REQUIRES(context, num_inputs == 2,
116                 errors::InvalidArgument("Expected two input matrices, got ",
117                                         num_inputs, "."));
118 
119     auto num_diags = input_matrix_shapes[0].dim_size(0);
120     OP_REQUIRES(
121         context, num_diags == 3,
122         errors::InvalidArgument("Expected diagonals to be provided as a "
123                                 "matrix with 3 columns, got ",
124                                 num_diags, " columns."));
125 
126     auto num_rows1 = input_matrix_shapes[0].dim_size(1);
127     auto num_rows2 = input_matrix_shapes[1].dim_size(0);
128     OP_REQUIRES(context, num_rows1 == num_rows2,
129                 errors::InvalidArgument("Expected same number of rows in both "
130                                         "arguments, got ",
131                                         num_rows1, " and ", num_rows2, "."));
132   }
133 
EnableInputForwarding() const134   bool EnableInputForwarding() const final { return false; }
135 
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const136   TensorShapes GetOutputMatrixShapes(
137       const TensorShapes& input_matrix_shapes) const final {
138     return TensorShapes({input_matrix_shapes[1]});
139   }
140 
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)141   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
142                      MatrixMaps* outputs) final {
143     const auto diagonals = inputs[0];
144     // Superdiagonal elements, first is ignored.
145     const auto& superdiag = diagonals.row(0);
146     // Diagonal elements.
147     const auto& diag = diagonals.row(1);
148     // Subdiagonal elements, last is ignored.
149     const auto& subdiag = diagonals.row(2);
150     // Right-hand sides.
151     const auto& rhs = inputs[1];
152     MatrixMap& x = outputs->at(0);
153     const int m = diag.size();
154     const int k = rhs.cols();
155 
156     if (m == 0) {
157       return;
158     }
159     if (m < 3) {
160       // Cusparse gtsv routine requires m >= 3. Solving manually for m < 3.
161       SolveForSizeOneOrTwo(context, diagonals.data(), rhs.data(), x.data(), m,
162                            k);
163       return;
164     }
165     std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
166     OP_REQUIRES_OK(context, cusparse_solver->Initialize());
167     if (k == 1) {
168       // rhs is copied into x, then gtsv replaces x with solution.
169       CopyDeviceToDevice(context, rhs.data(), x.data(), m);
170       SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(),
171                     subdiag.data(), x.data(), m, 1);
172     } else {
173       // Gtsv expects rhs in column-major form, so we have to transpose.
174       // rhs is transposed into temp, gtsv replaces temp with solution, then
175       // temp is transposed into x.
176       std::unique_ptr<GpuSolver> cublas_solver(new GpuSolver(context));
177       Tensor temp;
178       TensorShape temp_shape({k, m});
179       OP_REQUIRES_OK(context,
180                      cublas_solver->allocate_scoped_tensor(
181                          DataTypeToEnum<Scalar>::value, temp_shape, &temp));
182       TransposeWithGeam(context, cublas_solver, rhs.data(),
183                         temp.flat<Scalar>().data(), m, k);
184       SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(),
185                     subdiag.data(), temp.flat<Scalar>().data(), m, k);
186       TransposeWithGeam(context, cublas_solver, temp.flat<Scalar>().data(),
187                         x.data(), k, m);
188     }
189   }
190 
191  private:
TransposeWithGeam(OpKernelContext * context,const std::unique_ptr<GpuSolver> & cublas_solver,const Scalar * src,Scalar * dst,const int src_rows,const int src_cols) const192   void TransposeWithGeam(OpKernelContext* context,
193                          const std::unique_ptr<GpuSolver>& cublas_solver,
194                          const Scalar* src, Scalar* dst, const int src_rows,
195                          const int src_cols) const {
196     const Scalar zero(0), one(1);
197     OP_REQUIRES_OK(context,
198                    cublas_solver->Geam(CUBLAS_OP_T, CUBLAS_OP_N, src_rows,
199                                        src_cols, &one, src, src_cols, &zero,
200                                        static_cast<const Scalar*>(nullptr),
201                                        src_rows, dst, src_rows));
202   }
203 
SolveWithGtsv(OpKernelContext * context,std::unique_ptr<GpuSparse> & cusparse_solver,const Scalar * superdiag,const Scalar * diag,const Scalar * subdiag,Scalar * rhs,const int num_eqs,const int num_rhs) const204   void SolveWithGtsv(OpKernelContext* context,
205                      std::unique_ptr<GpuSparse>& cusparse_solver,
206                      const Scalar* superdiag, const Scalar* diag,
207                      const Scalar* subdiag, Scalar* rhs, const int num_eqs,
208                      const int num_rhs) const {
209     auto buffer_function = pivoting_
210                                ? &GpuSparse::Gtsv2BufferSizeExt<Scalar>
211                                : &GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>;
212     size_t buffer_size;
213     OP_REQUIRES_OK(context, (cusparse_solver.get()->*buffer_function)(
214                                 num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
215                                 num_eqs, &buffer_size));
216     Tensor temp_tensor;
217     TensorShape temp_shape({static_cast<int64_t>(buffer_size)});
218     OP_REQUIRES_OK(context,
219                    context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
220     void* buffer = temp_tensor.flat<std::uint8_t>().data();
221 
222     auto solver_function = pivoting_ ? &GpuSparse::Gtsv2<Scalar>
223                                      : &GpuSparse::Gtsv2NoPivot<Scalar>;
224     OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)(
225                                 num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
226                                 num_eqs, buffer));
227   }
228 
SolveForSizeOneOrTwo(OpKernelContext * context,const Scalar * diagonals,const Scalar * rhs,Scalar * output,int m,int k)229   void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals,
230                             const Scalar* rhs, Scalar* output, int m, int k) {
231     const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
232     GpuLaunchConfig cfg = GetGpuLaunchConfig(
233         /*work_element_count=*/1, device, &SolveForSizeOneOrTwoKernel<Scalar>,
234         /*dynamic_shared_memory_size=*/0,
235         /*block_size_limit=*/0);
236     TF_CHECK_OK(GpuLaunchKernel(SolveForSizeOneOrTwoKernel<Scalar>,
237                                 cfg.block_count, cfg.thread_per_block,
238                                 /*shared_memory_size_bytes=*/0, device.stream(),
239                                 m, diagonals, rhs, k, output));
240   }
241 
242   bool pivoting_;
243   bool perturb_singular_;
244 };
245 
246 template <class Scalar>
247 class TridiagonalSolveOpGpu : public OpKernel {
248  public:
TridiagonalSolveOpGpu(OpKernelConstruction * context)249   explicit TridiagonalSolveOpGpu(OpKernelConstruction* context)
250       : OpKernel(context), linalgOp_(context) {
251     OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
252   }
253 
Compute(OpKernelContext * context)254   void Compute(OpKernelContext* context) final {
255     const Tensor& lhs = context->input(0);
256     const Tensor& rhs = context->input(1);
257     const int ndims = lhs.dims();
258     const int64 num_rhs = rhs.dim_size(rhs.dims() - 1);
259     const int64 matrix_size = lhs.dim_size(ndims - 1);
260     int64 batch_size = 1;
261     for (int i = 0; i < ndims - 2; i++) {
262       batch_size *= lhs.dim_size(i);
263     }
264 
265     // The batching mechanism of LinearAlgebraOp is used when it's not
266     // possible or desirable to use GtsvBatched.
267     const bool use_linalg_op =
268         pivoting_            // GtsvBatched doesn't do pivoting
269         || num_rhs > 1       // GtsvBatched doesn't support multiple rhs
270         || matrix_size < 3   // Not supported in cuSparse, use the custom kernel
271         || batch_size == 1;  // No point to use GtsvBatched
272 
273     if (use_linalg_op) {
274       linalgOp_.Compute(context);
275     } else {
276       ComputeWithGtsvBatched(context, lhs, rhs, batch_size);
277     }
278   }
279 
280  private:
281   TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOpGpu);
282 
ComputeWithGtsvBatched(OpKernelContext * context,const Tensor & lhs,const Tensor & rhs,const int batch_size)283   void ComputeWithGtsvBatched(OpKernelContext* context, const Tensor& lhs,
284                               const Tensor& rhs, const int batch_size) {
285     const Scalar* rhs_data = rhs.flat<Scalar>().data();
286     const int ndims = lhs.dims();
287 
288     // To use GtsvBatched we need to transpose the left-hand side from shape
289     // [..., 3, M] into shape [3, ..., M]. With shape [..., 3, M] the stride
290     // between corresponding diagonal elements of consecutive batch components
291     // is 3 * M, while for the right-hand side the stride is M. Unfortunately,
292     // GtsvBatched requires the strides to be the same. For this reason we
293     // transpose into [3, ..., M], so that diagonals, superdiagonals, and
294     // and subdiagonals are separated from each other, and have stride M.
295     Tensor lhs_transposed;
296     TransposeLhsForGtsvBatched(context, lhs, lhs_transposed);
297     int matrix_size = lhs.dim_size(ndims - 1);
298     const Scalar* lhs_data = lhs_transposed.flat<Scalar>().data();
299     const Scalar* superdiag = lhs_data;
300     const Scalar* diag = lhs_data + matrix_size * batch_size;
301     const Scalar* subdiag = lhs_data + 2 * matrix_size * batch_size;
302 
303     // Copy right-hand side into the output. GtsvBatched will replace it with
304     // the solution.
305     Tensor* output;
306     OP_REQUIRES_OK(context, context->allocate_output(0, rhs.shape(), &output));
307     CopyDeviceToDevice(context, rhs_data, output->flat<Scalar>().data(),
308                        rhs.flat<Scalar>().size());
309     Scalar* x = output->flat<Scalar>().data();
310 
311     std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
312 
313     OP_REQUIRES_OK(context, cusparse_solver->Initialize());
314 
315     size_t buffer_size;
316     OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatchBufferSizeExt(
317                                 matrix_size, subdiag, diag, superdiag, x,
318                                 batch_size, matrix_size, &buffer_size));
319     Tensor temp_tensor;
320     TensorShape temp_shape({static_cast<int64_t>(buffer_size)});
321     OP_REQUIRES_OK(context,
322                    context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
323     void* buffer = temp_tensor.flat<std::uint8_t>().data();
324     OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatch(
325                                 matrix_size, subdiag, diag, superdiag, x,
326                                 batch_size, matrix_size, buffer));
327   }
328 
TransposeLhsForGtsvBatched(OpKernelContext * context,const Tensor & lhs,Tensor & lhs_transposed)329   void TransposeLhsForGtsvBatched(OpKernelContext* context, const Tensor& lhs,
330                                   Tensor& lhs_transposed) {
331     const int ndims = lhs.dims();
332 
333     // Permutation of indices, transforming [..., 3, M] into [3, ..., M].
334     // E.g. for ndims = 6, it is [4, 0, 1, 2, 3, 5].
335     std::vector<int> perm(ndims);
336     perm[0] = ndims - 2;
337     for (int i = 0; i < ndims - 2; ++i) {
338       perm[i + 1] = i;
339     }
340     perm[ndims - 1] = ndims - 1;
341 
342     std::vector<int64_t> dims;
343     for (int index : perm) {
344       dims.push_back(lhs.dim_size(index));
345     }
346     TensorShape lhs_transposed_shape(
347         gtl::ArraySlice<int64_t>(dims.data(), ndims));
348 
349     std::unique_ptr<GpuSolver> cublas_solver(new GpuSolver(context));
350     OP_REQUIRES_OK(context, cublas_solver->allocate_scoped_tensor(
351                                 DataTypeToEnum<Scalar>::value,
352                                 lhs_transposed_shape, &lhs_transposed));
353     auto device = context->eigen_device<Eigen::GpuDevice>();
354     OP_REQUIRES_OK(
355         context,
356         DoTranspose(device, lhs, gtl::ArraySlice<int>(perm.data(), ndims),
357                     &lhs_transposed));
358   }
359 
360   TridiagonalSolveOpGpuLinalg<Scalar> linalgOp_;
361   bool pivoting_;
362 };
363 
364 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<float>),
365                        float);
366 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<double>),
367                        double);
368 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<complex64>),
369                        complex64);
370 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<complex128>),
371                        complex128);
372 
373 }  // namespace tensorflow
374 
375 #endif  // GOOGLE_CUDA
376