1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
17
18 #include <optional>
19
20 #include "tensorflow/compiler/xla/service/global_device_id.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
24 #include "tensorflow/compiler/xla/util.h"
25
26 namespace xla {
27
28 // Match the instruction to a reduction kind. We can represent and/or of pred as
29 // min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
MatchReductionInstruction(const HloInstruction * hlo)30 std::optional<ReductionKind> MatchReductionInstruction(
31 const HloInstruction* hlo) {
32 PrimitiveType type = hlo->shape().element_type();
33 switch (hlo->opcode()) {
34 case HloOpcode::kAdd:
35 return ReductionKind::SUM;
36 case HloOpcode::kMultiply:
37 return ReductionKind::PRODUCT;
38 case HloOpcode::kMinimum:
39 return ReductionKind::MIN;
40 case HloOpcode::kMaximum:
41 return ReductionKind::MAX;
42 case HloOpcode::kAnd:
43 return type == PRED ? std::optional<ReductionKind>(ReductionKind::MIN)
44 : std::nullopt;
45 case HloOpcode::kOr:
46 return type == PRED ? std::optional<ReductionKind>(ReductionKind::MAX)
47 : std::nullopt;
48 default:
49 return std::nullopt;
50 }
51 }
52
MatchReductionComputation(const HloComputation * computation)53 std::optional<ReductionKind> MatchReductionComputation(
54 const HloComputation* computation) {
55 namespace m = match;
56 const HloInstruction* root = computation->root_instruction();
57 auto kind = MatchReductionInstruction(root);
58 if (kind && !Match(root, m::Op()
59 .WithBinaryOperandsAnyOrder(m::Parameter(0),
60 m::Parameter(1))
61 .WithShape(m::Shape().IsEffectiveScalar()))) {
62 kind = std::nullopt;
63 }
64 return kind;
65 }
66
GetParticipatingIDs(int current_id,std::optional<int> total_participant_count,absl::Span<const ReplicaGroup> groups)67 StatusOr<std::vector<int>> GetParticipatingIDs(
68 int current_id, std::optional<int> total_participant_count,
69 absl::Span<const ReplicaGroup> groups) {
70 // Empty replica_groups() means that all replicas participate.
71 if (groups.empty()) {
72 TF_RET_CHECK(total_participant_count.has_value());
73 std::vector<int> all_participants(*total_participant_count);
74 absl::c_iota(all_participants, 0);
75 return all_participants;
76 }
77
78 // Figure out the other replicas that go together with this one.
79 std::optional<ReplicaGroup> group;
80 for (const ReplicaGroup& g : groups) {
81 if (absl::c_linear_search(g.replica_ids(), current_id)) {
82 TF_RET_CHECK(!group.has_value())
83 << "ID " << current_id << " appears twice in replica groups";
84 group = g;
85 }
86 }
87 TF_RET_CHECK(group.has_value())
88 << "ID " << current_id << " doesn't appear in replica groups";
89 return std::vector<int>(group->replica_ids().begin(),
90 group->replica_ids().end());
91 }
92
93 // Returns the group formation mode implied by (a) whether the operation has
94 // channel_id and (b) if it has use_global_device_ids and if yes, its value.
GetCollectiveOpGroupMode(bool has_channel_id,std::optional<bool> use_global_device_ids)95 StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
96 bool has_channel_id, std::optional<bool> use_global_device_ids) {
97 if (!has_channel_id) {
98 if (!use_global_device_ids.has_value() || !*use_global_device_ids) {
99 return CollectiveOpGroupMode::kCrossReplica;
100 } else {
101 return InvalidArgument(
102 "Invalid combination of has_channel_id and use_global_device_ids");
103 }
104 } else {
105 if (!use_global_device_ids.has_value()) {
106 return CollectiveOpGroupMode::kCrossPartition;
107 } else if (!*use_global_device_ids) {
108 return CollectiveOpGroupMode::kCrossReplicaAndPartition;
109 } else {
110 return CollectiveOpGroupMode::kFlattenedID;
111 }
112 }
113 }
114
CollectiveOpGroupModeToString(CollectiveOpGroupMode group_mode)115 absl::string_view CollectiveOpGroupModeToString(
116 CollectiveOpGroupMode group_mode) {
117 switch (group_mode) {
118 case CollectiveOpGroupMode::kCrossReplica:
119 return "kCrossReplica";
120 case CollectiveOpGroupMode::kCrossPartition:
121 return "kCrossPartition";
122 case CollectiveOpGroupMode::kCrossReplicaAndPartition:
123 return "kCrossReplicaAndPartition";
124 case CollectiveOpGroupMode::kFlattenedID:
125 return "kFlattenedID";
126 }
127 }
128
129 StatusOr<std::vector<std::vector<GlobalDeviceId>>>
GetParticipatingDevicesGroups(const DeviceAssignment & device_assignment,absl::Span<const ReplicaGroup> replica_groups,CollectiveOpGroupMode group_mode)130 GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment,
131 absl::Span<const ReplicaGroup> replica_groups,
132 CollectiveOpGroupMode group_mode) {
133 int replica_count = device_assignment.replica_count();
134 int partition_count = device_assignment.computation_count();
135
136 std::vector<ReplicaGroup> participating_replica_groups =
137 SpanToVector(replica_groups);
138
139 // If replica groups are empty, assume a group with all replicas.
140 if (replica_groups.empty()) {
141 if (group_mode == CollectiveOpGroupMode::kFlattenedID) {
142 // replica groups contain flattened-ids and cannot be empty.
143 TF_RET_CHECK(!replica_groups.empty())
144 << "replica groups cannot be empty for kFlattenedID mode";
145 }
146
147 int total_participant_count;
148 if (group_mode == CollectiveOpGroupMode::kCrossPartition) {
149 // replica group are partition ids.
150 total_participant_count = partition_count;
151 } else {
152 // replica group are replica ids.
153 total_participant_count = replica_count;
154 }
155
156 ReplicaGroup replica_group = ReplicaGroup();
157 for (int id = 0; id < total_participant_count; id++) {
158 replica_group.add_replica_ids(id);
159 }
160 participating_replica_groups.push_back(replica_group);
161 }
162
163 std::vector<std::vector<GlobalDeviceId>> groups;
164 switch (group_mode) {
165 case CollectiveOpGroupMode::kCrossReplica: {
166 for (const auto& replica_group : participating_replica_groups) {
167 // replica_group contains replica id, participants contains all
168 // replica_group's replica_ids for the current partition.
169 for (int partition_id = 0; partition_id < partition_count;
170 partition_id++) {
171 std::vector<GlobalDeviceId> participants;
172 participants.reserve(replica_group.replica_ids().size());
173
174 for (int replica_id : replica_group.replica_ids()) {
175 participants.emplace_back(
176 device_assignment(replica_id, partition_id));
177 }
178 groups.push_back(participants);
179 }
180 }
181 return groups;
182 }
183 case CollectiveOpGroupMode::kCrossPartition: {
184 for (const auto& replica_group : participating_replica_groups) {
185 // replica_group contains partition id, participants contains all
186 // replica_group's partition_ids for the current replica_id.
187 for (int replica_id = 0; replica_id < replica_count; replica_id++) {
188 std::vector<GlobalDeviceId> participants;
189 participants.reserve(replica_group.replica_ids().size());
190
191 for (int partition_id : replica_group.replica_ids()) {
192 participants.emplace_back(
193 device_assignment(replica_id, partition_id));
194 }
195 groups.push_back(participants);
196 }
197 }
198 return groups;
199 }
200 case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
201 for (const auto& replica_group : participating_replica_groups) {
202 std::vector<GlobalDeviceId> participants;
203 participants.reserve(replica_group.replica_ids().size() *
204 partition_count);
205
206 // replica_group contains replica id, participants contains all
207 // replica_group's replica_ids for all partitions.
208 for (int replica_id : replica_group.replica_ids()) {
209 for (int partition_id = 0; partition_id < partition_count;
210 partition_id++) {
211 participants.emplace_back(
212 device_assignment(replica_id, partition_id));
213 }
214 }
215 groups.push_back(participants);
216 }
217 return groups;
218 }
219 case CollectiveOpGroupMode::kFlattenedID: {
220 for (const auto& replica_group : participating_replica_groups) {
221 std::vector<GlobalDeviceId> participants;
222 participants.reserve(replica_group.replica_ids().size());
223
224 for (int flattened_id : replica_group.replica_ids()) {
225 // Map from flattened id back to replica_id, partition_id.
226 int replica_id = flattened_id / partition_count;
227 int partition_id = flattened_id % partition_count;
228 participants.emplace_back(
229 device_assignment(replica_id, partition_id));
230 }
231 groups.push_back(participants);
232 }
233 return groups;
234 }
235 }
236 }
237
GetParticipatingDevices(GlobalDeviceId device_id,const DeviceAssignment & device_assignment,absl::Span<const ReplicaGroup> replica_groups,CollectiveOpGroupMode group_mode)238 StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
239 GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
240 absl::Span<const ReplicaGroup> replica_groups,
241 CollectiveOpGroupMode group_mode) {
242 int replica_count = device_assignment.replica_count();
243 int partition_count = device_assignment.computation_count();
244
245 TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID logical_id,
246 device_assignment.LogicalIdForDevice(device_id));
247 int current_replica_id = logical_id.replica_id;
248 int current_partition_id = logical_id.computation_id;
249
250 std::vector<GlobalDeviceId> participants;
251 switch (group_mode) {
252 case CollectiveOpGroupMode::kCrossReplica: {
253 // This is a cross replica operation. replica group contains replica id.
254 // use current replica id to find the set of participating replicas. If
255 // replica groups are empty, assume a group with all replicas.
256 TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
257 GetParticipatingIDs(current_replica_id, replica_count,
258 replica_groups));
259
260 // The set of participating devices is the replicas from the current
261 // partition.
262 participants.reserve(participating_replicas.size());
263 for (int replica_id : participating_replicas) {
264 participants.emplace_back(
265 device_assignment(replica_id, current_partition_id));
266 }
267 return participants;
268 }
269
270 case CollectiveOpGroupMode::kCrossPartition: {
271 // replica_groups contain partition_id, group contains all partitions for
272 // the current replica.
273 TF_ASSIGN_OR_RETURN(std::vector<int> participating_partitions,
274 GetParticipatingIDs(current_partition_id,
275 partition_count, replica_groups));
276 participants.reserve(participating_partitions.size());
277 for (int partition_id : participating_partitions) {
278 participants.emplace_back(
279 device_assignment(current_replica_id, partition_id));
280 }
281 return participants;
282 }
283
284 case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
285 // replica_groups contain replica_ids. Group contains replicas for all
286 // partitions.
287 TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
288 GetParticipatingIDs(current_replica_id, replica_count,
289 replica_groups));
290 participants.reserve(participating_replicas.size() * partition_count);
291 for (int replica_id : participating_replicas) {
292 for (int partition_id = 0; partition_id < partition_count;
293 ++partition_id) {
294 participants.emplace_back(
295 device_assignment(replica_id, partition_id));
296 }
297 }
298 return participants;
299 }
300
301 case CollectiveOpGroupMode::kFlattenedID: {
302 // replica groups contain flattened-ids and cannot be empty.
303 TF_RET_CHECK(!replica_groups.empty())
304 << "replica groups cannot be empty for kFlattenedID mode";
305
306 int current_flattened_id =
307 current_replica_id * partition_count + current_partition_id;
308
309 // Find participants based on flattened id. replica_groups cannot be empty
310 // so no need to pass in total_participant_count.
311 TF_ASSIGN_OR_RETURN(
312 std::vector<int> participating_flattened_ids,
313 GetParticipatingIDs(current_flattened_id,
314 /*total_participant_count=*/std::nullopt,
315 replica_groups));
316
317 participants.reserve(participating_flattened_ids.size());
318 for (int flattened_id : participating_flattened_ids) {
319 // Map from flattened id back to replica_id, partition_id.
320 int replica_id = flattened_id / partition_count;
321 int partition_id = flattened_id % partition_count;
322 participants.emplace_back(device_assignment(replica_id, partition_id));
323 }
324 return participants;
325 }
326 }
327 }
328
ReplicaGroupsOrthogonal(absl::Span<const ReplicaGroup> first,absl::Span<const ReplicaGroup> second)329 bool ReplicaGroupsOrthogonal(absl::Span<const ReplicaGroup> first,
330 absl::Span<const ReplicaGroup> second) {
331 if (first.size() != second[0].replica_ids_size()) {
332 return false;
333 }
334 if (first[0].replica_ids_size() != second.size()) {
335 return false;
336 }
337 for (int64_t i = 0; i < first.size(); ++i) {
338 for (int64_t j = 0; j < first[i].replica_ids_size(); ++j) {
339 if (first[i].replica_ids(j) != second[j].replica_ids(i)) {
340 return false;
341 }
342 }
343 }
344 return true;
345 }
346
IsCollective(const HloInstruction * instruction)347 bool IsCollective(const HloInstruction* instruction) {
348 switch (instruction->opcode()) {
349 case HloOpcode::kAllReduce:
350 case HloOpcode::kAllReduceStart:
351 case HloOpcode::kAllReduceDone:
352 case HloOpcode::kAllGather:
353 case HloOpcode::kAllGatherStart:
354 case HloOpcode::kAllGatherDone:
355 case HloOpcode::kAllToAll:
356 case HloOpcode::kCollectivePermute:
357 case HloOpcode::kCollectivePermuteStart:
358 case HloOpcode::kCollectivePermuteDone:
359 return true;
360 case HloOpcode::kFusion:
361 if (instruction->IsCustomFusion()) {
362 for (const auto* inner_inst : instruction->fused_instructions()) {
363 if (IsCollective(inner_inst)) {
364 return true;
365 }
366 }
367 }
368 return false;
369 default:
370 return false;
371 }
372 }
373
374 } // end namespace xla
375