xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/all_gather_combiner.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/all_gather_combiner.h"
17 
18 #include <algorithm>
19 #include <list>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/service/collective_combiner_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/service/hlo_query.h"
36 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 
43 namespace xla {
44 namespace {
45 
46 // Combines the elements of to_combine into a single AllGather op. All entries
47 // in to_combine must be AllGather ops with exactly one operand and the same
48 // all_gather_dimension.
CombineAllGathers(absl::Span<HloInstruction * const> to_combine)49 Status CombineAllGathers(absl::Span<HloInstruction* const> to_combine) {
50   if (to_combine.size() < 2) {
51     return OkStatus();
52   }
53   VLOG(1) << "Combined " << to_combine.size() << " AllGather ops";
54 
55   HloComputation& computation = *to_combine.back()->parent();
56   int64_t all_gather_dimension =
57       Cast<HloAllGatherInstruction>(to_combine.front())->all_gather_dimension();
58 
59   // Create a single bigger AllGather of the operands of the smaller AllGather.
60   std::vector<HloInstruction*> operands;
61   std::vector<const Shape*> output_shapes;
62   VLOG(1) << "Combining set";
63   for (HloInstruction* hlo : to_combine) {
64     VLOG(1) << "Set element: " << hlo->ToString();
65     TF_RET_CHECK(hlo->opcode() == HloOpcode::kAllGather);
66     TF_RET_CHECK(hlo->operands().size() == 1);
67     TF_RET_CHECK(Cast<HloAllGatherInstruction>(hlo)->all_gather_dimension() ==
68                  all_gather_dimension);
69     TF_RET_CHECK(hlo->shape().IsArray());
70     for (HloInstruction* operand : hlo->operands()) {
71       operands.push_back(operand);
72       output_shapes.push_back(&hlo->shape());
73     }
74   }
75 
76   HloInstruction* combined;
77   // AllGather ops with more than one operand produce a tuple.
78   TF_RET_CHECK(operands.size() >= 2);
79   combined = computation.AddInstruction(HloInstruction::CreateAllGather(
80       ShapeUtil::MakeTupleShapeWithPtrs(output_shapes), operands,
81       all_gather_dimension, to_combine.front()->replica_groups(),
82       /*constrain_layout=*/false, to_combine.front()->channel_id(),
83       Cast<HloAllGatherInstruction>(to_combine.front())
84           ->use_global_device_ids()));
85 
86   // We have to propagate the sharding manually because Domain instructions are
87   // not guaranteed to preserve it for side effecting instructions.
88   if (to_combine.front()->has_sharding()) {
89     combined->set_sharding(to_combine.front()->sharding());
90   }
91   VLOG(1) << "Replacing with : " << combined->ToString();
92 
93   // Replace all the smaller AllGathers with elements of the tuple output
94   // of the single bigger AllGather.
95   for (int64_t i = 0; i < to_combine.size(); ++i) {
96     auto replace_with = HloInstruction::CreateGetTupleElement(
97         to_combine[i]->shape(), combined, i);
98     TF_RETURN_IF_ERROR(computation.ReplaceWithNewInstruction(
99         to_combine[i], std::move(replace_with)));
100   }
101   return OkStatus();
102 }
103 
104 // The group key encapsulates all of the properties which must match for it to
105 // be possible to combine the instructions.
106 using GroupKey =
107     std::tuple<int64_t, int64_t, bool, bool, std::vector<std::vector<int64_t>>>;
108 
109 // Returns a key that will be equal for instructions that might be combined, or
110 // different if not.
CombineKey(const HloInstruction * instruction,const HloDomainMap & domain_map)111 std::optional<GroupKey> CombineKey(const HloInstruction* instruction,
112                                    const HloDomainMap& domain_map) {
113   if (instruction->opcode() != HloOpcode::kAllGather) {
114     return std::nullopt;
115   }
116 
117   const auto* ag = Cast<HloAllGatherInstruction>(instruction);
118 
119   std::vector<std::vector<int64_t>> replica_groups;
120   replica_groups.reserve(ag->replica_groups().size());
121   for (const ReplicaGroup& replica_group : ag->replica_groups()) {
122     replica_groups.push_back(
123         std::vector<int64_t>(replica_group.replica_ids().begin(),
124                              replica_group.replica_ids().end()));
125   }
126 
127   return GroupKey{ag->all_gather_dimension(),
128                   domain_map.GetDomainMetadataId(ag),
129                   ag->channel_id().has_value(), ag->use_global_device_ids(),
130                   replica_groups};
131 }
132 
133 }  // namespace
134 
AllGatherCombiner(int64_t combine_threshold_in_bytes,int64_t combine_threshold_count)135 AllGatherCombiner::AllGatherCombiner(int64_t combine_threshold_in_bytes,
136                                      int64_t combine_threshold_count)
137     : combine_threshold_in_bytes_(combine_threshold_in_bytes),
138       combine_threshold_count_(combine_threshold_count) {}
139 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)140 StatusOr<bool> AllGatherCombiner::Run(
141     HloModule* module,
142     const absl::flat_hash_set<absl::string_view>& execution_threads) {
143   VLOG(1) << "Running AllGatherCombiner with threshold of "
144           << combine_threshold_in_bytes_ << " bytes";
145 
146   if (combine_threshold_in_bytes_ <= 0 || combine_threshold_count_ <= 0) {
147     VLOG(1) << "Skip AllGatherCombiner because the threshold is zero";
148     return false;
149   }
150 
151   if (hlo_query::ContainsLayoutConstrainedCollective(*module,
152                                                      HloOpcode::kAllGather)) {
153     VLOG(1) << "Skip AllGatherCombiner because the module contains "
154                "all-gather with constrained layouts";
155     return false;
156   }
157 
158   bool changed = false;
159   for (HloComputation* computation :
160        module->MakeNonfusionComputations(execution_threads)) {
161     TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));
162 
163     auto key_fn = [&domain_map](const HloInstruction* instruction) {
164       return CombineKey(instruction, *domain_map);
165     };
166 
167     TF_ASSIGN_OR_RETURN(
168         bool computation_changed,
169         CombineInstructionsByKey<GroupKey>(
170             computation, key_fn, &CombineAllGathers,
171             combine_threshold_in_bytes_, combine_threshold_count_));
172     changed |= computation_changed;
173   }
174 
175   return changed;
176 }
177 
178 }  // namespace xla
179