xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sparse_concat_op_gpu.cu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 
18 #define EIGEN_USE_GPU
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/kernels/gpu_device_array.h"
25 #include "tensorflow/core/kernels/gpu_device_array_gpu.h"
26 #include "tensorflow/core/kernels/gpu_prim_helpers.h"
27 #include "tensorflow/core/kernels/sparse_concat_op.h"
28 #include "tensorflow/core/lib/core/bits.h"
29 #include "tensorflow/core/util/gpu_kernel_helper.h"
30 
31 namespace tensorflow {
32 
33 typedef Eigen::GpuDevice GPUDevice;
34 
35 namespace functor {
36 
37 namespace {
38 
39 template <typename T>
SparseConcatKernel(int64 output_nnz,int rank,int concat_dim,bool need_to_sort,GpuDeviceArrayStruct<const int64 * > ind_ptrs_data,GpuDeviceArrayStruct<const T * > val_ptrs_data,GpuDeviceArrayStruct<int64_t> nnz_scan_data,GpuDeviceArrayStruct<int64_t> concat_size_scan_data,GpuDeviceArrayStruct<int64_t> output_shape_data,int64 * __restrict__ output_inds,T * __restrict__ output_vals,int64 * __restrict__ output_flat_inds)40 __global__ void SparseConcatKernel(
41     int64 output_nnz, int rank, int concat_dim, bool need_to_sort,
42     GpuDeviceArrayStruct<const int64*> ind_ptrs_data,
43     GpuDeviceArrayStruct<const T*> val_ptrs_data,
44     GpuDeviceArrayStruct<int64_t> nnz_scan_data,
45     GpuDeviceArrayStruct<int64_t> concat_size_scan_data,
46     GpuDeviceArrayStruct<int64_t> output_shape_data,
47     int64* __restrict__ output_inds, T* __restrict__ output_vals,
48     int64* __restrict__ output_flat_inds) {
49   const int64* __restrict__* __restrict__ ind_ptrs =
50       GetGpuDeviceArrayOnDevice(&ind_ptrs_data);
51   const T* __restrict__* __restrict__ val_ptrs =
52       GetGpuDeviceArrayOnDevice(&val_ptrs_data);
53   const int64* __restrict__ nnz_scan =
54       GetGpuDeviceArrayOnDevice(&nnz_scan_data);
55   const int64* __restrict__ concat_size_scan =
56       GetGpuDeviceArrayOnDevice(&concat_size_scan_data);
57   const int64* __restrict__ output_shape =
58       GetGpuDeviceArrayOnDevice(&output_shape_data);
59   const int64 num_inputs = ind_ptrs_data.size;
60 
61   for (int64 nz : GpuGridRangeX<int64_t>(output_nnz)) {
62     const int64 input_num =
63         gpu_helper::upper_bound<int64_t>(nnz_scan, num_inputs, nz) - 1;
64     const int64 input_nz = nz - nnz_scan[input_num];
65     const int64 ind_offset = concat_size_scan[input_num];
66     if (!need_to_sort) {
67       output_vals[nz] = val_ptrs[input_num][input_nz];
68     }
69     int64 flat_ind = 0;
70     for (int j = 0; j < rank; ++j) {
71       const int64 output_ind = ind_ptrs[input_num][input_nz * rank + j] +
72                                (j == concat_dim ? ind_offset : 0);
73       if (!need_to_sort) {
74         output_inds[nz * rank + j] = output_ind;
75       } else {
76         flat_ind = flat_ind * output_shape[j] + output_ind;
77         output_flat_inds[nz] = flat_ind;
78       }
79     }
80   }
81 }
82 
83 template <typename T>
SparseConcatPermuteKernel(int64 output_nnz,int rank,GpuDeviceArrayStruct<const T * > val_ptrs_data,GpuDeviceArrayStruct<int64_t> nnz_scan_data,GpuDeviceArrayStruct<int64_t> output_shape_data,const int64 * __restrict__ output_flat_inds,const int64 * __restrict__ permutation,int64 * __restrict__ output_inds,T * __restrict__ output_vals)84 __global__ void SparseConcatPermuteKernel(
85     int64 output_nnz, int rank, GpuDeviceArrayStruct<const T*> val_ptrs_data,
86     GpuDeviceArrayStruct<int64_t> nnz_scan_data,
87     GpuDeviceArrayStruct<int64_t> output_shape_data,
88     const int64* __restrict__ output_flat_inds,
89     const int64* __restrict__ permutation, int64* __restrict__ output_inds,
90     T* __restrict__ output_vals) {
91   const T* __restrict__* __restrict__ val_ptrs =
92       GetGpuDeviceArrayOnDevice(&val_ptrs_data);
93   const int64* __restrict__ nnz_scan =
94       GetGpuDeviceArrayOnDevice(&nnz_scan_data);
95   const int64* __restrict__ output_shape =
96       GetGpuDeviceArrayOnDevice(&output_shape_data);
97   const int64 num_inputs = val_ptrs_data.size;
98 
99   for (int64 nz : GpuGridRangeX<int64_t>(output_nnz)) {
100     const int64 permuted_nz = permutation[nz];
101     const int64 input_num =
102         gpu_helper::upper_bound<int64_t>(nnz_scan, num_inputs, permuted_nz) - 1;
103     const int64 input_nz = permuted_nz - nnz_scan[input_num];
104     output_vals[nz] = val_ptrs[input_num][input_nz];
105     int64 output_flat_ind = output_flat_inds[permuted_nz];
106     for (int j = rank - 1; j >= 0; --j) {
107       const int64 output_dim_size = output_shape[j];
108       output_inds[nz * rank + j] = output_flat_ind % output_dim_size;
109       output_flat_ind /= output_dim_size;
110     }
111   }
112 }
113 
114 }  // namespace
115 
116 template <typename T>
117 struct SparseConcatFunctor<GPUDevice, T> {
operator ()tensorflow::functor::SparseConcatFunctor118   void operator()(OpKernelContext* context, const OpInputList& inds,
119                   const OpInputList& vals, const OpInputList& shapes,
120                   int concat_dim) {
121     const int N = inds.size();
122     const TensorShape input_shape0(shapes[0].vec<int64_t>());
123     const int rank = input_shape0.dims();
124 
125     // The input non-zeros are assumed to be sorted by increasing dimension
126     // number (i.e., row-major order), so if the concatenation is along the
127     // first dimension then they remain in order and we can directly compute the
128     // output indices and values. To concatenate along other dimensions, we
129     // first compute the flattened (1D) row-major output indices, then sort
130     // these to obtain the required permutation, and finally gather the permuted
131     // input values.
132 
133     GpuDeviceArrayOnHost<const int64*> ind_ptrs(context, N);
134     GpuDeviceArrayOnHost<const T*> val_ptrs(context, N);
135     GpuDeviceArrayOnHost<int64_t> nnz_scan(context, N + 1);
136     GpuDeviceArrayOnHost<int64_t> concat_size_scan(context, N + 1);
137     OP_REQUIRES_OK(context, ind_ptrs.Init());
138     OP_REQUIRES_OK(context, val_ptrs.Init());
139     OP_REQUIRES_OK(context, nnz_scan.Init());
140     OP_REQUIRES_OK(context, concat_size_scan.Init());
141     int64 nnz_sum = 0;
142     int64 concat_size_sum = 0;
143     nnz_scan.Set(0, nnz_sum);
144     concat_size_scan.Set(0, concat_size_sum);
145     for (int i = 0; i < N; ++i) {
146       ind_ptrs.Set(i, inds[i].matrix<int64_t>().data());
147       val_ptrs.Set(i, vals[i].vec<T>().data());
148       nnz_sum += inds[i].dim_size(0);
149       nnz_scan.Set(i + 1, nnz_sum);
150       const TensorShape current_shape(shapes[i].vec<int64_t>());
151       concat_size_sum += current_shape.dim_size(concat_dim);
152       concat_size_scan.Set(i + 1, concat_size_sum);
153     }
154     OP_REQUIRES_OK(context, ind_ptrs.Finalize());
155     OP_REQUIRES_OK(context, val_ptrs.Finalize());
156     OP_REQUIRES_OK(context, nnz_scan.Finalize());
157     OP_REQUIRES_OK(context, concat_size_scan.Finalize());
158     const int64 output_nnz = nnz_sum;
159     const int64 output_concat_size = concat_size_sum;
160 
161     const bool need_to_sort = concat_dim != 0;
162 
163     GpuDeviceArrayOnHost<int64_t> output_shape(context, rank);
164     int64 output_dense_elements;
165     if (need_to_sort) {
166       OP_REQUIRES_OK(context, output_shape.Init());
167       output_dense_elements = 1;
168       for (int j = 0; j < rank; ++j) {
169         int64 output_dim_size =
170             j == concat_dim ? output_concat_size : input_shape0.dim_size(j);
171         output_shape.Set(j, output_dim_size);
172         output_dense_elements *= output_dim_size;
173       }
174       OP_REQUIRES_OK(context, output_shape.Finalize());
175     }
176 
177     int64* output_inds_ptr = nullptr;
178     T* output_vals_ptr = nullptr;
179     int64* output_flat_inds_ptr = nullptr;
180     Tensor output_flat_inds;
181     if (need_to_sort) {
182       // SparseConcatKernel will (only) produce output_flat_inds.
183       OP_REQUIRES_OK(context,
184                      context->allocate_temp(DT_INT64, TensorShape({output_nnz}),
185                                             &output_flat_inds));
186       output_flat_inds_ptr = output_flat_inds.vec<int64_t>().data();
187     } else {
188       OP_REQUIRES_OK(
189           context, allocate_outputs(context, rank, output_nnz, &output_inds_ptr,
190                                     &output_vals_ptr));
191     }
192 
193     const GPUDevice& device = context->eigen_gpu_device();
194 
195     GpuLaunchConfig config = GetGpuLaunchConfig(
196         output_nnz, device, &SparseConcatKernel<T>,
197         /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
198     OP_REQUIRES_OK(
199         context, GpuLaunchKernel(
200                      SparseConcatKernel<T>, config.block_count,
201                      config.thread_per_block, 0, device.stream(), output_nnz,
202                      rank, concat_dim, need_to_sort, ind_ptrs.data(),
203                      val_ptrs.data(), nnz_scan.data(), concat_size_scan.data(),
204                      (need_to_sort ? output_shape.data()
205                                    : GpuDeviceArrayStruct<int64_t>()),
206                      output_inds_ptr, output_vals_ptr, output_flat_inds_ptr));
207 
208     if (!need_to_sort) return;
209 
210     OP_REQUIRES_OK(context,
211                    allocate_outputs(context, rank, output_nnz, &output_inds_ptr,
212                                     &output_vals_ptr));
213 
214     Tensor permutation;
215     OP_REQUIRES_OK(context,
216                    context->allocate_temp(DT_INT64, TensorShape({output_nnz}),
217                                           &permutation));
218     int64* permutation_ptr = permutation.vec<int64_t>().data();
219     OP_REQUIRES_OK(
220         context,
221         GpuRadixSort(context, /*size=*/output_nnz,
222                      /*keys_in=*/output_flat_inds_ptr,
223                      /*keys_out=*/static_cast<int64*>(nullptr),
224                      /*indices_in=*/static_cast<const int64*>(nullptr),
225                      /*indices_out=*/permutation_ptr,
226                      /*num_bits=*/Log2Ceiling64(output_dense_elements)));
227 
228     config = GetGpuLaunchConfig(
229         output_nnz, device, &SparseConcatPermuteKernel<T>,
230         /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
231     OP_REQUIRES_OK(
232         context,
233         GpuLaunchKernel(SparseConcatPermuteKernel<T>, config.block_count,
234                         config.thread_per_block, 0, device.stream(), output_nnz,
235                         rank, val_ptrs.data(), nnz_scan.data(),
236                         output_shape.data(), output_flat_inds_ptr,
237                         permutation_ptr, output_inds_ptr, output_vals_ptr));
238   }
239 
240  private:
allocate_outputstensorflow::functor::SparseConcatFunctor241   Status allocate_outputs(OpKernelContext* context, int rank, int64 output_nnz,
242                           int64** output_inds_ptr, T** output_vals_ptr) const {
243     Tensor* output_inds = nullptr;
244     TF_RETURN_IF_ERROR(context->allocate_output(
245         0, TensorShape({output_nnz, rank}), &output_inds));
246     *output_inds_ptr = output_inds->matrix<int64_t>().data();
247     Tensor* output_vals = nullptr;
248     TF_RETURN_IF_ERROR(
249         context->allocate_output(1, TensorShape({output_nnz}), &output_vals));
250     *output_vals_ptr = output_vals->vec<T>().data();
251     return Status::OK();
252   }
253 };
254 
255 #define DEFINE_SPARSE_CONCAT_FUNCTOR(T) \
256   template struct SparseConcatFunctor<GPUDevice, T>;
257 TF_CALL_POD_TYPES(DEFINE_SPARSE_CONCAT_FUNCTOR);
258 
259 #undef DEFINE_SPARSE_CONCAT_FUNCTOR
260 
261 }  // namespace functor
262 
263 }  // namespace tensorflow
264 
265 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
266