xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cholesky_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/loops.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/lib/matrix.h"
26 #include "tensorflow/compiler/xla/client/lib/slicing.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 
36 namespace xla {
37 
38 // The Cholesky–Banachiewicz algorithm. See
39 // https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms
40 // for a description.
41 //
42 // def cholesky_unblocked(a):
43 //   assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1]
44 //   n = a.shape[-2]
45 //   l = np.zeros_like(a)
46 //   for j in xrange(n):
47 //     mask = np.zeros_like(a)
48 //     mask[i, k] == 1 when i >= k and k == j
49 //     l_square = np.dot(l, l_t)
50 //     temp = a - l_square
51 //     l[..., j, j] = temp(j, j)
52 //     l = temp / l[..., j, j) * mask + l
53 //   return l
54 // Returns a (result, error) pair.
CholeskyUnblocked(XlaOp a,PrecisionConfig::Precision precision)55 StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
56     XlaOp a, PrecisionConfig::Precision precision) {
57   XlaBuilder* builder = a.builder();
58   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
59   const int ndims = a_shape.rank();
60   const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
61   std::vector<int64_t> error_dims(a_shape.dimensions().begin(),
62                                   a_shape.dimensions().end());
63   error_dims.back() = error_dims.at(ndims - 2) = 1;
64 
65   auto major_dims = a_shape.dimensions().subspan(
66       /*pos=*/0,
67       /*len=*/ndims - 2);
68 
69   auto matrix_dims = a_shape.dimensions().subspan(
70       /*pos=*/0,
71       /*len=*/ndims);
72 
73   XlaOp l = ZerosLike(a);
74 
75   // Construct the for loop body to iterate over rows.
76   auto body_fn = [&](XlaOp i, absl::Span<const XlaOp> loop_vars,
77                      XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
78     std::vector<int64_t> row_shape_dims(major_dims.begin(), major_dims.end());
79     std::vector<int64_t> col_shape_dims(major_dims.begin(), major_dims.end());
80     auto body_a = loop_vars[0];
81     auto body_l = loop_vars[1];
82     auto seen_error = loop_vars[2];
83     auto iota_row =
84         Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 1);
85     auto iota_col =
86         Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 2);
87 
88     auto mask_pred = Ge(iota_col, iota_row);
89     mask_pred = And(mask_pred, Eq(iota_row, i));
90     auto mask_zeros =
91         Zeros(body_builder,
92               ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims));
93     // L * L.T, This matrix has of a lot of multiplying with zero
94     // (namely, L[:, j:] = 0) and redundant computation, but it is faster
95     // than slice.
96     auto l_square =
97         BatchDot(body_l, false, MaybeConjugate(body_l, true), true, precision);
98 
99     // A - L*L.T
100     l_square = body_a - l_square;
101     auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1});
102     if (ShapeUtil::ElementIsComplex(a_shape)) {
103       auto sqrt = Sqrt(Real(l_ii));
104       l_ii = Complex(sqrt, ZerosLike(sqrt));
105       seen_error = Or(seen_error, IsNan(sqrt));
106     } else {
107       l_ii = Sqrt(l_ii);
108       seen_error = Or(seen_error, IsNan(l_ii));
109     }
110     // L = (A - L*L.T) / l_ii * mask + L
111     body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l;
112 
113     return std::vector<XlaOp>{body_a, body_l, seen_error};
114   };
115 
116   TF_ASSIGN_OR_RETURN(
117       auto cholesky_while,
118       ForEachIndex(
119           n, S32, body_fn,
120           {a, l, Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims))},
121           "unblocked", builder));
122 
123   return std::make_pair(cholesky_while[1], cholesky_while[2]);
124 }
125 
BuildCholesky(XlaOp a,int64_t block_size,PrecisionConfig::Precision precision)126 XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64_t block_size,
127                                       PrecisionConfig::Precision precision) {
128   XlaBuilder* builder = a.builder();
129   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
130     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
131     const int ndims = a_shape.rank();
132     if (ndims < 2) {
133       return InvalidArgument(
134           "Argument to Cholesky must have rank >= 2; shape was %s",
135           a_shape.ToString());
136     }
137 
138     const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
139     if (n != ShapeUtil::GetDimension(a_shape, -2)) {
140       return InvalidArgument(
141           "Argument to Cholesky must be batched square matrices; got shape %s",
142           ShapeUtil::HumanString(a_shape));
143     }
144 
145     if (block_size < 1) {
146       return InvalidArgument(
147           "block_size argument to Cholesky must be >= 1; got %d", block_size);
148     }
149 
150     std::vector<int64_t> error_dims(a_shape.dimensions().begin(),
151                                     a_shape.dimensions().end());
152     error_dims.back() = error_dims.at(ndims - 2) = 1;
153     std::vector<int64_t> error_dim_indices(ndims);
154     absl::c_iota(error_dim_indices, 0);
155 
156     // Blocked left-looking Cholesky factorization.
157     // Algorithm 1 from
158     // Haidar, Azzam, et al. "High-performance Cholesky factorization for
159     // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
160     XlaOp l = ZerosLike(a);
161     XlaOp seen_error = Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims));
162     for (int64_t i = 0; i < n; i += block_size) {
163       int64_t k = std::min(block_size, n - i);
164       auto panel = SliceInMinorDims(a, {i, i}, {n, i + k});
165       if (i > 0) {
166         // TODO(phawkins): consider implementing SYRK for the diagonal part of
167         // the panel.
168         // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
169         auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
170         auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
171         auto delta =
172             BatchDot(lhs, false, MaybeConjugate(rhs, true), true, precision);
173         panel = panel - delta;
174       }
175 
176       // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
177       auto x = SliceInMinorDims(panel, {0, 0}, {k, k});
178       XlaOp factorized;
179       // TODO(b/167896062): A failure in one element of a batch shouldn't fail
180       // other elements.
181       XlaOp factorized_error;
182       if (k == 1) {
183         if (ShapeUtil::ElementIsComplex(a_shape)) {
184           auto sqrt = Sqrt(Real(x));
185           factorized = Complex(sqrt, ZerosLike(sqrt));
186           factorized_error = IsNan(sqrt);
187         } else {
188           factorized = Sqrt(x);
189           factorized_error = IsNan(factorized);
190         }
191       } else {
192         TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision));
193         std::tie(factorized, factorized_error) = tile_output;
194       }
195       seen_error = Or(seen_error, factorized_error);
196       l = UpdateSliceInMinorDims(l, factorized, {i, i});
197 
198       if (i + k < n) {
199         // l[i+k:, i:i+k] =
200         //     trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
201         auto update = TriangularSolve(
202             factorized, SliceInMinorDims(panel, {k, 0}, {n - i, k}),
203             /*left_side=*/false,
204             /*lower=*/true,
205             /*unit_diagonal=*/false,
206             /*transpose_a=*/TriangularSolveOptions::ADJOINT);
207         l = UpdateSliceInMinorDims(l, update, {i + k, i});
208       }
209     }
210     return Select(
211         BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices),
212         FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
213   });
214 }
215 
InstructionMatchesPattern(HloInstruction * instruction)216 bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) {
217   return instruction->opcode() == HloOpcode::kCholesky;
218 }
219 
ExpandInstruction(HloInstruction * instruction)220 StatusOr<HloInstruction*> CholeskyExpander::ExpandInstruction(
221     HloInstruction* instruction) {
222   const CholeskyOptions& options = instruction->cholesky_options();
223   const std::string name = absl::StrFormat(
224       "xla.cholesky_%s_%s", instruction->operand(0)->shape().ToString(),
225       options.lower() ? "lower" : "upper");
226 
227   HloModule* module = instruction->parent()->parent();
228 
229   HloComputation*& computation =
230       computation_cache_.emplace(name, nullptr).first->second;
231   if (!computation) {
232     // Builds a new expansion.
233     //
234     // TODO(b/62327888): We do something unusual here: we build the computation
235     // using the XlaBuilder API, which is nominally an XLA client API. We do
236     // this because the external APIs for building complicated computations
237     // (XlaBuilder) are much more ergonomic than the internal ones. As it turns
238     // out, XlaBuilder isn't really a client API—what it does is build a
239     // HloModuleProto protocol buffer, that we can then deserialize and clone
240     // into our HloModule. Ideally we would avoid the protocol buffer step;
241     // that is left as an exercise for future work.
242     XlaBuilder builder(name);
243     XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
244     XlaOp l = BuildCholesky(MaybeTransposeInMinorDims(a, !options.lower()),
245                             /*block_size=*/128,
246                             /*precision=*/PrecisionConfig::HIGHEST);
247     MaybeTransposeInMinorDims(l, !options.lower());
248 
249     TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
250 
251     TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
252                         xla_computation.GetProgramShape());
253     HloModuleConfig config(program_shape);
254     TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
255                                              xla_computation.proto(), config));
256     HloCloneContext context(module);
257     computation =
258         module->DeepCloneComputation(new_module->entry_computation(), &context);
259   }
260 
261   return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
262       instruction->shape(), instruction->operands(), computation));
263 }
264 
265 }  // namespace xla
266