1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/client/client_library.h"
24 #include "tensorflow/compiler/xla/client/local_client.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/execution_options_util.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/service/platform_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test_helpers.h"
33 #include "tensorflow/core/platform/logging.h"
34
35 namespace xla {
36 namespace {
37
38 // Name of the interpreter backend.
39 constexpr char kInterpreter[] = "interpreter";
40
41 // Wrapper function that creates a nicer error message (than a bare
42 // ValueOrDie()) if the platform we intend to test is not available.
GetOrCreateLocalClientOrDie(const LocalClientOptions & client_options)43 LocalClient* GetOrCreateLocalClientOrDie(
44 const LocalClientOptions& client_options) {
45 StatusOr<LocalClient*> result =
46 ClientLibrary::GetOrCreateLocalClient(client_options);
47 TF_CHECK_OK(result.status()) << " could not create local client for testing";
48 return result.ValueOrDie();
49 }
50
51 // Helper functions to get the reference platform.
GetReferencePlatform()52 se::Platform* GetReferencePlatform() {
53 auto result = PlatformUtil::GetPlatform(kInterpreter);
54 TF_CHECK_OK(result.status()) << "could not get interpreter platform";
55 return result.ValueOrDie();
56 }
57
58 } // namespace
59
ClientLibraryTestBase(se::Platform * platform,const LocalClientOptions & client_options)60 ClientLibraryTestBase::ClientLibraryTestBase(
61 se::Platform* platform, const LocalClientOptions& client_options)
62 : client_(GetOrCreateLocalClientOrDie(client_options)),
63 execution_options_(CreateDefaultExecutionOptions()) {
64 CHECK_EQ(platform, client_options.platform());
65
66 LocalClientOptions ref_options;
67 ref_options.set_platform(GetReferencePlatform());
68 ref_client_ = GetOrCreateLocalClientOrDie(ref_options);
69
70 // Disabling constant_folding so that tests (usually written using Constants)
71 // will exercise the intended code paths, instead of being constant folded.
72 //
73 // TODO(b/38354253): Constant folding is currently disabled. Change tests to
74 // use Parameters instead of Constants, and re-enable constant folding by
75 // default.
76 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
77 "constant_folding");
78
79 execution_options_.mutable_debug_options()
80 ->set_xla_hlo_evaluator_use_fast_path(true);
81 }
82
ClientLibraryTestBase(se::Platform * platform)83 ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform)
84 : execution_options_(CreateDefaultExecutionOptions()) {
85 LocalClientOptions default_options;
86 default_options.set_platform(platform);
87 client_ = GetOrCreateLocalClientOrDie(default_options);
88
89 LocalClientOptions ref_options;
90 ref_options.set_platform(GetReferencePlatform());
91 ref_client_ = GetOrCreateLocalClientOrDie(ref_options);
92
93 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
94 "constant_folding");
95
96 execution_options_.mutable_debug_options()
97 ->set_xla_hlo_evaluator_use_fast_path(true);
98 }
99
TestName() const100 std::string ClientLibraryTestBase::TestName() const {
101 return ::testing::UnitTest::GetInstance()->current_test_info()->name();
102 }
103
Execute(XlaBuilder * builder,absl::Span<GlobalData * const> arguments)104 StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
105 XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
106 // Build the computation, as a convenience.
107 TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
108 return client_->Execute(computation, arguments, &execution_options_);
109 }
110
ExecuteAndTransfer(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const Shape * shape_with_output_layout)111 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
112 const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
113 const Shape* shape_with_output_layout) {
114 ExecutionOptions execution_options = execution_options_;
115 if (shape_with_output_layout != nullptr) {
116 *execution_options.mutable_shape_with_output_layout() =
117 shape_with_output_layout->ToProto();
118 }
119 return client_->ExecuteAndTransfer(computation, arguments,
120 &execution_options);
121 }
122
ExecuteAndTransfer(XlaBuilder * builder,absl::Span<GlobalData * const> arguments,const Shape * shape_with_output_layout)123 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
124 XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
125 const Shape* shape_with_output_layout) {
126 // Build the computation, as a convenience.
127 TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
128 return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
129 }
130
ExecuteAndTransferReference(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const Shape * shape_with_output_layout)131 StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
132 const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
133 const Shape* shape_with_output_layout) {
134 ExecutionOptions execution_options = execution_options_;
135 if (shape_with_output_layout != nullptr) {
136 *execution_options.mutable_shape_with_output_layout() =
137 shape_with_output_layout->ToProto();
138 }
139 execution_options.clear_device_handles();
140 return ref_client_->ExecuteAndTransfer(computation, arguments,
141 &execution_options);
142 }
143
ExecuteToString(XlaBuilder * builder,absl::Span<GlobalData * const> arguments)144 std::string ClientLibraryTestBase::ExecuteToString(
145 XlaBuilder* builder, absl::Span<GlobalData* const> arguments) {
146 auto computation_status = builder->Build();
147 if (!computation_status.ok()) {
148 return computation_status.status().ToString();
149 }
150 auto computation = std::move(computation_status).value();
151
152 auto result =
153 client_->ExecuteAndTransfer(computation, arguments, &execution_options_);
154 if (!result.ok()) {
155 return result.status().ToString();
156 } else {
157 return result.ValueOrDie().ToString();
158 }
159 }
160
ComputeAndCompareR1(XlaBuilder * builder,const tensorflow::core::Bitmap & expected,absl::Span<GlobalData * const> arguments)161 void ClientLibraryTestBase::ComputeAndCompareR1(
162 XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
163 absl::Span<GlobalData* const> arguments) {
164 Literal expected_literal = LiteralUtil::CreateR1(expected);
165 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
166 arguments);
167 }
168
ComputeAndCompareLiteral(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments,const Shape * shape_with_layout)169 void ClientLibraryTestBase::ComputeAndCompareLiteral(
170 XlaBuilder* builder, const Literal& expected,
171 absl::Span<GlobalData* const> arguments, const Shape* shape_with_layout) {
172 EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
173 shape_with_layout));
174 }
175
ComputeAndCompareLiteral(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error,const Shape * shape_with_layout)176 void ClientLibraryTestBase::ComputeAndCompareLiteral(
177 XlaBuilder* builder, const Literal& expected,
178 absl::Span<GlobalData* const> arguments, ErrorSpec error,
179 const Shape* shape_with_layout) {
180 EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
181 error, shape_with_layout));
182 }
183
ComputeAndCompareLiteralWithAllOutputLayouts(const xla::XlaComputation & computation,const Literal & expected,absl::Span<GlobalData * const> arguments,const std::function<void (const Literal & actual,const std::string & error_message)> & verify_output)184 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
185 const xla::XlaComputation& computation, const Literal& expected,
186 absl::Span<GlobalData* const> arguments,
187 const std::function<void(const Literal& actual,
188 const std::string& error_message)>&
189 verify_output) {
190 // Try with no layout requirement.
191 TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
192 verify_output(actual, "");
193
194 // Try with all output layouts.
195 std::vector<int64_t> minor_to_major(expected.shape().rank());
196 std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
197 do {
198 auto layout = ShapeUtil::MakeShapeWithLayout(
199 expected.shape().element_type(), expected.shape().dimensions(),
200 minor_to_major);
201 TF_ASSIGN_OR_RETURN(auto actual,
202 ExecuteAndTransfer(computation, arguments, &layout));
203 verify_output(actual,
204 absl::StrCat("Test with output layout: ",
205 ShapeUtil::HumanStringWithLayout(layout)));
206 } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
207 return OkStatus();
208 }
209
ComputeAndCompareLiteralWithAllInputLayouts(const xla::XlaComputation & computation,const Literal &,absl::Span<GlobalData * const> arguments,const std::function<void (const Literal & actual,const std::string & error_message)> & verify_output,const Shape * output_with_layout)210 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
211 const xla::XlaComputation& computation, const Literal& /*expected*/,
212 absl::Span<GlobalData* const> arguments,
213 const std::function<void(const Literal& actual,
214 const std::string& error_message)>& verify_output,
215 const Shape* output_with_layout) {
216 std::vector<GlobalData*> arguments_with_layout;
217 std::vector<std::string> layout_strings;
218 // This is a recursive function. It's an std::function instead of a lambda
219 // because it needs to capture itself. The index is the index of the argument
220 // to try all layouts for.
221 std::function<Status(int64_t)> choose;
222 choose = [&, this](int64_t index) -> Status {
223 if (index < arguments.size()) {
224 // Try out all layouts for the operand.
225 TF_ASSIGN_OR_RETURN(auto literal,
226 client_->Transfer(*arguments[index], nullptr));
227 // Skip tuples because they don't have a rank.
228 if (literal.shape().IsTuple()) {
229 layout_strings.push_back(
230 ShapeUtil::HumanStringWithLayout(literal.shape()));
231 arguments_with_layout.push_back(arguments[index]);
232 TF_RETURN_IF_ERROR(choose(index + 1));
233 arguments_with_layout.pop_back();
234 layout_strings.pop_back();
235 return OkStatus();
236 }
237
238 std::vector<int64_t> minor_to_major(literal.shape().rank());
239 std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
240 do {
241 auto literal_relayout =
242 literal.Relayout(LayoutUtil::MakeLayout(minor_to_major));
243 layout_strings.push_back(
244 ShapeUtil::HumanStringWithLayout(literal_relayout.shape()));
245 TF_ASSIGN_OR_RETURN(auto data,
246 client_->TransferToServer(literal_relayout));
247 arguments_with_layout.push_back(data.get());
248 TF_RETURN_IF_ERROR(choose(index + 1));
249 arguments_with_layout.pop_back();
250 layout_strings.pop_back();
251 } while (
252 std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
253 return OkStatus();
254 }
255
256 // Every argument has an assigned layout.
257 TF_ASSIGN_OR_RETURN(
258 auto actual,
259 ExecuteAndTransfer(computation,
260 absl::Span<GlobalData* const>(arguments_with_layout),
261 output_with_layout));
262 std::string error_message = "Test with input layouts: ";
263 for (const auto& str : layout_strings) {
264 absl::StrAppend(&error_message, str, " ");
265 }
266 verify_output(actual, error_message);
267 return OkStatus();
268 };
269
270 return choose(0);
271 }
272
ComputeAndTransfer(XlaBuilder * builder,absl::Span<GlobalData * const> arguments_passed_in,const Shape * shape_with_layout)273 StatusOr<Literal> ClientLibraryTestBase::ComputeAndTransfer(
274 XlaBuilder* builder, absl::Span<GlobalData* const> arguments_passed_in,
275 const Shape* shape_with_layout) {
276 std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
277 arguments_passed_in.end());
278
279 // Transfer and use elements of arguments_, if the AddParam() API was used.
280 std::vector<std::unique_ptr<GlobalData>> owning_arguments;
281 if (!arguments_.empty()) {
282 CHECK(arguments.empty());
283 for (const auto& argument : arguments_) {
284 TF_ASSIGN_OR_RETURN(
285 std::unique_ptr<GlobalData> owned_argument,
286 client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
287 owning_arguments.push_back(std::move(owned_argument));
288 arguments.push_back(owning_arguments.back().get());
289 }
290 }
291
292 TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
293 return ExecuteAndTransfer(computation, arguments, shape_with_layout);
294 }
295
ComputeAndCompareLiteralWithStatus(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments_passed_in,const Shape * shape_with_layout)296 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
297 XlaBuilder* builder, const Literal& expected,
298 absl::Span<GlobalData* const> arguments_passed_in,
299 const Shape* shape_with_layout) {
300 std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
301 arguments_passed_in.end());
302
303 // Transfer and use elements of arguments_, if the AddParam() API was used.
304 std::vector<std::unique_ptr<GlobalData>> owning_arguments;
305 if (!arguments_.empty()) {
306 CHECK(arguments.empty());
307 for (const auto& argument : arguments_) {
308 TF_ASSIGN_OR_RETURN(
309 std::unique_ptr<GlobalData> owned_argument,
310 client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
311 owning_arguments.push_back(std::move(owned_argument));
312 arguments.push_back(owning_arguments.back().get());
313 }
314 }
315
316 TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
317 if (ShapeUtil::ElementIsFloating(expected.shape()) ||
318 ShapeUtil::ElementIsComplex(expected.shape())) {
319 LOG(WARNING) << "performing exact comparison of floating point numbers";
320 }
321 // We allow using a float expected literal for a bfloat16 output. In this
322 // case, we need to convert the expected literal to bfloat16.
323 const Literal* expected_ptr = &expected;
324 Literal converted_expected;
325 Shape layout_shape;
326 if (use_bfloat16_) {
327 converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
328 expected_ptr = &converted_expected;
329 if (shape_with_layout != nullptr) {
330 layout_shape = *shape_with_layout;
331 ShapeUtil::ForEachMutableSubshape(
332 &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
333 if (subshape->element_type() == F32) {
334 subshape->set_element_type(BF16);
335 }
336 });
337 shape_with_layout = &layout_shape;
338 }
339 }
340 auto expect_equal = [&](const Literal& actual,
341 const std::string& error_message) {
342 EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message;
343 };
344 if (execution_options_.debug_options().xla_test_all_output_layouts()) {
345 return ComputeAndCompareLiteralWithAllOutputLayouts(
346 computation, *expected_ptr, arguments, expect_equal);
347 }
348 if (execution_options_.debug_options().xla_test_all_input_layouts()) {
349 return ComputeAndCompareLiteralWithAllInputLayouts(
350 computation, *expected_ptr, arguments, expect_equal, shape_with_layout);
351 }
352 TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
353 shape_with_layout));
354 EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual));
355 return OkStatus();
356 }
357
ComputeAndCompareLiteralWithStatus(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments_passed_in,ErrorSpec error,const Shape * shape_with_layout)358 Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
359 XlaBuilder* builder, const Literal& expected,
360 absl::Span<GlobalData* const> arguments_passed_in, ErrorSpec error,
361 const Shape* shape_with_layout) {
362 std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
363 arguments_passed_in.end());
364
365 // Transfer and use elements of arguments_, if the AddParam() API was used.
366 std::vector<std::unique_ptr<GlobalData>> owning_arguments;
367 if (!arguments_.empty()) {
368 CHECK(arguments.empty());
369 for (const auto& argument : arguments_) {
370 TF_ASSIGN_OR_RETURN(
371 std::unique_ptr<GlobalData> owned_argument,
372 client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument)));
373 owning_arguments.push_back(std::move(owned_argument));
374 arguments.push_back(owning_arguments.back().get());
375 }
376 }
377
378 TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
379 // We allow using a float expected literal for a bfloat16 output. In this
380 // case, we need to convert the expected literal to bfloat16.
381 const Literal* expected_ptr = &expected;
382 Literal converted_expected;
383 Shape layout_shape;
384 if (use_bfloat16_) {
385 converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
386 expected_ptr = &converted_expected;
387 if (shape_with_layout != nullptr) {
388 layout_shape = *shape_with_layout;
389 ShapeUtil::ForEachMutableSubshape(
390 &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) {
391 if (subshape->element_type() == F32) {
392 subshape->set_element_type(BF16);
393 }
394 });
395 shape_with_layout = &layout_shape;
396 }
397 }
398 auto expect_near = [&](const Literal& actual,
399 const std::string& error_message) {
400 EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error))
401 << error_message;
402 };
403 if (execution_options_.debug_options().xla_test_all_output_layouts()) {
404 return ComputeAndCompareLiteralWithAllOutputLayouts(
405 computation, *expected_ptr, arguments, expect_near);
406 }
407 if (execution_options_.debug_options().xla_test_all_input_layouts()) {
408 return ComputeAndCompareLiteralWithAllInputLayouts(
409 computation, *expected_ptr, arguments, expect_near, shape_with_layout);
410 }
411 TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
412 shape_with_layout));
413 EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error));
414 return OkStatus();
415 }
416
ComputeAndCompareR1U8(XlaBuilder * builder,absl::string_view expected,absl::Span<GlobalData * const> arguments)417 void ClientLibraryTestBase::ComputeAndCompareR1U8(
418 XlaBuilder* builder, absl::string_view expected,
419 absl::Span<GlobalData* const> arguments) {
420 auto actual_status = ExecuteAndTransfer(builder, arguments);
421 EXPECT_IS_OK(actual_status.status());
422 if (!actual_status.ok()) {
423 return;
424 }
425 auto actual = std::move(actual_status).value();
426
427 // Turn the expected value into a literal.
428 Literal expected_literal = LiteralUtil::CreateR1U8(expected);
429
430 VLOG(1) << "expected: " << expected_literal.ToString();
431 VLOG(1) << "actual: " << actual.ToString();
432
433 EXPECT_EQ(expected, actual.GetR1U8AsString());
434 }
435
ComputeAndCompareTuple(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments)436 void ClientLibraryTestBase::ComputeAndCompareTuple(
437 XlaBuilder* builder, const Literal& expected,
438 absl::Span<GlobalData* const> arguments) {
439 auto actual_status = ExecuteAndTransfer(builder, arguments);
440 EXPECT_IS_OK(actual_status.status());
441 if (!actual_status.ok()) {
442 return;
443 }
444 auto actual = std::move(actual_status).value();
445 EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
446 }
447
ComputeAndCompareTuple(XlaBuilder * builder,const Literal & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)448 void ClientLibraryTestBase::ComputeAndCompareTuple(
449 XlaBuilder* builder, const Literal& expected,
450 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
451 auto actual_status = ExecuteAndTransfer(builder, arguments);
452 EXPECT_IS_OK(actual_status.status());
453 if (!actual_status.ok()) {
454 return;
455 }
456 auto actual = std::move(actual_status).value();
457 EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error));
458 }
459
ComputeAndCompare(XlaBuilder * builder,absl::Span<const Literal> arguments)460 void ClientLibraryTestBase::ComputeAndCompare(
461 XlaBuilder* builder, absl::Span<const Literal> arguments) {
462 auto status_or_data = ComputeValueAndReference(builder, arguments);
463 EXPECT_IS_OK(status_or_data);
464 if (!status_or_data.ok()) {
465 return;
466 }
467 Literal reference, result;
468 std::tie(reference, result) = std::move(status_or_data).value();
469 EXPECT_TRUE(LiteralTestUtil::Equal(reference, result));
470 }
471
ComputeAndCompare(XlaBuilder * builder,absl::Span<const Literal> arguments,ErrorSpec error)472 void ClientLibraryTestBase::ComputeAndCompare(
473 XlaBuilder* builder, absl::Span<const Literal> arguments, ErrorSpec error) {
474 auto status_or_data = ComputeValueAndReference(builder, arguments);
475 EXPECT_IS_OK(status_or_data);
476 if (!status_or_data.ok()) {
477 return;
478 }
479 Literal reference, result;
480 std::tie(reference, result) = std::move(status_or_data).value();
481 EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error));
482 }
483
484 StatusOr<std::pair<Literal, Literal>>
ComputeValueAndReference(XlaBuilder * builder,absl::Span<const Literal> arguments)485 ClientLibraryTestBase::ComputeValueAndReference(
486 XlaBuilder* builder, absl::Span<const Literal> arguments) {
487 // Transfer the arguments to the executor service. We put the unique_ptr's
488 // into a vector to keep the data alive on the service until the end of this
489 // function.
490 std::vector<std::unique_ptr<GlobalData>> argument_data;
491 std::vector<std::unique_ptr<GlobalData>> ref_argument_data;
492
493 // Use `arguments_` if the AddParam() API was used. Otherwise, use
494 // plain `arguments`.
495 if (!arguments_.empty()) {
496 CHECK_EQ(arguments.size(), 0);
497 arguments = arguments_;
498 }
499
500 for (const auto& arg : arguments) {
501 TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg.Clone()));
502 TF_ASSIGN_OR_RETURN(auto ref_data, ref_client_->TransferToServer(arg));
503 argument_data.push_back(std::move(data));
504 ref_argument_data.push_back(std::move(ref_data));
505 }
506
507 // Create raw pointers to the GlobalData for the rest of the call stack.
508 std::vector<GlobalData*> argument_data_ptr;
509 std::transform(
510 argument_data.begin(), argument_data.end(),
511 std::back_inserter(argument_data_ptr),
512 [](const std::unique_ptr<GlobalData>& data) { return data.get(); });
513 std::vector<GlobalData*> ref_argument_data_ptr;
514 std::transform(
515 ref_argument_data.begin(), ref_argument_data.end(),
516 std::back_inserter(ref_argument_data_ptr),
517 [](const std::unique_ptr<GlobalData>& data) { return data.get(); });
518
519 TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
520
521 TF_ASSIGN_OR_RETURN(auto result,
522 ExecuteAndTransfer(computation, argument_data_ptr));
523
524 TF_ASSIGN_OR_RETURN(auto reference, ExecuteAndTransferReference(
525 computation, ref_argument_data_ptr));
526
527 return std::make_pair(std::move(reference), std::move(result));
528 }
529
CreateScalarRelu()530 XlaComputation ClientLibraryTestBase::CreateScalarRelu() {
531 XlaBuilder builder("relu");
532 auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
533 auto z_value = Parameter(&builder, 0, shape, "z_value");
534 auto zero = use_bfloat16_
535 ? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
536 : ConstantR0<float>(&builder, 0.0f);
537 Max(z_value, zero);
538 auto computation_status = builder.Build();
539 TF_CHECK_OK(computation_status.status());
540 return std::move(computation_status).value();
541 }
542
CreateScalarMax()543 XlaComputation ClientLibraryTestBase::CreateScalarMax() {
544 XlaBuilder builder("max");
545 auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
546 auto x = Parameter(&builder, 0, shape, "x");
547 auto y = Parameter(&builder, 1, shape, "y");
548 Max(x, y);
549 auto computation_status = builder.Build();
550 TF_CHECK_OK(computation_status.status());
551 return std::move(computation_status).value();
552 }
553
CreateScalarReluSensitivity()554 XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() {
555 XlaBuilder builder("relu_sensitivity");
556 auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
557 auto activation = Parameter(&builder, 0, shape, "activation");
558 auto backprop = Parameter(&builder, 1, shape, "backprop");
559 auto zero = use_bfloat16_
560 ? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
561 : ConstantR0<float>(&builder, 0.0f);
562 auto activation_gtz = Gt(activation, zero);
563 Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero);
564
565 auto computation_status = builder.Build();
566 TF_CHECK_OK(computation_status.status());
567 return std::move(computation_status).value();
568 }
569
CreatePatternedMatrix(int rows,int cols,float offset)570 std::unique_ptr<Array2D<float>> ClientLibraryTestBase::CreatePatternedMatrix(
571 int rows, int cols, float offset) {
572 auto array = std::make_unique<Array2D<float>>(rows, cols);
573 for (int64_t row = 0; row < rows; ++row) {
574 for (int64_t col = 0; col < cols; ++col) {
575 (*array)(row, col) = col + (row * 1000.0f) + offset;
576 }
577 }
578 return array;
579 }
580
581 std::unique_ptr<Array2D<float>>
CreatePatternedMatrixWithZeroPadding(int rows,int cols,int rows_padded,int cols_padded)582 ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
583 int rows_padded,
584 int cols_padded) {
585 CHECK_GE(rows_padded, rows);
586 CHECK_GE(cols_padded, cols);
587 auto array = std::make_unique<Array2D<float>>(rows_padded, cols_padded, 0.0);
588 for (int64_t row = 0; row < rows; ++row) {
589 for (int64_t col = 0; col < cols; ++col) {
590 (*array)(row, col) = col + (row * 1000.0f);
591 }
592 }
593 return array;
594 }
595
AddParam(const Literal & argument,XlaBuilder * builder)596 XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
597 XlaBuilder* builder) {
598 arguments_.push_back(argument.Clone());
599 return Parameter(builder, /*parameter_number=*/arguments_.size() - 1,
600 MaybeConvertShapeToBfloat16(argument.shape()), "");
601 }
602
CreateConstantFromLiteral(const Literal & literal,XlaBuilder * builder)603 XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
604 XlaBuilder* builder) {
605 return ConstantLiteral(builder, use_bfloat16_
606 ? LiteralUtil::ConvertF32ToBF16(literal)
607 : LiteralSlice(literal));
608 }
609
610 StatusOr<std::unique_ptr<GlobalData>>
CreateParameterAndTransferLiteral(int64_t parameter_number,const Literal & literal,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)611 ClientLibraryTestBase::CreateParameterAndTransferLiteral(
612 int64_t parameter_number, const Literal& literal, const std::string& name,
613 XlaBuilder* builder, XlaOp* data_handle) {
614 return CreateParameterAndTransferLiteral(parameter_number, literal, name,
615 nullptr, builder, data_handle);
616 }
617
MaybeConvertShapeToBfloat16(const Shape & shape)618 Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) {
619 if (!use_bfloat16_) {
620 return shape;
621 }
622 Shape new_shape = shape;
623 ShapeUtil::ForEachMutableSubshape(&new_shape,
624 [](Shape* subshape, const ShapeIndex&) {
625 if (subshape->element_type() == F32) {
626 subshape->set_element_type(BF16);
627 }
628 });
629 return new_shape;
630 }
631
MaybeConvertLiteralToBfloat16(const Literal & literal)632 Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
633 const Literal& literal) {
634 if (use_bfloat16_) {
635 return LiteralUtil::ConvertF32ToBF16(literal);
636 }
637 return literal.Clone();
638 }
639
640 StatusOr<std::unique_ptr<GlobalData>>
CreateParameterAndTransferLiteral(int64_t parameter_number,const Literal & literal,const std::string & name,const DeviceHandle * device_handle,XlaBuilder * builder,XlaOp * data_handle)641 ClientLibraryTestBase::CreateParameterAndTransferLiteral(
642 int64_t parameter_number, const Literal& literal, const std::string& name,
643 const DeviceHandle* device_handle, XlaBuilder* builder,
644 XlaOp* data_handle) {
645 Literal param_literal = MaybeConvertLiteralToBfloat16(literal);
646 TF_ASSIGN_OR_RETURN(auto data,
647 client_->TransferToServer(param_literal, device_handle));
648 *data_handle =
649 Parameter(builder, parameter_number, param_literal.shape(), name);
650 return data;
651 }
652
653 } // namespace xla
654