xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/types/span.h"
17 #include "tensorflow/compiler/xla/client/executable_build_options.h"
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/pjrt/gpu_device.h"
20 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
23 #include "tensorflow/core/platform/random.h"
24 
25 namespace xla {
26 namespace {
27 
28 // Regression test that verifies that substreams of a multistream GPU
29 // computation wait for the inputs to be produced before executing.
TEST(GpuMultiStream,Basics)30 TEST(GpuMultiStream, Basics) {
31   TF_ASSERT_OK_AND_ASSIGN(
32       std::unique_ptr<PjRtClient> client,
33       GetGpuClient(/*asynchronous=*/true, GpuAllocatorConfig(),
34                    /*distributed_client=*/nullptr, /*node_id=*/0));
35 
36   PjRtDevice* device = client->addressable_devices().at(0);
37 
38   int n = 1024;
39   Shape shape = ShapeUtil::MakeShape(S32, {n});
40   std::vector<int32_t> inputs(n);
41   std::vector<int32_t> expected_outputs(n);
42 
43   XlaBuilder builder("acomputation");
44   auto p0 = Parameter(&builder, 0, shape, "param");
45   auto p1 = Parameter(&builder, 1, shape, "param");
46   Tuple(&builder, {Neg(p0), Neg(p1)});
47   TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build());
48 
49   CompileOptions compile_options;
50   compile_options.executable_build_options.mutable_debug_options()
51       ->set_xla_gpu_disable_multi_streaming(false);
52   compile_options.executable_build_options.mutable_debug_options()
53       ->set_xla_gpu_use_random_streams(true);
54   DeviceAssignment device_assignment(1, 1);
55   device_assignment(0, 0) = device->id();
56   compile_options.executable_build_options.set_device_assignment(
57       device_assignment);
58   TF_ASSERT_OK_AND_ASSIGN(
59       std::unique_ptr<PjRtLoadedExecutable> executable,
60       client->Compile(computation, std::move(compile_options)));
61 
62   int64_t dummy_size = 1 << 20;
63   std::vector<int32_t> dummy_inputs(dummy_size);
64   Shape dummy_shape = ShapeUtil::MakeShape(S32, {dummy_size});
65 
66   for (int i = 0; i < 100; ++i) {
67     for (int i = 0; i < n; ++i) {
68       inputs[i] = tensorflow::random::New64();
69       expected_outputs[i] = -inputs[i];
70     }
71     // Transfer a large dummy buffer, behind which the inputs to the computation
72     // must wait.
73     TF_ASSERT_OK_AND_ASSIGN(
74         auto dummy_buffer,
75         client->BufferFromHostBuffer(
76             dummy_inputs.data(), S32, dummy_shape.dimensions(),
77             /*byte_strides=*/std::nullopt,
78             PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
79             /*on_done_with_host_buffer=*/nullptr, device));
80     TF_ASSERT_OK_AND_ASSIGN(
81         auto in_buffer0,
82         client->BufferFromHostBuffer(
83             inputs.data(), S32, shape.dimensions(),
84             /*byte_strides=*/std::nullopt,
85             PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
86             /*on_done_with_host_buffer=*/nullptr, device));
87     TF_ASSERT_OK_AND_ASSIGN(
88         auto in_buffer1,
89         client->BufferFromHostBuffer(
90             inputs.data(), S32, shape.dimensions(),
91             /*byte_strides=*/std::nullopt,
92             PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
93             /*on_done_with_host_buffer=*/nullptr, device));
94     // The execution may be enqueued before the transfers complete, requiring
95     // adequate device-side synchronization.
96     ExecuteOptions options;
97     options.untuple_result = true;
98     TF_ASSERT_OK_AND_ASSIGN(
99         auto out_buffers,
100         executable->Execute({{in_buffer0.get(), in_buffer1.get()}}, options));
101 
102     TF_ASSERT_OK_AND_ASSIGN(auto out_literal,
103                             out_buffers[0][0]->ToLiteralSync());
104     LiteralTestUtil::ExpectR1Equal<int32_t>(expected_outputs, *out_literal);
105     TF_ASSERT_OK_AND_ASSIGN(out_literal, out_buffers[0][1]->ToLiteralSync());
106     LiteralTestUtil::ExpectR1Equal<int32_t>(expected_outputs, *out_literal);
107   }
108 }
109 
110 }  // namespace
111 }  // namespace xla
112