xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/while_loop_concat_code_motion.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_concat_code_motion.h"
17 
18 #include <map>
19 #include <optional>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_dce.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
34 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
35 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/platform/status.h"
45 #include "tensorflow/stream_executor/lib/statusor.h"
46 
47 namespace xla {
48 
49 namespace {
50 
51 // This algorithm tries to group HLO instructions into concat candidates. Each
52 // instruction can only belong to a single group.
53 //
54 // For simplicity, after finding the groups, it in-place updates the first group
55 // member to the full shape, and replaces non-grouped uses with slices of it.
56 // Then it relies on TupleSimplifier, WhileLoopSimplifier, and DCE passes to
57 // remove other elements.
58 
59 // Represents a group of elements and how to concat them.
60 struct ConcatGroup {
ConcatGroupxla::__anonbf9fd6d50111::ConcatGroup61   ConcatGroup(std::vector<HloInstruction*> elements, int64_t concat_dim,
62               bool inserted_concat_dim)
63       : elements(std::move(elements)),
64         element_sizes(this->elements.size(), 1),
65         element_offsets(this->elements.size(), 0),
66         concat_dim(concat_dim),
67         inserted_concat_dim(inserted_concat_dim) {
68     if (inserted_concat_dim) {
69       absl::c_iota(element_offsets, 0);
70     } else {
71       for (int64_t i = 0; i < element_sizes.size(); ++i) {
72         element_sizes[i] = this->elements[i]->shape().dimensions(concat_dim);
73         if (i > 0) {
74           element_offsets[i] = element_offsets[i - 1] + element_sizes[i - 1];
75         }
76       }
77     }
78   }
79 
GetConcatShapexla::__anonbf9fd6d50111::ConcatGroup80   Shape GetConcatShape() const {
81     if (inserted_concat_dim) {
82       std::vector<int64_t> dims;
83       const Shape& element_shape = elements.back()->shape();
84       dims.reserve(element_shape.rank() + 1);
85       for (int64_t i = 0; i < element_shape.rank(); ++i) {
86         if (i == concat_dim) {
87           dims.push_back(elements.size());
88         }
89         dims.push_back(element_shape.dimensions(i));
90       }
91       if (dims.size() == concat_dim) {
92         dims.push_back(elements.size());
93       }
94       return ShapeUtil::MakeShape(element_shape.element_type(), dims);
95     } else {
96       int64_t dim_size = 0;
97       for (int64_t size : element_sizes) {
98         dim_size += size;
99       }
100       Shape shape = elements.back()->shape();
101       shape.set_dimensions(concat_dim, dim_size);
102       return shape;
103     }
104   }
105 
CreateSlicexla::__anonbf9fd6d50111::ConcatGroup106   HloInstruction* CreateSlice(HloInstruction* full_data, int64_t element_index,
107                               HloComputation* comp) const {
108     Shape shape = full_data->shape();
109     shape.set_dimensions(concat_dim, element_sizes[element_index]);
110     std::vector<int64_t> starts(shape.rank(), 0);
111     std::vector<int64_t> limits(shape.dimensions().begin(),
112                                 shape.dimensions().end());
113     starts[concat_dim] = element_offsets[element_index];
114     limits[concat_dim] += starts[concat_dim];
115     auto slice = comp->AddInstruction(
116         HloInstruction::CreateSlice(shape, full_data, starts, limits,
117                                     std::vector<int64_t>(shape.rank(), 1)));
118     if (!inserted_concat_dim) {
119       return slice;
120     }
121     std::vector<int64_t> element_shape;
122     element_shape.reserve(shape.rank() - 1);
123     for (int64_t i = 0; i < shape.rank(); ++i) {
124       if (i != concat_dim) {
125         element_shape.push_back(shape.dimensions(i));
126       }
127     }
128     return comp->AddInstruction(HloInstruction::CreateReshape(
129         ShapeUtil::MakeShape(shape.element_type(), element_shape), slice));
130   }
131 
CreateConcatxla::__anonbf9fd6d50111::ConcatGroup132   HloInstruction* CreateConcat(std::vector<HloInstruction*> input_elements,
133                                HloComputation* comp) const {
134     if (inserted_concat_dim) {
135       for (int64_t i = 0; i < input_elements.size(); ++i) {
136         std::vector<int64_t> element_shape;
137         element_shape.reserve(input_elements[i]->shape().rank() + 1);
138         for (int64_t j = 0; j < input_elements[i]->shape().rank(); ++j) {
139           if (j == concat_dim) {
140             element_shape.push_back(1);
141           }
142           element_shape.push_back(input_elements[i]->shape().dimensions(j));
143         }
144         if (element_shape.size() == concat_dim) {
145           element_shape.push_back(1);
146         }
147         input_elements[i] = comp->AddInstruction(HloInstruction::CreateReshape(
148             ShapeUtil::MakeShape(input_elements[i]->shape().element_type(),
149                                  element_shape),
150             input_elements[i]));
151       }
152     }
153 
154     return comp->AddInstruction(HloInstruction::CreateConcatenate(
155         GetConcatShape(), input_elements, concat_dim));
156   }
157 
158   std::vector<HloInstruction*> elements;
159   std::vector<int64_t> element_sizes;
160   std::vector<int64_t> element_offsets;
161   int64_t concat_dim;
162   // Whether the concat dim is an inserted new dimension.
163   bool inserted_concat_dim;
164 };
165 
166 // A collection of ConcatGroup's where each HLO can only belong to a single
167 // group.
168 class ConcatGroups {
169  public:
170   // Returns the group index and element index in group for an HLO, if it
171   // belongs to a group.
GetGroupIndex(const HloInstruction * hlo) const172   std::optional<std::pair<int64_t, int64_t>> GetGroupIndex(
173       const HloInstruction* hlo) const {
174     auto it = element_to_group_.find(hlo);
175     if (it == element_to_group_.end()) {
176       return std::nullopt;
177     }
178     return it->second;
179   }
180 
GetGroup(int64_t index) const181   const ConcatGroup& GetGroup(int64_t index) const { return groups_[index]; }
182 
183   // Creates a new group and returns the index if it doesn't exist, or returns
184   // existing group index. If the new group doesn't match exactly with an
185   // existing group but shared some of the elements, returns -1 as the index.
186   // It also returns whether a new group is created. So the return value is a
187   // pair of {whether created, group index}.
MaybeCreateNewGroup(ConcatGroup group)188   std::pair<bool, int64_t> MaybeCreateNewGroup(ConcatGroup group) {
189     int64_t group_id = -1;
190     absl::flat_hash_set<HloInstruction*> elements_dedup;
191     for (int64_t i = 0; i < group.elements.size(); ++i) {
192       if (!elements_dedup.insert(group.elements[i]).second) {
193         VLOG(2) << "Duplicates in group. Element: "
194                 << group.elements[i]->ToString();
195       }
196       if (concat_disallowed_.contains(group.elements[i])) {
197         VLOG(2) << "Failed creating group. Grouping disallowed on "
198                 << group.elements[i]->ToString();
199         return std::pair<bool, int64_t>(false, -1);
200       }
201       auto existing = GetGroupIndex(group.elements[i]);
202       if (existing.has_value() &&
203           (i != existing->second ||
204            groups_[existing->first].concat_dim != group.concat_dim)) {
205         // We allow mismatched inserted_concat_dim, since that only requires a
206         // trivial reshape.
207         VLOG(2)
208             << "Failed creating group. Different than existing group. Element: "
209             << group.elements[i]->ToString();
210         return std::pair<bool, int64_t>(false, -1);
211       }
212       if (i == 0 && existing.has_value()) {
213         group_id = existing->first;
214       }
215       if (i > 0) {
216         if (existing.has_value() && existing->first != group_id) {
217           VLOG(2) << "Failed creating group. Different than existing group. "
218                      "Element: "
219                   << group.elements[i]->ToString();
220           return std::pair<bool, int64_t>(false, -1);
221         }
222         if (!existing.has_value() && group_id >= 0) {
223           VLOG(2) << "Failed creating group. Different than existing group. "
224                      "Element: "
225                   << group.elements[i]->ToString();
226           return std::pair<bool, int64_t>(false, -1);
227         }
228       }
229     }
230     if (group_id >= 0) {
231       VLOG(2) << "Group already exists at " << group_id << " for "
232               << group.elements[0]->ToString();
233       return std::pair<bool, int64_t>(false, group_id);
234     }
235     int64_t index = groups_.size();
236     for (int64_t i = 0; i < group.elements.size(); ++i) {
237       element_to_group_[group.elements[i]] =
238           std::pair<int64_t, int64_t>(index, i);
239     }
240     VLOG(2) << "Created new group at " << index << " for "
241             << group.elements[0]->ToString()
242             << ", concat_dim: " << group.concat_dim
243             << ", inserted: " << group.inserted_concat_dim;
244     groups_.push_back(std::move(group));
245     return std::pair<bool, int64_t>(true, index);
246   }
247 
Groups() const248   const std::vector<ConcatGroup>& Groups() const { return groups_; }
249 
NextGroupIndex() const250   int64_t NextGroupIndex() const { return groups_.size(); }
251 
RemoveTailingGroups(int64_t start_index)252   void RemoveTailingGroups(int64_t start_index) {
253     while (groups_.size() > start_index) {
254       for (auto element : groups_.back().elements) {
255         element_to_group_.erase(element);
256       }
257       groups_.pop_back();
258     }
259   }
260 
DisallowGroupingOn(const HloInstruction * hlo)261   void DisallowGroupingOn(const HloInstruction* hlo) {
262     VLOG(2) << "Disallow grouping on " << hlo->ToString();
263     concat_disallowed_.insert(hlo);
264   }
265 
266  private:
267   // element -> {group index in groups_, element index in group}.
268   absl::flat_hash_map<const HloInstruction*, std::pair<int64_t, int64_t>>
269       element_to_group_;
270   std::vector<ConcatGroup> groups_;
271   absl::flat_hash_set<const HloInstruction*> concat_disallowed_;
272 };
273 
274 // Infers an operand's concat dim and whether it's an inserted dim. For example,
275 // if hlo is f32[2,4,2] broadcast(f32[2,4]), dimensions={0,1} concatenated on
276 // dim 2, then this function will return {2, true}.
277 //
278 // If the operand is already transformed to the combined shape, specify its
279 // group in combined_operand_group. (Only required for kReshape.)
GetOperandConcatDim(const HloInstruction * hlo,int64_t operand_index,int64_t hlo_concat_dim,bool hlo_inserted_concat_dim,const ConcatGroup * combined_operand_group=nullptr)280 std::optional<std::pair<int64_t, bool>> GetOperandConcatDim(
281     const HloInstruction* hlo, int64_t operand_index, int64_t hlo_concat_dim,
282     bool hlo_inserted_concat_dim,
283     const ConcatGroup* combined_operand_group = nullptr) {
284   if (hlo->IsElementwise() || hlo->opcode() == HloOpcode::kAllReduce) {
285     return std::pair<int64_t, bool>(hlo_concat_dim, hlo_inserted_concat_dim);
286   }
287   int64_t operand_concat_dim = -1;
288   bool operand_inserted_concat_dim = false;
289   const Shape& operand_shape =
290       combined_operand_group == nullptr
291           ? hlo->operand(operand_index)->shape()
292           : combined_operand_group->elements.back()->shape();
293   if (hlo->opcode() == HloOpcode::kBroadcast) {
294     operand_concat_dim = 0;
295     operand_inserted_concat_dim = true;
296     // Try to place operand_concat_dim adjacent to dims the same way as the
297     // output, if it does not exist in the operand..
298     int64_t min_dist_to_concat_dim = hlo->shape().rank();
299     for (int64_t i = 0; i < operand_shape.rank(); ++i) {
300       if (hlo->dimensions(i) == hlo_concat_dim) {
301         operand_concat_dim = i;
302         operand_inserted_concat_dim = hlo_inserted_concat_dim;
303         break;
304       }
305       if (hlo->dimensions(i) < hlo_concat_dim &&
306           min_dist_to_concat_dim > hlo_concat_dim - hlo->dimensions(i)) {
307         operand_concat_dim = i + 1;
308         min_dist_to_concat_dim = hlo_concat_dim - hlo->dimensions(i);
309       }
310       if (hlo->dimensions(i) > hlo_concat_dim &&
311           min_dist_to_concat_dim > hlo->dimensions(i) - hlo_concat_dim) {
312         operand_concat_dim = i;
313         min_dist_to_concat_dim = hlo->dimensions(i) - hlo_concat_dim;
314       }
315     }
316   } else if (hlo->opcode() == HloOpcode::kReduce) {
317     if (operand_index != 0) {
318       return std::nullopt;
319     }
320     operand_concat_dim = hlo_concat_dim;
321     operand_inserted_concat_dim = hlo_inserted_concat_dim;
322     std::set<int64_t> sorted_reduce_dims;
323     for (int64_t dim : hlo->dimensions()) {
324       sorted_reduce_dims.insert(dim);
325     }
326     for (int64_t dim : sorted_reduce_dims) {
327       if ((hlo_inserted_concat_dim && dim < operand_concat_dim) ||
328           (!hlo_inserted_concat_dim && dim <= operand_concat_dim)) {
329         operand_concat_dim++;
330       }
331     }
332   } else if (hlo->opcode() == HloOpcode::kReshape) {
333     int64_t i = 0;
334     int64_t j = 0;
335     operand_inserted_concat_dim = false;
336     // Only support adding/removing trivial dims.
337     while (i < operand_shape.rank() || j <= hlo_concat_dim) {
338       if (i < operand_shape.rank() && j < hlo->shape().rank() &&
339           operand_shape.dimensions(i) == hlo->shape().dimensions(j)) {
340         if (j == hlo_concat_dim) {
341           operand_inserted_concat_dim =
342               hlo_inserted_concat_dim && operand_shape.dimensions(i) != 1;
343           operand_concat_dim = i;
344           break;
345         }
346         i++;
347         j++;
348         continue;
349       }
350       if (i < operand_shape.rank() && operand_shape.dimensions(i) == 1) {
351         if (j == hlo_concat_dim && hlo_inserted_concat_dim) {
352           operand_concat_dim = i;
353           break;
354         }
355         i++;
356         continue;
357       }
358       if (j == hlo_concat_dim) {
359         operand_concat_dim = i;
360         operand_inserted_concat_dim = true;
361         break;
362       }
363       if (j < hlo->shape().rank() && hlo->shape().dimensions(j) == 1) {
364         j++;
365         continue;
366       }
367       return std::nullopt;
368     }
369   } else {
370     return std::nullopt;
371   }
372   CHECK_GE(operand_concat_dim, 0);
373   return std::pair<int64_t, bool>(operand_concat_dim,
374                                   operand_inserted_concat_dim);
375 }
376 
ModifyHloPropertiesForConcatShape(const ConcatGroup & group,HloInstruction * hlo)377 void ModifyHloPropertiesForConcatShape(const ConcatGroup& group,
378                                        HloInstruction* hlo) {
379   *hlo->mutable_shape() = group.GetConcatShape();
380   if (hlo->opcode() == HloOpcode::kBroadcast) {
381     // Use the last element to infer the operand concat dim, since the first
382     // element's operand might have been rewriten.
383     auto operand_dim = GetOperandConcatDim(
384         group.elements.back(), 0, group.concat_dim, group.inserted_concat_dim);
385     CHECK(operand_dim.has_value());
386     int64_t operand_concat_dim = operand_dim->first;
387     bool operand_inserted_concat_dim = operand_dim->second;
388     if (operand_inserted_concat_dim) {
389       // We should have added an dimension on the operand.
390       CHECK_EQ(hlo->operand(0)->shape().rank(), hlo->dimensions().size() + 1)
391           << hlo->ToString();
392     } else {
393       CHECK_EQ(hlo->operand(0)->shape().rank(), hlo->dimensions().size());
394     }
395     std::vector<int64_t> dims;
396     const int64_t rank = hlo->operand(0)->shape().rank();
397     dims.reserve(rank);
398     for (int64_t i = 0; i < rank; ++i) {
399       if (i == operand_concat_dim && operand_inserted_concat_dim) {
400         dims.push_back(group.concat_dim);
401       } else {
402         if (i > operand_concat_dim && operand_inserted_concat_dim) {
403           dims.push_back(hlo->dimensions(i - 1));
404         } else {
405           dims.push_back(hlo->dimensions(i));
406         }
407         if (group.inserted_concat_dim && dims.back() >= group.concat_dim) {
408           dims.back()++;
409         }
410       }
411     }
412     *hlo->mutable_dimensions() = std::move(dims);
413   } else if (hlo->opcode() == HloOpcode::kReduce) {
414     auto operand_dim = GetOperandConcatDim(
415         group.elements.back(), 0, group.concat_dim, group.inserted_concat_dim);
416     int64_t operand_concat_dim = operand_dim->first;
417     bool operand_inserted_concat_dim = operand_dim->second;
418     CHECK(operand_dim.has_value());
419     if (operand_inserted_concat_dim) {
420       auto dims = hlo->mutable_dimensions();
421       for (int64_t i = 0; i < dims->size(); ++i) {
422         if ((*dims)[i] >= operand_concat_dim) {
423           (*dims)[i]++;
424         }
425       }
426     }
427   }
428 }
429 
430 // Main method to assign groups to HLOs, based on a concat.
GroupHlosForConcat(HloComputation * body,HloInstruction * concat,absl::flat_hash_map<const HloInstruction *,int64_t> topological_order,ConcatGroups * groups)431 bool GroupHlosForConcat(
432     HloComputation* body, HloInstruction* concat,
433     absl::flat_hash_map<const HloInstruction*, int64_t> topological_order,
434     ConcatGroups* groups) {
435   const int64_t group_size = concat->operand_count();
436   absl::flat_hash_set<int64_t> used_groups;
437   auto root_tuple = body->root_instruction();
438   CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple);
439   absl::flat_hash_map<HloInstruction*, int64_t> root_tuple_element_use_count;
440   for (auto operand : root_tuple->operands()) {
441     root_tuple_element_use_count.emplace(operand, 0).first->second++;
442   }
443   // Priority Queue sorted by topological order. Users come before operands, so
444   // it uses -topological_order[element0] as the key. We start with the concat
445   // operands.
446   std::multimap<int64_t, ConcatGroup> pq;
447   const int64_t first_group_id_to_create = groups->NextGroupIndex();
448   auto fail_and_cleanup = [&] {
449     VLOG(1) << "Failed to get the subcomputation to optimize for "
450             << concat->ToString() << ", clear groups starting at "
451             << first_group_id_to_create;
452     groups->RemoveTailingGroups(first_group_id_to_create);
453     return false;
454   };
455   struct GroupUse {
456     int64_t group_id;
457     bool newly_created;
458     bool already_used_by_subcomp;
459   };
460   auto maybe_create_group = [&](ConcatGroup group) {
461     auto res = groups->MaybeCreateNewGroup(std::move(group));
462     GroupUse use{res.second, false, false};
463     if (res.second < 0) {
464       return use;
465     }
466     use.newly_created = res.first;
467     use.already_used_by_subcomp = !used_groups.insert(res.second).second;
468     return use;
469   };
470   std::vector<HloInstruction*> concat_operands(concat->operands().begin(),
471                                                concat->operands().end());
472   int64_t concat_operand_order = -topological_order[concat_operands[0]];
473   pq.emplace(concat_operand_order,
474              ConcatGroup(std::move(concat_operands),
475                          concat->concatenate_dimension(), false));
476 
477   // Find the subcomputation on elements to combine, in order to move `concat`
478   // out of the loop without adding new concats. We start from the concat's
479   // operands, and the priority queue is ordered in reverse topological order
480   // so we process outputs before inputs. Each entry in the queue is a group of
481   // elements to combine. A legitimate group consists of identical ops, except
482   // that they each operate on one element. When a group of loop inputs are
483   // processed, we also enqueue the corresponding loop outputs to keep them
484   // match in shape.
485   while (!pq.empty()) {
486     auto group = std::move(pq.begin()->second);
487     pq.erase(pq.begin());
488     const auto& hlos = group.elements;
489     VLOG(2) << "GroupHlosForConcat dequeued " << hlos[0]->ToString();
490     bool group_is_param_gtes = false;
491     if (absl::c_all_of(hlos, [&](const HloInstruction* element) {
492           return element == hlos[0];
493         })) {
494       // Shared operand.
495       if (groups->GetGroupIndex(hlos[0]).has_value()) {
496         VLOG(1) << "We do not support the case if a shared operand also part "
497                    "of a group: "
498                 << hlos[0]->ToString();
499         return fail_and_cleanup();
500       }
501       groups->DisallowGroupingOn(hlos[0]);
502       continue;
503     }
504     if (absl::c_all_of(hlos, [&](const HloInstruction* element) {
505           return element->opcode() == HloOpcode::kGetTupleElement &&
506                  element->operand(0) == body->parameter_instruction(0);
507         })) {
508       group_is_param_gtes = true;
509     } else if (((hlos[0]->IsElementwise() ||
510                  hlos[0]->opcode() == HloOpcode::kAllReduce) &&
511                 !hlos[0]->HasSideEffect()) ||
512                hlos[0]->opcode() == HloOpcode::kBroadcast ||
513                hlos[0]->opcode() == HloOpcode::kReduce ||
514                hlos[0]->opcode() == HloOpcode::kReshape ||
515                hlos[0]->IsCustomCall("Sharding")) {
516       if (hlos[0]->opcode() == HloOpcode::kAllReduce &&
517           (!hlos[0]->shape().IsArray() || hlos[0]->IsCrossModuleAllReduce())) {
518         VLOG(2) << "Unsupported allreduce: " << hlos[0]->ToString();
519         return fail_and_cleanup();
520       }
521       // Check if these elements can be concatenated.
522       if (absl::c_any_of(hlos, [&](const HloInstruction* element) {
523             auto eq_operand = [](const HloInstruction* a,
524                                  const HloInstruction* b) {
525               return ShapeUtil::Compatible(a->shape(), b->shape());
526             };
527             auto eq_computations = [](const HloComputation* lhs,
528                                       const HloComputation* rhs) {
529               return lhs->Equal(*rhs, /*is_layout_sensitive=*/false);
530             };
531             if (!hlos[0]->Identical(*element, eq_operand, eq_computations,
532                                     /*layout_sensitive=*/false)) {
533               return true;
534             }
535             if (element->opcode() == HloOpcode::kReduce &&
536                 (element->operand_count() != 2 ||
537                  element->operand(1) != hlos[0]->operand(1))) {
538               return true;
539             }
540             return false;
541           })) {
542         VLOG(2) << "Different types of elements. First element: "
543                 << hlos[0]->ToString();
544         return fail_and_cleanup();
545       }
546       // Now enqueue the inputs.
547       int64_t input_count = hlos[0]->operand_count();
548       if (hlos[0]->opcode() == HloOpcode::kReduce) {
549         CHECK_EQ(input_count, 2);
550         // Exclude the init value that we have checked to be the same.
551         input_count = 1;
552       }
553       for (int64_t i = 0; i < input_count; ++i) {
554         std::vector<HloInstruction*> elements(group_size);
555         for (int64_t j = 0; j < group_size; ++j) {
556           elements[j] = hlos[j]->mutable_operand(i);
557         }
558         auto maybe_new_concat_dim = GetOperandConcatDim(
559             hlos[0], i, group.concat_dim, group.inserted_concat_dim);
560         if (!maybe_new_concat_dim.has_value()) {
561           VLOG(2) << "Cannot find operand concat dimension for operand " << i
562                   << " of " << hlos[0]->ToString();
563           return fail_and_cleanup();
564         }
565         int64_t new_group_concat_dim = maybe_new_concat_dim->first;
566         bool inserted_concat_dim = maybe_new_concat_dim->second;
567         // Enqueue the input group.
568         int64_t element_order = -topological_order[elements[0]];
569         pq.emplace(element_order,
570                    ConcatGroup(std::move(elements), new_group_concat_dim,
571                                inserted_concat_dim));
572       }
573     } else if (hlos[0]->opcode() == HloOpcode::kSlice) {
574       int64_t offset = 0;
575       auto operand = hlos[0]->operand(0);
576       if (group.inserted_concat_dim) {
577         VLOG(2) << "Slices cannot be grouped on new dimension.";
578         return fail_and_cleanup();
579       }
580       if (groups->GetGroupIndex(operand).has_value()) {
581         // Should not slice an operand to be grouped.
582         return fail_and_cleanup();
583       }
584       groups->DisallowGroupingOn(operand);
585       for (int64_t i = 0; i < group_size; ++i) {
586         if (hlos[i]->operand(0) != operand) {
587           VLOG(2) << "Slices of different operands.";
588           return fail_and_cleanup();
589         }
590         for (int64_t j = 0; j < hlos[i]->shape().rank(); ++j) {
591           if (hlos[i]->slice_strides(j) != 1) {
592             VLOG(2) << "Slices with strides.";
593             return fail_and_cleanup();
594           }
595           if (j == group.concat_dim) {
596             if (hlos[i]->slice_starts(j) != offset) {
597               VLOG(2) << "Slices with unsupported offsets.";
598               return fail_and_cleanup();
599             }
600             offset += hlos[i]->shape().dimensions(j);
601           } else {
602             if (hlos[i]->slice_starts(j) != 0 ||
603                 hlos[i]->slice_limits(j) != operand->shape().dimensions(j)) {
604               VLOG(2) << "Slice with unsupported offsets at dimension " << j
605                       << ", " << hlos[i]->ToString();
606               return fail_and_cleanup();
607             }
608           }
609         }
610       }
611       if (offset != operand->shape().dimensions(group.concat_dim)) {
612         VLOG(2) << "Slices with unsupported sizes.";
613         return fail_and_cleanup();
614       }
615     } else {
616       VLOG(2) << "Unsupported opcode: " << hlos[0]->ToString();
617       return fail_and_cleanup();
618     }
619     auto guse = maybe_create_group(std::move(group));
620     if (guse.group_id < 0) {
621       VLOG(2) << "Failed to create group.";
622       return fail_and_cleanup();
623     }
624     const auto& registered_group = groups->GetGroup(guse.group_id);
625     if (!guse.already_used_by_subcomp && group_is_param_gtes) {
626       // When we processed a group of parameter GTEs, we should also enqueue the
627       // corresponding root tuple operands, so that they have matching shapes.
628       std::vector<HloInstruction*> new_outputs(group_size);
629       for (int64_t i = 0; i < group_size; ++i) {
630         new_outputs[i] = root_tuple->mutable_operand(
631             registered_group.elements[i]->tuple_index());
632       }
633       int64_t new_output_order = -topological_order[new_outputs[0]];
634       pq.emplace(
635           new_output_order,
636           ConcatGroup(std::move(new_outputs), registered_group.concat_dim,
637                       registered_group.inserted_concat_dim));
638     }
639   }
640   return groups->Groups().size() > first_group_id_to_create;
641 }
642 
TupleElementsUsedInCond(HloInstruction * loop)643 std::vector<bool> TupleElementsUsedInCond(HloInstruction* loop) {
644   std::vector<bool> result(loop->shape().tuple_shapes_size(), false);
645   for (auto user : loop->while_condition()->parameter_instruction(0)->users()) {
646     if (user->opcode() != HloOpcode::kGetTupleElement) {
647       absl::c_fill(result, true);
648       return result;
649     }
650     result[user->tuple_index()] = true;
651   }
652   return result;
653 }
654 
655 // Adds copies to returned values to keep RewriteLoopWithConcatGroups simple:
656 // the copies do not have other users and only appear once in the root tuple.
AddCopiesToRoot(HloComputation * body,absl::Span<HloInstruction * const> param_gtes,ConcatGroups * groups)657 Status AddCopiesToRoot(HloComputation* body,
658                        absl::Span<HloInstruction* const> param_gtes,
659                        ConcatGroups* groups) {
660   auto root = body->root_instruction();
661   CHECK_EQ(root->opcode(), HloOpcode::kTuple);
662   std::vector<HloInstruction*> copies(root->operand_count(), nullptr);
663   for (int64_t i = 0; i < copies.size(); ++i) {
664     auto element = root->mutable_operand(i);
665     if (!element->shape().IsArray()) {
666       continue;
667     }
668     copies[i] = body->AddInstruction(HloInstruction::CreateUnary(
669         element->shape(), HloOpcode::kCopy, element));
670     TF_RETURN_IF_ERROR(root->ReplaceOperandWith(i, copies[i]));
671   }
672   for (int64_t i = 0; i < copies.size(); ++i) {
673     auto copy = copies[i];
674     if (groups->GetGroupIndex(copy).has_value()) {
675       // Already handled by earlier group members.
676       continue;
677     }
678     auto param_group_index = groups->GetGroupIndex(param_gtes[i]);
679     if (!param_group_index.has_value()) {
680       continue;
681     }
682     const auto& param_group = groups->GetGroup(param_group_index->first);
683     std::vector<HloInstruction*> copy_group(param_group.elements.size());
684     for (int64_t j = 0; j < copy_group.size(); ++j) {
685       copy_group[j] = copies[param_group.elements[j]->tuple_index()];
686     }
687     CHECK(groups
688               ->MaybeCreateNewGroup(
689                   ConcatGroup(std::move(copy_group), param_group.concat_dim,
690                               param_group.inserted_concat_dim))
691               .first);
692   }
693   return OkStatus();
694 }
695 
RemoveCopiesFromRoot(HloComputation * body)696 Status RemoveCopiesFromRoot(HloComputation* body) {
697   auto root = body->root_instruction();
698   CHECK_EQ(root->opcode(), HloOpcode::kTuple);
699   for (int64_t i = 0; i < root->operand_count(); ++i) {
700     auto copy = root->mutable_operand(i);
701     if (copy->opcode() == HloOpcode::kCopy) {
702       TF_RETURN_IF_ERROR(root->ReplaceOperandWith(i, copy->mutable_operand(0)));
703     }
704   }
705   return OkStatus();
706 }
707 
RewriteLoopWithConcatGroups(HloInstruction * loop,absl::Span<HloInstruction * const> param_gtes,ConcatGroups & groups)708 Status RewriteLoopWithConcatGroups(HloInstruction* loop,
709                                    absl::Span<HloInstruction* const> param_gtes,
710                                    ConcatGroups& groups) {
711   VLOG(1) << "RewriteLoopWithConcatGroups with " << groups.Groups().size()
712           << " groups.";
713   // For simplicity, for each group, we rewrite the first element into full
714   // shape, and leave the other elements unchagned. Non-grouped users will be
715   // have slices of the expanded first element as the new input. Later
716   // simplification and DCE passes can remove the other elements.
717   absl::flat_hash_set<int64_t> processed_groups;
718   auto body = loop->while_body();
719   auto param = body->parameter_instruction(0);
720   auto cond_param = loop->while_condition()->parameter_instruction(0);
721 
722   // First, modify loop signature and operands/users.
723   std::vector<HloInstruction*> init_elements(loop->shape().tuple_shapes_size());
724   for (int64_t i = 0; i < param_gtes.size(); ++i) {
725     init_elements[i] =
726         loop->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
727             loop->shape().tuple_shapes(i), loop->mutable_operand(0), i));
728   }
729   for (int64_t i = 0; i < param_gtes.size(); ++i) {
730     const auto& group_and_index = groups.GetGroupIndex(param_gtes[i]);
731     if (!group_and_index.has_value() || group_and_index->second != 0) {
732       continue;
733     }
734     const auto& group = groups.GetGroup(group_and_index->first);
735     // Change body parameter shape.
736     *param_gtes[i]->mutable_shape() = group.GetConcatShape();
737     *param->mutable_shape()->mutable_tuple_shapes(i) = param_gtes[i]->shape();
738     *body->root_instruction()->mutable_shape()->mutable_tuple_shapes(i) =
739         param_gtes[i]->shape();
740     *cond_param->mutable_shape()->mutable_tuple_shapes(i) =
741         param_gtes[i]->shape();
742     *loop->mutable_shape()->mutable_tuple_shapes(i) = param_gtes[i]->shape();
743     processed_groups.insert(group_and_index->first);
744     std::vector<HloInstruction*> input_concat_elements;
745     input_concat_elements.reserve(group.elements.size());
746     for (auto param_gte : group.elements) {
747       input_concat_elements.push_back(init_elements[param_gte->tuple_index()]);
748     }
749     init_elements[i] =
750         group.CreateConcat(std::move(input_concat_elements), loop->parent());
751   }
752   TF_RETURN_IF_ERROR(loop->ReplaceOperandWithDifferentShape(
753       0, loop->parent()->AddInstruction(
754              HloInstruction::CreateTuple(init_elements))));
755   // Adjust loop users.
756   auto original_loop_users = loop->users();
757   const bool loop_is_root = loop == loop->parent()->root_instruction();
758   std::vector<HloInstruction*> output_elements(
759       loop->shape().tuple_shapes_size());
760   for (int64_t i = 0; i < param_gtes.size(); ++i) {
761     output_elements[i] =
762         loop->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
763             init_elements[i]->shape(), loop, i));
764   }
765   for (int64_t i = 0; i < param_gtes.size(); ++i) {
766     const auto& group_and_index = groups.GetGroupIndex(param_gtes[i]);
767     if (!group_and_index.has_value() || group_and_index->second != 0) {
768       continue;
769     }
770     const auto& group = groups.GetGroup(group_and_index->first);
771     auto concat_output = output_elements[group.elements[0]->tuple_index()];
772     for (int64_t j = 0; j < group.elements.size(); ++j) {
773       const auto param_gte = group.elements[j];
774       output_elements[param_gte->tuple_index()] =
775           group.CreateSlice(concat_output, j, loop->parent());
776     }
777   }
778   auto new_output_tuple = loop->parent()->AddInstruction(
779       HloInstruction::CreateTuple(output_elements));
780   for (auto user : original_loop_users) {
781     TF_RETURN_IF_ERROR(
782         loop->ReplaceUseWithDifferentShape(user, new_output_tuple));
783   }
784   if (loop_is_root) {
785     loop->parent()->set_root_instruction(new_output_tuple,
786                                          /*accept_different_shape=*/true);
787   }
788 
789   // Now rewrite the loop body.
790   std::vector<HloInstruction*> slices_to_remove;
791   absl::flat_hash_set<HloInstruction*> new_reshapes;
792   for (auto hlo : body->MakeInstructionPostOrder()) {
793     const auto& group_and_index = groups.GetGroupIndex(hlo);
794     if (!group_and_index.has_value() || group_and_index->second != 0) {
795       continue;
796     }
797 
798     if (!processed_groups.insert(group_and_index->first).second) {
799       // Already processed the group at the first element.
800       continue;
801     }
802     const auto& group = groups.GetGroup(group_and_index->first);
803     if (hlo->opcode() == HloOpcode::kSlice) {
804       // We could just replace hlo with its operand; however, to follow the
805       // practice of using the first element as full data, we defer that
806       // replacement.
807       slices_to_remove.push_back(hlo);
808     } else {
809       int64_t operand_count_to_adjust = hlo->operand_count();
810       if (hlo->opcode() == HloOpcode::kReduce) {
811         CHECK_EQ(operand_count_to_adjust, 2);
812         operand_count_to_adjust = 1;
813       }
814       for (int64_t i = 0; i < operand_count_to_adjust; ++i) {
815         auto operand_group_index = groups.GetGroupIndex(hlo->operand(i));
816         const ConcatGroup* operand_group =
817             operand_group_index.has_value()
818                 ? &groups.GetGroup(operand_group_index->first)
819                 : nullptr;
820         auto maybe_operand_concat_dim = GetOperandConcatDim(
821             hlo, i, group.concat_dim, group.inserted_concat_dim, operand_group);
822         CHECK(maybe_operand_concat_dim.has_value())
823             << "Operand " << i << " of " << hlo->ToString();
824         int64_t operand_concat_dim = maybe_operand_concat_dim->first;
825         bool operand_inserted_concat_dim = maybe_operand_concat_dim->second;
826         if (operand_group != nullptr) {
827           CHECK_EQ(operand_concat_dim, operand_group->concat_dim);
828           if (operand_inserted_concat_dim !=
829               operand_group->inserted_concat_dim) {
830             // The operand's actual inserted_concat_dim doesn't match the
831             // expected operand_inserted_concat_dim. Need a reshape.
832             std::vector<int64_t> new_dims;
833             int64_t d = 0;
834             for (; d < operand_concat_dim; ++d) {
835               new_dims.push_back(hlo->operand(i)->shape().dimensions(d));
836             }
837             if (operand_inserted_concat_dim) {
838               // Split operand concat dim.
839               new_dims.push_back(group.elements.size());
840               new_dims.push_back(
841                   hlo->operand(i)->shape().dimensions(operand_concat_dim) /
842                   group.elements.size());
843               d = operand_concat_dim + 1;
844             } else {
845               // Combine operand concat dim with the next.
846               new_dims.push_back(
847                   group.elements.size() *
848                   hlo->operand(i)->shape().dimensions(operand_concat_dim + 1));
849               d = operand_concat_dim + 2;
850             }
851             for (; d < hlo->operand(i)->shape().rank(); ++d) {
852               new_dims.push_back(hlo->operand(i)->shape().dimensions(d));
853             }
854             auto reshape = body->AddInstruction(HloInstruction::CreateReshape(
855                 ShapeUtil::MakeShape(hlo->operand(i)->shape().element_type(),
856                                      new_dims),
857                 hlo->mutable_operand(i)));
858             new_reshapes.insert(reshape);
859             TF_RETURN_IF_ERROR(
860                 hlo->ReplaceOperandWithDifferentShape(i, reshape));
861           }
862           continue;
863         }
864         // This is a shared operand, we need to broadcast it.
865         CHECK(
866             absl::c_all_of(group.elements, [&](const HloInstruction* element) {
867               return element->operand(i) == hlo->operand(i);
868             }));
869         VLOG(2) << "Broadcasting shared operand "
870                 << hlo->operand(i)->ToString();
871         Shape data_shape = hlo->operand(i)->shape();
872         std::vector<int64_t> broadcast_dims;
873         std::vector<int64_t> broadcast_shape;
874         const int64_t data_shape_rank = data_shape.rank();
875         broadcast_dims.reserve(data_shape_rank);
876         broadcast_shape.reserve(data_shape_rank + 1);
877         for (int64_t j = 0; j < data_shape_rank; ++j) {
878           if (j < operand_concat_dim) {
879             broadcast_dims.push_back(j);
880           } else {
881             broadcast_dims.push_back(j + 1);
882           }
883           if (j == operand_concat_dim) {
884             broadcast_shape.push_back(group.elements.size());
885           }
886           broadcast_shape.push_back(data_shape.dimensions(j));
887         }
888         if (broadcast_shape.size() == data_shape.rank()) {
889           // New dim at the end.
890           broadcast_shape.push_back(group.elements.size());
891         }
892         auto broadcast = body->AddInstruction(HloInstruction::CreateBroadcast(
893             ShapeUtil::MakeShape(data_shape.element_type(), broadcast_shape),
894             hlo->mutable_operand(i), broadcast_dims));
895 
896         if (!operand_inserted_concat_dim) {
897           // Concat on existing dim. Reshape to merge the broadcast dim.
898           data_shape.set_dimensions(
899               operand_concat_dim,
900               data_shape.dimensions(operand_inserted_concat_dim) *
901                   group.elements.size());
902           broadcast = body->AddInstruction(
903               HloInstruction::CreateReshape(data_shape, broadcast));
904         }
905         TF_RETURN_IF_ERROR(hlo->ReplaceOperandWithDifferentShape(i, broadcast));
906       }
907     }
908     VLOG(2) << "Modifying HLO to full shape " << hlo->ToString();
909     ModifyHloPropertiesForConcatShape(group, hlo);
910     VLOG(2) << "Modified HLO to full shape " << hlo->ToString();
911   }
912 
913   // For non-grouped HLOs, replace grouped inputs with slices. Also inlcude
914   // grouped reduce HLOs because their init values are not grouped.
915   for (auto hlo : body->MakeInstructionPostOrder()) {
916     if (new_reshapes.contains(hlo)) {
917       continue;
918     }
919     const auto& group_and_index = groups.GetGroupIndex(hlo);
920     if ((!group_and_index.has_value() || hlo->opcode() == HloOpcode::kReduce) &&
921         hlo != body->root_instruction()) {
922       auto operands = hlo->operands();
923       if (group_and_index.has_value()) {
924         // Only handle reduce init value.
925         CHECK_EQ(operands.size(), 2);
926         CHECK_EQ(hlo->opcode(), HloOpcode::kReduce);
927         operands.erase(operands.begin());
928       }
929       for (int64_t i = 0; i < operands.size(); ++i) {
930         auto operand = operands[i];
931         auto operand_group_index = groups.GetGroupIndex(operand);
932         if (!operand_group_index.has_value()) {
933           continue;
934         }
935         const auto& operand_group = groups.GetGroup(operand_group_index->first);
936         auto slice = operand_group.CreateSlice(
937             operand_group.elements[0], operand_group_index->second, body);
938         TF_RETURN_IF_ERROR(hlo->ReplaceOperandWithDifferentShape(i, slice));
939       }
940     }
941   }
942   for (auto slice : slices_to_remove) {
943     TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(slice->mutable_operand(0)));
944     TF_RETURN_IF_ERROR(body->RemoveInstruction(slice));
945   }
946   return OkStatus();
947 }
948 
RunOnLoop(HloInstruction * loop,int64_t min_operand_count_to_optimize)949 StatusOr<bool> RunOnLoop(HloInstruction* loop,
950                          int64_t min_operand_count_to_optimize) {
951   auto body = loop->while_body();
952   auto param = body->parameter_instruction(0);
953   auto root = body->root_instruction();
954   if (!param->shape().IsTuple() || root->opcode() != HloOpcode::kTuple) {
955     return false;
956   }
957   std::vector<HloInstruction*> gtes(param->shape().tuple_shapes_size(),
958                                     nullptr);
959   ConcatGroups groups;
960   auto indices_used_in_cond = TupleElementsUsedInCond(loop);
961   for (auto user : param->users()) {
962     if (user->opcode() != HloOpcode::kGetTupleElement) {
963       // Unhandled user opcode.
964       return false;
965     }
966     int64_t idx = user->tuple_index();
967     if (gtes[idx] != nullptr) {
968       // Seen this index before.
969       return false;
970     }
971     gtes[idx] = user;
972     if (indices_used_in_cond[idx]) {
973       groups.DisallowGroupingOn(user);
974     }
975   }
976   std::vector<HloInstruction*> concats;
977   auto body_instructions = body->MakeInstructionPostOrder();
978   absl::flat_hash_map<const HloInstruction*, int64_t> topological_order;
979   for (int64_t i = 0; i < body_instructions.size(); ++i) {
980     auto hlo = body_instructions[i];
981     topological_order[hlo] = i;
982     if (hlo->opcode() == HloOpcode::kConcatenate &&
983         hlo->operand_count() >= min_operand_count_to_optimize) {
984       concats.push_back(hlo);
985     }
986   }
987 
988   for (auto& concat : concats) {
989     if (!GroupHlosForConcat(body, concat, topological_order, &groups)) {
990       concat = nullptr;
991     }
992   }
993   if (groups.Groups().empty()) {
994     return false;
995   }
996 
997   TF_RETURN_IF_ERROR(AddCopiesToRoot(body, gtes, &groups));
998   TF_RETURN_IF_ERROR(RewriteLoopWithConcatGroups(loop, gtes, groups));
999   for (auto concat : concats) {
1000     if (concat == nullptr) {
1001       continue;
1002     }
1003     // We have repalced the operands of the concat with slices of full data.
1004     auto new_slice = concat->mutable_operand(0);
1005     CHECK_EQ(new_slice->opcode(), HloOpcode::kSlice);
1006     TF_RETURN_IF_ERROR(
1007         concat->ReplaceAllUsesWith(new_slice->mutable_operand(0)));
1008     TF_RETURN_IF_ERROR(body->RemoveInstruction(concat));
1009   }
1010   TF_RETURN_IF_ERROR(RemoveCopiesFromRoot(body));
1011   // Finally pass-through replaced elements from parameter to root, so that
1012   // while loop simplifier can get rid of them.
1013   for (auto gte : gtes) {
1014     auto group_index = groups.GetGroupIndex(gte);
1015     if (group_index.has_value() && group_index->second > 0) {
1016       TF_RETURN_IF_ERROR(root->ReplaceOperandWith(gte->tuple_index(), gte));
1017     }
1018   }
1019   return true;
1020 }
1021 
1022 }  // namespace
1023 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)1024 StatusOr<bool> WhileLoopConcatCodeMotion::Run(
1025     HloModule* module,
1026     const absl::flat_hash_set<absl::string_view>& execution_threads) {
1027   bool changed = false;
1028   for (HloComputation* comp :
1029        module->MakeComputationPostOrder(execution_threads)) {
1030     for (HloInstruction* hlo : comp->MakeInstructionPostOrder()) {
1031       if (hlo->opcode() == HloOpcode::kWhile) {
1032         TF_ASSIGN_OR_RETURN(bool loop_changed,
1033                             RunOnLoop(hlo, min_operand_count_to_optimize_));
1034         changed |= loop_changed;
1035       }
1036     }
1037   }
1038   if (changed) {
1039     HloPassPipeline pipeline("loop-concat-motion-cleanup");
1040     pipeline.AddPass<TupleSimplifier>();
1041     pipeline.AddPass<HloDCE>();
1042     pipeline.AddPass<WhileLoopSimplifier>();
1043     pipeline.AddPass<TupleSimplifier>();
1044     pipeline.AddPass<HloDCE>();
1045     TF_RETURN_IF_ERROR(pipeline.Run(module, execution_threads).status());
1046   }
1047   return changed;
1048 }
1049 
1050 }  // namespace xla
1051