1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/reduce_scatter_combiner.h"
17
18 #include <algorithm>
19 #include <list>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/compiler/xla/service/all_reduce_key.h"
28 #include "tensorflow/compiler/xla/service/collective_combiner_utils.h"
29 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/service/hlo_query.h"
36 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42
43 namespace xla {
44 namespace {
45
46 using ReduceScatterKey =
47 std::tuple<AllReduceKey, /*scatter_dimension*/ int64_t>;
48
49 // Combines the elements of to_combine into a single ReduceScatter op. All
50 // entries in to_combine must be ReduceScatter ops with exactly one operand
51 // and the same reduction operation.
CombineReduceScatters(absl::Span<HloInstruction * const> to_combine)52 Status CombineReduceScatters(absl::Span<HloInstruction* const> to_combine) {
53 if (to_combine.size() < 2) {
54 return OkStatus();
55 }
56 VLOG(1) << "Combined " << to_combine.size() << " reduce-scatter ops";
57
58 HloComputation& computation = *to_combine.back()->parent();
59 HloComputation* reduction = to_combine[0]->to_apply();
60 std::optional<ReductionKind> first_reduction_kind =
61 MatchReductionComputation(reduction);
62 TF_RET_CHECK(first_reduction_kind);
63
64 // Create a single bigger ReduceScatter of the operands of the smaller
65 // ReduceScatters.
66 std::vector<HloInstruction*> operands;
67 std::vector<const Shape*> output_shapes;
68 VLOG(1) << "Combining set";
69 for (HloInstruction* hlo : to_combine) {
70 VLOG(1) << "Set element: " << hlo->ToString();
71 TF_RET_CHECK(hlo->opcode() == HloOpcode::kReduceScatter);
72 TF_RET_CHECK(hlo->operands().size() == 1);
73 std::optional<ReductionKind> reduction_kind =
74 MatchReductionComputation(hlo->to_apply());
75 TF_RET_CHECK(reduction_kind);
76 TF_RET_CHECK(*reduction_kind == *first_reduction_kind);
77 TF_RET_CHECK(hlo->shape().IsArray());
78 operands.push_back(hlo->operands()[0]);
79 output_shapes.push_back(&hlo->shape());
80 }
81
82 const auto* rs = Cast<HloReduceScatterInstruction>(to_combine.front());
83
84 HloInstruction* combined;
85 // AllReduce ops with more than one operand produce a tuple.
86 TF_RET_CHECK(operands.size() >= 2);
87 combined = computation.AddInstruction(HloInstruction::CreateReduceScatter(
88 ShapeUtil::MakeTupleShapeWithPtrs(output_shapes), operands, reduction,
89 to_combine.front()->replica_groups(),
90 /*constrain_layout=*/false, to_combine.front()->channel_id(),
91 rs->use_global_device_ids(), rs->scatter_dimension()));
92
93 // We have to propagate the sharding manually because Domain instructions are
94 // not guaranteed to preserve it for side effecting instructions.
95 if (to_combine.front()->has_sharding()) {
96 combined->set_sharding(to_combine.front()->sharding());
97 }
98 VLOG(1) << "Replacing with : " << combined->ToString();
99
100 // Replace all the smaller ReduceScatters with elements of the tuple output
101 // of the single bigger ReduceScatter.
102 for (int64_t i = 0; i < to_combine.size(); ++i) {
103 auto replace_with = HloInstruction::CreateGetTupleElement(
104 to_combine[i]->shape(), combined, i);
105 TF_RETURN_IF_ERROR(computation.ReplaceWithNewInstruction(
106 to_combine[i], std::move(replace_with)));
107 }
108 return OkStatus();
109 }
110 } // namespace
111
ReduceScatterCombiner(int64_t combine_threshold_in_bytes,int64_t combine_threshold_count)112 ReduceScatterCombiner::ReduceScatterCombiner(int64_t combine_threshold_in_bytes,
113 int64_t combine_threshold_count)
114 : combine_threshold_in_bytes_(combine_threshold_in_bytes),
115 combine_threshold_count_(combine_threshold_count) {}
116
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)117 StatusOr<bool> ReduceScatterCombiner::Run(
118 HloModule* module,
119 const absl::flat_hash_set<absl::string_view>& execution_threads) {
120 VLOG(1) << "Running ReduceScatterCombiner with threshold of "
121 << combine_threshold_in_bytes_ << " bytes";
122
123 if (combine_threshold_in_bytes_ <= 0 || combine_threshold_count_ <= 0) {
124 VLOG(1) << "Skip ReduceScatterCombiner because the threshold is zero";
125 return false;
126 }
127
128 if (hlo_query::ContainsLayoutConstrainedCollective(
129 *module, HloOpcode::kReduceScatter)) {
130 VLOG(1) << "Skip ReduceScatterCombiner because the module contains "
131 "reduce-scatter "
132 "with constrained layouts";
133 return false;
134 }
135
136 bool changed = false;
137 for (HloComputation* computation :
138 module->MakeNonfusionComputations(execution_threads)) {
139 TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));
140
141 auto key_fn = [&domain_map](const HloInstruction* instruction)
142 -> std::optional<ReduceScatterKey> {
143 auto* rs = DynCast<HloReduceScatterInstruction>(instruction);
144 std::optional<AllReduceKey> key =
145 GetAllReduceKey(instruction, domain_map.get());
146
147 if (!rs || !key) {
148 return std::nullopt;
149 }
150 return ReduceScatterKey{std::move(*key), rs->scatter_dimension()};
151 };
152
153 TF_ASSIGN_OR_RETURN(
154 bool computation_changed,
155 CombineInstructionsByKey<ReduceScatterKey>(
156 computation, key_fn, &CombineReduceScatters,
157 combine_threshold_in_bytes_, combine_threshold_count_));
158 changed |= computation_changed;
159 }
160
161 return changed;
162 }
163
164 } // namespace xla
165