1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ 18 19 #include <queue> 20 #include <vector> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/strings/string_view.h" 24 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 27 #include "tensorflow/compiler/xla/service/hlo_reachability.h" 28 #include "tensorflow/compiler/xla/statusor.h" 29 30 namespace xla { 31 namespace gpu { 32 33 // Multi-output fusion of sibling and producer-consumer instructions for the 34 // GPU backend to reduce memory bandwidth requirements. 35 // 36 // 0) Before multi- 1) Sibling multi- 2) Producer-consumer 37 // output fusion output fusion multi-output fusion 38 // 39 // p p p 40 // | | | 41 // v v v 42 // A A +-fusion--+ 43 // / \ | | A | 44 // | | +-fusion--+ | / \ | 45 // v v | / \ | | B | | 46 // B C | B C | | | | | 47 // \ / | | | | | v v | 48 // v v | v v | | tuple | 49 // ROOT | tuple | +---------+ 50 // +---------+ / \ 51 // / \ gte_b gte_a 52 // gte_b gte_c | | 53 // | | | v 54 // \ / | C 55 // v v \ / 56 // ROOT v v 57 // ROOT 58 // 59 // Multi-output fusion ops have a tuple op at their root containing multiple 60 // elements as outputs. GetTupleElement ops (depicted as gte_* above) are 61 // inserted to extract tuple elements for consumers. 62 // 63 // The two different flavors of multi-output fusion this pass performs are 64 // depicted above. 65 // 1) Fusion of sibling ops reduces memory bandwidth requirements, because 66 // common input parameters have to be read only once. 67 // 2) Fusion of producer-consumer ops reduces memory bandwidth requirements by 68 // saving one read from memory. In the example above, B does not need to read 69 // the output of A from memory, while C still does (using gte_a). 70 // Note that sibling (1) and producer-consumer (2) multi-output fusion can be 71 // combined. 72 // 73 // The GpuMultiOutputFusion pass modifies the HLO in reverse post-order (defs 74 // before uses). First, it attempts to fuse the consumer ops of the current op, 75 // which are siblings (1). Hereafter, it attempts to fuse the current op with 76 // one of its consumers (2). This order avoids a phase ordering issue (described 77 // in go/fusionfusion). It ensures that all GetTupleElement ops inserted as a 78 // by-product of multi-output fusion will occur before the current op in the 79 // order of traversal, and hence, not get into the way of subsequent fusion 80 // attempts. 81 // 82 // The GpuMultiOutputFusion pass ensures several conditions are met for fusion. 83 // Some of them are relevant for correctness. In particular, no cycles must be 84 // introduced into the HLO module. Moreover, the code emitters for multi-output 85 // fusion must support the combination of ops and their shapes. Other 86 // restrictions are rather arbitrary and lifting them could be beneficial. 87 // * Sibling fusion (1) requires at least one op to be a kFusion. 88 // * Sibling fusion (1) does not fuse kInput fusions with kLoop fusions, i.e. 89 // the fusion kinds must match. 90 91 class GpuMultiOutputFusion : public HloModulePass { 92 public: 93 GpuMultiOutputFusion() = default; 94 name()95 absl::string_view name() const override { return "multi_output_fusion"; } 96 97 using HloPassInterface::Run; 98 StatusOr<bool> Run( 99 HloModule* module, 100 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 101 102 private: 103 bool FuseSiblings(HloInstruction* parent, FusionInfoCache* fusion_info_cache); 104 105 StatusOr<bool> DoMultiOutputFusion(); 106 107 // Recompute reachability for the current computation. 108 void RecomputeReachability(); 109 110 void DumpFusionState(const HloInstruction& consumer, absl::string_view label, 111 const HloInstruction* producer = nullptr); 112 113 // Computation for the pass. 114 HloComputation* computation_; 115 116 // The reachability map of current computation. 117 std::unique_ptr<HloReachabilityMap> reachability_; 118 }; 119 120 } // namespace gpu 121 } // namespace xla 122 123 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ 124