1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/all_gather_combiner.h"
17
18 #include <algorithm>
19 #include <list>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/service/collective_combiner_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/service/hlo_query.h"
36 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42
43 namespace xla {
44 namespace {
45
46 // Combines the elements of to_combine into a single AllGather op. All entries
47 // in to_combine must be AllGather ops with exactly one operand and the same
48 // all_gather_dimension.
CombineAllGathers(absl::Span<HloInstruction * const> to_combine)49 Status CombineAllGathers(absl::Span<HloInstruction* const> to_combine) {
50 if (to_combine.size() < 2) {
51 return OkStatus();
52 }
53 VLOG(1) << "Combined " << to_combine.size() << " AllGather ops";
54
55 HloComputation& computation = *to_combine.back()->parent();
56 int64_t all_gather_dimension =
57 Cast<HloAllGatherInstruction>(to_combine.front())->all_gather_dimension();
58
59 // Create a single bigger AllGather of the operands of the smaller AllGather.
60 std::vector<HloInstruction*> operands;
61 std::vector<const Shape*> output_shapes;
62 VLOG(1) << "Combining set";
63 for (HloInstruction* hlo : to_combine) {
64 VLOG(1) << "Set element: " << hlo->ToString();
65 TF_RET_CHECK(hlo->opcode() == HloOpcode::kAllGather);
66 TF_RET_CHECK(hlo->operands().size() == 1);
67 TF_RET_CHECK(Cast<HloAllGatherInstruction>(hlo)->all_gather_dimension() ==
68 all_gather_dimension);
69 TF_RET_CHECK(hlo->shape().IsArray());
70 for (HloInstruction* operand : hlo->operands()) {
71 operands.push_back(operand);
72 output_shapes.push_back(&hlo->shape());
73 }
74 }
75
76 HloInstruction* combined;
77 // AllGather ops with more than one operand produce a tuple.
78 TF_RET_CHECK(operands.size() >= 2);
79 combined = computation.AddInstruction(HloInstruction::CreateAllGather(
80 ShapeUtil::MakeTupleShapeWithPtrs(output_shapes), operands,
81 all_gather_dimension, to_combine.front()->replica_groups(),
82 /*constrain_layout=*/false, to_combine.front()->channel_id(),
83 Cast<HloAllGatherInstruction>(to_combine.front())
84 ->use_global_device_ids()));
85
86 // We have to propagate the sharding manually because Domain instructions are
87 // not guaranteed to preserve it for side effecting instructions.
88 if (to_combine.front()->has_sharding()) {
89 combined->set_sharding(to_combine.front()->sharding());
90 }
91 VLOG(1) << "Replacing with : " << combined->ToString();
92
93 // Replace all the smaller AllGathers with elements of the tuple output
94 // of the single bigger AllGather.
95 for (int64_t i = 0; i < to_combine.size(); ++i) {
96 auto replace_with = HloInstruction::CreateGetTupleElement(
97 to_combine[i]->shape(), combined, i);
98 TF_RETURN_IF_ERROR(computation.ReplaceWithNewInstruction(
99 to_combine[i], std::move(replace_with)));
100 }
101 return OkStatus();
102 }
103
104 // The group key encapsulates all of the properties which must match for it to
105 // be possible to combine the instructions.
106 using GroupKey =
107 std::tuple<int64_t, int64_t, bool, bool, std::vector<std::vector<int64_t>>>;
108
109 // Returns a key that will be equal for instructions that might be combined, or
110 // different if not.
CombineKey(const HloInstruction * instruction,const HloDomainMap & domain_map)111 std::optional<GroupKey> CombineKey(const HloInstruction* instruction,
112 const HloDomainMap& domain_map) {
113 if (instruction->opcode() != HloOpcode::kAllGather) {
114 return std::nullopt;
115 }
116
117 const auto* ag = Cast<HloAllGatherInstruction>(instruction);
118
119 std::vector<std::vector<int64_t>> replica_groups;
120 replica_groups.reserve(ag->replica_groups().size());
121 for (const ReplicaGroup& replica_group : ag->replica_groups()) {
122 replica_groups.push_back(
123 std::vector<int64_t>(replica_group.replica_ids().begin(),
124 replica_group.replica_ids().end()));
125 }
126
127 return GroupKey{ag->all_gather_dimension(),
128 domain_map.GetDomainMetadataId(ag),
129 ag->channel_id().has_value(), ag->use_global_device_ids(),
130 replica_groups};
131 }
132
133 } // namespace
134
AllGatherCombiner(int64_t combine_threshold_in_bytes,int64_t combine_threshold_count)135 AllGatherCombiner::AllGatherCombiner(int64_t combine_threshold_in_bytes,
136 int64_t combine_threshold_count)
137 : combine_threshold_in_bytes_(combine_threshold_in_bytes),
138 combine_threshold_count_(combine_threshold_count) {}
139
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)140 StatusOr<bool> AllGatherCombiner::Run(
141 HloModule* module,
142 const absl::flat_hash_set<absl::string_view>& execution_threads) {
143 VLOG(1) << "Running AllGatherCombiner with threshold of "
144 << combine_threshold_in_bytes_ << " bytes";
145
146 if (combine_threshold_in_bytes_ <= 0 || combine_threshold_count_ <= 0) {
147 VLOG(1) << "Skip AllGatherCombiner because the threshold is zero";
148 return false;
149 }
150
151 if (hlo_query::ContainsLayoutConstrainedCollective(*module,
152 HloOpcode::kAllGather)) {
153 VLOG(1) << "Skip AllGatherCombiner because the module contains "
154 "all-gather with constrained layouts";
155 return false;
156 }
157
158 bool changed = false;
159 for (HloComputation* computation :
160 module->MakeNonfusionComputations(execution_threads)) {
161 TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));
162
163 auto key_fn = [&domain_map](const HloInstruction* instruction) {
164 return CombineKey(instruction, *domain_map);
165 };
166
167 TF_ASSIGN_OR_RETURN(
168 bool computation_changed,
169 CombineInstructionsByKey<GroupKey>(
170 computation, key_fn, &CombineAllGathers,
171 combine_threshold_in_bytes_, combine_threshold_count_));
172 changed |= computation_changed;
173 }
174
175 return changed;
176 }
177
178 } // namespace xla
179