xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_
18 
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
23 
24 namespace xla {
25 namespace gpu {
26 
27 // This optimization pass horizontally fuses computations for reducing kernel
28 // launch overhead while increasing kernel launch dims on GPU. The initial
29 // motivation of this horizontal fusion is due to the observation that the
30 // training optimizer phase (e.g., AdamOptimizer and L2Loss, etc.) typically
31 // has many small kernels as a result of applying the same formula on many
32 // training parameters (or variables in Tensorflow). Fusing these small
33 // kernels, hence, provides performance gain.
34 //
35 // Theoretically speaking, we may implement a cycle detection algorithm to make
36 // sure no cycles are created after fusion. However, cycle detection check is
37 // somewhat cumbersome; also, we observe that naive horizontal fusion of
38 // arbitrary kernels may not be profitable due to control divergence and
39 // possible increase of memory bandwidth pressure due to uncoalesced memory
40 // accesses (note that horizontal fusion does not change the amount of memory
41 // read+written at all). In practice, a simple yet effective heuristic is used
42 // to avoid these issues while addressing the known beneficial cases. That is,
43 // we simply search for fusion candidates by looking for instructions whose
44 // outputs are all consumed by the same instruction. This catches the cases in
45 // the training optimizer phase, as the candidate instructions are typically
46 // consumed only by the ROOT tuple of the entry computation.
47 //
48 // The following illustrates the mechanism of the horizontal fusion. Before
49 // fusion, there are two trivial kernels in the illustrating example. One has
50 // only a Mul op, while the other consists of only an Add op. Since they are
51 // only consumed by the same (ROOT) tuple instruction, horizontal fusion is
52 // triggered.
53 //
54 // i0 i1   i2 i3
55 //  | |     | |
56 //  v v     v v
57 //  Mul     Add
58 //   |       |
59 //   v       v
60 //  (ROOT) tuple
61 //
62 // We horizontally fuse them into the below pattern.
63 //
64 // i0 i1   i2 i3       +++ (Slice) Input Fusion
65 //  | |     | |          +
66 //  v v     v v          +
67 //  Mul     Add          +
68 //   |       |           +
69 //   v       v           +
70 // Reshape0  Reshape1    +
71 //   |       |           +
72 //   v       v           +
73 //  Concatenate          +
74 //   |       |           +
75 //   v       v           +
76 //  Slice0  Slice1     +++
77 //   |       |
78 //   v       v
79 // Reshape2  Reshape3
80 //   |       |
81 //   v       v
82 //  (ROOT) tuple
83 //
84 // Note that this fusion style provides an important advantage that kernels of
85 // different shapes can be horizontally fused. The first pair of reshapes
86 // (i.e., Reshape0 and Reshape1) reshape the dims to 1 dimension, so that the
87 // outputs of the fused kernels can (always) be concatenated. The second pair
88 // of reshapes (Reshape2 and Reshape3) restore the original shapes to the
89 // output tensors.
90 //
91 // No extra copies are introduced by the horizontal fusion. Besides Reshape2
92 // and Reshape3, the other instructions are fused into an input fusion; the
93 // output dims of the concatenate will be used as the kernel launch dims.
94 // Instruction bitcasts can be used for Reshape2 and Reshape3 as long as the
95 // outputs of Mul and Add are row-major.
96 class GpuHorizontalLoopFusion : public HloModulePass {
97  public:
GpuHorizontalLoopFusion()98   GpuHorizontalLoopFusion() {}
GpuHorizontalLoopFusion(absl::string_view prefix)99   GpuHorizontalLoopFusion(absl::string_view prefix) : prefix_(prefix) {}
100 
name()101   absl::string_view name() const override {
102     return "gpu_horizontal_loop_fusion";
103   }
104 
105   using HloPassInterface::Run;
106   StatusOr<bool> Run(
107       HloModule* module,
108       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
109 
110  private:
111   StatusOr<bool> RunOnComputation(HloComputation*);
112   std::string prefix_;
113 };
114 
115 }  // namespace gpu
116 }  // namespace xla
117 
118 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_
119