1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_concat_code_motion.h"
17
18 #include <map>
19 #include <optional>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_dce.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_module.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
34 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
35 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/platform/status.h"
45 #include "tensorflow/stream_executor/lib/statusor.h"
46
47 namespace xla {
48
49 namespace {
50
51 // This algorithm tries to group HLO instructions into concat candidates. Each
52 // instruction can only belong to a single group.
53 //
54 // For simplicity, after finding the groups, it in-place updates the first group
55 // member to the full shape, and replaces non-grouped uses with slices of it.
56 // Then it relies on TupleSimplifier, WhileLoopSimplifier, and DCE passes to
57 // remove other elements.
58
59 // Represents a group of elements and how to concat them.
60 struct ConcatGroup {
ConcatGroupxla::__anonbf9fd6d50111::ConcatGroup61 ConcatGroup(std::vector<HloInstruction*> elements, int64_t concat_dim,
62 bool inserted_concat_dim)
63 : elements(std::move(elements)),
64 element_sizes(this->elements.size(), 1),
65 element_offsets(this->elements.size(), 0),
66 concat_dim(concat_dim),
67 inserted_concat_dim(inserted_concat_dim) {
68 if (inserted_concat_dim) {
69 absl::c_iota(element_offsets, 0);
70 } else {
71 for (int64_t i = 0; i < element_sizes.size(); ++i) {
72 element_sizes[i] = this->elements[i]->shape().dimensions(concat_dim);
73 if (i > 0) {
74 element_offsets[i] = element_offsets[i - 1] + element_sizes[i - 1];
75 }
76 }
77 }
78 }
79
GetConcatShapexla::__anonbf9fd6d50111::ConcatGroup80 Shape GetConcatShape() const {
81 if (inserted_concat_dim) {
82 std::vector<int64_t> dims;
83 const Shape& element_shape = elements.back()->shape();
84 dims.reserve(element_shape.rank() + 1);
85 for (int64_t i = 0; i < element_shape.rank(); ++i) {
86 if (i == concat_dim) {
87 dims.push_back(elements.size());
88 }
89 dims.push_back(element_shape.dimensions(i));
90 }
91 if (dims.size() == concat_dim) {
92 dims.push_back(elements.size());
93 }
94 return ShapeUtil::MakeShape(element_shape.element_type(), dims);
95 } else {
96 int64_t dim_size = 0;
97 for (int64_t size : element_sizes) {
98 dim_size += size;
99 }
100 Shape shape = elements.back()->shape();
101 shape.set_dimensions(concat_dim, dim_size);
102 return shape;
103 }
104 }
105
CreateSlicexla::__anonbf9fd6d50111::ConcatGroup106 HloInstruction* CreateSlice(HloInstruction* full_data, int64_t element_index,
107 HloComputation* comp) const {
108 Shape shape = full_data->shape();
109 shape.set_dimensions(concat_dim, element_sizes[element_index]);
110 std::vector<int64_t> starts(shape.rank(), 0);
111 std::vector<int64_t> limits(shape.dimensions().begin(),
112 shape.dimensions().end());
113 starts[concat_dim] = element_offsets[element_index];
114 limits[concat_dim] += starts[concat_dim];
115 auto slice = comp->AddInstruction(
116 HloInstruction::CreateSlice(shape, full_data, starts, limits,
117 std::vector<int64_t>(shape.rank(), 1)));
118 if (!inserted_concat_dim) {
119 return slice;
120 }
121 std::vector<int64_t> element_shape;
122 element_shape.reserve(shape.rank() - 1);
123 for (int64_t i = 0; i < shape.rank(); ++i) {
124 if (i != concat_dim) {
125 element_shape.push_back(shape.dimensions(i));
126 }
127 }
128 return comp->AddInstruction(HloInstruction::CreateReshape(
129 ShapeUtil::MakeShape(shape.element_type(), element_shape), slice));
130 }
131
CreateConcatxla::__anonbf9fd6d50111::ConcatGroup132 HloInstruction* CreateConcat(std::vector<HloInstruction*> input_elements,
133 HloComputation* comp) const {
134 if (inserted_concat_dim) {
135 for (int64_t i = 0; i < input_elements.size(); ++i) {
136 std::vector<int64_t> element_shape;
137 element_shape.reserve(input_elements[i]->shape().rank() + 1);
138 for (int64_t j = 0; j < input_elements[i]->shape().rank(); ++j) {
139 if (j == concat_dim) {
140 element_shape.push_back(1);
141 }
142 element_shape.push_back(input_elements[i]->shape().dimensions(j));
143 }
144 if (element_shape.size() == concat_dim) {
145 element_shape.push_back(1);
146 }
147 input_elements[i] = comp->AddInstruction(HloInstruction::CreateReshape(
148 ShapeUtil::MakeShape(input_elements[i]->shape().element_type(),
149 element_shape),
150 input_elements[i]));
151 }
152 }
153
154 return comp->AddInstruction(HloInstruction::CreateConcatenate(
155 GetConcatShape(), input_elements, concat_dim));
156 }
157
158 std::vector<HloInstruction*> elements;
159 std::vector<int64_t> element_sizes;
160 std::vector<int64_t> element_offsets;
161 int64_t concat_dim;
162 // Whether the concat dim is an inserted new dimension.
163 bool inserted_concat_dim;
164 };
165
166 // A collection of ConcatGroup's where each HLO can only belong to a single
167 // group.
168 class ConcatGroups {
169 public:
170 // Returns the group index and element index in group for an HLO, if it
171 // belongs to a group.
GetGroupIndex(const HloInstruction * hlo) const172 std::optional<std::pair<int64_t, int64_t>> GetGroupIndex(
173 const HloInstruction* hlo) const {
174 auto it = element_to_group_.find(hlo);
175 if (it == element_to_group_.end()) {
176 return std::nullopt;
177 }
178 return it->second;
179 }
180
GetGroup(int64_t index) const181 const ConcatGroup& GetGroup(int64_t index) const { return groups_[index]; }
182
183 // Creates a new group and returns the index if it doesn't exist, or returns
184 // existing group index. If the new group doesn't match exactly with an
185 // existing group but shared some of the elements, returns -1 as the index.
186 // It also returns whether a new group is created. So the return value is a
187 // pair of {whether created, group index}.
MaybeCreateNewGroup(ConcatGroup group)188 std::pair<bool, int64_t> MaybeCreateNewGroup(ConcatGroup group) {
189 int64_t group_id = -1;
190 absl::flat_hash_set<HloInstruction*> elements_dedup;
191 for (int64_t i = 0; i < group.elements.size(); ++i) {
192 if (!elements_dedup.insert(group.elements[i]).second) {
193 VLOG(2) << "Duplicates in group. Element: "
194 << group.elements[i]->ToString();
195 }
196 if (concat_disallowed_.contains(group.elements[i])) {
197 VLOG(2) << "Failed creating group. Grouping disallowed on "
198 << group.elements[i]->ToString();
199 return std::pair<bool, int64_t>(false, -1);
200 }
201 auto existing = GetGroupIndex(group.elements[i]);
202 if (existing.has_value() &&
203 (i != existing->second ||
204 groups_[existing->first].concat_dim != group.concat_dim)) {
205 // We allow mismatched inserted_concat_dim, since that only requires a
206 // trivial reshape.
207 VLOG(2)
208 << "Failed creating group. Different than existing group. Element: "
209 << group.elements[i]->ToString();
210 return std::pair<bool, int64_t>(false, -1);
211 }
212 if (i == 0 && existing.has_value()) {
213 group_id = existing->first;
214 }
215 if (i > 0) {
216 if (existing.has_value() && existing->first != group_id) {
217 VLOG(2) << "Failed creating group. Different than existing group. "
218 "Element: "
219 << group.elements[i]->ToString();
220 return std::pair<bool, int64_t>(false, -1);
221 }
222 if (!existing.has_value() && group_id >= 0) {
223 VLOG(2) << "Failed creating group. Different than existing group. "
224 "Element: "
225 << group.elements[i]->ToString();
226 return std::pair<bool, int64_t>(false, -1);
227 }
228 }
229 }
230 if (group_id >= 0) {
231 VLOG(2) << "Group already exists at " << group_id << " for "
232 << group.elements[0]->ToString();
233 return std::pair<bool, int64_t>(false, group_id);
234 }
235 int64_t index = groups_.size();
236 for (int64_t i = 0; i < group.elements.size(); ++i) {
237 element_to_group_[group.elements[i]] =
238 std::pair<int64_t, int64_t>(index, i);
239 }
240 VLOG(2) << "Created new group at " << index << " for "
241 << group.elements[0]->ToString()
242 << ", concat_dim: " << group.concat_dim
243 << ", inserted: " << group.inserted_concat_dim;
244 groups_.push_back(std::move(group));
245 return std::pair<bool, int64_t>(true, index);
246 }
247
Groups() const248 const std::vector<ConcatGroup>& Groups() const { return groups_; }
249
NextGroupIndex() const250 int64_t NextGroupIndex() const { return groups_.size(); }
251
RemoveTailingGroups(int64_t start_index)252 void RemoveTailingGroups(int64_t start_index) {
253 while (groups_.size() > start_index) {
254 for (auto element : groups_.back().elements) {
255 element_to_group_.erase(element);
256 }
257 groups_.pop_back();
258 }
259 }
260
DisallowGroupingOn(const HloInstruction * hlo)261 void DisallowGroupingOn(const HloInstruction* hlo) {
262 VLOG(2) << "Disallow grouping on " << hlo->ToString();
263 concat_disallowed_.insert(hlo);
264 }
265
266 private:
267 // element -> {group index in groups_, element index in group}.
268 absl::flat_hash_map<const HloInstruction*, std::pair<int64_t, int64_t>>
269 element_to_group_;
270 std::vector<ConcatGroup> groups_;
271 absl::flat_hash_set<const HloInstruction*> concat_disallowed_;
272 };
273
274 // Infers an operand's concat dim and whether it's an inserted dim. For example,
275 // if hlo is f32[2,4,2] broadcast(f32[2,4]), dimensions={0,1} concatenated on
276 // dim 2, then this function will return {2, true}.
277 //
278 // If the operand is already transformed to the combined shape, specify its
279 // group in combined_operand_group. (Only required for kReshape.)
GetOperandConcatDim(const HloInstruction * hlo,int64_t operand_index,int64_t hlo_concat_dim,bool hlo_inserted_concat_dim,const ConcatGroup * combined_operand_group=nullptr)280 std::optional<std::pair<int64_t, bool>> GetOperandConcatDim(
281 const HloInstruction* hlo, int64_t operand_index, int64_t hlo_concat_dim,
282 bool hlo_inserted_concat_dim,
283 const ConcatGroup* combined_operand_group = nullptr) {
284 if (hlo->IsElementwise() || hlo->opcode() == HloOpcode::kAllReduce) {
285 return std::pair<int64_t, bool>(hlo_concat_dim, hlo_inserted_concat_dim);
286 }
287 int64_t operand_concat_dim = -1;
288 bool operand_inserted_concat_dim = false;
289 const Shape& operand_shape =
290 combined_operand_group == nullptr
291 ? hlo->operand(operand_index)->shape()
292 : combined_operand_group->elements.back()->shape();
293 if (hlo->opcode() == HloOpcode::kBroadcast) {
294 operand_concat_dim = 0;
295 operand_inserted_concat_dim = true;
296 // Try to place operand_concat_dim adjacent to dims the same way as the
297 // output, if it does not exist in the operand..
298 int64_t min_dist_to_concat_dim = hlo->shape().rank();
299 for (int64_t i = 0; i < operand_shape.rank(); ++i) {
300 if (hlo->dimensions(i) == hlo_concat_dim) {
301 operand_concat_dim = i;
302 operand_inserted_concat_dim = hlo_inserted_concat_dim;
303 break;
304 }
305 if (hlo->dimensions(i) < hlo_concat_dim &&
306 min_dist_to_concat_dim > hlo_concat_dim - hlo->dimensions(i)) {
307 operand_concat_dim = i + 1;
308 min_dist_to_concat_dim = hlo_concat_dim - hlo->dimensions(i);
309 }
310 if (hlo->dimensions(i) > hlo_concat_dim &&
311 min_dist_to_concat_dim > hlo->dimensions(i) - hlo_concat_dim) {
312 operand_concat_dim = i;
313 min_dist_to_concat_dim = hlo->dimensions(i) - hlo_concat_dim;
314 }
315 }
316 } else if (hlo->opcode() == HloOpcode::kReduce) {
317 if (operand_index != 0) {
318 return std::nullopt;
319 }
320 operand_concat_dim = hlo_concat_dim;
321 operand_inserted_concat_dim = hlo_inserted_concat_dim;
322 std::set<int64_t> sorted_reduce_dims;
323 for (int64_t dim : hlo->dimensions()) {
324 sorted_reduce_dims.insert(dim);
325 }
326 for (int64_t dim : sorted_reduce_dims) {
327 if ((hlo_inserted_concat_dim && dim < operand_concat_dim) ||
328 (!hlo_inserted_concat_dim && dim <= operand_concat_dim)) {
329 operand_concat_dim++;
330 }
331 }
332 } else if (hlo->opcode() == HloOpcode::kReshape) {
333 int64_t i = 0;
334 int64_t j = 0;
335 operand_inserted_concat_dim = false;
336 // Only support adding/removing trivial dims.
337 while (i < operand_shape.rank() || j <= hlo_concat_dim) {
338 if (i < operand_shape.rank() && j < hlo->shape().rank() &&
339 operand_shape.dimensions(i) == hlo->shape().dimensions(j)) {
340 if (j == hlo_concat_dim) {
341 operand_inserted_concat_dim =
342 hlo_inserted_concat_dim && operand_shape.dimensions(i) != 1;
343 operand_concat_dim = i;
344 break;
345 }
346 i++;
347 j++;
348 continue;
349 }
350 if (i < operand_shape.rank() && operand_shape.dimensions(i) == 1) {
351 if (j == hlo_concat_dim && hlo_inserted_concat_dim) {
352 operand_concat_dim = i;
353 break;
354 }
355 i++;
356 continue;
357 }
358 if (j == hlo_concat_dim) {
359 operand_concat_dim = i;
360 operand_inserted_concat_dim = true;
361 break;
362 }
363 if (j < hlo->shape().rank() && hlo->shape().dimensions(j) == 1) {
364 j++;
365 continue;
366 }
367 return std::nullopt;
368 }
369 } else {
370 return std::nullopt;
371 }
372 CHECK_GE(operand_concat_dim, 0);
373 return std::pair<int64_t, bool>(operand_concat_dim,
374 operand_inserted_concat_dim);
375 }
376
ModifyHloPropertiesForConcatShape(const ConcatGroup & group,HloInstruction * hlo)377 void ModifyHloPropertiesForConcatShape(const ConcatGroup& group,
378 HloInstruction* hlo) {
379 *hlo->mutable_shape() = group.GetConcatShape();
380 if (hlo->opcode() == HloOpcode::kBroadcast) {
381 // Use the last element to infer the operand concat dim, since the first
382 // element's operand might have been rewriten.
383 auto operand_dim = GetOperandConcatDim(
384 group.elements.back(), 0, group.concat_dim, group.inserted_concat_dim);
385 CHECK(operand_dim.has_value());
386 int64_t operand_concat_dim = operand_dim->first;
387 bool operand_inserted_concat_dim = operand_dim->second;
388 if (operand_inserted_concat_dim) {
389 // We should have added an dimension on the operand.
390 CHECK_EQ(hlo->operand(0)->shape().rank(), hlo->dimensions().size() + 1)
391 << hlo->ToString();
392 } else {
393 CHECK_EQ(hlo->operand(0)->shape().rank(), hlo->dimensions().size());
394 }
395 std::vector<int64_t> dims;
396 const int64_t rank = hlo->operand(0)->shape().rank();
397 dims.reserve(rank);
398 for (int64_t i = 0; i < rank; ++i) {
399 if (i == operand_concat_dim && operand_inserted_concat_dim) {
400 dims.push_back(group.concat_dim);
401 } else {
402 if (i > operand_concat_dim && operand_inserted_concat_dim) {
403 dims.push_back(hlo->dimensions(i - 1));
404 } else {
405 dims.push_back(hlo->dimensions(i));
406 }
407 if (group.inserted_concat_dim && dims.back() >= group.concat_dim) {
408 dims.back()++;
409 }
410 }
411 }
412 *hlo->mutable_dimensions() = std::move(dims);
413 } else if (hlo->opcode() == HloOpcode::kReduce) {
414 auto operand_dim = GetOperandConcatDim(
415 group.elements.back(), 0, group.concat_dim, group.inserted_concat_dim);
416 int64_t operand_concat_dim = operand_dim->first;
417 bool operand_inserted_concat_dim = operand_dim->second;
418 CHECK(operand_dim.has_value());
419 if (operand_inserted_concat_dim) {
420 auto dims = hlo->mutable_dimensions();
421 for (int64_t i = 0; i < dims->size(); ++i) {
422 if ((*dims)[i] >= operand_concat_dim) {
423 (*dims)[i]++;
424 }
425 }
426 }
427 }
428 }
429
430 // Main method to assign groups to HLOs, based on a concat.
GroupHlosForConcat(HloComputation * body,HloInstruction * concat,absl::flat_hash_map<const HloInstruction *,int64_t> topological_order,ConcatGroups * groups)431 bool GroupHlosForConcat(
432 HloComputation* body, HloInstruction* concat,
433 absl::flat_hash_map<const HloInstruction*, int64_t> topological_order,
434 ConcatGroups* groups) {
435 const int64_t group_size = concat->operand_count();
436 absl::flat_hash_set<int64_t> used_groups;
437 auto root_tuple = body->root_instruction();
438 CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple);
439 absl::flat_hash_map<HloInstruction*, int64_t> root_tuple_element_use_count;
440 for (auto operand : root_tuple->operands()) {
441 root_tuple_element_use_count.emplace(operand, 0).first->second++;
442 }
443 // Priority Queue sorted by topological order. Users come before operands, so
444 // it uses -topological_order[element0] as the key. We start with the concat
445 // operands.
446 std::multimap<int64_t, ConcatGroup> pq;
447 const int64_t first_group_id_to_create = groups->NextGroupIndex();
448 auto fail_and_cleanup = [&] {
449 VLOG(1) << "Failed to get the subcomputation to optimize for "
450 << concat->ToString() << ", clear groups starting at "
451 << first_group_id_to_create;
452 groups->RemoveTailingGroups(first_group_id_to_create);
453 return false;
454 };
455 struct GroupUse {
456 int64_t group_id;
457 bool newly_created;
458 bool already_used_by_subcomp;
459 };
460 auto maybe_create_group = [&](ConcatGroup group) {
461 auto res = groups->MaybeCreateNewGroup(std::move(group));
462 GroupUse use{res.second, false, false};
463 if (res.second < 0) {
464 return use;
465 }
466 use.newly_created = res.first;
467 use.already_used_by_subcomp = !used_groups.insert(res.second).second;
468 return use;
469 };
470 std::vector<HloInstruction*> concat_operands(concat->operands().begin(),
471 concat->operands().end());
472 int64_t concat_operand_order = -topological_order[concat_operands[0]];
473 pq.emplace(concat_operand_order,
474 ConcatGroup(std::move(concat_operands),
475 concat->concatenate_dimension(), false));
476
477 // Find the subcomputation on elements to combine, in order to move `concat`
478 // out of the loop without adding new concats. We start from the concat's
479 // operands, and the priority queue is ordered in reverse topological order
480 // so we process outputs before inputs. Each entry in the queue is a group of
481 // elements to combine. A legitimate group consists of identical ops, except
482 // that they each operate on one element. When a group of loop inputs are
483 // processed, we also enqueue the corresponding loop outputs to keep them
484 // match in shape.
485 while (!pq.empty()) {
486 auto group = std::move(pq.begin()->second);
487 pq.erase(pq.begin());
488 const auto& hlos = group.elements;
489 VLOG(2) << "GroupHlosForConcat dequeued " << hlos[0]->ToString();
490 bool group_is_param_gtes = false;
491 if (absl::c_all_of(hlos, [&](const HloInstruction* element) {
492 return element == hlos[0];
493 })) {
494 // Shared operand.
495 if (groups->GetGroupIndex(hlos[0]).has_value()) {
496 VLOG(1) << "We do not support the case if a shared operand also part "
497 "of a group: "
498 << hlos[0]->ToString();
499 return fail_and_cleanup();
500 }
501 groups->DisallowGroupingOn(hlos[0]);
502 continue;
503 }
504 if (absl::c_all_of(hlos, [&](const HloInstruction* element) {
505 return element->opcode() == HloOpcode::kGetTupleElement &&
506 element->operand(0) == body->parameter_instruction(0);
507 })) {
508 group_is_param_gtes = true;
509 } else if (((hlos[0]->IsElementwise() ||
510 hlos[0]->opcode() == HloOpcode::kAllReduce) &&
511 !hlos[0]->HasSideEffect()) ||
512 hlos[0]->opcode() == HloOpcode::kBroadcast ||
513 hlos[0]->opcode() == HloOpcode::kReduce ||
514 hlos[0]->opcode() == HloOpcode::kReshape ||
515 hlos[0]->IsCustomCall("Sharding")) {
516 if (hlos[0]->opcode() == HloOpcode::kAllReduce &&
517 (!hlos[0]->shape().IsArray() || hlos[0]->IsCrossModuleAllReduce())) {
518 VLOG(2) << "Unsupported allreduce: " << hlos[0]->ToString();
519 return fail_and_cleanup();
520 }
521 // Check if these elements can be concatenated.
522 if (absl::c_any_of(hlos, [&](const HloInstruction* element) {
523 auto eq_operand = [](const HloInstruction* a,
524 const HloInstruction* b) {
525 return ShapeUtil::Compatible(a->shape(), b->shape());
526 };
527 auto eq_computations = [](const HloComputation* lhs,
528 const HloComputation* rhs) {
529 return lhs->Equal(*rhs, /*is_layout_sensitive=*/false);
530 };
531 if (!hlos[0]->Identical(*element, eq_operand, eq_computations,
532 /*layout_sensitive=*/false)) {
533 return true;
534 }
535 if (element->opcode() == HloOpcode::kReduce &&
536 (element->operand_count() != 2 ||
537 element->operand(1) != hlos[0]->operand(1))) {
538 return true;
539 }
540 return false;
541 })) {
542 VLOG(2) << "Different types of elements. First element: "
543 << hlos[0]->ToString();
544 return fail_and_cleanup();
545 }
546 // Now enqueue the inputs.
547 int64_t input_count = hlos[0]->operand_count();
548 if (hlos[0]->opcode() == HloOpcode::kReduce) {
549 CHECK_EQ(input_count, 2);
550 // Exclude the init value that we have checked to be the same.
551 input_count = 1;
552 }
553 for (int64_t i = 0; i < input_count; ++i) {
554 std::vector<HloInstruction*> elements(group_size);
555 for (int64_t j = 0; j < group_size; ++j) {
556 elements[j] = hlos[j]->mutable_operand(i);
557 }
558 auto maybe_new_concat_dim = GetOperandConcatDim(
559 hlos[0], i, group.concat_dim, group.inserted_concat_dim);
560 if (!maybe_new_concat_dim.has_value()) {
561 VLOG(2) << "Cannot find operand concat dimension for operand " << i
562 << " of " << hlos[0]->ToString();
563 return fail_and_cleanup();
564 }
565 int64_t new_group_concat_dim = maybe_new_concat_dim->first;
566 bool inserted_concat_dim = maybe_new_concat_dim->second;
567 // Enqueue the input group.
568 int64_t element_order = -topological_order[elements[0]];
569 pq.emplace(element_order,
570 ConcatGroup(std::move(elements), new_group_concat_dim,
571 inserted_concat_dim));
572 }
573 } else if (hlos[0]->opcode() == HloOpcode::kSlice) {
574 int64_t offset = 0;
575 auto operand = hlos[0]->operand(0);
576 if (group.inserted_concat_dim) {
577 VLOG(2) << "Slices cannot be grouped on new dimension.";
578 return fail_and_cleanup();
579 }
580 if (groups->GetGroupIndex(operand).has_value()) {
581 // Should not slice an operand to be grouped.
582 return fail_and_cleanup();
583 }
584 groups->DisallowGroupingOn(operand);
585 for (int64_t i = 0; i < group_size; ++i) {
586 if (hlos[i]->operand(0) != operand) {
587 VLOG(2) << "Slices of different operands.";
588 return fail_and_cleanup();
589 }
590 for (int64_t j = 0; j < hlos[i]->shape().rank(); ++j) {
591 if (hlos[i]->slice_strides(j) != 1) {
592 VLOG(2) << "Slices with strides.";
593 return fail_and_cleanup();
594 }
595 if (j == group.concat_dim) {
596 if (hlos[i]->slice_starts(j) != offset) {
597 VLOG(2) << "Slices with unsupported offsets.";
598 return fail_and_cleanup();
599 }
600 offset += hlos[i]->shape().dimensions(j);
601 } else {
602 if (hlos[i]->slice_starts(j) != 0 ||
603 hlos[i]->slice_limits(j) != operand->shape().dimensions(j)) {
604 VLOG(2) << "Slice with unsupported offsets at dimension " << j
605 << ", " << hlos[i]->ToString();
606 return fail_and_cleanup();
607 }
608 }
609 }
610 }
611 if (offset != operand->shape().dimensions(group.concat_dim)) {
612 VLOG(2) << "Slices with unsupported sizes.";
613 return fail_and_cleanup();
614 }
615 } else {
616 VLOG(2) << "Unsupported opcode: " << hlos[0]->ToString();
617 return fail_and_cleanup();
618 }
619 auto guse = maybe_create_group(std::move(group));
620 if (guse.group_id < 0) {
621 VLOG(2) << "Failed to create group.";
622 return fail_and_cleanup();
623 }
624 const auto& registered_group = groups->GetGroup(guse.group_id);
625 if (!guse.already_used_by_subcomp && group_is_param_gtes) {
626 // When we processed a group of parameter GTEs, we should also enqueue the
627 // corresponding root tuple operands, so that they have matching shapes.
628 std::vector<HloInstruction*> new_outputs(group_size);
629 for (int64_t i = 0; i < group_size; ++i) {
630 new_outputs[i] = root_tuple->mutable_operand(
631 registered_group.elements[i]->tuple_index());
632 }
633 int64_t new_output_order = -topological_order[new_outputs[0]];
634 pq.emplace(
635 new_output_order,
636 ConcatGroup(std::move(new_outputs), registered_group.concat_dim,
637 registered_group.inserted_concat_dim));
638 }
639 }
640 return groups->Groups().size() > first_group_id_to_create;
641 }
642
TupleElementsUsedInCond(HloInstruction * loop)643 std::vector<bool> TupleElementsUsedInCond(HloInstruction* loop) {
644 std::vector<bool> result(loop->shape().tuple_shapes_size(), false);
645 for (auto user : loop->while_condition()->parameter_instruction(0)->users()) {
646 if (user->opcode() != HloOpcode::kGetTupleElement) {
647 absl::c_fill(result, true);
648 return result;
649 }
650 result[user->tuple_index()] = true;
651 }
652 return result;
653 }
654
655 // Adds copies to returned values to keep RewriteLoopWithConcatGroups simple:
656 // the copies do not have other users and only appear once in the root tuple.
AddCopiesToRoot(HloComputation * body,absl::Span<HloInstruction * const> param_gtes,ConcatGroups * groups)657 Status AddCopiesToRoot(HloComputation* body,
658 absl::Span<HloInstruction* const> param_gtes,
659 ConcatGroups* groups) {
660 auto root = body->root_instruction();
661 CHECK_EQ(root->opcode(), HloOpcode::kTuple);
662 std::vector<HloInstruction*> copies(root->operand_count(), nullptr);
663 for (int64_t i = 0; i < copies.size(); ++i) {
664 auto element = root->mutable_operand(i);
665 if (!element->shape().IsArray()) {
666 continue;
667 }
668 copies[i] = body->AddInstruction(HloInstruction::CreateUnary(
669 element->shape(), HloOpcode::kCopy, element));
670 TF_RETURN_IF_ERROR(root->ReplaceOperandWith(i, copies[i]));
671 }
672 for (int64_t i = 0; i < copies.size(); ++i) {
673 auto copy = copies[i];
674 if (groups->GetGroupIndex(copy).has_value()) {
675 // Already handled by earlier group members.
676 continue;
677 }
678 auto param_group_index = groups->GetGroupIndex(param_gtes[i]);
679 if (!param_group_index.has_value()) {
680 continue;
681 }
682 const auto& param_group = groups->GetGroup(param_group_index->first);
683 std::vector<HloInstruction*> copy_group(param_group.elements.size());
684 for (int64_t j = 0; j < copy_group.size(); ++j) {
685 copy_group[j] = copies[param_group.elements[j]->tuple_index()];
686 }
687 CHECK(groups
688 ->MaybeCreateNewGroup(
689 ConcatGroup(std::move(copy_group), param_group.concat_dim,
690 param_group.inserted_concat_dim))
691 .first);
692 }
693 return OkStatus();
694 }
695
RemoveCopiesFromRoot(HloComputation * body)696 Status RemoveCopiesFromRoot(HloComputation* body) {
697 auto root = body->root_instruction();
698 CHECK_EQ(root->opcode(), HloOpcode::kTuple);
699 for (int64_t i = 0; i < root->operand_count(); ++i) {
700 auto copy = root->mutable_operand(i);
701 if (copy->opcode() == HloOpcode::kCopy) {
702 TF_RETURN_IF_ERROR(root->ReplaceOperandWith(i, copy->mutable_operand(0)));
703 }
704 }
705 return OkStatus();
706 }
707
RewriteLoopWithConcatGroups(HloInstruction * loop,absl::Span<HloInstruction * const> param_gtes,ConcatGroups & groups)708 Status RewriteLoopWithConcatGroups(HloInstruction* loop,
709 absl::Span<HloInstruction* const> param_gtes,
710 ConcatGroups& groups) {
711 VLOG(1) << "RewriteLoopWithConcatGroups with " << groups.Groups().size()
712 << " groups.";
713 // For simplicity, for each group, we rewrite the first element into full
714 // shape, and leave the other elements unchagned. Non-grouped users will be
715 // have slices of the expanded first element as the new input. Later
716 // simplification and DCE passes can remove the other elements.
717 absl::flat_hash_set<int64_t> processed_groups;
718 auto body = loop->while_body();
719 auto param = body->parameter_instruction(0);
720 auto cond_param = loop->while_condition()->parameter_instruction(0);
721
722 // First, modify loop signature and operands/users.
723 std::vector<HloInstruction*> init_elements(loop->shape().tuple_shapes_size());
724 for (int64_t i = 0; i < param_gtes.size(); ++i) {
725 init_elements[i] =
726 loop->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
727 loop->shape().tuple_shapes(i), loop->mutable_operand(0), i));
728 }
729 for (int64_t i = 0; i < param_gtes.size(); ++i) {
730 const auto& group_and_index = groups.GetGroupIndex(param_gtes[i]);
731 if (!group_and_index.has_value() || group_and_index->second != 0) {
732 continue;
733 }
734 const auto& group = groups.GetGroup(group_and_index->first);
735 // Change body parameter shape.
736 *param_gtes[i]->mutable_shape() = group.GetConcatShape();
737 *param->mutable_shape()->mutable_tuple_shapes(i) = param_gtes[i]->shape();
738 *body->root_instruction()->mutable_shape()->mutable_tuple_shapes(i) =
739 param_gtes[i]->shape();
740 *cond_param->mutable_shape()->mutable_tuple_shapes(i) =
741 param_gtes[i]->shape();
742 *loop->mutable_shape()->mutable_tuple_shapes(i) = param_gtes[i]->shape();
743 processed_groups.insert(group_and_index->first);
744 std::vector<HloInstruction*> input_concat_elements;
745 input_concat_elements.reserve(group.elements.size());
746 for (auto param_gte : group.elements) {
747 input_concat_elements.push_back(init_elements[param_gte->tuple_index()]);
748 }
749 init_elements[i] =
750 group.CreateConcat(std::move(input_concat_elements), loop->parent());
751 }
752 TF_RETURN_IF_ERROR(loop->ReplaceOperandWithDifferentShape(
753 0, loop->parent()->AddInstruction(
754 HloInstruction::CreateTuple(init_elements))));
755 // Adjust loop users.
756 auto original_loop_users = loop->users();
757 const bool loop_is_root = loop == loop->parent()->root_instruction();
758 std::vector<HloInstruction*> output_elements(
759 loop->shape().tuple_shapes_size());
760 for (int64_t i = 0; i < param_gtes.size(); ++i) {
761 output_elements[i] =
762 loop->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
763 init_elements[i]->shape(), loop, i));
764 }
765 for (int64_t i = 0; i < param_gtes.size(); ++i) {
766 const auto& group_and_index = groups.GetGroupIndex(param_gtes[i]);
767 if (!group_and_index.has_value() || group_and_index->second != 0) {
768 continue;
769 }
770 const auto& group = groups.GetGroup(group_and_index->first);
771 auto concat_output = output_elements[group.elements[0]->tuple_index()];
772 for (int64_t j = 0; j < group.elements.size(); ++j) {
773 const auto param_gte = group.elements[j];
774 output_elements[param_gte->tuple_index()] =
775 group.CreateSlice(concat_output, j, loop->parent());
776 }
777 }
778 auto new_output_tuple = loop->parent()->AddInstruction(
779 HloInstruction::CreateTuple(output_elements));
780 for (auto user : original_loop_users) {
781 TF_RETURN_IF_ERROR(
782 loop->ReplaceUseWithDifferentShape(user, new_output_tuple));
783 }
784 if (loop_is_root) {
785 loop->parent()->set_root_instruction(new_output_tuple,
786 /*accept_different_shape=*/true);
787 }
788
789 // Now rewrite the loop body.
790 std::vector<HloInstruction*> slices_to_remove;
791 absl::flat_hash_set<HloInstruction*> new_reshapes;
792 for (auto hlo : body->MakeInstructionPostOrder()) {
793 const auto& group_and_index = groups.GetGroupIndex(hlo);
794 if (!group_and_index.has_value() || group_and_index->second != 0) {
795 continue;
796 }
797
798 if (!processed_groups.insert(group_and_index->first).second) {
799 // Already processed the group at the first element.
800 continue;
801 }
802 const auto& group = groups.GetGroup(group_and_index->first);
803 if (hlo->opcode() == HloOpcode::kSlice) {
804 // We could just replace hlo with its operand; however, to follow the
805 // practice of using the first element as full data, we defer that
806 // replacement.
807 slices_to_remove.push_back(hlo);
808 } else {
809 int64_t operand_count_to_adjust = hlo->operand_count();
810 if (hlo->opcode() == HloOpcode::kReduce) {
811 CHECK_EQ(operand_count_to_adjust, 2);
812 operand_count_to_adjust = 1;
813 }
814 for (int64_t i = 0; i < operand_count_to_adjust; ++i) {
815 auto operand_group_index = groups.GetGroupIndex(hlo->operand(i));
816 const ConcatGroup* operand_group =
817 operand_group_index.has_value()
818 ? &groups.GetGroup(operand_group_index->first)
819 : nullptr;
820 auto maybe_operand_concat_dim = GetOperandConcatDim(
821 hlo, i, group.concat_dim, group.inserted_concat_dim, operand_group);
822 CHECK(maybe_operand_concat_dim.has_value())
823 << "Operand " << i << " of " << hlo->ToString();
824 int64_t operand_concat_dim = maybe_operand_concat_dim->first;
825 bool operand_inserted_concat_dim = maybe_operand_concat_dim->second;
826 if (operand_group != nullptr) {
827 CHECK_EQ(operand_concat_dim, operand_group->concat_dim);
828 if (operand_inserted_concat_dim !=
829 operand_group->inserted_concat_dim) {
830 // The operand's actual inserted_concat_dim doesn't match the
831 // expected operand_inserted_concat_dim. Need a reshape.
832 std::vector<int64_t> new_dims;
833 int64_t d = 0;
834 for (; d < operand_concat_dim; ++d) {
835 new_dims.push_back(hlo->operand(i)->shape().dimensions(d));
836 }
837 if (operand_inserted_concat_dim) {
838 // Split operand concat dim.
839 new_dims.push_back(group.elements.size());
840 new_dims.push_back(
841 hlo->operand(i)->shape().dimensions(operand_concat_dim) /
842 group.elements.size());
843 d = operand_concat_dim + 1;
844 } else {
845 // Combine operand concat dim with the next.
846 new_dims.push_back(
847 group.elements.size() *
848 hlo->operand(i)->shape().dimensions(operand_concat_dim + 1));
849 d = operand_concat_dim + 2;
850 }
851 for (; d < hlo->operand(i)->shape().rank(); ++d) {
852 new_dims.push_back(hlo->operand(i)->shape().dimensions(d));
853 }
854 auto reshape = body->AddInstruction(HloInstruction::CreateReshape(
855 ShapeUtil::MakeShape(hlo->operand(i)->shape().element_type(),
856 new_dims),
857 hlo->mutable_operand(i)));
858 new_reshapes.insert(reshape);
859 TF_RETURN_IF_ERROR(
860 hlo->ReplaceOperandWithDifferentShape(i, reshape));
861 }
862 continue;
863 }
864 // This is a shared operand, we need to broadcast it.
865 CHECK(
866 absl::c_all_of(group.elements, [&](const HloInstruction* element) {
867 return element->operand(i) == hlo->operand(i);
868 }));
869 VLOG(2) << "Broadcasting shared operand "
870 << hlo->operand(i)->ToString();
871 Shape data_shape = hlo->operand(i)->shape();
872 std::vector<int64_t> broadcast_dims;
873 std::vector<int64_t> broadcast_shape;
874 const int64_t data_shape_rank = data_shape.rank();
875 broadcast_dims.reserve(data_shape_rank);
876 broadcast_shape.reserve(data_shape_rank + 1);
877 for (int64_t j = 0; j < data_shape_rank; ++j) {
878 if (j < operand_concat_dim) {
879 broadcast_dims.push_back(j);
880 } else {
881 broadcast_dims.push_back(j + 1);
882 }
883 if (j == operand_concat_dim) {
884 broadcast_shape.push_back(group.elements.size());
885 }
886 broadcast_shape.push_back(data_shape.dimensions(j));
887 }
888 if (broadcast_shape.size() == data_shape.rank()) {
889 // New dim at the end.
890 broadcast_shape.push_back(group.elements.size());
891 }
892 auto broadcast = body->AddInstruction(HloInstruction::CreateBroadcast(
893 ShapeUtil::MakeShape(data_shape.element_type(), broadcast_shape),
894 hlo->mutable_operand(i), broadcast_dims));
895
896 if (!operand_inserted_concat_dim) {
897 // Concat on existing dim. Reshape to merge the broadcast dim.
898 data_shape.set_dimensions(
899 operand_concat_dim,
900 data_shape.dimensions(operand_inserted_concat_dim) *
901 group.elements.size());
902 broadcast = body->AddInstruction(
903 HloInstruction::CreateReshape(data_shape, broadcast));
904 }
905 TF_RETURN_IF_ERROR(hlo->ReplaceOperandWithDifferentShape(i, broadcast));
906 }
907 }
908 VLOG(2) << "Modifying HLO to full shape " << hlo->ToString();
909 ModifyHloPropertiesForConcatShape(group, hlo);
910 VLOG(2) << "Modified HLO to full shape " << hlo->ToString();
911 }
912
913 // For non-grouped HLOs, replace grouped inputs with slices. Also inlcude
914 // grouped reduce HLOs because their init values are not grouped.
915 for (auto hlo : body->MakeInstructionPostOrder()) {
916 if (new_reshapes.contains(hlo)) {
917 continue;
918 }
919 const auto& group_and_index = groups.GetGroupIndex(hlo);
920 if ((!group_and_index.has_value() || hlo->opcode() == HloOpcode::kReduce) &&
921 hlo != body->root_instruction()) {
922 auto operands = hlo->operands();
923 if (group_and_index.has_value()) {
924 // Only handle reduce init value.
925 CHECK_EQ(operands.size(), 2);
926 CHECK_EQ(hlo->opcode(), HloOpcode::kReduce);
927 operands.erase(operands.begin());
928 }
929 for (int64_t i = 0; i < operands.size(); ++i) {
930 auto operand = operands[i];
931 auto operand_group_index = groups.GetGroupIndex(operand);
932 if (!operand_group_index.has_value()) {
933 continue;
934 }
935 const auto& operand_group = groups.GetGroup(operand_group_index->first);
936 auto slice = operand_group.CreateSlice(
937 operand_group.elements[0], operand_group_index->second, body);
938 TF_RETURN_IF_ERROR(hlo->ReplaceOperandWithDifferentShape(i, slice));
939 }
940 }
941 }
942 for (auto slice : slices_to_remove) {
943 TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(slice->mutable_operand(0)));
944 TF_RETURN_IF_ERROR(body->RemoveInstruction(slice));
945 }
946 return OkStatus();
947 }
948
RunOnLoop(HloInstruction * loop,int64_t min_operand_count_to_optimize)949 StatusOr<bool> RunOnLoop(HloInstruction* loop,
950 int64_t min_operand_count_to_optimize) {
951 auto body = loop->while_body();
952 auto param = body->parameter_instruction(0);
953 auto root = body->root_instruction();
954 if (!param->shape().IsTuple() || root->opcode() != HloOpcode::kTuple) {
955 return false;
956 }
957 std::vector<HloInstruction*> gtes(param->shape().tuple_shapes_size(),
958 nullptr);
959 ConcatGroups groups;
960 auto indices_used_in_cond = TupleElementsUsedInCond(loop);
961 for (auto user : param->users()) {
962 if (user->opcode() != HloOpcode::kGetTupleElement) {
963 // Unhandled user opcode.
964 return false;
965 }
966 int64_t idx = user->tuple_index();
967 if (gtes[idx] != nullptr) {
968 // Seen this index before.
969 return false;
970 }
971 gtes[idx] = user;
972 if (indices_used_in_cond[idx]) {
973 groups.DisallowGroupingOn(user);
974 }
975 }
976 std::vector<HloInstruction*> concats;
977 auto body_instructions = body->MakeInstructionPostOrder();
978 absl::flat_hash_map<const HloInstruction*, int64_t> topological_order;
979 for (int64_t i = 0; i < body_instructions.size(); ++i) {
980 auto hlo = body_instructions[i];
981 topological_order[hlo] = i;
982 if (hlo->opcode() == HloOpcode::kConcatenate &&
983 hlo->operand_count() >= min_operand_count_to_optimize) {
984 concats.push_back(hlo);
985 }
986 }
987
988 for (auto& concat : concats) {
989 if (!GroupHlosForConcat(body, concat, topological_order, &groups)) {
990 concat = nullptr;
991 }
992 }
993 if (groups.Groups().empty()) {
994 return false;
995 }
996
997 TF_RETURN_IF_ERROR(AddCopiesToRoot(body, gtes, &groups));
998 TF_RETURN_IF_ERROR(RewriteLoopWithConcatGroups(loop, gtes, groups));
999 for (auto concat : concats) {
1000 if (concat == nullptr) {
1001 continue;
1002 }
1003 // We have repalced the operands of the concat with slices of full data.
1004 auto new_slice = concat->mutable_operand(0);
1005 CHECK_EQ(new_slice->opcode(), HloOpcode::kSlice);
1006 TF_RETURN_IF_ERROR(
1007 concat->ReplaceAllUsesWith(new_slice->mutable_operand(0)));
1008 TF_RETURN_IF_ERROR(body->RemoveInstruction(concat));
1009 }
1010 TF_RETURN_IF_ERROR(RemoveCopiesFromRoot(body));
1011 // Finally pass-through replaced elements from parameter to root, so that
1012 // while loop simplifier can get rid of them.
1013 for (auto gte : gtes) {
1014 auto group_index = groups.GetGroupIndex(gte);
1015 if (group_index.has_value() && group_index->second > 0) {
1016 TF_RETURN_IF_ERROR(root->ReplaceOperandWith(gte->tuple_index(), gte));
1017 }
1018 }
1019 return true;
1020 }
1021
1022 } // namespace
1023
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)1024 StatusOr<bool> WhileLoopConcatCodeMotion::Run(
1025 HloModule* module,
1026 const absl::flat_hash_set<absl::string_view>& execution_threads) {
1027 bool changed = false;
1028 for (HloComputation* comp :
1029 module->MakeComputationPostOrder(execution_threads)) {
1030 for (HloInstruction* hlo : comp->MakeInstructionPostOrder()) {
1031 if (hlo->opcode() == HloOpcode::kWhile) {
1032 TF_ASSIGN_OR_RETURN(bool loop_changed,
1033 RunOnLoop(hlo, min_operand_count_to_optimize_));
1034 changed |= loop_changed;
1035 }
1036 }
1037 }
1038 if (changed) {
1039 HloPassPipeline pipeline("loop-concat-motion-cleanup");
1040 pipeline.AddPass<TupleSimplifier>();
1041 pipeline.AddPass<HloDCE>();
1042 pipeline.AddPass<WhileLoopSimplifier>();
1043 pipeline.AddPass<TupleSimplifier>();
1044 pipeline.AddPass<HloDCE>();
1045 TF_RETURN_IF_ERROR(pipeline.Run(module, execution_threads).status());
1046 }
1047 return changed;
1048 }
1049
1050 } // namespace xla
1051