1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <sstream>
17
18 #if GOOGLE_CUDA
19 #include "third_party/gpus/cuda/include/cuda.h"
20 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
21 #include "third_party/gpus/cuda/include/driver_types.h"
22 #define PLATFORM "CUDA"
23 #elif TENSORFLOW_USE_ROCM
24 #include "rocm/include/hip/hip_runtime.h"
25 #define PLATFORM "ROCM"
26 #endif
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/service/custom_call_status.h"
30 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/test_helpers.h"
33 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/stream_executor/gpu/gpu_types.h"
36
37 #if GOOGLE_CUDA
38 #define gpuSuccess cudaSuccess
39 #define gpuMemcpyAsync cudaMemcpyAsync
40 #define gpuMemcpyDeviceToDevice cudaMemcpyDeviceToDevice
41 #define gpuMemcpy cudaMemcpy
42 #define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
43 #define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
44 #elif TENSORFLOW_USE_ROCM
45 #define gpuSuccess hipSuccess
46 #define gpuMemcpyAsync hipMemcpyAsync
47 #define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice
48 #define gpuMemcpy hipMemcpy
49 #define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
50 #define gpuMemcpyHostToDevice hipMemcpyHostToDevice
51 #endif
52
53 namespace xla {
54 namespace {
55
56 class CustomCallTest : public ClientLibraryTestBase {};
57
58 bool is_invoked_called = false;
Callback_IsInvoked(se::gpu::GpuStreamHandle,void **,const char *,size_t)59 void Callback_IsInvoked(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/,
60 const char* /*opaque*/, size_t /*opaque_len*/) {
61 is_invoked_called = true;
62 }
63 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_IsInvoked, PLATFORM);
64
TEST_F(CustomCallTest,IsInvoked)65 TEST_F(CustomCallTest, IsInvoked) {
66 XlaBuilder b(TestName());
67 CustomCall(&b, "Callback_IsInvoked", /*operands=*/{},
68 ShapeUtil::MakeShape(F32, {}),
69 /*opaque=*/"");
70 EXPECT_FALSE(is_invoked_called);
71 TF_ASSERT_OK(Execute(&b, {}).status());
72 EXPECT_TRUE(is_invoked_called);
73 }
74
TEST_F(CustomCallTest,UnknownTarget)75 TEST_F(CustomCallTest, UnknownTarget) {
76 XlaBuilder b(TestName());
77 CustomCall(&b, "UnknownTarget", /*operands=*/{},
78 ShapeUtil::MakeShape(F32, {}),
79 /*opaque=*/"");
80 ASSERT_FALSE(Execute(&b, {}).ok());
81 }
Callback_Memcpy(se::gpu::GpuStreamHandle stream,void ** buffers,const char *,size_t)82 void Callback_Memcpy(se::gpu::GpuStreamHandle stream, void** buffers,
83 const char* /*opaque*/, size_t /*opaque_len*/) {
84 void* src = buffers[0];
85 void* dst = buffers[1];
86 auto err = gpuMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128,
87 gpuMemcpyDeviceToDevice, stream);
88 ASSERT_EQ(err, gpuSuccess);
89 }
90 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Memcpy, PLATFORM);
TEST_F(CustomCallTest,Memcpy)91 TEST_F(CustomCallTest, Memcpy) {
92 XlaBuilder b(TestName());
93 CustomCall(&b, "Callback_Memcpy",
94 /*operands=*/{Broadcast(ConstantR0WithType(&b, F32, 42.0), {128})},
95 ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"");
96 TF_ASSERT_OK_AND_ASSIGN(auto result, ExecuteAndTransfer(&b, {}));
97 EXPECT_THAT(result.data<float>(), ::testing::Each(42));
98 }
99
100 // Check that opaque handles nulls within the string.
101 std::string& kExpectedOpaque = *new std::string("abc\0def", 7);
Callback_Opaque(se::gpu::GpuStreamHandle,void **,const char * opaque,size_t opaque_len)102 void Callback_Opaque(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/,
103 const char* opaque, size_t opaque_len) {
104 std::string opaque_str(opaque, opaque_len);
105 ASSERT_EQ(opaque_str, kExpectedOpaque);
106 }
107 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Opaque, PLATFORM);
TEST_F(CustomCallTest,Opaque)108 TEST_F(CustomCallTest, Opaque) {
109 XlaBuilder b(TestName());
110 CustomCall(&b, "Callback_Opaque", /*operands=*/{},
111 ShapeUtil::MakeShape(F32, {}), kExpectedOpaque);
112 TF_ASSERT_OK(Execute(&b, {}).status());
113 }
114
Callback_SubBuffers(se::gpu::GpuStreamHandle stream,void ** buffers,const char *,size_t)115 void Callback_SubBuffers(se::gpu::GpuStreamHandle stream, void** buffers,
116 const char* /*opaque*/, size_t /*opaque_len*/) {
117 // `buffers` is a flat array containing device pointers to the following.
118 //
119 // 0: param 0 at tuple index {0}, shape f32[128]
120 // 1: param 0 at tuple index {1}, shape f32[256]
121 // 2: param 1 at tuple index {0}, shape f32[1024]
122 // 3: param 1 at tuple index {1}, shape f32[8]
123 // 4: result at tuple index {0}, shape f32[8]
124 // 5: result at tuple index {1, 0}, shape f32[128]
125 // 6: result at tuple index {1, 1}, shape f32[256]
126 // 7: result at tuple index {2}, shape f32[1024]
127 //
128
129 // Set output leaf buffers, copying data from the corresponding same-sized
130 // inputs.
131 gpuMemcpyAsync(buffers[4], buffers[3], 8 * sizeof(float),
132 gpuMemcpyDeviceToDevice, stream);
133 gpuMemcpyAsync(buffers[5], buffers[0], 128 * sizeof(float),
134 gpuMemcpyDeviceToDevice, stream);
135 gpuMemcpyAsync(buffers[6], buffers[1], 256 * sizeof(float),
136 gpuMemcpyDeviceToDevice, stream);
137 gpuMemcpyAsync(buffers[7], buffers[2], 1024 * sizeof(float),
138 gpuMemcpyDeviceToDevice, stream);
139 }
140 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_SubBuffers, PLATFORM);
TEST_F(CustomCallTest,SubBuffers)141 TEST_F(CustomCallTest, SubBuffers) {
142 XlaBuilder b(TestName());
143 CustomCall(&b, "Callback_SubBuffers", /*operands=*/
144 {
145 Tuple(&b,
146 {
147 Broadcast(ConstantR0WithType(&b, F32, 1), {128}),
148 Broadcast(ConstantR0WithType(&b, F32, 2), {256}),
149 }),
150 Tuple(&b,
151 {
152 Broadcast(ConstantR0WithType(&b, F32, 3), {1024}),
153 Broadcast(ConstantR0WithType(&b, F32, 4), {8}),
154 }),
155 },
156 ShapeUtil::MakeTupleShape({
157 ShapeUtil::MakeShape(F32, {8}),
158 ShapeUtil::MakeTupleShape({
159 ShapeUtil::MakeShape(F32, {128}),
160 ShapeUtil::MakeShape(F32, {256}),
161 }),
162 ShapeUtil::MakeShape(F32, {1024}),
163 }),
164 /*opaque=*/"");
165 TF_ASSERT_OK_AND_ASSIGN(auto result, ExecuteAndTransfer(&b, {}));
166 EXPECT_THAT(result.data<float>({0}), ::testing::Each(4));
167 EXPECT_THAT(result.data<float>({1, 0}), ::testing::Each(1));
168 EXPECT_THAT(result.data<float>({1, 1}), ::testing::Each(2));
169 EXPECT_THAT(result.data<float>({2}), ::testing::Each(3));
170 }
171
172 // The test case for custom call with tokens encodes the arguments and result
173 // type using a string with A(=Array), T(=Token) and {} for Tuples. It also
174 // encodes the check that the callback has to do in terms of a string of A and T
175 // where all the As need to be non-null and all the Ts need to be null. This is
176 // passed to the custom call as its opaque data.
177 //
178 // As an example, "ATTA" for an input encodes 4 inputs to custom call,
179 // "{A{A}T}" for output encodes a custom call with return type containing a
180 // single tuple, with another tuple as the 2nd element. For outputs, it is
181 // either a single element or a tuple. Note, no error checking is performed.
182
183 struct TokenTestCase {
184 std::string input;
185 std::string output;
186 std::string opaque;
187 };
188
operator <<(std::ostream & s,const TokenTestCase & tc)189 std::ostream& operator<<(std::ostream& s, const TokenTestCase& tc) {
190 s << tc.input << "x" << tc.output << "x" << tc.opaque;
191 return s;
192 }
193
Callback_Tokens(se::gpu::GpuStreamHandle stream,void ** buffers,const char * opaque,size_t opaque_len)194 void Callback_Tokens(se::gpu::GpuStreamHandle stream, void** buffers,
195 const char* opaque, size_t opaque_len) {
196 for (int i = 0; i < opaque_len; ++i) {
197 char c = opaque[i];
198 ASSERT_TRUE(c == 'A' || c == 'T');
199 if (c == 'A') {
200 ASSERT_NE(buffers[i], nullptr);
201 } else {
202 ASSERT_EQ(buffers[i], nullptr);
203 }
204 }
205 }
206
207 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Tokens, PLATFORM);
208
GetTokenTestCases()209 std::vector<TokenTestCase> GetTokenTestCases() {
210 return {{"{AT}{AT}", "{A{AT}A}", "ATATAATA"}, // tokens in input and output
211 {"{A}", "T", "AT"}, // single token as output
212 {"{{T}}", "A", "TA"}, // single token as input
213 {"AA", "{TA}", "AATA"},
214 {"TA{TA{TA}}", "{AA}", "TATATAAA"}};
215 }
216
217 class CustomCallTokensTest
218 : public ::testing::WithParamInterface<TokenTestCase>,
219 public ClientLibraryTestBase {
220 public:
BuildInputs(XlaBuilder & b,std::istringstream & str)221 static std::vector<XlaOp> BuildInputs(XlaBuilder& b,
222 std::istringstream& str) {
223 std::vector<XlaOp> values;
224 while (!str.eof()) {
225 int ch = str.get();
226 if (ch == 'A') {
227 values.push_back(Broadcast(ConstantR0WithType(&b, F32, 1), {128}));
228 } else if (ch == 'T') {
229 values.push_back(CreateToken(&b));
230 } else if (ch == '{') {
231 // build a tuple of values. This will eat the } as well.
232 std::vector<XlaOp> tuple_elements = BuildInputs(b, str);
233 values.push_back(Tuple(&b, tuple_elements));
234 } else if (ch == '}') {
235 break;
236 }
237 }
238 return values;
239 }
240
BuildOutputType(std::istringstream & str)241 static std::vector<Shape> BuildOutputType(std::istringstream& str) {
242 std::vector<Shape> shapes;
243 while (!str.eof()) {
244 int ch = str.get();
245 if (ch == 'A') {
246 shapes.push_back(ShapeUtil::MakeShape(F32, {8}));
247 } else if (ch == 'T') {
248 shapes.push_back(ShapeUtil::MakeTokenShape());
249 } else if (ch == '{') {
250 // build a tuple shape. This will eat the } as well.
251 std::vector<Shape> tuple_elements = BuildOutputType(str);
252 shapes.push_back(ShapeUtil::MakeTupleShape(tuple_elements));
253 } else if (ch == '}') {
254 break;
255 }
256 }
257 return shapes;
258 }
259 };
260
TEST_P(CustomCallTokensTest,TokensTest)261 TEST_P(CustomCallTokensTest, TokensTest) {
262 const TokenTestCase& tc = GetParam();
263
264 XlaBuilder b("CustomCallTokens");
265
266 std::istringstream input(tc.input);
267 std::istringstream output(tc.output);
268 std::vector<XlaOp> call_inputs = BuildInputs(b, input);
269 std::vector<Shape> call_output = BuildOutputType(output);
270 ASSERT_EQ(call_output.size(), 1);
271
272 CustomCall(&b, "Callback_Tokens", call_inputs, call_output.front(),
273 tc.opaque);
274 TF_ASSERT_OK(Execute(&b, {}).status());
275 }
276
277 INSTANTIATE_TEST_CASE_P(CustomCallTokens, CustomCallTokensTest,
278 ::testing::ValuesIn(GetTokenTestCases()));
279
Callback_WithStatusSucceeded(se::gpu::GpuStreamHandle,void **,const char *,size_t,XlaCustomCallStatus * status)280 void Callback_WithStatusSucceeded(se::gpu::GpuStreamHandle /*stream*/,
281 void** /*buffers*/, const char* /*opaque*/,
282 size_t /*opaque_len*/,
283 XlaCustomCallStatus* status) {
284 XlaCustomCallStatusSetSuccess(status);
285 }
286 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_WithStatusSucceeded, PLATFORM);
287
TEST_F(CustomCallTest,WithStatusSucceeded)288 TEST_F(CustomCallTest, WithStatusSucceeded) {
289 XlaBuilder b(TestName());
290 CustomCall(
291 &b, "Callback_WithStatusSucceeded", /*operands=*/{},
292 ShapeUtil::MakeShape(F32, {}), /*opaque=*/"",
293 /*has_side_effect=*/false,
294 /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
295 /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
296 /*api_version=*/CustomCallApiVersion::API_VERSION_STATUS_RETURNING);
297 TF_ASSERT_OK(Execute(&b, {}).status());
298 }
299
Callback_WithStatusFailed(se::gpu::GpuStreamHandle,void **,const char *,size_t,XlaCustomCallStatus * status)300 void Callback_WithStatusFailed(se::gpu::GpuStreamHandle /*stream*/,
301 void** /*buffers*/, const char* /*opaque*/,
302 size_t /*opaque_len*/,
303 XlaCustomCallStatus* status) {
304 XlaCustomCallStatusSetFailure(status, "Failed", 6);
305 }
306 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_WithStatusFailed, PLATFORM);
307
TEST_F(CustomCallTest,WithStatusFailed)308 TEST_F(CustomCallTest, WithStatusFailed) {
309 XlaBuilder b(TestName());
310 CustomCall(
311 &b, "Callback_WithStatusFailed", /*operands=*/{},
312 ShapeUtil::MakeShape(F32, {}), /*opaque=*/"",
313 /*has_side_effect=*/false,
314 /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
315 /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
316 /*api_version=*/CustomCallApiVersion::API_VERSION_STATUS_RETURNING);
317 auto status = Execute(&b, {}).status();
318 EXPECT_EQ(status.code(), tensorflow::error::Code::INTERNAL);
319 EXPECT_THAT(status.error_message(), ::testing::HasSubstr("Failed"));
320 }
321
322 } // anonymous namespace
323 } // namespace xla
324