1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_COMBINER_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_COMBINER_UTILS_H_
18
19 #include <functional>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34
35 namespace xla {
36
37 // Combines instructions with matching keys together.
38 //
39 // Instructions are combined in topological post-order.
40 //
41 // `key_fn` should return equal keys for two instructions that might be combined
42 // together. Instructions will be combined until the threshold for output byte
43 // size or instruction count is reached.
44 template <typename K>
CombineInstructionsByKey(HloComputation * computation,const std::function<std::optional<K> (const HloInstruction *)> & key_fn,const std::function<Status (absl::Span<HloInstruction * const>)> & combine_fn,int64_t combine_threshold_bytes,int64_t combine_threshold_count)45 StatusOr<bool> CombineInstructionsByKey(
46 HloComputation* computation,
47 const std::function<std::optional<K>(const HloInstruction*)>& key_fn,
48 const std::function<Status(absl::Span<HloInstruction* const>)>& combine_fn,
49 int64_t combine_threshold_bytes, int64_t combine_threshold_count) {
50 // Cache keys for each instruction and build sets of instructions with the
51 // same key that might be combined together.
52 absl::flat_hash_map<HloInstruction*, K> keys;
53 absl::flat_hash_map<K, absl::flat_hash_set<HloInstruction*>> groups;
54
55 for (HloInstruction* instruction : computation->instructions()) {
56 std::optional<K> key = key_fn(instruction);
57 if (key) {
58 keys.insert({instruction, *key});
59 groups[*key].insert(instruction);
60 }
61 }
62
63 bool changed = false;
64
65 // Keys are removed after the instruction is combined (or never will be).
66 while (!keys.empty()) {
67 std::vector<HloInstruction*> to_combine;
68 int64_t to_combine_bytes = 0;
69 absl::flat_hash_set<HloInstruction*>* group = nullptr;
70
71 // Recompute reachability after every combine group because we can't
72 // maintain a cross group topological order to be able to rely on the
73 // transitive dependencies to detect cycles.
74 std::unique_ptr<HloReachabilityMap> reachability =
75 HloReachabilityMap::Build(computation);
76
77 for (HloInstruction* instruction :
78 computation->MakeInstructionPostOrder()) {
79 auto it = keys.find(instruction);
80 if (it == keys.end()) continue;
81
82 // If this is the first instruction, set the active group.
83 if (to_combine.empty()) {
84 group = &groups.find(it->second)->second;
85 }
86
87 // Check instruction is in the active group.
88 if (group->find(instruction) == group->end()) {
89 continue;
90 }
91
92 VLOG(1) << "Considering HLO " << instruction->ToString()
93 << " with current set size of " << to_combine_bytes
94 << " and current operand count of " << to_combine.size();
95
96 // We do not handle ops that have more than one operand since that is
97 // simpler and this pass is the only way to generate such ops.
98 if (instruction->operands().size() != 1) {
99 VLOG(1) << "Skipping due to " << instruction->operands().size()
100 << " operands";
101 keys.erase(it);
102 continue;
103 }
104
105 TF_RET_CHECK(instruction->shape().IsArray());
106 int64_t instruction_bytes = ShapeUtil::ByteSizeOf(instruction->shape());
107
108 // If the instruction is greater than the threshold, then we can never
109 // combine it with anything.
110 if (instruction_bytes > combine_threshold_bytes) {
111 VLOG(1) << "Size " << instruction_bytes << " above threshold.";
112 keys.erase(it);
113 continue;
114 }
115
116 if (to_combine_bytes + instruction_bytes > combine_threshold_bytes) {
117 VLOG(1) << "Combined size threshold exceeded.";
118 break;
119 }
120
121 // We can't combine dependent instructions.
122 bool is_reachable =
123 absl::c_any_of(to_combine, [&](HloInstruction* to_combine_inst) {
124 return reachability->IsReachable(to_combine_inst, instruction);
125 });
126 if (is_reachable) {
127 VLOG(1) << "Instruction is reachable.";
128 break;
129 }
130
131 VLOG(1) << "Adding instruction to set.";
132 to_combine.push_back(instruction);
133 to_combine_bytes += instruction_bytes;
134 keys.erase(it);
135
136 if (to_combine.size() >= combine_threshold_count) {
137 VLOG(1) << "Combined count threshold reached.";
138 break;
139 }
140 }
141
142 if (to_combine.size() > 1) {
143 TF_RETURN_IF_ERROR(combine_fn(to_combine));
144 changed = true;
145 }
146 }
147
148 return changed;
149 }
150
151 } // namespace xla
152
153 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_COMBINER_UTILS_H_
154