1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/linalg_ops.cc.
17
18 #ifdef GOOGLE_CUDA
19
20 #define EIGEN_USE_GPU
21
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
27 #include "tensorflow/core/kernels/transpose_functor.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/util/cuda_sparse.h"
30 #include "tensorflow/core/util/gpu_device_functions.h"
31 #include "tensorflow/core/util/gpu_kernel_helper.h"
32 #include "tensorflow/core/util/gpu_launch_config.h"
33 #include "tensorflow/core/util/gpu_solvers.h"
34
35 namespace tensorflow {
36
37 static const char kNotInvertibleMsg[] = "The matrix is not invertible.";
38
39 static const char kNotInvertibleScalarMsg[] =
40 "The matrix is not invertible: it is a scalar with value zero.";
41
42 template <typename Scalar>
SolveForSizeOneOrTwoKernel(const int m,const Scalar * __restrict__ diags,const Scalar * __restrict__ rhs,const int num_rhs,Scalar * __restrict__ x)43 __global__ void SolveForSizeOneOrTwoKernel(const int m,
44 const Scalar* __restrict__ diags,
45 const Scalar* __restrict__ rhs,
46 const int num_rhs,
47 Scalar* __restrict__ x) {
48 const Scalar nan = Eigen::NumTraits<Scalar>::quiet_NaN();
49 if (m == 1) {
50 bool singular = diags[1] == Scalar(0);
51 for (int i : GpuGridRangeX(num_rhs)) {
52 x[i] = singular ? nan : rhs[i] / diags[1];
53 }
54 } else {
55 const Scalar det = diags[2] * diags[3] - diags[0] * diags[5];
56 bool singular = det == Scalar(0);
57 for (int i : GpuGridRangeX(num_rhs)) {
58 x[i] = singular ? nan
59 : (diags[3] * rhs[i] - diags[0] * rhs[i + num_rhs]) / det;
60 x[i + num_rhs] =
61 singular ? nan
62 : (diags[2] * rhs[i + num_rhs] - diags[5] * rhs[i]) / det;
63 }
64 }
65 }
66
67 template <typename Scalar>
AsDeviceMemory(const Scalar * cuda_memory)68 se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* cuda_memory) {
69 se::DeviceMemoryBase wrapped(const_cast<Scalar*>(cuda_memory));
70 se::DeviceMemory<Scalar> typed(wrapped);
71 return typed;
72 }
73
74 template <typename Scalar>
CopyDeviceToDevice(OpKernelContext * context,const Scalar * src,Scalar * dst,const int num_elements)75 void CopyDeviceToDevice(OpKernelContext* context, const Scalar* src,
76 Scalar* dst, const int num_elements) {
77 auto src_device_mem = AsDeviceMemory(src);
78 auto dst_device_mem = AsDeviceMemory(dst);
79 auto* stream = context->op_device_context()->stream();
80 bool copy_status = stream
81 ->ThenMemcpyD2D(&dst_device_mem, src_device_mem,
82 sizeof(Scalar) * num_elements)
83 .ok();
84
85 if (!copy_status) {
86 context->SetStatus(errors::Internal("Copying device-to-device failed."));
87 }
88 }
89
90 // This implementation is used in cases when the batching mechanism of
91 // LinearAlgebraOp is suitable. See TridiagonalSolveOpGpu below.
92 template <class Scalar>
93 class TridiagonalSolveOpGpuLinalg : public LinearAlgebraOp<Scalar> {
94 public:
95 INHERIT_LINALG_TYPEDEFS(Scalar);
96
TridiagonalSolveOpGpuLinalg(OpKernelConstruction * context)97 explicit TridiagonalSolveOpGpuLinalg(OpKernelConstruction* context)
98 : Base(context) {
99 OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
100 perturb_singular_ = false;
101 if (context->HasAttr("perturb_singular")) {
102 OP_REQUIRES_OK(context,
103 context->GetAttr("perturb_singular", &perturb_singular_));
104 }
105 OP_REQUIRES(
106 context, perturb_singular_ == false,
107 errors::Unimplemented("The solver to support perturb_singular is"
108 " not implemented on GPU."));
109 }
110
ValidateInputMatrixShapes(OpKernelContext * context,const TensorShapes & input_matrix_shapes) const111 void ValidateInputMatrixShapes(
112 OpKernelContext* context,
113 const TensorShapes& input_matrix_shapes) const final {
114 auto num_inputs = input_matrix_shapes.size();
115 OP_REQUIRES(context, num_inputs == 2,
116 errors::InvalidArgument("Expected two input matrices, got ",
117 num_inputs, "."));
118
119 auto num_diags = input_matrix_shapes[0].dim_size(0);
120 OP_REQUIRES(
121 context, num_diags == 3,
122 errors::InvalidArgument("Expected diagonals to be provided as a "
123 "matrix with 3 columns, got ",
124 num_diags, " columns."));
125
126 auto num_rows1 = input_matrix_shapes[0].dim_size(1);
127 auto num_rows2 = input_matrix_shapes[1].dim_size(0);
128 OP_REQUIRES(context, num_rows1 == num_rows2,
129 errors::InvalidArgument("Expected same number of rows in both "
130 "arguments, got ",
131 num_rows1, " and ", num_rows2, "."));
132 }
133
EnableInputForwarding() const134 bool EnableInputForwarding() const final { return false; }
135
GetOutputMatrixShapes(const TensorShapes & input_matrix_shapes) const136 TensorShapes GetOutputMatrixShapes(
137 const TensorShapes& input_matrix_shapes) const final {
138 return TensorShapes({input_matrix_shapes[1]});
139 }
140
ComputeMatrix(OpKernelContext * context,const ConstMatrixMaps & inputs,MatrixMaps * outputs)141 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
142 MatrixMaps* outputs) final {
143 const auto diagonals = inputs[0];
144 // Superdiagonal elements, first is ignored.
145 const auto& superdiag = diagonals.row(0);
146 // Diagonal elements.
147 const auto& diag = diagonals.row(1);
148 // Subdiagonal elements, last is ignored.
149 const auto& subdiag = diagonals.row(2);
150 // Right-hand sides.
151 const auto& rhs = inputs[1];
152 MatrixMap& x = outputs->at(0);
153 const int m = diag.size();
154 const int k = rhs.cols();
155
156 if (m == 0) {
157 return;
158 }
159 if (m < 3) {
160 // Cusparse gtsv routine requires m >= 3. Solving manually for m < 3.
161 SolveForSizeOneOrTwo(context, diagonals.data(), rhs.data(), x.data(), m,
162 k);
163 return;
164 }
165 std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
166 OP_REQUIRES_OK(context, cusparse_solver->Initialize());
167 if (k == 1) {
168 // rhs is copied into x, then gtsv replaces x with solution.
169 CopyDeviceToDevice(context, rhs.data(), x.data(), m);
170 SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(),
171 subdiag.data(), x.data(), m, 1);
172 } else {
173 // Gtsv expects rhs in column-major form, so we have to transpose.
174 // rhs is transposed into temp, gtsv replaces temp with solution, then
175 // temp is transposed into x.
176 std::unique_ptr<GpuSolver> cublas_solver(new GpuSolver(context));
177 Tensor temp;
178 TensorShape temp_shape({k, m});
179 OP_REQUIRES_OK(context,
180 cublas_solver->allocate_scoped_tensor(
181 DataTypeToEnum<Scalar>::value, temp_shape, &temp));
182 TransposeWithGeam(context, cublas_solver, rhs.data(),
183 temp.flat<Scalar>().data(), m, k);
184 SolveWithGtsv(context, cusparse_solver, superdiag.data(), diag.data(),
185 subdiag.data(), temp.flat<Scalar>().data(), m, k);
186 TransposeWithGeam(context, cublas_solver, temp.flat<Scalar>().data(),
187 x.data(), k, m);
188 }
189 }
190
191 private:
TransposeWithGeam(OpKernelContext * context,const std::unique_ptr<GpuSolver> & cublas_solver,const Scalar * src,Scalar * dst,const int src_rows,const int src_cols) const192 void TransposeWithGeam(OpKernelContext* context,
193 const std::unique_ptr<GpuSolver>& cublas_solver,
194 const Scalar* src, Scalar* dst, const int src_rows,
195 const int src_cols) const {
196 const Scalar zero(0), one(1);
197 OP_REQUIRES_OK(context,
198 cublas_solver->Geam(CUBLAS_OP_T, CUBLAS_OP_N, src_rows,
199 src_cols, &one, src, src_cols, &zero,
200 static_cast<const Scalar*>(nullptr),
201 src_rows, dst, src_rows));
202 }
203
SolveWithGtsv(OpKernelContext * context,std::unique_ptr<GpuSparse> & cusparse_solver,const Scalar * superdiag,const Scalar * diag,const Scalar * subdiag,Scalar * rhs,const int num_eqs,const int num_rhs) const204 void SolveWithGtsv(OpKernelContext* context,
205 std::unique_ptr<GpuSparse>& cusparse_solver,
206 const Scalar* superdiag, const Scalar* diag,
207 const Scalar* subdiag, Scalar* rhs, const int num_eqs,
208 const int num_rhs) const {
209 auto buffer_function = pivoting_
210 ? &GpuSparse::Gtsv2BufferSizeExt<Scalar>
211 : &GpuSparse::Gtsv2NoPivotBufferSizeExt<Scalar>;
212 size_t buffer_size;
213 OP_REQUIRES_OK(context, (cusparse_solver.get()->*buffer_function)(
214 num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
215 num_eqs, &buffer_size));
216 Tensor temp_tensor;
217 TensorShape temp_shape({static_cast<int64_t>(buffer_size)});
218 OP_REQUIRES_OK(context,
219 context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
220 void* buffer = temp_tensor.flat<std::uint8_t>().data();
221
222 auto solver_function = pivoting_ ? &GpuSparse::Gtsv2<Scalar>
223 : &GpuSparse::Gtsv2NoPivot<Scalar>;
224 OP_REQUIRES_OK(context, (cusparse_solver.get()->*solver_function)(
225 num_eqs, num_rhs, subdiag, diag, superdiag, rhs,
226 num_eqs, buffer));
227 }
228
SolveForSizeOneOrTwo(OpKernelContext * context,const Scalar * diagonals,const Scalar * rhs,Scalar * output,int m,int k)229 void SolveForSizeOneOrTwo(OpKernelContext* context, const Scalar* diagonals,
230 const Scalar* rhs, Scalar* output, int m, int k) {
231 const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
232 GpuLaunchConfig cfg = GetGpuLaunchConfig(
233 /*work_element_count=*/1, device, &SolveForSizeOneOrTwoKernel<Scalar>,
234 /*dynamic_shared_memory_size=*/0,
235 /*block_size_limit=*/0);
236 TF_CHECK_OK(GpuLaunchKernel(SolveForSizeOneOrTwoKernel<Scalar>,
237 cfg.block_count, cfg.thread_per_block,
238 /*shared_memory_size_bytes=*/0, device.stream(),
239 m, diagonals, rhs, k, output));
240 }
241
242 bool pivoting_;
243 bool perturb_singular_;
244 };
245
246 template <class Scalar>
247 class TridiagonalSolveOpGpu : public OpKernel {
248 public:
TridiagonalSolveOpGpu(OpKernelConstruction * context)249 explicit TridiagonalSolveOpGpu(OpKernelConstruction* context)
250 : OpKernel(context), linalgOp_(context) {
251 OP_REQUIRES_OK(context, context->GetAttr("partial_pivoting", &pivoting_));
252 }
253
Compute(OpKernelContext * context)254 void Compute(OpKernelContext* context) final {
255 const Tensor& lhs = context->input(0);
256 const Tensor& rhs = context->input(1);
257 const int ndims = lhs.dims();
258 const int64 num_rhs = rhs.dim_size(rhs.dims() - 1);
259 const int64 matrix_size = lhs.dim_size(ndims - 1);
260 int64 batch_size = 1;
261 for (int i = 0; i < ndims - 2; i++) {
262 batch_size *= lhs.dim_size(i);
263 }
264
265 // The batching mechanism of LinearAlgebraOp is used when it's not
266 // possible or desirable to use GtsvBatched.
267 const bool use_linalg_op =
268 pivoting_ // GtsvBatched doesn't do pivoting
269 || num_rhs > 1 // GtsvBatched doesn't support multiple rhs
270 || matrix_size < 3 // Not supported in cuSparse, use the custom kernel
271 || batch_size == 1; // No point to use GtsvBatched
272
273 if (use_linalg_op) {
274 linalgOp_.Compute(context);
275 } else {
276 ComputeWithGtsvBatched(context, lhs, rhs, batch_size);
277 }
278 }
279
280 private:
281 TF_DISALLOW_COPY_AND_ASSIGN(TridiagonalSolveOpGpu);
282
ComputeWithGtsvBatched(OpKernelContext * context,const Tensor & lhs,const Tensor & rhs,const int batch_size)283 void ComputeWithGtsvBatched(OpKernelContext* context, const Tensor& lhs,
284 const Tensor& rhs, const int batch_size) {
285 const Scalar* rhs_data = rhs.flat<Scalar>().data();
286 const int ndims = lhs.dims();
287
288 // To use GtsvBatched we need to transpose the left-hand side from shape
289 // [..., 3, M] into shape [3, ..., M]. With shape [..., 3, M] the stride
290 // between corresponding diagonal elements of consecutive batch components
291 // is 3 * M, while for the right-hand side the stride is M. Unfortunately,
292 // GtsvBatched requires the strides to be the same. For this reason we
293 // transpose into [3, ..., M], so that diagonals, superdiagonals, and
294 // and subdiagonals are separated from each other, and have stride M.
295 Tensor lhs_transposed;
296 TransposeLhsForGtsvBatched(context, lhs, lhs_transposed);
297 int matrix_size = lhs.dim_size(ndims - 1);
298 const Scalar* lhs_data = lhs_transposed.flat<Scalar>().data();
299 const Scalar* superdiag = lhs_data;
300 const Scalar* diag = lhs_data + matrix_size * batch_size;
301 const Scalar* subdiag = lhs_data + 2 * matrix_size * batch_size;
302
303 // Copy right-hand side into the output. GtsvBatched will replace it with
304 // the solution.
305 Tensor* output;
306 OP_REQUIRES_OK(context, context->allocate_output(0, rhs.shape(), &output));
307 CopyDeviceToDevice(context, rhs_data, output->flat<Scalar>().data(),
308 rhs.flat<Scalar>().size());
309 Scalar* x = output->flat<Scalar>().data();
310
311 std::unique_ptr<GpuSparse> cusparse_solver(new GpuSparse(context));
312
313 OP_REQUIRES_OK(context, cusparse_solver->Initialize());
314
315 size_t buffer_size;
316 OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatchBufferSizeExt(
317 matrix_size, subdiag, diag, superdiag, x,
318 batch_size, matrix_size, &buffer_size));
319 Tensor temp_tensor;
320 TensorShape temp_shape({static_cast<int64_t>(buffer_size)});
321 OP_REQUIRES_OK(context,
322 context->allocate_temp(DT_UINT8, temp_shape, &temp_tensor));
323 void* buffer = temp_tensor.flat<std::uint8_t>().data();
324 OP_REQUIRES_OK(context, cusparse_solver->Gtsv2StridedBatch(
325 matrix_size, subdiag, diag, superdiag, x,
326 batch_size, matrix_size, buffer));
327 }
328
TransposeLhsForGtsvBatched(OpKernelContext * context,const Tensor & lhs,Tensor & lhs_transposed)329 void TransposeLhsForGtsvBatched(OpKernelContext* context, const Tensor& lhs,
330 Tensor& lhs_transposed) {
331 const int ndims = lhs.dims();
332
333 // Permutation of indices, transforming [..., 3, M] into [3, ..., M].
334 // E.g. for ndims = 6, it is [4, 0, 1, 2, 3, 5].
335 std::vector<int> perm(ndims);
336 perm[0] = ndims - 2;
337 for (int i = 0; i < ndims - 2; ++i) {
338 perm[i + 1] = i;
339 }
340 perm[ndims - 1] = ndims - 1;
341
342 std::vector<int64_t> dims;
343 for (int index : perm) {
344 dims.push_back(lhs.dim_size(index));
345 }
346 TensorShape lhs_transposed_shape(
347 gtl::ArraySlice<int64_t>(dims.data(), ndims));
348
349 std::unique_ptr<GpuSolver> cublas_solver(new GpuSolver(context));
350 OP_REQUIRES_OK(context, cublas_solver->allocate_scoped_tensor(
351 DataTypeToEnum<Scalar>::value,
352 lhs_transposed_shape, &lhs_transposed));
353 auto device = context->eigen_device<Eigen::GpuDevice>();
354 OP_REQUIRES_OK(
355 context,
356 DoTranspose(device, lhs, gtl::ArraySlice<int>(perm.data(), ndims),
357 &lhs_transposed));
358 }
359
360 TridiagonalSolveOpGpuLinalg<Scalar> linalgOp_;
361 bool pivoting_;
362 };
363
364 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<float>),
365 float);
366 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<double>),
367 double);
368 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<complex64>),
369 complex64);
370 REGISTER_LINALG_OP_GPU("TridiagonalSolve", (TridiagonalSolveOpGpu<complex128>),
371 complex128);
372
373 } // namespace tensorflow
374
375 #endif // GOOGLE_CUDA
376