xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
18 
19 #include <limits>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
32 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
33 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 
36 namespace xla {
37 namespace spmd {
38 
39 template <typename T>
40 using IsCompOrCompBuilder =
41     typename std::enable_if_t<std::is_same<HloComputation, T>::value ||
42                               std::is_same<HloComputation::Builder, T>::value ||
43                               std::is_same<SpmdBuilder, T>::value>;
44 
45 struct GatherParallelDimSharding {
46   HloSharding indices_sharding;
47   HloSharding operand_sharding;
48 };
49 
50 // Returns true if the given sharding contains any replicated sharding.
51 bool HasReplicatedSharding(const HloSharding& sharding);
52 
53 // Base for creating constants.
54 template <typename T, typename = IsCompOrCompBuilder<T>>
CreateConstantBase(const Shape & shape,Literal value,T * b,Literal (* literal_creator)(Literal,PrimitiveType))55 HloInstruction* CreateConstantBase(const Shape& shape, Literal value, T* b,
56                                    Literal (*literal_creator)(Literal,
57                                                               PrimitiveType)) {
58   if (shape.IsTuple()) {
59     std::vector<HloInstruction*> elements;
60     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
61       elements.push_back(
62           CreateConstantBase(ShapeUtil::GetTupleElementShape(shape, i),
63                              value.Clone(), b, literal_creator));
64     }
65     return b->AddInstruction(HloInstruction::CreateTuple(elements));
66   }
67 
68   if (shape.IsToken()) {
69     return b->AddInstruction(HloInstruction::CreateToken());
70   }
71   auto c = b->AddInstruction(HloInstruction::CreateConstant(
72       literal_creator(std::move(value), shape.element_type())));
73   if (shape.rank() == 0) {
74     return c;
75   }
76   return b->AddInstruction(HloInstruction::CreateBroadcast(shape, c, {}));
77 }
78 
79 // Creates constant value instructions of the given shape. The literal must be a
80 // scalar shape and is broadcast to the given shape.
81 template <typename T, typename = IsCompOrCompBuilder<T>>
CreateConstant(const Shape & shape,Literal value,T * b)82 HloInstruction* CreateConstant(const Shape& shape, Literal value, T* b) {
83   auto identity = [](Literal value, PrimitiveType primitive_type) {
84     CHECK(ShapeUtil::IsScalarWithElementType(value.shape(), primitive_type));
85     return value;
86   };
87   return CreateConstantBase(shape, std::move(value), b, identity);
88 }
89 
90 // Creates zero value instructions of the given shape.
91 template <typename T, typename = IsCompOrCompBuilder<T>>
CreateZero(const Shape & shape,T * b)92 HloInstruction* CreateZero(const Shape& shape, T* b) {
93   auto zero = [](Literal /*unused*/, PrimitiveType primitive_type) {
94     return LiteralUtil::Zero(primitive_type);
95   };
96   return CreateConstantBase(shape, /*unused*/ Literal(), b, zero);
97 }
98 
99 // Creates one value instructions of the given shape.
100 template <typename T, typename = IsCompOrCompBuilder<T>>
CreateOne(const Shape & shape,T * b)101 HloInstruction* CreateOne(const Shape& shape, T* b) {
102   auto one = [](Literal /*unused*/, PrimitiveType primitive_type) {
103     return LiteralUtil::One(primitive_type);
104   };
105   return CreateConstantBase(shape, /*unused*/ Literal(), b, one);
106 }
107 
108 template <typename NativeT, typename T, typename = IsCompOrCompBuilder<T>>
CreateR0WithType(PrimitiveType type,NativeT value,T * b)109 HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value, T* b) {
110   auto literal = LiteralUtil::CreateR0(value)
111                      .ConvertToShape(ShapeUtil::MakeShape(type, {}))
112                      .ValueOrDie();
113   return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
114 }
115 
116 template <typename T, typename = IsCompOrCompBuilder<T>>
CreateFirstWithType(PrimitiveType type,T * b)117 inline HloInstruction* CreateFirstWithType(PrimitiveType type, T* b) {
118   if (type == F32) {
119     auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
120     return CreateR0WithType(type, -float_pad_value, b);
121   }
122   auto literal = LiteralUtil::MinValue(type);
123   return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
124 }
125 
126 template <typename T, typename = IsCompOrCompBuilder<T>>
CreateLastWithType(PrimitiveType type,T * b)127 inline HloInstruction* CreateLastWithType(PrimitiveType type, T* b) {
128   if (type == F32) {
129     auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
130     return CreateR0WithType(type, float_pad_value, b);
131   }
132   auto literal = LiteralUtil::MaxValue(type);
133   return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
134 }
135 
136 // Create a binary add computation of the given type and add to the module.
137 HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module);
138 
139 // Returns true if the shape can be evenly partitioned for the given sharding.
140 // All tile sharded dimensions should be evenly divisible and there should be no
141 // single-device sharding. Replicate sharding is considered even partition.
142 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding);
143 
144 // Returns the shard shape of the given shape when it is partitioned for the
145 // target sharding.
146 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding);
147 
148 // Similar to ShapeUtil::ByteSizeOf(), but does not check it has dense layout
149 // since this can be before layout assignment.
150 int64_t ShapeSizeInBytes(const Shape& shape);
151 
152 // Returns the shard shape for a partition without padding due to uneven
153 // sharding.
154 Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
155                                           const HloSharding& sharding,
156                                           int64_t partition_id);
157 
158 // Generates the HLO instructions that represent the dimension offsets on any
159 // device. The size of the returned vector is the rank of the given shape.
160 // If `dims` is non-empty, the generated offsets will only be non-zero for those
161 // dimensions.
162 std::vector<HloInstruction*> MakePartitionOffsets(
163     const Shape& shape, const HloSharding& sharding,
164     HloInstruction* partition_id, SpmdBuilder* b,
165     absl::Span<const int64_t> dims = {});
166 
167 // Returns the offsets of the partition in the tile assignment.
168 std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
169     const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b);
170 
171 // Pads hlo to the desired shape using high padding. Either a builder or a
172 // computation needs to be supplied, but not both.
173 template <typename T, typename = IsCompOrCompBuilder<T>>
174 HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape, T* b,
175                            std::optional<Literal> value = std::nullopt) {
176   if (ShapeUtil::Compatible(hlo->shape(), padded_shape)) {
177     return hlo;
178   }
179   PaddingConfig padding_config;
180   for (int64_t i = 0; i < padded_shape.rank(); ++i) {
181     auto padding_config_dim = padding_config.add_dimensions();
182     padding_config_dim->set_edge_padding_low(0);
183     padding_config_dim->set_interior_padding(0);
184     padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
185                                               hlo->shape().dimensions(i));
186   }
187   const Shape padding_shape =
188       ShapeUtil::MakeScalarShape(hlo->shape().element_type());
189   HloInstruction* padding =
190       value.has_value() ? CreateConstant(padding_shape, std::move(*value), b)
191                         : CreateZero(padding_shape, b);
192   return b->AddInstruction(
193       HloInstruction::CreatePad(padded_shape, hlo, padding, padding_config));
194 }
195 
196 // Returns the padded shape when combining all partitions.
197 Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
198                                           const HloSharding& sharding);
199 
200 // Pads the HLO (with base shape) for uneven tiled partition to make it evenly
201 // partitionable.
202 template <typename T, typename = IsCompOrCompBuilder<T>>
203 HloInstruction* PadBaseShapeBeforeUnevenTiledSharding(
204     HloInstruction* hlo, const HloSharding& sharding, T* b,
205     std::optional<Literal> value = std::nullopt) {
206   auto padded_base_shape =
207       GetPaddedShapeForUnevenPartitioning(hlo->shape(), sharding);
208   if (ShapeUtil::Compatible(padded_base_shape, hlo->shape())) {
209     return hlo;
210   }
211   return PadToShape(hlo, padded_base_shape, b, std::move(value));
212 }
213 
214 // Returns the index of the unique tile dimension. Returns std::nullopt if the
215 // given sharding is not tiled or tiled along multiple dimensions.
216 std::optional<int64_t> UniqueTiledDim(const HloSharding& sharding);
217 
218 // Utilities for symbolic offset calculation and halo exchange.
219 class OffsetCalculation;
220 
221 // Represents a calculation over integers:
222 //   (shard_ordinal * multiplier + offset) / divisor
223 class MultiplyAddDivideOffsetCalculation {
224  public:
MultiplyAddDivideOffsetCalculation()225   MultiplyAddDivideOffsetCalculation()
226       : multiplier_(0), offset_(0), divisor_(1) {}
227   MultiplyAddDivideOffsetCalculation(int64_t multiplier, int64_t offset,
228                                      int64_t divisor);
229 
230   OffsetCalculation operator-(
231       const MultiplyAddDivideOffsetCalculation& other) const;
232 
233   bool operator==(const MultiplyAddDivideOffsetCalculation& other) const {
234     return multiplier_ == other.multiplier_ && offset_ == other.offset_ &&
235            divisor_ == other.divisor_;
236   }
237 
IsConstant()238   bool IsConstant() const { return multiplier_ == 0; }
239   void Simplify();
240   int64_t Calculate(int64_t shard_ordinal) const;
241   HloInstruction* Calculate(HloInstruction* shard_ordinal,
242                             SpmdBuilder* b) const;
243 
244   // Returns the maximum result for shard ordinals in the range
245   // [start_ordinal, limit_ordinal).
246   int64_t MaxInRange(int64_t start_ordinal, int64_t limit_ordinal) const;
247 
248  private:
249   int64_t multiplier_;
250   int64_t offset_;
251   int64_t divisor_;
252 };
253 
254 // Represents a calculation over integers based on results of other calculations
255 // defined by an opcode. If the opcode is kCopy, it simply wraps an
256 // MultiplyAddDivideOffsetCalculation.
257 class OffsetCalculation {
258  public:
OffsetCalculation()259   OffsetCalculation() : opcode_(HloOpcode::kCopy), copy_from_() {}
OffsetCalculation(const MultiplyAddDivideOffsetCalculation & copy_from)260   explicit OffsetCalculation(
261       const MultiplyAddDivideOffsetCalculation& copy_from)
262       : opcode_(HloOpcode::kCopy), copy_from_(copy_from) {}
OffsetCalculation(const OffsetCalculation & copy_from)263   OffsetCalculation(const OffsetCalculation& copy_from) { *this = copy_from; }
OffsetCalculation(HloOpcode opcode,const MultiplyAddDivideOffsetCalculation & lhs,const MultiplyAddDivideOffsetCalculation & rhs)264   OffsetCalculation(HloOpcode opcode,
265                     const MultiplyAddDivideOffsetCalculation& lhs,
266                     const MultiplyAddDivideOffsetCalculation& rhs)
267       : opcode_(opcode),
268         lhs_(std::make_unique<OffsetCalculation>(lhs)),
269         rhs_(std::make_unique<OffsetCalculation>(rhs)) {}
OffsetCalculation(HloOpcode opcode,const OffsetCalculation & lhs,const OffsetCalculation & rhs)270   OffsetCalculation(HloOpcode opcode, const OffsetCalculation& lhs,
271                     const OffsetCalculation& rhs)
272       : opcode_(opcode),
273         lhs_(std::make_unique<OffsetCalculation>(lhs)),
274         rhs_(std::make_unique<OffsetCalculation>(rhs)) {}
275 
276   OffsetCalculation& operator=(const OffsetCalculation& other);
277 
278   // Returns whether the calculation returns the same value for all shards. This
279   // is conservative and could return false even if it is actually constant.
280   bool IsConstant() const;
281 
282   OffsetCalculation operator-(const OffsetCalculation& other) const;
283   bool operator==(const OffsetCalculation& other) const;
284   int64_t Calculate(int64_t shard_ordinal) const;
285   HloInstruction* Calculate(HloInstruction* shard_ordinal,
286                             SpmdBuilder* b) const;
287 
288   // Returns the maximum result for shard ordinals in the range
289   // [start_ordinal, limit_ordinal).
290   int64_t MaxInRange(int64_t start_ordinal, int64_t limit_ordinal) const;
291 
292  private:
293   HloOpcode opcode_;
294   std::unique_ptr<OffsetCalculation> lhs_;
295   std::unique_ptr<OffsetCalculation> rhs_;
296   MultiplyAddDivideOffsetCalculation copy_from_;
297 };
298 
299 // Performs halo exchange on the given dimension based on the provided
300 // left/right halo size functions. Returns nullopt if the halo is beyond the
301 // direct neighbor of the shard.
302 std::optional<HloInstruction*> ExchangeHalo(
303     HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
304     const OffsetCalculation& right_halo_size_function, int64_t dim,
305     const HloSharding& target,
306     const SPMDCollectiveOpsCreator& collective_ops_creator,
307     int64_t* next_channel_id, SpmdBuilder* b);
308 
309 // Exchange halo on all dimensions of the HLO. Returns nullopt if any one of the
310 // dimensions fails to exchange halo (halo is beyond the neighbor shard).
311 std::optional<HloInstruction*> ExchangeHalo(
312     HloInstruction* hlo,
313     std::vector<OffsetCalculation> left_halo_size_functions,
314     std::vector<OffsetCalculation> right_halo_size_functions,
315     const HloSharding& target,
316     const SPMDCollectiveOpsCreator& collective_ops_creator,
317     int64_t* next_channel_id, SpmdBuilder* b);
318 
319 // Exchanges halos and performs pad/dynamic-slice on the concatenated data such
320 // that the result starts with the first needed element on each shard. It also
321 // masks off invalid data due to padding.
322 // Arguments:
323 //  hlo: the HLO op before halo exchange
324 //  explicit_left_padding_on_full_shape: the amount of left padding to be added
325 //   explicitly by this function on the base shape before partitioning. Without
326 //   base dilation, this is usually set to the window's padding_low so that the
327 //   sharded op do not need to add padding_low on the window; however, with base
328 //   dilation, this could only be set to a custom size.
329 //  padded_full_shape_size: the size of the padded full shape on the given
330 //   dimension, which includes explicit_left_padding_on_full_shape and required
331 //   right padding to make the shape evenly shardable.
332 //  shard_size_with_halo: the shard size on the dimension after halo exchange.
333 //   If different shards have different sizes, use the maximum size.
334 //  offset_on_padded_shape: the offset HLO (S32) that represents the start of
335 //   each shard on the padded full shape.
336 //  pad_value: the padding value used on the full shape.
337 std::optional<HloInstruction*> ExchangeHaloAndGetValidData(
338     HloInstruction* hlo, const Shape& base_shape,
339     const OffsetCalculation& left_halo_size_function,
340     const OffsetCalculation& right_halo_size_function,
341     int64_t explicit_left_padding_on_full_shape, int64_t padded_full_shape_size,
342     int64_t shard_size_with_halo, int64_t dim, const HloSharding& target,
343     HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
344     HloInstruction* partition_ordinal,
345     const SPMDCollectiveOpsCreator& collective_ops_creator,
346     int64_t* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true);
347 
348 // Uses halo exchange to change from right-padding to left-padding for uneven
349 // tiled sharding on the given dimensions. Tiled sharding always pads uneven
350 // partitioned data on the right, but we need to swap it to the left for
351 // kReverse or kConvolution with window reversal.
352 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
353                                         absl::Span<const int64_t> dims);
354 
355 // Check if the computation is GT comparison and safe for NaNs.
356 bool IsNanSafeGt(HloComputation* computation);
357 
358 // Return k in TopK when input value is parttioned in the sort dimension.
359 std::optional<int64_t> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo);
360 
361 // Slices the first k elements at slice dimension.
362 HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
363                             int64_t slice_dim, int64_t k);
364 
365 // Check if a dimension is sharded.
366 int64_t ShardCountAtDim(const HloSharding& sharding, int64_t dim);
367 
368 // Returns the list of source-target pairs of dimensions to swap during
369 // resharding via all-to-all. Reshard can be done by swapping each pair at a
370 // time.
371 std::optional<std::vector<std::pair<int64_t, int64_t>>>
372 GetReshardAllToAllSourceTargetDims(const HloSharding& source,
373                                    const HloSharding& target);
374 
375 // Returns whether the resharding can be done via collective-permute.
376 bool CanReshardWithCollectivePermute(const HloSharding& source,
377                                      const HloSharding& target);
378 
379 // Returns a new GroupedSharding that has the same group definition of
380 // `reference`.
381 hlo_sharding_util::GroupedSharding AlignGroupsWith(
382     hlo_sharding_util::GroupedSharding grouped_sharding,
383     const hlo_sharding_util::GroupedSharding& reference,
384     bool ignore_group_order = false);
385 
386 // Align device groups between the two ahrdings. Equivalent in calling
387 // GroupShardingOnDims on the two sharding AlignGroupsWith and then
388 // UngroupSharding
389 HloSharding AlignShardingOnDims(const HloSharding& sharding,
390                                 absl::Span<const int64_t> sharding_dims,
391                                 const HloSharding& reference,
392                                 absl::Span<const int64_t> reference_dims);
393 
394 // AlignShardingOnDims only if it doesn't change the sharding when ungrouped.
395 std::optional<hlo_sharding_util::GroupedSharding> AlignGroupsWithIfCompatible(
396     hlo_sharding_util::GroupedSharding grouped_sharding,
397     const hlo_sharding_util::GroupedSharding& reference);
398 
399 // Returns the per-group base shape, i.e., before applying the in-group
400 // sharding.
401 Shape GetPerGroupBaseShape(
402     const hlo_sharding_util::GroupedSharding& grouped_sharding,
403     const Shape& original_base_shape);
404 
405 // Creates the nested partitioner state for in-group patitioning.
406 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
407     const PartitionedHlo::PartitioningState& state,
408     const std::vector<std::vector<int64_t>>& device_groups, SpmdBuilder* b);
409 
410 // Partially shards a replicated HLO into groups along the group dimensions, and
411 // within each group data is still replicated.
412 HloInstruction* PerGroupSliceFromReplicated(
413     HloInstruction* replicated, HloInstruction* partition_id,
414     const std::vector<std::vector<int64_t>>& device_groups,
415     absl::Span<const int64_t> group_dims,
416     absl::Span<const int64_t> group_dim_sizes, SpmdBuilder* b);
417 
418 // Pad the shape from partial replicate shape for `dst_sharding`.
419 // If dst_sharding needs more padding and per_shard_size increased in
420 // dst_sharding, halo exchange on the right side is needed.
421 std::optional<HloInstruction*> PadFromPartialReplicateShape(
422     HloInstruction* hlo, const Shape& base_shape,
423     const HloSharding& src_sharding, const HloSharding& dst_sharding,
424     const std::vector<int64_t>& expand_tile_dims,
425     const SPMDCollectiveOpsCreator& collective_ops_creator,
426     int64_t* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b);
427 
428 // Get the compatible sharding from a partial replicate sharding to a desired
429 // target tiled sharding.
430 // Compatible means replicate sharding can transform to the target tile
431 // dimensions by dynamic slice.
432 // For example, if partial_sharding is
433 // {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
434 // Target sharding is {devices=[2,2]0,1,2,3}, the returned compatible sharding
435 // will be sharding={devices=[2,2]0,2,1,3}.
436 // If patial replicate sharding is not partial replicate or can't reshard to
437 // target_tile_dims by dynamic slice, return std::nullopt.
438 // If target_sharding is already compatible, returns it.
439 std::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
440     const HloSharding& partial_sharding, const HloSharding& target_sharding);
441 
442 // Do left halo exchange if all-reduce directly from tile sharding to partial
443 // replicate sharding will remove useful data from the source.
444 std::optional<HloInstruction*> TileToPartialReplicateHaloExchange(
445     HloInstruction* hlo, const Shape& base_shape,
446     const HloSharding& src_sharding, const HloSharding& dst_sharding,
447     const std::vector<int64_t>& replicate_dims,
448     const SPMDCollectiveOpsCreator& collective_ops_creator,
449     int64_t* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b);
450 
451 // Finds a list of dimensions that can be grouped on such that it will have the
452 // specified device groups. Group order and dimension order are ignored.
453 std::optional<std::vector<int64_t>> FindMatchingPartitionedDimsForGrouping(
454     const HloSharding& sharding,
455     const std::vector<std::vector<int64_t>>& device_groups);
456 
457 // Create a sharding that matches the provided source sharding on the
458 // specified dimensions. 'target_dims' and 'source_dims' represent the
459 // dimensions for which the sharding should match in their respective shape.
460 // If some devices from the source sharding are left over (because not all the
461 // devices are allocated to 'source_dims' dimensions) then partial replication
462 // is employed to make sure the number of devices for the two sharding match.
463 HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
464                                          const HloSharding& source_sharding,
465                                          absl::Span<const int64_t> target_dims,
466                                          absl::Span<const int64_t> source_dims);
467 
468 // Returns if the sharding across operand and indices of a gather is across
469 // parallel dimensions and matches what SPMD partitioner supports.
470 std::optional<GatherParallelDimSharding>
471 GatherOperandsShardedAcrossParallelDims(
472     const HloInstruction& operand, const HloInstruction& indices,
473     const hlo_sharding_util::GatherParallelDims& parallel_dims);
474 
475 // Pattern rewrite preprocessing utilities.
476 
477 // Returns rotate_amount if the concat(lhs, rhs) is equivalent to rotating the
478 // elements along the concat dimension to the right by rotate_amount, where the
479 // input of rotation is the shard operand of lhs and rhs. Returns -1 if the
480 // pattern is not found.
481 int64_t FindRotateRightPattern(const HloInstruction* concat,
482                                const HloInstruction* lhs,
483                                const HloInstruction* rhs);
484 
485 // Describes the pad with wrap pattern.
486 struct PadWithWrapPattern {
487   int64_t lhs_slice_start;
488   int64_t rhs_slice_start;
489   std::vector<const HloInstruction*> lhs_modifiers;
490   std::vector<const HloInstruction*> rhs_modifiers;
491 };
492 
493 // Returns the `PadWithWrapPattern` if the concat(lhs,mid,rhs) is equivalent to
494 // padding mid with wrapping (i.e., padding mid with slices of itself). Return
495 // std::nullopt if the pattern is not found.
496 std::optional<PadWithWrapPattern> FindPadWithWrapPattern(
497     const HloInstruction* concat, const HloInstruction* lhs,
498     const HloInstruction* mid, const HloInstruction* rhs);
499 
500 }  // namespace spmd
501 }  // namespace xla
502 
503 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
504