1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17
18 #define EIGEN_USE_GPU
19
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/kernels/gpu_device_array.h"
24 #include "tensorflow/core/kernels/gpu_device_array_gpu.h"
25 #include "tensorflow/core/kernels/gpu_prim.h"
26 #include "tensorflow/core/kernels/sparse/kernels.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/cuda_sparse.h"
30 #include "tensorflow/core/util/gpu_kernel_helper.h"
31
32 namespace tensorflow {
33
34 typedef Eigen::GpuDevice GPUDevice;
35
36 namespace functor {
37
38 namespace {
39 struct StridedDataReader {
StridedDataReadertensorflow::functor::__anon4cc7d65c0111::StridedDataReader40 StridedDataReader(const int64* begin, int stride)
41 : begin_(begin), stride_(stride) {}
42
operator ()tensorflow::functor::__anon4cc7d65c0111::StridedDataReader43 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
44 return static_cast<int>(ldg(begin_ + idx * stride_));
45 }
46
47 const int64* begin_;
48 const int stride_;
49 };
50 } // namespace
51
52 template <>
operator ()(OpKernelContext * c,TTypes<int64_t>::ConstMatrix indices,TTypes<int32>::Vec nnz_per_batch)53 Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
54 OpKernelContext* c, TTypes<int64_t>::ConstMatrix indices,
55 TTypes<int32>::Vec nnz_per_batch) {
56 const auto& cu_stream = GetGpuStream(c);
57
58 const int total_nnz = indices.dimension(0);
59 const int size = nnz_per_batch.size();
60
61 DCHECK_EQ(indices.rank(), 2);
62 DCHECK_EQ(indices.dimension(1), 3); // batch, row, col
63
64 const int rank = indices.dimension(1);
65 gpuprim::CountingInputIterator<int> row_counter(0);
66 gpuprim::TransformInputIterator<int, StridedDataReader,
67 gpuprim::CountingInputIterator<int>>
68 indices_first_column(row_counter,
69 StridedDataReader(indices.data(), rank));
70
71 std::size_t temp_storage_bytes = 0;
72
73 DCHECK_NE(indices.data(), nullptr);
74 DCHECK_NE(nnz_per_batch.data(), nullptr);
75
76 auto first_success = gpuprim::DeviceHistogram::HistogramEven(
77 /*d_temp_storage*/ nullptr,
78 /*temp_storage_bytes&*/ temp_storage_bytes,
79 /*d_samples*/ indices_first_column,
80 /*d_histogram*/ nnz_per_batch.data(),
81 /*num_levels*/ size + 1,
82 /*lower_level*/ 0,
83 /*upper_level*/ size,
84 /*num_samples*/ total_nnz,
85 /*stream*/ cu_stream);
86
87 if (first_success != gpuSuccess) {
88 return errors::Internal(
89 "SparseTensorToCSRSparseMatrix: Could not launch "
90 "gpuprim::DeviceHistogram::HistogramEven "
91 "to calculate temp_storage_bytes, status: ",
92 GpuGetErrorString(first_success));
93 }
94
95 Tensor temp_storage;
96 TF_RETURN_IF_ERROR(c->allocate_temp(
97 DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
98 &temp_storage));
99 DCHECK_NE(temp_storage.flat<int8>().data(), nullptr);
100 auto second_success = gpuprim::DeviceHistogram::HistogramEven(
101 /*d_temp_storage*/ temp_storage.flat<int8>().data(),
102 /*temp_storage_bytes&*/ temp_storage_bytes,
103 /*d_samples*/ indices_first_column,
104 /*d_histogram*/ nnz_per_batch.data(),
105 /*num_levels*/ size + 1,
106 /*lower_level*/ 0,
107 /*upper_level*/ size,
108 /*num_samples*/ total_nnz,
109 /*stream*/ cu_stream);
110
111 if (second_success != gpuSuccess) {
112 return errors::Internal(
113 "SparseTensorToCSRSparseMatrix: Could not launch "
114 "gpuprim::DeviceHistogram::HistogramEven "
115 "to count nnz entries per batch. temp_storage_bytes: ",
116 temp_storage_bytes, ", status: ", GpuGetErrorString(second_success));
117 }
118
119 return Status::OK();
120 }
121
122 // TODO(ebrevdo): Write a custom batch-friendly impl of this to update
123 // the SparseTensor indices directly.
124 template <>
operator ()(OpKernelContext * c,TTypes<const int>::UnalignedVec csr_row_ptr,TTypes<int>::UnalignedVec coo_row_ind)125 Status CSRSparseMatrixToCOOSparseMatrix<GPUDevice>::operator()(
126 OpKernelContext* c, TTypes<const int>::UnalignedVec csr_row_ptr,
127 TTypes<int>::UnalignedVec coo_row_ind) {
128 GpuSparse gpu_sparse(c);
129 const int nnz = coo_row_ind.size();
130 TF_RETURN_IF_ERROR(gpu_sparse.Initialize());
131 const int m = csr_row_ptr.size() - 1; // rows
132 return gpu_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data());
133 }
134
135 template <int stride>
SparseTensorToCOOMatrixKernel(const int64 * indices,int * coo_rows_out,int * coo_cols_out,int size)136 __global__ void SparseTensorToCOOMatrixKernel(const int64* indices,
137 int* coo_rows_out,
138 int* coo_cols_out, int size) {
139 const int offset = (stride == 3) ? 1 : 0;
140 GPU_1D_KERNEL_LOOP(i, size) {
141 coo_rows_out[i] = static_cast<int>(ldg(indices + i * stride + offset));
142 coo_cols_out[i] = static_cast<int>(ldg(indices + i * stride + offset + 1));
143 }
144 }
145
146 template <>
operator ()(const GPUDevice & d,TTypes<int64_t>::ConstVec host_dense_shape,TTypes<int64_t>::ConstMatrix indices,TTypes<int>::Vec coo_row_ind,TTypes<int>::Vec coo_col_ind)147 void SparseTensorToCOOSparseMatrix<GPUDevice>::operator()(
148 const GPUDevice& d, TTypes<int64_t>::ConstVec host_dense_shape,
149 TTypes<int64_t>::ConstMatrix indices, TTypes<int>::Vec coo_row_ind,
150 TTypes<int>::Vec coo_col_ind) {
151 const int stride = host_dense_shape.size();
152 DCHECK(stride == 2 || stride == 3);
153 DCHECK_EQ(stride, indices.dimension(1));
154 const int size = coo_row_ind.dimension(0);
155 GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
156 if (stride == 2) {
157 TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<2>,
158 config.block_count, config.thread_per_block, 0,
159 d.stream(), indices.data(), coo_row_ind.data(),
160 coo_col_ind.data(), size));
161 } else {
162 TF_CHECK_OK(GpuLaunchKernel(SparseTensorToCOOMatrixKernel<3>,
163 config.block_count, config.thread_per_block, 0,
164 d.stream(), indices.data(), coo_row_ind.data(),
165 coo_col_ind.data(), size));
166 }
167 }
168
COOMatrixToSparseTensorKernel2D(const int * coo_rows,const int * coo_cols,int64 * indices_out,int size)169 __global__ void COOMatrixToSparseTensorKernel2D(const int* coo_rows,
170 const int* coo_cols,
171 int64* indices_out, int size) {
172 GPU_1D_KERNEL_LOOP(i, size) {
173 indices_out[i * 2] = static_cast<int64_t>(ldg(coo_rows + i));
174 indices_out[i * 2 + 1] = static_cast<int64_t>(ldg(coo_cols + i));
175 }
176 }
177
BinarySearchRange(int * range,int n,int x)178 __device__ inline int BinarySearchRange(int* range, int n, int x) {
179 int left = 0;
180 int right = n - 1;
181 while (left < right) {
182 int mid = left + (right - left) / 2;
183 if (x < range[mid])
184 right = mid - 1;
185 else if (range[mid + 1] <= x)
186 left = mid + 1;
187 else
188 return mid; // range[mid] <= x < range[mid + 1].
189 }
190 return left;
191 }
192
COOMatrixToSparseTensorKernel3D(const int * coo_rows,const int * coo_cols,int64 * indices_out,GpuDeviceArrayStruct<int> batch_ptr_s,const int batch_size,const int size)193 __global__ void COOMatrixToSparseTensorKernel3D(
194 const int* coo_rows, const int* coo_cols, int64* indices_out,
195 GpuDeviceArrayStruct<int> batch_ptr_s, const int batch_size,
196 const int size) {
197 // Step 1: access the batch ptrs and copy to shared memory.
198 const int* batch_ptr = GetGpuDeviceArrayOnDevice(&batch_ptr_s);
199 extern __shared__ int local_batch_ptr[];
200 for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
201 local_batch_ptr[i] = batch_ptr[i];
202 }
203 __syncthreads();
204
205 GPU_1D_KERNEL_LOOP(i, size) {
206 // TODO(ebrevdo): Consider special casing batch_size <= 3,
207 // alternatively doing linear instead of binary search. Requires
208 // some benchmarks.
209 const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
210 indices_out[i * 3] = static_cast<int64_t>(b);
211 indices_out[i * 3 + 1] = static_cast<int64_t>(ldg(coo_rows + i));
212 indices_out[i * 3 + 2] = static_cast<int64_t>(ldg(coo_cols + i));
213 }
214 }
215
216 template <>
operator ()(OpKernelContext * c,TTypes<int64_t>::ConstVec host_dense_shape,TTypes<int>::ConstVec host_batch_ptr,TTypes<int>::Vec coo_row_ind,TTypes<int>::ConstVec coo_col_ind,TTypes<int64_t>::Matrix indices)217 Status COOSparseMatrixToSparseTensor<GPUDevice>::operator()(
218 OpKernelContext* c, TTypes<int64_t>::ConstVec host_dense_shape,
219 TTypes<int>::ConstVec host_batch_ptr, TTypes<int>::Vec coo_row_ind,
220 TTypes<int>::ConstVec coo_col_ind, TTypes<int64_t>::Matrix indices) {
221 const int ndims = indices.dimension(1);
222 DCHECK(ndims == 2 || ndims == 3);
223 DCHECK_EQ(ndims, host_dense_shape.size());
224 DCHECK_NE(coo_row_ind.data(), nullptr);
225 DCHECK_NE(coo_col_ind.data(), nullptr);
226 DCHECK_NE(indices.data(), nullptr);
227 const GPUDevice& d = c->eigen_device<GPUDevice>();
228 const int size = coo_row_ind.size();
229 DCHECK_EQ(size, coo_col_ind.size());
230 DCHECK_EQ(size, indices.dimension(0));
231 if (ndims == 2) {
232 GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
233 TF_CHECK_OK(GpuLaunchKernel(COOMatrixToSparseTensorKernel2D,
234 config.block_count, config.thread_per_block, 0,
235 d.stream(), coo_row_ind.data(),
236 coo_col_ind.data(), indices.data(), size));
237 return Status::OK();
238 } else {
239 const int batch_size = host_dense_shape(0);
240 GpuDeviceArrayOnHost<int> batch_ptr_copy(c, host_batch_ptr.size());
241 TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
242 for (int i = 0; i < batch_size; ++i) {
243 batch_ptr_copy.Set(i, host_batch_ptr(i));
244 }
245 TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
246 GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
247 // shared memory stores the batch pointers.
248 const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
249 TF_CHECK_OK(
250 GpuLaunchKernel(COOMatrixToSparseTensorKernel3D, config.block_count,
251 config.thread_per_block, shared_memory_size, d.stream(),
252 coo_row_ind.data(), coo_col_ind.data(), indices.data(),
253 batch_ptr_copy.data(), batch_size, size));
254 return Status::OK();
255 }
256 }
257
258 template <typename T>
CSRSparseMatrixBatchMulVecKernel3D(const T * a_values,const T * b_batch_values,T * c_values,GpuDeviceArrayStruct<int> batch_ptr_s,const int batch_size,const int total_nnz)259 __global__ void CSRSparseMatrixBatchMulVecKernel3D(
260 const T* a_values, const T* b_batch_values, T* c_values,
261 GpuDeviceArrayStruct<int> batch_ptr_s, const int batch_size,
262 const int total_nnz) {
263 // Step 1: Access the batch ptrs and copy to shared memory.
264 // Also copy the per-batch multipliers into shared memory.
265 const int* batch_ptr = GetGpuDeviceArrayOnDevice(&batch_ptr_s);
266 extern __shared__ int local_batch_ptr[];
267 T* local_batch_values =
268 reinterpret_cast<T*>(local_batch_ptr + batch_size + 1);
269 for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
270 local_batch_ptr[i] = batch_ptr[i];
271 if (i < batch_size) {
272 local_batch_values[i] = b_batch_values[i];
273 }
274 }
275 __syncthreads();
276
277 GPU_1D_KERNEL_LOOP(i, total_nnz) {
278 const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
279 c_values[i] = ldg(a_values + i) * local_batch_values[b];
280 }
281 }
282
283 template <typename T>
CSRSparseMatrixBatchMulVecImpl(OpKernelContext * ctx,const CSRSparseMatrix & a,typename TTypes<T>::ConstFlat b,CSRSparseMatrix * c)284 Status CSRSparseMatrixBatchMulVecImpl(OpKernelContext* ctx,
285 const CSRSparseMatrix& a,
286 typename TTypes<T>::ConstFlat b,
287 CSRSparseMatrix* c) {
288 DCHECK_EQ(a.dims(), 3);
289 const int total_nnz = a.total_nnz();
290 Tensor c_values_t;
291 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
292 TensorShape({total_nnz}), &c_values_t));
293 TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
294 DataTypeToEnum<T>::value, a.dense_shape(), a.batch_pointers(),
295 a.row_pointers(), a.col_indices(), c_values_t, c));
296
297 auto a_values = a.values().flat<T>();
298 auto c_values = c_values_t.flat<T>();
299
300 auto host_dense_shape = a.dense_shape().vec<int64_t>();
301 auto host_batch_ptr = a.batch_pointers().vec<int>();
302
303 const GPUDevice& d = ctx->eigen_device<GPUDevice>();
304
305 const int batch_size = host_dense_shape(0);
306 DCHECK_EQ(b.size(), batch_size);
307
308 GpuDeviceArrayOnHost<int> batch_ptr_copy(ctx, host_batch_ptr.size());
309 TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
310 for (int i = 0; i < batch_size; ++i) {
311 batch_ptr_copy.Set(i, host_batch_ptr(i));
312 }
313 TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
314 GpuLaunchConfig config = GetGpuLaunchConfig(total_nnz, d);
315 // shared memory stores the batch pointers.
316 const size_t shared_memory_size =
317 (sizeof(int) * (batch_size + 1) // local batch_pointers.
318 + sizeof(T) * batch_size); // local copy of b.
319 TF_CHECK_OK(GpuLaunchKernel(
320 CSRSparseMatrixBatchMulVecKernel3D<T>, config.block_count,
321 config.thread_per_block, shared_memory_size, d.stream(), a_values.data(),
322 b.data(), c_values.data(), batch_ptr_copy.data(), batch_size, total_nnz));
323
324 return Status::OK();
325 }
326
327 #define DEFINE_SPARSE_MUL_VEC_GPU(T) \
328 template <> \
329 CSRSparseMatrixBatchMulVec<GPUDevice, T>::CSRSparseMatrixBatchMulVec() {} \
330 template <> \
331 Status CSRSparseMatrixBatchMulVec<GPUDevice, T>::Compute( \
332 OpKernelContext* ctx, const CSRSparseMatrix& a, \
333 typename TTypes<T>::ConstFlat b, CSRSparseMatrix* c) { \
334 return CSRSparseMatrixBatchMulVecImpl<T>(ctx, a, b, c); \
335 }
336
337 DEFINE_SPARSE_MUL_VEC_GPU(float);
338 DEFINE_SPARSE_MUL_VEC_GPU(double);
339 DEFINE_SPARSE_MUL_VEC_GPU(std::complex<float>);
340 DEFINE_SPARSE_MUL_VEC_GPU(std::complex<double>);
341
342 #undef DEFINE_SPARSE_MUL_VEC_GPU
343
344 template <typename T>
CalculateRowSoftmax(const int begin,const int end,const T * logits,T * softmax)345 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void CalculateRowSoftmax(const int begin,
346 const int end,
347 const T* logits,
348 T* softmax) {
349 // For each row, calculate the vector:
350 // softmax[row] = exp(shifted_logits[row]) / sum(exp(shifted_logits[row]))
351 // where
352 // shifted_logits[row] = logits[row] - max(logits[row])
353 // are the logits normalized for stability.
354 T row_max = Eigen::NumTraits<T>::lowest();
355 for (int r_i = begin; r_i < end; ++r_i) {
356 row_max = Eigen::numext::maxi(row_max, ldg(logits + r_i));
357 }
358 T sum_exp = 0;
359 for (int r_i = begin; r_i < end; ++r_i) {
360 const T exp_i = Eigen::numext::exp(ldg(logits + r_i) - row_max);
361 softmax[r_i] = exp_i;
362 sum_exp += exp_i;
363 }
364 for (int r_i = begin; r_i < end; ++r_i) {
365 softmax[r_i] = softmax[r_i] / sum_exp;
366 }
367 }
368
369 template <typename T>
CSRSparseMatrixSoftmaxKernel2D(const int rows,const int * row_ptr,const T * logits,T * softmax)370 __global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows,
371 const int* row_ptr,
372 const T* logits, T* softmax) {
373 // TODO(ebrevdo): consider something like a merge-path based
374 // algorithm to distribute the work in case the row sizes are
375 // uneven:
376 // http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
377 GPU_1D_KERNEL_LOOP(row, rows) {
378 CalculateRowSoftmax(ldg(row_ptr + row), ldg(row_ptr + row + 1), logits,
379 softmax);
380 }
381 }
382
CopyFromGpuDeviceArrayToLocal(GpuDeviceArrayStruct<int> cuda_ptr_s,int * local_ptr,int length)383 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void CopyFromGpuDeviceArrayToLocal(
384 GpuDeviceArrayStruct<int> cuda_ptr_s, int* local_ptr, int length) {
385 #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
386 const int* cuda_ptr = GetGpuDeviceArrayOnDevice(&cuda_ptr_s);
387 for (int i = threadIdx.x; i < length; i += blockDim.x) {
388 local_ptr[i] = cuda_ptr[i];
389 }
390 __syncthreads();
391 #endif
392 }
393
394 template <typename T>
CSRSparseMatrixSoftmaxKernel3D(const int size,const int rows,GpuDeviceArrayStruct<int> batch_ptr_s,const int * row_ptr,const T * logits,T * softmax)395 __global__ void CSRSparseMatrixSoftmaxKernel3D(
396 const int size, const int rows, GpuDeviceArrayStruct<int> batch_ptr_s,
397 const int* row_ptr, const T* logits, T* softmax) {
398 // TODO(ebrevdo): consider something like a merge-path based
399 // algorithm to distribute the work in case the row sizes are
400 // uneven:
401 // http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
402 const int batch_size = size / rows;
403 extern __shared__ int local_batch_ptr[];
404 CopyFromGpuDeviceArrayToLocal(std::move(batch_ptr_s), local_batch_ptr,
405 batch_size + 1);
406
407 GPU_1D_KERNEL_LOOP(i, size) {
408 const int batch = i / rows;
409 const int row = i % rows;
410 const int batch_offset = local_batch_ptr[batch];
411 const int row_offset = batch * (rows + 1) + row;
412 CalculateRowSoftmax(batch_offset + ldg(row_ptr + row_offset),
413 batch_offset + ldg(row_ptr + row_offset + 1), logits,
414 softmax);
415 }
416 }
417
418 template <typename T>
CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext * ctx,const CSRSparseMatrix & logits,typename TTypes<T>::Vec softmax_values)419 Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx,
420 const CSRSparseMatrix& logits,
421 typename TTypes<T>::Vec softmax_values) {
422 auto host_dense_shape = logits.dense_shape().vec<int64_t>();
423 auto host_batch_ptr = logits.batch_pointers().vec<int32>();
424 auto row_ptr = logits.row_pointers().vec<int32>();
425 auto logits_values = logits.values().vec<T>();
426
427 const int ndims = host_dense_shape.size();
428 DCHECK(ndims == 2 || ndims == 3);
429 const GPUDevice& d = ctx->eigen_device<GPUDevice>();
430 if (ndims == 2) {
431 const int rows = host_dense_shape(0);
432 DCHECK_EQ(rows, row_ptr.size() - 1);
433 GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
434 TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel2D<T>,
435 config.block_count, config.thread_per_block, 0,
436 d.stream(), rows /*size*/, row_ptr.data(),
437 logits_values.data(), softmax_values.data()));
438 } else {
439 const int batch_size = host_dense_shape(0);
440 const int rows = host_dense_shape(1);
441 DCHECK_EQ(batch_size, host_batch_ptr.size() - 1);
442 DCHECK_EQ((rows + 1) * batch_size, row_ptr.size());
443 const int size = rows * batch_size;
444
445 GpuDeviceArrayOnHost<int> batch_ptr_copy(ctx, host_batch_ptr.size());
446 TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
447 for (int i = 0; i < host_batch_ptr.size(); ++i) {
448 batch_ptr_copy.Set(i, host_batch_ptr(i));
449 }
450 TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
451
452 GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
453 // shared memory stores the batch pointers.
454 const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
455 TF_CHECK_OK(GpuLaunchKernel(CSRSparseMatrixSoftmaxKernel3D<T>,
456 config.block_count, config.thread_per_block,
457 shared_memory_size, d.stream(), size, rows,
458 batch_ptr_copy.data(), row_ptr.data(),
459 logits_values.data(), softmax_values.data()));
460 }
461
462 return Status::OK();
463 }
464
465 #define DEFINE_SOFTMAX_GPU(T) \
466 template <> \
467 Status CSRSparseMatrixSoftmax<GPUDevice, T>::operator()( \
468 OpKernelContext* ctx, const CSRSparseMatrix& logits, \
469 typename TTypes<T>::Vec softmax_values) { \
470 return CSRSparseMatrixSoftmaxGPUImpl<T>(ctx, logits, softmax_values); \
471 }
472
473 DEFINE_SOFTMAX_GPU(float);
474 DEFINE_SOFTMAX_GPU(double);
475
476 #undef DEFINE_SOFTMAX_GPU
477
478 template <typename T>
CalculateRowSoftmaxGrad(const int softmax_begin,const int softmax_end,const int * softmax_col_ind,const T * softmax,const int grad_softmax_begin,const int grad_softmax_end,const int * grad_softmax_col_ind,const T * grad_softmax,T * gradient)479 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void CalculateRowSoftmaxGrad(
480 const int softmax_begin, const int softmax_end, const int* softmax_col_ind,
481 const T* softmax, const int grad_softmax_begin, const int grad_softmax_end,
482 const int* grad_softmax_col_ind, const T* grad_softmax, T* gradient) {
483 // Iterate from
484 // softmax_col_ind[softmax_begin] to
485 // softmax_col_ind[softmax_end]
486 // and from
487 // grad_softmax_col_ind[grad_softmax_begin] to
488 // grad_softmax_col_ind[grad_softmax_end]
489 //
490 // looking for matching indices. In the softmax indices only, perform:
491 //
492 // gradient = (grad_softmax - sum(grad_softmax * softmax)) * softmax
493 //
494 // where the sum is along the given row.
495 T sum_prod = 0;
496 for (int i = softmax_begin, j = grad_softmax_begin;
497 i < softmax_end && j < grad_softmax_end;) {
498 const int softmax_col = ldg(softmax_col_ind + i);
499 const int grad_softmax_col = ldg(grad_softmax_col_ind + j);
500 if (softmax_col == grad_softmax_col) {
501 sum_prod += ldg(softmax + i) * ldg(grad_softmax + j);
502 ++i;
503 ++j;
504 } else if (softmax_col > grad_softmax_col) {
505 ++j;
506 } else {
507 ++i;
508 }
509 }
510
511 // Find an upper bound on the column numbers in this row; for use in
512 // the special case of a empty grad_softmax row and a non-empty
513 // softmax row.
514 const int softmax_col_upper_bound =
515 (softmax_begin == softmax_end)
516 ? -1
517 : ldg(softmax_col_ind + softmax_end - 1) + 1;
518 for (int i = softmax_begin, j = grad_softmax_begin; i < softmax_end;) {
519 const int softmax_col = ldg(softmax_col_ind + i);
520 // We need to keep a large grad_softmax_col value if we're at the
521 // end of the grad_softmax row, so we can fill in the remainder of
522 // the gradients row (the last if branch in this loop).
523 const int grad_softmax_col = (j == grad_softmax_end)
524 ? softmax_col_upper_bound
525 : ldg(grad_softmax_col_ind + j);
526
527 if (softmax_col == grad_softmax_col) {
528 gradient[i] = (ldg(grad_softmax + j) - sum_prod) * ldg(softmax + i);
529 ++i;
530 ++j;
531 } else if (softmax_col > grad_softmax_col) {
532 // grad_softmax is nonzero here, but since softmax is zero, the
533 // gradient is 0; so we skip it since the sparsity structure
534 // already encodes this zero.
535 ++j;
536 } else {
537 // grad_softmax is zero but softmax is not.
538 gradient[i] = -sum_prod * ldg(softmax + i);
539 ++i;
540 }
541 }
542 }
543
544 template <typename T>
CSRSparseMatrixSoftmaxGradKernel2D(const int rows,const int * softmax_row_ptr,const int * softmax_col_ind,const T * softmax,const int * grad_softmax_row_ptr,const int * grad_softmax_col_ind,const T * grad_softmax,T * gradient)545 __global__ void CSRSparseMatrixSoftmaxGradKernel2D(
546 const int rows, const int* softmax_row_ptr, const int* softmax_col_ind,
547 const T* softmax, const int* grad_softmax_row_ptr,
548 const int* grad_softmax_col_ind, const T* grad_softmax, T* gradient) {
549 // TODO(ebrevdo): consider something like a merge-path based
550 // algorithm to distribute the work in case the row sizes are
551 // uneven:
552 // http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
553 GPU_1D_KERNEL_LOOP(row, rows) {
554 CalculateRowSoftmaxGrad(
555 ldg(softmax_row_ptr + row) /*softmax_begin*/,
556 ldg(softmax_row_ptr + row + 1) /*softmax_end*/, softmax_col_ind,
557 softmax, ldg(grad_softmax_row_ptr + row) /*grad_softmax_begin*/,
558 ldg(grad_softmax_row_ptr + row + 1) /*grad_softmax_end*/,
559 grad_softmax_col_ind, grad_softmax, gradient);
560 }
561 }
562
563 template <typename T>
CSRSparseMatrixSoftmaxGradKernel3D(const int size,const int rows,GpuDeviceArrayStruct<int> softmax_and_grad_batch_ptr_s,const int * softmax_row_ptr,const int * softmax_col_ind,const T * softmax,const int * grad_softmax_row_ptr,const int * grad_softmax_col_ind,const T * grad_softmax,T * gradient)564 __global__ void CSRSparseMatrixSoftmaxGradKernel3D(
565 const int size, const int rows,
566 GpuDeviceArrayStruct<int> softmax_and_grad_batch_ptr_s,
567 const int* softmax_row_ptr, const int* softmax_col_ind, const T* softmax,
568 const int* grad_softmax_row_ptr, const int* grad_softmax_col_ind,
569 const T* grad_softmax, T* gradient) {
570 // TODO(ebrevdo): consider something like a merge-path based
571 // algorithm to distribute the work in case the row sizes are
572 // uneven:
573 // http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
574
575 const int batch_size = size / rows;
576 extern __shared__ int local_batch_ptr[];
577 CopyFromGpuDeviceArrayToLocal(std::move(softmax_and_grad_batch_ptr_s),
578 local_batch_ptr, 2 * (batch_size + 1));
579
580 #define SOFTMAX_BATCH_PTR(i) local_batch_ptr[i];
581 #define GRAD_SOFTMAX_BATCH_PTR(i) local_batch_ptr[batch_size + 1 + i];
582
583 GPU_1D_KERNEL_LOOP(i, size) {
584 const int batch = i / rows;
585 const int row = i % rows;
586 const int softmax_batch_offset = SOFTMAX_BATCH_PTR(batch);
587 const int grad_softmax_batch_offset = GRAD_SOFTMAX_BATCH_PTR(batch);
588 const int row_offset = batch * (rows + 1) + row;
589 CalculateRowSoftmaxGrad(
590 softmax_batch_offset +
591 ldg(softmax_row_ptr + row_offset) /*softmax_begin*/,
592 softmax_batch_offset +
593 ldg(softmax_row_ptr + row_offset + 1) /*softmax_end*/,
594 softmax_col_ind, softmax,
595 grad_softmax_batch_offset +
596 ldg(grad_softmax_row_ptr + row_offset) /*grad_softmax_begin*/,
597 grad_softmax_batch_offset +
598 ldg(grad_softmax_row_ptr + row_offset + 1) /*grad_softmax_end*/,
599 grad_softmax_col_ind, grad_softmax, gradient);
600 }
601
602 #undef SOFTMAX_BATCH_PTR
603 #undef GRAD_SOFTMAX_BATCH_PTR
604 }
605
606 template <typename T>
CSRSparseMatrixSoftmaxGradGPUImpl(OpKernelContext * ctx,const CSRSparseMatrix & softmax,const CSRSparseMatrix & grad_softmax,typename TTypes<T>::Vec gradient_values)607 Status CSRSparseMatrixSoftmaxGradGPUImpl(
608 OpKernelContext* ctx, const CSRSparseMatrix& softmax,
609 const CSRSparseMatrix& grad_softmax,
610 typename TTypes<T>::Vec gradient_values) {
611 auto host_dense_shape = softmax.dense_shape().vec<int64_t>();
612 auto softmax_host_batch_ptr = softmax.batch_pointers().vec<int32>();
613 auto softmax_row_ptr = softmax.row_pointers().vec<int32>();
614 auto softmax_col_ind = softmax.col_indices().vec<int32>();
615 auto softmax_values = softmax.values().vec<T>();
616 auto grad_softmax_host_batch_ptr = grad_softmax.batch_pointers().vec<int32>();
617 auto grad_softmax_row_ptr = grad_softmax.row_pointers().vec<int32>();
618 auto grad_softmax_col_ind = grad_softmax.col_indices().vec<int32>();
619 auto grad_softmax_values = grad_softmax.values().vec<T>();
620
621 const int ndims = host_dense_shape.size();
622 DCHECK(ndims == 2 || ndims == 3);
623 const int rows = host_dense_shape(0);
624 const GPUDevice& d = ctx->eigen_device<GPUDevice>();
625 if (ndims == 2) {
626 DCHECK_EQ(rows + 1, softmax_row_ptr.size());
627 DCHECK_EQ(rows + 1, grad_softmax_row_ptr.size());
628 GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
629 TF_CHECK_OK(GpuLaunchKernel(
630 CSRSparseMatrixSoftmaxGradKernel2D<T>, config.block_count,
631 config.thread_per_block, 0, d.stream(), rows /*size*/,
632 softmax_row_ptr.data(), softmax_col_ind.data(), softmax_values.data(),
633 grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
634 grad_softmax_values.data(), gradient_values.data()));
635 } else {
636 const int batch_size = host_dense_shape(0);
637 const int rows = host_dense_shape(1);
638 DCHECK_EQ(batch_size, softmax_host_batch_ptr.size() - 1);
639 DCHECK_EQ(batch_size, grad_softmax_host_batch_ptr.size() - 1);
640 DCHECK_EQ((rows + 1) * batch_size, softmax_row_ptr.size());
641 DCHECK_EQ((rows + 1) * batch_size, grad_softmax_row_ptr.size());
642 const int size = rows * batch_size;
643 // The length of softmax_and_grad_batch_ptr_copy is 2 * (batch_size + 1)
644 // The first (batch_size + 1) entries contain softmax_batch_ptr and
645 // the second (batch_size + 1) entries contain grad_softmax_batch_ptr.
646 GpuDeviceArrayOnHost<int> softmax_and_grad_batch_ptr_copy(
647 ctx, 2 * softmax_host_batch_ptr.size());
648 TF_RETURN_IF_ERROR(softmax_and_grad_batch_ptr_copy.Init());
649 for (int i = 0; i < softmax_host_batch_ptr.size(); ++i) {
650 softmax_and_grad_batch_ptr_copy.Set(i, softmax_host_batch_ptr(i));
651 softmax_and_grad_batch_ptr_copy.Set(batch_size + 1 + i,
652 grad_softmax_host_batch_ptr(i));
653 }
654 TF_RETURN_IF_ERROR(softmax_and_grad_batch_ptr_copy.Finalize());
655
656 GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
657 // shared memory stores two copies of batch pointers: one for the
658 // softmax CSR matrix, one for the grad_softmax CSR matrix.
659 const size_t shared_memory_size = 2 * sizeof(int) * (batch_size + 1);
660 TF_CHECK_OK(GpuLaunchKernel(
661 CSRSparseMatrixSoftmaxGradKernel3D<T>, config.block_count,
662 config.thread_per_block, shared_memory_size, d.stream(), size, rows,
663 softmax_and_grad_batch_ptr_copy.data(), softmax_row_ptr.data(),
664 softmax_col_ind.data(), softmax_values.data(),
665 grad_softmax_row_ptr.data(), grad_softmax_col_ind.data(),
666 grad_softmax_values.data(), gradient_values.data()));
667 }
668
669 return Status::OK();
670 }
671
672 #define DEFINE_SOFTMAX_GRAD_GPU(T) \
673 template <> \
674 Status CSRSparseMatrixSoftmaxGrad<GPUDevice, T>::operator()( \
675 OpKernelContext* ctx, const CSRSparseMatrix& softmax, \
676 const CSRSparseMatrix& grad_softmax, \
677 typename TTypes<T>::Vec gradient_values) { \
678 return CSRSparseMatrixSoftmaxGradGPUImpl<T>(ctx, softmax, grad_softmax, \
679 gradient_values); \
680 }
681
682 DEFINE_SOFTMAX_GRAD_GPU(float);
683 DEFINE_SOFTMAX_GRAD_GPU(double);
684
685 #undef DEFINE_SOFTMAX_GRAD_GPU
686
687 } // namespace functor
688
689 } // namespace tensorflow
690
691 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
692