xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/triangular_solve_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <numeric>
18 #include <vector>
19 
20 #include "absl/strings/ascii.h"
21 #include "tensorflow/compiler/xla/array.h"
22 #include "tensorflow/compiler/xla/array2d.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/lib/matrix.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/tests/test_macros.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 
36 namespace xla {
37 namespace {
38 
39 using TriangularSolveTest = ClientLibraryTestBase;
40 using TriangularSolveLeftLookingTest = ClientLibraryTestBase;
41 
42 static constexpr float kNan = std::numeric_limits<float>::quiet_NaN();
43 
AValsLower()44 Array2D<float> AValsLower() {
45   return {{2, kNan, kNan, kNan},
46           {3, 6, kNan, kNan},
47           {4, 7, 9, kNan},
48           {5, 8, 10, 11}};
49 }
50 
AValsUpper()51 Array2D<float> AValsUpper() {
52   return {{2, 3, 4, 5},
53           {kNan, 6, 7, 8},
54           {kNan, kNan, 9, 10},
55           {kNan, kNan, kNan, 11}};
56 }
57 
AValsLowerUnitDiagonal()58 Array2D<float> AValsLowerUnitDiagonal() {
59   return {{kNan, kNan, kNan, kNan},
60           {3, kNan, kNan, kNan},
61           {4, 7, kNan, kNan},
62           {5, 8, 10, kNan}};
63 }
64 
AValsUpperUnitDiagonal()65 Array2D<float> AValsUpperUnitDiagonal() {
66   return {{kNan, 3, 4, 5},
67           {kNan, kNan, 7, 8},
68           {kNan, kNan, kNan, 10},
69           {kNan, kNan, kNan, kNan}};
70 }
71 
BValsRight()72 Array2D<float> BValsRight() {
73   return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
74 }
75 
BValsLeft()76 Array2D<float> BValsLeft() {
77   return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
78 }
79 
80 static constexpr complex64 kNanC64 = complex64(kNan, kNan);
81 
AValsLowerComplex()82 Array2D<complex64> AValsLowerComplex() {
83   return {{2, kNanC64, kNanC64, kNanC64},
84           {complex64(3, 1), 6, kNanC64, kNanC64},
85           {4, complex64(7, 2), 9, kNanC64},
86           {5, 8, complex64(10, 3), 11}};
87 }
88 
AValsUpperComplex()89 Array2D<complex64> AValsUpperComplex() {
90   return {{2, 3, complex64(4, 3), 5},
91           {kNanC64, 6, complex64(7, 2), 8},
92           {kNanC64, kNanC64, complex64(9, 1), 10},
93           {kNanC64, kNanC64, kNanC64, 11}};
94 }
95 
BValsRightComplex()96 Array2D<complex64> BValsRightComplex() {
97   return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
98 }
99 
BValsLeftComplex()100 Array2D<complex64> BValsLeftComplex() {
101   return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
102 }
103 
XLA_TEST_F(TriangularSolveTest,EmptyArrays)104 XLA_TEST_F(TriangularSolveTest, EmptyArrays) {
105   XlaBuilder builder(TestName());
106 
107   XlaOp a, b;
108   auto a_data =
109       CreateR2Parameter<float>(Array2D<float>(0, 0), 0, "a", &builder, &a);
110   auto b_data =
111       CreateR2Parameter<float>(Array2D<float>(0, 10), 1, "b", &builder, &b);
112   TriangularSolve(a, b,
113                   /*left_side=*/true, /*lower=*/true,
114                   /*unit_diagonal=*/false,
115                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
116 
117   ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 10),
118                              {a_data.get(), b_data.get()});
119 }
120 
XLA_TEST_F(TriangularSolveTest,SimpleRightLowerTranspose)121 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
122   XlaBuilder builder(TestName());
123 
124   XlaOp a, b;
125   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
126   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
127   TriangularSolve(a, b,
128                   /*left_side=*/false, /*lower=*/true,
129                   /*unit_diagonal=*/false,
130                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
131 
132   Array2D<float> expected({
133       {0.5, 0.08333334, 0.04629629, 0.03367003},
134       {2.5, -0.25, -0.1388889, -0.1010101},
135       {4.5, -0.58333331, -0.32407406, -0.23569024},
136   });
137 
138   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
139                              ErrorSpec(1e-2, 1e-2));
140 }
141 
XLA_TEST_F(TriangularSolveTest,SimpleRightLowerNotranspose)142 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
143   XlaBuilder builder(TestName());
144 
145   XlaOp a, b;
146   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
147   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
148   TriangularSolve(a, b,
149                   /*left_side=*/false, /*lower=*/true,
150                   /*unit_diagonal=*/false,
151                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
152 
153   Array2D<float> expected({
154       {-0.16414141, -0.06902357, -0.07070707, 0.36363636},
155       {0.64393939, 0.06565657, -0.03030303, 0.72727273},
156       {1.4520202, 0.2003367, 0.01010101, 1.09090909},
157   });
158 
159   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
160                              ErrorSpec(1e-2, 1e-2));
161 }
162 
XLA_TEST_F(TriangularSolveTest,SimpleRightUpperTranspose)163 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
164   XlaBuilder builder(TestName());
165 
166   XlaOp a, b;
167   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
168   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
169   TriangularSolve(a, b,
170                   /*left_side=*/false, /*lower=*/false,
171                   /*unit_diagonal=*/false,
172                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
173 
174   Array2D<float> expected({
175       {-0.16414141, -0.06902357, -0.07070707, 0.36363636},
176       {0.64393939, 0.06565657, -0.03030303, 0.72727273},
177       {1.4520202, 0.2003367, 0.01010101, 1.09090909},
178   });
179 
180   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
181                              ErrorSpec(1e-2, 1e-2));
182 }
183 
XLA_TEST_F(TriangularSolveTest,SimpleRightUpperNotranspose)184 XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
185   XlaBuilder builder(TestName());
186 
187   XlaOp a, b;
188   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
189   auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
190   TriangularSolve(a, b,
191                   /*left_side=*/false, /*lower=*/false,
192                   /*unit_diagonal=*/false,
193                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
194 
195   Array2D<float> expected({
196       {0.5, 0.08333334, 0.04629629, 0.03367003},
197       {2.5, -0.25, -0.1388889, -0.1010101},
198       {4.5, -0.58333331, -0.32407406, -0.23569024},
199   });
200 
201   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
202                              ErrorSpec(1e-2, 1e-2));
203 }
204 
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerTranspose)205 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
206   XlaBuilder builder(TestName());
207 
208   XlaOp a, b;
209   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
210   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
211   TriangularSolve(a, b,
212                   /*left_side=*/true, /*lower=*/true,
213                   /*unit_diagonal=*/false,
214                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
215 
216   Array2D<float> expected({
217       {-0.89646465, -0.69444444, -0.49242424},
218       {-0.27441077, -0.24074074, -0.20707071},
219       {-0.23232323, -0.22222222, -0.21212121},
220       {0.90909091, 1., 1.09090909},
221   });
222 
223   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
224                              ErrorSpec(1e-2, 1e-2));
225 }
226 
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerNotranspose)227 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
228   XlaBuilder builder(TestName());
229 
230   XlaOp a, b;
231   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
232   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
233   TriangularSolve(a, b,
234                   /*left_side=*/true, /*lower=*/true,
235                   /*unit_diagonal=*/false,
236                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
237 
238   Array2D<float> expected({
239       {0.5, 1.0, 1.5},
240       {0.41666667, 0.33333333, 0.25},
241       {0.23148148, 0.18518519, 0.13888889},
242       {0.16835017, 0.13468013, 0.1010101},
243   });
244 
245   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
246                              ErrorSpec(1e-2, 1e-2));
247 }
248 
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerNoTransposeUnitDiagonal)249 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNoTransposeUnitDiagonal) {
250   XlaBuilder builder(TestName());
251 
252   XlaOp a, b;
253   auto a_data =
254       CreateR2Parameter<float>(AValsLowerUnitDiagonal(), 0, "a", &builder, &a);
255   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
256   TriangularSolve(a, b,
257                   /*left_side=*/true, /*lower=*/true,
258                   /*unit_diagonal=*/true,
259                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
260 
261   Array2D<float> expected(
262       {{1., 2., 3.}, {1., -1., -3.}, {-4., 7., 18.}, {37., -61., -159.}});
263 
264   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
265                              ErrorSpec(1e-2, 1e-2));
266 }
267 
XLA_TEST_F(TriangularSolveTest,SimpleLeftLowerNotransposeIrregularblock)268 XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) {
269   XlaBuilder builder(TestName());
270 
271   XlaOp a, b;
272   auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
273   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
274   TriangularSolve(a, b,
275                   /*left_side=*/true, /*lower=*/true,
276                   /*unit_diagonal=*/false,
277                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
278 
279   Array2D<float> expected({
280       {0.5, 1.0, 1.5},
281       {0.41666667, 0.33333333, 0.25},
282       {0.23148148, 0.18518519, 0.13888889},
283       {0.16835017, 0.13468013, 0.1010101},
284   });
285 
286   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
287                              ErrorSpec(1e-2, 1e-2));
288 }
289 
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperTranspose)290 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
291   XlaBuilder builder(TestName());
292 
293   XlaOp a, b;
294   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
295   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
296   TriangularSolve(a, b,
297                   /*left_side=*/true, /*lower=*/false,
298                   /*unit_diagonal=*/false,
299                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
300 
301   Array2D<float> expected({
302       {0.5, 1.0, 1.5},
303       {0.41666667, 0.33333333, 0.25},
304       {0.23148148, 0.18518519, 0.13888889},
305       {0.16835017, 0.13468013, 0.1010101},
306   });
307 
308   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
309                              ErrorSpec(1e-2, 1e-2));
310 }
311 
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperNotranspose)312 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
313   XlaBuilder builder(TestName());
314 
315   XlaOp a, b;
316   auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
317   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
318   TriangularSolve(a, b,
319                   /*left_side=*/true, /*lower=*/false,
320                   /*unit_diagonal=*/false,
321                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
322 
323   Array2D<float> expected({
324       {-0.89646465, -0.69444444, -0.49242424},
325       {-0.27441077, -0.24074074, -0.20707071},
326       {-0.23232323, -0.22222222, -0.21212121},
327       {0.90909091, 1., 1.09090909},
328   });
329 
330   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
331                              ErrorSpec(1e-2, 1e-2));
332 }
333 
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperNotransposeUnitDiagonal)334 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotransposeUnitDiagonal) {
335   XlaBuilder builder(TestName());
336 
337   XlaOp a, b;
338   auto a_data =
339       CreateR2Parameter<float>(AValsUpperUnitDiagonal(), 0, "a", &builder, &a);
340   auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
341   TriangularSolve(a, b,
342                   /*left_side=*/true, /*lower=*/false,
343                   /*unit_diagonal=*/true,
344                   /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE);
345 
346   Array2D<float> expected({{-1402., -1538., -1674.},
347                            {575., 631., 687.},
348                            {-93., -102., -111.},
349                            {10., 11., 12.}});
350 
351   ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
352                              ErrorSpec(1e-2, 1e-2));
353 }
354 
XLA_TEST_F(TriangularSolveTest,SimpleRightLowerTransposeConjugate)355 XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
356   XlaBuilder builder(TestName());
357 
358   XlaOp a, b;
359   auto a_data =
360       CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
361   auto b_data =
362       CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b);
363   TriangularSolve(a, b,
364                   /*left_side=*/false, /*lower=*/true,
365                   /*unit_diagonal=*/false,
366                   /*transpose_a=*/TriangularSolveOptions::ADJOINT);
367 
368   Array2D<complex64> expected({
369       {0.5, complex64(0.08333333, 0.08333333),
370        complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)},
371       {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963),
372        complex64(0.08670034, -0.02104377)},
373       {4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296),
374        complex64(0.11026936, -0.03114478)},
375   });
376 
377   ComputeAndCompareR2<complex64>(
378       &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2));
379 }
380 
XLA_TEST_F(TriangularSolveTest,SimpleLeftUpperTransposeNoconjugate)381 XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
382   XlaBuilder builder(TestName());
383 
384   XlaOp a, b;
385   auto a_data =
386       CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
387   auto b_data =
388       CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b);
389   TriangularSolve(a, b,
390                   /*left_side=*/true, /*lower=*/false,
391                   /*unit_diagonal=*/false,
392                   /*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
393 
394   Array2D<complex64> expected({
395       {0.5, 1., 1.5},
396       {0.41666667, 0.33333333, 0.25},
397       {complex64(0.20020325, -2.81504065e-01),
398        complex64(0.13821138, -4.22764228e-01),
399        complex64(0.07621951, -5.64024390e-01)},
400       {complex64(0.19678492, 2.55912786e-01),
401        complex64(0.17738359, 3.84331116e-01),
402        complex64(0.15798226, 5.12749446e-01)},
403   });
404 
405   ComputeAndCompareR2<complex64>(
406       &builder, expected, {a_data.get(), b_data.get()}, ErrorSpec(1e-2, 1e-2));
407 }
408 
XLA_TEST_F(TriangularSolveTest,BatchedLeftUpper)409 XLA_TEST_F(TriangularSolveTest, BatchedLeftUpper) {
410   XlaBuilder builder(TestName());
411 
412   Array3D<float> bvals(7, 5, 5);
413   bvals.FillIota(1.);
414 
415   // Set avals to the upper triangle of bvals.
416   Array3D<float> avals = bvals;
417   avals.Each([](absl::Span<const int64_t> indices, float* value) {
418     if (indices[1] > indices[2]) {
419       *value = 0;
420     }
421   });
422 
423   XlaOp a, b;
424   auto a_data = CreateR3Parameter<float>(avals, 0, "a", &builder, &a);
425   auto b_data = CreateR3Parameter<float>(bvals, 1, "b", &builder, &b);
426   BatchDot(
427       ConstantR3FromArray3D(&builder, avals),
428       TriangularSolve(a, b,
429                       /*left_side=*/true, /*lower=*/false,
430                       /*unit_diagonal=*/false,
431                       /*transpose_a=*/TriangularSolveOptions::NO_TRANSPOSE));
432 
433   ComputeAndCompareR3<float>(&builder, bvals, {a_data.get(), b_data.get()},
434                              ErrorSpec(1e-2, 1e-2));
435 }
436 
437 struct TriangularSolveTestSpec {
438   std::vector<int64_t> dims;  // [..., m, n] A is mxm, B is mxn
439   bool left_side;
440   bool lower;
441   TriangularSolveOptions::Transpose transpose_a;
442 };
443 
444 class TriangularSolveParametricTest
445     : public ClientLibraryTestBase,
446       public ::testing::WithParamInterface<TriangularSolveTestSpec> {};
447 
XLA_TEST_P(TriangularSolveParametricTest,Random)448 XLA_TEST_P(TriangularSolveParametricTest, Random) {
449   TriangularSolveTestSpec spec = GetParam();
450 
451   XlaBuilder builder(TestName());
452 
453   CHECK_GE(spec.dims.size(), 2);
454   std::vector<int64_t> a_dims = spec.dims;
455   a_dims.back() = a_dims.at(a_dims.size() - 2);
456   Array<float> avals(a_dims);
457   avals.FillRandom(1.0);
458   avals.Each([](absl::Span<const int64_t> dims, float* v) {
459     if (dims.back() == dims.at(dims.size() - 2)) {
460       *v += 30;
461     }
462   });
463 
464   std::vector<int64_t> b_dims = spec.dims;
465   if (!spec.left_side) {
466     std::swap(b_dims.back(), b_dims.at(b_dims.size() - 2));
467   }
468   Array<float> bvals(b_dims);
469   bvals.FillRandom(1.0);
470 
471   XlaOp a, b;
472   auto a_data = CreateParameter<float>(avals, 0, "a", &builder, &a);
473   auto b_data = CreateParameter<float>(bvals, 1, "b", &builder, &b);
474   auto x = TriangularSolve(a, b, spec.left_side, spec.lower,
475                            /*unit_diagonal=*/false, spec.transpose_a);
476   auto a_tri = Triangle(a, spec.lower);
477   a_tri = MaybeTransposeInMinorDims(
478       a_tri, spec.transpose_a != TriangularSolveOptions::NO_TRANSPOSE);
479   if (spec.left_side) {
480     BatchDot(a_tri, x, xla::PrecisionConfig::HIGHEST);
481   } else {
482     BatchDot(x, a_tri, xla::PrecisionConfig::HIGHEST);
483   }
484 
485   ComputeAndCompare<float>(&builder, bvals, {a_data.get(), b_data.get()},
486                            ErrorSpec(3e-2, 3e-2));
487 }
488 
TriangularSolveTests()489 std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
490   std::vector<TriangularSolveTestSpec> specs;
491   for (auto batch : {std::vector<int64_t>{}, std::vector<int64_t>{1},
492                      std::vector<int64_t>{5}, std::vector<int64_t>{65},
493                      std::vector<int64_t>{129}}) {
494     for (int m : {5, 10, 150}) {
495       for (int n : {5, 150}) {
496         for (bool left_side : {false, true}) {
497           for (bool lower : {false, true}) {
498             for (TriangularSolveOptions::Transpose transpose_a :
499                  {TriangularSolveOptions::NO_TRANSPOSE,
500                   TriangularSolveOptions::TRANSPOSE}) {
501               std::vector<int64_t> dims(batch.begin(), batch.end());
502               dims.push_back(m);
503               dims.push_back(n);
504               specs.push_back({dims, left_side, lower, transpose_a});
505             }
506           }
507         }
508       }
509     }
510   }
511   return specs;
512 }
513 
514 INSTANTIATE_TEST_SUITE_P(
515     TriangularSolveParametricTestInstantiation, TriangularSolveParametricTest,
516     ::testing::ValuesIn(TriangularSolveTests()),
__anon3c5653df0402(const ::testing::TestParamInfo<TriangularSolveTestSpec>& info) 517     [](const ::testing::TestParamInfo<TriangularSolveTestSpec>& info) {
518       const TriangularSolveTestSpec& spec = info.param;
519       std::string name = absl::StrCat(
520           absl::StrJoin(spec.dims, "_"), "_", spec.left_side ? "left" : "right",
521           "_", spec.lower ? "lower" : "upper", "_",
522           absl::AsciiStrToLower(
523               TriangularSolveOptions_Transpose_Name(spec.transpose_a)));
524       return name;
525     });
526 
527 }  // namespace
528 }  // namespace xla
529