xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/layout_assignment.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
18 
19 #include <iosfwd>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/container/node_hash_map.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/service/call_graph.h"
32 #include "tensorflow/compiler/xla/service/computation_layout.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_module.h"
36 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
37 #include "tensorflow/compiler/xla/service/logical_buffer.h"
38 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
39 #include "tensorflow/compiler/xla/shape_layout.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/statusor.h"
42 #include "tensorflow/compiler/xla/types.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/core/status.h"
45 
46 namespace xla {
47 
48 // Abstract base class for layout constraints. These constraint objects are
49 // gathered together in LayoutConstraints object.
50 class LayoutConstraint {
51  public:
LayoutConstraint(bool mandatory,bool dfs,int64_t priority)52   LayoutConstraint(bool mandatory, bool dfs, int64_t priority)
53       : mandatory_(mandatory), dfs_(dfs), priority_(priority) {}
54   virtual ~LayoutConstraint() = default;
55 
56   virtual std::string ToString() const = 0;
57 
58   // True if this constraint cannot be overwritten by a different constraint.
mandatory()59   bool mandatory() const { return mandatory_; }
60 
61   // When true, propagate in DFS. When false, constraint will propagate in BFS.
dfs()62   bool dfs() const { return dfs_; }
63 
64   // Return the priority of the current constraint. When conflicting constraints
65   // are encountered, the higher priority one should win.
priority()66   int64_t priority() const { return priority_; }
IsDefaultLayout()67   bool IsDefaultLayout() const { return priority_ == kDefaultPriority; }
68 
69   // The priority of all default layouts when not set explicitly.
70   static constexpr int64_t kDefaultPriority = -2;
71   // The beginning priority of layout assignment.
72   static constexpr int64_t kBeginningPriority = 0;
73   // The priority of layout assignment given by the user for entry computation.
74   static constexpr int64_t kGivenPriority = 3;
75 
76  protected:
77   bool mandatory_;
78   bool dfs_;
79   int64_t priority_;
80 };
81 
82 std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint);
83 
84 // Layout constraint on a single LogicalBuffer. This constrains the layout of an
85 // array produced by a particular instruction.
86 class BufferLayoutConstraint : public LayoutConstraint {
87  public:
88   BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer,
89                          bool mandatory, bool dfs, int64_t priority);
90 
buffer()91   const LogicalBuffer& buffer() const { return *buffer_; }
layout()92   const Layout& layout() const { return layout_; }
93   bool UpdateLayout(int64_t priority, const Layout& layout, bool mandatory,
94                     bool dfs);
95 
96   std::string ToString() const override;
97 
98  private:
99   Layout layout_;
100   const LogicalBuffer* buffer_;
101 };
102 
103 // Constraint on the layout of the operand of an instruction. The constrained
104 // shape can be arbitrarily shaped (array or tuple). This is a constraint on the
105 // use of a shaped value and is not a hard constraint on the instruction(s)
106 // which define the value as copies may be inserted between the definition and
107 // use.
108 class OperandLayoutConstraint : public LayoutConstraint {
109  public:
110   OperandLayoutConstraint(const ShapeLayout& shape_layout,
111                           const HloInstruction* instruction, int64_t operand_no,
112                           bool mandatory, bool dfs, int64_t priority);
113 
shape_layout()114   const ShapeLayout& shape_layout() const { return shape_layout_; }
instruction()115   const HloInstruction* instruction() const { return instruction_; }
operand_no()116   const int64_t operand_no() const { return operand_no_; }
operand()117   const HloInstruction* operand() const {
118     return instruction_->operand(operand_no_);
119   }
120 
121   std::string ToString() const override;
122 
123  private:
124   ShapeLayout shape_layout_;
125   const HloInstruction* instruction_;
126   int64_t operand_no_;
127 };
128 
129 // Constraint on the layout of a computation interface.
130 class ComputationLayoutConstraint : public LayoutConstraint {
131  public:
132   static constexpr int64_t kDefaultLayoutIsUsed = 0;
133   static constexpr int64_t kResultLayoutIsSet = 1;
134   static constexpr int64_t kParameterLayoutIsSet = 2;
135   static constexpr int64_t kComputationLayoutIsSet = 3;
ComputationLayoutConstraint(const HloComputation * computation,ComputationLayout * computation_layout,int64_t priority)136   explicit ComputationLayoutConstraint(const HloComputation* computation,
137                                        ComputationLayout* computation_layout,
138                                        int64_t priority)
139       : LayoutConstraint(/*mandatory=*/true, /*dfs=*/true, priority),
140         layout_state_((computation_layout == nullptr)
141                           ? kDefaultLayoutIsUsed
142                           : kComputationLayoutIsSet),
143         computation_layout_(
144             (computation_layout == nullptr)
145                 ? ComputationLayout(computation->ComputeProgramShape(),
146                                     /*ignore_layouts=*/false)
147                 : *computation_layout) {}
148 
computation_layout()149   const ComputationLayout& computation_layout() const {
150     return computation_layout_;
151   }
ResetComputationLayout(const ComputationLayout & layout,int64_t priority,bool prop_result_layout,bool prop_parameter_layout)152   void ResetComputationLayout(const ComputationLayout& layout, int64_t priority,
153                               bool prop_result_layout,
154                               bool prop_parameter_layout) {
155     computation_layout_ = layout;
156     priority_ = priority;
157     if (prop_result_layout) {
158       layout_state_ |= kResultLayoutIsSet;
159     }
160     if (prop_parameter_layout) {
161       layout_state_ |= kParameterLayoutIsSet;
162     }
163   }
ResetResultLayout(const ShapeLayout & shape_layout,int64_t priority)164   void ResetResultLayout(const ShapeLayout& shape_layout, int64_t priority) {
165     *computation_layout_.mutable_result_layout() = shape_layout;
166     layout_state_ |= kResultLayoutIsSet;
167     priority_ = priority;
168   }
parameter_layout_is_set()169   bool parameter_layout_is_set() const {
170     return layout_state_ & kParameterLayoutIsSet;
171   }
result_layout_is_set()172   bool result_layout_is_set() const {
173     return layout_state_ & kResultLayoutIsSet;
174   }
default_layout_is_used()175   bool default_layout_is_used() const {
176     return layout_state_ == kDefaultLayoutIsUsed;
177   }
178   std::string ToString() const override;
179 
180  private:
181   // The layout_state_ variable is used to remember whether the layout for
182   // the overall computation is explicitly set, whether its result layout is
183   // explicitly set, or whether it only stores the default layout of the
184   // computation.
185   int64_t layout_state_;
186   ComputationLayout computation_layout_;
187 };
188 
189 // Contains constraints on the layout of channels; sends and recvs.
190 class ChannelLayoutConstraints {
191  public:
192   // Construct an empty constraint set.
ChannelLayoutConstraints()193   ChannelLayoutConstraints() {}
194 
195   // Returns true if channel_id has a layout constraint.
IsChannelConstrained(int64_t channel_id)196   bool IsChannelConstrained(int64_t channel_id) const {
197     return constraints_.contains(channel_id);
198   }
199 
200   // Given `shape`, apply the layout for `channel_id`. `channel_id` must already
201   // be constrained.
LayoutShapeForChannel(Shape shape,int64_t channel_id)202   Shape LayoutShapeForChannel(Shape shape, int64_t channel_id) const {
203     auto it = constraints_.find(channel_id);
204     CHECK(it != constraints_.end()) << "Channel " << channel_id;
205     *shape.mutable_layout() = it->second;
206     return shape;
207   }
208 
209   // Returns the layout constraint for `channel_id`, which must already be
210   // constrained.
LayoutForChannel(int64_t channel_id)211   const Layout& LayoutForChannel(int64_t channel_id) const {
212     auto it = constraints_.find(channel_id);
213     CHECK(it != constraints_.end()) << "Channel " << channel_id;
214     return it->second;
215   }
216 
217   // Adds a new layout constraint for `channel_id`. If a constraint for
218   // `channel_id` has been added, this API returns nullptr, otherwise returns
219   // the layout which has already been set for the channel.
ConstrainChannel(int64_t channel_id,const Layout & layout)220   const Layout* ConstrainChannel(int64_t channel_id, const Layout& layout) {
221     auto it = constraints_.emplace(std::make_pair(channel_id, layout));
222     if (it.second) {
223       return nullptr;
224     }
225     return LayoutUtil::Equal(layout, it.first->second) ? nullptr
226                                                        : &it.first->second;
227   }
228 
229  private:
230   absl::flat_hash_map<int64_t, Layout> constraints_;
231 };
232 
233 // HLO pass which assigns layouts to all instructions in the HLO module while
234 // satisfying all necessary invariants and minimizing cost.
235 class LayoutAssignment : public HloModulePass {
236  public:
237   // entry_computation_layout is modified to populate a layout for the result in
238   // the case that no particular layout is requested.
239   //
240   // channel_constraints is both an input and output. Any sends or recvs that
241   // are present in channel_constraints will be laid out as constrained. Any
242   // unconstrained sends or recvs will be laid out as locally optimal and their
243   // layout will be added as a constraint to channel_constraints.
244   //
245   // If channel_constraints is nullptr, no kSend or kRecvs must be contained
246   // within any module passed to `Run`.
247   explicit LayoutAssignment(
248       ComputationLayout* entry_computation_layout,
249       ChannelLayoutConstraints* channel_constraints = nullptr,
250       bool reverse_computation_order = false);
~LayoutAssignment()251   ~LayoutAssignment() override {}
points_to_analysis()252   const TuplePointsToAnalysis& points_to_analysis() const {
253     return *points_to_analysis_;
254   }
name()255   absl::string_view name() const override { return "layout-assignment"; }
256 
257   // Assign layouts to the given module. Returns whether the module was changed
258   // (any layouts were changed).
259   using HloPassInterface::Run;
260   StatusOr<bool> Run(
261       HloModule* module,
262       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
263 
264   // Class encapsulating the layout constraints of the values in a HLO
265   // computation.
266   class LayoutConstraints {
267    public:
268     explicit LayoutConstraints(HloComputation* computation,
269                                ComputationLayout* computation_layout,
270                                int64_t priority);
271     ~LayoutConstraints() = default;
272 
computation()273     const HloComputation* computation() const { return computation_; }
computation()274     HloComputation* computation() { return computation_; }
ResetOperandConstraints()275     void ResetOperandConstraints() { operand_constraints_.clear(); }
276     const ShapeLayout* OperandLayout(const HloInstruction* instruction,
277                                      int64_t operand_no) const;
278     const OperandLayoutConstraint* GetOperandLayoutConstraint(
279         const HloInstruction* instruction, int64_t operand_no) const;
280     const ShapeLayout* ResultLayout() const;
281     OperandLayoutConstraint* InsertOperandLayoutConstraint(
282         const HloInstruction* instruction, int64_t operand_no,
283         const OperandLayoutConstraint& constraint);
284     Status SetResultLayout(LayoutAssignment* assignment,
285                            const Shape& shape_with_layout, int64_t priority);
286 
computation_layout()287     const ComputationLayout& computation_layout() const {
288       return computation_constraint_.computation_layout();
289     }
computation_constraint()290     const ComputationLayoutConstraint& computation_constraint() const {
291       return computation_constraint_;
292     }
mutable_computation_constraint()293     ComputationLayoutConstraint* mutable_computation_constraint() {
294       return &computation_constraint_;
295     }
296 
297    private:
298     // The set of OperandLayoutConstraints applied to the computation.
299     using OperandConstraintKey = std::pair<const HloInstruction*, int64_t>;
300     std::map<OperandConstraintKey, OperandLayoutConstraint>
301         operand_constraints_;
302 
303     HloComputation* computation_;
304     ComputationLayoutConstraint computation_constraint_;
305   };
306 
307   // Determines whether an instruction can change layouts. An instruction not
308   // being able to change layout means that it requires operands with the same
309   // rank as the output to have the same layout as the output.
310   static bool InstructionCanChangeLayout(const HloInstruction* instruction);
311 
mutable_computation_constraints(HloComputation * computation)312   LayoutConstraints* mutable_computation_constraints(
313       HloComputation* computation) {
314     auto it = computation_layouts_.find(computation);
315     LayoutConstraints* constraints = nullptr;
316     if (it == computation_layouts_.end()) {
317       computation_layouts_.emplace(
318           computation,
319           constraints = new LayoutConstraints(
320               computation, nullptr, LayoutConstraint::kDefaultPriority));
321     } else {
322       constraints = (*it).second.get();
323     }
324     return constraints;
325   }
326   void PushAddedConstraints(const LayoutConstraint* constraint);
327 
328   // In case of an array shape returns true iff it is at most rank 1. In case of
329   // a tuple shape returns true iff all leaf shapes are at most rank 1.
330   static bool IsAtMostRank1(const Shape& shape);
331   // Convenience wrapper around SetOperandLayout for setting the layout of a
332   // operand using a Layout object. The operand must be array-shaped.
333   Status SetArrayOperandLayout(const Layout& layout,
334                                const HloInstruction* instruction,
335                                int64_t operand_no, bool mandatory = true,
336                                bool dfs = true) {
337     return SetArrayOperandLayout(layout, instruction, operand_no, mandatory,
338                                  dfs, current_priority_);
339   }
340   Status SetArrayOperandLayout(const Layout& layout,
341                                const HloInstruction* instruction,
342                                int64_t operand_no, bool mandatory, bool dfs,
343                                int64_t priority);
344   // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers
345   // created by the instruction to the layouts in the given shape. The
346   // instruction must define every logical buffer in its output.
347   // If `allow_alias` is false, the function will check that all output buffers
348   // are defined by `instruction`, not aliased to an instruction elsewhere.
349   Status SetInstructionLayout(const Shape& shape_with_layout,
350                               const HloInstruction* instruction,
351                               bool mandatory = true, bool dfs = true,
352                               bool allow_alias = false) {
353     return SetInstructionLayout(shape_with_layout, instruction, mandatory, dfs,
354                                 allow_alias, current_priority_);
355   }
356   Status SetInstructionLayout(const Shape& shape_with_layout,
357                               const HloInstruction* instruction, bool mandatory,
358                               bool dfs, bool allow_alias, int64_t priority);
359   // Set the same given layout across all components of the instruction output.
360   // It works the same as the API above if the output is a single array.
361   Status SetInstructionLayout(const Layout& layout,
362                               const HloInstruction* instruction,
363                               bool mandatory = true, bool dfs = true,
364                               bool allow_alias = false, int64_t priority = -1);
365   // Add a constraint on the layout of a LogicalBuffer, the layout of the
366   // operand of the instruction, or the layout of the result of the computation,
367   // respectively.
368   Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer,
369                          bool mandatory = true, bool dfs = true) {
370     return SetBufferLayout(layout, buffer, mandatory, dfs, current_priority_);
371   }
372   Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer,
373                          bool mandatory, bool dfs, int64_t priority);
374   Status SetOperandLayout(const Shape& shape_with_layout,
375                           const HloInstruction* instruction, int64_t operand_no,
376                           bool mandatory = true, bool dfs = true) {
377     return SetOperandLayout(shape_with_layout, instruction, operand_no,
378                             mandatory, dfs, current_priority_);
379   }
380   Status SetOperandLayout(const Shape& shape_with_layout,
381                           const HloInstruction* instruction, int64_t operand_no,
382                           bool mandatory, bool dfs, int64_t priority);
reverse_computation_order()383   bool reverse_computation_order() const { return reverse_computation_order_; }
384 
saved_entry_computation_layout()385   ComputationLayout& saved_entry_computation_layout() {
386     return saved_entry_computation_layout_;
387   }
388 
389  protected:
390   // These methods, invoked by PropagateConstraints, propagate a layout
391   // constraint to its neighbors (i.e. operands and users) in order to minimize
392   // the cost of the instructions being constrainted on. New constraints are
393   // added to the given constraint set.
394   //
395   // Backends can override these methods with backend-specific propagation
396   // rules.
397   virtual Status PropagateBufferConstraint(
398       const BufferLayoutConstraint& buffer_constraint,
399       LayoutConstraints* constraints);
400   virtual Status PropagateOperandConstraint(
401       const OperandLayoutConstraint& operand_constraint,
402       LayoutConstraints* constraints);
403   virtual Status PropagateResultConstraint(
404       const ComputationLayoutConstraint& layout_constraint,
405       LayoutConstraints* constraints);
406 
GetUnconstrainedLayout(const LogicalBuffer & buffer)407   virtual Layout GetUnconstrainedLayout(const LogicalBuffer& buffer) {
408     return LayoutUtil::GetDefaultLayoutForShape(buffer.shape());
409   }
410   // Called after layouts of an instruction have been finalized to allow
411   // subclasses to check for platform specific assumptions.
Verify(const HloInstruction * instruction)412   virtual Status Verify(const HloInstruction* instruction) {
413     return OkStatus();
414   }
415 
416   Status PropagateUnconstraintedBuffers(LayoutConstraints* constraints);
417   const BufferLayoutConstraint* GetBufferLayoutConstraint(
418       const LogicalBuffer& buffer) const;
419   // Find a bufferset in the bufferset cache. This is useful since we can
420   // currently create the flattened buffer set for the same instruction many
421   // times, which is often slow.
422   PointsToSet::BufferSet* GetBufferSet(const HloInstruction* instruction) const;
423   // Similar to above, but returns true only if all buffers associated with that
424   // operand are forwarded.
425   bool AllOperandBuffersForwarded(const HloInstruction* instruction,
426                                   int64_t operand_no) const;
427   // Returns true if any buffer in the given operand is forwarded to the output
428   // of the given instruction. For example, the Tuple instruction forwards the
429   // buffers of its operands and would return true for each of its operands.
430   bool AnyOperandBufferForwarded(const HloInstruction* instruction,
431                                  int64_t operand_no) const;
432   StatusOr<Layout> InferArrayLayout(const HloInstruction* instruction,
433                                     const ShapeIndex& index);
434 
435   // Propagates a buffer layout constraint into the operands that use it.
436   Status PropagateBufferConstraintToUses(
437       const BufferLayoutConstraint& buffer_constraint,
438       LayoutConstraints* constraints);
439 
440   // Propagates a layout constraint on the use of the result of the given
441   // instruction to the definitions of the LogicalBuffers which make up the
442   // result.
443   Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout,
444                                       const HloInstruction* instruction,
445                                       LayoutConstraints* constraints,
446                                       int64_t priority);
447 
448   // Propagates the memory space defined in the entry computation to the called
449   // computations.
450   Status PropagateMemorySpace(HloModule* module);
451 
452   // Chooses a layout of operand `operand_no` of `instruction` that minimizes
453   // the cost of `instruction`. `output_layout` is the layout of `instruction`.
454   // Returns null if it can't decide the best layout.
455   // Precondition: `instruction` and the operand are array-shaped.
456   virtual std::unique_ptr<Layout> ChooseOperandLayoutFromOutputLayout(
457       const Layout& output_layout, const HloInstruction* instruction,
458       int64_t operand_no);
459   // Given the layout of `user`'s `operand_no`-th operand, chooses a layout of
460   // `user` that minimizes its cost on that operand.  Returns null if it can't
461   // decide the best layout.
462   // Precondition: `user` and the operand are array-shaped.
463   virtual std::unique_ptr<Layout> ChooseOutputLayoutFromOperandLayout(
464       const Layout& operand_layout, const HloInstruction* user,
465       int64_t operand_no);
466 
467   // Convenient wrapper for InstructionCanChangeLayout which can be overridden
468   // in subclasses.
469   virtual bool InstructionCanChangeLayoutInstance(
470       const HloInstruction* instruction);
471 
472  private:
473   // Initializes the layout assignment object for a new Run() call.
474   Status Init(HloModule* module);
475 
476   // Adds constraints which must be satisfied for correctness on all
477   // backends. Called once prior to propagating constraints.
478   Status AddMandatoryConstraints(ChannelLayoutConstraints* channel_constraints,
479                                  LayoutConstraints* constraints);
480 
481   // Return a vector containing the constraints which have been added to the
482   // LayoutConstraints object since the construction of the object or since the
483   // last time ConsumeAddedConstraints() has been called. This is used to
484   // identify newly added constraints when propagating layouts.
ConsumeAddedConstraints()485   std::vector<const LayoutConstraint*> ConsumeAddedConstraints() {
486     std::vector<const LayoutConstraint*> ret_vec(std::move(added_constraints_));
487     added_constraints_.clear();
488     return ret_vec;
489   }
ClearAddedConstraints()490   void ClearAddedConstraints() { added_constraints_.clear(); }
491 
492   // This method can be overridden to add backend-specific constraints to the
493   // layout of the instructions of a computation. This method is called after
494   // all mandatory constraints have been added via AddMandatoryConstraints
495   // and before propagating constraints.
AddBackendConstraints(LayoutConstraints * constraints)496   virtual Status AddBackendConstraints(LayoutConstraints* constraints) {
497     return OkStatus();
498   }
499 
500   // Construct constraints and assign layouts to all instructions in the
501   // computation satisfying the given ComputationLayout, if not nullptr.
502   // Otherwise the ComputationLayout will be calculated by propagating the
503   // computation instruction constraints.
504   // Layouts constraints are added, then propagated until all LogicalBuffers in
505   // the computation are constrained.
506   Status RunOnComputation(LayoutConstraints* constraints,
507                           ChannelLayoutConstraints* channel_constraints);
508 
509   // Assign layouts to the instructions of a computation which satisfy the given
510   // layout constraints. Copies may be added to satisfy the constraints. The
511   // given LayoutConstraints must have layout constraints every logical buffer
512   // in the computation.
513   Status AssignLayouts(LayoutConstraints& constraints);
514 
515   // Propagates layout constraints from a set of initial constraints in order to
516   // minimize the local cost of the computation. This propagation is *not*
517   // required for correctness.
518   Status PropagateConstraints(LayoutConstraints* constraints);
519 
520   Status PropagateBufferConstraintToOperands(
521       const BufferLayoutConstraint& buffer_constraint,
522       LayoutConstraints* constraints);
523 
524   // Check that all layouts in the module have been set and satisfy all
525   // necessary conditions.
526   Status CheckLayouts(
527       HloModule* module,
528       const absl::flat_hash_set<absl::string_view>& execution_threads);
529 
530   // Computes the ComputationLayout of the given constraints based of the
531   // layouts assigned to parameters and root instruction. Also propagate
532   // constraints to computation nested inside.
533   Status CalculateComputationLayout(LayoutConstraints* constraints);
534 
535   // Clears all the layouts which can be cleared within a computation.
536   Status ClearComputationLayouts(HloComputation* computation);
537 
538   // Clears the side effects of a previous pass, like added copy instructions.
539   Status ClearPreviousPassSideEffects(
540       HloModule* module,
541       const absl::flat_hash_set<absl::string_view>& execution_threads);
542 
543   // Propagates the layouts computed by the layout assignment pass on the given
544   // computation, to the computation layout passed in to this API.
545   // This API propagates missing layout, and also checks that the caller
546   // specified have been respected, by comparing those with the parameters and
547   // root computation instruction.
548   Status PropagateComputationLayouts(HloComputation* computation,
549                                      ComputationLayout* computation_layout);
550 
551   // The pointer to the ComputationLayout passed as constructor parameter.
552   ComputationLayout* entry_computation_layout_;
553 
554   // A copy of entry_computation_layout_ used to reset it to the initial values
555   // during the multiple passes done by the layout assignment operation.
556   ComputationLayout saved_entry_computation_layout_;
557   // If set true, reverse the computation traversal order when assigning layout.
558   bool reverse_computation_order_;
559 
560  protected:
561   static constexpr int64_t kNumberOfPropagationRounds = 2;
562   // Sets up the copy instruction according to the characteristic (sharding,
563   // metadata, ...) of the reference instruction. The index argument is used
564   // when the instruction is a tuple, and in such case the index represents
565   // the location from where the copy instruction was created from.
566   // If the index is empty, the whole sharding will be propagated, even in case
567   // the instruction has a tuple sharding.
568   static void SetupCopiedInstruction(const HloInstruction& instruction,
569                                      HloInstruction* copy,
570                                      const ShapeIndex& index);
571 
572   // Creates and returns a copy of the given instruction with a different
573   // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
574   // instruction producing the copy is returned.
575   StatusOr<HloInstruction*> CreateCopyWithNewLayout(
576       const Shape& shape_with_layout, HloInstruction* instruction);
577 
578   // Creates a copy of the given operand if the operand's layout does not match
579   // the given layout. This copy replaces the use in the given instruction.
580   // Tuple operands will be deep-copied.
581   virtual Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
582                                             HloInstruction* instruction,
583                                             int64_t operand_no);
584 
585   // Registers a copy instruction added by the layout assignment pass.
RegisterAddedCopy(HloInstruction * copy)586   void RegisterAddedCopy(HloInstruction* copy) {
587     CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
588     added_copies_.insert(copy);
589   }
590 
591   // Adds a copy for the operand of an instruction, unless such operand is
592   // already a copy, and has a single user (which is forcibly the instruction
593   // itself).
594   Status AddCopyForOperand(HloInstruction* instruction, int64_t operand_number);
595 
596   // Apply the channel layout constraints by populating the channel_constraints
597   // data structure passed in at constructor time. Eventually adds copies in
598   // case two ends of a channel ended up with a different leyout.
599   Status ConstrainChannelLayouts(HloComputation* computation,
600                                  ChannelLayoutConstraints* channel_constraints);
601 
602   // Resets the input ChannelLayoutConstraints to the original copy received
603   // from the constructor input.
ResetChannelConstraints()604   void ResetChannelConstraints() {
605     if (channel_layout_constraints_ != nullptr) {
606       *channel_layout_constraints_ = channel_constraints_;
607     }
608   }
609 
610   // Adds constraints related to host Send/Recv instructions.
611   Status BuildHostChannelConstraints(HloComputation* computation);
612 
613   // Module points to analysis that can be updated for cloned computations.
614   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
615 
616   // The set of HLO instructions which lacked any layout constraint, thus
617   // receiving propagated default layouts.
618   absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_;
619 
620   HloPredicate instruction_can_change_layout_func_;
621 
622   // CallGraph of the module, used to track callsites of each computation.
623   std::unique_ptr<CallGraph> call_graph_;
624 
625   std::string ToString(const LayoutConstraints& constraints) const;
626 
627  private:
628   // Map containing the layouts of all computations assigned so
629   // far. Computations are handled in a topological sort where computations are
630   // handled before their caller instructions so the layouts of caller
631   // instructions can be set to match the computation.
632   absl::flat_hash_map<const HloComputation*, std::unique_ptr<LayoutConstraints>>
633       computation_layouts_;
634 
635   // Map from branch computations to the result layout they should apply.
636   absl::flat_hash_map<HloComputation*, ComputationLayout> conditional_mismatch_;
637 
638   // Every copy added to the module by the layout assignment pass is registered
639   // here.
640   absl::flat_hash_set<HloInstruction*> added_copies_;
641 
642   // The pointer to the channel layout constraints passed in with the
643   // constructor. If not nullptr, this is an input/output argument.
644   ChannelLayoutConstraints* channel_layout_constraints_ = nullptr;
645 
646   // A copy of the input layout constraints used to reset the above pointer in
647   // case we have to undo operations due to the multiple passes over the
648   // computations/instructions.
649   ChannelLayoutConstraints channel_constraints_;
650 
651   // Layout constraints for send/recv instructions which communicate with the
652   // host.
653   ChannelLayoutConstraints host_channel_constraints_;
654 
655   // Array-shaped buffers which have not yet been constrained.
656   std::set<LogicalBuffer::Id> unconstrained_buffer_ids_;
657 
658   mutable absl::flat_hash_map<const HloInstruction*,
659                               std::unique_ptr<PointsToSet::BufferSet>>
660       buffer_sets_cache_;
661 
662   // The set of BufferLayoutConstraints applied to the computation.
663   absl::node_hash_map<const LogicalBuffer*, BufferLayoutConstraint>
664       buffer_constraints_;
665 
666   // A vector which holds constraints as they are added. Can be cleared with
667   // ClearAddedConstraints.
668   std::vector<const LayoutConstraint*> added_constraints_;
669   int64_t current_priority_ = LayoutConstraint::kBeginningPriority;
670 };
671 
672 }  // namespace xla
673 
674 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
675