1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <numeric>
24 #include <ostream>
25 #include <set>
26 #include <string>
27 #include <tuple>
28 #include <utility>
29
30 #include "absl/algorithm/container.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/map_util.h"
37 #include "tensorflow/compiler/xla/permutation_util.h"
38 #include "tensorflow/compiler/xla/service/call_graph.h"
39 #include "tensorflow/compiler/xla/service/computation_layout.h"
40 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
41 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
42 #include "tensorflow/compiler/xla/service/hlo_computation.h"
43 #include "tensorflow/compiler/xla/service/hlo_dce.h"
44 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
45 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
46 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
47 #include "tensorflow/compiler/xla/service/logical_buffer.h"
48 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
49 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
50 #include "tensorflow/compiler/xla/shape_layout.h"
51 #include "tensorflow/compiler/xla/shape_util.h"
52 #include "tensorflow/compiler/xla/status_macros.h"
53 #include "tensorflow/compiler/xla/statusor.h"
54 #include "tensorflow/compiler/xla/types.h"
55 #include "tensorflow/compiler/xla/util.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/platform/logging.h"
60 #include "tensorflow/core/platform/protobuf.h"
61
62 namespace xla {
63
operator <<(std::ostream & out,const LayoutConstraint & constraint)64 std::ostream& operator<<(std::ostream& out,
65 const LayoutConstraint& constraint) {
66 out << constraint.ToString();
67 return out;
68 }
69
BufferLayoutConstraint(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs,int64_t priority)70 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
71 const LogicalBuffer& buffer,
72 bool mandatory, bool dfs,
73 int64_t priority)
74 : LayoutConstraint(mandatory, dfs, priority),
75 layout_(layout),
76 buffer_(&buffer) {
77 CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
78 }
79
ToString() const80 std::string BufferLayoutConstraint::ToString() const {
81 return absl::StrFormat(
82 "BufferLayoutConstraint (prioity=%d, mandatory=%d, dfs=%d) %s: %s",
83 priority(), mandatory(), dfs(), buffer_->ToString(),
84 LayoutUtil::HumanString(layout_));
85 }
86
UpdateLayout(int64_t priority,const Layout & layout,bool mandatory,bool dfs)87 bool BufferLayoutConstraint::UpdateLayout(int64_t priority,
88 const Layout& layout, bool mandatory,
89 bool dfs) {
90 if (!mandatory && priority <= priority_) {
91 return false;
92 }
93 mandatory_ = mandatory;
94 dfs_ = dfs;
95 priority_ = priority;
96 if (Layout::Equal().MinorToMajorOnly()(layout_, layout)) {
97 // New constraint matches existing constraint. Nothing to propagation.
98 return false;
99 }
100 layout_ = layout;
101 return true;
102 }
103
OperandLayoutConstraint(const ShapeLayout & shape_layout,const HloInstruction * instruction,int64_t operand_no,bool mandatory,bool dfs,int64_t priority)104 OperandLayoutConstraint::OperandLayoutConstraint(
105 const ShapeLayout& shape_layout, const HloInstruction* instruction,
106 int64_t operand_no, bool mandatory, bool dfs, int64_t priority)
107 : LayoutConstraint(mandatory, dfs, priority),
108 shape_layout_(shape_layout),
109 instruction_(instruction),
110 operand_no_(operand_no) {
111 CHECK(shape_layout_.LayoutIsSet());
112 CHECK(ShapeUtil::Compatible(shape_layout.shape(),
113 instruction->operand(operand_no)->shape()))
114 << shape_layout.shape() << " is not compatible with "
115 << instruction->operand(operand_no)->shape() << " (for operand "
116 << operand_no << " of instruction " << instruction->ToString() << ")";
117 }
118
ToString() const119 std::string OperandLayoutConstraint::ToString() const {
120 return absl::StrFormat(
121 "OperandLayoutConstraint (prioity=%d) %s, operand %d: %s", priority(),
122 instruction_->name(), operand_no_, shape_layout_.ToString());
123 }
124
ToString() const125 std::string ComputationLayoutConstraint::ToString() const {
126 return absl::StrFormat("ComputationLayoutConstraint (status=%d): %s",
127 layout_state_, computation_layout_.ToString());
128 }
129
LayoutConstraints(HloComputation * computation,ComputationLayout * computation_layout,int64_t priority)130 LayoutAssignment::LayoutConstraints::LayoutConstraints(
131 HloComputation* computation, ComputationLayout* computation_layout,
132 int64_t priority)
133 : computation_(computation),
134 computation_constraint_(computation, computation_layout, priority) {}
135
GetBufferSet(const HloInstruction * instruction) const136 PointsToSet::BufferSet* LayoutAssignment::GetBufferSet(
137 const HloInstruction* instruction) const {
138 auto it = buffer_sets_cache_.find(instruction);
139 if (it != buffer_sets_cache_.end()) {
140 return it->second.get();
141 }
142 auto& buffer_set =
143 buffer_sets_cache_
144 .emplace(instruction, std::make_unique<PointsToSet::BufferSet>())
145 .first->second;
146 const auto& points_to_set = points_to_analysis_->GetPointsToSet(instruction);
147 points_to_set.ForEachElement(
148 [&buffer_set](const ShapeIndex& /*index*/,
149 const PointsToSet::BufferList& buffers) {
150 buffer_set->insert(buffers.begin(), buffers.end());
151 });
152 return buffer_set.get();
153 }
154
AnyOperandBufferForwarded(const HloInstruction * instruction,int64_t operand_no) const155 bool LayoutAssignment::AnyOperandBufferForwarded(
156 const HloInstruction* instruction, int64_t operand_no) const {
157 // The operand is potentially forwarded if the intersection of points-to sets
158 // of the operand and the instruction is non-empty.
159 PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
160 PointsToSet::BufferSet* operand_buffers =
161 GetBufferSet(instruction->operand(operand_no));
162 return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
163 return operand_buffers->count(b) > 0;
164 });
165 }
166
AllOperandBuffersForwarded(const HloInstruction * instruction,int64_t operand_no) const167 bool LayoutAssignment::AllOperandBuffersForwarded(
168 const HloInstruction* instruction, int64_t operand_no) const {
169 // The operand is potentially forwarded if the intersection of points-to sets
170 // of the operand and the instruction is non-empty.
171 PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
172 PointsToSet::BufferSet* operand_buffers =
173 GetBufferSet(instruction->operand(operand_no));
174 return absl::c_all_of(*output_buffers, [&](const LogicalBuffer* b) {
175 return operand_buffers->count(b) > 0;
176 });
177 }
178
SetBufferLayout(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs,int64_t priority)179 Status LayoutAssignment::SetBufferLayout(const Layout& layout,
180 const LogicalBuffer& buffer,
181 bool mandatory, bool dfs,
182 int64_t priority) {
183 VLOG(3) << "SetBufferLayout : " << buffer << " : "
184 << LayoutUtil::HumanString(layout) << " with priority " << priority;
185 TF_RETURN_IF_ERROR(points_to_analysis_->VerifyBuffer(buffer));
186 if (unconstrained_buffer_ids_.find(buffer.id()) !=
187 unconstrained_buffer_ids_.end()) {
188 VLOG(3) << "Erase buffer from unconstrained ids\n";
189 TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1)
190 << buffer.ToString();
191 }
192
193 if (!buffer.IsArray()) {
194 return FailedPrecondition(
195 "Layout of buffer %s cannot be constrained because buffer is not "
196 "array-shaped, has shape: %s",
197 buffer.ToString(), ShapeUtil::HumanString(buffer.shape()));
198 }
199 TF_RETURN_IF_ERROR(
200 LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
201
202 auto iter = buffer_constraints_.find(&buffer);
203 if (iter != buffer_constraints_.end()) {
204 BufferLayoutConstraint curr_constraint = iter->second;
205 if (curr_constraint.mandatory() && !mandatory) {
206 VLOG(3) << "Buffer" << buffer
207 << " already has a mandatory layout constrain, skipping "
208 "non-mandatory new layout";
209 return OkStatus();
210 } else {
211 if (curr_constraint.UpdateLayout(priority, layout, mandatory, dfs)) {
212 VLOG(3) << "Updating existing Buffer layout for " << buffer.ToString()
213 << " with new layout" << LayoutUtil::HumanString(layout);
214 iter = buffer_constraints_.insert_or_assign(&buffer, curr_constraint)
215 .first;
216 } else {
217 VLOG(3) << "Unable to update existing Buffer layout for "
218 << curr_constraint.ToString() << " with new layout"
219 << LayoutUtil::HumanString(layout) << " at priority "
220 << priority << "\n";
221 return OkStatus();
222 }
223 }
224 } else {
225 iter = buffer_constraints_
226 .insert(std::make_pair(
227 &buffer, BufferLayoutConstraint(layout, buffer, mandatory,
228 dfs, priority)))
229 .first;
230 }
231 added_constraints_.push_back(&iter->second);
232 return OkStatus();
233 }
234
SetOperandLayout(const Shape & shape_with_layout,const HloInstruction * instruction,int64_t operand_no,bool mandatory,bool dfs,int64_t priority)235 Status LayoutAssignment::SetOperandLayout(const Shape& shape_with_layout,
236 const HloInstruction* instruction,
237 int64_t operand_no, bool mandatory,
238 bool dfs, int64_t priority) {
239 LayoutConstraints& constraints =
240 *FindOrDie(computation_layouts_, instruction->parent());
241 // The second and third operands (operand_no > 0) of a dynamic-update-slice
242 // operation typically have much smaller sizes than the first (operand_no==0)
243 // operand. It is necessary to downgrade the importance of the smaller
244 // operands, so that the overall layout choice of the operation is dicated by
245 // operand 0 when possible.
246 if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice &&
247 operand_no > 0 && !mandatory &&
248 priority > LayoutConstraint::kDefaultPriority) {
249 dfs = false;
250 priority--;
251 } else if (instruction->opcode() == HloOpcode::kReshape && !mandatory &&
252 instruction->operand(0)->opcode() == HloOpcode::kDynamicSlice) {
253 dfs = false;
254 priority--;
255 }
256 VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
257 << operand_no << " : "
258 << ShapeUtil::HumanStringWithLayout(shape_with_layout)
259 << " : priority = " << priority << "; mandatory = " << mandatory
260 << "; dfs = " << dfs << "\n";
261 const OperandLayoutConstraint* curr_shape_layout =
262 constraints.GetOperandLayoutConstraint(instruction, operand_no);
263 if (curr_shape_layout != nullptr) {
264 if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
265 shape_with_layout, /*minor_to_major_only=*/true)) {
266 // New constraint matches existing constraint. Nothing to do.
267 return OkStatus();
268 }
269 if (curr_shape_layout->mandatory() && !mandatory) {
270 VLOG(5)
271 << "Existing layout is mandatory but the new one is not. Skipping.\n";
272 return OkStatus();
273 } else if (curr_shape_layout->priority() > priority) {
274 VLOG(5) << "Existing layout has higher priority: "
275 << curr_shape_layout->priority() << " vs " << priority << "\n";
276 return OkStatus();
277 }
278 }
279 // If any buffers in the operand occur in the output of the instruction, then
280 // return an error. This case is not handled because such a constraint changes
281 // layouts beyond this immediate use and is complicated to handle.
282 if (AnyOperandBufferForwarded(instruction, operand_no)) {
283 return FailedPrecondition(
284 "Cannot constraint layout of operand %d of instruction %s "
285 "because instruction forwards operand's LogicalBuffer(s)",
286 operand_no, instruction->name());
287 }
288
289 OperandLayoutConstraint new_constraint(ShapeLayout(shape_with_layout),
290 instruction, operand_no, mandatory,
291 dfs, priority);
292 auto op_constraint = constraints.InsertOperandLayoutConstraint(
293 instruction, operand_no, new_constraint);
294 PushAddedConstraints(op_constraint);
295 return OkStatus();
296 }
297
298 OperandLayoutConstraint*
InsertOperandLayoutConstraint(const HloInstruction * instruction,int64_t operand_no,const OperandLayoutConstraint & constraint)299 LayoutAssignment::LayoutConstraints::InsertOperandLayoutConstraint(
300 const HloInstruction* instruction, int64_t operand_no,
301 const OperandLayoutConstraint& constraint) {
302 auto key = std::make_pair(instruction, operand_no);
303 auto iter = operand_constraints_.find(key);
304 if (iter == operand_constraints_.end()) {
305 auto pair = std::make_pair(key, constraint);
306 iter = operand_constraints_.insert(pair).first;
307 } else {
308 iter->second = constraint;
309 }
310 return &iter->second;
311 }
312
PushAddedConstraints(const LayoutConstraint * constraint)313 void LayoutAssignment::PushAddedConstraints(
314 const LayoutConstraint* constraint) {
315 if (!constraint->dfs()) {
316 // Insert a new constraint to the first location where it's strictly greater
317 // than all the subsequent constraints. Assumes invariant that the list is
318 // sorted.
319 auto it = absl::c_upper_bound(
320 added_constraints_, constraint,
321 [&](const LayoutConstraint* a, const LayoutConstraint* b) {
322 return a->priority() > b->priority();
323 });
324 added_constraints_.insert(it, constraint);
325 } else {
326 added_constraints_.push_back(constraint);
327 }
328 }
329
SetArrayOperandLayout(const Layout & layout,const HloInstruction * instruction,int64_t operand_no,bool mandatory,bool dfs,int64_t priority)330 Status LayoutAssignment::SetArrayOperandLayout(
331 const Layout& layout, const HloInstruction* instruction, int64_t operand_no,
332 bool mandatory, bool dfs, int64_t priority) {
333 const HloInstruction* operand = instruction->operand(operand_no);
334 TF_RET_CHECK(operand->shape().IsArray());
335 Shape shape(operand->shape());
336 *shape.mutable_layout() = layout;
337 TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
338 return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs,
339 priority);
340 }
341
SetResultLayout(LayoutAssignment * assignment,const Shape & shape_with_layout,int64_t priority)342 Status LayoutAssignment::LayoutConstraints::SetResultLayout(
343 LayoutAssignment* assignment, const Shape& shape_with_layout,
344 int64_t priority) {
345 VLOG(3) << " : " << ShapeUtil::HumanStringWithLayout(shape_with_layout)
346 << "; priority = " << priority << ".\n";
347
348 computation_constraint_.ResetResultLayout(ShapeLayout(shape_with_layout),
349 priority);
350 assignment->PushAddedConstraints(&computation_constraint_);
351 return OkStatus();
352 }
353
SetInstructionLayout(const Layout & layout,const HloInstruction * instruction,bool mandatory,bool dfs,bool allow_alias,int64_t priority)354 Status LayoutAssignment::SetInstructionLayout(const Layout& layout,
355 const HloInstruction* instruction,
356 bool mandatory, bool dfs,
357 bool allow_alias,
358 int64_t priority) {
359 if (priority < 0) {
360 priority = current_priority_;
361 }
362 auto RequiresSameShapeForAllOutput = [](const HloInstruction* op) -> bool {
363 switch (op->opcode()) {
364 case HloOpcode::kSort:
365 case HloOpcode::kReduce:
366 case HloOpcode::kReduceWindow:
367 return true;
368 default:
369 return false;
370 }
371 };
372 CHECK(instruction->shape().IsArray() ||
373 RequiresSameShapeForAllOutput(instruction));
374
375 return ShapeUtil::ForEachSubshapeWithStatus(
376 instruction->shape(),
377 [this, layout, instruction, mandatory, allow_alias, priority](
378 const Shape& subshape, const ShapeIndex& index) -> Status {
379 auto buffers =
380 points_to_analysis_->GetPointsToSet(instruction).element(index);
381 CHECK_EQ(1, buffers.size());
382 if (!allow_alias) {
383 CHECK_EQ(buffers[0]->instruction(), instruction);
384 }
385 if (subshape.IsArray()) {
386 return SetBufferLayout(layout, *buffers[0], mandatory,
387 /*dfs=*/true, priority);
388 } else {
389 return OkStatus();
390 }
391 });
392 }
393
SetInstructionLayout(const Shape & shape_with_layout,const HloInstruction * instruction,bool mandatory,bool dfs,bool allow_alias,int64_t priority)394 Status LayoutAssignment::SetInstructionLayout(const Shape& shape_with_layout,
395 const HloInstruction* instruction,
396 bool mandatory, bool dfs,
397 bool allow_alias,
398 int64_t priority) {
399 VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
400 << ShapeUtil::HumanStringWithLayout(shape_with_layout)
401 << ": priority = " << priority << " : mandatory = " << mandatory
402 << "; dfs = " << dfs << "\n";
403
404 if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
405 return FailedPrecondition(
406 "Instruction %s of shape %s cannot be assigned incompatible layout %s",
407 instruction->name(), ShapeUtil::HumanString(instruction->shape()),
408 ShapeUtil::HumanStringWithLayout(shape_with_layout));
409 }
410
411 // Create a BufferLayoutConstraint for each array shape in the output of the
412 // instruction.
413 return ShapeUtil::ForEachSubshapeWithStatus(
414 shape_with_layout,
415 [this, instruction, mandatory, allow_alias, priority](
416 const Shape& subshape, const ShapeIndex& index) -> Status {
417 auto buffers =
418 points_to_analysis_->GetPointsToSet(instruction).element(index);
419 CHECK_EQ(1, buffers.size());
420 if (!allow_alias) {
421 CHECK_EQ(buffers[0]->instruction(), instruction);
422 }
423
424 if (subshape.IsArray() && subshape.has_layout()) {
425 return SetBufferLayout(subshape.layout(), *buffers[0], mandatory,
426 /*dfs=*/true, priority);
427 } else {
428 return OkStatus();
429 }
430 });
431 }
432
GetBufferLayoutConstraint(const LogicalBuffer & buffer) const433 const BufferLayoutConstraint* LayoutAssignment::GetBufferLayoutConstraint(
434 const LogicalBuffer& buffer) const {
435 auto it = buffer_constraints_.find(&buffer);
436 return it == buffer_constraints_.end() ? nullptr : &it->second;
437 }
438
OperandLayout(const HloInstruction * instruction,int64_t operand_no) const439 const ShapeLayout* LayoutAssignment::LayoutConstraints::OperandLayout(
440 const HloInstruction* instruction, int64_t operand_no) const {
441 if (const auto* constraint =
442 GetOperandLayoutConstraint(instruction, operand_no)) {
443 return &constraint->shape_layout();
444 }
445 return nullptr;
446 }
447
448 const OperandLayoutConstraint*
GetOperandLayoutConstraint(const HloInstruction * instruction,int64_t operand_no) const449 LayoutAssignment::LayoutConstraints::GetOperandLayoutConstraint(
450 const HloInstruction* instruction, int64_t operand_no) const {
451 auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
452 return it == operand_constraints_.end() ? nullptr : &it->second;
453 }
454
ResultLayout() const455 const ShapeLayout* LayoutAssignment::LayoutConstraints::ResultLayout() const {
456 return (computation_->IsEntryComputation() ||
457 computation_constraint_.result_layout_is_set())
458 ? &computation_layout().result_layout()
459 : nullptr;
460 }
461
ToString(const LayoutConstraints & constraints) const462 std::string LayoutAssignment::ToString(
463 const LayoutConstraints& constraints) const {
464 std::string output;
465 absl::StrAppend(&output, "LayoutConstraints for computation ",
466 constraints.computation()->name(), "\n");
467 for (auto* instruction :
468 constraints.computation()->MakeInstructionPostOrder()) {
469 absl::StrAppend(&output, " ", instruction->ToShortString(), "\n");
470 for (int64_t i = 0; i < instruction->operand_count(); ++i) {
471 if (constraints.OperandLayout(instruction, i) != nullptr) {
472 absl::StrAppend(
473 &output, " operand (", i,
474 "): ", constraints.OperandLayout(instruction, i)->ToString(), "\n");
475 }
476 }
477 for (const LogicalBuffer* buffer :
478 points_to_analysis_->GetBuffersDefinedByInstruction(instruction)) {
479 auto* buffer_constraint = GetBufferLayoutConstraint(*buffer);
480 if (buffer_constraint != nullptr) {
481 absl::StrAppend(&output, " ", buffer->ToString(), " : ",
482 LayoutUtil::HumanString(buffer_constraint->layout()),
483 "\n");
484 }
485 }
486 }
487
488 absl::StrAppend(&output, " => ",
489 constraints.computation_constraint().ToString(), "\n");
490 return output;
491 }
492
493 namespace {
494
IsHostSendRecv(const HloInstruction * instruction)495 bool IsHostSendRecv(const HloInstruction* instruction) {
496 const HloSendRecvInstruction* send_recv_instr =
497 DynCast<HloSendRecvInstruction>(instruction);
498 return send_recv_instr != nullptr && send_recv_instr->is_host_transfer();
499 }
500
501 } // namespace
502
BuildHostChannelConstraints(HloComputation * computation)503 Status LayoutAssignment::BuildHostChannelConstraints(
504 HloComputation* computation) {
505 for (auto* instruction : computation->instructions()) {
506 const HloSendRecvInstruction* send_recv_instr =
507 DynCast<HloSendRecvInstruction>(instruction);
508 if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) {
509 continue;
510 }
511
512 // For host transfers the Send and Recv instruction carry the layout.
513 if (instruction->opcode() == HloOpcode::kSend ||
514 instruction->opcode() == HloOpcode::kRecv) {
515 const Shape& data_shape =
516 ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0);
517 TF_RET_CHECK(data_shape.IsArray());
518 TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
519 const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
520 *send_recv_instr->channel_id(), data_shape.layout());
521 TF_RET_CHECK(prev_layout == nullptr)
522 << "Cannot constrain host transfer layout as it was set to "
523 << LayoutUtil::HumanString(*prev_layout) << ": "
524 << send_recv_instr->ToString();
525 }
526 }
527 return OkStatus();
528 }
529
530 namespace {
531
IsLayoutConstrainedCustomCall(HloInstruction * instruction)532 bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
533 const HloCustomCallInstruction* custom_call =
534 DynCast<HloCustomCallInstruction>(instruction);
535 return custom_call != nullptr && custom_call->layout_constrained();
536 }
537
IsLayoutConstrainedCollective(const HloInstruction * instruction)538 bool IsLayoutConstrainedCollective(const HloInstruction* instruction) {
539 const HloCollectiveInstruction* collective =
540 DynCast<HloCollectiveInstruction>(instruction);
541 return collective != nullptr && collective->constrain_layout();
542 }
543
PropagateParameterLayoutToUsers(const HloInstruction * instruction,const Shape & shape,LayoutAssignment * constraints)544 Status PropagateParameterLayoutToUsers(const HloInstruction* instruction,
545 const Shape& shape,
546 LayoutAssignment* constraints) {
547 for (auto* user : instruction->users()) {
548 // Excluding tuple operations as they do not participate in layout
549 // propagations (they do not create or aliase buffers).
550 if (user->opcode() == HloOpcode::kTuple) {
551 continue;
552 }
553 VLOG(3) << "Setting user layout : " << user->ToString();
554 if (user->opcode() == HloOpcode::kGetTupleElement) {
555 auto tuple_index = user->tuple_index();
556 CHECK(shape.IsTuple());
557 auto elem_shape = shape.tuple_shapes(tuple_index);
558 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
559 elem_shape, user, /*mandatory=*/false, /*dfs=*/false,
560 /*allow_alias=*/true));
561 TF_RETURN_IF_ERROR(
562 PropagateParameterLayoutToUsers(user, elem_shape, constraints));
563 } else {
564 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
565 shape, user, user->operand_index(instruction), /*mandatory=*/false,
566 /*dfs=*/false));
567 }
568 }
569 return OkStatus();
570 }
571
572 } // namespace
573
AddMandatoryConstraints(ChannelLayoutConstraints * channel_constraints,LayoutConstraints * constraints)574 Status LayoutAssignment::AddMandatoryConstraints(
575 ChannelLayoutConstraints* channel_constraints,
576 LayoutConstraints* constraints) {
577 VLOG(2) << "Adding mandatory layout constraints to computation "
578 << constraints->computation()->name();
579
580 auto get_channel_constraints = [&](const HloInstruction* instruction) {
581 return IsHostSendRecv(instruction) ? &host_channel_constraints_
582 : channel_constraints;
583 };
584
585 // Constrain layouts of instructions which define values with pre-existing
586 // layouts.
587 for (auto* instruction : constraints->computation()->instructions()) {
588 if (instruction->opcode() == HloOpcode::kInfeed) {
589 // Infeed layouts must match the layout of the original inserted
590 // instruction.
591 // TODO(b/31425034): Change infeeds to be more like parameters, with
592 // shapes in the ComputationLayout.
593 TF_RETURN_IF_ERROR(SetInstructionLayout(instruction->shape(), instruction,
594 /*mandatory=*/true, /*dfs=*/true,
595 /*allow_alias=*/false));
596 } else if (instruction->opcode() == HloOpcode::kOutfeed) {
597 // Constrain the input to the Outfeed instruction to be the expected
598 // layout of the Outfeed.
599 TF_RETURN_IF_ERROR(SetOperandLayout(instruction->outfeed_shape(),
600 instruction, 0,
601 /*mandatory=*/true, /*dfs=*/true));
602 } else if (instruction->opcode() == HloOpcode::kParameter) {
603 if (reverse_computation_order_ ||
604 (constraints->computation()->IsEntryComputation() &&
605 entry_computation_layout_->LayoutIsSet()) ||
606 (conditional_mismatch_.count(constraints->computation()) == 0 &&
607 constraints->computation_constraint().parameter_layout_is_set())) {
608 const ShapeLayout& parameter_layout =
609 constraints->computation_layout().parameter_layout(
610 instruction->parameter_number());
611 // Parameter layouts must match the respective layout in
612 // ComputationLayout, if there is one.
613 TF_RETURN_IF_ERROR(
614 SetInstructionLayout(parameter_layout.shape(), instruction));
615 if (reverse_computation_order_) {
616 TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers(
617 instruction, parameter_layout.shape(), this));
618 }
619 }
620 } else if (IsLayoutConstrainedCustomCall(instruction)) {
621 const HloCustomCallInstruction* custom_call =
622 DynCast<HloCustomCallInstruction>(instruction);
623
624 TF_RETURN_IF_ERROR(SetInstructionLayout(custom_call->shape(), custom_call,
625 /*mandatory=*/true, /*dfs=*/true,
626 /*allow_alias=*/true));
627
628 for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
629 if (AnyOperandBufferForwarded(custom_call, i)) {
630 TF_RET_CHECK(AllOperandBuffersForwarded(custom_call, i))
631 << "Partial alias of an operand is not supported";
632 } else {
633 TF_RETURN_IF_ERROR(SetOperandLayout(
634 custom_call->operand_shapes_with_layout()[i], custom_call, i));
635 }
636 }
637 } else if (IsLayoutConstrainedCollective(instruction)) {
638 TF_RETURN_IF_ERROR(
639 SetInstructionLayout(instruction->shape(), instruction));
640 } else if (instruction->IsCrossModuleAllReduce()) {
641 CHECK(get_channel_constraints(instruction))
642 << "Multi-module layout assignment requires ChannelLayoutConstraints";
643 int64_t channel_id = instruction->channel_id().value();
644 if (!get_channel_constraints(instruction)
645 ->IsChannelConstrained(channel_id)) {
646 continue;
647 }
648 // TODO(b/68493863): Change to use SetOperandLayout().
649 const Shape& buffer_shape = instruction->operand(0)->shape();
650 TF_RET_CHECK(buffer_shape.IsArray());
651 Shape new_buffer_shape =
652 get_channel_constraints(instruction)
653 ->LayoutShapeForChannel(buffer_shape, channel_id);
654 TF_RETURN_IF_ERROR(SetInstructionLayout(new_buffer_shape, instruction));
655 }
656 }
657
658 // Constrain layouts of instructions which call computations which have
659 // already been assigned layouts. Instructions which call computations in a
660 // parallel element-wise context (eg, map or reduce) do not need layout
661 // constraints because they operate on scalars.
662 for (auto* instruction : constraints->computation()->instructions()) {
663 if (instruction->opcode() == HloOpcode::kCall &&
664 computation_layouts_.find(instruction->to_apply()) !=
665 computation_layouts_.end()) {
666 // kCall instruction operands and output must match the ComputationLayout
667 // of the called computation.
668 const ComputationLayout& called_computation_layout =
669 FindOrDie(computation_layouts_, instruction->to_apply())
670 ->computation_layout();
671 TF_RETURN_IF_ERROR(SetInstructionLayout(
672 called_computation_layout.result_layout().shape(), instruction));
673 TF_RET_CHECK(instruction->operand_count() ==
674 called_computation_layout.parameter_count());
675 for (int64_t i = 0; i < instruction->operand_count(); ++i) {
676 TF_RETURN_IF_ERROR(SetOperandLayout(
677 called_computation_layout.parameter_layout(i).shape(), instruction,
678 i, /*mandatory=*/true, /*dfs=*/true));
679 }
680 } else if (instruction->opcode() == HloOpcode::kWhile &&
681 computation_layouts_.find(instruction->while_body()) !=
682 computation_layouts_.end()) {
683 // Layout of input and output of kWhile instruction must be equal and must
684 // match both input and output of body computation. Also, the input of
685 // condition computation must match kWhile layout.
686 HloComputation* body = instruction->while_body();
687 HloComputation* condition = instruction->while_condition();
688 const HloInstruction* init = instruction->operand(0);
689 ComputationLayoutConstraint* body_constraint =
690 mutable_computation_constraints(body)
691 ->mutable_computation_constraint();
692 ComputationLayout body_layout = body_constraint->computation_layout();
693 ComputationLayoutConstraint* condition_constraint =
694 mutable_computation_constraints(condition)
695 ->mutable_computation_constraint();
696 ComputationLayout condition_layout =
697 condition_constraint->computation_layout();
698
699 // Check a few invariants irrespective of layout.
700 CHECK_EQ(1, instruction->operand_count());
701 CHECK_EQ(1, body->num_parameters());
702 CHECK_EQ(1, condition->num_parameters());
703 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
704 body_layout.parameter_shape(0)));
705 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
706 condition_layout.parameter_shape(0)));
707 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
708
709 if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
710 VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
711 << " while=" << instruction->name()
712 << " shape=" << body_layout.result_layout().ToString();
713 *body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
714 body_constraint->ResetComputationLayout(
715 body_layout, current_priority_ + kNumberOfPropagationRounds,
716 /*prop_result_layout=*/true,
717 /*prop_parameter_layout=*/true);
718 }
719 if (condition_layout.parameter_layout(0) !=
720 body_layout.parameter_layout(0)) {
721 VLOG(2) << "Reset %while condition parameter layout: cond="
722 << condition->name() << " while=" << instruction->name()
723 << " shape=" << body_layout.parameter_layout(0).ToString();
724 *condition_layout.mutable_parameter_layout(0) =
725 body_layout.parameter_layout(0);
726 condition_constraint->ResetComputationLayout(
727 condition_layout, current_priority_ + kNumberOfPropagationRounds,
728 /*prop_result_layout=*/true, /*prop_parameter_layout=*/true);
729 }
730
731 // Constrain the output and the operand of the while instruction to match
732 // the computations.
733 TF_RETURN_IF_ERROR(
734 SetOperandLayout(body_layout.result_shape(), instruction, 0));
735 TF_RETURN_IF_ERROR(
736 SetInstructionLayout(body_layout.result_shape(), instruction));
737 } else if (instruction->opcode() == HloOpcode::kConditional &&
738 computation_layouts_.find(instruction->branch_computation(0)) !=
739 computation_layouts_.end()) {
740 // Find the conditional branch with the most instructions and force all
741 // other computations to match that layout. A potentially better decision
742 // could count the number FLOPs or how constrained the layouts are.
743 int64_t largest_branch = -1;
744 int64_t largest_instruction_count = 0;
745 for (int j = 0; j < instruction->branch_count(); ++j) {
746 const int64_t instruction_count =
747 instruction->branch_computation(j)->instruction_count();
748 if (instruction_count > largest_instruction_count &&
749 !ShapeUtil::IsEmptyTuple(instruction->operand(j + 1)->shape())) {
750 largest_branch = j;
751 largest_instruction_count = instruction_count;
752 }
753 }
754 if (largest_branch == -1) {
755 largest_branch = 0;
756 }
757 const ComputationLayout& best_branch_computation_layout =
758 mutable_computation_constraints(
759 instruction->branch_computation(largest_branch))
760 ->computation_layout();
761 for (int k = 0; k < instruction->branch_count(); ++k) {
762 // Visit the best branch first.
763 int j = (k + largest_branch) % instruction->branch_count();
764 TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1);
765 ComputationLayout branch_computation_layout =
766 mutable_computation_constraints(instruction->branch_computation(k))
767 ->computation_layout();
768 if (!branch_computation_layout.result_layout().MatchesLayoutInShape(
769 best_branch_computation_layout.result_layout().shape(),
770 /*minor_to_major_only=*/true)) {
771 *branch_computation_layout.mutable_result_layout() =
772 best_branch_computation_layout.result_layout();
773 InsertOrDie(&conditional_mismatch_,
774 instruction->branch_computation(k),
775 branch_computation_layout);
776 } else {
777 TF_RETURN_IF_ERROR(SetOperandLayout(
778 branch_computation_layout.parameter_shape(0), instruction, k + 1,
779 /*mandatory=*/true, /*dfs=*/true));
780 }
781 }
782 TF_RETURN_IF_ERROR(
783 SetOperandLayout(best_branch_computation_layout.parameter_shape(0),
784 instruction, largest_branch + 1,
785 /*mandatory=*/true, /*dfs=*/true));
786 TF_RETURN_IF_ERROR(SetInstructionLayout(
787 best_branch_computation_layout.result_shape(), instruction,
788 /*mandatory=*/true, /*dfs=*/true, /*allow_alias=*/false));
789 }
790 }
791 // Finally set the result layout to match ComputationLayout, if there is one.
792 if (conditional_mismatch_.count(constraints->computation()) > 0) {
793 VLOG(5) << "Setting mismatching conditional result:"
794 << constraints->computation()->name() << "\n";
795 TF_RETURN_IF_ERROR(constraints->SetResultLayout(
796 this,
797 FindOrDie(conditional_mismatch_, constraints->computation())
798 .result_layout()
799 .shape(),
800 current_priority_ + kNumberOfPropagationRounds));
801 } else if (reverse_computation_order_ ||
802 (constraints->computation()->IsEntryComputation() &&
803 entry_computation_layout_->LayoutIsSet()) ||
804 current_priority_ > LayoutConstraint::kBeginningPriority) {
805 const ShapeLayout* result_layout = constraints->ResultLayout();
806 if (result_layout != nullptr) {
807 VLOG(2) << "Setting computation result layout.\n";
808 PushAddedConstraints(&constraints->computation_constraint());
809 } else {
810 VLOG(2) << "Computation result layout is not set.\n";
811 }
812 }
813 return OkStatus();
814 }
815
816 namespace {
817
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)818 bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) {
819 if (!lhs.has_layout() && !rhs.has_layout()) {
820 return true;
821 }
822 CHECK(lhs.has_layout() && rhs.has_layout());
823 return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout());
824 }
825
826 // The operands of a call must match the layouts of parameters in the
827 // ComputationLayout, and the call instruction itself must match the result
828 // layout in the ComputationLayout.
CheckCallLayout(HloInstruction * call,const ComputationLayout & computation_layout)829 Status CheckCallLayout(HloInstruction* call,
830 const ComputationLayout& computation_layout) {
831 HloComputation* computation = call->to_apply();
832 TF_RET_CHECK(computation->num_parameters() == call->operand_count());
833 for (int64_t i = 0; i < computation->num_parameters(); ++i) {
834 TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
835 call->operand(i)->shape(), /*minor_to_major_only=*/true));
836 }
837 TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape(
838 call->shape(), /*minor_to_major_only=*/true));
839 return OkStatus();
840 }
841
842 // Operands of layout-constrained custom calls must match the expected
843 // constrained layouts.
CheckCustomCallLayout(HloInstruction * instruction)844 Status CheckCustomCallLayout(HloInstruction* instruction) {
845 if (IsLayoutConstrainedCustomCall(instruction)) {
846 const HloCustomCallInstruction* custom_call =
847 DynCast<HloCustomCallInstruction>(instruction);
848 for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
849 TF_RET_CHECK(
850 LayoutsInShapesEqual(custom_call->operand(i)->shape(),
851 custom_call->operand_shapes_with_layout()[i]));
852 }
853 }
854 return OkStatus();
855 }
856
857 // For a while instruction, all the following layouts must be the same:
858 // (1) init operand
859 // (2) condition computation parameter
860 // (3) body computation parameter
861 // (4) body computation result
862 // (5) while instruction result
CheckWhileLayout(HloInstruction * while_inst,const ComputationLayout & condition_computation_layout,const ComputationLayout & body_computation_layout)863 Status CheckWhileLayout(HloInstruction* while_inst,
864 const ComputationLayout& condition_computation_layout,
865 const ComputationLayout& body_computation_layout) {
866 auto init_shape = while_inst->operand(0)->shape();
867 TF_RET_CHECK(
868 condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
869 init_shape, /*minor_to_major_only=*/true));
870 TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
871 init_shape, /*minor_to_major_only=*/true));
872 TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape(
873 init_shape, /*minor_to_major_only=*/true));
874 TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape()));
875 return OkStatus();
876 }
877
CheckOptimizationBarrierLayout(HloInstruction * inst)878 Status CheckOptimizationBarrierLayout(HloInstruction* inst) {
879 TF_RET_CHECK(LayoutsInShapesEqual(inst->operand(0)->shape(), inst->shape()));
880 return OkStatus();
881 }
882
CheckConditionalLayout(HloInstruction * instruction,absl::Span<const ComputationLayout> branch_computation_layouts)883 Status CheckConditionalLayout(
884 HloInstruction* instruction,
885 absl::Span<const ComputationLayout> branch_computation_layouts) {
886 for (int j = 0; j < instruction->branch_count(); ++j) {
887 const HloInstruction* branch_operand = instruction->operand(j + 1);
888 TF_RET_CHECK(
889 branch_computation_layouts[0].result_layout().MatchesLayoutInShape(
890 branch_computation_layouts[j].result_layout().shape(),
891 /*minor_to_major_only=*/true));
892 TF_RET_CHECK(
893 branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
894 instruction->shape(), /*minor_to_major_only=*/true));
895 TF_RET_CHECK(
896 branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
897 instruction->branch_computation(j)->root_instruction()->shape(),
898 /*minor_to_major_only=*/true))
899 << j << ":"
900 << instruction->branch_computation(j)->root_instruction()->ToString();
901 TF_RET_CHECK(
902 branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
903 branch_operand->shape(), /*minor_to_major_only=*/true));
904 }
905 return OkStatus();
906 }
907
908 // Fusion parameters must match the layout of the fusion instructions operands,
909 // and the root of the fusion expression must match the layout of the fusion
910 // instruction.
CheckFusionLayout(HloInstruction * fusion)911 Status CheckFusionLayout(HloInstruction* fusion) {
912 TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
913
914 TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(),
915 fusion->fused_expression_root()->shape()));
916 for (int64_t i = 0; i < fusion->operand_count(); ++i) {
917 TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(),
918 fusion->operand(i)->shape()));
919 }
920 return OkStatus();
921 }
922
923 // The layout of a parameter must match the respective layout in the
924 // computation's ComputationLayout.
CheckParameterLayout(HloInstruction * parameter,const ComputationLayout & computation_layout)925 Status CheckParameterLayout(HloInstruction* parameter,
926 const ComputationLayout& computation_layout) {
927 const ShapeLayout& parameter_layout =
928 computation_layout.parameter_layout(parameter->parameter_number());
929 return ShapeUtil::ForEachSubshapeWithStatus(
930 parameter_layout.shape(),
931 [&](const Shape& subshape, const ShapeIndex& shape_index) {
932 if (!ShapeUtil::IsLeafIndex(parameter_layout.shape(), shape_index) ||
933 !subshape.has_layout()) {
934 return OkStatus();
935 }
936 if (!Shape::Equal().MinorToMajorOnlyInLayout().IgnoreDynamicDimension()(
937 subshape,
938 ShapeUtil::GetSubshape(parameter->shape(), shape_index))) {
939 return InternalError(
940 "parameter instruction %s does not match layout of computation "
941 "shape: %s",
942 parameter->ToString(), parameter_layout.ToString());
943 }
944 return OkStatus();
945 });
946 }
947
948 // The layout of a constant instruction must match the layout of its literal.
CheckConstantLayout(HloInstruction * constant)949 Status CheckConstantLayout(HloInstruction* constant) {
950 if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) {
951 return InternalError(
952 "constant instruction %s does not match the layout of its literal %s",
953 constant->ToString(),
954 ShapeUtil::HumanStringWithLayout(constant->literal().shape()));
955 }
956 return OkStatus();
957 }
958
GetBroadcastLayoutFromOutput(const Layout & layout,const HloInstruction * hlo)959 Layout GetBroadcastLayoutFromOutput(const Layout& layout,
960 const HloInstruction* hlo) {
961 CHECK_EQ(hlo->opcode(), HloOpcode::kBroadcast);
962 Shape shape = hlo->shape();
963 *shape.mutable_layout() = layout;
964 shape = ShapeUtil::FilterDimensions(
965 [&](int64_t dim) {
966 return absl::c_linear_search(hlo->dimensions(), dim);
967 },
968 shape);
969 return shape.layout();
970 }
971
CheckBroadcastLayout(HloInstruction * broadcast)972 Status CheckBroadcastLayout(HloInstruction* broadcast) {
973 CHECK_EQ(broadcast->opcode(), HloOpcode::kBroadcast);
974 Shape shape = ShapeUtil::FilterDimensions(
975 [&](int64_t dim) {
976 return absl::c_linear_search(broadcast->dimensions(), dim);
977 },
978 broadcast->shape());
979 if (!LayoutsInShapesEqual(shape, broadcast->operand(0)->shape())) {
980 return InternalError(
981 "broadcast instruction %s does not match the layout of its operand %s",
982 broadcast->ToString(), broadcast->operand(0)->ToString());
983 }
984 return OkStatus();
985 }
986
987 } // namespace
988
CreateCopyWithNewLayout(const Shape & shape_with_layout,HloInstruction * instruction)989 StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
990 const Shape& shape_with_layout, HloInstruction* instruction) {
991 TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
992 DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
993 << ShapeUtil::HumanString(shape_with_layout) << " "
994 << ShapeUtil::HumanString(instruction->shape())
995 << " instruction: " << instruction->ToString();
996
997 if (instruction->shape().IsTuple()) {
998 // Copy tuple elements which have differing layouts.
999 std::vector<HloInstruction*> element_copies;
1000 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
1001 ++i) {
1002 const Shape& target_shape =
1003 ShapeUtil::GetSubshape(shape_with_layout, {i});
1004 const Shape& instr_shape =
1005 ShapeUtil::GetSubshape(instruction->shape(), {i});
1006 HloInstruction* gte = instruction->parent()->AddInstruction(
1007 HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
1008
1009 if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape,
1010 instr_shape)) {
1011 // Shapes and layouts are equal, no need to copy.
1012 element_copies.push_back(gte);
1013 } else {
1014 SetupCopiedInstruction(*instruction, gte, {i});
1015 // Recurse to copy each element.
1016 TF_ASSIGN_OR_RETURN(HloInstruction * element_copy,
1017 CreateCopyWithNewLayout(target_shape, gte));
1018 element_copies.push_back(element_copy);
1019 }
1020 }
1021 // Gather element copies into a tuple with a new Tuple instruction.
1022 HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
1023 HloInstruction::CreateTuple(element_copies));
1024 SetupCopiedInstruction(*instruction, tuple_copy, {});
1025 LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
1026 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1027 shape_with_layout, tuple_copy->mutable_shape()));
1028 return tuple_copy;
1029 } else if (instruction->shape().IsArray()) {
1030 HloInstruction* copy =
1031 instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
1032 instruction->shape(), HloOpcode::kCopy, instruction));
1033 RegisterAddedCopy(copy);
1034 SetupCopiedInstruction(*instruction, copy, {});
1035 LayoutUtil::ClearLayout(copy->mutable_shape());
1036 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1037 shape_with_layout, copy->mutable_shape()));
1038
1039 return copy;
1040 } else {
1041 return FailedPrecondition(
1042 "Can only copy array and tuple shaped instructions");
1043 }
1044 }
1045
1046 // Creates a copy of the given operand if the operand's layout does not match
1047 // the given layout. This copy replaces the use in the given instruction. Tuple
1048 // operands will be deep-copied.
CopyOperandIfLayoutsDiffer(const ShapeLayout & operand_layout,HloInstruction * instruction,int64_t operand_no)1049 Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
1050 const ShapeLayout& operand_layout, HloInstruction* instruction,
1051 int64_t operand_no) {
1052 HloInstruction* operand = instruction->mutable_operand(operand_no);
1053 TF_RET_CHECK(operand_layout.LayoutIsSet());
1054 TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
1055
1056 if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(),
1057 operand->shape())) {
1058 VLOG(2) << "Operand " << operand->ToString() << " layout matches in "
1059 << instruction->ToString();
1060 // Operand layout already matches our constraint. Nothing to do.
1061 return OkStatus();
1062 }
1063 VLOG(2) << "Operand " << operand->ToString() << " layout does not match "
1064 << operand_layout.ToString() << " in " << instruction->ToString();
1065
1066 // If the operand is only used by a conditional, do the copy inside the branch
1067 // to avoid overhead for other branches.
1068 if (!reverse_computation_order_ &&
1069 instruction->opcode() == HloOpcode::kConditional && operand_no > 0 &&
1070 instruction->operand(operand_no)->user_count() == 1) {
1071 auto branch_comp = instruction->branch_computation(operand_no - 1);
1072 auto param = branch_comp->parameter_instruction(0);
1073 *param->mutable_shape() = operand->shape();
1074 auto param_users = param->users();
1075 TF_ASSIGN_OR_RETURN(HloInstruction * param_copy,
1076 CreateCopyWithNewLayout(operand_layout.shape(), param));
1077 for (auto user : param_users) {
1078 TF_RETURN_IF_ERROR(param->ReplaceUseWithDifferentShape(user, param_copy));
1079 }
1080 VLOG(2) << "New copy of " << operand->ToString() << " is "
1081 << param_copy->ToString();
1082 if (param == branch_comp->root_instruction()) {
1083 branch_comp->set_root_instruction(param_copy,
1084 /*accept_different_shape=*/true);
1085 }
1086
1087 ComputationLayout computed_computation_layout(
1088 branch_comp->ComputeProgramShape(),
1089 /*ignore_layouts=*/false);
1090 mutable_computation_constraints(branch_comp)
1091 ->mutable_computation_constraint()
1092 ->ResetComputationLayout(computed_computation_layout,
1093 current_priority_ + 1,
1094 /* prop_result_layout=*/false,
1095 /*prop_parameter_layout=*/false);
1096 return OkStatus();
1097 }
1098
1099 TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
1100 CreateCopyWithNewLayout(operand_layout.shape(), operand));
1101
1102 VLOG(4) << "New copy of " << operand->ToString() << " is "
1103 << operand_copy->ToString();
1104 return instruction->ReplaceOperandWith(operand_no, operand_copy);
1105 }
1106
SetupCopiedInstruction(const HloInstruction & instruction,HloInstruction * copy,const ShapeIndex & index)1107 void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
1108 HloInstruction* copy,
1109 const ShapeIndex& index) {
1110 if (instruction.has_sharding()) {
1111 // If the index is empty, we want to copy the whole sharding, in case the
1112 // sharding is a tuple sharding.
1113 HloSharding sharding =
1114 !index.empty() && instruction.sharding().IsTuple()
1115 ? instruction.sharding().GetSubSharding(instruction.shape(), index)
1116 : instruction.sharding();
1117 // We propagate the sharding to the copied instruction only if it is a
1118 // special sharding, like tiled ones.
1119 // Otherwise it is preferable to leave the new instruction without device,
1120 // and let the automatic device placer to choose the best location.
1121 auto device = sharding.UniqueDevice();
1122 if (!device || HloSharding::IsReservedDevice(*device)) {
1123 copy->set_sharding(sharding);
1124 }
1125 }
1126 copy->set_metadata(instruction.metadata());
1127 }
1128
CheckLayouts(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)1129 Status LayoutAssignment::CheckLayouts(
1130 HloModule* module,
1131 const absl::flat_hash_set<absl::string_view>& execution_threads) {
1132 TF_ASSIGN_OR_RETURN(auto points_to_analysis,
1133 TuplePointsToAnalysis::Run(module));
1134 for (auto* computation :
1135 module->MakeNonfusionComputations(execution_threads)) {
1136 for (auto* instruction : computation->instructions()) {
1137 // Verify every instruction has a layout and the layout is valid for the
1138 // shape.
1139 TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
1140 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
1141
1142 // Use points-to analysis to verify that every subshape element in the
1143 // output of the instruction matches the layout of the logical buffer
1144 // which could be the source of the subshape value.
1145 const PointsToSet& points_to_set =
1146 points_to_analysis->GetPointsToSet(instruction);
1147 TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
1148 [&instruction](ShapeIndex index,
1149 const PointsToSet::BufferList& buffers) -> Status {
1150 if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
1151 const Shape& instruction_subshape =
1152 ShapeUtil::GetSubshape(instruction->shape(), index);
1153 for (const LogicalBuffer* buffer : buffers) {
1154 if (!Shape::Equal()
1155 .IgnoreDynamicDimension()
1156 .MinorToMajorOnlyInLayout()(instruction_subshape,
1157 buffer->shape())) {
1158 return InternalError(
1159 "Layout of instruction %s at index {%s} does not match "
1160 "source LogicalBuffer %s: %s vs %s",
1161 instruction->name(), absl::StrJoin(index, ","),
1162 buffer->ToString(),
1163 ShapeUtil::HumanStringWithLayout(instruction_subshape),
1164 ShapeUtil::HumanStringWithLayout(buffer->shape()));
1165 }
1166 }
1167 }
1168 return OkStatus();
1169 }));
1170
1171 // Verify instructions that have special layout constraints.
1172 switch (instruction->opcode()) {
1173 case HloOpcode::kCall:
1174 TF_RETURN_IF_ERROR(CheckCallLayout(
1175 instruction,
1176 FindOrDie(computation_layouts_, instruction->to_apply())
1177 ->computation_layout()));
1178 break;
1179 case HloOpcode::kCustomCall:
1180 TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
1181 break;
1182 case HloOpcode::kFusion:
1183 TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
1184 break;
1185 case HloOpcode::kParameter:
1186 TF_RETURN_IF_ERROR(CheckParameterLayout(
1187 instruction,
1188 FindOrDie(computation_layouts_, instruction->parent())
1189 ->computation_layout()));
1190 break;
1191 case HloOpcode::kBroadcast:
1192 TF_RETURN_IF_ERROR(CheckBroadcastLayout(instruction));
1193 break;
1194 case HloOpcode::kConstant:
1195 TF_RETURN_IF_ERROR(CheckConstantLayout(instruction));
1196 break;
1197 case HloOpcode::kWhile:
1198 TF_RETURN_IF_ERROR(CheckWhileLayout(
1199 instruction,
1200 FindOrDie(computation_layouts_, instruction->while_condition())
1201 ->computation_layout(),
1202 FindOrDie(computation_layouts_, instruction->while_body())
1203 ->computation_layout()));
1204 break;
1205 case HloOpcode::kOptimizationBarrier:
1206 TF_RETURN_IF_ERROR(CheckOptimizationBarrierLayout(instruction));
1207 break;
1208 case HloOpcode::kConditional: {
1209 std::vector<ComputationLayout> branch_computation_layouts;
1210 const auto& branch_computations = instruction->branch_computations();
1211 branch_computation_layouts.reserve(branch_computations.size());
1212 for (const auto branch_computation : branch_computations) {
1213 branch_computation_layouts.emplace_back(
1214 FindOrDie(computation_layouts_, branch_computation)
1215 ->computation_layout());
1216 }
1217 TF_RETURN_IF_ERROR(CheckConditionalLayout(
1218 instruction, absl::MakeSpan(branch_computation_layouts)));
1219 break;
1220 }
1221 default:
1222 break;
1223 }
1224 }
1225 }
1226 // Finally verify the result layout, if set, matches the layout of the entry
1227 // computation root.
1228 const ShapeLayout& result_layout =
1229 FindOrDie(computation_layouts_, module->entry_computation())
1230 ->computation_layout()
1231 .result_layout();
1232 if (result_layout.LayoutIsSet()) {
1233 TF_RET_CHECK(
1234 Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
1235 module->result_shape(), result_layout.shape()));
1236 }
1237 return OkStatus();
1238 }
1239
LayoutAssignment(ComputationLayout * entry_computation_layout,ChannelLayoutConstraints * channel_constraints,bool reverse_computation_order)1240 LayoutAssignment::LayoutAssignment(
1241 ComputationLayout* entry_computation_layout,
1242 ChannelLayoutConstraints* channel_constraints,
1243 bool reverse_computation_order)
1244 : entry_computation_layout_(entry_computation_layout),
1245 saved_entry_computation_layout_(*entry_computation_layout),
1246 reverse_computation_order_(reverse_computation_order),
1247 channel_layout_constraints_(channel_constraints) {
1248 if (channel_layout_constraints_ != nullptr) {
1249 // Save a copy of the input ChannelLayoutConstraints so that we can reset it
1250 // if we have to undo previous operations (ClearPreviousPassSideEffects()).
1251 channel_constraints_ = *channel_layout_constraints_;
1252 }
1253 VLOG(1) << "Entry computation layout given to layout assignment: "
1254 << entry_computation_layout_->ToString();
1255 }
1256
ChooseOperandLayoutFromOutputLayout(const Layout & output_layout,const HloInstruction * instruction,int64_t operand_no)1257 std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
1258 const Layout& output_layout, const HloInstruction* instruction,
1259 int64_t operand_no) {
1260 const HloInstruction* operand = instruction->operand(operand_no);
1261 CHECK(instruction->shape().IsArray());
1262 CHECK(operand->shape().IsArray());
1263 if (!ShapeUtil::IsScalar(operand->shape()) &&
1264 operand->shape().rank() == instruction->shape().rank() &&
1265 !InstructionCanChangeLayoutInstance(instruction)) {
1266 // Propagate the result layout to the operand layout if the instruction
1267 // requires the same layout out for the result and the operand.
1268 //
1269 // For elementwise operations, using the same layout for the operands and
1270 // the result also has the following benefits:
1271 // 1) the elementwise operation can reuse its operand's buffer, and
1272 // 2) the input and output elements can reuse the same linear index.
1273 return std::make_unique<Layout>(output_layout);
1274 }
1275
1276 if (instruction->opcode() == HloOpcode::kReshape) {
1277 // Prefer the operand layout that makes the reshape an bitcast. If any
1278 // dimension bound is 1 in the operand shape, there may be several such
1279 // layouts. So if 'output_layout' is the default layout, try if the
1280 // reshape is a bitcast when using the same layout. This may avoid copy
1281 // operations. For similar reasons, if the operand and output have the same
1282 // rank, try to match the operand's layout to the output.
1283 if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1284 ShapeUtil::TrueRank(instruction->shape()) == 1) {
1285 // Don't assign a layout in case of R1 -> effective R1 reshape.
1286 return nullptr;
1287 }
1288
1289 const Shape& output_shape = instruction->shape();
1290 Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1291 output_shape.element_type(), output_shape.dimensions(),
1292 LayoutUtil::MinorToMajor(output_layout));
1293 Shape operand_shape = operand->shape();
1294 *operand_shape.mutable_layout() =
1295 LayoutUtil::GetDefaultLayoutForShape(operand_shape);
1296 auto aligned_operand_shape =
1297 ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape);
1298 if (aligned_operand_shape) {
1299 auto operand_layout = aligned_operand_shape.value().layout();
1300 TF_CHECK_OK(
1301 LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
1302 return std::make_unique<Layout>(operand_layout);
1303 }
1304 }
1305
1306 if (instruction->opcode() == HloOpcode::kTranspose) {
1307 // Pick the operand layout that makes the transpose a bitcast.
1308 int64_t rank = instruction->shape().rank();
1309 std::vector<int64_t> new_minor_to_major(rank);
1310 for (int64_t i = 0; i < rank; ++i) {
1311 int64_t output_dim = LayoutUtil::Minor(output_layout, i);
1312 int64_t operand_dim = instruction->dimensions(output_dim);
1313 new_minor_to_major[i] = operand_dim;
1314 }
1315 Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1316 TF_CHECK_OK(
1317 LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
1318 return std::make_unique<Layout>(operand_layout);
1319 }
1320
1321 return nullptr;
1322 }
1323
GetReduceLayoutFromOperand(const Layout & operand_layout,const HloInstruction * hlo)1324 static Layout GetReduceLayoutFromOperand(const Layout& operand_layout,
1325 const HloInstruction* hlo) {
1326 CHECK_EQ(hlo->opcode(), HloOpcode::kReduce);
1327 Shape operand_shape = hlo->operand(0)->shape();
1328 *operand_shape.mutable_layout() = operand_layout;
1329 operand_shape = ShapeUtil::DeleteDimensions(hlo->dimensions(), operand_shape);
1330 return operand_shape.layout();
1331 }
1332
ChooseOutputLayoutFromOperandLayout(const Layout & operand_layout,const HloInstruction * user,int64_t operand_no)1333 std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
1334 const Layout& operand_layout, const HloInstruction* user,
1335 int64_t operand_no) {
1336 const HloInstruction* operand = user->operand(operand_no);
1337
1338 // Enforce standard layout on variadic reduction output to avoid having two
1339 // inconsistent layouts.
1340 if (user->opcode() == HloOpcode::kReduce && user->shape().IsTuple()) {
1341 return std::make_unique<Layout>(
1342 GetReduceLayoutFromOperand(operand_layout, user));
1343 }
1344
1345 CHECK(user->shape().IsArray() && operand->shape().IsArray())
1346 << "Fails on instruction: " << user->ToString();
1347
1348 if (!ShapeUtil::IsScalar(operand->shape()) &&
1349 operand->shape().rank() == user->shape().rank() &&
1350 !InstructionCanChangeLayoutInstance(user)) {
1351 // Assign users the same layout as the operand.
1352 return std::make_unique<Layout>(operand_layout);
1353 }
1354
1355 if (user->opcode() == HloOpcode::kReshape) {
1356 // Prefer the user layout that makes the reshape an bitcast. If any
1357 // dimension bound is 1 in the user shape, there may be several such
1358 // layouts. So if 'operand_layout' is the default layout, try if the
1359 // reshape is a bitcast when using the same layout. This may avoid copy
1360 // operations. For similar reasons, if the operand and output have the same
1361 // rank, try to match the outputs's layout to the operand.
1362 if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1363 ShapeUtil::TrueRank(user->shape()) == 1) {
1364 // Don't assign a layout in case of R1 -> effective R1 reshape.
1365 return nullptr;
1366 }
1367 Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1368 operand->shape().element_type(), operand->shape().dimensions(),
1369 LayoutUtil::MinorToMajor(operand_layout));
1370 Shape output_shape = user->shape();
1371 *output_shape.mutable_layout() =
1372 LayoutUtil::GetDefaultLayoutForShape(output_shape);
1373 auto aligned_user_shape =
1374 ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape);
1375 if (aligned_user_shape) {
1376 auto user_layout = aligned_user_shape.value().layout();
1377 TF_CHECK_OK(
1378 LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
1379 return std::make_unique<Layout>(user_layout);
1380 }
1381 }
1382
1383 if (user->opcode() == HloOpcode::kTranspose) {
1384 // Pick the user layout that makes the transpose a bitcast.
1385 int64_t rank = user->shape().rank();
1386 std::vector<int64_t> new_minor_to_major(rank);
1387 auto inverse_dimensions = InversePermutation(user->dimensions());
1388 for (int64_t i = 0; i < rank; ++i) {
1389 int64_t operand_dim = LayoutUtil::Minor(operand_layout, i);
1390 int64_t user_dim = inverse_dimensions[operand_dim];
1391 new_minor_to_major[i] = user_dim;
1392 }
1393 Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1394 TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
1395 return std::make_unique<Layout>(user_layout);
1396 }
1397
1398 return nullptr;
1399 }
1400
PropagateConstraints(LayoutConstraints * constraints)1401 Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
1402 // Gathers all initial constraints in a worklist and propagates them in
1403 // depth-first order. DFS order seems to be better than BFS because a
1404 // constraint is propagated as far as possible before propagating unrelated
1405 // constraints which makes it less likely that conflicting constraints will be
1406 // propagated to instructions. However, we should experiment with other orders
1407 // too.
1408 std::deque<const LayoutConstraint*> worklist;
1409
1410 // Lambda for moving newly added constraints to the worklist.
1411 auto add_new_constraints_to_worklist = [this, &worklist]() {
1412 // Add constraints to the front of the deque for DFS ordering.
1413 for (auto* constraint : ConsumeAddedConstraints()) {
1414 if (constraint->dfs()) {
1415 worklist.push_front(constraint);
1416 } else {
1417 VLOG(3) << "push back constraint for propagation : "
1418 << constraint->ToString();
1419 worklist.push_back(constraint);
1420 }
1421 }
1422 };
1423 add_new_constraints_to_worklist();
1424
1425 while (!worklist.empty()) {
1426 const LayoutConstraint* layout_constraint = worklist.front();
1427 worklist.pop_front();
1428 VLOG(2) << "Propagating " << layout_constraint->ToString()
1429 << " to its neighbors with priority = "
1430 << layout_constraint->priority() << "\n";
1431 if (auto* buffer_constraint =
1432 dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
1433 TF_RETURN_IF_ERROR(
1434 PropagateBufferConstraint(*buffer_constraint, constraints));
1435 } else if (auto* operand_constraint =
1436 dynamic_cast<const OperandLayoutConstraint*>(
1437 layout_constraint)) {
1438 TF_RETURN_IF_ERROR(
1439 PropagateOperandConstraint(*operand_constraint, constraints));
1440 } else if (auto* computation_constraint =
1441 dynamic_cast<const ComputationLayoutConstraint*>(
1442 layout_constraint)) {
1443 TF_RETURN_IF_ERROR(
1444 PropagateResultConstraint(*computation_constraint, constraints));
1445 } else {
1446 LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
1447 }
1448 add_new_constraints_to_worklist();
1449 }
1450 return OkStatus();
1451 }
1452
1453 namespace {
1454
1455 // Returns a vector containing all array-shaped uses (instruction and operand
1456 // number) of the given logical buffer or its aliases.
GetArrayUsesOfBuffer(const TuplePointsToAnalysis::BufferAliasVector & aliases)1457 std::vector<std::pair<const HloInstruction*, int64_t>> GetArrayUsesOfBuffer(
1458 const TuplePointsToAnalysis::BufferAliasVector& aliases) {
1459 std::vector<std::pair<const HloInstruction*, int64_t>> uses;
1460 for (const auto& buffer_alias : aliases) {
1461 if (!buffer_alias.instruction()->shape().IsArray()) {
1462 continue;
1463 }
1464 // This alias must be the top-level (index == {}) of the instruction's
1465 // result because the instruction produces an array.
1466 CHECK(buffer_alias.index().empty());
1467
1468 // Add all uses of the instruction's output.
1469 for (const HloInstruction* user : buffer_alias.instruction()->users()) {
1470 for (int64_t operand_no :
1471 user->OperandIndices(buffer_alias.instruction())) {
1472 uses.emplace_back(user, operand_no);
1473 }
1474 }
1475 }
1476 return uses;
1477 }
1478
1479 } // namespace
1480
PropagateUseConstraintToDefs(const ShapeLayout & shape_layout,const HloInstruction * instruction,LayoutConstraints * constraints,int64_t priority)1481 Status LayoutAssignment::PropagateUseConstraintToDefs(
1482 const ShapeLayout& shape_layout, const HloInstruction* instruction,
1483 LayoutConstraints* constraints, int64_t priority) {
1484 // Try to set all logical buffers which may be sources of the given operand to
1485 // match the given layout.
1486 const PointsToSet& points_to_set =
1487 points_to_analysis_->GetPointsToSet(instruction);
1488 return points_to_set.ForEachElementWithStatus(
1489 [&shape_layout, this, priority](
1490 const ShapeIndex& index,
1491 const PointsToSet::BufferList& buffers) -> Status {
1492 if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
1493 for (const LogicalBuffer* buffer : buffers) {
1494 if (buffer->shape().IsArray() &&
1495 GetBufferLayoutConstraint(*buffer) == nullptr &&
1496 (buffer->instruction()->opcode() != HloOpcode::kReduce ||
1497 !buffer->instruction()->shape().IsTuple())) {
1498 TF_RETURN_IF_ERROR(SetBufferLayout(
1499 ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
1500 *buffer, /*mandatory=*/true, /*dfs=*/true, priority));
1501 }
1502 }
1503 }
1504 return OkStatus();
1505 });
1506 }
1507
1508 namespace {
1509 // A transpose or a reshape that only changes trivial dimensions have meaningful
1510 // layouts that are valuable to propagate in a depthfirst manner to avoid
1511 // unassigned layouts in the graph.
InstructionShouldPropagateDepthFirst(const HloInstruction & hlo)1512 bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
1513 switch (hlo.opcode()) {
1514 case HloOpcode::kFusion:
1515 return hlo.IsCustomFusion();
1516 case HloOpcode::kGather:
1517 return true;
1518 case HloOpcode::kReshape:
1519 return hlo.operand(0)->shape().rank() == 1 ||
1520 hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions().has_value();
1521 case HloOpcode::kScatter:
1522 case HloOpcode::kTranspose:
1523 return true;
1524 default:
1525 return false;
1526 }
1527 }
1528
1529 } // namespace
1530
PropagateOperandConstraint(const OperandLayoutConstraint & operand_constraint,LayoutConstraints * constraints)1531 Status LayoutAssignment::PropagateOperandConstraint(
1532 const OperandLayoutConstraint& operand_constraint,
1533 LayoutConstraints* constraints) {
1534 VLOG(3) << "Propagate Operand Constraint : " << operand_constraint.ToString()
1535 << "\n";
1536 // Try to set the layout of the logical buffers in the given operand to match
1537 // the constrained layout. This avoids copies.
1538 TF_RETURN_IF_ERROR(PropagateUseConstraintToDefs(
1539 operand_constraint.shape_layout(), operand_constraint.operand(),
1540 constraints, operand_constraint.priority()));
1541
1542 // For array-shaped operands and user instructions try to pick a minimum cost
1543 // layout. For example, if the operand of an elementwise instruction is
1544 // constrained to a certain layout we want the output of the instruction to
1545 // have the same layout.
1546 //
1547 // If the user is not array-shaped, we still want to propagate the layout
1548 // to siblings if the instruction can't change layout. This is to represent
1549 // the information that non-layout-changing instructions should have the same
1550 // layout for the operands with the same ranks.
1551 const HloInstruction* operand = operand_constraint.operand();
1552 const HloInstruction* user = operand_constraint.instruction();
1553 if (!operand->shape().IsArray()) {
1554 return OkStatus();
1555 }
1556
1557 if (user->opcode() == HloOpcode::kAllReduce) {
1558 const auto shape_index =
1559 user->operand_count() == 1
1560 ? ShapeIndex()
1561 : ShapeIndex({operand_constraint.operand_no()});
1562 TF_ASSIGN_OR_RETURN(
1563 const LogicalBuffer* buffer,
1564 points_to_analysis_->GetBufferDefinedAt(user, shape_index));
1565 const BufferLayoutConstraint* constraint =
1566 GetBufferLayoutConstraint(*buffer);
1567 if (constraint == nullptr) {
1568 TF_RETURN_IF_ERROR(
1569 SetBufferLayout(operand_constraint.shape_layout().layout(), *buffer,
1570 /*mandatory=*/false, /*dfs=*/true));
1571 }
1572 }
1573
1574 if (InstructionCanChangeLayoutInstance(user) && !user->shape().IsArray() &&
1575 user->opcode() != HloOpcode::kReduce) {
1576 return OkStatus();
1577 }
1578
1579 // Only try to choose a low cost layout if the instruction 'user' defines its
1580 // output (ie, doesn't forward a buffer from elsewhere).
1581 if (AnyOperandBufferForwarded(user, operand_constraint.operand_no())) {
1582 return OkStatus();
1583 }
1584
1585 int64_t operand_rank = operand->shape().rank();
1586 if (operand_rank <= 1) {
1587 return OkStatus();
1588 }
1589
1590 // Propagate layouts between operands of the same instruction. This is a
1591 // constraint on non-layout-changing instructions.
1592 if (!InstructionCanChangeLayoutInstance(user)) {
1593 // Only propgate the layout of the largest concatenate operand.
1594 if (user->opcode() == HloOpcode::kConcatenate) {
1595 for (int64_t operand_no = 0; operand_no < user->operand_count();
1596 ++operand_no) {
1597 const HloInstruction* sibling = user->operand(operand_no);
1598 if (sibling == operand) {
1599 continue;
1600 }
1601 if (sibling->shape().dimensions(user->concatenate_dimension()) >
1602 operand->shape().dimensions(user->concatenate_dimension())) {
1603 return OkStatus();
1604 }
1605 }
1606 }
1607 // Make sure all siblings have the same layout as the operand.
1608 for (int64_t operand_no = 0; operand_no < user->operand_count();
1609 ++operand_no) {
1610 if (user->operand(operand_no) == operand) {
1611 continue;
1612 }
1613 const HloInstruction* sibling = user->operand(operand_no);
1614 if (!sibling->shape().IsArray()) {
1615 continue;
1616 }
1617 const int64_t sibling_rank = sibling->shape().rank();
1618 if (sibling_rank <= 1) {
1619 continue;
1620 }
1621 if (operand_rank != sibling_rank) {
1622 continue;
1623 }
1624 const OperandLayoutConstraint* constraint =
1625 constraints->GetOperandLayoutConstraint(user, operand_no);
1626 if (constraint != nullptr) {
1627 // Due to the DFS of the propagation we can end up here when operand_no
1628 // has a layout set that hasn't been propagated yet (is still on the
1629 // stack of layouts to propagate).
1630 // We can continue here and leave the operands with different layouts,
1631 // as we will either:
1632 // - overwrite the current operand when the DFS gets back to propagating
1633 // operand(operand_no) to its siblings
1634 // - overwrite operand(operand_no)'s layout with a mandatory layout if
1635 // we continue to propagate our layout to the result, and then
1636 // backwards into all operands (if the result is an array of rank > 1)
1637 continue;
1638 }
1639 TF_RETURN_IF_ERROR(SetArrayOperandLayout(
1640 operand_constraint.shape_layout().layout(), user, operand_no,
1641 /*mandatory=*/false, /*dfs=*/true, operand_constraint.priority()));
1642 }
1643 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1644 user->shape(),
1645 [&](const Shape& subshape, const ShapeIndex& shape_index) {
1646 if (subshape.IsTuple()) {
1647 return OkStatus();
1648 }
1649 if (subshape.rank() <= 1) {
1650 return OkStatus();
1651 }
1652
1653 // Assign the right layout to input fusion of higher rank reduce
1654 // operations.
1655 if (subshape.rank() != operand->shape().rank()) {
1656 return OkStatus();
1657 }
1658 if (!points_to_analysis_->InstructionDefinesBufferAtIndex(
1659 user, shape_index)) {
1660 return OkStatus();
1661 }
1662 // TODO(b/67641796): Are there cases except fusion that use this code
1663 // path?
1664 TF_ASSIGN_OR_RETURN(
1665 const LogicalBuffer* buffer,
1666 points_to_analysis_->GetBufferDefinedAt(user, shape_index));
1667 // If we already have a constraint for the buffer it was assigned but
1668 // hasn't propagated yet. This can happen with diamond-shaped graphs
1669 // where one path is first evaluated in depth-first order (we're here)
1670 // and the other path is propagated later. We don't set the layout
1671 // here as it will always be overwritten later.
1672 TF_RETURN_IF_ERROR(SetBufferLayout(
1673 operand_constraint.shape_layout().layout(), *buffer,
1674 /*mandatory=*/false));
1675 return OkStatus();
1676 }));
1677 return OkStatus();
1678 }
1679 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1680 user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) {
1681 if (subshape.IsTuple()) {
1682 return OkStatus();
1683 }
1684 if (subshape.rank() <= 1) {
1685 return OkStatus();
1686 }
1687 if (!points_to_analysis_->InstructionDefinesBufferAtIndex(
1688 user, shape_index)) {
1689 return OkStatus();
1690 }
1691 TF_ASSIGN_OR_RETURN(
1692 const LogicalBuffer* buffer,
1693 points_to_analysis_->GetBufferDefinedAt(user, shape_index));
1694 auto* buffer_constraint = GetBufferLayoutConstraint(*buffer);
1695 if (buffer_constraint == nullptr || !buffer_constraint->mandatory()) {
1696 std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
1697 operand_constraint.shape_layout().layout(), user,
1698 operand_constraint.operand_no());
1699 if (layout != nullptr) {
1700 TF_RETURN_IF_ERROR(SetBufferLayout(
1701 *layout, *buffer,
1702 /*mandatory=*/user->opcode() == HloOpcode::kReduce,
1703 /*dfs=*/InstructionShouldPropagateDepthFirst(*user),
1704 operand_constraint.priority()));
1705 }
1706 }
1707 return OkStatus();
1708 }));
1709 return OkStatus();
1710 }
1711
PropagateBufferConstraintToOperands(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1712 Status LayoutAssignment::PropagateBufferConstraintToOperands(
1713 const BufferLayoutConstraint& buffer_constraint,
1714 LayoutConstraints* constraints) {
1715 VLOG(5) << "PropagateBufferConstraintToOperands: "
1716 << buffer_constraint.ToString();
1717 const LogicalBuffer& buffer = buffer_constraint.buffer();
1718
1719 const HloInstruction* instruction = buffer.instruction();
1720 if (IsAtMostRank1(instruction->shape())) {
1721 return OkStatus();
1722 }
1723
1724 if (instruction->opcode() == HloOpcode::kAllReduce) {
1725 TF_RETURN_IF_ERROR(SetArrayOperandLayout(
1726 buffer_constraint.layout(), instruction,
1727 instruction->operand_count() == 1 ? 0 : buffer.index()[0],
1728 /*mandatory=*/true, /*dfs=*/true, buffer_constraint.priority()));
1729 return OkStatus();
1730 }
1731 for (int64_t operand_no = 0; operand_no < instruction->operand_count();
1732 ++operand_no) {
1733 const HloInstruction* operand = instruction->operand(operand_no);
1734 if (IsAtMostRank1(operand->shape())) {
1735 continue;
1736 }
1737 if (!InstructionCanChangeLayoutInstance(instruction)) {
1738 // Copy the layout to the operand.
1739 if (buffer.IsArray() && operand->shape().IsArray() &&
1740 operand->shape().rank() ==
1741 LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) {
1742 TF_RETURN_IF_ERROR(SetArrayOperandLayout(
1743 buffer_constraint.layout(), instruction, operand_no,
1744 /*mandatory=*/true, /*dfs=*/true, current_priority_));
1745 }
1746 } else if (instruction->opcode() == HloOpcode::kBroadcast) {
1747 Layout layout =
1748 GetBroadcastLayoutFromOutput(buffer_constraint.layout(), instruction);
1749 TF_RETURN_IF_ERROR(SetArrayOperandLayout(
1750 layout, instruction, operand_no, /*mandatory=*/true,
1751 /*dfs=*/
1752 InstructionShouldPropagateDepthFirst(*instruction),
1753 current_priority_));
1754 } else {
1755 if (!buffer.IsTopLevel() ||
1756 !instruction->operand(operand_no)->shape().IsArray()) {
1757 continue; // Don't touch buffers that are internal to a tuple.
1758 }
1759 VLOG(6) << "Propagating constraint to operand " << operand_no << " of "
1760 << instruction->ToShortString();
1761 // Assign a layout if there is no constraint already.
1762 const OperandLayoutConstraint* constraint =
1763 constraints->GetOperandLayoutConstraint(instruction, operand_no);
1764 if (constraint == nullptr || !constraint->mandatory()) {
1765 std::unique_ptr<Layout> operand_layout =
1766 ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
1767 instruction, operand_no);
1768 if (operand_layout != nullptr) {
1769 TF_RETURN_IF_ERROR(SetArrayOperandLayout(
1770 *operand_layout, instruction, operand_no, /*mandatory=*/false,
1771 /*dfs=*/
1772 InstructionShouldPropagateDepthFirst(*instruction),
1773 current_priority_));
1774 }
1775 } else {
1776 VLOG(6) << "Operand already has a constraint "
1777 << constraint->ToString();
1778 }
1779 }
1780 }
1781 return OkStatus();
1782 }
1783
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1784 Status LayoutAssignment::PropagateBufferConstraint(
1785 const BufferLayoutConstraint& buffer_constraint,
1786 LayoutConstraints* constraints) {
1787 // Only propagate array layouts.
1788 const LogicalBuffer& buffer = buffer_constraint.buffer();
1789 if (!buffer.IsArray()) {
1790 return OkStatus();
1791 }
1792 TF_RETURN_IF_ERROR(
1793 PropagateBufferConstraintToOperands(buffer_constraint, constraints));
1794 return PropagateBufferConstraintToUses(buffer_constraint, constraints);
1795 }
1796
PropagateBufferConstraintToUses(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1797 Status LayoutAssignment::PropagateBufferConstraintToUses(
1798 const BufferLayoutConstraint& buffer_constraint,
1799 LayoutConstraints* constraints) {
1800 VLOG(5) << "PropagateBufferConstraintToUses: "
1801 << buffer_constraint.ToString();
1802 const LogicalBuffer& buffer = buffer_constraint.buffer();
1803 TF_RET_CHECK(buffer.IsArray());
1804
1805 // Propagate the layout to all array uses of the logical buffer. This skips
1806 // uses of the buffer where the buffer is the element of a tuple.
1807 for (const auto& user_operand_no :
1808 GetArrayUsesOfBuffer(points_to_analysis_->GetBufferAliases(buffer))) {
1809 const HloInstruction* user = user_operand_no.first;
1810 int64_t operand_no = user_operand_no.second;
1811 // Only add an operand constraint if the user does not forward the buffer
1812 // because this case is not handled is SetOperandLayout.
1813 if (constraints->OperandLayout(user, operand_no) == nullptr &&
1814 !AnyOperandBufferForwarded(user, operand_no)) {
1815 TF_RETURN_IF_ERROR(SetArrayOperandLayout(
1816 buffer_constraint.layout(), user, operand_no, /*mandatory=*/false,
1817 /*dfs=*/true, buffer_constraint.priority()));
1818 }
1819 }
1820
1821 // Propagate to backedges of kWhile.
1822 CallGraphNode& node = call_graph_->GetNode(buffer.instruction()->parent());
1823 if (node.caller_callsites().size() != 1) {
1824 return OkStatus();
1825 }
1826 const HloInstruction* parent = node.caller_callsites()[0].instruction();
1827 if (parent->opcode() != HloOpcode::kWhile) {
1828 return OkStatus();
1829 }
1830
1831 for (HloInstruction* user : buffer.instruction()->users()) {
1832 if (user->parent()->root_instruction()->opcode() != HloOpcode::kTuple) {
1833 continue;
1834 }
1835 if (user->parent()->root_instruction() == user) {
1836 VLOG(3) << "Propagating layout through backedge"
1837 << buffer_constraint.layout().ToString();
1838 int64_t index = user->operand_index(buffer.instruction());
1839 TF_ASSIGN_OR_RETURN(
1840 auto buffer, points_to_analysis_->GetBufferDefinedAt(
1841 user->parent()->parameter_instruction(0), {index}));
1842
1843 TF_RETURN_IF_ERROR(SetBufferLayout(buffer_constraint.layout(), *buffer,
1844 /*mandatory=*/false));
1845 }
1846 }
1847
1848 return OkStatus();
1849 }
1850
PropagateResultConstraint(const ComputationLayoutConstraint & layout_constraint,LayoutConstraints * constraints)1851 Status LayoutAssignment::PropagateResultConstraint(
1852 const ComputationLayoutConstraint& layout_constraint,
1853 LayoutConstraints* constraints) {
1854 // Propagate the use constraint of the root instruction up to the logical
1855 // buffers which make up the result.
1856 return PropagateUseConstraintToDefs(
1857 layout_constraint.computation_layout().result_layout(),
1858 constraints->computation()->root_instruction(), constraints,
1859 current_priority_);
1860 }
1861
1862 // Infers the layout of the array at the given index in the given instruction's
1863 // output using points-to analysis. Precondition: The given instruction must
1864 // not produce this array value (that is, the array is forwarded from the
1865 // instruction's operands).
InferArrayLayout(const HloInstruction * instruction,const ShapeIndex & index)1866 StatusOr<Layout> LayoutAssignment::InferArrayLayout(
1867 const HloInstruction* instruction, const ShapeIndex& index) {
1868 const auto& source_buffers =
1869 points_to_analysis_->GetPointsToSet(instruction).element(index);
1870 TF_RET_CHECK(!source_buffers.empty());
1871
1872 // Verify the layout is the same for every LogicalBuffer which this location
1873 // ('instruction' and 'index') points to.
1874 const Layout* first_buffer_layout = nullptr;
1875 for (const LogicalBuffer* source_buffer : source_buffers) {
1876 VLOG(5) << "Logical buffer: " << source_buffer->ToString() << "\n";
1877 auto* source_buffer_constraint = GetBufferLayoutConstraint(*source_buffer);
1878 if (source_buffer_constraint == nullptr) {
1879 // This should not happen because we've assigned layouts to all
1880 // instructions preceding this one.
1881 return InternalError("LogicalBuffer %s does not have a layout",
1882 source_buffer->ToString());
1883 }
1884
1885 if (first_buffer_layout == nullptr) {
1886 first_buffer_layout = &source_buffer_constraint->layout();
1887 } else if (!Layout::Equal().MinorToMajorOnly()(
1888 source_buffer->shape().layout(), *first_buffer_layout)) {
1889 // The points-to set is ambiguous for this index and the different source
1890 // buffers have different layouts. This case is possible in valid XLA
1891 // computations because we do not propagate BufferLayoutConstraints to all
1892 // LogicalBuffers which may alias the constrained LogicalBuffer at some
1893 // point in the computation.
1894 return FailedPrecondition(
1895 "Array at index {%s} in instruction %s aliases buffers %s "
1896 "and %s which have different layouts",
1897 absl::StrJoin(index, ","), instruction->name(),
1898 source_buffers[0]->ToString(), source_buffer->ToString());
1899 }
1900 }
1901
1902 return *first_buffer_layout;
1903 }
1904
1905 namespace {
1906
1907 // For fusion instructions, set the layout of each fused parameter instruction
1908 // to match the layout of its corresponding fusion instruction operand. Also,
1909 // set the layout of the fused root to match the layout of the fusion
1910 // instruction itself.
SetFusionLayouts(HloInstruction * fusion)1911 Status SetFusionLayouts(HloInstruction* fusion) {
1912 TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
1913 for (auto* fused_instruction :
1914 fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
1915 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1916 const HloInstruction* fusion_operand =
1917 fusion->operand(fused_instruction->parameter_number());
1918 DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
1919 fused_instruction->shape()));
1920 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1921 fusion_operand->shape(), fused_instruction->mutable_shape()));
1922 } else if (fused_instruction == fusion->fused_expression_root()) {
1923 // The layout of the root of the fused expression must match the fusion
1924 // instruction layout.
1925 DCHECK(
1926 ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
1927 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1928 fusion->shape(), fused_instruction->mutable_shape()));
1929 } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) {
1930 // A GTE inherits its layout from its operand (which should ultimately be
1931 // a parameter).
1932 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1933 fused_instruction->operand(0)->shape().tuple_shapes(
1934 fused_instruction->tuple_index()),
1935 fused_instruction->mutable_shape()));
1936 } else if (fused_instruction->opcode() == HloOpcode::kConstant) {
1937 // Give constants the layout of their literal.
1938 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1939 fused_instruction->literal().shape(),
1940 fused_instruction->mutable_shape()));
1941 } else if (fused_instruction->opcode() == HloOpcode::kInfeed) {
1942 // Nop; leave the infeed layout alone.
1943 } else if (!fusion->IsCustomFusion()) {
1944 // Other instructions don't have layouts inside of fusion nodes.
1945 // But do not clear layouts for other instructions in custom fusion nodes.
1946 LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
1947 }
1948 }
1949
1950 return OkStatus();
1951 }
1952
1953 } // namespace
1954
AssignLayouts(LayoutConstraints & constraints)1955 Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) {
1956 HloComputation* computation = constraints.computation();
1957 VLOG(2) << "Assigning layouts to computation: " << computation->name();
1958
1959 XLA_VLOG_LINES(2, ToString(constraints));
1960
1961 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1962 if (instruction->opcode() == HloOpcode::kBitcast) {
1963 // bitcasts are inherently layout sensitive and so a bitcast instruction
1964 // present in the IR before layout assignment is a bug.
1965 return InternalError(
1966 "Unexpected bitcast operation seen during layout assignment: %s.",
1967 instruction->ToString());
1968 }
1969 LayoutUtil::ClearLayout(instruction->mutable_shape());
1970
1971 // Set the layouts of the array shapes this instruction defines as indicated
1972 // by the respective BufferLayoutConstraints. Any array shapes in the output
1973 // of the instruction which are not defined by the instruction (eg, array
1974 // elements in a Tuple instruction) will be assigned below via inference.
1975 for (const LogicalBuffer* buffer :
1976 points_to_analysis_->GetBuffersDefinedByInstruction(instruction)) {
1977 if (!buffer->shape().IsArray()) {
1978 continue;
1979 }
1980 TF_RET_CHECK(buffer->instruction() == instruction);
1981 auto* buffer_layout_constraint = GetBufferLayoutConstraint(*buffer);
1982 TF_RET_CHECK(buffer_layout_constraint != nullptr);
1983 if (instruction->opcode() == HloOpcode::kConstant) {
1984 // For constants, we also need to change the layout of the internal
1985 // literal.
1986 instruction->RelayoutConstant(buffer_layout_constraint->layout(),
1987 buffer->index());
1988 } else {
1989 Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
1990 instruction->mutable_shape(), buffer->index());
1991 *buffer_subshape->mutable_layout() = buffer_layout_constraint->layout();
1992 }
1993 }
1994
1995 // Any remaining layouts in the output of the instruction must be
1996 // inferrable using points-to analysis.
1997 TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
1998 instruction->mutable_shape(),
1999 [instruction, this](Shape* subshape, const ShapeIndex& index) {
2000 if (subshape->has_layout() || !subshape->IsArray()) {
2001 return OkStatus();
2002 }
2003 // Set Layout of subshape to match layout of LogicalBuffer which
2004 // produces it.
2005 TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
2006 InferArrayLayout(instruction, index));
2007 return OkStatus();
2008 }));
2009 VLOG(3) << "Instruction layout:" << instruction->ToString();
2010 // Create a copy of an operand if the operand instruction's layout does not
2011 // match the use constraint (OperandLayoutConstraint).
2012 for (int64_t operand_no = 0; operand_no < instruction->operand_count();
2013 ++operand_no) {
2014 const ShapeLayout* operand_layout =
2015 constraints.OperandLayout(instruction, operand_no);
2016 if (operand_layout != nullptr) {
2017 TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
2018 instruction, operand_no));
2019 } else {
2020 VLOG(2) << "operand " << operand_no << " has no constraint";
2021 }
2022 }
2023 if (instruction->opcode() == HloOpcode::kFusion) {
2024 TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
2025 }
2026
2027 VLOG(3) << "Resulting instruction:" << instruction->ToString() << "\n";
2028 // Execute extra verification step once the layout has been finalized.
2029 TF_RETURN_IF_ERROR(Verify(instruction));
2030
2031 // Shape must be valid.
2032 TF_RETURN_IF_ERROR(
2033 ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
2034
2035 // Verify all layouts in the shape have been set.
2036 TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
2037 }
2038 // Copy the root instruction's result if its layout does not match the result
2039 // layout constraint.
2040 if (constraints.ResultLayout() != nullptr) {
2041 // Layout assignment at this point only does minor-to-major assignment so
2042 // tiling info should be ignored here for comparison.
2043 VLOG(5) << "Computation result layout needs root copying\n";
2044 if (!constraints.ResultLayout()->MatchesLayoutInShape(
2045 computation->root_instruction()->shape(),
2046 /*minor_to_major_only=*/true)) {
2047 TF_ASSIGN_OR_RETURN(
2048 HloInstruction * new_root,
2049 CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
2050 computation->root_instruction()));
2051 computation->set_root_instruction(new_root);
2052 } else {
2053 // Copy the tiling info specified in result layout.
2054 auto copy_tiling = [&constraints](xla::Shape* subshape,
2055 const xla::ShapeIndex& index) {
2056 if (subshape->IsArray()) {
2057 const Shape& result_shape = ShapeUtil::GetSubshape(
2058 constraints.ResultLayout()->shape(), index);
2059 if (result_shape.layout().tiles_size() != 0) {
2060 subshape->mutable_layout()->mutable_tiles()->assign(
2061 result_shape.layout().tiles().begin(),
2062 result_shape.layout().tiles().end());
2063 }
2064 }
2065 };
2066 xla::ShapeUtil::ForEachMutableSubshape(
2067 computation->root_instruction()->mutable_shape(), copy_tiling);
2068 }
2069 }
2070 VLOG(5) << "Final computation layout:" << computation->name() << ":"
2071 << constraints.computation_constraint().ToString() << "\n";
2072 VLOG(5) << "Root instruction:" << computation->root_instruction()->ToString()
2073 << "\n";
2074 return OkStatus();
2075 }
2076
CalculateComputationLayout(LayoutConstraints * constraints)2077 Status LayoutAssignment::CalculateComputationLayout(
2078 LayoutConstraints* constraints) {
2079 // Process instructions that contain nested computations and may require
2080 // additional layouts to be assigned on the instructions nested inside.
2081
2082 auto UpdateLayout = [this](const HloInstruction* operand,
2083 ShapeLayout* update) -> bool {
2084 bool change = false;
2085 ShapeUtil::ForEachSubshape(
2086 operand->shape(), [this, &change, operand, update](
2087 const Shape& subshape, const ShapeIndex& index) {
2088 if (subshape.IsTuple()) {
2089 return;
2090 }
2091 auto param_layout = InferArrayLayout(operand, index);
2092 if (param_layout.ok()) {
2093 VLOG(5) << index << ":" << param_layout.ValueOrDie().ToString()
2094 << "\n";
2095 update->ResetLayout(param_layout.ValueOrDie(), index);
2096 change = true;
2097 }
2098 });
2099 return change;
2100 };
2101
2102 auto SetCalleeLayout =
2103 [this, UpdateLayout](const HloInstruction* result,
2104 absl::Span<const HloInstruction* const> operands,
2105 LayoutConstraints* callee, int priority) -> Status {
2106 CHECK_NE(result, nullptr);
2107 ComputationLayoutConstraint* callee_constraint =
2108 callee->mutable_computation_constraint();
2109 ComputationLayout callee_layout = callee_constraint->computation_layout();
2110 if (callee_constraint->priority() < priority ||
2111 conditional_mismatch_.count(callee->computation()) > 0) {
2112 if (conditional_mismatch_.count(callee->computation()) == 0 &&
2113 UpdateLayout(result, callee_layout.mutable_result_layout())) {
2114 VLOG(2) << "Setting result layout from : " << result->ToString()
2115 << "\n";
2116 }
2117 int64_t operand_no = 0;
2118 for (auto* operand : operands) {
2119 if (UpdateLayout(operand,
2120 callee_layout.mutable_parameter_layout(operand_no))) {
2121 VLOG(2) << "Setting callee parameter: " << operand->ToString()
2122 << "\n";
2123 }
2124 ++operand_no;
2125 }
2126 VLOG(2) << "Set callee layout: " << callee->computation()->name() << ":"
2127 << callee_layout.ToString()
2128 << "; original priority = " << callee_constraint->priority()
2129 << "\n";
2130 callee_constraint->ResetComputationLayout(callee_layout, priority, true,
2131 true);
2132 }
2133 return OkStatus();
2134 };
2135 for (HloInstruction* instruction :
2136 constraints->computation()->MakeInstructionPostOrder()) {
2137 switch (instruction->opcode()) {
2138 case HloOpcode::kFusion:
2139 TF_RETURN_IF_ERROR(
2140 SetCalleeLayout(instruction, instruction->operands(),
2141 mutable_computation_constraints(
2142 instruction->fused_instructions_computation()),
2143 current_priority_ + 1));
2144 break;
2145 case HloOpcode::kCall:
2146 if (reverse_computation_order_ &&
2147 SetCalleeLayout(
2148 instruction, instruction->operands(),
2149 mutable_computation_constraints(instruction->to_apply()),
2150 current_priority_ + 1) == OkStatus()) {
2151 VLOG(2) << "Successfully propagated to callee layout\n";
2152 }
2153 break;
2154 case HloOpcode::kConditional:
2155 if (reverse_computation_order_) {
2156 // If the branches don't yet have layouts, propagate existing layout
2157 // inside the branches.
2158 for (int i = 0; i < instruction->branch_count(); ++i) {
2159 TF_RETURN_IF_ERROR(
2160 SetCalleeLayout(instruction, {instruction->operand(i + 1)},
2161 mutable_computation_constraints(
2162 instruction->branch_computation(i)),
2163 current_priority_ + 1));
2164 }
2165 }
2166 break;
2167 case HloOpcode::kWhile:
2168 // If the loop body doesn't have layouts, propagate existing one inside.
2169 if (reverse_computation_order_) {
2170 VLOG(2) << "Populating while loop constraints inside loop body.";
2171 VLOG(2) << instruction->ToString();
2172 TF_RETURN_IF_ERROR(SetCalleeLayout(
2173 instruction, {instruction->operand(0)},
2174 mutable_computation_constraints(instruction->while_body()),
2175 current_priority_ + 1));
2176 VLOG(2) << "Populating while loop constraints inside loop condition.";
2177 VLOG(2) << instruction->ToString();
2178 TF_RETURN_IF_ERROR(SetCalleeLayout(
2179 instruction->operand(0), {instruction->operand(0)},
2180 mutable_computation_constraints(instruction->while_condition()),
2181 current_priority_ + 1));
2182 }
2183 break;
2184 default:
2185 break;
2186 }
2187 }
2188 // Reset the layout of the current computation from its body.
2189 if (current_priority_ == 0 ||
2190 conditional_mismatch_.count(constraints->computation()) > 0) {
2191 TF_RETURN_IF_ERROR(SetCalleeLayout(
2192 constraints->computation()->root_instruction(),
2193 constraints->computation()->parameter_instructions(), constraints,
2194 current_priority_ + kNumberOfPropagationRounds));
2195 if (constraints->computation()->IsEntryComputation()) {
2196 *entry_computation_layout_ = constraints->computation_layout();
2197 }
2198 }
2199 return OkStatus();
2200 }
2201
ClearComputationLayouts(HloComputation * computation)2202 Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
2203 // Clear existing layouts of the instructions. All layouts must be assigned
2204 // by the LayoutAssignment pass, except for those on parameters, the
2205 // computation result, and a couple special cases. The former two are
2206 // specified in computation_layout. Clearing the layouts here avoids hiding
2207 // potential bugs in the layout assignment pass that may accidentally use the
2208 // existing layout.
2209 for (HloInstruction* instruction : computation->instructions()) {
2210 if (instruction->opcode() == HloOpcode::kBitcast) {
2211 // bitcasts are inherently layout sensitive and so a bitcast instruction
2212 // present in the IR before layout assignment is a bug.
2213 return InternalError(
2214 "Unexpected bitcast operation seen during layout assignment: %s.",
2215 instruction->ToString());
2216 }
2217 // Some instructions carry mandatory layouts in their shape.
2218 if (instruction->opcode() != HloOpcode::kInfeed &&
2219 !IsLayoutConstrainedCustomCall(instruction) &&
2220 !IsLayoutConstrainedCollective(instruction)) {
2221 LayoutUtil::ClearLayout(instruction->mutable_shape());
2222 }
2223 }
2224 return OkStatus();
2225 }
2226
RunOnComputation(LayoutConstraints * constraints,ChannelLayoutConstraints * channel_constraints)2227 Status LayoutAssignment::RunOnComputation(
2228 LayoutConstraints* constraints,
2229 ChannelLayoutConstraints* channel_constraints) {
2230 HloComputation* computation = constraints->computation();
2231 VLOG(1) << "LayoutAssignment::RunOnComputation(" << computation->name()
2232 << ")";
2233 VLOG(4) << computation->ToString() << "\n";
2234
2235 // Gather all array-shaped logical buffers into unconstrained_buffer_ids.
2236 for (HloInstruction* inst : computation->instructions()) {
2237 points_to_analysis_->GetPointsToSet(inst).ForEachElement(
2238 [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) {
2239 for (const LogicalBuffer* buffer : buffers) {
2240 // The points to analysis is computed per module, restrict
2241 // constraints to array buffers in this computation.
2242 if (buffer->IsArray() &&
2243 buffer->instruction()->parent() == computation) {
2244 unconstrained_buffer_ids_.insert(buffer->id());
2245 }
2246 }
2247 });
2248 }
2249
2250 // Add constraints required for correctness on all backends (eg, entry
2251 // parameter layout constraints).
2252 TF_RETURN_IF_ERROR(AddMandatoryConstraints(channel_constraints, constraints));
2253
2254 // Add any backend-specific constraints.
2255 TF_RETURN_IF_ERROR(AddBackendConstraints(constraints));
2256
2257 // Propagates layouts from mandatory and backend constraints.
2258 TF_RETURN_IF_ERROR(PropagateConstraints(constraints));
2259
2260 // Prior to applying default layouts, we take note of all HLO instructions
2261 // which lack a layout constraint.
2262 for (LogicalBuffer::Id buffer_id : unconstrained_buffer_ids_) {
2263 VLOG(5)
2264 << "unconstrained instruction:"
2265 << points_to_analysis_->GetBuffer(buffer_id).instruction()->ToString()
2266 << "\n";
2267 unconstrained_layout_instructions_.insert(
2268 points_to_analysis_->GetBuffer(buffer_id).instruction());
2269 }
2270
2271 // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
2272 // layout and propagate the change.
2273 while (!unconstrained_buffer_ids_.empty()) {
2274 int unconstrained_count = unconstrained_buffer_ids_.size();
2275
2276 // Arbitrarily pick the first unconstrained buffer and give it the default
2277 // layout (or the literal layout, in case of constants). By construction
2278 // unconstrained_buffers() has a stable sort based on LogicalBuffer::Id.
2279 const LogicalBuffer& buffer =
2280 points_to_analysis_->GetBuffer(*unconstrained_buffer_ids_.begin());
2281 const HloInstruction* instruction = buffer.instruction();
2282 Layout new_layout =
2283 instruction->opcode() == HloOpcode::kConstant
2284 ? ShapeUtil::GetSubshape(instruction->literal().shape(),
2285 buffer.index())
2286 .layout()
2287 : GetUnconstrainedLayout(buffer);
2288 TF_RETURN_IF_ERROR(SetBufferLayout(new_layout, buffer,
2289 /*mandatory=*/false));
2290
2291 TF_RETURN_IF_ERROR(PropagateConstraints(constraints));
2292
2293 // To verify progress has been made, check that the number of unconstrained
2294 // buffers has been reduced.
2295 CHECK_LT(unconstrained_buffer_ids_.size(), unconstrained_count);
2296 }
2297
2298 TF_RETURN_IF_ERROR(CalculateComputationLayout(constraints));
2299 // Record the layouts assigned for any communication ops in
2300 // channel_constraints so that they are constrained for future modules.
2301 if (channel_constraints != nullptr) {
2302 TF_RETURN_IF_ERROR(
2303 ConstrainChannelLayouts(computation, channel_constraints));
2304 }
2305
2306 return OkStatus();
2307 }
2308
ConstrainChannelLayouts(HloComputation * computation,ChannelLayoutConstraints * channel_constraints)2309 Status LayoutAssignment::ConstrainChannelLayouts(
2310 HloComputation* computation,
2311 ChannelLayoutConstraints* channel_constraints) {
2312 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
2313 if (instruction->IsCrossModuleAllReduce()) {
2314 TF_ASSIGN_OR_RETURN(auto op_layout, InferArrayLayout(instruction, {}));
2315 VLOG(5) << "Constrain cross module all reduce: " << op_layout.ToString()
2316 << "\n";
2317 channel_constraints->ConstrainChannel(instruction->channel_id().value(),
2318 op_layout);
2319 }
2320 }
2321 return OkStatus();
2322 }
2323
PropagateMemorySpace(HloModule * module)2324 Status LayoutAssignment::PropagateMemorySpace(HloModule* module) {
2325 TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module));
2326 for (const auto& buffer : alias_analysis->buffers()) {
2327 // First go through values to collect the memory spaces.
2328 int64_t buffer_memory_space = Layout::kDefaultMemorySpace;
2329 for (auto value : buffer.values()) {
2330 const Shape& defining_shape = value->defining_position().shape();
2331 if (!defining_shape.has_layout()) {
2332 continue;
2333 }
2334 int64_t memory_space = defining_shape.layout().memory_space();
2335 if (memory_space != Layout::kDefaultMemorySpace) {
2336 if (buffer_memory_space != Layout::kDefaultMemorySpace &&
2337 memory_space != buffer_memory_space) {
2338 return InternalError(
2339 "Buffer %d (%s) has conflicting memory spaces: %d and %d.",
2340 buffer.id(), value->ToShortString(), buffer_memory_space,
2341 memory_space);
2342 }
2343 buffer_memory_space = memory_space;
2344 }
2345 }
2346
2347 // If we encounter a memory space other than the default, then propagate all
2348 // the positions with the buffer's memory space.
2349 if (buffer_memory_space != Layout::kDefaultMemorySpace) {
2350 for (auto value : buffer.values()) {
2351 for (auto& position : value->positions()) {
2352 Shape* shape = ShapeUtil::GetMutableSubshape(
2353 position.instruction->mutable_shape(), position.index);
2354 shape->mutable_layout()->set_memory_space(buffer_memory_space);
2355 }
2356 }
2357 }
2358 }
2359 return OkStatus();
2360 }
2361
PropagateComputationLayouts(HloComputation * computation,ComputationLayout * computation_layout)2362 Status LayoutAssignment::PropagateComputationLayouts(
2363 HloComputation* computation, ComputationLayout* computation_layout) {
2364 ComputationLayout computed_computation_layout(
2365 computation->ComputeProgramShape(),
2366 /*ignore_layouts=*/false);
2367 for (int64_t i = 0; i < computed_computation_layout.parameter_count(); ++i) {
2368 ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
2369 bool needs_assign = false;
2370 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2371 param_layout->shape(),
2372 [&](const Shape& subshape, const ShapeIndex& shape_index) {
2373 if (!ShapeUtil::IsLeafIndex(param_layout->shape(), shape_index)) {
2374 return OkStatus();
2375 }
2376 if (!subshape.has_layout()) {
2377 needs_assign = true;
2378 return OkStatus();
2379 }
2380 const auto& computed_subshape = ShapeUtil::GetSubshape(
2381 computed_computation_layout.parameter_shape(i), shape_index);
2382 if (subshape.layout() != computed_subshape.layout()) {
2383 return InternalError(
2384 "Assigned parameter shape %s does not match layout of "
2385 "computation shape: %s",
2386 computed_computation_layout.ToString(),
2387 computation_layout->ToString());
2388 }
2389 return OkStatus();
2390 }));
2391 if (needs_assign) {
2392 VLOG(4) << "Assigning layout to parameter " << i << " of computation "
2393 << computation->name() << ": "
2394 << computed_computation_layout.parameter_layout(i).ToString();
2395 *param_layout = computed_computation_layout.parameter_layout(i);
2396 }
2397 }
2398 ShapeLayout* result_layout = computation_layout->mutable_result_layout();
2399 if (!result_layout->LayoutIsSet()) {
2400 VLOG(4) << "Assigning result layout of computation " << computation->name()
2401 << ": " << computed_computation_layout.result_layout().ToString();
2402 *result_layout = computed_computation_layout.result_layout();
2403 } else {
2404 TF_RET_CHECK(
2405 Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
2406 computed_computation_layout.result_layout().shape(),
2407 result_layout->shape()));
2408 }
2409 return OkStatus();
2410 }
2411
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2412 StatusOr<bool> LayoutAssignment::Run(
2413 HloModule* module,
2414 const absl::flat_hash_set<absl::string_view>& execution_threads) {
2415 VLOG(2) << "Running layout assignment on module " << module->name();
2416 TF_RETURN_IF_ERROR(Init(module));
2417 call_graph_ = CallGraph::Build(module);
2418 // Add copy to the operand of Send instructions, since we cannot call
2419 // SetOperandLayout on Send instructions as it aliases its input to the
2420 // output.
2421 //
2422 // TODO(b/68493863): Remove this once we can call SetOperandLayout() on the
2423 // operand buffers that aliases with the output.
2424 for (HloComputation* computation : module->computations(execution_threads)) {
2425 for (HloInstruction* instruction :
2426 computation->MakeInstructionPostOrder()) {
2427 if (instruction->opcode() == HloOpcode::kSend) {
2428 TF_RETURN_IF_ERROR(AddCopyForOperand(instruction, 0));
2429 }
2430 }
2431 }
2432
2433 // Clone Conditional computations with multiple callsites.
2434 for (HloComputation* computation : module->computations(execution_threads)) {
2435 CallGraphNode& node = call_graph_->GetNode(computation);
2436 if (node.caller_callsites().size() == 1) {
2437 continue;
2438 }
2439 if (absl::c_none_of(node.caller_callsites(), [](CallSite caller) {
2440 return caller.instruction()->opcode() == HloOpcode::kConditional;
2441 })) {
2442 continue;
2443 }
2444 for (int64_t i = 0; i < node.caller_callsites().size() - 1; ++i) {
2445 HloInstruction* caller = node.caller_callsites()[i].instruction();
2446 if (caller->opcode() == HloOpcode::kConditional) {
2447 for (int64_t k = 0; k < caller->branch_count(); ++k) {
2448 if (computation == caller->branch_computation(k)) {
2449 caller->set_branch_computation(
2450 k, module->AddEmbeddedComputation(computation->Clone()));
2451 break;
2452 }
2453 }
2454 }
2455 }
2456 }
2457
2458 // Verify computation layout is sane.
2459 HloComputation* entry = module->entry_computation();
2460 TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
2461 entry->num_parameters());
2462 for (int64_t i = 0; i < entry->num_parameters(); ++i) {
2463 TF_RET_CHECK(
2464 ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
2465 entry->parameter_instruction(i)->shape()));
2466 }
2467 TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
2468 entry->root_instruction()->shape()));
2469 // We do two passes. The first one we pass a nullptr ComputationLayout to
2470 // the RunOnComputation() calls (for non entry computations), and we register
2471 // the ComputationLayout which are naturally flowing in DFS fashion to the
2472 // parameters and root instruction.
2473 // Walking in DFS mode though, means that we can end up with incorrect layouts
2474 // when seen from an outer instruction, which has across-computation
2475 // constraints to impose.
2476 // For example, the kWhile instruction needs to enforce the same layouts for
2477 // the parameters and root of the body, as well as the condition parameters.
2478 // Similarly, the kConditional instruction needs to enforce the same layouts
2479 // for the root of the true and false computations.
2480 // So in the first pass, while allowing the layouts to flow to parameters and
2481 // root, we also fix up the eventually inconsistent ComputationLayout, which
2482 // will be then made mandatory by the second pass.
2483 TF_ASSIGN_OR_RETURN(auto points_to_analysis,
2484 TuplePointsToAnalysis::Run(module));
2485 points_to_analysis_ = std::move(points_to_analysis);
2486 auto computations_to_work =
2487 module->MakeNonfusionComputations(execution_threads);
2488 // If the reverse_comptation_order_ flag is set, reverse the ordering of
2489 // traversing computations, to generate an alternative layout assignment.
2490 if (reverse_computation_order_ && !computations_to_work.empty()) {
2491 absl::c_reverse(computations_to_work);
2492
2493 VLOG(2) << "reversing traversal order for computation:";
2494 }
2495 computation_layouts_.emplace(
2496 module->entry_computation(),
2497 new LayoutConstraints(entry,
2498 entry_computation_layout_->LayoutIsSet()
2499 ? entry_computation_layout_
2500 : nullptr,
2501 entry_computation_layout_->LayoutIsSet()
2502 ? LayoutConstraint::kGivenPriority
2503 : LayoutConstraint::kDefaultPriority));
2504 for (int64_t i = 0; i < kNumberOfPropagationRounds; ++i) {
2505 VLOG(1) << "Running " << (i == 0 ? "un" : "") << "constrained pass";
2506 TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module, execution_threads));
2507 for (auto* computation : computations_to_work) {
2508 LayoutConstraints* constraints =
2509 mutable_computation_constraints(computation);
2510 TF_RETURN_IF_ERROR(
2511 RunOnComputation(constraints, channel_layout_constraints_));
2512 }
2513 current_priority_ += 1;
2514 }
2515
2516 for (auto* computation : computations_to_work) {
2517 LayoutConstraints* constraints =
2518 FindOrDie(computation_layouts_, computation).get();
2519 // All logical buffers should have constraints at this point. All that
2520 // remains is assign the constraints to the buffers and infer layouts for
2521 // aliased buffers.
2522 TF_RETURN_IF_ERROR(AssignLayouts(*constraints));
2523 }
2524 TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
2525 entry_computation_layout_));
2526
2527 TF_RETURN_IF_ERROR(PropagateMemorySpace(module));
2528
2529 TF_RETURN_IF_ERROR(CheckLayouts(module, execution_threads));
2530
2531 // All layouts are reset then reassigned by this pass.
2532 return true;
2533 }
2534
2535 /* static */
InstructionCanChangeLayout(const HloInstruction * instruction)2536 bool LayoutAssignment::InstructionCanChangeLayout(
2537 const HloInstruction* instruction) {
2538 switch (instruction->opcode()) {
2539 case HloOpcode::kAbs:
2540 case HloOpcode::kAdd:
2541 case HloOpcode::kAddDependency:
2542 case HloOpcode::kAnd:
2543 case HloOpcode::kAtan2:
2544 case HloOpcode::kBitcastConvert:
2545 case HloOpcode::kCeil:
2546 case HloOpcode::kClamp:
2547 case HloOpcode::kClz:
2548 case HloOpcode::kCompare:
2549 case HloOpcode::kComplex:
2550 case HloOpcode::kConcatenate:
2551 case HloOpcode::kConvert:
2552 case HloOpcode::kCos:
2553 case HloOpcode::kAllGather:
2554 case HloOpcode::kAllGatherStart:
2555 case HloOpcode::kAllGatherDone:
2556 case HloOpcode::kAllToAll:
2557 case HloOpcode::kCollectivePermute:
2558 case HloOpcode::kDivide:
2559 case HloOpcode::kDynamicSlice:
2560 case HloOpcode::kDynamicUpdateSlice:
2561 case HloOpcode::kExp:
2562 case HloOpcode::kExpm1:
2563 case HloOpcode::kFft:
2564 case HloOpcode::kFloor:
2565 case HloOpcode::kImag:
2566 case HloOpcode::kIsFinite:
2567 case HloOpcode::kLog:
2568 case HloOpcode::kLog1p:
2569 case HloOpcode::kLogistic:
2570 case HloOpcode::kMap:
2571 case HloOpcode::kMaximum:
2572 case HloOpcode::kMinimum:
2573 case HloOpcode::kMultiply:
2574 case HloOpcode::kNegate:
2575 case HloOpcode::kNot:
2576 case HloOpcode::kOptimizationBarrier:
2577 case HloOpcode::kOr:
2578 case HloOpcode::kXor:
2579 case HloOpcode::kPad:
2580 case HloOpcode::kPower:
2581 case HloOpcode::kReal:
2582 case HloOpcode::kReducePrecision:
2583 case HloOpcode::kReduceWindow:
2584 case HloOpcode::kRemainder:
2585 case HloOpcode::kReverse:
2586 case HloOpcode::kRoundNearestAfz:
2587 case HloOpcode::kRoundNearestEven:
2588 case HloOpcode::kRsqrt:
2589 case HloOpcode::kScatter:
2590 case HloOpcode::kSelect:
2591 case HloOpcode::kSelectAndScatter:
2592 case HloOpcode::kShiftLeft:
2593 case HloOpcode::kShiftRightArithmetic:
2594 case HloOpcode::kShiftRightLogical:
2595 case HloOpcode::kSign:
2596 case HloOpcode::kSin:
2597 case HloOpcode::kSlice:
2598 case HloOpcode::kSort:
2599 case HloOpcode::kSqrt:
2600 case HloOpcode::kCbrt:
2601 case HloOpcode::kSubtract:
2602 case HloOpcode::kTanh:
2603 case HloOpcode::kPopulationCount:
2604 case HloOpcode::kTriangularSolve:
2605 case HloOpcode::kCholesky:
2606 case HloOpcode::kWhile:
2607 case HloOpcode::kSetDimensionSize:
2608 // AllReduce is variadic so it needs to be careful to assign the same layout
2609 // to the corresponding input argument and Tuple index.
2610 case HloOpcode::kAllReduce:
2611 case HloOpcode::kReduceScatter:
2612 case HloOpcode::kAllReduceStart:
2613 case HloOpcode::kAllReduceDone:
2614 return false;
2615 case HloOpcode::kAsyncStart:
2616 case HloOpcode::kAsyncUpdate:
2617 case HloOpcode::kAsyncDone:
2618 case HloOpcode::kBatchNormGrad:
2619 case HloOpcode::kBatchNormInference:
2620 case HloOpcode::kBatchNormTraining:
2621 case HloOpcode::kBitcast:
2622 case HloOpcode::kBroadcast:
2623 case HloOpcode::kCall:
2624 case HloOpcode::kCollectivePermuteStart:
2625 case HloOpcode::kCollectivePermuteDone:
2626 case HloOpcode::kConditional:
2627 case HloOpcode::kConstant:
2628 case HloOpcode::kConvolution:
2629 case HloOpcode::kCopy:
2630 case HloOpcode::kCopyStart:
2631 case HloOpcode::kCopyDone:
2632 case HloOpcode::kCustomCall:
2633 case HloOpcode::kDomain:
2634 case HloOpcode::kDot:
2635 case HloOpcode::kFusion:
2636 case HloOpcode::kGather:
2637 case HloOpcode::kGetTupleElement:
2638 case HloOpcode::kInfeed:
2639 case HloOpcode::kIota:
2640 case HloOpcode::kOutfeed:
2641 case HloOpcode::kParameter:
2642 case HloOpcode::kPartitionId:
2643 case HloOpcode::kRecv:
2644 case HloOpcode::kRecvDone:
2645 case HloOpcode::kReduce:
2646 case HloOpcode::kReplicaId:
2647 case HloOpcode::kReshape:
2648 case HloOpcode::kDynamicReshape:
2649 case HloOpcode::kRng:
2650 case HloOpcode::kRngBitGenerator:
2651 case HloOpcode::kRngGetAndUpdateState:
2652 case HloOpcode::kSend:
2653 case HloOpcode::kSendDone:
2654 case HloOpcode::kAfterAll:
2655 case HloOpcode::kTranspose:
2656 case HloOpcode::kTuple:
2657 case HloOpcode::kGetDimensionSize:
2658 return true;
2659 }
2660 }
2661
InstructionCanChangeLayoutInstance(const HloInstruction * instruction)2662 bool LayoutAssignment::InstructionCanChangeLayoutInstance(
2663 const HloInstruction* instruction) {
2664 return InstructionCanChangeLayout(instruction);
2665 }
2666
2667 /* static */
IsAtMostRank1(const Shape & shape)2668 bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
2669 if (shape.IsArray()) {
2670 return shape.rank() <= 1;
2671 }
2672 return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) {
2673 return IsAtMostRank1(subshape);
2674 });
2675 }
2676
Init(HloModule * module)2677 Status LayoutAssignment::Init(HloModule* module) {
2678 computation_layouts_.clear();
2679 conditional_mismatch_.clear();
2680 *entry_computation_layout_ = saved_entry_computation_layout_;
2681 current_priority_ = LayoutConstraint::kBeginningPriority;
2682 // Clear all the copies which have been added, and all the related
2683 // instructions (like GTE and tuples).
2684 int64_t removed_copies = 0;
2685 for (HloComputation* computation : module->computations()) {
2686 for (HloInstruction* instruction :
2687 computation->MakeInstructionPostOrder()) {
2688 if (instruction->opcode() == HloOpcode::kCopy &&
2689 added_copies_.contains(instruction)) {
2690 VLOG(5) << "Removing added copy: " << instruction->ToString();
2691 TF_RETURN_IF_ERROR(
2692 instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
2693 TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
2694 ++removed_copies;
2695 }
2696 }
2697 }
2698 added_copies_.clear();
2699 if (removed_copies > 0) {
2700 TupleSimplifier tuple_simplifier;
2701 HloDCE dce;
2702 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2703 TF_RETURN_IF_ERROR(dce.Run(module).status());
2704 call_graph_ = CallGraph::Build(module);
2705 }
2706 return OkStatus();
2707 }
2708
ClearPreviousPassSideEffects(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2709 Status LayoutAssignment::ClearPreviousPassSideEffects(
2710 HloModule* module,
2711 const absl::flat_hash_set<absl::string_view>& execution_threads) {
2712 VLOG(5) << "Clearing previous side effects";
2713 for (HloComputation* computation : module->computations(execution_threads)) {
2714 if (computation_layouts_.find(computation) != computation_layouts_.end()) {
2715 mutable_computation_constraints(computation)->ResetOperandConstraints();
2716 }
2717 }
2718 unconstrained_layout_instructions_.clear();
2719 unconstrained_buffer_ids_.clear();
2720 buffer_constraints_.clear();
2721 buffer_sets_cache_.clear();
2722 return OkStatus();
2723 }
AddCopyForOperand(HloInstruction * instruction,int64_t operand_number)2724 Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
2725 int64_t operand_number) {
2726 HloInstruction* operand = instruction->mutable_operand(operand_number);
2727 if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
2728 HloInstruction* copy =
2729 instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
2730 operand->shape(), HloOpcode::kCopy, operand));
2731 SetupCopiedInstruction(*operand, copy, {});
2732 LayoutUtil::ClearLayout(copy->mutable_shape());
2733 TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
2734 }
2735 return OkStatus();
2736 }
2737
2738 } // namespace xla
2739