xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
18 
19 #include <memory>
20 #include <optional>
21 #include <string>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/node_hash_map.h"
26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
32 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 
35 namespace xla {
36 namespace spmd {
37 
38 struct SpmdPartitionerOptions {
39   // Always exchange halo on LHS for all convolutions. If false, backprop filter
40   // convolution exchanges halo on RHS.
41   bool conv_halo_exchange_always_on_lhs = true;
42 
43   // The number of instructions to be reported for the highest memory profile
44   // instructions.
45   int64_t report_instruction_count = 5;
46 
47   // The minimum size in MiB of an einsum operand to be considered using
48   // windowed implementation in an HLO loop.
49   int64_t threshold_for_windowed_einsum_mib = 256;
50 
51   // Whether unroll windowed einsum loop by degree of two.
52   bool unroll_windowed_einsum = false;
53 
54   // Whether doing bidirectional collective permute in windowed einsum loop.
55   bool bidirectional_windowed_einsum = false;
56 
57   // Whether the entry computations' signature could change after partitioning.
58   bool allow_module_signature_change = false;
59 
60   // Whether to use cached all-gather to avoid repeatedly replicate a tiled
61   // tensor. If it is set to false, the result tends to be more
62   // memory-efficient, and the compiler can use the ScheduleAwareAllGatherCSE
63   // pass to CSE some all-gathers which are relatively close to each other.
64   bool cache_all_gather = true;
65 
66   // When making a compromise between windowed einsum speed and memory usage
67   // prefer the former if true.
68   bool choose_faster_windowed_einsum_over_mem = false;
69 
70   // Whether doing bidirectional communication when decomposing independent
71   // all-gathers.
72   bool bidirectional_decomposed_all_gather = false;
73 };
74 
75 // Class to wrap the computation builder to capture information during SPMD
76 // transformation.
77 class SpmdBuilder : public HloComputation::Builder {
78  public:
SpmdBuilder(const std::string & name,HloInstruction * hlo)79   SpmdBuilder(const std::string& name, HloInstruction* hlo)
80       : HloComputation::Builder(name) {
81     visiting_hlo_ = hlo;
82   }
83 
84   HloInstruction* AddInstruction(
85       std::unique_ptr<HloInstruction> instruction) override;
86 
derived_instructions(HloInstruction * hlo)87   const std::vector<HloInstruction*>& derived_instructions(
88       HloInstruction* hlo) {
89     return instructions_.at(hlo);
90   }
91 
set_visiting_hlo(HloInstruction * hlo)92   void set_visiting_hlo(HloInstruction* hlo) {
93     visiting_hlo_ = hlo;
94     instructions_[hlo];
95   }
96 
visiting_hlo()97   HloInstruction* visiting_hlo() const { return visiting_hlo_; }
98 
99   // Wrapper of queries to broadcast_dims_.
BroadcastDimsForCreatedHlo(const HloInstruction * hlo)100   std::optional<const absl::flat_hash_set<int64_t>*> BroadcastDimsForCreatedHlo(
101       const HloInstruction* hlo) {
102     auto it = broadcast_dims_.find(hlo);
103     if (it == broadcast_dims_.end()) {
104       return std::nullopt;
105     }
106     return &it->second;
107   }
108 
109  private:
110   // Currently visiting instruction.
111   HloInstruction* visiting_hlo_;
112 
113   // Map from the currently visiting (old) instruction to new instructions
114   // created during SPMD partitioning.
115   HloInstructionMap<std::vector<HloInstruction*>> instructions_;
116 
117   // Maps from each created instruction to a set of dimensions that are from
118   // broadcasts or elementwise ops over broadcasts. This means elements along
119   // these dimensions have the same value.
120   absl::flat_hash_map<const HloInstruction*, absl::flat_hash_set<int64_t>>
121       broadcast_dims_;
122 };
123 
124 // A set of functions that create the cross-partition collective ops.
125 struct SPMDCollectiveOpsCreator {
126   // Function used to create a partition ID HLO.
127   std::function<HloInstruction*(SpmdBuilder*)> create_partition_id;
128 
129   // Function used to create a cross-partition all-reduce HLO.
130   std::function<HloInstruction*(
131       SpmdBuilder*, HloInstruction* operand, HloComputation* reduction,
132       const std::vector<std::vector<int64_t>>& partition_subgroups,
133       int64_t channel_id)>
134       create_cross_partition_all_reduce;
135 
136   // Function used to create a cross-partition collective-permute HLO.
137   std::function<HloInstruction*(
138       SpmdBuilder*, HloInstruction* operand,
139       std::vector<std::pair<int64_t, int64_t>>& src_dst_pairs,
140       int64_t next_channel_id)>
141       create_cross_partition_collective_permute;
142 
143   // Function used to create a cross-partition all-to-all HLO.
144   std::function<HloInstruction*(
145       SpmdBuilder*, absl::Span<HloInstruction* const> operands,
146       const std::vector<std::vector<int64_t>>& partition_subgroups,
147       int64_t channel_id, std::optional<int64_t> split_dimension)>
148       create_cross_partition_all_to_all;
149 
150   // Function used to create a cross-partition all-gather HLO. This is optional:
151   // if it is nullptr, the partitioner will use all-reduce instead.
152   std::function<HloInstruction*(
153       SpmdBuilder*, HloInstruction* operand, const Shape& ag_shape,
154       const std::vector<std::vector<int64_t>>& partition_subgroups,
155       int64_t channel_id, int64_t all_gather_dimension)>
156       create_cross_partition_all_gather;
157 };
158 
159 // Create a default SPMDCollectiveOpsCreator.
160 SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions,
161                                                         int64_t num_replicas);
162 
163 // Logger to report memory usage during SPMD partitioning.
164 class SpmdLogger {
165  public:
SpmdLogger(int64_t report_instruction_count,bool disabled)166   SpmdLogger(int64_t report_instruction_count, bool disabled)
167       : report_instruction_count_(report_instruction_count),
168         disabled_(disabled) {}
169   static std::string ReportBeforePartition(const HloModule& module,
170                                            int64_t report_instruction_count);
171   static std::string ReportAfterPartition(const HloModule& module,
172                                           int64_t report_instruction_count);
173 
174   // Registers the logging for the groups of instructions created to transform
175   // the given hlo.
176   void RegisterLogEntry(HloInstruction* hlo,
177                         const std::vector<HloInstruction*>& group);
178 
179   std::string MakeReport();
180 
181  private:
182   template <typename F>
183   static std::string ReportMemoryUsage(const HloModule& module, const F& filter,
184                                        int64_t report_instruction_count);
185 
186   // A vector of logging messages (one for each original HLO instruction), where
187   // the first integer of the pair represents the size of the HBM used.
188   std::vector<std::pair<int64_t, std::string>> entries_;
189 
190   int64_t report_instruction_count_;
191 
192   // Note that we allow creating a *disabled* logger when logging is not
193   // enabled, in which case it is supposed to avoid doing any potentially
194   // expensive work. The logger is still created in this case and passed to the
195   // users to help avoid changing current call sites.
196   const bool disabled_;
197 };
198 
199 class SpmdPartitioningVisitor;
200 
201 class SpmdPartitioner : public HloModulePass {
202  public:
203   SpmdPartitioner(int64_t num_partitions, int64_t num_replicas,
204                   SpmdPartitionerOptions options);
SpmdPartitioner(int64_t num_partitions,int64_t num_replicas,SpmdPartitionerOptions options,SPMDCollectiveOpsCreator collective_ops_creator)205   SpmdPartitioner(int64_t num_partitions, int64_t num_replicas,
206                   SpmdPartitionerOptions options,
207                   SPMDCollectiveOpsCreator collective_ops_creator)
208       : num_partitions_(num_partitions),
209         num_replicas_(num_replicas),
210         options_(std::move(options)),
211         collective_ops_creator_(std::move(collective_ops_creator)) {}
name()212   absl::string_view name() const override { return "spmd-partitioning"; }
213   using HloPassInterface::Run;
214   StatusOr<bool> Run(
215       HloModule* module,
216       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
217 
218   // Transforms the given computation with SPMD instructions, replacing it with
219   // a new computation.
220   StatusOr<bool> PartitionComputation(HloComputation* computation,
221                                       const HloSharding& root_sharding,
222                                       int64_t* next_channel_id,
223                                       SpmdLogger* logger);
224 
225   // Creates all-gather(s) based on HloSharding. Can be overridden to customize.
226   // The default uses a single all-gather even if there are multiple sharded
227   // dimensions, and adds potential reshapes and transposes to achieve that.
228   // If it returns false, the partitioner will fall back to all-reduce.
229   // `selected_dims` specifies the dimensions along which the all-gather happens
230   // in the tiled sharding, which allows potentially creating a subgroup
231   // all-gather.
232   virtual HloInstruction* AllGatherShards(
233       SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
234       int64_t* next_channel_id, absl::Span<const int64_t> selected_dims,
235       const SPMDCollectiveOpsCreator& collectives_creator);
236 
237   // Creates all-reduce(s) across devices along selected_dims in sharding. Can
238   // be overridden to customize.
239   virtual HloInstruction* AllReduceAlongShardingDims(
240       SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
241       int64_t* next_channel_id, absl::Span<const int64_t> selected_dims,
242       const SPMDCollectiveOpsCreator& collectives_creator,
243       HloComputation* reduction);
244 
options()245   const SpmdPartitionerOptions& options() { return options_; }
246 
247  protected:
248   virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor(
249       HloComputation* computation, int64_t num_partitions, int64_t num_replicas,
250       const SPMDCollectiveOpsCreator& collective_ops_creator,
251       int64_t* next_channel_id, SpmdLogger* logger,
252       SpmdPartitionerOptions options);
253 
254   HloInstruction* AllGatherShardsInternal(
255       SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
256       int64_t* next_channel_id, absl::Span<const int64_t> selected_dims,
257       const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag);
258   HloInstruction* AllReduceAlongShardingDimsInternal(
259       SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
260       int64_t* next_channel_id, absl::Span<const int64_t> selected_dims,
261       const SPMDCollectiveOpsCreator& collectives_creator,
262       HloComputation* reduction, bool per_dim_ar);
263 
264   // Verifies that the sharding of instructions in the module are valid, and
265   // also fill in missing sharding information.
266   virtual Status PreprocessSharding(
267       HloModule* module,
268       const absl::flat_hash_set<absl::string_view>& execution_threads);
269 
270   // Returns if the given side-effecting instruction is allowed to have
271   // replicated sharding.
CanSideEffectingHaveReplicatedSharding(const HloInstruction * hlo)272   virtual bool CanSideEffectingHaveReplicatedSharding(
273       const HloInstruction* hlo) {
274     return hlo->opcode() == HloOpcode::kInfeed ||
275            hlo->opcode() == HloOpcode::kOutfeed;
276   }
277 
278   // Preprocesses the graph to simplify some communication patterns. E.g., merge
279   // pad->slice into a single pad with potentially negative padding to avoid
280   // multiple halo exchanges.
281   Status PreprocessHlos(
282       HloModule* module,
283       const absl::flat_hash_set<absl::string_view>& execution_threads);
284 
285   const int64_t num_partitions_;
286   const int64_t num_replicas_;
287 
288   SpmdPartitionerOptions options_;
289   SPMDCollectiveOpsCreator collective_ops_creator_;
290   std::vector<std::vector<int64_t>> device_groups_;
291 };
292 
293 // Class describes partition state of the data represented by an HLO created
294 // during SPMD partitioning pass.
295 //
296 // Data on some devices may include padding region, if the base (full) shape
297 // could not be evenly partitioned.
298 class PartitionedHlo {
299  public:
300   // Return value for ReshardAsWindowedInput which describes the resharded HLO,
301   // the window for the user on the shard, and if necessary, the dynamic slice
302   // offsets to be applied to the output of the op being sharded.
303   struct WindowedInputShardReturnValue {
304     HloInstruction* sharded_input;
305     Window shard_window;
306     std::optional<std::vector<HloInstruction*>> dynamic_slice_index_on_output;
307   };
308   // A cache for resharding each partitioned HLO.
309   struct ReshardCache {
310     struct PerHloCache {
311       absl::flat_hash_map<HloSharding, PartitionedHlo> reshard_cache;
312       std::vector<
313           std::tuple<HloSharding, Window, WindowedInputShardReturnValue>>
314           window_reshard_cache;
315     };
316     // Use absl::node_hash_map for pointer stability.
317     absl::node_hash_map<HloInstruction*, PerHloCache> per_hlo_cache;
318     // Caches for nested partitioning of grouped sharding. Each string key
319     // represents a unique way of grouping devices.
320     absl::flat_hash_map<std::string, std::unique_ptr<ReshardCache>>
321         groupd_caches;
322   };
323   struct PartitioningState {
324     SpmdBuilder* b;
325     HloModule* module;
326     int64_t num_replicas;
327     HloInstruction* partition_id;
328     SPMDCollectiveOpsCreator collective_ops_creator;
329     int64_t* next_channel_id;
330     ReshardCache* reshard_cache;
331     SpmdPartitioner* partitioner;
332   };
PartitionedHlo(HloInstruction * hlo,Shape base_shape,PartitioningState state)333   PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state)
334       : hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) {
335     CHECK(hlo->has_sharding())
336         << "PartitionedHlo is missing sharding:" << hlo->ToString();
337     // If the tuple shape instruction does not have a tuple sharding, reassign
338     // to use the tuple sharding. Reshard() implementation assumes this.
339     if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) {
340       hlo_->set_sharding(
341           hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie());
342     }
343   }
344 
345   // Reshards the current SPMD instruction to a new sharding with optional
346   // specified pad value used during resharding. Could only modify the reshard
347   // cache.
348   PartitionedHlo Reshard(const HloSharding& target,
349                          std::optional<Literal> pad_value = std::nullopt);
350 
351   // Pads the garbage area of the output with the provided value. Normally,
352   // unevenly partitioned dimensions are padded on the right, but this function
353   // allows specifying left-padded dimensions, which can be used during the
354   // handling of kReverse, etc.
355   PartitionedHlo PadWithValue(
356       HloInstruction* pad_value,
357       absl::Span<const int64_t> left_padded_dims = {},
358       absl::Span<const int64_t> skipped_dims = {}) const;
359 
360   // Same as PadWithValue but does not create a new PartitionedHlo.
361   HloInstruction* PadWithValueHlo(
362       HloInstruction* pad_value,
363       absl::Span<const int64_t> left_padded_dims = {},
364       absl::Span<const int64_t> skipped_dims = {}) const;
365 
366   PartitionedHlo PadWithZero(absl::Span<const int64_t> left_padded_dims = {},
367                              absl::Span<const int64_t> skipped_dims = {}) const;
368 
369   // Returns the SPMD instruction.
hlo()370   HloInstruction* hlo() const { return hlo_; }
371 
372   // Returns the sharding of the SPMD instruction.
sharding()373   const HloSharding& sharding() const { return hlo_->sharding(); }
374 
375   // Original full shape of the data.
base_shape()376   const Shape& base_shape() const { return base_shape_; }
377 
NewChannel()378   int64_t NewChannel() const { return (*state_.next_channel_id)++; }
379 
380   // Reshards the HLO to a usable partitioned input for a windowed user. Could
381   // only modify the reshard cache.
382   std::optional<WindowedInputShardReturnValue> ReshardAsWindowedInput(
383       const Window& window, const HloSharding& target,
384       HloInstruction* pad_value, bool mask_invalid_region = true);
385 
state()386   const PartitioningState& state() const { return state_; }
387 
388   // Helper function to replicate the data on all devices. Could only modify
389   // the reshard cache.
390   PartitionedHlo Replicate();
391 
392   // Helper function to replicate the data for partitions along the given dims.
393   HloInstruction* ReplicatePartial(absl::Span<const int64_t> dims);
394 
395   // Set state of the partitoned HLO.
set_state(PartitioningState state)396   void set_state(PartitioningState state) { state_ = std::move(state); }
397 
398  private:
399   // Same as Reshard except that it does not explicitly modify the reshard
400   // cache, although it would indirectly modify by calling Replicate().
401   PartitionedHlo ReshardNoCache(const HloSharding& target,
402                                 std::optional<Literal> pad_value = std::nullopt,
403                                 bool allow_full_replication = true);
404 
405   // Helper function to broadcast data from a single device to all devices.
406   PartitionedHlo Broadcast() const;
407 
408   // Try to perform complicated reshard handling by splitting a big reshard into
409   // multiple reshards using that can be handled directly.
410   std::optional<PartitionedHlo> TryComplexReshardHandling(
411       const HloSharding& target);
412 
413   // Helper function to reshard the tensor using AllToAll (instead of the
414   // default of Replicate followed by Slice).
415   PartitionedHlo ReshardWithAllToAll(
416       const HloSharding& target,
417       absl::Span<const std::pair<int64_t, int64_t>> source_target_dims) const;
418 
419   // Helper function to reshard the tensor using CollectivePermute.
420   PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const;
421 
422   // Helper function to reshard to partial replicate using AllGather.
423   std::optional<PartitionedHlo> ReshardToPartialReplicateWithAllGather(
424       const HloSharding& target);
425 
426   // Helper function to reshard from partial replicate using DynamicSlice.
427   std::optional<PartitionedHlo> ReshardFromPartialReplicateWithDynamicSlice(
428       const HloSharding& target);
429 
430   // Helper function to reshard from partial replicate using AllToAll.
431   std::optional<PartitionedHlo> ReshardPartialReplicateWithAllToAll(
432       const HloSharding& target);
433 
434   // SPMD instruction.
435   HloInstruction* hlo_;
436 
437   // The original shape of the data before SPMD transformation is applied.
438   Shape base_shape_;
439 
440   PartitioningState state_;
441 };
442 
443 struct DotConvDimsMapping {
444   // The dimension numbers for the operands and output corresponding to a
445   // logical dimension (e.g., batch, contracting, non-contracting). If an
446   // operand or the output doesn't have the logical dimension, it is set to
447   // -1.
448   struct DimsMapping {
449     int64_t lhs;
450     int64_t rhs;
451     int64_t output;
452     // input mapped to index in input_spatial_dimensions().
453     int64_t spatial;
454   };
455   std::vector<DimsMapping> batch_dims;
456   std::vector<DimsMapping> contracting_dims;
457   std::vector<DimsMapping> lhs_non_contracting_dims;
458   std::vector<DimsMapping> rhs_non_contracting_dims;
459   std::vector<DimsMapping> conv_spatial_dims;
460 };
461 
462 class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
463  public:
464   SpmdPartitioningVisitor(
465       HloComputation* computation, int64_t num_partitions, int64_t num_replicas,
466       const SPMDCollectiveOpsCreator& collective_ops_creator,
467       int64_t* next_channel_id, SpmdLogger* logger,
468       SpmdPartitionerOptions options, SpmdPartitioner* partitioner);
469 
470   Status DefaultAction(HloInstruction* hlo) override;
471   Status HandleAllReduce(HloInstruction* hlo) override;
472   Status HandleBroadcast(HloInstruction* hlo) override;
473   Status HandleConstant(HloInstruction* hlo) override;
474   Status HandleCustomCall(HloInstruction* hlo) override;
475   Status HandleDot(HloInstruction* hlo) override;
476   Status HandleDynamicSlice(HloInstruction* hlo) override;
477   Status HandleDynamicUpdateSlice(HloInstruction* hlo) override;
478   Status HandleFft(HloInstruction* hlo) override;
479   Status HandleGather(HloInstruction* hlo) override;
480   Status HandleGetTupleElement(HloInstruction* hlo) override;
481   Status HandleInfeed(HloInstruction* hlo) override;
482   Status HandleOptimizationBarrier(HloInstruction* hlo) override;
483   Status HandleOutfeed(HloInstruction* hlo) override;
484   Status HandlePad(HloInstruction* hlo) override;
485   Status HandleParameter(HloInstruction* hlo) override;
486   Status HandleReduce(HloInstruction* hlo) override;
487   Status HandleReverse(HloInstruction* hlo) override;
488   Status HandleWhile(HloInstruction* hlo) override;
489   Status HandleConditional(HloInstruction* hlo) override;
490   Status HandleReduceWindow(HloInstruction* hlo) override;
491   Status HandleSelectAndScatter(HloInstruction* hlo) override;
492   Status HandleTuple(HloInstruction* hlo) override;
493   Status HandleRng(HloInstruction* hlo) override;
494   Status HandleConvolution(HloInstruction* hlo) override;
495   Status HandleConcatenate(HloInstruction* hlo) override;
496   Status HandleScatter(HloInstruction* hlo) override;
497   Status HandleSlice(HloInstruction* hlo) override;
498   Status HandleSort(HloInstruction* hlo) override;
499   Status HandleTranspose(HloInstruction* hlo) override;
500   Status HandleReshape(HloInstruction* hlo) override;
501   Status HandleIota(HloInstruction* hlo) override;
502   Status HandlePartitionId(HloInstruction* hlo) override;
503 
504   // Implementation of dot partitioning given DotGeneralDimsMapping.
505   Status HandleDotHelper(HloInstruction* hlo,
506                          const DotConvDimsMapping& dims_mapping,
507                          const std::function<StatusOr<HloInstruction*>(
508                              HloInstruction*, HloInstruction*, SpmdBuilder*,
509                              const Window& conv_window)>& create_sharded_dot);
510 
511   // Common handle for elementwise HLOs.
512   Status HandleElementwise(HloInstruction* hlo);
513 
514   // Common handle for HLOs that runs on a single device.
515   Status HandleSingleDevice(const HloInstruction* hlo);
516 
517   // CustomCall handlers per call target.
518   Status HandleCustomCallTopK(HloInstruction* hlo);
519   // Convenient custom ops defined by the partitioner itself.
520   Status HandleCustomCallSPMDInternal_RotateRight(HloInstruction* hlo);
521 
522   // Returns the PartitionedHlo that corresponds to the original hlo.
GetPartitionedHlo(const HloInstruction * hlo)523   PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) {
524     CHECK_EQ(partitioned_instructions_.count(hlo), 1);
525     return partitioned_instructions_.find(hlo)->second;
526   }
527 
528   // Sets the PartitionedHlo for the original hlo.
SetPartitionedHlo(const HloInstruction * hlo,const PartitionedHlo & partitioned_hlo)529   void SetPartitionedHlo(const HloInstruction* hlo,
530                          const PartitionedHlo& partitioned_hlo) {
531     CHECK_EQ(partitioned_instructions_.count(hlo), 0);
532     partitioned_instructions_.emplace(hlo, partitioned_hlo);
533     changed_ = true;
534   }
535 
536   // Convenient wrapper that creates PartitionedHlo from the result of the func
537   // and maps it to the given original hlo.
SetPartitionedHlo(const HloInstruction * hlo,const std::function<HloInstruction * ()> & func)538   void SetPartitionedHlo(const HloInstruction* hlo,
539                          const std::function<HloInstruction*()>& func) {
540     HloInstruction* new_hlo = func();
541     new_hlo->set_sharding(hlo->sharding());
542     SetPartitionedHlo(
543         hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState()));
544     changed_ = true;
545   }
546 
NewChannel()547   int64_t NewChannel() { return (*next_channel_id_)++; }
548 
549   PartitionedHlo::PartitioningState MakePartitioningState();
550 
builder()551   SpmdBuilder* builder() { return &b_; }
552 
553   virtual StatusOr<bool> DoPartition(HloComputation* computation,
554                                      const HloSharding& root_sharding,
555                                      const SpmdPartitionerOptions& options);
556 
GetComputationTimeInMilliSec(HloInstruction * hlo)557   virtual double GetComputationTimeInMilliSec(HloInstruction* hlo) {
558     return 0.0;
559   }
560 
GetCommunicationTimeInMilliSec(int64_t bytes,absl::Span<const ReplicaGroup> device_groups)561   virtual double GetCommunicationTimeInMilliSec(
562       int64_t bytes, absl::Span<const ReplicaGroup> device_groups) {
563     return 0.0;
564   }
565 
GetCommunicationMultiplier(absl::Span<const ReplicaGroup> device_groups)566   virtual int GetCommunicationMultiplier(
567       absl::Span<const ReplicaGroup> device_groups) {
568     return 1;
569   }
570 
571   std::vector<ReplicaGroup> CreateReplicaGroups(
572       std::vector<std::vector<int64_t>>& groups);
573 
574   // Information about a loop created for windowed dot-general. Used when
575   // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor
576   // finishes traversing the graph.
577   struct WindowedDotGeneralLoop {
578     HloInstruction* while_loop;
579     int64_t windowed_operand;
580     bool windowed_in_contracting_dims;
581     bool windowed_in_batch_dims;
582     bool operands_sharded_at_contracting_dims;
583     int64_t num_partitions;
584     std::vector<ReplicaGroup> loop_replica_groups;
585   };
586 
587  protected:
588   Status Preprocess(HloInstruction* hlo) override;
589   Status Postprocess(HloInstruction* hlo) override;
590 
591   // Performs code motion for windowed dot-general loops in
592   // windowed_dot_general_loops_. Invoked after the visitor finishes traversing
593   // the graph.
594   Status DoCodeMotionForWindowedDotGeneralLoops(
595       HloComputation* computation, const SpmdPartitionerOptions& options);
596 
597   bool changed_;
598   HloModule* module_;
599   int64_t num_partitions_;
600   int64_t num_replicas_;
601 
602   SPMDCollectiveOpsCreator collective_ops_creator_;
603 
604   // Tracks the next channel id to use for cross-partition all-reduce.
605   int64_t* next_channel_id_;
606   SpmdBuilder b_;
607 
608   std::vector<WindowedDotGeneralLoop> windowed_dot_general_loops_;
609 
610   HloInstruction* partition_id_;
611 
612  private:
613   PartitionedHlo::ReshardCache reshard_cache_;
614 
615   // Mapping from the instruction in the original computation to the new SPMD
616   // partitioned instruction.
617   ConstHloInstructionMap<PartitionedHlo> partitioned_instructions_;
618 
619   HloInstruction* visiting_hlo_;
620   SpmdLogger* logger_;
621   const SpmdPartitionerOptions options_;
622   SpmdPartitioner* partitioner_;
623   std::vector<HloSharding> visiting_hlo_operand_shardings_;
624   std::optional<HloSharding> visiting_hlo_sharding_;
625   std::optional<int64_t> visiting_num_partitions_;
626   std::optional<SPMDCollectiveOpsCreator> visiting_collective_ops_creator_;
627   std::optional<HloInstruction*> visiting_partition_id_;
628   std::vector<PartitionedHlo::PartitioningState> visiting_state_;
629   std::vector<std::vector<int64_t>> device_groups_;
630 };
631 
632 }  // namespace spmd
633 }  // namespace xla
634 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
635