1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h"
17
18 #include <deque>
19 #include <memory>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/compiler/xla/service/buffer_value.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
25 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
26
27 namespace xla {
28 namespace gpu {
29
30 namespace {
31
32
ShouldScheduleAsEarlyAsPossible(const HloInstruction & instr)33 bool ShouldScheduleAsEarlyAsPossible(const HloInstruction& instr) {
34 switch (instr.opcode()) {
35 case HloOpcode::kAllReduceStart:
36 return true;
37 case HloOpcode::kCustomCall:
38 return static_cast<const HloCustomCallInstruction&>(instr)
39 .custom_call_schedule() ==
40 CustomCallSchedule::SCHEDULE_EARLIEST;
41 default:
42 return false;
43 }
44 }
45
ShouldScheduleSuccessor(const HloInstruction & sussessor,const HloPredicate & is_scheduled)46 bool ShouldScheduleSuccessor(const HloInstruction& sussessor,
47 const HloPredicate& is_scheduled) {
48 return ShouldScheduleAsEarlyAsPossible(sussessor) &&
49 absl::c_all_of(sussessor.operands(), is_scheduled) &&
50 absl::c_all_of(sussessor.control_predecessors(), is_scheduled);
51 }
52
ShouldScheduleAsLateAsPossible(const HloInstruction & instr)53 bool ShouldScheduleAsLateAsPossible(const HloInstruction& instr) {
54 switch (instr.opcode()) {
55 case HloOpcode::kAllReduceDone:
56 return true;
57 case HloOpcode::kCustomCall:
58 return static_cast<const HloCustomCallInstruction&>(instr)
59 .custom_call_schedule() == CustomCallSchedule::SCHEDULE_LATEST;
60 default:
61 return false;
62 }
63 }
64
ShouldSchedulePredecessor(const HloInstruction & predecessor,const HloPredicate & is_scheduled)65 bool ShouldSchedulePredecessor(const HloInstruction& predecessor,
66 const HloPredicate& is_scheduled) {
67 return ShouldScheduleAsLateAsPossible(predecessor) &&
68 absl::c_all_of(predecessor.users(), is_scheduled) &&
69 absl::c_all_of(predecessor.control_successors(), is_scheduled);
70 }
71
72 // Schedules certain ops as early or late as possible. This supports a
73 // custom-call use case, where a logical operation is lowered into two HLOs
74 // (e.g., PerformX and PerformXDone). We utilize this mechanism to either hide
75 // host latencies between the pair of the custom-calls or more accurately
76 // identify the def-use relationship of the two calls (typically PerformX is
77 // scheduled right after all of its producers have been scheduled and
78 // PerformXDone is scheduled right before its first consumer.)
PostprocessorToScheduleAsEarlyOrLateAsPossible(const HloInstructionSequence & input)79 HloInstructionSequence PostprocessorToScheduleAsEarlyOrLateAsPossible(
80 const HloInstructionSequence& input) {
81 std::vector<HloInstruction*> earliest_scheduled;
82 {
83 absl::flat_hash_set<HloInstruction*> scheduled;
84 auto is_scheduled = [&](const HloInstruction* instr) -> bool {
85 return scheduled.contains(instr);
86 };
87 auto add_to_schedule = [&](HloInstruction* instr) {
88 earliest_scheduled.push_back(instr);
89 scheduled.insert(instr);
90 };
91 for (HloInstruction* instr : input.instructions()) {
92 if (is_scheduled(instr)) {
93 continue;
94 }
95
96 add_to_schedule(instr);
97
98 // Schedule any successor that should be scheduled as early as possible if
99 // all of its producers and control_predecessors have been scheduled.
100 for (HloInstruction* user : instr->users()) {
101 if (ShouldScheduleSuccessor(*user, is_scheduled)) {
102 add_to_schedule(user);
103 }
104 }
105 for (HloInstruction* successor : instr->control_successors()) {
106 if (ShouldScheduleSuccessor(*successor, is_scheduled)) {
107 add_to_schedule(successor);
108 }
109 }
110 }
111 }
112
113 std::deque<HloInstruction*> latest_scheduled;
114 {
115 absl::flat_hash_set<HloInstruction*> scheduled;
116 auto is_scheduled = [&](const HloInstruction* instr) -> bool {
117 return scheduled.contains(instr);
118 };
119 auto add_to_schedule = [&](HloInstruction* instr) {
120 latest_scheduled.push_front(instr);
121 scheduled.insert(instr);
122 };
123 for (auto it = earliest_scheduled.rbegin(); it != earliest_scheduled.rend();
124 it++) {
125 if (is_scheduled(*it)) {
126 continue;
127 }
128
129 add_to_schedule(*it);
130
131 // Schedule any predecessor that should be scheduled as late as possible
132 // if all of its users and control_successors have been scheduled.
133 for (HloInstruction* operand : (*it)->operands()) {
134 if (ShouldSchedulePredecessor(*operand, is_scheduled)) {
135 add_to_schedule(operand);
136 }
137 }
138 for (HloInstruction* predecessor : (*it)->control_predecessors()) {
139 if (ShouldSchedulePredecessor(*predecessor, is_scheduled)) {
140 add_to_schedule(predecessor);
141 }
142 }
143 }
144 }
145
146 HloInstructionSequence result;
147 absl::c_for_each(latest_scheduled,
148 [&](HloInstruction* i) { result.push_back(i); });
149 return result;
150 }
151
152 } // end namespace
153
ScheduleGpuModule(const HloModule * module,int64_t pointer_size)154 StatusOr<HloSchedule> ScheduleGpuModule(const HloModule* module,
155 int64_t pointer_size) {
156 return ScheduleModule(
157 module,
158 [pointer_size](const BufferValue& buffer) {
159 return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
160 },
161 ComputationSchedulerToModuleScheduler(
162 DefaultMemoryScheduler,
163 PostprocessorToScheduleAsEarlyOrLateAsPossible));
164 }
165
166 } // namespace gpu
167 } // namespace xla
168