1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
17
18 #include <algorithm>
19 #include <list>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/service/all_reduce_key.h"
30 #include "tensorflow/compiler/xla/service/collective_combiner_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
35 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
36 #include "tensorflow/compiler/xla/service/hlo_query.h"
37 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
38 #include "tensorflow/compiler/xla/service/shape_inference.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/errors.h"
43
44 namespace xla {
45 namespace {
46
47 // Combines the elements of to_combine into a single AllReduce op. All
48 // entries in to_combine must be AllReduce ops with exactly one operand
49 // and the same reduction operation.
CombineAllReduces(absl::Span<HloInstruction * const> to_combine)50 Status CombineAllReduces(absl::Span<HloInstruction* const> to_combine) {
51 if (to_combine.size() < 2) {
52 return OkStatus();
53 }
54 VLOG(1) << "Combined " << to_combine.size() << " CRS ops";
55
56 HloComputation& computation = *to_combine.back()->parent();
57 HloComputation* reduction = to_combine[0]->to_apply();
58 const HloOpcode type = reduction->root_instruction()->opcode();
59
60 // Create a single bigger AllReduce of the operands of the smaller
61 // AllReduces.
62 std::vector<HloInstruction*> operands;
63 std::vector<const Shape*> operand_shapes;
64 VLOG(1) << "Combining set";
65 for (HloInstruction* hlo : to_combine) {
66 VLOG(1) << "Set element: " << hlo->ToString();
67 TF_RET_CHECK(hlo->opcode() == HloOpcode::kAllReduce);
68 TF_RET_CHECK(hlo->operands().size() == 1);
69 TF_RET_CHECK(hlo->to_apply() == reduction ||
70 (hlo->to_apply()->instruction_count() == 3 &&
71 hlo->to_apply()->num_parameters() == 2 &&
72 hlo->to_apply()->root_instruction()->opcode() == type));
73 TF_RET_CHECK(hlo->shape().IsArray());
74 for (HloInstruction* operand : hlo->operands()) {
75 operands.push_back(operand);
76 operand_shapes.push_back(&operand->shape());
77 }
78 }
79
80 HloInstruction* combined;
81 // AllReduce ops with more than one operand produce a tuple.
82 TF_RET_CHECK(operands.size() >= 2);
83 combined = computation.AddInstruction(HloInstruction::CreateAllReduce(
84 ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes), operands, reduction,
85 to_combine.front()->replica_groups(),
86 /*constrain_layout=*/false, to_combine.front()->channel_id(),
87 Cast<HloAllReduceInstruction>(to_combine.front())
88 ->use_global_device_ids()));
89
90 // We have to propagate the sharding manually because Domain instructions are
91 // not guaranteed to preserve it for side effecting instructions.
92 if (to_combine.front()->has_sharding()) {
93 combined->set_sharding(to_combine.front()->sharding());
94 }
95 VLOG(1) << "Replacing with : " << combined->ToString();
96
97 // Replace all the smaller AllReduces with elements of the tuple output
98 // of the single bigger AllReduce.
99 for (int64_t i = 0; i < to_combine.size(); ++i) {
100 auto replace_with = HloInstruction::CreateGetTupleElement(
101 to_combine[i]->shape(), combined, i);
102 TF_RETURN_IF_ERROR(computation.ReplaceWithNewInstruction(
103 to_combine[i], std::move(replace_with)));
104 }
105 return OkStatus();
106 }
107 } // namespace
108
AllReduceCombiner(int64_t combine_threshold_in_bytes,int64_t combine_threshold_count)109 AllReduceCombiner::AllReduceCombiner(int64_t combine_threshold_in_bytes,
110 int64_t combine_threshold_count)
111 : combine_threshold_in_bytes_(combine_threshold_in_bytes),
112 combine_threshold_count_(combine_threshold_count) {}
113
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)114 StatusOr<bool> AllReduceCombiner::Run(
115 HloModule* module,
116 const absl::flat_hash_set<absl::string_view>& execution_threads) {
117 VLOG(1) << "Running AllReduceCombiner with threshold of "
118 << combine_threshold_in_bytes_ << " bytes";
119
120 if (combine_threshold_in_bytes_ <= 0 || combine_threshold_count_ <= 0) {
121 VLOG(1) << "Skip AllReduceCombiner because the threshold is zero";
122 return false;
123 }
124
125 if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) {
126 VLOG(1) << "Skip AllReduceCombiner because the module contains all-reduce "
127 "with constrained layouts";
128 return false;
129 }
130
131 bool changed = false;
132 for (HloComputation* computation :
133 module->MakeNonfusionComputations(execution_threads)) {
134 TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));
135
136 auto key_fn =
137 [&domain_map](
138 const HloInstruction* instruction) -> std::optional<AllReduceKey> {
139 if (instruction->opcode() != HloOpcode::kAllReduce) {
140 return std::nullopt;
141 }
142 return GetAllReduceKey(instruction, domain_map.get());
143 };
144
145 TF_ASSIGN_OR_RETURN(
146 bool computation_changed,
147 CombineInstructionsByKey<AllReduceKey>(
148 computation, key_fn, &CombineAllReduces,
149 combine_threshold_in_bytes_, combine_threshold_count_));
150 changed |= computation_changed;
151 }
152
153 return changed;
154 }
155
156 } // namespace xla
157