xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/sharding_propagation.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/sharding_propagation.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/algorithm/container.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/status/status.h"
31 #include "absl/strings/str_split.h"
32 #include "absl/types/span.h"
33 #include "tensorflow/compiler/xla/protobuf_util.h"
34 #include "tensorflow/compiler/xla/service/dot_as_convolution_util.h"
35 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_computation.h"
37 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
41 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
42 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
43 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/sharding_op_util.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/util.h"
48 #include "tensorflow/compiler/xla/xla_data.pb.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/statusor.h"
52 
53 namespace xla {
54 namespace {
55 
56 // Returns true iff the specified hlo or sharding has a spatially partitioned
57 // sharding (tiled or replicated) that can be propagated by sharding
58 // propagation.
IsSpatiallyPartitioned(const HloSharding & sharding)59 bool IsSpatiallyPartitioned(const HloSharding& sharding) {
60   if (sharding.IsTuple()) {
61     return absl::c_any_of(sharding.tuple_elements(), IsSpatiallyPartitioned);
62   } else {
63     return !sharding.IsTileMaximal() || sharding.IsReplicated();
64   }
65 }
IsSpatiallyPartitioned(const HloInstruction * hlo)66 bool IsSpatiallyPartitioned(const HloInstruction* hlo) {
67   return hlo->has_sharding() && IsSpatiallyPartitioned(hlo->sharding());
68 }
69 
70 // Updates the sharding of the specified instruction with the specified sharding
71 // if it is better than the current one and returns true if a new sharding have
72 // been applied. If may_combine_partial_sharding is true, this may combine the
73 // new and existing sharding if they are both partial tiling partial
74 // replication.
MaybeImproveInstructionSharding(HloSharding sharding,HloInstruction * instruction,bool may_combine_partial_sharding,bool allow_aggressive_resharding=false)75 bool MaybeImproveInstructionSharding(HloSharding sharding,
76                                      HloInstruction* instruction,
77                                      bool may_combine_partial_sharding,
78                                      bool allow_aggressive_resharding = false) {
79   // We don't want to propagate tile maximal shardings.
80   if (!IsSpatiallyPartitioned(sharding)) {
81     return false;
82   }
83   // Any sharding is better then no sharding.
84   if (!instruction->has_sharding()) {
85     instruction->set_sharding(std::move(sharding));
86     return true;
87   }
88   int64_t sharding_tiles = sharding.NumTiles();
89   if (hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding,
90                                        may_combine_partial_sharding)) {
91     // Override existing tiled sharding only when the new sharding is compatible
92     // with the existing one. This avoids unexpected resharding when `sharding`
93     // just has more tiles than existing sharding but they are not mergeable.
94     if (!allow_aggressive_resharding && instruction->shape().IsArray() &&
95         !instruction->sharding().IsTileMaximal() &&
96         sharding.NumTiles() == sharding_tiles) {
97       std::vector<int64_t> diff_dims;
98       for (int64_t i = 0; i < instruction->shape().rank(); ++i) {
99         if (instruction->sharding().tile_assignment().dim(i) ==
100             sharding.tile_assignment().dim(i)) {
101           continue;
102         }
103         if (instruction->sharding().tile_assignment().dim(i) != 1) {
104           VLOG(10) << "Not merging because of dim i = " << i
105                    << " sharded differently";
106           VLOG(10) << "Instr sharding: " << instruction->sharding().ToString();
107           VLOG(10) << "New sharding " << sharding.ToString();
108           return false;
109         }
110         diff_dims.push_back(i);
111       }
112       if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
113               sharding, diff_dims) != instruction->sharding()) {
114         VLOG(10) << "Not merging because of different device distribution";
115         VLOG(10) << "Instr sharding: " << instruction->sharding().ToString();
116         VLOG(10) << "New sharding " << sharding.ToString();
117         return false;
118       }
119     }
120     instruction->set_sharding(std::move(sharding));
121     return true;
122   }
123   return false;
124 }
125 
126 // We consider a convolution kernel to be small iff it is smaller along all
127 // spatial dimensions then the output of the convolution. The rational is that
128 // we can either shard the kernel or the output and we want to shard the larger
129 // one for better efficiency.
IsConvolutionKernelSmall(const HloInstruction * instruction)130 bool IsConvolutionKernelSmall(const HloInstruction* instruction) {
131   CHECK_EQ(instruction->opcode(), HloOpcode::kConvolution);
132   const HloInstruction* rhs = instruction->operand(1);
133   const auto& dnums = instruction->convolution_dimension_numbers();
134   int64_t kernel_dim_prod = 1;
135   int64_t output_dim_prod = 1;
136   for (int64_t i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
137     int64_t kernel_dim =
138         rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i));
139     kernel_dim_prod *= kernel_dim;
140     int64_t output_dim =
141         instruction->shape().dimensions(dnums.output_spatial_dimensions(i));
142     output_dim_prod *= output_dim;
143     if (kernel_dim >= output_dim &&
144         (i < 2 || kernel_dim > 3 || kernel_dim_prod >= output_dim_prod)) {
145       return false;
146     }
147   }
148   return true;
149 }
150 
IsPassthroughCustomOps(const HloInstruction * hlo)151 bool IsPassthroughCustomOps(const HloInstruction* hlo) {
152   if (hlo->IsCustomCall("Sharding")) {
153     return true;
154   }
155   if (hlo->IsCustomCall("X64Combine")) {
156     return true;
157   }
158   if (hlo->operand_count() != 1 || !hlo->shape().IsArray() ||
159       !hlo->operand(0)->shape().IsArray() ||
160       hlo->operand(0)->shape().rank() != hlo->shape().rank()) {
161     return false;
162   }
163   return hlo->IsCustomCall("ResizeNearest") ||
164          hlo->IsCustomCall("ResizeBilinear") ||
165          hlo->IsCustomCall("ResizeNearestGrad") ||
166          hlo->IsCustomCall("ResizeBilinearGrad") ||
167          hlo->IsCustomCall("Cholesky");
168 }
169 
170 // Return the operand which is the most suitable for determining the sharding
171 // for the specified instruction or nullptr if there isn't any suitable operand.
PickRepresentativeOperand(const HloInstruction * instruction)172 const HloInstruction* PickRepresentativeOperand(
173     const HloInstruction* instruction) {
174   switch (instruction->opcode()) {
175     case HloOpcode::kMap:
176     case HloOpcode::kPad:
177     case HloOpcode::kPower:
178     case HloOpcode::kOptimizationBarrier:
179     case HloOpcode::kReverse:
180     case HloOpcode::kSlice:
181     case HloOpcode::kShiftLeft:
182     case HloOpcode::kShiftRightArithmetic:
183     case HloOpcode::kShiftRightLogical:
184       // For these opcodes the output sharding has to be determined by the
185       // sharding of the first operand but we can only determine sharding based
186       // on it if it already has a sharding.
187       if (instruction->operand(0)->has_sharding()) {
188         return instruction->operand(0);
189       }
190       return nullptr;
191     case HloOpcode::kAbs:
192     case HloOpcode::kAdd:
193     case HloOpcode::kAnd:
194     case HloOpcode::kAtan2:
195     case HloOpcode::kBitcastConvert:
196     case HloOpcode::kCeil:
197     case HloOpcode::kClamp:
198     case HloOpcode::kClz:
199     case HloOpcode::kCompare:
200     case HloOpcode::kComplex:
201     case HloOpcode::kConcatenate:
202     case HloOpcode::kConvert:
203     case HloOpcode::kCopy:
204     case HloOpcode::kCos:
205     case HloOpcode::kAllGather:
206     case HloOpcode::kAllReduce:
207     case HloOpcode::kReduceScatter:
208     case HloOpcode::kAllToAll:
209     case HloOpcode::kCollectivePermute:
210     case HloOpcode::kDivide:
211     case HloOpcode::kExp:
212     case HloOpcode::kExpm1:
213     case HloOpcode::kFloor:
214     case HloOpcode::kImag:
215     case HloOpcode::kIsFinite:
216     case HloOpcode::kLog:
217     case HloOpcode::kLog1p:
218     case HloOpcode::kLogistic:
219     case HloOpcode::kMaximum:
220     case HloOpcode::kMinimum:
221     case HloOpcode::kMultiply:
222     case HloOpcode::kNegate:
223     case HloOpcode::kNot:
224     case HloOpcode::kOr:
225     case HloOpcode::kPopulationCount:
226     case HloOpcode::kReal:
227     case HloOpcode::kReducePrecision:
228     case HloOpcode::kRemainder:
229     case HloOpcode::kRoundNearestAfz:
230     case HloOpcode::kRoundNearestEven:
231     case HloOpcode::kRsqrt:
232     case HloOpcode::kSelect:
233     case HloOpcode::kSign:
234     case HloOpcode::kSin:
235     case HloOpcode::kSort:
236     case HloOpcode::kSqrt:
237     case HloOpcode::kCbrt:
238     case HloOpcode::kSubtract:
239     case HloOpcode::kTanh:
240     case HloOpcode::kWhile:
241     case HloOpcode::kXor: {
242       // For these opcodes the output sharding can be determined by any operand
243       // so we find the operand with the most specific sharding.
244       const HloInstruction* best_operand = nullptr;
245       for (const HloInstruction* operand : instruction->operands()) {
246         if (operand->has_sharding() &&
247             (best_operand == nullptr ||
248              hlo_sharding_util::IsShardingMoreSpecific(
249                  operand->sharding(), best_operand->sharding()))) {
250           best_operand = operand;
251         }
252       }
253       return best_operand;
254     }
255     case HloOpcode::kCustomCall: {
256       if (IsPassthroughCustomOps(instruction)) {
257         return instruction->operand(0);
258       }
259       return nullptr;
260     }
261     // There is no suitable operand for the rest of the opcodes.
262     case HloOpcode::kAddDependency:
263     case HloOpcode::kAfterAll:
264     case HloOpcode::kAsyncStart:
265     case HloOpcode::kAsyncUpdate:
266     case HloOpcode::kAsyncDone:
267     case HloOpcode::kAllGatherStart:
268     case HloOpcode::kAllGatherDone:
269     case HloOpcode::kAllReduceStart:
270     case HloOpcode::kAllReduceDone:
271     case HloOpcode::kBatchNormGrad:
272     case HloOpcode::kBatchNormInference:
273     case HloOpcode::kBatchNormTraining:
274     case HloOpcode::kBitcast:
275     case HloOpcode::kBroadcast:
276     case HloOpcode::kCall:
277     case HloOpcode::kCholesky:
278     case HloOpcode::kCollectivePermuteDone:
279     case HloOpcode::kCollectivePermuteStart:
280     case HloOpcode::kConditional:
281     case HloOpcode::kConstant:
282     case HloOpcode::kConvolution:
283     case HloOpcode::kCopyDone:
284     case HloOpcode::kCopyStart:
285     case HloOpcode::kDomain:
286     case HloOpcode::kDot:
287     case HloOpcode::kDynamicSlice:
288     case HloOpcode::kDynamicUpdateSlice:
289     case HloOpcode::kDynamicReshape:
290     case HloOpcode::kFft:
291     case HloOpcode::kFusion:
292     case HloOpcode::kGather:
293     case HloOpcode::kGetTupleElement:
294     case HloOpcode::kInfeed:
295     case HloOpcode::kIota:
296     case HloOpcode::kOutfeed:
297     case HloOpcode::kParameter:
298     case HloOpcode::kPartitionId:
299     case HloOpcode::kRecv:
300     case HloOpcode::kRecvDone:
301     case HloOpcode::kReduce:
302     case HloOpcode::kReduceWindow:
303     case HloOpcode::kReplicaId:
304     case HloOpcode::kReshape:
305     case HloOpcode::kRng:
306     case HloOpcode::kRngGetAndUpdateState:
307     case HloOpcode::kRngBitGenerator:
308     case HloOpcode::kScatter:
309     case HloOpcode::kSelectAndScatter:
310     case HloOpcode::kSend:
311     case HloOpcode::kSendDone:
312     case HloOpcode::kTranspose:
313     case HloOpcode::kTriangularSolve:
314     case HloOpcode::kTuple:
315     case HloOpcode::kGetDimensionSize:
316     case HloOpcode::kSetDimensionSize:
317       return nullptr;
318   }
319 }
320 
SupportSpatialPartitioning(const HloInstruction * instruction,const ShardingPropagation::ComputationMap & computation_map,bool is_spmd,bool allow_spmd_sharding_propagation_to_output,const CustomCallShardingHelper * sharding_helper)321 bool SupportSpatialPartitioning(
322     const HloInstruction* instruction,
323     const ShardingPropagation::ComputationMap& computation_map, bool is_spmd,
324     bool allow_spmd_sharding_propagation_to_output,
325     const CustomCallShardingHelper* sharding_helper) {
326   const bool is_entry_root = instruction->parent()
327                                  ->parent()
328                                  ->entry_computation()
329                                  ->root_instruction() == instruction;
330   if (instruction->parent()->root_instruction() == instruction &&
331       computation_map.find(instruction->parent()) == computation_map.end() &
332           !(is_entry_root && allow_spmd_sharding_propagation_to_output)) {
333     // We don't support sharding the root instruction of a computation yet,
334     // unless the computation is a while body.
335     return false;
336   }
337 
338   if (instruction->IsElementwise() &&
339       (instruction->opcode() != HloOpcode::kRng || is_spmd)) {
340     return true;
341   }
342   switch (instruction->opcode()) {
343     case HloOpcode::kBroadcast:
344     case HloOpcode::kConcatenate:
345     case HloOpcode::kConditional:
346     case HloOpcode::kConstant:
347     case HloOpcode::kConvolution:
348     case HloOpcode::kOptimizationBarrier:
349     case HloOpcode::kDot:
350     case HloOpcode::kDynamicSlice:
351     case HloOpcode::kDynamicUpdateSlice:
352     case HloOpcode::kGather:
353     case HloOpcode::kGetTupleElement:
354     case HloOpcode::kInfeed:
355     case HloOpcode::kIota:
356     case HloOpcode::kPad:
357     case HloOpcode::kReduceWindow:
358     case HloOpcode::kReshape:
359     case HloOpcode::kScatter:
360     case HloOpcode::kSelectAndScatter:
361     case HloOpcode::kSlice:
362     case HloOpcode::kSort:
363     case HloOpcode::kTranspose:
364     case HloOpcode::kTuple:
365     case HloOpcode::kWhile:
366     case HloOpcode::kReduce:
367     case HloOpcode::kRngBitGenerator:
368       return true;
369     case HloOpcode::kAllReduce:
370     case HloOpcode::kReduceScatter:
371       // Only if channel_id is not specified.
372       return instruction->channel_id() == std::nullopt;
373     case HloOpcode::kParameter:
374       return computation_map.find(instruction->parent()) !=
375              computation_map.end();
376     case HloOpcode::kReverse:
377       return is_spmd;
378     case HloOpcode::kCustomCall:
379       return is_spmd && (IsPassthroughCustomOps(instruction) ||
380                          sharding_helper->IsCustomCallShardable(instruction));
381     default:
382       return false;
383   }
384 }
385 
InferDotShardingFromOperands(HloInstruction * instruction,const dot_as_convolution_util::DotConvolutionDimsInfo & dnums,bool may_combine_partial_sharding)386 bool InferDotShardingFromOperands(
387     HloInstruction* instruction,
388     const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
389     bool may_combine_partial_sharding) {
390   auto from_operand = [&](int64_t operand_index) {
391     auto operand = instruction->operand(operand_index);
392     const HloSharding& operand_sharding = operand->sharding();
393     if (operand_sharding.IsTileMaximal()) {
394       return operand_sharding;
395     }
396     std::vector<int64_t> contracting_dims;
397     contracting_dims.reserve(dnums.contracting_dims.size());
398     for (const auto& dim : dnums.contracting_dims) {
399       contracting_dims.push_back(operand_index == 0 ? dim.lhs : dim.rhs);
400     }
401     // It's possible that some size-1 spatial dims of convolutions are parsed as
402     // non-contracting dims. We might have tiled dimensions on them.
403     for (const auto& dim : operand_index == 0
404                                ? dnums.rhs_non_contracting_dims
405                                : dnums.lhs_non_contracting_dims) {
406       int64_t d = operand_index == 0 ? dim.lhs : dim.rhs;
407       if (d > 0) {
408         contracting_dims.push_back(d);
409       }
410     }
411     auto replicate_contracting_dims =
412         hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
413             operand_sharding, contracting_dims);
414     std::vector<int64_t> out_dims_to_op_perm(instruction->shape().rank(), -1);
415     std::vector<int64_t> op_dims_to_output_perm(operand->shape().rank(), -1);
416     for (const auto& dim : dnums.batch_dims) {
417       out_dims_to_op_perm[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
418       op_dims_to_output_perm[operand_index == 0 ? dim.lhs : dim.rhs] =
419           dim.output;
420     }
421     for (const auto& dim : operand_index == 0
422                                ? dnums.lhs_non_contracting_dims
423                                : dnums.rhs_non_contracting_dims) {
424       out_dims_to_op_perm[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
425       op_dims_to_output_perm[operand_index == 0 ? dim.lhs : dim.rhs] =
426           dim.output;
427     }
428     return *hlo_sharding_util::TransposeShardingWithCollapsedDims(
429         replicate_contracting_dims, op_dims_to_output_perm,
430         out_dims_to_op_perm);
431   };
432   bool changed = false;
433   int64_t larger_operand =
434       ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) >=
435               ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())
436           ? 0
437           : 1;
438   if (IsSpatiallyPartitioned(instruction->operand(larger_operand))) {
439     changed |= MaybeImproveInstructionSharding(from_operand(larger_operand),
440                                                instruction,
441                                                may_combine_partial_sharding);
442   }
443   if (IsSpatiallyPartitioned(instruction->operand(1 - larger_operand))) {
444     changed |= MaybeImproveInstructionSharding(from_operand(1 - larger_operand),
445                                                instruction,
446                                                may_combine_partial_sharding);
447   }
448   return changed;
449 }
450 
InferGatherParallelShardingFromOperands(HloInstruction * instruction,const hlo_sharding_util::GatherParallelDims & parallel_dims,bool may_combine_partial_sharding)451 bool InferGatherParallelShardingFromOperands(
452     HloInstruction* instruction,
453     const hlo_sharding_util::GatherParallelDims& parallel_dims,
454     bool may_combine_partial_sharding) {
455   auto from_operand =
456       [instruction](int64_t operand_index,
457                     absl::Span<const int64_t> output_aligned_parallel_dims,
458                     absl::Span<const int64_t> output_parallel_dims) {
459         const HloInstruction* operand = instruction->operand(operand_index);
460         const HloSharding& operand_sharding = operand->sharding();
461         if (operand_sharding.IsTileMaximal()) {
462           return operand_sharding;
463         }
464         auto dnums = instruction->gather_dimension_numbers();
465         std::vector<int64_t> output_tile_dims(instruction->shape().rank(), 1);
466         std::vector<int64_t> index_non_parallel_dims;
467         index_non_parallel_dims.reserve(operand->shape().rank());
468         // Detect non parallel dimensions in the index.
469         for (int i = 0; i < operand->shape().rank(); ++i) {
470           if (!absl::c_linear_search(output_aligned_parallel_dims, i)) {
471             index_non_parallel_dims.push_back(i);
472           }
473         }
474         // Collect tile dimensions in the operand. The order of the parallel
475         // dimensions in output_aligned_parallel_dims is the same as that of the
476         // output
477         for (int i = 0; i < output_aligned_parallel_dims.size(); ++i) {
478           const int64_t indices_idx = output_aligned_parallel_dims[i];
479           const int64_t output_idx = output_parallel_dims[i];
480           output_tile_dims[output_idx] =
481               operand_sharding.tile_assignment().dim(indices_idx);
482         }
483         HloSharding replicate_non_parallel_dims =
484             hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
485                 operand_sharding, index_non_parallel_dims);
486         if (replicate_non_parallel_dims.IsTileMaximal()) {
487           return replicate_non_parallel_dims;
488         }
489         for (int64_t i = replicate_non_parallel_dims.TiledDataRank();
490              i < replicate_non_parallel_dims.tile_assignment().num_dimensions();
491              ++i) {
492           output_tile_dims.push_back(
493               replicate_non_parallel_dims.tile_assignment().dim(i));
494         }
495         auto output_tile_assignment =
496             replicate_non_parallel_dims.tile_assignment();
497         output_tile_assignment.Reshape(output_tile_dims);
498         return replicate_non_parallel_dims.ReplicateOnLastTileDim()
499                    ? HloSharding::PartialTile(
500                          output_tile_assignment,
501                          replicate_non_parallel_dims.metadata())
502                    : HloSharding::Subgroup(
503                          output_tile_assignment,
504                          replicate_non_parallel_dims.subgroup_types(),
505                          replicate_non_parallel_dims.metadata());
506       };
507 
508   bool changed = false;
509   auto output_parallel_dims =
510       hlo_sharding_util::GatherParallelOutputDims(*instruction, parallel_dims);
511   if (IsSpatiallyPartitioned(instruction->operand(0))) {
512     changed |= MaybeImproveInstructionSharding(
513         from_operand(
514             0,
515             absl::MakeConstSpan(
516                 hlo_sharding_util::GatherOutputAlignedOperandParallelDims(
517                     *instruction, parallel_dims)),
518             absl::MakeConstSpan(output_parallel_dims)),
519         instruction, may_combine_partial_sharding);
520   }
521   if (IsSpatiallyPartitioned(instruction->operand(1))) {
522     changed |= MaybeImproveInstructionSharding(
523         from_operand(1,
524                      absl::MakeConstSpan(parallel_dims.indices_parallel_dims),
525                      absl::MakeConstSpan(output_parallel_dims)),
526         instruction, may_combine_partial_sharding);
527   }
528   return changed;
529 }
530 
531 // Convolution handling for InferShardingFromOperands().
InferConvolutionShardingFromOperands(HloInstruction * instruction,int64_t aggressiveness,bool may_combine_partial_sharding)532 bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
533                                           int64_t aggressiveness,
534                                           bool may_combine_partial_sharding) {
535   auto get_partitions_for_dims =
536       [&](const HloInstruction* inst,
537           absl::Span<
538               const dot_as_convolution_util::DotConvolutionDimsInfo::DimNums>
539               dims,
540           int lhs_or_rhs) {
541         int64_t partitions = 1;
542         if (!inst->has_sharding()) {
543           return partitions;
544         }
545         const auto& sharding = inst->sharding();
546         if (sharding.IsTileMaximal()) {
547           return partitions;
548         }
549         for (const auto& dim : dims) {
550           if (lhs_or_rhs == 0) {
551             partitions *= sharding.tile_assignment().dim(dim.lhs);
552           } else {
553             CHECK_EQ(lhs_or_rhs, 1);
554             partitions *= sharding.tile_assignment().dim(dim.rhs);
555           }
556         }
557         return partitions;
558       };
559   auto dot_dims =
560       dot_as_convolution_util::ParseConvolutionDimsInfo(instruction);
561   const int64_t lhs_conv_spatial_partitions = get_partitions_for_dims(
562       instruction->operand(0), dot_dims.conv_spatial_dims, 0);
563   const int64_t rhs_conv_spatial_partitions = get_partitions_for_dims(
564       instruction->operand(1), dot_dims.conv_spatial_dims, 1);
565   if (dot_dims.conv_spatial_dims.empty() ||
566       (lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 &&
567        instruction->batch_group_count() == 1 &&
568        instruction->feature_group_count() == 1)) {
569     return InferDotShardingFromOperands(instruction, dot_dims,
570                                         may_combine_partial_sharding);
571   }
572   const auto& dnums = instruction->convolution_dimension_numbers();
573   const HloInstruction* lhs = instruction->operand(0);
574   auto get_tiled_sharding_based_on_lhs = [&] {
575     CHECK(!lhs->sharding().IsTileMaximal());
576     std::vector<int64_t> output_to_lhs_indices(instruction->shape().rank());
577     output_to_lhs_indices[dnums.output_batch_dimension()] =
578         dnums.input_batch_dimension();
579     output_to_lhs_indices[dnums.output_feature_dimension()] =
580         dnums.input_feature_dimension();
581     for (int64_t i = 0; i < dnums.input_spatial_dimensions_size(); ++i) {
582       output_to_lhs_indices[dnums.output_spatial_dimensions(i)] =
583           dnums.input_spatial_dimensions(i);
584     }
585     return hlo_sharding_util::TransposeSharding(lhs->sharding(),
586                                                 output_to_lhs_indices);
587   };
588   if (!IsSpatiallyPartitioned(lhs)) {
589     return false;
590   }
591   if (lhs->sharding().IsTileMaximal()) {
592     return MaybeImproveInstructionSharding(lhs->sharding(), instruction,
593                                            may_combine_partial_sharding);
594   }
595 
596   if (IsConvolutionKernelSmall(instruction)) {
597     // If the kernel is small compared to the input then we can generate an
598     // output what is sharded the same way as the input.
599     const auto& tile_assignment = lhs->sharding().tile_assignment();
600     if (tile_assignment.dim(dnums.input_feature_dimension()) > 1) {
601       return false;
602     }
603     return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_lhs(),
604                                            instruction,
605                                            may_combine_partial_sharding);
606   }
607   // If the kernel is large (e.g backward convolution) then we only support
608   // replicated output.
609   return MaybeImproveInstructionSharding(
610       hlo_sharding_util::ReplicateAllDataDims(lhs->sharding(),
611                                               instruction->shape().rank()),
612       instruction, may_combine_partial_sharding);
613 }
614 
CanPropagateThroughAtAggressiveLevel(const HloInstruction & inst,int64_t aggressiveness)615 bool CanPropagateThroughAtAggressiveLevel(const HloInstruction& inst,
616                                           int64_t aggressiveness) {
617   // At minimum aggressiveness, only allow pass-through ops.
618   if (aggressiveness < 1 &&
619       !(inst.IsElementwise() || inst.IsCustomCall("Sharding")) &&
620       inst.opcode() != HloOpcode::kTranspose &&
621       inst.opcode() != HloOpcode::kReshape &&
622       inst.opcode() != HloOpcode::kTuple &&
623       inst.opcode() != HloOpcode::kGetTupleElement &&
624       inst.opcode() != HloOpcode::kWhile &&
625       inst.opcode() != HloOpcode::kDynamicSlice &&
626       inst.opcode() != HloOpcode::kOptimizationBarrier &&
627       inst.opcode() != HloOpcode::kConcatenate) {
628     return false;
629   }
630   // Broadcast propagation should have at least aggressiveness 2.
631   if (aggressiveness < 2 && inst.opcode() == HloOpcode::kBroadcast) {
632     return false;
633   }
634   return true;
635 }
636 
InferDotOperandSharding(const HloInstruction * instruction,const dot_as_convolution_util::DotConvolutionDimsInfo & dnums,int64_t operand_index,bool may_combine_partial_sharding)637 HloSharding InferDotOperandSharding(
638     const HloInstruction* instruction,
639     const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
640     int64_t operand_index, bool may_combine_partial_sharding) {
641   auto operand = instruction->operand(operand_index);
642   auto other = instruction->operand(1 - operand_index);
643   std::vector<int64_t> output_dims_to_replicate;
644   std::vector<int64_t> other_operand_dims_to_replicate;
645   for (const auto& dim : operand_index == 0 ? dnums.rhs_non_contracting_dims
646                                             : dnums.lhs_non_contracting_dims) {
647     output_dims_to_replicate.push_back(dim.output);
648     other_operand_dims_to_replicate.push_back(operand_index == 0 ? dim.rhs
649                                                                  : dim.lhs);
650   }
651   // If this dot is interpreted from a conv, then contracting dims may have
652   // corresponding spatial dimensions in the output, and this operand's
653   // non-contracting dims may have corresponding spatial dims in the other
654   // operand.
655   for (const auto& dim : dnums.contracting_dims) {
656     if (dim.output >= 0) {
657       output_dims_to_replicate.push_back(dim.output);
658     }
659   }
660   for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims
661                                             : dnums.rhs_non_contracting_dims) {
662     int64_t other_dim = operand_index == 0 ? dim.rhs : dim.lhs;
663     if (other_dim >= 0) {
664       other_operand_dims_to_replicate.push_back(other_dim);
665     }
666   }
667   auto output_other_dims_replicated =
668       hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
669           instruction->sharding(), output_dims_to_replicate);
670   std::vector<int64_t> output_to_operand_dims(instruction->shape().rank(), -1);
671   std::vector<int64_t> operand_to_output_dims(operand->shape().rank(), -1);
672   for (const auto& dim : dnums.batch_dims) {
673     output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
674     operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output;
675   }
676   for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims
677                                             : dnums.rhs_non_contracting_dims) {
678     output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs;
679     operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output;
680   }
681   auto sharding = *hlo_sharding_util::TransposeShardingWithCollapsedDims(
682       output_other_dims_replicated, output_to_operand_dims,
683       operand_to_output_dims);
684   if (IsSpatiallyPartitioned(other)) {
685     auto other_operand_dims_replicated =
686         hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
687             other->sharding(), other_operand_dims_to_replicate);
688     std::vector<int64_t> other_to_operand_dims(other->shape().rank(), -1);
689     std::vector<int64_t> operand_to_other_dims(operand->shape().rank(), -1);
690     for (const auto& dim : dnums.batch_dims) {
691       other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] =
692           operand_index == 0 ? dim.lhs : dim.rhs;
693       operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] =
694           operand_index == 0 ? dim.rhs : dim.lhs;
695     }
696     for (const auto& dim : dnums.contracting_dims) {
697       other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] =
698           operand_index == 0 ? dim.lhs : dim.rhs;
699       operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] =
700           operand_index == 0 ? dim.rhs : dim.lhs;
701     }
702     HloSharding sharding_from_other =
703         *hlo_sharding_util::TransposeShardingWithCollapsedDims(
704             other_operand_dims_replicated, other_to_operand_dims,
705             operand_to_other_dims);
706     if (hlo_sharding_util::MergeSharding(sharding, &sharding_from_other,
707                                          may_combine_partial_sharding)) {
708       sharding = std::move(sharding_from_other);
709     }
710   }
711   return sharding;
712 }
713 
714 // Tries to update the sharding of the specified instruction based on its users
715 // and returns true if the sharding of the instruction have been changed and
716 // false otherwise.
InferShardingFromUsers(HloInstruction * instruction,const ShardingPropagation::ComputationMap & computation_map,int64_t aggressiveness,bool is_spmd,const CustomCallShardingHelper * sharding_helper)717 bool InferShardingFromUsers(
718     HloInstruction* instruction,
719     const ShardingPropagation::ComputationMap& computation_map,
720     int64_t aggressiveness, bool is_spmd,
721     const CustomCallShardingHelper* sharding_helper) {
722   if (aggressiveness < 2 && instruction->opcode() == HloOpcode::kBroadcast) {
723     return false;
724   }
725   // Do not change manual sharding.
726   if (instruction->has_sharding() && instruction->sharding().IsManual()) {
727     return false;
728   }
729   // Propagate manual sharding.
730   if (!instruction->has_sharding()) {
731     for (const HloInstruction* user : instruction->users()) {
732       if (!user->has_sharding() || !user->sharding().IsManual() ||
733           user->IsCustomCall("SPMDFullToShardShape"))
734         continue;
735       if (instruction->shape().IsArray()) {
736         instruction->set_sharding(
737             HloSharding::Manual(user->sharding().metadata()));
738       } else {
739         std::optional<HloSharding> user_sharding =
740             ShardingPropagation::GetShardingFromUser(*instruction, *user,
741                                                      aggressiveness, is_spmd);
742         if (user_sharding) {
743           instruction->set_sharding(*user_sharding);
744         }
745       }
746       return true;
747     }
748   }
749   if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd, false,
750                                   sharding_helper)) {
751     return false;
752   }
753 
754   bool improved_sharding = false;
755   const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
756   for (const HloInstruction* user : instruction->users()) {
757     std::optional<HloSharding> user_sharding =
758         ShardingPropagation::GetShardingFromUser(*instruction, *user,
759                                                  aggressiveness, is_spmd);
760     if (user_sharding && sharding_helper->IsCustomCallShardable(instruction)) {
761       user_sharding = sharding_helper->PropagateUserSharding(instruction, user,
762                                                              *user_sharding);
763     }
764     if (user_sharding) {
765       improved_sharding |= MaybeImproveInstructionSharding(
766           std::move(*user_sharding), instruction, may_combine_partial_sharding);
767     }
768   }
769   return improved_sharding;
770 }
771 
772 // Checks if two HloShardings have the same metadata attached.
SameShardingMetadata(const HloSharding & a,const HloSharding & b)773 bool SameShardingMetadata(const HloSharding& a, const HloSharding& b) {
774   DCHECK_EQ(a, b);
775 
776   auto same_metadata = [](absl::Span<const OpMetadata> a,
777                           absl::Span<const OpMetadata> b) {
778     if (a.size() != b.size()) return false;
779     for (int i = 0, e = a.size(); i < e; ++i) {
780       if (!protobuf_util::ProtobufEquals(a[i], b[i])) {
781         return false;
782       }
783     }
784     return true;
785   };
786 
787   if (a.IsTuple()) {
788     for (int i = 0, e = a.tuple_elements().size(); i < e; ++i) {
789       if (!same_metadata(a.tuple_elements()[i].metadata(),
790                          b.tuple_elements()[i].metadata())) {
791         return false;
792       }
793     }
794     return true;
795   } else {
796     return same_metadata(a.metadata(), b.metadata());
797   }
798 }
799 
800 // Assigns metadata to optional sharding on instructions if instructions have
801 // metadata. If sharding already has some metadata, no new metadata will be
802 // added.
AssignShardingMetadata(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)803 bool AssignShardingMetadata(
804     HloModule* module,
805     const absl::flat_hash_set<absl::string_view>& execution_threads) {
806   bool changed = false;
807   for (HloComputation* computation : module->computations(execution_threads)) {
808     for (HloInstruction* instruction : computation->instructions()) {
809       const auto& metadata = instruction->metadata();
810       if (!instruction->has_sharding() || metadata.ByteSizeLong() == 0) {
811         continue;
812       }
813 
814       HloSharding sharding_with_metadata =
815           instruction->sharding().WithMetadata({metadata}, /*overwrite=*/false);
816       if (!SameShardingMetadata(instruction->sharding(),
817                                 sharding_with_metadata)) {
818         instruction->set_sharding(std::move(sharding_with_metadata));
819         changed = true;
820       }
821     }
822   }
823   return changed;
824 }
825 
826 // Removes all sharding metadata from shardings on instructions.
RemoveShardingMetadata(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)827 bool RemoveShardingMetadata(
828     HloModule* module,
829     const absl::flat_hash_set<absl::string_view>& execution_threads) {
830   bool changed = false;
831   for (HloComputation* computation : module->computations(execution_threads)) {
832     for (HloInstruction* instruction : computation->instructions()) {
833       if (!instruction->has_sharding()) {
834         continue;
835       }
836       HloSharding sharding_no_metadata =
837           instruction->sharding().WithoutMetadata();
838       if (!SameShardingMetadata(instruction->sharding(),
839                                 sharding_no_metadata)) {
840         instruction->set_sharding(std::move(sharding_no_metadata));
841         changed = true;
842       }
843     }
844   }
845   return changed;
846 }
847 
848 // If a while contains a channel instruction on device D, check that any other
849 // instructions with a device assignment are on D. Further, annotate the root
850 // instruction of the while body to ensure that HLO partitioning will keep the
851 // entire while instruction on D.
CheckAndUpdateDeviceAssignmentsInWhileBody(HloInstruction * while_instruction)852 Status CheckAndUpdateDeviceAssignmentsInWhileBody(
853     HloInstruction* while_instruction) {
854   auto bad_status = [](HloInstruction* instruction, int64_t device,
855                        HloInstruction* channel_instruction,
856                        int64_t correct_device) {
857     return FailedPrecondition(
858         "Instruction: %s is on device: %d, which conflicts with device: %d "
859         "of channel instruction: %s",
860         instruction->name(), device, correct_device,
861         channel_instruction->name());
862   };
863 
864   CHECK_EQ(while_instruction->opcode(), HloOpcode::kWhile);
865   HloComputation* while_body = while_instruction->while_body();
866   // Maps a device number to an instruction in the while_body with that
867   // device assignment.
868   std::map<int64_t, HloInstruction*> devices_to_instructions;
869   std::optional<int64_t> unique_device = std::nullopt;
870   HloInstruction* channel_instruction = nullptr;
871 
872   for (HloInstruction* instruction : while_body->instructions()) {
873     if (instruction->sharding_unique_device()) {
874       auto opcode = instruction->opcode();
875       int64_t device = *instruction->sharding_unique_device();
876       if (unique_device.has_value()) {
877         if (*unique_device != device) {
878           return bad_status(instruction, device, channel_instruction,
879                             *unique_device);
880         }
881       } else if (opcode == HloOpcode::kSend || opcode == HloOpcode::kRecv ||
882                  // Cross-replica AllReduces don't have a channel_id, and we
883                  // don't enforce any invariant about their device assignment.
884                  ((opcode == HloOpcode::kAllReduce ||
885                    opcode == HloOpcode::kReduceScatter) &&
886                   instruction->channel_id())) {
887         channel_instruction = instruction;
888         unique_device = device;
889         if (!devices_to_instructions.empty()) {
890           for (auto it = devices_to_instructions.begin();
891                it != devices_to_instructions.end(); ++it) {
892             if (*unique_device != it->first) {
893               return bad_status(it->second, it->first, channel_instruction,
894                                 *unique_device);
895             }
896           }
897         }
898       } else {
899         devices_to_instructions[device] = instruction;
900       }
901     }
902   }
903 
904   if (unique_device.has_value()) {
905     auto while_device = while_instruction->sharding_unique_device();
906     if (while_device.has_value() && *unique_device != *while_device) {
907       return bad_status(while_instruction, *while_device, channel_instruction,
908                         *unique_device);
909     }
910     auto body_root = while_body->root_instruction();
911     auto root_device = body_root->sharding_unique_device();
912     if (!root_device.has_value()) {
913       body_root->set_device_sharding(*unique_device);
914     } else if (*unique_device != *root_device) {
915       return bad_status(body_root, *root_device, channel_instruction,
916                         *unique_device);
917     }
918   }
919   return OkStatus();
920 }
921 
922 // Refines a pair of auto/manual shardings based on auto sharding `to_merge`
923 // along `unspecified_dims`. Returns if anything changed.
RefineManualAutoShardingFromAuto(const HloSharding & to_merge,absl::Span<const int64_t> unspecified_dims,HloSharding * auto_sharding,HloSharding * manual_sharding)924 bool RefineManualAutoShardingFromAuto(
925     const HloSharding& to_merge, absl::Span<const int64_t> unspecified_dims,
926     HloSharding* auto_sharding, HloSharding* manual_sharding) {
927   if (!manual_sharding->IsManualSubgroup() ||
928       auto_sharding->IsManualSubgroup() ||
929       !manual_sharding->HasPartialReplication() ||
930       manual_sharding->subgroup_types().size() != 2) {
931     // We do not support nested subgroup manual. man_conversion_op must have
932     // replication in order to be merged.
933     return false;
934   }
935   HloSharding partial_rep =
936       hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
937           to_merge, unspecified_dims);
938   if (partial_rep.IsTileMaximal()) {
939     return false;
940   }
941 
942   // Merge with the non-manual partial annotation.
943   if (!hlo_sharding_util::MergeShardingIfCompatible(
944           partial_rep, auto_sharding->NumTiles() + 1, auto_sharding)) {
945     return false;
946   }
947 
948   // Merge with the manual partial annotation.
949   const int64_t data_rank = partial_rep.TiledDataRank();
950   // We are also merging the non-manual sharding into the manual sharding. To
951   // leverage existing merging implementation, we treat the manual dim as a
952   // data dim, and add it right before the replication dim.
953   auto partial_tiling_for_manual = partial_rep.tile_assignment();
954   std::vector<int64_t> partial_manual_shape =
955       partial_tiling_for_manual.dimensions();
956   partial_manual_shape.insert(partial_manual_shape.begin() + data_rank, 1);
957   partial_tiling_for_manual.Reshape(partial_manual_shape);
958   HloSharding partial_rep_for_manual = HloSharding::PartialTile(
959       partial_tiling_for_manual, partial_rep.metadata());
960   Array<int64_t> man_tiling = manual_sharding->tile_assignment();
961   if (manual_sharding->subgroup_types().back() != OpSharding::REPLICATED) {
962     // Move the manual dim before replication dim.
963     std::vector<int64_t> transposed_dims = man_tiling.dimensions();
964     transposed_dims[data_rank] = transposed_dims.back();
965     transposed_dims.back() = man_tiling.dim(data_rank);
966     Array<int64_t> transposed(transposed_dims);
967     man_tiling.Each([&](absl::Span<const int64_t> indices, int64_t device) {
968       std::vector<int64_t> xposed_idx(indices.begin(), indices.end() - 2);
969       xposed_idx.push_back(indices.back());
970       xposed_idx.push_back(indices[data_rank]);
971       transposed(xposed_idx) = device;
972     });
973     man_tiling = std::move(transposed);
974   }
975   HloSharding tmp_sharding_for_merging =
976       HloSharding::PartialTile(man_tiling, manual_sharding->metadata());
977   if (!hlo_sharding_util::MergeShardingIfCompatible(
978           partial_rep_for_manual, tmp_sharding_for_merging.NumTiles() + 1,
979           &tmp_sharding_for_merging)) {
980     return false;
981   }
982 
983   std::vector<OpSharding::Type> subgroup_types;
984   subgroup_types.push_back(OpSharding::MANUAL);
985   if (tmp_sharding_for_merging.HasPartialReplication()) {
986     subgroup_types.push_back(OpSharding::REPLICATED);
987   }
988   *manual_sharding = HloSharding::Subgroup(
989       tmp_sharding_for_merging.tile_assignment(), subgroup_types,
990       tmp_sharding_for_merging.metadata());
991   return true;
992 }
993 
994 // Refines a pair of auto/manual shardings based on manual sharding `to_merge`
995 // along `unspecified_dims`. Returns if anything changed.
RefineManualAutoShardingFromManual(const HloSharding & to_merge,absl::Span<const int64_t> unspecified_dims,HloSharding * auto_sharding,HloSharding * manual_sharding)996 bool RefineManualAutoShardingFromManual(
997     const HloSharding& to_merge, absl::Span<const int64_t> unspecified_dims,
998     HloSharding* auto_sharding, HloSharding* manual_sharding) {
999   if (!to_merge.IsManualSubgroup() || !manual_sharding->IsManualSubgroup() ||
1000       !manual_sharding->HasPartialReplication() ||
1001       auto_sharding->IsManualSubgroup() ||
1002       manual_sharding->subgroup_types().size() != 2) {
1003     return false;
1004   }
1005   HloSharding partial_rep =
1006       hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
1007           to_merge, unspecified_dims);
1008   if (partial_rep.IsTileMaximal()) {
1009     return false;
1010   }
1011   if (!hlo_sharding_util::MergeShardingIfCompatible(
1012           partial_rep, manual_sharding->NumTiles() + 1, manual_sharding)) {
1013     return false;
1014   }
1015   HloSharding partial_rep_for_auto = HloSharding::Subgroup(
1016       partial_rep.tile_assignment(),
1017       std::vector<OpSharding::Type>(partial_rep.subgroup_types().size(),
1018                                     OpSharding::REPLICATED),
1019       partial_rep.metadata());
1020   if (!hlo_sharding_util::MergeShardingIfCompatible(
1021           partial_rep_for_auto, auto_sharding->NumTiles() + 1, auto_sharding)) {
1022     return false;
1023   }
1024   return true;
1025 }
1026 
InferUnspecifiedDimsFromOperand(HloInstruction * annotate_op,absl::Span<const int64_t> unspecified_dims,HloInstruction ** man_conversion_op_after)1027 bool InferUnspecifiedDimsFromOperand(HloInstruction* annotate_op,
1028                                      absl::Span<const int64_t> unspecified_dims,
1029                                      HloInstruction** man_conversion_op_after) {
1030   // ProcessShardingInstruction will either keep the "Sharding" custom call as
1031   // is or replace it with a copy.
1032   CHECK(annotate_op->IsCustomCall("Sharding") ||
1033         annotate_op->opcode() == HloOpcode::kCopy);
1034   if (!IsSpatiallyPartitioned(annotate_op->operand(0))) {
1035     return false;
1036   }
1037   const HloSharding& operand_sharding = annotate_op->operand(0)->sharding();
1038   if (!operand_sharding.IsTiled()) {
1039     return false;
1040   }
1041   HloInstruction* man_conversion_op = nullptr;
1042   if (annotate_op->user_count() == 1) {
1043     HloInstruction* user = annotate_op->users()[0];
1044     if (user->IsCustomCall("SPMDFullToShardShape") ||
1045         user->IsCustomCall("SPMDShardToFullShape")) {
1046       std::vector<int64_t> user_unspec_dims;
1047       if (!sharding_op_util::ParseAttributes(
1048                Cast<HloCustomCallInstruction>(user)->opaque(),
1049                &user_unspec_dims)
1050                .ok()) {
1051         return false;
1052       }
1053       absl::c_sort(user_unspec_dims);
1054       if (unspecified_dims != user_unspec_dims) {
1055         // The manual/auto conversion op must have the same set of unspecified
1056         // dims.
1057         return false;
1058       }
1059       man_conversion_op = user;
1060     }
1061   }
1062   *man_conversion_op_after = man_conversion_op;
1063   if (man_conversion_op == nullptr) {
1064     HloSharding partial_replicated =
1065         hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
1066             operand_sharding, unspecified_dims);
1067     HloSharding sharding = annotate_op->sharding();
1068     if (!hlo_sharding_util::MergeShardingIfCompatible(
1069             partial_replicated, sharding.NumTiles() + 1, &sharding)) {
1070       return false;
1071     }
1072     annotate_op->set_sharding(sharding);
1073     return true;
1074   }
1075   if (man_conversion_op->IsCustomCall("SPMDFullToShardShape")) {
1076     HloSharding auto_sharding = annotate_op->sharding();
1077     HloSharding manual_sharding = man_conversion_op->sharding();
1078     if (!RefineManualAutoShardingFromAuto(operand_sharding, unspecified_dims,
1079                                           &auto_sharding, &manual_sharding)) {
1080       return false;
1081     }
1082     annotate_op->set_sharding(auto_sharding);
1083     man_conversion_op->set_sharding(manual_sharding);
1084     return true;
1085   }
1086 
1087   CHECK(man_conversion_op->IsCustomCall("SPMDShardToFullShape"));
1088   HloSharding manual_sharding = annotate_op->sharding();
1089   HloSharding auto_sharding = man_conversion_op->sharding();
1090   if (!RefineManualAutoShardingFromManual(operand_sharding, unspecified_dims,
1091                                           &auto_sharding, &manual_sharding)) {
1092     return false;
1093   }
1094   annotate_op->set_sharding(manual_sharding);
1095   man_conversion_op->set_sharding(auto_sharding);
1096   return true;
1097 }
1098 
InferUnspecifiedDimsFromOneUser(HloInstruction * annotate_op,const HloInstruction * user,int64_t aggressiveness,bool is_spmd,absl::Span<const int64_t> unspecified_dims,HloInstruction * man_conversion_op)1099 bool InferUnspecifiedDimsFromOneUser(HloInstruction* annotate_op,
1100                                      const HloInstruction* user,
1101                                      int64_t aggressiveness, bool is_spmd,
1102                                      absl::Span<const int64_t> unspecified_dims,
1103                                      HloInstruction* man_conversion_op) {
1104   CHECK(annotate_op->IsCustomCall("Sharding") ||
1105         annotate_op->opcode() == HloOpcode::kCopy);
1106   if (!user->has_sharding() || !user->sharding().IsTiled()) {
1107     return false;
1108   }
1109   std::optional<HloSharding> user_sharding =
1110       ShardingPropagation::GetShardingFromUser(
1111           man_conversion_op == nullptr ? *annotate_op : *man_conversion_op,
1112           *user, aggressiveness, is_spmd);
1113   if (!user_sharding.has_value() || user_sharding->IsTileMaximal()) {
1114     return false;
1115   }
1116   if (man_conversion_op == nullptr) {
1117     HloSharding partial_replicated =
1118         hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
1119             *user_sharding, unspecified_dims);
1120     HloSharding sharding = annotate_op->sharding();
1121     if (!hlo_sharding_util::MergeShardingIfCompatible(
1122             partial_replicated, sharding.NumTiles() + 1, &sharding)) {
1123       return false;
1124     }
1125     annotate_op->set_sharding(sharding);
1126     return true;
1127   }
1128   if (man_conversion_op->IsCustomCall("SPMDFullToShardShape")) {
1129     HloSharding auto_sharding = annotate_op->sharding();
1130     HloSharding manual_sharding = man_conversion_op->sharding();
1131     if (!RefineManualAutoShardingFromManual(*user_sharding, unspecified_dims,
1132                                             &auto_sharding, &manual_sharding)) {
1133       return false;
1134     }
1135     annotate_op->set_sharding(auto_sharding);
1136     man_conversion_op->set_sharding(manual_sharding);
1137     return true;
1138   }
1139   CHECK(man_conversion_op->IsCustomCall("SPMDShardToFullShape"));
1140   HloSharding manual_sharding = annotate_op->sharding();
1141   HloSharding auto_sharding = man_conversion_op->sharding();
1142   if (!RefineManualAutoShardingFromAuto(*user_sharding, unspecified_dims,
1143                                         &auto_sharding, &manual_sharding)) {
1144     return false;
1145   }
1146   annotate_op->set_sharding(manual_sharding);
1147   man_conversion_op->set_sharding(auto_sharding);
1148   return true;
1149 }
1150 
InferUnspecifiedDimsFromUsers(HloInstruction * annotate_op,absl::Span<const int64_t> unspecified_dims,int64_t aggressiveness,bool is_spmd,HloInstruction ** man_conversion_op_after)1151 bool InferUnspecifiedDimsFromUsers(HloInstruction* annotate_op,
1152                                    absl::Span<const int64_t> unspecified_dims,
1153                                    int64_t aggressiveness, bool is_spmd,
1154                                    HloInstruction** man_conversion_op_after) {
1155   HloInstruction* man_conversion_op = nullptr;
1156   if (annotate_op->user_count() == 1) {
1157     HloInstruction* user = annotate_op->users()[0];
1158     if (user->IsCustomCall("SPMDFullToShardShape") ||
1159         user->IsCustomCall("SPMDShardToFullShape")) {
1160       std::vector<int64_t> user_unspec_dims;
1161       absl::c_sort(user_unspec_dims);
1162       if (!sharding_op_util::ParseAttributes(
1163                Cast<HloCustomCallInstruction>(user)->opaque(),
1164                &user_unspec_dims)
1165                .ok() ||
1166           unspecified_dims != user_unspec_dims) {
1167         // The manual/auto conversion op must have the same set of unspecified
1168         // dims.
1169         return false;
1170       }
1171       man_conversion_op = user;
1172     }
1173   }
1174   *man_conversion_op_after = man_conversion_op;
1175 
1176   HloInstruction* op_for_users =
1177       man_conversion_op == nullptr ? annotate_op : man_conversion_op;
1178   bool changed = false;
1179   for (HloInstruction* user : op_for_users->users()) {
1180     changed |= InferUnspecifiedDimsFromOneUser(
1181         annotate_op, user, aggressiveness, is_spmd, unspecified_dims,
1182         man_conversion_op);
1183   }
1184   return changed;
1185 }
1186 
1187 // Returns whether an op is a target for CSE prevention.
IsCSEPreventionTarget(const HloInstruction * instruction)1188 bool IsCSEPreventionTarget(const HloInstruction* instruction) {
1189   // Scalar broadcasts are the most common CSE target that causes cross-layer
1190   // propagation on unrelated subgraphs.
1191   return instruction->opcode() == HloOpcode::kBroadcast &&
1192          instruction->operand(0)->shape().rank() == 0;
1193 }
1194 
1195 // Marks a sharding as for CSE prevention/
SetCSEPreventionSharding(const HloSharding & sharding)1196 HloSharding SetCSEPreventionSharding(const HloSharding& sharding) {
1197   OpMetadata metadata;
1198   metadata.set_op_name("_sharding_propagation_cse_prevention");
1199   return sharding.WithMetadata({metadata}, /*overwrite=*/true);
1200 }
1201 
1202 // Returns if the sharding is for CSE prevention.
IsCSEPreventionSharding(const HloSharding & sharding)1203 bool IsCSEPreventionSharding(const HloSharding& sharding) {
1204   if (sharding.metadata().size() != 1) {
1205     return false;
1206   }
1207   return sharding.metadata()[0].op_name() ==
1208          "_sharding_propagation_cse_prevention";
1209 }
1210 
1211 }  // namespace
1212 
InferBroadcastOperandSharding(const HloInstruction & instruction,bool is_spmd)1213 std::optional<HloSharding> InferBroadcastOperandSharding(
1214     const HloInstruction& instruction, bool is_spmd) {
1215   if (instruction.sharding().IsReplicated()) {
1216     return instruction.sharding();
1217   }
1218   std::vector<int64_t> dims_to_replicate;
1219   bool needs_replication = false;
1220   for (int64_t i = 0; i < instruction.shape().rank(); ++i) {
1221     if (absl::c_count(instruction.dimensions(), i) == 0) {
1222       dims_to_replicate.push_back(i);
1223       if (instruction.sharding().tile_assignment().dim(i) > 1) {
1224         needs_replication = true;
1225       }
1226     }
1227   }
1228   // If not SPMD, only support when none of the partitioned dimensions in
1229   // the broadcast output belong to new dimensions.
1230   if (!is_spmd && needs_replication) {
1231     return std::nullopt;
1232   }
1233   return hlo_sharding_util::RemoveShapeDimensions(
1234       hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1235           instruction.sharding(), dims_to_replicate),
1236       dims_to_replicate);
1237 }
1238 
1239 // Remove Sharding custom-call instruction by folding the sharding attribute
1240 // to its operand. If the operand already has a different sharding, insert a
1241 // copy node for reshard.
1242 // `unspecified_dims` will be populated with the converted copies if the custom
1243 // call is partially specified.
ProcessShardingInstruction(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads,bool replace_sharding_with_copy,absl::flat_hash_map<const HloInstruction *,std::vector<int64_t>> * unspecified_dims)1244 StatusOr<bool> ProcessShardingInstruction(
1245     HloModule* module,
1246     const absl::flat_hash_set<absl::string_view>& execution_threads,
1247     bool replace_sharding_with_copy,
1248     absl::flat_hash_map<const HloInstruction*, std::vector<int64_t>>*
1249         unspecified_dims) {
1250   bool changed = false;
1251 
1252   for (HloComputation* computation : module->computations(execution_threads)) {
1253     auto instructions = computation->MakeInstructionPostOrder();
1254     std::reverse(instructions.begin(), instructions.end());
1255     for (HloInstruction* instruction : instructions) {
1256       if (!instruction->IsCustomCall("Sharding")) {
1257         continue;
1258       }
1259       TF_RET_CHECK(instruction->has_sharding())
1260           << "Sharding instruction must have a sharding attribute";
1261       const HloSharding& sharding = instruction->sharding();
1262 
1263       std::vector<int64_t> unspec_dims;
1264       TF_RETURN_IF_ERROR(sharding_op_util::ParseAttributes(
1265           Cast<HloCustomCallInstruction>(instruction)->opaque(), &unspec_dims));
1266       // Replace it with a copy node so that it does not need special handling.
1267       if (replace_sharding_with_copy) {
1268         auto copy = computation->AddInstruction(
1269             HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kCopy,
1270                                         instruction->mutable_operand(0)));
1271         TF_RETURN_IF_ERROR(computation->ReplaceInstruction(instruction, copy));
1272         copy->set_sharding(sharding);
1273         instruction = copy;
1274         changed = true;
1275       }
1276       if (!unspec_dims.empty()) {
1277         absl::c_sort(unspec_dims);
1278         unspecified_dims->emplace(instruction, std::move(unspec_dims));
1279       } else if (!instruction->operand(0)->has_sharding()) {
1280         instruction->mutable_operand(0)->set_sharding(sharding);
1281       }
1282     }
1283   }
1284   return changed;
1285 }
1286 
NormalizeDomain(const DomainMetadata::Domain & domain,const DomainMetadata * metadata)1287 /*static*/ Status ShardingPropagation::NormalizeDomain(
1288     const DomainMetadata::Domain& domain, const DomainMetadata* metadata) {
1289   if (metadata != nullptr) {
1290     TF_ASSIGN_OR_RETURN(const auto& sharding_metadata,
1291                         ShardingMetadata::ToShardingMetadata(metadata));
1292     const auto& sharding = sharding_metadata->sharding();
1293     if (sharding != nullptr) {
1294       bool is_spatially_partitioned = !sharding->HasUniqueDevice();
1295       if (sharding->IsTuple()) {
1296         is_spatially_partitioned = absl::c_any_of(
1297             sharding->tuple_elements(),
1298             [](const HloSharding& s) { return !s.HasUniqueDevice(); });
1299       }
1300       if (is_spatially_partitioned) {
1301         for (HloInstruction* d : domain.exit_domains) {
1302           HloInstruction* operand = d->mutable_operand(0);
1303           // Set sharding only if it is different. We don't overwrite the
1304           // metadata if it has the same sharding besides metadata.
1305           if (!operand->has_sharding() || operand->sharding() != *sharding) {
1306             d->mutable_operand(0)->set_sharding(*sharding);
1307           }
1308         }
1309         return OkStatus();
1310       }
1311     }
1312   }
1313   return ShardingMetadata::NormalizeShardingDomain(domain, metadata);
1314 }
1315 
1316 // Return the sharding that should be propagated from user to instruction.
GetShardingFromUser(const HloInstruction & instruction,const HloInstruction & user,int64_t aggressiveness,bool is_spmd)1317 std::optional<HloSharding> ShardingPropagation::GetShardingFromUser(
1318     const HloInstruction& instruction, const HloInstruction& user,
1319     int64_t aggressiveness, bool is_spmd) {
1320   if (!CanPropagateThroughAtAggressiveLevel(user, aggressiveness)) {
1321     return std::nullopt;
1322   }
1323   if (!IsSpatiallyPartitioned(&user)) {
1324     return std::nullopt;
1325   }
1326   const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0;
1327 
1328   switch (user.opcode()) {
1329     case HloOpcode::kBroadcast: {
1330       return InferBroadcastOperandSharding(user, is_spmd);
1331     }
1332     case HloOpcode::kConcatenate: {
1333       if (aggressiveness == 0) {
1334         return std::nullopt;
1335       }
1336       if (user.sharding().IsReplicated()) {
1337         return user.sharding();
1338       }
1339 
1340       const int64_t cdim = user.concatenate_dimension();
1341       const Array<int64_t>& tile_assignment = user.sharding().tile_assignment();
1342       if (tile_assignment.dim(cdim) == 1) {
1343         // If we are concatenating along a non-sharded dimension then the
1344         // operands should have the same sharding as the result.
1345         return user.sharding();
1346       }
1347 
1348       if (is_spmd) {
1349         // SPMD doesn't support tiling with part of the devices. Return the same
1350         // sharding.
1351         return user.sharding();
1352       }
1353 
1354       // If we are concatenating along a sharded dimension then we want the
1355       // operands to be distributed among the devices their data is used.
1356       int64_t start_offset = 0;
1357       for (HloInstruction* op : user.operands()) {
1358         if (op == &instruction) {
1359           break;
1360         }
1361         start_offset += op->shape().dimensions(cdim);
1362       }
1363       const int64_t tile_shape = CeilOfRatio(
1364           user.shape().dimensions(cdim), tile_assignment.dimensions()[cdim]);
1365       std::vector<int64_t> start_indices(tile_assignment.num_dimensions());
1366       std::vector<int64_t> end_indices = tile_assignment.dimensions();
1367       start_indices[cdim] = start_offset / tile_shape;
1368       end_indices[cdim] = CeilOfRatio(
1369           start_offset + instruction.shape().dimensions(cdim), tile_shape);
1370       auto new_tile_assignment =
1371           tile_assignment.Slice(start_indices, end_indices);
1372       if (new_tile_assignment.num_elements() == 1) {
1373         return HloSharding::AssignDevice(*new_tile_assignment.begin(),
1374                                          user.sharding().metadata());
1375       }
1376       return HloSharding::Tile(new_tile_assignment, user.sharding().metadata());
1377     }
1378     case HloOpcode::kConvolution: {
1379       auto dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(&user);
1380       if (dot_dims.conv_spatial_dims.empty()) {
1381         int64_t op_idx = user.operand_index(&instruction);
1382         return InferDotOperandSharding(&user, dot_dims, op_idx,
1383                                        may_combine_partial_sharding);
1384       }
1385       return std::nullopt;
1386     }
1387     case HloOpcode::kDynamicSlice:
1388     case HloOpcode::kDynamicUpdateSlice: {
1389       if (aggressiveness == 0) {
1390         return std::nullopt;
1391       }
1392       if (user.sharding().IsReplicated()) {
1393         return user.sharding();
1394       }
1395       if (user.opcode() == HloOpcode::kDynamicUpdateSlice &&
1396           &instruction == user.operand(0)) {
1397         return user.sharding();
1398       }
1399       const HloInstruction* operand = user.opcode() == HloOpcode::kDynamicSlice
1400                                           ? user.operand(0)
1401                                           : user.operand(1);
1402       if (&instruction != operand) {
1403         return std::nullopt;
1404       }
1405 
1406       if (is_spmd) {
1407         return user.sharding();
1408       }
1409       const auto& tile_assignment = user.sharding().tile_assignment();
1410       for (int64_t i = 0; i < user.shape().rank(); ++i) {
1411         if (tile_assignment.dim(i) > 1 &&
1412             user.shape().dimensions(i) != operand->shape().dimensions(i)) {
1413           return std::nullopt;
1414         }
1415       }
1416       return user.sharding();
1417     }
1418     case HloOpcode::kReduceWindow: {
1419       auto* reduce_window = Cast<HloReduceWindowInstruction>(&user);
1420       if (!absl::c_linear_search(reduce_window->inputs(), &instruction)) {
1421         return std::nullopt;
1422       }
1423       if (reduce_window->shape().IsTuple()) {
1424         auto sub_sharding = reduce_window->sharding().GetSubSharding(
1425             reduce_window->shape(),
1426             {reduce_window->operand_index(&instruction)});
1427         return sub_sharding;
1428       }
1429       return reduce_window->sharding();
1430     }
1431     case HloOpcode::kReshape: {
1432       auto reshaped_sharding = hlo_sharding_util::ReshapeSharding(
1433           user.shape(), instruction.shape(), user.sharding());
1434       if (reshaped_sharding.has_value()) {
1435         return reshaped_sharding;
1436       }
1437       return hlo_sharding_util::ReplicateAllDataDims(
1438           user.sharding(), instruction.shape().rank());
1439     }
1440     case HloOpcode::kPad: {
1441       if (&instruction != user.operand(0)) {
1442         return std::nullopt;
1443       }
1444       return user.sharding();
1445     }
1446     case HloOpcode::kSlice: {
1447       return user.sharding();
1448     }
1449     case HloOpcode::kTranspose: {
1450       // Calculate the dimension numbers for reversing the current transpose
1451       // and then use TransposeSharding to convert the output sharding to an
1452       // input sharding.
1453       std::vector<int64_t> reverse_dimensions(user.dimensions().size());
1454       for (int64_t i = 0; i < user.dimensions().size(); ++i) {
1455         reverse_dimensions[user.dimensions(i)] = i;
1456       }
1457       return hlo_sharding_util::TransposeSharding(user.sharding(),
1458                                                   reverse_dimensions);
1459     }
1460     case HloOpcode::kTuple: {
1461       auto sub_sharding = user.sharding().GetSubSharding(
1462           user.shape(), {user.operand_index(&instruction)});
1463       return sub_sharding;
1464     }
1465     case HloOpcode::kGetTupleElement: {
1466       int64_t sharding_index = 0;
1467       for (int i = 0; i < instruction.shape().tuple_shapes_size(); ++i) {
1468         if (i == user.tuple_index()) {
1469           break;
1470         }
1471         if (instruction.shape().tuple_shapes(i).IsArray()) {
1472           sharding_index += 1;
1473         } else {
1474           sharding_index +=
1475               ShapeUtil::GetLeafCount(instruction.shape().tuple_shapes(i));
1476         }
1477       }
1478       if (user.shape().IsArray()) {
1479         // Use ReplicateAllDataDims instead of HloSharding::Replicate() to
1480         // preserve manual subgroups.
1481         HloSharding new_sharding =
1482             instruction.has_sharding()
1483                 ? instruction.sharding()
1484                 : HloSharding::SingleTuple(
1485                       instruction.shape(),
1486                       hlo_sharding_util::ReplicateAllDataDims(user.sharding()));
1487         new_sharding.tuple_elements()[sharding_index] = user.sharding();
1488         return new_sharding;
1489       } else {
1490         if (user.sharding().tuple_elements().empty()) {
1491           return std::nullopt;
1492         }
1493         HloSharding new_sharding =
1494             instruction.has_sharding()
1495                 ? instruction.sharding()
1496                 : HloSharding::SingleTuple(
1497                       instruction.shape(),
1498                       hlo_sharding_util::ReplicateAllDataDims(
1499                           user.sharding().tuple_elements()[0]));
1500         for (int64_t i = 0; i < user.sharding().tuple_elements().size(); ++i) {
1501           new_sharding.tuple_elements()[sharding_index + i] =
1502               user.sharding().tuple_elements()[i];
1503         }
1504         return new_sharding;
1505       }
1506     }
1507     case HloOpcode::kDot: {
1508       int64_t op_idx = user.operand_index(&instruction);
1509       auto dnums = dot_as_convolution_util::ParseDotGeneralFromDot(&user);
1510       return InferDotOperandSharding(&user, dnums, op_idx,
1511                                      may_combine_partial_sharding);
1512     }
1513     case HloOpcode::kReduce: {
1514       if (instruction.shape().rank() == 0) {
1515         return std::nullopt;
1516       }
1517       auto user_sharding =
1518           user.shape().IsTuple()
1519               ? user.sharding().GetSubSharding(
1520                     user.shape(), {user.operand_index(&instruction)})
1521               : user.sharding();
1522       if (user_sharding.IsTileMaximal()) {
1523         return user_sharding;
1524       }
1525       std::vector<int64_t> target_tile_assignment_dimensions(
1526           instruction.shape().rank() +
1527           (user_sharding.ReplicateOnLastTileDim() ? 1 : 0) +
1528           user_sharding.subgroup_types().size());
1529       const auto& dimensions = user.dimensions();
1530       int64_t next_output_dim = 0;
1531       for (int64_t i = 0; i < target_tile_assignment_dimensions.size(); ++i) {
1532         if (absl::c_find(dimensions, i) == dimensions.end()) {
1533           target_tile_assignment_dimensions[i] =
1534               user_sharding.tile_assignment().dim(next_output_dim++);
1535         } else {
1536           target_tile_assignment_dimensions[i] = 1;
1537         }
1538       }
1539       auto tile_assignment = user_sharding.tile_assignment();
1540       tile_assignment.Reshape(target_tile_assignment_dimensions);
1541       return user_sharding.ReplicateOnLastTileDim()
1542                  ? HloSharding::PartialTile(tile_assignment,
1543                                             user_sharding.metadata())
1544                  : HloSharding::Subgroup(tile_assignment,
1545                                          user_sharding.subgroup_types(),
1546                                          user_sharding.metadata());
1547     }
1548     case HloOpcode::kSort: {
1549       HloSharding user_sharding = user.sharding();
1550       if (user_sharding.IsTuple()) {
1551         return user_sharding = user_sharding.GetSubSharding(
1552                    user.shape(), {user.operand_index(&instruction)});
1553       }
1554       return user_sharding;
1555     }
1556     case HloOpcode::kReverse: {
1557       return hlo_sharding_util::ReverseSharding(user.sharding(),
1558                                                 user.dimensions());
1559     }
1560     case HloOpcode::kGather: {
1561       if (&instruction == user.operand(1)) {
1562         return hlo_sharding_util::GatherIndexSharding(user.sharding(), &user);
1563       }
1564       if (is_spmd) {
1565         return hlo_sharding_util::GatherDataOperandShardingFromOutput(
1566             user.sharding(), user);
1567       }
1568       return std::nullopt;
1569     }
1570     case HloOpcode::kScatter: {
1571       auto& scatter_user = *Cast<HloScatterInstruction>(&user);
1572       if (&instruction == scatter_user.operand(0)) {
1573         return user.sharding();
1574       }
1575       if (&instruction == scatter_user.scatter_indices()) {
1576         auto update = scatter_user.scatter_updates()[0];
1577         if (!IsSpatiallyPartitioned(update)) {
1578           return std::nullopt;
1579         }
1580         return hlo_sharding_util::ScatterIndexSharding(update->sharding(),
1581                                                        &scatter_user);
1582       }
1583       CHECK_EQ(&instruction, scatter_user.scatter_updates()[0]);
1584       auto indices = scatter_user.scatter_indices();
1585       if (IsSpatiallyPartitioned(indices)) {
1586         auto from_indices = hlo_sharding_util::ScatterDataSharding(
1587             indices->sharding(), &scatter_user);
1588         if (!from_indices.IsTileMaximal()) {
1589           return from_indices;
1590         }
1591       }
1592       if (is_spmd) {
1593         return hlo_sharding_util::ScatterUpdateShardingFromOutput(
1594             user.sharding(), scatter_user);
1595       }
1596       return std::nullopt;
1597     }
1598     default: {
1599       // If the user output shape is compatible with the current instruction
1600       // shape excluding element type and the current instruction is supported
1601       // by spatial partitioning, then the user sharding can be used for
1602       // propagation to the current instruction.
1603       if (ShapeUtil::CompatibleIgnoringElementType(instruction.shape(),
1604                                                    user.shape())) {
1605         return user.sharding();
1606       }
1607       return std::nullopt;
1608     }
1609   }
1610 }
1611 
1612 // Compute the number of users that are only internal to the computation.
ComputeNonRootUsers(const HloInstruction * instr)1613 int64_t ComputeNonRootUsers(const HloInstruction* instr) {
1614   int64_t non_root_users = instr->users().size();
1615   for (int i = 0; i < instr->users().size(); ++i) {
1616     if (instr->users()[i] == instr->parent()->root_instruction()) {
1617       --non_root_users;
1618     }
1619   }
1620   return non_root_users;
1621 }
1622 
1623 // Only pass through sharding annotation at the first iteration when:
1624 //  1. Operand is sharded;  2. Only non-concat dim is sharded;
1625 //  3. Concat is for params in the repeated layers which follows the
1626 //     pattern of param/gte -> reshape -> concat.
AggressiveConcatOperandShardingCanPassThrough(const HloInstruction * concat_operand)1627 bool AggressiveConcatOperandShardingCanPassThrough(
1628     const HloInstruction* concat_operand) {
1629   return (
1630       IsSpatiallyPartitioned(concat_operand) &&
1631       (concat_operand->has_sharding() &&
1632        concat_operand->sharding().NumTiles() > 1) &&
1633       concat_operand->opcode() == HloOpcode::kReshape &&
1634       (concat_operand->operand(0)->opcode() == HloOpcode::kParameter ||
1635        concat_operand->operand(0)->opcode() == HloOpcode::kGetTupleElement));
1636 }
1637 
1638 // DyanmicSlice or DynamicUpdateSlice handling for InferShardingFromOperands().
InferDynamicSliceOrDynamicUpdateSliceShardingFromOperands(HloInstruction * instruction,int64_t aggressiveness,bool may_combine_partial_sharding)1639 bool InferDynamicSliceOrDynamicUpdateSliceShardingFromOperands(
1640     HloInstruction* instruction, int64_t aggressiveness,
1641     bool may_combine_partial_sharding) {
1642   const HloInstruction* operand =
1643       instruction->opcode() == HloOpcode::kDynamicSlice
1644           ? instruction->operand(0)
1645           : instruction->operand(1);
1646   auto slice_dim_is_sharded = [&]() {
1647     if (!IsSpatiallyPartitioned(operand) ||
1648         operand->sharding().NumTiles() == 1) {
1649       return false;
1650     }
1651     for (int64_t i = 0; i < instruction->shape().rank(); ++i) {
1652       const auto& tile_assignment = operand->sharding().tile_assignment();
1653       if (tile_assignment.dim(i) > 1 && instruction->shape().dimensions(i) !=
1654                                             operand->shape().dimensions(i)) {
1655         return true;
1656       }
1657     }
1658     return false;
1659   };
1660 
1661   // Do not pass through sharding annotation at the first iteration
1662   // if slice dim is sharded.
1663   if (aggressiveness == 0 && slice_dim_is_sharded()) {
1664     return false;
1665   }
1666 
1667   auto propagate_slicing = [&]() {
1668     if (!IsSpatiallyPartitioned(operand)) {
1669       return false;
1670     }
1671 
1672     if (operand->sharding().NumTiles() == 1) {
1673       return MaybeImproveInstructionSharding(
1674           operand->sharding(), instruction, may_combine_partial_sharding,
1675           /*allow_aggressive_resharding=*/
1676           ComputeNonRootUsers(instruction) == 1);
1677     }
1678 
1679     if (slice_dim_is_sharded()) {
1680       return false;
1681     }
1682     return MaybeImproveInstructionSharding(
1683         operand->sharding(), instruction, may_combine_partial_sharding,
1684         /*allow_aggressive_resharding=*/
1685         ComputeNonRootUsers(instruction) == 1);
1686   };
1687   auto propagate_base = [&]() {
1688     if (instruction->opcode() != HloOpcode::kDynamicUpdateSlice) {
1689       return false;
1690     }
1691     if (!IsSpatiallyPartitioned(instruction->operand(0))) {
1692       return false;
1693     }
1694     return MaybeImproveInstructionSharding(instruction->operand(0)->sharding(),
1695                                            instruction,
1696                                            may_combine_partial_sharding);
1697   };
1698   bool changed = propagate_slicing();
1699   changed |= propagate_base();
1700   return changed;
1701 }
1702 
1703 // Tries to update the sharding of the specified instruction based on its
1704 // operands and returns true if the sharding of the instruction have been
1705 // changed and false otherwise.
InferShardingFromOperands(HloInstruction * instruction,const ComputationMap & computation_map,int64_t aggressiveness)1706 bool ShardingPropagation::InferShardingFromOperands(
1707     HloInstruction* instruction, const ComputationMap& computation_map,
1708     int64_t aggressiveness) {
1709   if (!CanPropagateThroughAtAggressiveLevel(*instruction, aggressiveness)) {
1710     return false;
1711   }
1712   // Do not change manual sharding.
1713   if (instruction->has_sharding() && instruction->sharding().IsManual()) {
1714     return false;
1715   }
1716   // Propagate manual sharding. Avoid tuple shaped HLOs that group independent
1717   // together. Reduce, ReduceWindow, and Sort can be tuples but the elements
1718   // are correlated, so we propagate manual sharding through them.
1719   if (!instruction->has_sharding() &&
1720       (instruction->shape().IsArray() ||
1721        instruction->opcode() == HloOpcode::kReduce ||
1722        instruction->opcode() == HloOpcode::kSort ||
1723        instruction->opcode() == HloOpcode::kReduceWindow)) {
1724     for (const HloInstruction* op : instruction->operands()) {
1725       if (!op->has_sharding() || !op->sharding().IsManual()) continue;
1726       // Do not pass through manual sharding to concat or dynamic slice when
1727       // aggressiveneess is 0.
1728       if (aggressiveness == 0 &&
1729           (instruction->opcode() == HloOpcode::kConcatenate ||
1730            instruction->opcode() == HloOpcode::kDynamicSlice)) {
1731         return false;
1732       }
1733       instruction->set_sharding(HloSharding::Manual(op->sharding().metadata()));
1734       return true;
1735     }
1736   }
1737   const bool may_combine_partial_sharding = is_spmd_ && aggressiveness > 0;
1738   if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd_,
1739                                   allow_spmd_sharding_propagation_to_output_,
1740                                   sharding_helper_.get())) {
1741     // If an array shaped HLO doesn't support spatial partitioning but at least
1742     // one of its operand is replicated then we make the HLO replicated as well.
1743     if (instruction->shape().IsTuple() || instruction->operand_count() == 0 ||
1744         instruction == instruction->parent()->root_instruction() ||
1745         instruction->HasSideEffect()) {
1746       return false;
1747     }
1748     for (const HloInstruction* op : instruction->operands()) {
1749       if (op->has_sharding() && op->sharding().IsTileMaximal() &&
1750           !op->sharding().HasUniqueDevice()) {
1751         return MaybeImproveInstructionSharding(op->sharding(), instruction,
1752                                                may_combine_partial_sharding);
1753       }
1754     }
1755     return false;
1756   }
1757 
1758   auto get_maybe_tuple_sharding = [&](HloSharding sharding) {
1759     if (instruction->shape().IsArray()) {
1760       return sharding;
1761     }
1762     std::vector<HloSharding> tuple(instruction->shape().tuple_shapes_size(),
1763                                    std::move(sharding));
1764     return HloSharding::Tuple(instruction->shape(), tuple);
1765   };
1766 
1767   switch (instruction->opcode()) {
1768     case HloOpcode::kGetTupleElement: {
1769       const HloInstruction* operand = instruction->operand(0);
1770       if (!IsSpatiallyPartitioned(operand)) {
1771         return false;
1772       }
1773       HloSharding new_sharding = operand->sharding().GetSubSharding(
1774           operand->shape(), {instruction->tuple_index()});
1775       if (new_sharding.IsManual()) {
1776         instruction->set_sharding(new_sharding);
1777         return true;
1778       }
1779       return MaybeImproveInstructionSharding(
1780           std::move(new_sharding), instruction, may_combine_partial_sharding,
1781           /*allow_aggressive_resharding=*/
1782           ComputeNonRootUsers(instruction) == 1);
1783     }
1784     case HloOpcode::kTuple: {
1785       if (absl::c_none_of(instruction->operands(),
1786                           [](const HloInstruction* hlo) {
1787                             return IsSpatiallyPartitioned(hlo);
1788                           })) {
1789         // None of the operands have a spatially partitioned sharding.
1790         return false;
1791       }
1792       const Shape& shape = instruction->shape();
1793       bool changed = false;
1794       if (!instruction->has_sharding()) {
1795         // Set the sharding for all elements in the tuple because it isn't
1796         // possible to set a partial sharding.
1797         changed = true;
1798         for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
1799           const HloInstruction* operand = instruction->operand(i);
1800           if (!operand->has_sharding()) {
1801             continue;
1802           }
1803           if (operand->sharding().IsTuple()) {
1804             if (operand->sharding().tuple_elements().empty()) {
1805               continue;
1806             }
1807             // Use ReplicateAllDataDims to preserve manual subgroups.
1808             instruction->set_sharding(HloSharding::SingleTuple(
1809                 instruction->shape(),
1810                 hlo_sharding_util::ReplicateAllDataDims(
1811                     operand->sharding().tuple_elements()[0])
1812                     .WithoutMetadata()));
1813           } else {
1814             instruction->set_sharding(HloSharding::SingleTuple(
1815                 instruction->shape(),
1816                 hlo_sharding_util::ReplicateAllDataDims(operand->sharding())
1817                     .WithoutMetadata()));
1818           }
1819           break;
1820         }
1821       }
1822       if (!instruction->has_sharding()) {
1823         return false;
1824       }
1825       // Go through each operand and if the operand has a sharding that is
1826       // better than the current sharding for that tuple element then update
1827       // it.
1828       std::vector<HloSharding> sub_shardings =
1829           instruction->sharding().tuple_elements();
1830       int64_t sub_sharding_index = 0;
1831       for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
1832         const HloInstruction* operand = instruction->operand(i);
1833         if (operand->has_sharding()) {
1834           if (operand->shape().IsTuple()) {
1835             for (int64_t i = 0, e = ShapeUtil::GetLeafCount(operand->shape());
1836                  i < e; ++i) {
1837               if (hlo_sharding_util::IsShardingMoreSpecific(
1838                       operand->sharding().tuple_elements()[i],
1839                       sub_shardings[sub_sharding_index + i])) {
1840                 sub_shardings[sub_sharding_index + i] =
1841                     operand->sharding().tuple_elements()[i];
1842               }
1843             }
1844           } else {
1845             if (hlo_sharding_util::IsShardingMoreSpecific(
1846                     operand->sharding(), sub_shardings[sub_sharding_index])) {
1847               sub_shardings[sub_sharding_index] = operand->sharding();
1848             }
1849           }
1850         }
1851         sub_sharding_index += ShapeUtil::GetLeafCount(operand->shape());
1852       }
1853 
1854       HloSharding new_sharding = HloSharding::Tuple(shape, sub_shardings);
1855       if (new_sharding != instruction->sharding()) {
1856         instruction->set_sharding(std::move(new_sharding));
1857         return true;
1858       }
1859       return changed;
1860     }
1861     case HloOpcode::kReduce: {
1862       // Reduce could have a tuple shape, where the first half of operands are
1863       // the arrays to reduce, and the second half of operands are the init
1864       // values.
1865       bool changed = false;
1866       auto* reduce = Cast<HloReduceInstruction>(instruction);
1867       for (HloInstruction* operand : reduce->inputs()) {
1868         if (!IsSpatiallyPartitioned(operand)) {
1869           continue;
1870         }
1871         if (operand->sharding().IsReplicated() ||
1872             (!is_spmd_ &&
1873              absl::c_any_of(instruction->dimensions(), [operand](int64_t dim) {
1874                return operand->sharding().tile_assignment().dim(dim) > 1;
1875              }))) {
1876           // We are reducing along one of the sharded dimensions. We only
1877           // support this in SPMD.
1878           changed |= MaybeImproveInstructionSharding(
1879               get_maybe_tuple_sharding(
1880                   hlo_sharding_util::ReplicateAllDataDims(operand->sharding())),
1881               reduce, may_combine_partial_sharding,
1882               /*allow_aggressive_resharding=*/
1883               ComputeNonRootUsers(instruction) == 1);
1884           continue;
1885         }
1886         auto after_partial_replication =
1887             operand->sharding().IsReplicated()
1888                 ? operand->sharding()
1889                 : hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
1890                       operand->sharding(), reduce->dimensions());
1891         if (after_partial_replication.IsReplicated()) {
1892           changed |= MaybeImproveInstructionSharding(
1893               get_maybe_tuple_sharding(after_partial_replication), reduce,
1894               may_combine_partial_sharding,
1895               /*allow_aggressive_resharding=*/
1896               ComputeNonRootUsers(instruction) == 1);
1897           continue;
1898         }
1899         // Use the same sharding for all tuple elements, because they are part
1900         // of the same reduce instruction.
1901         HloSharding new_sharding =
1902             get_maybe_tuple_sharding(hlo_sharding_util::RemoveShapeDimensions(
1903                 after_partial_replication, reduce->dimensions()));
1904         changed |= MaybeImproveInstructionSharding(
1905             std::move(new_sharding), reduce, may_combine_partial_sharding,
1906             /*allow_aggressive_resharding=*/
1907             ComputeNonRootUsers(instruction) == 1);
1908       }
1909       return changed;
1910     }
1911     case HloOpcode::kBroadcast: {
1912       // Make forward propagation through broadcast low priority to avoid
1913       // resharding after broadcast.
1914       if (aggressiveness < 3) {
1915         return false;
1916       }
1917       const HloInstruction* op = instruction->operand(0);
1918       if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) {
1919         return false;
1920       }
1921       // The output will be tiled along the broadcasted dimension the same way
1922       // as the input for the broadcast while the other dimensions are kept
1923       // non-tiled.
1924       std::vector<int64_t> target_tile_assignment_dimensions;
1925       const auto& dimensions = instruction->dimensions();
1926       for (int64_t i = 0; i < instruction->shape().rank(); ++i) {
1927         auto it = absl::c_find(dimensions, i);
1928         if (it == dimensions.end()) {
1929           target_tile_assignment_dimensions.push_back(1);
1930         } else {
1931           const int64_t source_dim = std::distance(dimensions.begin(), it);
1932           target_tile_assignment_dimensions.push_back(
1933               op->sharding().tile_assignment().dim(source_dim));
1934         }
1935       }
1936       for (int64_t i = op->sharding().TiledDataRank();
1937            i < op->sharding().tile_assignment().num_dimensions(); ++i) {
1938         target_tile_assignment_dimensions.push_back(
1939             op->sharding().tile_assignment().dim(i));
1940       }
1941       Array<int64_t> new_tile_assignment = op->sharding().tile_assignment();
1942       new_tile_assignment.Reshape(target_tile_assignment_dimensions);
1943       HloSharding new_sharding =
1944           op->sharding().ReplicateOnLastTileDim()
1945               ? HloSharding::PartialTile(new_tile_assignment,
1946                                          op->sharding().metadata())
1947               : HloSharding::Subgroup(new_tile_assignment,
1948                                       op->sharding().subgroup_types(),
1949                                       op->sharding().metadata());
1950       return MaybeImproveInstructionSharding(
1951           std::move(new_sharding), instruction, may_combine_partial_sharding,
1952           /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
1953               1);
1954     }
1955     case HloOpcode::kConcatenate: {
1956       const HloInstruction* operand = PickRepresentativeOperand(instruction);
1957       if (!operand || !IsSpatiallyPartitioned(operand)) {
1958         return false;
1959       }
1960 
1961       if (aggressiveness == 0) {
1962         for (const HloInstruction* concat_operand : instruction->operands()) {
1963           if (!AggressiveConcatOperandShardingCanPassThrough(concat_operand)) {
1964             return false;
1965           }
1966           const auto& tile_assignment =
1967               concat_operand->sharding().tile_assignment();
1968           for (int64_t i = 0; i < instruction->shape().rank(); ++i) {
1969             if (absl::c_linear_search(instruction->dimensions(), i) &&
1970                 tile_assignment.dim(i) > 1) {
1971               return false;
1972             }
1973           }
1974         }
1975       }
1976       return MaybeImproveInstructionSharding(
1977           operand->sharding(), instruction, may_combine_partial_sharding,
1978           /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
1979               1);
1980     }
1981     case HloOpcode::kConvolution:
1982       return InferConvolutionShardingFromOperands(instruction, aggressiveness,
1983                                                   may_combine_partial_sharding);
1984     case HloOpcode::kTranspose: {
1985       const HloInstruction* input = instruction->operand(0);
1986       if (!IsSpatiallyPartitioned(input)) {
1987         return false;
1988       }
1989       HloSharding sharding = hlo_sharding_util::TransposeSharding(
1990           input->sharding(), instruction->dimensions());
1991       return MaybeImproveInstructionSharding(
1992           std::move(sharding), instruction, may_combine_partial_sharding,
1993           /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
1994               1);
1995     }
1996     case HloOpcode::kReduceWindow: {
1997       auto* reduce_window = Cast<HloReduceWindowInstruction>(instruction);
1998       auto has_dilation = [](const WindowDimension& dimensions) {
1999         return dimensions.base_dilation() > 1 ||
2000                dimensions.window_dilation() > 1;
2001       };
2002       if (absl::c_any_of(instruction->window().dimensions(), has_dilation)) {
2003         VLOG(2) << "Not applying sharding to reduce window because dilatation "
2004                    "isn't supported yet: "
2005                 << reduce_window->ToString();
2006         return false;
2007       }
2008       bool changed = false;
2009       for (HloInstruction* operand : reduce_window->inputs()) {
2010         if (!IsSpatiallyPartitioned(operand)) {
2011           continue;
2012         }
2013         changed |= MaybeImproveInstructionSharding(
2014             get_maybe_tuple_sharding(operand->sharding()), reduce_window,
2015             may_combine_partial_sharding,
2016             /*allow_aggressive_resharding=*/
2017             ComputeNonRootUsers(instruction) == 1);
2018       }
2019       return changed;
2020     }
2021     case HloOpcode::kSelectAndScatter: {
2022       // Shard according to first operand, as output keeps the same shape.
2023       const HloInstruction* lhs = instruction->operand(0);
2024       if (!IsSpatiallyPartitioned(lhs)) {
2025         return false;
2026       }
2027 
2028       auto has_base_dilation = [](const WindowDimension& dimensions) {
2029         return dimensions.base_dilation() > 1;
2030       };
2031       if (absl::c_any_of(instruction->window().dimensions(),
2032                          has_base_dilation)) {
2033         VLOG(2) << "Not applying sharding to select-and-scatter because "
2034                    "base dilation isn't supported yet: "
2035                 << instruction->ToString();
2036         return false;
2037       }
2038       return MaybeImproveInstructionSharding(
2039           lhs->sharding(), instruction, may_combine_partial_sharding,
2040           /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
2041               1);
2042     }
2043     case HloOpcode::kReshape: {
2044       if (!IsSpatiallyPartitioned(instruction->operand(0))) {
2045         return false;
2046       }
2047       std::optional<HloSharding> new_sharding =
2048           hlo_sharding_util::ReshapeSharding(
2049               instruction->operand(0)->shape(), instruction->shape(),
2050               instruction->operand(0)->sharding());
2051       if (new_sharding.has_value()) {
2052         return MaybeImproveInstructionSharding(
2053             std::move(*new_sharding), instruction, may_combine_partial_sharding,
2054             /*allow_aggressive_resharding=*/
2055             ComputeNonRootUsers(instruction) == 1);
2056       }
2057       if (!instruction->has_sharding()) {
2058         instruction->set_sharding(hlo_sharding_util::ReplicateAllDataDims(
2059             instruction->operand(0)->sharding(), instruction->shape().rank()));
2060         return true;
2061       }
2062       return false;
2063     }
2064     case HloOpcode::kReverse: {
2065       const HloInstruction* operand = instruction->operand(0);
2066       if (!IsSpatiallyPartitioned(operand)) {
2067         return false;
2068       }
2069       return MaybeImproveInstructionSharding(
2070           hlo_sharding_util::ReverseSharding(operand->sharding(),
2071                                              instruction->dimensions()),
2072           instruction, may_combine_partial_sharding,
2073           /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
2074               1);
2075     }
2076     case HloOpcode::kDot: {
2077       const auto& dnums =
2078           dot_as_convolution_util::ParseDotGeneralFromDot(instruction);
2079       return InferDotShardingFromOperands(instruction, dnums,
2080                                           may_combine_partial_sharding);
2081     }
2082     case HloOpcode::kParameter: {
2083       auto parent_it = computation_map.find(instruction->parent());
2084       if (parent_it == computation_map.end()) {
2085         return false;
2086       }
2087       const HloInstruction* parent = parent_it->second;
2088       switch (parent->opcode()) {
2089         case HloOpcode::kConditional: {
2090           for (int64_t i = 1; i < parent->operand_count(); ++i) {
2091             if (parent->called_computations()[i - 1] == instruction->parent()) {
2092               if (parent->operand(i)->has_sharding()) {
2093                 return MaybeImproveInstructionSharding(
2094                     parent->operand(i)->sharding(), instruction,
2095                     may_combine_partial_sharding);
2096               }
2097               return false;
2098             }
2099           }
2100           return false;
2101         }
2102         default:
2103           return false;
2104       }
2105     }
2106     case HloOpcode::kSort: {
2107       const HloInstruction* operand = PickRepresentativeOperand(instruction);
2108       if (!operand || !IsSpatiallyPartitioned(operand)) {
2109         return false;
2110       }
2111 
2112       if (!operand->sharding().IsTileMaximal() &&
2113           operand->sharding().tile_assignment().dim(
2114               instruction->dimensions(0)) != 1) {
2115         // Doesn't support sharding the sorting dimension.
2116         return false;
2117       }
2118 
2119       if (instruction->shape().IsTuple()) {
2120         return MaybeImproveInstructionSharding(
2121             HloSharding::SingleTuple(instruction->shape(), operand->sharding()),
2122             instruction, may_combine_partial_sharding,
2123             /*allow_aggressive_resharding=*/
2124             ComputeNonRootUsers(instruction) == 1);
2125       } else {
2126         return MaybeImproveInstructionSharding(
2127             operand->sharding(), instruction, may_combine_partial_sharding,
2128             /*allow_aggressive_resharding=*/
2129             ComputeNonRootUsers(instruction) == 1);
2130       }
2131     }
2132     case HloOpcode::kDynamicSlice:
2133     case HloOpcode::kDynamicUpdateSlice: {
2134       return InferDynamicSliceOrDynamicUpdateSliceShardingFromOperands(
2135           instruction, aggressiveness, may_combine_partial_sharding);
2136     }
2137     case HloOpcode::kGather: {
2138       bool changed = false;
2139       if (IsSpatiallyPartitioned(instruction->operand(1))) {
2140         HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding(
2141             instruction->operand(1)->sharding(), instruction);
2142         changed |= MaybeImproveInstructionSharding(
2143             std::move(new_sharding), instruction, may_combine_partial_sharding);
2144       }
2145       if (is_spmd_) {
2146         auto gather_parallel_dims =
2147             hlo_sharding_util::GetGatherBatchParallelDims(*instruction);
2148         if (gather_parallel_dims) {
2149           changed |= InferGatherParallelShardingFromOperands(
2150               instruction, *gather_parallel_dims, may_combine_partial_sharding);
2151         }
2152         if (IsSpatiallyPartitioned(instruction->operand(0))) {
2153           absl::Span<const int64_t> operand_parallel_dims;
2154           if (gather_parallel_dims) {
2155             operand_parallel_dims = absl::MakeConstSpan(
2156                 gather_parallel_dims->operand_parallel_dims);
2157           }
2158           HloSharding filtered_operand_sharding =
2159               hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2160                   instruction->operand(0)->sharding(), operand_parallel_dims);
2161           auto maybe_from_data =
2162               hlo_sharding_util::GatherOutputShardingFromDataOperand(
2163                   filtered_operand_sharding, *instruction,
2164                   instruction->gather_slice_sizes(), instruction->shape(),
2165                   instruction->operand(0)->shape());
2166           if (maybe_from_data) {
2167             changed |= MaybeImproveInstructionSharding(
2168                 std::move(*maybe_from_data), instruction,
2169                 may_combine_partial_sharding);
2170           }
2171         }
2172       }
2173       return changed;
2174     }
2175     case HloOpcode::kScatter: {
2176       bool changed = false;
2177       if (is_spmd_ && IsSpatiallyPartitioned(instruction->operand(0))) {
2178         changed |= MaybeImproveInstructionSharding(
2179             instruction->operand(0)->sharding(), instruction,
2180             may_combine_partial_sharding);
2181       }
2182       auto* scatter = Cast<HloScatterInstruction>(instruction);
2183       if (!IsSpatiallyPartitioned(scatter->scatter_indices()) &&
2184           !IsSpatiallyPartitioned(scatter->scatter_updates()[0])) {
2185         return false;
2186       }
2187       if (is_spmd_ && IsSpatiallyPartitioned(scatter->scatter_updates()[0])) {
2188         auto maybe_from_update =
2189             hlo_sharding_util::ScatterOutputShardingFromUpdate(
2190                 scatter->scatter_updates()[0]->sharding(), *scatter);
2191         if (maybe_from_update) {
2192           changed |= MaybeImproveInstructionSharding(
2193               std::move(*maybe_from_update), instruction,
2194               may_combine_partial_sharding);
2195         }
2196       }
2197       if (!is_spmd_) {
2198         changed |= MaybeImproveInstructionSharding(
2199             HloSharding::Replicate(), instruction,
2200             may_combine_partial_sharding);
2201       }
2202       return changed;
2203     }
2204     case HloOpcode::kWhile: {
2205       if (!instruction->operand(0)->has_sharding()) {
2206         return false;
2207       }
2208       auto sharding = instruction->operand(0)->sharding();
2209       if (instruction->has_sharding()) {
2210         hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding,
2211                                          may_combine_partial_sharding);
2212       }
2213       return MaybeImproveInstructionSharding(std::move(sharding), instruction,
2214                                              may_combine_partial_sharding);
2215     }
2216     case HloOpcode::kCustomCall: {
2217       if (instruction->IsCustomCall("X64Combine")) {
2218         return false;
2219       }
2220       HloSharding inferred_operand_sharding = HloSharding::Replicate();
2221       if (sharding_helper_->IsCustomCallShardable(instruction)) {
2222         if (auto sharding =
2223                 sharding_helper_->InferShardingFromOperands(instruction)) {
2224           inferred_operand_sharding = *sharding;
2225         } else {
2226           return false;
2227         }
2228       } else {
2229         const HloInstruction* operand = PickRepresentativeOperand(instruction);
2230         if (!operand || !IsSpatiallyPartitioned(operand)) {
2231           return false;
2232         }
2233         inferred_operand_sharding = operand->sharding();
2234       }
2235       return MaybeImproveInstructionSharding(
2236           inferred_operand_sharding, instruction, may_combine_partial_sharding,
2237           /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
2238               1);
2239     }
2240     default: {
2241       if (instruction->IsElementwise() && may_combine_partial_sharding) {
2242         bool changed = false;
2243         for (auto operand : instruction->operands()) {
2244           if (IsSpatiallyPartitioned(operand)) {
2245             if (instruction->opcode() == HloOpcode::kRng) {
2246               // Rng is considered elementwise but has operands with different
2247               // shapes.
2248               changed |= MaybeImproveInstructionSharding(
2249                   hlo_sharding_util::ReplicateAllDataDims(
2250                       operand->sharding(), instruction->shape().rank()),
2251                   instruction, may_combine_partial_sharding,
2252                   ComputeNonRootUsers(instruction) == 1);
2253               continue;
2254             }
2255             changed |= MaybeImproveInstructionSharding(
2256                 operand->sharding(), instruction, may_combine_partial_sharding,
2257                 /*allow_aggressive_resharding=*/
2258                 instruction->operands().size() == 1 &&
2259                     ComputeNonRootUsers(instruction) == 1);
2260           }
2261         }
2262         return changed;
2263       }
2264       const HloInstruction* operand = PickRepresentativeOperand(instruction);
2265       if (!operand || !IsSpatiallyPartitioned(operand)) {
2266         return false;
2267       }
2268       return MaybeImproveInstructionSharding(
2269           operand->sharding(), instruction, may_combine_partial_sharding,
2270           /*allow_aggressive_resharding=*/ComputeNonRootUsers(instruction) ==
2271               1);
2272     }
2273   }
2274   return false;
2275 }  // NOLINT(readability/fn_size)
2276 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2277 StatusOr<bool> ShardingPropagation::Run(
2278     HloModule* module,
2279     const absl::flat_hash_set<absl::string_view>& execution_threads) {
2280   std::optional<absl::flat_hash_map<const HloInstruction*, HloSharding>>
2281       original_sharding;
2282   bool any_changed = false;
2283   // Preprocessing for CSE prevention propagation: record the original shardings
2284   // so that we can revert to them at the end, and only keep those on CSE
2285   // prevention instructions.
2286   if (cse_prevention_only_) {
2287     original_sharding.emplace();
2288     for (auto computation : module->computations(execution_threads)) {
2289       for (auto instruction : computation->instructions()) {
2290         if (instruction->has_sharding()) {
2291           original_sharding->emplace(instruction, instruction->sharding());
2292         }
2293       }
2294     }
2295   } else {
2296     // The current pass is not for CSE prevention, but we remove the shardings
2297     // added by previous passes for CSE prevention.
2298     for (auto computation : module->computations(execution_threads)) {
2299       for (auto instruction : computation->instructions()) {
2300         if (instruction->has_sharding() &&
2301             IsCSEPreventionSharding(instruction->sharding())) {
2302           instruction->clear_sharding();
2303           any_changed = true;
2304         }
2305       }
2306     }
2307   }
2308   any_changed |= propagate_metadata_
2309                      ? AssignShardingMetadata(module, execution_threads)
2310                      : RemoveShardingMetadata(module, execution_threads);
2311   absl::flat_hash_map<const HloInstruction*, std::vector<int64_t>>
2312       unspecified_dims;
2313   TF_ASSIGN_OR_RETURN(
2314       bool changed,
2315       ProcessShardingInstruction(module, execution_threads,
2316                                  !cse_prevention_only_, &unspecified_dims));
2317   any_changed |= changed;
2318 
2319   // Association of partitionable embedded computations with their parent
2320   // instruction.
2321   ComputationMap computation_map;
2322 
2323   // Instructions that are related through a computation and need to share the
2324   // same sharding.
2325   auto get_related_instructions = [](HloInstruction* inst) {
2326     if (inst->opcode() == HloOpcode::kWhile) {
2327       return std::vector<HloInstruction*>{
2328           inst, inst->while_body()->root_instruction(),
2329           inst->while_body()->parameter_instruction(0),
2330           inst->while_condition()->parameter_instruction(0)};
2331     } else if (inst->opcode() == HloOpcode::kConditional) {
2332       const auto& called_computations = inst->called_computations();
2333       std::vector<HloInstruction*> comps;
2334       comps.reserve(called_computations.size() + 1);
2335       comps.push_back(inst);
2336       for (HloComputation* c : called_computations) {
2337         comps.push_back(c->root_instruction());
2338       }
2339       return comps;
2340     } else {
2341       CHECK(false);
2342     }
2343   };
2344 
2345   // If instruction is a while, or the root or a parameter of a while body,
2346   // then propagate its sharding to the while instruction, to its body root,
2347   // and to its condition parameter.
2348   std::function<void(HloInstruction*, absl::flat_hash_set<HloInstruction*>*)>
2349       maybe_computation_propagation =
2350           [&](HloInstruction* instruction,
2351               absl::flat_hash_set<HloInstruction*>* changed) {
2352             auto propagate_to_instruction = [&](HloInstruction* search_inst) {
2353               auto related_instructions = get_related_instructions(search_inst);
2354               if (absl::c_count(related_instructions, instruction)) {
2355                 for (HloInstruction* inst : related_instructions) {
2356                   if (!inst->has_sharding() ||
2357                       inst->sharding() != instruction->sharding()) {
2358                     VLOG(2) << "Add computation sharding: " << inst->name()
2359                             << " " << instruction->sharding().ToString();
2360                     inst->set_sharding(instruction->sharding());
2361                     changed->insert(inst);
2362                     maybe_computation_propagation(inst, changed);
2363                   }
2364                 }
2365               }
2366             };
2367 
2368             if (instruction->opcode() == HloOpcode::kConditional ||
2369                 instruction->opcode() == HloOpcode::kWhile) {
2370               propagate_to_instruction(instruction);
2371             }
2372 
2373             if (instruction->opcode() == HloOpcode::kParameter ||
2374                 instruction->parent()->root_instruction() == instruction) {
2375               auto it = computation_map.find(instruction->parent());
2376               if (it != computation_map.end()) {
2377                 propagate_to_instruction(it->second);
2378               }
2379             }
2380           };
2381 
2382   for (auto computation : module->computations(execution_threads)) {
2383     for (auto instruction : computation->instructions()) {
2384       if (instruction->opcode() == HloOpcode::kWhile) {
2385         TF_RETURN_IF_ERROR(
2386             CheckAndUpdateDeviceAssignmentsInWhileBody(instruction));
2387       }
2388     }
2389   }
2390 
2391   // Populate computation_map in order to associate while bodies to their
2392   // while instructions.
2393   for (auto computation : module->computations(execution_threads)) {
2394     for (auto instruction : computation->instructions()) {
2395       if (instruction->opcode() == HloOpcode::kWhile ||
2396           instruction->opcode() == HloOpcode::kConditional) {
2397         // Check if any of the related instructions has sharding, in which case
2398         // propagate it to the other instructions, so they all share the same
2399         // sharding, in case the user didn't shard all of them. We don't check
2400         // that user shardings are consistent, because such check is already
2401         // done by HloShardingVerifier.
2402         const HloInstruction* sharded_inst = nullptr;
2403         auto related_instructions = get_related_instructions(instruction);
2404         for (auto inst : related_instructions) {
2405           if (inst->has_sharding()) {
2406             sharded_inst = inst;
2407             break;
2408           }
2409         }
2410         if (sharded_inst != nullptr) {
2411           // Set the same sharding to all the other related instructions.
2412           for (auto inst : related_instructions) {
2413             inst->set_sharding(sharded_inst->sharding());
2414           }
2415         }
2416         if (instruction->opcode() == HloOpcode::kWhile) {
2417           computation_map[instruction->while_body()] = instruction;
2418         } else {
2419           for (HloComputation* c : instruction->called_computations()) {
2420             computation_map[c] = instruction;
2421           }
2422         }
2423       }
2424     }
2425   }
2426 
2427   // Collect all pre-sharded instructions as we aren't allowed to modify their
2428   // sharding.
2429   absl::flat_hash_set<const HloInstruction*> provided_shardings;
2430   for (const HloComputation* computation :
2431        module->computations(execution_threads)) {
2432     for (const HloInstruction* inst : computation->instructions()) {
2433       if (inst->has_sharding()) {
2434         provided_shardings.insert(inst);
2435       }
2436     }
2437   }
2438 
2439   if (!allow_spmd_sharding_propagation_to_output_) {
2440     // Consider the root instruction of the entry module as one with provided
2441     // sharding as its sharding have to match with the one expected by the host.
2442     provided_shardings.insert(module->entry_computation()->root_instruction());
2443   }
2444 
2445   // Iterate to a fixpoint that is guaranteed to be reached because we only
2446   // strictly improve the sharding of the graph and it can't be improved
2447   // indefinitely.
2448   int64_t iterations = 0;
2449   auto run_to_fix_point = [&](int64_t aggressiveness) {
2450     absl::flat_hash_set<const HloInstruction*> already_inferred_from_operands;
2451     absl::flat_hash_set<const HloInstruction*> already_inferred_from_users;
2452     bool changed_last_iter = true;
2453     const bool may_merge_partial = is_spmd_ && aggressiveness > 0;
2454     while (changed_last_iter) {
2455       changed_last_iter = false;
2456       int64_t inferred_from_operand_counter = 0;
2457       int64_t inferred_from_user_counter = 0;
2458       int64_t instruction_counter = 0;
2459       int64_t already_sharded_counter = 0;
2460       for (const HloComputation* computation :
2461            module->computations(execution_threads)) {
2462         VLOG(2) << "Consider computation: " << computation->name();
2463         std::vector<HloInstruction*> instructions =
2464             computation->MakeInstructionPostOrder();
2465 
2466         instruction_counter += instructions.size();
2467         already_sharded_counter += absl::c_count_if(
2468             instructions,
2469             [](const HloInstruction* inst) { return inst->has_sharding(); });
2470         auto clear_cache = [&](HloInstruction* hlo,
2471                                HloInstruction* hlo_for_users = nullptr) {
2472           for (auto operand : hlo->operands()) {
2473             already_inferred_from_users.erase(operand);
2474           }
2475           if (hlo_for_users == nullptr) {
2476             hlo_for_users = hlo;
2477           }
2478           for (auto user : hlo_for_users->users()) {
2479             already_inferred_from_operands.erase(user);
2480           }
2481         };
2482         // First iterate the HLO graph in post order taking shardings from
2483         // operands.
2484         for (HloInstruction* instruction : instructions) {
2485           if (already_inferred_from_operands.contains(instruction)) {
2486             continue;
2487           }
2488           if (provided_shardings.contains(instruction)) {
2489             if (!may_merge_partial) {
2490               continue;
2491             }
2492             auto it = unspecified_dims.find(instruction);
2493             HloInstruction* man_conversion_op_after;
2494             if (it != unspecified_dims.end() &&
2495                 InferUnspecifiedDimsFromOperand(instruction, it->second,
2496                                                 &man_conversion_op_after)) {
2497               ++inferred_from_operand_counter;
2498               VLOG(2) << "Refined partial sharding (forward-pass): "
2499                       << instruction->ToString();
2500               clear_cache(instruction, man_conversion_op_after);
2501               already_inferred_from_operands.insert(instruction);
2502               changed_last_iter = true;
2503             }
2504             continue;
2505           }
2506           already_inferred_from_operands.insert(instruction);
2507           if (InferShardingFromOperands(instruction, computation_map,
2508                                         aggressiveness)) {
2509             ++inferred_from_operand_counter;
2510             any_changed = true;
2511             VLOG(2) << "Add sharding (forward-pass): "
2512                     << instruction->ToString();
2513             absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
2514             maybe_computation_propagation(instruction, &changed_in_comp_prop);
2515             clear_cache(instruction);
2516             for (auto hlo : changed_in_comp_prop) {
2517               clear_cache(hlo);
2518             }
2519             changed_last_iter = true;
2520           }
2521         }
2522 
2523         // Then iterate the HLO graph in reverse post order taking shardings
2524         // from users.
2525         for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) {
2526           if ((*it)->IsCustomCall("SPMDFullToShardShape") ||
2527               (*it)->IsCustomCall("SPMDShardToFullShape")) {
2528             // The manual conversion op is processed together with the sharding
2529             // op before it. If the conversion op is removed from cache, the
2530             // sharding op should also be removed.
2531             if (!already_inferred_from_users.contains(*it)) {
2532               already_inferred_from_users.erase((*it)->operand(0));
2533             }
2534           }
2535           if (already_inferred_from_users.contains(*it)) {
2536             continue;
2537           }
2538           if (provided_shardings.contains(*it)) {
2539             if (!may_merge_partial) {
2540               continue;
2541             }
2542             auto uit = unspecified_dims.find(*it);
2543             HloInstruction* man_conversion_op_after;
2544             if (uit != unspecified_dims.end() &&
2545                 InferUnspecifiedDimsFromUsers(*it, uit->second, aggressiveness,
2546                                               is_spmd_,
2547                                               &man_conversion_op_after)) {
2548               ++inferred_from_user_counter;
2549               VLOG(2) << "Refined partial sharding (backward-pass): "
2550                       << (*it)->ToString();
2551               clear_cache(*it, man_conversion_op_after);
2552               already_inferred_from_users.insert(*it);
2553               if (man_conversion_op_after != nullptr) {
2554                 already_inferred_from_users.insert(man_conversion_op_after);
2555               }
2556               changed_last_iter = true;
2557             }
2558             continue;
2559           }
2560           already_inferred_from_users.insert(*it);
2561           if (InferShardingFromUsers(*it, computation_map, aggressiveness,
2562                                      is_spmd_, sharding_helper_.get())) {
2563             ++inferred_from_user_counter;
2564             any_changed = true;
2565             VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString();
2566             absl::flat_hash_set<HloInstruction*> changed_in_comp_prop;
2567             maybe_computation_propagation(*it, &changed_in_comp_prop);
2568             clear_cache(*it);
2569             for (auto hlo : changed_in_comp_prop) {
2570               clear_cache(hlo);
2571             }
2572             changed_last_iter = true;
2573           }
2574         }
2575       }
2576       VLOG(1) << "Sharding propagation iteration " << iterations << ";";
2577       VLOG(1) << "  total instructions: " << instruction_counter;
2578       VLOG(1) << "  instructions already sharded: " << already_sharded_counter;
2579       VLOG(1) << "  shardings inferred from operands: "
2580               << inferred_from_operand_counter;
2581       VLOG(1) << "  shardings inferred from users: "
2582               << inferred_from_user_counter;
2583       VLOG(1) << "  aggressiveness: " << aggressiveness;
2584       ++iterations;
2585     }
2586     return OkStatus();
2587   };
2588   for (int64_t aggressiveness = 0; aggressiveness < 4; ++aggressiveness) {
2589     TF_RETURN_IF_ERROR(run_to_fix_point(aggressiveness));
2590   }
2591 
2592   // Post-process for CSE prevention.
2593   if (cse_prevention_only_) {
2594     for (auto computation : module->computations(execution_threads)) {
2595       for (auto instruction : computation->instructions()) {
2596         if (!instruction->has_sharding()) {
2597           continue;
2598         }
2599         if (IsCSEPreventionTarget(instruction) && instruction->has_sharding()) {
2600           if (!(*original_sharding).contains(instruction)) {
2601             // Mark the propagated sharding as for CSE prevention.
2602             instruction->set_sharding(
2603                 SetCSEPreventionSharding(instruction->sharding()));
2604           }
2605           continue;
2606         }
2607         auto it = (*original_sharding).find(instruction);
2608         if (it != (*original_sharding).end()) {
2609           // Revert sharding.
2610           instruction->set_sharding(it->second);
2611         } else {
2612           // Clear sharding.
2613           instruction->clear_sharding();
2614         }
2615       }
2616     }
2617   }
2618 
2619   VLOG(1) << "Sharding propagation completed after " << iterations
2620           << " iterations";
2621   return any_changed;
2622 }
2623 
2624 }  // namespace xla
2625