xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/reduce_scatter_combiner.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/reduce_scatter_combiner.h"
17 
18 #include <algorithm>
19 #include <list>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/compiler/xla/service/all_reduce_key.h"
28 #include "tensorflow/compiler/xla/service/collective_combiner_utils.h"
29 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/service/hlo_query.h"
36 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/lib/core/errors.h"
42 
43 namespace xla {
44 namespace {
45 
46 using ReduceScatterKey =
47     std::tuple<AllReduceKey, /*scatter_dimension*/ int64_t>;
48 
49 // Combines the elements of to_combine into a single ReduceScatter op. All
50 // entries in to_combine must be ReduceScatter ops with exactly one operand
51 // and the same reduction operation.
CombineReduceScatters(absl::Span<HloInstruction * const> to_combine)52 Status CombineReduceScatters(absl::Span<HloInstruction* const> to_combine) {
53   if (to_combine.size() < 2) {
54     return OkStatus();
55   }
56   VLOG(1) << "Combined " << to_combine.size() << " reduce-scatter ops";
57 
58   HloComputation& computation = *to_combine.back()->parent();
59   HloComputation* reduction = to_combine[0]->to_apply();
60   std::optional<ReductionKind> first_reduction_kind =
61       MatchReductionComputation(reduction);
62   TF_RET_CHECK(first_reduction_kind);
63 
64   // Create a single bigger ReduceScatter of the operands of the smaller
65   // ReduceScatters.
66   std::vector<HloInstruction*> operands;
67   std::vector<const Shape*> output_shapes;
68   VLOG(1) << "Combining set";
69   for (HloInstruction* hlo : to_combine) {
70     VLOG(1) << "Set element: " << hlo->ToString();
71     TF_RET_CHECK(hlo->opcode() == HloOpcode::kReduceScatter);
72     TF_RET_CHECK(hlo->operands().size() == 1);
73     std::optional<ReductionKind> reduction_kind =
74         MatchReductionComputation(hlo->to_apply());
75     TF_RET_CHECK(reduction_kind);
76     TF_RET_CHECK(*reduction_kind == *first_reduction_kind);
77     TF_RET_CHECK(hlo->shape().IsArray());
78     operands.push_back(hlo->operands()[0]);
79     output_shapes.push_back(&hlo->shape());
80   }
81 
82   const auto* rs = Cast<HloReduceScatterInstruction>(to_combine.front());
83 
84   HloInstruction* combined;
85   // AllReduce ops with more than one operand produce a tuple.
86   TF_RET_CHECK(operands.size() >= 2);
87   combined = computation.AddInstruction(HloInstruction::CreateReduceScatter(
88       ShapeUtil::MakeTupleShapeWithPtrs(output_shapes), operands, reduction,
89       to_combine.front()->replica_groups(),
90       /*constrain_layout=*/false, to_combine.front()->channel_id(),
91       rs->use_global_device_ids(), rs->scatter_dimension()));
92 
93   // We have to propagate the sharding manually because Domain instructions are
94   // not guaranteed to preserve it for side effecting instructions.
95   if (to_combine.front()->has_sharding()) {
96     combined->set_sharding(to_combine.front()->sharding());
97   }
98   VLOG(1) << "Replacing with : " << combined->ToString();
99 
100   // Replace all the smaller ReduceScatters with elements of the tuple output
101   // of the single bigger ReduceScatter.
102   for (int64_t i = 0; i < to_combine.size(); ++i) {
103     auto replace_with = HloInstruction::CreateGetTupleElement(
104         to_combine[i]->shape(), combined, i);
105     TF_RETURN_IF_ERROR(computation.ReplaceWithNewInstruction(
106         to_combine[i], std::move(replace_with)));
107   }
108   return OkStatus();
109 }
110 }  // namespace
111 
ReduceScatterCombiner(int64_t combine_threshold_in_bytes,int64_t combine_threshold_count)112 ReduceScatterCombiner::ReduceScatterCombiner(int64_t combine_threshold_in_bytes,
113                                              int64_t combine_threshold_count)
114     : combine_threshold_in_bytes_(combine_threshold_in_bytes),
115       combine_threshold_count_(combine_threshold_count) {}
116 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)117 StatusOr<bool> ReduceScatterCombiner::Run(
118     HloModule* module,
119     const absl::flat_hash_set<absl::string_view>& execution_threads) {
120   VLOG(1) << "Running ReduceScatterCombiner with threshold of "
121           << combine_threshold_in_bytes_ << " bytes";
122 
123   if (combine_threshold_in_bytes_ <= 0 || combine_threshold_count_ <= 0) {
124     VLOG(1) << "Skip ReduceScatterCombiner because the threshold is zero";
125     return false;
126   }
127 
128   if (hlo_query::ContainsLayoutConstrainedCollective(
129           *module, HloOpcode::kReduceScatter)) {
130     VLOG(1) << "Skip ReduceScatterCombiner because the module contains "
131                "reduce-scatter "
132                "with constrained layouts";
133     return false;
134   }
135 
136   bool changed = false;
137   for (HloComputation* computation :
138        module->MakeNonfusionComputations(execution_threads)) {
139     TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));
140 
141     auto key_fn = [&domain_map](const HloInstruction* instruction)
142         -> std::optional<ReduceScatterKey> {
143       auto* rs = DynCast<HloReduceScatterInstruction>(instruction);
144       std::optional<AllReduceKey> key =
145           GetAllReduceKey(instruction, domain_map.get());
146 
147       if (!rs || !key) {
148         return std::nullopt;
149       }
150       return ReduceScatterKey{std::move(*key), rs->scatter_dimension()};
151     };
152 
153     TF_ASSIGN_OR_RETURN(
154         bool computation_changed,
155         CombineInstructionsByKey<ReduceScatterKey>(
156             computation, key_fn, &CombineReduceScatters,
157             combine_threshold_in_bytes_, combine_threshold_count_));
158     changed |= computation_changed;
159   }
160 
161   return changed;
162 }
163 
164 }  // namespace xla
165