1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/loops.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/lib/matrix.h"
26 #include "tensorflow/compiler/xla/client/lib/slicing.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/primitive_util.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35
36 namespace xla {
37
38 // The Cholesky–Banachiewicz algorithm. See
39 // https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms
40 // for a description.
41 //
42 // def cholesky_unblocked(a):
43 // assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1]
44 // n = a.shape[-2]
45 // l = np.zeros_like(a)
46 // for j in xrange(n):
47 // mask = np.zeros_like(a)
48 // mask[i, k] == 1 when i >= k and k == j
49 // l_square = np.dot(l, l_t)
50 // temp = a - l_square
51 // l[..., j, j] = temp(j, j)
52 // l = temp / l[..., j, j) * mask + l
53 // return l
54 // Returns a (result, error) pair.
CholeskyUnblocked(XlaOp a,PrecisionConfig::Precision precision)55 StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
56 XlaOp a, PrecisionConfig::Precision precision) {
57 XlaBuilder* builder = a.builder();
58 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
59 const int ndims = a_shape.rank();
60 const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
61 std::vector<int64_t> error_dims(a_shape.dimensions().begin(),
62 a_shape.dimensions().end());
63 error_dims.back() = error_dims.at(ndims - 2) = 1;
64
65 auto major_dims = a_shape.dimensions().subspan(
66 /*pos=*/0,
67 /*len=*/ndims - 2);
68
69 auto matrix_dims = a_shape.dimensions().subspan(
70 /*pos=*/0,
71 /*len=*/ndims);
72
73 XlaOp l = ZerosLike(a);
74
75 // Construct the for loop body to iterate over rows.
76 auto body_fn = [&](XlaOp i, absl::Span<const XlaOp> loop_vars,
77 XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
78 std::vector<int64_t> row_shape_dims(major_dims.begin(), major_dims.end());
79 std::vector<int64_t> col_shape_dims(major_dims.begin(), major_dims.end());
80 auto body_a = loop_vars[0];
81 auto body_l = loop_vars[1];
82 auto seen_error = loop_vars[2];
83 auto iota_row =
84 Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 1);
85 auto iota_col =
86 Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 2);
87
88 auto mask_pred = Ge(iota_col, iota_row);
89 mask_pred = And(mask_pred, Eq(iota_row, i));
90 auto mask_zeros =
91 Zeros(body_builder,
92 ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims));
93 // L * L.T, This matrix has of a lot of multiplying with zero
94 // (namely, L[:, j:] = 0) and redundant computation, but it is faster
95 // than slice.
96 auto l_square =
97 BatchDot(body_l, false, MaybeConjugate(body_l, true), true, precision);
98
99 // A - L*L.T
100 l_square = body_a - l_square;
101 auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1});
102 if (ShapeUtil::ElementIsComplex(a_shape)) {
103 auto sqrt = Sqrt(Real(l_ii));
104 l_ii = Complex(sqrt, ZerosLike(sqrt));
105 seen_error = Or(seen_error, IsNan(sqrt));
106 } else {
107 l_ii = Sqrt(l_ii);
108 seen_error = Or(seen_error, IsNan(l_ii));
109 }
110 // L = (A - L*L.T) / l_ii * mask + L
111 body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l;
112
113 return std::vector<XlaOp>{body_a, body_l, seen_error};
114 };
115
116 TF_ASSIGN_OR_RETURN(
117 auto cholesky_while,
118 ForEachIndex(
119 n, S32, body_fn,
120 {a, l, Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims))},
121 "unblocked", builder));
122
123 return std::make_pair(cholesky_while[1], cholesky_while[2]);
124 }
125
BuildCholesky(XlaOp a,int64_t block_size,PrecisionConfig::Precision precision)126 XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64_t block_size,
127 PrecisionConfig::Precision precision) {
128 XlaBuilder* builder = a.builder();
129 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
130 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
131 const int ndims = a_shape.rank();
132 if (ndims < 2) {
133 return InvalidArgument(
134 "Argument to Cholesky must have rank >= 2; shape was %s",
135 a_shape.ToString());
136 }
137
138 const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
139 if (n != ShapeUtil::GetDimension(a_shape, -2)) {
140 return InvalidArgument(
141 "Argument to Cholesky must be batched square matrices; got shape %s",
142 ShapeUtil::HumanString(a_shape));
143 }
144
145 if (block_size < 1) {
146 return InvalidArgument(
147 "block_size argument to Cholesky must be >= 1; got %d", block_size);
148 }
149
150 std::vector<int64_t> error_dims(a_shape.dimensions().begin(),
151 a_shape.dimensions().end());
152 error_dims.back() = error_dims.at(ndims - 2) = 1;
153 std::vector<int64_t> error_dim_indices(ndims);
154 absl::c_iota(error_dim_indices, 0);
155
156 // Blocked left-looking Cholesky factorization.
157 // Algorithm 1 from
158 // Haidar, Azzam, et al. "High-performance Cholesky factorization for
159 // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
160 XlaOp l = ZerosLike(a);
161 XlaOp seen_error = Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims));
162 for (int64_t i = 0; i < n; i += block_size) {
163 int64_t k = std::min(block_size, n - i);
164 auto panel = SliceInMinorDims(a, {i, i}, {n, i + k});
165 if (i > 0) {
166 // TODO(phawkins): consider implementing SYRK for the diagonal part of
167 // the panel.
168 // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
169 auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
170 auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
171 auto delta =
172 BatchDot(lhs, false, MaybeConjugate(rhs, true), true, precision);
173 panel = panel - delta;
174 }
175
176 // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
177 auto x = SliceInMinorDims(panel, {0, 0}, {k, k});
178 XlaOp factorized;
179 // TODO(b/167896062): A failure in one element of a batch shouldn't fail
180 // other elements.
181 XlaOp factorized_error;
182 if (k == 1) {
183 if (ShapeUtil::ElementIsComplex(a_shape)) {
184 auto sqrt = Sqrt(Real(x));
185 factorized = Complex(sqrt, ZerosLike(sqrt));
186 factorized_error = IsNan(sqrt);
187 } else {
188 factorized = Sqrt(x);
189 factorized_error = IsNan(factorized);
190 }
191 } else {
192 TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision));
193 std::tie(factorized, factorized_error) = tile_output;
194 }
195 seen_error = Or(seen_error, factorized_error);
196 l = UpdateSliceInMinorDims(l, factorized, {i, i});
197
198 if (i + k < n) {
199 // l[i+k:, i:i+k] =
200 // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
201 auto update = TriangularSolve(
202 factorized, SliceInMinorDims(panel, {k, 0}, {n - i, k}),
203 /*left_side=*/false,
204 /*lower=*/true,
205 /*unit_diagonal=*/false,
206 /*transpose_a=*/TriangularSolveOptions::ADJOINT);
207 l = UpdateSliceInMinorDims(l, update, {i + k, i});
208 }
209 }
210 return Select(
211 BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices),
212 FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
213 });
214 }
215
InstructionMatchesPattern(HloInstruction * instruction)216 bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) {
217 return instruction->opcode() == HloOpcode::kCholesky;
218 }
219
ExpandInstruction(HloInstruction * instruction)220 StatusOr<HloInstruction*> CholeskyExpander::ExpandInstruction(
221 HloInstruction* instruction) {
222 const CholeskyOptions& options = instruction->cholesky_options();
223 const std::string name = absl::StrFormat(
224 "xla.cholesky_%s_%s", instruction->operand(0)->shape().ToString(),
225 options.lower() ? "lower" : "upper");
226
227 HloModule* module = instruction->parent()->parent();
228
229 HloComputation*& computation =
230 computation_cache_.emplace(name, nullptr).first->second;
231 if (!computation) {
232 // Builds a new expansion.
233 //
234 // TODO(b/62327888): We do something unusual here: we build the computation
235 // using the XlaBuilder API, which is nominally an XLA client API. We do
236 // this because the external APIs for building complicated computations
237 // (XlaBuilder) are much more ergonomic than the internal ones. As it turns
238 // out, XlaBuilder isn't really a client API—what it does is build a
239 // HloModuleProto protocol buffer, that we can then deserialize and clone
240 // into our HloModule. Ideally we would avoid the protocol buffer step;
241 // that is left as an exercise for future work.
242 XlaBuilder builder(name);
243 XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
244 XlaOp l = BuildCholesky(MaybeTransposeInMinorDims(a, !options.lower()),
245 /*block_size=*/128,
246 /*precision=*/PrecisionConfig::HIGHEST);
247 MaybeTransposeInMinorDims(l, !options.lower());
248
249 TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
250
251 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
252 xla_computation.GetProgramShape());
253 HloModuleConfig config(program_shape);
254 TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
255 xla_computation.proto(), config));
256 HloCloneContext context(module);
257 computation =
258 module->DeepCloneComputation(new_module->entry_computation(), &context);
259 }
260
261 return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
262 instruction->shape(), instruction->operands(), computation));
263 }
264
265 } // namespace xla
266