xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/eigh_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/eigh_expander.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/comparators.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/lib/loops.h"
25 #include "tensorflow/compiler/xla/client/lib/math.h"
26 #include "tensorflow/compiler/xla/client/lib/matrix.h"
27 #include "tensorflow/compiler/xla/client/lib/slicing.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/literal_util.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/errors.h"
37 
38 // Parallel two-sided Jacobi symmetric eigendecomposition.
39 //
40 // The implementation follows the approach described in:
41 // Brent, Richard P., and Franklin T. Luk. "The solution of singular-value and
42 // symmetric eigenvalue problems on multiprocessor arrays." SIAM Journal on
43 // Scientific and Statistical Computing 6.1 (1985): 69-84.
44 //
45 // Where the Brent/Luk paper uses "processors", we use "vector elements".
46 namespace xla {
47 
48 namespace {
49 
50 // A 2x2 symmetric Eigendecomposition of a matrix A.
51 // If
52 // G = [[ c, s],
53 //      [-s, c]]
54 // matmul(G_T, G) = I
55 // and
56 // G @ [[rt1, 0  ],  @ G.T = A
57 //      [  0, rt2]]
58 struct Eigh2x2 {
59   // Eigenvalues
60   XlaOp rt1;
61   XlaOp rt2;
62   // First row of Eigenvector matrix.
63   XlaOp c;  // cosine.
64   XlaOp s;  // sine.
65 };
66 
67 // sqrt(x**2 + y**2), calculated avoiding overflow.
Hypot(XlaOp x,XlaOp y)68 XlaOp Hypot(XlaOp x, XlaOp y) {
69   x = Abs(x);
70   y = Abs(y);
71   auto xy_min = Min(x, y);
72   auto xy_max = Max(x, y);
73   auto out = xy_max * Sqrt(ScalarLike(x, 1) + Square(xy_min / xy_max));
74   return Select(Eq(xy_min, xy_max), xy_min * ScalarLike(xy_min, std::sqrt(2.)),
75                 out);
76 }
77 
78 // Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n,
79 // a Jacobi rotation computes a rotation matrix G = [[c, s], [-s, c]], such that
80 //   G_T * A[[p, q], [p, q]] * G
81 // is diagonalized.
82 //
83 // In this parallel Jacobi algorithm, we simultaneously compute Jacobi rotations
84 // for all of the matrix diagonal elements at the same time. The matrix diagonal
85 // elements correspond to different rows and columns of the original matrix and
86 // their rotations do not interfere and hence can be computed in parallel.
87 //
88 // def sym_schur2x2(w_tl, w_tr, w_br):
89 //   off_diag = np.diag(w_tr)
90 //   tau = (np.diag(w_br) - np.diag(w_tl)) / (2 * off_diag)
91 //   t = np.where(tau >= 0, 1.0 / (tau + np.sqrt(1 + tau ** 2)),
92 //                -1.0 / (-tau + np.sqrt(1 + tau ** 2)))
93 //   pred = np.abs(off_diag) > 1e-6
94 //   t = np.where(pred, t, 0.)
95 //   c = 1.0 / np.sqrt(1.0 + t ** 2)
96 //   s = t * c
97 //   rt1 = w_tl - t * w_tr
98 //   rt2 = w_br + t * w_tr
99 //   return rt1, rt2, c, s
HermitianEigenDecomposition2x2(XlaOp w_tl,XlaOp w_tr,XlaOp w_br)100 StatusOr<Eigh2x2> HermitianEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr,
101                                                  XlaOp w_br) {
102   TF_ASSIGN_OR_RETURN(Shape w_tl_shape, w_tl.builder()->GetShape(w_tl));
103   bool is_complex = primitive_util::IsComplexType(w_tl_shape.element_type());
104 
105   w_tl = GetMatrixDiagonal(Real(w_tl));
106   w_tr = GetMatrixDiagonal(w_tr);
107   w_br = GetMatrixDiagonal(Real(w_br));
108   auto zero = ScalarLike(w_tl, 0.0);
109   auto one = ScalarLike(w_tl, 1.0);
110   auto two = ScalarLike(w_tl, 2.0);
111 
112   XlaOp w;
113   if (is_complex) {
114     auto abs_tr = Abs(w_tr);
115     w = Select(Eq(abs_tr, ZerosLike(abs_tr)), FullLike(w_tr, 1),
116                Conj(w_tr) / Complex(abs_tr, ZerosLike(abs_tr)));
117     w_tr = abs_tr;
118   }
119 
120   auto tol = ScalarLike(w_tr, 1e-6);
121   auto tau = (w_br - w_tl) / (two * w_tr);
122   auto t = Sqrt(one + Square(tau));
123   t = Reciprocal(tau + Select(Ge(tau, zero), t, Neg(t)));
124   t = Select(Gt(Abs(w_tr), tol), t, ZerosLike(t));
125   auto c = Rsqrt(one + Square(t));
126   auto s = t * c;
127 
128   auto rt1 = w_tl - t * w_tr;
129   auto rt2 = w_br + t * w_tr;
130 
131   if (is_complex) {
132     rt1 = Complex(rt1, ZerosLike(rt1));
133     rt2 = Complex(rt2, ZerosLike(rt2));
134     c = Complex(c, ZerosLike(c));
135     s = Complex(s, ZerosLike(s)) * w;
136   }
137 
138   return Eigh2x2{rt1, rt2, c, s};
139 }
140 
141 // tl, tr, bl, br = (
142 //   tl * c[:, None] - bl * s[:, None],
143 //   tr * c[:, None] - br * s[:, None],
144 //   tl * s[:, None] + bl * c[:, None],
145 //   tr * s[:, None] + br * c[:, None],
146 // )
ApplyJacobiRotationOverRows(Eigh2x2 rotation,XlaOp & tl,XlaOp & tr,XlaOp & bl,XlaOp & br)147 void ApplyJacobiRotationOverRows(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr,
148                                  XlaOp& bl, XlaOp& br) {
149   Shape shape = tl.builder()->GetShape(tl).ValueOrDie();
150   std::vector<int64_t> broadcast_dims(shape.dimensions().size() - 1);
151   absl::c_iota(broadcast_dims, 0);
152   auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims);
153   auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims);
154 
155   auto s_conj = MaybeConjugate(s, true);
156   std::tie(tl, tr, bl, br) =
157       std::make_tuple(tl * c - bl * s_conj, tr * c - br * s_conj,
158                       tl * s + bl * c, tr * s + br * c);
159 }
160 
161 // tl, tr, bl, br = (
162 //   tl * c[None, :] - tr * s[None, :],
163 //   tl * s[None, :] + tr * c[None, :],
164 //   bl * c[None, :] - br * s[None, :],
165 //   bl * s[None, :] + br * c[None, :],
166 // )
ApplyJacobiRotationOverCols(Eigh2x2 rotation,XlaOp & tl,XlaOp & tr,XlaOp & bl,XlaOp & br)167 void ApplyJacobiRotationOverCols(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr,
168                                  XlaOp& bl, XlaOp& br) {
169   Shape shape = tl.builder()->GetShape(tl).ValueOrDie();
170   std::vector<int64_t> broadcast_dims(shape.dimensions().size() - 1);
171   absl::c_iota(broadcast_dims, 0);
172   broadcast_dims.back() = shape.dimensions().size() - 1;
173   auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims);
174   auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims);
175 
176   auto s_conj = MaybeConjugate(s, true);
177   std::tie(tl, tr, bl, br) =
178       std::make_tuple(tl * c - tr * s, tl * s_conj + tr * c, bl * c - br * s,
179                       bl * s_conj + br * c);
180 }
181 
182 // def permute_rows_in_col(top, bottom):
183 //   top_out = np.zeros_like(l)
184 //   top_out[0] = top[0]
185 //   top_out[1] = bottom[0]
186 //   top_out[2:] = top[1:-1]
187 //   bottom_out = np.zeros_like(r)
188 //   bottom_out[:-1] = bottom[1:]
189 //   bottom_out[-1] = top[-1]
190 //   return top_out, bottom_out
PermuteRowsInColumn(XlaOp & top,XlaOp & bottom)191 void PermuteRowsInColumn(XlaOp& top, XlaOp& bottom) {
192   XlaBuilder* builder = top.builder();
193   Shape shape = builder->GetShape(top).ValueOrDie();
194   int64_t k = ShapeUtil::GetDimension(shape, -1);
195   if (k <= 1) {
196     return;
197   }
198   int ndim = shape.dimensions_size();
199   std::tie(top, bottom) =
200       std::make_tuple(ConcatInDim(builder,
201                                   {SliceInMinorDims(top, {0, 0}, {1, k}),
202                                    SliceInMinorDims(bottom, {0, 0}, {1, k}),
203                                    SliceInMinorDims(top, {1, 0}, {k - 1, k})},
204                                   ndim - 2),
205                       ConcatInDim(builder,
206                                   {SliceInMinorDims(bottom, {1, 0}, {k, k}),
207                                    SliceInMinorDims(top, {k - 1, 0}, {k, k})},
208                                   ndim - 2));
209 }
210 
PermuteColumnsInRow(XlaOp & left,XlaOp & right)211 void PermuteColumnsInRow(XlaOp& left, XlaOp& right) {
212   XlaBuilder* builder = left.builder();
213   Shape shape = builder->GetShape(left).ValueOrDie();
214   int64_t k = ShapeUtil::GetDimension(shape, -1);
215   if (k <= 1) {
216     return;
217   }
218   int ndim = shape.dimensions_size();
219   std::tie(left, right) =
220       std::make_tuple(ConcatInDim(builder,
221                                   {SliceInMinorDims(left, {0}, {1}),
222                                    SliceInMinorDims(right, {0}, {1}),
223                                    SliceInMinorDims(left, {1}, {k - 1})},
224                                   ndim - 1),
225                       ConcatInDim(builder,
226                                   {SliceInMinorDims(right, {1}, {k}),
227                                    SliceInMinorDims(left, {k - 1}, {k})},
228                                   ndim - 1));
229 }
230 
231 // Performs one round of parallel Jacobi rotations; n-1 rounds make a sweep.
232 // After each rotation, we permute the rows and columns of the quadrants of the
233 // matrix. The effect of the permutations is that all pairs of rows end up
234 // on the diagonal of the quadrants after n-1 rounds. The permutations are an
235 // implicit way of computing a tournament for n players such that each player
236 // plays every other player exactly once in n - 1 rounds. See the Brent/Luk
237 // paper for more details.
ApplyRotations(int64_t n,XlaOp & w_tl,XlaOp & w_tr,XlaOp & w_bl,XlaOp & w_br,XlaOp & v_tl,XlaOp & v_tr,XlaOp & v_bl,XlaOp & v_br)238 Status ApplyRotations(int64_t n, XlaOp& w_tl, XlaOp& w_tr, XlaOp& w_bl,
239                       XlaOp& w_br, XlaOp& v_tl, XlaOp& v_tr, XlaOp& v_bl,
240                       XlaOp& v_br) {
241   TF_ASSIGN_OR_RETURN(Eigh2x2 rotation,
242                       HermitianEigenDecomposition2x2(w_tl, w_tr, w_br));
243 
244   ApplyJacobiRotationOverRows(rotation, w_tl, w_tr, w_bl, w_br);
245   ApplyJacobiRotationOverCols(rotation, w_tl, w_tr, w_bl, w_br);
246   w_tl = SetMatrixDiagonal(w_tl, rotation.rt1);
247   w_tr = SetMatrixDiagonal(w_tr, ZerosLike(rotation.rt1));
248   w_bl = SetMatrixDiagonal(w_bl, ZerosLike(rotation.rt1));
249   w_br = SetMatrixDiagonal(w_br, rotation.rt2);
250 
251   PermuteColumnsInRow(w_tl, w_tr);
252   PermuteColumnsInRow(w_bl, w_br);
253   PermuteRowsInColumn(w_tl, w_bl);
254   PermuteRowsInColumn(w_tr, w_br);
255 
256   // Apply the rotations to the eigenvector matrix.
257   // TODO(phawkins): we could omit this if we aren't interested in computing the
258   // eigenvectors.
259   ApplyJacobiRotationOverRows(rotation, v_tl, v_tr, v_bl, v_br);
260   PermuteRowsInColumn(v_tl, v_bl);
261   PermuteRowsInColumn(v_tr, v_br);
262   return OkStatus();
263 }
264 
265 struct FrobeniusNorms {
266   XlaOp off_diagonal_sq_norm;
267   XlaOp frobenius_sq_norm;
268 };
269 
ComputeFrobeniusNorms(XlaOp w_tl,XlaOp w_tr,XlaOp w_bl,XlaOp w_br)270 StatusOr<FrobeniusNorms> ComputeFrobeniusNorms(XlaOp w_tl, XlaOp w_tr,
271                                                XlaOp w_bl, XlaOp w_br) {
272   XlaBuilder* builder = w_tl.builder();
273   TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w_tl));
274   const int64_t num_dims = shape.rank();
275   auto square_norm = [](XlaOp x) -> XlaOp {
276     return Real(x * MaybeConjugate(x, true));
277   };
278   auto off_diag = [](XlaOp x) {
279     return Select(GetDiagonalMask(x), ZerosLike(x), x);
280   };
281   PrimitiveType norm_type =
282       primitive_util::IsComplexType(shape.element_type())
283           ? primitive_util::ComplexComponentType(shape.element_type())
284           : shape.element_type();
285   auto zero = ScalarLike(Real(w_tl), 0.0);
286   FrobeniusNorms norms;
287   norms.frobenius_sq_norm =
288       Reduce(square_norm(w_tl) + square_norm(w_tr) + square_norm(w_bl) +
289                  square_norm(w_br),
290              zero, CreateScalarAddComputation(norm_type, builder),
291              {num_dims - 2, num_dims - 1});
292   norms.off_diagonal_sq_norm =
293       Reduce(square_norm(off_diag(w_tl)) + square_norm(w_tr) +
294                  square_norm(w_bl) + square_norm(off_diag(w_br)),
295              zero, CreateScalarAddComputation(norm_type, builder),
296              {num_dims - 2, num_dims - 1});
297 
298   return norms;
299 }
300 
Sweeps(absl::Span<const XlaOp> initial_values,int64_t n,int max_iters,PrimitiveType index_type,XlaBuilder * builder)301 StatusOr<std::vector<XlaOp>> Sweeps(absl::Span<const XlaOp> initial_values,
302                                     int64_t n, int max_iters,
303                                     PrimitiveType index_type,
304                                     XlaBuilder* builder) {
305   auto while_cond_fn = [&](absl::Span<const XlaOp> values,
306                            XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
307     auto iter_cond = Lt(values[0], ScalarLike(values[0], max_iters));
308 
309     XlaOp w_tl, w_tr, w_bl, w_br;
310     std::tie(w_tl, w_tr, w_bl, w_br) =
311         std::make_tuple(values[2], values[3], values[4], values[5]);
312     TF_ASSIGN_OR_RETURN(auto norms,
313                         ComputeFrobeniusNorms(w_tl, w_tr, w_bl, w_br));
314     auto tol = norms.frobenius_sq_norm * Square(values[1]);
315     auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_sq_norm),
316                               xla::ConstantR0<bool>(cond_builder, false),
317                               CreateScalarOrComputation(PRED, cond_builder));
318 
319     return And(iter_cond, tol_cond);
320   };
321 
322   auto while_body_fn =
323       [&](absl::Span<const XlaOp> values,
324           XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
325     std::vector<XlaOp> sweep_values(values.begin() + 1, values.end());
326     TF_ASSIGN_OR_RETURN(
327         sweep_values,
328         ForEachIndex(
329             n - 1, S32,
330             [&](XlaOp iter, absl::Span<const XlaOp> values,
331                 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
332               XlaOp tol, w_tl, w_tr, w_bl, w_br, v_tl, v_tr, v_bl, v_br;
333               std::tie(tol, w_tl, w_tr, w_bl, w_br, v_tl, v_tr, v_bl, v_br) =
334                   std::make_tuple(values[0], values[1], values[2], values[3],
335                                   values[4], values[5], values[6], values[7],
336                                   values[8]);
337               TF_RETURN_IF_ERROR(ApplyRotations(n, w_tl, w_tr, w_bl, w_br, v_tl,
338                                                 v_tr, v_bl, v_br));
339               return std::vector<XlaOp>{tol,  w_tl, w_tr, w_bl, w_br,
340                                         v_tl, v_tr, v_bl, v_br};
341             },
342             sweep_values, "ApplyRotations", body_builder));
343     std::vector<XlaOp> output(values.size());
344     output[0] = values[0] + ScalarLike(values[0], 1);
345     std::copy(sweep_values.begin(), sweep_values.end(), output.begin() + 1);
346     return output;
347   };
348   return WhileLoopHelper(while_cond_fn, while_body_fn, initial_values,
349                          "EighJacobiSweeps", builder);
350 }
351 
352 }  // namespace
353 
SortByEigenvalues(XlaOp & v,XlaOp & w)354 Status EighExpander::SortByEigenvalues(XlaOp& v, XlaOp& w) {
355   XlaBuilder* builder = v.builder();
356   TF_ASSIGN_OR_RETURN(Shape v_shape, builder->GetShape(v));
357   TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(w));
358   const int64_t num_dims = v_shape.rank();
359   auto dimensions = v_shape.dimensions();
360 
361   std::vector<int64_t> broadcast_dims(num_dims - 1);
362   std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
363   broadcast_dims[num_dims - 2] = num_dims - 1;
364   w = BroadcastInDim(w, dimensions, broadcast_dims);
365 
366   XlaOp sort_result =
367       Sort({w, v},
368            CreateScalarLtComputation(
369                {w_shape.element_type(), v_shape.element_type()}, builder),
370            num_dims - 1);
371   w = GetMatrixDiagonal(GetTupleElement(sort_result, 0));
372   v = GetTupleElement(sort_result, 1);
373   return OkStatus();
374 }
375 
376 // This is the cyclic Jacobi iteration.
377 //
378 // def jacobi(A):
379 //   n, _ = A.shape
380 //   tl = A[:n // 2, :n // 2]
381 //   bl = A[n // 2:, :n // 2]
382 //   tr = A[:n // 2, n // 2:]
383 //   br = A[n // 2:, n // 2:]
384 //   v_tl = np.eye(n // 2, dtype=A.dtype)
385 //   v_tr = np.zeros((n // 2, n // 2), A.dtype)
386 //   v_bl = np.zeros((n // 2, n // 2), A.dtype)
387 //   v_br = np.eye(n // 2, dtype=A.dtype)
388 //   frobenius_norm = np.sqrt(np.sum(np.square(tl) + np.square(tr) +
389 //                            np.square(bl) + np.square(br)))
390 //   diag_norm = np.sqrt(np.sum(np.square(np.diag(tl)) +
391 //                              np.square(np.diag(br))))
392 //    off_diag_norm = np.sqrt(frobenius_norm - diag_norm) * np.sqrt(
393 //            frobenius_norm + diag_norm)
394 //   while off_diag_norm > 1e-6 * frobenius_norm:
395 //     for i in range(n - 1):
396 //       c, s = sym_schur2x2(tl, tr, br)
397 //        tl, tr, bl, br = (
398 //          tl * c[:, None] - bl * s[:, None],
399 //          tr * c[:, None] - br * s[:, None],
400 //          tl * s[:, None] + bl * c[:, None],
401 //          tr * s[:, None] + br * c[:, None],
402 //        )
403 //        tl, tr, bl, br = (
404 //          tl * c[None, :] - tr * s[None, :],
405 //          tl * s[None, :] + tr * c[None, :],
406 //          bl * c[None, :] - br * s[None, :],
407 //          bl * s[None, :] + br * c[None, :],
408 //        )
409 //        tl, bl = permute_rows_in_col(tl, bl)
410 //        tr, br = permute_rows_in_col(tr, br)
411 //        tl, tr = permute_cols_in_row(tl, tr)
412 //        bl, br = permute_cols_in_row(bl, br)
413 //        v_tl, v_tr, v_bl, v_br = (
414 //          v_tl * c[:, None] - v_bl * s[:, None],
415 //          v_tr * c[:, None] - v_br * s[:, None],
416 //          v_tl * s[:, None] + v_bl * c[:, None],
417 //          v_tr * s[:, None] + v_br * c[:, None],
418 //        )
419 //        v_tl, v_bl = permute_rovs_in_col(v_tl, v_bl)
420 //        v_tr, v_br = permute_rovs_in_col(v_tr, v_br)
421 //
422 //     frobenius_norm = np.sqrt(np.sum(np.square(tl) + np.square(tr) +
423 //                              np.square(bl) + np.square(br)))
424 //     diag_norm = np.sqrt(np.sum(np.square(np.diag(tl)) +
425 //                         np.square(np.diag(br))))
426 //     off_diag_norm = np.sqrt(frobenius_norm - diag_norm) * np.sqrt(
427 //             frobenius_norm + diag_norm)
428 //   return A, V
BuildEigh(XlaOp a,bool lower,int64_t max_iter,float tol,bool sort_eigenvalues)429 XlaOp EighExpander::BuildEigh(XlaOp a, bool lower, int64_t max_iter, float tol,
430                               bool sort_eigenvalues) {
431   XlaBuilder* builder = a.builder();
432   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
433     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
434     const int64_t num_dims = a_shape.rank();
435     if (num_dims < 2) {
436       return InvalidArgument(
437           "Arguments to Eigen decomposition must have rank >= 2: got shape %s.",
438           a_shape.ToString());
439     }
440     PrimitiveType type = a_shape.element_type();
441     if (!primitive_util::IsFloatingPointType(type) &&
442         !primitive_util::IsComplexType(type)) {
443       return InvalidArgument(
444           "Type of the input matrix must be floating point "
445           "or complex: got %s.",
446           a_shape.ToString());
447     }
448 
449     const int64_t m = ShapeUtil::GetDimension(a_shape, -2);
450     const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
451 
452     if (m != n) {
453       return InvalidArgument(
454           "Arguments to symmetric eigendecomposition must be square matrices: "
455           "got shape (%d, %d).",
456           m, n);
457     }
458 
459     const int64_t num_batch_dims = num_dims - 2;
460     std::vector<int64_t> batch_dims(num_batch_dims);
461     for (int i = 0; i < num_batch_dims; ++i) {
462       batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
463     }
464 
465     if (m <= 1) {
466       return Tuple(builder, {FullLike(a, 1), GetMatrixDiagonal(Real(a))});
467     }
468 
469     a = Symmetrize(a, lower);
470 
471     const int64_t k = CeilOfRatio(n, int64_t{2});
472     // tl = A[:n // 2, :n // 2]
473     // bl = A[n // 2:, :n // 2]
474     // tr = A[:n // 2, n // 2:]
475     // br = A[n // 2:, n // 2:]
476     auto tl = SliceInMinorDims(a, {0, 0}, {k, k});
477     auto bl = SliceInMinorDims(a, {k, 0}, {n, k});
478     auto tr = SliceInMinorDims(a, {0, k}, {k, n});
479     auto br = SliceInMinorDims(a, {k, k}, {n, n});
480     if (n % 2) {
481       auto zero = Zero(builder, type);
482       tr = PadInDim(tr, zero, num_dims - 1, /*pad_lo=*/0, /*pad_hi=*/1);
483       bl = PadInDim(bl, zero, num_dims - 2, /*pad_lo=*/0, /*pad_hi=*/1);
484       PaddingConfig config = MakeNoPaddingConfig(num_dims);
485       config.mutable_dimensions(num_dims - 2)->set_edge_padding_high(1);
486       config.mutable_dimensions(num_dims - 1)->set_edge_padding_high(1);
487       br = Pad(br, zero, config);
488     }
489     // v_tl = np.eye(n // 2, dtype=A.dtype)
490     // v_tr = np.zeros((n // 2, n // 2), A.dtype)
491     // v_bl = np.zeros((n // 2, n // 2), A.dtype)
492     // v_br = np.eye(n // 2, dtype=A.dtype)
493     auto v_tl = Broadcast(IdentityMatrix(builder, type, k, k), batch_dims);
494     auto v_br = v_tl;
495     auto v_tr = ZerosLike(v_tl);
496     auto v_bl = v_tr;
497 
498     TF_ASSIGN_OR_RETURN(auto output, Sweeps(
499                                          {
500                                              Zero(builder, S32),
501                                              ScalarLike(Real(a), tol),
502                                              tl,
503                                              tr,
504                                              bl,
505                                              br,
506                                              v_tl,
507                                              v_tr,
508                                              v_bl,
509                                              v_br,
510                                          },
511                                          k * 2, max_iter, S32, builder));
512 
513     std::tie(tl, tr, bl, br) =
514         std::make_tuple(output[2], output[3], output[4], output[5]);
515     std::tie(v_tl, v_tr, v_bl, v_br) =
516         std::make_tuple(output[6], output[7], output[8], output[9]);
517 
518     auto w = ConcatInDim(
519         builder, {GetMatrixDiagonal(Real(tl)), GetMatrixDiagonal(Real(br))},
520         num_dims - 2);
521     auto v = ConcatInDim(builder,
522                          {ConcatInDim(builder, {v_tl, v_tr}, num_dims - 1),
523                           ConcatInDim(builder, {v_bl, v_br}, num_dims - 1)},
524                          num_dims - 2);
525     if (n % 2) {
526       w = SliceInMinorDims(w, {0}, {n});
527       v = SliceInMinorDims(v, {0, 0}, {n, n});
528     }
529     v = MaybeConjugate(TransposeInMinorDims(v), true);
530 
531     if (sort_eigenvalues) {
532       TF_RETURN_IF_ERROR(SortByEigenvalues(v, w));
533     }
534     return Tuple(builder, {v, w});
535   });
536 }
537 
538 static const char* kEighCustomCallName = "Eigh";
539 
InstructionMatchesPattern(HloInstruction * instruction)540 bool EighExpander::InstructionMatchesPattern(HloInstruction* instruction) {
541   return instruction->opcode() == HloOpcode::kCustomCall &&
542          instruction->custom_call_target() == kEighCustomCallName;
543 }
544 
ExpandInstruction(HloInstruction * instruction)545 StatusOr<HloInstruction*> EighExpander::ExpandInstruction(
546     HloInstruction* instruction) {
547   const std::string name =
548       absl::StrFormat("xla.%s_%s", instruction->custom_call_target(),
549                       instruction->operand(0)->shape().ToString());
550 
551   HloModule* module = instruction->parent()->parent();
552 
553   HloComputation*& computation =
554       computation_cache_.emplace(name, nullptr).first->second;
555   if (!computation) {
556     // Builds a new expansion.
557     //
558     // TODO(b/62327888): We do something unusual here: we build the computation
559     // using the XlaBuilder API, which is nominally an XLA client API. We do
560     // this because the external APIs for building complicated computations
561     // (XlaBuilder) are much more ergonomic than the internal ones. As it turns
562     // out, XlaBuilder isn't really a client API—what it does is build a
563     // HloModuleProto protocol buffer, that we can then deserialize and clone
564     // into our HloModule. Ideally we would avoid the protocol buffer step;
565     // that is left as an exercise for future work.
566     XlaBuilder builder(name);
567     TF_RET_CHECK(instruction->operand_count() == 1);
568     XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
569 
570     std::vector<std::string> config_strs =
571         absl::StrSplit(instruction->raw_backend_config_string(), ',');
572     int lower;
573     int64_t max_iter;
574     int sort_eigenvalues;
575     float tol;
576     if (config_strs.size() != 4 || !absl::SimpleAtoi(config_strs[0], &lower) ||
577         !absl::SimpleAtoi(config_strs[1], &sort_eigenvalues) ||
578         !absl::SimpleAtoi(config_strs[2], &max_iter) ||
579         !absl::SimpleAtof(config_strs[3], &tol)) {
580       return Internal("Unable to parse arguments to Eigh custom call, got: %s",
581                       instruction->raw_backend_config_string());
582     }
583     XlaOp result = BuildEigh(a, lower, max_iter, tol, sort_eigenvalues);
584     TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result));
585 
586     TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
587                         xla_computation.GetProgramShape());
588     HloModuleConfig config(program_shape);
589     TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
590                                              xla_computation.proto(), config));
591     HloCloneContext context(module);
592     computation =
593         module->DeepCloneComputation(new_module->entry_computation(), &context);
594   }
595 
596   return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
597       instruction->shape(), instruction->operands(), computation));
598 }
599 
600 }  // namespace xla
601