xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/all_reduce_combiner.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
17 
18 #include <algorithm>
19 #include <list>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/service/all_reduce_key.h"
30 #include "tensorflow/compiler/xla/service/collective_combiner_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
35 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
36 #include "tensorflow/compiler/xla/service/hlo_query.h"
37 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
38 #include "tensorflow/compiler/xla/service/shape_inference.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 
44 namespace xla {
45 namespace {
46 
47 // Combines the elements of to_combine into a single AllReduce op. All
48 // entries in to_combine must be AllReduce ops with exactly one operand
49 // and the same reduction operation.
CombineAllReduces(absl::Span<HloInstruction * const> to_combine)50 Status CombineAllReduces(absl::Span<HloInstruction* const> to_combine) {
51   if (to_combine.size() < 2) {
52     return OkStatus();
53   }
54   VLOG(1) << "Combined " << to_combine.size() << " CRS ops";
55 
56   HloComputation& computation = *to_combine.back()->parent();
57   HloComputation* reduction = to_combine[0]->to_apply();
58   const HloOpcode type = reduction->root_instruction()->opcode();
59 
60   // Create a single bigger AllReduce of the operands of the smaller
61   // AllReduces.
62   std::vector<HloInstruction*> operands;
63   std::vector<const Shape*> operand_shapes;
64   VLOG(1) << "Combining set";
65   for (HloInstruction* hlo : to_combine) {
66     VLOG(1) << "Set element: " << hlo->ToString();
67     TF_RET_CHECK(hlo->opcode() == HloOpcode::kAllReduce);
68     TF_RET_CHECK(hlo->operands().size() == 1);
69     TF_RET_CHECK(hlo->to_apply() == reduction ||
70                  (hlo->to_apply()->instruction_count() == 3 &&
71                   hlo->to_apply()->num_parameters() == 2 &&
72                   hlo->to_apply()->root_instruction()->opcode() == type));
73     TF_RET_CHECK(hlo->shape().IsArray());
74     for (HloInstruction* operand : hlo->operands()) {
75       operands.push_back(operand);
76       operand_shapes.push_back(&operand->shape());
77     }
78   }
79 
80   HloInstruction* combined;
81   // AllReduce ops with more than one operand produce a tuple.
82   TF_RET_CHECK(operands.size() >= 2);
83   combined = computation.AddInstruction(HloInstruction::CreateAllReduce(
84       ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes), operands, reduction,
85       to_combine.front()->replica_groups(),
86       /*constrain_layout=*/false, to_combine.front()->channel_id(),
87       Cast<HloAllReduceInstruction>(to_combine.front())
88           ->use_global_device_ids()));
89 
90   // We have to propagate the sharding manually because Domain instructions are
91   // not guaranteed to preserve it for side effecting instructions.
92   if (to_combine.front()->has_sharding()) {
93     combined->set_sharding(to_combine.front()->sharding());
94   }
95   VLOG(1) << "Replacing with : " << combined->ToString();
96 
97   // Replace all the smaller AllReduces with elements of the tuple output
98   // of the single bigger AllReduce.
99   for (int64_t i = 0; i < to_combine.size(); ++i) {
100     auto replace_with = HloInstruction::CreateGetTupleElement(
101         to_combine[i]->shape(), combined, i);
102     TF_RETURN_IF_ERROR(computation.ReplaceWithNewInstruction(
103         to_combine[i], std::move(replace_with)));
104   }
105   return OkStatus();
106 }
107 }  // namespace
108 
AllReduceCombiner(int64_t combine_threshold_in_bytes,int64_t combine_threshold_count)109 AllReduceCombiner::AllReduceCombiner(int64_t combine_threshold_in_bytes,
110                                      int64_t combine_threshold_count)
111     : combine_threshold_in_bytes_(combine_threshold_in_bytes),
112       combine_threshold_count_(combine_threshold_count) {}
113 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)114 StatusOr<bool> AllReduceCombiner::Run(
115     HloModule* module,
116     const absl::flat_hash_set<absl::string_view>& execution_threads) {
117   VLOG(1) << "Running AllReduceCombiner with threshold of "
118           << combine_threshold_in_bytes_ << " bytes";
119 
120   if (combine_threshold_in_bytes_ <= 0 || combine_threshold_count_ <= 0) {
121     VLOG(1) << "Skip AllReduceCombiner because the threshold is zero";
122     return false;
123   }
124 
125   if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
126     VLOG(1) << "Skip AllReduceCombiner because the module contains all-reduce "
127                "with constrained layouts";
128     return false;
129   }
130 
131   bool changed = false;
132   for (HloComputation* computation :
133        module->MakeNonfusionComputations(execution_threads)) {
134     TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));
135 
136     auto key_fn =
137         [&domain_map](
138             const HloInstruction* instruction) -> std::optional<AllReduceKey> {
139       if (instruction->opcode() != HloOpcode::kAllReduce) {
140         return std::nullopt;
141       }
142       return GetAllReduceKey(instruction, domain_map.get());
143     };
144 
145     TF_ASSIGN_OR_RETURN(
146         bool computation_changed,
147         CombineInstructionsByKey<AllReduceKey>(
148             computation, key_fn, &CombineAllReduces,
149             combine_threshold_in_bytes_, combine_threshold_count_));
150     changed |= computation_changed;
151   }
152 
153   return changed;
154 }
155 
156 }  // namespace xla
157