xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
17 
18 #include <stdint.h>
19 
20 #include <algorithm>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/compiler/xla/debug_options_flags.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
30 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
31 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 
39 namespace xla {
40 namespace gpu {
41 
42 namespace {
43 
IsProfitableOperand(HloInstruction * instr)44 bool IsProfitableOperand(HloInstruction* instr) {
45   // kConstant instruction will not have memory reads, so it won't be a profit
46   // source. Skip them.
47   if (instr->opcode() == HloOpcode::kConstant &&
48       ShapeUtil::IsEffectiveScalar(instr->shape())) {
49     return false;
50   }
51   return true;
52 }
53 
LegalToFuse(HloInstruction * instr1,HloInstruction * instr2,FusionInfoCache * fusion_info_cache)54 FusionDecision LegalToFuse(HloInstruction* instr1, HloInstruction* instr2,
55                            FusionInfoCache* fusion_info_cache) {
56   // If we're fusing fusions only do it if the fusion kind matches. Loop fusions
57   // merge into bigger loop fusions and input (reduce) fusions become fusions
58   // with multiple reduce outputs. We could fuse reduce and loop fusions
59   // together too (the result being an input fusion) if we find cases where this
60   // improves things. Also disable fusing standalone input-fusible reduces into
61   // loop fusions.
62   CHECK(instr1->opcode() == HloOpcode::kFusion);
63   if ((instr2->opcode() == HloOpcode::kFusion &&
64        instr1->fusion_kind() != instr2->fusion_kind()) ||
65       (IsReductionFromOrToContiguousDimensions(*instr2) &&
66        instr1->IsLoopFusion())) {
67     return "Can't merge fusions of two different types";
68   }
69   // The emitter only supports in-place DUS for fusions with a single DUS at the
70   // root. Don't sibling fuse DUS for now.
71   // TODO(b/119178699): Multi-output fusing DUS can improve performance if we
72   // share the input and output buffers and add support to the emitter.
73   if (instr1->fused_expression_root()->opcode() ==
74           HloOpcode::kDynamicUpdateSlice ||
75       (instr2->opcode() == HloOpcode::kFusion &&
76        instr2->fused_expression_root()->opcode() ==
77            HloOpcode::kDynamicUpdateSlice)) {
78     return "Can't fuse multiple DUSs";
79   }
80 
81   // Do this check last, as it may be expensive.
82   return FusionFitsInBudget(*instr1, *instr2,
83                             /*is_consumer_producer_fusion=*/false,
84                             fusion_info_cache);
85 }
86 
87 // We prefer multi-output fusions over other fusions over unfused ops, because
88 // we want to preserve fusion opportunities if possible.
FusionPriority(const HloInstruction * instr)89 int FusionPriority(const HloInstruction* instr) {
90   if (instr->IsMultiOutputFusion()) {
91     return 2;
92   }
93   if (instr->opcode() == HloOpcode::kFusion) {
94     return 1;
95   }
96   return 0;
97 }
98 
SelectPreferredFusionCandidate(const std::vector<HloInstruction * > candidates)99 HloInstruction* SelectPreferredFusionCandidate(
100     const std::vector<HloInstruction*> candidates) {
101   if (candidates.empty()) {
102     return nullptr;
103   }
104   return *std::max_element(
105       candidates.begin(), candidates.end(),
106       [](const HloInstruction* a, const HloInstruction* b) {
107         return FusionPriority(a) < FusionPriority(b);
108       });
109 }
110 
GetProducerConsumerMultiOutputFusionCandidates(const HloInstruction * producer,const HloReachabilityMap & reachability,FusionInfoCache * fusion_info_cache)111 std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates(
112     const HloInstruction* producer, const HloReachabilityMap& reachability,
113     FusionInfoCache* fusion_info_cache) {
114   std::vector<HloInstruction*> fusion_candidates;
115   // If there is only one user, and it is not a multi-output fusion node, this
116   // fusion possibility was already considered and rejected by the FusionMerger
117   // pass. No need to try again!
118   if (producer->user_count() == 1 &&
119       !producer->users()[0]->IsMultiOutputFusion()) {
120     return fusion_candidates;
121   }
122   for (HloInstruction* consumer : producer->users()) {
123     VLOG(3) << "Looking at producer " << producer->name()
124             << " and its consumer " << consumer->name();
125     if (!IsFusibleAsMultiOutputFusionRoot(*consumer)) {
126       VLOG(3) << "Consumer " << consumer->name()
127               << " is not eligible as multi-output fusion root.";
128       continue;
129     }
130     if (!IsProducerConsumerMultiOutputFusible(*producer, *consumer)) {
131       VLOG(3) << producer->name() << " and " << consumer->name()
132               << " are not fusible.";
133       continue;
134     }
135     // Do not fuse a producer if the other operands of the fusion are
136     // reachable from the producer, this would create a cycle.
137     auto operand_reachable_from_producer = [&](const HloInstruction* operand) {
138       // If a get-tuple-element instruction is not in the reachability
139       // map, it has been created by fusion in this pass. Simply move
140       // on to its operand, which is in the reachability map.
141       if (!reachability.IsPresent(operand) &&
142           operand->opcode() == HloOpcode::kGetTupleElement) {
143         operand = operand->operand(0);
144       }
145       CHECK(reachability.IsPresent(operand) && reachability.IsPresent(producer))
146           << "Reachability map is incomplete. This should never "
147              "happen.";
148       return producer != operand && reachability.IsReachable(producer, operand);
149     };
150     if (absl::c_any_of(consumer->operands(), operand_reachable_from_producer)) {
151       VLOG(3) << producer->name() << " would introduce a cycle when fused.";
152       continue;
153     }
154     if (!FusionFitsInBudget(*producer, *consumer,
155                             /*is_consumer_producer_fusion=*/false,
156                             fusion_info_cache)) {
157       VLOG(3) << producer->name() << " and " << consumer->name()
158               << " would be too large of a fusion.";
159       continue;
160     }
161     // Make sure the emitter can codegen the fusion op efficiently. We currently
162     // can have exponential time/memory requirements for emitting certain fusion
163     // ops, in which case we don't want to fuse.
164     // TODO(b/119692968): Remove this once fixed in the emitter.
165     if (FusedIrEmitter::IsFusedIrEmitterInefficient(*consumer, *producer)) {
166       VLOG(3) << "Fusion of " << producer->name() << " into "
167               << consumer->name()
168               << " would result in overly large code duplication.";
169       continue;
170     }
171     fusion_candidates.push_back(consumer);
172   }
173   return fusion_candidates;
174 }
175 
IsSiblingFusionCandidate(const HloInstruction * instr)176 bool IsSiblingFusionCandidate(const HloInstruction* instr) {
177   if (instr->IsDead()) {
178     return false;
179   }
180   if (!IsFusibleAsMultiOutputFusionRoot(*instr)) {
181     return false;
182   }
183   // Check if the users of multioutput fusion is not a get-tuple-element.
184   // If this is the case, we bail out because the transformation assumes
185   // the users are get-tuple-element.
186   if (instr->IsMultiOutputFusion()) {
187     for (auto user : instr->users()) {
188       if (user->opcode() != HloOpcode::kGetTupleElement) {
189         return false;
190       }
191     }
192   }
193   return true;
194 }
195 
196 }  // namespace
197 
RecomputeReachability()198 void GpuMultiOutputFusion::RecomputeReachability() {
199   reachability_ = HloReachabilityMap::Build(computation_);
200 }
201 
FuseSiblings(HloInstruction * parent,FusionInfoCache * fusion_info_cache)202 bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent,
203                                         FusionInfoCache* fusion_info_cache) {
204   if (!IsProfitableOperand(parent)) {
205     VLOG(3) << "Operand " << parent->ToShortString() << " is not profitable";
206     return false;
207   }
208   bool changed = false;
209   std::vector<HloInstruction*> siblings = parent->users();
210   // Sort the siblings such that multi-output fusion ops occur first, followed
211   // by fusion ops, followed by unfused ops.
212   absl::c_stable_sort(siblings,
213                       [](const HloInstruction* a, const HloInstruction* b) {
214                         return FusionPriority(a) > FusionPriority(b);
215                       });
216   for (auto i = siblings.begin(); i != siblings.end(); ++i) {
217     VLOG(3) << "Considering " << (*i)->name();
218     if ((*i)->opcode() != HloOpcode::kFusion || !IsSiblingFusionCandidate(*i)) {
219       continue;
220     }
221     for (auto j = i + 1; j != siblings.end();) {
222       VLOG(3) << "Considering " << (*i)->name() << " and " << (*j)->name();
223       if (!IsSiblingFusionCandidate(*j) || reachability_->IsConnected(*i, *j) ||
224           !ShapesCompatibleForMultiOutputFusion(*(*i), *(*j)) ||
225           !LegalToFuse(*i, *j, fusion_info_cache)) {
226         ++j;
227         continue;
228       }
229       if (!ConsumeFuel(name(), [&] {
230             return absl::StrFormat("Not fusing siblings %s and %s.",
231                                    (*i)->name(), (*j)->name());
232           })) {
233         ++j;
234         continue;
235       }
236       VLOG(2) << "Fuse siblings " << (*i)->name() << " and " << (*j)->name();
237       fusion_info_cache->Invalidate(*i);
238       fusion_info_cache->Invalidate(*j);
239       HloInstruction* remaining = *i;
240       HloInstruction* fused = *j;
241 
242       DumpFusionState(*remaining,
243                       absl::StrCat("About to fuse producer |", fused->name(),
244                                    "| into consumer |", remaining->name(),
245                                    "| inside GPU multi-output fusion"),
246                       /*producer=*/fused);
247 
248       if (fused->opcode() == HloOpcode::kFusion) {
249         remaining->MergeFusionInstructionIntoMultiOutput(fused);
250       } else {
251         remaining->FuseInstructionIntoMultiOutput(fused);
252         CHECK_EQ(0, fused->user_count());
253         TF_CHECK_OK(computation_->RemoveInstruction(fused));
254       }
255       DumpFusionState(*remaining,
256                       absl::StrCat("Fused into consumer |", remaining->name(),
257                                    "| inside GPU multi-output fusion"));
258       changed = true;
259       siblings.erase(j);
260       RecomputeReachability();
261     }
262   }
263   return changed;
264 }
265 
DoMultiOutputFusion()266 StatusOr<bool> GpuMultiOutputFusion::DoMultiOutputFusion() {
267   bool changed = false;
268   RecomputeReachability();
269   std::vector<HloInstruction*> defs_before_uses =
270       computation_->MakeInstructionPostOrder();
271 
272   FusionInfoCache fusion_info_cache;
273   while (!defs_before_uses.empty()) {
274     // Traverse the HLO in uses-before-defs order by removing instruction from
275     // the back of the vector.
276     HloInstruction* producer = defs_before_uses.back();
277 
278     // Copy on purpose: to use after removing the producer.
279     std::string producer_name = producer->name();
280     defs_before_uses.pop_back();
281     // Never multi-output fuse constants.  To the extent that we want to fuse
282     // constants, that should be handled by the regular fusion pass.
283     if (producer->opcode() == HloOpcode::kConstant) {
284       VLOG(3) << producer->name() << " is a constant.";
285       continue;
286     }
287     // First, fuse the consumer ops of the current op, which are siblings.
288     if (FuseSiblings(/*parent=*/producer, &fusion_info_cache)) {
289       changed = true;
290     }
291     // Second, perform producer-consumer multi-output fusion. This order will
292     // ensure that all get-tuple-element ops inserted as a by-product of
293     // multi-output fusion will occur before the current op in the order of
294     // traversal, and hence, not get into the way of subsequent fusion attempts.
295     const auto candidates = GetProducerConsumerMultiOutputFusionCandidates(
296         producer, *reachability_, &fusion_info_cache);
297     auto* consumer_for_fusion = SelectPreferredFusionCandidate(candidates);
298     if (consumer_for_fusion == nullptr) {
299       continue;
300     }
301     if (!ConsumeFuel(name(), [&] {
302           return absl::StrFormat("Not fusing %s and %s.", producer->name(),
303                                  consumer_for_fusion->name());
304         })) {
305       continue;
306     }
307     changed = true;
308     fusion_info_cache.Invalidate(producer);
309     fusion_info_cache.Invalidate(consumer_for_fusion);
310 
311     if (consumer_for_fusion->opcode() == HloOpcode::kFusion) {
312       VLOG(2) << "Fuse producer " << producer->name() << " into its consumer "
313               << consumer_for_fusion->name();
314       DumpFusionState(
315           *consumer_for_fusion,
316           absl::StrCat("About to fuse producer |", producer_name,
317                        "| into consumer |", consumer_for_fusion->name(),
318                        "| inside GPU multi-output fusion"),
319           /*producer=*/producer);
320       if (producer->opcode() == HloOpcode::kFusion) {
321         consumer_for_fusion->MergeFusionInstructionIntoMultiOutput(producer);
322       } else {
323         consumer_for_fusion->FuseInstructionIntoMultiOutput(producer);
324         CHECK_EQ(0, producer->user_count());
325         TF_CHECK_OK(computation_->RemoveInstruction(producer));
326       }
327 
328       DumpFusionState(
329           *consumer_for_fusion,
330           absl::StrCat("Fusing producer |", producer_name, "| into consumer |",
331                        consumer_for_fusion->name(),
332                        "| inside GPU multi-output fusion"));
333       RecomputeReachability();
334       continue;
335     }
336     HloInstruction* input_fusion =
337         computation_->AddInstruction(HloInstruction::CreateFusion(
338             consumer_for_fusion->shape(),
339             ChooseFusionKind(*producer, *consumer_for_fusion),
340             consumer_for_fusion));
341     VLOG(2) << "Fuse producer " << producer->name() << " and its consumer "
342             << consumer_for_fusion->name() << " into " << input_fusion->name();
343     DumpFusionState(
344         *input_fusion,
345         absl::StrCat("About to fuse |", producer_name, "| into consumer |",
346                      input_fusion->name(), "| inside GPU multi-output fusion"),
347         /*producer=*/input_fusion);
348     TF_CHECK_OK(
349         computation_->ReplaceInstruction(consumer_for_fusion, input_fusion));
350     if (producer->opcode() == HloOpcode::kFusion) {
351       input_fusion->MergeFusionInstructionIntoMultiOutput(producer);
352     } else {
353       input_fusion->FuseInstructionIntoMultiOutput(producer);
354       CHECK_EQ(0, producer->user_count());
355       TF_CHECK_OK(computation_->RemoveInstruction(producer));
356     }
357 
358     DumpFusionState(
359         *input_fusion,
360         absl::StrCat("Fusing producer |", producer_name, "| into consumer |",
361                      input_fusion->name(), "| inside GPU multi-output fusion"));
362     RecomputeReachability();
363   }
364   return changed;
365 }
366 
DumpFusionState(const HloInstruction & consumer,absl::string_view label,const HloInstruction * producer)367 void GpuMultiOutputFusion::DumpFusionState(const HloInstruction& consumer,
368                                            absl::string_view label,
369                                            const HloInstruction* producer) {
370   if (consumer.GetModule()
371           ->config()
372           .debug_options()
373           .xla_dump_fusion_visualization()) {
374     RegisterFusionState(*computation_, label, consumer, producer);
375   }
376 }
377 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)378 StatusOr<bool> GpuMultiOutputFusion::Run(
379     HloModule* module,
380     const absl::flat_hash_set<absl::string_view>& execution_threads) {
381   bool changed = false;
382   for (auto* computation :
383        module->MakeNonfusionComputations(execution_threads)) {
384     computation_ = computation;
385     TF_ASSIGN_OR_RETURN(bool fusion_changed, DoMultiOutputFusion());
386     if (fusion_changed) {
387       changed = true;
388     }
389   }
390   return changed;
391 }
392 
393 }  // namespace gpu
394 }  // namespace xla
395