xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/dot_merger.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DOT_MERGER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DOT_MERGER_H_
18 
19 #include <functional>
20 #include <utility>
21 
22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
23 
24 namespace xla {
25 
26 // Merges dots that share an operand.  Transforms
27 //
28 //   x = dot(a, b)
29 //   y = dot(a, c)
30 //
31 // into
32 //
33 //   z = dot(a, concat(b, c))
34 //   x = slice(z)
35 //   y = slice(z).
36 //
37 // This requires that x and y are independent -- that is, x does not
38 // transitively depend on y, and y does not transitively depend on x.
39 //
40 // This is a good transformation if the merged dot runs faster than the original
41 // dots.  On the other hand, merging the dots results in a single result buffer
42 // z whose live range is the union of x and y's live ranges, so can lead to
43 // increased memory pressure.  You probably only want to do this optimization on
44 // "small" dots which cannot saturate your device when run alone.
45 //
46 // We thus allow backends to set a max size above which an op will not be
47 // merged.  The input+output bytes of at least one dot must be below the
48 // threshold otherwise we won't merge.  (We don't require that both dots be
49 // below the threshold because backends likely want to allow merging a "small"
50 // dot into a "large" dot while preventing two large dots from being merged.)
51 //
52 // Assumes DotDecomposer has already canonicalized the gemms and will skip
53 // noncanonical gemms.
54 class DotMerger : public HloModulePass {
55  public:
DotMerger(int64_t max_size_to_merge)56   explicit DotMerger(int64_t max_size_to_merge)
57       : max_size_to_merge_(max_size_to_merge) {}
58 
name()59   absl::string_view name() const override { return "dot-merger"; }
60   using HloPassInterface::Run;
61   StatusOr<bool> Run(
62       HloModule* module,
63       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
64 
65  private:
66   int64_t max_size_to_merge_;
67 };
68 
69 }  // namespace xla
70 
71 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DOT_MERGER_H_
72