xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/matrix_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17 
18 #include <limits>
19 #include <map>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/status.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 
34 namespace xla {
35 namespace {
36 
37 class MatrixTest : public ClientLibraryTestBase {
38  protected:
39   template <typename T>
40   void TestMatrixDiagonal();
41   template <typename T>
42   void TestMatrixDiagonal4D();
43   template <typename T>
44   void TestSetMatrixDiagonal();
45 
46   template <typename T>
k_and_expected() const47   std::map<int, Array2D<T>> k_and_expected() const {
48     return std::map<int, Array2D<T>>{
49         {0, {{0, 5, 10}, {12, 17, 22}}},
50         {1, {{1, 6, 11}, {13, 18, 23}}},
51         {2, {{2, 7}, {14, 19}}},
52         {3, {{3}, {15}}},
53         {4, {{}, {}}},
54         {-1, {{4, 9}, {16, 21}}},
55         {-2, {{8}, {20}}},
56         {-3, {{}, {}}},
57         {-4, {{}, {}}},
58     };
59   }
60 };
61 
XLA_TEST_F(MatrixTest,Triangle)62 XLA_TEST_F(MatrixTest, Triangle) {
63   XlaBuilder builder(TestName());
64   Array3D<int32_t> input(2, 3, 4);
65   input.FillIota(0);
66 
67   XlaOp a;
68   auto a_data = CreateR3Parameter<int32_t>(input, 0, "a", &builder, &a);
69   LowerTriangle(a);
70   Array3D<int32_t> expected({{{0, 0, 0, 0}, {4, 5, 0, 0}, {8, 9, 10, 0}},
71                              {{12, 0, 0, 0}, {16, 17, 0, 0}, {20, 21, 22, 0}}});
72 
73   ComputeAndCompareR3<int32_t>(&builder, expected, {a_data.get()});
74 }
75 
XLA_TEST_F(MatrixTest,Symmetrize)76 XLA_TEST_F(MatrixTest, Symmetrize) {
77   for (bool lower : {false, true}) {
78     XlaBuilder builder(TestName());
79     float nan = std::numeric_limits<float>::quiet_NaN();
80     Array<float> input = {
81         {1, nan, nan},
82         {2, 3, nan},
83         {4, 5, 6},
84     };
85 
86     XlaOp a;
87     auto a_data = CreateParameter<float>(input, 0, "a", &builder, &a);
88     Symmetrize(lower ? a : TransposeInMinorDims(a), /*lower=*/lower);
89 
90     Array<float> expected = {
91         {1, 2, 4},
92         {2, 3, 5},
93         {4, 5, 6},
94     };
95 
96     ComputeAndCompare<float>(&builder, expected, {a_data.get()});
97   }
98 }
99 
XLA_TEST_F(MatrixTest,SymmetrizeComplex)100 XLA_TEST_F(MatrixTest, SymmetrizeComplex) {
101   for (bool lower : {false, true}) {
102     XlaBuilder builder(TestName());
103     float nan = std::numeric_limits<float>::quiet_NaN();
104     Array<complex64> input = {
105         {complex64{1, nan}, nan, nan},
106         {complex64{2, 7}, complex64{3, nan}, nan},
107         {complex64{4, 8}, complex64{5, 9}, complex64{6, nan}},
108     };
109 
110     XlaOp a;
111     auto a_data = CreateParameter<complex64>(input, 0, "a", &builder, &a);
112     Symmetrize(lower ? a : Conj(TransposeInMinorDims(a)), /*lower=*/lower);
113 
114     Array<complex64> expected = {
115         {1, complex64{2, -7}, complex64{4, -8}},
116         {complex64{2, 7}, 3, complex64{5, -9}},
117         {complex64{4, 8}, complex64{5, 9}, 6},
118     };
119 
120     ComputeAndCompare<complex64>(&builder, expected, {a_data.get()});
121   }
122 }
123 
XLA_TEST_F(MatrixTest,Symmetrize)124 XLA_TEST_F(MatrixTest, Symmetrize) {
125   for (bool lower : {false, true}) {
126     XlaBuilder builder(TestName());
127     float nan = std::numeric_limits<float>::quiet_NaN();
128     Array<float> input = {
129         {1, nan, nan},
130         {2, 3, nan},
131         {4, 5, 6},
132     };
133 
134     XlaOp a;
135     auto a_data = CreateParameter<float>(input, 0, "a", &builder, &a);
136     Symmetrize(lower ? a : TransposeInMinorDims(a), /*lower=*/lower);
137 
138     Array<float> expected = {
139         {1, 2, 4},
140         {2, 3, 5},
141         {4, 5, 6},
142     };
143 
144     ComputeAndCompare<float>(&builder, expected, {a_data.get()});
145   }
146 }
147 
XLA_TEST_F(MatrixTest,SymmetrizeComplex)148 XLA_TEST_F(MatrixTest, SymmetrizeComplex) {
149   for (bool lower : {false, true}) {
150     XlaBuilder builder(TestName());
151     float nan = std::numeric_limits<float>::quiet_NaN();
152     Array<complex64> input = {
153         {complex64{1, nan}, nan, nan},
154         {complex64{2, 7}, complex64{3, nan}, nan},
155         {complex64{4, 8}, complex64{5, 9}, complex64{6, nan}},
156     };
157 
158     XlaOp a;
159     auto a_data = CreateParameter<complex64>(input, 0, "a", &builder, &a);
160     Symmetrize(lower ? a : Conj(TransposeInMinorDims(a)), /*lower=*/lower);
161 
162     Array<complex64> expected = {
163         {1, complex64{2, -7}, complex64{4, -8}},
164         {complex64{2, 7}, 3, complex64{5, -9}},
165         {complex64{4, 8}, complex64{5, 9}, 6},
166     };
167 
168     ComputeAndCompare<complex64>(&builder, expected, {a_data.get()});
169   }
170 }
171 
172 template <typename T>
TestMatrixDiagonal()173 void MatrixTest::TestMatrixDiagonal() {
174   XlaBuilder builder("SetMatrixDiagonal");
175   Array3D<T> input(2, 3, 4);
176   input.FillIota(0);
177   for (const auto& kv : k_and_expected<T>()) {
178     XlaOp a;
179     auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
180     GetMatrixDiagonal(a, kv.first);
181 
182     ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get()});
183   }
184 }
185 
186 template <typename T>
TestSetMatrixDiagonal()187 void MatrixTest::TestSetMatrixDiagonal() {
188   XlaBuilder builder("GetMatrixDiagonal");
189   Array3D<T> input(2, 3, 4);
190   input.FillIota(0);
191   for (const auto& kv : k_and_expected<T>()) {
192     XlaOp a;
193     XlaOp b;
194     auto a_data = CreateR3Parameter<T>(input, 0, "a", &builder, &a);
195     auto new_diag =
196         CreateR2Parameter<T>(Array2D<T>{kv.second}, 1, "d", &builder, &b);
197 
198     GetMatrixDiagonal(SetMatrixDiagonal(a, b + ScalarLike(b, 1), kv.first),
199                       kv.first) -
200         ScalarLike(b, 1);
201 
202     ComputeAndCompareR2<T>(&builder, kv.second, {a_data.get(), new_diag.get()});
203   }
204 }
205 
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_S32)206 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S32) {
207   TestSetMatrixDiagonal<int32_t>();
208 }
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_S64)209 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_S64) {
210   TestSetMatrixDiagonal<int64_t>();
211 }
XLA_TEST_F(MatrixTest,SetMatrixDiagonal_F32)212 XLA_TEST_F(MatrixTest, SetMatrixDiagonal_F32) {
213   TestSetMatrixDiagonal<float>();
214 }
215 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S32)216 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S32) { TestMatrixDiagonal<int32_t>(); }
217 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_S64)218 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_S64) { TestMatrixDiagonal<int64_t>(); }
219 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal_F32)220 XLA_TEST_F(MatrixTest, GetMatrixDiagonal_F32) { TestMatrixDiagonal<float>(); }
221 
222 template <typename T>
TestMatrixDiagonal4D()223 void MatrixTest::TestMatrixDiagonal4D() {
224   XlaBuilder builder("GetMatrixDiagonal");
225   Array4D<T> input(2, 2, 4, 3);
226   input.FillIota(0);
227   std::map<int, Array3D<T>> k_and_expected = {
228       {0, {{{0, 4, 8}, {12, 16, 20}}, {{24, 28, 32}, {36, 40, 44}}}},
229       {1, {{{1, 5}, {13, 17}}, {{25, 29}, {37, 41}}}},
230       {2, {{{2}, {14}}, {{26}, {38}}}},
231       {3, {{{}, {}}, {{}, {}}}},
232       {4, {{{}, {}}, {{}, {}}}},
233       {-1, {{{3, 7, 11}, {15, 19, 23}}, {{27, 31, 35}, {39, 43, 47}}}},
234       {-2, {{{6, 10}, {18, 22}}, {{30, 34}, {42, 46}}}},
235       {-3, {{{9}, {21}}, {{33}, {45}}}},
236       {-4, {{{}, {}}, {{}, {}}}},
237   };
238   for (const auto& kv : k_and_expected) {
239     XlaOp a;
240     auto a_data = CreateR4Parameter<T>(input, 0, "a", &builder, &a);
241     GetMatrixDiagonal(a, kv.first);
242 
243     ComputeAndCompareR3<T>(&builder, kv.second, {a_data.get()});
244   }
245 }
246 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_S32)247 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S32) {
248   TestMatrixDiagonal4D<int32_t>();
249 }
250 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_S64)251 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_S64) {
252   TestMatrixDiagonal4D<int64_t>();
253 }
254 
XLA_TEST_F(MatrixTest,GetMatrixDiagonal4D_F32)255 XLA_TEST_F(MatrixTest, GetMatrixDiagonal4D_F32) {
256   TestMatrixDiagonal4D<float>();
257 }
258 
BatchedAValsFull()259 Array3D<float> BatchedAValsFull() {
260   return {{
261               {2, 0, 1, 2},
262               {3, 6, 0, 1},
263               {4, 7, 9, 0},
264               {5, 8, 10, 11},
265           },
266           {
267               {16, 24, 8, 12},
268               {24, 61, 82, 48},
269               {8, 82, 456, 106},
270               {12, 48, 106, 62},
271           }};
272 }
273 
XLA_TEST_F(MatrixTest,RowBatchDot)274 XLA_TEST_F(MatrixTest, RowBatchDot) {
275   XlaBuilder builder(TestName());
276   int n = 4;
277 
278   XlaOp a, row, index;
279   auto a_data =
280       CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
281   auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
282                                            "row", &builder, &row);
283   // Select {{3, 6, 0, 1}, {24, 61,  82,  48}} out of BatchedAValsFull().
284   auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
285 
286   auto l_index = DynamicSliceInMinorDims(
287       a, {index, ConstantR0<int32_t>(&builder, 0)}, {1, n});
288   BatchDot(l_index, TransposeInMinorDims(row));
289 
290   ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
291                              {a_data.get(), row_data.get(), index_data.get()});
292 }
293 
XLA_TEST_F(MatrixTest,Einsum)294 XLA_TEST_F(MatrixTest, Einsum) {
295   XlaBuilder builder(TestName());
296 
297   int n = 4;
298 
299   XlaOp a, row, index;
300   auto a_data =
301       CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
302   auto row_data = CreateR3Parameter<float>({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1,
303                                            "row", &builder, &row);
304   // Select {{3, 6, 0, 1}, {24, 61,  82,  48}} out of BatchedAValsFull().
305   auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
306 
307   auto l_index = DynamicSliceInMinorDims(
308       a, {index, ConstantR0<int32_t>(&builder, 0)}, {1, n});
309   Einsum(l_index, row, "abc,adc->abd");
310 
311   ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
312                              {a_data.get(), row_data.get(), index_data.get()});
313 }
314 
XLA_TEST_F(MatrixTest,ParseEinsumString)315 XLA_TEST_F(MatrixTest, ParseEinsumString) {
316   auto to_vec = [](absl::string_view s) {
317     std::vector<int64_t> v;
318     v.reserve(s.size());
319     int e = -3;
320     for (auto c : s) {
321       v.push_back(c == '.' ? e++ : int64_t{c});
322     }
323     return v;
324   };
325 
326   auto to_string = [&](absl::string_view x, absl::string_view y,
327                        absl::string_view o) {
328     return absl::StrCat(x, ",", y, "->", o);
329   };
330 
331   std::vector<std::vector<std::string>> good_test_cases = {
332       {"ab", "bc", "ac"},
333       {"Bab", "Bbc", "Bac"},
334       {"ab", "cd", "dcba"},
335       {"abc", "abd", "cbd"},
336       {"...ab", "...bc", "...ac"},
337       {"a...bc", "...abd", "cbd..."},
338       {"...ab", "...bc", "ac"},
339       {"...b", "...bc", "...c"},
340       {"...abz", "...bc", "...ac"},
341       {"...ab", "...bcz", "...ac"},
342       {"abz", "bc", "ac"},
343       {"ab", "bcz", "ac"},
344 
345       {"a", "b", "c"},
346       {"...a", "...b", "...c"},
347       {"abb", "bcc", "ac"},
348       {"ab", "bc", "ad"},
349   };
350   for (auto test_case : good_test_cases) {
351     auto parse_result_or_status =
352         ParseEinsumString(to_string(test_case[0], test_case[1], test_case[2]),
353                           test_case[0].size(), test_case[1].size());
354     EXPECT_TRUE(parse_result_or_status.status().ok());
355     auto parse_result = parse_result_or_status.ValueOrDie();
356     for (int i = 0; i < 3; ++i) {
357       EXPECT_EQ(parse_result[i], to_vec(test_case[i]));
358     }
359   }
360 
361   std::vector<std::string> einsum_strings_that_fail_parsing = {
362       "", "a", "ab->ba", "ab,bc,cd->ad", "a...b...,bc->a...c",
363   };
364   for (auto test_case : einsum_strings_that_fail_parsing) {
365     auto parse_result_or_status = ParseEinsumString(test_case, 3, 3);
366     EXPECT_FALSE(parse_result_or_status.status().ok());
367   }
368 }
369 
XLA_TEST_F(MatrixTest,NormalizeEinsumString)370 XLA_TEST_F(MatrixTest, NormalizeEinsumString) {
371   EXPECT_EQ(NormalizeEinsumString("a,b->ab"), "");
372   EXPECT_EQ(NormalizeEinsumString("ba"), "ba->ab");
373   EXPECT_EQ(NormalizeEinsumString("ab,dc"), "ab,dc->abcd");
374   EXPECT_EQ(NormalizeEinsumString("a,b"), "a,b->ab");
375   EXPECT_EQ(NormalizeEinsumString("...ba,ca..."), "...ba,ca...->...bc");
376 }
377 
378 }  // namespace
379 }  // namespace xla
380