xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/reduce_scatter_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/reduce_scatter_utils.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20 
21 namespace xla {
22 namespace {
23 
IsTableLookup(const HloInstruction * hlo)24 bool IsTableLookup(const HloInstruction* hlo) {
25   while (hlo->opcode() == HloOpcode::kBitcast ||
26          hlo->opcode() == HloOpcode::kReshape ||
27          hlo->opcode() == HloOpcode::kCopy) {
28     hlo = hlo->operand(0);
29   }
30   return hlo->opcode() == HloOpcode::kDynamicSlice &&
31          (hlo->operand(0)->IsConstant() ||
32           hlo->operand(0)->opcode() == HloOpcode::kIota) &&
33          hlo->operand(0)->shape().rank() == 1 &&
34          (hlo->operand(0)->shape().element_type() == S32 ||
35           hlo->operand(0)->shape().element_type() == U32);
36 }
37 
38 // Function to map a replica/partition/global ID to an offset in the offset
39 // table, based on the given scalar offset HLO. For example, if the HLO is
40 // kPartitionId but the all-reduce uses global IDs, then the function maps
41 // global IDs to partition IDs. It returns -1 if the HLO cannot be understood.
42 using MapIdToTableOffset =
43     std::function<int64_t(const HloInstruction*, int64_t)>;
44 
GetIndexForId(const HloInstruction * index,int64_t id,const MapIdToTableOffset & map_id)45 int64_t GetIndexForId(const HloInstruction* index, int64_t id,
46                       const MapIdToTableOffset& map_id) {
47   // ID itself.
48   int64_t maybe_mapped_id = map_id(index, id);
49   if (maybe_mapped_id >= 0) {
50     return maybe_mapped_id;
51   }
52   if (!IsTableLookup(index)) {
53     VLOG(2) << "Index is not table lookup " << index->ToString();
54     return -1;
55   }
56   while (index->opcode() == HloOpcode::kReshape ||
57          index->opcode() == HloOpcode::kBitcast ||
58          index->opcode() == HloOpcode::kCopy) {
59     index = index->operand(0);
60   }
61   int64_t inner_index = GetIndexForId(index->operand(1), id, map_id);
62   if (inner_index < 0) {
63     VLOG(2) << "Failed to get inner index.";
64     return -1;
65   }
66   if (index->operand(0)->opcode() == HloOpcode::kIota) {
67     return inner_index;
68   }
69   // A table lookup.
70   const auto& table = index->operand(0)->literal();
71   return *table.GetIntegralAsS64({inner_index});
72 }
73 
IsPerIdOffsets(absl::Span<const HloInstruction * > offsets,int64_t shard_size,const MapIdToTableOffset & map_id,std::vector<int64_t> slice_group_sizes,const HloAllReduceInstruction * ar)74 bool IsPerIdOffsets(absl::Span<const HloInstruction*> offsets,
75                     int64_t shard_size, const MapIdToTableOffset& map_id,
76                     std::vector<int64_t> slice_group_sizes,
77                     const HloAllReduceInstruction* ar) {
78   if (offsets.size() != slice_group_sizes.size()) {
79     return false;
80   }
81   if (!ar->IsCrossModuleAllReduce() || !ar->use_global_device_ids()) {
82     return false;
83   }
84 
85   int num_groups = ar->replica_groups().size();
86   int num_split_dims = slice_group_sizes.size();
87 
88   for (int64_t i = 0; i < num_groups; ++i) {
89     for (int64_t j = 0; j < Product(slice_group_sizes); ++j) {
90       int64_t final_table_entry = 0;
91       int64_t id = ar->replica_groups()[i].replica_ids(j);
92       int64_t slice_group_size = Product(slice_group_sizes);
93       for (int dim = 0; dim < num_split_dims; dim++) {
94         auto scalar_offset = offsets[dim];
95         while (scalar_offset->opcode() == HloOpcode::kReshape ||
96                scalar_offset->opcode() == HloOpcode::kBitcast ||
97                scalar_offset->opcode() == HloOpcode::kCopy) {
98           scalar_offset = scalar_offset->operand(0);
99         }
100         if (!IsTableLookup(scalar_offset)) {
101           return false;
102         }
103         int64_t table_index =
104             GetIndexForId(scalar_offset->operand(1), id, map_id);
105         if (table_index < 0) {
106           return false;
107         }
108 
109         int64_t table_entry;
110         if (scalar_offset->operand(0)->opcode() == HloOpcode::kIota) {
111           table_entry = table_index;
112         } else {
113           table_entry = *scalar_offset->operand(0)->literal().GetIntegralAsS64(
114               {table_index});
115         }
116         slice_group_size /= slice_group_sizes[dim];
117         final_table_entry += table_entry * slice_group_size;
118       }
119       if (final_table_entry != shard_size * j) {
120         return false;
121       }
122     }
123   }
124 
125   return true;
126 }
127 
128 // Returns if `offset` == shard_size * id.
IsPerIdOffset(const HloInstruction * offset,int64_t shard_size,const MapIdToTableOffset & map_id,int64_t group_size,const HloAllReduceInstruction * ar)129 bool IsPerIdOffset(const HloInstruction* offset, int64_t shard_size,
130                    const MapIdToTableOffset& map_id, int64_t group_size,
131                    const HloAllReduceInstruction* ar) {
132   const bool iota_group =
133       ar->replica_groups().empty() ||
134       (ar->IsCrossModuleAllReduce() && !ar->use_global_device_ids());
135 
136   if (offset->opcode() == HloOpcode::kMultiply) {
137     // Check if it's constant * IsPerIdOffset(..., shard_size / constant, ...)
138     if (offset->shape().rank() != 0) {
139       VLOG(2) << "Offset is not a scalar " << offset->ToString();
140       return false;
141     }
142     int64_t const_operand = -1;
143     if (offset->operand(0)->IsConstant()) {
144       const_operand = 0;
145     } else if (offset->operand(1)->IsConstant()) {
146       const_operand = 1;
147     } else {
148       VLOG(2) << "Offset is not multiple(const, ...) " << offset->ToString();
149       return false;
150     }
151     auto multiplier =
152         offset->operand(const_operand)->literal().GetIntegralAsS64({});
153     if (!multiplier || shard_size % *multiplier != 0) {
154       VLOG(2) << "Multiplier is unknown or cannot evenly divide shard size "
155               << offset->operand(const_operand);
156       return false;
157     }
158     return IsPerIdOffset(offset->operand(1 - const_operand),
159                          shard_size / *multiplier, map_id, group_size, ar);
160   }
161   if (shard_size == 1 && iota_group) {
162     bool id_mapping_is_identity = true;
163     for (int64_t id = 0; id < group_size; ++id) {
164       int64_t mapped_id = map_id(offset, id);
165       if (mapped_id != id) {
166         id_mapping_is_identity = false;
167         break;
168       }
169     }
170     if (id_mapping_is_identity) {
171       return true;
172     }
173   }
174   if (offset->opcode() == HloOpcode::kBitcast ||
175       offset->opcode() == HloOpcode::kReshape ||
176       offset->opcode() == HloOpcode::kCopy) {
177     return IsPerIdOffset(offset->operand(0), shard_size, map_id, group_size,
178                          ar);
179   }
180 
181   if (offset->opcode() == HloOpcode::kConvert &&
182       offset->operand(0)->shape().IsInteger() &&
183       primitive_util::BitWidth(offset->operand(0)->shape().element_type()) <=
184           primitive_util::BitWidth(offset->shape().element_type())) {
185     return IsPerIdOffset(offset->operand(0), shard_size, map_id, group_size,
186                          ar);
187   }
188 
189   if (offset->opcode() == HloOpcode::kClamp) {
190     auto lower_bound = offset->operand(0)->literal().GetIntegralAsS64({});
191     auto upper_bound = offset->operand(2)->literal().GetIntegralAsS64({});
192     if (!lower_bound || !upper_bound || *lower_bound != 0 ||
193         *upper_bound < (group_size - 1) * shard_size) {
194       VLOG(2) << "Boundaries of the clamp is not legal: " << offset->ToString();
195       return false;
196     }
197     return IsPerIdOffset(offset->operand(1), shard_size, map_id, group_size,
198                          ar);
199   }
200 
201   const int64_t num_groups = iota_group ? 1 : ar->replica_groups().size();
202   if (IsTableLookup(offset)) {
203     // Check the values of the offset table, and see if they are shard_index *
204     // shard_size.
205     for (int64_t i = 0; i < num_groups; ++i) {
206       for (int64_t j = 0; j < group_size; ++j) {
207         int64_t id = iota_group ? j : ar->replica_groups()[i].replica_ids(j);
208         int64_t table_index = GetIndexForId(offset->operand(1), id, map_id);
209         if (table_index < 0) {
210           VLOG(2) << "Failed to infer table index from "
211                   << offset->operand(1)->ToString();
212           return false;
213         }
214 
215         int64_t table_entry;
216         if (offset->operand(0)->opcode() == HloOpcode::kIota) {
217           table_entry = table_index;
218         } else {
219           table_entry =
220               *offset->operand(0)->literal().GetIntegralAsS64({table_index});
221         }
222         if (table_entry != shard_size * j) {
223           VLOG(2) << "Unexpected offset from table.";
224           return false;
225         }
226       }
227     }
228 
229     // All table entries are good.
230     return true;
231   }
232 
233   // Check if the offset is the id itself and it has the right values.
234   for (int64_t i = 0; i < num_groups; ++i) {
235     for (int64_t j = 0; j < group_size; ++j) {
236       int64_t id = iota_group ? j : ar->replica_groups()[i].replica_ids(j);
237       int mapped_id = map_id(offset, id);
238       if (mapped_id != shard_size * j) {
239         VLOG(2) << "Mapping of " << id << " to " << mapped_id
240                 << " not matching expected value " << shard_size * j << ": "
241                 << offset->ToString();
242         return false;
243       }
244     }
245   }
246 
247   return true;
248 }
249 
250 }  // namespace
251 
MatchReduceScatter(const HloAllReduceInstruction * ar,int64_t num_partitions,int64_t num_replicas,bool allow_multiple_split_dims,bool allow_intervening_reshape,int64_t min_rank)252 std::optional<ReduceScatterSpec> MatchReduceScatter(
253     const HloAllReduceInstruction* ar, int64_t num_partitions,
254     int64_t num_replicas, bool allow_multiple_split_dims,
255     bool allow_intervening_reshape, int64_t min_rank) {
256   HloPredicate match_partition_id = [](const HloInstruction* i) {
257     return i->opcode() == HloOpcode::kPartitionId;
258   };
259   HloPredicate match_replica_id = [](const HloInstruction* i) {
260     return i->opcode() == HloOpcode::kReplicaId;
261   };
262   return MatchReduceScatter(ar, num_partitions, num_replicas,
263                             allow_multiple_split_dims,
264                             allow_intervening_reshape, min_rank,
265                             match_partition_id, match_replica_id);
266 }
267 
MatchReduceScatter(const HloAllReduceInstruction * ar,int64_t num_partitions,int64_t num_replicas,bool allow_multiple_split_dims,bool allow_intervening_reshape,int64_t min_rank,HloPredicate match_partition_id,HloPredicate match_replica_id)268 std::optional<ReduceScatterSpec> MatchReduceScatter(
269     const HloAllReduceInstruction* ar, int64_t num_partitions,
270     int64_t num_replicas, bool allow_multiple_split_dims,
271     bool allow_intervening_reshape, int64_t min_rank,
272     HloPredicate match_partition_id, HloPredicate match_replica_id) {
273   if (!ar->shape().IsArray() || ar->constrain_layout() ||
274       (ar->IsCrossModuleAllReduce() &&
275        !ar->GetModule()->config().use_spmd_partitioning())) {
276     VLOG(2) << "Unsupported all-reduce: " << ar->ToString();
277     return std::nullopt;
278   }
279   if (ar->shape().rank() - absl::c_count(ar->shape().dimensions(), 1) <
280       min_rank) {
281     VLOG(2) << " Should be at least rank-" << min_rank
282             << " excluding trivial dimensions " << ar->ToString();
283     return std::nullopt;
284   }
285   if (ar->user_count() != 1) {
286     VLOG(2) << "All-reduce user_count > 1 " << ar->ToString();
287     return std::nullopt;
288   }
289   if (ar->replica_groups().size() > 1) {
290     const int64_t size = ar->replica_groups()[0].replica_ids_size();
291     absl::Span<const ReplicaGroup> rgs = ar->replica_groups();
292     const bool has_uniform_size = absl::c_all_of(
293         rgs.subspan(1, size - 1), [size](const ReplicaGroup& group) {
294           return group.replica_ids_size() == size;
295         });
296     if (!has_uniform_size) {
297       VLOG(2) << "Unsupported non-uniform replica group size "
298               << ar->ToString();
299       return std::nullopt;
300     }
301   }
302 
303   HloInstruction* user = ar->users()[0];
304   HloInstruction* reshape = nullptr;
305   if (allow_intervening_reshape && user->opcode() == HloOpcode::kReshape) {
306     // Allow the intervening reshape if it reshapes just the non scattered
307     // dimension (checked later).
308     reshape = user;
309     if (reshape->user_count() != 1) {
310       VLOG(2) << "Reshape following all-reduce has user count > 1"
311               << reshape->ToString();
312       return std::nullopt;
313     }
314     user = reshape->users().front();
315   }
316   if (user->opcode() != HloOpcode::kDynamicSlice) {
317     VLOG(2) << "All-reduce user is not dynamic slice " << user->ToString();
318     return std::nullopt;
319   }
320   ReduceScatterSpec spec;
321   int64_t group_size;
322   MapIdToTableOffset map_id;
323   spec.dynamic_slice = user;
324   if (!ar->IsCrossModuleAllReduce()) {
325     spec.sharded_replicas = num_replicas;
326     group_size = ar->replica_groups().empty()
327                      ? num_replicas
328                      : ar->replica_groups()[0].replica_ids_size();
329     map_id = [&](const HloInstruction* hlo, int64_t id) {
330       return match_replica_id(hlo) ? id : -1;
331     };
332   } else if (ar->use_global_device_ids()) {
333     spec.sharded_replicas = num_replicas;
334     spec.sharded_partitions = num_partitions;
335     group_size = ar->replica_groups()[0].replica_ids_size();
336     bool orthogonal_replicas = true;
337     std::vector<int64_t> partition_id_to_index(num_partitions, -1);
338     for (int64_t g = 0; g < ar->replica_groups().size(); ++g) {
339       const auto& group = ar->replica_groups()[g];
340       for (int64_t i = 0; i < group.replica_ids_size(); ++i) {
341         int64_t global_id = group.replica_ids(i);
342         int64_t partition_id = global_id % num_partitions;
343         if (partition_id_to_index[partition_id] == -1) {
344           partition_id_to_index[partition_id] = i;
345           continue;
346         }
347         if (partition_id_to_index[partition_id] != i ||
348             global_id / num_partitions !=
349                 group.replica_ids(0) / num_partitions) {
350           orthogonal_replicas = false;
351           break;
352         }
353       }
354     }
355     map_id = [&, orthogonal_replicas](const HloInstruction* hlo, int64_t id) {
356       if (match_replica_id(hlo)) {
357         return num_partitions == 1 ? id : -1;
358       }
359       if (match_partition_id(hlo)) {
360         if (num_replicas == 1) {
361           return id;
362         }
363         return orthogonal_replicas ? id % num_partitions : -1;
364       }
365       auto is_replica_mul_num_partitions = [&](const HloInstruction* operand) {
366         return operand->opcode() == HloOpcode::kMultiply &&
367                ((operand->operand(0)->opcode() == HloOpcode::kReplicaId &&
368                  operand->operand(1)->IsConstant() &&
369                  operand->operand(1)->literal().GetIntegralAsS64({}) ==
370                      num_partitions) ||
371                 (operand->operand(1)->opcode() == HloOpcode::kReplicaId &&
372                  operand->operand(0)->IsConstant() &&
373                  operand->operand(0)->literal().GetIntegralAsS64({}) ==
374                      num_partitions));
375       };
376       if (hlo->opcode() == HloOpcode::kAdd &&
377           ((match_partition_id(hlo->operand(0)) &&
378             is_replica_mul_num_partitions(hlo->operand(1))) ||
379            (match_partition_id(hlo->operand(1)) &&
380             is_replica_mul_num_partitions(hlo->operand(0))))) {
381         return id;
382       }
383       return int64_t{-1};
384     };
385   } else {
386     // Right now all cross-partition all-reduces' subgroups refer to replicas
387     // unless they use use_global_device_ids.
388     if (ar->replica_groups().size() != num_replicas ||
389         ar->replica_groups()[0].replica_ids_size() != 1) {
390       VLOG(2) << "Unsupported size > 1 replica groups for cross-partition, "
391                  "non-global ID "
392               << ar->ToString();
393       return std::nullopt;
394     }
395     spec.sharded_partitions = num_partitions;
396     group_size = num_partitions;
397     map_id = [&](const HloInstruction* hlo, int64_t id) {
398       return match_partition_id(hlo) ? id : -1;
399     };
400   }
401   if (group_size < 2) {
402     VLOG(2) << "Group_size < 2, nothing to do " << ar->ToString();
403     return std::nullopt;
404   }
405   spec.group_size = group_size;
406   spec.split_dim = -1;
407   std::vector<int64_t> split_dims;
408   // First find a single dimension where the input and output of dynamic slice
409   // differ.
410   int num_dims = 0;
411   for (int64_t dim = 0; dim < ar->shape().rank(); ++dim) {
412     if (ar->shape().dimensions(dim) == user->shape().dimensions(dim)) {
413       continue;
414     }
415     num_dims++;
416     VLOG(2) << "select dim: " << dim;
417     spec.split_dim = dim;
418   }
419   if (spec.split_dim != -1) {
420     if (num_dims == 1) {
421       split_dims.push_back(spec.split_dim);
422     } else {
423       // Recompute split dim.
424       spec.split_dim = -1;
425     }
426   }
427   const Shape& shape = user->operand(0)->shape();
428   if (spec.split_dim == -1) {
429     for (int64_t dim = 0; dim < shape.rank(); ++dim) {
430       auto offset = user->operand(dim + 1);
431       // Skip trivial (1) dimensions or if the index is a constant 0.
432       if (shape.dimensions(dim) == 1 ||
433           (offset->opcode() == HloOpcode::kConstant &&
434            offset->literal().IsZero({}))) {
435         continue;
436       }
437       split_dims.push_back(dim);
438       if (spec.split_dim != -1) {
439         if (!allow_multiple_split_dims || spec.split_dim != (dim - 1)) {
440           VLOG(2) << "Only support split on consecutive dims "
441                   << user->ToString();
442           return std::nullopt;
443         }
444         continue;
445       }
446       spec.split_dim = dim;
447     }
448   }
449 
450   std::vector<int64_t> group_sizes;
451   group_sizes.reserve(split_dims.size());
452   for (auto dim : split_dims) {
453     group_sizes.push_back(user->operand(0)->shape().dimensions(dim) /
454                           user->dynamic_slice_sizes()[dim]);
455   }
456   if (Product(group_sizes) != group_size) {
457     VLOG(2) << "Group size mismatch " << user->ToString() << " vs "
458             << ar->ToString();
459     return std::nullopt;
460   }
461 
462   if (split_dims.size() > 1) {
463     std::vector<const HloInstruction*> offsets;
464     int shard_size = 1;
465     for (auto dim : split_dims) {
466       offsets.push_back(user->operand(dim + 1));
467       shard_size *= user->dynamic_slice_sizes()[dim];
468     }
469 
470     if (!IsPerIdOffsets(absl::MakeSpan(offsets), shard_size, map_id,
471                         group_sizes, ar)) {
472       VLOG(2) << "IsPerIdOffsets() failed " << ar->ToString();
473       return std::nullopt;
474     }
475   } else {
476     if (!IsPerIdOffset(user->operand(spec.split_dim + 1),
477                        user->dynamic_slice_sizes()[spec.split_dim], map_id,
478                        group_size, ar)) {
479       VLOG(2) << "IsPerIdOffsets() failed " << ar->ToString();
480       return std::nullopt;
481     }
482   }
483 
484   // If there was a reshape, allow only if the split dims are left unmodified
485   // by the reshape. Also rewrite the split dims so that they are in terms of
486   // the shape for the all-reduce as opposed to that of the reshape.
487   if (reshape) {
488     std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
489         ShapeUtil::DimensionsUnmodifiedByReshape(reshape->operand(0)->shape(),
490                                                  reshape->shape());
491     // Map each unmodified output dim of reshape to the corresponding input dim.
492     absl::flat_hash_map<int64_t, int64_t> unmodified_output_to_input_map;
493     for (const std::pair<int64_t, int64_t>& io_pair : unmodified_dims) {
494       unmodified_output_to_input_map.insert({io_pair.second, io_pair.first});
495     }
496 
497     bool all_split_dims_unmodified =
498         absl::c_all_of(split_dims, [&](int64_t out_dim) {
499           return unmodified_output_to_input_map.count(out_dim) != 0;
500         });
501     if (!all_split_dims_unmodified) {
502       VLOG(2) << "Split dimensions are modified by reshape";
503       return std::nullopt;
504     }
505 
506     // rewrite the split dim and original_split_dims to be in terms of the
507     // shape of the all-reduce.
508     spec.split_dim = unmodified_output_to_input_map.at(spec.split_dim);
509     for (int64_t& split_dim : split_dims) {
510       split_dim = unmodified_output_to_input_map.at(split_dim);
511     }
512   }
513 
514   spec.original_split_dims = split_dims;
515   return spec;
516 }
517 
518 }  // namespace xla
519