xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/svd.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/client/lib/svd.h"
16 
17 #include <memory>
18 #include <numeric>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/comparators.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/loops.h"
26 #include "tensorflow/compiler/xla/client/lib/math.h"
27 #include "tensorflow/compiler/xla/client/lib/matrix.h"
28 #include "tensorflow/compiler/xla/client/lib/slicing.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/literal_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 
37 namespace xla {
38 
39 namespace {
40 
41 // Given a matrix A, define H,
42 //   H = A * (I - beta * v_T * v) if v is a row vector, or
43 //   H = (I - beta * v * v_T) if v is column vector.
44 // A * H or H * A zeros out trailing part of some row or column of A.
45 //
46 // [x0, ..., x_{k-1}, xk, x_{k+1}, ..., x_{n-1}] * H
47 //       = [x0, ..., x_{k-1}, xnorm, 0, ..., 0]
48 //
49 // Here xnorm = norm([x_k, x_{k+1}, ..., x_{n - 1}])
50 struct HouseHolderResult {
51   XlaOp v;
52   XlaOp beta;
53   XlaOp a;
54 };
55 
56 // Jacobi rotation (also known as Givens rotation):
57 // G = [[ c, s],
58 //      [-s, c]]
59 // matmul(G_T, G) = I
60 struct JacobiRotation {
61   XlaOp c;  // cosine.
62   XlaOp s;  // sine.
63 };
64 
65 // JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix.
66 struct JacobiUpdate {
67   XlaOp v;
68   XlaOp w;
69 };
70 
71 // OneSidedJacobiRotation holds the left and right Jacobi rotations. Refer to
72 // GetOneSidedJacobiRotation for the effect of applying OneSidedJacobiRotation
73 // to a matrix.
74 struct OneSidedJacobiRotation {
75   JacobiRotation rot_l;
76   JacobiRotation rot_r;
77 };
78 
79 // Householder reflection on the trailing elements of a vector.
80 //
81 // H = I - beta * [1, v]' * [1, v]
82 //
83 // H * x = [..., xnorm, 0, ..., 0]
84 //          ..., j, j + 1, ..., n
85 //
86 // def house(x, j, eps):
87 //    sigma = np.linalg.norm(x[(j + 1):])
88 //    v = np.zeros_like(x)
89 //    v[(j + 1):] = x[(j + 1):]
90 //    if sigma < eps:
91 //        beta = 0
92 //    else:
93 //        mu = sigma * np.sqrt((x[j]/sigma)**2 + 1)
94 //        if x[j] <= 0:
95 //            v[j] = x[j] - mu
96 //        else:
97 //            v[j] = -sigma / (x[j] + mu) * sigma
98 //        beta = 2 / ((sigma / v[j])**2 + 1)
99 //        v = v / v[j]
100 //    v[j] = 1
101 //    return v, beta
102 //
103 // Householder reflection on the trailing elements of a row of a matrix. After
104 // applying it on the matrix, all elements in [i, (j+1):] become zeros, i.e.,
105 //
106 // H = I - beta * [1, v]' * [1, v], then,
107 //
108 // A[i, j:] * H = [sigma, 0, 0, ..., 0]
109 //
HouseRow(XlaOp a,XlaOp i,XlaOp j,XlaOp eps,PrecisionConfig::Precision precision)110 StatusOr<HouseHolderResult> HouseRow(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
111                                      PrecisionConfig::Precision precision) {
112   XlaBuilder* builder = a.builder();
113   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
114   const int64_t num_dims = a_shape.rank();
115   const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
116   XlaOp zero = ScalarLike(i, 0);
117   XlaOp x = DynamicSliceInMinorDims(a, {i, zero}, {1, n});
118 
119   const int64_t num_batch_dims = num_dims - 2;
120   std::vector<int64_t> batch_dims(num_batch_dims);
121   for (int k = 0; k < num_batch_dims; ++k) {
122     batch_dims[k] = ShapeUtil::GetDimension(a_shape, k);
123   }
124 
125   TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
126   auto idx = Iota(builder, ShapeUtil::MakeShape(S32, x_shape.dimensions()),
127                   num_dims - 1);
128   auto zeros = ZerosLike(x);
129   auto v = Select(Gt(idx, j), x, zeros);
130 
131   auto one = ScalarLike(v, 1.0);
132 
133   auto sigma =
134       Sqrt(Reduce(Square(v), ScalarLike(v, 0.0),
135                   CreateScalarAddComputation(x_shape.element_type(), builder),
136                   {num_dims - 1}));
137 
138   std::vector<int64_t> broadcast_dims(num_dims - 1);
139   std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
140   auto x_0j = DynamicSliceInMinorDims(x, {zero, j}, {1, 1});
141   auto mu = Mul(sigma, Sqrt(Square(Div(x_0j, sigma, broadcast_dims)) + one),
142                 broadcast_dims);
143 
144   auto v_0j = Select(
145       Le(x_0j, ScalarLike(x_0j, 0.0)), Sub(x_0j, mu),
146       -Mul(sigma, Div(sigma, Add(x_0j, mu), broadcast_dims), broadcast_dims));
147 
148   auto beta = Div(ScalarLike(v_0j, 2.0),
149                   (Square(Div(sigma, v_0j, broadcast_dims)) + one));
150 
151   v = Select(
152       BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v,
153       v / v_0j);
154   v = Select(Eq(idx, j), zeros + one, v);
155 
156   beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps),
157                 ZerosLike(beta), beta);
158 
159   HouseHolderResult result;
160   result.v = v;
161   result.beta = beta;
162   result.a = Sub(a, Mul(beta, BatchDot(BatchDot(a, false, v, true, precision),
163                                        v, precision)));
164 
165   return result;
166 }
167 
168 // Householder reflection on the trailing elements of a col of a matrix. After
169 // applying it on the matrix, all elements in [(i+1):, j] become zeros, i.e.,
170 //
171 // H = I - beta * [1; v] * [1; v]', then,
172 //
173 // H * A[i:, j] = [xnorm, 0, 0, ..., 0]
174 //
HouseCol(XlaOp a,XlaOp i,XlaOp j,XlaOp eps,PrecisionConfig::Precision precision)175 StatusOr<HouseHolderResult> HouseCol(XlaOp a, XlaOp i, XlaOp j, XlaOp eps,
176                                      PrecisionConfig::Precision precision) {
177   XlaBuilder* builder = a.builder();
178   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
179   const int64_t num_dims = a_shape.rank();
180   const int64_t m = ShapeUtil::GetDimension(a_shape, -2);
181   XlaOp zero = ScalarLike(i, 0);
182   XlaOp x = DynamicSliceInMinorDims(a, {zero, j}, {m, 1});
183 
184   const int64_t num_batch_dims = num_dims - 2;
185   std::vector<int64_t> batch_dims(num_batch_dims);
186   for (int k = 0; k < num_batch_dims; ++k) {
187     batch_dims[k] = ShapeUtil::GetDimension(a_shape, k);
188   }
189 
190   TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
191   auto idx = Iota(builder, ShapeUtil::MakeShape(S32, x_shape.dimensions()),
192                   num_dims - 2);
193   auto zeros = ZerosLike(x);
194   auto v = Select(Gt(idx, i), x, zeros);
195 
196   auto one = ScalarLike(v, 1.0);
197 
198   auto sigma =
199       Sqrt(Reduce(Square(v), ScalarLike(v, 0.0),
200                   CreateScalarAddComputation(x_shape.element_type(), builder),
201                   {num_dims - 2}));
202 
203   std::vector<int64_t> broadcast_dims(num_dims - 1);
204   std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
205   broadcast_dims[num_dims - 2] = num_dims - 1;
206   auto x_0i = DynamicSliceInMinorDims(x, {i, zero}, {1, 1});
207   auto mu = Mul(sigma, Sqrt(Square(Div(x_0i, sigma, broadcast_dims)) + one),
208                 broadcast_dims);
209 
210   auto v_0i = Select(
211       Le(x_0i, ScalarLike(x_0i, 0.0)), Sub(x_0i, mu),
212       -Mul(sigma, Div(sigma, Add(x_0i, mu), broadcast_dims), broadcast_dims));
213 
214   auto beta = Div(ScalarLike(v_0i, 2.0),
215                   (Square(Div(sigma, v_0i, broadcast_dims)) + one));
216 
217   v = Select(
218       BroadcastInDim(Lt(sigma, eps), x_shape.dimensions(), broadcast_dims), v,
219       v / v_0i);
220   v = Select(Eq(idx, i), zeros + one, v);
221 
222   beta = Select(Lt(Add(sigma, ZerosLike(beta), broadcast_dims), eps),
223                 ZerosLike(beta), beta);
224 
225   HouseHolderResult result;
226   result.v = v;
227   result.beta = beta;
228   result.a = Sub(
229       a, Mul(beta, BatchDot(v, false, BatchDot(v, true, a, false, precision),
230                             false, precision)));
231 
232   return result;
233 }
234 
235 // Apply column and row householder reflections for bidiagonalization.
236 //
237 // def house_bidiag(A):
238 //    xz, yz = A.shape
239 //    LL = np.eye(xz)
240 //    RR = np.eye(yz)
241 //    for i in range(yz - 1):
242 //        v, beta = house_col(A, i, i, 1e-8)
243 //        L = np.eye(xz) - beta * np.outer(v, v)
244 //        LL = np.matmul(LL, L)
245 //        A = np.matmul(L, A)
246 //        if i < yz - 2:
247 //            v, beta = house_row(A, i, i + 1, 1e-8)
248 //            R = np.eye(yz) - beta * np.outer(v, v)
249 //            RR = np.matmul(RR, R)
250 //            A = np.matmul(A, R)
251 //    return LL, A, RR
252 //
HouseHolderBidiagonalization(XlaOp a,XlaOp eps,PrecisionConfig::Precision precision)253 StatusOr<SVDResult> HouseHolderBidiagonalization(
254     XlaOp a, XlaOp eps, PrecisionConfig::Precision precision) {
255   XlaBuilder* builder = a.builder();
256   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
257   const int64_t num_dims = a_shape.rank();
258   const int64_t num_batch_dims = num_dims - 2;
259   std::vector<int64_t> batch_dims(num_batch_dims);
260   for (int i = 0; i < num_batch_dims; ++i) {
261     batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
262   }
263   const int64_t m = ShapeUtil::GetDimension(a_shape, -2);
264   const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
265   XlaOp u_init = Broadcast(
266       IdentityMatrix(builder, a_shape.element_type(), m, m), batch_dims);
267   XlaOp v_init = Broadcast(
268       IdentityMatrix(builder, a_shape.element_type(), n, n), batch_dims);
269 
270   auto while_cond_fn = [&](absl::Span<const XlaOp> values,
271                            XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
272     auto i = values[0];
273     return Lt(i, ScalarLike(i, n - 2));
274   };
275   auto while_body_fn =
276       [&](absl::Span<const XlaOp> values,
277           XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
278     auto i = values[0];
279     auto one = ScalarLike(i, 1);
280 
281     auto u = values[1];
282     auto v = values[2];
283     auto a = values[3];
284     auto eps = values[4];
285 
286     TF_ASSIGN_OR_RETURN(HouseHolderResult house_col,
287                         HouseCol(a, i, i, eps, precision));
288     u = Sub(u,
289             Mul(house_col.beta, BatchDot(BatchDot(u, house_col.v, precision),
290                                          false, house_col.v, true, precision)));
291     a = house_col.a;
292 
293     TF_ASSIGN_OR_RETURN(HouseHolderResult house_row,
294                         HouseRow(a, i, i + one, eps, precision));
295     v = Sub(v, Mul(house_row.beta,
296                    BatchDot(BatchDot(v, false, house_row.v, true, precision),
297                             house_row.v, precision)));
298     a = house_row.a;
299 
300     std::vector<XlaOp> updated_values;
301     updated_values.reserve(values.size());
302 
303     updated_values.push_back(i + one);
304     updated_values.push_back(u);
305     updated_values.push_back(v);
306     updated_values.push_back(a);
307     updated_values.push_back(eps);
308     return updated_values;
309   };
310 
311   std::vector<XlaOp> values(5);
312   values[0] = Zero(builder, S32);
313   values[1] = u_init;
314   values[2] = v_init;
315   values[3] = a;
316   values[4] = eps;
317 
318   TF_ASSIGN_OR_RETURN(values,
319                       WhileLoopHelper(while_cond_fn, while_body_fn, values,
320                                       "HouseHolderBidiagonalization", builder));
321 
322   for (int k = 2; k > 0; --k) {
323     if (n - k >= 0) {
324       XlaOp index = ScalarLike(values[0], n - k);
325       TF_ASSIGN_OR_RETURN(HouseHolderResult house_col,
326                           HouseCol(values[3], index, index, eps, precision));
327       values[1] = Sub(values[1],
328                       Mul(house_col.beta,
329                           BatchDot(BatchDot(values[1], house_col.v, precision),
330                                    false, house_col.v, true, precision)));
331       values[3] = house_col.a;
332     }
333   }
334 
335   SVDResult result;
336   result.u = values[1];
337   result.v = values[2];
338   result.d = values[3];
339   return result;
340 }
341 
342 // MakeJacobi computes a rotation matrix G = [[c, s], [-s, c]], such that
343 //                        G_T * [[ps, pqs], [pqs, qs]] * G
344 // is diagonalized.
345 //
346 //  def make_jacobi(ps, qs, pqs, eps):
347 //     if np.abs(a_pq) > eps:
348 //         tau = (a_qq - a_pp) / (2 * a_pq)
349 //         if tau >= 0:
350 //             t = 1.0 / (tau + np.sqrt(1 + tau ** 2))
351 //         else:
352 //             t = -1.0 / (-tau + np.sqrt(1 + tau ** 2))
353 //         c = 1.0 / np.sqrt(1.0 + t ** 2)
354 //         s = t * c
355 //     else:
356 //         c = 1.0
357 //         s = 0.0
358 //     return c, s
359 //
MakeJacobi(XlaOp ps,XlaOp qs,XlaOp pqs,XlaOp eps)360 StatusOr<JacobiRotation> MakeJacobi(XlaOp ps, XlaOp qs, XlaOp pqs, XlaOp eps) {
361   auto zero = ScalarLike(ps, 0.0);
362   auto one = ScalarLike(ps, 1.0);
363   auto two = ScalarLike(ps, 2.0);
364 
365   auto tau = (qs - ps) / (pqs * two);
366   auto t_pos = one / (tau + Sqrt(one + Square(tau)));
367   auto t_neg = -one / (-tau + Sqrt(one + Square(tau)));
368   auto t = Select(Ge(tau, zero), t_pos, t_neg);
369 
370   auto c_temp = Rsqrt(one + Square(t));
371   auto s_temp = t * c_temp;
372 
373   auto c = Select(Ge(Abs(pqs), eps), c_temp, ZerosLike(c_temp) + one);
374   auto s = Select(Ge(Abs(pqs), eps), s_temp, ZerosLike(s_temp));
375   // Renormalize c and s to compensate for low precision arithmetic, this step
376   // is redundant if high precision float is used, like float64.
377   auto rnorm = Rsqrt(Square(c) + Square(s));
378 
379   JacobiRotation rot;
380 
381   rot.c = c * rnorm;
382   rot.s = s * rnorm;
383 
384   return rot;
385 }
386 
387 // One sided Jacobi rotations. For a matrix,
388 //  [a_pp, a_pq]
389 //  [a_qp, a_qq]
390 // After applying Jacobi rotations on both sides, the matrix is diagonalized.
391 //  [b_pp, 0]
392 //  [0, b_qq]
393 //
394 // def jacobi_rot(a, p, q, eps):
395 //     t = a[p, p] + a[q, q]
396 //     d = a[q, p] - a[p, q]
397 //
398 //     if np.abs(d) < eps:
399 //         s = 0.0
400 //         c = 1.0
401 //     else:
402 //         u = t / d
403 //         tmp = np.sqrt(1.0 + u**2)
404 //         s = -1.0 / tmp
405 //         c = u / tmp
406 //
407 //     rot = np.array([[c, s], [-s, c]])
408 //     m_tmp = rot.T @ a[[p, q], [p, q]]
409 //     c_r, s_r = make_jacobi(m_tmp[0, 0], m_tmp[1, 1], m_tmp[0, 1])
410 //     rot_r = np.array([[c_r, s_r], [-s_r, c_r]])
411 //     rot_l = rot @ rot_r
412 //    return rot_l, rot_r
413 //
GetOneSidedJacobiRotation(XlaOp a,XlaOp p,XlaOp q,XlaOp eps)414 StatusOr<OneSidedJacobiRotation> GetOneSidedJacobiRotation(XlaOp a, XlaOp p,
415                                                            XlaOp q, XlaOp eps) {
416   XlaOp a_pp = DynamicSliceInMinorDims(a, {p, p}, {1, 1});
417   XlaOp a_pq = DynamicSliceInMinorDims(a, {p, q}, {1, 1});
418   XlaOp a_qp = DynamicSliceInMinorDims(a, {q, p}, {1, 1});
419   XlaOp a_qq = DynamicSliceInMinorDims(a, {q, q}, {1, 1});
420 
421   XlaOp one = ScalarLike(a, 1.0);
422 
423   XlaOp t = a_pp + a_qq;
424   XlaOp d = a_qp - a_pq;
425 
426   XlaOp u = Div(t, d);
427   XlaOp tmp = Rsqrt(one + Square(u));
428 
429   JacobiRotation rot;
430 
431   XlaOp zeros = ZerosLike(tmp);
432   XlaOp ones = zeros + one;
433 
434   rot.s = Select(Lt(Abs(d), eps), zeros, -tmp);
435   rot.c = Select(Lt(Abs(d), eps), ones, Mul(u, tmp));
436 
437   XlaOp a_pp_new = rot.c * a_pp - rot.s * a_qp;
438   XlaOp a_pq_new = rot.c * a_pq - rot.s * a_qq;
439   XlaOp a_qq_new = rot.s * a_pq + rot.c * a_qq;
440 
441   OneSidedJacobiRotation rots;
442   TF_ASSIGN_OR_RETURN(rots.rot_r,
443                       MakeJacobi(a_pp_new, a_qq_new, a_pq_new, eps));
444 
445   rots.rot_l.c = rot.c * rots.rot_r.c - rot.s * rots.rot_r.s;
446   rots.rot_l.s = rot.s * rots.rot_r.c + rot.c * rots.rot_r.s;
447 
448   return rots;
449 }
450 
451 // Apply one-sided Jacobi on elements at indices pp, pq, qp, qq.
OneSidedJacobiUpdate(SVDResult svd_result,XlaOp p,XlaOp q,XlaOp eps)452 StatusOr<SVDResult> OneSidedJacobiUpdate(SVDResult svd_result, XlaOp p, XlaOp q,
453                                          XlaOp eps) {
454   XlaOp u = svd_result.u;
455   XlaOp v = svd_result.v;
456   XlaOp d = svd_result.d;
457   XlaBuilder* builder = d.builder();
458   TF_ASSIGN_OR_RETURN(Shape d_shape, builder->GetShape(d));
459   const int64_t num_dims = d_shape.rank();
460   const int64_t num_batch_dims = num_dims - 2;
461   std::vector<int64_t> batch_dims(num_batch_dims);
462   for (int i = 0; i < num_batch_dims; ++i) {
463     batch_dims[i] = ShapeUtil::GetDimension(d_shape, i);
464   }
465   const int64_t m = ShapeUtil::GetDimension(d_shape, -2);
466   const int64_t n = ShapeUtil::GetDimension(d_shape, -1);
467 
468   TF_ASSIGN_OR_RETURN(OneSidedJacobiRotation onesided_jacobi,
469                       GetOneSidedJacobiRotation(d, p, q, eps));
470 
471   auto zero = ScalarLike(p, 0);
472 
473   // Zero out a_{pq} explicitly.
474   std::vector<int64_t> pq_dims(batch_dims.begin(), batch_dims.end());
475   pq_dims.push_back(1);
476   pq_dims.push_back(1);
477   auto pq_zero = ScalarLike(d, 0.0);
478   auto pq_zeros = Broadcast(pq_zero, pq_dims);
479 
480   std::vector<int64_t> broadcast_dims(batch_dims.size());
481   std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
482   broadcast_dims.push_back(num_dims - 1);
483 
484   // Apply Jacobi Rotation on the left.
485   auto slice_p = DynamicSliceInMinorDims(d, {p, zero}, {1, n});
486   auto slice_q = DynamicSliceInMinorDims(d, {q, zero}, {1, n});
487   auto slice_p_new =
488       onesided_jacobi.rot_l.c * slice_p - onesided_jacobi.rot_l.s * slice_q;
489   auto slice_q_new =
490       onesided_jacobi.rot_l.s * slice_p + onesided_jacobi.rot_l.c * slice_q;
491   d = DynamicUpdateSliceInMinorDims(d, slice_p_new, {p, zero});
492   d = DynamicUpdateSliceInMinorDims(d, slice_q_new, {q, zero});
493 
494   // Apply Jacobi Rotation on the right.
495   slice_p = DynamicSliceInMinorDims(d, {zero, p}, {m, 1});
496   slice_q = DynamicSliceInMinorDims(d, {zero, q}, {m, 1});
497   slice_p_new =
498       onesided_jacobi.rot_r.c * slice_p - onesided_jacobi.rot_r.s * slice_q;
499   slice_q_new =
500       onesided_jacobi.rot_r.s * slice_p + onesided_jacobi.rot_r.c * slice_q;
501   d = DynamicUpdateSliceInMinorDims(d, slice_p_new, {zero, p});
502   d = DynamicUpdateSliceInMinorDims(d, slice_q_new, {zero, q});
503 
504   d = DynamicUpdateSliceInMinorDims(d, pq_zeros, {p, q});
505   d = DynamicUpdateSliceInMinorDims(d, pq_zeros, {q, p});
506 
507   // Apply left Jacobi Rotation on U.
508   slice_p = DynamicSliceInMinorDims(u, {zero, p}, {m, 1});
509   slice_q = DynamicSliceInMinorDims(u, {zero, q}, {m, 1});
510   slice_p_new =
511       onesided_jacobi.rot_l.c * slice_p - onesided_jacobi.rot_l.s * slice_q;
512 
513   slice_p_new = Mul(
514       slice_p_new,
515       Rsqrt(Reduce(Square(slice_p_new), pq_zero,
516                    CreateScalarAddComputation(d_shape.element_type(), builder),
517                    {num_dims - 2})),
518       broadcast_dims);
519 
520   slice_q_new =
521       onesided_jacobi.rot_l.s * slice_p + onesided_jacobi.rot_l.c * slice_q;
522 
523   slice_q_new = Mul(
524       slice_q_new,
525       Rsqrt(Reduce(Square(slice_q_new), pq_zero,
526                    CreateScalarAddComputation(d_shape.element_type(), builder),
527                    {num_dims - 2})),
528       broadcast_dims);
529 
530   u = DynamicUpdateSliceInMinorDims(u, slice_p_new, {zero, p});
531   u = DynamicUpdateSliceInMinorDims(u, slice_q_new, {zero, q});
532 
533   // Apply right Jacobi Rotation on V.
534   slice_p = DynamicSliceInMinorDims(v, {zero, p}, {n, 1});
535   slice_q = DynamicSliceInMinorDims(v, {zero, q}, {n, 1});
536   slice_p_new =
537       onesided_jacobi.rot_r.c * slice_p - onesided_jacobi.rot_r.s * slice_q;
538 
539   slice_p_new = Mul(
540       slice_p_new,
541       Rsqrt(Reduce(Square(slice_p_new), pq_zero,
542                    CreateScalarAddComputation(d_shape.element_type(), builder),
543                    {num_dims - 2})),
544       broadcast_dims);
545 
546   slice_q_new =
547       onesided_jacobi.rot_r.s * slice_p + onesided_jacobi.rot_r.c * slice_q;
548 
549   slice_q_new = Mul(
550       slice_q_new,
551       Rsqrt(Reduce(Square(slice_q_new), pq_zero,
552                    CreateScalarAddComputation(d_shape.element_type(), builder),
553                    {num_dims - 2})),
554       broadcast_dims);
555 
556   v = DynamicUpdateSliceInMinorDims(v, slice_p_new, {zero, p});
557   v = DynamicUpdateSliceInMinorDims(v, slice_q_new, {zero, q});
558 
559   svd_result.d = d;
560   svd_result.u = u;
561   svd_result.v = v;
562 
563   return svd_result;
564 }
565 
ComputeToleranceComparison(XlaOp w,XlaOp epsilon)566 StatusOr<XlaOp> ComputeToleranceComparison(XlaOp w, XlaOp epsilon) {
567   XlaBuilder* builder = w.builder();
568   TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(w));
569   auto num_dims = static_cast<int32_t>(shape.rank());
570   int64_t n = shape.dimensions(num_dims - 1);
571   shape.set_dimensions(num_dims - 2, n);
572   auto w_sliced = SliceInMinorDims(w, {0, 0}, {n, n});
573   auto diag = GetMatrixDiagonal(w_sliced);
574   diag = Select(Lt(diag, ZerosLike(diag)), -diag, diag);
575   std::vector<int64_t> broadcasted_dims(num_dims - 1);
576   std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
577   auto broadcast_to_rows =
578       BroadcastInDim(diag, shape.dimensions(), broadcasted_dims);
579   broadcasted_dims.back() = num_dims - 1;
580   auto broadcast_to_columns =
581       BroadcastInDim(diag, shape.dimensions(), broadcasted_dims);
582   // Compute tolerance = w_{i,i} * w_{j,j} * epsilon^2
583   // Use at least F32 precision to avoid precision issues with small denormal.
584   XlaOp tolerance;
585   if (builder->GetShape(epsilon)->element_type() == BF16 ||
586       builder->GetShape(epsilon)->element_type() == F16) {
587     auto upscale_eps = ConvertElementType(epsilon, F32);
588     tolerance = ConvertElementType(broadcast_to_rows, F32) *
589                 ConvertElementType(broadcast_to_columns, F32) * upscale_eps *
590                 upscale_eps;
591     // Convert back into the original precision.
592     tolerance = ConvertElementType(tolerance,
593                                    builder->GetShape(epsilon)->element_type());
594   } else {
595     tolerance = broadcast_to_rows * broadcast_to_columns * epsilon * epsilon;
596   }
597   // tolerance < (w_{i,j})^2
598   return Lt(tolerance, Square(Select(GetDiagonalMask(w_sliced),
599                                      ZerosLike(w_sliced), w_sliced)));
600 }
601 
602 // Main boby of One-sided Jacobi Method.
WhileLoopFn(absl::Span<const XlaOp> initial_values,int matrix_dimension,int max_sweep_updates,absl::string_view name,XlaBuilder * builder)603 StatusOr<std::vector<XlaOp>> WhileLoopFn(
604     absl::Span<const XlaOp> initial_values,  //
605     int matrix_dimension,                    //
606     int max_sweep_updates,                   //
607     absl::string_view name,                  //
608     XlaBuilder* builder) {
609   auto while_cond_fn = [&](absl::Span<const XlaOp> values,
610                            XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
611     auto k = values[0];
612     auto max_sweeps = ScalarLike(k, max_sweep_updates);
613     auto sweep_update_cond = Gt(max_sweeps, k);
614 
615     TF_ASSIGN_OR_RETURN(auto tolerance_comparison,
616                         ComputeToleranceComparison(values[3], values[4]));
617     auto tolerance_cond = ReduceAll(
618         tolerance_comparison, xla::ConstantR0<bool>(cond_builder, false),
619         CreateScalarOrComputation(PRED, cond_builder));
620 
621     return And(sweep_update_cond, tolerance_cond);
622   };
623 
624   auto while_body_fn =
625       [&](absl::Span<const XlaOp> values,
626           XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
627     auto while_cond_fn_inner =
628         [&](absl::Span<const XlaOp> values_inner,
629             XlaBuilder* inner_cond_builder) -> StatusOr<XlaOp> {
630       auto p = values_inner[0];
631       return Lt(p, ScalarLike(p, matrix_dimension - 1));
632     };
633 
634     auto while_body_fn_inner =
635         [&](absl::Span<const XlaOp> values_inner,
636             XlaBuilder* inner_body_builder) -> StatusOr<std::vector<XlaOp>> {
637       auto while_cond_fn_innermost =
638           [&](absl::Span<const XlaOp> values_innermost,
639               XlaBuilder* innermost_cond_builder) -> StatusOr<XlaOp> {
640         auto q = values_innermost[1];
641         return Lt(q, ScalarLike(q, matrix_dimension));
642       };
643       auto while_body_fn_innermost =
644           [&](absl::Span<const XlaOp> values_innermost,
645               XlaBuilder* innermost_body_builder)
646           -> StatusOr<std::vector<XlaOp>> {
647         auto p = values_innermost[0];
648         auto q = values_innermost[1];
649 
650         SVDResult onesided_jacobi_update;
651         onesided_jacobi_update.u = values_innermost[2];
652         onesided_jacobi_update.v = values_innermost[3];
653         onesided_jacobi_update.d = values_innermost[4];
654 
655         auto eps = values_innermost[5];
656 
657         TF_ASSIGN_OR_RETURN(
658             onesided_jacobi_update,
659             OneSidedJacobiUpdate(onesided_jacobi_update, p, q, eps));
660 
661         std::vector<XlaOp> updated_values_innermost;
662         updated_values_innermost.reserve(values_innermost.size());
663 
664         updated_values_innermost.push_back(p);
665         updated_values_innermost.push_back(q + ScalarLike(q, 1));
666         updated_values_innermost.push_back(onesided_jacobi_update.u);
667         updated_values_innermost.push_back(onesided_jacobi_update.v);
668         updated_values_innermost.push_back(onesided_jacobi_update.d);
669         updated_values_innermost.push_back(eps);
670 
671         return updated_values_innermost;
672       };
673 
674       std::vector<XlaOp> values_innermost(6);
675       auto p = values_inner[0];
676       auto q = p + ScalarLike(p, 1);
677       values_innermost[0] = p;                // index p.
678       values_innermost[1] = q;                // index q.
679       values_innermost[2] = values_inner[1];  // u.
680       values_innermost[3] = values_inner[2];  // v.
681       values_innermost[4] = values_inner[3];  // d.
682       values_innermost[5] = values_inner[4];  // eps.
683       TF_ASSIGN_OR_RETURN(
684           values_innermost,
685           WhileLoopHelper(while_cond_fn_innermost, while_body_fn_innermost,
686                           values_innermost, absl::StrCat(name, "-Innermost"),
687                           inner_body_builder));
688 
689       std::vector<XlaOp> updated_values_inner;
690       updated_values_inner.reserve(values_inner.size());
691 
692       updated_values_inner.push_back(p + ScalarLike(p, 1));
693       updated_values_inner.push_back(values_innermost[2]);
694       updated_values_inner.push_back(values_innermost[3]);
695       updated_values_inner.push_back(values_innermost[4]);
696       updated_values_inner.push_back(values_innermost[5]);
697       return updated_values_inner;
698     };
699     // Indexes.
700     XlaOp k = values[0];
701 
702     std::vector<XlaOp> values_inner(5);
703     values_inner[0] = ScalarLike(k, 0);  // index p.
704     values_inner[1] = values[1];         // u.
705     values_inner[2] = values[2];         // v.
706     values_inner[3] = values[3];         // d.
707     values_inner[4] = values[4];         // eps.
708     TF_ASSIGN_OR_RETURN(
709         values_inner,
710         WhileLoopHelper(while_cond_fn_inner, while_body_fn_inner, values_inner,
711                         absl::StrCat(name, "-Inner"), body_builder));
712 
713     std::vector<XlaOp> updated_values;
714     updated_values.reserve(values_inner.size());
715 
716     updated_values.push_back(k + ScalarLike(k, 1));
717     updated_values.push_back(values_inner[1]);
718     updated_values.push_back(values_inner[2]);
719     updated_values.push_back(values_inner[3]);
720     updated_values.push_back(values_inner[4]);
721 
722     return updated_values;
723   };
724   std::vector<XlaOp> values;
725   TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
726                                               initial_values, name, builder));
727 
728   return values;
729 }
730 
731 // Sort singular values in decending order, and make sure they are non-negative
732 // by flipping the signs of negative diagonal values and transferring the signs
733 // to V. And for numeric stability, renormalize U and V.
SortBySingularValuesAndPostProcessing(SVDResult result)734 StatusOr<SVDResult> SortBySingularValuesAndPostProcessing(SVDResult result) {
735   XlaBuilder* builder = result.d.builder();
736   TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(result.d));
737   const int64_t num_dims = shape.rank();
738   auto dimensions = shape.dimensions();
739   const int64_t m = ShapeUtil::GetDimension(shape, -2);
740   const int64_t n = ShapeUtil::GetDimension(shape, -1);
741 
742   std::vector<int64_t> broadcast_dims(num_dims - 1);
743   std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
744   broadcast_dims[num_dims - 2] = num_dims - 1;
745 
746   auto d = GetMatrixDiagonal(result.d);
747 
748   auto zeros = ZerosLike(d);
749   auto one = ScalarLike(d, 1.0);
750 
751   // Make all the singular values to be non-negative by transferring the signs
752   // to V.
753   auto sign = Select(Ge(d, zeros), zeros + one, zeros - one);
754   d = Select(Ge(d, zeros), d, -d);
755   result.v = Mul(result.v, sign, broadcast_dims);
756 
757   d = BroadcastInDim(d, dimensions, broadcast_dims);
758 
759   // As m >= n, only first n column vectors need to be permuted, and the rest of
760   // m - n vectors are appended after the sorting is done.
761   XlaOp sort_u_result =
762       Sort({d, SliceInMinorDims(result.u, {0, 0}, {m, n})},
763            CreateScalarGtComputation(
764                {shape.element_type(), shape.element_type()}, builder),
765            num_dims - 1);
766 
767   XlaOp sort_v_result =
768       Sort({SliceInMinorDims(d, {0, 0}, {n, n}), result.v},
769            CreateScalarGtComputation(
770                {shape.element_type(), shape.element_type()}, builder),
771            num_dims - 1);
772   result.d = GetMatrixDiagonal(GetTupleElement(sort_v_result, 0));
773 
774   result.v = GetTupleElement(sort_v_result, 1);
775   result.v = Mul(
776       result.v,
777       Rsqrt(Reduce(Square(result.v), ScalarLike(d, 0.0),
778                    CreateScalarAddComputation(shape.element_type(), builder),
779                    {num_dims - 2})),
780       broadcast_dims);
781 
782   // Append the rest of m - n vectors.
783   result.u = ConcatInDim(builder,
784                          {GetTupleElement(sort_u_result, 1),
785                           SliceInMinorDims(result.u, {0, n}, {m, m})},
786                          num_dims - 1);
787   result.u = Mul(
788       result.u,
789       Rsqrt(Reduce(Square(result.u), ScalarLike(d, 0.0),
790                    CreateScalarAddComputation(shape.element_type(), builder),
791                    {num_dims - 2})),
792       broadcast_dims);
793 
794   return result;
795 }
796 
797 }  // namespace
798 
799 // def jacobi_svd(A):
800 //    U, D, V = house_bidiag(A)
801 //    m, n = D.shape
802 //    iter, max_iter = 0, 100
803 //    frobenius_norm = np.linalg.norm(D)
804 //    diag_norm = np.linalg.norm(np.diag(D))
805 //    off_diag_norm = np.sqrt(
806 //        frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm)
807 //    while off_diag_norm > 1e-6 * frobenius_norm and iter < max_iter:
808 //        iter += 1
809 //        for p in range(m - 1):
810 //            for q in range(p + 1, n):
811 //                rot_l, rot_r = jacobi_rot(D[p][p], D[p][q], D[q][p], D[q][q])
812 //                D[[p, q], :] = np.matmul(rot_l.T, D[[p, q], :])
813 //                D[:, [p, q]] = np.matmul(D[:, [p, q]], rot_r)
814 //                U[:, [p, q]] = np.matmul(U[:, [p, q]], rot_l)
815 //                V[:, [p, q]] = np.matmul(V[:, [p, q]], rot_r)
816 //        frobenius_norm = np.linalg.norm(D)
817 //        diag_norm = np.linalg.norm(np.diag(D))
818 //        off_diag_norm = np.sqrt(
819 //            frobenius_norm - diag_norm) * np.sqrt(frobenius_norm + diag_norm)
820 //
821 //    return U, np.diag(D), V
822 //
SVD(XlaOp a,int64_t max_iter,float epsilon,PrecisionConfig::Precision precision)823 SVDResult SVD(XlaOp a, int64_t max_iter, float epsilon,
824               PrecisionConfig::Precision precision) {
825   XlaBuilder* builder = a.builder();
826   auto return_error = [&](const Status& status) {
827     SVDResult result;
828     result.u = builder->ReportError(status);
829     result.v = builder->ReportError(status);
830     result.d = builder->ReportError(status);
831     return result;
832   };
833   auto shape_with_status = builder->GetShape(a);
834   if (!shape_with_status.status().ok()) {
835     return return_error(shape_with_status.status());
836   }
837   Shape a_shape = shape_with_status.ValueOrDie();
838   const int64_t num_dims = a_shape.rank();
839   const int64_t num_batch_dims = num_dims - 2;
840   std::vector<int64_t> batch_dims(num_batch_dims);
841   for (int i = 0; i < num_batch_dims; ++i) {
842     batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
843   }
844   int64_t m = ShapeUtil::GetDimension(a_shape, -2);
845   int64_t n = ShapeUtil::GetDimension(a_shape, -1);
846   bool maybe_transpose = m < n;
847 
848   if (maybe_transpose) {
849     a = TransposeInMinorDims(a);
850     std::swap(m, n);
851   }
852 
853   auto eps = ScalarLike(a, epsilon);
854 
855   auto svd_result_or = HouseHolderBidiagonalization(a, eps, precision);
856   if (!svd_result_or.ok()) {
857     return return_error(svd_result_or.status());
858   }
859   SVDResult svd_result = svd_result_or.ValueOrDie();
860 
861   auto output_with_status = WhileLoopFn(
862       {
863           Zero(builder, S32),  // k
864           svd_result.u,        // u
865           svd_result.v,        // v
866           svd_result.d,        // d
867           eps,                 // epsilon
868       },                       //
869       n,                       //
870       max_iter,                //
871       "CyclicOneSidedJacobi",  //
872       builder);
873   if (!output_with_status.status().ok()) {
874     return return_error(output_with_status.status());
875   }
876 
877   auto output = output_with_status.ValueOrDie();
878 
879   svd_result.u = output[1];
880   svd_result.v = output[2];
881   svd_result.d = output[3];
882 
883   svd_result_or = SortBySingularValuesAndPostProcessing(svd_result);
884   if (!svd_result_or.ok()) {
885     return return_error(svd_result_or.status());
886   }
887   svd_result = svd_result_or.ValueOrDie();
888 
889   if (maybe_transpose) {
890     std::swap(svd_result.u, svd_result.v);
891   }
892   return svd_result;
893 }
894 
895 }  // namespace xla
896