1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17
18 #include <limits>
19 #include <map>
20 #include <string>
21 #include <vector>
22
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33
34 namespace xla {
35 namespace {
36
37 class MatrixTest : public ClientLibraryTestBase {
38 protected:
39 template <typename T>
40 void TestMatrixDiagonal();
41 template <typename T>
42 void TestMatrixDiagonal4D();
43 template <typename T>
44 void TestSetMatrixDiagonal();
45
46 template <typename T>
k_and_expected() const47 std::map<int, Array2D<T>> k_and_expected() const {
48 return std::map<int, Array2D<T>>{
49 {0, {{0, 5, 10}, {12, 17, 22}}},
50 {1, {{1, 6, 11}, {13, 18, 23}}},
51 {2, {{2, 7}, {14, 19}}},
52 {3, {{3}, {15}}},
53 {4, {{}, {}}},
54 {-1, {{4, 9}, {16, 21}}},
55 {-2, {{8}, {20}}},
56 {-3, {{}, {}}},
57 {-4, {{}, {}}},
58 };
59 }
60 };
61
XLA_TEST_F(MatrixTest,Triangle)62 XLA_TEST_F(MatrixTest, Triangle) {
63 XlaBuilder builder(TestName());
64 Array3D<int32_t> input(2, 3, 4);
65 input.FillIota(0);
66
67 XlaOp a;
68 auto a_data = CreateR3Parameter<int32_t>(input, 0, "a", &builder, &a);
69 LowerTriangle(a);
70 Array3D<int32_t> expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}},
71 {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}});
72
73 ComputeAndCompareR3<int32_t>(&builder, expected, {a_data.get()});
74 }
75
XLA_TEST_F(MatrixTest,Symmetrize)76 XLA_TEST_F(MatrixTest, Symmetrize) {
77 for (bool lower : {false, true}) {
78 XlaBuilder builder(TestName());
79 float nan = std::numeric_limits<float>::quiet_NaN();
80 Array<float> input = {
81 {1, nan, nan},
82 {2, 3, nan},
83 {4, 5, 6},
84 };
85
86 XlaOp a;
87 auto a_data = CreateParameter<float>(input, 0, "a", &builder, &a);
88 Symmetrize(lower ? a : TransposeInMinorDims(a), /*lower=*/lower);
89
90 Array<float> expected = {
91 {1, 2, 4},
92 {2, 3, 5},
93 {4, 5, 6},
94 };
95
96 ComputeAndCompare<float>(&builder, expected, {a_data.get()});
97 }
98 }
99
XLA_TEST_F(MatrixTest,SymmetrizeComplex)100 XLA_TEST_F(MatrixTest, SymmetrizeComplex) {
101 for (bool lower : {false, true}) {
102 XlaBuilder builder(TestName());
103 float nan = std::numeric_limits<float>::quiet_NaN();
104 Array<complex64> input = {
105 {complex64{1, nan}, nan, nan},
106 {complex64{2, 7}, complex64{3, nan}, nan},
107 {complex64{4, 8}, complex64{5, 9}, complex64{6, nan}},
108 };
109
110 XlaOp a;
111 auto a_data = CreateParameter<complex64>(input, 0, "a", &builder, &a);
112 Symmetrize(lower ? a : Conj(TransposeInMinorDims(a)), /*lower=*/lower);
113
114 Array<complex64> expected = {
115 {1, complex64{2, -7}, complex64{4, -8}},
116 {complex64{2, 7}, 3, complex64{5, -9}},
117 {complex64{4, 8}, complex64{5, 9}, 6},
118 };
119
120 ComputeAndCompare<complex64>(&builder, expected, {a_data.get()});
121 }
122 }
123
XLA_TEST_F(MatrixTest,Symmetrize)124 XLA_TEST_F(MatrixTest, Symmetrize) {
125 for (bool lower : {false, true}) {
126 XlaBuilder builder(TestName());
127 float nan = std::numeric_limits<float>::quiet_NaN();
128 Array<float> input = {
129 {1, nan, nan},
130 {2, 3, nan},
131 {4, 5, 6},
132 };
133
134 XlaOp a;
135 auto a_data = CreateParameter<float>(input, 0, "a", &builder, &a);
136 Symmetrize(lower ? a : TransposeInMinorDims(a), /*lower=*/lower);
137
138 Array<float> expected = {
139 {1, 2, 4},
140 {2, 3, 5},
141 {4, 5, 6},
142 };
143
144 ComputeAndCompare<float>(&builder, expected, {a_data.get()});
145 }
146 }
147
XLA_TEST_F(MatrixTest,SymmetrizeComplex)148 XLA_TEST_F(MatrixTest, SymmetrizeComplex) {
149 for (bool lower : {false, true}) {
150 XlaBuilder builder(TestName());
151 float nan = std::numeric_limits<float>::quiet_NaN();
152 Array<complex64> input = {
153 {complex64{1, nan}, nan, nan},
154 {complex64{2, 7}, complex64{3, nan}, nan},
155 {complex64{4, 8}, complex64{5, 9}, complex64{6, nan}},
156 };
157
158 XlaOp a;
159 auto a_data = CreateParameter<complex64>(input, 0, "a", &builder, &a);
160 Symmetrize(lower ? a : Conj(TransposeInMinorDims(a)), /*lower=*/lower);
161
162 Array<complex64> expected = {
163 {1, complex64{2, -7}, complex64{4, -8}},
164 {complex64{2, 7}, 3, complex64{5, -9}},
165 {complex64{4, 8}, complex64{5, 9}, 6},
166 };
167
168 ComputeAndCompare<complex64>(&builder, expected, {a_data.get()});
169 }
170 }
171
172 template <typename T>
TestMatrixDiagonal()173 void MatrixTest::TestMatrixDiagonal() {
174 XlaBuilder builder("SetMatrixDiagonal");
175 Array3D<T> input(2, 3, 4);
176 input.FillIota(0);
177 for (const auto& kv : k_and_expected<T>()) {
178 XlaOp a;
179 auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
180 GetMatrixDiagonal(a, kv.first);
181
182 ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get()});
183 }
184 }
185
186 template <typename T>
TestSetMatrixDiagonal()187 void MatrixTest::TestSetMatrixDiagonal() {
188 XlaBuilder builder("GetMatrixDiagonal");
189 Array3D<T> input(2, 3, 4);
190 input.FillIota(0);
191 for (const auto& kv : k_and_expected<T>()) {
192 XlaOp a;
193 XlaOp b;
194 auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
195 auto new_diag =
196 CreateR2Parameter<T>(Array2D<T>{kv.second}, 1, "d", &builder, &b);
197
198 GetMatrixDiagonal(SetMatrixDiagonal(a, b + ScalarLike(b, 1), kv.first),
199 kv.first) -
200 ScalarLike(b, 1);
201
202 ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get(), new_diag.get()});
203 }
204 }
205
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_S32)206 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S32) {
207 TestSetMatrixDiagonal<int32_t>();
208 }
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_S64)209 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S64) {
210 TestSetMatrixDiagonal<int64_t>();
211 }
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_F32)212 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_F32) {
213 TestSetMatrixDiagonal<float>();
214 }
215
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S32)216 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32_t>(); }
217
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S64)218 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64_t>(); }
219
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_F32)220 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal<float>(); }
221
222 template <typename T>
TestMatrixDiagonal4D()223 void MatrixTest::TestMatrixDiagonal4D() {
224 XlaBuilder builder("GetMatrixDiagonal");
225 Array4D<T> input(2, 2, 4, 3);
226 input.FillIota(0);
227 std::map<int, Array3D<T>> k_and_expected = {
228 {0, {{{0, 4, 8}, {12, 16, 20}}, {{24, 28, 32}, {36, 40, 44}}}},
229 {1, {{{1, 5}, {13, 17}}, {{25, 29}, {37, 41}}}},
230 {2, {{{2}, {14}}, {{26}, {38}}}},
231 {3, {{{}, {}}, {{}, {}}}},
232 {4, {{{}, {}}, {{}, {}}}},
233 {-1, {{{3, 7, 11}, {15, 19, 23}}, {{27, 31, 35}, {39, 43, 47}}}},
234 {-2, {{{6, 10}, {18, 22}}, {{30, 34}, {42, 46}}}},
235 {-3, {{{9}, {21}}, {{33}, {45}}}},
236 {-4, {{{}, {}}, {{}, {}}}},
237 };
238 for (const auto& kv : k_and_expected) {
239 XlaOp a;
240 auto a_data = CreateR4Parameter<T>(input, 0, "a", &builder, &a);
241 GetMatrixDiagonal(a, kv.first);
242
243 ComputeAndCompareR3<T>(&builder, kv.second, {a_data.get()});
244 }
245 }
246
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_S32)247 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S32) {
248 TestMatrixDiagonal4D<int32_t>();
249 }
250
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_S64)251 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S64) {
252 TestMatrixDiagonal4D<int64_t>();
253 }
254
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_F32)255 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_F32) {
256 TestMatrixDiagonal4D<float>();
257 }
258
BatchedAValsFull()259 Array3D<float> BatchedAValsFull() {
260 return {{
261 {2, 0, 1, 2},
262 {3, 6, 0, 1},
263 {4, 7, 9, 0},
264 {5, 8, 10, 11},
265 },
266 {
267 {16, 24, 8, 12},
268 {24, 61, 82, 48},
269 {8, 82, 456, 106},
270 {12, 48, 106, 62},
271 }};
272 }
273
XLA_TEST_F(MatrixTest,RowBatchDot)274 XLA_TEST_F(MatrixTest, RowBatchDot) {
275 XlaBuilder builder(TestName());
276 int n = 4;
277
278 XlaOp a, row, index;
279 auto a_data =
280 CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
281 auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
282 "row", &builder, &row);
283 // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
284 auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
285
286 auto l_index = DynamicSliceInMinorDims(
287 a, {index, ConstantR0<int32_t>(&builder, 0)}, {1, n});
288 BatchDot(l_index, TransposeInMinorDims(row));
289
290 ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
291 {a_data.get(), row_data.get(), index_data.get()});
292 }
293
XLA_TEST_F(MatrixTest,Einsum)294 XLA_TEST_F(MatrixTest, Einsum) {
295 XlaBuilder builder(TestName());
296
297 int n = 4;
298
299 XlaOp a, row, index;
300 auto a_data =
301 CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
302 auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
303 "row", &builder, &row);
304 // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
305 auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
306
307 auto l_index = DynamicSliceInMinorDims(
308 a, {index, ConstantR0<int32_t>(&builder, 0)}, {1, n});
309 Einsum(l_index, row, "abc,adc->abd");
310
311 ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
312 {a_data.get(), row_data.get(), index_data.get()});
313 }
314
XLA_TEST_F(MatrixTest,ParseEinsumString)315 XLA_TEST_F(MatrixTest, ParseEinsumString) {
316 auto to_vec = [](absl::string_view s) {
317 std::vector<int64_t> v;
318 v.reserve(s.size());
319 int e = -3;
320 for (auto c : s) {
321 v.push_back(c == '.' ? e++ : int64_t{c});
322 }
323 return v;
324 };
325
326 auto to_string = [&](absl::string_view x, absl::string_view y,
327 absl::string_view o) {
328 return absl::StrCat(x, ",", y, "->", o);
329 };
330
331 std::vector<std::vector<std::string>> good_test_cases = {
332 {"ab", "bc", "ac"},
333 {"Bab", "Bbc", "Bac"},
334 {"ab", "cd", "dcba"},
335 {"abc", "abd", "cbd"},
336 {"...ab", "...bc", "...ac"},
337 {"a...bc", "...abd", "cbd..."},
338 {"...ab", "...bc", "ac"},
339 {"...b", "...bc", "...c"},
340 {"...abz", "...bc", "...ac"},
341 {"...ab", "...bcz", "...ac"},
342 {"abz", "bc", "ac"},
343 {"ab", "bcz", "ac"},
344
345 {"a", "b", "c"},
346 {"...a", "...b", "...c"},
347 {"abb", "bcc", "ac"},
348 {"ab", "bc", "ad"},
349 };
350 for (auto test_case : good_test_cases) {
351 auto parse_result_or_status =
352 ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2]),
353 test_case[0].size(), test_case[1].size());
354 EXPECT_TRUE(parse_result_or_status.status().ok());
355 auto parse_result = parse_result_or_status.ValueOrDie();
356 for (int i = 0; i < 3; ++i) {
357 EXPECT_EQ(parse_result[i], to_vec(test_case[i]));
358 }
359 }
360
361 std::vector<std::string> einsum_strings_that_fail_parsing = {
362 "", "a", "ab->ba", "ab,bc,cd->ad", "a...b...,bc->a...c",
363 };
364 for (auto test_case : einsum_strings_that_fail_parsing) {
365 auto parse_result_or_status = ParseEinsumString(test_case, 3, 3);
366 EXPECT_FALSE(parse_result_or_status.status().ok());
367 }
368 }
369
XLA_TEST_F(MatrixTest,NormalizeEinsumString)370 XLA_TEST_F(MatrixTest, NormalizeEinsumString) {
371 EXPECT_EQ(NormalizeEinsumString("a,b->ab"), "");
372 EXPECT_EQ(NormalizeEinsumString("ba"), "ba->ab");
373 EXPECT_EQ(NormalizeEinsumString("ab,dc"), "ab,dc->abcd");
374 EXPECT_EQ(NormalizeEinsumString("a,b"), "a,b->ab");
375 EXPECT_EQ(NormalizeEinsumString("...ba,ca..."), "...ba,ca...->...bc");
376 }
377
378 } // namespace
379 } // namespace xla
380