1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DOT_MERGER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DOT_MERGER_H_ 18 19 #include <functional> 20 #include <utility> 21 22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 23 24 namespace xla { 25 26 // Merges dots that share an operand. Transforms 27 // 28 // x = dot(a, b) 29 // y = dot(a, c) 30 // 31 // into 32 // 33 // z = dot(a, concat(b, c)) 34 // x = slice(z) 35 // y = slice(z). 36 // 37 // This requires that x and y are independent -- that is, x does not 38 // transitively depend on y, and y does not transitively depend on x. 39 // 40 // This is a good transformation if the merged dot runs faster than the original 41 // dots. On the other hand, merging the dots results in a single result buffer 42 // z whose live range is the union of x and y's live ranges, so can lead to 43 // increased memory pressure. You probably only want to do this optimization on 44 // "small" dots which cannot saturate your device when run alone. 45 // 46 // We thus allow backends to set a max size above which an op will not be 47 // merged. The input+output bytes of at least one dot must be below the 48 // threshold otherwise we won't merge. (We don't require that both dots be 49 // below the threshold because backends likely want to allow merging a "small" 50 // dot into a "large" dot while preventing two large dots from being merged.) 51 // 52 // Assumes DotDecomposer has already canonicalized the gemms and will skip 53 // noncanonical gemms. 54 class DotMerger : public HloModulePass { 55 public: DotMerger(int64_t max_size_to_merge)56 explicit DotMerger(int64_t max_size_to_merge) 57 : max_size_to_merge_(max_size_to_merge) {} 58 name()59 absl::string_view name() const override { return "dot-merger"; } 60 using HloPassInterface::Run; 61 StatusOr<bool> Run( 62 HloModule* module, 63 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 64 65 private: 66 int64_t max_size_to_merge_; 67 }; 68 69 } // namespace xla 70 71 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DOT_MERGER_H_ 72