1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/sharding_propagation.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/algorithm/container.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/status/status.h"
31 #include "absl/strings/str_split.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/protobuf_util.h"
34 #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
35 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_computation.h"
37 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
41 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
42 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
43 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/sharding_op_util.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/statusor.h"
52
53 namespace xla {
54 namespace {
55
56 // Returns true iff the specified hlo or sharding has a spatially partitioned
57 // sharding (tiled or replicated) that can be propagated by sharding
58 // propagation.
IsSpatiallyPartitioned(const HloSharding & sharding)59 bool IsSpatiallyPartitioned(const HloSharding& sharding) {
60 if (sharding.IsTuple()) {
61 return absl::c_any_of(sharding.tuple_elements(), IsSpatiallyPartitioned);
62 } else {
63 return !sharding.IsTileMaximal() || sharding.IsReplicated();
64 }
65 }
IsSpatiallyPartitioned(const HloInstruction * hlo)66 bool IsSpatiallyPartitioned(const HloInstruction* hlo) {
67 return hlo->has_sharding() && IsSpatiallyPartitioned(hlo->sharding());
68 }
69
70 // Updates the sharding of the specified instruction with the specified sharding
71 // if it is better than the current one and returns true if a new sharding have
72 // been applied. If may_combine_partial_sharding is true, this may combine the
73 // new and existing sharding if they are both partial tiling partial
74 // replication.
MaybeImproveInstructionSharding(HloSharding sharding,HloInstruction * instruction,bool may_combine_partial_sharding,bool allow_aggressive_resharding=false)75 bool MaybeImproveInstructionSharding(HloSharding sharding,
76 HloInstruction* instruction,
77 bool may_combine_partial_sharding,
78 bool allow_aggressive_resharding = false) {
79 // We don't want to propagate tile maximal shardings.
80 if (!IsSpatiallyPartitioned(sharding)) {
81 return false;
82 }
83 // Any sharding is better then no sharding.
84 if (!instruction->has_sharding()) {
85 instruction->set_sharding(std::move(sharding));
86 return true;
87 }
88 int64_t sharding_tiles = sharding.NumTiles();
89 if (hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding,
90 may_combine_partial_sharding)) {
91 // Override existing tiled sharding only when the new sharding is compatible
92 // with the existing one. This avoids unexpected resharding when `sharding`
93 // just has more tiles than existing sharding but they are not mergeable.
94 if (!allow_aggressive_resharding && instruction->shape().IsArray() &&
95 !instruction->sharding().IsTileMaximal() &&
96 sharding.NumTiles() == sharding_tiles) {
97 std::vector<int64_t> diff_dims;
98 for (int64_t i = 0; i < instruction->shape().rank(); ++i) {
99 if (instruction->sharding().tile_assignment().dim(i) ==
100 sharding.tile_assignment().dim(i)) {
101 continue;
102 }
103 if (instruction->sharding().tile_assignment().dim(i) != 1) {
104 VLOG(10) << "Not merging because of dim i = " << i
105 << " sharded differently";
106 VLOG(10) << "Instr sharding: " << instruction->sharding().ToString();
107 VLOG(10) << "New sharding " << sharding.ToString();
108 return false;
109 }
110 diff_dims.push_back(i);
111 }
112 if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
113 sharding, diff_dims) != instruction->sharding()) {
114 VLOG(10) << "Not merging because of different device distribution";
115 VLOG(10) << "Instr sharding: " << instruction->sharding().ToString();
116 VLOG(10) << "New sharding " << sharding.ToString();
117 return false;
118 }
119 }
120 instruction->set_sharding(std::move(sharding));
121 return true;
122 }
123 return false;
124 }
125
126 // We consider a convolution kernel to be small iff it is smaller along all
127 // spatial dimensions then the output of the convolution. The rational is that
128 // we can either shard the kernel or the output and we want to shard the larger
129 // one for better efficiency.
IsConvolutionKernelSmall(const HloInstruction * instruction)130 bool IsConvolutionKernelSmall(const HloInstruction* instruction) {
131 CHECK_EQ(instruction->opcode(), HloOpcode::kConvolution);
132 const HloInstruction* rhs = instruction->operand(1);
133 const auto& dnums = instruction->convolution_dimension_numbers();
134 int64_t kernel_dim_prod = 1;
135 int64_t output_dim_prod = 1;
136 for (int64_t i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
137 int64_t kernel_dim =
138 rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i));
139 kernel_dim_prod *= kernel_dim;
140 int64_t output_dim =
141 instruction->shape().dimensions(dnums.output_spatial_dimensions(i));
142 output_dim_prod *= output_dim;
143 if (kernel_dim >= output_dim &&
144 (i < 2 || kernel_dim > 3 || kernel_dim_prod >= output_dim_prod)) {
145 return false;
146 }
147 }
148 return true;
149 }
150
IsPassthroughCustomOps(const HloInstruction * hlo)151 bool IsPassthroughCustomOps(const HloInstruction* hlo) {
152 if (hlo->IsCustomCall("Sharding")) {
153 return true;
154 }
155 if (hlo->IsCustomCall("X64Combine")) {
156 return true;
157 }
158 if (hlo->operand_count() != 1 || !hlo->shape().IsArray() ||
159 !hlo->operand(0)->shape().IsArray() ||
160 hlo->operand(0)->shape().rank() != hlo->shape().rank()) {
161 return false;
162 }
163 return hlo->IsCustomCall("ResizeNearest") ||
164 hlo->IsCustomCall("ResizeBilinear") ||
165 hlo->IsCustomCall("ResizeNearestGrad") ||
166 hlo->IsCustomCall("ResizeBilinearGrad") ||
167 hlo->IsCustomCall("Cholesky");
168 }
169
170 // Return the operand which is the most suitable for determining the sharding
171 // for the specified instruction or nullptr if there isn't any suitable operand.
PickRepresentativeOperand(const HloInstruction * instruction)172 const HloInstruction* PickRepresentativeOperand(
173 const HloInstruction* instruction) {
174 switch (instruction->opcode()) {
175 case HloOpcode::kMap:
176 case HloOpcode::kPad:
177 case HloOpcode::kPower:
178 case HloOpcode::kOptimizationBarrier:
179 case HloOpcode::kReverse:
180 case HloOpcode::kSlice:
181 case HloOpcode::kShiftLeft:
182 case HloOpcode::kShiftRightArithmetic:
183 case HloOpcode::kShiftRightLogical:
184 // For these opcodes the output sharding has to be determined by the
185 // sharding of the first operand but we can only determine sharding based
186 // on it if it already has a sharding.
187 if (instruction->operand(0)->has_sharding()) {
188 return instruction->operand(0);
189 }
190 return nullptr;
191 case HloOpcode::kAbs:
192 case HloOpcode::kAdd:
193 case HloOpcode::kAnd:
194 case HloOpcode::kAtan2:
195 case HloOpcode::kBitcastConvert:
196 case HloOpcode::kCeil:
197 case HloOpcode::kClamp:
198 case HloOpcode::kClz:
199 case HloOpcode::kCompare:
200 case HloOpcode::kComplex:
201 case HloOpcode::kConcatenate:
202 case HloOpcode::kConvert:
203 case HloOpcode::kCopy:
204 case HloOpcode::kCos:
205 case HloOpcode::kAllGather:
206 case HloOpcode::kAllReduce:
207 case HloOpcode::kReduceScatter:
208 case HloOpcode::kAllToAll:
209 case HloOpcode::kCollectivePermute:
210 case HloOpcode::kDivide:
211 case HloOpcode::kExp:
212 case HloOpcode::kExpm1:
213 case HloOpcode::kFloor:
214 case HloOpcode::kImag:
215 case HloOpcode::kIsFinite:
216 case HloOpcode::kLog:
217 case HloOpcode::kLog1p:
218 case HloOpcode::kLogistic:
219 case HloOpcode::kMaximum:
220 case HloOpcode::kMinimum:
221 case HloOpcode::kMultiply:
222 case HloOpcode::kNegate:
223 case HloOpcode::kNot:
224 case HloOpcode::kOr:
225 case HloOpcode::kPopulationCount:
226 case HloOpcode::kReal:
227 case HloOpcode::kReducePrecision:
228 case HloOpcode::kRemainder:
229 case HloOpcode::kRoundNearestAfz:
230 case HloOpcode::kRoundNearestEven:
231 case HloOpcode::kRsqrt:
232 case HloOpcode::kSelect:
233 case HloOpcode::kSign:
234 case HloOpcode::kSin:
235 case HloOpcode::kSort:
236 case HloOpcode::kSqrt:
237 case HloOpcode::kCbrt:
238 case HloOpcode::kSubtract:
239 case HloOpcode::kTanh:
240 case HloOpcode::kWhile:
241 case HloOpcode::kXor: {
242 // For these opcodes the output sharding can be determined by any operand
243 // so we find the operand with the most specific sharding.
244 const HloInstruction* best_operand = nullptr;
245 for (const HloInstruction* operand : instruction->operands()) {
246 if (operand->has_sharding() &&
247 (best_operand == nullptr ||
248 hlo_sharding_util::IsShardingMoreSpecific(
249 operand->sharding(), best_operand->sharding()))) {
250 best_operand = operand;
251 }
252 }
253 return best_operand;
254 }
255 case HloOpcode::kCustomCall: {
256 if (IsPassthroughCustomOps(instruction)) {
257 return instruction->operand(0);
258 }
259 return nullptr;
260 }
261 // There is no suitable operand for the rest of the opcodes.
262 case HloOpcode::kAddDependency:
263 case HloOpcode::kAfterAll:
264 case HloOpcode::kAsyncStart:
265 case HloOpcode::kAsyncUpdate:
266 case HloOpcode::kAsyncDone:
267 case HloOpcode::kAllGatherStart:
268 case HloOpcode::kAllGatherDone:
269 case HloOpcode::kAllReduceStart:
270 case HloOpcode::kAllReduceDone:
271 case HloOpcode::kBatchNormGrad:
272 case HloOpcode::kBatchNormInference:
273 case HloOpcode::kBatchNormTraining:
274 case HloOpcode::kBitcast:
275 case HloOpcode::kBroadcast:
276 case HloOpcode::kCall:
277 case HloOpcode::kCholesky:
278 case HloOpcode::kCollectivePermuteDone:
279 case HloOpcode::kCollectivePermuteStart:
280 case HloOpcode::kConditional:
281 case HloOpcode::kConstant:
282 case HloOpcode::kConvolution:
283 case HloOpcode::kCopyDone:
284 case HloOpcode::kCopyStart:
285 case HloOpcode::kDomain:
286 case HloOpcode::kDot:
287 case HloOpcode::kDynamicSlice:
288 case HloOpcode::kDynamicUpdateSlice:
289 case HloOpcode::kDynamicReshape:
290 case HloOpcode::kFft:
291 case HloOpcode::kFusion:
292 case HloOpcode::kGather:
293 case HloOpcode::kGetTupleElement:
294 case HloOpcode::kInfeed:
295 case HloOpcode::kIota:
296 case HloOpcode::kOutfeed:
297 case HloOpcode::kParameter:
298 case HloOpcode::kPartitionId:
299 case HloOpcode::kRecv:
300 case HloOpcode::kRecvDone:
301 case HloOpcode::kReduce:
302 case HloOpcode::kReduceWindow:
303 case HloOpcode::kReplicaId:
304 case HloOpcode::kReshape:
305 case HloOpcode::kRng:
306 case HloOpcode::kRngGetAndUpdateState:
307 case HloOpcode::kRngBitGenerator:
308 case HloOpcode::kScatter:
309 case HloOpcode::kSelectAndScatter:
310 case HloOpcode::kSend:
311 case HloOpcode::kSendDone:
312 case HloOpcode::kTranspose:
313 case HloOpcode::kTriangularSolve:
314 case HloOpcode::kTuple:
315 case HloOpcode::kGetDimensionSize:
316 case HloOpcode::kSetDimensionSize:
317 return nullptr;
318 }
319 }
320
SupportSpatialPartitioning(const HloInstruction * instruction,const ShardingPropagation::ComputationMap & computation_map,bool is_spmd,bool allow_spmd_sharding_propagation_to_output,const CustomCallShardingHelper * sharding_helper)321 bool SupportSpatialPartitioning(
322 const HloInstruction* instruction,
323 const ShardingPropagation::ComputationMap& computation_map, bool is_spmd,
324 bool allow_spmd_sharding_propagation_to_output,
325 const CustomCallShardingHelper* sharding_helper) {
326 const bool is_entry_root = instruction->parent()
327 ->parent()
328 ->entry_computation()
329 ->root_instruction() == instruction;
330 if (instruction->parent()->root_instruction() == instruction &&
331 computation_map.find(instruction->parent()) == computation_map.end() &
332 !(is_entry_root && allow_spmd_sharding_propagation_to_output)) {
333 // We don't support sharding the root instruction of a computation yet,
334 // unless the computation is a while body.
335 return false;
336 }
337
338 if (instruction->IsElementwise() &&
339 (instruction->opcode() != HloOpcode::kRng || is_spmd)) {
340 return true;
341 }
342 switch (instruction->opcode()) {
343 case HloOpcode::kBroadcast:
344 case HloOpcode::kConcatenate:
345 case HloOpcode::kConditional:
346 case HloOpcode::kConstant:
347 case HloOpcode::kConvolution:
348 case HloOpcode::kOptimizationBarrier:
349 case HloOpcode::kDot:
350 case HloOpcode::kDynamicSlice:
351 case HloOpcode::kDynamicUpdateSlice:
352 case HloOpcode::kGather:
353 case HloOpcode::kGetTupleElement:
354 case HloOpcode::kInfeed:
355 case HloOpcode::kIota:
356 case HloOpcode::kPad:
357 case HloOpcode::kReduceWindow:
358 case HloOpcode::kReshape:
359 case HloOpcode::kScatter:
360 case HloOpcode::kSelectAndScatter:
361 case HloOpcode::kSlice:
362 case HloOpcode::kSort:
363 case HloOpcode::kTranspose:
364 case HloOpcode::kTuple:
365 case HloOpcode::kWhile:
366 case HloOpcode::kReduce:
367 case HloOpcode::kRngBitGenerator:
368 return true;
369 case HloOpcode::kAllReduce:
370 case HloOpcode::kReduceScatter:
371 // Only if channel_id is not specified.
372 return instruction->channel_id() == std::nullopt;
373 case HloOpcode::kParameter:
374 return computation_map.find(instruction->parent()) !=
375 computation_map.end();
376 case HloOpcode::kReverse:
377 return is_spmd;
378 case HloOpcode::kCustomCall:
379 return is_spmd && (IsPassthroughCustomOps(instruction) ||
380 sharding_helper->IsCustomCallShardable(instruction));
381 default:
382 return false;
383 }
384 }
385
InferDotShardingFromOperands(HloInstruction * instruction,const dot_as_convolution_util::DotConvolutionDimsInfo & dnums,bool may_combine_partial_sharding)386 bool InferDotShardingFromOperands(
387 HloInstruction* instruction,
388 const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
389 bool may_combine_partial_sharding) {
390 auto from_operand = [&](int64_t operand_index) {
391 auto operand = instruction->operand(operand_index);
392 const HloSharding& operand_sharding = operand->sharding();
393 if (operand_sharding.IsTileMaximal()) {
394 return operand_sharding;
395 }
396 std::vector<int64_t> contracting_dims;
397 contracting_dims.reserve(dnums.contracting_dims.size());
398 for (const auto& dim : dnums.contracting_dims) {
399 contracting_dims.push_back(operand_index == 0 ? dim.lhs : dim.rhs);
400 }
401 // It's possible that some size-1 spatial dims of convolutions are parsed as
402 // non-contracting dims. We might have tiled dimensions on them.
403 for (const auto& dim : operand_index == 0
404 ? dnums.rhs_non_contracting_dims
405 : dnums.lhs_non_contracting_dims) {
406 int64_t d = operand_index == 0 ? dim.lhs : dim.rhs;
407 if (d > 0) {
408 contracting_dims.push_back(d);
409 }
410 }
411 auto replicate_contracting_dims =
412 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
413 operand_sharding, contracting_dims);
414 std::vector<int64_t> out_dims_to_op_perm(instruction->shape().rank(), -1);
415 std::vector<int64_t> op_dims_to_output_perm(operand->shape().rank(), -1);
416 for (const auto& dim : dnums.batch_dims) {
417 out_dims_to_op_perm[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
418 op_dims_to_output_perm[operand_index == 0 ? dim.lhs : dim.rhs] =
419 dim.output;
420 }
421 for (const auto& dim : operand_index == 0
422 ? dnums.lhs_non_contracting_dims
423 : dnums.rhs_non_contracting_dims) {
424 out_dims_to_op_perm[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
425 op_dims_to_output_perm[operand_index == 0 ? dim.lhs : dim.rhs] =
426 dim.output;
427 }
428 return *hlo_sharding_util::TransposeShardingWithCollapsedDims(
429 replicate_contracting_dims, op_dims_to_output_perm,
430 out_dims_to_op_perm);
431 };
432 bool changed = false;
433 int64_t larger_operand =
434 ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) >=
435 ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())
436 ? 0
437 : 1;
438 if (IsSpatiallyPartitioned(instruction->operand(larger_operand))) {
439 changed |= MaybeImproveInstructionSharding(from_operand(larger_operand),
440 instruction,
441 may_combine_partial_sharding);
442 }
443 if (IsSpatiallyPartitioned(instruction->operand(1 - larger_operand))) {
444 changed |= MaybeImproveInstructionSharding(from_operand(1 - larger_operand),
445 instruction,
446 may_combine_partial_sharding);
447 }
448 return changed;
449 }
450
InferGatherParallelShardingFromOperands(HloInstruction * instruction,const hlo_sharding_util::GatherParallelDims & parallel_dims,bool may_combine_partial_sharding)451 bool InferGatherParallelShardingFromOperands(
452 HloInstruction* instruction,
453 const hlo_sharding_util::GatherParallelDims& parallel_dims,
454 bool may_combine_partial_sharding) {
455 auto from_operand =
456 [instruction](int64_t operand_index,
457 absl::Span<const int64_t> output_aligned_parallel_dims,
458 absl::Span<const int64_t> output_parallel_dims) {
459 const HloInstruction* operand = instruction->operand(operand_index);
460 const HloSharding& operand_sharding = operand->sharding();
461 if (operand_sharding.IsTileMaximal()) {
462 return operand_sharding;
463 }
464 auto dnums = instruction->gather_dimension_numbers();
465 std::vector<int64_t> output_tile_dims(instruction->shape().rank(), 1);
466 std::vector<int64_t> index_non_parallel_dims;
467 index_non_parallel_dims.reserve(operand->shape().rank());
468 // Detect non parallel dimensions in the index.
469 for (int i = 0; i < operand->shape().rank(); ++i) {
470 if (!absl::c_linear_search(output_aligned_parallel_dims, i)) {
471 index_non_parallel_dims.push_back(i);
472 }
473 }
474 // Collect tile dimensions in the operand. The order of the parallel
475 // dimensions in output_aligned_parallel_dims is the same as that of the
476 // output
477 for (int i = 0; i < output_aligned_parallel_dims.size(); ++i) {
478 const int64_t indices_idx = output_aligned_parallel_dims[i];
479 const int64_t output_idx = output_parallel_dims[i];
480 output_tile_dims[output_idx] =
481 operand_sharding.tile_assignment().dim(indices_idx);
482 }
483 HloSharding replicate_non_parallel_dims =
484 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
485 operand_sharding, index_non_parallel_dims);
486 if (replicate_non_parallel_dims.IsTileMaximal()) {
487 return replicate_non_parallel_dims;
488 }
489 for (int64_t i = replicate_non_parallel_dims.TiledDataRank();
490 i < replicate_non_parallel_dims.tile_assignment().num_dimensions();
491 ++i) {
492 output_tile_dims.push_back(
493 replicate_non_parallel_dims.tile_assignment().dim(i));
494 }
495 auto output_tile_assignment =
496 replicate_non_parallel_dims.tile_assignment();
497 output_tile_assignment.Reshape(output_tile_dims);
498 return replicate_non_parallel_dims.ReplicateOnLastTileDim()
499 ? HloSharding::PartialTile(
500 output_tile_assignment,
501 replicate_non_parallel_dims.metadata())
502 : HloSharding::Subgroup(
503 output_tile_assignment,
504 replicate_non_parallel_dims.subgroup_types(),
505 replicate_non_parallel_dims.metadata());
506 };
507
508 bool changed = false;
509 auto output_parallel_dims =
510 hlo_sharding_util::GatherParallelOutputDims(*instruction, parallel_dims);
511 if (IsSpatiallyPartitioned(instruction->operand(0))) {
512 changed |= MaybeImproveInstructionSharding(
513 from_operand(
514 0,
515 absl::MakeConstSpan(
516 hlo_sharding_util::GatherOutputAlignedOperandParallelDims(
517 *instruction, parallel_dims)),
518 absl::MakeConstSpan(output_parallel_dims)),
519 instruction, may_combine_partial_sharding);
520 }
521 if (IsSpatiallyPartitioned(instruction->operand(1))) {
522 changed |= MaybeImproveInstructionSharding(
523 from_operand(1,
524 absl::MakeConstSpan(parallel_dims.indices_parallel_dims),
525 absl::MakeConstSpan(output_parallel_dims)),
526 instruction, may_combine_partial_sharding);
527 }
528 return changed;
529 }
530
531 // Convolution handling for InferShardingFromOperands().
InferConvolutionShardingFromOperands(HloInstruction * instruction,int64_t aggressiveness,bool may_combine_partial_sharding)532 bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
533 int64_t aggressiveness,
534 bool may_combine_partial_sharding) {
535 auto get_partitions_for_dims =
536 [&](const HloInstruction* inst,
537 absl::Span<
538 const dot_as_convolution_util::DotConvolutionDimsInfo::DimNums>
539 dims,
540 int lhs_or_rhs) {
541 int64_t partitions = 1;
542 if (!inst->has_sharding()) {
543 return partitions;
544 }
545 const auto& sharding = inst->sharding();
546 if (sharding.IsTileMaximal()) {
547 return partitions;
548 }
549 for (const auto& dim : dims) {
550 if (lhs_or_rhs == 0) {
551 partitions *= sharding.tile_assignment().dim(dim.lhs);
552 } else {
553 CHECK_EQ(lhs_or_rhs, 1);
554 partitions *= sharding.tile_assignment().dim(dim.rhs);
555 }
556 }
557 return partitions;
558 };
559 auto dot_dims =
560 dot_as_convolution_util::ParseConvolutionDimsInfo(instruction);
561 const int64_t lhs_conv_spatial_partitions = get_partitions_for_dims(
562 instruction->operand(0), dot_dims.conv_spatial_dims, 0);
563 const int64_t rhs_conv_spatial_partitions = get_partitions_for_dims(
564 instruction->operand(1), dot_dims.conv_spatial_dims, 1);
565 if (dot_dims.conv_spatial_dims.empty() ||
566 (lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 &&
567 instruction->batch_group_count() == 1 &&
568 instruction->feature_group_count() == 1)) {
569 return InferDotShardingFromOperands(instruction, dot_dims,
570 may_combine_partial_sharding);
571 }
572 const auto& dnums = instruction->convolution_dimension_numbers();
573 const HloInstruction* lhs = instruction->operand(0);
574 auto get_tiled_sharding_based_on_lhs = [&] {
575 CHECK(!lhs->sharding().IsTileMaximal());
576 std::vector<int64_t> output_to_lhs_indices(instruction->shape().rank());
577 output_to_lhs_indices[dnums.output_batch_dimension()] =
578 dnums.input_batch_dimension();
579 output_to_lhs_indices[dnums.output_feature_dimension()] =
580 dnums.input_feature_dimension();
581 for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
582 output_to_lhs_indices[dnums.output_spatial_dimensions(i)] =
583 dnums.input_spatial_dimensions(i);
584 }
585 return hlo_sharding_util::TransposeSharding(lhs->sharding(),
586 output_to_lhs_indices);
587 };
588 if (!IsSpatiallyPartitioned(lhs)) {
589 return false;
590 }
591 if (lhs->sharding().IsTileMaximal()) {
592 return MaybeImproveInstructionSharding(lhs->sharding(), instruction,
593 may_combine_partial_sharding);
594 }
595
596 if (IsConvolutionKernelSmall(instruction)) {
597 // If the kernel is small compared to the input then we can generate an
598 // output what is sharded the same way as the input.
599 const auto& tile_assignment = lhs->sharding().tile_assignment();
600 if (tile_assignment.dim(dnums.input_feature_dimension()) > 1) {
601 return false;
602 }
603 return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_lhs(),
604 instruction,
605 may_combine_partial_sharding);
606 }
607 // If the kernel is large (e.g backward convolution) then we only support
608 // replicated output.
609 return MaybeImproveInstructionSharding(
610 hlo_sharding_util::ReplicateAllDataDims(lhs->sharding(),
611 instruction->shape().rank()),
612 instruction, may_combine_partial_sharding);
613 }
614
CanPropagateThroughAtAggressiveLevel(const HloInstruction & inst,int64_t aggressiveness)615 bool CanPropagateThroughAtAggressiveLevel(const HloInstruction& inst,
616 int64_t aggressiveness) {
617 // At minimum aggressiveness, only allow pass-through ops.
618 if (aggressiveness < 1 &&
619 !(inst.IsElementwise() || inst.IsCustomCall("Sharding")) &&
620 inst.opcode() != HloOpcode::kTranspose &&
621 inst.opcode() != HloOpcode::kReshape &&
622 inst.opcode() != HloOpcode::kTuple &&
623 inst.opcode() != HloOpcode::kGetTupleElement &&
624 inst.opcode() != HloOpcode::kWhile &&
625 inst.opcode() != HloOpcode::kDynamicSlice &&
626 inst.opcode() != HloOpcode::kOptimizationBarrier &&
627 inst.opcode() != HloOpcode::kConcatenate) {
628 return false;
629 }
630 // Broadcast propagation should have at least aggressiveness 2.
631 if (aggressiveness < 2 && inst.opcode() == HloOpcode::kBroadcast) {
632 return false;
633 }
634 return true;
635 }
636
InferDotOperandSharding(const HloInstruction * instruction,const dot_as_convolution_util::DotConvolutionDimsInfo & dnums,int64_t operand_index,bool may_combine_partial_sharding)637 HloSharding InferDotOperandSharding(
638 const HloInstruction* instruction,
639 const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
640 int64_t operand_index, bool may_combine_partial_sharding) {
641 auto operand = instruction->operand(operand_index);
642 auto other = instruction->operand(1 - operand_index);
643 std::vector<int64_t> output_dims_to_replicate;
644 std::vector<int64_t> other_operand_dims_to_replicate;
645 for (const auto& dim : operand_index == 0 ? dnums.rhs_non_contracting_dims
646 : dnums.lhs_non_contracting_dims) {
647 output_dims_to_replicate.push_back(dim.output);
648 other_operand_dims_to_replicate.push_back(operand_index == 0 ? dim.rhs
649 : dim.lhs);
650 }
651 // If this dot is interpreted from a conv, then contracting dims may have
652 // corresponding spatial dimensions in the output, and this operand's
653 // non-contracting dims may have corresponding spatial dims in the other
654 // operand.
655 for (const auto& dim : dnums.contracting_dims) {
656 if (dim.output >= 0) {
657 output_dims_to_replicate.push_back(dim.output);
658 }
659 }
660 for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims
661 : dnums.rhs_non_contracting_dims) {
662 int64_t other_dim = operand_index == 0 ? dim.rhs : dim.lhs;
663 if (other_dim >= 0) {
664 other_operand_dims_to_replicate.push_back(other_dim);
665 }
666 }
667 auto output_other_dims_replicated =
668 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
669 instruction->sharding(), output_dims_to_replicate);
670 std::vector<int64_t> output_to_operand_dims(instruction->shape().rank(), -1);
671 std::vector<int64_t> operand_to_output_dims(operand->shape().rank(), -1);
672 for (const auto& dim : dnums.batch_dims) {
673 output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
674 operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output;
675 }
676 for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims
677 : dnums.rhs_non_contracting_dims) {
678 output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
679 operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output;
680 }
681 auto sharding = *hlo_sharding_util::TransposeShardingWithCollapsedDims(
682 output_other_dims_replicated, output_to_operand_dims,
683 operand_to_output_dims);
684 if (IsSpatiallyPartitioned(other)) {
685 auto other_operand_dims_replicated =
686 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
687 other->sharding(), other_operand_dims_to_replicate);
688 std::vector<int64_t> other_to_operand_dims(other->shape().rank(), -1);
689 std::vector<int64_t> operand_to_other_dims(operand->shape().rank(), -1);
690 for (const auto& dim : dnums.batch_dims) {
691 other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] =
692 operand_index == 0 ? dim.lhs : dim.rhs;
693 operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] =
694 operand_index == 0 ? dim.rhs : dim.lhs;
695 }
696 for (const auto& dim : dnums.contracting_dims) {
697 other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] =
698 operand_index == 0 ? dim.lhs : dim.rhs;
699 operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] =
700 operand_index == 0 ? dim.rhs : dim.lhs;
701 }
702 HloSharding sharding_from_other =
703 *hlo_sharding_util::TransposeShardingWithCollapsedDims(
704 other_operand_dims_replicated, other_to_operand_dims,
705 operand_to_other_dims);
706 if (hlo_sharding_util::MergeSharding(sharding, &sharding_from_other,
707 may_combine_partial_sharding)) {
708 sharding = std::move(sharding_from_other);
709 }
710 }
711 return sharding;
712 }
713
714 // Tries to update the sharding of the specified instruction based on its users
715 // and returns true if the sharding of the instruction have been changed and
716 // false otherwise.
InferShardingFromUsers(HloInstruction * instruction,const ShardingPropagation::ComputationMap & computation_map,int64_t aggressiveness,bool is_spmd,const CustomCallShardingHelper * sharding_helper)717 bool InferShardingFromUsers(
718 HloInstruction* instruction,
719 const ShardingPropagation::ComputationMap& computation_map,
720 int64_t aggressiveness, bool is_spmd,
721 const CustomCallShardingHelper* sharding_helper) {
722 if (aggressiveness < 2 && instruction->opcode() == HloOpcode::kBroadcast) {
723 return false;
724 }
725 // Do not change manual sharding.
726 if (instruction->has_sharding() && instruction->sharding().IsManual()) {
727 return false;
728 }
729 // Propagate manual sharding.
730 if (!instruction->has_sharding()) {
731 for (const HloInstruction* user : instruction->users()) {
732 if (!user->has_sharding() || !user->sharding().IsManual() ||
733 user->IsCustomCall("SPMDFullToShardShape"))
734 continue;
735 if (instruction->shape().IsArray()) {
736 instruction->set_sharding(
737 HloSharding::Manual(user->sharding().metadata()));
738 } else {
739 std::optional<HloSharding> user_sharding =
740 ShardingPropagation::GetShardingFromUser(*instruction, *user,
741 aggressiveness, is_spmd);
742 if (user_sharding) {
743 instruction->set_sharding(*user_sharding);
744 }
745 }
746 return true;
747 }
748 }
749 if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd, false,
750 sharding_helper)) {
751 return false;
752 }
753
754 bool improved_sharding = false;
755 const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
756 for (const HloInstruction* user : instruction->users()) {
757 std::optional<HloSharding> user_sharding =
758 ShardingPropagation::GetShardingFromUser(*instruction, *user,
759 aggressiveness, is_spmd);
760 if (user_sharding && sharding_helper->IsCustomCallShardable(instruction)) {
761 user_sharding = sharding_helper->PropagateUserSharding(instruction, user,
762 *user_sharding);
763 }
764 if (user_sharding) {
765 improved_sharding |= MaybeImproveInstructionSharding(
766 std::move(*user_sharding), instruction, may_combine_partial_sharding);
767 }
768 }
769 return improved_sharding;
770 }
771
772 // Checks if two HloShardings have the same metadata attached.
SameShardingMetadata(const HloSharding & a,const HloSharding & b)773 bool SameShardingMetadata(const HloSharding& a, const HloSharding& b) {
774 DCHECK_EQ(a, b);
775
776 auto same_metadata = [](absl::Span<const OpMetadata> a,
777 absl::Span<const OpMetadata> b) {
778 if (a.size() != b.size()) return false;
779 for (int i = 0, e = a.size(); i < e; ++i) {
780 if (!protobuf_util::ProtobufEquals(a[i], b[i])) {
781 return false;
782 }
783 }
784 return true;
785 };
786
787 if (a.IsTuple()) {
788 for (int i = 0, e = a.tuple_elements().size(); i < e; ++i) {
789 if (!same_metadata(a.tuple_elements()[i].metadata(),
790 b.tuple_elements()[i].metadata())) {
791 return false;
792 }
793 }
794 return true;
795 } else {
796 return same_metadata(a.metadata(), b.metadata());
797 }
798 }
799
800 // Assigns metadata to optional sharding on instructions if instructions have
801 // metadata. If sharding already has some metadata, no new metadata will be
802 // added.
AssignShardingMetadata(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)803 bool AssignShardingMetadata(
804 HloModule* module,
805 const absl::flat_hash_set<absl::string_view>& execution_threads) {
806 bool changed = false;
807 for (HloComputation* computation : module->computations(execution_threads)) {
808 for (HloInstruction* instruction : computation->instructions()) {
809 const auto& metadata = instruction->metadata();
810 if (!instruction->has_sharding() || metadata.ByteSizeLong() == 0) {
811 continue;
812 }
813
814 HloSharding sharding_with_metadata =
815 instruction->sharding().WithMetadata({metadata}, /*overwrite=*/false);
816 if (!SameShardingMetadata(instruction->sharding(),
817 sharding_with_metadata)) {
818 instruction->set_sharding(std::move(sharding_with_metadata));
819 changed = true;
820 }
821 }
822 }
823 return changed;
824 }
825
826 // Removes all sharding metadata from shardings on instructions.
RemoveShardingMetadata(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)827 bool RemoveShardingMetadata(
828 HloModule* module,
829 const absl::flat_hash_set<absl::string_view>& execution_threads) {
830 bool changed = false;
831 for (HloComputation* computation : module->computations(execution_threads)) {
832 for (HloInstruction* instruction : computation->instructions()) {
833 if (!instruction->has_sharding()) {
834 continue;
835 }
836 HloSharding sharding_no_metadata =
837 instruction->sharding().WithoutMetadata();
838 if (!SameShardingMetadata(instruction->sharding(),
839 sharding_no_metadata)) {
840 instruction->set_sharding(std::move(sharding_no_metadata));
841 changed = true;
842 }
843 }
844 }
845 return changed;
846 }
847
848 // If a while contains a channel instruction on device D, check that any other
849 // instructions with a device assignment are on D. Further, annotate the root
850 // instruction of the while body to ensure that HLO partitioning will keep the
851 // entire while instruction on D.
CheckAndUpdateDeviceAssignmentsInWhileBody(HloInstruction * while_instruction)852 Status CheckAndUpdateDeviceAssignmentsInWhileBody(
853 HloInstruction* while_instruction) {
854 auto bad_status = [](HloInstruction* instruction, int64_t device,
855 HloInstruction* channel_instruction,
856 int64_t correct_device) {
857 return FailedPrecondition(
858 "Instruction: %s is on device: %d, which conflicts with device: %d "
859 "of channel instruction: %s",
860 instruction->name(), device, correct_device,
861 channel_instruction->name());
862 };
863
864 CHECK_EQ(while_instruction->opcode(), HloOpcode::kWhile);
865 HloComputation* while_body = while_instruction->while_body();
866 // Maps a device number to an instruction in the while_body with that
867 // device assignment.
868 std::map<int64_t, HloInstruction*> devices_to_instructions;
869 std::optional<int64_t> unique_device = std::nullopt;
870 HloInstruction* channel_instruction = nullptr;
871
872 for (HloInstruction* instruction : while_body->instructions()) {
873 if (instruction->sharding_unique_device()) {
874 auto opcode = instruction->opcode();
875 int64_t device = *instruction->sharding_unique_device();
876 if (unique_device.has_value()) {
877 if (*unique_device != device) {
878 return bad_status(instruction, device, channel_instruction,
879 *unique_device);
880 }
881 } else if (opcode == HloOpcode::kSend || opcode == HloOpcode::kRecv ||
882 // Cross-replica AllReduces don't have a channel_id, and we
883 // don't enforce any invariant about their device assignment.
884 ((opcode == HloOpcode::kAllReduce ||
885 opcode == HloOpcode::kReduceScatter) &&
886 instruction->channel_id())) {
887 channel_instruction = instruction;
888 unique_device = device;
889 if (!devices_to_instructions.empty()) {
890 for (auto it = devices_to_instructions.begin();
891 it != devices_to_instructions.end(); ++it) {
892 if (*unique_device != it->first) {
893 return bad_status(it->second, it->first, channel_instruction,
894 *unique_device);
895 }
896 }
897 }
898 } else {
899 devices_to_instructions[device] = instruction;
900 }
901 }
902 }
903
904 if (unique_device.has_value()) {
905 auto while_device = while_instruction->sharding_unique_device();
906 if (while_device.has_value() && *unique_device != *while_device) {
907 return bad_status(while_instruction, *while_device, channel_instruction,
908 *unique_device);
909 }
910 auto body_root = while_body->root_instruction();
911 auto root_device = body_root->sharding_unique_device();
912 if (!root_device.has_value()) {
913 body_root->set_device_sharding(*unique_device);
914 } else if (*unique_device != *root_device) {
915 return bad_status(body_root, *root_device, channel_instruction,
916 *unique_device);
917 }
918 }
919 return OkStatus();
920 }
921
922 // Refines a pair of auto/manual shardings based on auto sharding `to_merge`
923 // along `unspecified_dims`. Returns if anything changed.
RefineManualAutoShardingFromAuto(const HloSharding & to_merge,absl::Span<const int64_t> unspecified_dims,HloSharding * auto_sharding,HloSharding * manual_sharding)924 bool RefineManualAutoShardingFromAuto(
925 const HloSharding& to_merge, absl::Span<const int64_t> unspecified_dims,
926 HloSharding* auto_sharding, HloSharding* manual_sharding) {
927 if (!manual_sharding->IsManualSubgroup() ||
928 auto_sharding->IsManualSubgroup() ||
929 !manual_sharding->HasPartialReplication() ||
930 manual_sharding->subgroup_types().size() != 2) {
931 // We do not support nested subgroup manual. man_conversion_op must have
932 // replication in order to be merged.
933 return false;
934 }
935 HloSharding partial_rep =
936 hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
937 to_merge, unspecified_dims);
938 if (partial_rep.IsTileMaximal()) {
939 return false;
940 }
941
942 // Merge with the non-manual partial annotation.
943 if (!hlo_sharding_util::MergeShardingIfCompatible(
944 partial_rep, auto_sharding->NumTiles() + 1, auto_sharding)) {
945 return false;
946 }
947
948 // Merge with the manual partial annotation.
949 const int64_t data_rank = partial_rep.TiledDataRank();
950 // We are also merging the non-manual sharding into the manual sharding. To
951 // leverage existing merging implementation, we treat the manual dim as a
952 // data dim, and add it right before the replication dim.
953 auto partial_tiling_for_manual = partial_rep.tile_assignment();
954 std::vector<int64_t> partial_manual_shape =
955 partial_tiling_for_manual.dimensions();
956 partial_manual_shape.insert(partial_manual_shape.begin() + data_rank, 1);
957 partial_tiling_for_manual.Reshape(partial_manual_shape);
958 HloSharding partial_rep_for_manual = HloSharding::PartialTile(
959 partial_tiling_for_manual, partial_rep.metadata());
960 Array<int64_t> man_tiling = manual_sharding->tile_assignment();
961 if (manual_sharding->subgroup_types().back() != OpSharding::REPLICATED) {
962 // Move the manual dim before replication dim.
963 std::vector<int64_t> transposed_dims = man_tiling.dimensions();
964 transposed_dims[data_rank] = transposed_dims.back();
965 transposed_dims.back() = man_tiling.dim(data_rank);
966 Array<int64_t> transposed(transposed_dims);
967 man_tiling.Each([&](absl::Span<const int64_t> indices, int64_t device) {
968 std::vector<int64_t> xposed_idx(indices.begin(), indices.end() - 2);
969 xposed_idx.push_back(indices.back());
970 xposed_idx.push_back(indices[data_rank]);
971 transposed(xposed_idx) = device;
972 });
973 man_tiling = std::move(transposed);
974 }
975 HloSharding tmp_sharding_for_merging =
976 HloSharding::PartialTile(man_tiling, manual_sharding->metadata());
977 if (!hlo_sharding_util::MergeShardingIfCompatible(
978 partial_rep_for_manual, tmp_sharding_for_merging.NumTiles() + 1,
979 &tmp_sharding_for_merging)) {
980 return false;
981 }
982
983 std::vector<OpSharding::Type> subgroup_types;
984 subgroup_types.push_back(OpSharding::MANUAL);
985 if (tmp_sharding_for_merging.HasPartialReplication()) {
986 subgroup_types.push_back(OpSharding::REPLICATED);
987 }
988 *manual_sharding = HloSharding::Subgroup(
989 tmp_sharding_for_merging.tile_assignment(), subgroup_types,
990 tmp_sharding_for_merging.metadata());
991 return true;
992 }
993
994 // Refines a pair of auto/manual shardings based on manual sharding `to_merge`
995 // along `unspecified_dims`. Returns if anything changed.
RefineManualAutoShardingFromManual(const HloSharding & to_merge,absl::Span<const int64_t> unspecified_dims,HloSharding * auto_sharding,HloSharding * manual_sharding)996 bool RefineManualAutoShardingFromManual(
997 const HloSharding& to_merge, absl::Span<const int64_t> unspecified_dims,
998 HloSharding* auto_sharding, HloSharding* manual_sharding) {
999 if (!to_merge.IsManualSubgroup() || !manual_sharding->IsManualSubgroup() ||
1000 !manual_sharding->HasPartialReplication() ||
1001 auto_sharding->IsManualSubgroup() ||
1002 manual_sharding->subgroup_types().size() != 2) {
1003 return false;
1004 }
1005 HloSharding partial_rep =
1006 hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
1007 to_merge, unspecified_dims);
1008 if (partial_rep.IsTileMaximal()) {
1009 return false;
1010 }
1011 if (!hlo_sharding_util::MergeShardingIfCompatible(
1012 partial_rep, manual_sharding->NumTiles() + 1, manual_sharding)) {
1013 return false;
1014 }
1015 HloSharding partial_rep_for_auto = HloSharding::Subgroup(
1016 partial_rep.tile_assignment(),
1017 std::vector<OpSharding::Type>(partial_rep.subgroup_types().size(),
1018 OpSharding::REPLICATED),
1019 partial_rep.metadata());
1020 if (!hlo_sharding_util::MergeShardingIfCompatible(
1021 partial_rep_for_auto, auto_sharding->NumTiles() + 1, auto_sharding)) {
1022 return false;
1023 }
1024 return true;
1025 }
1026
InferUnspecifiedDimsFromOperand(HloInstruction * annotate_op,absl::Span<const int64_t> unspecified_dims,HloInstruction ** man_conversion_op_after)1027 bool InferUnspecifiedDimsFromOperand(HloInstruction* annotate_op,
1028 absl::Span<const int64_t> unspecified_dims,
1029 HloInstruction** man_conversion_op_after) {
1030 // ProcessShardingInstruction will either keep the "Sharding" custom call as
1031 // is or replace it with a copy.
1032 CHECK(annotate_op->IsCustomCall("Sharding") ||
1033 annotate_op->opcode() == HloOpcode::kCopy);
1034 if (!IsSpatiallyPartitioned(annotate_op->operand(0))) {
1035 return false;
1036 }
1037 const HloSharding& operand_sharding = annotate_op->operand(0)->sharding();
1038 if (!operand_sharding.IsTiled()) {
1039 return false;
1040 }
1041 HloInstruction* man_conversion_op = nullptr;
1042 if (annotate_op->user_count() == 1) {
1043 HloInstruction* user = annotate_op->users()[0];
1044 if (user->IsCustomCall("SPMDFullToShardShape") ||
1045 user->IsCustomCall("SPMDShardToFullShape")) {
1046 std::vector<int64_t> user_unspec_dims;
1047 if (!sharding_op_util::ParseAttributes(
1048 Cast<HloCustomCallInstruction>(user)->opaque(),
1049 &user_unspec_dims)
1050 .ok()) {
1051 return false;
1052 }
1053 absl::c_sort(user_unspec_dims);
1054 if (unspecified_dims != user_unspec_dims) {
1055 // The manual/auto conversion op must have the same set of unspecified
1056 // dims.
1057 return false;
1058 }
1059 man_conversion_op = user;
1060 }
1061 }
1062 *man_conversion_op_after = man_conversion_op;
1063 if (man_conversion_op == nullptr) {
1064 HloSharding partial_replicated =
1065 hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
1066 operand_sharding, unspecified_dims);
1067 HloSharding sharding = annotate_op->sharding();
1068 if (!hlo_sharding_util::MergeShardingIfCompatible(
1069 partial_replicated, sharding.NumTiles() + 1, &sharding)) {
1070 return false;
1071 }
1072 annotate_op->set_sharding(sharding);
1073 return true;
1074 }
1075 if (man_conversion_op->IsCustomCall("SPMDFullToShardShape")) {
1076 HloSharding auto_sharding = annotate_op->sharding();
1077 HloSharding manual_sharding = man_conversion_op->sharding();
1078 if (!RefineManualAutoShardingFromAuto(operand_sharding, unspecified_dims,
1079 &auto_sharding, &manual_sharding)) {
1080 return false;
1081 }
1082 annotate_op->set_sharding(auto_sharding);
1083 man_conversion_op->set_sharding(manual_sharding);
1084 return true;
1085 }
1086
1087 CHECK(man_conversion_op->IsCustomCall("SPMDShardToFullShape"));
1088 HloSharding manual_sharding = annotate_op->sharding();
1089 HloSharding auto_sharding = man_conversion_op->sharding();
1090 if (!RefineManualAutoShardingFromManual(operand_sharding, unspecified_dims,
1091 &auto_sharding, &manual_sharding)) {
1092 return false;
1093 }
1094 annotate_op->set_sharding(manual_sharding);
1095 man_conversion_op->set_sharding(auto_sharding);
1096 return true;
1097 }
1098
InferUnspecifiedDimsFromOneUser(HloInstruction * annotate_op,const HloInstruction * user,int64_t aggressiveness,bool is_spmd,absl::Span<const int64_t> unspecified_dims,HloInstruction * man_conversion_op)1099 bool InferUnspecifiedDimsFromOneUser(HloInstruction* annotate_op,
1100 const HloInstruction* user,
1101 int64_t aggressiveness, bool is_spmd,
1102 absl::Span<const int64_t> unspecified_dims,
1103 HloInstruction* man_conversion_op) {
1104 CHECK(annotate_op->IsCustomCall("Sharding") ||
1105 annotate_op->opcode() == HloOpcode::kCopy);
1106 if (!user->has_sharding() || !user->sharding().IsTiled()) {
1107 return false;
1108 }
1109 std::optional<HloSharding> user_sharding =
1110 ShardingPropagation::GetShardingFromUser(
1111 man_conversion_op == nullptr ? *annotate_op : *man_conversion_op,
1112 *user, aggressiveness, is_spmd);
1113 if (!user_sharding.has_value() || user_sharding->IsTileMaximal()) {
1114 return false;
1115 }
1116 if (man_conversion_op == nullptr) {
1117 HloSharding partial_replicated =
1118 hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
1119 *user_sharding, unspecified_dims);
1120 HloSharding sharding = annotate_op->sharding();
1121 if (!hlo_sharding_util::MergeShardingIfCompatible(
1122 partial_replicated, sharding.NumTiles() + 1, &sharding)) {
1123 return false;
1124 }
1125 annotate_op->set_sharding(sharding);
1126 return true;
1127 }
1128 if (man_conversion_op->IsCustomCall("SPMDFullToShardShape")) {
1129 HloSharding auto_sharding = annotate_op->sharding();
1130 HloSharding manual_sharding = man_conversion_op->sharding();
1131 if (!RefineManualAutoShardingFromManual(*user_sharding, unspecified_dims,
1132 &auto_sharding, &manual_sharding)) {
1133 return false;
1134 }
1135 annotate_op->set_sharding(auto_sharding);
1136 man_conversion_op->set_sharding(manual_sharding);
1137 return true;
1138 }
1139 CHECK(man_conversion_op->IsCustomCall("SPMDShardToFullShape"));
1140 HloSharding manual_sharding = annotate_op->sharding();
1141 HloSharding auto_sharding = man_conversion_op->sharding();
1142 if (!RefineManualAutoShardingFromAuto(*user_sharding, unspecified_dims,
1143 &auto_sharding, &manual_sharding)) {
1144 return false;
1145 }
1146 annotate_op->set_sharding(manual_sharding);
1147 man_conversion_op->set_sharding(auto_sharding);
1148 return true;
1149 }
1150
InferUnspecifiedDimsFromUsers(HloInstruction * annotate_op,absl::Span<const int64_t> unspecified_dims,int64_t aggressiveness,bool is_spmd,HloInstruction ** man_conversion_op_after)1151 bool InferUnspecifiedDimsFromUsers(HloInstruction* annotate_op,
1152 absl::Span<const int64_t> unspecified_dims,
1153 int64_t aggressiveness, bool is_spmd,
1154 HloInstruction** man_conversion_op_after) {
1155 HloInstruction* man_conversion_op = nullptr;
1156 if (annotate_op->user_count() == 1) {
1157 HloInstruction* user = annotate_op->users()[0];
1158 if (user->IsCustomCall("SPMDFullToShardShape") ||
1159 user->IsCustomCall("SPMDShardToFullShape")) {
1160 std::vector<int64_t> user_unspec_dims;
1161 absl::c_sort(user_unspec_dims);
1162 if (!sharding_op_util::ParseAttributes(
1163 Cast<HloCustomCallInstruction>(user)->opaque(),
1164 &user_unspec_dims)
1165 .ok() ||
1166 unspecified_dims != user_unspec_dims) {
1167 // The manual/auto conversion op must have the same set of unspecified
1168 // dims.
1169 return false;
1170 }
1171 man_conversion_op = user;
1172 }
1173 }
1174 *man_conversion_op_after = man_conversion_op;
1175
1176 HloInstruction* op_for_users =
1177 man_conversion_op == nullptr ? annotate_op : man_conversion_op;
1178 bool changed = false;
1179 for (HloInstruction* user : op_for_users->users()) {
1180 changed |= InferUnspecifiedDimsFromOneUser(
1181 annotate_op, user, aggressiveness, is_spmd, unspecified_dims,
1182 man_conversion_op);
1183 }
1184 return changed;
1185 }
1186
1187 // Returns whether an op is a target for CSE prevention.
IsCSEPreventionTarget(const HloInstruction * instruction)1188 bool IsCSEPreventionTarget(const HloInstruction* instruction) {
1189 // Scalar broadcasts are the most common CSE target that causes cross-layer
1190 // propagation on unrelated subgraphs.
1191 return instruction->opcode() == HloOpcode::kBroadcast &&
1192 instruction->operand(0)->shape().rank() == 0;
1193 }
1194
1195 // Marks a sharding as for CSE prevention/
SetCSEPreventionSharding(const HloSharding & sharding)1196 HloSharding SetCSEPreventionSharding(const HloSharding& sharding) {
1197 OpMetadata metadata;
1198 metadata.set_op_name("_sharding_propagation_cse_prevention");
1199 return sharding.WithMetadata({metadata}, /*overwrite=*/true);
1200 }
1201
1202 // Returns if the sharding is for CSE prevention.
IsCSEPreventionSharding(const HloSharding & sharding)1203 bool IsCSEPreventionSharding(const HloSharding& sharding) {
1204 if (sharding.metadata().size() != 1) {
1205 return false;
1206 }
1207 return sharding.metadata()[0].op_name() ==
1208 "_sharding_propagation_cse_prevention";
1209 }
1210
1211 } // namespace
1212
InferBroadcastOperandSharding(const HloInstruction & instruction,bool is_spmd)1213 std::optional<HloSharding> InferBroadcastOperandSharding(
1214 const HloInstruction& instruction, bool is_spmd) {
1215 if (instruction.sharding().IsReplicated()) {
1216 return instruction.sharding();
1217 }
1218 std::vector<int64_t> dims_to_replicate;
1219 bool needs_replication = false;
1220 for (int64_t i = 0; i < instruction.shape().rank(); ++i) {
1221 if (absl::c_count(instruction.dimensions(), i) == 0) {
1222 dims_to_replicate.push_back(i);
1223 if (instruction.sharding().tile_assignment().dim(i) > 1) {
1224 needs_replication = true;
1225 }
1226 }
1227 }
1228 // If not SPMD, only support when none of the partitioned dimensions in
1229 // the broadcast output belong to new dimensions.
1230 if (!is_spmd && needs_replication) {
1231 return std::nullopt;
1232 }
1233 return hlo_sharding_util::RemoveShapeDimensions(
1234 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1235 instruction.sharding(), dims_to_replicate),
1236 dims_to_replicate);
1237 }
1238
1239 // Remove Sharding custom-call instruction by folding the sharding attribute
1240 // to its operand. If the operand already has a different sharding, insert a
1241 // copy node for reshard.
1242 // `unspecified_dims` will be populated with the converted copies if the custom
1243 // call is partially specified.
ProcessShardingInstruction(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads,bool replace_sharding_with_copy,absl::flat_hash_map<const HloInstruction *,std::vector<int64_t>> * unspecified_dims)1244 StatusOr<bool> ProcessShardingInstruction(
1245 HloModule* module,
1246 const absl::flat_hash_set<absl::string_view>& execution_threads,
1247 bool replace_sharding_with_copy,
1248 absl::flat_hash_map<const HloInstruction*, std::vector<int64_t>>*
1249 unspecified_dims) {
1250 bool changed = false;
1251
1252 for (HloComputation* computation : module->computations(execution_threads)) {
1253 auto instructions = computation->MakeInstructionPostOrder();
1254 std::reverse(instructions.begin(), instructions.end());
1255 for (HloInstruction* instruction : instructions) {
1256 if (!instruction->IsCustomCall("Sharding")) {
1257 continue;
1258 }
1259 TF_RET_CHECK(instruction->has_sharding())
1260 << "Sharding instruction must have a sharding attribute";
1261 const HloSharding& sharding = instruction->sharding();
1262
1263 std::vector<int64_t> unspec_dims;
1264 TF_RETURN_IF_ERROR(sharding_op_util::ParseAttributes(
1265 Cast<HloCustomCallInstruction>(instruction)->opaque(), &unspec_dims));
1266 // Replace it with a copy node so that it does not need special handling.
1267 if (replace_sharding_with_copy) {
1268 auto copy = computation->AddInstruction(
1269 HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kCopy,
1270 instruction->mutable_operand(0)));
1271 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(instruction, copy));
1272 copy->set_sharding(sharding);
1273 instruction = copy;
1274 changed = true;
1275 }
1276 if (!unspec_dims.empty()) {
1277 absl::c_sort(unspec_dims);
1278 unspecified_dims->emplace(instruction, std::move(unspec_dims));
1279 } else if (!instruction->operand(0)->has_sharding()) {
1280 instruction->mutable_operand(0)->set_sharding(sharding);
1281 }
1282 }
1283 }
1284 return changed;
1285 }
1286
NormalizeDomain(const DomainMetadata::Domain & domain,const DomainMetadata * metadata)1287 /*static*/ Status ShardingPropagation::NormalizeDomain(
1288 const DomainMetadata::Domain& domain, const DomainMetadata* metadata) {
1289 if (metadata != nullptr) {
1290 TF_ASSIGN_OR_RETURN(const auto& sharding_metadata,
1291 ShardingMetadata::ToShardingMetadata(metadata));
1292 const auto& sharding = sharding_metadata->sharding();
1293 if (sharding != nullptr) {
1294 bool is_spatially_partitioned = !sharding->HasUniqueDevice();
1295 if (sharding->IsTuple()) {
1296 is_spatially_partitioned = absl::c_any_of(
1297 sharding->tuple_elements(),
1298 [](const HloSharding& s) { return !s.HasUniqueDevice(); });
1299 }
1300 if (is_spatially_partitioned) {
1301 for (HloInstruction* d : domain.exit_domains) {
1302 HloInstruction* operand = d->mutable_operand(0);
1303 // Set sharding only if it is different. We don't overwrite the
1304 // metadata if it has the same sharding besides metadata.
1305 if (!operand->has_sharding() || operand->sharding() != *sharding) {
1306 d->mutable_operand(0)->set_sharding(*sharding);
1307 }
1308 }
1309 return OkStatus();
1310 }
1311 }
1312 }
1313 return ShardingMetadata::NormalizeShardingDomain(domain, metadata);
1314 }
1315
1316 // Return the sharding that should be propagated from user to instruction.
GetShardingFromUser(const HloInstruction & instruction,const HloInstruction & user,int64_t aggressiveness,bool is_spmd)1317 std::optional<HloSharding> ShardingPropagation::GetShardingFromUser(
1318 const HloInstruction& instruction, const HloInstruction& user,
1319 int64_t aggressiveness, bool is_spmd) {
1320 if (!CanPropagateThroughAtAggressiveLevel(user, aggressiveness)) {
1321 return std::nullopt;
1322 }
1323 if (!IsSpatiallyPartitioned(&user)) {
1324 return std::nullopt;
1325 }
1326 const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
1327
1328 switch (user.opcode()) {
1329 case HloOpcode::kBroadcast: {
1330 return InferBroadcastOperandSharding(user, is_spmd);
1331 }
1332 case HloOpcode::kConcatenate: {
1333 if (aggressiveness == 0) {
1334 return std::nullopt;
1335 }
1336 if (user.sharding().IsReplicated()) {
1337 return user.sharding();
1338 }
1339
1340 const int64_t cdim = user.concatenate_dimension();
1341 const Array<int64_t>& tile_assignment = user.sharding().tile_assignment();
1342 if (tile_assignment.dim(cdim) == 1) {
1343 // If we are concatenating along a non-sharded dimension then the
1344 // operands should have the same sharding as the result.
1345 return user.sharding();
1346 }
1347
1348 if (is_spmd) {
1349 // SPMD doesn't support tiling with part of the devices. Return the same
1350 // sharding.
1351 return user.sharding();
1352 }
1353
1354 // If we are concatenating along a sharded dimension then we want the
1355 // operands to be distributed among the devices their data is used.
1356 int64_t start_offset = 0;
1357 for (HloInstruction* op : user.operands()) {
1358 if (op == &instruction) {
1359 break;
1360 }
1361 start_offset += op->shape().dimensions(cdim);
1362 }
1363 const int64_t tile_shape = CeilOfRatio(
1364 user.shape().dimensions(cdim), tile_assignment.dimensions()[cdim]);
1365 std::vector<int64_t> start_indices(tile_assignment.num_dimensions());
1366 std::vector<int64_t> end_indices = tile_assignment.dimensions();
1367 start_indices[cdim] = start_offset / tile_shape;
1368 end_indices[cdim] = CeilOfRatio(
1369 start_offset + instruction.shape().dimensions(cdim), tile_shape);
1370 auto new_tile_assignment =
1371 tile_assignment.Slice(start_indices, end_indices);
1372 if (new_tile_assignment.num_elements() == 1) {
1373 return HloSharding::AssignDevice(*new_tile_assignment.begin(),
1374 user.sharding().metadata());
1375 }
1376 return HloSharding::Tile(new_tile_assignment, user.sharding().metadata());
1377 }
1378 case HloOpcode::kConvolution: {
1379 auto dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(&user);
1380 if (dot_dims.conv_spatial_dims.empty()) {
1381 int64_t op_idx = user.operand_index(&instruction);
1382 return InferDotOperandSharding(&user, dot_dims, op_idx,
1383 may_combine_partial_sharding);
1384 }
1385 return std::nullopt;
1386 }
1387 case HloOpcode::kDynamicSlice:
1388 case HloOpcode::kDynamicUpdateSlice: {
1389 if (aggressiveness == 0) {
1390 return std::nullopt;
1391 }
1392 if (user.sharding().IsReplicated()) {
1393 return user.sharding();
1394 }
1395 if (user.opcode() == HloOpcode::kDynamicUpdateSlice &&
1396 &instruction == user.operand(0)) {
1397 return user.sharding();
1398 }
1399 const HloInstruction* operand = user.opcode() == HloOpcode::kDynamicSlice
1400 ? user.operand(0)
1401 : user.operand(1);
1402 if (&instruction != operand) {
1403 return std::nullopt;
1404 }
1405
1406 if (is_spmd) {
1407 return user.sharding();
1408 }
1409 const auto& tile_assignment = user.sharding().tile_assignment();
1410 for (int64_t i = 0; i < user.shape().rank(); ++i) {
1411 if (tile_assignment.dim(i) > 1 &&
1412 user.shape().dimensions(i) != operand->shape().dimensions(i)) {
1413 return std::nullopt;
1414 }
1415 }
1416 return user.sharding();
1417 }
1418 case HloOpcode::kReduceWindow: {
1419 auto* reduce_window = Cast<HloReduceWindowInstruction>(&user);
1420 if (!absl::c_linear_search(reduce_window->inputs(), &instruction)) {
1421 return std::nullopt;
1422 }
1423 if (reduce_window->shape().IsTuple()) {
1424 auto sub_sharding = reduce_window->sharding().GetSubSharding(
1425 reduce_window->shape(),
1426 {reduce_window->operand_index(&instruction)});
1427 return sub_sharding;
1428 }
1429 return reduce_window->sharding();
1430 }
1431 case HloOpcode::kReshape: {
1432 auto reshaped_sharding = hlo_sharding_util::ReshapeSharding(
1433 user.shape(), instruction.shape(), user.sharding());
1434 if (reshaped_sharding.has_value()) {
1435 return reshaped_sharding;
1436 }
1437 return hlo_sharding_util::ReplicateAllDataDims(
1438 user.sharding(), instruction.shape().rank());
1439 }
1440 case HloOpcode::kPad: {
1441 if (&instruction != user.operand(0)) {
1442 return std::nullopt;
1443 }
1444 return user.sharding();
1445 }
1446 case HloOpcode::kSlice: {
1447 return user.sharding();
1448 }
1449 case HloOpcode::kTranspose: {
1450 // Calculate the dimension numbers for reversing the current transpose
1451 // and then use TransposeSharding to convert the output sharding to an
1452 // input sharding.
1453 std::vector<int64_t> reverse_dimensions(user.dimensions().size());
1454 for (int64_t i = 0; i < user.dimensions().size(); ++i) {
1455 reverse_dimensions[user.dimensions(i)] = i;
1456 }
1457 return hlo_sharding_util::TransposeSharding(user.sharding(),
1458 reverse_dimensions);
1459 }
1460 case HloOpcode::kTuple: {
1461 auto sub_sharding = user.sharding().GetSubSharding(
1462 user.shape(), {user.operand_index(&instruction)});
1463 return sub_sharding;
1464 }
1465 case HloOpcode::kGetTupleElement: {
1466 int64_t sharding_index = 0;
1467 for (int i = 0; i < instruction.shape().tuple_shapes_size(); ++i) {
1468 if (i == user.tuple_index()) {
1469 break;
1470 }
1471 if (instruction.shape().tuple_shapes(i).IsArray()) {
1472 sharding_index += 1;
1473 } else {
1474 sharding_index +=
1475 ShapeUtil::GetLeafCount(instruction.shape().tuple_shapes(i));
1476 }
1477 }
1478 if (user.shape().IsArray()) {
1479 // Use ReplicateAllDataDims instead of HloSharding::Replicate() to
1480 // preserve manual subgroups.
1481 HloSharding new_sharding =
1482 instruction.has_sharding()
1483 ? instruction.sharding()
1484 : HloSharding::SingleTuple(
1485 instruction.shape(),
1486 hlo_sharding_util::ReplicateAllDataDims(user.sharding()));
1487 new_sharding.tuple_elements()[sharding_index] = user.sharding();
1488 return new_sharding;
1489 } else {
1490 if (user.sharding().tuple_elements().empty()) {
1491 return std::nullopt;
1492 }
1493 HloSharding new_sharding =
1494 instruction.has_sharding()
1495 ? instruction.sharding()
1496 : HloSharding::SingleTuple(
1497 instruction.shape(),
1498 hlo_sharding_util::ReplicateAllDataDims(
1499 user.sharding().tuple_elements()[0]));
1500 for (int64_t i = 0; i < user.sharding().tuple_elements().size(); ++i) {
1501 new_sharding.tuple_elements()[sharding_index + i] =
1502 user.sharding().tuple_elements()[i];
1503 }
1504 return new_sharding;
1505 }
1506 }
1507 case HloOpcode::kDot: {
1508 int64_t op_idx = user.operand_index(&instruction);
1509 auto dnums = dot_as_convolution_util::ParseDotGeneralFromDot(&user);
1510 return InferDotOperandSharding(&user, dnums, op_idx,
1511 may_combine_partial_sharding);
1512 }
1513 case HloOpcode::kReduce: {
1514 if (instruction.shape().rank() == 0) {
1515 return std::nullopt;
1516 }
1517 auto user_sharding =
1518 user.shape().IsTuple()
1519 ? user.sharding().GetSubSharding(
1520 user.shape(), {user.operand_index(&instruction)})
1521 : user.sharding();
1522 if (user_sharding.IsTileMaximal()) {
1523 return user_sharding;
1524 }
1525 std::vector<int64_t> target_tile_assignment_dimensions(
1526 instruction.shape().rank() +
1527 (user_sharding.ReplicateOnLastTileDim() ? 1 : 0) +
1528 user_sharding.subgroup_types().size());
1529 const auto& dimensions = user.dimensions();
1530 int64_t next_output_dim = 0;
1531 for (int64_t i = 0; i < target_tile_assignment_dimensions.size(); ++i) {
1532 if (absl::c_find(dimensions, i) == dimensions.end()) {
1533 target_tile_assignment_dimensions[i] =
1534 user_sharding.tile_assignment().dim(next_output_dim++);
1535 } else {
1536 target_tile_assignment_dimensions[i] = 1;
1537 }
1538 }
1539 auto tile_assignment = user_sharding.tile_assignment();
1540 tile_assignment.Reshape(target_tile_assignment_dimensions);
1541 return user_sharding.ReplicateOnLastTileDim()
1542 ? HloSharding::PartialTile(tile_assignment,
1543 user_sharding.metadata())
1544 : HloSharding::Subgroup(tile_assignment,
1545 user_sharding.subgroup_types(),
1546 user_sharding.metadata());
1547 }
1548 case HloOpcode::kSort: {
1549 HloSharding user_sharding = user.sharding();
1550 if (user_sharding.IsTuple()) {
1551 return user_sharding = user_sharding.GetSubSharding(
1552 user.shape(), {user.operand_index(&instruction)});
1553 }
1554 return user_sharding;
1555 }
1556 case HloOpcode::kReverse: {
1557 return hlo_sharding_util::ReverseSharding(user.sharding(),
1558 user.dimensions());
1559 }
1560 case HloOpcode::kGather: {
1561 if (&instruction == user.operand(1)) {
1562 return hlo_sharding_util::GatherIndexSharding(user.sharding(), &user);
1563 }
1564 if (is_spmd) {
1565 return hlo_sharding_util::GatherDataOperandShardingFromOutput(
1566 user.sharding(), user);
1567 }
1568 return std::nullopt;
1569 }
1570 case HloOpcode::kScatter: {
1571 auto& scatter_user = *Cast<HloScatterInstruction>(&user);
1572 if (&instruction == scatter_user.operand(0)) {
1573 return user.sharding();
1574 }
1575 if (&instruction == scatter_user.scatter_indices()) {
1576 auto update = scatter_user.scatter_updates()[0];
1577 if (!IsSpatiallyPartitioned(update)) {
1578 return std::nullopt;
1579 }
1580 return hlo_sharding_util::ScatterIndexSharding(update->sharding(),
1581 &scatter_user);
1582 }
1583 CHECK_EQ(&instruction, scatter_user.scatter_updates()[0]);
1584 auto indices = scatter_user.scatter_indices();
1585 if (IsSpatiallyPartitioned(indices)) {
1586 auto from_indices = hlo_sharding_util::ScatterDataSharding(
1587 indices->sharding(), &scatter_user);
1588 if (!from_indices.IsTileMaximal()) {
1589 return from_indices;
1590 }
1591 }
1592 if (is_spmd) {
1593 return hlo_sharding_util::ScatterUpdateShardingFromOutput(
1594 user.sharding(), scatter_user);
1595 }
1596 return std::nullopt;
1597 }
1598 default: {
1599 // If the user output shape is compatible with the current instruction
1600 // shape excluding element type and the current instruction is supported
1601 // by spatial partitioning, then the user sharding can be used for
1602 // propagation to the current instruction.
1603 if (ShapeUtil::CompatibleIgnoringElementType(instruction.shape(),
1604 user.shape())) {
1605 return user.sharding();
1606 }
1607 return std::nullopt;
1608 }
1609 }
1610 }
1611
1612 // Compute the number of users that are only internal to the computation.
ComputeNonRootUsers(const HloInstruction * instr)1613 int64_t ComputeNonRootUsers(const HloInstruction* instr) {
1614 int64_t non_root_users = instr->users().size();
1615 for (int i = 0; i < instr->users().size(); ++i) {
1616 if (instr->users()[i] == instr->parent()->root_instruction()) {
1617 --non_root_users;
1618 }
1619 }
1620 return non_root_users;
1621 }
1622
1623 // Only pass through sharding annotation at the first iteration when:
1624 // 1. Operand is sharded; 2. Only non-concat dim is sharded;
1625 // 3. Concat is for params in the repeated layers which follows the
1626 // pattern of param/gte -> reshape -> concat.
AggressiveConcatOperandShardingCanPassThrough(const HloInstruction * concat_operand)1627 bool AggressiveConcatOperandShardingCanPassThrough(
1628 const HloInstruction* concat_operand) {
1629 return (
1630 IsSpatiallyPartitioned(concat_operand) &&
1631 (concat_operand->has_sharding() &&
1632 concat_operand->sharding().NumTiles() > 1) &&
1633 concat_operand->opcode() == HloOpcode::kReshape &&
1634 (concat_operand->operand(0)->opcode() == HloOpcode::kParameter ||
1635 concat_operand->operand(0)->opcode() == HloOpcode::kGetTupleElement));
1636 }
1637
1638 // DyanmicSlice or DynamicUpdateSlice handling for InferShardingFromOperands().
InferDynamicSliceOrDynamicUpdateSliceShardingFromOperands(HloInstruction * instruction,int64_t aggressiveness,bool may_combine_partial_sharding)1639 bool InferDynamicSliceOrDynamicUpdateSliceShardingFromOperands(
1640 HloInstruction* instruction, int64_t aggressiveness,
1641 bool may_combine_partial_sharding) {
1642 const HloInstruction* operand =
1643 instruction->opcode() == HloOpcode::kDynamicSlice
1644 ? instruction->operand(0)
1645 : instruction->operand(1);
1646 auto slice_dim_is_sharded = [&]() {
1647 if (!IsSpatiallyPartitioned(operand) ||
1648 operand->sharding().NumTiles() == 1) {
1649 return false;
1650 }
1651 for (int64_t i = 0; i < instruction->shape().rank(); ++i) {
1652 const auto& tile_assignment = operand->sharding().tile_assignment();
1653 if (tile_assignment.dim(i) > 1 && instruction->shape().dimensions(i) !=
1654 operand->shape().dimensions(i)) {
1655 return true;
1656 }
1657 }
1658 return false;
1659 };
1660
1661 // Do not pass through sharding annotation at the first iteration
1662 // if slice dim is sharded.
1663 if (aggressiveness == 0 && slice_dim_is_sharded()) {
1664 return false;
1665 }
1666
1667 auto propagate_slicing = [&]() {
1668 if (!IsSpatiallyPartitioned(operand)) {
1669 return false;
1670 }
1671
1672 if (operand->sharding().NumTiles() == 1) {
1673 return MaybeImproveInstructionSharding(
1674 operand->sharding(), instruction, may_combine_partial_sharding,
1675 /*allow_aggressive_resharding=*/
1676 ComputeNonRootUsers(instruction) == 1);
1677 }
1678
1679 if (slice_dim_is_sharded()) {
1680 return false;
1681 }
1682 return MaybeImproveInstructionSharding(
1683 operand->sharding(), instruction, may_combine_partial_sharding,
1684 /*allow_aggressive_resharding=*/
1685 ComputeNonRootUsers(instruction) == 1);
1686 };
1687 auto propagate_base = [&]() {
1688 if (instruction->opcode() != HloOpcode::kDynamicUpdateSlice) {
1689 return false;
1690 }
1691 if (!IsSpatiallyPartitioned(instruction->operand(0))) {
1692 return false;
1693 }
1694 return MaybeImproveInstructionSharding(instruction->operand(0)->sharding(),
1695 instruction,
1696 may_combine_partial_sharding);
1697 };
1698 bool changed = propagate_slicing();
1699 changed |= propagate_base();
1700 return changed;
1701 }
1702
1703 // Tries to update the sharding of the specified instruction based on its
1704 // operands and returns true if the sharding of the instruction have been
1705 // changed and false otherwise.
InferShardingFromOperands(HloInstruction * instruction,const ComputationMap & computation_map,int64_t aggressiveness)1706 bool ShardingPropagation::InferShardingFromOperands(
1707 HloInstruction* instruction, const ComputationMap& computation_map,
1708 int64_t aggressiveness) {
1709 if (!CanPropagateThroughAtAggressiveLevel(*instruction, aggressiveness)) {
1710 return false;
1711 }
1712 // Do not change manual sharding.
1713 if (instruction->has_sharding() && instruction->sharding().IsManual()) {
1714 return false;
1715 }
1716 // Propagate manual sharding. Avoid tuple shaped HLOs that group independent
1717 // together. Reduce, ReduceWindow, and Sort can be tuples but the elements
1718 // are correlated, so we propagate manual sharding through them.
1719 if (!instruction->has_sharding() &&
1720 (instruction->shape().IsArray() ||
1721 instruction->opcode() == HloOpcode::kReduce ||
1722 instruction->opcode() == HloOpcode::kSort ||
1723 instruction->opcode() == HloOpcode::kReduceWindow)) {
1724 for (const HloInstruction* op : instruction->operands()) {
1725 if (!op->has_sharding() || !op->sharding().IsManual()) continue;
1726 // Do not pass through manual sharding to concat or dynamic slice when
1727 // aggressiveneess is 0.
1728 if (aggressiveness == 0 &&
1729 (instruction->opcode() == HloOpcode::kConcatenate ||
1730 instruction->opcode() == HloOpcode::kDynamicSlice)) {
1731 return false;
1732 }
1733 instruction->set_sharding(HloSharding::Manual(op->sharding().metadata()));
1734 return true;
1735 }
1736 }
1737 const bool may_combine_partial_sharding = is_spmd_ && aggressiveness > 0;
1738 if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd_,
1739 allow_spmd_sharding_propagation_to_output_,
1740 sharding_helper_.get())) {
1741 // If an array shaped HLO doesn't support spatial partitioning but at least
1742 // one of its operand is replicated then we make the HLO replicated as well.
1743 if (instruction->shape().IsTuple() || instruction->operand_count() == 0 ||
1744 instruction == instruction->parent()->root_instruction() ||
1745 instruction->HasSideEffect()) {
1746 return false;
1747 }
1748 for (const HloInstruction* op : instruction->operands()) {
1749 if (op->has_sharding() && op->sharding().IsTileMaximal() &&
1750 !op->sharding().HasUniqueDevice()) {
1751 return MaybeImproveInstructionSharding(op->sharding(), instruction,
1752 may_combine_partial_sharding);
1753 }
1754 }
1755 return false;
1756 }
1757
1758 auto get_maybe_tuple_sharding = [&](HloSharding sharding) {
1759 if (instruction->shape().IsArray()) {
1760 return sharding;
1761 }
1762 std::vector<HloSharding> tuple(instruction->shape().tuple_shapes_size(),
1763 std::move(sharding));
1764 return HloSharding::Tuple(instruction->shape(), tuple);
1765 };
1766
1767 switch (instruction->opcode()) {
1768 case HloOpcode::kGetTupleElement: {
1769 const HloInstruction* operand = instruction->operand(0);
1770 if (!IsSpatiallyPartitioned(operand)) {
1771 return false;
1772 }
1773 HloSharding new_sharding = operand->sharding().GetSubSharding(
1774 operand->shape(), {instruction->tuple_index()});
1775 if (new_sharding.IsManual()) {
1776 instruction->set_sharding(new_sharding);
1777 return true;
1778 }
1779 return MaybeImproveInstructionSharding(
1780 std::move(new_sharding), instruction, may_combine_partial_sharding,
1781 /*allow_aggressive_resharding=*/
1782 ComputeNonRootUsers(instruction) == 1);
1783 }
1784 case HloOpcode::kTuple: {
1785 if (absl::c_none_of(instruction->operands(),
1786 [](const HloInstruction* hlo) {
1787 return IsSpatiallyPartitioned(hlo);
1788 })) {
1789 // None of the operands have a spatially partitioned sharding.
1790 return false;
1791 }
1792 const Shape& shape = instruction->shape();
1793 bool changed = false;
1794 if (!instruction->has_sharding()) {
1795 // Set the sharding for all elements in the tuple because it isn't
1796 // possible to set a partial sharding.
1797 changed = true;
1798 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
1799 const HloInstruction* operand = instruction->operand(i);
1800 if (!operand->has_sharding()) {
1801 continue;
1802 }
1803 if (operand->sharding().IsTuple()) {
1804 if (operand->sharding().tuple_elements().empty()) {
1805 continue;
1806 }
1807 // Use ReplicateAllDataDims to preserve manual subgroups.
1808 instruction->set_sharding(HloSharding::SingleTuple(
1809 instruction->shape(),
1810 hlo_sharding_util::ReplicateAllDataDims(
1811 operand->sharding().tuple_elements()[0])
1812 .WithoutMetadata()));
1813 } else {
1814 instruction->set_sharding(HloSharding::SingleTuple(
1815 instruction->shape(),
1816 hlo_sharding_util::ReplicateAllDataDims(operand->sharding())
1817 .WithoutMetadata()));
1818 }
1819 break;
1820 }
1821 }
1822 if (!instruction->has_sharding()) {
1823 return false;
1824 }
1825 // Go through each operand and if the operand has a sharding that is
1826 // better than the current sharding for that tuple element then update
1827 // it.
1828 std::vector<HloSharding> sub_shardings =
1829 instruction->sharding().tuple_elements();
1830 int64_t sub_sharding_index = 0;
1831 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
1832 const HloInstruction* operand = instruction->operand(i);
1833 if (operand->has_sharding()) {
1834 if (operand->shape().IsTuple()) {
1835 for (int64_t i = 0, e = ShapeUtil::GetLeafCount(operand->shape());
1836 i < e; ++i) {
1837 if (hlo_sharding_util::IsShardingMoreSpecific(
1838 operand->sharding().tuple_elements()[i],
1839 sub_shardings[sub_sharding_index + i])) {
1840 sub_shardings[sub_sharding_index + i] =
1841 operand->sharding().tuple_elements()[i];
1842 }
1843 }
1844 } else {
1845 if (hlo_sharding_util::IsShardingMoreSpecific(
1846 operand->sharding(), sub_shardings[sub_sharding_index])) {
1847 sub_shardings[sub_sharding_index] = operand->sharding();
1848 }
1849 }
1850 }
1851 sub_sharding_index += ShapeUtil::GetLeafCount(operand->shape());
1852 }
1853
1854 HloSharding new_sharding = HloSharding::Tuple(shape, sub_shardings);
1855 if (new_sharding != instruction->sharding()) {
1856 instruction->set_sharding(std::move(new_sharding));
1857 return true;
1858 }
1859 return changed;
1860 }
1861 case HloOpcode::kReduce: {
1862 // Reduce could have a tuple shape, where the first half of operands are
1863 // the arrays to reduce, and the second half of operands are the init
1864 // values.
1865 bool changed = false;
1866 auto* reduce = Cast<HloReduceInstruction>(instruction);
1867 for (HloInstruction* operand : reduce->inputs()) {
1868 if (!IsSpatiallyPartitioned(operand)) {
1869 continue;
1870 }
1871 if (operand->sharding().IsReplicated() ||
1872 (!is_spmd_ &&
1873 absl::c_any_of(instruction->dimensions(), [operand](int64_t dim) {
1874 return operand->sharding().tile_assignment().dim(dim) > 1;
1875 }))) {
1876 // We are reducing along one of the sharded dimensions. We only
1877 // support this in SPMD.
1878 changed |= MaybeImproveInstructionSharding(
1879 get_maybe_tuple_sharding(
1880 hlo_sharding_util::ReplicateAllDataDims(operand->sharding())),
1881 reduce, may_combine_partial_sharding,
1882 /*allow_aggressive_resharding=*/
1883 ComputeNonRootUsers(instruction) == 1);
1884 continue;
1885 }
1886 auto after_partial_replication =
1887 operand->sharding().IsReplicated()
1888 ? operand->sharding()
1889 : hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1890 operand->sharding(), reduce->dimensions());
1891 if (after_partial_replication.IsReplicated()) {
1892 changed |= MaybeImproveInstructionSharding(
1893 get_maybe_tuple_sharding(after_partial_replication), reduce,
1894 may_combine_partial_sharding,
1895 /*allow_aggressive_resharding=*/
1896 ComputeNonRootUsers(instruction) == 1);
1897 continue;
1898 }
1899 // Use the same sharding for all tuple elements, because they are part
1900 // of the same reduce instruction.
1901 HloSharding new_sharding =
1902 get_maybe_tuple_sharding(hlo_sharding_util::RemoveShapeDimensions(
1903 after_partial_replication, reduce->dimensions()));
1904 changed |= MaybeImproveInstructionSharding(
1905 std::move(new_sharding), reduce, may_combine_partial_sharding,
1906 /*allow_aggressive_resharding=*/
1907 ComputeNonRootUsers(instruction) == 1);
1908 }
1909 return changed;
1910 }
1911 case HloOpcode::kBroadcast: {
1912 // Make forward propagation through broadcast low priority to avoid
1913 // resharding after broadcast.
1914 if (aggressiveness < 3) {
1915 return false;
1916 }
1917 const HloInstruction* op = instruction->operand(0);
1918 if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) {
1919 return false;
1920 }
1921 // The output will be tiled along the broadcasted dimension the same way
1922 // as the input for the broadcast while the other dimensions are kept
1923 // non-tiled.
1924 std::vector<int64_t> target_tile_assignment_dimensions;
1925 const auto& dimensions = instruction->dimensions();
1926 for (int64_t i = 0; i < instruction->shape().rank(); ++i) {
1927 auto it = absl::c_find(dimensions, i);
1928 if (it == dimensions.end()) {
1929 target_tile_assignment_dimensions.push_back(1);
1930 } else {
1931 const int64_t source_dim = std::distance(dimensions.begin(), it);
1932 target_tile_assignment_dimensions.push_back(
1933 op->sharding().tile_assignment().dim(source_dim));
1934 }
1935 }
1936 for (int64_t i = op->sharding().TiledDataRank();
1937 i < op->sharding().tile_assignment().num_dimensions(); ++i) {
1938 target_tile_assignment_dimensions.push_back(
1939 op->sharding().tile_assignment().dim(i));
1940 }
1941 Array<int64_t> new_tile_assignment = op->sharding().tile_assignment();
1942 new_tile_assignment.Reshape(target_tile_assignment_dimensions);
1943 HloSharding new_sharding =
1944 op->sharding().ReplicateOnLastTileDim()
1945 ? HloSharding::PartialTile(new_tile_assignment,
1946 op->sharding().metadata())
1947 : HloSharding::Subgroup(new_tile_assignment,
1948 op->sharding().subgroup_types(),
1949 op->sharding().metadata());
1950 return MaybeImproveInstructionSharding(
1951 std::move(new_sharding), instruction, may_combine_partial_sharding,
1952 /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
1953 1);
1954 }
1955 case HloOpcode::kConcatenate: {
1956 const HloInstruction* operand = PickRepresentativeOperand(instruction);
1957 if (!operand || !IsSpatiallyPartitioned(operand)) {
1958 return false;
1959 }
1960
1961 if (aggressiveness == 0) {
1962 for (const HloInstruction* concat_operand : instruction->operands()) {
1963 if (!AggressiveConcatOperandShardingCanPassThrough(concat_operand)) {
1964 return false;
1965 }
1966 const auto& tile_assignment =
1967 concat_operand->sharding().tile_assignment();
1968 for (int64_t i = 0; i < instruction->shape().rank(); ++i) {
1969 if (absl::c_linear_search(instruction->dimensions(), i) &&
1970 tile_assignment.dim(i) > 1) {
1971 return false;
1972 }
1973 }
1974 }
1975 }
1976 return MaybeImproveInstructionSharding(
1977 operand->sharding(), instruction, may_combine_partial_sharding,
1978 /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
1979 1);
1980 }
1981 case HloOpcode::kConvolution:
1982 return InferConvolutionShardingFromOperands(instruction, aggressiveness,
1983 may_combine_partial_sharding);
1984 case HloOpcode::kTranspose: {
1985 const HloInstruction* input = instruction->operand(0);
1986 if (!IsSpatiallyPartitioned(input)) {
1987 return false;
1988 }
1989 HloSharding sharding = hlo_sharding_util::TransposeSharding(
1990 input->sharding(), instruction->dimensions());
1991 return MaybeImproveInstructionSharding(
1992 std::move(sharding), instruction, may_combine_partial_sharding,
1993 /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
1994 1);
1995 }
1996 case HloOpcode::kReduceWindow: {
1997 auto* reduce_window = Cast<HloReduceWindowInstruction>(instruction);
1998 auto has_dilation = [](const WindowDimension& dimensions) {
1999 return dimensions.base_dilation() > 1 ||
2000 dimensions.window_dilation() > 1;
2001 };
2002 if (absl::c_any_of(instruction->window().dimensions(), has_dilation)) {
2003 VLOG(2) << "Not applying sharding to reduce window because dilatation "
2004 "isn't supported yet: "
2005 << reduce_window->ToString();
2006 return false;
2007 }
2008 bool changed = false;
2009 for (HloInstruction* operand : reduce_window->inputs()) {
2010 if (!IsSpatiallyPartitioned(operand)) {
2011 continue;
2012 }
2013 changed |= MaybeImproveInstructionSharding(
2014 get_maybe_tuple_sharding(operand->sharding()), reduce_window,
2015 may_combine_partial_sharding,
2016 /*allow_aggressive_resharding=*/
2017 ComputeNonRootUsers(instruction) == 1);
2018 }
2019 return changed;
2020 }
2021 case HloOpcode::kSelectAndScatter: {
2022 // Shard according to first operand, as output keeps the same shape.
2023 const HloInstruction* lhs = instruction->operand(0);
2024 if (!IsSpatiallyPartitioned(lhs)) {
2025 return false;
2026 }
2027
2028 auto has_base_dilation = [](const WindowDimension& dimensions) {
2029 return dimensions.base_dilation() > 1;
2030 };
2031 if (absl::c_any_of(instruction->window().dimensions(),
2032 has_base_dilation)) {
2033 VLOG(2) << "Not applying sharding to select-and-scatter because "
2034 "base dilation isn't supported yet: "
2035 << instruction->ToString();
2036 return false;
2037 }
2038 return MaybeImproveInstructionSharding(
2039 lhs->sharding(), instruction, may_combine_partial_sharding,
2040 /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
2041 1);
2042 }
2043 case HloOpcode::kReshape: {
2044 if (!IsSpatiallyPartitioned(instruction->operand(0))) {
2045 return false;
2046 }
2047 std::optional<HloSharding> new_sharding =
2048 hlo_sharding_util::ReshapeSharding(
2049 instruction->operand(0)->shape(), instruction->shape(),
2050 instruction->operand(0)->sharding());
2051 if (new_sharding.has_value()) {
2052 return MaybeImproveInstructionSharding(
2053 std::move(*new_sharding), instruction, may_combine_partial_sharding,
2054 /*allow_aggressive_resharding=*/
2055 ComputeNonRootUsers(instruction) == 1);
2056 }
2057 if (!instruction->has_sharding()) {
2058 instruction->set_sharding(hlo_sharding_util::ReplicateAllDataDims(
2059 instruction->operand(0)->sharding(), instruction->shape().rank()));
2060 return true;
2061 }
2062 return false;
2063 }
2064 case HloOpcode::kReverse: {
2065 const HloInstruction* operand = instruction->operand(0);
2066 if (!IsSpatiallyPartitioned(operand)) {
2067 return false;
2068 }
2069 return MaybeImproveInstructionSharding(
2070 hlo_sharding_util::ReverseSharding(operand->sharding(),
2071 instruction->dimensions()),
2072 instruction, may_combine_partial_sharding,
2073 /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
2074 1);
2075 }
2076 case HloOpcode::kDot: {
2077 const auto& dnums =
2078 dot_as_convolution_util::ParseDotGeneralFromDot(instruction);
2079 return InferDotShardingFromOperands(instruction, dnums,
2080 may_combine_partial_sharding);
2081 }
2082 case HloOpcode::kParameter: {
2083 auto parent_it = computation_map.find(instruction->parent());
2084 if (parent_it == computation_map.end()) {
2085 return false;
2086 }
2087 const HloInstruction* parent = parent_it->second;
2088 switch (parent->opcode()) {
2089 case HloOpcode::kConditional: {
2090 for (int64_t i = 1; i < parent->operand_count(); ++i) {
2091 if (parent->called_computations()[i - 1] == instruction->parent()) {
2092 if (parent->operand(i)->has_sharding()) {
2093 return MaybeImproveInstructionSharding(
2094 parent->operand(i)->sharding(), instruction,
2095 may_combine_partial_sharding);
2096 }
2097 return false;
2098 }
2099 }
2100 return false;
2101 }
2102 default:
2103 return false;
2104 }
2105 }
2106 case HloOpcode::kSort: {
2107 const HloInstruction* operand = PickRepresentativeOperand(instruction);
2108 if (!operand || !IsSpatiallyPartitioned(operand)) {
2109 return false;
2110 }
2111
2112 if (!operand->sharding().IsTileMaximal() &&
2113 operand->sharding().tile_assignment().dim(
2114 instruction->dimensions(0)) != 1) {
2115 // Doesn't support sharding the sorting dimension.
2116 return false;
2117 }
2118
2119 if (instruction->shape().IsTuple()) {
2120 return MaybeImproveInstructionSharding(
2121 HloSharding::SingleTuple(instruction->shape(), operand->sharding()),
2122 instruction, may_combine_partial_sharding,
2123 /*allow_aggressive_resharding=*/
2124 ComputeNonRootUsers(instruction) == 1);
2125 } else {
2126 return MaybeImproveInstructionSharding(
2127 operand->sharding(), instruction, may_combine_partial_sharding,
2128 /*allow_aggressive_resharding=*/
2129 ComputeNonRootUsers(instruction) == 1);
2130 }
2131 }
2132 case HloOpcode::kDynamicSlice:
2133 case HloOpcode::kDynamicUpdateSlice: {
2134 return InferDynamicSliceOrDynamicUpdateSliceShardingFromOperands(
2135 instruction, aggressiveness, may_combine_partial_sharding);
2136 }
2137 case HloOpcode::kGather: {
2138 bool changed = false;
2139 if (IsSpatiallyPartitioned(instruction->operand(1))) {
2140 HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
2141 instruction->operand(1)->sharding(), instruction);
2142 changed |= MaybeImproveInstructionSharding(
2143 std::move(new_sharding), instruction, may_combine_partial_sharding);
2144 }
2145 if (is_spmd_) {
2146 auto gather_parallel_dims =
2147 hlo_sharding_util::GetGatherBatchParallelDims(*instruction);
2148 if (gather_parallel_dims) {
2149 changed |= InferGatherParallelShardingFromOperands(
2150 instruction, *gather_parallel_dims, may_combine_partial_sharding);
2151 }
2152 if (IsSpatiallyPartitioned(instruction->operand(0))) {
2153 absl::Span<const int64_t> operand_parallel_dims;
2154 if (gather_parallel_dims) {
2155 operand_parallel_dims = absl::MakeConstSpan(
2156 gather_parallel_dims->operand_parallel_dims);
2157 }
2158 HloSharding filtered_operand_sharding =
2159 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2160 instruction->operand(0)->sharding(), operand_parallel_dims);
2161 auto maybe_from_data =
2162 hlo_sharding_util::GatherOutputShardingFromDataOperand(
2163 filtered_operand_sharding, *instruction,
2164 instruction->gather_slice_sizes(), instruction->shape(),
2165 instruction->operand(0)->shape());
2166 if (maybe_from_data) {
2167 changed |= MaybeImproveInstructionSharding(
2168 std::move(*maybe_from_data), instruction,
2169 may_combine_partial_sharding);
2170 }
2171 }
2172 }
2173 return changed;
2174 }
2175 case HloOpcode::kScatter: {
2176 bool changed = false;
2177 if (is_spmd_ && IsSpatiallyPartitioned(instruction->operand(0))) {
2178 changed |= MaybeImproveInstructionSharding(
2179 instruction->operand(0)->sharding(), instruction,
2180 may_combine_partial_sharding);
2181 }
2182 auto* scatter = Cast<HloScatterInstruction>(instruction);
2183 if (!IsSpatiallyPartitioned(scatter->scatter_indices()) &&
2184 !IsSpatiallyPartitioned(scatter->scatter_updates()[0])) {
2185 return false;
2186 }
2187 if (is_spmd_ && IsSpatiallyPartitioned(scatter->scatter_updates()[0])) {
2188 auto maybe_from_update =
2189 hlo_sharding_util::ScatterOutputShardingFromUpdate(
2190 scatter->scatter_updates()[0]->sharding(), *scatter);
2191 if (maybe_from_update) {
2192 changed |= MaybeImproveInstructionSharding(
2193 std::move(*maybe_from_update), instruction,
2194 may_combine_partial_sharding);
2195 }
2196 }
2197 if (!is_spmd_) {
2198 changed |= MaybeImproveInstructionSharding(
2199 HloSharding::Replicate(), instruction,
2200 may_combine_partial_sharding);
2201 }
2202 return changed;
2203 }
2204 case HloOpcode::kWhile: {
2205 if (!instruction->operand(0)->has_sharding()) {
2206 return false;
2207 }
2208 auto sharding = instruction->operand(0)->sharding();
2209 if (instruction->has_sharding()) {
2210 hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding,
2211 may_combine_partial_sharding);
2212 }
2213 return MaybeImproveInstructionSharding(std::move(sharding), instruction,
2214 may_combine_partial_sharding);
2215 }
2216 case HloOpcode::kCustomCall: {
2217 if (instruction->IsCustomCall("X64Combine")) {
2218 return false;
2219 }
2220 HloSharding inferred_operand_sharding = HloSharding::Replicate();
2221 if (sharding_helper_->IsCustomCallShardable(instruction)) {
2222 if (auto sharding =
2223 sharding_helper_->InferShardingFromOperands(instruction)) {
2224 inferred_operand_sharding = *sharding;
2225 } else {
2226 return false;
2227 }
2228 } else {
2229 const HloInstruction* operand = PickRepresentativeOperand(instruction);
2230 if (!operand || !IsSpatiallyPartitioned(operand)) {
2231 return false;
2232 }
2233 inferred_operand_sharding = operand->sharding();
2234 }
2235 return MaybeImproveInstructionSharding(
2236 inferred_operand_sharding, instruction, may_combine_partial_sharding,
2237 /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
2238 1);
2239 }
2240 default: {
2241 if (instruction->IsElementwise() && may_combine_partial_sharding) {
2242 bool changed = false;
2243 for (auto operand : instruction->operands()) {
2244 if (IsSpatiallyPartitioned(operand)) {
2245 if (instruction->opcode() == HloOpcode::kRng) {
2246 // Rng is considered elementwise but has operands with different
2247 // shapes.
2248 changed |= MaybeImproveInstructionSharding(
2249 hlo_sharding_util::ReplicateAllDataDims(
2250 operand->sharding(), instruction->shape().rank()),
2251 instruction, may_combine_partial_sharding,
2252 ComputeNonRootUsers(instruction) == 1);
2253 continue;
2254 }
2255 changed |= MaybeImproveInstructionSharding(
2256 operand->sharding(), instruction, may_combine_partial_sharding,
2257 /*allow_aggressive_resharding=*/
2258 instruction->operands().size() == 1 &&
2259 ComputeNonRootUsers(instruction) == 1);
2260 }
2261 }
2262 return changed;
2263 }
2264 const HloInstruction* operand = PickRepresentativeOperand(instruction);
2265 if (!operand || !IsSpatiallyPartitioned(operand)) {
2266 return false;
2267 }
2268 return MaybeImproveInstructionSharding(
2269 operand->sharding(), instruction, may_combine_partial_sharding,
2270 /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
2271 1);
2272 }
2273 }
2274 return false;
2275 } // NOLINT(readability/fn_size)
2276
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2277 StatusOr<bool> ShardingPropagation::Run(
2278 HloModule* module,
2279 const absl::flat_hash_set<absl::string_view>& execution_threads) {
2280 std::optional<absl::flat_hash_map<const HloInstruction*, HloSharding>>
2281 original_sharding;
2282 bool any_changed = false;
2283 // Preprocessing for CSE prevention propagation: record the original shardings
2284 // so that we can revert to them at the end, and only keep those on CSE
2285 // prevention instructions.
2286 if (cse_prevention_only_) {
2287 original_sharding.emplace();
2288 for (auto computation : module->computations(execution_threads)) {
2289 for (auto instruction : computation->instructions()) {
2290 if (instruction->has_sharding()) {
2291 original_sharding->emplace(instruction, instruction->sharding());
2292 }
2293 }
2294 }
2295 } else {
2296 // The current pass is not for CSE prevention, but we remove the shardings
2297 // added by previous passes for CSE prevention.
2298 for (auto computation : module->computations(execution_threads)) {
2299 for (auto instruction : computation->instructions()) {
2300 if (instruction->has_sharding() &&
2301 IsCSEPreventionSharding(instruction->sharding())) {
2302 instruction->clear_sharding();
2303 any_changed = true;
2304 }
2305 }
2306 }
2307 }
2308 any_changed |= propagate_metadata_
2309 ? AssignShardingMetadata(module, execution_threads)
2310 : RemoveShardingMetadata(module, execution_threads);
2311 absl::flat_hash_map<const HloInstruction*, std::vector<int64_t>>
2312 unspecified_dims;
2313 TF_ASSIGN_OR_RETURN(
2314 bool changed,
2315 ProcessShardingInstruction(module, execution_threads,
2316 !cse_prevention_only_, &unspecified_dims));
2317 any_changed |= changed;
2318
2319 // Association of partitionable embedded computations with their parent
2320 // instruction.
2321 ComputationMap computation_map;
2322
2323 // Instructions that are related through a computation and need to share the
2324 // same sharding.
2325 auto get_related_instructions = [](HloInstruction* inst) {
2326 if (inst->opcode() == HloOpcode::kWhile) {
2327 return std::vector<HloInstruction*>{
2328 inst, inst->while_body()->root_instruction(),
2329 inst->while_body()->parameter_instruction(0),
2330 inst->while_condition()->parameter_instruction(0)};
2331 } else if (inst->opcode() == HloOpcode::kConditional) {
2332 const auto& called_computations = inst->called_computations();
2333 std::vector<HloInstruction*> comps;
2334 comps.reserve(called_computations.size() + 1);
2335 comps.push_back(inst);
2336 for (HloComputation* c : called_computations) {
2337 comps.push_back(c->root_instruction());
2338 }
2339 return comps;
2340 } else {
2341 CHECK(false);
2342 }
2343 };
2344
2345 // If instruction is a while, or the root or a parameter of a while body,
2346 // then propagate its sharding to the while instruction, to its body root,
2347 // and to its condition parameter.
2348 std::function<void(HloInstruction*, absl::flat_hash_set<HloInstruction*>*)>
2349 maybe_computation_propagation =
2350 [&](HloInstruction* instruction,
2351 absl::flat_hash_set<HloInstruction*>* changed) {
2352 auto propagate_to_instruction = [&](HloInstruction* search_inst) {
2353 auto related_instructions = get_related_instructions(search_inst);
2354 if (absl::c_count(related_instructions, instruction)) {
2355 for (HloInstruction* inst : related_instructions) {
2356 if (!inst->has_sharding() ||
2357 inst->sharding() != instruction->sharding()) {
2358 VLOG(2) << "Add computation sharding: " << inst->name()
2359 << " " << instruction->sharding().ToString();
2360 inst->set_sharding(instruction->sharding());
2361 changed->insert(inst);
2362 maybe_computation_propagation(inst, changed);
2363 }
2364 }
2365 }
2366 };
2367
2368 if (instruction->opcode() == HloOpcode::kConditional ||
2369 instruction->opcode() == HloOpcode::kWhile) {
2370 propagate_to_instruction(instruction);
2371 }
2372
2373 if (instruction->opcode() == HloOpcode::kParameter ||
2374 instruction->parent()->root_instruction() == instruction) {
2375 auto it = computation_map.find(instruction->parent());
2376 if (it != computation_map.end()) {
2377 propagate_to_instruction(it->second);
2378 }
2379 }
2380 };
2381
2382 for (auto computation : module->computations(execution_threads)) {
2383 for (auto instruction : computation->instructions()) {
2384 if (instruction->opcode() == HloOpcode::kWhile) {
2385 TF_RETURN_IF_ERROR(
2386 CheckAndUpdateDeviceAssignmentsInWhileBody(instruction));
2387 }
2388 }
2389 }
2390
2391 // Populate computation_map in order to associate while bodies to their
2392 // while instructions.
2393 for (auto computation : module->computations(execution_threads)) {
2394 for (auto instruction : computation->instructions()) {
2395 if (instruction->opcode() == HloOpcode::kWhile ||
2396 instruction->opcode() == HloOpcode::kConditional) {
2397 // Check if any of the related instructions has sharding, in which case
2398 // propagate it to the other instructions, so they all share the same
2399 // sharding, in case the user didn't shard all of them. We don't check
2400 // that user shardings are consistent, because such check is already
2401 // done by HloShardingVerifier.
2402 const HloInstruction* sharded_inst = nullptr;
2403 auto related_instructions = get_related_instructions(instruction);
2404 for (auto inst : related_instructions) {
2405 if (inst->has_sharding()) {
2406 sharded_inst = inst;
2407 break;
2408 }
2409 }
2410 if (sharded_inst != nullptr) {
2411 // Set the same sharding to all the other related instructions.
2412 for (auto inst : related_instructions) {
2413 inst->set_sharding(sharded_inst->sharding());
2414 }
2415 }
2416 if (instruction->opcode() == HloOpcode::kWhile) {
2417 computation_map[instruction->while_body()] = instruction;
2418 } else {
2419 for (HloComputation* c : instruction->called_computations()) {
2420 computation_map[c] = instruction;
2421 }
2422 }
2423 }
2424 }
2425 }
2426
2427 // Collect all pre-sharded instructions as we aren't allowed to modify their
2428 // sharding.
2429 absl::flat_hash_set<const HloInstruction*> provided_shardings;
2430 for (const HloComputation* computation :
2431 module->computations(execution_threads)) {
2432 for (const HloInstruction* inst : computation->instructions()) {
2433 if (inst->has_sharding()) {
2434 provided_shardings.insert(inst);
2435 }
2436 }
2437 }
2438
2439 if (!allow_spmd_sharding_propagation_to_output_) {
2440 // Consider the root instruction of the entry module as one with provided
2441 // sharding as its sharding have to match with the one expected by the host.
2442 provided_shardings.insert(module->entry_computation()->root_instruction());
2443 }
2444
2445 // Iterate to a fixpoint that is guaranteed to be reached because we only
2446 // strictly improve the sharding of the graph and it can't be improved
2447 // indefinitely.
2448 int64_t iterations = 0;
2449 auto run_to_fix_point = [&](int64_t aggressiveness) {
2450 absl::flat_hash_set<const HloInstruction*> already_inferred_from_operands;
2451 absl::flat_hash_set<const HloInstruction*> already_inferred_from_users;
2452 bool changed_last_iter = true;
2453 const bool may_merge_partial = is_spmd_ && aggressiveness > 0;
2454 while (changed_last_iter) {
2455 changed_last_iter = false;
2456 int64_t inferred_from_operand_counter = 0;
2457 int64_t inferred_from_user_counter = 0;
2458 int64_t instruction_counter = 0;
2459 int64_t already_sharded_counter = 0;
2460 for (const HloComputation* computation :
2461 module->computations(execution_threads)) {
2462 VLOG(2) << "Consider computation: " << computation->name();
2463 std::vector<HloInstruction*> instructions =
2464 computation->MakeInstructionPostOrder();
2465
2466 instruction_counter += instructions.size();
2467 already_sharded_counter += absl::c_count_if(
2468 instructions,
2469 [](const HloInstruction* inst) { return inst->has_sharding(); });
2470 auto clear_cache = [&](HloInstruction* hlo,
2471 HloInstruction* hlo_for_users = nullptr) {
2472 for (auto operand : hlo->operands()) {
2473 already_inferred_from_users.erase(operand);
2474 }
2475 if (hlo_for_users == nullptr) {
2476 hlo_for_users = hlo;
2477 }
2478 for (auto user : hlo_for_users->users()) {
2479 already_inferred_from_operands.erase(user);
2480 }
2481 };
2482 // First iterate the HLO graph in post order taking shardings from
2483 // operands.
2484 for (HloInstruction* instruction : instructions) {
2485 if (already_inferred_from_operands.contains(instruction)) {
2486 continue;
2487 }
2488 if (provided_shardings.contains(instruction)) {
2489 if (!may_merge_partial) {
2490 continue;
2491 }
2492 auto it = unspecified_dims.find(instruction);
2493 HloInstruction* man_conversion_op_after;
2494 if (it != unspecified_dims.end() &&
2495 InferUnspecifiedDimsFromOperand(instruction, it->second,
2496 &man_conversion_op_after)) {
2497 ++inferred_from_operand_counter;
2498 VLOG(2) << "Refined partial sharding (forward-pass): "
2499 << instruction->ToString();
2500 clear_cache(instruction, man_conversion_op_after);
2501 already_inferred_from_operands.insert(instruction);
2502 changed_last_iter = true;
2503 }
2504 continue;
2505 }
2506 already_inferred_from_operands.insert(instruction);
2507 if (InferShardingFromOperands(instruction, computation_map,
2508 aggressiveness)) {
2509 ++inferred_from_operand_counter;
2510 any_changed = true;
2511 VLOG(2) << "Add sharding (forward-pass): "
2512 << instruction->ToString();
2513 absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
2514 maybe_computation_propagation(instruction, &changed_in_comp_prop);
2515 clear_cache(instruction);
2516 for (auto hlo : changed_in_comp_prop) {
2517 clear_cache(hlo);
2518 }
2519 changed_last_iter = true;
2520 }
2521 }
2522
2523 // Then iterate the HLO graph in reverse post order taking shardings
2524 // from users.
2525 for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) {
2526 if ((*it)->IsCustomCall("SPMDFullToShardShape") ||
2527 (*it)->IsCustomCall("SPMDShardToFullShape")) {
2528 // The manual conversion op is processed together with the sharding
2529 // op before it. If the conversion op is removed from cache, the
2530 // sharding op should also be removed.
2531 if (!already_inferred_from_users.contains(*it)) {
2532 already_inferred_from_users.erase((*it)->operand(0));
2533 }
2534 }
2535 if (already_inferred_from_users.contains(*it)) {
2536 continue;
2537 }
2538 if (provided_shardings.contains(*it)) {
2539 if (!may_merge_partial) {
2540 continue;
2541 }
2542 auto uit = unspecified_dims.find(*it);
2543 HloInstruction* man_conversion_op_after;
2544 if (uit != unspecified_dims.end() &&
2545 InferUnspecifiedDimsFromUsers(*it, uit->second, aggressiveness,
2546 is_spmd_,
2547 &man_conversion_op_after)) {
2548 ++inferred_from_user_counter;
2549 VLOG(2) << "Refined partial sharding (backward-pass): "
2550 << (*it)->ToString();
2551 clear_cache(*it, man_conversion_op_after);
2552 already_inferred_from_users.insert(*it);
2553 if (man_conversion_op_after != nullptr) {
2554 already_inferred_from_users.insert(man_conversion_op_after);
2555 }
2556 changed_last_iter = true;
2557 }
2558 continue;
2559 }
2560 already_inferred_from_users.insert(*it);
2561 if (InferShardingFromUsers(*it, computation_map, aggressiveness,
2562 is_spmd_, sharding_helper_.get())) {
2563 ++inferred_from_user_counter;
2564 any_changed = true;
2565 VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString();
2566 absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
2567 maybe_computation_propagation(*it, &changed_in_comp_prop);
2568 clear_cache(*it);
2569 for (auto hlo : changed_in_comp_prop) {
2570 clear_cache(hlo);
2571 }
2572 changed_last_iter = true;
2573 }
2574 }
2575 }
2576 VLOG(1) << "Sharding propagation iteration " << iterations << ";";
2577 VLOG(1) << " total instructions: " << instruction_counter;
2578 VLOG(1) << " instructions already sharded: " << already_sharded_counter;
2579 VLOG(1) << " shardings inferred from operands: "
2580 << inferred_from_operand_counter;
2581 VLOG(1) << " shardings inferred from users: "
2582 << inferred_from_user_counter;
2583 VLOG(1) << " aggressiveness: " << aggressiveness;
2584 ++iterations;
2585 }
2586 return OkStatus();
2587 };
2588 for (int64_t aggressiveness = 0; aggressiveness < 4; ++aggressiveness) {
2589 TF_RETURN_IF_ERROR(run_to_fix_point(aggressiveness));
2590 }
2591
2592 // Post-process for CSE prevention.
2593 if (cse_prevention_only_) {
2594 for (auto computation : module->computations(execution_threads)) {
2595 for (auto instruction : computation->instructions()) {
2596 if (!instruction->has_sharding()) {
2597 continue;
2598 }
2599 if (IsCSEPreventionTarget(instruction) && instruction->has_sharding()) {
2600 if (!(*original_sharding).contains(instruction)) {
2601 // Mark the propagated sharding as for CSE prevention.
2602 instruction->set_sharding(
2603 SetCSEPreventionSharding(instruction->sharding()));
2604 }
2605 continue;
2606 }
2607 auto it = (*original_sharding).find(instruction);
2608 if (it != (*original_sharding).end()) {
2609 // Revert sharding.
2610 instruction->set_sharding(it->second);
2611 } else {
2612 // Clear sharding.
2613 instruction->clear_sharding();
2614 }
2615 }
2616 }
2617 }
2618
2619 VLOG(1) << "Sharding propagation completed after " << iterations
2620 << " iterations";
2621 return any_changed;
2622 }
2623
2624 } // namespace xla
2625