xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/gpu_prim_helpers_test.cu.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 
18 #include "tensorflow/core/kernels/gpu_prim_helpers.h"
19 
20 #include "tensorflow/core/framework/fake_input.h"
21 #include "tensorflow/core/framework/node_def_builder.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/kernels/ops_testutil.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 template <typename Tkey, typename Tindex>
33 class TestGpuRadixSortKernel : public tensorflow::OpKernel {
34  public:
TestGpuRadixSortKernel(tensorflow::OpKernelConstruction * context)35   explicit TestGpuRadixSortKernel(tensorflow::OpKernelConstruction* context)
36       : OpKernel(context) {
37     OP_REQUIRES_OK(context, context->GetAttr("need_keys_out", &need_keys_out_));
38     OP_REQUIRES_OK(context, context->GetAttr("num_bits", &num_bits_));
39     if (num_bits_ == -1) {
40       num_bits_ = sizeof(Tkey) * 8;
41     }
42   }
43 
Compute(tensorflow::OpKernelContext * context)44   void Compute(tensorflow::OpKernelContext* context) override {
45     const Tensor& keys_in = context->input(0);
46     const Tensor& indices_in = context->input(1);
47 
48     const Tkey* keys_in_data = keys_in.flat<Tkey>().data();
49     const Tindex* indices_in_data = indices_in.NumElements() == 0
50                                         ? nullptr
51                                         : indices_in.flat<Tindex>().data();
52 
53     int64 size = keys_in.NumElements();
54 
55     Tkey* keys_out_data = nullptr;
56     if (need_keys_out_) {
57       Tensor* keys_out = nullptr;
58       OP_REQUIRES_OK(
59           context, context->allocate_output(0, TensorShape({size}), &keys_out));
60       keys_out_data = keys_out->flat<Tkey>().data();
61     }
62 
63     Tensor* indices_out;
64     OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({size}),
65                                                      &indices_out));
66     Tindex* indices_out_data = indices_out->flat<Tindex>().data();
67 
68     OP_REQUIRES_OK(context,
69                    GpuRadixSort(context, size, keys_in_data, keys_out_data,
70                                 indices_in_data, indices_out_data, num_bits_));
71   }
72 
73  private:
74   bool need_keys_out_;
75   int num_bits_;
76 };
77 
78 REGISTER_OP("TestGpuRadixSort")
79     .Input("keys_in: Tkey")
80     .Input("indices_in: Tindex")
81     .Output("keys_out: Tkey")
82     .Output("indices_out: Tindex")
83     .Attr("need_keys_out: bool = true")
84     .Attr("num_bits: int = -1")
85     .Attr("Tkey: type")
86     .Attr("Tindex: type");
87 #define REGISTER_KERNELS(Tkey, Tindex)                           \
88   REGISTER_KERNEL_BUILDER(Name("TestGpuRadixSort")               \
89                               .Device(tensorflow::DEVICE_GPU)    \
90                               .TypeConstraint<Tkey>("Tkey")      \
91                               .TypeConstraint<Tindex>("Tindex"), \
92                           TestGpuRadixSortKernel<Tkey, Tindex>)
93 REGISTER_KERNELS(float, int32);
94 REGISTER_KERNELS(int32, int32);
95 #undef REGISTER_KERNELS
96 
97 template <typename T>
98 class TestGpuInclusivePrefixSumKernel : public tensorflow::OpKernel {
99  public:
TestGpuInclusivePrefixSumKernel(tensorflow::OpKernelConstruction * context)100   explicit TestGpuInclusivePrefixSumKernel(
101       tensorflow::OpKernelConstruction* context)
102       : OpKernel(context) {}
103 
Compute(tensorflow::OpKernelContext * context)104   void Compute(tensorflow::OpKernelContext* context) override {
105     const Tensor& input = context->input(0);
106     const T* input_data = input.flat<T>().data();
107     int64 size = input.NumElements();
108 
109     Tensor* output = nullptr;
110     OP_REQUIRES_OK(context,
111                    context->allocate_output(0, TensorShape({size}), &output));
112     T* output_data = output->flat<T>().data();
113 
114     OP_REQUIRES_OK(
115         context, GpuInclusivePrefixSum(context, size, input_data, output_data));
116   }
117 };
118 
119 REGISTER_OP("TestGpuInclusivePrefixSum")
120     .Input("input: T")
121     .Output("output: T")
122     .Attr("T: type");
123 #define REGISTER_KERNELS(T)                                   \
124   REGISTER_KERNEL_BUILDER(Name("TestGpuInclusivePrefixSum")   \
125                               .Device(tensorflow::DEVICE_GPU) \
126                               .TypeConstraint<T>("T"),        \
127                           TestGpuInclusivePrefixSumKernel<T>)
128 REGISTER_KERNELS(int32);
129 #undef REGISTER_KERNELS
130 
131 template <typename T, typename Toffset, typename ReduceOp>
132 class TestGpuSegmentedReduceKernel : public tensorflow::OpKernel {
133  public:
TestGpuSegmentedReduceKernel(tensorflow::OpKernelConstruction * context)134   explicit TestGpuSegmentedReduceKernel(
135       tensorflow::OpKernelConstruction* context)
136       : OpKernel(context) {}
137 
Compute(tensorflow::OpKernelContext * context)138   void Compute(tensorflow::OpKernelContext* context) override {
139     const Tensor& input = context->input(0);
140     const T* input_data = input.flat<T>().data();
141     const Tensor& segment_offsets = context->input(1);
142     const Toffset* segment_offsets_data =
143         segment_offsets.flat<Toffset>().data();
144     int num_segments = segment_offsets.NumElements() - 1;
145     const Tensor& initial_value_tensor = context->input(2);
146     T initial_value = initial_value_tensor.scalar<T>()();
147 
148     Tensor* output = nullptr;
149     OP_REQUIRES_OK(context, context->allocate_output(
150                                 0, TensorShape({num_segments}), &output));
151     T* output_data = output->flat<T>().data();
152 
153     OP_REQUIRES_OK(
154         context,
155         GpuSegmentedReduce(context, num_segments, ReduceOp(), initial_value,
156                            input_data, segment_offsets_data, output_data));
157   }
158 
159  private:
160   T initial_value_;
161 };
162 
163 REGISTER_OP("TestGpuSegmentedSum")
164     .Input("input: T")
165     .Input("segment_offsets: Toffset")
166     .Input("initial_value: T")
167     .Output("output: T")
168     .Attr("T: type")
169     .Attr("Toffset: type");
170 #define REGISTER_KERNELS(T, Toffset)           \
171   REGISTER_KERNEL_BUILDER(                     \
172       Name("TestGpuSegmentedSum")              \
173           .Device(tensorflow::DEVICE_GPU)      \
174           .HostMemory("initial_value")         \
175           .TypeConstraint<T>("T")              \
176           .TypeConstraint<Toffset>("Toffset"), \
177       TestGpuSegmentedReduceKernel<T, Toffset, gpuprim::Sum>)
178 REGISTER_KERNELS(int32, int32);
179 #undef REGISTER_KERNELS
180 
181 template <typename T>
182 class TestGpuSelectFlaggedKernel : public tensorflow::OpKernel {
183  public:
TestGpuSelectFlaggedKernel(tensorflow::OpKernelConstruction * context)184   explicit TestGpuSelectFlaggedKernel(tensorflow::OpKernelConstruction* context)
185       : OpKernel(context) {
186     OP_REQUIRES_OK(context, context->GetAttr("output_size", &output_size_));
187   }
188 
Compute(tensorflow::OpKernelContext * context)189   void Compute(tensorflow::OpKernelContext* context) override {
190     const Tensor& input = context->input(0);
191     const T* input_data = input.flat<T>().data();
192     const Tensor& flags = context->input(1);
193     const bool* flags_data = flags.flat<bool>().data();
194 
195     int64_t input_size = input.dim_size(0);
196 
197     Tensor* output = nullptr;
198     OP_REQUIRES_OK(context, context->allocate_output(
199                                 0, TensorShape({output_size_}), &output));
200     T* output_data = output->flat<T>().data();
201 
202     Tensor output_size_t;
203     OP_REQUIRES_OK(context, context->allocate_temp(DT_INT64, TensorShape({}),
204                                                    &output_size_t));
205     int64_t* output_size_data = output_size_t.scalar<int64_t>().data();
206 
207     OP_REQUIRES_OK(context,
208                    GpuSelectFlagged(context, input_size, input_data, flags_data,
209                                     output_data, output_size_data));
210 
211     // Copy the computed output size to host and ensure it matches.
212     se::Stream* stream = context->op_device_context()->stream();
213     int64_t output_size_host;
214     OP_REQUIRES(context,
215                 stream
216                     ->ThenMemcpy(&output_size_host,
217                                  se::DeviceMemoryBase(output_size_data,
218                                                       sizeof(output_size_data)),
219                                  sizeof(output_size_host))
220                     .ok(),
221                 errors::Internal("Failed to copy output_size_gpu to host"));
222     OP_REQUIRES_OK(context, stream->BlockHostUntilDone());
223     OP_REQUIRES(context, output_size_host == output_size_,
224                 errors::Internal("Incorrect output size: expected ",
225                                  output_size_, ", got ", output_size_host));
226   }
227 
228  private:
229   int64_t output_size_;
230 };
231 
232 REGISTER_OP("TestGpuSelectFlagged")
233     .Input("input: T")
234     .Input("flags: bool")
235     .Output("output: T")
236     .Attr("T: type")
237     .Attr("output_size: int");
238 #define REGISTER_KERNELS(T)                                   \
239   REGISTER_KERNEL_BUILDER(Name("TestGpuSelectFlagged")        \
240                               .Device(tensorflow::DEVICE_GPU) \
241                               .TypeConstraint<T>("T"),        \
242                           TestGpuSelectFlaggedKernel<T>)
243 REGISTER_KERNELS(int32);
244 #undef REGISTER_KERNELS
245 
246 class GpuPrimHelpersTest : public OpsTestBase {
247  protected:
GpuPrimHelpersTest()248   GpuPrimHelpersTest() {
249     SetDevice(DEVICE_GPU,
250               std::unique_ptr<tensorflow::Device>(DeviceFactory::NewDevice(
251                   "GPU", {}, "/job:a/replica:0/task:0")));
252   }
253 
MakeRadixSort(DataType key_type,DataType index_type,bool need_keys_out=true,int num_bits=-1)254   void MakeRadixSort(DataType key_type, DataType index_type,
255                      bool need_keys_out = true, int num_bits = -1) {
256     TF_ASSERT_OK(NodeDefBuilder("test_op", "TestGpuRadixSort")
257                      .Input(FakeInput(key_type))
258                      .Input(FakeInput(index_type))
259                      .Attr("need_keys_out", need_keys_out)
260                      .Attr("num_bits", num_bits)
261                      .Finalize(node_def()));
262     TF_ASSERT_OK(InitOp());
263   }
264 
MakeInclusivePrefixSum(DataType type)265   void MakeInclusivePrefixSum(DataType type) {
266     TF_ASSERT_OK(NodeDefBuilder("test_op", "TestGpuInclusivePrefixSum")
267                      .Input(FakeInput(type))
268                      .Finalize(node_def()));
269     TF_ASSERT_OK(InitOp());
270   }
271 
MakeSegmentedSum(DataType type,DataType offset_type)272   void MakeSegmentedSum(DataType type, DataType offset_type) {
273     TF_ASSERT_OK(NodeDefBuilder("test_op", "TestGpuSegmentedSum")
274                      .Input(FakeInput(type))
275                      .Input(FakeInput(offset_type))
276                      .Input(FakeInput(type))
277                      .Finalize(node_def()));
278     TF_ASSERT_OK(InitOp());
279   }
280 
MakeSelectFlagged(DataType type,int64 output_size)281   void MakeSelectFlagged(DataType type, int64 output_size) {
282     TF_ASSERT_OK(NodeDefBuilder("test_op", "TestGpuSelectFlagged")
283                      .Input(FakeInput(type))
284                      .Input(FakeInput(DT_BOOL))
285                      .Attr("output_size", output_size)
286                      .Finalize(node_def()));
287     TF_ASSERT_OK(InitOp());
288   }
289 };
290 
TEST_F(GpuPrimHelpersTest,GpuRadixSort_Keys)291 TEST_F(GpuPrimHelpersTest, GpuRadixSort_Keys) {
292   MakeRadixSort(DT_FLOAT, DT_INT32);
293   AddInputFromArray<float>(TensorShape({8}), {4, 2, 6, 7, 1, 3, 0, 5});  // keys
294   AddInputFromArray<int32>(TensorShape({0}), {});                        // inds
295   TF_ASSERT_OK(RunOpKernel());
296 
297   Tensor expected_keys_out(allocator(), DT_FLOAT, TensorShape({8}));
298   test::FillValues<float>(&expected_keys_out, {0, 1, 2, 3, 4, 5, 6, 7});
299   test::ExpectTensorEqual<float>(expected_keys_out, *GetOutput(0));
300 
301   Tensor expected_indices_out(allocator(), DT_INT32, TensorShape({8}));
302   test::FillValues<int32>(&expected_indices_out, {6, 4, 1, 5, 0, 7, 2, 3});
303   test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput(1));
304 }
305 
TEST_F(GpuPrimHelpersTest,GpuRadixSort_KeysAndIndices)306 TEST_F(GpuPrimHelpersTest, GpuRadixSort_KeysAndIndices) {
307   MakeRadixSort(DT_FLOAT, DT_INT32);
308   AddInputFromArray<float>(TensorShape({8}), {4, 2, 6, 7, 1, 3, 0, 5});  // keys
309   AddInputFromArray<int32>(TensorShape({8}), {7, 6, 5, 4, 3, 2, 1, 0});  // inds
310   TF_ASSERT_OK(RunOpKernel());
311 
312   Tensor expected_keys_out(allocator(), DT_FLOAT, TensorShape({8}));
313   test::FillValues<float>(&expected_keys_out, {0, 1, 2, 3, 4, 5, 6, 7});
314   test::ExpectTensorEqual<float>(expected_keys_out, *GetOutput(0));
315 
316   Tensor expected_indices_out(allocator(), DT_INT32, TensorShape({8}));
317   test::FillValues<int32>(&expected_indices_out, {1, 3, 6, 2, 7, 0, 5, 4});
318   test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput(1));
319 }
320 
TEST_F(GpuPrimHelpersTest,GpuRadixSort_NoKeysOut)321 TEST_F(GpuPrimHelpersTest, GpuRadixSort_NoKeysOut) {
322   MakeRadixSort(DT_FLOAT, DT_INT32, /*need_keys_out=*/false);
323   AddInputFromArray<float>(TensorShape({8}), {4, 2, 6, 7, 1, 3, 0, 5});  // keys
324   AddInputFromArray<int32>(TensorShape({0}), {});                        // inds
325   TF_ASSERT_OK(RunOpKernel());
326 
327   Tensor expected_indices_out(allocator(), DT_INT32, TensorShape({8}));
328   test::FillValues<int32>(&expected_indices_out, {6, 4, 1, 5, 0, 7, 2, 3});
329   test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput(1));
330 }
331 
TEST_F(GpuPrimHelpersTest,GpuRadixSort_WithNumBits)332 TEST_F(GpuPrimHelpersTest, GpuRadixSort_WithNumBits) {
333   // Only sort by the lowest 2 bits, otherwise keep input order (stable sort).
334   MakeRadixSort(DT_INT32, DT_INT32, /*need_keys_out=*/true, /*num_bits=*/2);
335   AddInputFromArray<int32>(TensorShape({8}), {4, 2, 6, 7, 1, 3, 0, 5});  // keys
336   AddInputFromArray<int32>(TensorShape({0}), {});                        // inds
337   TF_ASSERT_OK(RunOpKernel());
338 
339   Tensor expected_keys_out(allocator(), DT_INT32, TensorShape({8}));
340   test::FillValues<int32>(&expected_keys_out, {4, 0, 1, 5, 2, 6, 7, 3});
341   test::ExpectTensorEqual<int32>(expected_keys_out, *GetOutput(0));
342 
343   Tensor expected_indices_out(allocator(), DT_INT32, TensorShape({8}));
344   test::FillValues<int32>(&expected_indices_out, {0, 6, 4, 7, 1, 2, 3, 5});
345   test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput(1));
346 }
347 
TEST_F(GpuPrimHelpersTest,GpuRadixSort_WithNumBitsZero)348 TEST_F(GpuPrimHelpersTest, GpuRadixSort_WithNumBitsZero) {
349   // Check that num_bits=0 is handled correctly.
350   MakeRadixSort(DT_INT32, DT_INT32, /*need_keys_out=*/true, /*num_bits=*/0);
351   AddInputFromArray<int32>(TensorShape({8}), {4, 2, 6, 7, 1, 3, 0, 5});  // keys
352   AddInputFromArray<int32>(TensorShape({0}), {});                        // inds
353   TF_ASSERT_OK(RunOpKernel());
354 
355   Tensor expected_keys_out(allocator(), DT_INT32, TensorShape({8}));
356   test::FillValues<int32>(&expected_keys_out, {4, 2, 6, 7, 1, 3, 0, 5});
357   test::ExpectTensorEqual<int32>(expected_keys_out, *GetOutput(0));
358 
359   Tensor expected_indices_out(allocator(), DT_INT32, TensorShape({8}));
360   test::FillValues<int32>(&expected_indices_out, {0, 1, 2, 3, 4, 5, 6, 7});
361   test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput(1));
362 }
363 
TEST_F(GpuPrimHelpersTest,GpuRadixSort_KeysAndIndices_WithNumBitsZero)364 TEST_F(GpuPrimHelpersTest, GpuRadixSort_KeysAndIndices_WithNumBitsZero) {
365   // Check that num_bits=0 is handled correctly (with indices_in).
366   MakeRadixSort(DT_INT32, DT_INT32, /*need_keys_out=*/true, /*num_bits=*/0);
367   AddInputFromArray<int32>(TensorShape({8}), {4, 2, 6, 7, 1, 3, 0, 5});  // keys
368   AddInputFromArray<int32>(TensorShape({8}), {7, 6, 5, 4, 3, 2, 1, 0});  // inds
369   TF_ASSERT_OK(RunOpKernel());
370 
371   Tensor expected_keys_out(allocator(), DT_INT32, TensorShape({8}));
372   test::FillValues<int32>(&expected_keys_out, {4, 2, 6, 7, 1, 3, 0, 5});
373   test::ExpectTensorEqual<int32>(expected_keys_out, *GetOutput(0));
374 
375   Tensor expected_indices_out(allocator(), DT_INT32, TensorShape({8}));
376   test::FillValues<int32>(&expected_indices_out, {7, 6, 5, 4, 3, 2, 1, 0});
377   test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput(1));
378 }
379 
TEST_F(GpuPrimHelpersTest,GpuInclusivePrefixSum)380 TEST_F(GpuPrimHelpersTest, GpuInclusivePrefixSum) {
381   MakeInclusivePrefixSum(DT_INT32);
382   AddInputFromArray<int32>(TensorShape({8}), {4, 2, 6, 7, 1, 3, 0, 5});
383   TF_ASSERT_OK(RunOpKernel());
384 
385   Tensor expected_output(allocator(), DT_INT32, TensorShape({8}));
386   test::FillValues<int32>(&expected_output, {4, 6, 12, 19, 20, 23, 23, 28});
387   test::ExpectTensorEqual<int32>(expected_output, *GetOutput(0));
388 }
389 
TEST_F(GpuPrimHelpersTest,GpuSegmentedReduce_Sum)390 TEST_F(GpuPrimHelpersTest, GpuSegmentedReduce_Sum) {
391   MakeSegmentedSum(DT_INT32, DT_INT32);
392   // Input.
393   AddInputFromArray<int32>(TensorShape({10}), {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
394   // Segment IDs.
395   AddInputFromArray<int32>(TensorShape({6}), {1, 3, 4, 4, 8, 10});
396   // Initial value.
397   AddInputFromArray<int32>(TensorShape({}), {0});
398   TF_ASSERT_OK(RunOpKernel());
399 
400   Tensor expected_output(allocator(), DT_INT32, TensorShape({5}));
401   test::FillValues<int32>(&expected_output, {3, 3, 0, 22, 17});
402   test::ExpectTensorEqual<int32>(expected_output, *GetOutput(0));
403 }
404 
TEST_F(GpuPrimHelpersTest,GpuSelectFlagged)405 TEST_F(GpuPrimHelpersTest, GpuSelectFlagged) {
406   MakeSelectFlagged(DT_INT32, 3);
407   // Input.
408   AddInputFromArray<int32>(TensorShape({10}), {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
409   // Flags.
410   AddInputFromArray<bool>(TensorShape({10}), {0, 0, 1, 0, 1, 0, 0, 1, 0, 0});
411   TF_ASSERT_OK(RunOpKernel());
412 
413   Tensor expected_output(allocator(), DT_INT32, TensorShape({3}));
414   test::FillValues<int32>(&expected_output, {2, 4, 7});
415   test::ExpectTensorEqual<int32>(expected_output, *GetOutput(0));
416 }
417 
TEST_F(GpuPrimHelpersTest,GpuSelectFlagged_Empty)418 TEST_F(GpuPrimHelpersTest, GpuSelectFlagged_Empty) {
419   MakeSelectFlagged(DT_INT32, 0);
420   // Input.
421   AddInputFromArray<int32>(TensorShape({0}), {});
422   // Flags.
423   AddInputFromArray<bool>(TensorShape({0}), {});
424   TF_ASSERT_OK(RunOpKernel());
425 
426   Tensor expected_output(allocator(), DT_INT32, TensorShape({0}));
427   test::FillValues<int32>(&expected_output, {});
428   test::ExpectTensorEqual<int32>(expected_output, *GetOutput(0));
429 }
430 
431 }  // namespace
432 }  // namespace tensorflow
433 
434 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
435