1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/client/lib/svd.h"
16
17 #include <memory>
18 #include <numeric>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/comparators.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/loops.h"
26 #include "tensorflow/compiler/xla/client/lib/math.h"
27 #include "tensorflow/compiler/xla/client/lib/matrix.h"
28 #include "tensorflow/compiler/xla/client/lib/slicing.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/literal_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/errors.h"
36
37 namespace xla {
38
39 namespace {
40
41 // Given a matrix A, define H,
42 // H = A * (I - beta * v_T * v) if v is a row vector, or
43 // H = (I - beta * v * v_T) if v is column vector.
44 // A * H or H * A zeros out trailing part of some row or column of A.
45 //
46 // [x0, ..., x_{k-1}, xk, x_{k+1}, ..., x_{n-1}] * H
47 // = [x0, ..., x_{k-1}, xnorm, 0, ..., 0]
48 //
49 // Here xnorm = norm([x_k, x_{k+1}, ..., x_{n - 1}])
50 struct HouseHolderResult {
51 XlaOp v;
52 XlaOp beta;
53 XlaOp a;
54 };
55
56 // Jacobi rotation (also known as Givens rotation):
57 // G = [[ c, s],
58 // [-s, c]]
59 // matmul(G_T, G) = I
60 struct JacobiRotation {
61 XlaOp c; // cosine.
62 XlaOp s; // sine.
63 };
64
65 // JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix.
66 struct JacobiUpdate {
67 XlaOp v;
68 XlaOp w;
69 };
70
71 // OneSidedJacobiRotation holds the left and right Jacobi rotations. Refer to
72 // GetOneSidedJacobiRotation for the effect of applying OneSidedJacobiRotation
73 // to a matrix.
74 struct OneSidedJacobiRotation {
75 JacobiRotation rot_l;
76 JacobiRotation rot_r;
77 };
78
79 // Householder reflection on the trailing elements of a vector.
80 //
81 // H = I - beta * [1, v]' * [1, v]
82 //
83 // H * x = [..., xnorm, 0, ..., 0]
84 // ..., j, j + 1, ..., n
85 //
86 // def house(x, j, eps):
87 // sigma = np.linalg.norm(x[(j + 1):])
88 // v = np.zeros_like(x)
89 // v[(j + 1):] = x[(j + 1):]
90 // if sigma < eps:
91 // beta = 0
92 // else:
93 // mu = sigma * np.sqrt((x[j]/sigma)**2 + 1)
94 // if x[j] <= 0:
95 // v[j] = x[j] - mu
96 // else:
97 // v[j] = -sigma / (x[j] + mu) * sigma
98 // beta = 2 / ((sigma / v[j])**2 + 1)
99 // v = v / v[j]
100 // v[j] = 1
101 // return v, beta
102 //
103 // Householder reflection on the trailing elements of a row of a matrix. After
104 // applying it on the matrix, all elements in [i, (j+1):] become zeros, i.e.,
105 //
106 // H = I - beta * [1, v]' * [1, v], then,
107 //
108 // A[i, j:] * H = [sigma, 0, 0, ..., 0]
109 //
HouseRow(XlaOp a,XlaOp i,XlaOp j,XlaOp eps,PrecisionConfig::Precision precision)110 StatusOr<HouseHolderResult> HouseRow(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
111 PrecisionConfig::Precision precision) {
112 XlaBuilder* builder = a.builder();
113 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
114 const int64_t num_dims = a_shape.rank();
115 const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
116 XlaOp zero = ScalarLike(i, 0);
117 XlaOp x = DynamicSliceInMinorDims(a, {i, zero}, {1, n});
118
119 const int64_t num_batch_dims = num_dims - 2;
120 std::vector<int64_t> batch_dims(num_batch_dims);
121 for (int k = 0; k < num_batch_dims; ++k) {
122 batch_dims[k] = ShapeUtil::GetDimension(a_shape, k);
123 }
124
125 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
126 auto idx = Iota(builder, ShapeUtil::MakeShape(S32, x_shape.dimensions()),
127 num_dims - 1);
128 auto zeros = ZerosLike(x);
129 auto v = Select(Gt(idx, j), x, zeros);
130
131 auto one = ScalarLike(v, 1.0);
132
133 auto sigma =
134 Sqrt(Reduce(Square(v), ScalarLike(v, 0.0),
135 CreateScalarAddComputation(x_shape.element_type(), builder),
136 {num_dims - 1}));
137
138 std::vector<int64_t> broadcast_dims(num_dims - 1);
139 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
140 auto x_0j = DynamicSliceInMinorDims(x, {zero, j}, {1, 1});
141 auto mu = Mul(sigma, Sqrt(Square(Div(x_0j, sigma, broadcast_dims)) + one),
142 broadcast_dims);
143
144 auto v_0j = Select(
145 Le(x_0j, ScalarLike(x_0j, 0.0)), Sub(x_0j, mu),
146 -Mul(sigma, Div(sigma, Add(x_0j, mu), broadcast_dims), broadcast_dims));
147
148 auto beta = Div(ScalarLike(v_0j, 2.0),
149 (Square(Div(sigma, v_0j, broadcast_dims)) + one));
150
151 v = Select(
152 BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v,
153 v / v_0j);
154 v = Select(Eq(idx, j), zeros + one, v);
155
156 beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps),
157 ZerosLike(beta), beta);
158
159 HouseHolderResult result;
160 result.v = v;
161 result.beta = beta;
162 result.a = Sub(a, Mul(beta, BatchDot(BatchDot(a, false, v, true, precision),
163 v, precision)));
164
165 return result;
166 }
167
168 // Householder reflection on the trailing elements of a col of a matrix. After
169 // applying it on the matrix, all elements in [(i+1):, j] become zeros, i.e.,
170 //
171 // H = I - beta * [1; v] * [1; v]', then,
172 //
173 // H * A[i:, j] = [xnorm, 0, 0, ..., 0]
174 //
HouseCol(XlaOp a,XlaOp i,XlaOp j,XlaOp eps,PrecisionConfig::Precision precision)175 StatusOr<HouseHolderResult> HouseCol(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
176 PrecisionConfig::Precision precision) {
177 XlaBuilder* builder = a.builder();
178 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
179 const int64_t num_dims = a_shape.rank();
180 const int64_t m = ShapeUtil::GetDimension(a_shape, -2);
181 XlaOp zero = ScalarLike(i, 0);
182 XlaOp x = DynamicSliceInMinorDims(a, {zero, j}, {m, 1});
183
184 const int64_t num_batch_dims = num_dims - 2;
185 std::vector<int64_t> batch_dims(num_batch_dims);
186 for (int k = 0; k < num_batch_dims; ++k) {
187 batch_dims[k] = ShapeUtil::GetDimension(a_shape, k);
188 }
189
190 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
191 auto idx = Iota(builder, ShapeUtil::MakeShape(S32, x_shape.dimensions()),
192 num_dims - 2);
193 auto zeros = ZerosLike(x);
194 auto v = Select(Gt(idx, i), x, zeros);
195
196 auto one = ScalarLike(v, 1.0);
197
198 auto sigma =
199 Sqrt(Reduce(Square(v), ScalarLike(v, 0.0),
200 CreateScalarAddComputation(x_shape.element_type(), builder),
201 {num_dims - 2}));
202
203 std::vector<int64_t> broadcast_dims(num_dims - 1);
204 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
205 broadcast_dims[num_dims - 2] = num_dims - 1;
206 auto x_0i = DynamicSliceInMinorDims(x, {i, zero}, {1, 1});
207 auto mu = Mul(sigma, Sqrt(Square(Div(x_0i, sigma, broadcast_dims)) + one),
208 broadcast_dims);
209
210 auto v_0i = Select(
211 Le(x_0i, ScalarLike(x_0i, 0.0)), Sub(x_0i, mu),
212 -Mul(sigma, Div(sigma, Add(x_0i, mu), broadcast_dims), broadcast_dims));
213
214 auto beta = Div(ScalarLike(v_0i, 2.0),
215 (Square(Div(sigma, v_0i, broadcast_dims)) + one));
216
217 v = Select(
218 BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v,
219 v / v_0i);
220 v = Select(Eq(idx, i), zeros + one, v);
221
222 beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps),
223 ZerosLike(beta), beta);
224
225 HouseHolderResult result;
226 result.v = v;
227 result.beta = beta;
228 result.a = Sub(
229 a, Mul(beta, BatchDot(v, false, BatchDot(v, true, a, false, precision),
230 false, precision)));
231
232 return result;
233 }
234
235 // Apply column and row householder reflections for bidiagonalization.
236 //
237 // def house_bidiag(A):
238 // xz, yz = A.shape
239 // LL = np.eye(xz)
240 // RR = np.eye(yz)
241 // for i in range(yz - 1):
242 // v, beta = house_col(A, i, i, 1e-8)
243 // L = np.eye(xz) - beta * np.outer(v, v)
244 // LL = np.matmul(LL, L)
245 // A = np.matmul(L, A)
246 // if i < yz - 2:
247 // v, beta = house_row(A, i, i + 1, 1e-8)
248 // R = np.eye(yz) - beta * np.outer(v, v)
249 // RR = np.matmul(RR, R)
250 // A = np.matmul(A, R)
251 // return LL, A, RR
252 //
HouseHolderBidiagonalization(XlaOp a,XlaOp eps,PrecisionConfig::Precision precision)253 StatusOr<SVDResult> HouseHolderBidiagonalization(
254 XlaOp a, XlaOp eps, PrecisionConfig::Precision precision) {
255 XlaBuilder* builder = a.builder();
256 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
257 const int64_t num_dims = a_shape.rank();
258 const int64_t num_batch_dims = num_dims - 2;
259 std::vector<int64_t> batch_dims(num_batch_dims);
260 for (int i = 0; i < num_batch_dims; ++i) {
261 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
262 }
263 const int64_t m = ShapeUtil::GetDimension(a_shape, -2);
264 const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
265 XlaOp u_init = Broadcast(
266 IdentityMatrix(builder, a_shape.element_type(), m, m), batch_dims);
267 XlaOp v_init = Broadcast(
268 IdentityMatrix(builder, a_shape.element_type(), n, n), batch_dims);
269
270 auto while_cond_fn = [&](absl::Span<const XlaOp> values,
271 XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
272 auto i = values[0];
273 return Lt(i, ScalarLike(i, n - 2));
274 };
275 auto while_body_fn =
276 [&](absl::Span<const XlaOp> values,
277 XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
278 auto i = values[0];
279 auto one = ScalarLike(i, 1);
280
281 auto u = values[1];
282 auto v = values[2];
283 auto a = values[3];
284 auto eps = values[4];
285
286 TF_ASSIGN_OR_RETURN(HouseHolderResult house_col,
287 HouseCol(a, i, i, eps, precision));
288 u = Sub(u,
289 Mul(house_col.beta, BatchDot(BatchDot(u, house_col.v, precision),
290 false, house_col.v, true, precision)));
291 a = house_col.a;
292
293 TF_ASSIGN_OR_RETURN(HouseHolderResult house_row,
294 HouseRow(a, i, i + one, eps, precision));
295 v = Sub(v, Mul(house_row.beta,
296 BatchDot(BatchDot(v, false, house_row.v, true, precision),
297 house_row.v, precision)));
298 a = house_row.a;
299
300 std::vector<XlaOp> updated_values;
301 updated_values.reserve(values.size());
302
303 updated_values.push_back(i + one);
304 updated_values.push_back(u);
305 updated_values.push_back(v);
306 updated_values.push_back(a);
307 updated_values.push_back(eps);
308 return updated_values;
309 };
310
311 std::vector<XlaOp> values(5);
312 values[0] = Zero(builder, S32);
313 values[1] = u_init;
314 values[2] = v_init;
315 values[3] = a;
316 values[4] = eps;
317
318 TF_ASSIGN_OR_RETURN(values,
319 WhileLoopHelper(while_cond_fn, while_body_fn, values,
320 "HouseHolderBidiagonalization", builder));
321
322 for (int k = 2; k > 0; --k) {
323 if (n - k >= 0) {
324 XlaOp index = ScalarLike(values[0], n - k);
325 TF_ASSIGN_OR_RETURN(HouseHolderResult house_col,
326 HouseCol(values[3], index, index, eps, precision));
327 values[1] = Sub(values[1],
328 Mul(house_col.beta,
329 BatchDot(BatchDot(values[1], house_col.v, precision),
330 false, house_col.v, true, precision)));
331 values[3] = house_col.a;
332 }
333 }
334
335 SVDResult result;
336 result.u = values[1];
337 result.v = values[2];
338 result.d = values[3];
339 return result;
340 }
341
342 // MakeJacobi computes a rotation matrix G = [[c, s], [-s, c]], such that
343 // G_T * [[ps, pqs], [pqs, qs]] * G
344 // is diagonalized.
345 //
346 // def make_jacobi(ps, qs, pqs, eps):
347 // if np.abs(a_pq) > eps:
348 // tau = (a_qq - a_pp) / (2 * a_pq)
349 // if tau >= 0:
350 // t = 1.0 / (tau + np.sqrt(1 + tau ** 2))
351 // else:
352 // t = -1.0 / (-tau + np.sqrt(1 + tau ** 2))
353 // c = 1.0 / np.sqrt(1.0 + t ** 2)
354 // s = t * c
355 // else:
356 // c = 1.0
357 // s = 0.0
358 // return c, s
359 //
MakeJacobi(XlaOp ps,XlaOp qs,XlaOp pqs,XlaOp eps)360 StatusOr<JacobiRotation> MakeJacobi(XlaOp ps, XlaOp qs, XlaOp pqs, XlaOp eps) {
361 auto zero = ScalarLike(ps, 0.0);
362 auto one = ScalarLike(ps, 1.0);
363 auto two = ScalarLike(ps, 2.0);
364
365 auto tau = (qs - ps) / (pqs * two);
366 auto t_pos = one / (tau + Sqrt(one + Square(tau)));
367 auto t_neg = -one / (-tau + Sqrt(one + Square(tau)));
368 auto t = Select(Ge(tau, zero), t_pos, t_neg);
369
370 auto c_temp = Rsqrt(one + Square(t));
371 auto s_temp = t * c_temp;
372
373 auto c = Select(Ge(Abs(pqs), eps), c_temp, ZerosLike(c_temp) + one);
374 auto s = Select(Ge(Abs(pqs), eps), s_temp, ZerosLike(s_temp));
375 // Renormalize c and s to compensate for low precision arithmetic, this step
376 // is redundant if high precision float is used, like float64.
377 auto rnorm = Rsqrt(Square(c) + Square(s));
378
379 JacobiRotation rot;
380
381 rot.c = c * rnorm;
382 rot.s = s * rnorm;
383
384 return rot;
385 }
386
387 // One sided Jacobi rotations. For a matrix,
388 // [a_pp, a_pq]
389 // [a_qp, a_qq]
390 // After applying Jacobi rotations on both sides, the matrix is diagonalized.
391 // [b_pp, 0]
392 // [0, b_qq]
393 //
394 // def jacobi_rot(a, p, q, eps):
395 // t = a[p, p] + a[q, q]
396 // d = a[q, p] - a[p, q]
397 //
398 // if np.abs(d) < eps:
399 // s = 0.0
400 // c = 1.0
401 // else:
402 // u = t / d
403 // tmp = np.sqrt(1.0 + u**2)
404 // s = -1.0 / tmp
405 // c = u / tmp
406 //
407 // rot = np.array([[c, s], [-s, c]])
408 // m_tmp = rot.T @ a[[p, q], [p, q]]
409 // c_r, s_r = make_jacobi(m_tmp[0, 0], m_tmp[1, 1], m_tmp[0, 1])
410 // rot_r = np.array([[c_r, s_r], [-s_r, c_r]])
411 // rot_l = rot @ rot_r
412 // return rot_l, rot_r
413 //
GetOneSidedJacobiRotation(XlaOp a,XlaOp p,XlaOp q,XlaOp eps)414 StatusOr<OneSidedJacobiRotation> GetOneSidedJacobiRotation(XlaOp a, XlaOp p,
415 XlaOp q, XlaOp eps) {
416 XlaOp a_pp = DynamicSliceInMinorDims(a, {p, p}, {1, 1});
417 XlaOp a_pq = DynamicSliceInMinorDims(a, {p, q}, {1, 1});
418 XlaOp a_qp = DynamicSliceInMinorDims(a, {q, p}, {1, 1});
419 XlaOp a_qq = DynamicSliceInMinorDims(a, {q, q}, {1, 1});
420
421 XlaOp one = ScalarLike(a, 1.0);
422
423 XlaOp t = a_pp + a_qq;
424 XlaOp d = a_qp - a_pq;
425
426 XlaOp u = Div(t, d);
427 XlaOp tmp = Rsqrt(one + Square(u));
428
429 JacobiRotation rot;
430
431 XlaOp zeros = ZerosLike(tmp);
432 XlaOp ones = zeros + one;
433
434 rot.s = Select(Lt(Abs(d), eps), zeros, -tmp);
435 rot.c = Select(Lt(Abs(d), eps), ones, Mul(u, tmp));
436
437 XlaOp a_pp_new = rot.c * a_pp - rot.s * a_qp;
438 XlaOp a_pq_new = rot.c * a_pq - rot.s * a_qq;
439 XlaOp a_qq_new = rot.s * a_pq + rot.c * a_qq;
440
441 OneSidedJacobiRotation rots;
442 TF_ASSIGN_OR_RETURN(rots.rot_r,
443 MakeJacobi(a_pp_new, a_qq_new, a_pq_new, eps));
444
445 rots.rot_l.c = rot.c * rots.rot_r.c - rot.s * rots.rot_r.s;
446 rots.rot_l.s = rot.s * rots.rot_r.c + rot.c * rots.rot_r.s;
447
448 return rots;
449 }
450
451 // Apply one-sided Jacobi on elements at indices pp, pq, qp, qq.
OneSidedJacobiUpdate(SVDResult svd_result,XlaOp p,XlaOp q,XlaOp eps)452 StatusOr<SVDResult> OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, XlaOp q,
453 XlaOp eps) {
454 XlaOp u = svd_result.u;
455 XlaOp v = svd_result.v;
456 XlaOp d = svd_result.d;
457 XlaBuilder* builder = d.builder();
458 TF_ASSIGN_OR_RETURN(Shape d_shape, builder->GetShape(d));
459 const int64_t num_dims = d_shape.rank();
460 const int64_t num_batch_dims = num_dims - 2;
461 std::vector<int64_t> batch_dims(num_batch_dims);
462 for (int i = 0; i < num_batch_dims; ++i) {
463 batch_dims[i] = ShapeUtil::GetDimension(d_shape, i);
464 }
465 const int64_t m = ShapeUtil::GetDimension(d_shape, -2);
466 const int64_t n = ShapeUtil::GetDimension(d_shape, -1);
467
468 TF_ASSIGN_OR_RETURN(OneSidedJacobiRotation onesided_jacobi,
469 GetOneSidedJacobiRotation(d, p, q, eps));
470
471 auto zero = ScalarLike(p, 0);
472
473 // Zero out a_{pq} explicitly.
474 std::vector<int64_t> pq_dims(batch_dims.begin(), batch_dims.end());
475 pq_dims.push_back(1);
476 pq_dims.push_back(1);
477 auto pq_zero = ScalarLike(d, 0.0);
478 auto pq_zeros = Broadcast(pq_zero, pq_dims);
479
480 std::vector<int64_t> broadcast_dims(batch_dims.size());
481 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
482 broadcast_dims.push_back(num_dims - 1);
483
484 // Apply Jacobi Rotation on the left.
485 auto slice_p = DynamicSliceInMinorDims(d, {p, zero}, {1, n});
486 auto slice_q = DynamicSliceInMinorDims(d, {q, zero}, {1, n});
487 auto slice_p_new =
488 onesided_jacobi.rot_l.c * slice_p - onesided_jacobi.rot_l.s * slice_q;
489 auto slice_q_new =
490 onesided_jacobi.rot_l.s * slice_p + onesided_jacobi.rot_l.c * slice_q;
491 d = DynamicUpdateSliceInMinorDims(d, slice_p_new, {p, zero});
492 d = DynamicUpdateSliceInMinorDims(d, slice_q_new, {q, zero});
493
494 // Apply Jacobi Rotation on the right.
495 slice_p = DynamicSliceInMinorDims(d, {zero, p}, {m, 1});
496 slice_q = DynamicSliceInMinorDims(d, {zero, q}, {m, 1});
497 slice_p_new =
498 onesided_jacobi.rot_r.c * slice_p - onesided_jacobi.rot_r.s * slice_q;
499 slice_q_new =
500 onesided_jacobi.rot_r.s * slice_p + onesided_jacobi.rot_r.c * slice_q;
501 d = DynamicUpdateSliceInMinorDims(d, slice_p_new, {zero, p});
502 d = DynamicUpdateSliceInMinorDims(d, slice_q_new, {zero, q});
503
504 d = DynamicUpdateSliceInMinorDims(d, pq_zeros, {p, q});
505 d = DynamicUpdateSliceInMinorDims(d, pq_zeros, {q, p});
506
507 // Apply left Jacobi Rotation on U.
508 slice_p = DynamicSliceInMinorDims(u, {zero, p}, {m, 1});
509 slice_q = DynamicSliceInMinorDims(u, {zero, q}, {m, 1});
510 slice_p_new =
511 onesided_jacobi.rot_l.c * slice_p - onesided_jacobi.rot_l.s * slice_q;
512
513 slice_p_new = Mul(
514 slice_p_new,
515 Rsqrt(Reduce(Square(slice_p_new), pq_zero,
516 CreateScalarAddComputation(d_shape.element_type(), builder),
517 {num_dims - 2})),
518 broadcast_dims);
519
520 slice_q_new =
521 onesided_jacobi.rot_l.s * slice_p + onesided_jacobi.rot_l.c * slice_q;
522
523 slice_q_new = Mul(
524 slice_q_new,
525 Rsqrt(Reduce(Square(slice_q_new), pq_zero,
526 CreateScalarAddComputation(d_shape.element_type(), builder),
527 {num_dims - 2})),
528 broadcast_dims);
529
530 u = DynamicUpdateSliceInMinorDims(u, slice_p_new, {zero, p});
531 u = DynamicUpdateSliceInMinorDims(u, slice_q_new, {zero, q});
532
533 // Apply right Jacobi Rotation on V.
534 slice_p = DynamicSliceInMinorDims(v, {zero, p}, {n, 1});
535 slice_q = DynamicSliceInMinorDims(v, {zero, q}, {n, 1});
536 slice_p_new =
537 onesided_jacobi.rot_r.c * slice_p - onesided_jacobi.rot_r.s * slice_q;
538
539 slice_p_new = Mul(
540 slice_p_new,
541 Rsqrt(Reduce(Square(slice_p_new), pq_zero,
542 CreateScalarAddComputation(d_shape.element_type(), builder),
543 {num_dims - 2})),
544 broadcast_dims);
545
546 slice_q_new =
547 onesided_jacobi.rot_r.s * slice_p + onesided_jacobi.rot_r.c * slice_q;
548
549 slice_q_new = Mul(
550 slice_q_new,
551 Rsqrt(Reduce(Square(slice_q_new), pq_zero,
552 CreateScalarAddComputation(d_shape.element_type(), builder),
553 {num_dims - 2})),
554 broadcast_dims);
555
556 v = DynamicUpdateSliceInMinorDims(v, slice_p_new, {zero, p});
557 v = DynamicUpdateSliceInMinorDims(v, slice_q_new, {zero, q});
558
559 svd_result.d = d;
560 svd_result.u = u;
561 svd_result.v = v;
562
563 return svd_result;
564 }
565
ComputeToleranceComparison(XlaOp w,XlaOp epsilon)566 StatusOr<XlaOp> ComputeToleranceComparison(XlaOp w, XlaOp epsilon) {
567 XlaBuilder* builder = w.builder();
568 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w));
569 auto num_dims = static_cast<int32_t>(shape.rank());
570 int64_t n = shape.dimensions(num_dims - 1);
571 shape.set_dimensions(num_dims - 2, n);
572 auto w_sliced = SliceInMinorDims(w, {0, 0}, {n, n});
573 auto diag = GetMatrixDiagonal(w_sliced);
574 diag = Select(Lt(diag, ZerosLike(diag)), -diag, diag);
575 std::vector<int64_t> broadcasted_dims(num_dims - 1);
576 std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
577 auto broadcast_to_rows =
578 BroadcastInDim(diag, shape.dimensions(), broadcasted_dims);
579 broadcasted_dims.back() = num_dims - 1;
580 auto broadcast_to_columns =
581 BroadcastInDim(diag, shape.dimensions(), broadcasted_dims);
582 // Compute tolerance = w_{i,i} * w_{j,j} * epsilon^2
583 // Use at least F32 precision to avoid precision issues with small denormal.
584 XlaOp tolerance;
585 if (builder->GetShape(epsilon)->element_type() == BF16 ||
586 builder->GetShape(epsilon)->element_type() == F16) {
587 auto upscale_eps = ConvertElementType(epsilon, F32);
588 tolerance = ConvertElementType(broadcast_to_rows, F32) *
589 ConvertElementType(broadcast_to_columns, F32) * upscale_eps *
590 upscale_eps;
591 // Convert back into the original precision.
592 tolerance = ConvertElementType(tolerance,
593 builder->GetShape(epsilon)->element_type());
594 } else {
595 tolerance = broadcast_to_rows * broadcast_to_columns * epsilon * epsilon;
596 }
597 // tolerance < (w_{i,j})^2
598 return Lt(tolerance, Square(Select(GetDiagonalMask(w_sliced),
599 ZerosLike(w_sliced), w_sliced)));
600 }
601
602 // Main boby of One-sided Jacobi Method.
WhileLoopFn(absl::Span<const XlaOp> initial_values,int matrix_dimension,int max_sweep_updates,absl::string_view name,XlaBuilder * builder)603 StatusOr<std::vector<XlaOp>> WhileLoopFn(
604 absl::Span<const XlaOp> initial_values, //
605 int matrix_dimension, //
606 int max_sweep_updates, //
607 absl::string_view name, //
608 XlaBuilder* builder) {
609 auto while_cond_fn = [&](absl::Span<const XlaOp> values,
610 XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
611 auto k = values[0];
612 auto max_sweeps = ScalarLike(k, max_sweep_updates);
613 auto sweep_update_cond = Gt(max_sweeps, k);
614
615 TF_ASSIGN_OR_RETURN(auto tolerance_comparison,
616 ComputeToleranceComparison(values[3], values[4]));
617 auto tolerance_cond = ReduceAll(
618 tolerance_comparison, xla::ConstantR0<bool>(cond_builder, false),
619 CreateScalarOrComputation(PRED, cond_builder));
620
621 return And(sweep_update_cond, tolerance_cond);
622 };
623
624 auto while_body_fn =
625 [&](absl::Span<const XlaOp> values,
626 XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
627 auto while_cond_fn_inner =
628 [&](absl::Span<const XlaOp> values_inner,
629 XlaBuilder* inner_cond_builder) -> StatusOr<XlaOp> {
630 auto p = values_inner[0];
631 return Lt(p, ScalarLike(p, matrix_dimension - 1));
632 };
633
634 auto while_body_fn_inner =
635 [&](absl::Span<const XlaOp> values_inner,
636 XlaBuilder* inner_body_builder) -> StatusOr<std::vector<XlaOp>> {
637 auto while_cond_fn_innermost =
638 [&](absl::Span<const XlaOp> values_innermost,
639 XlaBuilder* innermost_cond_builder) -> StatusOr<XlaOp> {
640 auto q = values_innermost[1];
641 return Lt(q, ScalarLike(q, matrix_dimension));
642 };
643 auto while_body_fn_innermost =
644 [&](absl::Span<const XlaOp> values_innermost,
645 XlaBuilder* innermost_body_builder)
646 -> StatusOr<std::vector<XlaOp>> {
647 auto p = values_innermost[0];
648 auto q = values_innermost[1];
649
650 SVDResult onesided_jacobi_update;
651 onesided_jacobi_update.u = values_innermost[2];
652 onesided_jacobi_update.v = values_innermost[3];
653 onesided_jacobi_update.d = values_innermost[4];
654
655 auto eps = values_innermost[5];
656
657 TF_ASSIGN_OR_RETURN(
658 onesided_jacobi_update,
659 OneSidedJacobiUpdate(onesided_jacobi_update, p, q, eps));
660
661 std::vector<XlaOp> updated_values_innermost;
662 updated_values_innermost.reserve(values_innermost.size());
663
664 updated_values_innermost.push_back(p);
665 updated_values_innermost.push_back(q + ScalarLike(q, 1));
666 updated_values_innermost.push_back(onesided_jacobi_update.u);
667 updated_values_innermost.push_back(onesided_jacobi_update.v);
668 updated_values_innermost.push_back(onesided_jacobi_update.d);
669 updated_values_innermost.push_back(eps);
670
671 return updated_values_innermost;
672 };
673
674 std::vector<XlaOp> values_innermost(6);
675 auto p = values_inner[0];
676 auto q = p + ScalarLike(p, 1);
677 values_innermost[0] = p; // index p.
678 values_innermost[1] = q; // index q.
679 values_innermost[2] = values_inner[1]; // u.
680 values_innermost[3] = values_inner[2]; // v.
681 values_innermost[4] = values_inner[3]; // d.
682 values_innermost[5] = values_inner[4]; // eps.
683 TF_ASSIGN_OR_RETURN(
684 values_innermost,
685 WhileLoopHelper(while_cond_fn_innermost, while_body_fn_innermost,
686 values_innermost, absl::StrCat(name, "-Innermost"),
687 inner_body_builder));
688
689 std::vector<XlaOp> updated_values_inner;
690 updated_values_inner.reserve(values_inner.size());
691
692 updated_values_inner.push_back(p + ScalarLike(p, 1));
693 updated_values_inner.push_back(values_innermost[2]);
694 updated_values_inner.push_back(values_innermost[3]);
695 updated_values_inner.push_back(values_innermost[4]);
696 updated_values_inner.push_back(values_innermost[5]);
697 return updated_values_inner;
698 };
699 // Indexes.
700 XlaOp k = values[0];
701
702 std::vector<XlaOp> values_inner(5);
703 values_inner[0] = ScalarLike(k, 0); // index p.
704 values_inner[1] = values[1]; // u.
705 values_inner[2] = values[2]; // v.
706 values_inner[3] = values[3]; // d.
707 values_inner[4] = values[4]; // eps.
708 TF_ASSIGN_OR_RETURN(
709 values_inner,
710 WhileLoopHelper(while_cond_fn_inner, while_body_fn_inner, values_inner,
711 absl::StrCat(name, "-Inner"), body_builder));
712
713 std::vector<XlaOp> updated_values;
714 updated_values.reserve(values_inner.size());
715
716 updated_values.push_back(k + ScalarLike(k, 1));
717 updated_values.push_back(values_inner[1]);
718 updated_values.push_back(values_inner[2]);
719 updated_values.push_back(values_inner[3]);
720 updated_values.push_back(values_inner[4]);
721
722 return updated_values;
723 };
724 std::vector<XlaOp> values;
725 TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
726 initial_values, name, builder));
727
728 return values;
729 }
730
731 // Sort singular values in decending order, and make sure they are non-negative
732 // by flipping the signs of negative diagonal values and transferring the signs
733 // to V. And for numeric stability, renormalize U and V.
SortBySingularValuesAndPostProcessing(SVDResult result)734 StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
735 XlaBuilder* builder = result.d.builder();
736 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.d));
737 const int64_t num_dims = shape.rank();
738 auto dimensions = shape.dimensions();
739 const int64_t m = ShapeUtil::GetDimension(shape, -2);
740 const int64_t n = ShapeUtil::GetDimension(shape, -1);
741
742 std::vector<int64_t> broadcast_dims(num_dims - 1);
743 std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
744 broadcast_dims[num_dims - 2] = num_dims - 1;
745
746 auto d = GetMatrixDiagonal(result.d);
747
748 auto zeros = ZerosLike(d);
749 auto one = ScalarLike(d, 1.0);
750
751 // Make all the singular values to be non-negative by transferring the signs
752 // to V.
753 auto sign = Select(Ge(d, zeros), zeros + one, zeros - one);
754 d = Select(Ge(d, zeros), d, -d);
755 result.v = Mul(result.v, sign, broadcast_dims);
756
757 d = BroadcastInDim(d, dimensions, broadcast_dims);
758
759 // As m >= n, only first n column vectors need to be permuted, and the rest of
760 // m - n vectors are appended after the sorting is done.
761 XlaOp sort_u_result =
762 Sort({d, SliceInMinorDims(result.u, {0, 0}, {m, n})},
763 CreateScalarGtComputation(
764 {shape.element_type(), shape.element_type()}, builder),
765 num_dims - 1);
766
767 XlaOp sort_v_result =
768 Sort({SliceInMinorDims(d, {0, 0}, {n, n}), result.v},
769 CreateScalarGtComputation(
770 {shape.element_type(), shape.element_type()}, builder),
771 num_dims - 1);
772 result.d = GetMatrixDiagonal(GetTupleElement(sort_v_result, 0));
773
774 result.v = GetTupleElement(sort_v_result, 1);
775 result.v = Mul(
776 result.v,
777 Rsqrt(Reduce(Square(result.v), ScalarLike(d, 0.0),
778 CreateScalarAddComputation(shape.element_type(), builder),
779 {num_dims - 2})),
780 broadcast_dims);
781
782 // Append the rest of m - n vectors.
783 result.u = ConcatInDim(builder,
784 {GetTupleElement(sort_u_result, 1),
785 SliceInMinorDims(result.u, {0, n}, {m, m})},
786 num_dims - 1);
787 result.u = Mul(
788 result.u,
789 Rsqrt(Reduce(Square(result.u), ScalarLike(d, 0.0),
790 CreateScalarAddComputation(shape.element_type(), builder),
791 {num_dims - 2})),
792 broadcast_dims);
793
794 return result;
795 }
796
797 } // namespace
798
799 // def jacobi_svd(A):
800 // U, D, V = house_bidiag(A)
801 // m, n = D.shape
802 // iter, max_iter = 0, 100
803 // frobenius_norm = np.linalg.norm(D)
804 // diag_norm = np.linalg.norm(np.diag(D))
805 // off_diag_norm = np.sqrt(
806 // frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm)
807 // while off_diag_norm > 1e-6 * frobenius_norm and iter < max_iter:
808 // iter += 1
809 // for p in range(m - 1):
810 // for q in range(p + 1, n):
811 // rot_l, rot_r = jacobi_rot(D[p][p], D[p][q], D[q][p], D[q][q])
812 // D[[p, q], :] = np.matmul(rot_l.T, D[[p, q], :])
813 // D[:, [p, q]] = np.matmul(D[:, [p, q]], rot_r)
814 // U[:, [p, q]] = np.matmul(U[:, [p, q]], rot_l)
815 // V[:, [p, q]] = np.matmul(V[:, [p, q]], rot_r)
816 // frobenius_norm = np.linalg.norm(D)
817 // diag_norm = np.linalg.norm(np.diag(D))
818 // off_diag_norm = np.sqrt(
819 // frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm)
820 //
821 // return U, np.diag(D), V
822 //
SVD(XlaOp a,int64_t max_iter,float epsilon,PrecisionConfig::Precision precision)823 SVDResult SVD(XlaOp a, int64_t max_iter, float epsilon,
824 PrecisionConfig::Precision precision) {
825 XlaBuilder* builder = a.builder();
826 auto return_error = [&](const Status& status) {
827 SVDResult result;
828 result.u = builder->ReportError(status);
829 result.v = builder->ReportError(status);
830 result.d = builder->ReportError(status);
831 return result;
832 };
833 auto shape_with_status = builder->GetShape(a);
834 if (!shape_with_status.status().ok()) {
835 return return_error(shape_with_status.status());
836 }
837 Shape a_shape = shape_with_status.ValueOrDie();
838 const int64_t num_dims = a_shape.rank();
839 const int64_t num_batch_dims = num_dims - 2;
840 std::vector<int64_t> batch_dims(num_batch_dims);
841 for (int i = 0; i < num_batch_dims; ++i) {
842 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
843 }
844 int64_t m = ShapeUtil::GetDimension(a_shape, -2);
845 int64_t n = ShapeUtil::GetDimension(a_shape, -1);
846 bool maybe_transpose = m < n;
847
848 if (maybe_transpose) {
849 a = TransposeInMinorDims(a);
850 std::swap(m, n);
851 }
852
853 auto eps = ScalarLike(a, epsilon);
854
855 auto svd_result_or = HouseHolderBidiagonalization(a, eps, precision);
856 if (!svd_result_or.ok()) {
857 return return_error(svd_result_or.status());
858 }
859 SVDResult svd_result = svd_result_or.ValueOrDie();
860
861 auto output_with_status = WhileLoopFn(
862 {
863 Zero(builder, S32), // k
864 svd_result.u, // u
865 svd_result.v, // v
866 svd_result.d, // d
867 eps, // epsilon
868 }, //
869 n, //
870 max_iter, //
871 "CyclicOneSidedJacobi", //
872 builder);
873 if (!output_with_status.status().ok()) {
874 return return_error(output_with_status.status());
875 }
876
877 auto output = output_with_status.ValueOrDie();
878
879 svd_result.u = output[1];
880 svd_result.v = output[2];
881 svd_result.d = output[3];
882
883 svd_result_or = SortBySingularValuesAndPostProcessing(svd_result);
884 if (!svd_result_or.ok()) {
885 return return_error(svd_result_or.status());
886 }
887 svd_result = svd_result_or.ValueOrDie();
888
889 if (maybe_transpose) {
890 std::swap(svd_result.u, svd_result.v);
891 }
892 return svd_result;
893 }
894
895 } // namespace xla
896