1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/eigh_expander.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/comparators.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/lib/loops.h"
25 #include "tensorflow/compiler/xla/client/lib/math.h"
26 #include "tensorflow/compiler/xla/client/lib/matrix.h"
27 #include "tensorflow/compiler/xla/client/lib/slicing.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/literal_util.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/errors.h"
37
38 // Parallel two-sided Jacobi symmetric eigendecomposition.
39 //
40 // The implementation follows the approach described in:
41 // Brent, Richard P., and Franklin T. Luk. "The solution of singular-value and
42 // symmetric eigenvalue problems on multiprocessor arrays." SIAM Journal on
43 // Scientific and Statistical Computing 6.1 (1985): 69-84.
44 //
45 // Where the Brent/Luk paper uses "processors", we use "vector elements".
46 namespace xla {
47
48 namespace {
49
50 // A 2x2 symmetric Eigendecomposition of a matrix A.
51 // If
52 // G = [[ c, s],
53 // [-s, c]]
54 // matmul(G_T, G) = I
55 // and
56 // G @ [[rt1, 0 ], @ G.T = A
57 // [ 0, rt2]]
58 struct Eigh2x2 {
59 // Eigenvalues
60 XlaOp rt1;
61 XlaOp rt2;
62 // First row of Eigenvector matrix.
63 XlaOp c; // cosine.
64 XlaOp s; // sine.
65 };
66
67 // sqrt(x**2 + y**2), calculated avoiding overflow.
Hypot(XlaOp x,XlaOp y)68 XlaOp Hypot(XlaOp x, XlaOp y) {
69 x = Abs(x);
70 y = Abs(y);
71 auto xy_min = Min(x, y);
72 auto xy_max = Max(x, y);
73 auto out = xy_max * Sqrt(ScalarLike(x, 1) + Square(xy_min / xy_max));
74 return Select(Eq(xy_min, xy_max), xy_min * ScalarLike(xy_min, std::sqrt(2.)),
75 out);
76 }
77
78 // Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n,
79 // a Jacobi rotation computes a rotation matrix G = [[c, s], [-s, c]], such that
80 // G_T * A[[p, q], [p, q]] * G
81 // is diagonalized.
82 //
83 // In this parallel Jacobi algorithm, we simultaneously compute Jacobi rotations
84 // for all of the matrix diagonal elements at the same time. The matrix diagonal
85 // elements correspond to different rows and columns of the original matrix and
86 // their rotations do not interfere and hence can be computed in parallel.
87 //
88 // def sym_schur2x2(w_tl, w_tr, w_br):
89 // off_diag = np.diag(w_tr)
90 // tau = (np.diag(w_br) - np.diag(w_tl)) / (2 * off_diag)
91 // t = np.where(tau >= 0, 1.0 / (tau + np.sqrt(1 + tau ** 2)),
92 // -1.0 / (-tau + np.sqrt(1 + tau ** 2)))
93 // pred = np.abs(off_diag) > 1e-6
94 // t = np.where(pred, t, 0.)
95 // c = 1.0 / np.sqrt(1.0 + t ** 2)
96 // s = t * c
97 // rt1 = w_tl - t * w_tr
98 // rt2 = w_br + t * w_tr
99 // return rt1, rt2, c, s
HermitianEigenDecomposition2x2(XlaOp w_tl,XlaOp w_tr,XlaOp w_br)100 StatusOr<Eigh2x2> HermitianEigenDecomposition2x2(XlaOp w_tl, XlaOp w_tr,
101 XlaOp w_br) {
102 TF_ASSIGN_OR_RETURN(Shape w_tl_shape, w_tl.builder()->GetShape(w_tl));
103 bool is_complex = primitive_util::IsComplexType(w_tl_shape.element_type());
104
105 w_tl = GetMatrixDiagonal(Real(w_tl));
106 w_tr = GetMatrixDiagonal(w_tr);
107 w_br = GetMatrixDiagonal(Real(w_br));
108 auto zero = ScalarLike(w_tl, 0.0);
109 auto one = ScalarLike(w_tl, 1.0);
110 auto two = ScalarLike(w_tl, 2.0);
111
112 XlaOp w;
113 if (is_complex) {
114 auto abs_tr = Abs(w_tr);
115 w = Select(Eq(abs_tr, ZerosLike(abs_tr)), FullLike(w_tr, 1),
116 Conj(w_tr) / Complex(abs_tr, ZerosLike(abs_tr)));
117 w_tr = abs_tr;
118 }
119
120 auto tol = ScalarLike(w_tr, 1e-6);
121 auto tau = (w_br - w_tl) / (two * w_tr);
122 auto t = Sqrt(one + Square(tau));
123 t = Reciprocal(tau + Select(Ge(tau, zero), t, Neg(t)));
124 t = Select(Gt(Abs(w_tr), tol), t, ZerosLike(t));
125 auto c = Rsqrt(one + Square(t));
126 auto s = t * c;
127
128 auto rt1 = w_tl - t * w_tr;
129 auto rt2 = w_br + t * w_tr;
130
131 if (is_complex) {
132 rt1 = Complex(rt1, ZerosLike(rt1));
133 rt2 = Complex(rt2, ZerosLike(rt2));
134 c = Complex(c, ZerosLike(c));
135 s = Complex(s, ZerosLike(s)) * w;
136 }
137
138 return Eigh2x2{rt1, rt2, c, s};
139 }
140
141 // tl, tr, bl, br = (
142 // tl * c[:, None] - bl * s[:, None],
143 // tr * c[:, None] - br * s[:, None],
144 // tl * s[:, None] + bl * c[:, None],
145 // tr * s[:, None] + br * c[:, None],
146 // )
ApplyJacobiRotationOverRows(Eigh2x2 rotation,XlaOp & tl,XlaOp & tr,XlaOp & bl,XlaOp & br)147 void ApplyJacobiRotationOverRows(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr,
148 XlaOp& bl, XlaOp& br) {
149 Shape shape = tl.builder()->GetShape(tl).ValueOrDie();
150 std::vector<int64_t> broadcast_dims(shape.dimensions().size() - 1);
151 absl::c_iota(broadcast_dims, 0);
152 auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims);
153 auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims);
154
155 auto s_conj = MaybeConjugate(s, true);
156 std::tie(tl, tr, bl, br) =
157 std::make_tuple(tl * c - bl * s_conj, tr * c - br * s_conj,
158 tl * s + bl * c, tr * s + br * c);
159 }
160
161 // tl, tr, bl, br = (
162 // tl * c[None, :] - tr * s[None, :],
163 // tl * s[None, :] + tr * c[None, :],
164 // bl * c[None, :] - br * s[None, :],
165 // bl * s[None, :] + br * c[None, :],
166 // )
ApplyJacobiRotationOverCols(Eigh2x2 rotation,XlaOp & tl,XlaOp & tr,XlaOp & bl,XlaOp & br)167 void ApplyJacobiRotationOverCols(Eigh2x2 rotation, XlaOp& tl, XlaOp& tr,
168 XlaOp& bl, XlaOp& br) {
169 Shape shape = tl.builder()->GetShape(tl).ValueOrDie();
170 std::vector<int64_t> broadcast_dims(shape.dimensions().size() - 1);
171 absl::c_iota(broadcast_dims, 0);
172 broadcast_dims.back() = shape.dimensions().size() - 1;
173 auto c = BroadcastInDim(rotation.c, shape.dimensions(), broadcast_dims);
174 auto s = BroadcastInDim(rotation.s, shape.dimensions(), broadcast_dims);
175
176 auto s_conj = MaybeConjugate(s, true);
177 std::tie(tl, tr, bl, br) =
178 std::make_tuple(tl * c - tr * s, tl * s_conj + tr * c, bl * c - br * s,
179 bl * s_conj + br * c);
180 }
181
182 // def permute_rows_in_col(top, bottom):
183 // top_out = np.zeros_like(l)
184 // top_out[0] = top[0]
185 // top_out[1] = bottom[0]
186 // top_out[2:] = top[1:-1]
187 // bottom_out = np.zeros_like(r)
188 // bottom_out[:-1] = bottom[1:]
189 // bottom_out[-1] = top[-1]
190 // return top_out, bottom_out
PermuteRowsInColumn(XlaOp & top,XlaOp & bottom)191 void PermuteRowsInColumn(XlaOp& top, XlaOp& bottom) {
192 XlaBuilder* builder = top.builder();
193 Shape shape = builder->GetShape(top).ValueOrDie();
194 int64_t k = ShapeUtil::GetDimension(shape, -1);
195 if (k <= 1) {
196 return;
197 }
198 int ndim = shape.dimensions_size();
199 std::tie(top, bottom) =
200 std::make_tuple(ConcatInDim(builder,
201 {SliceInMinorDims(top, {0, 0}, {1, k}),
202 SliceInMinorDims(bottom, {0, 0}, {1, k}),
203 SliceInMinorDims(top, {1, 0}, {k - 1, k})},
204 ndim - 2),
205 ConcatInDim(builder,
206 {SliceInMinorDims(bottom, {1, 0}, {k, k}),
207 SliceInMinorDims(top, {k - 1, 0}, {k, k})},
208 ndim - 2));
209 }
210
PermuteColumnsInRow(XlaOp & left,XlaOp & right)211 void PermuteColumnsInRow(XlaOp& left, XlaOp& right) {
212 XlaBuilder* builder = left.builder();
213 Shape shape = builder->GetShape(left).ValueOrDie();
214 int64_t k = ShapeUtil::GetDimension(shape, -1);
215 if (k <= 1) {
216 return;
217 }
218 int ndim = shape.dimensions_size();
219 std::tie(left, right) =
220 std::make_tuple(ConcatInDim(builder,
221 {SliceInMinorDims(left, {0}, {1}),
222 SliceInMinorDims(right, {0}, {1}),
223 SliceInMinorDims(left, {1}, {k - 1})},
224 ndim - 1),
225 ConcatInDim(builder,
226 {SliceInMinorDims(right, {1}, {k}),
227 SliceInMinorDims(left, {k - 1}, {k})},
228 ndim - 1));
229 }
230
231 // Performs one round of parallel Jacobi rotations; n-1 rounds make a sweep.
232 // After each rotation, we permute the rows and columns of the quadrants of the
233 // matrix. The effect of the permutations is that all pairs of rows end up
234 // on the diagonal of the quadrants after n-1 rounds. The permutations are an
235 // implicit way of computing a tournament for n players such that each player
236 // plays every other player exactly once in n - 1 rounds. See the Brent/Luk
237 // paper for more details.
ApplyRotations(int64_t n,XlaOp & w_tl,XlaOp & w_tr,XlaOp & w_bl,XlaOp & w_br,XlaOp & v_tl,XlaOp & v_tr,XlaOp & v_bl,XlaOp & v_br)238 Status ApplyRotations(int64_t n, XlaOp& w_tl, XlaOp& w_tr, XlaOp& w_bl,
239 XlaOp& w_br, XlaOp& v_tl, XlaOp& v_tr, XlaOp& v_bl,
240 XlaOp& v_br) {
241 TF_ASSIGN_OR_RETURN(Eigh2x2 rotation,
242 HermitianEigenDecomposition2x2(w_tl, w_tr, w_br));
243
244 ApplyJacobiRotationOverRows(rotation, w_tl, w_tr, w_bl, w_br);
245 ApplyJacobiRotationOverCols(rotation, w_tl, w_tr, w_bl, w_br);
246 w_tl = SetMatrixDiagonal(w_tl, rotation.rt1);
247 w_tr = SetMatrixDiagonal(w_tr, ZerosLike(rotation.rt1));
248 w_bl = SetMatrixDiagonal(w_bl, ZerosLike(rotation.rt1));
249 w_br = SetMatrixDiagonal(w_br, rotation.rt2);
250
251 PermuteColumnsInRow(w_tl, w_tr);
252 PermuteColumnsInRow(w_bl, w_br);
253 PermuteRowsInColumn(w_tl, w_bl);
254 PermuteRowsInColumn(w_tr, w_br);
255
256 // Apply the rotations to the eigenvector matrix.
257 // TODO(phawkins): we could omit this if we aren't interested in computing the
258 // eigenvectors.
259 ApplyJacobiRotationOverRows(rotation, v_tl, v_tr, v_bl, v_br);
260 PermuteRowsInColumn(v_tl, v_bl);
261 PermuteRowsInColumn(v_tr, v_br);
262 return OkStatus();
263 }
264
265 struct FrobeniusNorms {
266 XlaOp off_diagonal_sq_norm;
267 XlaOp frobenius_sq_norm;
268 };
269
ComputeFrobeniusNorms(XlaOp w_tl,XlaOp w_tr,XlaOp w_bl,XlaOp w_br)270 StatusOr<FrobeniusNorms> ComputeFrobeniusNorms(XlaOp w_tl, XlaOp w_tr,
271 XlaOp w_bl, XlaOp w_br) {
272 XlaBuilder* builder = w_tl.builder();
273 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w_tl));
274 const int64_t num_dims = shape.rank();
275 auto square_norm = [](XlaOp x) -> XlaOp {
276 return Real(x * MaybeConjugate(x, true));
277 };
278 auto off_diag = [](XlaOp x) {
279 return Select(GetDiagonalMask(x), ZerosLike(x), x);
280 };
281 PrimitiveType norm_type =
282 primitive_util::IsComplexType(shape.element_type())
283 ? primitive_util::ComplexComponentType(shape.element_type())
284 : shape.element_type();
285 auto zero = ScalarLike(Real(w_tl), 0.0);
286 FrobeniusNorms norms;
287 norms.frobenius_sq_norm =
288 Reduce(square_norm(w_tl) + square_norm(w_tr) + square_norm(w_bl) +
289 square_norm(w_br),
290 zero, CreateScalarAddComputation(norm_type, builder),
291 {num_dims - 2, num_dims - 1});
292 norms.off_diagonal_sq_norm =
293 Reduce(square_norm(off_diag(w_tl)) + square_norm(w_tr) +
294 square_norm(w_bl) + square_norm(off_diag(w_br)),
295 zero, CreateScalarAddComputation(norm_type, builder),
296 {num_dims - 2, num_dims - 1});
297
298 return norms;
299 }
300
Sweeps(absl::Span<const XlaOp> initial_values,int64_t n,int max_iters,PrimitiveType index_type,XlaBuilder * builder)301 StatusOr<std::vector<XlaOp>> Sweeps(absl::Span<const XlaOp> initial_values,
302 int64_t n, int max_iters,
303 PrimitiveType index_type,
304 XlaBuilder* builder) {
305 auto while_cond_fn = [&](absl::Span<const XlaOp> values,
306 XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
307 auto iter_cond = Lt(values[0], ScalarLike(values[0], max_iters));
308
309 XlaOp w_tl, w_tr, w_bl, w_br;
310 std::tie(w_tl, w_tr, w_bl, w_br) =
311 std::make_tuple(values[2], values[3], values[4], values[5]);
312 TF_ASSIGN_OR_RETURN(auto norms,
313 ComputeFrobeniusNorms(w_tl, w_tr, w_bl, w_br));
314 auto tol = norms.frobenius_sq_norm * Square(values[1]);
315 auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_sq_norm),
316 xla::ConstantR0<bool>(cond_builder, false),
317 CreateScalarOrComputation(PRED, cond_builder));
318
319 return And(iter_cond, tol_cond);
320 };
321
322 auto while_body_fn =
323 [&](absl::Span<const XlaOp> values,
324 XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
325 std::vector<XlaOp> sweep_values(values.begin() + 1, values.end());
326 TF_ASSIGN_OR_RETURN(
327 sweep_values,
328 ForEachIndex(
329 n - 1, S32,
330 [&](XlaOp iter, absl::Span<const XlaOp> values,
331 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
332 XlaOp tol, w_tl, w_tr, w_bl, w_br, v_tl, v_tr, v_bl, v_br;
333 std::tie(tol, w_tl, w_tr, w_bl, w_br, v_tl, v_tr, v_bl, v_br) =
334 std::make_tuple(values[0], values[1], values[2], values[3],
335 values[4], values[5], values[6], values[7],
336 values[8]);
337 TF_RETURN_IF_ERROR(ApplyRotations(n, w_tl, w_tr, w_bl, w_br, v_tl,
338 v_tr, v_bl, v_br));
339 return std::vector<XlaOp>{tol, w_tl, w_tr, w_bl, w_br,
340 v_tl, v_tr, v_bl, v_br};
341 },
342 sweep_values, "ApplyRotations", body_builder));
343 std::vector<XlaOp> output(values.size());
344 output[0] = values[0] + ScalarLike(values[0], 1);
345 std::copy(sweep_values.begin(), sweep_values.end(), output.begin() + 1);
346 return output;
347 };
348 return WhileLoopHelper(while_cond_fn, while_body_fn, initial_values,
349 "EighJacobiSweeps", builder);
350 }
351
352 } // namespace
353
SortByEigenvalues(XlaOp & v,XlaOp & w)354 Status EighExpander::SortByEigenvalues(XlaOp& v, XlaOp& w) {
355 XlaBuilder* builder = v.builder();
356 TF_ASSIGN_OR_RETURN(Shape v_shape, builder->GetShape(v));
357 TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(w));
358 const int64_t num_dims = v_shape.rank();
359 auto dimensions = v_shape.dimensions();
360
361 std::vector<int64_t> broadcast_dims(num_dims - 1);
362 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
363 broadcast_dims[num_dims - 2] = num_dims - 1;
364 w = BroadcastInDim(w, dimensions, broadcast_dims);
365
366 XlaOp sort_result =
367 Sort({w, v},
368 CreateScalarLtComputation(
369 {w_shape.element_type(), v_shape.element_type()}, builder),
370 num_dims - 1);
371 w = GetMatrixDiagonal(GetTupleElement(sort_result, 0));
372 v = GetTupleElement(sort_result, 1);
373 return OkStatus();
374 }
375
376 // This is the cyclic Jacobi iteration.
377 //
378 // def jacobi(A):
379 // n, _ = A.shape
380 // tl = A[:n // 2, :n // 2]
381 // bl = A[n // 2:, :n // 2]
382 // tr = A[:n // 2, n // 2:]
383 // br = A[n // 2:, n // 2:]
384 // v_tl = np.eye(n // 2, dtype=A.dtype)
385 // v_tr = np.zeros((n // 2, n // 2), A.dtype)
386 // v_bl = np.zeros((n // 2, n // 2), A.dtype)
387 // v_br = np.eye(n // 2, dtype=A.dtype)
388 // frobenius_norm = np.sqrt(np.sum(np.square(tl) + np.square(tr) +
389 // np.square(bl) + np.square(br)))
390 // diag_norm = np.sqrt(np.sum(np.square(np.diag(tl)) +
391 // np.square(np.diag(br))))
392 // off_diag_norm = np.sqrt(frobenius_norm - diag_norm) * np.sqrt(
393 // frobenius_norm + diag_norm)
394 // while off_diag_norm > 1e-6 * frobenius_norm:
395 // for i in range(n - 1):
396 // c, s = sym_schur2x2(tl, tr, br)
397 // tl, tr, bl, br = (
398 // tl * c[:, None] - bl * s[:, None],
399 // tr * c[:, None] - br * s[:, None],
400 // tl * s[:, None] + bl * c[:, None],
401 // tr * s[:, None] + br * c[:, None],
402 // )
403 // tl, tr, bl, br = (
404 // tl * c[None, :] - tr * s[None, :],
405 // tl * s[None, :] + tr * c[None, :],
406 // bl * c[None, :] - br * s[None, :],
407 // bl * s[None, :] + br * c[None, :],
408 // )
409 // tl, bl = permute_rows_in_col(tl, bl)
410 // tr, br = permute_rows_in_col(tr, br)
411 // tl, tr = permute_cols_in_row(tl, tr)
412 // bl, br = permute_cols_in_row(bl, br)
413 // v_tl, v_tr, v_bl, v_br = (
414 // v_tl * c[:, None] - v_bl * s[:, None],
415 // v_tr * c[:, None] - v_br * s[:, None],
416 // v_tl * s[:, None] + v_bl * c[:, None],
417 // v_tr * s[:, None] + v_br * c[:, None],
418 // )
419 // v_tl, v_bl = permute_rovs_in_col(v_tl, v_bl)
420 // v_tr, v_br = permute_rovs_in_col(v_tr, v_br)
421 //
422 // frobenius_norm = np.sqrt(np.sum(np.square(tl) + np.square(tr) +
423 // np.square(bl) + np.square(br)))
424 // diag_norm = np.sqrt(np.sum(np.square(np.diag(tl)) +
425 // np.square(np.diag(br))))
426 // off_diag_norm = np.sqrt(frobenius_norm - diag_norm) * np.sqrt(
427 // frobenius_norm + diag_norm)
428 // return A, V
BuildEigh(XlaOp a,bool lower,int64_t max_iter,float tol,bool sort_eigenvalues)429 XlaOp EighExpander::BuildEigh(XlaOp a, bool lower, int64_t max_iter, float tol,
430 bool sort_eigenvalues) {
431 XlaBuilder* builder = a.builder();
432 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
433 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
434 const int64_t num_dims = a_shape.rank();
435 if (num_dims < 2) {
436 return InvalidArgument(
437 "Arguments to Eigen decomposition must have rank >= 2: got shape %s.",
438 a_shape.ToString());
439 }
440 PrimitiveType type = a_shape.element_type();
441 if (!primitive_util::IsFloatingPointType(type) &&
442 !primitive_util::IsComplexType(type)) {
443 return InvalidArgument(
444 "Type of the input matrix must be floating point "
445 "or complex: got %s.",
446 a_shape.ToString());
447 }
448
449 const int64_t m = ShapeUtil::GetDimension(a_shape, -2);
450 const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
451
452 if (m != n) {
453 return InvalidArgument(
454 "Arguments to symmetric eigendecomposition must be square matrices: "
455 "got shape (%d, %d).",
456 m, n);
457 }
458
459 const int64_t num_batch_dims = num_dims - 2;
460 std::vector<int64_t> batch_dims(num_batch_dims);
461 for (int i = 0; i < num_batch_dims; ++i) {
462 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
463 }
464
465 if (m <= 1) {
466 return Tuple(builder, {FullLike(a, 1), GetMatrixDiagonal(Real(a))});
467 }
468
469 a = Symmetrize(a, lower);
470
471 const int64_t k = CeilOfRatio(n, int64_t{2});
472 // tl = A[:n // 2, :n // 2]
473 // bl = A[n // 2:, :n // 2]
474 // tr = A[:n // 2, n // 2:]
475 // br = A[n // 2:, n // 2:]
476 auto tl = SliceInMinorDims(a, {0, 0}, {k, k});
477 auto bl = SliceInMinorDims(a, {k, 0}, {n, k});
478 auto tr = SliceInMinorDims(a, {0, k}, {k, n});
479 auto br = SliceInMinorDims(a, {k, k}, {n, n});
480 if (n % 2) {
481 auto zero = Zero(builder, type);
482 tr = PadInDim(tr, zero, num_dims - 1, /*pad_lo=*/0, /*pad_hi=*/1);
483 bl = PadInDim(bl, zero, num_dims - 2, /*pad_lo=*/0, /*pad_hi=*/1);
484 PaddingConfig config = MakeNoPaddingConfig(num_dims);
485 config.mutable_dimensions(num_dims - 2)->set_edge_padding_high(1);
486 config.mutable_dimensions(num_dims - 1)->set_edge_padding_high(1);
487 br = Pad(br, zero, config);
488 }
489 // v_tl = np.eye(n // 2, dtype=A.dtype)
490 // v_tr = np.zeros((n // 2, n // 2), A.dtype)
491 // v_bl = np.zeros((n // 2, n // 2), A.dtype)
492 // v_br = np.eye(n // 2, dtype=A.dtype)
493 auto v_tl = Broadcast(IdentityMatrix(builder, type, k, k), batch_dims);
494 auto v_br = v_tl;
495 auto v_tr = ZerosLike(v_tl);
496 auto v_bl = v_tr;
497
498 TF_ASSIGN_OR_RETURN(auto output, Sweeps(
499 {
500 Zero(builder, S32),
501 ScalarLike(Real(a), tol),
502 tl,
503 tr,
504 bl,
505 br,
506 v_tl,
507 v_tr,
508 v_bl,
509 v_br,
510 },
511 k * 2, max_iter, S32, builder));
512
513 std::tie(tl, tr, bl, br) =
514 std::make_tuple(output[2], output[3], output[4], output[5]);
515 std::tie(v_tl, v_tr, v_bl, v_br) =
516 std::make_tuple(output[6], output[7], output[8], output[9]);
517
518 auto w = ConcatInDim(
519 builder, {GetMatrixDiagonal(Real(tl)), GetMatrixDiagonal(Real(br))},
520 num_dims - 2);
521 auto v = ConcatInDim(builder,
522 {ConcatInDim(builder, {v_tl, v_tr}, num_dims - 1),
523 ConcatInDim(builder, {v_bl, v_br}, num_dims - 1)},
524 num_dims - 2);
525 if (n % 2) {
526 w = SliceInMinorDims(w, {0}, {n});
527 v = SliceInMinorDims(v, {0, 0}, {n, n});
528 }
529 v = MaybeConjugate(TransposeInMinorDims(v), true);
530
531 if (sort_eigenvalues) {
532 TF_RETURN_IF_ERROR(SortByEigenvalues(v, w));
533 }
534 return Tuple(builder, {v, w});
535 });
536 }
537
538 static const char* kEighCustomCallName = "Eigh";
539
InstructionMatchesPattern(HloInstruction * instruction)540 bool EighExpander::InstructionMatchesPattern(HloInstruction* instruction) {
541 return instruction->opcode() == HloOpcode::kCustomCall &&
542 instruction->custom_call_target() == kEighCustomCallName;
543 }
544
ExpandInstruction(HloInstruction * instruction)545 StatusOr<HloInstruction*> EighExpander::ExpandInstruction(
546 HloInstruction* instruction) {
547 const std::string name =
548 absl::StrFormat("xla.%s_%s", instruction->custom_call_target(),
549 instruction->operand(0)->shape().ToString());
550
551 HloModule* module = instruction->parent()->parent();
552
553 HloComputation*& computation =
554 computation_cache_.emplace(name, nullptr).first->second;
555 if (!computation) {
556 // Builds a new expansion.
557 //
558 // TODO(b/62327888): We do something unusual here: we build the computation
559 // using the XlaBuilder API, which is nominally an XLA client API. We do
560 // this because the external APIs for building complicated computations
561 // (XlaBuilder) are much more ergonomic than the internal ones. As it turns
562 // out, XlaBuilder isn't really a client API—what it does is build a
563 // HloModuleProto protocol buffer, that we can then deserialize and clone
564 // into our HloModule. Ideally we would avoid the protocol buffer step;
565 // that is left as an exercise for future work.
566 XlaBuilder builder(name);
567 TF_RET_CHECK(instruction->operand_count() == 1);
568 XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
569
570 std::vector<std::string> config_strs =
571 absl::StrSplit(instruction->raw_backend_config_string(), ',');
572 int lower;
573 int64_t max_iter;
574 int sort_eigenvalues;
575 float tol;
576 if (config_strs.size() != 4 || !absl::SimpleAtoi(config_strs[0], &lower) ||
577 !absl::SimpleAtoi(config_strs[1], &sort_eigenvalues) ||
578 !absl::SimpleAtoi(config_strs[2], &max_iter) ||
579 !absl::SimpleAtof(config_strs[3], &tol)) {
580 return Internal("Unable to parse arguments to Eigh custom call, got: %s",
581 instruction->raw_backend_config_string());
582 }
583 XlaOp result = BuildEigh(a, lower, max_iter, tol, sort_eigenvalues);
584 TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result));
585
586 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
587 xla_computation.GetProgramShape());
588 HloModuleConfig config(program_shape);
589 TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
590 xla_computation.proto(), config));
591 HloCloneContext context(module);
592 computation =
593 module->DeepCloneComputation(new_module->entry_computation(), &context);
594 }
595
596 return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
597 instruction->shape(), instruction->operands(), computation));
598 }
599
600 } // namespace xla
601