1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <numeric>
18 #include <vector>
19
20 #include "absl/strings/ascii.h"
21 #include "tensorflow/compiler/xla/array.h"
22 #include "tensorflow/compiler/xla/array2d.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/lib/matrix.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35
36 namespace xla {
37 namespace {
38
39 using TriangularSolveTest = ClientLibraryTestBase;
40 using TriangularSolveLeftLookingTest = ClientLibraryTestBase;
41
42 static constexpr float kNan = std::numeric_limits<float>::quiet_NaN();
43
AValsLower()44 Array2D<float> AValsLower() {
45 return {{2, kNan, kNan, kNan},
46 {3, 6, kNan, kNan},
47 {4, 7, 9, kNan},
48 {5, 8, 10, 11}};
49 }
50
AValsUpper()51 Array2D<float> AValsUpper() {
52 return {{2, 3, 4, 5},
53 {kNan, 6, 7, 8},
54 {kNan, kNan, 9, 10},
55 {kNan, kNan, kNan, 11}};
56 }
57
AValsLowerUnitDiagonal()58 Array2D<float> AValsLowerUnitDiagonal() {
59 return {{kNan, kNan, kNan, kNan},
60 {3, kNan, kNan, kNan},
61 {4, 7, kNan, kNan},
62 {5, 8, 10, kNan}};
63 }
64
AValsUpperUnitDiagonal()65 Array2D<float> AValsUpperUnitDiagonal() {
66 return {{kNan, 3, 4, 5},
67 {kNan, kNan, 7, 8},
68 {kNan, kNan, kNan, 10},
69 {kNan, kNan, kNan, kNan}};
70 }
71
BValsRight()72 Array2D<float> BValsRight() {
73 return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
74 }
75
BValsLeft()76 Array2D<float> BValsLeft() {
77 return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
78 }
79
80 static constexpr complex64 kNanC64 = complex64(kNan, kNan);
81
AValsLowerComplex()82 Array2D<complex64> AValsLowerComplex() {
83 return {{2, kNanC64, kNanC64, kNanC64},
84 {complex64(3, 1), 6, kNanC64, kNanC64},
85 {4, complex64(7, 2), 9, kNanC64},
86 {5, 8, complex64(10, 3), 11}};
87 }
88
AValsUpperComplex()89 Array2D<complex64> AValsUpperComplex() {
90 return {{2, 3, complex64(4, 3), 5},
91 {kNanC64, 6, complex64(7, 2), 8},
92 {kNanC64, kNanC64, complex64(9, 1), 10},
93 {kNanC64, kNanC64, kNanC64, 11}};
94 }
95
BValsRightComplex()96 Array2D<complex64> BValsRightComplex() {
97 return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
98 }
99
BValsLeftComplex()100 Array2D<complex64> BValsLeftComplex() {
101 return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
102 }
103
XLA_TEST_F(TriangularSolveTest,EmptyArrays)104 XLA_TEST_F(TriangularSolveTest, EmptyArrays) {
105 XlaBuilder builder(TestName());
106
107 XlaOp a, b;
108 auto a_data =
109 CreateR2Parameter<float>(Array2D<float>(0, 0), 0, "a", &builder, &a);
110 auto b_data =
111 CreateR2Parameter<float>(Array2D<float>(0, 10), 1, "b", &builder, &b);
112 TriangularSolve(a, b,
113 /*left_side=*/true, /*lower=*/true,
114 /*unit_diagonal=*/false,
115 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
116
117 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 10),
118 {a_data.get(), b_data.get()});
119 }
120
XLA_TEST_F(TriangularSolveTest,SimpleRightLowerTranspose)121 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
122 XlaBuilder builder(TestName());
123
124 XlaOp a, b;
125 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
126 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
127 TriangularSolve(a, b,
128 /*left_side=*/false, /*lower=*/true,
129 /*unit_diagonal=*/false,
130 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
131
132 Array2D<float> expected({
133 {0.5, 0.08333334, 0.04629629, 0.03367003},
134 {2.5, -0.25, -0.1388889, -0.1010101},
135 {4.5, -0.58333331, -0.32407406, -0.23569024},
136 });
137
138 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
139 ErrorSpec(1e-2, 1e-2));
140 }
141
XLA_TEST_F(TriangularSolveTest,SimpleRightLowerNotranspose)142 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
143 XlaBuilder builder(TestName());
144
145 XlaOp a, b;
146 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
147 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
148 TriangularSolve(a, b,
149 /*left_side=*/false, /*lower=*/true,
150 /*unit_diagonal=*/false,
151 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
152
153 Array2D<float> expected({
154 {-0.16414141, -0.06902357, -0.07070707, 0.36363636},
155 {0.64393939, 0.06565657, -0.03030303, 0.72727273},
156 {1.4520202, 0.2003367, 0.01010101, 1.09090909},
157 });
158
159 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
160 ErrorSpec(1e-2, 1e-2));
161 }
162
XLA_TEST_F(TriangularSolveTest,SimpleRightUpperTranspose)163 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
164 XlaBuilder builder(TestName());
165
166 XlaOp a, b;
167 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
168 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
169 TriangularSolve(a, b,
170 /*left_side=*/false, /*lower=*/false,
171 /*unit_diagonal=*/false,
172 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
173
174 Array2D<float> expected({
175 {-0.16414141, -0.06902357, -0.07070707, 0.36363636},
176 {0.64393939, 0.06565657, -0.03030303, 0.72727273},
177 {1.4520202, 0.2003367, 0.01010101, 1.09090909},
178 });
179
180 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
181 ErrorSpec(1e-2, 1e-2));
182 }
183
XLA_TEST_F(TriangularSolveTest,SimpleRightUpperNotranspose)184 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
185 XlaBuilder builder(TestName());
186
187 XlaOp a, b;
188 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
189 auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
190 TriangularSolve(a, b,
191 /*left_side=*/false, /*lower=*/false,
192 /*unit_diagonal=*/false,
193 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
194
195 Array2D<float> expected({
196 {0.5, 0.08333334, 0.04629629, 0.03367003},
197 {2.5, -0.25, -0.1388889, -0.1010101},
198 {4.5, -0.58333331, -0.32407406, -0.23569024},
199 });
200
201 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
202 ErrorSpec(1e-2, 1e-2));
203 }
204
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerTranspose)205 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
206 XlaBuilder builder(TestName());
207
208 XlaOp a, b;
209 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
210 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
211 TriangularSolve(a, b,
212 /*left_side=*/true, /*lower=*/true,
213 /*unit_diagonal=*/false,
214 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
215
216 Array2D<float> expected({
217 {-0.89646465, -0.69444444, -0.49242424},
218 {-0.27441077, -0.24074074, -0.20707071},
219 {-0.23232323, -0.22222222, -0.21212121},
220 {0.90909091, 1., 1.09090909},
221 });
222
223 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
224 ErrorSpec(1e-2, 1e-2));
225 }
226
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerNotranspose)227 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
228 XlaBuilder builder(TestName());
229
230 XlaOp a, b;
231 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
232 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
233 TriangularSolve(a, b,
234 /*left_side=*/true, /*lower=*/true,
235 /*unit_diagonal=*/false,
236 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
237
238 Array2D<float> expected({
239 {0.5, 1.0, 1.5},
240 {0.41666667, 0.33333333, 0.25},
241 {0.23148148, 0.18518519, 0.13888889},
242 {0.16835017, 0.13468013, 0.1010101},
243 });
244
245 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
246 ErrorSpec(1e-2, 1e-2));
247 }
248
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerNoTransposeUnitDiagonal)249 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNoTransposeUnitDiagonal) {
250 XlaBuilder builder(TestName());
251
252 XlaOp a, b;
253 auto a_data =
254 CreateR2Parameter<float>(AValsLowerUnitDiagonal(), 0, "a", &builder, &a);
255 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
256 TriangularSolve(a, b,
257 /*left_side=*/true, /*lower=*/true,
258 /*unit_diagonal=*/true,
259 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
260
261 Array2D<float> expected(
262 {{1., 2., 3.}, {1., -1., -3.}, {-4., 7., 18.}, {37., -61., -159.}});
263
264 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
265 ErrorSpec(1e-2, 1e-2));
266 }
267
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerNotransposeIrregularblock)268 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) {
269 XlaBuilder builder(TestName());
270
271 XlaOp a, b;
272 auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
273 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
274 TriangularSolve(a, b,
275 /*left_side=*/true, /*lower=*/true,
276 /*unit_diagonal=*/false,
277 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
278
279 Array2D<float> expected({
280 {0.5, 1.0, 1.5},
281 {0.41666667, 0.33333333, 0.25},
282 {0.23148148, 0.18518519, 0.13888889},
283 {0.16835017, 0.13468013, 0.1010101},
284 });
285
286 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
287 ErrorSpec(1e-2, 1e-2));
288 }
289
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperTranspose)290 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
291 XlaBuilder builder(TestName());
292
293 XlaOp a, b;
294 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
295 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
296 TriangularSolve(a, b,
297 /*left_side=*/true, /*lower=*/false,
298 /*unit_diagonal=*/false,
299 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
300
301 Array2D<float> expected({
302 {0.5, 1.0, 1.5},
303 {0.41666667, 0.33333333, 0.25},
304 {0.23148148, 0.18518519, 0.13888889},
305 {0.16835017, 0.13468013, 0.1010101},
306 });
307
308 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
309 ErrorSpec(1e-2, 1e-2));
310 }
311
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperNotranspose)312 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
313 XlaBuilder builder(TestName());
314
315 XlaOp a, b;
316 auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
317 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
318 TriangularSolve(a, b,
319 /*left_side=*/true, /*lower=*/false,
320 /*unit_diagonal=*/false,
321 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
322
323 Array2D<float> expected({
324 {-0.89646465, -0.69444444, -0.49242424},
325 {-0.27441077, -0.24074074, -0.20707071},
326 {-0.23232323, -0.22222222, -0.21212121},
327 {0.90909091, 1., 1.09090909},
328 });
329
330 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
331 ErrorSpec(1e-2, 1e-2));
332 }
333
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperNotransposeUnitDiagonal)334 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotransposeUnitDiagonal) {
335 XlaBuilder builder(TestName());
336
337 XlaOp a, b;
338 auto a_data =
339 CreateR2Parameter<float>(AValsUpperUnitDiagonal(), 0, "a", &builder, &a);
340 auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
341 TriangularSolve(a, b,
342 /*left_side=*/true, /*lower=*/false,
343 /*unit_diagonal=*/true,
344 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
345
346 Array2D<float> expected({{-1402., -1538., -1674.},
347 {575., 631., 687.},
348 {-93., -102., -111.},
349 {10., 11., 12.}});
350
351 ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
352 ErrorSpec(1e-2, 1e-2));
353 }
354
XLA_TEST_F(TriangularSolveTest,SimpleRightLowerTransposeConjugate)355 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
356 XlaBuilder builder(TestName());
357
358 XlaOp a, b;
359 auto a_data =
360 CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
361 auto b_data =
362 CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b);
363 TriangularSolve(a, b,
364 /*left_side=*/false, /*lower=*/true,
365 /*unit_diagonal=*/false,
366 /*transpose_a=*/TriangularSolveOptions::ADJOINT);
367
368 Array2D<complex64> expected({
369 {0.5, complex64(0.08333333, 0.08333333),
370 complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)},
371 {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963),
372 complex64(0.08670034, -0.02104377)},
373 {4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296),
374 complex64(0.11026936, -0.03114478)},
375 });
376
377 ComputeAndCompareR2<complex64>(
378 &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2));
379 }
380
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperTransposeNoconjugate)381 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
382 XlaBuilder builder(TestName());
383
384 XlaOp a, b;
385 auto a_data =
386 CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
387 auto b_data =
388 CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b);
389 TriangularSolve(a, b,
390 /*left_side=*/true, /*lower=*/false,
391 /*unit_diagonal=*/false,
392 /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
393
394 Array2D<complex64> expected({
395 {0.5, 1., 1.5},
396 {0.41666667, 0.33333333, 0.25},
397 {complex64(0.20020325, -2.81504065e-01),
398 complex64(0.13821138, -4.22764228e-01),
399 complex64(0.07621951, -5.64024390e-01)},
400 {complex64(0.19678492, 2.55912786e-01),
401 complex64(0.17738359, 3.84331116e-01),
402 complex64(0.15798226, 5.12749446e-01)},
403 });
404
405 ComputeAndCompareR2<complex64>(
406 &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2));
407 }
408
XLA_TEST_F(TriangularSolveTest,BatchedLeftUpper)409 XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) {
410 XlaBuilder builder(TestName());
411
412 Array3D<float> bvals(7, 5, 5);
413 bvals.FillIota(1.);
414
415 // Set avals to the upper triangle of bvals.
416 Array3D<float> avals = bvals;
417 avals.Each([](absl::Span<const int64_t> indices, float* value) {
418 if (indices[1] > indices[2]) {
419 *value = 0;
420 }
421 });
422
423 XlaOp a, b;
424 auto a_data = CreateR3Parameter<float>(avals, 0, "a", &builder, &a);
425 auto b_data = CreateR3Parameter<float>(bvals, 1, "b", &builder, &b);
426 BatchDot(
427 ConstantR3FromArray3D(&builder, avals),
428 TriangularSolve(a, b,
429 /*left_side=*/true, /*lower=*/false,
430 /*unit_diagonal=*/false,
431 /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE));
432
433 ComputeAndCompareR3<float>(&builder, bvals, {a_data.get(), b_data.get()},
434 ErrorSpec(1e-2, 1e-2));
435 }
436
437 struct TriangularSolveTestSpec {
438 std::vector<int64_t> dims; // [..., m, n] A is mxm, B is mxn
439 bool left_side;
440 bool lower;
441 TriangularSolveOptions::Transpose transpose_a;
442 };
443
444 class TriangularSolveParametricTest
445 : public ClientLibraryTestBase,
446 public ::testing::WithParamInterface<TriangularSolveTestSpec> {};
447
XLA_TEST_P(TriangularSolveParametricTest,Random)448 XLA_TEST_P(TriangularSolveParametricTest, Random) {
449 TriangularSolveTestSpec spec = GetParam();
450
451 XlaBuilder builder(TestName());
452
453 CHECK_GE(spec.dims.size(), 2);
454 std::vector<int64_t> a_dims = spec.dims;
455 a_dims.back() = a_dims.at(a_dims.size() - 2);
456 Array<float> avals(a_dims);
457 avals.FillRandom(1.0);
458 avals.Each([](absl::Span<const int64_t> dims, float* v) {
459 if (dims.back() == dims.at(dims.size() - 2)) {
460 *v += 30;
461 }
462 });
463
464 std::vector<int64_t> b_dims = spec.dims;
465 if (!spec.left_side) {
466 std::swap(b_dims.back(), b_dims.at(b_dims.size() - 2));
467 }
468 Array<float> bvals(b_dims);
469 bvals.FillRandom(1.0);
470
471 XlaOp a, b;
472 auto a_data = CreateParameter<float>(avals, 0, "a", &builder, &a);
473 auto b_data = CreateParameter<float>(bvals, 1, "b", &builder, &b);
474 auto x = TriangularSolve(a, b, spec.left_side, spec.lower,
475 /*unit_diagonal=*/false, spec.transpose_a);
476 auto a_tri = Triangle(a, spec.lower);
477 a_tri = MaybeTransposeInMinorDims(
478 a_tri, spec.transpose_a != TriangularSolveOptions::NO_TRANSPOSE);
479 if (spec.left_side) {
480 BatchDot(a_tri, x, xla::PrecisionConfig::HIGHEST);
481 } else {
482 BatchDot(x, a_tri, xla::PrecisionConfig::HIGHEST);
483 }
484
485 ComputeAndCompare<float>(&builder, bvals, {a_data.get(), b_data.get()},
486 ErrorSpec(3e-2, 3e-2));
487 }
488
TriangularSolveTests()489 std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
490 std::vector<TriangularSolveTestSpec> specs;
491 for (auto batch : {std::vector<int64_t>{}, std::vector<int64_t>{1},
492 std::vector<int64_t>{5}, std::vector<int64_t>{65},
493 std::vector<int64_t>{129}}) {
494 for (int m : {5, 10, 150}) {
495 for (int n : {5, 150}) {
496 for (bool left_side : {false, true}) {
497 for (bool lower : {false, true}) {
498 for (TriangularSolveOptions::Transpose transpose_a :
499 {TriangularSolveOptions::NO_TRANSPOSE,
500 TriangularSolveOptions::TRANSPOSE}) {
501 std::vector<int64_t> dims(batch.begin(), batch.end());
502 dims.push_back(m);
503 dims.push_back(n);
504 specs.push_back({dims, left_side, lower, transpose_a});
505 }
506 }
507 }
508 }
509 }
510 }
511 return specs;
512 }
513
514 INSTANTIATE_TEST_SUITE_P(
515 TriangularSolveParametricTestInstantiation, TriangularSolveParametricTest,
516 ::testing::ValuesIn(TriangularSolveTests()),
__anon3c5653df0402(const ::testing::TestParamInfo<TriangularSolveTestSpec>& info) 517 [](const ::testing::TestParamInfo<TriangularSolveTestSpec>& info) {
518 const TriangularSolveTestSpec& spec = info.param;
519 std::string name = absl::StrCat(
520 absl::StrJoin(spec.dims, "_"), "_", spec.left_side ? "left" : "right",
521 "_", spec.lower ? "lower" : "upper", "_",
522 absl::AsciiStrToLower(
523 TriangularSolveOptions_Transpose_Name(spec.transpose_a)));
524 return name;
525 });
526
527 } // namespace
528 } // namespace xla
529