1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/reduce_scatter_utils.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20
21 namespace xla {
22 namespace {
23
IsTableLookup(const HloInstruction * hlo)24 bool IsTableLookup(const HloInstruction* hlo) {
25 while (hlo->opcode() == HloOpcode::kBitcast ||
26 hlo->opcode() == HloOpcode::kReshape ||
27 hlo->opcode() == HloOpcode::kCopy) {
28 hlo = hlo->operand(0);
29 }
30 return hlo->opcode() == HloOpcode::kDynamicSlice &&
31 (hlo->operand(0)->IsConstant() ||
32 hlo->operand(0)->opcode() == HloOpcode::kIota) &&
33 hlo->operand(0)->shape().rank() == 1 &&
34 (hlo->operand(0)->shape().element_type() == S32 ||
35 hlo->operand(0)->shape().element_type() == U32);
36 }
37
38 // Function to map a replica/partition/global ID to an offset in the offset
39 // table, based on the given scalar offset HLO. For example, if the HLO is
40 // kPartitionId but the all-reduce uses global IDs, then the function maps
41 // global IDs to partition IDs. It returns -1 if the HLO cannot be understood.
42 using MapIdToTableOffset =
43 std::function<int64_t(const HloInstruction*, int64_t)>;
44
GetIndexForId(const HloInstruction * index,int64_t id,const MapIdToTableOffset & map_id)45 int64_t GetIndexForId(const HloInstruction* index, int64_t id,
46 const MapIdToTableOffset& map_id) {
47 // ID itself.
48 int64_t maybe_mapped_id = map_id(index, id);
49 if (maybe_mapped_id >= 0) {
50 return maybe_mapped_id;
51 }
52 if (!IsTableLookup(index)) {
53 VLOG(2) << "Index is not table lookup " << index->ToString();
54 return -1;
55 }
56 while (index->opcode() == HloOpcode::kReshape ||
57 index->opcode() == HloOpcode::kBitcast ||
58 index->opcode() == HloOpcode::kCopy) {
59 index = index->operand(0);
60 }
61 int64_t inner_index = GetIndexForId(index->operand(1), id, map_id);
62 if (inner_index < 0) {
63 VLOG(2) << "Failed to get inner index.";
64 return -1;
65 }
66 if (index->operand(0)->opcode() == HloOpcode::kIota) {
67 return inner_index;
68 }
69 // A table lookup.
70 const auto& table = index->operand(0)->literal();
71 return *table.GetIntegralAsS64({inner_index});
72 }
73
IsPerIdOffsets(absl::Span<const HloInstruction * > offsets,int64_t shard_size,const MapIdToTableOffset & map_id,std::vector<int64_t> slice_group_sizes,const HloAllReduceInstruction * ar)74 bool IsPerIdOffsets(absl::Span<const HloInstruction*> offsets,
75 int64_t shard_size, const MapIdToTableOffset& map_id,
76 std::vector<int64_t> slice_group_sizes,
77 const HloAllReduceInstruction* ar) {
78 if (offsets.size() != slice_group_sizes.size()) {
79 return false;
80 }
81 if (!ar->IsCrossModuleAllReduce() || !ar->use_global_device_ids()) {
82 return false;
83 }
84
85 int num_groups = ar->replica_groups().size();
86 int num_split_dims = slice_group_sizes.size();
87
88 for (int64_t i = 0; i < num_groups; ++i) {
89 for (int64_t j = 0; j < Product(slice_group_sizes); ++j) {
90 int64_t final_table_entry = 0;
91 int64_t id = ar->replica_groups()[i].replica_ids(j);
92 int64_t slice_group_size = Product(slice_group_sizes);
93 for (int dim = 0; dim < num_split_dims; dim++) {
94 auto scalar_offset = offsets[dim];
95 while (scalar_offset->opcode() == HloOpcode::kReshape ||
96 scalar_offset->opcode() == HloOpcode::kBitcast ||
97 scalar_offset->opcode() == HloOpcode::kCopy) {
98 scalar_offset = scalar_offset->operand(0);
99 }
100 if (!IsTableLookup(scalar_offset)) {
101 return false;
102 }
103 int64_t table_index =
104 GetIndexForId(scalar_offset->operand(1), id, map_id);
105 if (table_index < 0) {
106 return false;
107 }
108
109 int64_t table_entry;
110 if (scalar_offset->operand(0)->opcode() == HloOpcode::kIota) {
111 table_entry = table_index;
112 } else {
113 table_entry = *scalar_offset->operand(0)->literal().GetIntegralAsS64(
114 {table_index});
115 }
116 slice_group_size /= slice_group_sizes[dim];
117 final_table_entry += table_entry * slice_group_size;
118 }
119 if (final_table_entry != shard_size * j) {
120 return false;
121 }
122 }
123 }
124
125 return true;
126 }
127
128 // Returns if `offset` == shard_size * id.
IsPerIdOffset(const HloInstruction * offset,int64_t shard_size,const MapIdToTableOffset & map_id,int64_t group_size,const HloAllReduceInstruction * ar)129 bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size,
130 const MapIdToTableOffset& map_id, int64_t group_size,
131 const HloAllReduceInstruction* ar) {
132 const bool iota_group =
133 ar->replica_groups().empty() ||
134 (ar->IsCrossModuleAllReduce() && !ar->use_global_device_ids());
135
136 if (offset->opcode() == HloOpcode::kMultiply) {
137 // Check if it's constant * IsPerIdOffset(..., shard_size / constant, ...)
138 if (offset->shape().rank() != 0) {
139 VLOG(2) << "Offset is not a scalar " << offset->ToString();
140 return false;
141 }
142 int64_t const_operand = -1;
143 if (offset->operand(0)->IsConstant()) {
144 const_operand = 0;
145 } else if (offset->operand(1)->IsConstant()) {
146 const_operand = 1;
147 } else {
148 VLOG(2) << "Offset is not multiple(const, ...) " << offset->ToString();
149 return false;
150 }
151 auto multiplier =
152 offset->operand(const_operand)->literal().GetIntegralAsS64({});
153 if (!multiplier || shard_size % *multiplier != 0) {
154 VLOG(2) << "Multiplier is unknown or cannot evenly divide shard size "
155 << offset->operand(const_operand);
156 return false;
157 }
158 return IsPerIdOffset(offset->operand(1 - const_operand),
159 shard_size / *multiplier, map_id, group_size, ar);
160 }
161 if (shard_size == 1 && iota_group) {
162 bool id_mapping_is_identity = true;
163 for (int64_t id = 0; id < group_size; ++id) {
164 int64_t mapped_id = map_id(offset, id);
165 if (mapped_id != id) {
166 id_mapping_is_identity = false;
167 break;
168 }
169 }
170 if (id_mapping_is_identity) {
171 return true;
172 }
173 }
174 if (offset->opcode() == HloOpcode::kBitcast ||
175 offset->opcode() == HloOpcode::kReshape ||
176 offset->opcode() == HloOpcode::kCopy) {
177 return IsPerIdOffset(offset->operand(0), shard_size, map_id, group_size,
178 ar);
179 }
180
181 if (offset->opcode() == HloOpcode::kConvert &&
182 offset->operand(0)->shape().IsInteger() &&
183 primitive_util::BitWidth(offset->operand(0)->shape().element_type()) <=
184 primitive_util::BitWidth(offset->shape().element_type())) {
185 return IsPerIdOffset(offset->operand(0), shard_size, map_id, group_size,
186 ar);
187 }
188
189 if (offset->opcode() == HloOpcode::kClamp) {
190 auto lower_bound = offset->operand(0)->literal().GetIntegralAsS64({});
191 auto upper_bound = offset->operand(2)->literal().GetIntegralAsS64({});
192 if (!lower_bound || !upper_bound || *lower_bound != 0 ||
193 *upper_bound < (group_size - 1) * shard_size) {
194 VLOG(2) << "Boundaries of the clamp is not legal: " << offset->ToString();
195 return false;
196 }
197 return IsPerIdOffset(offset->operand(1), shard_size, map_id, group_size,
198 ar);
199 }
200
201 const int64_t num_groups = iota_group ? 1 : ar->replica_groups().size();
202 if (IsTableLookup(offset)) {
203 // Check the values of the offset table, and see if they are shard_index *
204 // shard_size.
205 for (int64_t i = 0; i < num_groups; ++i) {
206 for (int64_t j = 0; j < group_size; ++j) {
207 int64_t id = iota_group ? j : ar->replica_groups()[i].replica_ids(j);
208 int64_t table_index = GetIndexForId(offset->operand(1), id, map_id);
209 if (table_index < 0) {
210 VLOG(2) << "Failed to infer table index from "
211 << offset->operand(1)->ToString();
212 return false;
213 }
214
215 int64_t table_entry;
216 if (offset->operand(0)->opcode() == HloOpcode::kIota) {
217 table_entry = table_index;
218 } else {
219 table_entry =
220 *offset->operand(0)->literal().GetIntegralAsS64({table_index});
221 }
222 if (table_entry != shard_size * j) {
223 VLOG(2) << "Unexpected offset from table.";
224 return false;
225 }
226 }
227 }
228
229 // All table entries are good.
230 return true;
231 }
232
233 // Check if the offset is the id itself and it has the right values.
234 for (int64_t i = 0; i < num_groups; ++i) {
235 for (int64_t j = 0; j < group_size; ++j) {
236 int64_t id = iota_group ? j : ar->replica_groups()[i].replica_ids(j);
237 int mapped_id = map_id(offset, id);
238 if (mapped_id != shard_size * j) {
239 VLOG(2) << "Mapping of " << id << " to " << mapped_id
240 << " not matching expected value " << shard_size * j << ": "
241 << offset->ToString();
242 return false;
243 }
244 }
245 }
246
247 return true;
248 }
249
250 } // namespace
251
MatchReduceScatter(const HloAllReduceInstruction * ar,int64_t num_partitions,int64_t num_replicas,bool allow_multiple_split_dims,bool allow_intervening_reshape,int64_t min_rank)252 std::optional<ReduceScatterSpec> MatchReduceScatter(
253 const HloAllReduceInstruction* ar, int64_t num_partitions,
254 int64_t num_replicas, bool allow_multiple_split_dims,
255 bool allow_intervening_reshape, int64_t min_rank) {
256 HloPredicate match_partition_id = [](const HloInstruction* i) {
257 return i->opcode() == HloOpcode::kPartitionId;
258 };
259 HloPredicate match_replica_id = [](const HloInstruction* i) {
260 return i->opcode() == HloOpcode::kReplicaId;
261 };
262 return MatchReduceScatter(ar, num_partitions, num_replicas,
263 allow_multiple_split_dims,
264 allow_intervening_reshape, min_rank,
265 match_partition_id, match_replica_id);
266 }
267
MatchReduceScatter(const HloAllReduceInstruction * ar,int64_t num_partitions,int64_t num_replicas,bool allow_multiple_split_dims,bool allow_intervening_reshape,int64_t min_rank,HloPredicate match_partition_id,HloPredicate match_replica_id)268 std::optional<ReduceScatterSpec> MatchReduceScatter(
269 const HloAllReduceInstruction* ar, int64_t num_partitions,
270 int64_t num_replicas, bool allow_multiple_split_dims,
271 bool allow_intervening_reshape, int64_t min_rank,
272 HloPredicate match_partition_id, HloPredicate match_replica_id) {
273 if (!ar->shape().IsArray() || ar->constrain_layout() ||
274 (ar->IsCrossModuleAllReduce() &&
275 !ar->GetModule()->config().use_spmd_partitioning())) {
276 VLOG(2) << "Unsupported all-reduce: " << ar->ToString();
277 return std::nullopt;
278 }
279 if (ar->shape().rank() - absl::c_count(ar->shape().dimensions(), 1) <
280 min_rank) {
281 VLOG(2) << " Should be at least rank-" << min_rank
282 << " excluding trivial dimensions " << ar->ToString();
283 return std::nullopt;
284 }
285 if (ar->user_count() != 1) {
286 VLOG(2) << "All-reduce user_count > 1 " << ar->ToString();
287 return std::nullopt;
288 }
289 if (ar->replica_groups().size() > 1) {
290 const int64_t size = ar->replica_groups()[0].replica_ids_size();
291 absl::Span<const ReplicaGroup> rgs = ar->replica_groups();
292 const bool has_uniform_size = absl::c_all_of(
293 rgs.subspan(1, size - 1), [size](const ReplicaGroup& group) {
294 return group.replica_ids_size() == size;
295 });
296 if (!has_uniform_size) {
297 VLOG(2) << "Unsupported non-uniform replica group size "
298 << ar->ToString();
299 return std::nullopt;
300 }
301 }
302
303 HloInstruction* user = ar->users()[0];
304 HloInstruction* reshape = nullptr;
305 if (allow_intervening_reshape && user->opcode() == HloOpcode::kReshape) {
306 // Allow the intervening reshape if it reshapes just the non scattered
307 // dimension (checked later).
308 reshape = user;
309 if (reshape->user_count() != 1) {
310 VLOG(2) << "Reshape following all-reduce has user count > 1"
311 << reshape->ToString();
312 return std::nullopt;
313 }
314 user = reshape->users().front();
315 }
316 if (user->opcode() != HloOpcode::kDynamicSlice) {
317 VLOG(2) << "All-reduce user is not dynamic slice " << user->ToString();
318 return std::nullopt;
319 }
320 ReduceScatterSpec spec;
321 int64_t group_size;
322 MapIdToTableOffset map_id;
323 spec.dynamic_slice = user;
324 if (!ar->IsCrossModuleAllReduce()) {
325 spec.sharded_replicas = num_replicas;
326 group_size = ar->replica_groups().empty()
327 ? num_replicas
328 : ar->replica_groups()[0].replica_ids_size();
329 map_id = [&](const HloInstruction* hlo, int64_t id) {
330 return match_replica_id(hlo) ? id : -1;
331 };
332 } else if (ar->use_global_device_ids()) {
333 spec.sharded_replicas = num_replicas;
334 spec.sharded_partitions = num_partitions;
335 group_size = ar->replica_groups()[0].replica_ids_size();
336 bool orthogonal_replicas = true;
337 std::vector<int64_t> partition_id_to_index(num_partitions, -1);
338 for (int64_t g = 0; g < ar->replica_groups().size(); ++g) {
339 const auto& group = ar->replica_groups()[g];
340 for (int64_t i = 0; i < group.replica_ids_size(); ++i) {
341 int64_t global_id = group.replica_ids(i);
342 int64_t partition_id = global_id % num_partitions;
343 if (partition_id_to_index[partition_id] == -1) {
344 partition_id_to_index[partition_id] = i;
345 continue;
346 }
347 if (partition_id_to_index[partition_id] != i ||
348 global_id / num_partitions !=
349 group.replica_ids(0) / num_partitions) {
350 orthogonal_replicas = false;
351 break;
352 }
353 }
354 }
355 map_id = [&, orthogonal_replicas](const HloInstruction* hlo, int64_t id) {
356 if (match_replica_id(hlo)) {
357 return num_partitions == 1 ? id : -1;
358 }
359 if (match_partition_id(hlo)) {
360 if (num_replicas == 1) {
361 return id;
362 }
363 return orthogonal_replicas ? id % num_partitions : -1;
364 }
365 auto is_replica_mul_num_partitions = [&](const HloInstruction* operand) {
366 return operand->opcode() == HloOpcode::kMultiply &&
367 ((operand->operand(0)->opcode() == HloOpcode::kReplicaId &&
368 operand->operand(1)->IsConstant() &&
369 operand->operand(1)->literal().GetIntegralAsS64({}) ==
370 num_partitions) ||
371 (operand->operand(1)->opcode() == HloOpcode::kReplicaId &&
372 operand->operand(0)->IsConstant() &&
373 operand->operand(0)->literal().GetIntegralAsS64({}) ==
374 num_partitions));
375 };
376 if (hlo->opcode() == HloOpcode::kAdd &&
377 ((match_partition_id(hlo->operand(0)) &&
378 is_replica_mul_num_partitions(hlo->operand(1))) ||
379 (match_partition_id(hlo->operand(1)) &&
380 is_replica_mul_num_partitions(hlo->operand(0))))) {
381 return id;
382 }
383 return int64_t{-1};
384 };
385 } else {
386 // Right now all cross-partition all-reduces' subgroups refer to replicas
387 // unless they use use_global_device_ids.
388 if (ar->replica_groups().size() != num_replicas ||
389 ar->replica_groups()[0].replica_ids_size() != 1) {
390 VLOG(2) << "Unsupported size > 1 replica groups for cross-partition, "
391 "non-global ID "
392 << ar->ToString();
393 return std::nullopt;
394 }
395 spec.sharded_partitions = num_partitions;
396 group_size = num_partitions;
397 map_id = [&](const HloInstruction* hlo, int64_t id) {
398 return match_partition_id(hlo) ? id : -1;
399 };
400 }
401 if (group_size < 2) {
402 VLOG(2) << "Group_size < 2, nothing to do " << ar->ToString();
403 return std::nullopt;
404 }
405 spec.group_size = group_size;
406 spec.split_dim = -1;
407 std::vector<int64_t> split_dims;
408 // First find a single dimension where the input and output of dynamic slice
409 // differ.
410 int num_dims = 0;
411 for (int64_t dim = 0; dim < ar->shape().rank(); ++dim) {
412 if (ar->shape().dimensions(dim) == user->shape().dimensions(dim)) {
413 continue;
414 }
415 num_dims++;
416 VLOG(2) << "select dim: " << dim;
417 spec.split_dim = dim;
418 }
419 if (spec.split_dim != -1) {
420 if (num_dims == 1) {
421 split_dims.push_back(spec.split_dim);
422 } else {
423 // Recompute split dim.
424 spec.split_dim = -1;
425 }
426 }
427 const Shape& shape = user->operand(0)->shape();
428 if (spec.split_dim == -1) {
429 for (int64_t dim = 0; dim < shape.rank(); ++dim) {
430 auto offset = user->operand(dim + 1);
431 // Skip trivial (1) dimensions or if the index is a constant 0.
432 if (shape.dimensions(dim) == 1 ||
433 (offset->opcode() == HloOpcode::kConstant &&
434 offset->literal().IsZero({}))) {
435 continue;
436 }
437 split_dims.push_back(dim);
438 if (spec.split_dim != -1) {
439 if (!allow_multiple_split_dims || spec.split_dim != (dim - 1)) {
440 VLOG(2) << "Only support split on consecutive dims "
441 << user->ToString();
442 return std::nullopt;
443 }
444 continue;
445 }
446 spec.split_dim = dim;
447 }
448 }
449
450 std::vector<int64_t> group_sizes;
451 group_sizes.reserve(split_dims.size());
452 for (auto dim : split_dims) {
453 group_sizes.push_back(user->operand(0)->shape().dimensions(dim) /
454 user->dynamic_slice_sizes()[dim]);
455 }
456 if (Product(group_sizes) != group_size) {
457 VLOG(2) << "Group size mismatch " << user->ToString() << " vs "
458 << ar->ToString();
459 return std::nullopt;
460 }
461
462 if (split_dims.size() > 1) {
463 std::vector<const HloInstruction*> offsets;
464 int shard_size = 1;
465 for (auto dim : split_dims) {
466 offsets.push_back(user->operand(dim + 1));
467 shard_size *= user->dynamic_slice_sizes()[dim];
468 }
469
470 if (!IsPerIdOffsets(absl::MakeSpan(offsets), shard_size, map_id,
471 group_sizes, ar)) {
472 VLOG(2) << "IsPerIdOffsets() failed " << ar->ToString();
473 return std::nullopt;
474 }
475 } else {
476 if (!IsPerIdOffset(user->operand(spec.split_dim + 1),
477 user->dynamic_slice_sizes()[spec.split_dim], map_id,
478 group_size, ar)) {
479 VLOG(2) << "IsPerIdOffsets() failed " << ar->ToString();
480 return std::nullopt;
481 }
482 }
483
484 // If there was a reshape, allow only if the split dims are left unmodified
485 // by the reshape. Also rewrite the split dims so that they are in terms of
486 // the shape for the all-reduce as opposed to that of the reshape.
487 if (reshape) {
488 std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
489 ShapeUtil::DimensionsUnmodifiedByReshape(reshape->operand(0)->shape(),
490 reshape->shape());
491 // Map each unmodified output dim of reshape to the corresponding input dim.
492 absl::flat_hash_map<int64_t, int64_t> unmodified_output_to_input_map;
493 for (const std::pair<int64_t, int64_t>& io_pair : unmodified_dims) {
494 unmodified_output_to_input_map.insert({io_pair.second, io_pair.first});
495 }
496
497 bool all_split_dims_unmodified =
498 absl::c_all_of(split_dims, [&](int64_t out_dim) {
499 return unmodified_output_to_input_map.count(out_dim) != 0;
500 });
501 if (!all_split_dims_unmodified) {
502 VLOG(2) << "Split dimensions are modified by reshape";
503 return std::nullopt;
504 }
505
506 // rewrite the split dim and original_split_dims to be in terms of the
507 // shape of the all-reduce.
508 spec.split_dim = unmodified_output_to_input_map.at(spec.split_dim);
509 for (int64_t& split_dim : split_dims) {
510 split_dim = unmodified_output_to_input_map.at(split_dim);
511 }
512 }
513
514 spec.original_split_dims = split_dims;
515 return spec;
516 }
517
518 } // namespace xla
519