xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/client_library_test_base.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/client/client_library.h"
24 #include "tensorflow/compiler/xla/client/local_client.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/execution_options_util.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/platform_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test_helpers.h"
33 #include "tensorflow/core/platform/logging.h"
34 
35 namespace xla {
36 namespace {
37 
38 // Name of the interpreter backend.
39 constexpr char kInterpreter[] = "interpreter";
40 
41 // Wrapper function that creates a nicer error message (than a bare
42 // ValueOrDie()) if the platform we intend to test is not available.
GetOrCreateLocalClientOrDie(const LocalClientOptions & client_options)43 LocalClient* GetOrCreateLocalClientOrDie(
44     const LocalClientOptions& client_options) {
45   StatusOr<LocalClient*> result =
46       ClientLibrary::GetOrCreateLocalClient(client_options);
47   TF_CHECK_OK(result.status()) << " could not create local client for testing";
48   return result.ValueOrDie();
49 }
50 
51 // Helper functions to get the reference platform.
GetReferencePlatform()52 se::Platform* GetReferencePlatform() {
53   auto result = PlatformUtil::GetPlatform(kInterpreter);
54   TF_CHECK_OK(result.status()) << "could not get interpreter platform";
55   return result.ValueOrDie();
56 }
57 
58 }  // namespace
59 
ClientLibraryTestBase(se::Platform * platform,const LocalClientOptions & client_options)60 ClientLibraryTestBase::ClientLibraryTestBase(
61     se::Platform* platform, const LocalClientOptions& client_options)
62     : client_(GetOrCreateLocalClientOrDie(client_options)),
63       execution_options_(CreateDefaultExecutionOptions()) {
64   CHECK_EQ(platform, client_options.platform());
65 
66   LocalClientOptions ref_options;
67   ref_options.set_platform(GetReferencePlatform());
68   ref_client_ = GetOrCreateLocalClientOrDie(ref_options);
69 
70   // Disabling constant_folding so that tests (usually written using Constants)
71   // will exercise the intended code paths, instead of being constant folded.
72   //
73   // TODO(b/38354253): Constant folding is currently disabled. Change tests to
74   // use Parameters instead of Constants, and re-enable constant folding by
75   // default.
76   execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
77       "constant_folding");
78 
79   execution_options_.mutable_debug_options()
80       ->set_xla_hlo_evaluator_use_fast_path(true);
81 }
82 
ClientLibraryTestBase(se::Platform * platform)83 ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform)
84     : execution_options_(CreateDefaultExecutionOptions()) {
85   LocalClientOptions default_options;
86   default_options.set_platform(platform);
87   client_ = GetOrCreateLocalClientOrDie(default_options);
88 
89   LocalClientOptions ref_options;
90   ref_options.set_platform(GetReferencePlatform());
91   ref_client_ = GetOrCreateLocalClientOrDie(ref_options);
92 
93   execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
94       "constant_folding");
95 
96   execution_options_.mutable_debug_options()
97       ->set_xla_hlo_evaluator_use_fast_path(true);
98 }
99 
TestName() const100 std::string ClientLibraryTestBase::TestName() const {
101   return ::testing::UnitTest::GetInstance()->current_test_info()->name();
102 }
103 
Execute(XlaBuilder * builder,absl::Span<GlobalData * const> arguments)104 StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
105     XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
106   // Build the computation, as a convenience.
107   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
108   return client_->Execute(computation, arguments, &execution_options_);
109 }
110 
ExecuteAndTransfer(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const Shape * shape_with_output_layout)111 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
112     const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
113     const Shape* shape_with_output_layout) {
114   ExecutionOptions execution_options = execution_options_;
115   if (shape_with_output_layout != nullptr) {
116     *execution_options.mutable_shape_with_output_layout() =
117         shape_with_output_layout->ToProto();
118   }
119   return client_->ExecuteAndTransfer(computation, arguments,
120                                      &execution_options);
121 }
122 
ExecuteAndTransfer(XlaBuilder * builder,absl::Span<GlobalData * const> arguments,const Shape * shape_with_output_layout)123 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
124     XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
125     const Shape* shape_with_output_layout) {
126   // Build the computation, as a convenience.
127   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
128   return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
129 }
130 
ExecuteAndTransferReference(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const Shape * shape_with_output_layout)131 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
132     const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
133     const Shape* shape_with_output_layout) {
134   ExecutionOptions execution_options = execution_options_;
135   if (shape_with_output_layout != nullptr) {
136     *execution_options.mutable_shape_with_output_layout() =
137         shape_with_output_layout->ToProto();
138   }
139   execution_options.clear_device_handles();
140   return ref_client_->ExecuteAndTransfer(computation, arguments,
141                                          &execution_options);
142 }
143 
ExecuteToString(XlaBuilder * builder,absl::Span<GlobalData * const> arguments)144 std::string ClientLibraryTestBase::ExecuteToString(
145     XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
146   auto computation_status = builder->Build();
147   if (!computation_status.ok()) {
148     return computation_status.status().ToString();
149   }
150   auto computation = std::move(computation_status).value();
151 
152   auto result =
153       client_->ExecuteAndTransfer(computation, arguments, &execution_options_);
154   if (!result.ok()) {
155     return result.status().ToString();
156   } else {
157     return result.ValueOrDie().ToString();
158   }
159 }
160 
ComputeAndCompareR1(XlaBuilder * builder,const tensorflow::core::Bitmap & expected,absl::Span<GlobalData * const> arguments)161 void ClientLibraryTestBase::ComputeAndCompareR1(
162     XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
163     absl::Span<GlobalData* const> arguments) {
164   Literal expected_literal = LiteralUtil::CreateR1(expected);
165   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
166                                                   arguments);
167 }
168 
ComputeAndCompareLiteral(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments,const Shape * shape_with_layout)169 void ClientLibraryTestBase::ComputeAndCompareLiteral(
170     XlaBuilder* builder, const Literal& expected,
171     absl::Span<GlobalData* const> arguments, const Shape* shape_with_layout) {
172   EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
173                                                   shape_with_layout));
174 }
175 
ComputeAndCompareLiteral(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error,const Shape * shape_with_layout)176 void ClientLibraryTestBase::ComputeAndCompareLiteral(
177     XlaBuilder* builder, const Literal& expected,
178     absl::Span<GlobalData* const> arguments, ErrorSpec error,
179     const Shape* shape_with_layout) {
180   EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
181                                                   error, shape_with_layout));
182 }
183 
ComputeAndCompareLiteralWithAllOutputLayouts(const xla::XlaComputation & computation,const Literal & expected,absl::Span<GlobalData * const> arguments,const std::function<void (const Literal & actual,const std::string & error_message)> & verify_output)184 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
185     const xla::XlaComputation& computation, const Literal& expected,
186     absl::Span<GlobalData* const> arguments,
187     const std::function<void(const Literal& actual,
188                              const std::string& error_message)>&
189         verify_output) {
190   // Try with no layout requirement.
191   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
192   verify_output(actual, "");
193 
194   // Try with all output layouts.
195   std::vector<int64_t> minor_to_major(expected.shape().rank());
196   std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
197   do {
198     auto layout = ShapeUtil::MakeShapeWithLayout(
199         expected.shape().element_type(), expected.shape().dimensions(),
200         minor_to_major);
201     TF_ASSIGN_OR_RETURN(auto actual,
202                         ExecuteAndTransfer(computation, arguments, &layout));
203     verify_output(actual,
204                   absl::StrCat("Test with output layout: ",
205                                ShapeUtil::HumanStringWithLayout(layout)));
206   } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
207   return OkStatus();
208 }
209 
ComputeAndCompareLiteralWithAllInputLayouts(const xla::XlaComputation & computation,const Literal &,absl::Span<GlobalData * const> arguments,const std::function<void (const Literal & actual,const std::string & error_message)> & verify_output,const Shape * output_with_layout)210 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
211     const xla::XlaComputation& computation, const Literal& /*expected*/,
212     absl::Span<GlobalData* const> arguments,
213     const std::function<void(const Literal& actual,
214                              const std::string& error_message)>& verify_output,
215     const Shape* output_with_layout) {
216   std::vector<GlobalData*> arguments_with_layout;
217   std::vector<std::string> layout_strings;
218   // This is a recursive function. It's an std::function instead of a lambda
219   // because it needs to capture itself. The index is the index of the argument
220   // to try all layouts for.
221   std::function<Status(int64_t)> choose;
222   choose = [&, this](int64_t index) -> Status {
223     if (index < arguments.size()) {
224       // Try out all layouts for the operand.
225       TF_ASSIGN_OR_RETURN(auto literal,
226                           client_->Transfer(*arguments[index], nullptr));
227       // Skip tuples because they don't have a rank.
228       if (literal.shape().IsTuple()) {
229         layout_strings.push_back(
230             ShapeUtil::HumanStringWithLayout(literal.shape()));
231         arguments_with_layout.push_back(arguments[index]);
232         TF_RETURN_IF_ERROR(choose(index + 1));
233         arguments_with_layout.pop_back();
234         layout_strings.pop_back();
235         return OkStatus();
236       }
237 
238       std::vector<int64_t> minor_to_major(literal.shape().rank());
239       std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
240       do {
241         auto literal_relayout =
242             literal.Relayout(LayoutUtil::MakeLayout(minor_to_major));
243         layout_strings.push_back(
244             ShapeUtil::HumanStringWithLayout(literal_relayout.shape()));
245         TF_ASSIGN_OR_RETURN(auto data,
246                             client_->TransferToServer(literal_relayout));
247         arguments_with_layout.push_back(data.get());
248         TF_RETURN_IF_ERROR(choose(index + 1));
249         arguments_with_layout.pop_back();
250         layout_strings.pop_back();
251       } while (
252           std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
253       return OkStatus();
254     }
255 
256     // Every argument has an assigned layout.
257     TF_ASSIGN_OR_RETURN(
258         auto actual,
259         ExecuteAndTransfer(computation,
260                            absl::Span<GlobalData* const>(arguments_with_layout),
261                            output_with_layout));
262     std::string error_message = "Test with input layouts: ";
263     for (const auto& str : layout_strings) {
264       absl::StrAppend(&error_message, str, " ");
265     }
266     verify_output(actual, error_message);
267     return OkStatus();
268   };
269 
270   return choose(0);
271 }
272 
ComputeAndTransfer(XlaBuilder * builder,absl::Span<GlobalData * const> arguments_passed_in,const Shape * shape_with_layout)273 StatusOr<Literal> ClientLibraryTestBase::ComputeAndTransfer(
274     XlaBuilder* builder, absl::Span<GlobalData* const> arguments_passed_in,
275     const Shape* shape_with_layout) {
276   std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
277                                      arguments_passed_in.end());
278 
279   // Transfer and use elements of arguments_, if the AddParam() API was used.
280   std::vector<std::unique_ptr<GlobalData>> owning_arguments;
281   if (!arguments_.empty()) {
282     CHECK(arguments.empty());
283     for (const auto& argument : arguments_) {
284       TF_ASSIGN_OR_RETURN(
285           std::unique_ptr<GlobalData> owned_argument,
286           client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
287       owning_arguments.push_back(std::move(owned_argument));
288       arguments.push_back(owning_arguments.back().get());
289     }
290   }
291 
292   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
293   return ExecuteAndTransfer(computation, arguments, shape_with_layout);
294 }
295 
ComputeAndCompareLiteralWithStatus(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments_passed_in,const Shape * shape_with_layout)296 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
297     XlaBuilder* builder, const Literal& expected,
298     absl::Span<GlobalData* const> arguments_passed_in,
299     const Shape* shape_with_layout) {
300   std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
301                                      arguments_passed_in.end());
302 
303   // Transfer and use elements of arguments_, if the AddParam() API was used.
304   std::vector<std::unique_ptr<GlobalData>> owning_arguments;
305   if (!arguments_.empty()) {
306     CHECK(arguments.empty());
307     for (const auto& argument : arguments_) {
308       TF_ASSIGN_OR_RETURN(
309           std::unique_ptr<GlobalData> owned_argument,
310           client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
311       owning_arguments.push_back(std::move(owned_argument));
312       arguments.push_back(owning_arguments.back().get());
313     }
314   }
315 
316   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
317   if (ShapeUtil::ElementIsFloating(expected.shape()) ||
318       ShapeUtil::ElementIsComplex(expected.shape())) {
319     LOG(WARNING) << "performing exact comparison of floating point numbers";
320   }
321   // We allow using a float expected literal for a bfloat16 output. In this
322   // case, we need to convert the expected literal to bfloat16.
323   const Literal* expected_ptr = &expected;
324   Literal converted_expected;
325   Shape layout_shape;
326   if (use_bfloat16_) {
327     converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
328     expected_ptr = &converted_expected;
329     if (shape_with_layout != nullptr) {
330       layout_shape = *shape_with_layout;
331       ShapeUtil::ForEachMutableSubshape(
332           &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
333             if (subshape->element_type() == F32) {
334               subshape->set_element_type(BF16);
335             }
336           });
337       shape_with_layout = &layout_shape;
338     }
339   }
340   auto expect_equal = [&](const Literal& actual,
341                           const std::string& error_message) {
342     EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message;
343   };
344   if (execution_options_.debug_options().xla_test_all_output_layouts()) {
345     return ComputeAndCompareLiteralWithAllOutputLayouts(
346         computation, *expected_ptr, arguments, expect_equal);
347   }
348   if (execution_options_.debug_options().xla_test_all_input_layouts()) {
349     return ComputeAndCompareLiteralWithAllInputLayouts(
350         computation, *expected_ptr, arguments, expect_equal, shape_with_layout);
351   }
352   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
353                                                       shape_with_layout));
354   EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual));
355   return OkStatus();
356 }
357 
ComputeAndCompareLiteralWithStatus(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments_passed_in,ErrorSpec error,const Shape * shape_with_layout)358 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
359     XlaBuilder* builder, const Literal& expected,
360     absl::Span<GlobalData* const> arguments_passed_in, ErrorSpec error,
361     const Shape* shape_with_layout) {
362   std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
363                                      arguments_passed_in.end());
364 
365   // Transfer and use elements of arguments_, if the AddParam() API was used.
366   std::vector<std::unique_ptr<GlobalData>> owning_arguments;
367   if (!arguments_.empty()) {
368     CHECK(arguments.empty());
369     for (const auto& argument : arguments_) {
370       TF_ASSIGN_OR_RETURN(
371           std::unique_ptr<GlobalData> owned_argument,
372           client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
373       owning_arguments.push_back(std::move(owned_argument));
374       arguments.push_back(owning_arguments.back().get());
375     }
376   }
377 
378   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
379   // We allow using a float expected literal for a bfloat16 output. In this
380   // case, we need to convert the expected literal to bfloat16.
381   const Literal* expected_ptr = &expected;
382   Literal converted_expected;
383   Shape layout_shape;
384   if (use_bfloat16_) {
385     converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
386     expected_ptr = &converted_expected;
387     if (shape_with_layout != nullptr) {
388       layout_shape = *shape_with_layout;
389       ShapeUtil::ForEachMutableSubshape(
390           &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
391             if (subshape->element_type() == F32) {
392               subshape->set_element_type(BF16);
393             }
394           });
395       shape_with_layout = &layout_shape;
396     }
397   }
398   auto expect_near = [&](const Literal& actual,
399                          const std::string& error_message) {
400     EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error))
401         << error_message;
402   };
403   if (execution_options_.debug_options().xla_test_all_output_layouts()) {
404     return ComputeAndCompareLiteralWithAllOutputLayouts(
405         computation, *expected_ptr, arguments, expect_near);
406   }
407   if (execution_options_.debug_options().xla_test_all_input_layouts()) {
408     return ComputeAndCompareLiteralWithAllInputLayouts(
409         computation, *expected_ptr, arguments, expect_near, shape_with_layout);
410   }
411   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
412                                                       shape_with_layout));
413   EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error));
414   return OkStatus();
415 }
416 
ComputeAndCompareR1U8(XlaBuilder * builder,absl::string_view expected,absl::Span<GlobalData * const> arguments)417 void ClientLibraryTestBase::ComputeAndCompareR1U8(
418     XlaBuilder* builder, absl::string_view expected,
419     absl::Span<GlobalData* const> arguments) {
420   auto actual_status = ExecuteAndTransfer(builder, arguments);
421   EXPECT_IS_OK(actual_status.status());
422   if (!actual_status.ok()) {
423     return;
424   }
425   auto actual = std::move(actual_status).value();
426 
427   // Turn the expected value into a literal.
428   Literal expected_literal = LiteralUtil::CreateR1U8(expected);
429 
430   VLOG(1) << "expected: " << expected_literal.ToString();
431   VLOG(1) << "actual:   " << actual.ToString();
432 
433   EXPECT_EQ(expected, actual.GetR1U8AsString());
434 }
435 
ComputeAndCompareTuple(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments)436 void ClientLibraryTestBase::ComputeAndCompareTuple(
437     XlaBuilder* builder, const Literal& expected,
438     absl::Span<GlobalData* const> arguments) {
439   auto actual_status = ExecuteAndTransfer(builder, arguments);
440   EXPECT_IS_OK(actual_status.status());
441   if (!actual_status.ok()) {
442     return;
443   }
444   auto actual = std::move(actual_status).value();
445   EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
446 }
447 
ComputeAndCompareTuple(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)448 void ClientLibraryTestBase::ComputeAndCompareTuple(
449     XlaBuilder* builder, const Literal& expected,
450     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
451   auto actual_status = ExecuteAndTransfer(builder, arguments);
452   EXPECT_IS_OK(actual_status.status());
453   if (!actual_status.ok()) {
454     return;
455   }
456   auto actual = std::move(actual_status).value();
457   EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error));
458 }
459 
ComputeAndCompare(XlaBuilder * builder,absl::Span<const Literal> arguments)460 void ClientLibraryTestBase::ComputeAndCompare(
461     XlaBuilder* builder, absl::Span<const Literal> arguments) {
462   auto status_or_data = ComputeValueAndReference(builder, arguments);
463   EXPECT_IS_OK(status_or_data);
464   if (!status_or_data.ok()) {
465     return;
466   }
467   Literal reference, result;
468   std::tie(reference, result) = std::move(status_or_data).value();
469   EXPECT_TRUE(LiteralTestUtil::Equal(reference, result));
470 }
471 
ComputeAndCompare(XlaBuilder * builder,absl::Span<const Literal> arguments,ErrorSpec error)472 void ClientLibraryTestBase::ComputeAndCompare(
473     XlaBuilder* builder, absl::Span<const Literal> arguments, ErrorSpec error) {
474   auto status_or_data = ComputeValueAndReference(builder, arguments);
475   EXPECT_IS_OK(status_or_data);
476   if (!status_or_data.ok()) {
477     return;
478   }
479   Literal reference, result;
480   std::tie(reference, result) = std::move(status_or_data).value();
481   EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error));
482 }
483 
484 StatusOr<std::pair<Literal, Literal>>
ComputeValueAndReference(XlaBuilder * builder,absl::Span<const Literal> arguments)485 ClientLibraryTestBase::ComputeValueAndReference(
486     XlaBuilder* builder, absl::Span<const Literal> arguments) {
487   // Transfer the arguments to the executor service. We put the unique_ptr's
488   // into a vector to keep the data alive on the service until the end of this
489   // function.
490   std::vector<std::unique_ptr<GlobalData>> argument_data;
491   std::vector<std::unique_ptr<GlobalData>> ref_argument_data;
492 
493   // Use `arguments_` if the AddParam() API was used.  Otherwise, use
494   // plain `arguments`.
495   if (!arguments_.empty()) {
496     CHECK_EQ(arguments.size(), 0);
497     arguments = arguments_;
498   }
499 
500   for (const auto& arg : arguments) {
501     TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg.Clone()));
502     TF_ASSIGN_OR_RETURN(auto ref_data, ref_client_->TransferToServer(arg));
503     argument_data.push_back(std::move(data));
504     ref_argument_data.push_back(std::move(ref_data));
505   }
506 
507   // Create raw pointers to the GlobalData for the rest of the call stack.
508   std::vector<GlobalData*> argument_data_ptr;
509   std::transform(
510       argument_data.begin(), argument_data.end(),
511       std::back_inserter(argument_data_ptr),
512       [](const std::unique_ptr<GlobalData>& data) { return data.get(); });
513   std::vector<GlobalData*> ref_argument_data_ptr;
514   std::transform(
515       ref_argument_data.begin(), ref_argument_data.end(),
516       std::back_inserter(ref_argument_data_ptr),
517       [](const std::unique_ptr<GlobalData>& data) { return data.get(); });
518 
519   TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
520 
521   TF_ASSIGN_OR_RETURN(auto result,
522                       ExecuteAndTransfer(computation, argument_data_ptr));
523 
524   TF_ASSIGN_OR_RETURN(auto reference, ExecuteAndTransferReference(
525                                           computation, ref_argument_data_ptr));
526 
527   return std::make_pair(std::move(reference), std::move(result));
528 }
529 
CreateScalarRelu()530 XlaComputation ClientLibraryTestBase::CreateScalarRelu() {
531   XlaBuilder builder("relu");
532   auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
533   auto z_value = Parameter(&builder, 0, shape, "z_value");
534   auto zero = use_bfloat16_
535                   ? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
536                   : ConstantR0<float>(&builder, 0.0f);
537   Max(z_value, zero);
538   auto computation_status = builder.Build();
539   TF_CHECK_OK(computation_status.status());
540   return std::move(computation_status).value();
541 }
542 
CreateScalarMax()543 XlaComputation ClientLibraryTestBase::CreateScalarMax() {
544   XlaBuilder builder("max");
545   auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
546   auto x = Parameter(&builder, 0, shape, "x");
547   auto y = Parameter(&builder, 1, shape, "y");
548   Max(x, y);
549   auto computation_status = builder.Build();
550   TF_CHECK_OK(computation_status.status());
551   return std::move(computation_status).value();
552 }
553 
CreateScalarReluSensitivity()554 XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() {
555   XlaBuilder builder("relu_sensitivity");
556   auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
557   auto activation = Parameter(&builder, 0, shape, "activation");
558   auto backprop = Parameter(&builder, 1, shape, "backprop");
559   auto zero = use_bfloat16_
560                   ? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
561                   : ConstantR0<float>(&builder, 0.0f);
562   auto activation_gtz = Gt(activation, zero);
563   Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero);
564 
565   auto computation_status = builder.Build();
566   TF_CHECK_OK(computation_status.status());
567   return std::move(computation_status).value();
568 }
569 
CreatePatternedMatrix(int rows,int cols,float offset)570 std::unique_ptr<Array2D<float>> ClientLibraryTestBase::CreatePatternedMatrix(
571     int rows, int cols, float offset) {
572   auto array = std::make_unique<Array2D<float>>(rows, cols);
573   for (int64_t row = 0; row < rows; ++row) {
574     for (int64_t col = 0; col < cols; ++col) {
575       (*array)(row, col) = col + (row * 1000.0f) + offset;
576     }
577   }
578   return array;
579 }
580 
581 std::unique_ptr<Array2D<float>>
CreatePatternedMatrixWithZeroPadding(int rows,int cols,int rows_padded,int cols_padded)582 ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
583                                                             int rows_padded,
584                                                             int cols_padded) {
585   CHECK_GE(rows_padded, rows);
586   CHECK_GE(cols_padded, cols);
587   auto array = std::make_unique<Array2D<float>>(rows_padded, cols_padded, 0.0);
588   for (int64_t row = 0; row < rows; ++row) {
589     for (int64_t col = 0; col < cols; ++col) {
590       (*array)(row, col) = col + (row * 1000.0f);
591     }
592   }
593   return array;
594 }
595 
AddParam(const Literal & argument,XlaBuilder * builder)596 XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
597                                       XlaBuilder* builder) {
598   arguments_.push_back(argument.Clone());
599   return Parameter(builder, /*parameter_number=*/arguments_.size() - 1,
600                    MaybeConvertShapeToBfloat16(argument.shape()), "");
601 }
602 
CreateConstantFromLiteral(const Literal & literal,XlaBuilder * builder)603 XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
604                                                        XlaBuilder* builder) {
605   return ConstantLiteral(builder, use_bfloat16_
606                                       ? LiteralUtil::ConvertF32ToBF16(literal)
607                                       : LiteralSlice(literal));
608 }
609 
610 StatusOr<std::unique_ptr<GlobalData>>
CreateParameterAndTransferLiteral(int64_t parameter_number,const Literal & literal,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)611 ClientLibraryTestBase::CreateParameterAndTransferLiteral(
612     int64_t parameter_number, const Literal& literal, const std::string& name,
613     XlaBuilder* builder, XlaOp* data_handle) {
614   return CreateParameterAndTransferLiteral(parameter_number, literal, name,
615                                            nullptr, builder, data_handle);
616 }
617 
MaybeConvertShapeToBfloat16(const Shape & shape)618 Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
619   if (!use_bfloat16_) {
620     return shape;
621   }
622   Shape new_shape = shape;
623   ShapeUtil::ForEachMutableSubshape(&new_shape,
624                                     [](Shape* subshape, const ShapeIndex&) {
625                                       if (subshape->element_type() == F32) {
626                                         subshape->set_element_type(BF16);
627                                       }
628                                     });
629   return new_shape;
630 }
631 
MaybeConvertLiteralToBfloat16(const Literal & literal)632 Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
633     const Literal& literal) {
634   if (use_bfloat16_) {
635     return LiteralUtil::ConvertF32ToBF16(literal);
636   }
637   return literal.Clone();
638 }
639 
640 StatusOr<std::unique_ptr<GlobalData>>
CreateParameterAndTransferLiteral(int64_t parameter_number,const Literal & literal,const std::string & name,const DeviceHandle * device_handle,XlaBuilder * builder,XlaOp * data_handle)641 ClientLibraryTestBase::CreateParameterAndTransferLiteral(
642     int64_t parameter_number, const Literal& literal, const std::string& name,
643     const DeviceHandle* device_handle, XlaBuilder* builder,
644     XlaOp* data_handle) {
645   Literal param_literal = MaybeConvertLiteralToBfloat16(literal);
646   TF_ASSIGN_OR_RETURN(auto data,
647                       client_->TransferToServer(param_literal, device_handle));
648   *data_handle =
649       Parameter(builder, parameter_number, param_literal.shape(), name);
650   return data;
651 }
652 
653 }  // namespace xla
654