1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_rewriter.h"
17
18 #include <numeric>
19
20 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
21 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
22
23 namespace xla {
24 namespace gpu {
25
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)26 StatusOr<bool> TriangularSolveRewriter::Run(
27 HloModule* module,
28 const absl::flat_hash_set<absl::string_view>& execution_threads) {
29 bool changed = false;
30 for (HloComputation* comp :
31 module->MakeNonfusionComputations(execution_threads)) {
32 std::vector<HloInstruction*> to_rewrite;
33 for (HloInstruction* instr : comp->instructions()) {
34 if (instr->opcode() == HloOpcode::kTriangularSolve) {
35 to_rewrite.push_back(instr);
36 }
37 }
38
39 for (HloInstruction* instr : to_rewrite) {
40 const Shape& b_shape = instr->operand(1)->shape();
41 int64_t batch_size = std::accumulate(
42 b_shape.dimensions().begin(), b_shape.dimensions().end() - 2,
43 int64_t{1}, [](int64_t a, int64_t b) { return a * b; });
44
45 // batch 1 triangular solves get 0 temp bytes, because unbatched trsm()
46 // doesn't require temp memory.
47 int64_t temp_bytes = batch_size == 1 ? 0 : 2 * sizeof(void*) * batch_size;
48 Shape new_shape = ShapeUtil::MakeTupleShape({
49 instr->shape(),
50 ShapeUtil::MakeShape(S8, {temp_bytes}),
51 });
52
53 HloInstruction* custom_call =
54 comp->AddInstruction(HloInstruction::CreateCustomCall(
55 new_shape, instr->operands(), kTriangularSolveCallTarget));
56 module->SetAndUniquifyInstrName(custom_call, "triangular-solve");
57 TF_RETURN_IF_ERROR(
58 custom_call->set_backend_config(instr->triangular_solve_options()));
59
60 // Preserve metadata from `instr`.
61 custom_call->set_metadata(instr->metadata());
62 custom_call->set_frontend_attributes(instr->frontend_attributes());
63
64 // Get the actual result out of the custom call's tuple.
65 TF_ASSIGN_OR_RETURN(HloInstruction * gte,
66 MakeGetTupleElementHlo(custom_call, 0));
67 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, gte));
68 }
69 }
70 return changed;
71 }
72
73 } // namespace gpu
74 } // namespace xla
75