1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Randomized tests for XLA implementations of Tensorflow operations.
17 //
18 // For each operator, the tests in this file choose a set of random inputs and
19 // attributes. The test then compares the outputs of the operator when executed
20 // via Tensorflow using the CPU device and when executed via XLA.
21 //
22 // By default, each test chooses a random seed nondeterministically (using
23 // std::random_device). However, a particular choice of random seed can be
24 // forced using the flag --tf_xla_random_seed; each test logs the
25 // flag value necessary to reproduce its outputs.
26 //
27 // Example usage:
28 // Run tests, comparing the Tensorflow CPU operators with their XLA-compiled
29 // counterparts:
30 // randomized_tests \
31 // --tf_xla_test_use_jit=true --tf_xla_test_device=CPU:0 \
32 // --tf_xla_test_repetitions=20
33
34 // TODO(phawkins): add tests for:
35 // * DepthwiseConv2DNative
36 // * Gather
37 // * InvertPermutation
38 // * MaxPoolGrad (requires implementation of forward operator)
39 // * Select
40 // * Unpack
41 //
42 // TODO(phawkins): improve tests for:
43 // * StridedSliceGrad (need to use shape function to compute sensible inputs)
44
45 #include <algorithm>
46 #include <random>
47 #include <unordered_map>
48
49 #include "absl/algorithm/container.h"
50 #include "absl/container/flat_hash_set.h"
51 #include "absl/strings/str_cat.h"
52 #include "absl/strings/string_view.h"
53 #include "tensorflow/compiler/jit/defs.h"
54 #include "tensorflow/compiler/jit/flags.h"
55 #include "tensorflow/compiler/tf2xla/type_util.h"
56 #include "tensorflow/core/common_runtime/device.h"
57 #include "tensorflow/core/common_runtime/device_factory.h"
58 #include "tensorflow/core/common_runtime/device_mgr.h"
59 #include "tensorflow/core/common_runtime/graph_constructor.h"
60 #include "tensorflow/core/framework/kernel_shape_util.h"
61 #include "tensorflow/core/framework/node_def_builder.h"
62 #include "tensorflow/core/framework/node_def_util.h"
63 #include "tensorflow/core/framework/op_kernel.h"
64 #include "tensorflow/core/framework/tensor.h"
65 #include "tensorflow/core/framework/tensor_testutil.h"
66 #include "tensorflow/core/framework/types.pb.h"
67 #include "tensorflow/core/graph/graph.h"
68 #include "tensorflow/core/lib/core/status.h"
69 #include "tensorflow/core/lib/core/status_test_util.h"
70 #include "tensorflow/core/platform/bfloat16.h"
71 #include "tensorflow/core/platform/test.h"
72 #include "tensorflow/core/public/session.h"
73 #include "tensorflow/core/public/session_options.h"
74 #include "tensorflow/core/util/command_line_flags.h"
75 #include "tensorflow/core/util/device_name_utils.h"
76 #include "tensorflow/core/util/tensor_format.h"
77
78 namespace tensorflow {
79 namespace {
80
81 // Command line flags: see main() below.
82 int64_t tf_xla_random_seed = 0;
83 int32_t tf_xla_test_repetitions = 20;
84 int64_t tf_xla_max_tensor_size = 10000LL;
85 string* tf_xla_test_device_ptr; // initial value set in main()
86 string* tf_xla_reference_device_ptr; // initial value set in main()
87 bool tf_xla_test_use_jit = true;
88 bool tf_xla_test_use_mlir = false;
89
LocalDeviceToFullDeviceName(const string & device)90 string LocalDeviceToFullDeviceName(const string& device) {
91 return absl::StrCat("/job:localhost/replica:0/task:0/device:", device);
92 }
93
94 constexpr std::array<DataType, 5> kAllXlaTypes = {
95 {DT_INT32, DT_INT64, DT_FLOAT, DT_BOOL, DT_COMPLEX64}};
96 constexpr std::array<DataType, 4> kAllNumberTypes = {
97 {DT_INT32, DT_INT64, DT_FLOAT, DT_COMPLEX64}};
98
99 // An OpTestBuilder is a graph builder class that takes as input an operator to
100 // test, its inputs and attributes, and builds a graph that executes the
101 // operator.
102 class OpTestBuilder {
103 public:
104 explicit OpTestBuilder(const string& op_name);
105
106 // Adds an input 'tensor' as a Placeholder node.
107 OpTestBuilder& Input(const Tensor& tensor);
108
109 // Adds a random input tensor with 'type' as a Placeholder node.
110 // If 'dims' is not provided, RandomDims() is used.
111 OpTestBuilder& RandomInput(DataType type);
112 OpTestBuilder& RandomInput(DataType type, std::vector<int64_t> dims);
113
114 // As RandomInput but the values are unique.
115 OpTestBuilder& RandomUniqueInput(DataType type, std::vector<int64_t> dims);
116
117 // Add variadic input tensors as Placehodler nodes.
118 OpTestBuilder& VariadicInput(const std::vector<Tensor>& tensor);
119
120 // Sets an attribute.
121 template <class T>
122 OpTestBuilder& Attr(absl::string_view attr_name, T&& value);
123
124 // Overload needed to allow {...} expressions for value.
125 template <class T>
126 OpTestBuilder& Attr(absl::string_view attr_name,
127 std::initializer_list<T> value);
128
129 // Adds nodes that executes the operator under test on 'device' to 'graphdef'.
130 // If 'use_jit' is true, marks the operator under test to be compiled by XLA.
131 // The graph will consist of one Placeholder node per input, the operator
132 // itself, and one Identity node per output. If 'test_node_def' is not null,
133 // sets it to the NodeDef of the operator under test. Fills 'inputs' and
134 // 'outputs' with the names of the input placeholder nodes and the output
135 // identity nodes, respectively.
136 Status BuildGraph(const string& name_prefix, const string& device,
137 bool use_jit, GraphDef* graphdef, NodeDef** test_node_def,
138 std::vector<string>* inputs,
139 std::vector<string>* outputs) const;
140
141 struct InputDescription {
142 Tensor tensor;
143
144 DataType type = DT_INVALID;
145 bool has_dims = false;
146 bool needs_unique_values = false;
147 std::vector<int64_t> dims;
148 };
149
inputs() const150 const std::vector<InputDescription>& inputs() const { return inputs_; }
151
152 private:
153 NodeDef node_def_;
154 std::vector<InputDescription> inputs_;
155 };
156
OpTestBuilder(const string & op_name)157 OpTestBuilder::OpTestBuilder(const string& op_name) {
158 node_def_.set_op(op_name);
159 }
160
Input(const Tensor & tensor)161 OpTestBuilder& OpTestBuilder::Input(const Tensor& tensor) {
162 VLOG(1) << "Adding input: " << tensor.DebugString();
163 InputDescription input;
164 input.tensor = tensor;
165 inputs_.push_back(input);
166 return *this;
167 }
168
RandomInput(DataType type)169 OpTestBuilder& OpTestBuilder::RandomInput(DataType type) {
170 VLOG(1) << "Adding random input: " << type;
171 InputDescription input;
172 input.type = type;
173 inputs_.push_back(input);
174 return *this;
175 }
176
RandomInput(DataType type,std::vector<int64_t> dims)177 OpTestBuilder& OpTestBuilder::RandomInput(DataType type,
178 std::vector<int64_t> dims) {
179 VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString();
180 InputDescription input;
181 input.type = type;
182 input.has_dims = true;
183 input.dims = std::move(dims);
184 inputs_.push_back(input);
185 return *this;
186 }
187
RandomUniqueInput(DataType type,std::vector<int64_t> dims)188 OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type,
189 std::vector<int64_t> dims) {
190 VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString();
191 InputDescription input;
192 input.type = type;
193 input.has_dims = true;
194 input.needs_unique_values = true;
195 input.dims = std::move(dims);
196 inputs_.push_back(input);
197 return *this;
198 }
199
VariadicInput(const std::vector<Tensor> & tensors)200 OpTestBuilder& OpTestBuilder::VariadicInput(
201 const std::vector<Tensor>& tensors) {
202 VLOG(1) << "Adding variadic input of length " << tensors.size() << ":";
203 for (auto& t : tensors) {
204 Input(t);
205 }
206 return *this;
207 }
208
209 template <class T>
Attr(absl::string_view attr_name,T && value)210 OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, T&& value) {
211 AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
212 return *this;
213 }
214
215 template <class T>
Attr(absl::string_view attr_name,std::initializer_list<T> value)216 OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name,
217 std::initializer_list<T> value) {
218 Attr<std::initializer_list<T>>(attr_name, std::move(value));
219 return *this;
220 }
221
BuildGraph(const string & name_prefix,const string & device,bool use_jit,GraphDef * graphdef,NodeDef ** test_node_def,std::vector<string> * inputs,std::vector<string> * outputs) const222 Status OpTestBuilder::BuildGraph(const string& name_prefix,
223 const string& device, bool use_jit,
224 GraphDef* graphdef, NodeDef** test_node_def,
225 std::vector<string>* inputs,
226 std::vector<string>* outputs) const {
227 OpRegistryInterface* op_registry = OpRegistry::Global();
228
229 const OpDef* op_def;
230 TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(node_def_.op(), &op_def));
231
232 NodeDef* test_def = graphdef->add_node();
233 *test_def = node_def_;
234 test_def->set_name(absl::StrCat(name_prefix, "_op_under_test"));
235 test_def->set_device(device);
236 AddDefaultsToNodeDef(*op_def, test_def);
237 if (use_jit) {
238 AddNodeAttr(kXlaCompileAttr, true, test_def);
239 }
240 VLOG(1) << "Op under test: " << test_def->DebugString();
241
242 DataTypeVector input_types, output_types;
243 TF_RETURN_IF_ERROR(
244 InOutTypesForNode(*test_def, *op_def, &input_types, &output_types));
245
246 // Build feed and fetch nodes.
247 for (int i = 0; i < input_types.size(); ++i) {
248 NodeDef* def = graphdef->add_node();
249 string name = absl::StrCat(name_prefix, "_input_", i);
250 TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder")
251 .Device(device)
252 .Attr("dtype", input_types[i])
253 .Finalize(def));
254 inputs->push_back(name);
255 test_def->add_input(name);
256 }
257
258 for (int i = 0; i < output_types.size(); ++i) {
259 NodeDef* def = graphdef->add_node();
260 string name = absl::StrCat(name_prefix, "_output_", i);
261 TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity")
262 .Device(device)
263 .Attr("T", output_types[i])
264 .Input(test_def->name(), i, output_types[i])
265 .Finalize(def));
266 outputs->push_back(name);
267 }
268
269 if (test_node_def) {
270 *test_node_def = test_def;
271 }
272
273 return OkStatus();
274 }
275
276 // Test fixture. The fixture manages the random number generator and its seed,
277 // and has a number of convenience methods for building random Tensors, shapes,
278 // etc.
279 class OpTest : public ::testing::Test {
280 public:
281 OpTest();
282
283 enum TestResult {
284 // The test saw an unrecoverable error. Don't try any more runs.
285 kFatalError,
286 // The parameters of the test were invalid (e.g., the "golden"
287 // implementation failed, or the parameters are oversize). Reruns are ok.
288 kInvalid,
289 // The test ran successfully, and we have a verdict. Does *not* mean the
290 // test passed.
291 kOk,
292 };
293
294 // Runs 'fn' up to --tf_xla_test_repetitions times, or until a test failure
295 // occurs; whichever happens first. Reruns if the TestResult is kInvalid.
296 void Repeatedly(const std::function<TestResult(void)>& fn);
297
298 // Select a random element from 'candidates'.
299 template <typename T>
300 T Choose(absl::Span<const T> candidates);
301
302 static constexpr int kDefaultMaxRank = 5;
303 static constexpr int64_t kDefaultMaxDimensionSize = 256LL;
304
305 // Returns true if 'dims' have a size less than tf_xla_max_tensor_size.
306 bool TensorSizeIsOk(absl::Span<const int64_t> dims);
307
308 // Returns a random dimension size, in the range [min, max).
309 int64_t RandomDim(int64_t min = 0, int64_t max = kDefaultMaxDimensionSize);
310
311 // Returns a random shape. The tensor has rank in the range [min_rank,
312 // max_rank). Each dimension has size [min_size, max_size).
313 std::vector<int64_t> RandomDims(int min_rank = 0,
314 int max_rank = kDefaultMaxRank,
315 int64_t min_size = 0,
316 int64_t max_size = kDefaultMaxDimensionSize);
317
318 // Given a shape 'dims', build dimensions that are broadcastable to 'dims'.
319 std::vector<int64_t> BroadcastableToDims(std::vector<int64_t> dims);
320
321 // Given a shape 'dims', build a pair of dimensions such that one broadcasts
322 // to the other.
323 std::pair<std::vector<int64_t>, std::vector<int64_t>> BroadcastableDims(
324 std::vector<int64_t> dims);
325
326 // Builds a random pair of broadcastable dims.
327 // TODO(phawkins): currently the maximum rank is 3, because broadcasting > 3
328 // dimensions is unimplemented by the Tensorflow Eigen code (b/29268487)
329 std::pair<std::vector<int64_t>, std::vector<int64_t>> BroadcastableDims();
330
331 // Returns a tensor filled with random but "reasonable" values from the middle
332 // of the type's range. If the shape is omitted, a random shape is used.
333 // TODO(phawkins): generalize this code to a caller-supplied distribution.
334 Tensor RandomTensor(DataType dtype, bool needs_unique_values,
335 absl::Span<const int64_t> shape);
336 Tensor RandomTensor(DataType dtype);
337
338 // Like RandomTensor, but uses values >= 0.
339 Tensor RandomNonNegativeTensor(DataType dtype,
340 absl::Span<const int64_t> shape);
341 Tensor RandomNonNegativeTensor(DataType dtype);
342
343 // Like RandomTensor, but all values are in the range [lo, hi].
344 template <typename T>
345 Tensor RandomBoundedTensor(DataType dtype, T lo, T hi,
346 bool needs_unique_values,
347 absl::Span<const int64_t> shape);
348 template <typename T>
349 Tensor RandomBoundedTensor(DataType dtype, T lo, T hi,
350 bool needs_unique_values);
351
352 // Like RandomTensor, but the value at index i is in the range [lo[i], hi[i]].
353 Tensor RandomBoundedTensor(DataType dtype, Tensor lo, Tensor hi);
354
355 // Like RandomTensor, but return a pair {left, right} with
356 // left[i] <= right[i].
357 std::pair<Tensor, Tensor> RandomLteTensors(DataType dtype,
358 absl::Span<const int64_t> shape);
359 std::pair<Tensor, Tensor> RandomLteTensors(DataType dtype);
360
361 // Returns a random subset of the integers in the range [0, rank), suitable
362 // for use as reduction indices.
363 Tensor RandomReductionIndices(int rank);
364
365 // Returns a random bit.
366 bool RandomBool();
367
368 // Randomly choose a seed for a random number generator.
369 int64_t RandomSeed();
370
371 struct WindowedSpatialDims {
372 Padding padding;
373 std::vector<int64_t> kernel_dims;
374 std::vector<int64_t> stride_dims;
375 std::vector<int64_t> input_dims;
376 std::vector<int64_t> output_dims;
377 };
378 // Choose spatial dimensions for a windowed op such as pooling or convolution.
379 WindowedSpatialDims ChooseWindowedSpatialDims(int num_spatial_dims);
380
381 struct BatchMatMulArguments {
382 std::vector<int64_t> lhs_dims;
383 std::vector<int64_t> rhs_dims;
384 DataType dtype;
385 bool adj_lhs;
386 bool adj_rhs;
387 };
388 // Choose arguments for the tf.BatchMatMul{V2} ops.
389 BatchMatMulArguments ChooseBatchMatMulArguments(bool broadcastable_batch);
390
391 struct ConcatArguments {
392 std::vector<Tensor> values;
393 Tensor axis;
394 int n;
395 DataType type;
396 DataType type_idx;
397 };
398 // Choose arguments for the tf.Concat{V2} ops.
399 ConcatArguments ChooseConcatArguments(bool int64_idx_allowed);
400
401 struct EinsumArguments {
402 std::vector<int64_t> lhs_dims;
403 std::vector<int64_t> rhs_dims;
404 DataType type;
405 std::string equation;
406 };
407 // Choose arguments for the tf.{Xla}Einsum ops.
408 EinsumArguments ChooseEinsumArguments();
409
410 struct GatherArguments {
411 int64_t batch_dims;
412 DataType axis_type;
413 DataType indices_type;
414 DataType params_type;
415 std::vector<int64_t> params_shape;
416 Tensor indices;
417 Tensor axis;
418 };
419 // Choose arguments for the tf.Gather{V2} ops.
420 GatherArguments ChooseGatherArguments(bool axis_0);
421
422 struct PadArguments {
423 DataType input_type;
424 DataType paddings_type;
425 std::vector<int64_t> input_shape;
426 Tensor paddings;
427 Tensor constant_values;
428 };
429 // Choose arguments for the tf.Pad{V2} ops.
430 PadArguments ChoosePadArguments();
431
432 struct ScatterArguments {
433 DataType type;
434 DataType indices_type;
435 Tensor indices;
436 Tensor updates;
437 std::vector<int64_t> shape;
438 };
439 // Choose arguments for ScatterNd and TensorScatterUpdate.
440 ScatterArguments ChooseScatterArguments();
441
442 struct SliceArguments {
443 DataType type;
444 DataType indices_type;
445 std::vector<int64_t> shape;
446 Tensor indices;
447 std::vector<int64_t> size;
448 };
449 // Choose arguments for the tf.{XlaDynamicUpdate}Slice ops.
450 SliceArguments ChooseSliceArguments(bool neg_one_size);
451
452 struct XlaDotArguments {
453 std::vector<int64_t> lhs_dims;
454 std::vector<int64_t> rhs_dims;
455 std::string dnums_encoded;
456 std::string precision_config_encoded;
457 DataType dtype;
458 };
459 // Choose arguments for tf.XlaDot operation.
460 XlaDotArguments ChooseXlaDotArguments();
461
462 // Builds dimensions for a windowed op such as pooling or convolution,
463 // including a batch and feature dimension.
464 std::vector<int64_t> ImageDims(TensorFormat format, int batch, int feature,
465 const std::vector<int64_t>& spatial_dims);
466
467 // Converts an int64 vector to an int32 vector.
468 std::vector<int32> AsInt32s(const std::vector<int64_t>& int64s);
469
generator()470 std::mt19937& generator() { return *generator_; }
471
472 // Run the test case described by 'builder' with and without XLA and check
473 // that the outputs are close. Tensors x and y are close if they have the same
474 // type, same shape, and have close values. For floating-point tensors, the
475 // element-wise difference between x and y must no more than
476 // atol + rtol * abs(x); or both elements may be NaN or infinity. For
477 // non-floating-point tensors the element values must match exactly.
478 TestResult ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder,
479 double atol = 1e-2,
480 double rtol = 1e-2);
481
482 protected:
483 // Per-test state:
484 std::unique_ptr<std::mt19937> generator_;
485
486 std::unique_ptr<Session> session_;
487
488 // Number of test cases built in 'session_'. Used to uniquify node names.
489 int num_tests_ = 0;
490 };
491
OpTest()492 OpTest::OpTest() {
493 // Creates a random-number generator for the test case. Use the value of
494 // --tf_xla_random_seed as the seed, if provided.
495 int64_t s = tf_xla_random_seed;
496 unsigned int seed;
497 if (s <= 0) {
498 std::random_device random_device;
499 seed = random_device();
500 } else {
501 seed = static_cast<unsigned int>(s);
502 }
503 LOG(ERROR) << "Random seed for test case: " << seed
504 << ". To reproduce the "
505 "results of this test, pass flag --tf_xla_random_seed="
506 << seed;
507 generator_.reset(new std::mt19937(seed));
508
509 // Create a session with an empty graph.
510 SessionOptions session_options;
511 session_.reset(NewSession(session_options));
512 GraphDef def;
513 TF_CHECK_OK(session_->Create(def));
514 }
515
516 namespace {
517 template <typename T>
TensorFromValues(DataType dtype,absl::Span<const int64_t> shape,absl::Span<T> vals)518 Tensor TensorFromValues(DataType dtype, absl::Span<const int64_t> shape,
519 absl::Span<T> vals) {
520 Tensor tensor(dtype, TensorShape(shape));
521 test::FillValues<T>(&tensor, vals);
522 return tensor;
523 }
524
ShapeNumVals(absl::Span<const int64_t> shape)525 int64_t ShapeNumVals(absl::Span<const int64_t> shape) {
526 int64_t num_vals = 1;
527 for (int i = 0; i < shape.size(); ++i) {
528 num_vals *= shape[i];
529 }
530 return num_vals;
531 }
532 } // namespace
533
534 // TensorGenerator is an abstact class that has one implementing class for each
535 // (DataType,T) pair. The implementing class implements RandomVals, which is
536 // the only Tensor generation code that is specific to the DataType.
537 template <typename T>
538 class TensorGenerator {
539 public:
TensorGenerator(OpTest & test)540 explicit TensorGenerator(OpTest& test) : test_(test) {}
~TensorGenerator()541 virtual ~TensorGenerator() {}
542 virtual DataType dtype() = 0;
543 virtual void RandomVals(std::optional<T> lo, std::optional<T> hi,
544 bool needs_unique_values,
545 absl::FixedArray<T>& vals) = 0;
546
RandomTensor(std::optional<T> lo,std::optional<T> hi,bool needs_unique_values,absl::Span<const int64_t> shape)547 Tensor RandomTensor(std::optional<T> lo, std::optional<T> hi,
548 bool needs_unique_values,
549 absl::Span<const int64_t> shape) {
550 absl::FixedArray<T> vals(ShapeNumVals(shape));
551 RandomVals(lo, hi, needs_unique_values, vals);
552 return TensorFromValues<T>(dtype(), shape, absl::Span<T>(vals));
553 }
554
RandomLteTensors(absl::Span<const int64_t> shape)555 std::pair<Tensor, Tensor> RandomLteTensors(absl::Span<const int64_t> shape) {
556 int64_t num_vals = ShapeNumVals(shape);
557 absl::FixedArray<T> less(num_vals);
558 RandomVals({}, {}, false, less);
559 absl::FixedArray<T> greater(num_vals);
560 RandomVals({}, {}, false, greater);
561 for (int i = 0; i < num_vals; ++i) {
562 if (less[i] > greater[i]) {
563 std::swap(less[i], greater[i]);
564 }
565 }
566 std::pair<Tensor, Tensor> pair(
567 TensorFromValues<T>(dtype(), shape, absl::Span<T>(less)),
568 TensorFromValues<T>(dtype(), shape, absl::Span<T>(greater)));
569 return pair;
570 }
571
572 protected:
573 OpTest& test_;
574 };
575
576 class TensorGeneratorFloat : public TensorGenerator<float> {
577 public:
TensorGeneratorFloat(OpTest & test)578 explicit TensorGeneratorFloat(OpTest& test) : TensorGenerator(test) {}
dtype()579 DataType dtype() override { return DT_FLOAT; }
RandomVals(std::optional<float> lo,std::optional<float> hi,bool needs_unique_values,absl::FixedArray<float> & vals)580 void RandomVals(std::optional<float> lo, std::optional<float> hi,
581 bool needs_unique_values,
582 absl::FixedArray<float>& vals) override {
583 absl::flat_hash_set<float> already_generated;
584 std::uniform_real_distribution<float> distribution(lo.value_or(-1.0f),
585 hi.value_or(1.0f));
586 for (int64_t i = 0; i < vals.size(); ++i) {
587 float generated;
588 do {
589 generated = distribution(test_.generator());
590 } while (needs_unique_values &&
591 !already_generated.insert(generated).second);
592 vals[i] = (generated);
593 }
594 }
595 };
596
597 class TensorGeneratorDouble : public TensorGenerator<double> {
598 public:
TensorGeneratorDouble(OpTest & test)599 explicit TensorGeneratorDouble(OpTest& test) : TensorGenerator(test) {}
dtype()600 DataType dtype() override { return DT_DOUBLE; }
RandomVals(std::optional<double> lo,std::optional<double> hi,bool needs_unique_values,absl::FixedArray<double> & vals)601 void RandomVals(std::optional<double> lo, std::optional<double> hi,
602 bool needs_unique_values,
603 absl::FixedArray<double>& vals) override {
604 absl::flat_hash_set<double> already_generated;
605 std::uniform_real_distribution<double> distribution(lo.value_or(-1.0),
606 hi.value_or(1.0));
607 for (int64_t i = 0; i < vals.size(); ++i) {
608 double generated;
609 do {
610 generated = distribution(test_.generator());
611 } while (needs_unique_values &&
612 !already_generated.insert(generated).second);
613 vals[i] = generated;
614 }
615 }
616 };
617
618 class TensorGeneratorComplex64 : public TensorGenerator<complex64> {
619 public:
TensorGeneratorComplex64(OpTest & test)620 explicit TensorGeneratorComplex64(OpTest& test) : TensorGenerator(test) {}
dtype()621 DataType dtype() override { return DT_COMPLEX64; }
RandomVals(std::optional<complex64> lo,std::optional<complex64> hi,bool needs_unique_values,absl::FixedArray<complex64> & vals)622 void RandomVals(std::optional<complex64> lo, std::optional<complex64> hi,
623 bool needs_unique_values,
624 absl::FixedArray<complex64>& vals) override {
625 absl::flat_hash_set<std::pair<float, float>> already_generated;
626 if (lo || hi) {
627 LOG(FATAL) << "Lower or upper bounds are not supported for complex64.";
628 }
629 std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
630 for (int64_t i = 0; i < vals.size(); ++i) {
631 complex64 generated;
632 do {
633 generated = complex64(distribution(test_.generator()),
634 distribution(test_.generator()));
635 } while (needs_unique_values &&
636 !already_generated
637 .insert(std::make_pair(generated.real(), generated.imag()))
638 .second);
639 vals[i] = generated;
640 }
641 }
642 };
643
644 class TensorGeneratorInt32 : public TensorGenerator<int32> {
645 public:
TensorGeneratorInt32(OpTest & test)646 explicit TensorGeneratorInt32(OpTest& test) : TensorGenerator(test) {}
dtype()647 DataType dtype() override { return DT_INT32; }
RandomVals(std::optional<int32> lo,std::optional<int32> hi,bool needs_unique_values,absl::FixedArray<int32> & vals)648 void RandomVals(std::optional<int32> lo, std::optional<int32> hi,
649 bool needs_unique_values,
650 absl::FixedArray<int32>& vals) override {
651 absl::flat_hash_set<int32> already_generated;
652 std::uniform_int_distribution<int32> distribution(lo.value_or(-(1 << 20)),
653 hi.value_or(1 << 20));
654 for (int64_t i = 0; i < vals.size(); ++i) {
655 int32_t generated;
656 do {
657 generated = distribution(test_.generator());
658 } while (needs_unique_values &&
659 !already_generated.insert(generated).second);
660 vals[i] = generated;
661 }
662 }
663 };
664
665 class TensorGeneratorInt64 : public TensorGenerator<int64> {
666 public:
TensorGeneratorInt64(OpTest & test)667 explicit TensorGeneratorInt64(OpTest& test) : TensorGenerator(test) {}
dtype()668 DataType dtype() override { return DT_INT64; }
RandomVals(std::optional<int64> lo,std::optional<int64> hi,bool needs_unique_values,absl::FixedArray<int64> & vals)669 void RandomVals(std::optional<int64> lo, std::optional<int64> hi,
670 bool needs_unique_values,
671 absl::FixedArray<int64>& vals) override {
672 absl::flat_hash_set<int64_t> already_generated;
673 std::uniform_int_distribution<int64_t> distribution(
674 lo.value_or(-(1LL << 40)), hi.value_or(1LL << 40));
675 for (int64_t i = 0; i < vals.size(); ++i) {
676 int64_t generated;
677 do {
678 generated = distribution(test_.generator());
679 } while (needs_unique_values &&
680 !already_generated.insert(generated).second);
681 vals[i] = generated;
682 }
683 }
684 };
685
686 class TensorGeneratorBool : public TensorGenerator<bool> {
687 public:
TensorGeneratorBool(OpTest & test)688 explicit TensorGeneratorBool(OpTest& test) : TensorGenerator(test) {}
dtype()689 DataType dtype() override { return DT_BOOL; }
RandomVals(std::optional<bool> lo,std::optional<bool> hi,bool needs_unique_values,absl::FixedArray<bool> & vals)690 void RandomVals(std::optional<bool> lo, std::optional<bool> hi,
691 bool needs_unique_values,
692 absl::FixedArray<bool>& vals) override {
693 absl::flat_hash_set<bool> already_generated;
694 if (lo || hi) {
695 LOG(FATAL) << "Lower or upper bounds are not supported for bool.";
696 }
697 std::bernoulli_distribution distribution;
698 for (int64_t i = 0; i < vals.size(); ++i) {
699 bool generated;
700 do {
701 generated = distribution(test_.generator());
702 } while (needs_unique_values &&
703 !already_generated.insert(generated).second);
704 vals[i] = generated;
705 }
706 }
707 };
708
Repeatedly(const std::function<TestResult (void)> & fn)709 void OpTest::Repeatedly(const std::function<TestResult(void)>& fn) {
710 int const max_repetitions = tf_xla_test_repetitions;
711 int valid_test_runs = 0;
712 // We run up to 100 * max_repetitions times; the idea is that if we roll the
713 // dice enough times we will find some valid parameters. We want to put an
714 // upper limit on the number iterations just in case the probability of
715 // finding feasible parameters is very low.
716 for (int i = 0; !HasFailure() && i < max_repetitions * 100 &&
717 valid_test_runs < max_repetitions;
718 ++i) {
719 TestResult result = fn();
720 switch (result) {
721 case kOk:
722 ++valid_test_runs;
723 break;
724
725 case kFatalError:
726 ASSERT_TRUE(false) << "Test had fatal failure";
727 return;
728
729 case kInvalid:
730 break;
731 }
732 }
733 if (!HasFailure()) {
734 EXPECT_GE(valid_test_runs, max_repetitions)
735 << "Not enough test instances passed; this means that either the "
736 "golden implementation is buggy or the operator harness is not "
737 "producing well-formed test cases with a high probability.";
738 }
739 }
740
741 template <typename T>
Choose(absl::Span<const T> candidates)742 T OpTest::Choose(absl::Span<const T> candidates) {
743 std::uniform_int_distribution<size_t> d(0, candidates.size() - 1);
744 return candidates[d(generator())];
745 }
746
RandomDim(int64_t min,int64_t max)747 int64_t OpTest::RandomDim(int64_t min, int64_t max) {
748 std::uniform_int_distribution<int64_t> size_distribution(min, max - 1);
749 return size_distribution(generator());
750 }
751
TensorSizeIsOk(absl::Span<const int64_t> dims)752 bool OpTest::TensorSizeIsOk(absl::Span<const int64_t> dims) {
753 int64_t size = 1LL;
754 for (int64_t dim : dims) {
755 size *= dim;
756 }
757 return size < tf_xla_max_tensor_size;
758 }
759
RandomDims(int min_rank,int max_rank,int64_t min_size,int64_t max_size)760 std::vector<int64_t> OpTest::RandomDims(int min_rank, int max_rank,
761 int64_t min_size, int64_t max_size) {
762 CHECK_LE(0, min_rank);
763 CHECK_LE(min_rank, max_rank);
764 std::uniform_int_distribution<int> rank_distribution(min_rank, max_rank);
765 int rank = rank_distribution(generator());
766 std::vector<int64_t> dims(rank);
767 if (rank == 0) {
768 return dims;
769 }
770 int64_t per_dim_limit = std::pow(tf_xla_max_tensor_size, 1.0 / rank);
771 int64_t per_dim_max = std::min(max_size, per_dim_limit);
772 std::generate(dims.begin(), dims.end(), [this, min_size, per_dim_max]() {
773 return RandomDim(min_size, per_dim_max);
774 });
775 CHECK(TensorSizeIsOk(dims)); // Crash OK
776 return dims;
777 }
778
RandomBool()779 bool OpTest::RandomBool() {
780 std::bernoulli_distribution d(0.5);
781 return d(generator());
782 }
783
RandomSeed()784 int64_t OpTest::RandomSeed() {
785 std::uniform_int_distribution<int64_t> seed_dist(
786 std::numeric_limits<int64_t>::min(), std::numeric_limits<int64_t>::max());
787 int64_t seed = seed_dist(generator());
788 if (seed == 0) return 1;
789 return seed;
790 }
791
RandomTensor(DataType dtype,bool needs_unique_values,absl::Span<const int64_t> shape)792 Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
793 absl::Span<const int64_t> shape) {
794 switch (dtype) {
795 case DT_FLOAT:
796 return TensorGeneratorFloat(*this).RandomTensor(
797 {}, {}, needs_unique_values, shape);
798 case DT_DOUBLE:
799 return TensorGeneratorDouble(*this).RandomTensor(
800 {}, {}, needs_unique_values, shape);
801 case DT_COMPLEX64:
802 return TensorGeneratorComplex64(*this).RandomTensor(
803 {}, {}, needs_unique_values, shape);
804 case DT_INT32:
805 return TensorGeneratorInt32(*this).RandomTensor(
806 {}, {}, needs_unique_values, shape);
807 case DT_INT64:
808 return TensorGeneratorInt64(*this).RandomTensor(
809 {}, {}, needs_unique_values, shape);
810 case DT_BOOL:
811 return TensorGeneratorBool(*this).RandomTensor(
812 {}, {}, needs_unique_values, shape);
813 default:
814 LOG(FATAL) << "Unimplemented type " << dtype << " in RandomTensor";
815 }
816 }
817
RandomTensor(DataType dtype)818 Tensor OpTest::RandomTensor(DataType dtype) {
819 return RandomTensor(dtype, /*needs_unique_values=*/false, RandomDims());
820 }
821
RandomNonNegativeTensor(DataType dtype,absl::Span<const int64_t> shape)822 Tensor OpTest::RandomNonNegativeTensor(DataType dtype,
823 absl::Span<const int64_t> shape) {
824 switch (dtype) {
825 case DT_FLOAT:
826 return TensorGeneratorFloat(*this).RandomTensor({0.0f}, {}, false, shape);
827 case DT_DOUBLE:
828 return TensorGeneratorDouble(*this).RandomTensor({0.0}, {}, false, shape);
829 case DT_INT32:
830 return TensorGeneratorInt32(*this).RandomTensor({0}, {}, false, shape);
831 case DT_INT64:
832 return TensorGeneratorInt64(*this).RandomTensor({0}, {}, false, shape);
833 default:
834 LOG(FATAL) << "Unimplemented type " << dtype
835 << " in RandomNonNegativeTensor";
836 }
837 }
838
RandomNonNegativeTensor(DataType dtype)839 Tensor OpTest::RandomNonNegativeTensor(DataType dtype) {
840 return RandomNonNegativeTensor(dtype, RandomDims());
841 }
842
843 template <typename T>
RandomBoundedTensor(DataType dtype,T lo,T hi,bool needs_unique_values,absl::Span<const int64_t> shape)844 Tensor OpTest::RandomBoundedTensor(DataType dtype, T lo, T hi,
845 bool needs_unique_values,
846 absl::Span<const int64_t> shape) {
847 switch (dtype) {
848 case DT_FLOAT:
849 return TensorGeneratorFloat(*this).RandomTensor(
850 {lo}, {hi}, needs_unique_values, shape);
851 case DT_DOUBLE:
852 return TensorGeneratorDouble(*this).RandomTensor(
853 {lo}, {hi}, needs_unique_values, shape);
854 case DT_INT32:
855 return TensorGeneratorInt32(*this).RandomTensor(
856 {lo}, {hi}, needs_unique_values, shape);
857 case DT_INT64:
858 return TensorGeneratorInt64(*this).RandomTensor(
859 {lo}, {hi}, needs_unique_values, shape);
860 default:
861 LOG(FATAL) << "RandomBoundedTensor does not support type " << dtype
862 << ".";
863 }
864 }
865
866 template <typename T>
RandomBoundedTensor(DataType dtype,T lo,T hi,bool needs_unique_values)867 Tensor OpTest::RandomBoundedTensor(DataType dtype, T lo, T hi,
868 bool needs_unique_values) {
869 return RandomBoundedTensor<T>(dtype, lo, hi, needs_unique_values,
870 RandomDims());
871 }
872
RandomBoundedTensor(DataType dtype,Tensor lo,Tensor hi)873 Tensor OpTest::RandomBoundedTensor(DataType dtype, Tensor lo, Tensor hi) {
874 TensorShape shape = lo.shape();
875 if (hi.shape() != shape) {
876 LOG(FATAL) << "hi and lo do not have the same shape in RandomBoundedTensor";
877 }
878 if (hi.dtype() != dtype) {
879 LOG(FATAL) << "hi does not have the expected dtype in RandomBoundedTensor";
880 }
881 if (lo.dtype() != dtype) {
882 LOG(FATAL) << "lo does not have the expected dtype in RandomBoundedTensor";
883 }
884 Tensor tensor(dtype, shape);
885 switch (dtype) {
886 case DT_FLOAT: {
887 auto lo_flat = lo.flat<float>();
888 auto hi_flat = hi.flat<float>();
889 test::FillFn<float>(&tensor, [this, &lo_flat, &hi_flat](int i) -> float {
890 std::uniform_real_distribution<float> distribution(lo_flat(i),
891 hi_flat(i));
892 return distribution(generator());
893 });
894 break;
895 }
896 case DT_DOUBLE: {
897 auto lo_flat = lo.flat<double>();
898 auto hi_flat = hi.flat<double>();
899 test::FillFn<double>(
900 &tensor, [this, &lo_flat, &hi_flat](int i) -> double {
901 std::uniform_real_distribution<double> distribution(lo_flat(i),
902 hi_flat(i));
903 return distribution(generator());
904 });
905 break;
906 }
907 case DT_INT32: {
908 auto lo_flat = lo.flat<int32>();
909 auto hi_flat = hi.flat<int32>();
910 test::FillFn<int32>(&tensor, [this, &lo_flat, &hi_flat](int i) -> int32 {
911 std::uniform_int_distribution<int32> distribution(lo_flat(i),
912 hi_flat(i));
913 return distribution(generator());
914 });
915 break;
916 }
917 case DT_INT64: {
918 auto lo_flat = lo.flat<int64>();
919 auto hi_flat = hi.flat<int64>();
920 test::FillFn<int64_t>(
921 &tensor, [this, &lo_flat, &hi_flat](int i) -> int64_t {
922 std::uniform_int_distribution<int64_t> distribution(lo_flat(i),
923 hi_flat(i));
924 return distribution(generator());
925 });
926 break;
927 }
928 default:
929 LOG(FATAL) << "RandomBoundedTensor does not support type " << dtype
930 << ".";
931 }
932 return tensor;
933 }
934
RandomLteTensors(DataType dtype,absl::Span<const int64_t> shape)935 std::pair<Tensor, Tensor> OpTest::RandomLteTensors(
936 DataType dtype, absl::Span<const int64_t> shape) {
937 switch (dtype) {
938 case DT_FLOAT:
939 return TensorGeneratorFloat(*this).RandomLteTensors(shape);
940 case DT_DOUBLE:
941 return TensorGeneratorDouble(*this).RandomLteTensors(shape);
942 case DT_COMPLEX64:
943 LOG(FATAL) << "RandomLteTensors unavailable for DT_COMPLEX64";
944 break;
945 case DT_INT32:
946 return TensorGeneratorInt32(*this).RandomLteTensors(shape);
947 case DT_INT64:
948 return TensorGeneratorInt64(*this).RandomLteTensors(shape);
949 case DT_BOOL:
950 LOG(FATAL) << "RandomLteTensors unavailable for DT_BOOL";
951 break;
952 default:
953 LOG(FATAL) << "Unimplemented type " << dtype << " in RandomLteTensors";
954 }
955 Tensor tensor(dtype, TensorShape(shape));
956 return std::pair<Tensor, Tensor>(tensor, tensor);
957 }
958
RandomLteTensors(DataType dtype)959 std::pair<Tensor, Tensor> OpTest::RandomLteTensors(DataType dtype) {
960 return RandomLteTensors(dtype, RandomDims());
961 }
962
BroadcastableToDims(std::vector<int64_t> dims)963 std::vector<int64_t> OpTest::BroadcastableToDims(std::vector<int64_t> dims) {
964 if (dims.empty()) return dims;
965
966 // Remove some dimensions from the front of 'dims'.
967 size_t skip =
968 std::uniform_int_distribution<size_t>(0, dims.size() - 1)(generator());
969
970 std::vector<int64_t> bdims(dims.begin() + skip, dims.end());
971
972 // Randomly replace some of the remaining dimensions of 'dims' with 1.
973 std::bernoulli_distribution random_bool;
974
975 for (int64_t& dim : bdims) {
976 if (random_bool(generator())) {
977 dim = 1LL;
978 }
979 }
980 return bdims;
981 }
982
BroadcastableDims(std::vector<int64_t> dims)983 std::pair<std::vector<int64_t>, std::vector<int64_t>> OpTest::BroadcastableDims(
984 std::vector<int64_t> dims) {
985 auto bdims = BroadcastableToDims(dims);
986 // Possibly swap the roles of 'dims' and 'bdims'.
987 std::bernoulli_distribution random_bool;
988 if (random_bool(generator())) {
989 dims.swap(bdims);
990 }
991 return {dims, bdims};
992 }
993
994 std::pair<std::vector<int64_t>, std::vector<int64_t>>
BroadcastableDims()995 OpTest::BroadcastableDims() {
996 return BroadcastableDims(RandomDims(0, 3));
997 }
998
RandomReductionIndices(int rank)999 Tensor OpTest::RandomReductionIndices(int rank) {
1000 std::bernoulli_distribution random_bool;
1001 std::vector<int32> indices;
1002 for (int i = 0; i < rank; ++i) {
1003 if (random_bool(generator())) {
1004 indices.push_back(i);
1005 }
1006 }
1007 return test::AsTensor<int32>(indices);
1008 }
1009
1010 // Helper that converts 'values' to an int32 or int64 Tensor.
AsIntTensor(DataType dtype,const std::vector<int64_t> & values)1011 static Tensor AsIntTensor(DataType dtype, const std::vector<int64_t>& values) {
1012 switch (dtype) {
1013 case DT_INT32: {
1014 std::vector<int32> values32(values.begin(), values.end());
1015 return test::AsTensor<int32>(values32);
1016 }
1017 case DT_INT64:
1018 return test::AsTensor<int64_t>(values);
1019 default:
1020 LOG(FATAL);
1021 }
1022 }
1023
ChooseBatchMatMulArguments(bool broadcastable_batch)1024 OpTest::BatchMatMulArguments OpTest::ChooseBatchMatMulArguments(
1025 bool broadcastable_batch) {
1026 BatchMatMulArguments a;
1027 a.dtype = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
1028
1029 int64_t min_size = 0;
1030 int64_t max_size = 7;
1031 auto batch_dims_to = RandomDims(0, 3, min_size, max_size);
1032 int rank = batch_dims_to.size() + 2;
1033 std::pair<std::vector<int64_t>, std::vector<int64_t>> batch_dims_nobcast(
1034 batch_dims_to, batch_dims_to);
1035 auto batch_dims = broadcastable_batch ? BroadcastableDims(batch_dims_to)
1036 : batch_dims_nobcast;
1037 std::vector<int64_t> lhs_dims(batch_dims.first), rhs_dims(batch_dims.second);
1038 int64_t inner_dim = RandomDim();
1039 lhs_dims.push_back(RandomDim(min_size, max_size));
1040 lhs_dims.push_back(inner_dim);
1041 rhs_dims.push_back(inner_dim);
1042 rhs_dims.push_back(RandomDim(min_size, max_size));
1043
1044 std::bernoulli_distribution random_bool;
1045 a.adj_lhs = random_bool(generator());
1046 a.adj_rhs = random_bool(generator());
1047 if (a.adj_lhs) {
1048 std::swap(lhs_dims[rank - 1], lhs_dims[rank - 2]);
1049 }
1050 if (a.adj_rhs) {
1051 std::swap(rhs_dims[rank - 1], rhs_dims[rank - 2]);
1052 }
1053
1054 a.lhs_dims = lhs_dims;
1055 a.rhs_dims = rhs_dims;
1056 return a;
1057 }
1058
ChooseConcatArguments(bool int64_idx_allowed)1059 OpTest::ConcatArguments OpTest::ChooseConcatArguments(bool int64_idx_allowed) {
1060 ConcatArguments a;
1061
1062 std::bernoulli_distribution random_bool;
1063 bool use_int64_idx = random_bool(generator());
1064
1065 a.type = Choose<DataType>(kAllXlaTypes);
1066 a.type_idx = use_int64_idx ? DT_INT64 : DT_INT32;
1067 a.n = std::uniform_int_distribution<int>(2, 4)(generator());
1068
1069 std::vector<int64_t> dims = RandomDims(1, 4, 0, 64);
1070
1071 int axis =
1072 std::uniform_int_distribution<int32>(0, dims.size() - 1)(generator());
1073 a.axis =
1074 use_int64_idx ? test::AsScalar<int64>(axis) : test::AsScalar<int32>(axis);
1075
1076 for (int i = 0; i < a.n; ++i) {
1077 std::vector<int64_t> shape = dims;
1078 shape[axis] = RandomDim(0, 64);
1079 a.values.push_back(RandomTensor(a.type, false, shape));
1080 }
1081
1082 return a;
1083 }
1084
ChooseEinsumArguments()1085 OpTest::EinsumArguments OpTest::ChooseEinsumArguments() {
1086 EinsumArguments a;
1087
1088 enum EinsumType { matmul, batchmatmul, dot, outer };
1089 int op_kind = Choose<int>({matmul, batchmatmul, dot, outer});
1090 switch (op_kind) {
1091 case matmul:
1092 case batchmatmul: {
1093 std::vector<int64> dims;
1094 if (op_kind == matmul) {
1095 a.equation = "ij,jk->ik";
1096 dims = RandomDims(2, 2);
1097 } else {
1098 a.equation = "...ij,...jk->...ik";
1099 dims = RandomDims(2);
1100 }
1101 int64_t ndims = dims.size();
1102 int64_t inner_dim = RandomDim();
1103 a.lhs_dims = dims;
1104 a.rhs_dims = dims;
1105 a.lhs_dims[ndims - 1] = inner_dim;
1106 a.rhs_dims[ndims - 2] = inner_dim;
1107 break;
1108 }
1109 case dot: {
1110 a.equation = "i,i->";
1111 std::vector<int64> dims = RandomDims(1, 1);
1112 a.lhs_dims = dims;
1113 a.rhs_dims = dims;
1114 break;
1115 }
1116 case outer: {
1117 a.equation = "i,j->ij";
1118 a.lhs_dims = RandomDims(1, 1);
1119 a.rhs_dims = RandomDims(1, 1);
1120 break;
1121 }
1122 }
1123
1124 a.type = Choose<DataType>(kAllXlaTypes);
1125 return a;
1126 }
1127
ChooseGatherArguments(bool axis_0)1128 OpTest::GatherArguments OpTest::ChooseGatherArguments(bool axis_0) {
1129 GatherArguments a;
1130
1131 a.axis_type = DT_INT32;
1132 a.indices_type = DT_INT32;
1133 a.params_type = Choose<DataType>(kAllXlaTypes);
1134
1135 // Choose parameters such that
1136 // 0 <= batch_dims <= axis < params.rank <= kDefaultMaxRank
1137 a.batch_dims = 0;
1138 int64_t axis;
1139 if (axis_0) {
1140 axis = 0;
1141 } else {
1142 std::uniform_int_distribution<int64_t> axis_distribution(
1143 a.batch_dims, kDefaultMaxRank - 1);
1144 axis = axis_distribution(generator());
1145 }
1146 a.axis = test::AsScalar<int32>((int32)axis);
1147 a.params_shape = RandomDims(axis + 1, kDefaultMaxRank, 1, 16);
1148 std::vector<int64_t> indices_shape = RandomDims(0, 3, 0, 16);
1149 a.indices = RandomBoundedTensor<int32>(DT_INT32, 0, a.params_shape[axis] - 1,
1150 false, indices_shape);
1151
1152 return a;
1153 }
1154
ChoosePadArguments()1155 OpTest::PadArguments OpTest::ChoosePadArguments() {
1156 PadArguments a;
1157
1158 a.input_type = Choose<DataType>(kAllXlaTypes);
1159 a.input_shape = RandomDims();
1160 int input_rank = a.input_shape.size();
1161
1162 a.paddings_type = Choose<DataType>({DT_INT32, DT_INT64});
1163 std::vector<int64_t> paddings_vec;
1164 for (int i = 0; i < input_rank; ++i) {
1165 std::uniform_int_distribution<int> pad_distribution(0, a.input_shape[i]);
1166 int pad_size = pad_distribution(generator());
1167 std::uniform_int_distribution<int> lower_distribution(0, pad_size);
1168 int low_pad_size = lower_distribution(generator());
1169 paddings_vec.push_back(low_pad_size);
1170 paddings_vec.push_back(pad_size - low_pad_size);
1171 a.input_shape[i] -= pad_size;
1172 }
1173 CHECK(
1174 a.paddings.CopyFrom(AsIntTensor(a.paddings_type, paddings_vec),
1175 TensorShape({static_cast<int64_t>(input_rank), 2})));
1176
1177 a.constant_values = RandomTensor(a.input_type, false, {});
1178
1179 return a;
1180 }
1181
ChooseScatterArguments()1182 OpTest::ScatterArguments OpTest::ChooseScatterArguments() {
1183 ScatterArguments a;
1184
1185 a.type = Choose<DataType>(kAllXlaTypes);
1186 a.indices_type = DT_INT32;
1187 a.shape = RandomDims(1, kDefaultMaxRank, 1);
1188 int rank = a.shape.size();
1189 std::uniform_int_distribution<int32> index_len_dist(1, rank);
1190 int index_len = index_len_dist(generator());
1191 std::vector<int64_t> indices_first = RandomDims(1, kDefaultMaxRank - 1, 1);
1192 std::vector<int64_t> indices_shape(indices_first);
1193 indices_shape.push_back(index_len);
1194 std::vector<int64_t> updates_shape(indices_first);
1195 for (int i = 0; i < rank - index_len; ++i) {
1196 updates_shape.push_back(a.shape[index_len + i]);
1197 }
1198 Tensor indices_lo(a.indices_type, TensorShape(indices_shape));
1199 test::FillFn<int32>(&indices_lo, [](int i) -> int32 { return 0; });
1200 Tensor indices_hi(a.indices_type, TensorShape(indices_shape));
1201 test::FillFn<int32>(&indices_hi, [index_len, &a](int i) -> int32 {
1202 int idx_dim = i % index_len;
1203 return a.shape[idx_dim] - 1;
1204 });
1205 a.indices = RandomBoundedTensor(a.indices_type, indices_lo, indices_hi);
1206 a.updates = RandomTensor(a.type, false, updates_shape);
1207
1208 return a;
1209 }
1210
ChooseSliceArguments(bool neg_one_size)1211 OpTest::SliceArguments OpTest::ChooseSliceArguments(bool neg_one_size) {
1212 SliceArguments a;
1213
1214 a.type = Choose<DataType>(kAllXlaTypes);
1215 a.indices_type = DT_INT32;
1216 a.shape = RandomDims();
1217 int rank = a.shape.size();
1218
1219 std::vector<int32> indices(rank);
1220 a.size.resize(rank);
1221 for (int i = 0; i < rank; ++i) {
1222 indices[i] =
1223 std::uniform_int_distribution<int32>(0, a.shape[i])(generator());
1224 int64_t low = neg_one_size ? -1 : 0;
1225 a.size[i] = std::uniform_int_distribution<int64_t>(
1226 low, a.shape[i] - indices[i])(generator());
1227 }
1228 a.indices = test::AsTensor<int32>(indices);
1229
1230 return a;
1231 }
1232
ChooseWindowedSpatialDims(int num_spatial_dims)1233 OpTest::WindowedSpatialDims OpTest::ChooseWindowedSpatialDims(
1234 int num_spatial_dims) {
1235 WindowedSpatialDims d;
1236 d.padding = Choose<Padding>({SAME, VALID});
1237 std::uniform_int_distribution<int> random_int(1, 5);
1238 d.kernel_dims.resize(num_spatial_dims);
1239 d.input_dims.resize(num_spatial_dims);
1240 d.output_dims.resize(num_spatial_dims);
1241 d.stride_dims.resize(num_spatial_dims);
1242 for (int i = 0; i < num_spatial_dims; ++i) {
1243 Status s;
1244 // Repeatedly try different filter/stride sizes until we find a valid
1245 // combination.
1246 do {
1247 // CPU implementations require stride <= kernel size.
1248 d.kernel_dims[i] = random_int(generator()),
1249 d.input_dims[i] = RandomDim(d.kernel_dims[i]);
1250 d.stride_dims[i] =
1251 std::uniform_int_distribution<int>(1, d.kernel_dims[i])(generator());
1252 int64_t pad_dummy;
1253 s = GetWindowedOutputSize(d.input_dims[i], d.kernel_dims[i],
1254 d.stride_dims[i], d.padding, &d.output_dims[i],
1255 &pad_dummy);
1256 } while (!s.ok());
1257 }
1258 return d;
1259 }
1260
ChooseXlaDotArguments()1261 OpTest::XlaDotArguments OpTest::ChooseXlaDotArguments() {
1262 std::vector<int64_t> batch_dims = RandomDims(0, 2);
1263 std::vector<int64_t> contracting_dims = RandomDims(0, 2);
1264 std::vector<int64_t> lhs_outer_dims = RandomDims(0, 2);
1265 std::vector<int64_t> rhs_outer_dims = RandomDims(0, 2);
1266
1267 XlaDotArguments a;
1268 a.lhs_dims.insert(a.lhs_dims.end(), batch_dims.begin(), batch_dims.end());
1269 a.lhs_dims.insert(a.lhs_dims.end(), contracting_dims.begin(),
1270 contracting_dims.end());
1271 a.lhs_dims.insert(a.lhs_dims.end(), lhs_outer_dims.begin(),
1272 lhs_outer_dims.end());
1273 a.rhs_dims.insert(a.rhs_dims.end(), batch_dims.begin(), batch_dims.end());
1274 a.rhs_dims.insert(a.rhs_dims.end(), contracting_dims.begin(),
1275 contracting_dims.end());
1276 a.rhs_dims.insert(a.rhs_dims.end(), rhs_outer_dims.begin(),
1277 rhs_outer_dims.end());
1278
1279 xla::DotDimensionNumbers dnums;
1280 for (auto i = 0; i < batch_dims.size(); ++i) {
1281 dnums.add_lhs_batch_dimensions(i);
1282 dnums.add_rhs_batch_dimensions(i);
1283 }
1284 for (auto i = 0; i < contracting_dims.size(); ++i) {
1285 dnums.add_lhs_contracting_dimensions(batch_dims.size() + i);
1286 dnums.add_rhs_contracting_dimensions(batch_dims.size() + i);
1287 }
1288 dnums.SerializeToString(&a.dnums_encoded);
1289
1290 a.precision_config_encoded = "";
1291
1292 a.dtype = Choose<DataType>(kAllXlaTypes);
1293 return a;
1294 }
1295
ImageDims(TensorFormat format,int batch,int feature,const std::vector<int64_t> & spatial_dims)1296 std::vector<int64_t> OpTest::ImageDims(
1297 TensorFormat format, int batch, int feature,
1298 const std::vector<int64_t>& spatial_dims) {
1299 std::vector<int64_t> dims;
1300 switch (format) {
1301 case FORMAT_NHWC:
1302 dims.push_back(batch);
1303 for (int dim : spatial_dims) {
1304 dims.push_back(dim);
1305 }
1306 dims.push_back(feature);
1307 break;
1308 case FORMAT_NCHW:
1309 dims.push_back(batch);
1310 dims.push_back(feature);
1311 for (int dim : spatial_dims) {
1312 dims.push_back(dim);
1313 }
1314 break;
1315 default:
1316 LOG(FATAL) << "Tensor format " << ToString(format) << " not supported.";
1317 }
1318 return dims;
1319 }
1320
AsInt32s(const std::vector<int64_t> & int64s)1321 std::vector<int32> OpTest::AsInt32s(const std::vector<int64_t>& int64s) {
1322 return std::vector<int32>(int64s.begin(), int64s.end());
1323 }
1324
1325 // Functions for comparing tensors.
1326
1327 template <typename T>
Abs(T x)1328 double Abs(T x) {
1329 return std::fabs(x);
1330 }
1331
1332 template <>
Abs(complex64 x)1333 double Abs<complex64>(complex64 x) {
1334 return std::abs(x);
1335 }
1336
1337 template <typename T>
IsClose(const T & x,const T & y,double atol,double rtol)1338 bool IsClose(const T& x, const T& y, double atol, double rtol) {
1339 if (std::isnan(x) && std::isnan(y)) return true;
1340 if (x == y) return true; // Allow inf == inf.
1341 return Abs(x - y) < atol + rtol * Abs(x);
1342 }
1343
1344 template <>
IsClose(const complex64 & x,const complex64 & y,double atol,double rtol)1345 bool IsClose<complex64>(const complex64& x, const complex64& y, double atol,
1346 double rtol) {
1347 if (std::isnan(x.real()) && std::isnan(y.real())) {
1348 if (std::isnan(x.imag()) && std::isnan(y.imag())) {
1349 return true;
1350 }
1351 if (x.imag() == y.imag()) return true; // Allow inf == inf.
1352 return Abs(x.imag() - y.imag()) < atol + rtol * Abs(x.imag());
1353 } else if (std::isnan(x.imag()) && std::isnan(y.imag())) {
1354 if (x.real() == y.real()) return true; // Allow inf == inf.
1355 return Abs(x.real() - y.real()) < atol + rtol * Abs(x.real());
1356 }
1357 if (x == y) return true; // Allow inf == inf.
1358 return Abs(x - y) < atol + rtol * Abs(x);
1359 }
1360
1361 template <typename T>
Str(T x)1362 string Str(T x) {
1363 return absl::StrCat(x);
1364 }
1365 template <>
Str(complex64 x)1366 string Str<complex64>(complex64 x) {
1367 return absl::StrCat("(", x.real(), ", ", x.imag(), ")");
1368 }
1369
1370 template <typename T>
TensorsAreCloseImpl(const Tensor & x,const Tensor & y,double atol,double rtol)1371 Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol,
1372 double rtol) {
1373 auto Tx = x.flat<T>();
1374 auto Ty = y.flat<T>();
1375 for (int i = 0; i < Tx.size(); ++i) {
1376 if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
1377 return errors::InvalidArgument(
1378 absl::StrCat(i, "-th tensor element isn't close: ", Str(Tx(i)),
1379 " vs. ", Str(Ty(i)), ". x = ", x.DebugString(),
1380 "y = ", y.DebugString(), "atol = ", atol,
1381 " rtol = ", rtol, " tol = ", atol + rtol * Abs(Tx(i))));
1382 }
1383 }
1384 return OkStatus();
1385 }
1386
1387 template <typename T>
TensorsAreEqualImpl(const Tensor & x,const Tensor & y)1388 Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) {
1389 auto Tx = x.flat<T>();
1390 auto Ty = y.flat<T>();
1391 for (int i = 0; i < Tx.size(); ++i) {
1392 if (Tx(i) != Ty(i)) {
1393 return errors::InvalidArgument(absl::StrCat(
1394 i, "-th tensor element isn't equal: ", Str(Tx(i)), " vs. ",
1395 Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString()));
1396 }
1397 }
1398 return OkStatus();
1399 }
1400
TensorsAreEqualImplBfloat16(const Tensor & x,const Tensor & y)1401 Status TensorsAreEqualImplBfloat16(const Tensor& x, const Tensor& y) {
1402 auto Tx = x.flat<bfloat16>();
1403 auto Ty = y.flat<bfloat16>();
1404 for (int i = 0; i < Tx.size(); ++i) {
1405 if (Tx(i) != Ty(i)) {
1406 return errors::InvalidArgument(absl::StrCat(
1407 i, "-th tensor element isn't equal: ", static_cast<float>(Tx(i)),
1408 " vs. ", static_cast<float>(Ty(i)), ". x = ", x.DebugString(),
1409 "y = ", y.DebugString()));
1410 }
1411 }
1412 return OkStatus();
1413 }
1414
1415 // Tests if "x" and "y" are tensors of the same type, same shape, and with
1416 // close values. For floating-point tensors, the element-wise difference between
1417 // x and y must no more than atol + rtol * abs(x). For non-floating-point
1418 // tensors the values must match exactly.
TensorsAreClose(const Tensor & a,const Tensor & b,double atol,double rtol)1419 Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol,
1420 double rtol) {
1421 if (a.dtype() != b.dtype()) {
1422 return errors::InvalidArgument(absl::StrCat(
1423 "Tensors have different types: ", DataTypeString(a.dtype()), " and ",
1424 DataTypeString(b.dtype())));
1425 }
1426 if (!a.IsSameSize(b)) {
1427 return errors::InvalidArgument(
1428 absl::StrCat("Tensors have different shapes: ", a.shape().DebugString(),
1429 " and ", b.shape().DebugString()));
1430 }
1431
1432 switch (a.dtype()) {
1433 case DT_FLOAT:
1434 return TensorsAreCloseImpl<float>(a, b, atol, rtol);
1435 case DT_DOUBLE:
1436 return TensorsAreCloseImpl<double>(a, b, atol, rtol);
1437 case DT_COMPLEX64:
1438 return TensorsAreCloseImpl<complex64>(a, b, atol, rtol);
1439 case DT_INT32:
1440 return TensorsAreEqualImpl<int32>(a, b);
1441 case DT_INT64:
1442 return TensorsAreEqualImpl<int64_t>(a, b);
1443 case DT_BOOL:
1444 return TensorsAreEqualImpl<bool>(a, b);
1445 case DT_BFLOAT16:
1446 return TensorsAreEqualImplBfloat16(a, b);
1447 default:
1448 LOG(FATAL) << "Unexpected type : " << DataTypeString(a.dtype());
1449 }
1450 }
1451
ExpectTfAndXlaOutputsAreClose(const OpTestBuilder & builder,double atol,double rtol)1452 OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
1453 const OpTestBuilder& builder, double atol, double rtol) {
1454 const std::vector<OpTestBuilder::InputDescription>& inputs = builder.inputs();
1455 std::vector<Tensor> input_tensors;
1456 input_tensors.reserve(inputs.size());
1457 for (const OpTestBuilder::InputDescription& input : inputs) {
1458 if (input.type == DT_INVALID) {
1459 input_tensors.push_back(input.tensor);
1460 } else {
1461 std::vector<int64_t> dims;
1462 if (input.has_dims) {
1463 dims = input.dims;
1464 } else {
1465 dims = RandomDims();
1466 }
1467 if (!TensorSizeIsOk(dims)) {
1468 VLOG(1) << "Input: " << input.type << " "
1469 << TensorShape(input.dims).DebugString();
1470 VLOG(1) << "Ignoring oversize dims.";
1471 return kInvalid;
1472 }
1473 input_tensors.push_back(
1474 RandomTensor(input.type, input.needs_unique_values, dims));
1475 }
1476 VLOG(1) << "Input: " << input_tensors.back().DebugString();
1477 }
1478
1479 string reference_device =
1480 LocalDeviceToFullDeviceName(*tf_xla_reference_device_ptr);
1481 string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr);
1482
1483 DeviceNameUtils::ParsedName parsed_name;
1484 if (!DeviceNameUtils::ParseLocalName(*tf_xla_test_device_ptr, &parsed_name)) {
1485 LOG(ERROR) << "Could not parse device name: " << *tf_xla_test_device_ptr;
1486 return kFatalError;
1487 }
1488 DeviceType test_device_type(parsed_name.type);
1489 ++num_tests_;
1490
1491 GraphDef graph;
1492 std::vector<string> expected_inputs, test_inputs;
1493 std::vector<string> expected_fetches, test_fetches;
1494 Status status = builder.BuildGraph(
1495 absl::StrCat("test", num_tests_, "_expected"), reference_device,
1496 /*use_jit=*/false, &graph, /*test_node_def=*/nullptr, &expected_inputs,
1497 &expected_fetches);
1498 if (!status.ok()) {
1499 LOG(ERROR) << "Expected graph construction failed: " << status;
1500 return kFatalError;
1501 }
1502
1503 NodeDef* node_def;
1504 status = builder.BuildGraph(absl::StrCat("test", num_tests_, "_test"),
1505 test_device, tf_xla_test_use_jit, &graph,
1506 &node_def, &test_inputs, &test_fetches);
1507 if (!status.ok()) {
1508 LOG(ERROR) << "Test graph construction failed: " << status;
1509 return kFatalError;
1510 }
1511
1512 // Check that there's a kernel corresponding to 'node_def' on the device under
1513 // test.
1514 status = FindKernelDef(test_device_type, *node_def, nullptr, nullptr);
1515 if (!status.ok()) {
1516 VLOG(1) << "Skipping test because there is no corresponding registered "
1517 << "kernel on the test device: " << status;
1518 return kInvalid;
1519 }
1520
1521 status = session_->Extend(graph);
1522 if (!status.ok()) {
1523 LOG(ERROR) << "Session::Extend() failed: " << status;
1524 return kFatalError;
1525 }
1526
1527 std::vector<std::pair<string, Tensor>> expected_feeds(expected_inputs.size());
1528 std::vector<std::pair<string, Tensor>> test_feeds(test_inputs.size());
1529 CHECK_EQ(input_tensors.size(), expected_inputs.size());
1530 CHECK_EQ(input_tensors.size(), test_inputs.size());
1531
1532 for (int i = 0; i < input_tensors.size(); ++i) {
1533 expected_feeds[i] = {expected_inputs[i], input_tensors[i]};
1534 test_feeds[i] = {test_inputs[i], input_tensors[i]};
1535 }
1536
1537 std::vector<Tensor> expected_outputs, test_outputs;
1538 VLOG(1) << "Running expected graph";
1539 Status s =
1540 session_->Run(expected_feeds, expected_fetches, {}, &expected_outputs);
1541 if (!s.ok()) {
1542 VLOG(1) << "Expected graph failed with status: " << s << ". Ignoring test";
1543 return kInvalid;
1544 }
1545 for (const Tensor& expected : expected_outputs) {
1546 VLOG(1) << "Expected: " << expected.DebugString();
1547 }
1548
1549 VLOG(1) << "Running test graph";
1550 status = session_->Run(test_feeds, test_fetches, {}, &test_outputs);
1551 if (!status.ok()) {
1552 LOG(ERROR) << "Test graph failed: " << status;
1553 return kFatalError;
1554 }
1555
1556 CHECK_EQ(expected_outputs.size(), test_outputs.size());
1557 for (int j = 0; s.ok() && j < test_outputs.size(); ++j) {
1558 s = TensorsAreClose(expected_outputs[j], test_outputs[j], atol, rtol);
1559 }
1560 TF_EXPECT_OK(s);
1561
1562 return kOk;
1563 }
1564
TEST_F(OpTest,_EagerConst)1565 TEST_F(OpTest, _EagerConst) {
1566 Repeatedly([this]() {
1567 auto type = Choose<DataType>(kAllXlaTypes);
1568 return ExpectTfAndXlaOutputsAreClose(
1569 OpTestBuilder("_EagerConst").RandomInput(type).Attr("T", type));
1570 });
1571 }
1572
TEST_F(OpTest,Abs)1573 TEST_F(OpTest, Abs) {
1574 Repeatedly([this]() {
1575 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1576 return ExpectTfAndXlaOutputsAreClose(
1577 OpTestBuilder("Abs").RandomInput(type).Attr("T", type));
1578 });
1579 }
1580
TEST_F(OpTest,Acos)1581 TEST_F(OpTest, Acos) {
1582 Repeatedly([this]() {
1583 return ExpectTfAndXlaOutputsAreClose(
1584 OpTestBuilder("Acos")
1585 .Input(RandomBoundedTensor<float>(DT_FLOAT, -1, 1, false))
1586 .Attr("T", DT_FLOAT));
1587 });
1588 }
1589
TEST_F(OpTest,Acosh)1590 TEST_F(OpTest, Acosh) {
1591 Repeatedly([this]() {
1592 return ExpectTfAndXlaOutputsAreClose(
1593 OpTestBuilder("Acosh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1594 });
1595 }
1596
TEST_F(OpTest,Add)1597 TEST_F(OpTest, Add) {
1598 Repeatedly([this]() {
1599 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1600 auto dims = BroadcastableDims();
1601 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add")
1602 .RandomInput(type, dims.first)
1603 .RandomInput(type, dims.second)
1604 .Attr("T", type));
1605 });
1606 }
1607
TEST_F(OpTest,AddN)1608 TEST_F(OpTest, AddN) {
1609 Repeatedly([this]() {
1610 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1611 int n = std::uniform_int_distribution<int>(1, 5)(generator());
1612
1613 auto shape = RandomDims();
1614
1615 OpTestBuilder builder("AddN");
1616 builder.Attr("T", type);
1617 builder.Attr("N", n);
1618 for (int i = 0; i < n; ++i) {
1619 builder.RandomInput(type, shape);
1620 }
1621 return ExpectTfAndXlaOutputsAreClose(builder);
1622 });
1623 }
1624
TEST_F(OpTest,AddV2)1625 TEST_F(OpTest, AddV2) {
1626 Repeatedly([this]() {
1627 auto type = Choose<DataType>(kAllXlaTypes);
1628 auto dims = BroadcastableDims();
1629 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("AddV2")
1630 .RandomInput(type, dims.first)
1631 .RandomInput(type, dims.second)
1632 .Attr("T", type));
1633 });
1634 }
1635
TEST_F(OpTest,All)1636 TEST_F(OpTest, All) {
1637 Repeatedly([this]() {
1638 std::vector<int64_t> data_dims = RandomDims();
1639 Tensor indices = RandomReductionIndices(data_dims.size());
1640 bool keep_dims = Choose<bool>({false, true});
1641 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("All")
1642 .RandomInput(DT_BOOL, data_dims)
1643 .Input(indices)
1644 .Attr("keep_dims", keep_dims));
1645 });
1646 }
1647
TEST_F(OpTest,Angle)1648 TEST_F(OpTest, Angle) {
1649 Repeatedly([this]() {
1650 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Angle")
1651 .RandomInput(DT_COMPLEX64)
1652 .Attr("T", DT_COMPLEX64));
1653 });
1654 }
1655
TEST_F(OpTest,Any)1656 TEST_F(OpTest, Any) {
1657 Repeatedly([this]() {
1658 std::vector<int64_t> data_dims = RandomDims();
1659 Tensor indices = RandomReductionIndices(data_dims.size());
1660 bool keep_dims = Choose<bool>({false, true});
1661 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Any")
1662 .RandomInput(DT_BOOL, data_dims)
1663 .Input(indices)
1664 .Attr("keep_dims", keep_dims));
1665 });
1666 }
1667
TEST_F(OpTest,ApproximateEqual)1668 TEST_F(OpTest, ApproximateEqual) {
1669 Repeatedly([this]() {
1670 auto dims = BroadcastableDims();
1671 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
1672 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual")
1673 .RandomInput(type, dims.first)
1674 .RandomInput(type, dims.second)
1675 .Attr("T", DT_FLOAT));
1676 });
1677 }
1678
TEST_F(OpTest,ArgMax)1679 TEST_F(OpTest, ArgMax) {
1680 Repeatedly([this]() {
1681 auto type = Choose<DataType>({DT_BOOL, DT_FLOAT});
1682 std::vector<int64_t> dims = RandomDims(1, 5, 1);
1683 int num_dims = dims.size();
1684 int reduce_dim =
1685 std::uniform_int_distribution<int32>(-num_dims, num_dims)(generator());
1686 return ExpectTfAndXlaOutputsAreClose(
1687 OpTestBuilder("ArgMax")
1688 .RandomInput(type, dims)
1689 .Input(test::AsScalar<int32>(reduce_dim))
1690 .Attr("T", type)
1691 .Attr("Tidx", DT_INT32)
1692 .Attr("output_type", DT_INT32));
1693 });
1694 }
1695
TEST_F(OpTest,ArgMin)1696 TEST_F(OpTest, ArgMin) {
1697 Repeatedly([this]() {
1698 auto type = Choose<DataType>({DT_BOOL, DT_FLOAT});
1699 std::vector<int64_t> dims = RandomDims(1, 5, 1);
1700 int num_dims = dims.size();
1701 int reduce_dim =
1702 std::uniform_int_distribution<int32>(-num_dims, num_dims)(generator());
1703 return ExpectTfAndXlaOutputsAreClose(
1704 OpTestBuilder("ArgMin")
1705 .RandomInput(type, dims)
1706 .Input(test::AsScalar<int32>(reduce_dim))
1707 .Attr("T", type)
1708 .Attr("Tidx", DT_INT32)
1709 .Attr("output_type", DT_INT32));
1710 });
1711 }
1712
TEST_F(OpTest,Asin)1713 TEST_F(OpTest, Asin) {
1714 Repeatedly([this]() {
1715 return ExpectTfAndXlaOutputsAreClose(
1716 OpTestBuilder("Asin")
1717 .Input(RandomBoundedTensor<float>(DT_FLOAT, -1, 1, false))
1718 .Attr("T", DT_FLOAT));
1719 });
1720 }
1721
TEST_F(OpTest,Asinh)1722 TEST_F(OpTest, Asinh) {
1723 Repeatedly([this]() {
1724 return ExpectTfAndXlaOutputsAreClose(
1725 OpTestBuilder("Asinh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1726 });
1727 }
1728
TEST_F(OpTest,Atanh)1729 TEST_F(OpTest, Atanh) {
1730 Repeatedly([this]() {
1731 return ExpectTfAndXlaOutputsAreClose(
1732 OpTestBuilder("Atanh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1733 });
1734 }
1735
TEST_F(OpTest,Atan)1736 TEST_F(OpTest, Atan) {
1737 Repeatedly([this]() {
1738 return ExpectTfAndXlaOutputsAreClose(
1739 OpTestBuilder("Atan").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1740 });
1741 }
1742
TEST_F(OpTest,Atan2)1743 TEST_F(OpTest, Atan2) {
1744 Repeatedly([this]() {
1745 auto dims = BroadcastableDims();
1746 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Atan2")
1747 .RandomInput(DT_FLOAT, dims.first)
1748 .RandomInput(DT_FLOAT, dims.second)
1749 .Attr("T", DT_FLOAT));
1750 });
1751 }
1752
TEST_F(OpTest,AvgPool)1753 TEST_F(OpTest, AvgPool) {
1754 Repeatedly([this]() {
1755 std::uniform_int_distribution<int> random_int(1, 5);
1756 std::vector<int64_t> dims = RandomDims(4, 4, 1);
1757 int kernel_rows =
1758 std::uniform_int_distribution<int>(1, dims[1])(generator());
1759 int kernel_cols =
1760 std::uniform_int_distribution<int>(1, dims[2])(generator());
1761 int stride_rows = random_int(generator()),
1762 stride_cols = random_int(generator());
1763 string padding = Choose<string>({"SAME", "VALID"});
1764 return ExpectTfAndXlaOutputsAreClose(
1765 OpTestBuilder("AvgPool")
1766 .RandomInput(DT_FLOAT, dims)
1767 .Attr("T", DT_FLOAT)
1768 .Attr("ksize", {1, kernel_rows, kernel_cols, 1})
1769 .Attr("strides", {1, stride_rows, stride_cols, 1})
1770 .Attr("padding", padding)
1771 .Attr("data_format", "NHWC"));
1772 });
1773 // TODO(phawkins): the CPU device only implements spatial pooling. Add tests
1774 // for batch pooling when supported.
1775 }
1776
TEST_F(OpTest,AvgPool3D)1777 TEST_F(OpTest, AvgPool3D) {
1778 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1779 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
1780 Repeatedly([this]() {
1781 std::uniform_int_distribution<int> random_int(1, 5);
1782 std::vector<int64_t> dims = RandomDims(5, 5, 1);
1783
1784 std::vector<int64_t> input_dims, kernel_dims, stride_dims;
1785 for (int i = 0; i < 3; ++i) {
1786 kernel_dims.push_back(
1787 std::uniform_int_distribution<int>(1, dims[i])(generator()));
1788 input_dims.push_back(dims[i]);
1789 stride_dims.push_back(random_int(generator()));
1790 }
1791 int64_t batch = dims[3];
1792 int64_t feature = dims[4];
1793
1794 string padding = Choose<string>({"SAME", "VALID"});
1795 return ExpectTfAndXlaOutputsAreClose(
1796 OpTestBuilder("AvgPool3D")
1797 .RandomInput(DT_FLOAT,
1798 ImageDims(FORMAT_NHWC, batch, feature, input_dims))
1799 .Attr("T", DT_FLOAT)
1800 .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, kernel_dims))
1801 .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, stride_dims))
1802 .Attr("padding", padding)
1803 .Attr("data_format", "NDHWC"));
1804 });
1805 // TODO(phawkins): test NCHW format (not supported by CPU)
1806 }
1807
TEST_F(OpTest,AvgPoolGrad)1808 TEST_F(OpTest, AvgPoolGrad) {
1809 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1810 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
1811 Repeatedly([this]() {
1812 int batch = RandomDim(1), features = RandomDim(1);
1813 WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
1814 std::vector<int32> input_dims =
1815 AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims));
1816 std::vector<int64_t> output_dims =
1817 ImageDims(FORMAT_NHWC, batch, features, d.output_dims);
1818 return ExpectTfAndXlaOutputsAreClose(
1819 OpTestBuilder("AvgPoolGrad")
1820 .Input(test::AsTensor<int32>(input_dims))
1821 .RandomInput(DT_FLOAT, output_dims)
1822 .Attr("T", DT_FLOAT)
1823 .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims))
1824 .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
1825 .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
1826 .Attr("data_format", "NHWC"));
1827 });
1828 }
1829
TEST_F(OpTest,AvgPool3DGrad)1830 TEST_F(OpTest, AvgPool3DGrad) {
1831 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1832 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
1833 Repeatedly([this]() {
1834 int batch = RandomDim(1), features = RandomDim(1);
1835 WindowedSpatialDims d = ChooseWindowedSpatialDims(3);
1836 std::vector<int32> input_dims =
1837 AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims));
1838 std::vector<int64_t> output_dims =
1839 ImageDims(FORMAT_NHWC, batch, features, d.output_dims);
1840 return ExpectTfAndXlaOutputsAreClose(
1841 OpTestBuilder("AvgPool3DGrad")
1842 .Input(test::AsTensor<int32>(input_dims))
1843 .RandomInput(DT_FLOAT, output_dims)
1844 .Attr("T", DT_FLOAT)
1845 .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims))
1846 .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
1847 .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
1848 .Attr("data_format", "NDHWC"));
1849 });
1850 }
1851
TEST_F(OpTest,BatchMatMul)1852 TEST_F(OpTest, BatchMatMul) {
1853 // See note about failing Kokoro tests: b/214080339#comment22
1854 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1855 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
1856 Repeatedly([this]() {
1857 const BatchMatMulArguments a = ChooseBatchMatMulArguments(false);
1858 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
1859 .RandomInput(a.dtype, a.lhs_dims)
1860 .RandomInput(a.dtype, a.rhs_dims)
1861 .Attr("T", a.dtype)
1862 .Attr("adj_x", a.adj_lhs)
1863 .Attr("adj_y", a.adj_rhs));
1864 });
1865 }
1866
TEST_F(OpTest,BatchMatMulV2)1867 TEST_F(OpTest, BatchMatMulV2) {
1868 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1869 // :randomized_tests_seeded is flaky with --tf_xla_random_seed=200839030
1870 // See b/229622638.
1871 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
1872 Repeatedly([this]() {
1873 const BatchMatMulArguments a = ChooseBatchMatMulArguments(true);
1874 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMulV2")
1875 .RandomInput(a.dtype, a.lhs_dims)
1876 .RandomInput(a.dtype, a.rhs_dims)
1877 .Attr("T", a.dtype)
1878 .Attr("adj_x", a.adj_lhs)
1879 .Attr("adj_y", a.adj_rhs));
1880 });
1881 }
1882
TEST_F(OpTest,BatchToSpace)1883 TEST_F(OpTest, BatchToSpace) {
1884 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1885 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
1886 Repeatedly([this]() {
1887 const int num_block_dims = 2;
1888 std::vector<int64_t> block_dims =
1889 RandomDims(num_block_dims, num_block_dims, 0, 5);
1890 int64_t block_size = RandomDim(2, 5);
1891
1892 std::vector<int64_t> input_dims(1 + num_block_dims + 1);
1893 input_dims[0] = RandomDim();
1894 for (int i = 0; i < num_block_dims; ++i) {
1895 input_dims[0] *= block_size;
1896 input_dims[1 + i] = block_dims[i];
1897 }
1898 input_dims[1 + num_block_dims] = RandomDim();
1899
1900 std::vector<int64_t> crop_vals;
1901 std::uniform_int_distribution<int> distribution(0, 4);
1902 for (int i = 0; i < num_block_dims; ++i) {
1903 // Chooses crop values; does not always choose legal values.
1904 crop_vals.push_back(distribution(generator()));
1905 crop_vals.push_back(distribution(generator()));
1906 }
1907 Tensor crops;
1908 CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
1909 TensorShape({num_block_dims, 2})));
1910
1911 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1912 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace")
1913 .RandomInput(type, input_dims)
1914 .Input(crops)
1915 .Attr("T", type)
1916 .Attr("block_size", block_size));
1917 });
1918 }
1919
TEST_F(OpTest,BatchToSpaceND)1920 TEST_F(OpTest, BatchToSpaceND) {
1921 Repeatedly([this]() {
1922 std::vector<int64_t> block_dims = RandomDims(1, 3, 0, 5);
1923 int num_block_dims = block_dims.size();
1924 std::vector<int64_t> remaining_dims = RandomDims(0, 3);
1925 std::vector<int64_t> block_multipliers =
1926 RandomDims(block_dims.size(), block_dims.size(), 0, 4);
1927
1928 std::vector<int64_t> input_dims(1 + num_block_dims + remaining_dims.size());
1929 input_dims[0] = RandomDim();
1930 for (int i = 0; i < num_block_dims; ++i) {
1931 input_dims[0] *= block_dims[i];
1932 }
1933 std::copy(block_multipliers.begin(), block_multipliers.end(),
1934 input_dims.begin() + 1);
1935 std::copy(remaining_dims.begin(), remaining_dims.end(),
1936 input_dims.begin() + 1 + num_block_dims);
1937
1938 std::vector<int64_t> crop_vals;
1939 std::uniform_int_distribution<int> distribution(0, 3);
1940 for (int i = 0; i < num_block_dims; ++i) {
1941 // Chooses crop values; does not always choose legal values.
1942 crop_vals.push_back(distribution(generator()));
1943 crop_vals.push_back(distribution(generator()));
1944 }
1945 Tensor crops;
1946 CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
1947 TensorShape({num_block_dims, 2})));
1948
1949 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1950 return ExpectTfAndXlaOutputsAreClose(
1951 OpTestBuilder("BatchToSpaceND")
1952 .RandomInput(type, input_dims)
1953 .Input(test::AsTensor<int32>(
1954 std::vector<int32>(block_dims.begin(), block_dims.end())))
1955 .Input(crops)
1956 .Attr("T", type));
1957 });
1958 }
1959
TEST_F(OpTest,BiasAdd)1960 TEST_F(OpTest, BiasAdd) {
1961 Repeatedly([this]() {
1962 auto x_dims = RandomDims(2, kDefaultMaxRank);
1963 auto y_dims = {x_dims[x_dims.size() - 1]};
1964 // TODO(phawkins): test both data formats.
1965 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1966 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAdd")
1967 .RandomInput(type, x_dims)
1968 .RandomInput(type, y_dims)
1969 .Attr("T", type));
1970 });
1971 }
1972
TEST_F(OpTest,BiasAddGrad)1973 TEST_F(OpTest, BiasAddGrad) {
1974 Repeatedly([this]() {
1975 // TODO(phawkins): test both data formats.
1976 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1977 return ExpectTfAndXlaOutputsAreClose(
1978 OpTestBuilder("BiasAddGrad").RandomInput(type).Attr("T", type));
1979 });
1980 }
1981
TEST_F(OpTest,BiasAddV1)1982 TEST_F(OpTest, BiasAddV1) {
1983 Repeatedly([this]() {
1984 auto x_dims = RandomDims(2, kDefaultMaxRank);
1985 auto y_dims = {x_dims[x_dims.size() - 1]};
1986 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1987 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAddV1")
1988 .RandomInput(type, x_dims)
1989 .RandomInput(type, y_dims)
1990 .Attr("T", type));
1991 });
1992 }
1993
TEST_F(OpTest,Bitcast)1994 TEST_F(OpTest, Bitcast) {
1995 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1996 Repeatedly([this]() { // NOLINT: due to GTEST_SKIP
1997 auto src_type = Choose<DataType>(kAllNumberTypes);
1998 auto dst_type = Choose<DataType>(kAllNumberTypes);
1999 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Bitcast")
2000 .RandomInput(src_type)
2001 .Attr("T", src_type)
2002 .Attr("type", dst_type));
2003 });
2004 }
2005
TEST_F(OpTest,BitwiseAnd)2006 TEST_F(OpTest, BitwiseAnd) {
2007 Repeatedly([this]() {
2008 DataType type = DT_INT32;
2009 auto dims = BroadcastableDims();
2010 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseAnd")
2011 .RandomInput(type, dims.first)
2012 .RandomInput(type, dims.second)
2013 .Attr("T", type));
2014 });
2015 }
2016
TEST_F(OpTest,BitwiseOr)2017 TEST_F(OpTest, BitwiseOr) {
2018 Repeatedly([this]() {
2019 DataType type = DT_INT32;
2020 auto dims = BroadcastableDims();
2021 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseOr")
2022 .RandomInput(type, dims.first)
2023 .RandomInput(type, dims.second)
2024 .Attr("T", type));
2025 });
2026 }
2027
TEST_F(OpTest,BitwiseXor)2028 TEST_F(OpTest, BitwiseXor) {
2029 Repeatedly([this]() {
2030 auto dims = BroadcastableDims();
2031 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseXor")
2032 .RandomInput(DT_INT32, dims.first)
2033 .RandomInput(DT_INT32, dims.second)
2034 .Attr("T", DT_INT32));
2035 });
2036 }
2037
TEST_F(OpTest,BroadcastArgs)2038 TEST_F(OpTest, BroadcastArgs) {
2039 Repeatedly([this]() {
2040 // TODO(phawkins): only int32 seems to be implemented in Tensorflow.
2041 // auto type = Choose<DataType>({DT_INT32, DT_INT64});
2042 DataType type = DT_INT32;
2043 auto dims = BroadcastableDims();
2044 return ExpectTfAndXlaOutputsAreClose(
2045 OpTestBuilder("BroadcastArgs")
2046 .Input(AsIntTensor(type, dims.first))
2047 .Input(AsIntTensor(type, dims.second))
2048 .Attr("T", type));
2049 });
2050 }
2051
TEST_F(OpTest,BroadcastGradientArgs)2052 TEST_F(OpTest, BroadcastGradientArgs) {
2053 Repeatedly([this]() {
2054 // TODO(phawkins): only int32 seems to be implemented in Tensorflow.
2055 // auto type = Choose<DataType>({DT_INT32, DT_INT64});
2056 DataType type = DT_INT32;
2057 auto dims = BroadcastableDims();
2058 return ExpectTfAndXlaOutputsAreClose(
2059 OpTestBuilder("BroadcastGradientArgs")
2060 .Input(AsIntTensor(type, dims.first))
2061 .Input(AsIntTensor(type, dims.second))
2062 .Attr("T", type));
2063 });
2064 }
2065
TEST_F(OpTest,BroadcastTo)2066 TEST_F(OpTest, BroadcastTo) {
2067 Repeatedly([this]() {
2068 auto type = Choose<DataType>(kAllXlaTypes);
2069 auto type_idx = Choose<DataType>({DT_INT32, DT_INT64});
2070 auto dims_to = RandomDims();
2071 auto dims_from = BroadcastableToDims(dims_to);
2072 return ExpectTfAndXlaOutputsAreClose(
2073 OpTestBuilder("BroadcastTo")
2074 .RandomInput(type, dims_from)
2075 .Input(AsIntTensor(type_idx, dims_to))
2076 .Attr("T", type)
2077 .Attr("Tidx", type_idx));
2078 });
2079 }
2080
TEST_F(OpTest,Cast)2081 TEST_F(OpTest, Cast) {
2082 Repeatedly([this]() {
2083 DataType src_type, dst_type;
2084 src_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64});
2085 dst_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64});
2086 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast")
2087 .RandomInput(src_type)
2088 .Attr("SrcT", src_type)
2089 .Attr("DstT", dst_type));
2090 });
2091 }
2092
TEST_F(OpTest,CastBF16)2093 TEST_F(OpTest, CastBF16) {
2094 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2095 Repeatedly([this]() {
2096 DataType src_type, dst_type;
2097 src_type = Choose<DataType>({DT_FLOAT});
2098 dst_type = Choose<DataType>({DT_BFLOAT16});
2099 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast")
2100 .RandomInput(src_type)
2101 .Attr("SrcT", src_type)
2102 .Attr("DstT", dst_type)
2103 .Attr("Truncate", true));
2104 });
2105 }
2106
TEST_F(OpTest,Ceil)2107 TEST_F(OpTest, Ceil) {
2108 Repeatedly([this]() {
2109 return ExpectTfAndXlaOutputsAreClose(
2110 OpTestBuilder("Ceil").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2111 });
2112 }
2113
TEST_F(OpTest,ClipByValue)2114 TEST_F(OpTest, ClipByValue) {
2115 // TODO(b/211012085): Change input_dims to BroadcastableDimsN(3). The
2116 // compiled ClipByValue fails in this case.
2117 // --tf_xla_random_seed=200839030
2118 Repeatedly([this]() {
2119 auto type = Choose<DataType>({DT_INT32, DT_INT64, DT_FLOAT});
2120 // ClipByValue requires that broadcasting min and max tensors do not cause
2121 // the returned shape to be larger than the input shape.
2122 auto input_dims = RandomDims();
2123 // clip_value_min must be <= clip_value_max for correct results. Different
2124 // implementations handle the max < min case differently, so ensure that
2125 // min <= max.
2126 auto min_max_dims = BroadcastableToDims(input_dims);
2127 auto min_max = RandomLteTensors(type, min_max_dims);
2128 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ClipByValue")
2129 .RandomInput(type, input_dims)
2130 .Input(min_max.first)
2131 .Input(min_max.second)
2132 .Attr("T", type));
2133 });
2134 }
2135
TEST_F(OpTest,Complex)2136 TEST_F(OpTest, Complex) {
2137 Repeatedly([this]() {
2138 auto dims = BroadcastableDims();
2139 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Complex")
2140 .RandomInput(DT_FLOAT, dims.first)
2141 .RandomInput(DT_FLOAT, dims.second)
2142 .Attr("T", DT_FLOAT));
2143 });
2144 }
2145
TEST_F(OpTest,Concat)2146 TEST_F(OpTest, Concat) {
2147 Repeatedly([this]() { // NOLINT: due to GTEST_SKIP
2148 ConcatArguments a = ChooseConcatArguments(false);
2149 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Concat")
2150 .Input(a.axis)
2151 .VariadicInput(a.values)
2152 .Attr("N", a.n)
2153 .Attr("T", a.type));
2154 });
2155 }
2156
TEST_F(OpTest,ConcatV2)2157 TEST_F(OpTest, ConcatV2) {
2158 Repeatedly([this]() {
2159 ConcatArguments a = ChooseConcatArguments(true);
2160 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ConcatV2")
2161 .VariadicInput(a.values)
2162 .Input(a.axis)
2163 .Attr("N", a.n)
2164 .Attr("T", a.type)
2165 .Attr("Tidx", a.type_idx));
2166 });
2167 }
2168
TEST_F(OpTest,ConcatOffset)2169 TEST_F(OpTest, ConcatOffset) {
2170 Repeatedly([this]() {
2171 int n = std::uniform_int_distribution<int>(2, 5)(generator());
2172
2173 std::vector<int64_t> dims = RandomDims(1);
2174 int concat_dim =
2175 std::uniform_int_distribution<int32>(0, dims.size() - 1)(generator());
2176
2177 OpTestBuilder builder("ConcatOffset");
2178 builder.Input(test::AsScalar<int32>(concat_dim));
2179 builder.Attr("N", n);
2180 for (int i = 0; i < n; ++i) {
2181 std::vector<int32> shape(dims.begin(), dims.end());
2182 shape[concat_dim] = RandomDim();
2183 builder.Input(test::AsTensor<int32>(shape));
2184 }
2185 return ExpectTfAndXlaOutputsAreClose(builder);
2186 });
2187 }
2188
TEST_F(OpTest,Conj)2189 TEST_F(OpTest, Conj) {
2190 Repeatedly([this]() {
2191 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Conj")
2192 .RandomInput(DT_COMPLEX64)
2193 .Attr("T", DT_COMPLEX64));
2194 });
2195 }
2196
TEST_F(OpTest,Const)2197 TEST_F(OpTest, Const) {
2198 Repeatedly([this]() {
2199 auto type = Choose<DataType>({DT_FLOAT});
2200 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Const")
2201 .Attr("value", RandomTensor(type))
2202 .Attr("dtype", type));
2203 });
2204 }
2205
TEST_F(OpTest,FFT)2206 TEST_F(OpTest, FFT) {
2207 Repeatedly([this]() {
2208 std::vector<int64_t> dims = RandomDims(1, kDefaultMaxRank);
2209 return ExpectTfAndXlaOutputsAreClose(
2210 OpTestBuilder("FFT").RandomInput(DT_COMPLEX64, dims));
2211 });
2212 }
2213
TEST_F(OpTest,FFT2D)2214 TEST_F(OpTest, FFT2D) {
2215 Repeatedly([this]() {
2216 std::vector<int64_t> dims = RandomDims(2, kDefaultMaxRank);
2217 return ExpectTfAndXlaOutputsAreClose(
2218 OpTestBuilder("FFT2D").RandomInput(DT_COMPLEX64, dims));
2219 });
2220 }
2221
TEST_F(OpTest,FFT3D)2222 TEST_F(OpTest, FFT3D) {
2223 Repeatedly([this]() {
2224 std::vector<int64_t> dims = RandomDims(3, kDefaultMaxRank);
2225 return ExpectTfAndXlaOutputsAreClose(
2226 OpTestBuilder("FFT3D").RandomInput(DT_COMPLEX64, dims));
2227 });
2228 }
2229
TEST_F(OpTest,IFFT)2230 TEST_F(OpTest, IFFT) {
2231 Repeatedly([this]() {
2232 std::vector<int64_t> dims = RandomDims(1, kDefaultMaxRank);
2233 return ExpectTfAndXlaOutputsAreClose(
2234 OpTestBuilder("IFFT").RandomInput(DT_COMPLEX64, dims));
2235 });
2236 }
2237
TEST_F(OpTest,IFFT2D)2238 TEST_F(OpTest, IFFT2D) {
2239 Repeatedly([this]() {
2240 std::vector<int64_t> dims = RandomDims(2, kDefaultMaxRank);
2241 return ExpectTfAndXlaOutputsAreClose(
2242 OpTestBuilder("IFFT2D").RandomInput(DT_COMPLEX64, dims));
2243 });
2244 }
2245
TEST_F(OpTest,IFFT3D)2246 TEST_F(OpTest, IFFT3D) {
2247 Repeatedly([this]() {
2248 std::vector<int64_t> dims = RandomDims(3, kDefaultMaxRank);
2249 return ExpectTfAndXlaOutputsAreClose(
2250 OpTestBuilder("IFFT3D").RandomInput(DT_COMPLEX64, dims));
2251 });
2252 }
2253
TEST_F(OpTest,RFFT)2254 TEST_F(OpTest, RFFT) {
2255 Repeatedly([this]() {
2256 std::vector<int64_t> dims = RandomDims(1, kDefaultMaxRank, 3);
2257 Tensor fft_shape = test::AsTensor<int32>(AsInt32s({dims[dims.size() - 1]}));
2258 return ExpectTfAndXlaOutputsAreClose(
2259 OpTestBuilder("RFFT").RandomInput(DT_FLOAT, dims).Input(fft_shape));
2260 });
2261 }
2262
TEST_F(OpTest,RFFT2D)2263 TEST_F(OpTest, RFFT2D) {
2264 Repeatedly([this]() {
2265 std::vector<int64_t> dims = RandomDims(2, kDefaultMaxRank, 3);
2266 Tensor fft_shape = test::AsTensor<int32>(
2267 AsInt32s({dims[dims.size() - 2], dims[dims.size() - 1]}));
2268 return ExpectTfAndXlaOutputsAreClose(
2269 OpTestBuilder("RFFT2D").RandomInput(DT_FLOAT, dims).Input(fft_shape));
2270 });
2271 }
2272
TEST_F(OpTest,RFFT3D)2273 TEST_F(OpTest, RFFT3D) {
2274 Repeatedly([this]() {
2275 std::vector<int64_t> dims = RandomDims(3, kDefaultMaxRank, 3);
2276 Tensor fft_shape = test::AsTensor<int32>(AsInt32s(
2277 {dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]}));
2278 return ExpectTfAndXlaOutputsAreClose(
2279 OpTestBuilder("RFFT3D").RandomInput(DT_FLOAT, dims).Input(fft_shape));
2280 });
2281 }
2282
TEST_F(OpTest,IRFFT)2283 TEST_F(OpTest, IRFFT) {
2284 Repeatedly([this]() {
2285 std::vector<int64_t> dims = RandomDims(1, kDefaultMaxRank, 3);
2286 int64_t orig_size = dims[dims.size() - 1];
2287 dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1;
2288 Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size}));
2289 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT")
2290 .RandomInput(DT_COMPLEX64, dims)
2291 .Input(fft_shape));
2292 });
2293 }
2294
TEST_F(OpTest,IRFFT2D)2295 TEST_F(OpTest, IRFFT2D) {
2296 Repeatedly([this]() {
2297 std::vector<int64_t> dims = RandomDims(2, kDefaultMaxRank, 3);
2298 std::vector<int64_t> orig_size = {dims[dims.size() - 2],
2299 dims[dims.size() - 1]};
2300 dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1;
2301 Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size}));
2302 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT2D")
2303 .RandomInput(DT_COMPLEX64, dims)
2304 .Input(fft_shape));
2305 });
2306 }
2307
TEST_F(OpTest,IRFFT3D)2308 TEST_F(OpTest, IRFFT3D) {
2309 Repeatedly([this]() {
2310 std::vector<int64_t> dims = RandomDims(3, kDefaultMaxRank, 3);
2311 std::vector<int64_t> orig_size = {
2312 dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]};
2313 dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1;
2314 Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size}));
2315 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT3D")
2316 .RandomInput(DT_COMPLEX64, dims)
2317 .Input(fft_shape));
2318 });
2319 }
2320
TEST_F(OpTest,Conv2D)2321 TEST_F(OpTest, Conv2D) {
2322 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2323 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2324 Repeatedly([this]() {
2325 WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
2326 std::uniform_int_distribution<int> random_int(1, 5);
2327 int features_in = random_int(generator());
2328 int features_out = random_int(generator());
2329
2330 int64_t batch = RandomDim();
2331
2332 std::vector<int64_t> data_dims =
2333 ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
2334
2335 std::vector<int64_t> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
2336 features_in, features_out};
2337 DataType type = DT_FLOAT;
2338 return ExpectTfAndXlaOutputsAreClose(
2339 OpTestBuilder("Conv2D")
2340 .RandomInput(type, data_dims)
2341 .RandomInput(type, kernel_dims)
2342 .Attr("T", type)
2343 .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
2344 .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
2345 .Attr("data_format", "NHWC"));
2346 });
2347 }
2348
TEST_F(OpTest,Conv2DBackpropFilter)2349 TEST_F(OpTest, Conv2DBackpropFilter) {
2350 Repeatedly([this]() {
2351 WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
2352 std::uniform_int_distribution<int> random_int(1, 5);
2353 int features_in = random_int(generator());
2354 int features_out = random_int(generator());
2355 int32_t batch = RandomDim();
2356 std::vector<int64_t> activations =
2357 ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
2358 std::vector<int64_t> backprop =
2359 ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
2360 Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
2361 {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}));
2362 DataType type = DT_FLOAT;
2363 return ExpectTfAndXlaOutputsAreClose(
2364 OpTestBuilder("Conv2DBackpropFilter")
2365 .RandomInput(type, activations)
2366 .Input(kernel_shape)
2367 .RandomInput(type, backprop)
2368 .Attr("T", type)
2369 .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
2370 .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
2371 .Attr("data_format", "NHWC"));
2372 });
2373 }
2374
TEST_F(OpTest,Conv2DBackpropInput)2375 TEST_F(OpTest, Conv2DBackpropInput) {
2376 Repeatedly([this]() {
2377 WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
2378 std::uniform_int_distribution<int> random_int(1, 5);
2379 int features_in = random_int(generator());
2380 int features_out = random_int(generator());
2381 int32_t batch = RandomDim();
2382 Tensor in_shape = test::AsTensor<int32>(
2383 AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)));
2384 std::vector<int64_t> backprop =
2385 ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
2386 std::vector<int64_t> kernel = {d.kernel_dims[0], d.kernel_dims[1],
2387 features_in, features_out};
2388 DataType type = DT_FLOAT;
2389 return ExpectTfAndXlaOutputsAreClose(
2390 OpTestBuilder("Conv2DBackpropInput")
2391 .Input(in_shape)
2392 .RandomInput(type, kernel)
2393 .RandomInput(type, backprop)
2394 .Attr("T", type)
2395 .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
2396 .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
2397 .Attr("data_format", "NHWC"));
2398 });
2399 }
2400
TEST_F(OpTest,Conv3D)2401 TEST_F(OpTest, Conv3D) {
2402 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2403 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2404 Repeatedly([this]() {
2405 WindowedSpatialDims d = ChooseWindowedSpatialDims(3);
2406 std::uniform_int_distribution<int> random_int(1, 5);
2407 int features_in = random_int(generator());
2408 int features_out = random_int(generator());
2409 std::vector<int64_t> data = {RandomDim(), d.input_dims[0], d.input_dims[1],
2410 d.input_dims[2], features_in};
2411
2412 std::vector<int64_t> kernel = {d.kernel_dims[0], d.kernel_dims[1],
2413 d.kernel_dims[2], features_in, features_out};
2414 DataType type = DT_FLOAT;
2415 return ExpectTfAndXlaOutputsAreClose(
2416 OpTestBuilder("Conv3D")
2417 .RandomInput(type, data)
2418 .RandomInput(type, kernel)
2419 .Attr("T", type)
2420 .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
2421 .Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
2422 });
2423 }
2424
TEST_F(OpTest,Conv3DBackpropFilter)2425 TEST_F(OpTest, Conv3DBackpropFilter) {
2426 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2427 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2428 Repeatedly([this]() {
2429 WindowedSpatialDims d = ChooseWindowedSpatialDims(3);
2430 std::uniform_int_distribution<int> random_int(1, 5);
2431 int features_in = random_int(generator());
2432 int features_out = random_int(generator());
2433 int32_t batch = RandomDim(1);
2434 std::vector<int64_t> activations =
2435 ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
2436 std::vector<int64_t> backprop =
2437 ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
2438 Tensor kernel_shape = test::AsTensor<int32>(
2439 AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2],
2440 features_in, features_out}));
2441 DataType type = DT_FLOAT;
2442 return ExpectTfAndXlaOutputsAreClose(
2443 OpTestBuilder("Conv3DBackpropFilterV2")
2444 .RandomInput(type, activations)
2445 .Input(kernel_shape)
2446 .RandomInput(type, backprop)
2447 .Attr("T", type)
2448 .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
2449 .Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
2450 });
2451 }
2452
TEST_F(OpTest,Conv3DBackpropInput)2453 TEST_F(OpTest, Conv3DBackpropInput) {
2454 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2455 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2456 Repeatedly([this]() {
2457 WindowedSpatialDims d = ChooseWindowedSpatialDims(3);
2458 std::uniform_int_distribution<int> random_int(1, 5);
2459 int features_in = random_int(generator());
2460 int features_out = random_int(generator());
2461 int32_t batch = RandomDim(1);
2462 Tensor in_shape = test::AsTensor<int32>(
2463 AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)));
2464 std::vector<int64_t> backprop =
2465 ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
2466 std::vector<int64_t> kernel = {d.kernel_dims[0], d.kernel_dims[1],
2467 d.kernel_dims[2], features_in, features_out};
2468 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2469 return ExpectTfAndXlaOutputsAreClose(
2470 OpTestBuilder("Conv3DBackpropInputV2")
2471 .Input(in_shape)
2472 .RandomInput(type, kernel)
2473 .RandomInput(type, backprop)
2474 .Attr("T", type)
2475 .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
2476 .Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
2477 });
2478 }
2479
TEST_F(OpTest,ComplexAbs)2480 TEST_F(OpTest, ComplexAbs) {
2481 Repeatedly([this]() {
2482 auto type = DT_COMPLEX64;
2483 auto type_out = DT_FLOAT;
2484 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ComplexAbs")
2485 .RandomInput(type)
2486 .Attr("T", type)
2487 .Attr("Tout", type_out));
2488 });
2489 }
2490
TEST_F(OpTest,Cos)2491 TEST_F(OpTest, Cos) {
2492 Repeatedly([this]() {
2493 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2494 return ExpectTfAndXlaOutputsAreClose(
2495 OpTestBuilder("Cos").RandomInput(type).Attr("T", type));
2496 });
2497 }
2498
TEST_F(OpTest,Cosh)2499 TEST_F(OpTest, Cosh) {
2500 Repeatedly([this]() {
2501 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2502 return ExpectTfAndXlaOutputsAreClose(
2503 OpTestBuilder("Cosh").RandomInput(type).Attr("T", type));
2504 });
2505 }
2506
TEST_F(OpTest,DepthToSpace)2507 TEST_F(OpTest, DepthToSpace) {
2508 Repeatedly([this]() {
2509 int64_t block = RandomDim(2, 5);
2510 std::vector<int64_t> input_dims = RandomDims(4, 4);
2511 input_dims[1] = (input_dims[1] + (block - 1)) / block;
2512 input_dims[2] = (input_dims[2] + (block - 1)) / block;
2513 input_dims[3] *= block * block;
2514 auto type = Choose<DataType>(kAllXlaTypes);
2515 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DepthToSpace")
2516 .RandomInput(type, input_dims)
2517 .Attr("T", type)
2518 .Attr("block_size", block));
2519 });
2520 }
2521
TEST_F(OpTest,DepthwiseConv2DNative)2522 TEST_F(OpTest, DepthwiseConv2DNative) {
2523 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2524 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2525 Repeatedly([this]() {
2526 WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
2527 std::uniform_int_distribution<int> random_int(1, 5);
2528 int features_in = random_int(generator());
2529 int depth_multiplier = random_int(generator());
2530 std::vector<int64_t> input_dims = {RandomDim(), d.input_dims[0],
2531 d.input_dims[1], features_in};
2532
2533 std::vector<int64_t> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
2534 features_in, depth_multiplier};
2535 std::vector<int64_t> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
2536 strides[2] = strides[1]; // Current impl only supports equal strides
2537 return ExpectTfAndXlaOutputsAreClose(
2538 OpTestBuilder("DepthwiseConv2dNative")
2539 .RandomInput(DT_FLOAT, input_dims)
2540 .RandomInput(DT_FLOAT, kernel_dims)
2541 .Attr("T", DT_FLOAT)
2542 .Attr("strides", strides)
2543 .Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
2544 });
2545 }
2546
TEST_F(OpTest,DepthwiseConv2DNativeBackpropFilter)2547 TEST_F(OpTest, DepthwiseConv2DNativeBackpropFilter) {
2548 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2549 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2550 Repeatedly([this]() {
2551 WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
2552 std::uniform_int_distribution<int> random_int(1, 5);
2553 int features_in = random_int(generator());
2554 int depth_multiplier = random_int(generator());
2555 int32_t batch = RandomDim();
2556 std::vector<int64_t> activations =
2557 ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
2558 std::vector<int64_t> backprop = ImageDims(
2559 FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
2560 Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
2561 {d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier}));
2562 std::vector<int64_t> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
2563 strides[2] = strides[1]; // Current impl only supports equal strides
2564 return ExpectTfAndXlaOutputsAreClose(
2565 OpTestBuilder("DepthwiseConv2dNativeBackpropFilter")
2566 .RandomInput(DT_FLOAT, activations)
2567 .Input(kernel_shape)
2568 .RandomInput(DT_FLOAT, backprop)
2569 .Attr("T", DT_FLOAT)
2570 .Attr("strides", strides)
2571 .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
2572 .Attr("data_format", "NHWC"));
2573 });
2574 }
2575
TEST_F(OpTest,DepthwiseConv2DBackpropInput)2576 TEST_F(OpTest, DepthwiseConv2DBackpropInput) {
2577 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2578 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2579 Repeatedly([this]() {
2580 WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
2581 std::uniform_int_distribution<int> random_int(1, 5);
2582 int features_in = random_int(generator());
2583 int depth_multiplier = random_int(generator());
2584 int32_t batch = RandomDim();
2585 Tensor in_shape = test::AsTensor<int32>(
2586 AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)));
2587 std::vector<int64_t> backprop = ImageDims(
2588 FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
2589 std::vector<int64_t> kernel = {d.kernel_dims[0], d.kernel_dims[1],
2590 features_in, depth_multiplier};
2591 std::vector<int64_t> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
2592 strides[2] = strides[1]; // Current impl only supports equal strides
2593 return ExpectTfAndXlaOutputsAreClose(
2594 OpTestBuilder("DepthwiseConv2dNativeBackpropInput")
2595 .Input(in_shape)
2596 .RandomInput(DT_FLOAT, kernel)
2597 .RandomInput(DT_FLOAT, backprop)
2598 .Attr("T", DT_FLOAT)
2599 .Attr("strides", strides)
2600 .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
2601 .Attr("data_format", "NHWC"));
2602 });
2603 }
2604
TEST_F(OpTest,Diag)2605 TEST_F(OpTest, Diag) {
2606 Repeatedly([this]() {
2607 auto type = Choose<DataType>(kAllXlaTypes);
2608 std::vector<int64_t> dims;
2609 // Diag causes a quadratic blowup in output size.
2610 int64_t size;
2611 do {
2612 dims = RandomDims(1);
2613 size = TensorShape(dims).num_elements();
2614 } while (size * size > tf_xla_max_tensor_size);
2615 return ExpectTfAndXlaOutputsAreClose(
2616 OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type));
2617 });
2618 }
2619
TEST_F(OpTest,DiagPart)2620 TEST_F(OpTest, DiagPart) {
2621 Repeatedly([this]() {
2622 auto type = Choose<DataType>(kAllXlaTypes);
2623 auto dims = RandomDims(1, 3);
2624 // Duplicate the random dims.
2625 std::vector<int64_t> doubled_dims(dims.size() * 2);
2626 std::copy(dims.begin(), dims.end(), doubled_dims.begin());
2627 std::copy(dims.begin(), dims.end(), doubled_dims.begin() + dims.size());
2628 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DiagPart")
2629 .RandomInput(type, doubled_dims)
2630 .Attr("T", type));
2631 });
2632 }
2633
TEST_F(OpTest,Digamma)2634 TEST_F(OpTest, Digamma) {
2635 Repeatedly([this]() {
2636 return ExpectTfAndXlaOutputsAreClose(
2637 OpTestBuilder("Digamma").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2638 });
2639 }
2640
TEST_F(OpTest,Div)2641 TEST_F(OpTest, Div) {
2642 Repeatedly([this]() {
2643 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2644 auto dims = BroadcastableDims();
2645 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div")
2646 .RandomInput(type, dims.first)
2647 .RandomInput(type, dims.second)
2648 .Attr("T", type));
2649 });
2650 }
2651
TEST_F(OpTest,DivNoNan)2652 TEST_F(OpTest, DivNoNan) {
2653 Repeatedly([this]() {
2654 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2655 auto dims = BroadcastableDims();
2656 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DivNoNan")
2657 .RandomInput(type, dims.first)
2658 .RandomInput(type, dims.second)
2659 .Attr("T", type));
2660 });
2661 }
2662
TEST_F(OpTest,DynamicStitch)2663 TEST_F(OpTest, DynamicStitch) {
2664 Repeatedly([this]() {
2665 auto type = Choose<DataType>(kAllXlaTypes);
2666 int n = std::uniform_int_distribution<int>(2, 5)(generator());
2667 OpTestBuilder builder("DynamicStitch");
2668 builder.Attr("T", type);
2669 builder.Attr("N", n);
2670 std::vector<std::vector<int64_t>> index_dims;
2671 int size = 0;
2672 // TODO(phawkins): the XLA implementation of DynamicStitch does not
2673 // accept an empty set of indices.
2674 do {
2675 size = 0;
2676 index_dims.clear();
2677 for (int i = 0; i < n; ++i) {
2678 std::vector<int64_t> dims = RandomDims(0, 3, 0, 5);
2679 size += TensorShape(dims).num_elements();
2680 index_dims.push_back(dims);
2681 }
2682 } while (size == 0);
2683
2684 // Shuffle the range of indices that cover the output.
2685 // TODO(phawkins): The documentation for DynamicStitch doesn't require
2686 // that the indices cover all positions of the output. The XLA
2687 // implementation does so require. However, the native TF implementation
2688 // leaves undefined values if we don't cover everything, so we can't
2689 // really test that case anyway.
2690 std::vector<int32> indices(size);
2691 std::iota(indices.begin(), indices.end(), 0);
2692 std::shuffle(indices.begin(), indices.end(), generator());
2693
2694 int pos = 0;
2695 for (int i = 0; i < n; ++i) {
2696 TensorShape shape(index_dims[i]);
2697 Tensor t = test::AsTensor<int32>(
2698 absl::Span<const int32>(indices).subspan(pos, shape.num_elements()),
2699 shape);
2700 builder.Input(t);
2701 pos += t.NumElements();
2702 }
2703
2704 std::vector<int64_t> constant_dims = RandomDims(0, 3, 0, 5);
2705 for (int i = 0; i < n; ++i) {
2706 std::vector<int64_t> dims(index_dims[i].begin(), index_dims[i].end());
2707 std::copy(constant_dims.begin(), constant_dims.end(),
2708 std::back_inserter(dims));
2709 builder.RandomInput(type, dims);
2710 }
2711 return ExpectTfAndXlaOutputsAreClose(builder);
2712 });
2713 }
2714
TEST_F(OpTest,Einsum)2715 TEST_F(OpTest, Einsum) {
2716 Repeatedly([this]() {
2717 const EinsumArguments a = ChooseEinsumArguments();
2718 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Einsum")
2719 .RandomInput(a.type, a.lhs_dims)
2720 .RandomInput(a.type, a.rhs_dims)
2721 .Attr("equation", a.equation)
2722 .Attr("T", a.type)
2723 .Attr("N", 2));
2724 });
2725 }
2726
TEST_F(OpTest,Empty)2727 TEST_F(OpTest, Empty) {
2728 Repeatedly([this]() {
2729 auto type = Choose<DataType>({kAllXlaTypes});
2730 return ExpectTfAndXlaOutputsAreClose(
2731 OpTestBuilder("Empty")
2732 .Input(AsIntTensor(DT_INT32, RandomDims()))
2733 .Attr("init", true)
2734 .Attr("dtype", type));
2735 });
2736 }
2737
TEST_F(OpTest,Elu)2738 TEST_F(OpTest, Elu) {
2739 Repeatedly([this]() {
2740 return ExpectTfAndXlaOutputsAreClose(
2741 OpTestBuilder("Elu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2742 });
2743 }
2744
TEST_F(OpTest,EluGrad)2745 TEST_F(OpTest, EluGrad) {
2746 Repeatedly([this]() {
2747 auto dims = RandomDims();
2748 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad")
2749 .RandomInput(DT_FLOAT, dims)
2750 .RandomInput(DT_FLOAT, dims)
2751 .Attr("T", DT_FLOAT));
2752 });
2753 }
2754
TEST_F(OpTest,ScatterNd)2755 TEST_F(OpTest, ScatterNd) {
2756 Repeatedly([this]() {
2757 auto a = ChooseScatterArguments();
2758 auto shape = test::AsTensor<int32>(
2759 std::vector<int32>(a.shape.begin(), a.shape.end()));
2760 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ScatterNd")
2761 .Input(a.indices)
2762 .Input(a.updates)
2763 .Input(shape)
2764 .Attr("T", a.type)
2765 .Attr("Tindices", a.indices_type));
2766 });
2767 }
2768
TEST_F(OpTest,Selu)2769 TEST_F(OpTest, Selu) {
2770 Repeatedly([this]() {
2771 return ExpectTfAndXlaOutputsAreClose(
2772 OpTestBuilder("Selu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2773 });
2774 }
2775
TEST_F(OpTest,SeluGrad)2776 TEST_F(OpTest, SeluGrad) {
2777 Repeatedly([this]() {
2778 auto dims = RandomDims();
2779 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SeluGrad")
2780 .RandomInput(DT_FLOAT, dims)
2781 .RandomInput(DT_FLOAT, dims)
2782 .Attr("T", DT_FLOAT));
2783 });
2784 }
2785
TEST_F(OpTest,Equal)2786 TEST_F(OpTest, Equal) {
2787 Repeatedly([this]() {
2788 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2789 auto dims = BroadcastableDims();
2790 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal")
2791 .RandomInput(type, dims.first)
2792 .RandomInput(type, dims.second)
2793 .Attr("T", type));
2794 });
2795 }
2796
TEST_F(OpTest,Erf)2797 TEST_F(OpTest, Erf) {
2798 Repeatedly([this]() {
2799 return ExpectTfAndXlaOutputsAreClose(
2800 OpTestBuilder("Erf").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2801 });
2802 }
2803
TEST_F(OpTest,Erfc)2804 TEST_F(OpTest, Erfc) {
2805 Repeatedly([this]() {
2806 return ExpectTfAndXlaOutputsAreClose(
2807 OpTestBuilder("Erfc").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2808 });
2809 }
2810
TEST_F(OpTest,Exp)2811 TEST_F(OpTest, Exp) {
2812 Repeatedly([this]() {
2813 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2814 return ExpectTfAndXlaOutputsAreClose(
2815 OpTestBuilder("Exp").RandomInput(type).Attr("T", type));
2816 });
2817 }
2818
TEST_F(OpTest,Expm1)2819 TEST_F(OpTest, Expm1) {
2820 Repeatedly([this]() {
2821 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2822 return ExpectTfAndXlaOutputsAreClose(
2823 OpTestBuilder("Expm1").RandomInput(type).Attr("T", type));
2824 });
2825 }
2826
TEST_F(OpTest,ExpandDims)2827 TEST_F(OpTest, ExpandDims) {
2828 Repeatedly([this]() {
2829 auto type = Choose<DataType>(kAllXlaTypes);
2830 std::vector<int64_t> in_dims = RandomDims();
2831 Tensor dim(DT_INT32, TensorShape());
2832 std::uniform_int_distribution<int32> d(-1 - in_dims.size(), in_dims.size());
2833 dim.scalar<int32>()() = d(generator());
2834 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ExpandDims")
2835 .RandomInput(type, in_dims)
2836 .Input(dim)
2837 .Attr("T", type));
2838 });
2839 }
2840
TEST_F(OpTest,Fill)2841 TEST_F(OpTest, Fill) {
2842 Repeatedly([this]() {
2843 auto type = Choose<DataType>(kAllXlaTypes);
2844 std::vector<int64_t> dims = RandomDims();
2845 std::vector<int32> shape(dims.begin(), dims.end());
2846 return ExpectTfAndXlaOutputsAreClose(
2847 OpTestBuilder("Fill")
2848 .Input(test::AsTensor<int32>(shape))
2849 .RandomInput(type, {})
2850 .Attr("T", type));
2851 });
2852 }
2853
TEST_F(OpTest,Floor)2854 TEST_F(OpTest, Floor) {
2855 Repeatedly([this]() {
2856 return ExpectTfAndXlaOutputsAreClose(
2857 OpTestBuilder("Floor").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2858 });
2859 }
2860
TEST_F(OpTest,FloorDiv)2861 TEST_F(OpTest, FloorDiv) {
2862 Repeatedly([this]() {
2863 DataType type = DT_INT32;
2864 auto dims = BroadcastableDims();
2865 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorDiv")
2866 .RandomInput(type, dims.first)
2867 .RandomInput(type, dims.second)
2868 .Attr("T", type));
2869 });
2870 }
2871
TEST_F(OpTest,FloorMod)2872 TEST_F(OpTest, FloorMod) {
2873 Repeatedly([this]() {
2874 auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2875 auto dims = BroadcastableDims();
2876 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorMod")
2877 .RandomInput(type, dims.first)
2878 .RandomInput(type, dims.second)
2879 .Attr("T", type));
2880 });
2881 }
2882
TEST_F(OpTest,Gather)2883 TEST_F(OpTest, Gather) {
2884 Repeatedly([this]() {
2885 GatherArguments a = ChooseGatherArguments(true);
2886 return ExpectTfAndXlaOutputsAreClose(
2887 OpTestBuilder("Gather")
2888 .RandomInput(a.params_type, a.params_shape)
2889 .Input(a.indices)
2890 .Attr("Tparams", a.params_type)
2891 .Attr("Tindices", a.indices_type));
2892 });
2893 }
2894
TEST_F(OpTest,GatherV2)2895 TEST_F(OpTest, GatherV2) {
2896 Repeatedly([this]() {
2897 GatherArguments a = ChooseGatherArguments(false);
2898 return ExpectTfAndXlaOutputsAreClose(
2899 OpTestBuilder("GatherV2")
2900 .RandomInput(a.params_type, a.params_shape)
2901 .Input(a.indices)
2902 .Input(a.axis)
2903 .Attr("batch_dims", a.batch_dims)
2904 .Attr("Taxis", a.axis_type)
2905 .Attr("Tindices", a.indices_type)
2906 .Attr("Tparams", a.params_type));
2907 });
2908 }
2909
TEST_F(OpTest,GatherNd)2910 TEST_F(OpTest, GatherNd) {
2911 // :randomized_tests_mlir fails with --tf_xla_random_seed=459353625
2912 // --test_arg=--tf_xla_test_repetitions=100
2913 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2914 // See b/214080339#comment27 as this test causes Kokoro to crash.
2915 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2916 Repeatedly([this]() { // NOLINT: due to GTEST_SKIP
2917 auto params_type = Choose<DataType>(kAllXlaTypes);
2918 // GatherNd seems undefined on the case where params has rank 0.
2919 std::vector<int64_t> params_shape = RandomDims(1);
2920 auto indices_type = DT_INT32;
2921 std::vector<int64_t> output_outer_shape = RandomDims(0, 4, 0, 32);
2922 int64_t index_len = RandomDim(0, params_shape.size() + 1);
2923 std::vector<int64_t> output_shape(output_outer_shape);
2924 output_shape.push_back(index_len);
2925 Tensor lo(indices_type, TensorShape(output_shape));
2926 test::FillFn<int32>(&lo, [](int i) -> int32 { return 0; });
2927 Tensor hi(indices_type, TensorShape(output_shape));
2928 test::FillFn<int32>(&hi, [index_len, ¶ms_shape](int i) -> int32 {
2929 int idx_dim = i % index_len;
2930 return params_shape[idx_dim] - 1;
2931 });
2932 Tensor indices = RandomBoundedTensor(indices_type, lo, hi);
2933 return ExpectTfAndXlaOutputsAreClose(
2934 OpTestBuilder("GatherNd")
2935 .RandomInput(params_type, params_shape)
2936 .Input(indices)
2937 .Attr("Tindices", indices_type)
2938 .Attr("Tparams", params_type));
2939 });
2940 }
2941
TEST_F(OpTest,Greater)2942 TEST_F(OpTest, Greater) {
2943 Repeatedly([this]() {
2944 auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2945 auto dims = BroadcastableDims();
2946 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Greater")
2947 .RandomInput(type, dims.first)
2948 .RandomInput(type, dims.second)
2949 .Attr("T", type));
2950 });
2951 }
2952
TEST_F(OpTest,GreaterEqual)2953 TEST_F(OpTest, GreaterEqual) {
2954 Repeatedly([this]() {
2955 auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2956 auto dims = BroadcastableDims();
2957 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("GreaterEqual")
2958 .RandomInput(type, dims.first)
2959 .RandomInput(type, dims.second)
2960 .Attr("T", type));
2961 });
2962 }
2963
TEST_F(OpTest,Identity)2964 TEST_F(OpTest, Identity) {
2965 Repeatedly([this]() {
2966 auto type = Choose<DataType>(kAllXlaTypes);
2967 return ExpectTfAndXlaOutputsAreClose(
2968 OpTestBuilder("Identity").RandomInput(type).Attr("T", type));
2969 });
2970 }
2971
TEST_F(OpTest,Imag)2972 TEST_F(OpTest, Imag) {
2973 Repeatedly([this]() {
2974 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Imag")
2975 .RandomInput(DT_COMPLEX64)
2976 .Attr("T", DT_COMPLEX64));
2977 });
2978 }
2979
TEST_F(OpTest,InplaceUpdate)2980 TEST_F(OpTest, InplaceUpdate) {
2981 Repeatedly([this]() {
2982 auto type = Choose<DataType>(kAllXlaTypes);
2983 std::vector<int64_t> common_dims =
2984 RandomDims(0, kDefaultMaxRank - 1, 0, kDefaultMaxDimensionSize);
2985 // TODO(b/211012712): Once needs_unique_values case is linear instead of
2986 // quadratic time, use default Dim max instead of 8.
2987 std::vector<int64_t> v_dims{RandomDim(1, 8)};
2988 v_dims.insert(v_dims.end(), common_dims.begin(), common_dims.end());
2989 std::vector<int64_t> x_dims{RandomDim(v_dims[0])};
2990 x_dims.insert(x_dims.end(), common_dims.begin(), common_dims.end());
2991 std::vector<int64_t> i_shape{v_dims[0]};
2992 Tensor i =
2993 RandomBoundedTensor<int32>(DT_INT32, 0, x_dims[0] - 1, true, i_shape);
2994 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("InplaceUpdate")
2995 .RandomInput(type, x_dims)
2996 .Input(i)
2997 .RandomInput(type, v_dims)
2998 .Attr("T", type));
2999 });
3000 }
3001
TEST_F(OpTest,Inv)3002 TEST_F(OpTest, Inv) {
3003 Repeatedly([this]() {
3004 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3005 return ExpectTfAndXlaOutputsAreClose(
3006 OpTestBuilder("Inv").RandomInput(type).Attr("T", type));
3007 });
3008 }
3009
TEST_F(OpTest,Invert)3010 TEST_F(OpTest, Invert) {
3011 Repeatedly([this]() {
3012 DataType type = DT_INT32;
3013 return ExpectTfAndXlaOutputsAreClose(
3014 OpTestBuilder("Invert").RandomInput(type).Attr("T", type));
3015 });
3016 }
3017
TEST_F(OpTest,InvertPermutation)3018 TEST_F(OpTest, InvertPermutation) {
3019 Repeatedly([this]() {
3020 // TODO(b/211012712): Once needs_unique_values case is linear instead of
3021 // quadratic time, use default Dim max instead of 8.
3022 int64_t len = RandomDim(0, 8);
3023 Tensor x = RandomBoundedTensor<int32>(DT_INT32, 0, len - 1, true, {len});
3024 return ExpectTfAndXlaOutputsAreClose(
3025 OpTestBuilder("InvertPermutation").Input(x).Attr("T", DT_INT32));
3026 });
3027 }
3028
TEST_F(OpTest,IsFinite)3029 TEST_F(OpTest, IsFinite) {
3030 Repeatedly([this]() {
3031 return ExpectTfAndXlaOutputsAreClose(
3032 OpTestBuilder("IsFinite").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3033 });
3034 }
3035
TEST_F(OpTest,IsInf)3036 TEST_F(OpTest, IsInf) {
3037 Repeatedly([this]() {
3038 return ExpectTfAndXlaOutputsAreClose(
3039 OpTestBuilder("IsInf").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3040 });
3041 }
3042
TEST_F(OpTest,IsNan)3043 TEST_F(OpTest, IsNan) {
3044 Repeatedly([this]() {
3045 return ExpectTfAndXlaOutputsAreClose(
3046 OpTestBuilder("IsNan").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3047 });
3048 }
3049
TEST_F(OpTest,L2Loss)3050 TEST_F(OpTest, L2Loss) {
3051 Repeatedly([this]() {
3052 DataType type = DT_FLOAT;
3053 return ExpectTfAndXlaOutputsAreClose(
3054 OpTestBuilder("L2Loss").RandomInput(type).Attr("T", type));
3055 });
3056 }
3057
TEST_F(OpTest,LeakyRelu)3058 TEST_F(OpTest, LeakyRelu) {
3059 Repeatedly([this]() {
3060 std::uniform_real_distribution<float> alpha(-2.0f, 2.0f);
3061 return ExpectTfAndXlaOutputsAreClose(
3062 OpTestBuilder("LeakyRelu")
3063 .RandomInput(DT_FLOAT)
3064 .Attr("T", DT_FLOAT)
3065 .Attr("alpha", alpha(generator())));
3066 });
3067 }
3068
TEST_F(OpTest,LeakyReluGrad)3069 TEST_F(OpTest, LeakyReluGrad) {
3070 Repeatedly([this]() {
3071 auto dims = RandomDims(1);
3072 std::uniform_real_distribution<float> alpha(-2.0f, 2.0f);
3073 return ExpectTfAndXlaOutputsAreClose(
3074 OpTestBuilder("LeakyReluGrad")
3075 .RandomInput(DT_FLOAT, dims)
3076 .RandomInput(DT_FLOAT, dims)
3077 .Attr("T", DT_FLOAT)
3078 .Attr("alpha", alpha(generator())));
3079 });
3080 }
3081
TEST_F(OpTest,LeftShift)3082 TEST_F(OpTest, LeftShift) {
3083 Repeatedly([this]() {
3084 bool is64 = RandomBool();
3085 auto dims = RandomDims();
3086 auto type = is64 ? DT_INT64 : DT_INT32;
3087 int max_shift = is64 ? 63 : 31;
3088 auto y = RandomBoundedTensor(type, 0, max_shift, false, dims);
3089 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LeftShift")
3090 .RandomInput(type, dims)
3091 .Input(y)
3092 .Attr("T", type));
3093 });
3094 }
3095
TEST_F(OpTest,Less)3096 TEST_F(OpTest, Less) {
3097 Repeatedly([this]() {
3098 auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
3099 auto dims = BroadcastableDims();
3100 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Less")
3101 .RandomInput(type, dims.first)
3102 .RandomInput(type, dims.second)
3103 .Attr("T", type));
3104 });
3105 }
3106
TEST_F(OpTest,LessEqual)3107 TEST_F(OpTest, LessEqual) {
3108 Repeatedly([this]() {
3109 auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
3110 auto dims = BroadcastableDims();
3111 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LessEqual")
3112 .RandomInput(type, dims.first)
3113 .RandomInput(type, dims.second)
3114 .Attr("T", type));
3115 });
3116 }
3117
TEST_F(OpTest,Lgamma)3118 TEST_F(OpTest, Lgamma) {
3119 Repeatedly([this]() {
3120 return ExpectTfAndXlaOutputsAreClose(
3121 OpTestBuilder("Lgamma").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3122 });
3123 }
3124
TEST_F(OpTest,LinSpace)3125 TEST_F(OpTest, LinSpace) {
3126 Repeatedly([this]() {
3127 auto ToScalar = [](DataType type, int x) {
3128 if (type == DT_INT32) return test::AsScalar<int32>(x);
3129 return test::AsScalar<int64_t>(x);
3130 };
3131 std::uniform_int_distribution<int> distribution(-50, 50);
3132 auto type = Choose<DataType>({DT_INT32, DT_INT64});
3133 return ExpectTfAndXlaOutputsAreClose(
3134 OpTestBuilder("LinSpace")
3135 .RandomInput(DT_FLOAT, {})
3136 .RandomInput(DT_FLOAT, {})
3137 .Input(ToScalar(type, distribution(generator())))
3138 .Attr("T", DT_FLOAT)
3139 .Attr("Tidx", type));
3140 });
3141 }
3142
TEST_F(OpTest,Log)3143 TEST_F(OpTest, Log) {
3144 Repeatedly([this]() {
3145 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3146 return ExpectTfAndXlaOutputsAreClose(
3147 OpTestBuilder("Log").RandomInput(type).Attr("T", type));
3148 });
3149 }
3150
TEST_F(OpTest,Log1p)3151 TEST_F(OpTest, Log1p) {
3152 Repeatedly([this]() {
3153 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3154 return ExpectTfAndXlaOutputsAreClose(
3155 OpTestBuilder("Log1p").RandomInput(type).Attr("T", DT_FLOAT));
3156 });
3157 }
3158
TEST_F(OpTest,LogicalAnd)3159 TEST_F(OpTest, LogicalAnd) {
3160 Repeatedly([this]() {
3161 auto dims = BroadcastableDims();
3162 return ExpectTfAndXlaOutputsAreClose(
3163 OpTestBuilder("LogicalAnd")
3164 .RandomInput(DT_BOOL, dims.first)
3165 .RandomInput(DT_BOOL, dims.second));
3166 });
3167 }
3168
TEST_F(OpTest,LogicalNot)3169 TEST_F(OpTest, LogicalNot) {
3170 Repeatedly([this]() {
3171 return ExpectTfAndXlaOutputsAreClose(
3172 OpTestBuilder("LogicalNot").RandomInput(DT_BOOL));
3173 });
3174 }
3175
TEST_F(OpTest,LogicalOr)3176 TEST_F(OpTest, LogicalOr) {
3177 Repeatedly([this]() {
3178 auto dims = BroadcastableDims();
3179 return ExpectTfAndXlaOutputsAreClose(
3180 OpTestBuilder("LogicalOr")
3181 .RandomInput(DT_BOOL, dims.first)
3182 .RandomInput(DT_BOOL, dims.second));
3183 });
3184 }
3185
TEST_F(OpTest,LogSoftmax)3186 TEST_F(OpTest, LogSoftmax) {
3187 Repeatedly([this]() {
3188 return ExpectTfAndXlaOutputsAreClose(
3189 OpTestBuilder("LogSoftmax")
3190 .RandomInput(DT_FLOAT, RandomDims(2, 2))
3191 .Attr("T", DT_FLOAT));
3192 });
3193 }
3194
TEST_F(OpTest,LRN)3195 TEST_F(OpTest, LRN) {
3196 Repeatedly([this]() {
3197 // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed.
3198 std::vector<int64_t> data_dims = RandomDims(4, 4, 1, 8);
3199 // CuDNN requires depth_radius > 0.
3200 std::uniform_int_distribution<int> radius(1, data_dims[3]);
3201 std::uniform_real_distribution<float> coeff(0.01, 2.0);
3202 return ExpectTfAndXlaOutputsAreClose(
3203 OpTestBuilder("LRN")
3204 .RandomInput(DT_FLOAT, data_dims)
3205 .Attr("T", DT_FLOAT)
3206 .Attr("depth_radius", radius(generator()))
3207 .Attr("bias", coeff(generator()))
3208 .Attr("alpha", coeff(generator()))
3209 .Attr("beta", coeff(generator())));
3210 });
3211 }
3212
TEST_F(OpTest,LRNGrad)3213 TEST_F(OpTest, LRNGrad) {
3214 Repeatedly([this]() {
3215 // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed.
3216 std::vector<int64_t> dims = RandomDims(4, 4, 1, 8);
3217 // CuDNN requires depth_radius > 0.
3218 std::uniform_int_distribution<int> radius(1, dims[3]);
3219 std::uniform_real_distribution<float> coeff(0.0, 2.0);
3220 return ExpectTfAndXlaOutputsAreClose(
3221 OpTestBuilder("LRNGrad")
3222 .RandomInput(DT_FLOAT, dims)
3223 .RandomInput(DT_FLOAT, dims)
3224 .RandomInput(DT_FLOAT, dims)
3225 .Attr("T", DT_FLOAT)
3226 .Attr("depth_radius", radius(generator()))
3227 .Attr("bias", coeff(generator()))
3228 .Attr("alpha", coeff(generator()))
3229 .Attr("beta", coeff(generator())));
3230 });
3231 }
3232
TEST_F(OpTest,MatMul)3233 TEST_F(OpTest, MatMul) {
3234 Repeatedly([this]() {
3235 int64_t x = RandomDim();
3236 int64_t y = RandomDim();
3237 int64_t z = RandomDim();
3238
3239 std::vector<int64_t> a_dims = {x, y};
3240 std::vector<int64_t> b_dims = {y, z};
3241
3242 std::bernoulli_distribution random_bool;
3243 bool transpose_a = random_bool(generator());
3244 bool transpose_b = random_bool(generator());
3245 if (transpose_a) {
3246 std::swap(a_dims[0], a_dims[1]);
3247 }
3248 if (transpose_b) {
3249 std::swap(b_dims[0], b_dims[1]);
3250 }
3251
3252 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3253 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
3254 .RandomInput(type, a_dims)
3255 .RandomInput(type, b_dims)
3256 .Attr("T", type)
3257 .Attr("transpose_a", transpose_a)
3258 .Attr("transpose_b", transpose_b));
3259 });
3260 }
3261
TEST_F(OpTest,MatrixBandPart)3262 TEST_F(OpTest, MatrixBandPart) {
3263 Repeatedly([this]() {
3264 auto type = Choose<DataType>(kAllXlaTypes);
3265 auto index_type = Choose<DataType>({DT_INT32, DT_INT64});
3266 auto num_lower =
3267 RandomBoundedTensor<int32>(index_type, -2 * kDefaultMaxDimensionSize,
3268 2 * kDefaultMaxDimensionSize, false, {});
3269 auto num_upper =
3270 RandomBoundedTensor<int32>(index_type, -2 * kDefaultMaxDimensionSize,
3271 2 * kDefaultMaxDimensionSize, false, {});
3272 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixBandPart")
3273 .RandomInput(type)
3274 .Input(num_lower)
3275 .Input(num_upper)
3276 .Attr("T", type)
3277 .Attr("Tindex", index_type));
3278 });
3279 }
3280
TEST_F(OpTest,MatrixDiag)3281 TEST_F(OpTest, MatrixDiag) {
3282 Repeatedly([this]() {
3283 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3284 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag")
3285 .RandomInput(type, RandomDims(1))
3286 .Attr("T", type));
3287 });
3288 }
3289
TEST_F(OpTest,MatrixDiagPart)3290 TEST_F(OpTest, MatrixDiagPart) {
3291 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3292 Repeatedly([this]() {
3293 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3294 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart")
3295 .RandomInput(type, RandomDims(2))
3296 .Attr("T", type));
3297 });
3298 }
3299
TEST_F(OpTest,MatrixDiagPartV3)3300 TEST_F(OpTest, MatrixDiagPartV3) {
3301 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3302 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
3303 Repeatedly([this]() { // NOLINT: due to GTEST_SKIP
3304 auto type = Choose<DataType>(kAllXlaTypes);
3305 auto align = Choose<std::string>(
3306 {"LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT"});
3307 auto k0 = std::uniform_int_distribution<int32>(
3308 -2 * kDefaultMaxDimensionSize,
3309 2 * kDefaultMaxDimensionSize)(generator());
3310 auto k1 = std::uniform_int_distribution<int32>(
3311 k0, 2 * kDefaultMaxDimensionSize)(generator());
3312 auto k = test::AsTensor<int32>({k0, k1});
3313 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPartV3")
3314 .RandomInput(type)
3315 .Input(k)
3316 .RandomInput(type, {})
3317 .Attr("align", align)
3318 .Attr("T", type));
3319 });
3320 }
3321
TEST_F(OpTest,MatrixSetDiag)3322 TEST_F(OpTest, MatrixSetDiag) {
3323 Repeatedly([this]() {
3324 auto type = Choose<DataType>(kAllXlaTypes);
3325 auto shape = RandomDims(2);
3326 int rank = shape.size();
3327 std::vector<int64_t> diagonal_shape(shape);
3328 diagonal_shape.pop_back();
3329 diagonal_shape.pop_back();
3330 diagonal_shape.push_back(std::min(shape[rank - 2], shape[rank - 1]));
3331 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixSetDiag")
3332 .RandomInput(type, shape)
3333 .RandomInput(type, diagonal_shape)
3334 .Attr("T", type));
3335 });
3336 }
3337
TEST_F(OpTest,MatrixSetDiagV2)3338 TEST_F(OpTest, MatrixSetDiagV2) {
3339 Repeatedly([this]() {
3340 auto type = Choose<DataType>(kAllXlaTypes);
3341 auto shape = RandomDims(2, kDefaultMaxRank, 1 /* non-zero dims */);
3342 int rank = shape.size();
3343 int64_t max_num_diags = shape[rank - 2] + shape[rank - 1] - 1;
3344 int64_t num_diags =
3345 std::uniform_int_distribution<int64_t>(2, max_num_diags)(generator());
3346 int32 k0 = std::uniform_int_distribution<int32>(
3347 -shape[rank - 2] + 1, shape[rank - 1] - num_diags)(generator());
3348 int32 k1 = k0 + num_diags - 1;
3349 Tensor k = test::AsTensor<int32>({k0, k1});
3350 int64_t max_diag_len = std::min(shape[rank - 2] + std::min(k1, 0),
3351 shape[rank - 1] + std::min(-k0, 0));
3352 std::vector<int64_t> diagonal_shape(shape);
3353 diagonal_shape.pop_back();
3354 diagonal_shape.pop_back();
3355 diagonal_shape.push_back(num_diags);
3356 diagonal_shape.push_back(max_diag_len);
3357 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixSetDiagV2")
3358 .RandomInput(type, shape)
3359 .RandomInput(type, diagonal_shape)
3360 .Input(k)
3361 .Attr("T", type));
3362 });
3363 }
3364
TEST_F(OpTest,Max)3365 TEST_F(OpTest, Max) {
3366 Repeatedly([this]() {
3367 auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
3368 std::vector<int64_t> data_dims = RandomDims();
3369 Tensor indices = RandomReductionIndices(data_dims.size());
3370 bool keep_dims = Choose<bool>({false, true});
3371 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Max")
3372 .RandomInput(type, data_dims)
3373 .Input(indices)
3374 .Attr("T", type)
3375 .Attr("keep_dims", keep_dims));
3376 });
3377 }
3378
TEST_F(OpTest,Maximum)3379 TEST_F(OpTest, Maximum) {
3380 Repeatedly([this]() {
3381 auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
3382 auto dims = BroadcastableDims();
3383 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Maximum")
3384 .RandomInput(type, dims.first)
3385 .RandomInput(type, dims.second)
3386 .Attr("T", type));
3387 });
3388 }
3389
TEST_F(OpTest,MaxPool)3390 TEST_F(OpTest, MaxPool) {
3391 Repeatedly([this]() {
3392 std::uniform_int_distribution<int> random_int(1, 5);
3393 std::vector<int64_t> dims = RandomDims(4, 4, 1);
3394 int kernel_rows =
3395 std::uniform_int_distribution<int>(1, dims[1])(generator());
3396 int kernel_cols =
3397 std::uniform_int_distribution<int>(1, dims[2])(generator());
3398 int stride_rows = random_int(generator()),
3399 stride_cols = random_int(generator());
3400
3401 string padding = Choose<string>({"SAME", "VALID"});
3402 return ExpectTfAndXlaOutputsAreClose(
3403 OpTestBuilder("MaxPool")
3404 .RandomInput(DT_FLOAT, dims)
3405 .Attr("T", DT_FLOAT)
3406 .Attr("ksize", {1, kernel_rows, kernel_cols, 1})
3407 .Attr("strides", {1, stride_rows, stride_cols, 1})
3408 .Attr("padding", padding)
3409 .Attr("data_format", "NHWC"));
3410 });
3411 // TODO(phawkins): test NCHW format (not supported by CPU)
3412 }
3413
TEST_F(OpTest,MaxPool3D)3414 TEST_F(OpTest, MaxPool3D) {
3415 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3416 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
3417 Repeatedly([this]() {
3418 std::uniform_int_distribution<int> random_int(1, 5);
3419 std::vector<int64_t> dims = RandomDims(5, 5, 1);
3420
3421 std::vector<int64_t> input_dims, kernel_dims, stride_dims;
3422 kernel_dims.push_back(1);
3423 stride_dims.push_back(1);
3424 for (int i = 0; i < 3; ++i) {
3425 kernel_dims.push_back(
3426 std::uniform_int_distribution<int>(1, dims[i])(generator()));
3427 input_dims.push_back(dims[i]);
3428 stride_dims.push_back(random_int(generator()));
3429 }
3430 kernel_dims.push_back(1);
3431 stride_dims.push_back(1);
3432 int64_t batch = dims[3];
3433 int64_t feature = dims[4];
3434
3435 string padding = Choose<string>({"SAME", "VALID"});
3436 return ExpectTfAndXlaOutputsAreClose(
3437 OpTestBuilder("MaxPool3D")
3438 .RandomInput(DT_FLOAT,
3439 ImageDims(FORMAT_NHWC, batch, feature, input_dims))
3440 .Attr("T", DT_FLOAT)
3441 .Attr("ksize", kernel_dims)
3442 .Attr("strides", stride_dims)
3443 .Attr("padding", padding)
3444 .Attr("data_format", "NDHWC"));
3445 });
3446 // TODO(phawkins): test NCHW format (not supported by CPU)
3447 }
3448
TEST_F(OpTest,Mean)3449 TEST_F(OpTest, Mean) {
3450 Repeatedly([this]() {
3451 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3452 // TODO(phawkins): CPU and XLA differ output for reducing across a
3453 // size-0 dimension (nan vs 0). For now, require size >= 1.
3454 std::vector<int64_t> data_dims = RandomDims(0, kDefaultMaxRank, 1);
3455 Tensor indices = RandomReductionIndices(data_dims.size());
3456 bool keep_dims = Choose<bool>({false, true});
3457 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mean")
3458 .RandomInput(type, data_dims)
3459 .Input(indices)
3460 .Attr("T", type)
3461 .Attr("keep_dims", keep_dims));
3462 });
3463 }
3464
TEST_F(OpTest,Min)3465 TEST_F(OpTest, Min) {
3466 Repeatedly([this]() {
3467 auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
3468 std::vector<int64_t> data_dims = RandomDims();
3469 Tensor indices = RandomReductionIndices(data_dims.size());
3470 bool keep_dims = Choose<bool>({false, true});
3471 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Min")
3472 .RandomInput(type, data_dims)
3473 .Input(indices)
3474 .Attr("T", type)
3475 .Attr("keep_dims", keep_dims));
3476 });
3477 }
3478
TEST_F(OpTest,Minimum)3479 TEST_F(OpTest, Minimum) {
3480 Repeatedly([this]() {
3481 auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
3482 auto dims = BroadcastableDims();
3483 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Minimum")
3484 .RandomInput(type, dims.first)
3485 .RandomInput(type, dims.second)
3486 .Attr("T", type));
3487 });
3488 }
3489
TEST_F(OpTest,Mod)3490 TEST_F(OpTest, Mod) {
3491 Repeatedly([this]() {
3492 auto dims = BroadcastableDims();
3493 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mod")
3494 .RandomInput(DT_INT32, dims.first)
3495 .RandomInput(DT_INT32, dims.second)
3496 .Attr("T", DT_INT32));
3497 });
3498 }
3499
TEST_F(OpTest,Mul)3500 TEST_F(OpTest, Mul) {
3501 Repeatedly([this]() {
3502 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3503 auto dims = BroadcastableDims();
3504 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul")
3505 .RandomInput(type, dims.first)
3506 .RandomInput(type, dims.second)
3507 .Attr("T", type));
3508 });
3509 }
3510
TEST_F(OpTest,MulNoNan)3511 TEST_F(OpTest, MulNoNan) {
3512 Repeatedly([this]() {
3513 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3514 auto dims = BroadcastableDims();
3515 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul")
3516 .RandomInput(type, dims.first)
3517 .RandomInput(type, dims.second)
3518 .Attr("T", type));
3519 });
3520 }
3521
TEST_F(OpTest,Neg)3522 TEST_F(OpTest, Neg) {
3523 Repeatedly([this]() {
3524 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3525 return ExpectTfAndXlaOutputsAreClose(
3526 OpTestBuilder("Neg").RandomInput(type).Attr("T", type));
3527 });
3528 }
3529
TEST_F(OpTest,NextAfter)3530 TEST_F(OpTest, NextAfter) {
3531 Repeatedly([this]() {
3532 auto type = Choose<DataType>({DT_FLOAT});
3533 auto dims = RandomDims();
3534 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NextAfter")
3535 .RandomInput(type, dims)
3536 .RandomInput(type, dims)
3537 .Attr("T", type));
3538 });
3539 }
3540
TEST_F(OpTest,NotEqual)3541 TEST_F(OpTest, NotEqual) {
3542 Repeatedly([this]() {
3543 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3544 auto dims = BroadcastableDims();
3545 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual")
3546 .RandomInput(type, dims.first)
3547 .RandomInput(type, dims.second)
3548 .Attr("T", type));
3549 });
3550 }
3551
TEST_F(OpTest,OneHot)3552 TEST_F(OpTest, OneHot) {
3553 Repeatedly([this]() {
3554 auto type = Choose<DataType>(kAllXlaTypes);
3555
3556 std::vector<int64_t> dims = RandomDims();
3557 int num_dims = dims.size();
3558
3559 int32_t depth = RandomDim();
3560
3561 Tensor indices(DT_INT32, TensorShape(dims));
3562 std::uniform_int_distribution<int32> distribution(-depth * 2, depth * 2);
3563 test::FillFn<int32>(&indices, [this, &distribution](int i) -> int32 {
3564 return distribution(generator());
3565 });
3566
3567 int axis = std::uniform_int_distribution<int32>(-num_dims - 5,
3568 num_dims + 5)(generator());
3569
3570 OpTestBuilder builder("OneHot");
3571 builder.Attr("T", type);
3572 builder.Attr("TI", DT_INT32);
3573 builder.Attr("axis", axis);
3574 builder.Input(indices);
3575 builder.Input(test::AsScalar<int32>(depth));
3576 builder.RandomInput(type, {});
3577 builder.RandomInput(type, {});
3578 return ExpectTfAndXlaOutputsAreClose(builder);
3579 });
3580 }
3581
TEST_F(OpTest,OnesLike)3582 TEST_F(OpTest, OnesLike) {
3583 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3584 Repeatedly([this]() {
3585 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3586 return ExpectTfAndXlaOutputsAreClose(
3587 OpTestBuilder("OnesLike").RandomInput(type).Attr("T", type));
3588 });
3589 }
3590
TEST_F(OpTest,Pack)3591 TEST_F(OpTest, Pack) {
3592 Repeatedly([this]() {
3593 auto type = Choose<DataType>(kAllXlaTypes);
3594 int n = std::uniform_int_distribution<int>(1, 5)(generator());
3595
3596 std::vector<int64_t> dims = RandomDims();
3597 int num_dims = dims.size();
3598 int axis = std::uniform_int_distribution<int32>(-num_dims - 1,
3599 num_dims)(generator());
3600
3601 OpTestBuilder builder("Pack");
3602 builder.Attr("T", type);
3603 builder.Attr("N", n);
3604 builder.Attr("axis", axis);
3605 for (int i = 0; i < n; ++i) {
3606 builder.RandomInput(type, dims);
3607 }
3608 return ExpectTfAndXlaOutputsAreClose(builder);
3609 });
3610 }
3611
TEST_F(OpTest,Pad)3612 TEST_F(OpTest, Pad) {
3613 // See note about failing Kokoro tests: b/214080339#comment22
3614 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3615 Repeatedly([this]() {
3616 auto a = ChoosePadArguments();
3617 return ExpectTfAndXlaOutputsAreClose(
3618 OpTestBuilder("Pad")
3619 .RandomInput(a.input_type, a.input_shape)
3620 .Input(a.paddings)
3621 .Attr("T", a.input_type)
3622 .Attr("Tpaddings", a.paddings_type));
3623 });
3624 }
3625
TEST_F(OpTest,PadV2)3626 TEST_F(OpTest, PadV2) {
3627 Repeatedly([this]() {
3628 auto a = ChoosePadArguments();
3629 return ExpectTfAndXlaOutputsAreClose(
3630 OpTestBuilder("PadV2")
3631 .RandomInput(a.input_type, a.input_shape)
3632 .Input(a.paddings)
3633 .Input(a.constant_values)
3634 .Attr("T", a.input_type)
3635 .Attr("Tpaddings", a.paddings_type));
3636 });
3637 }
3638
TEST_F(OpTest,Pow)3639 TEST_F(OpTest, Pow) {
3640 // TODO(phawkins): Feeding large DT_INT32 values to Pow() leads to
3641 // nontermination.
3642 Repeatedly([this]() {
3643 auto dims = BroadcastableDims();
3644 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3645 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pow")
3646 .RandomInput(type, dims.first)
3647 .RandomInput(type, dims.second)
3648 .Attr("T", type));
3649 });
3650 }
3651
TEST_F(OpTest,Prod)3652 TEST_F(OpTest, Prod) {
3653 Repeatedly([this]() {
3654 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3655 std::vector<int64_t> data_dims = RandomDims();
3656 Tensor indices = RandomReductionIndices(data_dims.size());
3657 bool keep_dims = Choose<bool>({false, true});
3658 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Prod")
3659 .RandomInput(type, data_dims)
3660 .Input(indices)
3661 .Attr("T", type)
3662 .Attr("keep_dims", keep_dims));
3663 });
3664 }
3665
TEST_F(OpTest,Qr)3666 TEST_F(OpTest, Qr) {
3667 Repeatedly([this]() {
3668 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3669 return ExpectTfAndXlaOutputsAreClose(
3670 OpTestBuilder("Qr")
3671 .RandomInput(type, RandomDims(2, kDefaultMaxRank, 1))
3672 .Attr("T", type)
3673 .Attr("full_matrices", true));
3674 });
3675 }
3676
TEST_F(OpTest,QuantizeAndDequantizeV2)3677 TEST_F(OpTest, QuantizeAndDequantizeV2) {
3678 Repeatedly([this]() {
3679 std::uniform_int_distribution<int64_t> num_bits_dist(1, 64);
3680 int64_t num_bits = num_bits_dist(generator());
3681 std::string round_mode = Choose<std::string>({"HALF_TO_EVEN", "HALF_UP"});
3682 auto dims = RandomDims(0, kDefaultMaxRank, 1);
3683 return ExpectTfAndXlaOutputsAreClose(
3684 OpTestBuilder("QuantizeAndDequantizeV2")
3685 .RandomInput(DT_FLOAT, dims)
3686 .RandomInput(DT_FLOAT, dims) // unused because range_given = false
3687 .RandomInput(DT_FLOAT, dims) // unused because range_given = false
3688 .Attr("signed_input", RandomBool())
3689 .Attr("num_bits", num_bits)
3690 .Attr("range_given", false)
3691 .Attr("round_mode", round_mode)
3692 .Attr("narrow_range", RandomBool())
3693 .Attr("axis", -1)
3694 .Attr("T", DT_FLOAT));
3695 });
3696 }
3697
TEST_F(OpTest,RandomShuffle)3698 TEST_F(OpTest, RandomShuffle) {
3699 // See b/209062491 as this test passes with --tf_xla_test_device=CPU:0
3700 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3701 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
3702 Repeatedly([this]() { // NOLINT: due to GTEST_SKIP
3703 auto type = Choose<DataType>(kAllXlaTypes);
3704 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RandomShuffle")
3705 .RandomInput(type, RandomDims(1))
3706 .Attr("seed", RandomSeed())
3707 .Attr("seed2", RandomSeed())
3708 .Attr("T", type));
3709 });
3710 }
3711
TEST_F(OpTest,RandomStandardNormal)3712 TEST_F(OpTest, RandomStandardNormal) {
3713 Repeatedly([this]() {
3714 auto shape_type = Choose<DataType>({DT_INT32, DT_INT64});
3715 return ExpectTfAndXlaOutputsAreClose(
3716 OpTestBuilder("RandomStandardNormal")
3717 .Input(AsIntTensor(shape_type, RandomDims()))
3718 .Attr("seed", RandomSeed())
3719 .Attr("seed2", RandomSeed())
3720 .Attr("T", shape_type)
3721 .Attr("dtype", DT_FLOAT));
3722 });
3723 }
3724
TEST_F(OpTest,RandomUniform)3725 TEST_F(OpTest, RandomUniform) {
3726 Repeatedly([this]() {
3727 auto shape_type = Choose<DataType>({DT_INT32, DT_INT64});
3728 return ExpectTfAndXlaOutputsAreClose(
3729 OpTestBuilder("RandomStandardNormal")
3730 .Input(AsIntTensor(shape_type, RandomDims()))
3731 .Attr("seed", RandomSeed())
3732 .Attr("seed2", RandomSeed())
3733 .Attr("T", shape_type)
3734 .Attr("dtype", DT_FLOAT));
3735 });
3736 }
3737
TEST_F(OpTest,Range)3738 TEST_F(OpTest, Range) {
3739 Repeatedly([this]() {
3740 auto ToScalar = [](DataType type, int x) {
3741 if (type == DT_INT32) return test::AsScalar<int32>(x);
3742 if (type == DT_INT64) return test::AsScalar<int64_t>(x);
3743 if (type == DT_FLOAT) return test::AsScalar<float>(x);
3744 if (type == DT_DOUBLE) return test::AsScalar<double>(x);
3745 LOG(FATAL) << "Unknown type " << DataTypeString(type);
3746 };
3747 std::uniform_int_distribution<int> distribution(-50, 50);
3748 DataType tidx = Choose<DataType>({DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE});
3749 return ExpectTfAndXlaOutputsAreClose(
3750 OpTestBuilder("Range")
3751 .Input(ToScalar(tidx, distribution(generator())))
3752 .Input(ToScalar(tidx, distribution(generator())))
3753 .Input(ToScalar(tidx, distribution(generator())))
3754 .Attr("Tidx", tidx));
3755 });
3756 }
3757
TEST_F(OpTest,Rank)3758 TEST_F(OpTest, Rank) {
3759 Repeatedly([this]() {
3760 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3761 return ExpectTfAndXlaOutputsAreClose(
3762 OpTestBuilder("Rank").RandomInput(type).Attr("T", type));
3763 });
3764 }
3765
TEST_F(OpTest,Real)3766 TEST_F(OpTest, Real) {
3767 Repeatedly([this]() {
3768 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Real")
3769 .RandomInput(DT_COMPLEX64)
3770 .Attr("T", DT_COMPLEX64));
3771 });
3772 }
3773
TEST_F(OpTest,RealDiv)3774 TEST_F(OpTest, RealDiv) {
3775 Repeatedly([this]() {
3776 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3777 auto dims = BroadcastableDims();
3778 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv")
3779 .RandomInput(type, dims.first)
3780 .RandomInput(type, dims.second)
3781 .Attr("T", type));
3782 });
3783 }
3784
TEST_F(OpTest,Reciprocal)3785 TEST_F(OpTest, Reciprocal) {
3786 Repeatedly([this]() {
3787 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3788 return ExpectTfAndXlaOutputsAreClose(
3789 OpTestBuilder("Reciprocal").RandomInput(type).Attr("T", type));
3790 });
3791 }
3792
TEST_F(OpTest,ReciprocalGrad)3793 TEST_F(OpTest, ReciprocalGrad) {
3794 Repeatedly([this]() {
3795 std::vector<int64_t> dims = RandomDims();
3796 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3797 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad")
3798 .RandomInput(type, dims)
3799 .RandomInput(type, dims)
3800 .Attr("T", type));
3801 });
3802 }
TEST_F(OpTest,Relu)3803 TEST_F(OpTest, Relu) {
3804 Repeatedly([this]() {
3805 return ExpectTfAndXlaOutputsAreClose(
3806 OpTestBuilder("Relu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3807 });
3808 }
3809
TEST_F(OpTest,Relu6)3810 TEST_F(OpTest, Relu6) {
3811 Repeatedly([this]() {
3812 return ExpectTfAndXlaOutputsAreClose(
3813 OpTestBuilder("Relu6").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3814 });
3815 }
3816
TEST_F(OpTest,Relu6Grad)3817 TEST_F(OpTest, Relu6Grad) {
3818 Repeatedly([this]() {
3819 auto dims = RandomDims(1);
3820 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6Grad")
3821 .RandomInput(DT_FLOAT, dims)
3822 .RandomInput(DT_FLOAT, dims)
3823 .Attr("T", DT_FLOAT));
3824 });
3825 }
3826
TEST_F(OpTest,ReluGrad)3827 TEST_F(OpTest, ReluGrad) {
3828 Repeatedly([this]() {
3829 auto dims = RandomDims(1);
3830 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReluGrad")
3831 .RandomInput(DT_FLOAT, dims)
3832 .RandomInput(DT_FLOAT, dims)
3833 .Attr("T", DT_FLOAT));
3834 });
3835 }
3836
TEST_F(OpTest,Reshape)3837 TEST_F(OpTest, Reshape) {
3838 Repeatedly([this]() {
3839 auto type = Choose<DataType>(kAllXlaTypes);
3840 std::vector<int64_t> dims = RandomDims();
3841 std::bernoulli_distribution random_bool;
3842 std::vector<int64_t> dims_before, dims_after;
3843 for (std::vector<int64_t>* out : {&dims_before, &dims_after}) {
3844 std::shuffle(dims.begin(), dims.end(), generator());
3845 for (int64_t dim : dims) {
3846 // Either add the dimension as a new dimension or merge it with the
3847 // previous dimension.
3848 if (out->empty() || random_bool(generator())) {
3849 out->push_back(dim);
3850 } else {
3851 out->back() *= dim;
3852 }
3853 }
3854 }
3855 return ExpectTfAndXlaOutputsAreClose(
3856 OpTestBuilder("Reshape")
3857 .RandomInput(type, dims_before)
3858 .Input(test::AsTensor<int32>(
3859 std::vector<int32>(dims_after.begin(), dims_after.end())))
3860 .Attr("T", type));
3861 });
3862 }
3863
TEST_F(OpTest,ResizeNearestNeighbor)3864 TEST_F(OpTest, ResizeNearestNeighbor) {
3865 Repeatedly([this]() {
3866 auto type = Choose<DataType>({DT_FLOAT, DT_INT32, DT_INT64});
3867 return ExpectTfAndXlaOutputsAreClose(
3868 OpTestBuilder("ResizeNearestNeighbor")
3869 .RandomInput(type, RandomDims(4, 4, 1))
3870 .Input(AsIntTensor(DT_INT32, RandomDims(2, kDefaultMaxRank, 1)))
3871 .Attr("align_corners", RandomBool())
3872 .Attr("half_pixel_centers", RandomBool())
3873 .Attr("T", type));
3874 });
3875 }
3876
TEST_F(OpTest,ResizeBilinear)3877 TEST_F(OpTest, ResizeBilinear) {
3878 Repeatedly([this]() {
3879 std::vector<int64_t> in_dims = RandomDims(4, 4);
3880 std::vector<int64_t> out_dims = RandomDims(2, 2);
3881
3882 return ExpectTfAndXlaOutputsAreClose(
3883 OpTestBuilder("ResizeBilinear")
3884 .RandomInput(DT_FLOAT, in_dims)
3885 .Input(test::AsTensor<int32>(
3886 std::vector<int32>(out_dims.begin(), out_dims.end())))
3887 .Attr("T", DT_FLOAT)
3888 .Attr("align_corners", true));
3889 });
3890 }
3891
TEST_F(OpTest,ResizeBilinearGrad)3892 TEST_F(OpTest, ResizeBilinearGrad) {
3893 Repeatedly([this]() {
3894 std::vector<int64_t> in_dims = RandomDims(4, 4);
3895 std::vector<int64_t> out_dims = RandomDims(2, 2);
3896
3897 return ExpectTfAndXlaOutputsAreClose(
3898 OpTestBuilder("ResizeBilinearGrad")
3899 .RandomInput(DT_FLOAT, in_dims)
3900 .RandomInput(DT_FLOAT,
3901 {in_dims[0], out_dims[0], out_dims[1], in_dims[3]})
3902 .Attr("T", DT_FLOAT)
3903 .Attr("align_corners", true));
3904 });
3905 }
3906
TEST_F(OpTest,Reverse)3907 TEST_F(OpTest, Reverse) {
3908 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3909 Repeatedly([this]() {
3910 std::vector<int64_t> dims = RandomDims(1);
3911 auto type = Choose<DataType>(kAllXlaTypes);
3912 int64_t rank = dims.size();
3913 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse")
3914 .RandomInput(type, dims)
3915 .RandomInput(DT_BOOL, {rank})
3916 .Attr("T", type));
3917 });
3918 }
3919
TEST_F(OpTest,ReverseSequence)3920 TEST_F(OpTest, ReverseSequence) {
3921 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3922 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
3923 Repeatedly([this]() {
3924 std::vector<int64_t> dims = RandomDims(/*min_rank=*/2);
3925 auto type = Choose<DataType>(kAllXlaTypes);
3926 int64_t rank = dims.size();
3927
3928 // Choose random batch and sequence dimensions.
3929 std::vector<int> shuffled_dim_ids(rank);
3930 absl::c_iota(shuffled_dim_ids, 0);
3931 absl::c_shuffle(shuffled_dim_ids, generator());
3932 shuffled_dim_ids.resize(2);
3933 int batch_dim = shuffled_dim_ids[0];
3934 int seq_dim = shuffled_dim_ids[1];
3935
3936 int batch_size = dims[batch_dim];
3937 int max_seq_len = dims[seq_dim];
3938 std::vector<int32> seq_lens(batch_size);
3939 std::uniform_int_distribution<int32> d(0, max_seq_len);
3940 absl::c_generate(seq_lens, [&]() { return d(generator()); });
3941
3942 return ExpectTfAndXlaOutputsAreClose(
3943 OpTestBuilder("ReverseSequence")
3944 .RandomInput(type, dims)
3945 .Input(test::AsTensor<int32>(seq_lens))
3946 .Attr("seq_dim", seq_dim)
3947 .Attr("batch_dim", batch_dim)
3948 .Attr("T", type)
3949 .Attr("Tlen", DT_INT32));
3950 });
3951 }
3952
TEST_F(OpTest,ReverseV2)3953 TEST_F(OpTest, ReverseV2) {
3954 Repeatedly([this]() {
3955 auto type = Choose<DataType>(kAllXlaTypes);
3956 std::vector<int64_t> data_dims = RandomDims();
3957 Tensor indices = RandomReductionIndices(data_dims.size());
3958 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2")
3959 .RandomInput(type, data_dims)
3960 .Input(indices)
3961 .Attr("T", type));
3962 });
3963 }
3964
TEST_F(OpTest,RightShift)3965 TEST_F(OpTest, RightShift) {
3966 Repeatedly([this]() {
3967 bool is64 = RandomBool();
3968 auto dims = RandomDims();
3969 auto type = is64 ? DT_INT64 : DT_INT32;
3970 int max_shift = is64 ? 63 : 31;
3971 auto y = RandomBoundedTensor(type, 0, max_shift, false, dims);
3972 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RightShift")
3973 .RandomInput(type, dims)
3974 .Input(y)
3975 .Attr("T", type));
3976 });
3977 }
3978
TEST_F(OpTest,Rint)3979 TEST_F(OpTest, Rint) {
3980 Repeatedly([this]() {
3981 return ExpectTfAndXlaOutputsAreClose(
3982 OpTestBuilder("Rint").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3983 });
3984 }
3985
TEST_F(OpTest,Roll)3986 TEST_F(OpTest, Roll) {
3987 Repeatedly([this]() {
3988 auto input_type = Choose<DataType>(kAllXlaTypes);
3989 auto axis_type = Choose<DataType>({DT_INT32, DT_INT64});
3990 // TODO(b/201095155,b/197140886): shift_type = DT_INT64 doesn't work.
3991 auto shift_type = DT_INT32;
3992 auto input_shape = RandomDims(1);
3993 int rank = input_shape.size();
3994 auto axis_shape = RandomDims(1, 1, 1, rank + 1);
3995 auto axis = RandomBoundedTensor(axis_type, 0, rank - 1, true, axis_shape);
3996 auto shift = RandomTensor(shift_type, false, axis_shape);
3997 return ExpectTfAndXlaOutputsAreClose(
3998 OpTestBuilder("Roll")
3999 .RandomInput(input_type, input_shape)
4000 .Input(shift)
4001 .Input(axis)
4002 .Attr("T", input_type)
4003 .Attr("Taxis", axis_type)
4004 .Attr("Tshift", shift_type));
4005 });
4006 }
4007
TEST_F(OpTest,Round)4008 TEST_F(OpTest, Round) {
4009 Repeatedly([this]() {
4010 return ExpectTfAndXlaOutputsAreClose(
4011 OpTestBuilder("Round").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
4012 });
4013 }
4014
TEST_F(OpTest,Rsqrt)4015 TEST_F(OpTest, Rsqrt) {
4016 Repeatedly([this]() {
4017 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4018 return ExpectTfAndXlaOutputsAreClose(
4019 OpTestBuilder("Rsqrt").RandomInput(type).Attr("T", type));
4020 });
4021 }
4022
TEST_F(OpTest,RsqrtGrad)4023 TEST_F(OpTest, RsqrtGrad) {
4024 Repeatedly([this]() {
4025 auto dims = RandomDims();
4026 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4027 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad")
4028 .RandomInput(type, dims)
4029 .RandomInput(type, dims)
4030 .Attr("T", type));
4031 });
4032 }
4033
TEST_F(OpTest,Select)4034 TEST_F(OpTest, Select) {
4035 Repeatedly([this]() {
4036 auto type = Choose<DataType>(kAllXlaTypes);
4037 auto shape = RandomDims();
4038 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Select")
4039 .RandomInput(DT_BOOL, shape)
4040 .RandomInput(type, shape)
4041 .RandomInput(type, shape)
4042 .Attr("T", type));
4043 });
4044 }
4045
TEST_F(OpTest,SelectV2)4046 TEST_F(OpTest, SelectV2) {
4047 Repeatedly([this]() {
4048 auto type = Choose<DataType>(kAllXlaTypes);
4049 auto shape = RandomDims();
4050 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SelectV2")
4051 .RandomInput(DT_BOOL, shape)
4052 .RandomInput(type, shape)
4053 .RandomInput(type, shape)
4054 .Attr("T", type));
4055 });
4056 }
4057
TEST_F(OpTest,Shape)4058 TEST_F(OpTest, Shape) {
4059 Repeatedly([this]() {
4060 auto type = Choose<DataType>(kAllXlaTypes);
4061 return ExpectTfAndXlaOutputsAreClose(
4062 OpTestBuilder("Shape").RandomInput(type).Attr("T", type));
4063 });
4064 }
4065
TEST_F(OpTest,ShapeN)4066 TEST_F(OpTest, ShapeN) {
4067 Repeatedly([this]() {
4068 auto type = Choose<DataType>(kAllXlaTypes);
4069 int n = std::uniform_int_distribution<int>(1, 5)(generator());
4070 OpTestBuilder builder("ShapeN");
4071 builder.Attr("T", type);
4072 builder.Attr("N", n);
4073 for (int i = 0; i < n; ++i) {
4074 builder.RandomInput(type);
4075 }
4076 return ExpectTfAndXlaOutputsAreClose(builder);
4077 });
4078 }
4079
TEST_F(OpTest,Sigmoid)4080 TEST_F(OpTest, Sigmoid) {
4081 Repeatedly([this]() {
4082 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4083 return ExpectTfAndXlaOutputsAreClose(
4084 OpTestBuilder("Sigmoid").RandomInput(type).Attr("T", type));
4085 });
4086 }
4087
TEST_F(OpTest,SigmoidGrad)4088 TEST_F(OpTest, SigmoidGrad) {
4089 Repeatedly([this]() {
4090 auto dims = RandomDims();
4091 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4092 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad")
4093 .RandomInput(type, dims)
4094 .RandomInput(type, dims)
4095 .Attr("T", type));
4096 });
4097 }
4098
TEST_F(OpTest,Sign)4099 TEST_F(OpTest, Sign) {
4100 Repeatedly([this]() {
4101 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
4102 return ExpectTfAndXlaOutputsAreClose(
4103 OpTestBuilder("Sign").RandomInput(type).Attr("T", type));
4104 });
4105 }
4106
TEST_F(OpTest,Sin)4107 TEST_F(OpTest, Sin) {
4108 Repeatedly([this]() {
4109 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4110 return ExpectTfAndXlaOutputsAreClose(
4111 OpTestBuilder("Sin").RandomInput(type).Attr("T", type));
4112 });
4113 }
4114
TEST_F(OpTest,Sinh)4115 TEST_F(OpTest, Sinh) {
4116 Repeatedly([this]() {
4117 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4118 return ExpectTfAndXlaOutputsAreClose(
4119 OpTestBuilder("Sinh").RandomInput(type).Attr("T", type));
4120 });
4121 }
4122
TEST_F(OpTest,Size)4123 TEST_F(OpTest, Size) {
4124 Repeatedly([this]() {
4125 auto type = Choose<DataType>(kAllXlaTypes);
4126 return ExpectTfAndXlaOutputsAreClose(
4127 OpTestBuilder("Size").RandomInput(type).Attr("T", type));
4128 });
4129 }
4130
TEST_F(OpTest,Slice)4131 TEST_F(OpTest, Slice) {
4132 Repeatedly([this]() {
4133 SliceArguments a = ChooseSliceArguments(true);
4134 std::vector<int32> size;
4135 size.insert(size.end(), a.size.begin(), a.size.end());
4136 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Slice")
4137 .RandomInput(a.type, a.shape)
4138 .Input(a.indices)
4139 .Input(test::AsTensor<int32>(size))
4140 .Attr("T", a.type)
4141 .Attr("Index", a.indices_type));
4142 });
4143 }
4144
TEST_F(OpTest,Softmax)4145 TEST_F(OpTest, Softmax) {
4146 Repeatedly([this]() {
4147 return ExpectTfAndXlaOutputsAreClose(
4148 OpTestBuilder("Softmax")
4149 .RandomInput(DT_FLOAT, RandomDims(2, 2))
4150 .Attr("T", DT_FLOAT));
4151 });
4152 }
4153
TEST_F(OpTest,SoftmaxCrossEntropyWithLogits)4154 TEST_F(OpTest, SoftmaxCrossEntropyWithLogits) {
4155 Repeatedly([this]() {
4156 std::vector<int64_t> dims = RandomDims(2, 2, 1);
4157 return ExpectTfAndXlaOutputsAreClose(
4158 OpTestBuilder("SoftmaxCrossEntropyWithLogits")
4159 .RandomInput(DT_FLOAT, dims)
4160 .RandomInput(DT_FLOAT, dims)
4161 .Attr("T", DT_FLOAT));
4162 });
4163 }
4164
TEST_F(OpTest,Softplus)4165 TEST_F(OpTest, Softplus) {
4166 Repeatedly([this]() {
4167 return ExpectTfAndXlaOutputsAreClose(
4168 OpTestBuilder("Softplus").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
4169 });
4170 }
4171
TEST_F(OpTest,SoftplusGrad)4172 TEST_F(OpTest, SoftplusGrad) {
4173 Repeatedly([this]() {
4174 std::vector<int64_t> dims = RandomDims();
4175 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftplusGrad")
4176 .RandomInput(DT_FLOAT, dims)
4177 .RandomInput(DT_FLOAT, dims)
4178 .Attr("T", DT_FLOAT));
4179 });
4180 }
4181
TEST_F(OpTest,Softsign)4182 TEST_F(OpTest, Softsign) {
4183 Repeatedly([this]() {
4184 return ExpectTfAndXlaOutputsAreClose(
4185 OpTestBuilder("Softsign").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
4186 });
4187 }
4188
TEST_F(OpTest,SoftsignGrad)4189 TEST_F(OpTest, SoftsignGrad) {
4190 Repeatedly([this]() {
4191 std::vector<int64_t> dims = RandomDims();
4192 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftsignGrad")
4193 .RandomInput(DT_FLOAT, dims)
4194 .RandomInput(DT_FLOAT, dims)
4195 .Attr("T", DT_FLOAT));
4196 });
4197 }
4198
TEST_F(OpTest,SpaceToBatch)4199 TEST_F(OpTest, SpaceToBatch) {
4200 Repeatedly([this]() {
4201 std::vector<int64_t> block_dims = RandomDims(4, 4, 0, 5);
4202 const int num_block_dims = 2;
4203 int64_t block_size = RandomDim(2, 5);
4204
4205 std::vector<int64_t> input_dims(1 + num_block_dims + 1);
4206 input_dims[0] = RandomDim();
4207 for (int i = 0; i < num_block_dims; ++i) {
4208 input_dims[1 + i] = block_dims[i] * block_size;
4209 }
4210 input_dims[1 + num_block_dims] = RandomDim();
4211
4212 std::vector<int64_t> padding_vals;
4213 std::uniform_int_distribution<int> distribution(0, 7);
4214 for (int i = 0; i < num_block_dims; ++i) {
4215 int64_t pad_before;
4216 int64_t pad_after;
4217 do {
4218 pad_before = distribution(generator());
4219 pad_after = distribution(generator());
4220 } while (pad_before + pad_after > input_dims[1 + i]);
4221 input_dims[1 + i] -= pad_before + pad_after;
4222 padding_vals.push_back(pad_before);
4223 padding_vals.push_back(pad_after);
4224 }
4225 Tensor paddings;
4226 CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
4227 TensorShape({num_block_dims, 2})));
4228
4229 auto type = Choose<DataType>(kAllXlaTypes);
4230 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch")
4231 .RandomInput(type, input_dims)
4232 .Input(paddings)
4233 .Attr("T", type)
4234 .Attr("block_size", block_size));
4235 });
4236 }
4237
TEST_F(OpTest,SpaceToBatchND)4238 TEST_F(OpTest, SpaceToBatchND) {
4239 Repeatedly([this]() {
4240 std::vector<int64_t> block_dims = RandomDims(1, 3, 0, 5);
4241 int num_block_dims = block_dims.size();
4242 std::vector<int64_t> remaining_dims = RandomDims(0, 3);
4243 std::vector<int64_t> block_multipliers =
4244 RandomDims(block_dims.size(), block_dims.size(), 0, 4);
4245
4246 std::vector<int64_t> input_dims(1 + num_block_dims + remaining_dims.size());
4247 input_dims[0] = RandomDim();
4248 for (int i = 0; i < num_block_dims; ++i) {
4249 input_dims[1 + i] = block_dims[i] * block_multipliers[i];
4250 }
4251 std::copy(remaining_dims.begin(), remaining_dims.end(),
4252 input_dims.begin() + 1 + num_block_dims);
4253
4254 std::vector<int64_t> padding_vals;
4255 std::uniform_int_distribution<int> distribution(0, 7);
4256 for (int i = 0; i < num_block_dims; ++i) {
4257 int64_t pad_before;
4258 int64_t pad_after;
4259 do {
4260 pad_before = distribution(generator());
4261 pad_after = distribution(generator());
4262 } while (pad_before + pad_after > input_dims[1 + i]);
4263 input_dims[1 + i] -= pad_before + pad_after;
4264 padding_vals.push_back(pad_before);
4265 padding_vals.push_back(pad_after);
4266 }
4267 Tensor paddings;
4268 CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
4269 TensorShape({num_block_dims, 2})));
4270
4271 auto type = Choose<DataType>(kAllXlaTypes);
4272 return ExpectTfAndXlaOutputsAreClose(
4273 OpTestBuilder("SpaceToBatchND")
4274 .RandomInput(type, input_dims)
4275 .Input(test::AsTensor<int32>(
4276 std::vector<int32>(block_dims.begin(), block_dims.end())))
4277 .Input(paddings)
4278 .Attr("T", type));
4279 });
4280 }
4281
TEST_F(OpTest,SpaceToDepth)4282 TEST_F(OpTest, SpaceToDepth) {
4283 Repeatedly([this]() {
4284 int64_t block = RandomDim(2, 5);
4285 std::vector<int64_t> input_dims = RandomDims(4, 4);
4286 // Round spatial dimensions up to a multiple of the block size
4287 input_dims[1] = (input_dims[1] + (block - 1)) / block * block;
4288 input_dims[2] = (input_dims[2] + (block - 1)) / block * block;
4289 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToDepth")
4290 .RandomInput(DT_FLOAT, input_dims)
4291 .Attr("T", DT_FLOAT)
4292 .Attr("block_size", block));
4293 });
4294 }
4295
TEST_F(OpTest,SparseMatMul)4296 TEST_F(OpTest, SparseMatMul) {
4297 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4298 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4299 Repeatedly([this]() {
4300 int64_t x = RandomDim();
4301 int64_t y = RandomDim();
4302 int64_t z = RandomDim();
4303
4304 std::vector<int64_t> a_dims = {x, y};
4305 std::vector<int64_t> b_dims = {y, z};
4306
4307 std::bernoulli_distribution random_bool;
4308 bool transpose_a = random_bool(generator());
4309 bool transpose_b = random_bool(generator());
4310 if (transpose_a) {
4311 std::swap(a_dims[0], a_dims[1]);
4312 }
4313 if (transpose_b) {
4314 std::swap(b_dims[0], b_dims[1]);
4315 }
4316
4317 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul")
4318 .RandomInput(DT_FLOAT, a_dims)
4319 .RandomInput(DT_FLOAT, b_dims)
4320 .Attr("Ta", DT_FLOAT)
4321 .Attr("Tb", DT_FLOAT)
4322 .Attr("transpose_a", transpose_a)
4323 .Attr("transpose_b", transpose_b));
4324 });
4325 }
4326
TEST_F(OpTest,SparseSoftmaxCrossEntropyWithLogits)4327 TEST_F(OpTest, SparseSoftmaxCrossEntropyWithLogits) {
4328 Repeatedly([this]() {
4329 std::vector<int64_t> dims = RandomDims(2, 2, 1);
4330 int64_t batch_size = dims[0];
4331 int64_t num_classes = dims[1];
4332
4333 std::vector<int32> indices(batch_size);
4334 for (int64_t i = 0; i < batch_size; ++i) {
4335 indices[i] =
4336 std::uniform_int_distribution<int32>(0, num_classes - 1)(generator());
4337 }
4338
4339 return ExpectTfAndXlaOutputsAreClose(
4340 OpTestBuilder("SparseSoftmaxCrossEntropyWithLogits")
4341 .RandomInput(DT_FLOAT, dims)
4342 .Input(test::AsTensor<int32>(indices))
4343 .Attr("T", DT_FLOAT)
4344 .Attr("Tlabels", DT_INT32));
4345 });
4346 }
4347
TEST_F(OpTest,Split)4348 TEST_F(OpTest, Split) {
4349 // See b/214080339#comment27 as this test causes Kokoro to crash.
4350 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4351 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4352 Repeatedly([this]() {
4353 auto type = Choose<DataType>(kAllXlaTypes);
4354 std::vector<int64_t> dims = RandomDims(1);
4355 std::uniform_int_distribution<int> ud;
4356 int32_t dim = std::uniform_int_distribution<int32>(
4357 -static_cast<int32>(dims.size()),
4358 static_cast<int32>(dims.size()) - 1)(generator());
4359 int n = std::uniform_int_distribution<int>(1, 5)(generator());
4360 // Ensure 'dim' is evenly divisible by 'n'.
4361 dims[dim] /= n;
4362 dims[dim] *= n;
4363 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split")
4364 .Input(test::AsScalar<int32>(dim))
4365 .RandomInput(type, dims)
4366 .Attr("T", type)
4367 .Attr("num_split", n));
4368 });
4369 }
4370
TEST_F(OpTest,SplitV)4371 TEST_F(OpTest, SplitV) {
4372 // Likely this only fails when dim is negative. Try type = DT_FLOAT first.
4373 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4374 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4375 Repeatedly([this]() { // NOLINT: due to GTEST_SKIP
4376 auto type = Choose<DataType>(kAllXlaTypes);
4377 std::vector<int64_t> dims = RandomDims(1, kDefaultMaxRank, 1);
4378 int32_t dim = std::uniform_int_distribution<int32>(
4379 -static_cast<int32>(dims.size()),
4380 static_cast<int32>(dims.size()) - 1)(generator());
4381 int n = std::uniform_int_distribution<int>(
4382 1, std::min(5, static_cast<int>(dims[dim])))(generator());
4383 std::vector<int32> size_splits(n);
4384 for (int i = 0; i < n - 1; ++i) {
4385 size_splits.push_back(dims[dim] / n);
4386 }
4387 size_splits.push_back(dims[dim] - (n - 1) * (dims[dim] / n));
4388 return ExpectTfAndXlaOutputsAreClose(
4389 OpTestBuilder("SplitV")
4390 .RandomInput(type, dims)
4391 .Input(test::AsTensor<int32>(size_splits))
4392 .Input(test::AsScalar<int32>(dim))
4393 .Attr("T", type)
4394 .Attr("num_split", n)
4395 .Attr("Tlen", DT_INT32));
4396 });
4397 }
4398
TEST_F(OpTest,Sqrt)4399 TEST_F(OpTest, Sqrt) {
4400 Repeatedly([this]() {
4401 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4402 return ExpectTfAndXlaOutputsAreClose(
4403 OpTestBuilder("Sqrt").RandomInput(type).Attr("T", type));
4404 });
4405 }
4406
TEST_F(OpTest,StopGradient)4407 TEST_F(OpTest, StopGradient) {
4408 Repeatedly([this]() {
4409 auto type = Choose<DataType>(kAllXlaTypes);
4410 return ExpectTfAndXlaOutputsAreClose(
4411 OpTestBuilder("StopGradient").RandomInput(type).Attr("T", type));
4412 });
4413 }
4414
TEST_F(OpTest,SqrtGrad)4415 TEST_F(OpTest, SqrtGrad) {
4416 Repeatedly([this]() {
4417 auto dims = RandomDims();
4418 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4419 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad")
4420 .RandomInput(type, dims)
4421 .RandomInput(type, dims)
4422 .Attr("T", type));
4423 });
4424 }
4425
TEST_F(OpTest,SquaredDifference)4426 TEST_F(OpTest, SquaredDifference) {
4427 Repeatedly([this]() {
4428 auto dims = BroadcastableDims();
4429 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SquaredDifference")
4430 .RandomInput(DT_FLOAT, dims.first)
4431 .RandomInput(DT_FLOAT, dims.second)
4432 .Attr("T", DT_FLOAT));
4433 });
4434 }
4435
TEST_F(OpTest,Square)4436 TEST_F(OpTest, Square) {
4437 Repeatedly([this]() {
4438 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
4439 return ExpectTfAndXlaOutputsAreClose(
4440 OpTestBuilder("Square").RandomInput(type).Attr("T", type));
4441 });
4442 }
4443
TEST_F(OpTest,Squeeze)4444 TEST_F(OpTest, Squeeze) {
4445 Repeatedly([this]() {
4446 auto type = Choose<DataType>(kAllXlaTypes);
4447 std::vector<int64_t> t_dims = RandomDims(0, kDefaultMaxRank, 0, 5);
4448 std::bernoulli_distribution random_bool;
4449 std::vector<int> squeeze_dims;
4450 for (int i = 0; i < t_dims.size(); ++i) {
4451 if (t_dims[i] == 1 && random_bool(generator())) {
4452 squeeze_dims.push_back(i);
4453 }
4454 }
4455 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Squeeze")
4456 .RandomInput(type, t_dims)
4457 .Attr("squeeze_dims", squeeze_dims)
4458 .Attr("T", type));
4459 });
4460 }
4461
TEST_F(OpTest,Sub)4462 TEST_F(OpTest, Sub) {
4463 Repeatedly([this]() {
4464 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
4465 auto dims = BroadcastableDims();
4466 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub")
4467 .RandomInput(type, dims.first)
4468 .RandomInput(type, dims.second)
4469 .Attr("T", type));
4470 });
4471 }
4472
TEST_F(OpTest,Sum)4473 TEST_F(OpTest, Sum) {
4474 Repeatedly([this]() {
4475 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
4476 std::vector<int64_t> data_dims = RandomDims();
4477 Tensor indices = RandomReductionIndices(data_dims.size());
4478 bool keep_dims = Choose<bool>({false, true});
4479 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sum")
4480 .RandomInput(type, data_dims)
4481 .Input(indices)
4482 .Attr("T", type)
4483 .Attr("keep_dims", keep_dims));
4484 });
4485 }
4486
TEST_F(OpTest,StridedSlice)4487 TEST_F(OpTest, StridedSlice) {
4488 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4489 Repeatedly([this]() {
4490 auto type = Choose<DataType>(kAllXlaTypes);
4491 std::vector<int64_t> data_dims = RandomDims();
4492 std::vector<int32> begin(data_dims.size()), end(data_dims.size());
4493 std::vector<int32> strides(data_dims.size());
4494 for (int i = 0; i < data_dims.size(); ++i) {
4495 begin[i] = std::uniform_int_distribution<int32>(
4496 -2 * data_dims[i], 2 * data_dims[i])(generator());
4497 end[i] = std::uniform_int_distribution<int32>(
4498 -2 * data_dims[i], 2 * data_dims[i])(generator());
4499 // TODO(b/31360685): support strides other than 1 or -1
4500 strides[i] = std::bernoulli_distribution()(generator()) ? 1 : -1;
4501 }
4502 int64_t max_bitmask = (1LL << data_dims.size()) - 1;
4503 std::uniform_int_distribution<int64_t> bitmask_distribution(0, max_bitmask);
4504 int64_t begin_mask = bitmask_distribution(generator());
4505 int64_t end_mask = bitmask_distribution(generator());
4506
4507 // Create a ellipsis bitmask with at most one 1 bit set.
4508 int64_t ellipsis_mask = 0;
4509 if (!data_dims.empty() && std::bernoulli_distribution()(generator())) {
4510 int ellipsis_pos = std::uniform_int_distribution<int>(
4511 0, data_dims.size() - 1)(generator());
4512 ellipsis_mask = 1LL << ellipsis_pos;
4513 }
4514
4515 int64_t new_axis_mask = bitmask_distribution(generator());
4516 int64_t shrink_axis_mask = bitmask_distribution(generator());
4517 return ExpectTfAndXlaOutputsAreClose(
4518 OpTestBuilder("StridedSlice")
4519 .RandomInput(type, data_dims)
4520 .Input(test::AsTensor<int32>(begin))
4521 .Input(test::AsTensor<int32>(end))
4522 .Input(test::AsTensor<int32>(strides))
4523 .Attr("T", type)
4524 .Attr("Index", DT_INT32)
4525 .Attr("begin_mask", begin_mask)
4526 .Attr("end_mask", end_mask)
4527 .Attr("ellipsis_mask", ellipsis_mask)
4528 .Attr("new_axis_mask", new_axis_mask)
4529 .Attr("shrink_axis_mask", shrink_axis_mask));
4530 });
4531 }
4532
TEST_F(OpTest,StridedSliceGrad)4533 TEST_F(OpTest, StridedSliceGrad) {
4534 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4535 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4536 Repeatedly([this]() {
4537 auto type = Choose<DataType>(kAllXlaTypes);
4538
4539 // Dimensions of the forward input.
4540 std::vector<int64_t> dims = RandomDims();
4541
4542 std::vector<int64_t> begin(dims.size()), end(dims.size());
4543 std::vector<int64_t> strides(dims.size());
4544 for (int i = 0; i < dims.size(); ++i) {
4545 begin[i] = std::uniform_int_distribution<int64_t>(
4546 -2 * dims[i], 2 * dims[i])(generator());
4547 end[i] = std::uniform_int_distribution<int64_t>(-2 * dims[i],
4548 2 * dims[i])(generator());
4549 strides[i] = std::uniform_int_distribution<int64_t>(
4550 -2 * dims[i], 2 * dims[i])(generator());
4551 }
4552 int64_t max_bitmask = (1LL << dims.size()) - 1;
4553 std::uniform_int_distribution<int64_t> bitmask_distribution(0, max_bitmask);
4554 int64_t begin_mask = bitmask_distribution(generator());
4555 int64_t end_mask = bitmask_distribution(generator());
4556
4557 // Create a ellipsis bitmask with at most one 1 bit set.
4558 int64_t ellipsis_mask = 0;
4559 if (!dims.empty() && std::bernoulli_distribution()(generator())) {
4560 int ellipsis_pos =
4561 std::uniform_int_distribution<int>(0, dims.size() - 1)(generator());
4562 ellipsis_mask = 1LL << ellipsis_pos;
4563 }
4564
4565 int64_t new_axis_mask = bitmask_distribution(generator());
4566 int64_t shrink_axis_mask = bitmask_distribution(generator());
4567
4568 // TODO(phawkins): use shape inference for the forward op to compute the
4569 // gradient shape for the backward op. At present, there is a low
4570 // probability of the golden op succeeding.
4571 return ExpectTfAndXlaOutputsAreClose(
4572 OpTestBuilder("StridedSliceGrad")
4573 .Input(test::AsTensor<int64_t>(dims))
4574 .Input(test::AsTensor<int64_t>(begin))
4575 .Input(test::AsTensor<int64_t>(end))
4576 .Input(test::AsTensor<int64_t>(strides))
4577 .RandomInput(type, RandomDims(1))
4578 .Attr("T", type)
4579 .Attr("Index", DT_INT64)
4580 .Attr("begin_mask", begin_mask)
4581 .Attr("end_mask", end_mask)
4582 .Attr("ellipsis_mask", ellipsis_mask)
4583 .Attr("new_axis_mask", new_axis_mask)
4584 .Attr("shrink_axis_mask", shrink_axis_mask));
4585 });
4586 }
4587
TEST_F(OpTest,Tan)4588 TEST_F(OpTest, Tan) {
4589 Repeatedly([this]() {
4590 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4591 return ExpectTfAndXlaOutputsAreClose(
4592 OpTestBuilder("Tan").RandomInput(type).Attr("T", type));
4593 });
4594 }
4595
TEST_F(OpTest,Tanh)4596 TEST_F(OpTest, Tanh) {
4597 Repeatedly([this]() {
4598 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4599 return ExpectTfAndXlaOutputsAreClose(
4600 OpTestBuilder("Tanh").RandomInput(type).Attr("T", type));
4601 });
4602 }
4603
TEST_F(OpTest,TanhGrad)4604 TEST_F(OpTest, TanhGrad) {
4605 Repeatedly([this]() {
4606 auto dims = RandomDims();
4607 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4608 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad")
4609 .RandomInput(type, dims)
4610 .RandomInput(type, dims)
4611 .Attr("T", type));
4612 });
4613 }
4614
TEST_F(OpTest,TensorScatterUpdate)4615 TEST_F(OpTest, TensorScatterUpdate) {
4616 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4617 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4618 Repeatedly([this]() { // NOLINT: due to GTEST_SKIP
4619 auto a = ChooseScatterArguments();
4620 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TensorScatterUpdate")
4621 .RandomInput(a.type, a.shape)
4622 .Input(a.indices)
4623 .Input(a.updates)
4624 .Attr("T", a.type)
4625 .Attr("Tindices", a.indices_type));
4626 });
4627 }
4628
TEST_F(OpTest,Tile)4629 TEST_F(OpTest, Tile) {
4630 Repeatedly([this]() {
4631 auto type = Choose<DataType>(kAllXlaTypes);
4632 std::vector<int64_t> t_dims = RandomDims(1);
4633 std::vector<int32> multiples(t_dims.size());
4634 for (int i = 0; i < t_dims.size(); ++i) {
4635 multiples[i] = std::uniform_int_distribution<int>(1, 3)(generator());
4636 }
4637 return ExpectTfAndXlaOutputsAreClose(
4638 OpTestBuilder("Tile")
4639 .RandomInput(type, t_dims)
4640 .Input(test::AsTensor<int32>(multiples))
4641 .Attr("T", type));
4642 });
4643 }
4644
TEST_F(OpTest,TopKV2)4645 TEST_F(OpTest, TopKV2) {
4646 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4647 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4648 Repeatedly([this]() { // NOLINT: due to GTEST_SKIP
4649 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_INT64});
4650 auto shape = RandomDims(1);
4651 int32 k = std::uniform_int_distribution<int32>(1, shape[0])(generator());
4652 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TopKV2")
4653 .RandomInput(type, shape)
4654 .Input(test::AsScalar<int32>(k))
4655 .Attr("sorted", RandomBool())
4656 .Attr("T", type));
4657 });
4658 }
4659
TEST_F(OpTest,Transpose)4660 TEST_F(OpTest, Transpose) {
4661 Repeatedly([this]() {
4662 auto type = Choose<DataType>(kAllXlaTypes);
4663 std::vector<int64_t> data_dims = RandomDims();
4664 std::vector<int32> perm(data_dims.size());
4665 std::iota(perm.begin(), perm.end(), 0);
4666 std::shuffle(perm.begin(), perm.end(), generator());
4667 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Transpose")
4668 .RandomInput(type, data_dims)
4669 .Input(test::AsTensor<int32>(perm))
4670 .Attr("T", type));
4671 });
4672 }
4673
TEST_F(OpTest,TruncateDiv)4674 TEST_F(OpTest, TruncateDiv) {
4675 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4676 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4677 Repeatedly([this]() {
4678 DataType type = DT_INT32;
4679 auto dims = BroadcastableDims();
4680 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateDiv")
4681 .RandomInput(type, dims.first)
4682 .RandomInput(type, dims.second)
4683 .Attr("T", type));
4684 });
4685 }
4686
TEST_F(OpTest,TruncateMod)4687 TEST_F(OpTest, TruncateMod) {
4688 Repeatedly([this]() {
4689 auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
4690 auto dims = BroadcastableDims();
4691 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateMod")
4692 .RandomInput(type, dims.first)
4693 .RandomInput(type, dims.second)
4694 .Attr("T", type));
4695 });
4696 }
4697
TEST_F(OpTest,Unpack)4698 TEST_F(OpTest, Unpack) {
4699 Repeatedly([this]() {
4700 auto type = Choose<DataType>(kAllXlaTypes);
4701 auto shape = RandomDims(1);
4702 int axis =
4703 std::uniform_int_distribution<int>(0, shape.size() - 1)(generator());
4704 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Unpack")
4705 .RandomInput(type, shape)
4706 .Attr("axis", axis)
4707 .Attr("T", type)
4708 .Attr("num", shape[axis]));
4709 });
4710 }
4711
TEST_F(OpTest,Xdivy)4712 TEST_F(OpTest, Xdivy) {
4713 Repeatedly([this]() {
4714 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4715 auto dims = BroadcastableDims();
4716 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Xdivy")
4717 .RandomInput(type, dims.first)
4718 .RandomInput(type, dims.second)
4719 .Attr("T", type));
4720 });
4721 }
4722
TEST_F(OpTest,XlaDot)4723 TEST_F(OpTest, XlaDot) {
4724 Repeatedly([this]() {
4725 const XlaDotArguments& a = ChooseXlaDotArguments();
4726 return ExpectTfAndXlaOutputsAreClose(
4727 OpTestBuilder("XlaDot")
4728 .RandomInput(a.dtype, a.lhs_dims)
4729 .RandomInput(a.dtype, a.rhs_dims)
4730 .Attr("dimension_numbers", a.dnums_encoded)
4731 .Attr("precision_config", a.precision_config_encoded)
4732 .Attr("T", a.dtype));
4733 });
4734 }
4735
TEST_F(OpTest,XlaDotV2)4736 TEST_F(OpTest, XlaDotV2) {
4737 Repeatedly([this]() {
4738 const XlaDotArguments& a = ChooseXlaDotArguments();
4739 return ExpectTfAndXlaOutputsAreClose(
4740 OpTestBuilder("XlaDotV2")
4741 .RandomInput(a.dtype, a.lhs_dims)
4742 .RandomInput(a.dtype, a.rhs_dims)
4743 .Attr("dimension_numbers", a.dnums_encoded)
4744 .Attr("precision_config", a.precision_config_encoded)
4745 .Attr("LhsT", a.dtype)
4746 .Attr("RhsT", a.dtype)
4747 .Attr("preferred_element_type", a.dtype));
4748 });
4749 }
4750
TEST_F(OpTest,XlaDynamicUpdateSlice)4751 TEST_F(OpTest, XlaDynamicUpdateSlice) {
4752 Repeatedly([this]() {
4753 SliceArguments a = ChooseSliceArguments(false);
4754 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("XlaDynamicUpdateSlice")
4755 .RandomInput(a.type, a.shape)
4756 .RandomInput(a.type, a.size)
4757 .Input(a.indices)
4758 .Attr("T", a.type)
4759 .Attr("Tindices", a.indices_type));
4760 });
4761 }
4762
TEST_F(OpTest,XlaEinsum)4763 TEST_F(OpTest, XlaEinsum) {
4764 Repeatedly([this]() {
4765 const EinsumArguments a = ChooseEinsumArguments();
4766 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("XlaEinsum")
4767 .RandomInput(a.type, a.lhs_dims)
4768 .RandomInput(a.type, a.rhs_dims)
4769 .Attr("equation", a.equation)
4770 .Attr("T", a.type));
4771 });
4772 }
4773
TEST_F(OpTest,XlaSort)4774 TEST_F(OpTest, XlaSort) {
4775 Repeatedly([this]() {
4776 auto type = Choose<DataType>(kAllXlaTypes);
4777 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("XlaSort")
4778 .RandomInput(type, RandomDims())
4779 .Attr("T", type));
4780 });
4781 }
4782
TEST_F(OpTest,Xlog1py)4783 TEST_F(OpTest, Xlog1py) {
4784 Repeatedly([this]() {
4785 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4786 auto dims = BroadcastableDims();
4787 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Xlog1py")
4788 .RandomInput(type, dims.first)
4789 .RandomInput(type, dims.second)
4790 .Attr("T", type));
4791 });
4792 }
4793
TEST_F(OpTest,Xlogy)4794 TEST_F(OpTest, Xlogy) {
4795 Repeatedly([this]() {
4796 auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4797 auto dims = BroadcastableDims();
4798 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Xlogy")
4799 .RandomInput(type, dims.first)
4800 .RandomInput(type, dims.second)
4801 .Attr("T", type));
4802 });
4803 }
4804
TEST_F(OpTest,ZerosLike)4805 TEST_F(OpTest, ZerosLike) {
4806 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4807 Repeatedly([this]() {
4808 auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
4809 return ExpectTfAndXlaOutputsAreClose(
4810 OpTestBuilder("ZerosLike").RandomInput(type).Attr("T", type));
4811 });
4812 }
4813
TEST_F(OpTest,Zeta)4814 TEST_F(OpTest, Zeta) {
4815 Repeatedly([this]() {
4816 auto dims = BroadcastableDims();
4817 return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Xlogy")
4818 .RandomInput(DT_FLOAT, dims.first)
4819 .RandomInput(DT_FLOAT, dims.second)
4820 .Attr("T", DT_FLOAT));
4821 });
4822 }
4823
4824 // Example failing run:
4825 // --tf_xla_reference_device=GPU:0
4826 // --tf_xla_test_use_jit=true --tf_xla_test_device=GPU:0
4827 // --tf_xla_test_use_mlir=true
4828 // --tf_xla_test_repetitions=2
4829 // --gunit_filter='OpTest.FusedBatchNormTraining'
4830 // --tf_xla_random_seed=2838146746
TEST_F(OpTest,FusedBatchNormTraining)4831 TEST_F(OpTest, FusedBatchNormTraining) {
4832 if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4833 if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4834 bool is_nhwc = RandomBool();
4835 std::vector<int64_t> x_dims = RandomDims(/*min_rank=*/4, /*max_rank=*/4,
4836 /*min_size=*/5, /*max_size=*/20);
4837 std::vector<int64_t> scale_dims = {x_dims[is_nhwc ? 3 : 1]};
4838 std::vector<int64_t> offset_dims = {x_dims[is_nhwc ? 3 : 1]};
4839 std::vector<int64_t> mean_dims = {0};
4840 std::vector<int64_t> variance_dims = {0};
4841 DataType type = DT_FLOAT;
4842 Repeatedly([&] {
4843 return ExpectTfAndXlaOutputsAreClose(
4844 OpTestBuilder("FusedBatchNorm")
4845 .RandomInput(type, x_dims)
4846 .RandomInput(type, scale_dims)
4847 .RandomInput(type, offset_dims)
4848 .RandomInput(type, mean_dims)
4849 .RandomInput(type, variance_dims)
4850 .Attr("T", type)
4851 .Attr("data_format", is_nhwc ? "NHWC" : "NCHW")
4852 .Attr("epsilon", static_cast<float>(1.001e-05))
4853 .Attr("is_training", true));
4854 });
4855 }
4856 } // anonymous namespace
4857 } // namespace tensorflow
4858
main(int argc,char ** argv)4859 int main(int argc, char** argv) {
4860 tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU:0");
4861 tensorflow::tf_xla_reference_device_ptr = new tensorflow::string("CPU:0");
4862 std::vector<tensorflow::Flag> flag_list = {
4863 tensorflow::Flag(
4864 "tf_xla_random_seed", &tensorflow::tf_xla_random_seed,
4865 "Random seed to use for XLA tests. <= 0 means choose a seed "
4866 "nondeterministically."),
4867 // TODO(phawkins): it might make more sense to run each test up to a
4868 // configurable time bound.
4869 tensorflow::Flag("tf_xla_test_repetitions",
4870 &tensorflow::tf_xla_test_repetitions,
4871 "Number of repetitions for each test."),
4872 tensorflow::Flag("tf_xla_max_tensor_size",
4873 &tensorflow::tf_xla_max_tensor_size,
4874 "Maximum number of elements for random input tensors."),
4875 tensorflow::Flag("tf_xla_test_device", tensorflow::tf_xla_test_device_ptr,
4876 "Tensorflow device type to use for test"),
4877 tensorflow::Flag("tf_xla_reference_device",
4878 tensorflow::tf_xla_reference_device_ptr,
4879 "Tensorflow device type to use for reference"),
4880 tensorflow::Flag("tf_xla_test_use_jit", &tensorflow::tf_xla_test_use_jit,
4881 "Use JIT compilation for the operator under test"),
4882 tensorflow::Flag(
4883 "tf_xla_test_use_mlir", &tensorflow::tf_xla_test_use_mlir,
4884 "Use MLIR legalization kernels for the operator under test"),
4885 };
4886 tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
4887 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
4888 if (!parse_result) {
4889 LOG(ERROR) << "\n" << usage;
4890 return 2;
4891 }
4892 testing::InitGoogleTest(&argc, argv);
4893 if (argc > 1) {
4894 LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
4895 return 2;
4896 }
4897 // XLA devices register kernels at construction time; create all known devices
4898 // to make sure the kernels are registered.
4899 std::vector<std::unique_ptr<tensorflow::Device>> devices;
4900 TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
4901 tensorflow::SessionOptions(), "", &devices));
4902 tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
4903
4904 tensorflow::Device* ignored;
4905 TF_QCHECK_OK(
4906 device_mgr.LookupDevice(*tensorflow::tf_xla_test_device_ptr, &ignored))
4907 << "Unknown test device (" << *tensorflow::tf_xla_test_device_ptr
4908 << "). Did you build in the right configuration (e.g., is CUDA enabled)?";
4909
4910 if (tensorflow::tf_xla_test_use_mlir)
4911 tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
4912 tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
4913 return RUN_ALL_TESTS();
4914 }
4915