1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
17
18 #include <cstdint>
19 #include <ostream>
20 #include <queue>
21 #include <string>
22 #include <tuple>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/str_format.h"
29 #include "absl/strings/str_join.h"
30 #include "absl/strings/string_view.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37
38 namespace xla {
39
CreateFromProto(const HloModule * module,const HloScheduleProto & proto)40 /* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
41 const HloModule* module, const HloScheduleProto& proto) {
42 absl::flat_hash_map<int64_t, const HloComputation*> id_to_computation;
43 for (const HloComputation* computation : module->computations()) {
44 id_to_computation[computation->unique_id()] = computation;
45 }
46
47 HloSchedule schedule(module);
48 for (const auto& id_sequence : proto.sequences()) {
49 int64_t computation_id = id_sequence.first;
50
51 auto comp_it = id_to_computation.find(computation_id);
52 // Computation could have been removed if unused, so
53 // skip if not found.
54 if (comp_it == id_to_computation.end()) {
55 continue;
56 }
57 const HloComputation* computation = comp_it->second;
58
59 absl::flat_hash_map<int64_t, HloInstruction*> id_to_instruction;
60 for (HloInstruction* instruction : computation->instructions()) {
61 id_to_instruction[instruction->unique_id()] = instruction;
62 }
63
64 HloInstructionSequence& sequence =
65 schedule.GetOrCreateSequence(computation);
66 for (const int64_t instruction_id : id_sequence.second.instruction_ids()) {
67 auto instr_it = id_to_instruction.find(instruction_id);
68 TF_RET_CHECK(instr_it != id_to_instruction.end())
69 << "No instruction exists in HLO computation " << computation->name()
70 << " with id " << instruction_id;
71 sequence.push_back(instr_it->second);
72 }
73 }
74 TF_RETURN_IF_ERROR(schedule.Verify());
75 return std::move(schedule);
76 }
77
ToProto() const78 StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
79 TF_RETURN_IF_ERROR(Verify());
80 HloScheduleProto proto;
81 for (const auto& id_sequence : sequences_) {
82 int64_t computation_id = id_sequence.first;
83 const HloInstructionSequence& sequence = id_sequence.second;
84 HloScheduleProto::InstructionSequence& proto_sequence =
85 (*proto.mutable_sequences())[computation_id];
86 proto_sequence.mutable_instruction_ids()->Reserve(sequence.size());
87 for (const int64_t id : sequence.ids()) {
88 proto_sequence.add_instruction_ids(id);
89 }
90 }
91 return std::move(proto);
92 }
93
set_sequence(const HloComputation * computation,absl::Span<HloInstruction * const> sequence)94 void HloSchedule::set_sequence(const HloComputation* computation,
95 absl::Span<HloInstruction* const> sequence) {
96 set_sequence(computation, HloInstructionSequence(sequence));
97 }
98
set_sequence(const HloComputation * computation,HloInstructionSequence sequence)99 void HloSchedule::set_sequence(const HloComputation* computation,
100 HloInstructionSequence sequence) {
101 CHECK(computation->parent() == module_);
102 sequences_[computation->unique_id()] = std::move(sequence);
103 execution_threads_[computation->unique_id()] =
104 std::string(computation->execution_thread());
105 }
106
GetOrCreateSequence(const HloComputation * computation)107 HloInstructionSequence& HloSchedule::GetOrCreateSequence(
108 const HloComputation* computation) {
109 auto it = sequences_.find(computation->unique_id());
110 if (it == sequences_.end()) {
111 // No sequence found for computation. Create and return an empty one.
112 CHECK(computation->parent() == module_);
113 execution_threads_[computation->unique_id()] =
114 std::string(computation->execution_thread());
115 return sequences_[computation->unique_id()];
116 } else {
117 return it->second;
118 }
119 }
120
sequence(const HloComputation * computation) const121 const HloInstructionSequence& HloSchedule::sequence(
122 const HloComputation* computation) const {
123 return sequences_.at(computation->unique_id());
124 }
125
UpdateComputationSchedule(const HloComputation * computation)126 Status HloSchedule::UpdateComputationSchedule(
127 const HloComputation* computation) {
128 // Map from unique ID to HloInstruction pointer for instructions in the
129 // computation.
130 absl::flat_hash_map<int, HloInstruction*> id_to_instruction;
131 for (HloInstruction* instruction : computation->instructions()) {
132 InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
133 }
134
135 // Set of all HloInstructions in the schedule.
136 absl::flat_hash_set<int> ids_in_schedule;
137 for (int id : sequences_.at(computation->unique_id()).ids()) {
138 InsertOrDie(&ids_in_schedule, id);
139 }
140
141 // Map from HloInstruction X to newly added instructions (instruction is in
142 // computation, but not in schedule) which use X. If an instruction is not in
143 // the map, then it has no users which are newly added instructions.
144 absl::flat_hash_map<const HloInstruction*, std::vector<HloInstruction*>>
145 new_instruction_uses;
146
147 // For each newly added instruction, this is the count of the instruction's
148 // operands that have not yet been scheduled. When this value reaches zero,
149 // then the instruction may be placed in the schedule.
150 absl::flat_hash_map<const HloInstruction*, int> unscheduled_operand_count;
151
152 // Create a worklist of newly added instructions which are ready to be added
153 // to the schedule. Initialize worklist with those that have zero operands.
154 std::queue<HloInstruction*> worklist;
155
156 for (HloInstruction* instruction : computation->instructions()) {
157 if (!ids_in_schedule.contains(instruction->unique_id())) {
158 // This is a newly added instruction which is not in the schedule.
159 if (instruction->operands().empty()) {
160 worklist.push(instruction);
161 } else {
162 for (const HloInstruction* operand : instruction->operands()) {
163 new_instruction_uses[operand].push_back(instruction);
164 }
165 unscheduled_operand_count[instruction] = instruction->operand_count();
166 }
167 }
168 }
169
170 // Update the schedule with the newly added instructions, and remove any
171 // instructions no longer in the graph.
172 HloInstructionSequence new_sequence;
173
174 // Lambda which schedules all instructions on the worklist.
175 auto schedule_worklist = [&]() {
176 while (!worklist.empty()) {
177 HloInstruction* instruction = worklist.front();
178 worklist.pop();
179 new_sequence.push_back(instruction);
180 std::vector<HloInstruction*>* new_users =
181 tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
182 if (new_users != nullptr) {
183 // This just-scheduled instruction has users which are newly added to
184 // the module. Update the number of unscheduled operands and push the
185 // newly added instruction to the worklist if it is ready to
186 // schedule.
187 for (HloInstruction* new_user : *new_users) {
188 unscheduled_operand_count.at(new_user)--;
189 CHECK_GE(unscheduled_operand_count.at(new_user), 0);
190 if (unscheduled_operand_count.at(new_user) == 0) {
191 worklist.push(new_user);
192 }
193 }
194 }
195 }
196 };
197
198 schedule_worklist();
199 for (int id : sequences_.at(computation->unique_id()).ids()) {
200 auto it = id_to_instruction.find(id);
201 if (it == id_to_instruction.end()) {
202 // This instruction in the schedule is no longer in the module. Do not add
203 // it to the new schedule.
204 continue;
205 }
206 worklist.push(it->second);
207 schedule_worklist();
208 }
209
210 set_sequence(computation, std::move(new_sequence));
211 return OkStatus();
212 }
213
Update(const absl::flat_hash_set<absl::string_view> & execution_threads)214 Status HloSchedule::Update(
215 const absl::flat_hash_set<absl::string_view>& execution_threads) {
216 // The schedule must contain a sequence for every non-fusion computation in
217 // the module for the specified threads, but can have sequences for
218 // computations which no longer exist (these are removed).
219 std::vector<HloComputation*> nonfusion_computations =
220 module_->MakeNonfusionComputations(execution_threads);
221 for (const HloComputation* computation : nonfusion_computations) {
222 TF_RET_CHECK(sequences_.contains(computation->unique_id()))
223 << "Computation " << computation->name() << " not in HloSchedule.";
224 }
225 auto sum_of_sequences_for_threads = [&]() -> int64_t {
226 if (execution_threads.empty()) {
227 return sequences_.size();
228 }
229 int64_t sequences_num_for_threads = 0;
230 for (const auto& [thread_name, sequence_num] :
231 num_sequences_by_execution_thread()) {
232 sequences_num_for_threads +=
233 execution_threads.contains(thread_name) ? sequence_num : 0;
234 }
235 return sequences_num_for_threads;
236 };
237 int64_t sequence_sum = sum_of_sequences_for_threads();
238 if (sequence_sum > nonfusion_computations.size()) {
239 // Schedule contains some computations which have been removed from the
240 // HloModule. Remove them from the schedule as well.
241 absl::flat_hash_set<int64_t> nonfusion_computations_ids;
242 for (const HloComputation* computation : nonfusion_computations) {
243 nonfusion_computations_ids.insert(computation->unique_id());
244 }
245 for (auto it = sequences_.begin(); it != sequences_.end();) {
246 std::string sequence_thread_name = tensorflow::gtl::FindWithDefault(
247 execution_threads_, it->first, HloInstruction::kMainExecutionThread);
248 bool is_thread_included =
249 execution_threads.empty() ||
250 execution_threads.contains(sequence_thread_name);
251 if (!nonfusion_computations_ids.contains(it->first) &&
252 is_thread_included) {
253 execution_threads_.erase(it->first);
254 sequences_.erase(it++);
255 } else {
256 ++it;
257 }
258 }
259 }
260 sequence_sum = sum_of_sequences_for_threads();
261 CHECK_EQ(sequence_sum, nonfusion_computations.size());
262
263 for (const HloComputation* computation : nonfusion_computations) {
264 TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation));
265 }
266
267 TF_RETURN_IF_ERROR(Verify());
268 return OkStatus();
269 }
270
271 absl::flat_hash_map<std::string, int64_t>
num_sequences_by_execution_thread() const272 HloSchedule::num_sequences_by_execution_thread() const {
273 absl::flat_hash_map<std::string, int64_t> sequence_num_by_execution_threads;
274 for (const auto& id_sequence_item : sequences_) {
275 ++sequence_num_by_execution_threads[tensorflow::gtl::FindWithDefault(
276 execution_threads_, id_sequence_item.first,
277 HloInstruction::kMainExecutionThread)];
278 }
279 return sequence_num_by_execution_threads;
280 }
281
Verify() const282 Status HloSchedule::Verify() const {
283 VLOG(2) << "VerifySchedule()";
284 XLA_VLOG_LINES(2, ToString());
285
286 // Verify schedule contains exactly the same set of non-fusion computations as
287 // module currently does for each thread that has schedule.
288 absl::flat_hash_map<std::string, int64_t> sequence_num_by_execution_threads =
289 num_sequences_by_execution_thread();
290 for (const auto& [thread_name, sequence_size] :
291 sequence_num_by_execution_threads) {
292 std::vector<HloComputation*> nonfusion_computations =
293 module_->MakeNonfusionComputations({thread_name});
294 TF_RET_CHECK(nonfusion_computations.size() == sequence_size)
295 << "For thread " << thread_name << ", schedule has " << sequence_size
296 << " sequences, but module has " << nonfusion_computations.size()
297 << " non-fusion computations for thread " << thread_name;
298 for (const HloComputation* computation : nonfusion_computations) {
299 TF_RET_CHECK(sequences_.contains(computation->unique_id()))
300 << "Computation " << computation->name()
301 << " missing from HLO schedule.";
302 }
303
304 // For each computation verify the set of instructions is the same and
305 // that each dependency and control edge is honored.
306 for (const HloComputation* computation : nonfusion_computations) {
307 absl::flat_hash_map<const HloInstruction*, int> instruction_position;
308 int pos = 0;
309 for (const HloInstruction* instruction :
310 sequence(computation).instructions()) {
311 TF_RET_CHECK(instruction_position.insert({instruction, pos}).second)
312 << "Instruction " << instruction->name()
313 << " appears more than once in the schedule";
314 pos++;
315 }
316
317 TF_RET_CHECK(instruction_position.size() ==
318 computation->instruction_count())
319 << "Schedule for computation " << computation->name() << " has "
320 << instruction_position.size() << " instructions, expected "
321 << computation->instruction_count();
322 for (const HloInstruction* instruction : computation->instructions()) {
323 TF_RET_CHECK(instruction_position.contains(instruction))
324 << "Instruction " << instruction->name() << " is not in schedule";
325 }
326
327 for (const HloInstruction* instruction : computation->instructions()) {
328 for (const HloInstruction* operand : instruction->operands()) {
329 TF_RET_CHECK(instruction_position.at(operand) <
330 instruction_position.at(instruction))
331 << "Instruction " << instruction->name()
332 << " is not scheduled after its operand " << operand->name();
333 }
334
335 for (const HloInstruction* pred : instruction->control_predecessors()) {
336 TF_RET_CHECK(instruction_position.at(pred) <
337 instruction_position.at(instruction))
338 << "Instruction " << instruction->name()
339 << " is not scheduled after its control predecessor "
340 << pred->name();
341 }
342 }
343 }
344 }
345
346 return OkStatus();
347 }
348
349 namespace {
350
351 // Returns the computation in the given module with the given unique ID. Returns
352 // nullptr if no such computation exists.
IdToComputation(const HloModule * module,int64_t id)353 const HloComputation* IdToComputation(const HloModule* module, int64_t id) {
354 for (const HloComputation* computation : module->computations()) {
355 if (computation->unique_id() == id) {
356 return computation;
357 }
358 }
359 return nullptr;
360 }
361
362 } // namespace
363
ToString() const364 std::string HloSchedule::ToString() const {
365 std::vector<std::string> pieces;
366
367 pieces.push_back("HloSchedule");
368 for (const auto& id_sequence : sequences_) {
369 const HloComputation* computation =
370 IdToComputation(module_, id_sequence.first);
371 if (computation == nullptr) {
372 // The computation is not in the module and may have been deleted so it is
373 // not safe to dereference any HLO pointers. Just use the HLO unique ids
374 // stored in this object.
375 pieces.push_back(
376 absl::StrFormat("computation with id %d (no longer in HLO module):",
377 id_sequence.first));
378 for (int id : id_sequence.second.ids()) {
379 pieces.push_back(absl::StrCat(" ", id));
380 }
381 } else {
382 pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
383 for (const HloInstruction* instruction :
384 id_sequence.second.instructions()) {
385 pieces.push_back(absl::StrCat(" ", instruction->name()));
386 }
387 }
388 }
389 return absl::StrJoin(pieces, "\n");
390 }
391
operator <<(std::ostream & out,const HloSchedule & schedule)392 std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) {
393 out << schedule.ToString();
394 return out;
395 }
396
397 } // namespace xla
398