1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ 18 19 #include <iosfwd> 20 #include <map> 21 #include <memory> 22 #include <set> 23 #include <string> 24 #include <utility> 25 #include <vector> 26 27 #include "absl/container/flat_hash_map.h" 28 #include "absl/container/flat_hash_set.h" 29 #include "absl/container/node_hash_map.h" 30 #include "tensorflow/compiler/xla/layout_util.h" 31 #include "tensorflow/compiler/xla/service/call_graph.h" 32 #include "tensorflow/compiler/xla/service/computation_layout.h" 33 #include "tensorflow/compiler/xla/service/hlo_computation.h" 34 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 35 #include "tensorflow/compiler/xla/service/hlo_module.h" 36 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 37 #include "tensorflow/compiler/xla/service/logical_buffer.h" 38 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 39 #include "tensorflow/compiler/xla/shape_layout.h" 40 #include "tensorflow/compiler/xla/shape_util.h" 41 #include "tensorflow/compiler/xla/statusor.h" 42 #include "tensorflow/compiler/xla/types.h" 43 #include "tensorflow/compiler/xla/xla_data.pb.h" 44 #include "tensorflow/core/lib/core/status.h" 45 46 namespace xla { 47 48 // Abstract base class for layout constraints. These constraint objects are 49 // gathered together in LayoutConstraints object. 50 class LayoutConstraint { 51 public: LayoutConstraint(bool mandatory,bool dfs,int64_t priority)52 LayoutConstraint(bool mandatory, bool dfs, int64_t priority) 53 : mandatory_(mandatory), dfs_(dfs), priority_(priority) {} 54 virtual ~LayoutConstraint() = default; 55 56 virtual std::string ToString() const = 0; 57 58 // True if this constraint cannot be overwritten by a different constraint. mandatory()59 bool mandatory() const { return mandatory_; } 60 61 // When true, propagate in DFS. When false, constraint will propagate in BFS. dfs()62 bool dfs() const { return dfs_; } 63 64 // Return the priority of the current constraint. When conflicting constraints 65 // are encountered, the higher priority one should win. priority()66 int64_t priority() const { return priority_; } IsDefaultLayout()67 bool IsDefaultLayout() const { return priority_ == kDefaultPriority; } 68 69 // The priority of all default layouts when not set explicitly. 70 static constexpr int64_t kDefaultPriority = -2; 71 // The beginning priority of layout assignment. 72 static constexpr int64_t kBeginningPriority = 0; 73 // The priority of layout assignment given by the user for entry computation. 74 static constexpr int64_t kGivenPriority = 3; 75 76 protected: 77 bool mandatory_; 78 bool dfs_; 79 int64_t priority_; 80 }; 81 82 std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); 83 84 // Layout constraint on a single LogicalBuffer. This constrains the layout of an 85 // array produced by a particular instruction. 86 class BufferLayoutConstraint : public LayoutConstraint { 87 public: 88 BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer, 89 bool mandatory, bool dfs, int64_t priority); 90 buffer()91 const LogicalBuffer& buffer() const { return *buffer_; } layout()92 const Layout& layout() const { return layout_; } 93 bool UpdateLayout(int64_t priority, const Layout& layout, bool mandatory, 94 bool dfs); 95 96 std::string ToString() const override; 97 98 private: 99 Layout layout_; 100 const LogicalBuffer* buffer_; 101 }; 102 103 // Constraint on the layout of the operand of an instruction. The constrained 104 // shape can be arbitrarily shaped (array or tuple). This is a constraint on the 105 // use of a shaped value and is not a hard constraint on the instruction(s) 106 // which define the value as copies may be inserted between the definition and 107 // use. 108 class OperandLayoutConstraint : public LayoutConstraint { 109 public: 110 OperandLayoutConstraint(const ShapeLayout& shape_layout, 111 const HloInstruction* instruction, int64_t operand_no, 112 bool mandatory, bool dfs, int64_t priority); 113 shape_layout()114 const ShapeLayout& shape_layout() const { return shape_layout_; } instruction()115 const HloInstruction* instruction() const { return instruction_; } operand_no()116 const int64_t operand_no() const { return operand_no_; } operand()117 const HloInstruction* operand() const { 118 return instruction_->operand(operand_no_); 119 } 120 121 std::string ToString() const override; 122 123 private: 124 ShapeLayout shape_layout_; 125 const HloInstruction* instruction_; 126 int64_t operand_no_; 127 }; 128 129 // Constraint on the layout of a computation interface. 130 class ComputationLayoutConstraint : public LayoutConstraint { 131 public: 132 static constexpr int64_t kDefaultLayoutIsUsed = 0; 133 static constexpr int64_t kResultLayoutIsSet = 1; 134 static constexpr int64_t kParameterLayoutIsSet = 2; 135 static constexpr int64_t kComputationLayoutIsSet = 3; ComputationLayoutConstraint(const HloComputation * computation,ComputationLayout * computation_layout,int64_t priority)136 explicit ComputationLayoutConstraint(const HloComputation* computation, 137 ComputationLayout* computation_layout, 138 int64_t priority) 139 : LayoutConstraint(/*mandatory=*/true, /*dfs=*/true, priority), 140 layout_state_((computation_layout == nullptr) 141 ? kDefaultLayoutIsUsed 142 : kComputationLayoutIsSet), 143 computation_layout_( 144 (computation_layout == nullptr) 145 ? ComputationLayout(computation->ComputeProgramShape(), 146 /*ignore_layouts=*/false) 147 : *computation_layout) {} 148 computation_layout()149 const ComputationLayout& computation_layout() const { 150 return computation_layout_; 151 } ResetComputationLayout(const ComputationLayout & layout,int64_t priority,bool prop_result_layout,bool prop_parameter_layout)152 void ResetComputationLayout(const ComputationLayout& layout, int64_t priority, 153 bool prop_result_layout, 154 bool prop_parameter_layout) { 155 computation_layout_ = layout; 156 priority_ = priority; 157 if (prop_result_layout) { 158 layout_state_ |= kResultLayoutIsSet; 159 } 160 if (prop_parameter_layout) { 161 layout_state_ |= kParameterLayoutIsSet; 162 } 163 } ResetResultLayout(const ShapeLayout & shape_layout,int64_t priority)164 void ResetResultLayout(const ShapeLayout& shape_layout, int64_t priority) { 165 *computation_layout_.mutable_result_layout() = shape_layout; 166 layout_state_ |= kResultLayoutIsSet; 167 priority_ = priority; 168 } parameter_layout_is_set()169 bool parameter_layout_is_set() const { 170 return layout_state_ & kParameterLayoutIsSet; 171 } result_layout_is_set()172 bool result_layout_is_set() const { 173 return layout_state_ & kResultLayoutIsSet; 174 } default_layout_is_used()175 bool default_layout_is_used() const { 176 return layout_state_ == kDefaultLayoutIsUsed; 177 } 178 std::string ToString() const override; 179 180 private: 181 // The layout_state_ variable is used to remember whether the layout for 182 // the overall computation is explicitly set, whether its result layout is 183 // explicitly set, or whether it only stores the default layout of the 184 // computation. 185 int64_t layout_state_; 186 ComputationLayout computation_layout_; 187 }; 188 189 // Contains constraints on the layout of channels; sends and recvs. 190 class ChannelLayoutConstraints { 191 public: 192 // Construct an empty constraint set. ChannelLayoutConstraints()193 ChannelLayoutConstraints() {} 194 195 // Returns true if channel_id has a layout constraint. IsChannelConstrained(int64_t channel_id)196 bool IsChannelConstrained(int64_t channel_id) const { 197 return constraints_.contains(channel_id); 198 } 199 200 // Given `shape`, apply the layout for `channel_id`. `channel_id` must already 201 // be constrained. LayoutShapeForChannel(Shape shape,int64_t channel_id)202 Shape LayoutShapeForChannel(Shape shape, int64_t channel_id) const { 203 auto it = constraints_.find(channel_id); 204 CHECK(it != constraints_.end()) << "Channel " << channel_id; 205 *shape.mutable_layout() = it->second; 206 return shape; 207 } 208 209 // Returns the layout constraint for `channel_id`, which must already be 210 // constrained. LayoutForChannel(int64_t channel_id)211 const Layout& LayoutForChannel(int64_t channel_id) const { 212 auto it = constraints_.find(channel_id); 213 CHECK(it != constraints_.end()) << "Channel " << channel_id; 214 return it->second; 215 } 216 217 // Adds a new layout constraint for `channel_id`. If a constraint for 218 // `channel_id` has been added, this API returns nullptr, otherwise returns 219 // the layout which has already been set for the channel. ConstrainChannel(int64_t channel_id,const Layout & layout)220 const Layout* ConstrainChannel(int64_t channel_id, const Layout& layout) { 221 auto it = constraints_.emplace(std::make_pair(channel_id, layout)); 222 if (it.second) { 223 return nullptr; 224 } 225 return LayoutUtil::Equal(layout, it.first->second) ? nullptr 226 : &it.first->second; 227 } 228 229 private: 230 absl::flat_hash_map<int64_t, Layout> constraints_; 231 }; 232 233 // HLO pass which assigns layouts to all instructions in the HLO module while 234 // satisfying all necessary invariants and minimizing cost. 235 class LayoutAssignment : public HloModulePass { 236 public: 237 // entry_computation_layout is modified to populate a layout for the result in 238 // the case that no particular layout is requested. 239 // 240 // channel_constraints is both an input and output. Any sends or recvs that 241 // are present in channel_constraints will be laid out as constrained. Any 242 // unconstrained sends or recvs will be laid out as locally optimal and their 243 // layout will be added as a constraint to channel_constraints. 244 // 245 // If channel_constraints is nullptr, no kSend or kRecvs must be contained 246 // within any module passed to `Run`. 247 explicit LayoutAssignment( 248 ComputationLayout* entry_computation_layout, 249 ChannelLayoutConstraints* channel_constraints = nullptr, 250 bool reverse_computation_order = false); ~LayoutAssignment()251 ~LayoutAssignment() override {} points_to_analysis()252 const TuplePointsToAnalysis& points_to_analysis() const { 253 return *points_to_analysis_; 254 } name()255 absl::string_view name() const override { return "layout-assignment"; } 256 257 // Assign layouts to the given module. Returns whether the module was changed 258 // (any layouts were changed). 259 using HloPassInterface::Run; 260 StatusOr<bool> Run( 261 HloModule* module, 262 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 263 264 // Class encapsulating the layout constraints of the values in a HLO 265 // computation. 266 class LayoutConstraints { 267 public: 268 explicit LayoutConstraints(HloComputation* computation, 269 ComputationLayout* computation_layout, 270 int64_t priority); 271 ~LayoutConstraints() = default; 272 computation()273 const HloComputation* computation() const { return computation_; } computation()274 HloComputation* computation() { return computation_; } ResetOperandConstraints()275 void ResetOperandConstraints() { operand_constraints_.clear(); } 276 const ShapeLayout* OperandLayout(const HloInstruction* instruction, 277 int64_t operand_no) const; 278 const OperandLayoutConstraint* GetOperandLayoutConstraint( 279 const HloInstruction* instruction, int64_t operand_no) const; 280 const ShapeLayout* ResultLayout() const; 281 OperandLayoutConstraint* InsertOperandLayoutConstraint( 282 const HloInstruction* instruction, int64_t operand_no, 283 const OperandLayoutConstraint& constraint); 284 Status SetResultLayout(LayoutAssignment* assignment, 285 const Shape& shape_with_layout, int64_t priority); 286 computation_layout()287 const ComputationLayout& computation_layout() const { 288 return computation_constraint_.computation_layout(); 289 } computation_constraint()290 const ComputationLayoutConstraint& computation_constraint() const { 291 return computation_constraint_; 292 } mutable_computation_constraint()293 ComputationLayoutConstraint* mutable_computation_constraint() { 294 return &computation_constraint_; 295 } 296 297 private: 298 // The set of OperandLayoutConstraints applied to the computation. 299 using OperandConstraintKey = std::pair<const HloInstruction*, int64_t>; 300 std::map<OperandConstraintKey, OperandLayoutConstraint> 301 operand_constraints_; 302 303 HloComputation* computation_; 304 ComputationLayoutConstraint computation_constraint_; 305 }; 306 307 // Determines whether an instruction can change layouts. An instruction not 308 // being able to change layout means that it requires operands with the same 309 // rank as the output to have the same layout as the output. 310 static bool InstructionCanChangeLayout(const HloInstruction* instruction); 311 mutable_computation_constraints(HloComputation * computation)312 LayoutConstraints* mutable_computation_constraints( 313 HloComputation* computation) { 314 auto it = computation_layouts_.find(computation); 315 LayoutConstraints* constraints = nullptr; 316 if (it == computation_layouts_.end()) { 317 computation_layouts_.emplace( 318 computation, 319 constraints = new LayoutConstraints( 320 computation, nullptr, LayoutConstraint::kDefaultPriority)); 321 } else { 322 constraints = (*it).second.get(); 323 } 324 return constraints; 325 } 326 void PushAddedConstraints(const LayoutConstraint* constraint); 327 328 // In case of an array shape returns true iff it is at most rank 1. In case of 329 // a tuple shape returns true iff all leaf shapes are at most rank 1. 330 static bool IsAtMostRank1(const Shape& shape); 331 // Convenience wrapper around SetOperandLayout for setting the layout of a 332 // operand using a Layout object. The operand must be array-shaped. 333 Status SetArrayOperandLayout(const Layout& layout, 334 const HloInstruction* instruction, 335 int64_t operand_no, bool mandatory = true, 336 bool dfs = true) { 337 return SetArrayOperandLayout(layout, instruction, operand_no, mandatory, 338 dfs, current_priority_); 339 } 340 Status SetArrayOperandLayout(const Layout& layout, 341 const HloInstruction* instruction, 342 int64_t operand_no, bool mandatory, bool dfs, 343 int64_t priority); 344 // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers 345 // created by the instruction to the layouts in the given shape. The 346 // instruction must define every logical buffer in its output. 347 // If `allow_alias` is false, the function will check that all output buffers 348 // are defined by `instruction`, not aliased to an instruction elsewhere. 349 Status SetInstructionLayout(const Shape& shape_with_layout, 350 const HloInstruction* instruction, 351 bool mandatory = true, bool dfs = true, 352 bool allow_alias = false) { 353 return SetInstructionLayout(shape_with_layout, instruction, mandatory, dfs, 354 allow_alias, current_priority_); 355 } 356 Status SetInstructionLayout(const Shape& shape_with_layout, 357 const HloInstruction* instruction, bool mandatory, 358 bool dfs, bool allow_alias, int64_t priority); 359 // Set the same given layout across all components of the instruction output. 360 // It works the same as the API above if the output is a single array. 361 Status SetInstructionLayout(const Layout& layout, 362 const HloInstruction* instruction, 363 bool mandatory = true, bool dfs = true, 364 bool allow_alias = false, int64_t priority = -1); 365 // Add a constraint on the layout of a LogicalBuffer, the layout of the 366 // operand of the instruction, or the layout of the result of the computation, 367 // respectively. 368 Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer, 369 bool mandatory = true, bool dfs = true) { 370 return SetBufferLayout(layout, buffer, mandatory, dfs, current_priority_); 371 } 372 Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer, 373 bool mandatory, bool dfs, int64_t priority); 374 Status SetOperandLayout(const Shape& shape_with_layout, 375 const HloInstruction* instruction, int64_t operand_no, 376 bool mandatory = true, bool dfs = true) { 377 return SetOperandLayout(shape_with_layout, instruction, operand_no, 378 mandatory, dfs, current_priority_); 379 } 380 Status SetOperandLayout(const Shape& shape_with_layout, 381 const HloInstruction* instruction, int64_t operand_no, 382 bool mandatory, bool dfs, int64_t priority); reverse_computation_order()383 bool reverse_computation_order() const { return reverse_computation_order_; } 384 saved_entry_computation_layout()385 ComputationLayout& saved_entry_computation_layout() { 386 return saved_entry_computation_layout_; 387 } 388 389 protected: 390 // These methods, invoked by PropagateConstraints, propagate a layout 391 // constraint to its neighbors (i.e. operands and users) in order to minimize 392 // the cost of the instructions being constrainted on. New constraints are 393 // added to the given constraint set. 394 // 395 // Backends can override these methods with backend-specific propagation 396 // rules. 397 virtual Status PropagateBufferConstraint( 398 const BufferLayoutConstraint& buffer_constraint, 399 LayoutConstraints* constraints); 400 virtual Status PropagateOperandConstraint( 401 const OperandLayoutConstraint& operand_constraint, 402 LayoutConstraints* constraints); 403 virtual Status PropagateResultConstraint( 404 const ComputationLayoutConstraint& layout_constraint, 405 LayoutConstraints* constraints); 406 GetUnconstrainedLayout(const LogicalBuffer & buffer)407 virtual Layout GetUnconstrainedLayout(const LogicalBuffer& buffer) { 408 return LayoutUtil::GetDefaultLayoutForShape(buffer.shape()); 409 } 410 // Called after layouts of an instruction have been finalized to allow 411 // subclasses to check for platform specific assumptions. Verify(const HloInstruction * instruction)412 virtual Status Verify(const HloInstruction* instruction) { 413 return OkStatus(); 414 } 415 416 Status PropagateUnconstraintedBuffers(LayoutConstraints* constraints); 417 const BufferLayoutConstraint* GetBufferLayoutConstraint( 418 const LogicalBuffer& buffer) const; 419 // Find a bufferset in the bufferset cache. This is useful since we can 420 // currently create the flattened buffer set for the same instruction many 421 // times, which is often slow. 422 PointsToSet::BufferSet* GetBufferSet(const HloInstruction* instruction) const; 423 // Similar to above, but returns true only if all buffers associated with that 424 // operand are forwarded. 425 bool AllOperandBuffersForwarded(const HloInstruction* instruction, 426 int64_t operand_no) const; 427 // Returns true if any buffer in the given operand is forwarded to the output 428 // of the given instruction. For example, the Tuple instruction forwards the 429 // buffers of its operands and would return true for each of its operands. 430 bool AnyOperandBufferForwarded(const HloInstruction* instruction, 431 int64_t operand_no) const; 432 StatusOr<Layout> InferArrayLayout(const HloInstruction* instruction, 433 const ShapeIndex& index); 434 435 // Propagates a buffer layout constraint into the operands that use it. 436 Status PropagateBufferConstraintToUses( 437 const BufferLayoutConstraint& buffer_constraint, 438 LayoutConstraints* constraints); 439 440 // Propagates a layout constraint on the use of the result of the given 441 // instruction to the definitions of the LogicalBuffers which make up the 442 // result. 443 Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout, 444 const HloInstruction* instruction, 445 LayoutConstraints* constraints, 446 int64_t priority); 447 448 // Propagates the memory space defined in the entry computation to the called 449 // computations. 450 Status PropagateMemorySpace(HloModule* module); 451 452 // Chooses a layout of operand `operand_no` of `instruction` that minimizes 453 // the cost of `instruction`. `output_layout` is the layout of `instruction`. 454 // Returns null if it can't decide the best layout. 455 // Precondition: `instruction` and the operand are array-shaped. 456 virtual std::unique_ptr<Layout> ChooseOperandLayoutFromOutputLayout( 457 const Layout& output_layout, const HloInstruction* instruction, 458 int64_t operand_no); 459 // Given the layout of `user`'s `operand_no`-th operand, chooses a layout of 460 // `user` that minimizes its cost on that operand. Returns null if it can't 461 // decide the best layout. 462 // Precondition: `user` and the operand are array-shaped. 463 virtual std::unique_ptr<Layout> ChooseOutputLayoutFromOperandLayout( 464 const Layout& operand_layout, const HloInstruction* user, 465 int64_t operand_no); 466 467 // Convenient wrapper for InstructionCanChangeLayout which can be overridden 468 // in subclasses. 469 virtual bool InstructionCanChangeLayoutInstance( 470 const HloInstruction* instruction); 471 472 private: 473 // Initializes the layout assignment object for a new Run() call. 474 Status Init(HloModule* module); 475 476 // Adds constraints which must be satisfied for correctness on all 477 // backends. Called once prior to propagating constraints. 478 Status AddMandatoryConstraints(ChannelLayoutConstraints* channel_constraints, 479 LayoutConstraints* constraints); 480 481 // Return a vector containing the constraints which have been added to the 482 // LayoutConstraints object since the construction of the object or since the 483 // last time ConsumeAddedConstraints() has been called. This is used to 484 // identify newly added constraints when propagating layouts. ConsumeAddedConstraints()485 std::vector<const LayoutConstraint*> ConsumeAddedConstraints() { 486 std::vector<const LayoutConstraint*> ret_vec(std::move(added_constraints_)); 487 added_constraints_.clear(); 488 return ret_vec; 489 } ClearAddedConstraints()490 void ClearAddedConstraints() { added_constraints_.clear(); } 491 492 // This method can be overridden to add backend-specific constraints to the 493 // layout of the instructions of a computation. This method is called after 494 // all mandatory constraints have been added via AddMandatoryConstraints 495 // and before propagating constraints. AddBackendConstraints(LayoutConstraints * constraints)496 virtual Status AddBackendConstraints(LayoutConstraints* constraints) { 497 return OkStatus(); 498 } 499 500 // Construct constraints and assign layouts to all instructions in the 501 // computation satisfying the given ComputationLayout, if not nullptr. 502 // Otherwise the ComputationLayout will be calculated by propagating the 503 // computation instruction constraints. 504 // Layouts constraints are added, then propagated until all LogicalBuffers in 505 // the computation are constrained. 506 Status RunOnComputation(LayoutConstraints* constraints, 507 ChannelLayoutConstraints* channel_constraints); 508 509 // Assign layouts to the instructions of a computation which satisfy the given 510 // layout constraints. Copies may be added to satisfy the constraints. The 511 // given LayoutConstraints must have layout constraints every logical buffer 512 // in the computation. 513 Status AssignLayouts(LayoutConstraints& constraints); 514 515 // Propagates layout constraints from a set of initial constraints in order to 516 // minimize the local cost of the computation. This propagation is *not* 517 // required for correctness. 518 Status PropagateConstraints(LayoutConstraints* constraints); 519 520 Status PropagateBufferConstraintToOperands( 521 const BufferLayoutConstraint& buffer_constraint, 522 LayoutConstraints* constraints); 523 524 // Check that all layouts in the module have been set and satisfy all 525 // necessary conditions. 526 Status CheckLayouts( 527 HloModule* module, 528 const absl::flat_hash_set<absl::string_view>& execution_threads); 529 530 // Computes the ComputationLayout of the given constraints based of the 531 // layouts assigned to parameters and root instruction. Also propagate 532 // constraints to computation nested inside. 533 Status CalculateComputationLayout(LayoutConstraints* constraints); 534 535 // Clears all the layouts which can be cleared within a computation. 536 Status ClearComputationLayouts(HloComputation* computation); 537 538 // Clears the side effects of a previous pass, like added copy instructions. 539 Status ClearPreviousPassSideEffects( 540 HloModule* module, 541 const absl::flat_hash_set<absl::string_view>& execution_threads); 542 543 // Propagates the layouts computed by the layout assignment pass on the given 544 // computation, to the computation layout passed in to this API. 545 // This API propagates missing layout, and also checks that the caller 546 // specified have been respected, by comparing those with the parameters and 547 // root computation instruction. 548 Status PropagateComputationLayouts(HloComputation* computation, 549 ComputationLayout* computation_layout); 550 551 // The pointer to the ComputationLayout passed as constructor parameter. 552 ComputationLayout* entry_computation_layout_; 553 554 // A copy of entry_computation_layout_ used to reset it to the initial values 555 // during the multiple passes done by the layout assignment operation. 556 ComputationLayout saved_entry_computation_layout_; 557 // If set true, reverse the computation traversal order when assigning layout. 558 bool reverse_computation_order_; 559 560 protected: 561 static constexpr int64_t kNumberOfPropagationRounds = 2; 562 // Sets up the copy instruction according to the characteristic (sharding, 563 // metadata, ...) of the reference instruction. The index argument is used 564 // when the instruction is a tuple, and in such case the index represents 565 // the location from where the copy instruction was created from. 566 // If the index is empty, the whole sharding will be propagated, even in case 567 // the instruction has a tuple sharding. 568 static void SetupCopiedInstruction(const HloInstruction& instruction, 569 HloInstruction* copy, 570 const ShapeIndex& index); 571 572 // Creates and returns a copy of the given instruction with a different 573 // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple 574 // instruction producing the copy is returned. 575 StatusOr<HloInstruction*> CreateCopyWithNewLayout( 576 const Shape& shape_with_layout, HloInstruction* instruction); 577 578 // Creates a copy of the given operand if the operand's layout does not match 579 // the given layout. This copy replaces the use in the given instruction. 580 // Tuple operands will be deep-copied. 581 virtual Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, 582 HloInstruction* instruction, 583 int64_t operand_no); 584 585 // Registers a copy instruction added by the layout assignment pass. RegisterAddedCopy(HloInstruction * copy)586 void RegisterAddedCopy(HloInstruction* copy) { 587 CHECK_EQ(copy->opcode(), HloOpcode::kCopy); 588 added_copies_.insert(copy); 589 } 590 591 // Adds a copy for the operand of an instruction, unless such operand is 592 // already a copy, and has a single user (which is forcibly the instruction 593 // itself). 594 Status AddCopyForOperand(HloInstruction* instruction, int64_t operand_number); 595 596 // Apply the channel layout constraints by populating the channel_constraints 597 // data structure passed in at constructor time. Eventually adds copies in 598 // case two ends of a channel ended up with a different leyout. 599 Status ConstrainChannelLayouts(HloComputation* computation, 600 ChannelLayoutConstraints* channel_constraints); 601 602 // Resets the input ChannelLayoutConstraints to the original copy received 603 // from the constructor input. ResetChannelConstraints()604 void ResetChannelConstraints() { 605 if (channel_layout_constraints_ != nullptr) { 606 *channel_layout_constraints_ = channel_constraints_; 607 } 608 } 609 610 // Adds constraints related to host Send/Recv instructions. 611 Status BuildHostChannelConstraints(HloComputation* computation); 612 613 // Module points to analysis that can be updated for cloned computations. 614 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_; 615 616 // The set of HLO instructions which lacked any layout constraint, thus 617 // receiving propagated default layouts. 618 absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_; 619 620 HloPredicate instruction_can_change_layout_func_; 621 622 // CallGraph of the module, used to track callsites of each computation. 623 std::unique_ptr<CallGraph> call_graph_; 624 625 std::string ToString(const LayoutConstraints& constraints) const; 626 627 private: 628 // Map containing the layouts of all computations assigned so 629 // far. Computations are handled in a topological sort where computations are 630 // handled before their caller instructions so the layouts of caller 631 // instructions can be set to match the computation. 632 absl::flat_hash_map<const HloComputation*, std::unique_ptr<LayoutConstraints>> 633 computation_layouts_; 634 635 // Map from branch computations to the result layout they should apply. 636 absl::flat_hash_map<HloComputation*, ComputationLayout> conditional_mismatch_; 637 638 // Every copy added to the module by the layout assignment pass is registered 639 // here. 640 absl::flat_hash_set<HloInstruction*> added_copies_; 641 642 // The pointer to the channel layout constraints passed in with the 643 // constructor. If not nullptr, this is an input/output argument. 644 ChannelLayoutConstraints* channel_layout_constraints_ = nullptr; 645 646 // A copy of the input layout constraints used to reset the above pointer in 647 // case we have to undo operations due to the multiple passes over the 648 // computations/instructions. 649 ChannelLayoutConstraints channel_constraints_; 650 651 // Layout constraints for send/recv instructions which communicate with the 652 // host. 653 ChannelLayoutConstraints host_channel_constraints_; 654 655 // Array-shaped buffers which have not yet been constrained. 656 std::set<LogicalBuffer::Id> unconstrained_buffer_ids_; 657 658 mutable absl::flat_hash_map<const HloInstruction*, 659 std::unique_ptr<PointsToSet::BufferSet>> 660 buffer_sets_cache_; 661 662 // The set of BufferLayoutConstraints applied to the computation. 663 absl::node_hash_map<const LogicalBuffer*, BufferLayoutConstraint> 664 buffer_constraints_; 665 666 // A vector which holds constraints as they are added. Can be cleared with 667 // ClearAddedConstraints. 668 std::vector<const LayoutConstraint*> added_constraints_; 669 int64_t current_priority_ = LayoutConstraint::kBeginningPriority; 670 }; 671 672 } // namespace xla 673 674 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ 675