xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/client_library_test_base.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <type_traits>
22 #include <vector>
23 
24 #include "absl/strings/string_view.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/array2d.h"
27 #include "tensorflow/compiler/xla/array3d.h"
28 #include "tensorflow/compiler/xla/array4d.h"
29 #include "tensorflow/compiler/xla/client/client_library.h"
30 #include "tensorflow/compiler/xla/client/global_data.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
38 #include "tensorflow/compiler/xla/tests/test_utils.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/bitmap.h"
41 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
42 #include "tensorflow/core/platform/test.h"
43 
44 namespace xla {
45 
46 // Sets the use_bfloat16 on a container of test cases according to the values in
47 // use_bfloat16_params. Generates one set of test cases for each values in
48 // use_bfloat16_params with that value. Returns the result.
49 template <typename TestCase>
ExpandUseBfloat16(absl::Span<const bool> use_bfloat16_params,absl::Span<const TestCase> specs)50 std::vector<TestCase> ExpandUseBfloat16(
51     absl::Span<const bool> use_bfloat16_params,
52     absl::Span<const TestCase> specs) {
53   std::vector<TestCase> expanded;
54   for (bool use_bfloat16 : use_bfloat16_params) {
55     for (const auto& spec : specs) {
56       expanded.push_back(spec);
57       expanded.back().use_bfloat16 = use_bfloat16;
58     }
59   }
60   return expanded;
61 }
62 
63 // A client library test establishes an in-process XLA client connection.
64 class ClientLibraryTestBase : public ManifestCheckingTest {
65  protected:
66   explicit ClientLibraryTestBase(se::Platform* platform = nullptr);
67 
68   // Creates a new ClientLibraryTestBase with custom client options.
69   ClientLibraryTestBase(se::Platform* platform,
70                         const LocalClientOptions& client_options);
71 
72   // Returns the name of the test currently being run.
73   std::string TestName() const;
74 
SetFastMathDisabled(bool disabled)75   void SetFastMathDisabled(bool disabled) {
76     auto* opts = execution_options_.mutable_debug_options();
77     opts->set_xla_cpu_enable_fast_math(!disabled);
78     opts->set_xla_cpu_enable_fast_min_max(!disabled);
79     opts->set_xla_gpu_enable_fast_min_max(!disabled);
80   }
81 
SetSeed(uint64_t seed)82   void SetSeed(uint64_t seed) { execution_options_.set_seed(seed); }
83 
84   // Provides mutable access to the execution DebugOptions field; this lets
85   // tests tweak the options that will be used to compile/run the graph.
mutable_debug_options()86   DebugOptions* mutable_debug_options() {
87     return execution_options_.mutable_debug_options();
88   }
89 
90   // TODO(b/25566808): Add helper that populates a literal from a testdata file.
91 
92   // Convenience methods for building and running a computation with the member
93   // execution options. Modify execution_options_ in your test if you want to
94   // customize the options.
95   StatusOr<std::unique_ptr<GlobalData>> Execute(
96       XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
97 
98   StatusOr<Literal> ExecuteAndTransfer(
99       XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
100       const Shape* shape_with_output_layout = nullptr);
101 
102   StatusOr<Literal> ExecuteAndTransfer(
103       const XlaComputation& computation,
104       absl::Span<GlobalData* const> arguments,
105       const Shape* shape_with_output_layout = nullptr);
106 
107   // This executes the computation via the reference client (which connects a
108   // interpreter backend). The result is used as the expected value of the
109   // computation.
110   StatusOr<Literal> ExecuteAndTransferReference(
111       const XlaComputation& computation,
112       absl::Span<GlobalData* const> arguments,
113       const Shape* shape_with_output_layout = nullptr);
114 
115   // Run a computation and return its value as a string. If an error
116   // occurs, then instead return the error as a string.
117   std::string ExecuteToString(XlaBuilder* builder,
118                               absl::Span<GlobalData* const> arguments);
119 
120   // Convenience methods for building and running a computation, transferring
121   // the result, and comparing it to the expected value(s). Methods are
122   // templated on the native host type which maps to specific XLA types (See
123   // XlaBuilder for details). For each rank, two forms are
124   // provided: one for floating point types with an ErrorSpec parameter, and one
125   // for integral types without the ErrorSpec parameter.
126   template <typename NativeT>
127   void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
128                            absl::Span<GlobalData* const> arguments);
129   template <typename NativeT>
130   void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
131                            absl::Span<GlobalData* const> arguments,
132                            ErrorSpec error);
133 
134   template <typename NativeT>
135   void ComputeAndCompareR1(XlaBuilder* builder,
136                            absl::Span<const NativeT> expected,
137                            absl::Span<GlobalData* const> arguments);
138   template <typename NativeT>
139   void ComputeAndCompareR1(XlaBuilder* builder,
140                            absl::Span<const NativeT> expected,
141                            absl::Span<GlobalData* const> arguments,
142                            ErrorSpec error);
143 
144   // As above, but uses a bitmap to hold the predicate vector to avoid
145   // deficiencies of vector<bool>.
146   void ComputeAndCompareR1(XlaBuilder* builder,
147                            const tensorflow::core::Bitmap& expected,
148                            absl::Span<GlobalData* const> arguments);
149 
150   template <typename NativeT>
151   void ComputeAndCompareR2(XlaBuilder* builder,
152                            const Array2D<NativeT>& expected,
153                            absl::Span<GlobalData* const> arguments);
154   template <typename NativeT>
155   void ComputeAndCompareR2(XlaBuilder* builder,
156                            const Array2D<NativeT>& expected,
157                            absl::Span<GlobalData* const> arguments,
158                            ErrorSpec error);
159 
160   template <typename NativeT>
161   void ComputeAndCompareR3(XlaBuilder* builder,
162                            const Array3D<NativeT>& expected,
163                            absl::Span<GlobalData* const> arguments);
164   template <typename NativeT>
165   void ComputeAndCompareR3(XlaBuilder* builder,
166                            const Array3D<NativeT>& expected,
167                            absl::Span<GlobalData* const> arguments,
168                            ErrorSpec error);
169 
170   template <typename NativeT>
171   void ComputeAndCompareR4(XlaBuilder* builder,
172                            const Array4D<NativeT>& expected,
173                            absl::Span<GlobalData* const> arguments);
174   template <typename NativeT>
175   void ComputeAndCompareR4(XlaBuilder* builder,
176                            const Array4D<NativeT>& expected,
177                            absl::Span<GlobalData* const> arguments,
178                            ErrorSpec error);
179 
180   // Build and run the computation and compare the result with the given
181   // literal. shape_with_layout indicates the result layout to request when
182   // calling Execute.
183   void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
184                                 absl::Span<GlobalData* const> arguments,
185                                 const Shape* shape_with_layout = nullptr);
186   void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected,
187                                 absl::Span<GlobalData* const> arguments,
188                                 ErrorSpec error,
189                                 const Shape* shape_with_layout = nullptr);
190 
191   // Build and run the computation and return the result as a literal.
192   // shape_with_layout indicates the result layout to request when calling
193   // Execute.
194   StatusOr<Literal> ComputeAndTransfer(
195       XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
196       const Shape* shape_with_layout = nullptr);
197 
198   // ComputeAndCompare variant which returns an error status.
199   Status ComputeAndCompareLiteralWithStatus(
200       XlaBuilder* builder, const Literal& expected,
201       absl::Span<GlobalData* const> arguments,
202       const Shape* shape_with_layout = nullptr);
203   Status ComputeAndCompareLiteralWithStatus(
204       XlaBuilder* builder, const Literal& expected,
205       absl::Span<GlobalData* const> arguments, ErrorSpec error,
206       const Shape* shape_with_layout = nullptr);
207 
208   // Compare the result of the computation to a strings. In XLA strings are
209   // represented using rank-1 U8 shapes.
210   void ComputeAndCompareR1U8(XlaBuilder* builder, absl::string_view expected,
211                              absl::Span<GlobalData* const> arguments);
212 
213   // Convenience method for running a built computation, transferring the
214   // result, and comparing it to the expected tuple literal.
215   void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
216                               absl::Span<GlobalData* const> arguments);
217   void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected,
218                               absl::Span<GlobalData* const> arguments,
219                               ErrorSpec error);
220 
221   // Convenience method for running a built computation and comparing the result
222   // with the reference result.
223   void ComputeAndCompare(XlaBuilder* builder,
224                          absl::Span<const Literal> arguments);
225   void ComputeAndCompare(XlaBuilder* builder,
226                          absl::Span<const Literal> arguments, ErrorSpec error);
227   template <typename NativeT>
228   void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
229                          absl::Span<GlobalData* const> arguments);
230   template <typename NativeT>
231   void ComputeAndCompare(XlaBuilder* builder, const Array<NativeT>& expected,
232                          absl::Span<GlobalData* const> arguments,
233                          ErrorSpec error);
234   // Create scalar operations for use in reductions.
235   XlaComputation CreateScalarRelu();
236   XlaComputation CreateScalarMax();
237   XlaComputation CreateScalarReluSensitivity();
238 
239   // Special case convenience functions for creating filled arrays.
240 
241   // Creates an array of pseudorandom values lying between the given minimum and
242   // maximum values.
243   template <typename NativeT>
244   std::vector<NativeT> CreatePseudorandomR1(const int width, NativeT min_value,
245                                             NativeT max_value, uint32_t seed);
246   template <typename NativeT>
247   std::unique_ptr<Array2D<NativeT>> CreatePseudorandomR2(const int rows,
248                                                          const int cols,
249                                                          NativeT min_value,
250                                                          NativeT max_value,
251                                                          uint32_t seed);
252 
253   // Creates a (rows x cols) array filled in the following form:
254   //
255   //  [      0              1 ...                   cols-1]
256   //  [  1,000          1,001 ...          1000.0 + cols-1]
257   //  [    ...            ... ...                      ...]
258   //  [(rows-1)*1000.0    ... ... (rows-1)*1000.0 + cols-1]
259   //
260   // If provided, offset is added uniformly to every element (e.g. an offset of
261   // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.)
262   std::unique_ptr<Array2D<float>> CreatePatternedMatrix(const int rows,
263                                                         const int cols,
264                                                         float offset = 0.0);
265 
266   // Creates a (rows x cols) array as above, padded out to
267   // (rows_padded x cols_padded) with zeroes.  Requires rows_padded >= rows
268   // and cols_padded > cols.
269   std::unique_ptr<Array2D<float>> CreatePatternedMatrixWithZeroPadding(
270       const int rows, const int cols, const int rows_padded,
271       const int cols_padded);
272 
273   // Creates a parameter instruction, transfers the literal for the parameter to
274   // server, then stores into "data_handle" the global handle for that
275   // parameter. When the use_bfloat16 flag is set but the literal has F32
276   // elements, the literal will be converted to BF16 before being transferred.
277   StatusOr<std::unique_ptr<GlobalData>> CreateParameterAndTransferLiteral(
278       int64_t parameter_number, const Literal& literal, const std::string& name,
279       XlaBuilder* builder, XlaOp* data_handle);
280 
281   // As above, but the caller can specify the device that the literal is
282   // transferred to. If device_handle is nullptr, the literal will be
283   // transferred to the default device.
284   StatusOr<std::unique_ptr<GlobalData>> CreateParameterAndTransferLiteral(
285       int64_t parameter_number, const Literal& literal, const std::string& name,
286       const DeviceHandle* device_handle, XlaBuilder* builder,
287       XlaOp* data_handle);
288 
289   // Creates a parameter instruction and sets the value that will be passed to
290   // the computation as specified. This function must be used for all parameters
291   // or none and no parameters must be passed when invoking the computation if
292   // using this mechanism. If using this mechanism, then each parameter must be
293   // set exactly once. The first added parameter gets index 0, then 1 and so on.
294   XlaOp AddParam(const Literal& argument, XlaBuilder* builder);
295 
296   template <class T>
AddParam(const Array<T> & argument,XlaBuilder * builder)297   XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
298     return AddParam(LiteralUtil::CreateFromArray(argument), builder);
299   }
300 
301   // Creates a constant instruction with the given literal. When the
302   // use_bfloat16 flag is set but the literal has F32 elements, the elements
303   // will be converted to BF16s.
304   XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder);
305 
306   // Creates a constant instruction with the given array. When the use_bfloat16
307   // flag is set but the array has float elements, the elements will be
308   // converted to bfloat16s.
309 
310   template <typename NativeT>
CreateConstantFromArray(const Array<NativeT> & array,XlaBuilder * builder)311   XlaOp CreateConstantFromArray(const Array<NativeT>& array,
312                                 XlaBuilder* builder) {
313     return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
314                                      builder);
315   }
316 
317   // Same as CreateConstantFromArray, but for scalars.
318   template <typename NativeT>
CreateConstantFromScalar(NativeT value,XlaBuilder * builder)319   XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
320     return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
321                                      builder);
322   }
323 
324   // Creates a parameter instruction that wraps a given value and then stores
325   // into "data_handle" the global handle for that parameter.
326   //
327   // "parameter_number" is the parameter number.
328   // "name" is the name of the parameter instruction.
329   //
330   // When the use_bfloat16 flag is set but NativeT is float, the data will be
331   // converted to bfloat16.
332   template <typename NativeT>
333   std::unique_ptr<GlobalData> CreateR0Parameter(NativeT value,
334                                                 int64_t parameter_number,
335                                                 const std::string& name,
336                                                 XlaBuilder* builder,
337                                                 XlaOp* data_handle);
338 
339   // Creates a parameter instruction that wraps the given values and then stores
340   // into "data_handle" the global handle for that parameter.
341   //
342   // "parameter_number" is the parameter number.
343   // "name" is the name of the parameter instruction.
344   //
345   // When the use_bfloat16 flag is set but NativeT is float, the data will be
346   // converted to bfloat16.
347   template <typename NativeT>
348   std::unique_ptr<GlobalData> CreateR1Parameter(
349       absl::Span<const NativeT> values, int64_t parameter_number,
350       const std::string& name, XlaBuilder* builder, XlaOp* data_handle);
351 
352   // Creates a parameter instruction that wraps the given constant array
353   // "array_2d" and then stores it to the global handle for that parameter
354   // "data_handle".
355   //
356   // "parameter_number" is the parameter number.
357   // "name" is the name of the parameter instruction.
358   //
359   // When the use_bfloat16 flag is set but NativeT is float, the data will be
360   // converted to bfloat16.
361   template <typename NativeT>
362   std::unique_ptr<GlobalData> CreateR2Parameter(
363       const Array2D<NativeT>& array_2d, int64_t parameter_number,
364       const std::string& name, XlaBuilder* builder, XlaOp* data_handle);
365 
366   // Creates a parameter instruction that wraps the given constant array
367   // "array_3d" and then stores it to the global handle for that parameter
368   // "data_handle".
369   //
370   // "parameter_number" is the parameter number.
371   // "name" is the name of the parameter instruction.
372   //
373   // When the use_bfloat16 flag is set but NativeT is float, the data will be
374   // converted to bfloat16.
375   template <typename NativeT>
376   std::unique_ptr<GlobalData> CreateR3Parameter(
377       const Array3D<NativeT>& array_3d, int64_t parameter_number,
378       const std::string& name, XlaBuilder* builder, XlaOp* data_handle);
379 
380   // Creates a parameter instruction that wraps the given constant array
381   // "array_4d" and then stores it to the global handle for that parameter
382   // "data_handle".
383   //
384   // "parameter_number" is the parameter number.
385   // "name" is the name of the parameter instruction.
386   //
387   // When the use_bfloat16 flag is set but NativeT is float, the data will be
388   // converted to bfloat16.
389   template <typename NativeT>
390   std::unique_ptr<GlobalData> CreateR4Parameter(
391       const Array4D<NativeT>& array_4d, int64_t parameter_number,
392       const std::string& name, XlaBuilder* builder, XlaOp* data_handle);
393 
394   template <typename NativeT>
395   std::unique_ptr<GlobalData> CreateParameter(const Array<NativeT>& array_4d,
396                                               int64_t parameter_number,
397                                               const std::string& name,
398                                               XlaBuilder* builder,
399                                               XlaOp* data_handle);
400 
401   // Getter and setter for the use_bfloat16 flag, which indicates whether to run
402   // tests with all float-type input/output converted to bfloat16.
use_bfloat16()403   bool use_bfloat16() const { return use_bfloat16_; }
set_use_bfloat16(bool value)404   void set_use_bfloat16(bool value) { use_bfloat16_ = value; }
405 
406   // The float type used in this test, BF16 or F32 according to use_bfloat16.
FloatType()407   PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; }
408 
409   // Executes the computation and calculates the expected reference value using
410   // the reference client. Returns two literals in the order of (expected,
411   // actual).
412   StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
413       XlaBuilder* builder, absl::Span<const Literal> arguments);
414 
415   // Converts an f32 literal to bf16 if use_bfloat16_ is true.
416   Literal MaybeConvertLiteralToBfloat16(const Literal& literal);
417 
418   LocalClient* client_;
419   LocalClient* ref_client_;  // To compute reference result.
420   ExecutionOptions execution_options_;
421 
422  private:
423   Status ComputeAndCompareLiteralWithAllOutputLayouts(
424       const xla::XlaComputation& computation, const Literal& expected,
425       absl::Span<GlobalData* const> arguments,
426       const std::function<void(const Literal& actual,
427                                const std::string& error_message)>&
428           verify_output);
429   Status ComputeAndCompareLiteralWithAllInputLayouts(
430       const xla::XlaComputation& computation, const Literal& expected,
431       absl::Span<GlobalData* const> arguments,
432       const std::function<void(const Literal& actual,
433                                const std::string& error_message)>&
434           verify_output,
435       const Shape* output_with_layout = nullptr);
436 
437   // Converts an f32 shape to bf16 if use_bfloat16_ is true.
438   Shape MaybeConvertShapeToBfloat16(const Shape& shape);
439 
440   // Whether to run tests with all float-type input/output converted to
441   // bfloat16.
442   bool use_bfloat16_ = false;
443 
444   // Arguments to be passed to the computation when it runs.
445   std::vector<Literal> arguments_;
446 };
447 
448 template <typename NativeT>
ComputeAndCompareR0(XlaBuilder * builder,NativeT expected,absl::Span<GlobalData * const> arguments)449 void ClientLibraryTestBase::ComputeAndCompareR0(
450     XlaBuilder* builder, NativeT expected,
451     absl::Span<GlobalData* const> arguments) {
452   Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
453   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
454                                                   arguments);
455 }
456 
457 template <typename NativeT>
ComputeAndCompareR0(XlaBuilder * builder,NativeT expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)458 void ClientLibraryTestBase::ComputeAndCompareR0(
459     XlaBuilder* builder, NativeT expected,
460     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
461   static_assert(std::is_same<NativeT, float>::value ||
462                     std::is_same<NativeT, double>::value ||
463                     std::is_same<NativeT, bfloat16>::value ||
464                     std::is_same<NativeT, half>::value ||
465                     std::is_same<NativeT, complex64>::value ||
466                     std::is_same<NativeT, complex128>::value,
467                 "Float or complex type required when specifying an ErrorSpec");
468   Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
469   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
470                                                   arguments, error);
471 }
472 
473 template <typename NativeT>
ComputeAndCompareR1(XlaBuilder * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments)474 void ClientLibraryTestBase::ComputeAndCompareR1(
475     XlaBuilder* builder, absl::Span<const NativeT> expected,
476     absl::Span<GlobalData* const> arguments) {
477   Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
478   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
479                                                   arguments);
480 }
481 
482 template <typename NativeT>
ComputeAndCompareR1(XlaBuilder * builder,absl::Span<const NativeT> expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)483 void ClientLibraryTestBase::ComputeAndCompareR1(
484     XlaBuilder* builder, absl::Span<const NativeT> expected,
485     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
486   static_assert(std::is_same<NativeT, float>::value ||
487                     std::is_same<NativeT, double>::value ||
488                     std::is_same<NativeT, bfloat16>::value ||
489                     std::is_same<NativeT, half>::value ||
490                     std::is_same<NativeT, complex64>::value ||
491                     std::is_same<NativeT, complex128>::value,
492                 "Float or complex type required when specifying an ErrorSpec");
493   Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
494   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
495                                                   arguments, error);
496 }
497 
498 template <typename NativeT>
ComputeAndCompareR2(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments)499 void ClientLibraryTestBase::ComputeAndCompareR2(
500     XlaBuilder* builder, const Array2D<NativeT>& expected,
501     absl::Span<GlobalData* const> arguments) {
502   Literal expected_literal =
503       LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
504   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
505                                                   arguments);
506 }
507 
508 template <typename NativeT>
ComputeAndCompareR2(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)509 void ClientLibraryTestBase::ComputeAndCompareR2(
510     XlaBuilder* builder, const Array2D<NativeT>& expected,
511     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
512   static_assert(std::is_same<NativeT, float>::value ||
513                     std::is_same<NativeT, double>::value ||
514                     std::is_same<NativeT, bfloat16>::value ||
515                     std::is_same<NativeT, half>::value ||
516                     std::is_same<NativeT, complex64>::value ||
517                     std::is_same<NativeT, complex128>::value,
518                 "Float or complex type required when specifying an ErrorSpec");
519   Literal expected_literal =
520       LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
521   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
522                                                   arguments, error);
523 }
524 
525 template <typename NativeT>
ComputeAndCompareR3(XlaBuilder * builder,const Array3D<NativeT> & expected,absl::Span<GlobalData * const> arguments)526 void ClientLibraryTestBase::ComputeAndCompareR3(
527     XlaBuilder* builder, const Array3D<NativeT>& expected,
528     absl::Span<GlobalData* const> arguments) {
529   Literal expected_literal =
530       LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
531   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
532                                                   arguments);
533 }
534 
535 template <typename NativeT>
ComputeAndCompareR3(XlaBuilder * builder,const Array3D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)536 void ClientLibraryTestBase::ComputeAndCompareR3(
537     XlaBuilder* builder, const Array3D<NativeT>& expected,
538     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
539   static_assert(std::is_same<NativeT, float>::value ||
540                     std::is_same<NativeT, double>::value ||
541                     std::is_same<NativeT, bfloat16>::value ||
542                     std::is_same<NativeT, half>::value ||
543                     std::is_same<NativeT, complex64>::value ||
544                     std::is_same<NativeT, complex128>::value,
545                 "Float or complex type required when specifying an ErrorSpec");
546   Literal expected_literal =
547       LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
548   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
549                                                   arguments, error);
550 }
551 
552 template <typename NativeT>
ComputeAndCompareR4(XlaBuilder * builder,const Array4D<NativeT> & expected,absl::Span<GlobalData * const> arguments)553 void ClientLibraryTestBase::ComputeAndCompareR4(
554     XlaBuilder* builder, const Array4D<NativeT>& expected,
555     absl::Span<GlobalData* const> arguments) {
556   Literal expected_literal =
557       LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
558   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
559                                                   arguments);
560 }
561 
562 template <typename NativeT>
ComputeAndCompareR4(XlaBuilder * builder,const Array4D<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)563 void ClientLibraryTestBase::ComputeAndCompareR4(
564     XlaBuilder* builder, const Array4D<NativeT>& expected,
565     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
566   static_assert(std::is_same<NativeT, float>::value ||
567                     std::is_same<NativeT, double>::value ||
568                     std::is_same<NativeT, bfloat16>::value ||
569                     std::is_same<NativeT, half>::value ||
570                     std::is_same<NativeT, complex64>::value ||
571                     std::is_same<NativeT, complex128>::value,
572                 "Float or complex type required when specifying an ErrorSpec");
573   Literal expected_literal =
574       LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
575   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
576                                                   arguments, error);
577 }
578 
579 template <typename NativeT>
ComputeAndCompare(XlaBuilder * builder,const Array<NativeT> & expected,absl::Span<GlobalData * const> arguments)580 void ClientLibraryTestBase::ComputeAndCompare(
581     XlaBuilder* builder, const Array<NativeT>& expected,
582     absl::Span<GlobalData* const> arguments) {
583   Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
584   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
585                                                   arguments);
586 }
587 
588 template <typename NativeT>
ComputeAndCompare(XlaBuilder * builder,const Array<NativeT> & expected,absl::Span<GlobalData * const> arguments,ErrorSpec error)589 void ClientLibraryTestBase::ComputeAndCompare(
590     XlaBuilder* builder, const Array<NativeT>& expected,
591     absl::Span<GlobalData* const> arguments, ErrorSpec error) {
592   static_assert(std::is_same<NativeT, float>::value ||
593                     std::is_same<NativeT, double>::value ||
594                     std::is_same<NativeT, bfloat16>::value ||
595                     std::is_same<NativeT, half>::value ||
596                     std::is_same<NativeT, complex64>::value ||
597                     std::is_same<NativeT, complex128>::value,
598                 "Float or complex type required when specifying an ErrorSpec");
599   Literal expected_literal = LiteralUtil::CreateFromArray<NativeT>(expected);
600   ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
601                                                   arguments, error);
602 }
603 
604 template <typename NativeT>
CreateR0Parameter(NativeT value,int64_t parameter_number,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)605 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
606     NativeT value, int64_t parameter_number, const std::string& name,
607     XlaBuilder* builder, XlaOp* data_handle) {
608   Literal literal = LiteralUtil::CreateR0(value);
609   if (use_bfloat16_ && literal.shape().element_type() == F32) {
610     literal = LiteralUtil::ConvertF32ToBF16(literal);
611   }
612   std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
613   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
614   return data;
615 }
616 
617 template <typename NativeT>
CreateR1Parameter(absl::Span<const NativeT> values,int64_t parameter_number,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)618 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
619     absl::Span<const NativeT> values, int64_t parameter_number,
620     const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
621   Literal literal = LiteralUtil::CreateR1(values);
622   if (use_bfloat16_ && literal.shape().element_type() == F32) {
623     literal = LiteralUtil::ConvertF32ToBF16(literal);
624   }
625   std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
626   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
627   return data;
628 }
629 
630 template <typename NativeT>
CreateR2Parameter(const Array2D<NativeT> & array_2d,int64_t parameter_number,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)631 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
632     const Array2D<NativeT>& array_2d, int64_t parameter_number,
633     const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
634   Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
635   if (use_bfloat16_ && literal.shape().element_type() == F32) {
636     literal = LiteralUtil::ConvertF32ToBF16(literal);
637   }
638   std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
639   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
640   return data;
641 }
642 
643 template <typename NativeT>
CreateR3Parameter(const Array3D<NativeT> & array_3d,int64_t parameter_number,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)644 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
645     const Array3D<NativeT>& array_3d, int64_t parameter_number,
646     const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
647   Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
648   if (use_bfloat16_ && literal.shape().element_type() == F32) {
649     literal = LiteralUtil::ConvertF32ToBF16(literal);
650   }
651   std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
652   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
653   return data;
654 }
655 
656 template <typename NativeT>
CreateR4Parameter(const Array4D<NativeT> & array_4d,int64_t parameter_number,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)657 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR4Parameter(
658     const Array4D<NativeT>& array_4d, int64_t parameter_number,
659     const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
660   Literal literal = LiteralUtil::CreateR4FromArray4D(array_4d);
661   if (use_bfloat16_ && literal.shape().element_type() == F32) {
662     literal = LiteralUtil::ConvertF32ToBF16(literal);
663   }
664   std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
665   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
666   return data;
667 }
668 
669 template <typename NativeT>
CreateParameter(const Array<NativeT> & array,int64_t parameter_number,const std::string & name,XlaBuilder * builder,XlaOp * data_handle)670 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateParameter(
671     const Array<NativeT>& array, int64_t parameter_number,
672     const std::string& name, XlaBuilder* builder, XlaOp* data_handle) {
673   Literal literal = LiteralUtil::CreateFromArray(array);
674   if (use_bfloat16_ && literal.shape().element_type() == F32) {
675     literal = LiteralUtil::ConvertF32ToBF16(literal);
676   }
677   std::unique_ptr<GlobalData> data = client_->TransferToServer(literal).value();
678   *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
679   return data;
680 }
681 
682 template <typename NativeT>
CreatePseudorandomR1(const int width,NativeT min_value,NativeT max_value,uint32_t seed)683 std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
684     const int width, NativeT min_value, NativeT max_value, uint32_t seed) {
685   std::vector<NativeT> result(width);
686   PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
687   for (int i = 0; i < width; ++i) {
688     result[i] = generator.get();
689   }
690   return result;
691 }
692 
693 template <typename NativeT>
CreatePseudorandomR2(const int rows,const int cols,NativeT min_value,NativeT max_value,uint32_t seed)694 std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
695     const int rows, const int cols, NativeT min_value, NativeT max_value,
696     uint32_t seed) {
697   auto result = std::make_unique<Array2D<NativeT>>(rows, cols);
698   PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
699   for (int y = 0; y < rows; ++y) {
700     for (int x = 0; x < cols; ++x) {
701       (*result)(y, x) = generator.get();
702     }
703   }
704   return result;
705 }
706 
707 }  // namespace xla
708 
709 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
710