1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
18
19 #include <memory>
20 #include <string>
21 #include <type_traits>
22 #include <vector>
23
24 #include "absl/strings/string_view.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/array2d.h"
27 #include "tensorflow/compiler/xla/array3d.h"
28 #include "tensorflow/compiler/xla/array4d.h"
29 #include "tensorflow/compiler/xla/client/client_library.h"
30 #include "tensorflow/compiler/xla/client/global_data.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
38 #include "tensorflow/compiler/xla/tests/test_utils.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/bitmap.h"
41 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
42 #include "tensorflow/core/platform/test.h"
43
44 namespace xla {
45
46 // Sets the use_bfloat16 on a container of test cases according to the values in
47 // use_bfloat16_params. Generates one set of test cases for each values in
48 // use_bfloat16_params with that value. Returns the result.
49 template <typename TestCase>
ExpandUseBfloat16(absl::Span<const bool> use_bfloat16_params,absl::Span<const TestCase> specs)50 std::vector<TestCase> ExpandUseBfloat16(
51 absl::Span<const bool> use_bfloat16_params,
52 absl::Span<const TestCase> specs) {
53 std::vector<TestCase> expanded;
54 for (bool use_bfloat16 : use_bfloat16_params) {
55 for (const auto& spec : specs) {
56 expanded.push_back(spec);
57 expanded.back().use_bfloat16 = use_bfloat16;
58 }
59 }
60 return expanded;
61 }
62
63 // A client library test establishes an in-process XLA client connection.
64 class ClientLibraryTestBase : public ManifestCheckingTest {
65 protected:
66 explicit ClientLibraryTestBase(se::Platform* platform = nullptr);
67
68 // Creates a new ClientLibraryTestBase with custom client options.
69 ClientLibraryTestBase(se::Platform* platform,
70 const LocalClientOptions& client_options);
71
72 // Returns the name of the test currently being run.
73 std::string TestName() const;
74
SetFastMathDisabled(bool disabled)75 void SetFastMathDisabled(bool disabled) {
76 auto* opts = execution_options_.mutable_debug_options();
77 opts->set_xla_cpu_enable_fast_math(!disabled);
78 opts->set_xla_cpu_enable_fast_min_max(!disabled);
79 opts->set_xla_gpu_enable_fast_min_max(!disabled);
80 }
81
SetSeed(uint64_t seed)82 void SetSeed(uint64_t seed) { execution_options_.set_seed(seed); }
83
84 // Provides mutable access to the execution DebugOptions field; this lets
85 // tests tweak the options that will be used to compile/run the graph.
mutable_debug_options()86 DebugOptions* mutable_debug_options() {
87 return execution_options_.mutable_debug_options();
88 }
89
90 // TODO(b/25566808): Add helper that populates a literal from a testdata file.
91
92 // Convenience methods for building and running a computation with the member
93 // execution options. Modify execution_options_ in your test if you want to
94 // customize the options.
95 StatusOr<std::unique_ptr<GlobalData>> Execute(
96 XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
97
98 StatusOr<Literal> ExecuteAndTransfer(
99 XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
100 const Shape* shape_with_output_layout = nullptr);
101
102 StatusOr<Literal> ExecuteAndTransfer(
103 const XlaComputation& computation,
104 absl::Span<GlobalData* const> arguments,
105 const Shape* shape_with_output_layout = nullptr);
106
107 // This executes the computation via the reference client (which connects a
108 // interpreter backend). The result is used as the expected value of the
109 // computation.
110 StatusOr<Literal> ExecuteAndTransferReference(
111 const XlaComputation& computation,
112 absl::Span<GlobalData* const> arguments,
113 const Shape* shape_with_output_layout = nullptr);
114
115 // Run a computation and return its value as a string. If an error
116 // occurs, then instead return the error as a string.
117 std::string ExecuteToString(XlaBuilder* builder,
118 absl::Span<GlobalData* const> arguments);
119
120 // Convenience methods for building and running a computation, transferring
121 // the result, and comparing it to the expected value(s). Methods are
122 // templated on the native host type which maps to specific XLA types (See
123 // XlaBuilder for details). For each rank, two forms are
124 // provided: one for floating point types with an ErrorSpec parameter, and one
125 // for integral types without the ErrorSpec parameter.
126 template <typename NativeT>
127 void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
128 absl::Span<GlobalData* const> arguments);
129 template <typename NativeT>
130 void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
131 absl::Span<GlobalData* const> arguments,
132 ErrorSpec error);
133
134 template <typename NativeT>
135 void ComputeAndCompareR1(XlaBuilder* builder,
136 absl::Span<const NativeT> expected,
137 absl::Span<GlobalData* const> arguments);
138 template <typename NativeT>
139 void ComputeAndCompareR1(XlaBuilder* builder,
140 absl::Span<const NativeT> expected,
141 absl::Span<GlobalData* const> arguments,
142 ErrorSpec error);
143
144 // As above, but uses a bitmap to hold the predicate vector to avoid
145 // deficiencies of vector<bool>.
146 void ComputeAndCompareR1(XlaBuilder* builder,
147 const tensorflow::core::Bitmap& expected,
148 absl::Span<GlobalData* const> arguments);
149
150 template <typename NativeT>
151 void ComputeAndCompareR2(XlaBuilder* builder,
152 const Array2D<NativeT>& expected,
153 absl::Span<GlobalData* const> arguments);
154 template <typename NativeT>
155 void ComputeAndCompareR2(XlaBuilder* builder,
156 const Array2D<NativeT>& expected,
157 absl::Span<GlobalData* const> arguments,
158 ErrorSpec error);
159
160 template <typename NativeT>
161 void ComputeAndCompareR3(XlaBuilder* builder,
162 const Array3D<NativeT>& expected,
163 absl::Span<GlobalData* const> arguments);
164 template <typename NativeT>
165 void ComputeAndCompareR3(XlaBuilder* builder,
166 const Array3D<NativeT>& expected,
167 absl::Span<GlobalData* const> arguments,
168 ErrorSpec error);
169
170 template <typename NativeT>
171 void ComputeAndCompareR4(XlaBuilder* builder,
172 const Array4D<NativeT>& expected,
173 absl::Span<GlobalData* const> arguments);
174 template <typename NativeT>
175 void ComputeAndCompareR4(XlaBuilder* builder,
176 const Array4D<NativeT>& expected,
177 absl::Span<GlobalData* const> arguments,
178 ErrorSpec error);
179
180 // Build and run the computation and compare the result with the given
181 // literal. shape_with_layout indicates the result layout to request when
182 // calling Execute.
183 void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
184 absl::Span<GlobalData* const> arguments,
185 const Shape* shape_with_layout = nullptr);
186 void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
187 absl::Span<GlobalData* const> arguments,
188 ErrorSpec error,
189 const Shape* shape_with_layout = nullptr);
190
191 // Build and run the computation and return the result as a literal.
192 // shape_with_layout indicates the result layout to request when calling
193 // Execute.
194 StatusOr<Literal> ComputeAndTransfer(
195 XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
196 const Shape* shape_with_layout = nullptr);
197
198 // ComputeAndCompare variant which returns an error status.
199 Status ComputeAndCompareLiteralWithStatus(
200 XlaBuilder* builder, const Literal& expected,
201 absl::Span<GlobalData* const> arguments,
202 const Shape* shape_with_layout = nullptr);
203 Status ComputeAndCompareLiteralWithStatus(
204 XlaBuilder* builder, const Literal& expected,
205 absl::Span<GlobalData* const> arguments, ErrorSpec error,
206 const Shape* shape_with_layout = nullptr);
207
208 // Compare the result of the computation to a strings. In XLA strings are
209 // represented using rank-1 U8 shapes.
210 void ComputeAndCompareR1U8(XlaBuilder* builder, absl::string_view expected,
211 absl::Span<GlobalData* const> arguments);
212
213 // Convenience method for running a built computation, transferring the
214 // result, and comparing it to the expected tuple literal.
215 void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
216 absl::Span<GlobalData* const> arguments);
217 void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
218 absl::Span<GlobalData* const> arguments,
219 ErrorSpec error);
220
221 // Convenience method for running a built computation and comparing the result
222 // with the reference result.
223 void ComputeAndCompare(XlaBuilder* builder,
224 absl::Span<const Literal> arguments);
225 void ComputeAndCompare(XlaBuilder* builder,
226 absl::Span<const Literal> arguments, ErrorSpec error);
227 template <typename NativeT>
228 void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
229 absl::Span<GlobalData* const> arguments);
230 template <typename NativeT>
231 void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
232 absl::Span<GlobalData* const> arguments,
233 ErrorSpec error);
234 // Create scalar operations for use in reductions.
235 XlaComputation CreateScalarRelu();
236 XlaComputation CreateScalarMax();
237 XlaComputation CreateScalarReluSensitivity();
238
239 // Special case convenience functions for creating filled arrays.
240
241 // Creates an array of pseudorandom values lying between the given minimum and
242 // maximum values.
243 template <typename NativeT>
244 std::vector<NativeT> CreatePseudorandomR1(const int width, NativeT min_value,
245 NativeT max_value, uint32_t seed);
246 template <typename NativeT>
247 std::unique_ptr<Array2D<NativeT>> CreatePseudorandomR2(const int rows,
248 const int cols,
249 NativeT min_value,
250 NativeT max_value,
251 uint32_t seed);
252
253 // Creates a (rows x cols) array filled in the following form:
254 //
255 // [ 0 1 ... cols-1]
256 // [ 1,000 1,001 ... 1000.0 + cols-1]
257 // [ ... ... ... ...]
258 // [(rows-1)*1000.0 ... ... (rows-1)*1000.0 + cols-1]
259 //
260 // If provided, offset is added uniformly to every element (e.g. an offset of
261 // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.)
262 std::unique_ptr<Array2D<float>> CreatePatternedMatrix(const int rows,
263 const int cols,
264 float offset = 0.0);
265
266 // Creates a (rows x cols) array as above, padded out to
267 // (rows_padded x cols_padded) with zeroes. Requires rows_padded >= rows
268 // and cols_padded > cols.
269 std::unique_ptr<Array2D<float>> CreatePatternedMatrixWithZeroPadding(
270 const int rows, const int cols, const int rows_padded,
271 const int cols_padded);
272
273 // Creates a parameter instruction, transfers the literal for the parameter to
274 // server, then stores into "data_handle" the global handle for that
275 // parameter. When the use_bfloat16 flag is set but the literal has F32
276 // elements, the literal will be converted to BF16 before being transferred.
277 StatusOr<std::unique_ptr<GlobalData>> CreateParameterAndTransferLiteral(
278 int64_t parameter_number, const Literal& literal, const std::string& name,
279 XlaBuilder* builder, XlaOp* data_handle);
280
281 // As above, but the caller can specify the device that the literal is
282 // transferred to. If device_handle is nullptr, the literal will be
283 // transferred to the default device.
284 StatusOr<std::unique_ptr<GlobalData>> CreateParameterAndTransferLiteral(
285 int64_t parameter_number, const Literal& literal, const std::string& name,
286 const DeviceHandle* device_handle, XlaBuilder* builder,
287 XlaOp* data_handle);
288
289 // Creates a parameter instruction and sets the value that will be passed to
290 // the computation as specified. This function must be used for all parameters
291 // or none and no parameters must be passed when invoking the computation if
292 // using this mechanism. If using this mechanism, then each parameter must be
293 // set exactly once. The first added parameter gets index 0, then 1 and so on.
294 XlaOp AddParam(const Literal& argument, XlaBuilder* builder);
295
296 template <class T>
AddParam(const Array<T> & argument,XlaBuilder * builder)297 XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
298 return AddParam(LiteralUtil::CreateFromArray(argument), builder);
299 }
300
301 // Creates a constant instruction with the given literal. When the
302 // use_bfloat16 flag is set but the literal has F32 elements, the elements
303 // will be converted to BF16s.
304 XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder);
305
306 // Creates a constant instruction with the given array. When the use_bfloat16
307 // flag is set but the array has float elements, the elements will be
308 // converted to bfloat16s.
309
310 template <typename NativeT>
CreateConstantFromArray(const Array<NativeT> & array,XlaBuilder * builder)311 XlaOp CreateConstantFromArray(const Array<NativeT>& array,
312 XlaBuilder* builder) {
313 return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
314 builder);
315 }
316
317 // Same as CreateConstantFromArray, but for scalars.
318 template <typename NativeT>
CreateConstantFromScalar(NativeT value,XlaBuilder * builder)319 XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
320 return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
321 builder);
322 }
323
324 // Creates a parameter instruction that wraps a given value and then stores
325 // into "data_handle" the global handle for that parameter.
326 //
327 // "parameter_number" is the parameter number.
328 // "name" is the name of the parameter instruction.
329 //
330 // When the use_bfloat16 flag is set but NativeT is float, the data will be
331 // converted to bfloat16.
332 template <typename NativeT>
333 std::unique_ptr<GlobalData> CreateR0Parameter(NativeT value,
334 int64_t parameter_number,
335 const std::string& name,
336 XlaBuilder* builder,
337 XlaOp* data_handle);
338
339 // Creates a parameter instruction that wraps the given values and then stores
340 // into "data_handle" the global handle for that parameter.
341 //
342 // "parameter_number" is the parameter number.
343 // "name" is the name of the parameter instruction.
344 //
345 // When the use_bfloat16 flag is set but NativeT is float, the data will be
346 // converted to bfloat16.
347 template <typename NativeT>
348 std::unique_ptr<GlobalData> CreateR1Parameter(
349 absl::Span<const NativeT> values, int64_t parameter_number,
350 const std::string& name, XlaBuilder* builder, XlaOp* data_handle);
351
352 // Creates a parameter instruction that wraps the given constant array
353 // "array_2d" and then stores it to the global handle for that parameter
354 // "data_handle".
355 //
356 // "parameter_number" is the parameter number.
357 // "name" is the name of the parameter instruction.
358 //
359 // When the use_bfloat16 flag is set but NativeT is float, the data will be
360 // converted to bfloat16.
361 template <typename NativeT>
362 std::unique_ptr<GlobalData> CreateR2Parameter(
363 const Array2D<NativeT>& array_2d, int64_t parameter_number,
364 const std::string& name, XlaBuilder* builder, XlaOp* data_handle);
365
366 // Creates a parameter instruction that wraps the given constant array
367 // "array_3d" and then stores it to the global handle for that parameter
368 // "data_handle".
369 //
370 // "parameter_number" is the parameter number.
371 // "name" is the name of the parameter instruction.
372 //
373 // When the use_bfloat16 flag is set but NativeT is float, the data will be
374 // converted to bfloat16.
375 template <typename NativeT>
376 std::unique_ptr<GlobalData> CreateR3Parameter(
377 const Array3D<NativeT>& array_3d, int64_t parameter_number,
378 const std::string& name, XlaBuilder* builder, XlaOp* data_handle);
379
380 // Creates a parameter instruction that wraps the given constant array
381 // "array_4d" and then stores it to the global handle for that parameter
382 // "data_handle".
383 //
384 // "parameter_number" is the parameter number.
385 // "name" is the name of the parameter instruction.
386 //
387 // When the use_bfloat16 flag is set but NativeT is float, the data will be
388 // converted to bfloat16.
389 template <typename NativeT>
390 std::unique_ptr<GlobalData> CreateR4Parameter(
391 const Array4D<NativeT>& array_4d, int64_t parameter_number,
392 const std::string& name, XlaBuilder* builder, XlaOp* data_handle);
393
394 template <typename NativeT>
395 std::unique_ptr<GlobalData> CreateParameter(const Array<NativeT>& array_4d,
396 int64_t parameter_number,
397 const std::string& name,
398 XlaBuilder* builder,
399 XlaOp* data_handle);
400
401 // Getter and setter for the use_bfloat16 flag, which indicates whether to run
402 // tests with all float-type input/output converted to bfloat16.
use_bfloat16()403 bool use_bfloat16() const { return use_bfloat16_; }
set_use_bfloat16(bool value)404 void set_use_bfloat16(bool value) { use_bfloat16_ = value; }
405
406 // The float type used in this test, BF16 or F32 according to use_bfloat16.
FloatType()407 PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; }
408
409 // Executes the computation and calculates the expected reference value using
410 // the reference client. Returns two literals in the order of (expected,
411 // actual).
412 StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
413 XlaBuilder* builder, absl::Span<const Literal> arguments);
414
415 // Converts an f32 literal to bf16 if use_bfloat16_ is true.
416 Literal MaybeConvertLiteralToBfloat16(const Literal& literal);
417
418 LocalClient* client_;
419 LocalClient* ref_client_; // To compute reference result.
420 ExecutionOptions execution_options_;
421
422 private:
423 Status ComputeAndCompareLiteralWithAllOutputLayouts(
424 const xla::XlaComputation& computation, const Literal& expected,
425 absl::Span<GlobalData* const> arguments,
426 const std::function<void(const Literal& actual,
427 const std::string& error_message)>&
428 verify_output);
429 Status ComputeAndCompareLiteralWithAllInputLayouts(
430 const xla::XlaComputation& computation, const Literal& expected,
431 absl::Span<GlobalData* const> arguments,
432 const std::function<void(const Literal& actual,
433 const std::string& error_message)>&
434 verify_output,
435 const Shape* output_with_layout = nullptr);
436
437 // Converts an f32 shape to bf16 if use_bfloat16_ is true.
438 Shape MaybeConvertShapeToBfloat16(const Shape& shape);
439
440 // Whether to run tests with all float-type input/output converted to
441 // bfloat16.
442 bool use_bfloat16_ = false;
443
444 // Arguments to be passed to the computation when it runs.
445 std::vector<Literal> arguments_;
446 };
447
448 template <typename NativeT>
ComputeAndCompareR0(XlaBuilder * builder,NativeT expected,absl::Span<GlobalData * const> arguments)449 void ClientLibraryTestBase::ComputeAndCompareR0(
450 XlaBuilder* builder, NativeT expected,
451 absl::Span<GlobalData* const> arguments) {
452 Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
453 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
454 arguments);
455 }
456
457 template <typename NativeT>
ComputeAndCompareR0(XlaBuilder * builder,NativeT expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)458 void ClientLibraryTestBase::ComputeAndCompareR0(
459 XlaBuilder* builder, NativeT expected,
460 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
461 static_assert(std::is_same<NativeT, float>::value ||
462 std::is_same<NativeT, double>::value ||
463 std::is_same<NativeT, bfloat16>::value ||
464 std::is_same<NativeT, half>::value ||
465 std::is_same<NativeT, complex64>::value ||
466 std::is_same<NativeT, complex128>::value,
467 "Float or complex type required when specifying an ErrorSpec");
468 Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
469 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
470 arguments, error);
471 }
472
473 template <typename NativeT>
ComputeAndCompareR1(XlaBuilder * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)474 void ClientLibraryTestBase::ComputeAndCompareR1(
475 XlaBuilder* builder, absl::Span<const NativeT> expected,
476 absl::Span<GlobalData* const> arguments) {
477 Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
478 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
479 arguments);
480 }
481
482 template <typename NativeT>
ComputeAndCompareR1(XlaBuilder * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)483 void ClientLibraryTestBase::ComputeAndCompareR1(
484 XlaBuilder* builder, absl::Span<const NativeT> expected,
485 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
486 static_assert(std::is_same<NativeT, float>::value ||
487 std::is_same<NativeT, double>::value ||
488 std::is_same<NativeT, bfloat16>::value ||
489 std::is_same<NativeT, half>::value ||
490 std::is_same<NativeT, complex64>::value ||
491 std::is_same<NativeT, complex128>::value,
492 "Float or complex type required when specifying an ErrorSpec");
493 Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
494 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
495 arguments, error);
496 }
497
498 template <typename NativeT>
ComputeAndCompareR2(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments)499 void ClientLibraryTestBase::ComputeAndCompareR2(
500 XlaBuilder* builder, const Array2D<NativeT>& expected,
501 absl::Span<GlobalData* const> arguments) {
502 Literal expected_literal =
503 LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
504 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
505 arguments);
506 }
507
508 template <typename NativeT>
ComputeAndCompareR2(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)509 void ClientLibraryTestBase::ComputeAndCompareR2(
510 XlaBuilder* builder, const Array2D<NativeT>& expected,
511 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
512 static_assert(std::is_same<NativeT, float>::value ||
513 std::is_same<NativeT, double>::value ||
514 std::is_same<NativeT, bfloat16>::value ||
515 std::is_same<NativeT, half>::value ||
516 std::is_same<NativeT, complex64>::value ||
517 std::is_same<NativeT, complex128>::value,
518 "Float or complex type required when specifying an ErrorSpec");
519 Literal expected_literal =
520 LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
521 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
522 arguments, error);
523 }
524
525 template <typename NativeT>
ComputeAndCompareR3(XlaBuilder * builder,const Array3D<NativeT> & expected,absl::Span<GlobalData * const> arguments)526 void ClientLibraryTestBase::ComputeAndCompareR3(
527 XlaBuilder* builder, const Array3D<NativeT>& expected,
528 absl::Span<GlobalData* const> arguments) {
529 Literal expected_literal =
530 LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
531 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
532 arguments);
533 }
534
535 template <typename NativeT>
ComputeAndCompareR3(XlaBuilder * builder,const Array3D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)536 void ClientLibraryTestBase::ComputeAndCompareR3(
537 XlaBuilder* builder, const Array3D<NativeT>& expected,
538 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
539 static_assert(std::is_same<NativeT, float>::value ||
540 std::is_same<NativeT, double>::value ||
541 std::is_same<NativeT, bfloat16>::value ||
542 std::is_same<NativeT, half>::value ||
543 std::is_same<NativeT, complex64>::value ||
544 std::is_same<NativeT, complex128>::value,
545 "Float or complex type required when specifying an ErrorSpec");
546 Literal expected_literal =
547 LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
548 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
549 arguments, error);
550 }
551
552 template <typename NativeT>
ComputeAndCompareR4(XlaBuilder * builder,const Array4D<NativeT> & expected,absl::Span<GlobalData * const> arguments)553 void ClientLibraryTestBase::ComputeAndCompareR4(
554 XlaBuilder* builder, const Array4D<NativeT>& expected,
555 absl::Span<GlobalData* const> arguments) {
556 Literal expected_literal =
557 LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
558 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
559 arguments);
560 }
561
562 template <typename NativeT>
ComputeAndCompareR4(XlaBuilder * builder,const Array4D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)563 void ClientLibraryTestBase::ComputeAndCompareR4(
564 XlaBuilder* builder, const Array4D<NativeT>& expected,
565 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
566 static_assert(std::is_same<NativeT, float>::value ||
567 std::is_same<NativeT, double>::value ||
568 std::is_same<NativeT, bfloat16>::value ||
569 std::is_same<NativeT, half>::value ||
570 std::is_same<NativeT, complex64>::value ||
571 std::is_same<NativeT, complex128>::value,
572 "Float or complex type required when specifying an ErrorSpec");
573 Literal expected_literal =
574 LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
575 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
576 arguments, error);
577 }
578
579 template <typename NativeT>
ComputeAndCompare(XlaBuilder * builder,const Array<NativeT> & expected,absl::Span<GlobalData * const> arguments)580 void ClientLibraryTestBase::ComputeAndCompare(
581 XlaBuilder* builder, const Array<NativeT>& expected,
582 absl::Span<GlobalData* const> arguments) {
583 Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
584 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
585 arguments);
586 }
587
588 template <typename NativeT>
ComputeAndCompare(XlaBuilder * builder,const Array<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)589 void ClientLibraryTestBase::ComputeAndCompare(
590 XlaBuilder* builder, const Array<NativeT>& expected,
591 absl::Span<GlobalData* const> arguments, ErrorSpec error) {
592 static_assert(std::is_same<NativeT, float>::value ||
593 std::is_same<NativeT, double>::value ||
594 std::is_same<NativeT, bfloat16>::value ||
595 std::is_same<NativeT, half>::value ||
596 std::is_same<NativeT, complex64>::value ||
597 std::is_same<NativeT, complex128>::value,
598 "Float or complex type required when specifying an ErrorSpec");
599 Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
600 ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
601 arguments, error);
602 }
603
604 template <typename NativeT>
CreateR0Parameter(NativeT value,int64_t parameter_number,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)605 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
606 NativeT value, int64_t parameter_number, const std::string& name,
607 XlaBuilder* builder, XlaOp* data_handle) {
608 Literal literal = LiteralUtil::CreateR0(value);
609 if (use_bfloat16_ && literal.shape().element_type() == F32) {
610 literal = LiteralUtil::ConvertF32ToBF16(literal);
611 }
612 std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
613 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
614 return data;
615 }
616
617 template <typename NativeT>
CreateR1Parameter(absl::Span<const NativeT> values,int64_t parameter_number,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)618 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
619 absl::Span<const NativeT> values, int64_t parameter_number,
620 const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
621 Literal literal = LiteralUtil::CreateR1(values);
622 if (use_bfloat16_ && literal.shape().element_type() == F32) {
623 literal = LiteralUtil::ConvertF32ToBF16(literal);
624 }
625 std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
626 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
627 return data;
628 }
629
630 template <typename NativeT>
CreateR2Parameter(const Array2D<NativeT> & array_2d,int64_t parameter_number,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)631 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
632 const Array2D<NativeT>& array_2d, int64_t parameter_number,
633 const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
634 Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
635 if (use_bfloat16_ && literal.shape().element_type() == F32) {
636 literal = LiteralUtil::ConvertF32ToBF16(literal);
637 }
638 std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
639 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
640 return data;
641 }
642
643 template <typename NativeT>
CreateR3Parameter(const Array3D<NativeT> & array_3d,int64_t parameter_number,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)644 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
645 const Array3D<NativeT>& array_3d, int64_t parameter_number,
646 const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
647 Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
648 if (use_bfloat16_ && literal.shape().element_type() == F32) {
649 literal = LiteralUtil::ConvertF32ToBF16(literal);
650 }
651 std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
652 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
653 return data;
654 }
655
656 template <typename NativeT>
CreateR4Parameter(const Array4D<NativeT> & array_4d,int64_t parameter_number,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)657 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR4Parameter(
658 const Array4D<NativeT>& array_4d, int64_t parameter_number,
659 const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
660 Literal literal = LiteralUtil::CreateR4FromArray4D(array_4d);
661 if (use_bfloat16_ && literal.shape().element_type() == F32) {
662 literal = LiteralUtil::ConvertF32ToBF16(literal);
663 }
664 std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
665 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
666 return data;
667 }
668
669 template <typename NativeT>
CreateParameter(const Array<NativeT> & array,int64_t parameter_number,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)670 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateParameter(
671 const Array<NativeT>& array, int64_t parameter_number,
672 const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
673 Literal literal = LiteralUtil::CreateFromArray(array);
674 if (use_bfloat16_ && literal.shape().element_type() == F32) {
675 literal = LiteralUtil::ConvertF32ToBF16(literal);
676 }
677 std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
678 *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
679 return data;
680 }
681
682 template <typename NativeT>
CreatePseudorandomR1(const int width,NativeT min_value,NativeT max_value,uint32_t seed)683 std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
684 const int width, NativeT min_value, NativeT max_value, uint32_t seed) {
685 std::vector<NativeT> result(width);
686 PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
687 for (int i = 0; i < width; ++i) {
688 result[i] = generator.get();
689 }
690 return result;
691 }
692
693 template <typename NativeT>
CreatePseudorandomR2(const int rows,const int cols,NativeT min_value,NativeT max_value,uint32_t seed)694 std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
695 const int rows, const int cols, NativeT min_value, NativeT max_value,
696 uint32_t seed) {
697 auto result = std::make_unique<Array2D<NativeT>>(rows, cols);
698 PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
699 for (int y = 0; y < rows; ++y) {
700 for (int x = 0; x < cols; ++x) {
701 (*result)(y, x) = generator.get();
702 }
703 }
704 return result;
705 }
706
707 } // namespace xla
708
709 #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
710