xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/dot_operation_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/array3d.h"
22 #include "tensorflow/compiler/xla/client/lib/matrix.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/primitive_util.h"
26 #include "tensorflow/compiler/xla/reference_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_parser.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/tests/test_utils.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/platform/test_benchmark.h"
36 
37 namespace xla {
38 namespace {
39 
40 class DotOperationTest : public ClientLibraryTestBase {
41  public:
42   ErrorSpec error_spec_{0.0001, 1e-5};
43 };
44 
45 using TypesF16F32 = ::testing::Types<
46 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
47     Eigen::half,
48 #endif
49     float>;
50 
51 using TypesF16F32F64 = ::testing::Types<
52 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
53     Eigen::half,
54 #endif
55 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
56     double,
57 #endif
58     float>;
59 
60 using TypesF16F32F64CF64 = ::testing::Types<
61 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
62     Eigen::half,
63 #endif
64 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
65     double, complex64,
66 #endif
67     float>;
68 
69 // Check that we can safely pass an input tuple's elements to a dot operation.
XLA_TEST_F(DotOperationTest,DotOfInputTupleElem)70 XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
71   XlaBuilder builder(TestName());
72 
73   XlaOp param;
74   TF_ASSERT_OK_AND_ASSIGN(
75       auto param_data,
76       CreateParameterAndTransferLiteral(
77           0,
78           LiteralUtil::MakeTupleFromSlices(
79               {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
80                LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}})}),
81           "arg0", &builder, &param));
82   auto lhs = GetTupleElement(param, 0);
83   auto rhs = GetTupleElement(param, 1);
84   Dot(lhs, rhs);
85 
86   ComputeAndCompareLiteral(&builder,
87                            LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
88                            {param_data.get()});
89 }
90 
91 template <typename T>
92 class DotOperationTest_F16F32F64CF64 : public DotOperationTest {};
93 TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64);
94 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,ZeroElementVectorDot)95 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) {
96   using T = TypeParam;
97   XlaBuilder builder(this->TestName());
98 
99   auto lhs = ConstantR1<T>(&builder, {});
100   auto rhs = ConstantR1<T>(&builder, {});
101   Dot(lhs, rhs);
102 
103   this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(0.0), {},
104                                         this->error_spec_);
105 }
106 
107 template <typename T>
108 class DotOperationTest_F16F32F64 : public DotOperationTest {};
109 TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64);
110 
XLA_TYPED_TEST(DotOperationTest_F16F32F64,TrivialMatrixVectorDot)111 XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) {
112   using T = TypeParam;
113   XlaBuilder builder(this->TestName());
114   auto lhs = ConstantR2FromArray2D<T>(&builder, {{3.0f, 4.0f}});
115   auto rhs = ConstantFromArray<T>(&builder, {3.0f, 4.0f});
116   Dot(lhs, rhs);
117 
118   this->template ComputeAndCompareR1<T>(&builder, {static_cast<T>(25.0f)}, {},
119                                         this->error_spec_);
120 }
121 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,OneElementVectorDot)122 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, OneElementVectorDot) {
123   using T = TypeParam;
124   XlaBuilder builder(this->TestName());
125   auto lhs = ConstantR1<T>(&builder, {static_cast<T>(2.0f)});
126   auto rhs = ConstantR1<T>(&builder, {static_cast<T>(3.0f)});
127   Dot(lhs, rhs);
128 
129   this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(6.0f), {},
130                                         this->error_spec_);
131 }
132 
XLA_TYPED_TEST(DotOperationTest_F16F32F64,VectorDot)133 XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) {
134   using T = TypeParam;
135   XlaBuilder builder(this->TestName());
136   auto lhs = ConstantFromArray<T>(&builder, {1.0f, 2.5f, 42.0f});
137   auto rhs = ConstantFromArray<T>(&builder, {11.0f, -1.0f, 0.5f});
138   Dot(lhs, rhs);
139 
140   this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(29.5f), {},
141                                         this->error_spec_);
142 }
143 
MinorToMajorForIsRowMajor(bool row_major)144 std::vector<int64_t> MinorToMajorForIsRowMajor(bool row_major) {
145   return {row_major ? 1 : 0, row_major ? 0 : 1};
146 }
147 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_0x2_2x0)148 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x0) {
149   using T = TypeParam;
150   XlaBuilder builder(this->TestName());
151   auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
152   auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
153   Dot(lhs, rhs);
154 
155   this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 0), {},
156                                         this->error_spec_);
157 }
158 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_0x2_2x3)159 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x3) {
160   using T = TypeParam;
161   XlaBuilder builder(this->TestName());
162   auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
163   auto rhs = ConstantR2FromArray2D<T>(
164       &builder, {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}});
165   Dot(lhs, rhs);
166 
167   this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 3), {},
168                                         this->error_spec_);
169 }
170 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_3x2_2x0)171 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_3x2_2x0) {
172   using T = TypeParam;
173   XlaBuilder builder(this->TestName());
174   auto lhs = ConstantR2FromArray2D<T>(
175       &builder, {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}});
176   auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
177   Dot(lhs, rhs);
178 
179   this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(3, 0), {},
180                                         this->error_spec_);
181 }
182 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_2x0_0x2)183 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_2x0_0x2) {
184   using T = TypeParam;
185   XlaBuilder builder(this->TestName());
186   auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
187   auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
188   Dot(lhs, rhs);
189 
190   this->template ComputeAndCompareR2<T>(
191       &builder, Array2D<T>(2, 2, static_cast<T>(0.0f)), {}, this->error_spec_);
192 }
193 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,FusedDot)194 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) {
195   using T = TypeParam;
196   XlaBuilder builder(this->TestName());
197   auto param0 =
198       Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 4}), "arg0");
199   auto param1 =
200       Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({4, 1}), "arg1");
201   auto exp0 = Exp(param0);
202   Dot(exp0, param1);
203 
204   auto lhs_handle =
205       this->client_
206           ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
207               {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
208           .value();
209   auto rhs_handle = this->client_
210                         ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
211                             {{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
212                         .value();
213 
214   if (std::is_same<Eigen::half, T>::value) {
215     this->error_spec_ = ErrorSpec{0.0001, 1e-3};
216   }
217 
218   this->template ComputeAndCompareR2<T>(
219       &builder, Array2D<T>({{296.14560492846033f}, {0.8611737683031964f}}),
220       {lhs_handle.get(), rhs_handle.get()}, this->error_spec_);
221 }
222 
223 template <typename T>
224 class SquareMatrixDot : public DotOperationTest {
225  public:
TestImpl(bool lhs_row_major,bool rhs_row_major)226   void TestImpl(bool lhs_row_major, bool rhs_row_major) {
227     auto lhs_handle =
228         client_
229             ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
230                 {{1.0f, 2.0f}, {3.0f, -4.0f}},
231                 LayoutUtil::MakeLayout(
232                     MinorToMajorForIsRowMajor(lhs_row_major))))
233             .value();
234     auto rhs_handle =
235         client_
236             ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
237                 {{1.0f, 6.0f}, {7.0f, -4.0f}},
238                 LayoutUtil::MakeLayout(
239                     MinorToMajorForIsRowMajor(rhs_row_major))))
240             .value();
241     XlaBuilder builder(TestName());
242     auto prim_type = primitive_util::NativeToPrimitiveType<T>();
243     Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"),
244         Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs"));
245 
246     Array2D<T> expected({{15.0f, -2.0f}, {-25.0f, 34.0f}});
247     ComputeAndCompareR2<T>(&builder, expected,
248                            {lhs_handle.get(), rhs_handle.get()}, error_spec_);
249   }
250 };
251 
252 TYPED_TEST_CASE(SquareMatrixDot, TypesF16F32F64CF64);
XLA_TYPED_TEST(SquareMatrixDot,TypesFF)253 XLA_TYPED_TEST(SquareMatrixDot, TypesFF) { this->TestImpl(false, false); }
XLA_TYPED_TEST(SquareMatrixDot,TypesFT)254 XLA_TYPED_TEST(SquareMatrixDot, TypesFT) { this->TestImpl(false, true); }
XLA_TYPED_TEST(SquareMatrixDot,TypesTF)255 XLA_TYPED_TEST(SquareMatrixDot, TypesTF) { this->TestImpl(true, false); }
XLA_TYPED_TEST(SquareMatrixDot,TypesTT)256 XLA_TYPED_TEST(SquareMatrixDot, TypesTT) { this->TestImpl(true, true); }
257 
258 struct DotTestParam {
259   int m;
260   int k;
261   int n;
262   bool dot_lhs_row_major;
263   bool dot_rhs_row_major;
264   bool has_addend;
265   bool addend_row_major;
266 };
267 
PrintDotTestParam(const::testing::TestParamInfo<DotTestParam> & test_param)268 std::string PrintDotTestParam(
269     const ::testing::TestParamInfo<DotTestParam>& test_param) {
270   const DotTestParam& param = test_param.param;
271   if (param.has_addend) {
272     return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor",
273                         param.dot_lhs_row_major ? "T" : "F",
274                         param.dot_rhs_row_major ? "T" : "F",
275                         param.addend_row_major ? "T" : "F");
276   } else {
277     return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor",
278                         param.dot_lhs_row_major ? "T" : "F",
279                         param.dot_rhs_row_major ? "T" : "F");
280   }
281 }
282 
283 class ParametricDotTest : public DotOperationTest,
284                           public ::testing::WithParamInterface<DotTestParam> {
285  protected:
286   template <typename NativeT>
287   void TestImpl();
288 
289   template <typename NativeT>
290   void ComputeAndCompareR2WithError(XlaBuilder* builder,
291                                     const Array2D<NativeT>& expected,
292                                     absl::Span<GlobalData* const> arguments);
293 };
294 
295 template <typename NativeT>
ComputeAndCompareR2WithError(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments)296 void ParametricDotTest::ComputeAndCompareR2WithError(
297     XlaBuilder* builder, const Array2D<NativeT>& expected,
298     absl::Span<GlobalData* const> arguments) {
299   ErrorSpec error_spec(0.3, 3e-3);
300   ComputeAndCompareR2(builder, expected, arguments, error_spec);
301 }
302 
303 template <>
ComputeAndCompareR2WithError(XlaBuilder * builder,const Array2D<Eigen::half> & expected,absl::Span<GlobalData * const> arguments)304 void ParametricDotTest::ComputeAndCompareR2WithError<Eigen::half>(
305     XlaBuilder* builder, const Array2D<Eigen::half>& expected,
306     absl::Span<GlobalData* const> arguments) {
307   ErrorSpec error_spec(0.3, 7e-3);
308   ComputeAndCompareR2(builder, expected, arguments, error_spec);
309 }
310 
311 template <>
ComputeAndCompareR2WithError(XlaBuilder * builder,const Array2D<int32_t> & expected,absl::Span<GlobalData * const> arguments)312 void ParametricDotTest::ComputeAndCompareR2WithError<int32_t>(
313     XlaBuilder* builder, const Array2D<int32_t>& expected,
314     absl::Span<GlobalData* const> arguments) {
315   ComputeAndCompareR2(builder, expected, arguments);
316 }
317 
318 template <typename NativeT>
TestImpl()319 void ParametricDotTest::TestImpl() {
320   DotTestParam param = GetParam();
321 
322   std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
323       MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
324   Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
325       *dot_lhs_data, LayoutUtil::MakeLayout(
326                          MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
327   std::unique_ptr<GlobalData> dot_lhs_handle =
328       client_->TransferToServer(dot_lhs_lit).value();
329 
330   std::unique_ptr<Array2D<NativeT>> dot_rhs_data =
331       MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n);
332   Layout rhs_layout = LayoutUtil::MakeLayout(
333       MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
334   Literal dot_rhs_lit =
335       LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
336   std::unique_ptr<GlobalData> dot_rhs_handle =
337       client_->TransferToServer(dot_rhs_lit).value();
338 
339   std::unique_ptr<Array2D<NativeT>> addend_data;
340   Literal addend_lit;
341   std::unique_ptr<GlobalData> addend_handle;
342 
343   if (param.has_addend) {
344     addend_data = MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.n);
345     addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
346         *addend_data, LayoutUtil::MakeLayout(
347                           MinorToMajorForIsRowMajor(param.addend_row_major)));
348     addend_handle = client_->TransferToServer(addend_lit).value();
349   }
350 
351   XlaBuilder builder(TestName());
352   auto prim_type = primitive_util::NativeToPrimitiveType<NativeT>();
353   auto result =
354       Dot(Parameter(&builder, 0,
355                     ShapeUtil::MakeShapeWithLayout(
356                         prim_type, {param.m, param.k},
357                         MinorToMajorForIsRowMajor(param.dot_lhs_row_major)),
358                     "dot_lhs"),
359           Parameter(&builder, 1,
360                     ShapeUtil::MakeShapeWithLayout(
361                         prim_type, {param.k, param.n},
362                         MinorToMajorForIsRowMajor(param.dot_rhs_row_major)),
363                     "dot_rhs"));
364 
365   if (param.has_addend) {
366     result =
367         Add(result,
368             Parameter(&builder, 2,
369                       ShapeUtil::MakeShapeWithLayout(
370                           prim_type, {param.m, param.n},
371                           MinorToMajorForIsRowMajor(param.addend_row_major)),
372                       "addend"));
373   }
374 
375   std::unique_ptr<Array2D<NativeT>> expected;
376   if (param.has_addend) {
377     expected = ReferenceUtil::ApplyElementwise2D(
378         std::plus<NativeT>(),
379         *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data),
380         *addend_data);
381   } else {
382     expected = ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data);
383   }
384 
385   std::vector<GlobalData*> args = {dot_lhs_handle.get(), dot_rhs_handle.get()};
386   if (param.has_addend) {
387     args.push_back(addend_handle.get());
388   }
389   ComputeAndCompareR2WithError<NativeT>(&builder, *expected, args);
390 }
391 
CreateDotTestParameters()392 std::vector<DotTestParam> CreateDotTestParameters() {
393   std::vector<DotTestParam> params;
394 
395   auto add_matrix_matrix_dot_test = [&](int m, int k, int n) {
396     for (bool lhs_row_major : {true, false}) {
397       for (bool rhs_row_major : {true, false}) {
398         params.push_back({/*m=*/m, /*k=*/k, /*n=*/n,
399                           /*dot_lhs_row_major=*/lhs_row_major,
400                           /*dot_rhs_row_major=*/rhs_row_major,
401                           /*has_addend=*/false, /*addend_row_major=*/true});
402       }
403     }
404   };
405 
406   add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/42);
407   add_matrix_matrix_dot_test(/*m=*/23, /*k=*/1, /*n=*/42);
408   add_matrix_matrix_dot_test(/*m=*/23, /*k=*/42, /*n=*/1);
409   add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/1);
410   add_matrix_matrix_dot_test(/*m=*/1, /*k=*/1, /*n=*/1);
411   add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7);
412   add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520);
413   add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520);
414 
415   return params;
416 }
417 
418 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
XLA_TEST_P(ParametricDotTest,TestF16)419 XLA_TEST_P(ParametricDotTest, TestF16) { TestImpl<Eigen::half>(); }
420 #endif
XLA_TEST_P(ParametricDotTest,TestF32)421 XLA_TEST_P(ParametricDotTest, TestF32) { TestImpl<float>(); }
XLA_TEST_P(ParametricDotTest,TestF64)422 XLA_TEST_P(ParametricDotTest, TestF64) { TestImpl<double>(); }
XLA_TEST_P(ParametricDotTest,TestC64)423 XLA_TEST_P(ParametricDotTest, TestC64) { TestImpl<std::complex<float>>(); }
424 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX128
XLA_TEST_P(ParametricDotTest,TestC128)425 XLA_TEST_P(ParametricDotTest, TestC128) { TestImpl<std::complex<double>>(); }
426 #endif
XLA_TEST_P(ParametricDotTest,TestS32)427 XLA_TEST_P(ParametricDotTest, TestS32) { TestImpl<int32_t>(); }
428 
429 INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest,
430                         ::testing::ValuesIn(CreateDotTestParameters()),
431                         PrintDotTestParam);
432 
433 class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest {
434  public:
ParametricDotTestWithoutLayoutAssignment()435   ParametricDotTestWithoutLayoutAssignment() {
436     execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
437         "layout-assignment");
438     execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
439         "hlo-verifier");
440     // Disable algebraic simplification because the pass may replace a dot
441     // instruction with a layout-changing multiplication instruction.
442     execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
443         "algsimp");
444   }
445 };
446 
CreateNoLayoutAssignmentDotTestParameters()447 std::vector<DotTestParam> CreateNoLayoutAssignmentDotTestParameters() {
448   std::vector<DotTestParam> params;
449 
450   auto add_matrix_vector_dot_test = [&](int k, int n) {
451     for (bool lhs_row_major : {true, false}) {
452       for (bool rhs_row_major : {true, false}) {
453         for (bool has_addend : {true, false}) {
454           // The addend needs to be row major to match the result of the dot.
455           params.push_back({/*m=*/1, /*k=*/k, /*n=*/n,
456                             /*dot_lhs_row_major=*/lhs_row_major,
457                             /*dot_rhs_row_major=*/rhs_row_major,
458                             /*has_addend=*/has_addend,
459                             /*addend_row_major=*/true});
460           if (n != 1) {
461             params.push_back({/*m=*/n, /*k=*/k, /*n=*/1,
462                               /*dot_lhs_row_major=*/lhs_row_major,
463                               /*dot_rhs_row_major=*/rhs_row_major,
464                               /*has_addend=*/has_addend,
465                               /*addend_row_major=*/true});
466           }
467         }
468       }
469     }
470   };
471 
472   add_matrix_vector_dot_test(/*k=*/8, /*n=*/8);
473   add_matrix_vector_dot_test(/*k=*/130, /*n=*/8);
474   add_matrix_vector_dot_test(/*k=*/8, /*n=*/130);
475   add_matrix_vector_dot_test(/*k=*/290, /*n=*/130);
476   add_matrix_vector_dot_test(/*k=*/1, /*n=*/1);
477   add_matrix_vector_dot_test(/*k=*/1, /*n=*/16);
478   add_matrix_vector_dot_test(/*k=*/1, /*n=*/4);
479   add_matrix_vector_dot_test(/*k=*/1, /*n=*/3);
480   add_matrix_vector_dot_test(/*k=*/3, /*n=*/16);
481   add_matrix_vector_dot_test(/*k=*/3, /*n=*/3);
482   add_matrix_vector_dot_test(/*k=*/29, /*n=*/29);
483   add_matrix_vector_dot_test(/*k=*/8, /*n=*/2);
484   add_matrix_vector_dot_test(/*k=*/2, /*n=*/8);
485   add_matrix_vector_dot_test(/*k=*/259, /*n=*/258);
486 
487   return params;
488 }
489 
490 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment,TestF16)491 XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF16) {
492   TestImpl<Eigen::half>();
493 }
494 #endif
XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment,TestF32)495 XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF32) {
496   TestImpl<float>();
497 }
498 // TODO(b/147505663): Disabled for now.
XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment,DISABLED_TestF64)499 XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, DISABLED_TestF64) {
500   TestImpl<double>();
501 }
502 
503 INSTANTIATE_TEST_CASE_P(
504     DotTests, ParametricDotTestWithoutLayoutAssignment,
505     ::testing::ValuesIn(CreateNoLayoutAssignmentDotTestParameters()),
506     PrintDotTestParam);
507 
508 template <typename T>
509 class NonsquareMatrixDot : public DotOperationTest {
510  public:
TestImpl(bool lhs_row_major,bool rhs_row_major)511   void TestImpl(bool lhs_row_major, bool rhs_row_major) {
512     auto lhs_handle =
513         client_
514             ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
515                 {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
516                 LayoutUtil::MakeLayout(
517                     MinorToMajorForIsRowMajor(lhs_row_major))))
518             .value();
519     auto rhs_handle =
520         client_
521             ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
522                 {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
523                 LayoutUtil::MakeLayout(
524                     MinorToMajorForIsRowMajor(rhs_row_major))))
525             .value();
526 
527     XlaBuilder builder(TestName());
528     auto prim_type = primitive_util::NativeToPrimitiveType<T>();
529     Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"),
530         Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs"));
531 
532     Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}});
533 
534     ComputeAndCompareR2<T>(&builder, expected,
535                            {lhs_handle.get(), rhs_handle.get()}, error_spec_);
536   }
537 };
538 
539 TYPED_TEST_CASE(NonsquareMatrixDot, TypesF16F32F64CF64);
XLA_TYPED_TEST(NonsquareMatrixDot,TestFF)540 XLA_TYPED_TEST(NonsquareMatrixDot, TestFF) { this->TestImpl(false, false); }
XLA_TYPED_TEST(NonsquareMatrixDot,TestFT)541 XLA_TYPED_TEST(NonsquareMatrixDot, TestFT) { this->TestImpl(false, true); }
XLA_TYPED_TEST(NonsquareMatrixDot,TestTF)542 XLA_TYPED_TEST(NonsquareMatrixDot, TestTF) { this->TestImpl(true, false); }
XLA_TYPED_TEST(NonsquareMatrixDot,TestTT)543 XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
544 
XLA_TEST_F(DotOperationTest,MatrixVectorC64)545 XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
546   auto lhs_handle =
547       client_
548           ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
549               {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
550           .value();
551   auto rhs_handle =
552       client_
553           ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
554               {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
555               LayoutUtil::MakeLayout({1, 0})))
556           .value();
557 
558   XlaBuilder builder(TestName());
559   auto prim_type = primitive_util::NativeToPrimitiveType<complex64>();
560   Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"),
561       Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs"));
562 
563   Array2D<complex64> expected({{30.0, -2.0}});
564 
565   ComputeAndCompareR2<complex64>(
566       &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
567 }
568 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,ConcurrentMatMult)569 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ConcurrentMatMult) {
570   using T = TypeParam;
571 
572   XlaBuilder builder(this->TestName());
573   auto matrix1 =
574       ConstantR2FromArray2D<T>(&builder, {{1.0f, 2.0f}, {3.0f, 4.0f}});
575   auto matrix2 =
576       ConstantR2FromArray2D<T>(&builder, {{5.0f, 6.0f}, {7.0f, 8.0f}});
577   auto matrix12 = Dot(matrix1, matrix2);
578   auto matrix21 = Dot(matrix2, matrix1);
579   Add(matrix12, matrix21);
580 
581   Array2D<T> expected({{42.0f, 56.0f}, {74.0f, 96.0f}});
582   this->template ComputeAndCompareR2<T>(&builder, expected, {},
583                                         this->error_spec_);
584 }
585 
586 template <typename T>
587 class DotOperationTestForBatchMatMul : public DotOperationTest {};
588 TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64);
589 
590 // Regression test for b/32055648. The root of the graph is a kFusion of 4
591 // bitcasts. Although bitcasts don't map to thunks, the root should still be
592 // sync-dependent on bitcasts' operands.
XLA_TYPED_TEST(DotOperationTestForBatchMatMul,Types)593 XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
594   using T = TypeParam;
595   XlaBuilder builder(this->TestName());
596   auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
597                      "x");
598   auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
599                      "y");
600 
601   auto x_flat = Reshape(x, {0, 1, 2, 3}, {4, 2, 2});
602   auto y_flat = Reshape(y, {0, 1, 2, 3}, {4, 2, 2});
603 
604   // Slice batches into individual matrices and multiply them.
605   std::vector<XlaOp> out_slices;
606   const auto n = 4;
607   out_slices.reserve(n);
608   for (int i = 0; i < n; ++i) {
609     // Slice off individual matrices and reshape to 2D tensors.
610     auto x_slice = Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
611     x_slice = Reshape(x_slice, {0, 1, 2}, {2, 2});
612     auto y_slice = Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
613     y_slice = Reshape(y_slice, {0, 1, 2}, {2, 2});
614 
615     auto out = Dot(x_slice, y_slice);
616     out = Reshape(out, {0, 1}, {1, 2, 2});
617     out_slices.push_back(out);
618   }
619   auto out_flat = ConcatInDim(&builder, out_slices, 0);
620   Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
621 
622   auto x_data = this->client_
623                     ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
624                         {{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
625                           {{2000.0f, 200.0f}, {20.0f, 2.0f}}},
626                          {{{3000.0f, 300.0f}, {30.0f, 3.0f}},
627                           {{4000.0f, 400.0f}, {40.0f, 4.0f}}}}))
628                     .value();
629   auto y_data =
630       this->client_
631           ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
632               {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
633                {{{11.0f, 22.0f}, {33.0f, 44.0f}},
634                 {{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
635           .value();
636 
637   if (std::is_same<Eigen::half, T>::value) {
638     this->error_spec_ = ErrorSpec{0.0001, 1e-3};
639   }
640   this->template ComputeAndCompareR4<T>(
641       &builder,
642       /*expected=*/
643       {{{{1300.0f, 2400.0f}, {13.0f, 24.0f}},
644         {{11400.0f, 13600.0f}, {114.0f, 136.0f}}},
645        {{{42900.0f, 79200.0f}, {429.0f, 792.0f}},
646         {{250800.0f, 299200.0f}, {2508.0f, 2992.0f}}}},
647       {x_data.get(), y_data.get()}, this->error_spec_);
648 }
649 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMul)650 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) {
651   using T = TypeParam;
652 
653   XlaBuilder builder(this->TestName());
654   auto x =
655       Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "x");
656   auto y =
657       Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "y");
658 
659   DotDimensionNumbers dnums;
660   dnums.add_lhs_contracting_dimensions(2);
661   dnums.add_rhs_contracting_dimensions(1);
662   dnums.add_lhs_batch_dimensions(0);
663   dnums.add_rhs_batch_dimensions(0);
664 
665   DotGeneral(x, y, dnums);
666 
667   auto x_data =
668       this->client_
669           ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
670               {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
671           .value();
672 
673   auto y_data =
674       this->client_
675           ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
676               {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
677           .value();
678 
679   this->template ComputeAndCompareR3<T>(
680       &builder,
681       /*expected=*/
682       {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
683       {x_data.get(), y_data.get()}, this->error_spec_);
684 }
685 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMulR3LhsR2Rhs)686 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR3LhsR2Rhs) {
687   using T = TypeParam;
688 
689   XlaBuilder builder(this->TestName());
690   auto x =
691       Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "x");
692   auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2}), "y");
693 
694   DotDimensionNumbers dnums;
695   dnums.add_lhs_contracting_dimensions(1);
696   dnums.add_rhs_contracting_dimensions(1);
697   dnums.add_lhs_batch_dimensions(0);
698   dnums.add_rhs_batch_dimensions(0);
699 
700   DotGeneral(x, y, dnums);
701 
702   auto x_data =
703       this->client_
704           ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
705               {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
706           .value();
707 
708   auto y_data = this->client_
709                     ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
710                         {{1.0f, 0.0f}, {0.0f, 1.0f}}))
711                     .value();
712 
713   this->template ComputeAndCompareR2<T>(
714       &builder,
715       /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()},
716       this->error_spec_);
717 }
718 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMulR2LhsR3Rhs)719 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR2LhsR3Rhs) {
720   using T = TypeParam;
721 
722   XlaBuilder builder(this->TestName());
723   auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2}), "x");
724   auto y =
725       Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "y");
726 
727   DotDimensionNumbers dnums;
728   dnums.add_lhs_contracting_dimensions(1);
729   dnums.add_rhs_contracting_dimensions(1);
730   dnums.add_lhs_batch_dimensions(0);
731   dnums.add_rhs_batch_dimensions(0);
732 
733   DotGeneral(x, y, dnums);
734 
735   auto x_data = this->client_
736                     ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
737                         {{1.0f, 0.0f}, {0.0f, 1.0f}}))
738                     .value();
739 
740   auto y_data =
741       this->client_
742           ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
743               {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
744           .value();
745 
746   this->template ComputeAndCompareR2<T>(
747       &builder,
748       /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()},
749       this->error_spec_);
750 }
751 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMulMultipleBatch)752 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
753   using T = TypeParam;
754 
755   XlaBuilder builder(this->TestName());
756   auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
757                      "x");
758   auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
759                      "y");
760 
761   DotDimensionNumbers dnums;
762   dnums.add_lhs_contracting_dimensions(3);
763   dnums.add_rhs_contracting_dimensions(2);
764   dnums.add_lhs_batch_dimensions(0);
765   dnums.add_lhs_batch_dimensions(1);
766   dnums.add_rhs_batch_dimensions(0);
767   dnums.add_rhs_batch_dimensions(1);
768 
769   DotGeneral(x, y, dnums);
770 
771   auto x_data =
772       this->client_
773           ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
774               {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
775                {{{9.0f, 10.0f}, {11.0f, 12.0f}},
776                 {{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
777           .value();
778 
779   auto y_data =
780       this->client_
781           ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
782               {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
783                {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
784           .value();
785 
786   this->template ComputeAndCompareR4<T>(
787       &builder,
788       /*expected=*/
789       {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
790        {{{10.0f, 9.0f}, {12.0f, 11.0f}}, {{14.0f, 13.0f}, {16.0f, 15.0f}}}},
791       {x_data.get(), y_data.get()}, this->error_spec_);
792 }
793 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,TransposeFolding)794 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) {
795   using T = TypeParam;
796   for (bool transpose_lhs : {false, true}) {
797     for (bool transpose_rhs : {false, true}) {
798       for (bool row_major : {false, true}) {
799         std::unique_ptr<Array2D<T>> lhs(
800             new Array2D<T>({{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}));
801         std::unique_ptr<Array2D<T>> rhs(
802             new Array2D<T>({{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}));
803 
804         if (transpose_lhs) {
805           lhs = ReferenceUtil::TransposeArray2D(*lhs);
806         }
807         if (transpose_rhs) {
808           rhs = ReferenceUtil::TransposeArray2D(*rhs);
809         }
810         auto lhs_handle =
811             this->client_
812                 ->TransferToServer(
813                     LiteralUtil::CreateR2FromArray2DWithLayout<T>(
814                         *lhs, LayoutUtil::MakeLayout(
815                                   MinorToMajorForIsRowMajor(row_major))))
816                 .value();
817         auto rhs_handle =
818             this->client_
819                 ->TransferToServer(
820                     LiteralUtil::CreateR2FromArray2DWithLayout<T>(
821                         *rhs, LayoutUtil::MakeLayout(
822                                   MinorToMajorForIsRowMajor(row_major))))
823                 .value();
824 
825         XlaBuilder builder(this->TestName());
826         auto prim_type = primitive_util::NativeToPrimitiveType<T>();
827         auto lhs_arg = Parameter(
828             &builder, 0,
829             ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}),
830             "lhs");
831         auto rhs_arg = Parameter(
832             &builder, 1,
833             ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}),
834             "rhs");
835         if (transpose_lhs) {
836           lhs_arg = Transpose(lhs_arg, {1, 0});
837         }
838         if (transpose_rhs) {
839           rhs_arg = Transpose(rhs_arg, {1, 0});
840         }
841         Dot(lhs_arg, rhs_arg);
842 
843         Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}});
844         VLOG(1) << "TestTransposeFolding " << transpose_lhs << " "
845                 << transpose_rhs << " " << row_major;
846         this->template ComputeAndCompareR2<T>(
847             &builder, expected, {lhs_handle.get(), rhs_handle.get()},
848             this->error_spec_);
849       }
850     }
851   }
852 }
853 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,DotOfConcatOptimizationWithConstLHS)854 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
855                DotOfConcatOptimizationWithConstLHS) {
856   using T = TypeParam;
857   auto prim_type = primitive_util::NativeToPrimitiveType<T>();
858 
859   std::unique_ptr<Array2D<T>> constant_lhs_array(
860       new Array2D<T>({{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
861                       {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}}));
862 
863   XlaBuilder builder(this->TestName());
864   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
865   auto rhs_arg_0 = Parameter(
866       &builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0");
867   auto rhs_arg_1 = Parameter(
868       &builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs_arg_1");
869   auto rhs_arg_2 = Parameter(
870       &builder, 2, ShapeUtil::MakeShape(prim_type, {1, 2}), "rhs_arg_2");
871   Dot(lhs_constant,
872       ConcatInDim(&builder, {rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0));
873 
874   std::unique_ptr<Array2D<T>> arg_0_value_array(
875       new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
876   std::unique_ptr<Array2D<T>> arg_1_value_array(
877       new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}));
878   std::unique_ptr<Array2D<T>> arg_2_value_array(new Array2D<T>({{1.0f, 2.0f}}));
879 
880   TF_ASSERT_OK_AND_ASSIGN(
881       auto arg_0_value,
882       this->client_->TransferToServer(
883           LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
884   TF_ASSERT_OK_AND_ASSIGN(
885       auto arg_1_value,
886       this->client_->TransferToServer(
887           LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
888   TF_ASSERT_OK_AND_ASSIGN(
889       auto arg_2_value,
890       this->client_->TransferToServer(
891           LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
892 
893   Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
894   this->template ComputeAndCompareR2<T>(
895       &builder, expected,
896       {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()},
897       this->error_spec_);
898 }
899 
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,DotOfConcatOptimizationWithConstRHS)900 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
901                DotOfConcatOptimizationWithConstRHS) {
902   using T = TypeParam;
903   std::unique_ptr<Array2D<T>> constant_rhs_array(
904       new Array2D<T>({{1.0f, 2.0f},
905                       {3.0f, 4.0f},
906                       {5.0f, 6.0f},
907                       {6.0f, 5.0f},
908                       {4.0f, 3.0f},
909                       {2.0f, 1.0f}}));
910 
911   XlaBuilder builder(this->TestName());
912   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
913   auto lhs_arg_0 = Parameter(
914       &builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2}), "lhs_arg_0");
915   auto lhs_arg_1 = Parameter(
916       &builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 3}), "lhs_arg_1");
917   auto lhs_arg_2 = Parameter(
918       &builder, 2, ShapeUtil::MakeShapeWithType<T>({2, 1}), "lhs_arg_2");
919   Dot(ConcatInDim(&builder, {lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1),
920       rhs_constant);
921 
922   std::unique_ptr<Array2D<T>> arg_0_value_array(
923       new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
924   std::unique_ptr<Array2D<T>> arg_1_value_array(
925       new Array2D<T>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}));
926   std::unique_ptr<Array2D<T>> arg_2_value_array(
927       new Array2D<T>({{1.0f}, {2.0f}}));
928 
929   TF_ASSERT_OK_AND_ASSIGN(
930       auto arg_0_value,
931       this->client_->TransferToServer(
932           LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
933   TF_ASSERT_OK_AND_ASSIGN(
934       auto arg_1_value,
935       this->client_->TransferToServer(
936           LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
937   TF_ASSERT_OK_AND_ASSIGN(
938       auto arg_2_value,
939       this->client_->TransferToServer(
940           LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
941 
942   Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
943   this->template ComputeAndCompareR2<T>(
944       &builder, expected,
945       {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()},
946       this->error_spec_);
947 }
948 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSClassicMM)949 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) {
950   std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
951       {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
952   std::unique_ptr<Array2D<float>> constant_rhs_array(
953       new Array2D<float>({{1.0, 2.0, 3.0},
954                           {4.0, 5.0, 6.0},
955                           {7.0, 8.0, 9.0},
956                           {9.0, 8.0, 7.0},
957                           {6.0, 5.0, 4.0},
958                           {3.0, 2.0, 1.0}}));
959   // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
960 
961   XlaBuilder builder(TestName());
962   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
963   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
964   auto one = ConstantR0<int32_t>(&builder, 1);
965   auto zero = ConstantR0<int32_t>(&builder, 0);
966   auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6});
967 
968   DotDimensionNumbers dot_dnums;
969   dot_dnums.add_lhs_contracting_dimensions(1);
970   dot_dnums.add_rhs_contracting_dimensions(0);
971   DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
972 
973   Array2D<float> expected({{96.0, 105.0, 114.0}});
974   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
975 }
976 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSClassicMM)977 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
978   std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
979       {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
980   std::unique_ptr<Array2D<float>> constant_rhs_array(
981       new Array2D<float>({{1.0, 2.0, 3.0},
982                           {4.0, 5.0, 6.0},
983                           {7.0, 8.0, 9.0},
984                           {9.0, 8.0, 7.0},
985                           {6.0, 5.0, 4.0},
986                           {3.0, 2.0, 1.0}}));
987   // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
988 
989   XlaBuilder builder(TestName());
990   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
991   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
992   auto zero = ConstantR0<int32_t>(&builder, 0);
993   auto one = ConstantR0<int32_t>(&builder, 1);
994   auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1});
995 
996   DotDimensionNumbers dot_dnums;
997   dot_dnums.add_lhs_contracting_dimensions(1);
998   dot_dnums.add_rhs_contracting_dimensions(0);
999   DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1000 
1001   Array2D<float> expected({{105.0}, {105.0}});
1002   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1003 }
1004 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSReverseMM)1005 XLA_TEST_F(DotOperationTest,
1006 
1007            DotOfGatherOptimizationWithConstRHSReverseMM) {
1008   std::unique_ptr<Array2D<float>> constant_lhs_array(
1009       new Array2D<float>({{1.0, 2.0, 3.0},
1010                           {4.0, 5.0, 6.0},
1011                           {7.0, 8.0, 9.0},
1012                           {9.0, 8.0, 7.0},
1013                           {6.0, 5.0, 4.0},
1014                           {3.0, 2.0, 1.0}}));
1015   std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
1016       {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1017   // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
1018 
1019   XlaBuilder builder(TestName());
1020   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1021   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1022   auto zero = ConstantR0<int32_t>(&builder, 0);
1023   auto one = ConstantR0<int32_t>(&builder, 1);
1024   auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1});
1025 
1026   DotDimensionNumbers dot_dnums;
1027   dot_dnums.add_lhs_contracting_dimensions(0);
1028   dot_dnums.add_rhs_contracting_dimensions(1);
1029   DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
1030 
1031   Array2D<float> expected({{105.0, 105.0}});
1032   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1033 }
1034 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSReverseMM)1035 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) {
1036   std::unique_ptr<Array2D<float>> constant_lhs_array(
1037       new Array2D<float>({{1.0, 2.0, 3.0},
1038                           {4.0, 5.0, 6.0},
1039                           {7.0, 8.0, 9.0},
1040                           {9.0, 8.0, 7.0},
1041                           {6.0, 5.0, 4.0},
1042                           {3.0, 2.0, 1.0}}));
1043   std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
1044       {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1045   // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
1046 
1047   XlaBuilder builder(TestName());
1048   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1049   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1050   auto zero = ConstantR0<int32_t>(&builder, 0);
1051   auto one = ConstantR0<int32_t>(&builder, 1);
1052   auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6});
1053 
1054   DotDimensionNumbers dot_dnums;
1055   dot_dnums.add_lhs_contracting_dimensions(0);
1056   dot_dnums.add_rhs_contracting_dimensions(1);
1057   DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1058 
1059   Array2D<float> expected({{96.0}, {105.0}, {114.0}});
1060   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1061 }
1062 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSRows)1063 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) {
1064   std::unique_ptr<Array2D<float>> constant_lhs_array(
1065       new Array2D<float>({{1.0, 2.0},
1066                           {3.0, 4.0},
1067                           {5.0, 6.0},
1068                           {6.0, 5.0},
1069                           {4.0, 3.0},
1070                           {2.0, 1.0}}));
1071   std::unique_ptr<Array2D<float>> constant_rhs_array(
1072       new Array2D<float>({{1.0, 2.0, 3.0},
1073                           {4.0, 5.0, 6.0},
1074                           {7.0, 8.0, 9.0},
1075                           {9.0, 8.0, 7.0},
1076                           {6.0, 5.0, 4.0},
1077                           {3.0, 2.0, 1.0}}));
1078   // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
1079 
1080   XlaBuilder builder(TestName());
1081   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1082   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1083   auto zero = ConstantR0<int32_t>(&builder, 0);
1084   auto one = ConstantR0<int32_t>(&builder, 1);
1085   auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1});
1086 
1087   DotDimensionNumbers dot_dnums;
1088   dot_dnums.add_lhs_contracting_dimensions(0);
1089   dot_dnums.add_rhs_contracting_dimensions(0);
1090   DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
1091 
1092   Array2D<float> expected({{126.0, 129.0, 132.0}});
1093   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1094 }
1095 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSRows)1096 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) {
1097   std::unique_ptr<Array2D<float>> constant_lhs_array(
1098       new Array2D<float>({{1.0, 2.0},
1099                           {3.0, 4.0},
1100                           {5.0, 6.0},
1101                           {6.0, 5.0},
1102                           {4.0, 3.0},
1103                           {2.0, 1.0}}));
1104   std::unique_ptr<Array2D<float>> constant_rhs_array(
1105       new Array2D<float>({{1.0, 2.0, 3.0},
1106                           {4.0, 5.0, 6.0},
1107                           {7.0, 8.0, 9.0},
1108                           {9.0, 8.0, 7.0},
1109                           {6.0, 5.0, 4.0},
1110                           {3.0, 2.0, 1.0}}));
1111   // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
1112 
1113   XlaBuilder builder(TestName());
1114   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1115   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1116   auto zero = ConstantR0<int32_t>(&builder, 0);
1117   auto one = ConstantR0<int32_t>(&builder, 1);
1118   auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1});
1119 
1120   DotDimensionNumbers dot_dnums;
1121   dot_dnums.add_lhs_contracting_dimensions(0);
1122   dot_dnums.add_rhs_contracting_dimensions(0);
1123   DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1124 
1125   Array2D<float> expected({{129.0}, {129.0}});
1126   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1127 }
1128 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSCols)1129 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) {
1130   std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
1131       {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1132   std::unique_ptr<Array2D<float>> constant_rhs_array(
1133       new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
1134                           {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
1135                           {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1136   // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
1137 
1138   XlaBuilder builder(TestName());
1139   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1140   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1141   auto zero = ConstantR0<int32_t>(&builder, 0);
1142   auto one = ConstantR0<int32_t>(&builder, 1);
1143   auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6});
1144 
1145   DotDimensionNumbers dot_dnums;
1146   dot_dnums.add_lhs_contracting_dimensions(1);
1147   dot_dnums.add_rhs_contracting_dimensions(1);
1148   DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
1149 
1150   Array2D<float> expected({{56.0, 168.0, 91.0}});
1151   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1152 }
1153 
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSCols)1154 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) {
1155   std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
1156       {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1157   std::unique_ptr<Array2D<float>> constant_rhs_array(
1158       new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
1159                           {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
1160                           {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1161   // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
1162 
1163   XlaBuilder builder(TestName());
1164   auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1165   auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1166   auto zero = ConstantR0<int32_t>(&builder, 0);
1167   auto one = ConstantR0<int32_t>(&builder, 1);
1168   auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6});
1169 
1170   DotDimensionNumbers dot_dnums;
1171   dot_dnums.add_lhs_contracting_dimensions(1);
1172   dot_dnums.add_rhs_contracting_dimensions(1);
1173   DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1174 
1175   Array2D<float> expected({{168.0}, {168.0}});
1176   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1177 }
1178 
XLA_TEST_F(DotOperationTest,DotRank2AndRank2NonDefaultContractionDims)1179 XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) {
1180   XlaBuilder builder(TestName());
1181 
1182   Array2D<float> lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}});
1183   auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array);
1184 
1185   Array2D<float> rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}});
1186   auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array);
1187 
1188   Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
1189   DotDimensionNumbers dot_dnums;
1190   dot_dnums.add_lhs_contracting_dimensions(0);
1191   dot_dnums.add_rhs_contracting_dimensions(0);
1192   DotGeneral(lhs_constant, rhs_constant, dot_dnums);
1193 
1194   Array2D<float> expected({
1195       {26.f, 30.f},
1196       {38.f, 44.f},
1197   });
1198 
1199   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1200 }
1201 
1202 using EinsumParamType =
1203     std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::string>;
1204 class EinsumTest : public DotOperationTest,
1205                    public ::testing::WithParamInterface<EinsumParamType> {};
XLA_TEST_P(EinsumTest,SimpleEinsumTest)1206 XLA_TEST_P(EinsumTest, SimpleEinsumTest) {
1207   XlaBuilder builder(TestName());
1208   auto x = AddParam(
1209       MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam())))
1210           .ValueOrDie(),
1211       &builder);
1212   auto y = AddParam(
1213       MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam())))
1214           .ValueOrDie(),
1215       &builder);
1216   auto config = std::get<2>(GetParam());
1217   if (config.find(',') == config.npos) {
1218     Einsum(x, config);
1219   } else {
1220     Einsum(x, y, config);
1221   }
1222   ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3});
1223 }
1224 
GetEinsumTestCases()1225 std::vector<EinsumParamType> GetEinsumTestCases() {
1226   using v = std::vector<int64_t>;
1227   using p = EinsumParamType;
1228   std::vector<p> test_cases = {
1229       p{v{5, 6}, v{6, 7}, "mk,kn->mn"},
1230       p{v{5, 6}, v{6, 7}, "mk,kn->nm"},
1231       p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn->nmB"},
1232       p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->nmB"},
1233       p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->Bnm"},
1234       p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"},
1235       p{v{5, 6}, v{6, 7}, "ab,cd->dcba"},
1236       p{v{6}, v{6, 7}, "b,bc->c"},
1237       p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc->ab"},
1238       p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba->ca"},
1239       p{v{77}, v{77}, "a,a->a"},
1240       p{v{77}, v{77, 55}, "a,ab->ba"},
1241       p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"},
1242       p{v{55}, v{}, "a,->a"},
1243       p{v{11, 111}, v{11}, "ab,a->ab"},
1244       p{v{16, 34}, v{16, 34}, "ab,ab->ab"},
1245       p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac->abc"},
1246       p{v{5, 19}, v{}, "ab,->ab"},
1247       p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf->bhqk"},
1248       p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn->...mn"},
1249       p{v{5, 6}, v{6, 7}, "...mk,...kn->...mn"},
1250       p{v{5, 6}, v{6, 7}, "...mk,kn->...mn"},
1251       p{v{6, 6}, v{7, 7}, "mm,nn->mn"},
1252       p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
1253       p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
1254       p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn->...mn"},
1255       p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->n"},
1256       p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb...->ba...ij"},
1257       p{v{5, 6}, v{6, 7}, "mk,kn"},
1258       p{v{5, 6}, v{6, 7}, "mk,kn"},
1259       p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn"},
1260       p{v{5, 6}, v{6, 7}, "ab,cd"},
1261       p{v{6}, v{6, 7}, "b,bc"},
1262       p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc"},
1263       p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba"},
1264       p{v{77}, v{77}, "a,a"},
1265       p{v{77}, v{77, 55}, "a,ab"},
1266       p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb"},
1267       p{v{55}, v{}, "a"},
1268       p{v{11, 111}, v{11}, "ab,a"},
1269       p{v{16, 34}, v{16, 34}, "ab,ab"},
1270       p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac"},
1271       p{v{5, 19}, v{}, "ab"},
1272       p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf"},
1273       p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn"},
1274       p{v{5, 6}, v{}, "...mk"},
1275       p{v{5, 6, 12, 13}, v{}, "...mk"},
1276       p{v{5, 6, 12, 13}, v{}, "m...k"},
1277       p{v{5, 6, 12, 13}, v{}, "mk..."},
1278       p{v{5, 6}, v{6, 7}, "...mk->km..."},
1279       p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
1280       p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
1281       p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn"},
1282       p{v{16, 16, 16}, v{}, "iii"},
1283       p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb..."},
1284   };
1285   return test_cases;
1286 }
1287 
1288 INSTANTIATE_TEST_SUITE_P(Einsum, EinsumTest,
1289                          ::testing::ValuesIn(GetEinsumTestCases()));
1290 
1291 using BatchDotParamType = std::tuple<std::vector<int64_t>, std::vector<int64_t>,
1292                                      std::vector<int64_t>>;
1293 class BatchDotTest : public DotOperationTest,
1294                      public ::testing::WithParamInterface<BatchDotParamType> {};
XLA_TEST_P(BatchDotTest,BroadcastingBatchDotTest)1295 XLA_TEST_P(BatchDotTest, BroadcastingBatchDotTest) {
1296   XlaBuilder builder(TestName());
1297   auto x = AddParam(
1298       MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam())))
1299           .ValueOrDie(),
1300       &builder);
1301   auto y = AddParam(
1302       MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam())))
1303           .ValueOrDie(),
1304       &builder);
1305   auto batch_dot = BatchDot(x, y);
1306   auto output_shape = builder.GetShape(batch_dot).ValueOrDie();
1307   EXPECT_EQ(output_shape.dimensions(), std::get<2>(GetParam()));
1308   ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3});
1309 }
1310 
GetBatchDotTestCases()1311 std::vector<BatchDotParamType> GetBatchDotTestCases() {
1312   using v = std::vector<int64_t>;
1313   using p = BatchDotParamType;
1314   std::vector<p> test_cases = {
1315       p{v{5, 6}, v{6, 7}, v{5, 7}},
1316       p{v{5, 6, 11}, v{5, 11, 7}, v{5, 6, 7}},
1317       p{v{5, 6, 11}, v{11, 7}, v{5, 6, 7}},
1318       p{v{5, 6, 11}, v{1, 11, 7}, v{5, 6, 7}},
1319       p{v{6, 11}, v{5, 11, 7}, v{5, 6, 7}},
1320       p{v{1, 6, 11}, v{5, 11, 7}, v{5, 6, 7}},
1321       p{v{8, 1, 2, 3}, v{8, 3, 4}, v{8, 8, 2, 4}},
1322       p{v{8, 8, 2, 3}, v{8, 1, 3, 2}, v{8, 8, 2, 2}},
1323   };
1324   return test_cases;
1325 }
1326 
1327 INSTANTIATE_TEST_SUITE_P(BatchDot, BatchDotTest,
1328                          ::testing::ValuesIn(GetBatchDotTestCases()));
1329 
1330 class DotOperationTextTest : public HloTestBase {};
1331 
XLA_TEST_F(DotOperationTextTest,DotReorderedDotDims)1332 XLA_TEST_F(DotOperationTextTest, DotReorderedDotDims) {
1333   absl::string_view hlo_string =
1334       R"(
1335 HloModule ComplexDotMultipleNonContracting
1336 
1337 ENTRY %test {
1338   %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0)
1339   %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1)
1340   ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3}
1341 }
1342 )";
1343 
1344   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
1345 }
1346 
XLA_TEST_F(DotOperationTextTest,DotReorderedDotDimsAndMultipleContracting)1347 XLA_TEST_F(DotOperationTextTest, DotReorderedDotDimsAndMultipleContracting) {
1348   absl::string_view hlo_string =
1349       R"(
1350 HloModule ComplexDotMultipleNonContracting
1351 
1352 ENTRY %test {
1353   %lhs = f32[7,5,17,10,13]{4,3,2,1,0} parameter(0)
1354   %rhs = f32[7,9,10,13,6,5]{5,4,3,2,1,0} parameter(1)
1355   ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={3,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={1,4}, rhs_contracting_dims={5,3}
1356 }
1357 )";
1358 
1359   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
1360 }
1361 
XLA_TEST_F(DotOperationTextTest,DotWithNoDnums)1362 XLA_TEST_F(DotOperationTextTest, DotWithNoDnums) {
1363   absl::string_view hlo_string =
1364       R"(
1365 HloModule DotWithNoDnums
1366 
1367 ENTRY %test {
1368   %lhs = f32[2,3]{1,0} parameter(0)
1369   %rhs = f32[4,5]{1,0} parameter(1)
1370   ROOT %dot = f32[2,3,4,5]{3,2,1,0} dot(%lhs, %rhs)
1371 }
1372 )";
1373 
1374   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
1375 }
1376 
XLA_TEST_F(DotOperationTextTest,Einsum)1377 XLA_TEST_F(DotOperationTextTest, Einsum) {
1378   absl::string_view hlo_string =
1379       R"(
1380 HloModule Einsum
1381 
1382 ENTRY %test {
1383   %lhs = f32[8,64,96]{2,1,0} parameter(0)
1384   %rhs = f32[96,32,4]{2,1,0} parameter(1)
1385   ROOT %dot = f32[8,64,32,4]{3,2,1,0}  dot(%lhs, %rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0}
1386 }
1387 )";
1388 
1389   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1390 }
1391 
XLA_TEST_F(DotOperationTextTest,CpuTiledDotEmitterCachingBug_1)1392 XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_1) {
1393   // Tests for a caching bug in the XLA CPU backend.
1394   absl::string_view hlo_string =
1395       R"(
1396 HloModule CpuTiledDotEmitterCachingBug
1397 
1398 ENTRY main {
1399   lhs = f32[20,40] parameter(0)
1400   rhs_0 = f32[40,1] parameter(2)
1401   rhs_1 = f32[1,40] parameter(1)
1402 
1403   dot_0 = f32[20,1] dot(lhs, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1404   dot_1 = f32[20,1] dot(lhs, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
1405 
1406   ROOT result = f32[20,1] divide(dot_0, dot_1)
1407 }
1408 )";
1409 
1410   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1411 }
1412 
XLA_TEST_F(DotOperationTextTest,CpuTiledDotEmitterCachingBug_2)1413 XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_2) {
1414   // Tests for a caching bug in the XLA CPU backend.
1415   absl::string_view hlo_string =
1416       R"(
1417 HloModule CpuTiledDotEmitterCachingBug
1418 
1419 ENTRY main {
1420   lhs_0 = f32[20,40] parameter(0)
1421   rhs_0 = f32[40,1] parameter(1)
1422   lhs_1 = f32[1,40] parameter(2)
1423   rhs_1 = f32[20,40] parameter(3)
1424 
1425   dot_0 = f32[20,1] dot(lhs_0, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1426   dot_1 = f32[1,20] dot(lhs_1, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
1427 
1428   dot_0_reshaped = f32[20] reshape(dot_0)
1429   dot_1_reshaped = f32[20] reshape(dot_1)
1430 
1431   ROOT result = f32[20] divide(dot_0_reshaped, dot_1_reshaped)
1432 }
1433 )";
1434 
1435   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1436 }
1437 
XLA_TEST_F(DotOperationTextTest,S32IotaDot)1438 XLA_TEST_F(DotOperationTextTest, S32IotaDot) {
1439   absl::string_view hlo_string =
1440       R"(
1441 HloModule SmallIntegerDot
1442 
1443 ENTRY SmallIntegerDot {
1444   arg0 = s32[5,55,8] iota(), iota_dimension=1
1445   arg1 = s32[5,8,200] iota(), iota_dimension=2
1446   ROOT dot = s32[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1447 }
1448 )";
1449 
1450   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1451 }
1452 
XLA_TEST_F(DotOperationTextTest,S32IotaSquaredDot)1453 XLA_TEST_F(DotOperationTextTest, S32IotaSquaredDot) {
1454   absl::string_view hlo_string =
1455       R"(
1456 HloModule SmallIntegerDot
1457 
1458 ENTRY SmallIntegerDot {
1459   arg0 = s32[16,2] iota(), iota_dimension=0
1460   a = s32[16,2] multiply(arg0, arg0)
1461   r = s32[16,2] multiply(a, a)
1462   arg1 = s32[2,98] iota(), iota_dimension=1
1463   b = s32[2,98] multiply(arg1, arg1)
1464   s = s32[2,98] multiply(b, b)
1465   ROOT dot = s32[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1466 }
1467 )";
1468 
1469   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1470 }
1471 
XLA_TEST_F(DotOperationTextTest,U16IotaDot)1472 XLA_TEST_F(DotOperationTextTest, U16IotaDot) {
1473   absl::string_view hlo_string =
1474       R"(
1475 HloModule SmallIntegerDot
1476 
1477 ENTRY SmallIntegerDot {
1478   arg0 = u16[5,55,8] parameter(0)
1479   arg1 = u16[5,8,200] parameter(1)
1480   dot = u16[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1481   ROOT c = s32[5,55,200] convert(dot)
1482 }
1483 )";
1484 
1485   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1486 }
1487 
XLA_TEST_F(DotOperationTextTest,U16IotaSquaredDot)1488 XLA_TEST_F(DotOperationTextTest, U16IotaSquaredDot) {
1489   absl::string_view hlo_string =
1490       R"(
1491 HloModule SmallIntegerDot
1492 
1493 ENTRY SmallIntegerDot {
1494   arg0 = u16[16,2] iota(), iota_dimension=0
1495   a = u16[16,2] multiply(arg0, arg0)
1496   r = u16[16,2] multiply(a, a)
1497   arg1 = u16[2,98] iota(), iota_dimension=1
1498   b = u16[2,98] multiply(arg1, arg1)
1499   s = u16[2,98] multiply(b, b)
1500   ROOT dot = u16[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1501 }
1502 )";
1503 
1504   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1505 }
1506 
XLA_TEST_F(DotOperationTextTest,S16IotaDot)1507 XLA_TEST_F(DotOperationTextTest, S16IotaDot) {
1508   absl::string_view hlo_string =
1509       R"(
1510 HloModule SmallIntegerDot
1511 
1512 ENTRY SmallIntegerDot {
1513   arg0 = s16[5,55,8] iota(), iota_dimension=1
1514   arg1 = s16[5,8,200] iota(), iota_dimension=2
1515   ROOT dot = s16[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1516 }
1517 )";
1518 
1519   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1520 }
1521 
XLA_TEST_F(DotOperationTextTest,S16IotaSquaredDot)1522 XLA_TEST_F(DotOperationTextTest, S16IotaSquaredDot) {
1523   absl::string_view hlo_string =
1524       R"(
1525 HloModule SmallIntegerDot
1526 
1527 ENTRY SmallIntegerDot {
1528   arg0 = s16[16,2] iota(), iota_dimension=0
1529   a = s16[16,2] multiply(arg0, arg0)
1530   r = s16[16,2] multiply(a, a)
1531   arg1 = s16[2,98] iota(), iota_dimension=1
1532   b = s16[2,98] multiply(arg1, arg1)
1533   s = s16[2,98] multiply(b, b)
1534   ROOT dot = s16[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1535 }
1536 )";
1537 
1538   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1539 }
1540 
XLA_TEST_F(DotOperationTextTest,PREDDot)1541 XLA_TEST_F(DotOperationTextTest, PREDDot) {
1542   absl::string_view hlo_string =
1543       R"(
1544 HloModule SmallIntegerDot
1545 
1546 ENTRY SmallIntegerDot {
1547   arg0 = pred[20,2] parameter(0)
1548   arg1 = pred[2,20] parameter(1)
1549   ROOT dot = pred[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1550 }
1551 )";
1552 
1553   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1554 }
1555 
XLA_TEST_F(DotOperationTextTest,S8Dot)1556 XLA_TEST_F(DotOperationTextTest, S8Dot) {
1557   absl::string_view hlo_string =
1558       R"(
1559 HloModule SmallIntegerDot
1560 
1561 ENTRY SmallIntegerDot {
1562   arg0 = s8[20,2] parameter(0)
1563   arg1 = s8[2,20] parameter(1)
1564   ROOT dot = s8[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1565 }
1566 )";
1567 
1568   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1569 }
1570 
XLA_TEST_F(DotOperationTextTest,S32Dot)1571 XLA_TEST_F(DotOperationTextTest, S32Dot) {
1572   absl::string_view hlo_string =
1573       R"(
1574 HloModule SmallIntegerDot
1575 
1576 ENTRY SmallIntegerDot {
1577   arg0 = s32[20,55] parameter(0)
1578   arg1 = s32[55,20] parameter(1)
1579   ROOT dot = s32[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1580 }
1581 )";
1582 
1583   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1584 }
1585 
XLA_TEST_F(DotOperationTextTest,GpuTransposeOutput)1586 XLA_TEST_F(DotOperationTextTest, GpuTransposeOutput) {
1587   absl::string_view hlo_string =
1588       R"(
1589 HloModule TransposeOutput
1590 
1591 ENTRY TransposeOutput {
1592   p0 = f32[32,32] parameter(0)
1593   p1 = f32[32,64] parameter(1)
1594   dot = f32[32,64] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
1595   ROOT tr = f32[64,32] transpose(dot), dimensions={1,0}
1596 }
1597 )";
1598 
1599   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1600 }
1601 
XLA_TEST_F(DotOperationTextTest,MatrixVectorComplex)1602 XLA_TEST_F(DotOperationTextTest, MatrixVectorComplex) {
1603   absl::string_view hlo_string =
1604       R"(
1605 HloModule MatrixVectorComplex
1606 
1607 ENTRY MatrixVectorComplex {
1608   p0 = c64[5,5] parameter(0)
1609   p1 = c64[5,1] parameter(1)
1610   p2 = c64[5,1] parameter(2)
1611   dot = c64[5,1] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1612   ROOT add = c64[5,1] add(dot, p2)
1613 }
1614 )";
1615 
1616   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
1617                           ParseAndReturnUnverifiedModule(hlo_string));
1618   EXPECT_TRUE(RunAndCompare(std::move(hlo_module), ErrorSpec{4e-3, 4e-3}));
1619 }
1620 
XLA_TEST_F(DotOperationTextTest,MatrixVectorBF16)1621 XLA_TEST_F(DotOperationTextTest, MatrixVectorBF16) {
1622   absl::string_view hlo_string =
1623       R"(
1624 HloModule MatrixVectorBF16
1625 
1626 ENTRY MatrixVectorBF16 {
1627   p0 = bf16[128] parameter(0)
1628   p1 = bf16[128,256] parameter(1)
1629   p2 = bf16[256] parameter(2)
1630   dot = bf16[256] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
1631   ROOT add = bf16[256] add(dot, p2)
1632 }
1633 )";
1634 
1635   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1636 }
1637 
1638 // Regression test for b/138155357, where we were incorrectly creating a dot-add
1639 // fusion where the dot had a batch dimension.  This isn't supported on the CPU
1640 // backend.
XLA_TEST_F(DotOperationTextTest,FusedBatchDotRegressionTest)1641 XLA_TEST_F(DotOperationTextTest, FusedBatchDotRegressionTest) {
1642   absl::string_view module_string = R"(
1643 HloModule jaxpr_computation__5.33
1644 
1645 jaxpr_computation__6.8 {
1646   tuple.9 = () tuple()
1647   parameter.14 = () parameter(4)
1648   parameter.13 = (f32[2]{0}) parameter(3)
1649   get-tuple-element.15 = f32[2]{0} get-tuple-element(parameter.13), index=0
1650   reshape.16 = f32[1,2]{1,0} reshape(get-tuple-element.15)
1651   parameter.10 = f32[2,2]{1,0} parameter(0)
1652   reshape.17 = f32[2,1]{1,0} reshape(get-tuple-element.15)
1653   dot.18 = f32[2,1]{1,0} dot(parameter.10, reshape.17), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1654   reshape.19 = f32[2]{0} reshape(dot.18)
1655   reshape.20 = f32[2,1]{1,0} reshape(reshape.19)
1656   dot.21 = f32[1,1]{1,0} dot(reshape.16, reshape.20), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1657   reshape.22 = f32[] reshape(dot.21)
1658   parameter.11 = f32[2,1,2]{2,1,0} parameter(1)
1659   broadcast.23 = f32[2,2,1]{2,1,0} broadcast(reshape.20), dimensions={1,2}
1660   dot.24 = f32[2,1,1]{2,1,0} dot(parameter.11, broadcast.23), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1661   broadcast.25 = f32[2,1,2]{2,1,0} broadcast(reshape.16), dimensions={1,2}
1662   parameter.12 = f32[2,2,1]{2,1,0} parameter(2)
1663   dot.26 = f32[2,1,1]{2,1,0} dot(broadcast.25, parameter.12), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1664   add.27 = f32[2,1,1]{2,1,0} add(dot.24, dot.26)
1665   reshape.28 = f32[2]{0} reshape(add.27)
1666   ROOT tuple.29 = (f32[], f32[2]{0}) tuple(reshape.22, reshape.28)
1667 }
1668 
1669 ENTRY jaxpr_computation__5.33 {
1670   constant.2 = f32[] constant(1)
1671   broadcast.3 = f32[2,2]{1,0} broadcast(constant.2), dimensions={}
1672   constant.5 = f32[2,1,2]{2,1,0} constant({ { { 1, 0 } }, { { 0, 1 } } })
1673   constant.4 = f32[2,2,1]{2,1,0} constant({ { {1}, {1} }, { {1}, {1} } })
1674   parameter.6 = f32[2]{0} parameter(0)
1675   tuple.7 = (f32[2]{0}) tuple(parameter.6)
1676   tuple.1 = () tuple()
1677   call.30 = (f32[], f32[2]{0}) call(broadcast.3, constant.5, constant.4, tuple.7, tuple.1), to_apply=jaxpr_computation__6.8
1678   get-tuple-element.31 = f32[] get-tuple-element(call.30), index=0
1679   ROOT get-tuple-element.32 = f32[2]{0} get-tuple-element(call.30), index=1
1680 })";
1681   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1682                           ParseAndReturnVerifiedModule(module_string));
1683   EXPECT_TRUE(RunAndCompare(std::move(module), /*error=*/std::nullopt));
1684 }
1685 
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstLHS_RL)1686 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstLHS_RL) {
1687   Array3D<float> input_arr(2, 3, 2);
1688   Array2D<float> const_arr(2, 6);
1689   input_arr.FillIota(0);
1690   const_arr.FillIota(0);
1691 
1692   XlaBuilder builder(TestName());
1693   auto t0 =
1694       AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder);
1695   auto t1 = Transpose(t0, {1, 0, 2});
1696   auto rhs = Reshape(t1, {6, 2});
1697   auto lhs = ConstantR2FromArray2D(&builder, const_arr);
1698   Dot(lhs, rhs);
1699 
1700   ComputeAndCompare(&builder, {}, error_spec_);
1701 }
1702 
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstRHS_LR)1703 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_LR) {
1704   Array3D<float> input_arr(2, 3, 2);
1705   Array2D<float> const_arr(2, 6);
1706   input_arr.FillIota(0);
1707   const_arr.FillIota(0);
1708 
1709   XlaBuilder builder(TestName());
1710   auto t0 =
1711       AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder);
1712   auto t1 = Transpose(t0, {1, 0, 2});
1713   auto lhs = Reshape(t1, {6, 2});
1714   auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1715 
1716   DotDimensionNumbers dims;
1717   dims.add_lhs_contracting_dimensions(0);
1718   dims.add_rhs_contracting_dimensions(1);
1719   DotGeneral(lhs, rhs, dims);
1720 
1721   ComputeAndCompare(&builder, {}, error_spec_);
1722 }
1723 
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstRHS_RL)1724 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_RL) {
1725   Array4D<float> input_arr(2, 2, 3, 4);
1726   Array2D<float> const_arr(24, 2);
1727   input_arr.FillIota(0);
1728   const_arr.FillIota(0);
1729 
1730   XlaBuilder builder(TestName());
1731   auto t0 =
1732       AddParam(LiteralUtil::CreateR4FromArray4D<float>(input_arr), &builder);
1733   auto t1 = Transpose(t0, {0, 2, 3, 1});
1734   auto lhs = Reshape(t1, {2, 24});
1735   auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1736   Dot(lhs, rhs);
1737 
1738   ComputeAndCompare(&builder, {}, error_spec_);
1739 }
1740 
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstRHS_MM)1741 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_MM) {
1742   Array3D<float> input_arr(2, 6, 2);
1743   Array3D<float> const_arr(2, 6, 3);
1744   input_arr.FillIota(0);
1745   const_arr.FillIota(0);
1746 
1747   XlaBuilder builder(TestName());
1748   auto t0 =
1749       AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder);
1750   auto t1 = Reshape(t0, {2, 2, 3, 2});
1751   auto t2 = Transpose(t1, {0, 2, 1, 3});
1752   auto lhs = Reshape(t2, {2, 6, 2});
1753   auto rhs = ConstantR3FromArray3D(&builder, const_arr);
1754 
1755   DotDimensionNumbers dims;
1756   dims.add_lhs_contracting_dimensions(1);
1757   dims.add_rhs_contracting_dimensions(1);
1758   dims.add_lhs_batch_dimensions(0);
1759   dims.add_rhs_batch_dimensions(0);
1760   DotGeneral(lhs, rhs, dims);
1761 
1762   ComputeAndCompare(&builder, {}, error_spec_);
1763 }
1764 
XLA_TEST_F(DotOperationTest,ReorderContractingDims_Multipass)1765 XLA_TEST_F(DotOperationTest, ReorderContractingDims_Multipass) {
1766   Array4D<float> input_arr(2, 2, 3, 5);
1767   Array2D<float> const_arr(2, 30);
1768   input_arr.FillIota(0);
1769   const_arr.FillIota(0);
1770 
1771   XlaBuilder builder(TestName());
1772   auto t0 =
1773       AddParam(LiteralUtil::CreateR4FromArray4D<float>(input_arr), &builder);
1774   auto t1 = Transpose(t0, {0, 2, 1, 3});
1775   auto t2 = Reshape(t1, {2, 6, 5});
1776   auto t3 = Transpose(t2, {0, 2, 1});
1777   auto lhs = Reshape(t3, {2, 30});
1778   auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1779 
1780   DotDimensionNumbers dims;
1781   dims.add_lhs_contracting_dimensions(1);
1782   dims.add_rhs_contracting_dimensions(1);
1783   DotGeneral(lhs, rhs, dims);
1784 
1785   // Constant folding are disabled by default in unit tests. algsimp
1786   // optimization can be applied multiple times if we fold the transpose
1787   // and reshape that are moved to the constant side of the dot.
1788   mutable_debug_options()->clear_xla_disable_hlo_passes();
1789   ComputeAndCompare(&builder, {}, error_spec_);
1790 }
1791 
XLA_TEST_F(DotOperationTextTest,WiderIntegralResultAccumulation)1792 XLA_TEST_F(DotOperationTextTest, WiderIntegralResultAccumulation) {
1793   absl::string_view hlo_string =
1794       R"(
1795 HloModule WiderIntegralAccumulation
1796 
1797 ENTRY MatrixVectorComplex {
1798   p0 = s8[5,5]{1,0} parameter(0)
1799   p1 = s16[5,1]{0,1} parameter(1)
1800   ROOT dot = s32[5,1]{1,0} dot(p0, p1), lhs_contracting_dims={1},
1801                                         rhs_contracting_dims={0}
1802 }
1803 )";
1804 
1805   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1806 }
1807 
1808 // This benchmark is to show the performance impact of the following
1809 // transformation:
1810 //   dot(reshape(transpose(A)), Const) ==>
1811 //   dot(reshape(A), reshape(transpose(reshape(Const)))),
1812 // and then fold the reshape and transpose on the Const side.
1813 // We can compare performance with and without algsimp pass to see the impact.
DOT_ReorderContracting(::testing::benchmark::State & state)1814 void DOT_ReorderContracting(::testing::benchmark::State& state) {
1815   se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
1816   auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
1817   se::StreamExecutorMemoryAllocator allocator(platform, executors);
1818 
1819   xla::LocalClientOptions client_options;
1820   client_options.set_platform(platform);
1821   auto client =
1822       ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie();
1823 
1824   int device_ordinal = client->default_device_ordinal();
1825 
1826   const int64_t d0 = 128;
1827   const int64_t d1 = 128;
1828   const int64_t d2 = 128;
1829   const int64_t d3 = 128;
1830 
1831   Array3D<float> input_arr(d0, d1, d2);
1832   Array2D<float> const_arr(d1 * d2, d3);
1833   input_arr.FillIota(0);
1834   const_arr.FillIota(0);
1835   XlaBuilder builder("ReorderContracting");
1836   auto t0 =
1837       Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {d0, d1, d2}), "param0");
1838   auto t1 = Transpose(t0, {0, 2, 1});
1839   auto lhs = Reshape(t1, {d0, d2 * d1});
1840   auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1841   Dot(lhs, rhs);
1842   auto computation = builder.Build().value();
1843 
1844   auto input_literal = LiteralUtil::CreateR3FromArray3D<float>(input_arr);
1845   ScopedShapedBuffer buffer0 =
1846       client->LiteralToShapedBuffer(input_literal, device_ordinal).value();
1847 
1848   TF_ASSERT_OK_AND_ASSIGN(
1849       auto executables, client->Compile(computation, {&buffer0.on_host_shape()},
1850                                         ExecutableBuildOptions()));
1851   auto executable = std::move(executables[0]);
1852 
1853   se::Stream stream(executors[device_ordinal]);
1854   stream.Init();
1855 
1856   ExecutableRunOptions options;
1857   options.set_allocator(&allocator);
1858 
1859   const int kWarmups = 2;
1860   for (int i = 0; i < kWarmups; ++i) {
1861     ASSERT_IS_OK(executable->Run({&buffer0}, options));
1862   }
1863 
1864   const int64_t total_bytes = d0 * d1 * d2 + d1 * d2 * d3 + d0 * d3;
1865   for (auto s : state) {
1866     ASSERT_IS_OK(executable->Run({&buffer0}, options));
1867   }
1868   state.SetBytesProcessed(state.iterations() * total_bytes * sizeof(float));
1869 }
1870 
1871 BENCHMARK(DOT_ReorderContracting)->UseRealTime();
1872 
1873 }  // namespace
1874 }  // namespace xla
1875