1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17
18 #define EIGEN_USE_GPU
19
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/kernels/gpu_device_array.h"
25 #include "tensorflow/core/kernels/gpu_device_array_gpu.h"
26 #include "tensorflow/core/kernels/gpu_prim_helpers.h"
27 #include "tensorflow/core/kernels/sparse_concat_op.h"
28 #include "tensorflow/core/lib/core/bits.h"
29 #include "tensorflow/core/util/gpu_kernel_helper.h"
30
31 namespace tensorflow {
32
33 typedef Eigen::GpuDevice GPUDevice;
34
35 namespace functor {
36
37 namespace {
38
39 template <typename T>
SparseConcatKernel(int64 output_nnz,int rank,int concat_dim,bool need_to_sort,GpuDeviceArrayStruct<const int64 * > ind_ptrs_data,GpuDeviceArrayStruct<const T * > val_ptrs_data,GpuDeviceArrayStruct<int64_t> nnz_scan_data,GpuDeviceArrayStruct<int64_t> concat_size_scan_data,GpuDeviceArrayStruct<int64_t> output_shape_data,int64 * __restrict__ output_inds,T * __restrict__ output_vals,int64 * __restrict__ output_flat_inds)40 __global__ void SparseConcatKernel(
41 int64 output_nnz, int rank, int concat_dim, bool need_to_sort,
42 GpuDeviceArrayStruct<const int64*> ind_ptrs_data,
43 GpuDeviceArrayStruct<const T*> val_ptrs_data,
44 GpuDeviceArrayStruct<int64_t> nnz_scan_data,
45 GpuDeviceArrayStruct<int64_t> concat_size_scan_data,
46 GpuDeviceArrayStruct<int64_t> output_shape_data,
47 int64* __restrict__ output_inds, T* __restrict__ output_vals,
48 int64* __restrict__ output_flat_inds) {
49 const int64* __restrict__* __restrict__ ind_ptrs =
50 GetGpuDeviceArrayOnDevice(&ind_ptrs_data);
51 const T* __restrict__* __restrict__ val_ptrs =
52 GetGpuDeviceArrayOnDevice(&val_ptrs_data);
53 const int64* __restrict__ nnz_scan =
54 GetGpuDeviceArrayOnDevice(&nnz_scan_data);
55 const int64* __restrict__ concat_size_scan =
56 GetGpuDeviceArrayOnDevice(&concat_size_scan_data);
57 const int64* __restrict__ output_shape =
58 GetGpuDeviceArrayOnDevice(&output_shape_data);
59 const int64 num_inputs = ind_ptrs_data.size;
60
61 for (int64 nz : GpuGridRangeX<int64_t>(output_nnz)) {
62 const int64 input_num =
63 gpu_helper::upper_bound<int64_t>(nnz_scan, num_inputs, nz) - 1;
64 const int64 input_nz = nz - nnz_scan[input_num];
65 const int64 ind_offset = concat_size_scan[input_num];
66 if (!need_to_sort) {
67 output_vals[nz] = val_ptrs[input_num][input_nz];
68 }
69 int64 flat_ind = 0;
70 for (int j = 0; j < rank; ++j) {
71 const int64 output_ind = ind_ptrs[input_num][input_nz * rank + j] +
72 (j == concat_dim ? ind_offset : 0);
73 if (!need_to_sort) {
74 output_inds[nz * rank + j] = output_ind;
75 } else {
76 flat_ind = flat_ind * output_shape[j] + output_ind;
77 output_flat_inds[nz] = flat_ind;
78 }
79 }
80 }
81 }
82
83 template <typename T>
SparseConcatPermuteKernel(int64 output_nnz,int rank,GpuDeviceArrayStruct<const T * > val_ptrs_data,GpuDeviceArrayStruct<int64_t> nnz_scan_data,GpuDeviceArrayStruct<int64_t> output_shape_data,const int64 * __restrict__ output_flat_inds,const int64 * __restrict__ permutation,int64 * __restrict__ output_inds,T * __restrict__ output_vals)84 __global__ void SparseConcatPermuteKernel(
85 int64 output_nnz, int rank, GpuDeviceArrayStruct<const T*> val_ptrs_data,
86 GpuDeviceArrayStruct<int64_t> nnz_scan_data,
87 GpuDeviceArrayStruct<int64_t> output_shape_data,
88 const int64* __restrict__ output_flat_inds,
89 const int64* __restrict__ permutation, int64* __restrict__ output_inds,
90 T* __restrict__ output_vals) {
91 const T* __restrict__* __restrict__ val_ptrs =
92 GetGpuDeviceArrayOnDevice(&val_ptrs_data);
93 const int64* __restrict__ nnz_scan =
94 GetGpuDeviceArrayOnDevice(&nnz_scan_data);
95 const int64* __restrict__ output_shape =
96 GetGpuDeviceArrayOnDevice(&output_shape_data);
97 const int64 num_inputs = val_ptrs_data.size;
98
99 for (int64 nz : GpuGridRangeX<int64_t>(output_nnz)) {
100 const int64 permuted_nz = permutation[nz];
101 const int64 input_num =
102 gpu_helper::upper_bound<int64_t>(nnz_scan, num_inputs, permuted_nz) - 1;
103 const int64 input_nz = permuted_nz - nnz_scan[input_num];
104 output_vals[nz] = val_ptrs[input_num][input_nz];
105 int64 output_flat_ind = output_flat_inds[permuted_nz];
106 for (int j = rank - 1; j >= 0; --j) {
107 const int64 output_dim_size = output_shape[j];
108 output_inds[nz * rank + j] = output_flat_ind % output_dim_size;
109 output_flat_ind /= output_dim_size;
110 }
111 }
112 }
113
114 } // namespace
115
116 template <typename T>
117 struct SparseConcatFunctor<GPUDevice, T> {
operator ()tensorflow::functor::SparseConcatFunctor118 void operator()(OpKernelContext* context, const OpInputList& inds,
119 const OpInputList& vals, const OpInputList& shapes,
120 int concat_dim) {
121 const int N = inds.size();
122 const TensorShape input_shape0(shapes[0].vec<int64_t>());
123 const int rank = input_shape0.dims();
124
125 // The input non-zeros are assumed to be sorted by increasing dimension
126 // number (i.e., row-major order), so if the concatenation is along the
127 // first dimension then they remain in order and we can directly compute the
128 // output indices and values. To concatenate along other dimensions, we
129 // first compute the flattened (1D) row-major output indices, then sort
130 // these to obtain the required permutation, and finally gather the permuted
131 // input values.
132
133 GpuDeviceArrayOnHost<const int64*> ind_ptrs(context, N);
134 GpuDeviceArrayOnHost<const T*> val_ptrs(context, N);
135 GpuDeviceArrayOnHost<int64_t> nnz_scan(context, N + 1);
136 GpuDeviceArrayOnHost<int64_t> concat_size_scan(context, N + 1);
137 OP_REQUIRES_OK(context, ind_ptrs.Init());
138 OP_REQUIRES_OK(context, val_ptrs.Init());
139 OP_REQUIRES_OK(context, nnz_scan.Init());
140 OP_REQUIRES_OK(context, concat_size_scan.Init());
141 int64 nnz_sum = 0;
142 int64 concat_size_sum = 0;
143 nnz_scan.Set(0, nnz_sum);
144 concat_size_scan.Set(0, concat_size_sum);
145 for (int i = 0; i < N; ++i) {
146 ind_ptrs.Set(i, inds[i].matrix<int64_t>().data());
147 val_ptrs.Set(i, vals[i].vec<T>().data());
148 nnz_sum += inds[i].dim_size(0);
149 nnz_scan.Set(i + 1, nnz_sum);
150 const TensorShape current_shape(shapes[i].vec<int64_t>());
151 concat_size_sum += current_shape.dim_size(concat_dim);
152 concat_size_scan.Set(i + 1, concat_size_sum);
153 }
154 OP_REQUIRES_OK(context, ind_ptrs.Finalize());
155 OP_REQUIRES_OK(context, val_ptrs.Finalize());
156 OP_REQUIRES_OK(context, nnz_scan.Finalize());
157 OP_REQUIRES_OK(context, concat_size_scan.Finalize());
158 const int64 output_nnz = nnz_sum;
159 const int64 output_concat_size = concat_size_sum;
160
161 const bool need_to_sort = concat_dim != 0;
162
163 GpuDeviceArrayOnHost<int64_t> output_shape(context, rank);
164 int64 output_dense_elements;
165 if (need_to_sort) {
166 OP_REQUIRES_OK(context, output_shape.Init());
167 output_dense_elements = 1;
168 for (int j = 0; j < rank; ++j) {
169 int64 output_dim_size =
170 j == concat_dim ? output_concat_size : input_shape0.dim_size(j);
171 output_shape.Set(j, output_dim_size);
172 output_dense_elements *= output_dim_size;
173 }
174 OP_REQUIRES_OK(context, output_shape.Finalize());
175 }
176
177 int64* output_inds_ptr = nullptr;
178 T* output_vals_ptr = nullptr;
179 int64* output_flat_inds_ptr = nullptr;
180 Tensor output_flat_inds;
181 if (need_to_sort) {
182 // SparseConcatKernel will (only) produce output_flat_inds.
183 OP_REQUIRES_OK(context,
184 context->allocate_temp(DT_INT64, TensorShape({output_nnz}),
185 &output_flat_inds));
186 output_flat_inds_ptr = output_flat_inds.vec<int64_t>().data();
187 } else {
188 OP_REQUIRES_OK(
189 context, allocate_outputs(context, rank, output_nnz, &output_inds_ptr,
190 &output_vals_ptr));
191 }
192
193 const GPUDevice& device = context->eigen_gpu_device();
194
195 GpuLaunchConfig config = GetGpuLaunchConfig(
196 output_nnz, device, &SparseConcatKernel<T>,
197 /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
198 OP_REQUIRES_OK(
199 context, GpuLaunchKernel(
200 SparseConcatKernel<T>, config.block_count,
201 config.thread_per_block, 0, device.stream(), output_nnz,
202 rank, concat_dim, need_to_sort, ind_ptrs.data(),
203 val_ptrs.data(), nnz_scan.data(), concat_size_scan.data(),
204 (need_to_sort ? output_shape.data()
205 : GpuDeviceArrayStruct<int64_t>()),
206 output_inds_ptr, output_vals_ptr, output_flat_inds_ptr));
207
208 if (!need_to_sort) return;
209
210 OP_REQUIRES_OK(context,
211 allocate_outputs(context, rank, output_nnz, &output_inds_ptr,
212 &output_vals_ptr));
213
214 Tensor permutation;
215 OP_REQUIRES_OK(context,
216 context->allocate_temp(DT_INT64, TensorShape({output_nnz}),
217 &permutation));
218 int64* permutation_ptr = permutation.vec<int64_t>().data();
219 OP_REQUIRES_OK(
220 context,
221 GpuRadixSort(context, /*size=*/output_nnz,
222 /*keys_in=*/output_flat_inds_ptr,
223 /*keys_out=*/static_cast<int64*>(nullptr),
224 /*indices_in=*/static_cast<const int64*>(nullptr),
225 /*indices_out=*/permutation_ptr,
226 /*num_bits=*/Log2Ceiling64(output_dense_elements)));
227
228 config = GetGpuLaunchConfig(
229 output_nnz, device, &SparseConcatPermuteKernel<T>,
230 /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
231 OP_REQUIRES_OK(
232 context,
233 GpuLaunchKernel(SparseConcatPermuteKernel<T>, config.block_count,
234 config.thread_per_block, 0, device.stream(), output_nnz,
235 rank, val_ptrs.data(), nnz_scan.data(),
236 output_shape.data(), output_flat_inds_ptr,
237 permutation_ptr, output_inds_ptr, output_vals_ptr));
238 }
239
240 private:
allocate_outputstensorflow::functor::SparseConcatFunctor241 Status allocate_outputs(OpKernelContext* context, int rank, int64 output_nnz,
242 int64** output_inds_ptr, T** output_vals_ptr) const {
243 Tensor* output_inds = nullptr;
244 TF_RETURN_IF_ERROR(context->allocate_output(
245 0, TensorShape({output_nnz, rank}), &output_inds));
246 *output_inds_ptr = output_inds->matrix<int64_t>().data();
247 Tensor* output_vals = nullptr;
248 TF_RETURN_IF_ERROR(
249 context->allocate_output(1, TensorShape({output_nnz}), &output_vals));
250 *output_vals_ptr = output_vals->vec<T>().data();
251 return Status::OK();
252 }
253 };
254
255 #define DEFINE_SPARSE_CONCAT_FUNCTOR(T) \
256 template struct SparseConcatFunctor<GPUDevice, T>;
257 TF_CALL_POD_TYPES(DEFINE_SPARSE_CONCAT_FUNCTOR);
258
259 #undef DEFINE_SPARSE_CONCAT_FUNCTOR
260
261 } // namespace functor
262
263 } // namespace tensorflow
264
265 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
266