xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/layout_assignment.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <numeric>
24 #include <ostream>
25 #include <set>
26 #include <string>
27 #include <tuple>
28 #include <utility>
29 
30 #include "absl/algorithm/container.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/map_util.h"
37 #include "tensorflow/compiler/xla/permutation_util.h"
38 #include "tensorflow/compiler/xla/service/call_graph.h"
39 #include "tensorflow/compiler/xla/service/computation_layout.h"
40 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
41 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
42 #include "tensorflow/compiler/xla/service/hlo_computation.h"
43 #include "tensorflow/compiler/xla/service/hlo_dce.h"
44 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
45 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
46 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
47 #include "tensorflow/compiler/xla/service/logical_buffer.h"
48 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
49 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
50 #include "tensorflow/compiler/xla/shape_layout.h"
51 #include "tensorflow/compiler/xla/shape_util.h"
52 #include "tensorflow/compiler/xla/status_macros.h"
53 #include "tensorflow/compiler/xla/statusor.h"
54 #include "tensorflow/compiler/xla/types.h"
55 #include "tensorflow/compiler/xla/util.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/platform/logging.h"
60 #include "tensorflow/core/platform/protobuf.h"
61 
62 namespace xla {
63 
operator <<(std::ostream & out,const LayoutConstraint & constraint)64 std::ostream& operator<<(std::ostream& out,
65                          const LayoutConstraint& constraint) {
66   out << constraint.ToString();
67   return out;
68 }
69 
BufferLayoutConstraint(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs,int64_t priority)70 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
71                                                const LogicalBuffer& buffer,
72                                                bool mandatory, bool dfs,
73                                                int64_t priority)
74     : LayoutConstraint(mandatory, dfs, priority),
75       layout_(layout),
76       buffer_(&buffer) {
77   CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
78 }
79 
ToString() const80 std::string BufferLayoutConstraint::ToString() const {
81   return absl::StrFormat(
82       "BufferLayoutConstraint (prioity=%d, mandatory=%d, dfs=%d) %s: %s",
83       priority(), mandatory(), dfs(), buffer_->ToString(),
84       LayoutUtil::HumanString(layout_));
85 }
86 
UpdateLayout(int64_t priority,const Layout & layout,bool mandatory,bool dfs)87 bool BufferLayoutConstraint::UpdateLayout(int64_t priority,
88                                           const Layout& layout, bool mandatory,
89                                           bool dfs) {
90   if (!mandatory && priority <= priority_) {
91     return false;
92   }
93   mandatory_ = mandatory;
94   dfs_ = dfs;
95   priority_ = priority;
96   if (Layout::Equal().MinorToMajorOnly()(layout_, layout)) {
97     // New constraint matches existing constraint. Nothing to propagation.
98     return false;
99   }
100   layout_ = layout;
101   return true;
102 }
103 
OperandLayoutConstraint(const ShapeLayout & shape_layout,const HloInstruction * instruction,int64_t operand_no,bool mandatory,bool dfs,int64_t priority)104 OperandLayoutConstraint::OperandLayoutConstraint(
105     const ShapeLayout& shape_layout, const HloInstruction* instruction,
106     int64_t operand_no, bool mandatory, bool dfs, int64_t priority)
107     : LayoutConstraint(mandatory, dfs, priority),
108       shape_layout_(shape_layout),
109       instruction_(instruction),
110       operand_no_(operand_no) {
111   CHECK(shape_layout_.LayoutIsSet());
112   CHECK(ShapeUtil::Compatible(shape_layout.shape(),
113                               instruction->operand(operand_no)->shape()))
114       << shape_layout.shape() << " is not compatible with "
115       << instruction->operand(operand_no)->shape() << " (for operand "
116       << operand_no << " of instruction " << instruction->ToString() << ")";
117 }
118 
ToString() const119 std::string OperandLayoutConstraint::ToString() const {
120   return absl::StrFormat(
121       "OperandLayoutConstraint (prioity=%d) %s, operand %d: %s", priority(),
122       instruction_->name(), operand_no_, shape_layout_.ToString());
123 }
124 
ToString() const125 std::string ComputationLayoutConstraint::ToString() const {
126   return absl::StrFormat("ComputationLayoutConstraint (status=%d): %s",
127                          layout_state_, computation_layout_.ToString());
128 }
129 
LayoutConstraints(HloComputation * computation,ComputationLayout * computation_layout,int64_t priority)130 LayoutAssignment::LayoutConstraints::LayoutConstraints(
131     HloComputation* computation, ComputationLayout* computation_layout,
132     int64_t priority)
133     : computation_(computation),
134       computation_constraint_(computation, computation_layout, priority) {}
135 
GetBufferSet(const HloInstruction * instruction) const136 PointsToSet::BufferSet* LayoutAssignment::GetBufferSet(
137     const HloInstruction* instruction) const {
138   auto it = buffer_sets_cache_.find(instruction);
139   if (it != buffer_sets_cache_.end()) {
140     return it->second.get();
141   }
142   auto& buffer_set =
143       buffer_sets_cache_
144           .emplace(instruction, std::make_unique<PointsToSet::BufferSet>())
145           .first->second;
146   const auto& points_to_set = points_to_analysis_->GetPointsToSet(instruction);
147   points_to_set.ForEachElement(
148       [&buffer_set](const ShapeIndex& /*index*/,
149                     const PointsToSet::BufferList& buffers) {
150         buffer_set->insert(buffers.begin(), buffers.end());
151       });
152   return buffer_set.get();
153 }
154 
AnyOperandBufferForwarded(const HloInstruction * instruction,int64_t operand_no) const155 bool LayoutAssignment::AnyOperandBufferForwarded(
156     const HloInstruction* instruction, int64_t operand_no) const {
157   // The operand is potentially forwarded if the intersection of points-to sets
158   // of the operand and the instruction is non-empty.
159   PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
160   PointsToSet::BufferSet* operand_buffers =
161       GetBufferSet(instruction->operand(operand_no));
162   return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
163     return operand_buffers->count(b) > 0;
164   });
165 }
166 
AllOperandBuffersForwarded(const HloInstruction * instruction,int64_t operand_no) const167 bool LayoutAssignment::AllOperandBuffersForwarded(
168     const HloInstruction* instruction, int64_t operand_no) const {
169   // The operand is potentially forwarded if the intersection of points-to sets
170   // of the operand and the instruction is non-empty.
171   PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
172   PointsToSet::BufferSet* operand_buffers =
173       GetBufferSet(instruction->operand(operand_no));
174   return absl::c_all_of(*output_buffers, [&](const LogicalBuffer* b) {
175     return operand_buffers->count(b) > 0;
176   });
177 }
178 
SetBufferLayout(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs,int64_t priority)179 Status LayoutAssignment::SetBufferLayout(const Layout& layout,
180                                          const LogicalBuffer& buffer,
181                                          bool mandatory, bool dfs,
182                                          int64_t priority) {
183   VLOG(3) << "SetBufferLayout : " << buffer << " : "
184           << LayoutUtil::HumanString(layout) << " with priority " << priority;
185   TF_RETURN_IF_ERROR(points_to_analysis_->VerifyBuffer(buffer));
186   if (unconstrained_buffer_ids_.find(buffer.id()) !=
187       unconstrained_buffer_ids_.end()) {
188     VLOG(3) << "Erase buffer from unconstrained ids\n";
189     TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1)
190         << buffer.ToString();
191   }
192 
193   if (!buffer.IsArray()) {
194     return FailedPrecondition(
195         "Layout of buffer %s cannot be constrained because buffer is not "
196         "array-shaped, has shape: %s",
197         buffer.ToString(), ShapeUtil::HumanString(buffer.shape()));
198   }
199   TF_RETURN_IF_ERROR(
200       LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
201 
202   auto iter = buffer_constraints_.find(&buffer);
203   if (iter != buffer_constraints_.end()) {
204     BufferLayoutConstraint curr_constraint = iter->second;
205     if (curr_constraint.mandatory() && !mandatory) {
206       VLOG(3) << "Buffer" << buffer
207               << " already has a mandatory layout constrain, skipping "
208                  "non-mandatory new layout";
209       return OkStatus();
210     } else {
211       if (curr_constraint.UpdateLayout(priority, layout, mandatory, dfs)) {
212         VLOG(3) << "Updating existing Buffer layout for " << buffer.ToString()
213                 << " with new layout" << LayoutUtil::HumanString(layout);
214         iter = buffer_constraints_.insert_or_assign(&buffer, curr_constraint)
215                    .first;
216       } else {
217         VLOG(3) << "Unable to update existing Buffer layout for "
218                 << curr_constraint.ToString() << " with new layout"
219                 << LayoutUtil::HumanString(layout) << " at priority "
220                 << priority << "\n";
221         return OkStatus();
222       }
223     }
224   } else {
225     iter = buffer_constraints_
226                .insert(std::make_pair(
227                    &buffer, BufferLayoutConstraint(layout, buffer, mandatory,
228                                                    dfs, priority)))
229                .first;
230   }
231   added_constraints_.push_back(&iter->second);
232   return OkStatus();
233 }
234 
SetOperandLayout(const Shape & shape_with_layout,const HloInstruction * instruction,int64_t operand_no,bool mandatory,bool dfs,int64_t priority)235 Status LayoutAssignment::SetOperandLayout(const Shape& shape_with_layout,
236                                           const HloInstruction* instruction,
237                                           int64_t operand_no, bool mandatory,
238                                           bool dfs, int64_t priority) {
239   LayoutConstraints& constraints =
240       *FindOrDie(computation_layouts_, instruction->parent());
241   // The second and third operands (operand_no > 0) of a dynamic-update-slice
242   // operation typically have much smaller sizes than the first (operand_no==0)
243   // operand. It is necessary to downgrade the importance of the smaller
244   // operands, so that the overall layout choice of the operation is dicated by
245   // operand 0 when possible.
246   if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice &&
247       operand_no > 0 && !mandatory &&
248       priority > LayoutConstraint::kDefaultPriority) {
249     dfs = false;
250     priority--;
251   } else if (instruction->opcode() == HloOpcode::kReshape && !mandatory &&
252              instruction->operand(0)->opcode() == HloOpcode::kDynamicSlice) {
253     dfs = false;
254     priority--;
255   }
256   VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
257           << operand_no << " : "
258           << ShapeUtil::HumanStringWithLayout(shape_with_layout)
259           << " : priority = " << priority << "; mandatory = " << mandatory
260           << "; dfs = " << dfs << "\n";
261   const OperandLayoutConstraint* curr_shape_layout =
262       constraints.GetOperandLayoutConstraint(instruction, operand_no);
263   if (curr_shape_layout != nullptr) {
264     if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
265             shape_with_layout, /*minor_to_major_only=*/true)) {
266       // New constraint matches existing constraint. Nothing to do.
267       return OkStatus();
268     }
269     if (curr_shape_layout->mandatory() && !mandatory) {
270       VLOG(5)
271           << "Existing layout is mandatory but the new one is not. Skipping.\n";
272       return OkStatus();
273     } else if (curr_shape_layout->priority() > priority) {
274       VLOG(5) << "Existing layout has higher priority: "
275               << curr_shape_layout->priority() << " vs " << priority << "\n";
276       return OkStatus();
277     }
278   }
279   // If any buffers in the operand occur in the output of the instruction, then
280   // return an error. This case is not handled because such a constraint changes
281   // layouts beyond this immediate use and is complicated to handle.
282   if (AnyOperandBufferForwarded(instruction, operand_no)) {
283     return FailedPrecondition(
284         "Cannot constraint layout of operand %d of instruction %s "
285         "because instruction forwards operand's LogicalBuffer(s)",
286         operand_no, instruction->name());
287   }
288 
289   OperandLayoutConstraint new_constraint(ShapeLayout(shape_with_layout),
290                                          instruction, operand_no, mandatory,
291                                          dfs, priority);
292   auto op_constraint = constraints.InsertOperandLayoutConstraint(
293       instruction, operand_no, new_constraint);
294   PushAddedConstraints(op_constraint);
295   return OkStatus();
296 }
297 
298 OperandLayoutConstraint*
InsertOperandLayoutConstraint(const HloInstruction * instruction,int64_t operand_no,const OperandLayoutConstraint & constraint)299 LayoutAssignment::LayoutConstraints::InsertOperandLayoutConstraint(
300     const HloInstruction* instruction, int64_t operand_no,
301     const OperandLayoutConstraint& constraint) {
302   auto key = std::make_pair(instruction, operand_no);
303   auto iter = operand_constraints_.find(key);
304   if (iter == operand_constraints_.end()) {
305     auto pair = std::make_pair(key, constraint);
306     iter = operand_constraints_.insert(pair).first;
307   } else {
308     iter->second = constraint;
309   }
310   return &iter->second;
311 }
312 
PushAddedConstraints(const LayoutConstraint * constraint)313 void LayoutAssignment::PushAddedConstraints(
314     const LayoutConstraint* constraint) {
315   if (!constraint->dfs()) {
316     // Insert a new constraint to the first location where it's strictly greater
317     // than all the subsequent constraints. Assumes invariant that the list is
318     // sorted.
319     auto it = absl::c_upper_bound(
320         added_constraints_, constraint,
321         [&](const LayoutConstraint* a, const LayoutConstraint* b) {
322           return a->priority() > b->priority();
323         });
324     added_constraints_.insert(it, constraint);
325   } else {
326     added_constraints_.push_back(constraint);
327   }
328 }
329 
SetArrayOperandLayout(const Layout & layout,const HloInstruction * instruction,int64_t operand_no,bool mandatory,bool dfs,int64_t priority)330 Status LayoutAssignment::SetArrayOperandLayout(
331     const Layout& layout, const HloInstruction* instruction, int64_t operand_no,
332     bool mandatory, bool dfs, int64_t priority) {
333   const HloInstruction* operand = instruction->operand(operand_no);
334   TF_RET_CHECK(operand->shape().IsArray());
335   Shape shape(operand->shape());
336   *shape.mutable_layout() = layout;
337   TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
338   return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs,
339                           priority);
340 }
341 
SetResultLayout(LayoutAssignment * assignment,const Shape & shape_with_layout,int64_t priority)342 Status LayoutAssignment::LayoutConstraints::SetResultLayout(
343     LayoutAssignment* assignment, const Shape& shape_with_layout,
344     int64_t priority) {
345   VLOG(3) << "  : " << ShapeUtil::HumanStringWithLayout(shape_with_layout)
346           << "; priority = " << priority << ".\n";
347 
348   computation_constraint_.ResetResultLayout(ShapeLayout(shape_with_layout),
349                                             priority);
350   assignment->PushAddedConstraints(&computation_constraint_);
351   return OkStatus();
352 }
353 
SetInstructionLayout(const Layout & layout,const HloInstruction * instruction,bool mandatory,bool dfs,bool allow_alias,int64_t priority)354 Status LayoutAssignment::SetInstructionLayout(const Layout& layout,
355                                               const HloInstruction* instruction,
356                                               bool mandatory, bool dfs,
357                                               bool allow_alias,
358                                               int64_t priority) {
359   if (priority < 0) {
360     priority = current_priority_;
361   }
362   auto RequiresSameShapeForAllOutput = [](const HloInstruction* op) -> bool {
363     switch (op->opcode()) {
364       case HloOpcode::kSort:
365       case HloOpcode::kReduce:
366       case HloOpcode::kReduceWindow:
367         return true;
368       default:
369         return false;
370     }
371   };
372   CHECK(instruction->shape().IsArray() ||
373         RequiresSameShapeForAllOutput(instruction));
374 
375   return ShapeUtil::ForEachSubshapeWithStatus(
376       instruction->shape(),
377       [this, layout, instruction, mandatory, allow_alias, priority](
378           const Shape& subshape, const ShapeIndex& index) -> Status {
379         auto buffers =
380             points_to_analysis_->GetPointsToSet(instruction).element(index);
381         CHECK_EQ(1, buffers.size());
382         if (!allow_alias) {
383           CHECK_EQ(buffers[0]->instruction(), instruction);
384         }
385         if (subshape.IsArray()) {
386           return SetBufferLayout(layout, *buffers[0], mandatory,
387                                  /*dfs=*/true, priority);
388         } else {
389           return OkStatus();
390         }
391       });
392 }
393 
SetInstructionLayout(const Shape & shape_with_layout,const HloInstruction * instruction,bool mandatory,bool dfs,bool allow_alias,int64_t priority)394 Status LayoutAssignment::SetInstructionLayout(const Shape& shape_with_layout,
395                                               const HloInstruction* instruction,
396                                               bool mandatory, bool dfs,
397                                               bool allow_alias,
398                                               int64_t priority) {
399   VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
400           << ShapeUtil::HumanStringWithLayout(shape_with_layout)
401           << ": priority = " << priority << " : mandatory = " << mandatory
402           << "; dfs = " << dfs << "\n";
403 
404   if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
405     return FailedPrecondition(
406         "Instruction %s of shape %s cannot be assigned incompatible layout %s",
407         instruction->name(), ShapeUtil::HumanString(instruction->shape()),
408         ShapeUtil::HumanStringWithLayout(shape_with_layout));
409   }
410 
411   // Create a BufferLayoutConstraint for each array shape in the output of the
412   // instruction.
413   return ShapeUtil::ForEachSubshapeWithStatus(
414       shape_with_layout,
415       [this, instruction, mandatory, allow_alias, priority](
416           const Shape& subshape, const ShapeIndex& index) -> Status {
417         auto buffers =
418             points_to_analysis_->GetPointsToSet(instruction).element(index);
419         CHECK_EQ(1, buffers.size());
420         if (!allow_alias) {
421           CHECK_EQ(buffers[0]->instruction(), instruction);
422         }
423 
424         if (subshape.IsArray() && subshape.has_layout()) {
425           return SetBufferLayout(subshape.layout(), *buffers[0], mandatory,
426                                  /*dfs=*/true, priority);
427         } else {
428           return OkStatus();
429         }
430       });
431 }
432 
GetBufferLayoutConstraint(const LogicalBuffer & buffer) const433 const BufferLayoutConstraint* LayoutAssignment::GetBufferLayoutConstraint(
434     const LogicalBuffer& buffer) const {
435   auto it = buffer_constraints_.find(&buffer);
436   return it == buffer_constraints_.end() ? nullptr : &it->second;
437 }
438 
OperandLayout(const HloInstruction * instruction,int64_t operand_no) const439 const ShapeLayout* LayoutAssignment::LayoutConstraints::OperandLayout(
440     const HloInstruction* instruction, int64_t operand_no) const {
441   if (const auto* constraint =
442           GetOperandLayoutConstraint(instruction, operand_no)) {
443     return &constraint->shape_layout();
444   }
445   return nullptr;
446 }
447 
448 const OperandLayoutConstraint*
GetOperandLayoutConstraint(const HloInstruction * instruction,int64_t operand_no) const449 LayoutAssignment::LayoutConstraints::GetOperandLayoutConstraint(
450     const HloInstruction* instruction, int64_t operand_no) const {
451   auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
452   return it == operand_constraints_.end() ? nullptr : &it->second;
453 }
454 
ResultLayout() const455 const ShapeLayout* LayoutAssignment::LayoutConstraints::ResultLayout() const {
456   return (computation_->IsEntryComputation() ||
457           computation_constraint_.result_layout_is_set())
458              ? &computation_layout().result_layout()
459              : nullptr;
460 }
461 
ToString(const LayoutConstraints & constraints) const462 std::string LayoutAssignment::ToString(
463     const LayoutConstraints& constraints) const {
464   std::string output;
465   absl::StrAppend(&output, "LayoutConstraints for computation ",
466                   constraints.computation()->name(), "\n");
467   for (auto* instruction :
468        constraints.computation()->MakeInstructionPostOrder()) {
469     absl::StrAppend(&output, "  ", instruction->ToShortString(), "\n");
470     for (int64_t i = 0; i < instruction->operand_count(); ++i) {
471       if (constraints.OperandLayout(instruction, i) != nullptr) {
472         absl::StrAppend(
473             &output, "    operand (", i,
474             "): ", constraints.OperandLayout(instruction, i)->ToString(), "\n");
475       }
476     }
477     for (const LogicalBuffer* buffer :
478          points_to_analysis_->GetBuffersDefinedByInstruction(instruction)) {
479       auto* buffer_constraint = GetBufferLayoutConstraint(*buffer);
480       if (buffer_constraint != nullptr) {
481         absl::StrAppend(&output, "    ", buffer->ToString(), " : ",
482                         LayoutUtil::HumanString(buffer_constraint->layout()),
483                         "\n");
484       }
485     }
486   }
487 
488   absl::StrAppend(&output, "  => ",
489                   constraints.computation_constraint().ToString(), "\n");
490   return output;
491 }
492 
493 namespace {
494 
IsHostSendRecv(const HloInstruction * instruction)495 bool IsHostSendRecv(const HloInstruction* instruction) {
496   const HloSendRecvInstruction* send_recv_instr =
497       DynCast<HloSendRecvInstruction>(instruction);
498   return send_recv_instr != nullptr && send_recv_instr->is_host_transfer();
499 }
500 
501 }  // namespace
502 
BuildHostChannelConstraints(HloComputation * computation)503 Status LayoutAssignment::BuildHostChannelConstraints(
504     HloComputation* computation) {
505   for (auto* instruction : computation->instructions()) {
506     const HloSendRecvInstruction* send_recv_instr =
507         DynCast<HloSendRecvInstruction>(instruction);
508     if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) {
509       continue;
510     }
511 
512     // For host transfers the Send and Recv instruction carry the layout.
513     if (instruction->opcode() == HloOpcode::kSend ||
514         instruction->opcode() == HloOpcode::kRecv) {
515       const Shape& data_shape =
516           ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0);
517       TF_RET_CHECK(data_shape.IsArray());
518       TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
519       const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
520           *send_recv_instr->channel_id(), data_shape.layout());
521       TF_RET_CHECK(prev_layout == nullptr)
522           << "Cannot constrain host transfer layout as it was set to "
523           << LayoutUtil::HumanString(*prev_layout) << ": "
524           << send_recv_instr->ToString();
525     }
526   }
527   return OkStatus();
528 }
529 
530 namespace {
531 
IsLayoutConstrainedCustomCall(HloInstruction * instruction)532 bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
533   const HloCustomCallInstruction* custom_call =
534       DynCast<HloCustomCallInstruction>(instruction);
535   return custom_call != nullptr && custom_call->layout_constrained();
536 }
537 
IsLayoutConstrainedCollective(const HloInstruction * instruction)538 bool IsLayoutConstrainedCollective(const HloInstruction* instruction) {
539   const HloCollectiveInstruction* collective =
540       DynCast<HloCollectiveInstruction>(instruction);
541   return collective != nullptr && collective->constrain_layout();
542 }
543 
PropagateParameterLayoutToUsers(const HloInstruction * instruction,const Shape & shape,LayoutAssignment * constraints)544 Status PropagateParameterLayoutToUsers(const HloInstruction* instruction,
545                                        const Shape& shape,
546                                        LayoutAssignment* constraints) {
547   for (auto* user : instruction->users()) {
548     // Excluding tuple operations as they do not participate in layout
549     // propagations (they do not create or aliase buffers).
550     if (user->opcode() == HloOpcode::kTuple) {
551       continue;
552     }
553     VLOG(3) << "Setting  user layout : " << user->ToString();
554     if (user->opcode() == HloOpcode::kGetTupleElement) {
555       auto tuple_index = user->tuple_index();
556       CHECK(shape.IsTuple());
557       auto elem_shape = shape.tuple_shapes(tuple_index);
558       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
559           elem_shape, user, /*mandatory=*/false, /*dfs=*/false,
560           /*allow_alias=*/true));
561       TF_RETURN_IF_ERROR(
562           PropagateParameterLayoutToUsers(user, elem_shape, constraints));
563     } else {
564       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
565           shape, user, user->operand_index(instruction), /*mandatory=*/false,
566           /*dfs=*/false));
567     }
568   }
569   return OkStatus();
570 }
571 
572 }  // namespace
573 
AddMandatoryConstraints(ChannelLayoutConstraints * channel_constraints,LayoutConstraints * constraints)574 Status LayoutAssignment::AddMandatoryConstraints(
575     ChannelLayoutConstraints* channel_constraints,
576     LayoutConstraints* constraints) {
577   VLOG(2) << "Adding mandatory layout constraints to computation "
578           << constraints->computation()->name();
579 
580   auto get_channel_constraints = [&](const HloInstruction* instruction) {
581     return IsHostSendRecv(instruction) ? &host_channel_constraints_
582                                        : channel_constraints;
583   };
584 
585   // Constrain layouts of instructions which define values with pre-existing
586   // layouts.
587   for (auto* instruction : constraints->computation()->instructions()) {
588     if (instruction->opcode() == HloOpcode::kInfeed) {
589       // Infeed layouts must match the layout of the original inserted
590       // instruction.
591       // TODO(b/31425034): Change infeeds to be more like parameters, with
592       // shapes in the ComputationLayout.
593       TF_RETURN_IF_ERROR(SetInstructionLayout(instruction->shape(), instruction,
594                                               /*mandatory=*/true, /*dfs=*/true,
595                                               /*allow_alias=*/false));
596     } else if (instruction->opcode() == HloOpcode::kOutfeed) {
597       // Constrain the input to the Outfeed instruction to be the expected
598       // layout of the Outfeed.
599       TF_RETURN_IF_ERROR(SetOperandLayout(instruction->outfeed_shape(),
600                                           instruction, 0,
601                                           /*mandatory=*/true, /*dfs=*/true));
602     } else if (instruction->opcode() == HloOpcode::kParameter) {
603       if (reverse_computation_order_ ||
604           (constraints->computation()->IsEntryComputation() &&
605            entry_computation_layout_->LayoutIsSet()) ||
606           (conditional_mismatch_.count(constraints->computation()) == 0 &&
607            constraints->computation_constraint().parameter_layout_is_set())) {
608         const ShapeLayout& parameter_layout =
609             constraints->computation_layout().parameter_layout(
610                 instruction->parameter_number());
611         // Parameter layouts must match the respective layout in
612         // ComputationLayout, if there is one.
613         TF_RETURN_IF_ERROR(
614             SetInstructionLayout(parameter_layout.shape(), instruction));
615         if (reverse_computation_order_) {
616           TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers(
617               instruction, parameter_layout.shape(), this));
618         }
619       }
620     } else if (IsLayoutConstrainedCustomCall(instruction)) {
621       const HloCustomCallInstruction* custom_call =
622           DynCast<HloCustomCallInstruction>(instruction);
623 
624       TF_RETURN_IF_ERROR(SetInstructionLayout(custom_call->shape(), custom_call,
625                                               /*mandatory=*/true, /*dfs=*/true,
626                                               /*allow_alias=*/true));
627 
628       for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
629         if (AnyOperandBufferForwarded(custom_call, i)) {
630           TF_RET_CHECK(AllOperandBuffersForwarded(custom_call, i))
631               << "Partial alias of an operand is not supported";
632         } else {
633           TF_RETURN_IF_ERROR(SetOperandLayout(
634               custom_call->operand_shapes_with_layout()[i], custom_call, i));
635         }
636       }
637     } else if (IsLayoutConstrainedCollective(instruction)) {
638       TF_RETURN_IF_ERROR(
639           SetInstructionLayout(instruction->shape(), instruction));
640     } else if (instruction->IsCrossModuleAllReduce()) {
641       CHECK(get_channel_constraints(instruction))
642           << "Multi-module layout assignment requires ChannelLayoutConstraints";
643       int64_t channel_id = instruction->channel_id().value();
644       if (!get_channel_constraints(instruction)
645                ->IsChannelConstrained(channel_id)) {
646         continue;
647       }
648       // TODO(b/68493863): Change to use SetOperandLayout().
649       const Shape& buffer_shape = instruction->operand(0)->shape();
650       TF_RET_CHECK(buffer_shape.IsArray());
651       Shape new_buffer_shape =
652           get_channel_constraints(instruction)
653               ->LayoutShapeForChannel(buffer_shape, channel_id);
654       TF_RETURN_IF_ERROR(SetInstructionLayout(new_buffer_shape, instruction));
655     }
656   }
657 
658   // Constrain layouts of instructions which call computations which have
659   // already been assigned layouts. Instructions which call computations in a
660   // parallel element-wise context (eg, map or reduce) do not need layout
661   // constraints because they operate on scalars.
662   for (auto* instruction : constraints->computation()->instructions()) {
663     if (instruction->opcode() == HloOpcode::kCall &&
664         computation_layouts_.find(instruction->to_apply()) !=
665             computation_layouts_.end()) {
666       // kCall instruction operands and output must match the ComputationLayout
667       // of the called computation.
668       const ComputationLayout& called_computation_layout =
669           FindOrDie(computation_layouts_, instruction->to_apply())
670               ->computation_layout();
671       TF_RETURN_IF_ERROR(SetInstructionLayout(
672           called_computation_layout.result_layout().shape(), instruction));
673       TF_RET_CHECK(instruction->operand_count() ==
674                    called_computation_layout.parameter_count());
675       for (int64_t i = 0; i < instruction->operand_count(); ++i) {
676         TF_RETURN_IF_ERROR(SetOperandLayout(
677             called_computation_layout.parameter_layout(i).shape(), instruction,
678             i, /*mandatory=*/true, /*dfs=*/true));
679       }
680     } else if (instruction->opcode() == HloOpcode::kWhile &&
681                computation_layouts_.find(instruction->while_body()) !=
682                    computation_layouts_.end()) {
683       // Layout of input and output of kWhile instruction must be equal and must
684       // match both input and output of body computation. Also, the input of
685       // condition computation must match kWhile layout.
686       HloComputation* body = instruction->while_body();
687       HloComputation* condition = instruction->while_condition();
688       const HloInstruction* init = instruction->operand(0);
689       ComputationLayoutConstraint* body_constraint =
690           mutable_computation_constraints(body)
691               ->mutable_computation_constraint();
692       ComputationLayout body_layout = body_constraint->computation_layout();
693       ComputationLayoutConstraint* condition_constraint =
694           mutable_computation_constraints(condition)
695               ->mutable_computation_constraint();
696       ComputationLayout condition_layout =
697           condition_constraint->computation_layout();
698 
699       // Check a few invariants irrespective of layout.
700       CHECK_EQ(1, instruction->operand_count());
701       CHECK_EQ(1, body->num_parameters());
702       CHECK_EQ(1, condition->num_parameters());
703       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
704                                    body_layout.parameter_shape(0)));
705       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
706                                    condition_layout.parameter_shape(0)));
707       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
708 
709       if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
710         VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
711                 << " while=" << instruction->name()
712                 << " shape=" << body_layout.result_layout().ToString();
713         *body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
714         body_constraint->ResetComputationLayout(
715             body_layout, current_priority_ + kNumberOfPropagationRounds,
716             /*prop_result_layout=*/true,
717             /*prop_parameter_layout=*/true);
718       }
719       if (condition_layout.parameter_layout(0) !=
720           body_layout.parameter_layout(0)) {
721         VLOG(2) << "Reset %while condition parameter layout: cond="
722                 << condition->name() << " while=" << instruction->name()
723                 << " shape=" << body_layout.parameter_layout(0).ToString();
724         *condition_layout.mutable_parameter_layout(0) =
725             body_layout.parameter_layout(0);
726         condition_constraint->ResetComputationLayout(
727             condition_layout, current_priority_ + kNumberOfPropagationRounds,
728             /*prop_result_layout=*/true, /*prop_parameter_layout=*/true);
729       }
730 
731       // Constrain the output and the operand of the while instruction to match
732       // the computations.
733       TF_RETURN_IF_ERROR(
734           SetOperandLayout(body_layout.result_shape(), instruction, 0));
735       TF_RETURN_IF_ERROR(
736           SetInstructionLayout(body_layout.result_shape(), instruction));
737     } else if (instruction->opcode() == HloOpcode::kConditional &&
738                computation_layouts_.find(instruction->branch_computation(0)) !=
739                    computation_layouts_.end()) {
740       // Find the conditional branch with the most instructions and force all
741       // other computations to match that layout. A potentially better decision
742       // could count the number FLOPs or how constrained the layouts are.
743       int64_t largest_branch = -1;
744       int64_t largest_instruction_count = 0;
745       for (int j = 0; j < instruction->branch_count(); ++j) {
746         const int64_t instruction_count =
747             instruction->branch_computation(j)->instruction_count();
748         if (instruction_count > largest_instruction_count &&
749             !ShapeUtil::IsEmptyTuple(instruction->operand(j + 1)->shape())) {
750           largest_branch = j;
751           largest_instruction_count = instruction_count;
752         }
753       }
754       if (largest_branch == -1) {
755         largest_branch = 0;
756       }
757       const ComputationLayout& best_branch_computation_layout =
758           mutable_computation_constraints(
759               instruction->branch_computation(largest_branch))
760               ->computation_layout();
761       for (int k = 0; k < instruction->branch_count(); ++k) {
762         // Visit the best branch first.
763         int j = (k + largest_branch) % instruction->branch_count();
764         TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1);
765         ComputationLayout branch_computation_layout =
766             mutable_computation_constraints(instruction->branch_computation(k))
767                 ->computation_layout();
768         if (!branch_computation_layout.result_layout().MatchesLayoutInShape(
769                 best_branch_computation_layout.result_layout().shape(),
770                 /*minor_to_major_only=*/true)) {
771           *branch_computation_layout.mutable_result_layout() =
772               best_branch_computation_layout.result_layout();
773           InsertOrDie(&conditional_mismatch_,
774                       instruction->branch_computation(k),
775                       branch_computation_layout);
776         } else {
777           TF_RETURN_IF_ERROR(SetOperandLayout(
778               branch_computation_layout.parameter_shape(0), instruction, k + 1,
779               /*mandatory=*/true, /*dfs=*/true));
780         }
781       }
782       TF_RETURN_IF_ERROR(
783           SetOperandLayout(best_branch_computation_layout.parameter_shape(0),
784                            instruction, largest_branch + 1,
785                            /*mandatory=*/true, /*dfs=*/true));
786       TF_RETURN_IF_ERROR(SetInstructionLayout(
787           best_branch_computation_layout.result_shape(), instruction,
788           /*mandatory=*/true, /*dfs=*/true, /*allow_alias=*/false));
789     }
790   }
791   // Finally set the result layout to match ComputationLayout, if there is one.
792   if (conditional_mismatch_.count(constraints->computation()) > 0) {
793     VLOG(5) << "Setting mismatching conditional result:"
794             << constraints->computation()->name() << "\n";
795     TF_RETURN_IF_ERROR(constraints->SetResultLayout(
796         this,
797         FindOrDie(conditional_mismatch_, constraints->computation())
798             .result_layout()
799             .shape(),
800         current_priority_ + kNumberOfPropagationRounds));
801   } else if (reverse_computation_order_ ||
802              (constraints->computation()->IsEntryComputation() &&
803               entry_computation_layout_->LayoutIsSet()) ||
804              current_priority_ > LayoutConstraint::kBeginningPriority) {
805     const ShapeLayout* result_layout = constraints->ResultLayout();
806     if (result_layout != nullptr) {
807       VLOG(2) << "Setting computation result layout.\n";
808       PushAddedConstraints(&constraints->computation_constraint());
809     } else {
810       VLOG(2) << "Computation result layout is not set.\n";
811     }
812   }
813   return OkStatus();
814 }
815 
816 namespace {
817 
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)818 bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) {
819   if (!lhs.has_layout() && !rhs.has_layout()) {
820     return true;
821   }
822   CHECK(lhs.has_layout() && rhs.has_layout());
823   return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout());
824 }
825 
826 // The operands of a call must match the layouts of parameters in the
827 // ComputationLayout, and the call instruction itself must match the result
828 // layout in the ComputationLayout.
CheckCallLayout(HloInstruction * call,const ComputationLayout & computation_layout)829 Status CheckCallLayout(HloInstruction* call,
830                        const ComputationLayout& computation_layout) {
831   HloComputation* computation = call->to_apply();
832   TF_RET_CHECK(computation->num_parameters() == call->operand_count());
833   for (int64_t i = 0; i < computation->num_parameters(); ++i) {
834     TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
835         call->operand(i)->shape(), /*minor_to_major_only=*/true));
836   }
837   TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape(
838       call->shape(), /*minor_to_major_only=*/true));
839   return OkStatus();
840 }
841 
842 // Operands of layout-constrained custom calls must match the expected
843 // constrained layouts.
CheckCustomCallLayout(HloInstruction * instruction)844 Status CheckCustomCallLayout(HloInstruction* instruction) {
845   if (IsLayoutConstrainedCustomCall(instruction)) {
846     const HloCustomCallInstruction* custom_call =
847         DynCast<HloCustomCallInstruction>(instruction);
848     for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
849       TF_RET_CHECK(
850           LayoutsInShapesEqual(custom_call->operand(i)->shape(),
851                                custom_call->operand_shapes_with_layout()[i]));
852     }
853   }
854   return OkStatus();
855 }
856 
857 // For a while instruction, all the following layouts must be the same:
858 //   (1) init operand
859 //   (2) condition computation parameter
860 //   (3) body computation parameter
861 //   (4) body computation result
862 //   (5) while instruction result
CheckWhileLayout(HloInstruction * while_inst,const ComputationLayout & condition_computation_layout,const ComputationLayout & body_computation_layout)863 Status CheckWhileLayout(HloInstruction* while_inst,
864                         const ComputationLayout& condition_computation_layout,
865                         const ComputationLayout& body_computation_layout) {
866   auto init_shape = while_inst->operand(0)->shape();
867   TF_RET_CHECK(
868       condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
869           init_shape, /*minor_to_major_only=*/true));
870   TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
871       init_shape, /*minor_to_major_only=*/true));
872   TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape(
873       init_shape, /*minor_to_major_only=*/true));
874   TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape()));
875   return OkStatus();
876 }
877 
CheckOptimizationBarrierLayout(HloInstruction * inst)878 Status CheckOptimizationBarrierLayout(HloInstruction* inst) {
879   TF_RET_CHECK(LayoutsInShapesEqual(inst->operand(0)->shape(), inst->shape()));
880   return OkStatus();
881 }
882 
CheckConditionalLayout(HloInstruction * instruction,absl::Span<const ComputationLayout> branch_computation_layouts)883 Status CheckConditionalLayout(
884     HloInstruction* instruction,
885     absl::Span<const ComputationLayout> branch_computation_layouts) {
886   for (int j = 0; j < instruction->branch_count(); ++j) {
887     const HloInstruction* branch_operand = instruction->operand(j + 1);
888     TF_RET_CHECK(
889         branch_computation_layouts[0].result_layout().MatchesLayoutInShape(
890             branch_computation_layouts[j].result_layout().shape(),
891             /*minor_to_major_only=*/true));
892     TF_RET_CHECK(
893         branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
894             instruction->shape(), /*minor_to_major_only=*/true));
895     TF_RET_CHECK(
896         branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
897             instruction->branch_computation(j)->root_instruction()->shape(),
898             /*minor_to_major_only=*/true))
899         << j << ":"
900         << instruction->branch_computation(j)->root_instruction()->ToString();
901     TF_RET_CHECK(
902         branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
903             branch_operand->shape(), /*minor_to_major_only=*/true));
904   }
905   return OkStatus();
906 }
907 
908 // Fusion parameters must match the layout of the fusion instructions operands,
909 // and the root of the fusion expression must match the layout of the fusion
910 // instruction.
CheckFusionLayout(HloInstruction * fusion)911 Status CheckFusionLayout(HloInstruction* fusion) {
912   TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
913 
914   TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(),
915                                     fusion->fused_expression_root()->shape()));
916   for (int64_t i = 0; i < fusion->operand_count(); ++i) {
917     TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(),
918                                       fusion->operand(i)->shape()));
919   }
920   return OkStatus();
921 }
922 
923 // The layout of a parameter must match the respective layout in the
924 // computation's ComputationLayout.
CheckParameterLayout(HloInstruction * parameter,const ComputationLayout & computation_layout)925 Status CheckParameterLayout(HloInstruction* parameter,
926                             const ComputationLayout& computation_layout) {
927   const ShapeLayout& parameter_layout =
928       computation_layout.parameter_layout(parameter->parameter_number());
929   return ShapeUtil::ForEachSubshapeWithStatus(
930       parameter_layout.shape(),
931       [&](const Shape& subshape, const ShapeIndex& shape_index) {
932         if (!ShapeUtil::IsLeafIndex(parameter_layout.shape(), shape_index) ||
933             !subshape.has_layout()) {
934           return OkStatus();
935         }
936         if (!Shape::Equal().MinorToMajorOnlyInLayout().IgnoreDynamicDimension()(
937                 subshape,
938                 ShapeUtil::GetSubshape(parameter->shape(), shape_index))) {
939           return InternalError(
940               "parameter instruction %s does not match layout of computation "
941               "shape: %s",
942               parameter->ToString(), parameter_layout.ToString());
943         }
944         return OkStatus();
945       });
946 }
947 
948 // The layout of a constant instruction must match the layout of its literal.
CheckConstantLayout(HloInstruction * constant)949 Status CheckConstantLayout(HloInstruction* constant) {
950   if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) {
951     return InternalError(
952         "constant instruction %s does not match the layout of its literal %s",
953         constant->ToString(),
954         ShapeUtil::HumanStringWithLayout(constant->literal().shape()));
955   }
956   return OkStatus();
957 }
958 
GetBroadcastLayoutFromOutput(const Layout & layout,const HloInstruction * hlo)959 Layout GetBroadcastLayoutFromOutput(const Layout& layout,
960                                     const HloInstruction* hlo) {
961   CHECK_EQ(hlo->opcode(), HloOpcode::kBroadcast);
962   Shape shape = hlo->shape();
963   *shape.mutable_layout() = layout;
964   shape = ShapeUtil::FilterDimensions(
965       [&](int64_t dim) {
966         return absl::c_linear_search(hlo->dimensions(), dim);
967       },
968       shape);
969   return shape.layout();
970 }
971 
CheckBroadcastLayout(HloInstruction * broadcast)972 Status CheckBroadcastLayout(HloInstruction* broadcast) {
973   CHECK_EQ(broadcast->opcode(), HloOpcode::kBroadcast);
974   Shape shape = ShapeUtil::FilterDimensions(
975       [&](int64_t dim) {
976         return absl::c_linear_search(broadcast->dimensions(), dim);
977       },
978       broadcast->shape());
979   if (!LayoutsInShapesEqual(shape, broadcast->operand(0)->shape())) {
980     return InternalError(
981         "broadcast instruction %s does not match the layout of its operand %s",
982         broadcast->ToString(), broadcast->operand(0)->ToString());
983   }
984   return OkStatus();
985 }
986 
987 }  // namespace
988 
CreateCopyWithNewLayout(const Shape & shape_with_layout,HloInstruction * instruction)989 StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
990     const Shape& shape_with_layout, HloInstruction* instruction) {
991   TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
992   DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
993       << ShapeUtil::HumanString(shape_with_layout) << " "
994       << ShapeUtil::HumanString(instruction->shape())
995       << " instruction: " << instruction->ToString();
996 
997   if (instruction->shape().IsTuple()) {
998     // Copy tuple elements which have differing layouts.
999     std::vector<HloInstruction*> element_copies;
1000     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
1001          ++i) {
1002       const Shape& target_shape =
1003           ShapeUtil::GetSubshape(shape_with_layout, {i});
1004       const Shape& instr_shape =
1005           ShapeUtil::GetSubshape(instruction->shape(), {i});
1006       HloInstruction* gte = instruction->parent()->AddInstruction(
1007           HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
1008 
1009       if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape,
1010                                                     instr_shape)) {
1011         // Shapes and layouts are equal, no need to copy.
1012         element_copies.push_back(gte);
1013       } else {
1014         SetupCopiedInstruction(*instruction, gte, {i});
1015         // Recurse to copy each element.
1016         TF_ASSIGN_OR_RETURN(HloInstruction * element_copy,
1017                             CreateCopyWithNewLayout(target_shape, gte));
1018         element_copies.push_back(element_copy);
1019       }
1020     }
1021     // Gather element copies into a tuple with a new Tuple instruction.
1022     HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
1023         HloInstruction::CreateTuple(element_copies));
1024     SetupCopiedInstruction(*instruction, tuple_copy, {});
1025     LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
1026     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1027         shape_with_layout, tuple_copy->mutable_shape()));
1028     return tuple_copy;
1029   } else if (instruction->shape().IsArray()) {
1030     HloInstruction* copy =
1031         instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
1032             instruction->shape(), HloOpcode::kCopy, instruction));
1033     RegisterAddedCopy(copy);
1034     SetupCopiedInstruction(*instruction, copy, {});
1035     LayoutUtil::ClearLayout(copy->mutable_shape());
1036     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1037         shape_with_layout, copy->mutable_shape()));
1038 
1039     return copy;
1040   } else {
1041     return FailedPrecondition(
1042         "Can only copy array and tuple shaped instructions");
1043   }
1044 }
1045 
1046 // Creates a copy of the given operand if the operand's layout does not match
1047 // the given layout. This copy replaces the use in the given instruction. Tuple
1048 // operands will be deep-copied.
CopyOperandIfLayoutsDiffer(const ShapeLayout & operand_layout,HloInstruction * instruction,int64_t operand_no)1049 Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
1050     const ShapeLayout& operand_layout, HloInstruction* instruction,
1051     int64_t operand_no) {
1052   HloInstruction* operand = instruction->mutable_operand(operand_no);
1053   TF_RET_CHECK(operand_layout.LayoutIsSet());
1054   TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
1055 
1056   if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(),
1057                                                 operand->shape())) {
1058     VLOG(2) << "Operand " << operand->ToString() << " layout matches in "
1059             << instruction->ToString();
1060     // Operand layout already matches our constraint. Nothing to do.
1061     return OkStatus();
1062   }
1063   VLOG(2) << "Operand " << operand->ToString() << " layout does not match "
1064           << operand_layout.ToString() << " in " << instruction->ToString();
1065 
1066   // If the operand is only used by a conditional, do the copy inside the branch
1067   // to avoid overhead for other branches.
1068   if (!reverse_computation_order_ &&
1069       instruction->opcode() == HloOpcode::kConditional && operand_no > 0 &&
1070       instruction->operand(operand_no)->user_count() == 1) {
1071     auto branch_comp = instruction->branch_computation(operand_no - 1);
1072     auto param = branch_comp->parameter_instruction(0);
1073     *param->mutable_shape() = operand->shape();
1074     auto param_users = param->users();
1075     TF_ASSIGN_OR_RETURN(HloInstruction * param_copy,
1076                         CreateCopyWithNewLayout(operand_layout.shape(), param));
1077     for (auto user : param_users) {
1078       TF_RETURN_IF_ERROR(param->ReplaceUseWithDifferentShape(user, param_copy));
1079     }
1080     VLOG(2) << "New copy of " << operand->ToString() << " is "
1081             << param_copy->ToString();
1082     if (param == branch_comp->root_instruction()) {
1083       branch_comp->set_root_instruction(param_copy,
1084                                         /*accept_different_shape=*/true);
1085     }
1086 
1087     ComputationLayout computed_computation_layout(
1088         branch_comp->ComputeProgramShape(),
1089         /*ignore_layouts=*/false);
1090     mutable_computation_constraints(branch_comp)
1091         ->mutable_computation_constraint()
1092         ->ResetComputationLayout(computed_computation_layout,
1093                                  current_priority_ + 1,
1094                                  /* prop_result_layout=*/false,
1095                                  /*prop_parameter_layout=*/false);
1096     return OkStatus();
1097   }
1098 
1099   TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
1100                       CreateCopyWithNewLayout(operand_layout.shape(), operand));
1101 
1102   VLOG(4) << "New copy of " << operand->ToString() << " is "
1103           << operand_copy->ToString();
1104   return instruction->ReplaceOperandWith(operand_no, operand_copy);
1105 }
1106 
SetupCopiedInstruction(const HloInstruction & instruction,HloInstruction * copy,const ShapeIndex & index)1107 void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
1108                                               HloInstruction* copy,
1109                                               const ShapeIndex& index) {
1110   if (instruction.has_sharding()) {
1111     // If the index is empty, we want to copy the whole sharding, in case the
1112     // sharding is a tuple sharding.
1113     HloSharding sharding =
1114         !index.empty() && instruction.sharding().IsTuple()
1115             ? instruction.sharding().GetSubSharding(instruction.shape(), index)
1116             : instruction.sharding();
1117     // We propagate the sharding to the copied instruction only if it is a
1118     // special sharding, like tiled ones.
1119     // Otherwise it is preferable to leave the new instruction without device,
1120     // and let the automatic device placer to choose the best location.
1121     auto device = sharding.UniqueDevice();
1122     if (!device || HloSharding::IsReservedDevice(*device)) {
1123       copy->set_sharding(sharding);
1124     }
1125   }
1126   copy->set_metadata(instruction.metadata());
1127 }
1128 
CheckLayouts(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)1129 Status LayoutAssignment::CheckLayouts(
1130     HloModule* module,
1131     const absl::flat_hash_set<absl::string_view>& execution_threads) {
1132   TF_ASSIGN_OR_RETURN(auto points_to_analysis,
1133                       TuplePointsToAnalysis::Run(module));
1134   for (auto* computation :
1135        module->MakeNonfusionComputations(execution_threads)) {
1136     for (auto* instruction : computation->instructions()) {
1137       // Verify every instruction has a layout and the layout is valid for the
1138       // shape.
1139       TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
1140       TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
1141 
1142       // Use points-to analysis to verify that every subshape element in the
1143       // output of the instruction matches the layout of the logical buffer
1144       // which could be the source of the subshape value.
1145       const PointsToSet& points_to_set =
1146           points_to_analysis->GetPointsToSet(instruction);
1147       TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
1148           [&instruction](ShapeIndex index,
1149                          const PointsToSet::BufferList& buffers) -> Status {
1150             if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
1151               const Shape& instruction_subshape =
1152                   ShapeUtil::GetSubshape(instruction->shape(), index);
1153               for (const LogicalBuffer* buffer : buffers) {
1154                 if (!Shape::Equal()
1155                          .IgnoreDynamicDimension()
1156                          .MinorToMajorOnlyInLayout()(instruction_subshape,
1157                                                      buffer->shape())) {
1158                   return InternalError(
1159                       "Layout of instruction %s at index {%s} does not match "
1160                       "source LogicalBuffer %s: %s vs %s",
1161                       instruction->name(), absl::StrJoin(index, ","),
1162                       buffer->ToString(),
1163                       ShapeUtil::HumanStringWithLayout(instruction_subshape),
1164                       ShapeUtil::HumanStringWithLayout(buffer->shape()));
1165                 }
1166               }
1167             }
1168             return OkStatus();
1169           }));
1170 
1171       // Verify instructions that have special layout constraints.
1172       switch (instruction->opcode()) {
1173         case HloOpcode::kCall:
1174           TF_RETURN_IF_ERROR(CheckCallLayout(
1175               instruction,
1176               FindOrDie(computation_layouts_, instruction->to_apply())
1177                   ->computation_layout()));
1178           break;
1179         case HloOpcode::kCustomCall:
1180           TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
1181           break;
1182         case HloOpcode::kFusion:
1183           TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
1184           break;
1185         case HloOpcode::kParameter:
1186           TF_RETURN_IF_ERROR(CheckParameterLayout(
1187               instruction,
1188               FindOrDie(computation_layouts_, instruction->parent())
1189                   ->computation_layout()));
1190           break;
1191         case HloOpcode::kBroadcast:
1192           TF_RETURN_IF_ERROR(CheckBroadcastLayout(instruction));
1193           break;
1194         case HloOpcode::kConstant:
1195           TF_RETURN_IF_ERROR(CheckConstantLayout(instruction));
1196           break;
1197         case HloOpcode::kWhile:
1198           TF_RETURN_IF_ERROR(CheckWhileLayout(
1199               instruction,
1200               FindOrDie(computation_layouts_, instruction->while_condition())
1201                   ->computation_layout(),
1202               FindOrDie(computation_layouts_, instruction->while_body())
1203                   ->computation_layout()));
1204           break;
1205         case HloOpcode::kOptimizationBarrier:
1206           TF_RETURN_IF_ERROR(CheckOptimizationBarrierLayout(instruction));
1207           break;
1208         case HloOpcode::kConditional: {
1209           std::vector<ComputationLayout> branch_computation_layouts;
1210           const auto& branch_computations = instruction->branch_computations();
1211           branch_computation_layouts.reserve(branch_computations.size());
1212           for (const auto branch_computation : branch_computations) {
1213             branch_computation_layouts.emplace_back(
1214                 FindOrDie(computation_layouts_, branch_computation)
1215                     ->computation_layout());
1216           }
1217           TF_RETURN_IF_ERROR(CheckConditionalLayout(
1218               instruction, absl::MakeSpan(branch_computation_layouts)));
1219           break;
1220         }
1221         default:
1222           break;
1223       }
1224     }
1225   }
1226   // Finally verify the result layout, if set, matches the layout of the entry
1227   // computation root.
1228   const ShapeLayout& result_layout =
1229       FindOrDie(computation_layouts_, module->entry_computation())
1230           ->computation_layout()
1231           .result_layout();
1232   if (result_layout.LayoutIsSet()) {
1233     TF_RET_CHECK(
1234         Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
1235             module->result_shape(), result_layout.shape()));
1236   }
1237   return OkStatus();
1238 }
1239 
LayoutAssignment(ComputationLayout * entry_computation_layout,ChannelLayoutConstraints * channel_constraints,bool reverse_computation_order)1240 LayoutAssignment::LayoutAssignment(
1241     ComputationLayout* entry_computation_layout,
1242     ChannelLayoutConstraints* channel_constraints,
1243     bool reverse_computation_order)
1244     : entry_computation_layout_(entry_computation_layout),
1245       saved_entry_computation_layout_(*entry_computation_layout),
1246       reverse_computation_order_(reverse_computation_order),
1247       channel_layout_constraints_(channel_constraints) {
1248   if (channel_layout_constraints_ != nullptr) {
1249     // Save a copy of the input ChannelLayoutConstraints so that we can reset it
1250     // if we have to undo previous operations (ClearPreviousPassSideEffects()).
1251     channel_constraints_ = *channel_layout_constraints_;
1252   }
1253   VLOG(1) << "Entry computation layout given to layout assignment: "
1254           << entry_computation_layout_->ToString();
1255 }
1256 
ChooseOperandLayoutFromOutputLayout(const Layout & output_layout,const HloInstruction * instruction,int64_t operand_no)1257 std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
1258     const Layout& output_layout, const HloInstruction* instruction,
1259     int64_t operand_no) {
1260   const HloInstruction* operand = instruction->operand(operand_no);
1261   CHECK(instruction->shape().IsArray());
1262   CHECK(operand->shape().IsArray());
1263   if (!ShapeUtil::IsScalar(operand->shape()) &&
1264       operand->shape().rank() == instruction->shape().rank() &&
1265       !InstructionCanChangeLayoutInstance(instruction)) {
1266     // Propagate the result layout to the operand layout if the instruction
1267     // requires the same layout out for the result and the operand.
1268     //
1269     // For elementwise operations, using the same layout for the operands and
1270     // the result also has the following benefits:
1271     // 1) the elementwise operation can reuse its operand's buffer, and
1272     // 2) the input and output elements can reuse the same linear index.
1273     return std::make_unique<Layout>(output_layout);
1274   }
1275 
1276   if (instruction->opcode() == HloOpcode::kReshape) {
1277     // Prefer the operand layout that makes the reshape an bitcast. If any
1278     // dimension bound is 1 in the operand shape, there may be several such
1279     // layouts. So if 'output_layout' is the default layout, try if the
1280     // reshape is a bitcast when using the same layout. This may avoid copy
1281     // operations. For similar reasons, if the operand and output have the same
1282     // rank, try to match the operand's layout to the output.
1283     if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1284         ShapeUtil::TrueRank(instruction->shape()) == 1) {
1285       // Don't assign a layout in case of R1 -> effective R1 reshape.
1286       return nullptr;
1287     }
1288 
1289     const Shape& output_shape = instruction->shape();
1290     Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1291         output_shape.element_type(), output_shape.dimensions(),
1292         LayoutUtil::MinorToMajor(output_layout));
1293     Shape operand_shape = operand->shape();
1294     *operand_shape.mutable_layout() =
1295         LayoutUtil::GetDefaultLayoutForShape(operand_shape);
1296     auto aligned_operand_shape =
1297         ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape);
1298     if (aligned_operand_shape) {
1299       auto operand_layout = aligned_operand_shape.value().layout();
1300       TF_CHECK_OK(
1301           LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
1302       return std::make_unique<Layout>(operand_layout);
1303     }
1304   }
1305 
1306   if (instruction->opcode() == HloOpcode::kTranspose) {
1307     // Pick the operand layout that makes the transpose a bitcast.
1308     int64_t rank = instruction->shape().rank();
1309     std::vector<int64_t> new_minor_to_major(rank);
1310     for (int64_t i = 0; i < rank; ++i) {
1311       int64_t output_dim = LayoutUtil::Minor(output_layout, i);
1312       int64_t operand_dim = instruction->dimensions(output_dim);
1313       new_minor_to_major[i] = operand_dim;
1314     }
1315     Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1316     TF_CHECK_OK(
1317         LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
1318     return std::make_unique<Layout>(operand_layout);
1319   }
1320 
1321   return nullptr;
1322 }
1323 
GetReduceLayoutFromOperand(const Layout & operand_layout,const HloInstruction * hlo)1324 static Layout GetReduceLayoutFromOperand(const Layout& operand_layout,
1325                                          const HloInstruction* hlo) {
1326   CHECK_EQ(hlo->opcode(), HloOpcode::kReduce);
1327   Shape operand_shape = hlo->operand(0)->shape();
1328   *operand_shape.mutable_layout() = operand_layout;
1329   operand_shape = ShapeUtil::DeleteDimensions(hlo->dimensions(), operand_shape);
1330   return operand_shape.layout();
1331 }
1332 
ChooseOutputLayoutFromOperandLayout(const Layout & operand_layout,const HloInstruction * user,int64_t operand_no)1333 std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
1334     const Layout& operand_layout, const HloInstruction* user,
1335     int64_t operand_no) {
1336   const HloInstruction* operand = user->operand(operand_no);
1337 
1338   // Enforce standard layout on variadic reduction output to avoid having two
1339   // inconsistent layouts.
1340   if (user->opcode() == HloOpcode::kReduce && user->shape().IsTuple()) {
1341     return std::make_unique<Layout>(
1342         GetReduceLayoutFromOperand(operand_layout, user));
1343   }
1344 
1345   CHECK(user->shape().IsArray() && operand->shape().IsArray())
1346       << "Fails on instruction: " << user->ToString();
1347 
1348   if (!ShapeUtil::IsScalar(operand->shape()) &&
1349       operand->shape().rank() == user->shape().rank() &&
1350       !InstructionCanChangeLayoutInstance(user)) {
1351     // Assign users the same layout as the operand.
1352     return std::make_unique<Layout>(operand_layout);
1353   }
1354 
1355   if (user->opcode() == HloOpcode::kReshape) {
1356     // Prefer the user layout that makes the reshape an bitcast. If any
1357     // dimension bound is 1 in the user shape, there may be several such
1358     // layouts. So if 'operand_layout' is the default layout, try if the
1359     // reshape is a bitcast when using the same layout. This may avoid copy
1360     // operations. For similar reasons, if the operand and output have the same
1361     // rank, try to match the outputs's layout to the operand.
1362     if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1363         ShapeUtil::TrueRank(user->shape()) == 1) {
1364       // Don't assign a layout in case of R1 -> effective R1 reshape.
1365       return nullptr;
1366     }
1367     Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1368         operand->shape().element_type(), operand->shape().dimensions(),
1369         LayoutUtil::MinorToMajor(operand_layout));
1370     Shape output_shape = user->shape();
1371     *output_shape.mutable_layout() =
1372         LayoutUtil::GetDefaultLayoutForShape(output_shape);
1373     auto aligned_user_shape =
1374         ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape);
1375     if (aligned_user_shape) {
1376       auto user_layout = aligned_user_shape.value().layout();
1377       TF_CHECK_OK(
1378           LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
1379       return std::make_unique<Layout>(user_layout);
1380     }
1381   }
1382 
1383   if (user->opcode() == HloOpcode::kTranspose) {
1384     // Pick the user layout that makes the transpose a bitcast.
1385     int64_t rank = user->shape().rank();
1386     std::vector<int64_t> new_minor_to_major(rank);
1387     auto inverse_dimensions = InversePermutation(user->dimensions());
1388     for (int64_t i = 0; i < rank; ++i) {
1389       int64_t operand_dim = LayoutUtil::Minor(operand_layout, i);
1390       int64_t user_dim = inverse_dimensions[operand_dim];
1391       new_minor_to_major[i] = user_dim;
1392     }
1393     Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1394     TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
1395     return std::make_unique<Layout>(user_layout);
1396   }
1397 
1398   return nullptr;
1399 }
1400 
PropagateConstraints(LayoutConstraints * constraints)1401 Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
1402   // Gathers all initial constraints in a worklist and propagates them in
1403   // depth-first order. DFS order seems to be better than BFS because a
1404   // constraint is propagated as far as possible before propagating unrelated
1405   // constraints which makes it less likely that conflicting constraints will be
1406   // propagated to instructions. However, we should experiment with other orders
1407   // too.
1408   std::deque<const LayoutConstraint*> worklist;
1409 
1410   // Lambda for moving newly added constraints to the worklist.
1411   auto add_new_constraints_to_worklist = [this, &worklist]() {
1412     // Add constraints to the front of the deque for DFS ordering.
1413     for (auto* constraint : ConsumeAddedConstraints()) {
1414       if (constraint->dfs()) {
1415         worklist.push_front(constraint);
1416       } else {
1417         VLOG(3) << "push back constraint for propagation : "
1418                 << constraint->ToString();
1419         worklist.push_back(constraint);
1420       }
1421     }
1422   };
1423   add_new_constraints_to_worklist();
1424 
1425   while (!worklist.empty()) {
1426     const LayoutConstraint* layout_constraint = worklist.front();
1427     worklist.pop_front();
1428     VLOG(2) << "Propagating " << layout_constraint->ToString()
1429             << " to its neighbors with priority = "
1430             << layout_constraint->priority() << "\n";
1431     if (auto* buffer_constraint =
1432             dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
1433       TF_RETURN_IF_ERROR(
1434           PropagateBufferConstraint(*buffer_constraint, constraints));
1435     } else if (auto* operand_constraint =
1436                    dynamic_cast<const OperandLayoutConstraint*>(
1437                        layout_constraint)) {
1438       TF_RETURN_IF_ERROR(
1439           PropagateOperandConstraint(*operand_constraint, constraints));
1440     } else if (auto* computation_constraint =
1441                    dynamic_cast<const ComputationLayoutConstraint*>(
1442                        layout_constraint)) {
1443       TF_RETURN_IF_ERROR(
1444           PropagateResultConstraint(*computation_constraint, constraints));
1445     } else {
1446       LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
1447     }
1448     add_new_constraints_to_worklist();
1449   }
1450   return OkStatus();
1451 }
1452 
1453 namespace {
1454 
1455 // Returns a vector containing all array-shaped uses (instruction and operand
1456 // number) of the given logical buffer or its aliases.
GetArrayUsesOfBuffer(const TuplePointsToAnalysis::BufferAliasVector & aliases)1457 std::vector<std::pair<const HloInstruction*, int64_t>> GetArrayUsesOfBuffer(
1458     const TuplePointsToAnalysis::BufferAliasVector& aliases) {
1459   std::vector<std::pair<const HloInstruction*, int64_t>> uses;
1460   for (const auto& buffer_alias : aliases) {
1461     if (!buffer_alias.instruction()->shape().IsArray()) {
1462       continue;
1463     }
1464     // This alias must be the top-level (index == {}) of the instruction's
1465     // result because the instruction produces an array.
1466     CHECK(buffer_alias.index().empty());
1467 
1468     // Add all uses of the instruction's output.
1469     for (const HloInstruction* user : buffer_alias.instruction()->users()) {
1470       for (int64_t operand_no :
1471            user->OperandIndices(buffer_alias.instruction())) {
1472         uses.emplace_back(user, operand_no);
1473       }
1474     }
1475   }
1476   return uses;
1477 }
1478 
1479 }  // namespace
1480 
PropagateUseConstraintToDefs(const ShapeLayout & shape_layout,const HloInstruction * instruction,LayoutConstraints * constraints,int64_t priority)1481 Status LayoutAssignment::PropagateUseConstraintToDefs(
1482     const ShapeLayout& shape_layout, const HloInstruction* instruction,
1483     LayoutConstraints* constraints, int64_t priority) {
1484   // Try to set all logical buffers which may be sources of the given operand to
1485   // match the given layout.
1486   const PointsToSet& points_to_set =
1487       points_to_analysis_->GetPointsToSet(instruction);
1488   return points_to_set.ForEachElementWithStatus(
1489       [&shape_layout, this, priority](
1490           const ShapeIndex& index,
1491           const PointsToSet::BufferList& buffers) -> Status {
1492         if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
1493           for (const LogicalBuffer* buffer : buffers) {
1494             if (buffer->shape().IsArray() &&
1495                 GetBufferLayoutConstraint(*buffer) == nullptr &&
1496                 (buffer->instruction()->opcode() != HloOpcode::kReduce ||
1497                  !buffer->instruction()->shape().IsTuple())) {
1498               TF_RETURN_IF_ERROR(SetBufferLayout(
1499                   ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
1500                   *buffer, /*mandatory=*/true, /*dfs=*/true, priority));
1501             }
1502           }
1503         }
1504         return OkStatus();
1505       });
1506 }
1507 
1508 namespace {
1509 // A transpose or a reshape that only changes trivial dimensions have meaningful
1510 // layouts that are valuable to propagate in a depthfirst manner to avoid
1511 // unassigned layouts in the graph.
InstructionShouldPropagateDepthFirst(const HloInstruction & hlo)1512 bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
1513   switch (hlo.opcode()) {
1514     case HloOpcode::kFusion:
1515       return hlo.IsCustomFusion();
1516     case HloOpcode::kGather:
1517       return true;
1518     case HloOpcode::kReshape:
1519       return hlo.operand(0)->shape().rank() == 1 ||
1520              hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions().has_value();
1521     case HloOpcode::kScatter:
1522     case HloOpcode::kTranspose:
1523       return true;
1524     default:
1525       return false;
1526   }
1527 }
1528 
1529 }  // namespace
1530 
PropagateOperandConstraint(const OperandLayoutConstraint & operand_constraint,LayoutConstraints * constraints)1531 Status LayoutAssignment::PropagateOperandConstraint(
1532     const OperandLayoutConstraint& operand_constraint,
1533     LayoutConstraints* constraints) {
1534   VLOG(3) << "Propagate Operand Constraint : " << operand_constraint.ToString()
1535           << "\n";
1536   // Try to set the layout of the logical buffers in the given operand to match
1537   // the constrained layout. This avoids copies.
1538   TF_RETURN_IF_ERROR(PropagateUseConstraintToDefs(
1539       operand_constraint.shape_layout(), operand_constraint.operand(),
1540       constraints, operand_constraint.priority()));
1541 
1542   // For array-shaped operands and user instructions try to pick a minimum cost
1543   // layout. For example, if the operand of an elementwise instruction is
1544   // constrained to a certain layout we want the output of the instruction to
1545   // have the same layout.
1546   //
1547   // If the user is not array-shaped, we still want to propagate the layout
1548   // to siblings if the instruction can't change layout. This is to represent
1549   // the information that non-layout-changing instructions should have the same
1550   // layout for the operands with the same ranks.
1551   const HloInstruction* operand = operand_constraint.operand();
1552   const HloInstruction* user = operand_constraint.instruction();
1553   if (!operand->shape().IsArray()) {
1554     return OkStatus();
1555   }
1556 
1557   if (user->opcode() == HloOpcode::kAllReduce) {
1558     const auto shape_index =
1559         user->operand_count() == 1
1560             ? ShapeIndex()
1561             : ShapeIndex({operand_constraint.operand_no()});
1562     TF_ASSIGN_OR_RETURN(
1563         const LogicalBuffer* buffer,
1564         points_to_analysis_->GetBufferDefinedAt(user, shape_index));
1565     const BufferLayoutConstraint* constraint =
1566         GetBufferLayoutConstraint(*buffer);
1567     if (constraint == nullptr) {
1568       TF_RETURN_IF_ERROR(
1569           SetBufferLayout(operand_constraint.shape_layout().layout(), *buffer,
1570                           /*mandatory=*/false, /*dfs=*/true));
1571     }
1572   }
1573 
1574   if (InstructionCanChangeLayoutInstance(user) && !user->shape().IsArray() &&
1575       user->opcode() != HloOpcode::kReduce) {
1576     return OkStatus();
1577   }
1578 
1579   // Only try to choose a low cost layout if the instruction 'user' defines its
1580   // output (ie, doesn't forward a buffer from elsewhere).
1581   if (AnyOperandBufferForwarded(user, operand_constraint.operand_no())) {
1582     return OkStatus();
1583   }
1584 
1585   int64_t operand_rank = operand->shape().rank();
1586   if (operand_rank <= 1) {
1587     return OkStatus();
1588   }
1589 
1590   // Propagate layouts between operands of the same instruction. This is a
1591   // constraint on non-layout-changing instructions.
1592   if (!InstructionCanChangeLayoutInstance(user)) {
1593     // Only propgate the layout of the largest concatenate operand.
1594     if (user->opcode() == HloOpcode::kConcatenate) {
1595       for (int64_t operand_no = 0; operand_no < user->operand_count();
1596            ++operand_no) {
1597         const HloInstruction* sibling = user->operand(operand_no);
1598         if (sibling == operand) {
1599           continue;
1600         }
1601         if (sibling->shape().dimensions(user->concatenate_dimension()) >
1602             operand->shape().dimensions(user->concatenate_dimension())) {
1603           return OkStatus();
1604         }
1605       }
1606     }
1607     // Make sure all siblings have the same layout as the operand.
1608     for (int64_t operand_no = 0; operand_no < user->operand_count();
1609          ++operand_no) {
1610       if (user->operand(operand_no) == operand) {
1611         continue;
1612       }
1613       const HloInstruction* sibling = user->operand(operand_no);
1614       if (!sibling->shape().IsArray()) {
1615         continue;
1616       }
1617       const int64_t sibling_rank = sibling->shape().rank();
1618       if (sibling_rank <= 1) {
1619         continue;
1620       }
1621       if (operand_rank != sibling_rank) {
1622         continue;
1623       }
1624       const OperandLayoutConstraint* constraint =
1625           constraints->GetOperandLayoutConstraint(user, operand_no);
1626       if (constraint != nullptr) {
1627         // Due to the DFS of the propagation we can end up here when operand_no
1628         // has a layout set that hasn't been propagated yet (is still on the
1629         // stack of layouts to propagate).
1630         // We can continue here and leave the operands with different layouts,
1631         // as we will either:
1632         // - overwrite the current operand when the DFS gets back to propagating
1633         //   operand(operand_no) to its siblings
1634         // - overwrite operand(operand_no)'s layout with a mandatory layout if
1635         //   we continue to propagate our layout to the result, and then
1636         //   backwards into all operands (if the result is an array of rank > 1)
1637         continue;
1638       }
1639       TF_RETURN_IF_ERROR(SetArrayOperandLayout(
1640           operand_constraint.shape_layout().layout(), user, operand_no,
1641           /*mandatory=*/false, /*dfs=*/true, operand_constraint.priority()));
1642     }
1643     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1644         user->shape(),
1645         [&](const Shape& subshape, const ShapeIndex& shape_index) {
1646           if (subshape.IsTuple()) {
1647             return OkStatus();
1648           }
1649           if (subshape.rank() <= 1) {
1650             return OkStatus();
1651           }
1652 
1653           // Assign the right layout to input fusion of higher rank reduce
1654           // operations.
1655           if (subshape.rank() != operand->shape().rank()) {
1656             return OkStatus();
1657           }
1658           if (!points_to_analysis_->InstructionDefinesBufferAtIndex(
1659                   user, shape_index)) {
1660             return OkStatus();
1661           }
1662           // TODO(b/67641796): Are there cases except fusion that use this code
1663           // path?
1664           TF_ASSIGN_OR_RETURN(
1665               const LogicalBuffer* buffer,
1666               points_to_analysis_->GetBufferDefinedAt(user, shape_index));
1667           // If we already have a constraint for the buffer it was assigned but
1668           // hasn't propagated yet. This can happen with diamond-shaped graphs
1669           // where one path is first evaluated in depth-first order (we're here)
1670           // and the other path is propagated later. We don't set the layout
1671           // here as it will always be overwritten later.
1672           TF_RETURN_IF_ERROR(SetBufferLayout(
1673               operand_constraint.shape_layout().layout(), *buffer,
1674               /*mandatory=*/false));
1675           return OkStatus();
1676         }));
1677     return OkStatus();
1678   }
1679   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1680       user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) {
1681         if (subshape.IsTuple()) {
1682           return OkStatus();
1683         }
1684         if (subshape.rank() <= 1) {
1685           return OkStatus();
1686         }
1687         if (!points_to_analysis_->InstructionDefinesBufferAtIndex(
1688                 user, shape_index)) {
1689           return OkStatus();
1690         }
1691         TF_ASSIGN_OR_RETURN(
1692             const LogicalBuffer* buffer,
1693             points_to_analysis_->GetBufferDefinedAt(user, shape_index));
1694         auto* buffer_constraint = GetBufferLayoutConstraint(*buffer);
1695         if (buffer_constraint == nullptr || !buffer_constraint->mandatory()) {
1696           std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
1697               operand_constraint.shape_layout().layout(), user,
1698               operand_constraint.operand_no());
1699           if (layout != nullptr) {
1700             TF_RETURN_IF_ERROR(SetBufferLayout(
1701                 *layout, *buffer,
1702                 /*mandatory=*/user->opcode() == HloOpcode::kReduce,
1703                 /*dfs=*/InstructionShouldPropagateDepthFirst(*user),
1704                 operand_constraint.priority()));
1705           }
1706         }
1707         return OkStatus();
1708       }));
1709   return OkStatus();
1710 }
1711 
PropagateBufferConstraintToOperands(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1712 Status LayoutAssignment::PropagateBufferConstraintToOperands(
1713     const BufferLayoutConstraint& buffer_constraint,
1714     LayoutConstraints* constraints) {
1715   VLOG(5) << "PropagateBufferConstraintToOperands: "
1716           << buffer_constraint.ToString();
1717   const LogicalBuffer& buffer = buffer_constraint.buffer();
1718 
1719   const HloInstruction* instruction = buffer.instruction();
1720   if (IsAtMostRank1(instruction->shape())) {
1721     return OkStatus();
1722   }
1723 
1724   if (instruction->opcode() == HloOpcode::kAllReduce) {
1725     TF_RETURN_IF_ERROR(SetArrayOperandLayout(
1726         buffer_constraint.layout(), instruction,
1727         instruction->operand_count() == 1 ? 0 : buffer.index()[0],
1728         /*mandatory=*/true, /*dfs=*/true, buffer_constraint.priority()));
1729     return OkStatus();
1730   }
1731   for (int64_t operand_no = 0; operand_no < instruction->operand_count();
1732        ++operand_no) {
1733     const HloInstruction* operand = instruction->operand(operand_no);
1734     if (IsAtMostRank1(operand->shape())) {
1735       continue;
1736     }
1737     if (!InstructionCanChangeLayoutInstance(instruction)) {
1738       // Copy the layout to the operand.
1739       if (buffer.IsArray() && operand->shape().IsArray() &&
1740           operand->shape().rank() ==
1741               LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) {
1742         TF_RETURN_IF_ERROR(SetArrayOperandLayout(
1743             buffer_constraint.layout(), instruction, operand_no,
1744             /*mandatory=*/true, /*dfs=*/true, current_priority_));
1745       }
1746     } else if (instruction->opcode() == HloOpcode::kBroadcast) {
1747       Layout layout =
1748           GetBroadcastLayoutFromOutput(buffer_constraint.layout(), instruction);
1749       TF_RETURN_IF_ERROR(SetArrayOperandLayout(
1750           layout, instruction, operand_no, /*mandatory=*/true,
1751           /*dfs=*/
1752           InstructionShouldPropagateDepthFirst(*instruction),
1753           current_priority_));
1754     } else {
1755       if (!buffer.IsTopLevel() ||
1756           !instruction->operand(operand_no)->shape().IsArray()) {
1757         continue;  // Don't touch buffers that are internal to a tuple.
1758       }
1759       VLOG(6) << "Propagating constraint to operand " << operand_no << " of "
1760               << instruction->ToShortString();
1761       // Assign a layout if there is no constraint already.
1762       const OperandLayoutConstraint* constraint =
1763           constraints->GetOperandLayoutConstraint(instruction, operand_no);
1764       if (constraint == nullptr || !constraint->mandatory()) {
1765         std::unique_ptr<Layout> operand_layout =
1766             ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
1767                                                 instruction, operand_no);
1768         if (operand_layout != nullptr) {
1769           TF_RETURN_IF_ERROR(SetArrayOperandLayout(
1770               *operand_layout, instruction, operand_no, /*mandatory=*/false,
1771               /*dfs=*/
1772               InstructionShouldPropagateDepthFirst(*instruction),
1773               current_priority_));
1774         }
1775       } else {
1776         VLOG(6) << "Operand already has a constraint "
1777                 << constraint->ToString();
1778       }
1779     }
1780   }
1781   return OkStatus();
1782 }
1783 
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1784 Status LayoutAssignment::PropagateBufferConstraint(
1785     const BufferLayoutConstraint& buffer_constraint,
1786     LayoutConstraints* constraints) {
1787   // Only propagate array layouts.
1788   const LogicalBuffer& buffer = buffer_constraint.buffer();
1789   if (!buffer.IsArray()) {
1790     return OkStatus();
1791   }
1792   TF_RETURN_IF_ERROR(
1793       PropagateBufferConstraintToOperands(buffer_constraint, constraints));
1794   return PropagateBufferConstraintToUses(buffer_constraint, constraints);
1795 }
1796 
PropagateBufferConstraintToUses(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1797 Status LayoutAssignment::PropagateBufferConstraintToUses(
1798     const BufferLayoutConstraint& buffer_constraint,
1799     LayoutConstraints* constraints) {
1800   VLOG(5) << "PropagateBufferConstraintToUses: "
1801           << buffer_constraint.ToString();
1802   const LogicalBuffer& buffer = buffer_constraint.buffer();
1803   TF_RET_CHECK(buffer.IsArray());
1804 
1805   // Propagate the layout to all array uses of the logical buffer. This skips
1806   // uses of the buffer where the buffer is the element of a tuple.
1807   for (const auto& user_operand_no :
1808        GetArrayUsesOfBuffer(points_to_analysis_->GetBufferAliases(buffer))) {
1809     const HloInstruction* user = user_operand_no.first;
1810     int64_t operand_no = user_operand_no.second;
1811     // Only add an operand constraint if the user does not forward the buffer
1812     // because this case is not handled is SetOperandLayout.
1813     if (constraints->OperandLayout(user, operand_no) == nullptr &&
1814         !AnyOperandBufferForwarded(user, operand_no)) {
1815       TF_RETURN_IF_ERROR(SetArrayOperandLayout(
1816           buffer_constraint.layout(), user, operand_no, /*mandatory=*/false,
1817           /*dfs=*/true, buffer_constraint.priority()));
1818     }
1819   }
1820 
1821   // Propagate to backedges of kWhile.
1822   CallGraphNode& node = call_graph_->GetNode(buffer.instruction()->parent());
1823   if (node.caller_callsites().size() != 1) {
1824     return OkStatus();
1825   }
1826   const HloInstruction* parent = node.caller_callsites()[0].instruction();
1827   if (parent->opcode() != HloOpcode::kWhile) {
1828     return OkStatus();
1829   }
1830 
1831   for (HloInstruction* user : buffer.instruction()->users()) {
1832     if (user->parent()->root_instruction()->opcode() != HloOpcode::kTuple) {
1833       continue;
1834     }
1835     if (user->parent()->root_instruction() == user) {
1836       VLOG(3) << "Propagating layout through backedge"
1837               << buffer_constraint.layout().ToString();
1838       int64_t index = user->operand_index(buffer.instruction());
1839       TF_ASSIGN_OR_RETURN(
1840           auto buffer, points_to_analysis_->GetBufferDefinedAt(
1841                            user->parent()->parameter_instruction(0), {index}));
1842 
1843       TF_RETURN_IF_ERROR(SetBufferLayout(buffer_constraint.layout(), *buffer,
1844                                          /*mandatory=*/false));
1845     }
1846   }
1847 
1848   return OkStatus();
1849 }
1850 
PropagateResultConstraint(const ComputationLayoutConstraint & layout_constraint,LayoutConstraints * constraints)1851 Status LayoutAssignment::PropagateResultConstraint(
1852     const ComputationLayoutConstraint& layout_constraint,
1853     LayoutConstraints* constraints) {
1854   // Propagate the use constraint of the root instruction up to the logical
1855   // buffers which make up the result.
1856   return PropagateUseConstraintToDefs(
1857       layout_constraint.computation_layout().result_layout(),
1858       constraints->computation()->root_instruction(), constraints,
1859       current_priority_);
1860 }
1861 
1862 // Infers the layout of the array at the given index in the given instruction's
1863 // output using points-to analysis. Precondition: The given instruction must
1864 // not produce this array value (that is, the array is forwarded from the
1865 // instruction's operands).
InferArrayLayout(const HloInstruction * instruction,const ShapeIndex & index)1866 StatusOr<Layout> LayoutAssignment::InferArrayLayout(
1867     const HloInstruction* instruction, const ShapeIndex& index) {
1868   const auto& source_buffers =
1869       points_to_analysis_->GetPointsToSet(instruction).element(index);
1870   TF_RET_CHECK(!source_buffers.empty());
1871 
1872   // Verify the layout is the same for every LogicalBuffer which this location
1873   // ('instruction' and 'index') points to.
1874   const Layout* first_buffer_layout = nullptr;
1875   for (const LogicalBuffer* source_buffer : source_buffers) {
1876     VLOG(5) << "Logical buffer: " << source_buffer->ToString() << "\n";
1877     auto* source_buffer_constraint = GetBufferLayoutConstraint(*source_buffer);
1878     if (source_buffer_constraint == nullptr) {
1879       // This should not happen because we've assigned layouts to all
1880       // instructions preceding this one.
1881       return InternalError("LogicalBuffer %s does not have a layout",
1882                            source_buffer->ToString());
1883     }
1884 
1885     if (first_buffer_layout == nullptr) {
1886       first_buffer_layout = &source_buffer_constraint->layout();
1887     } else if (!Layout::Equal().MinorToMajorOnly()(
1888                    source_buffer->shape().layout(), *first_buffer_layout)) {
1889       // The points-to set is ambiguous for this index and the different source
1890       // buffers have different layouts. This case is possible in valid XLA
1891       // computations because we do not propagate BufferLayoutConstraints to all
1892       // LogicalBuffers which may alias the constrained LogicalBuffer at some
1893       // point in the computation.
1894       return FailedPrecondition(
1895           "Array at index {%s} in instruction %s aliases buffers %s "
1896           "and %s which have different layouts",
1897           absl::StrJoin(index, ","), instruction->name(),
1898           source_buffers[0]->ToString(), source_buffer->ToString());
1899     }
1900   }
1901 
1902   return *first_buffer_layout;
1903 }
1904 
1905 namespace {
1906 
1907 // For fusion instructions, set the layout of each fused parameter instruction
1908 // to match the layout of its corresponding fusion instruction operand. Also,
1909 // set the layout of the fused root to match the layout of the fusion
1910 // instruction itself.
SetFusionLayouts(HloInstruction * fusion)1911 Status SetFusionLayouts(HloInstruction* fusion) {
1912   TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
1913   for (auto* fused_instruction :
1914        fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
1915     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1916       const HloInstruction* fusion_operand =
1917           fusion->operand(fused_instruction->parameter_number());
1918       DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
1919                                    fused_instruction->shape()));
1920       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1921           fusion_operand->shape(), fused_instruction->mutable_shape()));
1922     } else if (fused_instruction == fusion->fused_expression_root()) {
1923       // The layout of the root of the fused expression must match the fusion
1924       // instruction layout.
1925       DCHECK(
1926           ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
1927       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1928           fusion->shape(), fused_instruction->mutable_shape()));
1929     } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) {
1930       // A GTE inherits its layout from its operand (which should ultimately be
1931       // a parameter).
1932       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1933           fused_instruction->operand(0)->shape().tuple_shapes(
1934               fused_instruction->tuple_index()),
1935           fused_instruction->mutable_shape()));
1936     } else if (fused_instruction->opcode() == HloOpcode::kConstant) {
1937       // Give constants the layout of their literal.
1938       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1939           fused_instruction->literal().shape(),
1940           fused_instruction->mutable_shape()));
1941     } else if (fused_instruction->opcode() == HloOpcode::kInfeed) {
1942       // Nop; leave the infeed layout alone.
1943     } else if (!fusion->IsCustomFusion()) {
1944       // Other instructions don't have layouts inside of fusion nodes.
1945       // But do not clear layouts for other instructions in custom fusion nodes.
1946       LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
1947     }
1948   }
1949 
1950   return OkStatus();
1951 }
1952 
1953 }  // namespace
1954 
AssignLayouts(LayoutConstraints & constraints)1955 Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) {
1956   HloComputation* computation = constraints.computation();
1957   VLOG(2) << "Assigning layouts to computation: " << computation->name();
1958 
1959   XLA_VLOG_LINES(2, ToString(constraints));
1960 
1961   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1962     if (instruction->opcode() == HloOpcode::kBitcast) {
1963       // bitcasts are inherently layout sensitive and so a bitcast instruction
1964       // present in the IR before layout assignment is a bug.
1965       return InternalError(
1966           "Unexpected bitcast operation seen during layout assignment: %s.",
1967           instruction->ToString());
1968     }
1969     LayoutUtil::ClearLayout(instruction->mutable_shape());
1970 
1971     // Set the layouts of the array shapes this instruction defines as indicated
1972     // by the respective BufferLayoutConstraints. Any array shapes in the output
1973     // of the instruction which are not defined by the instruction (eg, array
1974     // elements in a Tuple instruction) will be assigned below via inference.
1975     for (const LogicalBuffer* buffer :
1976          points_to_analysis_->GetBuffersDefinedByInstruction(instruction)) {
1977       if (!buffer->shape().IsArray()) {
1978         continue;
1979       }
1980       TF_RET_CHECK(buffer->instruction() == instruction);
1981       auto* buffer_layout_constraint = GetBufferLayoutConstraint(*buffer);
1982       TF_RET_CHECK(buffer_layout_constraint != nullptr);
1983       if (instruction->opcode() == HloOpcode::kConstant) {
1984         // For constants, we also need to change the layout of the internal
1985         // literal.
1986         instruction->RelayoutConstant(buffer_layout_constraint->layout(),
1987                                       buffer->index());
1988       } else {
1989         Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
1990             instruction->mutable_shape(), buffer->index());
1991         *buffer_subshape->mutable_layout() = buffer_layout_constraint->layout();
1992       }
1993     }
1994 
1995     // Any remaining layouts in the output of the instruction must be
1996     // inferrable using points-to analysis.
1997     TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
1998         instruction->mutable_shape(),
1999         [instruction, this](Shape* subshape, const ShapeIndex& index) {
2000           if (subshape->has_layout() || !subshape->IsArray()) {
2001             return OkStatus();
2002           }
2003           // Set Layout of subshape to match layout of LogicalBuffer which
2004           // produces it.
2005           TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
2006                               InferArrayLayout(instruction, index));
2007           return OkStatus();
2008         }));
2009     VLOG(3) << "Instruction layout:" << instruction->ToString();
2010     // Create a copy of an operand if the operand instruction's layout does not
2011     // match the use constraint (OperandLayoutConstraint).
2012     for (int64_t operand_no = 0; operand_no < instruction->operand_count();
2013          ++operand_no) {
2014       const ShapeLayout* operand_layout =
2015           constraints.OperandLayout(instruction, operand_no);
2016       if (operand_layout != nullptr) {
2017         TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
2018                                                       instruction, operand_no));
2019       } else {
2020         VLOG(2) << "operand " << operand_no << " has no constraint";
2021       }
2022     }
2023     if (instruction->opcode() == HloOpcode::kFusion) {
2024       TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
2025     }
2026 
2027     VLOG(3) << "Resulting instruction:" << instruction->ToString() << "\n";
2028     // Execute extra verification step once the layout has been finalized.
2029     TF_RETURN_IF_ERROR(Verify(instruction));
2030 
2031     // Shape must be valid.
2032     TF_RETURN_IF_ERROR(
2033         ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
2034 
2035     // Verify all layouts in the shape have been set.
2036     TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
2037   }
2038   // Copy the root instruction's result if its layout does not match the result
2039   // layout constraint.
2040   if (constraints.ResultLayout() != nullptr) {
2041     // Layout assignment at this point only does minor-to-major assignment so
2042     // tiling info should be ignored here for comparison.
2043     VLOG(5) << "Computation result layout needs root copying\n";
2044     if (!constraints.ResultLayout()->MatchesLayoutInShape(
2045             computation->root_instruction()->shape(),
2046             /*minor_to_major_only=*/true)) {
2047       TF_ASSIGN_OR_RETURN(
2048           HloInstruction * new_root,
2049           CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
2050                                   computation->root_instruction()));
2051       computation->set_root_instruction(new_root);
2052     } else {
2053       // Copy the tiling info specified in result layout.
2054       auto copy_tiling = [&constraints](xla::Shape* subshape,
2055                                         const xla::ShapeIndex& index) {
2056         if (subshape->IsArray()) {
2057           const Shape& result_shape = ShapeUtil::GetSubshape(
2058               constraints.ResultLayout()->shape(), index);
2059           if (result_shape.layout().tiles_size() != 0) {
2060             subshape->mutable_layout()->mutable_tiles()->assign(
2061                 result_shape.layout().tiles().begin(),
2062                 result_shape.layout().tiles().end());
2063           }
2064         }
2065       };
2066       xla::ShapeUtil::ForEachMutableSubshape(
2067           computation->root_instruction()->mutable_shape(), copy_tiling);
2068     }
2069   }
2070   VLOG(5) << "Final computation layout:" << computation->name() << ":"
2071           << constraints.computation_constraint().ToString() << "\n";
2072   VLOG(5) << "Root instruction:" << computation->root_instruction()->ToString()
2073           << "\n";
2074   return OkStatus();
2075 }
2076 
CalculateComputationLayout(LayoutConstraints * constraints)2077 Status LayoutAssignment::CalculateComputationLayout(
2078     LayoutConstraints* constraints) {
2079   // Process instructions that contain nested computations and may require
2080   // additional layouts to be assigned on the instructions nested inside.
2081 
2082   auto UpdateLayout = [this](const HloInstruction* operand,
2083                              ShapeLayout* update) -> bool {
2084     bool change = false;
2085     ShapeUtil::ForEachSubshape(
2086         operand->shape(), [this, &change, operand, update](
2087                               const Shape& subshape, const ShapeIndex& index) {
2088           if (subshape.IsTuple()) {
2089             return;
2090           }
2091           auto param_layout = InferArrayLayout(operand, index);
2092           if (param_layout.ok()) {
2093             VLOG(5) << index << ":" << param_layout.ValueOrDie().ToString()
2094                     << "\n";
2095             update->ResetLayout(param_layout.ValueOrDie(), index);
2096             change = true;
2097           }
2098         });
2099     return change;
2100   };
2101 
2102   auto SetCalleeLayout =
2103       [this, UpdateLayout](const HloInstruction* result,
2104                            absl::Span<const HloInstruction* const> operands,
2105                            LayoutConstraints* callee, int priority) -> Status {
2106     CHECK_NE(result, nullptr);
2107     ComputationLayoutConstraint* callee_constraint =
2108         callee->mutable_computation_constraint();
2109     ComputationLayout callee_layout = callee_constraint->computation_layout();
2110     if (callee_constraint->priority() < priority ||
2111         conditional_mismatch_.count(callee->computation()) > 0) {
2112       if (conditional_mismatch_.count(callee->computation()) == 0 &&
2113           UpdateLayout(result, callee_layout.mutable_result_layout())) {
2114         VLOG(2) << "Setting result layout from : " << result->ToString()
2115                 << "\n";
2116       }
2117       int64_t operand_no = 0;
2118       for (auto* operand : operands) {
2119         if (UpdateLayout(operand,
2120                          callee_layout.mutable_parameter_layout(operand_no))) {
2121           VLOG(2) << "Setting callee parameter: " << operand->ToString()
2122                   << "\n";
2123         }
2124         ++operand_no;
2125       }
2126       VLOG(2) << "Set callee layout: " << callee->computation()->name() << ":"
2127               << callee_layout.ToString()
2128               << "; original priority = " << callee_constraint->priority()
2129               << "\n";
2130       callee_constraint->ResetComputationLayout(callee_layout, priority, true,
2131                                                 true);
2132     }
2133     return OkStatus();
2134   };
2135   for (HloInstruction* instruction :
2136        constraints->computation()->MakeInstructionPostOrder()) {
2137     switch (instruction->opcode()) {
2138       case HloOpcode::kFusion:
2139         TF_RETURN_IF_ERROR(
2140             SetCalleeLayout(instruction, instruction->operands(),
2141                             mutable_computation_constraints(
2142                                 instruction->fused_instructions_computation()),
2143                             current_priority_ + 1));
2144         break;
2145       case HloOpcode::kCall:
2146         if (reverse_computation_order_ &&
2147             SetCalleeLayout(
2148                 instruction, instruction->operands(),
2149                 mutable_computation_constraints(instruction->to_apply()),
2150                 current_priority_ + 1) == OkStatus()) {
2151           VLOG(2) << "Successfully propagated to callee layout\n";
2152         }
2153         break;
2154       case HloOpcode::kConditional:
2155         if (reverse_computation_order_) {
2156           // If the branches don't yet have layouts, propagate existing layout
2157           // inside the branches.
2158           for (int i = 0; i < instruction->branch_count(); ++i) {
2159             TF_RETURN_IF_ERROR(
2160                 SetCalleeLayout(instruction, {instruction->operand(i + 1)},
2161                                 mutable_computation_constraints(
2162                                     instruction->branch_computation(i)),
2163                                 current_priority_ + 1));
2164           }
2165         }
2166         break;
2167       case HloOpcode::kWhile:
2168         // If the loop body doesn't have layouts, propagate existing one inside.
2169         if (reverse_computation_order_) {
2170           VLOG(2) << "Populating while loop constraints inside loop body.";
2171           VLOG(2) << instruction->ToString();
2172           TF_RETURN_IF_ERROR(SetCalleeLayout(
2173               instruction, {instruction->operand(0)},
2174               mutable_computation_constraints(instruction->while_body()),
2175               current_priority_ + 1));
2176           VLOG(2) << "Populating while loop constraints inside loop condition.";
2177           VLOG(2) << instruction->ToString();
2178           TF_RETURN_IF_ERROR(SetCalleeLayout(
2179               instruction->operand(0), {instruction->operand(0)},
2180               mutable_computation_constraints(instruction->while_condition()),
2181               current_priority_ + 1));
2182         }
2183         break;
2184       default:
2185         break;
2186     }
2187   }
2188   // Reset the layout of the current computation from its body.
2189   if (current_priority_ == 0 ||
2190       conditional_mismatch_.count(constraints->computation()) > 0) {
2191     TF_RETURN_IF_ERROR(SetCalleeLayout(
2192         constraints->computation()->root_instruction(),
2193         constraints->computation()->parameter_instructions(), constraints,
2194         current_priority_ + kNumberOfPropagationRounds));
2195     if (constraints->computation()->IsEntryComputation()) {
2196       *entry_computation_layout_ = constraints->computation_layout();
2197     }
2198   }
2199   return OkStatus();
2200 }
2201 
ClearComputationLayouts(HloComputation * computation)2202 Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
2203   // Clear existing layouts of the instructions.  All layouts must be assigned
2204   // by the LayoutAssignment pass, except for those on parameters, the
2205   // computation result, and a couple special cases. The former two are
2206   // specified in computation_layout.  Clearing the layouts here avoids hiding
2207   // potential bugs in the layout assignment pass that may accidentally use the
2208   // existing layout.
2209   for (HloInstruction* instruction : computation->instructions()) {
2210     if (instruction->opcode() == HloOpcode::kBitcast) {
2211       // bitcasts are inherently layout sensitive and so a bitcast instruction
2212       // present in the IR before layout assignment is a bug.
2213       return InternalError(
2214           "Unexpected bitcast operation seen during layout assignment: %s.",
2215           instruction->ToString());
2216     }
2217     // Some instructions carry mandatory layouts in their shape.
2218     if (instruction->opcode() != HloOpcode::kInfeed &&
2219         !IsLayoutConstrainedCustomCall(instruction) &&
2220         !IsLayoutConstrainedCollective(instruction)) {
2221       LayoutUtil::ClearLayout(instruction->mutable_shape());
2222     }
2223   }
2224   return OkStatus();
2225 }
2226 
RunOnComputation(LayoutConstraints * constraints,ChannelLayoutConstraints * channel_constraints)2227 Status LayoutAssignment::RunOnComputation(
2228     LayoutConstraints* constraints,
2229     ChannelLayoutConstraints* channel_constraints) {
2230   HloComputation* computation = constraints->computation();
2231   VLOG(1) << "LayoutAssignment::RunOnComputation(" << computation->name()
2232           << ")";
2233   VLOG(4) << computation->ToString() << "\n";
2234 
2235   // Gather all array-shaped logical buffers into unconstrained_buffer_ids.
2236   for (HloInstruction* inst : computation->instructions()) {
2237     points_to_analysis_->GetPointsToSet(inst).ForEachElement(
2238         [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) {
2239           for (const LogicalBuffer* buffer : buffers) {
2240             // The points to analysis is computed per module, restrict
2241             // constraints to array buffers in this computation.
2242             if (buffer->IsArray() &&
2243                 buffer->instruction()->parent() == computation) {
2244               unconstrained_buffer_ids_.insert(buffer->id());
2245             }
2246           }
2247         });
2248   }
2249 
2250   // Add constraints required for correctness on all backends (eg, entry
2251   // parameter layout constraints).
2252   TF_RETURN_IF_ERROR(AddMandatoryConstraints(channel_constraints, constraints));
2253 
2254   // Add any backend-specific constraints.
2255   TF_RETURN_IF_ERROR(AddBackendConstraints(constraints));
2256 
2257   // Propagates layouts from mandatory and backend constraints.
2258   TF_RETURN_IF_ERROR(PropagateConstraints(constraints));
2259 
2260   // Prior to applying default layouts, we take note of all HLO instructions
2261   // which lack a layout constraint.
2262   for (LogicalBuffer::Id buffer_id : unconstrained_buffer_ids_) {
2263     VLOG(5)
2264         << "unconstrained instruction:"
2265         << points_to_analysis_->GetBuffer(buffer_id).instruction()->ToString()
2266         << "\n";
2267     unconstrained_layout_instructions_.insert(
2268         points_to_analysis_->GetBuffer(buffer_id).instruction());
2269   }
2270 
2271   // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
2272   // layout and propagate the change.
2273   while (!unconstrained_buffer_ids_.empty()) {
2274     int unconstrained_count = unconstrained_buffer_ids_.size();
2275 
2276     // Arbitrarily pick the first unconstrained buffer and give it the default
2277     // layout (or the literal layout, in case of constants). By construction
2278     // unconstrained_buffers() has a stable sort based on LogicalBuffer::Id.
2279     const LogicalBuffer& buffer =
2280         points_to_analysis_->GetBuffer(*unconstrained_buffer_ids_.begin());
2281     const HloInstruction* instruction = buffer.instruction();
2282     Layout new_layout =
2283         instruction->opcode() == HloOpcode::kConstant
2284             ? ShapeUtil::GetSubshape(instruction->literal().shape(),
2285                                      buffer.index())
2286                   .layout()
2287             : GetUnconstrainedLayout(buffer);
2288     TF_RETURN_IF_ERROR(SetBufferLayout(new_layout, buffer,
2289                                        /*mandatory=*/false));
2290 
2291     TF_RETURN_IF_ERROR(PropagateConstraints(constraints));
2292 
2293     // To verify progress has been made, check that the number of unconstrained
2294     // buffers has been reduced.
2295     CHECK_LT(unconstrained_buffer_ids_.size(), unconstrained_count);
2296   }
2297 
2298   TF_RETURN_IF_ERROR(CalculateComputationLayout(constraints));
2299   // Record the layouts assigned for any communication ops in
2300   // channel_constraints so that they are constrained for future modules.
2301   if (channel_constraints != nullptr) {
2302     TF_RETURN_IF_ERROR(
2303         ConstrainChannelLayouts(computation, channel_constraints));
2304   }
2305 
2306   return OkStatus();
2307 }
2308 
ConstrainChannelLayouts(HloComputation * computation,ChannelLayoutConstraints * channel_constraints)2309 Status LayoutAssignment::ConstrainChannelLayouts(
2310     HloComputation* computation,
2311     ChannelLayoutConstraints* channel_constraints) {
2312   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
2313     if (instruction->IsCrossModuleAllReduce()) {
2314       TF_ASSIGN_OR_RETURN(auto op_layout, InferArrayLayout(instruction, {}));
2315       VLOG(5) << "Constrain cross module all reduce: " << op_layout.ToString()
2316               << "\n";
2317       channel_constraints->ConstrainChannel(instruction->channel_id().value(),
2318                                             op_layout);
2319     }
2320   }
2321   return OkStatus();
2322 }
2323 
PropagateMemorySpace(HloModule * module)2324 Status LayoutAssignment::PropagateMemorySpace(HloModule* module) {
2325   TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module));
2326   for (const auto& buffer : alias_analysis->buffers()) {
2327     // First go through values to collect the memory spaces.
2328     int64_t buffer_memory_space = Layout::kDefaultMemorySpace;
2329     for (auto value : buffer.values()) {
2330       const Shape& defining_shape = value->defining_position().shape();
2331       if (!defining_shape.has_layout()) {
2332         continue;
2333       }
2334       int64_t memory_space = defining_shape.layout().memory_space();
2335       if (memory_space != Layout::kDefaultMemorySpace) {
2336         if (buffer_memory_space != Layout::kDefaultMemorySpace &&
2337             memory_space != buffer_memory_space) {
2338           return InternalError(
2339               "Buffer %d (%s) has conflicting memory spaces: %d and %d.",
2340               buffer.id(), value->ToShortString(), buffer_memory_space,
2341               memory_space);
2342         }
2343         buffer_memory_space = memory_space;
2344       }
2345     }
2346 
2347     // If we encounter a memory space other than the default, then propagate all
2348     // the positions with the buffer's memory space.
2349     if (buffer_memory_space != Layout::kDefaultMemorySpace) {
2350       for (auto value : buffer.values()) {
2351         for (auto& position : value->positions()) {
2352           Shape* shape = ShapeUtil::GetMutableSubshape(
2353               position.instruction->mutable_shape(), position.index);
2354           shape->mutable_layout()->set_memory_space(buffer_memory_space);
2355         }
2356       }
2357     }
2358   }
2359   return OkStatus();
2360 }
2361 
PropagateComputationLayouts(HloComputation * computation,ComputationLayout * computation_layout)2362 Status LayoutAssignment::PropagateComputationLayouts(
2363     HloComputation* computation, ComputationLayout* computation_layout) {
2364   ComputationLayout computed_computation_layout(
2365       computation->ComputeProgramShape(),
2366       /*ignore_layouts=*/false);
2367   for (int64_t i = 0; i < computed_computation_layout.parameter_count(); ++i) {
2368     ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
2369     bool needs_assign = false;
2370     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2371         param_layout->shape(),
2372         [&](const Shape& subshape, const ShapeIndex& shape_index) {
2373           if (!ShapeUtil::IsLeafIndex(param_layout->shape(), shape_index)) {
2374             return OkStatus();
2375           }
2376           if (!subshape.has_layout()) {
2377             needs_assign = true;
2378             return OkStatus();
2379           }
2380           const auto& computed_subshape = ShapeUtil::GetSubshape(
2381               computed_computation_layout.parameter_shape(i), shape_index);
2382           if (subshape.layout() != computed_subshape.layout()) {
2383             return InternalError(
2384                 "Assigned parameter shape %s does not match layout of "
2385                 "computation shape: %s",
2386                 computed_computation_layout.ToString(),
2387                 computation_layout->ToString());
2388           }
2389           return OkStatus();
2390         }));
2391     if (needs_assign) {
2392       VLOG(4) << "Assigning layout to parameter " << i << " of computation "
2393               << computation->name() << ": "
2394               << computed_computation_layout.parameter_layout(i).ToString();
2395       *param_layout = computed_computation_layout.parameter_layout(i);
2396     }
2397   }
2398   ShapeLayout* result_layout = computation_layout->mutable_result_layout();
2399   if (!result_layout->LayoutIsSet()) {
2400     VLOG(4) << "Assigning result layout of computation " << computation->name()
2401             << ": " << computed_computation_layout.result_layout().ToString();
2402     *result_layout = computed_computation_layout.result_layout();
2403   } else {
2404     TF_RET_CHECK(
2405         Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
2406             computed_computation_layout.result_layout().shape(),
2407             result_layout->shape()));
2408   }
2409   return OkStatus();
2410 }
2411 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2412 StatusOr<bool> LayoutAssignment::Run(
2413     HloModule* module,
2414     const absl::flat_hash_set<absl::string_view>& execution_threads) {
2415   VLOG(2) << "Running layout assignment on module " << module->name();
2416   TF_RETURN_IF_ERROR(Init(module));
2417   call_graph_ = CallGraph::Build(module);
2418   // Add copy to the operand of Send instructions, since we cannot call
2419   // SetOperandLayout on Send instructions as it aliases its input to the
2420   // output.
2421   //
2422   // TODO(b/68493863): Remove this once we can call SetOperandLayout() on the
2423   // operand buffers that aliases with the output.
2424   for (HloComputation* computation : module->computations(execution_threads)) {
2425     for (HloInstruction* instruction :
2426          computation->MakeInstructionPostOrder()) {
2427       if (instruction->opcode() == HloOpcode::kSend) {
2428         TF_RETURN_IF_ERROR(AddCopyForOperand(instruction, 0));
2429       }
2430     }
2431   }
2432 
2433   // Clone Conditional computations with multiple callsites.
2434   for (HloComputation* computation : module->computations(execution_threads)) {
2435     CallGraphNode& node = call_graph_->GetNode(computation);
2436     if (node.caller_callsites().size() == 1) {
2437       continue;
2438     }
2439     if (absl::c_none_of(node.caller_callsites(), [](CallSite caller) {
2440           return caller.instruction()->opcode() == HloOpcode::kConditional;
2441         })) {
2442       continue;
2443     }
2444     for (int64_t i = 0; i < node.caller_callsites().size() - 1; ++i) {
2445       HloInstruction* caller = node.caller_callsites()[i].instruction();
2446       if (caller->opcode() == HloOpcode::kConditional) {
2447         for (int64_t k = 0; k < caller->branch_count(); ++k) {
2448           if (computation == caller->branch_computation(k)) {
2449             caller->set_branch_computation(
2450                 k, module->AddEmbeddedComputation(computation->Clone()));
2451             break;
2452           }
2453         }
2454       }
2455     }
2456   }
2457 
2458   // Verify computation layout is sane.
2459   HloComputation* entry = module->entry_computation();
2460   TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
2461                entry->num_parameters());
2462   for (int64_t i = 0; i < entry->num_parameters(); ++i) {
2463     TF_RET_CHECK(
2464         ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
2465                               entry->parameter_instruction(i)->shape()));
2466   }
2467   TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
2468                                      entry->root_instruction()->shape()));
2469   // We do two passes. The first one we pass a nullptr ComputationLayout to
2470   // the RunOnComputation() calls (for non entry computations), and we register
2471   // the ComputationLayout which are naturally flowing in DFS fashion to the
2472   // parameters and root instruction.
2473   // Walking in DFS mode though, means that we can end up with incorrect layouts
2474   // when seen from an outer instruction, which has across-computation
2475   // constraints to impose.
2476   // For example, the kWhile instruction needs to enforce the same layouts for
2477   // the parameters and root of the body, as well as the condition parameters.
2478   // Similarly, the kConditional instruction needs to enforce the same layouts
2479   // for the root of the true and false computations.
2480   // So in the first pass, while allowing the layouts to flow to parameters and
2481   // root, we also fix up the eventually inconsistent ComputationLayout, which
2482   // will be then made mandatory by the second pass.
2483   TF_ASSIGN_OR_RETURN(auto points_to_analysis,
2484                       TuplePointsToAnalysis::Run(module));
2485   points_to_analysis_ = std::move(points_to_analysis);
2486   auto computations_to_work =
2487       module->MakeNonfusionComputations(execution_threads);
2488   // If the reverse_comptation_order_ flag is set, reverse the ordering of
2489   // traversing computations, to generate an alternative layout assignment.
2490   if (reverse_computation_order_ && !computations_to_work.empty()) {
2491     absl::c_reverse(computations_to_work);
2492 
2493     VLOG(2) << "reversing traversal order for computation:";
2494   }
2495   computation_layouts_.emplace(
2496       module->entry_computation(),
2497       new LayoutConstraints(entry,
2498                             entry_computation_layout_->LayoutIsSet()
2499                                 ? entry_computation_layout_
2500                                 : nullptr,
2501                             entry_computation_layout_->LayoutIsSet()
2502                                 ? LayoutConstraint::kGivenPriority
2503                                 : LayoutConstraint::kDefaultPriority));
2504   for (int64_t i = 0; i < kNumberOfPropagationRounds; ++i) {
2505     VLOG(1) << "Running " << (i == 0 ? "un" : "") << "constrained pass";
2506     TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module, execution_threads));
2507     for (auto* computation : computations_to_work) {
2508       LayoutConstraints* constraints =
2509           mutable_computation_constraints(computation);
2510       TF_RETURN_IF_ERROR(
2511           RunOnComputation(constraints, channel_layout_constraints_));
2512     }
2513     current_priority_ += 1;
2514   }
2515 
2516   for (auto* computation : computations_to_work) {
2517     LayoutConstraints* constraints =
2518         FindOrDie(computation_layouts_, computation).get();
2519     // All logical buffers should have constraints at this point. All that
2520     // remains is assign the constraints to the buffers and infer layouts for
2521     // aliased buffers.
2522     TF_RETURN_IF_ERROR(AssignLayouts(*constraints));
2523   }
2524   TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
2525                                                  entry_computation_layout_));
2526 
2527   TF_RETURN_IF_ERROR(PropagateMemorySpace(module));
2528 
2529   TF_RETURN_IF_ERROR(CheckLayouts(module, execution_threads));
2530 
2531   // All layouts are reset then reassigned by this pass.
2532   return true;
2533 }
2534 
2535 /* static */
InstructionCanChangeLayout(const HloInstruction * instruction)2536 bool LayoutAssignment::InstructionCanChangeLayout(
2537     const HloInstruction* instruction) {
2538   switch (instruction->opcode()) {
2539     case HloOpcode::kAbs:
2540     case HloOpcode::kAdd:
2541     case HloOpcode::kAddDependency:
2542     case HloOpcode::kAnd:
2543     case HloOpcode::kAtan2:
2544     case HloOpcode::kBitcastConvert:
2545     case HloOpcode::kCeil:
2546     case HloOpcode::kClamp:
2547     case HloOpcode::kClz:
2548     case HloOpcode::kCompare:
2549     case HloOpcode::kComplex:
2550     case HloOpcode::kConcatenate:
2551     case HloOpcode::kConvert:
2552     case HloOpcode::kCos:
2553     case HloOpcode::kAllGather:
2554     case HloOpcode::kAllGatherStart:
2555     case HloOpcode::kAllGatherDone:
2556     case HloOpcode::kAllToAll:
2557     case HloOpcode::kCollectivePermute:
2558     case HloOpcode::kDivide:
2559     case HloOpcode::kDynamicSlice:
2560     case HloOpcode::kDynamicUpdateSlice:
2561     case HloOpcode::kExp:
2562     case HloOpcode::kExpm1:
2563     case HloOpcode::kFft:
2564     case HloOpcode::kFloor:
2565     case HloOpcode::kImag:
2566     case HloOpcode::kIsFinite:
2567     case HloOpcode::kLog:
2568     case HloOpcode::kLog1p:
2569     case HloOpcode::kLogistic:
2570     case HloOpcode::kMap:
2571     case HloOpcode::kMaximum:
2572     case HloOpcode::kMinimum:
2573     case HloOpcode::kMultiply:
2574     case HloOpcode::kNegate:
2575     case HloOpcode::kNot:
2576     case HloOpcode::kOptimizationBarrier:
2577     case HloOpcode::kOr:
2578     case HloOpcode::kXor:
2579     case HloOpcode::kPad:
2580     case HloOpcode::kPower:
2581     case HloOpcode::kReal:
2582     case HloOpcode::kReducePrecision:
2583     case HloOpcode::kReduceWindow:
2584     case HloOpcode::kRemainder:
2585     case HloOpcode::kReverse:
2586     case HloOpcode::kRoundNearestAfz:
2587     case HloOpcode::kRoundNearestEven:
2588     case HloOpcode::kRsqrt:
2589     case HloOpcode::kScatter:
2590     case HloOpcode::kSelect:
2591     case HloOpcode::kSelectAndScatter:
2592     case HloOpcode::kShiftLeft:
2593     case HloOpcode::kShiftRightArithmetic:
2594     case HloOpcode::kShiftRightLogical:
2595     case HloOpcode::kSign:
2596     case HloOpcode::kSin:
2597     case HloOpcode::kSlice:
2598     case HloOpcode::kSort:
2599     case HloOpcode::kSqrt:
2600     case HloOpcode::kCbrt:
2601     case HloOpcode::kSubtract:
2602     case HloOpcode::kTanh:
2603     case HloOpcode::kPopulationCount:
2604     case HloOpcode::kTriangularSolve:
2605     case HloOpcode::kCholesky:
2606     case HloOpcode::kWhile:
2607     case HloOpcode::kSetDimensionSize:
2608     // AllReduce is variadic so it needs to be careful to assign the same layout
2609     // to the corresponding input argument and Tuple index.
2610     case HloOpcode::kAllReduce:
2611     case HloOpcode::kReduceScatter:
2612     case HloOpcode::kAllReduceStart:
2613     case HloOpcode::kAllReduceDone:
2614       return false;
2615     case HloOpcode::kAsyncStart:
2616     case HloOpcode::kAsyncUpdate:
2617     case HloOpcode::kAsyncDone:
2618     case HloOpcode::kBatchNormGrad:
2619     case HloOpcode::kBatchNormInference:
2620     case HloOpcode::kBatchNormTraining:
2621     case HloOpcode::kBitcast:
2622     case HloOpcode::kBroadcast:
2623     case HloOpcode::kCall:
2624     case HloOpcode::kCollectivePermuteStart:
2625     case HloOpcode::kCollectivePermuteDone:
2626     case HloOpcode::kConditional:
2627     case HloOpcode::kConstant:
2628     case HloOpcode::kConvolution:
2629     case HloOpcode::kCopy:
2630     case HloOpcode::kCopyStart:
2631     case HloOpcode::kCopyDone:
2632     case HloOpcode::kCustomCall:
2633     case HloOpcode::kDomain:
2634     case HloOpcode::kDot:
2635     case HloOpcode::kFusion:
2636     case HloOpcode::kGather:
2637     case HloOpcode::kGetTupleElement:
2638     case HloOpcode::kInfeed:
2639     case HloOpcode::kIota:
2640     case HloOpcode::kOutfeed:
2641     case HloOpcode::kParameter:
2642     case HloOpcode::kPartitionId:
2643     case HloOpcode::kRecv:
2644     case HloOpcode::kRecvDone:
2645     case HloOpcode::kReduce:
2646     case HloOpcode::kReplicaId:
2647     case HloOpcode::kReshape:
2648     case HloOpcode::kDynamicReshape:
2649     case HloOpcode::kRng:
2650     case HloOpcode::kRngBitGenerator:
2651     case HloOpcode::kRngGetAndUpdateState:
2652     case HloOpcode::kSend:
2653     case HloOpcode::kSendDone:
2654     case HloOpcode::kAfterAll:
2655     case HloOpcode::kTranspose:
2656     case HloOpcode::kTuple:
2657     case HloOpcode::kGetDimensionSize:
2658       return true;
2659   }
2660 }
2661 
InstructionCanChangeLayoutInstance(const HloInstruction * instruction)2662 bool LayoutAssignment::InstructionCanChangeLayoutInstance(
2663     const HloInstruction* instruction) {
2664   return InstructionCanChangeLayout(instruction);
2665 }
2666 
2667 /* static */
IsAtMostRank1(const Shape & shape)2668 bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
2669   if (shape.IsArray()) {
2670     return shape.rank() <= 1;
2671   }
2672   return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) {
2673     return IsAtMostRank1(subshape);
2674   });
2675 }
2676 
Init(HloModule * module)2677 Status LayoutAssignment::Init(HloModule* module) {
2678   computation_layouts_.clear();
2679   conditional_mismatch_.clear();
2680   *entry_computation_layout_ = saved_entry_computation_layout_;
2681   current_priority_ = LayoutConstraint::kBeginningPriority;
2682   // Clear all the copies which have been added, and all the related
2683   // instructions (like GTE and tuples).
2684   int64_t removed_copies = 0;
2685   for (HloComputation* computation : module->computations()) {
2686     for (HloInstruction* instruction :
2687          computation->MakeInstructionPostOrder()) {
2688       if (instruction->opcode() == HloOpcode::kCopy &&
2689           added_copies_.contains(instruction)) {
2690         VLOG(5) << "Removing added copy: " << instruction->ToString();
2691         TF_RETURN_IF_ERROR(
2692             instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
2693         TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
2694         ++removed_copies;
2695       }
2696     }
2697   }
2698   added_copies_.clear();
2699   if (removed_copies > 0) {
2700     TupleSimplifier tuple_simplifier;
2701     HloDCE dce;
2702     TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2703     TF_RETURN_IF_ERROR(dce.Run(module).status());
2704     call_graph_ = CallGraph::Build(module);
2705   }
2706   return OkStatus();
2707 }
2708 
ClearPreviousPassSideEffects(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2709 Status LayoutAssignment::ClearPreviousPassSideEffects(
2710     HloModule* module,
2711     const absl::flat_hash_set<absl::string_view>& execution_threads) {
2712   VLOG(5) << "Clearing previous side effects";
2713   for (HloComputation* computation : module->computations(execution_threads)) {
2714     if (computation_layouts_.find(computation) != computation_layouts_.end()) {
2715       mutable_computation_constraints(computation)->ResetOperandConstraints();
2716     }
2717   }
2718   unconstrained_layout_instructions_.clear();
2719   unconstrained_buffer_ids_.clear();
2720   buffer_constraints_.clear();
2721   buffer_sets_cache_.clear();
2722   return OkStatus();
2723 }
AddCopyForOperand(HloInstruction * instruction,int64_t operand_number)2724 Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
2725                                            int64_t operand_number) {
2726   HloInstruction* operand = instruction->mutable_operand(operand_number);
2727   if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
2728     HloInstruction* copy =
2729         instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
2730             operand->shape(), HloOpcode::kCopy, operand));
2731     SetupCopiedInstruction(*operand, copy, {});
2732     LayoutUtil::ClearLayout(copy->mutable_shape());
2733     TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
2734   }
2735   return OkStatus();
2736 }
2737 
2738 }  // namespace xla
2739