xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/custom_call_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <sstream>
17 
18 #if GOOGLE_CUDA
19 #include "third_party/gpus/cuda/include/cuda.h"
20 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
21 #include "third_party/gpus/cuda/include/driver_types.h"
22 #define PLATFORM "CUDA"
23 #elif TENSORFLOW_USE_ROCM
24 #include "rocm/include/hip/hip_runtime.h"
25 #define PLATFORM "ROCM"
26 #endif
27 #include "tensorflow/compiler/xla/client/lib/constants.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/service/custom_call_status.h"
30 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/test_helpers.h"
33 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/stream_executor/gpu/gpu_types.h"
36 
37 #if GOOGLE_CUDA
38 #define gpuSuccess cudaSuccess
39 #define gpuMemcpyAsync cudaMemcpyAsync
40 #define gpuMemcpyDeviceToDevice cudaMemcpyDeviceToDevice
41 #define gpuMemcpy cudaMemcpy
42 #define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
43 #define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
44 #elif TENSORFLOW_USE_ROCM
45 #define gpuSuccess hipSuccess
46 #define gpuMemcpyAsync hipMemcpyAsync
47 #define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice
48 #define gpuMemcpy hipMemcpy
49 #define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
50 #define gpuMemcpyHostToDevice hipMemcpyHostToDevice
51 #endif
52 
53 namespace xla {
54 namespace {
55 
56 class CustomCallTest : public ClientLibraryTestBase {};
57 
58 bool is_invoked_called = false;
Callback_IsInvoked(se::gpu::GpuStreamHandle,void **,const char *,size_t)59 void Callback_IsInvoked(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/,
60                         const char* /*opaque*/, size_t /*opaque_len*/) {
61   is_invoked_called = true;
62 }
63 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_IsInvoked, PLATFORM);
64 
TEST_F(CustomCallTest,IsInvoked)65 TEST_F(CustomCallTest, IsInvoked) {
66   XlaBuilder b(TestName());
67   CustomCall(&b, "Callback_IsInvoked", /*operands=*/{},
68              ShapeUtil::MakeShape(F32, {}),
69              /*opaque=*/"");
70   EXPECT_FALSE(is_invoked_called);
71   TF_ASSERT_OK(Execute(&b, {}).status());
72   EXPECT_TRUE(is_invoked_called);
73 }
74 
TEST_F(CustomCallTest,UnknownTarget)75 TEST_F(CustomCallTest, UnknownTarget) {
76   XlaBuilder b(TestName());
77   CustomCall(&b, "UnknownTarget", /*operands=*/{},
78              ShapeUtil::MakeShape(F32, {}),
79              /*opaque=*/"");
80   ASSERT_FALSE(Execute(&b, {}).ok());
81 }
Callback_Memcpy(se::gpu::GpuStreamHandle stream,void ** buffers,const char *,size_t)82 void Callback_Memcpy(se::gpu::GpuStreamHandle stream, void** buffers,
83                      const char* /*opaque*/, size_t /*opaque_len*/) {
84   void* src = buffers[0];
85   void* dst = buffers[1];
86   auto err = gpuMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128,
87                             gpuMemcpyDeviceToDevice, stream);
88   ASSERT_EQ(err, gpuSuccess);
89 }
90 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Memcpy, PLATFORM);
TEST_F(CustomCallTest,Memcpy)91 TEST_F(CustomCallTest, Memcpy) {
92   XlaBuilder b(TestName());
93   CustomCall(&b, "Callback_Memcpy",
94              /*operands=*/{Broadcast(ConstantR0WithType(&b, F32, 42.0), {128})},
95              ShapeUtil::MakeShape(F32, {128}), /*opaque=*/"");
96   TF_ASSERT_OK_AND_ASSIGN(auto result, ExecuteAndTransfer(&b, {}));
97   EXPECT_THAT(result.data<float>(), ::testing::Each(42));
98 }
99 
100 // Check that opaque handles nulls within the string.
101 std::string& kExpectedOpaque = *new std::string("abc\0def", 7);
Callback_Opaque(se::gpu::GpuStreamHandle,void **,const char * opaque,size_t opaque_len)102 void Callback_Opaque(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/,
103                      const char* opaque, size_t opaque_len) {
104   std::string opaque_str(opaque, opaque_len);
105   ASSERT_EQ(opaque_str, kExpectedOpaque);
106 }
107 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Opaque, PLATFORM);
TEST_F(CustomCallTest,Opaque)108 TEST_F(CustomCallTest, Opaque) {
109   XlaBuilder b(TestName());
110   CustomCall(&b, "Callback_Opaque", /*operands=*/{},
111              ShapeUtil::MakeShape(F32, {}), kExpectedOpaque);
112   TF_ASSERT_OK(Execute(&b, {}).status());
113 }
114 
Callback_SubBuffers(se::gpu::GpuStreamHandle stream,void ** buffers,const char *,size_t)115 void Callback_SubBuffers(se::gpu::GpuStreamHandle stream, void** buffers,
116                          const char* /*opaque*/, size_t /*opaque_len*/) {
117   // `buffers` is a flat array containing device pointers to the following.
118   //
119   //  0:  param 0 at tuple index {0}, shape f32[128]
120   //  1:  param 0 at tuple index {1}, shape f32[256]
121   //  2:  param 1 at tuple index {0}, shape f32[1024]
122   //  3:  param 1 at tuple index {1}, shape f32[8]
123   //  4:  result at tuple index {0}, shape f32[8]
124   //  5:  result at tuple index {1, 0}, shape f32[128]
125   //  6:  result at tuple index {1, 1}, shape f32[256]
126   //  7:  result at tuple index {2}, shape f32[1024]
127   //
128 
129   // Set output leaf buffers, copying data from the corresponding same-sized
130   // inputs.
131   gpuMemcpyAsync(buffers[4], buffers[3], 8 * sizeof(float),
132                  gpuMemcpyDeviceToDevice, stream);
133   gpuMemcpyAsync(buffers[5], buffers[0], 128 * sizeof(float),
134                  gpuMemcpyDeviceToDevice, stream);
135   gpuMemcpyAsync(buffers[6], buffers[1], 256 * sizeof(float),
136                  gpuMemcpyDeviceToDevice, stream);
137   gpuMemcpyAsync(buffers[7], buffers[2], 1024 * sizeof(float),
138                  gpuMemcpyDeviceToDevice, stream);
139 }
140 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_SubBuffers, PLATFORM);
TEST_F(CustomCallTest,SubBuffers)141 TEST_F(CustomCallTest, SubBuffers) {
142   XlaBuilder b(TestName());
143   CustomCall(&b, "Callback_SubBuffers", /*operands=*/
144              {
145                  Tuple(&b,
146                        {
147                            Broadcast(ConstantR0WithType(&b, F32, 1), {128}),
148                            Broadcast(ConstantR0WithType(&b, F32, 2), {256}),
149                        }),
150                  Tuple(&b,
151                        {
152                            Broadcast(ConstantR0WithType(&b, F32, 3), {1024}),
153                            Broadcast(ConstantR0WithType(&b, F32, 4), {8}),
154                        }),
155              },
156              ShapeUtil::MakeTupleShape({
157                  ShapeUtil::MakeShape(F32, {8}),
158                  ShapeUtil::MakeTupleShape({
159                      ShapeUtil::MakeShape(F32, {128}),
160                      ShapeUtil::MakeShape(F32, {256}),
161                  }),
162                  ShapeUtil::MakeShape(F32, {1024}),
163              }),
164              /*opaque=*/"");
165   TF_ASSERT_OK_AND_ASSIGN(auto result, ExecuteAndTransfer(&b, {}));
166   EXPECT_THAT(result.data<float>({0}), ::testing::Each(4));
167   EXPECT_THAT(result.data<float>({1, 0}), ::testing::Each(1));
168   EXPECT_THAT(result.data<float>({1, 1}), ::testing::Each(2));
169   EXPECT_THAT(result.data<float>({2}), ::testing::Each(3));
170 }
171 
172 // The test case for custom call with tokens encodes the arguments and result
173 // type using a string with A(=Array), T(=Token) and {} for Tuples. It also
174 // encodes the check that the callback has to do in terms of a string of A and T
175 // where all the As need to be non-null and all the Ts need to be null. This is
176 // passed to the custom call as its opaque data.
177 //
178 // As an example, "ATTA" for an input encodes 4 inputs to custom call,
179 // "{A{A}T}" for output encodes a custom call with return type containing a
180 // single tuple, with another tuple as the 2nd element. For outputs, it is
181 // either a single element or a tuple. Note, no error checking is performed.
182 
183 struct TokenTestCase {
184   std::string input;
185   std::string output;
186   std::string opaque;
187 };
188 
operator <<(std::ostream & s,const TokenTestCase & tc)189 std::ostream& operator<<(std::ostream& s, const TokenTestCase& tc) {
190   s << tc.input << "x" << tc.output << "x" << tc.opaque;
191   return s;
192 }
193 
Callback_Tokens(se::gpu::GpuStreamHandle stream,void ** buffers,const char * opaque,size_t opaque_len)194 void Callback_Tokens(se::gpu::GpuStreamHandle stream, void** buffers,
195                      const char* opaque, size_t opaque_len) {
196   for (int i = 0; i < opaque_len; ++i) {
197     char c = opaque[i];
198     ASSERT_TRUE(c == 'A' || c == 'T');
199     if (c == 'A') {
200       ASSERT_NE(buffers[i], nullptr);
201     } else {
202       ASSERT_EQ(buffers[i], nullptr);
203     }
204   }
205 }
206 
207 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Tokens, PLATFORM);
208 
GetTokenTestCases()209 std::vector<TokenTestCase> GetTokenTestCases() {
210   return {{"{AT}{AT}", "{A{AT}A}", "ATATAATA"},  // tokens in input and output
211           {"{A}", "T", "AT"},                    // single token as output
212           {"{{T}}", "A", "TA"},                  // single token as input
213           {"AA", "{TA}", "AATA"},
214           {"TA{TA{TA}}", "{AA}", "TATATAAA"}};
215 }
216 
217 class CustomCallTokensTest
218     : public ::testing::WithParamInterface<TokenTestCase>,
219       public ClientLibraryTestBase {
220  public:
BuildInputs(XlaBuilder & b,std::istringstream & str)221   static std::vector<XlaOp> BuildInputs(XlaBuilder& b,
222                                         std::istringstream& str) {
223     std::vector<XlaOp> values;
224     while (!str.eof()) {
225       int ch = str.get();
226       if (ch == 'A') {
227         values.push_back(Broadcast(ConstantR0WithType(&b, F32, 1), {128}));
228       } else if (ch == 'T') {
229         values.push_back(CreateToken(&b));
230       } else if (ch == '{') {
231         // build a tuple of values. This will eat the } as well.
232         std::vector<XlaOp> tuple_elements = BuildInputs(b, str);
233         values.push_back(Tuple(&b, tuple_elements));
234       } else if (ch == '}') {
235         break;
236       }
237     }
238     return values;
239   }
240 
BuildOutputType(std::istringstream & str)241   static std::vector<Shape> BuildOutputType(std::istringstream& str) {
242     std::vector<Shape> shapes;
243     while (!str.eof()) {
244       int ch = str.get();
245       if (ch == 'A') {
246         shapes.push_back(ShapeUtil::MakeShape(F32, {8}));
247       } else if (ch == 'T') {
248         shapes.push_back(ShapeUtil::MakeTokenShape());
249       } else if (ch == '{') {
250         // build a tuple shape. This will eat the } as well.
251         std::vector<Shape> tuple_elements = BuildOutputType(str);
252         shapes.push_back(ShapeUtil::MakeTupleShape(tuple_elements));
253       } else if (ch == '}') {
254         break;
255       }
256     }
257     return shapes;
258   }
259 };
260 
TEST_P(CustomCallTokensTest,TokensTest)261 TEST_P(CustomCallTokensTest, TokensTest) {
262   const TokenTestCase& tc = GetParam();
263 
264   XlaBuilder b("CustomCallTokens");
265 
266   std::istringstream input(tc.input);
267   std::istringstream output(tc.output);
268   std::vector<XlaOp> call_inputs = BuildInputs(b, input);
269   std::vector<Shape> call_output = BuildOutputType(output);
270   ASSERT_EQ(call_output.size(), 1);
271 
272   CustomCall(&b, "Callback_Tokens", call_inputs, call_output.front(),
273              tc.opaque);
274   TF_ASSERT_OK(Execute(&b, {}).status());
275 }
276 
277 INSTANTIATE_TEST_CASE_P(CustomCallTokens, CustomCallTokensTest,
278                         ::testing::ValuesIn(GetTokenTestCases()));
279 
Callback_WithStatusSucceeded(se::gpu::GpuStreamHandle,void **,const char *,size_t,XlaCustomCallStatus * status)280 void Callback_WithStatusSucceeded(se::gpu::GpuStreamHandle /*stream*/,
281                                   void** /*buffers*/, const char* /*opaque*/,
282                                   size_t /*opaque_len*/,
283                                   XlaCustomCallStatus* status) {
284   XlaCustomCallStatusSetSuccess(status);
285 }
286 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_WithStatusSucceeded, PLATFORM);
287 
TEST_F(CustomCallTest,WithStatusSucceeded)288 TEST_F(CustomCallTest, WithStatusSucceeded) {
289   XlaBuilder b(TestName());
290   CustomCall(
291       &b, "Callback_WithStatusSucceeded", /*operands=*/{},
292       ShapeUtil::MakeShape(F32, {}), /*opaque=*/"",
293       /*has_side_effect=*/false,
294       /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
295       /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
296       /*api_version=*/CustomCallApiVersion::API_VERSION_STATUS_RETURNING);
297   TF_ASSERT_OK(Execute(&b, {}).status());
298 }
299 
Callback_WithStatusFailed(se::gpu::GpuStreamHandle,void **,const char *,size_t,XlaCustomCallStatus * status)300 void Callback_WithStatusFailed(se::gpu::GpuStreamHandle /*stream*/,
301                                void** /*buffers*/, const char* /*opaque*/,
302                                size_t /*opaque_len*/,
303                                XlaCustomCallStatus* status) {
304   XlaCustomCallStatusSetFailure(status, "Failed", 6);
305 }
306 XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_WithStatusFailed, PLATFORM);
307 
TEST_F(CustomCallTest,WithStatusFailed)308 TEST_F(CustomCallTest, WithStatusFailed) {
309   XlaBuilder b(TestName());
310   CustomCall(
311       &b, "Callback_WithStatusFailed", /*operands=*/{},
312       ShapeUtil::MakeShape(F32, {}), /*opaque=*/"",
313       /*has_side_effect=*/false,
314       /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
315       /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
316       /*api_version=*/CustomCallApiVersion::API_VERSION_STATUS_RETURNING);
317   auto status = Execute(&b, {}).status();
318   EXPECT_EQ(status.code(), tensorflow::error::Code::INTERNAL);
319   EXPECT_THAT(status.error_message(), ::testing::HasSubstr("Failed"));
320 }
321 
322 }  // anonymous namespace
323 }  // namespace xla
324