xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/collective_combiner_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_COMBINER_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_COMBINER_UTILS_H_
18 
19 #include <functional>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 
35 namespace xla {
36 
37 // Combines instructions with matching keys together.
38 //
39 // Instructions are combined in topological post-order.
40 //
41 // `key_fn` should return equal keys for two instructions that might be combined
42 // together. Instructions will be combined until the threshold for output byte
43 // size or instruction count is reached.
44 template <typename K>
CombineInstructionsByKey(HloComputation * computation,const std::function<std::optional<K> (const HloInstruction *)> & key_fn,const std::function<Status (absl::Span<HloInstruction * const>)> & combine_fn,int64_t combine_threshold_bytes,int64_t combine_threshold_count)45 StatusOr<bool> CombineInstructionsByKey(
46     HloComputation* computation,
47     const std::function<std::optional<K>(const HloInstruction*)>& key_fn,
48     const std::function<Status(absl::Span<HloInstruction* const>)>& combine_fn,
49     int64_t combine_threshold_bytes, int64_t combine_threshold_count) {
50   // Cache keys for each instruction and build sets of instructions with the
51   // same key that might be combined together.
52   absl::flat_hash_map<HloInstruction*, K> keys;
53   absl::flat_hash_map<K, absl::flat_hash_set<HloInstruction*>> groups;
54 
55   for (HloInstruction* instruction : computation->instructions()) {
56     std::optional<K> key = key_fn(instruction);
57     if (key) {
58       keys.insert({instruction, *key});
59       groups[*key].insert(instruction);
60     }
61   }
62 
63   bool changed = false;
64 
65   // Keys are removed after the instruction is combined (or never will be).
66   while (!keys.empty()) {
67     std::vector<HloInstruction*> to_combine;
68     int64_t to_combine_bytes = 0;
69     absl::flat_hash_set<HloInstruction*>* group = nullptr;
70 
71     // Recompute reachability after every combine group because we can't
72     // maintain a cross group topological order to be able to rely on the
73     // transitive dependencies to detect cycles.
74     std::unique_ptr<HloReachabilityMap> reachability =
75         HloReachabilityMap::Build(computation);
76 
77     for (HloInstruction* instruction :
78          computation->MakeInstructionPostOrder()) {
79       auto it = keys.find(instruction);
80       if (it == keys.end()) continue;
81 
82       // If this is the first instruction, set the active group.
83       if (to_combine.empty()) {
84         group = &groups.find(it->second)->second;
85       }
86 
87       // Check instruction is in the active group.
88       if (group->find(instruction) == group->end()) {
89         continue;
90       }
91 
92       VLOG(1) << "Considering HLO " << instruction->ToString()
93               << " with current set size of " << to_combine_bytes
94               << " and current operand count of " << to_combine.size();
95 
96       // We do not handle ops that have more than one operand since that is
97       // simpler and this pass is the only way to generate such ops.
98       if (instruction->operands().size() != 1) {
99         VLOG(1) << "Skipping due to " << instruction->operands().size()
100                 << " operands";
101         keys.erase(it);
102         continue;
103       }
104 
105       TF_RET_CHECK(instruction->shape().IsArray());
106       int64_t instruction_bytes = ShapeUtil::ByteSizeOf(instruction->shape());
107 
108       // If the instruction is greater than the threshold, then we can never
109       // combine it with anything.
110       if (instruction_bytes > combine_threshold_bytes) {
111         VLOG(1) << "Size " << instruction_bytes << " above threshold.";
112         keys.erase(it);
113         continue;
114       }
115 
116       if (to_combine_bytes + instruction_bytes > combine_threshold_bytes) {
117         VLOG(1) << "Combined size threshold exceeded.";
118         break;
119       }
120 
121       // We can't combine dependent instructions.
122       bool is_reachable =
123           absl::c_any_of(to_combine, [&](HloInstruction* to_combine_inst) {
124             return reachability->IsReachable(to_combine_inst, instruction);
125           });
126       if (is_reachable) {
127         VLOG(1) << "Instruction is reachable.";
128         break;
129       }
130 
131       VLOG(1) << "Adding instruction to set.";
132       to_combine.push_back(instruction);
133       to_combine_bytes += instruction_bytes;
134       keys.erase(it);
135 
136       if (to_combine.size() >= combine_threshold_count) {
137         VLOG(1) << "Combined count threshold reached.";
138         break;
139       }
140     }
141 
142     if (to_combine.size() > 1) {
143       TF_RETURN_IF_ERROR(combine_fn(to_combine));
144       changed = true;
145     }
146   }
147 
148   return changed;
149 }
150 
151 }  // namespace xla
152 
153 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_COLLECTIVE_COMBINER_UTILS_H_
154