1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/inlined_vector.h"
20 #include "tensorflow/compiler/xla/service/while_util.h"
21 #include "tensorflow/compiler/xla/util.h"
22
23 namespace xla {
24 namespace {
25
26 // Replaces all uses of old_instr with new_instr except the use at
27 // `while_body_root` (which must be a tuple instruction) at index `tuple_index`.
28 // This utility helps us replace an instruction in the while body with a
29 // constant while still keeping it trivially loop invariant.
ReplaceUsesWhileKeepingLoopInvariance(HloInstruction * old_instr,HloInstruction * new_instr,HloInstruction * while_body_root,int64_t tuple_index)30 Status ReplaceUsesWhileKeepingLoopInvariance(HloInstruction* old_instr,
31 HloInstruction* new_instr,
32 HloInstruction* while_body_root,
33 int64_t tuple_index) {
34 CHECK_EQ(while_body_root->opcode(), HloOpcode::kTuple);
35
36 std::vector<HloInstruction*> users;
37 users.reserve(old_instr->user_count());
38 absl::c_copy(old_instr->users(), std::back_inserter(users));
39
40 for (auto* user : users) {
41 for (int64_t i = 0, e = user->operand_count(); i < e; i++) {
42 if (user->operand(i) == old_instr &&
43 !(user == while_body_root && i == tuple_index)) {
44 TF_RETURN_IF_ERROR(user->ReplaceOperandWith(i, new_instr));
45 }
46 }
47 }
48
49 return OkStatus();
50 }
51
CloneHelper(const HloInstruction * instruction,HloComputation * computation)52 HloInstruction* CloneHelper(const HloInstruction* instruction,
53 HloComputation* computation) {
54 if (instruction->opcode() == HloOpcode::kConstant) {
55 return computation->AddInstruction(instruction->Clone(/*suffix=*/".sunk"));
56 }
57 if (instruction->opcode() == HloOpcode::kBroadcast) {
58 return computation->AddInstruction(instruction->CloneWithNewOperands(
59 instruction->shape(),
60 {CloneHelper(instruction->operand(0), computation)}));
61 }
62 LOG(FATAL) << "Unexpected instruction.";
63 }
64
65 } // namespace
66
TrySinkingConstantsIntoWhileLoop(HloInstruction * while_instr)67 StatusOr<bool> WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop(
68 HloInstruction* while_instr) {
69 HloComputation* while_cond = while_instr->while_condition();
70 HloComputation* while_body = while_instr->while_body();
71
72 const HloInstruction& init_value = *while_instr->operand(0);
73 if (init_value.opcode() != HloOpcode::kTuple) {
74 return false;
75 }
76
77 bool changed = false;
78
79 absl::flat_hash_map<int64_t, absl::InlinedVector<HloInstruction*, 1>>
80 conditional_gte_index_to_insts =
81 WhileUtil::GetGTEsMapForWhileConditional(*while_cond);
82 std::vector<HloInstruction*> invariant_body_gtes =
83 WhileUtil::GetInvariantGTEsForWhileBody(*while_body);
84
85 for (HloInstruction* invariant_body_gte : invariant_body_gtes) {
86 int64_t index = invariant_body_gte->tuple_index();
87 const HloInstruction& invariant_value = *init_value.operand(index);
88
89 // Original value should be a constant or broadcast of constant.
90 if (invariant_value.opcode() != HloOpcode::kConstant &&
91 (!sink_broadcast_of_constants_ ||
92 invariant_value.opcode() != HloOpcode::kBroadcast ||
93 invariant_value.operand(0)->opcode() != HloOpcode::kConstant)) {
94 continue;
95 }
96
97 // Sink into the while_body.
98 // Should have at least one user that's not while_body_root.
99 if (invariant_body_gte->user_count() > 1) {
100 HloInstruction* constant_instr =
101 CloneHelper(&invariant_value, while_body);
102 TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance(
103 invariant_body_gte, constant_instr, while_body->root_instruction(),
104 index));
105 changed = true;
106 }
107
108 // Check if there is a corresponding GTE in while_conditional.
109 auto it = conditional_gte_index_to_insts.find(index);
110 if (it == conditional_gte_index_to_insts.end()) {
111 continue;
112 }
113
114 for (HloInstruction* invariant_cond_gte : it->second) {
115 // Should have at least one user.
116 if (invariant_cond_gte->user_count() > 0) {
117 HloInstruction* constant_instr =
118 CloneHelper(&invariant_value, while_cond);
119 TF_RETURN_IF_ERROR(
120 invariant_cond_gte->ReplaceAllUsesWith(constant_instr));
121 changed = true;
122 }
123 }
124 }
125
126 return changed;
127 }
128
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)129 StatusOr<bool> WhileLoopConstantSinking::Run(
130 HloModule* module,
131 const absl::flat_hash_set<absl::string_view>& execution_threads) {
132 VLOG(2) << "HLO module before WhileLoopConstantSinking:";
133 XLA_VLOG_LINES(2, module->ToString());
134
135 bool changed = false;
136 std::vector<HloInstruction*> while_instrs;
137 for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
138 // Right now we don't particularly care about optimizing while-of-while
139 // patterns. If/When we do, we'll want to visit the outer while (while_0)
140 // before we visit the inner while (while_1):
141 //
142 // while_1_body(state) {
143 // val = gte(state, 0) // Loop invariant
144 // use(val)
145 // }
146 //
147 // while_0_body(state) {
148 // val = gte(state, 0) // Loop invariant
149 // while_1 = while(init=tuple(val, ...), body=while_1_body, ...)
150 // ...
151 // }
152 //
153 // main {
154 // while_0 = while(init=(constant, ...), body=while_0_body, ...)
155 // }
156 //
157 // This will let us sink the constant into the outer while first and then
158 // into the inner while in a single run of this pass.
159 absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs),
160 [](const HloInstruction* instr) {
161 return instr->opcode() == HloOpcode::kWhile;
162 });
163 }
164
165 for (HloInstruction* while_instr : while_instrs) {
166 TF_ASSIGN_OR_RETURN(bool result,
167 TrySinkingConstantsIntoWhileLoop(while_instr));
168 changed |= result;
169 }
170
171 if (changed) {
172 VLOG(2) << "HLO module after WhileLoopConstantSinking:";
173 XLA_VLOG_LINES(2, module->ToString());
174 } else {
175 VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking";
176 }
177
178 return changed;
179 }
180 } // namespace xla
181