1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
17
18 #include <stdint.h>
19
20 #include <algorithm>
21 #include <memory>
22 #include <string>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/compiler/xla/debug_options_flags.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
30 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
31 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38
39 namespace xla {
40 namespace gpu {
41
42 namespace {
43
IsProfitableOperand(HloInstruction * instr)44 bool IsProfitableOperand(HloInstruction* instr) {
45 // kConstant instruction will not have memory reads, so it won't be a profit
46 // source. Skip them.
47 if (instr->opcode() == HloOpcode::kConstant &&
48 ShapeUtil::IsEffectiveScalar(instr->shape())) {
49 return false;
50 }
51 return true;
52 }
53
LegalToFuse(HloInstruction * instr1,HloInstruction * instr2,FusionInfoCache * fusion_info_cache)54 FusionDecision LegalToFuse(HloInstruction* instr1, HloInstruction* instr2,
55 FusionInfoCache* fusion_info_cache) {
56 // If we're fusing fusions only do it if the fusion kind matches. Loop fusions
57 // merge into bigger loop fusions and input (reduce) fusions become fusions
58 // with multiple reduce outputs. We could fuse reduce and loop fusions
59 // together too (the result being an input fusion) if we find cases where this
60 // improves things. Also disable fusing standalone input-fusible reduces into
61 // loop fusions.
62 CHECK(instr1->opcode() == HloOpcode::kFusion);
63 if ((instr2->opcode() == HloOpcode::kFusion &&
64 instr1->fusion_kind() != instr2->fusion_kind()) ||
65 (IsReductionFromOrToContiguousDimensions(*instr2) &&
66 instr1->IsLoopFusion())) {
67 return "Can't merge fusions of two different types";
68 }
69 // The emitter only supports in-place DUS for fusions with a single DUS at the
70 // root. Don't sibling fuse DUS for now.
71 // TODO(b/119178699): Multi-output fusing DUS can improve performance if we
72 // share the input and output buffers and add support to the emitter.
73 if (instr1->fused_expression_root()->opcode() ==
74 HloOpcode::kDynamicUpdateSlice ||
75 (instr2->opcode() == HloOpcode::kFusion &&
76 instr2->fused_expression_root()->opcode() ==
77 HloOpcode::kDynamicUpdateSlice)) {
78 return "Can't fuse multiple DUSs";
79 }
80
81 // Do this check last, as it may be expensive.
82 return FusionFitsInBudget(*instr1, *instr2,
83 /*is_consumer_producer_fusion=*/false,
84 fusion_info_cache);
85 }
86
87 // We prefer multi-output fusions over other fusions over unfused ops, because
88 // we want to preserve fusion opportunities if possible.
FusionPriority(const HloInstruction * instr)89 int FusionPriority(const HloInstruction* instr) {
90 if (instr->IsMultiOutputFusion()) {
91 return 2;
92 }
93 if (instr->opcode() == HloOpcode::kFusion) {
94 return 1;
95 }
96 return 0;
97 }
98
SelectPreferredFusionCandidate(const std::vector<HloInstruction * > candidates)99 HloInstruction* SelectPreferredFusionCandidate(
100 const std::vector<HloInstruction*> candidates) {
101 if (candidates.empty()) {
102 return nullptr;
103 }
104 return *std::max_element(
105 candidates.begin(), candidates.end(),
106 [](const HloInstruction* a, const HloInstruction* b) {
107 return FusionPriority(a) < FusionPriority(b);
108 });
109 }
110
GetProducerConsumerMultiOutputFusionCandidates(const HloInstruction * producer,const HloReachabilityMap & reachability,FusionInfoCache * fusion_info_cache)111 std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates(
112 const HloInstruction* producer, const HloReachabilityMap& reachability,
113 FusionInfoCache* fusion_info_cache) {
114 std::vector<HloInstruction*> fusion_candidates;
115 // If there is only one user, and it is not a multi-output fusion node, this
116 // fusion possibility was already considered and rejected by the FusionMerger
117 // pass. No need to try again!
118 if (producer->user_count() == 1 &&
119 !producer->users()[0]->IsMultiOutputFusion()) {
120 return fusion_candidates;
121 }
122 for (HloInstruction* consumer : producer->users()) {
123 VLOG(3) << "Looking at producer " << producer->name()
124 << " and its consumer " << consumer->name();
125 if (!IsFusibleAsMultiOutputFusionRoot(*consumer)) {
126 VLOG(3) << "Consumer " << consumer->name()
127 << " is not eligible as multi-output fusion root.";
128 continue;
129 }
130 if (!IsProducerConsumerMultiOutputFusible(*producer, *consumer)) {
131 VLOG(3) << producer->name() << " and " << consumer->name()
132 << " are not fusible.";
133 continue;
134 }
135 // Do not fuse a producer if the other operands of the fusion are
136 // reachable from the producer, this would create a cycle.
137 auto operand_reachable_from_producer = [&](const HloInstruction* operand) {
138 // If a get-tuple-element instruction is not in the reachability
139 // map, it has been created by fusion in this pass. Simply move
140 // on to its operand, which is in the reachability map.
141 if (!reachability.IsPresent(operand) &&
142 operand->opcode() == HloOpcode::kGetTupleElement) {
143 operand = operand->operand(0);
144 }
145 CHECK(reachability.IsPresent(operand) && reachability.IsPresent(producer))
146 << "Reachability map is incomplete. This should never "
147 "happen.";
148 return producer != operand && reachability.IsReachable(producer, operand);
149 };
150 if (absl::c_any_of(consumer->operands(), operand_reachable_from_producer)) {
151 VLOG(3) << producer->name() << " would introduce a cycle when fused.";
152 continue;
153 }
154 if (!FusionFitsInBudget(*producer, *consumer,
155 /*is_consumer_producer_fusion=*/false,
156 fusion_info_cache)) {
157 VLOG(3) << producer->name() << " and " << consumer->name()
158 << " would be too large of a fusion.";
159 continue;
160 }
161 // Make sure the emitter can codegen the fusion op efficiently. We currently
162 // can have exponential time/memory requirements for emitting certain fusion
163 // ops, in which case we don't want to fuse.
164 // TODO(b/119692968): Remove this once fixed in the emitter.
165 if (FusedIrEmitter::IsFusedIrEmitterInefficient(*consumer, *producer)) {
166 VLOG(3) << "Fusion of " << producer->name() << " into "
167 << consumer->name()
168 << " would result in overly large code duplication.";
169 continue;
170 }
171 fusion_candidates.push_back(consumer);
172 }
173 return fusion_candidates;
174 }
175
IsSiblingFusionCandidate(const HloInstruction * instr)176 bool IsSiblingFusionCandidate(const HloInstruction* instr) {
177 if (instr->IsDead()) {
178 return false;
179 }
180 if (!IsFusibleAsMultiOutputFusionRoot(*instr)) {
181 return false;
182 }
183 // Check if the users of multioutput fusion is not a get-tuple-element.
184 // If this is the case, we bail out because the transformation assumes
185 // the users are get-tuple-element.
186 if (instr->IsMultiOutputFusion()) {
187 for (auto user : instr->users()) {
188 if (user->opcode() != HloOpcode::kGetTupleElement) {
189 return false;
190 }
191 }
192 }
193 return true;
194 }
195
196 } // namespace
197
RecomputeReachability()198 void GpuMultiOutputFusion::RecomputeReachability() {
199 reachability_ = HloReachabilityMap::Build(computation_);
200 }
201
FuseSiblings(HloInstruction * parent,FusionInfoCache * fusion_info_cache)202 bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent,
203 FusionInfoCache* fusion_info_cache) {
204 if (!IsProfitableOperand(parent)) {
205 VLOG(3) << "Operand " << parent->ToShortString() << " is not profitable";
206 return false;
207 }
208 bool changed = false;
209 std::vector<HloInstruction*> siblings = parent->users();
210 // Sort the siblings such that multi-output fusion ops occur first, followed
211 // by fusion ops, followed by unfused ops.
212 absl::c_stable_sort(siblings,
213 [](const HloInstruction* a, const HloInstruction* b) {
214 return FusionPriority(a) > FusionPriority(b);
215 });
216 for (auto i = siblings.begin(); i != siblings.end(); ++i) {
217 VLOG(3) << "Considering " << (*i)->name();
218 if ((*i)->opcode() != HloOpcode::kFusion || !IsSiblingFusionCandidate(*i)) {
219 continue;
220 }
221 for (auto j = i + 1; j != siblings.end();) {
222 VLOG(3) << "Considering " << (*i)->name() << " and " << (*j)->name();
223 if (!IsSiblingFusionCandidate(*j) || reachability_->IsConnected(*i, *j) ||
224 !ShapesCompatibleForMultiOutputFusion(*(*i), *(*j)) ||
225 !LegalToFuse(*i, *j, fusion_info_cache)) {
226 ++j;
227 continue;
228 }
229 if (!ConsumeFuel(name(), [&] {
230 return absl::StrFormat("Not fusing siblings %s and %s.",
231 (*i)->name(), (*j)->name());
232 })) {
233 ++j;
234 continue;
235 }
236 VLOG(2) << "Fuse siblings " << (*i)->name() << " and " << (*j)->name();
237 fusion_info_cache->Invalidate(*i);
238 fusion_info_cache->Invalidate(*j);
239 HloInstruction* remaining = *i;
240 HloInstruction* fused = *j;
241
242 DumpFusionState(*remaining,
243 absl::StrCat("About to fuse producer |", fused->name(),
244 "| into consumer |", remaining->name(),
245 "| inside GPU multi-output fusion"),
246 /*producer=*/fused);
247
248 if (fused->opcode() == HloOpcode::kFusion) {
249 remaining->MergeFusionInstructionIntoMultiOutput(fused);
250 } else {
251 remaining->FuseInstructionIntoMultiOutput(fused);
252 CHECK_EQ(0, fused->user_count());
253 TF_CHECK_OK(computation_->RemoveInstruction(fused));
254 }
255 DumpFusionState(*remaining,
256 absl::StrCat("Fused into consumer |", remaining->name(),
257 "| inside GPU multi-output fusion"));
258 changed = true;
259 siblings.erase(j);
260 RecomputeReachability();
261 }
262 }
263 return changed;
264 }
265
DoMultiOutputFusion()266 StatusOr<bool> GpuMultiOutputFusion::DoMultiOutputFusion() {
267 bool changed = false;
268 RecomputeReachability();
269 std::vector<HloInstruction*> defs_before_uses =
270 computation_->MakeInstructionPostOrder();
271
272 FusionInfoCache fusion_info_cache;
273 while (!defs_before_uses.empty()) {
274 // Traverse the HLO in uses-before-defs order by removing instruction from
275 // the back of the vector.
276 HloInstruction* producer = defs_before_uses.back();
277
278 // Copy on purpose: to use after removing the producer.
279 std::string producer_name = producer->name();
280 defs_before_uses.pop_back();
281 // Never multi-output fuse constants. To the extent that we want to fuse
282 // constants, that should be handled by the regular fusion pass.
283 if (producer->opcode() == HloOpcode::kConstant) {
284 VLOG(3) << producer->name() << " is a constant.";
285 continue;
286 }
287 // First, fuse the consumer ops of the current op, which are siblings.
288 if (FuseSiblings(/*parent=*/producer, &fusion_info_cache)) {
289 changed = true;
290 }
291 // Second, perform producer-consumer multi-output fusion. This order will
292 // ensure that all get-tuple-element ops inserted as a by-product of
293 // multi-output fusion will occur before the current op in the order of
294 // traversal, and hence, not get into the way of subsequent fusion attempts.
295 const auto candidates = GetProducerConsumerMultiOutputFusionCandidates(
296 producer, *reachability_, &fusion_info_cache);
297 auto* consumer_for_fusion = SelectPreferredFusionCandidate(candidates);
298 if (consumer_for_fusion == nullptr) {
299 continue;
300 }
301 if (!ConsumeFuel(name(), [&] {
302 return absl::StrFormat("Not fusing %s and %s.", producer->name(),
303 consumer_for_fusion->name());
304 })) {
305 continue;
306 }
307 changed = true;
308 fusion_info_cache.Invalidate(producer);
309 fusion_info_cache.Invalidate(consumer_for_fusion);
310
311 if (consumer_for_fusion->opcode() == HloOpcode::kFusion) {
312 VLOG(2) << "Fuse producer " << producer->name() << " into its consumer "
313 << consumer_for_fusion->name();
314 DumpFusionState(
315 *consumer_for_fusion,
316 absl::StrCat("About to fuse producer |", producer_name,
317 "| into consumer |", consumer_for_fusion->name(),
318 "| inside GPU multi-output fusion"),
319 /*producer=*/producer);
320 if (producer->opcode() == HloOpcode::kFusion) {
321 consumer_for_fusion->MergeFusionInstructionIntoMultiOutput(producer);
322 } else {
323 consumer_for_fusion->FuseInstructionIntoMultiOutput(producer);
324 CHECK_EQ(0, producer->user_count());
325 TF_CHECK_OK(computation_->RemoveInstruction(producer));
326 }
327
328 DumpFusionState(
329 *consumer_for_fusion,
330 absl::StrCat("Fusing producer |", producer_name, "| into consumer |",
331 consumer_for_fusion->name(),
332 "| inside GPU multi-output fusion"));
333 RecomputeReachability();
334 continue;
335 }
336 HloInstruction* input_fusion =
337 computation_->AddInstruction(HloInstruction::CreateFusion(
338 consumer_for_fusion->shape(),
339 ChooseFusionKind(*producer, *consumer_for_fusion),
340 consumer_for_fusion));
341 VLOG(2) << "Fuse producer " << producer->name() << " and its consumer "
342 << consumer_for_fusion->name() << " into " << input_fusion->name();
343 DumpFusionState(
344 *input_fusion,
345 absl::StrCat("About to fuse |", producer_name, "| into consumer |",
346 input_fusion->name(), "| inside GPU multi-output fusion"),
347 /*producer=*/input_fusion);
348 TF_CHECK_OK(
349 computation_->ReplaceInstruction(consumer_for_fusion, input_fusion));
350 if (producer->opcode() == HloOpcode::kFusion) {
351 input_fusion->MergeFusionInstructionIntoMultiOutput(producer);
352 } else {
353 input_fusion->FuseInstructionIntoMultiOutput(producer);
354 CHECK_EQ(0, producer->user_count());
355 TF_CHECK_OK(computation_->RemoveInstruction(producer));
356 }
357
358 DumpFusionState(
359 *input_fusion,
360 absl::StrCat("Fusing producer |", producer_name, "| into consumer |",
361 input_fusion->name(), "| inside GPU multi-output fusion"));
362 RecomputeReachability();
363 }
364 return changed;
365 }
366
DumpFusionState(const HloInstruction & consumer,absl::string_view label,const HloInstruction * producer)367 void GpuMultiOutputFusion::DumpFusionState(const HloInstruction& consumer,
368 absl::string_view label,
369 const HloInstruction* producer) {
370 if (consumer.GetModule()
371 ->config()
372 .debug_options()
373 .xla_dump_fusion_visualization()) {
374 RegisterFusionState(*computation_, label, consumer, producer);
375 }
376 }
377
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)378 StatusOr<bool> GpuMultiOutputFusion::Run(
379 HloModule* module,
380 const absl::flat_hash_set<absl::string_view>& execution_threads) {
381 bool changed = false;
382 for (auto* computation :
383 module->MakeNonfusionComputations(execution_threads)) {
384 computation_ = computation;
385 TF_ASSIGN_OR_RETURN(bool fusion_changed, DoMultiOutputFusion());
386 if (fusion_changed) {
387 changed = true;
388 }
389 }
390 return changed;
391 }
392
393 } // namespace gpu
394 } // namespace xla
395