xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/dot_handler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <deque>
18 #include <optional>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/cleanup/cleanup.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
31 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
32 #include "tensorflow/compiler/xla/service/hlo_sharding_util.h"
33 #include "tensorflow/compiler/xla/service/shape_inference.h"
34 #include "tensorflow/compiler/xla/service/sharding_propagation.h"
35 #include "tensorflow/compiler/xla/service/spmd/convolution_handler.h"
36 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h"
37 #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/compiler/xla/window_util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 
44 namespace xla {
45 namespace spmd {
46 
47 namespace {
48 using hlo_sharding_util::GroupedSharding;
49 }  // namespace
50 
HandleDot(HloInstruction * hlo)51 Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
52   DotConvDimsMapping mapping;
53   const auto& dnums = hlo->dot_dimension_numbers();
54   int64_t next_output_dim = 0;
55   for (int64_t i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) {
56     mapping.batch_dims.emplace_back();
57     mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i);
58     mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i);
59     mapping.batch_dims.back().output = next_output_dim++;
60   }
61   for (int64_t i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) {
62     mapping.contracting_dims.emplace_back();
63     mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i);
64     mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i);
65     mapping.contracting_dims.back().output = -1;
66   }
67   for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) {
68     if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) ||
69         absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) {
70       continue;
71     }
72     mapping.lhs_non_contracting_dims.emplace_back();
73     mapping.lhs_non_contracting_dims.back().lhs = i;
74     mapping.lhs_non_contracting_dims.back().rhs = -1;
75     mapping.lhs_non_contracting_dims.back().output = next_output_dim++;
76   }
77   for (int64_t i = 0; i < hlo->operand(1)->shape().rank(); ++i) {
78     if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) ||
79         absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) {
80       continue;
81     }
82     mapping.rhs_non_contracting_dims.emplace_back();
83     mapping.rhs_non_contracting_dims.back().lhs = -1;
84     mapping.rhs_non_contracting_dims.back().rhs = i;
85     mapping.rhs_non_contracting_dims.back().output = next_output_dim++;
86   }
87   auto create_sharded_dot =
88       [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
89           const Window& conv_window) -> StatusOr<HloInstruction*> {
90     TF_ASSIGN_OR_RETURN(
91         auto sharded_dot_shape,
92         ShapeInference::InferDotOpShape(
93             l->shape(), r->shape(), hlo->dot_dimension_numbers(),
94             /*preferred_element_type=*/hlo->shape().element_type()));
95     return b->AddInstruction(HloInstruction::CreateDot(
96         sharded_dot_shape, l, r, hlo->dot_dimension_numbers(),
97         hlo->precision_config()));
98   };
99   return HandleDotHelper(hlo, mapping, create_sharded_dot);
100 }
101 
102 namespace {
103 
104 enum class WindowedEinsumOperand { LHS, RHS };
105 
106 struct WindowedEinsumConfig {
107   WindowedEinsumOperand windowed_op;
108   bool windowed_at_contracting_dims;
109   bool windowed_at_batch_dims;
110   bool operands_sharded_at_contracting_dims;
111 };
112 
113 struct DotDimensionIndexMapping {
114   std::vector<int64_t> lhs_to_rhs_indices;
115   std::vector<int64_t> lhs_to_output_indices;
116   std::vector<int64_t> rhs_to_lhs_indices;
117   std::vector<int64_t> rhs_to_output_indices;
118   std::vector<int64_t> output_to_lhs_indices;
119   std::vector<int64_t> output_to_rhs_indices;
120 };
121 
UpdateDDNums(DotDimensionNumbers * new_ddnums,int64_t reshaped_dim,bool lhs)122 void UpdateDDNums(DotDimensionNumbers* new_ddnums, int64_t reshaped_dim,
123                   bool lhs) {
124   auto update_dims =
125       [&reshaped_dim](tensorflow::protobuf::RepeatedField<int64_t>* dims) {
126         bool add_reshaped_dim = false;
127         if (absl::c_linear_search(*dims, reshaped_dim)) {
128           add_reshaped_dim = true;
129         }
130         for (int64_t i = 0; i < dims->size(); ++i) {
131           auto dim = dims->at(i);
132           if (reshaped_dim <= dim) {
133             dims->Set(i, dim + 1);
134           }
135         }
136         if (add_reshaped_dim) {
137           dims->Add(reshaped_dim);
138         }
139       };
140 
141   if (lhs) {
142     update_dims(new_ddnums->mutable_lhs_contracting_dimensions());
143     update_dims(new_ddnums->mutable_lhs_batch_dimensions());
144   } else {  // rhs
145     update_dims(new_ddnums->mutable_rhs_contracting_dimensions());
146     update_dims(new_ddnums->mutable_rhs_batch_dimensions());
147   }
148 }
149 
GenNewWindow(const HloInstruction * original_dot,const HloInstruction * dot_lhs,const HloInstruction * dot_rhs,int64_t lhs_concat_dim,int64_t rhs_concat_dim,bool windowed_at_contracting_dims,bool windowed_at_batch_dims)150 Window GenNewWindow(const HloInstruction* original_dot,
151                     const HloInstruction* dot_lhs,
152                     const HloInstruction* dot_rhs, int64_t lhs_concat_dim,
153                     int64_t rhs_concat_dim, bool windowed_at_contracting_dims,
154                     bool windowed_at_batch_dims) {
155   auto new_window = original_dot->window();
156   const ConvolutionDimensionNumbers& conv_dnums =
157       original_dot->convolution_dimension_numbers();
158   if (lhs_concat_dim != -1) {
159     for (int64_t i = 0; i < conv_dnums.input_spatial_dimensions_size(); ++i) {
160       if (conv_dnums.input_spatial_dimensions(i) == lhs_concat_dim) {
161         auto wd = new_window.mutable_dimensions(i);
162         auto lhs_size = dot_lhs->shape().dimensions(lhs_concat_dim + 1);
163         if (windowed_at_contracting_dims) {
164           wd->set_size(lhs_size);
165         }
166         if (windowed_at_batch_dims) {
167           wd->set_size(lhs_size);
168           wd->set_padding_low(0);
169           wd->set_padding_high(0);
170           wd->set_stride(std::max<int64_t>(1, lhs_size - 1));
171           wd->set_window_dilation(1);
172           wd->set_base_dilation(lhs_size);
173           wd->set_window_reversal(false);
174         }
175       }
176     }
177   }
178   if (rhs_concat_dim != -1) {
179     for (int64_t i = 0; i < conv_dnums.kernel_spatial_dimensions_size(); ++i) {
180       if (conv_dnums.kernel_spatial_dimensions(i) == rhs_concat_dim &&
181           !windowed_at_contracting_dims && !windowed_at_batch_dims &&
182           lhs_concat_dim == -1) {
183         auto wd = new_window.mutable_dimensions(i);
184         auto rhs_size = dot_rhs->shape().dimensions(rhs_concat_dim + 1);
185         wd->set_size(rhs_size);
186         wd->set_padding_low(rhs_size - 1);
187         wd->set_padding_high(rhs_size - 1);
188       }
189     }
190   }
191   // Add the extra dimension to window.
192   WindowDimension* new_dim = new_window.add_dimensions();
193   if (windowed_at_contracting_dims) {
194     new_dim->set_size(2);
195     new_dim->set_padding_low(0);
196     new_dim->set_padding_high(0);
197     new_dim->set_stride(1);
198     new_dim->set_window_dilation(1);
199     new_dim->set_base_dilation(1);
200     new_dim->set_window_reversal(false);
201   } else if (windowed_at_batch_dims) {
202     new_dim->set_size(2);
203     new_dim->set_padding_low(0);
204     new_dim->set_padding_high(0);
205     new_dim->set_stride(1);  // std::max<int64_t>(1, 2 - 1)
206     new_dim->set_window_dilation(1);
207     new_dim->set_base_dilation(2);
208     new_dim->set_window_reversal(false);
209   } else {
210     if (lhs_concat_dim != -1) {
211       new_dim->set_size(1);
212       new_dim->set_padding_low(0);
213       new_dim->set_padding_high(0);
214       new_dim->set_stride(1);
215       new_dim->set_window_dilation(1);
216       new_dim->set_base_dilation(1);
217       new_dim->set_window_reversal(false);
218     }
219     if (rhs_concat_dim != -1) {
220       new_dim->set_size(2);          // rhs_size
221       new_dim->set_padding_low(1);   // rhs_size - 1
222       new_dim->set_padding_high(1);  // rhs_size - 1
223       new_dim->set_stride(1);
224       new_dim->set_window_dilation(1);
225       new_dim->set_base_dilation(1);
226       new_dim->set_window_reversal(true);
227     }
228   }
229 
230   VLOG(2) << "new_window: " << new_window.ShortDebugString();
231   return new_window;
232 }
233 
GenNewConvDNums(const HloInstruction * original_dot,const HloInstruction * dot_lhs,const HloInstruction * dot_rhs,int64_t lhs_concat_dim,int64_t rhs_concat_dim,bool windowed_at_contracting_dims,bool windowed_at_batch_dims,const std::vector<int64_t> & lhs_to_output_indices,const std::vector<int64_t> & rhs_to_output_indices,const Shape & new_dot_shape)234 ConvolutionDimensionNumbers GenNewConvDNums(
235     const HloInstruction* original_dot, const HloInstruction* dot_lhs,
236     const HloInstruction* dot_rhs, int64_t lhs_concat_dim,
237     int64_t rhs_concat_dim, bool windowed_at_contracting_dims,
238     bool windowed_at_batch_dims,
239     const std::vector<int64_t>& lhs_to_output_indices,
240     const std::vector<int64_t>& rhs_to_output_indices,
241     const Shape& new_dot_shape) {
242   // Generate the new conv dimension numbers.
243   const ConvolutionDimensionNumbers& dnums =
244       original_dot->convolution_dimension_numbers();
245   // Handle the LHS dimension numbers.
246   int64_t input_batch_dimension = dnums.input_batch_dimension();
247   int64_t input_feature_dimension = dnums.input_feature_dimension();
248   std::vector<int64_t> input_spatial_dimensions(
249       dnums.input_spatial_dimensions().begin(),
250       dnums.input_spatial_dimensions().end());
251   if (lhs_concat_dim != -1) {
252     if (lhs_concat_dim <= input_batch_dimension) {
253       input_batch_dimension++;
254     }
255     if (lhs_concat_dim <= input_feature_dimension) {
256       input_feature_dimension++;
257     }
258     for (int64_t i = 0; i < input_spatial_dimensions.size(); ++i) {
259       if (lhs_concat_dim <= input_spatial_dimensions[i]) {
260         input_spatial_dimensions[i]++;
261       }
262     }
263     input_spatial_dimensions.push_back(lhs_concat_dim);
264   }
265   if (rhs_concat_dim != -1 && !windowed_at_contracting_dims &&
266       !windowed_at_batch_dims) {
267     input_spatial_dimensions.push_back(dot_lhs->shape().dimensions_size() - 1);
268   }
269   // Handle the RHS dimension numbers.
270   int64_t kernel_input_feature_dimension =
271       dnums.kernel_input_feature_dimension();
272   int64_t kernel_output_feature_dimension =
273       dnums.kernel_output_feature_dimension();
274   std::vector<int64_t> kernel_spatial_dimensions(
275       dnums.kernel_spatial_dimensions().begin(),
276       dnums.kernel_spatial_dimensions().end());
277   if (rhs_concat_dim != -1) {
278     if (rhs_concat_dim <= kernel_input_feature_dimension) {
279       kernel_input_feature_dimension++;
280     }
281     if (rhs_concat_dim <= kernel_output_feature_dimension) {
282       kernel_output_feature_dimension++;
283     }
284     for (int64_t i = 0; i < kernel_spatial_dimensions.size(); ++i) {
285       if (rhs_concat_dim <= kernel_spatial_dimensions[i]) {
286         kernel_spatial_dimensions[i]++;
287       }
288     }
289     kernel_spatial_dimensions.push_back(rhs_concat_dim);
290   }
291   if (lhs_concat_dim != -1 && !windowed_at_contracting_dims &&
292       !windowed_at_batch_dims) {
293     kernel_spatial_dimensions.push_back(dot_rhs->shape().dimensions_size() - 1);
294   }
295   // Handle the Output dimension numbers.
296   int64_t output_batch_dimension = dnums.output_batch_dimension();
297   int64_t output_feature_dimension = dnums.output_feature_dimension();
298   std::vector<int64_t> output_spatial_dimensions(
299       dnums.output_spatial_dimensions().begin(),
300       dnums.output_spatial_dimensions().end());
301   if (!windowed_at_contracting_dims) {
302     auto output_slice_dim = lhs_concat_dim != -1
303                                 ? lhs_to_output_indices[lhs_concat_dim]
304                                 : rhs_to_output_indices[rhs_concat_dim];
305     if (output_slice_dim <= output_batch_dimension) {
306       output_batch_dimension++;
307     }
308     if (output_slice_dim <= output_feature_dimension) {
309       output_feature_dimension++;
310     }
311     for (int64_t i = 0; i < output_spatial_dimensions.size(); ++i) {
312       if (output_slice_dim <= output_spatial_dimensions[i]) {
313         output_spatial_dimensions[i]++;
314       }
315     }
316     output_spatial_dimensions.push_back(output_slice_dim);
317   } else {
318     output_spatial_dimensions.push_back(new_dot_shape.dimensions_size() - 1);
319   }
320   // Construct the new dot dimension numbers.
321   ConvolutionDimensionNumbers new_dnums;
322   new_dnums.set_input_batch_dimension(input_batch_dimension);
323   new_dnums.set_input_feature_dimension(input_feature_dimension);
324   for (auto dim : input_spatial_dimensions) {
325     new_dnums.add_input_spatial_dimensions(dim);
326   }
327   new_dnums.set_kernel_input_feature_dimension(kernel_input_feature_dimension);
328   new_dnums.set_kernel_output_feature_dimension(
329       kernel_output_feature_dimension);
330   for (auto dim : kernel_spatial_dimensions) {
331     new_dnums.add_kernel_spatial_dimensions(dim);
332   }
333   new_dnums.set_output_batch_dimension(output_batch_dimension);
334   new_dnums.set_output_feature_dimension(output_feature_dimension);
335   for (auto dim : output_spatial_dimensions) {
336     new_dnums.add_output_spatial_dimensions(dim);
337   }
338 
339   return new_dnums;
340 }
341 
ComputeDimensionIndexMapping(const DotConvDimsMapping & dims_mapping,int64_t lhs_rank,int64_t rhs_rank,int64_t output_rank)342 DotDimensionIndexMapping ComputeDimensionIndexMapping(
343     const DotConvDimsMapping& dims_mapping, int64_t lhs_rank, int64_t rhs_rank,
344     int64_t output_rank) {
345   std::vector<int64_t> lhs_to_rhs_indices(lhs_rank, -1);
346   std::vector<int64_t> lhs_to_output_indices(lhs_rank, -1);
347   std::vector<int64_t> rhs_to_lhs_indices(rhs_rank, -1);
348   std::vector<int64_t> rhs_to_output_indices(rhs_rank, -1);
349   std::vector<int64_t> output_to_lhs_indices(output_rank, -1);
350   std::vector<int64_t> output_to_rhs_indices(output_rank, -1);
351   auto populate_indices_mapping =
352       [&](const DotConvDimsMapping::DimsMapping& mapping) {
353         if (mapping.lhs >= 0) {
354           lhs_to_rhs_indices[mapping.lhs] = mapping.rhs;
355           lhs_to_output_indices[mapping.lhs] = mapping.output;
356         }
357         if (mapping.rhs >= 0) {
358           rhs_to_lhs_indices[mapping.rhs] = mapping.lhs;
359           rhs_to_output_indices[mapping.rhs] = mapping.output;
360         }
361         if (mapping.output >= 0) {
362           output_to_lhs_indices[mapping.output] = mapping.lhs;
363           output_to_rhs_indices[mapping.output] = mapping.rhs;
364         }
365       };
366   for (const auto& mapping : dims_mapping.batch_dims) {
367     populate_indices_mapping(mapping);
368   }
369   for (const auto& mapping : dims_mapping.contracting_dims) {
370     populate_indices_mapping(mapping);
371   }
372   for (const auto& mapping : dims_mapping.lhs_non_contracting_dims) {
373     populate_indices_mapping(mapping);
374   }
375   for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) {
376     populate_indices_mapping(mapping);
377   }
378   for (const auto& mapping : dims_mapping.conv_spatial_dims) {
379     populate_indices_mapping(mapping);
380   }
381   return DotDimensionIndexMapping{lhs_to_rhs_indices,    lhs_to_output_indices,
382                                   rhs_to_lhs_indices,    rhs_to_output_indices,
383                                   output_to_lhs_indices, output_to_rhs_indices};
384 }
385 
GetPartitionGroupsForReplication(const HloSharding & sharding,absl::Span<const int64_t> replication_dims)386 std::vector<std::vector<int64_t>> GetPartitionGroupsForReplication(
387     const HloSharding& sharding, absl::Span<const int64_t> replication_dims) {
388   int64_t group_size = 1;
389   for (int64_t i : replication_dims) {
390     group_size *= sharding.tile_assignment().dim(i);
391   }
392   std::vector<std::vector<int64_t>> partition_groups(
393       sharding.tile_assignment().num_elements() / group_size);
394   sharding.tile_assignment().Each(
395       [&](absl::Span<const int64_t> indices, int64_t partition) {
396         int64_t group_id = 0;
397         for (int64_t i = 0; i < indices.size(); ++i) {
398           if (!absl::c_linear_search(replication_dims, i)) {
399             group_id *= sharding.tile_assignment().dim(i);
400             group_id += indices[i];
401           }
402         }
403         partition_groups[group_id].push_back(partition);
404       });
405   return partition_groups;
406 }
407 
408 // Returns true iff all of the following conditions are simultaneously true:
409 // 1) 'lhs/rhs_sharding' have different partition counts on a dimension in
410 //    'dims'.
411 // 2) 'lhs/rhs_sharding' BOTH have partitions on at least one dimension in
412 //    'dims'.
RequiresTransposeSharding(const HloSharding & lhs_sharding,const HloSharding & rhs_sharding,const std::vector<DotConvDimsMapping::DimsMapping> & dims)413 bool RequiresTransposeSharding(
414     const HloSharding& lhs_sharding, const HloSharding& rhs_sharding,
415     const std::vector<DotConvDimsMapping::DimsMapping>& dims) {
416   int64_t lhs_total_partitions = 1;
417   int64_t rhs_total_partitions = 1;
418   bool has_different_lhs_rhs_dim_sharding = false;
419   for (const auto& dim : dims) {
420     int64_t lhs_dim_partitions = lhs_sharding.tile_assignment().dim(dim.lhs);
421     lhs_total_partitions *= lhs_dim_partitions;
422 
423     int64_t rhs_dim_partitions = rhs_sharding.tile_assignment().dim(dim.rhs);
424     rhs_total_partitions *= rhs_dim_partitions;
425 
426     if (lhs_dim_partitions != rhs_dim_partitions) {
427       has_different_lhs_rhs_dim_sharding = true;
428     }
429   }
430   return lhs_total_partitions > 1 && rhs_total_partitions > 1 &&
431          has_different_lhs_rhs_dim_sharding;
432 }
433 
GetWindowedEinsumConfiguration(int64_t num_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t rhs_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t rhs_batch_partitions,int64_t lhs_contracting_partitions,int64_t lhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_shape_size,int64_t lhs_shape_size,int64_t output_shape_size,const SpmdPartitionerOptions & options,const std::optional<HloSharding> & output_sharding_transposed_to_match_lhs,const std::optional<HloSharding> & output_sharding_transposed_to_match_rhs,const std::optional<HloSharding> & lhs_sharding_transposed_to_match_rhs,const std::optional<HloSharding> & rhs_sharding_transposed_to_match_lhs,const HloSharding & lhs_sharding,const HloSharding & rhs_sharding,const Window & conv_window,const DotConvDimsMapping & dims_mapping,int64_t max_iterations=INT64_MAX,const HloInstruction * original_hlo=nullptr,PartitionedHlo * partitioned_lhs=nullptr,PartitionedHlo * partitioned_rhs=nullptr,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot={},SpmdBuilder * b=nullptr,HloModule * module=nullptr,SpmdPartitioningVisitor * visitor=nullptr)434 std::optional<WindowedEinsumConfig> GetWindowedEinsumConfiguration(
435     int64_t num_partitions, int64_t output_lhs_non_contracting_partitions,
436     int64_t output_rhs_non_contracting_partitions,
437     int64_t rhs_contracting_partitions, int64_t rhs_non_contracting_partitions,
438     int64_t rhs_batch_partitions, int64_t lhs_contracting_partitions,
439     int64_t lhs_non_contracting_partitions, int64_t lhs_batch_partitions,
440     int64_t rhs_shape_size, int64_t lhs_shape_size, int64_t output_shape_size,
441     const SpmdPartitionerOptions& options,
442     const std::optional<HloSharding>& output_sharding_transposed_to_match_lhs,
443     const std::optional<HloSharding>& output_sharding_transposed_to_match_rhs,
444     const std::optional<HloSharding>& lhs_sharding_transposed_to_match_rhs,
445     const std::optional<HloSharding>& rhs_sharding_transposed_to_match_lhs,
446     const HloSharding& lhs_sharding, const HloSharding& rhs_sharding,
447     const Window& conv_window, const DotConvDimsMapping& dims_mapping,
448     int64_t max_iterations = INT64_MAX,
449     const HloInstruction* original_hlo = nullptr,
450     PartitionedHlo* partitioned_lhs = nullptr,
451     PartitionedHlo* partitioned_rhs = nullptr,
452     const std::function<StatusOr<HloInstruction*>(
453         HloInstruction*, HloInstruction*, SpmdBuilder*,
454         const Window& conv_window)>& create_sharded_dot = {},
455     SpmdBuilder* b = nullptr, HloModule* module = nullptr,
456     SpmdPartitioningVisitor* visitor = nullptr) {
457   if (num_partitions > max_iterations) {
458     return std::nullopt;
459   }
460 
461   const HloInstruction* lhs = nullptr;
462   const HloInstruction* rhs = nullptr;
463   if (original_hlo) {
464     lhs = original_hlo->operand(0);
465     rhs = original_hlo->operand(1);
466   }
467 
468   // Determine if any of the users users have the same shardings that can allow
469   // reuse of the resharding for the operand with original_hlo.
470   auto check_users_sharding = [original_hlo](
__anon991ccdc40702( const HloInstruction* to_loop_over) 471                                   const HloInstruction* to_loop_over) {
472     if (to_loop_over->users().size() <= 1) {
473       return true;
474     }
475     constexpr int kAggressiveness = 3;
476     std::optional<HloSharding> original_ideal_sharding =
477         ShardingPropagation::GetShardingFromUser(*to_loop_over, *original_hlo,
478                                                  kAggressiveness,
479                                                  /*is_spmd=*/true);
480     // Default to perform collective matmul if GetShardingFromUser() couldn't
481     // determine the sharding.
482     if (!original_ideal_sharding) {
483       return true;
484     }
485     for (const HloInstruction* user : to_loop_over->users()) {
486       if (user == original_hlo) {
487         continue;
488       }
489       std::optional<HloSharding> from_user =
490           ShardingPropagation::GetShardingFromUser(*to_loop_over, *user,
491                                                    kAggressiveness,
492                                                    /*is_spmd=*/true);
493       // Could't determine sharding. Skip to next one and pretend it wouldn't
494       // share the resharding.
495       if (!from_user) {
496         continue;
497       }
498       // This user doesn't require resharding, so even if has different sharding
499       // than original_hlo its ok to do collective matmul.
500       if (*from_user == to_loop_over->sharding()) {
501         continue;
502       }
503       // Same sharding needed, so we would share the resharding. Do not do
504       // collective matmul.
505       if (*original_ideal_sharding == *from_user) {
506         return false;
507       }
508     }
509     return true;
510   };
511 
512   // Disable windowed einsum when the overheads may overweigh the benefits.
513   // Specifically, when max(computation time, communication time after
514   // decomposition) + extra prologue or epilogue collecitve permute is longer
515   // than the sum of computation time and the original communication time
516   // which can use more communication links. This is checked with the premise
517   // that communication/computation is large enough. For super small
518   // communication/computation generated by unit tests, we always allow windowed
519   // einsum to have meaningful unit tests.
__anon991ccdc40802(bool lhs_needs_ag, bool rhs_needs_ag) 520   auto disable_windowed_einsum = [&](bool lhs_needs_ag, bool rhs_needs_ag) {
521     if (visitor == nullptr) {
522       return false;
523     }
524 
525     double computation_time_in_ms = 0.0;
526     double communication_time_in_ms = 0.0;
527     HloInstruction* dot;
528     HloInstruction* collective;
529     if (lhs_needs_ag || rhs_needs_ag) {
530       CHECK(!lhs_needs_ag || !rhs_needs_ag);
531       auto new_lhs = lhs_needs_ag
532                          ? PartitionedHlo(partitioned_lhs->hlo(),
533                                           partitioned_lhs->base_shape(),
534                                           partitioned_lhs->state())
535                                .Reshard(HloSharding::Replicate())
536                          : *partitioned_lhs;
537       auto new_rhs = rhs_needs_ag
538                          ? PartitionedHlo(partitioned_rhs->hlo(),
539                                           partitioned_rhs->base_shape(),
540                                           partitioned_rhs->state())
541                                .Reshard(HloSharding::Replicate())
542                          : *partitioned_rhs;
543       dot = create_sharded_dot(new_lhs.hlo(), new_rhs.hlo(), b, conv_window)
544                 .ValueOrDie();
545       computation_time_in_ms = visitor->GetComputationTimeInMilliSec(dot);
546 
547       collective = lhs_needs_ag ? new_lhs.hlo() : new_rhs.hlo();
548       while (collective->opcode() != HloOpcode::kAllGather &&
549              collective->opcode() != HloOpcode::kAllReduce &&
550              collective->operand_count() > 0 &&
551              collective != (lhs_needs_ag ? partitioned_lhs->hlo()
552                                          : partitioned_rhs->hlo())) {
553         collective = collective->mutable_operand(0);
554       }
555       if (collective->opcode() == HloOpcode::kAllGather ||
556           collective->opcode() == HloOpcode::kAllReduce) {
557         communication_time_in_ms = visitor->GetCommunicationTimeInMilliSec(
558             ShapeUtil::ByteSizeOf(collective->shape()),
559             collective->replica_groups());
560       }
561     } else {
562       auto new_lhs =
563           PartitionedHlo(partitioned_lhs->hlo(), partitioned_lhs->base_shape(),
564                          partitioned_lhs->state());
565       auto new_rhs =
566           PartitionedHlo(partitioned_rhs->hlo(), partitioned_rhs->base_shape(),
567                          partitioned_rhs->state());
568 
569       // Check if contracting dimension sharding requires lhs/rhs resharding.
570       if (RequiresTransposeSharding(lhs->sharding(), rhs->sharding(),
571                                     dims_mapping.contracting_dims) &&
572           rhs_sharding_transposed_to_match_lhs.has_value() &&
573           lhs_sharding_transposed_to_match_rhs.has_value()) {
574         if (ShapeSizeInBytes(lhs->shape()) < ShapeSizeInBytes(rhs->shape())) {
575           new_lhs = new_lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
576         } else {
577           new_rhs = new_rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
578         }
579       }
580       new_lhs = new_lhs.PadWithZero();
581       new_rhs = new_rhs.PadWithZero();
582 
583       dot = create_sharded_dot(new_lhs.hlo(), new_rhs.hlo(), b, conv_window)
584                 .ValueOrDie();
585       computation_time_in_ms = visitor->GetComputationTimeInMilliSec(dot);
586 
587       std::vector<int64_t> lhs_contracting_dims;
588       lhs_contracting_dims.reserve(new_lhs.base_shape().rank());
589       for (const auto& cd : dims_mapping.contracting_dims) {
590         lhs_contracting_dims.push_back(cd.lhs);
591       }
592       collective = new_lhs.state().partitioner->AllReduceAlongShardingDims(
593           b, dot, new_lhs.sharding(), new_lhs.state().next_channel_id,
594           lhs_contracting_dims, new_lhs.state().collective_ops_creator,
595           MakeBinaryAdd(dot->shape().element_type(), module));
596       communication_time_in_ms = visitor->GetCommunicationTimeInMilliSec(
597           ShapeUtil::ByteSizeOf(dot->shape()), collective->replica_groups());
598     }
599 
600     VLOG(2) << "collective: " << collective->ToString() << "\n"
601             << "dot: " << dot->ToString() << "\n"
602             << "num_partitions: " << num_partitions << "\n"
603             << "computation_time_in_ms: " << computation_time_in_ms
604             << " communication_time_in_ms: " << communication_time_in_ms;
605     double extra_collective_permute_time = 0.0;
606     if (communication_time_in_ms != 0.0) {
607       extra_collective_permute_time =
608           communication_time_in_ms *
609           visitor->GetCommunicationMultiplier(collective->replica_groups()) *
610           2 / num_partitions;
611     }
612     if (communication_time_in_ms > 1e-5 &&
613         (std::max(
614              computation_time_in_ms,
615              communication_time_in_ms * visitor->GetCommunicationMultiplier(
616                                             collective->replica_groups())) +
617          extra_collective_permute_time) >=
618             (computation_time_in_ms + communication_time_in_ms)) {
619       return true;
620     } else {
621       return false;
622     }
623   };
624 
625   if (output_lhs_non_contracting_partitions == num_partitions &&
626       output_sharding_transposed_to_match_lhs == lhs_sharding &&
627       rhs_shape_size >=
628           options.threshold_for_windowed_einsum_mib * 1024 * 1024 &&
629       (!rhs || check_users_sharding(rhs)) &&
630       !disable_windowed_einsum(/*lhs_needs_ag=*/false, /*rhs_needs_ag=*/true)) {
631     if (rhs_contracting_partitions == num_partitions) {
632       return WindowedEinsumConfig{
633           /*windowed_op=*/WindowedEinsumOperand::RHS,
634           /*windowed_at_contracting_dims*/ true,
635           /*windowed_at_batch_dims=*/false,
636           /*operands_sharded_at_contracting_dims=*/false};
637     }
638     if (rhs_non_contracting_partitions == num_partitions) {
639       return WindowedEinsumConfig{
640           /*windowed_op=*/WindowedEinsumOperand::RHS,
641           /*windowed_at_contracting_dims*/ false,
642           /*windowed_at_batch_dims=*/false,
643           /*operands_sharded_at_contracting_dims=*/false};
644     }
645     if (rhs_batch_partitions == num_partitions) {
646       return WindowedEinsumConfig{
647           /*windowed_op=*/WindowedEinsumOperand::RHS,
648           /*windowed_at_contracting_dims*/ false,
649           /*windowed_at_batch_dims=*/true,
650           /*operands_sharded_at_contracting_dims=*/false};
651     }
652   }
653   if (output_rhs_non_contracting_partitions == num_partitions &&
654       output_sharding_transposed_to_match_rhs == rhs_sharding &&
655       lhs_shape_size >=
656           options.threshold_for_windowed_einsum_mib * 1024 * 1024 &&
657       (!lhs || check_users_sharding(lhs)) &&
658       !disable_windowed_einsum(/*lhs_needs_ag=*/true, /*rhs_needs_ag=*/false)) {
659     if (lhs_contracting_partitions == num_partitions) {
660       return WindowedEinsumConfig{
661           /*windowed_op=*/WindowedEinsumOperand::LHS,
662           /*windowed_at_contracting_dims*/ true,
663           /*windowed_at_batch_dims=*/false,
664           /*operands_sharded_at_contracting_dims=*/false};
665     }
666     if (lhs_non_contracting_partitions == num_partitions) {
667       return WindowedEinsumConfig{
668           /*windowed_op=*/WindowedEinsumOperand::LHS,
669           /*windowed_at_contracting_dims*/ false,
670           /*windowed_at_batch_dims=*/false,
671           /*operands_sharded_at_contracting_dims=*/false};
672     }
673     if (lhs_batch_partitions == num_partitions) {
674       return WindowedEinsumConfig{
675           /*windowed_op=*/WindowedEinsumOperand::LHS,
676           /*windowed_at_contracting_dims*/ false,
677           /*windowed_at_batch_dims=*/true,
678           /*operands_sharded_at_contracting_dims=*/false};
679     }
680   }
681   if (lhs_contracting_partitions == rhs_contracting_partitions &&
682       lhs_contracting_partitions == num_partitions &&
683       (output_lhs_non_contracting_partitions == num_partitions ||
684        output_rhs_non_contracting_partitions == num_partitions) &&
685       output_shape_size >=
686           options.threshold_for_windowed_einsum_mib * 1024 * 1024 &&
687       !disable_windowed_einsum(/*lhs_needs_ag=*/false,
688                                /*rhs_needs_ag=*/false)) {
689     if (output_lhs_non_contracting_partitions == num_partitions) {
690       return WindowedEinsumConfig{
691           /*windowed_op=*/WindowedEinsumOperand::RHS,
692           /*windowed_at_contracting_dims*/ false,
693           /*windowed_at_batch_dims=*/false,
694           /*operands_sharded_at_contracting_dims=*/true};
695     }
696     if (output_rhs_non_contracting_partitions == num_partitions) {
697       return WindowedEinsumConfig{
698           /*windowed_op=*/WindowedEinsumOperand::LHS,
699           /*windowed_at_contracting_dims*/ false,
700           /*windowed_at_batch_dims=*/false,
701           /*operands_sharded_at_contracting_dims=*/true};
702     }
703   }
704   return std::nullopt;
705 }
706 
GetLoopReplicaGroups(HloInstruction * while_loop)707 std::vector<ReplicaGroup> GetLoopReplicaGroups(HloInstruction* while_loop) {
708   std::vector<ReplicaGroup> groups;
709   for (auto inst : while_loop->while_body()->instructions()) {
710     if (inst->opcode() == HloOpcode::kCollectivePermute) {
711       std::vector<std::pair<int64_t, int64_t>> st_pairs =
712           inst->source_target_pairs();
713       std::vector<int64_t> source_index(st_pairs.size());
714       for (int64_t i = 0; i < st_pairs.size(); ++i) {
715         source_index[st_pairs[i].first] = i;
716       }
717 
718       absl::flat_hash_set<int64_t> visited;
719       for (int64_t i = 0; i < st_pairs.size(); ++i) {
720         if (visited.contains(st_pairs[i].first)) {
721           continue;
722         }
723         std::vector<int64_t> replica_group;
724         int64_t source = st_pairs[i].first;
725         int64_t target = st_pairs[i].second;
726         replica_group.push_back(source);
727         replica_group.push_back(target);
728         visited.insert(source);
729         visited.insert(target);
730         while (target != source) {
731           target = st_pairs[source_index[target]].second;
732           if (target != source) {
733             replica_group.push_back(target);
734             visited.insert(target);
735           }
736         }
737         absl::c_sort(replica_group);
738         groups.emplace_back();
739         for (auto id : replica_group) {
740           groups.back().add_replica_ids(id);
741         }
742       }
743 
744       VLOG(3) << "while loop: " << while_loop->name()
745               << ", replica groups: " << ReplicaGroupsToString(groups);
746       break;
747     }
748   }
749   return groups;
750 }
751 
752 // We use a recursive approach where sets of matching dimensions are recognized
753 // one at a time. The base shapes and shardings can be changed during the
754 // recursion as we group devices together. So refer to the passed in shapes and
755 // shardings for inputs and output, and do not use shape inference.
756 
PartitionBaseCase(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions,int64_t output_batch_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,bool may_reshard_without_detecting_match,SpmdPartitioningVisitor * visitor)757 StatusOr<HloInstruction*> PartitionBaseCase(
758     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
759     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
760     int64_t num_partitions,
761     const std::function<StatusOr<HloInstruction*>(
762         HloInstruction*, HloInstruction*, SpmdBuilder*,
763         const Window& conv_window)>& create_sharded_dot,
764     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
765     int64_t lhs_batch_partitions, int64_t rhs_batch_partitions,
766     int64_t output_batch_partitions, int64_t lhs_contracting_partitions,
767     int64_t rhs_contracting_partitions, int64_t lhs_non_contracting_partitions,
768     int64_t rhs_non_contracting_partitions,
769     int64_t output_lhs_non_contracting_partitions,
770     int64_t output_rhs_non_contracting_partitions,
771     const SpmdPartitionerOptions& options, SpmdBuilder* b,
772     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
773         windowed_dot_general_loops,
774     bool may_reshard_without_detecting_match,
775     SpmdPartitioningVisitor* visitor) {
776   const HloSharding& lhs_sharding = lhs.sharding();
777   const HloSharding& rhs_sharding = rhs.sharding();
778   if (lhs_sharding.ReplicateOnLastTileDim() ||
779       rhs_sharding.ReplicateOnLastTileDim() ||
780       output_sharding.ReplicateOnLastTileDim()) {
781     return nullptr;
782   }
783   DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
784       dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
785       output_base_shape.rank());
786   auto lhs_sharding_transposed_to_match_rhs =
787       hlo_sharding_util::TransposeShardingWithCollapsedDims(
788           lhs_sharding, indices_map.lhs_to_rhs_indices,
789           indices_map.rhs_to_lhs_indices);
790   auto rhs_sharding_transposed_to_match_lhs =
791       hlo_sharding_util::TransposeShardingWithCollapsedDims(
792           rhs_sharding, indices_map.rhs_to_lhs_indices,
793           indices_map.lhs_to_rhs_indices);
794   auto lhs_sharding_transposed_to_match_output =
795       hlo_sharding_util::TransposeShardingWithCollapsedDims(
796           lhs_sharding, indices_map.lhs_to_output_indices,
797           indices_map.output_to_lhs_indices);
798   auto rhs_sharding_transposed_to_match_output =
799       hlo_sharding_util::TransposeShardingWithCollapsedDims(
800           rhs_sharding, indices_map.rhs_to_output_indices,
801           indices_map.output_to_rhs_indices);
802   auto output_sharding_transposed_to_match_lhs =
803       hlo_sharding_util::TransposeShardingWithCollapsedDims(
804           output_sharding, indices_map.output_to_lhs_indices,
805           indices_map.lhs_to_output_indices);
806   auto output_sharding_transposed_to_match_rhs =
807       hlo_sharding_util::TransposeShardingWithCollapsedDims(
808           output_sharding, indices_map.output_to_rhs_indices,
809           indices_map.rhs_to_output_indices);
810 
811   // LHS and RHS are partitioned the same way and only partitioned in batch
812   // dimensions.
813   if (lhs_batch_partitions == rhs_batch_partitions &&
814       rhs_batch_partitions == num_partitions &&
815       lhs_sharding_transposed_to_match_rhs == rhs_sharding) {
816     TF_ASSIGN_OR_RETURN(
817         auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
818     dot->set_sharding(*lhs_sharding_transposed_to_match_output);
819     return PartitionedHlo(dot, output_base_shape, lhs.state())
820         .Reshard(output_sharding)
821         .hlo();
822   }
823 
824   // Try emit batch-partitioned einsum with one operand resharded. Returns
825   // partitioned HLO or nullptr if the attempt fails. If
826   // may_reshard_with_allreduce is false, reshard must be done using
827   // all-to-all/collective-permute; otherwise this attempt fails.
828   auto try_emit_output_batch_partitioned_einsum_with_reshard =
829       [&](bool may_reshard_with_allreduce) -> StatusOr<HloInstruction*> {
830     // LHS and output are batch partitioned in the same way.
831     if (lhs_batch_partitions == num_partitions &&
832         output_batch_partitions == num_partitions &&
833         lhs_sharding_transposed_to_match_output == output_sharding) {
834       if (!may_reshard_with_allreduce &&
835           !CanReshardWithCollectivePermute(
836               rhs.sharding(), *lhs_sharding_transposed_to_match_rhs) &&
837           !GetReshardAllToAllSourceTargetDims(
838               rhs.sharding(), *lhs_sharding_transposed_to_match_rhs)) {
839         return nullptr;
840       }
841       auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
842       TF_ASSIGN_OR_RETURN(
843           auto dot,
844           create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b, conv_window));
845       return dot;
846     }
847     // RHS and output are batch partitioned in the same way.
848     if (rhs_batch_partitions == num_partitions &&
849         output_batch_partitions == num_partitions &&
850         rhs_sharding_transposed_to_match_output == output_sharding) {
851       if (!may_reshard_with_allreduce &&
852           !CanReshardWithCollectivePermute(
853               lhs.sharding(), *rhs_sharding_transposed_to_match_lhs) &&
854           !GetReshardAllToAllSourceTargetDims(
855               lhs.sharding(), *rhs_sharding_transposed_to_match_lhs)) {
856         return nullptr;
857       }
858       auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
859       TF_ASSIGN_OR_RETURN(
860           auto dot,
861           create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b, conv_window));
862       return dot;
863     }
864     return nullptr;
865   };
866 
867   {
868     // Try batch-parallel by resharding one operand, and not using all-reduce.
869     TF_ASSIGN_OR_RETURN(
870         HloInstruction * partitioned_dot,
871         try_emit_output_batch_partitioned_einsum_with_reshard(false));
872     if (partitioned_dot) {
873       return partitioned_dot;
874     }
875   }
876 
877   // Try to emit windowed DotGeneral when one operand is partitioned in the same
878   // way as the output along non-contracting dimensions, but the other operand
879   // is tiled in other dimensions. Or both operands are partitioned in the same
880   // way along contracting dimensions, but the output is partitioned along
881   // non-contracting dimensions.
882   auto emit_windowed_dot_general =
883       [&](const WindowedEinsumConfig& einsum_config)
884       -> StatusOr<HloInstruction*> {
885     CHECK(!einsum_config.windowed_at_batch_dims ||
886           !einsum_config.windowed_at_contracting_dims);
887     const bool windowed_at_batch_dims = einsum_config.windowed_at_batch_dims;
888     const bool windowed_at_contracting_dims =
889         einsum_config.windowed_at_contracting_dims;
890     const bool operands_sharded_at_contracting_dims =
891         einsum_config.operands_sharded_at_contracting_dims;
892     auto unpadded_result_buffer_shape =
893         MakePartitionedShape(output_base_shape, output_sharding);
894     auto padded_result_buffer_shape = unpadded_result_buffer_shape;
895     const bool windowed_op_is_lhs =
896         einsum_config.windowed_op == WindowedEinsumOperand::LHS;
897     // For windowing at batch/non-contracting dims, we produce the result one
898     // partition at a time, so we need to pad the shape in case of uneven
899     // partitioning in order to make dynamic-update-slice in-bound.
900     if (!windowed_at_contracting_dims &&
901         !operands_sharded_at_contracting_dims) {
902       padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning(
903           padded_result_buffer_shape,
904           windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
905                              : *rhs_sharding_transposed_to_match_output);
906     }
907     // Mask the padding area of the windowed operand with zero if there is
908     // uneven partitioning.
909     if (windowed_at_contracting_dims) {
910       auto& to_mask = windowed_op_is_lhs ? lhs : rhs;
911       to_mask = to_mask.PadWithZero();
912     }
913     if (operands_sharded_at_contracting_dims) {
914       lhs = lhs.PadWithZero();
915       rhs = rhs.PadWithZero();
916     }
917 
918     // Check if contracting dimension sharding requires lhs/rhs resharding.
919     if (RequiresTransposeSharding(lhs.hlo()->sharding(), rhs.hlo()->sharding(),
920                                   dims_mapping.contracting_dims) &&
921         rhs_sharding_transposed_to_match_lhs.has_value() &&
922         lhs_sharding_transposed_to_match_rhs.has_value()) {
923       if (ShapeSizeInBytes(lhs.hlo()->shape()) <
924           ShapeSizeInBytes(rhs.hlo()->shape())) {
925         lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithZero();
926       } else {
927         rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithZero();
928       }
929     }
930 
931     // Get slice sharding, sharding dim, and lhs/rhs concat dim.
932     const HloSharding* slice_sharding;
933     if (operands_sharded_at_contracting_dims) {
934       slice_sharding = windowed_op_is_lhs
935                            ? &*output_sharding_transposed_to_match_rhs
936                            : &*output_sharding_transposed_to_match_lhs;
937     } else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
938       slice_sharding = windowed_op_is_lhs
939                            ? &*lhs_sharding_transposed_to_match_rhs
940                            : &*rhs_sharding_transposed_to_match_lhs;
941     } else {
942       slice_sharding = windowed_op_is_lhs
943                            ? &*lhs_sharding_transposed_to_match_output
944                            : &*rhs_sharding_transposed_to_match_output;
945     }
946     CHECK_EQ(Product(slice_sharding->tile_assignment().dimensions()),
947              num_partitions);
948     int64_t slice_sharding_dim = -1;
949     for (int64_t i = 0; i < slice_sharding->tile_assignment().num_dimensions();
950          ++i) {
951       if (slice_sharding->tile_assignment().dim(i) > 1) {
952         slice_sharding_dim = i;
953         break;
954       }
955     }
956     int64_t lhs_concat_dim = -1;
957     int64_t rhs_concat_dim = -1;
958     if (operands_sharded_at_contracting_dims) {
959       if (windowed_op_is_lhs) {
960         rhs_concat_dim = slice_sharding_dim;
961       } else {
962         lhs_concat_dim = slice_sharding_dim;
963       }
964     } else if (windowed_at_contracting_dims || windowed_at_batch_dims) {
965       lhs_concat_dim = windowed_op_is_lhs
966                            ? indices_map.rhs_to_lhs_indices[slice_sharding_dim]
967                            : slice_sharding_dim;
968       rhs_concat_dim = windowed_op_is_lhs
969                            ? slice_sharding_dim
970                            : indices_map.lhs_to_rhs_indices[slice_sharding_dim];
971     } else {
972       if (windowed_op_is_lhs) {
973         lhs_concat_dim = indices_map.output_to_lhs_indices[slice_sharding_dim];
974       } else {
975         rhs_concat_dim = indices_map.output_to_rhs_indices[slice_sharding_dim];
976       }
977     }
978 
979     auto lhs_hlo = lhs.hlo();
980     auto rhs_hlo = rhs.hlo();
981     // Reshape lhs and rhs before the loop for bidirectional communication case.
982     if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
983       if (lhs_concat_dim != -1 && windowed_op_is_lhs &&
984           !operands_sharded_at_contracting_dims) {
985         std::vector<int64_t> reshaped_dims(
986             lhs_hlo->shape().dimensions().begin(),
987             lhs_hlo->shape().dimensions().end());
988         reshaped_dims.insert(reshaped_dims.begin() + lhs_concat_dim, 1);
989         lhs_hlo = b->AddInstruction(HloInstruction::CreateReshape(
990             ShapeUtil::MakeShape(lhs_hlo->shape().element_type(),
991                                  reshaped_dims),
992             lhs_hlo));
993       }
994       if (rhs_concat_dim != -1 && !windowed_op_is_lhs &&
995           !operands_sharded_at_contracting_dims) {
996         std::vector<int64_t> reshaped_dims(
997             rhs_hlo->shape().dimensions().begin(),
998             rhs_hlo->shape().dimensions().end());
999         reshaped_dims.insert(reshaped_dims.begin() + rhs_concat_dim, 1);
1000         rhs_hlo = b->AddInstruction(HloInstruction::CreateReshape(
1001             ShapeUtil::MakeShape(rhs_hlo->shape().element_type(),
1002                                  reshaped_dims),
1003             rhs_hlo));
1004       }
1005     }
1006 
1007     auto result_buffer = CreateZero(padded_result_buffer_shape, b);
1008     auto extra_buffer =
1009         (!(options.bidirectional_windowed_einsum && num_partitions % 4 == 0) ||
1010          operands_sharded_at_contracting_dims)
1011             ? CreateZero(padded_result_buffer_shape, b)
1012         : windowed_op_is_lhs ? lhs_hlo
1013                              : rhs_hlo;
1014 
1015     if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0 &&
1016         !operands_sharded_at_contracting_dims) {
1017       std::vector<std::pair<int64_t, int64_t>> pre_sd_pairs(num_partitions);
1018       for (int64_t source = 0; source < num_partitions; ++source) {
1019         // 0 -> 1, 1 -> 2, 2 -> 3, ...
1020         pre_sd_pairs[source] = {source, (source + 1) % num_partitions};
1021       }
1022       extra_buffer =
1023           lhs.state()
1024               .collective_ops_creator.create_cross_partition_collective_permute(
1025                   b, extra_buffer, pre_sd_pairs,
1026                   (*lhs.state().next_channel_id)++);
1027     }
1028 
1029     auto iteration = b->AddInstruction(
1030         HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32_t>(0)));
1031 
1032     // Create a while loop that computes one window per iteration. During each
1033     // iteration, each partition sends its input window to its neighbor using
1034     // collective-permute for the next iteration.
1035     SpmdBuilder body_b("windowed_dot_general_body", original_hlo);
1036 
1037     // Generate partial results used by bidirectional algorithm.
1038     auto get_partial_bid_results =
1039         [&](HloInstruction* l, HloInstruction* r, HloInstruction* o,
1040             HloInstruction* extra_inout, HloInstruction* cw_cp_output,
1041             HloInstruction* i) -> StatusOr<std::vector<HloInstruction*>> {
1042       auto partition_id =
1043           lhs.state().collective_ops_creator.create_partition_id(&body_b);
1044       auto partition_count =
1045           body_b.AddInstruction(HloInstruction::CreateConstant(
1046               LiteralUtil::CreateR0<uint32_t>(num_partitions)));
1047       auto ccw_data_partition_id =
1048           body_b.AddInstruction(HloInstruction::CreateBinary(
1049               i->shape(), HloOpcode::kAdd, i, partition_id));
1050       auto cw_data_partition_id =
1051           body_b.AddInstruction(HloInstruction::CreateBinary(
1052               i->shape(), HloOpcode::kAdd, partition_count, partition_id));
1053       if (operands_sharded_at_contracting_dims) {
1054         ccw_data_partition_id =
1055             body_b.AddInstruction(HloInstruction::CreateBinary(
1056                 i->shape(), HloOpcode::kAdd, ccw_data_partition_id,
1057                 body_b.AddInstruction(HloInstruction::CreateConstant(
1058                     LiteralUtil::CreateR0<uint32_t>(num_partitions / 2 + 1)))));
1059         cw_data_partition_id =
1060             body_b.AddInstruction(HloInstruction::CreateBinary(
1061                 i->shape(), HloOpcode::kSubtract, cw_data_partition_id,
1062                 body_b.AddInstruction(HloInstruction::CreateConstant(
1063                     LiteralUtil::CreateR0<uint32_t>(num_partitions / 2)))));
1064       } else {
1065         cw_data_partition_id =
1066             body_b.AddInstruction(HloInstruction::CreateBinary(
1067                 i->shape(), HloOpcode::kSubtract, cw_data_partition_id,
1068                 CreateOne(cw_data_partition_id->shape(), &body_b)));
1069       }
1070       ccw_data_partition_id = body_b.AddInstruction(
1071           HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
1072                                        ccw_data_partition_id, partition_count));
1073       cw_data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary(
1074           i->shape(), HloOpcode::kSubtract, cw_data_partition_id, i));
1075       cw_data_partition_id = body_b.AddInstruction(
1076           HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
1077                                        cw_data_partition_id, partition_count));
1078 
1079       DotDimensionNumbers new_ddnums;
1080       if (original_hlo->opcode() == HloOpcode::kDot) {
1081         new_ddnums = original_hlo->dot_dimension_numbers();
1082       }
1083 
1084       auto dot_lhs = l;
1085       auto dot_rhs = r;
1086       auto original_dot_lhs = l;
1087       auto original_dot_rhs = r;
1088       // Recover original lhs and rhs, will not be used in real computation.
1089       if (lhs_concat_dim != -1 && windowed_op_is_lhs) {
1090         std::vector<int64_t> reshaped_dims(
1091             original_dot_lhs->shape().dimensions().begin(),
1092             original_dot_lhs->shape().dimensions().end());
1093         reshaped_dims.erase(reshaped_dims.begin() + lhs_concat_dim);
1094         original_dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1095             ShapeUtil::MakeShape(original_dot_lhs->shape().element_type(),
1096                                  reshaped_dims),
1097             original_dot_lhs));
1098       }
1099       if (rhs_concat_dim != -1 && !windowed_op_is_lhs) {
1100         std::vector<int64_t> reshaped_dims(
1101             original_dot_rhs->shape().dimensions().begin(),
1102             original_dot_rhs->shape().dimensions().end());
1103         reshaped_dims.erase(reshaped_dims.begin() + rhs_concat_dim);
1104         original_dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1105             ShapeUtil::MakeShape(original_dot_rhs->shape().element_type(),
1106                                  reshaped_dims),
1107             original_dot_rhs));
1108       }
1109 
1110       if (windowed_at_contracting_dims || windowed_at_batch_dims ||
1111           operands_sharded_at_contracting_dims) {
1112         // Slice the matching operand according to the partitioned dimensions
1113         // on the windowed operand or the output.
1114         auto slice_operand = !windowed_op_is_lhs ? l : r;
1115 
1116         // Pad the sharding dim first (then the concat dim) for correctness.
1117         auto sharding_dim_size =
1118             slice_operand->shape().dimensions(slice_sharding_dim);
1119         if (sharding_dim_size % num_partitions != 0) {
1120           slice_operand = PadBaseShapeBeforeUnevenTiledSharding(
1121               slice_operand, *slice_sharding, &body_b);
1122         }
1123 
1124         // We do this by treating the matching operand as replicated, and
1125         // resharding it to match the windowed operand or the output.
1126         auto gen_slice = [&](HloInstruction* data_partition_id,
1127                              bool ccw) -> HloInstruction* {
1128           std::vector<int64_t> new_dims;
1129           const int64_t dimensions_size =
1130               slice_operand->shape().dimensions_size();
1131           new_dims.reserve(dimensions_size + 1);
1132           for (int64_t i = 0; i < dimensions_size; ++i) {
1133             if (i == slice_sharding_dim) {
1134               new_dims.push_back(1);
1135             }
1136             new_dims.push_back(slice_operand->shape().dimensions(i));
1137           }
1138           auto reshaped_slice_operand =
1139               body_b.AddInstruction(HloInstruction::CreateReshape(
1140                   ShapeUtil::MakeShape(slice_operand->shape().element_type(),
1141                                        new_dims),
1142                   slice_operand));
1143           auto min = body_b.AddInstruction(
1144               HloInstruction::CreateConstant(LiteralUtil::MinValue(
1145                   reshaped_slice_operand->shape().element_type())));
1146           std::vector<int64_t> min_padding(
1147               reshaped_slice_operand->shape().rank());
1148           auto padded_slice_operand = reshaped_slice_operand;
1149           auto padded_shape = padded_slice_operand->shape();
1150           int64_t padding_dim = slice_sharding_dim;
1151           padded_shape.set_dimensions(padding_dim, 2);
1152           if (ccw) {
1153             // ccw pad high
1154             PaddingConfig ccw_pad_config =
1155                 window_util::MakeSymmetricPadding(min_padding);
1156             ccw_pad_config.mutable_dimensions(padding_dim)
1157                 ->set_edge_padding_low(0);
1158             ccw_pad_config.mutable_dimensions(padding_dim)
1159                 ->set_edge_padding_high(1);
1160             padded_slice_operand =
1161                 body_b.AddInstruction(HloInstruction::CreatePad(
1162                     padded_shape, padded_slice_operand, min, ccw_pad_config));
1163           } else {
1164             // cw pad low
1165             PaddingConfig cw_pad_config =
1166                 window_util::MakeSymmetricPadding(min_padding);
1167             cw_pad_config.mutable_dimensions(padding_dim)
1168                 ->set_edge_padding_low(1);
1169             cw_pad_config.mutable_dimensions(padding_dim)
1170                 ->set_edge_padding_high(0);
1171             padded_slice_operand =
1172                 body_b.AddInstruction(HloInstruction::CreatePad(
1173                     padded_shape, padded_slice_operand, min, cw_pad_config));
1174           }
1175 
1176           padded_slice_operand->set_sharding(HloSharding::Replicate());
1177           auto state = lhs.state();
1178           state.b = &body_b;
1179           state.partition_id = data_partition_id;
1180           state.reshard_cache->per_hlo_cache.erase(padded_slice_operand);
1181           auto padded_slice_sharding = hlo_sharding_util::ReshapeSharding(
1182               slice_operand->shape(), reshaped_slice_operand->shape(),
1183               *slice_sharding);
1184           auto padded_slice =
1185               PartitionedHlo(padded_slice_operand,
1186                              padded_slice_operand->shape(), state)
1187                   .Reshard(*padded_slice_sharding)
1188                   .hlo();
1189           padded_slice_operand->clear_sharding();
1190           return padded_slice;
1191         };
1192 
1193         auto ccw_slice = gen_slice(ccw_data_partition_id, true);
1194         auto cw_slice = gen_slice(cw_data_partition_id, false);
1195         auto slice = body_b.AddInstruction(HloInstruction::CreateBinary(
1196             ccw_slice->shape(), HloOpcode::kMaximum, ccw_slice, cw_slice));
1197         // Reshape. The reshaped slice will not be used to produce the final
1198         // result, but used as a hint for the shape inference.
1199         std::vector<int64_t> reshaped_slice_dims;
1200         const int64_t dim_size = slice->shape().dimensions_size();
1201         reshaped_slice_dims.reserve(dim_size);
1202         for (int64_t i = 0; i < dim_size; ++i) {
1203           auto dim_size = slice->shape().dimensions(i);
1204           if (i == (slice_sharding_dim + 1)) {
1205             reshaped_slice_dims.push_back(dim_size * 2);
1206           } else if (i != slice_sharding_dim) {
1207             reshaped_slice_dims.push_back(dim_size);
1208           }
1209         }
1210         auto reshaped_slice =
1211             body_b.AddInstruction(HloInstruction::CreateReshape(
1212                 ShapeUtil::MakeShape(slice->shape().element_type(),
1213                                      reshaped_slice_dims),
1214                 slice));
1215 
1216         if (!windowed_op_is_lhs) {
1217           dot_lhs = slice;
1218           original_dot_lhs = reshaped_slice;
1219           if (original_hlo->opcode() == HloOpcode::kDot) {
1220             UpdateDDNums(&new_ddnums, slice_sharding_dim, true);
1221           }
1222         } else {
1223           dot_rhs = slice;
1224           original_dot_rhs = reshaped_slice;
1225           if (original_hlo->opcode() == HloOpcode::kDot) {
1226             UpdateDDNums(&new_ddnums, slice_sharding_dim, false);
1227           }
1228         }
1229       }
1230 
1231       auto ccw_dot_lhs = l;
1232       auto ccw_dot_rhs = r;
1233       auto cw_dot_lhs = windowed_op_is_lhs ? extra_inout : l;
1234       auto cw_dot_rhs = windowed_op_is_lhs ? r : extra_inout;
1235       if (lhs_concat_dim != -1 && windowed_op_is_lhs) {
1236         // Concat
1237         auto lhs_concat_shape = ccw_dot_lhs->shape();
1238         lhs_concat_shape.set_dimensions(lhs_concat_dim, 2);
1239         dot_lhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
1240             lhs_concat_shape, {ccw_dot_lhs, cw_dot_lhs}, lhs_concat_dim));
1241 
1242         std::vector<int64_t> reshaped_dims(
1243             ccw_dot_lhs->shape().dimensions().begin(),
1244             ccw_dot_lhs->shape().dimensions().end());
1245         reshaped_dims.erase(reshaped_dims.begin() + lhs_concat_dim);
1246         reshaped_dims[lhs_concat_dim] *= 2;
1247         original_dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1248             ShapeUtil::MakeShape(dot_lhs->shape().element_type(),
1249                                  reshaped_dims),
1250             dot_lhs));
1251 
1252         if (original_hlo->opcode() == HloOpcode::kDot) {
1253           UpdateDDNums(&new_ddnums, lhs_concat_dim, true);
1254         }
1255       }
1256       if (rhs_concat_dim != -1 && !windowed_op_is_lhs) {
1257         // Concat
1258         auto rhs_concat_shape = ccw_dot_rhs->shape();
1259         rhs_concat_shape.set_dimensions(rhs_concat_dim, 2);
1260         dot_rhs = body_b.AddInstruction(HloInstruction::CreateConcatenate(
1261             rhs_concat_shape, {ccw_dot_rhs, cw_dot_rhs}, rhs_concat_dim));
1262 
1263         std::vector<int64_t> reshaped_dims(
1264             ccw_dot_rhs->shape().dimensions().begin(),
1265             ccw_dot_rhs->shape().dimensions().end());
1266         reshaped_dims.erase(reshaped_dims.begin() + rhs_concat_dim);
1267         reshaped_dims[rhs_concat_dim] *= 2;
1268         original_dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1269             ShapeUtil::MakeShape(dot_rhs->shape().element_type(),
1270                                  reshaped_dims),
1271             dot_rhs));
1272 
1273         if (original_hlo->opcode() == HloOpcode::kDot) {
1274           UpdateDDNums(&new_ddnums, rhs_concat_dim, false);
1275         }
1276       }
1277 
1278       // The generated original dot will not be used.
1279       TF_ASSIGN_OR_RETURN(auto original_dot,
1280                           create_sharded_dot(original_dot_lhs, original_dot_rhs,
1281                                              &body_b, conv_window));
1282       VLOG(2) << original_dot->ToString();
1283 
1284       // Generate the correct shape of the new dot/conv.
1285       auto original_sharded_dot_shape = original_dot->shape();
1286       auto new_dot_shape = original_sharded_dot_shape;
1287       std::vector<int64_t> new_dims(new_dot_shape.dimensions().begin(),
1288                                     new_dot_shape.dimensions().end());
1289       if (!windowed_at_contracting_dims) {
1290         auto slice_dim =
1291             lhs_concat_dim != -1
1292                 ? indices_map.lhs_to_output_indices[lhs_concat_dim]
1293                 : indices_map.rhs_to_output_indices[rhs_concat_dim];
1294         new_dims[slice_dim] /= 2;
1295         new_dims.insert(new_dims.begin() + slice_dim, 2);
1296       } else if (original_hlo->opcode() != HloOpcode::kDot) {
1297         new_dims.push_back(1);
1298       }
1299       new_dot_shape =
1300           ShapeUtil::MakeShape(original_hlo->shape().element_type(), new_dims);
1301 
1302       HloInstruction* dot;
1303       if (original_hlo->opcode() == HloOpcode::kDot) {
1304         dot = body_b.AddInstruction(HloInstruction::CreateDot(
1305             new_dot_shape, dot_lhs, dot_rhs, new_ddnums,
1306             original_hlo->precision_config()));
1307       } else {
1308         if (!windowed_at_contracting_dims && !windowed_at_batch_dims) {
1309           if (lhs_concat_dim != -1) {
1310             std::vector<int64_t> new_dims(dot_rhs->shape().dimensions().begin(),
1311                                           dot_rhs->shape().dimensions().end());
1312             new_dims.push_back(1);
1313             dot_rhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1314                 ShapeUtil::MakeShape(dot_rhs->shape().element_type(), new_dims),
1315                 dot_rhs));
1316           }
1317           if (rhs_concat_dim != -1) {
1318             std::vector<int64_t> new_dims(dot_lhs->shape().dimensions().begin(),
1319                                           dot_lhs->shape().dimensions().end());
1320             new_dims.push_back(1);
1321             dot_lhs = body_b.AddInstruction(HloInstruction::CreateReshape(
1322                 ShapeUtil::MakeShape(dot_lhs->shape().element_type(), new_dims),
1323                 dot_lhs));
1324           }
1325         }
1326 
1327         dot = body_b.AddInstruction(HloInstruction::CreateConvolve(
1328             new_dot_shape, dot_lhs, dot_rhs,
1329             original_dot->feature_group_count(),
1330             original_dot->batch_group_count(),
1331             GenNewWindow(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
1332                          rhs_concat_dim, windowed_at_contracting_dims,
1333                          windowed_at_batch_dims),
1334             GenNewConvDNums(original_dot, dot_lhs, dot_rhs, lhs_concat_dim,
1335                             rhs_concat_dim, windowed_at_contracting_dims,
1336                             windowed_at_batch_dims,
1337                             indices_map.lhs_to_output_indices,
1338                             indices_map.rhs_to_output_indices, new_dot_shape),
1339             original_dot->precision_config()));
1340       }
1341       VLOG(2) << dot->ToString();
1342 
1343       if (windowed_at_contracting_dims) {
1344         if (original_hlo->opcode() != HloOpcode::kDot) {
1345           // Reshape to the original sharded dot shape.
1346           dot = body_b.AddInstruction(
1347               HloInstruction::CreateReshape(original_sharded_dot_shape, dot));
1348         }
1349 
1350         // Accumulate the partial output to the result buffer.
1351         o = body_b.AddInstruction(
1352             HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot));
1353       } else {
1354         // The windowing operand is partitioned along batch/non-contracting
1355         // dimensions, so we need a dynamic-update-slice to save the partial
1356         // output in the result buffer.
1357         auto slice_shape = dot->shape();
1358         auto slice_dim =
1359             lhs_concat_dim != -1
1360                 ? indices_map.lhs_to_output_indices[lhs_concat_dim]
1361                 : indices_map.rhs_to_output_indices[rhs_concat_dim];
1362         slice_shape.set_dimensions(slice_dim, 1);
1363         std::vector<int64_t> ccw_start_indices(dot->shape().rank(), 0);
1364         std::vector<int64_t> cw_start_indices(dot->shape().rank(), 0);
1365         cw_start_indices[slice_dim] = 1;
1366         auto ccw_dot = body_b.AddInstruction(HloInstruction::CreateSlice(
1367             slice_shape, dot, ccw_start_indices, slice_shape.dimensions(),
1368             std::vector<int64_t>(dot->shape().rank(), 1)));
1369         auto cw_dot = body_b.AddInstruction(HloInstruction::CreateSlice(
1370             slice_shape, dot, cw_start_indices, dot->shape().dimensions(),
1371             std::vector<int64_t>(dot->shape().rank(), 1)));
1372 
1373         std::vector<int64_t> reshaped_dims(
1374             original_sharded_dot_shape.dimensions().begin(),
1375             original_sharded_dot_shape.dimensions().end());
1376         reshaped_dims[slice_dim] /= 2;
1377         ccw_dot = body_b.AddInstruction(HloInstruction::CreateReshape(
1378             ShapeUtil::MakeShape(ccw_dot->shape().element_type(),
1379                                  reshaped_dims),
1380             ccw_dot));
1381         cw_dot = body_b.AddInstruction(HloInstruction::CreateReshape(
1382             ShapeUtil::MakeShape(cw_dot->shape().element_type(), reshaped_dims),
1383             cw_dot));
1384 
1385         if (operands_sharded_at_contracting_dims) {
1386           // Accumulate the partial output to the result buffer.
1387           o = body_b.AddInstruction(HloInstruction::CreateBinary(
1388               o->shape(), HloOpcode::kAdd, o, ccw_dot));
1389           cw_cp_output = body_b.AddInstruction(HloInstruction::CreateBinary(
1390               o->shape(), HloOpcode::kAdd, cw_cp_output, cw_dot));
1391         } else {
1392           auto ccw_offsets = MakePartitionOffsets(
1393               o->shape(),
1394               windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1395                                  : *rhs_sharding_transposed_to_match_output,
1396               ccw_data_partition_id, &body_b);
1397           auto cw_offsets = MakePartitionOffsets(
1398               o->shape(),
1399               windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1400                                  : *rhs_sharding_transposed_to_match_output,
1401               cw_data_partition_id, &body_b);
1402           o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1403               o->shape(), o, ccw_dot, ccw_offsets));
1404           o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1405               o->shape(), o, cw_dot, cw_offsets));
1406         }
1407       }
1408 
1409       std::vector<HloInstruction*> partial_results;
1410       partial_results.push_back(o);
1411       partial_results.push_back(cw_cp_output);
1412       return partial_results;
1413     };
1414 
1415     // Generate partial result used by unidirectional algorithm.
1416     auto get_partial_unid_result =
1417         [&](HloInstruction* l, HloInstruction* r, HloInstruction* o,
1418             HloInstruction* i) -> StatusOr<HloInstruction*> {
1419       auto partition_id =
1420           lhs.state().collective_ops_creator.create_partition_id(&body_b);
1421       auto data_partition_id =
1422           body_b.AddInstruction(HloInstruction::CreateBinary(
1423               i->shape(), HloOpcode::kAdd, i, partition_id));
1424       auto partition_count =
1425           body_b.AddInstruction(HloInstruction::CreateConstant(
1426               LiteralUtil::CreateR0<uint32_t>(num_partitions)));
1427       data_partition_id = body_b.AddInstruction(
1428           HloInstruction::CreateBinary(i->shape(), HloOpcode::kRemainder,
1429                                        data_partition_id, partition_count));
1430       auto dot_lhs = l;
1431       auto dot_rhs = r;
1432       if (windowed_at_contracting_dims || windowed_at_batch_dims ||
1433           operands_sharded_at_contracting_dims) {
1434         // Slice the matching operand according to the partitioned dimensions on
1435         // the windowed operand or the output.
1436         auto slice_operand = !windowed_op_is_lhs ? l : r;
1437         // We do this by treating the matching operand as replicated, and
1438         // resharding it to match the windowed operand or the output.
1439         slice_operand->set_sharding(HloSharding::Replicate());
1440         auto state = lhs.state();
1441         state.b = &body_b;
1442         state.partition_id = data_partition_id;
1443         state.reshard_cache->per_hlo_cache.erase(slice_operand);
1444         auto slice =
1445             PartitionedHlo(slice_operand, slice_operand->shape(), state)
1446                 .Reshard(*slice_sharding)
1447                 .hlo();
1448         slice_operand->clear_sharding();
1449         if (!windowed_op_is_lhs) {
1450           dot_lhs = slice;
1451         } else {
1452           dot_rhs = slice;
1453         }
1454       }
1455       TF_ASSIGN_OR_RETURN(
1456           auto dot, create_sharded_dot(dot_lhs, dot_rhs, &body_b, conv_window));
1457       if (windowed_at_contracting_dims ||
1458           operands_sharded_at_contracting_dims) {
1459         // Accumulate the partial output to the result buffer.
1460         o = body_b.AddInstruction(
1461             HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot));
1462       } else {
1463         // The windowing operand is partitioned along batch/non-contracting
1464         // dimensions, so we need a dynamic-update-slice to save the partial
1465         // output in the result buffer.
1466         auto offsets = MakePartitionOffsets(
1467             o->shape(),
1468             windowed_op_is_lhs ? *lhs_sharding_transposed_to_match_output
1469                                : *rhs_sharding_transposed_to_match_output,
1470             data_partition_id, &body_b);
1471         o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
1472             o->shape(), o, dot, offsets));
1473       }
1474       return o;
1475     };
1476 
1477     auto param = body_b.AddInstruction(HloInstruction::CreateParameter(
1478         /*parameter_number=*/0,
1479         ShapeUtil::MakeTupleShapeWithPtrs(
1480             {&lhs_hlo->shape(), &rhs_hlo->shape(), &result_buffer->shape(),
1481              &extra_buffer->shape(), &iteration->shape()}),
1482         "param"));
1483     auto l = body_b.AddInstruction(
1484         HloInstruction::CreateGetTupleElement(lhs_hlo->shape(), param, 0));
1485     auto r = body_b.AddInstruction(
1486         HloInstruction::CreateGetTupleElement(rhs_hlo->shape(), param, 1));
1487     auto o = body_b.AddInstruction(HloInstruction::CreateGetTupleElement(
1488         result_buffer->shape(), param, 2));
1489     auto extra_inout = body_b.AddInstruction(
1490         HloInstruction::CreateGetTupleElement(extra_buffer->shape(), param, 3));
1491     auto i = body_b.AddInstruction(
1492         HloInstruction::CreateGetTupleElement(iteration->shape(), param, 4));
1493 
1494     // The bidirectional collective permute implementation has loop unrolling
1495     // of degree 2, so num_partitions is required to be a multiple of 4.
1496     if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
1497       std::vector<std::pair<int64_t, int64_t>> ccw_sd_pairs(num_partitions);
1498       for (int64_t source = 0; source < num_partitions; ++source) {
1499         // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1500         ccw_sd_pairs[source] = {source,
1501                                 (source - 1 + num_partitions) % num_partitions};
1502       }
1503       std::vector<std::pair<int64_t, int64_t>> cw_sd_pairs(num_partitions);
1504       for (int64_t source = 0; source < num_partitions; ++source) {
1505         // 0 -> 1, 1 -> 2, 2 -> 3, ...
1506         cw_sd_pairs[source] = {source, (source + 1) % num_partitions};
1507       }
1508 
1509       // Even number iteration.
1510       auto next_l = l;
1511       auto next_r = r;
1512       auto ccw_cp_input = operands_sharded_at_contracting_dims ? o
1513                           : windowed_op_is_lhs                 ? l
1514                                                                : r;
1515       auto ccw_cp_output =
1516           lhs.state()
1517               .collective_ops_creator.create_cross_partition_collective_permute(
1518                   &body_b, ccw_cp_input, ccw_sd_pairs,
1519                   (*lhs.state().next_channel_id)++);
1520       if (operands_sharded_at_contracting_dims) {
1521         o = ccw_cp_output;
1522       } else if (windowed_op_is_lhs) {
1523         next_l = ccw_cp_output;
1524       } else {
1525         next_r = ccw_cp_output;
1526       }
1527       auto cw_cp_input = extra_inout;
1528       auto cw_cp_output =
1529           lhs.state()
1530               .collective_ops_creator.create_cross_partition_collective_permute(
1531                   &body_b, cw_cp_input, cw_sd_pairs,
1532                   (*lhs.state().next_channel_id)++);
1533 
1534       TF_ASSIGN_OR_RETURN(
1535           auto outputs,
1536           get_partial_bid_results(l, r, o, extra_inout, cw_cp_output, i));
1537       o = outputs[0];
1538       cw_cp_output = outputs[1];
1539 
1540       // ++i
1541       i = body_b.AddInstruction(HloInstruction::CreateBinary(
1542           i->shape(), HloOpcode::kAdd, i, CreateOne(i->shape(), &body_b)));
1543 
1544       // Odd number iteration.
1545       auto second_next_l = next_l;
1546       auto second_next_r = next_r;
1547       ccw_cp_input = operands_sharded_at_contracting_dims ? o
1548                      : windowed_op_is_lhs                 ? next_l
1549                                                           : next_r;
1550       ccw_cp_output =
1551           lhs.state()
1552               .collective_ops_creator.create_cross_partition_collective_permute(
1553                   &body_b, ccw_cp_input, ccw_sd_pairs,
1554                   (*lhs.state().next_channel_id)++);
1555       if (operands_sharded_at_contracting_dims) {
1556         o = ccw_cp_output;
1557       } else if (windowed_op_is_lhs) {
1558         second_next_l = ccw_cp_output;
1559       } else {
1560         second_next_r = ccw_cp_output;
1561       }
1562       auto next_cw_cp_input = cw_cp_output;
1563       auto next_cw_cp_output =
1564           lhs.state()
1565               .collective_ops_creator.create_cross_partition_collective_permute(
1566                   &body_b, next_cw_cp_input, cw_sd_pairs,
1567                   (*lhs.state().next_channel_id)++);
1568 
1569       TF_ASSIGN_OR_RETURN(
1570           outputs, get_partial_bid_results(next_l, next_r, o, cw_cp_output,
1571                                            next_cw_cp_output, i));
1572       o = outputs[0];
1573       next_cw_cp_output = outputs[1];
1574 
1575       // ++i
1576       i = body_b.AddInstruction(HloInstruction::CreateBinary(
1577           i->shape(), HloOpcode::kAdd, i, CreateOne(i->shape(), &body_b)));
1578 
1579       body_b.AddInstruction(HloInstruction::CreateTuple(
1580           {second_next_l, second_next_r, o, next_cw_cp_output, i}));
1581 
1582     } else if (options.unroll_windowed_einsum && num_partitions % 2 == 0) {
1583       if (operands_sharded_at_contracting_dims) {
1584         std::vector<std::pair<int64_t, int64_t>> output_sd_pairs(
1585             num_partitions);
1586         for (int64_t source = 0; source < num_partitions; ++source) {
1587           // 0 -> n-2, 1 -> n-1, 2 -> 0, ...
1588           output_sd_pairs[source] = {
1589               source, (source - 2 + num_partitions) % num_partitions};
1590         }
1591 
1592         o = lhs.state()
1593                 .collective_ops_creator
1594                 .create_cross_partition_collective_permute(
1595                     &body_b, o, output_sd_pairs,
1596                     (*lhs.state().next_channel_id)++);
1597 
1598         TF_ASSIGN_OR_RETURN(extra_inout,
1599                             get_partial_unid_result(l, r, extra_inout, i));
1600 
1601         extra_inout = lhs.state()
1602                           .collective_ops_creator
1603                           .create_cross_partition_collective_permute(
1604                               &body_b, extra_inout, output_sd_pairs,
1605                               (*lhs.state().next_channel_id)++);
1606 
1607         // i+2
1608         i = body_b.AddInstruction(HloInstruction::CreateBinary(
1609             i->shape(), HloOpcode::kAdd, i,
1610             body_b.AddInstruction(HloInstruction::CreateConstant(
1611                 LiteralUtil::CreateR0<uint32_t>(2)))));
1612         auto real_i = body_b.AddInstruction(HloInstruction::CreateBinary(
1613             i->shape(), HloOpcode::kAdd, i,
1614             body_b.AddInstruction(HloInstruction::CreateConstant(
1615                 LiteralUtil::CreateR0<uint32_t>(1)))));
1616 
1617         TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, real_i));
1618         body_b.AddInstruction(
1619             HloInstruction::CreateTuple({l, r, o, extra_inout, i}));
1620       } else {
1621         std::vector<std::pair<int64_t, int64_t>> sd_pairs(num_partitions);
1622         for (int64_t source = 0; source < num_partitions; ++source) {
1623           // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1624           sd_pairs[source] = {source,
1625                               (source - 1 + num_partitions) % num_partitions};
1626         }
1627 
1628         // Even number iteration.
1629         auto next_l = l;
1630         auto next_r = r;
1631         auto cp_input = windowed_op_is_lhs ? l : r;
1632         auto cp_output = lhs.state()
1633                              .collective_ops_creator
1634                              .create_cross_partition_collective_permute(
1635                                  &body_b, cp_input, sd_pairs,
1636                                  (*lhs.state().next_channel_id)++);
1637         if (windowed_op_is_lhs) {
1638           next_l = cp_output;
1639         } else {
1640           next_r = cp_output;
1641         }
1642         TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, i));
1643 
1644         // ++i
1645         i = body_b.AddInstruction(HloInstruction::CreateBinary(
1646             i->shape(), HloOpcode::kAdd, i,
1647             body_b.AddInstruction(HloInstruction::CreateConstant(
1648                 LiteralUtil::CreateR0<uint32_t>(1)))));
1649 
1650         // Odd number iteration.
1651         auto second_next_l = next_l;
1652         auto second_next_r = next_r;
1653         cp_input = windowed_op_is_lhs ? next_l : next_r;
1654         cp_output = lhs.state()
1655                         .collective_ops_creator
1656                         .create_cross_partition_collective_permute(
1657                             &body_b, cp_input, sd_pairs,
1658                             (*lhs.state().next_channel_id)++);
1659         if (windowed_op_is_lhs) {
1660           second_next_l = cp_output;
1661         } else {
1662           second_next_r = cp_output;
1663         }
1664         TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(next_l, next_r, o, i));
1665 
1666         // ++i
1667         i = body_b.AddInstruction(HloInstruction::CreateBinary(
1668             i->shape(), HloOpcode::kAdd, i,
1669             body_b.AddInstruction(HloInstruction::CreateConstant(
1670                 LiteralUtil::CreateR0<uint32_t>(1)))));
1671 
1672         body_b.AddInstruction(HloInstruction::CreateTuple(
1673             {second_next_l, second_next_r, o, extra_inout, i}));
1674       }
1675     } else {
1676       auto real_i = i;
1677       if (operands_sharded_at_contracting_dims) {
1678         // For reduce-scatter case, start from the data_partition_id + 1 to make
1679         // the data_partition_id of the final data shard in each partition the
1680         // same as the corresponding partition_id.
1681         real_i = body_b.AddInstruction(HloInstruction::CreateBinary(
1682             real_i->shape(), HloOpcode::kAdd, real_i,
1683             CreateOne(real_i->shape(), &body_b)));
1684       }
1685       TF_ASSIGN_OR_RETURN(o, get_partial_unid_result(l, r, o, real_i));
1686 
1687       // ++i
1688       i = body_b.AddInstruction(HloInstruction::CreateBinary(
1689           i->shape(), HloOpcode::kAdd, i,
1690           body_b.AddInstruction(HloInstruction::CreateConstant(
1691               LiteralUtil::CreateR0<uint32_t>(1)))));
1692       auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare(
1693           ShapeUtil::MakeShape(PRED, {}), i,
1694           body_b.AddInstruction(HloInstruction::CreateConstant(
1695               LiteralUtil::CreateR0<uint32_t>(num_partitions))),
1696           ComparisonDirection::kLt));
1697       // Collective-permute for the next window. We don't need it for the last
1698       // iteration, so we use a conditional around the collective-permute.
1699       HloInstruction* conditional;
1700       {
1701         SpmdBuilder cp_b("window_collective_permute", original_hlo);
1702         {
1703           auto p = cp_b.AddInstruction(HloInstruction::CreateParameter(
1704               0,
1705               operands_sharded_at_contracting_dims ? o->shape()
1706               : windowed_op_is_lhs                 ? l->shape()
1707                                                    : r->shape(),
1708               "window"));
1709           std::vector<std::pair<int64_t, int64_t>> sd_pairs(num_partitions);
1710           for (int64_t source = 0; source < num_partitions; ++source) {
1711             // 0 -> n-1, 1 -> 0, 2 -> 1, ...
1712             sd_pairs[source] = {source,
1713                                 (source - 1 + num_partitions) % num_partitions};
1714           }
1715           lhs.state()
1716               .collective_ops_creator.create_cross_partition_collective_permute(
1717                   &cp_b, p, sd_pairs, (*lhs.state().next_channel_id)++);
1718         }
1719         SpmdBuilder ncp_b("last_iteration_noop", original_hlo);
1720         {
1721           ncp_b.AddInstruction(HloInstruction::CreateParameter(
1722               0,
1723               operands_sharded_at_contracting_dims ? o->shape()
1724               : windowed_op_is_lhs                 ? l->shape()
1725                                                    : r->shape(),
1726               "window"));
1727         }
1728         conditional = body_b.AddInstruction(HloInstruction::CreateConditional(
1729             operands_sharded_at_contracting_dims ? o->shape()
1730             : windowed_op_is_lhs                 ? l->shape()
1731                                                  : r->shape(),
1732             has_more,
1733             operands_sharded_at_contracting_dims ? o
1734             : windowed_op_is_lhs                 ? l
1735                                                  : r,
1736             module->AddEmbeddedComputation(cp_b.Build()),
1737             operands_sharded_at_contracting_dims ? o
1738             : windowed_op_is_lhs                 ? l
1739                                                  : r,
1740             module->AddEmbeddedComputation(ncp_b.Build())));
1741       }
1742       if (operands_sharded_at_contracting_dims) {
1743         o = conditional;
1744       } else if (windowed_op_is_lhs) {
1745         l = conditional;
1746       } else {
1747         r = conditional;
1748       }
1749       body_b.AddInstruction(
1750           HloInstruction::CreateTuple({l, r, o, extra_inout, i}));
1751     }
1752 
1753     SpmdBuilder cond_b("windowed_dot_general_cond", original_hlo);
1754     auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter(
1755         /*parameter_number=*/0,
1756         ShapeUtil::MakeTupleShapeWithPtrs(
1757             {&lhs_hlo->shape(), &rhs_hlo->shape(), &result_buffer->shape(),
1758              &extra_buffer->shape(), &iteration->shape()}),
1759         "param"));
1760     auto cond_i = cond_b.AddInstruction(HloInstruction::CreateGetTupleElement(
1761         iteration->shape(), cond_param, 4));
1762     int64_t adapted_num_partitions =
1763         (options.bidirectional_windowed_einsum && num_partitions % 4 == 0)
1764             ? num_partitions / 2
1765             : num_partitions;
1766     cond_b.AddInstruction(HloInstruction::CreateCompare(
1767         ShapeUtil::MakeShape(PRED, {}), cond_i,
1768         cond_b.AddInstruction(HloInstruction::CreateConstant(
1769             LiteralUtil::CreateR0<uint32_t>(adapted_num_partitions))),
1770         ComparisonDirection::kLt));
1771     auto while_loop = b->AddInstruction(HloInstruction::CreateWhile(
1772         cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()),
1773         module->AddEmbeddedComputation(body_b.Build()),
1774         b->AddInstruction(HloInstruction::CreateTuple(
1775             {lhs_hlo, rhs_hlo, result_buffer, extra_buffer, iteration}))));
1776     windowed_dot_general_loops->push_back(
1777         {while_loop, windowed_op_is_lhs ? 0 : 1, windowed_at_contracting_dims,
1778          windowed_at_batch_dims, operands_sharded_at_contracting_dims,
1779          num_partitions, GetLoopReplicaGroups(while_loop)});
1780     auto result = b->AddInstruction(HloInstruction::CreateGetTupleElement(
1781         result_buffer->shape(), while_loop, 2));
1782     if (((options.bidirectional_windowed_einsum && num_partitions % 4 == 0) ||
1783          (options.unroll_windowed_einsum && num_partitions % 2 == 0)) &&
1784         operands_sharded_at_contracting_dims) {
1785       std::vector<std::pair<int64_t, int64_t>> extra_sd_pairs(num_partitions);
1786       for (int64_t source = 0; source < num_partitions; ++source) {
1787         // 0 -> 1, 1 -> 2, 2 -> 3, ...
1788         extra_sd_pairs[source] = {source, (source + 1) % num_partitions};
1789       }
1790       auto extra_result =
1791           b->AddInstruction(HloInstruction::CreateGetTupleElement(
1792               extra_buffer->shape(), while_loop, 3));
1793       if (options.bidirectional_windowed_einsum && num_partitions % 4 == 0) {
1794         extra_result = lhs.state()
1795                            .collective_ops_creator
1796                            .create_cross_partition_collective_permute(
1797                                b, extra_result, extra_sd_pairs,
1798                                (*lhs.state().next_channel_id)++);
1799       }
1800       if (options.unroll_windowed_einsum && num_partitions % 2 == 0) {
1801         result = lhs.state()
1802                      .collective_ops_creator
1803                      .create_cross_partition_collective_permute(
1804                          b, result, extra_sd_pairs,
1805                          (*lhs.state().next_channel_id)++);
1806       }
1807       result = b->AddInstruction(HloInstruction::CreateBinary(
1808           result->shape(), HloOpcode::kAdd, result, extra_result));
1809     }
1810     if (!ShapeUtil::Compatible(padded_result_buffer_shape,
1811                                unpadded_result_buffer_shape)) {
1812       result = b->AddInstruction(HloInstruction::CreateSlice(
1813           unpadded_result_buffer_shape, result,
1814           std::vector<int64_t>(padded_result_buffer_shape.rank(), 0),
1815           unpadded_result_buffer_shape.dimensions(),
1816           std::vector<int64_t>(padded_result_buffer_shape.rank(), 1)));
1817     }
1818     return result;
1819   };
1820   // Hard limit on iteration count based on empirical data (above this amount
1821   // there's pretty significant overhead).
1822   constexpr int64_t kMaxIterations = 32;
1823   std::optional<WindowedEinsumConfig> e_config = GetWindowedEinsumConfiguration(
1824       num_partitions, output_lhs_non_contracting_partitions,
1825       output_rhs_non_contracting_partitions, rhs_contracting_partitions,
1826       rhs_non_contracting_partitions, rhs_batch_partitions,
1827       lhs_contracting_partitions, lhs_non_contracting_partitions,
1828       lhs_batch_partitions, ShapeSizeInBytes(rhs.base_shape()),
1829       ShapeSizeInBytes(lhs.base_shape()), ShapeSizeInBytes(output_base_shape),
1830       options, output_sharding_transposed_to_match_lhs,
1831       output_sharding_transposed_to_match_rhs,
1832       lhs_sharding_transposed_to_match_rhs,
1833       rhs_sharding_transposed_to_match_lhs, lhs_sharding, rhs_sharding,
1834       conv_window, dims_mapping, kMaxIterations, original_hlo, &lhs, &rhs,
1835       create_sharded_dot, b, module, visitor);
1836   if (e_config) {
1837     VLOG(2) << "Emit windowed dot.";
1838     return emit_windowed_dot_general(*e_config);
1839   }
1840 
1841   {
1842     // Try batch-parallel by resharding one operand, and allowing all-reduce.
1843     TF_ASSIGN_OR_RETURN(
1844         HloInstruction * partitioned_dot,
1845         try_emit_output_batch_partitioned_einsum_with_reshard(true));
1846     if (partitioned_dot) {
1847       return partitioned_dot;
1848     }
1849   }
1850 
1851   // LHS and RHS have the same partitioned contracting dimensions.
1852   if (lhs_contracting_partitions == rhs_contracting_partitions &&
1853       lhs_contracting_partitions == num_partitions) {
1854     if (!may_reshard_without_detecting_match &&
1855         !output_sharding.IsReplicated() &&
1856         output_sharding.NumTiles() != num_partitions) {
1857       // The output is not fully sliced; the recursive handling has better
1858       // pattern matching for reduce scatters in subgroups.
1859       return nullptr;
1860     }
1861     // Pad both sides with zero, since NaN at one side cannot be masked by zero
1862     // on the other side.
1863     if (ShapeSizeInBytes(lhs.base_shape()) <
1864         ShapeSizeInBytes(rhs.base_shape())) {
1865       lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithZero();
1866       rhs = rhs.PadWithZero();
1867     } else {
1868       lhs = lhs.PadWithZero();
1869       rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithZero();
1870     }
1871     TF_ASSIGN_OR_RETURN(
1872         auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
1873     std::vector<int64_t> lhs_contracting_dims;
1874     lhs_contracting_dims.reserve(lhs.base_shape().rank());
1875     for (const auto& cd : dims_mapping.contracting_dims) {
1876       lhs_contracting_dims.push_back(cd.lhs);
1877     }
1878     auto ar = lhs.state().partitioner->AllReduceAlongShardingDims(
1879         b, dot, lhs.sharding(), lhs.state().next_channel_id,
1880         lhs_contracting_dims, lhs.state().collective_ops_creator,
1881         MakeBinaryAdd(output_base_shape.element_type(), module));
1882     ar->set_sharding(HloSharding::Replicate());
1883     return PartitionedHlo(ar, output_base_shape, lhs.state())
1884         .Reshard(output_sharding)
1885         .hlo();
1886   }
1887 
1888   // LHS and output have the same partitioned non-contracting dimensions.
1889   if (lhs_non_contracting_partitions == num_partitions &&
1890       output_lhs_non_contracting_partitions == num_partitions &&
1891       lhs_sharding_transposed_to_match_output == output_sharding) {
1892     auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo();
1893     TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs_replicated,
1894                                                      b, conv_window));
1895     return dot;
1896   }
1897 
1898   // RHS and output have the same partitioned non-contracting dimensions.
1899   if (rhs_non_contracting_partitions == num_partitions &&
1900       output_rhs_non_contracting_partitions == num_partitions &&
1901       rhs_sharding_transposed_to_match_output == output_sharding) {
1902     auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo();
1903     TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs_replicated, rhs.hlo(),
1904                                                      b, conv_window));
1905     return dot;
1906   }
1907 
1908   if (may_reshard_without_detecting_match) {
1909     // Output is batch partitioned.
1910     if (output_batch_partitions == num_partitions) {
1911       auto resharded_lhs =
1912           lhs.Reshard(*output_sharding_transposed_to_match_lhs);
1913       auto resharded_rhs =
1914           rhs.Reshard(*output_sharding_transposed_to_match_rhs);
1915       TF_ASSIGN_OR_RETURN(
1916           auto dot, create_sharded_dot(resharded_lhs.hlo(), resharded_rhs.hlo(),
1917                                        b, conv_window));
1918       return dot;
1919     }
1920     // Output is partitioned along LHS non-contracting dimensions.
1921     if (output_lhs_non_contracting_partitions == num_partitions) {
1922       auto resharded_lhs =
1923           lhs.Reshard(*output_sharding_transposed_to_match_lhs);
1924       auto replicated_rhs = rhs.Reshard(HloSharding::Replicate());
1925       TF_ASSIGN_OR_RETURN(
1926           auto dot, create_sharded_dot(resharded_lhs.hlo(),
1927                                        replicated_rhs.hlo(), b, conv_window));
1928       return dot;
1929     }
1930     // Output is partitioned along RHS non-contracting dimensions.
1931     if (output_rhs_non_contracting_partitions == num_partitions) {
1932       auto replicated_lhs = lhs.Reshard(HloSharding::Replicate());
1933       auto resharded_rhs =
1934           rhs.Reshard(*output_sharding_transposed_to_match_rhs);
1935       TF_ASSIGN_OR_RETURN(
1936           auto dot, create_sharded_dot(replicated_lhs.hlo(),
1937                                        resharded_rhs.hlo(), b, conv_window));
1938       return dot;
1939     }
1940   }
1941 
1942   // Returns true if it is beneficial to reshard the operand at `operand_idx`
1943   // across the contracting dimension.
1944   const auto should_partition_contracting_dim = [&](int64_t operand_idx) {
1945     if (!output_sharding.IsReplicated()) {
1946       return false;
1947     }
1948 
1949     if (operand_idx == 0) {
1950       // If LHS and output are replicated, we compare the cost of all-gather
1951       // on RHS vs all-reduce on the output.
1952       return (rhs_contracting_partitions == num_partitions) &&
1953              lhs.sharding().IsReplicated() &&
1954              ShapeUtil::ElementsIn(rhs.base_shape()) >
1955                  ShapeUtil::ElementsIn(output_base_shape);
1956     } else {
1957       return (lhs_contracting_partitions == num_partitions) &&
1958              rhs.sharding().IsReplicated() &&
1959              ShapeUtil::ElementsIn(lhs.base_shape()) >
1960                  ShapeUtil::ElementsIn(output_base_shape);
1961     }
1962   };
1963 
1964   // When the output is replicated and one of the operands is partitioned along
1965   // contracting dimension, align the other operand to be partitioned along
1966   // the contracting dimensions.
1967   if (output_sharding.IsReplicated() && (should_partition_contracting_dim(0) ||
1968                                          should_partition_contracting_dim(1))) {
1969     if (should_partition_contracting_dim(0)) {
1970       lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithZero();
1971       rhs = rhs.PadWithZero();
1972     } else {
1973       lhs = lhs.PadWithZero();
1974       rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithZero();
1975     }
1976     TF_ASSIGN_OR_RETURN(
1977         auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b, conv_window));
1978 
1979     std::vector<int64_t> lhs_contracting_dims;
1980     lhs_contracting_dims.reserve(lhs.base_shape().rank());
1981     for (const auto& cd : dims_mapping.contracting_dims) {
1982       lhs_contracting_dims.push_back(cd.lhs);
1983     }
1984     return lhs.state().partitioner->AllReduceAlongShardingDims(
1985         b, dot, lhs.sharding(), lhs.state().next_channel_id,
1986         lhs_contracting_dims, lhs.state().collective_ops_creator,
1987         MakeBinaryAdd(output_base_shape.element_type(), module));
1988   }
1989   return nullptr;
1990 }
1991 
1992 StatusOr<HloInstruction*> PartitionDot(
1993     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
1994     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
1995     int64_t num_partitions,
1996     const std::function<StatusOr<HloInstruction*>(
1997         HloInstruction*, HloInstruction*, SpmdBuilder*,
1998         const Window& conv_window)>& create_sharded_dot,
1999     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2000     const SpmdPartitionerOptions& options, SpmdBuilder* b,
2001     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2002         windowed_dot_general_loops,
2003     SpmdPartitioningVisitor* visitor);
2004 
PartitionDotGroupOnBatch(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,SpmdPartitioningVisitor * visitor)2005 StatusOr<HloInstruction*> PartitionDotGroupOnBatch(
2006     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
2007     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
2008     int64_t num_partitions, int64_t lhs_contracting_partitions,
2009     int64_t rhs_contracting_partitions, int64_t lhs_non_contracting_partitions,
2010     int64_t rhs_non_contracting_partitions,
2011     const std::function<StatusOr<HloInstruction*>(
2012         HloInstruction*, HloInstruction*, SpmdBuilder*,
2013         const Window& conv_window)>& create_sharded_dot,
2014     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2015     bool require_matching_devices_to_group,
2016     const SpmdPartitionerOptions& options, SpmdBuilder* b,
2017     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2018         windowed_dot_general_loops,
2019     SpmdPartitioningVisitor* visitor) {
2020   std::vector<std::pair<HloInstruction*, HloSharding>>
2021       top_level_sharding_to_reset;
2022   absl::Cleanup cleaner = [&] {
2023     for (auto& to_reset : top_level_sharding_to_reset) {
2024       to_reset.first->set_sharding(to_reset.second);
2025     }
2026   };
2027   std::vector<int64_t> lhs_dims;
2028   std::vector<int64_t> rhs_dims;
2029   std::vector<int64_t> output_dims;
2030   auto lhs_sharding_dims_adjusted_to_output =
2031       lhs.sharding().IsReplicated()
2032           ? std::vector<int64_t>(lhs.base_shape().rank(), 1)
2033           : lhs.sharding().tile_assignment().dimensions();
2034   auto rhs_sharding_dims_adjusted_to_output =
2035       rhs.sharding().IsReplicated()
2036           ? std::vector<int64_t>(rhs.base_shape().rank(), 1)
2037           : rhs.sharding().tile_assignment().dimensions();
2038   auto output_sharding_dims_adjusted_to_lhs =
2039       output_sharding.tile_assignment().dimensions();
2040   bool lhs_rhs_dims_matching = true;
2041   for (const auto& dim : dims_mapping.batch_dims) {
2042     lhs_dims.push_back(dim.lhs);
2043     rhs_dims.push_back(dim.rhs);
2044     output_dims.push_back(dim.output);
2045     if (lhs_sharding_dims_adjusted_to_output[dim.lhs] !=
2046         rhs_sharding_dims_adjusted_to_output[dim.rhs]) {
2047       lhs_rhs_dims_matching = false;
2048     }
2049     lhs_sharding_dims_adjusted_to_output[dim.lhs] =
2050         output_sharding.tile_assignment().dim(dim.output);
2051     rhs_sharding_dims_adjusted_to_output[dim.rhs] =
2052         output_sharding.tile_assignment().dim(dim.output);
2053     output_sharding_dims_adjusted_to_lhs[dim.output] =
2054         lhs.sharding().tile_assignment().dim(dim.lhs);
2055   }
2056   if (require_matching_devices_to_group && lhs_rhs_dims_matching) {
2057     lhs_rhs_dims_matching =
2058         rhs.sharding() ==
2059         UngroupSharding(AlignGroupsWith(
2060             hlo_sharding_util::GroupShardingOnDims(rhs.sharding(), rhs_dims),
2061             hlo_sharding_util::GroupShardingOnDims(lhs.sharding(), lhs_dims)));
2062   }
2063   auto output_grouped =
2064       hlo_sharding_util::GroupShardingOnDims(output_sharding, output_dims);
2065   PartitionedHlo per_group_lhs = lhs;
2066   PartitionedHlo per_group_rhs = rhs;
2067   if (lhs_rhs_dims_matching) {
2068     auto lhs_grouped =
2069         hlo_sharding_util::GroupShardingOnDims(lhs.sharding(), lhs_dims);
2070     auto rhs_grouped =
2071         hlo_sharding_util::GroupShardingOnDims(rhs.sharding(), rhs_dims);
2072     if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2073         ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2074       rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped);
2075       rhs = rhs.Reshard(UngroupSharding(rhs_grouped));
2076     } else {
2077       lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped);
2078       lhs = lhs.Reshard(UngroupSharding(lhs_grouped));
2079     }
2080     auto reshaped_output_tiling = output_sharding.tile_assignment();
2081     reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs);
2082     output_grouped = AlignGroupsWith(
2083         hlo_sharding_util::GroupShardingOnDims(
2084             output_sharding.ReplicateOnLastTileDim()
2085                 ? HloSharding::PartialTile(reshaped_output_tiling)
2086                 : HloSharding::Tile(reshaped_output_tiling),
2087             output_dims),
2088         lhs_grouped);
2089     auto per_group_partitioner_state = CreatePerGroupPartitioningState(
2090         lhs.state(), lhs_grouped.device_groups, b);
2091     top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs.sharding());
2092     lhs.hlo()->set_sharding(lhs_grouped.sharding);
2093     top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs.sharding());
2094     rhs.hlo()->set_sharding(rhs_grouped.sharding);
2095     CHECK(lhs.hlo() != rhs.hlo() ||
2096           lhs_grouped.sharding == rhs_grouped.sharding);
2097     per_group_lhs = PartitionedHlo(
2098         lhs.hlo(), GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
2099         per_group_partitioner_state);
2100     per_group_rhs = PartitionedHlo(
2101         rhs.hlo(), GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
2102         per_group_partitioner_state);
2103   } else {
2104     auto per_group_partitioner_state = CreatePerGroupPartitioningState(
2105         lhs.state(), output_grouped.device_groups, b);
2106     auto reshard_to_output_batch =
2107         [&](PartitionedHlo operand, absl::Span<const int64_t> batch_dims,
2108             absl::Span<const int64_t> contracting_dims,
2109             absl::Span<const int64_t> non_contracting_dims,
2110             int64_t contracting_dim_partitions,
2111             int64_t non_contracting_dim_partitions,
2112             int64_t other_contracting_dim_partitions,
2113             std::vector<int64_t>* sharding_dims_adjusted_to_output)
2114         -> std::optional<PartitionedHlo> {
2115       if (operand.sharding().IsTileMaximal()) {
2116         auto partially_sharded = PerGroupSliceFromReplicated(
2117             operand.Replicate().hlo(), operand.state().partition_id,
2118             output_grouped.device_groups, batch_dims,
2119             output_grouped.group_dim_sizes, b);
2120         partially_sharded->set_sharding(HloSharding::Replicate());
2121         return PartitionedHlo(partially_sharded, partially_sharded->shape(),
2122                               per_group_partitioner_state);
2123       }
2124       auto reshaped_tiling = operand.sharding().tile_assignment();
2125       // It's possible that the operand is not initially sharded on batch
2126       // dimensions in the same way as the output, although being tiled. In that
2127       // case, the current sharding_dims_adjusted_to_output may contain more
2128       // partitions than available devices. We remove partitioning on other
2129       // dimensions.
2130       if (Product(*sharding_dims_adjusted_to_output) >
2131           reshaped_tiling.num_elements()) {
2132         if (Product(*sharding_dims_adjusted_to_output) %
2133                 reshaped_tiling.num_elements() !=
2134             0) {
2135           return std::nullopt;
2136         }
2137         int64_t ratio = Product(*sharding_dims_adjusted_to_output) /
2138                         reshaped_tiling.num_elements();
2139         if (operand.sharding().ReplicateOnLastTileDim() &&
2140             reshaped_tiling.dimensions().back() % ratio == 0) {
2141           sharding_dims_adjusted_to_output->back() /= ratio;
2142           if (sharding_dims_adjusted_to_output->back() == 1) {
2143             sharding_dims_adjusted_to_output->pop_back();
2144           }
2145         } else if (ratio == non_contracting_dim_partitions &&
2146                    (ratio != contracting_dim_partitions ||
2147                     contracting_dim_partitions ==
2148                         other_contracting_dim_partitions)) {
2149           for (int64_t dim : non_contracting_dims) {
2150             (*sharding_dims_adjusted_to_output)[dim] = 1;
2151           }
2152         } else if (ratio == contracting_dim_partitions) {
2153           for (int64_t dim : contracting_dims) {
2154             (*sharding_dims_adjusted_to_output)[dim] = 1;
2155           }
2156         } else {
2157           return std::nullopt;
2158         }
2159       }
2160       // If the operand is initially sharded more ways than the output in the
2161       // batch dimensions, sharding_dims_adjusted_to_output currently contains
2162       // fewer partitions than available devices. We do not handle this case.
2163       if (Product(*sharding_dims_adjusted_to_output) <
2164           reshaped_tiling.num_elements()) {
2165         return std::nullopt;
2166       }
2167       reshaped_tiling.Reshape(*sharding_dims_adjusted_to_output);
2168       auto grouped =
2169           AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
2170                               operand.base_shape().rank() <
2171                                       sharding_dims_adjusted_to_output->size()
2172                                   ? HloSharding::PartialTile(reshaped_tiling)
2173                                   : HloSharding::Tile(reshaped_tiling),
2174                               batch_dims),
2175                           output_grouped);
2176       if (require_matching_devices_to_group &&
2177           operand.sharding() != UngroupSharding(grouped)) {
2178         return std::nullopt;
2179       }
2180       auto resharded = operand.Reshard(UngroupSharding(grouped));
2181       top_level_sharding_to_reset.emplace_back(resharded.hlo(),
2182                                                resharded.sharding());
2183       resharded.hlo()->set_sharding(grouped.sharding);
2184       return PartitionedHlo(resharded.hlo(),
2185                             GetPerGroupBaseShape(grouped, operand.base_shape()),
2186                             per_group_partitioner_state);
2187     };
2188     std::vector<int64_t> lhs_contracting_dims;
2189     std::vector<int64_t> rhs_contracting_dims;
2190     lhs_contracting_dims.reserve(dims_mapping.contracting_dims.size());
2191     rhs_contracting_dims.reserve(dims_mapping.contracting_dims.size());
2192     for (const auto& dim : dims_mapping.contracting_dims) {
2193       lhs_contracting_dims.push_back(dim.lhs);
2194       rhs_contracting_dims.push_back(dim.rhs);
2195     }
2196     std::vector<int64_t> lhs_non_contracting_dims;
2197     std::vector<int64_t> rhs_non_contracting_dims;
2198     lhs_non_contracting_dims.reserve(
2199         dims_mapping.lhs_non_contracting_dims.size());
2200     rhs_non_contracting_dims.reserve(
2201         dims_mapping.rhs_non_contracting_dims.size());
2202     for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
2203       lhs_non_contracting_dims.push_back(dim.lhs);
2204     }
2205     for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
2206       rhs_non_contracting_dims.push_back(dim.rhs);
2207     }
2208     if (auto resharded = reshard_to_output_batch(
2209             lhs, lhs_dims, lhs_contracting_dims, lhs_non_contracting_dims,
2210             lhs_contracting_partitions, lhs_non_contracting_partitions,
2211             rhs_contracting_partitions,
2212             &lhs_sharding_dims_adjusted_to_output)) {
2213       per_group_lhs = *resharded;
2214     } else {
2215       return nullptr;
2216     }
2217     if (auto resharded = reshard_to_output_batch(
2218             rhs, rhs_dims, rhs_contracting_dims, rhs_non_contracting_dims,
2219             rhs_contracting_partitions, rhs_non_contracting_partitions,
2220             lhs_contracting_partitions,
2221             &rhs_sharding_dims_adjusted_to_output)) {
2222       per_group_rhs = *resharded;
2223     } else {
2224       return nullptr;
2225     }
2226     CHECK(lhs.hlo() != rhs.hlo() ||
2227           per_group_lhs.sharding() == per_group_rhs.sharding());
2228   }
2229   TF_ASSIGN_OR_RETURN(
2230       auto dot,
2231       PartitionDot(per_group_lhs, per_group_rhs,
2232                    GetPerGroupBaseShape(output_grouped, output_base_shape),
2233                    output_grouped.sharding, dims_mapping,
2234                    num_partitions / output_grouped.device_groups.size(),
2235                    create_sharded_dot, conv_window, module, original_hlo,
2236                    options, b, windowed_dot_general_loops, visitor));
2237   dot->set_sharding(UngroupSharding(output_grouped));
2238   return PartitionedHlo(dot, output_base_shape, lhs.state())
2239       .Reshard(output_sharding)
2240       .hlo();
2241 }
2242 
GetNonContractingPartitionGroupedShardingForMatchedOperand(bool lhs_matching,const HloSharding & matching_sharding,const HloSharding & output_sharding,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_dims)2243 GroupedSharding GetNonContractingPartitionGroupedShardingForMatchedOperand(
2244     bool lhs_matching, const HloSharding& matching_sharding,
2245     const HloSharding& output_sharding,
2246     absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_dims) {
2247   std::vector<int64_t> matching_sharding_dims =
2248       matching_sharding.tile_assignment().dimensions();
2249   std::vector<int64_t> matching_dims;
2250   std::vector<int64_t> output_dims;
2251   // Make sure the partitioning on matching's non-contracting dimensions
2252   // defines the same device groups for both matching and output.
2253   for (const auto& dim : partitioned_dims) {
2254     int64_t md = lhs_matching ? dim.lhs : dim.rhs;
2255     matching_sharding_dims[md] =
2256         output_sharding.tile_assignment().dim(dim.output);
2257     matching_dims.push_back(md);
2258     output_dims.push_back(dim.output);
2259   }
2260   GroupedSharding output_grouped =
2261       hlo_sharding_util::GroupShardingOnDims(output_sharding, output_dims);
2262   Array<int64_t> reshaped_matching_tiling = matching_sharding.tile_assignment();
2263   reshaped_matching_tiling.Reshape(matching_sharding_dims);
2264   return AlignGroupsWith(
2265       hlo_sharding_util::GroupShardingOnDims(
2266           matching_sharding.ReplicateOnLastTileDim()
2267               ? HloSharding::PartialTile(reshaped_matching_tiling)
2268               : HloSharding::Tile(reshaped_matching_tiling),
2269           matching_dims),
2270       output_grouped);
2271 }
2272 
2273 std::optional<GroupedSharding>
GetNonContractingPartitionGroupedShardingForOtherOperand(bool lhs_matching,const Shape & output_base_shape,const Shape & other_shape,int64_t other_contracting_partitions,int64_t other_non_contracting_partitions,int64_t matching_contracting_partitions,int64_t output_other_non_contracting_partitions,const HloSharding & other_sharding,const HloSharding & output_sharding,absl::Span<const DotConvDimsMapping::DimsMapping> matching_partitioned_dims,absl::Span<const DotConvDimsMapping::DimsMapping> other_non_contracting_dims,absl::Span<const DotConvDimsMapping::DimsMapping> other_contracting_dims)2274 GetNonContractingPartitionGroupedShardingForOtherOperand(
2275     bool lhs_matching, const Shape& output_base_shape, const Shape& other_shape,
2276     int64_t other_contracting_partitions,
2277     int64_t other_non_contracting_partitions,
2278     int64_t matching_contracting_partitions,
2279     int64_t output_other_non_contracting_partitions,
2280     const HloSharding& other_sharding, const HloSharding& output_sharding,
2281     absl::Span<const DotConvDimsMapping::DimsMapping> matching_partitioned_dims,
2282     absl::Span<const DotConvDimsMapping::DimsMapping>
2283         other_non_contracting_dims,
2284     absl::Span<const DotConvDimsMapping::DimsMapping> other_contracting_dims) {
2285   int64_t group_count = 1;
2286   std::vector<int64_t> output_dims;
2287   output_dims.reserve(matching_partitioned_dims.size());
2288   for (const auto& dim : matching_partitioned_dims) {
2289     output_dims.push_back(dim.output);
2290     group_count *= output_sharding.tile_assignment().dim(dim.output);
2291   }
2292   GroupedSharding output_grouped =
2293       hlo_sharding_util::GroupShardingOnDims(output_sharding, output_dims);
2294   std::vector<int64_t> other_group_dims;
2295   if (other_sharding.ReplicateOnLastTileDim() &&
2296       other_sharding.tile_assignment().dimensions().back() % group_count == 0) {
2297     other_group_dims.push_back(
2298         other_sharding.tile_assignment().num_dimensions() - 1);
2299   } else {
2300     const bool may_replicate_other_contracting_dims =
2301         (other_contracting_partitions == group_count &&
2302          other_non_contracting_partitions ==
2303              output_other_non_contracting_partitions);
2304     const bool may_replicate_other_non_contracting_dims =
2305         group_count == other_non_contracting_partitions &&
2306         matching_contracting_partitions == other_contracting_partitions;
2307     if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
2308             other_sharding, output_grouped.device_groups)) {
2309       other_group_dims = std::move(*found_dims);
2310     } else if (may_replicate_other_contracting_dims &&
2311                (!may_replicate_other_non_contracting_dims ||
2312                 ShapeUtil::ByteSizeOf(other_shape)) <=
2313                    ShapeUtil::ByteSizeOf(MakePartitionedShape(
2314                        output_base_shape, output_sharding))) {
2315       for (const auto& dim : other_contracting_dims) {
2316         other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
2317       }
2318     } else if (may_replicate_other_non_contracting_dims) {
2319       for (const auto& dim : other_non_contracting_dims) {
2320         other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs);
2321       }
2322     } else {
2323       return std::nullopt;
2324     }
2325   }
2326   if (other_group_dims.size() == 1 &&
2327       other_group_dims[0] ==
2328           other_sharding.tile_assignment().num_dimensions() - 1) {
2329     std::vector<int64_t> group_dim_shards = {
2330         other_sharding.tile_assignment().dimensions().back() / group_count};
2331     return AlignGroupsWith(
2332         hlo_sharding_util::GroupShardingOnDims(
2333             other_sharding, {other_group_dims[0]}, group_dim_shards),
2334         output_grouped, /*ignore_group_order=*/true);
2335 
2336   } else if (!other_sharding.IsReplicated()) {
2337     return AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
2338                                other_sharding, other_group_dims),
2339                            output_grouped,
2340                            /*ignore_group_order=*/true);
2341   }
2342   return std::nullopt;
2343 }
2344 
PartitionDotGroupOnNonContracting(bool lhs_matching,PartitionedHlo matching,PartitionedHlo other,int64_t matching_contracting_partitions,int64_t other_contracting_partitions,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_non_contracting_dims,int64_t other_non_contracting_partitions,int64_t output_other_non_contracting_partitions,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,SpmdPartitioningVisitor * visitor)2345 StatusOr<HloInstruction*> PartitionDotGroupOnNonContracting(
2346     bool lhs_matching, PartitionedHlo matching, PartitionedHlo other,
2347     int64_t matching_contracting_partitions,
2348     int64_t other_contracting_partitions,
2349     absl::Span<const DotConvDimsMapping::DimsMapping>
2350         partitioned_non_contracting_dims,
2351     int64_t other_non_contracting_partitions,
2352     int64_t output_other_non_contracting_partitions,
2353     const Shape& output_base_shape, const HloSharding& output_sharding,
2354     const DotConvDimsMapping& dims_mapping, int64_t num_partitions,
2355     const std::function<StatusOr<HloInstruction*>(
2356         HloInstruction*, HloInstruction*, SpmdBuilder*,
2357         const Window& conv_window)>& create_sharded_dot,
2358     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2359     bool require_matching_devices_to_group,
2360     const SpmdPartitionerOptions& options, SpmdBuilder* b,
2361     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2362         windowed_dot_general_loops,
2363     SpmdPartitioningVisitor* visitor) {
2364   std::vector<std::pair<HloInstruction*, HloSharding>>
2365       top_level_sharding_to_reset;
2366   absl::Cleanup cleaner = [&] {
2367     for (auto& to_reset : top_level_sharding_to_reset) {
2368       to_reset.first->set_sharding(to_reset.second);
2369     }
2370   };
2371 
2372   std::vector<int64_t> output_dims;
2373   output_dims.reserve(partitioned_non_contracting_dims.size());
2374   for (const auto& dim : partitioned_non_contracting_dims) {
2375     output_dims.push_back(dim.output);
2376   }
2377   GroupedSharding output_grouped =
2378       hlo_sharding_util::GroupShardingOnDims(output_sharding, output_dims);
2379   GroupedSharding matching_grouped =
2380       GetNonContractingPartitionGroupedShardingForMatchedOperand(
2381           lhs_matching, matching.sharding(), output_sharding,
2382           partitioned_non_contracting_dims);
2383   if (require_matching_devices_to_group &&
2384       matching.sharding() != UngroupSharding(matching_grouped)) {
2385     return nullptr;
2386   }
2387   std::optional<GroupedSharding> other_grouped =
2388       GetNonContractingPartitionGroupedShardingForOtherOperand(
2389           lhs_matching, output_base_shape, other.hlo()->shape(),
2390           other_contracting_partitions, other_non_contracting_partitions,
2391           matching_contracting_partitions,
2392           output_other_non_contracting_partitions, other.sharding(),
2393           output_sharding, partitioned_non_contracting_dims,
2394           lhs_matching ? dims_mapping.rhs_non_contracting_dims
2395                        : dims_mapping.lhs_non_contracting_dims,
2396           dims_mapping.contracting_dims);
2397 
2398   if (!other_grouped) {
2399     other = other.Replicate();
2400   }
2401   matching = matching.Reshard(UngroupSharding(matching_grouped));
2402   auto per_group_partitioner_state = CreatePerGroupPartitioningState(
2403       matching.state(), matching_grouped.device_groups, b);
2404   top_level_sharding_to_reset.emplace_back(matching.hlo(), matching.sharding());
2405   matching.hlo()->set_sharding(matching_grouped.sharding);
2406   auto matching_p = PartitionedHlo(
2407       matching.hlo(),
2408       GetPerGroupBaseShape(matching_grouped, matching.base_shape()),
2409       per_group_partitioner_state);
2410 
2411   auto partially_replicated_other = other.hlo();
2412   if (other_grouped && other_grouped->group_dims.size() == 1 &&
2413       other_grouped->group_dims[0] == other.base_shape().rank()) {
2414     // Group on replication dim.
2415     other = other.Reshard(UngroupSharding(*other_grouped));
2416     partially_replicated_other = other.hlo();
2417     top_level_sharding_to_reset.emplace_back(other.hlo(), other.sharding());
2418     partially_replicated_other->set_sharding(other_grouped->sharding);
2419   } else if (!other.sharding().IsReplicated()) {
2420     HloSharding target_sharding = UngroupSharding(*other_grouped);
2421     GroupedSharding target_group_sharding =
2422         hlo_sharding_util::GroupShardingOnDims(target_sharding,
2423                                                other_grouped->group_dims);
2424     const bool device_group_match = hlo_sharding_util::DeviceGroupsAreMatch(
2425         target_group_sharding, *other_grouped, /*ignore_group_order=*/false);
2426 
2427     // Do not reshard for partial replicate if device group are matched.
2428     // There is a reshard to partial replicate right after this reshard. If
2429     // the device ids within each partial replicate group is the same, no need
2430     // to reshard here.
2431     if (!other.sharding().ReplicateOnLastTileDim() || !device_group_match) {
2432       other = other.Reshard(target_sharding);
2433     }
2434     partially_replicated_other =
2435         other
2436             .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2437                 other.sharding(), other_grouped->group_dims))
2438             .hlo();
2439     top_level_sharding_to_reset.emplace_back(
2440         partially_replicated_other, partially_replicated_other->sharding());
2441     partially_replicated_other->set_sharding(other_grouped->sharding);
2442   }
2443   auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(),
2444                                 per_group_partitioner_state);
2445   TF_ASSIGN_OR_RETURN(
2446       auto dot,
2447       PartitionDot(lhs_matching ? matching_p : other_p,
2448                    lhs_matching ? other_p : matching_p,
2449                    GetPerGroupBaseShape(output_grouped, output_base_shape),
2450                    output_grouped.sharding, dims_mapping,
2451                    num_partitions / matching_grouped.device_groups.size(),
2452                    create_sharded_dot, conv_window, module, original_hlo,
2453                    options, b, windowed_dot_general_loops, visitor));
2454   return dot;
2455 }
2456 
2457 std::pair<HloSharding, HloSharding>
GetDotGroupPartitionContractingOutputShardings(const DotConvDimsMapping & dims_mapping,const GroupedSharding & lhs_grouped,const Shape & output_base_shape,const HloSharding & output_sharding,int64_t group_count,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t output_batch_partitions,std::vector<int64_t> * output_slice_dims_out,bool * output_replicate_dim_grouped=nullptr)2458 GetDotGroupPartitionContractingOutputShardings(
2459     const DotConvDimsMapping& dims_mapping, const GroupedSharding& lhs_grouped,
2460     const Shape& output_base_shape, const HloSharding& output_sharding,
2461     int64_t group_count, int64_t output_lhs_non_contracting_partitions,
2462     int64_t output_rhs_non_contracting_partitions,
2463     int64_t output_batch_partitions,
2464     std::vector<int64_t>* output_slice_dims_out,
2465     bool* output_replicate_dim_grouped = nullptr) {
2466   HloSharding inner_output_sharding = HloSharding::Replicate();
2467   HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
2468   std::vector<int64_t> output_slice_dims;
2469   if (output_sharding.ReplicateOnLastTileDim() &&
2470       output_sharding.tile_assignment().dimensions().back() % group_count ==
2471           0) {
2472     std::vector<int64_t> group_dim_shards = {
2473         output_sharding.tile_assignment().dimensions().back() / group_count};
2474     auto grouped = AlignGroupsWith(
2475         hlo_sharding_util::GroupShardingOnDims(
2476             output_sharding,
2477             {output_sharding.tile_assignment().num_dimensions() - 1},
2478             group_dim_shards),
2479         lhs_grouped,
2480         /*ignore_group_order=*/true);
2481     outer_output_tmp_sharding = UngroupSharding(grouped);
2482     inner_output_sharding = std::move(grouped.sharding);
2483   } else {
2484     if (auto found_dims = FindMatchingPartitionedDimsForGrouping(
2485             output_sharding, lhs_grouped.device_groups)) {
2486       output_slice_dims = std::move(*found_dims);
2487       if (!output_slice_dims.empty()) {
2488         // FindMatchingPartitionedDimsForGrouping already makes sure the groups
2489         // are compatible with LHS/RHS. We avoid AlignGroupsWith/UngroupSharding
2490         // because that could change the group order causing a reshard with
2491         // collective-permute, which is unnecessary since these groups will be
2492         // all-reduced upon anyway for contracting-dim sharding.
2493         auto grouped = hlo_sharding_util::GroupShardingOnDims(
2494             output_sharding, output_slice_dims);
2495         inner_output_sharding = grouped.sharding;
2496         outer_output_tmp_sharding = output_sharding;
2497       }
2498     } else if (output_lhs_non_contracting_partitions == group_count ||
2499                output_rhs_non_contracting_partitions == group_count ||
2500                output_batch_partitions == group_count) {
2501       if (output_lhs_non_contracting_partitions == group_count) {
2502         for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
2503           output_slice_dims.push_back(dim.output);
2504         }
2505       } else if (output_rhs_non_contracting_partitions == group_count) {
2506         for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
2507           output_slice_dims.push_back(dim.output);
2508         }
2509       } else {
2510         for (const auto& dim : dims_mapping.batch_dims) {
2511           output_slice_dims.push_back(dim.output);
2512         }
2513       }
2514       if (!output_slice_dims.empty()) {
2515         auto grouped = AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
2516                                            output_sharding, output_slice_dims),
2517                                        lhs_grouped);
2518         inner_output_sharding = grouped.sharding;
2519         outer_output_tmp_sharding = UngroupSharding(grouped);
2520       }
2521     }
2522   }
2523   if (output_replicate_dim_grouped) {
2524     *output_replicate_dim_grouped =
2525         absl::c_linear_search(output_slice_dims, output_base_shape.rank());
2526   }
2527   if (output_slice_dims_out) {
2528     if (output_sharding.ReplicateOnLastTileDim()) {
2529       // Remove the replication group dim.
2530       output_slice_dims.erase(
2531           std::remove_if(
2532               output_slice_dims.begin(), output_slice_dims.end(),
2533               [&](int64_t dim) { return dim == output_base_shape.rank(); }),
2534           output_slice_dims.end());
2535     }
2536     (*output_slice_dims_out) = std::move(output_slice_dims);
2537   }
2538   return std::make_pair(inner_output_sharding, outer_output_tmp_sharding);
2539 }
2540 
2541 std::pair<HloSharding, HloSharding>
GetDotGroupPartitionContractingLhsRhsShardings(const PartitionedHlo & lhs,const PartitionedHlo & rhs,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_contracting_dims)2542 GetDotGroupPartitionContractingLhsRhsShardings(
2543     const PartitionedHlo& lhs, const PartitionedHlo& rhs,
2544     absl::Span<const DotConvDimsMapping::DimsMapping>
2545         partitioned_contracting_dims) {
2546   HloSharding lhs_sharding = lhs.sharding();
2547   HloSharding rhs_sharding = rhs.sharding();
2548   std::vector<int64_t> lhs_tile_shape =
2549       lhs_sharding.tile_assignment().dimensions();
2550   std::vector<int64_t> rhs_tile_shape =
2551       rhs_sharding.tile_assignment().dimensions();
2552   if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2553       ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2554     for (const auto& dim : partitioned_contracting_dims) {
2555       rhs_tile_shape[dim.rhs] = lhs_tile_shape[dim.lhs];
2556     }
2557     auto new_tile = rhs.sharding().tile_assignment();
2558     new_tile.Reshape(rhs_tile_shape);
2559     rhs_sharding = rhs_sharding.ReplicateOnLastTileDim()
2560                        ? HloSharding::PartialTile(new_tile)
2561                        : HloSharding::Tile(new_tile);
2562   } else {
2563     for (const auto& dim : partitioned_contracting_dims) {
2564       lhs_tile_shape[dim.lhs] = rhs_tile_shape[dim.rhs];
2565     }
2566     auto new_tile = lhs.sharding().tile_assignment();
2567     new_tile.Reshape(lhs_tile_shape);
2568     lhs_sharding = lhs_sharding.ReplicateOnLastTileDim()
2569                        ? HloSharding::PartialTile(new_tile)
2570                        : HloSharding::Tile(new_tile);
2571   }
2572   return std::make_pair(lhs_sharding, rhs_sharding);
2573 }
2574 
PartitionDotGroupOnContracting(PartitionedHlo lhs,PartitionedHlo rhs,absl::Span<const DotConvDimsMapping::DimsMapping> partitioned_contracting_dims,int64_t output_batch_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,SpmdPartitioningVisitor * visitor)2575 StatusOr<HloInstruction*> PartitionDotGroupOnContracting(
2576     PartitionedHlo lhs, PartitionedHlo rhs,
2577     absl::Span<const DotConvDimsMapping::DimsMapping>
2578         partitioned_contracting_dims,
2579     int64_t output_batch_partitions,
2580     int64_t output_lhs_non_contracting_partitions,
2581     int64_t output_rhs_non_contracting_partitions,
2582     const Shape& output_base_shape, const HloSharding& output_sharding,
2583     const DotConvDimsMapping& dims_mapping, int64_t num_partitions,
2584     const std::function<StatusOr<HloInstruction*>(
2585         HloInstruction*, HloInstruction*, SpmdBuilder*,
2586         const Window& conv_window)>& create_sharded_dot,
2587     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
2588     bool require_matching_devices_to_group,
2589     const SpmdPartitionerOptions& options, SpmdBuilder* b,
2590     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
2591         windowed_dot_general_loops,
2592     SpmdPartitioningVisitor* visitor) {
2593   std::vector<std::pair<HloInstruction*, HloSharding>>
2594       top_level_sharding_to_reset;
2595   absl::Cleanup cleaner = [&] {
2596     for (auto& to_reset : top_level_sharding_to_reset) {
2597       to_reset.first->set_sharding(to_reset.second);
2598     }
2599   };
2600   std::vector<int64_t> lhs_dims;
2601   std::vector<int64_t> rhs_dims;
2602   int64_t group_count = 1;
2603   for (const auto& dim : partitioned_contracting_dims) {
2604     lhs_dims.push_back(dim.lhs);
2605     rhs_dims.push_back(dim.rhs);
2606     group_count *= lhs.sharding().tile_assignment().dim(dim.lhs);
2607   }
2608   HloSharding lhs_sharding = HloSharding::Replicate();
2609   HloSharding rhs_sharding = HloSharding::Replicate();
2610   std::tie(lhs_sharding, rhs_sharding) =
2611       GetDotGroupPartitionContractingLhsRhsShardings(
2612           lhs, rhs, partitioned_contracting_dims);
2613   auto lhs_grouped =
2614       hlo_sharding_util::GroupShardingOnDims(lhs_sharding, lhs_dims);
2615   auto rhs_grouped =
2616       hlo_sharding_util::GroupShardingOnDims(rhs_sharding, rhs_dims);
2617   if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) >
2618       ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) {
2619     rhs_grouped = AlignGroupsWith(rhs_grouped, lhs_grouped);
2620     rhs_sharding = UngroupSharding(rhs_grouped);
2621     if (require_matching_devices_to_group && rhs.sharding() != rhs_sharding) {
2622       return nullptr;
2623     }
2624     rhs = rhs.Reshard(rhs_sharding);
2625   } else {
2626     lhs_grouped = AlignGroupsWith(lhs_grouped, rhs_grouped);
2627     lhs_sharding = UngroupSharding(lhs_grouped);
2628     if (require_matching_devices_to_group && lhs.sharding() != lhs_sharding) {
2629       return nullptr;
2630     }
2631     lhs = lhs.Reshard(lhs_sharding);
2632   }
2633   // Mask out invalid data.
2634   std::vector<int64_t> lhs_skipped_dims;
2635   for (int64_t i = 0; i < lhs.base_shape().rank(); ++i) {
2636     if (absl::c_linear_search(lhs_dims, i)) {
2637       continue;
2638     }
2639     lhs_skipped_dims.push_back(i);
2640   }
2641   lhs = lhs.PadWithZero(
2642       /*left_padded_dims=*/{}, lhs_skipped_dims);
2643   std::vector<int64_t> rhs_skipped_dims;
2644   for (int64_t i = 0; i < rhs.base_shape().rank(); ++i) {
2645     if (absl::c_linear_search(rhs_dims, i)) {
2646       continue;
2647     }
2648     rhs_skipped_dims.push_back(i);
2649   }
2650   rhs = rhs.PadWithZero(
2651       /*left_padded_dims=*/{}, rhs_skipped_dims);
2652   top_level_sharding_to_reset.emplace_back(lhs.hlo(), lhs_sharding);
2653   lhs.hlo()->set_sharding(lhs_grouped.sharding);
2654   top_level_sharding_to_reset.emplace_back(rhs.hlo(), rhs_sharding);
2655   rhs.hlo()->set_sharding(rhs_grouped.sharding);
2656 
2657   HloSharding inner_output_sharding = HloSharding::Replicate();
2658   HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
2659   std::vector<int64_t> output_slice_dims;
2660   bool output_replicate_dim_grouped;
2661   std::tie(inner_output_sharding, outer_output_tmp_sharding) =
2662       GetDotGroupPartitionContractingOutputShardings(
2663           dims_mapping, lhs_grouped, output_base_shape, output_sharding,
2664           group_count, output_lhs_non_contracting_partitions,
2665           output_rhs_non_contracting_partitions, output_batch_partitions,
2666           &output_slice_dims, &output_replicate_dim_grouped);
2667   Shape inner_output_base_shape = output_base_shape;
2668   auto get_non_slice_dims = [&] {
2669     std::vector<int64_t> non_group_dims;
2670     for (int64_t i = 0; i < output_base_shape.rank(); ++i) {
2671       if (!absl::c_linear_search(output_slice_dims, i)) {
2672         non_group_dims.push_back(i);
2673       }
2674     }
2675     return non_group_dims;
2676   };
2677   if (!output_slice_dims.empty()) {
2678     inner_output_base_shape = MakePartitionedShape(
2679         output_base_shape,
2680         hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2681             output_sharding, get_non_slice_dims()));
2682   }
2683   std::function<StatusOr<HloInstruction*>(HloInstruction*, HloInstruction*,
2684                                           SpmdBuilder*, const Window&)>
2685       inner_creator =
2686           [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b,
2687               const Window& conv_window) -> StatusOr<HloInstruction*> {
2688     TF_ASSIGN_OR_RETURN(auto inner_dot,
2689                         create_sharded_dot(l, r, b, conv_window));
2690     HloInstruction* result = inner_dot;
2691     if (!output_slice_dims.empty()) {
2692       // Create an AllReduce along slice dims first to allow a reduce-scatter.
2693       result = lhs.state().partitioner->AllReduceAlongShardingDims(
2694           b, result, outer_output_tmp_sharding, lhs.state().next_channel_id,
2695           output_slice_dims, lhs.state().collective_ops_creator,
2696           MakeBinaryAdd(output_base_shape.element_type(), module));
2697       // Use resharding to slice the output. Use a temporary reshard cache since
2698       // we are faking with replicated sharding.
2699       PartitionedHlo::PartitioningState new_state = lhs.state();
2700       new_state.b = b;
2701       new_state.partition_id =
2702           lhs.state().collective_ops_creator.create_partition_id(b);
2703       PartitionedHlo::ReshardCache tmp_cache;
2704       new_state.reshard_cache = &tmp_cache;
2705       result->set_sharding(HloSharding::Replicate());
2706       result =
2707           PartitionedHlo(result, result->shape(), new_state)
2708               .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2709                   outer_output_tmp_sharding, get_non_slice_dims()))
2710               .hlo();
2711       // If the output has partial replication, and the last tile dim is used
2712       // for grouping, we need to do a separate allreduce after reduce-scatter.
2713       if (output_replicate_dim_grouped) {
2714         result = lhs.state().partitioner->AllReduceAlongShardingDims(
2715             b, result, outer_output_tmp_sharding, lhs.state().next_channel_id,
2716             {output_base_shape.rank()}, lhs.state().collective_ops_creator,
2717             MakeBinaryAdd(output_base_shape.element_type(), module));
2718       }
2719     } else {
2720       result = lhs.state().partitioner->AllReduceAlongShardingDims(
2721           b, result, lhs_sharding, lhs.state().next_channel_id, lhs_dims,
2722           lhs.state().collective_ops_creator,
2723           MakeBinaryAdd(output_base_shape.element_type(), module));
2724     }
2725     return result;
2726   };
2727 
2728   PartitionedHlo::PartitioningState inner_state =
2729       CreatePerGroupPartitioningState(lhs.state(), lhs_grouped.device_groups,
2730                                       b);
2731 
2732   HloInstruction* maybe_windowed_dot = nullptr;
2733 
2734   // Tentatively disables the inner reshard when the "faster windowed einsum"
2735   // flag is enabled, because the windowed einsum implementation is currently
2736   // slow with this kind of reshard happening.
2737   int original_num_windowed_loops = windowed_dot_general_loops->size();
2738   if (options.choose_faster_windowed_einsum_over_mem) {
2739     Shape predicted_inner_output_base_shape = output_base_shape;
2740     auto predicted_inner_creator = create_sharded_dot;
2741     TF_ASSIGN_OR_RETURN(
2742         maybe_windowed_dot,
2743         PartitionDot(
2744             PartitionedHlo(lhs.hlo(),
2745                            GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
2746                            inner_state),
2747             PartitionedHlo(rhs.hlo(),
2748                            GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
2749                            inner_state),
2750             predicted_inner_output_base_shape, inner_output_sharding,
2751             dims_mapping, num_partitions / group_count, predicted_inner_creator,
2752             conv_window, module, original_hlo, options, b,
2753             windowed_dot_general_loops, visitor));
2754   }
2755   int new_num_windowed_loops = windowed_dot_general_loops->size();
2756 
2757   TF_ASSIGN_OR_RETURN(
2758       auto inner_dot,
2759       PartitionDot(
2760           PartitionedHlo(lhs.hlo(),
2761                          GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()),
2762                          inner_state),
2763           PartitionedHlo(rhs.hlo(),
2764                          GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()),
2765                          inner_state),
2766           inner_output_base_shape, inner_output_sharding, dims_mapping,
2767           num_partitions / group_count, inner_creator, conv_window, module,
2768           original_hlo, options, b, windowed_dot_general_loops, visitor));
2769 
2770   // Reenables the inner reshard if there is an inner dot and no actual
2771   // windowed_dot_general_loops generated.
2772   if (inner_dot && (new_num_windowed_loops == original_num_windowed_loops)) {
2773     maybe_windowed_dot = inner_dot;
2774   } else if (maybe_windowed_dot) {
2775     if (options.choose_faster_windowed_einsum_over_mem) {
2776       HloInstruction* ar = lhs.state().partitioner->AllReduceAlongShardingDims(
2777           b, maybe_windowed_dot, lhs_sharding, lhs.state().next_channel_id,
2778           lhs_dims, lhs.state().collective_ops_creator,
2779           MakeBinaryAdd(output_base_shape.element_type(), module));
2780       maybe_windowed_dot = ar;
2781       outer_output_tmp_sharding =
2782           hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
2783               outer_output_tmp_sharding, output_slice_dims);
2784     }
2785   } else {
2786     return nullptr;
2787   }
2788 
2789   maybe_windowed_dot->set_sharding(outer_output_tmp_sharding);
2790   auto d = PartitionedHlo(maybe_windowed_dot, output_base_shape, lhs.state())
2791                .Reshard(output_sharding)
2792                .hlo();
2793   return d;
2794 }
2795 
ConvertDimsMappingWithFeatureGroupCount(const DotConvDimsMapping & dims_mapping,HloInstruction * original_hlo)2796 DotConvDimsMapping ConvertDimsMappingWithFeatureGroupCount(
2797     const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) {
2798   const auto& dnums = original_hlo->convolution_dimension_numbers();
2799   DotConvDimsMapping new_dims_mapping;
2800   new_dims_mapping.batch_dims = dims_mapping.batch_dims;
2801   new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims;
2802   // Append batch dims.
2803   new_dims_mapping.batch_dims.emplace_back();
2804   new_dims_mapping.batch_dims.back().lhs = dnums.input_feature_dimension();
2805   new_dims_mapping.batch_dims.back().rhs =
2806       dnums.kernel_output_feature_dimension();
2807   new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension();
2808   new_dims_mapping.batch_dims.back().spatial = -1;
2809   // Setup non contracting dims.
2810   new_dims_mapping.lhs_non_contracting_dims.emplace_back();
2811   new_dims_mapping.lhs_non_contracting_dims.back().lhs =
2812       dnums.input_batch_dimension();
2813   new_dims_mapping.rhs_non_contracting_dims.emplace_back();
2814   new_dims_mapping.rhs_non_contracting_dims.back().rhs =
2815       dnums.kernel_input_feature_dimension();
2816   return new_dims_mapping;
2817 }
2818 
ConvertDimsMappingWithBatchGroupCount(const DotConvDimsMapping & dims_mapping,HloInstruction * original_hlo)2819 DotConvDimsMapping ConvertDimsMappingWithBatchGroupCount(
2820     const DotConvDimsMapping& dims_mapping, HloInstruction* original_hlo) {
2821   const auto& dnums = original_hlo->convolution_dimension_numbers();
2822   DotConvDimsMapping new_dims_mapping;
2823   new_dims_mapping.batch_dims = dims_mapping.batch_dims;
2824   new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims;
2825   new_dims_mapping.contracting_dims = dims_mapping.contracting_dims;
2826   // Append batch dims.
2827   new_dims_mapping.batch_dims.emplace_back();
2828   new_dims_mapping.batch_dims.back().lhs = dnums.input_batch_dimension();
2829   new_dims_mapping.batch_dims.back().rhs =
2830       dnums.kernel_output_feature_dimension();
2831   new_dims_mapping.batch_dims.back().output = dnums.output_feature_dimension();
2832   new_dims_mapping.batch_dims.back().spatial = -1;
2833   return new_dims_mapping;
2834 }
2835 
2836 // Estimate the number of iterations of a subsequent windowed einsum
2837 // partitioning if its partitioned in the non-contracting dimensions.
2838 // First value returned is the estimate of the number of iterations if LHS is
2839 // matched while the second is the number of iterations if RHS is matched.
2840 std::pair<std::optional<int64_t>, std::optional<int64_t>>
EstimateWindowedEinsumIterationsForNonContractingPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64_t num_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t lhs_matching_partitions,int64_t rhs_matching_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions,const Window & conv_window)2841 EstimateWindowedEinsumIterationsForNonContractingPartitioning(
2842     const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
2843     const PartitionedHlo& rhs, const Shape& output_base_shape,
2844     const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
2845     int64_t num_partitions, int64_t lhs_non_contracting_partitions,
2846     int64_t rhs_non_contracting_partitions, int64_t lhs_matching_partitions,
2847     int64_t rhs_matching_partitions, int64_t lhs_contracting_partitions,
2848     int64_t rhs_contracting_partitions,
2849     int64_t output_lhs_non_contracting_partitions,
2850     int64_t output_rhs_non_contracting_partitions, int64_t lhs_batch_partitions,
2851     int64_t rhs_batch_partitions, const Window& conv_window) {
2852   const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
2853       dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
2854       output_base_shape.rank());
2855   auto subsequent_einsum_iterations_estimate =
2856       [&](bool assume_lhs_match) -> std::optional<int64_t> {
2857     const std::vector<DotConvDimsMapping::DimsMapping>&
2858         matching_non_contracting_dims =
2859             assume_lhs_match ? dims_mapping.lhs_non_contracting_dims
2860                              : dims_mapping.rhs_non_contracting_dims;
2861     const std::vector<DotConvDimsMapping::DimsMapping>&
2862         other_non_contracting_dims =
2863             assume_lhs_match ? dims_mapping.rhs_non_contracting_dims
2864                              : dims_mapping.lhs_non_contracting_dims;
2865     const std::vector<int64_t>& output_to_matching_indices =
2866         assume_lhs_match ? indices_map.output_to_lhs_indices
2867                          : indices_map.output_to_rhs_indices;
2868     const std::vector<int64_t>& output_to_other_indices =
2869         assume_lhs_match ? indices_map.output_to_rhs_indices
2870                          : indices_map.output_to_lhs_indices;
2871     const std::vector<int64_t>& matching_to_output_indices =
2872         assume_lhs_match ? indices_map.lhs_to_output_indices
2873                          : indices_map.rhs_to_output_indices;
2874     const std::vector<int64_t>& other_to_output_indices =
2875         assume_lhs_match ? indices_map.rhs_to_output_indices
2876                          : indices_map.lhs_to_output_indices;
2877     const HloSharding& matching_sharding =
2878         assume_lhs_match ? lhs.sharding() : rhs.sharding();
2879     const HloSharding& other_sharding =
2880         assume_lhs_match ? rhs.sharding() : lhs.sharding();
2881     const PartitionedHlo& matching_partitioned = assume_lhs_match ? lhs : rhs;
2882     const PartitionedHlo& other_partitioned = assume_lhs_match ? rhs : lhs;
2883     const int64_t matching_non_contracting_partitions =
2884         assume_lhs_match ? lhs_non_contracting_partitions
2885                          : rhs_non_contracting_partitions;
2886     const int64_t other_non_contracting_partitions =
2887         assume_lhs_match ? rhs_non_contracting_partitions
2888                          : lhs_non_contracting_partitions;
2889     const int64_t matching_contracting_partitions =
2890         assume_lhs_match ? lhs_contracting_partitions
2891                          : rhs_contracting_partitions;
2892     const int64_t other_contracting_partitions =
2893         assume_lhs_match ? rhs_contracting_partitions
2894                          : lhs_contracting_partitions;
2895     const int64_t output_matching_non_contracting_partitions =
2896         assume_lhs_match ? output_lhs_non_contracting_partitions
2897                          : output_rhs_non_contracting_partitions;
2898     const int64_t output_other_non_contracting_partitions =
2899         assume_lhs_match ? output_rhs_non_contracting_partitions
2900                          : output_lhs_non_contracting_partitions;
2901     const int64_t matching_batch_partitions =
2902         assume_lhs_match ? lhs_batch_partitions : rhs_batch_partitions;
2903     const int64_t other_batch_partitions =
2904         assume_lhs_match ? rhs_batch_partitions : lhs_batch_partitions;
2905     const int64_t matching_matched_non_contracting_partitions =
2906         assume_lhs_match ? lhs_non_contracting_partitions
2907                          : rhs_non_contracting_partitions;
2908     std::vector<int64_t> output_dims;
2909     output_dims.reserve(matching_non_contracting_dims.size());
2910     for (const DotConvDimsMapping::DimsMapping& dim :
2911          matching_non_contracting_dims) {
2912       output_dims.push_back(dim.output);
2913     }
2914     GroupedSharding output_grouped =
2915         hlo_sharding_util::GroupShardingOnDims(output_sharding, output_dims);
2916     GroupedSharding matching_grouped =
2917         GetNonContractingPartitionGroupedShardingForMatchedOperand(
2918             assume_lhs_match, matching_sharding, output_sharding,
2919             matching_non_contracting_dims);
2920     std::optional<GroupedSharding> other_grouped =
2921         GetNonContractingPartitionGroupedShardingForOtherOperand(
2922             assume_lhs_match, output_base_shape,
2923             other_partitioned.hlo()->shape(), other_contracting_partitions,
2924             other_non_contracting_partitions, matching_contracting_partitions,
2925             output_other_non_contracting_partitions, other_sharding,
2926             output_sharding, matching_non_contracting_dims,
2927             other_non_contracting_dims, dims_mapping.contracting_dims);
2928     if (!other_grouped) {
2929       return std::nullopt;
2930     }
2931     std::optional<HloSharding> output_sharding_transposed_to_match_matching =
2932         hlo_sharding_util::TransposeShardingWithCollapsedDims(
2933             output_grouped.sharding, output_to_matching_indices,
2934             matching_to_output_indices);
2935     std::optional<HloSharding> output_sharding_transposed_to_match_other =
2936         hlo_sharding_util::TransposeShardingWithCollapsedDims(
2937             output_grouped.sharding, output_to_other_indices,
2938             other_to_output_indices);
2939     auto lhs_sharding_transposed_to_match_rhs =
2940         hlo_sharding_util::TransposeShardingWithCollapsedDims(
2941             lhs.sharding(), indices_map.lhs_to_rhs_indices,
2942             indices_map.rhs_to_lhs_indices);
2943     auto rhs_sharding_transposed_to_match_lhs =
2944         hlo_sharding_util::TransposeShardingWithCollapsedDims(
2945             rhs.sharding(), indices_map.rhs_to_lhs_indices,
2946             indices_map.lhs_to_rhs_indices);
2947     const int64_t new_num_partitions =
2948         num_partitions / matching_non_contracting_partitions;
2949     std::optional<WindowedEinsumConfig> e_config =
2950         GetWindowedEinsumConfiguration(
2951             new_num_partitions, output_matching_non_contracting_partitions,
2952             output_other_non_contracting_partitions,
2953             other_contracting_partitions, other_non_contracting_partitions,
2954             other_batch_partitions, matching_contracting_partitions,
2955             matching_non_contracting_partitions /
2956                 matching_matched_non_contracting_partitions,
2957             matching_batch_partitions,
2958             ShapeSizeInBytes(other_partitioned.base_shape()),
2959             ShapeSizeInBytes(matching_partitioned.base_shape()) /
2960                 matching_non_contracting_partitions,
2961             ShapeSizeInBytes(
2962                 GetPerGroupBaseShape(output_grouped, output_base_shape)),
2963             options, output_sharding_transposed_to_match_matching,
2964             output_sharding_transposed_to_match_other,
2965             lhs_sharding_transposed_to_match_rhs,
2966             rhs_sharding_transposed_to_match_lhs, matching_grouped.sharding,
2967             other_grouped->sharding, conv_window, dims_mapping);
2968     return e_config ? new_num_partitions : std::optional<int64_t>(std::nullopt);
2969   };
2970   std::optional<int64_t> lhs_matching_iterations;
2971   if (lhs_matching_partitions != 0) {
2972     lhs_matching_iterations = subsequent_einsum_iterations_estimate(true);
2973   }
2974   std::optional<int64_t> rhs_matching_iterations;
2975   if (rhs_matching_partitions != 0) {
2976     rhs_matching_iterations = subsequent_einsum_iterations_estimate(false);
2977   }
2978   return std::make_pair(lhs_matching_iterations, rhs_matching_iterations);
2979 }
2980 
2981 // Return if we should prioritize partitioning in the contracting dimensions
2982 // first then non-contracting dimensions if we estimate that would be faster.
2983 // The general idea is similar as the one in
2984 // LhsIsBestMatchForNonContractingPartitioning with one all-gather replaced by
2985 // reduce-scatter.
PrioritizeContractingDimensionsPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64_t num_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions,int64_t output_batch_partitions,bool require_matching_devices_to_group,SpmdBuilder * b,const Window & conv_window,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,SpmdPartitioningVisitor * visitor)2986 bool PrioritizeContractingDimensionsPartitioning(
2987     const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
2988     const PartitionedHlo& rhs, const Shape& output_base_shape,
2989     const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
2990     int64_t num_partitions, int64_t lhs_non_contracting_partitions,
2991     int64_t rhs_non_contracting_partitions, int64_t lhs_contracting_partitions,
2992     int64_t rhs_contracting_partitions,
2993     int64_t output_lhs_non_contracting_partitions,
2994     int64_t output_rhs_non_contracting_partitions, int64_t lhs_batch_partitions,
2995     int64_t rhs_batch_partitions, int64_t output_batch_partitions,
2996     bool require_matching_devices_to_group, SpmdBuilder* b,
2997     const Window& conv_window,
2998     const std::function<StatusOr<HloInstruction*>(
2999         HloInstruction*, HloInstruction*, SpmdBuilder*,
3000         const Window& conv_window)>& create_sharded_dot,
3001     SpmdPartitioningVisitor* visitor) {
3002   const bool may_group_on_lhs_non_contracting =
3003       lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
3004       lhs_non_contracting_partitions > 1;
3005   const bool may_group_on_rhs_non_contracting =
3006       rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
3007       rhs_non_contracting_partitions > 1;
3008   if (!options.choose_faster_windowed_einsum_over_mem) {
3009     return false;
3010   }
3011   // Check only for perfect dimensions match for now.
3012   if (!may_group_on_lhs_non_contracting && !may_group_on_rhs_non_contracting) {
3013     return false;
3014   }
3015   std::optional<int64_t> lhs_matching_iterations;
3016   std::optional<int64_t> rhs_matching_iterations;
3017   const int64_t lhs_matching_non_contracting_partitions =
3018       may_group_on_lhs_non_contracting ? lhs_non_contracting_partitions : 0;
3019   const int64_t rhs_matching_non_contracting_partitions =
3020       may_group_on_rhs_non_contracting ? rhs_non_contracting_partitions : 0;
3021   std::tie(lhs_matching_iterations, rhs_matching_iterations) =
3022       EstimateWindowedEinsumIterationsForNonContractingPartitioning(
3023           dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
3024           num_partitions, lhs_non_contracting_partitions,
3025           rhs_non_contracting_partitions,
3026           lhs_matching_non_contracting_partitions,
3027           rhs_matching_non_contracting_partitions, lhs_contracting_partitions,
3028           rhs_contracting_partitions, output_lhs_non_contracting_partitions,
3029           output_rhs_non_contracting_partitions, lhs_batch_partitions,
3030           rhs_batch_partitions, conv_window);
3031   if (!lhs_matching_iterations && !rhs_matching_iterations) {
3032     return false;
3033   }
3034   // Be conservative and handle only case where the two partitions in rhs and
3035   // lhs match
3036   if (!(lhs_contracting_partitions == rhs_contracting_partitions &&
3037         lhs_contracting_partitions > 1)) {
3038     return false;
3039   }
3040   // Estimate the iterations in the case we perform the partitioning on the
3041   // contracting dimensions instead.
3042   std::vector<int64_t> lhs_dims;
3043   std::vector<int64_t> rhs_dims;
3044   int64_t group_count = 1;
3045   for (const auto& dim : dims_mapping.contracting_dims) {
3046     lhs_dims.push_back(dim.lhs);
3047     rhs_dims.push_back(dim.rhs);
3048     group_count *= lhs.sharding().tile_assignment().dim(dim.lhs);
3049   }
3050   HloSharding lhs_sharding = HloSharding::Replicate();
3051   HloSharding rhs_sharding = HloSharding::Replicate();
3052   std::tie(lhs_sharding, rhs_sharding) =
3053       GetDotGroupPartitionContractingLhsRhsShardings(
3054           lhs, rhs, dims_mapping.contracting_dims);
3055   auto lhs_grouped =
3056       hlo_sharding_util::GroupShardingOnDims(lhs_sharding, lhs_dims);
3057   auto rhs_grouped =
3058       hlo_sharding_util::GroupShardingOnDims(rhs_sharding, rhs_dims);
3059   rhs_grouped = AlignGroupsWith(rhs_grouped, lhs_grouped);
3060   rhs_sharding = UngroupSharding(rhs_grouped);
3061 
3062   if (require_matching_devices_to_group && rhs.sharding() != rhs_sharding) {
3063     return false;
3064   }
3065   const int64_t new_num_partitions =
3066       num_partitions / lhs_contracting_partitions;
3067 
3068   HloSharding inner_output_sharding = HloSharding::Replicate();
3069   HloSharding outer_output_tmp_sharding = HloSharding::Replicate();
3070   std::vector<int64_t> output_slice_dims;
3071   std::tie(inner_output_sharding, outer_output_tmp_sharding) =
3072       GetDotGroupPartitionContractingOutputShardings(
3073           dims_mapping, lhs_grouped, output_base_shape, output_sharding,
3074           group_count, output_lhs_non_contracting_partitions,
3075           output_rhs_non_contracting_partitions, output_batch_partitions,
3076           &output_slice_dims);
3077   Shape inner_output_base_shape = output_base_shape;
3078   if (!output_slice_dims.empty()) {
3079     std::vector<int64_t> non_group_dims;
3080     for (int64_t i = 0; i < output_base_shape.rank(); ++i) {
3081       if (!absl::c_linear_search(output_slice_dims, i)) {
3082         non_group_dims.push_back(i);
3083       }
3084     }
3085     inner_output_base_shape = MakePartitionedShape(
3086         output_base_shape,
3087         hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(
3088             output_sharding, non_group_dims));
3089   }
3090   int64_t new_output_lhs_non_contracting_partitions = 1;
3091   int64_t new_output_rhs_non_contracting_partitions = 1;
3092   if (!inner_output_sharding.IsTileMaximal()) {
3093     for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
3094       new_output_lhs_non_contracting_partitions *=
3095           inner_output_sharding.tile_assignment().dim(dim.output);
3096     }
3097     for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
3098       if (dim.output != -1) {
3099         new_output_rhs_non_contracting_partitions *=
3100             inner_output_sharding.tile_assignment().dim(dim.output);
3101       }
3102     }
3103   }
3104 
3105   const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
3106       dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
3107       inner_output_base_shape.rank());
3108   std::optional<HloSharding> output_sharding_transposed_to_match_lhs =
3109       hlo_sharding_util::TransposeShardingWithCollapsedDims(
3110           inner_output_sharding, indices_map.output_to_lhs_indices,
3111           indices_map.lhs_to_output_indices);
3112   std::optional<HloSharding> output_sharding_transposed_to_match_rhs =
3113       hlo_sharding_util::TransposeShardingWithCollapsedDims(
3114           inner_output_sharding, indices_map.output_to_rhs_indices,
3115           indices_map.rhs_to_output_indices);
3116   auto lhs_sharding_transposed_to_match_rhs =
3117       hlo_sharding_util::TransposeShardingWithCollapsedDims(
3118           lhs_sharding, indices_map.lhs_to_rhs_indices,
3119           indices_map.rhs_to_lhs_indices);
3120   auto rhs_sharding_transposed_to_match_lhs =
3121       hlo_sharding_util::TransposeShardingWithCollapsedDims(
3122           rhs_sharding, indices_map.rhs_to_lhs_indices,
3123           indices_map.lhs_to_rhs_indices);
3124   std::optional<WindowedEinsumConfig> e_config = GetWindowedEinsumConfiguration(
3125       new_num_partitions, new_output_lhs_non_contracting_partitions,
3126       new_output_rhs_non_contracting_partitions, 1,
3127       rhs_non_contracting_partitions, rhs_batch_partitions, 1,
3128       lhs_non_contracting_partitions, lhs_batch_partitions,
3129       ShapeSizeInBytes(GetPerGroupBaseShape(rhs_grouped, rhs.base_shape())),
3130       ShapeSizeInBytes(GetPerGroupBaseShape(lhs_grouped, lhs.base_shape())),
3131       ShapeSizeInBytes(inner_output_base_shape), options,
3132       output_sharding_transposed_to_match_lhs,
3133       output_sharding_transposed_to_match_rhs,
3134       lhs_sharding_transposed_to_match_rhs,
3135       rhs_sharding_transposed_to_match_lhs, lhs_grouped.sharding,
3136       rhs_grouped.sharding, conv_window, dims_mapping);
3137   if (!e_config) {
3138     return false;
3139   }
3140 
3141   int64_t num_iterations = lhs_matching_iterations ? *lhs_matching_iterations
3142                                                    : *rhs_matching_iterations;
3143   HloInstruction* other_hlo = lhs_matching_iterations ? rhs.hlo() : lhs.hlo();
3144   auto other_non_contracting_dims = lhs_matching_iterations
3145                                         ? dims_mapping.rhs_non_contracting_dims
3146                                         : dims_mapping.lhs_non_contracting_dims;
3147   auto other_sharding =
3148       lhs_matching_iterations ? rhs.sharding() : lhs.sharding();
3149   auto other_grouped = lhs_matching_iterations ? rhs_grouped : lhs_grouped;
3150   Shape other_base_shape =
3151       lhs_matching_iterations ? rhs.base_shape() : lhs.base_shape();
3152 
3153   const int64_t all_gather_bytes =
3154       ShapeUtil::ByteSizeOf(other_hlo->shape()) * new_num_partitions;
3155   const int64_t reduce_scatter_bytes =
3156       ShapeUtil::ByteSizeOf(inner_output_base_shape) / new_num_partitions *
3157       num_iterations;
3158   std::vector<int64_t> ag_replication_dims;
3159   ag_replication_dims.reserve(other_non_contracting_dims.size());
3160   for (const DotConvDimsMapping::DimsMapping& dim :
3161        other_non_contracting_dims) {
3162     ag_replication_dims.push_back(lhs_matching_iterations ? dim.rhs : dim.lhs);
3163   }
3164   auto all_gather_subgroups =
3165       GetPartitionGroupsForReplication(other_sharding, ag_replication_dims);
3166   auto reduce_scatter_subgroups = GetPartitionGroupsForReplication(
3167       outer_output_tmp_sharding, output_slice_dims);
3168   const double all_gather_time_in_ms = visitor->GetCommunicationTimeInMilliSec(
3169       all_gather_bytes, visitor->CreateReplicaGroups(all_gather_subgroups));
3170   const double reduce_scatter_time_in_ms =
3171       visitor->GetCommunicationTimeInMilliSec(
3172           reduce_scatter_bytes,
3173           visitor->CreateReplicaGroups(reduce_scatter_subgroups));
3174 
3175   Shape other_original_shape = other_hlo->shape();
3176   *other_hlo->mutable_shape() =
3177       GetPerGroupBaseShape(other_grouped, other_base_shape);
3178   HloInstruction* dot =
3179       create_sharded_dot(lhs_matching_iterations ? lhs.hlo() : other_hlo,
3180                          lhs_matching_iterations ? other_hlo : rhs.hlo(), b,
3181                          conv_window)
3182           .ValueOrDie();
3183   const double computation_time_in_ms =
3184       visitor->GetComputationTimeInMilliSec(dot);
3185   *other_hlo->mutable_shape() = other_original_shape;
3186 
3187   VLOG(2) << "lhs: " << lhs.hlo()->ToString() << "\n"
3188           << "rhs: " << rhs.hlo()->ToString() << "\n"
3189           << "new_num_partitions: " << new_num_partitions
3190           << " num_iterations: " << num_iterations << "\n"
3191           << "all_gather_bytes: " << all_gather_bytes
3192           << " reduce_scatter_bytes: " << reduce_scatter_bytes << "\n"
3193           << "all_gather_time_in_ms: " << all_gather_time_in_ms
3194           << " reduce_scatter_time_in_ms: " << reduce_scatter_time_in_ms << "\n"
3195           << "dot: " << dot->ToString() << "\n"
3196           << "computation_time_in_ms: " << computation_time_in_ms;
3197   if (computation_time_in_ms == 0.0 || all_gather_time_in_ms == 0.0 ||
3198       reduce_scatter_time_in_ms == 0.0) {
3199     const int64_t min_nc_iterations = std::min(
3200         lhs_matching_iterations ? *lhs_matching_iterations : INT64_MAX,
3201         rhs_matching_iterations ? *rhs_matching_iterations : INT64_MAX);
3202     return min_nc_iterations > new_num_partitions;
3203   } else if ((computation_time_in_ms <= all_gather_time_in_ms) &&
3204              (computation_time_in_ms <= reduce_scatter_time_in_ms)) {
3205     return all_gather_bytes / new_num_partitions <
3206            reduce_scatter_bytes / num_iterations;
3207   } else {
3208     return all_gather_time_in_ms > reduce_scatter_time_in_ms;
3209   }
3210 }
3211 
3212 // Return if it would be better to match the LHS operand or RHS operand
3213 // of a dot for non-contracting partitioning.
LhsIsBestMatchForNonContractingPartitioning(const DotConvDimsMapping & dims_mapping,const PartitionedHlo & lhs,const PartitionedHlo & rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const SpmdPartitionerOptions & options,int64_t num_partitions,int64_t lhs_non_contracting_partitions,int64_t rhs_non_contracting_partitions,int64_t lhs_matching_partitions,int64_t rhs_matching_partitions,int64_t lhs_contracting_partitions,int64_t rhs_contracting_partitions,int64_t output_lhs_non_contracting_partitions,int64_t output_rhs_non_contracting_partitions,int64_t lhs_batch_partitions,int64_t rhs_batch_partitions,SpmdBuilder * b,const Window & conv_window,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,SpmdPartitioningVisitor * visitor)3214 bool LhsIsBestMatchForNonContractingPartitioning(
3215     const DotConvDimsMapping& dims_mapping, const PartitionedHlo& lhs,
3216     const PartitionedHlo& rhs, const Shape& output_base_shape,
3217     const HloSharding& output_sharding, const SpmdPartitionerOptions& options,
3218     int64_t num_partitions, int64_t lhs_non_contracting_partitions,
3219     int64_t rhs_non_contracting_partitions, int64_t lhs_matching_partitions,
3220     int64_t rhs_matching_partitions, int64_t lhs_contracting_partitions,
3221     int64_t rhs_contracting_partitions,
3222     int64_t output_lhs_non_contracting_partitions,
3223     int64_t output_rhs_non_contracting_partitions, int64_t lhs_batch_partitions,
3224     int64_t rhs_batch_partitions, SpmdBuilder* b, const Window& conv_window,
3225     const std::function<StatusOr<HloInstruction*>(
3226         HloInstruction*, HloInstruction*, SpmdBuilder*,
3227         const Window& conv_window)>& create_sharded_dot,
3228     SpmdPartitioningVisitor* visitor) {
3229   const bool may_group_on_lhs_non_contracting =
3230       lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
3231       lhs_non_contracting_partitions > 1;
3232   const bool may_group_on_rhs_non_contracting =
3233       rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
3234       rhs_non_contracting_partitions > 1;
3235   // If both match output non-contracting dimensions, choose the one which
3236   // will result in smaller replication of the other operand.
3237   bool lhs_matching = may_group_on_lhs_non_contracting &&
3238                       (!may_group_on_rhs_non_contracting ||
3239                        lhs_non_contracting_partitions *
3240                                ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <
3241                            rhs_non_contracting_partitions *
3242                                ShapeUtil::ByteSizeOf(lhs.hlo()->shape()));
3243   // If both grouping are available and the option to choose faster windowed
3244   // einsums vs saving memory is enabled then try to determine which of the
3245   // operands will have more overlapping benefits for the windowed einsum
3246   // when matched (if a windowed einsum is gonna be generated at all).
3247   // 1) When computation is shorter than both all_gathers, we choose to overlap
3248   // with the smaller all_gather as it has potentially smaller extra
3249   // collective-permute overhead outside of the while loop; 2) Otherwise, we
3250   // choose the all_gather with longer runtime to overlap with.
3251   if (may_group_on_lhs_non_contracting && may_group_on_rhs_non_contracting &&
3252       options.choose_faster_windowed_einsum_over_mem) {
3253     const DotDimensionIndexMapping indices_map = ComputeDimensionIndexMapping(
3254         dims_mapping, lhs.base_shape().rank(), rhs.base_shape().rank(),
3255         output_base_shape.rank());
3256     std::optional<int64_t> lhs_matching_iterations;
3257     std::optional<int64_t> rhs_matching_iterations;
3258     std::tie(lhs_matching_iterations, rhs_matching_iterations) =
3259         EstimateWindowedEinsumIterationsForNonContractingPartitioning(
3260             dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
3261             num_partitions, lhs_non_contracting_partitions,
3262             rhs_non_contracting_partitions, lhs_matching_partitions,
3263             rhs_matching_partitions, lhs_contracting_partitions,
3264             rhs_contracting_partitions, output_lhs_non_contracting_partitions,
3265             output_rhs_non_contracting_partitions, lhs_batch_partitions,
3266             rhs_batch_partitions, conv_window);
3267     if (lhs_matching_iterations && rhs_matching_iterations) {
3268       const int64_t lhs_all_gather_bytes =
3269           ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) *
3270           rhs_non_contracting_partitions;
3271       const int64_t rhs_all_gather_bytes =
3272           ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) *
3273           lhs_non_contracting_partitions;
3274       auto lhs_grouped =
3275           GetNonContractingPartitionGroupedShardingForMatchedOperand(
3276               /*lhs_matching=*/true, lhs.sharding(), output_sharding,
3277               dims_mapping.lhs_non_contracting_dims);
3278       auto lhs_all_gather_subgroups = lhs_grouped.device_groups;
3279       auto rhs_grouped =
3280           GetNonContractingPartitionGroupedShardingForMatchedOperand(
3281               /*lhs_matching=*/false, rhs.sharding(), output_sharding,
3282               dims_mapping.rhs_non_contracting_dims);
3283       auto rhs_all_gather_subgroups = rhs_grouped.device_groups;
3284       const double lhs_all_gather_time_in_ms =
3285           visitor->GetCommunicationTimeInMilliSec(
3286               lhs_all_gather_bytes,
3287               visitor->CreateReplicaGroups(lhs_all_gather_subgroups));
3288       const double rhs_all_gather_time_in_ms =
3289           visitor->GetCommunicationTimeInMilliSec(
3290               rhs_all_gather_bytes,
3291               visitor->CreateReplicaGroups(rhs_all_gather_subgroups));
3292 
3293       HloInstruction* compute_lhs = lhs.hlo();
3294       Shape lhs_original_shape = compute_lhs->shape();
3295       *compute_lhs->mutable_shape() =
3296           GetPerGroupBaseShape(lhs_grouped, lhs.base_shape());
3297       HloInstruction* compute_rhs = rhs.hlo();
3298       Shape rhs_original_shape = compute_rhs->shape();
3299       *compute_rhs->mutable_shape() =
3300           GetPerGroupBaseShape(rhs_grouped, rhs.base_shape());
3301       HloInstruction* dot =
3302           create_sharded_dot(compute_lhs, compute_rhs, b, conv_window)
3303               .ValueOrDie();
3304       const double computation_time_in_ms =
3305           visitor->GetComputationTimeInMilliSec(dot);
3306       *compute_lhs->mutable_shape() = lhs_original_shape;
3307       *compute_rhs->mutable_shape() = rhs_original_shape;
3308 
3309       VLOG(2) << "lhs: " << lhs.hlo()->ToString() << "\n"
3310               << "rhs: " << rhs.hlo()->ToString() << "\n"
3311               << "lhs_non_contracting_partitions: "
3312               << lhs_non_contracting_partitions
3313               << " rhs_non_contracting_partitions: "
3314               << rhs_non_contracting_partitions << "\n"
3315               << "lhs_matching_iterations: " << *lhs_matching_iterations
3316               << " rhs_matching_iterations: " << *rhs_matching_iterations
3317               << "\n"
3318               << "lhs_all_gather_bytes: " << lhs_all_gather_bytes
3319               << " rhs_all_gather_bytes: " << rhs_all_gather_bytes << "\n"
3320               << "lhs_all_gather_time_in_ms: " << lhs_all_gather_time_in_ms
3321               << " rhs_all_gather_time_in_ms: " << rhs_all_gather_time_in_ms
3322               << "\n"
3323               << "dot: " << dot->ToString() << "\n"
3324               << "computation_time_in_ms: " << computation_time_in_ms;
3325       if (computation_time_in_ms == 0.0 || lhs_all_gather_time_in_ms == 0.0 ||
3326           rhs_all_gather_time_in_ms == 0.0) {
3327         lhs_matching = *lhs_matching_iterations < *rhs_matching_iterations;
3328       } else if ((computation_time_in_ms <= lhs_all_gather_time_in_ms) &&
3329                  (computation_time_in_ms <= rhs_all_gather_time_in_ms)) {
3330         lhs_matching = lhs_all_gather_bytes / rhs_non_contracting_partitions >
3331                        rhs_all_gather_bytes / lhs_non_contracting_partitions;
3332       } else {
3333         lhs_matching = lhs_all_gather_time_in_ms > rhs_all_gather_time_in_ms;
3334       }
3335     }
3336   }
3337   return lhs_matching;
3338 }
3339 
3340 // Recursive partitioning function. If there are partial dimensions matching
3341 // in the operands and output, group the devices and recursively partition
3342 // the in-group dot.
PartitionDot(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,bool require_matching_devices_to_group,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,SpmdPartitioningVisitor * visitor)3343 StatusOr<HloInstruction*> PartitionDot(
3344     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
3345     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
3346     int64_t num_partitions,
3347     const std::function<StatusOr<HloInstruction*>(
3348         HloInstruction*, HloInstruction*, SpmdBuilder*,
3349         const Window& conv_window)>& create_sharded_dot,
3350     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
3351     bool require_matching_devices_to_group,
3352     const SpmdPartitionerOptions& options, SpmdBuilder* b,
3353     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
3354         windowed_dot_general_loops,
3355     SpmdPartitioningVisitor* visitor) {
3356   // If lhs‘ hlo and rhs' hlo are identical, make a copy for rhs.
3357   if (lhs.hlo() == rhs.hlo()) {
3358     auto copy_hlo = b->AddInstruction(HloInstruction::CreateUnary(
3359         rhs.hlo()->shape(), HloOpcode::kCopy, rhs.hlo()));
3360     copy_hlo->set_sharding(rhs.sharding());
3361     rhs = PartitionedHlo(copy_hlo, rhs.base_shape(), rhs.state());
3362   }
3363 
3364   // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output.
3365   auto get_partitions_for_dims =
3366       [&](const HloSharding& sharding,
3367           absl::Span<const DotConvDimsMapping::DimsMapping> dims,
3368           int lhs_rhs_or_output) {
3369         int64_t partitions = 1;
3370         if (sharding.IsTileMaximal()) {
3371           return partitions;
3372         }
3373         for (const auto& dim : dims) {
3374           if (lhs_rhs_or_output == 0) {
3375             partitions *= sharding.tile_assignment().dim(dim.lhs);
3376           } else if (lhs_rhs_or_output == 1) {
3377             partitions *= sharding.tile_assignment().dim(dim.rhs);
3378           } else {
3379             CHECK_EQ(lhs_rhs_or_output, 2);
3380             partitions *= sharding.tile_assignment().dim(dim.output);
3381           }
3382         }
3383         return partitions;
3384       };
3385   const int64_t lhs_batch_partitions =
3386       get_partitions_for_dims(lhs.sharding(), dims_mapping.batch_dims, 0);
3387   const int64_t rhs_batch_partitions =
3388       get_partitions_for_dims(rhs.sharding(), dims_mapping.batch_dims, 1);
3389   const int64_t output_batch_partitions =
3390       get_partitions_for_dims(output_sharding, dims_mapping.batch_dims, 2);
3391   const int64_t lhs_contracting_partitions =
3392       get_partitions_for_dims(lhs.sharding(), dims_mapping.contracting_dims, 0);
3393   const int64_t rhs_contracting_partitions =
3394       get_partitions_for_dims(rhs.sharding(), dims_mapping.contracting_dims, 1);
3395   const int64_t lhs_non_contracting_partitions = get_partitions_for_dims(
3396       lhs.sharding(), dims_mapping.lhs_non_contracting_dims, 0);
3397   const int64_t rhs_non_contracting_partitions = get_partitions_for_dims(
3398       rhs.sharding(), dims_mapping.rhs_non_contracting_dims, 1);
3399   const int64_t output_lhs_non_contracting_partitions = get_partitions_for_dims(
3400       output_sharding, dims_mapping.lhs_non_contracting_dims, 2);
3401   const int64_t output_rhs_non_contracting_partitions = get_partitions_for_dims(
3402       output_sharding, dims_mapping.rhs_non_contracting_dims, 2);
3403   const int64_t lhs_conv_spatial_partitions = get_partitions_for_dims(
3404       lhs.sharding(), dims_mapping.conv_spatial_dims, 0);
3405   const int64_t rhs_conv_spatial_partitions = get_partitions_for_dims(
3406       rhs.sharding(), dims_mapping.conv_spatial_dims, 1);
3407   const int64_t output_conv_spatial_partitions = get_partitions_for_dims(
3408       output_sharding, dims_mapping.conv_spatial_dims, 2);
3409   // Before we find partial matches along the dimensions, invoke base case
3410   // again without may_reshard_without_detecting_match.
3411 
3412   // Try partition the purely spatially-partitioned convolution with
3413   // convolution spatial dimension partitioned or depthwise parallel
3414   // dimension partitioned.
3415   bool is_conv_spatial_dim_partitioned =
3416       (lhs_conv_spatial_partitions > 1 || rhs_conv_spatial_partitions > 1 ||
3417        output_conv_spatial_partitions > 1);
3418   bool is_conv_batch_or_contracting_dim_partitioned =
3419       (lhs_batch_partitions > 1 || rhs_batch_partitions > 1 ||
3420        output_batch_partitions > 1 ||
3421        (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1));
3422   if ((!dims_mapping.conv_spatial_dims.empty() &&
3423        is_conv_spatial_dim_partitioned &&
3424        !is_conv_batch_or_contracting_dim_partitioned) ||
3425       (original_hlo->opcode() == HloOpcode::kConvolution &&
3426        (original_hlo->batch_group_count() > 1 ||
3427         original_hlo->feature_group_count() > 1))) {
3428     // Partition with kernel_input_feature_dim > 1 and feature_group_count >
3429     // 1 is not supported.
3430     const auto& dnums = original_hlo->convolution_dimension_numbers();
3431     if (original_hlo->feature_group_count() > 1 &&
3432         rhs.hlo()->shape().dimensions(dnums.kernel_input_feature_dimension()) >
3433             1) {
3434       return nullptr;
3435     }
3436 
3437     TF_ASSIGN_OR_RETURN(
3438         auto partitioned_conv,
3439         PartitionConvolution(lhs, rhs, output_base_shape, output_sharding,
3440                              dims_mapping, create_sharded_dot, conv_window,
3441                              original_hlo, num_partitions, options,
3442                              lhs.state().partition_id, module, b));
3443 
3444     if (partitioned_conv) {
3445       return partitioned_conv;
3446     }
3447 
3448     // Recursively partition on different types of dimensions for
3449     // convolution. Case 0.a: Group partitions by feature group count.
3450     if (original_hlo->feature_group_count() > 1 ||
3451         original_hlo->batch_group_count() > 1) {
3452       std::optional<DotConvDimsMapping> new_dims_mapping;
3453       if (original_hlo->feature_group_count() > 1) {
3454         const int64_t input_feature_dim =
3455             original_hlo->convolution_dimension_numbers()
3456                 .input_feature_dimension();
3457         const int64_t kernel_output_feature_dim =
3458             original_hlo->convolution_dimension_numbers()
3459                 .kernel_output_feature_dimension();
3460         // If the input and output feature dims are not equal, we require the
3461         // feature_group_count to be evenly partitioned; otherwise, there will
3462         // be different padding in the input/output.
3463         // TODO(xla): Use halo exchange to solve this problem. Can be a
3464         // preprocessing that uses padding/slicing to make the shape evenly
3465         // shardable.
3466         if (lhs.base_shape().dimensions(input_feature_dim) ==
3467                 rhs.base_shape().dimensions(kernel_output_feature_dim) ||
3468             (lhs.sharding().IsTiled() &&
3469              original_hlo->feature_group_count() %
3470                      ShardCountAtDim(lhs.sharding(), input_feature_dim) ==
3471                  0)) {
3472           new_dims_mapping = ConvertDimsMappingWithFeatureGroupCount(
3473               dims_mapping, original_hlo);
3474         }
3475       }
3476 
3477       if (original_hlo->batch_group_count() > 1) {
3478         const int64_t input_batch_dim =
3479             original_hlo->convolution_dimension_numbers()
3480                 .input_batch_dimension();
3481         const int64_t kernel_output_feature_dim =
3482             original_hlo->convolution_dimension_numbers()
3483                 .kernel_output_feature_dimension();
3484         if (lhs.base_shape().dimensions(input_batch_dim) ==
3485                 rhs.base_shape().dimensions(kernel_output_feature_dim) ||
3486             (lhs.sharding().IsTiled() &&
3487              original_hlo->batch_group_count() %
3488                      ShardCountAtDim(lhs.sharding(), input_batch_dim) ==
3489                  0)) {
3490           new_dims_mapping =
3491               ConvertDimsMappingWithBatchGroupCount(dims_mapping, original_hlo);
3492         }
3493       }
3494       if (!new_dims_mapping.has_value()) {
3495         return nullptr;
3496       }
3497 
3498       const int64_t conv_lhs_contracting_partitions = get_partitions_for_dims(
3499           lhs.sharding(), new_dims_mapping->contracting_dims, 0);
3500       const int64_t conv_rhs_contracting_partitions = get_partitions_for_dims(
3501           rhs.sharding(), new_dims_mapping->contracting_dims, 1);
3502       const int64_t conv_lhs_non_contracting_partitions =
3503           get_partitions_for_dims(
3504               lhs.sharding(), new_dims_mapping->lhs_non_contracting_dims, 0);
3505       const int64_t conv_rhs_non_contracting_partitions =
3506           get_partitions_for_dims(
3507               rhs.sharding(), new_dims_mapping->rhs_non_contracting_dims, 1);
3508       const int64_t conv_lhs_batch_partitions = get_partitions_for_dims(
3509           lhs.sharding(), new_dims_mapping->batch_dims, 0);
3510       const int64_t conv_rhs_batch_partitions = get_partitions_for_dims(
3511           rhs.sharding(), new_dims_mapping->batch_dims, 1);
3512       const int64_t conv_output_batch_partitions = get_partitions_for_dims(
3513           output_sharding, new_dims_mapping->batch_dims, 2);
3514       if ((conv_lhs_batch_partitions == conv_output_batch_partitions ||
3515            conv_rhs_batch_partitions == conv_output_batch_partitions) &&
3516           conv_output_batch_partitions > 1) {
3517         TF_ASSIGN_OR_RETURN(
3518             auto try_partitioned_conv,
3519             PartitionDotGroupOnBatch(
3520                 lhs, rhs, output_base_shape, output_sharding, *new_dims_mapping,
3521                 num_partitions, conv_lhs_contracting_partitions,
3522                 conv_rhs_contracting_partitions,
3523                 conv_lhs_non_contracting_partitions,
3524                 conv_rhs_non_contracting_partitions, create_sharded_dot,
3525                 conv_window, module, original_hlo,
3526                 require_matching_devices_to_group, options, b,
3527                 windowed_dot_general_loops, visitor));
3528         if (try_partitioned_conv) {
3529           return try_partitioned_conv;
3530         }
3531       }
3532       return nullptr;
3533     }
3534   }
3535 
3536   TF_ASSIGN_OR_RETURN(
3537       auto try_partitioned_dot,
3538       PartitionBaseCase(
3539           lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3540           num_partitions, create_sharded_dot, conv_window, module, original_hlo,
3541           lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
3542           lhs_contracting_partitions, rhs_contracting_partitions,
3543           lhs_non_contracting_partitions, rhs_non_contracting_partitions,
3544           output_lhs_non_contracting_partitions,
3545           output_rhs_non_contracting_partitions, options, b,
3546           windowed_dot_general_loops,
3547           /*may_reshard_without_detecting_match=*/false, visitor));
3548   if (try_partitioned_dot) {
3549     return try_partitioned_dot;
3550   }
3551 
3552   // Recursively partition on different types of dimensions.
3553   //
3554   // Case 1: Group partitions by batch.
3555   if ((lhs_batch_partitions == output_batch_partitions ||
3556        rhs_batch_partitions == output_batch_partitions) &&
3557       output_batch_partitions > 1) {
3558     TF_ASSIGN_OR_RETURN(
3559         auto dot,
3560         PartitionDotGroupOnBatch(
3561             lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3562             num_partitions, lhs_contracting_partitions,
3563             rhs_contracting_partitions, lhs_non_contracting_partitions,
3564             rhs_non_contracting_partitions, create_sharded_dot, conv_window,
3565             module, original_hlo, require_matching_devices_to_group, options, b,
3566             windowed_dot_general_loops, visitor));
3567     if (dot) {
3568       return dot;
3569     }
3570   }
3571 
3572   // Case 2: Group partitions by non-contracting dimensions.
3573   const bool may_group_on_lhs_non_contracting =
3574       lhs_non_contracting_partitions == output_lhs_non_contracting_partitions &&
3575       lhs_non_contracting_partitions > 1;
3576   const bool may_group_on_rhs_non_contracting =
3577       rhs_non_contracting_partitions == output_rhs_non_contracting_partitions &&
3578       rhs_non_contracting_partitions > 1;
3579   bool lhs_matching = false;
3580   std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
3581   if (may_group_on_lhs_non_contracting || may_group_on_rhs_non_contracting) {
3582     lhs_matching = LhsIsBestMatchForNonContractingPartitioning(
3583         dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
3584         num_partitions, lhs_non_contracting_partitions,
3585         rhs_non_contracting_partitions, lhs_non_contracting_partitions,
3586         rhs_non_contracting_partitions, lhs_contracting_partitions,
3587         rhs_contracting_partitions, output_lhs_non_contracting_partitions,
3588         output_rhs_non_contracting_partitions, lhs_batch_partitions,
3589         rhs_batch_partitions, b, conv_window, create_sharded_dot, visitor);
3590     matching_dims = lhs_matching ? dims_mapping.lhs_non_contracting_dims
3591                                  : dims_mapping.rhs_non_contracting_dims;
3592   } else if (lhs_non_contracting_partitions > 1 &&
3593              output_lhs_non_contracting_partitions > 1) {
3594     lhs_matching = true;
3595     for (const auto& dim : dims_mapping.lhs_non_contracting_dims) {
3596       int64_t lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
3597       if (lhs_partitions > 1 &&
3598           lhs_partitions == output_sharding.tile_assignment().dim(dim.output)) {
3599         matching_dims.push_back(dim);
3600       }
3601     }
3602   } else if (rhs_non_contracting_partitions > 1 &&
3603              output_rhs_non_contracting_partitions > 1) {
3604     lhs_matching = false;
3605     for (const auto& dim : dims_mapping.rhs_non_contracting_dims) {
3606       int64_t rhs_partitions = rhs.sharding().tile_assignment().dim(dim.rhs);
3607       if (rhs_partitions > 1 &&
3608           rhs_partitions == output_sharding.tile_assignment().dim(dim.output)) {
3609         matching_dims.push_back(dim);
3610       }
3611     }
3612   }
3613   const bool prioritize_contracting_for_faster_windowed_einsum =
3614       PrioritizeContractingDimensionsPartitioning(
3615           dims_mapping, lhs, rhs, output_base_shape, output_sharding, options,
3616           num_partitions, lhs_non_contracting_partitions,
3617           rhs_non_contracting_partitions, lhs_contracting_partitions,
3618           rhs_contracting_partitions, output_lhs_non_contracting_partitions,
3619           output_rhs_non_contracting_partitions, lhs_batch_partitions,
3620           rhs_batch_partitions, output_batch_partitions,
3621           require_matching_devices_to_group, b, conv_window, create_sharded_dot,
3622           visitor);
3623   if (!(matching_dims.empty() ||
3624         prioritize_contracting_for_faster_windowed_einsum)) {
3625     TF_ASSIGN_OR_RETURN(
3626         auto dot,
3627         PartitionDotGroupOnNonContracting(
3628             lhs_matching, lhs_matching ? lhs : rhs, lhs_matching ? rhs : lhs,
3629             lhs_matching ? lhs_contracting_partitions
3630                          : rhs_contracting_partitions,
3631             lhs_matching ? rhs_contracting_partitions
3632                          : lhs_contracting_partitions,
3633             matching_dims,
3634             lhs_matching ? rhs_non_contracting_partitions
3635                          : lhs_non_contracting_partitions,
3636             lhs_matching ? output_rhs_non_contracting_partitions
3637                          : output_lhs_non_contracting_partitions,
3638             output_base_shape, output_sharding, dims_mapping, num_partitions,
3639             create_sharded_dot, conv_window, module, original_hlo,
3640             require_matching_devices_to_group, options, b,
3641             windowed_dot_general_loops, visitor));
3642     if (dot) {
3643       return dot;
3644     }
3645   }
3646 
3647   // Case 3: Group partitions by contracting dimensions.
3648   if (lhs_contracting_partitions == rhs_contracting_partitions &&
3649       lhs_contracting_partitions > 1) {
3650     TF_ASSIGN_OR_RETURN(
3651         auto dot,
3652         PartitionDotGroupOnContracting(
3653             lhs, rhs, dims_mapping.contracting_dims, output_batch_partitions,
3654             output_lhs_non_contracting_partitions,
3655             output_rhs_non_contracting_partitions, output_base_shape,
3656             output_sharding, dims_mapping, num_partitions, create_sharded_dot,
3657             conv_window, module, original_hlo,
3658             require_matching_devices_to_group, options, b,
3659             windowed_dot_general_loops, visitor));
3660     if (dot) {
3661       return dot;
3662     }
3663   }
3664   if (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1) {
3665     // If part of contracting dims match, try them.
3666     std::vector<DotConvDimsMapping::DimsMapping> matching_dims;
3667     for (const auto& dim : dims_mapping.contracting_dims) {
3668       int64_t lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs);
3669       if (lhs_partitions > 1 &&
3670           lhs_partitions == rhs.sharding().tile_assignment().dim(dim.rhs)) {
3671         matching_dims.push_back(dim);
3672       }
3673     }
3674     if (!matching_dims.empty()) {
3675       TF_ASSIGN_OR_RETURN(
3676           auto dot, PartitionDotGroupOnContracting(
3677                         lhs, rhs, matching_dims, output_batch_partitions,
3678                         output_lhs_non_contracting_partitions,
3679                         output_rhs_non_contracting_partitions,
3680                         output_base_shape, output_sharding, dims_mapping,
3681                         num_partitions, create_sharded_dot, conv_window, module,
3682                         original_hlo, require_matching_devices_to_group,
3683                         options, b, windowed_dot_general_loops, visitor));
3684       if (dot) {
3685         return dot;
3686       }
3687     }
3688   }
3689 
3690   // Case 4: If operands are replicated but output is partially replicated,
3691   // recursive call with partial replication removed.
3692   if (lhs.sharding().IsReplicated() && rhs.sharding().IsReplicated() &&
3693       output_sharding.ReplicateOnLastTileDim()) {
3694     auto grouped_output = hlo_sharding_util::GroupShardingOnDims(
3695         output_sharding, {output_base_shape.rank()});
3696     auto inner_state = CreatePerGroupPartitioningState(
3697         lhs.state(), grouped_output.device_groups, b);
3698     TF_ASSIGN_OR_RETURN(
3699         auto dot,
3700         PartitionDot(PartitionedHlo(lhs.hlo(), lhs.base_shape(), inner_state),
3701                      PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state),
3702                      output_base_shape, grouped_output.sharding, dims_mapping,
3703                      output_sharding.NumTiles(), create_sharded_dot,
3704                      conv_window, module, original_hlo, options, b,
3705                      windowed_dot_general_loops, visitor));
3706     if (dot) {
3707       return dot;
3708     }
3709   }
3710 
3711   // We failed to find partial matches, invoke base case again with
3712   // may_reshard_without_detecting_match.
3713   TF_ASSIGN_OR_RETURN(
3714       auto dot,
3715       PartitionBaseCase(
3716           lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3717           num_partitions, create_sharded_dot, conv_window, module, original_hlo,
3718           lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions,
3719           lhs_contracting_partitions, rhs_contracting_partitions,
3720           lhs_non_contracting_partitions, rhs_non_contracting_partitions,
3721           output_lhs_non_contracting_partitions,
3722           output_rhs_non_contracting_partitions, options, b,
3723           windowed_dot_general_loops,
3724           /*may_reshard_without_detecting_match=*/true, visitor));
3725   if (dot) {
3726     return dot;
3727   }
3728   return nullptr;
3729 }
3730 
PartitionDot(PartitionedHlo lhs,PartitionedHlo rhs,const Shape & output_base_shape,const HloSharding & output_sharding,const DotConvDimsMapping & dims_mapping,int64_t num_partitions,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot,const Window & conv_window,HloModule * module,HloInstruction * original_hlo,const SpmdPartitionerOptions & options,SpmdBuilder * b,std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop> * windowed_dot_general_loops,SpmdPartitioningVisitor * visitor)3731 StatusOr<HloInstruction*> PartitionDot(
3732     PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
3733     const HloSharding& output_sharding, const DotConvDimsMapping& dims_mapping,
3734     int64_t num_partitions,
3735     const std::function<StatusOr<HloInstruction*>(
3736         HloInstruction*, HloInstruction*, SpmdBuilder*,
3737         const Window& conv_window)>& create_sharded_dot,
3738     const Window& conv_window, HloModule* module, HloInstruction* original_hlo,
3739     const SpmdPartitionerOptions& options, SpmdBuilder* b,
3740     std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
3741         windowed_dot_general_loops,
3742     SpmdPartitioningVisitor* visitor) {
3743   // First try partitioning without resharding the groups, then try allow
3744   // resharding the groups.
3745   for (bool require_matching_devices_to_group : {true, false}) {
3746     TF_ASSIGN_OR_RETURN(
3747         auto try_partition,
3748         PartitionDot(lhs, rhs, output_base_shape, output_sharding, dims_mapping,
3749                      num_partitions, create_sharded_dot, conv_window, module,
3750                      original_hlo, require_matching_devices_to_group, options,
3751                      b, windowed_dot_general_loops, visitor));
3752     if (try_partition) {
3753       return try_partition;
3754     }
3755   }
3756 
3757   // Default action.
3758   TF_ASSIGN_OR_RETURN(
3759       auto dot, create_sharded_dot(lhs.Replicate().hlo(), rhs.Replicate().hlo(),
3760                                    b, conv_window));
3761   dot->set_sharding(HloSharding::Replicate());
3762   return PartitionedHlo(dot, output_base_shape, lhs.state())
3763       .Reshard(output_sharding)
3764       .hlo();
3765 }
3766 
3767 }  // namespace
3768 
HandleDotHelper(HloInstruction * hlo,const DotConvDimsMapping & dims_mapping,const std::function<StatusOr<HloInstruction * > (HloInstruction *,HloInstruction *,SpmdBuilder *,const Window & conv_window)> & create_sharded_dot)3769 Status SpmdPartitioningVisitor::HandleDotHelper(
3770     HloInstruction* hlo, const DotConvDimsMapping& dims_mapping,
3771     const std::function<StatusOr<HloInstruction*>(
3772         HloInstruction*, HloInstruction*, SpmdBuilder*,
3773         const Window& conv_window)>& create_sharded_dot) {
3774   if (hlo->sharding().HasUniqueDevice()) {
3775     return DefaultAction(hlo);
3776   }
3777   auto& lhs = GetPartitionedHlo(hlo->operand(0));
3778   auto& rhs = GetPartitionedHlo(hlo->operand(1));
3779   Window conv_window;
3780   if (hlo->opcode() == HloOpcode::kConvolution) {
3781     conv_window = hlo->window();
3782   }
3783 
3784   TF_ASSIGN_OR_RETURN(
3785       auto partitioned_dot,
3786       PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping,
3787                    num_partitions_, create_sharded_dot, conv_window, module_,
3788                    hlo, options_, &b_, &windowed_dot_general_loops_, this));
3789   SetPartitionedHlo(hlo, [&] { return partitioned_dot; });
3790   return OkStatus();
3791 }
3792 
3793 namespace {
3794 
3795 // Finds a cluster of nodes that produce the inputs for `hlo` which only
3796 // depend on small operands, which means the cluster should start with
3797 // broadcasts, constants and iotas. All other internal nodes must be
3798 // non-side-effecting elemntwise ops. Returns the set of nodes, and the small
3799 // operands. E.g., for the following graph,
3800 //
3801 //     a -> broadcast -> multiply
3802 //     iota  ---> add--/
3803 //     constant/
3804 //
3805 // FindInputNodesIfOnlyDependOnSmallOperands(multiply) will return
3806 //    <{broadcast, iota, constant, add, multiply}, [a]>.
3807 std::pair<absl::flat_hash_set<HloInstruction*>, std::vector<HloInstruction*>>
FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction * hlo)3808 FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction* hlo) {
3809   absl::flat_hash_set<HloInstruction*> nodes_found;
3810   std::vector<HloInstruction*> new_operands;
3811   absl::flat_hash_set<const HloInstruction*> new_operands_set;
3812   std::vector<HloInstruction*> worklist;
3813   worklist.push_back(hlo);
3814   while (!worklist.empty()) {
3815     auto inst = worklist.back();
3816     worklist.pop_back();
3817     if (nodes_found.count(inst) > 0) {
3818       continue;
3819     }
3820     if (inst->opcode() == HloOpcode::kBroadcast ||
3821         inst->opcode() == HloOpcode::kConstant ||
3822         inst->opcode() == HloOpcode::kIota) {
3823       nodes_found.insert(inst);
3824       for (auto o : inst->operands()) {
3825         auto res = new_operands_set.emplace(o);
3826         if (res.second) {
3827           new_operands.push_back(o);
3828         }
3829       }
3830     } else if (inst->IsElementwise() && !inst->HasSideEffectNoRecurse() &&
3831                absl::c_all_of(inst->operands(),
3832                               [inst](const HloInstruction* o) {
3833                                 return ShapeUtil::CompatibleIgnoringElementType(
3834                                     o->shape(), inst->shape());
3835                               })) {
3836       nodes_found.insert(inst);
3837       for (auto o : inst->operands()) {
3838         worklist.push_back(o);
3839       }
3840     } else {
3841       nodes_found.clear();
3842       new_operands.clear();
3843       break;
3844     }
3845   }
3846   return {std::move(nodes_found), std::move(new_operands)};
3847 }
3848 
3849 // Moves a cluster of memory-reducing nodes into the windowed dot-general loop
3850 // on contracting dimensions. Such a loop has a dynamic slice on the
3851 // non-windowed operand. If we move the input nodes into the loop, the
3852 // dynamic-slice could be merged with them by later optimization passes, which
3853 // reduces memory.
3854 //
3855 // small_operands             small_operands
3856 //        |                          |
3857 // input_nodes                loop { |
3858 //        |          =>         input_nodes
3859 // loop { |                          |
3860 //    dynamic-slice             dynamic-slice
3861 //    ...                       ...
3862 // }                          }
3863 //
3864 // Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
3865 // with the input nodes.
SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(HloInstruction * loop,int64_t non_windowed_operand_index)3866 Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
3867     HloInstruction* loop, int64_t non_windowed_operand_index) {
3868   auto input_tuple = loop->mutable_operand(0);
3869   auto old_operand = input_tuple->mutable_operand(non_windowed_operand_index);
3870   auto input_nodes = FindInputNodesIfOnlyDependOnSmallOperands(old_operand);
3871   auto to_sink = std::move(input_nodes.first);
3872   auto new_operands = std::move(input_nodes.second);
3873   if (to_sink.empty()) {
3874     return OkStatus();
3875   }
3876   auto computation = loop->parent();
3877   // Replace the old operand with a tuple of the found small operands.
3878   auto new_input_subtuple =
3879       computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
3880   TF_RETURN_IF_ERROR(input_tuple->ReplaceOperandWithDifferentShape(
3881       non_windowed_operand_index, new_input_subtuple));
3882 
3883   auto body = loop->while_body();
3884   auto body_param = body->parameter_instruction(0);
3885   auto old_body_param_users = body_param->users();
3886   // Update all tuple shapes.
3887   for (auto tuple : std::vector<HloInstruction*>{
3888            input_tuple, loop, loop->while_condition()->parameter_instruction(0),
3889            body_param, body->root_instruction()}) {
3890     *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(),
3891                                    {non_windowed_operand_index}) =
3892         new_input_subtuple->shape();
3893   }
3894   // Now update the loop body.
3895   auto new_operand_tuple_inside =
3896       body->AddInstruction(HloInstruction::CreateGetTupleElement(
3897           new_input_subtuple->shape(), body_param, non_windowed_operand_index));
3898   TF_RETURN_IF_ERROR(body->root_instruction()->ReplaceOperandWithDifferentShape(
3899       non_windowed_operand_index, new_operand_tuple_inside));
3900 
3901   // Create nodes inside the loop body.
3902   std::vector<HloInstruction*> worklist;
3903   absl::flat_hash_map<const HloInstruction*, HloInstruction*> outside_to_inside;
3904   auto add_users_if_available = [&](HloInstruction* inst) {
3905     for (auto u : inst->users()) {
3906       if (outside_to_inside.count(u) == 0 && to_sink.count(u) > 0 &&
3907           absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
3908             return outside_to_inside.count(o) > 0;
3909           })) {
3910         worklist.push_back(u);
3911       }
3912     }
3913   };
3914   for (int64_t i = 0; i < new_operands.size(); ++i) {
3915     outside_to_inside[new_operands[i]] =
3916         body->AddInstruction(HloInstruction::CreateGetTupleElement(
3917             new_operands[i]->shape(), new_operand_tuple_inside, i));
3918     add_users_if_available(new_operands[i]);
3919   }
3920   // HLOs to sink without operands.
3921   std::vector<HloInstruction*> nullaries_to_sink;
3922   for (auto inst : to_sink) {
3923     if (inst->operand_count() == 0) {
3924       nullaries_to_sink.push_back(inst);
3925     }
3926   }
3927   // Sort nullaries_to_sink to make it deterministic.
3928   absl::c_sort(nullaries_to_sink,
3929                [](const HloInstruction* a, const HloInstruction* b) {
3930                  return a->unique_id() < b->unique_id();
3931                });
3932   worklist.reserve(nullaries_to_sink.size());
3933   for (auto inst : nullaries_to_sink) {
3934     worklist.push_back(inst);
3935   }
3936   while (!worklist.empty()) {
3937     auto inst = worklist.back();
3938     worklist.pop_back();
3939     std::vector<HloInstruction*> inst_new_operands(inst->operand_count());
3940     for (int64_t i = 0; i < inst->operand_count(); ++i) {
3941       inst_new_operands[i] = outside_to_inside[inst->operand(i)];
3942     }
3943     outside_to_inside[inst] = body->AddInstruction(
3944         inst->CloneWithNewOperands(inst->shape(), inst_new_operands));
3945     add_users_if_available(inst);
3946   }
3947   TF_RET_CHECK(outside_to_inside.count(old_operand) > 0);
3948   for (auto ou : old_body_param_users) {
3949     if (ou->opcode() == HloOpcode::kGetTupleElement &&
3950         ou->tuple_index() == non_windowed_operand_index) {
3951       TF_RETURN_IF_ERROR(
3952           ou->ReplaceAllUsesWith(outside_to_inside[old_operand]));
3953       TF_RETURN_IF_ERROR(body->RemoveInstruction(ou));
3954     }
3955   }
3956   return OkStatus();
3957 }
3958 
3959 // Checks a condition holds true for all recursive operands of an hlo.
CheckOperandsRecursive(const HloInstruction * hlo,std::function<bool (const HloInstruction *)> check)3960 bool CheckOperandsRecursive(const HloInstruction* hlo,
3961                             std::function<bool(const HloInstruction*)> check) {
3962   std::deque<const HloInstruction*> worklist;
3963   worklist.push_front(hlo);
3964   while (!worklist.empty()) {
3965     auto inst = worklist.back();
3966     worklist.pop_back();
3967     for (HloInstruction* operand : inst->operands()) {
3968       if (!check(operand)) {
3969         return false;
3970       }
3971       worklist.push_front(operand);
3972     }
3973   }
3974   return true;
3975 }
3976 
3977 // Moves a cluster of memory-reducing nodes (with reduce nodes at the end)
3978 // into the windowed dot-general loop on non-contracting dimensions. Such a
3979 // loop has a dynamic-update-slice at the output. If we move the user nodes
3980 // into the loop and before the dynamic-update-slice, the user nodes can
3981 // operate on smaller shapes, which reduces memory.
3982 //
3983 // small_operands                   small_operands
3984 //  | |                 =>                  | |
3985 //  | |  loop {                     loop {  | |
3986 //  | |    conv                             | broadcast      conv
3987 //  | |      |                              |     |           /
3988 //  | | dynamic-update-slice                |  dynamic-slice /
3989 //  | |         |                           |     |         /
3990 //  | |  }      |                           |  multiply-----
3991 //  |broadcast  /                           |    /
3992 //  | |        /                            reduce
3993 //  |multiply--                             |
3994 //  \ |                                dynamic-update-slice
3995 //   reduce                         }
3996 //
3997 // Later optimization passes (TpuPadSliceMover) will merge the dynamic slice
3998 // with the input nodes (broadcast).
MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(HloInstruction * loop,const SpmdPartitionerOptions & options)3999 Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
4000     HloInstruction* loop, const SpmdPartitionerOptions& options) {
4001   CHECK_EQ(loop->user_count(), 1);
4002   // There should be a single direct user of the while loop, which is the
4003   // gte for element 2, i.e., the dot output.
4004   auto* user_gte = loop->users().front();
4005   CHECK_EQ(user_gte->opcode(), HloOpcode::kGetTupleElement);
4006   CHECK_EQ(user_gte->tuple_index(), 2);
4007   auto* computation = loop->parent();
4008 
4009   // Find the reduce outputs and the input nodes they depend on, if input
4010   // nodes only have small operands.
4011   absl::flat_hash_set<HloInstruction*> to_move;
4012   std::vector<HloInstruction*> new_operands;
4013   absl::flat_hash_set<const HloInstruction*> new_operands_set;
4014   std::vector<HloInstruction*> reduce_outputs;
4015   std::vector<HloInstruction*> worklist;
4016   Shape padded_shape = user_gte->shape();
4017   Shape unpadded_shape = user_gte->shape();
4018   auto* original_output = user_gte;
4019 
4020   if (user_gte->user_count() == 1 &&
4021       user_gte->users().back()->opcode() == HloOpcode::kSlice) {
4022     original_output = user_gte->users().back();
4023     unpadded_shape = original_output->shape();
4024   }
4025   for (auto* u : original_output->users()) {
4026     worklist.push_back(u);
4027   }
4028   to_move.insert(original_output);
4029   while (!worklist.empty()) {
4030     auto* inst = worklist.back();
4031     worklist.pop_back();
4032     if (to_move.count(inst) > 0) {
4033       continue;
4034     }
4035     // We only support reduces with simple reduction function, since we may
4036     // need to accumulate across iterations manually.
4037     if (inst->opcode() == HloOpcode::kReduce &&
4038         inst->to_apply()->instruction_count() == 3 &&
4039         inst->to_apply()->num_parameters() == 2 &&
4040         inst->to_apply()->root_instruction()->IsElementwise()) {
4041       to_move.insert(inst);
4042       auto* other_operand = inst->mutable_operand(1);
4043       auto res = new_operands_set.emplace(other_operand);
4044       if (res.second) {
4045         new_operands.push_back(other_operand);
4046       }
4047       reduce_outputs.push_back(inst);
4048     } else if (inst != computation->root_instruction() &&
4049                inst->user_count() > 0 && inst->IsElementwise() &&
4050                !inst->HasSideEffectNoRecurse() &&
4051                absl::c_all_of(inst->operands(),
4052                               [inst](const HloInstruction* o) {
4053                                 return ShapeUtil::CompatibleIgnoringElementType(
4054                                     o->shape(), inst->shape());
4055                               })) {
4056       // For an elementwise op, we need to make sure that they depend on only
4057       // nodes already in to_move and nodes with small operands.
4058       bool can_include = true;
4059       for (auto* operand : inst->operands()) {
4060         if (to_move.count(operand) > 0) {
4061           continue;
4062         }
4063         auto find_result = FindInputNodesIfOnlyDependOnSmallOperands(operand);
4064         if (find_result.first.empty()) {
4065           can_include = false;
4066           break;
4067         }
4068         for (auto* n : find_result.first) {
4069           to_move.insert(n);
4070         }
4071         for (auto* new_operand : find_result.second) {
4072           auto res = new_operands_set.insert(new_operand);
4073           if (res.second) {
4074             new_operands.push_back(new_operand);
4075           }
4076         }
4077       }
4078       if (!can_include) {
4079         to_move.clear();
4080         break;
4081       }
4082       to_move.insert(inst);
4083       for (auto* u : inst->users()) {
4084         worklist.push_back(u);
4085       }
4086     } else {
4087       to_move.clear();
4088       break;
4089     }
4090   }
4091   // If nothing is found, to_move could contain only original_output, or
4092   // cleared by the above code.
4093   if (to_move.size() <= 1) {
4094     return OkStatus();
4095   }
4096 
4097   // If there is a reduce that's dependent of another reduce, then we can't do
4098   // code motion, as it will create a circular dependencies.
4099   if (reduce_outputs.size() > 10) {
4100     // When there are many reduces, it might be faster to build a reachibility
4101     // map, and then check pair-wise reachability.
4102     auto reachability = HloReachabilityMap::Build(computation);
4103     for (const HloInstruction* reduce : reduce_outputs) {
4104       for (const HloInstruction* other_reduce : reduce_outputs) {
4105         if (reduce != other_reduce &&
4106             reachability->IsReachable(reduce, other_reduce)) {
4107           return OkStatus();
4108         }
4109       }
4110     }
4111   } else if (reduce_outputs.size() > 1) {
4112     // When there are only few reduces, we can do traversal to check dependency.
4113     for (const HloInstruction* reduce : reduce_outputs) {
4114       auto reduce_outputs_do_not_contain = [&](const HloInstruction* inst) {
4115         return !absl::c_linear_search(reduce_outputs, inst);
4116       };
4117       if (!CheckOperandsRecursive(reduce, reduce_outputs_do_not_contain)) {
4118         return OkStatus();
4119       }
4120     }
4121   }
4122 
4123   // We will replace the original loop output with reduce-shape outputs.
4124   // Create the initial buffers before the loop.
4125   for (auto* out : reduce_outputs) {
4126     Shape padded_out_shape = out->shape();
4127     int64_t operand_dim = 0;
4128     int64_t output_dim = 0;
4129     while (output_dim < padded_out_shape.rank()) {
4130       if (absl::c_linear_search(out->dimensions(), operand_dim)) {
4131         // Dimension colapsed.
4132         ++operand_dim;
4133         continue;
4134       }
4135       // Kept dimensions have the same size of the padded shape.
4136       padded_out_shape.set_dimensions(output_dim,
4137                                       padded_shape.dimensions(operand_dim));
4138       ++operand_dim;
4139       ++output_dim;
4140     }
4141     auto* broadcast =
4142         computation->AddInstruction(HloInstruction::CreateBroadcast(
4143             padded_out_shape,
4144             computation->AddInstruction(HloInstruction::CreateConstant(
4145                 LiteralUtil::Zero(out->shape().element_type()))),
4146             {}));
4147     new_operands.push_back(broadcast);
4148   }
4149 
4150   auto* input_tuple = loop->mutable_operand(0);
4151   // Create the new input subtuple that contains the small operands and the
4152   // reduce-shape result buffers.
4153   auto* new_input_subtuple =
4154       computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
4155   TF_RETURN_IF_ERROR(
4156       input_tuple->ReplaceOperandWithDifferentShape(2, new_input_subtuple));
4157   auto* body = loop->while_body();
4158   auto* body_param = body->parameter_instruction(0);
4159   auto* body_root = body->root_instruction();
4160   CHECK_EQ(body_root->opcode(), HloOpcode::kTuple);
4161   // Update tuple shapes.
4162   for (auto* tuple : std::vector<HloInstruction*>{
4163            input_tuple, loop, loop->while_condition()->parameter_instruction(0),
4164            body_param, body_root}) {
4165     *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {2}) =
4166         new_input_subtuple->shape();
4167   }
4168   auto* new_loop_input =
4169       body->AddInstruction(HloInstruction::CreateGetTupleElement(
4170           new_input_subtuple->shape(), body_param, 2));
4171 
4172   // Represents a cluster that's associated to a single dynamic-update-slice op,
4173   // which should be moved to inside of the windowed dot-general loop. There
4174   // might be multiple clusters associated with multiple dynamic-update-slice
4175   // ops which all need moving.
4176   struct MotionCluster {
4177     HloInstruction* dus;
4178     absl::flat_hash_map<const HloInstruction*, HloInstruction*>
4179         outside_to_inside;
4180     std::vector<HloInstruction*> slice_offsets;
4181   };
4182 
4183   std::vector<MotionCluster> motion_clusters;
4184 
4185   // The elementwise nodes will be created with sliced shape. The original
4186   // loop output corresponds to the dynamic-update-slice's update slice.
4187   {
4188     HloInstruction* dus = body_root->mutable_operand(2);
4189     while (dus->opcode() == HloOpcode::kDynamicUpdateSlice) {
4190       motion_clusters.emplace_back();
4191       motion_clusters.back().dus = dus;
4192       motion_clusters.back().outside_to_inside[original_output] =
4193           dus->mutable_operand(1);
4194       motion_clusters.back().slice_offsets.reserve(padded_shape.rank());
4195       for (int64_t i = 0; i < padded_shape.rank(); ++i) {
4196         motion_clusters.back().slice_offsets.push_back(
4197             motion_clusters.back().dus->mutable_operand(i + 2));
4198       }
4199       dus = dus->mutable_operand(0);
4200     }
4201   }
4202   // This is at least one cluster that needs moving.
4203   CHECK_GE(motion_clusters.size(), 1);
4204   MotionCluster& base_motion_cluster = motion_clusters[0];
4205 
4206   worklist.clear();
4207   auto add_users_if_available = [&](HloInstruction* inst) {
4208     for (auto* u : inst->users()) {
4209       if (base_motion_cluster.outside_to_inside.count(u) == 0 &&
4210           to_move.count(u) > 0 &&
4211           absl::c_all_of(u->operands(), [&](const HloInstruction* o) {
4212             return base_motion_cluster.outside_to_inside.count(o) > 0;
4213           })) {
4214         worklist.push_back(u);
4215       }
4216     }
4217   };
4218 
4219   for (int64_t i = 0; i < new_operands.size(); ++i) {
4220     auto* operand_gte =
4221         body->AddInstruction(HloInstruction::CreateGetTupleElement(
4222             new_operands[i]->shape(), new_loop_input, i));
4223     for (MotionCluster& motion_cluster : motion_clusters) {
4224       motion_cluster.outside_to_inside[new_operands[i]] = operand_gte;
4225     }
4226     add_users_if_available(new_operands[i]);
4227   }
4228   add_users_if_available(original_output);
4229 
4230   // Now create the moved nodes inside the loop body.
4231   auto get_slice = [&](HloInstruction* padded,
4232                        absl::Span<HloInstruction* const> slice_offsets,
4233                        HloInstruction* dus) {
4234     return body->AddInstruction(HloInstruction::CreateDynamicSlice(
4235         ShapeUtil::ChangeElementType(dus->operand(1)->shape(),
4236                                      padded->shape().element_type()),
4237         padded, slice_offsets, dus->operand(1)->shape().dimensions()));
4238   };
4239   // Helper functions to create nodes with small operands.
4240   auto add_broadcast = [&](const HloInstruction* broadcast) {
4241     Shape padded_operand_shape = broadcast->operand(0)->shape();
4242     for (int64_t i = 0; i < broadcast->dimensions().size(); ++i) {
4243       padded_operand_shape.set_dimensions(
4244           i, padded_shape.dimensions(broadcast->dimensions(i)));
4245     }
4246     auto* padded_operand =
4247         PadToShape(base_motion_cluster.outside_to_inside[broadcast->operand(0)],
4248                    padded_operand_shape, body);
4249     auto* inside_broadcast =
4250         body->AddInstruction(broadcast->CloneWithNewOperands(
4251             ShapeUtil::ChangeElementType(padded_shape,
4252                                          padded_operand_shape.element_type()),
4253             {padded_operand}));
4254     for (MotionCluster& motion_cluster : motion_clusters) {
4255       motion_cluster.outside_to_inside[broadcast] = get_slice(
4256           inside_broadcast, motion_cluster.slice_offsets, motion_cluster.dus);
4257     }
4258   };
4259   auto add_iota = [&](const HloInstruction* iota) {
4260     auto* inside_iota = body->AddInstruction(iota->CloneWithNewOperands(
4261         ShapeUtil::ChangeElementType(padded_shape,
4262                                      iota->shape().element_type()),
4263         {}));
4264     for (MotionCluster& motion_cluster : motion_clusters) {
4265       motion_cluster.outside_to_inside[iota] = get_slice(
4266           inside_iota, motion_cluster.slice_offsets, motion_cluster.dus);
4267     }
4268   };
4269   auto add_constant = [&](const HloInstruction* constant) {
4270     auto* constant_clone = body->AddInstruction(constant->Clone());
4271     auto* inside_constant =
4272         PadToShape(constant_clone,
4273                    ShapeUtil::ChangeElementType(
4274                        padded_shape, constant->shape().element_type()),
4275                    body);
4276     for (MotionCluster& motion_cluster : motion_clusters) {
4277       motion_cluster.outside_to_inside[constant] = get_slice(
4278           inside_constant, motion_cluster.slice_offsets, motion_cluster.dus);
4279     }
4280   };
4281   auto add_other_inst = [&](const HloInstruction* inst) {
4282     std::vector<HloInstruction*> operands_inside(inst->operand_count());
4283     for (MotionCluster& motion_cluster : motion_clusters) {
4284       for (int64_t i = 0; i < operands_inside.size(); ++i) {
4285         operands_inside[i] = motion_cluster.outside_to_inside[inst->operand(i)];
4286       }
4287       motion_cluster.outside_to_inside[inst] =
4288           body->AddInstruction(inst->CloneWithNewOperands(
4289               ShapeUtil::ChangeElementType(
4290                   motion_cluster.dus->operand(1)->shape(),
4291                   inst->shape().element_type()),
4292               operands_inside));
4293     }
4294   };
4295 
4296   while (!worklist.empty()) {
4297     auto* inst = worklist.back();
4298     worklist.pop_back();
4299     if (absl::c_all_of(
4300             motion_clusters, [inst](const MotionCluster& motion_cluster) {
4301               return motion_cluster.outside_to_inside.count(inst) > 0;
4302             })) {
4303       continue;
4304     }
4305     if (inst->opcode() == HloOpcode::kBroadcast) {
4306       add_broadcast(inst);
4307     } else if (inst->opcode() == HloOpcode::kIota) {
4308       add_iota(inst);
4309     } else if (inst->opcode() == HloOpcode::kConstant) {
4310       add_constant(inst);
4311     } else if (inst->opcode() == HloOpcode::kReduce) {
4312       // This is an output, for which we has special handling later.
4313     } else if (inst->IsElementwise()) {
4314       add_other_inst(inst);
4315     } else {
4316       // Skip cloning other non-elementwise ops.
4317     }
4318     add_users_if_available(inst);
4319   }
4320   std::vector<HloInstruction*> new_outputs_inside(new_operands.size());
4321   for (int64_t i = 0; i < new_outputs_inside.size(); ++i) {
4322     new_outputs_inside[i] =
4323         base_motion_cluster.outside_to_inside[new_operands[i]];
4324   }
4325 
4326   // Now create the reduce outputs inside of the loop.
4327   for (int64_t i = 0; i < reduce_outputs.size(); ++i) {
4328     auto* reduce_outside = reduce_outputs[i];
4329     CHECK_EQ(reduce_outside->opcode(), HloOpcode::kReduce);
4330     int64_t index_in_operand = new_operands.size() - reduce_outputs.size() + i;
4331     auto* last_iter_result =
4332         base_motion_cluster.outside_to_inside[new_operands[index_in_operand]];
4333 
4334     auto create_inside_reduce =
4335         [&](absl::flat_hash_map<const HloInstruction*, HloInstruction*>&
4336                 outside_to_inside,
4337             absl::Span<HloInstruction* const> slice_offsets,
4338             HloInstruction* last_iter_result) -> StatusOr<HloInstruction*> {
4339       HloInstruction* operand0 = outside_to_inside[reduce_outside->operand(0)];
4340       HloInstruction* operand1 = outside_to_inside[reduce_outside->operand(1)];
4341       TF_ASSIGN_OR_RETURN(
4342           Shape reduce_shape,
4343           ShapeInference::InferReduceShape(
4344               {&operand0->shape(), &operand1->shape()},
4345               reduce_outside->dimensions(),
4346               reduce_outside->to_apply()->ComputeProgramShape()));
4347       *reduce_shape.mutable_layout() = reduce_outside->shape().layout();
4348       std::vector<HloInstruction*> reduce_dus_offsets;
4349       // If any collapsed dimension is windowed, we need to accumulate with last
4350       // iteration's result. If such a dimension has padding, we also need to
4351       // mask off invalid data.
4352       bool needs_accumulate = false;
4353       std::vector<int64_t> dims_to_mask;
4354       for (int64_t i = 0; i < slice_offsets.size(); ++i) {
4355         if (absl::c_linear_search(reduce_outside->dimensions(), i)) {
4356           if (reduce_outside->operand(0)->shape().dimensions(i) !=
4357               operand0->shape().dimensions(i)) {
4358             needs_accumulate = true;
4359             if (unpadded_shape.dimensions(i) != padded_shape.dimensions(i)) {
4360               dims_to_mask.push_back(i);
4361             }
4362           }
4363           continue;
4364         }
4365         reduce_dus_offsets.push_back(slice_offsets[i]);
4366       }
4367       // Mask off invalid data in collapsed dimensions.
4368       for (int64_t dim : dims_to_mask) {
4369         auto* iota = body->AddInstruction(HloInstruction::CreateIota(
4370             ShapeUtil::ChangeElementType(operand0->shape(), S32), dim));
4371         auto* add = body->AddInstruction(HloInstruction::CreateBinary(
4372             iota->shape(), HloOpcode::kAdd, iota,
4373             body->AddInstruction(HloInstruction::CreateBroadcast(
4374                 iota->shape(), slice_offsets[dim], {}))));
4375         auto* limit = body->AddInstruction(HloInstruction::CreateBroadcast(
4376             iota->shape(),
4377             body->AddInstruction(
4378                 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(
4379                     reduce_outside->operand(0)->shape().dimensions(dim)))),
4380             {}));
4381         auto* compare = body->AddInstruction(HloInstruction::CreateCompare(
4382             ShapeUtil::ChangeElementType(iota->shape(), PRED), add, limit,
4383             ComparisonDirection::kLt));
4384         operand0 = body->AddInstruction(HloInstruction::CreateTernary(
4385             operand0->shape(), HloOpcode::kSelect, compare, operand0,
4386             body->AddInstruction(HloInstruction::CreateBroadcast(
4387                 operand0->shape(), operand1, {}))));
4388       }
4389       auto* output_inside =
4390           body->AddInstruction(reduce_outside->CloneWithNewOperands(
4391               reduce_shape, {operand0, operand1}));
4392       // Accumulate with previous results if needed.
4393       if (needs_accumulate) {
4394         auto* input_slice =
4395             body->AddInstruction(HloInstruction::CreateDynamicSlice(
4396                 output_inside->shape(), last_iter_result, reduce_dus_offsets,
4397                 output_inside->shape().dimensions()));
4398         output_inside = body->AddInstruction(HloInstruction::CreateBinary(
4399             output_inside->shape(),
4400             reduce_outside->to_apply()->root_instruction()->opcode(),
4401             output_inside, input_slice));
4402       }
4403       // Dynamic-update-slice if needed.
4404       if (!ShapeUtil::Compatible(output_inside->shape(),
4405                                  last_iter_result->shape())) {
4406         output_inside =
4407             body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
4408                 last_iter_result->shape(), last_iter_result, output_inside,
4409                 reduce_dus_offsets));
4410       }
4411       return output_inside;
4412     };
4413     for (MotionCluster& motion_cluster : motion_clusters) {
4414       TF_ASSIGN_OR_RETURN(
4415           last_iter_result,
4416           create_inside_reduce(motion_cluster.outside_to_inside,
4417                                motion_cluster.slice_offsets, last_iter_result));
4418     }
4419     new_outputs_inside[index_in_operand] = last_iter_result;
4420   }
4421 
4422   // Body output.
4423   auto* new_output_inside =
4424       body->AddInstruction(HloInstruction::CreateTuple(new_outputs_inside));
4425   TF_RETURN_IF_ERROR(
4426       body_root->ReplaceOperandWithDifferentShape(2, new_output_inside));
4427   TF_RETURN_IF_ERROR(
4428       body->RemoveInstructionAndUnusedOperands(base_motion_cluster.dus));
4429   // Replace uses of the reduces outside the loop.
4430   auto* new_output_gte =
4431       computation->AddInstruction(HloInstruction::CreateGetTupleElement(
4432           new_output_inside->shape(), loop, 2));
4433   for (int64_t i = 0; i < reduce_outputs.size(); ++i) {
4434     int64_t index_in_operand = new_operands.size() - reduce_outputs.size() + i;
4435     auto* new_output =
4436         computation->AddInstruction(HloInstruction::CreateGetTupleElement(
4437             new_outputs_inside[index_in_operand]->shape(), new_output_gte,
4438             index_in_operand));
4439     if (!ShapeUtil::Compatible(new_output->shape(),
4440                                reduce_outputs[i]->shape())) {
4441       new_output = computation->AddInstruction(HloInstruction::CreateSlice(
4442           reduce_outputs[i]->shape(), new_output,
4443           std::vector<int64_t>(new_output->shape().rank(), 0),
4444           reduce_outputs[i]->shape().dimensions(),
4445           std::vector<int64_t>(new_output->shape().rank(), 1)));
4446     }
4447     TF_RETURN_IF_ERROR(reduce_outputs[i]->ReplaceAllUsesWith(new_output));
4448     TF_RETURN_IF_ERROR(
4449         computation->RemoveInstructionAndUnusedOperands(reduce_outputs[i]));
4450   }
4451   return OkStatus();
4452 }
4453 
4454 }  // namespace
4455 
DoCodeMotionForWindowedDotGeneralLoops(HloComputation * computation,const SpmdPartitionerOptions & options)4456 Status SpmdPartitioningVisitor::DoCodeMotionForWindowedDotGeneralLoops(
4457     HloComputation* computation, const SpmdPartitionerOptions& options) {
4458   for (auto& loop : windowed_dot_general_loops_) {
4459     if (loop.windowed_in_contracting_dims || loop.windowed_in_batch_dims ||
4460         loop.operands_sharded_at_contracting_dims) {
4461       // We have a dynamic-slice for the non-windowed operand in
4462       // batch/contracting-dim/noncontracting-dim windowed dot-general. So
4463       // moving the broadcast/iota/elementwise ops into the loop could help
4464       // reduce memory via fusion.
4465       TF_RETURN_IF_ERROR(
4466           SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions(
4467               loop.while_loop, 1 - loop.windowed_operand));
4468     }
4469     // Currently unrolled loop does not support this optimization.
4470     if (!loop.windowed_in_contracting_dims &&
4471         !loop.operands_sharded_at_contracting_dims) {
4472       // We have a dynamic-update-slice for the output in
4473       // batch/non-contracting-dim windowed dot-general. So moving reduce ops
4474       // into the loop could help reduce memory.
4475       TF_RETURN_IF_ERROR(
4476           MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions(
4477               loop.while_loop, options));
4478     }
4479   }
4480   return OkStatus();
4481 }
4482 
4483 }  // namespace spmd
4484 }  // namespace xla
4485