1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
18
19 #include <limits>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
32 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
33 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35
36 namespace xla {
37 namespace spmd {
38
39 template <typename T>
40 using IsCompOrCompBuilder =
41 typename std::enable_if_t<std::is_same<HloComputation, T>::value ||
42 std::is_same<HloComputation::Builder, T>::value ||
43 std::is_same<SpmdBuilder, T>::value>;
44
45 struct GatherParallelDimSharding {
46 HloSharding indices_sharding;
47 HloSharding operand_sharding;
48 };
49
50 // Returns true if the given sharding contains any replicated sharding.
51 bool HasReplicatedSharding(const HloSharding& sharding);
52
53 // Base for creating constants.
54 template <typename T, typename = IsCompOrCompBuilder<T>>
CreateConstantBase(const Shape & shape,Literal value,T * b,Literal (* literal_creator)(Literal,PrimitiveType))55 HloInstruction* CreateConstantBase(const Shape& shape, Literal value, T* b,
56 Literal (*literal_creator)(Literal,
57 PrimitiveType)) {
58 if (shape.IsTuple()) {
59 std::vector<HloInstruction*> elements;
60 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
61 elements.push_back(
62 CreateConstantBase(ShapeUtil::GetTupleElementShape(shape, i),
63 value.Clone(), b, literal_creator));
64 }
65 return b->AddInstruction(HloInstruction::CreateTuple(elements));
66 }
67
68 if (shape.IsToken()) {
69 return b->AddInstruction(HloInstruction::CreateToken());
70 }
71 auto c = b->AddInstruction(HloInstruction::CreateConstant(
72 literal_creator(std::move(value), shape.element_type())));
73 if (shape.rank() == 0) {
74 return c;
75 }
76 return b->AddInstruction(HloInstruction::CreateBroadcast(shape, c, {}));
77 }
78
79 // Creates constant value instructions of the given shape. The literal must be a
80 // scalar shape and is broadcast to the given shape.
81 template <typename T, typename = IsCompOrCompBuilder<T>>
CreateConstant(const Shape & shape,Literal value,T * b)82 HloInstruction* CreateConstant(const Shape& shape, Literal value, T* b) {
83 auto identity = [](Literal value, PrimitiveType primitive_type) {
84 CHECK(ShapeUtil::IsScalarWithElementType(value.shape(), primitive_type));
85 return value;
86 };
87 return CreateConstantBase(shape, std::move(value), b, identity);
88 }
89
90 // Creates zero value instructions of the given shape.
91 template <typename T, typename = IsCompOrCompBuilder<T>>
CreateZero(const Shape & shape,T * b)92 HloInstruction* CreateZero(const Shape& shape, T* b) {
93 auto zero = [](Literal /*unused*/, PrimitiveType primitive_type) {
94 return LiteralUtil::Zero(primitive_type);
95 };
96 return CreateConstantBase(shape, /*unused*/ Literal(), b, zero);
97 }
98
99 // Creates one value instructions of the given shape.
100 template <typename T, typename = IsCompOrCompBuilder<T>>
CreateOne(const Shape & shape,T * b)101 HloInstruction* CreateOne(const Shape& shape, T* b) {
102 auto one = [](Literal /*unused*/, PrimitiveType primitive_type) {
103 return LiteralUtil::One(primitive_type);
104 };
105 return CreateConstantBase(shape, /*unused*/ Literal(), b, one);
106 }
107
108 template <typename NativeT, typename T, typename = IsCompOrCompBuilder<T>>
CreateR0WithType(PrimitiveType type,NativeT value,T * b)109 HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value, T* b) {
110 auto literal = LiteralUtil::CreateR0(value)
111 .ConvertToShape(ShapeUtil::MakeShape(type, {}))
112 .ValueOrDie();
113 return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
114 }
115
116 template <typename T, typename = IsCompOrCompBuilder<T>>
CreateFirstWithType(PrimitiveType type,T * b)117 inline HloInstruction* CreateFirstWithType(PrimitiveType type, T* b) {
118 if (type == F32) {
119 auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
120 return CreateR0WithType(type, -float_pad_value, b);
121 }
122 auto literal = LiteralUtil::MinValue(type);
123 return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
124 }
125
126 template <typename T, typename = IsCompOrCompBuilder<T>>
CreateLastWithType(PrimitiveType type,T * b)127 inline HloInstruction* CreateLastWithType(PrimitiveType type, T* b) {
128 if (type == F32) {
129 auto float_pad_value = std::numeric_limits<float>::quiet_NaN();
130 return CreateR0WithType(type, float_pad_value, b);
131 }
132 auto literal = LiteralUtil::MaxValue(type);
133 return b->AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
134 }
135
136 // Create a binary add computation of the given type and add to the module.
137 HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module);
138
139 // Returns true if the shape can be evenly partitioned for the given sharding.
140 // All tile sharded dimensions should be evenly divisible and there should be no
141 // single-device sharding. Replicate sharding is considered even partition.
142 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding);
143
144 // Returns the shard shape of the given shape when it is partitioned for the
145 // target sharding.
146 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding);
147
148 // Similar to ShapeUtil::ByteSizeOf(), but does not check it has dense layout
149 // since this can be before layout assignment.
150 int64_t ShapeSizeInBytes(const Shape& shape);
151
152 // Returns the shard shape for a partition without padding due to uneven
153 // sharding.
154 Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape,
155 const HloSharding& sharding,
156 int64_t partition_id);
157
158 // Generates the HLO instructions that represent the dimension offsets on any
159 // device. The size of the returned vector is the rank of the given shape.
160 // If `dims` is non-empty, the generated offsets will only be non-zero for those
161 // dimensions.
162 std::vector<HloInstruction*> MakePartitionOffsets(
163 const Shape& shape, const HloSharding& sharding,
164 HloInstruction* partition_id, SpmdBuilder* b,
165 absl::Span<const int64_t> dims = {});
166
167 // Returns the offsets of the partition in the tile assignment.
168 std::vector<HloInstruction*> MakeTiledPartitionOrdinals(
169 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b);
170
171 // Pads hlo to the desired shape using high padding. Either a builder or a
172 // computation needs to be supplied, but not both.
173 template <typename T, typename = IsCompOrCompBuilder<T>>
174 HloInstruction* PadToShape(HloInstruction* hlo, const Shape& padded_shape, T* b,
175 std::optional<Literal> value = std::nullopt) {
176 if (ShapeUtil::Compatible(hlo->shape(), padded_shape)) {
177 return hlo;
178 }
179 PaddingConfig padding_config;
180 for (int64_t i = 0; i < padded_shape.rank(); ++i) {
181 auto padding_config_dim = padding_config.add_dimensions();
182 padding_config_dim->set_edge_padding_low(0);
183 padding_config_dim->set_interior_padding(0);
184 padding_config_dim->set_edge_padding_high(padded_shape.dimensions(i) -
185 hlo->shape().dimensions(i));
186 }
187 const Shape padding_shape =
188 ShapeUtil::MakeScalarShape(hlo->shape().element_type());
189 HloInstruction* padding =
190 value.has_value() ? CreateConstant(padding_shape, std::move(*value), b)
191 : CreateZero(padding_shape, b);
192 return b->AddInstruction(
193 HloInstruction::CreatePad(padded_shape, hlo, padding, padding_config));
194 }
195
196 // Returns the padded shape when combining all partitions.
197 Shape GetPaddedShapeForUnevenPartitioning(const Shape& base_shape,
198 const HloSharding& sharding);
199
200 // Pads the HLO (with base shape) for uneven tiled partition to make it evenly
201 // partitionable.
202 template <typename T, typename = IsCompOrCompBuilder<T>>
203 HloInstruction* PadBaseShapeBeforeUnevenTiledSharding(
204 HloInstruction* hlo, const HloSharding& sharding, T* b,
205 std::optional<Literal> value = std::nullopt) {
206 auto padded_base_shape =
207 GetPaddedShapeForUnevenPartitioning(hlo->shape(), sharding);
208 if (ShapeUtil::Compatible(padded_base_shape, hlo->shape())) {
209 return hlo;
210 }
211 return PadToShape(hlo, padded_base_shape, b, std::move(value));
212 }
213
214 // Returns the index of the unique tile dimension. Returns std::nullopt if the
215 // given sharding is not tiled or tiled along multiple dimensions.
216 std::optional<int64_t> UniqueTiledDim(const HloSharding& sharding);
217
218 // Utilities for symbolic offset calculation and halo exchange.
219 class OffsetCalculation;
220
221 // Represents a calculation over integers:
222 // (shard_ordinal * multiplier + offset) / divisor
223 class MultiplyAddDivideOffsetCalculation {
224 public:
MultiplyAddDivideOffsetCalculation()225 MultiplyAddDivideOffsetCalculation()
226 : multiplier_(0), offset_(0), divisor_(1) {}
227 MultiplyAddDivideOffsetCalculation(int64_t multiplier, int64_t offset,
228 int64_t divisor);
229
230 OffsetCalculation operator-(
231 const MultiplyAddDivideOffsetCalculation& other) const;
232
233 bool operator==(const MultiplyAddDivideOffsetCalculation& other) const {
234 return multiplier_ == other.multiplier_ && offset_ == other.offset_ &&
235 divisor_ == other.divisor_;
236 }
237
IsConstant()238 bool IsConstant() const { return multiplier_ == 0; }
239 void Simplify();
240 int64_t Calculate(int64_t shard_ordinal) const;
241 HloInstruction* Calculate(HloInstruction* shard_ordinal,
242 SpmdBuilder* b) const;
243
244 // Returns the maximum result for shard ordinals in the range
245 // [start_ordinal, limit_ordinal).
246 int64_t MaxInRange(int64_t start_ordinal, int64_t limit_ordinal) const;
247
248 private:
249 int64_t multiplier_;
250 int64_t offset_;
251 int64_t divisor_;
252 };
253
254 // Represents a calculation over integers based on results of other calculations
255 // defined by an opcode. If the opcode is kCopy, it simply wraps an
256 // MultiplyAddDivideOffsetCalculation.
257 class OffsetCalculation {
258 public:
OffsetCalculation()259 OffsetCalculation() : opcode_(HloOpcode::kCopy), copy_from_() {}
OffsetCalculation(const MultiplyAddDivideOffsetCalculation & copy_from)260 explicit OffsetCalculation(
261 const MultiplyAddDivideOffsetCalculation& copy_from)
262 : opcode_(HloOpcode::kCopy), copy_from_(copy_from) {}
OffsetCalculation(const OffsetCalculation & copy_from)263 OffsetCalculation(const OffsetCalculation& copy_from) { *this = copy_from; }
OffsetCalculation(HloOpcode opcode,const MultiplyAddDivideOffsetCalculation & lhs,const MultiplyAddDivideOffsetCalculation & rhs)264 OffsetCalculation(HloOpcode opcode,
265 const MultiplyAddDivideOffsetCalculation& lhs,
266 const MultiplyAddDivideOffsetCalculation& rhs)
267 : opcode_(opcode),
268 lhs_(std::make_unique<OffsetCalculation>(lhs)),
269 rhs_(std::make_unique<OffsetCalculation>(rhs)) {}
OffsetCalculation(HloOpcode opcode,const OffsetCalculation & lhs,const OffsetCalculation & rhs)270 OffsetCalculation(HloOpcode opcode, const OffsetCalculation& lhs,
271 const OffsetCalculation& rhs)
272 : opcode_(opcode),
273 lhs_(std::make_unique<OffsetCalculation>(lhs)),
274 rhs_(std::make_unique<OffsetCalculation>(rhs)) {}
275
276 OffsetCalculation& operator=(const OffsetCalculation& other);
277
278 // Returns whether the calculation returns the same value for all shards. This
279 // is conservative and could return false even if it is actually constant.
280 bool IsConstant() const;
281
282 OffsetCalculation operator-(const OffsetCalculation& other) const;
283 bool operator==(const OffsetCalculation& other) const;
284 int64_t Calculate(int64_t shard_ordinal) const;
285 HloInstruction* Calculate(HloInstruction* shard_ordinal,
286 SpmdBuilder* b) const;
287
288 // Returns the maximum result for shard ordinals in the range
289 // [start_ordinal, limit_ordinal).
290 int64_t MaxInRange(int64_t start_ordinal, int64_t limit_ordinal) const;
291
292 private:
293 HloOpcode opcode_;
294 std::unique_ptr<OffsetCalculation> lhs_;
295 std::unique_ptr<OffsetCalculation> rhs_;
296 MultiplyAddDivideOffsetCalculation copy_from_;
297 };
298
299 // Performs halo exchange on the given dimension based on the provided
300 // left/right halo size functions. Returns nullopt if the halo is beyond the
301 // direct neighbor of the shard.
302 std::optional<HloInstruction*> ExchangeHalo(
303 HloInstruction* hlo, const OffsetCalculation& left_halo_size_function,
304 const OffsetCalculation& right_halo_size_function, int64_t dim,
305 const HloSharding& target,
306 const SPMDCollectiveOpsCreator& collective_ops_creator,
307 int64_t* next_channel_id, SpmdBuilder* b);
308
309 // Exchange halo on all dimensions of the HLO. Returns nullopt if any one of the
310 // dimensions fails to exchange halo (halo is beyond the neighbor shard).
311 std::optional<HloInstruction*> ExchangeHalo(
312 HloInstruction* hlo,
313 std::vector<OffsetCalculation> left_halo_size_functions,
314 std::vector<OffsetCalculation> right_halo_size_functions,
315 const HloSharding& target,
316 const SPMDCollectiveOpsCreator& collective_ops_creator,
317 int64_t* next_channel_id, SpmdBuilder* b);
318
319 // Exchanges halos and performs pad/dynamic-slice on the concatenated data such
320 // that the result starts with the first needed element on each shard. It also
321 // masks off invalid data due to padding.
322 // Arguments:
323 // hlo: the HLO op before halo exchange
324 // explicit_left_padding_on_full_shape: the amount of left padding to be added
325 // explicitly by this function on the base shape before partitioning. Without
326 // base dilation, this is usually set to the window's padding_low so that the
327 // sharded op do not need to add padding_low on the window; however, with base
328 // dilation, this could only be set to a custom size.
329 // padded_full_shape_size: the size of the padded full shape on the given
330 // dimension, which includes explicit_left_padding_on_full_shape and required
331 // right padding to make the shape evenly shardable.
332 // shard_size_with_halo: the shard size on the dimension after halo exchange.
333 // If different shards have different sizes, use the maximum size.
334 // offset_on_padded_shape: the offset HLO (S32) that represents the start of
335 // each shard on the padded full shape.
336 // pad_value: the padding value used on the full shape.
337 std::optional<HloInstruction*> ExchangeHaloAndGetValidData(
338 HloInstruction* hlo, const Shape& base_shape,
339 const OffsetCalculation& left_halo_size_function,
340 const OffsetCalculation& right_halo_size_function,
341 int64_t explicit_left_padding_on_full_shape, int64_t padded_full_shape_size,
342 int64_t shard_size_with_halo, int64_t dim, const HloSharding& target,
343 HloInstruction* offset_on_padded_shape, HloInstruction* pad_value,
344 HloInstruction* partition_ordinal,
345 const SPMDCollectiveOpsCreator& collective_ops_creator,
346 int64_t* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true);
347
348 // Uses halo exchange to change from right-padding to left-padding for uneven
349 // tiled sharding on the given dimensions. Tiled sharding always pads uneven
350 // partitioned data on the right, but we need to swap it to the left for
351 // kReverse or kConvolution with window reversal.
352 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
353 absl::Span<const int64_t> dims);
354
355 // Check if the computation is GT comparison and safe for NaNs.
356 bool IsNanSafeGt(HloComputation* computation);
357
358 // Return k in TopK when input value is parttioned in the sort dimension.
359 std::optional<int64_t> GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo);
360
361 // Slices the first k elements at slice dimension.
362 HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder,
363 int64_t slice_dim, int64_t k);
364
365 // Check if a dimension is sharded.
366 int64_t ShardCountAtDim(const HloSharding& sharding, int64_t dim);
367
368 // Returns the list of source-target pairs of dimensions to swap during
369 // resharding via all-to-all. Reshard can be done by swapping each pair at a
370 // time.
371 std::optional<std::vector<std::pair<int64_t, int64_t>>>
372 GetReshardAllToAllSourceTargetDims(const HloSharding& source,
373 const HloSharding& target);
374
375 // Returns whether the resharding can be done via collective-permute.
376 bool CanReshardWithCollectivePermute(const HloSharding& source,
377 const HloSharding& target);
378
379 // Returns a new GroupedSharding that has the same group definition of
380 // `reference`.
381 hlo_sharding_util::GroupedSharding AlignGroupsWith(
382 hlo_sharding_util::GroupedSharding grouped_sharding,
383 const hlo_sharding_util::GroupedSharding& reference,
384 bool ignore_group_order = false);
385
386 // Align device groups between the two ahrdings. Equivalent in calling
387 // GroupShardingOnDims on the two sharding AlignGroupsWith and then
388 // UngroupSharding
389 HloSharding AlignShardingOnDims(const HloSharding& sharding,
390 absl::Span<const int64_t> sharding_dims,
391 const HloSharding& reference,
392 absl::Span<const int64_t> reference_dims);
393
394 // AlignShardingOnDims only if it doesn't change the sharding when ungrouped.
395 std::optional<hlo_sharding_util::GroupedSharding> AlignGroupsWithIfCompatible(
396 hlo_sharding_util::GroupedSharding grouped_sharding,
397 const hlo_sharding_util::GroupedSharding& reference);
398
399 // Returns the per-group base shape, i.e., before applying the in-group
400 // sharding.
401 Shape GetPerGroupBaseShape(
402 const hlo_sharding_util::GroupedSharding& grouped_sharding,
403 const Shape& original_base_shape);
404
405 // Creates the nested partitioner state for in-group patitioning.
406 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
407 const PartitionedHlo::PartitioningState& state,
408 const std::vector<std::vector<int64_t>>& device_groups, SpmdBuilder* b);
409
410 // Partially shards a replicated HLO into groups along the group dimensions, and
411 // within each group data is still replicated.
412 HloInstruction* PerGroupSliceFromReplicated(
413 HloInstruction* replicated, HloInstruction* partition_id,
414 const std::vector<std::vector<int64_t>>& device_groups,
415 absl::Span<const int64_t> group_dims,
416 absl::Span<const int64_t> group_dim_sizes, SpmdBuilder* b);
417
418 // Pad the shape from partial replicate shape for `dst_sharding`.
419 // If dst_sharding needs more padding and per_shard_size increased in
420 // dst_sharding, halo exchange on the right side is needed.
421 std::optional<HloInstruction*> PadFromPartialReplicateShape(
422 HloInstruction* hlo, const Shape& base_shape,
423 const HloSharding& src_sharding, const HloSharding& dst_sharding,
424 const std::vector<int64_t>& expand_tile_dims,
425 const SPMDCollectiveOpsCreator& collective_ops_creator,
426 int64_t* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b);
427
428 // Get the compatible sharding from a partial replicate sharding to a desired
429 // target tiled sharding.
430 // Compatible means replicate sharding can transform to the target tile
431 // dimensions by dynamic slice.
432 // For example, if partial_sharding is
433 // {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}
434 // Target sharding is {devices=[2,2]0,1,2,3}, the returned compatible sharding
435 // will be sharding={devices=[2,2]0,2,1,3}.
436 // If patial replicate sharding is not partial replicate or can't reshard to
437 // target_tile_dims by dynamic slice, return std::nullopt.
438 // If target_sharding is already compatible, returns it.
439 std::optional<HloSharding> PartialReplicateReshardCompatibleSharding(
440 const HloSharding& partial_sharding, const HloSharding& target_sharding);
441
442 // Do left halo exchange if all-reduce directly from tile sharding to partial
443 // replicate sharding will remove useful data from the source.
444 std::optional<HloInstruction*> TileToPartialReplicateHaloExchange(
445 HloInstruction* hlo, const Shape& base_shape,
446 const HloSharding& src_sharding, const HloSharding& dst_sharding,
447 const std::vector<int64_t>& replicate_dims,
448 const SPMDCollectiveOpsCreator& collective_ops_creator,
449 int64_t* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b);
450
451 // Finds a list of dimensions that can be grouped on such that it will have the
452 // specified device groups. Group order and dimension order are ignored.
453 std::optional<std::vector<int64_t>> FindMatchingPartitionedDimsForGrouping(
454 const HloSharding& sharding,
455 const std::vector<std::vector<int64_t>>& device_groups);
456
457 // Create a sharding that matches the provided source sharding on the
458 // specified dimensions. 'target_dims' and 'source_dims' represent the
459 // dimensions for which the sharding should match in their respective shape.
460 // If some devices from the source sharding are left over (because not all the
461 // devices are allocated to 'source_dims' dimensions) then partial replication
462 // is employed to make sure the number of devices for the two sharding match.
463 HloSharding CreateMatchingShardingOnDims(const Shape& target_shape,
464 const HloSharding& source_sharding,
465 absl::Span<const int64_t> target_dims,
466 absl::Span<const int64_t> source_dims);
467
468 // Returns if the sharding across operand and indices of a gather is across
469 // parallel dimensions and matches what SPMD partitioner supports.
470 std::optional<GatherParallelDimSharding>
471 GatherOperandsShardedAcrossParallelDims(
472 const HloInstruction& operand, const HloInstruction& indices,
473 const hlo_sharding_util::GatherParallelDims& parallel_dims);
474
475 // Pattern rewrite preprocessing utilities.
476
477 // Returns rotate_amount if the concat(lhs, rhs) is equivalent to rotating the
478 // elements along the concat dimension to the right by rotate_amount, where the
479 // input of rotation is the shard operand of lhs and rhs. Returns -1 if the
480 // pattern is not found.
481 int64_t FindRotateRightPattern(const HloInstruction* concat,
482 const HloInstruction* lhs,
483 const HloInstruction* rhs);
484
485 // Describes the pad with wrap pattern.
486 struct PadWithWrapPattern {
487 int64_t lhs_slice_start;
488 int64_t rhs_slice_start;
489 std::vector<const HloInstruction*> lhs_modifiers;
490 std::vector<const HloInstruction*> rhs_modifiers;
491 };
492
493 // Returns the `PadWithWrapPattern` if the concat(lhs,mid,rhs) is equivalent to
494 // padding mid with wrapping (i.e., padding mid with slices of itself). Return
495 // std::nullopt if the pattern is not found.
496 std::optional<PadWithWrapPattern> FindPadWithWrapPattern(
497 const HloInstruction* concat, const HloInstruction* lhs,
498 const HloInstruction* mid, const HloInstruction* rhs);
499
500 } // namespace spmd
501 } // namespace xla
502
503 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_UTIL_H_
504