xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
18 
19 #include <queue>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
27 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 
30 namespace xla {
31 namespace gpu {
32 
33 // Multi-output fusion of sibling and producer-consumer instructions for the
34 // GPU backend to reduce memory bandwidth requirements.
35 //
36 //   0) Before multi-    1) Sibling multi-    2) Producer-consumer
37 //      output fusion       output fusion        multi-output fusion
38 //
39 //          p                    p                    p
40 //          |                    |                    |
41 //          v                    v                    v
42 //          A                    A               +-fusion--+
43 //        /   \                  |               |    A    |
44 //       |     |            +-fusion--+          |   / \   |
45 //       v     v            |   / \   |          |  B   |  |
46 //       B     C            |  B   C  |          |  |   |  |
47 //        \   /             |  |   |  |          |  v   v  |
48 //         v v              |  v   v  |          |  tuple  |
49 //        ROOT              |  tuple  |          +---------+
50 //                          +---------+            /    \
51 //                            /    \            gte_b  gte_a
52 //                         gte_b  gte_c           |      |
53 //                           |      |             |      v
54 //                            \    /              |      C
55 //                             v  v                \    /
56 //                             ROOT                 v  v
57 //                                                  ROOT
58 //
59 // Multi-output fusion ops have a tuple op at their root containing multiple
60 // elements as outputs. GetTupleElement ops (depicted as gte_* above) are
61 // inserted to extract tuple elements for consumers.
62 //
63 // The two different flavors of multi-output fusion this pass performs are
64 // depicted above.
65 // 1) Fusion of sibling ops reduces memory bandwidth requirements, because
66 //    common input parameters have to be read only once.
67 // 2) Fusion of producer-consumer ops reduces memory bandwidth requirements by
68 //    saving one read from memory. In the example above, B does not need to read
69 //    the output of A from memory, while C still does (using gte_a).
70 // Note that sibling (1) and producer-consumer (2) multi-output fusion can be
71 // combined.
72 //
73 // The GpuMultiOutputFusion pass modifies the HLO in reverse post-order (defs
74 // before uses). First, it attempts to fuse the consumer ops of the current op,
75 // which are siblings (1). Hereafter, it attempts to fuse the current op with
76 // one of its consumers (2). This order avoids a phase ordering issue (described
77 // in go/fusionfusion). It ensures that all GetTupleElement ops inserted as a
78 // by-product of multi-output fusion will occur before the current op in the
79 // order of traversal, and hence, not get into the way of subsequent fusion
80 // attempts.
81 //
82 // The GpuMultiOutputFusion pass ensures several conditions are met for fusion.
83 // Some of them are relevant for correctness. In particular, no cycles must be
84 // introduced into the HLO module. Moreover, the code emitters for multi-output
85 // fusion must support the combination of ops and their shapes. Other
86 // restrictions are rather arbitrary and lifting them could be beneficial.
87 // * Sibling fusion (1) requires at least one op to be a kFusion.
88 // * Sibling fusion (1) does not fuse kInput fusions with kLoop fusions, i.e.
89 //   the fusion kinds must match.
90 
91 class GpuMultiOutputFusion : public HloModulePass {
92  public:
93   GpuMultiOutputFusion() = default;
94 
name()95   absl::string_view name() const override { return "multi_output_fusion"; }
96 
97   using HloPassInterface::Run;
98   StatusOr<bool> Run(
99       HloModule* module,
100       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
101 
102  private:
103   bool FuseSiblings(HloInstruction* parent, FusionInfoCache* fusion_info_cache);
104 
105   StatusOr<bool> DoMultiOutputFusion();
106 
107   // Recompute reachability for the current computation.
108   void RecomputeReachability();
109 
110   void DumpFusionState(const HloInstruction& consumer, absl::string_view label,
111                        const HloInstruction* producer = nullptr);
112 
113   // Computation for the pass.
114   HloComputation* computation_;
115 
116   // The reachability map of current computation.
117   std::unique_ptr<HloReachabilityMap> reachability_;
118 };
119 
120 }  // namespace gpu
121 }  // namespace xla
122 
123 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
124