xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tests/randomized_tests.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Randomized tests for XLA implementations of Tensorflow operations.
17 //
18 // For each operator, the tests in this file choose a set of random inputs and
19 // attributes. The test then compares the outputs of the operator when executed
20 // via Tensorflow using the CPU device and when executed via XLA.
21 //
22 // By default, each test chooses a random seed nondeterministically (using
23 // std::random_device). However, a particular choice of random seed can be
24 // forced using the flag --tf_xla_random_seed; each test logs the
25 // flag value necessary to reproduce its outputs.
26 //
27 // Example usage:
28 // Run tests, comparing the Tensorflow CPU operators with their XLA-compiled
29 // counterparts:
30 // randomized_tests \
31 //   --tf_xla_test_use_jit=true --tf_xla_test_device=CPU:0 \
32 //   --tf_xla_test_repetitions=20
33 
34 // TODO(phawkins): add tests for:
35 // * DepthwiseConv2DNative
36 // * Gather
37 // * InvertPermutation
38 // * MaxPoolGrad (requires implementation of forward operator)
39 // * Select
40 // * Unpack
41 //
42 // TODO(phawkins): improve tests for:
43 // * StridedSliceGrad (need to use shape function to compute sensible inputs)
44 
45 #include <algorithm>
46 #include <random>
47 #include <unordered_map>
48 
49 #include "absl/algorithm/container.h"
50 #include "absl/container/flat_hash_set.h"
51 #include "absl/strings/str_cat.h"
52 #include "absl/strings/string_view.h"
53 #include "tensorflow/compiler/jit/defs.h"
54 #include "tensorflow/compiler/jit/flags.h"
55 #include "tensorflow/compiler/tf2xla/type_util.h"
56 #include "tensorflow/core/common_runtime/device.h"
57 #include "tensorflow/core/common_runtime/device_factory.h"
58 #include "tensorflow/core/common_runtime/device_mgr.h"
59 #include "tensorflow/core/common_runtime/graph_constructor.h"
60 #include "tensorflow/core/framework/kernel_shape_util.h"
61 #include "tensorflow/core/framework/node_def_builder.h"
62 #include "tensorflow/core/framework/node_def_util.h"
63 #include "tensorflow/core/framework/op_kernel.h"
64 #include "tensorflow/core/framework/tensor.h"
65 #include "tensorflow/core/framework/tensor_testutil.h"
66 #include "tensorflow/core/framework/types.pb.h"
67 #include "tensorflow/core/graph/graph.h"
68 #include "tensorflow/core/lib/core/status.h"
69 #include "tensorflow/core/lib/core/status_test_util.h"
70 #include "tensorflow/core/platform/bfloat16.h"
71 #include "tensorflow/core/platform/test.h"
72 #include "tensorflow/core/public/session.h"
73 #include "tensorflow/core/public/session_options.h"
74 #include "tensorflow/core/util/command_line_flags.h"
75 #include "tensorflow/core/util/device_name_utils.h"
76 #include "tensorflow/core/util/tensor_format.h"
77 
78 namespace tensorflow {
79 namespace {
80 
81 // Command line flags: see main() below.
82 int64_t tf_xla_random_seed = 0;
83 int32_t tf_xla_test_repetitions = 20;
84 int64_t tf_xla_max_tensor_size = 10000LL;
85 string* tf_xla_test_device_ptr;       // initial value set in main()
86 string* tf_xla_reference_device_ptr;  // initial value set in main()
87 bool tf_xla_test_use_jit = true;
88 bool tf_xla_test_use_mlir = false;
89 
LocalDeviceToFullDeviceName(const string & device)90 string LocalDeviceToFullDeviceName(const string& device) {
91   return absl::StrCat("/job:localhost/replica:0/task:0/device:", device);
92 }
93 
94 constexpr std::array<DataType, 5> kAllXlaTypes = {
95     {DT_INT32, DT_INT64, DT_FLOAT, DT_BOOL, DT_COMPLEX64}};
96 constexpr std::array<DataType, 4> kAllNumberTypes = {
97     {DT_INT32, DT_INT64, DT_FLOAT, DT_COMPLEX64}};
98 
99 // An OpTestBuilder is a graph builder class that takes as input an operator to
100 // test, its inputs and attributes, and builds a graph that executes the
101 // operator.
102 class OpTestBuilder {
103  public:
104   explicit OpTestBuilder(const string& op_name);
105 
106   // Adds an input 'tensor' as a Placeholder node.
107   OpTestBuilder& Input(const Tensor& tensor);
108 
109   // Adds a random input tensor with 'type' as a Placeholder node.
110   // If 'dims' is not provided, RandomDims() is used.
111   OpTestBuilder& RandomInput(DataType type);
112   OpTestBuilder& RandomInput(DataType type, std::vector<int64_t> dims);
113 
114   // As RandomInput but the values are unique.
115   OpTestBuilder& RandomUniqueInput(DataType type, std::vector<int64_t> dims);
116 
117   // Add variadic input tensors as Placehodler nodes.
118   OpTestBuilder& VariadicInput(const std::vector<Tensor>& tensor);
119 
120   // Sets an attribute.
121   template <class T>
122   OpTestBuilder& Attr(absl::string_view attr_name, T&& value);
123 
124   // Overload needed to allow {...} expressions for value.
125   template <class T>
126   OpTestBuilder& Attr(absl::string_view attr_name,
127                       std::initializer_list<T> value);
128 
129   // Adds nodes that executes the operator under test on 'device' to 'graphdef'.
130   // If 'use_jit' is true, marks the operator under test to be compiled by XLA.
131   // The graph will consist of one Placeholder node per input, the operator
132   // itself, and one Identity node per output. If 'test_node_def' is not null,
133   // sets it to the NodeDef of the operator under test. Fills 'inputs' and
134   // 'outputs' with the names of the input placeholder nodes and the output
135   // identity nodes, respectively.
136   Status BuildGraph(const string& name_prefix, const string& device,
137                     bool use_jit, GraphDef* graphdef, NodeDef** test_node_def,
138                     std::vector<string>* inputs,
139                     std::vector<string>* outputs) const;
140 
141   struct InputDescription {
142     Tensor tensor;
143 
144     DataType type = DT_INVALID;
145     bool has_dims = false;
146     bool needs_unique_values = false;
147     std::vector<int64_t> dims;
148   };
149 
inputs() const150   const std::vector<InputDescription>& inputs() const { return inputs_; }
151 
152  private:
153   NodeDef node_def_;
154   std::vector<InputDescription> inputs_;
155 };
156 
OpTestBuilder(const string & op_name)157 OpTestBuilder::OpTestBuilder(const string& op_name) {
158   node_def_.set_op(op_name);
159 }
160 
Input(const Tensor & tensor)161 OpTestBuilder& OpTestBuilder::Input(const Tensor& tensor) {
162   VLOG(1) << "Adding input: " << tensor.DebugString();
163   InputDescription input;
164   input.tensor = tensor;
165   inputs_.push_back(input);
166   return *this;
167 }
168 
RandomInput(DataType type)169 OpTestBuilder& OpTestBuilder::RandomInput(DataType type) {
170   VLOG(1) << "Adding random input: " << type;
171   InputDescription input;
172   input.type = type;
173   inputs_.push_back(input);
174   return *this;
175 }
176 
RandomInput(DataType type,std::vector<int64_t> dims)177 OpTestBuilder& OpTestBuilder::RandomInput(DataType type,
178                                           std::vector<int64_t> dims) {
179   VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString();
180   InputDescription input;
181   input.type = type;
182   input.has_dims = true;
183   input.dims = std::move(dims);
184   inputs_.push_back(input);
185   return *this;
186 }
187 
RandomUniqueInput(DataType type,std::vector<int64_t> dims)188 OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type,
189                                                 std::vector<int64_t> dims) {
190   VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString();
191   InputDescription input;
192   input.type = type;
193   input.has_dims = true;
194   input.needs_unique_values = true;
195   input.dims = std::move(dims);
196   inputs_.push_back(input);
197   return *this;
198 }
199 
VariadicInput(const std::vector<Tensor> & tensors)200 OpTestBuilder& OpTestBuilder::VariadicInput(
201     const std::vector<Tensor>& tensors) {
202   VLOG(1) << "Adding variadic input of length " << tensors.size() << ":";
203   for (auto& t : tensors) {
204     Input(t);
205   }
206   return *this;
207 }
208 
209 template <class T>
Attr(absl::string_view attr_name,T && value)210 OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, T&& value) {
211   AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
212   return *this;
213 }
214 
215 template <class T>
Attr(absl::string_view attr_name,std::initializer_list<T> value)216 OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name,
217                                    std::initializer_list<T> value) {
218   Attr<std::initializer_list<T>>(attr_name, std::move(value));
219   return *this;
220 }
221 
BuildGraph(const string & name_prefix,const string & device,bool use_jit,GraphDef * graphdef,NodeDef ** test_node_def,std::vector<string> * inputs,std::vector<string> * outputs) const222 Status OpTestBuilder::BuildGraph(const string& name_prefix,
223                                  const string& device, bool use_jit,
224                                  GraphDef* graphdef, NodeDef** test_node_def,
225                                  std::vector<string>* inputs,
226                                  std::vector<string>* outputs) const {
227   OpRegistryInterface* op_registry = OpRegistry::Global();
228 
229   const OpDef* op_def;
230   TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(node_def_.op(), &op_def));
231 
232   NodeDef* test_def = graphdef->add_node();
233   *test_def = node_def_;
234   test_def->set_name(absl::StrCat(name_prefix, "_op_under_test"));
235   test_def->set_device(device);
236   AddDefaultsToNodeDef(*op_def, test_def);
237   if (use_jit) {
238     AddNodeAttr(kXlaCompileAttr, true, test_def);
239   }
240   VLOG(1) << "Op under test: " << test_def->DebugString();
241 
242   DataTypeVector input_types, output_types;
243   TF_RETURN_IF_ERROR(
244       InOutTypesForNode(*test_def, *op_def, &input_types, &output_types));
245 
246   // Build feed and fetch nodes.
247   for (int i = 0; i < input_types.size(); ++i) {
248     NodeDef* def = graphdef->add_node();
249     string name = absl::StrCat(name_prefix, "_input_", i);
250     TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder")
251                            .Device(device)
252                            .Attr("dtype", input_types[i])
253                            .Finalize(def));
254     inputs->push_back(name);
255     test_def->add_input(name);
256   }
257 
258   for (int i = 0; i < output_types.size(); ++i) {
259     NodeDef* def = graphdef->add_node();
260     string name = absl::StrCat(name_prefix, "_output_", i);
261     TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity")
262                            .Device(device)
263                            .Attr("T", output_types[i])
264                            .Input(test_def->name(), i, output_types[i])
265                            .Finalize(def));
266     outputs->push_back(name);
267   }
268 
269   if (test_node_def) {
270     *test_node_def = test_def;
271   }
272 
273   return OkStatus();
274 }
275 
276 // Test fixture. The fixture manages the random number generator and its seed,
277 // and has a number of convenience methods for building random Tensors, shapes,
278 // etc.
279 class OpTest : public ::testing::Test {
280  public:
281   OpTest();
282 
283   enum TestResult {
284     // The test saw an unrecoverable error. Don't try any more runs.
285     kFatalError,
286     // The parameters of the test were invalid (e.g., the "golden"
287     // implementation failed, or the parameters are oversize). Reruns are ok.
288     kInvalid,
289     // The test ran successfully, and we have a verdict. Does *not* mean the
290     // test passed.
291     kOk,
292   };
293 
294   // Runs 'fn' up to --tf_xla_test_repetitions times, or until a test failure
295   // occurs; whichever happens first. Reruns if the TestResult is kInvalid.
296   void Repeatedly(const std::function<TestResult(void)>& fn);
297 
298   // Select a random element from 'candidates'.
299   template <typename T>
300   T Choose(absl::Span<const T> candidates);
301 
302   static constexpr int kDefaultMaxRank = 5;
303   static constexpr int64_t kDefaultMaxDimensionSize = 256LL;
304 
305   // Returns true if 'dims' have a size less than tf_xla_max_tensor_size.
306   bool TensorSizeIsOk(absl::Span<const int64_t> dims);
307 
308   // Returns a random dimension size, in the range [min, max).
309   int64_t RandomDim(int64_t min = 0, int64_t max = kDefaultMaxDimensionSize);
310 
311   // Returns a random shape. The tensor has rank in the range [min_rank,
312   // max_rank). Each dimension has size [min_size, max_size).
313   std::vector<int64_t> RandomDims(int min_rank = 0,
314                                   int max_rank = kDefaultMaxRank,
315                                   int64_t min_size = 0,
316                                   int64_t max_size = kDefaultMaxDimensionSize);
317 
318   // Given a shape 'dims', build dimensions that are broadcastable to 'dims'.
319   std::vector<int64_t> BroadcastableToDims(std::vector<int64_t> dims);
320 
321   // Given a shape 'dims', build a pair of dimensions such that one broadcasts
322   // to the other.
323   std::pair<std::vector<int64_t>, std::vector<int64_t>> BroadcastableDims(
324       std::vector<int64_t> dims);
325 
326   // Builds a random pair of broadcastable dims.
327   // TODO(phawkins): currently the maximum rank is 3, because broadcasting > 3
328   // dimensions is unimplemented by the Tensorflow Eigen code (b/29268487)
329   std::pair<std::vector<int64_t>, std::vector<int64_t>> BroadcastableDims();
330 
331   // Returns a tensor filled with random but "reasonable" values from the middle
332   // of the type's range. If the shape is omitted, a random shape is used.
333   // TODO(phawkins): generalize this code to a caller-supplied distribution.
334   Tensor RandomTensor(DataType dtype, bool needs_unique_values,
335                       absl::Span<const int64_t> shape);
336   Tensor RandomTensor(DataType dtype);
337 
338   // Like RandomTensor, but uses values >= 0.
339   Tensor RandomNonNegativeTensor(DataType dtype,
340                                  absl::Span<const int64_t> shape);
341   Tensor RandomNonNegativeTensor(DataType dtype);
342 
343   // Like RandomTensor, but all values are in the range [lo, hi].
344   template <typename T>
345   Tensor RandomBoundedTensor(DataType dtype, T lo, T hi,
346                              bool needs_unique_values,
347                              absl::Span<const int64_t> shape);
348   template <typename T>
349   Tensor RandomBoundedTensor(DataType dtype, T lo, T hi,
350                              bool needs_unique_values);
351 
352   // Like RandomTensor, but the value at index i is in the range [lo[i], hi[i]].
353   Tensor RandomBoundedTensor(DataType dtype, Tensor lo, Tensor hi);
354 
355   // Like RandomTensor, but return a pair {left, right} with
356   // left[i] <= right[i].
357   std::pair<Tensor, Tensor> RandomLteTensors(DataType dtype,
358                                              absl::Span<const int64_t> shape);
359   std::pair<Tensor, Tensor> RandomLteTensors(DataType dtype);
360 
361   // Returns a random subset of the integers in the range [0, rank), suitable
362   // for use as reduction indices.
363   Tensor RandomReductionIndices(int rank);
364 
365   // Returns a random bit.
366   bool RandomBool();
367 
368   // Randomly choose a seed for a random number generator.
369   int64_t RandomSeed();
370 
371   struct WindowedSpatialDims {
372     Padding padding;
373     std::vector<int64_t> kernel_dims;
374     std::vector<int64_t> stride_dims;
375     std::vector<int64_t> input_dims;
376     std::vector<int64_t> output_dims;
377   };
378   // Choose spatial dimensions for a windowed op such as pooling or convolution.
379   WindowedSpatialDims ChooseWindowedSpatialDims(int num_spatial_dims);
380 
381   struct BatchMatMulArguments {
382     std::vector<int64_t> lhs_dims;
383     std::vector<int64_t> rhs_dims;
384     DataType dtype;
385     bool adj_lhs;
386     bool adj_rhs;
387   };
388   // Choose arguments for the tf.BatchMatMul{V2} ops.
389   BatchMatMulArguments ChooseBatchMatMulArguments(bool broadcastable_batch);
390 
391   struct ConcatArguments {
392     std::vector<Tensor> values;
393     Tensor axis;
394     int n;
395     DataType type;
396     DataType type_idx;
397   };
398   // Choose arguments for the tf.Concat{V2} ops.
399   ConcatArguments ChooseConcatArguments(bool int64_idx_allowed);
400 
401   struct EinsumArguments {
402     std::vector<int64_t> lhs_dims;
403     std::vector<int64_t> rhs_dims;
404     DataType type;
405     std::string equation;
406   };
407   // Choose arguments for the tf.{Xla}Einsum ops.
408   EinsumArguments ChooseEinsumArguments();
409 
410   struct GatherArguments {
411     int64_t batch_dims;
412     DataType axis_type;
413     DataType indices_type;
414     DataType params_type;
415     std::vector<int64_t> params_shape;
416     Tensor indices;
417     Tensor axis;
418   };
419   // Choose arguments for the tf.Gather{V2} ops.
420   GatherArguments ChooseGatherArguments(bool axis_0);
421 
422   struct PadArguments {
423     DataType input_type;
424     DataType paddings_type;
425     std::vector<int64_t> input_shape;
426     Tensor paddings;
427     Tensor constant_values;
428   };
429   // Choose arguments for the tf.Pad{V2} ops.
430   PadArguments ChoosePadArguments();
431 
432   struct ScatterArguments {
433     DataType type;
434     DataType indices_type;
435     Tensor indices;
436     Tensor updates;
437     std::vector<int64_t> shape;
438   };
439   // Choose arguments for ScatterNd and TensorScatterUpdate.
440   ScatterArguments ChooseScatterArguments();
441 
442   struct SliceArguments {
443     DataType type;
444     DataType indices_type;
445     std::vector<int64_t> shape;
446     Tensor indices;
447     std::vector<int64_t> size;
448   };
449   // Choose arguments for the tf.{XlaDynamicUpdate}Slice ops.
450   SliceArguments ChooseSliceArguments(bool neg_one_size);
451 
452   struct XlaDotArguments {
453     std::vector<int64_t> lhs_dims;
454     std::vector<int64_t> rhs_dims;
455     std::string dnums_encoded;
456     std::string precision_config_encoded;
457     DataType dtype;
458   };
459   // Choose arguments for tf.XlaDot operation.
460   XlaDotArguments ChooseXlaDotArguments();
461 
462   // Builds dimensions for a windowed op such as pooling or convolution,
463   // including a batch and feature dimension.
464   std::vector<int64_t> ImageDims(TensorFormat format, int batch, int feature,
465                                  const std::vector<int64_t>& spatial_dims);
466 
467   // Converts an int64 vector to an int32 vector.
468   std::vector<int32> AsInt32s(const std::vector<int64_t>& int64s);
469 
generator()470   std::mt19937& generator() { return *generator_; }
471 
472   // Run the test case described by 'builder' with and without XLA and check
473   // that the outputs are close. Tensors x and y are close if they have the same
474   // type, same shape, and have close values. For floating-point tensors, the
475   // element-wise difference between x and y must no more than
476   // atol + rtol * abs(x); or both elements may be NaN or infinity. For
477   // non-floating-point tensors the element values must match exactly.
478   TestResult ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder,
479                                            double atol = 1e-2,
480                                            double rtol = 1e-2);
481 
482  protected:
483   // Per-test state:
484   std::unique_ptr<std::mt19937> generator_;
485 
486   std::unique_ptr<Session> session_;
487 
488   // Number of test cases built in 'session_'. Used to uniquify node names.
489   int num_tests_ = 0;
490 };
491 
OpTest()492 OpTest::OpTest() {
493   // Creates a random-number generator for the test case. Use the value of
494   // --tf_xla_random_seed as the seed, if provided.
495   int64_t s = tf_xla_random_seed;
496   unsigned int seed;
497   if (s <= 0) {
498     std::random_device random_device;
499     seed = random_device();
500   } else {
501     seed = static_cast<unsigned int>(s);
502   }
503   LOG(ERROR) << "Random seed for test case: " << seed
504              << ". To reproduce the "
505                 "results of this test, pass flag --tf_xla_random_seed="
506              << seed;
507   generator_.reset(new std::mt19937(seed));
508 
509   // Create a session with an empty graph.
510   SessionOptions session_options;
511   session_.reset(NewSession(session_options));
512   GraphDef def;
513   TF_CHECK_OK(session_->Create(def));
514 }
515 
516 namespace {
517 template <typename T>
TensorFromValues(DataType dtype,absl::Span<const int64_t> shape,absl::Span<T> vals)518 Tensor TensorFromValues(DataType dtype, absl::Span<const int64_t> shape,
519                         absl::Span<T> vals) {
520   Tensor tensor(dtype, TensorShape(shape));
521   test::FillValues<T>(&tensor, vals);
522   return tensor;
523 }
524 
ShapeNumVals(absl::Span<const int64_t> shape)525 int64_t ShapeNumVals(absl::Span<const int64_t> shape) {
526   int64_t num_vals = 1;
527   for (int i = 0; i < shape.size(); ++i) {
528     num_vals *= shape[i];
529   }
530   return num_vals;
531 }
532 }  // namespace
533 
534 // TensorGenerator is an abstact class that has one implementing class for each
535 // (DataType,T) pair. The implementing class implements RandomVals, which is
536 // the only Tensor generation code that is specific to the DataType.
537 template <typename T>
538 class TensorGenerator {
539  public:
TensorGenerator(OpTest & test)540   explicit TensorGenerator(OpTest& test) : test_(test) {}
~TensorGenerator()541   virtual ~TensorGenerator() {}
542   virtual DataType dtype() = 0;
543   virtual void RandomVals(std::optional<T> lo, std::optional<T> hi,
544                           bool needs_unique_values,
545                           absl::FixedArray<T>& vals) = 0;
546 
RandomTensor(std::optional<T> lo,std::optional<T> hi,bool needs_unique_values,absl::Span<const int64_t> shape)547   Tensor RandomTensor(std::optional<T> lo, std::optional<T> hi,
548                       bool needs_unique_values,
549                       absl::Span<const int64_t> shape) {
550     absl::FixedArray<T> vals(ShapeNumVals(shape));
551     RandomVals(lo, hi, needs_unique_values, vals);
552     return TensorFromValues<T>(dtype(), shape, absl::Span<T>(vals));
553   }
554 
RandomLteTensors(absl::Span<const int64_t> shape)555   std::pair<Tensor, Tensor> RandomLteTensors(absl::Span<const int64_t> shape) {
556     int64_t num_vals = ShapeNumVals(shape);
557     absl::FixedArray<T> less(num_vals);
558     RandomVals({}, {}, false, less);
559     absl::FixedArray<T> greater(num_vals);
560     RandomVals({}, {}, false, greater);
561     for (int i = 0; i < num_vals; ++i) {
562       if (less[i] > greater[i]) {
563         std::swap(less[i], greater[i]);
564       }
565     }
566     std::pair<Tensor, Tensor> pair(
567         TensorFromValues<T>(dtype(), shape, absl::Span<T>(less)),
568         TensorFromValues<T>(dtype(), shape, absl::Span<T>(greater)));
569     return pair;
570   }
571 
572  protected:
573   OpTest& test_;
574 };
575 
576 class TensorGeneratorFloat : public TensorGenerator<float> {
577  public:
TensorGeneratorFloat(OpTest & test)578   explicit TensorGeneratorFloat(OpTest& test) : TensorGenerator(test) {}
dtype()579   DataType dtype() override { return DT_FLOAT; }
RandomVals(std::optional<float> lo,std::optional<float> hi,bool needs_unique_values,absl::FixedArray<float> & vals)580   void RandomVals(std::optional<float> lo, std::optional<float> hi,
581                   bool needs_unique_values,
582                   absl::FixedArray<float>& vals) override {
583     absl::flat_hash_set<float> already_generated;
584     std::uniform_real_distribution<float> distribution(lo.value_or(-1.0f),
585                                                        hi.value_or(1.0f));
586     for (int64_t i = 0; i < vals.size(); ++i) {
587       float generated;
588       do {
589         generated = distribution(test_.generator());
590       } while (needs_unique_values &&
591                !already_generated.insert(generated).second);
592       vals[i] = (generated);
593     }
594   }
595 };
596 
597 class TensorGeneratorDouble : public TensorGenerator<double> {
598  public:
TensorGeneratorDouble(OpTest & test)599   explicit TensorGeneratorDouble(OpTest& test) : TensorGenerator(test) {}
dtype()600   DataType dtype() override { return DT_DOUBLE; }
RandomVals(std::optional<double> lo,std::optional<double> hi,bool needs_unique_values,absl::FixedArray<double> & vals)601   void RandomVals(std::optional<double> lo, std::optional<double> hi,
602                   bool needs_unique_values,
603                   absl::FixedArray<double>& vals) override {
604     absl::flat_hash_set<double> already_generated;
605     std::uniform_real_distribution<double> distribution(lo.value_or(-1.0),
606                                                         hi.value_or(1.0));
607     for (int64_t i = 0; i < vals.size(); ++i) {
608       double generated;
609       do {
610         generated = distribution(test_.generator());
611       } while (needs_unique_values &&
612                !already_generated.insert(generated).second);
613       vals[i] = generated;
614     }
615   }
616 };
617 
618 class TensorGeneratorComplex64 : public TensorGenerator<complex64> {
619  public:
TensorGeneratorComplex64(OpTest & test)620   explicit TensorGeneratorComplex64(OpTest& test) : TensorGenerator(test) {}
dtype()621   DataType dtype() override { return DT_COMPLEX64; }
RandomVals(std::optional<complex64> lo,std::optional<complex64> hi,bool needs_unique_values,absl::FixedArray<complex64> & vals)622   void RandomVals(std::optional<complex64> lo, std::optional<complex64> hi,
623                   bool needs_unique_values,
624                   absl::FixedArray<complex64>& vals) override {
625     absl::flat_hash_set<std::pair<float, float>> already_generated;
626     if (lo || hi) {
627       LOG(FATAL) << "Lower or upper bounds are not supported for complex64.";
628     }
629     std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
630     for (int64_t i = 0; i < vals.size(); ++i) {
631       complex64 generated;
632       do {
633         generated = complex64(distribution(test_.generator()),
634                               distribution(test_.generator()));
635       } while (needs_unique_values &&
636                !already_generated
637                     .insert(std::make_pair(generated.real(), generated.imag()))
638                     .second);
639       vals[i] = generated;
640     }
641   }
642 };
643 
644 class TensorGeneratorInt32 : public TensorGenerator<int32> {
645  public:
TensorGeneratorInt32(OpTest & test)646   explicit TensorGeneratorInt32(OpTest& test) : TensorGenerator(test) {}
dtype()647   DataType dtype() override { return DT_INT32; }
RandomVals(std::optional<int32> lo,std::optional<int32> hi,bool needs_unique_values,absl::FixedArray<int32> & vals)648   void RandomVals(std::optional<int32> lo, std::optional<int32> hi,
649                   bool needs_unique_values,
650                   absl::FixedArray<int32>& vals) override {
651     absl::flat_hash_set<int32> already_generated;
652     std::uniform_int_distribution<int32> distribution(lo.value_or(-(1 << 20)),
653                                                       hi.value_or(1 << 20));
654     for (int64_t i = 0; i < vals.size(); ++i) {
655       int32_t generated;
656       do {
657         generated = distribution(test_.generator());
658       } while (needs_unique_values &&
659                !already_generated.insert(generated).second);
660       vals[i] = generated;
661     }
662   }
663 };
664 
665 class TensorGeneratorInt64 : public TensorGenerator<int64> {
666  public:
TensorGeneratorInt64(OpTest & test)667   explicit TensorGeneratorInt64(OpTest& test) : TensorGenerator(test) {}
dtype()668   DataType dtype() override { return DT_INT64; }
RandomVals(std::optional<int64> lo,std::optional<int64> hi,bool needs_unique_values,absl::FixedArray<int64> & vals)669   void RandomVals(std::optional<int64> lo, std::optional<int64> hi,
670                   bool needs_unique_values,
671                   absl::FixedArray<int64>& vals) override {
672     absl::flat_hash_set<int64_t> already_generated;
673     std::uniform_int_distribution<int64_t> distribution(
674         lo.value_or(-(1LL << 40)), hi.value_or(1LL << 40));
675     for (int64_t i = 0; i < vals.size(); ++i) {
676       int64_t generated;
677       do {
678         generated = distribution(test_.generator());
679       } while (needs_unique_values &&
680                !already_generated.insert(generated).second);
681       vals[i] = generated;
682     }
683   }
684 };
685 
686 class TensorGeneratorBool : public TensorGenerator<bool> {
687  public:
TensorGeneratorBool(OpTest & test)688   explicit TensorGeneratorBool(OpTest& test) : TensorGenerator(test) {}
dtype()689   DataType dtype() override { return DT_BOOL; }
RandomVals(std::optional<bool> lo,std::optional<bool> hi,bool needs_unique_values,absl::FixedArray<bool> & vals)690   void RandomVals(std::optional<bool> lo, std::optional<bool> hi,
691                   bool needs_unique_values,
692                   absl::FixedArray<bool>& vals) override {
693     absl::flat_hash_set<bool> already_generated;
694     if (lo || hi) {
695       LOG(FATAL) << "Lower or upper bounds are not supported for bool.";
696     }
697     std::bernoulli_distribution distribution;
698     for (int64_t i = 0; i < vals.size(); ++i) {
699       bool generated;
700       do {
701         generated = distribution(test_.generator());
702       } while (needs_unique_values &&
703                !already_generated.insert(generated).second);
704       vals[i] = generated;
705     }
706   }
707 };
708 
Repeatedly(const std::function<TestResult (void)> & fn)709 void OpTest::Repeatedly(const std::function<TestResult(void)>& fn) {
710   int const max_repetitions = tf_xla_test_repetitions;
711   int valid_test_runs = 0;
712   // We run up to 100 * max_repetitions times; the idea is that if we roll the
713   // dice enough times we will find some valid parameters. We want to put an
714   // upper limit on the number iterations just in case the probability of
715   // finding feasible parameters is very low.
716   for (int i = 0; !HasFailure() && i < max_repetitions * 100 &&
717                   valid_test_runs < max_repetitions;
718        ++i) {
719     TestResult result = fn();
720     switch (result) {
721       case kOk:
722         ++valid_test_runs;
723         break;
724 
725       case kFatalError:
726         ASSERT_TRUE(false) << "Test had fatal failure";
727         return;
728 
729       case kInvalid:
730         break;
731     }
732   }
733   if (!HasFailure()) {
734     EXPECT_GE(valid_test_runs, max_repetitions)
735         << "Not enough test instances passed; this means that either the "
736            "golden implementation is buggy or the operator harness is not "
737            "producing well-formed test cases with a high probability.";
738   }
739 }
740 
741 template <typename T>
Choose(absl::Span<const T> candidates)742 T OpTest::Choose(absl::Span<const T> candidates) {
743   std::uniform_int_distribution<size_t> d(0, candidates.size() - 1);
744   return candidates[d(generator())];
745 }
746 
RandomDim(int64_t min,int64_t max)747 int64_t OpTest::RandomDim(int64_t min, int64_t max) {
748   std::uniform_int_distribution<int64_t> size_distribution(min, max - 1);
749   return size_distribution(generator());
750 }
751 
TensorSizeIsOk(absl::Span<const int64_t> dims)752 bool OpTest::TensorSizeIsOk(absl::Span<const int64_t> dims) {
753   int64_t size = 1LL;
754   for (int64_t dim : dims) {
755     size *= dim;
756   }
757   return size < tf_xla_max_tensor_size;
758 }
759 
RandomDims(int min_rank,int max_rank,int64_t min_size,int64_t max_size)760 std::vector<int64_t> OpTest::RandomDims(int min_rank, int max_rank,
761                                         int64_t min_size, int64_t max_size) {
762   CHECK_LE(0, min_rank);
763   CHECK_LE(min_rank, max_rank);
764   std::uniform_int_distribution<int> rank_distribution(min_rank, max_rank);
765   int rank = rank_distribution(generator());
766   std::vector<int64_t> dims(rank);
767   if (rank == 0) {
768     return dims;
769   }
770   int64_t per_dim_limit = std::pow(tf_xla_max_tensor_size, 1.0 / rank);
771   int64_t per_dim_max = std::min(max_size, per_dim_limit);
772   std::generate(dims.begin(), dims.end(), [this, min_size, per_dim_max]() {
773     return RandomDim(min_size, per_dim_max);
774   });
775   CHECK(TensorSizeIsOk(dims));  // Crash OK
776   return dims;
777 }
778 
RandomBool()779 bool OpTest::RandomBool() {
780   std::bernoulli_distribution d(0.5);
781   return d(generator());
782 }
783 
RandomSeed()784 int64_t OpTest::RandomSeed() {
785   std::uniform_int_distribution<int64_t> seed_dist(
786       std::numeric_limits<int64_t>::min(), std::numeric_limits<int64_t>::max());
787   int64_t seed = seed_dist(generator());
788   if (seed == 0) return 1;
789   return seed;
790 }
791 
RandomTensor(DataType dtype,bool needs_unique_values,absl::Span<const int64_t> shape)792 Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
793                             absl::Span<const int64_t> shape) {
794   switch (dtype) {
795     case DT_FLOAT:
796       return TensorGeneratorFloat(*this).RandomTensor(
797           {}, {}, needs_unique_values, shape);
798     case DT_DOUBLE:
799       return TensorGeneratorDouble(*this).RandomTensor(
800           {}, {}, needs_unique_values, shape);
801     case DT_COMPLEX64:
802       return TensorGeneratorComplex64(*this).RandomTensor(
803           {}, {}, needs_unique_values, shape);
804     case DT_INT32:
805       return TensorGeneratorInt32(*this).RandomTensor(
806           {}, {}, needs_unique_values, shape);
807     case DT_INT64:
808       return TensorGeneratorInt64(*this).RandomTensor(
809           {}, {}, needs_unique_values, shape);
810     case DT_BOOL:
811       return TensorGeneratorBool(*this).RandomTensor(
812           {}, {}, needs_unique_values, shape);
813     default:
814       LOG(FATAL) << "Unimplemented type " << dtype << " in RandomTensor";
815   }
816 }
817 
RandomTensor(DataType dtype)818 Tensor OpTest::RandomTensor(DataType dtype) {
819   return RandomTensor(dtype, /*needs_unique_values=*/false, RandomDims());
820 }
821 
RandomNonNegativeTensor(DataType dtype,absl::Span<const int64_t> shape)822 Tensor OpTest::RandomNonNegativeTensor(DataType dtype,
823                                        absl::Span<const int64_t> shape) {
824   switch (dtype) {
825     case DT_FLOAT:
826       return TensorGeneratorFloat(*this).RandomTensor({0.0f}, {}, false, shape);
827     case DT_DOUBLE:
828       return TensorGeneratorDouble(*this).RandomTensor({0.0}, {}, false, shape);
829     case DT_INT32:
830       return TensorGeneratorInt32(*this).RandomTensor({0}, {}, false, shape);
831     case DT_INT64:
832       return TensorGeneratorInt64(*this).RandomTensor({0}, {}, false, shape);
833     default:
834       LOG(FATAL) << "Unimplemented type " << dtype
835                  << " in RandomNonNegativeTensor";
836   }
837 }
838 
RandomNonNegativeTensor(DataType dtype)839 Tensor OpTest::RandomNonNegativeTensor(DataType dtype) {
840   return RandomNonNegativeTensor(dtype, RandomDims());
841 }
842 
843 template <typename T>
RandomBoundedTensor(DataType dtype,T lo,T hi,bool needs_unique_values,absl::Span<const int64_t> shape)844 Tensor OpTest::RandomBoundedTensor(DataType dtype, T lo, T hi,
845                                    bool needs_unique_values,
846                                    absl::Span<const int64_t> shape) {
847   switch (dtype) {
848     case DT_FLOAT:
849       return TensorGeneratorFloat(*this).RandomTensor(
850           {lo}, {hi}, needs_unique_values, shape);
851     case DT_DOUBLE:
852       return TensorGeneratorDouble(*this).RandomTensor(
853           {lo}, {hi}, needs_unique_values, shape);
854     case DT_INT32:
855       return TensorGeneratorInt32(*this).RandomTensor(
856           {lo}, {hi}, needs_unique_values, shape);
857     case DT_INT64:
858       return TensorGeneratorInt64(*this).RandomTensor(
859           {lo}, {hi}, needs_unique_values, shape);
860     default:
861       LOG(FATAL) << "RandomBoundedTensor does not support type " << dtype
862                  << ".";
863   }
864 }
865 
866 template <typename T>
RandomBoundedTensor(DataType dtype,T lo,T hi,bool needs_unique_values)867 Tensor OpTest::RandomBoundedTensor(DataType dtype, T lo, T hi,
868                                    bool needs_unique_values) {
869   return RandomBoundedTensor<T>(dtype, lo, hi, needs_unique_values,
870                                 RandomDims());
871 }
872 
RandomBoundedTensor(DataType dtype,Tensor lo,Tensor hi)873 Tensor OpTest::RandomBoundedTensor(DataType dtype, Tensor lo, Tensor hi) {
874   TensorShape shape = lo.shape();
875   if (hi.shape() != shape) {
876     LOG(FATAL) << "hi and lo do not have the same shape in RandomBoundedTensor";
877   }
878   if (hi.dtype() != dtype) {
879     LOG(FATAL) << "hi does not have the expected dtype in RandomBoundedTensor";
880   }
881   if (lo.dtype() != dtype) {
882     LOG(FATAL) << "lo does not have the expected dtype in RandomBoundedTensor";
883   }
884   Tensor tensor(dtype, shape);
885   switch (dtype) {
886     case DT_FLOAT: {
887       auto lo_flat = lo.flat<float>();
888       auto hi_flat = hi.flat<float>();
889       test::FillFn<float>(&tensor, [this, &lo_flat, &hi_flat](int i) -> float {
890         std::uniform_real_distribution<float> distribution(lo_flat(i),
891                                                            hi_flat(i));
892         return distribution(generator());
893       });
894       break;
895     }
896     case DT_DOUBLE: {
897       auto lo_flat = lo.flat<double>();
898       auto hi_flat = hi.flat<double>();
899       test::FillFn<double>(
900           &tensor, [this, &lo_flat, &hi_flat](int i) -> double {
901             std::uniform_real_distribution<double> distribution(lo_flat(i),
902                                                                 hi_flat(i));
903             return distribution(generator());
904           });
905       break;
906     }
907     case DT_INT32: {
908       auto lo_flat = lo.flat<int32>();
909       auto hi_flat = hi.flat<int32>();
910       test::FillFn<int32>(&tensor, [this, &lo_flat, &hi_flat](int i) -> int32 {
911         std::uniform_int_distribution<int32> distribution(lo_flat(i),
912                                                           hi_flat(i));
913         return distribution(generator());
914       });
915       break;
916     }
917     case DT_INT64: {
918       auto lo_flat = lo.flat<int64>();
919       auto hi_flat = hi.flat<int64>();
920       test::FillFn<int64_t>(
921           &tensor, [this, &lo_flat, &hi_flat](int i) -> int64_t {
922             std::uniform_int_distribution<int64_t> distribution(lo_flat(i),
923                                                                 hi_flat(i));
924             return distribution(generator());
925           });
926       break;
927     }
928     default:
929       LOG(FATAL) << "RandomBoundedTensor does not support type " << dtype
930                  << ".";
931   }
932   return tensor;
933 }
934 
RandomLteTensors(DataType dtype,absl::Span<const int64_t> shape)935 std::pair<Tensor, Tensor> OpTest::RandomLteTensors(
936     DataType dtype, absl::Span<const int64_t> shape) {
937   switch (dtype) {
938     case DT_FLOAT:
939       return TensorGeneratorFloat(*this).RandomLteTensors(shape);
940     case DT_DOUBLE:
941       return TensorGeneratorDouble(*this).RandomLteTensors(shape);
942     case DT_COMPLEX64:
943       LOG(FATAL) << "RandomLteTensors unavailable for DT_COMPLEX64";
944       break;
945     case DT_INT32:
946       return TensorGeneratorInt32(*this).RandomLteTensors(shape);
947     case DT_INT64:
948       return TensorGeneratorInt64(*this).RandomLteTensors(shape);
949     case DT_BOOL:
950       LOG(FATAL) << "RandomLteTensors unavailable for DT_BOOL";
951       break;
952     default:
953       LOG(FATAL) << "Unimplemented type " << dtype << " in RandomLteTensors";
954   }
955   Tensor tensor(dtype, TensorShape(shape));
956   return std::pair<Tensor, Tensor>(tensor, tensor);
957 }
958 
RandomLteTensors(DataType dtype)959 std::pair<Tensor, Tensor> OpTest::RandomLteTensors(DataType dtype) {
960   return RandomLteTensors(dtype, RandomDims());
961 }
962 
BroadcastableToDims(std::vector<int64_t> dims)963 std::vector<int64_t> OpTest::BroadcastableToDims(std::vector<int64_t> dims) {
964   if (dims.empty()) return dims;
965 
966   // Remove some dimensions from the front of 'dims'.
967   size_t skip =
968       std::uniform_int_distribution<size_t>(0, dims.size() - 1)(generator());
969 
970   std::vector<int64_t> bdims(dims.begin() + skip, dims.end());
971 
972   // Randomly replace some of the remaining dimensions of 'dims' with 1.
973   std::bernoulli_distribution random_bool;
974 
975   for (int64_t& dim : bdims) {
976     if (random_bool(generator())) {
977       dim = 1LL;
978     }
979   }
980   return bdims;
981 }
982 
BroadcastableDims(std::vector<int64_t> dims)983 std::pair<std::vector<int64_t>, std::vector<int64_t>> OpTest::BroadcastableDims(
984     std::vector<int64_t> dims) {
985   auto bdims = BroadcastableToDims(dims);
986   // Possibly swap the roles of 'dims' and 'bdims'.
987   std::bernoulli_distribution random_bool;
988   if (random_bool(generator())) {
989     dims.swap(bdims);
990   }
991   return {dims, bdims};
992 }
993 
994 std::pair<std::vector<int64_t>, std::vector<int64_t>>
BroadcastableDims()995 OpTest::BroadcastableDims() {
996   return BroadcastableDims(RandomDims(0, 3));
997 }
998 
RandomReductionIndices(int rank)999 Tensor OpTest::RandomReductionIndices(int rank) {
1000   std::bernoulli_distribution random_bool;
1001   std::vector<int32> indices;
1002   for (int i = 0; i < rank; ++i) {
1003     if (random_bool(generator())) {
1004       indices.push_back(i);
1005     }
1006   }
1007   return test::AsTensor<int32>(indices);
1008 }
1009 
1010 // Helper that converts 'values' to an int32 or int64 Tensor.
AsIntTensor(DataType dtype,const std::vector<int64_t> & values)1011 static Tensor AsIntTensor(DataType dtype, const std::vector<int64_t>& values) {
1012   switch (dtype) {
1013     case DT_INT32: {
1014       std::vector<int32> values32(values.begin(), values.end());
1015       return test::AsTensor<int32>(values32);
1016     }
1017     case DT_INT64:
1018       return test::AsTensor<int64_t>(values);
1019     default:
1020       LOG(FATAL);
1021   }
1022 }
1023 
ChooseBatchMatMulArguments(bool broadcastable_batch)1024 OpTest::BatchMatMulArguments OpTest::ChooseBatchMatMulArguments(
1025     bool broadcastable_batch) {
1026   BatchMatMulArguments a;
1027   a.dtype = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
1028 
1029   int64_t min_size = 0;
1030   int64_t max_size = 7;
1031   auto batch_dims_to = RandomDims(0, 3, min_size, max_size);
1032   int rank = batch_dims_to.size() + 2;
1033   std::pair<std::vector<int64_t>, std::vector<int64_t>> batch_dims_nobcast(
1034       batch_dims_to, batch_dims_to);
1035   auto batch_dims = broadcastable_batch ? BroadcastableDims(batch_dims_to)
1036                                         : batch_dims_nobcast;
1037   std::vector<int64_t> lhs_dims(batch_dims.first), rhs_dims(batch_dims.second);
1038   int64_t inner_dim = RandomDim();
1039   lhs_dims.push_back(RandomDim(min_size, max_size));
1040   lhs_dims.push_back(inner_dim);
1041   rhs_dims.push_back(inner_dim);
1042   rhs_dims.push_back(RandomDim(min_size, max_size));
1043 
1044   std::bernoulli_distribution random_bool;
1045   a.adj_lhs = random_bool(generator());
1046   a.adj_rhs = random_bool(generator());
1047   if (a.adj_lhs) {
1048     std::swap(lhs_dims[rank - 1], lhs_dims[rank - 2]);
1049   }
1050   if (a.adj_rhs) {
1051     std::swap(rhs_dims[rank - 1], rhs_dims[rank - 2]);
1052   }
1053 
1054   a.lhs_dims = lhs_dims;
1055   a.rhs_dims = rhs_dims;
1056   return a;
1057 }
1058 
ChooseConcatArguments(bool int64_idx_allowed)1059 OpTest::ConcatArguments OpTest::ChooseConcatArguments(bool int64_idx_allowed) {
1060   ConcatArguments a;
1061 
1062   std::bernoulli_distribution random_bool;
1063   bool use_int64_idx = random_bool(generator());
1064 
1065   a.type = Choose<DataType>(kAllXlaTypes);
1066   a.type_idx = use_int64_idx ? DT_INT64 : DT_INT32;
1067   a.n = std::uniform_int_distribution<int>(2, 4)(generator());
1068 
1069   std::vector<int64_t> dims = RandomDims(1, 4, 0, 64);
1070 
1071   int axis =
1072       std::uniform_int_distribution<int32>(0, dims.size() - 1)(generator());
1073   a.axis =
1074       use_int64_idx ? test::AsScalar<int64>(axis) : test::AsScalar<int32>(axis);
1075 
1076   for (int i = 0; i < a.n; ++i) {
1077     std::vector<int64_t> shape = dims;
1078     shape[axis] = RandomDim(0, 64);
1079     a.values.push_back(RandomTensor(a.type, false, shape));
1080   }
1081 
1082   return a;
1083 }
1084 
ChooseEinsumArguments()1085 OpTest::EinsumArguments OpTest::ChooseEinsumArguments() {
1086   EinsumArguments a;
1087 
1088   enum EinsumType { matmul, batchmatmul, dot, outer };
1089   int op_kind = Choose<int>({matmul, batchmatmul, dot, outer});
1090   switch (op_kind) {
1091     case matmul:
1092     case batchmatmul: {
1093       std::vector<int64> dims;
1094       if (op_kind == matmul) {
1095         a.equation = "ij,jk->ik";
1096         dims = RandomDims(2, 2);
1097       } else {
1098         a.equation = "...ij,...jk->...ik";
1099         dims = RandomDims(2);
1100       }
1101       int64_t ndims = dims.size();
1102       int64_t inner_dim = RandomDim();
1103       a.lhs_dims = dims;
1104       a.rhs_dims = dims;
1105       a.lhs_dims[ndims - 1] = inner_dim;
1106       a.rhs_dims[ndims - 2] = inner_dim;
1107       break;
1108     }
1109     case dot: {
1110       a.equation = "i,i->";
1111       std::vector<int64> dims = RandomDims(1, 1);
1112       a.lhs_dims = dims;
1113       a.rhs_dims = dims;
1114       break;
1115     }
1116     case outer: {
1117       a.equation = "i,j->ij";
1118       a.lhs_dims = RandomDims(1, 1);
1119       a.rhs_dims = RandomDims(1, 1);
1120       break;
1121     }
1122   }
1123 
1124   a.type = Choose<DataType>(kAllXlaTypes);
1125   return a;
1126 }
1127 
ChooseGatherArguments(bool axis_0)1128 OpTest::GatherArguments OpTest::ChooseGatherArguments(bool axis_0) {
1129   GatherArguments a;
1130 
1131   a.axis_type = DT_INT32;
1132   a.indices_type = DT_INT32;
1133   a.params_type = Choose<DataType>(kAllXlaTypes);
1134 
1135   // Choose parameters such that
1136   // 0 <= batch_dims <= axis < params.rank <= kDefaultMaxRank
1137   a.batch_dims = 0;
1138   int64_t axis;
1139   if (axis_0) {
1140     axis = 0;
1141   } else {
1142     std::uniform_int_distribution<int64_t> axis_distribution(
1143         a.batch_dims, kDefaultMaxRank - 1);
1144     axis = axis_distribution(generator());
1145   }
1146   a.axis = test::AsScalar<int32>((int32)axis);
1147   a.params_shape = RandomDims(axis + 1, kDefaultMaxRank, 1, 16);
1148   std::vector<int64_t> indices_shape = RandomDims(0, 3, 0, 16);
1149   a.indices = RandomBoundedTensor<int32>(DT_INT32, 0, a.params_shape[axis] - 1,
1150                                          false, indices_shape);
1151 
1152   return a;
1153 }
1154 
ChoosePadArguments()1155 OpTest::PadArguments OpTest::ChoosePadArguments() {
1156   PadArguments a;
1157 
1158   a.input_type = Choose<DataType>(kAllXlaTypes);
1159   a.input_shape = RandomDims();
1160   int input_rank = a.input_shape.size();
1161 
1162   a.paddings_type = Choose<DataType>({DT_INT32, DT_INT64});
1163   std::vector<int64_t> paddings_vec;
1164   for (int i = 0; i < input_rank; ++i) {
1165     std::uniform_int_distribution<int> pad_distribution(0, a.input_shape[i]);
1166     int pad_size = pad_distribution(generator());
1167     std::uniform_int_distribution<int> lower_distribution(0, pad_size);
1168     int low_pad_size = lower_distribution(generator());
1169     paddings_vec.push_back(low_pad_size);
1170     paddings_vec.push_back(pad_size - low_pad_size);
1171     a.input_shape[i] -= pad_size;
1172   }
1173   CHECK(
1174       a.paddings.CopyFrom(AsIntTensor(a.paddings_type, paddings_vec),
1175                           TensorShape({static_cast<int64_t>(input_rank), 2})));
1176 
1177   a.constant_values = RandomTensor(a.input_type, false, {});
1178 
1179   return a;
1180 }
1181 
ChooseScatterArguments()1182 OpTest::ScatterArguments OpTest::ChooseScatterArguments() {
1183   ScatterArguments a;
1184 
1185   a.type = Choose<DataType>(kAllXlaTypes);
1186   a.indices_type = DT_INT32;
1187   a.shape = RandomDims(1, kDefaultMaxRank, 1);
1188   int rank = a.shape.size();
1189   std::uniform_int_distribution<int32> index_len_dist(1, rank);
1190   int index_len = index_len_dist(generator());
1191   std::vector<int64_t> indices_first = RandomDims(1, kDefaultMaxRank - 1, 1);
1192   std::vector<int64_t> indices_shape(indices_first);
1193   indices_shape.push_back(index_len);
1194   std::vector<int64_t> updates_shape(indices_first);
1195   for (int i = 0; i < rank - index_len; ++i) {
1196     updates_shape.push_back(a.shape[index_len + i]);
1197   }
1198   Tensor indices_lo(a.indices_type, TensorShape(indices_shape));
1199   test::FillFn<int32>(&indices_lo, [](int i) -> int32 { return 0; });
1200   Tensor indices_hi(a.indices_type, TensorShape(indices_shape));
1201   test::FillFn<int32>(&indices_hi, [index_len, &a](int i) -> int32 {
1202     int idx_dim = i % index_len;
1203     return a.shape[idx_dim] - 1;
1204   });
1205   a.indices = RandomBoundedTensor(a.indices_type, indices_lo, indices_hi);
1206   a.updates = RandomTensor(a.type, false, updates_shape);
1207 
1208   return a;
1209 }
1210 
ChooseSliceArguments(bool neg_one_size)1211 OpTest::SliceArguments OpTest::ChooseSliceArguments(bool neg_one_size) {
1212   SliceArguments a;
1213 
1214   a.type = Choose<DataType>(kAllXlaTypes);
1215   a.indices_type = DT_INT32;
1216   a.shape = RandomDims();
1217   int rank = a.shape.size();
1218 
1219   std::vector<int32> indices(rank);
1220   a.size.resize(rank);
1221   for (int i = 0; i < rank; ++i) {
1222     indices[i] =
1223         std::uniform_int_distribution<int32>(0, a.shape[i])(generator());
1224     int64_t low = neg_one_size ? -1 : 0;
1225     a.size[i] = std::uniform_int_distribution<int64_t>(
1226         low, a.shape[i] - indices[i])(generator());
1227   }
1228   a.indices = test::AsTensor<int32>(indices);
1229 
1230   return a;
1231 }
1232 
ChooseWindowedSpatialDims(int num_spatial_dims)1233 OpTest::WindowedSpatialDims OpTest::ChooseWindowedSpatialDims(
1234     int num_spatial_dims) {
1235   WindowedSpatialDims d;
1236   d.padding = Choose<Padding>({SAME, VALID});
1237   std::uniform_int_distribution<int> random_int(1, 5);
1238   d.kernel_dims.resize(num_spatial_dims);
1239   d.input_dims.resize(num_spatial_dims);
1240   d.output_dims.resize(num_spatial_dims);
1241   d.stride_dims.resize(num_spatial_dims);
1242   for (int i = 0; i < num_spatial_dims; ++i) {
1243     Status s;
1244     // Repeatedly try different filter/stride sizes until we find a valid
1245     // combination.
1246     do {
1247       // CPU implementations require stride <= kernel size.
1248       d.kernel_dims[i] = random_int(generator()),
1249       d.input_dims[i] = RandomDim(d.kernel_dims[i]);
1250       d.stride_dims[i] =
1251           std::uniform_int_distribution<int>(1, d.kernel_dims[i])(generator());
1252       int64_t pad_dummy;
1253       s = GetWindowedOutputSize(d.input_dims[i], d.kernel_dims[i],
1254                                 d.stride_dims[i], d.padding, &d.output_dims[i],
1255                                 &pad_dummy);
1256     } while (!s.ok());
1257   }
1258   return d;
1259 }
1260 
ChooseXlaDotArguments()1261 OpTest::XlaDotArguments OpTest::ChooseXlaDotArguments() {
1262   std::vector<int64_t> batch_dims = RandomDims(0, 2);
1263   std::vector<int64_t> contracting_dims = RandomDims(0, 2);
1264   std::vector<int64_t> lhs_outer_dims = RandomDims(0, 2);
1265   std::vector<int64_t> rhs_outer_dims = RandomDims(0, 2);
1266 
1267   XlaDotArguments a;
1268   a.lhs_dims.insert(a.lhs_dims.end(), batch_dims.begin(), batch_dims.end());
1269   a.lhs_dims.insert(a.lhs_dims.end(), contracting_dims.begin(),
1270                     contracting_dims.end());
1271   a.lhs_dims.insert(a.lhs_dims.end(), lhs_outer_dims.begin(),
1272                     lhs_outer_dims.end());
1273   a.rhs_dims.insert(a.rhs_dims.end(), batch_dims.begin(), batch_dims.end());
1274   a.rhs_dims.insert(a.rhs_dims.end(), contracting_dims.begin(),
1275                     contracting_dims.end());
1276   a.rhs_dims.insert(a.rhs_dims.end(), rhs_outer_dims.begin(),
1277                     rhs_outer_dims.end());
1278 
1279   xla::DotDimensionNumbers dnums;
1280   for (auto i = 0; i < batch_dims.size(); ++i) {
1281     dnums.add_lhs_batch_dimensions(i);
1282     dnums.add_rhs_batch_dimensions(i);
1283   }
1284   for (auto i = 0; i < contracting_dims.size(); ++i) {
1285     dnums.add_lhs_contracting_dimensions(batch_dims.size() + i);
1286     dnums.add_rhs_contracting_dimensions(batch_dims.size() + i);
1287   }
1288   dnums.SerializeToString(&a.dnums_encoded);
1289 
1290   a.precision_config_encoded = "";
1291 
1292   a.dtype = Choose<DataType>(kAllXlaTypes);
1293   return a;
1294 }
1295 
ImageDims(TensorFormat format,int batch,int feature,const std::vector<int64_t> & spatial_dims)1296 std::vector<int64_t> OpTest::ImageDims(
1297     TensorFormat format, int batch, int feature,
1298     const std::vector<int64_t>& spatial_dims) {
1299   std::vector<int64_t> dims;
1300   switch (format) {
1301     case FORMAT_NHWC:
1302       dims.push_back(batch);
1303       for (int dim : spatial_dims) {
1304         dims.push_back(dim);
1305       }
1306       dims.push_back(feature);
1307       break;
1308     case FORMAT_NCHW:
1309       dims.push_back(batch);
1310       dims.push_back(feature);
1311       for (int dim : spatial_dims) {
1312         dims.push_back(dim);
1313       }
1314       break;
1315     default:
1316       LOG(FATAL) << "Tensor format " << ToString(format) << " not supported.";
1317   }
1318   return dims;
1319 }
1320 
AsInt32s(const std::vector<int64_t> & int64s)1321 std::vector<int32> OpTest::AsInt32s(const std::vector<int64_t>& int64s) {
1322   return std::vector<int32>(int64s.begin(), int64s.end());
1323 }
1324 
1325 // Functions for comparing tensors.
1326 
1327 template <typename T>
Abs(T x)1328 double Abs(T x) {
1329   return std::fabs(x);
1330 }
1331 
1332 template <>
Abs(complex64 x)1333 double Abs<complex64>(complex64 x) {
1334   return std::abs(x);
1335 }
1336 
1337 template <typename T>
IsClose(const T & x,const T & y,double atol,double rtol)1338 bool IsClose(const T& x, const T& y, double atol, double rtol) {
1339   if (std::isnan(x) && std::isnan(y)) return true;
1340   if (x == y) return true;  // Allow inf == inf.
1341   return Abs(x - y) < atol + rtol * Abs(x);
1342 }
1343 
1344 template <>
IsClose(const complex64 & x,const complex64 & y,double atol,double rtol)1345 bool IsClose<complex64>(const complex64& x, const complex64& y, double atol,
1346                         double rtol) {
1347   if (std::isnan(x.real()) && std::isnan(y.real())) {
1348     if (std::isnan(x.imag()) && std::isnan(y.imag())) {
1349       return true;
1350     }
1351     if (x.imag() == y.imag()) return true;  // Allow inf == inf.
1352     return Abs(x.imag() - y.imag()) < atol + rtol * Abs(x.imag());
1353   } else if (std::isnan(x.imag()) && std::isnan(y.imag())) {
1354     if (x.real() == y.real()) return true;  // Allow inf == inf.
1355     return Abs(x.real() - y.real()) < atol + rtol * Abs(x.real());
1356   }
1357   if (x == y) return true;  // Allow inf == inf.
1358   return Abs(x - y) < atol + rtol * Abs(x);
1359 }
1360 
1361 template <typename T>
Str(T x)1362 string Str(T x) {
1363   return absl::StrCat(x);
1364 }
1365 template <>
Str(complex64 x)1366 string Str<complex64>(complex64 x) {
1367   return absl::StrCat("(", x.real(), ", ", x.imag(), ")");
1368 }
1369 
1370 template <typename T>
TensorsAreCloseImpl(const Tensor & x,const Tensor & y,double atol,double rtol)1371 Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol,
1372                            double rtol) {
1373   auto Tx = x.flat<T>();
1374   auto Ty = y.flat<T>();
1375   for (int i = 0; i < Tx.size(); ++i) {
1376     if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
1377       return errors::InvalidArgument(
1378           absl::StrCat(i, "-th tensor element isn't close: ", Str(Tx(i)),
1379                        " vs. ", Str(Ty(i)), ". x = ", x.DebugString(),
1380                        "y = ", y.DebugString(), "atol = ", atol,
1381                        " rtol = ", rtol, " tol = ", atol + rtol * Abs(Tx(i))));
1382     }
1383   }
1384   return OkStatus();
1385 }
1386 
1387 template <typename T>
TensorsAreEqualImpl(const Tensor & x,const Tensor & y)1388 Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) {
1389   auto Tx = x.flat<T>();
1390   auto Ty = y.flat<T>();
1391   for (int i = 0; i < Tx.size(); ++i) {
1392     if (Tx(i) != Ty(i)) {
1393       return errors::InvalidArgument(absl::StrCat(
1394           i, "-th tensor element isn't equal: ", Str(Tx(i)), " vs. ",
1395           Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString()));
1396     }
1397   }
1398   return OkStatus();
1399 }
1400 
TensorsAreEqualImplBfloat16(const Tensor & x,const Tensor & y)1401 Status TensorsAreEqualImplBfloat16(const Tensor& x, const Tensor& y) {
1402   auto Tx = x.flat<bfloat16>();
1403   auto Ty = y.flat<bfloat16>();
1404   for (int i = 0; i < Tx.size(); ++i) {
1405     if (Tx(i) != Ty(i)) {
1406       return errors::InvalidArgument(absl::StrCat(
1407           i, "-th tensor element isn't equal: ", static_cast<float>(Tx(i)),
1408           " vs. ", static_cast<float>(Ty(i)), ". x = ", x.DebugString(),
1409           "y = ", y.DebugString()));
1410     }
1411   }
1412   return OkStatus();
1413 }
1414 
1415 // Tests if "x" and "y" are tensors of the same type, same shape, and with
1416 // close values. For floating-point tensors, the element-wise difference between
1417 // x and y must no more than atol + rtol * abs(x). For non-floating-point
1418 // tensors the values must match exactly.
TensorsAreClose(const Tensor & a,const Tensor & b,double atol,double rtol)1419 Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol,
1420                        double rtol) {
1421   if (a.dtype() != b.dtype()) {
1422     return errors::InvalidArgument(absl::StrCat(
1423         "Tensors have different types: ", DataTypeString(a.dtype()), " and ",
1424         DataTypeString(b.dtype())));
1425   }
1426   if (!a.IsSameSize(b)) {
1427     return errors::InvalidArgument(
1428         absl::StrCat("Tensors have different shapes: ", a.shape().DebugString(),
1429                      " and ", b.shape().DebugString()));
1430   }
1431 
1432   switch (a.dtype()) {
1433     case DT_FLOAT:
1434       return TensorsAreCloseImpl<float>(a, b, atol, rtol);
1435     case DT_DOUBLE:
1436       return TensorsAreCloseImpl<double>(a, b, atol, rtol);
1437     case DT_COMPLEX64:
1438       return TensorsAreCloseImpl<complex64>(a, b, atol, rtol);
1439     case DT_INT32:
1440       return TensorsAreEqualImpl<int32>(a, b);
1441     case DT_INT64:
1442       return TensorsAreEqualImpl<int64_t>(a, b);
1443     case DT_BOOL:
1444       return TensorsAreEqualImpl<bool>(a, b);
1445     case DT_BFLOAT16:
1446       return TensorsAreEqualImplBfloat16(a, b);
1447     default:
1448       LOG(FATAL) << "Unexpected type : " << DataTypeString(a.dtype());
1449   }
1450 }
1451 
ExpectTfAndXlaOutputsAreClose(const OpTestBuilder & builder,double atol,double rtol)1452 OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
1453     const OpTestBuilder& builder, double atol, double rtol) {
1454   const std::vector<OpTestBuilder::InputDescription>& inputs = builder.inputs();
1455   std::vector<Tensor> input_tensors;
1456   input_tensors.reserve(inputs.size());
1457   for (const OpTestBuilder::InputDescription& input : inputs) {
1458     if (input.type == DT_INVALID) {
1459       input_tensors.push_back(input.tensor);
1460     } else {
1461       std::vector<int64_t> dims;
1462       if (input.has_dims) {
1463         dims = input.dims;
1464       } else {
1465         dims = RandomDims();
1466       }
1467       if (!TensorSizeIsOk(dims)) {
1468         VLOG(1) << "Input: " << input.type << " "
1469                 << TensorShape(input.dims).DebugString();
1470         VLOG(1) << "Ignoring oversize dims.";
1471         return kInvalid;
1472       }
1473       input_tensors.push_back(
1474           RandomTensor(input.type, input.needs_unique_values, dims));
1475     }
1476     VLOG(1) << "Input: " << input_tensors.back().DebugString();
1477   }
1478 
1479   string reference_device =
1480       LocalDeviceToFullDeviceName(*tf_xla_reference_device_ptr);
1481   string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr);
1482 
1483   DeviceNameUtils::ParsedName parsed_name;
1484   if (!DeviceNameUtils::ParseLocalName(*tf_xla_test_device_ptr, &parsed_name)) {
1485     LOG(ERROR) << "Could not parse device name: " << *tf_xla_test_device_ptr;
1486     return kFatalError;
1487   }
1488   DeviceType test_device_type(parsed_name.type);
1489   ++num_tests_;
1490 
1491   GraphDef graph;
1492   std::vector<string> expected_inputs, test_inputs;
1493   std::vector<string> expected_fetches, test_fetches;
1494   Status status = builder.BuildGraph(
1495       absl::StrCat("test", num_tests_, "_expected"), reference_device,
1496       /*use_jit=*/false, &graph, /*test_node_def=*/nullptr, &expected_inputs,
1497       &expected_fetches);
1498   if (!status.ok()) {
1499     LOG(ERROR) << "Expected graph construction failed: " << status;
1500     return kFatalError;
1501   }
1502 
1503   NodeDef* node_def;
1504   status = builder.BuildGraph(absl::StrCat("test", num_tests_, "_test"),
1505                               test_device, tf_xla_test_use_jit, &graph,
1506                               &node_def, &test_inputs, &test_fetches);
1507   if (!status.ok()) {
1508     LOG(ERROR) << "Test graph construction failed: " << status;
1509     return kFatalError;
1510   }
1511 
1512   // Check that there's a kernel corresponding to 'node_def' on the device under
1513   // test.
1514   status = FindKernelDef(test_device_type, *node_def, nullptr, nullptr);
1515   if (!status.ok()) {
1516     VLOG(1) << "Skipping test because there is no corresponding registered "
1517             << "kernel on the test device: " << status;
1518     return kInvalid;
1519   }
1520 
1521   status = session_->Extend(graph);
1522   if (!status.ok()) {
1523     LOG(ERROR) << "Session::Extend() failed: " << status;
1524     return kFatalError;
1525   }
1526 
1527   std::vector<std::pair<string, Tensor>> expected_feeds(expected_inputs.size());
1528   std::vector<std::pair<string, Tensor>> test_feeds(test_inputs.size());
1529   CHECK_EQ(input_tensors.size(), expected_inputs.size());
1530   CHECK_EQ(input_tensors.size(), test_inputs.size());
1531 
1532   for (int i = 0; i < input_tensors.size(); ++i) {
1533     expected_feeds[i] = {expected_inputs[i], input_tensors[i]};
1534     test_feeds[i] = {test_inputs[i], input_tensors[i]};
1535   }
1536 
1537   std::vector<Tensor> expected_outputs, test_outputs;
1538   VLOG(1) << "Running expected graph";
1539   Status s =
1540       session_->Run(expected_feeds, expected_fetches, {}, &expected_outputs);
1541   if (!s.ok()) {
1542     VLOG(1) << "Expected graph failed with status: " << s << ". Ignoring test";
1543     return kInvalid;
1544   }
1545   for (const Tensor& expected : expected_outputs) {
1546     VLOG(1) << "Expected: " << expected.DebugString();
1547   }
1548 
1549   VLOG(1) << "Running test graph";
1550   status = session_->Run(test_feeds, test_fetches, {}, &test_outputs);
1551   if (!status.ok()) {
1552     LOG(ERROR) << "Test graph failed: " << status;
1553     return kFatalError;
1554   }
1555 
1556   CHECK_EQ(expected_outputs.size(), test_outputs.size());
1557   for (int j = 0; s.ok() && j < test_outputs.size(); ++j) {
1558     s = TensorsAreClose(expected_outputs[j], test_outputs[j], atol, rtol);
1559   }
1560   TF_EXPECT_OK(s);
1561 
1562   return kOk;
1563 }
1564 
TEST_F(OpTest,_EagerConst)1565 TEST_F(OpTest, _EagerConst) {
1566   Repeatedly([this]() {
1567     auto type = Choose<DataType>(kAllXlaTypes);
1568     return ExpectTfAndXlaOutputsAreClose(
1569         OpTestBuilder("_EagerConst").RandomInput(type).Attr("T", type));
1570   });
1571 }
1572 
TEST_F(OpTest,Abs)1573 TEST_F(OpTest, Abs) {
1574   Repeatedly([this]() {
1575     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1576     return ExpectTfAndXlaOutputsAreClose(
1577         OpTestBuilder("Abs").RandomInput(type).Attr("T", type));
1578   });
1579 }
1580 
TEST_F(OpTest,Acos)1581 TEST_F(OpTest, Acos) {
1582   Repeatedly([this]() {
1583     return ExpectTfAndXlaOutputsAreClose(
1584         OpTestBuilder("Acos")
1585             .Input(RandomBoundedTensor<float>(DT_FLOAT, -1, 1, false))
1586             .Attr("T", DT_FLOAT));
1587   });
1588 }
1589 
TEST_F(OpTest,Acosh)1590 TEST_F(OpTest, Acosh) {
1591   Repeatedly([this]() {
1592     return ExpectTfAndXlaOutputsAreClose(
1593         OpTestBuilder("Acosh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1594   });
1595 }
1596 
TEST_F(OpTest,Add)1597 TEST_F(OpTest, Add) {
1598   Repeatedly([this]() {
1599     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1600     auto dims = BroadcastableDims();
1601     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add")
1602                                              .RandomInput(type, dims.first)
1603                                              .RandomInput(type, dims.second)
1604                                              .Attr("T", type));
1605   });
1606 }
1607 
TEST_F(OpTest,AddN)1608 TEST_F(OpTest, AddN) {
1609   Repeatedly([this]() {
1610     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1611     int n = std::uniform_int_distribution<int>(1, 5)(generator());
1612 
1613     auto shape = RandomDims();
1614 
1615     OpTestBuilder builder("AddN");
1616     builder.Attr("T", type);
1617     builder.Attr("N", n);
1618     for (int i = 0; i < n; ++i) {
1619       builder.RandomInput(type, shape);
1620     }
1621     return ExpectTfAndXlaOutputsAreClose(builder);
1622   });
1623 }
1624 
TEST_F(OpTest,AddV2)1625 TEST_F(OpTest, AddV2) {
1626   Repeatedly([this]() {
1627     auto type = Choose<DataType>(kAllXlaTypes);
1628     auto dims = BroadcastableDims();
1629     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("AddV2")
1630                                              .RandomInput(type, dims.first)
1631                                              .RandomInput(type, dims.second)
1632                                              .Attr("T", type));
1633   });
1634 }
1635 
TEST_F(OpTest,All)1636 TEST_F(OpTest, All) {
1637   Repeatedly([this]() {
1638     std::vector<int64_t> data_dims = RandomDims();
1639     Tensor indices = RandomReductionIndices(data_dims.size());
1640     bool keep_dims = Choose<bool>({false, true});
1641     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("All")
1642                                              .RandomInput(DT_BOOL, data_dims)
1643                                              .Input(indices)
1644                                              .Attr("keep_dims", keep_dims));
1645   });
1646 }
1647 
TEST_F(OpTest,Angle)1648 TEST_F(OpTest, Angle) {
1649   Repeatedly([this]() {
1650     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Angle")
1651                                              .RandomInput(DT_COMPLEX64)
1652                                              .Attr("T", DT_COMPLEX64));
1653   });
1654 }
1655 
TEST_F(OpTest,Any)1656 TEST_F(OpTest, Any) {
1657   Repeatedly([this]() {
1658     std::vector<int64_t> data_dims = RandomDims();
1659     Tensor indices = RandomReductionIndices(data_dims.size());
1660     bool keep_dims = Choose<bool>({false, true});
1661     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Any")
1662                                              .RandomInput(DT_BOOL, data_dims)
1663                                              .Input(indices)
1664                                              .Attr("keep_dims", keep_dims));
1665   });
1666 }
1667 
TEST_F(OpTest,ApproximateEqual)1668 TEST_F(OpTest, ApproximateEqual) {
1669   Repeatedly([this]() {
1670     auto dims = BroadcastableDims();
1671     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
1672     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual")
1673                                              .RandomInput(type, dims.first)
1674                                              .RandomInput(type, dims.second)
1675                                              .Attr("T", DT_FLOAT));
1676   });
1677 }
1678 
TEST_F(OpTest,ArgMax)1679 TEST_F(OpTest, ArgMax) {
1680   Repeatedly([this]() {
1681     auto type = Choose<DataType>({DT_BOOL, DT_FLOAT});
1682     std::vector<int64_t> dims = RandomDims(1, 5, 1);
1683     int num_dims = dims.size();
1684     int reduce_dim =
1685         std::uniform_int_distribution<int32>(-num_dims, num_dims)(generator());
1686     return ExpectTfAndXlaOutputsAreClose(
1687         OpTestBuilder("ArgMax")
1688             .RandomInput(type, dims)
1689             .Input(test::AsScalar<int32>(reduce_dim))
1690             .Attr("T", type)
1691             .Attr("Tidx", DT_INT32)
1692             .Attr("output_type", DT_INT32));
1693   });
1694 }
1695 
TEST_F(OpTest,ArgMin)1696 TEST_F(OpTest, ArgMin) {
1697   Repeatedly([this]() {
1698     auto type = Choose<DataType>({DT_BOOL, DT_FLOAT});
1699     std::vector<int64_t> dims = RandomDims(1, 5, 1);
1700     int num_dims = dims.size();
1701     int reduce_dim =
1702         std::uniform_int_distribution<int32>(-num_dims, num_dims)(generator());
1703     return ExpectTfAndXlaOutputsAreClose(
1704         OpTestBuilder("ArgMin")
1705             .RandomInput(type, dims)
1706             .Input(test::AsScalar<int32>(reduce_dim))
1707             .Attr("T", type)
1708             .Attr("Tidx", DT_INT32)
1709             .Attr("output_type", DT_INT32));
1710   });
1711 }
1712 
TEST_F(OpTest,Asin)1713 TEST_F(OpTest, Asin) {
1714   Repeatedly([this]() {
1715     return ExpectTfAndXlaOutputsAreClose(
1716         OpTestBuilder("Asin")
1717             .Input(RandomBoundedTensor<float>(DT_FLOAT, -1, 1, false))
1718             .Attr("T", DT_FLOAT));
1719   });
1720 }
1721 
TEST_F(OpTest,Asinh)1722 TEST_F(OpTest, Asinh) {
1723   Repeatedly([this]() {
1724     return ExpectTfAndXlaOutputsAreClose(
1725         OpTestBuilder("Asinh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1726   });
1727 }
1728 
TEST_F(OpTest,Atanh)1729 TEST_F(OpTest, Atanh) {
1730   Repeatedly([this]() {
1731     return ExpectTfAndXlaOutputsAreClose(
1732         OpTestBuilder("Atanh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1733   });
1734 }
1735 
TEST_F(OpTest,Atan)1736 TEST_F(OpTest, Atan) {
1737   Repeatedly([this]() {
1738     return ExpectTfAndXlaOutputsAreClose(
1739         OpTestBuilder("Atan").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1740   });
1741 }
1742 
TEST_F(OpTest,Atan2)1743 TEST_F(OpTest, Atan2) {
1744   Repeatedly([this]() {
1745     auto dims = BroadcastableDims();
1746     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Atan2")
1747                                              .RandomInput(DT_FLOAT, dims.first)
1748                                              .RandomInput(DT_FLOAT, dims.second)
1749                                              .Attr("T", DT_FLOAT));
1750   });
1751 }
1752 
TEST_F(OpTest,AvgPool)1753 TEST_F(OpTest, AvgPool) {
1754   Repeatedly([this]() {
1755     std::uniform_int_distribution<int> random_int(1, 5);
1756     std::vector<int64_t> dims = RandomDims(4, 4, 1);
1757     int kernel_rows =
1758         std::uniform_int_distribution<int>(1, dims[1])(generator());
1759     int kernel_cols =
1760         std::uniform_int_distribution<int>(1, dims[2])(generator());
1761     int stride_rows = random_int(generator()),
1762         stride_cols = random_int(generator());
1763     string padding = Choose<string>({"SAME", "VALID"});
1764     return ExpectTfAndXlaOutputsAreClose(
1765         OpTestBuilder("AvgPool")
1766             .RandomInput(DT_FLOAT, dims)
1767             .Attr("T", DT_FLOAT)
1768             .Attr("ksize", {1, kernel_rows, kernel_cols, 1})
1769             .Attr("strides", {1, stride_rows, stride_cols, 1})
1770             .Attr("padding", padding)
1771             .Attr("data_format", "NHWC"));
1772   });
1773   // TODO(phawkins): the CPU device only implements spatial pooling. Add tests
1774   // for batch pooling when supported.
1775 }
1776 
TEST_F(OpTest,AvgPool3D)1777 TEST_F(OpTest, AvgPool3D) {
1778   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1779   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
1780   Repeatedly([this]() {
1781     std::uniform_int_distribution<int> random_int(1, 5);
1782     std::vector<int64_t> dims = RandomDims(5, 5, 1);
1783 
1784     std::vector<int64_t> input_dims, kernel_dims, stride_dims;
1785     for (int i = 0; i < 3; ++i) {
1786       kernel_dims.push_back(
1787           std::uniform_int_distribution<int>(1, dims[i])(generator()));
1788       input_dims.push_back(dims[i]);
1789       stride_dims.push_back(random_int(generator()));
1790     }
1791     int64_t batch = dims[3];
1792     int64_t feature = dims[4];
1793 
1794     string padding = Choose<string>({"SAME", "VALID"});
1795     return ExpectTfAndXlaOutputsAreClose(
1796         OpTestBuilder("AvgPool3D")
1797             .RandomInput(DT_FLOAT,
1798                          ImageDims(FORMAT_NHWC, batch, feature, input_dims))
1799             .Attr("T", DT_FLOAT)
1800             .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, kernel_dims))
1801             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, stride_dims))
1802             .Attr("padding", padding)
1803             .Attr("data_format", "NDHWC"));
1804   });
1805   // TODO(phawkins): test NCHW format (not supported by CPU)
1806 }
1807 
TEST_F(OpTest,AvgPoolGrad)1808 TEST_F(OpTest, AvgPoolGrad) {
1809   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1810   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
1811   Repeatedly([this]() {
1812     int batch = RandomDim(1), features = RandomDim(1);
1813     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
1814     std::vector<int32> input_dims =
1815         AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims));
1816     std::vector<int64_t> output_dims =
1817         ImageDims(FORMAT_NHWC, batch, features, d.output_dims);
1818     return ExpectTfAndXlaOutputsAreClose(
1819         OpTestBuilder("AvgPoolGrad")
1820             .Input(test::AsTensor<int32>(input_dims))
1821             .RandomInput(DT_FLOAT, output_dims)
1822             .Attr("T", DT_FLOAT)
1823             .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims))
1824             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
1825             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
1826             .Attr("data_format", "NHWC"));
1827   });
1828 }
1829 
TEST_F(OpTest,AvgPool3DGrad)1830 TEST_F(OpTest, AvgPool3DGrad) {
1831   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1832   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
1833   Repeatedly([this]() {
1834     int batch = RandomDim(1), features = RandomDim(1);
1835     WindowedSpatialDims d = ChooseWindowedSpatialDims(3);
1836     std::vector<int32> input_dims =
1837         AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims));
1838     std::vector<int64_t> output_dims =
1839         ImageDims(FORMAT_NHWC, batch, features, d.output_dims);
1840     return ExpectTfAndXlaOutputsAreClose(
1841         OpTestBuilder("AvgPool3DGrad")
1842             .Input(test::AsTensor<int32>(input_dims))
1843             .RandomInput(DT_FLOAT, output_dims)
1844             .Attr("T", DT_FLOAT)
1845             .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims))
1846             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
1847             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
1848             .Attr("data_format", "NDHWC"));
1849   });
1850 }
1851 
TEST_F(OpTest,BatchMatMul)1852 TEST_F(OpTest, BatchMatMul) {
1853   // See note about failing Kokoro tests: b/214080339#comment22
1854   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1855   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
1856   Repeatedly([this]() {
1857     const BatchMatMulArguments a = ChooseBatchMatMulArguments(false);
1858     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
1859                                              .RandomInput(a.dtype, a.lhs_dims)
1860                                              .RandomInput(a.dtype, a.rhs_dims)
1861                                              .Attr("T", a.dtype)
1862                                              .Attr("adj_x", a.adj_lhs)
1863                                              .Attr("adj_y", a.adj_rhs));
1864   });
1865 }
1866 
TEST_F(OpTest,BatchMatMulV2)1867 TEST_F(OpTest, BatchMatMulV2) {
1868   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1869   // :randomized_tests_seeded is flaky with --tf_xla_random_seed=200839030
1870   // See b/229622638.
1871   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
1872   Repeatedly([this]() {
1873     const BatchMatMulArguments a = ChooseBatchMatMulArguments(true);
1874     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMulV2")
1875                                              .RandomInput(a.dtype, a.lhs_dims)
1876                                              .RandomInput(a.dtype, a.rhs_dims)
1877                                              .Attr("T", a.dtype)
1878                                              .Attr("adj_x", a.adj_lhs)
1879                                              .Attr("adj_y", a.adj_rhs));
1880   });
1881 }
1882 
TEST_F(OpTest,BatchToSpace)1883 TEST_F(OpTest, BatchToSpace) {
1884   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1885   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
1886   Repeatedly([this]() {
1887     const int num_block_dims = 2;
1888     std::vector<int64_t> block_dims =
1889         RandomDims(num_block_dims, num_block_dims, 0, 5);
1890     int64_t block_size = RandomDim(2, 5);
1891 
1892     std::vector<int64_t> input_dims(1 + num_block_dims + 1);
1893     input_dims[0] = RandomDim();
1894     for (int i = 0; i < num_block_dims; ++i) {
1895       input_dims[0] *= block_size;
1896       input_dims[1 + i] = block_dims[i];
1897     }
1898     input_dims[1 + num_block_dims] = RandomDim();
1899 
1900     std::vector<int64_t> crop_vals;
1901     std::uniform_int_distribution<int> distribution(0, 4);
1902     for (int i = 0; i < num_block_dims; ++i) {
1903       // Chooses crop values; does not always choose legal values.
1904       crop_vals.push_back(distribution(generator()));
1905       crop_vals.push_back(distribution(generator()));
1906     }
1907     Tensor crops;
1908     CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
1909                          TensorShape({num_block_dims, 2})));
1910 
1911     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1912     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace")
1913                                              .RandomInput(type, input_dims)
1914                                              .Input(crops)
1915                                              .Attr("T", type)
1916                                              .Attr("block_size", block_size));
1917   });
1918 }
1919 
TEST_F(OpTest,BatchToSpaceND)1920 TEST_F(OpTest, BatchToSpaceND) {
1921   Repeatedly([this]() {
1922     std::vector<int64_t> block_dims = RandomDims(1, 3, 0, 5);
1923     int num_block_dims = block_dims.size();
1924     std::vector<int64_t> remaining_dims = RandomDims(0, 3);
1925     std::vector<int64_t> block_multipliers =
1926         RandomDims(block_dims.size(), block_dims.size(), 0, 4);
1927 
1928     std::vector<int64_t> input_dims(1 + num_block_dims + remaining_dims.size());
1929     input_dims[0] = RandomDim();
1930     for (int i = 0; i < num_block_dims; ++i) {
1931       input_dims[0] *= block_dims[i];
1932     }
1933     std::copy(block_multipliers.begin(), block_multipliers.end(),
1934               input_dims.begin() + 1);
1935     std::copy(remaining_dims.begin(), remaining_dims.end(),
1936               input_dims.begin() + 1 + num_block_dims);
1937 
1938     std::vector<int64_t> crop_vals;
1939     std::uniform_int_distribution<int> distribution(0, 3);
1940     for (int i = 0; i < num_block_dims; ++i) {
1941       // Chooses crop values; does not always choose legal values.
1942       crop_vals.push_back(distribution(generator()));
1943       crop_vals.push_back(distribution(generator()));
1944     }
1945     Tensor crops;
1946     CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
1947                          TensorShape({num_block_dims, 2})));
1948 
1949     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1950     return ExpectTfAndXlaOutputsAreClose(
1951         OpTestBuilder("BatchToSpaceND")
1952             .RandomInput(type, input_dims)
1953             .Input(test::AsTensor<int32>(
1954                 std::vector<int32>(block_dims.begin(), block_dims.end())))
1955             .Input(crops)
1956             .Attr("T", type));
1957   });
1958 }
1959 
TEST_F(OpTest,BiasAdd)1960 TEST_F(OpTest, BiasAdd) {
1961   Repeatedly([this]() {
1962     auto x_dims = RandomDims(2, kDefaultMaxRank);
1963     auto y_dims = {x_dims[x_dims.size() - 1]};
1964     // TODO(phawkins): test both data formats.
1965     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1966     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAdd")
1967                                              .RandomInput(type, x_dims)
1968                                              .RandomInput(type, y_dims)
1969                                              .Attr("T", type));
1970   });
1971 }
1972 
TEST_F(OpTest,BiasAddGrad)1973 TEST_F(OpTest, BiasAddGrad) {
1974   Repeatedly([this]() {
1975     // TODO(phawkins): test both data formats.
1976     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1977     return ExpectTfAndXlaOutputsAreClose(
1978         OpTestBuilder("BiasAddGrad").RandomInput(type).Attr("T", type));
1979   });
1980 }
1981 
TEST_F(OpTest,BiasAddV1)1982 TEST_F(OpTest, BiasAddV1) {
1983   Repeatedly([this]() {
1984     auto x_dims = RandomDims(2, kDefaultMaxRank);
1985     auto y_dims = {x_dims[x_dims.size() - 1]};
1986     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1987     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAddV1")
1988                                              .RandomInput(type, x_dims)
1989                                              .RandomInput(type, y_dims)
1990                                              .Attr("T", type));
1991   });
1992 }
1993 
TEST_F(OpTest,Bitcast)1994 TEST_F(OpTest, Bitcast) {
1995   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
1996   Repeatedly([this]() {  // NOLINT: due to GTEST_SKIP
1997     auto src_type = Choose<DataType>(kAllNumberTypes);
1998     auto dst_type = Choose<DataType>(kAllNumberTypes);
1999     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Bitcast")
2000                                              .RandomInput(src_type)
2001                                              .Attr("T", src_type)
2002                                              .Attr("type", dst_type));
2003   });
2004 }
2005 
TEST_F(OpTest,BitwiseAnd)2006 TEST_F(OpTest, BitwiseAnd) {
2007   Repeatedly([this]() {
2008     DataType type = DT_INT32;
2009     auto dims = BroadcastableDims();
2010     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseAnd")
2011                                              .RandomInput(type, dims.first)
2012                                              .RandomInput(type, dims.second)
2013                                              .Attr("T", type));
2014   });
2015 }
2016 
TEST_F(OpTest,BitwiseOr)2017 TEST_F(OpTest, BitwiseOr) {
2018   Repeatedly([this]() {
2019     DataType type = DT_INT32;
2020     auto dims = BroadcastableDims();
2021     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseOr")
2022                                              .RandomInput(type, dims.first)
2023                                              .RandomInput(type, dims.second)
2024                                              .Attr("T", type));
2025   });
2026 }
2027 
TEST_F(OpTest,BitwiseXor)2028 TEST_F(OpTest, BitwiseXor) {
2029   Repeatedly([this]() {
2030     auto dims = BroadcastableDims();
2031     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseXor")
2032                                              .RandomInput(DT_INT32, dims.first)
2033                                              .RandomInput(DT_INT32, dims.second)
2034                                              .Attr("T", DT_INT32));
2035   });
2036 }
2037 
TEST_F(OpTest,BroadcastArgs)2038 TEST_F(OpTest, BroadcastArgs) {
2039   Repeatedly([this]() {
2040     // TODO(phawkins): only int32 seems to be implemented in Tensorflow.
2041     // auto type = Choose<DataType>({DT_INT32, DT_INT64});
2042     DataType type = DT_INT32;
2043     auto dims = BroadcastableDims();
2044     return ExpectTfAndXlaOutputsAreClose(
2045         OpTestBuilder("BroadcastArgs")
2046             .Input(AsIntTensor(type, dims.first))
2047             .Input(AsIntTensor(type, dims.second))
2048             .Attr("T", type));
2049   });
2050 }
2051 
TEST_F(OpTest,BroadcastGradientArgs)2052 TEST_F(OpTest, BroadcastGradientArgs) {
2053   Repeatedly([this]() {
2054     // TODO(phawkins): only int32 seems to be implemented in Tensorflow.
2055     // auto type = Choose<DataType>({DT_INT32, DT_INT64});
2056     DataType type = DT_INT32;
2057     auto dims = BroadcastableDims();
2058     return ExpectTfAndXlaOutputsAreClose(
2059         OpTestBuilder("BroadcastGradientArgs")
2060             .Input(AsIntTensor(type, dims.first))
2061             .Input(AsIntTensor(type, dims.second))
2062             .Attr("T", type));
2063   });
2064 }
2065 
TEST_F(OpTest,BroadcastTo)2066 TEST_F(OpTest, BroadcastTo) {
2067   Repeatedly([this]() {
2068     auto type = Choose<DataType>(kAllXlaTypes);
2069     auto type_idx = Choose<DataType>({DT_INT32, DT_INT64});
2070     auto dims_to = RandomDims();
2071     auto dims_from = BroadcastableToDims(dims_to);
2072     return ExpectTfAndXlaOutputsAreClose(
2073         OpTestBuilder("BroadcastTo")
2074             .RandomInput(type, dims_from)
2075             .Input(AsIntTensor(type_idx, dims_to))
2076             .Attr("T", type)
2077             .Attr("Tidx", type_idx));
2078   });
2079 }
2080 
TEST_F(OpTest,Cast)2081 TEST_F(OpTest, Cast) {
2082   Repeatedly([this]() {
2083     DataType src_type, dst_type;
2084     src_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64});
2085     dst_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64});
2086     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast")
2087                                              .RandomInput(src_type)
2088                                              .Attr("SrcT", src_type)
2089                                              .Attr("DstT", dst_type));
2090   });
2091 }
2092 
TEST_F(OpTest,CastBF16)2093 TEST_F(OpTest, CastBF16) {
2094   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2095   Repeatedly([this]() {
2096     DataType src_type, dst_type;
2097     src_type = Choose<DataType>({DT_FLOAT});
2098     dst_type = Choose<DataType>({DT_BFLOAT16});
2099     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast")
2100                                              .RandomInput(src_type)
2101                                              .Attr("SrcT", src_type)
2102                                              .Attr("DstT", dst_type)
2103                                              .Attr("Truncate", true));
2104   });
2105 }
2106 
TEST_F(OpTest,Ceil)2107 TEST_F(OpTest, Ceil) {
2108   Repeatedly([this]() {
2109     return ExpectTfAndXlaOutputsAreClose(
2110         OpTestBuilder("Ceil").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2111   });
2112 }
2113 
TEST_F(OpTest,ClipByValue)2114 TEST_F(OpTest, ClipByValue) {
2115   // TODO(b/211012085): Change input_dims to BroadcastableDimsN(3). The
2116   //                    compiled ClipByValue fails in this case.
2117   //                    --tf_xla_random_seed=200839030
2118   Repeatedly([this]() {
2119     auto type = Choose<DataType>({DT_INT32, DT_INT64, DT_FLOAT});
2120     // ClipByValue requires that broadcasting min and max tensors do not cause
2121     // the returned shape to be larger than the input shape.
2122     auto input_dims = RandomDims();
2123     // clip_value_min must be <= clip_value_max for correct results. Different
2124     // implementations handle the max < min case differently, so ensure that
2125     // min <= max.
2126     auto min_max_dims = BroadcastableToDims(input_dims);
2127     auto min_max = RandomLteTensors(type, min_max_dims);
2128     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ClipByValue")
2129                                              .RandomInput(type, input_dims)
2130                                              .Input(min_max.first)
2131                                              .Input(min_max.second)
2132                                              .Attr("T", type));
2133   });
2134 }
2135 
TEST_F(OpTest,Complex)2136 TEST_F(OpTest, Complex) {
2137   Repeatedly([this]() {
2138     auto dims = BroadcastableDims();
2139     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Complex")
2140                                              .RandomInput(DT_FLOAT, dims.first)
2141                                              .RandomInput(DT_FLOAT, dims.second)
2142                                              .Attr("T", DT_FLOAT));
2143   });
2144 }
2145 
TEST_F(OpTest,Concat)2146 TEST_F(OpTest, Concat) {
2147   Repeatedly([this]() {  // NOLINT: due to GTEST_SKIP
2148     ConcatArguments a = ChooseConcatArguments(false);
2149     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Concat")
2150                                              .Input(a.axis)
2151                                              .VariadicInput(a.values)
2152                                              .Attr("N", a.n)
2153                                              .Attr("T", a.type));
2154   });
2155 }
2156 
TEST_F(OpTest,ConcatV2)2157 TEST_F(OpTest, ConcatV2) {
2158   Repeatedly([this]() {
2159     ConcatArguments a = ChooseConcatArguments(true);
2160     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ConcatV2")
2161                                              .VariadicInput(a.values)
2162                                              .Input(a.axis)
2163                                              .Attr("N", a.n)
2164                                              .Attr("T", a.type)
2165                                              .Attr("Tidx", a.type_idx));
2166   });
2167 }
2168 
TEST_F(OpTest,ConcatOffset)2169 TEST_F(OpTest, ConcatOffset) {
2170   Repeatedly([this]() {
2171     int n = std::uniform_int_distribution<int>(2, 5)(generator());
2172 
2173     std::vector<int64_t> dims = RandomDims(1);
2174     int concat_dim =
2175         std::uniform_int_distribution<int32>(0, dims.size() - 1)(generator());
2176 
2177     OpTestBuilder builder("ConcatOffset");
2178     builder.Input(test::AsScalar<int32>(concat_dim));
2179     builder.Attr("N", n);
2180     for (int i = 0; i < n; ++i) {
2181       std::vector<int32> shape(dims.begin(), dims.end());
2182       shape[concat_dim] = RandomDim();
2183       builder.Input(test::AsTensor<int32>(shape));
2184     }
2185     return ExpectTfAndXlaOutputsAreClose(builder);
2186   });
2187 }
2188 
TEST_F(OpTest,Conj)2189 TEST_F(OpTest, Conj) {
2190   Repeatedly([this]() {
2191     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Conj")
2192                                              .RandomInput(DT_COMPLEX64)
2193                                              .Attr("T", DT_COMPLEX64));
2194   });
2195 }
2196 
TEST_F(OpTest,Const)2197 TEST_F(OpTest, Const) {
2198   Repeatedly([this]() {
2199     auto type = Choose<DataType>({DT_FLOAT});
2200     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Const")
2201                                              .Attr("value", RandomTensor(type))
2202                                              .Attr("dtype", type));
2203   });
2204 }
2205 
TEST_F(OpTest,FFT)2206 TEST_F(OpTest, FFT) {
2207   Repeatedly([this]() {
2208     std::vector<int64_t> dims = RandomDims(1, kDefaultMaxRank);
2209     return ExpectTfAndXlaOutputsAreClose(
2210         OpTestBuilder("FFT").RandomInput(DT_COMPLEX64, dims));
2211   });
2212 }
2213 
TEST_F(OpTest,FFT2D)2214 TEST_F(OpTest, FFT2D) {
2215   Repeatedly([this]() {
2216     std::vector<int64_t> dims = RandomDims(2, kDefaultMaxRank);
2217     return ExpectTfAndXlaOutputsAreClose(
2218         OpTestBuilder("FFT2D").RandomInput(DT_COMPLEX64, dims));
2219   });
2220 }
2221 
TEST_F(OpTest,FFT3D)2222 TEST_F(OpTest, FFT3D) {
2223   Repeatedly([this]() {
2224     std::vector<int64_t> dims = RandomDims(3, kDefaultMaxRank);
2225     return ExpectTfAndXlaOutputsAreClose(
2226         OpTestBuilder("FFT3D").RandomInput(DT_COMPLEX64, dims));
2227   });
2228 }
2229 
TEST_F(OpTest,IFFT)2230 TEST_F(OpTest, IFFT) {
2231   Repeatedly([this]() {
2232     std::vector<int64_t> dims = RandomDims(1, kDefaultMaxRank);
2233     return ExpectTfAndXlaOutputsAreClose(
2234         OpTestBuilder("IFFT").RandomInput(DT_COMPLEX64, dims));
2235   });
2236 }
2237 
TEST_F(OpTest,IFFT2D)2238 TEST_F(OpTest, IFFT2D) {
2239   Repeatedly([this]() {
2240     std::vector<int64_t> dims = RandomDims(2, kDefaultMaxRank);
2241     return ExpectTfAndXlaOutputsAreClose(
2242         OpTestBuilder("IFFT2D").RandomInput(DT_COMPLEX64, dims));
2243   });
2244 }
2245 
TEST_F(OpTest,IFFT3D)2246 TEST_F(OpTest, IFFT3D) {
2247   Repeatedly([this]() {
2248     std::vector<int64_t> dims = RandomDims(3, kDefaultMaxRank);
2249     return ExpectTfAndXlaOutputsAreClose(
2250         OpTestBuilder("IFFT3D").RandomInput(DT_COMPLEX64, dims));
2251   });
2252 }
2253 
TEST_F(OpTest,RFFT)2254 TEST_F(OpTest, RFFT) {
2255   Repeatedly([this]() {
2256     std::vector<int64_t> dims = RandomDims(1, kDefaultMaxRank, 3);
2257     Tensor fft_shape = test::AsTensor<int32>(AsInt32s({dims[dims.size() - 1]}));
2258     return ExpectTfAndXlaOutputsAreClose(
2259         OpTestBuilder("RFFT").RandomInput(DT_FLOAT, dims).Input(fft_shape));
2260   });
2261 }
2262 
TEST_F(OpTest,RFFT2D)2263 TEST_F(OpTest, RFFT2D) {
2264   Repeatedly([this]() {
2265     std::vector<int64_t> dims = RandomDims(2, kDefaultMaxRank, 3);
2266     Tensor fft_shape = test::AsTensor<int32>(
2267         AsInt32s({dims[dims.size() - 2], dims[dims.size() - 1]}));
2268     return ExpectTfAndXlaOutputsAreClose(
2269         OpTestBuilder("RFFT2D").RandomInput(DT_FLOAT, dims).Input(fft_shape));
2270   });
2271 }
2272 
TEST_F(OpTest,RFFT3D)2273 TEST_F(OpTest, RFFT3D) {
2274   Repeatedly([this]() {
2275     std::vector<int64_t> dims = RandomDims(3, kDefaultMaxRank, 3);
2276     Tensor fft_shape = test::AsTensor<int32>(AsInt32s(
2277         {dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]}));
2278     return ExpectTfAndXlaOutputsAreClose(
2279         OpTestBuilder("RFFT3D").RandomInput(DT_FLOAT, dims).Input(fft_shape));
2280   });
2281 }
2282 
TEST_F(OpTest,IRFFT)2283 TEST_F(OpTest, IRFFT) {
2284   Repeatedly([this]() {
2285     std::vector<int64_t> dims = RandomDims(1, kDefaultMaxRank, 3);
2286     int64_t orig_size = dims[dims.size() - 1];
2287     dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1;
2288     Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size}));
2289     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT")
2290                                              .RandomInput(DT_COMPLEX64, dims)
2291                                              .Input(fft_shape));
2292   });
2293 }
2294 
TEST_F(OpTest,IRFFT2D)2295 TEST_F(OpTest, IRFFT2D) {
2296   Repeatedly([this]() {
2297     std::vector<int64_t> dims = RandomDims(2, kDefaultMaxRank, 3);
2298     std::vector<int64_t> orig_size = {dims[dims.size() - 2],
2299                                       dims[dims.size() - 1]};
2300     dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1;
2301     Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size}));
2302     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT2D")
2303                                              .RandomInput(DT_COMPLEX64, dims)
2304                                              .Input(fft_shape));
2305   });
2306 }
2307 
TEST_F(OpTest,IRFFT3D)2308 TEST_F(OpTest, IRFFT3D) {
2309   Repeatedly([this]() {
2310     std::vector<int64_t> dims = RandomDims(3, kDefaultMaxRank, 3);
2311     std::vector<int64_t> orig_size = {
2312         dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]};
2313     dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1;
2314     Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size}));
2315     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT3D")
2316                                              .RandomInput(DT_COMPLEX64, dims)
2317                                              .Input(fft_shape));
2318   });
2319 }
2320 
TEST_F(OpTest,Conv2D)2321 TEST_F(OpTest, Conv2D) {
2322   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2323   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2324   Repeatedly([this]() {
2325     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
2326     std::uniform_int_distribution<int> random_int(1, 5);
2327     int features_in = random_int(generator());
2328     int features_out = random_int(generator());
2329 
2330     int64_t batch = RandomDim();
2331 
2332     std::vector<int64_t> data_dims =
2333         ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
2334 
2335     std::vector<int64_t> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
2336                                         features_in, features_out};
2337     DataType type = DT_FLOAT;
2338     return ExpectTfAndXlaOutputsAreClose(
2339         OpTestBuilder("Conv2D")
2340             .RandomInput(type, data_dims)
2341             .RandomInput(type, kernel_dims)
2342             .Attr("T", type)
2343             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
2344             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
2345             .Attr("data_format", "NHWC"));
2346   });
2347 }
2348 
TEST_F(OpTest,Conv2DBackpropFilter)2349 TEST_F(OpTest, Conv2DBackpropFilter) {
2350   Repeatedly([this]() {
2351     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
2352     std::uniform_int_distribution<int> random_int(1, 5);
2353     int features_in = random_int(generator());
2354     int features_out = random_int(generator());
2355     int32_t batch = RandomDim();
2356     std::vector<int64_t> activations =
2357         ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
2358     std::vector<int64_t> backprop =
2359         ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
2360     Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
2361         {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}));
2362     DataType type = DT_FLOAT;
2363     return ExpectTfAndXlaOutputsAreClose(
2364         OpTestBuilder("Conv2DBackpropFilter")
2365             .RandomInput(type, activations)
2366             .Input(kernel_shape)
2367             .RandomInput(type, backprop)
2368             .Attr("T", type)
2369             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
2370             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
2371             .Attr("data_format", "NHWC"));
2372   });
2373 }
2374 
TEST_F(OpTest,Conv2DBackpropInput)2375 TEST_F(OpTest, Conv2DBackpropInput) {
2376   Repeatedly([this]() {
2377     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
2378     std::uniform_int_distribution<int> random_int(1, 5);
2379     int features_in = random_int(generator());
2380     int features_out = random_int(generator());
2381     int32_t batch = RandomDim();
2382     Tensor in_shape = test::AsTensor<int32>(
2383         AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)));
2384     std::vector<int64_t> backprop =
2385         ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
2386     std::vector<int64_t> kernel = {d.kernel_dims[0], d.kernel_dims[1],
2387                                    features_in, features_out};
2388     DataType type = DT_FLOAT;
2389     return ExpectTfAndXlaOutputsAreClose(
2390         OpTestBuilder("Conv2DBackpropInput")
2391             .Input(in_shape)
2392             .RandomInput(type, kernel)
2393             .RandomInput(type, backprop)
2394             .Attr("T", type)
2395             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
2396             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
2397             .Attr("data_format", "NHWC"));
2398   });
2399 }
2400 
TEST_F(OpTest,Conv3D)2401 TEST_F(OpTest, Conv3D) {
2402   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2403   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2404   Repeatedly([this]() {
2405     WindowedSpatialDims d = ChooseWindowedSpatialDims(3);
2406     std::uniform_int_distribution<int> random_int(1, 5);
2407     int features_in = random_int(generator());
2408     int features_out = random_int(generator());
2409     std::vector<int64_t> data = {RandomDim(), d.input_dims[0], d.input_dims[1],
2410                                  d.input_dims[2], features_in};
2411 
2412     std::vector<int64_t> kernel = {d.kernel_dims[0], d.kernel_dims[1],
2413                                    d.kernel_dims[2], features_in, features_out};
2414     DataType type = DT_FLOAT;
2415     return ExpectTfAndXlaOutputsAreClose(
2416         OpTestBuilder("Conv3D")
2417             .RandomInput(type, data)
2418             .RandomInput(type, kernel)
2419             .Attr("T", type)
2420             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
2421             .Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
2422   });
2423 }
2424 
TEST_F(OpTest,Conv3DBackpropFilter)2425 TEST_F(OpTest, Conv3DBackpropFilter) {
2426   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2427   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2428   Repeatedly([this]() {
2429     WindowedSpatialDims d = ChooseWindowedSpatialDims(3);
2430     std::uniform_int_distribution<int> random_int(1, 5);
2431     int features_in = random_int(generator());
2432     int features_out = random_int(generator());
2433     int32_t batch = RandomDim(1);
2434     std::vector<int64_t> activations =
2435         ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
2436     std::vector<int64_t> backprop =
2437         ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
2438     Tensor kernel_shape = test::AsTensor<int32>(
2439         AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2],
2440                   features_in, features_out}));
2441     DataType type = DT_FLOAT;
2442     return ExpectTfAndXlaOutputsAreClose(
2443         OpTestBuilder("Conv3DBackpropFilterV2")
2444             .RandomInput(type, activations)
2445             .Input(kernel_shape)
2446             .RandomInput(type, backprop)
2447             .Attr("T", type)
2448             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
2449             .Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
2450   });
2451 }
2452 
TEST_F(OpTest,Conv3DBackpropInput)2453 TEST_F(OpTest, Conv3DBackpropInput) {
2454   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2455   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2456   Repeatedly([this]() {
2457     WindowedSpatialDims d = ChooseWindowedSpatialDims(3);
2458     std::uniform_int_distribution<int> random_int(1, 5);
2459     int features_in = random_int(generator());
2460     int features_out = random_int(generator());
2461     int32_t batch = RandomDim(1);
2462     Tensor in_shape = test::AsTensor<int32>(
2463         AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)));
2464     std::vector<int64_t> backprop =
2465         ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
2466     std::vector<int64_t> kernel = {d.kernel_dims[0], d.kernel_dims[1],
2467                                    d.kernel_dims[2], features_in, features_out};
2468     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2469     return ExpectTfAndXlaOutputsAreClose(
2470         OpTestBuilder("Conv3DBackpropInputV2")
2471             .Input(in_shape)
2472             .RandomInput(type, kernel)
2473             .RandomInput(type, backprop)
2474             .Attr("T", type)
2475             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
2476             .Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
2477   });
2478 }
2479 
TEST_F(OpTest,ComplexAbs)2480 TEST_F(OpTest, ComplexAbs) {
2481   Repeatedly([this]() {
2482     auto type = DT_COMPLEX64;
2483     auto type_out = DT_FLOAT;
2484     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ComplexAbs")
2485                                              .RandomInput(type)
2486                                              .Attr("T", type)
2487                                              .Attr("Tout", type_out));
2488   });
2489 }
2490 
TEST_F(OpTest,Cos)2491 TEST_F(OpTest, Cos) {
2492   Repeatedly([this]() {
2493     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2494     return ExpectTfAndXlaOutputsAreClose(
2495         OpTestBuilder("Cos").RandomInput(type).Attr("T", type));
2496   });
2497 }
2498 
TEST_F(OpTest,Cosh)2499 TEST_F(OpTest, Cosh) {
2500   Repeatedly([this]() {
2501     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2502     return ExpectTfAndXlaOutputsAreClose(
2503         OpTestBuilder("Cosh").RandomInput(type).Attr("T", type));
2504   });
2505 }
2506 
TEST_F(OpTest,DepthToSpace)2507 TEST_F(OpTest, DepthToSpace) {
2508   Repeatedly([this]() {
2509     int64_t block = RandomDim(2, 5);
2510     std::vector<int64_t> input_dims = RandomDims(4, 4);
2511     input_dims[1] = (input_dims[1] + (block - 1)) / block;
2512     input_dims[2] = (input_dims[2] + (block - 1)) / block;
2513     input_dims[3] *= block * block;
2514     auto type = Choose<DataType>(kAllXlaTypes);
2515     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DepthToSpace")
2516                                              .RandomInput(type, input_dims)
2517                                              .Attr("T", type)
2518                                              .Attr("block_size", block));
2519   });
2520 }
2521 
TEST_F(OpTest,DepthwiseConv2DNative)2522 TEST_F(OpTest, DepthwiseConv2DNative) {
2523   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2524   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2525   Repeatedly([this]() {
2526     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
2527     std::uniform_int_distribution<int> random_int(1, 5);
2528     int features_in = random_int(generator());
2529     int depth_multiplier = random_int(generator());
2530     std::vector<int64_t> input_dims = {RandomDim(), d.input_dims[0],
2531                                        d.input_dims[1], features_in};
2532 
2533     std::vector<int64_t> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
2534                                         features_in, depth_multiplier};
2535     std::vector<int64_t> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
2536     strides[2] = strides[1];  // Current impl only supports equal strides
2537     return ExpectTfAndXlaOutputsAreClose(
2538         OpTestBuilder("DepthwiseConv2dNative")
2539             .RandomInput(DT_FLOAT, input_dims)
2540             .RandomInput(DT_FLOAT, kernel_dims)
2541             .Attr("T", DT_FLOAT)
2542             .Attr("strides", strides)
2543             .Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
2544   });
2545 }
2546 
TEST_F(OpTest,DepthwiseConv2DNativeBackpropFilter)2547 TEST_F(OpTest, DepthwiseConv2DNativeBackpropFilter) {
2548   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2549   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2550   Repeatedly([this]() {
2551     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
2552     std::uniform_int_distribution<int> random_int(1, 5);
2553     int features_in = random_int(generator());
2554     int depth_multiplier = random_int(generator());
2555     int32_t batch = RandomDim();
2556     std::vector<int64_t> activations =
2557         ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
2558     std::vector<int64_t> backprop = ImageDims(
2559         FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
2560     Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
2561         {d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier}));
2562     std::vector<int64_t> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
2563     strides[2] = strides[1];  // Current impl only supports equal strides
2564     return ExpectTfAndXlaOutputsAreClose(
2565         OpTestBuilder("DepthwiseConv2dNativeBackpropFilter")
2566             .RandomInput(DT_FLOAT, activations)
2567             .Input(kernel_shape)
2568             .RandomInput(DT_FLOAT, backprop)
2569             .Attr("T", DT_FLOAT)
2570             .Attr("strides", strides)
2571             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
2572             .Attr("data_format", "NHWC"));
2573   });
2574 }
2575 
TEST_F(OpTest,DepthwiseConv2DBackpropInput)2576 TEST_F(OpTest, DepthwiseConv2DBackpropInput) {
2577   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2578   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2579   Repeatedly([this]() {
2580     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
2581     std::uniform_int_distribution<int> random_int(1, 5);
2582     int features_in = random_int(generator());
2583     int depth_multiplier = random_int(generator());
2584     int32_t batch = RandomDim();
2585     Tensor in_shape = test::AsTensor<int32>(
2586         AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)));
2587     std::vector<int64_t> backprop = ImageDims(
2588         FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
2589     std::vector<int64_t> kernel = {d.kernel_dims[0], d.kernel_dims[1],
2590                                    features_in, depth_multiplier};
2591     std::vector<int64_t> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
2592     strides[2] = strides[1];  // Current impl only supports equal strides
2593     return ExpectTfAndXlaOutputsAreClose(
2594         OpTestBuilder("DepthwiseConv2dNativeBackpropInput")
2595             .Input(in_shape)
2596             .RandomInput(DT_FLOAT, kernel)
2597             .RandomInput(DT_FLOAT, backprop)
2598             .Attr("T", DT_FLOAT)
2599             .Attr("strides", strides)
2600             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
2601             .Attr("data_format", "NHWC"));
2602   });
2603 }
2604 
TEST_F(OpTest,Diag)2605 TEST_F(OpTest, Diag) {
2606   Repeatedly([this]() {
2607     auto type = Choose<DataType>(kAllXlaTypes);
2608     std::vector<int64_t> dims;
2609     // Diag causes a quadratic blowup in output size.
2610     int64_t size;
2611     do {
2612       dims = RandomDims(1);
2613       size = TensorShape(dims).num_elements();
2614     } while (size * size > tf_xla_max_tensor_size);
2615     return ExpectTfAndXlaOutputsAreClose(
2616         OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type));
2617   });
2618 }
2619 
TEST_F(OpTest,DiagPart)2620 TEST_F(OpTest, DiagPart) {
2621   Repeatedly([this]() {
2622     auto type = Choose<DataType>(kAllXlaTypes);
2623     auto dims = RandomDims(1, 3);
2624     // Duplicate the random dims.
2625     std::vector<int64_t> doubled_dims(dims.size() * 2);
2626     std::copy(dims.begin(), dims.end(), doubled_dims.begin());
2627     std::copy(dims.begin(), dims.end(), doubled_dims.begin() + dims.size());
2628     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DiagPart")
2629                                              .RandomInput(type, doubled_dims)
2630                                              .Attr("T", type));
2631   });
2632 }
2633 
TEST_F(OpTest,Digamma)2634 TEST_F(OpTest, Digamma) {
2635   Repeatedly([this]() {
2636     return ExpectTfAndXlaOutputsAreClose(
2637         OpTestBuilder("Digamma").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2638   });
2639 }
2640 
TEST_F(OpTest,Div)2641 TEST_F(OpTest, Div) {
2642   Repeatedly([this]() {
2643     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2644     auto dims = BroadcastableDims();
2645     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div")
2646                                              .RandomInput(type, dims.first)
2647                                              .RandomInput(type, dims.second)
2648                                              .Attr("T", type));
2649   });
2650 }
2651 
TEST_F(OpTest,DivNoNan)2652 TEST_F(OpTest, DivNoNan) {
2653   Repeatedly([this]() {
2654     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2655     auto dims = BroadcastableDims();
2656     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DivNoNan")
2657                                              .RandomInput(type, dims.first)
2658                                              .RandomInput(type, dims.second)
2659                                              .Attr("T", type));
2660   });
2661 }
2662 
TEST_F(OpTest,DynamicStitch)2663 TEST_F(OpTest, DynamicStitch) {
2664   Repeatedly([this]() {
2665     auto type = Choose<DataType>(kAllXlaTypes);
2666     int n = std::uniform_int_distribution<int>(2, 5)(generator());
2667     OpTestBuilder builder("DynamicStitch");
2668     builder.Attr("T", type);
2669     builder.Attr("N", n);
2670     std::vector<std::vector<int64_t>> index_dims;
2671     int size = 0;
2672     // TODO(phawkins): the XLA implementation of DynamicStitch does not
2673     // accept an empty set of indices.
2674     do {
2675       size = 0;
2676       index_dims.clear();
2677       for (int i = 0; i < n; ++i) {
2678         std::vector<int64_t> dims = RandomDims(0, 3, 0, 5);
2679         size += TensorShape(dims).num_elements();
2680         index_dims.push_back(dims);
2681       }
2682     } while (size == 0);
2683 
2684     // Shuffle the range of indices that cover the output.
2685     // TODO(phawkins): The documentation for DynamicStitch doesn't require
2686     // that the indices cover all positions of the output. The XLA
2687     // implementation does so require. However, the native TF implementation
2688     // leaves undefined values if we don't cover everything, so we can't
2689     // really test that case anyway.
2690     std::vector<int32> indices(size);
2691     std::iota(indices.begin(), indices.end(), 0);
2692     std::shuffle(indices.begin(), indices.end(), generator());
2693 
2694     int pos = 0;
2695     for (int i = 0; i < n; ++i) {
2696       TensorShape shape(index_dims[i]);
2697       Tensor t = test::AsTensor<int32>(
2698           absl::Span<const int32>(indices).subspan(pos, shape.num_elements()),
2699           shape);
2700       builder.Input(t);
2701       pos += t.NumElements();
2702     }
2703 
2704     std::vector<int64_t> constant_dims = RandomDims(0, 3, 0, 5);
2705     for (int i = 0; i < n; ++i) {
2706       std::vector<int64_t> dims(index_dims[i].begin(), index_dims[i].end());
2707       std::copy(constant_dims.begin(), constant_dims.end(),
2708                 std::back_inserter(dims));
2709       builder.RandomInput(type, dims);
2710     }
2711     return ExpectTfAndXlaOutputsAreClose(builder);
2712   });
2713 }
2714 
TEST_F(OpTest,Einsum)2715 TEST_F(OpTest, Einsum) {
2716   Repeatedly([this]() {
2717     const EinsumArguments a = ChooseEinsumArguments();
2718     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Einsum")
2719                                              .RandomInput(a.type, a.lhs_dims)
2720                                              .RandomInput(a.type, a.rhs_dims)
2721                                              .Attr("equation", a.equation)
2722                                              .Attr("T", a.type)
2723                                              .Attr("N", 2));
2724   });
2725 }
2726 
TEST_F(OpTest,Empty)2727 TEST_F(OpTest, Empty) {
2728   Repeatedly([this]() {
2729     auto type = Choose<DataType>({kAllXlaTypes});
2730     return ExpectTfAndXlaOutputsAreClose(
2731         OpTestBuilder("Empty")
2732             .Input(AsIntTensor(DT_INT32, RandomDims()))
2733             .Attr("init", true)
2734             .Attr("dtype", type));
2735   });
2736 }
2737 
TEST_F(OpTest,Elu)2738 TEST_F(OpTest, Elu) {
2739   Repeatedly([this]() {
2740     return ExpectTfAndXlaOutputsAreClose(
2741         OpTestBuilder("Elu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2742   });
2743 }
2744 
TEST_F(OpTest,EluGrad)2745 TEST_F(OpTest, EluGrad) {
2746   Repeatedly([this]() {
2747     auto dims = RandomDims();
2748     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad")
2749                                              .RandomInput(DT_FLOAT, dims)
2750                                              .RandomInput(DT_FLOAT, dims)
2751                                              .Attr("T", DT_FLOAT));
2752   });
2753 }
2754 
TEST_F(OpTest,ScatterNd)2755 TEST_F(OpTest, ScatterNd) {
2756   Repeatedly([this]() {
2757     auto a = ChooseScatterArguments();
2758     auto shape = test::AsTensor<int32>(
2759         std::vector<int32>(a.shape.begin(), a.shape.end()));
2760     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ScatterNd")
2761                                              .Input(a.indices)
2762                                              .Input(a.updates)
2763                                              .Input(shape)
2764                                              .Attr("T", a.type)
2765                                              .Attr("Tindices", a.indices_type));
2766   });
2767 }
2768 
TEST_F(OpTest,Selu)2769 TEST_F(OpTest, Selu) {
2770   Repeatedly([this]() {
2771     return ExpectTfAndXlaOutputsAreClose(
2772         OpTestBuilder("Selu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2773   });
2774 }
2775 
TEST_F(OpTest,SeluGrad)2776 TEST_F(OpTest, SeluGrad) {
2777   Repeatedly([this]() {
2778     auto dims = RandomDims();
2779     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SeluGrad")
2780                                              .RandomInput(DT_FLOAT, dims)
2781                                              .RandomInput(DT_FLOAT, dims)
2782                                              .Attr("T", DT_FLOAT));
2783   });
2784 }
2785 
TEST_F(OpTest,Equal)2786 TEST_F(OpTest, Equal) {
2787   Repeatedly([this]() {
2788     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2789     auto dims = BroadcastableDims();
2790     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal")
2791                                              .RandomInput(type, dims.first)
2792                                              .RandomInput(type, dims.second)
2793                                              .Attr("T", type));
2794   });
2795 }
2796 
TEST_F(OpTest,Erf)2797 TEST_F(OpTest, Erf) {
2798   Repeatedly([this]() {
2799     return ExpectTfAndXlaOutputsAreClose(
2800         OpTestBuilder("Erf").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2801   });
2802 }
2803 
TEST_F(OpTest,Erfc)2804 TEST_F(OpTest, Erfc) {
2805   Repeatedly([this]() {
2806     return ExpectTfAndXlaOutputsAreClose(
2807         OpTestBuilder("Erfc").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2808   });
2809 }
2810 
TEST_F(OpTest,Exp)2811 TEST_F(OpTest, Exp) {
2812   Repeatedly([this]() {
2813     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2814     return ExpectTfAndXlaOutputsAreClose(
2815         OpTestBuilder("Exp").RandomInput(type).Attr("T", type));
2816   });
2817 }
2818 
TEST_F(OpTest,Expm1)2819 TEST_F(OpTest, Expm1) {
2820   Repeatedly([this]() {
2821     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2822     return ExpectTfAndXlaOutputsAreClose(
2823         OpTestBuilder("Expm1").RandomInput(type).Attr("T", type));
2824   });
2825 }
2826 
TEST_F(OpTest,ExpandDims)2827 TEST_F(OpTest, ExpandDims) {
2828   Repeatedly([this]() {
2829     auto type = Choose<DataType>(kAllXlaTypes);
2830     std::vector<int64_t> in_dims = RandomDims();
2831     Tensor dim(DT_INT32, TensorShape());
2832     std::uniform_int_distribution<int32> d(-1 - in_dims.size(), in_dims.size());
2833     dim.scalar<int32>()() = d(generator());
2834     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ExpandDims")
2835                                              .RandomInput(type, in_dims)
2836                                              .Input(dim)
2837                                              .Attr("T", type));
2838   });
2839 }
2840 
TEST_F(OpTest,Fill)2841 TEST_F(OpTest, Fill) {
2842   Repeatedly([this]() {
2843     auto type = Choose<DataType>(kAllXlaTypes);
2844     std::vector<int64_t> dims = RandomDims();
2845     std::vector<int32> shape(dims.begin(), dims.end());
2846     return ExpectTfAndXlaOutputsAreClose(
2847         OpTestBuilder("Fill")
2848             .Input(test::AsTensor<int32>(shape))
2849             .RandomInput(type, {})
2850             .Attr("T", type));
2851   });
2852 }
2853 
TEST_F(OpTest,Floor)2854 TEST_F(OpTest, Floor) {
2855   Repeatedly([this]() {
2856     return ExpectTfAndXlaOutputsAreClose(
2857         OpTestBuilder("Floor").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2858   });
2859 }
2860 
TEST_F(OpTest,FloorDiv)2861 TEST_F(OpTest, FloorDiv) {
2862   Repeatedly([this]() {
2863     DataType type = DT_INT32;
2864     auto dims = BroadcastableDims();
2865     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorDiv")
2866                                              .RandomInput(type, dims.first)
2867                                              .RandomInput(type, dims.second)
2868                                              .Attr("T", type));
2869   });
2870 }
2871 
TEST_F(OpTest,FloorMod)2872 TEST_F(OpTest, FloorMod) {
2873   Repeatedly([this]() {
2874     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2875     auto dims = BroadcastableDims();
2876     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorMod")
2877                                              .RandomInput(type, dims.first)
2878                                              .RandomInput(type, dims.second)
2879                                              .Attr("T", type));
2880   });
2881 }
2882 
TEST_F(OpTest,Gather)2883 TEST_F(OpTest, Gather) {
2884   Repeatedly([this]() {
2885     GatherArguments a = ChooseGatherArguments(true);
2886     return ExpectTfAndXlaOutputsAreClose(
2887         OpTestBuilder("Gather")
2888             .RandomInput(a.params_type, a.params_shape)
2889             .Input(a.indices)
2890             .Attr("Tparams", a.params_type)
2891             .Attr("Tindices", a.indices_type));
2892   });
2893 }
2894 
TEST_F(OpTest,GatherV2)2895 TEST_F(OpTest, GatherV2) {
2896   Repeatedly([this]() {
2897     GatherArguments a = ChooseGatherArguments(false);
2898     return ExpectTfAndXlaOutputsAreClose(
2899         OpTestBuilder("GatherV2")
2900             .RandomInput(a.params_type, a.params_shape)
2901             .Input(a.indices)
2902             .Input(a.axis)
2903             .Attr("batch_dims", a.batch_dims)
2904             .Attr("Taxis", a.axis_type)
2905             .Attr("Tindices", a.indices_type)
2906             .Attr("Tparams", a.params_type));
2907   });
2908 }
2909 
TEST_F(OpTest,GatherNd)2910 TEST_F(OpTest, GatherNd) {
2911   // :randomized_tests_mlir fails with --tf_xla_random_seed=459353625
2912   // --test_arg=--tf_xla_test_repetitions=100
2913   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
2914   // See b/214080339#comment27 as this test causes Kokoro to crash.
2915   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
2916   Repeatedly([this]() {  // NOLINT: due to GTEST_SKIP
2917     auto params_type = Choose<DataType>(kAllXlaTypes);
2918     // GatherNd seems undefined on the case where params has rank 0.
2919     std::vector<int64_t> params_shape = RandomDims(1);
2920     auto indices_type = DT_INT32;
2921     std::vector<int64_t> output_outer_shape = RandomDims(0, 4, 0, 32);
2922     int64_t index_len = RandomDim(0, params_shape.size() + 1);
2923     std::vector<int64_t> output_shape(output_outer_shape);
2924     output_shape.push_back(index_len);
2925     Tensor lo(indices_type, TensorShape(output_shape));
2926     test::FillFn<int32>(&lo, [](int i) -> int32 { return 0; });
2927     Tensor hi(indices_type, TensorShape(output_shape));
2928     test::FillFn<int32>(&hi, [index_len, &params_shape](int i) -> int32 {
2929       int idx_dim = i % index_len;
2930       return params_shape[idx_dim] - 1;
2931     });
2932     Tensor indices = RandomBoundedTensor(indices_type, lo, hi);
2933     return ExpectTfAndXlaOutputsAreClose(
2934         OpTestBuilder("GatherNd")
2935             .RandomInput(params_type, params_shape)
2936             .Input(indices)
2937             .Attr("Tindices", indices_type)
2938             .Attr("Tparams", params_type));
2939   });
2940 }
2941 
TEST_F(OpTest,Greater)2942 TEST_F(OpTest, Greater) {
2943   Repeatedly([this]() {
2944     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2945     auto dims = BroadcastableDims();
2946     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Greater")
2947                                              .RandomInput(type, dims.first)
2948                                              .RandomInput(type, dims.second)
2949                                              .Attr("T", type));
2950   });
2951 }
2952 
TEST_F(OpTest,GreaterEqual)2953 TEST_F(OpTest, GreaterEqual) {
2954   Repeatedly([this]() {
2955     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2956     auto dims = BroadcastableDims();
2957     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("GreaterEqual")
2958                                              .RandomInput(type, dims.first)
2959                                              .RandomInput(type, dims.second)
2960                                              .Attr("T", type));
2961   });
2962 }
2963 
TEST_F(OpTest,Identity)2964 TEST_F(OpTest, Identity) {
2965   Repeatedly([this]() {
2966     auto type = Choose<DataType>(kAllXlaTypes);
2967     return ExpectTfAndXlaOutputsAreClose(
2968         OpTestBuilder("Identity").RandomInput(type).Attr("T", type));
2969   });
2970 }
2971 
TEST_F(OpTest,Imag)2972 TEST_F(OpTest, Imag) {
2973   Repeatedly([this]() {
2974     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Imag")
2975                                              .RandomInput(DT_COMPLEX64)
2976                                              .Attr("T", DT_COMPLEX64));
2977   });
2978 }
2979 
TEST_F(OpTest,InplaceUpdate)2980 TEST_F(OpTest, InplaceUpdate) {
2981   Repeatedly([this]() {
2982     auto type = Choose<DataType>(kAllXlaTypes);
2983     std::vector<int64_t> common_dims =
2984         RandomDims(0, kDefaultMaxRank - 1, 0, kDefaultMaxDimensionSize);
2985     // TODO(b/211012712): Once needs_unique_values case is linear instead of
2986     // quadratic time, use default Dim max instead of 8.
2987     std::vector<int64_t> v_dims{RandomDim(1, 8)};
2988     v_dims.insert(v_dims.end(), common_dims.begin(), common_dims.end());
2989     std::vector<int64_t> x_dims{RandomDim(v_dims[0])};
2990     x_dims.insert(x_dims.end(), common_dims.begin(), common_dims.end());
2991     std::vector<int64_t> i_shape{v_dims[0]};
2992     Tensor i =
2993         RandomBoundedTensor<int32>(DT_INT32, 0, x_dims[0] - 1, true, i_shape);
2994     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("InplaceUpdate")
2995                                              .RandomInput(type, x_dims)
2996                                              .Input(i)
2997                                              .RandomInput(type, v_dims)
2998                                              .Attr("T", type));
2999   });
3000 }
3001 
TEST_F(OpTest,Inv)3002 TEST_F(OpTest, Inv) {
3003   Repeatedly([this]() {
3004     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3005     return ExpectTfAndXlaOutputsAreClose(
3006         OpTestBuilder("Inv").RandomInput(type).Attr("T", type));
3007   });
3008 }
3009 
TEST_F(OpTest,Invert)3010 TEST_F(OpTest, Invert) {
3011   Repeatedly([this]() {
3012     DataType type = DT_INT32;
3013     return ExpectTfAndXlaOutputsAreClose(
3014         OpTestBuilder("Invert").RandomInput(type).Attr("T", type));
3015   });
3016 }
3017 
TEST_F(OpTest,InvertPermutation)3018 TEST_F(OpTest, InvertPermutation) {
3019   Repeatedly([this]() {
3020     // TODO(b/211012712): Once needs_unique_values case is linear instead of
3021     // quadratic time, use default Dim max instead of 8.
3022     int64_t len = RandomDim(0, 8);
3023     Tensor x = RandomBoundedTensor<int32>(DT_INT32, 0, len - 1, true, {len});
3024     return ExpectTfAndXlaOutputsAreClose(
3025         OpTestBuilder("InvertPermutation").Input(x).Attr("T", DT_INT32));
3026   });
3027 }
3028 
TEST_F(OpTest,IsFinite)3029 TEST_F(OpTest, IsFinite) {
3030   Repeatedly([this]() {
3031     return ExpectTfAndXlaOutputsAreClose(
3032         OpTestBuilder("IsFinite").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3033   });
3034 }
3035 
TEST_F(OpTest,IsInf)3036 TEST_F(OpTest, IsInf) {
3037   Repeatedly([this]() {
3038     return ExpectTfAndXlaOutputsAreClose(
3039         OpTestBuilder("IsInf").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3040   });
3041 }
3042 
TEST_F(OpTest,IsNan)3043 TEST_F(OpTest, IsNan) {
3044   Repeatedly([this]() {
3045     return ExpectTfAndXlaOutputsAreClose(
3046         OpTestBuilder("IsNan").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3047   });
3048 }
3049 
TEST_F(OpTest,L2Loss)3050 TEST_F(OpTest, L2Loss) {
3051   Repeatedly([this]() {
3052     DataType type = DT_FLOAT;
3053     return ExpectTfAndXlaOutputsAreClose(
3054         OpTestBuilder("L2Loss").RandomInput(type).Attr("T", type));
3055   });
3056 }
3057 
TEST_F(OpTest,LeakyRelu)3058 TEST_F(OpTest, LeakyRelu) {
3059   Repeatedly([this]() {
3060     std::uniform_real_distribution<float> alpha(-2.0f, 2.0f);
3061     return ExpectTfAndXlaOutputsAreClose(
3062         OpTestBuilder("LeakyRelu")
3063             .RandomInput(DT_FLOAT)
3064             .Attr("T", DT_FLOAT)
3065             .Attr("alpha", alpha(generator())));
3066   });
3067 }
3068 
TEST_F(OpTest,LeakyReluGrad)3069 TEST_F(OpTest, LeakyReluGrad) {
3070   Repeatedly([this]() {
3071     auto dims = RandomDims(1);
3072     std::uniform_real_distribution<float> alpha(-2.0f, 2.0f);
3073     return ExpectTfAndXlaOutputsAreClose(
3074         OpTestBuilder("LeakyReluGrad")
3075             .RandomInput(DT_FLOAT, dims)
3076             .RandomInput(DT_FLOAT, dims)
3077             .Attr("T", DT_FLOAT)
3078             .Attr("alpha", alpha(generator())));
3079   });
3080 }
3081 
TEST_F(OpTest,LeftShift)3082 TEST_F(OpTest, LeftShift) {
3083   Repeatedly([this]() {
3084     bool is64 = RandomBool();
3085     auto dims = RandomDims();
3086     auto type = is64 ? DT_INT64 : DT_INT32;
3087     int max_shift = is64 ? 63 : 31;
3088     auto y = RandomBoundedTensor(type, 0, max_shift, false, dims);
3089     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LeftShift")
3090                                              .RandomInput(type, dims)
3091                                              .Input(y)
3092                                              .Attr("T", type));
3093   });
3094 }
3095 
TEST_F(OpTest,Less)3096 TEST_F(OpTest, Less) {
3097   Repeatedly([this]() {
3098     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
3099     auto dims = BroadcastableDims();
3100     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Less")
3101                                              .RandomInput(type, dims.first)
3102                                              .RandomInput(type, dims.second)
3103                                              .Attr("T", type));
3104   });
3105 }
3106 
TEST_F(OpTest,LessEqual)3107 TEST_F(OpTest, LessEqual) {
3108   Repeatedly([this]() {
3109     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
3110     auto dims = BroadcastableDims();
3111     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LessEqual")
3112                                              .RandomInput(type, dims.first)
3113                                              .RandomInput(type, dims.second)
3114                                              .Attr("T", type));
3115   });
3116 }
3117 
TEST_F(OpTest,Lgamma)3118 TEST_F(OpTest, Lgamma) {
3119   Repeatedly([this]() {
3120     return ExpectTfAndXlaOutputsAreClose(
3121         OpTestBuilder("Lgamma").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3122   });
3123 }
3124 
TEST_F(OpTest,LinSpace)3125 TEST_F(OpTest, LinSpace) {
3126   Repeatedly([this]() {
3127     auto ToScalar = [](DataType type, int x) {
3128       if (type == DT_INT32) return test::AsScalar<int32>(x);
3129       return test::AsScalar<int64_t>(x);
3130     };
3131     std::uniform_int_distribution<int> distribution(-50, 50);
3132     auto type = Choose<DataType>({DT_INT32, DT_INT64});
3133     return ExpectTfAndXlaOutputsAreClose(
3134         OpTestBuilder("LinSpace")
3135             .RandomInput(DT_FLOAT, {})
3136             .RandomInput(DT_FLOAT, {})
3137             .Input(ToScalar(type, distribution(generator())))
3138             .Attr("T", DT_FLOAT)
3139             .Attr("Tidx", type));
3140   });
3141 }
3142 
TEST_F(OpTest,Log)3143 TEST_F(OpTest, Log) {
3144   Repeatedly([this]() {
3145     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3146     return ExpectTfAndXlaOutputsAreClose(
3147         OpTestBuilder("Log").RandomInput(type).Attr("T", type));
3148   });
3149 }
3150 
TEST_F(OpTest,Log1p)3151 TEST_F(OpTest, Log1p) {
3152   Repeatedly([this]() {
3153     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3154     return ExpectTfAndXlaOutputsAreClose(
3155         OpTestBuilder("Log1p").RandomInput(type).Attr("T", DT_FLOAT));
3156   });
3157 }
3158 
TEST_F(OpTest,LogicalAnd)3159 TEST_F(OpTest, LogicalAnd) {
3160   Repeatedly([this]() {
3161     auto dims = BroadcastableDims();
3162     return ExpectTfAndXlaOutputsAreClose(
3163         OpTestBuilder("LogicalAnd")
3164             .RandomInput(DT_BOOL, dims.first)
3165             .RandomInput(DT_BOOL, dims.second));
3166   });
3167 }
3168 
TEST_F(OpTest,LogicalNot)3169 TEST_F(OpTest, LogicalNot) {
3170   Repeatedly([this]() {
3171     return ExpectTfAndXlaOutputsAreClose(
3172         OpTestBuilder("LogicalNot").RandomInput(DT_BOOL));
3173   });
3174 }
3175 
TEST_F(OpTest,LogicalOr)3176 TEST_F(OpTest, LogicalOr) {
3177   Repeatedly([this]() {
3178     auto dims = BroadcastableDims();
3179     return ExpectTfAndXlaOutputsAreClose(
3180         OpTestBuilder("LogicalOr")
3181             .RandomInput(DT_BOOL, dims.first)
3182             .RandomInput(DT_BOOL, dims.second));
3183   });
3184 }
3185 
TEST_F(OpTest,LogSoftmax)3186 TEST_F(OpTest, LogSoftmax) {
3187   Repeatedly([this]() {
3188     return ExpectTfAndXlaOutputsAreClose(
3189         OpTestBuilder("LogSoftmax")
3190             .RandomInput(DT_FLOAT, RandomDims(2, 2))
3191             .Attr("T", DT_FLOAT));
3192   });
3193 }
3194 
TEST_F(OpTest,LRN)3195 TEST_F(OpTest, LRN) {
3196   Repeatedly([this]() {
3197     // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed.
3198     std::vector<int64_t> data_dims = RandomDims(4, 4, 1, 8);
3199     // CuDNN requires depth_radius > 0.
3200     std::uniform_int_distribution<int> radius(1, data_dims[3]);
3201     std::uniform_real_distribution<float> coeff(0.01, 2.0);
3202     return ExpectTfAndXlaOutputsAreClose(
3203         OpTestBuilder("LRN")
3204             .RandomInput(DT_FLOAT, data_dims)
3205             .Attr("T", DT_FLOAT)
3206             .Attr("depth_radius", radius(generator()))
3207             .Attr("bias", coeff(generator()))
3208             .Attr("alpha", coeff(generator()))
3209             .Attr("beta", coeff(generator())));
3210   });
3211 }
3212 
TEST_F(OpTest,LRNGrad)3213 TEST_F(OpTest, LRNGrad) {
3214   Repeatedly([this]() {
3215     // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed.
3216     std::vector<int64_t> dims = RandomDims(4, 4, 1, 8);
3217     // CuDNN requires depth_radius > 0.
3218     std::uniform_int_distribution<int> radius(1, dims[3]);
3219     std::uniform_real_distribution<float> coeff(0.0, 2.0);
3220     return ExpectTfAndXlaOutputsAreClose(
3221         OpTestBuilder("LRNGrad")
3222             .RandomInput(DT_FLOAT, dims)
3223             .RandomInput(DT_FLOAT, dims)
3224             .RandomInput(DT_FLOAT, dims)
3225             .Attr("T", DT_FLOAT)
3226             .Attr("depth_radius", radius(generator()))
3227             .Attr("bias", coeff(generator()))
3228             .Attr("alpha", coeff(generator()))
3229             .Attr("beta", coeff(generator())));
3230   });
3231 }
3232 
TEST_F(OpTest,MatMul)3233 TEST_F(OpTest, MatMul) {
3234   Repeatedly([this]() {
3235     int64_t x = RandomDim();
3236     int64_t y = RandomDim();
3237     int64_t z = RandomDim();
3238 
3239     std::vector<int64_t> a_dims = {x, y};
3240     std::vector<int64_t> b_dims = {y, z};
3241 
3242     std::bernoulli_distribution random_bool;
3243     bool transpose_a = random_bool(generator());
3244     bool transpose_b = random_bool(generator());
3245     if (transpose_a) {
3246       std::swap(a_dims[0], a_dims[1]);
3247     }
3248     if (transpose_b) {
3249       std::swap(b_dims[0], b_dims[1]);
3250     }
3251 
3252     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3253     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
3254                                              .RandomInput(type, a_dims)
3255                                              .RandomInput(type, b_dims)
3256                                              .Attr("T", type)
3257                                              .Attr("transpose_a", transpose_a)
3258                                              .Attr("transpose_b", transpose_b));
3259   });
3260 }
3261 
TEST_F(OpTest,MatrixBandPart)3262 TEST_F(OpTest, MatrixBandPart) {
3263   Repeatedly([this]() {
3264     auto type = Choose<DataType>(kAllXlaTypes);
3265     auto index_type = Choose<DataType>({DT_INT32, DT_INT64});
3266     auto num_lower =
3267         RandomBoundedTensor<int32>(index_type, -2 * kDefaultMaxDimensionSize,
3268                                    2 * kDefaultMaxDimensionSize, false, {});
3269     auto num_upper =
3270         RandomBoundedTensor<int32>(index_type, -2 * kDefaultMaxDimensionSize,
3271                                    2 * kDefaultMaxDimensionSize, false, {});
3272     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixBandPart")
3273                                              .RandomInput(type)
3274                                              .Input(num_lower)
3275                                              .Input(num_upper)
3276                                              .Attr("T", type)
3277                                              .Attr("Tindex", index_type));
3278   });
3279 }
3280 
TEST_F(OpTest,MatrixDiag)3281 TEST_F(OpTest, MatrixDiag) {
3282   Repeatedly([this]() {
3283     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3284     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag")
3285                                              .RandomInput(type, RandomDims(1))
3286                                              .Attr("T", type));
3287   });
3288 }
3289 
TEST_F(OpTest,MatrixDiagPart)3290 TEST_F(OpTest, MatrixDiagPart) {
3291   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3292   Repeatedly([this]() {
3293     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3294     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart")
3295                                              .RandomInput(type, RandomDims(2))
3296                                              .Attr("T", type));
3297   });
3298 }
3299 
TEST_F(OpTest,MatrixDiagPartV3)3300 TEST_F(OpTest, MatrixDiagPartV3) {
3301   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3302   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
3303   Repeatedly([this]() {  // NOLINT: due to GTEST_SKIP
3304     auto type = Choose<DataType>(kAllXlaTypes);
3305     auto align = Choose<std::string>(
3306         {"LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT"});
3307     auto k0 = std::uniform_int_distribution<int32>(
3308         -2 * kDefaultMaxDimensionSize,
3309         2 * kDefaultMaxDimensionSize)(generator());
3310     auto k1 = std::uniform_int_distribution<int32>(
3311         k0, 2 * kDefaultMaxDimensionSize)(generator());
3312     auto k = test::AsTensor<int32>({k0, k1});
3313     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPartV3")
3314                                              .RandomInput(type)
3315                                              .Input(k)
3316                                              .RandomInput(type, {})
3317                                              .Attr("align", align)
3318                                              .Attr("T", type));
3319   });
3320 }
3321 
TEST_F(OpTest,MatrixSetDiag)3322 TEST_F(OpTest, MatrixSetDiag) {
3323   Repeatedly([this]() {
3324     auto type = Choose<DataType>(kAllXlaTypes);
3325     auto shape = RandomDims(2);
3326     int rank = shape.size();
3327     std::vector<int64_t> diagonal_shape(shape);
3328     diagonal_shape.pop_back();
3329     diagonal_shape.pop_back();
3330     diagonal_shape.push_back(std::min(shape[rank - 2], shape[rank - 1]));
3331     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixSetDiag")
3332                                              .RandomInput(type, shape)
3333                                              .RandomInput(type, diagonal_shape)
3334                                              .Attr("T", type));
3335   });
3336 }
3337 
TEST_F(OpTest,MatrixSetDiagV2)3338 TEST_F(OpTest, MatrixSetDiagV2) {
3339   Repeatedly([this]() {
3340     auto type = Choose<DataType>(kAllXlaTypes);
3341     auto shape = RandomDims(2, kDefaultMaxRank, 1 /* non-zero dims */);
3342     int rank = shape.size();
3343     int64_t max_num_diags = shape[rank - 2] + shape[rank - 1] - 1;
3344     int64_t num_diags =
3345         std::uniform_int_distribution<int64_t>(2, max_num_diags)(generator());
3346     int32 k0 = std::uniform_int_distribution<int32>(
3347         -shape[rank - 2] + 1, shape[rank - 1] - num_diags)(generator());
3348     int32 k1 = k0 + num_diags - 1;
3349     Tensor k = test::AsTensor<int32>({k0, k1});
3350     int64_t max_diag_len = std::min(shape[rank - 2] + std::min(k1, 0),
3351                                     shape[rank - 1] + std::min(-k0, 0));
3352     std::vector<int64_t> diagonal_shape(shape);
3353     diagonal_shape.pop_back();
3354     diagonal_shape.pop_back();
3355     diagonal_shape.push_back(num_diags);
3356     diagonal_shape.push_back(max_diag_len);
3357     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixSetDiagV2")
3358                                              .RandomInput(type, shape)
3359                                              .RandomInput(type, diagonal_shape)
3360                                              .Input(k)
3361                                              .Attr("T", type));
3362   });
3363 }
3364 
TEST_F(OpTest,Max)3365 TEST_F(OpTest, Max) {
3366   Repeatedly([this]() {
3367     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
3368     std::vector<int64_t> data_dims = RandomDims();
3369     Tensor indices = RandomReductionIndices(data_dims.size());
3370     bool keep_dims = Choose<bool>({false, true});
3371     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Max")
3372                                              .RandomInput(type, data_dims)
3373                                              .Input(indices)
3374                                              .Attr("T", type)
3375                                              .Attr("keep_dims", keep_dims));
3376   });
3377 }
3378 
TEST_F(OpTest,Maximum)3379 TEST_F(OpTest, Maximum) {
3380   Repeatedly([this]() {
3381     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
3382     auto dims = BroadcastableDims();
3383     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Maximum")
3384                                              .RandomInput(type, dims.first)
3385                                              .RandomInput(type, dims.second)
3386                                              .Attr("T", type));
3387   });
3388 }
3389 
TEST_F(OpTest,MaxPool)3390 TEST_F(OpTest, MaxPool) {
3391   Repeatedly([this]() {
3392     std::uniform_int_distribution<int> random_int(1, 5);
3393     std::vector<int64_t> dims = RandomDims(4, 4, 1);
3394     int kernel_rows =
3395         std::uniform_int_distribution<int>(1, dims[1])(generator());
3396     int kernel_cols =
3397         std::uniform_int_distribution<int>(1, dims[2])(generator());
3398     int stride_rows = random_int(generator()),
3399         stride_cols = random_int(generator());
3400 
3401     string padding = Choose<string>({"SAME", "VALID"});
3402     return ExpectTfAndXlaOutputsAreClose(
3403         OpTestBuilder("MaxPool")
3404             .RandomInput(DT_FLOAT, dims)
3405             .Attr("T", DT_FLOAT)
3406             .Attr("ksize", {1, kernel_rows, kernel_cols, 1})
3407             .Attr("strides", {1, stride_rows, stride_cols, 1})
3408             .Attr("padding", padding)
3409             .Attr("data_format", "NHWC"));
3410   });
3411   // TODO(phawkins): test NCHW format (not supported by CPU)
3412 }
3413 
TEST_F(OpTest,MaxPool3D)3414 TEST_F(OpTest, MaxPool3D) {
3415   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3416   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
3417   Repeatedly([this]() {
3418     std::uniform_int_distribution<int> random_int(1, 5);
3419     std::vector<int64_t> dims = RandomDims(5, 5, 1);
3420 
3421     std::vector<int64_t> input_dims, kernel_dims, stride_dims;
3422     kernel_dims.push_back(1);
3423     stride_dims.push_back(1);
3424     for (int i = 0; i < 3; ++i) {
3425       kernel_dims.push_back(
3426           std::uniform_int_distribution<int>(1, dims[i])(generator()));
3427       input_dims.push_back(dims[i]);
3428       stride_dims.push_back(random_int(generator()));
3429     }
3430     kernel_dims.push_back(1);
3431     stride_dims.push_back(1);
3432     int64_t batch = dims[3];
3433     int64_t feature = dims[4];
3434 
3435     string padding = Choose<string>({"SAME", "VALID"});
3436     return ExpectTfAndXlaOutputsAreClose(
3437         OpTestBuilder("MaxPool3D")
3438             .RandomInput(DT_FLOAT,
3439                          ImageDims(FORMAT_NHWC, batch, feature, input_dims))
3440             .Attr("T", DT_FLOAT)
3441             .Attr("ksize", kernel_dims)
3442             .Attr("strides", stride_dims)
3443             .Attr("padding", padding)
3444             .Attr("data_format", "NDHWC"));
3445   });
3446   // TODO(phawkins): test NCHW format (not supported by CPU)
3447 }
3448 
TEST_F(OpTest,Mean)3449 TEST_F(OpTest, Mean) {
3450   Repeatedly([this]() {
3451     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3452     // TODO(phawkins): CPU and XLA differ output for reducing across a
3453     // size-0 dimension (nan vs 0). For now, require size >= 1.
3454     std::vector<int64_t> data_dims = RandomDims(0, kDefaultMaxRank, 1);
3455     Tensor indices = RandomReductionIndices(data_dims.size());
3456     bool keep_dims = Choose<bool>({false, true});
3457     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mean")
3458                                              .RandomInput(type, data_dims)
3459                                              .Input(indices)
3460                                              .Attr("T", type)
3461                                              .Attr("keep_dims", keep_dims));
3462   });
3463 }
3464 
TEST_F(OpTest,Min)3465 TEST_F(OpTest, Min) {
3466   Repeatedly([this]() {
3467     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
3468     std::vector<int64_t> data_dims = RandomDims();
3469     Tensor indices = RandomReductionIndices(data_dims.size());
3470     bool keep_dims = Choose<bool>({false, true});
3471     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Min")
3472                                              .RandomInput(type, data_dims)
3473                                              .Input(indices)
3474                                              .Attr("T", type)
3475                                              .Attr("keep_dims", keep_dims));
3476   });
3477 }
3478 
TEST_F(OpTest,Minimum)3479 TEST_F(OpTest, Minimum) {
3480   Repeatedly([this]() {
3481     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
3482     auto dims = BroadcastableDims();
3483     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Minimum")
3484                                              .RandomInput(type, dims.first)
3485                                              .RandomInput(type, dims.second)
3486                                              .Attr("T", type));
3487   });
3488 }
3489 
TEST_F(OpTest,Mod)3490 TEST_F(OpTest, Mod) {
3491   Repeatedly([this]() {
3492     auto dims = BroadcastableDims();
3493     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mod")
3494                                              .RandomInput(DT_INT32, dims.first)
3495                                              .RandomInput(DT_INT32, dims.second)
3496                                              .Attr("T", DT_INT32));
3497   });
3498 }
3499 
TEST_F(OpTest,Mul)3500 TEST_F(OpTest, Mul) {
3501   Repeatedly([this]() {
3502     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3503     auto dims = BroadcastableDims();
3504     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul")
3505                                              .RandomInput(type, dims.first)
3506                                              .RandomInput(type, dims.second)
3507                                              .Attr("T", type));
3508   });
3509 }
3510 
TEST_F(OpTest,MulNoNan)3511 TEST_F(OpTest, MulNoNan) {
3512   Repeatedly([this]() {
3513     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3514     auto dims = BroadcastableDims();
3515     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul")
3516                                              .RandomInput(type, dims.first)
3517                                              .RandomInput(type, dims.second)
3518                                              .Attr("T", type));
3519   });
3520 }
3521 
TEST_F(OpTest,Neg)3522 TEST_F(OpTest, Neg) {
3523   Repeatedly([this]() {
3524     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3525     return ExpectTfAndXlaOutputsAreClose(
3526         OpTestBuilder("Neg").RandomInput(type).Attr("T", type));
3527   });
3528 }
3529 
TEST_F(OpTest,NextAfter)3530 TEST_F(OpTest, NextAfter) {
3531   Repeatedly([this]() {
3532     auto type = Choose<DataType>({DT_FLOAT});
3533     auto dims = RandomDims();
3534     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NextAfter")
3535                                              .RandomInput(type, dims)
3536                                              .RandomInput(type, dims)
3537                                              .Attr("T", type));
3538   });
3539 }
3540 
TEST_F(OpTest,NotEqual)3541 TEST_F(OpTest, NotEqual) {
3542   Repeatedly([this]() {
3543     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3544     auto dims = BroadcastableDims();
3545     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual")
3546                                              .RandomInput(type, dims.first)
3547                                              .RandomInput(type, dims.second)
3548                                              .Attr("T", type));
3549   });
3550 }
3551 
TEST_F(OpTest,OneHot)3552 TEST_F(OpTest, OneHot) {
3553   Repeatedly([this]() {
3554     auto type = Choose<DataType>(kAllXlaTypes);
3555 
3556     std::vector<int64_t> dims = RandomDims();
3557     int num_dims = dims.size();
3558 
3559     int32_t depth = RandomDim();
3560 
3561     Tensor indices(DT_INT32, TensorShape(dims));
3562     std::uniform_int_distribution<int32> distribution(-depth * 2, depth * 2);
3563     test::FillFn<int32>(&indices, [this, &distribution](int i) -> int32 {
3564       return distribution(generator());
3565     });
3566 
3567     int axis = std::uniform_int_distribution<int32>(-num_dims - 5,
3568                                                     num_dims + 5)(generator());
3569 
3570     OpTestBuilder builder("OneHot");
3571     builder.Attr("T", type);
3572     builder.Attr("TI", DT_INT32);
3573     builder.Attr("axis", axis);
3574     builder.Input(indices);
3575     builder.Input(test::AsScalar<int32>(depth));
3576     builder.RandomInput(type, {});
3577     builder.RandomInput(type, {});
3578     return ExpectTfAndXlaOutputsAreClose(builder);
3579   });
3580 }
3581 
TEST_F(OpTest,OnesLike)3582 TEST_F(OpTest, OnesLike) {
3583   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3584   Repeatedly([this]() {
3585     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3586     return ExpectTfAndXlaOutputsAreClose(
3587         OpTestBuilder("OnesLike").RandomInput(type).Attr("T", type));
3588   });
3589 }
3590 
TEST_F(OpTest,Pack)3591 TEST_F(OpTest, Pack) {
3592   Repeatedly([this]() {
3593     auto type = Choose<DataType>(kAllXlaTypes);
3594     int n = std::uniform_int_distribution<int>(1, 5)(generator());
3595 
3596     std::vector<int64_t> dims = RandomDims();
3597     int num_dims = dims.size();
3598     int axis = std::uniform_int_distribution<int32>(-num_dims - 1,
3599                                                     num_dims)(generator());
3600 
3601     OpTestBuilder builder("Pack");
3602     builder.Attr("T", type);
3603     builder.Attr("N", n);
3604     builder.Attr("axis", axis);
3605     for (int i = 0; i < n; ++i) {
3606       builder.RandomInput(type, dims);
3607     }
3608     return ExpectTfAndXlaOutputsAreClose(builder);
3609   });
3610 }
3611 
TEST_F(OpTest,Pad)3612 TEST_F(OpTest, Pad) {
3613   // See note about failing Kokoro tests: b/214080339#comment22
3614   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3615   Repeatedly([this]() {
3616     auto a = ChoosePadArguments();
3617     return ExpectTfAndXlaOutputsAreClose(
3618         OpTestBuilder("Pad")
3619             .RandomInput(a.input_type, a.input_shape)
3620             .Input(a.paddings)
3621             .Attr("T", a.input_type)
3622             .Attr("Tpaddings", a.paddings_type));
3623   });
3624 }
3625 
TEST_F(OpTest,PadV2)3626 TEST_F(OpTest, PadV2) {
3627   Repeatedly([this]() {
3628     auto a = ChoosePadArguments();
3629     return ExpectTfAndXlaOutputsAreClose(
3630         OpTestBuilder("PadV2")
3631             .RandomInput(a.input_type, a.input_shape)
3632             .Input(a.paddings)
3633             .Input(a.constant_values)
3634             .Attr("T", a.input_type)
3635             .Attr("Tpaddings", a.paddings_type));
3636   });
3637 }
3638 
TEST_F(OpTest,Pow)3639 TEST_F(OpTest, Pow) {
3640   // TODO(phawkins): Feeding large DT_INT32 values to Pow() leads to
3641   // nontermination.
3642   Repeatedly([this]() {
3643     auto dims = BroadcastableDims();
3644     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3645     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pow")
3646                                              .RandomInput(type, dims.first)
3647                                              .RandomInput(type, dims.second)
3648                                              .Attr("T", type));
3649   });
3650 }
3651 
TEST_F(OpTest,Prod)3652 TEST_F(OpTest, Prod) {
3653   Repeatedly([this]() {
3654     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3655     std::vector<int64_t> data_dims = RandomDims();
3656     Tensor indices = RandomReductionIndices(data_dims.size());
3657     bool keep_dims = Choose<bool>({false, true});
3658     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Prod")
3659                                              .RandomInput(type, data_dims)
3660                                              .Input(indices)
3661                                              .Attr("T", type)
3662                                              .Attr("keep_dims", keep_dims));
3663   });
3664 }
3665 
TEST_F(OpTest,Qr)3666 TEST_F(OpTest, Qr) {
3667   Repeatedly([this]() {
3668     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3669     return ExpectTfAndXlaOutputsAreClose(
3670         OpTestBuilder("Qr")
3671             .RandomInput(type, RandomDims(2, kDefaultMaxRank, 1))
3672             .Attr("T", type)
3673             .Attr("full_matrices", true));
3674   });
3675 }
3676 
TEST_F(OpTest,QuantizeAndDequantizeV2)3677 TEST_F(OpTest, QuantizeAndDequantizeV2) {
3678   Repeatedly([this]() {
3679     std::uniform_int_distribution<int64_t> num_bits_dist(1, 64);
3680     int64_t num_bits = num_bits_dist(generator());
3681     std::string round_mode = Choose<std::string>({"HALF_TO_EVEN", "HALF_UP"});
3682     auto dims = RandomDims(0, kDefaultMaxRank, 1);
3683     return ExpectTfAndXlaOutputsAreClose(
3684         OpTestBuilder("QuantizeAndDequantizeV2")
3685             .RandomInput(DT_FLOAT, dims)
3686             .RandomInput(DT_FLOAT, dims)  // unused because range_given = false
3687             .RandomInput(DT_FLOAT, dims)  // unused because range_given = false
3688             .Attr("signed_input", RandomBool())
3689             .Attr("num_bits", num_bits)
3690             .Attr("range_given", false)
3691             .Attr("round_mode", round_mode)
3692             .Attr("narrow_range", RandomBool())
3693             .Attr("axis", -1)
3694             .Attr("T", DT_FLOAT));
3695   });
3696 }
3697 
TEST_F(OpTest,RandomShuffle)3698 TEST_F(OpTest, RandomShuffle) {
3699   // See b/209062491 as this test passes with --tf_xla_test_device=CPU:0
3700   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3701   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
3702   Repeatedly([this]() {  // NOLINT: due to GTEST_SKIP
3703     auto type = Choose<DataType>(kAllXlaTypes);
3704     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RandomShuffle")
3705                                              .RandomInput(type, RandomDims(1))
3706                                              .Attr("seed", RandomSeed())
3707                                              .Attr("seed2", RandomSeed())
3708                                              .Attr("T", type));
3709   });
3710 }
3711 
TEST_F(OpTest,RandomStandardNormal)3712 TEST_F(OpTest, RandomStandardNormal) {
3713   Repeatedly([this]() {
3714     auto shape_type = Choose<DataType>({DT_INT32, DT_INT64});
3715     return ExpectTfAndXlaOutputsAreClose(
3716         OpTestBuilder("RandomStandardNormal")
3717             .Input(AsIntTensor(shape_type, RandomDims()))
3718             .Attr("seed", RandomSeed())
3719             .Attr("seed2", RandomSeed())
3720             .Attr("T", shape_type)
3721             .Attr("dtype", DT_FLOAT));
3722   });
3723 }
3724 
TEST_F(OpTest,RandomUniform)3725 TEST_F(OpTest, RandomUniform) {
3726   Repeatedly([this]() {
3727     auto shape_type = Choose<DataType>({DT_INT32, DT_INT64});
3728     return ExpectTfAndXlaOutputsAreClose(
3729         OpTestBuilder("RandomStandardNormal")
3730             .Input(AsIntTensor(shape_type, RandomDims()))
3731             .Attr("seed", RandomSeed())
3732             .Attr("seed2", RandomSeed())
3733             .Attr("T", shape_type)
3734             .Attr("dtype", DT_FLOAT));
3735   });
3736 }
3737 
TEST_F(OpTest,Range)3738 TEST_F(OpTest, Range) {
3739   Repeatedly([this]() {
3740     auto ToScalar = [](DataType type, int x) {
3741       if (type == DT_INT32) return test::AsScalar<int32>(x);
3742       if (type == DT_INT64) return test::AsScalar<int64_t>(x);
3743       if (type == DT_FLOAT) return test::AsScalar<float>(x);
3744       if (type == DT_DOUBLE) return test::AsScalar<double>(x);
3745       LOG(FATAL) << "Unknown type " << DataTypeString(type);
3746     };
3747     std::uniform_int_distribution<int> distribution(-50, 50);
3748     DataType tidx = Choose<DataType>({DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE});
3749     return ExpectTfAndXlaOutputsAreClose(
3750         OpTestBuilder("Range")
3751             .Input(ToScalar(tidx, distribution(generator())))
3752             .Input(ToScalar(tidx, distribution(generator())))
3753             .Input(ToScalar(tidx, distribution(generator())))
3754             .Attr("Tidx", tidx));
3755   });
3756 }
3757 
TEST_F(OpTest,Rank)3758 TEST_F(OpTest, Rank) {
3759   Repeatedly([this]() {
3760     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3761     return ExpectTfAndXlaOutputsAreClose(
3762         OpTestBuilder("Rank").RandomInput(type).Attr("T", type));
3763   });
3764 }
3765 
TEST_F(OpTest,Real)3766 TEST_F(OpTest, Real) {
3767   Repeatedly([this]() {
3768     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Real")
3769                                              .RandomInput(DT_COMPLEX64)
3770                                              .Attr("T", DT_COMPLEX64));
3771   });
3772 }
3773 
TEST_F(OpTest,RealDiv)3774 TEST_F(OpTest, RealDiv) {
3775   Repeatedly([this]() {
3776     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3777     auto dims = BroadcastableDims();
3778     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv")
3779                                              .RandomInput(type, dims.first)
3780                                              .RandomInput(type, dims.second)
3781                                              .Attr("T", type));
3782   });
3783 }
3784 
TEST_F(OpTest,Reciprocal)3785 TEST_F(OpTest, Reciprocal) {
3786   Repeatedly([this]() {
3787     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3788     return ExpectTfAndXlaOutputsAreClose(
3789         OpTestBuilder("Reciprocal").RandomInput(type).Attr("T", type));
3790   });
3791 }
3792 
TEST_F(OpTest,ReciprocalGrad)3793 TEST_F(OpTest, ReciprocalGrad) {
3794   Repeatedly([this]() {
3795     std::vector<int64_t> dims = RandomDims();
3796     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3797     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad")
3798                                              .RandomInput(type, dims)
3799                                              .RandomInput(type, dims)
3800                                              .Attr("T", type));
3801   });
3802 }
TEST_F(OpTest,Relu)3803 TEST_F(OpTest, Relu) {
3804   Repeatedly([this]() {
3805     return ExpectTfAndXlaOutputsAreClose(
3806         OpTestBuilder("Relu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3807   });
3808 }
3809 
TEST_F(OpTest,Relu6)3810 TEST_F(OpTest, Relu6) {
3811   Repeatedly([this]() {
3812     return ExpectTfAndXlaOutputsAreClose(
3813         OpTestBuilder("Relu6").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3814   });
3815 }
3816 
TEST_F(OpTest,Relu6Grad)3817 TEST_F(OpTest, Relu6Grad) {
3818   Repeatedly([this]() {
3819     auto dims = RandomDims(1);
3820     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6Grad")
3821                                              .RandomInput(DT_FLOAT, dims)
3822                                              .RandomInput(DT_FLOAT, dims)
3823                                              .Attr("T", DT_FLOAT));
3824   });
3825 }
3826 
TEST_F(OpTest,ReluGrad)3827 TEST_F(OpTest, ReluGrad) {
3828   Repeatedly([this]() {
3829     auto dims = RandomDims(1);
3830     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReluGrad")
3831                                              .RandomInput(DT_FLOAT, dims)
3832                                              .RandomInput(DT_FLOAT, dims)
3833                                              .Attr("T", DT_FLOAT));
3834   });
3835 }
3836 
TEST_F(OpTest,Reshape)3837 TEST_F(OpTest, Reshape) {
3838   Repeatedly([this]() {
3839     auto type = Choose<DataType>(kAllXlaTypes);
3840     std::vector<int64_t> dims = RandomDims();
3841     std::bernoulli_distribution random_bool;
3842     std::vector<int64_t> dims_before, dims_after;
3843     for (std::vector<int64_t>* out : {&dims_before, &dims_after}) {
3844       std::shuffle(dims.begin(), dims.end(), generator());
3845       for (int64_t dim : dims) {
3846         // Either add the dimension as a new dimension or merge it with the
3847         // previous dimension.
3848         if (out->empty() || random_bool(generator())) {
3849           out->push_back(dim);
3850         } else {
3851           out->back() *= dim;
3852         }
3853       }
3854     }
3855     return ExpectTfAndXlaOutputsAreClose(
3856         OpTestBuilder("Reshape")
3857             .RandomInput(type, dims_before)
3858             .Input(test::AsTensor<int32>(
3859                 std::vector<int32>(dims_after.begin(), dims_after.end())))
3860             .Attr("T", type));
3861   });
3862 }
3863 
TEST_F(OpTest,ResizeNearestNeighbor)3864 TEST_F(OpTest, ResizeNearestNeighbor) {
3865   Repeatedly([this]() {
3866     auto type = Choose<DataType>({DT_FLOAT, DT_INT32, DT_INT64});
3867     return ExpectTfAndXlaOutputsAreClose(
3868         OpTestBuilder("ResizeNearestNeighbor")
3869             .RandomInput(type, RandomDims(4, 4, 1))
3870             .Input(AsIntTensor(DT_INT32, RandomDims(2, kDefaultMaxRank, 1)))
3871             .Attr("align_corners", RandomBool())
3872             .Attr("half_pixel_centers", RandomBool())
3873             .Attr("T", type));
3874   });
3875 }
3876 
TEST_F(OpTest,ResizeBilinear)3877 TEST_F(OpTest, ResizeBilinear) {
3878   Repeatedly([this]() {
3879     std::vector<int64_t> in_dims = RandomDims(4, 4);
3880     std::vector<int64_t> out_dims = RandomDims(2, 2);
3881 
3882     return ExpectTfAndXlaOutputsAreClose(
3883         OpTestBuilder("ResizeBilinear")
3884             .RandomInput(DT_FLOAT, in_dims)
3885             .Input(test::AsTensor<int32>(
3886                 std::vector<int32>(out_dims.begin(), out_dims.end())))
3887             .Attr("T", DT_FLOAT)
3888             .Attr("align_corners", true));
3889   });
3890 }
3891 
TEST_F(OpTest,ResizeBilinearGrad)3892 TEST_F(OpTest, ResizeBilinearGrad) {
3893   Repeatedly([this]() {
3894     std::vector<int64_t> in_dims = RandomDims(4, 4);
3895     std::vector<int64_t> out_dims = RandomDims(2, 2);
3896 
3897     return ExpectTfAndXlaOutputsAreClose(
3898         OpTestBuilder("ResizeBilinearGrad")
3899             .RandomInput(DT_FLOAT, in_dims)
3900             .RandomInput(DT_FLOAT,
3901                          {in_dims[0], out_dims[0], out_dims[1], in_dims[3]})
3902             .Attr("T", DT_FLOAT)
3903             .Attr("align_corners", true));
3904   });
3905 }
3906 
TEST_F(OpTest,Reverse)3907 TEST_F(OpTest, Reverse) {
3908   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3909   Repeatedly([this]() {
3910     std::vector<int64_t> dims = RandomDims(1);
3911     auto type = Choose<DataType>(kAllXlaTypes);
3912     int64_t rank = dims.size();
3913     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse")
3914                                              .RandomInput(type, dims)
3915                                              .RandomInput(DT_BOOL, {rank})
3916                                              .Attr("T", type));
3917   });
3918 }
3919 
TEST_F(OpTest,ReverseSequence)3920 TEST_F(OpTest, ReverseSequence) {
3921   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
3922   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
3923   Repeatedly([this]() {
3924     std::vector<int64_t> dims = RandomDims(/*min_rank=*/2);
3925     auto type = Choose<DataType>(kAllXlaTypes);
3926     int64_t rank = dims.size();
3927 
3928     // Choose random batch and sequence dimensions.
3929     std::vector<int> shuffled_dim_ids(rank);
3930     absl::c_iota(shuffled_dim_ids, 0);
3931     absl::c_shuffle(shuffled_dim_ids, generator());
3932     shuffled_dim_ids.resize(2);
3933     int batch_dim = shuffled_dim_ids[0];
3934     int seq_dim = shuffled_dim_ids[1];
3935 
3936     int batch_size = dims[batch_dim];
3937     int max_seq_len = dims[seq_dim];
3938     std::vector<int32> seq_lens(batch_size);
3939     std::uniform_int_distribution<int32> d(0, max_seq_len);
3940     absl::c_generate(seq_lens, [&]() { return d(generator()); });
3941 
3942     return ExpectTfAndXlaOutputsAreClose(
3943         OpTestBuilder("ReverseSequence")
3944             .RandomInput(type, dims)
3945             .Input(test::AsTensor<int32>(seq_lens))
3946             .Attr("seq_dim", seq_dim)
3947             .Attr("batch_dim", batch_dim)
3948             .Attr("T", type)
3949             .Attr("Tlen", DT_INT32));
3950   });
3951 }
3952 
TEST_F(OpTest,ReverseV2)3953 TEST_F(OpTest, ReverseV2) {
3954   Repeatedly([this]() {
3955     auto type = Choose<DataType>(kAllXlaTypes);
3956     std::vector<int64_t> data_dims = RandomDims();
3957     Tensor indices = RandomReductionIndices(data_dims.size());
3958     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2")
3959                                              .RandomInput(type, data_dims)
3960                                              .Input(indices)
3961                                              .Attr("T", type));
3962   });
3963 }
3964 
TEST_F(OpTest,RightShift)3965 TEST_F(OpTest, RightShift) {
3966   Repeatedly([this]() {
3967     bool is64 = RandomBool();
3968     auto dims = RandomDims();
3969     auto type = is64 ? DT_INT64 : DT_INT32;
3970     int max_shift = is64 ? 63 : 31;
3971     auto y = RandomBoundedTensor(type, 0, max_shift, false, dims);
3972     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RightShift")
3973                                              .RandomInput(type, dims)
3974                                              .Input(y)
3975                                              .Attr("T", type));
3976   });
3977 }
3978 
TEST_F(OpTest,Rint)3979 TEST_F(OpTest, Rint) {
3980   Repeatedly([this]() {
3981     return ExpectTfAndXlaOutputsAreClose(
3982         OpTestBuilder("Rint").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
3983   });
3984 }
3985 
TEST_F(OpTest,Roll)3986 TEST_F(OpTest, Roll) {
3987   Repeatedly([this]() {
3988     auto input_type = Choose<DataType>(kAllXlaTypes);
3989     auto axis_type = Choose<DataType>({DT_INT32, DT_INT64});
3990     // TODO(b/201095155,b/197140886): shift_type = DT_INT64 doesn't work.
3991     auto shift_type = DT_INT32;
3992     auto input_shape = RandomDims(1);
3993     int rank = input_shape.size();
3994     auto axis_shape = RandomDims(1, 1, 1, rank + 1);
3995     auto axis = RandomBoundedTensor(axis_type, 0, rank - 1, true, axis_shape);
3996     auto shift = RandomTensor(shift_type, false, axis_shape);
3997     return ExpectTfAndXlaOutputsAreClose(
3998         OpTestBuilder("Roll")
3999             .RandomInput(input_type, input_shape)
4000             .Input(shift)
4001             .Input(axis)
4002             .Attr("T", input_type)
4003             .Attr("Taxis", axis_type)
4004             .Attr("Tshift", shift_type));
4005   });
4006 }
4007 
TEST_F(OpTest,Round)4008 TEST_F(OpTest, Round) {
4009   Repeatedly([this]() {
4010     return ExpectTfAndXlaOutputsAreClose(
4011         OpTestBuilder("Round").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
4012   });
4013 }
4014 
TEST_F(OpTest,Rsqrt)4015 TEST_F(OpTest, Rsqrt) {
4016   Repeatedly([this]() {
4017     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4018     return ExpectTfAndXlaOutputsAreClose(
4019         OpTestBuilder("Rsqrt").RandomInput(type).Attr("T", type));
4020   });
4021 }
4022 
TEST_F(OpTest,RsqrtGrad)4023 TEST_F(OpTest, RsqrtGrad) {
4024   Repeatedly([this]() {
4025     auto dims = RandomDims();
4026     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4027     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad")
4028                                              .RandomInput(type, dims)
4029                                              .RandomInput(type, dims)
4030                                              .Attr("T", type));
4031   });
4032 }
4033 
TEST_F(OpTest,Select)4034 TEST_F(OpTest, Select) {
4035   Repeatedly([this]() {
4036     auto type = Choose<DataType>(kAllXlaTypes);
4037     auto shape = RandomDims();
4038     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Select")
4039                                              .RandomInput(DT_BOOL, shape)
4040                                              .RandomInput(type, shape)
4041                                              .RandomInput(type, shape)
4042                                              .Attr("T", type));
4043   });
4044 }
4045 
TEST_F(OpTest,SelectV2)4046 TEST_F(OpTest, SelectV2) {
4047   Repeatedly([this]() {
4048     auto type = Choose<DataType>(kAllXlaTypes);
4049     auto shape = RandomDims();
4050     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SelectV2")
4051                                              .RandomInput(DT_BOOL, shape)
4052                                              .RandomInput(type, shape)
4053                                              .RandomInput(type, shape)
4054                                              .Attr("T", type));
4055   });
4056 }
4057 
TEST_F(OpTest,Shape)4058 TEST_F(OpTest, Shape) {
4059   Repeatedly([this]() {
4060     auto type = Choose<DataType>(kAllXlaTypes);
4061     return ExpectTfAndXlaOutputsAreClose(
4062         OpTestBuilder("Shape").RandomInput(type).Attr("T", type));
4063   });
4064 }
4065 
TEST_F(OpTest,ShapeN)4066 TEST_F(OpTest, ShapeN) {
4067   Repeatedly([this]() {
4068     auto type = Choose<DataType>(kAllXlaTypes);
4069     int n = std::uniform_int_distribution<int>(1, 5)(generator());
4070     OpTestBuilder builder("ShapeN");
4071     builder.Attr("T", type);
4072     builder.Attr("N", n);
4073     for (int i = 0; i < n; ++i) {
4074       builder.RandomInput(type);
4075     }
4076     return ExpectTfAndXlaOutputsAreClose(builder);
4077   });
4078 }
4079 
TEST_F(OpTest,Sigmoid)4080 TEST_F(OpTest, Sigmoid) {
4081   Repeatedly([this]() {
4082     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4083     return ExpectTfAndXlaOutputsAreClose(
4084         OpTestBuilder("Sigmoid").RandomInput(type).Attr("T", type));
4085   });
4086 }
4087 
TEST_F(OpTest,SigmoidGrad)4088 TEST_F(OpTest, SigmoidGrad) {
4089   Repeatedly([this]() {
4090     auto dims = RandomDims();
4091     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4092     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad")
4093                                              .RandomInput(type, dims)
4094                                              .RandomInput(type, dims)
4095                                              .Attr("T", type));
4096   });
4097 }
4098 
TEST_F(OpTest,Sign)4099 TEST_F(OpTest, Sign) {
4100   Repeatedly([this]() {
4101     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
4102     return ExpectTfAndXlaOutputsAreClose(
4103         OpTestBuilder("Sign").RandomInput(type).Attr("T", type));
4104   });
4105 }
4106 
TEST_F(OpTest,Sin)4107 TEST_F(OpTest, Sin) {
4108   Repeatedly([this]() {
4109     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4110     return ExpectTfAndXlaOutputsAreClose(
4111         OpTestBuilder("Sin").RandomInput(type).Attr("T", type));
4112   });
4113 }
4114 
TEST_F(OpTest,Sinh)4115 TEST_F(OpTest, Sinh) {
4116   Repeatedly([this]() {
4117     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4118     return ExpectTfAndXlaOutputsAreClose(
4119         OpTestBuilder("Sinh").RandomInput(type).Attr("T", type));
4120   });
4121 }
4122 
TEST_F(OpTest,Size)4123 TEST_F(OpTest, Size) {
4124   Repeatedly([this]() {
4125     auto type = Choose<DataType>(kAllXlaTypes);
4126     return ExpectTfAndXlaOutputsAreClose(
4127         OpTestBuilder("Size").RandomInput(type).Attr("T", type));
4128   });
4129 }
4130 
TEST_F(OpTest,Slice)4131 TEST_F(OpTest, Slice) {
4132   Repeatedly([this]() {
4133     SliceArguments a = ChooseSliceArguments(true);
4134     std::vector<int32> size;
4135     size.insert(size.end(), a.size.begin(), a.size.end());
4136     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Slice")
4137                                              .RandomInput(a.type, a.shape)
4138                                              .Input(a.indices)
4139                                              .Input(test::AsTensor<int32>(size))
4140                                              .Attr("T", a.type)
4141                                              .Attr("Index", a.indices_type));
4142   });
4143 }
4144 
TEST_F(OpTest,Softmax)4145 TEST_F(OpTest, Softmax) {
4146   Repeatedly([this]() {
4147     return ExpectTfAndXlaOutputsAreClose(
4148         OpTestBuilder("Softmax")
4149             .RandomInput(DT_FLOAT, RandomDims(2, 2))
4150             .Attr("T", DT_FLOAT));
4151   });
4152 }
4153 
TEST_F(OpTest,SoftmaxCrossEntropyWithLogits)4154 TEST_F(OpTest, SoftmaxCrossEntropyWithLogits) {
4155   Repeatedly([this]() {
4156     std::vector<int64_t> dims = RandomDims(2, 2, 1);
4157     return ExpectTfAndXlaOutputsAreClose(
4158         OpTestBuilder("SoftmaxCrossEntropyWithLogits")
4159             .RandomInput(DT_FLOAT, dims)
4160             .RandomInput(DT_FLOAT, dims)
4161             .Attr("T", DT_FLOAT));
4162   });
4163 }
4164 
TEST_F(OpTest,Softplus)4165 TEST_F(OpTest, Softplus) {
4166   Repeatedly([this]() {
4167     return ExpectTfAndXlaOutputsAreClose(
4168         OpTestBuilder("Softplus").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
4169   });
4170 }
4171 
TEST_F(OpTest,SoftplusGrad)4172 TEST_F(OpTest, SoftplusGrad) {
4173   Repeatedly([this]() {
4174     std::vector<int64_t> dims = RandomDims();
4175     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftplusGrad")
4176                                              .RandomInput(DT_FLOAT, dims)
4177                                              .RandomInput(DT_FLOAT, dims)
4178                                              .Attr("T", DT_FLOAT));
4179   });
4180 }
4181 
TEST_F(OpTest,Softsign)4182 TEST_F(OpTest, Softsign) {
4183   Repeatedly([this]() {
4184     return ExpectTfAndXlaOutputsAreClose(
4185         OpTestBuilder("Softsign").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
4186   });
4187 }
4188 
TEST_F(OpTest,SoftsignGrad)4189 TEST_F(OpTest, SoftsignGrad) {
4190   Repeatedly([this]() {
4191     std::vector<int64_t> dims = RandomDims();
4192     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftsignGrad")
4193                                              .RandomInput(DT_FLOAT, dims)
4194                                              .RandomInput(DT_FLOAT, dims)
4195                                              .Attr("T", DT_FLOAT));
4196   });
4197 }
4198 
TEST_F(OpTest,SpaceToBatch)4199 TEST_F(OpTest, SpaceToBatch) {
4200   Repeatedly([this]() {
4201     std::vector<int64_t> block_dims = RandomDims(4, 4, 0, 5);
4202     const int num_block_dims = 2;
4203     int64_t block_size = RandomDim(2, 5);
4204 
4205     std::vector<int64_t> input_dims(1 + num_block_dims + 1);
4206     input_dims[0] = RandomDim();
4207     for (int i = 0; i < num_block_dims; ++i) {
4208       input_dims[1 + i] = block_dims[i] * block_size;
4209     }
4210     input_dims[1 + num_block_dims] = RandomDim();
4211 
4212     std::vector<int64_t> padding_vals;
4213     std::uniform_int_distribution<int> distribution(0, 7);
4214     for (int i = 0; i < num_block_dims; ++i) {
4215       int64_t pad_before;
4216       int64_t pad_after;
4217       do {
4218         pad_before = distribution(generator());
4219         pad_after = distribution(generator());
4220       } while (pad_before + pad_after > input_dims[1 + i]);
4221       input_dims[1 + i] -= pad_before + pad_after;
4222       padding_vals.push_back(pad_before);
4223       padding_vals.push_back(pad_after);
4224     }
4225     Tensor paddings;
4226     CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
4227                             TensorShape({num_block_dims, 2})));
4228 
4229     auto type = Choose<DataType>(kAllXlaTypes);
4230     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch")
4231                                              .RandomInput(type, input_dims)
4232                                              .Input(paddings)
4233                                              .Attr("T", type)
4234                                              .Attr("block_size", block_size));
4235   });
4236 }
4237 
TEST_F(OpTest,SpaceToBatchND)4238 TEST_F(OpTest, SpaceToBatchND) {
4239   Repeatedly([this]() {
4240     std::vector<int64_t> block_dims = RandomDims(1, 3, 0, 5);
4241     int num_block_dims = block_dims.size();
4242     std::vector<int64_t> remaining_dims = RandomDims(0, 3);
4243     std::vector<int64_t> block_multipliers =
4244         RandomDims(block_dims.size(), block_dims.size(), 0, 4);
4245 
4246     std::vector<int64_t> input_dims(1 + num_block_dims + remaining_dims.size());
4247     input_dims[0] = RandomDim();
4248     for (int i = 0; i < num_block_dims; ++i) {
4249       input_dims[1 + i] = block_dims[i] * block_multipliers[i];
4250     }
4251     std::copy(remaining_dims.begin(), remaining_dims.end(),
4252               input_dims.begin() + 1 + num_block_dims);
4253 
4254     std::vector<int64_t> padding_vals;
4255     std::uniform_int_distribution<int> distribution(0, 7);
4256     for (int i = 0; i < num_block_dims; ++i) {
4257       int64_t pad_before;
4258       int64_t pad_after;
4259       do {
4260         pad_before = distribution(generator());
4261         pad_after = distribution(generator());
4262       } while (pad_before + pad_after > input_dims[1 + i]);
4263       input_dims[1 + i] -= pad_before + pad_after;
4264       padding_vals.push_back(pad_before);
4265       padding_vals.push_back(pad_after);
4266     }
4267     Tensor paddings;
4268     CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
4269                             TensorShape({num_block_dims, 2})));
4270 
4271     auto type = Choose<DataType>(kAllXlaTypes);
4272     return ExpectTfAndXlaOutputsAreClose(
4273         OpTestBuilder("SpaceToBatchND")
4274             .RandomInput(type, input_dims)
4275             .Input(test::AsTensor<int32>(
4276                 std::vector<int32>(block_dims.begin(), block_dims.end())))
4277             .Input(paddings)
4278             .Attr("T", type));
4279   });
4280 }
4281 
TEST_F(OpTest,SpaceToDepth)4282 TEST_F(OpTest, SpaceToDepth) {
4283   Repeatedly([this]() {
4284     int64_t block = RandomDim(2, 5);
4285     std::vector<int64_t> input_dims = RandomDims(4, 4);
4286     // Round spatial dimensions up to a multiple of the block size
4287     input_dims[1] = (input_dims[1] + (block - 1)) / block * block;
4288     input_dims[2] = (input_dims[2] + (block - 1)) / block * block;
4289     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToDepth")
4290                                              .RandomInput(DT_FLOAT, input_dims)
4291                                              .Attr("T", DT_FLOAT)
4292                                              .Attr("block_size", block));
4293   });
4294 }
4295 
TEST_F(OpTest,SparseMatMul)4296 TEST_F(OpTest, SparseMatMul) {
4297   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4298   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4299   Repeatedly([this]() {
4300     int64_t x = RandomDim();
4301     int64_t y = RandomDim();
4302     int64_t z = RandomDim();
4303 
4304     std::vector<int64_t> a_dims = {x, y};
4305     std::vector<int64_t> b_dims = {y, z};
4306 
4307     std::bernoulli_distribution random_bool;
4308     bool transpose_a = random_bool(generator());
4309     bool transpose_b = random_bool(generator());
4310     if (transpose_a) {
4311       std::swap(a_dims[0], a_dims[1]);
4312     }
4313     if (transpose_b) {
4314       std::swap(b_dims[0], b_dims[1]);
4315     }
4316 
4317     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul")
4318                                              .RandomInput(DT_FLOAT, a_dims)
4319                                              .RandomInput(DT_FLOAT, b_dims)
4320                                              .Attr("Ta", DT_FLOAT)
4321                                              .Attr("Tb", DT_FLOAT)
4322                                              .Attr("transpose_a", transpose_a)
4323                                              .Attr("transpose_b", transpose_b));
4324   });
4325 }
4326 
TEST_F(OpTest,SparseSoftmaxCrossEntropyWithLogits)4327 TEST_F(OpTest, SparseSoftmaxCrossEntropyWithLogits) {
4328   Repeatedly([this]() {
4329     std::vector<int64_t> dims = RandomDims(2, 2, 1);
4330     int64_t batch_size = dims[0];
4331     int64_t num_classes = dims[1];
4332 
4333     std::vector<int32> indices(batch_size);
4334     for (int64_t i = 0; i < batch_size; ++i) {
4335       indices[i] =
4336           std::uniform_int_distribution<int32>(0, num_classes - 1)(generator());
4337     }
4338 
4339     return ExpectTfAndXlaOutputsAreClose(
4340         OpTestBuilder("SparseSoftmaxCrossEntropyWithLogits")
4341             .RandomInput(DT_FLOAT, dims)
4342             .Input(test::AsTensor<int32>(indices))
4343             .Attr("T", DT_FLOAT)
4344             .Attr("Tlabels", DT_INT32));
4345   });
4346 }
4347 
TEST_F(OpTest,Split)4348 TEST_F(OpTest, Split) {
4349   // See b/214080339#comment27 as this test causes Kokoro to crash.
4350   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4351   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4352   Repeatedly([this]() {
4353     auto type = Choose<DataType>(kAllXlaTypes);
4354     std::vector<int64_t> dims = RandomDims(1);
4355     std::uniform_int_distribution<int> ud;
4356     int32_t dim = std::uniform_int_distribution<int32>(
4357         -static_cast<int32>(dims.size()),
4358         static_cast<int32>(dims.size()) - 1)(generator());
4359     int n = std::uniform_int_distribution<int>(1, 5)(generator());
4360     // Ensure 'dim' is evenly divisible by 'n'.
4361     dims[dim] /= n;
4362     dims[dim] *= n;
4363     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split")
4364                                              .Input(test::AsScalar<int32>(dim))
4365                                              .RandomInput(type, dims)
4366                                              .Attr("T", type)
4367                                              .Attr("num_split", n));
4368   });
4369 }
4370 
TEST_F(OpTest,SplitV)4371 TEST_F(OpTest, SplitV) {
4372   // Likely this only fails when dim is negative. Try type = DT_FLOAT first.
4373   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4374   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4375   Repeatedly([this]() {  // NOLINT: due to GTEST_SKIP
4376     auto type = Choose<DataType>(kAllXlaTypes);
4377     std::vector<int64_t> dims = RandomDims(1, kDefaultMaxRank, 1);
4378     int32_t dim = std::uniform_int_distribution<int32>(
4379         -static_cast<int32>(dims.size()),
4380         static_cast<int32>(dims.size()) - 1)(generator());
4381     int n = std::uniform_int_distribution<int>(
4382         1, std::min(5, static_cast<int>(dims[dim])))(generator());
4383     std::vector<int32> size_splits(n);
4384     for (int i = 0; i < n - 1; ++i) {
4385       size_splits.push_back(dims[dim] / n);
4386     }
4387     size_splits.push_back(dims[dim] - (n - 1) * (dims[dim] / n));
4388     return ExpectTfAndXlaOutputsAreClose(
4389         OpTestBuilder("SplitV")
4390             .RandomInput(type, dims)
4391             .Input(test::AsTensor<int32>(size_splits))
4392             .Input(test::AsScalar<int32>(dim))
4393             .Attr("T", type)
4394             .Attr("num_split", n)
4395             .Attr("Tlen", DT_INT32));
4396   });
4397 }
4398 
TEST_F(OpTest,Sqrt)4399 TEST_F(OpTest, Sqrt) {
4400   Repeatedly([this]() {
4401     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4402     return ExpectTfAndXlaOutputsAreClose(
4403         OpTestBuilder("Sqrt").RandomInput(type).Attr("T", type));
4404   });
4405 }
4406 
TEST_F(OpTest,StopGradient)4407 TEST_F(OpTest, StopGradient) {
4408   Repeatedly([this]() {
4409     auto type = Choose<DataType>(kAllXlaTypes);
4410     return ExpectTfAndXlaOutputsAreClose(
4411         OpTestBuilder("StopGradient").RandomInput(type).Attr("T", type));
4412   });
4413 }
4414 
TEST_F(OpTest,SqrtGrad)4415 TEST_F(OpTest, SqrtGrad) {
4416   Repeatedly([this]() {
4417     auto dims = RandomDims();
4418     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4419     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad")
4420                                              .RandomInput(type, dims)
4421                                              .RandomInput(type, dims)
4422                                              .Attr("T", type));
4423   });
4424 }
4425 
TEST_F(OpTest,SquaredDifference)4426 TEST_F(OpTest, SquaredDifference) {
4427   Repeatedly([this]() {
4428     auto dims = BroadcastableDims();
4429     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SquaredDifference")
4430                                              .RandomInput(DT_FLOAT, dims.first)
4431                                              .RandomInput(DT_FLOAT, dims.second)
4432                                              .Attr("T", DT_FLOAT));
4433   });
4434 }
4435 
TEST_F(OpTest,Square)4436 TEST_F(OpTest, Square) {
4437   Repeatedly([this]() {
4438     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
4439     return ExpectTfAndXlaOutputsAreClose(
4440         OpTestBuilder("Square").RandomInput(type).Attr("T", type));
4441   });
4442 }
4443 
TEST_F(OpTest,Squeeze)4444 TEST_F(OpTest, Squeeze) {
4445   Repeatedly([this]() {
4446     auto type = Choose<DataType>(kAllXlaTypes);
4447     std::vector<int64_t> t_dims = RandomDims(0, kDefaultMaxRank, 0, 5);
4448     std::bernoulli_distribution random_bool;
4449     std::vector<int> squeeze_dims;
4450     for (int i = 0; i < t_dims.size(); ++i) {
4451       if (t_dims[i] == 1 && random_bool(generator())) {
4452         squeeze_dims.push_back(i);
4453       }
4454     }
4455     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Squeeze")
4456                                              .RandomInput(type, t_dims)
4457                                              .Attr("squeeze_dims", squeeze_dims)
4458                                              .Attr("T", type));
4459   });
4460 }
4461 
TEST_F(OpTest,Sub)4462 TEST_F(OpTest, Sub) {
4463   Repeatedly([this]() {
4464     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
4465     auto dims = BroadcastableDims();
4466     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub")
4467                                              .RandomInput(type, dims.first)
4468                                              .RandomInput(type, dims.second)
4469                                              .Attr("T", type));
4470   });
4471 }
4472 
TEST_F(OpTest,Sum)4473 TEST_F(OpTest, Sum) {
4474   Repeatedly([this]() {
4475     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
4476     std::vector<int64_t> data_dims = RandomDims();
4477     Tensor indices = RandomReductionIndices(data_dims.size());
4478     bool keep_dims = Choose<bool>({false, true});
4479     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sum")
4480                                              .RandomInput(type, data_dims)
4481                                              .Input(indices)
4482                                              .Attr("T", type)
4483                                              .Attr("keep_dims", keep_dims));
4484   });
4485 }
4486 
TEST_F(OpTest,StridedSlice)4487 TEST_F(OpTest, StridedSlice) {
4488   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4489   Repeatedly([this]() {
4490     auto type = Choose<DataType>(kAllXlaTypes);
4491     std::vector<int64_t> data_dims = RandomDims();
4492     std::vector<int32> begin(data_dims.size()), end(data_dims.size());
4493     std::vector<int32> strides(data_dims.size());
4494     for (int i = 0; i < data_dims.size(); ++i) {
4495       begin[i] = std::uniform_int_distribution<int32>(
4496           -2 * data_dims[i], 2 * data_dims[i])(generator());
4497       end[i] = std::uniform_int_distribution<int32>(
4498           -2 * data_dims[i], 2 * data_dims[i])(generator());
4499       // TODO(b/31360685): support strides other than 1 or -1
4500       strides[i] = std::bernoulli_distribution()(generator()) ? 1 : -1;
4501     }
4502     int64_t max_bitmask = (1LL << data_dims.size()) - 1;
4503     std::uniform_int_distribution<int64_t> bitmask_distribution(0, max_bitmask);
4504     int64_t begin_mask = bitmask_distribution(generator());
4505     int64_t end_mask = bitmask_distribution(generator());
4506 
4507     // Create a ellipsis bitmask with at most one 1 bit set.
4508     int64_t ellipsis_mask = 0;
4509     if (!data_dims.empty() && std::bernoulli_distribution()(generator())) {
4510       int ellipsis_pos = std::uniform_int_distribution<int>(
4511           0, data_dims.size() - 1)(generator());
4512       ellipsis_mask = 1LL << ellipsis_pos;
4513     }
4514 
4515     int64_t new_axis_mask = bitmask_distribution(generator());
4516     int64_t shrink_axis_mask = bitmask_distribution(generator());
4517     return ExpectTfAndXlaOutputsAreClose(
4518         OpTestBuilder("StridedSlice")
4519             .RandomInput(type, data_dims)
4520             .Input(test::AsTensor<int32>(begin))
4521             .Input(test::AsTensor<int32>(end))
4522             .Input(test::AsTensor<int32>(strides))
4523             .Attr("T", type)
4524             .Attr("Index", DT_INT32)
4525             .Attr("begin_mask", begin_mask)
4526             .Attr("end_mask", end_mask)
4527             .Attr("ellipsis_mask", ellipsis_mask)
4528             .Attr("new_axis_mask", new_axis_mask)
4529             .Attr("shrink_axis_mask", shrink_axis_mask));
4530   });
4531 }
4532 
TEST_F(OpTest,StridedSliceGrad)4533 TEST_F(OpTest, StridedSliceGrad) {
4534   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4535   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4536   Repeatedly([this]() {
4537     auto type = Choose<DataType>(kAllXlaTypes);
4538 
4539     // Dimensions of the forward input.
4540     std::vector<int64_t> dims = RandomDims();
4541 
4542     std::vector<int64_t> begin(dims.size()), end(dims.size());
4543     std::vector<int64_t> strides(dims.size());
4544     for (int i = 0; i < dims.size(); ++i) {
4545       begin[i] = std::uniform_int_distribution<int64_t>(
4546           -2 * dims[i], 2 * dims[i])(generator());
4547       end[i] = std::uniform_int_distribution<int64_t>(-2 * dims[i],
4548                                                       2 * dims[i])(generator());
4549       strides[i] = std::uniform_int_distribution<int64_t>(
4550           -2 * dims[i], 2 * dims[i])(generator());
4551     }
4552     int64_t max_bitmask = (1LL << dims.size()) - 1;
4553     std::uniform_int_distribution<int64_t> bitmask_distribution(0, max_bitmask);
4554     int64_t begin_mask = bitmask_distribution(generator());
4555     int64_t end_mask = bitmask_distribution(generator());
4556 
4557     // Create a ellipsis bitmask with at most one 1 bit set.
4558     int64_t ellipsis_mask = 0;
4559     if (!dims.empty() && std::bernoulli_distribution()(generator())) {
4560       int ellipsis_pos =
4561           std::uniform_int_distribution<int>(0, dims.size() - 1)(generator());
4562       ellipsis_mask = 1LL << ellipsis_pos;
4563     }
4564 
4565     int64_t new_axis_mask = bitmask_distribution(generator());
4566     int64_t shrink_axis_mask = bitmask_distribution(generator());
4567 
4568     // TODO(phawkins): use shape inference for the forward op to compute the
4569     // gradient shape for the backward op. At present, there is a low
4570     // probability of the golden op succeeding.
4571     return ExpectTfAndXlaOutputsAreClose(
4572         OpTestBuilder("StridedSliceGrad")
4573             .Input(test::AsTensor<int64_t>(dims))
4574             .Input(test::AsTensor<int64_t>(begin))
4575             .Input(test::AsTensor<int64_t>(end))
4576             .Input(test::AsTensor<int64_t>(strides))
4577             .RandomInput(type, RandomDims(1))
4578             .Attr("T", type)
4579             .Attr("Index", DT_INT64)
4580             .Attr("begin_mask", begin_mask)
4581             .Attr("end_mask", end_mask)
4582             .Attr("ellipsis_mask", ellipsis_mask)
4583             .Attr("new_axis_mask", new_axis_mask)
4584             .Attr("shrink_axis_mask", shrink_axis_mask));
4585   });
4586 }
4587 
TEST_F(OpTest,Tan)4588 TEST_F(OpTest, Tan) {
4589   Repeatedly([this]() {
4590     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4591     return ExpectTfAndXlaOutputsAreClose(
4592         OpTestBuilder("Tan").RandomInput(type).Attr("T", type));
4593   });
4594 }
4595 
TEST_F(OpTest,Tanh)4596 TEST_F(OpTest, Tanh) {
4597   Repeatedly([this]() {
4598     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4599     return ExpectTfAndXlaOutputsAreClose(
4600         OpTestBuilder("Tanh").RandomInput(type).Attr("T", type));
4601   });
4602 }
4603 
TEST_F(OpTest,TanhGrad)4604 TEST_F(OpTest, TanhGrad) {
4605   Repeatedly([this]() {
4606     auto dims = RandomDims();
4607     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4608     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad")
4609                                              .RandomInput(type, dims)
4610                                              .RandomInput(type, dims)
4611                                              .Attr("T", type));
4612   });
4613 }
4614 
TEST_F(OpTest,TensorScatterUpdate)4615 TEST_F(OpTest, TensorScatterUpdate) {
4616   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4617   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4618   Repeatedly([this]() {  // NOLINT: due to GTEST_SKIP
4619     auto a = ChooseScatterArguments();
4620     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TensorScatterUpdate")
4621                                              .RandomInput(a.type, a.shape)
4622                                              .Input(a.indices)
4623                                              .Input(a.updates)
4624                                              .Attr("T", a.type)
4625                                              .Attr("Tindices", a.indices_type));
4626   });
4627 }
4628 
TEST_F(OpTest,Tile)4629 TEST_F(OpTest, Tile) {
4630   Repeatedly([this]() {
4631     auto type = Choose<DataType>(kAllXlaTypes);
4632     std::vector<int64_t> t_dims = RandomDims(1);
4633     std::vector<int32> multiples(t_dims.size());
4634     for (int i = 0; i < t_dims.size(); ++i) {
4635       multiples[i] = std::uniform_int_distribution<int>(1, 3)(generator());
4636     }
4637     return ExpectTfAndXlaOutputsAreClose(
4638         OpTestBuilder("Tile")
4639             .RandomInput(type, t_dims)
4640             .Input(test::AsTensor<int32>(multiples))
4641             .Attr("T", type));
4642   });
4643 }
4644 
TEST_F(OpTest,TopKV2)4645 TEST_F(OpTest, TopKV2) {
4646   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4647   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4648   Repeatedly([this]() {  // NOLINT: due to GTEST_SKIP
4649     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_INT64});
4650     auto shape = RandomDims(1);
4651     int32 k = std::uniform_int_distribution<int32>(1, shape[0])(generator());
4652     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TopKV2")
4653                                              .RandomInput(type, shape)
4654                                              .Input(test::AsScalar<int32>(k))
4655                                              .Attr("sorted", RandomBool())
4656                                              .Attr("T", type));
4657   });
4658 }
4659 
TEST_F(OpTest,Transpose)4660 TEST_F(OpTest, Transpose) {
4661   Repeatedly([this]() {
4662     auto type = Choose<DataType>(kAllXlaTypes);
4663     std::vector<int64_t> data_dims = RandomDims();
4664     std::vector<int32> perm(data_dims.size());
4665     std::iota(perm.begin(), perm.end(), 0);
4666     std::shuffle(perm.begin(), perm.end(), generator());
4667     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Transpose")
4668                                              .RandomInput(type, data_dims)
4669                                              .Input(test::AsTensor<int32>(perm))
4670                                              .Attr("T", type));
4671   });
4672 }
4673 
TEST_F(OpTest,TruncateDiv)4674 TEST_F(OpTest, TruncateDiv) {
4675   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4676   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4677   Repeatedly([this]() {
4678     DataType type = DT_INT32;
4679     auto dims = BroadcastableDims();
4680     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateDiv")
4681                                              .RandomInput(type, dims.first)
4682                                              .RandomInput(type, dims.second)
4683                                              .Attr("T", type));
4684   });
4685 }
4686 
TEST_F(OpTest,TruncateMod)4687 TEST_F(OpTest, TruncateMod) {
4688   Repeatedly([this]() {
4689     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
4690     auto dims = BroadcastableDims();
4691     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateMod")
4692                                              .RandomInput(type, dims.first)
4693                                              .RandomInput(type, dims.second)
4694                                              .Attr("T", type));
4695   });
4696 }
4697 
TEST_F(OpTest,Unpack)4698 TEST_F(OpTest, Unpack) {
4699   Repeatedly([this]() {
4700     auto type = Choose<DataType>(kAllXlaTypes);
4701     auto shape = RandomDims(1);
4702     int axis =
4703         std::uniform_int_distribution<int>(0, shape.size() - 1)(generator());
4704     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Unpack")
4705                                              .RandomInput(type, shape)
4706                                              .Attr("axis", axis)
4707                                              .Attr("T", type)
4708                                              .Attr("num", shape[axis]));
4709   });
4710 }
4711 
TEST_F(OpTest,Xdivy)4712 TEST_F(OpTest, Xdivy) {
4713   Repeatedly([this]() {
4714     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4715     auto dims = BroadcastableDims();
4716     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Xdivy")
4717                                              .RandomInput(type, dims.first)
4718                                              .RandomInput(type, dims.second)
4719                                              .Attr("T", type));
4720   });
4721 }
4722 
TEST_F(OpTest,XlaDot)4723 TEST_F(OpTest, XlaDot) {
4724   Repeatedly([this]() {
4725     const XlaDotArguments& a = ChooseXlaDotArguments();
4726     return ExpectTfAndXlaOutputsAreClose(
4727         OpTestBuilder("XlaDot")
4728             .RandomInput(a.dtype, a.lhs_dims)
4729             .RandomInput(a.dtype, a.rhs_dims)
4730             .Attr("dimension_numbers", a.dnums_encoded)
4731             .Attr("precision_config", a.precision_config_encoded)
4732             .Attr("T", a.dtype));
4733   });
4734 }
4735 
TEST_F(OpTest,XlaDotV2)4736 TEST_F(OpTest, XlaDotV2) {
4737   Repeatedly([this]() {
4738     const XlaDotArguments& a = ChooseXlaDotArguments();
4739     return ExpectTfAndXlaOutputsAreClose(
4740         OpTestBuilder("XlaDotV2")
4741             .RandomInput(a.dtype, a.lhs_dims)
4742             .RandomInput(a.dtype, a.rhs_dims)
4743             .Attr("dimension_numbers", a.dnums_encoded)
4744             .Attr("precision_config", a.precision_config_encoded)
4745             .Attr("LhsT", a.dtype)
4746             .Attr("RhsT", a.dtype)
4747             .Attr("preferred_element_type", a.dtype));
4748   });
4749 }
4750 
TEST_F(OpTest,XlaDynamicUpdateSlice)4751 TEST_F(OpTest, XlaDynamicUpdateSlice) {
4752   Repeatedly([this]() {
4753     SliceArguments a = ChooseSliceArguments(false);
4754     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("XlaDynamicUpdateSlice")
4755                                              .RandomInput(a.type, a.shape)
4756                                              .RandomInput(a.type, a.size)
4757                                              .Input(a.indices)
4758                                              .Attr("T", a.type)
4759                                              .Attr("Tindices", a.indices_type));
4760   });
4761 }
4762 
TEST_F(OpTest,XlaEinsum)4763 TEST_F(OpTest, XlaEinsum) {
4764   Repeatedly([this]() {
4765     const EinsumArguments a = ChooseEinsumArguments();
4766     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("XlaEinsum")
4767                                              .RandomInput(a.type, a.lhs_dims)
4768                                              .RandomInput(a.type, a.rhs_dims)
4769                                              .Attr("equation", a.equation)
4770                                              .Attr("T", a.type));
4771   });
4772 }
4773 
TEST_F(OpTest,XlaSort)4774 TEST_F(OpTest, XlaSort) {
4775   Repeatedly([this]() {
4776     auto type = Choose<DataType>(kAllXlaTypes);
4777     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("XlaSort")
4778                                              .RandomInput(type, RandomDims())
4779                                              .Attr("T", type));
4780   });
4781 }
4782 
TEST_F(OpTest,Xlog1py)4783 TEST_F(OpTest, Xlog1py) {
4784   Repeatedly([this]() {
4785     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4786     auto dims = BroadcastableDims();
4787     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Xlog1py")
4788                                              .RandomInput(type, dims.first)
4789                                              .RandomInput(type, dims.second)
4790                                              .Attr("T", type));
4791   });
4792 }
4793 
TEST_F(OpTest,Xlogy)4794 TEST_F(OpTest, Xlogy) {
4795   Repeatedly([this]() {
4796     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
4797     auto dims = BroadcastableDims();
4798     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Xlogy")
4799                                              .RandomInput(type, dims.first)
4800                                              .RandomInput(type, dims.second)
4801                                              .Attr("T", type));
4802   });
4803 }
4804 
TEST_F(OpTest,ZerosLike)4805 TEST_F(OpTest, ZerosLike) {
4806   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4807   Repeatedly([this]() {
4808     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
4809     return ExpectTfAndXlaOutputsAreClose(
4810         OpTestBuilder("ZerosLike").RandomInput(type).Attr("T", type));
4811   });
4812 }
4813 
TEST_F(OpTest,Zeta)4814 TEST_F(OpTest, Zeta) {
4815   Repeatedly([this]() {
4816     auto dims = BroadcastableDims();
4817     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Xlogy")
4818                                              .RandomInput(DT_FLOAT, dims.first)
4819                                              .RandomInput(DT_FLOAT, dims.second)
4820                                              .Attr("T", DT_FLOAT));
4821   });
4822 }
4823 
4824 // Example failing run:
4825 //   --tf_xla_reference_device=GPU:0
4826 //   --tf_xla_test_use_jit=true --tf_xla_test_device=GPU:0
4827 //   --tf_xla_test_use_mlir=true
4828 //   --tf_xla_test_repetitions=2
4829 //   --gunit_filter='OpTest.FusedBatchNormTraining'
4830 //   --tf_xla_random_seed=2838146746
TEST_F(OpTest,FusedBatchNormTraining)4831 TEST_F(OpTest, FusedBatchNormTraining) {
4832   if (tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/201095155";
4833   if (!tensorflow::tf_xla_test_use_mlir) GTEST_SKIP() << "b/197140886";
4834   bool is_nhwc = RandomBool();
4835   std::vector<int64_t> x_dims = RandomDims(/*min_rank=*/4, /*max_rank=*/4,
4836                                            /*min_size=*/5, /*max_size=*/20);
4837   std::vector<int64_t> scale_dims = {x_dims[is_nhwc ? 3 : 1]};
4838   std::vector<int64_t> offset_dims = {x_dims[is_nhwc ? 3 : 1]};
4839   std::vector<int64_t> mean_dims = {0};
4840   std::vector<int64_t> variance_dims = {0};
4841   DataType type = DT_FLOAT;
4842   Repeatedly([&] {
4843     return ExpectTfAndXlaOutputsAreClose(
4844         OpTestBuilder("FusedBatchNorm")
4845             .RandomInput(type, x_dims)
4846             .RandomInput(type, scale_dims)
4847             .RandomInput(type, offset_dims)
4848             .RandomInput(type, mean_dims)
4849             .RandomInput(type, variance_dims)
4850             .Attr("T", type)
4851             .Attr("data_format", is_nhwc ? "NHWC" : "NCHW")
4852             .Attr("epsilon", static_cast<float>(1.001e-05))
4853             .Attr("is_training", true));
4854   });
4855 }
4856 }  // anonymous namespace
4857 }  // namespace tensorflow
4858 
main(int argc,char ** argv)4859 int main(int argc, char** argv) {
4860   tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU:0");
4861   tensorflow::tf_xla_reference_device_ptr = new tensorflow::string("CPU:0");
4862   std::vector<tensorflow::Flag> flag_list = {
4863       tensorflow::Flag(
4864           "tf_xla_random_seed", &tensorflow::tf_xla_random_seed,
4865           "Random seed to use for XLA tests. <= 0 means choose a seed "
4866           "nondeterministically."),
4867       // TODO(phawkins): it might make more sense to run each test up to a
4868       // configurable time bound.
4869       tensorflow::Flag("tf_xla_test_repetitions",
4870                        &tensorflow::tf_xla_test_repetitions,
4871                        "Number of repetitions for each test."),
4872       tensorflow::Flag("tf_xla_max_tensor_size",
4873                        &tensorflow::tf_xla_max_tensor_size,
4874                        "Maximum number of elements for random input tensors."),
4875       tensorflow::Flag("tf_xla_test_device", tensorflow::tf_xla_test_device_ptr,
4876                        "Tensorflow device type to use for test"),
4877       tensorflow::Flag("tf_xla_reference_device",
4878                        tensorflow::tf_xla_reference_device_ptr,
4879                        "Tensorflow device type to use for reference"),
4880       tensorflow::Flag("tf_xla_test_use_jit", &tensorflow::tf_xla_test_use_jit,
4881                        "Use JIT compilation for the operator under test"),
4882       tensorflow::Flag(
4883           "tf_xla_test_use_mlir", &tensorflow::tf_xla_test_use_mlir,
4884           "Use MLIR legalization kernels for the operator under test"),
4885   };
4886   tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
4887   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
4888   if (!parse_result) {
4889     LOG(ERROR) << "\n" << usage;
4890     return 2;
4891   }
4892   testing::InitGoogleTest(&argc, argv);
4893   if (argc > 1) {
4894     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
4895     return 2;
4896   }
4897   // XLA devices register kernels at construction time; create all known devices
4898   // to make sure the kernels are registered.
4899   std::vector<std::unique_ptr<tensorflow::Device>> devices;
4900   TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
4901       tensorflow::SessionOptions(), "", &devices));
4902   tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
4903 
4904   tensorflow::Device* ignored;
4905   TF_QCHECK_OK(
4906       device_mgr.LookupDevice(*tensorflow::tf_xla_test_device_ptr, &ignored))
4907       << "Unknown test device (" << *tensorflow::tf_xla_test_device_ptr
4908       << "). Did you build in the right configuration (e.g., is CUDA enabled)?";
4909 
4910   if (tensorflow::tf_xla_test_use_mlir)
4911     tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
4912         tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
4913   return RUN_ALL_TESTS();
4914 }
4915