xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/collective_ops_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
17 
18 #include <optional>
19 
20 #include "tensorflow/compiler/xla/service/global_device_id.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
24 #include "tensorflow/compiler/xla/util.h"
25 
26 namespace xla {
27 
28 // Match the instruction to a reduction kind. We can represent and/or of pred as
29 // min/max. This works because pred is stored as an 8-bit int of value 0 or 1.
MatchReductionInstruction(const HloInstruction * hlo)30 std::optional<ReductionKind> MatchReductionInstruction(
31     const HloInstruction* hlo) {
32   PrimitiveType type = hlo->shape().element_type();
33   switch (hlo->opcode()) {
34     case HloOpcode::kAdd:
35       return ReductionKind::SUM;
36     case HloOpcode::kMultiply:
37       return ReductionKind::PRODUCT;
38     case HloOpcode::kMinimum:
39       return ReductionKind::MIN;
40     case HloOpcode::kMaximum:
41       return ReductionKind::MAX;
42     case HloOpcode::kAnd:
43       return type == PRED ? std::optional<ReductionKind>(ReductionKind::MIN)
44                           : std::nullopt;
45     case HloOpcode::kOr:
46       return type == PRED ? std::optional<ReductionKind>(ReductionKind::MAX)
47                           : std::nullopt;
48     default:
49       return std::nullopt;
50   }
51 }
52 
MatchReductionComputation(const HloComputation * computation)53 std::optional<ReductionKind> MatchReductionComputation(
54     const HloComputation* computation) {
55   namespace m = match;
56   const HloInstruction* root = computation->root_instruction();
57   auto kind = MatchReductionInstruction(root);
58   if (kind && !Match(root, m::Op()
59                                .WithBinaryOperandsAnyOrder(m::Parameter(0),
60                                                            m::Parameter(1))
61                                .WithShape(m::Shape().IsEffectiveScalar()))) {
62     kind = std::nullopt;
63   }
64   return kind;
65 }
66 
GetParticipatingIDs(int current_id,std::optional<int> total_participant_count,absl::Span<const ReplicaGroup> groups)67 StatusOr<std::vector<int>> GetParticipatingIDs(
68     int current_id, std::optional<int> total_participant_count,
69     absl::Span<const ReplicaGroup> groups) {
70   // Empty replica_groups() means that all replicas participate.
71   if (groups.empty()) {
72     TF_RET_CHECK(total_participant_count.has_value());
73     std::vector<int> all_participants(*total_participant_count);
74     absl::c_iota(all_participants, 0);
75     return all_participants;
76   }
77 
78   // Figure out the other replicas that go together with this one.
79   std::optional<ReplicaGroup> group;
80   for (const ReplicaGroup& g : groups) {
81     if (absl::c_linear_search(g.replica_ids(), current_id)) {
82       TF_RET_CHECK(!group.has_value())
83           << "ID " << current_id << " appears twice in replica groups";
84       group = g;
85     }
86   }
87   TF_RET_CHECK(group.has_value())
88       << "ID " << current_id << " doesn't appear in replica groups";
89   return std::vector<int>(group->replica_ids().begin(),
90                           group->replica_ids().end());
91 }
92 
93 // Returns the group formation mode implied by (a) whether the operation has
94 // channel_id and (b) if it has use_global_device_ids and if yes, its value.
GetCollectiveOpGroupMode(bool has_channel_id,std::optional<bool> use_global_device_ids)95 StatusOr<CollectiveOpGroupMode> GetCollectiveOpGroupMode(
96     bool has_channel_id, std::optional<bool> use_global_device_ids) {
97   if (!has_channel_id) {
98     if (!use_global_device_ids.has_value() || !*use_global_device_ids) {
99       return CollectiveOpGroupMode::kCrossReplica;
100     } else {
101       return InvalidArgument(
102           "Invalid combination of has_channel_id and use_global_device_ids");
103     }
104   } else {
105     if (!use_global_device_ids.has_value()) {
106       return CollectiveOpGroupMode::kCrossPartition;
107     } else if (!*use_global_device_ids) {
108       return CollectiveOpGroupMode::kCrossReplicaAndPartition;
109     } else {
110       return CollectiveOpGroupMode::kFlattenedID;
111     }
112   }
113 }
114 
CollectiveOpGroupModeToString(CollectiveOpGroupMode group_mode)115 absl::string_view CollectiveOpGroupModeToString(
116     CollectiveOpGroupMode group_mode) {
117   switch (group_mode) {
118     case CollectiveOpGroupMode::kCrossReplica:
119       return "kCrossReplica";
120     case CollectiveOpGroupMode::kCrossPartition:
121       return "kCrossPartition";
122     case CollectiveOpGroupMode::kCrossReplicaAndPartition:
123       return "kCrossReplicaAndPartition";
124     case CollectiveOpGroupMode::kFlattenedID:
125       return "kFlattenedID";
126   }
127 }
128 
129 StatusOr<std::vector<std::vector<GlobalDeviceId>>>
GetParticipatingDevicesGroups(const DeviceAssignment & device_assignment,absl::Span<const ReplicaGroup> replica_groups,CollectiveOpGroupMode group_mode)130 GetParticipatingDevicesGroups(const DeviceAssignment& device_assignment,
131                               absl::Span<const ReplicaGroup> replica_groups,
132                               CollectiveOpGroupMode group_mode) {
133   int replica_count = device_assignment.replica_count();
134   int partition_count = device_assignment.computation_count();
135 
136   std::vector<ReplicaGroup> participating_replica_groups =
137       SpanToVector(replica_groups);
138 
139   // If replica groups are empty, assume a group with all replicas.
140   if (replica_groups.empty()) {
141     if (group_mode == CollectiveOpGroupMode::kFlattenedID) {
142       // replica groups contain flattened-ids and cannot be empty.
143       TF_RET_CHECK(!replica_groups.empty())
144           << "replica groups cannot be empty for kFlattenedID mode";
145     }
146 
147     int total_participant_count;
148     if (group_mode == CollectiveOpGroupMode::kCrossPartition) {
149       // replica group are partition ids.
150       total_participant_count = partition_count;
151     } else {
152       // replica group are replica ids.
153       total_participant_count = replica_count;
154     }
155 
156     ReplicaGroup replica_group = ReplicaGroup();
157     for (int id = 0; id < total_participant_count; id++) {
158       replica_group.add_replica_ids(id);
159     }
160     participating_replica_groups.push_back(replica_group);
161   }
162 
163   std::vector<std::vector<GlobalDeviceId>> groups;
164   switch (group_mode) {
165     case CollectiveOpGroupMode::kCrossReplica: {
166       for (const auto& replica_group : participating_replica_groups) {
167         // replica_group contains replica id, participants contains all
168         // replica_group's replica_ids for the current partition.
169         for (int partition_id = 0; partition_id < partition_count;
170              partition_id++) {
171           std::vector<GlobalDeviceId> participants;
172           participants.reserve(replica_group.replica_ids().size());
173 
174           for (int replica_id : replica_group.replica_ids()) {
175             participants.emplace_back(
176                 device_assignment(replica_id, partition_id));
177           }
178           groups.push_back(participants);
179         }
180       }
181       return groups;
182     }
183     case CollectiveOpGroupMode::kCrossPartition: {
184       for (const auto& replica_group : participating_replica_groups) {
185         // replica_group contains partition id, participants contains all
186         // replica_group's partition_ids for the current replica_id.
187         for (int replica_id = 0; replica_id < replica_count; replica_id++) {
188           std::vector<GlobalDeviceId> participants;
189           participants.reserve(replica_group.replica_ids().size());
190 
191           for (int partition_id : replica_group.replica_ids()) {
192             participants.emplace_back(
193                 device_assignment(replica_id, partition_id));
194           }
195           groups.push_back(participants);
196         }
197       }
198       return groups;
199     }
200     case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
201       for (const auto& replica_group : participating_replica_groups) {
202         std::vector<GlobalDeviceId> participants;
203         participants.reserve(replica_group.replica_ids().size() *
204                              partition_count);
205 
206         // replica_group contains replica id, participants contains all
207         // replica_group's replica_ids for all partitions.
208         for (int replica_id : replica_group.replica_ids()) {
209           for (int partition_id = 0; partition_id < partition_count;
210                partition_id++) {
211             participants.emplace_back(
212                 device_assignment(replica_id, partition_id));
213           }
214         }
215         groups.push_back(participants);
216       }
217       return groups;
218     }
219     case CollectiveOpGroupMode::kFlattenedID: {
220       for (const auto& replica_group : participating_replica_groups) {
221         std::vector<GlobalDeviceId> participants;
222         participants.reserve(replica_group.replica_ids().size());
223 
224         for (int flattened_id : replica_group.replica_ids()) {
225           // Map from flattened id back to replica_id, partition_id.
226           int replica_id = flattened_id / partition_count;
227           int partition_id = flattened_id % partition_count;
228           participants.emplace_back(
229               device_assignment(replica_id, partition_id));
230         }
231         groups.push_back(participants);
232       }
233       return groups;
234     }
235   }
236 }
237 
GetParticipatingDevices(GlobalDeviceId device_id,const DeviceAssignment & device_assignment,absl::Span<const ReplicaGroup> replica_groups,CollectiveOpGroupMode group_mode)238 StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(
239     GlobalDeviceId device_id, const DeviceAssignment& device_assignment,
240     absl::Span<const ReplicaGroup> replica_groups,
241     CollectiveOpGroupMode group_mode) {
242   int replica_count = device_assignment.replica_count();
243   int partition_count = device_assignment.computation_count();
244 
245   TF_ASSIGN_OR_RETURN(const DeviceAssignment::LogicalID logical_id,
246                       device_assignment.LogicalIdForDevice(device_id));
247   int current_replica_id = logical_id.replica_id;
248   int current_partition_id = logical_id.computation_id;
249 
250   std::vector<GlobalDeviceId> participants;
251   switch (group_mode) {
252     case CollectiveOpGroupMode::kCrossReplica: {
253       // This is a cross replica operation. replica group contains replica id.
254       // use current replica id to find the set of participating replicas. If
255       // replica groups are empty, assume a group with all replicas.
256       TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
257                           GetParticipatingIDs(current_replica_id, replica_count,
258                                               replica_groups));
259 
260       // The set of participating devices is the replicas from the current
261       // partition.
262       participants.reserve(participating_replicas.size());
263       for (int replica_id : participating_replicas) {
264         participants.emplace_back(
265             device_assignment(replica_id, current_partition_id));
266       }
267       return participants;
268     }
269 
270     case CollectiveOpGroupMode::kCrossPartition: {
271       // replica_groups contain partition_id, group contains all partitions for
272       // the current replica.
273       TF_ASSIGN_OR_RETURN(std::vector<int> participating_partitions,
274                           GetParticipatingIDs(current_partition_id,
275                                               partition_count, replica_groups));
276       participants.reserve(participating_partitions.size());
277       for (int partition_id : participating_partitions) {
278         participants.emplace_back(
279             device_assignment(current_replica_id, partition_id));
280       }
281       return participants;
282     }
283 
284     case CollectiveOpGroupMode::kCrossReplicaAndPartition: {
285       // replica_groups contain replica_ids. Group contains replicas for all
286       // partitions.
287       TF_ASSIGN_OR_RETURN(std::vector<int> participating_replicas,
288                           GetParticipatingIDs(current_replica_id, replica_count,
289                                               replica_groups));
290       participants.reserve(participating_replicas.size() * partition_count);
291       for (int replica_id : participating_replicas) {
292         for (int partition_id = 0; partition_id < partition_count;
293              ++partition_id) {
294           participants.emplace_back(
295               device_assignment(replica_id, partition_id));
296         }
297       }
298       return participants;
299     }
300 
301     case CollectiveOpGroupMode::kFlattenedID: {
302       // replica groups contain flattened-ids and cannot be empty.
303       TF_RET_CHECK(!replica_groups.empty())
304           << "replica groups cannot be empty for kFlattenedID mode";
305 
306       int current_flattened_id =
307           current_replica_id * partition_count + current_partition_id;
308 
309       // Find participants based on flattened id. replica_groups cannot be empty
310       // so no need to pass in total_participant_count.
311       TF_ASSIGN_OR_RETURN(
312           std::vector<int> participating_flattened_ids,
313           GetParticipatingIDs(current_flattened_id,
314                               /*total_participant_count=*/std::nullopt,
315                               replica_groups));
316 
317       participants.reserve(participating_flattened_ids.size());
318       for (int flattened_id : participating_flattened_ids) {
319         // Map from flattened id back to replica_id, partition_id.
320         int replica_id = flattened_id / partition_count;
321         int partition_id = flattened_id % partition_count;
322         participants.emplace_back(device_assignment(replica_id, partition_id));
323       }
324       return participants;
325     }
326   }
327 }
328 
ReplicaGroupsOrthogonal(absl::Span<const ReplicaGroup> first,absl::Span<const ReplicaGroup> second)329 bool ReplicaGroupsOrthogonal(absl::Span<const ReplicaGroup> first,
330                              absl::Span<const ReplicaGroup> second) {
331   if (first.size() != second[0].replica_ids_size()) {
332     return false;
333   }
334   if (first[0].replica_ids_size() != second.size()) {
335     return false;
336   }
337   for (int64_t i = 0; i < first.size(); ++i) {
338     for (int64_t j = 0; j < first[i].replica_ids_size(); ++j) {
339       if (first[i].replica_ids(j) != second[j].replica_ids(i)) {
340         return false;
341       }
342     }
343   }
344   return true;
345 }
346 
IsCollective(const HloInstruction * instruction)347 bool IsCollective(const HloInstruction* instruction) {
348   switch (instruction->opcode()) {
349     case HloOpcode::kAllReduce:
350     case HloOpcode::kAllReduceStart:
351     case HloOpcode::kAllReduceDone:
352     case HloOpcode::kAllGather:
353     case HloOpcode::kAllGatherStart:
354     case HloOpcode::kAllGatherDone:
355     case HloOpcode::kAllToAll:
356     case HloOpcode::kCollectivePermute:
357     case HloOpcode::kCollectivePermuteStart:
358     case HloOpcode::kCollectivePermuteDone:
359       return true;
360     case HloOpcode::kFusion:
361       if (instruction->IsCustomFusion()) {
362         for (const auto* inner_inst : instruction->fused_instructions()) {
363           if (IsCollective(inner_inst)) {
364             return true;
365           }
366         }
367       }
368       return false;
369     default:
370       return false;
371   }
372 }
373 
374 }  // end namespace xla
375