xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sparse/kernels_gpu.cu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 
18 #define EIGEN_USE_GPU
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/kernels/gpu_device_array.h"
24 #include "tensorflow/core/kernels/gpu_device_array_gpu.h"
25 #include "tensorflow/core/kernels/gpu_prim.h"
26 #include "tensorflow/core/kernels/sparse/kernels.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/cuda_sparse.h"
30 #include "tensorflow/core/util/gpu_kernel_helper.h"
31 
32 namespace tensorflow {
33 
34 typedef Eigen::GpuDevice GPUDevice;
35 
36 namespace functor {
37 
38 namespace {
39 struct StridedDataReader {
StridedDataReadertensorflow::functor::__anon4cc7d65c0111::StridedDataReader40   StridedDataReader(const int64* begin, int stride)
41       : begin_(begin), stride_(stride) {}
42 
operator ()tensorflow::functor::__anon4cc7d65c0111::StridedDataReader43   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
44     return static_cast<int>(ldg(begin_ + idx * stride_));
45   }
46 
47   const int64* begin_;
48   const int stride_;
49 };
50 }  // namespace
51 
52 template <>
operator ()(OpKernelContext * c,TTypes<int64_t>::ConstMatrix indices,TTypes<int32>::Vec nnz_per_batch)53 Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
54     OpKernelContext* c, TTypes<int64_t>::ConstMatrix indices,
55     TTypes<int32>::Vec nnz_per_batch) {
56   const auto& cu_stream = GetGpuStream(c);
57 
58   const int total_nnz = indices.dimension(0);
59   const int size = nnz_per_batch.size();
60 
61   DCHECK_EQ(indices.rank(), 2);
62   DCHECK_EQ(indices.dimension(1), 3);  // batch, row, col
63 
64   const int rank = indices.dimension(1);
65   gpuprim::CountingInputIterator<int> row_counter(0);
66   gpuprim::TransformInputIterator<int, StridedDataReader,
67                                   gpuprim::CountingInputIterator<int>>
68       indices_first_column(row_counter,
69                            StridedDataReader(indices.data(), rank));
70 
71   std::size_t temp_storage_bytes = 0;
72 
73   DCHECK_NE(indices.data(), nullptr);
74   DCHECK_NE(nnz_per_batch.data(), nullptr);
75 
76   auto first_success = gpuprim::DeviceHistogram::HistogramEven(
77       /*d_temp_storage*/ nullptr,
78       /*temp_storage_bytes&*/ temp_storage_bytes,
79       /*d_samples*/ indices_first_column,
80       /*d_histogram*/ nnz_per_batch.data(),
81       /*num_levels*/ size + 1,
82       /*lower_level*/ 0,
83       /*upper_level*/ size,
84       /*num_samples*/ total_nnz,
85       /*stream*/ cu_stream);
86 
87   if (first_success != gpuSuccess) {
88     return errors::Internal(
89         "SparseTensorToCSRSparseMatrix: Could not launch "
90         "gpuprim::DeviceHistogram::HistogramEven "
91         "to calculate temp_storage_bytes, status: ",
92         GpuGetErrorString(first_success));
93   }
94 
95   Tensor temp_storage;
96   TF_RETURN_IF_ERROR(c->allocate_temp(
97       DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
98       &temp_storage));
99   DCHECK_NE(temp_storage.flat<int8>().data(), nullptr);
100   auto second_success = gpuprim::DeviceHistogram::HistogramEven(
101       /*d_temp_storage*/ temp_storage.flat<int8>().data(),
102       /*temp_storage_bytes&*/ temp_storage_bytes,
103       /*d_samples*/ indices_first_column,
104       /*d_histogram*/ nnz_per_batch.data(),
105       /*num_levels*/ size + 1,
106       /*lower_level*/ 0,
107       /*upper_level*/ size,
108       /*num_samples*/ total_nnz,
109       /*stream*/ cu_stream);
110 
111   if (second_success != gpuSuccess) {
112     return errors::Internal(
113         "SparseTensorToCSRSparseMatrix: Could not launch "
114         "gpuprim::DeviceHistogram::HistogramEven "
115         "to count nnz entries per batch.  temp_storage_bytes: ",
116         temp_storage_bytes, ", status: ", GpuGetErrorString(second_success));
117   }
118 
119   return Status::OK();
120 }
121 
122 // TODO(ebrevdo): Write a custom batch-friendly impl of this to update
123 // the SparseTensor indices directly.
124 template <>
operator ()(OpKernelContext * c,TTypes<const int>::UnalignedVec csr_row_ptr,TTypes<int>::UnalignedVec coo_row_ind)125 Status CSRSparseMatrixToCOOSparseMatrix<GPUDevice>::operator()(
126     OpKernelContext* c, TTypes<const int>::UnalignedVec csr_row_ptr,
127     TTypes<int>::UnalignedVec coo_row_ind) {
128   GpuSparse gpu_sparse(c);
129   const int nnz = coo_row_ind.size();
130   TF_RETURN_IF_ERROR(gpu_sparse.Initialize());
131   const int m = csr_row_ptr.size() - 1;  // rows
132   return gpu_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data());
133 }
134 
135 template <int stride>
SparseTensorToCOOMatrixKernel(const int64 * indices,int * coo_rows_out,int * coo_cols_out,int size)136 __global__ void SparseTensorToCOOMatrixKernel(const int64* indices,
137                                               int* coo_rows_out,
138                                               int* coo_cols_out, int size) {
139   const int offset = (stride == 3) ? 1 : 0;
140   GPU_1D_KERNEL_LOOP(i, size) {
141     coo_rows_out[i] = static_cast<int>(ldg(indices + i * stride + offset));
142     coo_cols_out[i] = static_cast<int>(ldg(indices + i * stride + offset + 1));
143   }
144 }
145 
146 template <>
operator ()(const GPUDevice & d,TTypes<int64_t>::ConstVec host_dense_shape,TTypes<int64_t>::ConstMatrix indices,TTypes<int>::Vec coo_row_ind,TTypes<int>::Vec coo_col_ind)147 void SparseTensorToCOOSparseMatrix<GPUDevice>::operator()(
148     const GPUDevice& d, TTypes<int64_t>::ConstVec host_dense_shape,
149     TTypes<int64_t>::ConstMatrix indices, TTypes<int>::Vec coo_row_ind,
150     TTypes<int>::Vec coo_col_ind) {
151   const int stride = host_dense_shape.size();
152   DCHECK(stride == 2 || stride == 3);
153   DCHECK_EQ(stride, indices.dimension(1));
154   const int size = coo_row_ind.dimension(0);
155   GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
156   if (stride == 2) {
157     TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<2>,
158                                 config.block_count, config.thread_per_block, 0,
159                                 d.stream(), indices.data(), coo_row_ind.data(),
160                                 coo_col_ind.data(), size));
161   } else {
162     TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<3>,
163                                 config.block_count, config.thread_per_block, 0,
164                                 d.stream(), indices.data(), coo_row_ind.data(),
165                                 coo_col_ind.data(), size));
166   }
167 }
168 
COOMatrixToSparseTensorKernel2D(const int * coo_rows,const int * coo_cols,int64 * indices_out,int size)169 __global__ void COOMatrixToSparseTensorKernel2D(const int* coo_rows,
170                                                 const int* coo_cols,
171                                                 int64* indices_out, int size) {
172   GPU_1D_KERNEL_LOOP(i, size) {
173     indices_out[i * 2] = static_cast<int64_t>(ldg(coo_rows + i));
174     indices_out[i * 2 + 1] = static_cast<int64_t>(ldg(coo_cols + i));
175   }
176 }
177 
BinarySearchRange(int * range,int n,int x)178 __device__ inline int BinarySearchRange(int* range, int n, int x) {
179   int left = 0;
180   int right = n - 1;
181   while (left < right) {
182     int mid = left + (right - left) / 2;
183     if (x < range[mid])
184       right = mid - 1;
185     else if (range[mid + 1] <= x)
186       left = mid + 1;
187     else
188       return mid;  // range[mid] <= x < range[mid + 1].
189   }
190   return left;
191 }
192 
COOMatrixToSparseTensorKernel3D(const int * coo_rows,const int * coo_cols,int64 * indices_out,GpuDeviceArrayStruct<int> batch_ptr_s,const int batch_size,const int size)193 __global__ void COOMatrixToSparseTensorKernel3D(
194     const int* coo_rows, const int* coo_cols, int64* indices_out,
195     GpuDeviceArrayStruct<int> batch_ptr_s, const int batch_size,
196     const int size) {
197   // Step 1: access the batch ptrs and copy to shared memory.
198   const int* batch_ptr = GetGpuDeviceArrayOnDevice(&batch_ptr_s);
199   extern __shared__ int local_batch_ptr[];
200   for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
201     local_batch_ptr[i] = batch_ptr[i];
202   }
203   __syncthreads();
204 
205   GPU_1D_KERNEL_LOOP(i, size) {
206     // TODO(ebrevdo): Consider special casing batch_size <= 3,
207     // alternatively doing linear instead of binary search.  Requires
208     // some benchmarks.
209     const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
210     indices_out[i * 3] = static_cast<int64_t>(b);
211     indices_out[i * 3 + 1] = static_cast<int64_t>(ldg(coo_rows + i));
212     indices_out[i * 3 + 2] = static_cast<int64_t>(ldg(coo_cols + i));
213   }
214 }
215 
216 template <>
operator ()(OpKernelContext * c,TTypes<int64_t>::ConstVec host_dense_shape,TTypes<int>::ConstVec host_batch_ptr,TTypes<int>::Vec coo_row_ind,TTypes<int>::ConstVec coo_col_ind,TTypes<int64_t>::Matrix indices)217 Status COOSparseMatrixToSparseTensor<GPUDevice>::operator()(
218     OpKernelContext* c, TTypes<int64_t>::ConstVec host_dense_shape,
219     TTypes<int>::ConstVec host_batch_ptr, TTypes<int>::Vec coo_row_ind,
220     TTypes<int>::ConstVec coo_col_ind, TTypes<int64_t>::Matrix indices) {
221   const int ndims = indices.dimension(1);
222   DCHECK(ndims == 2 || ndims == 3);
223   DCHECK_EQ(ndims, host_dense_shape.size());
224   DCHECK_NE(coo_row_ind.data(), nullptr);
225   DCHECK_NE(coo_col_ind.data(), nullptr);
226   DCHECK_NE(indices.data(), nullptr);
227   const GPUDevice& d = c->eigen_device<GPUDevice>();
228   const int size = coo_row_ind.size();
229   DCHECK_EQ(size, coo_col_ind.size());
230   DCHECK_EQ(size, indices.dimension(0));
231   if (ndims == 2) {
232     GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
233     TF_CHECK_OK(GpuLaunchKernel(COOMatrixToSparseTensorKernel2D,
234                                 config.block_count, config.thread_per_block, 0,
235                                 d.stream(), coo_row_ind.data(),
236                                 coo_col_ind.data(), indices.data(), size));
237     return Status::OK();
238   } else {
239     const int batch_size = host_dense_shape(0);
240     GpuDeviceArrayOnHost<int> batch_ptr_copy(c, host_batch_ptr.size());
241     TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
242     for (int i = 0; i < batch_size; ++i) {
243       batch_ptr_copy.Set(i, host_batch_ptr(i));
244     }
245     TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
246     GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
247     // shared memory stores the batch pointers.
248     const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
249     TF_CHECK_OK(
250         GpuLaunchKernel(COOMatrixToSparseTensorKernel3D, config.block_count,
251                         config.thread_per_block, shared_memory_size, d.stream(),
252                         coo_row_ind.data(), coo_col_ind.data(), indices.data(),
253                         batch_ptr_copy.data(), batch_size, size));
254     return Status::OK();
255   }
256 }
257 
258 template <typename T>
CSRSparseMatrixBatchMulVecKernel3D(const T * a_values,const T * b_batch_values,T * c_values,GpuDeviceArrayStruct<int> batch_ptr_s,const int batch_size,const int total_nnz)259 __global__ void CSRSparseMatrixBatchMulVecKernel3D(
260     const T* a_values, const T* b_batch_values, T* c_values,
261     GpuDeviceArrayStruct<int> batch_ptr_s, const int batch_size,
262     const int total_nnz) {
263   // Step 1: Access the batch ptrs and copy to shared memory.
264   //         Also copy the per-batch multipliers into shared memory.
265   const int* batch_ptr = GetGpuDeviceArrayOnDevice(&batch_ptr_s);
266   extern __shared__ int local_batch_ptr[];
267   T* local_batch_values =
268       reinterpret_cast<T*>(local_batch_ptr + batch_size + 1);
269   for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
270     local_batch_ptr[i] = batch_ptr[i];
271     if (i < batch_size) {
272       local_batch_values[i] = b_batch_values[i];
273     }
274   }
275   __syncthreads();
276 
277   GPU_1D_KERNEL_LOOP(i, total_nnz) {
278     const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
279     c_values[i] = ldg(a_values + i) * local_batch_values[b];
280   }
281 }
282 
283 template <typename T>
CSRSparseMatrixBatchMulVecImpl(OpKernelContext * ctx,const CSRSparseMatrix & a,typename TTypes<T>::ConstFlat b,CSRSparseMatrix * c)284 Status CSRSparseMatrixBatchMulVecImpl(OpKernelContext* ctx,
285                                       const CSRSparseMatrix& a,
286                                       typename TTypes<T>::ConstFlat b,
287                                       CSRSparseMatrix* c) {
288   DCHECK_EQ(a.dims(), 3);
289   const int total_nnz = a.total_nnz();
290   Tensor c_values_t;
291   TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
292                                         TensorShape({total_nnz}), &c_values_t));
293   TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
294       DataTypeToEnum<T>::value, a.dense_shape(), a.batch_pointers(),
295       a.row_pointers(), a.col_indices(), c_values_t, c));
296 
297   auto a_values = a.values().flat<T>();
298   auto c_values = c_values_t.flat<T>();
299 
300   auto host_dense_shape = a.dense_shape().vec<int64_t>();
301   auto host_batch_ptr = a.batch_pointers().vec<int>();
302 
303   const GPUDevice& d = ctx->eigen_device<GPUDevice>();
304 
305   const int batch_size = host_dense_shape(0);
306   DCHECK_EQ(b.size(), batch_size);
307 
308   GpuDeviceArrayOnHost<int> batch_ptr_copy(ctx, host_batch_ptr.size());
309   TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
310   for (int i = 0; i < batch_size; ++i) {
311     batch_ptr_copy.Set(i, host_batch_ptr(i));
312   }
313   TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
314   GpuLaunchConfig config = GetGpuLaunchConfig(total_nnz, d);
315   // shared memory stores the batch pointers.
316   const size_t shared_memory_size =
317       (sizeof(int) * (batch_size + 1)  // local batch_pointers.
318        + sizeof(T) * batch_size);      // local copy of b.
319   TF_CHECK_OK(GpuLaunchKernel(
320       CSRSparseMatrixBatchMulVecKernel3D<T>, config.block_count,
321       config.thread_per_block, shared_memory_size, d.stream(), a_values.data(),
322       b.data(), c_values.data(), batch_ptr_copy.data(), batch_size, total_nnz));
323 
324   return Status::OK();
325 }
326 
327 #define DEFINE_SPARSE_MUL_VEC_GPU(T)                                        \
328   template <>                                                               \
329   CSRSparseMatrixBatchMulVec<GPUDevice, T>::CSRSparseMatrixBatchMulVec() {} \
330   template <>                                                               \
331   Status CSRSparseMatrixBatchMulVec<GPUDevice, T>::Compute(                 \
332       OpKernelContext* ctx, const CSRSparseMatrix& a,                       \
333       typename TTypes<T>::ConstFlat b, CSRSparseMatrix* c) {                \
334     return CSRSparseMatrixBatchMulVecImpl<T>(ctx, a, b, c);                 \
335   }
336 
337 DEFINE_SPARSE_MUL_VEC_GPU(float);
338 DEFINE_SPARSE_MUL_VEC_GPU(double);
339 DEFINE_SPARSE_MUL_VEC_GPU(std::complex<float>);
340 DEFINE_SPARSE_MUL_VEC_GPU(std::complex<double>);
341 
342 #undef DEFINE_SPARSE_MUL_VEC_GPU
343 
344 template <typename T>
CalculateRowSoftmax(const int begin,const int end,const T * logits,T * softmax)345 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void CalculateRowSoftmax(const int begin,
346                                                                const int end,
347                                                                const T* logits,
348                                                                T* softmax) {
349   // For each row, calculate the vector:
350   //   softmax[row] = exp(shifted_logits[row]) / sum(exp(shifted_logits[row]))
351   // where
352   //   shifted_logits[row] = logits[row] - max(logits[row])
353   // are the logits normalized for stability.
354   T row_max = Eigen::NumTraits<T>::lowest();
355   for (int r_i = begin; r_i < end; ++r_i) {
356     row_max = Eigen::numext::maxi(row_max, ldg(logits + r_i));
357   }
358   T sum_exp = 0;
359   for (int r_i = begin; r_i < end; ++r_i) {
360     const T exp_i = Eigen::numext::exp(ldg(logits + r_i) - row_max);
361     softmax[r_i] = exp_i;
362     sum_exp += exp_i;
363   }
364   for (int r_i = begin; r_i < end; ++r_i) {
365     softmax[r_i] = softmax[r_i] / sum_exp;
366   }
367 }
368 
369 template <typename T>
CSRSparseMatrixSoftmaxKernel2D(const int rows,const int * row_ptr,const T * logits,T * softmax)370 __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows,
371                                                const int* row_ptr,
372                                                const T* logits, T* softmax) {
373   // TODO(ebrevdo): consider something like a merge-path based
374   // algorithm to distribute the work in case the row sizes are
375   // uneven:
376   //   http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
377   GPU_1D_KERNEL_LOOP(row, rows) {
378     CalculateRowSoftmax(ldg(row_ptr + row), ldg(row_ptr + row + 1), logits,
379                         softmax);
380   }
381 }
382 
CopyFromGpuDeviceArrayToLocal(GpuDeviceArrayStruct<int> cuda_ptr_s,int * local_ptr,int length)383 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void CopyFromGpuDeviceArrayToLocal(
384     GpuDeviceArrayStruct<int> cuda_ptr_s, int* local_ptr, int length) {
385 #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
386   const int* cuda_ptr = GetGpuDeviceArrayOnDevice(&cuda_ptr_s);
387   for (int i = threadIdx.x; i < length; i += blockDim.x) {
388     local_ptr[i] = cuda_ptr[i];
389   }
390   __syncthreads();
391 #endif
392 }
393 
394 template <typename T>
CSRSparseMatrixSoftmaxKernel3D(const int size,const int rows,GpuDeviceArrayStruct<int> batch_ptr_s,const int * row_ptr,const T * logits,T * softmax)395 __global__ void CSRSparseMatrixSoftmaxKernel3D(
396     const int size, const int rows, GpuDeviceArrayStruct<int> batch_ptr_s,
397     const int* row_ptr, const T* logits, T* softmax) {
398   // TODO(ebrevdo): consider something like a merge-path based
399   // algorithm to distribute the work in case the row sizes are
400   // uneven:
401   //   http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
402   const int batch_size = size / rows;
403   extern __shared__ int local_batch_ptr[];
404   CopyFromGpuDeviceArrayToLocal(std::move(batch_ptr_s), local_batch_ptr,
405                                 batch_size + 1);
406 
407   GPU_1D_KERNEL_LOOP(i, size) {
408     const int batch = i / rows;
409     const int row = i % rows;
410     const int batch_offset = local_batch_ptr[batch];
411     const int row_offset = batch * (rows + 1) + row;
412     CalculateRowSoftmax(batch_offset + ldg(row_ptr + row_offset),
413                         batch_offset + ldg(row_ptr + row_offset + 1), logits,
414                         softmax);
415   }
416 }
417 
418 template <typename T>
CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext * ctx,const CSRSparseMatrix & logits,typename TTypes<T>::Vec softmax_values)419 Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx,
420                                      const CSRSparseMatrix& logits,
421                                      typename TTypes<T>::Vec softmax_values) {
422   auto host_dense_shape = logits.dense_shape().vec<int64_t>();
423   auto host_batch_ptr = logits.batch_pointers().vec<int32>();
424   auto row_ptr = logits.row_pointers().vec<int32>();
425   auto logits_values = logits.values().vec<T>();
426 
427   const int ndims = host_dense_shape.size();
428   DCHECK(ndims == 2 || ndims == 3);
429   const GPUDevice& d = ctx->eigen_device<GPUDevice>();
430   if (ndims == 2) {
431     const int rows = host_dense_shape(0);
432     DCHECK_EQ(rows, row_ptr.size() - 1);
433     GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
434     TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel2D<T>,
435                                 config.block_count, config.thread_per_block, 0,
436                                 d.stream(), rows /*size*/, row_ptr.data(),
437                                 logits_values.data(), softmax_values.data()));
438   } else {
439     const int batch_size = host_dense_shape(0);
440     const int rows = host_dense_shape(1);
441     DCHECK_EQ(batch_size, host_batch_ptr.size() - 1);
442     DCHECK_EQ((rows + 1) * batch_size, row_ptr.size());
443     const int size = rows * batch_size;
444 
445     GpuDeviceArrayOnHost<int> batch_ptr_copy(ctx, host_batch_ptr.size());
446     TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
447     for (int i = 0; i < host_batch_ptr.size(); ++i) {
448       batch_ptr_copy.Set(i, host_batch_ptr(i));
449     }
450     TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
451 
452     GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
453     // shared memory stores the batch pointers.
454     const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
455     TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel3D<T>,
456                                 config.block_count, config.thread_per_block,
457                                 shared_memory_size, d.stream(), size, rows,
458                                 batch_ptr_copy.data(), row_ptr.data(),
459                                 logits_values.data(), softmax_values.data()));
460   }
461 
462   return Status::OK();
463 }
464 
465 #define DEFINE_SOFTMAX_GPU(T)                                             \
466   template <>                                                             \
467   Status CSRSparseMatrixSoftmax<GPUDevice, T>::operator()(                \
468       OpKernelContext* ctx, const CSRSparseMatrix& logits,                \
469       typename TTypes<T>::Vec softmax_values) {                           \
470     return CSRSparseMatrixSoftmaxGPUImpl<T>(ctx, logits, softmax_values); \
471   }
472 
473 DEFINE_SOFTMAX_GPU(float);
474 DEFINE_SOFTMAX_GPU(double);
475 
476 #undef DEFINE_SOFTMAX_GPU
477 
478 template <typename T>
CalculateRowSoftmaxGrad(const int softmax_begin,const int softmax_end,const int * softmax_col_ind,const T * softmax,const int grad_softmax_begin,const int grad_softmax_end,const int * grad_softmax_col_ind,const T * grad_softmax,T * gradient)479 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void CalculateRowSoftmaxGrad(
480     const int softmax_begin, const int softmax_end, const int* softmax_col_ind,
481     const T* softmax, const int grad_softmax_begin, const int grad_softmax_end,
482     const int* grad_softmax_col_ind, const T* grad_softmax, T* gradient) {
483   // Iterate from
484   //   softmax_col_ind[softmax_begin] to
485   //   softmax_col_ind[softmax_end]
486   // and from
487   //  grad_softmax_col_ind[grad_softmax_begin] to
488   //  grad_softmax_col_ind[grad_softmax_end]
489   //
490   // looking for matching indices.  In the softmax indices only, perform:
491   //
492   //   gradient = (grad_softmax - sum(grad_softmax * softmax)) * softmax
493   //
494   // where the sum is along the given row.
495   T sum_prod = 0;
496   for (int i = softmax_begin, j = grad_softmax_begin;
497        i < softmax_end && j < grad_softmax_end;) {
498     const int softmax_col = ldg(softmax_col_ind + i);
499     const int grad_softmax_col = ldg(grad_softmax_col_ind + j);
500     if (softmax_col == grad_softmax_col) {
501       sum_prod += ldg(softmax + i) * ldg(grad_softmax + j);
502       ++i;
503       ++j;
504     } else if (softmax_col > grad_softmax_col) {
505       ++j;
506     } else {
507       ++i;
508     }
509   }
510 
511   // Find an upper bound on the column numbers in this row; for use in
512   // the special case of a empty grad_softmax row and a non-empty
513   // softmax row.
514   const int softmax_col_upper_bound =
515       (softmax_begin == softmax_end)
516           ? -1
517           : ldg(softmax_col_ind + softmax_end - 1) + 1;
518   for (int i = softmax_begin, j = grad_softmax_begin; i < softmax_end;) {
519     const int softmax_col = ldg(softmax_col_ind + i);
520     // We need to keep a large grad_softmax_col value if we're at the
521     // end of the grad_softmax row, so we can fill in the remainder of
522     // the gradients row (the last if branch in this loop).
523     const int grad_softmax_col = (j == grad_softmax_end)
524                                      ? softmax_col_upper_bound
525                                      : ldg(grad_softmax_col_ind + j);
526 
527     if (softmax_col == grad_softmax_col) {
528       gradient[i] = (ldg(grad_softmax + j) - sum_prod) * ldg(softmax + i);
529       ++i;
530       ++j;
531     } else if (softmax_col > grad_softmax_col) {
532       // grad_softmax is nonzero here, but since softmax is zero, the
533       // gradient is 0; so we skip it since the sparsity structure
534       // already encodes this zero.
535       ++j;
536     } else {
537       // grad_softmax is zero but softmax is not.
538       gradient[i] = -sum_prod * ldg(softmax + i);
539       ++i;
540     }
541   }
542 }
543 
544 template <typename T>
CSRSparseMatrixSoftmaxGradKernel2D(const int rows,const int * softmax_row_ptr,const int * softmax_col_ind,const T * softmax,const int * grad_softmax_row_ptr,const int * grad_softmax_col_ind,const T * grad_softmax,T * gradient)545 __global__ void CSRSparseMatrixSoftmaxGradKernel2D(
546     const int rows, const int* softmax_row_ptr, const int* softmax_col_ind,
547     const T* softmax, const int* grad_softmax_row_ptr,
548     const int* grad_softmax_col_ind, const T* grad_softmax, T* gradient) {
549   // TODO(ebrevdo): consider something like a merge-path based
550   // algorithm to distribute the work in case the row sizes are
551   // uneven:
552   //   http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
553   GPU_1D_KERNEL_LOOP(row, rows) {
554     CalculateRowSoftmaxGrad(
555         ldg(softmax_row_ptr + row) /*softmax_begin*/,
556         ldg(softmax_row_ptr + row + 1) /*softmax_end*/, softmax_col_ind,
557         softmax, ldg(grad_softmax_row_ptr + row) /*grad_softmax_begin*/,
558         ldg(grad_softmax_row_ptr + row + 1) /*grad_softmax_end*/,
559         grad_softmax_col_ind, grad_softmax, gradient);
560   }
561 }
562 
563 template <typename T>
CSRSparseMatrixSoftmaxGradKernel3D(const int size,const int rows,GpuDeviceArrayStruct<int> softmax_and_grad_batch_ptr_s,const int * softmax_row_ptr,const int * softmax_col_ind,const T * softmax,const int * grad_softmax_row_ptr,const int * grad_softmax_col_ind,const T * grad_softmax,T * gradient)564 __global__ void CSRSparseMatrixSoftmaxGradKernel3D(
565     const int size, const int rows,
566     GpuDeviceArrayStruct<int> softmax_and_grad_batch_ptr_s,
567     const int* softmax_row_ptr, const int* softmax_col_ind, const T* softmax,
568     const int* grad_softmax_row_ptr, const int* grad_softmax_col_ind,
569     const T* grad_softmax, T* gradient) {
570   // TODO(ebrevdo): consider something like a merge-path based
571   // algorithm to distribute the work in case the row sizes are
572   // uneven:
573   //   http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
574 
575   const int batch_size = size / rows;
576   extern __shared__ int local_batch_ptr[];
577   CopyFromGpuDeviceArrayToLocal(std::move(softmax_and_grad_batch_ptr_s),
578                                 local_batch_ptr, 2 * (batch_size + 1));
579 
580 #define SOFTMAX_BATCH_PTR(i) local_batch_ptr[i];
581 #define GRAD_SOFTMAX_BATCH_PTR(i) local_batch_ptr[batch_size + 1 + i];
582 
583   GPU_1D_KERNEL_LOOP(i, size) {
584     const int batch = i / rows;
585     const int row = i % rows;
586     const int softmax_batch_offset = SOFTMAX_BATCH_PTR(batch);
587     const int grad_softmax_batch_offset = GRAD_SOFTMAX_BATCH_PTR(batch);
588     const int row_offset = batch * (rows + 1) + row;
589     CalculateRowSoftmaxGrad(
590         softmax_batch_offset +
591             ldg(softmax_row_ptr + row_offset) /*softmax_begin*/,
592         softmax_batch_offset +
593             ldg(softmax_row_ptr + row_offset + 1) /*softmax_end*/,
594         softmax_col_ind, softmax,
595         grad_softmax_batch_offset +
596             ldg(grad_softmax_row_ptr + row_offset) /*grad_softmax_begin*/,
597         grad_softmax_batch_offset +
598             ldg(grad_softmax_row_ptr + row_offset + 1) /*grad_softmax_end*/,
599         grad_softmax_col_ind, grad_softmax, gradient);
600   }
601 
602 #undef SOFTMAX_BATCH_PTR
603 #undef GRAD_SOFTMAX_BATCH_PTR
604 }
605 
606 template <typename T>
CSRSparseMatrixSoftmaxGradGPUImpl(OpKernelContext * ctx,const CSRSparseMatrix & softmax,const CSRSparseMatrix & grad_softmax,typename TTypes<T>::Vec gradient_values)607 Status CSRSparseMatrixSoftmaxGradGPUImpl(
608     OpKernelContext* ctx, const CSRSparseMatrix& softmax,
609     const CSRSparseMatrix& grad_softmax,
610     typename TTypes<T>::Vec gradient_values) {
611   auto host_dense_shape = softmax.dense_shape().vec<int64_t>();
612   auto softmax_host_batch_ptr = softmax.batch_pointers().vec<int32>();
613   auto softmax_row_ptr = softmax.row_pointers().vec<int32>();
614   auto softmax_col_ind = softmax.col_indices().vec<int32>();
615   auto softmax_values = softmax.values().vec<T>();
616   auto grad_softmax_host_batch_ptr = grad_softmax.batch_pointers().vec<int32>();
617   auto grad_softmax_row_ptr = grad_softmax.row_pointers().vec<int32>();
618   auto grad_softmax_col_ind = grad_softmax.col_indices().vec<int32>();
619   auto grad_softmax_values = grad_softmax.values().vec<T>();
620 
621   const int ndims = host_dense_shape.size();
622   DCHECK(ndims == 2 || ndims == 3);
623   const int rows = host_dense_shape(0);
624   const GPUDevice& d = ctx->eigen_device<GPUDevice>();
625   if (ndims == 2) {
626     DCHECK_EQ(rows + 1, softmax_row_ptr.size());
627     DCHECK_EQ(rows + 1, grad_softmax_row_ptr.size());
628     GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
629     TF_CHECK_OK(GpuLaunchKernel(
630         CSRSparseMatrixSoftmaxGradKernel2D<T>, config.block_count,
631         config.thread_per_block, 0, d.stream(), rows /*size*/,
632         softmax_row_ptr.data(), softmax_col_ind.data(), softmax_values.data(),
633         grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
634         grad_softmax_values.data(), gradient_values.data()));
635   } else {
636     const int batch_size = host_dense_shape(0);
637     const int rows = host_dense_shape(1);
638     DCHECK_EQ(batch_size, softmax_host_batch_ptr.size() - 1);
639     DCHECK_EQ(batch_size, grad_softmax_host_batch_ptr.size() - 1);
640     DCHECK_EQ((rows + 1) * batch_size, softmax_row_ptr.size());
641     DCHECK_EQ((rows + 1) * batch_size, grad_softmax_row_ptr.size());
642     const int size = rows * batch_size;
643     // The length of softmax_and_grad_batch_ptr_copy is 2 * (batch_size + 1)
644     // The first (batch_size + 1) entries contain softmax_batch_ptr and
645     // the second (batch_size + 1) entries contain grad_softmax_batch_ptr.
646     GpuDeviceArrayOnHost<int> softmax_and_grad_batch_ptr_copy(
647         ctx, 2 * softmax_host_batch_ptr.size());
648     TF_RETURN_IF_ERROR(softmax_and_grad_batch_ptr_copy.Init());
649     for (int i = 0; i < softmax_host_batch_ptr.size(); ++i) {
650       softmax_and_grad_batch_ptr_copy.Set(i, softmax_host_batch_ptr(i));
651       softmax_and_grad_batch_ptr_copy.Set(batch_size + 1 + i,
652                                           grad_softmax_host_batch_ptr(i));
653     }
654     TF_RETURN_IF_ERROR(softmax_and_grad_batch_ptr_copy.Finalize());
655 
656     GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
657     // shared memory stores two copies of batch pointers: one for the
658     // softmax CSR matrix, one for the grad_softmax CSR matrix.
659     const size_t shared_memory_size = 2 * sizeof(int) * (batch_size + 1);
660     TF_CHECK_OK(GpuLaunchKernel(
661         CSRSparseMatrixSoftmaxGradKernel3D<T>, config.block_count,
662         config.thread_per_block, shared_memory_size, d.stream(), size, rows,
663         softmax_and_grad_batch_ptr_copy.data(), softmax_row_ptr.data(),
664         softmax_col_ind.data(), softmax_values.data(),
665         grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
666         grad_softmax_values.data(), gradient_values.data()));
667   }
668 
669   return Status::OK();
670 }
671 
672 #define DEFINE_SOFTMAX_GRAD_GPU(T)                                          \
673   template <>                                                               \
674   Status CSRSparseMatrixSoftmaxGrad<GPUDevice, T>::operator()(              \
675       OpKernelContext* ctx, const CSRSparseMatrix& softmax,                 \
676       const CSRSparseMatrix& grad_softmax,                                  \
677       typename TTypes<T>::Vec gradient_values) {                            \
678     return CSRSparseMatrixSoftmaxGradGPUImpl<T>(ctx, softmax, grad_softmax, \
679                                                 gradient_values);           \
680   }
681 
682 DEFINE_SOFTMAX_GRAD_GPU(float);
683 DEFINE_SOFTMAX_GRAD_GPU(double);
684 
685 #undef DEFINE_SOFTMAX_GRAD_GPU
686 
687 }  // namespace functor
688 
689 }  // namespace tensorflow
690 
691 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
692