xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
17 
18 #include <deque>
19 #include <memory>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/compiler/xla/service/buffer_value.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
25 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
26 
27 namespace xla {
28 namespace gpu {
29 
30 namespace {
31 
32 
ShouldScheduleAsEarlyAsPossible(const HloInstruction & instr)33 bool ShouldScheduleAsEarlyAsPossible(const HloInstruction& instr) {
34   switch (instr.opcode()) {
35     case HloOpcode::kAllReduceStart:
36       return true;
37     case HloOpcode::kCustomCall:
38       return static_cast<const HloCustomCallInstruction&>(instr)
39                  .custom_call_schedule() ==
40              CustomCallSchedule::SCHEDULE_EARLIEST;
41     default:
42       return false;
43   }
44 }
45 
ShouldScheduleSuccessor(const HloInstruction & sussessor,const HloPredicate & is_scheduled)46 bool ShouldScheduleSuccessor(const HloInstruction& sussessor,
47                              const HloPredicate& is_scheduled) {
48   return ShouldScheduleAsEarlyAsPossible(sussessor) &&
49          absl::c_all_of(sussessor.operands(), is_scheduled) &&
50          absl::c_all_of(sussessor.control_predecessors(), is_scheduled);
51 }
52 
ShouldScheduleAsLateAsPossible(const HloInstruction & instr)53 bool ShouldScheduleAsLateAsPossible(const HloInstruction& instr) {
54   switch (instr.opcode()) {
55     case HloOpcode::kAllReduceDone:
56       return true;
57     case HloOpcode::kCustomCall:
58       return static_cast<const HloCustomCallInstruction&>(instr)
59                  .custom_call_schedule() == CustomCallSchedule::SCHEDULE_LATEST;
60     default:
61       return false;
62   }
63 }
64 
ShouldSchedulePredecessor(const HloInstruction & predecessor,const HloPredicate & is_scheduled)65 bool ShouldSchedulePredecessor(const HloInstruction& predecessor,
66                                const HloPredicate& is_scheduled) {
67   return ShouldScheduleAsLateAsPossible(predecessor) &&
68          absl::c_all_of(predecessor.users(), is_scheduled) &&
69          absl::c_all_of(predecessor.control_successors(), is_scheduled);
70 }
71 
72 // Schedules certain ops as early or late as possible. This supports a
73 // custom-call use case, where a logical operation is lowered into two HLOs
74 // (e.g., PerformX and PerformXDone). We utilize this mechanism to either hide
75 // host latencies between the pair of the custom-calls or more accurately
76 // identify the def-use relationship of the two calls (typically PerformX is
77 // scheduled right after all of its producers have been scheduled and
78 // PerformXDone is scheduled right before its first consumer.)
PostprocessorToScheduleAsEarlyOrLateAsPossible(const HloInstructionSequence & input)79 HloInstructionSequence PostprocessorToScheduleAsEarlyOrLateAsPossible(
80     const HloInstructionSequence& input) {
81   std::vector<HloInstruction*> earliest_scheduled;
82   {
83     absl::flat_hash_set<HloInstruction*> scheduled;
84     auto is_scheduled = [&](const HloInstruction* instr) -> bool {
85       return scheduled.contains(instr);
86     };
87     auto add_to_schedule = [&](HloInstruction* instr) {
88       earliest_scheduled.push_back(instr);
89       scheduled.insert(instr);
90     };
91     for (HloInstruction* instr : input.instructions()) {
92       if (is_scheduled(instr)) {
93         continue;
94       }
95 
96       add_to_schedule(instr);
97 
98       // Schedule any successor that should be scheduled as early as possible if
99       // all of its producers and control_predecessors have been scheduled.
100       for (HloInstruction* user : instr->users()) {
101         if (ShouldScheduleSuccessor(*user, is_scheduled)) {
102           add_to_schedule(user);
103         }
104       }
105       for (HloInstruction* successor : instr->control_successors()) {
106         if (ShouldScheduleSuccessor(*successor, is_scheduled)) {
107           add_to_schedule(successor);
108         }
109       }
110     }
111   }
112 
113   std::deque<HloInstruction*> latest_scheduled;
114   {
115     absl::flat_hash_set<HloInstruction*> scheduled;
116     auto is_scheduled = [&](const HloInstruction* instr) -> bool {
117       return scheduled.contains(instr);
118     };
119     auto add_to_schedule = [&](HloInstruction* instr) {
120       latest_scheduled.push_front(instr);
121       scheduled.insert(instr);
122     };
123     for (auto it = earliest_scheduled.rbegin(); it != earliest_scheduled.rend();
124          it++) {
125       if (is_scheduled(*it)) {
126         continue;
127       }
128 
129       add_to_schedule(*it);
130 
131       // Schedule any predecessor that should be scheduled as late as possible
132       // if all of its users and control_successors have been scheduled.
133       for (HloInstruction* operand : (*it)->operands()) {
134         if (ShouldSchedulePredecessor(*operand, is_scheduled)) {
135           add_to_schedule(operand);
136         }
137       }
138       for (HloInstruction* predecessor : (*it)->control_predecessors()) {
139         if (ShouldSchedulePredecessor(*predecessor, is_scheduled)) {
140           add_to_schedule(predecessor);
141         }
142       }
143     }
144   }
145 
146   HloInstructionSequence result;
147   absl::c_for_each(latest_scheduled,
148                    [&](HloInstruction* i) { result.push_back(i); });
149   return result;
150 }
151 
152 }  // end namespace
153 
ScheduleGpuModule(const HloModule * module,int64_t pointer_size)154 StatusOr<HloSchedule> ScheduleGpuModule(const HloModule* module,
155                                         int64_t pointer_size) {
156   return ScheduleModule(
157       module,
158       [pointer_size](const BufferValue& buffer) {
159         return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
160       },
161       ComputationSchedulerToModuleScheduler(
162           DefaultMemoryScheduler,
163           PostprocessorToScheduleAsEarlyOrLateAsPossible));
164 }
165 
166 }  // namespace gpu
167 }  // namespace xla
168