1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "absl/strings/str_cat.h"
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/array3d.h"
22 #include "tensorflow/compiler/xla/client/lib/matrix.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/primitive_util.h"
26 #include "tensorflow/compiler/xla/reference_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_parser.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/tests/test_utils.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/platform/test_benchmark.h"
36
37 namespace xla {
38 namespace {
39
40 class DotOperationTest : public ClientLibraryTestBase {
41 public:
42 ErrorSpec error_spec_{0.0001, 1e-5};
43 };
44
45 using TypesF16F32 = ::testing::Types<
46 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
47 Eigen::half,
48 #endif
49 float>;
50
51 using TypesF16F32F64 = ::testing::Types<
52 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
53 Eigen::half,
54 #endif
55 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
56 double,
57 #endif
58 float>;
59
60 using TypesF16F32F64CF64 = ::testing::Types<
61 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
62 Eigen::half,
63 #endif
64 #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64)
65 double, complex64,
66 #endif
67 float>;
68
69 // Check that we can safely pass an input tuple's elements to a dot operation.
XLA_TEST_F(DotOperationTest,DotOfInputTupleElem)70 XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
71 XlaBuilder builder(TestName());
72
73 XlaOp param;
74 TF_ASSERT_OK_AND_ASSIGN(
75 auto param_data,
76 CreateParameterAndTransferLiteral(
77 0,
78 LiteralUtil::MakeTupleFromSlices(
79 {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
80 LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}})}),
81 "arg0", &builder, ¶m));
82 auto lhs = GetTupleElement(param, 0);
83 auto rhs = GetTupleElement(param, 1);
84 Dot(lhs, rhs);
85
86 ComputeAndCompareLiteral(&builder,
87 LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
88 {param_data.get()});
89 }
90
91 template <typename T>
92 class DotOperationTest_F16F32F64CF64 : public DotOperationTest {};
93 TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64);
94
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,ZeroElementVectorDot)95 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) {
96 using T = TypeParam;
97 XlaBuilder builder(this->TestName());
98
99 auto lhs = ConstantR1<T>(&builder, {});
100 auto rhs = ConstantR1<T>(&builder, {});
101 Dot(lhs, rhs);
102
103 this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(0.0), {},
104 this->error_spec_);
105 }
106
107 template <typename T>
108 class DotOperationTest_F16F32F64 : public DotOperationTest {};
109 TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64);
110
XLA_TYPED_TEST(DotOperationTest_F16F32F64,TrivialMatrixVectorDot)111 XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) {
112 using T = TypeParam;
113 XlaBuilder builder(this->TestName());
114 auto lhs = ConstantR2FromArray2D<T>(&builder, {{3.0f, 4.0f}});
115 auto rhs = ConstantFromArray<T>(&builder, {3.0f, 4.0f});
116 Dot(lhs, rhs);
117
118 this->template ComputeAndCompareR1<T>(&builder, {static_cast<T>(25.0f)}, {},
119 this->error_spec_);
120 }
121
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,OneElementVectorDot)122 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, OneElementVectorDot) {
123 using T = TypeParam;
124 XlaBuilder builder(this->TestName());
125 auto lhs = ConstantR1<T>(&builder, {static_cast<T>(2.0f)});
126 auto rhs = ConstantR1<T>(&builder, {static_cast<T>(3.0f)});
127 Dot(lhs, rhs);
128
129 this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(6.0f), {},
130 this->error_spec_);
131 }
132
XLA_TYPED_TEST(DotOperationTest_F16F32F64,VectorDot)133 XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) {
134 using T = TypeParam;
135 XlaBuilder builder(this->TestName());
136 auto lhs = ConstantFromArray<T>(&builder, {1.0f, 2.5f, 42.0f});
137 auto rhs = ConstantFromArray<T>(&builder, {11.0f, -1.0f, 0.5f});
138 Dot(lhs, rhs);
139
140 this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(29.5f), {},
141 this->error_spec_);
142 }
143
MinorToMajorForIsRowMajor(bool row_major)144 std::vector<int64_t> MinorToMajorForIsRowMajor(bool row_major) {
145 return {row_major ? 1 : 0, row_major ? 0 : 1};
146 }
147
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_0x2_2x0)148 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x0) {
149 using T = TypeParam;
150 XlaBuilder builder(this->TestName());
151 auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
152 auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
153 Dot(lhs, rhs);
154
155 this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 0), {},
156 this->error_spec_);
157 }
158
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_0x2_2x3)159 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_0x2_2x3) {
160 using T = TypeParam;
161 XlaBuilder builder(this->TestName());
162 auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
163 auto rhs = ConstantR2FromArray2D<T>(
164 &builder, {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}});
165 Dot(lhs, rhs);
166
167 this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 3), {},
168 this->error_spec_);
169 }
170
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_3x2_2x0)171 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_3x2_2x0) {
172 using T = TypeParam;
173 XlaBuilder builder(this->TestName());
174 auto lhs = ConstantR2FromArray2D<T>(
175 &builder, {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}});
176 auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
177 Dot(lhs, rhs);
178
179 this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(3, 0), {},
180 this->error_spec_);
181 }
182
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,Dot_2x0_0x2)183 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, Dot_2x0_0x2) {
184 using T = TypeParam;
185 XlaBuilder builder(this->TestName());
186 auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
187 auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
188 Dot(lhs, rhs);
189
190 this->template ComputeAndCompareR2<T>(
191 &builder, Array2D<T>(2, 2, static_cast<T>(0.0f)), {}, this->error_spec_);
192 }
193
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,FusedDot)194 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) {
195 using T = TypeParam;
196 XlaBuilder builder(this->TestName());
197 auto param0 =
198 Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 4}), "arg0");
199 auto param1 =
200 Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({4, 1}), "arg1");
201 auto exp0 = Exp(param0);
202 Dot(exp0, param1);
203
204 auto lhs_handle =
205 this->client_
206 ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
207 {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
208 .value();
209 auto rhs_handle = this->client_
210 ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
211 {{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
212 .value();
213
214 if (std::is_same<Eigen::half, T>::value) {
215 this->error_spec_ = ErrorSpec{0.0001, 1e-3};
216 }
217
218 this->template ComputeAndCompareR2<T>(
219 &builder, Array2D<T>({{296.14560492846033f}, {0.8611737683031964f}}),
220 {lhs_handle.get(), rhs_handle.get()}, this->error_spec_);
221 }
222
223 template <typename T>
224 class SquareMatrixDot : public DotOperationTest {
225 public:
TestImpl(bool lhs_row_major,bool rhs_row_major)226 void TestImpl(bool lhs_row_major, bool rhs_row_major) {
227 auto lhs_handle =
228 client_
229 ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
230 {{1.0f, 2.0f}, {3.0f, -4.0f}},
231 LayoutUtil::MakeLayout(
232 MinorToMajorForIsRowMajor(lhs_row_major))))
233 .value();
234 auto rhs_handle =
235 client_
236 ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
237 {{1.0f, 6.0f}, {7.0f, -4.0f}},
238 LayoutUtil::MakeLayout(
239 MinorToMajorForIsRowMajor(rhs_row_major))))
240 .value();
241 XlaBuilder builder(TestName());
242 auto prim_type = primitive_util::NativeToPrimitiveType<T>();
243 Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"),
244 Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs"));
245
246 Array2D<T> expected({{15.0f, -2.0f}, {-25.0f, 34.0f}});
247 ComputeAndCompareR2<T>(&builder, expected,
248 {lhs_handle.get(), rhs_handle.get()}, error_spec_);
249 }
250 };
251
252 TYPED_TEST_CASE(SquareMatrixDot, TypesF16F32F64CF64);
XLA_TYPED_TEST(SquareMatrixDot,TypesFF)253 XLA_TYPED_TEST(SquareMatrixDot, TypesFF) { this->TestImpl(false, false); }
XLA_TYPED_TEST(SquareMatrixDot,TypesFT)254 XLA_TYPED_TEST(SquareMatrixDot, TypesFT) { this->TestImpl(false, true); }
XLA_TYPED_TEST(SquareMatrixDot,TypesTF)255 XLA_TYPED_TEST(SquareMatrixDot, TypesTF) { this->TestImpl(true, false); }
XLA_TYPED_TEST(SquareMatrixDot,TypesTT)256 XLA_TYPED_TEST(SquareMatrixDot, TypesTT) { this->TestImpl(true, true); }
257
258 struct DotTestParam {
259 int m;
260 int k;
261 int n;
262 bool dot_lhs_row_major;
263 bool dot_rhs_row_major;
264 bool has_addend;
265 bool addend_row_major;
266 };
267
PrintDotTestParam(const::testing::TestParamInfo<DotTestParam> & test_param)268 std::string PrintDotTestParam(
269 const ::testing::TestParamInfo<DotTestParam>& test_param) {
270 const DotTestParam& param = test_param.param;
271 if (param.has_addend) {
272 return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor",
273 param.dot_lhs_row_major ? "T" : "F",
274 param.dot_rhs_row_major ? "T" : "F",
275 param.addend_row_major ? "T" : "F");
276 } else {
277 return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor",
278 param.dot_lhs_row_major ? "T" : "F",
279 param.dot_rhs_row_major ? "T" : "F");
280 }
281 }
282
283 class ParametricDotTest : public DotOperationTest,
284 public ::testing::WithParamInterface<DotTestParam> {
285 protected:
286 template <typename NativeT>
287 void TestImpl();
288
289 template <typename NativeT>
290 void ComputeAndCompareR2WithError(XlaBuilder* builder,
291 const Array2D<NativeT>& expected,
292 absl::Span<GlobalData* const> arguments);
293 };
294
295 template <typename NativeT>
ComputeAndCompareR2WithError(XlaBuilder * builder,const Array2D<NativeT> & expected,absl::Span<GlobalData * const> arguments)296 void ParametricDotTest::ComputeAndCompareR2WithError(
297 XlaBuilder* builder, const Array2D<NativeT>& expected,
298 absl::Span<GlobalData* const> arguments) {
299 ErrorSpec error_spec(0.3, 3e-3);
300 ComputeAndCompareR2(builder, expected, arguments, error_spec);
301 }
302
303 template <>
ComputeAndCompareR2WithError(XlaBuilder * builder,const Array2D<Eigen::half> & expected,absl::Span<GlobalData * const> arguments)304 void ParametricDotTest::ComputeAndCompareR2WithError<Eigen::half>(
305 XlaBuilder* builder, const Array2D<Eigen::half>& expected,
306 absl::Span<GlobalData* const> arguments) {
307 ErrorSpec error_spec(0.3, 7e-3);
308 ComputeAndCompareR2(builder, expected, arguments, error_spec);
309 }
310
311 template <>
ComputeAndCompareR2WithError(XlaBuilder * builder,const Array2D<int32_t> & expected,absl::Span<GlobalData * const> arguments)312 void ParametricDotTest::ComputeAndCompareR2WithError<int32_t>(
313 XlaBuilder* builder, const Array2D<int32_t>& expected,
314 absl::Span<GlobalData* const> arguments) {
315 ComputeAndCompareR2(builder, expected, arguments);
316 }
317
318 template <typename NativeT>
TestImpl()319 void ParametricDotTest::TestImpl() {
320 DotTestParam param = GetParam();
321
322 std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
323 MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
324 Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
325 *dot_lhs_data, LayoutUtil::MakeLayout(
326 MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
327 std::unique_ptr<GlobalData> dot_lhs_handle =
328 client_->TransferToServer(dot_lhs_lit).value();
329
330 std::unique_ptr<Array2D<NativeT>> dot_rhs_data =
331 MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n);
332 Layout rhs_layout = LayoutUtil::MakeLayout(
333 MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
334 Literal dot_rhs_lit =
335 LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
336 std::unique_ptr<GlobalData> dot_rhs_handle =
337 client_->TransferToServer(dot_rhs_lit).value();
338
339 std::unique_ptr<Array2D<NativeT>> addend_data;
340 Literal addend_lit;
341 std::unique_ptr<GlobalData> addend_handle;
342
343 if (param.has_addend) {
344 addend_data = MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.n);
345 addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
346 *addend_data, LayoutUtil::MakeLayout(
347 MinorToMajorForIsRowMajor(param.addend_row_major)));
348 addend_handle = client_->TransferToServer(addend_lit).value();
349 }
350
351 XlaBuilder builder(TestName());
352 auto prim_type = primitive_util::NativeToPrimitiveType<NativeT>();
353 auto result =
354 Dot(Parameter(&builder, 0,
355 ShapeUtil::MakeShapeWithLayout(
356 prim_type, {param.m, param.k},
357 MinorToMajorForIsRowMajor(param.dot_lhs_row_major)),
358 "dot_lhs"),
359 Parameter(&builder, 1,
360 ShapeUtil::MakeShapeWithLayout(
361 prim_type, {param.k, param.n},
362 MinorToMajorForIsRowMajor(param.dot_rhs_row_major)),
363 "dot_rhs"));
364
365 if (param.has_addend) {
366 result =
367 Add(result,
368 Parameter(&builder, 2,
369 ShapeUtil::MakeShapeWithLayout(
370 prim_type, {param.m, param.n},
371 MinorToMajorForIsRowMajor(param.addend_row_major)),
372 "addend"));
373 }
374
375 std::unique_ptr<Array2D<NativeT>> expected;
376 if (param.has_addend) {
377 expected = ReferenceUtil::ApplyElementwise2D(
378 std::plus<NativeT>(),
379 *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data),
380 *addend_data);
381 } else {
382 expected = ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data);
383 }
384
385 std::vector<GlobalData*> args = {dot_lhs_handle.get(), dot_rhs_handle.get()};
386 if (param.has_addend) {
387 args.push_back(addend_handle.get());
388 }
389 ComputeAndCompareR2WithError<NativeT>(&builder, *expected, args);
390 }
391
CreateDotTestParameters()392 std::vector<DotTestParam> CreateDotTestParameters() {
393 std::vector<DotTestParam> params;
394
395 auto add_matrix_matrix_dot_test = [&](int m, int k, int n) {
396 for (bool lhs_row_major : {true, false}) {
397 for (bool rhs_row_major : {true, false}) {
398 params.push_back({/*m=*/m, /*k=*/k, /*n=*/n,
399 /*dot_lhs_row_major=*/lhs_row_major,
400 /*dot_rhs_row_major=*/rhs_row_major,
401 /*has_addend=*/false, /*addend_row_major=*/true});
402 }
403 }
404 };
405
406 add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/42);
407 add_matrix_matrix_dot_test(/*m=*/23, /*k=*/1, /*n=*/42);
408 add_matrix_matrix_dot_test(/*m=*/23, /*k=*/42, /*n=*/1);
409 add_matrix_matrix_dot_test(/*m=*/1, /*k=*/23, /*n=*/1);
410 add_matrix_matrix_dot_test(/*m=*/1, /*k=*/1, /*n=*/1);
411 add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7);
412 add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520);
413 add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520);
414
415 return params;
416 }
417
418 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
XLA_TEST_P(ParametricDotTest,TestF16)419 XLA_TEST_P(ParametricDotTest, TestF16) { TestImpl<Eigen::half>(); }
420 #endif
XLA_TEST_P(ParametricDotTest,TestF32)421 XLA_TEST_P(ParametricDotTest, TestF32) { TestImpl<float>(); }
XLA_TEST_P(ParametricDotTest,TestF64)422 XLA_TEST_P(ParametricDotTest, TestF64) { TestImpl<double>(); }
XLA_TEST_P(ParametricDotTest,TestC64)423 XLA_TEST_P(ParametricDotTest, TestC64) { TestImpl<std::complex<float>>(); }
424 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX128
XLA_TEST_P(ParametricDotTest,TestC128)425 XLA_TEST_P(ParametricDotTest, TestC128) { TestImpl<std::complex<double>>(); }
426 #endif
XLA_TEST_P(ParametricDotTest,TestS32)427 XLA_TEST_P(ParametricDotTest, TestS32) { TestImpl<int32_t>(); }
428
429 INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest,
430 ::testing::ValuesIn(CreateDotTestParameters()),
431 PrintDotTestParam);
432
433 class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest {
434 public:
ParametricDotTestWithoutLayoutAssignment()435 ParametricDotTestWithoutLayoutAssignment() {
436 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
437 "layout-assignment");
438 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
439 "hlo-verifier");
440 // Disable algebraic simplification because the pass may replace a dot
441 // instruction with a layout-changing multiplication instruction.
442 execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
443 "algsimp");
444 }
445 };
446
CreateNoLayoutAssignmentDotTestParameters()447 std::vector<DotTestParam> CreateNoLayoutAssignmentDotTestParameters() {
448 std::vector<DotTestParam> params;
449
450 auto add_matrix_vector_dot_test = [&](int k, int n) {
451 for (bool lhs_row_major : {true, false}) {
452 for (bool rhs_row_major : {true, false}) {
453 for (bool has_addend : {true, false}) {
454 // The addend needs to be row major to match the result of the dot.
455 params.push_back({/*m=*/1, /*k=*/k, /*n=*/n,
456 /*dot_lhs_row_major=*/lhs_row_major,
457 /*dot_rhs_row_major=*/rhs_row_major,
458 /*has_addend=*/has_addend,
459 /*addend_row_major=*/true});
460 if (n != 1) {
461 params.push_back({/*m=*/n, /*k=*/k, /*n=*/1,
462 /*dot_lhs_row_major=*/lhs_row_major,
463 /*dot_rhs_row_major=*/rhs_row_major,
464 /*has_addend=*/has_addend,
465 /*addend_row_major=*/true});
466 }
467 }
468 }
469 }
470 };
471
472 add_matrix_vector_dot_test(/*k=*/8, /*n=*/8);
473 add_matrix_vector_dot_test(/*k=*/130, /*n=*/8);
474 add_matrix_vector_dot_test(/*k=*/8, /*n=*/130);
475 add_matrix_vector_dot_test(/*k=*/290, /*n=*/130);
476 add_matrix_vector_dot_test(/*k=*/1, /*n=*/1);
477 add_matrix_vector_dot_test(/*k=*/1, /*n=*/16);
478 add_matrix_vector_dot_test(/*k=*/1, /*n=*/4);
479 add_matrix_vector_dot_test(/*k=*/1, /*n=*/3);
480 add_matrix_vector_dot_test(/*k=*/3, /*n=*/16);
481 add_matrix_vector_dot_test(/*k=*/3, /*n=*/3);
482 add_matrix_vector_dot_test(/*k=*/29, /*n=*/29);
483 add_matrix_vector_dot_test(/*k=*/8, /*n=*/2);
484 add_matrix_vector_dot_test(/*k=*/2, /*n=*/8);
485 add_matrix_vector_dot_test(/*k=*/259, /*n=*/258);
486
487 return params;
488 }
489
490 #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment,TestF16)491 XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF16) {
492 TestImpl<Eigen::half>();
493 }
494 #endif
XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment,TestF32)495 XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF32) {
496 TestImpl<float>();
497 }
498 // TODO(b/147505663): Disabled for now.
XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment,DISABLED_TestF64)499 XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, DISABLED_TestF64) {
500 TestImpl<double>();
501 }
502
503 INSTANTIATE_TEST_CASE_P(
504 DotTests, ParametricDotTestWithoutLayoutAssignment,
505 ::testing::ValuesIn(CreateNoLayoutAssignmentDotTestParameters()),
506 PrintDotTestParam);
507
508 template <typename T>
509 class NonsquareMatrixDot : public DotOperationTest {
510 public:
TestImpl(bool lhs_row_major,bool rhs_row_major)511 void TestImpl(bool lhs_row_major, bool rhs_row_major) {
512 auto lhs_handle =
513 client_
514 ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
515 {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
516 LayoutUtil::MakeLayout(
517 MinorToMajorForIsRowMajor(lhs_row_major))))
518 .value();
519 auto rhs_handle =
520 client_
521 ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
522 {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
523 LayoutUtil::MakeLayout(
524 MinorToMajorForIsRowMajor(rhs_row_major))))
525 .value();
526
527 XlaBuilder builder(TestName());
528 auto prim_type = primitive_util::NativeToPrimitiveType<T>();
529 Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"),
530 Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs"));
531
532 Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}});
533
534 ComputeAndCompareR2<T>(&builder, expected,
535 {lhs_handle.get(), rhs_handle.get()}, error_spec_);
536 }
537 };
538
539 TYPED_TEST_CASE(NonsquareMatrixDot, TypesF16F32F64CF64);
XLA_TYPED_TEST(NonsquareMatrixDot,TestFF)540 XLA_TYPED_TEST(NonsquareMatrixDot, TestFF) { this->TestImpl(false, false); }
XLA_TYPED_TEST(NonsquareMatrixDot,TestFT)541 XLA_TYPED_TEST(NonsquareMatrixDot, TestFT) { this->TestImpl(false, true); }
XLA_TYPED_TEST(NonsquareMatrixDot,TestTF)542 XLA_TYPED_TEST(NonsquareMatrixDot, TestTF) { this->TestImpl(true, false); }
XLA_TYPED_TEST(NonsquareMatrixDot,TestTT)543 XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
544
XLA_TEST_F(DotOperationTest,MatrixVectorC64)545 XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
546 auto lhs_handle =
547 client_
548 ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
549 {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
550 .value();
551 auto rhs_handle =
552 client_
553 ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
554 {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
555 LayoutUtil::MakeLayout({1, 0})))
556 .value();
557
558 XlaBuilder builder(TestName());
559 auto prim_type = primitive_util::NativeToPrimitiveType<complex64>();
560 Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"),
561 Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs"));
562
563 Array2D<complex64> expected({{30.0, -2.0}});
564
565 ComputeAndCompareR2<complex64>(
566 &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
567 }
568
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,ConcurrentMatMult)569 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ConcurrentMatMult) {
570 using T = TypeParam;
571
572 XlaBuilder builder(this->TestName());
573 auto matrix1 =
574 ConstantR2FromArray2D<T>(&builder, {{1.0f, 2.0f}, {3.0f, 4.0f}});
575 auto matrix2 =
576 ConstantR2FromArray2D<T>(&builder, {{5.0f, 6.0f}, {7.0f, 8.0f}});
577 auto matrix12 = Dot(matrix1, matrix2);
578 auto matrix21 = Dot(matrix2, matrix1);
579 Add(matrix12, matrix21);
580
581 Array2D<T> expected({{42.0f, 56.0f}, {74.0f, 96.0f}});
582 this->template ComputeAndCompareR2<T>(&builder, expected, {},
583 this->error_spec_);
584 }
585
586 template <typename T>
587 class DotOperationTestForBatchMatMul : public DotOperationTest {};
588 TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64);
589
590 // Regression test for b/32055648. The root of the graph is a kFusion of 4
591 // bitcasts. Although bitcasts don't map to thunks, the root should still be
592 // sync-dependent on bitcasts' operands.
XLA_TYPED_TEST(DotOperationTestForBatchMatMul,Types)593 XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
594 using T = TypeParam;
595 XlaBuilder builder(this->TestName());
596 auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
597 "x");
598 auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
599 "y");
600
601 auto x_flat = Reshape(x, {0, 1, 2, 3}, {4, 2, 2});
602 auto y_flat = Reshape(y, {0, 1, 2, 3}, {4, 2, 2});
603
604 // Slice batches into individual matrices and multiply them.
605 std::vector<XlaOp> out_slices;
606 const auto n = 4;
607 out_slices.reserve(n);
608 for (int i = 0; i < n; ++i) {
609 // Slice off individual matrices and reshape to 2D tensors.
610 auto x_slice = Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
611 x_slice = Reshape(x_slice, {0, 1, 2}, {2, 2});
612 auto y_slice = Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
613 y_slice = Reshape(y_slice, {0, 1, 2}, {2, 2});
614
615 auto out = Dot(x_slice, y_slice);
616 out = Reshape(out, {0, 1}, {1, 2, 2});
617 out_slices.push_back(out);
618 }
619 auto out_flat = ConcatInDim(&builder, out_slices, 0);
620 Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
621
622 auto x_data = this->client_
623 ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
624 {{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
625 {{2000.0f, 200.0f}, {20.0f, 2.0f}}},
626 {{{3000.0f, 300.0f}, {30.0f, 3.0f}},
627 {{4000.0f, 400.0f}, {40.0f, 4.0f}}}}))
628 .value();
629 auto y_data =
630 this->client_
631 ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
632 {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
633 {{{11.0f, 22.0f}, {33.0f, 44.0f}},
634 {{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
635 .value();
636
637 if (std::is_same<Eigen::half, T>::value) {
638 this->error_spec_ = ErrorSpec{0.0001, 1e-3};
639 }
640 this->template ComputeAndCompareR4<T>(
641 &builder,
642 /*expected=*/
643 {{{{1300.0f, 2400.0f}, {13.0f, 24.0f}},
644 {{11400.0f, 13600.0f}, {114.0f, 136.0f}}},
645 {{{42900.0f, 79200.0f}, {429.0f, 792.0f}},
646 {{250800.0f, 299200.0f}, {2508.0f, 2992.0f}}}},
647 {x_data.get(), y_data.get()}, this->error_spec_);
648 }
649
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMul)650 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) {
651 using T = TypeParam;
652
653 XlaBuilder builder(this->TestName());
654 auto x =
655 Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "x");
656 auto y =
657 Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "y");
658
659 DotDimensionNumbers dnums;
660 dnums.add_lhs_contracting_dimensions(2);
661 dnums.add_rhs_contracting_dimensions(1);
662 dnums.add_lhs_batch_dimensions(0);
663 dnums.add_rhs_batch_dimensions(0);
664
665 DotGeneral(x, y, dnums);
666
667 auto x_data =
668 this->client_
669 ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
670 {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
671 .value();
672
673 auto y_data =
674 this->client_
675 ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
676 {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
677 .value();
678
679 this->template ComputeAndCompareR3<T>(
680 &builder,
681 /*expected=*/
682 {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
683 {x_data.get(), y_data.get()}, this->error_spec_);
684 }
685
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMulR3LhsR2Rhs)686 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR3LhsR2Rhs) {
687 using T = TypeParam;
688
689 XlaBuilder builder(this->TestName());
690 auto x =
691 Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "x");
692 auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2}), "y");
693
694 DotDimensionNumbers dnums;
695 dnums.add_lhs_contracting_dimensions(1);
696 dnums.add_rhs_contracting_dimensions(1);
697 dnums.add_lhs_batch_dimensions(0);
698 dnums.add_rhs_batch_dimensions(0);
699
700 DotGeneral(x, y, dnums);
701
702 auto x_data =
703 this->client_
704 ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
705 {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
706 .value();
707
708 auto y_data = this->client_
709 ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
710 {{1.0f, 0.0f}, {0.0f, 1.0f}}))
711 .value();
712
713 this->template ComputeAndCompareR2<T>(
714 &builder,
715 /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()},
716 this->error_spec_);
717 }
718
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMulR2LhsR3Rhs)719 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR2LhsR3Rhs) {
720 using T = TypeParam;
721
722 XlaBuilder builder(this->TestName());
723 auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2}), "x");
724 auto y =
725 Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "y");
726
727 DotDimensionNumbers dnums;
728 dnums.add_lhs_contracting_dimensions(1);
729 dnums.add_rhs_contracting_dimensions(1);
730 dnums.add_lhs_batch_dimensions(0);
731 dnums.add_rhs_batch_dimensions(0);
732
733 DotGeneral(x, y, dnums);
734
735 auto x_data = this->client_
736 ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
737 {{1.0f, 0.0f}, {0.0f, 1.0f}}))
738 .value();
739
740 auto y_data =
741 this->client_
742 ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
743 {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
744 .value();
745
746 this->template ComputeAndCompareR2<T>(
747 &builder,
748 /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()},
749 this->error_spec_);
750 }
751
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,GeneralMatMulMultipleBatch)752 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
753 using T = TypeParam;
754
755 XlaBuilder builder(this->TestName());
756 auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
757 "x");
758 auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
759 "y");
760
761 DotDimensionNumbers dnums;
762 dnums.add_lhs_contracting_dimensions(3);
763 dnums.add_rhs_contracting_dimensions(2);
764 dnums.add_lhs_batch_dimensions(0);
765 dnums.add_lhs_batch_dimensions(1);
766 dnums.add_rhs_batch_dimensions(0);
767 dnums.add_rhs_batch_dimensions(1);
768
769 DotGeneral(x, y, dnums);
770
771 auto x_data =
772 this->client_
773 ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
774 {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
775 {{{9.0f, 10.0f}, {11.0f, 12.0f}},
776 {{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
777 .value();
778
779 auto y_data =
780 this->client_
781 ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
782 {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
783 {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
784 .value();
785
786 this->template ComputeAndCompareR4<T>(
787 &builder,
788 /*expected=*/
789 {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
790 {{{10.0f, 9.0f}, {12.0f, 11.0f}}, {{14.0f, 13.0f}, {16.0f, 15.0f}}}},
791 {x_data.get(), y_data.get()}, this->error_spec_);
792 }
793
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,TransposeFolding)794 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) {
795 using T = TypeParam;
796 for (bool transpose_lhs : {false, true}) {
797 for (bool transpose_rhs : {false, true}) {
798 for (bool row_major : {false, true}) {
799 std::unique_ptr<Array2D<T>> lhs(
800 new Array2D<T>({{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}));
801 std::unique_ptr<Array2D<T>> rhs(
802 new Array2D<T>({{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}));
803
804 if (transpose_lhs) {
805 lhs = ReferenceUtil::TransposeArray2D(*lhs);
806 }
807 if (transpose_rhs) {
808 rhs = ReferenceUtil::TransposeArray2D(*rhs);
809 }
810 auto lhs_handle =
811 this->client_
812 ->TransferToServer(
813 LiteralUtil::CreateR2FromArray2DWithLayout<T>(
814 *lhs, LayoutUtil::MakeLayout(
815 MinorToMajorForIsRowMajor(row_major))))
816 .value();
817 auto rhs_handle =
818 this->client_
819 ->TransferToServer(
820 LiteralUtil::CreateR2FromArray2DWithLayout<T>(
821 *rhs, LayoutUtil::MakeLayout(
822 MinorToMajorForIsRowMajor(row_major))))
823 .value();
824
825 XlaBuilder builder(this->TestName());
826 auto prim_type = primitive_util::NativeToPrimitiveType<T>();
827 auto lhs_arg = Parameter(
828 &builder, 0,
829 ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}),
830 "lhs");
831 auto rhs_arg = Parameter(
832 &builder, 1,
833 ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}),
834 "rhs");
835 if (transpose_lhs) {
836 lhs_arg = Transpose(lhs_arg, {1, 0});
837 }
838 if (transpose_rhs) {
839 rhs_arg = Transpose(rhs_arg, {1, 0});
840 }
841 Dot(lhs_arg, rhs_arg);
842
843 Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}});
844 VLOG(1) << "TestTransposeFolding " << transpose_lhs << " "
845 << transpose_rhs << " " << row_major;
846 this->template ComputeAndCompareR2<T>(
847 &builder, expected, {lhs_handle.get(), rhs_handle.get()},
848 this->error_spec_);
849 }
850 }
851 }
852 }
853
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,DotOfConcatOptimizationWithConstLHS)854 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
855 DotOfConcatOptimizationWithConstLHS) {
856 using T = TypeParam;
857 auto prim_type = primitive_util::NativeToPrimitiveType<T>();
858
859 std::unique_ptr<Array2D<T>> constant_lhs_array(
860 new Array2D<T>({{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
861 {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}}));
862
863 XlaBuilder builder(this->TestName());
864 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
865 auto rhs_arg_0 = Parameter(
866 &builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0");
867 auto rhs_arg_1 = Parameter(
868 &builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs_arg_1");
869 auto rhs_arg_2 = Parameter(
870 &builder, 2, ShapeUtil::MakeShape(prim_type, {1, 2}), "rhs_arg_2");
871 Dot(lhs_constant,
872 ConcatInDim(&builder, {rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0));
873
874 std::unique_ptr<Array2D<T>> arg_0_value_array(
875 new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
876 std::unique_ptr<Array2D<T>> arg_1_value_array(
877 new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}));
878 std::unique_ptr<Array2D<T>> arg_2_value_array(new Array2D<T>({{1.0f, 2.0f}}));
879
880 TF_ASSERT_OK_AND_ASSIGN(
881 auto arg_0_value,
882 this->client_->TransferToServer(
883 LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
884 TF_ASSERT_OK_AND_ASSIGN(
885 auto arg_1_value,
886 this->client_->TransferToServer(
887 LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
888 TF_ASSERT_OK_AND_ASSIGN(
889 auto arg_2_value,
890 this->client_->TransferToServer(
891 LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
892
893 Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
894 this->template ComputeAndCompareR2<T>(
895 &builder, expected,
896 {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()},
897 this->error_spec_);
898 }
899
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,DotOfConcatOptimizationWithConstRHS)900 XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64,
901 DotOfConcatOptimizationWithConstRHS) {
902 using T = TypeParam;
903 std::unique_ptr<Array2D<T>> constant_rhs_array(
904 new Array2D<T>({{1.0f, 2.0f},
905 {3.0f, 4.0f},
906 {5.0f, 6.0f},
907 {6.0f, 5.0f},
908 {4.0f, 3.0f},
909 {2.0f, 1.0f}}));
910
911 XlaBuilder builder(this->TestName());
912 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
913 auto lhs_arg_0 = Parameter(
914 &builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2}), "lhs_arg_0");
915 auto lhs_arg_1 = Parameter(
916 &builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 3}), "lhs_arg_1");
917 auto lhs_arg_2 = Parameter(
918 &builder, 2, ShapeUtil::MakeShapeWithType<T>({2, 1}), "lhs_arg_2");
919 Dot(ConcatInDim(&builder, {lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1),
920 rhs_constant);
921
922 std::unique_ptr<Array2D<T>> arg_0_value_array(
923 new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
924 std::unique_ptr<Array2D<T>> arg_1_value_array(
925 new Array2D<T>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}));
926 std::unique_ptr<Array2D<T>> arg_2_value_array(
927 new Array2D<T>({{1.0f}, {2.0f}}));
928
929 TF_ASSERT_OK_AND_ASSIGN(
930 auto arg_0_value,
931 this->client_->TransferToServer(
932 LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
933 TF_ASSERT_OK_AND_ASSIGN(
934 auto arg_1_value,
935 this->client_->TransferToServer(
936 LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
937 TF_ASSERT_OK_AND_ASSIGN(
938 auto arg_2_value,
939 this->client_->TransferToServer(
940 LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
941
942 Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
943 this->template ComputeAndCompareR2<T>(
944 &builder, expected,
945 {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()},
946 this->error_spec_);
947 }
948
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSClassicMM)949 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) {
950 std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
951 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
952 std::unique_ptr<Array2D<float>> constant_rhs_array(
953 new Array2D<float>({{1.0, 2.0, 3.0},
954 {4.0, 5.0, 6.0},
955 {7.0, 8.0, 9.0},
956 {9.0, 8.0, 7.0},
957 {6.0, 5.0, 4.0},
958 {3.0, 2.0, 1.0}}));
959 // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
960
961 XlaBuilder builder(TestName());
962 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
963 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
964 auto one = ConstantR0<int32_t>(&builder, 1);
965 auto zero = ConstantR0<int32_t>(&builder, 0);
966 auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6});
967
968 DotDimensionNumbers dot_dnums;
969 dot_dnums.add_lhs_contracting_dimensions(1);
970 dot_dnums.add_rhs_contracting_dimensions(0);
971 DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
972
973 Array2D<float> expected({{96.0, 105.0, 114.0}});
974 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
975 }
976
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSClassicMM)977 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
978 std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
979 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
980 std::unique_ptr<Array2D<float>> constant_rhs_array(
981 new Array2D<float>({{1.0, 2.0, 3.0},
982 {4.0, 5.0, 6.0},
983 {7.0, 8.0, 9.0},
984 {9.0, 8.0, 7.0},
985 {6.0, 5.0, 4.0},
986 {3.0, 2.0, 1.0}}));
987 // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
988
989 XlaBuilder builder(TestName());
990 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
991 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
992 auto zero = ConstantR0<int32_t>(&builder, 0);
993 auto one = ConstantR0<int32_t>(&builder, 1);
994 auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1});
995
996 DotDimensionNumbers dot_dnums;
997 dot_dnums.add_lhs_contracting_dimensions(1);
998 dot_dnums.add_rhs_contracting_dimensions(0);
999 DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1000
1001 Array2D<float> expected({{105.0}, {105.0}});
1002 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1003 }
1004
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSReverseMM)1005 XLA_TEST_F(DotOperationTest,
1006
1007 DotOfGatherOptimizationWithConstRHSReverseMM) {
1008 std::unique_ptr<Array2D<float>> constant_lhs_array(
1009 new Array2D<float>({{1.0, 2.0, 3.0},
1010 {4.0, 5.0, 6.0},
1011 {7.0, 8.0, 9.0},
1012 {9.0, 8.0, 7.0},
1013 {6.0, 5.0, 4.0},
1014 {3.0, 2.0, 1.0}}));
1015 std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
1016 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1017 // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
1018
1019 XlaBuilder builder(TestName());
1020 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1021 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1022 auto zero = ConstantR0<int32_t>(&builder, 0);
1023 auto one = ConstantR0<int32_t>(&builder, 1);
1024 auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1});
1025
1026 DotDimensionNumbers dot_dnums;
1027 dot_dnums.add_lhs_contracting_dimensions(0);
1028 dot_dnums.add_rhs_contracting_dimensions(1);
1029 DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
1030
1031 Array2D<float> expected({{105.0, 105.0}});
1032 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1033 }
1034
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSReverseMM)1035 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) {
1036 std::unique_ptr<Array2D<float>> constant_lhs_array(
1037 new Array2D<float>({{1.0, 2.0, 3.0},
1038 {4.0, 5.0, 6.0},
1039 {7.0, 8.0, 9.0},
1040 {9.0, 8.0, 7.0},
1041 {6.0, 5.0, 4.0},
1042 {3.0, 2.0, 1.0}}));
1043 std::unique_ptr<Array2D<float>> constant_rhs_array(new Array2D<float>(
1044 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1045 // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
1046
1047 XlaBuilder builder(TestName());
1048 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1049 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1050 auto zero = ConstantR0<int32_t>(&builder, 0);
1051 auto one = ConstantR0<int32_t>(&builder, 1);
1052 auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6});
1053
1054 DotDimensionNumbers dot_dnums;
1055 dot_dnums.add_lhs_contracting_dimensions(0);
1056 dot_dnums.add_rhs_contracting_dimensions(1);
1057 DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1058
1059 Array2D<float> expected({{96.0}, {105.0}, {114.0}});
1060 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1061 }
1062
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSRows)1063 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) {
1064 std::unique_ptr<Array2D<float>> constant_lhs_array(
1065 new Array2D<float>({{1.0, 2.0},
1066 {3.0, 4.0},
1067 {5.0, 6.0},
1068 {6.0, 5.0},
1069 {4.0, 3.0},
1070 {2.0, 1.0}}));
1071 std::unique_ptr<Array2D<float>> constant_rhs_array(
1072 new Array2D<float>({{1.0, 2.0, 3.0},
1073 {4.0, 5.0, 6.0},
1074 {7.0, 8.0, 9.0},
1075 {9.0, 8.0, 7.0},
1076 {6.0, 5.0, 4.0},
1077 {3.0, 2.0, 1.0}}));
1078 // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
1079
1080 XlaBuilder builder(TestName());
1081 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1082 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1083 auto zero = ConstantR0<int32_t>(&builder, 0);
1084 auto one = ConstantR0<int32_t>(&builder, 1);
1085 auto dynamic_slice = DynamicSlice(lhs_constant, {zero, one}, {6, 1});
1086
1087 DotDimensionNumbers dot_dnums;
1088 dot_dnums.add_lhs_contracting_dimensions(0);
1089 dot_dnums.add_rhs_contracting_dimensions(0);
1090 DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
1091
1092 Array2D<float> expected({{126.0, 129.0, 132.0}});
1093 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1094 }
1095
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSRows)1096 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) {
1097 std::unique_ptr<Array2D<float>> constant_lhs_array(
1098 new Array2D<float>({{1.0, 2.0},
1099 {3.0, 4.0},
1100 {5.0, 6.0},
1101 {6.0, 5.0},
1102 {4.0, 3.0},
1103 {2.0, 1.0}}));
1104 std::unique_ptr<Array2D<float>> constant_rhs_array(
1105 new Array2D<float>({{1.0, 2.0, 3.0},
1106 {4.0, 5.0, 6.0},
1107 {7.0, 8.0, 9.0},
1108 {9.0, 8.0, 7.0},
1109 {6.0, 5.0, 4.0},
1110 {3.0, 2.0, 1.0}}));
1111 // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
1112
1113 XlaBuilder builder(TestName());
1114 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1115 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1116 auto zero = ConstantR0<int32_t>(&builder, 0);
1117 auto one = ConstantR0<int32_t>(&builder, 1);
1118 auto dynamic_slice = DynamicSlice(rhs_constant, {zero, one}, {6, 1});
1119
1120 DotDimensionNumbers dot_dnums;
1121 dot_dnums.add_lhs_contracting_dimensions(0);
1122 dot_dnums.add_rhs_contracting_dimensions(0);
1123 DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1124
1125 Array2D<float> expected({{129.0}, {129.0}});
1126 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1127 }
1128
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstRHSCols)1129 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) {
1130 std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
1131 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1132 std::unique_ptr<Array2D<float>> constant_rhs_array(
1133 new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
1134 {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
1135 {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1136 // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
1137
1138 XlaBuilder builder(TestName());
1139 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1140 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1141 auto zero = ConstantR0<int32_t>(&builder, 0);
1142 auto one = ConstantR0<int32_t>(&builder, 1);
1143 auto dynamic_slice = DynamicSlice(lhs_constant, {one, zero}, {1, 6});
1144
1145 DotDimensionNumbers dot_dnums;
1146 dot_dnums.add_lhs_contracting_dimensions(1);
1147 dot_dnums.add_rhs_contracting_dimensions(1);
1148 DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
1149
1150 Array2D<float> expected({{56.0, 168.0, 91.0}});
1151 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1152 }
1153
XLA_TEST_F(DotOperationTest,DotOfGatherOptimizationWithConstLHSCols)1154 XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) {
1155 std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
1156 {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1157 std::unique_ptr<Array2D<float>> constant_rhs_array(
1158 new Array2D<float>({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0},
1159 {7.0, 8.0, 9.0, 9.0, 8.0, 7.0},
1160 {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
1161 // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
1162
1163 XlaBuilder builder(TestName());
1164 auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
1165 auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
1166 auto zero = ConstantR0<int32_t>(&builder, 0);
1167 auto one = ConstantR0<int32_t>(&builder, 1);
1168 auto dynamic_slice = DynamicSlice(rhs_constant, {one, zero}, {1, 6});
1169
1170 DotDimensionNumbers dot_dnums;
1171 dot_dnums.add_lhs_contracting_dimensions(1);
1172 dot_dnums.add_rhs_contracting_dimensions(1);
1173 DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
1174
1175 Array2D<float> expected({{168.0}, {168.0}});
1176 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1177 }
1178
XLA_TEST_F(DotOperationTest,DotRank2AndRank2NonDefaultContractionDims)1179 XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) {
1180 XlaBuilder builder(TestName());
1181
1182 Array2D<float> lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}});
1183 auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array);
1184
1185 Array2D<float> rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}});
1186 auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array);
1187
1188 Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
1189 DotDimensionNumbers dot_dnums;
1190 dot_dnums.add_lhs_contracting_dimensions(0);
1191 dot_dnums.add_rhs_contracting_dimensions(0);
1192 DotGeneral(lhs_constant, rhs_constant, dot_dnums);
1193
1194 Array2D<float> expected({
1195 {26.f, 30.f},
1196 {38.f, 44.f},
1197 });
1198
1199 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1200 }
1201
1202 using EinsumParamType =
1203 std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::string>;
1204 class EinsumTest : public DotOperationTest,
1205 public ::testing::WithParamInterface<EinsumParamType> {};
XLA_TEST_P(EinsumTest,SimpleEinsumTest)1206 XLA_TEST_P(EinsumTest, SimpleEinsumTest) {
1207 XlaBuilder builder(TestName());
1208 auto x = AddParam(
1209 MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam())))
1210 .ValueOrDie(),
1211 &builder);
1212 auto y = AddParam(
1213 MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam())))
1214 .ValueOrDie(),
1215 &builder);
1216 auto config = std::get<2>(GetParam());
1217 if (config.find(',') == config.npos) {
1218 Einsum(x, config);
1219 } else {
1220 Einsum(x, y, config);
1221 }
1222 ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3});
1223 }
1224
GetEinsumTestCases()1225 std::vector<EinsumParamType> GetEinsumTestCases() {
1226 using v = std::vector<int64_t>;
1227 using p = EinsumParamType;
1228 std::vector<p> test_cases = {
1229 p{v{5, 6}, v{6, 7}, "mk,kn->mn"},
1230 p{v{5, 6}, v{6, 7}, "mk,kn->nm"},
1231 p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn->nmB"},
1232 p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->nmB"},
1233 p{v{31, 55, 11}, v{55, 11, 29}, "mkB,kBn->Bnm"},
1234 p{v{8, 55, 11, 3}, v{55, 11, 3, 29}, "mkBC,kBCn->BCnm"},
1235 p{v{5, 6}, v{6, 7}, "ab,cd->dcba"},
1236 p{v{6}, v{6, 7}, "b,bc->c"},
1237 p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc->ab"},
1238 p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba->ca"},
1239 p{v{77}, v{77}, "a,a->a"},
1240 p{v{77}, v{77, 55}, "a,ab->ba"},
1241 p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"},
1242 p{v{55}, v{}, "a,->a"},
1243 p{v{11, 111}, v{11}, "ab,a->ab"},
1244 p{v{16, 34}, v{16, 34}, "ab,ab->ab"},
1245 p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac->abc"},
1246 p{v{5, 19}, v{}, "ab,->ab"},
1247 p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf->bhqk"},
1248 p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn->...mn"},
1249 p{v{5, 6}, v{6, 7}, "...mk,...kn->...mn"},
1250 p{v{5, 6}, v{6, 7}, "...mk,kn->...mn"},
1251 p{v{6, 6}, v{7, 7}, "mm,nn->mn"},
1252 p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
1253 p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->...mn"},
1254 p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn->...mn"},
1255 p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn->n"},
1256 p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb...->ba...ij"},
1257 p{v{5, 6}, v{6, 7}, "mk,kn"},
1258 p{v{5, 6}, v{6, 7}, "mk,kn"},
1259 p{v{5, 6, 11}, v{6, 11, 7}, "mkB,kBn"},
1260 p{v{5, 6}, v{6, 7}, "ab,cd"},
1261 p{v{6}, v{6, 7}, "b,bc"},
1262 p{v{5, 6, 7}, v{5, 6, 7}, "abc,abc"},
1263 p{v{5, 6, 7}, v{7, 6, 5}, "abc,cba"},
1264 p{v{77}, v{77}, "a,a"},
1265 p{v{77}, v{77, 55}, "a,ab"},
1266 p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb"},
1267 p{v{55}, v{}, "a"},
1268 p{v{11, 111}, v{11}, "ab,a"},
1269 p{v{16, 34}, v{16, 34}, "ab,ab"},
1270 p{v{16, 3, 34}, v{3, 16, 34}, "abc,bac"},
1271 p{v{5, 19}, v{}, "ab"},
1272 p{v{8, 1, 16, 64}, v{8, 12, 16, 64}, "bqhf,bkhf"},
1273 p{v{2, 3, 5, 6}, v{2, 3, 6, 7}, "...mk,...kn"},
1274 p{v{5, 6}, v{}, "...mk"},
1275 p{v{5, 6, 12, 13}, v{}, "...mk"},
1276 p{v{5, 6, 12, 13}, v{}, "m...k"},
1277 p{v{5, 6, 12, 13}, v{}, "mk..."},
1278 p{v{5, 6}, v{6, 7}, "...mk->km..."},
1279 p{v{1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
1280 p{v{3, 1, 2, 5, 6}, v{2, 1, 6, 7}, "...mk,...kn"},
1281 p{v{1, 2, 5, 6}, v{3, 2, 1, 6, 7}, "...mk,...kn"},
1282 p{v{16, 16, 16}, v{}, "iii"},
1283 p{v{1, 2, 2, 3, 77}, v{77, 2, 3, 55, 1, 2}, "...ija,aijb..."},
1284 };
1285 return test_cases;
1286 }
1287
1288 INSTANTIATE_TEST_SUITE_P(Einsum, EinsumTest,
1289 ::testing::ValuesIn(GetEinsumTestCases()));
1290
1291 using BatchDotParamType = std::tuple<std::vector<int64_t>, std::vector<int64_t>,
1292 std::vector<int64_t>>;
1293 class BatchDotTest : public DotOperationTest,
1294 public ::testing::WithParamInterface<BatchDotParamType> {};
XLA_TEST_P(BatchDotTest,BroadcastingBatchDotTest)1295 XLA_TEST_P(BatchDotTest, BroadcastingBatchDotTest) {
1296 XlaBuilder builder(TestName());
1297 auto x = AddParam(
1298 MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<0>(GetParam())))
1299 .ValueOrDie(),
1300 &builder);
1301 auto y = AddParam(
1302 MakeFakeLiteral(ShapeUtil::MakeShape(F32, std::get<1>(GetParam())))
1303 .ValueOrDie(),
1304 &builder);
1305 auto batch_dot = BatchDot(x, y);
1306 auto output_shape = builder.GetShape(batch_dot).ValueOrDie();
1307 EXPECT_EQ(output_shape.dimensions(), std::get<2>(GetParam()));
1308 ComputeAndCompare(&builder, {}, ErrorSpec{1e-3, 1e-3});
1309 }
1310
GetBatchDotTestCases()1311 std::vector<BatchDotParamType> GetBatchDotTestCases() {
1312 using v = std::vector<int64_t>;
1313 using p = BatchDotParamType;
1314 std::vector<p> test_cases = {
1315 p{v{5, 6}, v{6, 7}, v{5, 7}},
1316 p{v{5, 6, 11}, v{5, 11, 7}, v{5, 6, 7}},
1317 p{v{5, 6, 11}, v{11, 7}, v{5, 6, 7}},
1318 p{v{5, 6, 11}, v{1, 11, 7}, v{5, 6, 7}},
1319 p{v{6, 11}, v{5, 11, 7}, v{5, 6, 7}},
1320 p{v{1, 6, 11}, v{5, 11, 7}, v{5, 6, 7}},
1321 p{v{8, 1, 2, 3}, v{8, 3, 4}, v{8, 8, 2, 4}},
1322 p{v{8, 8, 2, 3}, v{8, 1, 3, 2}, v{8, 8, 2, 2}},
1323 };
1324 return test_cases;
1325 }
1326
1327 INSTANTIATE_TEST_SUITE_P(BatchDot, BatchDotTest,
1328 ::testing::ValuesIn(GetBatchDotTestCases()));
1329
1330 class DotOperationTextTest : public HloTestBase {};
1331
XLA_TEST_F(DotOperationTextTest,DotReorderedDotDims)1332 XLA_TEST_F(DotOperationTextTest, DotReorderedDotDims) {
1333 absl::string_view hlo_string =
1334 R"(
1335 HloModule ComplexDotMultipleNonContracting
1336
1337 ENTRY %test {
1338 %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0)
1339 %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1)
1340 ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3}
1341 }
1342 )";
1343
1344 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
1345 }
1346
XLA_TEST_F(DotOperationTextTest,DotReorderedDotDimsAndMultipleContracting)1347 XLA_TEST_F(DotOperationTextTest, DotReorderedDotDimsAndMultipleContracting) {
1348 absl::string_view hlo_string =
1349 R"(
1350 HloModule ComplexDotMultipleNonContracting
1351
1352 ENTRY %test {
1353 %lhs = f32[7,5,17,10,13]{4,3,2,1,0} parameter(0)
1354 %rhs = f32[7,9,10,13,6,5]{5,4,3,2,1,0} parameter(1)
1355 ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={3,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={1,4}, rhs_contracting_dims={5,3}
1356 }
1357 )";
1358
1359 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
1360 }
1361
XLA_TEST_F(DotOperationTextTest,DotWithNoDnums)1362 XLA_TEST_F(DotOperationTextTest, DotWithNoDnums) {
1363 absl::string_view hlo_string =
1364 R"(
1365 HloModule DotWithNoDnums
1366
1367 ENTRY %test {
1368 %lhs = f32[2,3]{1,0} parameter(0)
1369 %rhs = f32[4,5]{1,0} parameter(1)
1370 ROOT %dot = f32[2,3,4,5]{3,2,1,0} dot(%lhs, %rhs)
1371 }
1372 )";
1373
1374 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3}));
1375 }
1376
XLA_TEST_F(DotOperationTextTest,Einsum)1377 XLA_TEST_F(DotOperationTextTest, Einsum) {
1378 absl::string_view hlo_string =
1379 R"(
1380 HloModule Einsum
1381
1382 ENTRY %test {
1383 %lhs = f32[8,64,96]{2,1,0} parameter(0)
1384 %rhs = f32[96,32,4]{2,1,0} parameter(1)
1385 ROOT %dot = f32[8,64,32,4]{3,2,1,0} dot(%lhs, %rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0}
1386 }
1387 )";
1388
1389 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1390 }
1391
XLA_TEST_F(DotOperationTextTest,CpuTiledDotEmitterCachingBug_1)1392 XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_1) {
1393 // Tests for a caching bug in the XLA CPU backend.
1394 absl::string_view hlo_string =
1395 R"(
1396 HloModule CpuTiledDotEmitterCachingBug
1397
1398 ENTRY main {
1399 lhs = f32[20,40] parameter(0)
1400 rhs_0 = f32[40,1] parameter(2)
1401 rhs_1 = f32[1,40] parameter(1)
1402
1403 dot_0 = f32[20,1] dot(lhs, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1404 dot_1 = f32[20,1] dot(lhs, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
1405
1406 ROOT result = f32[20,1] divide(dot_0, dot_1)
1407 }
1408 )";
1409
1410 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1411 }
1412
XLA_TEST_F(DotOperationTextTest,CpuTiledDotEmitterCachingBug_2)1413 XLA_TEST_F(DotOperationTextTest, CpuTiledDotEmitterCachingBug_2) {
1414 // Tests for a caching bug in the XLA CPU backend.
1415 absl::string_view hlo_string =
1416 R"(
1417 HloModule CpuTiledDotEmitterCachingBug
1418
1419 ENTRY main {
1420 lhs_0 = f32[20,40] parameter(0)
1421 rhs_0 = f32[40,1] parameter(1)
1422 lhs_1 = f32[1,40] parameter(2)
1423 rhs_1 = f32[20,40] parameter(3)
1424
1425 dot_0 = f32[20,1] dot(lhs_0, rhs_0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1426 dot_1 = f32[1,20] dot(lhs_1, rhs_1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
1427
1428 dot_0_reshaped = f32[20] reshape(dot_0)
1429 dot_1_reshaped = f32[20] reshape(dot_1)
1430
1431 ROOT result = f32[20] divide(dot_0_reshaped, dot_1_reshaped)
1432 }
1433 )";
1434
1435 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1436 }
1437
XLA_TEST_F(DotOperationTextTest,S32IotaDot)1438 XLA_TEST_F(DotOperationTextTest, S32IotaDot) {
1439 absl::string_view hlo_string =
1440 R"(
1441 HloModule SmallIntegerDot
1442
1443 ENTRY SmallIntegerDot {
1444 arg0 = s32[5,55,8] iota(), iota_dimension=1
1445 arg1 = s32[5,8,200] iota(), iota_dimension=2
1446 ROOT dot = s32[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1447 }
1448 )";
1449
1450 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1451 }
1452
XLA_TEST_F(DotOperationTextTest,S32IotaSquaredDot)1453 XLA_TEST_F(DotOperationTextTest, S32IotaSquaredDot) {
1454 absl::string_view hlo_string =
1455 R"(
1456 HloModule SmallIntegerDot
1457
1458 ENTRY SmallIntegerDot {
1459 arg0 = s32[16,2] iota(), iota_dimension=0
1460 a = s32[16,2] multiply(arg0, arg0)
1461 r = s32[16,2] multiply(a, a)
1462 arg1 = s32[2,98] iota(), iota_dimension=1
1463 b = s32[2,98] multiply(arg1, arg1)
1464 s = s32[2,98] multiply(b, b)
1465 ROOT dot = s32[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1466 }
1467 )";
1468
1469 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1470 }
1471
XLA_TEST_F(DotOperationTextTest,U16IotaDot)1472 XLA_TEST_F(DotOperationTextTest, U16IotaDot) {
1473 absl::string_view hlo_string =
1474 R"(
1475 HloModule SmallIntegerDot
1476
1477 ENTRY SmallIntegerDot {
1478 arg0 = u16[5,55,8] parameter(0)
1479 arg1 = u16[5,8,200] parameter(1)
1480 dot = u16[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1481 ROOT c = s32[5,55,200] convert(dot)
1482 }
1483 )";
1484
1485 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1486 }
1487
XLA_TEST_F(DotOperationTextTest,U16IotaSquaredDot)1488 XLA_TEST_F(DotOperationTextTest, U16IotaSquaredDot) {
1489 absl::string_view hlo_string =
1490 R"(
1491 HloModule SmallIntegerDot
1492
1493 ENTRY SmallIntegerDot {
1494 arg0 = u16[16,2] iota(), iota_dimension=0
1495 a = u16[16,2] multiply(arg0, arg0)
1496 r = u16[16,2] multiply(a, a)
1497 arg1 = u16[2,98] iota(), iota_dimension=1
1498 b = u16[2,98] multiply(arg1, arg1)
1499 s = u16[2,98] multiply(b, b)
1500 ROOT dot = u16[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1501 }
1502 )";
1503
1504 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1505 }
1506
XLA_TEST_F(DotOperationTextTest,S16IotaDot)1507 XLA_TEST_F(DotOperationTextTest, S16IotaDot) {
1508 absl::string_view hlo_string =
1509 R"(
1510 HloModule SmallIntegerDot
1511
1512 ENTRY SmallIntegerDot {
1513 arg0 = s16[5,55,8] iota(), iota_dimension=1
1514 arg1 = s16[5,8,200] iota(), iota_dimension=2
1515 ROOT dot = s16[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1516 }
1517 )";
1518
1519 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1520 }
1521
XLA_TEST_F(DotOperationTextTest,S16IotaSquaredDot)1522 XLA_TEST_F(DotOperationTextTest, S16IotaSquaredDot) {
1523 absl::string_view hlo_string =
1524 R"(
1525 HloModule SmallIntegerDot
1526
1527 ENTRY SmallIntegerDot {
1528 arg0 = s16[16,2] iota(), iota_dimension=0
1529 a = s16[16,2] multiply(arg0, arg0)
1530 r = s16[16,2] multiply(a, a)
1531 arg1 = s16[2,98] iota(), iota_dimension=1
1532 b = s16[2,98] multiply(arg1, arg1)
1533 s = s16[2,98] multiply(b, b)
1534 ROOT dot = s16[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1535 }
1536 )";
1537
1538 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1539 }
1540
XLA_TEST_F(DotOperationTextTest,PREDDot)1541 XLA_TEST_F(DotOperationTextTest, PREDDot) {
1542 absl::string_view hlo_string =
1543 R"(
1544 HloModule SmallIntegerDot
1545
1546 ENTRY SmallIntegerDot {
1547 arg0 = pred[20,2] parameter(0)
1548 arg1 = pred[2,20] parameter(1)
1549 ROOT dot = pred[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1550 }
1551 )";
1552
1553 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1554 }
1555
XLA_TEST_F(DotOperationTextTest,S8Dot)1556 XLA_TEST_F(DotOperationTextTest, S8Dot) {
1557 absl::string_view hlo_string =
1558 R"(
1559 HloModule SmallIntegerDot
1560
1561 ENTRY SmallIntegerDot {
1562 arg0 = s8[20,2] parameter(0)
1563 arg1 = s8[2,20] parameter(1)
1564 ROOT dot = s8[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1565 }
1566 )";
1567
1568 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1569 }
1570
XLA_TEST_F(DotOperationTextTest,S32Dot)1571 XLA_TEST_F(DotOperationTextTest, S32Dot) {
1572 absl::string_view hlo_string =
1573 R"(
1574 HloModule SmallIntegerDot
1575
1576 ENTRY SmallIntegerDot {
1577 arg0 = s32[20,55] parameter(0)
1578 arg1 = s32[55,20] parameter(1)
1579 ROOT dot = s32[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1580 }
1581 )";
1582
1583 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0}));
1584 }
1585
XLA_TEST_F(DotOperationTextTest,GpuTransposeOutput)1586 XLA_TEST_F(DotOperationTextTest, GpuTransposeOutput) {
1587 absl::string_view hlo_string =
1588 R"(
1589 HloModule TransposeOutput
1590
1591 ENTRY TransposeOutput {
1592 p0 = f32[32,32] parameter(0)
1593 p1 = f32[32,64] parameter(1)
1594 dot = f32[32,64] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
1595 ROOT tr = f32[64,32] transpose(dot), dimensions={1,0}
1596 }
1597 )";
1598
1599 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1600 }
1601
XLA_TEST_F(DotOperationTextTest,MatrixVectorComplex)1602 XLA_TEST_F(DotOperationTextTest, MatrixVectorComplex) {
1603 absl::string_view hlo_string =
1604 R"(
1605 HloModule MatrixVectorComplex
1606
1607 ENTRY MatrixVectorComplex {
1608 p0 = c64[5,5] parameter(0)
1609 p1 = c64[5,1] parameter(1)
1610 p2 = c64[5,1] parameter(2)
1611 dot = c64[5,1] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1612 ROOT add = c64[5,1] add(dot, p2)
1613 }
1614 )";
1615
1616 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module,
1617 ParseAndReturnUnverifiedModule(hlo_string));
1618 EXPECT_TRUE(RunAndCompare(std::move(hlo_module), ErrorSpec{4e-3, 4e-3}));
1619 }
1620
XLA_TEST_F(DotOperationTextTest,MatrixVectorBF16)1621 XLA_TEST_F(DotOperationTextTest, MatrixVectorBF16) {
1622 absl::string_view hlo_string =
1623 R"(
1624 HloModule MatrixVectorBF16
1625
1626 ENTRY MatrixVectorBF16 {
1627 p0 = bf16[128] parameter(0)
1628 p1 = bf16[128,256] parameter(1)
1629 p2 = bf16[256] parameter(2)
1630 dot = bf16[256] dot(p0, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
1631 ROOT add = bf16[256] add(dot, p2)
1632 }
1633 )";
1634
1635 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1636 }
1637
1638 // Regression test for b/138155357, where we were incorrectly creating a dot-add
1639 // fusion where the dot had a batch dimension. This isn't supported on the CPU
1640 // backend.
XLA_TEST_F(DotOperationTextTest,FusedBatchDotRegressionTest)1641 XLA_TEST_F(DotOperationTextTest, FusedBatchDotRegressionTest) {
1642 absl::string_view module_string = R"(
1643 HloModule jaxpr_computation__5.33
1644
1645 jaxpr_computation__6.8 {
1646 tuple.9 = () tuple()
1647 parameter.14 = () parameter(4)
1648 parameter.13 = (f32[2]{0}) parameter(3)
1649 get-tuple-element.15 = f32[2]{0} get-tuple-element(parameter.13), index=0
1650 reshape.16 = f32[1,2]{1,0} reshape(get-tuple-element.15)
1651 parameter.10 = f32[2,2]{1,0} parameter(0)
1652 reshape.17 = f32[2,1]{1,0} reshape(get-tuple-element.15)
1653 dot.18 = f32[2,1]{1,0} dot(parameter.10, reshape.17), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1654 reshape.19 = f32[2]{0} reshape(dot.18)
1655 reshape.20 = f32[2,1]{1,0} reshape(reshape.19)
1656 dot.21 = f32[1,1]{1,0} dot(reshape.16, reshape.20), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1657 reshape.22 = f32[] reshape(dot.21)
1658 parameter.11 = f32[2,1,2]{2,1,0} parameter(1)
1659 broadcast.23 = f32[2,2,1]{2,1,0} broadcast(reshape.20), dimensions={1,2}
1660 dot.24 = f32[2,1,1]{2,1,0} dot(parameter.11, broadcast.23), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1661 broadcast.25 = f32[2,1,2]{2,1,0} broadcast(reshape.16), dimensions={1,2}
1662 parameter.12 = f32[2,2,1]{2,1,0} parameter(2)
1663 dot.26 = f32[2,1,1]{2,1,0} dot(broadcast.25, parameter.12), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1664 add.27 = f32[2,1,1]{2,1,0} add(dot.24, dot.26)
1665 reshape.28 = f32[2]{0} reshape(add.27)
1666 ROOT tuple.29 = (f32[], f32[2]{0}) tuple(reshape.22, reshape.28)
1667 }
1668
1669 ENTRY jaxpr_computation__5.33 {
1670 constant.2 = f32[] constant(1)
1671 broadcast.3 = f32[2,2]{1,0} broadcast(constant.2), dimensions={}
1672 constant.5 = f32[2,1,2]{2,1,0} constant({ { { 1, 0 } }, { { 0, 1 } } })
1673 constant.4 = f32[2,2,1]{2,1,0} constant({ { {1}, {1} }, { {1}, {1} } })
1674 parameter.6 = f32[2]{0} parameter(0)
1675 tuple.7 = (f32[2]{0}) tuple(parameter.6)
1676 tuple.1 = () tuple()
1677 call.30 = (f32[], f32[2]{0}) call(broadcast.3, constant.5, constant.4, tuple.7, tuple.1), to_apply=jaxpr_computation__6.8
1678 get-tuple-element.31 = f32[] get-tuple-element(call.30), index=0
1679 ROOT get-tuple-element.32 = f32[2]{0} get-tuple-element(call.30), index=1
1680 })";
1681 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1682 ParseAndReturnVerifiedModule(module_string));
1683 EXPECT_TRUE(RunAndCompare(std::move(module), /*error=*/std::nullopt));
1684 }
1685
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstLHS_RL)1686 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstLHS_RL) {
1687 Array3D<float> input_arr(2, 3, 2);
1688 Array2D<float> const_arr(2, 6);
1689 input_arr.FillIota(0);
1690 const_arr.FillIota(0);
1691
1692 XlaBuilder builder(TestName());
1693 auto t0 =
1694 AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder);
1695 auto t1 = Transpose(t0, {1, 0, 2});
1696 auto rhs = Reshape(t1, {6, 2});
1697 auto lhs = ConstantR2FromArray2D(&builder, const_arr);
1698 Dot(lhs, rhs);
1699
1700 ComputeAndCompare(&builder, {}, error_spec_);
1701 }
1702
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstRHS_LR)1703 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_LR) {
1704 Array3D<float> input_arr(2, 3, 2);
1705 Array2D<float> const_arr(2, 6);
1706 input_arr.FillIota(0);
1707 const_arr.FillIota(0);
1708
1709 XlaBuilder builder(TestName());
1710 auto t0 =
1711 AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder);
1712 auto t1 = Transpose(t0, {1, 0, 2});
1713 auto lhs = Reshape(t1, {6, 2});
1714 auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1715
1716 DotDimensionNumbers dims;
1717 dims.add_lhs_contracting_dimensions(0);
1718 dims.add_rhs_contracting_dimensions(1);
1719 DotGeneral(lhs, rhs, dims);
1720
1721 ComputeAndCompare(&builder, {}, error_spec_);
1722 }
1723
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstRHS_RL)1724 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_RL) {
1725 Array4D<float> input_arr(2, 2, 3, 4);
1726 Array2D<float> const_arr(24, 2);
1727 input_arr.FillIota(0);
1728 const_arr.FillIota(0);
1729
1730 XlaBuilder builder(TestName());
1731 auto t0 =
1732 AddParam(LiteralUtil::CreateR4FromArray4D<float>(input_arr), &builder);
1733 auto t1 = Transpose(t0, {0, 2, 3, 1});
1734 auto lhs = Reshape(t1, {2, 24});
1735 auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1736 Dot(lhs, rhs);
1737
1738 ComputeAndCompare(&builder, {}, error_spec_);
1739 }
1740
XLA_TEST_F(DotOperationTest,ReorderContractingDimsConstRHS_MM)1741 XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstRHS_MM) {
1742 Array3D<float> input_arr(2, 6, 2);
1743 Array3D<float> const_arr(2, 6, 3);
1744 input_arr.FillIota(0);
1745 const_arr.FillIota(0);
1746
1747 XlaBuilder builder(TestName());
1748 auto t0 =
1749 AddParam(LiteralUtil::CreateR3FromArray3D<float>(input_arr), &builder);
1750 auto t1 = Reshape(t0, {2, 2, 3, 2});
1751 auto t2 = Transpose(t1, {0, 2, 1, 3});
1752 auto lhs = Reshape(t2, {2, 6, 2});
1753 auto rhs = ConstantR3FromArray3D(&builder, const_arr);
1754
1755 DotDimensionNumbers dims;
1756 dims.add_lhs_contracting_dimensions(1);
1757 dims.add_rhs_contracting_dimensions(1);
1758 dims.add_lhs_batch_dimensions(0);
1759 dims.add_rhs_batch_dimensions(0);
1760 DotGeneral(lhs, rhs, dims);
1761
1762 ComputeAndCompare(&builder, {}, error_spec_);
1763 }
1764
XLA_TEST_F(DotOperationTest,ReorderContractingDims_Multipass)1765 XLA_TEST_F(DotOperationTest, ReorderContractingDims_Multipass) {
1766 Array4D<float> input_arr(2, 2, 3, 5);
1767 Array2D<float> const_arr(2, 30);
1768 input_arr.FillIota(0);
1769 const_arr.FillIota(0);
1770
1771 XlaBuilder builder(TestName());
1772 auto t0 =
1773 AddParam(LiteralUtil::CreateR4FromArray4D<float>(input_arr), &builder);
1774 auto t1 = Transpose(t0, {0, 2, 1, 3});
1775 auto t2 = Reshape(t1, {2, 6, 5});
1776 auto t3 = Transpose(t2, {0, 2, 1});
1777 auto lhs = Reshape(t3, {2, 30});
1778 auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1779
1780 DotDimensionNumbers dims;
1781 dims.add_lhs_contracting_dimensions(1);
1782 dims.add_rhs_contracting_dimensions(1);
1783 DotGeneral(lhs, rhs, dims);
1784
1785 // Constant folding are disabled by default in unit tests. algsimp
1786 // optimization can be applied multiple times if we fold the transpose
1787 // and reshape that are moved to the constant side of the dot.
1788 mutable_debug_options()->clear_xla_disable_hlo_passes();
1789 ComputeAndCompare(&builder, {}, error_spec_);
1790 }
1791
XLA_TEST_F(DotOperationTextTest,WiderIntegralResultAccumulation)1792 XLA_TEST_F(DotOperationTextTest, WiderIntegralResultAccumulation) {
1793 absl::string_view hlo_string =
1794 R"(
1795 HloModule WiderIntegralAccumulation
1796
1797 ENTRY MatrixVectorComplex {
1798 p0 = s8[5,5]{1,0} parameter(0)
1799 p1 = s16[5,1]{0,1} parameter(1)
1800 ROOT dot = s32[5,1]{1,0} dot(p0, p1), lhs_contracting_dims={1},
1801 rhs_contracting_dims={0}
1802 }
1803 )";
1804
1805 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3}));
1806 }
1807
1808 // This benchmark is to show the performance impact of the following
1809 // transformation:
1810 // dot(reshape(transpose(A)), Const) ==>
1811 // dot(reshape(A), reshape(transpose(reshape(Const)))),
1812 // and then fold the reshape and transpose on the Const side.
1813 // We can compare performance with and without algsimp pass to see the impact.
DOT_ReorderContracting(::testing::benchmark::State & state)1814 void DOT_ReorderContracting(::testing::benchmark::State& state) {
1815 se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
1816 auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
1817 se::StreamExecutorMemoryAllocator allocator(platform, executors);
1818
1819 xla::LocalClientOptions client_options;
1820 client_options.set_platform(platform);
1821 auto client =
1822 ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie();
1823
1824 int device_ordinal = client->default_device_ordinal();
1825
1826 const int64_t d0 = 128;
1827 const int64_t d1 = 128;
1828 const int64_t d2 = 128;
1829 const int64_t d3 = 128;
1830
1831 Array3D<float> input_arr(d0, d1, d2);
1832 Array2D<float> const_arr(d1 * d2, d3);
1833 input_arr.FillIota(0);
1834 const_arr.FillIota(0);
1835 XlaBuilder builder("ReorderContracting");
1836 auto t0 =
1837 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {d0, d1, d2}), "param0");
1838 auto t1 = Transpose(t0, {0, 2, 1});
1839 auto lhs = Reshape(t1, {d0, d2 * d1});
1840 auto rhs = ConstantR2FromArray2D(&builder, const_arr);
1841 Dot(lhs, rhs);
1842 auto computation = builder.Build().value();
1843
1844 auto input_literal = LiteralUtil::CreateR3FromArray3D<float>(input_arr);
1845 ScopedShapedBuffer buffer0 =
1846 client->LiteralToShapedBuffer(input_literal, device_ordinal).value();
1847
1848 TF_ASSERT_OK_AND_ASSIGN(
1849 auto executables, client->Compile(computation, {&buffer0.on_host_shape()},
1850 ExecutableBuildOptions()));
1851 auto executable = std::move(executables[0]);
1852
1853 se::Stream stream(executors[device_ordinal]);
1854 stream.Init();
1855
1856 ExecutableRunOptions options;
1857 options.set_allocator(&allocator);
1858
1859 const int kWarmups = 2;
1860 for (int i = 0; i < kWarmups; ++i) {
1861 ASSERT_IS_OK(executable->Run({&buffer0}, options));
1862 }
1863
1864 const int64_t total_bytes = d0 * d1 * d2 + d1 * d2 * d3 + d0 * d3;
1865 for (auto s : state) {
1866 ASSERT_IS_OK(executable->Run({&buffer0}, options));
1867 }
1868 state.SetBytesProcessed(state.iterations() * total_bytes * sizeof(float));
1869 }
1870
1871 BENCHMARK(DOT_ReorderContracting)->UseRealTime();
1872
1873 } // namespace
1874 } // namespace xla
1875