xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_schedule.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
17 
18 #include <cstdint>
19 #include <ostream>
20 #include <queue>
21 #include <string>
22 #include <tuple>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/str_format.h"
29 #include "absl/strings/str_join.h"
30 #include "absl/strings/string_view.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 
38 namespace xla {
39 
CreateFromProto(const HloModule * module,const HloScheduleProto & proto)40 /* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
41     const HloModule* module, const HloScheduleProto& proto) {
42   absl::flat_hash_map<int64_t, const HloComputation*> id_to_computation;
43   for (const HloComputation* computation : module->computations()) {
44     id_to_computation[computation->unique_id()] = computation;
45   }
46 
47   HloSchedule schedule(module);
48   for (const auto& id_sequence : proto.sequences()) {
49     int64_t computation_id = id_sequence.first;
50 
51     auto comp_it = id_to_computation.find(computation_id);
52     // Computation could have been removed if unused, so
53     // skip if not found.
54     if (comp_it == id_to_computation.end()) {
55       continue;
56     }
57     const HloComputation* computation = comp_it->second;
58 
59     absl::flat_hash_map<int64_t, HloInstruction*> id_to_instruction;
60     for (HloInstruction* instruction : computation->instructions()) {
61       id_to_instruction[instruction->unique_id()] = instruction;
62     }
63 
64     HloInstructionSequence& sequence =
65         schedule.GetOrCreateSequence(computation);
66     for (const int64_t instruction_id : id_sequence.second.instruction_ids()) {
67       auto instr_it = id_to_instruction.find(instruction_id);
68       TF_RET_CHECK(instr_it != id_to_instruction.end())
69           << "No instruction exists in HLO computation " << computation->name()
70           << " with id " << instruction_id;
71       sequence.push_back(instr_it->second);
72     }
73   }
74   TF_RETURN_IF_ERROR(schedule.Verify());
75   return std::move(schedule);
76 }
77 
ToProto() const78 StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
79   TF_RETURN_IF_ERROR(Verify());
80   HloScheduleProto proto;
81   for (const auto& id_sequence : sequences_) {
82     int64_t computation_id = id_sequence.first;
83     const HloInstructionSequence& sequence = id_sequence.second;
84     HloScheduleProto::InstructionSequence& proto_sequence =
85         (*proto.mutable_sequences())[computation_id];
86     proto_sequence.mutable_instruction_ids()->Reserve(sequence.size());
87     for (const int64_t id : sequence.ids()) {
88       proto_sequence.add_instruction_ids(id);
89     }
90   }
91   return std::move(proto);
92 }
93 
set_sequence(const HloComputation * computation,absl::Span<HloInstruction * const> sequence)94 void HloSchedule::set_sequence(const HloComputation* computation,
95                                absl::Span<HloInstruction* const> sequence) {
96   set_sequence(computation, HloInstructionSequence(sequence));
97 }
98 
set_sequence(const HloComputation * computation,HloInstructionSequence sequence)99 void HloSchedule::set_sequence(const HloComputation* computation,
100                                HloInstructionSequence sequence) {
101   CHECK(computation->parent() == module_);
102   sequences_[computation->unique_id()] = std::move(sequence);
103   execution_threads_[computation->unique_id()] =
104       std::string(computation->execution_thread());
105 }
106 
GetOrCreateSequence(const HloComputation * computation)107 HloInstructionSequence& HloSchedule::GetOrCreateSequence(
108     const HloComputation* computation) {
109   auto it = sequences_.find(computation->unique_id());
110   if (it == sequences_.end()) {
111     // No sequence found for computation. Create and return an empty one.
112     CHECK(computation->parent() == module_);
113     execution_threads_[computation->unique_id()] =
114         std::string(computation->execution_thread());
115     return sequences_[computation->unique_id()];
116   } else {
117     return it->second;
118   }
119 }
120 
sequence(const HloComputation * computation) const121 const HloInstructionSequence& HloSchedule::sequence(
122     const HloComputation* computation) const {
123   return sequences_.at(computation->unique_id());
124 }
125 
UpdateComputationSchedule(const HloComputation * computation)126 Status HloSchedule::UpdateComputationSchedule(
127     const HloComputation* computation) {
128   // Map from unique ID to HloInstruction pointer for instructions in the
129   // computation.
130   absl::flat_hash_map<int, HloInstruction*> id_to_instruction;
131   for (HloInstruction* instruction : computation->instructions()) {
132     InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
133   }
134 
135   // Set of all HloInstructions in the schedule.
136   absl::flat_hash_set<int> ids_in_schedule;
137   for (int id : sequences_.at(computation->unique_id()).ids()) {
138     InsertOrDie(&ids_in_schedule, id);
139   }
140 
141   // Map from HloInstruction X to newly added instructions (instruction is in
142   // computation, but not in schedule) which use X. If an instruction is not in
143   // the map, then it has no users which are newly added instructions.
144   absl::flat_hash_map<const HloInstruction*, std::vector<HloInstruction*>>
145       new_instruction_uses;
146 
147   // For each newly added instruction, this is the count of the instruction's
148   // operands that have not yet been scheduled. When this value reaches zero,
149   // then the instruction may be placed in the schedule.
150   absl::flat_hash_map<const HloInstruction*, int> unscheduled_operand_count;
151 
152   // Create a worklist of newly added instructions which are ready to be added
153   // to the schedule. Initialize worklist with those that have zero operands.
154   std::queue<HloInstruction*> worklist;
155 
156   for (HloInstruction* instruction : computation->instructions()) {
157     if (!ids_in_schedule.contains(instruction->unique_id())) {
158       // This is a newly added instruction which is not in the schedule.
159       if (instruction->operands().empty()) {
160         worklist.push(instruction);
161       } else {
162         for (const HloInstruction* operand : instruction->operands()) {
163           new_instruction_uses[operand].push_back(instruction);
164         }
165         unscheduled_operand_count[instruction] = instruction->operand_count();
166       }
167     }
168   }
169 
170   // Update the schedule with the newly added instructions, and remove any
171   // instructions no longer in the graph.
172   HloInstructionSequence new_sequence;
173 
174   // Lambda which schedules all instructions on the worklist.
175   auto schedule_worklist = [&]() {
176     while (!worklist.empty()) {
177       HloInstruction* instruction = worklist.front();
178       worklist.pop();
179       new_sequence.push_back(instruction);
180       std::vector<HloInstruction*>* new_users =
181           tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
182       if (new_users != nullptr) {
183         // This just-scheduled instruction has users which are newly added to
184         // the module. Update the number of unscheduled operands and push the
185         // newly added instruction to the worklist if it is ready to
186         // schedule.
187         for (HloInstruction* new_user : *new_users) {
188           unscheduled_operand_count.at(new_user)--;
189           CHECK_GE(unscheduled_operand_count.at(new_user), 0);
190           if (unscheduled_operand_count.at(new_user) == 0) {
191             worklist.push(new_user);
192           }
193         }
194       }
195     }
196   };
197 
198   schedule_worklist();
199   for (int id : sequences_.at(computation->unique_id()).ids()) {
200     auto it = id_to_instruction.find(id);
201     if (it == id_to_instruction.end()) {
202       // This instruction in the schedule is no longer in the module. Do not add
203       // it to the new schedule.
204       continue;
205     }
206     worklist.push(it->second);
207     schedule_worklist();
208   }
209 
210   set_sequence(computation, std::move(new_sequence));
211   return OkStatus();
212 }
213 
Update(const absl::flat_hash_set<absl::string_view> & execution_threads)214 Status HloSchedule::Update(
215     const absl::flat_hash_set<absl::string_view>& execution_threads) {
216   // The schedule must contain a sequence for every non-fusion computation in
217   // the module for the specified threads, but can have sequences for
218   // computations which no longer exist (these are removed).
219   std::vector<HloComputation*> nonfusion_computations =
220       module_->MakeNonfusionComputations(execution_threads);
221   for (const HloComputation* computation : nonfusion_computations) {
222     TF_RET_CHECK(sequences_.contains(computation->unique_id()))
223         << "Computation " << computation->name() << " not in HloSchedule.";
224   }
225   auto sum_of_sequences_for_threads = [&]() -> int64_t {
226     if (execution_threads.empty()) {
227       return sequences_.size();
228     }
229     int64_t sequences_num_for_threads = 0;
230     for (const auto& [thread_name, sequence_num] :
231          num_sequences_by_execution_thread()) {
232       sequences_num_for_threads +=
233           execution_threads.contains(thread_name) ? sequence_num : 0;
234     }
235     return sequences_num_for_threads;
236   };
237   int64_t sequence_sum = sum_of_sequences_for_threads();
238   if (sequence_sum > nonfusion_computations.size()) {
239     // Schedule contains some computations which have been removed from the
240     // HloModule. Remove them from the schedule as well.
241     absl::flat_hash_set<int64_t> nonfusion_computations_ids;
242     for (const HloComputation* computation : nonfusion_computations) {
243       nonfusion_computations_ids.insert(computation->unique_id());
244     }
245     for (auto it = sequences_.begin(); it != sequences_.end();) {
246       std::string sequence_thread_name = tensorflow::gtl::FindWithDefault(
247           execution_threads_, it->first, HloInstruction::kMainExecutionThread);
248       bool is_thread_included =
249           execution_threads.empty() ||
250           execution_threads.contains(sequence_thread_name);
251       if (!nonfusion_computations_ids.contains(it->first) &&
252           is_thread_included) {
253         execution_threads_.erase(it->first);
254         sequences_.erase(it++);
255       } else {
256         ++it;
257       }
258     }
259   }
260   sequence_sum = sum_of_sequences_for_threads();
261   CHECK_EQ(sequence_sum, nonfusion_computations.size());
262 
263   for (const HloComputation* computation : nonfusion_computations) {
264     TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation));
265   }
266 
267   TF_RETURN_IF_ERROR(Verify());
268   return OkStatus();
269 }
270 
271 absl::flat_hash_map<std::string, int64_t>
num_sequences_by_execution_thread() const272 HloSchedule::num_sequences_by_execution_thread() const {
273   absl::flat_hash_map<std::string, int64_t> sequence_num_by_execution_threads;
274   for (const auto& id_sequence_item : sequences_) {
275     ++sequence_num_by_execution_threads[tensorflow::gtl::FindWithDefault(
276         execution_threads_, id_sequence_item.first,
277         HloInstruction::kMainExecutionThread)];
278   }
279   return sequence_num_by_execution_threads;
280 }
281 
Verify() const282 Status HloSchedule::Verify() const {
283   VLOG(2) << "VerifySchedule()";
284   XLA_VLOG_LINES(2, ToString());
285 
286   // Verify schedule contains exactly the same set of non-fusion computations as
287   // module currently does for each thread that has schedule.
288   absl::flat_hash_map<std::string, int64_t> sequence_num_by_execution_threads =
289       num_sequences_by_execution_thread();
290   for (const auto& [thread_name, sequence_size] :
291        sequence_num_by_execution_threads) {
292     std::vector<HloComputation*> nonfusion_computations =
293         module_->MakeNonfusionComputations({thread_name});
294     TF_RET_CHECK(nonfusion_computations.size() == sequence_size)
295         << "For thread " << thread_name << ", schedule has " << sequence_size
296         << " sequences, but module has " << nonfusion_computations.size()
297         << " non-fusion computations for thread " << thread_name;
298     for (const HloComputation* computation : nonfusion_computations) {
299       TF_RET_CHECK(sequences_.contains(computation->unique_id()))
300           << "Computation " << computation->name()
301           << " missing from HLO schedule.";
302     }
303 
304     // For each computation verify the set of instructions is the same and
305     // that each dependency and control edge is honored.
306     for (const HloComputation* computation : nonfusion_computations) {
307       absl::flat_hash_map<const HloInstruction*, int> instruction_position;
308       int pos = 0;
309       for (const HloInstruction* instruction :
310            sequence(computation).instructions()) {
311         TF_RET_CHECK(instruction_position.insert({instruction, pos}).second)
312             << "Instruction " << instruction->name()
313             << " appears more than once in the schedule";
314         pos++;
315       }
316 
317       TF_RET_CHECK(instruction_position.size() ==
318                    computation->instruction_count())
319           << "Schedule for computation " << computation->name() << " has "
320           << instruction_position.size() << " instructions, expected "
321           << computation->instruction_count();
322       for (const HloInstruction* instruction : computation->instructions()) {
323         TF_RET_CHECK(instruction_position.contains(instruction))
324             << "Instruction " << instruction->name() << " is not in schedule";
325       }
326 
327       for (const HloInstruction* instruction : computation->instructions()) {
328         for (const HloInstruction* operand : instruction->operands()) {
329           TF_RET_CHECK(instruction_position.at(operand) <
330                        instruction_position.at(instruction))
331               << "Instruction " << instruction->name()
332               << " is not scheduled after its operand " << operand->name();
333         }
334 
335         for (const HloInstruction* pred : instruction->control_predecessors()) {
336           TF_RET_CHECK(instruction_position.at(pred) <
337                        instruction_position.at(instruction))
338               << "Instruction " << instruction->name()
339               << " is not scheduled after its control predecessor "
340               << pred->name();
341         }
342       }
343     }
344   }
345 
346   return OkStatus();
347 }
348 
349 namespace {
350 
351 // Returns the computation in the given module with the given unique ID. Returns
352 // nullptr if no such computation exists.
IdToComputation(const HloModule * module,int64_t id)353 const HloComputation* IdToComputation(const HloModule* module, int64_t id) {
354   for (const HloComputation* computation : module->computations()) {
355     if (computation->unique_id() == id) {
356       return computation;
357     }
358   }
359   return nullptr;
360 }
361 
362 }  // namespace
363 
ToString() const364 std::string HloSchedule::ToString() const {
365   std::vector<std::string> pieces;
366 
367   pieces.push_back("HloSchedule");
368   for (const auto& id_sequence : sequences_) {
369     const HloComputation* computation =
370         IdToComputation(module_, id_sequence.first);
371     if (computation == nullptr) {
372       // The computation is not in the module and may have been deleted so it is
373       // not safe to dereference any HLO pointers. Just use the HLO unique ids
374       // stored in this object.
375       pieces.push_back(
376           absl::StrFormat("computation with id %d (no longer in HLO module):",
377                           id_sequence.first));
378       for (int id : id_sequence.second.ids()) {
379         pieces.push_back(absl::StrCat("  ", id));
380       }
381     } else {
382       pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
383       for (const HloInstruction* instruction :
384            id_sequence.second.instructions()) {
385         pieces.push_back(absl::StrCat("  ", instruction->name()));
386       }
387     }
388   }
389   return absl::StrJoin(pieces, "\n");
390 }
391 
operator <<(std::ostream & out,const HloSchedule & schedule)392 std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) {
393   out << schedule.ToString();
394   return out;
395 }
396 
397 }  // namespace xla
398