1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ 18 19 #include <memory> 20 #include <optional> 21 #include <string> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/container/flat_hash_set.h" 25 #include "absl/container/node_hash_map.h" 26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 27 #include "tensorflow/compiler/xla/service/hlo_computation.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 31 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 32 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 33 #include "tensorflow/compiler/xla/xla_data.pb.h" 34 35 namespace xla { 36 namespace spmd { 37 38 struct SpmdPartitionerOptions { 39 // Always exchange halo on LHS for all convolutions. If false, backprop filter 40 // convolution exchanges halo on RHS. 41 bool conv_halo_exchange_always_on_lhs = true; 42 43 // The number of instructions to be reported for the highest memory profile 44 // instructions. 45 int64_t report_instruction_count = 5; 46 47 // The minimum size in MiB of an einsum operand to be considered using 48 // windowed implementation in an HLO loop. 49 int64_t threshold_for_windowed_einsum_mib = 256; 50 51 // Whether unroll windowed einsum loop by degree of two. 52 bool unroll_windowed_einsum = false; 53 54 // Whether doing bidirectional collective permute in windowed einsum loop. 55 bool bidirectional_windowed_einsum = false; 56 57 // Whether the entry computations' signature could change after partitioning. 58 bool allow_module_signature_change = false; 59 60 // Whether to use cached all-gather to avoid repeatedly replicate a tiled 61 // tensor. If it is set to false, the result tends to be more 62 // memory-efficient, and the compiler can use the ScheduleAwareAllGatherCSE 63 // pass to CSE some all-gathers which are relatively close to each other. 64 bool cache_all_gather = true; 65 66 // When making a compromise between windowed einsum speed and memory usage 67 // prefer the former if true. 68 bool choose_faster_windowed_einsum_over_mem = false; 69 70 // Whether doing bidirectional communication when decomposing independent 71 // all-gathers. 72 bool bidirectional_decomposed_all_gather = false; 73 }; 74 75 // Class to wrap the computation builder to capture information during SPMD 76 // transformation. 77 class SpmdBuilder : public HloComputation::Builder { 78 public: SpmdBuilder(const std::string & name,HloInstruction * hlo)79 SpmdBuilder(const std::string& name, HloInstruction* hlo) 80 : HloComputation::Builder(name) { 81 visiting_hlo_ = hlo; 82 } 83 84 HloInstruction* AddInstruction( 85 std::unique_ptr<HloInstruction> instruction) override; 86 derived_instructions(HloInstruction * hlo)87 const std::vector<HloInstruction*>& derived_instructions( 88 HloInstruction* hlo) { 89 return instructions_.at(hlo); 90 } 91 set_visiting_hlo(HloInstruction * hlo)92 void set_visiting_hlo(HloInstruction* hlo) { 93 visiting_hlo_ = hlo; 94 instructions_[hlo]; 95 } 96 visiting_hlo()97 HloInstruction* visiting_hlo() const { return visiting_hlo_; } 98 99 // Wrapper of queries to broadcast_dims_. BroadcastDimsForCreatedHlo(const HloInstruction * hlo)100 std::optional<const absl::flat_hash_set<int64_t>*> BroadcastDimsForCreatedHlo( 101 const HloInstruction* hlo) { 102 auto it = broadcast_dims_.find(hlo); 103 if (it == broadcast_dims_.end()) { 104 return std::nullopt; 105 } 106 return &it->second; 107 } 108 109 private: 110 // Currently visiting instruction. 111 HloInstruction* visiting_hlo_; 112 113 // Map from the currently visiting (old) instruction to new instructions 114 // created during SPMD partitioning. 115 HloInstructionMap<std::vector<HloInstruction*>> instructions_; 116 117 // Maps from each created instruction to a set of dimensions that are from 118 // broadcasts or elementwise ops over broadcasts. This means elements along 119 // these dimensions have the same value. 120 absl::flat_hash_map<const HloInstruction*, absl::flat_hash_set<int64_t>> 121 broadcast_dims_; 122 }; 123 124 // A set of functions that create the cross-partition collective ops. 125 struct SPMDCollectiveOpsCreator { 126 // Function used to create a partition ID HLO. 127 std::function<HloInstruction*(SpmdBuilder*)> create_partition_id; 128 129 // Function used to create a cross-partition all-reduce HLO. 130 std::function<HloInstruction*( 131 SpmdBuilder*, HloInstruction* operand, HloComputation* reduction, 132 const std::vector<std::vector<int64_t>>& partition_subgroups, 133 int64_t channel_id)> 134 create_cross_partition_all_reduce; 135 136 // Function used to create a cross-partition collective-permute HLO. 137 std::function<HloInstruction*( 138 SpmdBuilder*, HloInstruction* operand, 139 std::vector<std::pair<int64_t, int64_t>>& src_dst_pairs, 140 int64_t next_channel_id)> 141 create_cross_partition_collective_permute; 142 143 // Function used to create a cross-partition all-to-all HLO. 144 std::function<HloInstruction*( 145 SpmdBuilder*, absl::Span<HloInstruction* const> operands, 146 const std::vector<std::vector<int64_t>>& partition_subgroups, 147 int64_t channel_id, std::optional<int64_t> split_dimension)> 148 create_cross_partition_all_to_all; 149 150 // Function used to create a cross-partition all-gather HLO. This is optional: 151 // if it is nullptr, the partitioner will use all-reduce instead. 152 std::function<HloInstruction*( 153 SpmdBuilder*, HloInstruction* operand, const Shape& ag_shape, 154 const std::vector<std::vector<int64_t>>& partition_subgroups, 155 int64_t channel_id, int64_t all_gather_dimension)> 156 create_cross_partition_all_gather; 157 }; 158 159 // Create a default SPMDCollectiveOpsCreator. 160 SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions, 161 int64_t num_replicas); 162 163 // Logger to report memory usage during SPMD partitioning. 164 class SpmdLogger { 165 public: SpmdLogger(int64_t report_instruction_count,bool disabled)166 SpmdLogger(int64_t report_instruction_count, bool disabled) 167 : report_instruction_count_(report_instruction_count), 168 disabled_(disabled) {} 169 static std::string ReportBeforePartition(const HloModule& module, 170 int64_t report_instruction_count); 171 static std::string ReportAfterPartition(const HloModule& module, 172 int64_t report_instruction_count); 173 174 // Registers the logging for the groups of instructions created to transform 175 // the given hlo. 176 void RegisterLogEntry(HloInstruction* hlo, 177 const std::vector<HloInstruction*>& group); 178 179 std::string MakeReport(); 180 181 private: 182 template <typename F> 183 static std::string ReportMemoryUsage(const HloModule& module, const F& filter, 184 int64_t report_instruction_count); 185 186 // A vector of logging messages (one for each original HLO instruction), where 187 // the first integer of the pair represents the size of the HBM used. 188 std::vector<std::pair<int64_t, std::string>> entries_; 189 190 int64_t report_instruction_count_; 191 192 // Note that we allow creating a *disabled* logger when logging is not 193 // enabled, in which case it is supposed to avoid doing any potentially 194 // expensive work. The logger is still created in this case and passed to the 195 // users to help avoid changing current call sites. 196 const bool disabled_; 197 }; 198 199 class SpmdPartitioningVisitor; 200 201 class SpmdPartitioner : public HloModulePass { 202 public: 203 SpmdPartitioner(int64_t num_partitions, int64_t num_replicas, 204 SpmdPartitionerOptions options); SpmdPartitioner(int64_t num_partitions,int64_t num_replicas,SpmdPartitionerOptions options,SPMDCollectiveOpsCreator collective_ops_creator)205 SpmdPartitioner(int64_t num_partitions, int64_t num_replicas, 206 SpmdPartitionerOptions options, 207 SPMDCollectiveOpsCreator collective_ops_creator) 208 : num_partitions_(num_partitions), 209 num_replicas_(num_replicas), 210 options_(std::move(options)), 211 collective_ops_creator_(std::move(collective_ops_creator)) {} name()212 absl::string_view name() const override { return "spmd-partitioning"; } 213 using HloPassInterface::Run; 214 StatusOr<bool> Run( 215 HloModule* module, 216 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 217 218 // Transforms the given computation with SPMD instructions, replacing it with 219 // a new computation. 220 StatusOr<bool> PartitionComputation(HloComputation* computation, 221 const HloSharding& root_sharding, 222 int64_t* next_channel_id, 223 SpmdLogger* logger); 224 225 // Creates all-gather(s) based on HloSharding. Can be overridden to customize. 226 // The default uses a single all-gather even if there are multiple sharded 227 // dimensions, and adds potential reshapes and transposes to achieve that. 228 // If it returns false, the partitioner will fall back to all-reduce. 229 // `selected_dims` specifies the dimensions along which the all-gather happens 230 // in the tiled sharding, which allows potentially creating a subgroup 231 // all-gather. 232 virtual HloInstruction* AllGatherShards( 233 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 234 int64_t* next_channel_id, absl::Span<const int64_t> selected_dims, 235 const SPMDCollectiveOpsCreator& collectives_creator); 236 237 // Creates all-reduce(s) across devices along selected_dims in sharding. Can 238 // be overridden to customize. 239 virtual HloInstruction* AllReduceAlongShardingDims( 240 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 241 int64_t* next_channel_id, absl::Span<const int64_t> selected_dims, 242 const SPMDCollectiveOpsCreator& collectives_creator, 243 HloComputation* reduction); 244 options()245 const SpmdPartitionerOptions& options() { return options_; } 246 247 protected: 248 virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor( 249 HloComputation* computation, int64_t num_partitions, int64_t num_replicas, 250 const SPMDCollectiveOpsCreator& collective_ops_creator, 251 int64_t* next_channel_id, SpmdLogger* logger, 252 SpmdPartitionerOptions options); 253 254 HloInstruction* AllGatherShardsInternal( 255 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 256 int64_t* next_channel_id, absl::Span<const int64_t> selected_dims, 257 const SPMDCollectiveOpsCreator& collectives_creator, bool per_dim_ag); 258 HloInstruction* AllReduceAlongShardingDimsInternal( 259 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 260 int64_t* next_channel_id, absl::Span<const int64_t> selected_dims, 261 const SPMDCollectiveOpsCreator& collectives_creator, 262 HloComputation* reduction, bool per_dim_ar); 263 264 // Verifies that the sharding of instructions in the module are valid, and 265 // also fill in missing sharding information. 266 virtual Status PreprocessSharding( 267 HloModule* module, 268 const absl::flat_hash_set<absl::string_view>& execution_threads); 269 270 // Returns if the given side-effecting instruction is allowed to have 271 // replicated sharding. CanSideEffectingHaveReplicatedSharding(const HloInstruction * hlo)272 virtual bool CanSideEffectingHaveReplicatedSharding( 273 const HloInstruction* hlo) { 274 return hlo->opcode() == HloOpcode::kInfeed || 275 hlo->opcode() == HloOpcode::kOutfeed; 276 } 277 278 // Preprocesses the graph to simplify some communication patterns. E.g., merge 279 // pad->slice into a single pad with potentially negative padding to avoid 280 // multiple halo exchanges. 281 Status PreprocessHlos( 282 HloModule* module, 283 const absl::flat_hash_set<absl::string_view>& execution_threads); 284 285 const int64_t num_partitions_; 286 const int64_t num_replicas_; 287 288 SpmdPartitionerOptions options_; 289 SPMDCollectiveOpsCreator collective_ops_creator_; 290 std::vector<std::vector<int64_t>> device_groups_; 291 }; 292 293 // Class describes partition state of the data represented by an HLO created 294 // during SPMD partitioning pass. 295 // 296 // Data on some devices may include padding region, if the base (full) shape 297 // could not be evenly partitioned. 298 class PartitionedHlo { 299 public: 300 // Return value for ReshardAsWindowedInput which describes the resharded HLO, 301 // the window for the user on the shard, and if necessary, the dynamic slice 302 // offsets to be applied to the output of the op being sharded. 303 struct WindowedInputShardReturnValue { 304 HloInstruction* sharded_input; 305 Window shard_window; 306 std::optional<std::vector<HloInstruction*>> dynamic_slice_index_on_output; 307 }; 308 // A cache for resharding each partitioned HLO. 309 struct ReshardCache { 310 struct PerHloCache { 311 absl::flat_hash_map<HloSharding, PartitionedHlo> reshard_cache; 312 std::vector< 313 std::tuple<HloSharding, Window, WindowedInputShardReturnValue>> 314 window_reshard_cache; 315 }; 316 // Use absl::node_hash_map for pointer stability. 317 absl::node_hash_map<HloInstruction*, PerHloCache> per_hlo_cache; 318 // Caches for nested partitioning of grouped sharding. Each string key 319 // represents a unique way of grouping devices. 320 absl::flat_hash_map<std::string, std::unique_ptr<ReshardCache>> 321 groupd_caches; 322 }; 323 struct PartitioningState { 324 SpmdBuilder* b; 325 HloModule* module; 326 int64_t num_replicas; 327 HloInstruction* partition_id; 328 SPMDCollectiveOpsCreator collective_ops_creator; 329 int64_t* next_channel_id; 330 ReshardCache* reshard_cache; 331 SpmdPartitioner* partitioner; 332 }; PartitionedHlo(HloInstruction * hlo,Shape base_shape,PartitioningState state)333 PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state) 334 : hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) { 335 CHECK(hlo->has_sharding()) 336 << "PartitionedHlo is missing sharding:" << hlo->ToString(); 337 // If the tuple shape instruction does not have a tuple sharding, reassign 338 // to use the tuple sharding. Reshard() implementation assumes this. 339 if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) { 340 hlo_->set_sharding( 341 hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie()); 342 } 343 } 344 345 // Reshards the current SPMD instruction to a new sharding with optional 346 // specified pad value used during resharding. Could only modify the reshard 347 // cache. 348 PartitionedHlo Reshard(const HloSharding& target, 349 std::optional<Literal> pad_value = std::nullopt); 350 351 // Pads the garbage area of the output with the provided value. Normally, 352 // unevenly partitioned dimensions are padded on the right, but this function 353 // allows specifying left-padded dimensions, which can be used during the 354 // handling of kReverse, etc. 355 PartitionedHlo PadWithValue( 356 HloInstruction* pad_value, 357 absl::Span<const int64_t> left_padded_dims = {}, 358 absl::Span<const int64_t> skipped_dims = {}) const; 359 360 // Same as PadWithValue but does not create a new PartitionedHlo. 361 HloInstruction* PadWithValueHlo( 362 HloInstruction* pad_value, 363 absl::Span<const int64_t> left_padded_dims = {}, 364 absl::Span<const int64_t> skipped_dims = {}) const; 365 366 PartitionedHlo PadWithZero(absl::Span<const int64_t> left_padded_dims = {}, 367 absl::Span<const int64_t> skipped_dims = {}) const; 368 369 // Returns the SPMD instruction. hlo()370 HloInstruction* hlo() const { return hlo_; } 371 372 // Returns the sharding of the SPMD instruction. sharding()373 const HloSharding& sharding() const { return hlo_->sharding(); } 374 375 // Original full shape of the data. base_shape()376 const Shape& base_shape() const { return base_shape_; } 377 NewChannel()378 int64_t NewChannel() const { return (*state_.next_channel_id)++; } 379 380 // Reshards the HLO to a usable partitioned input for a windowed user. Could 381 // only modify the reshard cache. 382 std::optional<WindowedInputShardReturnValue> ReshardAsWindowedInput( 383 const Window& window, const HloSharding& target, 384 HloInstruction* pad_value, bool mask_invalid_region = true); 385 state()386 const PartitioningState& state() const { return state_; } 387 388 // Helper function to replicate the data on all devices. Could only modify 389 // the reshard cache. 390 PartitionedHlo Replicate(); 391 392 // Helper function to replicate the data for partitions along the given dims. 393 HloInstruction* ReplicatePartial(absl::Span<const int64_t> dims); 394 395 // Set state of the partitoned HLO. set_state(PartitioningState state)396 void set_state(PartitioningState state) { state_ = std::move(state); } 397 398 private: 399 // Same as Reshard except that it does not explicitly modify the reshard 400 // cache, although it would indirectly modify by calling Replicate(). 401 PartitionedHlo ReshardNoCache(const HloSharding& target, 402 std::optional<Literal> pad_value = std::nullopt, 403 bool allow_full_replication = true); 404 405 // Helper function to broadcast data from a single device to all devices. 406 PartitionedHlo Broadcast() const; 407 408 // Try to perform complicated reshard handling by splitting a big reshard into 409 // multiple reshards using that can be handled directly. 410 std::optional<PartitionedHlo> TryComplexReshardHandling( 411 const HloSharding& target); 412 413 // Helper function to reshard the tensor using AllToAll (instead of the 414 // default of Replicate followed by Slice). 415 PartitionedHlo ReshardWithAllToAll( 416 const HloSharding& target, 417 absl::Span<const std::pair<int64_t, int64_t>> source_target_dims) const; 418 419 // Helper function to reshard the tensor using CollectivePermute. 420 PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const; 421 422 // Helper function to reshard to partial replicate using AllGather. 423 std::optional<PartitionedHlo> ReshardToPartialReplicateWithAllGather( 424 const HloSharding& target); 425 426 // Helper function to reshard from partial replicate using DynamicSlice. 427 std::optional<PartitionedHlo> ReshardFromPartialReplicateWithDynamicSlice( 428 const HloSharding& target); 429 430 // Helper function to reshard from partial replicate using AllToAll. 431 std::optional<PartitionedHlo> ReshardPartialReplicateWithAllToAll( 432 const HloSharding& target); 433 434 // SPMD instruction. 435 HloInstruction* hlo_; 436 437 // The original shape of the data before SPMD transformation is applied. 438 Shape base_shape_; 439 440 PartitioningState state_; 441 }; 442 443 struct DotConvDimsMapping { 444 // The dimension numbers for the operands and output corresponding to a 445 // logical dimension (e.g., batch, contracting, non-contracting). If an 446 // operand or the output doesn't have the logical dimension, it is set to 447 // -1. 448 struct DimsMapping { 449 int64_t lhs; 450 int64_t rhs; 451 int64_t output; 452 // input mapped to index in input_spatial_dimensions(). 453 int64_t spatial; 454 }; 455 std::vector<DimsMapping> batch_dims; 456 std::vector<DimsMapping> contracting_dims; 457 std::vector<DimsMapping> lhs_non_contracting_dims; 458 std::vector<DimsMapping> rhs_non_contracting_dims; 459 std::vector<DimsMapping> conv_spatial_dims; 460 }; 461 462 class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { 463 public: 464 SpmdPartitioningVisitor( 465 HloComputation* computation, int64_t num_partitions, int64_t num_replicas, 466 const SPMDCollectiveOpsCreator& collective_ops_creator, 467 int64_t* next_channel_id, SpmdLogger* logger, 468 SpmdPartitionerOptions options, SpmdPartitioner* partitioner); 469 470 Status DefaultAction(HloInstruction* hlo) override; 471 Status HandleAllReduce(HloInstruction* hlo) override; 472 Status HandleBroadcast(HloInstruction* hlo) override; 473 Status HandleConstant(HloInstruction* hlo) override; 474 Status HandleCustomCall(HloInstruction* hlo) override; 475 Status HandleDot(HloInstruction* hlo) override; 476 Status HandleDynamicSlice(HloInstruction* hlo) override; 477 Status HandleDynamicUpdateSlice(HloInstruction* hlo) override; 478 Status HandleFft(HloInstruction* hlo) override; 479 Status HandleGather(HloInstruction* hlo) override; 480 Status HandleGetTupleElement(HloInstruction* hlo) override; 481 Status HandleInfeed(HloInstruction* hlo) override; 482 Status HandleOptimizationBarrier(HloInstruction* hlo) override; 483 Status HandleOutfeed(HloInstruction* hlo) override; 484 Status HandlePad(HloInstruction* hlo) override; 485 Status HandleParameter(HloInstruction* hlo) override; 486 Status HandleReduce(HloInstruction* hlo) override; 487 Status HandleReverse(HloInstruction* hlo) override; 488 Status HandleWhile(HloInstruction* hlo) override; 489 Status HandleConditional(HloInstruction* hlo) override; 490 Status HandleReduceWindow(HloInstruction* hlo) override; 491 Status HandleSelectAndScatter(HloInstruction* hlo) override; 492 Status HandleTuple(HloInstruction* hlo) override; 493 Status HandleRng(HloInstruction* hlo) override; 494 Status HandleConvolution(HloInstruction* hlo) override; 495 Status HandleConcatenate(HloInstruction* hlo) override; 496 Status HandleScatter(HloInstruction* hlo) override; 497 Status HandleSlice(HloInstruction* hlo) override; 498 Status HandleSort(HloInstruction* hlo) override; 499 Status HandleTranspose(HloInstruction* hlo) override; 500 Status HandleReshape(HloInstruction* hlo) override; 501 Status HandleIota(HloInstruction* hlo) override; 502 Status HandlePartitionId(HloInstruction* hlo) override; 503 504 // Implementation of dot partitioning given DotGeneralDimsMapping. 505 Status HandleDotHelper(HloInstruction* hlo, 506 const DotConvDimsMapping& dims_mapping, 507 const std::function<StatusOr<HloInstruction*>( 508 HloInstruction*, HloInstruction*, SpmdBuilder*, 509 const Window& conv_window)>& create_sharded_dot); 510 511 // Common handle for elementwise HLOs. 512 Status HandleElementwise(HloInstruction* hlo); 513 514 // Common handle for HLOs that runs on a single device. 515 Status HandleSingleDevice(const HloInstruction* hlo); 516 517 // CustomCall handlers per call target. 518 Status HandleCustomCallTopK(HloInstruction* hlo); 519 // Convenient custom ops defined by the partitioner itself. 520 Status HandleCustomCallSPMDInternal_RotateRight(HloInstruction* hlo); 521 522 // Returns the PartitionedHlo that corresponds to the original hlo. GetPartitionedHlo(const HloInstruction * hlo)523 PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) { 524 CHECK_EQ(partitioned_instructions_.count(hlo), 1); 525 return partitioned_instructions_.find(hlo)->second; 526 } 527 528 // Sets the PartitionedHlo for the original hlo. SetPartitionedHlo(const HloInstruction * hlo,const PartitionedHlo & partitioned_hlo)529 void SetPartitionedHlo(const HloInstruction* hlo, 530 const PartitionedHlo& partitioned_hlo) { 531 CHECK_EQ(partitioned_instructions_.count(hlo), 0); 532 partitioned_instructions_.emplace(hlo, partitioned_hlo); 533 changed_ = true; 534 } 535 536 // Convenient wrapper that creates PartitionedHlo from the result of the func 537 // and maps it to the given original hlo. SetPartitionedHlo(const HloInstruction * hlo,const std::function<HloInstruction * ()> & func)538 void SetPartitionedHlo(const HloInstruction* hlo, 539 const std::function<HloInstruction*()>& func) { 540 HloInstruction* new_hlo = func(); 541 new_hlo->set_sharding(hlo->sharding()); 542 SetPartitionedHlo( 543 hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState())); 544 changed_ = true; 545 } 546 NewChannel()547 int64_t NewChannel() { return (*next_channel_id_)++; } 548 549 PartitionedHlo::PartitioningState MakePartitioningState(); 550 builder()551 SpmdBuilder* builder() { return &b_; } 552 553 virtual StatusOr<bool> DoPartition(HloComputation* computation, 554 const HloSharding& root_sharding, 555 const SpmdPartitionerOptions& options); 556 GetComputationTimeInMilliSec(HloInstruction * hlo)557 virtual double GetComputationTimeInMilliSec(HloInstruction* hlo) { 558 return 0.0; 559 } 560 GetCommunicationTimeInMilliSec(int64_t bytes,absl::Span<const ReplicaGroup> device_groups)561 virtual double GetCommunicationTimeInMilliSec( 562 int64_t bytes, absl::Span<const ReplicaGroup> device_groups) { 563 return 0.0; 564 } 565 GetCommunicationMultiplier(absl::Span<const ReplicaGroup> device_groups)566 virtual int GetCommunicationMultiplier( 567 absl::Span<const ReplicaGroup> device_groups) { 568 return 1; 569 } 570 571 std::vector<ReplicaGroup> CreateReplicaGroups( 572 std::vector<std::vector<int64_t>>& groups); 573 574 // Information about a loop created for windowed dot-general. Used when 575 // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor 576 // finishes traversing the graph. 577 struct WindowedDotGeneralLoop { 578 HloInstruction* while_loop; 579 int64_t windowed_operand; 580 bool windowed_in_contracting_dims; 581 bool windowed_in_batch_dims; 582 bool operands_sharded_at_contracting_dims; 583 int64_t num_partitions; 584 std::vector<ReplicaGroup> loop_replica_groups; 585 }; 586 587 protected: 588 Status Preprocess(HloInstruction* hlo) override; 589 Status Postprocess(HloInstruction* hlo) override; 590 591 // Performs code motion for windowed dot-general loops in 592 // windowed_dot_general_loops_. Invoked after the visitor finishes traversing 593 // the graph. 594 Status DoCodeMotionForWindowedDotGeneralLoops( 595 HloComputation* computation, const SpmdPartitionerOptions& options); 596 597 bool changed_; 598 HloModule* module_; 599 int64_t num_partitions_; 600 int64_t num_replicas_; 601 602 SPMDCollectiveOpsCreator collective_ops_creator_; 603 604 // Tracks the next channel id to use for cross-partition all-reduce. 605 int64_t* next_channel_id_; 606 SpmdBuilder b_; 607 608 std::vector<WindowedDotGeneralLoop> windowed_dot_general_loops_; 609 610 HloInstruction* partition_id_; 611 612 private: 613 PartitionedHlo::ReshardCache reshard_cache_; 614 615 // Mapping from the instruction in the original computation to the new SPMD 616 // partitioned instruction. 617 ConstHloInstructionMap<PartitionedHlo> partitioned_instructions_; 618 619 HloInstruction* visiting_hlo_; 620 SpmdLogger* logger_; 621 const SpmdPartitionerOptions options_; 622 SpmdPartitioner* partitioner_; 623 std::vector<HloSharding> visiting_hlo_operand_shardings_; 624 std::optional<HloSharding> visiting_hlo_sharding_; 625 std::optional<int64_t> visiting_num_partitions_; 626 std::optional<SPMDCollectiveOpsCreator> visiting_collective_ops_creator_; 627 std::optional<HloInstruction*> visiting_partition_id_; 628 std::vector<PartitionedHlo::PartitioningState> visiting_state_; 629 std::vector<std::vector<int64_t>> device_groups_; 630 }; 631 632 } // namespace spmd 633 } // namespace xla 634 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_ 635